This is page 10 of 17. Use http://codebase.md/stanfordnlp/dspy?lines=true&page={x} to view the full context. # Directory Structure ``` ├── .github │ ├── .internal_dspyai │ │ ├── internals │ │ │ ├── build-and-release.md │ │ │ └── release-checklist.md │ │ └── pyproject.toml │ ├── .tmp │ │ └── .generated-actions │ │ └── run-pypi-publish-in-docker-container │ │ └── action.yml │ ├── ISSUE_TEMPLATE │ │ ├── bug_report.yml │ │ └── feature_request.yml │ ├── PULL_REQUEST_TEMPLATE │ │ └── pull_request_template.md │ ├── workflow_scripts │ │ └── install_testpypi_pkg.sh │ └── workflows │ ├── build_and_release.yml │ ├── build_utils │ │ └── test_version.py │ ├── docs-push.yml │ ├── precommits_check.yml │ └── run_tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── docs │ ├── .gitignore │ ├── docs │ │ ├── api │ │ │ ├── adapters │ │ │ │ ├── Adapter.md │ │ │ │ ├── ChatAdapter.md │ │ │ │ ├── JSONAdapter.md │ │ │ │ └── TwoStepAdapter.md │ │ │ ├── evaluation │ │ │ │ ├── answer_exact_match.md │ │ │ │ ├── answer_passage_match.md │ │ │ │ ├── CompleteAndGrounded.md │ │ │ │ ├── Evaluate.md │ │ │ │ ├── EvaluationResult.md │ │ │ │ └── SemanticF1.md │ │ │ ├── experimental │ │ │ │ ├── Citations.md │ │ │ │ └── Document.md │ │ │ ├── index.md │ │ │ ├── models │ │ │ │ ├── Embedder.md │ │ │ │ └── LM.md │ │ │ ├── modules │ │ │ │ ├── BestOfN.md │ │ │ │ ├── ChainOfThought.md │ │ │ │ ├── CodeAct.md │ │ │ │ ├── Module.md │ │ │ │ ├── MultiChainComparison.md │ │ │ │ ├── Parallel.md │ │ │ │ ├── Predict.md │ │ │ │ ├── ProgramOfThought.md │ │ │ │ ├── ReAct.md │ │ │ │ └── Refine.md │ │ │ ├── optimizers │ │ │ │ ├── BetterTogether.md │ │ │ │ ├── BootstrapFewShot.md │ │ │ │ ├── BootstrapFewShotWithRandomSearch.md │ │ │ │ ├── BootstrapFinetune.md │ │ │ │ ├── BootstrapRS.md │ │ │ │ ├── COPRO.md │ │ │ │ ├── Ensemble.md │ │ │ │ ├── GEPA │ │ │ │ │ ├── GEPA_Advanced.md │ │ │ │ │ └── overview.md │ │ │ │ ├── InferRules.md │ │ │ │ ├── KNN.md │ │ │ │ ├── KNNFewShot.md │ │ │ │ ├── LabeledFewShot.md │ │ │ │ ├── MIPROv2.md │ │ │ │ └── SIMBA.md │ │ │ ├── primitives │ │ │ │ ├── Audio.md │ │ │ │ ├── Code.md │ │ │ │ ├── Example.md │ │ │ │ ├── History.md │ │ │ │ ├── Image.md │ │ │ │ ├── Prediction.md │ │ │ │ ├── Tool.md │ │ │ │ └── ToolCalls.md │ │ │ ├── signatures │ │ │ │ ├── InputField.md │ │ │ │ ├── OutputField.md │ │ │ │ └── Signature.md │ │ │ ├── tools │ │ │ │ ├── ColBERTv2.md │ │ │ │ ├── Embeddings.md │ │ │ │ └── PythonInterpreter.md │ │ │ └── utils │ │ │ ├── asyncify.md │ │ │ ├── configure_cache.md │ │ │ ├── disable_litellm_logging.md │ │ │ ├── disable_logging.md │ │ │ ├── enable_litellm_logging.md │ │ │ ├── enable_logging.md │ │ │ ├── inspect_history.md │ │ │ ├── load.md │ │ │ ├── StatusMessage.md │ │ │ ├── StatusMessageProvider.md │ │ │ ├── streamify.md │ │ │ └── StreamListener.md │ │ ├── cheatsheet.md │ │ ├── community │ │ │ ├── community-resources.md │ │ │ ├── how-to-contribute.md │ │ │ └── use-cases.md │ │ ├── deep-dive │ │ │ └── data-handling │ │ │ ├── built-in-datasets.md │ │ │ ├── examples.md │ │ │ ├── img │ │ │ │ └── data-loading.png │ │ │ └── loading-custom-data.md │ │ ├── faqs.md │ │ ├── index.md │ │ ├── js │ │ │ └── runllm-widget.js │ │ ├── learn │ │ │ ├── evaluation │ │ │ │ ├── data.md │ │ │ │ ├── metrics.md │ │ │ │ └── overview.md │ │ │ ├── figures │ │ │ │ ├── native_tool_call.png │ │ │ │ └── teleprompter-classes.png │ │ │ ├── index.md │ │ │ ├── optimization │ │ │ │ ├── optimizers.md │ │ │ │ └── overview.md │ │ │ └── programming │ │ │ ├── 7-assertions.md │ │ │ ├── adapters.md │ │ │ ├── language_models.md │ │ │ ├── mcp.md │ │ │ ├── modules.md │ │ │ ├── overview.md │ │ │ ├── signatures.md │ │ │ └── tools.md │ │ ├── production │ │ │ └── index.md │ │ ├── roadmap.md │ │ ├── static │ │ │ ├── .nojekyll │ │ │ └── img │ │ │ ├── dspy_logo.png │ │ │ ├── logo.png │ │ │ ├── mlflow-tracing-rag.png │ │ │ ├── modular.png │ │ │ ├── optimize.png │ │ │ ├── undraw_docusaurus_mountain.svg │ │ │ ├── undraw_docusaurus_react.svg │ │ │ ├── undraw_docusaurus_tree.svg │ │ │ └── universal_compatibility.png │ │ ├── stylesheets │ │ │ └── extra.css │ │ └── tutorials │ │ ├── agents │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-agent.png │ │ ├── ai_text_game │ │ │ └── index.md │ │ ├── async │ │ │ └── index.md │ │ ├── audio │ │ │ └── index.ipynb │ │ ├── build_ai_program │ │ │ └── index.md │ │ ├── cache │ │ │ └── index.md │ │ ├── classification │ │ │ └── index.md │ │ ├── classification_finetuning │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-classification.png │ │ ├── conversation_history │ │ │ └── index.md │ │ ├── core_development │ │ │ └── index.md │ │ ├── custom_module │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-custom-module.png │ │ ├── customer_service_agent │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-customer-service-agent.png │ │ ├── deployment │ │ │ ├── dspy_mlflow_ui.png │ │ │ └── index.md │ │ ├── email_extraction │ │ │ ├── index.md │ │ │ └── mlflow-tracing-email-extraction.png │ │ ├── entity_extraction │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-entity-extraction.png │ │ ├── games │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-agent.png │ │ ├── gepa_ai_program │ │ │ └── index.md │ │ ├── gepa_aime │ │ │ ├── index.ipynb │ │ │ ├── mlflow-tracing-gepa-aime.png │ │ │ └── mlflow-tracking-gepa-aime-optimization.png │ │ ├── gepa_facilitysupportanalyzer │ │ │ ├── index.ipynb │ │ │ ├── mlflow-tracing-gepa-support.png │ │ │ └── mlflow-tracking-gepa-support-optimization.png │ │ ├── gepa_papillon │ │ │ ├── index.ipynb │ │ │ ├── mlflow-tracing-gepa-papilon.png │ │ │ └── mlflow-tracking-gepa-papilon-optimization.png │ │ ├── image_generation_prompting │ │ │ └── index.ipynb │ │ ├── index.md │ │ ├── llms_txt_generation │ │ │ └── index.md │ │ ├── math │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-math.png │ │ ├── mcp │ │ │ └── index.md │ │ ├── mem0_react_agent │ │ │ └── index.md │ │ ├── multihop_search │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-multi-hop.png │ │ ├── observability │ │ │ ├── index.md │ │ │ ├── mlflow_trace_ui_navigation.gif │ │ │ ├── mlflow_trace_ui.png │ │ │ └── mlflow_trace_view.png │ │ ├── optimize_ai_program │ │ │ └── index.md │ │ ├── optimizer_tracking │ │ │ ├── child_run.png │ │ │ ├── experiment.png │ │ │ ├── index.md │ │ │ └── parent_run.png │ │ ├── output_refinement │ │ │ └── best-of-n-and-refine.md │ │ ├── papillon │ │ │ └── index.md │ │ ├── program_of_thought │ │ │ └── index.ipynb │ │ ├── rag │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-rag.png │ │ ├── real_world_examples │ │ │ └── index.md │ │ ├── rl_ai_program │ │ │ └── index.md │ │ ├── rl_multihop │ │ │ └── index.ipynb │ │ ├── rl_papillon │ │ │ └── index.ipynb │ │ ├── sample_code_generation │ │ │ └── index.md │ │ ├── saving │ │ │ └── index.md │ │ ├── streaming │ │ │ └── index.md │ │ ├── tool_use │ │ │ └── index.ipynb │ │ └── yahoo_finance_react │ │ └── index.md │ ├── mkdocs.yml │ ├── overrides │ │ ├── home.html │ │ ├── main.html │ │ └── partials │ │ └── tabs.html │ ├── Pipfile │ ├── Pipfile.lock │ ├── README.md │ ├── requirements.txt │ ├── scripts │ │ ├── generate_api_docs.py │ │ └── generate_api_summary.py │ └── vercel.json ├── dspy │ ├── __init__.py │ ├── __metadata__.py │ ├── adapters │ │ ├── __init__.py │ │ ├── baml_adapter.py │ │ ├── base.py │ │ ├── chat_adapter.py │ │ ├── json_adapter.py │ │ ├── two_step_adapter.py │ │ ├── types │ │ │ ├── __init__.py │ │ │ ├── audio.py │ │ │ ├── base_type.py │ │ │ ├── citation.py │ │ │ ├── code.py │ │ │ ├── document.py │ │ │ ├── history.py │ │ │ ├── image.py │ │ │ └── tool.py │ │ ├── utils.py │ │ └── xml_adapter.py │ ├── clients │ │ ├── __init__.py │ │ ├── base_lm.py │ │ ├── cache.py │ │ ├── databricks.py │ │ ├── embedding.py │ │ ├── lm_local_arbor.py │ │ ├── lm_local.py │ │ ├── lm.py │ │ ├── openai.py │ │ ├── provider.py │ │ └── utils_finetune.py │ ├── datasets │ │ ├── __init__.py │ │ ├── alfworld │ │ │ ├── __init__.py │ │ │ ├── alfworld.py │ │ │ └── base_config.yml │ │ ├── colors.py │ │ ├── dataloader.py │ │ ├── dataset.py │ │ ├── gsm8k.py │ │ ├── hotpotqa.py │ │ └── math.py │ ├── dsp │ │ ├── __init__.py │ │ ├── colbertv2.py │ │ └── utils │ │ ├── __init__.py │ │ ├── dpr.py │ │ ├── settings.py │ │ └── utils.py │ ├── evaluate │ │ ├── __init__.py │ │ ├── auto_evaluation.py │ │ ├── evaluate.py │ │ └── metrics.py │ ├── experimental │ │ └── __init__.py │ ├── predict │ │ ├── __init__.py │ │ ├── aggregation.py │ │ ├── avatar │ │ │ ├── __init__.py │ │ │ ├── avatar.py │ │ │ ├── models.py │ │ │ └── signatures.py │ │ ├── best_of_n.py │ │ ├── chain_of_thought.py │ │ ├── code_act.py │ │ ├── knn.py │ │ ├── multi_chain_comparison.py │ │ ├── parallel.py │ │ ├── parameter.py │ │ ├── predict.py │ │ ├── program_of_thought.py │ │ ├── react.py │ │ ├── refine.py │ │ └── retry.py │ ├── primitives │ │ ├── __init__.py │ │ ├── base_module.py │ │ ├── example.py │ │ ├── module.py │ │ ├── prediction.py │ │ ├── python_interpreter.py │ │ └── runner.js │ ├── propose │ │ ├── __init__.py │ │ ├── dataset_summary_generator.py │ │ ├── grounded_proposer.py │ │ ├── propose_base.py │ │ └── utils.py │ ├── retrievers │ │ ├── __init__.py │ │ ├── databricks_rm.py │ │ ├── embeddings.py │ │ ├── retrieve.py │ │ └── weaviate_rm.py │ ├── signatures │ │ ├── __init__.py │ │ ├── field.py │ │ ├── signature.py │ │ └── utils.py │ ├── streaming │ │ ├── __init__.py │ │ ├── messages.py │ │ ├── streamify.py │ │ └── streaming_listener.py │ ├── teleprompt │ │ ├── __init__.py │ │ ├── avatar_optimizer.py │ │ ├── bettertogether.py │ │ ├── bootstrap_finetune.py │ │ ├── bootstrap_trace.py │ │ ├── bootstrap.py │ │ ├── copro_optimizer.py │ │ ├── ensemble.py │ │ ├── gepa │ │ │ ├── __init__.py │ │ │ ├── gepa_utils.py │ │ │ ├── gepa.py │ │ │ └── instruction_proposal.py │ │ ├── grpo.py │ │ ├── infer_rules.py │ │ ├── knn_fewshot.py │ │ ├── mipro_optimizer_v2.py │ │ ├── random_search.py │ │ ├── signature_opt.py │ │ ├── simba_utils.py │ │ ├── simba.py │ │ ├── teleprompt_optuna.py │ │ ├── teleprompt.py │ │ ├── utils.py │ │ └── vanilla.py │ └── utils │ ├── __init__.py │ ├── annotation.py │ ├── asyncify.py │ ├── caching.py │ ├── callback.py │ ├── dummies.py │ ├── exceptions.py │ ├── hasher.py │ ├── inspect_history.py │ ├── langchain_tool.py │ ├── logging_utils.py │ ├── mcp.py │ ├── parallelizer.py │ ├── saving.py │ ├── syncify.py │ ├── unbatchify.py │ └── usage_tracker.py ├── LICENSE ├── pyproject.toml ├── README.md ├── tests │ ├── __init__.py │ ├── adapters │ │ ├── test_adapter_utils.py │ │ ├── test_baml_adapter.py │ │ ├── test_base_type.py │ │ ├── test_chat_adapter.py │ │ ├── test_citation.py │ │ ├── test_code.py │ │ ├── test_document.py │ │ ├── test_json_adapter.py │ │ ├── test_tool.py │ │ ├── test_two_step_adapter.py │ │ └── test_xml_adapter.py │ ├── callback │ │ └── test_callback.py │ ├── clients │ │ ├── test_cache.py │ │ ├── test_databricks.py │ │ ├── test_embedding.py │ │ ├── test_inspect_global_history.py │ │ └── test_lm.py │ ├── conftest.py │ ├── datasets │ │ └── test_dataset.py │ ├── docs │ │ └── test_mkdocs_links.py │ ├── evaluate │ │ ├── test_evaluate.py │ │ └── test_metrics.py │ ├── examples │ │ └── test_baleen.py │ ├── metadata │ │ └── test_metadata.py │ ├── predict │ │ ├── test_aggregation.py │ │ ├── test_best_of_n.py │ │ ├── test_chain_of_thought.py │ │ ├── test_code_act.py │ │ ├── test_knn.py │ │ ├── test_multi_chain_comparison.py │ │ ├── test_parallel.py │ │ ├── test_predict.py │ │ ├── test_program_of_thought.py │ │ ├── test_react.py │ │ ├── test_refine.py │ │ └── test_retry.py │ ├── primitives │ │ ├── resources │ │ │ └── saved_program.json │ │ ├── test_base_module.py │ │ ├── test_example.py │ │ ├── test_module.py │ │ └── test_python_interpreter.py │ ├── propose │ │ └── test_grounded_proposer.py │ ├── README.md │ ├── reliability │ │ ├── __init__.py │ │ ├── complex_types │ │ │ └── generated │ │ │ ├── test_many_types_1 │ │ │ │ ├── inputs │ │ │ │ │ ├── input1.json │ │ │ │ │ └── input2.json │ │ │ │ ├── program.py │ │ │ │ └── schema.json │ │ │ ├── test_nesting_1 │ │ │ │ ├── inputs │ │ │ │ │ ├── input1.json │ │ │ │ │ └── input2.json │ │ │ │ ├── program.py │ │ │ │ └── schema.json │ │ │ └── test_nesting_2 │ │ │ ├── inputs │ │ │ │ └── input1.json │ │ │ ├── program.py │ │ │ └── schema.json │ │ ├── conftest.py │ │ ├── generate │ │ │ ├── __init__.py │ │ │ ├── __main__.py │ │ │ └── utils.py │ │ ├── input_formats │ │ │ └── generated │ │ │ └── test_markdown_1 │ │ │ ├── inputs │ │ │ │ ├── input1.json │ │ │ │ └── input2.json │ │ │ ├── program.py │ │ │ └── schema.json │ │ ├── README.md │ │ ├── reliability_conf.yaml │ │ ├── test_generated.py │ │ ├── test_pydantic_models.py │ │ └── utils.py │ ├── retrievers │ │ └── test_embeddings.py │ ├── signatures │ │ ├── test_adapter_image.py │ │ ├── test_custom_types.py │ │ └── test_signature.py │ ├── streaming │ │ └── test_streaming.py │ ├── teleprompt │ │ ├── gepa_dummy_lm_custom_component_selector_custom_instruction_proposer.json │ │ ├── gepa_dummy_lm.json │ │ ├── test_bootstrap_finetune.py │ │ ├── test_bootstrap_trace.py │ │ ├── test_bootstrap.py │ │ ├── test_copro_optimizer.py │ │ ├── test_ensemble.py │ │ ├── test_finetune.py │ │ ├── test_gepa_instruction_proposer.py │ │ ├── test_gepa.py │ │ ├── test_grpo.py │ │ ├── test_knn_fewshot.py │ │ ├── test_random_search.py │ │ ├── test_teleprompt.py │ │ └── test_utils.py │ ├── test_utils │ │ ├── __init__.py │ │ └── server │ │ ├── __init__.py │ │ ├── litellm_server_config.yaml │ │ └── litellm_server.py │ └── utils │ ├── __init__.py │ ├── resources │ │ └── mcp_server.py │ ├── test_annotation.py │ ├── test_asyncify.py │ ├── test_exceptions.py │ ├── test_langchain_tool.py │ ├── test_mcp.py │ ├── test_parallelizer.py │ ├── test_saving.py │ ├── test_settings.py │ ├── test_syncify.py │ ├── test_unbatchify.py │ └── test_usage_tracker.py └── uv.lock ``` # Files -------------------------------------------------------------------------------- /docs/mkdocs.yml: -------------------------------------------------------------------------------- ```yaml 1 | site_name: DSPy 2 | site_description: The framework for programming—rather than prompting—language models. 3 | site_url: https://dspy.ai/ 4 | 5 | repo_url: https://github.com/stanfordnlp/dspy 6 | repo_name: stanfordnlp/dspy 7 | 8 | edit_uri: blob/main/docs/docs/ 9 | docs_dir: "docs/" 10 | 11 | nav: 12 | - Get Started: index.md 13 | - Learn DSPy: 14 | - Learning DSPy: learn/index.md 15 | - DSPy Programming: 16 | - Programming Overview: learn/programming/overview.md 17 | - Language Models: learn/programming/language_models.md 18 | - Signatures: learn/programming/signatures.md 19 | - Modules: learn/programming/modules.md 20 | - Adapters: learn/programming/adapters.md 21 | - Tools: learn/programming/tools.md 22 | - MCP: learn/programming/mcp.md 23 | - DSPy Evaluation: 24 | - Evaluation Overview: learn/evaluation/overview.md 25 | - Data Handling: learn/evaluation/data.md 26 | - Metrics: learn/evaluation/metrics.md 27 | - DSPy Optimization: 28 | - Optimization Overview: learn/optimization/overview.md 29 | - Optimizers: learn/optimization/optimizers.md 30 | - Tutorials: 31 | - Tutorials Overview: tutorials/index.md 32 | - Build AI Programs with DSPy: 33 | - Overview: tutorials/build_ai_program/index.md 34 | - Managing Conversation History: tutorials/conversation_history/index.md 35 | - Building AI Agents with DSPy: tutorials/customer_service_agent/index.ipynb 36 | - Building AI Applications by Customizing DSPy Modules: tutorials/custom_module/index.ipynb 37 | - Retrieval-Augmented Generation (RAG): tutorials/rag/index.ipynb 38 | - Building RAG as Agent: tutorials/agents/index.ipynb 39 | - Entity Extraction: tutorials/entity_extraction/index.ipynb 40 | - Classification: tutorials/classification/index.md 41 | - Multi-Hop RAG: tutorials/multihop_search/index.ipynb 42 | - Privacy-Conscious Delegation: tutorials/papillon/index.md 43 | - Program Of Thought: tutorials/program_of_thought/index.ipynb 44 | - Image Generation Prompt iteration: tutorials/image_generation_prompting/index.ipynb 45 | - Audio: tutorials/audio/index.ipynb 46 | - Optimize AI Programs with DSPy: 47 | - Overview: tutorials/optimize_ai_program/index.md 48 | - Math Reasoning: tutorials/math/index.ipynb 49 | - Classification Finetuning: tutorials/classification_finetuning/index.ipynb 50 | - Advanced Tool Use: tutorials/tool_use/index.ipynb 51 | - Finetuning Agents: tutorials/games/index.ipynb 52 | - Reflective Prompt Evolution with dspy.GEPA: 53 | - Overview: tutorials/gepa_ai_program/index.md 54 | - GEPA for AIME (Math): tutorials/gepa_aime/index.ipynb 55 | - GEPA for Structured Information Extraction for Enterprise Tasks: tutorials/gepa_facilitysupportanalyzer/index.ipynb 56 | - GEPA for Privacy-Conscious Delegation: tutorials/gepa_papillon/index.ipynb 57 | - Experimental RL Optimization for DSPy: 58 | - Overview: tutorials/rl_ai_program/index.md 59 | - RL for Privacy-Conscious Delegation: tutorials/rl_papillon/index.ipynb 60 | - RL for Multi-Hop Research: tutorials/rl_multihop/index.ipynb 61 | - Tools, Development, and Deployment: 62 | - Overview: tutorials/core_development/index.md 63 | - Use MCP in DSPy: tutorials/mcp/index.md 64 | - Output Refinement: tutorials/output_refinement/best-of-n-and-refine.md 65 | - Saving and Loading: tutorials/saving/index.md 66 | - Cache: tutorials/cache/index.md 67 | - Deployment: tutorials/deployment/index.md 68 | - Debugging & Observability: tutorials/observability/index.md 69 | - Tracking DSPy Optimizers: tutorials/optimizer_tracking/index.md 70 | - Streaming: tutorials/streaming/index.md 71 | - Async: tutorials/async/index.md 72 | - Real-World Examples: 73 | - Overview: tutorials/real_world_examples/index.md 74 | - Generating llms.txt: tutorials/llms_txt_generation/index.md 75 | - Memory-Enabled ReAct Agents: tutorials/mem0_react_agent/index.md 76 | - Financial Analysis with Yahoo Finance: tutorials/yahoo_finance_react/index.md 77 | - Email Information Extraction: tutorials/email_extraction/index.md 78 | - Code Generation for Unfamiliar Libraries: tutorials/sample_code_generation/index.md 79 | - Building a Creative Text-Based AI Game: tutorials/ai_text_game/index.md 80 | - DSPy in Production: production/index.md 81 | - Community: 82 | - Community Resources: community/community-resources.md 83 | - Use Cases: community/use-cases.md 84 | - Contributing: community/how-to-contribute.md 85 | - FAQ: 86 | - FAQ: faqs.md 87 | - Cheatsheet: cheatsheet.md 88 | 89 | - API Reference: 90 | - API Reference: api/index.md 91 | - Adapters: 92 | - Adapter: api/adapters/Adapter.md 93 | - ChatAdapter: api/adapters/ChatAdapter.md 94 | - JSONAdapter: api/adapters/JSONAdapter.md 95 | - TwoStepAdapter: api/adapters/TwoStepAdapter.md 96 | - Evaluation: 97 | - CompleteAndGrounded: api/evaluation/CompleteAndGrounded.md 98 | - Evaluate: api/evaluation/Evaluate.md 99 | - EvaluationResult: api/evaluation/EvaluationResult.md 100 | - SemanticF1: api/evaluation/SemanticF1.md 101 | - answer_exact_match: api/evaluation/answer_exact_match.md 102 | - answer_passage_match: api/evaluation/answer_passage_match.md 103 | - Experimental: 104 | - Citations: api/experimental/Citations.md 105 | - Document: api/experimental/Document.md 106 | - Models: 107 | - Embedder: api/models/Embedder.md 108 | - LM: api/models/LM.md 109 | - Modules: 110 | - BestOfN: api/modules/BestOfN.md 111 | - ChainOfThought: api/modules/ChainOfThought.md 112 | - CodeAct: api/modules/CodeAct.md 113 | - Module: api/modules/Module.md 114 | - MultiChainComparison: api/modules/MultiChainComparison.md 115 | - Parallel: api/modules/Parallel.md 116 | - Predict: api/modules/Predict.md 117 | - ProgramOfThought: api/modules/ProgramOfThought.md 118 | - ReAct: api/modules/ReAct.md 119 | - Refine: api/modules/Refine.md 120 | - Optimizers: 121 | - GEPA: 122 | - 1. GEPA Overview: api/optimizers/GEPA/overview.md 123 | - 2. GEPA Advanced: api/optimizers/GEPA/GEPA_Advanced.md 124 | - BetterTogether: api/optimizers/BetterTogether.md 125 | - BootstrapFewShot: api/optimizers/BootstrapFewShot.md 126 | - BootstrapFewShotWithRandomSearch: api/optimizers/BootstrapFewShotWithRandomSearch.md 127 | - BootstrapFinetune: api/optimizers/BootstrapFinetune.md 128 | - BootstrapRS: api/optimizers/BootstrapRS.md 129 | - COPRO: api/optimizers/COPRO.md 130 | - Ensemble: api/optimizers/Ensemble.md 131 | - InferRules: api/optimizers/InferRules.md 132 | - KNN: api/optimizers/KNN.md 133 | - KNNFewShot: api/optimizers/KNNFewShot.md 134 | - LabeledFewShot: api/optimizers/LabeledFewShot.md 135 | - MIPROv2: api/optimizers/MIPROv2.md 136 | - SIMBA: api/optimizers/SIMBA.md 137 | - Primitives: 138 | - Audio: api/primitives/Audio.md 139 | - Code: api/primitives/Code.md 140 | - Example: api/primitives/Example.md 141 | - History: api/primitives/History.md 142 | - Image: api/primitives/Image.md 143 | - Prediction: api/primitives/Prediction.md 144 | - Tool: api/primitives/Tool.md 145 | - ToolCalls: api/primitives/ToolCalls.md 146 | - Signatures: 147 | - InputField: api/signatures/InputField.md 148 | - OutputField: api/signatures/OutputField.md 149 | - Signature: api/signatures/Signature.md 150 | - Tools: 151 | - ColBERTv2: api/tools/ColBERTv2.md 152 | - Embeddings: api/tools/Embeddings.md 153 | - PythonInterpreter: api/tools/PythonInterpreter.md 154 | - Utils: 155 | - StatusMessage: api/utils/StatusMessage.md 156 | - StatusMessageProvider: api/utils/StatusMessageProvider.md 157 | - StreamListener: api/utils/StreamListener.md 158 | - asyncify: api/utils/asyncify.md 159 | - configure_cache: api/utils/configure_cache.md 160 | - disable_litellm_logging: api/utils/disable_litellm_logging.md 161 | - disable_logging: api/utils/disable_logging.md 162 | - enable_litellm_logging: api/utils/enable_litellm_logging.md 163 | - enable_logging: api/utils/enable_logging.md 164 | - inspect_history: api/utils/inspect_history.md 165 | - load: api/utils/load.md 166 | - streamify: api/utils/streamify.md 167 | 168 | theme: 169 | name: material 170 | custom_dir: overrides 171 | features: 172 | - navigation.tabs 173 | - navigation.path 174 | - navigation.indexes 175 | - navigation.expand 176 | - toc.follow 177 | - toc.integrate 178 | - navigation.top 179 | - search.suggest 180 | - search.highlight 181 | - content.tabs.link 182 | - content.code.annotation 183 | - content.code.copy 184 | - navigation.footer 185 | - content.action.edit 186 | language: en 187 | palette: 188 | - scheme: default 189 | toggle: 190 | icon: material/weather-night 191 | name: Switch to dark mode 192 | primary: white 193 | accent: black 194 | - scheme: slate 195 | toggle: 196 | icon: material/weather-sunny 197 | name: Switch to light mode 198 | primary: black 199 | accent: lime 200 | icon: 201 | repo: fontawesome/brands/git-alt 202 | edit: material/pencil 203 | view: material/eye 204 | logo: static/img/dspy_logo.png 205 | favicon: static/img/logo.png 206 | 207 | extra_css: 208 | - stylesheets/extra.css 209 | 210 | plugins: 211 | - social 212 | - search: 213 | lang: en 214 | separator: '[\s\-\.]+' 215 | - mkdocstrings: 216 | handlers: 217 | python: 218 | options: 219 | docstring_style: google 220 | show_source: true 221 | show_root_heading: true 222 | heading_level: 3 223 | members_order: source 224 | separate_signature: false 225 | show_category_heading: true 226 | show_symbol_type_heading: true 227 | show_docstring_parameters: true 228 | show_if_no_docstring: true 229 | show_signature_annotations: true 230 | unwrap_annotated: true 231 | annotations_path: brief 232 | docstring_section_style: table 233 | merge_init_into_class: true 234 | rendering: 235 | show_if_no_docstring: true 236 | show_warnings: false 237 | html_meta: false 238 | - mkdocs-jupyter: 239 | ignore_h1_titles: true 240 | - redirects: 241 | redirect_maps: 242 | # Redirect /intro/ to the main page 243 | "intro/index.md": "index.md" 244 | "intro.md": "index.md" 245 | 246 | "deep-dive/optimizers/bootstrap-fewshot.md": "api/optimizers/BootstrapFewShot.md" 247 | "deep-dive/optimizers/bfrs.md": "api/optimizers/BootstrapFewShotWithRandomSearch.md" 248 | "deep-dive/optimizers/BootstrapFinetune.md": "api/optimizers/BootstrapFinetune.md" 249 | "deep-dive/optimizers/copro.md": "api/optimizers/COPRO.md" 250 | "deep-dive/optimizers/Ensemble.md": "api/optimizers/Ensemble.md" 251 | "deep-dive/optimizers/LabeledFewShot.md": "api/optimizers/LabeledFewShot.md" 252 | "deep-dive/optimizers/miprov2.md": "api/optimizers/MIPROv2.md" 253 | "api/optimizers/GEPA/index.md": "api/optimizers/GEPA/overview.md" 254 | 255 | "docs/quick-start/getting-started-01.md": "tutorials/rag/index.ipynb" 256 | "docs/quick-start/getting-started-02.md": "tutorials/rag/index.ipynb" 257 | "quick-start/getting-started-01.md": "tutorials/rag/index.ipynb" 258 | "quick-start/getting-started-02.md": "tutorials/rag/index.ipynb" 259 | - llmstxt: 260 | markdown_description: > 261 | DSPy is the framework for programming—rather than prompting—language models. 262 | DSPy unifies techniques for prompting, fine-tuning, reasoning, tool use, and evaluation of LMs. 263 | It provides a systematic approach to building AI applications through composable modules, 264 | optimization techniques, and evaluation frameworks. 265 | sections: 266 | Getting Started: 267 | - index.md: DSPy overview and quick start guide 268 | - cheatsheet.md: DSPy cheatsheet for quick reference 269 | Core Concepts: 270 | - learn/programming/overview.md: Programming paradigm and philosophy 271 | - learn/programming/signatures.md: Signatures - declarative input/output specifications 272 | - learn/programming/modules.md: Modules - composable AI components 273 | - learn/programming/language_models.md: Language model interfaces and configuration 274 | Essential Tutorials: 275 | - tutorials/rag/index.ipynb: Retrieval-Augmented Generation (RAG) tutorial 276 | - tutorials/classification/index.md: Classification with DSPy 277 | - tutorials/agents/index.ipynb: Building AI agents with DSPy 278 | Optimization: 279 | - learn/optimization/overview.md: Optimization techniques overview 280 | - tutorials/optimize_ai_program/index.md: Guide to optimizing AI programs 281 | - api/optimizers/BootstrapFewShot.md: Bootstrap few-shot optimizer 282 | Key Modules API: 283 | - api/modules/Predict.md: Basic prediction module 284 | - api/modules/ChainOfThought.md: Chain of thought reasoning 285 | - api/modules/ReAct.md: ReAct agent module 286 | Core API Reference: 287 | - api/signatures/Signature.md: Signature system documentation 288 | - api/primitives/Example.md: Example primitive for training data 289 | Production: 290 | - tutorials/deployment/index.md: Production deployment guide 291 | - tutorials/observability/index.md: Debugging and observability 292 | 293 | extra: 294 | social: 295 | - icon: fontawesome/brands/github 296 | link: https://github.com/stanfordnlp/dspy 297 | - icon: fontawesome/brands/discord 298 | link: https://discord.gg/XCGy2WDCQB 299 | 300 | extra_javascript: 301 | - "js/runllm-widget.js" 302 | 303 | markdown_extensions: 304 | - toc: 305 | permalink: true 306 | toc_depth: 3 307 | - pymdownx.tabbed: 308 | alternate_style: true 309 | - pymdownx.highlight: 310 | anchor_linenums: true 311 | - pymdownx.inlinehilite 312 | - pymdownx.snippets 313 | - admonition 314 | - pymdownx.arithmatex: 315 | generic: true 316 | - footnotes 317 | - pymdownx.details 318 | - pymdownx.superfences 319 | - pymdownx.mark 320 | - attr_list 321 | - md_in_html 322 | - pymdownx.emoji: 323 | emoji_index: !!python/name:material.extensions.emoji.twemoji 324 | emoji_generator: !!python/name:material.extensions.emoji.to_svg 325 | 326 | copyright: | 327 | © 2025 <a href="https://github.com/stanfordnlp" target="_blank" rel="noopener">DSPy</a> ``` -------------------------------------------------------------------------------- /docs/docs/tutorials/mcp/index.md: -------------------------------------------------------------------------------- ```markdown 1 | # Tutorial: Use MCP tools in DSPy 2 | 3 | MCP, standing for Model Context Protocol, is an open protocol that standardizes how applications 4 | provide context to LLMs. Despite some development overhead, MCP offers a valuable opportunity to 5 | share tools, resources, and prompts with other developers regardless of the technical stack you are 6 | using. Likewise, you can use the tools built by other developers without rewriting code. 7 | 8 | In this guide, we will walk you through how to use MCP tools in DSPy. For demonstration purposes, 9 | we will build an airline service agent that can help users book flights and modify or cancel 10 | existing bookings. This will rely on an MCP server with custom tools, but it should be easy to generalize 11 | to [MCP servers built by the community](https://modelcontextprotocol.io/examples). 12 | 13 | ??? "How to run this tutorial" 14 | This tutorial cannot be run in hosted IPython notebooks like Google Colab or Databricks notebooks. 15 | To run the code, you will need to follow the guide to write code on your local device. The code 16 | is tested on macOS and should work the same way in Linux environments. 17 | 18 | ## Install Dependencies 19 | 20 | Before starting, let's install the required dependencies: 21 | 22 | ```shell 23 | pip install -U "dspy[mcp]" 24 | ``` 25 | 26 | ## MCP Server Setup 27 | 28 | Let's first set up the MCP server for the airline agent, which contains: 29 | 30 | - A set of databases 31 | - User database, storing user information. 32 | - Flight database, storing flight information. 33 | - Ticket database, storing customer tickets. 34 | - A set of tools 35 | - fetch_flight_info: get flight information for specific dates. 36 | - fetch_itinerary: get information about booked itineraries. 37 | - book_itinerary: book a flight on behalf of the user. 38 | - modify_itinerary: modify an itinerary, either through flight changes or cancellation. 39 | - get_user_info: get user information. 40 | - file_ticket: file a backlog ticket for human assistance. 41 | 42 | In your working directory, create a file `mcp_server.py`, and paste the following content into 43 | it: 44 | 45 | ```python 46 | import random 47 | import string 48 | 49 | from mcp.server.fastmcp import FastMCP 50 | from pydantic import BaseModel 51 | 52 | # Create an MCP server 53 | mcp = FastMCP("Airline Agent") 54 | 55 | 56 | class Date(BaseModel): 57 | # Somehow LLM is bad at specifying `datetime.datetime` 58 | year: int 59 | month: int 60 | day: int 61 | hour: int 62 | 63 | 64 | class UserProfile(BaseModel): 65 | user_id: str 66 | name: str 67 | email: str 68 | 69 | 70 | class Flight(BaseModel): 71 | flight_id: str 72 | date_time: Date 73 | origin: str 74 | destination: str 75 | duration: float 76 | price: float 77 | 78 | 79 | class Itinerary(BaseModel): 80 | confirmation_number: str 81 | user_profile: UserProfile 82 | flight: Flight 83 | 84 | 85 | class Ticket(BaseModel): 86 | user_request: str 87 | user_profile: UserProfile 88 | 89 | 90 | user_database = { 91 | "Adam": UserProfile(user_id="1", name="Adam", email="[email protected]"), 92 | "Bob": UserProfile(user_id="2", name="Bob", email="[email protected]"), 93 | "Chelsie": UserProfile(user_id="3", name="Chelsie", email="[email protected]"), 94 | "David": UserProfile(user_id="4", name="David", email="[email protected]"), 95 | } 96 | 97 | flight_database = { 98 | "DA123": Flight( 99 | flight_id="DA123", 100 | origin="SFO", 101 | destination="JFK", 102 | date_time=Date(year=2025, month=9, day=1, hour=1), 103 | duration=3, 104 | price=200, 105 | ), 106 | "DA125": Flight( 107 | flight_id="DA125", 108 | origin="SFO", 109 | destination="JFK", 110 | date_time=Date(year=2025, month=9, day=1, hour=7), 111 | duration=9, 112 | price=500, 113 | ), 114 | "DA456": Flight( 115 | flight_id="DA456", 116 | origin="SFO", 117 | destination="SNA", 118 | date_time=Date(year=2025, month=10, day=1, hour=1), 119 | duration=2, 120 | price=100, 121 | ), 122 | "DA460": Flight( 123 | flight_id="DA460", 124 | origin="SFO", 125 | destination="SNA", 126 | date_time=Date(year=2025, month=10, day=1, hour=9), 127 | duration=2, 128 | price=120, 129 | ), 130 | } 131 | 132 | itinery_database = {} 133 | ticket_database = {} 134 | 135 | 136 | @mcp.tool() 137 | def fetch_flight_info(date: Date, origin: str, destination: str): 138 | """Fetch flight information from origin to destination on the given date""" 139 | flights = [] 140 | 141 | for flight_id, flight in flight_database.items(): 142 | if ( 143 | flight.date_time.year == date.year 144 | and flight.date_time.month == date.month 145 | and flight.date_time.day == date.day 146 | and flight.origin == origin 147 | and flight.destination == destination 148 | ): 149 | flights.append(flight) 150 | return flights 151 | 152 | 153 | @mcp.tool() 154 | def fetch_itinerary(confirmation_number: str): 155 | """Fetch a booked itinerary information from database""" 156 | return itinery_database.get(confirmation_number) 157 | 158 | 159 | @mcp.tool() 160 | def pick_flight(flights: list[Flight]): 161 | """Pick up the best flight that matches users' request.""" 162 | sorted_flights = sorted( 163 | flights, 164 | key=lambda x: ( 165 | x.get("duration") if isinstance(x, dict) else x.duration, 166 | x.get("price") if isinstance(x, dict) else x.price, 167 | ), 168 | ) 169 | return sorted_flights[0] 170 | 171 | 172 | def generate_id(length=8): 173 | chars = string.ascii_lowercase + string.digits 174 | return "".join(random.choices(chars, k=length)) 175 | 176 | 177 | @mcp.tool() 178 | def book_itinerary(flight: Flight, user_profile: UserProfile): 179 | """Book a flight on behalf of the user.""" 180 | confirmation_number = generate_id() 181 | while confirmation_number in itinery_database: 182 | confirmation_number = generate_id() 183 | itinery_database[confirmation_number] = Itinerary( 184 | confirmation_number=confirmation_number, 185 | user_profile=user_profile, 186 | flight=flight, 187 | ) 188 | return confirmation_number, itinery_database[confirmation_number] 189 | 190 | 191 | @mcp.tool() 192 | def cancel_itinerary(confirmation_number: str, user_profile: UserProfile): 193 | """Cancel an itinerary on behalf of the user.""" 194 | if confirmation_number in itinery_database: 195 | del itinery_database[confirmation_number] 196 | return 197 | raise ValueError("Cannot find the itinerary, please check your confirmation number.") 198 | 199 | 200 | @mcp.tool() 201 | def get_user_info(name: str): 202 | """Fetch the user profile from database with given name.""" 203 | return user_database.get(name) 204 | 205 | 206 | @mcp.tool() 207 | def file_ticket(user_request: str, user_profile: UserProfile): 208 | """File a customer support ticket if this is something the agent cannot handle.""" 209 | ticket_id = generate_id(length=6) 210 | ticket_database[ticket_id] = Ticket( 211 | user_request=user_request, 212 | user_profile=user_profile, 213 | ) 214 | return ticket_id 215 | 216 | 217 | if __name__ == "__main__": 218 | mcp.run() 219 | ``` 220 | 221 | Before we start the server, let's take a look at the code. 222 | 223 | We first create a `FastMCP` instance, which is a utility that helps quickly build an MCP server: 224 | 225 | ```python 226 | mcp = FastMCP("Airline Agent") 227 | ``` 228 | 229 | Then we define our data structures, which in a real-world application would be the database schema, e.g.: 230 | 231 | ```python 232 | class Flight(BaseModel): 233 | flight_id: str 234 | date_time: Date 235 | origin: str 236 | destination: str 237 | duration: float 238 | price: float 239 | ``` 240 | 241 | Following that, we initialize our database instances. In a real-world application, these would be connectors to 242 | actual databases, but for simplicity, we just use dictionaries: 243 | 244 | ```python 245 | user_database = { 246 | "Adam": UserProfile(user_id="1", name="Adam", email="[email protected]"), 247 | "Bob": UserProfile(user_id="2", name="Bob", email="[email protected]"), 248 | "Chelsie": UserProfile(user_id="3", name="Chelsie", email="[email protected]"), 249 | "David": UserProfile(user_id="4", name="David", email="[email protected]"), 250 | } 251 | ``` 252 | 253 | The next step is to define the tools and mark them with `@mcp.tool()` so that they are discoverable by 254 | MCP clients as MCP tools: 255 | 256 | ```python 257 | @mcp.tool() 258 | def fetch_flight_info(date: Date, origin: str, destination: str): 259 | """Fetch flight information from origin to destination on the given date""" 260 | flights = [] 261 | 262 | for flight_id, flight in flight_database.items(): 263 | if ( 264 | flight.date_time.year == date.year 265 | and flight.date_time.month == date.month 266 | and flight.date_time.day == date.day 267 | and flight.origin == origin 268 | and flight.destination == destination 269 | ): 270 | flights.append(flight) 271 | return flights 272 | ``` 273 | 274 | The last step is spinning up the server: 275 | 276 | ```python 277 | if __name__ == "__main__": 278 | mcp.run() 279 | ``` 280 | 281 | Now we have finished writing the server! Let's launch it: 282 | 283 | ```shell 284 | python path_to_your_working_directory/mcp_server.py 285 | ``` 286 | 287 | ## Write a DSPy Program That Utilizes Tools in MCP Server 288 | 289 | Now that the server is running, let's build the actual airline service agent which 290 | utilizes the MCP tools in our server to assist users. In your working directory, 291 | create a file named `dspy_mcp_agent.py`, and follow the guide to add code to it. 292 | 293 | ### Gather Tools from MCP Servers 294 | 295 | We first need to gather all available tools from the MCP server and make them 296 | usable by DSPy. DSPy provides an API [`dspy.Tool`](https://dspy.ai/api/primitives/Tool/) 297 | as the standard tool interface. Let's convert all the MCP tools to `dspy.Tool`. 298 | 299 | We need to create an MCP client instance to communicate with the MCP server, fetch all available 300 | tools, and convert them to `dspy.Tool` using the static method `from_mcp_tool`: 301 | 302 | ```python 303 | from mcp import ClientSession, StdioServerParameters 304 | from mcp.client.stdio import stdio_client 305 | 306 | # Create server parameters for stdio connection 307 | server_params = StdioServerParameters( 308 | command="python", # Executable 309 | args=["path_to_your_working_directory/mcp_server.py"], 310 | env=None, 311 | ) 312 | 313 | async def run(): 314 | async with stdio_client(server_params) as (read, write): 315 | async with ClientSession(read, write) as session: 316 | # Initialize the connection 317 | await session.initialize() 318 | # List available tools 319 | tools = await session.list_tools() 320 | 321 | # Convert MCP tools to DSPy tools 322 | dspy_tools = [] 323 | for tool in tools.tools: 324 | dspy_tools.append(dspy.Tool.from_mcp_tool(session, tool)) 325 | 326 | print(len(dspy_tools)) 327 | print(dspy_tools[0].args) 328 | 329 | if __name__ == "__main__": 330 | import asyncio 331 | 332 | asyncio.run(run()) 333 | ``` 334 | 335 | With the code above, we have successfully collected all available MCP tools and converted 336 | them to DSPy tools. 337 | 338 | 339 | ### Build a DSPy Agent to Handle Customer Requests 340 | 341 | Now we will use `dspy.ReAct` to build the agent for handling customer requests. `ReAct` stands 342 | for "reasoning and acting," which asks the LLM to decide whether to call a tool or wrap up the process. 343 | If a tool is required, the LLM takes responsibility for deciding which tool to call and providing 344 | the appropriate arguments. 345 | 346 | As usual, we need to create a `dspy.Signature` to define the input and output of our agent: 347 | 348 | ```python 349 | import dspy 350 | 351 | class DSPyAirlineCustomerService(dspy.Signature): 352 | """You are an airline customer service agent. You are given a list of tools to handle user requests. You should decide the right tool to use in order to fulfill users' requests.""" 353 | 354 | user_request: str = dspy.InputField() 355 | process_result: str = dspy.OutputField( 356 | desc=( 357 | "Message that summarizes the process result, and the information users need, " 358 | "e.g., the confirmation_number if it's a flight booking request." 359 | ) 360 | ) 361 | ``` 362 | 363 | And choose an LM for our agent: 364 | 365 | ```python 366 | dspy.configure(lm=dspy.LM("openai/gpt-4o-mini")) 367 | ``` 368 | 369 | Then we create the ReAct agent by passing the tools and signature into the `dspy.ReAct` API. We can now 370 | put together the complete code script: 371 | 372 | ```python 373 | from mcp import ClientSession, StdioServerParameters 374 | from mcp.client.stdio import stdio_client 375 | 376 | import dspy 377 | 378 | # Create server parameters for stdio connection 379 | server_params = StdioServerParameters( 380 | command="python", # Executable 381 | args=["script_tmp/mcp_server.py"], # Optional command line arguments 382 | env=None, # Optional environment variables 383 | ) 384 | 385 | 386 | class DSPyAirlineCustomerService(dspy.Signature): 387 | """You are an airline customer service agent. You are given a list of tools to handle user requests. 388 | You should decide the right tool to use in order to fulfill users' requests.""" 389 | 390 | user_request: str = dspy.InputField() 391 | process_result: str = dspy.OutputField( 392 | desc=( 393 | "Message that summarizes the process result, and the information users need, " 394 | "e.g., the confirmation_number if it's a flight booking request." 395 | ) 396 | ) 397 | 398 | 399 | dspy.configure(lm=dspy.LM("openai/gpt-4o-mini")) 400 | 401 | 402 | async def run(user_request): 403 | async with stdio_client(server_params) as (read, write): 404 | async with ClientSession(read, write) as session: 405 | # Initialize the connection 406 | await session.initialize() 407 | # List available tools 408 | tools = await session.list_tools() 409 | 410 | # Convert MCP tools to DSPy tools 411 | dspy_tools = [] 412 | for tool in tools.tools: 413 | dspy_tools.append(dspy.Tool.from_mcp_tool(session, tool)) 414 | 415 | # Create the agent 416 | react = dspy.ReAct(DSPyAirlineCustomerService, tools=dspy_tools) 417 | 418 | result = await react.acall(user_request=user_request) 419 | print(result) 420 | 421 | 422 | if __name__ == "__main__": 423 | import asyncio 424 | 425 | asyncio.run(run("please help me book a flight from SFO to JFK on 09/01/2025, my name is Adam")) 426 | ``` 427 | 428 | Note that we must call `react.acall` because MCP tools are async by default. Let's execute the script: 429 | 430 | ```shell 431 | python path_to_your_working_directory/dspy_mcp_agent.py 432 | ``` 433 | 434 | You should see output similar to this: 435 | 436 | ``` 437 | Prediction( 438 | trajectory={'thought_0': 'I need to fetch flight information for Adam from SFO to JFK on 09/01/2025 to find available flights for booking.', 'tool_name_0': 'fetch_flight_info', 'tool_args_0': {'date': {'year': 2025, 'month': 9, 'day': 1, 'hour': 0}, 'origin': 'SFO', 'destination': 'JFK'}, 'observation_0': ['{"flight_id": "DA123", "date_time": {"year": 2025, "month": 9, "day": 1, "hour": 1}, "origin": "SFO", "destination": "JFK", "duration": 3.0, "price": 200.0}', '{"flight_id": "DA125", "date_time": {"year": 2025, "month": 9, "day": 1, "hour": 7}, "origin": "SFO", "destination": "JFK", "duration": 9.0, "price": 500.0}'], ..., 'tool_name_4': 'finish', 'tool_args_4': {}, 'observation_4': 'Completed.'}, 439 | reasoning="I successfully booked a flight for Adam from SFO to JFK on 09/01/2025. I found two available flights, selected the more economical option (flight DA123 at 1 AM for $200), retrieved Adam's user profile, and completed the booking process. The confirmation number for the flight is 8h7clk3q.", 440 | process_result='Your flight from SFO to JFK on 09/01/2025 has been successfully booked. Your confirmation number is 8h7clk3q.' 441 | ) 442 | ``` 443 | 444 | The `trajectory` field contains the entire thinking and acting process. If you're curious about what's happening 445 | under the hood, check out the [Observability Guide](https://dspy.ai/tutorials/observability/) to set up MLflow, 446 | which visualizes every step happening inside `dspy.ReAct`! 447 | 448 | 449 | ## Conclusion 450 | 451 | In this guide, we built an airline service agent that utilizes a custom MCP server and the `dspy.ReAct` module. In the context 452 | of MCP support, DSPy provides a simple interface for interacting with MCP tools, giving you the flexibility to implement 453 | any functionality you need. 454 | ``` -------------------------------------------------------------------------------- /docs/docs/learn/optimization/optimizers.md: -------------------------------------------------------------------------------- ```markdown 1 | --- 2 | sidebar_position: 1 3 | --- 4 | 5 | # DSPy Optimizers (formerly Teleprompters) 6 | 7 | 8 | A **DSPy optimizer** is an algorithm that can tune the parameters of a DSPy program (i.e., the prompts and/or the LM weights) to maximize the metrics you specify, like accuracy. 9 | 10 | 11 | A typical DSPy optimizer takes three things: 12 | 13 | - Your **DSPy program**. This may be a single module (e.g., `dspy.Predict`) or a complex multi-module program. 14 | 15 | - Your **metric**. This is a function that evaluates the output of your program, and assigns it a score (higher is better). 16 | 17 | - A few **training inputs**. This may be very small (i.e., only 5 or 10 examples) and incomplete (only inputs to your program, without any labels). 18 | 19 | If you happen to have a lot of data, DSPy can leverage that. But you can start small and get strong results. 20 | 21 | **Note:** Formerly called teleprompters. We are making an official name update, which will be reflected throughout the library and documentation. 22 | 23 | 24 | ## What does a DSPy Optimizer tune? How does it tune them? 25 | 26 | Different optimizers in DSPy will tune your program's quality by **synthesizing good few-shot examples** for every module, like `dspy.BootstrapRS`,<sup>[1](https://arxiv.org/abs/2310.03714)</sup> **proposing and intelligently exploring better natural-language instructions** for every prompt, like `dspy.MIPROv2`,<sup>[2](https://arxiv.org/abs/2406.11695)</sup> and `dspy.GEPA`,<sup>[3](https://arxiv.org/abs/2507.19457)</sup> and **building datasets for your modules and using them to finetune the LM weights** in your system, like `dspy.BootstrapFinetune`.<sup>[4](https://arxiv.org/abs/2407.10930)</sup> 27 | 28 | ??? "What's an example of a DSPy optimizer? How do different optimizers work?" 29 | 30 | Take the `dspy.MIPROv2` optimizer as an example. First, MIPRO starts with the **bootstrapping stage**. It takes your program, which may be unoptimized at this point, and runs it many times across different inputs to collect traces of input/output behavior for each one of your modules. It filters these traces to keep only those that appear in trajectories scored highly by your metric. Second, MIPRO enters its **grounded proposal stage**. It previews your DSPy program's code, your data, and traces from running your program, and uses them to draft many potential instructions for every prompt in your program. Third, MIPRO launches the **discrete search stage**. It samples mini-batches from your training set, proposes a combination of instructions and traces to use for constructing every prompt in the pipeline, and evaluates the candidate program on the mini-batch. Using the resulting score, MIPRO updates a surrogate model that helps the proposals get better over time. 31 | 32 | One thing that makes DSPy optimizers so powerful is that they can be composed. You can run `dspy.MIPROv2` and use the produced program as an input to `dspy.MIPROv2` again or, say, to `dspy.BootstrapFinetune` to get better results. This is partly the essence of `dspy.BetterTogether`. Alternatively, you can run the optimizer and then extract the top-5 candidate programs and build a `dspy.Ensemble` of them. This allows you to scale _inference-time compute_ (e.g., ensembles) as well as DSPy's unique _pre-inference time compute_ (i.e., optimization budget) in highly systematic ways. 33 | 34 | 35 | 36 | ## What DSPy Optimizers are currently available? 37 | 38 | Optimizers can be accessed via `from dspy.teleprompt import *`. 39 | 40 | ### Automatic Few-Shot Learning 41 | 42 | These optimizers extend the signature by automatically generating and including **optimized** examples within the prompt sent to the model, implementing few-shot learning. 43 | 44 | 1. [**`LabeledFewShot`**](../../api/optimizers/LabeledFewShot.md): Simply constructs few-shot examples (demos) from provided labeled input and output data points. Requires `k` (number of examples for the prompt) and `trainset` to randomly select `k` examples from. 45 | 46 | 2. [**`BootstrapFewShot`**](../../api/optimizers/BootstrapFewShot.md): Uses a `teacher` module (which defaults to your program) to generate complete demonstrations for every stage of your program, along with labeled examples in `trainset`. Parameters include `max_labeled_demos` (the number of demonstrations randomly selected from the `trainset`) and `max_bootstrapped_demos` (the number of additional examples generated by the `teacher`). The bootstrapping process employs the metric to validate demonstrations, including only those that pass the metric in the "compiled" prompt. Advanced: Supports using a `teacher` program that is a *different* DSPy program that has compatible structure, for harder tasks. 47 | 48 | 3. [**`BootstrapFewShotWithRandomSearch`**](../../api/optimizers/BootstrapFewShotWithRandomSearch.md): Applies `BootstrapFewShot` several times with random search over generated demonstrations, and selects the best program over the optimization. Parameters mirror those of `BootstrapFewShot`, with the addition of `num_candidate_programs`, which specifies the number of random programs evaluated over the optimization, including candidates of the uncompiled program, `LabeledFewShot` optimized program, `BootstrapFewShot` compiled program with unshuffled examples and `num_candidate_programs` of `BootstrapFewShot` compiled programs with randomized example sets. 49 | 50 | 4. [**`KNNFewShot`**](../../api/optimizers/KNNFewShot.md). Uses k-Nearest Neighbors algorithm to find the nearest training example demonstrations for a given input example. These nearest neighbor demonstrations are then used as the trainset for the BootstrapFewShot optimization process. 51 | 52 | 53 | ### Automatic Instruction Optimization 54 | 55 | These optimizers produce optimal instructions for the prompt and, in the case of MIPROv2 can also optimize the set of few-shot demonstrations. 56 | 57 | 5. [**`COPRO`**](../../api/optimizers/COPRO.md): Generates and refines new instructions for each step, and optimizes them with coordinate ascent (hill-climbing using the metric function and the `trainset`). Parameters include `depth` which is the number of iterations of prompt improvement the optimizer runs over. 58 | 59 | 6. [**`MIPROv2`**](../../api/optimizers/MIPROv2.md): Generates instructions *and* few-shot examples in each step. The instruction generation is data-aware and demonstration-aware. Uses Bayesian Optimization to effectively search over the space of generation instructions/demonstrations across your modules. 60 | 61 | 7. [**`SIMBA`**](../../api/optimizers/SIMBA.md) 62 | 63 | 8. [**`GEPA`**](../../api/optimizers/GEPA/overview.md): Uses LM's to reflect on the DSPy program's trajectory, to identify what worked, what didn't and propose prompts addressing the gaps. Additionally, GEPA can leverage domain-specific textual feedback to rapidly improve the DSPy program. Detailed tutorials on using GEPA are available at [dspy.GEPA Tutorials](../../tutorials/gepa_ai_program/index.md). 64 | 65 | ### Automatic Finetuning 66 | 67 | This optimizer is used to fine-tune the underlying LLM(s). 68 | 69 | 9. [**`BootstrapFinetune`**](/api/optimizers/BootstrapFinetune): Distills a prompt-based DSPy program into weight updates. The output is a DSPy program that has the same steps, but where each step is conducted by a finetuned model instead of a prompted LM. [See the classification fine-tuning tutorial](https://dspy.ai/tutorials/classification_finetuning/) for a complete example. 70 | 71 | 72 | ### Program Transformations 73 | 74 | 10. [**`Ensemble`**](../../api/optimizers/Ensemble.md): Ensembles a set of DSPy programs and either uses the full set or randomly samples a subset into a single program. 75 | 76 | 77 | ## Which optimizer should I use? 78 | 79 | Ultimately, finding the ‘right’ optimizer to use & the best configuration for your task will require experimentation. Success in DSPy is still an iterative process - getting the best performance on your task will require you to explore and iterate. 80 | 81 | That being said, here's the general guidance on getting started: 82 | 83 | - If you have **very few examples** (around 10), start with `BootstrapFewShot`. 84 | - If you have **more data** (50 examples or more), try `BootstrapFewShotWithRandomSearch`. 85 | - If you prefer to do **instruction optimization only** (i.e. you want to keep your prompt 0-shot), use `MIPROv2` [configured for 0-shot optimization](../../api/optimizers/MIPROv2.md). 86 | - If you’re willing to use more inference calls to perform **longer optimization runs** (e.g. 40 trials or more), and have enough data (e.g. 200 examples or more to prevent overfitting) then try `MIPROv2`. 87 | - If you have been able to use one of these with a large LM (e.g., 7B parameters or above) and need a very **efficient program**, finetune a small LM for your task with `BootstrapFinetune`. 88 | 89 | ## How do I use an optimizer? 90 | 91 | They all share this general interface, with some differences in the keyword arguments (hyperparameters). A full list can be found in the [API reference](../../api/optimizers/BetterTogether.md). 92 | 93 | Let's see this with the most common one, `BootstrapFewShotWithRandomSearch`. 94 | 95 | ```python 96 | from dspy.teleprompt import BootstrapFewShotWithRandomSearch 97 | 98 | # Set up the optimizer: we want to "bootstrap" (i.e., self-generate) 8-shot examples of your program's steps. 99 | # The optimizer will repeat this 10 times (plus some initial attempts) before selecting its best attempt on the devset. 100 | config = dict(max_bootstrapped_demos=4, max_labeled_demos=4, num_candidate_programs=10, num_threads=4) 101 | 102 | teleprompter = BootstrapFewShotWithRandomSearch(metric=YOUR_METRIC_HERE, **config) 103 | optimized_program = teleprompter.compile(YOUR_PROGRAM_HERE, trainset=YOUR_TRAINSET_HERE) 104 | ``` 105 | 106 | 107 | !!! info "Getting Started III: Optimizing the LM prompts or weights in DSPy programs" 108 | A typical simple optimization run costs on the order of $2 USD and takes around ten minutes, but be careful when running optimizers with very large LMs or very large datasets. 109 | Optimizer runs can cost as little as a few cents or up to tens of dollars, depending on your LM, dataset, and configuration. 110 | 111 | === "Optimizing prompts for a ReAct agent" 112 | This is a minimal but fully runnable example of setting up a `dspy.ReAct` agent that answers questions via 113 | search from Wikipedia and then optimizing it using `dspy.MIPROv2` in the cheap `light` mode on 500 114 | question-answer pairs sampled from the `HotPotQA` dataset. 115 | 116 | ```python linenums="1" 117 | import dspy 118 | from dspy.datasets import HotPotQA 119 | 120 | dspy.configure(lm=dspy.LM('openai/gpt-4o-mini')) 121 | 122 | def search(query: str) -> list[str]: 123 | """Retrieves abstracts from Wikipedia.""" 124 | results = dspy.ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts')(query, k=3) 125 | return [x['text'] for x in results] 126 | 127 | trainset = [x.with_inputs('question') for x in HotPotQA(train_seed=2024, train_size=500).train] 128 | react = dspy.ReAct("question -> answer", tools=[search]) 129 | 130 | tp = dspy.MIPROv2(metric=dspy.evaluate.answer_exact_match, auto="light", num_threads=24) 131 | optimized_react = tp.compile(react, trainset=trainset) 132 | ``` 133 | 134 | An informal run similar to this on DSPy 2.5.29 raises ReAct's score from 24% to 51%. 135 | 136 | === "Optimizing prompts for RAG" 137 | Given a retrieval index to `search`, your favorite `dspy.LM`, and a small `trainset` of questions and ground-truth responses, the following code snippet can optimize your RAG system with long outputs against the built-in `dspy.SemanticF1` metric, which is implemented as a DSPy module. 138 | 139 | ```python linenums="1" 140 | class RAG(dspy.Module): 141 | def __init__(self, num_docs=5): 142 | self.num_docs = num_docs 143 | self.respond = dspy.ChainOfThought('context, question -> response') 144 | 145 | def forward(self, question): 146 | context = search(question, k=self.num_docs) # not defined in this snippet, see link above 147 | return self.respond(context=context, question=question) 148 | 149 | tp = dspy.MIPROv2(metric=dspy.SemanticF1(), auto="medium", num_threads=24) 150 | optimized_rag = tp.compile(RAG(), trainset=trainset, max_bootstrapped_demos=2, max_labeled_demos=2) 151 | ``` 152 | 153 | For a complete RAG example that you can run, start this [tutorial](../../tutorials/rag/index.ipynb). It improves the quality of a RAG system over a subset of StackExchange communities from 53% to 61%. 154 | 155 | === "Optimizing weights for Classification" 156 | <details><summary>Click to show dataset setup code.</summary> 157 | 158 | ```python linenums="1" 159 | import random 160 | from typing import Literal 161 | 162 | from datasets import load_dataset 163 | 164 | import dspy 165 | from dspy.datasets import DataLoader 166 | 167 | # Load the Banking77 dataset. 168 | CLASSES = load_dataset("PolyAI/banking77", split="train", trust_remote_code=True).features["label"].names 169 | kwargs = {"fields": ("text", "label"), "input_keys": ("text",), "split": "train", "trust_remote_code": True} 170 | 171 | # Load the first 2000 examples from the dataset, and assign a hint to each *training* example. 172 | trainset = [ 173 | dspy.Example(x, hint=CLASSES[x.label], label=CLASSES[x.label]).with_inputs("text", "hint") 174 | for x in DataLoader().from_huggingface(dataset_name="PolyAI/banking77", **kwargs)[:2000] 175 | ] 176 | random.Random(0).shuffle(trainset) 177 | ``` 178 | </details> 179 | 180 | ```python linenums="1" 181 | import dspy 182 | lm=dspy.LM('openai/gpt-4o-mini-2024-07-18') 183 | 184 | # Define the DSPy module for classification. It will use the hint at training time, if available. 185 | signature = dspy.Signature("text, hint -> label").with_updated_fields('label', type_=Literal[tuple(CLASSES)]) 186 | classify = dspy.ChainOfThought(signature) 187 | classify.set_lm(lm) 188 | 189 | # Optimize via BootstrapFinetune. 190 | optimizer = dspy.BootstrapFinetune(metric=(lambda x, y, trace=None: x.label == y.label), num_threads=24) 191 | optimized = optimizer.compile(classify, trainset=trainset) 192 | 193 | optimized(text="What does a pending cash withdrawal mean?") 194 | 195 | # For a complete fine-tuning tutorial, see: https://dspy.ai/tutorials/classification_finetuning/ 196 | ``` 197 | 198 | **Possible Output (from the last line):** 199 | ```text 200 | Prediction( 201 | reasoning='A pending cash withdrawal indicates that a request to withdraw cash has been initiated but has not yet been completed or processed. This status means that the transaction is still in progress and the funds have not yet been deducted from the account or made available to the user.', 202 | label='pending_cash_withdrawal' 203 | ) 204 | ``` 205 | 206 | An informal run similar to this on DSPy 2.5.29 raises GPT-4o-mini's score 66% to 87%. 207 | 208 | 209 | ## Saving and loading optimizer output 210 | 211 | After running a program through an optimizer, it's useful to also save it. At a later point, a program can be loaded from a file and used for inference. For this, the `load` and `save` methods can be used. 212 | 213 | ```python 214 | optimized_program.save(YOUR_SAVE_PATH) 215 | ``` 216 | 217 | The resulting file is in plain-text JSON format. It contains all the parameters and steps in the source program. You can always read it and see what the optimizer generated. 218 | 219 | 220 | To load a program from a file, you can instantiate an object from that class and then call the load method on it. 221 | 222 | ```python 223 | loaded_program = YOUR_PROGRAM_CLASS() 224 | loaded_program.load(path=YOUR_SAVE_PATH) 225 | ``` 226 | 227 | ``` -------------------------------------------------------------------------------- /dspy/evaluate/evaluate.py: -------------------------------------------------------------------------------- ```python 1 | import csv 2 | import importlib 3 | import json 4 | import logging 5 | import types 6 | from typing import TYPE_CHECKING, Any, Callable 7 | 8 | if TYPE_CHECKING: 9 | import pandas as pd 10 | 11 | import tqdm 12 | 13 | import dspy 14 | from dspy.primitives.prediction import Prediction 15 | from dspy.utils.callback import with_callbacks 16 | from dspy.utils.parallelizer import ParallelExecutor 17 | 18 | try: 19 | from IPython.display import HTML 20 | from IPython.display import display as display 21 | 22 | except ImportError: 23 | 24 | def display(obj: Any): 25 | """ 26 | Display the specified Python object in the console. 27 | 28 | :param obj: The Python object to display. 29 | """ 30 | print(obj) 31 | 32 | def HTML(x: str) -> str: # noqa: N802 33 | """ 34 | Obtain the HTML representation of the specified string. 35 | """ 36 | # NB: This method exists purely for code compatibility with the IPython HTML() function in 37 | # environments where IPython is not available. In such environments where IPython is not 38 | # available, this method will simply return the input string. 39 | return x 40 | 41 | 42 | # TODO: Counting failures and having a max_failure count. When that is exceeded (also just at the end), 43 | # we print the number of failures, the first N examples that failed, and the first N exceptions raised. 44 | 45 | logger = logging.getLogger(__name__) 46 | 47 | 48 | class EvaluationResult(Prediction): 49 | """ 50 | A class that represents the result of an evaluation. 51 | It is a subclass of `dspy.Prediction` that contains the following fields 52 | 53 | - score: An float value (e.g., 67.30) representing the overall performance 54 | - results: a list of (example, prediction, score) tuples for each example in devset 55 | """ 56 | 57 | def __init__(self, score: float, results: list[tuple["dspy.Example", "dspy.Example", Any]]): 58 | super().__init__(score=score, results=results) 59 | 60 | def __repr__(self): 61 | return f"EvaluationResult(score={self.score}, results=<list of {len(self.results)} results>)" 62 | 63 | 64 | class Evaluate: 65 | """DSPy Evaluate class. 66 | 67 | This class is used to evaluate the performance of a DSPy program. Users need to provide a evaluation dataset and 68 | a metric function in order to use this class. This class supports parallel evaluation on the provided dataset. 69 | """ 70 | 71 | def __init__( 72 | self, 73 | *, 74 | devset: list["dspy.Example"], 75 | metric: Callable | None = None, 76 | num_threads: int | None = None, 77 | display_progress: bool = False, 78 | display_table: bool | int = False, 79 | max_errors: int | None = None, 80 | provide_traceback: bool | None = None, 81 | failure_score: float = 0.0, 82 | save_as_csv: str | None = None, 83 | save_as_json: str | None = None, 84 | **kwargs, 85 | ): 86 | """ 87 | Args: 88 | devset (list[dspy.Example]): the evaluation dataset. 89 | metric (Callable): The metric function to use for evaluation. 90 | num_threads (Optional[int]): The number of threads to use for parallel evaluation. 91 | display_progress (bool): Whether to display progress during evaluation. 92 | display_table (Union[bool, int]): Whether to display the evaluation results in a table. 93 | If a number is passed, the evaluation results will be truncated to that number before displayed. 94 | max_errors (Optional[int]): The maximum number of errors to allow before 95 | stopping evaluation. If ``None``, inherits from ``dspy.settings.max_errors``. 96 | provide_traceback (Optional[bool]): Whether to provide traceback information during evaluation. 97 | failure_score (float): The default score to use if evaluation fails due to an exception. 98 | save_as_csv (Optional[str]): The file name where the csv will be saved. 99 | save_as_json (Optional[str]): The file name where the json will be saved. 100 | 101 | """ 102 | self.devset = devset 103 | self.metric = metric 104 | self.num_threads = num_threads 105 | self.display_progress = display_progress 106 | self.display_table = display_table 107 | self.max_errors = max_errors 108 | self.provide_traceback = provide_traceback 109 | self.failure_score = failure_score 110 | self.save_as_csv = save_as_csv 111 | self.save_as_json = save_as_json 112 | 113 | if "return_outputs" in kwargs: 114 | raise ValueError("`return_outputs` is no longer supported. Results are always returned inside the `results` field of the `EvaluationResult` object.") 115 | 116 | @with_callbacks 117 | def __call__( 118 | self, 119 | program: "dspy.Module", 120 | metric: Callable | None = None, 121 | devset: list["dspy.Example"] | None = None, 122 | num_threads: int | None = None, 123 | display_progress: bool | None = None, 124 | display_table: bool | int | None = None, 125 | callback_metadata: dict[str, Any] | None = None, 126 | save_as_csv: str | None = None, 127 | save_as_json: str | None = None, 128 | ) -> EvaluationResult: 129 | """ 130 | Args: 131 | program (dspy.Module): The DSPy program to evaluate. 132 | metric (Callable): The metric function to use for evaluation. if not provided, use `self.metric`. 133 | devset (list[dspy.Example]): the evaluation dataset. if not provided, use `self.devset`. 134 | num_threads (Optional[int]): The number of threads to use for parallel evaluation. if not provided, use 135 | `self.num_threads`. 136 | display_progress (bool): Whether to display progress during evaluation. if not provided, use 137 | `self.display_progress`. 138 | display_table (Union[bool, int]): Whether to display the evaluation results in a table. if not provided, use 139 | `self.display_table`. If a number is passed, the evaluation results will be truncated to that number before displayed. 140 | callback_metadata (dict): Metadata to be used for evaluate callback handlers. 141 | 142 | Returns: 143 | The evaluation results are returned as a dspy.EvaluationResult object containing the following attributes: 144 | 145 | - score: A float percentage score (e.g., 67.30) representing overall performance 146 | 147 | - results: a list of (example, prediction, score) tuples for each example in devset 148 | """ 149 | metric = metric if metric is not None else self.metric 150 | devset = devset if devset is not None else self.devset 151 | num_threads = num_threads if num_threads is not None else self.num_threads 152 | display_progress = display_progress if display_progress is not None else self.display_progress 153 | display_table = display_table if display_table is not None else self.display_table 154 | save_as_csv = save_as_csv if save_as_csv is not None else self.save_as_csv 155 | save_as_json = save_as_json if save_as_json is not None else self.save_as_json 156 | 157 | if callback_metadata: 158 | logger.debug(f"Evaluate is called with callback metadata: {callback_metadata}") 159 | 160 | tqdm.tqdm._instances.clear() 161 | 162 | executor = ParallelExecutor( 163 | num_threads=num_threads, 164 | disable_progress_bar=not display_progress, 165 | max_errors=(self.max_errors if self.max_errors is not None else dspy.settings.max_errors), 166 | provide_traceback=self.provide_traceback, 167 | compare_results=True, 168 | ) 169 | 170 | def process_item(example): 171 | prediction = program(**example.inputs()) 172 | score = metric(example, prediction) 173 | return prediction, score 174 | 175 | results = executor.execute(process_item, devset) 176 | assert len(devset) == len(results) 177 | 178 | results = [((dspy.Prediction(), self.failure_score) if r is None else r) for r in results] 179 | results = [(example, prediction, score) for example, (prediction, score) in zip(devset, results, strict=False)] 180 | ncorrect, ntotal = sum(score for *_, score in results), len(devset) 181 | 182 | logger.info(f"Average Metric: {ncorrect} / {ntotal} ({round(100 * ncorrect / ntotal, 1)}%)") 183 | 184 | if display_table: 185 | if importlib.util.find_spec("pandas") is not None: 186 | # Rename the 'correct' column to the name of the metric object 187 | metric_name = metric.__name__ if isinstance(metric, types.FunctionType) else metric.__class__.__name__ 188 | # Construct a pandas DataFrame from the results 189 | result_df = self._construct_result_table(results, metric_name) 190 | 191 | self._display_result_table(result_df, display_table, metric_name) 192 | else: 193 | logger.warning("Skipping table display since `pandas` is not installed.") 194 | 195 | if save_as_csv: 196 | metric_name = ( 197 | metric.__name__ 198 | if isinstance(metric, types.FunctionType) 199 | else metric.__class__.__name__ 200 | ) 201 | data = self._prepare_results_output(results, metric_name) 202 | 203 | with open(save_as_csv, "w", newline="") as csvfile: 204 | fieldnames = data[0].keys() 205 | writer = csv.DictWriter(csvfile, fieldnames=fieldnames) 206 | 207 | writer.writeheader() 208 | for row in data: 209 | writer.writerow(row) 210 | if save_as_json: 211 | metric_name = ( 212 | metric.__name__ 213 | if isinstance(metric, types.FunctionType) 214 | else metric.__class__.__name__ 215 | ) 216 | data = self._prepare_results_output(results, metric_name) 217 | with open( 218 | save_as_json, 219 | "w", 220 | ) as f: 221 | json.dump(data, f) 222 | 223 | return EvaluationResult( 224 | score=round(100 * ncorrect / ntotal, 2), 225 | results=results, 226 | ) 227 | 228 | @staticmethod 229 | def _prepare_results_output( 230 | results: list[tuple["dspy.Example", "dspy.Example", Any]], metric_name: str 231 | ): 232 | return [ 233 | ( 234 | merge_dicts(example, prediction) | {metric_name: score} 235 | if prediction_is_dictlike(prediction) 236 | else dict(example) | {"prediction": prediction, metric_name: score} 237 | ) 238 | for example, prediction, score in results 239 | ] 240 | 241 | def _construct_result_table( 242 | self, results: list[tuple["dspy.Example", "dspy.Example", Any]], metric_name: str 243 | ) -> "pd.DataFrame": 244 | """ 245 | Construct a pandas DataFrame from the specified result list. 246 | Let's not try to change the name of this method as it may be patched by external tracing tools. 247 | 248 | Args: 249 | results: The list of results to construct the result DataFrame from. 250 | metric_name: The name of the metric used for evaluation. 251 | 252 | Returns: 253 | The constructed pandas DataFrame. 254 | """ 255 | import pandas as pd 256 | 257 | data = self._prepare_results_output(results, metric_name) 258 | 259 | # Truncate every cell in the DataFrame (DataFrame.applymap was renamed to DataFrame.map in Pandas 2.1.0) 260 | result_df = pd.DataFrame(data) 261 | result_df = result_df.map(truncate_cell) if hasattr(result_df, "map") else result_df.applymap(truncate_cell) 262 | 263 | return result_df.rename(columns={"correct": metric_name}) 264 | 265 | def _display_result_table(self, result_df: "pd.DataFrame", display_table: bool | int, metric_name: str): 266 | """ 267 | Display the specified result DataFrame in a table format. 268 | 269 | Args: 270 | result_df: The result DataFrame to display. 271 | display_table: Whether to display the evaluation results in a table. 272 | If a number is passed, the evaluation results will be truncated to that number before displayed. 273 | metric_name: The name of the metric used for evaluation. 274 | """ 275 | if isinstance(display_table, bool): 276 | df_to_display = result_df.copy() 277 | truncated_rows = 0 278 | else: 279 | df_to_display = result_df.head(display_table).copy() 280 | truncated_rows = len(result_df) - display_table 281 | 282 | df_to_display = stylize_metric_name(df_to_display, metric_name) 283 | 284 | display_dataframe(df_to_display) 285 | 286 | if truncated_rows > 0: 287 | # Simplified message about the truncated rows 288 | message = f""" 289 | <div style=' 290 | text-align: center; 291 | font-size: 16px; 292 | font-weight: bold; 293 | color: #555; 294 | margin: 10px 0;'> 295 | ... {truncated_rows} more rows not displayed ... 296 | </div> 297 | """ 298 | display(HTML(message)) 299 | 300 | 301 | def prediction_is_dictlike(prediction): 302 | # Downstream logic for displaying dictionary-like predictions depends solely on the predictions 303 | # having a method called `items()` for iterating through key/value pairs 304 | return hasattr(prediction, "items") and callable(prediction.items) 305 | 306 | 307 | def merge_dicts(d1, d2) -> dict: 308 | merged = {} 309 | for k, v in d1.items(): 310 | if k in d2: 311 | merged[f"example_{k}"] = v 312 | else: 313 | merged[k] = v 314 | 315 | for k, v in d2.items(): 316 | if k in d1: 317 | merged[f"pred_{k}"] = v 318 | else: 319 | merged[k] = v 320 | 321 | return merged 322 | 323 | 324 | def truncate_cell(content) -> str: 325 | """Truncate content of a cell to 25 words.""" 326 | words = str(content).split() 327 | if len(words) > 25: 328 | return " ".join(words[:25]) + "..." 329 | return content 330 | 331 | 332 | def stylize_metric_name(df: "pd.DataFrame", metric_name: str) -> "pd.DataFrame": 333 | """ 334 | Stylize the cell contents of a pandas DataFrame corresponding to the specified metric name. 335 | 336 | :param df: The pandas DataFrame for which to stylize cell contents. 337 | :param metric_name: The name of the metric for which to stylize DataFrame cell contents. 338 | """ 339 | def format_metric(x): 340 | if isinstance(x, float): 341 | return f"✔️ [{x:.3f}]" 342 | elif x is not None: 343 | return f"✔️ [{x}]" 344 | else: 345 | return "" 346 | df[metric_name] = df[metric_name].apply(format_metric) 347 | return df 348 | 349 | 350 | def display_dataframe(df: "pd.DataFrame"): 351 | """ 352 | Display the specified Pandas DataFrame in the console. 353 | 354 | :param df: The Pandas DataFrame to display. 355 | """ 356 | import pandas as pd 357 | 358 | if is_in_ipython_notebook_environment(): 359 | display(configure_dataframe_for_ipython_notebook_display(df)) 360 | else: 361 | # Pretty print the DataFrame to the console 362 | with pd.option_context( 363 | "display.max_rows", None, "display.max_columns", None 364 | ): # more options can be specified also 365 | print(df) 366 | 367 | 368 | def configure_dataframe_for_ipython_notebook_display(df: "pd.DataFrame") -> "pd.DataFrame": 369 | """Set various pandas display options for DataFrame in an IPython notebook environment.""" 370 | import pandas as pd 371 | 372 | pd.options.display.max_colwidth = 70 373 | return df 374 | 375 | 376 | def is_in_ipython_notebook_environment(): 377 | """ 378 | Check if the current environment is an IPython notebook environment. 379 | 380 | :return: True if the current environment is an IPython notebook environment, False otherwise. 381 | """ 382 | try: 383 | from IPython import get_ipython 384 | 385 | # This is a best-effort check to see if we are in an IPython notebook environment 386 | return "IPKernelApp" in getattr(get_ipython(), "config", {}) 387 | except ImportError: 388 | return False 389 | 390 | 391 | # FIXME: TODO: The merge_dicts stuff above is way too quick and dirty. 392 | # TODO: the display_table can't handle False but can handle 0! 393 | # Not sure how it works with True exactly, probably fails too. 394 | ``` -------------------------------------------------------------------------------- /dspy/teleprompt/utils.py: -------------------------------------------------------------------------------- ```python 1 | import inspect 2 | import logging 3 | import math 4 | import os 5 | import random 6 | import shutil 7 | import sys 8 | 9 | import numpy as np 10 | 11 | try: 12 | from IPython.core.magics.code import extract_symbols 13 | except ImportError: 14 | # Won't be able to read code from jupyter notebooks 15 | extract_symbols = None 16 | 17 | import dspy 18 | from dspy.teleprompt.bootstrap import BootstrapFewShot, LabeledFewShot 19 | 20 | """ 21 | This file consists of helper functions for our variety of optimizers. 22 | """ 23 | 24 | ### OPTIMIZER TRAINING UTILS ### 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | def create_minibatch(trainset, batch_size=50, rng=None): 30 | """Create a minibatch from the trainset.""" 31 | 32 | # Ensure batch_size isn't larger than the size of the dataset 33 | batch_size = min(batch_size, len(trainset)) 34 | 35 | # If no RNG is provided, fall back to the global random instance 36 | rng = rng or random 37 | 38 | # Randomly sample indices for the mini-batch using the provided rng 39 | sampled_indices = rng.sample(range(len(trainset)), batch_size) 40 | 41 | # Create the mini-batch using the sampled indices 42 | minibatch = [trainset[i] for i in sampled_indices] 43 | 44 | return minibatch 45 | 46 | 47 | def eval_candidate_program(batch_size, trainset, candidate_program, evaluate, rng=None): 48 | """Evaluate a candidate program on the trainset, using the specified batch size.""" 49 | 50 | try: 51 | # Evaluate on the full trainset 52 | if batch_size >= len(trainset): 53 | return evaluate(candidate_program, devset=trainset, callback_metadata={"metric_key": "eval_full"}) 54 | # Or evaluate on a minibatch 55 | else: 56 | return evaluate( 57 | candidate_program, 58 | devset=create_minibatch(trainset, batch_size, rng), 59 | callback_metadata={"metric_key": "eval_minibatch"} 60 | ) 61 | except Exception: 62 | logger.error("An exception occurred during evaluation", exc_info=True) 63 | # TODO: Handle this better, as -ve scores are possible 64 | return dspy.Prediction(score=0.0, results=[]) 65 | 66 | 67 | def eval_candidate_program_with_pruning( 68 | trial, 69 | trial_logs, 70 | trainset, 71 | candidate_program, 72 | evaluate, 73 | trial_num, 74 | batch_size=100, 75 | ): 76 | """Evaluation of candidate_program with pruning implemented""" 77 | 78 | # Evaluate with the new prompts 79 | total_score = 0 80 | num_batches = math.ceil(len(trainset) / batch_size) 81 | total_eval_size = 0 82 | 83 | for i in range(num_batches): 84 | start_index = i * batch_size 85 | end_index = min((i + 1) * batch_size, len(trainset)) 86 | split_trainset = trainset[start_index:end_index] 87 | split_score = evaluate( 88 | candidate_program, 89 | devset=split_trainset, 90 | display_table=0, 91 | ) 92 | print(f"{i}st split score: {split_score}") 93 | total_eval_size += len(split_trainset) 94 | 95 | total_score += split_score * len(split_trainset) 96 | curr_weighted_avg_score = total_score / min((i + 1) * batch_size, len(trainset)) 97 | print(f"curr average score: {curr_weighted_avg_score}") 98 | 99 | trial.report(curr_weighted_avg_score, i) 100 | 101 | # Handle pruning based on the intermediate value. 102 | if trial.should_prune(): 103 | print("Trial pruned.") 104 | trial_logs[trial_num]["score"] = curr_weighted_avg_score 105 | trial_logs[trial_num]["num_eval_calls"] = total_eval_size 106 | trial_logs[trial_num]["pruned"] = True 107 | return curr_weighted_avg_score, trial_logs, total_eval_size, True 108 | 109 | print(f"Fully evaled score: {curr_weighted_avg_score}") 110 | score = curr_weighted_avg_score 111 | 112 | trial_logs[trial_num]["full_eval"] = False 113 | trial_logs[trial_num]["score"] = score 114 | trial_logs[trial_num]["pruned"] = False 115 | return score, trial_logs, total_eval_size, False 116 | 117 | 118 | def get_program_with_highest_avg_score(param_score_dict, fully_evaled_param_combos): 119 | """Used as a helper function for bayesian + minibatching optimizers. Returns the program with the highest average score from the batches evaluated so far.""" 120 | 121 | # Calculate the mean for each combination of categorical parameters, based on past trials 122 | results = [] 123 | for key, values in param_score_dict.items(): 124 | scores = np.array([v[0] for v in values]) 125 | mean = np.average(scores) 126 | program = values[0][1] 127 | params = values[0][2] 128 | results.append((key, mean, program, params)) 129 | 130 | # Sort results by the mean 131 | sorted_results = sorted(results, key=lambda x: x[1], reverse=True) 132 | 133 | # Find the combination with the highest mean, skip fully evaluated ones 134 | for combination in sorted_results: 135 | key, mean, program, params = combination 136 | 137 | if key in fully_evaled_param_combos: 138 | continue 139 | 140 | return program, mean, key, params 141 | 142 | # If no valid program is found, we return the last valid one that we found 143 | return program, mean, key, params 144 | 145 | 146 | def calculate_last_n_proposed_quality( 147 | base_program, 148 | trial_logs, 149 | evaluate, 150 | trainset, 151 | devset, 152 | n, 153 | ): 154 | """ 155 | Calculate the average and best quality of the last n programs proposed. This is useful for seeing if our proposals 156 | are actually 'improving' overtime or not. 157 | """ 158 | # Get the trials from the last n keys in trial logs 159 | last_n_trial_nums = list(trial_logs.keys())[-n:] 160 | 161 | # Calculate the average and best score of these trials 162 | # if num_eval_calls in the trial is less than the trainset, throw a not-implemented error for now 163 | total_train_score = 0 164 | best_train_score = 0 165 | total_dev_score = 0 166 | best_dev_score = 0 167 | for trial_num in last_n_trial_nums: 168 | full_eval = trial_logs[trial_num]["full_eval"] 169 | if not full_eval: 170 | raise NotImplementedError( 171 | "Still need to implement non full eval handling in calculate_last_n_proposed_quality", 172 | ) 173 | train_score = trial_logs[trial_num]["score"] 174 | program = base_program.deepcopy() 175 | program.load(trial_logs[trial_num]["program_path"]) 176 | 177 | dev_score = evaluate(program, devset=devset) 178 | 179 | total_train_score += train_score 180 | total_dev_score += dev_score 181 | if train_score > best_train_score: 182 | best_train_score = train_score 183 | best_dev_score = dev_score 184 | 185 | return best_train_score, total_train_score / n, best_dev_score, total_dev_score / n 186 | 187 | 188 | ### LOGGING UTILS ### 189 | 190 | 191 | def get_task_model_history_for_full_example( 192 | candidate_program, 193 | task_model, 194 | devset, 195 | evaluate, 196 | ): 197 | """Get a full trace of the task model's history for a given candidate program.""" 198 | _ = evaluate(candidate_program, devset=devset[:1]) 199 | _ = task_model.inspect_history(n=len(candidate_program.predictors())) 200 | return task_model.inspect_history(n=len(candidate_program.predictors())) 201 | 202 | 203 | def print_full_program(program): 204 | """Print out the program's instructions & prefixes for each module.""" 205 | for i, predictor in enumerate(program.predictors()): 206 | print(f"Predictor {i}") 207 | print(f"i: {get_signature(predictor).instructions}") 208 | *_, last_field = get_signature(predictor).fields.values() 209 | print(f"p: {last_field.json_schema_extra['prefix']}") 210 | print("\n") 211 | 212 | 213 | def save_candidate_program(program, log_dir, trial_num, note=None): 214 | """Save the candidate program to the log directory.""" 215 | 216 | if log_dir is None: 217 | return None 218 | 219 | # Ensure the directory exists 220 | eval_programs_dir = os.path.join(log_dir, "evaluated_programs") 221 | os.makedirs(eval_programs_dir, exist_ok=True) 222 | 223 | # Define the save path for the program 224 | if note: 225 | save_path = os.path.join(eval_programs_dir, f"program_{trial_num}_{note}.json") 226 | else: 227 | save_path = os.path.join(eval_programs_dir, f"program_{trial_num}.json") 228 | 229 | # Save the program 230 | program.save(save_path) 231 | 232 | return save_path 233 | 234 | 235 | def save_file_to_log_dir(source_file_path, log_dir): 236 | if log_dir is None: 237 | return 238 | """Save a file to our log directory""" 239 | if not os.path.exists(log_dir): 240 | os.makedirs(log_dir) 241 | destination_file_path = os.path.join(log_dir, os.path.basename(source_file_path)) 242 | 243 | # Copy the file 244 | shutil.copy(source_file_path, destination_file_path) 245 | 246 | 247 | def setup_logging(log_dir): 248 | """Setup logger, which will log our print statements to a txt file at our log_dir for later viewing""" 249 | if log_dir is None: 250 | return 251 | # Create a logger 252 | logger = logging.getLogger() 253 | logger.setLevel(logging.WARNING) 254 | 255 | # Create a file handler that logs debug and higher level messages 256 | file_handler = logging.FileHandler(f"{log_dir}/logs.txt") 257 | file_handler.setLevel(logging.WARNING) 258 | file_formatter = logging.Formatter("%(asctime)s - %(message)s") 259 | file_handler.setFormatter(file_formatter) 260 | logger.addHandler(file_handler) 261 | 262 | # Create a console handler with a higher log level 263 | console_handler = logging.StreamHandler() 264 | console_handler.setLevel(logging.WARNING) 265 | console_formatter = logging.Formatter("%(message)s") 266 | console_handler.setFormatter(console_formatter) 267 | logger.addHandler(console_handler) 268 | 269 | 270 | def get_token_usage(model) -> tuple[int, int]: 271 | """ 272 | Extract total input tokens and output tokens from a model's interaction history. 273 | Returns (total_input_tokens, total_output_tokens). 274 | """ 275 | if not hasattr(model, "history"): 276 | return 0, 0 277 | 278 | input_tokens = [] 279 | output_tokens = [] 280 | for interaction in model.history: 281 | usage = interaction.get("usage", {}) 282 | _input_tokens = usage.get("prompt_tokens", 0) 283 | _output_tokens = usage.get("completion_tokens", 0) 284 | input_tokens.append(_input_tokens) 285 | output_tokens.append(_output_tokens) 286 | 287 | total_input_tokens = int(np.sum(input_tokens)) 288 | total_output_tokens = int(np.sum(output_tokens)) 289 | 290 | return total_input_tokens, total_output_tokens 291 | 292 | 293 | def log_token_usage(trial_logs, trial_num, model_dict): 294 | """ 295 | Extract total input and output tokens used by each model and log to trial_logs[trial_num]["token_usage"]. 296 | """ 297 | 298 | token_usage_dict = {} 299 | 300 | for model_name, model in model_dict.items(): 301 | in_tokens, out_tokens = get_token_usage(model) 302 | token_usage_dict[model_name] = {"total_input_tokens": in_tokens, "total_output_tokens": out_tokens} 303 | 304 | # Store token usage info in trial logs 305 | trial_logs[trial_num]["token_usage"] = token_usage_dict 306 | 307 | 308 | ### OTHER UTILS ### 309 | 310 | 311 | def get_prompt_model(prompt_model): 312 | if prompt_model: 313 | return prompt_model 314 | else: 315 | return dspy.settings.lm 316 | 317 | 318 | def get_signature(predictor): 319 | assert hasattr(predictor, "signature") 320 | return predictor.signature 321 | 322 | 323 | def set_signature(predictor, updated_signature): 324 | assert hasattr(predictor, "signature") 325 | predictor.signature = updated_signature 326 | 327 | 328 | def create_n_fewshot_demo_sets( 329 | student, 330 | num_candidate_sets, 331 | trainset, 332 | max_labeled_demos, 333 | max_bootstrapped_demos, 334 | metric, 335 | teacher_settings, 336 | max_errors=None, 337 | max_rounds=1, 338 | labeled_sample=True, 339 | min_num_samples=1, 340 | metric_threshold=None, 341 | teacher=None, 342 | include_non_bootstrapped=True, 343 | seed=0, 344 | rng=None, 345 | ): 346 | """ 347 | This function is copied from random_search.py, and creates fewshot examples in the same way that random search does. 348 | This allows us to take advantage of using the same fewshot examples when we use the same random seed in our optimizers. 349 | """ 350 | max_errors = dspy.settings.max_errors if max_errors is None else max_errors 351 | demo_candidates = {} 352 | 353 | # Account for confusing way this is set up, where we add in 3 more candidate sets to the N specified 354 | num_candidate_sets -= 3 355 | 356 | # Initialize demo_candidates dictionary 357 | for i, _ in enumerate(student.predictors()): 358 | demo_candidates[i] = [] 359 | 360 | rng = rng or random.Random(seed) 361 | 362 | # Go through and create each candidate set 363 | for seed in range(-3, num_candidate_sets): 364 | print(f"Bootstrapping set {seed + 4}/{num_candidate_sets + 3}") 365 | 366 | trainset_copy = list(trainset) 367 | 368 | if seed == -3 and include_non_bootstrapped: 369 | # zero-shot 370 | program2 = student.reset_copy() 371 | 372 | elif seed == -2 and max_labeled_demos > 0 and include_non_bootstrapped: 373 | # labels only 374 | teleprompter = LabeledFewShot(k=max_labeled_demos) 375 | program2 = teleprompter.compile( 376 | student, 377 | trainset=trainset_copy, 378 | sample=labeled_sample, 379 | ) 380 | 381 | elif seed == -1: 382 | # unshuffled few-shot 383 | program = BootstrapFewShot( 384 | metric=metric, 385 | max_errors=max_errors, 386 | max_bootstrapped_demos=max_bootstrapped_demos, 387 | max_labeled_demos=max_labeled_demos, 388 | teacher_settings=teacher_settings, 389 | max_rounds=max_rounds, 390 | ) 391 | program2 = program.compile(student, teacher=teacher, trainset=trainset_copy) 392 | 393 | else: 394 | # shuffled few-shot 395 | rng.shuffle(trainset_copy) 396 | size = rng.randint(min_num_samples, max_bootstrapped_demos) 397 | 398 | teleprompter = BootstrapFewShot( 399 | metric=metric, 400 | max_errors=max_errors, 401 | metric_threshold=metric_threshold, 402 | max_bootstrapped_demos=size, 403 | max_labeled_demos=max_labeled_demos, 404 | teacher_settings=teacher_settings, 405 | max_rounds=max_rounds, 406 | ) 407 | 408 | program2 = teleprompter.compile( 409 | student, 410 | teacher=teacher, 411 | trainset=trainset_copy, 412 | ) 413 | 414 | for i, _ in enumerate(student.predictors()): 415 | demo_candidates[i].append(program2.predictors()[i].demos) 416 | 417 | return demo_candidates 418 | 419 | 420 | def old_getfile(object): 421 | """Work out which source or compiled file an object was defined in.""" 422 | if inspect.ismodule(object): 423 | if getattr(object, "__file__", None): 424 | return object.__file__ 425 | raise TypeError(f"{object!r} is a built-in module") 426 | if inspect.isclass(object): 427 | if hasattr(object, "__module__"): 428 | module = sys.modules.get(object.__module__) 429 | if getattr(module, "__file__", None): 430 | return module.__file__ 431 | if object.__module__ == "__main__": 432 | raise OSError("source code not available") 433 | raise TypeError(f"{object!r} is a built-in class") 434 | if inspect.ismethod(object): 435 | object = object.__func__ 436 | if inspect.isfunction(object): 437 | object = object.__code__ 438 | if inspect.istraceback(object): 439 | object = object.tb_frame 440 | if inspect.isframe(object): 441 | object = object.f_code 442 | if inspect.iscode(object): 443 | return object.co_filename 444 | raise TypeError( 445 | f"module, class, method, function, traceback, frame, or code object was expected, got {type(object).__name__}" 446 | ) 447 | 448 | 449 | def new_getfile(object): 450 | if not inspect.isclass(object): 451 | return old_getfile(object) 452 | 453 | # Lookup by parent module (as in current inspect) 454 | if hasattr(object, "__module__"): 455 | object_ = sys.modules.get(object.__module__) 456 | if hasattr(object_, "__file__"): 457 | return object_.__file__ 458 | 459 | # If parent module is __main__, lookup by methods (NEW) 460 | for _, member in inspect.getmembers(object): 461 | if inspect.isfunction(member) and object.__qualname__ + "." + member.__name__ == member.__qualname__: 462 | return inspect.getfile(member) 463 | raise TypeError(f"Source for {object!r} not found") 464 | 465 | 466 | inspect.getfile = new_getfile 467 | ``` -------------------------------------------------------------------------------- /docs/docs/cheatsheet.md: -------------------------------------------------------------------------------- ```markdown 1 | --- 2 | sidebar_position: 999 3 | --- 4 | 5 | # DSPy Cheatsheet 6 | 7 | This page will contain snippets for frequent usage patterns. 8 | 9 | ## DSPy Programs 10 | 11 | ### Forcing fresh LM outputs 12 | 13 | DSPy caches LM calls. Provide a unique ``rollout_id`` and set a non-zero 14 | ``temperature`` (e.g., 1.0) to bypass an existing cache entry while still caching 15 | the new result: 16 | 17 | ```python 18 | predict = dspy.Predict("question -> answer") 19 | predict(question="1+1", config={"rollout_id": 1, "temperature": 1.0}) 20 | ``` 21 | 22 | ### dspy.Signature 23 | 24 | ```python 25 | class BasicQA(dspy.Signature): 26 | """Answer questions with short factoid answers.""" 27 | 28 | question: str = dspy.InputField() 29 | answer: str = dspy.OutputField(desc="often between 1 and 5 words") 30 | ``` 31 | 32 | ### dspy.ChainOfThought 33 | 34 | ```python 35 | generate_answer = dspy.ChainOfThought(BasicQA) 36 | 37 | # Call the predictor on a particular input alongside a hint. 38 | question='What is the color of the sky?' 39 | pred = generate_answer(question=question) 40 | ``` 41 | 42 | ### dspy.ProgramOfThought 43 | 44 | ```python 45 | pot = dspy.ProgramOfThought(BasicQA) 46 | 47 | question = 'Sarah has 5 apples. She buys 7 more apples from the store. How many apples does Sarah have now?' 48 | result = pot(question=question) 49 | 50 | print(f"Question: {question}") 51 | print(f"Final Predicted Answer (after ProgramOfThought process): {result.answer}") 52 | ``` 53 | 54 | ### dspy.ReAct 55 | 56 | ```python 57 | react_module = dspy.ReAct(BasicQA) 58 | 59 | question = 'Sarah has 5 apples. She buys 7 more apples from the store. How many apples does Sarah have now?' 60 | result = react_module(question=question) 61 | 62 | print(f"Question: {question}") 63 | print(f"Final Predicted Answer (after ReAct process): {result.answer}") 64 | ``` 65 | 66 | ### dspy.Retrieve 67 | 68 | ```python 69 | colbertv2_wiki17_abstracts = dspy.ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts') 70 | dspy.settings.configure(rm=colbertv2_wiki17_abstracts) 71 | 72 | #Define Retrieve Module 73 | retriever = dspy.Retrieve(k=3) 74 | 75 | query='When was the first FIFA World Cup held?' 76 | 77 | # Call the retriever on a particular query. 78 | topK_passages = retriever(query).passages 79 | 80 | for idx, passage in enumerate(topK_passages): 81 | print(f'{idx+1}]', passage, '\n') 82 | ``` 83 | 84 | ### dspy.CodeAct 85 | 86 | ```python 87 | from dspy import CodeAct 88 | 89 | def factorial(n): 90 | """Calculate factorial of n""" 91 | if n == 1: 92 | return 1 93 | return n * factorial(n-1) 94 | 95 | act = CodeAct("n->factorial", tools=[factorial]) 96 | result = act(n=5) 97 | result # Returns 120 98 | ``` 99 | 100 | ### dspy.Parallel 101 | 102 | ```python 103 | import dspy 104 | 105 | parallel = dspy.Parallel(num_threads=2) 106 | predict = dspy.Predict("question -> answer") 107 | result = parallel( 108 | [ 109 | (predict, dspy.Example(question="1+1").with_inputs("question")), 110 | (predict, dspy.Example(question="2+2").with_inputs("question")) 111 | ] 112 | ) 113 | result 114 | ``` 115 | 116 | ## DSPy Metrics 117 | 118 | ### Function as Metric 119 | 120 | To create a custom metric you can create a function that returns either a number or a boolean value: 121 | 122 | ```python 123 | def parse_integer_answer(answer, only_first_line=True): 124 | try: 125 | if only_first_line: 126 | answer = answer.strip().split('\n')[0] 127 | 128 | # find the last token that has a number in it 129 | answer = [token for token in answer.split() if any(c.isdigit() for c in token)][-1] 130 | answer = answer.split('.')[0] 131 | answer = ''.join([c for c in answer if c.isdigit()]) 132 | answer = int(answer) 133 | 134 | except (ValueError, IndexError): 135 | # print(answer) 136 | answer = 0 137 | 138 | return answer 139 | 140 | # Metric Function 141 | def gsm8k_metric(gold, pred, trace=None) -> int: 142 | return int(parse_integer_answer(str(gold.answer))) == int(parse_integer_answer(str(pred.answer))) 143 | ``` 144 | 145 | ### LLM as Judge 146 | 147 | ```python 148 | class FactJudge(dspy.Signature): 149 | """Judge if the answer is factually correct based on the context.""" 150 | 151 | context = dspy.InputField(desc="Context for the prediction") 152 | question = dspy.InputField(desc="Question to be answered") 153 | answer = dspy.InputField(desc="Answer for the question") 154 | factually_correct: bool = dspy.OutputField(desc="Is the answer factually correct based on the context?") 155 | 156 | judge = dspy.ChainOfThought(FactJudge) 157 | 158 | def factuality_metric(example, pred): 159 | factual = judge(context=example.context, question=example.question, answer=pred.answer) 160 | return factual.factually_correct 161 | ``` 162 | 163 | ## DSPy Evaluation 164 | 165 | ```python 166 | from dspy.evaluate import Evaluate 167 | 168 | evaluate_program = Evaluate(devset=devset, metric=your_defined_metric, num_threads=NUM_THREADS, display_progress=True, display_table=num_rows_to_display) 169 | 170 | evaluate_program(your_dspy_program) 171 | ``` 172 | 173 | ## DSPy Optimizers 174 | 175 | ### LabeledFewShot 176 | 177 | ```python 178 | from dspy.teleprompt import LabeledFewShot 179 | 180 | labeled_fewshot_optimizer = LabeledFewShot(k=8) 181 | your_dspy_program_compiled = labeled_fewshot_optimizer.compile(student = your_dspy_program, trainset=trainset) 182 | ``` 183 | 184 | ### BootstrapFewShot 185 | 186 | ```python 187 | from dspy.teleprompt import BootstrapFewShot 188 | 189 | fewshot_optimizer = BootstrapFewShot(metric=your_defined_metric, max_bootstrapped_demos=4, max_labeled_demos=16, max_rounds=1, max_errors=10) 190 | 191 | your_dspy_program_compiled = fewshot_optimizer.compile(student = your_dspy_program, trainset=trainset) 192 | ``` 193 | 194 | #### Using another LM for compilation, specifying in teacher_settings 195 | 196 | ```python 197 | from dspy.teleprompt import BootstrapFewShot 198 | 199 | fewshot_optimizer = BootstrapFewShot(metric=your_defined_metric, max_bootstrapped_demos=4, max_labeled_demos=16, max_rounds=1, max_errors=10, teacher_settings=dict(lm=gpt4)) 200 | 201 | your_dspy_program_compiled = fewshot_optimizer.compile(student = your_dspy_program, trainset=trainset) 202 | ``` 203 | 204 | #### Compiling a compiled program - bootstrapping a bootstrapped program 205 | 206 | ```python 207 | your_dspy_program_compiledx2 = teleprompter.compile( 208 | your_dspy_program, 209 | teacher=your_dspy_program_compiled, 210 | trainset=trainset, 211 | ) 212 | ``` 213 | 214 | #### Saving/loading a compiled program 215 | 216 | ```python 217 | save_path = './v1.json' 218 | your_dspy_program_compiledx2.save(save_path) 219 | ``` 220 | 221 | ```python 222 | loaded_program = YourProgramClass() 223 | loaded_program.load(path=save_path) 224 | ``` 225 | 226 | ### BootstrapFewShotWithRandomSearch 227 | 228 | Detailed documentation on BootstrapFewShotWithRandomSearch can be found [here](api/optimizers/BootstrapFewShot.md). 229 | 230 | ```python 231 | from dspy.teleprompt import BootstrapFewShotWithRandomSearch 232 | 233 | fewshot_optimizer = BootstrapFewShotWithRandomSearch(metric=your_defined_metric, max_bootstrapped_demos=2, num_candidate_programs=8, num_threads=NUM_THREADS) 234 | 235 | your_dspy_program_compiled = fewshot_optimizer.compile(student = your_dspy_program, trainset=trainset, valset=devset) 236 | 237 | ``` 238 | 239 | Other custom configurations are similar to customizing the `BootstrapFewShot` optimizer. 240 | 241 | ### Ensemble 242 | 243 | ```python 244 | from dspy.teleprompt import BootstrapFewShotWithRandomSearch 245 | from dspy.teleprompt.ensemble import Ensemble 246 | 247 | fewshot_optimizer = BootstrapFewShotWithRandomSearch(metric=your_defined_metric, max_bootstrapped_demos=2, num_candidate_programs=8, num_threads=NUM_THREADS) 248 | your_dspy_program_compiled = fewshot_optimizer.compile(student = your_dspy_program, trainset=trainset, valset=devset) 249 | 250 | ensemble_optimizer = Ensemble(reduce_fn=dspy.majority) 251 | programs = [x[-1] for x in your_dspy_program_compiled.candidate_programs] 252 | your_dspy_program_compiled_ensemble = ensemble_optimizer.compile(programs[:3]) 253 | ``` 254 | 255 | ### BootstrapFinetune 256 | 257 | ```python 258 | from dspy.teleprompt import BootstrapFewShotWithRandomSearch, BootstrapFinetune 259 | 260 | #Compile program on current dspy.settings.lm 261 | fewshot_optimizer = BootstrapFewShotWithRandomSearch(metric=your_defined_metric, max_bootstrapped_demos=2, num_threads=NUM_THREADS) 262 | your_dspy_program_compiled = tp.compile(your_dspy_program, trainset=trainset[:some_num], valset=trainset[some_num:]) 263 | 264 | #Configure model to finetune 265 | config = dict(target=model_to_finetune, epochs=2, bf16=True, bsize=6, accumsteps=2, lr=5e-5) 266 | 267 | #Compile program on BootstrapFinetune 268 | finetune_optimizer = BootstrapFinetune(metric=your_defined_metric) 269 | finetune_program = finetune_optimizer.compile(your_dspy_program, trainset=some_new_dataset_for_finetuning_model, **config) 270 | 271 | finetune_program = your_dspy_program 272 | 273 | #Load program and activate model's parameters in program before evaluation 274 | ckpt_path = "saved_checkpoint_path_from_finetuning" 275 | LM = dspy.HFModel(checkpoint=ckpt_path, model=model_to_finetune) 276 | 277 | for p in finetune_program.predictors(): 278 | p.lm = LM 279 | p.activated = False 280 | ``` 281 | 282 | ### COPRO 283 | 284 | Detailed documentation on COPRO can be found [here](api/optimizers/COPRO.md). 285 | 286 | ```python 287 | from dspy.teleprompt import COPRO 288 | 289 | eval_kwargs = dict(num_threads=16, display_progress=True, display_table=0) 290 | 291 | copro_teleprompter = COPRO(prompt_model=model_to_generate_prompts, metric=your_defined_metric, breadth=num_new_prompts_generated, depth=times_to_generate_prompts, init_temperature=prompt_generation_temperature, verbose=False) 292 | 293 | compiled_program_optimized_signature = copro_teleprompter.compile(your_dspy_program, trainset=trainset, eval_kwargs=eval_kwargs) 294 | ``` 295 | 296 | ### MIPROv2 297 | 298 | Note: detailed documentation can be found [here](api/optimizers/MIPROv2.md). `MIPROv2` is the latest extension of `MIPRO` which includes updates such as (1) improvements to instruction proposal and (2) more efficient search with minibatching. 299 | 300 | #### Optimizing with MIPROv2 301 | 302 | This shows how to perform an easy out-of-the box run with `auto=light`, which configures many hyperparameters for you and performs a light optimization run. You can alternatively set `auto=medium` or `auto=heavy` to perform longer optimization runs. The more detailed `MIPROv2` documentation [here](api/optimizers/MIPROv2.md) also provides more information about how to set hyperparameters by hand. 303 | 304 | ```python 305 | # Import the optimizer 306 | from dspy.teleprompt import MIPROv2 307 | 308 | # Initialize optimizer 309 | teleprompter = MIPROv2( 310 | metric=gsm8k_metric, 311 | auto="light", # Can choose between light, medium, and heavy optimization runs 312 | ) 313 | 314 | # Optimize program 315 | print(f"Optimizing program with MIPRO...") 316 | optimized_program = teleprompter.compile( 317 | program.deepcopy(), 318 | trainset=trainset, 319 | max_bootstrapped_demos=3, 320 | max_labeled_demos=4, 321 | ) 322 | 323 | # Save optimize program for future use 324 | optimized_program.save(f"mipro_optimized") 325 | 326 | # Evaluate optimized program 327 | print(f"Evaluate optimized program...") 328 | evaluate(optimized_program, devset=devset[:]) 329 | ``` 330 | 331 | #### Optimizing instructions only with MIPROv2 (0-Shot) 332 | 333 | ```python 334 | # Import the optimizer 335 | from dspy.teleprompt import MIPROv2 336 | 337 | # Initialize optimizer 338 | teleprompter = MIPROv2( 339 | metric=gsm8k_metric, 340 | auto="light", # Can choose between light, medium, and heavy optimization runs 341 | ) 342 | 343 | # Optimize program 344 | print(f"Optimizing program with MIPRO...") 345 | optimized_program = teleprompter.compile( 346 | program.deepcopy(), 347 | trainset=trainset, 348 | max_bootstrapped_demos=0, 349 | max_labeled_demos=0, 350 | ) 351 | 352 | # Save optimize program for future use 353 | optimized_program.save(f"mipro_optimized") 354 | 355 | # Evaluate optimized program 356 | print(f"Evaluate optimized program...") 357 | evaluate(optimized_program, devset=devset[:]) 358 | ``` 359 | 360 | ### KNNFewShot 361 | 362 | ```python 363 | from sentence_transformers import SentenceTransformer 364 | from dspy import Embedder 365 | from dspy.teleprompt import KNNFewShot 366 | from dspy import ChainOfThought 367 | 368 | knn_optimizer = KNNFewShot(k=3, trainset=trainset, vectorizer=Embedder(SentenceTransformer("all-MiniLM-L6-v2").encode)) 369 | 370 | qa_compiled = knn_optimizer.compile(student=ChainOfThought("question -> answer")) 371 | ``` 372 | 373 | ### BootstrapFewShotWithOptuna 374 | 375 | ```python 376 | from dspy.teleprompt import BootstrapFewShotWithOptuna 377 | 378 | fewshot_optuna_optimizer = BootstrapFewShotWithOptuna(metric=your_defined_metric, max_bootstrapped_demos=2, num_candidate_programs=8, num_threads=NUM_THREADS) 379 | 380 | your_dspy_program_compiled = fewshot_optuna_optimizer.compile(student=your_dspy_program, trainset=trainset, valset=devset) 381 | ``` 382 | 383 | Other custom configurations are similar to customizing the `dspy.BootstrapFewShot` optimizer. 384 | 385 | 386 | ### SIMBA 387 | 388 | SIMBA, which stands for Stochastic Introspective Mini-Batch Ascent, is a prompt optimizer that accepts arbitrary DSPy programs and proceeds in a sequence of mini-batches seeking to make incremental improvements to the prompt instructions or few-shot examples. 389 | 390 | ```python 391 | from dspy.teleprompt import SIMBA 392 | 393 | simba = SIMBA(metric=your_defined_metric, max_steps=12, max_demos=10) 394 | 395 | optimized_program = simba.compile(student=your_dspy_program, trainset=trainset) 396 | ``` 397 | 398 | 399 | ## DSPy Tools and Utilities 400 | 401 | ### dspy.Tool 402 | 403 | ```python 404 | import dspy 405 | 406 | def search_web(query: str) -> str: 407 | """Search the web for information""" 408 | return f"Search results for: {query}" 409 | 410 | tool = dspy.Tool(search_web) 411 | result = tool(query="Python programming") 412 | ``` 413 | 414 | ### dspy.streamify 415 | 416 | ```python 417 | import dspy 418 | import asyncio 419 | 420 | predict = dspy.Predict("question->answer") 421 | 422 | stream_predict = dspy.streamify( 423 | predict, 424 | stream_listeners=[dspy.streaming.StreamListener(signature_field_name="answer")], 425 | ) 426 | 427 | async def read_output_stream(): 428 | output_stream = stream_predict(question="Why did a chicken cross the kitchen?") 429 | 430 | async for chunk in output_stream: 431 | print(chunk) 432 | 433 | asyncio.run(read_output_stream()) 434 | ``` 435 | 436 | 437 | ### dspy.asyncify 438 | 439 | ```python 440 | import dspy 441 | 442 | dspy_program = dspy.ChainOfThought("question -> answer") 443 | dspy_program = dspy.asyncify(dspy_program) 444 | 445 | asyncio.run(dspy_program(question="What is DSPy")) 446 | ``` 447 | 448 | 449 | ### Track Usage 450 | 451 | ```python 452 | import dspy 453 | dspy.settings.configure(track_usage=True) 454 | 455 | result = dspy.ChainOfThought(BasicQA)(question="What is 2+2?") 456 | print(f"Token usage: {result.get_lm_usage()}") 457 | ``` 458 | 459 | ### dspy.configure_cache 460 | 461 | ```python 462 | import dspy 463 | 464 | # Configure cache settings 465 | dspy.configure_cache( 466 | enable_disk_cache=False, 467 | enable_memory_cache=False, 468 | ) 469 | ``` 470 | 471 | ## DSPy `Refine` and `BestofN` 472 | 473 | >`dspy.Suggest` and `dspy.Assert` are replaced by `dspy.Refine` and `dspy.BestofN` in DSPy 2.6. 474 | 475 | ### BestofN 476 | 477 | Runs a module up to `N` times with different rollout IDs (bypassing cache) and returns the best prediction, as defined by the `reward_fn`, or the first prediction that passes the `threshold`. 478 | 479 | ```python 480 | import dspy 481 | 482 | qa = dspy.ChainOfThought("question -> answer") 483 | def one_word_answer(args, pred): 484 | return 1.0 if len(pred.answer) == 1 else 0.0 485 | best_of_3 = dspy.BestOfN(module=qa, N=3, reward_fn=one_word_answer, threshold=1.0) 486 | best_of_3(question="What is the capital of Belgium?").answer 487 | # Brussels 488 | ``` 489 | 490 | ### Refine 491 | 492 | Refines a module by running it up to `N` times with different rollout IDs (bypassing cache) and returns the best prediction, as defined by the `reward_fn`, or the first prediction that passes the `threshold`. After each attempt (except the final one), `Refine` automatically generates detailed feedback about the module's performance and uses this feedback as hints for subsequent runs, creating an iterative refinement process. 493 | 494 | ```python 495 | import dspy 496 | 497 | qa = dspy.ChainOfThought("question -> answer") 498 | def one_word_answer(args, pred): 499 | return 1.0 if len(pred.answer) == 1 else 0.0 500 | best_of_3 = dspy.Refine(module=qa, N=3, reward_fn=one_word_answer, threshold=1.0) 501 | best_of_3(question="What is the capital of Belgium?").answer 502 | # Brussels 503 | ``` 504 | 505 | #### Error Handling 506 | 507 | By default, `Refine` will try to run the module up to N times until the threshold is met. If the module encounters an error, it will keep going up to N failed attempts. You can change this behavior by setting `fail_count` to a smaller number than `N`. 508 | 509 | ```python 510 | refine = dspy.Refine(module=qa, N=3, reward_fn=one_word_answer, threshold=1.0, fail_count=1) 511 | ... 512 | refine(question="What is the capital of Belgium?") 513 | # If we encounter just one failed attempt, the module will raise an error. 514 | ``` 515 | 516 | If you want to run the module up to N times without any error handling, you can set `fail_count` to `N`. This is the default behavior. 517 | 518 | ```python 519 | refine = dspy.Refine(module=qa, N=3, reward_fn=one_word_answer, threshold=1.0, fail_count=3) 520 | ... 521 | refine(question="What is the capital of Belgium?") 522 | ``` 523 | ``` -------------------------------------------------------------------------------- /docs/docs/roadmap.md: -------------------------------------------------------------------------------- ```markdown 1 | --- 2 | draft: true 3 | --- 4 | 5 | !!! warning "This document is from Aug 2024. Since then, DSPy 2.5 and 2.6 were released, DSPy has grown considerably, and 3.0 is approaching! Content below is highly outdated." 6 | 7 | 8 | 9 | # Roadmap Sketch: DSPy 2.5+ 10 | 11 | It’s been a year since DSPy evolved out of Demonstrate–Search–Predict (DSP), whose research started at Stanford NLP all the way back in February 2022. Thanks to 200 wonderful contributors, DSPy has introduced tens of thousands of people to building modular LM programs and optimizing their prompts and weights automatically. In this time, DSPy has grown to 160,000 monthly downloads and 16,000 stars on GitHub, becoming synonymous with prompt optimization in many circles and inspiring at least a half-dozen cool new libraries. 12 | 13 | This document is an initial sketch of DSPy’s public roadmap for the next few weeks and months, as we work on DSPy 2.5 and plan for DSPy 3.0. Suggestions and open-source contributors are more than welcome: just open an issue or submit a pull request regarding the roadmap. 14 | 15 | 16 | 17 | ## Technical Objectives 18 | 19 | The thesis of DSPy is that for LMs to be useful, we have to shift from ad-hoc prompting to new notions of programming LMs. Instead of relying on LMs gaining much more general or more compositional capabilities, we need to enable developers to iteratively explore their problems and build modular software that invokes LMs for well-scoped tasks. We need to enable that through modules and optimizers that isolate how they decompose their problems and describe their system's objectives from how their LMs are invoked or fine-tuned to maximize their objectives. DSPy's goal has been to develop (and to build the community and shared infrastructure for the collective development of) the abstractions, programming patterns, and optimizers toward this thesis. 20 | 21 | To a first approximation, DSPy’s current user-facing language has the minimum number of appropriate abstractions that address the goals above: declarative signatures, define-by-run modules, and optimizers that can be composed quite powerfully. But there are several things we need to do better to realize our goals. The upcoming DSPy releases will have the following objectives. 22 | 23 | 1. Polishing the core functionality. 24 | 2. Developing more accurate, lower-cost optimizers. 25 | 3. Building end-to-end tutorials from DSPy’s ML workflow to deployment. 26 | 4. Shifting towards more interactive optimization & tracking. 27 | 28 | 29 | 30 | ## Team & Organization 31 | 32 | DSPy is fairly unusual in its technical objectives, contributors, and audience. Though DSPy takes inspiration from PyTorch, a library for building and optimizing DNNs, there is one major difference: PyTorch was introduced well after DNNs were mature ML concepts, but DSPy seeks to establish and advance core LM Programs research: the framework is propelled by constant academic research from programming abstractions (like the original **Demonstrate–Search–Predict** concepts, DSPy **Signatures**, or **LM Assertions**) to NLP systems (like **STORM**, **PATH**, and **IReRa**) to prompt optimizers (like **MIPRO**) and RL (like **BetterTogether**), among many other related directions. 33 | 34 | This research all composes into a concrete, practical library, thanks to dozens of industry contributors, many of whom are deploying apps in production using DSPy. Because of this, DSPy reaches not only of grad students and ML engineers, but also many non-ML engineers, from early adopter SWEs to hobbyists exploring new ways of using LMs. The following team, with help from many folks in the OSS community, is working towards the objectives in this Roadmap. 35 | 36 | **Project Lead:** Omar Khattab (Stanford & Databricks) 37 | 38 | **Project Mentors:** Chris Potts (Stanford), Matei Zaharia (UC Berkeley & Databricks), Heather Miller (CMU & Two Sigma) 39 | 40 | **Core Library:** Arnav Singhvi (Databricks & Stanford), Herumb Shandilya (Stanford), Hanna Moazam (Databricks), Sri Vardhamanan (Dashworks), Cyrus Nouroozi (Zenbase), Amir Mehr (Zenbase), Kyle Caverly (Modular), with special thanks to Keshav Santhanam (Stanford), Thomas Ahle (Normal Computing), Connor Shorten (Weaviate) 41 | 42 | **Prompt Optimization:** Krista Opsahl-Ong (Stanford), Michael Ryan (Stanford), Josh Purtell (Basis), with special thanks to Eric Zhang (Stanford) 43 | 44 | **Finetuning & RL:** Dilara Soylu (Stanford), Isaac Miller (Anyscale), Karel D'Oosterlinck (Ghent), with special thanks to Paridhi Masehswari (Stanford) 45 | 46 | **PL Abstractions:** Shangyin Tan (UC Berkeley), Manish Shetty (UC Berkeley), Peter Zhong (CMU) 47 | 48 | **Applications:** Jasper Xian (Waterloo), Saron Samuel (Stanford), Alberto Mancarella (Stanford), Faraz Khoubsirat (Waterloo), Saiful Haq (IIT-B), Ashutosh Sharma (UIUC) 49 | 50 | 51 | 52 | ## 1) Polishing the core functionality. 53 | 54 | Over the next month, polishing is the main objective and likely the one to have the highest ROI on the experience of the average user. Conceptually, DSPy has an extremely small core. It’s nothing but (1) LMs, (2) Signatures & Modules, (3) Optimizers, and (4) Assertions. These concepts and their implementations evolved organically over the past couple of years. We are working now to consolidate what we’ve learned and refactor internally so that things “just work” out of the box for new users, who may not know all the tips-and-tricks just yet. 55 | 56 | More concretely: 57 | 58 | 1. We want to increase the quality of zero-shot, off-the-shelf DSPy programs, i.e. those not yet compiled on custom data. 59 | 2. Wherever possible, DSPy should delegate lower-level internal complexity (like managing LMs and structured generation) to emerging lower-level libraries. When required, we may fork smaller libraries out of DSPy to support infrastructure pieces as their own projects. 60 | 3. DSPy should internally be more modular and we need higher compatibility between internal components. Specifically, we need more deeper and more native investment in (i) typed multi-field constraints, (ii) assertions, (iii) observability and experimental tracking, (iv) deployment of artifacts and related concerns like streaming and async, and (v) fine-tuning and serving open models. 61 | 62 | 63 | ### On LMs 64 | 65 | As of DSPy 2.4, the library has approximately 20,000 lines of code and roughly another 10,000 lines of code for tests, examples, and documentation. Some of these are clearly necessary (e.g., DSPy optimizers) but others exist only because the LM space lacks the building blocks we need under the hood. Luckily, for LM interfaces, a very strong library now exists: LiteLLM, a library that unifies interfaces to various LM and embedding providers. We expect to reduce around 6000 LoCs of support for custom LMs and retrieval models by shifting a lot of that to LiteLLM. 66 | 67 | Objectives in this space include improved caching, saving/loading of LMs, support for streaming and async LM requests. Work here is currently led by Hanna Moazam and Sri Vardhamanan, building on a foundation by Cyrus Nouroozi, Amir Mehr, Kyle Caverly, and others. 68 | 69 | 70 | ### On Signatures & Modules 71 | 72 | Traditionally, LMs offer text-in-text-out interfaces. Toward modular programming, DSPy introduced signatures for the first time (as DSP Templates in Jan 2023) as a way to structure the inputs and outputs of LM interactions. Standard prompts conflate interface (“what should the LM do?”) with implementation (“how do we tell it to do that?”). DSPy signatures isolate the former so we can infer and learn the latter from data — in the context of a bigger program. Today in the LM landscape, notions of "structured outputs" have evolved dramatically, thanks to constrained decoding and other improvements, and have become mainstream. What may be called "structured inputs" remains is yet to become mainstream outside of DSPy, but is as crucial. 73 | 74 | Objectives in this space include refining the abstractions and implementations first-class notion of LM Adapters in DSPy, as translators that sits between signatures and LM interfaces. While Optimizers adjust prompts through interactions with a user-supplied metric and data, Adapters are more concerned with building up interactions with LMs to account for, e.g. (i) non-plaintext LM interfaces like chat APIs, structured outputs, function calling, and multi-modal APIs, (ii) languages beyond English or other forms of higher-level specialization. This has been explored in DSPy on and off in various forms, but we have started working on more fundamental approaches to this problem that will offer tangible improvements to most use-cases. Work here is currently led by Omar Khattab. 75 | 76 | 77 | ### On Finetuning & Serving 78 | 79 | In February 2023, DSPy introduced the notion of compiling to optimize the weights of an LM program. (To understand just how long ago that was in AI terms, this was before the Alpaca training project at Stanford had even started and a month before the first GPT-4 was released.) Since then, we have shown in October 2023 and, much more expansively, in July 2024, that the fine-tuning flavor of DSPy can deliver large gains for small LMs, especially when composed with prompt optimization. 80 | 81 | Overall, though, most DSPy users in practice explore prompt optimization and not weight optimization and most of our examples do the same. The primary reason for a lot of this is infrastructure. Fine-tuning in the DSPy flavor is more than just training a model: ultimately, we need to bootstrap training data for several different modules in a program, train multiple models and handle model selection, and then load and plug in those models into the program's modules. Doing this robustly at the level of abstraction DSPy offers requires a level of resource management that is not generally supported by external existing tools. Major efforts in this regard are currently led by Dilara Soylu and Isaac Miller. 82 | 83 | 84 | ### On Optimizers & Assertions 85 | 86 | This is a naturally major direction in the course of polishing. We will share more thoughts here after making more progress on the three angles above. 87 | 88 | 89 | 90 | ## 2) Developing more accurate, lower-cost optimizers. 91 | 92 | A very large fraction of the research in DSPy focuses on optimizing the prompts and the weights of LM programs. In December 2022, we introduced the algorithm and abstractions behind BootstrapFewShot (as Demonstrate in DSP) and several of its variants. In February 2023, we introduced the core version of what later became BootstrapFinetune. In August 2023, we introduced new variations of both of these. In December 2023, we introduced the first couple of instruction optimizers into DSPy, CA-OPRO and early versions of MIPRO. These were again upgraded in March 2024. Fast forward to June and July 2024, we released MIPROv2 for prompt optimization and BetterTogether for fine-tuning the weights of LM programs. 93 | 94 | We have been working towards a number of stronger optimizers. While we cannot share the internal details of research on new optimizers yet, we can outline the goals. A DSPy optimizer can be characterized via three angles: 95 | 96 | 1. Quality: How much quality can it deliver from various LMs? How effective does it need the zero-shot program to be in order to work well? 97 | 2. Cost: How many labeled (and unlabeled) inputs does it need? How many invocations of the program does it need? How expensive is the resulting optimized program at inference time? 98 | 3. Robustness: How well can it generalize to different unseen data points or distributions? How sensitive is it to mistakes of the metric or labels? 99 | 100 | Over the next six months, our goal is to dramatically improve each angle of these _when the other two are held constant_. Concretely, there are three directions here. 101 | 102 | - Benchmarking: A key prerequisite here is work on benchmarking. On the team, Michael Ryan and Shangyin Tan are leading these efforts. More soon. 103 | 104 | - Quality: The goal here is optimizers that extract, on average, 20% more on representative tasks than MIPROv2 and BetterTogether, under the usual conditions — like a few hundred inputs with labels and a good metric starting from a decent zero-shot program. Various efforts here are led by Dilara Soylu, Michael Ryan, Josh Purtell, Krista Opsahl-Ong, and Isaac Miller. 105 | 106 | - Efficiency: The goal here is optimizers that match the current best scores from MIPROv2 and BetterTogether but under 1-2 challenges like: (i) starting from only 10-20 inputs with labels, (ii) starting with a weak zero-shot program that scores 0%, (iii) where significant misalignment exists between train/validation and test, or (iii) where the user supplies no metric but provides a very small number of output judgments. 107 | 108 | 109 | 110 | ## 3) Building end-to-end tutorials from DSPy’s ML workflow to deployment. 111 | 112 | Using DSPy well for solving a new task is just doing good machine learning with LMs, but teaching this is hard. On the one hand, it's an iterative process: you make some initial choices, which will be sub-optimal, and then you refine them incrementally. It's highly exploratory: it's often the case that no one knows yet how to best solve a problem in a DSPy-esque way. One the other hand, DSPy offers many emerging lessons from several years of building LM systems, in which the design space, the data regime, and many other factors are new both to ML experts and to the very large fraction of users that have no ML experience. 113 | 114 | Though current docs do address [a bunch of this](learn/index.md) in isolated ways, one thing we've learned is that we should separate teaching the core DSPy language (which is ultimately pretty small) from teaching the emerging ML workflow that works well in a DSPy-esque setting. As a natural extension of this, we need to place more emphasis on steps prior and after to the explicit coding in DSPy, from data collection to deployment that serves and monitors the optimized DSPy program in practice. This is just starting but efforts will be ramping up led by Omar Khattab, Isaac Miller, and Herumb Shandilya. 115 | 116 | 117 | ## 4) Shifting towards more interactive optimization & tracking. 118 | 119 | Right now, a DSPy user has a few ways to observe and tweak the process of optimization. They can study the prompts before, during, and after optimization methods like `inspect_history`, built-in logging, and/or the metadata returned by optimizers. Similarly, they can rely on `program.save` and `program.load` to potentially adjust the optimized prompts by hand. Alternatively, they can use one of the many powerful observability integrations — like from Phoenix Arize, LangWatch, or Weights & Biases Weave — to observe _in real time_ the process of optimization (e.g., scores, stack traces, successful & failed traces, and candidate prompts). DSPy encourages iterative engineering by adjusting the program, data, or metrics across optimization runs. For example, some optimizers allow “checkpointing” — e.g., if you optimize with BootstrapFewShotWithRandomSearch for 10 iterations then increase to 15 iterations, the first 10 will be loaded from cache. 120 | 121 | While these can accomplish a lot of goals, there are two limitations that future versions of DSPy will seek to address. 122 | 123 | 1. In general, DSPy’s (i) observability, (ii) experimental tracking, (iii) cost management, and (iii) deployment of programs should become first-class concerns via integration with tools like MLFlow. We will share more plans addressing this for DSPy 2.6 in the next 1-2 months. 124 | 125 | 2. DSPy 3.0 will introduce new optimizers that prioritize ad-hoc, human-in-the-loop feedback. This is perhaps the only substantial paradigm shift we see as necessary in the foreseeable future in DSPy. It involves various research questions at the level of the abstractions, UI/HCI, and ML, so it is a longer-term goal that we will share more about in the next 3-4 month. 126 | 127 | 128 | ``` -------------------------------------------------------------------------------- /dspy/clients/lm_local.py: -------------------------------------------------------------------------------- ```python 1 | import datetime 2 | import logging 3 | import random 4 | import socket 5 | import string 6 | import subprocess 7 | import threading 8 | import time 9 | from typing import TYPE_CHECKING, Any 10 | 11 | import requests 12 | 13 | from dspy.clients.provider import Provider, TrainingJob 14 | from dspy.clients.utils_finetune import TrainDataFormat, save_data 15 | 16 | if TYPE_CHECKING: 17 | from dspy.clients.lm import LM 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class LocalProvider(Provider): 23 | def __init__(self): 24 | super().__init__() 25 | self.finetunable = True 26 | self.TrainingJob = TrainingJob 27 | 28 | @staticmethod 29 | def launch(lm: "LM", launch_kwargs: dict[str, Any] | None = None): 30 | try: 31 | import sglang # noqa: F401 32 | except ImportError: 33 | raise ImportError( 34 | "For local model launching, please install sglang." 35 | "Navigate to https://docs.sglang.ai/start/install.html for the latest installation instructions!" 36 | ) 37 | 38 | if hasattr(lm, "process"): 39 | logger.info("Server is already launched.") 40 | return 41 | 42 | launch_kwargs = launch_kwargs or lm.launch_kwargs 43 | 44 | import os 45 | 46 | model = lm.model 47 | if model.startswith("openai/"): 48 | model = model[7:] 49 | if model.startswith("local:"): 50 | model = model[6:] 51 | if model.startswith("huggingface/"): 52 | model = model[len("huggingface/") :] 53 | 54 | logger.info(f"Grabbing a free port to launch an SGLang server for model {model}") 55 | logger.info(f"We see that CUDA_VISIBLE_DEVICES is {os.environ.get('CUDA_VISIBLE_DEVICES', 'unset')}") 56 | port = get_free_port() 57 | timeout = launch_kwargs.get("timeout", 1800) 58 | command = f"python -m sglang.launch_server --model-path {model} --port {port} --host 0.0.0.0" 59 | 60 | # We will manually stream & capture logs. 61 | process = subprocess.Popen( 62 | command.replace("\\\n", " ").replace("\\", " ").split(), 63 | text=True, 64 | stdout=subprocess.PIPE, # We'll read from pipe 65 | stderr=subprocess.STDOUT, # Merge stderr into stdout 66 | ) 67 | 68 | # A threading.Event to control printing after the server is ready. 69 | # This will store *all* lines (both before and after readiness). 70 | logger.info(f"SGLang server process started with PID {process.pid}.") 71 | stop_printing_event = threading.Event() 72 | logs_buffer = [] 73 | 74 | def _tail_process(proc, buffer, stop_event): 75 | while True: 76 | line = proc.stdout.readline() 77 | if not line and proc.poll() is not None: 78 | # Process ended and no new line 79 | break 80 | if line: 81 | buffer.append(line) 82 | # Print only if stop_event is not set 83 | if not stop_event.is_set(): 84 | print(line, end="") 85 | 86 | # Start a background thread to read from the process continuously 87 | thread = threading.Thread( 88 | target=_tail_process, 89 | args=(process, logs_buffer, stop_printing_event), 90 | daemon=True, 91 | ) 92 | thread.start() 93 | 94 | # Wait until the server is ready (or times out) 95 | base_url = f"http://localhost:{port}" 96 | try: 97 | wait_for_server(base_url, timeout=timeout) 98 | except TimeoutError: 99 | # If the server doesn't come up, we might want to kill it: 100 | process.kill() 101 | raise 102 | 103 | # Once server is ready, we tell the thread to stop printing further lines. 104 | stop_printing_event.set() 105 | 106 | # A convenience getter so the caller can see all logs so far (and future). 107 | def get_logs() -> str: 108 | # Join them all into a single string, or you might return a list 109 | return "".join(logs_buffer) 110 | 111 | # Let the user know server is up 112 | logger.info(f"Server ready on random port {port}! Logs are available via lm.get_logs() method on returned lm.") 113 | 114 | lm.kwargs["api_base"] = f"http://localhost:{port}/v1" 115 | lm.kwargs["api_key"] = "local" 116 | lm.get_logs = get_logs 117 | lm.process = process 118 | lm.thread = thread 119 | 120 | @staticmethod 121 | def kill(lm: "LM", launch_kwargs: dict[str, Any] | None = None): 122 | from sglang.utils import terminate_process 123 | 124 | if not hasattr(lm, "process"): 125 | logger.info("No running server to kill.") 126 | return 127 | # Ideally, the following happens atomically 128 | terminate_process(lm.process) 129 | lm.thread.join() 130 | del lm.process 131 | del lm.thread 132 | del lm.get_logs 133 | logger.info("Server killed.") 134 | 135 | @staticmethod 136 | def finetune( 137 | job: TrainingJob, 138 | model: str, 139 | train_data: list[dict[str, Any]], 140 | train_data_format: TrainDataFormat | None, 141 | train_kwargs: dict[str, Any] | None = None, 142 | ) -> str: 143 | if model.startswith("openai/"): 144 | model = model[7:] 145 | if model.startswith("local:"): 146 | model = model[6:] 147 | 148 | if train_data_format != TrainDataFormat.CHAT: 149 | raise ValueError("Only chat models are supported for local finetuning.") 150 | 151 | data_path = save_data(train_data) 152 | logger.info(f"Train data saved to {data_path}") 153 | output_dir = create_output_dir(model, data_path) 154 | 155 | default_train_kwargs = { 156 | "device": None, 157 | "use_peft": False, 158 | "num_train_epochs": 5, 159 | "per_device_train_batch_size": 1, 160 | "gradient_accumulation_steps": 8, 161 | "learning_rate": 1e-5, 162 | "max_seq_length": None, 163 | "packing": True, 164 | "bf16": True, 165 | "output_dir": output_dir, 166 | } 167 | train_kwargs = {**default_train_kwargs, **(train_kwargs or {})} 168 | output_dir = train_kwargs["output_dir"] # user might have changed the output_dir 169 | 170 | logger.info(f"Starting local training, will save to {output_dir}") 171 | train_sft_locally( 172 | model_name=model, 173 | train_data=train_data, 174 | train_kwargs=train_kwargs, 175 | ) 176 | 177 | logger.info("Training complete") 178 | return f"openai/local:{output_dir}" 179 | 180 | 181 | def create_output_dir(model_name, data_path): 182 | model_str = model_name.replace("/", "-") 183 | time_str = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 184 | rnd_str = "".join(random.choices(string.ascii_lowercase + string.digits, k=6)) 185 | model_identifier = f"{rnd_str}_{model_str}_{time_str}" 186 | output_dir = data_path.replace(".jsonl", "_" + model_identifier) 187 | return output_dir 188 | 189 | 190 | def train_sft_locally(model_name, train_data, train_kwargs): 191 | try: 192 | import torch 193 | from transformers import AutoModelForCausalLM, AutoTokenizer 194 | from trl import SFTConfig, SFTTrainer, setup_chat_format 195 | except ImportError: 196 | raise ImportError( 197 | "For local finetuning, please install torch, transformers, and trl " 198 | "by running `pip install -U torch transformers accelerate trl peft`" 199 | ) 200 | 201 | device = train_kwargs.get("device", None) 202 | if device is None: 203 | device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" 204 | logger.info(f"Using device: {device}") 205 | 206 | model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=model_name).to(device) 207 | tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_name) 208 | 209 | # Set up the chat format; generally only for non-chat model variants, hence the try-except. 210 | try: 211 | model, tokenizer = setup_chat_format(model=model, tokenizer=tokenizer) 212 | except Exception: 213 | pass 214 | 215 | if tokenizer.pad_token_id is None: 216 | logger.info("Adding pad token to tokenizer") 217 | tokenizer.add_special_tokens({"pad_token": "[!#PAD#!]"}) 218 | 219 | logger.info("Creating dataset") 220 | if "max_seq_length" not in train_kwargs: 221 | train_kwargs["max_seq_length"] = 4096 222 | logger.info( 223 | f"The 'train_kwargs' parameter didn't include a 'max_seq_length', defaulting to {train_kwargs['max_seq_length']}" 224 | ) 225 | 226 | from datasets import Dataset 227 | 228 | hf_dataset = Dataset.from_list(train_data) 229 | 230 | def tokenize_function(example): 231 | return encode_sft_example(example, tokenizer, train_kwargs["max_seq_length"]) # noqa: F821 232 | 233 | tokenized_dataset = hf_dataset.map(tokenize_function, batched=False) 234 | tokenized_dataset.set_format(type="torch") 235 | tokenized_dataset = tokenized_dataset.filter(lambda example: (example["labels"] != -100).any()) 236 | 237 | use_peft = train_kwargs.get("use_peft", False) 238 | peft_config = None 239 | 240 | if use_peft: 241 | from peft import LoraConfig 242 | 243 | rank_dimension = 32 244 | lora_alpha = 64 245 | lora_dropout = 0.05 246 | 247 | peft_config = LoraConfig( 248 | r=rank_dimension, 249 | lora_alpha=lora_alpha, 250 | lora_dropout=lora_dropout, 251 | bias="none", 252 | target_modules="all-linear", 253 | task_type="CAUSAL_LM", 254 | ) 255 | 256 | sft_config = SFTConfig( 257 | output_dir=train_kwargs["output_dir"], 258 | num_train_epochs=train_kwargs["num_train_epochs"], 259 | per_device_train_batch_size=train_kwargs["per_device_train_batch_size"], 260 | gradient_accumulation_steps=train_kwargs["gradient_accumulation_steps"], 261 | learning_rate=train_kwargs["learning_rate"], 262 | max_grad_norm=2.0, # note that the current SFTConfig default is 1.0 263 | logging_steps=20, 264 | warmup_ratio=0.03, 265 | lr_scheduler_type="constant", 266 | save_steps=10_000, 267 | bf16=train_kwargs["bf16"], 268 | max_seq_length=train_kwargs["max_seq_length"], 269 | packing=train_kwargs["packing"], 270 | dataset_kwargs={ # We need to pass dataset_kwargs because we are processing the dataset ourselves 271 | "add_special_tokens": False, # Special tokens handled by template 272 | "append_concat_token": False, # No additional separator needed 273 | }, 274 | ) 275 | 276 | logger.info("Starting training") 277 | trainer = SFTTrainer( 278 | model=model, 279 | args=sft_config, 280 | train_dataset=tokenized_dataset, 281 | peft_config=peft_config, 282 | ) 283 | 284 | # Train! 285 | trainer.train() 286 | 287 | # Save the model! 288 | trainer.save_model() 289 | 290 | merge = True 291 | if use_peft and merge: 292 | from peft import AutoPeftModelForCausalLM 293 | 294 | # Load PEFT model on CPU 295 | model_ = AutoPeftModelForCausalLM.from_pretrained( 296 | pretrained_model_name_or_path=sft_config.output_dir, 297 | torch_dtype=torch.float16, 298 | low_cpu_mem_usage=True, 299 | ) 300 | 301 | merged_model = model_.merge_and_unload() 302 | merged_model.save_pretrained(sft_config.output_dir, safe_serialization=True, max_shard_size="5GB") 303 | 304 | # Clean up! 305 | import gc 306 | 307 | del model 308 | del tokenizer 309 | del trainer 310 | gc.collect() 311 | torch.cuda.empty_cache() 312 | 313 | return sft_config.output_dir 314 | 315 | 316 | def get_free_port() -> int: 317 | """ 318 | Return a free TCP port on localhost. 319 | """ 320 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 321 | s.bind(("localhost", 0)) 322 | return s.getsockname()[1] 323 | 324 | 325 | def wait_for_server(base_url: str, timeout: int | None = None) -> None: 326 | """ 327 | Wait for the server to be ready by polling the /v1/models endpoint. 328 | 329 | Args: 330 | base_url: The base URL of the server (e.g. http://localhost:1234) 331 | timeout: Maximum time to wait in seconds. None means wait forever. 332 | """ 333 | start_time = time.time() 334 | while True: 335 | try: 336 | response = requests.get( 337 | f"{base_url}/v1/models", 338 | headers={"Authorization": "Bearer None"}, 339 | ) 340 | if response.status_code == 200: 341 | # A small extra sleep to ensure server is fully up. 342 | time.sleep(5) 343 | break 344 | 345 | if timeout and (time.time() - start_time) > timeout: 346 | raise TimeoutError("Server did not become ready within timeout period") 347 | except requests.exceptions.RequestException: 348 | # Server not up yet, wait and retry 349 | time.sleep(1) 350 | 351 | 352 | def encode_sft_example(example, tokenizer, max_seq_length): 353 | """ 354 | This function encodes a single example into a format that can be used for sft training. 355 | Here, we assume each example has a 'messages' field. Each message in it is a dict with 'role' and 'content' fields. 356 | We use the `apply_chat_template` function from the tokenizer to tokenize the messages and prepare the input and label tensors. 357 | 358 | Code obtained from the allenai/open-instruct repository: https://github.com/allenai/open-instruct/blob/4365dea3d1a6111e8b2712af06b22a4512a0df88/open_instruct/finetune.py 359 | """ 360 | import torch 361 | 362 | messages = example["messages"] 363 | if len(messages) == 0: 364 | raise ValueError("messages field is empty.") 365 | input_ids = tokenizer.apply_chat_template( 366 | conversation=messages, 367 | tokenize=True, 368 | return_tensors="pt", 369 | padding=False, 370 | truncation=True, 371 | max_length=max_seq_length, 372 | add_generation_prompt=False, 373 | ) 374 | labels = input_ids.clone() 375 | # mask the non-assistant part for avoiding loss 376 | for message_idx, message in enumerate(messages): 377 | if message["role"] != "assistant": 378 | # we calculate the start index of this non-assistant message 379 | if message_idx == 0: 380 | message_start_idx = 0 381 | else: 382 | message_start_idx = tokenizer.apply_chat_template( 383 | conversation=messages[:message_idx], # here marks the end of the previous messages 384 | tokenize=True, 385 | return_tensors="pt", 386 | padding=False, 387 | truncation=True, 388 | max_length=max_seq_length, 389 | add_generation_prompt=False, 390 | ).shape[1] 391 | # next, we calculate the end index of this non-assistant message 392 | if message_idx < len(messages) - 1 and messages[message_idx + 1]["role"] == "assistant": 393 | # for intermediate messages that follow with an assistant message, we need to 394 | # set `add_generation_prompt=True` to avoid the assistant generation prefix being included in the loss 395 | # (e.g., `<|assistant|>`) 396 | message_end_idx = tokenizer.apply_chat_template( 397 | conversation=messages[: message_idx + 1], 398 | tokenize=True, 399 | return_tensors="pt", 400 | padding=False, 401 | truncation=True, 402 | max_length=max_seq_length, 403 | add_generation_prompt=True, 404 | ).shape[1] 405 | else: 406 | # for the last message or the message that doesn't follow with an assistant message, 407 | # we don't need to add the assistant generation prefix 408 | message_end_idx = tokenizer.apply_chat_template( 409 | conversation=messages[: message_idx + 1], 410 | tokenize=True, 411 | return_tensors="pt", 412 | padding=False, 413 | truncation=True, 414 | max_length=max_seq_length, 415 | add_generation_prompt=False, 416 | ).shape[1] 417 | # set the label to -100 for the non-assistant part 418 | labels[:, message_start_idx:message_end_idx] = -100 419 | if max_seq_length and message_end_idx >= max_seq_length: 420 | break 421 | attention_mask = torch.ones_like(input_ids) 422 | return { 423 | "input_ids": input_ids.flatten(), 424 | "labels": labels.flatten(), 425 | "attention_mask": attention_mask.flatten(), 426 | } 427 | ``` -------------------------------------------------------------------------------- /dspy/streaming/streaming_listener.py: -------------------------------------------------------------------------------- ```python 1 | import re 2 | from collections import defaultdict 3 | from queue import Queue 4 | from typing import TYPE_CHECKING, Any 5 | 6 | from litellm import ModelResponseStream 7 | 8 | from dspy.adapters.chat_adapter import ChatAdapter 9 | from dspy.adapters.json_adapter import JSONAdapter 10 | from dspy.adapters.types import Type 11 | from dspy.adapters.xml_adapter import XMLAdapter 12 | from dspy.dsp.utils.settings import settings 13 | from dspy.streaming.messages import StreamResponse 14 | 15 | if TYPE_CHECKING: 16 | from dspy.primitives.module import Module 17 | 18 | ADAPTER_SUPPORT_STREAMING = [ChatAdapter, XMLAdapter, JSONAdapter] 19 | 20 | 21 | class StreamListener: 22 | """Class that listens to the stream to capture the streeaming of a specific output field of a predictor.""" 23 | 24 | def __init__( 25 | self, 26 | signature_field_name: str, 27 | predict: Any = None, 28 | predict_name: str | None = None, 29 | allow_reuse: bool = False, 30 | ): 31 | """ 32 | Args: 33 | signature_field_name: The name of the field to listen to. 34 | predict: The predictor to listen to. If None, when calling `streamify()` it will automatically look for 35 | the predictor that has the `signature_field_name` in its signature. 36 | predict_name: The name of the predictor to listen to. If None, when calling `streamify()` it will 37 | automatically look for the predictor that has the `signature_field_name` in its signature. 38 | allow_reuse: If True, the stream listener can be reused for multiple streams. Please note that this could 39 | hurt the performance because the same stream chunk is sent to multiple listeners. 40 | """ 41 | self.signature_field_name = signature_field_name 42 | self.predict = predict 43 | self.predict_name = predict_name 44 | 45 | self.field_start_queue = [] 46 | self.field_end_queue = Queue() 47 | self.stream_start = False 48 | self.stream_end = False 49 | self.cache_hit = False 50 | self.allow_reuse = allow_reuse 51 | 52 | self.adapter_identifiers = { 53 | "ChatAdapter": { 54 | "start_identifier": f"[[ ## {self.signature_field_name} ## ]]", 55 | "end_identifier": re.compile(r"\[\[ ## (\w+) ## \]\]"), 56 | "start_indicator": "[", 57 | "end_pattern_prefixes": ["[", "[[", "[[ ", "[[ #", "[[ ##"], 58 | "end_pattern_contains": "[[ ##", 59 | }, 60 | "JSONAdapter": { 61 | "start_identifier": f'"{self.signature_field_name}":', 62 | "end_identifier": re.compile(r"\w*\"(,|\s*})"), 63 | "start_indicator": '"', 64 | "end_pattern_prefixes": ['"', '",', '" ', '"}'], 65 | "end_pattern_contains": None, 66 | }, 67 | "XMLAdapter": { 68 | "start_identifier": f"<{self.signature_field_name}>", 69 | "end_identifier": re.compile(rf"</{self.signature_field_name}>"), 70 | "start_indicator": "<", 71 | "end_pattern_prefixes": ["<", "</"], 72 | "end_pattern_contains": "</", # Any closing tag start 73 | }, 74 | } 75 | 76 | def _buffered_message_end_with_start_identifier(self, concat_message: str, start_identifier: str) -> str: 77 | for i in range(len(concat_message)): 78 | if start_identifier.startswith(concat_message[len(concat_message) - i - 1 :]): 79 | return True 80 | return False 81 | 82 | def _could_form_end_identifier(self, concat_message: str, adapter_name: str) -> bool: 83 | """Check if the buffered message could potentially form the end identifier. 84 | 85 | This prevents unnecessary buffering when the tokens clearly cannot form the end pattern. 86 | For example, if buffered message is "hello world" and end pattern is "[[ ## ... ## ]]", 87 | we know it cannot form the pattern, so we should yield immediately. 88 | 89 | Args: 90 | concat_message: The concatenated buffered message 91 | adapter_name: The name of the adapter being used 92 | 93 | Returns: 94 | True if the message could potentially form part of the end identifier 95 | """ 96 | adapter_config = self.adapter_identifiers[adapter_name] 97 | end_pattern_prefixes = adapter_config.get("end_pattern_prefixes", []) 98 | end_pattern_contains = adapter_config.get("end_pattern_contains") 99 | 100 | # First check: does it end with a potential start of the pattern? 101 | if any(concat_message.endswith(prefix) for prefix in end_pattern_prefixes): 102 | return True 103 | 104 | # Second check: if there's a pattern marker, check if message contains it 105 | # This handles cases like "[[ ## com" where we have partial field name 106 | if end_pattern_contains and end_pattern_contains in concat_message: 107 | return True 108 | 109 | return False 110 | 111 | def receive(self, chunk: ModelResponseStream): 112 | adapter_name = settings.adapter.__class__.__name__ if settings.adapter else "ChatAdapter" 113 | if adapter_name not in self.adapter_identifiers: 114 | raise ValueError( 115 | f"Unsupported adapter for streaming: {adapter_name}, please use one of the following adapters: " 116 | f"{', '.join([a.__name__ for a in ADAPTER_SUPPORT_STREAMING])}" 117 | ) 118 | start_identifier = self.adapter_identifiers[adapter_name]["start_identifier"] 119 | end_identifier = self.adapter_identifiers[adapter_name]["end_identifier"] 120 | start_indicator = self.adapter_identifiers[adapter_name]["start_indicator"] 121 | 122 | if self.stream_end: 123 | if self.allow_reuse: 124 | # Clear up the state for the next stream. 125 | self.stream_end = False 126 | self.cache_hit = False 127 | self.field_start_queue = [] 128 | self.field_end_queue = Queue() 129 | self.stream_start = False 130 | else: 131 | return 132 | 133 | try: 134 | chunk_message = chunk.choices[0].delta.content 135 | if chunk_message is None: 136 | return 137 | except Exception: 138 | return 139 | 140 | # Handle custom streamable types 141 | if self._output_type and issubclass(self._output_type, Type) and self._output_type.is_streamable(): 142 | if parsed_chunk := self._output_type.parse_stream_chunk(chunk): 143 | return StreamResponse( 144 | self.predict_name, 145 | self.signature_field_name, 146 | parsed_chunk, 147 | is_last_chunk=self.stream_end, 148 | ) 149 | 150 | if chunk_message and start_identifier in chunk_message: 151 | # If the cache is hit, the chunk_message could be the full response. When it happens we can 152 | # directly end the stream listening. In some models like gemini, each stream chunk can be multiple 153 | # tokens, so it's possible that response only has one chunk, we also fall back to this logic. 154 | message_after_start_identifier = chunk_message[ 155 | chunk_message.find(start_identifier) + len(start_identifier) : 156 | ] 157 | if re.search(end_identifier, message_after_start_identifier): 158 | self.cache_hit = True 159 | self.stream_start = True 160 | self.stream_end = True 161 | return 162 | 163 | if len(self.field_start_queue) == 0 and not self.stream_start and start_indicator in chunk_message: 164 | # We look for the pattern of start_identifier, i.e., "[[ ## {self.signature_field_name} ## ]]" for 165 | # ChatAdapter to identify the start of the stream of our target field. Once the start_indicator, i.e., "[[" 166 | # for ChatAdapter, is found, we start checking the next tokens 167 | self.field_start_queue.append(chunk_message) 168 | return 169 | 170 | if len(self.field_start_queue) > 0 and not self.stream_start: 171 | # We keep appending the tokens to the queue until we have a full identifier or the concanated 172 | # tokens no longer match our expected identifier. 173 | self.field_start_queue.append(chunk_message) 174 | concat_message = "".join(self.field_start_queue) 175 | 176 | if start_identifier in concat_message: 177 | # We have a full identifier, we can start the stream. 178 | self.stream_start = True 179 | self.field_start_queue = [] 180 | # Keep the part after the start_identifier from the concat_message, we need to write it to the buffer. 181 | value_start_index = concat_message.find(start_identifier) + len(start_identifier) 182 | chunk_message = concat_message[value_start_index:].lstrip() 183 | if isinstance(settings.adapter, JSONAdapter) and chunk_message.startswith('"'): 184 | # For JSONAdapter, we need to remove the leading ". We cannot do this with the start_identifier 185 | # because there could be a few splitters between ':' and '"', e.g., '"name": "value"'. 186 | chunk_message = chunk_message[1:] 187 | 188 | elif self._buffered_message_end_with_start_identifier(concat_message.strip(), start_identifier): 189 | # If the buffered message ends with part of the start_identifier, we keep looking for the 190 | # start_identifier from the token stream. 191 | return 192 | else: 193 | # Doesn't match the expected identifier, reset the queue. 194 | self.field_start_queue = [] 195 | return 196 | 197 | if self.stream_start and chunk_message: 198 | # The stream is started, we keep returning the token until we see the start of the next field. 199 | token = None 200 | self.field_end_queue.put(chunk_message) 201 | 202 | concat_message = "".join(self.field_end_queue.queue).strip() 203 | if re.search(end_identifier, concat_message): 204 | # The next field is identified, we can end the stream and flush out all tokens in the buffer. 205 | self.stream_end = True 206 | token = self.flush() 207 | token = token.rstrip() # Remove the trailing \n\n 208 | elif not self._could_form_end_identifier(concat_message, adapter_name): 209 | # Buffer cannot form end identifier, safe to flush out the tokens in the buffer. 210 | token = self.flush() 211 | elif self.field_end_queue.qsize() > 10: 212 | # Buffer could form end identifier, but we've exceeded max buffer size 213 | # Yield the oldest token to prevent unbounded buffering 214 | token = self.field_end_queue.get() 215 | 216 | if token: 217 | return StreamResponse( 218 | self.predict_name, 219 | self.signature_field_name, 220 | token, 221 | is_last_chunk=self.stream_end, 222 | ) 223 | 224 | def flush(self) -> str: 225 | """Flush all tokens in the field end queue. 226 | 227 | This method is called to flush out the last a few tokens when the stream is ended. These tokens 228 | are in the buffer because we don't directly yield the tokens received by the stream listener 229 | with the purpose to not yield the end_identifier tokens, e.g., "[[ ## ... ## ]]" for ChatAdapter. 230 | """ 231 | last_tokens = "".join(self.field_end_queue.queue) 232 | self.field_end_queue = Queue() 233 | if isinstance(settings.adapter, JSONAdapter): 234 | match = re.search(r'",|"\s*}', last_tokens) 235 | if match: 236 | boundary_index = match.start() 237 | else: 238 | boundary_index = len(last_tokens) 239 | return last_tokens[:boundary_index] 240 | elif isinstance(settings.adapter, XMLAdapter): 241 | boundary_index = last_tokens.find(f"</{self.signature_field_name}>") 242 | if boundary_index == -1: 243 | boundary_index = len(last_tokens) 244 | return last_tokens[:boundary_index] 245 | elif isinstance(settings.adapter, ChatAdapter) or settings.adapter is None: 246 | boundary_index = last_tokens.find("[[") 247 | if boundary_index == -1: 248 | boundary_index = len(last_tokens) 249 | return last_tokens[:boundary_index] 250 | else: 251 | raise ValueError( 252 | f"Unsupported adapter for streaming: {settings.adapter}, please use one of the following adapters: " 253 | f"{', '.join([a.__name__ for a in ADAPTER_SUPPORT_STREAMING])}" 254 | ) 255 | 256 | def finalize(self) -> StreamResponse | None: 257 | """Finalize the stream and flush any remaining buffered tokens. 258 | 259 | This should be called when the stream ends. 260 | It ensures no tokens are lost from the buffer and marks the final chunk appropriately. 261 | 262 | Returns: 263 | A StreamResponse with the remaining buffered tokens and is_last_chunk=True, 264 | or None if there are no buffered tokens or the stream hasn't started. 265 | """ 266 | if self.stream_end or not self.stream_start: 267 | # Stream already ended or never started, nothing to finalize 268 | return None 269 | 270 | self.stream_end = True 271 | if self.field_end_queue.qsize() > 0: 272 | token = self.flush() 273 | if token: 274 | return StreamResponse( 275 | self.predict_name, 276 | self.signature_field_name, 277 | token, 278 | is_last_chunk=True, 279 | ) 280 | return None 281 | 282 | @property 283 | def _output_type(self) -> type | None: 284 | try: 285 | return self.predict.signature.output_fields[self.signature_field_name].annotation 286 | except Exception: 287 | return None 288 | 289 | 290 | 291 | def find_predictor_for_stream_listeners(program: "Module", stream_listeners: list[StreamListener]) -> dict[int, list[StreamListener]]: 292 | """Find the predictor for each stream listener. 293 | 294 | This is a utility function to automatically find the predictor for each stream listener. It is used when some 295 | listeners don't specify the predictor they want to listen to. If a listener's `signature_field_name` is not 296 | unique in the program, this function will raise an error. 297 | """ 298 | predictors = program.named_predictors() 299 | 300 | field_name_to_named_predictor = {} 301 | for listener in stream_listeners: 302 | if listener.predict: 303 | continue 304 | field_name_to_named_predictor[listener.signature_field_name] = None 305 | 306 | for name, predictor in predictors: 307 | for field_name, field_info in predictor.signature.output_fields.items(): 308 | if field_name not in field_name_to_named_predictor: 309 | continue 310 | 311 | if field_name_to_named_predictor[field_name] is not None: 312 | raise ValueError( 313 | f"Signature field {field_name} is not unique in the program, cannot automatically determine which " 314 | "predictor to use for streaming. Please specify the predictor to listen to." 315 | ) 316 | 317 | if not _is_streamable(field_info.annotation): 318 | raise ValueError( 319 | f"Stream listener can only be applied to string or subclass of `dspy.Type` that has `is_streamable() == True`, " 320 | f"but your field {field_name} is of type {field_info.annotation}." 321 | ) 322 | 323 | field_name_to_named_predictor[field_name] = (name, predictor) 324 | 325 | predict_id_to_listener = defaultdict(list) 326 | for listener in stream_listeners: 327 | if listener.predict: 328 | predict_id_to_listener[id(listener.predict)].append(listener) 329 | continue 330 | if listener.signature_field_name not in field_name_to_named_predictor: 331 | raise ValueError( 332 | f"Signature field {listener.signature_field_name} is not a field of any predictor in the program, " 333 | "cannot automatically determine which predictor to use for streaming. Please verify your field name or " 334 | "specify the predictor to listen to." 335 | ) 336 | listener.predict_name, listener.predict = field_name_to_named_predictor[listener.signature_field_name] 337 | predict_id_to_listener[id(listener.predict)].append(listener) 338 | return predict_id_to_listener 339 | 340 | def _is_streamable(field_type: type | None) -> bool: 341 | if field_type is None: 342 | return False 343 | if field_type is str: 344 | return True 345 | if issubclass(field_type, Type): 346 | return field_type.is_streamable() 347 | return False 348 | ```