This is page 8 of 17. Use http://codebase.md/stanfordnlp/dspy?lines=true&page={x} to view the full context. # Directory Structure ``` ├── .github │ ├── .internal_dspyai │ │ ├── internals │ │ │ ├── build-and-release.md │ │ │ └── release-checklist.md │ │ └── pyproject.toml │ ├── .tmp │ │ └── .generated-actions │ │ └── run-pypi-publish-in-docker-container │ │ └── action.yml │ ├── ISSUE_TEMPLATE │ │ ├── bug_report.yml │ │ └── feature_request.yml │ ├── PULL_REQUEST_TEMPLATE │ │ └── pull_request_template.md │ ├── workflow_scripts │ │ └── install_testpypi_pkg.sh │ └── workflows │ ├── build_and_release.yml │ ├── build_utils │ │ └── test_version.py │ ├── docs-push.yml │ ├── precommits_check.yml │ └── run_tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── docs │ ├── .gitignore │ ├── docs │ │ ├── api │ │ │ ├── adapters │ │ │ │ ├── Adapter.md │ │ │ │ ├── ChatAdapter.md │ │ │ │ ├── JSONAdapter.md │ │ │ │ └── TwoStepAdapter.md │ │ │ ├── evaluation │ │ │ │ ├── answer_exact_match.md │ │ │ │ ├── answer_passage_match.md │ │ │ │ ├── CompleteAndGrounded.md │ │ │ │ ├── Evaluate.md │ │ │ │ ├── EvaluationResult.md │ │ │ │ └── SemanticF1.md │ │ │ ├── experimental │ │ │ │ ├── Citations.md │ │ │ │ └── Document.md │ │ │ ├── index.md │ │ │ ├── models │ │ │ │ ├── Embedder.md │ │ │ │ └── LM.md │ │ │ ├── modules │ │ │ │ ├── BestOfN.md │ │ │ │ ├── ChainOfThought.md │ │ │ │ ├── CodeAct.md │ │ │ │ ├── Module.md │ │ │ │ ├── MultiChainComparison.md │ │ │ │ ├── Parallel.md │ │ │ │ ├── Predict.md │ │ │ │ ├── ProgramOfThought.md │ │ │ │ ├── ReAct.md │ │ │ │ └── Refine.md │ │ │ ├── optimizers │ │ │ │ ├── BetterTogether.md │ │ │ │ ├── BootstrapFewShot.md │ │ │ │ ├── BootstrapFewShotWithRandomSearch.md │ │ │ │ ├── BootstrapFinetune.md │ │ │ │ ├── BootstrapRS.md │ │ │ │ ├── COPRO.md │ │ │ │ ├── Ensemble.md │ │ │ │ ├── GEPA │ │ │ │ │ ├── GEPA_Advanced.md │ │ │ │ │ └── overview.md │ │ │ │ ├── InferRules.md │ │ │ │ ├── KNN.md │ │ │ │ ├── KNNFewShot.md │ │ │ │ ├── LabeledFewShot.md │ │ │ │ ├── MIPROv2.md │ │ │ │ └── SIMBA.md │ │ │ ├── primitives │ │ │ │ ├── Audio.md │ │ │ │ ├── Code.md │ │ │ │ ├── Example.md │ │ │ │ ├── History.md │ │ │ │ ├── Image.md │ │ │ │ ├── Prediction.md │ │ │ │ ├── Tool.md │ │ │ │ └── ToolCalls.md │ │ │ ├── signatures │ │ │ │ ├── InputField.md │ │ │ │ ├── OutputField.md │ │ │ │ └── Signature.md │ │ │ ├── tools │ │ │ │ ├── ColBERTv2.md │ │ │ │ ├── Embeddings.md │ │ │ │ └── PythonInterpreter.md │ │ │ └── utils │ │ │ ├── asyncify.md │ │ │ ├── configure_cache.md │ │ │ ├── disable_litellm_logging.md │ │ │ ├── disable_logging.md │ │ │ ├── enable_litellm_logging.md │ │ │ ├── enable_logging.md │ │ │ ├── inspect_history.md │ │ │ ├── load.md │ │ │ ├── StatusMessage.md │ │ │ ├── StatusMessageProvider.md │ │ │ ├── streamify.md │ │ │ └── StreamListener.md │ │ ├── cheatsheet.md │ │ ├── community │ │ │ ├── community-resources.md │ │ │ ├── how-to-contribute.md │ │ │ └── use-cases.md │ │ ├── deep-dive │ │ │ └── data-handling │ │ │ ├── built-in-datasets.md │ │ │ ├── examples.md │ │ │ ├── img │ │ │ │ └── data-loading.png │ │ │ └── loading-custom-data.md │ │ ├── faqs.md │ │ ├── index.md │ │ ├── js │ │ │ └── runllm-widget.js │ │ ├── learn │ │ │ ├── evaluation │ │ │ │ ├── data.md │ │ │ │ ├── metrics.md │ │ │ │ └── overview.md │ │ │ ├── figures │ │ │ │ ├── native_tool_call.png │ │ │ │ └── teleprompter-classes.png │ │ │ ├── index.md │ │ │ ├── optimization │ │ │ │ ├── optimizers.md │ │ │ │ └── overview.md │ │ │ └── programming │ │ │ ├── 7-assertions.md │ │ │ ├── adapters.md │ │ │ ├── language_models.md │ │ │ ├── mcp.md │ │ │ ├── modules.md │ │ │ ├── overview.md │ │ │ ├── signatures.md │ │ │ └── tools.md │ │ ├── production │ │ │ └── index.md │ │ ├── roadmap.md │ │ ├── static │ │ │ ├── .nojekyll │ │ │ └── img │ │ │ ├── dspy_logo.png │ │ │ ├── logo.png │ │ │ ├── mlflow-tracing-rag.png │ │ │ ├── modular.png │ │ │ ├── optimize.png │ │ │ ├── undraw_docusaurus_mountain.svg │ │ │ ├── undraw_docusaurus_react.svg │ │ │ ├── undraw_docusaurus_tree.svg │ │ │ └── universal_compatibility.png │ │ ├── stylesheets │ │ │ └── extra.css │ │ └── tutorials │ │ ├── agents │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-agent.png │ │ ├── ai_text_game │ │ │ └── index.md │ │ ├── async │ │ │ └── index.md │ │ ├── audio │ │ │ └── index.ipynb │ │ ├── build_ai_program │ │ │ └── index.md │ │ ├── cache │ │ │ └── index.md │ │ ├── classification │ │ │ └── index.md │ │ ├── classification_finetuning │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-classification.png │ │ ├── conversation_history │ │ │ └── index.md │ │ ├── core_development │ │ │ └── index.md │ │ ├── custom_module │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-custom-module.png │ │ ├── customer_service_agent │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-customer-service-agent.png │ │ ├── deployment │ │ │ ├── dspy_mlflow_ui.png │ │ │ └── index.md │ │ ├── email_extraction │ │ │ ├── index.md │ │ │ └── mlflow-tracing-email-extraction.png │ │ ├── entity_extraction │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-entity-extraction.png │ │ ├── games │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-agent.png │ │ ├── gepa_ai_program │ │ │ └── index.md │ │ ├── gepa_aime │ │ │ ├── index.ipynb │ │ │ ├── mlflow-tracing-gepa-aime.png │ │ │ └── mlflow-tracking-gepa-aime-optimization.png │ │ ├── gepa_facilitysupportanalyzer │ │ │ ├── index.ipynb │ │ │ ├── mlflow-tracing-gepa-support.png │ │ │ └── mlflow-tracking-gepa-support-optimization.png │ │ ├── gepa_papillon │ │ │ ├── index.ipynb │ │ │ ├── mlflow-tracing-gepa-papilon.png │ │ │ └── mlflow-tracking-gepa-papilon-optimization.png │ │ ├── image_generation_prompting │ │ │ └── index.ipynb │ │ ├── index.md │ │ ├── llms_txt_generation │ │ │ └── index.md │ │ ├── math │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-math.png │ │ ├── mcp │ │ │ └── index.md │ │ ├── mem0_react_agent │ │ │ └── index.md │ │ ├── multihop_search │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-multi-hop.png │ │ ├── observability │ │ │ ├── index.md │ │ │ ├── mlflow_trace_ui_navigation.gif │ │ │ ├── mlflow_trace_ui.png │ │ │ └── mlflow_trace_view.png │ │ ├── optimize_ai_program │ │ │ └── index.md │ │ ├── optimizer_tracking │ │ │ ├── child_run.png │ │ │ ├── experiment.png │ │ │ ├── index.md │ │ │ └── parent_run.png │ │ ├── output_refinement │ │ │ └── best-of-n-and-refine.md │ │ ├── papillon │ │ │ └── index.md │ │ ├── program_of_thought │ │ │ └── index.ipynb │ │ ├── rag │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-rag.png │ │ ├── real_world_examples │ │ │ └── index.md │ │ ├── rl_ai_program │ │ │ └── index.md │ │ ├── rl_multihop │ │ │ └── index.ipynb │ │ ├── rl_papillon │ │ │ └── index.ipynb │ │ ├── sample_code_generation │ │ │ └── index.md │ │ ├── saving │ │ │ └── index.md │ │ ├── streaming │ │ │ └── index.md │ │ ├── tool_use │ │ │ └── index.ipynb │ │ └── yahoo_finance_react │ │ └── index.md │ ├── mkdocs.yml │ ├── overrides │ │ ├── home.html │ │ ├── main.html │ │ └── partials │ │ └── tabs.html │ ├── Pipfile │ ├── Pipfile.lock │ ├── README.md │ ├── requirements.txt │ ├── scripts │ │ ├── generate_api_docs.py │ │ └── generate_api_summary.py │ └── vercel.json ├── dspy │ ├── __init__.py │ ├── __metadata__.py │ ├── adapters │ │ ├── __init__.py │ │ ├── baml_adapter.py │ │ ├── base.py │ │ ├── chat_adapter.py │ │ ├── json_adapter.py │ │ ├── two_step_adapter.py │ │ ├── types │ │ │ ├── __init__.py │ │ │ ├── audio.py │ │ │ ├── base_type.py │ │ │ ├── citation.py │ │ │ ├── code.py │ │ │ ├── document.py │ │ │ ├── history.py │ │ │ ├── image.py │ │ │ └── tool.py │ │ ├── utils.py │ │ └── xml_adapter.py │ ├── clients │ │ ├── __init__.py │ │ ├── base_lm.py │ │ ├── cache.py │ │ ├── databricks.py │ │ ├── embedding.py │ │ ├── lm_local_arbor.py │ │ ├── lm_local.py │ │ ├── lm.py │ │ ├── openai.py │ │ ├── provider.py │ │ └── utils_finetune.py │ ├── datasets │ │ ├── __init__.py │ │ ├── alfworld │ │ │ ├── __init__.py │ │ │ ├── alfworld.py │ │ │ └── base_config.yml │ │ ├── colors.py │ │ ├── dataloader.py │ │ ├── dataset.py │ │ ├── gsm8k.py │ │ ├── hotpotqa.py │ │ └── math.py │ ├── dsp │ │ ├── __init__.py │ │ ├── colbertv2.py │ │ └── utils │ │ ├── __init__.py │ │ ├── dpr.py │ │ ├── settings.py │ │ └── utils.py │ ├── evaluate │ │ ├── __init__.py │ │ ├── auto_evaluation.py │ │ ├── evaluate.py │ │ └── metrics.py │ ├── experimental │ │ └── __init__.py │ ├── predict │ │ ├── __init__.py │ │ ├── aggregation.py │ │ ├── avatar │ │ │ ├── __init__.py │ │ │ ├── avatar.py │ │ │ ├── models.py │ │ │ └── signatures.py │ │ ├── best_of_n.py │ │ ├── chain_of_thought.py │ │ ├── code_act.py │ │ ├── knn.py │ │ ├── multi_chain_comparison.py │ │ ├── parallel.py │ │ ├── parameter.py │ │ ├── predict.py │ │ ├── program_of_thought.py │ │ ├── react.py │ │ ├── refine.py │ │ └── retry.py │ ├── primitives │ │ ├── __init__.py │ │ ├── base_module.py │ │ ├── example.py │ │ ├── module.py │ │ ├── prediction.py │ │ ├── python_interpreter.py │ │ └── runner.js │ ├── propose │ │ ├── __init__.py │ │ ├── dataset_summary_generator.py │ │ ├── grounded_proposer.py │ │ ├── propose_base.py │ │ └── utils.py │ ├── retrievers │ │ ├── __init__.py │ │ ├── databricks_rm.py │ │ ├── embeddings.py │ │ ├── retrieve.py │ │ └── weaviate_rm.py │ ├── signatures │ │ ├── __init__.py │ │ ├── field.py │ │ ├── signature.py │ │ └── utils.py │ ├── streaming │ │ ├── __init__.py │ │ ├── messages.py │ │ ├── streamify.py │ │ └── streaming_listener.py │ ├── teleprompt │ │ ├── __init__.py │ │ ├── avatar_optimizer.py │ │ ├── bettertogether.py │ │ ├── bootstrap_finetune.py │ │ ├── bootstrap_trace.py │ │ ├── bootstrap.py │ │ ├── copro_optimizer.py │ │ ├── ensemble.py │ │ ├── gepa │ │ │ ├── __init__.py │ │ │ ├── gepa_utils.py │ │ │ ├── gepa.py │ │ │ └── instruction_proposal.py │ │ ├── grpo.py │ │ ├── infer_rules.py │ │ ├── knn_fewshot.py │ │ ├── mipro_optimizer_v2.py │ │ ├── random_search.py │ │ ├── signature_opt.py │ │ ├── simba_utils.py │ │ ├── simba.py │ │ ├── teleprompt_optuna.py │ │ ├── teleprompt.py │ │ ├── utils.py │ │ └── vanilla.py │ └── utils │ ├── __init__.py │ ├── annotation.py │ ├── asyncify.py │ ├── caching.py │ ├── callback.py │ ├── dummies.py │ ├── exceptions.py │ ├── hasher.py │ ├── inspect_history.py │ ├── langchain_tool.py │ ├── logging_utils.py │ ├── mcp.py │ ├── parallelizer.py │ ├── saving.py │ ├── syncify.py │ ├── unbatchify.py │ └── usage_tracker.py ├── LICENSE ├── pyproject.toml ├── README.md ├── tests │ ├── __init__.py │ ├── adapters │ │ ├── test_adapter_utils.py │ │ ├── test_baml_adapter.py │ │ ├── test_base_type.py │ │ ├── test_chat_adapter.py │ │ ├── test_citation.py │ │ ├── test_code.py │ │ ├── test_document.py │ │ ├── test_json_adapter.py │ │ ├── test_tool.py │ │ ├── test_two_step_adapter.py │ │ └── test_xml_adapter.py │ ├── callback │ │ └── test_callback.py │ ├── clients │ │ ├── test_cache.py │ │ ├── test_databricks.py │ │ ├── test_embedding.py │ │ ├── test_inspect_global_history.py │ │ └── test_lm.py │ ├── conftest.py │ ├── datasets │ │ └── test_dataset.py │ ├── docs │ │ └── test_mkdocs_links.py │ ├── evaluate │ │ ├── test_evaluate.py │ │ └── test_metrics.py │ ├── examples │ │ └── test_baleen.py │ ├── metadata │ │ └── test_metadata.py │ ├── predict │ │ ├── test_aggregation.py │ │ ├── test_best_of_n.py │ │ ├── test_chain_of_thought.py │ │ ├── test_code_act.py │ │ ├── test_knn.py │ │ ├── test_multi_chain_comparison.py │ │ ├── test_parallel.py │ │ ├── test_predict.py │ │ ├── test_program_of_thought.py │ │ ├── test_react.py │ │ ├── test_refine.py │ │ └── test_retry.py │ ├── primitives │ │ ├── resources │ │ │ └── saved_program.json │ │ ├── test_base_module.py │ │ ├── test_example.py │ │ ├── test_module.py │ │ └── test_python_interpreter.py │ ├── propose │ │ └── test_grounded_proposer.py │ ├── README.md │ ├── reliability │ │ ├── __init__.py │ │ ├── complex_types │ │ │ └── generated │ │ │ ├── test_many_types_1 │ │ │ │ ├── inputs │ │ │ │ │ ├── input1.json │ │ │ │ │ └── input2.json │ │ │ │ ├── program.py │ │ │ │ └── schema.json │ │ │ ├── test_nesting_1 │ │ │ │ ├── inputs │ │ │ │ │ ├── input1.json │ │ │ │ │ └── input2.json │ │ │ │ ├── program.py │ │ │ │ └── schema.json │ │ │ └── test_nesting_2 │ │ │ ├── inputs │ │ │ │ └── input1.json │ │ │ ├── program.py │ │ │ └── schema.json │ │ ├── conftest.py │ │ ├── generate │ │ │ ├── __init__.py │ │ │ ├── __main__.py │ │ │ └── utils.py │ │ ├── input_formats │ │ │ └── generated │ │ │ └── test_markdown_1 │ │ │ ├── inputs │ │ │ │ ├── input1.json │ │ │ │ └── input2.json │ │ │ ├── program.py │ │ │ └── schema.json │ │ ├── README.md │ │ ├── reliability_conf.yaml │ │ ├── test_generated.py │ │ ├── test_pydantic_models.py │ │ └── utils.py │ ├── retrievers │ │ └── test_embeddings.py │ ├── signatures │ │ ├── test_adapter_image.py │ │ ├── test_custom_types.py │ │ └── test_signature.py │ ├── streaming │ │ └── test_streaming.py │ ├── teleprompt │ │ ├── gepa_dummy_lm_custom_component_selector_custom_instruction_proposer.json │ │ ├── gepa_dummy_lm.json │ │ ├── test_bootstrap_finetune.py │ │ ├── test_bootstrap_trace.py │ │ ├── test_bootstrap.py │ │ ├── test_copro_optimizer.py │ │ ├── test_ensemble.py │ │ ├── test_finetune.py │ │ ├── test_gepa_instruction_proposer.py │ │ ├── test_gepa.py │ │ ├── test_grpo.py │ │ ├── test_knn_fewshot.py │ │ ├── test_random_search.py │ │ ├── test_teleprompt.py │ │ └── test_utils.py │ ├── test_utils │ │ ├── __init__.py │ │ └── server │ │ ├── __init__.py │ │ ├── litellm_server_config.yaml │ │ └── litellm_server.py │ └── utils │ ├── __init__.py │ ├── resources │ │ └── mcp_server.py │ ├── test_annotation.py │ ├── test_asyncify.py │ ├── test_exceptions.py │ ├── test_langchain_tool.py │ ├── test_mcp.py │ ├── test_parallelizer.py │ ├── test_saving.py │ ├── test_settings.py │ ├── test_syncify.py │ ├── test_unbatchify.py │ └── test_usage_tracker.py └── uv.lock ``` # Files -------------------------------------------------------------------------------- /docs/scripts/generate_api_docs.py: -------------------------------------------------------------------------------- ```python 1 | import importlib 2 | import inspect 3 | import pkgutil 4 | from pathlib import Path 5 | from typing import Any 6 | 7 | import dspy 8 | 9 | API_MAPPING = { 10 | "models": [ 11 | dspy.LM, 12 | dspy.Embedder, 13 | ], 14 | "primitives": [ 15 | dspy.Audio, 16 | dspy.Code, 17 | dspy.Example, 18 | dspy.Image, 19 | dspy.History, 20 | dspy.Prediction, 21 | dspy.Tool, 22 | dspy.ToolCalls, 23 | ], 24 | "signatures": [ 25 | dspy.Signature, 26 | dspy.InputField, 27 | dspy.OutputField, 28 | ], 29 | "adapters": [ 30 | dspy.Adapter, 31 | dspy.ChatAdapter, 32 | dspy.JSONAdapter, 33 | dspy.TwoStepAdapter, 34 | ], 35 | "modules": [ 36 | dspy.Module, 37 | dspy.Predict, 38 | dspy.ChainOfThought, 39 | dspy.ReAct, 40 | dspy.ProgramOfThought, 41 | dspy.MultiChainComparison, 42 | dspy.Parallel, 43 | dspy.BestOfN, 44 | dspy.Refine, 45 | ], 46 | "tools": [ 47 | dspy.ColBERTv2, 48 | dspy.retrievers.Embeddings, 49 | dspy.PythonInterpreter, 50 | ], 51 | "utils": [ 52 | dspy.inspect_history, 53 | dspy.load, 54 | dspy.asyncify, 55 | dspy.streamify, 56 | dspy.enable_logging, 57 | dspy.disable_logging, 58 | dspy.enable_litellm_logging, 59 | dspy.disable_litellm_logging, 60 | dspy.configure_cache, 61 | dspy.streaming.StatusMessageProvider, 62 | dspy.streaming.StatusMessage, 63 | dspy.streaming.StreamListener, 64 | ], 65 | "evaluation": [ 66 | dspy.Evaluate, 67 | dspy.evaluate.answer_exact_match, 68 | dspy.evaluate.answer_passage_match, 69 | dspy.evaluate.SemanticF1, 70 | dspy.evaluate.CompleteAndGrounded, 71 | dspy.evaluate.EvaluationResult, 72 | ], 73 | "optimizers": [ 74 | dspy.LabeledFewShot, 75 | dspy.BootstrapFewShot, 76 | dspy.BootstrapFewShotWithRandomSearch, 77 | dspy.MIPROv2, 78 | dspy.BetterTogether, 79 | dspy.BootstrapFinetune, 80 | dspy.COPRO, 81 | dspy.Ensemble, 82 | dspy.KNN, 83 | dspy.KNNFewShot, 84 | dspy.InferRules, 85 | dspy.GEPA, 86 | ], 87 | "experimental": [ 88 | dspy.experimental.Citations, 89 | dspy.experimental.Document, 90 | ], 91 | } 92 | 93 | LOCATION_OVERRIDES = { 94 | "docs/api/optimizers/GEPA.md": "docs/api/optimizers/GEPA/overview.md", 95 | } 96 | 97 | def should_document_method(obj): 98 | name = obj.__name__ 99 | # Exclude methods not defined in dspy, such as `model_dump_json` from pydantic. 100 | module = getattr(obj, "__module__", "") 101 | if not module or not module.startswith("dspy"): 102 | return False 103 | # Exclude private and dunder methods, but include `__call__` 104 | if name == "__call__" or not name.startswith("_"): 105 | return True 106 | return False 107 | 108 | 109 | def get_module_contents(module): 110 | """Get all public classes and functions from a module.""" 111 | contents_in_all = getattr(module, "__all__", None) 112 | 113 | contents = {} 114 | for name, obj in inspect.getmembers(module): 115 | if contents_in_all and name not in contents_in_all: 116 | continue 117 | if inspect.ismodule(obj) and obj.__name__.startswith(module.__name__) and not name.startswith("_"): 118 | contents[name] = obj 119 | elif ( 120 | (inspect.isclass(obj) or (inspect.isroutine(obj) and should_document_method(obj))) 121 | # classes or functions in experimental module are not located in dspy/experimental 122 | and (obj.__module__.startswith(module.__name__) or module.__name__.startswith("dspy.experimental")) 123 | and not name.startswith("_") 124 | ): 125 | contents[name] = obj 126 | return contents 127 | 128 | 129 | def get_public_methods(cls): 130 | """Returns a list of all public methods in a class.""" 131 | return [ 132 | name 133 | for name, member in inspect.getmembers( 134 | cls, predicate=lambda x: inspect.isroutine(x) and should_document_method(x) 135 | ) 136 | ] 137 | 138 | 139 | def generate_doc_page(name: str, module_path: str, obj: Any, is_root: bool = False) -> str: 140 | """Generate documentation page content for an object.""" 141 | members_config = "" 142 | if inspect.isclass(obj): 143 | methods = get_public_methods(obj) 144 | if methods: 145 | methods_list = "\n".join(f" - {method}" for method in methods) 146 | members_config = f""" 147 | members: 148 | {methods_list}""" 149 | 150 | # We need to put ::: at last to avoid unclosed div. See https://github.com/danielfrg/mkdocs-jupyter/issues/231 for more details. 151 | return f"""<!-- START_API_REF --> 152 | ::: {module_path}.{name} 153 | handler: python 154 | options:{members_config} 155 | show_source: true 156 | show_root_heading: true 157 | heading_level: 2 158 | docstring_style: google 159 | show_root_full_path: true 160 | show_object_full_path: false 161 | separate_signature: false 162 | inherited_members: true 163 | ::: 164 | <!-- END_API_REF --> 165 | """ 166 | 167 | 168 | def get_api_category(obj): 169 | for category, objects in API_MAPPING.items(): 170 | if obj in objects: 171 | return category 172 | return None 173 | 174 | 175 | def read_existing_content(file_path: Path) -> tuple[str, str]: 176 | """Read existing file content and split into pre and post API reference sections. 177 | 178 | Returns: 179 | tuple[str, str]: (content_before_api_ref, content_after_api_ref) 180 | If file doesn't exist or no API ref section found, returns empty strings. 181 | """ 182 | if not file_path.exists(): 183 | return "", "" 184 | 185 | content = file_path.read_text() 186 | 187 | # Look for our specific API reference markers 188 | api_start_marker = "<!-- START_API_REF -->" 189 | api_end_marker = "<!-- END_API_REF -->" 190 | 191 | api_start = content.find(api_start_marker) 192 | if api_start == -1: 193 | # No API section found, treat all content as pre-content 194 | return content, "" 195 | 196 | api_end = content.find(api_end_marker) 197 | if api_end == -1: 198 | # Start marker found but no end marker - treat rest of file as post-content 199 | api_end = len(content) 200 | else: 201 | api_end = api_end + len(api_end_marker) 202 | 203 | return content[:api_start].rstrip(), content[api_end:].lstrip() 204 | 205 | 206 | def write_doc_file(file_path: Path, title: str, api_content: str): 207 | """Write documentation to file while preserving existing content.""" 208 | pre_content, post_content = read_existing_content(file_path) 209 | 210 | # If no pre-content exists, add the title 211 | if not pre_content: 212 | pre_content = f"# {title}\n" 213 | 214 | # Combine all sections 215 | full_content = f"{pre_content}\n\n{api_content}\n{post_content}".strip() + "\n" 216 | 217 | # Write the combined content 218 | file_path.write_text(full_content) 219 | 220 | 221 | def generate_md_docs(output_dir: Path, excluded_modules=None): 222 | """Generate documentation for all public classes and functions in the dspy package. 223 | 224 | Args: 225 | output_dir: The directory to write the documentation to, e.g. "docs/api" 226 | excluded_modules: A list of modules to exclude from documentation, e.g. ["dspy.dsp"] 227 | """ 228 | module = importlib.import_module("dspy") 229 | output_dir.mkdir(parents=True, exist_ok=True) 230 | 231 | init_contents = get_module_contents(module) 232 | objects_processed = {} 233 | 234 | # Generate docs for root-level objects, e.g. dspy.Predict, dspy.Example, etc. 235 | for name, obj in init_contents.items(): 236 | if inspect.ismodule(obj): 237 | continue 238 | 239 | category = get_api_category(obj) 240 | if category is None: 241 | # Skip if the object is not in the API mapping. 242 | continue 243 | 244 | page_content = generate_doc_page(name, "dspy", obj, is_root=True) 245 | file_path = output_dir / category / f"{name}.md" 246 | if file_path.as_posix() in LOCATION_OVERRIDES: 247 | file_path = Path(LOCATION_OVERRIDES[file_path.as_posix()]) 248 | write_doc_file(file_path, f"dspy.{name}", page_content) 249 | 250 | objects_processed[f"{obj.__module__}.{name}"] = obj 251 | 252 | for submodule in pkgutil.iter_modules(module.__path__, prefix=f"{module.__name__}."): 253 | submodule_name = submodule.name.split(".")[-1] 254 | 255 | # Skip if this is a private module or not in __init__.py 256 | if submodule_name.startswith("_") or submodule_name not in init_contents: 257 | continue 258 | 259 | generate_md_docs_submodule(submodule.name, output_dir, objects_processed, excluded_modules) 260 | 261 | 262 | def generate_md_docs_submodule(module_path: str, output_dir: Path, objects_processed=None, excluded_modules=None): 263 | """Recursively generate documentation for a submodule. 264 | 265 | We generate docs for all public classes and functions in the submodule, then recursively generate docs for all 266 | submodules within the submodule. 267 | 268 | Args: 269 | module_path: The path to the submodule, e.g. "dspy.predict" 270 | output_dir: The directory to write the documentation to, e.g. "docs/api/predict" 271 | objects_processed: A dictionary of objects that have already been processed, used to avoid redundant processing. 272 | excluded_modules: A list of modules to exclude from documentation, e.g. ["dspy.dsp"] 273 | """ 274 | if excluded_modules and module_path in excluded_modules: 275 | return 276 | 277 | try: 278 | module = importlib.import_module(module_path) 279 | except ImportError: 280 | print(f"Skipping {module_path} due to import error") 281 | return 282 | 283 | init_contents = get_module_contents(module) 284 | 285 | for name, obj in init_contents.items(): 286 | if inspect.ismodule(obj): 287 | continue 288 | 289 | category = get_api_category(obj) 290 | if category is None: 291 | # Skip if the object is not in the API mapping. 292 | continue 293 | 294 | full_name = f"{obj.__module__}.{name}" 295 | if full_name not in objects_processed: 296 | # Only generate docs for objects that are not root-level objects. 297 | page_content = generate_doc_page(name, module_path, obj, is_root=False) 298 | file_path = output_dir / category / f"{name}.md" 299 | if file_path.as_posix() in LOCATION_OVERRIDES: 300 | file_path = Path(LOCATION_OVERRIDES[file_path.as_posix()]) 301 | write_doc_file(file_path, f"{module_path}.{name}", page_content) 302 | 303 | objects_processed[full_name] = obj 304 | 305 | for name, obj in init_contents.items(): 306 | if inspect.ismodule(obj): 307 | generate_md_docs_submodule(f"{module_path}.{name}", output_dir / name, objects_processed) 308 | 309 | 310 | def remove_empty_dirs(path: Path): 311 | """Recursively remove empty directories.""" 312 | for child in path.glob("*"): 313 | if child.is_dir(): 314 | remove_empty_dirs(child) 315 | 316 | if path.is_dir() and not any(path.iterdir()): 317 | path.rmdir() 318 | 319 | 320 | if __name__ == "__main__": 321 | api_dir = Path("docs/api") 322 | api_dir.mkdir(parents=True, exist_ok=True) 323 | 324 | # Create category directories if they don't exist 325 | for category in API_MAPPING.keys(): 326 | subpath = api_dir / category 327 | subpath.mkdir(parents=True, exist_ok=True) 328 | 329 | excluded_modules = ["dspy.dsp"] 330 | generate_md_docs(api_dir, excluded_modules=excluded_modules) 331 | 332 | # Clean up empty directories 333 | remove_empty_dirs(api_dir) 334 | ``` -------------------------------------------------------------------------------- /dspy/teleprompt/simba_utils.py: -------------------------------------------------------------------------------- ```python 1 | import inspect 2 | import logging 3 | import textwrap 4 | from typing import Callable 5 | 6 | import orjson 7 | 8 | import dspy 9 | from dspy.adapters.utils import get_field_description_string 10 | from dspy.signatures import InputField, OutputField 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | def prepare_models_for_resampling(program: dspy.Module, n: int, teacher_settings: dict | None = None): 15 | lm = program.get_lm() or dspy.settings.lm 16 | 17 | start_rollout_id = lm.kwargs.get("rollout_id", 0) 18 | rollout_ids = [start_rollout_id + i for i in range(n)] 19 | 20 | 21 | start_rollout_idx, models = 0, [] 22 | # If we have a teacher model, use this as the first model 23 | if teacher_settings: 24 | teacher_lm = teacher_settings.get("lm") or lm 25 | teacher_lm.kwargs["rollout_id"] = rollout_ids[start_rollout_idx] 26 | models.append(teacher_lm) 27 | start_rollout_idx += 1 28 | 29 | # The rest of the models are just copies of the base model 30 | models.extend([lm.copy(rollout_id=r, temperature=1.0) for r in rollout_ids[start_rollout_idx:]]) 31 | 32 | return models 33 | 34 | def wrap_program(program: dspy.Module, metric: Callable): 35 | def wrapped_program(example): 36 | with dspy.context(trace=[]): 37 | prediction, trace, score = None, None, 0.0 38 | try: 39 | prediction = program(**example.inputs()) 40 | except Exception as e: 41 | logger.warning(e) 42 | trace = dspy.settings.trace.copy() 43 | 44 | output = None 45 | score = 0.0 46 | output_metadata = {} 47 | 48 | try: 49 | output = metric(example, prediction) 50 | if isinstance(output, (int, float)): 51 | score = output 52 | elif isinstance(output, dspy.Prediction): 53 | if not hasattr(output, "score"): 54 | raise ValueError("When `metric` returns a `dspy.Prediction`, it must contain a `score` field.") 55 | score = output.score 56 | # Extract fields from the output dspy.Prediction, excluding `score`` 57 | output_metadata = { 58 | k: v for k, v in output.items() if k != "score" 59 | } 60 | except Exception as e: 61 | logger.warning(e) 62 | 63 | return { 64 | "prediction": prediction, 65 | "trace": trace, 66 | "score": score, 67 | "example": example, 68 | "output_metadata": output_metadata 69 | } 70 | 71 | return wrapped_program 72 | 73 | def append_a_demo(demo_input_field_maxlen): 74 | def append_a_demo_(bucket, system, **kwargs): 75 | predictor2name, name2predictor = kwargs["predictor2name"], kwargs["name2predictor"] 76 | batch_10p_score = kwargs["batch_10p_score"] 77 | 78 | good = bucket[0] 79 | trace = good["trace"] 80 | name2demo = {} 81 | 82 | if good["score"] <= batch_10p_score: 83 | logger.info(f"Skipping appending a demo as good score {good['score']} is at or below the 10th percentile.") 84 | return False 85 | 86 | for step in trace: 87 | predictor, _inputs, _outputs = step 88 | 89 | for k, v in _inputs.items(): 90 | if demo_input_field_maxlen and len(str(v)) > demo_input_field_maxlen: 91 | _inputs[k] = f"{str(v)[:demo_input_field_maxlen]}\n\t\t... <TRUNCATED FOR BREVITY>" 92 | 93 | demo = dspy.Example(augmented=True, **_inputs, **_outputs) 94 | name = predictor2name[id(predictor)] 95 | name2demo[name] = demo # keep the last demo for each predictor 96 | for name, demo in name2demo.items(): 97 | predictor = name2predictor[name] 98 | predictor.demos.append(demo) 99 | 100 | logger.info(f"Added {len(name2demo)} demos (one each) across all predictors.") 101 | return True 102 | 103 | return append_a_demo_ 104 | 105 | 106 | def append_a_rule(bucket, system, **kwargs): 107 | predictor2name = kwargs["predictor2name"] 108 | batch_10p_score, batch_90p_score = kwargs["batch_10p_score"], kwargs["batch_90p_score"] 109 | prompt_model = kwargs["prompt_model"] or dspy.settings.lm 110 | 111 | module_names = [name for name, _ in system.named_predictors()] 112 | good, bad = bucket[0], bucket[-1] 113 | example = good["example"] 114 | 115 | if good["score"] <= batch_10p_score or bad["score"] >= batch_90p_score: 116 | logger.info(f"Skipping rule generation as good score {good['score']} is at or below the 10th percentile " 117 | f"*or* bad score {bad['score']} is at or above the 90th percentile.") 118 | return False 119 | 120 | if good["score"] <= bad["score"]: 121 | if good["score"] > batch_90p_score: 122 | bad["trace"] = [] 123 | bad["score"] = "N/A" 124 | bad["prediction"] = {"N/A": "Prediction not available"} 125 | else: 126 | good["trace"] = [] 127 | good["score"] = "N/A" 128 | good["prediction"] = {"N/A": "Prediction not available"} 129 | 130 | better_trajectory = [ 131 | {"module_name": predictor2name[id(p)], "inputs": i, "outputs": dict(o)} 132 | for p, i, o in good["trace"] 133 | ] 134 | worse_trajectory = [ 135 | {"module_name": predictor2name[id(p)], "inputs": i, "outputs": dict(o)} 136 | for p, i, o in bad["trace"] 137 | ] 138 | 139 | kwargs = { 140 | "program_code": inspect.getsource(system.__class__), 141 | "modules_defn": inspect_modules(system), 142 | "program_inputs": {**example.inputs()}, 143 | "oracle_metadata": {**example.labels()}, 144 | "better_program_trajectory": better_trajectory, 145 | "better_program_outputs": dict(good["prediction"]), 146 | "worse_program_trajectory": worse_trajectory, 147 | "worse_program_outputs": dict(bad["prediction"] or {}), 148 | "worse_reward_value": bad["score"], 149 | "better_reward_value": good["score"], 150 | "worse_reward_info": bad["output_metadata"], 151 | "better_reward_info": good["output_metadata"], 152 | "module_names": module_names, 153 | } 154 | 155 | kwargs = {k: v if isinstance(v, str) else orjson.dumps(recursive_mask(v), option=orjson.OPT_INDENT_2).decode() 156 | for k, v in kwargs.items()} 157 | 158 | with dspy.settings.context(trace=[], lm=prompt_model): 159 | advice_program = dspy.Predict(OfferFeedback) 160 | advice = advice_program(**kwargs).module_advice 161 | 162 | for name, predictor in system.named_predictors(): 163 | if name in advice: 164 | logger.info(f"Advice for {name}: {advice[name]}") 165 | instructions = predictor.signature.instructions + "\n\n" + advice[name] 166 | predictor.signature = predictor.signature.with_instructions(instructions) 167 | 168 | return True 169 | 170 | class OfferFeedback(dspy.Signature): 171 | """ 172 | You will be given two trajectories of an LLM-driven program's execution. Your goal is to help the program's modules 173 | build up experience on how to maximize the reward value assigned to the program's outputs if it were to receive 174 | similar inputs in the future. 175 | 176 | The module won't see its own history. It will rely on your advice balancing being concrete and being generalizable. 177 | 178 | In your advice: 179 | - Avoid boilerplate. Offer advice that would change the module's behavior for the better in the future. 180 | - Ensure that advice offered to a module M is specific to that M's specific sub-task, not the overall program. 181 | - Rely on contrasting the behavior of the worse trajectory against the better trajectory in making recommendations. 182 | - Ensure each unique module name appears exactly once as a key in the advice dictionary. 183 | """ 184 | 185 | program_code: str = InputField(desc="The code of the program that we are analyzing") 186 | modules_defn: str = InputField(desc="The definition of each module in the program, including its I/O") 187 | program_inputs: str = InputField(desc="The inputs to the program that we are analyzing") 188 | oracle_metadata: str = InputField(desc="Any (hidden) metadata about the training set instance we're analyzing") 189 | worse_program_trajectory: str = InputField( 190 | desc="The trajectory of the program's execution, showing each module's I/O" 191 | ) 192 | worse_program_outputs: str = InputField(desc="The outputs of the program that we are analyzing") 193 | worse_reward_value: float = InputField(desc="The reward value assigned to the program's outputs") 194 | worse_reward_info: str = InputField(desc="Additional information that might be helpful to understanding the assigned reward value.") 195 | better_program_trajectory: str = InputField( 196 | desc="The trajectory of the program's execution, showing each module's I/O" 197 | ) 198 | better_program_outputs: str = InputField(desc="The outputs of the program that we are analyzing") 199 | better_reward_value: float = InputField(desc="The reward value assigned to the program's outputs") 200 | better_reward_info: str = InputField(desc="Additional information that might be helpful to understanding the assigned reward value.") 201 | module_names: list[str] = InputField(desc="The names of the modules in the program, for which we seek advice") 202 | discussion: str = OutputField(desc="Discussing blame of where each module went wrong, if it did") 203 | module_advice: dict[str, str] = OutputField( 204 | desc="For each module, describe very concretely: If the module receives ${description of input or patterns " 205 | "therein}, then it should ${description of content, behavior, or strategies to adopt and/or others to avoid}. " 206 | "Basically, your advice be such that if the module has access to your tip, it would be much more likely to act " 207 | "like the successful trajectory rather than the lower-scoring trajectory." 208 | ) 209 | 210 | def inspect_modules(program): 211 | separator = "-" * 80 212 | output = [separator] 213 | 214 | for name, predictor in program.named_predictors(): 215 | signature = predictor.signature 216 | instructions = textwrap.dedent(signature.instructions) 217 | instructions = ("\n" + "\t" * 2).join([""] + instructions.splitlines()) 218 | 219 | output.append(f"Module {name}") 220 | output.append("\n\tInput Fields:") 221 | output.append(("\n" + "\t" * 2).join([""] + get_field_description_string(signature.input_fields).splitlines())) 222 | output.append("\tOutput Fields:") 223 | output.append(("\n" + "\t" * 2).join([""] + get_field_description_string(signature.output_fields).splitlines())) 224 | output.append(f"\tOriginal Instructions: {instructions}") 225 | output.append(separator) 226 | 227 | return "\n".join([o.strip("\n") for o in output]) 228 | 229 | 230 | def recursive_mask(o): 231 | # If the object is already serializable, return it. 232 | try: 233 | orjson.dumps(o) 234 | return o 235 | except (TypeError, orjson.JSONEncodeError): 236 | pass 237 | 238 | # If it's a dictionary, apply recursively to its values. 239 | if isinstance(o, dict): 240 | return {k: recursive_mask(v) for k, v in o.items()} 241 | # If it's a list, apply recursively. 242 | elif isinstance(o, list): 243 | return [recursive_mask(v) for v in o] 244 | # If it's a tuple, apply recursively. 245 | elif isinstance(o, tuple): 246 | return tuple(recursive_mask(v) for v in o) 247 | # Otherwise, replace it with a placeholder string (or use repr(o)). 248 | else: 249 | return f"<non-serializable: {type(o).__name__}>" 250 | ``` -------------------------------------------------------------------------------- /docs/docs/learn/programming/adapters.md: -------------------------------------------------------------------------------- ```markdown 1 | # Understanding DSPy Adapters 2 | 3 | ## What are Adapters? 4 | 5 | Adapters are the bridge between `dspy.Predict` and the actual Language Model (LM). When you call a DSPy module, the 6 | adapter takes your signature, user inputs, and other attributes like `demos` (few-shot examples) and converts them 7 | into multi-turn messages that get sent to the LM. 8 | 9 | The adapter system is responsible for: 10 | 11 | - Translating DSPy signatures into system messages that define the task and request/response structure. 12 | - Formatting input data according to the request structure outlined in DSPy signatures. 13 | - Parsing LM responses back into structured DSPy outputs, such as `dspy.Prediction` instances. 14 | - Managing conversation history and function calls. 15 | - Converting pre-built DSPy types into LM prompt messages, e.g., `dspy.Tool`, `dspy.Image`, etc. 16 | 17 | ## Configure Adapters 18 | 19 | You can use `dspy.configure(adapter=...)` to choose the adapter for the entire Python process, or 20 | `with dspy.context(adapter=...):` to only affect a certain namespace. 21 | 22 | If no adapter is specified in the DSPy workflow, each `dspy.Predict.__call__` defaults to using the `dspy.ChatAdapter`. Thus, the two code snippets below are equivalent: 23 | 24 | ```python 25 | import dspy 26 | 27 | dspy.configure(lm=dspy.LM("openai/gpt-4o-mini")) 28 | 29 | predict = dspy.Predict("question -> answer") 30 | result = predict(question="What is the capital of France?") 31 | ``` 32 | 33 | ```python 34 | import dspy 35 | 36 | dspy.configure( 37 | lm=dspy.LM("openai/gpt-4o-mini"), 38 | adapter=dspy.ChatAdapter(), # This is the default value 39 | ) 40 | 41 | predict = dspy.Predict("question -> answer") 42 | result = predict(question="What is the capital of France?") 43 | ``` 44 | 45 | ## Where Adapters Fit in the System 46 | 47 | The flow works as follows: 48 | 49 | 1. The user calls their DSPy agent, typically a `dspy.Module` with inputs. 50 | 2. The inner `dspy.Predict` is invoked to obtain the LM response. 51 | 3. `dspy.Predict` calls **Adapter.format()**, which converts its signature, inputs, and demos into multi-turn messages sent to the `dspy.LM`. `dspy.LM` is a thin wrapper around `litellm`, which communicates with the LM endpoint. 52 | 4. The LM receives the messages and generates a response. 53 | 5. **Adapter.parse()** converts the LM response into structured DSPy outputs, as specified in the signature. 54 | 6. The caller of `dspy.Predict` receives the parsed outputs. 55 | 56 | You can explicitly call `Adapter.format()` to view the messages sent to the LM. 57 | 58 | ```python 59 | # Simplified flow example 60 | signature = dspy.Signature("question -> answer") 61 | inputs = {"question": "What is 2+2?"} 62 | demos = [{"question": "What is 1+1?", "answer": "2"}] 63 | 64 | adapter = dspy.ChatAdapter() 65 | print(adapter.format(signature, demos, inputs)) 66 | ``` 67 | 68 | The output should resemble: 69 | 70 | ``` 71 | {'role': 'system', 'content': 'Your input fields are:\n1. `question` (str):\nYour output fields are:\n1. `answer` (str):\nAll interactions will be structured in the following way, with the appropriate values filled in.\n\n[[ ## question ## ]]\n{question}\n\n[[ ## answer ## ]]\n{answer}\n\n[[ ## completed ## ]]\nIn adhering to this structure, your objective is: \n Given the fields `question`, produce the fields `answer`.'} 72 | {'role': 'user', 'content': '[[ ## question ## ]]\nWhat is 1+1?'} 73 | {'role': 'assistant', 'content': '[[ ## answer ## ]]\n2\n\n[[ ## completed ## ]]\n'} 74 | {'role': 'user', 'content': '[[ ## question ## ]]\nWhat is 2+2?\n\nRespond with the corresponding output fields, starting with the field `[[ ## answer ## ]]`, and then ending with the marker for `[[ ## completed ## ]]`.'} 75 | ``` 76 | 77 | ## Types of Adapters 78 | 79 | DSPy offers several adapter types, each tailored for specific use cases: 80 | 81 | ### ChatAdapter 82 | 83 | **ChatAdapter** is the default adapter and works with all language models. It uses a field-based format with special markers. 84 | 85 | #### Format Structure 86 | 87 | ChatAdapter uses `[[ ## field_name ## ]]` markers to delineate fields. For fields of non-primitive Python types, it includes the JSON schema of the type. Below, we use `dspy.inspect_history()` to display the formatted messages by `dspy.ChatAdapter` clearly. 88 | 89 | ```python 90 | import dspy 91 | import pydantic 92 | 93 | dspy.configure(lm=dspy.LM("openai/gpt-4o-mini"), adapter=dspy.ChatAdapter()) 94 | 95 | 96 | class ScienceNews(pydantic.BaseModel): 97 | text: str 98 | scientists_involved: list[str] 99 | 100 | 101 | class NewsQA(dspy.Signature): 102 | """Get news about the given science field""" 103 | 104 | science_field: str = dspy.InputField() 105 | year: int = dspy.InputField() 106 | num_of_outputs: int = dspy.InputField() 107 | news: list[ScienceNews] = dspy.OutputField(desc="science news") 108 | 109 | 110 | predict = dspy.Predict(NewsQA) 111 | predict(science_field="Computer Theory", year=2022, num_of_outputs=1) 112 | dspy.inspect_history() 113 | ``` 114 | 115 | The output looks like: 116 | 117 | ``` 118 | [2025-08-15T22:24:29.378666] 119 | 120 | System message: 121 | 122 | Your input fields are: 123 | 1. `science_field` (str): 124 | 2. `year` (int): 125 | 3. `num_of_outputs` (int): 126 | Your output fields are: 127 | 1. `news` (list[ScienceNews]): science news 128 | All interactions will be structured in the following way, with the appropriate values filled in. 129 | 130 | [[ ## science_field ## ]] 131 | {science_field} 132 | 133 | [[ ## year ## ]] 134 | {year} 135 | 136 | [[ ## num_of_outputs ## ]] 137 | {num_of_outputs} 138 | 139 | [[ ## news ## ]] 140 | {news} # note: the value you produce must adhere to the JSON schema: {"type": "array", "$defs": {"ScienceNews": {"type": "object", "properties": {"scientists_involved": {"type": "array", "items": {"type": "string"}, "title": "Scientists Involved"}, "text": {"type": "string", "title": "Text"}}, "required": ["text", "scientists_involved"], "title": "ScienceNews"}}, "items": {"$ref": "#/$defs/ScienceNews"}} 141 | 142 | [[ ## completed ## ]] 143 | In adhering to this structure, your objective is: 144 | Get news about the given science field 145 | 146 | 147 | User message: 148 | 149 | [[ ## science_field ## ]] 150 | Computer Theory 151 | 152 | [[ ## year ## ]] 153 | 2022 154 | 155 | [[ ## num_of_outputs ## ]] 156 | 1 157 | 158 | Respond with the corresponding output fields, starting with the field `[[ ## news ## ]]` (must be formatted as a valid Python list[ScienceNews]), and then ending with the marker for `[[ ## completed ## ]]`. 159 | 160 | 161 | Response: 162 | 163 | [[ ## news ## ]] 164 | [ 165 | { 166 | "scientists_involved": ["John Doe", "Jane Smith"], 167 | "text": "In 2022, researchers made significant advancements in quantum computing algorithms, demonstrating their potential to solve complex problems faster than classical computers. This breakthrough could revolutionize fields such as cryptography and optimization." 168 | } 169 | ] 170 | 171 | [[ ## completed ## ]] 172 | ``` 173 | 174 | !!! info "Practice: locate Signature information in the printed LM history" 175 | 176 | Try adjusting the signature, and observe how the changes are reflected in the printed LM message. 177 | 178 | 179 | Each field is preceded by a marker `[[ ## field_name ## ]]`. If an output field has non-primitive types, the instruction includes the type's JSON schema, and the output is formatted accordingly. Because the output field is structured as defined by ChatAdapter, it can be automatically parsed into structured data. 180 | 181 | #### When to Use ChatAdapter 182 | 183 | `ChatAdapter` offers the following advantages: 184 | 185 | - **Universal compatibility**: Works with all language models, though smaller models may generate responses that do not match the required format. 186 | - **Fallback protection**: If `ChatAdapter` fails, it automatically retries with `JSONAdapter`. 187 | 188 | In general, `ChatAdapter` is a reliable choice if you don't have specific requirements. 189 | 190 | #### When Not to Use ChatAdapter 191 | 192 | Avoid using `ChatAdapter` if you are: 193 | 194 | - **Latency sensitive**: `ChatAdapter` includes more boilerplate output tokens compared to other adapters, so if you're building a system sensitive to latency, consider using a different adapter. 195 | 196 | ### JSONAdapter 197 | 198 | **JSONAdapter** prompts the LM to return JSON data containing all output fields as specified in the signature. It is effective for models that support structured output via the `response_format` parameter, leveraging native JSON generation capabilities for more reliable parsing. 199 | 200 | #### Format Structure 201 | 202 | The input part of the prompt formatted by `JSONAdapter` is similar to `ChatAdapter`, but the output part differs, as shown below: 203 | 204 | ```python 205 | import dspy 206 | import pydantic 207 | 208 | dspy.configure(lm=dspy.LM("openai/gpt-4o-mini"), adapter=dspy.JSONAdapter()) 209 | 210 | 211 | class ScienceNews(pydantic.BaseModel): 212 | text: str 213 | scientists_involved: list[str] 214 | 215 | 216 | class NewsQA(dspy.Signature): 217 | """Get news about the given science field""" 218 | 219 | science_field: str = dspy.InputField() 220 | year: int = dspy.InputField() 221 | num_of_outputs: int = dspy.InputField() 222 | news: list[ScienceNews] = dspy.OutputField(desc="science news") 223 | 224 | 225 | predict = dspy.Predict(NewsQA) 226 | predict(science_field="Computer Theory", year=2022, num_of_outputs=1) 227 | dspy.inspect_history() 228 | ``` 229 | 230 | ``` 231 | System message: 232 | 233 | Your input fields are: 234 | 1. `science_field` (str): 235 | 2. `year` (int): 236 | 3. `num_of_outputs` (int): 237 | Your output fields are: 238 | 1. `news` (list[ScienceNews]): science news 239 | All interactions will be structured in the following way, with the appropriate values filled in. 240 | 241 | Inputs will have the following structure: 242 | 243 | [[ ## science_field ## ]] 244 | {science_field} 245 | 246 | [[ ## year ## ]] 247 | {year} 248 | 249 | [[ ## num_of_outputs ## ]] 250 | {num_of_outputs} 251 | 252 | Outputs will be a JSON object with the following fields. 253 | 254 | { 255 | "news": "{news} # note: the value you produce must adhere to the JSON schema: {\"type\": \"array\", \"$defs\": {\"ScienceNews\": {\"type\": \"object\", \"properties\": {\"scientists_involved\": {\"type\": \"array\", \"items\": {\"type\": \"string\"}, \"title\": \"Scientists Involved\"}, \"text\": {\"type\": \"string\", \"title\": \"Text\"}}, \"required\": [\"text\", \"scientists_involved\"], \"title\": \"ScienceNews\"}}, \"items\": {\"$ref\": \"#/$defs/ScienceNews\"}}" 256 | } 257 | In adhering to this structure, your objective is: 258 | Get news about the given science field 259 | 260 | 261 | User message: 262 | 263 | [[ ## science_field ## ]] 264 | Computer Theory 265 | 266 | [[ ## year ## ]] 267 | 2022 268 | 269 | [[ ## num_of_outputs ## ]] 270 | 1 271 | 272 | Respond with a JSON object in the following order of fields: `news` (must be formatted as a valid Python list[ScienceNews]). 273 | 274 | 275 | Response: 276 | 277 | { 278 | "news": [ 279 | { 280 | "text": "In 2022, researchers made significant advancements in quantum computing algorithms, demonstrating that quantum systems can outperform classical computers in specific tasks. This breakthrough could revolutionize fields such as cryptography and complex system simulations.", 281 | "scientists_involved": [ 282 | "Dr. Alice Smith", 283 | "Dr. Bob Johnson", 284 | "Dr. Carol Lee" 285 | ] 286 | } 287 | ] 288 | } 289 | ``` 290 | 291 | #### When to Use JSONAdapter 292 | 293 | `JSONAdapter` is good at: 294 | 295 | - **Structured output support**: When the model supports the `response_format` parameter. 296 | - **Low latency**: Minimal boilerplate in the LM response results in faster responses. 297 | 298 | #### When Not to Use JSONAdapter 299 | 300 | Avoid using `JSONAdapter` if you are: 301 | 302 | - Using a model that does not natively support structured output, such as a small open-source model hosted on Ollama. 303 | 304 | ## Summary 305 | 306 | Adapters are a crucial component of DSPy that bridge the gap between structured DSPy signatures and language model APIs. 307 | Understanding when and how to use different adapters will help you build more reliable and efficient DSPy programs. 308 | ``` -------------------------------------------------------------------------------- /dspy/predict/react.py: -------------------------------------------------------------------------------- ```python 1 | import logging 2 | from typing import TYPE_CHECKING, Any, Callable, Literal 3 | 4 | from litellm import ContextWindowExceededError 5 | 6 | import dspy 7 | from dspy.adapters.types.tool import Tool 8 | from dspy.primitives.module import Module 9 | from dspy.signatures.signature import ensure_signature 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | if TYPE_CHECKING: 14 | from dspy.signatures.signature import Signature 15 | 16 | 17 | class ReAct(Module): 18 | def __init__(self, signature: type["Signature"], tools: list[Callable], max_iters: int = 10): 19 | """ 20 | ReAct stands for "Reasoning and Acting," a popular paradigm for building tool-using agents. 21 | In this approach, the language model is iteratively provided with a list of tools and has 22 | to reason about the current situation. The model decides whether to call a tool to gather more 23 | information or to finish the task based on its reasoning process. The DSPy version of ReAct is 24 | generalized to work over any signature, thanks to signature polymorphism. 25 | 26 | Args: 27 | signature: The signature of the module, which defines the input and output of the react module. 28 | tools (list[Callable]): A list of functions, callable objects, or `dspy.Tool` instances. 29 | max_iters (Optional[int]): The maximum number of iterations to run. Defaults to 10. 30 | 31 | Example: 32 | 33 | ```python 34 | def get_weather(city: str) -> str: 35 | return f"The weather in {city} is sunny." 36 | 37 | react = dspy.ReAct(signature="question->answer", tools=[get_weather]) 38 | pred = react(question="What is the weather in Tokyo?") 39 | ``` 40 | """ 41 | super().__init__() 42 | self.signature = signature = ensure_signature(signature) 43 | self.max_iters = max_iters 44 | 45 | tools = [t if isinstance(t, Tool) else Tool(t) for t in tools] 46 | tools = {tool.name: tool for tool in tools} 47 | 48 | inputs = ", ".join([f"`{k}`" for k in signature.input_fields.keys()]) 49 | outputs = ", ".join([f"`{k}`" for k in signature.output_fields.keys()]) 50 | instr = [f"{signature.instructions}\n"] if signature.instructions else [] 51 | 52 | instr.extend( 53 | [ 54 | f"You are an Agent. In each episode, you will be given the fields {inputs} as input. And you can see your past trajectory so far.", 55 | f"Your goal is to use one or more of the supplied tools to collect any necessary information for producing {outputs}.\n", 56 | "To do this, you will interleave next_thought, next_tool_name, and next_tool_args in each turn, and also when finishing the task.", 57 | "After each tool call, you receive a resulting observation, which gets appended to your trajectory.\n", 58 | "When writing next_thought, you may reason about the current situation and plan for future steps.", 59 | "When selecting the next_tool_name and its next_tool_args, the tool must be one of:\n", 60 | ] 61 | ) 62 | 63 | tools["finish"] = Tool( 64 | func=lambda: "Completed.", 65 | name="finish", 66 | desc=f"Marks the task as complete. That is, signals that all information for producing the outputs, i.e. {outputs}, are now available to be extracted.", 67 | args={}, 68 | ) 69 | 70 | for idx, tool in enumerate(tools.values()): 71 | instr.append(f"({idx + 1}) {tool}") 72 | instr.append("When providing `next_tool_args`, the value inside the field must be in JSON format") 73 | 74 | react_signature = ( 75 | dspy.Signature({**signature.input_fields}, "\n".join(instr)) 76 | .append("trajectory", dspy.InputField(), type_=str) 77 | .append("next_thought", dspy.OutputField(), type_=str) 78 | .append("next_tool_name", dspy.OutputField(), type_=Literal[tuple(tools.keys())]) 79 | .append("next_tool_args", dspy.OutputField(), type_=dict[str, Any]) 80 | ) 81 | 82 | fallback_signature = dspy.Signature( 83 | {**signature.input_fields, **signature.output_fields}, 84 | signature.instructions, 85 | ).append("trajectory", dspy.InputField(), type_=str) 86 | 87 | self.tools = tools 88 | self.react = dspy.Predict(react_signature) 89 | self.extract = dspy.ChainOfThought(fallback_signature) 90 | 91 | def _format_trajectory(self, trajectory: dict[str, Any]): 92 | adapter = dspy.settings.adapter or dspy.ChatAdapter() 93 | trajectory_signature = dspy.Signature(f"{', '.join(trajectory.keys())} -> x") 94 | return adapter.format_user_message_content(trajectory_signature, trajectory) 95 | 96 | def forward(self, **input_args): 97 | trajectory = {} 98 | max_iters = input_args.pop("max_iters", self.max_iters) 99 | for idx in range(max_iters): 100 | try: 101 | pred = self._call_with_potential_trajectory_truncation(self.react, trajectory, **input_args) 102 | except ValueError as err: 103 | logger.warning(f"Ending the trajectory: Agent failed to select a valid tool: {_fmt_exc(err)}") 104 | break 105 | 106 | trajectory[f"thought_{idx}"] = pred.next_thought 107 | trajectory[f"tool_name_{idx}"] = pred.next_tool_name 108 | trajectory[f"tool_args_{idx}"] = pred.next_tool_args 109 | 110 | try: 111 | trajectory[f"observation_{idx}"] = self.tools[pred.next_tool_name](**pred.next_tool_args) 112 | except Exception as err: 113 | trajectory[f"observation_{idx}"] = f"Execution error in {pred.next_tool_name}: {_fmt_exc(err)}" 114 | 115 | if pred.next_tool_name == "finish": 116 | break 117 | 118 | extract = self._call_with_potential_trajectory_truncation(self.extract, trajectory, **input_args) 119 | return dspy.Prediction(trajectory=trajectory, **extract) 120 | 121 | async def aforward(self, **input_args): 122 | trajectory = {} 123 | max_iters = input_args.pop("max_iters", self.max_iters) 124 | for idx in range(max_iters): 125 | try: 126 | pred = await self._async_call_with_potential_trajectory_truncation(self.react, trajectory, **input_args) 127 | except ValueError as err: 128 | logger.warning(f"Ending the trajectory: Agent failed to select a valid tool: {_fmt_exc(err)}") 129 | break 130 | 131 | trajectory[f"thought_{idx}"] = pred.next_thought 132 | trajectory[f"tool_name_{idx}"] = pred.next_tool_name 133 | trajectory[f"tool_args_{idx}"] = pred.next_tool_args 134 | 135 | try: 136 | trajectory[f"observation_{idx}"] = await self.tools[pred.next_tool_name].acall(**pred.next_tool_args) 137 | except Exception as err: 138 | trajectory[f"observation_{idx}"] = f"Execution error in {pred.next_tool_name}: {_fmt_exc(err)}" 139 | 140 | if pred.next_tool_name == "finish": 141 | break 142 | 143 | extract = await self._async_call_with_potential_trajectory_truncation(self.extract, trajectory, **input_args) 144 | return dspy.Prediction(trajectory=trajectory, **extract) 145 | 146 | def _call_with_potential_trajectory_truncation(self, module, trajectory, **input_args): 147 | for _ in range(3): 148 | try: 149 | return module( 150 | **input_args, 151 | trajectory=self._format_trajectory(trajectory), 152 | ) 153 | except ContextWindowExceededError: 154 | logger.warning("Trajectory exceeded the context window, truncating the oldest tool call information.") 155 | trajectory = self.truncate_trajectory(trajectory) 156 | 157 | async def _async_call_with_potential_trajectory_truncation(self, module, trajectory, **input_args): 158 | for _ in range(3): 159 | try: 160 | return await module.acall( 161 | **input_args, 162 | trajectory=self._format_trajectory(trajectory), 163 | ) 164 | except ContextWindowExceededError: 165 | logger.warning("Trajectory exceeded the context window, truncating the oldest tool call information.") 166 | trajectory = self.truncate_trajectory(trajectory) 167 | 168 | def truncate_trajectory(self, trajectory): 169 | """Truncates the trajectory so that it fits in the context window. 170 | 171 | Users can override this method to implement their own truncation logic. 172 | """ 173 | keys = list(trajectory.keys()) 174 | if len(keys) < 4: 175 | # Every tool call has 4 keys: thought, tool_name, tool_args, and observation. 176 | raise ValueError( 177 | "The trajectory is too long so your prompt exceeded the context window, but the trajectory cannot be " 178 | "truncated because it only has one tool call." 179 | ) 180 | 181 | for key in keys[:4]: 182 | trajectory.pop(key) 183 | 184 | return trajectory 185 | 186 | 187 | def _fmt_exc(err: BaseException, *, limit: int = 5) -> str: 188 | """ 189 | Return a one-string traceback summary. 190 | * `limit` - how many stack frames to keep (from the innermost outwards). 191 | """ 192 | 193 | import traceback 194 | 195 | return "\n" + "".join(traceback.format_exception(type(err), err, err.__traceback__, limit=limit)).strip() 196 | 197 | 198 | """ 199 | Thoughts and Planned Improvements for dspy.ReAct. 200 | 201 | TOPIC 01: How Trajectories are Formatted, or rather when they are formatted. 202 | 203 | Right now, both sub-modules are invoked with a `trajectory` argument, which is a string formatted in `forward`. Though 204 | the formatter uses a general adapter.format_fields, the tracing of DSPy only sees the string, not the formatting logic. 205 | 206 | What this means is that, in demonstrations, even if the user adjusts the adapter for a fixed program, the demos' format 207 | will not update accordingly, but the inference-time trajectories will. 208 | 209 | One way to fix this is to support `format=fn` in the dspy.InputField() for "trajectory" in the signatures. But this 210 | means that care must be taken that the adapter is accessed at `forward` runtime, not signature definition time. 211 | 212 | Another potential fix is to more natively support a "variadic" input field, where the input is a list of dictionaries, 213 | or a big dictionary, and have each adapter format it accordingly. 214 | 215 | Trajectories also affect meta-programming modules that view the trace later. It's inefficient O(n^2) to view the 216 | trace of every module repeating the prefix. 217 | 218 | 219 | TOPIC 03: Simplifying ReAct's __init__ by moving modular logic to the Tool class. 220 | * Handling exceptions and error messages. 221 | * More cleanly defining the "finish" tool, perhaps as a runtime-defined function? 222 | 223 | 224 | TOPIC 04: Default behavior when the trajectory gets too long. 225 | 226 | 227 | TOPIC 05: Adding more structure around how the instruction is formatted. 228 | * Concretely, it's now a string, so an optimizer can and does rewrite it freely. 229 | * An alternative would be to add more structure, such that a certain template is fixed but values are variable? 230 | 231 | 232 | TOPIC 06: Idiomatically allowing tools that maintain state across iterations, but not across different `forward` calls. 233 | * So the tool would be newly initialized at the start of each `forward` call, but maintain state across iterations. 234 | * This is pretty useful for allowing the agent to keep notes or count certain things, etc. 235 | """ 236 | ``` -------------------------------------------------------------------------------- /tests/adapters/test_xml_adapter.py: -------------------------------------------------------------------------------- ```python 1 | from unittest import mock 2 | 3 | import pydantic 4 | import pytest 5 | from litellm import Choices, Message, ModelResponse 6 | 7 | import dspy 8 | from dspy.adapters.chat_adapter import FieldInfoWithName 9 | from dspy.adapters.xml_adapter import XMLAdapter 10 | 11 | 12 | def test_xml_adapter_format_and_parse_basic(): 13 | class TestSignature(dspy.Signature): 14 | question: str = dspy.InputField() 15 | answer: str = dspy.OutputField() 16 | 17 | adapter = XMLAdapter() 18 | # Format output fields as XML 19 | fields_with_values = {FieldInfoWithName(name="answer", info=TestSignature.output_fields["answer"]): "Paris"} 20 | xml = adapter.format_field_with_value(fields_with_values) 21 | assert xml.strip() == "<answer>\nParis\n</answer>" 22 | 23 | # Parse XML output 24 | completion = "<answer>Paris</answer>" 25 | parsed = adapter.parse(TestSignature, completion) 26 | assert parsed == {"answer": "Paris"} 27 | 28 | 29 | def test_xml_adapter_parse_multiple_fields(): 30 | class TestSignature(dspy.Signature): 31 | question: str = dspy.InputField() 32 | answer: str = dspy.OutputField() 33 | explanation: str = dspy.OutputField() 34 | 35 | adapter = XMLAdapter() 36 | completion = """ 37 | <answer>Paris</answer> 38 | <explanation>The capital of France is Paris.</explanation> 39 | """ 40 | parsed = adapter.parse(TestSignature, completion) 41 | assert parsed == {"answer": "Paris", "explanation": "The capital of France is Paris."} 42 | 43 | 44 | def test_xml_adapter_parse_raises_on_missing_field(): 45 | class TestSignature(dspy.Signature): 46 | question: str = dspy.InputField() 47 | answer: str = dspy.OutputField() 48 | explanation: str = dspy.OutputField() 49 | 50 | adapter = XMLAdapter() 51 | completion = "<answer>Paris</answer>" 52 | with pytest.raises(dspy.utils.exceptions.AdapterParseError) as e: 53 | adapter.parse(TestSignature, completion) 54 | assert e.value.adapter_name == "XMLAdapter" 55 | assert e.value.signature == TestSignature 56 | assert e.value.lm_response == "<answer>Paris</answer>" 57 | assert "explanation" in str(e.value) 58 | 59 | 60 | def test_xml_adapter_parse_casts_types(): 61 | class TestSignature(dspy.Signature): 62 | number: int = dspy.OutputField() 63 | flag: bool = dspy.OutputField() 64 | 65 | adapter = XMLAdapter() 66 | completion = """ 67 | <number>42</number> 68 | <flag>true</flag> 69 | """ 70 | parsed = adapter.parse(TestSignature, completion) 71 | assert parsed == {"number": 42, "flag": True} 72 | 73 | 74 | def test_xml_adapter_parse_raises_on_type_error(): 75 | class TestSignature(dspy.Signature): 76 | number: int = dspy.OutputField() 77 | 78 | adapter = XMLAdapter() 79 | completion = "<number>not_a_number</number>" 80 | with pytest.raises(dspy.utils.exceptions.AdapterParseError) as e: 81 | adapter.parse(TestSignature, completion) 82 | assert "Failed to parse field" in str(e.value) 83 | 84 | 85 | def test_xml_adapter_format_and_parse_nested_model(): 86 | class InnerModel(pydantic.BaseModel): 87 | value: int 88 | label: str 89 | 90 | class TestSignature(dspy.Signature): 91 | question: str = dspy.InputField() 92 | result: InnerModel = dspy.OutputField() 93 | 94 | adapter = XMLAdapter() 95 | # Format output fields as XML 96 | fields_with_values = { 97 | FieldInfoWithName(name="result", info=TestSignature.output_fields["result"]): InnerModel(value=5, label="foo") 98 | } 99 | xml = adapter.format_field_with_value(fields_with_values) 100 | # The output will be a JSON string inside the XML tag 101 | assert xml.strip().startswith("<result>") 102 | assert '"value": 5' in xml 103 | assert '"label": "foo"' in xml 104 | assert xml.strip().endswith("</result>") 105 | 106 | # Parse XML output (should parse as string, not as model) 107 | completion = '<result>{"value": 5, "label": "foo"}</result>' 108 | parsed = adapter.parse(TestSignature, completion) 109 | # The parse_value helper will try to cast to InnerModel 110 | assert isinstance(parsed["result"], InnerModel) 111 | assert parsed["result"].value == 5 112 | assert parsed["result"].label == "foo" 113 | 114 | 115 | def test_xml_adapter_format_and_parse_list_of_models(): 116 | class Item(pydantic.BaseModel): 117 | name: str 118 | score: float 119 | 120 | class TestSignature(dspy.Signature): 121 | items: list[Item] = dspy.OutputField() 122 | 123 | adapter = XMLAdapter() 124 | items = [Item(name="a", score=1.1), Item(name="b", score=2.2)] 125 | fields_with_values = {FieldInfoWithName(name="items", info=TestSignature.output_fields["items"]): items} 126 | xml = adapter.format_field_with_value(fields_with_values) 127 | assert xml.strip().startswith("<items>") 128 | assert '"name": "a"' in xml 129 | assert '"score": 2.2' in xml 130 | assert xml.strip().endswith("</items>") 131 | 132 | # Parse XML output 133 | import json 134 | 135 | completion = f"<items>{json.dumps([i.model_dump() for i in items])}</items>" 136 | parsed = adapter.parse(TestSignature, completion) 137 | assert isinstance(parsed["items"], list) 138 | assert all(isinstance(i, Item) for i in parsed["items"]) 139 | assert parsed["items"][0].name == "a" 140 | assert parsed["items"][1].score == 2.2 141 | 142 | 143 | def test_xml_adapter_with_tool_like_output(): 144 | # XMLAdapter does not natively support tool calls, but we can test structured output 145 | class ToolCall(pydantic.BaseModel): 146 | name: str 147 | args: dict 148 | result: str 149 | 150 | class TestSignature(dspy.Signature): 151 | question: str = dspy.InputField() 152 | tool_calls: list[ToolCall] = dspy.OutputField() 153 | answer: str = dspy.OutputField() 154 | 155 | adapter = XMLAdapter() 156 | tool_calls = [ 157 | ToolCall(name="get_weather", args={"city": "Tokyo"}, result="Sunny"), 158 | ToolCall(name="get_population", args={"country": "Japan", "year": 2023}, result="125M"), 159 | ] 160 | fields_with_values = { 161 | FieldInfoWithName(name="tool_calls", info=TestSignature.output_fields["tool_calls"]): tool_calls, 162 | FieldInfoWithName( 163 | name="answer", info=TestSignature.output_fields["answer"] 164 | ): "The weather is Sunny. Population is 125M.", 165 | } 166 | xml = adapter.format_field_with_value(fields_with_values) 167 | assert xml.strip().startswith("<tool_calls>") 168 | assert '"name": "get_weather"' in xml 169 | assert '"result": "125M"' in xml 170 | assert xml.strip().endswith("</answer>") 171 | 172 | import json 173 | 174 | completion = ( 175 | f"<tool_calls>{json.dumps([tc.model_dump() for tc in tool_calls])}</tool_calls>" 176 | f"\n<answer>The weather is Sunny. Population is 125M.</answer>" 177 | ) 178 | parsed = adapter.parse(TestSignature, completion) 179 | assert isinstance(parsed["tool_calls"], list) 180 | assert parsed["tool_calls"][0].name == "get_weather" 181 | assert parsed["tool_calls"][1].result == "125M" 182 | assert parsed["answer"] == "The weather is Sunny. Population is 125M." 183 | 184 | 185 | def test_xml_adapter_formats_nested_images(): 186 | class ImageWrapper(pydantic.BaseModel): 187 | images: list[dspy.Image] 188 | tag: list[str] 189 | 190 | class MySignature(dspy.Signature): 191 | image: ImageWrapper = dspy.InputField() 192 | text: str = dspy.OutputField() 193 | 194 | image1 = dspy.Image(url="https://example.com/image1.jpg") 195 | image2 = dspy.Image(url="https://example.com/image2.jpg") 196 | image3 = dspy.Image(url="https://example.com/image3.jpg") 197 | 198 | image_wrapper = ImageWrapper(images=[image1, image2, image3], tag=["test", "example"]) 199 | demos = [ 200 | dspy.Example( 201 | image=image_wrapper, 202 | text="This is a test image", 203 | ), 204 | ] 205 | 206 | image_wrapper_2 = ImageWrapper(images=[dspy.Image(url="https://example.com/image4.jpg")], tag=["test", "example"]) 207 | adapter = dspy.XMLAdapter() 208 | messages = adapter.format(MySignature, demos, {"image": image_wrapper_2}) 209 | 210 | assert len(messages) == 4 211 | 212 | # Image information in the few-shot example's user message 213 | expected_image1_content = {"type": "image_url", "image_url": {"url": "https://example.com/image1.jpg"}} 214 | expected_image2_content = {"type": "image_url", "image_url": {"url": "https://example.com/image2.jpg"}} 215 | expected_image3_content = {"type": "image_url", "image_url": {"url": "https://example.com/image3.jpg"}} 216 | assert expected_image1_content in messages[1]["content"] 217 | assert expected_image2_content in messages[1]["content"] 218 | assert expected_image3_content in messages[1]["content"] 219 | 220 | # The query image is formatted in the last user message 221 | assert {"type": "image_url", "image_url": {"url": "https://example.com/image4.jpg"}} in messages[-1]["content"] 222 | 223 | 224 | def test_xml_adapter_with_code(): 225 | # Test with code as input field 226 | class CodeAnalysis(dspy.Signature): 227 | """Analyze the time complexity of the code""" 228 | 229 | code: dspy.Code = dspy.InputField() 230 | result: str = dspy.OutputField() 231 | 232 | adapter = dspy.XMLAdapter() 233 | messages = adapter.format(CodeAnalysis, [], {"code": "print('Hello, world!')"}) 234 | 235 | assert len(messages) == 2 236 | 237 | # The output field type description should be included in the system message even if the output field is nested 238 | assert dspy.Code.description() in messages[0]["content"] 239 | 240 | # The user message should include the question and the tools 241 | assert "print('Hello, world!')" in messages[1]["content"] 242 | 243 | # Test with code as output field 244 | class CodeGeneration(dspy.Signature): 245 | """Generate code to answer the question""" 246 | 247 | question: str = dspy.InputField() 248 | code: dspy.Code = dspy.OutputField() 249 | 250 | adapter = dspy.XMLAdapter() 251 | with mock.patch("litellm.completion") as mock_completion: 252 | mock_completion.return_value = ModelResponse( 253 | choices=[Choices(message=Message(content='<code>print("Hello, world!")</code>'))], 254 | model="openai/gpt-4o-mini", 255 | ) 256 | result = adapter( 257 | dspy.LM(model="openai/gpt-4o-mini", cache=False), 258 | {}, 259 | CodeGeneration, 260 | [], 261 | {"question": "Write a python program to print 'Hello, world!'"}, 262 | ) 263 | assert result[0]["code"].code == 'print("Hello, world!")' 264 | 265 | 266 | def test_xml_adapter_full_prompt(): 267 | class QA(dspy.Signature): 268 | query: str = dspy.InputField() 269 | context: str | None = dspy.InputField() 270 | answer: str = dspy.OutputField() 271 | 272 | adapter = dspy.XMLAdapter() 273 | messages = adapter.format(QA, [], {"query": "when was Marie Curie born"}) 274 | 275 | assert len(messages) == 2 276 | assert messages[0]["role"] == "system" 277 | assert messages[1]["role"] == "user" 278 | 279 | expected_system = ( 280 | "Your input fields are:\n" 281 | "1. `query` (str): \n" 282 | "2. `context` (UnionType[str, NoneType]):\n" 283 | "Your output fields are:\n" 284 | "1. `answer` (str):\n" 285 | "All interactions will be structured in the following way, with the appropriate values filled in.\n\n" 286 | "<query>\n{query}\n</query>\n\n" 287 | "<context>\n{context}\n</context>\n\n" 288 | "<answer>\n{answer}\n</answer>\n" 289 | "In adhering to this structure, your objective is: \n" 290 | " Given the fields `query`, `context`, produce the fields `answer`." 291 | ) 292 | 293 | expected_user = ( 294 | "[[ ## query ## ]]\nwhen was Marie Curie born\n\n" 295 | "Respond with the corresponding output fields wrapped in XML tags `<answer>`." 296 | ) 297 | 298 | assert messages[0]["content"] == expected_system 299 | assert messages[1]["content"] == expected_user 300 | ``` -------------------------------------------------------------------------------- /dspy/primitives/base_module.py: -------------------------------------------------------------------------------- ```python 1 | import copy 2 | import logging 3 | from collections import deque 4 | from collections.abc import Generator 5 | from pathlib import Path 6 | 7 | import cloudpickle 8 | import orjson 9 | 10 | from dspy.utils.saving import get_dependency_versions 11 | 12 | # NOTE: Note: It's important (temporary decision) to maintain named_parameters that's different in behavior from 13 | # named_sub_modules for the time being. 14 | 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | class BaseModule: 20 | def __init__(self): 21 | pass 22 | 23 | def named_parameters(self): 24 | """ 25 | Unlike PyTorch, handles (non-recursive) lists of parameters too. 26 | """ 27 | 28 | import dspy 29 | from dspy.predict.parameter import Parameter 30 | 31 | visited = set() 32 | named_parameters = [] 33 | 34 | def add_parameter(param_name, param_value): 35 | if isinstance(param_value, Parameter): 36 | if id(param_value) not in visited: 37 | visited.add(id(param_value)) 38 | named_parameters.append((param_name, param_value)) 39 | 40 | elif isinstance(param_value, dspy.Module): 41 | # When a sub-module is pre-compiled, keep it frozen. 42 | if not getattr(param_value, "_compiled", False): 43 | for sub_name, param in param_value.named_parameters(): 44 | add_parameter(f"{param_name}.{sub_name}", param) 45 | 46 | if isinstance(self, Parameter): 47 | add_parameter("self", self) 48 | 49 | for name, value in self.__dict__.items(): 50 | if isinstance(value, Parameter): 51 | add_parameter(name, value) 52 | 53 | elif isinstance(value, dspy.Module): 54 | # When a sub-module is pre-compiled, keep it frozen. 55 | if not getattr(value, "_compiled", False): 56 | for sub_name, param in value.named_parameters(): 57 | add_parameter(f"{name}.{sub_name}", param) 58 | 59 | elif isinstance(value, (list, tuple)): 60 | for idx, item in enumerate(value): 61 | add_parameter(f"{name}[{idx}]", item) 62 | 63 | elif isinstance(value, dict): 64 | for key, item in value.items(): 65 | add_parameter(f"{name}['{key}']", item) 66 | 67 | return named_parameters 68 | 69 | def named_sub_modules(self, type_=None, skip_compiled=False) -> Generator[tuple[str, "BaseModule"], None, None]: 70 | """Find all sub-modules in the module, as well as their names. 71 | 72 | Say `self.children[4]['key'].sub_module` is a sub-module. Then the name will be 73 | `children[4]['key'].sub_module`. But if the sub-module is accessible at different 74 | paths, only one of the paths will be returned. 75 | """ 76 | if type_ is None: 77 | type_ = BaseModule 78 | 79 | queue = deque([("self", self)]) 80 | seen = {id(self)} 81 | 82 | def add_to_queue(name, item): 83 | if id(item) not in seen: 84 | seen.add(id(item)) 85 | queue.append((name, item)) 86 | 87 | while queue: 88 | name, item = queue.popleft() 89 | 90 | if isinstance(item, type_): 91 | yield name, item 92 | 93 | if isinstance(item, BaseModule): 94 | if skip_compiled and getattr(item, "_compiled", False): 95 | continue 96 | for sub_name, sub_item in item.__dict__.items(): 97 | add_to_queue(f"{name}.{sub_name}", sub_item) 98 | 99 | elif isinstance(item, (list, tuple)): 100 | for i, sub_item in enumerate(item): 101 | add_to_queue(f"{name}[{i}]", sub_item) 102 | 103 | elif isinstance(item, dict): 104 | for key, sub_item in item.items(): 105 | add_to_queue(f"{name}[{key}]", sub_item) 106 | 107 | def parameters(self): 108 | return [param for _, param in self.named_parameters()] 109 | 110 | def deepcopy(self): 111 | """Deep copy the module. 112 | 113 | This is a tweak to the default python deepcopy that only deep copies `self.parameters()`, and for other 114 | attributes, we just do the shallow copy. 115 | """ 116 | try: 117 | # If the instance itself is copyable, we can just deep copy it. 118 | # Otherwise we will have to create a new instance and copy over the attributes one by one. 119 | return copy.deepcopy(self) 120 | except Exception: 121 | pass 122 | 123 | # Create an empty instance. 124 | new_instance = self.__class__.__new__(self.__class__) 125 | # Set attribuetes of the copied instance. 126 | for attr, value in self.__dict__.items(): 127 | if isinstance(value, BaseModule): 128 | setattr(new_instance, attr, value.deepcopy()) 129 | else: 130 | try: 131 | # Try to deep copy the attribute 132 | setattr(new_instance, attr, copy.deepcopy(value)) 133 | except Exception: 134 | logging.warning( 135 | f"Failed to deep copy attribute '{attr}' of {self.__class__.__name__}, " 136 | "falling back to shallow copy or reference copy." 137 | ) 138 | try: 139 | # Fallback to shallow copy if deep copy fails 140 | setattr(new_instance, attr, copy.copy(value)) 141 | except Exception: 142 | # If even the shallow copy fails, we just copy over the reference. 143 | setattr(new_instance, attr, value) 144 | 145 | return new_instance 146 | 147 | def reset_copy(self): 148 | """Deep copy the module and reset all parameters.""" 149 | new_instance = self.deepcopy() 150 | 151 | for param in new_instance.parameters(): 152 | param.reset() 153 | 154 | return new_instance 155 | 156 | def dump_state(self, json_mode=True): 157 | return {name: param.dump_state(json_mode=json_mode) for name, param in self.named_parameters()} 158 | 159 | def load_state(self, state): 160 | for name, param in self.named_parameters(): 161 | param.load_state(state[name]) 162 | 163 | def save(self, path, save_program=False, modules_to_serialize=None): 164 | """Save the module. 165 | 166 | Save the module to a directory or a file. There are two modes: 167 | - `save_program=False`: Save only the state of the module to a json or pickle file, based on the value of 168 | the file extension. 169 | - `save_program=True`: Save the whole module to a directory via cloudpickle, which contains both the state and 170 | architecture of the model. 171 | 172 | If `save_program=True` and `modules_to_serialize` are provided, it will register those modules for serialization 173 | with cloudpickle's `register_pickle_by_value`. This causes cloudpickle to serialize the module by value rather 174 | than by reference, ensuring the module is fully preserved along with the saved program. This is useful 175 | when you have custom modules that need to be serialized alongside your program. If None, then no modules 176 | will be registered for serialization. 177 | 178 | We also save the dependency versions, so that the loaded model can check if there is a version mismatch on 179 | critical dependencies or DSPy version. 180 | 181 | Args: 182 | path (str): Path to the saved state file, which should be a .json or .pkl file when `save_program=False`, 183 | and a directory when `save_program=True`. 184 | save_program (bool): If True, save the whole module to a directory via cloudpickle, otherwise only save 185 | the state. 186 | modules_to_serialize (list): A list of modules to serialize with cloudpickle's `register_pickle_by_value`. 187 | If None, then no modules will be registered for serialization. 188 | 189 | """ 190 | metadata = {} 191 | metadata["dependency_versions"] = get_dependency_versions() 192 | path = Path(path) 193 | 194 | if save_program: 195 | if path.suffix: 196 | raise ValueError( 197 | f"`path` must point to a directory without a suffix when `save_program=True`, but received: {path}" 198 | ) 199 | if path.exists() and not path.is_dir(): 200 | raise NotADirectoryError(f"The path '{path}' exists but is not a directory.") 201 | 202 | if not path.exists(): 203 | # Create the directory (and any parent directories) 204 | path.mkdir(parents=True) 205 | 206 | try: 207 | modules_to_serialize = modules_to_serialize or [] 208 | for module in modules_to_serialize: 209 | cloudpickle.register_pickle_by_value(module) 210 | 211 | with open(path / "program.pkl", "wb") as f: 212 | cloudpickle.dump(self, f) 213 | except Exception as e: 214 | raise RuntimeError( 215 | f"Saving failed with error: {e}. Please remove the non-picklable attributes from your DSPy program, " 216 | "or consider using state-only saving by setting `save_program=False`." 217 | ) 218 | with open(path / "metadata.json", "wb") as f: 219 | f.write(orjson.dumps(metadata, option=orjson.OPT_INDENT_2 | orjson.OPT_APPEND_NEWLINE)) 220 | 221 | return 222 | 223 | if path.suffix == ".json": 224 | state = self.dump_state() 225 | state["metadata"] = metadata 226 | try: 227 | with open(path, "wb") as f: 228 | f.write(orjson.dumps(state, option=orjson.OPT_INDENT_2 | orjson.OPT_APPEND_NEWLINE)) 229 | except Exception as e: 230 | raise RuntimeError( 231 | f"Failed to save state to {path} with error: {e}. Your DSPy program may contain non " 232 | "json-serializable objects, please consider saving the state in .pkl by using `path` ending " 233 | "with `.pkl`, or saving the whole program by setting `save_program=True`." 234 | ) 235 | elif path.suffix == ".pkl": 236 | state = self.dump_state(json_mode=False) 237 | state["metadata"] = metadata 238 | with open(path, "wb") as f: 239 | cloudpickle.dump(state, f) 240 | else: 241 | raise ValueError(f"`path` must end with `.json` or `.pkl` when `save_program=False`, but received: {path}") 242 | 243 | def load(self, path): 244 | """Load the saved module. You may also want to check out dspy.load, if you want to 245 | load an entire program, not just the state for an existing program. 246 | 247 | Args: 248 | path (str): Path to the saved state file, which should be a .json or a .pkl file 249 | """ 250 | path = Path(path) 251 | 252 | if path.suffix == ".json": 253 | with open(path, "rb") as f: 254 | state = orjson.loads(f.read()) 255 | elif path.suffix == ".pkl": 256 | with open(path, "rb") as f: 257 | state = cloudpickle.load(f) 258 | else: 259 | raise ValueError(f"`path` must end with `.json` or `.pkl`, but received: {path}") 260 | 261 | dependency_versions = get_dependency_versions() 262 | saved_dependency_versions = state["metadata"]["dependency_versions"] 263 | for key, saved_version in saved_dependency_versions.items(): 264 | if dependency_versions[key] != saved_version: 265 | logger.warning( 266 | f"There is a mismatch of {key} version between saved model and current environment. " 267 | f"You saved with `{key}=={saved_version}`, but now you have " 268 | f"`{key}=={dependency_versions[key]}`. This might cause errors or performance downgrade " 269 | "on the loaded model, please consider loading the model in the same environment as the " 270 | "saving environment." 271 | ) 272 | self.load_state(state) 273 | ``` -------------------------------------------------------------------------------- /tests/clients/test_cache.py: -------------------------------------------------------------------------------- ```python 1 | import os 2 | from dataclasses import dataclass 3 | from unittest.mock import patch 4 | 5 | import pydantic 6 | import pytest 7 | from cachetools import LRUCache 8 | from diskcache import FanoutCache 9 | 10 | from dspy.clients.cache import Cache 11 | 12 | 13 | @dataclass 14 | class DummyResponse: 15 | message: str 16 | usage: dict 17 | 18 | 19 | @pytest.fixture 20 | def cache_config(tmp_path): 21 | """Default cache configuration.""" 22 | return { 23 | "enable_disk_cache": True, 24 | "enable_memory_cache": True, 25 | "disk_cache_dir": str(tmp_path), 26 | "disk_size_limit_bytes": 1024 * 1024, # 1MB 27 | "memory_max_entries": 100, 28 | } 29 | 30 | 31 | @pytest.fixture 32 | def cache(cache_config): 33 | """Create a cache instance with the default configuration.""" 34 | return Cache(**cache_config) 35 | 36 | 37 | def test_initialization(tmp_path): 38 | """Test different cache initialization configurations.""" 39 | # Test memory-only cache 40 | memory_cache = Cache( 41 | enable_disk_cache=False, 42 | enable_memory_cache=True, 43 | disk_cache_dir="", 44 | disk_size_limit_bytes=0, 45 | memory_max_entries=50, 46 | ) 47 | assert isinstance(memory_cache.memory_cache, LRUCache) 48 | assert memory_cache.memory_cache.maxsize == 50 49 | assert memory_cache.disk_cache == {} 50 | 51 | # Test disk-only cache 52 | disk_cache = Cache( 53 | enable_disk_cache=True, 54 | enable_memory_cache=False, 55 | disk_cache_dir=str(tmp_path), 56 | disk_size_limit_bytes=1024, 57 | memory_max_entries=0, 58 | ) 59 | assert isinstance(disk_cache.disk_cache, FanoutCache) 60 | assert disk_cache.memory_cache == {} 61 | 62 | # Test disabled cache 63 | disabled_cache = Cache( 64 | enable_disk_cache=False, 65 | enable_memory_cache=False, 66 | disk_cache_dir="", 67 | disk_size_limit_bytes=0, 68 | memory_max_entries=0, 69 | ) 70 | assert disabled_cache.memory_cache == {} 71 | assert disabled_cache.disk_cache == {} 72 | 73 | 74 | def test_cache_key_generation(cache): 75 | """Test cache key generation with different types of inputs.""" 76 | # Test with simple dictionary 77 | request = {"prompt": "Hello", "model": "openai/gpt-4o-mini", "temperature": 0.7} 78 | key = cache.cache_key(request) 79 | assert isinstance(key, str) 80 | assert len(key) == 64 # SHA-256 hash is 64 characters 81 | 82 | # Test with pydantic model 83 | class TestModel(pydantic.BaseModel): 84 | name: str 85 | value: int 86 | 87 | model = TestModel(name="test", value=42) 88 | request_with_model = {"data": model} 89 | key_with_model = cache.cache_key(request_with_model) 90 | assert isinstance(key_with_model, str) 91 | 92 | # Test with pydantic model class 93 | request_with_model_class = {"model_class": TestModel} 94 | key_with_model_class = cache.cache_key(request_with_model_class) 95 | assert isinstance(key_with_model_class, str) 96 | 97 | 98 | def test_put_and_get(cache): 99 | """Test putting and getting from cache.""" 100 | # Test putting and getting from memory cache 101 | request = {"prompt": "Hello", "model": "openai/gpt-4o-mini", "temperature": 0.7} 102 | 103 | value = DummyResponse(message="This is a test response", usage={"prompt_tokens": 10, "completion_tokens": 20}) 104 | 105 | cache.put(request, value) 106 | result = cache.get(request) 107 | 108 | assert result.message == value.message 109 | assert result.usage == {} 110 | 111 | # Test with disk cache 112 | # First, clear memory cache to ensure we're using disk cache 113 | cache.reset_memory_cache() 114 | 115 | # Get from disk cache 116 | result_from_disk = cache.get(request) 117 | assert result_from_disk.message == value.message 118 | assert result_from_disk.usage == {} 119 | 120 | # Verify it was also added back to memory cache 121 | assert cache.cache_key(request) in cache.memory_cache 122 | 123 | 124 | def test_cache_miss(cache): 125 | """Test getting a non-existent key.""" 126 | request = {"prompt": "Non-existent", "model": "gpt-4"} 127 | result = cache.get(request) 128 | assert result is None 129 | 130 | 131 | def test_cache_key_error_handling(cache): 132 | """Test error handling for unserializable objects.""" 133 | 134 | # Test with a request that can't be serialized to JSON 135 | class UnserializableObject: 136 | pass 137 | 138 | request = {"data": UnserializableObject()} 139 | 140 | # Should not raise an exception 141 | result = cache.get(request) 142 | assert result is None 143 | 144 | # Should not raise an exception 145 | cache.put(request, "value") 146 | 147 | 148 | def test_reset_memory_cache(cache): 149 | """Test resetting memory cache.""" 150 | # Add some items to the memory cache 151 | requests = [{"prompt": f"Hello {i}", "model": "openai/gpt-4o-mini"} for i in range(5)] 152 | for i, req in enumerate(requests): 153 | cache.put(req, f"Response {i}") 154 | 155 | # Verify items are in memory cache 156 | for req in requests: 157 | key = cache.cache_key(req) 158 | assert key in cache.memory_cache 159 | 160 | # Reset memory cache 161 | cache.reset_memory_cache() 162 | 163 | # Verify memory cache is empty 164 | assert len(cache.memory_cache) == 0 165 | 166 | # But disk cache still has the items 167 | for req in requests: 168 | result = cache.get(req) 169 | assert result is not None 170 | 171 | 172 | def test_save_and_load_memory_cache(cache, tmp_path): 173 | """Test saving and loading memory cache.""" 174 | # Add some items to the memory cache 175 | requests = [{"prompt": f"Hello {i}", "model": "openai/gpt-4o-mini"} for i in range(5)] 176 | for i, req in enumerate(requests): 177 | cache.put(req, f"Response {i}") 178 | 179 | # Save memory cache to a temporary file 180 | temp_cache_file = tmp_path / "memory_cache.pkl" 181 | cache.save_memory_cache(str(temp_cache_file)) 182 | 183 | # Create a new cache instance with disk cache disabled 184 | new_cache = Cache( 185 | enable_memory_cache=True, 186 | enable_disk_cache=False, 187 | disk_cache_dir=tmp_path / "disk_cache", 188 | disk_size_limit_bytes=0, 189 | memory_max_entries=100, 190 | ) 191 | 192 | # Load the memory cache 193 | new_cache.load_memory_cache(str(temp_cache_file)) 194 | 195 | # Verify items are in the new memory cache 196 | for req in requests: 197 | result = new_cache.get(req) 198 | assert result is not None 199 | assert result == f"Response {requests.index(req)}" 200 | 201 | 202 | def test_request_cache_decorator(cache): 203 | """Test the lm_cache decorator.""" 204 | from dspy.clients.cache import request_cache 205 | 206 | # Mock the dspy.cache attribute 207 | with patch("dspy.cache", cache): 208 | # Define a test function 209 | @request_cache() 210 | def test_function(prompt, model): 211 | return f"Response for {prompt} with {model}" 212 | 213 | # First call should compute the result 214 | result1 = test_function(prompt="Hello", model="openai/gpt-4o-mini") 215 | assert result1 == "Response for Hello with openai/gpt-4o-mini" 216 | 217 | # Second call with same arguments should use cache 218 | with patch.object(cache, "get") as mock_get: 219 | mock_get.return_value = "Cached response" 220 | result2 = test_function(prompt="Hello", model="openai/gpt-4o-mini") 221 | assert result2 == "Cached response" 222 | mock_get.assert_called_once() 223 | 224 | # Call with different arguments should compute again 225 | result3 = test_function(prompt="Different", model="openai/gpt-4o-mini") 226 | assert result3 == "Response for Different with openai/gpt-4o-mini" 227 | 228 | 229 | def test_request_cache_decorator_with_ignored_args_for_cache_key(cache): 230 | """Test the request_cache decorator with ignored_args_for_cache_key.""" 231 | from dspy.clients.cache import request_cache 232 | 233 | # Mock the dspy.cache attribute 234 | with patch("dspy.cache", cache): 235 | # Define a test function 236 | @request_cache(ignored_args_for_cache_key=["model"]) 237 | def test_function1(prompt, model): 238 | return f"Response for {prompt} with {model}" 239 | 240 | @request_cache() 241 | def test_function2(prompt, model): 242 | return f"Response for {prompt} with {model}" 243 | 244 | # First call should compute the result 245 | result1 = test_function1(prompt="Hello", model="openai/gpt-4o-mini") 246 | result2 = test_function1(prompt="Hello", model="openai/gpt-4o") 247 | 248 | # Because model arg is ignored, the second call should return the same result as the first 249 | assert result1 == result2 250 | 251 | result3 = test_function2(prompt="Hello", model="openai/gpt-4o-mini") 252 | result4 = test_function2(prompt="Hello", model="openai/gpt-4o") 253 | 254 | # Because model arg is not ignored, the second call should return a different result 255 | assert result3 != result4 256 | 257 | 258 | @pytest.mark.asyncio 259 | async def test_request_cache_decorator_async(cache): 260 | """Test the request_cache decorator with async functions.""" 261 | from dspy.clients.cache import request_cache 262 | 263 | # Mock the dspy.cache attribute 264 | with patch("dspy.cache", cache): 265 | # Define a test function 266 | @request_cache() 267 | async def test_function(prompt, model): 268 | return f"Response for {prompt} with {model}" 269 | 270 | # First call should compute the result 271 | result1 = await test_function(prompt="Hello", model="openai/gpt-4o-mini") 272 | assert result1 == "Response for Hello with openai/gpt-4o-mini" 273 | 274 | # Second call with same arguments should use cache 275 | with patch.object(cache, "get") as mock_get: 276 | mock_get.return_value = "Cached response" 277 | result2 = await test_function(prompt="Hello", model="openai/gpt-4o-mini") 278 | assert result2 == "Cached response" 279 | mock_get.assert_called_once() 280 | 281 | # Call with different arguments should compute again 282 | result3 = await test_function(prompt="Different", model="openai/gpt-4o-mini") 283 | assert result3 == "Response for Different with openai/gpt-4o-mini" 284 | 285 | 286 | def test_cache_consistency_with_lm_call_modifies_the_request(cache): 287 | """Test that the cache is consistent with the LM call that modifies the request.""" 288 | from dspy.clients.cache import request_cache 289 | 290 | # Mock the dspy.cache attribute 291 | with patch("dspy.cache", cache): 292 | # Define a test function 293 | @request_cache() 294 | def test_function(**kwargs): 295 | del kwargs["field_to_delete"] 296 | return kwargs 297 | 298 | # First call should compute the result 299 | test_function(field_to_delete="delete", field_to_keep="keep") 300 | 301 | # The cache key should use the original request, not the modified one 302 | assert ( 303 | cache.get( 304 | { 305 | "field_to_keep": "keep", 306 | "_fn_identifier": f"{test_function.__module__}.{test_function.__qualname__}", 307 | } 308 | ) 309 | is None 310 | ) 311 | assert ( 312 | cache.get( 313 | { 314 | "field_to_keep": "keep", 315 | "field_to_delete": "delete", 316 | "_fn_identifier": f"{test_function.__module__}.{test_function.__qualname__}", 317 | } 318 | ) 319 | is not None 320 | ) 321 | 322 | 323 | def test_cache_fallback_on_restricted_environment(): 324 | """Test that DSPy gracefully falls back to memory-only cache when disk cache fails.""" 325 | old_env = os.environ.get("DSPY_CACHEDIR") 326 | try: 327 | # Set an invalid cache directory that can't be created 328 | os.environ["DSPY_CACHEDIR"] = "/dev/null/invalid_path" 329 | 330 | import dspy 331 | from dspy.clients import _get_dspy_cache 332 | 333 | dspy.cache = _get_dspy_cache() 334 | 335 | # Cache should work with memory-only fallback despite invalid disk path 336 | test_request = {"model": "test", "prompt": "hello"} 337 | dspy.cache.put(test_request, "fallback_result") 338 | result = dspy.cache.get(test_request) 339 | 340 | assert result == "fallback_result", "Memory cache fallback should work" 341 | 342 | finally: 343 | if old_env is None: 344 | os.environ.pop("DSPY_CACHEDIR", None) 345 | else: 346 | os.environ["DSPY_CACHEDIR"] = old_env 347 | ``` -------------------------------------------------------------------------------- /docs/docs/community/use-cases.md: -------------------------------------------------------------------------------- ```markdown 1 | # Use Cases 2 | 3 | We often get questions like “How are people using DSPy in practice?”, both in production and for research. This list was created to collect a few pointers and to encourage others in the community to add their own work below. 4 | 5 | This list is continuously growing. We regularly add new use cases and welcome community contributions. If you would like to add your product or research to this list, please submit a PR. 6 | 7 | ## A Few Company Use Cases 8 | 9 | | **Name** | **Use Cases** | 10 | |---|---| 11 | | **[JetBlue](https://www.jetblue.com/)** | Multiple chatbot use cases. [Blog](https://www.databricks.com/blog/optimizing-databricks-llm-pipelines-dspy) | 12 | | **[Replit](https://replit.com/)** | Synthesize diffs using code LLMs using a DSPy pipeline. [Blog](https://blog.replit.com/code-repair) | 13 | | **[Databricks](https://www.databricks.com/)** | Research, products, and customer solutions around LM Judges, RAG, classification, and other applications. [Blog](https://www.databricks.com/blog/dspy-databricks), [Blog II](https://www.databricks.com/customers/ddi) | 14 | | **[Sephora](https://www.sephora.com/)** | Undisclosed agent usecases; perspectives shared in [DAIS Session](https://www.youtube.com/watch?v=D2HurSldDkE). | 15 | | **[Zoro UK](https://www.zoro.co.uk/)** | E-commerce applications around structured shopping. [Portkey Session](https://www.youtube.com/watch?v=_vGKSc1tekE) | 16 | | **[VMware](https://www.vmware.com/)** | RAG and other prompt optimization applications. [Interview in The Register.](https://www.theregister.com/2024/02/22/prompt_engineering_ai_models/) [Business Insider.](https://www.businessinsider.com/chaptgpt-large-language-model-ai-prompt-engineering-automated-optimizer-2024-3) | 17 | | **[Haize Labs](https://www.haizelabs.com/)** | Automated red-teaming for LLMs. [Blog](https://blog.haizelabs.com/posts/dspy/) | 18 | | **[Plastic Labs](https://www.plasticlabs.ai/)** | R&D pipelines for Honcho. [Blog](https://blog.plasticlabs.ai/blog/User-State-is-State-of-the-Art) | 19 | | **[PingCAP](https://pingcap.com/)** | Building a knowledge graph. [Article](https://www.pingcap.com/article/building-a-graphrag-from-wikipedia-page-using-dspy-openai-and-tidb-vector-database/) | 20 | | **[Salomatic](https://langtrace.ai/blog/case-study-how-salomatic-used-langtrace-to-build-a-reliable-medical-report-generation-system)** | Enriching medical reports using DSPy. [Blog](https://langtrace.ai/blog/case-study-how-salomatic-used-langtrace-to-build-a-reliable-medical-report-generation-system) | 21 | | **[Truelaw](https://www.youtube.com/watch?v=O0F3RAWZNfM)** | How Truelaw builds bespoke LLM pipelines for law firms using DSPy. [Podcast](https://www.youtube.com/watch?v=O0F3RAWZNfM) | 22 | | **[STChealth](https://stchealth.com/)** | Using DSPy for entity resolution including human-readable rationale for decisions. | 23 | | **[Moody's](https://www.moodys.com/)** | Leveraging DSPy to optimize RAG systems, LLM-as-a-Judge, and agentic systems for financial workflows. | 24 | | **[Normal Computing](https://www.normalcomputing.com/)** | Translate specs from chip companies from English to intermediate formal languages | 25 | | **[Procure.FYI](https://www.procure.fyi/)** | Process messy, publicly available technology spending and pricing data via DSPy. | 26 | | **[RadiantLogic](https://www.radiantlogic.com/)** | AI Data Assistant. DSPy is used for the agent that routes the query, the context extraction module, the text-to-sql conversion engine, and the table summarization module. | 27 | | **[Raia](https://raiahealth.com/)** | Using DSPy for AI-powered Personal Healthcare Agents. | 28 | | **[Hyperlint](https://hyperlint.com)** | Uses DSPy to generate technical documentation. DSPy helps to fetch relevant information and synthesize that into tutorials. | 29 | | **[Starops](https://staropshq.com/) & [Saya](https://heysaya.ai/)** | Building research documents given a user's corpus. Generate prompts to create more articles from example articles. | 30 | | **[Tessel AI](https://tesselai.com/)** | Enhancing human-machine interaction with data use cases. | 31 | | **[Dicer.ai](https://dicer.ai/)** | Uses DSPy for marketing AI to get the most from their paid ads. | 32 | | **[Howie](https://howie.ai)** | Using DSPy to automate meeting scheduling through email. | 33 | | **[Isoform.ai](https://isoform.ai)** | Building custom integrations using DSPy. | 34 | | **[Trampoline AI](https://trampoline.ai)** | Uses DSPy to power their data-augmentation and LM pipelines. | 35 | | **[Pretrain](https://pretrain.com)** | Uses DSPy to automatically optimize AI performance towards user-defined tasks based on uploaded examples. | 36 | | **[Spindle AI](https://spindle.ai)** | Turns natural-language constrained optimization problems into solvable mathematical programs whose candidate solutions are scenarios. | 37 | | **[Infinitus](https://www.infinitus.ai/product/ai-agents/)** | Leverages DSPy to build and optimize healthcare AI agents | 38 | 39 | This list represents companies that have publicly shared their use cases or have provided permission to be included. It reflects a selection of the many industry applications of DSPy currently in production. 40 | 41 | 42 | ## A Few Papers Using DSPy 43 | 44 | | **Name** | **Description** | 45 | |---|---| 46 | | **[STORM](https://arxiv.org/abs/2402.14207)** | Writing Wikipedia-like Articles From Scratch. | 47 | | **[PATH](https://arxiv.org/abs/2406.11706)** | Prompts as Auto-Optimized Training Hyperparameters: Training Best-in-Class IR Models from Scratch with 10 Gold Labels | 48 | | **[WangLab @ MEDIQA](https://arxiv.org/abs/2404.14544)** | UofT's winning system at MEDIQA, outperforms the next best system by 20 points | 49 | | **[UMD's Suicide Detection System](https://arxiv.org/abs/2406.06608)** | Outperforms 20-hour expert human prompt engineering by 40% | 50 | | **[IReRa](https://arxiv.org/abs/2401.12178)** | Infer-Retrieve-Rank: Extreme Classification with > 10,000 Labels | 51 | | **[Unreasonably Effective Eccentric Prompts](https://arxiv.org/abs/2402.10949v2)** | General Prompt Optimization | 52 | | **[Palimpzest](https://arxiv.org/abs/2405.14696)** | A Declarative System for Optimizing AI Workloads | 53 | | **[AI Agents that Matter](https://arxiv.org/abs/2407.01502v1)** | Agent Efficiency Optimization | 54 | | **[EDEN](https://arxiv.org/abs/2406.17982v1)** | Empathetic Dialogues for English Learning: Uses adaptive empathetic feedback to improve student grit | 55 | | **[ECG-Chat](https://arxiv.org/pdf/2408.08849)** | Uses DSPy with GraphRAG for medical report generation | 56 | | **[DSPy Assertions](https://arxiv.org/abs/2312.13382)** | Various applications of imposing hard and soft constraints on LM outputs | 57 | | **[DSPy Guardrails](https://boxiyu.github.io/assets/pdf/DSPy_Guardrails.pdf)** | Reduce the attack success rate of CodeAttack, decreasing from 75% to 5% | 58 | | **[Co-STORM](https://arxiv.org/pdf/2408.15232)** | Collaborative STORM: Generate Wikipedia-like articles through collaborative discourse among users and multiple LM agents | 59 | | **[MedVAL](https://arxiv.org/abs/2507.03152)** | Expert-level validation of AI-generated medical text with scalable language models | 60 | | **[Neural @ ArchEHR-QA 2025](https://aclanthology.org/2025.bionlp-share.13.pdf)** | Runner up method at 2025 BioNLP Shared Task Workshop 61 | 62 | This list is regularly updated with new research publications using DSPy. 63 | 64 | ## A Few Repositories (or other OSS examples) using DSPy 65 | 66 | | **Name** | **Description/Link** | 67 | |---|---| 68 | | **Stanford CS 224U Homework** | [Github](https://github.com/cgpotts/cs224u/blob/main/hw_openqa.ipynb) | 69 | | **STORM Report Generation (10,000 GitHub stars)** | [Github](https://github.com/stanford-oval/storm) | 70 | | **DSPy Redteaming** | [Github](https://github.com/haizelabs/dspy-redteam) | 71 | | **DSPy Theory of Mind** | [Github](https://github.com/plastic-labs/dspy-opentom) | 72 | | **Indic cross-lingual Natural Language Inference** | [Github](https://github.com/saifulhaq95/DSPy-Indic/blob/main/indicxlni.ipynb) | 73 | | **Optimizing LM for Text2SQL using DSPy** | [Github](https://github.com/jjovalle99/DSPy-Text2SQL) | 74 | | **DSPy PII Masking Demo by Eric Ness** | [Colab](https://colab.research.google.com/drive/1KZR1sGTp_RLWUJPAiK1FKPKI-Qn9neUm?usp=sharing) | 75 | | **DSPy on BIG-Bench Hard Example** | [Github](https://drchrislevy.github.io/posts/dspy/dspy.html) | 76 | | **Building a chess playing agent using DSPy** | [Github](https://medium.com/thoughts-on-machine-learning/building-a-chess-playing-agent-using-dspy-9b87c868f71e) | 77 | | **Ittia Research Fact Checking** | [Github](https://github.com/ittia-research/check) | 78 | | **Strategic Debate via Tree-of-Thought** | [Github](https://github.com/zbambergerNLP/strategic-debate-tot) | 79 | | **Sanskrit to English Translation App**| [Github](https://github.com/ganarajpr/sanskrit-translator-dspy) | 80 | | **DSPy for extracting features from PDFs on arXiv**| [Github](https://github.com/S1M0N38/dspy-arxiv) | 81 | | **DSPygen: DSPy in Ruby on Rails**| [Github](https://github.com/seanchatmangpt/dspygen) | 82 | | **DSPy Inspector**| [Github](https://github.com/Neoxelox/dspy-inspector) | 83 | | **DSPy with FastAPI**| [Github](https://github.com/diicellman/dspy-rag-fastapi) | 84 | | **DSPy for Indian Languages**| [Github](https://github.com/saifulhaq95/DSPy-Indic) | 85 | | **Hurricane: Blog Posts with Generative Feedback Loops!**| [Github](https://github.com/weaviate-tutorials/Hurricane) | 86 | | **RAG example using DSPy, Gradio, FastAPI, and Ollama**| [Github](https://github.com/diicellman/dspy-gradio-rag) | 87 | | **Synthetic Data Generation**| [Github](https://colab.research.google.com/drive/1CweVOu0qhTC0yOfW5QkLDRIKuAuWJKEr?usp=sharing) | 88 | | **Self Discover**| [Github](https://colab.research.google.com/drive/1GkAQKmw1XQgg5UNzzy8OncRe79V6pADB?usp=sharing) | 89 | | **MedVAL**| [Github](https://github.com/StanfordMIMI/MedVAL) | 90 | 91 | This list showcases some of the open-source projects and repositories using DSPy, with many more examples available in the community. 92 | 93 | ## A Few Providers, Integrations, and related Blog Releases 94 | 95 | | **Name** | **Link** | 96 | |---|---| 97 | | **Databricks** | [Link](https://www.databricks.com/blog/dspy-databricks) | 98 | | **Zenbase** | [Link](https://zenbase.ai/) | 99 | | **LangWatch** | [Link](https://langwatch.ai/blog/introducing-dspy-visualizer) | 100 | | **Gradient** | [Link](https://gradient.ai/blog/achieving-gpt-4-level-performance-at-lower-cost-using-dspy) | 101 | | **Snowflake** | [Link](https://medium.com/snowflake/dspy-snowflake-140d6d947d73) | 102 | | **Langchain** | [Link](https://python.langchain.com/v0.2/docs/integrations/providers/dspy/) | 103 | | **Weaviate** | [Link](https://weaviate.io/blog/dspy-optimizers) | 104 | | **Qdrant** | [Link](https://qdrant.tech/documentation/frameworks/dspy/) | 105 | | **Weights & Biases Weave** | [Link](https://weave-docs.wandb.ai/guides/integrations/dspy/) | 106 | | **Milvus** | [Link](https://milvus.io/docs/integrate_with_dspy.md) | 107 | | **Neo4j** | [Link](https://neo4j.com/labs/genai-ecosystem/dspy/) | 108 | | **Lightning AI** | [Link](https://lightning.ai/lightning-ai/studios/dspy-programming-with-foundation-models) | 109 | | **Haystack** | [Link](https://towardsdatascience.com/automating-prompt-engineering-with-dspy-and-haystack-926a637a3f43) | 110 | | **Arize** | [Link](https://docs.arize.com/phoenix/tracing/integrations-tracing/dspy) | 111 | | **LlamaIndex** | [Link](https://github.com/stanfordnlp/dspy/blob/main/examples/llamaindex/dspy_llamaindex_rag.ipynb) | 112 | | **Langtrace** | [Link](https://docs.langtrace.ai/supported-integrations/llm-frameworks/dspy) | 113 | | **Langfuse** | [Link](https://langfuse.com/docs/integrations/dspy) | 114 | | **OpenLIT** | [Link](https://docs.openlit.io/latest/integrations/dspy) | 115 | | **Relevance AI** | [Link](https://relevanceai.com/blog/dspy-programming---not-prompting---language-models) | 116 | 117 | Credit: Some of these resources were originally compiled in the [Awesome DSPy](https://github.com/ganarajpr/awesome-dspy/tree/master) repo. 118 | ``` -------------------------------------------------------------------------------- /dspy/streaming/streamify.py: -------------------------------------------------------------------------------- ```python 1 | import asyncio 2 | import contextvars 3 | import logging 4 | import threading 5 | from asyncio import iscoroutinefunction 6 | from queue import Queue 7 | from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Generator 8 | 9 | import litellm 10 | import orjson 11 | from anyio import create_memory_object_stream, create_task_group 12 | from anyio.streams.memory import MemoryObjectSendStream 13 | from litellm import ModelResponseStream 14 | 15 | from dspy.dsp.utils.settings import settings 16 | from dspy.primitives.prediction import Prediction 17 | from dspy.streaming.messages import StatusMessage, StatusMessageProvider, StatusStreamingCallback 18 | from dspy.streaming.streaming_listener import StreamListener, find_predictor_for_stream_listeners 19 | from dspy.utils.asyncify import asyncify 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | if TYPE_CHECKING: 24 | from dspy.primitives.module import Module 25 | 26 | 27 | def streamify( 28 | program: "Module", 29 | status_message_provider: StatusMessageProvider | None = None, 30 | stream_listeners: list[StreamListener] | None = None, 31 | include_final_prediction_in_output_stream: bool = True, 32 | is_async_program: bool = False, 33 | async_streaming: bool = True, 34 | ) -> Callable[[Any, Any], Awaitable[Any]]: 35 | """ 36 | Wrap a DSPy program so that it streams its outputs incrementally, rather than returning them 37 | all at once. It also provides status messages to the user to indicate the progress of the program, and users 38 | can implement their own status message provider to customize the status messages and what module to generate 39 | status messages for. 40 | 41 | Args: 42 | program: The DSPy program to wrap with streaming functionality. 43 | status_message_provider: A custom status message generator to use instead of the default one. Users can 44 | implement their own status message generator to customize the status messages and what module to generate 45 | status messages for. 46 | stream_listeners: A list of stream listeners to capture the streaming output of specific fields of sub predicts 47 | in the program. When provided, only the target fields in the target predict will be streamed to the user. 48 | include_final_prediction_in_output_stream: Whether to include the final prediction in the output stream, only 49 | useful when `stream_listeners` is provided. If `False`, the final prediction will not be included in the 50 | output stream. When the program hit cache, or no listeners captured anything, the final prediction will 51 | still be included in the output stream even if this is `False`. 52 | is_async_program: Whether the program is async. If `False`, the program will be wrapped with `asyncify`, 53 | otherwise the program will be called with `acall`. 54 | async_streaming: Whether to return an async generator or a sync generator. If `False`, the streaming will be 55 | converted to a sync generator. 56 | 57 | Returns: 58 | A function that takes the same arguments as the original program, but returns an async 59 | generator that yields the program's outputs incrementally. 60 | 61 | Example: 62 | 63 | ```python 64 | import asyncio 65 | import dspy 66 | 67 | dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini")) 68 | # Create the program and wrap it with streaming functionality 69 | program = dspy.streamify(dspy.Predict("q->a")) 70 | 71 | # Use the program with streaming output 72 | async def use_streaming(): 73 | output = program(q="Why did a chicken cross the kitchen?") 74 | return_value = None 75 | async for value in output: 76 | if isinstance(value, dspy.Prediction): 77 | return_value = value 78 | else: 79 | print(value) 80 | return return_value 81 | 82 | output = asyncio.run(use_streaming()) 83 | print(output) 84 | ``` 85 | 86 | Example with custom status message provider: 87 | ```python 88 | import asyncio 89 | import dspy 90 | 91 | dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini")) 92 | 93 | class MyStatusMessageProvider(StatusMessageProvider): 94 | def module_start_status_message(self, instance, inputs): 95 | return f"Predicting..." 96 | 97 | def tool_end_status_message(self, outputs): 98 | return f"Tool calling finished with output: {outputs}!" 99 | 100 | # Create the program and wrap it with streaming functionality 101 | program = dspy.streamify(dspy.Predict("q->a"), status_message_provider=MyStatusMessageProvider()) 102 | 103 | # Use the program with streaming output 104 | async def use_streaming(): 105 | output = program(q="Why did a chicken cross the kitchen?") 106 | return_value = None 107 | async for value in output: 108 | if isinstance(value, dspy.Prediction): 109 | return_value = value 110 | else: 111 | print(value) 112 | return return_value 113 | 114 | output = asyncio.run(use_streaming()) 115 | print(output) 116 | ``` 117 | 118 | Example with stream listeners: 119 | 120 | ```python 121 | import asyncio 122 | import dspy 123 | 124 | dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini", cache=False)) 125 | 126 | # Create the program and wrap it with streaming functionality 127 | predict = dspy.Predict("question->answer, reasoning") 128 | stream_listeners = [ 129 | dspy.streaming.StreamListener(signature_field_name="answer"), 130 | dspy.streaming.StreamListener(signature_field_name="reasoning"), 131 | ] 132 | stream_predict = dspy.streamify(predict, stream_listeners=stream_listeners) 133 | 134 | async def use_streaming(): 135 | output = stream_predict( 136 | question="why did a chicken cross the kitchen?", 137 | include_final_prediction_in_output_stream=False, 138 | ) 139 | return_value = None 140 | async for value in output: 141 | if isinstance(value, dspy.Prediction): 142 | return_value = value 143 | else: 144 | print(value) 145 | return return_value 146 | 147 | output = asyncio.run(use_streaming()) 148 | print(output) 149 | ``` 150 | 151 | You should see the streaming chunks (in the format of `dspy.streaming.StreamResponse`) in the console output. 152 | """ 153 | stream_listeners = stream_listeners or [] 154 | if len(stream_listeners) > 0: 155 | predict_id_to_listener = find_predictor_for_stream_listeners(program, stream_listeners) 156 | else: 157 | predict_id_to_listener = {} 158 | 159 | if is_async_program: 160 | program = program.acall 161 | elif not iscoroutinefunction(program): 162 | program = asyncify(program) 163 | 164 | callbacks = settings.callbacks 165 | status_streaming_callback = StatusStreamingCallback(status_message_provider) 166 | if not any(isinstance(c, StatusStreamingCallback) for c in callbacks): 167 | callbacks.append(status_streaming_callback) 168 | 169 | async def generator(args, kwargs, stream: MemoryObjectSendStream): 170 | with settings.context(send_stream=stream, callbacks=callbacks, stream_listeners=stream_listeners): 171 | prediction = await program(*args, **kwargs) 172 | 173 | await stream.send(prediction) 174 | 175 | async def async_streamer(*args, **kwargs): 176 | send_stream, receive_stream = create_memory_object_stream(16) 177 | async with create_task_group() as tg, send_stream, receive_stream: 178 | tg.start_soon(generator, args, kwargs, send_stream) 179 | 180 | async for value in receive_stream: 181 | if isinstance(value, ModelResponseStream): 182 | if len(predict_id_to_listener) == 0: 183 | # No listeners are configured, yield the chunk directly for backwards compatibility. 184 | yield value 185 | else: 186 | # We are receiving a chunk from the LM's response stream, delegate it to the listeners to 187 | # determine if we should yield a value to the user. 188 | for listener in predict_id_to_listener[value.predict_id]: 189 | # In some special cases such as Citation API, it is possible that multiple listeners 190 | # return values at the same time due to the chunk buffer of the listener. 191 | if output := listener.receive(value): 192 | yield output 193 | elif isinstance(value, StatusMessage): 194 | yield value 195 | elif isinstance(value, Prediction): 196 | # Flush remaining buffered tokens before yielding the Prediction instance 197 | for listener in stream_listeners: 198 | if final_chunk := listener.finalize(): 199 | yield final_chunk 200 | 201 | if include_final_prediction_in_output_stream: 202 | yield value 203 | elif ( 204 | len(stream_listeners) == 0 205 | or any(listener.cache_hit for listener in stream_listeners) 206 | or not any(listener.stream_start for listener in stream_listeners) 207 | ): 208 | yield value 209 | return 210 | else: 211 | # This wildcard case allows for customized streaming behavior. 212 | # It is useful when a users have a custom LM which returns stream chunks in a custom format. 213 | # We let those chunks pass through to the user to handle them as needed. 214 | yield value 215 | 216 | if async_streaming: 217 | return async_streamer 218 | else: 219 | 220 | def sync_streamer(*args, **kwargs): 221 | output = async_streamer(*args, **kwargs) 222 | return apply_sync_streaming(output) 223 | 224 | return sync_streamer 225 | 226 | 227 | def apply_sync_streaming(async_generator: AsyncGenerator) -> Generator: 228 | """Convert the async streaming generator to a sync generator.""" 229 | queue = Queue() # Queue to hold items from the async generator 230 | stop_sentinel = object() # Sentinel to signal the generator is complete 231 | 232 | # To propagate prediction request ID context to the child thread 233 | context = contextvars.copy_context() 234 | 235 | def producer(): 236 | """Runs in a background thread to fetch items asynchronously.""" 237 | 238 | async def runner(): 239 | try: 240 | async for item in async_generator: 241 | queue.put(item) 242 | finally: 243 | # Signal completion 244 | queue.put(stop_sentinel) 245 | 246 | context.run(asyncio.run, runner()) 247 | 248 | # Start the producer in a background thread 249 | thread = threading.Thread(target=producer, daemon=True) 250 | thread.start() 251 | 252 | # Consume items from the queue 253 | while True: 254 | item = queue.get() # Block until an item is available 255 | if item is stop_sentinel: 256 | break 257 | yield item 258 | 259 | 260 | async def streaming_response(streamer: AsyncGenerator) -> AsyncGenerator: 261 | """ 262 | Convert a DSPy program output stream to an OpenAI-compatible output stream that can be 263 | used by a service as an API response to a streaming request. 264 | 265 | Args: 266 | streamer: An async generator that yields values from a DSPy program output stream. 267 | Returns: 268 | An async generator that yields OpenAI-compatible streaming response chunks. 269 | """ 270 | async for value in streamer: 271 | if isinstance(value, Prediction): 272 | data = {"prediction": dict(value.items(include_dspy=False))} 273 | yield f"data: {orjson.dumps(data).decode()}\n\n" 274 | elif isinstance(value, litellm.ModelResponseStream): 275 | data = {"chunk": value.json()} 276 | yield f"data: {orjson.dumps(data).decode()}\n\n" 277 | elif isinstance(value, str) and value.startswith("data:"): 278 | # The chunk value is an OpenAI-compatible streaming chunk value, 279 | # e.g. "data: {"finish_reason": "stop", "index": 0, "is_finished": True, ...}", 280 | # so yield it directly 281 | yield value 282 | else: 283 | raise ValueError(f"Unknown chunk value type: {value}") 284 | yield "data: [DONE]\n\n" 285 | ``` -------------------------------------------------------------------------------- /docs/docs/static/img/undraw_docusaurus_tree.svg: -------------------------------------------------------------------------------- ``` 1 | <svg xmlns="http://www.w3.org/2000/svg" width="1129" height="663" viewBox="0 0 1129 663"> 2 | <title>Focus on What Matters</title> 3 | <circle cx="321" cy="321" r="321" fill="#f2f2f2" /> 4 | <ellipse cx="559" cy="635.49998" rx="514" ry="27.50002" fill="#3f3d56" /> 5 | <ellipse cx="558" cy="627" rx="460" ry="22" opacity="0.2" /> 6 | <rect x="131" y="152.5" width="840" height="50" fill="#3f3d56" /> 7 | <path d="M166.5,727.3299A21.67009,21.67009,0,0,0,188.1701,749H984.8299A21.67009,21.67009,0,0,0,1006.5,727.3299V296h-840Z" transform="translate(-35.5 -118.5)" fill="#3f3d56" /> 8 | <path d="M984.8299,236H188.1701A21.67009,21.67009,0,0,0,166.5,257.6701V296h840V257.6701A21.67009,21.67009,0,0,0,984.8299,236Z" transform="translate(-35.5 -118.5)" fill="#3f3d56" /> 9 | <path d="M984.8299,236H188.1701A21.67009,21.67009,0,0,0,166.5,257.6701V296h840V257.6701A21.67009,21.67009,0,0,0,984.8299,236Z" transform="translate(-35.5 -118.5)" opacity="0.2" /> 10 | <circle cx="181" cy="147.5" r="13" fill="#3f3d56" /> 11 | <circle cx="217" cy="147.5" r="13" fill="#3f3d56" /> 12 | <circle cx="253" cy="147.5" r="13" fill="#3f3d56" /> 13 | <rect x="168" y="213.5" width="337" height="386" rx="5.33505" fill="#606060" /> 14 | <rect x="603" y="272.5" width="284" height="22" rx="5.47638" fill="#2e8555" /> 15 | <rect x="537" y="352.5" width="416" height="15" rx="5.47638" fill="#2e8555" /> 16 | <rect x="537" y="396.5" width="416" height="15" rx="5.47638" fill="#2e8555" /> 17 | <rect x="537" y="440.5" width="416" height="15" rx="5.47638" fill="#2e8555" /> 18 | <rect x="537" y="484.5" width="416" height="15" rx="5.47638" fill="#2e8555" /> 19 | <rect x="865" y="552.5" width="88" height="26" rx="7.02756" fill="#3ecc5f" /> 20 | <path d="M1088.60287,624.61594a30.11371,30.11371,0,0,0,3.98291-15.266c0-13.79652-8.54358-24.98081-19.08256-24.98081s-19.08256,11.18429-19.08256,24.98081a30.11411,30.11411,0,0,0,3.98291,15.266,31.248,31.248,0,0,0,0,30.53213,31.248,31.248,0,0,0,0,30.53208,31.248,31.248,0,0,0,0,30.53208,30.11408,30.11408,0,0,0-3.98291,15.266c0,13.79652,8.54353,24.98081,19.08256,24.98081s19.08256-11.18429,19.08256-24.98081a30.11368,30.11368,0,0,0-3.98291-15.266,31.248,31.248,0,0,0,0-30.53208,31.248,31.248,0,0,0,0-30.53208,31.248,31.248,0,0,0,0-30.53213Z" transform="translate(-35.5 -118.5)" fill="#3f3d56" /> 21 | <ellipse cx="1038.00321" cy="460.31783" rx="19.08256" ry="24.9808" fill="#3f3d56" /> 22 | <ellipse cx="1038.00321" cy="429.78574" rx="19.08256" ry="24.9808" fill="#3f3d56" /> 23 | <path d="M1144.93871,339.34489a91.61081,91.61081,0,0,0,7.10658-10.46092l-50.141-8.23491,54.22885.4033a91.566,91.566,0,0,0,1.74556-72.42605l-72.75449,37.74139,67.09658-49.32086a91.41255,91.41255,0,1,0-150.971,102.29805,91.45842,91.45842,0,0,0-10.42451,16.66946l65.0866,33.81447-69.40046-23.292a91.46011,91.46011,0,0,0,14.73837,85.83669,91.40575,91.40575,0,1,0,143.68892,0,91.41808,91.41808,0,0,0,0-113.02862Z" transform="translate(-35.5 -118.5)" fill="#3ecc5f" fill-rule="evenodd" /> 24 | <path d="M981.6885,395.8592a91.01343,91.01343,0,0,0,19.56129,56.51431,91.40575,91.40575,0,1,0,143.68892,0C1157.18982,436.82067,981.6885,385.60008,981.6885,395.8592Z" transform="translate(-35.5 -118.5)" opacity="0.1" /> 25 | <path d="M365.62,461.43628H477.094v45.12043H365.62Z" transform="translate(-35.5 -118.5)" fill="#fff" fill-rule="evenodd" /> 26 | <path d="M264.76252,608.74122a26.50931,26.50931,0,0,1-22.96231-13.27072,26.50976,26.50976,0,0,0,22.96231,39.81215H291.304V608.74122Z" transform="translate(-35.5 -118.5)" fill="#3ecc5f" fill-rule="evenodd" /> 27 | <path d="M384.17242,468.57061l92.92155-5.80726V449.49263a26.54091,26.54091,0,0,0-26.54143-26.54143H331.1161l-3.31768-5.74622a3.83043,3.83043,0,0,0-6.63536,0l-3.31768,5.74622-3.31767-5.74622a3.83043,3.83043,0,0,0-6.63536,0l-3.31768,5.74622L301.257,417.205a3.83043,3.83043,0,0,0-6.63536,0L291.304,422.9512c-.02919,0-.05573.004-.08625.004l-5.49674-5.49541a3.8293,3.8293,0,0,0-6.4071,1.71723l-1.81676,6.77338L270.607,424.1031a3.82993,3.82993,0,0,0-4.6912,4.69253l1.84463,6.89148-6.77072,1.81411a3.8315,3.8315,0,0,0-1.71988,6.40975l5.49673,5.49673c0,.02787-.004.05574-.004.08493l-5.74622,3.31768a3.83043,3.83043,0,0,0,0,6.63536l5.74621,3.31768L259.0163,466.081a3.83043,3.83043,0,0,0,0,6.63536l5.74622,3.31768-5.74622,3.31767a3.83043,3.83043,0,0,0,0,6.63536l5.74622,3.31768-5.74622,3.31768a3.83043,3.83043,0,0,0,0,6.63536l5.74622,3.31768-5.74622,3.31767a3.83043,3.83043,0,0,0,0,6.63536l5.74622,3.31768-5.74622,3.31768a3.83043,3.83043,0,0,0,0,6.63536l5.74622,3.31768-5.74622,3.31768a3.83042,3.83042,0,0,0,0,6.63535l5.74622,3.31768-5.74622,3.31768a3.83043,3.83043,0,0,0,0,6.63536l5.74622,3.31768L259.0163,558.976a3.83042,3.83042,0,0,0,0,6.63535l5.74622,3.31768-5.74622,3.31768a3.83043,3.83043,0,0,0,0,6.63536l5.74622,3.31768-5.74622,3.31768a3.83042,3.83042,0,0,0,0,6.63535l5.74622,3.31768-5.74622,3.31768a3.83043,3.83043,0,0,0,0,6.63536l5.74622,3.31768A26.54091,26.54091,0,0,0,291.304,635.28265H450.55254A26.5409,26.5409,0,0,0,477.094,608.74122V502.5755l-92.92155-5.80727a14.12639,14.12639,0,0,1,0-28.19762" transform="translate(-35.5 -118.5)" fill="#3ecc5f" fill-rule="evenodd" /> 28 | <path d="M424.01111,635.28265h39.81214V582.19979H424.01111Z" transform="translate(-35.5 -118.5)" fill="#3ecc5f" fill-rule="evenodd" /> 29 | <path d="M490.36468,602.10586a6.60242,6.60242,0,0,0-.848.08493c-.05042-.19906-.09821-.39945-.15393-.59852A6.62668,6.62668,0,1,0,482.80568,590.21q-.2203-.22491-.44457-.44589a6.62391,6.62391,0,1,0-11.39689-6.56369c-.1964-.05575-.39414-.10218-.59056-.15262a6.63957,6.63957,0,1,0-13.10086,0c-.1964.05042-.39414.09687-.59056.15262a6.62767,6.62767,0,1,0-11.39688,6.56369,26.52754,26.52754,0,1,0,44.23127,25.52756,6.6211,6.6211,0,1,0,.848-13.18579" transform="translate(-35.5 -118.5)" fill="#44d860" fill-rule="evenodd" /> 30 | <path d="M437.28182,555.65836H477.094V529.11693H437.28182Z" transform="translate(-35.5 -118.5)" fill="#3ecc5f" fill-rule="evenodd" /> 31 | <path d="M490.36468,545.70532a3.31768,3.31768,0,0,0,0-6.63536,3.41133,3.41133,0,0,0-.42333.04247c-.02655-.09953-.04911-.19907-.077-.29859a3.319,3.319,0,0,0-1.278-6.37923,3.28174,3.28174,0,0,0-2.00122.68742q-.10947-.11346-.22294-.22295a3.282,3.282,0,0,0,.67149-1.98265,3.31768,3.31768,0,0,0-6.37-1.2992,13.27078,13.27078,0,1,0,0,25.54082,3.31768,3.31768,0,0,0,6.37-1.2992,3.282,3.282,0,0,0-.67149-1.98265q.11347-.10947.22294-.22294a3.28174,3.28174,0,0,0,2.00122.68742,3.31768,3.31768,0,0,0,1.278-6.37923c.02786-.0982.05042-.19907.077-.29859a3.41325,3.41325,0,0,0,.42333.04246" transform="translate(-35.5 -118.5)" fill="#44d860" fill-rule="evenodd" /> 32 | <path d="M317.84538,466.081a3.31768,3.31768,0,0,1-3.31767-3.31768,9.953,9.953,0,1,0-19.90608,0,3.31768,3.31768,0,1,1-6.63535,0,16.58839,16.58839,0,1,1,33.17678,0,3.31768,3.31768,0,0,1-3.31768,3.31768" transform="translate(-35.5 -118.5)" fill-rule="evenodd" /> 33 | <path d="M370.92825,635.28265h79.62429A26.5409,26.5409,0,0,0,477.094,608.74122v-92.895H397.46968a26.54091,26.54091,0,0,0-26.54143,26.54143Z" transform="translate(-35.5 -118.5)" fill="#ffff50" fill-rule="evenodd" /> 34 | <path d="M457.21444,556.98543H390.80778a1.32707,1.32707,0,0,1,0-2.65414h66.40666a1.32707,1.32707,0,0,1,0,2.65414m0,26.54143H390.80778a1.32707,1.32707,0,1,1,0-2.65414h66.40666a1.32707,1.32707,0,0,1,0,2.65414m0,26.54143H390.80778a1.32707,1.32707,0,1,1,0-2.65414h66.40666a1.32707,1.32707,0,0,1,0,2.65414m0-66.10674H390.80778a1.32707,1.32707,0,0,1,0-2.65414h66.40666a1.32707,1.32707,0,0,1,0,2.65414m0,26.29459H390.80778a1.32707,1.32707,0,0,1,0-2.65414h66.40666a1.32707,1.32707,0,0,1,0,2.65414m0,26.54143H390.80778a1.32707,1.32707,0,0,1,0-2.65414h66.40666a1.32707,1.32707,0,0,1,0,2.65414M477.094,474.19076c-.01592,0-.0292-.008-.04512-.00663-4.10064.13934-6.04083,4.24132-7.75274,7.86024-1.78623,3.78215-3.16771,6.24122-5.43171,6.16691-2.50685-.09024-3.94007-2.92222-5.45825-5.91874-1.74377-3.44243-3.73438-7.34667-7.91333-7.20069-4.04227.138-5.98907,3.70784-7.70631,6.857-1.82738,3.35484-3.07084,5.39455-5.46887,5.30033-2.55727-.09289-3.91619-2.39536-5.48877-5.06013-1.75306-2.96733-3.77951-6.30359-7.8775-6.18946-3.97326.13669-5.92537,3.16507-7.64791,5.83912-1.82207,2.82666-3.09872,4.5492-5.52725,4.447-2.61832-.09289-3.9706-2.00388-5.53522-4.21611-1.757-2.4856-3.737-5.299-7.82308-5.16231-3.88567.13271-5.83779,2.61434-7.559,4.80135-1.635,2.07555-2.9116,3.71846-5.61218,3.615a1.32793,1.32793,0,1,0-.09555,2.65414c4.00377.134,6.03154-2.38873,7.79257-4.6275,1.562-1.9853,2.91027-3.69855,5.56441-3.78879,2.55594-.10882,3.75429,1.47968,5.56707,4.04093,1.7212,2.43385,3.67465,5.19416,7.60545,5.33616,4.11789.138,6.09921-2.93946,7.8536-5.66261,1.56861-2.43385,2.92221-4.53461,5.50734-4.62352,2.37944-.08892,3.67466,1.79154,5.50072,4.885,1.72121,2.91557,3.67069,6.21865,7.67977,6.36463,4.14709.14332,6.14965-3.47693,7.89475-6.68181,1.51155-2.77092,2.93814-5.38791,5.46621-5.4755,2.37944-.05573,3.62025,2.11668,5.45558,5.74622,1.71459,3.388,3.65875,7.22591,7.73019,7.37321l.22429.004c4.06614,0,5.99571-4.08074,7.70364-7.68905,1.51154-3.19825,2.94211-6.21069,5.3972-6.33411Z" transform="translate(-35.5 -118.5)" fill-rule="evenodd" /> 35 | <path d="M344.38682,635.28265h53.08286V582.19979H344.38682Z" transform="translate(-35.5 -118.5)" fill="#3ecc5f" fill-rule="evenodd" /> 36 | <path d="M424.01111,602.10586a6.60242,6.60242,0,0,0-.848.08493c-.05042-.19906-.09821-.39945-.15394-.59852A6.62667,6.62667,0,1,0,416.45211,590.21q-.2203-.22491-.44458-.44589a6.62391,6.62391,0,1,0-11.39689-6.56369c-.1964-.05575-.39413-.10218-.59054-.15262a6.63957,6.63957,0,1,0-13.10084,0c-.19641.05042-.39414.09687-.59055.15262a6.62767,6.62767,0,1,0-11.39689,6.56369,26.52755,26.52755,0,1,0,44.2313,25.52756,6.6211,6.6211,0,1,0,.848-13.18579" transform="translate(-35.5 -118.5)" fill="#44d860" fill-rule="evenodd" /> 37 | <path d="M344.38682,555.65836h53.08286V529.11693H344.38682Z" transform="translate(-35.5 -118.5)" fill="#3ecc5f" fill-rule="evenodd" /> 38 | <path d="M410.74039,545.70532a3.31768,3.31768,0,1,0,0-6.63536,3.41133,3.41133,0,0,0-.42333.04247c-.02655-.09953-.04911-.19907-.077-.29859a3.319,3.319,0,0,0-1.278-6.37923,3.28174,3.28174,0,0,0-2.00122.68742q-.10947-.11346-.22294-.22295a3.282,3.282,0,0,0,.67149-1.98265,3.31768,3.31768,0,0,0-6.37-1.2992,13.27078,13.27078,0,1,0,0,25.54082,3.31768,3.31768,0,0,0,6.37-1.2992,3.282,3.282,0,0,0-.67149-1.98265q.11347-.10947.22294-.22294a3.28174,3.28174,0,0,0,2.00122.68742,3.31768,3.31768,0,0,0,1.278-6.37923c.02786-.0982.05042-.19907.077-.29859a3.41325,3.41325,0,0,0,.42333.04246" transform="translate(-35.5 -118.5)" fill="#44d860" fill-rule="evenodd" /> 39 | <path d="M424.01111,447.8338a3.60349,3.60349,0,0,1-.65028-.06636,3.34415,3.34415,0,0,1-.62372-.18579,3.44679,3.44679,0,0,1-.572-.30522,5.02708,5.02708,0,0,1-.50429-.4114,3.88726,3.88726,0,0,1-.41007-.50428,3.27532,3.27532,0,0,1-.55737-1.84463,3.60248,3.60248,0,0,1,.06636-.65027,3.82638,3.82638,0,0,1,.18447-.62373,3.48858,3.48858,0,0,1,.30656-.57064,3.197,3.197,0,0,1,.91436-.91568,3.44685,3.44685,0,0,1,.572-.30523,3.344,3.344,0,0,1,.62372-.18578,3.06907,3.06907,0,0,1,1.30053,0,3.22332,3.22332,0,0,1,1.19436.491,5.02835,5.02835,0,0,1,.50429.41139,4.8801,4.8801,0,0,1,.41139.50429,3.38246,3.38246,0,0,1,.30522.57064,3.47806,3.47806,0,0,1,.25215,1.274A3.36394,3.36394,0,0,1,426.36,446.865a5.02708,5.02708,0,0,1-.50429.4114,3.3057,3.3057,0,0,1-1.84463.55737m26.54143-1.65884a3.38754,3.38754,0,0,1-2.35024-.96877,5.04185,5.04185,0,0,1-.41007-.50428,3.27532,3.27532,0,0,1-.55737-1.84463,3.38659,3.38659,0,0,1,.96744-2.34892,5.02559,5.02559,0,0,1,.50429-.41139,3.44685,3.44685,0,0,1,.572-.30523,3.3432,3.3432,0,0,1,.62373-.18579,3.06952,3.06952,0,0,1,1.30052,0,3.22356,3.22356,0,0,1,1.19436.491,5.02559,5.02559,0,0,1,.50429.41139,3.38792,3.38792,0,0,1,.96876,2.34892,3.72635,3.72635,0,0,1-.06636.65026,3.37387,3.37387,0,0,1-.18579.62373,4.71469,4.71469,0,0,1-.30522.57064,4.8801,4.8801,0,0,1-.41139.50429,5.02559,5.02559,0,0,1-.50429.41139,3.30547,3.30547,0,0,1-1.84463.55737" transform="translate(-35.5 -118.5)" fill-rule="evenodd" /> 40 | </svg> 41 | ``` -------------------------------------------------------------------------------- /dspy/teleprompt/bootstrap.py: -------------------------------------------------------------------------------- ```python 1 | import logging 2 | import random 3 | import threading 4 | 5 | import tqdm 6 | 7 | import dspy 8 | from dspy.teleprompt.teleprompt import Teleprompter 9 | 10 | from .vanilla import LabeledFewShot 11 | 12 | # TODO: metrics should return an object with __bool__ basically, but fine if they're more complex. 13 | # They can also be sortable. 14 | 15 | # TODO: Switch here from dspy.dsp.Example to dspy.Example. Right now, it's okay because it's internal only (predictors). 16 | # NOTE: Notice the places where we don't shuffle examples. I do like that this one doesn't shuffle. 17 | # Other ones that consider options may want to use both unshuffled and then shuffle a few times, when 18 | # considering candidates. 19 | 20 | # TODO: the max_rounds via branch_idx to get past the cache, not just temperature. 21 | # In principle, we can also sample multiple outputs from the final generation step 22 | # (or even each step, in case the validation function just wants *one* thing that works, but nah) 23 | # and try them all. Having a pretty solid guess on the "final step" of each example isn't hard by the second round, 24 | # in the sense that we have the trace from the first round. (Yes it may change but that's an edge case that 25 | # won't hurt our "best effort" guarantees.) 26 | 27 | # TODO: When this bootstraps for another teleprompter like finetune, we want all demos we gather. 28 | # But when it's for direct use we may want to sample ONE demo per predictor--example pair. 29 | # This is important for "multi-use" modules. 30 | 31 | # TODO: Add baselines=[...] 32 | 33 | logger = logging.getLogger(__name__) 34 | 35 | 36 | class BootstrapFewShot(Teleprompter): 37 | def __init__( 38 | self, 39 | metric=None, 40 | metric_threshold=None, 41 | teacher_settings: dict | None = None, 42 | max_bootstrapped_demos=4, 43 | max_labeled_demos=16, 44 | max_rounds=1, 45 | max_errors=None, 46 | ): 47 | """A Teleprompter class that composes a set of demos/examples to go into a predictor's prompt. 48 | These demos come from a combination of labeled examples in the training set, and bootstrapped demos. 49 | 50 | Each bootstrap round copies the LM with a new ``rollout_id`` at ``temperature=1.0`` to 51 | bypass caches and gather diverse traces. 52 | 53 | Args: 54 | metric (Callable): A function that compares an expected value and predicted value, 55 | outputting the result of that comparison. 56 | metric_threshold (float, optional): If the metric yields a numerical value, then check it 57 | against this threshold when deciding whether or not to accept a bootstrap example. 58 | Defaults to None. 59 | teacher_settings (dict, optional): Settings for the `teacher` model. 60 | Defaults to None. 61 | max_bootstrapped_demos (int): Maximum number of bootstrapped demonstrations to include. 62 | Defaults to 4. 63 | max_labeled_demos (int): Maximum number of labeled demonstrations to include. 64 | Defaults to 16. 65 | max_rounds (int): Number of iterations to attempt generating the required bootstrap 66 | examples. If unsuccessful after `max_rounds`, the program ends. Defaults to 1. 67 | max_errors (Optional[int]): Maximum number of errors until program ends. 68 | If ``None``, inherits from ``dspy.settings.max_errors``. 69 | """ 70 | self.metric = metric 71 | self.metric_threshold = metric_threshold 72 | self.teacher_settings = {} if teacher_settings is None else teacher_settings 73 | 74 | self.max_bootstrapped_demos = max_bootstrapped_demos 75 | self.max_labeled_demos = max_labeled_demos 76 | self.max_rounds = max_rounds 77 | self.max_errors = max_errors 78 | self.error_count = 0 79 | self.error_lock = threading.Lock() 80 | 81 | def compile(self, student, *, teacher=None, trainset): 82 | self.trainset = trainset 83 | 84 | self._prepare_student_and_teacher(student, teacher) 85 | self._prepare_predictor_mappings() 86 | self._bootstrap() 87 | 88 | self.student = self._train() 89 | self.student._compiled = True 90 | 91 | return self.student 92 | 93 | def _prepare_student_and_teacher(self, student, teacher): 94 | self.student = student.reset_copy() 95 | 96 | # NOTE: behavior change on Oct 28, 2024. Deep copy instead of reset copy for the student-as-teacher. 97 | self.teacher = teacher.deepcopy() if teacher is not None else student.deepcopy() 98 | 99 | assert getattr(self.student, "_compiled", False) is False, "Student must be uncompiled." 100 | 101 | if self.max_labeled_demos and getattr(self.teacher, "_compiled", False) is False: 102 | teleprompter = LabeledFewShot(k=self.max_labeled_demos) 103 | self.teacher = teleprompter.compile(self.teacher.reset_copy(), trainset=self.trainset) 104 | 105 | def _prepare_predictor_mappings(self): 106 | name2predictor, predictor2name = {}, {} 107 | student, teacher = self.student, self.teacher 108 | 109 | assert len(student.predictors()) == len( 110 | teacher.predictors(), 111 | ), "Student and teacher must have the same number of predictors." 112 | 113 | for (name1, predictor1), (name2, predictor2) in zip( 114 | student.named_predictors(), teacher.named_predictors(), strict=False 115 | ): 116 | assert name1 == name2, "Student and teacher must have the same program structure." 117 | if hasattr(predictor1.signature, "equals"): 118 | assert predictor1.signature.equals( 119 | predictor2.signature, 120 | ), ( 121 | f"Student and teacher must have the same signatures. " 122 | f"{type(predictor1.signature)} != {type(predictor2.signature)}" 123 | ) 124 | else: 125 | # fallback in case if .equals is not implemented (e.g. dsp.Prompt) 126 | assert predictor1.signature == predictor2.signature, ( 127 | f"Student and teacher must have the same signatures. " 128 | f"{type(predictor1.signature)} != {type(predictor2.signature)}" 129 | ) 130 | assert id(predictor1) != id(predictor2), "Student and teacher must be different objects." 131 | 132 | name2predictor[name1] = None # dict(student=predictor1, teacher=predictor2) 133 | predictor2name[id(predictor1)] = name1 134 | 135 | # FIXME(shangyint): This is an ugly hack to bind traces of 136 | # retry.module to retry 137 | # if isinstance(predictor1, Retry): 138 | # predictor2name[id(predictor1.module)] = name1 139 | 140 | predictor2name[id(predictor2)] = name2 141 | 142 | self.name2predictor = name2predictor 143 | self.predictor2name = predictor2name 144 | 145 | def _bootstrap(self, *, max_bootstraps=None): 146 | max_bootstraps = max_bootstraps or self.max_bootstrapped_demos 147 | bootstrap_attempts = 0 148 | 149 | bootstrapped = {} 150 | self.name2traces = {name: [] for name in self.name2predictor} 151 | 152 | for example_idx, example in enumerate(tqdm.tqdm(self.trainset)): 153 | if len(bootstrapped) >= max_bootstraps: 154 | break 155 | 156 | for round_idx in range(self.max_rounds): 157 | bootstrap_attempts += 1 158 | 159 | if self._bootstrap_one_example(example, round_idx): 160 | bootstrapped[example_idx] = True 161 | break 162 | 163 | print( 164 | f"Bootstrapped {len(bootstrapped)} full traces after {example_idx} examples " 165 | f"for up to {self.max_rounds} rounds, amounting to {bootstrap_attempts} attempts." 166 | ) 167 | 168 | # Unbootstrapped training examples 169 | 170 | self.validation = [x for idx, x in enumerate(self.trainset) if idx not in bootstrapped] 171 | random.Random(0).shuffle(self.validation) 172 | 173 | self.validation = self.validation 174 | 175 | # NOTE: Can't yet use evaluate because we need to trace *per example* 176 | # evaluate = Evaluate(program=self.teacher, metric=self.metric, num_threads=12) 177 | # score = evaluate(self.metric, display_table=False, display_progress=True) 178 | 179 | def _bootstrap_one_example(self, example, round_idx=0): 180 | name2traces = {} 181 | teacher = self.teacher 182 | predictor_cache = {} 183 | 184 | try: 185 | with dspy.settings.context(trace=[], **self.teacher_settings): 186 | lm = dspy.settings.lm 187 | # Use a fresh rollout with temperature=1.0 to bypass caches. 188 | lm = lm.copy(rollout_id=round_idx, temperature=1.0) if round_idx > 0 else lm 189 | new_settings = {"lm": lm} if round_idx > 0 else {} 190 | 191 | with dspy.settings.context(**new_settings): 192 | for name, predictor in teacher.named_predictors(): 193 | predictor_cache[name] = predictor.demos 194 | predictor.demos = [x for x in predictor.demos if x != example] 195 | 196 | prediction = teacher(**example.inputs()) 197 | trace = dspy.settings.trace 198 | 199 | for name, predictor in teacher.named_predictors(): 200 | predictor.demos = predictor_cache[name] 201 | 202 | if self.metric: 203 | metric_val = self.metric(example, prediction, trace) 204 | if self.metric_threshold: 205 | success = metric_val >= self.metric_threshold 206 | else: 207 | success = metric_val 208 | else: 209 | success = True 210 | except Exception as e: 211 | success = False 212 | with self.error_lock: 213 | self.error_count += 1 214 | current_error_count = self.error_count 215 | effective_max_errors = self.max_errors if self.max_errors is not None else dspy.settings.max_errors 216 | if current_error_count >= effective_max_errors: 217 | raise e 218 | logger.error(f"Failed to run or to evaluate example {example} with {self.metric} due to {e}.") 219 | 220 | if success: 221 | for step in trace: 222 | predictor, inputs, outputs = step 223 | demo = dspy.Example(augmented=True, **inputs, **outputs) 224 | 225 | try: 226 | predictor_name = self.predictor2name[id(predictor)] 227 | except KeyError: 228 | continue # FIXME: ! 229 | 230 | # # TODO: Look closer into this. It's a bit tricky to reproduce. 231 | # print(f"Failed to find predictor {predictor} in {self.predictor2name}.") 232 | # print( 233 | # "Are you doing this in a notebook (Jupyter)? This might be caused by redefining values by rerunning cells.", 234 | # ) 235 | # print("Try restarting the notebook, or open an issue.") 236 | # raise KeyError( 237 | # f"Failed to find predictor {id(predictor)} {predictor} in {self.predictor2name}.", 238 | # ) from e 239 | 240 | name2traces[predictor_name] = name2traces.get(predictor_name, []) 241 | name2traces[predictor_name].append(demo) 242 | 243 | # Update the traces 244 | for name, demos in name2traces.items(): 245 | # If there are multiple traces for the same predictor in the sample example, 246 | # sample 50/50 from the first N-1 traces or the last trace. 247 | if len(demos) > 1: 248 | from dspy.utils.hasher import Hasher 249 | 250 | rng = random.Random(Hasher.hash(tuple(demos))) 251 | demos = [rng.choice(demos[:-1]) if rng.random() < 0.5 else demos[-1]] 252 | self.name2traces[name].extend(demos) 253 | 254 | return success 255 | 256 | def _train(self): 257 | rng = random.Random(0) 258 | raw_demos = self.validation 259 | 260 | for name, predictor in self.student.named_predictors(): 261 | augmented_demos = self.name2traces[name][: self.max_bootstrapped_demos] 262 | 263 | sample_size = min(self.max_labeled_demos - len(augmented_demos), len(raw_demos)) 264 | sample_size = max(0, sample_size) 265 | 266 | raw_demos = rng.sample(raw_demos, sample_size) 267 | predictor.demos = augmented_demos + raw_demos 268 | 269 | return self.student 270 | ``` -------------------------------------------------------------------------------- /docs/docs/tutorials/email_extraction/index.md: -------------------------------------------------------------------------------- ```markdown 1 | # Extracting Information from Emails with DSPy 2 | 3 | This tutorial demonstrates how to build an intelligent email processing system using DSPy. We'll create a system that can automatically extract key information from various types of emails, classify their intent, and structure the data for further processing. 4 | 5 | ## What You'll Build 6 | 7 | By the end of this tutorial, you'll have a DSPy-powered email processing system that can: 8 | 9 | - **Classify email types** (order confirmation, support request, meeting invitation, etc.) 10 | - **Extract key entities** (dates, amounts, product names, contact info) 11 | - **Determine urgency levels** and required actions 12 | - **Structure extracted data** into consistent formats 13 | - **Handle multiple email formats** robustly 14 | 15 | ## Prerequisites 16 | 17 | - Basic understanding of DSPy modules and signatures 18 | - Python 3.9+ installed 19 | - OpenAI API key (or access to another supported LLM) 20 | 21 | ## Installation and Setup 22 | 23 | ```bash 24 | pip install dspy 25 | ``` 26 | 27 | <details> 28 | <summary>Recommended: Set up MLflow Tracing to understand what's happening under the hood.</summary> 29 | 30 | ### MLflow DSPy Integration 31 | 32 | <a href="https://mlflow.org/">MLflow</a> is an LLMOps tool that natively integrates with DSPy and offer explainability and experiment tracking. In this tutorial, you can use MLflow to visualize prompts and optimization progress as traces to understand the DSPy's behavior better. You can set up MLflow easily by following the four steps below. 33 | 34 |  35 | 36 | 1. Install MLflow 37 | 38 | ```bash 39 | %pip install mlflow>=3.0.0 40 | ``` 41 | 42 | 2. Start MLflow UI in a separate terminal 43 | ```bash 44 | mlflow ui --port 5000 --backend-store-uri sqlite:///mlruns.db 45 | ``` 46 | 47 | 3. Connect the notebook to MLflow 48 | ```python 49 | import mlflow 50 | 51 | mlflow.set_tracking_uri("http://localhost:5000") 52 | mlflow.set_experiment("DSPy") 53 | ``` 54 | 55 | 4. Enabling tracing. 56 | ```python 57 | mlflow.dspy.autolog() 58 | ``` 59 | 60 | 61 | To learn more about the integration, visit [MLflow DSPy Documentation](https://mlflow.org/docs/latest/llms/dspy/index.html) as well. 62 | </details> 63 | 64 | ## Step 1: Define Our Data Structures 65 | 66 | First, let's define the types of information we want to extract from emails: 67 | 68 | ```python 69 | import dspy 70 | from typing import List, Optional, Literal 71 | from datetime import datetime 72 | from pydantic import BaseModel 73 | from enum import Enum 74 | 75 | class EmailType(str, Enum): 76 | ORDER_CONFIRMATION = "order_confirmation" 77 | SUPPORT_REQUEST = "support_request" 78 | MEETING_INVITATION = "meeting_invitation" 79 | NEWSLETTER = "newsletter" 80 | PROMOTIONAL = "promotional" 81 | INVOICE = "invoice" 82 | SHIPPING_NOTIFICATION = "shipping_notification" 83 | OTHER = "other" 84 | 85 | class UrgencyLevel(str, Enum): 86 | LOW = "low" 87 | MEDIUM = "medium" 88 | HIGH = "high" 89 | CRITICAL = "critical" 90 | 91 | class ExtractedEntity(BaseModel): 92 | entity_type: str 93 | value: str 94 | confidence: float 95 | ``` 96 | 97 | ## Step 2: Create DSPy Signatures 98 | 99 | Now let's define the signatures for our email processing pipeline: 100 | 101 | ```python 102 | class ClassifyEmail(dspy.Signature): 103 | """Classify the type and urgency of an email based on its content.""" 104 | 105 | email_subject: str = dspy.InputField(desc="The subject line of the email") 106 | email_body: str = dspy.InputField(desc="The main content of the email") 107 | sender: str = dspy.InputField(desc="Email sender information") 108 | 109 | email_type: EmailType = dspy.OutputField(desc="The classified type of email") 110 | urgency: UrgencyLevel = dspy.OutputField(desc="The urgency level of the email") 111 | reasoning: str = dspy.OutputField(desc="Brief explanation of the classification") 112 | 113 | class ExtractEntities(dspy.Signature): 114 | """Extract key entities and information from email content.""" 115 | 116 | email_content: str = dspy.InputField(desc="The full email content including subject and body") 117 | email_type: EmailType = dspy.InputField(desc="The classified type of email") 118 | 119 | key_entities: list[ExtractedEntity] = dspy.OutputField(desc="List of extracted entities with type, value, and confidence") 120 | financial_amount: Optional[float] = dspy.OutputField(desc="Any monetary amounts found (e.g., '$99.99')") 121 | important_dates: list[str] = dspy.OutputField(desc="List of important dates found in the email") 122 | contact_info: list[str] = dspy.OutputField(desc="Relevant contact information extracted") 123 | 124 | class GenerateActionItems(dspy.Signature): 125 | """Determine what actions are needed based on the email content and extracted information.""" 126 | 127 | email_type: EmailType = dspy.InputField() 128 | urgency: UrgencyLevel = dspy.InputField() 129 | email_summary: str = dspy.InputField(desc="Brief summary of the email content") 130 | extracted_entities: list[ExtractedEntity] = dspy.InputField(desc="Key entities found in the email") 131 | 132 | action_required: bool = dspy.OutputField(desc="Whether any action is required") 133 | action_items: list[str] = dspy.OutputField(desc="List of specific actions needed") 134 | deadline: Optional[str] = dspy.OutputField(desc="Deadline for action if applicable") 135 | priority_score: int = dspy.OutputField(desc="Priority score from 1-10") 136 | 137 | class SummarizeEmail(dspy.Signature): 138 | """Create a concise summary of the email content.""" 139 | 140 | email_subject: str = dspy.InputField() 141 | email_body: str = dspy.InputField() 142 | key_entities: list[ExtractedEntity] = dspy.InputField() 143 | 144 | summary: str = dspy.OutputField(desc="A 2-3 sentence summary of the email's main points") 145 | ``` 146 | 147 | ## Step 3: Build the Email Processing Module 148 | 149 | Now let's create our main email processing module: 150 | 151 | ```python 152 | class EmailProcessor(dspy.Module): 153 | """A comprehensive email processing system using DSPy.""" 154 | 155 | def __init__(self): 156 | super().__init__() 157 | 158 | # Initialize our processing components 159 | self.classifier = dspy.ChainOfThought(ClassifyEmail) 160 | self.entity_extractor = dspy.ChainOfThought(ExtractEntities) 161 | self.action_generator = dspy.ChainOfThought(GenerateActionItems) 162 | self.summarizer = dspy.ChainOfThought(SummarizeEmail) 163 | 164 | def forward(self, email_subject: str, email_body: str, sender: str = ""): 165 | """Process an email and extract structured information.""" 166 | 167 | # Step 1: Classify the email 168 | classification = self.classifier( 169 | email_subject=email_subject, 170 | email_body=email_body, 171 | sender=sender 172 | ) 173 | 174 | # Step 2: Extract entities 175 | full_content = f"Subject: {email_subject}\n\nFrom: {sender}\n\n{email_body}" 176 | entities = self.entity_extractor( 177 | email_content=full_content, 178 | email_type=classification.email_type 179 | ) 180 | 181 | # Step 3: Generate summary 182 | summary = self.summarizer( 183 | email_subject=email_subject, 184 | email_body=email_body, 185 | key_entities=entities.key_entities 186 | ) 187 | 188 | # Step 4: Determine actions 189 | actions = self.action_generator( 190 | email_type=classification.email_type, 191 | urgency=classification.urgency, 192 | email_summary=summary.summary, 193 | extracted_entities=entities.key_entities 194 | ) 195 | 196 | # Step 5: Structure the results 197 | return dspy.Prediction( 198 | email_type=classification.email_type, 199 | urgency=classification.urgency, 200 | summary=summary.summary, 201 | key_entities=entities.key_entities, 202 | financial_amount=entities.financial_amount, 203 | important_dates=entities.important_dates, 204 | action_required=actions.action_required, 205 | action_items=actions.action_items, 206 | deadline=actions.deadline, 207 | priority_score=actions.priority_score, 208 | reasoning=classification.reasoning, 209 | contact_info=entities.contact_info 210 | ) 211 | ``` 212 | 213 | ## Step 4: Running the Email Processing System 214 | 215 | Let's create a simple function to test our email processing system: 216 | 217 | ```python 218 | import os 219 | def run_email_processing_demo(): 220 | """Demonstration of the email processing system.""" 221 | 222 | # Configure DSPy 223 | lm = dspy.LM(model='openai/gpt-4o-mini') 224 | dspy.configure(lm=lm) 225 | os.environ["OPENAI_API_KEY"] = "<YOUR OPENAI KEY>" 226 | 227 | # Create our email processor 228 | processor = EmailProcessor() 229 | 230 | # Sample emails for testing 231 | sample_emails = [ 232 | { 233 | "subject": "Order Confirmation #12345 - Your MacBook Pro is on the way!", 234 | "body": """Dear John Smith, 235 | 236 | Thank you for your order! We're excited to confirm that your order #12345 has been processed. 237 | 238 | Order Details: 239 | - MacBook Pro 14-inch (Space Gray) 240 | - Order Total: $2,399.00 241 | - Estimated Delivery: December 15, 2024 242 | - Tracking Number: 1Z999AA1234567890 243 | 244 | If you have any questions, please contact our support team at [email protected]. 245 | 246 | Best regards, 247 | TechStore Team""", 248 | "sender": "[email protected]" 249 | }, 250 | { 251 | "subject": "URGENT: Server Outage - Immediate Action Required", 252 | "body": """Hi DevOps Team, 253 | 254 | We're experiencing a critical server outage affecting our production environment. 255 | 256 | Impact: All users unable to access the platform 257 | Started: 2:30 PM EST 258 | 259 | Please join the emergency call immediately: +1-555-123-4567 260 | 261 | This is our highest priority. 262 | 263 | Thanks, 264 | Site Reliability Team""", 265 | "sender": "[email protected]" 266 | }, 267 | { 268 | "subject": "Meeting Invitation: Q4 Planning Session", 269 | "body": """Hello team, 270 | 271 | You're invited to our Q4 planning session. 272 | 273 | When: Friday, December 20, 2024 at 2:00 PM - 4:00 PM EST 274 | Where: Conference Room A 275 | 276 | Please confirm your attendance by December 18th. 277 | 278 | Best, 279 | Sarah Johnson""", 280 | "sender": "[email protected]" 281 | } 282 | ] 283 | 284 | # Process each email and display results 285 | print("🚀 Email Processing Demo") 286 | print("=" * 50) 287 | 288 | for i, email in enumerate(sample_emails): 289 | print(f"\n📧 EMAIL {i+1}: {email['subject'][:50]}...") 290 | 291 | # Process the email 292 | result = processor( 293 | email_subject=email["subject"], 294 | email_body=email["body"], 295 | sender=email["sender"] 296 | ) 297 | 298 | # Display key results 299 | print(f" 📊 Type: {result.email_type}") 300 | print(f" 🚨 Urgency: {result.urgency}") 301 | print(f" 📝 Summary: {result.summary}") 302 | 303 | if result.financial_amount: 304 | print(f" 💰 Amount: ${result.financial_amount:,.2f}") 305 | 306 | if result.action_required: 307 | print(f" ✅ Action Required: Yes") 308 | if result.deadline: 309 | print(f" ⏰ Deadline: {result.deadline}") 310 | else: 311 | print(f" ✅ Action Required: No") 312 | 313 | # Run the demo 314 | if __name__ == "__main__": 315 | run_email_processing_demo() 316 | ``` 317 | 318 | ## Expected Output 319 | ``` 320 | 🚀 Email Processing Demo 321 | ================================================== 322 | 323 | 📧 EMAIL 1: Order Confirmation #12345 - Your MacBook Pro is on... 324 | 📊 Type: order_confirmation 325 | 🚨 Urgency: low 326 | 📝 Summary: The email confirms John Smith's order #12345 for a MacBook Pro 14-inch in Space Gray, totaling $2,399.00, with an estimated delivery date of December 15, 2024. It includes a tracking number and contact information for customer support. 327 | 💰 Amount: $2,399.00 328 | ✅ Action Required: No 329 | 330 | 📧 EMAIL 2: URGENT: Server Outage - Immediate Action Required... 331 | 📊 Type: other 332 | 🚨 Urgency: critical 333 | 📝 Summary: The Site Reliability Team has reported a critical server outage that began at 2:30 PM EST, preventing all users from accessing the platform. They have requested the DevOps Team to join an emergency call immediately to address the issue. 334 | ✅ Action Required: Yes 335 | ⏰ Deadline: Immediately 336 | 337 | 📧 EMAIL 3: Meeting Invitation: Q4 Planning Session... 338 | 📊 Type: meeting_invitation 339 | 🚨 Urgency: medium 340 | 📝 Summary: Sarah Johnson has invited the team to a Q4 planning session on December 20, 2024, from 2:00 PM to 4:00 PM EST in Conference Room A. Attendees are asked to confirm their participation by December 18th. 341 | ✅ Action Required: Yes 342 | ⏰ Deadline: December 18th 343 | ``` 344 | 345 | ## Next Steps 346 | 347 | - **Add more email types** and refine classification (newsletter, promotional, etc.) 348 | - **Add integration** with email providers (Gmail API, Outlook, IMAP) 349 | - **Experiment with different LLMs** and optimization strategies 350 | - **Add multilingual support** for international email processing 351 | - **Optimization** for increasing the performance of your program 352 | ```