This is page 9 of 17. Use http://codebase.md/stanfordnlp/dspy?lines=true&page={x} to view the full context. # Directory Structure ``` ├── .github │ ├── .internal_dspyai │ │ ├── internals │ │ │ ├── build-and-release.md │ │ │ └── release-checklist.md │ │ └── pyproject.toml │ ├── .tmp │ │ └── .generated-actions │ │ └── run-pypi-publish-in-docker-container │ │ └── action.yml │ ├── ISSUE_TEMPLATE │ │ ├── bug_report.yml │ │ └── feature_request.yml │ ├── PULL_REQUEST_TEMPLATE │ │ └── pull_request_template.md │ ├── workflow_scripts │ │ └── install_testpypi_pkg.sh │ └── workflows │ ├── build_and_release.yml │ ├── build_utils │ │ └── test_version.py │ ├── docs-push.yml │ ├── precommits_check.yml │ └── run_tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── docs │ ├── .gitignore │ ├── docs │ │ ├── api │ │ │ ├── adapters │ │ │ │ ├── Adapter.md │ │ │ │ ├── ChatAdapter.md │ │ │ │ ├── JSONAdapter.md │ │ │ │ └── TwoStepAdapter.md │ │ │ ├── evaluation │ │ │ │ ├── answer_exact_match.md │ │ │ │ ├── answer_passage_match.md │ │ │ │ ├── CompleteAndGrounded.md │ │ │ │ ├── Evaluate.md │ │ │ │ ├── EvaluationResult.md │ │ │ │ └── SemanticF1.md │ │ │ ├── experimental │ │ │ │ ├── Citations.md │ │ │ │ └── Document.md │ │ │ ├── index.md │ │ │ ├── models │ │ │ │ ├── Embedder.md │ │ │ │ └── LM.md │ │ │ ├── modules │ │ │ │ ├── BestOfN.md │ │ │ │ ├── ChainOfThought.md │ │ │ │ ├── CodeAct.md │ │ │ │ ├── Module.md │ │ │ │ ├── MultiChainComparison.md │ │ │ │ ├── Parallel.md │ │ │ │ ├── Predict.md │ │ │ │ ├── ProgramOfThought.md │ │ │ │ ├── ReAct.md │ │ │ │ └── Refine.md │ │ │ ├── optimizers │ │ │ │ ├── BetterTogether.md │ │ │ │ ├── BootstrapFewShot.md │ │ │ │ ├── BootstrapFewShotWithRandomSearch.md │ │ │ │ ├── BootstrapFinetune.md │ │ │ │ ├── BootstrapRS.md │ │ │ │ ├── COPRO.md │ │ │ │ ├── Ensemble.md │ │ │ │ ├── GEPA │ │ │ │ │ ├── GEPA_Advanced.md │ │ │ │ │ └── overview.md │ │ │ │ ├── InferRules.md │ │ │ │ ├── KNN.md │ │ │ │ ├── KNNFewShot.md │ │ │ │ ├── LabeledFewShot.md │ │ │ │ ├── MIPROv2.md │ │ │ │ └── SIMBA.md │ │ │ ├── primitives │ │ │ │ ├── Audio.md │ │ │ │ ├── Code.md │ │ │ │ ├── Example.md │ │ │ │ ├── History.md │ │ │ │ ├── Image.md │ │ │ │ ├── Prediction.md │ │ │ │ ├── Tool.md │ │ │ │ └── ToolCalls.md │ │ │ ├── signatures │ │ │ │ ├── InputField.md │ │ │ │ ├── OutputField.md │ │ │ │ └── Signature.md │ │ │ ├── tools │ │ │ │ ├── ColBERTv2.md │ │ │ │ ├── Embeddings.md │ │ │ │ └── PythonInterpreter.md │ │ │ └── utils │ │ │ ├── asyncify.md │ │ │ ├── configure_cache.md │ │ │ ├── disable_litellm_logging.md │ │ │ ├── disable_logging.md │ │ │ ├── enable_litellm_logging.md │ │ │ ├── enable_logging.md │ │ │ ├── inspect_history.md │ │ │ ├── load.md │ │ │ ├── StatusMessage.md │ │ │ ├── StatusMessageProvider.md │ │ │ ├── streamify.md │ │ │ └── StreamListener.md │ │ ├── cheatsheet.md │ │ ├── community │ │ │ ├── community-resources.md │ │ │ ├── how-to-contribute.md │ │ │ └── use-cases.md │ │ ├── deep-dive │ │ │ └── data-handling │ │ │ ├── built-in-datasets.md │ │ │ ├── examples.md │ │ │ ├── img │ │ │ │ └── data-loading.png │ │ │ └── loading-custom-data.md │ │ ├── faqs.md │ │ ├── index.md │ │ ├── js │ │ │ └── runllm-widget.js │ │ ├── learn │ │ │ ├── evaluation │ │ │ │ ├── data.md │ │ │ │ ├── metrics.md │ │ │ │ └── overview.md │ │ │ ├── figures │ │ │ │ ├── native_tool_call.png │ │ │ │ └── teleprompter-classes.png │ │ │ ├── index.md │ │ │ ├── optimization │ │ │ │ ├── optimizers.md │ │ │ │ └── overview.md │ │ │ └── programming │ │ │ ├── 7-assertions.md │ │ │ ├── adapters.md │ │ │ ├── language_models.md │ │ │ ├── mcp.md │ │ │ ├── modules.md │ │ │ ├── overview.md │ │ │ ├── signatures.md │ │ │ └── tools.md │ │ ├── production │ │ │ └── index.md │ │ ├── roadmap.md │ │ ├── static │ │ │ ├── .nojekyll │ │ │ └── img │ │ │ ├── dspy_logo.png │ │ │ ├── logo.png │ │ │ ├── mlflow-tracing-rag.png │ │ │ ├── modular.png │ │ │ ├── optimize.png │ │ │ ├── undraw_docusaurus_mountain.svg │ │ │ ├── undraw_docusaurus_react.svg │ │ │ ├── undraw_docusaurus_tree.svg │ │ │ └── universal_compatibility.png │ │ ├── stylesheets │ │ │ └── extra.css │ │ └── tutorials │ │ ├── agents │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-agent.png │ │ ├── ai_text_game │ │ │ └── index.md │ │ ├── async │ │ │ └── index.md │ │ ├── audio │ │ │ └── index.ipynb │ │ ├── build_ai_program │ │ │ └── index.md │ │ ├── cache │ │ │ └── index.md │ │ ├── classification │ │ │ └── index.md │ │ ├── classification_finetuning │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-classification.png │ │ ├── conversation_history │ │ │ └── index.md │ │ ├── core_development │ │ │ └── index.md │ │ ├── custom_module │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-custom-module.png │ │ ├── customer_service_agent │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-customer-service-agent.png │ │ ├── deployment │ │ │ ├── dspy_mlflow_ui.png │ │ │ └── index.md │ │ ├── email_extraction │ │ │ ├── index.md │ │ │ └── mlflow-tracing-email-extraction.png │ │ ├── entity_extraction │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-entity-extraction.png │ │ ├── games │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-agent.png │ │ ├── gepa_ai_program │ │ │ └── index.md │ │ ├── gepa_aime │ │ │ ├── index.ipynb │ │ │ ├── mlflow-tracing-gepa-aime.png │ │ │ └── mlflow-tracking-gepa-aime-optimization.png │ │ ├── gepa_facilitysupportanalyzer │ │ │ ├── index.ipynb │ │ │ ├── mlflow-tracing-gepa-support.png │ │ │ └── mlflow-tracking-gepa-support-optimization.png │ │ ├── gepa_papillon │ │ │ ├── index.ipynb │ │ │ ├── mlflow-tracing-gepa-papilon.png │ │ │ └── mlflow-tracking-gepa-papilon-optimization.png │ │ ├── image_generation_prompting │ │ │ └── index.ipynb │ │ ├── index.md │ │ ├── llms_txt_generation │ │ │ └── index.md │ │ ├── math │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-math.png │ │ ├── mcp │ │ │ └── index.md │ │ ├── mem0_react_agent │ │ │ └── index.md │ │ ├── multihop_search │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-multi-hop.png │ │ ├── observability │ │ │ ├── index.md │ │ │ ├── mlflow_trace_ui_navigation.gif │ │ │ ├── mlflow_trace_ui.png │ │ │ └── mlflow_trace_view.png │ │ ├── optimize_ai_program │ │ │ └── index.md │ │ ├── optimizer_tracking │ │ │ ├── child_run.png │ │ │ ├── experiment.png │ │ │ ├── index.md │ │ │ └── parent_run.png │ │ ├── output_refinement │ │ │ └── best-of-n-and-refine.md │ │ ├── papillon │ │ │ └── index.md │ │ ├── program_of_thought │ │ │ └── index.ipynb │ │ ├── rag │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-rag.png │ │ ├── real_world_examples │ │ │ └── index.md │ │ ├── rl_ai_program │ │ │ └── index.md │ │ ├── rl_multihop │ │ │ └── index.ipynb │ │ ├── rl_papillon │ │ │ └── index.ipynb │ │ ├── sample_code_generation │ │ │ └── index.md │ │ ├── saving │ │ │ └── index.md │ │ ├── streaming │ │ │ └── index.md │ │ ├── tool_use │ │ │ └── index.ipynb │ │ └── yahoo_finance_react │ │ └── index.md │ ├── mkdocs.yml │ ├── overrides │ │ ├── home.html │ │ ├── main.html │ │ └── partials │ │ └── tabs.html │ ├── Pipfile │ ├── Pipfile.lock │ ├── README.md │ ├── requirements.txt │ ├── scripts │ │ ├── generate_api_docs.py │ │ └── generate_api_summary.py │ └── vercel.json ├── dspy │ ├── __init__.py │ ├── __metadata__.py │ ├── adapters │ │ ├── __init__.py │ │ ├── baml_adapter.py │ │ ├── base.py │ │ ├── chat_adapter.py │ │ ├── json_adapter.py │ │ ├── two_step_adapter.py │ │ ├── types │ │ │ ├── __init__.py │ │ │ ├── audio.py │ │ │ ├── base_type.py │ │ │ ├── citation.py │ │ │ ├── code.py │ │ │ ├── document.py │ │ │ ├── history.py │ │ │ ├── image.py │ │ │ └── tool.py │ │ ├── utils.py │ │ └── xml_adapter.py │ ├── clients │ │ ├── __init__.py │ │ ├── base_lm.py │ │ ├── cache.py │ │ ├── databricks.py │ │ ├── embedding.py │ │ ├── lm_local_arbor.py │ │ ├── lm_local.py │ │ ├── lm.py │ │ ├── openai.py │ │ ├── provider.py │ │ └── utils_finetune.py │ ├── datasets │ │ ├── __init__.py │ │ ├── alfworld │ │ │ ├── __init__.py │ │ │ ├── alfworld.py │ │ │ └── base_config.yml │ │ ├── colors.py │ │ ├── dataloader.py │ │ ├── dataset.py │ │ ├── gsm8k.py │ │ ├── hotpotqa.py │ │ └── math.py │ ├── dsp │ │ ├── __init__.py │ │ ├── colbertv2.py │ │ └── utils │ │ ├── __init__.py │ │ ├── dpr.py │ │ ├── settings.py │ │ └── utils.py │ ├── evaluate │ │ ├── __init__.py │ │ ├── auto_evaluation.py │ │ ├── evaluate.py │ │ └── metrics.py │ ├── experimental │ │ └── __init__.py │ ├── predict │ │ ├── __init__.py │ │ ├── aggregation.py │ │ ├── avatar │ │ │ ├── __init__.py │ │ │ ├── avatar.py │ │ │ ├── models.py │ │ │ └── signatures.py │ │ ├── best_of_n.py │ │ ├── chain_of_thought.py │ │ ├── code_act.py │ │ ├── knn.py │ │ ├── multi_chain_comparison.py │ │ ├── parallel.py │ │ ├── parameter.py │ │ ├── predict.py │ │ ├── program_of_thought.py │ │ ├── react.py │ │ ├── refine.py │ │ └── retry.py │ ├── primitives │ │ ├── __init__.py │ │ ├── base_module.py │ │ ├── example.py │ │ ├── module.py │ │ ├── prediction.py │ │ ├── python_interpreter.py │ │ └── runner.js │ ├── propose │ │ ├── __init__.py │ │ ├── dataset_summary_generator.py │ │ ├── grounded_proposer.py │ │ ├── propose_base.py │ │ └── utils.py │ ├── retrievers │ │ ├── __init__.py │ │ ├── databricks_rm.py │ │ ├── embeddings.py │ │ ├── retrieve.py │ │ └── weaviate_rm.py │ ├── signatures │ │ ├── __init__.py │ │ ├── field.py │ │ ├── signature.py │ │ └── utils.py │ ├── streaming │ │ ├── __init__.py │ │ ├── messages.py │ │ ├── streamify.py │ │ └── streaming_listener.py │ ├── teleprompt │ │ ├── __init__.py │ │ ├── avatar_optimizer.py │ │ ├── bettertogether.py │ │ ├── bootstrap_finetune.py │ │ ├── bootstrap_trace.py │ │ ├── bootstrap.py │ │ ├── copro_optimizer.py │ │ ├── ensemble.py │ │ ├── gepa │ │ │ ├── __init__.py │ │ │ ├── gepa_utils.py │ │ │ ├── gepa.py │ │ │ └── instruction_proposal.py │ │ ├── grpo.py │ │ ├── infer_rules.py │ │ ├── knn_fewshot.py │ │ ├── mipro_optimizer_v2.py │ │ ├── random_search.py │ │ ├── signature_opt.py │ │ ├── simba_utils.py │ │ ├── simba.py │ │ ├── teleprompt_optuna.py │ │ ├── teleprompt.py │ │ ├── utils.py │ │ └── vanilla.py │ └── utils │ ├── __init__.py │ ├── annotation.py │ ├── asyncify.py │ ├── caching.py │ ├── callback.py │ ├── dummies.py │ ├── exceptions.py │ ├── hasher.py │ ├── inspect_history.py │ ├── langchain_tool.py │ ├── logging_utils.py │ ├── mcp.py │ ├── parallelizer.py │ ├── saving.py │ ├── syncify.py │ ├── unbatchify.py │ └── usage_tracker.py ├── LICENSE ├── pyproject.toml ├── README.md ├── tests │ ├── __init__.py │ ├── adapters │ │ ├── test_adapter_utils.py │ │ ├── test_baml_adapter.py │ │ ├── test_base_type.py │ │ ├── test_chat_adapter.py │ │ ├── test_citation.py │ │ ├── test_code.py │ │ ├── test_document.py │ │ ├── test_json_adapter.py │ │ ├── test_tool.py │ │ ├── test_two_step_adapter.py │ │ └── test_xml_adapter.py │ ├── callback │ │ └── test_callback.py │ ├── clients │ │ ├── test_cache.py │ │ ├── test_databricks.py │ │ ├── test_embedding.py │ │ ├── test_inspect_global_history.py │ │ └── test_lm.py │ ├── conftest.py │ ├── datasets │ │ └── test_dataset.py │ ├── docs │ │ └── test_mkdocs_links.py │ ├── evaluate │ │ ├── test_evaluate.py │ │ └── test_metrics.py │ ├── examples │ │ └── test_baleen.py │ ├── metadata │ │ └── test_metadata.py │ ├── predict │ │ ├── test_aggregation.py │ │ ├── test_best_of_n.py │ │ ├── test_chain_of_thought.py │ │ ├── test_code_act.py │ │ ├── test_knn.py │ │ ├── test_multi_chain_comparison.py │ │ ├── test_parallel.py │ │ ├── test_predict.py │ │ ├── test_program_of_thought.py │ │ ├── test_react.py │ │ ├── test_refine.py │ │ └── test_retry.py │ ├── primitives │ │ ├── resources │ │ │ └── saved_program.json │ │ ├── test_base_module.py │ │ ├── test_example.py │ │ ├── test_module.py │ │ └── test_python_interpreter.py │ ├── propose │ │ └── test_grounded_proposer.py │ ├── README.md │ ├── reliability │ │ ├── __init__.py │ │ ├── complex_types │ │ │ └── generated │ │ │ ├── test_many_types_1 │ │ │ │ ├── inputs │ │ │ │ │ ├── input1.json │ │ │ │ │ └── input2.json │ │ │ │ ├── program.py │ │ │ │ └── schema.json │ │ │ ├── test_nesting_1 │ │ │ │ ├── inputs │ │ │ │ │ ├── input1.json │ │ │ │ │ └── input2.json │ │ │ │ ├── program.py │ │ │ │ └── schema.json │ │ │ └── test_nesting_2 │ │ │ ├── inputs │ │ │ │ └── input1.json │ │ │ ├── program.py │ │ │ └── schema.json │ │ ├── conftest.py │ │ ├── generate │ │ │ ├── __init__.py │ │ │ ├── __main__.py │ │ │ └── utils.py │ │ ├── input_formats │ │ │ └── generated │ │ │ └── test_markdown_1 │ │ │ ├── inputs │ │ │ │ ├── input1.json │ │ │ │ └── input2.json │ │ │ ├── program.py │ │ │ └── schema.json │ │ ├── README.md │ │ ├── reliability_conf.yaml │ │ ├── test_generated.py │ │ ├── test_pydantic_models.py │ │ └── utils.py │ ├── retrievers │ │ └── test_embeddings.py │ ├── signatures │ │ ├── test_adapter_image.py │ │ ├── test_custom_types.py │ │ └── test_signature.py │ ├── streaming │ │ └── test_streaming.py │ ├── teleprompt │ │ ├── gepa_dummy_lm_custom_component_selector_custom_instruction_proposer.json │ │ ├── gepa_dummy_lm.json │ │ ├── test_bootstrap_finetune.py │ │ ├── test_bootstrap_trace.py │ │ ├── test_bootstrap.py │ │ ├── test_copro_optimizer.py │ │ ├── test_ensemble.py │ │ ├── test_finetune.py │ │ ├── test_gepa_instruction_proposer.py │ │ ├── test_gepa.py │ │ ├── test_grpo.py │ │ ├── test_knn_fewshot.py │ │ ├── test_random_search.py │ │ ├── test_teleprompt.py │ │ └── test_utils.py │ ├── test_utils │ │ ├── __init__.py │ │ └── server │ │ ├── __init__.py │ │ ├── litellm_server_config.yaml │ │ └── litellm_server.py │ └── utils │ ├── __init__.py │ ├── resources │ │ └── mcp_server.py │ ├── test_annotation.py │ ├── test_asyncify.py │ ├── test_exceptions.py │ ├── test_langchain_tool.py │ ├── test_mcp.py │ ├── test_parallelizer.py │ ├── test_saving.py │ ├── test_settings.py │ ├── test_syncify.py │ ├── test_unbatchify.py │ └── test_usage_tracker.py └── uv.lock ``` # Files -------------------------------------------------------------------------------- /docs/docs/faqs.md: -------------------------------------------------------------------------------- ```markdown 1 | --- 2 | sidebar_position: 998 3 | --- 4 | 5 | !!! warning "This page is outdated and may not be fully accurate in DSPy 2.5 and 2.6" 6 | 7 | 8 | # FAQs 9 | 10 | ## Is DSPy right for me? DSPy vs. other frameworks 11 | 12 | The **DSPy** philosophy and abstraction differ significantly from other libraries and frameworks, so it's usually straightforward to decide when **DSPy** is (or isn't) the right framework for your usecase. If you're a NLP/AI researcher (or a practitioner exploring new pipelines or new tasks), the answer is generally an invariable **yes**. If you're a practitioner doing other things, please read on. 13 | 14 | **DSPy vs. thin wrappers for prompts (OpenAI API, MiniChain, basic templating)** In other words: _Why can't I just write my prompts directly as string templates?_ Well, for extremely simple settings, this _might_ work just fine. (If you're familiar with neural networks, this is like expressing a tiny two-layer NN as a Python for-loop. It kinda works.) However, when you need higher quality (or manageable cost), then you need to iteratively explore multi-stage decomposition, improved prompting, data bootstrapping, careful finetuning, retrieval augmentation, and/or using smaller (or cheaper, or local) models. The true expressive power of building with foundation models lies in the interactions between these pieces. But every time you change one piece, you likely break (or weaken) multiple other components. **DSPy** cleanly abstracts away (_and_ powerfully optimizes) the parts of these interactions that are external to your actual system design. It lets you focus on designing the module-level interactions: the _same program_ expressed in 10 or 20 lines of **DSPy** can easily be compiled into multi-stage instructions for `GPT-4`, detailed prompts for `Llama2-13b`, or finetunes for `T5-base`. Oh, and you wouldn't need to maintain long, brittle, model-specific strings at the core of your project anymore. 15 | 16 | **DSPy vs. application development libraries like LangChain, LlamaIndex** LangChain and LlamaIndex target high-level application development; they offer _batteries-included_, pre-built application modules that plug in with your data or configuration. If you'd be happy to use a generic, off-the-shelf prompt for question answering over PDFs or standard text-to-SQL, you will find a rich ecosystem in these libraries. **DSPy** doesn't internally contain hand-crafted prompts that target specific applications. Instead, **DSPy** introduces a small set of much more powerful and general-purpose modules _that can learn to prompt (or finetune) your LM within your pipeline on your data_. when you change your data, make tweaks to your program's control flow, or change your target LM, the **DSPy compiler** can map your program into a new set of prompts (or finetunes) that are optimized specifically for this pipeline. Because of this, you may find that **DSPy** obtains the highest quality for your task, with the least effort, provided you're willing to implement (or extend) your own short program. In short, **DSPy** is for when you need a lightweight but automatically-optimizing programming model — not a library of predefined prompts and integrations. If you're familiar with neural networks: This is like the difference between PyTorch (i.e., representing **DSPy**) and HuggingFace Transformers (i.e., representing the higher-level libraries). 17 | 18 | **DSPy vs. generation control libraries like Guidance, LMQL, RELM, Outlines** These are all exciting new libraries for controlling the individual completions of LMs, e.g., if you want to enforce JSON output schema or constrain sampling to a particular regular expression. This is very useful in many settings, but it's generally focused on low-level, structured control of a single LM call. It doesn't help ensure the JSON (or structured output) you get is going to be correct or useful for your task. In contrast, **DSPy** automatically optimizes the prompts in your programs to align them with various task needs, which may also include producing valid structured outputs. That said, we are considering allowing **Signatures** in **DSPy** to express regex-like constraints that are implemented by these libraries. 19 | 20 | ## Basic Usage 21 | 22 | **How should I use DSPy for my task?** We wrote a [eight-step guide](learn/index.md) on this. In short, using DSPy is an iterative process. You first define your task and the metrics you want to maximize, and prepare a few example inputs — typically without labels (or only with labels for the final outputs, if your metric requires them). Then, you build your pipeline by selecting built-in layers (`modules`) to use, giving each layer a `signature` (input/output spec), and then calling your modules freely in your Python code. Lastly, you use a DSPy `optimizer` to compile your code into high-quality instructions, automatic few-shot examples, or updated LM weights for your LM. 23 | 24 | **How do I convert my complex prompt into a DSPy pipeline?** See the same answer above. 25 | 26 | **What do DSPy optimizers tune?** Or, _what does compiling actually do?_ Each optimizer is different, but they all seek to maximize a metric on your program by updating prompts or LM weights. Current DSPy `optimizers` can inspect your data, simulate traces through your program to generate good/bad examples of each step, propose or refine instructions for each step based on past results, finetune the weights of your LM on self-generated examples, or combine several of these to improve quality or cut cost. We'd love to merge new optimizers that explore a richer space: most manual steps you currently go through for prompt engineering, "synthetic data" generation, or self-improvement can probably generalized into a DSPy optimizer that acts on arbitrary LM programs. 27 | 28 | Other FAQs. We welcome PRs to add formal answers to each of these here. You will find the answer in existing issues, tutorials, or the papers for all or most of these. 29 | 30 | - **How do I get multiple outputs?** 31 | 32 | You can specify multiple output fields. For the short-form signature, you can list multiple outputs as comma separated values, following the "->" indicator (e.g. "inputs -> output1, output2"). For the long-form signature, you can include multiple `dspy.OutputField`s. 33 | 34 | 35 | - **How do I define my own metrics? Can metrics return a float?** 36 | 37 | You can define metrics as simply Python functions that process model generations and evaluate them based on user-defined requirements. Metrics can compare existent data (e.g. gold labels) to model predictions or they can be used to assess various components of an output using validation feedback from LMs (e.g. LLMs-as-Judges). Metrics can return `bool`, `int`, and `float` types scores. Check out the official [Metrics docs](learn/evaluation/metrics.md) to learn more about defining custom metrics and advanced evaluations using AI feedback and/or DSPy programs. 38 | 39 | - **How expensive or slow is compiling??** 40 | 41 | To reflect compiling metrics, we highlight an experiment for reference, compiling a program using the [BootstrapFewShotWithRandomSearch](api/optimizers/BootstrapFewShotWithRandomSearch.md) optimizer on the `gpt-3.5-turbo-1106` model over 7 candidate programs and 10 threads. We report that compiling this program takes around 6 minutes with 3200 API calls, 2.7 million input tokens and 156,000 output tokens, reporting a total cost of $3 USD (at the current pricing of the OpenAI model). 42 | 43 | Compiling DSPy `optimizers` naturally will incur additional LM calls, but we substantiate this overhead with minimalistic executions with the goal of maximizing performance. This invites avenues to enhance performance of smaller models by compiling DSPy programs with larger models to learn enhanced behavior during compile-time and propagate such behavior to the tested smaller model during inference-time. 44 | 45 | 46 | ## Deployment or Reproducibility Concerns 47 | 48 | - **How do I save a checkpoint of my compiled program?** 49 | 50 | Here is an example of saving/loading a compiled module: 51 | 52 | ```python 53 | cot_compiled = teleprompter.compile(CoT(), trainset=trainset, valset=devset) 54 | 55 | #Saving 56 | cot_compiled.save('compiled_cot_gsm8k.json') 57 | 58 | #Loading: 59 | cot = CoT() 60 | cot.load('compiled_cot_gsm8k.json') 61 | ``` 62 | 63 | - **How do I export for deployment?** 64 | 65 | Exporting DSPy programs is simply saving them as highlighted above! 66 | 67 | - **How do I search my own data?** 68 | 69 | Open source libraries such as [RAGautouille](https://github.com/bclavie/ragatouille) enable you to search for your own data through advanced retrieval models like ColBERT with tools to embed and index documents. Feel free to integrate such libraries to create searchable datasets while developing your DSPy programs! 70 | 71 | - **How do I turn off the cache? How do I export the cache?** 72 | 73 | From v2.5, you can turn off the cache by setting `cache` parameter in `dspy.LM` to `False`: 74 | 75 | ```python 76 | dspy.LM('openai/gpt-4o-mini', cache=False) 77 | ``` 78 | 79 | Your local cache will be saved to the global env directory `os.environ["DSP_CACHEDIR"]` or for notebooks `os.environ["DSP_NOTEBOOK_CACHEDIR"]`. You can usually set the cachedir to `os.path.join(repo_path, 'cache')` and export this cache from here: 80 | ```python 81 | os.environ["DSP_NOTEBOOK_CACHEDIR"] = os.path.join(os.getcwd(), 'cache') 82 | ``` 83 | 84 | !!! warning "Important" 85 | `DSP_CACHEDIR` is responsible for old clients (including dspy.OpenAI, dspy.ColBERTv2, etc.) and `DSPY_CACHEDIR` is responsible for the new dspy.LM client. 86 | 87 | In the AWS lambda deployment, you should disable both DSP_\* and DSPY_\*. 88 | 89 | 90 | ## Advanced Usage 91 | 92 | - **How do I parallelize?** 93 | You can parallelize DSPy programs during both compilation and evaluation by specifying multiple thread settings in the respective DSPy `optimizers` or within the `dspy.Evaluate` utility function. 94 | 95 | - **How do freeze a module?** 96 | 97 | Modules can be frozen by setting their `._compiled` attribute to be True, indicating the module has gone through optimizer compilation and should not have its parameters adjusted. This is handled internally in optimizers such as `dspy.BootstrapFewShot` where the student program is ensured to be frozen before the teacher propagates the gathered few-shot demonstrations in the bootstrapping process. 98 | 99 | - **How do I use DSPy assertions?** 100 | 101 | a) **How to Add Assertions to Your Program**: 102 | - **Define Constraints**: Use `dspy.Assert` and/or `dspy.Suggest` to define constraints within your DSPy program. These are based on boolean validation checks for the outcomes you want to enforce, which can simply be Python functions to validate the model outputs. 103 | - **Integrating Assertions**: Keep your Assertion statements following a model generations (hint: following a module layer) 104 | 105 | b) **How to Activate the Assertions**: 106 | 1. **Using `assert_transform_module`**: 107 | - Wrap your DSPy module with assertions using the `assert_transform_module` function, along with a `backtrack_handler`. This function transforms your program to include internal assertions backtracking and retry logic, which can be customized as well: 108 | `program_with_assertions = assert_transform_module(ProgramWithAssertions(), backtrack_handler)` 109 | 2. **Activate Assertions**: 110 | - Directly call `activate_assertions` on your DSPy program with assertions: `program_with_assertions = ProgramWithAssertions().activate_assertions()` 111 | 112 | **Note**: To use Assertions properly, you must **activate** a DSPy program that includes `dspy.Assert` or `dspy.Suggest` statements from either of the methods above. 113 | 114 | ## Errors 115 | 116 | - **How do I deal with "context too long" errors?** 117 | 118 | If you're dealing with "context too long" errors in DSPy, you're likely using DSPy optimizers to include demonstrations within your prompt, and this is exceeding your current context window. Try reducing these parameters (e.g. `max_bootstrapped_demos` and `max_labeled_demos`). Additionally, you can also reduce the number of retrieved passages/docs/embeddings to ensure your prompt is fitting within your model context length. 119 | 120 | A more general fix is simply increasing the number of `max_tokens` specified to the LM request (e.g. `lm = dspy.OpenAI(model = ..., max_tokens = ...`). 121 | 122 | ## Set Verbose Level 123 | DSPy utilizes the [logging library](https://docs.python.org/3/library/logging.html) to print logs. If you want to debug your DSPy code, set the logging level to DEBUG with the example code below. 124 | 125 | ```python 126 | import logging 127 | logging.getLogger("dspy").setLevel(logging.DEBUG) 128 | ``` 129 | 130 | Alternatively, if you want to reduce the amount of logs, set the logging level to WARNING or ERROR. 131 | 132 | ```python 133 | import logging 134 | logging.getLogger("dspy").setLevel(logging.WARNING) 135 | ``` ``` -------------------------------------------------------------------------------- /docs/docs/learn/programming/modules.md: -------------------------------------------------------------------------------- ```markdown 1 | --- 2 | sidebar_position: 3 3 | --- 4 | 5 | # Modules 6 | 7 | A **DSPy module** is a building block for programs that use LMs. 8 | 9 | - Each built-in module abstracts a **prompting technique** (like chain of thought or ReAct). Crucially, they are generalized to handle any signature. 10 | 11 | - A DSPy module has **learnable parameters** (i.e., the little pieces comprising the prompt and the LM weights) and can be invoked (called) to process inputs and return outputs. 12 | 13 | - Multiple modules can be composed into bigger modules (programs). DSPy modules are inspired directly by NN modules in PyTorch, but applied to LM programs. 14 | 15 | 16 | ## How do I use a built-in module, like `dspy.Predict` or `dspy.ChainOfThought`? 17 | 18 | Let's start with the most fundamental module, `dspy.Predict`. Internally, all other DSPy modules are built using `dspy.Predict`. We'll assume you are already at least a little familiar with [DSPy signatures](signatures.md), which are declarative specs for defining the behavior of any module we use in DSPy. 19 | 20 | To use a module, we first **declare** it by giving it a signature. Then we **call** the module with the input arguments, and extract the output fields! 21 | 22 | ```python 23 | sentence = "it's a charming and often affecting journey." # example from the SST-2 dataset. 24 | 25 | # 1) Declare with a signature. 26 | classify = dspy.Predict('sentence -> sentiment: bool') 27 | 28 | # 2) Call with input argument(s). 29 | response = classify(sentence=sentence) 30 | 31 | # 3) Access the output. 32 | print(response.sentiment) 33 | ``` 34 | **Output:** 35 | ```text 36 | True 37 | ``` 38 | 39 | When we declare a module, we can pass configuration keys to it. 40 | 41 | Below, we'll pass `n=5` to request five completions. We can also pass `temperature` or `max_len`, etc. 42 | 43 | Let's use `dspy.ChainOfThought`. In many cases, simply swapping `dspy.ChainOfThought` in place of `dspy.Predict` improves quality. 44 | 45 | ```python 46 | question = "What's something great about the ColBERT retrieval model?" 47 | 48 | # 1) Declare with a signature, and pass some config. 49 | classify = dspy.ChainOfThought('question -> answer', n=5) 50 | 51 | # 2) Call with input argument. 52 | response = classify(question=question) 53 | 54 | # 3) Access the outputs. 55 | response.completions.answer 56 | ``` 57 | **Possible Output:** 58 | ```text 59 | ['One great thing about the ColBERT retrieval model is its superior efficiency and effectiveness compared to other models.', 60 | 'Its ability to efficiently retrieve relevant information from large document collections.', 61 | 'One great thing about the ColBERT retrieval model is its superior performance compared to other models and its efficient use of pre-trained language models.', 62 | 'One great thing about the ColBERT retrieval model is its superior efficiency and accuracy compared to other models.', 63 | 'One great thing about the ColBERT retrieval model is its ability to incorporate user feedback and support complex queries.'] 64 | ``` 65 | 66 | Let's discuss the output object here. The `dspy.ChainOfThought` module will generally inject a `reasoning` before the output field(s) of your signature. 67 | 68 | Let's inspect the (first) reasoning and answer! 69 | 70 | ```python 71 | print(f"Reasoning: {response.reasoning}") 72 | print(f"Answer: {response.answer}") 73 | ``` 74 | **Possible Output:** 75 | ```text 76 | Reasoning: We can consider the fact that ColBERT has shown to outperform other state-of-the-art retrieval models in terms of efficiency and effectiveness. It uses contextualized embeddings and performs document retrieval in a way that is both accurate and scalable. 77 | Answer: One great thing about the ColBERT retrieval model is its superior efficiency and effectiveness compared to other models. 78 | ``` 79 | 80 | This is accessible whether we request one or many completions. 81 | 82 | We can also access the different completions as a list of `Prediction`s or as several lists, one for each field. 83 | 84 | ```python 85 | response.completions[3].reasoning == response.completions.reasoning[3] 86 | ``` 87 | **Output:** 88 | ```text 89 | True 90 | ``` 91 | 92 | 93 | ## What other DSPy modules are there? How can I use them? 94 | 95 | The others are very similar. They mainly change the internal behavior with which your signature is implemented! 96 | 97 | 1. **`dspy.Predict`**: Basic predictor. Does not modify the signature. Handles the key forms of learning (i.e., storing the instructions and demonstrations and updates to the LM). 98 | 99 | 2. **`dspy.ChainOfThought`**: Teaches the LM to think step-by-step before committing to the signature's response. 100 | 101 | 3. **`dspy.ProgramOfThought`**: Teaches the LM to output code, whose execution results will dictate the response. 102 | 103 | 4. **`dspy.ReAct`**: An agent that can use tools to implement the given signature. 104 | 105 | 5. **`dspy.MultiChainComparison`**: Can compare multiple outputs from `ChainOfThought` to produce a final prediction. 106 | 107 | 108 | We also have some function-style modules: 109 | 110 | 6. **`dspy.majority`**: Can do basic voting to return the most popular response from a set of predictions. 111 | 112 | 113 | !!! info "A few examples of DSPy modules on simple tasks." 114 | Try the examples below after configuring your `lm`. Adjust the fields to explore what tasks your LM can do well out of the box. 115 | 116 | === "Math" 117 | 118 | ```python linenums="1" 119 | math = dspy.ChainOfThought("question -> answer: float") 120 | math(question="Two dice are tossed. What is the probability that the sum equals two?") 121 | ``` 122 | 123 | **Possible Output:** 124 | ```text 125 | Prediction( 126 | reasoning='When two dice are tossed, each die has 6 faces, resulting in a total of 6 x 6 = 36 possible outcomes. The sum of the numbers on the two dice equals two only when both dice show a 1. This is just one specific outcome: (1, 1). Therefore, there is only 1 favorable outcome. The probability of the sum being two is the number of favorable outcomes divided by the total number of possible outcomes, which is 1/36.', 127 | answer=0.0277776 128 | ) 129 | ``` 130 | 131 | === "Retrieval-Augmented Generation" 132 | 133 | ```python linenums="1" 134 | def search(query: str) -> list[str]: 135 | """Retrieves abstracts from Wikipedia.""" 136 | results = dspy.ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts')(query, k=3) 137 | return [x['text'] for x in results] 138 | 139 | rag = dspy.ChainOfThought('context, question -> response') 140 | 141 | question = "What's the name of the castle that David Gregory inherited?" 142 | rag(context=search(question), question=question) 143 | ``` 144 | 145 | **Possible Output:** 146 | ```text 147 | Prediction( 148 | reasoning='The context provides information about David Gregory, a Scottish physician and inventor. It specifically mentions that he inherited Kinnairdy Castle in 1664. This detail directly answers the question about the name of the castle that David Gregory inherited.', 149 | response='Kinnairdy Castle' 150 | ) 151 | ``` 152 | 153 | === "Classification" 154 | 155 | ```python linenums="1" 156 | from typing import Literal 157 | 158 | class Classify(dspy.Signature): 159 | """Classify sentiment of a given sentence.""" 160 | 161 | sentence: str = dspy.InputField() 162 | sentiment: Literal['positive', 'negative', 'neutral'] = dspy.OutputField() 163 | confidence: float = dspy.OutputField() 164 | 165 | classify = dspy.Predict(Classify) 166 | classify(sentence="This book was super fun to read, though not the last chapter.") 167 | ``` 168 | 169 | **Possible Output:** 170 | 171 | ```text 172 | Prediction( 173 | sentiment='positive', 174 | confidence=0.75 175 | ) 176 | ``` 177 | 178 | === "Information Extraction" 179 | 180 | ```python linenums="1" 181 | text = "Apple Inc. announced its latest iPhone 14 today. The CEO, Tim Cook, highlighted its new features in a press release." 182 | 183 | module = dspy.Predict("text -> title, headings: list[str], entities_and_metadata: list[dict[str, str]]") 184 | response = module(text=text) 185 | 186 | print(response.title) 187 | print(response.headings) 188 | print(response.entities_and_metadata) 189 | ``` 190 | 191 | **Possible Output:** 192 | ```text 193 | Apple Unveils iPhone 14 194 | ['Introduction', 'Key Features', "CEO's Statement"] 195 | [{'entity': 'Apple Inc.', 'type': 'Organization'}, {'entity': 'iPhone 14', 'type': 'Product'}, {'entity': 'Tim Cook', 'type': 'Person'}] 196 | ``` 197 | 198 | === "Agents" 199 | 200 | ```python linenums="1" 201 | def evaluate_math(expression: str) -> float: 202 | return dspy.PythonInterpreter({}).execute(expression) 203 | 204 | def search_wikipedia(query: str) -> str: 205 | results = dspy.ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts')(query, k=3) 206 | return [x['text'] for x in results] 207 | 208 | react = dspy.ReAct("question -> answer: float", tools=[evaluate_math, search_wikipedia]) 209 | 210 | pred = react(question="What is 9362158 divided by the year of birth of David Gregory of Kinnairdy castle?") 211 | print(pred.answer) 212 | ``` 213 | 214 | **Possible Output:** 215 | 216 | ```text 217 | 5761.328 218 | ``` 219 | 220 | 221 | ## How do I compose multiple modules into a bigger program? 222 | 223 | DSPy is just Python code that uses modules in any control flow you like, with a little magic internally at `compile` time to trace your LM calls. What this means is that, you can just call the modules freely. 224 | 225 | See tutorials like [multi-hop search](https://dspy.ai/tutorials/multihop_search/), whose module is reproduced below as an example. 226 | 227 | ```python linenums="1" 228 | class Hop(dspy.Module): 229 | def __init__(self, num_docs=10, num_hops=4): 230 | self.num_docs, self.num_hops = num_docs, num_hops 231 | self.generate_query = dspy.ChainOfThought('claim, notes -> query') 232 | self.append_notes = dspy.ChainOfThought('claim, notes, context -> new_notes: list[str], titles: list[str]') 233 | 234 | def forward(self, claim: str) -> list[str]: 235 | notes = [] 236 | titles = [] 237 | 238 | for _ in range(self.num_hops): 239 | query = self.generate_query(claim=claim, notes=notes).query 240 | context = search(query, k=self.num_docs) 241 | prediction = self.append_notes(claim=claim, notes=notes, context=context) 242 | notes.extend(prediction.new_notes) 243 | titles.extend(prediction.titles) 244 | 245 | return dspy.Prediction(notes=notes, titles=list(set(titles))) 246 | ``` 247 | 248 | Then you can create a instance of the custom module class `Hop`, then invoke it by the `__call__` method: 249 | 250 | ``` 251 | hop = Hop() 252 | print(hop(claim="Stephen Curry is the best 3 pointer shooter ever in the human history")) 253 | ``` 254 | 255 | ## How do I track LM usage? 256 | 257 | !!! note "Version Requirement" 258 | LM usage tracking is available in DSPy version 2.6.16 and later. 259 | 260 | DSPy provides built-in tracking of language model usage across all module calls. To enable tracking: 261 | 262 | ```python 263 | dspy.settings.configure(track_usage=True) 264 | ``` 265 | 266 | Once enabled, you can access usage statistics from any `dspy.Prediction` object: 267 | 268 | ```python 269 | usage = prediction_instance.get_lm_usage() 270 | ``` 271 | 272 | The usage data is returned as a dictionary that maps each language model name to its usage statistics. Here's a complete example: 273 | 274 | ```python 275 | import dspy 276 | 277 | # Configure DSPy with tracking enabled 278 | dspy.settings.configure( 279 | lm=dspy.LM("openai/gpt-4o-mini", cache=False), 280 | track_usage=True 281 | ) 282 | 283 | # Define a simple program that makes multiple LM calls 284 | class MyProgram(dspy.Module): 285 | def __init__(self): 286 | self.predict1 = dspy.ChainOfThought("question -> answer") 287 | self.predict2 = dspy.ChainOfThought("question, answer -> score") 288 | 289 | def __call__(self, question: str) -> str: 290 | answer = self.predict1(question=question) 291 | score = self.predict2(question=question, answer=answer) 292 | return score 293 | 294 | # Run the program and check usage 295 | program = MyProgram() 296 | output = program(question="What is the capital of France?") 297 | print(output.get_lm_usage()) 298 | ``` 299 | 300 | This will output usage statistics like: 301 | 302 | ```python 303 | { 304 | 'openai/gpt-4o-mini': { 305 | 'completion_tokens': 61, 306 | 'prompt_tokens': 260, 307 | 'total_tokens': 321, 308 | 'completion_tokens_details': { 309 | 'accepted_prediction_tokens': 0, 310 | 'audio_tokens': 0, 311 | 'reasoning_tokens': 0, 312 | 'rejected_prediction_tokens': 0, 313 | 'text_tokens': None 314 | }, 315 | 'prompt_tokens_details': { 316 | 'audio_tokens': 0, 317 | 'cached_tokens': 0, 318 | 'text_tokens': None, 319 | 'image_tokens': None 320 | } 321 | } 322 | } 323 | ``` 324 | 325 | When using DSPy's caching features (either in-memory or on-disk via litellm), cached responses won't count toward usage statistics. For example: 326 | 327 | ```python 328 | # Enable caching 329 | dspy.settings.configure( 330 | lm=dspy.LM("openai/gpt-4o-mini", cache=True), 331 | track_usage=True 332 | ) 333 | 334 | program = MyProgram() 335 | 336 | # First call - will show usage statistics 337 | output = program(question="What is the capital of Zambia?") 338 | print(output.get_lm_usage()) # Shows token usage 339 | 340 | # Second call - same question, will use cache 341 | output = program(question="What is the capital of Zambia?") 342 | print(output.get_lm_usage()) # Shows empty dict: {} 343 | ``` 344 | ``` -------------------------------------------------------------------------------- /dspy/adapters/json_adapter.py: -------------------------------------------------------------------------------- ```python 1 | import json 2 | import logging 3 | from typing import Any, get_origin 4 | 5 | import json_repair 6 | import litellm 7 | import pydantic 8 | import regex 9 | from pydantic.fields import FieldInfo 10 | 11 | from dspy.adapters.chat_adapter import ChatAdapter, FieldInfoWithName 12 | from dspy.adapters.types.tool import ToolCalls 13 | from dspy.adapters.utils import ( 14 | format_field_value, 15 | get_annotation_name, 16 | parse_value, 17 | serialize_for_json, 18 | translate_field_type, 19 | ) 20 | from dspy.clients.lm import LM 21 | from dspy.signatures.signature import Signature, SignatureMeta 22 | from dspy.utils.callback import BaseCallback 23 | from dspy.utils.exceptions import AdapterParseError 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | def _has_open_ended_mapping(signature: SignatureMeta) -> bool: 29 | """ 30 | Check whether any output field in the signature has an open-ended mapping type, 31 | such as dict[str, Any]. Structured Outputs require explicit properties, so such fields 32 | are incompatible. 33 | """ 34 | for field in signature.output_fields.values(): 35 | annotation = field.annotation 36 | if get_origin(annotation) is dict: 37 | return True 38 | return False 39 | 40 | 41 | class JSONAdapter(ChatAdapter): 42 | def __init__(self, callbacks: list[BaseCallback] | None = None, use_native_function_calling: bool = True): 43 | # JSONAdapter uses native function calling by default. 44 | super().__init__(callbacks=callbacks, use_native_function_calling=use_native_function_calling) 45 | 46 | def _json_adapter_call_common(self, lm, lm_kwargs, signature, demos, inputs, call_fn): 47 | """Common call logic to be used for both sync and async calls.""" 48 | provider = lm.model.split("/", 1)[0] or "openai" 49 | params = litellm.get_supported_openai_params(model=lm.model, custom_llm_provider=provider) 50 | 51 | if not params or "response_format" not in params: 52 | return call_fn(lm, lm_kwargs, signature, demos, inputs) 53 | 54 | has_tool_calls = any(field.annotation == ToolCalls for field in signature.output_fields.values()) 55 | # Some models support json mode but not structured outputs 56 | # Follows guidance from: https://docs.litellm.ai/docs/completion/json_mode#check-model-support 57 | supports_structured_outputs = litellm.supports_response_schema(model=lm.model, custom_llm_provider=provider) 58 | 59 | if _has_open_ended_mapping(signature) or (not self.use_native_function_calling and has_tool_calls) or not supports_structured_outputs: 60 | # We found that structured output mode doesn't work well with dspy.ToolCalls as output field. 61 | # So we fall back to json mode if native function calling is disabled and ToolCalls is present. 62 | lm_kwargs["response_format"] = {"type": "json_object"} 63 | return call_fn(lm, lm_kwargs, signature, demos, inputs) 64 | 65 | def __call__( 66 | self, 67 | lm: LM, 68 | lm_kwargs: dict[str, Any], 69 | signature: type[Signature], 70 | demos: list[dict[str, Any]], 71 | inputs: dict[str, Any], 72 | ) -> list[dict[str, Any]]: 73 | result = self._json_adapter_call_common(lm, lm_kwargs, signature, demos, inputs, super().__call__) 74 | if result: 75 | return result 76 | 77 | try: 78 | structured_output_model = _get_structured_outputs_response_format( 79 | signature, self.use_native_function_calling 80 | ) 81 | lm_kwargs["response_format"] = structured_output_model 82 | return super().__call__(lm, lm_kwargs, signature, demos, inputs) 83 | except Exception: 84 | logger.warning("Failed to use structured output format, falling back to JSON mode.") 85 | lm_kwargs["response_format"] = {"type": "json_object"} 86 | return super().__call__(lm, lm_kwargs, signature, demos, inputs) 87 | 88 | async def acall( 89 | self, 90 | lm: LM, 91 | lm_kwargs: dict[str, Any], 92 | signature: type[Signature], 93 | demos: list[dict[str, Any]], 94 | inputs: dict[str, Any], 95 | ) -> list[dict[str, Any]]: 96 | result = self._json_adapter_call_common(lm, lm_kwargs, signature, demos, inputs, super().acall) 97 | if result: 98 | return await result 99 | 100 | try: 101 | structured_output_model = _get_structured_outputs_response_format(signature) 102 | lm_kwargs["response_format"] = structured_output_model 103 | return await super().acall(lm, lm_kwargs, signature, demos, inputs) 104 | except Exception: 105 | logger.warning("Failed to use structured output format, falling back to JSON mode.") 106 | lm_kwargs["response_format"] = {"type": "json_object"} 107 | return await super().acall(lm, lm_kwargs, signature, demos, inputs) 108 | 109 | def format_field_structure(self, signature: type[Signature]) -> str: 110 | parts = [] 111 | parts.append("All interactions will be structured in the following way, with the appropriate values filled in.") 112 | 113 | def format_signature_fields_for_instructions(fields: dict[str, FieldInfo], role: str): 114 | return self.format_field_with_value( 115 | fields_with_values={ 116 | FieldInfoWithName(name=field_name, info=field_info): translate_field_type(field_name, field_info) 117 | for field_name, field_info in fields.items() 118 | }, 119 | role=role, 120 | ) 121 | 122 | parts.append("Inputs will have the following structure:") 123 | parts.append(format_signature_fields_for_instructions(signature.input_fields, role="user")) 124 | parts.append("Outputs will be a JSON object with the following fields.") 125 | parts.append(format_signature_fields_for_instructions(signature.output_fields, role="assistant")) 126 | return "\n\n".join(parts).strip() 127 | 128 | def user_message_output_requirements(self, signature: type[Signature]) -> str: 129 | def type_info(v): 130 | return ( 131 | f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})" 132 | if v.annotation is not str 133 | else "" 134 | ) 135 | 136 | message = "Respond with a JSON object in the following order of fields: " 137 | message += ", then ".join(f"`{f}`{type_info(v)}" for f, v in signature.output_fields.items()) 138 | message += "." 139 | return message 140 | 141 | def format_assistant_message_content( 142 | self, 143 | signature: type[Signature], 144 | outputs: dict[str, Any], 145 | missing_field_message=None, 146 | ) -> str: 147 | fields_with_values = { 148 | FieldInfoWithName(name=k, info=v): outputs.get(k, missing_field_message) 149 | for k, v in signature.output_fields.items() 150 | } 151 | return self.format_field_with_value(fields_with_values, role="assistant") 152 | 153 | def parse(self, signature: type[Signature], completion: str) -> dict[str, Any]: 154 | pattern = r"\{(?:[^{}]|(?R))*\}" 155 | match = regex.search(pattern, completion, regex.DOTALL) 156 | if match: 157 | completion = match.group(0) 158 | fields = json_repair.loads(completion) 159 | 160 | if not isinstance(fields, dict): 161 | raise AdapterParseError( 162 | adapter_name="JSONAdapter", 163 | signature=signature, 164 | lm_response=completion, 165 | message="LM response cannot be serialized to a JSON object.", 166 | ) 167 | 168 | fields = {k: v for k, v in fields.items() if k in signature.output_fields} 169 | 170 | # Attempt to cast each value to type signature.output_fields[k].annotation. 171 | for k, v in fields.items(): 172 | if k in signature.output_fields: 173 | fields[k] = parse_value(v, signature.output_fields[k].annotation) 174 | 175 | if fields.keys() != signature.output_fields.keys(): 176 | raise AdapterParseError( 177 | adapter_name="JSONAdapter", 178 | signature=signature, 179 | lm_response=completion, 180 | parsed_result=fields, 181 | ) 182 | 183 | return fields 184 | 185 | def format_field_with_value(self, fields_with_values: dict[FieldInfoWithName, Any], role: str = "user") -> str: 186 | """ 187 | Formats the values of the specified fields according to the field's DSPy type (input or output), 188 | annotation (e.g. str, int, etc.), and the type of the value itself. Joins the formatted values 189 | into a single string, which is a multiline string if there are multiple fields. 190 | 191 | Args: 192 | fields_with_values: A dictionary mapping information about a field to its corresponding value. 193 | Returns: 194 | The joined formatted values of the fields, represented as a string. 195 | """ 196 | if role == "user": 197 | output = [] 198 | for field, field_value in fields_with_values.items(): 199 | formatted_field_value = format_field_value(field_info=field.info, value=field_value) 200 | output.append(f"[[ ## {field.name} ## ]]\n{formatted_field_value}") 201 | return "\n\n".join(output).strip() 202 | else: 203 | d = fields_with_values.items() 204 | d = {k.name: v for k, v in d} 205 | return json.dumps(serialize_for_json(d), indent=2) 206 | 207 | def format_finetune_data( 208 | self, signature: type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any], outputs: dict[str, Any] 209 | ) -> dict[str, list[Any]]: 210 | # TODO: implement format_finetune_data method in JSONAdapter 211 | raise NotImplementedError 212 | 213 | 214 | def _get_structured_outputs_response_format( 215 | signature: SignatureMeta, 216 | use_native_function_calling: bool = True, 217 | ) -> type[pydantic.BaseModel]: 218 | """ 219 | Builds a Pydantic model from a DSPy signature's output_fields and ensures the generated JSON schema 220 | is compatible with OpenAI Structured Outputs (all objects have a "required" key listing every property, 221 | and additionalProperties is always false). 222 | 223 | IMPORTANT: If any field's annotation is an open-ended mapping (e.g. dict[str, Any]), then a structured 224 | schema cannot be generated since all properties must be explicitly declared. In that case, an exception 225 | is raised so that the caller can fall back to using a plain "json_object" response_format. 226 | """ 227 | # Although we've already performed an early check, we keep this here as a final guard. 228 | for name, field in signature.output_fields.items(): 229 | annotation = field.annotation 230 | if get_origin(annotation) is dict: 231 | raise ValueError( 232 | f"Field '{name}' has an open-ended mapping type which is not supported by Structured Outputs." 233 | ) 234 | 235 | fields = {} 236 | for name, field in signature.output_fields.items(): 237 | annotation = field.annotation 238 | if use_native_function_calling and annotation == ToolCalls: 239 | # Skip ToolCalls field if native function calling is enabled. 240 | continue 241 | default = field.default if hasattr(field, "default") else ... 242 | fields[name] = (annotation, default) 243 | 244 | # Build the model with extra fields forbidden. 245 | pydantic_model = pydantic.create_model( 246 | "DSPyProgramOutputs", 247 | __config__=pydantic.ConfigDict(extra="forbid"), 248 | **fields, 249 | ) 250 | 251 | # Generate the initial schema. 252 | schema = pydantic_model.model_json_schema() 253 | 254 | # Remove any DSPy-specific metadata. 255 | for prop in schema.get("properties", {}).values(): 256 | prop.pop("json_schema_extra", None) 257 | 258 | def enforce_required(schema_part: dict): 259 | """ 260 | Recursively ensure that: 261 | - for any object schema, a "required" key is added with all property names (or [] if no properties) 262 | - additionalProperties is set to False regardless of the previous value. 263 | - the same enforcement is run for nested arrays and definitions. 264 | """ 265 | if schema_part.get("type") == "object": 266 | props = schema_part.get("properties") 267 | if props is not None: 268 | # For objects with explicitly declared properties: 269 | schema_part["required"] = list(props.keys()) 270 | schema_part["additionalProperties"] = False 271 | for sub_schema in props.values(): 272 | if isinstance(sub_schema, dict): 273 | enforce_required(sub_schema) 274 | else: 275 | # For objects with no properties (should not happen normally but a fallback). 276 | schema_part["properties"] = {} 277 | schema_part["required"] = [] 278 | schema_part["additionalProperties"] = False 279 | if schema_part.get("type") == "array" and isinstance(schema_part.get("items"), dict): 280 | enforce_required(schema_part["items"]) 281 | # Also enforce in any nested definitions / $defs. 282 | for key in ("$defs", "definitions"): 283 | if key in schema_part: 284 | for def_schema in schema_part[key].values(): 285 | enforce_required(def_schema) 286 | 287 | enforce_required(schema) 288 | 289 | # Override the model's JSON schema generation to return our precomputed schema. 290 | pydantic_model.model_json_schema = lambda *args, **kwargs: schema 291 | 292 | return pydantic_model 293 | ``` -------------------------------------------------------------------------------- /tests/teleprompt/test_gepa_instruction_proposer.py: -------------------------------------------------------------------------------- ```python 1 | from dataclasses import dataclass 2 | from typing import Any 3 | 4 | import dspy 5 | from dspy.teleprompt.gepa import instruction_proposal 6 | from dspy.utils.dummies import DummyLM 7 | 8 | 9 | def count_messages_with_image_url_pattern(messages): 10 | """Helper to count image URLs in messages - borrowed from image adapter tests""" 11 | pattern = {"type": "image_url", "image_url": {"url": lambda x: isinstance(x, str)}} 12 | 13 | try: 14 | 15 | def check_pattern(obj, pattern): 16 | if isinstance(pattern, dict): 17 | if not isinstance(obj, dict): 18 | return False 19 | return all(k in obj and check_pattern(obj[k], v) for k, v in pattern.items()) 20 | if callable(pattern): 21 | return pattern(obj) 22 | return obj == pattern 23 | 24 | def count_patterns(obj, pattern): 25 | count = 0 26 | if check_pattern(obj, pattern): 27 | count += 1 28 | if isinstance(obj, dict): 29 | count += sum(count_patterns(v, pattern) for v in obj.values()) 30 | if isinstance(obj, (list, tuple)): 31 | count += sum(count_patterns(v, pattern) for v in obj) 32 | return count 33 | 34 | return count_patterns(messages, pattern) 35 | except Exception: 36 | return 0 37 | 38 | 39 | @dataclass 40 | class ImagesInHistory: 41 | has_structured_images: bool 42 | has_text_serialized_images: bool 43 | 44 | 45 | def check_images_in_history(history: list[Any]) -> ImagesInHistory: 46 | def check_text_serialized(item: Any) -> bool: 47 | if isinstance(item, list): 48 | return any(check_text_serialized(i) for i in item) 49 | if isinstance(item, dict): 50 | return any(check_text_serialized(i) for i in item.values()) 51 | if isinstance(item, str): 52 | return "CUSTOM-TYPE-START-IDENTIFIER" in item 53 | 54 | return False 55 | 56 | has_structured_images = False 57 | 58 | for call in history: 59 | if call.get("messages"): 60 | image_count = count_messages_with_image_url_pattern(call["messages"]) 61 | if image_count > 0: 62 | has_structured_images = True 63 | 64 | break 65 | 66 | return ImagesInHistory( 67 | has_structured_images=has_structured_images, 68 | has_text_serialized_images=any(check_text_serialized(i) for i in history), 69 | ) 70 | 71 | 72 | def test_reflection_lm_gets_structured_images(): 73 | """ 74 | Verify reflection LM receives structured image messages, not serialized text. 75 | """ 76 | student = dspy.Predict("image: dspy.Image -> label: str") 77 | image = dspy.Image("https://example.com/test.jpg") 78 | example = dspy.Example(image=image, label="dog").with_inputs("image") 79 | 80 | reflection_lm = DummyLM( 81 | [ 82 | {"improved_instruction": "Better instruction"}, 83 | {"improved_instruction": "Enhanced visual analysis instruction"}, 84 | {"improved_instruction": "Focus on key features"}, 85 | {"improved_instruction": "Analyze visual patterns systematically"}, 86 | {"improved_instruction": "Consider distinctive visual elements"}, 87 | {"improved_instruction": "Enhance recognition accuracy"}, 88 | {"improved_instruction": "Improve classification methodology"}, 89 | ] 90 | ) 91 | lm = DummyLM( 92 | [ 93 | {"label": "cat"}, 94 | {"label": "dog"}, 95 | {"label": "animal"}, 96 | {"label": "pet"}, 97 | {"label": "feline"}, 98 | {"label": "canine"}, 99 | {"label": "mammal"}, 100 | {"label": "creature"}, 101 | {"label": "species"}, 102 | {"label": "domestic"}, 103 | {"label": "wild"}, 104 | {"label": "carnivore"}, 105 | {"label": "herbivore"}, 106 | {"label": "quadruped"}, 107 | {"label": "vertebrate"}, 108 | ] 109 | ) 110 | dspy.settings.configure(lm=lm) 111 | 112 | gepa = dspy.GEPA( 113 | metric=lambda gold, pred, trace=None, pred_name=None, pred_trace=None: 0.3, 114 | max_metric_calls=2, 115 | reflection_lm=reflection_lm, 116 | instruction_proposer=instruction_proposal.MultiModalInstructionProposer(), 117 | ) 118 | 119 | gepa.compile(student, trainset=[example], valset=[example]) 120 | 121 | assert len(lm.history) > 0, "LM should have been called" 122 | assert len(reflection_lm.history) > 0, "Reflection LM should have been called" 123 | 124 | images_in_history = check_images_in_history(reflection_lm.history) 125 | 126 | assert images_in_history.has_structured_images, "Reflection LM should have received structured images" 127 | assert not images_in_history.has_text_serialized_images, "Reflection LM received serialized images in prompts" 128 | 129 | 130 | def test_custom_proposer_without_reflection_lm(): 131 | """Test that custom instruction proposers can work without reflection_lm when using updated GEPA core.""" 132 | 133 | # External reflection LM managed by the custom proposer 134 | external_reflection_lm = DummyLM( 135 | [ 136 | {"improved_instruction": "External LM response"}, 137 | {"improved_instruction": "Enhanced instruction"}, 138 | {"improved_instruction": "Better guidance"}, 139 | {"improved_instruction": "Optimized instruction"}, 140 | {"improved_instruction": "Refined approach"}, 141 | ] 142 | ) 143 | 144 | class ProposerWithExternalLM: 145 | def __call__(self, candidate, reflective_dataset, components_to_update): 146 | # This proposer manages its own external reflection LM 147 | with dspy.context(lm=external_reflection_lm): 148 | # Use external LM for reflection (optional - could be any custom logic) 149 | external_reflection_lm([{"role": "user", "content": "Improve this instruction"}]) 150 | return {name: f"Externally-improved: {candidate[name]}" for name in components_to_update} 151 | 152 | student = dspy.Predict("text -> label") 153 | example = dspy.Example(text="test input", label="test").with_inputs("text") 154 | 155 | # Use a robust dummy LM with enough responses for optimization steps 156 | lm = DummyLM( 157 | [ 158 | {"label": "test"}, 159 | {"label": "result"}, 160 | {"label": "output"}, 161 | {"label": "response"}, 162 | {"label": "classification"}, 163 | {"label": "prediction"}, 164 | {"label": "category"}, 165 | {"label": "type"}, 166 | {"label": "class"}, 167 | {"label": "group"}, 168 | {"label": "kind"}, 169 | {"label": "variant"}, 170 | {"label": "form"}, 171 | {"label": "style"}, 172 | {"label": "mode"}, 173 | ] 174 | ) 175 | dspy.settings.configure(lm=lm) 176 | 177 | # Test the full flexibility: no reflection_lm provided to GEPA at all! 178 | # The updated GEPA core library now allows this when using custom proposers 179 | gepa = dspy.GEPA( 180 | metric=lambda gold, pred, trace=None, pred_name=None, pred_trace=None: 0.7, # Score to trigger optimization 181 | max_metric_calls=5, # More calls to allow proper optimization 182 | reflection_lm=None, # No reflection_lm provided - this now works! 183 | instruction_proposer=ProposerWithExternalLM(), 184 | ) 185 | 186 | result = gepa.compile(student, trainset=[example], valset=[example]) 187 | 188 | assert result is not None 189 | assert len(lm.history) > 0, "Main LM should have been called" 190 | assert len(external_reflection_lm.history) > 0, "External reflection LM should have been called by custom proposer" 191 | 192 | 193 | def test_image_serialization_into_strings(): 194 | """ 195 | Test that demonstrates the image serialization problem when calling lm directly with serialized image data. 196 | """ 197 | 198 | class InstructionProposerCallingLMDirectly: 199 | def __call__( 200 | self, 201 | candidate: dict[str, str], 202 | reflective_dataset: dict[str, list[dict[str, Any]]], 203 | components_to_update: list[str], 204 | ) -> dict[str, str]: 205 | updated_components = {} 206 | 207 | for component_name in components_to_update: 208 | if component_name not in candidate or component_name not in reflective_dataset: 209 | continue 210 | 211 | current_instruction = candidate[component_name] 212 | component_data = reflective_dataset[component_name] 213 | 214 | feedback_analysis = "Feedback analysis:\n" 215 | for i, example in enumerate(component_data): 216 | feedback_analysis += f"Example {i + 1}:\n" 217 | 218 | # Non ideal approach: extract and serialize image objects directly 219 | inputs = example.get("Inputs", {}) 220 | for key, value in inputs.items(): 221 | feedback_analysis += f" {key}: {value}\n" 222 | 223 | outputs = example.get("Generated Outputs", {}) 224 | feedback = example.get("Feedback", "") 225 | feedback_analysis += f" Outputs: {outputs}\n" 226 | feedback_analysis += f" Feedback: {feedback}\n\n" 227 | 228 | context_lm = dspy.settings.lm 229 | messages = [ 230 | {"role": "system", "content": "You are an instruction improvement assistant."}, 231 | { 232 | "role": "user", 233 | "content": f"Current instruction: {current_instruction}\n\nFeedback: {feedback_analysis}\n\nProvide an improved instruction:", 234 | }, 235 | ] 236 | 237 | result = context_lm(messages=messages) 238 | updated_components[component_name] = result[0] 239 | 240 | return updated_components 241 | 242 | direct_lm_call_proposer = InstructionProposerCallingLMDirectly() 243 | 244 | student = dspy.Predict("image -> label") 245 | 246 | image = dspy.Image("https://picsum.photos/id/237/200/300") 247 | 248 | examples = [ 249 | dspy.Example(image=image, label="cat").with_inputs("image"), 250 | dspy.Example(image=image, label="animal").with_inputs("image"), 251 | ] 252 | 253 | lm = DummyLM( 254 | [ 255 | {"label": "cat"}, 256 | {"label": "dog"}, 257 | {"label": "animal"}, 258 | {"label": "pet"}, 259 | {"label": "feline"}, 260 | {"label": "mammal"}, 261 | {"label": "creature"}, 262 | {"label": "species"}, 263 | {"label": "domestic"}, 264 | {"label": "wild"}, 265 | {"label": "carnivore"}, 266 | {"label": "herbivore"}, 267 | ] 268 | ) 269 | dspy.settings.configure(lm=lm) 270 | 271 | reflection_lm = DummyLM( 272 | [ 273 | {"improved_instruction": "Be more specific about image analysis"}, 274 | {"improved_instruction": "Focus on visual features when classifying"}, 275 | {"improved_instruction": "Consider contextual clues in the image"}, 276 | {"improved_instruction": "Analyze shape, color, and texture patterns"}, 277 | {"improved_instruction": "Look for distinguishing characteristics"}, 278 | ] 279 | ) 280 | 281 | gepa = dspy.GEPA( 282 | metric=lambda gold, pred, trace=None, pred_name=None, pred_trace=None: 0.3, 283 | max_metric_calls=5, 284 | reflection_lm=reflection_lm, 285 | instruction_proposer=direct_lm_call_proposer, 286 | ) 287 | 288 | gepa.compile(student, trainset=examples, valset=examples) 289 | 290 | assert len(lm.history) > 0, "LM should have been called" 291 | assert len(reflection_lm.history) > 0, "Reflection LM should have been called" 292 | 293 | images_in_history = check_images_in_history(reflection_lm.history) 294 | 295 | assert images_in_history.has_text_serialized_images, ( 296 | "Expected to find serialized images (CUSTOM-TYPE-START-IDENTIFIER)" 297 | ) 298 | 299 | 300 | def test_default_proposer(): 301 | student = dspy.Predict("image -> label") 302 | 303 | image = dspy.Image("https://picsum.photos/id/237/200/300") 304 | 305 | examples = [ 306 | dspy.Example(image=image, label="cat").with_inputs("image"), 307 | dspy.Example(image=image, label="animal").with_inputs("image"), 308 | ] 309 | 310 | lm = DummyLM( 311 | [ 312 | {"label": "cat"}, 313 | {"label": "dog"}, 314 | {"label": "animal"}, 315 | {"label": "pet"}, 316 | {"label": "feline"}, 317 | {"label": "mammal"}, 318 | {"label": "creature"}, 319 | {"label": "species"}, 320 | {"label": "domestic"}, 321 | {"label": "wild"}, 322 | {"label": "carnivore"}, 323 | {"label": "herbivore"}, 324 | ] 325 | ) 326 | dspy.settings.configure(lm=lm) 327 | 328 | reflection_lm = DummyLM( 329 | [ 330 | {"improved_instruction": "Be more specific about image analysis"}, 331 | {"improved_instruction": "Focus on visual features when classifying"}, 332 | {"improved_instruction": "Consider contextual clues in the image"}, 333 | {"improved_instruction": "Analyze shape, color, and texture patterns"}, 334 | {"improved_instruction": "Look for distinguishing characteristics"}, 335 | ] 336 | ) 337 | 338 | gepa = dspy.GEPA( 339 | metric=lambda gold, pred, trace=None, pred_name=None, pred_trace=None: 0.3, 340 | max_metric_calls=5, 341 | reflection_lm=reflection_lm, 342 | ) 343 | 344 | gepa.compile(student, trainset=examples, valset=examples) 345 | 346 | assert len(lm.history) > 0, "LM should have been called" 347 | assert len(reflection_lm.history) > 0, "Reflection LM should have been called" 348 | 349 | images_in_history = check_images_in_history(reflection_lm.history) 350 | 351 | assert images_in_history.has_text_serialized_images, ( 352 | "Expected to find serialized images (CUSTOM-TYPE-START-IDENTIFIER)" 353 | ) 354 | ``` -------------------------------------------------------------------------------- /dspy/teleprompt/bootstrap_finetune.py: -------------------------------------------------------------------------------- ```python 1 | import logging 2 | from collections import defaultdict 3 | from typing import Any, Callable 4 | 5 | import dspy 6 | from dspy.adapters.base import Adapter 7 | from dspy.adapters.chat_adapter import ChatAdapter 8 | from dspy.clients.lm import LM 9 | from dspy.clients.utils_finetune import infer_data_format 10 | from dspy.dsp.utils.settings import settings 11 | from dspy.predict.predict import Predict 12 | from dspy.primitives.example import Example 13 | from dspy.primitives.module import Module 14 | from dspy.teleprompt.bootstrap_trace import bootstrap_trace_data 15 | from dspy.teleprompt.teleprompt import Teleprompter 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | class FinetuneTeleprompter(Teleprompter): 21 | def __init__( 22 | self, 23 | train_kwargs: dict[str, Any] | dict[LM, dict[str, Any]] | None = None, 24 | ): 25 | self.train_kwargs: dict[LM, Any] = self.convert_to_lm_dict(train_kwargs or {}) 26 | 27 | @staticmethod 28 | def convert_to_lm_dict(arg) -> dict[LM, Any]: 29 | non_empty_dict = arg and isinstance(arg, dict) 30 | if non_empty_dict and all(isinstance(k, LM) for k in arg.keys()): 31 | return arg 32 | # Default to using the same value for all LMs 33 | return defaultdict(lambda: arg) 34 | 35 | 36 | class BootstrapFinetune(FinetuneTeleprompter): 37 | def __init__( 38 | self, 39 | metric: Callable | None = None, 40 | multitask: bool = True, 41 | train_kwargs: dict[str, Any] | dict[LM, dict[str, Any]] | None = None, 42 | adapter: Adapter | dict[LM, Adapter] | None = None, 43 | exclude_demos: bool = False, 44 | num_threads: int | None = None, 45 | ): 46 | # TODO(feature): Inputs train_kwargs (a dict with string keys) and 47 | # adapter (Adapter) can depend on the LM they are used with. We are 48 | # takingthese as parameters for the time being. However, they can be 49 | # attached to LMs themselves -- an LM could know which adapter it should 50 | # be used with along with the train_kwargs. This will lead the only 51 | # required argument for LM.finetune() to be the train dataset. 52 | 53 | super().__init__(train_kwargs=train_kwargs) 54 | self.metric = metric 55 | self.multitask = multitask 56 | self.adapter: dict[LM, Adapter] = self.convert_to_lm_dict(adapter) 57 | self.exclude_demos = exclude_demos 58 | self.num_threads = num_threads 59 | 60 | def compile( 61 | self, student: Module, trainset: list[Example], teacher: Module | list[Module] | None = None 62 | ) -> Module: 63 | # TODO: Print statements can be converted to logger.info if we ensure 64 | # that the default DSPy logger logs info level messages in notebook 65 | # environments. 66 | logger.info("Preparing the student and teacher programs...") 67 | all_predictors_have_lms(student) 68 | 69 | logger.info("Bootstrapping data...") 70 | trace_data = [] 71 | 72 | teachers = teacher if isinstance(teacher, list) else [teacher] 73 | teachers = [prepare_teacher(student, t) for t in teachers] 74 | num_threads = self.num_threads or dspy.settings.num_threads 75 | for t in teachers: 76 | trace_data += bootstrap_trace_data(program=t, dataset=trainset, metric=self.metric, num_threads=num_threads) 77 | 78 | logger.info("Preparing the train data...") 79 | key_to_data = {} 80 | for pred_ind, pred in enumerate(student.predictors()): 81 | data_pred_ind = None if self.multitask else pred_ind 82 | if pred.lm is None: 83 | raise ValueError( 84 | f"Predictor {pred_ind} does not have an LM assigned. " 85 | f"Please ensure the module's predictors have their LM set before fine-tuning. " 86 | f"You can set it using: your_module.set_lm(your_lm)" 87 | ) 88 | training_key = (pred.lm, data_pred_ind) 89 | 90 | if training_key not in key_to_data: 91 | train_data, data_format = self._prepare_finetune_data( 92 | trace_data=trace_data, lm=pred.lm, pred_ind=data_pred_ind 93 | ) 94 | logger.info(f"Using {len(train_data)} data points for fine-tuning the model: {pred.lm.model}") 95 | finetune_kwargs = { 96 | "lm": pred.lm, 97 | "train_data": train_data, 98 | "train_data_format": data_format, 99 | "train_kwargs": self.train_kwargs[pred.lm], 100 | } 101 | key_to_data[training_key] = finetune_kwargs 102 | 103 | logger.info("Starting LM fine-tuning...") 104 | # TODO(feature): We could run batches of fine-tuning jobs in sequence 105 | # to avoid exceeding the number of threads. 106 | if len(key_to_data) > num_threads: 107 | raise ValueError( 108 | "BootstrapFinetune requires `num_threads` to be bigger than or equal to the number of fine-tuning " 109 | f"jobs. There are {len(key_to_data)} fine-tuning jobs to start, but the number of threads is: " 110 | f"{num_threads}! If the `multitask` flag is set to False, the number of fine-tuning jobs will " 111 | "be equal to the number of predictors in the student program. If the `multitask` flag is set to True, " 112 | "the number of fine-tuning jobs will be equal to: 1 if there is only a context LM, or the number of " 113 | "unique LMs attached to the predictors in the student program. In any case, the number of fine-tuning " 114 | "jobs will be less than or equal to the number of predictors." 115 | ) 116 | logger.info(f"{len(key_to_data)} fine-tuning job(s) to start") 117 | key_to_lm = self.finetune_lms(key_to_data) 118 | 119 | logger.info("Updating the student program with the fine-tuned LMs...") 120 | for pred_ind, pred in enumerate(student.predictors()): 121 | data_pred_ind = None if self.multitask else pred_ind 122 | training_key = (pred.lm, data_pred_ind) 123 | finetuned_lm = key_to_lm[training_key] 124 | if isinstance(finetuned_lm, Exception): 125 | raise RuntimeError(f"Finetuned LM for predictor {pred_ind} failed.") from finetuned_lm 126 | pred.lm = finetuned_lm 127 | # TODO: What should the correct behavior be here? Should 128 | # BootstrapFinetune modify the prompt demos according to the 129 | # train data? 130 | pred.demos = [] if self.exclude_demos else pred.demos 131 | 132 | logger.info("BootstrapFinetune has finished compiling the student program") 133 | student._compiled = True 134 | return student 135 | 136 | @staticmethod 137 | def finetune_lms(finetune_dict) -> dict[Any, LM]: 138 | num_jobs = len(finetune_dict) 139 | logger.info(f"Starting {num_jobs} fine-tuning job(s)...") 140 | # TODO(nit) Pass an identifier to the job so that we can tell the logs 141 | # coming from different fine-tune threads. 142 | 143 | key_to_job = {} 144 | for key, finetune_kwargs in finetune_dict.items(): 145 | lm: LM = finetune_kwargs.pop("lm") 146 | # TODO: The following line is a hack. We should re-think how to free 147 | # up resources for fine-tuning. This might mean introducing a new 148 | # provider method (e.g. prepare_for_finetune) that can be called 149 | # before fine-tuning is started. 150 | logger.info( 151 | "Calling lm.kill() on the LM to be fine-tuned to free up resources. This won't have any effect if the " 152 | "LM is not running." 153 | ) 154 | lm.kill() 155 | key_to_job[key] = lm.finetune(**finetune_kwargs) 156 | 157 | key_to_lm = {} 158 | for ind, (key, job) in enumerate(key_to_job.items()): 159 | result = job.result() 160 | if isinstance(result, Exception): 161 | raise result 162 | key_to_lm[key] = result 163 | job.thread.join() 164 | logger.info(f"Job {ind + 1}/{num_jobs} is done") 165 | 166 | return key_to_lm 167 | 168 | def _prepare_finetune_data(self, trace_data: list[dict[str, Any]], lm: LM, pred_ind: int | None = None): 169 | # TODO(nit) Log dataset details/size; make logs nicer 170 | if self.metric: 171 | logger.info(f"Collected data for {len(trace_data)} examples") 172 | trace_data = [d for d in trace_data if d["score"]] 173 | logger.info(f"After filtering with the metric, {len(trace_data)} examples remain") 174 | 175 | data = [] 176 | adapter = self.adapter[lm] or settings.adapter or ChatAdapter() 177 | data_format = infer_data_format(adapter) 178 | for item in trace_data: 179 | for pred_ind, _ in enumerate(item["trace"]): 180 | include_data = pred_ind is None or pred_ind == pred_ind 181 | if include_data: 182 | call_data = build_call_data_from_trace( 183 | trace=item["trace"], pred_ind=pred_ind, adapter=adapter, exclude_demos=self.exclude_demos 184 | ) 185 | data.append(call_data) 186 | 187 | import random 188 | 189 | random.Random(0).shuffle(data) 190 | 191 | return data, data_format 192 | 193 | 194 | # Note: Shared below are useful functions for preparing student/teacher programs 195 | # Similar methods are implemented separately and used by other DSPy 196 | # teleprompters. These can be moved to shared locations. 197 | def build_call_data_from_trace( 198 | trace: list[dict], 199 | pred_ind: int, 200 | adapter: Adapter, 201 | exclude_demos: bool = False, 202 | ) -> dict[str, list[dict[str, Any]]]: 203 | # Find data that's relevant to the predictor 204 | pred, inputs, outputs = trace[pred_ind] # assuming that the order is kept 205 | 206 | demos = [] if exclude_demos else pred.demos 207 | call_data = adapter.format_finetune_data( 208 | signature=pred.signature, 209 | demos=demos, 210 | inputs=inputs, 211 | outputs=outputs, 212 | ) 213 | return call_data 214 | 215 | 216 | # # TODO(PR) check with team 217 | # def bootstrap_trace_data_one_example( 218 | # example: Example, 219 | # program: Program, 220 | # metric: Optional[Callable] = None 221 | # ) -> dict[str, Any]: 222 | # # Return a dict with the following keys: 223 | # # example, prediction, trace, and score (if metric != None) 224 | # with dspy.context(trace=[]): 225 | # prediction = program(**example.inputs()) 226 | # trace = dspy.settings.trace 227 | # score = metric(example, prediction, trace) if metric else None 228 | 229 | # data_dict = dict( 230 | # example=example, 231 | # prediction=prediction, 232 | # trace=trace, 233 | # ) 234 | # if metric: 235 | # data_dict["score"] = score 236 | 237 | # return data_dict 238 | 239 | 240 | # Note: Shared below are useful functions for preparing student/teacher programs 241 | # Similar methods are implemented separately and used by other DSPy 242 | # teleprompters. These can be moved to shared locations. 243 | def all_predictors_have_lms(program: Module) -> bool: 244 | """Return True if all predictors in the program have an LM set.""" 245 | return all(pred.lm for pred in program.predictors()) 246 | 247 | 248 | def copy_program_with_lms(program: Module) -> Module: 249 | pred_lms = [pred.lm for pred in program.predictors()] 250 | program = program.deepcopy() 251 | for ind, pred in enumerate(program.predictors()): 252 | pred.lm = pred_lms[ind] 253 | return program 254 | 255 | 256 | def prepare_student(student: Module) -> Module: 257 | if getattr(student, "_compiled", False): 258 | raise ValueError("The student program should not be compiled.") 259 | 260 | # TODO: Should we use reset_copy here? How would it affect the student 261 | # program's predictor LMs, if they are set? 262 | 263 | # TODO: Should there be a deepcopy here? 264 | # student = student.deepcopy() 265 | return student 266 | 267 | 268 | def prepare_teacher(student: Module, teacher: Module | None = None) -> Module: 269 | if teacher is None: 270 | return student 271 | 272 | # Ensuring that the student and teacher are are structurally equivalent 273 | assert_structural_equivalency(student, teacher) 274 | 275 | # Ensuring that the student and teacher programs do not share predictors 276 | assert_no_shared_predictor(student, teacher) 277 | 278 | return teacher 279 | 280 | 281 | def assert_structural_equivalency(program1: object, program2: object): 282 | assert isinstance(program1, Module) 283 | assert isinstance(program2, Module) 284 | 285 | num1 = len(program1.predictors()) 286 | num2 = len(program2.predictors()) 287 | err = f"Structurally equivalent programs must have the the number of predictors. The number of predictors for the two modules do not match: {num1} != {num2}" 288 | assert num1 == num2, err 289 | 290 | pzip = zip(program1.named_predictors(), program2.named_predictors(), strict=False) 291 | for ind, ((name1, pred1), (name2, pred2)) in enumerate(pzip): 292 | err = f"Program predictor names must match at corresponding indices for structural equivalency. The predictor names for the programs do not match at index {ind}: '{name1}' != '{name2}'" 293 | assert name1 == name2, err 294 | assert isinstance(pred1, Predict) 295 | assert isinstance(pred2, Predict) 296 | 297 | 298 | def assert_no_shared_predictor(program1: Module, program2: Module): 299 | id_to_name1 = {id(p): n for n, p in program1.named_predictors()} 300 | id_to_name2 = {id(p): n for n, p in program2.named_predictors()} 301 | shared_ids = set(id_to_name1.keys()) & set(id_to_name2.keys()) 302 | 303 | pred_names = ", ".join(id_to_name1[id] for id in shared_ids) 304 | err = f"The programs share the following predictor(s) with each other: {pred_names}" 305 | assert not shared_ids, err 306 | 307 | 308 | def get_unique_lms(program: Module) -> list[LM]: 309 | lms = [pred.lm for pred in program.predictors()] 310 | return list(set(lms)) 311 | 312 | 313 | def launch_lms(program: Module): 314 | lms = get_unique_lms(program) 315 | for lm in lms: 316 | lm.launch() 317 | 318 | 319 | def kill_lms(program: Module): 320 | lms = get_unique_lms(program) 321 | for lm in lms: 322 | lm.kill() 323 | ``` -------------------------------------------------------------------------------- /tests/predict/test_react.py: -------------------------------------------------------------------------------- ```python 1 | import re 2 | 3 | import litellm 4 | import pytest 5 | from pydantic import BaseModel 6 | 7 | import dspy 8 | from dspy.utils.dummies import DummyLM 9 | 10 | 11 | @pytest.mark.extra 12 | def test_tool_observation_preserves_custom_type(): 13 | pytest.importorskip("PIL.Image") 14 | from PIL import Image 15 | 16 | captured_calls = [] 17 | 18 | class SpyChatAdapter(dspy.ChatAdapter): 19 | def format_user_message_content(self, signature, inputs, *args, **kwargs): 20 | captured_calls.append((signature, dict(inputs))) 21 | return super().format_user_message_content(signature, inputs, *args, **kwargs) 22 | 23 | def make_images(): 24 | return dspy.Image("https://example.com/test.png"), dspy.Image(Image.new("RGB", (100, 100), "red")) 25 | 26 | 27 | adapter = SpyChatAdapter() 28 | lm = DummyLM( 29 | [ 30 | { 31 | "next_thought": "I should call the image tool.", 32 | "next_tool_name": "make_images", 33 | "next_tool_args": {}, 34 | }, 35 | { 36 | "next_thought": "I now have the image so I can finish.", 37 | "next_tool_name": "finish", 38 | "next_tool_args": {}, 39 | }, 40 | {"reasoning": "image ready", "answer": "done"}, 41 | ], 42 | adapter=adapter, 43 | ) 44 | dspy.settings.configure(lm=lm, adapter=adapter) 45 | 46 | react = dspy.ReAct("question -> answer", tools=[make_images]) 47 | react(question="Draw me something red") 48 | 49 | sigs_with_obs = [sig for sig, inputs in captured_calls if "observation_0" in inputs] 50 | assert sigs_with_obs, "Expected ReAct to format a trajectory containing observation_0" 51 | 52 | observation_content = lm.history[1]["messages"][1]["content"] 53 | assert sum(1 for part in observation_content if isinstance(part, dict) and part.get("type") == "image_url") == 2 54 | 55 | 56 | def test_tool_calling_with_pydantic_args(): 57 | class CalendarEvent(BaseModel): 58 | name: str 59 | date: str 60 | participants: dict[str, str] 61 | 62 | def write_invitation_letter(participant_name: str, event_info: CalendarEvent): 63 | if participant_name not in event_info.participants: 64 | return None 65 | return f"It's my honor to invite {participant_name} to event {event_info.name} on {event_info.date}" 66 | 67 | class InvitationSignature(dspy.Signature): 68 | participant_name: str = dspy.InputField(desc="The name of the participant to invite") 69 | event_info: CalendarEvent = dspy.InputField(desc="The information about the event") 70 | invitation_letter: str = dspy.OutputField(desc="The invitation letter to be sent to the participant") 71 | 72 | react = dspy.ReAct(InvitationSignature, tools=[write_invitation_letter]) 73 | 74 | lm = DummyLM( 75 | [ 76 | { 77 | "next_thought": "I need to write an invitation letter for Alice to the Science Fair event.", 78 | "next_tool_name": "write_invitation_letter", 79 | "next_tool_args": { 80 | "participant_name": "Alice", 81 | "event_info": { 82 | "name": "Science Fair", 83 | "date": "Friday", 84 | "participants": {"Alice": "female", "Bob": "male"}, 85 | }, 86 | }, 87 | }, 88 | { 89 | "next_thought": ( 90 | "I have successfully written the invitation letter for Alice to the Science Fair. Now " 91 | "I can finish the task." 92 | ), 93 | "next_tool_name": "finish", 94 | "next_tool_args": {}, 95 | }, 96 | { 97 | "reasoning": "This is a very rigorous reasoning process, trust me bro!", 98 | "invitation_letter": "It's my honor to invite Alice to the Science Fair event on Friday.", 99 | }, 100 | ] 101 | ) 102 | dspy.settings.configure(lm=lm) 103 | 104 | outputs = react( 105 | participant_name="Alice", 106 | event_info=CalendarEvent( 107 | name="Science Fair", 108 | date="Friday", 109 | participants={"Alice": "female", "Bob": "male"}, 110 | ), 111 | ) 112 | assert outputs.invitation_letter == "It's my honor to invite Alice to the Science Fair event on Friday." 113 | 114 | expected_trajectory = { 115 | "thought_0": "I need to write an invitation letter for Alice to the Science Fair event.", 116 | "tool_name_0": "write_invitation_letter", 117 | "tool_args_0": { 118 | "participant_name": "Alice", 119 | "event_info": { 120 | "name": "Science Fair", 121 | "date": "Friday", 122 | "participants": {"Alice": "female", "Bob": "male"}, 123 | }, 124 | }, 125 | "observation_0": "It's my honor to invite Alice to event Science Fair on Friday", 126 | "thought_1": "I have successfully written the invitation letter for Alice to the Science Fair. Now I can finish the task.", 127 | "tool_name_1": "finish", 128 | "tool_args_1": {}, 129 | "observation_1": "Completed.", 130 | } 131 | assert outputs.trajectory == expected_trajectory 132 | 133 | 134 | def test_tool_calling_without_typehint(): 135 | def foo(a, b): 136 | """Add two numbers.""" 137 | return a + b 138 | 139 | react = dspy.ReAct("a, b -> c:int", tools=[foo]) 140 | lm = DummyLM( 141 | [ 142 | {"next_thought": "I need to add two numbers.", "next_tool_name": "foo", "next_tool_args": {"a": 1, "b": 2}}, 143 | {"next_thought": "I have the sum, now I can finish.", "next_tool_name": "finish", "next_tool_args": {}}, 144 | {"reasoning": "I added the numbers successfully", "c": 3}, 145 | ] 146 | ) 147 | dspy.settings.configure(lm=lm) 148 | outputs = react(a=1, b=2) 149 | 150 | expected_trajectory = { 151 | "thought_0": "I need to add two numbers.", 152 | "tool_name_0": "foo", 153 | "tool_args_0": { 154 | "a": 1, 155 | "b": 2, 156 | }, 157 | "observation_0": 3, 158 | "thought_1": "I have the sum, now I can finish.", 159 | "tool_name_1": "finish", 160 | "tool_args_1": {}, 161 | "observation_1": "Completed.", 162 | } 163 | assert outputs.trajectory == expected_trajectory 164 | 165 | 166 | def test_trajectory_truncation(): 167 | # Create a simple tool for testing 168 | def echo(text: str) -> str: 169 | return f"Echoed: {text}" 170 | 171 | # Create ReAct instance with our echo tool 172 | react = dspy.ReAct("input_text -> output_text", tools=[echo]) 173 | 174 | # Mock react.react to simulate multiple tool calls 175 | call_count = 0 176 | 177 | def mock_react(**kwargs): 178 | nonlocal call_count 179 | call_count += 1 180 | 181 | if call_count < 3: 182 | # First 2 calls use the echo tool 183 | return dspy.Prediction( 184 | next_thought=f"Thought {call_count}", 185 | next_tool_name="echo", 186 | next_tool_args={"text": f"Text {call_count}"}, 187 | ) 188 | elif call_count == 3: 189 | # The 3rd call raises context window exceeded error 190 | raise litellm.ContextWindowExceededError("Context window exceeded", "dummy_model", "dummy_provider") 191 | else: 192 | # The 4th call finishes 193 | return dspy.Prediction(next_thought="Final thought", next_tool_name="finish", next_tool_args={}) 194 | 195 | react.react = mock_react 196 | react.extract = lambda **kwargs: dspy.Prediction(output_text="Final output") 197 | 198 | # Call forward and get the result 199 | result = react(input_text="test input") 200 | 201 | # Verify that older entries in the trajectory were truncated 202 | assert "thought_0" not in result.trajectory 203 | assert "thought_2" in result.trajectory 204 | assert result.output_text == "Final output" 205 | 206 | 207 | def test_error_retry(): 208 | # --- a tiny tool that always fails ------------------------------------- 209 | def foo(a, b): 210 | raise Exception("tool error") 211 | 212 | # --- program under test ------------------------------------------------- 213 | react = dspy.ReAct("a, b -> c:int", tools=[foo]) 214 | lm = DummyLM( 215 | [ 216 | { 217 | "next_thought": "I need to add two numbers.", 218 | "next_tool_name": "foo", 219 | "next_tool_args": {"a": 1, "b": 2}, 220 | }, 221 | { 222 | "next_thought": "I need to add two numbers.", 223 | "next_tool_name": "foo", 224 | "next_tool_args": {"a": 1, "b": 2}, 225 | }, 226 | # (The model *would* succeed on the 3rd turn, but max_iters=2 stops earlier.) 227 | {"reasoning": "I added the numbers successfully", "c": 3}, 228 | ] 229 | ) 230 | dspy.settings.configure(lm=lm) 231 | 232 | outputs = react(a=1, b=2, max_iters=2) 233 | traj = outputs.trajectory 234 | 235 | # --- exact-match checks (thoughts + tool calls) ------------------------- 236 | control_expected = { 237 | "thought_0": "I need to add two numbers.", 238 | "tool_name_0": "foo", 239 | "tool_args_0": {"a": 1, "b": 2}, 240 | "thought_1": "I need to add two numbers.", 241 | "tool_name_1": "foo", 242 | "tool_args_1": {"a": 1, "b": 2}, 243 | } 244 | for k, v in control_expected.items(): 245 | assert traj[k] == v, f"{k} mismatch" 246 | 247 | # --- flexible checks for observations ---------------------------------- 248 | # We only care that each observation mentions our error string; we ignore 249 | # any extra traceback detail or differing prefixes. 250 | for i in range(2): 251 | obs = traj[f"observation_{i}"] 252 | assert re.search(r"\btool error\b", obs), f"unexpected observation_{i!r}: {obs}" 253 | 254 | 255 | @pytest.mark.asyncio 256 | async def test_async_tool_calling_with_pydantic_args(): 257 | class CalendarEvent(BaseModel): 258 | name: str 259 | date: str 260 | participants: dict[str, str] 261 | 262 | async def write_invitation_letter(participant_name: str, event_info: CalendarEvent): 263 | if participant_name not in event_info.participants: 264 | return None 265 | return f"It's my honor to invite {participant_name} to event {event_info.name} on {event_info.date}" 266 | 267 | class InvitationSignature(dspy.Signature): 268 | participant_name: str = dspy.InputField(desc="The name of the participant to invite") 269 | event_info: CalendarEvent = dspy.InputField(desc="The information about the event") 270 | invitation_letter: str = dspy.OutputField(desc="The invitation letter to be sent to the participant") 271 | 272 | react = dspy.ReAct(InvitationSignature, tools=[write_invitation_letter]) 273 | 274 | lm = DummyLM( 275 | [ 276 | { 277 | "next_thought": "I need to write an invitation letter for Alice to the Science Fair event.", 278 | "next_tool_name": "write_invitation_letter", 279 | "next_tool_args": { 280 | "participant_name": "Alice", 281 | "event_info": { 282 | "name": "Science Fair", 283 | "date": "Friday", 284 | "participants": {"Alice": "female", "Bob": "male"}, 285 | }, 286 | }, 287 | }, 288 | { 289 | "next_thought": ( 290 | "I have successfully written the invitation letter for Alice to the Science Fair. Now " 291 | "I can finish the task." 292 | ), 293 | "next_tool_name": "finish", 294 | "next_tool_args": {}, 295 | }, 296 | { 297 | "reasoning": "This is a very rigorous reasoning process, trust me bro!", 298 | "invitation_letter": "It's my honor to invite Alice to the Science Fair event on Friday.", 299 | }, 300 | ] 301 | ) 302 | with dspy.context(lm=lm): 303 | outputs = await react.acall( 304 | participant_name="Alice", 305 | event_info=CalendarEvent( 306 | name="Science Fair", 307 | date="Friday", 308 | participants={"Alice": "female", "Bob": "male"}, 309 | ), 310 | ) 311 | assert outputs.invitation_letter == "It's my honor to invite Alice to the Science Fair event on Friday." 312 | 313 | expected_trajectory = { 314 | "thought_0": "I need to write an invitation letter for Alice to the Science Fair event.", 315 | "tool_name_0": "write_invitation_letter", 316 | "tool_args_0": { 317 | "participant_name": "Alice", 318 | "event_info": { 319 | "name": "Science Fair", 320 | "date": "Friday", 321 | "participants": {"Alice": "female", "Bob": "male"}, 322 | }, 323 | }, 324 | "observation_0": "It's my honor to invite Alice to event Science Fair on Friday", 325 | "thought_1": "I have successfully written the invitation letter for Alice to the Science Fair. Now I can finish the task.", 326 | "tool_name_1": "finish", 327 | "tool_args_1": {}, 328 | "observation_1": "Completed.", 329 | } 330 | assert outputs.trajectory == expected_trajectory 331 | 332 | 333 | @pytest.mark.asyncio 334 | async def test_async_error_retry(): 335 | # A tiny tool that always fails 336 | async def foo(a, b): 337 | raise Exception("tool error") 338 | 339 | react = dspy.ReAct("a, b -> c:int", tools=[foo]) 340 | lm = DummyLM( 341 | [ 342 | { 343 | "next_thought": "I need to add two numbers.", 344 | "next_tool_name": "foo", 345 | "next_tool_args": {"a": 1, "b": 2}, 346 | }, 347 | { 348 | "next_thought": "I need to add two numbers.", 349 | "next_tool_name": "foo", 350 | "next_tool_args": {"a": 1, "b": 2}, 351 | }, 352 | # (The model *would* succeed on the 3rd turn, but max_iters=2 stops earlier.) 353 | {"reasoning": "I added the numbers successfully", "c": 3}, 354 | ] 355 | ) 356 | with dspy.context(lm=lm): 357 | outputs = await react.acall(a=1, b=2, max_iters=2) 358 | traj = outputs.trajectory 359 | 360 | # Exact-match checks (thoughts + tool calls) 361 | control_expected = { 362 | "thought_0": "I need to add two numbers.", 363 | "tool_name_0": "foo", 364 | "tool_args_0": {"a": 1, "b": 2}, 365 | "thought_1": "I need to add two numbers.", 366 | "tool_name_1": "foo", 367 | "tool_args_1": {"a": 1, "b": 2}, 368 | } 369 | for k, v in control_expected.items(): 370 | assert traj[k] == v, f"{k} mismatch" 371 | 372 | # Flexible checks for observations 373 | # We only care that each observation mentions our error string; we ignore 374 | # any extra traceback detail or differing prefixes. 375 | for i in range(2): 376 | obs = traj[f"observation_{i}"] 377 | assert re.search(r"\btool error\b", obs), f"unexpected observation_{i!r}: {obs}" 378 | ``` -------------------------------------------------------------------------------- /dspy/utils/callback.py: -------------------------------------------------------------------------------- ```python 1 | import functools 2 | import inspect 3 | import logging 4 | import uuid 5 | from contextvars import ContextVar 6 | from typing import Any, Callable 7 | 8 | import dspy 9 | 10 | ACTIVE_CALL_ID = ContextVar("active_call_id", default=None) 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class BaseCallback: 16 | """A base class for defining callback handlers for DSPy components. 17 | 18 | To use a callback, subclass this class and implement the desired handlers. Each handler 19 | will be called at the appropriate time before/after the execution of the corresponding component. For example, if 20 | you want to print a message before and after an LM is called, implement `the on_llm_start` and `on_lm_end` handler. 21 | Users can set the callback globally using `dspy.settings.configure` or locally by passing it to the component 22 | constructor. 23 | 24 | 25 | Example 1: Set a global callback using `dspy.settings.configure`. 26 | 27 | ``` 28 | import dspy 29 | from dspy.utils.callback import BaseCallback 30 | 31 | class LoggingCallback(BaseCallback): 32 | 33 | def on_lm_start(self, call_id, instance, inputs): 34 | print(f"LM is called with inputs: {inputs}") 35 | 36 | def on_lm_end(self, call_id, outputs, exception): 37 | print(f"LM is finished with outputs: {outputs}") 38 | 39 | dspy.settings.configure( 40 | callbacks=[LoggingCallback()] 41 | ) 42 | 43 | cot = dspy.ChainOfThought("question -> answer") 44 | cot(question="What is the meaning of life?") 45 | 46 | # > LM is called with inputs: {'question': 'What is the meaning of life?'} 47 | # > LM is finished with outputs: {'answer': '42'} 48 | ``` 49 | 50 | Example 2: Set a local callback by passing it to the component constructor. 51 | 52 | ``` 53 | lm_1 = dspy.LM("gpt-3.5-turbo", callbacks=[LoggingCallback()]) 54 | lm_1(question="What is the meaning of life?") 55 | 56 | # > LM is called with inputs: {'question': 'What is the meaning of life?'} 57 | # > LM is finished with outputs: {'answer': '42'} 58 | 59 | lm_2 = dspy.LM("gpt-3.5-turbo") 60 | lm_2(question="What is the meaning of life?") 61 | # No logging here because only `lm_1` has the callback set. 62 | ``` 63 | """ 64 | 65 | def on_module_start( 66 | self, 67 | call_id: str, 68 | instance: Any, 69 | inputs: dict[str, Any], 70 | ): 71 | """A handler triggered when forward() method of a module (subclass of dspy.Module) is called. 72 | 73 | Args: 74 | call_id: A unique identifier for the call. Can be used to connect start/end handlers. 75 | instance: The Module instance. 76 | inputs: The inputs to the module's forward() method. Each arguments is stored as 77 | a key-value pair in a dictionary. 78 | """ 79 | pass 80 | 81 | def on_module_end( 82 | self, 83 | call_id: str, 84 | outputs: Any | None, 85 | exception: Exception | None = None, 86 | ): 87 | """A handler triggered after forward() method of a module (subclass of dspy.Module) is executed. 88 | 89 | Args: 90 | call_id: A unique identifier for the call. Can be used to connect start/end handlers. 91 | outputs: The outputs of the module's forward() method. If the method is interrupted by 92 | an exception, this will be None. 93 | exception: If an exception is raised during the execution, it will be stored here. 94 | """ 95 | pass 96 | 97 | def on_lm_start( 98 | self, 99 | call_id: str, 100 | instance: Any, 101 | inputs: dict[str, Any], 102 | ): 103 | """A handler triggered when __call__ method of dspy.LM instance is called. 104 | 105 | Args: 106 | call_id: A unique identifier for the call. Can be used to connect start/end handlers. 107 | instance: The LM instance. 108 | inputs: The inputs to the LM's __call__ method. Each arguments is stored as 109 | a key-value pair in a dictionary. 110 | """ 111 | pass 112 | 113 | def on_lm_end( 114 | self, 115 | call_id: str, 116 | outputs: dict[str, Any] | None, 117 | exception: Exception | None = None, 118 | ): 119 | """A handler triggered after __call__ method of dspy.LM instance is executed. 120 | 121 | Args: 122 | call_id: A unique identifier for the call. Can be used to connect start/end handlers. 123 | outputs: The outputs of the LM's __call__ method. If the method is interrupted by 124 | an exception, this will be None. 125 | exception: If an exception is raised during the execution, it will be stored here. 126 | """ 127 | pass 128 | 129 | def on_adapter_format_start( 130 | self, 131 | call_id: str, 132 | instance: Any, 133 | inputs: dict[str, Any], 134 | ): 135 | """A handler triggered when format() method of an adapter (subclass of dspy.Adapter) is called. 136 | 137 | Args: 138 | call_id: A unique identifier for the call. Can be used to connect start/end handlers. 139 | instance: The Adapter instance. 140 | inputs: The inputs to the Adapter's format() method. Each arguments is stored as 141 | a key-value pair in a dictionary. 142 | """ 143 | pass 144 | 145 | def on_adapter_format_end( 146 | self, 147 | call_id: str, 148 | outputs: dict[str, Any] | None, 149 | exception: Exception | None = None, 150 | ): 151 | """A handler triggered after format() method of an adapter (subclass of dspy.Adapter) is called.. 152 | 153 | Args: 154 | call_id: A unique identifier for the call. Can be used to connect start/end handlers. 155 | outputs: The outputs of the Adapter's format() method. If the method is interrupted 156 | by an exception, this will be None. 157 | exception: If an exception is raised during the execution, it will be stored here. 158 | """ 159 | pass 160 | 161 | def on_adapter_parse_start( 162 | self, 163 | call_id: str, 164 | instance: Any, 165 | inputs: dict[str, Any], 166 | ): 167 | """A handler triggered when parse() method of an adapter (subclass of dspy.Adapter) is called. 168 | 169 | Args: 170 | call_id: A unique identifier for the call. Can be used to connect start/end handlers. 171 | instance: The Adapter instance. 172 | inputs: The inputs to the Adapter's parse() method. Each arguments is stored as 173 | a key-value pair in a dictionary. 174 | """ 175 | pass 176 | 177 | def on_adapter_parse_end( 178 | self, 179 | call_id: str, 180 | outputs: dict[str, Any] | None, 181 | exception: Exception | None = None, 182 | ): 183 | """A handler triggered after parse() method of an adapter (subclass of dspy.Adapter) is called. 184 | 185 | Args: 186 | call_id: A unique identifier for the call. Can be used to connect start/end handlers. 187 | outputs: The outputs of the Adapter's parse() method. If the method is interrupted 188 | by an exception, this will be None. 189 | exception: If an exception is raised during the execution, it will be stored here. 190 | """ 191 | pass 192 | 193 | def on_tool_start( 194 | self, 195 | call_id: str, 196 | instance: Any, 197 | inputs: dict[str, Any], 198 | ): 199 | """A handler triggered when a tool is called. 200 | 201 | Args: 202 | call_id: A unique identifier for the call. Can be used to connect start/end handlers. 203 | instance: The Tool instance. 204 | inputs: The inputs to the Tool's __call__ method. Each arguments is stored as 205 | a key-value pair in a dictionary. 206 | """ 207 | pass 208 | 209 | def on_tool_end( 210 | self, 211 | call_id: str, 212 | outputs: dict[str, Any] | None, 213 | exception: Exception | None = None, 214 | ): 215 | """A handler triggered after a tool is executed. 216 | 217 | Args: 218 | call_id: A unique identifier for the call. Can be used to connect start/end handlers. 219 | outputs: The outputs of the Tool's __call__ method. If the method is interrupted by 220 | an exception, this will be None. 221 | exception: If an exception is raised during the execution, it will be stored here. 222 | """ 223 | pass 224 | 225 | def on_evaluate_start( 226 | self, 227 | call_id: str, 228 | instance: Any, 229 | inputs: dict[str, Any], 230 | ): 231 | """A handler triggered when evaluation is started. 232 | 233 | Args: 234 | call_id: A unique identifier for the call. Can be used to connect start/end handlers. 235 | instance: The Evaluate instance. 236 | inputs: The inputs to the Evaluate's __call__ method. Each arguments is stored as 237 | a key-value pair in a dictionary. 238 | """ 239 | pass 240 | 241 | def on_evaluate_end( 242 | self, 243 | call_id: str, 244 | outputs: Any | None, 245 | exception: Exception | None = None, 246 | ): 247 | """A handler triggered after evaluation is executed. 248 | 249 | Args: 250 | call_id: A unique identifier for the call. Can be used to connect start/end handlers. 251 | outputs: The outputs of the Evaluate's __call__ method. If the method is interrupted by 252 | an exception, this will be None. 253 | exception: If an exception is raised during the execution, it will be stored here. 254 | """ 255 | pass 256 | 257 | 258 | def with_callbacks(fn): 259 | """Decorator to add callback functionality to instance methods.""" 260 | 261 | def _execute_start_callbacks(instance, fn, call_id, callbacks, args, kwargs): 262 | """Execute all start callbacks for a function call.""" 263 | inputs = inspect.getcallargs(fn, instance, *args, **kwargs) 264 | if "self" in inputs: 265 | inputs.pop("self") 266 | elif "instance" in inputs: 267 | inputs.pop("instance") 268 | for callback in callbacks: 269 | try: 270 | _get_on_start_handler(callback, instance, fn)(call_id=call_id, instance=instance, inputs=inputs) 271 | except Exception as e: 272 | logger.warning(f"Error when calling callback {callback}: {e}") 273 | 274 | def _execute_end_callbacks(instance, fn, call_id, results, exception, callbacks): 275 | """Execute all end callbacks for a function call.""" 276 | for callback in callbacks: 277 | try: 278 | _get_on_end_handler(callback, instance, fn)( 279 | call_id=call_id, 280 | outputs=results, 281 | exception=exception, 282 | ) 283 | except Exception as e: 284 | logger.warning(f"Error when applying callback {callback}'s end handler on function {fn.__name__}: {e}.") 285 | 286 | def _get_active_callbacks(instance): 287 | """Get combined global and instance-level callbacks.""" 288 | return dspy.settings.get("callbacks", []) + getattr(instance, "callbacks", []) 289 | 290 | if inspect.iscoroutinefunction(fn): 291 | 292 | @functools.wraps(fn) 293 | async def async_wrapper(instance, *args, **kwargs): 294 | callbacks = _get_active_callbacks(instance) 295 | if not callbacks: 296 | return await fn(instance, *args, **kwargs) 297 | 298 | call_id = uuid.uuid4().hex 299 | 300 | _execute_start_callbacks(instance, fn, call_id, callbacks, args, kwargs) 301 | 302 | # Active ID must be set right before the function is called, not before calling the callbacks. 303 | parent_call_id = ACTIVE_CALL_ID.get() 304 | ACTIVE_CALL_ID.set(call_id) 305 | 306 | results = None 307 | exception = None 308 | try: 309 | results = await fn(instance, *args, **kwargs) 310 | return results 311 | except Exception as e: 312 | exception = e 313 | raise exception 314 | finally: 315 | ACTIVE_CALL_ID.set(parent_call_id) 316 | _execute_end_callbacks(instance, fn, call_id, results, exception, callbacks) 317 | 318 | return async_wrapper 319 | 320 | else: 321 | 322 | @functools.wraps(fn) 323 | def sync_wrapper(instance, *args, **kwargs): 324 | callbacks = _get_active_callbacks(instance) 325 | if not callbacks: 326 | return fn(instance, *args, **kwargs) 327 | 328 | call_id = uuid.uuid4().hex 329 | 330 | _execute_start_callbacks(instance, fn, call_id, callbacks, args, kwargs) 331 | 332 | # Active ID must be set right before the function is called, not before calling the callbacks. 333 | parent_call_id = ACTIVE_CALL_ID.get() 334 | ACTIVE_CALL_ID.set(call_id) 335 | 336 | results = None 337 | exception = None 338 | try: 339 | results = fn(instance, *args, **kwargs) 340 | return results 341 | except Exception as e: 342 | exception = e 343 | raise exception 344 | finally: 345 | ACTIVE_CALL_ID.set(parent_call_id) 346 | _execute_end_callbacks(instance, fn, call_id, results, exception, callbacks) 347 | 348 | return sync_wrapper 349 | 350 | 351 | def _get_on_start_handler(callback: BaseCallback, instance: Any, fn: Callable) -> Callable: 352 | """Selects the appropriate on_start handler of the callback based on the instance and function name.""" 353 | if isinstance(instance, dspy.LM): 354 | return callback.on_lm_start 355 | elif isinstance(instance, dspy.Evaluate): 356 | return callback.on_evaluate_start 357 | 358 | if isinstance(instance, dspy.Adapter): 359 | if fn.__name__ == "format": 360 | return callback.on_adapter_format_start 361 | elif fn.__name__ == "parse": 362 | return callback.on_adapter_parse_start 363 | else: 364 | raise ValueError(f"Unsupported adapter method for using callback: {fn.__name__}.") 365 | 366 | if isinstance(instance, dspy.Tool): 367 | return callback.on_tool_start 368 | 369 | # We treat everything else as a module. 370 | return callback.on_module_start 371 | 372 | 373 | def _get_on_end_handler(callback: BaseCallback, instance: Any, fn: Callable) -> Callable: 374 | """Selects the appropriate on_end handler of the callback based on the instance and function name.""" 375 | if isinstance(instance, (dspy.LM)): 376 | return callback.on_lm_end 377 | elif isinstance(instance, dspy.Evaluate): 378 | return callback.on_evaluate_end 379 | 380 | if isinstance(instance, (dspy.Adapter)): 381 | if fn.__name__ == "format": 382 | return callback.on_adapter_format_end 383 | elif fn.__name__ == "parse": 384 | return callback.on_adapter_parse_end 385 | else: 386 | raise ValueError(f"Unsupported adapter method for using callback: {fn.__name__}.") 387 | 388 | if isinstance(instance, dspy.Tool): 389 | return callback.on_tool_end 390 | 391 | # We treat everything else as a module. 392 | return callback.on_module_end 393 | ``` -------------------------------------------------------------------------------- /dspy/teleprompt/gepa/gepa_utils.py: -------------------------------------------------------------------------------- ```python 1 | import logging 2 | import random 3 | from typing import Any, Callable, Protocol, TypedDict 4 | 5 | from gepa import EvaluationBatch, GEPAAdapter 6 | from gepa.core.adapter import ProposalFn 7 | 8 | import dspy 9 | from dspy.adapters.chat_adapter import ChatAdapter 10 | from dspy.adapters.types import History 11 | from dspy.adapters.types.base_type import Type 12 | from dspy.evaluate import Evaluate 13 | from dspy.primitives import Example, Prediction 14 | from dspy.teleprompt.bootstrap_trace import TraceData 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | class LoggerAdapter: 19 | def __init__(self, logger: logging.Logger): 20 | self.logger = logger 21 | 22 | def log(self, x: str): 23 | self.logger.info(x) 24 | 25 | DSPyTrace = list[tuple[Any, dict[str, Any], Prediction]] 26 | 27 | 28 | class ReflectiveExample(TypedDict): 29 | """ 30 | Structure of individual examples in the reflective dataset. 31 | 32 | Each example contains the predictor inputs, generated outputs, and feedback from evaluation. 33 | """ 34 | Inputs: dict[str, Any] # Predictor inputs (may include str, dspy.Image, etc.) 35 | Generated_Outputs: dict[str, Any] | str # Success: dict with output fields, Failure: error message string 36 | Feedback: str # Always a string - from metric function or parsing error message 37 | 38 | 39 | class ScoreWithFeedback(Prediction): 40 | score: float 41 | feedback: str 42 | 43 | class PredictorFeedbackFn(Protocol): 44 | def __call__( 45 | predictor_output: dict[str, Any], 46 | predictor_inputs: dict[str, Any], 47 | module_inputs: Example, 48 | module_outputs: Prediction, 49 | captured_trace: DSPyTrace, 50 | ) -> ScoreWithFeedback: 51 | """ 52 | This function is used to provide feedback to a specific predictor. 53 | The function is called with the following arguments: 54 | - predictor_output: The output of the predictor. 55 | - predictor_inputs: The inputs to the predictor. 56 | - module_inputs: The inputs to the whole program --- `Example`. 57 | - module_outputs: The outputs of the whole program --- `Prediction`. 58 | - captured_trace: The trace of the module's execution. 59 | # Shape of trace is: [predictor_invocation_idx -> Tuple[Predictor, PredictorInputs, Prediction]] 60 | # Each trace is a tuple of (Predictor, PredictorInputs, Prediction) 61 | 62 | The function should return a `ScoreWithFeedback` object. 63 | The feedback is a string that is used to guide the evolution of the predictor. 64 | """ 65 | ... 66 | 67 | class DspyAdapter(GEPAAdapter[Example, TraceData, Prediction]): 68 | def __init__( 69 | self, 70 | student_module, 71 | metric_fn: Callable, 72 | feedback_map: dict[str, Callable], 73 | failure_score=0.0, 74 | num_threads: int | None = None, 75 | add_format_failure_as_feedback: bool = False, 76 | rng: random.Random | None = None, 77 | reflection_lm=None, 78 | custom_instruction_proposer: "ProposalFn | None" = None, 79 | warn_on_score_mismatch: bool = True 80 | ): 81 | self.student = student_module 82 | self.metric_fn = metric_fn 83 | self.feedback_map = feedback_map 84 | self.failure_score = failure_score 85 | self.num_threads = num_threads 86 | self.add_format_failure_as_feedback = add_format_failure_as_feedback 87 | self.rng = rng or random.Random(0) 88 | self.reflection_lm = reflection_lm 89 | self.custom_instruction_proposer = custom_instruction_proposer 90 | self.warn_on_score_mismatch = warn_on_score_mismatch 91 | 92 | if self.custom_instruction_proposer is not None: 93 | # We are only overriding the propose_new_texts method when a custom 94 | # instruction proposer is provided. Otherwise, we use the GEPA 95 | # default propose_new_texts. 96 | 97 | def custom_propose_new_texts( 98 | candidate: dict[str, str], 99 | reflective_dataset: dict[str, list[dict[str, Any]]], 100 | components_to_update: list[str] 101 | ) -> dict[str, str]: 102 | if self.reflection_lm is not None: 103 | with dspy.context(lm=self.reflection_lm): 104 | return self.custom_instruction_proposer( 105 | candidate=candidate, 106 | reflective_dataset=reflective_dataset, 107 | components_to_update=components_to_update 108 | ) 109 | else: 110 | return self.custom_instruction_proposer( 111 | candidate=candidate, 112 | reflective_dataset=reflective_dataset, 113 | components_to_update=components_to_update 114 | ) 115 | 116 | self.propose_new_texts = custom_propose_new_texts 117 | 118 | # Cache predictor names/signatures 119 | self.named_predictors = list(self.student.named_predictors()) 120 | 121 | 122 | def build_program(self, candidate: dict[str, str]): 123 | new_prog = self.student.deepcopy() 124 | for name, pred in new_prog.named_predictors(): 125 | if name in candidate: 126 | pred.signature = pred.signature.with_instructions(candidate[name]) 127 | return new_prog 128 | 129 | def evaluate(self, batch, candidate, capture_traces=False): 130 | program = self.build_program(candidate) 131 | 132 | if capture_traces: 133 | # bootstrap_trace_data-like flow with trace capture 134 | from dspy.teleprompt import bootstrap_trace as bootstrap_trace_module 135 | 136 | eval_callback_metadata = {"disable_logging": True} 137 | trajs = bootstrap_trace_module.bootstrap_trace_data( 138 | program=program, 139 | dataset=batch, 140 | metric=self.metric_fn, 141 | num_threads=self.num_threads, 142 | raise_on_error=False, 143 | capture_failed_parses=True, 144 | failure_score=self.failure_score, 145 | format_failure_score=self.failure_score, 146 | callback_metadata=eval_callback_metadata, 147 | ) 148 | scores = [] 149 | outputs = [] 150 | for t in trajs: 151 | outputs.append(t["prediction"]) 152 | if hasattr(t["prediction"], "__class__") and t.get("score") is None: 153 | scores.append(self.failure_score) 154 | else: 155 | score = t["score"] 156 | if hasattr(score, "score"): 157 | score = score["score"] 158 | scores.append(score) 159 | return EvaluationBatch(outputs=outputs, scores=scores, trajectories=trajs) 160 | else: 161 | evaluator = Evaluate( 162 | devset=batch, 163 | metric=self.metric_fn, 164 | num_threads=self.num_threads, 165 | return_all_scores=True, 166 | failure_score=self.failure_score, 167 | provide_traceback=True, 168 | max_errors=len(batch) * 100 169 | ) 170 | res = evaluator(program) 171 | outputs = [r[1] for r in res.results] 172 | scores = [r[2] for r in res.results] 173 | scores = [s["score"] if hasattr(s, "score") else s for s in scores] 174 | return EvaluationBatch(outputs=outputs, scores=scores, trajectories=None) 175 | 176 | def make_reflective_dataset(self, candidate, eval_batch, components_to_update) -> dict[str, list[ReflectiveExample]]: 177 | from dspy.teleprompt.bootstrap_trace import FailedPrediction 178 | program = self.build_program(candidate) 179 | 180 | ret_d: dict[str, list[ReflectiveExample]] = {} 181 | for pred_name in components_to_update: 182 | module = None 183 | for name, m in program.named_predictors(): 184 | if name == pred_name: 185 | module = m 186 | break 187 | assert module is not None 188 | 189 | items: list[ReflectiveExample] = [] 190 | for data in eval_batch.trajectories or []: 191 | trace = data["trace"] 192 | example = data["example"] 193 | prediction = data["prediction"] 194 | module_score = data["score"] 195 | if hasattr(module_score, "score"): 196 | module_score = module_score["score"] 197 | 198 | trace_instances = [t for t in trace if t[0].signature.equals(module.signature)] 199 | if not self.add_format_failure_as_feedback: 200 | trace_instances = [t for t in trace_instances if not isinstance(t[2], FailedPrediction)] 201 | if len(trace_instances) == 0: 202 | continue 203 | 204 | selected = None 205 | for t in trace_instances: 206 | if isinstance(t[2], FailedPrediction): 207 | selected = t 208 | break 209 | 210 | if selected is None: 211 | if isinstance(prediction, FailedPrediction): 212 | continue 213 | selected = self.rng.choice(trace_instances) 214 | 215 | inputs = selected[1] 216 | outputs = selected[2] 217 | 218 | new_inputs = {} 219 | new_outputs = {} 220 | 221 | contains_history = False 222 | history_key_name = None 223 | for input_key, input_val in inputs.items(): 224 | if isinstance(input_val, History): 225 | contains_history = True 226 | assert history_key_name is None 227 | history_key_name = input_key 228 | 229 | if contains_history: 230 | s = "```json\n" 231 | for i, message in enumerate(inputs[history_key_name].messages): 232 | s += f" {i}: {message}\n" 233 | s += "```" 234 | new_inputs["Context"] = s 235 | 236 | for input_key, input_val in inputs.items(): 237 | if contains_history and input_key == history_key_name: 238 | continue 239 | 240 | if isinstance(input_val, Type) and self.custom_instruction_proposer is not None: 241 | # Keep original object - will be properly formatted when sent to reflection LM 242 | new_inputs[input_key] = input_val 243 | else: 244 | new_inputs[input_key] = str(input_val) 245 | 246 | if isinstance(outputs, FailedPrediction): 247 | s = "Couldn't parse the output as per the expected output format. The model's raw response was:\n" 248 | s += "```\n" 249 | s += outputs.completion_text + "\n" 250 | s += "```\n\n" 251 | new_outputs = s 252 | else: 253 | for output_key, output_val in outputs.items(): 254 | new_outputs[output_key] = str(output_val) 255 | 256 | d = {"Inputs": new_inputs, "Generated Outputs": new_outputs} 257 | if isinstance(outputs, FailedPrediction): 258 | adapter = ChatAdapter() 259 | structure_instruction = "" 260 | for dd in adapter.format(module.signature, [], {}): 261 | structure_instruction += dd["role"] + ": " + dd["content"] + "\n" 262 | d["Feedback"] = "Your output failed to parse. Follow this structure:\n" + structure_instruction 263 | # d['score'] = self.failure_score 264 | else: 265 | feedback_fn = self.feedback_map[pred_name] 266 | fb = feedback_fn( 267 | predictor_output=outputs, 268 | predictor_inputs=inputs, 269 | module_inputs=example, 270 | module_outputs=prediction, 271 | captured_trace=trace, 272 | ) 273 | d["Feedback"] = fb["feedback"] 274 | if fb["score"] != module_score: 275 | if self.warn_on_score_mismatch: 276 | logger.warning("The score returned by the metric with pred_name is different from the overall metric score. This can indicate 2 things: Either the metric is non-deterministic (e.g., LLM-as-judge, Semantic score, etc.) or the metric returned a score specific to pred_name that differs from the module level score. Currently, GEPA does not support predictor level scoring (support coming soon), and only requires a feedback text to be provided, which can be specific to the predictor or program level. GEPA will ignore the differing score returned, and instead use module level score. You can safely ignore this warning if using a semantic metric, however, if this mismatch is caused due to predictor scoring, please return module-level scores. To disable this warning, set warn_on_score_mismatch=False.") 277 | self.warn_on_score_mismatch = False 278 | fb["score"] = module_score 279 | 280 | items.append(d) 281 | 282 | if len(items) == 0: 283 | # raise Exception(f"No valid predictions found for module {module.signature}.") 284 | continue 285 | ret_d[pred_name] = items 286 | 287 | if len(ret_d) == 0: 288 | raise Exception("No valid predictions found for any module.") 289 | 290 | return ret_d 291 | 292 | # TODO: The current DSPyAdapter implementation uses the GEPA default propose_new_texts. 293 | # We can potentially override this, to use the instruction proposal similar to MIPROv2. 294 | 295 | # def propose_new_texts( 296 | # self, 297 | # candidate: Dict[str, str], 298 | # reflective_dataset: Dict[str, List[Dict[str, Any]]], 299 | # components_to_update: List[str] 300 | # ) -> Dict[str, str]: 301 | # if self.adapter.propose_new_texts is not None: 302 | # return self.adapter.propose_new_texts(candidate, reflective_dataset, components_to_update) 303 | 304 | # from .instruction_proposal import InstructionProposalSignature 305 | # new_texts: Dict[str, str] = {} 306 | # for name in components_to_update: 307 | # base_instruction = candidate[name] 308 | # dataset_with_feedback = reflective_dataset[name] 309 | # new_texts[name] = InstructionProposalSignature.run( 310 | # lm=self.reflection_lm, 311 | # input_dict={ 312 | # "current_instruction_doc": base_instruction, 313 | # "dataset_with_feedback": dataset_with_feedback 314 | # } 315 | # )['new_instruction'] 316 | # return new_texts 317 | ``` -------------------------------------------------------------------------------- /dspy/teleprompt/gepa/instruction_proposal.py: -------------------------------------------------------------------------------- ```python 1 | from typing import Any 2 | 3 | from gepa.core.adapter import ProposalFn 4 | 5 | import dspy 6 | from dspy.adapters.types.base_type import Type 7 | from dspy.teleprompt.gepa.gepa_utils import ReflectiveExample 8 | 9 | 10 | class GenerateEnhancedMultimodalInstructionFromFeedback(dspy.Signature): 11 | """I provided an assistant with instructions to perform a task involving visual content, but the assistant's performance needs improvement based on the examples and feedback below. 12 | 13 | Your task is to write a better instruction for the assistant that addresses the specific issues identified in the feedback, with particular attention to how visual and textual information should be analyzed and integrated. 14 | 15 | ## Analysis Steps: 16 | 1. **Read the inputs carefully** and identify both the visual and textual input formats, understanding how they work together 17 | 2. **Read all the assistant responses and corresponding feedback** to understand what went wrong with visual analysis, text processing, or their integration 18 | 3. **Identify visual analysis patterns** - what visual features, relationships, or details are important for this task 19 | 4. **Identify domain-specific knowledge** about both visual and textual aspects, as this information may not be available to the assistant in the future 20 | 5. **Look for successful visual-textual integration strategies** and include these patterns in the instruction 21 | 6. **Address specific visual analysis issues** mentioned in the feedback 22 | 23 | ## Instruction Requirements: 24 | - **Clear task definition** explaining how to process both visual and textual inputs 25 | - **Visual analysis guidance** specific to this task (what to look for, how to describe, what features matter) 26 | - **Integration strategies** for combining visual observations with textual information 27 | - **Domain-specific knowledge** about visual concepts, terminology, or relationships 28 | - **Error prevention guidance** for common visual analysis mistakes shown in the feedback 29 | - **Precise, actionable language** for both visual and textual processing 30 | 31 | Focus on creating an instruction that helps the assistant properly analyze visual content, integrate it with textual information, and avoid the specific visual analysis mistakes shown in the examples.""" 32 | 33 | current_instruction = dspy.InputField( 34 | desc="The current instruction that was provided to the assistant to perform the multimodal task" 35 | ) 36 | examples_with_feedback = dspy.InputField( 37 | desc="Task examples with visual content showing inputs, assistant outputs, and feedback. " 38 | "Pay special attention to feedback about visual analysis accuracy, visual-textual integration, " 39 | "and any domain-specific visual knowledge that the assistant missed." 40 | ) 41 | 42 | improved_instruction = dspy.OutputField( 43 | desc="A better instruction for the assistant that addresses visual analysis issues, provides " 44 | "clear guidance on how to process and integrate visual and textual information, includes " 45 | "necessary visual domain knowledge, and prevents the visual analysis mistakes shown in the examples." 46 | ) 47 | 48 | 49 | class SingleComponentMultiModalProposer(dspy.Module): 50 | """ 51 | dspy.Module for proposing improved instructions based on feedback. 52 | """ 53 | 54 | def __init__(self): 55 | super().__init__() 56 | self.propose_instruction = dspy.Predict(GenerateEnhancedMultimodalInstructionFromFeedback) 57 | 58 | def forward(self, current_instruction: str, reflective_dataset: list[ReflectiveExample]) -> str: 59 | """ 60 | Generate an improved instruction based on current instruction and feedback examples. 61 | 62 | Args: 63 | current_instruction: The current instruction that needs improvement 64 | reflective_dataset: List of examples with inputs, outputs, and feedback 65 | May contain dspy.Image objects in inputs 66 | 67 | Returns: 68 | str: Improved instruction text 69 | """ 70 | # Format examples with enhanced pattern recognition 71 | formatted_examples, image_map = self._format_examples_with_pattern_analysis(reflective_dataset) 72 | 73 | # Build kwargs for the prediction call 74 | predict_kwargs = { 75 | "current_instruction": current_instruction, 76 | "examples_with_feedback": formatted_examples, 77 | } 78 | 79 | # Create a rich multimodal examples_with_feedback that includes both text and images 80 | predict_kwargs["examples_with_feedback"] = self._create_multimodal_examples(formatted_examples, image_map) 81 | 82 | # Use current dspy LM settings (GEPA will pass reflection_lm via context) 83 | result = self.propose_instruction(**predict_kwargs) 84 | 85 | return result.improved_instruction 86 | 87 | def _format_examples_with_pattern_analysis( 88 | self, reflective_dataset: list[ReflectiveExample] 89 | ) -> tuple[str, dict[int, list[Type]]]: 90 | """ 91 | Format examples with pattern analysis and feedback categorization. 92 | 93 | Returns: 94 | tuple: (formatted_text_with_patterns, image_map) 95 | """ 96 | # First, use the existing proven formatting approach 97 | formatted_examples, image_map = self._format_examples_for_instruction_generation(reflective_dataset) 98 | 99 | # Enhanced analysis: categorize feedback patterns 100 | feedback_analysis = self._analyze_feedback_patterns(reflective_dataset) 101 | 102 | # Add pattern analysis to the formatted examples 103 | if feedback_analysis["summary"]: 104 | pattern_summary = self._create_pattern_summary(feedback_analysis) 105 | enhanced_examples = f"{pattern_summary}\n\n{formatted_examples}" 106 | return enhanced_examples, image_map 107 | 108 | return formatted_examples, image_map 109 | 110 | def _analyze_feedback_patterns(self, reflective_dataset: list[ReflectiveExample]) -> dict[str, Any]: 111 | """ 112 | Analyze feedback patterns to provide better context for instruction generation. 113 | 114 | Categorizes feedback into: 115 | - Error patterns: Common mistakes and their types 116 | - Success patterns: What worked well and should be preserved/emphasized 117 | - Domain knowledge gaps: Missing information that should be included 118 | - Task-specific guidance: Specific requirements or edge cases 119 | """ 120 | analysis = { 121 | "error_patterns": [], 122 | "success_patterns": [], 123 | "domain_knowledge_gaps": [], 124 | "task_specific_guidance": [], 125 | "summary": "", 126 | } 127 | 128 | # Simple pattern recognition - could be enhanced further 129 | for example in reflective_dataset: 130 | feedback = example.get("Feedback", "").lower() 131 | 132 | # Identify error patterns 133 | if any(error_word in feedback for error_word in ["incorrect", "wrong", "error", "failed", "missing"]): 134 | analysis["error_patterns"].append(feedback) 135 | 136 | # Identify success patterns 137 | if any( 138 | success_word in feedback for success_word in ["correct", "good", "accurate", "well", "successfully"] 139 | ): 140 | analysis["success_patterns"].append(feedback) 141 | 142 | # Identify domain knowledge needs 143 | if any( 144 | knowledge_word in feedback 145 | for knowledge_word in ["should know", "domain", "specific", "context", "background"] 146 | ): 147 | analysis["domain_knowledge_gaps"].append(feedback) 148 | 149 | # Create summary if patterns were found 150 | if any(analysis[key] for key in ["error_patterns", "success_patterns", "domain_knowledge_gaps"]): 151 | analysis["summary"] = ( 152 | f"Patterns identified: {len(analysis['error_patterns'])} error(s), {len(analysis['success_patterns'])} success(es), {len(analysis['domain_knowledge_gaps'])} knowledge gap(s)" 153 | ) 154 | 155 | return analysis 156 | 157 | def _create_pattern_summary(self, feedback_analysis: dict[str, Any]) -> str: 158 | """Create a summary of feedback patterns to help guide instruction generation.""" 159 | 160 | summary_parts = ["## Feedback Pattern Analysis\n"] 161 | 162 | if feedback_analysis["error_patterns"]: 163 | summary_parts.append(f"**Common Issues Found ({len(feedback_analysis['error_patterns'])} examples):**") 164 | summary_parts.append("Focus on preventing these types of mistakes in the new instruction.\n") 165 | 166 | if feedback_analysis["success_patterns"]: 167 | summary_parts.append( 168 | f"**Successful Approaches Found ({len(feedback_analysis['success_patterns'])} examples):**" 169 | ) 170 | summary_parts.append("Build on these successful strategies in the new instruction.\n") 171 | 172 | if feedback_analysis["domain_knowledge_gaps"]: 173 | summary_parts.append( 174 | f"**Domain Knowledge Needs Identified ({len(feedback_analysis['domain_knowledge_gaps'])} examples):**" 175 | ) 176 | summary_parts.append("Include this specialized knowledge in the new instruction.\n") 177 | 178 | return "\n".join(summary_parts) 179 | 180 | def _format_examples_for_instruction_generation( 181 | self, reflective_dataset: list[ReflectiveExample] 182 | ) -> tuple[str, dict[int, list[Type]]]: 183 | """ 184 | Format examples using GEPA's markdown structure while preserving image objects. 185 | 186 | Returns: 187 | tuple: (formatted_text, image_map) where image_map maps example_index -> list[images] 188 | """ 189 | 190 | def render_value_with_images(value, level=3, example_images=None): 191 | if example_images is None: 192 | example_images = [] 193 | 194 | if isinstance(value, Type): 195 | image_idx = len(example_images) + 1 196 | example_images.append(value) 197 | return f"[IMAGE-{image_idx} - see visual content]\n\n" 198 | elif isinstance(value, dict): 199 | s = "" 200 | for k, v in value.items(): 201 | s += f"{'#' * level} {k}\n" 202 | s += render_value_with_images(v, min(level + 1, 6), example_images) 203 | if not value: 204 | s += "\n" 205 | return s 206 | elif isinstance(value, (list, tuple)): 207 | s = "" 208 | for i, item in enumerate(value): 209 | s += f"{'#' * level} Item {i + 1}\n" 210 | s += render_value_with_images(item, min(level + 1, 6), example_images) 211 | if not value: 212 | s += "\n" 213 | return s 214 | else: 215 | return f"{str(value).strip()}\n\n" 216 | 217 | def convert_sample_to_markdown_with_images(sample, example_num): 218 | example_images = [] 219 | s = f"# Example {example_num}\n" 220 | 221 | for key, val in sample.items(): 222 | s += f"## {key}\n" 223 | s += render_value_with_images(val, level=3, example_images=example_images) 224 | 225 | return s, example_images 226 | 227 | formatted_parts = [] 228 | image_map = {} 229 | 230 | for i, example_data in enumerate(reflective_dataset): 231 | formatted_example, example_images = convert_sample_to_markdown_with_images(example_data, i + 1) 232 | formatted_parts.append(formatted_example) 233 | 234 | if example_images: 235 | image_map[i] = example_images 236 | 237 | formatted_text = "\n\n".join(formatted_parts) 238 | 239 | if image_map: 240 | total_images = sum(len(imgs) for imgs in image_map.values()) 241 | formatted_text = ( 242 | f"The examples below include visual content ({total_images} images total). " 243 | "Please analyze both the text and visual elements when suggesting improvements.\n\n" + formatted_text 244 | ) 245 | 246 | return formatted_text, image_map 247 | 248 | def _create_multimodal_examples(self, formatted_text: str, image_map: dict[int, list[Type]]) -> Any: 249 | """ 250 | Create a multimodal input that contains both text and images for the reflection LM. 251 | 252 | Args: 253 | formatted_text: The formatted text with image placeholders 254 | image_map: Dictionary mapping example_index -> list[images] for structured access 255 | """ 256 | if not image_map: 257 | return formatted_text 258 | 259 | # Collect all images from all examples 260 | all_images = [] 261 | for example_images in image_map.values(): 262 | all_images.extend(example_images) 263 | 264 | multimodal_content = [formatted_text] 265 | multimodal_content.extend(all_images) 266 | return multimodal_content 267 | 268 | 269 | class MultiModalInstructionProposer(ProposalFn): 270 | """GEPA-compatible multimodal instruction proposer. 271 | 272 | This class handles multimodal inputs (like dspy.Image) during GEPA optimization by using 273 | a single-component proposer for each component that needs to be updated. 274 | """ 275 | 276 | def __init__(self): 277 | self.single_proposer = SingleComponentMultiModalProposer() 278 | 279 | def __call__( 280 | self, 281 | candidate: dict[str, str], 282 | reflective_dataset: dict[str, list[ReflectiveExample]], 283 | components_to_update: list[str], 284 | ) -> dict[str, str]: 285 | """GEPA-compatible proposal function. 286 | 287 | Args: 288 | candidate: Current component name -> instruction mapping 289 | reflective_dataset: Component name -> list of reflective examples 290 | components_to_update: List of component names to update 291 | 292 | Returns: 293 | dict: Component name -> new instruction mapping 294 | """ 295 | updated_components = {} 296 | 297 | for component_name in components_to_update: 298 | if component_name in candidate and component_name in reflective_dataset: 299 | current_instruction = candidate[component_name] 300 | component_reflective_data = reflective_dataset[component_name] 301 | 302 | # Call the single-instruction proposer. 303 | # 304 | # In the future, proposals could consider multiple components instructions, 305 | # instead of just the current instruction, for more holistic instruction proposals. 306 | new_instruction = self.single_proposer( 307 | current_instruction=current_instruction, reflective_dataset=component_reflective_data 308 | ) 309 | 310 | updated_components[component_name] = new_instruction 311 | 312 | return updated_components 313 | ``` -------------------------------------------------------------------------------- /docs/docs/learn/programming/7-assertions.md: -------------------------------------------------------------------------------- ```markdown 1 | # DSPy Assertions 2 | 3 | !!! warning "Assertions are deprecated and NOT supported. Please use the `dspy.Refine` module instead. (or dspy.Suggest)." 4 | 5 | The content below is deprecated, and is scheduled to be removed. 6 | 7 | ## Introduction 8 | 9 | Language models (LMs) have transformed how we interact with machine learning, offering vast capabilities in natural language understanding and generation. However, ensuring these models adhere to domain-specific constraints remains a challenge. Despite the growth of techniques like fine-tuning or “prompt engineering”, these approaches are extremely tedious and rely on heavy, manual hand-waving to guide the LMs in adhering to specific constraints. Even DSPy's modularity of programming prompting pipelines lacks mechanisms to effectively and automatically enforce these constraints. 10 | 11 | To address this, we introduce DSPy Assertions, a feature within the DSPy framework designed to automate the enforcement of computational constraints on LMs. DSPy Assertions empower developers to guide LMs towards desired outcomes with minimal manual intervention, enhancing the reliability, predictability, and correctness of LM outputs. 12 | 13 | ### dspy.Assert and dspy.Suggest API 14 | 15 | We introduce two primary constructs within DSPy Assertions: 16 | 17 | - **`dspy.Assert`**: 18 | - **Parameters**: 19 | - `constraint (bool)`: Outcome of Python-defined boolean validation check. 20 | - `msg (Optional[str])`: User-defined error message providing feedback or correction guidance. 21 | - `backtrack (Optional[module])`: Specifies target module for retry attempts upon constraint failure. The default backtracking module is the last module before the assertion. 22 | - **Behavior**: Initiates retry upon failure, dynamically adjusting the pipeline's execution. If failures persist, it halts execution and raises a `dspy.AssertionError`. 23 | 24 | - **`dspy.Suggest`**: 25 | - **Parameters**: Similar to `dspy.Assert`. 26 | - **Behavior**: Encourages self-refinement through retries without enforcing hard stops. Logs failures after maximum backtracking attempts and continues execution. 27 | 28 | - **dspy.Assert vs. Python Assertions**: Unlike conventional Python `assert` statements that terminate the program upon failure, `dspy.Assert` conducts a sophisticated retry mechanism, allowing the pipeline to adjust. 29 | 30 | Specifically, when a constraint is not met: 31 | 32 | - Backtracking Mechanism: An under-the-hood backtracking is initiated, offering the model a chance to self-refine and proceed, which is done through signature modification. 33 | - Dynamic Signature Modification: internally modifying your DSPy program’s Signature by adding the following fields: 34 | - Past Output: your model's past output that did not pass the validation_fn 35 | - Instruction: your user-defined feedback message on what went wrong and what possibly to fix 36 | 37 | If the error continues past the `max_backtracking_attempts`, then `dspy.Assert` will halt the pipeline execution, alerting you with an `dspy.AssertionError`. This ensures your program doesn't continue executing with “bad” LM behavior and immediately highlights sample failure outputs for user assessment. 38 | 39 | - **dspy.Suggest vs. dspy.Assert**: `dspy.Suggest` on the other hand offers a softer approach. It maintains the same retry backtracking as `dspy.Assert` but instead serves as a gentle nudger. If the model outputs cannot pass the model constraints after the `max_backtracking_attempts`, `dspy.Suggest` will log the persistent failure and continue execution of the program on the rest of the data. This ensures the LM pipeline works in a "best-effort" manner without halting execution. 40 | 41 | - **`dspy.Suggest`** statements are best utilized as "helpers" during the evaluation phase, offering guidance and potential corrections without halting the pipeline. 42 | - **`dspy.Assert`** statements are recommended during the development stage as "checkers" to ensure the LM behaves as expected, providing a robust mechanism for identifying and addressing errors early in the development cycle. 43 | 44 | 45 | ## Use Case: Including Assertions in DSPy Programs 46 | 47 | We start with using an example of a multi-hop QA SimplifiedBaleen pipeline as defined in the intro walkthrough. 48 | 49 | ```python 50 | class SimplifiedBaleen(dspy.Module): 51 | def __init__(self, passages_per_hop=2, max_hops=2): 52 | super().__init__() 53 | 54 | self.generate_query = [dspy.ChainOfThought(GenerateSearchQuery) for _ in range(max_hops)] 55 | self.retrieve = dspy.Retrieve(k=passages_per_hop) 56 | self.generate_answer = dspy.ChainOfThought(GenerateAnswer) 57 | self.max_hops = max_hops 58 | 59 | def forward(self, question): 60 | context = [] 61 | prev_queries = [question] 62 | 63 | for hop in range(self.max_hops): 64 | query = self.generate_query[hop](context=context, question=question).query 65 | prev_queries.append(query) 66 | passages = self.retrieve(query).passages 67 | context = deduplicate(context + passages) 68 | 69 | pred = self.generate_answer(context=context, question=question) 70 | pred = dspy.Prediction(context=context, answer=pred.answer) 71 | return pred 72 | 73 | baleen = SimplifiedBaleen() 74 | 75 | baleen(question = "Which award did Gary Zukav's first book receive?") 76 | ``` 77 | 78 | To include DSPy Assertions, we simply define our validation functions and declare our assertions following the respective model generation. 79 | 80 | For this use case, suppose we want to impose the following constraints: 81 | 1. Length - each query should be less than 100 characters 82 | 2. Uniqueness - each generated query should differ from previously-generated queries. 83 | 84 | We can define these validation checks as boolean functions: 85 | 86 | ```python 87 | #simplistic boolean check for query length 88 | len(query) <= 100 89 | 90 | #Python function for validating distinct queries 91 | def validate_query_distinction_local(previous_queries, query): 92 | """check if query is distinct from previous queries""" 93 | if previous_queries == []: 94 | return True 95 | if dspy.evaluate.answer_exact_match_str(query, previous_queries, frac=0.8): 96 | return False 97 | return True 98 | ``` 99 | 100 | We can declare these validation checks through `dspy.Suggest` statements (as we want to test the program in a best-effort demonstration). We want to keep these after the query generation `query = self.generate_query[hop](context=context, question=question).query`. 101 | 102 | ```python 103 | dspy.Suggest( 104 | len(query) <= 100, 105 | "Query should be short and less than 100 characters", 106 | target_module=self.generate_query 107 | ) 108 | 109 | dspy.Suggest( 110 | validate_query_distinction_local(prev_queries, query), 111 | "Query should be distinct from: " 112 | + "; ".join(f"{i+1}) {q}" for i, q in enumerate(prev_queries)), 113 | target_module=self.generate_query 114 | ) 115 | ``` 116 | 117 | It is recommended to define a program with assertions separately than your original program if you are doing comparative evaluation for the effect of assertions. If not, feel free to set Assertions away! 118 | 119 | Let's take a look at how the SimplifiedBaleen program will look with Assertions included: 120 | 121 | ```python 122 | class SimplifiedBaleenAssertions(dspy.Module): 123 | def __init__(self, passages_per_hop=2, max_hops=2): 124 | super().__init__() 125 | self.generate_query = [dspy.ChainOfThought(GenerateSearchQuery) for _ in range(max_hops)] 126 | self.retrieve = dspy.Retrieve(k=passages_per_hop) 127 | self.generate_answer = dspy.ChainOfThought(GenerateAnswer) 128 | self.max_hops = max_hops 129 | 130 | def forward(self, question): 131 | context = [] 132 | prev_queries = [question] 133 | 134 | for hop in range(self.max_hops): 135 | query = self.generate_query[hop](context=context, question=question).query 136 | 137 | dspy.Suggest( 138 | len(query) <= 100, 139 | "Query should be short and less than 100 characters", 140 | target_module=self.generate_query 141 | ) 142 | 143 | dspy.Suggest( 144 | validate_query_distinction_local(prev_queries, query), 145 | "Query should be distinct from: " 146 | + "; ".join(f"{i+1}) {q}" for i, q in enumerate(prev_queries)), 147 | target_module=self.generate_query 148 | ) 149 | 150 | prev_queries.append(query) 151 | passages = self.retrieve(query).passages 152 | context = deduplicate(context + passages) 153 | 154 | if all_queries_distinct(prev_queries): 155 | self.passed_suggestions += 1 156 | 157 | pred = self.generate_answer(context=context, question=question) 158 | pred = dspy.Prediction(context=context, answer=pred.answer) 159 | return pred 160 | ``` 161 | 162 | Now calling programs with DSPy Assertions requires one last step, and that is transforming the program to wrap it with internal assertions backtracking and Retry logic. 163 | 164 | ```python 165 | from dspy.primitives.assertions import assert_transform_module, backtrack_handler 166 | 167 | baleen_with_assertions = assert_transform_module(SimplifiedBaleenAssertions(), backtrack_handler) 168 | 169 | # backtrack_handler is parameterized over a few settings for the backtracking mechanism 170 | # To change the number of max retry attempts, you can do 171 | baleen_with_assertions_retry_once = assert_transform_module(SimplifiedBaleenAssertions(), 172 | functools.partial(backtrack_handler, max_backtracks=1)) 173 | ``` 174 | 175 | Alternatively, you can also directly call `activate_assertions` on the program with `dspy.Assert/Suggest` statements using the default backtracking mechanism (`max_backtracks=2`): 176 | 177 | ```python 178 | baleen_with_assertions = SimplifiedBaleenAssertions().activate_assertions() 179 | ``` 180 | 181 | Now let's take a look at the internal LM backtracking by inspecting the history of the LM query generations. Here we see that when a query fails to pass the validation check of being less than 100 characters, its internal `GenerateSearchQuery` signature is dynamically modified during the backtracking+Retry process to include the past query and the corresponding user-defined instruction: `"Query should be short and less than 100 characters"`. 182 | 183 | 184 | ```text 185 | Write a simple search query that will help answer a complex question. 186 | 187 | --- 188 | 189 | Follow the following format. 190 | 191 | Context: may contain relevant facts 192 | 193 | Question: ${question} 194 | 195 | Reasoning: Let's think step by step in order to ${produce the query}. We ... 196 | 197 | Query: ${query} 198 | 199 | --- 200 | 201 | Context: 202 | [1] «Kerry Condon | Kerry Condon (born 4 January 1983) is [...]» 203 | [2] «Corona Riccardo | Corona Riccardo (c. 1878October 15, 1917) was [...]» 204 | 205 | Question: Who acted in the shot film The Shore and is also the youngest actress ever to play Ophelia in a Royal Shakespeare Company production of "Hamlet." ? 206 | 207 | Reasoning: Let's think step by step in order to find the answer to this question. First, we need to identify the actress who played Ophelia in a Royal Shakespeare Company production of "Hamlet." Then, we need to find out if this actress also acted in the short film "The Shore." 208 | 209 | Query: "actress who played Ophelia in Royal Shakespeare Company production of Hamlet" + "actress in short film The Shore" 210 | 211 | 212 | 213 | Write a simple search query that will help answer a complex question. 214 | 215 | --- 216 | 217 | Follow the following format. 218 | 219 | Context: may contain relevant facts 220 | 221 | Question: ${question} 222 | 223 | Past Query: past output with errors 224 | 225 | Instructions: Some instructions you must satisfy 226 | 227 | Query: ${query} 228 | 229 | --- 230 | 231 | Context: 232 | [1] «Kerry Condon | Kerry Condon (born 4 January 1983) is an Irish television and film actress, best known for her role as Octavia of the Julii in the HBO/BBC series "Rome," as Stacey Ehrmantraut in AMC's "Better Call Saul" and as the voice of F.R.I.D.A.Y. in various films in the Marvel Cinematic Universe. She is also the youngest actress ever to play Ophelia in a Royal Shakespeare Company production of "Hamlet."» 233 | [2] «Corona Riccardo | Corona Riccardo (c. 1878October 15, 1917) was an Italian born American actress who had a brief Broadway stage career before leaving to become a wife and mother. Born in Naples she came to acting in 1894 playing a Mexican girl in a play at the Empire Theatre. Wilson Barrett engaged her for a role in his play "The Sign of the Cross" which he took on tour of the United States. Riccardo played the role of Ancaria and later played Berenice in the same play. Robert B. Mantell in 1898 who struck by her beauty also cast her in two Shakespeare plays, "Romeo and Juliet" and "Othello". Author Lewis Strang writing in 1899 said Riccardo was the most promising actress in America at the time. Towards the end of 1898 Mantell chose her for another Shakespeare part, Ophelia im Hamlet. Afterwards she was due to join Augustin Daly's Theatre Company but Daly died in 1899. In 1899 she gained her biggest fame by playing Iras in the first stage production of Ben-Hur.» 234 | 235 | Question: Who acted in the shot film The Shore and is also the youngest actress ever to play Ophelia in a Royal Shakespeare Company production of "Hamlet." ? 236 | 237 | Past Query: "actress who played Ophelia in Royal Shakespeare Company production of Hamlet" + "actress in short film The Shore" 238 | 239 | Instructions: Query should be short and less than 100 characters 240 | 241 | Query: "actress Ophelia RSC Hamlet" + "actress The Shore" 242 | 243 | ``` 244 | 245 | 246 | ## Assertion-Driven Optimizations 247 | 248 | DSPy Assertions work with optimizations that DSPy offers, particularly with `BootstrapFewShotWithRandomSearch`, including the following settings: 249 | 250 | - Compilation with Assertions 251 | This includes assertion-driven example bootstrapping and counterexample bootstrapping during compilation. The teacher model for bootstrapping few-shot demonstrations can make use of DSPy Assertions to offer robust bootstrapped examples for the student model to learn from during inference. In this setting, the student model does not perform assertion aware optimizations (backtracking and retry) during inference. 252 | - Compilation + Inference with Assertions 253 | -This includes assertion-driven optimizations in both compilation and inference. Now the teacher model offers assertion-driven examples but the student can further optimize with assertions of its own during inference time. 254 | ```python 255 | teleprompter = BootstrapFewShotWithRandomSearch( 256 | metric=validate_context_and_answer_and_hops, 257 | max_bootstrapped_demos=max_bootstrapped_demos, 258 | num_candidate_programs=6, 259 | ) 260 | 261 | #Compilation with Assertions 262 | compiled_with_assertions_baleen = teleprompter.compile(student = baleen, teacher = baleen_with_assertions, trainset = trainset, valset = devset) 263 | 264 | #Compilation + Inference with Assertions 265 | compiled_baleen_with_assertions = teleprompter.compile(student=baleen_with_assertions, teacher = baleen_with_assertions, trainset=trainset, valset=devset) 266 | 267 | ``` 268 | ``` -------------------------------------------------------------------------------- /dspy/clients/databricks.py: -------------------------------------------------------------------------------- ```python 1 | import logging 2 | import os 3 | import re 4 | import time 5 | from typing import TYPE_CHECKING, Any 6 | 7 | import orjson 8 | import requests 9 | 10 | from dspy.clients.provider import Provider, TrainingJob 11 | from dspy.clients.utils_finetune import TrainDataFormat, get_finetune_directory 12 | 13 | if TYPE_CHECKING: 14 | from databricks.sdk import WorkspaceClient 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | class TrainingJobDatabricks(TrainingJob): 20 | def __init__(self, finetuning_run=None, *args, **kwargs): 21 | super().__init__(*args, **kwargs) 22 | self.finetuning_run = finetuning_run 23 | self.launch_started = False 24 | self.launch_completed = False 25 | self.endpoint_name = None 26 | 27 | def status(self): 28 | if not self.finetuning_run: 29 | return None 30 | try: 31 | from databricks.model_training import foundation_model as fm 32 | except ImportError: 33 | raise ImportError( 34 | "To use Databricks finetuning, please install the databricks_genai package via " 35 | "`pip install databricks_genai`." 36 | ) 37 | run = fm.get(self.finetuning_run) 38 | return run.status 39 | 40 | 41 | class DatabricksProvider(Provider): 42 | finetunable = True 43 | TrainingJob = TrainingJobDatabricks 44 | 45 | @staticmethod 46 | def is_provider_model(model: str) -> bool: 47 | # We don't automatically infer Databricks models because Databricks is not a proprietary model provider. 48 | return False 49 | 50 | @staticmethod 51 | def deploy_finetuned_model( 52 | model: str, 53 | data_format: TrainDataFormat | None = None, 54 | databricks_host: str | None = None, 55 | databricks_token: str | None = None, 56 | deploy_timeout: int = 900, 57 | ): 58 | workspace_client = _get_workspace_client() 59 | model_version = next(workspace_client.model_versions.list(model)).version 60 | 61 | # Allow users to override the host and token. This is useful on Databricks hosted runtime. 62 | databricks_host = databricks_host or workspace_client.config.host 63 | databricks_token = databricks_token or workspace_client.config.token 64 | 65 | headers = {"Context-Type": "text/json", "Authorization": f"Bearer {databricks_token}"} 66 | 67 | optimizable_info = requests.get( 68 | url=f"{databricks_host}/api/2.0/serving-endpoints/get-model-optimization-info/{model}/{model_version}", 69 | headers=headers, 70 | ).json() 71 | 72 | if "optimizable" not in optimizable_info or not optimizable_info["optimizable"]: 73 | raise ValueError(f"Model is not eligible for provisioned throughput: {optimizable_info}") 74 | 75 | chunk_size = optimizable_info["throughput_chunk_size"] 76 | 77 | # Minimum desired provisioned throughput 78 | min_provisioned_throughput = 0 79 | 80 | # Maximum desired provisioned throughput 81 | max_provisioned_throughput = chunk_size 82 | 83 | # Databricks serving endpoint names cannot contain ".". 84 | model_name = model.replace(".", "_") 85 | 86 | get_endpoint_response = requests.get( 87 | url=f"{databricks_host}/api/2.0/serving-endpoints/{model_name}", json={"name": model_name}, headers=headers 88 | ) 89 | 90 | if get_endpoint_response.status_code == 200: 91 | logger.info(f"Serving endpoint {model_name} already exists, updating it instead of creating a new one.") 92 | # The serving endpoint already exists, we will update it instead of creating a new one. 93 | data = { 94 | "served_entities": [ 95 | { 96 | "name": model_name, 97 | "entity_name": model, 98 | "entity_version": model_version, 99 | "min_provisioned_throughput": min_provisioned_throughput, 100 | "max_provisioned_throughput": max_provisioned_throughput, 101 | } 102 | ] 103 | } 104 | 105 | response = requests.put( 106 | url=f"{databricks_host}/api/2.0/serving-endpoints/{model_name}/config", 107 | json=data, 108 | headers=headers, 109 | ) 110 | else: 111 | logger.info(f"Creating serving endpoint {model_name} on Databricks model serving!") 112 | # Send the POST request to create the serving endpoint. 113 | data = { 114 | "name": model_name, 115 | "config": { 116 | "served_entities": [ 117 | { 118 | "name": model_name, 119 | "entity_name": model, 120 | "entity_version": model_version, 121 | "min_provisioned_throughput": min_provisioned_throughput, 122 | "max_provisioned_throughput": max_provisioned_throughput, 123 | } 124 | ] 125 | }, 126 | } 127 | 128 | response = requests.post(url=f"{databricks_host}/api/2.0/serving-endpoints", json=data, headers=headers) 129 | 130 | if response.status_code == 200: 131 | logger.info( 132 | f"Successfully started creating/updating serving endpoint {model_name} on Databricks model serving!" 133 | ) 134 | else: 135 | raise ValueError(f"Failed to create serving endpoint: {response.json()}.") 136 | 137 | logger.info( 138 | f"Waiting for serving endpoint {model_name} to be ready, this might take a few minutes... You can check " 139 | f"the status of the endpoint at {databricks_host}/ml/endpoints/{model_name}" 140 | ) 141 | from openai import OpenAI 142 | 143 | client = OpenAI( 144 | api_key=databricks_token, 145 | base_url=f"{databricks_host}/serving-endpoints", 146 | ) 147 | # Wait for the deployment to be ready. 148 | num_retries = deploy_timeout // 60 149 | for _ in range(num_retries): 150 | try: 151 | if data_format == TrainDataFormat.CHAT: 152 | client.chat.completions.create( 153 | messages=[{"role": "user", "content": "hi"}], model=model_name, max_tokens=1 154 | ) 155 | elif data_format == TrainDataFormat.COMPLETION: 156 | client.completions.create(prompt="hi", model=model_name, max_tokens=1) 157 | logger.info(f"Databricks model serving endpoint {model_name} is ready!") 158 | return 159 | except Exception: 160 | time.sleep(60) 161 | 162 | raise ValueError( 163 | f"Failed to create serving endpoint {model_name} on Databricks model serving platform within " 164 | f"{deploy_timeout} seconds." 165 | ) 166 | 167 | @staticmethod 168 | def finetune( 169 | job: TrainingJobDatabricks, 170 | model: str, 171 | train_data: list[dict[str, Any]], 172 | train_data_format: TrainDataFormat | str | None = "chat", 173 | train_kwargs: dict[str, Any] | None = None, 174 | ) -> str: 175 | if isinstance(train_data_format, str): 176 | if train_data_format == "chat": 177 | train_data_format = TrainDataFormat.CHAT 178 | elif train_data_format == "completion": 179 | train_data_format = TrainDataFormat.COMPLETION 180 | else: 181 | raise ValueError( 182 | f"String `train_data_format` must be one of 'chat' or 'completion', but received: {train_data_format}." 183 | ) 184 | 185 | if "train_data_path" not in train_kwargs: 186 | raise ValueError("The `train_data_path` must be provided to finetune on Databricks.") 187 | # Add the file name to the directory path. 188 | train_kwargs["train_data_path"] = DatabricksProvider.upload_data( 189 | train_data, train_kwargs["train_data_path"], train_data_format 190 | ) 191 | 192 | try: 193 | from databricks.model_training import foundation_model as fm 194 | except ImportError: 195 | raise ImportError( 196 | "To use Databricks finetuning, please install the databricks_genai package via " 197 | "`pip install databricks_genai`." 198 | ) 199 | 200 | if "register_to" not in train_kwargs: 201 | raise ValueError("The `register_to` must be provided to finetune on Databricks.") 202 | 203 | # Allow users to override the host and token. This is useful on Databricks hosted runtime. 204 | databricks_host = train_kwargs.pop("databricks_host", None) 205 | databricks_token = train_kwargs.pop("databricks_token", None) 206 | 207 | skip_deploy = train_kwargs.pop("skip_deploy", False) 208 | deploy_timeout = train_kwargs.pop("deploy_timeout", 900) 209 | 210 | logger.info("Starting finetuning on Databricks... this might take a few minutes to finish.") 211 | finetuning_run = fm.create( 212 | model=model, 213 | **train_kwargs, 214 | ) 215 | 216 | job.run = finetuning_run 217 | 218 | # Wait for the finetuning run to be ready. 219 | while True: 220 | job.run = fm.get(job.run) 221 | if job.run.status.display_name == "Completed": 222 | logger.info("Finetuning run completed successfully!") 223 | break 224 | elif job.run.status.display_name == "Failed": 225 | raise ValueError( 226 | f"Finetuning run failed with status: {job.run.status.display_name}. Please check the Databricks " 227 | f"workspace for more details. Finetuning job's metadata: {job.run}." 228 | ) 229 | else: 230 | time.sleep(60) 231 | 232 | if skip_deploy: 233 | return None 234 | 235 | job.launch_started = True 236 | model_to_deploy = train_kwargs.get("register_to") 237 | job.endpoint_name = model_to_deploy.replace(".", "_") 238 | DatabricksProvider.deploy_finetuned_model( 239 | model_to_deploy, train_data_format, databricks_host, databricks_token, deploy_timeout 240 | ) 241 | job.launch_completed = True 242 | # The finetuned model name should be in the format: "databricks/<endpoint_name>". 243 | return f"databricks/{job.endpoint_name}" 244 | 245 | @staticmethod 246 | def upload_data(train_data: list[dict[str, Any]], databricks_unity_catalog_path: str, data_format: TrainDataFormat): 247 | logger.info("Uploading finetuning data to Databricks Unity Catalog...") 248 | file_path = _save_data_to_local_file(train_data, data_format) 249 | 250 | w = _get_workspace_client() 251 | _create_directory_in_databricks_unity_catalog(w, databricks_unity_catalog_path) 252 | 253 | try: 254 | with open(file_path, "rb") as f: 255 | target_path = os.path.join(databricks_unity_catalog_path, os.path.basename(file_path)) 256 | w.files.upload(target_path, f, overwrite=True) 257 | logger.info("Successfully uploaded finetuning data to Databricks Unity Catalog!") 258 | return target_path 259 | except Exception as e: 260 | raise ValueError(f"Failed to upload finetuning data to Databricks Unity Catalog: {e}") 261 | 262 | 263 | def _get_workspace_client() -> "WorkspaceClient": 264 | try: 265 | from databricks.sdk import WorkspaceClient 266 | except ImportError: 267 | raise ImportError( 268 | "To use Databricks finetuning, please install the databricks-sdk package via `pip install databricks-sdk`." 269 | ) 270 | return WorkspaceClient() 271 | 272 | 273 | def _create_directory_in_databricks_unity_catalog(w: "WorkspaceClient", databricks_unity_catalog_path: str): 274 | pattern = r"^/Volumes/(?P<catalog>[^/]+)/(?P<schema>[^/]+)/(?P<volume>[^/]+)(/[^/]+)+$" 275 | match = re.match(pattern, databricks_unity_catalog_path) 276 | if not match: 277 | raise ValueError( 278 | f"Databricks Unity Catalog path must be in the format '/Volumes/<catalog>/<schema>/<volume>/...', but " 279 | f"received: {databricks_unity_catalog_path}." 280 | ) 281 | 282 | catalog = match.group("catalog") 283 | schema = match.group("schema") 284 | volume = match.group("volume") 285 | 286 | try: 287 | volume_path = f"{catalog}.{schema}.{volume}" 288 | w.volumes.read(volume_path) 289 | except Exception: 290 | raise ValueError( 291 | f"Databricks Unity Catalog volume does not exist: {volume_path}, please create it on the Databricks " 292 | "workspace." 293 | ) 294 | 295 | try: 296 | w.files.get_directory_metadata(databricks_unity_catalog_path) 297 | logger.info(f"Directory {databricks_unity_catalog_path} already exists, skip creating it.") 298 | except Exception: 299 | # Create the directory if it doesn't exist, we don't raise an error because this is a common case. 300 | logger.info(f"Creating directory {databricks_unity_catalog_path} in Databricks Unity Catalog...") 301 | w.files.create_directory(databricks_unity_catalog_path) 302 | logger.info(f"Successfully created directory {databricks_unity_catalog_path} in Databricks Unity Catalog!") 303 | 304 | 305 | def _save_data_to_local_file(train_data: list[dict[str, Any]], data_format: TrainDataFormat): 306 | import uuid 307 | 308 | file_name = f"finetuning_{uuid.uuid4()}.jsonl" 309 | 310 | finetune_dir = get_finetune_directory() 311 | file_path = os.path.join(finetune_dir, file_name) 312 | file_path = os.path.abspath(file_path) 313 | with open(file_path, "wb") as f: 314 | for item in train_data: 315 | if data_format == TrainDataFormat.CHAT: 316 | _validate_chat_data(item) 317 | elif data_format == TrainDataFormat.COMPLETION: 318 | _validate_completion_data(item) 319 | 320 | f.write(orjson.dumps(item) + b"\n") 321 | return file_path 322 | 323 | 324 | def _validate_chat_data(data: dict[str, Any]): 325 | if "messages" not in data: 326 | raise ValueError( 327 | "Each finetuning data must be a dict with a 'messages' key when `task=CHAT_COMPLETION`, but " 328 | f"received: {data}" 329 | ) 330 | 331 | if not isinstance(data["messages"], list): 332 | raise ValueError( 333 | "The value of the 'messages' key in each finetuning data must be a list of dicts with keys 'role' and " 334 | f"'content' when `task=CHAT_COMPLETION`, but received: {data['messages']}" 335 | ) 336 | 337 | for message in data["messages"]: 338 | if "role" not in message: 339 | raise ValueError(f"Each message in the 'messages' list must contain a 'role' key, but received: {message}.") 340 | if "content" not in message: 341 | raise ValueError( 342 | f"Each message in the 'messages' list must contain a 'content' key, but received: {message}." 343 | ) 344 | 345 | 346 | def _validate_completion_data(data: dict[str, Any]): 347 | if "prompt" not in data: 348 | raise ValueError( 349 | "Each finetuning data must be a dict with a 'prompt' key when `task=INSTRUCTION_FINETUNE`, but " 350 | f"received: {data}" 351 | ) 352 | if "response" not in data and "completion" not in data: 353 | raise ValueError( 354 | "Each finetuning data must be a dict with a 'response' or 'completion' key when " 355 | f"`task=INSTRUCTION_FINETUNE`, but received: {data}" 356 | ) 357 | ```