This is page 7 of 17. Use http://codebase.md/stanfordnlp/dspy?lines=true&page={x} to view the full context. # Directory Structure ``` ├── .github │ ├── .internal_dspyai │ │ ├── internals │ │ │ ├── build-and-release.md │ │ │ └── release-checklist.md │ │ └── pyproject.toml │ ├── .tmp │ │ └── .generated-actions │ │ └── run-pypi-publish-in-docker-container │ │ └── action.yml │ ├── ISSUE_TEMPLATE │ │ ├── bug_report.yml │ │ └── feature_request.yml │ ├── PULL_REQUEST_TEMPLATE │ │ └── pull_request_template.md │ ├── workflow_scripts │ │ └── install_testpypi_pkg.sh │ └── workflows │ ├── build_and_release.yml │ ├── build_utils │ │ └── test_version.py │ ├── docs-push.yml │ ├── precommits_check.yml │ └── run_tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── docs │ ├── .gitignore │ ├── docs │ │ ├── api │ │ │ ├── adapters │ │ │ │ ├── Adapter.md │ │ │ │ ├── ChatAdapter.md │ │ │ │ ├── JSONAdapter.md │ │ │ │ └── TwoStepAdapter.md │ │ │ ├── evaluation │ │ │ │ ├── answer_exact_match.md │ │ │ │ ├── answer_passage_match.md │ │ │ │ ├── CompleteAndGrounded.md │ │ │ │ ├── Evaluate.md │ │ │ │ ├── EvaluationResult.md │ │ │ │ └── SemanticF1.md │ │ │ ├── experimental │ │ │ │ ├── Citations.md │ │ │ │ └── Document.md │ │ │ ├── index.md │ │ │ ├── models │ │ │ │ ├── Embedder.md │ │ │ │ └── LM.md │ │ │ ├── modules │ │ │ │ ├── BestOfN.md │ │ │ │ ├── ChainOfThought.md │ │ │ │ ├── CodeAct.md │ │ │ │ ├── Module.md │ │ │ │ ├── MultiChainComparison.md │ │ │ │ ├── Parallel.md │ │ │ │ ├── Predict.md │ │ │ │ ├── ProgramOfThought.md │ │ │ │ ├── ReAct.md │ │ │ │ └── Refine.md │ │ │ ├── optimizers │ │ │ │ ├── BetterTogether.md │ │ │ │ ├── BootstrapFewShot.md │ │ │ │ ├── BootstrapFewShotWithRandomSearch.md │ │ │ │ ├── BootstrapFinetune.md │ │ │ │ ├── BootstrapRS.md │ │ │ │ ├── COPRO.md │ │ │ │ ├── Ensemble.md │ │ │ │ ├── GEPA │ │ │ │ │ ├── GEPA_Advanced.md │ │ │ │ │ └── overview.md │ │ │ │ ├── InferRules.md │ │ │ │ ├── KNN.md │ │ │ │ ├── KNNFewShot.md │ │ │ │ ├── LabeledFewShot.md │ │ │ │ ├── MIPROv2.md │ │ │ │ └── SIMBA.md │ │ │ ├── primitives │ │ │ │ ├── Audio.md │ │ │ │ ├── Code.md │ │ │ │ ├── Example.md │ │ │ │ ├── History.md │ │ │ │ ├── Image.md │ │ │ │ ├── Prediction.md │ │ │ │ ├── Tool.md │ │ │ │ └── ToolCalls.md │ │ │ ├── signatures │ │ │ │ ├── InputField.md │ │ │ │ ├── OutputField.md │ │ │ │ └── Signature.md │ │ │ ├── tools │ │ │ │ ├── ColBERTv2.md │ │ │ │ ├── Embeddings.md │ │ │ │ └── PythonInterpreter.md │ │ │ └── utils │ │ │ ├── asyncify.md │ │ │ ├── configure_cache.md │ │ │ ├── disable_litellm_logging.md │ │ │ ├── disable_logging.md │ │ │ ├── enable_litellm_logging.md │ │ │ ├── enable_logging.md │ │ │ ├── inspect_history.md │ │ │ ├── load.md │ │ │ ├── StatusMessage.md │ │ │ ├── StatusMessageProvider.md │ │ │ ├── streamify.md │ │ │ └── StreamListener.md │ │ ├── cheatsheet.md │ │ ├── community │ │ │ ├── community-resources.md │ │ │ ├── how-to-contribute.md │ │ │ └── use-cases.md │ │ ├── deep-dive │ │ │ └── data-handling │ │ │ ├── built-in-datasets.md │ │ │ ├── examples.md │ │ │ ├── img │ │ │ │ └── data-loading.png │ │ │ └── loading-custom-data.md │ │ ├── faqs.md │ │ ├── index.md │ │ ├── js │ │ │ └── runllm-widget.js │ │ ├── learn │ │ │ ├── evaluation │ │ │ │ ├── data.md │ │ │ │ ├── metrics.md │ │ │ │ └── overview.md │ │ │ ├── figures │ │ │ │ ├── native_tool_call.png │ │ │ │ └── teleprompter-classes.png │ │ │ ├── index.md │ │ │ ├── optimization │ │ │ │ ├── optimizers.md │ │ │ │ └── overview.md │ │ │ └── programming │ │ │ ├── 7-assertions.md │ │ │ ├── adapters.md │ │ │ ├── language_models.md │ │ │ ├── mcp.md │ │ │ ├── modules.md │ │ │ ├── overview.md │ │ │ ├── signatures.md │ │ │ └── tools.md │ │ ├── production │ │ │ └── index.md │ │ ├── roadmap.md │ │ ├── static │ │ │ ├── .nojekyll │ │ │ └── img │ │ │ ├── dspy_logo.png │ │ │ ├── logo.png │ │ │ ├── mlflow-tracing-rag.png │ │ │ ├── modular.png │ │ │ ├── optimize.png │ │ │ ├── undraw_docusaurus_mountain.svg │ │ │ ├── undraw_docusaurus_react.svg │ │ │ ├── undraw_docusaurus_tree.svg │ │ │ └── universal_compatibility.png │ │ ├── stylesheets │ │ │ └── extra.css │ │ └── tutorials │ │ ├── agents │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-agent.png │ │ ├── ai_text_game │ │ │ └── index.md │ │ ├── async │ │ │ └── index.md │ │ ├── audio │ │ │ └── index.ipynb │ │ ├── build_ai_program │ │ │ └── index.md │ │ ├── cache │ │ │ └── index.md │ │ ├── classification │ │ │ └── index.md │ │ ├── classification_finetuning │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-classification.png │ │ ├── conversation_history │ │ │ └── index.md │ │ ├── core_development │ │ │ └── index.md │ │ ├── custom_module │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-custom-module.png │ │ ├── customer_service_agent │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-customer-service-agent.png │ │ ├── deployment │ │ │ ├── dspy_mlflow_ui.png │ │ │ └── index.md │ │ ├── email_extraction │ │ │ ├── index.md │ │ │ └── mlflow-tracing-email-extraction.png │ │ ├── entity_extraction │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-entity-extraction.png │ │ ├── games │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-agent.png │ │ ├── gepa_ai_program │ │ │ └── index.md │ │ ├── gepa_aime │ │ │ ├── index.ipynb │ │ │ ├── mlflow-tracing-gepa-aime.png │ │ │ └── mlflow-tracking-gepa-aime-optimization.png │ │ ├── gepa_facilitysupportanalyzer │ │ │ ├── index.ipynb │ │ │ ├── mlflow-tracing-gepa-support.png │ │ │ └── mlflow-tracking-gepa-support-optimization.png │ │ ├── gepa_papillon │ │ │ ├── index.ipynb │ │ │ ├── mlflow-tracing-gepa-papilon.png │ │ │ └── mlflow-tracking-gepa-papilon-optimization.png │ │ ├── image_generation_prompting │ │ │ └── index.ipynb │ │ ├── index.md │ │ ├── llms_txt_generation │ │ │ └── index.md │ │ ├── math │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-math.png │ │ ├── mcp │ │ │ └── index.md │ │ ├── mem0_react_agent │ │ │ └── index.md │ │ ├── multihop_search │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-multi-hop.png │ │ ├── observability │ │ │ ├── index.md │ │ │ ├── mlflow_trace_ui_navigation.gif │ │ │ ├── mlflow_trace_ui.png │ │ │ └── mlflow_trace_view.png │ │ ├── optimize_ai_program │ │ │ └── index.md │ │ ├── optimizer_tracking │ │ │ ├── child_run.png │ │ │ ├── experiment.png │ │ │ ├── index.md │ │ │ └── parent_run.png │ │ ├── output_refinement │ │ │ └── best-of-n-and-refine.md │ │ ├── papillon │ │ │ └── index.md │ │ ├── program_of_thought │ │ │ └── index.ipynb │ │ ├── rag │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-rag.png │ │ ├── real_world_examples │ │ │ └── index.md │ │ ├── rl_ai_program │ │ │ └── index.md │ │ ├── rl_multihop │ │ │ └── index.ipynb │ │ ├── rl_papillon │ │ │ └── index.ipynb │ │ ├── sample_code_generation │ │ │ └── index.md │ │ ├── saving │ │ │ └── index.md │ │ ├── streaming │ │ │ └── index.md │ │ ├── tool_use │ │ │ └── index.ipynb │ │ └── yahoo_finance_react │ │ └── index.md │ ├── mkdocs.yml │ ├── overrides │ │ ├── home.html │ │ ├── main.html │ │ └── partials │ │ └── tabs.html │ ├── Pipfile │ ├── Pipfile.lock │ ├── README.md │ ├── requirements.txt │ ├── scripts │ │ ├── generate_api_docs.py │ │ └── generate_api_summary.py │ └── vercel.json ├── dspy │ ├── __init__.py │ ├── __metadata__.py │ ├── adapters │ │ ├── __init__.py │ │ ├── baml_adapter.py │ │ ├── base.py │ │ ├── chat_adapter.py │ │ ├── json_adapter.py │ │ ├── two_step_adapter.py │ │ ├── types │ │ │ ├── __init__.py │ │ │ ├── audio.py │ │ │ ├── base_type.py │ │ │ ├── citation.py │ │ │ ├── code.py │ │ │ ├── document.py │ │ │ ├── history.py │ │ │ ├── image.py │ │ │ └── tool.py │ │ ├── utils.py │ │ └── xml_adapter.py │ ├── clients │ │ ├── __init__.py │ │ ├── base_lm.py │ │ ├── cache.py │ │ ├── databricks.py │ │ ├── embedding.py │ │ ├── lm_local_arbor.py │ │ ├── lm_local.py │ │ ├── lm.py │ │ ├── openai.py │ │ ├── provider.py │ │ └── utils_finetune.py │ ├── datasets │ │ ├── __init__.py │ │ ├── alfworld │ │ │ ├── __init__.py │ │ │ ├── alfworld.py │ │ │ └── base_config.yml │ │ ├── colors.py │ │ ├── dataloader.py │ │ ├── dataset.py │ │ ├── gsm8k.py │ │ ├── hotpotqa.py │ │ └── math.py │ ├── dsp │ │ ├── __init__.py │ │ ├── colbertv2.py │ │ └── utils │ │ ├── __init__.py │ │ ├── dpr.py │ │ ├── settings.py │ │ └── utils.py │ ├── evaluate │ │ ├── __init__.py │ │ ├── auto_evaluation.py │ │ ├── evaluate.py │ │ └── metrics.py │ ├── experimental │ │ └── __init__.py │ ├── predict │ │ ├── __init__.py │ │ ├── aggregation.py │ │ ├── avatar │ │ │ ├── __init__.py │ │ │ ├── avatar.py │ │ │ ├── models.py │ │ │ └── signatures.py │ │ ├── best_of_n.py │ │ ├── chain_of_thought.py │ │ ├── code_act.py │ │ ├── knn.py │ │ ├── multi_chain_comparison.py │ │ ├── parallel.py │ │ ├── parameter.py │ │ ├── predict.py │ │ ├── program_of_thought.py │ │ ├── react.py │ │ ├── refine.py │ │ └── retry.py │ ├── primitives │ │ ├── __init__.py │ │ ├── base_module.py │ │ ├── example.py │ │ ├── module.py │ │ ├── prediction.py │ │ ├── python_interpreter.py │ │ └── runner.js │ ├── propose │ │ ├── __init__.py │ │ ├── dataset_summary_generator.py │ │ ├── grounded_proposer.py │ │ ├── propose_base.py │ │ └── utils.py │ ├── retrievers │ │ ├── __init__.py │ │ ├── databricks_rm.py │ │ ├── embeddings.py │ │ ├── retrieve.py │ │ └── weaviate_rm.py │ ├── signatures │ │ ├── __init__.py │ │ ├── field.py │ │ ├── signature.py │ │ └── utils.py │ ├── streaming │ │ ├── __init__.py │ │ ├── messages.py │ │ ├── streamify.py │ │ └── streaming_listener.py │ ├── teleprompt │ │ ├── __init__.py │ │ ├── avatar_optimizer.py │ │ ├── bettertogether.py │ │ ├── bootstrap_finetune.py │ │ ├── bootstrap_trace.py │ │ ├── bootstrap.py │ │ ├── copro_optimizer.py │ │ ├── ensemble.py │ │ ├── gepa │ │ │ ├── __init__.py │ │ │ ├── gepa_utils.py │ │ │ ├── gepa.py │ │ │ └── instruction_proposal.py │ │ ├── grpo.py │ │ ├── infer_rules.py │ │ ├── knn_fewshot.py │ │ ├── mipro_optimizer_v2.py │ │ ├── random_search.py │ │ ├── signature_opt.py │ │ ├── simba_utils.py │ │ ├── simba.py │ │ ├── teleprompt_optuna.py │ │ ├── teleprompt.py │ │ ├── utils.py │ │ └── vanilla.py │ └── utils │ ├── __init__.py │ ├── annotation.py │ ├── asyncify.py │ ├── caching.py │ ├── callback.py │ ├── dummies.py │ ├── exceptions.py │ ├── hasher.py │ ├── inspect_history.py │ ├── langchain_tool.py │ ├── logging_utils.py │ ├── mcp.py │ ├── parallelizer.py │ ├── saving.py │ ├── syncify.py │ ├── unbatchify.py │ └── usage_tracker.py ├── LICENSE ├── pyproject.toml ├── README.md ├── tests │ ├── __init__.py │ ├── adapters │ │ ├── test_adapter_utils.py │ │ ├── test_baml_adapter.py │ │ ├── test_base_type.py │ │ ├── test_chat_adapter.py │ │ ├── test_citation.py │ │ ├── test_code.py │ │ ├── test_document.py │ │ ├── test_json_adapter.py │ │ ├── test_tool.py │ │ ├── test_two_step_adapter.py │ │ └── test_xml_adapter.py │ ├── callback │ │ └── test_callback.py │ ├── clients │ │ ├── test_cache.py │ │ ├── test_databricks.py │ │ ├── test_embedding.py │ │ ├── test_inspect_global_history.py │ │ └── test_lm.py │ ├── conftest.py │ ├── datasets │ │ └── test_dataset.py │ ├── docs │ │ └── test_mkdocs_links.py │ ├── evaluate │ │ ├── test_evaluate.py │ │ └── test_metrics.py │ ├── examples │ │ └── test_baleen.py │ ├── metadata │ │ └── test_metadata.py │ ├── predict │ │ ├── test_aggregation.py │ │ ├── test_best_of_n.py │ │ ├── test_chain_of_thought.py │ │ ├── test_code_act.py │ │ ├── test_knn.py │ │ ├── test_multi_chain_comparison.py │ │ ├── test_parallel.py │ │ ├── test_predict.py │ │ ├── test_program_of_thought.py │ │ ├── test_react.py │ │ ├── test_refine.py │ │ └── test_retry.py │ ├── primitives │ │ ├── resources │ │ │ └── saved_program.json │ │ ├── test_base_module.py │ │ ├── test_example.py │ │ ├── test_module.py │ │ └── test_python_interpreter.py │ ├── propose │ │ └── test_grounded_proposer.py │ ├── README.md │ ├── reliability │ │ ├── __init__.py │ │ ├── complex_types │ │ │ └── generated │ │ │ ├── test_many_types_1 │ │ │ │ ├── inputs │ │ │ │ │ ├── input1.json │ │ │ │ │ └── input2.json │ │ │ │ ├── program.py │ │ │ │ └── schema.json │ │ │ ├── test_nesting_1 │ │ │ │ ├── inputs │ │ │ │ │ ├── input1.json │ │ │ │ │ └── input2.json │ │ │ │ ├── program.py │ │ │ │ └── schema.json │ │ │ └── test_nesting_2 │ │ │ ├── inputs │ │ │ │ └── input1.json │ │ │ ├── program.py │ │ │ └── schema.json │ │ ├── conftest.py │ │ ├── generate │ │ │ ├── __init__.py │ │ │ ├── __main__.py │ │ │ └── utils.py │ │ ├── input_formats │ │ │ └── generated │ │ │ └── test_markdown_1 │ │ │ ├── inputs │ │ │ │ ├── input1.json │ │ │ │ └── input2.json │ │ │ ├── program.py │ │ │ └── schema.json │ │ ├── README.md │ │ ├── reliability_conf.yaml │ │ ├── test_generated.py │ │ ├── test_pydantic_models.py │ │ └── utils.py │ ├── retrievers │ │ └── test_embeddings.py │ ├── signatures │ │ ├── test_adapter_image.py │ │ ├── test_custom_types.py │ │ └── test_signature.py │ ├── streaming │ │ └── test_streaming.py │ ├── teleprompt │ │ ├── gepa_dummy_lm_custom_component_selector_custom_instruction_proposer.json │ │ ├── gepa_dummy_lm.json │ │ ├── test_bootstrap_finetune.py │ │ ├── test_bootstrap_trace.py │ │ ├── test_bootstrap.py │ │ ├── test_copro_optimizer.py │ │ ├── test_ensemble.py │ │ ├── test_finetune.py │ │ ├── test_gepa_instruction_proposer.py │ │ ├── test_gepa.py │ │ ├── test_grpo.py │ │ ├── test_knn_fewshot.py │ │ ├── test_random_search.py │ │ ├── test_teleprompt.py │ │ └── test_utils.py │ ├── test_utils │ │ ├── __init__.py │ │ └── server │ │ ├── __init__.py │ │ ├── litellm_server_config.yaml │ │ └── litellm_server.py │ └── utils │ ├── __init__.py │ ├── resources │ │ └── mcp_server.py │ ├── test_annotation.py │ ├── test_asyncify.py │ ├── test_exceptions.py │ ├── test_langchain_tool.py │ ├── test_mcp.py │ ├── test_parallelizer.py │ ├── test_saving.py │ ├── test_settings.py │ ├── test_syncify.py │ ├── test_unbatchify.py │ └── test_usage_tracker.py └── uv.lock ``` # Files -------------------------------------------------------------------------------- /docs/docs/learn/programming/signatures.md: -------------------------------------------------------------------------------- ```markdown 1 | --- 2 | sidebar_position: 2 3 | --- 4 | 5 | # Signatures 6 | 7 | When we assign tasks to LMs in DSPy, we specify the behavior we need as a Signature. 8 | 9 | **A signature is a declarative specification of input/output behavior of a DSPy module.** Signatures allow you to tell the LM _what_ it needs to do, rather than specify _how_ we should ask the LM to do it. 10 | 11 | You're probably familiar with function signatures, which specify the input and output arguments and their types. DSPy signatures are similar, but with a couple of differences. While typical function signatures just _describe_ things, DSPy Signatures _declare and initialize the behavior_ of modules. Moreover, the field names matter in DSPy Signatures. You express semantic roles in plain English: a `question` is different from an `answer`, a `sql_query` is different from `python_code`. 12 | 13 | ## Why should I use a DSPy Signature? 14 | 15 | For modular and clean code, in which LM calls can be optimized into high-quality prompts (or automatic finetunes). Most people coerce LMs to do tasks by hacking long, brittle prompts. Or by collecting/generating data for fine-tuning. Writing signatures is far more modular, adaptive, and reproducible than hacking at prompts or finetunes. The DSPy compiler will figure out how to build a highly-optimized prompt for your LM (or finetune your small LM) for your signature, on your data, and within your pipeline. In many cases, we found that compiling leads to better prompts than humans write. Not because DSPy optimizers are more creative than humans, but simply because they can try more things and tune the metrics directly. 16 | 17 | ## **Inline** DSPy Signatures 18 | 19 | Signatures can be defined as a short string, with argument names and optional types that define semantic roles for inputs/outputs. 20 | 21 | 1. Question Answering: `"question -> answer"`, which is equivalent to `"question: str -> answer: str"` as the default type is always `str` 22 | 23 | 2. Sentiment Classification: `"sentence -> sentiment: bool"`, e.g. `True` if positive 24 | 25 | 3. Summarization: `"document -> summary"` 26 | 27 | Your signatures can also have multiple input/output fields with types: 28 | 29 | 4. Retrieval-Augmented Question Answering: `"context: list[str], question: str -> answer: str"` 30 | 31 | 5. Multiple-Choice Question Answering with Reasoning: `"question, choices: list[str] -> reasoning: str, selection: int"` 32 | 33 | **Tip:** For fields, any valid variable names work! Field names should be semantically meaningful, but start simple and don't prematurely optimize keywords! Leave that kind of hacking to the DSPy compiler. For example, for summarization, it's probably fine to say `"document -> summary"`, `"text -> gist"`, or `"long_context -> tldr"`. 34 | 35 | You can also add instructions to your inline signature, which can use variables at runtime. Use the `instructions` keyword argument to add instructions to your signature. 36 | 37 | ```python 38 | toxicity = dspy.Predict( 39 | dspy.Signature( 40 | "comment -> toxic: bool", 41 | instructions="Mark as 'toxic' if the comment includes insults, harassment, or sarcastic derogatory remarks.", 42 | ) 43 | ) 44 | comment = "you are beautiful." 45 | toxicity(comment=comment).toxic 46 | ``` 47 | 48 | **Output:** 49 | ```text 50 | False 51 | ``` 52 | 53 | 54 | ### Example A: Sentiment Classification 55 | 56 | ```python 57 | sentence = "it's a charming and often affecting journey." # example from the SST-2 dataset. 58 | 59 | classify = dspy.Predict('sentence -> sentiment: bool') # we'll see an example with Literal[] later 60 | classify(sentence=sentence).sentiment 61 | ``` 62 | **Output:** 63 | ```text 64 | True 65 | ``` 66 | 67 | ### Example B: Summarization 68 | 69 | ```python 70 | # Example from the XSum dataset. 71 | document = """The 21-year-old made seven appearances for the Hammers and netted his only goal for them in a Europa League qualification round match against Andorran side FC Lustrains last season. Lee had two loan spells in League One last term, with Blackpool and then Colchester United. He scored twice for the U's but was unable to save them from relegation. The length of Lee's contract with the promoted Tykes has not been revealed. Find all the latest football transfers on our dedicated page.""" 72 | 73 | summarize = dspy.ChainOfThought('document -> summary') 74 | response = summarize(document=document) 75 | 76 | print(response.summary) 77 | ``` 78 | **Possible Output:** 79 | ```text 80 | The 21-year-old Lee made seven appearances and scored one goal for West Ham last season. He had loan spells in League One with Blackpool and Colchester United, scoring twice for the latter. He has now signed a contract with Barnsley, but the length of the contract has not been revealed. 81 | ``` 82 | 83 | Many DSPy modules (except `dspy.Predict`) return auxiliary information by expanding your signature under the hood. 84 | 85 | For example, `dspy.ChainOfThought` also adds a `reasoning` field that includes the LM's reasoning before it generates the output `summary`. 86 | 87 | ```python 88 | print("Reasoning:", response.reasoning) 89 | ``` 90 | **Possible Output:** 91 | ```text 92 | Reasoning: We need to highlight Lee's performance for West Ham, his loan spells in League One, and his new contract with Barnsley. We also need to mention that his contract length has not been disclosed. 93 | ``` 94 | 95 | ## **Class-based** DSPy Signatures 96 | 97 | For some advanced tasks, you need more verbose signatures. This is typically to: 98 | 99 | 1. Clarify something about the nature of the task (expressed below as a `docstring`). 100 | 101 | 2. Supply hints on the nature of an input field, expressed as a `desc` keyword argument for `dspy.InputField`. 102 | 103 | 3. Supply constraints on an output field, expressed as a `desc` keyword argument for `dspy.OutputField`. 104 | 105 | ### Example C: Classification 106 | 107 | ```python 108 | from typing import Literal 109 | 110 | class Emotion(dspy.Signature): 111 | """Classify emotion.""" 112 | 113 | sentence: str = dspy.InputField() 114 | sentiment: Literal['sadness', 'joy', 'love', 'anger', 'fear', 'surprise'] = dspy.OutputField() 115 | 116 | sentence = "i started feeling a little vulnerable when the giant spotlight started blinding me" # from dair-ai/emotion 117 | 118 | classify = dspy.Predict(Emotion) 119 | classify(sentence=sentence) 120 | ``` 121 | **Possible Output:** 122 | ```text 123 | Prediction( 124 | sentiment='fear' 125 | ) 126 | ``` 127 | 128 | **Tip:** There's nothing wrong with specifying your requests to the LM more clearly. Class-based Signatures help you with that. However, don't prematurely tune the keywords of your signature by hand. The DSPy optimizers will likely do a better job (and will transfer better across LMs). 129 | 130 | ### Example D: A metric that evaluates faithfulness to citations 131 | 132 | ```python 133 | class CheckCitationFaithfulness(dspy.Signature): 134 | """Verify that the text is based on the provided context.""" 135 | 136 | context: str = dspy.InputField(desc="facts here are assumed to be true") 137 | text: str = dspy.InputField() 138 | faithfulness: bool = dspy.OutputField() 139 | evidence: dict[str, list[str]] = dspy.OutputField(desc="Supporting evidence for claims") 140 | 141 | context = "The 21-year-old made seven appearances for the Hammers and netted his only goal for them in a Europa League qualification round match against Andorran side FC Lustrains last season. Lee had two loan spells in League One last term, with Blackpool and then Colchester United. He scored twice for the U's but was unable to save them from relegation. The length of Lee's contract with the promoted Tykes has not been revealed. Find all the latest football transfers on our dedicated page." 142 | 143 | text = "Lee scored 3 goals for Colchester United." 144 | 145 | faithfulness = dspy.ChainOfThought(CheckCitationFaithfulness) 146 | faithfulness(context=context, text=text) 147 | ``` 148 | **Possible Output:** 149 | ```text 150 | Prediction( 151 | reasoning="Let's check the claims against the context. The text states Lee scored 3 goals for Colchester United, but the context clearly states 'He scored twice for the U's'. This is a direct contradiction.", 152 | faithfulness=False, 153 | evidence={'goal_count': ["scored twice for the U's"]} 154 | ) 155 | ``` 156 | 157 | ### Example E: Multi-modal image classification 158 | 159 | ```python 160 | class DogPictureSignature(dspy.Signature): 161 | """Output the dog breed of the dog in the image.""" 162 | image_1: dspy.Image = dspy.InputField(desc="An image of a dog") 163 | answer: str = dspy.OutputField(desc="The dog breed of the dog in the image") 164 | 165 | image_url = "https://picsum.photos/id/237/200/300" 166 | classify = dspy.Predict(DogPictureSignature) 167 | classify(image_1=dspy.Image.from_url(image_url)) 168 | ``` 169 | 170 | **Possible Output:** 171 | 172 | ```text 173 | Prediction( 174 | answer='Labrador Retriever' 175 | ) 176 | ``` 177 | 178 | ## Type Resolution in Signatures 179 | 180 | DSPy signatures support various annotation types: 181 | 182 | 1. **Basic types** like `str`, `int`, `bool` 183 | 2. **Typing module types** like `list[str]`, `dict[str, int]`, `Optional[float]`. `Union[str, int]` 184 | 3. **Custom types** defined in your code 185 | 4. **Dot notation** for nested types with proper configuration 186 | 5. **Special data types** like `dspy.Image, dspy.History` 187 | 188 | ### Working with Custom Types 189 | 190 | ```python 191 | # Simple custom type 192 | class QueryResult(pydantic.BaseModel): 193 | text: str 194 | score: float 195 | 196 | signature = dspy.Signature("query: str -> result: QueryResult") 197 | 198 | class MyContainer: 199 | class Query(pydantic.BaseModel): 200 | text: str 201 | class Score(pydantic.BaseModel): 202 | score: float 203 | 204 | signature = dspy.Signature("query: MyContainer.Query -> score: MyContainer.Score") 205 | ``` 206 | 207 | ## Using signatures to build modules & compiling them 208 | 209 | While signatures are convenient for prototyping with structured inputs/outputs, that's not the only reason to use them! 210 | 211 | You should compose multiple signatures into bigger [DSPy modules](modules.md) and [compile these modules into optimized prompts](../optimization/optimizers.md) and finetunes. 212 | ``` -------------------------------------------------------------------------------- /dspy/predict/refine.py: -------------------------------------------------------------------------------- ```python 1 | import inspect 2 | import textwrap 3 | from typing import Callable 4 | 5 | import orjson 6 | 7 | import dspy 8 | from dspy.adapters.utils import get_field_description_string 9 | from dspy.predict.predict import Prediction 10 | from dspy.signatures import InputField, OutputField, Signature 11 | 12 | from .predict import Module 13 | 14 | 15 | class OfferFeedback(Signature): 16 | """ 17 | In the discussion, assign blame to each module that contributed to the final reward being below the threshold, if 18 | any. Then, prescribe concrete advice of how the module should act on its future input when we retry the process, if 19 | it were to receive the same or similar inputs. If a module is not to blame, the advice should be N/A. 20 | The module will not see its own history, so it needs to rely on entirely concrete and actionable advice from you 21 | to avoid the same mistake on the same or similar inputs. 22 | """ 23 | 24 | program_code: str = InputField(desc="The code of the program that we are analyzing") 25 | modules_defn: str = InputField(desc="The definition of each module in the program, including its I/O") 26 | program_inputs: str = InputField(desc="The inputs to the program that we are analyzing") 27 | program_trajectory: str = InputField(desc="The trajectory of the program's execution, showing each module's I/O") 28 | program_outputs: str = InputField(desc="The outputs of the program that we are analyzing") 29 | reward_code: str = InputField(desc="The code of the reward function that we are analyzing") 30 | target_threshold: float = InputField(desc="The target threshold for the reward function") 31 | reward_value: float = InputField(desc="The reward value assigned to the program's outputs") 32 | module_names: list[str] = InputField(desc="The names of the modules in the program, for which we seek advice") 33 | discussion: str = OutputField(desc="Discussing blame of where each module went wrong, if it did") 34 | advice: dict[str, str] = OutputField( 35 | desc="For each module, describe very concretely, in this order: the specific scenarios in which it has made " 36 | "mistakes in the past and what each mistake was, followed by what it should do differently in that kind of" 37 | "scenario in the future. If the module is not to blame, write N/A." 38 | ) 39 | 40 | 41 | class Refine(Module): 42 | def __init__( 43 | self, 44 | module: Module, 45 | N: int, # noqa: N803 46 | reward_fn: Callable[[dict, Prediction], float], 47 | threshold: float, 48 | fail_count: int | None = None, 49 | ): 50 | """ 51 | Refines a module by running it up to N times with different rollout IDs at `temperature=1.0` 52 | and returns the best prediction. 53 | 54 | This module runs the provided module multiple times with varying rollout identifiers and selects 55 | either the first prediction that exceeds the specified threshold or the one with the highest reward. 56 | If no prediction meets the threshold, it automatically generates feedback to improve future predictions. 57 | 58 | 59 | Args: 60 | module (Module): The module to refine. 61 | N (int): The number of times to run the module. must 62 | reward_fn (Callable): The reward function. 63 | threshold (float): The threshold for the reward function. 64 | fail_count (Optional[int], optional): The number of times the module can fail before raising an error 65 | 66 | Example: 67 | ```python 68 | import dspy 69 | 70 | dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini")) 71 | 72 | # Define a QA module with chain of thought 73 | qa = dspy.ChainOfThought("question -> answer") 74 | 75 | # Define a reward function that checks for one-word answers 76 | def one_word_answer(args, pred): 77 | return 1.0 if len(pred.answer.split()) == 1 else 0.0 78 | 79 | # Create a refined module that tries up to 3 times 80 | best_of_3 = dspy.Refine(module=qa, N=3, reward_fn=one_word_answer, threshold=1.0) 81 | 82 | # Use the refined module 83 | result = best_of_3(question="What is the capital of Belgium?").answer 84 | # Returns: Brussels 85 | ``` 86 | """ 87 | self.module = module 88 | self.reward_fn = lambda *args: reward_fn(*args) # to prevent this from becoming a parameter 89 | self.threshold = threshold 90 | self.N = N 91 | self.fail_count = fail_count or N # default to N if fail_count is not provided 92 | self.module_code = inspect.getsource(module.__class__) 93 | try: 94 | self.reward_fn_code = inspect.getsource(reward_fn) 95 | except TypeError: 96 | self.reward_fn_code = inspect.getsource(reward_fn.__class__) 97 | 98 | def forward(self, **kwargs): 99 | lm = self.module.get_lm() or dspy.settings.lm 100 | start = lm.kwargs.get("rollout_id", 0) 101 | rollout_ids = [start + i for i in range(self.N)] 102 | best_pred, best_trace, best_reward = None, None, -float("inf") 103 | advice = None 104 | adapter = dspy.settings.adapter or dspy.ChatAdapter() 105 | 106 | for idx, rid in enumerate(rollout_ids): 107 | lm_ = lm.copy(rollout_id=rid, temperature=1.0) 108 | mod = self.module.deepcopy() 109 | mod.set_lm(lm_) 110 | 111 | predictor2name = {predictor: name for name, predictor in mod.named_predictors()} 112 | signature2name = {predictor.signature: name for name, predictor in mod.named_predictors()} 113 | module_names = [name for name, _ in mod.named_predictors()] 114 | 115 | try: 116 | with dspy.context(trace=[]): 117 | if not advice: 118 | outputs = mod(**kwargs) 119 | else: 120 | 121 | class WrapperAdapter(adapter.__class__): 122 | def __call__(self, lm, lm_kwargs, signature, demos, inputs): 123 | inputs["hint_"] = advice.get(signature2name[signature], "N/A") # noqa: B023 124 | signature = signature.append( 125 | "hint_", InputField(desc="A hint to the module from an earlier run") 126 | ) 127 | return adapter(lm, lm_kwargs, signature, demos, inputs) 128 | 129 | with dspy.context(adapter=WrapperAdapter()): 130 | outputs = mod(**kwargs) 131 | 132 | trace = dspy.settings.trace.copy() 133 | 134 | # TODO: Remove the hint from the trace, if it's there. 135 | 136 | # NOTE: Not including the trace of reward_fn. 137 | reward = self.reward_fn(kwargs, outputs) 138 | 139 | if reward > best_reward: 140 | best_reward, best_pred, best_trace = reward, outputs, trace 141 | 142 | if self.threshold is not None and reward >= self.threshold: 143 | break 144 | 145 | if idx == self.N - 1: 146 | break 147 | 148 | modules = {"program_code": self.module_code, "modules_defn": inspect_modules(mod)} 149 | trajectory = [{"module_name": predictor2name[p], "inputs": i, "outputs": dict(o)} for p, i, o in trace] 150 | trajectory = { 151 | "program_inputs": kwargs, 152 | "program_trajectory": trajectory, 153 | "program_outputs": dict(outputs), 154 | } 155 | reward = { 156 | "reward_code": self.reward_fn_code, 157 | "target_threshold": self.threshold, 158 | "reward_value": reward, 159 | } 160 | 161 | advise_kwargs = dict(**modules, **trajectory, **reward, module_names=module_names) 162 | # only dumps if it's a list or dict 163 | advise_kwargs = { 164 | k: v if isinstance(v, str) else orjson.dumps(recursive_mask(v), option=orjson.OPT_INDENT_2).decode() 165 | for k, v in advise_kwargs.items() 166 | } 167 | advice = dspy.Predict(OfferFeedback)(**advise_kwargs).advice 168 | # print(f"Advice for each module: {advice}") 169 | 170 | except Exception as e: 171 | print(f"Refine: Attempt failed with rollout id {rid}: {e}") 172 | if idx > self.fail_count: 173 | raise e 174 | self.fail_count -= 1 175 | if best_trace: 176 | dspy.settings.trace.extend(best_trace) 177 | return best_pred 178 | 179 | 180 | def inspect_modules(program): 181 | separator = "-" * 80 182 | output = [separator] 183 | 184 | for _, (name, predictor) in enumerate(program.named_predictors()): 185 | signature = predictor.signature 186 | instructions = textwrap.dedent(signature.instructions) 187 | instructions = ("\n" + "\t" * 2).join([""] + instructions.splitlines()) 188 | 189 | output.append(f"Module {name}") 190 | output.append("\n\tInput Fields:") 191 | output.append(("\n" + "\t" * 2).join([""] + get_field_description_string(signature.input_fields).splitlines())) 192 | output.append("\tOutput Fields:") 193 | output.append(("\n" + "\t" * 2).join([""] + get_field_description_string(signature.output_fields).splitlines())) 194 | output.append(f"\tOriginal Instructions: {instructions}") 195 | output.append(separator) 196 | 197 | return "\n".join([o.strip("\n") for o in output]) 198 | 199 | 200 | def recursive_mask(o): 201 | # If the object is already serializable, return it. 202 | try: 203 | orjson.dumps(o) 204 | return o 205 | except TypeError: 206 | pass 207 | 208 | # If it's a dictionary, apply recursively to its values. 209 | if isinstance(o, dict): 210 | return {k: recursive_mask(v) for k, v in o.items()} 211 | # If it's a list, apply recursively. 212 | elif isinstance(o, list): 213 | return [recursive_mask(v) for v in o] 214 | # If it's a tuple, apply recursively. 215 | elif isinstance(o, tuple): 216 | return tuple(recursive_mask(v) for v in o) 217 | # Otherwise, replace it with a placeholder string (or use repr(o)). 218 | else: 219 | return f"<non-serializable: {type(o).__name__}>" 220 | ``` -------------------------------------------------------------------------------- /dspy/clients/base_lm.py: -------------------------------------------------------------------------------- ```python 1 | import datetime 2 | import uuid 3 | 4 | from dspy.dsp.utils import settings 5 | from dspy.utils.callback import with_callbacks 6 | from dspy.utils.inspect_history import pretty_print_history 7 | 8 | MAX_HISTORY_SIZE = 10_000 9 | GLOBAL_HISTORY = [] 10 | 11 | 12 | class BaseLM: 13 | """Base class for handling LLM calls. 14 | 15 | Most users can directly use the `dspy.LM` class, which is a subclass of `BaseLM`. Users can also implement their 16 | own subclasses of `BaseLM` to support custom LLM providers and inject custom logic. To do so, simply override the 17 | `forward` method and make sure the return format is identical to the 18 | [OpenAI response format](https://platform.openai.com/docs/api-reference/responses/object). 19 | 20 | Example: 21 | 22 | ```python 23 | from openai import OpenAI 24 | 25 | import dspy 26 | 27 | 28 | class MyLM(dspy.BaseLM): 29 | def forward(self, prompt, messages=None, **kwargs): 30 | client = OpenAI() 31 | return client.chat.completions.create( 32 | model=self.model, 33 | messages=messages or [{"role": "user", "content": prompt}], 34 | **self.kwargs, 35 | ) 36 | 37 | 38 | lm = MyLM(model="gpt-4o-mini") 39 | dspy.configure(lm=lm) 40 | print(dspy.Predict("q->a")(q="Why did the chicken cross the kitchen?")) 41 | ``` 42 | """ 43 | 44 | def __init__(self, model, model_type="chat", temperature=0.0, max_tokens=1000, cache=True, **kwargs): 45 | self.model = model 46 | self.model_type = model_type 47 | self.cache = cache 48 | self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs) 49 | self.history = [] 50 | 51 | def _process_lm_response(self, response, prompt, messages, **kwargs): 52 | merged_kwargs = {**self.kwargs, **kwargs} 53 | 54 | if self.model_type == "responses": 55 | outputs = self._process_response(response) 56 | else: 57 | outputs = self._process_completion(response, merged_kwargs) 58 | 59 | if settings.disable_history: 60 | return outputs 61 | 62 | # Logging, with removed api key & where `cost` is None on cache hit. 63 | kwargs = {k: v for k, v in kwargs.items() if not k.startswith("api_")} 64 | entry = { 65 | "prompt": prompt, 66 | "messages": messages, 67 | "kwargs": kwargs, 68 | "response": response, 69 | "outputs": outputs, 70 | "usage": dict(response.usage), 71 | "cost": getattr(response, "_hidden_params", {}).get("response_cost"), 72 | "timestamp": datetime.datetime.now().isoformat(), 73 | "uuid": str(uuid.uuid4()), 74 | "model": self.model, 75 | "response_model": response.model, 76 | "model_type": self.model_type, 77 | } 78 | 79 | self.update_history(entry) 80 | 81 | return outputs 82 | 83 | @with_callbacks 84 | def __call__(self, prompt=None, messages=None, **kwargs): 85 | response = self.forward(prompt=prompt, messages=messages, **kwargs) 86 | outputs = self._process_lm_response(response, prompt, messages, **kwargs) 87 | 88 | return outputs 89 | 90 | @with_callbacks 91 | async def acall(self, prompt=None, messages=None, **kwargs): 92 | response = await self.aforward(prompt=prompt, messages=messages, **kwargs) 93 | outputs = self._process_lm_response(response, prompt, messages, **kwargs) 94 | return outputs 95 | 96 | def forward(self, prompt=None, messages=None, **kwargs): 97 | """Forward pass for the language model. 98 | 99 | Subclasses must implement this method, and the response should be identical to 100 | [OpenAI response format](https://platform.openai.com/docs/api-reference/responses/object). 101 | """ 102 | raise NotImplementedError("Subclasses must implement this method.") 103 | 104 | async def aforward(self, prompt=None, messages=None, **kwargs): 105 | """Async forward pass for the language model. 106 | 107 | Subclasses that support async should implement this method, and the response should be identical to 108 | [OpenAI response format](https://platform.openai.com/docs/api-reference/responses/object). 109 | """ 110 | raise NotImplementedError("Subclasses must implement this method.") 111 | 112 | def copy(self, **kwargs): 113 | """Returns a copy of the language model with possibly updated parameters. 114 | 115 | Any provided keyword arguments update the corresponding attributes or LM kwargs of 116 | the copy. For example, ``lm.copy(rollout_id=1, temperature=1.0)`` returns an LM whose 117 | requests use a different rollout ID at non-zero temperature to bypass cache collisions. 118 | """ 119 | 120 | import copy 121 | 122 | new_instance = copy.deepcopy(self) 123 | new_instance.history = [] 124 | 125 | for key, value in kwargs.items(): 126 | if hasattr(self, key): 127 | setattr(new_instance, key, value) 128 | if (key in self.kwargs) or (not hasattr(self, key)): 129 | if value is None: 130 | new_instance.kwargs.pop(key, None) 131 | else: 132 | new_instance.kwargs[key] = value 133 | if hasattr(new_instance, "_warned_zero_temp_rollout"): 134 | new_instance._warned_zero_temp_rollout = False 135 | 136 | return new_instance 137 | 138 | def inspect_history(self, n: int = 1): 139 | return pretty_print_history(self.history, n) 140 | 141 | def update_history(self, entry): 142 | if settings.disable_history: 143 | return 144 | 145 | # Global LM history 146 | if len(GLOBAL_HISTORY) >= MAX_HISTORY_SIZE: 147 | GLOBAL_HISTORY.pop(0) 148 | 149 | GLOBAL_HISTORY.append(entry) 150 | 151 | if settings.max_history_size == 0: 152 | return 153 | 154 | # dspy.LM.history 155 | if len(self.history) >= settings.max_history_size: 156 | self.history.pop(0) 157 | 158 | self.history.append(entry) 159 | 160 | # Per-module history 161 | caller_modules = settings.caller_modules or [] 162 | for module in caller_modules: 163 | if len(module.history) >= settings.max_history_size: 164 | module.history.pop(0) 165 | module.history.append(entry) 166 | 167 | def _process_completion(self, response, merged_kwargs): 168 | """Process the response of OpenAI chat completion API and extract outputs. 169 | 170 | Args: 171 | response: The OpenAI chat completion response 172 | https://platform.openai.com/docs/api-reference/chat/object 173 | merged_kwargs: Merged kwargs from self.kwargs and method kwargs 174 | 175 | Returns: 176 | List of processed outputs 177 | """ 178 | outputs = [] 179 | for c in response.choices: 180 | output = {} 181 | output["text"] = c.message.content if hasattr(c, "message") else c["text"] 182 | if merged_kwargs.get("logprobs"): 183 | output["logprobs"] = c.logprobs if hasattr(c, "logprobs") else c["logprobs"] 184 | if hasattr(c, "message") and getattr(c.message, "tool_calls", None): 185 | output["tool_calls"] = c.message.tool_calls 186 | 187 | # Extract citations from LiteLLM response if available 188 | citations = self._extract_citations_from_response(c) 189 | if citations: 190 | output["citations"] = citations 191 | 192 | outputs.append(output) 193 | 194 | if all(len(output) == 1 for output in outputs): 195 | # Return a list if every output only has "text" key 196 | outputs = [output["text"] for output in outputs] 197 | 198 | return outputs 199 | 200 | def _extract_citations_from_response(self, choice): 201 | """Extract citations from LiteLLM response if available. 202 | Reference: https://docs.litellm.ai/docs/providers/anthropic#beta-citations-api 203 | 204 | Args: 205 | choice: The choice object from response.choices 206 | 207 | Returns: 208 | A list of citation dictionaries or None if no citations found 209 | """ 210 | try: 211 | # Check for citations in LiteLLM provider_specific_fields 212 | citations_data = choice.message.provider_specific_fields.get("citations") 213 | if isinstance(citations_data, list): 214 | return [citation for citations in citations_data for citation in citations] 215 | except Exception: 216 | return None 217 | 218 | def _process_response(self, response): 219 | """Process the response of OpenAI Response API and extract outputs. 220 | 221 | Args: 222 | response: OpenAI Response API response 223 | https://platform.openai.com/docs/api-reference/responses/object 224 | 225 | Returns: 226 | List of processed outputs, which is always of size 1 because the Response API only supports one output. 227 | """ 228 | text_outputs = [] 229 | tool_calls = [] 230 | reasoning_contents = [] 231 | 232 | for output_item in response.output: 233 | output_item_type = output_item.type 234 | if output_item_type == "message": 235 | for content_item in output_item.content: 236 | text_outputs.append(content_item.text) 237 | elif output_item_type == "function_call": 238 | tool_calls.append(output_item.model_dump()) 239 | elif output_item_type == "reasoning": 240 | if getattr(output_item, "content", None) and len(output_item.content) > 0: 241 | for content_item in output_item.content: 242 | reasoning_contents.append(content_item.text) 243 | elif getattr(output_item, "summary", None) and len(output_item.summary) > 0: 244 | for summary_item in output_item.summary: 245 | reasoning_contents.append(summary_item.text) 246 | 247 | result = {} 248 | if len(text_outputs) > 0: 249 | result["text"] = "".join(text_outputs) 250 | if len(tool_calls) > 0: 251 | result["tool_calls"] = tool_calls 252 | if len(reasoning_contents) > 0: 253 | result["reasoning_content"] = "".join(reasoning_contents) 254 | # All `response.output` items map to one answer, so we return a list of size 1. 255 | return [result] 256 | 257 | 258 | def inspect_history(n: int = 1): 259 | """The global history shared across all LMs.""" 260 | return pretty_print_history(GLOBAL_HISTORY, n) 261 | ``` -------------------------------------------------------------------------------- /dspy/primitives/python_interpreter.py: -------------------------------------------------------------------------------- ```python 1 | import json 2 | import os 3 | import subprocess 4 | from os import PathLike 5 | from types import TracebackType 6 | from typing import Any 7 | 8 | 9 | class InterpreterError(RuntimeError): 10 | pass 11 | 12 | 13 | class PythonInterpreter: 14 | r""" 15 | PythonInterpreter that runs code in a sandboxed environment using Deno and Pyodide. 16 | 17 | Prerequisites: 18 | - Deno (https://docs.deno.com/runtime/getting_started/installation/). 19 | 20 | Example Usage: 21 | ```python 22 | code_string = "print('Hello'); 1 + 2" 23 | with PythonInterpreter() as interp: 24 | output = interp(code_string) # If final statement is non-None, prints the numeric result, else prints captured output 25 | ``` 26 | """ 27 | 28 | def __init__( 29 | self, 30 | deno_command: list[str] | None = None, 31 | enable_read_paths: list[PathLike | str] | None = None, 32 | enable_write_paths: list[PathLike | str] | None = None, 33 | enable_env_vars: list[str] | None = None, 34 | enable_network_access: list[str] | None = None, 35 | sync_files: bool = True, 36 | ) -> None: 37 | """ 38 | Args: 39 | deno_command: command list to launch Deno. 40 | enable_read_paths: Files or directories to allow reading from in the sandbox. 41 | enable_write_paths: Files or directories to allow writing to in the sandbox. 42 | enable_env_vars: Environment variable names to allow in the sandbox. 43 | enable_network_access: Domains or IPs to allow network access in the sandbox. 44 | sync_files: If set, syncs changes within the sandbox back to original files after execution. 45 | """ 46 | if isinstance(deno_command, dict): 47 | deno_command = None # no-op, just a guard in case someone passes a dict 48 | 49 | self.enable_read_paths = enable_read_paths or [] 50 | self.enable_write_paths = enable_write_paths or [] 51 | self.enable_env_vars = enable_env_vars or [] 52 | self.enable_network_access = enable_network_access or [] 53 | self.sync_files = sync_files 54 | # TODO later on add enable_run (--allow-run) by proxying subprocess.run through Deno.run() to fix 'emscripten does not support processes' error 55 | 56 | if deno_command: 57 | self.deno_command = list(deno_command) 58 | else: 59 | args = ["deno", "run", "--allow-read"] 60 | self._env_arg = "" 61 | if self.enable_env_vars: 62 | user_vars = [str(v).strip() for v in self.enable_env_vars] 63 | args.append("--allow-env=" + ",".join(user_vars)) 64 | self._env_arg = ",".join(user_vars) 65 | if self.enable_network_access: 66 | args.append(f"--allow-net={','.join(str(x) for x in self.enable_network_access)}") 67 | if self.enable_write_paths: 68 | args.append(f"--allow-write={','.join(str(x) for x in self.enable_write_paths)}") 69 | 70 | args.append(self._get_runner_path()) 71 | 72 | # For runner.js to load in env vars 73 | if self._env_arg: 74 | args.append(self._env_arg) 75 | self.deno_command = args 76 | 77 | self.deno_process = None 78 | self._mounted_files = False 79 | 80 | def _get_runner_path(self) -> str: 81 | current_dir = os.path.dirname(os.path.abspath(__file__)) 82 | return os.path.join(current_dir, "runner.js") 83 | 84 | def _mount_files(self): 85 | if self._mounted_files: 86 | return 87 | paths_to_mount = [] 88 | if self.enable_read_paths: 89 | paths_to_mount.extend(self.enable_read_paths) 90 | if self.enable_write_paths: 91 | paths_to_mount.extend(self.enable_write_paths) 92 | if not paths_to_mount: 93 | return 94 | for path in paths_to_mount: 95 | if not path: 96 | continue 97 | if not os.path.exists(path): 98 | if self.enable_write_paths and path in self.enable_write_paths: 99 | open(path, "a").close() 100 | else: 101 | raise FileNotFoundError(f"Cannot mount non-existent file: {path}") 102 | virtual_path = f"/sandbox/{os.path.basename(path)}" 103 | mount_msg = json.dumps({"mount_file": str(path), "virtual_path": virtual_path}) 104 | self.deno_process.stdin.write(mount_msg + "\n") 105 | self.deno_process.stdin.flush() 106 | self._mounted_files = True 107 | 108 | def _sync_files(self): 109 | if not self.enable_write_paths or not self.sync_files: 110 | return 111 | for path in self.enable_write_paths: 112 | virtual_path = f"/sandbox/{os.path.basename(path)}" 113 | sync_msg = json.dumps({ 114 | "sync_file": virtual_path, 115 | "host_file": str(path) 116 | }) 117 | self.deno_process.stdin.write(sync_msg + "\n") 118 | self.deno_process.stdin.flush() 119 | 120 | 121 | def _ensure_deno_process(self) -> None: 122 | if self.deno_process is None or self.deno_process.poll() is not None: 123 | try: 124 | self.deno_process = subprocess.Popen( 125 | self.deno_command, 126 | stdin=subprocess.PIPE, 127 | stdout=subprocess.PIPE, 128 | stderr=subprocess.PIPE, 129 | text=True, 130 | encoding="UTF-8", 131 | env=os.environ.copy() 132 | ) 133 | except FileNotFoundError as e: 134 | install_instructions = ( 135 | "Deno executable not found. Please install Deno to proceed.\n" 136 | "Installation instructions:\n" 137 | "> curl -fsSL https://deno.land/install.sh | sh\n" 138 | "*or*, on macOS with Homebrew:\n" 139 | "> brew install deno\n" 140 | "For additional configurations: https://docs.deno.com/runtime/getting_started/installation/" 141 | ) 142 | raise InterpreterError(install_instructions) from e 143 | 144 | def _inject_variables(self, code: str, variables: dict[str, Any]) -> str: 145 | # Insert Python assignments for each variable at the top of the code 146 | injected_lines = [] 147 | for key, value in variables.items(): 148 | if not key.isidentifier(): 149 | raise InterpreterError(f"Invalid variable name: '{key}'") 150 | python_value = self._serialize_value(value) 151 | injected_lines.append(f"{key} = {python_value}") 152 | injected_code = "\n".join(injected_lines) + "\n" + code 153 | return injected_code 154 | 155 | def _serialize_value(self, value: Any) -> str: 156 | # Basic safe serialization 157 | if isinstance(value, str): 158 | return repr(value) 159 | elif isinstance(value, (int, float, bool)): 160 | return str(value) 161 | elif value is None: 162 | return "None" 163 | elif isinstance(value, list) or isinstance(value, dict): 164 | return json.dumps(value) 165 | else: 166 | raise InterpreterError(f"Unsupported value type: {type(value).__name__}") 167 | 168 | def execute( 169 | self, 170 | code: str, 171 | variables: dict[str, Any] | None = None, 172 | ) -> Any: 173 | variables = variables or {} 174 | code = self._inject_variables(code, variables) 175 | self._ensure_deno_process() 176 | self._mount_files() 177 | 178 | # Send the code as JSON 179 | input_data = json.dumps({"code": code}) 180 | try: 181 | self.deno_process.stdin.write(input_data + "\n") 182 | self.deno_process.stdin.flush() 183 | except BrokenPipeError: 184 | # If the process died, restart and try again once 185 | self._ensure_deno_process() 186 | self.deno_process.stdin.write(input_data + "\n") 187 | self.deno_process.stdin.flush() 188 | 189 | # Read one JSON line from stdout 190 | output_line = self.deno_process.stdout.readline().strip() 191 | if not output_line: 192 | # Possibly the subprocess died or gave no output 193 | err_output = self.deno_process.stderr.read() 194 | raise InterpreterError(f"No output from Deno subprocess. Stderr: {err_output}") 195 | 196 | # Parse that line as JSON 197 | try: 198 | result = json.loads(output_line) 199 | except json.JSONDecodeError: 200 | # If not valid JSON, just return raw text 201 | result = {"output": output_line} 202 | 203 | # If we have an error, determine if it's a SyntaxError or other error using error.errorType. 204 | if "error" in result: 205 | error_msg = result["error"] 206 | error_type = result.get("errorType", "Sandbox Error") 207 | if error_type == "FinalAnswer": 208 | # The `FinalAnswer` trick to receive output from the sandbox interpreter, 209 | # just simply replace the output with the arguments. 210 | result["output"] = result.get("errorArgs", None) 211 | elif error_type == "SyntaxError": 212 | raise SyntaxError(f"Invalid Python syntax. message: {error_msg}") 213 | else: 214 | raise InterpreterError(f"{error_type}: {result.get('errorArgs') or error_msg}") 215 | 216 | # If there's no error or got `FinalAnswer`, return the "output" field 217 | self._sync_files() 218 | return result.get("output", None) 219 | 220 | def __enter__(self): 221 | return self 222 | 223 | # All exception fields are ignored and the runtime will automatically re-raise the exception 224 | def __exit__( 225 | self, 226 | _exc_type: type[BaseException] | None, 227 | _exc_val: BaseException | None, 228 | _exc_tb: TracebackType | None, 229 | ): 230 | self.shutdown() 231 | 232 | def __call__( 233 | self, 234 | code: str, 235 | variables: dict[str, Any] | None = None, 236 | ) -> Any: 237 | return self.execute(code, variables) 238 | 239 | def shutdown(self) -> None: 240 | if self.deno_process and self.deno_process.poll() is None: 241 | shutdown_message = json.dumps({"shutdown": True}) + "\n" 242 | self.deno_process.stdin.write(shutdown_message) 243 | self.deno_process.stdin.flush() 244 | self.deno_process.stdin.close() 245 | self.deno_process.wait() 246 | self.deno_process = None 247 | ``` -------------------------------------------------------------------------------- /tests/reliability/test_pydantic_models.py: -------------------------------------------------------------------------------- ```python 1 | from enum import Enum 2 | from typing import Any, List, Literal 3 | 4 | import pydantic 5 | import pytest 6 | 7 | import dspy 8 | from tests.reliability.utils import assert_program_output_correct, known_failing_models 9 | 10 | @pytest.mark.reliability 11 | def test_qa_with_pydantic_answer_model(): 12 | class Answer(pydantic.BaseModel): 13 | value: str 14 | certainty: float = pydantic.Field( 15 | description="A value between 0 and 1 indicating the model's confidence in the answer." 16 | ) 17 | comments: list[str] = pydantic.Field( 18 | description="At least two comments providing additional details about the answer." 19 | ) 20 | 21 | class QA(dspy.Signature): 22 | question: str = dspy.InputField() 23 | answer: Answer = dspy.OutputField() 24 | 25 | program = dspy.Predict(QA) 26 | question = "What is the capital of France?" 27 | answer = program(question=question).answer 28 | 29 | assert_program_output_correct( 30 | program_input=question, 31 | program_output=answer.value, 32 | grading_guidelines="The answer should be Paris. Answer should not contain extraneous information.", 33 | ) 34 | assert_program_output_correct( 35 | program_input=question, 36 | program_output=answer.comments, 37 | grading_guidelines=( 38 | "The comments should be relevant to the answer. They don't need to restate the answer explicitly." 39 | ), 40 | ) 41 | assert answer.certainty >= 0 42 | assert answer.certainty <= 1 43 | assert len(answer.comments) >= 2 44 | 45 | 46 | @pytest.mark.parametrize("module", [dspy.Predict, dspy.ChainOfThought]) 47 | @pytest.mark.reliability 48 | def test_color_classification_using_enum(module): 49 | Color = Enum("Color", ["RED", "GREEN", "BLUE"]) 50 | 51 | class Colorful(dspy.Signature): 52 | text: str = dspy.InputField() 53 | color: Color = dspy.OutputField() 54 | 55 | program = module(Colorful) 56 | # Note: The precise text, including the trailing period, is important here for ensuring that 57 | # the program is correctly extracting the color from the text; previous implementations have 58 | # produced invalid enum responses for "The sky is blue.", but they have produced valid enum 59 | # responses for "The sky is blue" (without the period). 60 | color = program(text="The sky is blue.").color 61 | 62 | assert color == Color.BLUE 63 | 64 | 65 | @pytest.mark.reliability 66 | def test_entity_extraction_with_multiple_primitive_outputs(): 67 | class ExtractEntityFromDescriptionOutput(pydantic.BaseModel): 68 | entity_hu: str = pydantic.Field(description="The extracted entity in Hungarian, cleaned and lowercased.") 69 | entity_en: str = pydantic.Field(description="The English translation of the extracted Hungarian entity.") 70 | is_inverted: bool = pydantic.Field( 71 | description="Boolean flag indicating if the input is connected in an inverted way." 72 | ) 73 | categories: str = pydantic.Field(description="English categories separated by '|' to which the entity belongs.") 74 | review: bool = pydantic.Field( 75 | description="Boolean flag indicating low confidence or uncertainty in the extraction." 76 | ) 77 | 78 | class ExtractEntityFromDescription(dspy.Signature): 79 | """Extract an entity from a Hungarian description, provide its English translation, categories, and an inverted flag.""" 80 | 81 | description: str = dspy.InputField(description="The input description in Hungarian.") 82 | entity: ExtractEntityFromDescriptionOutput = dspy.OutputField( 83 | description="The extracted entity and its properties." 84 | ) 85 | 86 | program = dspy.ChainOfThought(ExtractEntityFromDescription) 87 | description = "A kávé egy növényi eredetű ital, amelyet a kávébabból készítenek." 88 | 89 | extracted_entity = program(description=description).entity 90 | assert_program_output_correct( 91 | program_input=description, 92 | program_output=extracted_entity.entity_hu, 93 | grading_guidelines="The translation of the extracted entity into English should be equivalent to 'coffee'", 94 | ) 95 | assert_program_output_correct( 96 | program_input=description, 97 | program_output=extracted_entity.entity_en, 98 | grading_guidelines="The extracted entity should be equivalent to 'coffee'", 99 | ) 100 | assert_program_output_correct( 101 | program_input=description, 102 | program_output=extracted_entity.categories, 103 | grading_guidelines=( 104 | "The extracted entity should be associated with English language categories that apply to the word 'coffee'." 105 | " The categories should be separated by the character '|'." 106 | ), 107 | ) 108 | 109 | 110 | @pytest.mark.parametrize("module", [dspy.Predict, dspy.ChainOfThought]) 111 | @pytest.mark.reliability 112 | def test_tool_calling_with_literals(module): 113 | next_tool_names = [ 114 | "get_docs", 115 | "finish", 116 | "search_policy", 117 | "notify_manager", 118 | "calculate_accrual", 119 | "combine_leave", 120 | "review_seniority_rules", 121 | "fetch_calendar", 122 | "verify_compensation", 123 | "check_carryover_policy", 124 | ] 125 | 126 | class ToolCalling(dspy.Signature): 127 | """ 128 | Given the fields question, produce the fields response. 129 | You will be given question and your goal is to finish with response. 130 | To do this, you will interleave Thought, Tool Name, and Tool Args, and receive a resulting Observation. 131 | """ 132 | 133 | question: str = dspy.InputField() 134 | trajectory: str = dspy.InputField() 135 | next_thought: str = dspy.OutputField() 136 | next_tool_name: Literal[ 137 | "get_docs", 138 | "finish", 139 | "search_policy", 140 | "notify_manager", 141 | "calculate_accrual", 142 | "combine_leave", 143 | "review_seniority_rules", 144 | "fetch_calendar", 145 | "verify_compensation", 146 | "check_carryover_policy", 147 | ] = dspy.OutputField() 148 | next_tool_args: dict[str, Any] = dspy.OutputField() 149 | response_status: Literal["success", "error", "pending"] = dspy.OutputField() 150 | user_intent: Literal["informational", "transactional", "exploratory"] = dspy.OutputField() 151 | 152 | program = dspy.Predict(ToolCalling) 153 | prediction = program( 154 | question=( 155 | "Tell me more about the company's internal policy for paid time off (PTO), " 156 | "including as many details as possible. I want to know how PTO is accrued—are " 157 | "there fixed rates, and do they vary by employee seniority or length of service? " 158 | "Are there specific rules about carrying unused PTO into the next calendar year, " 159 | "or is it a 'use it or lose it' system? Additionally, if an employee plans to take " 160 | "extended leave for a vacation or personal reasons, what is the process for submitting " 161 | "a request, and how far in advance should they notify their manager? Is there any overlap " 162 | "or interaction between PTO and other forms of leave, such as sick leave or parental leave? " 163 | "For example, can PTO be combined with those leave types to create a longer absence, or are " 164 | "they strictly separate? I’d also like to know if there are any restrictions on when PTO can " 165 | "be used, such as during critical business periods or holidays. Finally, what is the official " 166 | "policy if an employee resigns or is terminated—are they compensated for unused PTO days, and if " 167 | "so, at what rate?" 168 | ), 169 | trajectory=( 170 | "[" 171 | "{'thought': 'Start by understanding PTO accrual rules.', 'tool_name': 'search_policy', 'tool_args': {'topic': 'PTO accrual rates'}}, " 172 | "{'thought': 'Clarify whether PTO accrual rates vary by seniority.', 'tool_name': 'review_seniority_rules', 'tool_args': {}}, " 173 | "{'thought': 'Identify carryover rules for unused PTO.', 'tool_name': 'check_carryover_policy', 'tool_args': {'year': 'current year'}}, " 174 | "{'thought': 'Determine policies on extended leave requests.', 'tool_name': 'search_policy', 'tool_args': {'topic': 'PTO leave request process'}}, " 175 | "{'thought': 'Check the notification requirements for extended PTO.', 'tool_name': 'notify_manager', 'tool_args': {'type': 'extended leave'}}, " 176 | "{'thought': 'Investigate overlap between PTO and sick leave.', 'tool_name': 'combine_leave', 'tool_args': {'types': ['PTO', 'sick leave']}}, " 177 | "{'thought': 'Explore how PTO interacts with parental leave.', 'tool_name': 'combine_leave', 'tool_args': {'types': ['PTO', 'parental leave']}}, " 178 | "{'thought': 'Fetch the company calendar to determine critical business periods.', 'tool_name': 'fetch_calendar', 'tool_args': {'year': 'current year'}}, " 179 | "{'thought': 'Verify restrictions on PTO usage during holidays.', 'tool_name': 'search_policy', 'tool_args': {'topic': 'holiday restrictions on PTO'}}, " 180 | "{'thought': 'Confirm whether unused PTO is compensated upon termination.', 'tool_name': 'verify_compensation', 'tool_args': {'scenario': 'termination'}}, " 181 | "{'thought': 'Check if PTO is compensated differently upon resignation.', 'tool_name': 'verify_compensation', 'tool_args': {'scenario': 'resignation'}}, " 182 | "{'thought': 'Review if accrual caps limit PTO earnings.', 'tool_name': 'calculate_accrual', 'tool_args': {'cap': True}}, " 183 | "{'thought': 'Investigate whether senior employees receive additional PTO benefits.', 'tool_name': 'review_seniority_rules', 'tool_args': {'seniority_level': 'high'}}, " 184 | "{'thought': 'Assess policy transparency in PTO documentation.', 'tool_name': 'search_policy', 'tool_args': {'topic': 'PTO documentation clarity'}}, " 185 | "{'thought': 'Explore how leave types can be optimized together.', 'tool_name': 'combine_leave', 'tool_args': {'types': ['PTO', 'other leave']}}, " 186 | "{'thought': 'Check historical trends in PTO policy changes.', 'tool_name': 'get_docs', 'tool_args': {'document': 'PTO history'}}" 187 | "]" 188 | ), 189 | ) 190 | assert prediction.next_tool_name in next_tool_names 191 | ``` -------------------------------------------------------------------------------- /docs/docs/tutorials/observability/index.md: -------------------------------------------------------------------------------- ```markdown 1 | # Tutorial: Debugging and Observability in DSPy 2 | 3 | This guide demonstrates how to debug problems and improve observability in DSPy. Modern AI programs often involve multiple components, such as language models, retrievers, and tools. DSPy allows you to build and optimize such complex AI systems in a clean and modular way. 4 | 5 | However, as systems grow more sophisticated, the ability to **understand what your system is doing** becomes critical. Without transparency, the prediction process can easily become a black box, making failures or quality issues difficult to diagnose and production maintenance challenging. 6 | 7 | By the end of this tutorial, you'll understand how to debug an issue and improve observability using [MLflow Tracing](#tracing). You'll also explore how to build a custom logging solution using callbacks. 8 | 9 | 10 | 11 | ## Define a Program 12 | 13 | We'll start by creating a simple ReAct agent that uses ColBERTv2's Wikipedia dataset as a retrieval source. You can replace this with a more sophisticated program. 14 | 15 | ```python 16 | import dspy 17 | import os 18 | 19 | os.environ["OPENAI_API_KEY"] = "{your_openai_api_key}" 20 | 21 | lm = dspy.LM("openai/gpt-4o-mini") 22 | colbert = dspy.ColBERTv2(url="http://20.102.90.50:2017/wiki17_abstracts") 23 | dspy.configure(lm=lm) 24 | 25 | 26 | def retrieve(query: str): 27 | """Retrieve top 3 relevant information from ColBert""" 28 | results = colbert(query, k=3) 29 | return [x["text"] for x in results] 30 | 31 | 32 | agent = dspy.ReAct("question -> answer", tools=[retrieve], max_iters=3) 33 | ``` 34 | 35 | Now, let's ask the agent a simple question: 36 | 37 | ```python 38 | prediction = agent(question="Which baseball team does Shohei Ohtani play for in June 2025?") 39 | print(prediction.answer) 40 | ``` 41 | 42 | ``` 43 | Shohei Ohtani is expected to play for the Hokkaido Nippon-Ham Fighters in June 2025, based on the available information. 44 | ``` 45 | 46 | Oh, this is incorrect. He no longer plays for the Hokkaido Nippon-Ham Fighters; he moved to the Dodgers and won the World Series in 2024! Let's debug the program and explore potential fixes. 47 | 48 | ## Using ``inspect_history`` 49 | 50 | DSPy provides the `inspect_history()` utility, which prints out all LLM invocations made so far: 51 | 52 | ```python 53 | # Print out 5 LLM calls 54 | dspy.inspect_history(n=5) 55 | ``` 56 | 57 | ``` 58 | [2024-12-01T10:23:29.144257] 59 | 60 | System message: 61 | 62 | Your input fields are: 63 | 1. `question` (str) 64 | 65 | ... 66 | 67 | Response: 68 | 69 | Response: 70 | 71 | [[ ## reasoning ## ]] 72 | The search for information regarding Shohei Ohtani's team in June 2025 did not yield any specific results. The retrieved data consistently mentioned that he plays for the Hokkaido Nippon-Ham Fighters, but there was no indication of any changes or updates regarding his team for the specified date. Given the lack of information, it is reasonable to conclude that he may still be with the Hokkaido Nippon-Ham Fighters unless there are future developments that are not captured in the current data. 73 | 74 | [[ ## answer ## ]] 75 | Shohei Ohtani is expected to play for the Hokkaido Nippon-Ham Fighters in June 2025, based on the available information. 76 | 77 | [[ ## completed ## ]] 78 | 79 | ``` 80 | The log reveals that the agent could not retrieve helpful information from the search tool. However, what exactly did the retriever return? While useful, `inspect_history` has some limitations: 81 | 82 | * In real-world systems, other components like retrievers, tools, and custom modules play significant roles, but `inspect_history` only logs LLM calls. 83 | * DSPy programs often make multiple LLM calls within a single prediction. Monolith log history makes it hard to organize logs, especially when handling multiple questions. 84 | * Metadata such as parameters, latency, and the relationship between modules are not captured. 85 | 86 | **Tracing** addresses these limitations and provides a more comprehensive solution. 87 | 88 | ## Tracing 89 | 90 | [MLflow](https://mlflow.org/docs/latest/llms/tracing/index.html) is an end-to-end machine learning platform that is integrated seamlessly with DSPy to support best practices in LLMOps. Using MLflow's automatic tracing capability with DSPy is straightforward; **No sign up for services or an API key is required**. You just need to install MLflow and call `mlflow.dspy.autolog()` in your notebook or script. 91 | 92 | ```bash 93 | pip install -U mlflow>=2.18.0 94 | ``` 95 | 96 | After installation, spin up your server via the command below. 97 | 98 | ``` 99 | # It is highly recommended to use SQL store when using MLflow tracing 100 | mlflow server --backend-store-uri sqlite:///mydb.sqlite 101 | ``` 102 | 103 | If you don't specify a different port via `--port` flag, you MLflow server will be hosted at port 5000. 104 | 105 | Now let's change our code snippet to enable MLflow tracing. We need to: 106 | 107 | - Tell MLflow where the server is hosted. 108 | - Apply `mlflow.autolog()` so that DSPy tracing is automatically captured. 109 | 110 | The full code is as below, now let's run it again! 111 | 112 | ```python 113 | import dspy 114 | import os 115 | import mlflow 116 | 117 | os.environ["OPENAI_API_KEY"] = "{your_openai_api_key}" 118 | 119 | # Tell MLflow about the server URI. 120 | mlflow.set_tracking_uri("http://127.0.0.1:5000") 121 | # Create a unique name for your experiment. 122 | mlflow.set_experiment("DSPy") 123 | 124 | lm = dspy.LM("openai/gpt-4o-mini") 125 | colbert = dspy.ColBERTv2(url="http://20.102.90.50:2017/wiki17_abstracts") 126 | dspy.configure(lm=lm) 127 | 128 | 129 | def retrieve(query: str): 130 | """Retrieve top 3 relevant information from ColBert""" 131 | results = colbert(query, k=3) 132 | return [x["text"] for x in results] 133 | 134 | 135 | agent = dspy.ReAct("question -> answer", tools=[retrieve], max_iters=3) 136 | print(agent(question="Which baseball team does Shohei Ohtani play for?")) 137 | ``` 138 | 139 | 140 | MLflow automatically generates a **trace** for each prediction and records it within your experiment. To explore these traces visually, open `http://127.0.0.1:5000/` 141 | in your browser, then select the experiment you just created and navigate to the Traces tab: 142 | 143 |  144 | 145 | Click on the most recent trace to view its detailed breakdown: 146 | 147 |  148 | 149 | Here, you can examine the input and output of every step in your workflow. For example, the screenshot above shows the `retrieve` function's input and output. By inspecting the retriever's output, you can see that it returned outdated information, which is not sufficient to determine which team Shohei Ohtani plays for in June 2025. You can also inspect 150 | other steps, e.g, language model's input, output, and configuration. 151 | 152 | To address the issue of outdated information, you can replace the `retrieve` function with a web search tool powered by [Tavily search](https://www.tavily.com/). 153 | 154 | ```python 155 | from tavily import TavilyClient 156 | import dspy 157 | import mlflow 158 | 159 | # Tell MLflow about the server URI. 160 | mlflow.set_tracking_uri("http://127.0.0.1:5000") 161 | # Create a unique name for your experiment. 162 | mlflow.set_experiment("DSPy") 163 | 164 | search_client = TavilyClient(api_key="<YOUR_TAVILY_API_KEY>") 165 | 166 | def web_search(query: str) -> list[str]: 167 | """Run a web search and return the content from the top 5 search results""" 168 | response = search_client.search(query) 169 | return [r["content"] for r in response["results"]] 170 | 171 | agent = dspy.ReAct("question -> answer", tools=[web_search]) 172 | 173 | prediction = agent(question="Which baseball team does Shohei Ohtani play for?") 174 | print(agent.answer) 175 | ``` 176 | 177 | ``` 178 | Los Angeles Dodgers 179 | ``` 180 | 181 | Below is a GIF demonstrating how to navigate through the MLflow UI: 182 | 183 |  184 | 185 | 186 | For a complete guide on how to use MLflow tracing, please refer to 187 | the [MLflow Tracing Guide](https://mlflow.org/docs/3.0.0rc0/tracing). 188 | 189 | 190 | 191 | !!! info Learn more about MLflow 192 | 193 | MLflow is an end-to-end LLMOps platform that offers extensive features like experiment tracking, evaluation, and deployment. To learn more about DSPy and MLflow integration, visit [this tutorial](../deployment/index.md#deploying-with-mlflow). 194 | 195 | 196 | ## Building a Custom Logging Solution 197 | 198 | Sometimes, you may want to implement a custom logging solution. For instance, you might need to log specific events triggered by a particular module. DSPy's **callback** mechanism supports such use cases. The ``BaseCallback`` class provides several handlers for customizing logging behavior: 199 | 200 | |Handlers|Description| 201 | |:--|:--| 202 | |`on_module_start` / `on_module_end` | Triggered when a `dspy.Module` subclass is invoked. | 203 | |`on_lm_start` / `on_lm_end` | Triggered when a `dspy.LM` subclass is invoked. | 204 | |`on_adapter_format_start` / `on_adapter_format_end`| Triggered when a `dspy.Adapter` subclass formats the input prompt. | 205 | |`on_adapter_parse_start` / `on_adapter_parse_end`| Triggered when a `dspy.Adapter` subclass postprocess the output text from an LM. | 206 | |`on_tool_start` / `on_tool_end` | Triggered when a `dspy.Tool` subclass is invoked. | 207 | |`on_evaluate_start` / `on_evaluate_end` | Triggered when a `dspy.Evaluate` instance is invoked. | 208 | 209 | Here's an example of custom callback that logs the intermediate steps of a ReAct agent: 210 | 211 | ```python 212 | import dspy 213 | from dspy.utils.callback import BaseCallback 214 | 215 | # 1. Define a custom callback class that extends BaseCallback class 216 | class AgentLoggingCallback(BaseCallback): 217 | 218 | # 2. Implement on_module_end handler to run a custom logging code. 219 | def on_module_end(self, call_id, outputs, exception): 220 | step = "Reasoning" if self._is_reasoning_output(outputs) else "Acting" 221 | print(f"== {step} Step ===") 222 | for k, v in outputs.items(): 223 | print(f" {k}: {v}") 224 | print("\n") 225 | 226 | def _is_reasoning_output(self, outputs): 227 | return any(k.startswith("Thought") for k in outputs.keys()) 228 | 229 | # 3. Set the callback to DSPy setting so it will be applied to program execution 230 | dspy.configure(callbacks=[AgentLoggingCallback()]) 231 | ``` 232 | 233 | 234 | ``` 235 | == Reasoning Step === 236 | Thought_1: I need to find the current team that Shohei Ohtani plays for in Major League Baseball. 237 | Action_1: Search[Shohei Ohtani current team 2023] 238 | 239 | == Acting Step === 240 | passages: ["Shohei Ohtani ..."] 241 | 242 | ... 243 | ``` 244 | 245 | !!! info Handling Inputs and Outputs in Callbacks 246 | 247 | Be cautious when working with input or output data in callbacks. Mutating them in-place can modify the original data passed to the program, potentially leading to unexpected behavior. To avoid this, it's strongly recommended to create a copy of the data before performing any operations that may alter it. 248 | ``` -------------------------------------------------------------------------------- /dspy/predict/predict.py: -------------------------------------------------------------------------------- ```python 1 | import logging 2 | import random 3 | 4 | from pydantic import BaseModel 5 | 6 | from dspy.adapters.chat_adapter import ChatAdapter 7 | from dspy.clients.base_lm import BaseLM 8 | from dspy.clients.lm import LM 9 | from dspy.dsp.utils.settings import settings 10 | from dspy.predict.parameter import Parameter 11 | from dspy.primitives.module import Module 12 | from dspy.primitives.prediction import Prediction 13 | from dspy.signatures.signature import Signature, ensure_signature 14 | from dspy.utils.callback import BaseCallback 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | class Predict(Module, Parameter): 20 | """Basic DSPy module that maps inputs to outputs using a language model. 21 | 22 | Args: 23 | signature: The input/output signature describing the task. 24 | callbacks: Optional list of callbacks for instrumentation. 25 | **config: Default keyword arguments forwarded to the underlying 26 | language model. These values can be overridden for a single 27 | invocation by passing a ``config`` dictionary when calling the 28 | module. For example:: 29 | 30 | predict = dspy.Predict("q -> a", rollout_id=1, temperature=1.0) 31 | predict(q="What is 1 + 52?", config={"rollout_id": 2, "temperature": 1.0}) 32 | """ 33 | 34 | def __init__(self, signature: str | type[Signature], callbacks: list[BaseCallback] | None = None, **config): 35 | super().__init__(callbacks=callbacks) 36 | self.stage = random.randbytes(8).hex() 37 | self.signature = ensure_signature(signature) 38 | self.config = config 39 | self.reset() 40 | 41 | def reset(self): 42 | self.lm = None 43 | self.traces = [] 44 | self.train = [] 45 | self.demos = [] 46 | 47 | def dump_state(self, json_mode=True): 48 | state_keys = ["traces", "train"] 49 | state = {k: getattr(self, k) for k in state_keys} 50 | 51 | state["demos"] = [] 52 | for demo in self.demos: 53 | demo = demo.copy() 54 | 55 | for field in demo: 56 | # FIXME: Saving BaseModels as strings in examples doesn't matter because you never re-access as an object 57 | demo[field] = serialize_object(demo[field]) 58 | 59 | if isinstance(demo, dict) or not json_mode: 60 | state["demos"].append(demo) 61 | else: 62 | state["demos"].append(demo.toDict()) 63 | 64 | state["signature"] = self.signature.dump_state() 65 | state["lm"] = self.lm.dump_state() if self.lm else None 66 | return state 67 | 68 | def load_state(self, state: dict) -> "Predict": 69 | """Load the saved state of a `Predict` object. 70 | 71 | Args: 72 | state: The saved state of a `Predict` object. 73 | 74 | Returns: 75 | Self to allow method chaining. 76 | """ 77 | excluded_keys = ["signature", "extended_signature", "lm"] 78 | for name, value in state.items(): 79 | # `excluded_keys` are fields that go through special handling. 80 | if name not in excluded_keys: 81 | setattr(self, name, value) 82 | 83 | self.signature = self.signature.load_state(state["signature"]) 84 | self.lm = LM(**state["lm"]) if state["lm"] else None 85 | 86 | if "extended_signature" in state: # legacy, up to and including 2.5, for CoT. 87 | raise NotImplementedError("Loading extended_signature is no longer supported in DSPy 2.6+") 88 | 89 | return self 90 | 91 | def _get_positional_args_error_message(self): 92 | input_fields = list(self.signature.input_fields.keys()) 93 | return ( 94 | "Positional arguments are not allowed when calling `dspy.Predict`, must use keyword arguments " 95 | f"that match your signature input fields: '{', '.join(input_fields)}'. For example: " 96 | f"`predict({input_fields[0]}=input_value, ...)`." 97 | ) 98 | 99 | def __call__(self, *args, **kwargs): 100 | if args: 101 | raise ValueError(self._get_positional_args_error_message()) 102 | 103 | return super().__call__(**kwargs) 104 | 105 | async def acall(self, *args, **kwargs): 106 | if args: 107 | raise ValueError(self._get_positional_args_error_message()) 108 | 109 | return await super().acall(**kwargs) 110 | 111 | def _forward_preprocess(self, **kwargs): 112 | # Extract the three privileged keyword arguments. 113 | assert "new_signature" not in kwargs, "new_signature is no longer a valid keyword argument." 114 | signature = ensure_signature(kwargs.pop("signature", self.signature)) 115 | demos = kwargs.pop("demos", self.demos) 116 | config = {**self.config, **kwargs.pop("config", {})} 117 | 118 | # Get the right LM to use. 119 | lm = kwargs.pop("lm", self.lm) or settings.lm 120 | 121 | if lm is None: 122 | raise ValueError( 123 | "No LM is loaded. Please configure the LM using `dspy.configure(lm=dspy.LM(...))`. e.g, " 124 | "`dspy.configure(lm=dspy.LM('openai/gpt-4o-mini'))`" 125 | ) 126 | 127 | if isinstance(lm, str): 128 | # Many users mistakenly use `dspy.configure(lm="openai/gpt-4o-mini")` instead of 129 | # `dspy.configure(lm=dspy.LM("openai/gpt-4o-mini"))`, so we are providing a specific error message. 130 | raise ValueError( 131 | f"LM must be an instance of `dspy.BaseLM`, not a string. Instead of using a string like " 132 | f"'dspy.configure(lm=\"{lm}\")', please configure the LM like 'dspy.configure(lm=dspy.LM(\"{lm}\"))'" 133 | ) 134 | elif not isinstance(lm, BaseLM): 135 | raise ValueError(f"LM must be an instance of `dspy.BaseLM`, not {type(lm)}. Received `lm={lm}`.") 136 | 137 | # If temperature is unset or <=0.15, and n > 1, set temperature to 0.7 to keep randomness. 138 | temperature = config.get("temperature") or lm.kwargs.get("temperature") 139 | num_generations = config.get("n") or lm.kwargs.get("n") or lm.kwargs.get("num_generations") or 1 140 | 141 | if (temperature is None or temperature <= 0.15) and num_generations > 1: 142 | config["temperature"] = 0.7 143 | 144 | if "prediction" in kwargs: 145 | if ( 146 | isinstance(kwargs["prediction"], dict) 147 | and kwargs["prediction"].get("type") == "content" 148 | and "content" in kwargs["prediction"] 149 | ): 150 | # If the `prediction` is the standard predicted outputs format 151 | # (https://platform.openai.com/docs/guides/predicted-outputs), we remove it from input kwargs and add it 152 | # to the lm kwargs. 153 | config["prediction"] = kwargs.pop("prediction") 154 | 155 | if not all(k in kwargs for k in signature.input_fields): 156 | present = [k for k in signature.input_fields if k in kwargs] 157 | missing = [k for k in signature.input_fields if k not in kwargs] 158 | logger.warning( 159 | "Not all input fields were provided to module. Present: %s. Missing: %s.", 160 | present, 161 | missing, 162 | ) 163 | return lm, config, signature, demos, kwargs 164 | 165 | def _forward_postprocess(self, completions, signature, **kwargs): 166 | pred = Prediction.from_completions(completions, signature=signature) 167 | if kwargs.pop("_trace", True) and settings.trace is not None and settings.max_trace_size > 0: 168 | trace = settings.trace 169 | if len(trace) >= settings.max_trace_size: 170 | trace.pop(0) 171 | trace.append((self, {**kwargs}, pred)) 172 | return pred 173 | 174 | def _should_stream(self): 175 | stream_listeners = settings.stream_listeners or [] 176 | should_stream = settings.send_stream is not None 177 | if should_stream and len(stream_listeners) > 0: 178 | should_stream = any(stream_listener.predict == self for stream_listener in stream_listeners) 179 | 180 | return should_stream 181 | 182 | def forward(self, **kwargs): 183 | lm, config, signature, demos, kwargs = self._forward_preprocess(**kwargs) 184 | 185 | adapter = settings.adapter or ChatAdapter() 186 | 187 | if self._should_stream(): 188 | with settings.context(caller_predict=self): 189 | completions = adapter(lm, lm_kwargs=config, signature=signature, demos=demos, inputs=kwargs) 190 | else: 191 | with settings.context(send_stream=None): 192 | completions = adapter(lm, lm_kwargs=config, signature=signature, demos=demos, inputs=kwargs) 193 | 194 | return self._forward_postprocess(completions, signature, **kwargs) 195 | 196 | async def aforward(self, **kwargs): 197 | lm, config, signature, demos, kwargs = self._forward_preprocess(**kwargs) 198 | 199 | adapter = settings.adapter or ChatAdapter() 200 | if self._should_stream(): 201 | with settings.context(caller_predict=self): 202 | completions = await adapter.acall(lm, lm_kwargs=config, signature=signature, demos=demos, inputs=kwargs) 203 | else: 204 | with settings.context(send_stream=None): 205 | completions = await adapter.acall(lm, lm_kwargs=config, signature=signature, demos=demos, inputs=kwargs) 206 | 207 | return self._forward_postprocess(completions, signature, **kwargs) 208 | 209 | def update_config(self, **kwargs): 210 | self.config = {**self.config, **kwargs} 211 | 212 | def get_config(self): 213 | return self.config 214 | 215 | def __repr__(self): 216 | return f"{self.__class__.__name__}({self.signature})" 217 | 218 | 219 | def serialize_object(obj): 220 | """ 221 | Recursively serialize a given object into a JSON-compatible format. 222 | Supports Pydantic models, lists, dicts, and primitive types. 223 | """ 224 | if isinstance(obj, BaseModel): 225 | # Use model_dump with mode="json" to ensure all fields (including HttpUrl, datetime, etc.) 226 | # are converted to JSON-serializable types (strings) 227 | return obj.model_dump(mode="json") 228 | elif isinstance(obj, list): 229 | return [serialize_object(item) for item in obj] 230 | elif isinstance(obj, tuple): 231 | return tuple(serialize_object(item) for item in obj) 232 | elif isinstance(obj, dict): 233 | return {key: serialize_object(value) for key, value in obj.items()} 234 | else: 235 | return obj 236 | 237 | 238 | # # TODO: FIXME: Hmm, I guess expected behavior is that contexts can 239 | # affect execution. Well, we need to determine whether context dominates, __init__ demoninates, or forward dominates. 240 | # Generally, unless overwritten, we'd see n=None, temperature=None. 241 | # That will eventually mean we have to learn them. 242 | ``` -------------------------------------------------------------------------------- /dspy/adapters/baml_adapter.py: -------------------------------------------------------------------------------- ```python 1 | """ 2 | Custom adapter for improving structured outputs using the information from Pydantic models. 3 | Based on the format used by BAML: https://github.com/BoundaryML/baml 4 | """ 5 | 6 | import inspect 7 | import types 8 | from typing import Any, Literal, Union, get_args, get_origin 9 | 10 | from pydantic import BaseModel 11 | 12 | from dspy.adapters.json_adapter import JSONAdapter 13 | from dspy.adapters.utils import format_field_value as original_format_field_value 14 | from dspy.signatures.signature import Signature 15 | 16 | # Changing the comment symbol to Python's # rather than other languages' // seems to help 17 | COMMENT_SYMBOL = "#" 18 | 19 | 20 | def _render_type_str( 21 | annotation: Any, 22 | depth: int = 0, 23 | indent: int = 0, 24 | seen_models: set[type] | None = None, 25 | ) -> str: 26 | """Recursively renders a type annotation into a simplified string. 27 | 28 | Args: 29 | annotation: The type annotation to render 30 | depth: Current recursion depth (prevents infinite recursion) 31 | indent: Current indentation level for nested structures 32 | """ 33 | # Non-nested types 34 | if annotation is str: 35 | return "string" 36 | if annotation is int: 37 | return "int" 38 | if annotation is float: 39 | return "float" 40 | if annotation is bool: 41 | return "boolean" 42 | if inspect.isclass(annotation) and issubclass(annotation, BaseModel): 43 | return _build_simplified_schema(annotation, indent, seen_models) 44 | 45 | try: 46 | origin = get_origin(annotation) 47 | args = get_args(annotation) 48 | except Exception: 49 | return str(annotation) 50 | 51 | # Optional[T] or T | None 52 | if origin in (types.UnionType, Union): 53 | non_none_args = [arg for arg in args if arg is not type(None)] 54 | # Render the non-None part of the union 55 | type_render = " or ".join([_render_type_str(arg, depth + 1, indent) for arg in non_none_args]) 56 | # Add "or null" if None was part of the union 57 | if len(non_none_args) < len(args): 58 | return f"{type_render} or null" 59 | return type_render 60 | 61 | # Literal[T1, T2, ...] 62 | if origin is Literal: 63 | return " or ".join(f'"{arg}"' for arg in args) 64 | 65 | # list[T] 66 | if origin is list: 67 | # For Pydantic models in lists, use bracket notation 68 | inner_type = args[0] 69 | if inspect.isclass(inner_type) and issubclass(inner_type, BaseModel): 70 | # Build inner schema - the Pydantic model inside should use indent level for array contents 71 | inner_schema = _build_simplified_schema(inner_type, indent + 1, seen_models) 72 | # Format with proper bracket notation and indentation 73 | current_indent = " " * indent 74 | return f"[\n{inner_schema}\n{current_indent}]" 75 | else: 76 | return f"{_render_type_str(inner_type, depth + 1, indent)}[]" 77 | 78 | # dict[T1, T2] 79 | if origin is dict: 80 | return f"dict[{_render_type_str(args[0], depth + 1, indent)}, {_render_type_str(args[1], depth + 1, indent)}]" 81 | 82 | # fallback 83 | if hasattr(annotation, "__name__"): 84 | return annotation.__name__ 85 | return str(annotation) 86 | 87 | 88 | def _build_simplified_schema( 89 | pydantic_model: type[BaseModel], 90 | indent: int = 0, 91 | seen_models: set[type] | None = None, 92 | ) -> str: 93 | """Builds a simplified, human-readable schema from a Pydantic model. 94 | 95 | Args: 96 | pydantic_model: The Pydantic model to build schema for 97 | indent: Current indentation level 98 | seen_models: Set to track visited pydantic models (prevents infinite recursion) 99 | """ 100 | seen_models = seen_models or set() 101 | 102 | if pydantic_model in seen_models: 103 | raise ValueError("BAMLAdapter cannot handle recursive pydantic models, please use a different adapter.") 104 | 105 | # Add `pydantic_model` to `seen_models` with a placeholder value to avoid infinite recursion. 106 | seen_models.add(pydantic_model) 107 | 108 | lines = [] 109 | current_indent = " " * indent 110 | next_indent = " " * (indent + 1) 111 | 112 | lines.append(f"{current_indent}{{") 113 | 114 | fields = pydantic_model.model_fields 115 | if not fields: 116 | lines.append(f"{next_indent}{COMMENT_SYMBOL} No fields defined") 117 | for name, field in fields.items(): 118 | if field.description: 119 | lines.append(f"{next_indent}{COMMENT_SYMBOL} {field.description}") 120 | elif field.alias and field.alias != name: 121 | # If there's an alias but no description, show the alias as a comment 122 | lines.append(f"{next_indent}{COMMENT_SYMBOL} alias: {field.alias}") 123 | 124 | rendered_type = _render_type_str(field.annotation, indent=indent + 1, seen_models=seen_models) 125 | line = f"{next_indent}{name}: {rendered_type}," 126 | 127 | lines.append(line) 128 | 129 | lines.append(f"{current_indent}}}") 130 | return "\n".join(lines) 131 | 132 | 133 | class BAMLAdapter(JSONAdapter): 134 | """ 135 | A DSPy adapter that improves the rendering of complex/nested Pydantic models to help LMs. 136 | 137 | This adapter generates a compact, human-readable schema representation for nested Pydantic output 138 | fields, inspired by the BAML project's JSON formatter (https://github.com/BoundaryML/baml). 139 | The resulting rendered schema is more token-efficient and easier for smaller LMs to follow than a 140 | raw JSON schema. It also includes Pydantic field descriptions as comments in the schema, which 141 | provide valuable additional context for the LM to understand the expected output. 142 | 143 | Example Usage: 144 | ```python 145 | import dspy 146 | from pydantic import BaseModel, Field 147 | from typing import Literal 148 | from baml_adapter import BAMLAdapter # Import from your module 149 | 150 | # 1. Define your Pydantic models 151 | class PatientAddress(BaseModel): 152 | street: str 153 | city: str 154 | country: Literal["US", "CA"] 155 | 156 | class PatientDetails(BaseModel): 157 | name: str = Field(description="Full name of the patient.") 158 | age: int 159 | address: PatientAddress | None 160 | 161 | # 2. Define a signature using the Pydantic model as an output field 162 | class ExtractPatientInfo(dspy.Signature): 163 | '''Extract patient information from the clinical note.''' 164 | clinical_note: str = dspy.InputField() 165 | patient_info: PatientDetails = dspy.OutputField() 166 | 167 | # 3. Configure dspy to use the new adapter 168 | llm = dspy.OpenAI(model="gpt-4.1-mini") 169 | dspy.configure(lm=llm, adapter=BAMLAdapter()) 170 | 171 | # 4. Run your program 172 | extractor = dspy.Predict(ExtractPatientInfo) 173 | note = "John Doe, 45 years old, lives at 123 Main St, Anytown. Resident of the US." 174 | result = extractor(clinical_note=note) 175 | print(result.patient_info) 176 | 177 | # Expected output: 178 | # PatientDetails(name='John Doe', age=45, address=PatientAddress(street='123 Main St', city='Anytown', country='US')) 179 | ``` 180 | """ 181 | 182 | def format_field_description(self, signature: type[Signature]) -> str: 183 | """Format the field description for the system message.""" 184 | sections = [] 185 | 186 | # Add input field descriptions 187 | if signature.input_fields: 188 | sections.append("Your input fields are:") 189 | for i, (name, field) in enumerate(signature.input_fields.items(), 1): 190 | type_name = getattr(field.annotation, "__name__", str(field.annotation)) 191 | description = f": {field.description}" if field.description else ":" 192 | sections.append(f"{i}. `{name}` ({type_name}){description}") 193 | 194 | # Add output field descriptions 195 | if signature.output_fields: 196 | sections.append("Your output fields are:") 197 | for i, (name, field) in enumerate(signature.output_fields.items(), 1): 198 | type_name = getattr(field.annotation, "__name__", str(field.annotation)) 199 | description = f": {field.description}" if field.description else ":" 200 | sections.append(f"{i}. `{name}` ({type_name}){description}") 201 | 202 | return "\n".join(sections) 203 | 204 | def format_field_structure(self, signature: type[Signature]) -> str: 205 | """Overrides the base method to generate a simplified schema for Pydantic models.""" 206 | 207 | sections = [] 208 | 209 | # Add structural explanation 210 | sections.append( 211 | "All interactions will be structured in the following way, with the appropriate values filled in.\n" 212 | ) 213 | 214 | # Add input structure section 215 | if signature.input_fields: 216 | for name in signature.input_fields.keys(): 217 | sections.append(f"[[ ## {name} ## ]]") 218 | sections.append(f"{{{name}}}") 219 | sections.append("") # Empty line after each input 220 | 221 | # Add output structure section 222 | if signature.output_fields: 223 | for name, field in signature.output_fields.items(): 224 | field_type = field.annotation 225 | sections.append(f"[[ ## {name} ## ]]") 226 | sections.append(f"Output field `{name}` should be of type: {_render_type_str(field_type, indent=0)}\n") 227 | 228 | # Add completed section 229 | sections.append("[[ ## completed ## ]]") 230 | 231 | return "\n".join(sections) 232 | 233 | def format_user_message_content( 234 | self, 235 | signature: type[Signature], 236 | inputs: dict[str, Any], 237 | prefix: str = "", 238 | suffix: str = "", 239 | main_request: bool = False, 240 | ) -> str: 241 | """Overrides the base method to render Pydantic input instances as clean JSON.""" 242 | messages = [prefix] 243 | for key, field_info in signature.input_fields.items(): 244 | if key in inputs: 245 | value = inputs.get(key) 246 | formatted_value = "" 247 | if isinstance(value, BaseModel): 248 | # Use clean, indented JSON for Pydantic instances 249 | formatted_value = value.model_dump_json(indent=2, by_alias=True) 250 | else: 251 | # Fallback to the original dspy formatter for other types 252 | formatted_value = original_format_field_value(field_info=field_info, value=value) 253 | 254 | messages.append(f"[[ ## {key} ## ]]\n{formatted_value}") 255 | 256 | if main_request: 257 | output_requirements = self.user_message_output_requirements(signature) 258 | if output_requirements is not None: 259 | messages.append(output_requirements) 260 | 261 | messages.append(suffix) 262 | return "\n\n".join(m for m in messages if m).strip() 263 | ``` -------------------------------------------------------------------------------- /docs/docs/tutorials/llms_txt_generation/index.md: -------------------------------------------------------------------------------- ```markdown 1 | # Generating llms.txt for Code Documentation with DSPy 2 | 3 | This tutorial demonstrates how to use DSPy to automatically generate an `llms.txt` file for the DSPy repository itself. The `llms.txt` standard provides LLM-friendly documentation that helps AI systems better understand codebases. 4 | 5 | ## What is llms.txt? 6 | 7 | `llms.txt` is a proposed standard for providing structured, LLM-friendly documentation about a project. It typically includes: 8 | 9 | - Project overview and purpose 10 | - Key concepts and terminology 11 | - Architecture and structure 12 | - Usage examples 13 | - Important files and directories 14 | 15 | ## Building a DSPy Program for llms.txt Generation 16 | 17 | Let's create a DSPy program that analyzes a repository and generates comprehensive `llms.txt` documentation. 18 | 19 | ### Step 1: Define Our Signatures 20 | 21 | First, we'll define signatures for different aspects of documentation generation: 22 | 23 | ```python 24 | import dspy 25 | from typing import List 26 | 27 | class AnalyzeRepository(dspy.Signature): 28 | """Analyze a repository structure and identify key components.""" 29 | repo_url: str = dspy.InputField(desc="GitHub repository URL") 30 | file_tree: str = dspy.InputField(desc="Repository file structure") 31 | readme_content: str = dspy.InputField(desc="README.md content") 32 | 33 | project_purpose: str = dspy.OutputField(desc="Main purpose and goals of the project") 34 | key_concepts: list[str] = dspy.OutputField(desc="List of important concepts and terminology") 35 | architecture_overview: str = dspy.OutputField(desc="High-level architecture description") 36 | 37 | class AnalyzeCodeStructure(dspy.Signature): 38 | """Analyze code structure to identify important directories and files.""" 39 | file_tree: str = dspy.InputField(desc="Repository file structure") 40 | package_files: str = dspy.InputField(desc="Key package and configuration files") 41 | 42 | important_directories: list[str] = dspy.OutputField(desc="Key directories and their purposes") 43 | entry_points: list[str] = dspy.OutputField(desc="Main entry points and important files") 44 | development_info: str = dspy.OutputField(desc="Development setup and workflow information") 45 | 46 | class GenerateLLMsTxt(dspy.Signature): 47 | """Generate a comprehensive llms.txt file from analyzed repository information.""" 48 | project_purpose: str = dspy.InputField() 49 | key_concepts: list[str] = dspy.InputField() 50 | architecture_overview: str = dspy.InputField() 51 | important_directories: list[str] = dspy.InputField() 52 | entry_points: list[str] = dspy.InputField() 53 | development_info: str = dspy.InputField() 54 | usage_examples: str = dspy.InputField(desc="Common usage patterns and examples") 55 | 56 | llms_txt_content: str = dspy.OutputField(desc="Complete llms.txt file content following the standard format") 57 | ``` 58 | 59 | ### Step 2: Create the Repository Analyzer Module 60 | 61 | ```python 62 | class RepositoryAnalyzer(dspy.Module): 63 | def __init__(self): 64 | super().__init__() 65 | self.analyze_repo = dspy.ChainOfThought(AnalyzeRepository) 66 | self.analyze_structure = dspy.ChainOfThought(AnalyzeCodeStructure) 67 | self.generate_examples = dspy.ChainOfThought("repo_info -> usage_examples") 68 | self.generate_llms_txt = dspy.ChainOfThought(GenerateLLMsTxt) 69 | 70 | def forward(self, repo_url, file_tree, readme_content, package_files): 71 | # Analyze repository purpose and concepts 72 | repo_analysis = self.analyze_repo( 73 | repo_url=repo_url, 74 | file_tree=file_tree, 75 | readme_content=readme_content 76 | ) 77 | 78 | # Analyze code structure 79 | structure_analysis = self.analyze_structure( 80 | file_tree=file_tree, 81 | package_files=package_files 82 | ) 83 | 84 | # Generate usage examples 85 | usage_examples = self.generate_examples( 86 | repo_info=f"Purpose: {repo_analysis.project_purpose}\nConcepts: {repo_analysis.key_concepts}" 87 | ) 88 | 89 | # Generate final llms.txt 90 | llms_txt = self.generate_llms_txt( 91 | project_purpose=repo_analysis.project_purpose, 92 | key_concepts=repo_analysis.key_concepts, 93 | architecture_overview=repo_analysis.architecture_overview, 94 | important_directories=structure_analysis.important_directories, 95 | entry_points=structure_analysis.entry_points, 96 | development_info=structure_analysis.development_info, 97 | usage_examples=usage_examples.usage_examples 98 | ) 99 | 100 | return dspy.Prediction( 101 | llms_txt_content=llms_txt.llms_txt_content, 102 | analysis=repo_analysis, 103 | structure=structure_analysis 104 | ) 105 | ``` 106 | 107 | ### Step 3: Gather Repository Information 108 | 109 | Let's create helper functions to extract repository information: 110 | 111 | ```python 112 | import requests 113 | import os 114 | from pathlib import Path 115 | 116 | os.environ["GITHUB_ACCESS_TOKEN"] = "<your_access_token>" 117 | 118 | def get_github_file_tree(repo_url): 119 | """Get repository file structure from GitHub API.""" 120 | # Extract owner/repo from URL 121 | parts = repo_url.rstrip('/').split('/') 122 | owner, repo = parts[-2], parts[-1] 123 | 124 | api_url = f"https://api.github.com/repos/{owner}/{repo}/git/trees/main?recursive=1" 125 | response = requests.get(api_url, headers={ 126 | "Authorization": f"Bearer {os.environ.get('GITHUB_ACCESS_TOKEN')}" 127 | }) 128 | 129 | if response.status_code == 200: 130 | tree_data = response.json() 131 | file_paths = [item['path'] for item in tree_data['tree'] if item['type'] == 'blob'] 132 | return '\n'.join(sorted(file_paths)) 133 | else: 134 | raise Exception(f"Failed to fetch repository tree: {response.status_code}") 135 | 136 | def get_github_file_content(repo_url, file_path): 137 | """Get specific file content from GitHub.""" 138 | parts = repo_url.rstrip('/').split('/') 139 | owner, repo = parts[-2], parts[-1] 140 | 141 | api_url = f"https://api.github.com/repos/{owner}/{repo}/contents/{file_path}" 142 | response = requests.get(api_url, headers={ 143 | "Authorization": f"Bearer {os.environ.get('GITHUB_ACCESS_TOKEN')}" 144 | }) 145 | 146 | if response.status_code == 200: 147 | import base64 148 | content = base64.b64decode(response.json()['content']).decode('utf-8') 149 | return content 150 | else: 151 | return f"Could not fetch {file_path}" 152 | 153 | def gather_repository_info(repo_url): 154 | """Gather all necessary repository information.""" 155 | file_tree = get_github_file_tree(repo_url) 156 | readme_content = get_github_file_content(repo_url, "README.md") 157 | 158 | # Get key package files 159 | package_files = [] 160 | for file_path in ["pyproject.toml", "setup.py", "requirements.txt", "package.json"]: 161 | try: 162 | content = get_github_file_content(repo_url, file_path) 163 | if "Could not fetch" not in content: 164 | package_files.append(f"=== {file_path} ===\n{content}") 165 | except: 166 | continue 167 | 168 | package_files_content = "\n\n".join(package_files) 169 | 170 | return file_tree, readme_content, package_files_content 171 | ``` 172 | 173 | ### Step 4: Configure DSPy and Generate llms.txt 174 | 175 | ```python 176 | def generate_llms_txt_for_dspy(): 177 | # Configure DSPy (use your preferred LM) 178 | lm = dspy.LM(model="gpt-4o-mini") 179 | dspy.configure(lm=lm) 180 | os.environ["OPENAI_API_KEY"] = "<YOUR OPENAI KEY>" 181 | 182 | # Initialize our analyzer 183 | analyzer = RepositoryAnalyzer() 184 | 185 | # Gather DSPy repository information 186 | repo_url = "https://github.com/stanfordnlp/dspy" 187 | file_tree, readme_content, package_files = gather_repository_info(repo_url) 188 | 189 | # Generate llms.txt 190 | result = analyzer( 191 | repo_url=repo_url, 192 | file_tree=file_tree, 193 | readme_content=readme_content, 194 | package_files=package_files 195 | ) 196 | 197 | return result 198 | 199 | # Run the generation 200 | if __name__ == "__main__": 201 | result = generate_llms_txt_for_dspy() 202 | 203 | # Save the generated llms.txt 204 | with open("llms.txt", "w") as f: 205 | f.write(result.llms_txt_content) 206 | 207 | print("Generated llms.txt file!") 208 | print("\nPreview:") 209 | print(result.llms_txt_content[:500] + "...") 210 | ``` 211 | 212 | ## Expected Output Structure 213 | 214 | The generated `llms.txt` for DSPy would follow this structure: 215 | 216 | ``` 217 | # DSPy: Programming Language Models 218 | 219 | ## Project Overview 220 | DSPy is a framework for programming—rather than prompting—language models... 221 | 222 | ## Key Concepts 223 | - **Modules**: Building blocks for LM programs 224 | - **Signatures**: Input/output specifications 225 | - **Teleprompters**: Optimization algorithms 226 | - **Predictors**: Core reasoning components 227 | 228 | ## Architecture 229 | - `/dspy/`: Main package directory 230 | - `/adapters/`: Input/output format handlers 231 | - `/clients/`: LM client interfaces 232 | - `/predict/`: Core prediction modules 233 | - `/teleprompt/`: Optimization algorithms 234 | 235 | ## Usage Examples 236 | 1. **Building a Classifier**: Using DSPy, a user can define a modular classifier that takes in text data and categorizes it into predefined classes. The user can specify the classification logic declaratively, allowing for easy adjustments and optimizations. 237 | 2. **Creating a RAG Pipeline**: A developer can implement a retrieval-augmented generation pipeline that first retrieves relevant documents based on a query and then generates a coherent response using those documents. DSPy facilitates the integration of retrieval and generation components seamlessly. 238 | 3. **Optimizing Prompts**: Users can leverage DSPy to create a system that automatically optimizes prompts for language models based on performance metrics, improving the quality of responses over time without manual intervention. 239 | 4. **Implementing Agent Loops**: A user can design an agent loop that continuously interacts with users, learns from feedback, and refines its responses, showcasing the self-improving capabilities of the DSPy framework. 240 | 5. **Compositional Code**: Developers can write compositional code that allows different modules of the AI system to interact with each other, enabling complex workflows that can be easily modified and extended. 241 | ``` 242 | 243 | The resulting `llms.txt` file provides a comprehensive, LLM-friendly overview of the DSPy repository that can help other AI systems better understand and work with the codebase. 244 | 245 | ## Next Steps 246 | 247 | - Extend the program to analyze multiple repositories 248 | - Add support for different documentation formats 249 | - Create metrics for documentation quality assessment 250 | - Build a web interface for interactive repository analysis 251 | ``` -------------------------------------------------------------------------------- /docs/docs/learn/programming/language_models.md: -------------------------------------------------------------------------------- ```markdown 1 | --- 2 | sidebar_position: 2 3 | --- 4 | 5 | # Language Models 6 | 7 | The first step in any DSPy code is to set up your language model. For example, you can configure OpenAI's GPT-4o-mini as your default LM as follows. 8 | 9 | ```python linenums="1" 10 | # Authenticate via `OPENAI_API_KEY` env: import os; os.environ['OPENAI_API_KEY'] = 'here' 11 | lm = dspy.LM('openai/gpt-4o-mini') 12 | dspy.configure(lm=lm) 13 | ``` 14 | 15 | !!! info "A few different LMs" 16 | 17 | === "OpenAI" 18 | You can authenticate by setting the `OPENAI_API_KEY` env variable or passing `api_key` below. 19 | 20 | ```python linenums="1" 21 | import dspy 22 | lm = dspy.LM('openai/gpt-4o-mini', api_key='YOUR_OPENAI_API_KEY') 23 | dspy.configure(lm=lm) 24 | ``` 25 | 26 | === "Gemini (AI Studio)" 27 | You can authenticate by setting the GEMINI_API_KEY env variable or passing `api_key` below. 28 | 29 | ```python linenums="1" 30 | import dspy 31 | lm = dspy.LM('gemini/gemini-2.5-pro-preview-03-25', api_key='GEMINI_API_KEY') 32 | dspy.configure(lm=lm) 33 | ``` 34 | 35 | === "Anthropic" 36 | You can authenticate by setting the ANTHROPIC_API_KEY env variable or passing `api_key` below. 37 | 38 | ```python linenums="1" 39 | import dspy 40 | lm = dspy.LM('anthropic/claude-3-opus-20240229', api_key='YOUR_ANTHROPIC_API_KEY') 41 | dspy.configure(lm=lm) 42 | ``` 43 | 44 | === "Databricks" 45 | If you're on the Databricks platform, authentication is automatic via their SDK. If not, you can set the env variables `DATABRICKS_API_KEY` and `DATABRICKS_API_BASE`, or pass `api_key` and `api_base` below. 46 | 47 | ```python linenums="1" 48 | import dspy 49 | lm = dspy.LM('databricks/databricks-meta-llama-3-1-70b-instruct') 50 | dspy.configure(lm=lm) 51 | ``` 52 | 53 | === "Local LMs on a GPU server" 54 | First, install [SGLang](https://sgl-project.github.io/start/install.html) and launch its server with your LM. 55 | 56 | ```bash 57 | > pip install "sglang[all]" 58 | > pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ 59 | 60 | > CUDA_VISIBLE_DEVICES=0 python -m sglang.launch_server --port 7501 --model-path meta-llama/Meta-Llama-3-8B-Instruct 61 | ``` 62 | 63 | Then, connect to it from your DSPy code as an OpenAI-compatible endpoint. 64 | 65 | ```python linenums="1" 66 | lm = dspy.LM("openai/meta-llama/Meta-Llama-3-8B-Instruct", 67 | api_base="http://localhost:7501/v1", # ensure this points to your port 68 | api_key="", model_type='chat') 69 | dspy.configure(lm=lm) 70 | ``` 71 | 72 | === "Local LMs on your laptop" 73 | First, install [Ollama](https://github.com/ollama/ollama) and launch its server with your LM. 74 | 75 | ```bash 76 | > curl -fsSL https://ollama.ai/install.sh | sh 77 | > ollama run llama3.2:1b 78 | ``` 79 | 80 | Then, connect to it from your DSPy code. 81 | 82 | ```python linenums="1" 83 | import dspy 84 | lm = dspy.LM('ollama_chat/llama3.2', api_base='http://localhost:11434', api_key='') 85 | dspy.configure(lm=lm) 86 | ``` 87 | 88 | === "Other providers" 89 | In DSPy, you can use any of the dozens of [LLM providers supported by LiteLLM](https://docs.litellm.ai/docs/providers). Simply follow their instructions for which `{PROVIDER}_API_KEY` to set and how to write pass the `{provider_name}/{model_name}` to the constructor. 90 | 91 | Some examples: 92 | 93 | - `anyscale/mistralai/Mistral-7B-Instruct-v0.1`, with `ANYSCALE_API_KEY` 94 | - `together_ai/togethercomputer/llama-2-70b-chat`, with `TOGETHERAI_API_KEY` 95 | - `sagemaker/<your-endpoint-name>`, with `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, and `AWS_REGION_NAME` 96 | - `azure/<your_deployment_name>`, with `AZURE_API_KEY`, `AZURE_API_BASE`, `AZURE_API_VERSION`, and the optional `AZURE_AD_TOKEN` and `AZURE_API_TYPE` as environment variables. If you are initiating external models without setting environment variables, use the following: 97 | `lm = dspy.LM('azure/<your_deployment_name>', api_key = 'AZURE_API_KEY' , api_base = 'AZURE_API_BASE', api_version = 'AZURE_API_VERSION')` 98 | 99 | 100 | 101 | If your provider offers an OpenAI-compatible endpoint, just add an `openai/` prefix to your full model name. 102 | 103 | ```python linenums="1" 104 | import dspy 105 | lm = dspy.LM('openai/your-model-name', api_key='PROVIDER_API_KEY', api_base='YOUR_PROVIDER_URL') 106 | dspy.configure(lm=lm) 107 | ``` 108 | If you run into errors, please refer to the [LiteLLM Docs](https://docs.litellm.ai/docs/providers) to verify if you are using the same variable names/following the right procedure. 109 | 110 | ## Calling the LM directly. 111 | 112 | It's easy to call the `lm` you configured above directly. This gives you a unified API and lets you benefit from utilities like automatic caching. 113 | 114 | ```python linenums="1" 115 | lm("Say this is a test!", temperature=0.7) # => ['This is a test!'] 116 | lm(messages=[{"role": "user", "content": "Say this is a test!"}]) # => ['This is a test!'] 117 | ``` 118 | 119 | ## Using the LM with DSPy modules. 120 | 121 | Idiomatic DSPy involves using _modules_, which we discuss in the next guide. 122 | 123 | ```python linenums="1" 124 | # Define a module (ChainOfThought) and assign it a signature (return an answer, given a question). 125 | qa = dspy.ChainOfThought('question -> answer') 126 | 127 | # Run with the default LM configured with `dspy.configure` above. 128 | response = qa(question="How many floors are in the castle David Gregory inherited?") 129 | print(response.answer) 130 | ``` 131 | **Possible Output:** 132 | ```text 133 | The castle David Gregory inherited has 7 floors. 134 | ``` 135 | 136 | ## Using multiple LMs. 137 | 138 | You can change the default LM globally with `dspy.configure` or change it inside a block of code with `dspy.context`. 139 | 140 | !!! tip 141 | Using `dspy.configure` and `dspy.context` is thread-safe! 142 | 143 | 144 | ```python linenums="1" 145 | dspy.configure(lm=dspy.LM('openai/gpt-4o-mini')) 146 | response = qa(question="How many floors are in the castle David Gregory inherited?") 147 | print('GPT-4o-mini:', response.answer) 148 | 149 | with dspy.context(lm=dspy.LM('openai/gpt-3.5-turbo')): 150 | response = qa(question="How many floors are in the castle David Gregory inherited?") 151 | print('GPT-3.5-turbo:', response.answer) 152 | ``` 153 | **Possible Output:** 154 | ```text 155 | GPT-4o-mini: The number of floors in the castle David Gregory inherited cannot be determined with the information provided. 156 | GPT-3.5-turbo: The castle David Gregory inherited has 7 floors. 157 | ``` 158 | 159 | ## Configuring LM generation. 160 | 161 | For any LM, you can configure any of the following attributes at initialization or in each subsequent call. 162 | 163 | ```python linenums="1" 164 | gpt_4o_mini = dspy.LM('openai/gpt-4o-mini', temperature=0.9, max_tokens=3000, stop=None, cache=False) 165 | ``` 166 | 167 | By default LMs in DSPy are cached. If you repeat the same call, you will get the same outputs. But you can turn off caching by setting `cache=False`. 168 | 169 | If you want to keep caching enabled but force a new request (for example, to obtain diverse outputs), 170 | pass a unique `rollout_id` and set a non-zero `temperature` in your call. DSPy hashes both the inputs 171 | and the `rollout_id` when looking up a cache entry, so different values force a new LM request while 172 | still caching future calls with the same inputs and `rollout_id`. The ID is also recorded in 173 | `lm.history`, which makes it easy to track or compare different rollouts during experiments. Changing 174 | only the `rollout_id` while keeping `temperature=0` will not affect the LM's output. 175 | 176 | ```python linenums="1" 177 | lm("Say this is a test!", rollout_id=1, temperature=1.0) 178 | ``` 179 | 180 | You can pass these LM kwargs directly to DSPy modules as well. Supplying them at 181 | initialization sets the defaults for every call: 182 | 183 | ```python linenums="1" 184 | predict = dspy.Predict("question -> answer", rollout_id=1, temperature=1.0) 185 | ``` 186 | 187 | To override them for a single invocation, provide a ``config`` dictionary when 188 | calling the module: 189 | 190 | ```python linenums="1" 191 | predict = dspy.Predict("question -> answer") 192 | predict(question="What is 1 + 52?", config={"rollout_id": 5, "temperature": 1.0}) 193 | ``` 194 | 195 | In both cases, ``rollout_id`` is forwarded to the underlying LM, affects 196 | its caching behavior, and is stored alongside each response so you can 197 | replay or analyze specific rollouts later. 198 | 199 | 200 | ## Inspecting output and usage metadata. 201 | 202 | Every LM object maintains the history of its interactions, including inputs, outputs, token usage (and $$$ cost), and metadata. 203 | 204 | ```python linenums="1" 205 | len(lm.history) # e.g., 3 calls to the LM 206 | 207 | lm.history[-1].keys() # access the last call to the LM, with all metadata 208 | ``` 209 | 210 | **Output:** 211 | ```text 212 | dict_keys(['prompt', 'messages', 'kwargs', 'response', 'outputs', 'usage', 'cost', 'timestamp', 'uuid', 'model', 'response_model', 'model_type]) 213 | ``` 214 | 215 | ## Using the Responses API 216 | 217 | By default, DSPy calls language models (LMs) using LiteLLM's [Chat Completions API](https://docs.litellm.ai/docs/completion), which is suitable for most standard models and tasks. However, some advanced models, such as OpenAI's reasoning models (e.g., `gpt-5` or other future models), may offer improved quality or additional features when accessed via the [Responses API](https://docs.litellm.ai/docs/response_api), which is supported in DSPy. 218 | 219 | **When should you use the Responses API?** 220 | 221 | - If you are working with models that support or require the `responses` endpoint (such as OpenAI's reasoning models). 222 | - When you want to leverage enhanced reasoning, multi-turn, or richer output capabilities provided by certain models. 223 | 224 | **How to enable the Responses API in DSPy:** 225 | 226 | To enable the Responses API, just set `model_type="responses"` when creating the `dspy.LM` instance. 227 | 228 | ```python 229 | import dspy 230 | 231 | # Configure DSPy to use the Responses API for your language model 232 | dspy.settings.configure( 233 | lm=dspy.LM( 234 | "openai/gpt-5-mini", 235 | model_type="responses", 236 | temperature=1.0, 237 | max_tokens=16000, 238 | ), 239 | ) 240 | ``` 241 | 242 | Please note that not all models or providers support the Responses API, check [LiteLLM's documentation](https://docs.litellm.ai/docs/response_api) for more details. 243 | 244 | 245 | ## Advanced: Building custom LMs and writing your own Adapters. 246 | 247 | Though rarely needed, you can write custom LMs by inheriting from `dspy.BaseLM`. Another advanced layer in the DSPy ecosystem is that of _adapters_, which sit between DSPy signatures and LMs. A future version of this guide will discuss these advanced features, though you likely don't need them. 248 | 249 | ``` -------------------------------------------------------------------------------- /dspy/adapters/chat_adapter.py: -------------------------------------------------------------------------------- ```python 1 | import re 2 | import textwrap 3 | from typing import Any, NamedTuple 4 | 5 | from litellm import ContextWindowExceededError 6 | from pydantic.fields import FieldInfo 7 | 8 | from dspy.adapters.base import Adapter 9 | from dspy.adapters.utils import ( 10 | format_field_value, 11 | get_annotation_name, 12 | get_field_description_string, 13 | parse_value, 14 | translate_field_type, 15 | ) 16 | from dspy.clients.lm import LM 17 | from dspy.signatures.signature import Signature 18 | from dspy.utils.exceptions import AdapterParseError 19 | 20 | field_header_pattern = re.compile(r"\[\[ ## (\w+) ## \]\]") 21 | 22 | 23 | class FieldInfoWithName(NamedTuple): 24 | name: str 25 | info: FieldInfo 26 | 27 | 28 | class ChatAdapter(Adapter): 29 | def __call__( 30 | self, 31 | lm: LM, 32 | lm_kwargs: dict[str, Any], 33 | signature: type[Signature], 34 | demos: list[dict[str, Any]], 35 | inputs: dict[str, Any], 36 | ) -> list[dict[str, Any]]: 37 | try: 38 | return super().__call__(lm, lm_kwargs, signature, demos, inputs) 39 | except Exception as e: 40 | # fallback to JSONAdapter 41 | from dspy.adapters.json_adapter import JSONAdapter 42 | 43 | if isinstance(e, ContextWindowExceededError) or isinstance(self, JSONAdapter): 44 | # On context window exceeded error or already using JSONAdapter, we don't want to retry with a different 45 | # adapter. 46 | raise e 47 | return JSONAdapter()(lm, lm_kwargs, signature, demos, inputs) 48 | 49 | async def acall( 50 | self, 51 | lm: LM, 52 | lm_kwargs: dict[str, Any], 53 | signature: type[Signature], 54 | demos: list[dict[str, Any]], 55 | inputs: dict[str, Any], 56 | ) -> list[dict[str, Any]]: 57 | try: 58 | return await super().acall(lm, lm_kwargs, signature, demos, inputs) 59 | except Exception as e: 60 | # fallback to JSONAdapter 61 | from dspy.adapters.json_adapter import JSONAdapter 62 | 63 | if isinstance(e, ContextWindowExceededError) or isinstance(self, JSONAdapter): 64 | # On context window exceeded error or already using JSONAdapter, we don't want to retry with a different 65 | # adapter. 66 | raise e 67 | return await JSONAdapter().acall(lm, lm_kwargs, signature, demos, inputs) 68 | 69 | def format_field_description(self, signature: type[Signature]) -> str: 70 | return ( 71 | f"Your input fields are:\n{get_field_description_string(signature.input_fields)}\n" 72 | f"Your output fields are:\n{get_field_description_string(signature.output_fields)}" 73 | ) 74 | 75 | def format_field_structure(self, signature: type[Signature]) -> str: 76 | """ 77 | `ChatAdapter` requires input and output fields to be in their own sections, with section header using markers 78 | `[[ ## field_name ## ]]`. An arbitrary field `completed` ([[ ## completed ## ]]) is added to the end of the 79 | output fields section to indicate the end of the output fields. 80 | """ 81 | parts = [] 82 | parts.append("All interactions will be structured in the following way, with the appropriate values filled in.") 83 | 84 | def format_signature_fields_for_instructions(fields: dict[str, FieldInfo]): 85 | return self.format_field_with_value( 86 | fields_with_values={ 87 | FieldInfoWithName(name=field_name, info=field_info): translate_field_type(field_name, field_info) 88 | for field_name, field_info in fields.items() 89 | }, 90 | ) 91 | 92 | parts.append(format_signature_fields_for_instructions(signature.input_fields)) 93 | parts.append(format_signature_fields_for_instructions(signature.output_fields)) 94 | parts.append("[[ ## completed ## ]]\n") 95 | return "\n\n".join(parts).strip() 96 | 97 | def format_task_description(self, signature: type[Signature]) -> str: 98 | instructions = textwrap.dedent(signature.instructions) 99 | objective = ("\n" + " " * 8).join([""] + instructions.splitlines()) 100 | return f"In adhering to this structure, your objective is: {objective}" 101 | 102 | def format_user_message_content( 103 | self, 104 | signature: type[Signature], 105 | inputs: dict[str, Any], 106 | prefix: str = "", 107 | suffix: str = "", 108 | main_request: bool = False, 109 | ) -> str: 110 | messages = [prefix] 111 | for k, v in signature.input_fields.items(): 112 | if k in inputs: 113 | value = inputs.get(k) 114 | formatted_field_value = format_field_value(field_info=v, value=value) 115 | messages.append(f"[[ ## {k} ## ]]\n{formatted_field_value}") 116 | 117 | if main_request: 118 | output_requirements = self.user_message_output_requirements(signature) 119 | if output_requirements is not None: 120 | messages.append(output_requirements) 121 | 122 | messages.append(suffix) 123 | return "\n\n".join(messages).strip() 124 | 125 | def user_message_output_requirements(self, signature: type[Signature]) -> str: 126 | """Returns a simplified format reminder for the language model. 127 | 128 | In chat-based interactions, language models may lose track of the required output format 129 | as the conversation context grows longer. This method generates a concise reminder of 130 | the expected output structure that can be included in user messages. 131 | 132 | Args: 133 | signature (Type[Signature]): The DSPy signature defining the expected input/output fields. 134 | 135 | Returns: 136 | str: A simplified description of the required output format. 137 | 138 | Note: 139 | This is a more lightweight version of `format_field_structure` specifically designed 140 | for inline reminders within chat messages. 141 | """ 142 | 143 | def type_info(v): 144 | if v.annotation is not str: 145 | return f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})" 146 | else: 147 | return "" 148 | 149 | message = "Respond with the corresponding output fields, starting with the field " 150 | message += ", then ".join(f"`[[ ## {f} ## ]]`{type_info(v)}" for f, v in signature.output_fields.items()) 151 | message += ", and then ending with the marker for `[[ ## completed ## ]]`." 152 | return message 153 | 154 | def format_assistant_message_content( 155 | self, 156 | signature: type[Signature], 157 | outputs: dict[str, Any], 158 | missing_field_message=None, 159 | ) -> str: 160 | assistant_message_content = self.format_field_with_value( 161 | { 162 | FieldInfoWithName(name=k, info=v): outputs.get(k, missing_field_message) 163 | for k, v in signature.output_fields.items() 164 | }, 165 | ) 166 | assistant_message_content += "\n\n[[ ## completed ## ]]\n" 167 | return assistant_message_content 168 | 169 | def parse(self, signature: type[Signature], completion: str) -> dict[str, Any]: 170 | sections = [(None, [])] 171 | 172 | for line in completion.splitlines(): 173 | match = field_header_pattern.match(line.strip()) 174 | if match: 175 | # If the header pattern is found, split the rest of the line as content 176 | header = match.group(1) 177 | remaining_content = line[match.end() :].strip() 178 | sections.append((header, [remaining_content] if remaining_content else [])) 179 | else: 180 | sections[-1][1].append(line) 181 | 182 | sections = [(k, "\n".join(v).strip()) for k, v in sections] 183 | 184 | fields = {} 185 | for k, v in sections: 186 | if (k not in fields) and (k in signature.output_fields): 187 | try: 188 | fields[k] = parse_value(v, signature.output_fields[k].annotation) 189 | except Exception as e: 190 | raise AdapterParseError( 191 | adapter_name="ChatAdapter", 192 | signature=signature, 193 | lm_response=completion, 194 | message=f"Failed to parse field {k} with value {v} from the LM response. Error message: {e}", 195 | ) 196 | if fields.keys() != signature.output_fields.keys(): 197 | raise AdapterParseError( 198 | adapter_name="ChatAdapter", 199 | signature=signature, 200 | lm_response=completion, 201 | parsed_result=fields, 202 | ) 203 | 204 | return fields 205 | 206 | def format_field_with_value(self, fields_with_values: dict[FieldInfoWithName, Any]) -> str: 207 | """ 208 | Formats the values of the specified fields according to the field's DSPy type (input or output), 209 | annotation (e.g. str, int, etc.), and the type of the value itself. Joins the formatted values 210 | into a single string, which is is a multiline string if there are multiple fields. 211 | 212 | Args: 213 | fields_with_values: A dictionary mapping information about a field to its corresponding 214 | value. 215 | 216 | Returns: 217 | The joined formatted values of the fields, represented as a string 218 | """ 219 | output = [] 220 | for field, field_value in fields_with_values.items(): 221 | formatted_field_value = format_field_value(field_info=field.info, value=field_value) 222 | output.append(f"[[ ## {field.name} ## ]]\n{formatted_field_value}") 223 | 224 | return "\n\n".join(output).strip() 225 | 226 | def format_finetune_data( 227 | self, 228 | signature: type[Signature], 229 | demos: list[dict[str, Any]], 230 | inputs: dict[str, Any], 231 | outputs: dict[str, Any], 232 | ) -> dict[str, list[Any]]: 233 | """ 234 | Format the call data into finetuning data according to the OpenAI API specifications. 235 | 236 | For the chat adapter, this means formatting the data as a list of messages, where each message is a dictionary 237 | with a "role" and "content" key. The role can be "system", "user", or "assistant". Then, the messages are 238 | wrapped in a dictionary with a "messages" key. 239 | """ 240 | system_user_messages = self.format( # returns a list of dicts with the keys "role" and "content" 241 | signature=signature, demos=demos, inputs=inputs 242 | ) 243 | assistant_message_content = self.format_assistant_message_content( # returns a string, without the role 244 | signature=signature, outputs=outputs 245 | ) 246 | assistant_message = {"role": "assistant", "content": assistant_message_content} 247 | messages = system_user_messages + [assistant_message] 248 | return {"messages": messages} 249 | ``` -------------------------------------------------------------------------------- /docs/docs/tutorials/mem0_react_agent/index.md: -------------------------------------------------------------------------------- ```markdown 1 | # Building Memory-Enabled Agents with DSPy ReAct and Mem0 2 | 3 | This tutorial demonstrates how to build intelligent conversational agents that can remember information across interactions using DSPy's ReAct framework combined with [Mem0](https://docs.mem0.ai/)'s memory capabilities. You'll learn to create agents that can store, retrieve, and use contextual information to provide personalized and coherent responses. 4 | 5 | ## What You'll Build 6 | 7 | By the end of this tutorial, you'll have a memory-enabled agent that can: 8 | 9 | - **Remember user preferences** and past conversations 10 | - **Store and retrieve factual information** about users and topics 11 | - **Use memory to inform decisions** and provide personalized responses 12 | - **Handle complex multi-turn conversations** with context awareness 13 | - **Manage different types of memories** (facts, preferences, experiences) 14 | 15 | ## Prerequisites 16 | 17 | - Basic understanding of DSPy and ReAct agents 18 | - Python 3.9+ installed 19 | - API keys for your preferred LLM provider 20 | 21 | ## Installation and Setup 22 | 23 | ```bash 24 | pip install dspy mem0ai 25 | ``` 26 | 27 | ## Step 1: Understanding Mem0 Integration 28 | 29 | Mem0 provides a memory layer that can store, search, and retrieve memories for AI agents. Let's start by understanding how to integrate it with DSPy: 30 | 31 | ```python 32 | import dspy 33 | from mem0 import Memory 34 | import os 35 | from typing import List, Dict, Any, Optional 36 | from datetime import datetime 37 | 38 | # Configure environment 39 | os.environ["OPENAI_API_KEY"] = "your-openai-api-key" 40 | 41 | # Initialize Mem0 memory system 42 | config = { 43 | "llm": { 44 | "provider": "openai", 45 | "config": { 46 | "model": "gpt-4o-mini", 47 | "temperature": 0.1 48 | } 49 | }, 50 | "embedder": { 51 | "provider": "openai", 52 | "config": { 53 | "model": "text-embedding-3-small" 54 | } 55 | } 56 | } 57 | ``` 58 | 59 | ## Step 2: Create Memory-Aware Tools 60 | 61 | Let's create tools that can interact with the memory system: 62 | 63 | ```python 64 | import datetime 65 | 66 | class MemoryTools: 67 | """Tools for interacting with the Mem0 memory system.""" 68 | 69 | def __init__(self, memory: Memory): 70 | self.memory = memory 71 | 72 | def store_memory(self, content: str, user_id: str = "default_user") -> str: 73 | """Store information in memory.""" 74 | try: 75 | self.memory.add(content, user_id=user_id) 76 | return f"Stored memory: {content}" 77 | except Exception as e: 78 | return f"Error storing memory: {str(e)}" 79 | 80 | def search_memories(self, query: str, user_id: str = "default_user", limit: int = 5) -> str: 81 | """Search for relevant memories.""" 82 | try: 83 | results = self.memory.search(query, user_id=user_id, limit=limit) 84 | if not results: 85 | return "No relevant memories found." 86 | 87 | memory_text = "Relevant memories found:\n" 88 | for i, result in enumerate(results["results"]): 89 | memory_text += f"{i}. {result['memory']}\n" 90 | return memory_text 91 | except Exception as e: 92 | return f"Error searching memories: {str(e)}" 93 | 94 | def get_all_memories(self, user_id: str = "default_user") -> str: 95 | """Get all memories for a user.""" 96 | try: 97 | results = self.memory.get_all(user_id=user_id) 98 | if not results: 99 | return "No memories found for this user." 100 | 101 | memory_text = "All memories for user:\n" 102 | for i, result in enumerate(results["results"]): 103 | memory_text += f"{i}. {result['memory']}\n" 104 | return memory_text 105 | except Exception as e: 106 | return f"Error retrieving memories: {str(e)}" 107 | 108 | def update_memory(self, memory_id: str, new_content: str) -> str: 109 | """Update an existing memory.""" 110 | try: 111 | self.memory.update(memory_id, new_content) 112 | return f"Updated memory with new content: {new_content}" 113 | except Exception as e: 114 | return f"Error updating memory: {str(e)}" 115 | 116 | def delete_memory(self, memory_id: str) -> str: 117 | """Delete a specific memory.""" 118 | try: 119 | self.memory.delete(memory_id) 120 | return "Memory deleted successfully." 121 | except Exception as e: 122 | return f"Error deleting memory: {str(e)}" 123 | 124 | def get_current_time() -> str: 125 | """Get the current date and time.""" 126 | return datetime.now().strftime("%Y-%m-%d %H:%M:%S") 127 | ``` 128 | 129 | ## Step 3: Build the Memory-Enhanced ReAct Agent 130 | 131 | Now let's create our main ReAct agent that can use memory: 132 | 133 | ```python 134 | class MemoryQA(dspy.Signature): 135 | """ 136 | You're a helpful assistant and have access to memory method. 137 | Whenever you answer a user's input, remember to store the information in memory 138 | so that you can use it later. 139 | """ 140 | user_input: str = dspy.InputField() 141 | response: str = dspy.OutputField() 142 | 143 | class MemoryReActAgent(dspy.Module): 144 | """A ReAct agent enhanced with Mem0 memory capabilities.""" 145 | 146 | def __init__(self, memory: Memory): 147 | super().__init__() 148 | self.memory_tools = MemoryTools(memory) 149 | 150 | # Create tools list for ReAct 151 | self.tools = [ 152 | self.memory_tools.store_memory, 153 | self.memory_tools.search_memories, 154 | self.memory_tools.get_all_memories, 155 | get_current_time, 156 | self.set_reminder, 157 | self.get_preferences, 158 | self.update_preferences, 159 | ] 160 | 161 | # Initialize ReAct with our tools 162 | self.react = dspy.ReAct( 163 | signature=MemoryQA, 164 | tools=self.tools, 165 | max_iters=6 166 | ) 167 | 168 | def forward(self, user_input: str): 169 | """Process user input with memory-aware reasoning.""" 170 | 171 | return self.react(user_input=user_input) 172 | 173 | def set_reminder(self, reminder_text: str, date_time: str = None, user_id: str = "default_user") -> str: 174 | """Set a reminder for the user.""" 175 | reminder = f"Reminder set for {date_time}: {reminder_text}" 176 | return self.memory_tools.store_memory( 177 | f"REMINDER: {reminder}", 178 | user_id=user_id 179 | ) 180 | 181 | def get_preferences(self, category: str = "general", user_id: str = "default_user") -> str: 182 | """Get user preferences for a specific category.""" 183 | query = f"user preferences {category}" 184 | return self.memory_tools.search_memories( 185 | query=query, 186 | user_id=user_id 187 | ) 188 | 189 | def update_preferences(self, category: str, preference: str, user_id: str = "default_user") -> str: 190 | """Update user preferences.""" 191 | preference_text = f"User preference for {category}: {preference}" 192 | return self.memory_tools.store_memory( 193 | preference_text, 194 | user_id=user_id 195 | ) 196 | ``` 197 | 198 | ## Step 4: Running the Memory-Enhanced Agent 199 | 200 | Let's create a simple interface to interact with our memory-enabled agent: 201 | 202 | ```python 203 | import time 204 | def run_memory_agent_demo(): 205 | """Demonstration of memory-enhanced ReAct agent.""" 206 | 207 | # Configure DSPy 208 | lm = dspy.LM(model='openai/gpt-4o-mini') 209 | dspy.configure(lm=lm) 210 | 211 | # Initialize memory system 212 | memory = Memory.from_config(config) 213 | 214 | # Create our agent 215 | agent = MemoryReActAgent(memory) 216 | 217 | # Sample conversation demonstrating memory capabilities 218 | print("🧠 Memory-Enhanced ReAct Agent Demo") 219 | print("=" * 50) 220 | 221 | conversations = [ 222 | "Hi, I'm Alice and I love Italian food, especially pasta carbonara.", 223 | "I'm Alice. I prefer to exercise in the morning around 7 AM.", 224 | "I'm Alice. What do you remember about my food preferences?", 225 | "I'm Alice. Set a reminder for me to go grocery shopping tomorrow.", 226 | "I'm Alice. What are my exercise preferences?", 227 | "I'm Alice. I also enjoy hiking on weekends.", 228 | "I'm Alice. What do you know about me so far?" 229 | ] 230 | 231 | for i, user_input in enumerate(conversations, 1): 232 | print(f"\n📝 User: {user_input}") 233 | 234 | try: 235 | response = agent(user_input=user_input) 236 | print(f"🤖 Agent: {response.response}") 237 | time.sleep(1) 238 | 239 | except Exception as e: 240 | print(f"❌ Error: {e}") 241 | 242 | # Run the demonstration 243 | if __name__ == "__main__": 244 | run_memory_agent_demo() 245 | ``` 246 | 247 | ## Expected Output 248 | ``` 249 | 🧠 Memory-Enhanced ReAct Agent Demo 250 | ================================================== 251 | 252 | 📝 User: Hi, I'm Alice and I love Italian food, especially pasta carbonara. 253 | 254 | 🤖 Agent: Hi Alice! It's great to meet you! Pasta carbonara is a delicious choice. Do you have a favorite place to get it, or do you enjoy making it at home? 255 | 256 | 📝 User: I'm Alice. I prefer to exercise in the morning around 7 AM. 257 | 258 | 🤖 Agent: Nice to meet you, Alice! I've noted that you prefer to exercise in the morning around 7 AM. If you have any questions or need tips related to your morning workouts, feel free to ask! 259 | 260 | 📝 User: I'm Alice. What do you remember about my food preferences? 261 | 262 | 🤖 Agent: I remember that you love Italian food, and your favorite dish is pasta carbonara! 263 | 264 | 📝 User: I'm Alice. Set a reminder for me to go grocery shopping tomorrow. 265 | 266 | 🤖 Agent: I have set a reminder for you, Alice, to go grocery shopping tomorrow. 267 | 268 | 📝 User: I'm Alice. What are my exercise preferences? 269 | 270 | 🤖 Agent: Your exercise preference is to work out in the morning around 7 AM. 271 | 272 | 📝 User: I'm Alice. I also enjoy hiking on weekends. 273 | 274 | 🤖 Agent: Nice to meet you, Alice! It's great to hear that you enjoy hiking on weekends. If you need any recommendations for trails or tips for your hikes, feel free to ask! 275 | 276 | 📝 User: I'm Alice. What do you know about me so far? 277 | 278 | 🤖 Agent: So far, I know that you enjoy hiking on weekends, prefer to exercise in the morning around 7 AM, and love Italian food, especially pasta carbonara. Additionally, you have a reminder set for grocery shopping on October 7, 2023. 279 | ``` 280 | 281 | ## Next Steps 282 | 283 | - **Implement memory persistence** with databases (PostgreSQL, MongoDB) 284 | - **Add memory categorization** and tagging for better organization 285 | - **Create memory expiration policies** for data management 286 | - **Build multi-user memory isolation** for production applications 287 | - **Add memory analytics** and insights 288 | - **Integrate with vector databases** for enhanced semantic search 289 | - **Implement memory compression** for long-term storage efficiency 290 | 291 | This tutorial demonstrates how DSPy's ReAct framework can be enhanced with Mem0's memory capabilities to create intelligent, context-aware agents that can learn and remember information across interactions, making them more useful for real-world applications. ``` -------------------------------------------------------------------------------- /dspy/adapters/utils.py: -------------------------------------------------------------------------------- ```python 1 | import ast 2 | import enum 3 | import inspect 4 | import json 5 | import types 6 | from collections.abc import Mapping 7 | from typing import Any, Literal, Union, get_args, get_origin 8 | 9 | import json_repair 10 | import pydantic 11 | from pydantic import TypeAdapter 12 | from pydantic.fields import FieldInfo 13 | 14 | from dspy.adapters.types.base_type import Type as DspyType 15 | from dspy.signatures.utils import get_dspy_field_type 16 | 17 | 18 | def serialize_for_json(value: Any) -> Any: 19 | """ 20 | Formats the specified value so that it can be serialized as a JSON string. 21 | 22 | Args: 23 | value: The value to format as a JSON string. 24 | Returns: 25 | The formatted value, which is serializable as a JSON string. 26 | """ 27 | # Attempt to format the value as a JSON-compatible object using pydantic, falling back to 28 | # a string representation of the value if that fails (e.g. if the value contains an object 29 | # that pydantic doesn't recognize or can't serialize) 30 | try: 31 | return TypeAdapter(type(value)).dump_python(value, mode="json") 32 | except Exception: 33 | return str(value) 34 | 35 | 36 | def format_field_value(field_info: FieldInfo, value: Any, assume_text=True) -> str | dict: 37 | """ 38 | Formats the value of the specified field according to the field's DSPy type (input or output), 39 | annotation (e.g. str, int, etc.), and the type of the value itself. 40 | 41 | Args: 42 | field_info: Information about the field, including its DSPy field type and annotation. 43 | value: The value of the field. 44 | Returns: 45 | The formatted value of the field, represented as a string. 46 | """ 47 | string_value = None 48 | if isinstance(value, list) and field_info.annotation is str: 49 | # If the field has no special type requirements, format it as a nice numbered list for the LM. 50 | string_value = _format_input_list_field_value(value) 51 | else: 52 | jsonable_value = serialize_for_json(value) 53 | if isinstance(jsonable_value, dict) or isinstance(jsonable_value, list): 54 | string_value = json.dumps(jsonable_value, ensure_ascii=False) 55 | else: 56 | # If the value is not a Python representation of a JSON object or Array 57 | # (e.g. the value is a JSON string), just use the string representation of the value 58 | # to avoid double-quoting the JSON string (which would hurt accuracy for certain 59 | # tasks, e.g. tasks that rely on computing string length) 60 | string_value = str(jsonable_value) 61 | 62 | if assume_text: 63 | return string_value 64 | else: 65 | return {"type": "text", "text": string_value} 66 | 67 | 68 | def _get_json_schema(field_type): 69 | def move_type_to_front(d): 70 | # Move the 'type' key to the front of the dictionary, recursively, for LLM readability/adherence. 71 | if isinstance(d, Mapping): 72 | return { 73 | k: move_type_to_front(v) for k, v in sorted(d.items(), key=lambda item: (item[0] != "type", item[0])) 74 | } 75 | elif isinstance(d, list): 76 | return [move_type_to_front(item) for item in d] 77 | return d 78 | 79 | schema = pydantic.TypeAdapter(field_type).json_schema() 80 | schema = move_type_to_front(schema) 81 | return schema 82 | 83 | 84 | def translate_field_type(field_name, field_info): 85 | field_type = field_info.annotation 86 | 87 | if get_dspy_field_type(field_info) == "input" or field_type is str: 88 | desc = "" 89 | elif field_type is bool: 90 | desc = "must be True or False" 91 | elif field_type in (int, float): 92 | desc = f"must be a single {field_type.__name__} value" 93 | elif inspect.isclass(field_type) and issubclass(field_type, enum.Enum): 94 | enum_vals = "; ".join(str(member.value) for member in field_type) 95 | desc = f"must be one of: {enum_vals}" 96 | elif hasattr(field_type, "__origin__") and field_type.__origin__ is Literal: 97 | desc = ( 98 | # Strongly encourage the LM to avoid choosing values that don't appear in the 99 | # literal or returning a value of the form 'Literal[<selected_value>]' 100 | f"must exactly match (no extra characters) one of: {'; '.join([str(x) for x in field_type.__args__])}" 101 | ) 102 | else: 103 | desc = f"must adhere to the JSON schema: {json.dumps(_get_json_schema(field_type), ensure_ascii=False)}" 104 | 105 | desc = (" " * 8) + f"# note: the value you produce {desc}" if desc else "" 106 | return f"{{{field_name}}}{desc}" 107 | 108 | 109 | def find_enum_member(enum, identifier): 110 | """ 111 | Finds the enum member corresponding to the specified identifier, which may be the 112 | enum member's name or value. 113 | 114 | Args: 115 | enum: The enum to search for the member. 116 | identifier: If the enum is explicitly-valued, this is the value of the enum member to find. 117 | If the enum is auto-valued, this is the name of the enum member to find. 118 | Returns: 119 | The enum member corresponding to the specified identifier. 120 | """ 121 | # Check if the identifier is a valid enum member value *before* checking if it's a valid enum 122 | # member name, since the identifier will be a value for explicitly-valued enums. This handles 123 | # the (rare) case where an enum member value is the same as another enum member's name in 124 | # an explicitly-valued enum 125 | for member in enum: 126 | if member.value == identifier: 127 | return member 128 | 129 | # If the identifier is not a valid enum member value, check if it's a valid enum member name, 130 | # since the identifier will be a member name for auto-valued enums 131 | if identifier in enum.__members__: 132 | return enum[identifier] 133 | 134 | raise ValueError(f"{identifier} is not a valid name or value for the enum {enum.__name__}") 135 | 136 | 137 | def parse_value(value, annotation): 138 | if annotation is str: 139 | return str(value) 140 | 141 | if isinstance(annotation, enum.EnumMeta): 142 | return find_enum_member(annotation, value) 143 | 144 | origin = get_origin(annotation) 145 | 146 | if origin is Literal: 147 | allowed = get_args(annotation) 148 | if value in allowed: 149 | return value 150 | 151 | if isinstance(value, str): 152 | v = value.strip() 153 | if v.startswith(("Literal[", "str[")) and v.endswith("]"): 154 | v = v[v.find("[") + 1 : -1] 155 | if len(v) > 1 and v[0] == v[-1] and v[0] in "\"'": 156 | v = v[1:-1] 157 | 158 | if v in allowed: 159 | return v 160 | 161 | raise ValueError(f"{value!r} is not one of {allowed!r}") 162 | 163 | if not isinstance(value, str): 164 | return TypeAdapter(annotation).validate_python(value) 165 | 166 | if origin in (Union, types.UnionType) and type(None) in get_args(annotation) and str in get_args(annotation): 167 | # Handle union annotations, e.g., `str | None`, `Optional[str]`, `Union[str, int, None]`, etc. 168 | return TypeAdapter(annotation).validate_python(value) 169 | 170 | candidate = json_repair.loads(value) # json_repair.loads returns "" on failure. 171 | if candidate == "" and value != "": 172 | try: 173 | candidate = ast.literal_eval(value) 174 | except (ValueError, SyntaxError): 175 | candidate = value 176 | 177 | try: 178 | return TypeAdapter(annotation).validate_python(candidate) 179 | except pydantic.ValidationError as e: 180 | if inspect.isclass(annotation) and issubclass(annotation, DspyType): 181 | try: 182 | # For dspy.Type, try parsing from the original value in case it has a custom parser 183 | return TypeAdapter(annotation).validate_python(value) 184 | except Exception: 185 | raise e 186 | raise 187 | 188 | 189 | def get_annotation_name(annotation): 190 | origin = get_origin(annotation) 191 | args = get_args(annotation) 192 | if origin is None: 193 | if hasattr(annotation, "__name__"): 194 | return annotation.__name__ 195 | else: 196 | return str(annotation) 197 | 198 | if origin is Literal: 199 | args_str = ", ".join( 200 | _quoted_string_for_literal_type_annotation(a) if isinstance(a, str) else get_annotation_name(a) 201 | for a in args 202 | ) 203 | return f"{get_annotation_name(origin)}[{args_str}]" 204 | else: 205 | args_str = ", ".join(get_annotation_name(a) for a in args) 206 | return f"{get_annotation_name(origin)}[{args_str}]" 207 | 208 | 209 | def get_field_description_string(fields: dict) -> str: 210 | field_descriptions = [] 211 | for idx, (k, v) in enumerate(fields.items()): 212 | field_message = f"{idx + 1}. `{k}`" 213 | field_message += f" ({get_annotation_name(v.annotation)})" 214 | desc = v.json_schema_extra["desc"] if v.json_schema_extra["desc"] != f"${{{k}}}" else "" 215 | 216 | custom_types = DspyType.extract_custom_type_from_annotation(v.annotation) 217 | for custom_type in custom_types: 218 | if len(custom_type.description()) > 0: 219 | desc += f"\n Type description of {get_annotation_name(custom_type)}: {custom_type.description()}" 220 | 221 | field_message += f": {desc}" 222 | field_message += ( 223 | f"\nConstraints: {v.json_schema_extra['constraints']}" if v.json_schema_extra.get("constraints") else "" 224 | ) 225 | field_descriptions.append(field_message) 226 | return "\n".join(field_descriptions).strip() 227 | 228 | 229 | def _format_input_list_field_value(value: list[Any]) -> str: 230 | """ 231 | Formats the value of an input field of type list[Any]. 232 | 233 | Args: 234 | value: The value of the list-type input field. 235 | Returns: 236 | A string representation of the input field's list value. 237 | """ 238 | if len(value) == 0: 239 | return "N/A" 240 | if len(value) == 1: 241 | return _format_blob(value[0]) 242 | 243 | return "\n".join([f"[{idx + 1}] {_format_blob(txt)}" for idx, txt in enumerate(value)]) 244 | 245 | 246 | def _format_blob(blob: str) -> str: 247 | """ 248 | Formats the specified text blobs so that an LM can parse it correctly within a list 249 | of multiple text blobs. 250 | 251 | Args: 252 | blob: The text blob to format. 253 | Returns: 254 | The formatted text blob. 255 | """ 256 | if "\n" not in blob and "«" not in blob and "»" not in blob: 257 | return f"«{blob}»" 258 | 259 | modified_blob = blob.replace("\n", "\n ") 260 | return f"«««\n {modified_blob}\n»»»" 261 | 262 | 263 | def _quoted_string_for_literal_type_annotation(s: str) -> str: 264 | """ 265 | Return the specified string quoted for inclusion in a literal type annotation. 266 | """ 267 | has_single = "'" in s 268 | has_double = '"' in s 269 | 270 | if has_single and not has_double: 271 | # Only single quotes => enclose in double quotes 272 | return f'"{s}"' 273 | elif has_double and not has_single: 274 | # Only double quotes => enclose in single quotes 275 | return f"'{s}'" 276 | elif has_single and has_double: 277 | # Both => enclose in single quotes; escape each single quote with \' 278 | escaped = s.replace("'", "\\'") 279 | return f"'{escaped}'" 280 | else: 281 | # Neither => enclose in single quotes 282 | return f"'{s}'" 283 | ``` -------------------------------------------------------------------------------- /dspy/clients/cache.py: -------------------------------------------------------------------------------- ```python 1 | import copy 2 | import inspect 3 | import logging 4 | import threading 5 | from functools import wraps 6 | from hashlib import sha256 7 | from typing import Any 8 | 9 | import cloudpickle 10 | import orjson 11 | import pydantic 12 | from cachetools import LRUCache 13 | from diskcache import FanoutCache 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class Cache: 19 | """DSPy Cache 20 | 21 | `Cache` provides 2 levels of caching (in the given order): 22 | 1. In-memory cache - implemented with cachetools.LRUCache 23 | 2. On-disk cache - implemented with diskcache.FanoutCache 24 | """ 25 | 26 | def __init__( 27 | self, 28 | enable_disk_cache: bool, 29 | enable_memory_cache: bool, 30 | disk_cache_dir: str, 31 | disk_size_limit_bytes: int | None = 1024 * 1024 * 10, 32 | memory_max_entries: int | None = 1000000, 33 | ): 34 | """ 35 | Args: 36 | enable_disk_cache: Whether to enable on-disk cache. 37 | enable_memory_cache: Whether to enable in-memory cache. 38 | disk_cache_dir: The directory where the disk cache is stored. 39 | disk_size_limit_bytes: The maximum size of the disk cache (in bytes). 40 | memory_max_entries: The maximum size of the in-memory cache (in number of items). 41 | """ 42 | 43 | self.enable_disk_cache = enable_disk_cache 44 | self.enable_memory_cache = enable_memory_cache 45 | if self.enable_memory_cache: 46 | self.memory_cache = LRUCache(maxsize=memory_max_entries) 47 | else: 48 | self.memory_cache = {} 49 | if self.enable_disk_cache: 50 | self.disk_cache = FanoutCache( 51 | shards=16, 52 | timeout=10, 53 | directory=disk_cache_dir, 54 | size_limit=disk_size_limit_bytes, 55 | ) 56 | else: 57 | self.disk_cache = {} 58 | 59 | self._lock = threading.RLock() 60 | 61 | def __contains__(self, key: str) -> bool: 62 | """Check if a key is in the cache.""" 63 | return key in self.memory_cache or key in self.disk_cache 64 | 65 | def cache_key(self, request: dict[str, Any], ignored_args_for_cache_key: list[str] | None = None) -> str: 66 | """ 67 | Obtain a unique cache key for the given request dictionary by hashing its JSON 68 | representation. For request fields having types that are known to be JSON-incompatible, 69 | convert them to a JSON-serializable format before hashing. 70 | """ 71 | 72 | ignored_args_for_cache_key = ignored_args_for_cache_key or [] 73 | 74 | def transform_value(value): 75 | if isinstance(value, type) and issubclass(value, pydantic.BaseModel): 76 | return value.model_json_schema() 77 | elif isinstance(value, pydantic.BaseModel): 78 | return value.model_dump(mode="json") 79 | elif callable(value): 80 | # Try to get the source code of the callable if available 81 | import inspect 82 | 83 | try: 84 | # For regular functions, we can get the source code 85 | return f"<callable_source:{inspect.getsource(value)}>" 86 | except (TypeError, OSError): 87 | # For lambda functions or other callables where source isn't available, 88 | # use a string representation 89 | return f"<callable:{value.__name__ if hasattr(value, '__name__') else 'lambda'}>" 90 | elif isinstance(value, dict): 91 | return {k: transform_value(v) for k, v in value.items()} 92 | else: 93 | return value 94 | 95 | params = {k: transform_value(v) for k, v in request.items() if k not in ignored_args_for_cache_key} 96 | return sha256(orjson.dumps(params, option=orjson.OPT_SORT_KEYS)).hexdigest() 97 | 98 | def get(self, request: dict[str, Any], ignored_args_for_cache_key: list[str] | None = None) -> Any: 99 | 100 | if not self.enable_memory_cache and not self.enable_disk_cache: 101 | return None 102 | 103 | try: 104 | key = self.cache_key(request, ignored_args_for_cache_key) 105 | except Exception: 106 | logger.debug(f"Failed to generate cache key for request: {request}") 107 | return None 108 | 109 | if self.enable_memory_cache and key in self.memory_cache: 110 | with self._lock: 111 | response = self.memory_cache[key] 112 | elif self.enable_disk_cache and key in self.disk_cache: 113 | # Found on disk but not in memory cache, add to memory cache 114 | response = self.disk_cache[key] 115 | if self.enable_memory_cache: 116 | with self._lock: 117 | self.memory_cache[key] = response 118 | else: 119 | return None 120 | 121 | response = copy.deepcopy(response) 122 | if hasattr(response, "usage"): 123 | # Clear the usage data when cache is hit, because no LM call is made 124 | response.usage = {} 125 | response.cache_hit = True 126 | return response 127 | 128 | def put( 129 | self, 130 | request: dict[str, Any], 131 | value: Any, 132 | ignored_args_for_cache_key: list[str] | None = None, 133 | enable_memory_cache: bool = True, 134 | ) -> None: 135 | enable_memory_cache = self.enable_memory_cache and enable_memory_cache 136 | 137 | # Early return to avoid computing cache key if both memory and disk cache are disabled 138 | if not enable_memory_cache and not self.enable_disk_cache: 139 | return 140 | 141 | try: 142 | key = self.cache_key(request, ignored_args_for_cache_key) 143 | except Exception: 144 | logger.debug(f"Failed to generate cache key for request: {request}") 145 | return 146 | 147 | if enable_memory_cache: 148 | with self._lock: 149 | self.memory_cache[key] = value 150 | 151 | if self.enable_disk_cache: 152 | try: 153 | self.disk_cache[key] = value 154 | except Exception as e: 155 | # Disk cache writing can fail for different reasons, e.g. disk full or the `value` is not picklable. 156 | logger.debug(f"Failed to put value in disk cache: {value}, {e}") 157 | 158 | def reset_memory_cache(self) -> None: 159 | if not self.enable_memory_cache: 160 | return 161 | 162 | with self._lock: 163 | self.memory_cache.clear() 164 | 165 | def save_memory_cache(self, filepath: str) -> None: 166 | if not self.enable_memory_cache: 167 | return 168 | 169 | with self._lock: 170 | with open(filepath, "wb") as f: 171 | cloudpickle.dump(self.memory_cache, f) 172 | 173 | def load_memory_cache(self, filepath: str) -> None: 174 | if not self.enable_memory_cache: 175 | return 176 | 177 | with self._lock: 178 | with open(filepath, "rb") as f: 179 | self.memory_cache = cloudpickle.load(f) 180 | 181 | 182 | def request_cache( 183 | cache_arg_name: str | None = None, 184 | ignored_args_for_cache_key: list[str] | None = None, 185 | enable_memory_cache: bool = True, 186 | *, # everything after this is keyword-only 187 | maxsize: int | None = None, # legacy / no-op 188 | ): 189 | """ 190 | Decorator for applying caching to a function based on the request argument. 191 | 192 | Args: 193 | cache_arg_name: The name of the argument that contains the request. If not provided, the entire kwargs is used 194 | as the request. 195 | ignored_args_for_cache_key: A list of arguments to ignore when computing the cache key from the request. 196 | enable_memory_cache: Whether to enable in-memory cache at call time. If False, the memory cache will not be 197 | written to on new data. 198 | """ 199 | ignored_args_for_cache_key = ignored_args_for_cache_key or ["api_key", "api_base", "base_url"] 200 | # Deprecation notice 201 | if maxsize is not None: 202 | logger.warning( 203 | "[DEPRECATION] `maxsize` is deprecated and no longer does anything; " 204 | "the cache is now handled internally by `dspy.cache`. " 205 | "This parameter will be removed in a future release.", 206 | ) 207 | 208 | def decorator(fn): 209 | @wraps(fn) 210 | def process_request(args, kwargs): 211 | # Use fully qualified function name for uniqueness 212 | fn_identifier = f"{fn.__module__}.{fn.__qualname__}" 213 | 214 | # Create a modified request that includes the function identifier so that it's incorporated into the cache 215 | # key. Deep copy is required because litellm sometimes modifies the kwargs in place. 216 | if cache_arg_name: 217 | # When `cache_arg_name` is provided, use the value of the argument with this name as the request for 218 | # caching. 219 | modified_request = copy.deepcopy(kwargs[cache_arg_name]) 220 | else: 221 | # When `cache_arg_name` is not provided, use the entire kwargs as the request for caching. 222 | modified_request = copy.deepcopy(kwargs) 223 | for i, arg in enumerate(args): 224 | modified_request[f"positional_arg_{i}"] = arg 225 | modified_request["_fn_identifier"] = fn_identifier 226 | 227 | return modified_request 228 | 229 | @wraps(fn) 230 | def sync_wrapper(*args, **kwargs): 231 | import dspy 232 | 233 | cache = dspy.cache 234 | modified_request = process_request(args, kwargs) 235 | 236 | # Retrieve from cache if available 237 | cached_result = cache.get(modified_request, ignored_args_for_cache_key) 238 | 239 | if cached_result is not None: 240 | return cached_result 241 | 242 | # Otherwise, compute and store the result 243 | # Make a copy of the original request in case it's modified in place, e.g., deleting some fields 244 | original_request = copy.deepcopy(modified_request) 245 | result = fn(*args, **kwargs) 246 | # `enable_memory_cache` can be provided at call time to avoid indefinite growth. 247 | cache.put(original_request, result, ignored_args_for_cache_key, enable_memory_cache) 248 | 249 | return result 250 | 251 | @wraps(fn) 252 | async def async_wrapper(*args, **kwargs): 253 | import dspy 254 | 255 | cache = dspy.cache 256 | modified_request = process_request(args, kwargs) 257 | 258 | # Retrieve from cache if available 259 | cached_result = cache.get(modified_request, ignored_args_for_cache_key) 260 | if cached_result is not None: 261 | return cached_result 262 | 263 | # Otherwise, compute and store the result 264 | # Make a copy of the original request in case it's modified in place, e.g., deleting some fields 265 | original_request = copy.deepcopy(modified_request) 266 | result = await fn(*args, **kwargs) 267 | cache.put(original_request, result, ignored_args_for_cache_key, enable_memory_cache) 268 | 269 | return result 270 | 271 | if inspect.iscoroutinefunction(fn): 272 | return async_wrapper 273 | else: 274 | return sync_wrapper 275 | 276 | return decorator 277 | ```