This is page 11 of 17. Use http://codebase.md/stanfordnlp/dspy?lines=true&page={x} to view the full context. # Directory Structure ``` ├── .github │ ├── .internal_dspyai │ │ ├── internals │ │ │ ├── build-and-release.md │ │ │ └── release-checklist.md │ │ └── pyproject.toml │ ├── .tmp │ │ └── .generated-actions │ │ └── run-pypi-publish-in-docker-container │ │ └── action.yml │ ├── ISSUE_TEMPLATE │ │ ├── bug_report.yml │ │ └── feature_request.yml │ ├── PULL_REQUEST_TEMPLATE │ │ └── pull_request_template.md │ ├── workflow_scripts │ │ └── install_testpypi_pkg.sh │ └── workflows │ ├── build_and_release.yml │ ├── build_utils │ │ └── test_version.py │ ├── docs-push.yml │ ├── precommits_check.yml │ └── run_tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── docs │ ├── .gitignore │ ├── docs │ │ ├── api │ │ │ ├── adapters │ │ │ │ ├── Adapter.md │ │ │ │ ├── ChatAdapter.md │ │ │ │ ├── JSONAdapter.md │ │ │ │ └── TwoStepAdapter.md │ │ │ ├── evaluation │ │ │ │ ├── answer_exact_match.md │ │ │ │ ├── answer_passage_match.md │ │ │ │ ├── CompleteAndGrounded.md │ │ │ │ ├── Evaluate.md │ │ │ │ ├── EvaluationResult.md │ │ │ │ └── SemanticF1.md │ │ │ ├── experimental │ │ │ │ ├── Citations.md │ │ │ │ └── Document.md │ │ │ ├── index.md │ │ │ ├── models │ │ │ │ ├── Embedder.md │ │ │ │ └── LM.md │ │ │ ├── modules │ │ │ │ ├── BestOfN.md │ │ │ │ ├── ChainOfThought.md │ │ │ │ ├── CodeAct.md │ │ │ │ ├── Module.md │ │ │ │ ├── MultiChainComparison.md │ │ │ │ ├── Parallel.md │ │ │ │ ├── Predict.md │ │ │ │ ├── ProgramOfThought.md │ │ │ │ ├── ReAct.md │ │ │ │ └── Refine.md │ │ │ ├── optimizers │ │ │ │ ├── BetterTogether.md │ │ │ │ ├── BootstrapFewShot.md │ │ │ │ ├── BootstrapFewShotWithRandomSearch.md │ │ │ │ ├── BootstrapFinetune.md │ │ │ │ ├── BootstrapRS.md │ │ │ │ ├── COPRO.md │ │ │ │ ├── Ensemble.md │ │ │ │ ├── GEPA │ │ │ │ │ ├── GEPA_Advanced.md │ │ │ │ │ └── overview.md │ │ │ │ ├── InferRules.md │ │ │ │ ├── KNN.md │ │ │ │ ├── KNNFewShot.md │ │ │ │ ├── LabeledFewShot.md │ │ │ │ ├── MIPROv2.md │ │ │ │ └── SIMBA.md │ │ │ ├── primitives │ │ │ │ ├── Audio.md │ │ │ │ ├── Code.md │ │ │ │ ├── Example.md │ │ │ │ ├── History.md │ │ │ │ ├── Image.md │ │ │ │ ├── Prediction.md │ │ │ │ ├── Tool.md │ │ │ │ └── ToolCalls.md │ │ │ ├── signatures │ │ │ │ ├── InputField.md │ │ │ │ ├── OutputField.md │ │ │ │ └── Signature.md │ │ │ ├── tools │ │ │ │ ├── ColBERTv2.md │ │ │ │ ├── Embeddings.md │ │ │ │ └── PythonInterpreter.md │ │ │ └── utils │ │ │ ├── asyncify.md │ │ │ ├── configure_cache.md │ │ │ ├── disable_litellm_logging.md │ │ │ ├── disable_logging.md │ │ │ ├── enable_litellm_logging.md │ │ │ ├── enable_logging.md │ │ │ ├── inspect_history.md │ │ │ ├── load.md │ │ │ ├── StatusMessage.md │ │ │ ├── StatusMessageProvider.md │ │ │ ├── streamify.md │ │ │ └── StreamListener.md │ │ ├── cheatsheet.md │ │ ├── community │ │ │ ├── community-resources.md │ │ │ ├── how-to-contribute.md │ │ │ └── use-cases.md │ │ ├── deep-dive │ │ │ └── data-handling │ │ │ ├── built-in-datasets.md │ │ │ ├── examples.md │ │ │ ├── img │ │ │ │ └── data-loading.png │ │ │ └── loading-custom-data.md │ │ ├── faqs.md │ │ ├── index.md │ │ ├── js │ │ │ └── runllm-widget.js │ │ ├── learn │ │ │ ├── evaluation │ │ │ │ ├── data.md │ │ │ │ ├── metrics.md │ │ │ │ └── overview.md │ │ │ ├── figures │ │ │ │ ├── native_tool_call.png │ │ │ │ └── teleprompter-classes.png │ │ │ ├── index.md │ │ │ ├── optimization │ │ │ │ ├── optimizers.md │ │ │ │ └── overview.md │ │ │ └── programming │ │ │ ├── 7-assertions.md │ │ │ ├── adapters.md │ │ │ ├── language_models.md │ │ │ ├── mcp.md │ │ │ ├── modules.md │ │ │ ├── overview.md │ │ │ ├── signatures.md │ │ │ └── tools.md │ │ ├── production │ │ │ └── index.md │ │ ├── roadmap.md │ │ ├── static │ │ │ ├── .nojekyll │ │ │ └── img │ │ │ ├── dspy_logo.png │ │ │ ├── logo.png │ │ │ ├── mlflow-tracing-rag.png │ │ │ ├── modular.png │ │ │ ├── optimize.png │ │ │ ├── undraw_docusaurus_mountain.svg │ │ │ ├── undraw_docusaurus_react.svg │ │ │ ├── undraw_docusaurus_tree.svg │ │ │ └── universal_compatibility.png │ │ ├── stylesheets │ │ │ └── extra.css │ │ └── tutorials │ │ ├── agents │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-agent.png │ │ ├── ai_text_game │ │ │ └── index.md │ │ ├── async │ │ │ └── index.md │ │ ├── audio │ │ │ └── index.ipynb │ │ ├── build_ai_program │ │ │ └── index.md │ │ ├── cache │ │ │ └── index.md │ │ ├── classification │ │ │ └── index.md │ │ ├── classification_finetuning │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-classification.png │ │ ├── conversation_history │ │ │ └── index.md │ │ ├── core_development │ │ │ └── index.md │ │ ├── custom_module │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-custom-module.png │ │ ├── customer_service_agent │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-customer-service-agent.png │ │ ├── deployment │ │ │ ├── dspy_mlflow_ui.png │ │ │ └── index.md │ │ ├── email_extraction │ │ │ ├── index.md │ │ │ └── mlflow-tracing-email-extraction.png │ │ ├── entity_extraction │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-entity-extraction.png │ │ ├── games │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-agent.png │ │ ├── gepa_ai_program │ │ │ └── index.md │ │ ├── gepa_aime │ │ │ ├── index.ipynb │ │ │ ├── mlflow-tracing-gepa-aime.png │ │ │ └── mlflow-tracking-gepa-aime-optimization.png │ │ ├── gepa_facilitysupportanalyzer │ │ │ ├── index.ipynb │ │ │ ├── mlflow-tracing-gepa-support.png │ │ │ └── mlflow-tracking-gepa-support-optimization.png │ │ ├── gepa_papillon │ │ │ ├── index.ipynb │ │ │ ├── mlflow-tracing-gepa-papilon.png │ │ │ └── mlflow-tracking-gepa-papilon-optimization.png │ │ ├── image_generation_prompting │ │ │ └── index.ipynb │ │ ├── index.md │ │ ├── llms_txt_generation │ │ │ └── index.md │ │ ├── math │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-math.png │ │ ├── mcp │ │ │ └── index.md │ │ ├── mem0_react_agent │ │ │ └── index.md │ │ ├── multihop_search │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-multi-hop.png │ │ ├── observability │ │ │ ├── index.md │ │ │ ├── mlflow_trace_ui_navigation.gif │ │ │ ├── mlflow_trace_ui.png │ │ │ └── mlflow_trace_view.png │ │ ├── optimize_ai_program │ │ │ └── index.md │ │ ├── optimizer_tracking │ │ │ ├── child_run.png │ │ │ ├── experiment.png │ │ │ ├── index.md │ │ │ └── parent_run.png │ │ ├── output_refinement │ │ │ └── best-of-n-and-refine.md │ │ ├── papillon │ │ │ └── index.md │ │ ├── program_of_thought │ │ │ └── index.ipynb │ │ ├── rag │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-rag.png │ │ ├── real_world_examples │ │ │ └── index.md │ │ ├── rl_ai_program │ │ │ └── index.md │ │ ├── rl_multihop │ │ │ └── index.ipynb │ │ ├── rl_papillon │ │ │ └── index.ipynb │ │ ├── sample_code_generation │ │ │ └── index.md │ │ ├── saving │ │ │ └── index.md │ │ ├── streaming │ │ │ └── index.md │ │ ├── tool_use │ │ │ └── index.ipynb │ │ └── yahoo_finance_react │ │ └── index.md │ ├── mkdocs.yml │ ├── overrides │ │ ├── home.html │ │ ├── main.html │ │ └── partials │ │ └── tabs.html │ ├── Pipfile │ ├── Pipfile.lock │ ├── README.md │ ├── requirements.txt │ ├── scripts │ │ ├── generate_api_docs.py │ │ └── generate_api_summary.py │ └── vercel.json ├── dspy │ ├── __init__.py │ ├── __metadata__.py │ ├── adapters │ │ ├── __init__.py │ │ ├── baml_adapter.py │ │ ├── base.py │ │ ├── chat_adapter.py │ │ ├── json_adapter.py │ │ ├── two_step_adapter.py │ │ ├── types │ │ │ ├── __init__.py │ │ │ ├── audio.py │ │ │ ├── base_type.py │ │ │ ├── citation.py │ │ │ ├── code.py │ │ │ ├── document.py │ │ │ ├── history.py │ │ │ ├── image.py │ │ │ └── tool.py │ │ ├── utils.py │ │ └── xml_adapter.py │ ├── clients │ │ ├── __init__.py │ │ ├── base_lm.py │ │ ├── cache.py │ │ ├── databricks.py │ │ ├── embedding.py │ │ ├── lm_local_arbor.py │ │ ├── lm_local.py │ │ ├── lm.py │ │ ├── openai.py │ │ ├── provider.py │ │ └── utils_finetune.py │ ├── datasets │ │ ├── __init__.py │ │ ├── alfworld │ │ │ ├── __init__.py │ │ │ ├── alfworld.py │ │ │ └── base_config.yml │ │ ├── colors.py │ │ ├── dataloader.py │ │ ├── dataset.py │ │ ├── gsm8k.py │ │ ├── hotpotqa.py │ │ └── math.py │ ├── dsp │ │ ├── __init__.py │ │ ├── colbertv2.py │ │ └── utils │ │ ├── __init__.py │ │ ├── dpr.py │ │ ├── settings.py │ │ └── utils.py │ ├── evaluate │ │ ├── __init__.py │ │ ├── auto_evaluation.py │ │ ├── evaluate.py │ │ └── metrics.py │ ├── experimental │ │ └── __init__.py │ ├── predict │ │ ├── __init__.py │ │ ├── aggregation.py │ │ ├── avatar │ │ │ ├── __init__.py │ │ │ ├── avatar.py │ │ │ ├── models.py │ │ │ └── signatures.py │ │ ├── best_of_n.py │ │ ├── chain_of_thought.py │ │ ├── code_act.py │ │ ├── knn.py │ │ ├── multi_chain_comparison.py │ │ ├── parallel.py │ │ ├── parameter.py │ │ ├── predict.py │ │ ├── program_of_thought.py │ │ ├── react.py │ │ ├── refine.py │ │ └── retry.py │ ├── primitives │ │ ├── __init__.py │ │ ├── base_module.py │ │ ├── example.py │ │ ├── module.py │ │ ├── prediction.py │ │ ├── python_interpreter.py │ │ └── runner.js │ ├── propose │ │ ├── __init__.py │ │ ├── dataset_summary_generator.py │ │ ├── grounded_proposer.py │ │ ├── propose_base.py │ │ └── utils.py │ ├── retrievers │ │ ├── __init__.py │ │ ├── databricks_rm.py │ │ ├── embeddings.py │ │ ├── retrieve.py │ │ └── weaviate_rm.py │ ├── signatures │ │ ├── __init__.py │ │ ├── field.py │ │ ├── signature.py │ │ └── utils.py │ ├── streaming │ │ ├── __init__.py │ │ ├── messages.py │ │ ├── streamify.py │ │ └── streaming_listener.py │ ├── teleprompt │ │ ├── __init__.py │ │ ├── avatar_optimizer.py │ │ ├── bettertogether.py │ │ ├── bootstrap_finetune.py │ │ ├── bootstrap_trace.py │ │ ├── bootstrap.py │ │ ├── copro_optimizer.py │ │ ├── ensemble.py │ │ ├── gepa │ │ │ ├── __init__.py │ │ │ ├── gepa_utils.py │ │ │ ├── gepa.py │ │ │ └── instruction_proposal.py │ │ ├── grpo.py │ │ ├── infer_rules.py │ │ ├── knn_fewshot.py │ │ ├── mipro_optimizer_v2.py │ │ ├── random_search.py │ │ ├── signature_opt.py │ │ ├── simba_utils.py │ │ ├── simba.py │ │ ├── teleprompt_optuna.py │ │ ├── teleprompt.py │ │ ├── utils.py │ │ └── vanilla.py │ └── utils │ ├── __init__.py │ ├── annotation.py │ ├── asyncify.py │ ├── caching.py │ ├── callback.py │ ├── dummies.py │ ├── exceptions.py │ ├── hasher.py │ ├── inspect_history.py │ ├── langchain_tool.py │ ├── logging_utils.py │ ├── mcp.py │ ├── parallelizer.py │ ├── saving.py │ ├── syncify.py │ ├── unbatchify.py │ └── usage_tracker.py ├── LICENSE ├── pyproject.toml ├── README.md ├── tests │ ├── __init__.py │ ├── adapters │ │ ├── test_adapter_utils.py │ │ ├── test_baml_adapter.py │ │ ├── test_base_type.py │ │ ├── test_chat_adapter.py │ │ ├── test_citation.py │ │ ├── test_code.py │ │ ├── test_document.py │ │ ├── test_json_adapter.py │ │ ├── test_tool.py │ │ ├── test_two_step_adapter.py │ │ └── test_xml_adapter.py │ ├── callback │ │ └── test_callback.py │ ├── clients │ │ ├── test_cache.py │ │ ├── test_databricks.py │ │ ├── test_embedding.py │ │ ├── test_inspect_global_history.py │ │ └── test_lm.py │ ├── conftest.py │ ├── datasets │ │ └── test_dataset.py │ ├── docs │ │ └── test_mkdocs_links.py │ ├── evaluate │ │ ├── test_evaluate.py │ │ └── test_metrics.py │ ├── examples │ │ └── test_baleen.py │ ├── metadata │ │ └── test_metadata.py │ ├── predict │ │ ├── test_aggregation.py │ │ ├── test_best_of_n.py │ │ ├── test_chain_of_thought.py │ │ ├── test_code_act.py │ │ ├── test_knn.py │ │ ├── test_multi_chain_comparison.py │ │ ├── test_parallel.py │ │ ├── test_predict.py │ │ ├── test_program_of_thought.py │ │ ├── test_react.py │ │ ├── test_refine.py │ │ └── test_retry.py │ ├── primitives │ │ ├── resources │ │ │ └── saved_program.json │ │ ├── test_base_module.py │ │ ├── test_example.py │ │ ├── test_module.py │ │ └── test_python_interpreter.py │ ├── propose │ │ └── test_grounded_proposer.py │ ├── README.md │ ├── reliability │ │ ├── __init__.py │ │ ├── complex_types │ │ │ └── generated │ │ │ ├── test_many_types_1 │ │ │ │ ├── inputs │ │ │ │ │ ├── input1.json │ │ │ │ │ └── input2.json │ │ │ │ ├── program.py │ │ │ │ └── schema.json │ │ │ ├── test_nesting_1 │ │ │ │ ├── inputs │ │ │ │ │ ├── input1.json │ │ │ │ │ └── input2.json │ │ │ │ ├── program.py │ │ │ │ └── schema.json │ │ │ └── test_nesting_2 │ │ │ ├── inputs │ │ │ │ └── input1.json │ │ │ ├── program.py │ │ │ └── schema.json │ │ ├── conftest.py │ │ ├── generate │ │ │ ├── __init__.py │ │ │ ├── __main__.py │ │ │ └── utils.py │ │ ├── input_formats │ │ │ └── generated │ │ │ └── test_markdown_1 │ │ │ ├── inputs │ │ │ │ ├── input1.json │ │ │ │ └── input2.json │ │ │ ├── program.py │ │ │ └── schema.json │ │ ├── README.md │ │ ├── reliability_conf.yaml │ │ ├── test_generated.py │ │ ├── test_pydantic_models.py │ │ └── utils.py │ ├── retrievers │ │ └── test_embeddings.py │ ├── signatures │ │ ├── test_adapter_image.py │ │ ├── test_custom_types.py │ │ └── test_signature.py │ ├── streaming │ │ └── test_streaming.py │ ├── teleprompt │ │ ├── gepa_dummy_lm_custom_component_selector_custom_instruction_proposer.json │ │ ├── gepa_dummy_lm.json │ │ ├── test_bootstrap_finetune.py │ │ ├── test_bootstrap_trace.py │ │ ├── test_bootstrap.py │ │ ├── test_copro_optimizer.py │ │ ├── test_ensemble.py │ │ ├── test_finetune.py │ │ ├── test_gepa_instruction_proposer.py │ │ ├── test_gepa.py │ │ ├── test_grpo.py │ │ ├── test_knn_fewshot.py │ │ ├── test_random_search.py │ │ ├── test_teleprompt.py │ │ └── test_utils.py │ ├── test_utils │ │ ├── __init__.py │ │ └── server │ │ ├── __init__.py │ │ ├── litellm_server_config.yaml │ │ └── litellm_server.py │ └── utils │ ├── __init__.py │ ├── resources │ │ └── mcp_server.py │ ├── test_annotation.py │ ├── test_asyncify.py │ ├── test_exceptions.py │ ├── test_langchain_tool.py │ ├── test_mcp.py │ ├── test_parallelizer.py │ ├── test_saving.py │ ├── test_settings.py │ ├── test_syncify.py │ ├── test_unbatchify.py │ └── test_usage_tracker.py └── uv.lock ``` # Files -------------------------------------------------------------------------------- /dspy/teleprompt/simba.py: -------------------------------------------------------------------------------- ```python 1 | from __future__ import annotations 2 | 3 | import logging 4 | import random 5 | from typing import Any, Callable 6 | 7 | import numpy as np 8 | 9 | import dspy 10 | from dspy.teleprompt.simba_utils import append_a_demo, append_a_rule, prepare_models_for_resampling, wrap_program 11 | from dspy.teleprompt.teleprompt import Teleprompter 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class SIMBA(Teleprompter): 17 | """ 18 | SIMBA (Stochastic Introspective Mini-Batch Ascent) optimizer for DSPy. 19 | 20 | SIMBA is a DSPy optimizer that uses the LLM to analyze its own performance and 21 | generate improvement rules. It samples mini-batches, identifies challenging examples 22 | with high output variability, then either creates self-reflective rules or adds 23 | successful examples as demonstrations. 24 | 25 | For more details, see: https://dspy.ai/api/optimizers/SIMBA/ 26 | """ 27 | 28 | def __init__( 29 | self, 30 | *, 31 | metric: Callable[[dspy.Example, dict[str, Any]], float], 32 | bsize: int = 32, 33 | num_candidates: int = 6, 34 | max_steps: int = 8, 35 | max_demos: int = 4, 36 | prompt_model: dspy.LM | None = None, 37 | teacher_settings: dict | None = None, 38 | demo_input_field_maxlen: int = 100_000, 39 | num_threads: int | None = None, 40 | temperature_for_sampling: float = 0.2, 41 | temperature_for_candidates: float = 0.2, 42 | ) -> None: 43 | """ 44 | Initializes SIMBA. 45 | 46 | Args: 47 | metric: A function that takes an Example and a prediction_dict 48 | as input and returns a float. 49 | bsize: Mini-batch size. Defaults to 32. 50 | num_candidates: Number of new candidate programs to produce 51 | per iteration. Defaults to 6. 52 | max_steps: Number of optimization steps to run. Defaults to 8. 53 | max_demos: Maximum number of demos a predictor can hold 54 | before dropping some. Defaults to 4. 55 | prompt_model: The model to use to evolve the program. When `prompt_model is None`, the globally configured 56 | lm is used. 57 | teacher_settings: Settings for the teacher model. Defaults to None. 58 | demo_input_field_maxlen: Maximum number of characters to keep 59 | in an input field when building a new demo. Defaults to 100,000. 60 | num_threads: Number of threads for parallel execution. 61 | Defaults to None. 62 | temperature_for_sampling: Temperature used for picking 63 | programs during the trajectory-sampling step. Defaults to 0.2. 64 | temperature_for_candidates: Temperature used for picking 65 | the source program for building new candidates. Defaults to 0.2. 66 | """ 67 | self.metric = metric 68 | self.bsize = bsize 69 | self.num_candidates = num_candidates 70 | self.max_steps = max_steps 71 | self.max_demos = max_demos 72 | self.prompt_model = prompt_model or dspy.settings.lm 73 | self.teacher_settings = teacher_settings 74 | self.demo_input_field_maxlen = demo_input_field_maxlen 75 | self.num_threads = num_threads 76 | 77 | self.temperature_for_sampling = temperature_for_sampling 78 | self.temperature_for_candidates = temperature_for_candidates 79 | 80 | if self.max_demos > 0: 81 | self.strategies = [append_a_demo(demo_input_field_maxlen), append_a_rule] 82 | else: 83 | self.strategies = [append_a_rule] 84 | 85 | def compile( 86 | self, 87 | student: dspy.Module, 88 | *, 89 | trainset: list[dspy.Example], 90 | seed: int = 0 91 | ) -> dspy.Module: 92 | """ 93 | Compile and optimize the student module using SIMBA. 94 | 95 | Args: 96 | student: The module to optimize 97 | trainset: Training examples for optimization 98 | seed: Random seed for reproducibility 99 | 100 | Returns: 101 | The optimized module with candidate_programs and trial_logs attached 102 | """ 103 | # Basic checks 104 | assert len(trainset) >= self.bsize, f"Trainset too small: {len(trainset)} < {self.bsize}" 105 | 106 | # Initialize RNG 107 | rng = random.Random(seed) 108 | rng_np = np.random.default_rng(seed) 109 | 110 | programs = [] 111 | program_scores = {} 112 | next_program_idx = 0 113 | 114 | # Helper functions 115 | def calc_average_score(prog_idx: int) -> float: 116 | scores = program_scores.get(prog_idx, []) 117 | if not scores: 118 | return 0.0 119 | return sum(scores) / len(scores) 120 | 121 | def top_k_plus_baseline(k: int) -> list[int]: 122 | # Sort all programs by descending average score 123 | scored_programs = sorted(programs, key=lambda p: calc_average_score(p.simba_idx), reverse=True) 124 | top_k = [p.simba_idx for p in scored_programs[:k]] 125 | # Ensure baseline=0 is in there: 126 | if 0 not in top_k and len(top_k) > 0: 127 | top_k[-1] = 0 128 | return list(dict.fromkeys(top_k)) 129 | 130 | def softmax_sample(rng_obj: random.Random, program_idxs: list[int], temperature: float) -> int: 131 | if not program_idxs: 132 | raise ValueError("No programs available for softmax sampling.") 133 | 134 | # Unnormalized weights 135 | scores = [calc_average_score(idx) for idx in program_idxs] 136 | exps = [np.exp(s / temperature) for s in scores] 137 | sum_exps = sum(exps) 138 | if sum_exps <= 0: 139 | # Fallback: uniform if all exps are zero 140 | return rng_obj.choice(program_idxs) 141 | 142 | # Weighted random choice 143 | probs = [val / sum_exps for val in exps] 144 | return rng_obj.choices(program_idxs, weights=probs, k=1)[0] 145 | 146 | def register_new_program(prog: dspy.Module, score_list: list[float]) -> None: 147 | nonlocal next_program_idx 148 | next_program_idx += 1 149 | new_idx = next_program_idx 150 | prog.simba_idx = new_idx 151 | programs.append(prog) 152 | program_scores[new_idx] = score_list 153 | 154 | # Initialize the baseline program: index=0 155 | student = student.deepcopy() 156 | student.simba_idx = 0 157 | programs.append(student) 158 | program_scores[0] = [] 159 | 160 | winning_programs = [student] 161 | 162 | # Data shuffling 163 | data_indices = list(range(len(trainset))) 164 | rng.shuffle(data_indices) 165 | instance_idx = 0 166 | 167 | # Parallel runner 168 | run_parallel = dspy.Parallel(access_examples=False, num_threads=self.num_threads) 169 | 170 | trial_logs = {} 171 | for batch_idx in range(self.max_steps): 172 | trial_logs[batch_idx] = {} 173 | 174 | logger.info(f"Starting batch {batch_idx+1} of {self.max_steps}.") 175 | 176 | # STEP 1: Get next batch 177 | if instance_idx + self.bsize > len(trainset): 178 | rng.shuffle(data_indices) 179 | instance_idx = 0 180 | 181 | batch_indices = data_indices[instance_idx : instance_idx + self.bsize] 182 | batch = [trainset[i] for i in batch_indices] 183 | instance_idx += self.bsize 184 | 185 | # We'll generate (program, model) pairs for the trajectory sampling. 186 | # Prepare distinct LMs (with different temperatures, etc.) from the baseline=programs[0]. 187 | models = prepare_models_for_resampling(programs[0], self.num_candidates, self.teacher_settings) 188 | top_programs = top_k_plus_baseline(self.num_candidates) 189 | 190 | exec_pairs = [] 191 | predictor2name = {} 192 | 193 | # For each model, for each example, pick a program from the pool via softmax 194 | for model in models: 195 | for example in batch: 196 | chosen_prog_idx = softmax_sample(rng, top_programs, self.temperature_for_sampling) 197 | candidate_system = programs[chosen_prog_idx].deepcopy() 198 | candidate_system.set_lm(model) 199 | 200 | for name, predictor in candidate_system.named_predictors(): 201 | predictor2name[id(predictor)] = name 202 | 203 | # Use the special wrap that includes the 'example' in the output 204 | wrapped_candidate_system = wrap_program(candidate_system, self.metric) 205 | exec_pairs.append((wrapped_candidate_system, example)) 206 | 207 | # STEP 2: Execute 208 | logger.info(f"Sampling program trajectories on {self.bsize} examples x {self.num_candidates} samples.") 209 | outputs = run_parallel(exec_pairs) 210 | assert len(outputs) == len(exec_pairs) == self.bsize * self.num_candidates 211 | 212 | # STEP 3: Sort the training buckets by (max-to-min gap, max score, and max-to-avg gap). 213 | buckets = [] 214 | largest_max_to_avg_gap = float("-inf") 215 | batch_10th_percentile_score = np.percentile([float(o["score"]) for o in outputs], 10) 216 | batch_90th_percentile_score = np.percentile([float(o["score"]) for o in outputs], 90) 217 | 218 | # We'll chunk `outputs` by example index, each chunk has length = num_candidates 219 | for idx, _ in enumerate(batch): 220 | # gather all results for this example 221 | bucket = [outputs[i] for i in range(idx, len(outputs), self.bsize)] 222 | bucket.sort(key=lambda x: x["score"], reverse=True) 223 | 224 | max_score = float(bucket[0]["score"]) 225 | min_score = float(bucket[-1]["score"]) 226 | avg_score = sum(x["score"] for x in bucket) / len(bucket) 227 | max_to_min_gap = max_score - min_score 228 | max_to_avg_gap = max_score - avg_score 229 | if max_to_avg_gap > largest_max_to_avg_gap: 230 | largest_max_to_avg_gap = max_to_avg_gap 231 | 232 | buckets.append((bucket, (max_to_min_gap, max_score, max_to_avg_gap))) 233 | 234 | # sort the buckets 235 | buckets.sort(key=lambda x: x[1], reverse=True) 236 | 237 | # Baseline for the batch is just the average of all runs 238 | all_scores_in_this_batch = [o["score"] for o in outputs] 239 | baseline_score = sum(all_scores_in_this_batch) / len(all_scores_in_this_batch) 240 | logger.info(f"Batch {batch_idx+1}: Baseline mini-batch score: {baseline_score}\n") 241 | 242 | # STEP 4: Build new candidate programs by applying a strategy to some top buckets. 243 | system_candidates = [] 244 | for bucket_idx, (bucket, bucket_stats) in enumerate(buckets): 245 | max_to_min_gap, max_score, max_to_avg_gap = bucket_stats 246 | logger.info( 247 | f"Batch {batch_idx+1}: Processing bucket #{bucket_idx+1}, with max score {max_score}, " 248 | f"max-to-min gap {max_to_min_gap}, and max-to-avg gap {max_to_avg_gap}." 249 | ) 250 | 251 | # pick source program 252 | src_prog_idx = softmax_sample( 253 | rng, top_k_plus_baseline(self.num_candidates), self.temperature_for_candidates 254 | ) 255 | system_candidate = programs[src_prog_idx].deepcopy() 256 | 257 | # Drop some demos from each predictor 258 | name2predictor = {} 259 | num_demos_list = [] 260 | 261 | max_demos_tmp = self.max_demos if self.max_demos > 0 else 3 262 | 263 | for name, predictor in system_candidate.named_predictors(): 264 | name2predictor[name] = predictor 265 | num_demos_list.append(len(predictor.demos)) 266 | 267 | num_demos = max(num_demos_list) if num_demos_list else 0 268 | num_demos_to_drop = max(rng_np.poisson(num_demos / max_demos_tmp), int(num_demos >= max_demos_tmp)) 269 | num_demos_to_drop = min(num_demos_to_drop, num_demos) 270 | demos_to_drop = [rng.randrange(num_demos) for _ in range(num_demos_to_drop)] 271 | 272 | for _, predictor in name2predictor.items(): 273 | predictor.demos = [demo for idxd, demo in enumerate(predictor.demos) if idxd not in demos_to_drop] 274 | 275 | # Pick a strategy 276 | strategy = rng.choice(self.strategies) 277 | logger.info( 278 | f"Batch {batch_idx+1}: Invoking strategy: {strategy.__name__}" 279 | + (f", having dropped {num_demos_to_drop} demos per predictor" if num_demos_to_drop else "") 280 | ) 281 | 282 | try: 283 | strategy( 284 | bucket, 285 | system_candidate, 286 | predictor2name=predictor2name, 287 | name2predictor=name2predictor, 288 | batch_10p_score=batch_10th_percentile_score, 289 | batch_90p_score=batch_90th_percentile_score, 290 | prompt_model=self.prompt_model, 291 | ) 292 | except Exception as e: 293 | logger.error(f"Strategy failed with error: {e}") 294 | continue 295 | 296 | system_candidates.append(system_candidate) 297 | logger.info("\n") 298 | 299 | if len(system_candidates) >= self.num_candidates + 1: 300 | break 301 | 302 | # STEP 5: Evaluate these new system_candidates on the same mini-batch 303 | logger.info(f"Batch {batch_idx+1}: Evaluating {len(system_candidates)} programs on {self.bsize} examples.") 304 | 305 | exec_pairs = [(wrap_program(sys, self.metric), ex) for sys in system_candidates for ex in batch] 306 | outputs = run_parallel(exec_pairs) 307 | assert len(outputs) == len(exec_pairs) == len(system_candidates) * self.bsize 308 | 309 | # STEP 6: Compute average mini-batch scores for each new candidate 310 | candidate_scores = [] 311 | for idx_cand, _ in enumerate(system_candidates): 312 | start = idx_cand * self.bsize 313 | end = (idx_cand + 1) * self.bsize 314 | sys_scores = [outputs[i]["score"] for i in range(start, end)] 315 | avg_sys_score = sum(sys_scores) / len(sys_scores) 316 | candidate_scores.append(avg_sys_score) 317 | 318 | logger.info( 319 | f"Scores after {batch_idx+1} batches: {candidate_scores}, " 320 | f"Best: {max(candidate_scores) if candidate_scores else 'N/A'}\n" 321 | ) 322 | 323 | # STEP 7: Select the best among these new ones for "winning" record 324 | if candidate_scores: 325 | best_idx_among_candidates = candidate_scores.index(max(candidate_scores)) 326 | best_program = system_candidates[best_idx_among_candidates] 327 | winning_programs.append(best_program.deepcopy()) 328 | 329 | # STEP 8: Register all new candidate systems in our global pool 330 | for idx_cand, cand_sys in enumerate(system_candidates): 331 | start = idx_cand * self.bsize 332 | end = (idx_cand + 1) * self.bsize 333 | sys_scores = [outputs[i]["score"] for i in range(start, end)] 334 | register_new_program(cand_sys, sys_scores) 335 | 336 | M = len(winning_programs) - 1 # noqa: N806 337 | N = self.num_candidates + 1 # noqa: N806 338 | if M < 1: 339 | program_idxs = [0] * N 340 | else: 341 | program_idxs = [round(i * M / (N - 1)) for i in range(N)] 342 | 343 | program_idxs = list(dict.fromkeys(program_idxs)) 344 | 345 | candidate_programs = [winning_programs[i].deepcopy() for i in program_idxs] 346 | logger.info(f"VALIDATION: Evaluating {len(candidate_programs)} programs on the full trainset.") 347 | exec_pairs = [(wrap_program(sys, self.metric), ex) for sys in candidate_programs for ex in trainset] 348 | outputs = run_parallel(exec_pairs) 349 | 350 | scores = [] 351 | for idx_prog, _ in enumerate(candidate_programs): 352 | start = idx_prog * len(trainset) 353 | end = (idx_prog + 1) * len(trainset) 354 | sys_scores = [outputs[i]["score"] for i in range(start, end)] 355 | avg_score = sum(sys_scores) / len(sys_scores) if sys_scores else 0.0 356 | scores.append(avg_score) 357 | if idx_prog != 0: 358 | trial_logs[idx_prog - 1]["train_score"] = avg_score 359 | 360 | # Build sorted list of {"score", "program"} dicts 361 | assert len(scores) == len(candidate_programs) 362 | candidate_data = [{"score": s, "program": p} for s, p in zip(scores, candidate_programs, strict=False)] 363 | candidate_data.sort(key=lambda x: x["score"], reverse=True) 364 | 365 | best_idx = scores.index(max(scores)) if scores else 0 366 | best_program = candidate_programs[best_idx].deepcopy() 367 | logger.info( 368 | f"Final trainset scores: {scores}, Best: {max(scores) if scores else 'N/A'} " 369 | f"(at index {best_idx if scores else 'N/A'})\n\n\n" 370 | ) 371 | 372 | # Attach sorted, scored candidates & logs 373 | best_program.candidate_programs = candidate_data 374 | best_program.trial_logs = trial_logs 375 | 376 | return best_program 377 | ``` -------------------------------------------------------------------------------- /dspy/adapters/types/tool.py: -------------------------------------------------------------------------------- ```python 1 | import asyncio 2 | import inspect 3 | from typing import TYPE_CHECKING, Any, Callable, get_origin, get_type_hints 4 | 5 | import pydantic 6 | from jsonschema import ValidationError, validate 7 | from pydantic import BaseModel, TypeAdapter, create_model 8 | 9 | from dspy.adapters.types.base_type import Type 10 | from dspy.dsp.utils.settings import settings 11 | from dspy.utils.callback import with_callbacks 12 | 13 | if TYPE_CHECKING: 14 | import mcp 15 | from langchain.tools import BaseTool 16 | 17 | _TYPE_MAPPING = {"string": str, "integer": int, "number": float, "boolean": bool, "array": list, "object": dict} 18 | 19 | 20 | class Tool(Type): 21 | """Tool class. 22 | 23 | This class is used to simplify the creation of tools for tool calling (function calling) in LLMs. Only supports 24 | functions for now. 25 | """ 26 | 27 | func: Callable 28 | name: str | None = None 29 | desc: str | None = None 30 | args: dict[str, Any] | None = None 31 | arg_types: dict[str, Any] | None = None 32 | arg_desc: dict[str, str] | None = None 33 | has_kwargs: bool = False 34 | 35 | def __init__( 36 | self, 37 | func: Callable, 38 | name: str | None = None, 39 | desc: str | None = None, 40 | args: dict[str, Any] | None = None, 41 | arg_types: dict[str, Any] | None = None, 42 | arg_desc: dict[str, str] | None = None, 43 | ): 44 | """Initialize the Tool class. 45 | 46 | Users can choose to specify the `name`, `desc`, `args`, and `arg_types`, or let the `dspy.Tool` 47 | automatically infer the values from the function. For values that are specified by the user, automatic inference 48 | will not be performed on them. 49 | 50 | Args: 51 | func (Callable): The actual function that is being wrapped by the tool. 52 | name (Optional[str], optional): The name of the tool. Defaults to None. 53 | desc (Optional[str], optional): The description of the tool. Defaults to None. 54 | args (Optional[dict[str, Any]], optional): The args and their schema of the tool, represented as a 55 | dictionary from arg name to arg's json schema. Defaults to None. 56 | arg_types (Optional[dict[str, Any]], optional): The argument types of the tool, represented as a dictionary 57 | from arg name to the type of the argument. Defaults to None. 58 | arg_desc (Optional[dict[str, str]], optional): Descriptions for each arg, represented as a 59 | dictionary from arg name to description string. Defaults to None. 60 | 61 | Example: 62 | 63 | ```python 64 | def foo(x: int, y: str = "hello"): 65 | return str(x) + y 66 | 67 | tool = Tool(foo) 68 | print(tool.args) 69 | # Expected output: {'x': {'type': 'integer'}, 'y': {'type': 'string', 'default': 'hello'}} 70 | ``` 71 | """ 72 | super().__init__(func=func, name=name, desc=desc, args=args, arg_types=arg_types, arg_desc=arg_desc) 73 | self._parse_function(func, arg_desc) 74 | 75 | def _parse_function(self, func: Callable, arg_desc: dict[str, str] | None = None): 76 | """Helper method that parses a function to extract the name, description, and args. 77 | 78 | This is a helper function that automatically infers the name, description, and args of the tool from the 79 | provided function. In order to make the inference work, the function must have valid type hints. 80 | """ 81 | annotations_func = func if inspect.isfunction(func) or inspect.ismethod(func) else func.__call__ 82 | name = getattr(func, "__name__", type(func).__name__) 83 | desc = getattr(func, "__doc__", None) or getattr(annotations_func, "__doc__", "") 84 | args = {} 85 | arg_types = {} 86 | 87 | # Use inspect.signature to get all arg names 88 | sig = inspect.signature(annotations_func) 89 | # Get available type hints 90 | available_hints = get_type_hints(annotations_func) 91 | # Build a dictionary of arg name -> type (defaulting to Any when missing) 92 | hints = {param_name: available_hints.get(param_name, Any) for param_name in sig.parameters.keys()} 93 | default_values = {param_name: sig.parameters[param_name].default for param_name in sig.parameters.keys()} 94 | 95 | # Process each argument's type to generate its JSON schema. 96 | for k, v in hints.items(): 97 | arg_types[k] = v 98 | if k == "return": 99 | continue 100 | # Check if the type (or its origin) is a subclass of Pydantic's BaseModel 101 | origin = get_origin(v) or v 102 | if isinstance(origin, type) and issubclass(origin, BaseModel): 103 | # Get json schema, and replace $ref with the actual schema 104 | v_json_schema = _resolve_json_schema_reference(v.model_json_schema()) 105 | args[k] = v_json_schema 106 | else: 107 | args[k] = _resolve_json_schema_reference(TypeAdapter(v).json_schema()) 108 | if default_values[k] is not inspect.Parameter.empty: 109 | args[k]["default"] = default_values[k] 110 | if arg_desc and k in arg_desc: 111 | args[k]["description"] = arg_desc[k] 112 | 113 | self.name = self.name or name 114 | self.desc = self.desc or desc 115 | self.args = self.args if self.args is not None else args 116 | self.arg_types = self.arg_types if self.arg_types is not None else arg_types 117 | self.has_kwargs = any(param.kind == param.VAR_KEYWORD for param in sig.parameters.values()) 118 | 119 | def _validate_and_parse_args(self, **kwargs): 120 | # Validate the args value comply to the json schema. 121 | for k, v in kwargs.items(): 122 | if k not in self.args: 123 | if self.has_kwargs: 124 | continue 125 | else: 126 | raise ValueError(f"Arg {k} is not in the tool's args.") 127 | try: 128 | instance = v.model_dump() if hasattr(v, "model_dump") else v 129 | type_str = self.args[k].get("type") 130 | if type_str is not None and type_str != "Any": 131 | validate(instance=instance, schema=self.args[k]) 132 | except ValidationError as e: 133 | raise ValueError(f"Arg {k} is invalid: {e.message}") 134 | 135 | # Parse the args to the correct type. 136 | parsed_kwargs = {} 137 | for k, v in kwargs.items(): 138 | if k in self.arg_types and self.arg_types[k] != Any: 139 | # Create a pydantic model wrapper with a dummy field `value` to parse the arg to the correct type. 140 | # This is specifically useful for handling nested Pydantic models like `list[list[MyPydanticModel]]` 141 | pydantic_wrapper = create_model("Wrapper", value=(self.arg_types[k], ...)) 142 | parsed = pydantic_wrapper.model_validate({"value": v}) 143 | parsed_kwargs[k] = parsed.value 144 | else: 145 | parsed_kwargs[k] = v 146 | return parsed_kwargs 147 | 148 | def format(self): 149 | return str(self) 150 | 151 | def format_as_litellm_function_call(self): 152 | return { 153 | "type": "function", 154 | "function": { 155 | "name": self.name, 156 | "description": self.desc, 157 | "parameters": { 158 | "type": "object", 159 | "properties": self.args, 160 | "required": list(self.args.keys()), 161 | }, 162 | }, 163 | } 164 | 165 | def _run_async_in_sync(self, coroutine): 166 | try: 167 | loop = asyncio.get_running_loop() 168 | except RuntimeError: 169 | return asyncio.run(coroutine) 170 | 171 | return loop.run_until_complete(coroutine) 172 | 173 | @with_callbacks 174 | def __call__(self, **kwargs): 175 | parsed_kwargs = self._validate_and_parse_args(**kwargs) 176 | result = self.func(**parsed_kwargs) 177 | if asyncio.iscoroutine(result): 178 | if settings.allow_tool_async_sync_conversion: 179 | return self._run_async_in_sync(result) 180 | else: 181 | raise ValueError( 182 | "You are calling `__call__` on an async tool, please use `acall` instead or set " 183 | "`allow_async=True` to run the async tool in sync mode." 184 | ) 185 | return result 186 | 187 | @with_callbacks 188 | async def acall(self, **kwargs): 189 | parsed_kwargs = self._validate_and_parse_args(**kwargs) 190 | result = self.func(**parsed_kwargs) 191 | if asyncio.iscoroutine(result): 192 | return await result 193 | else: 194 | # We should allow calling a sync tool in the async path. 195 | return result 196 | 197 | @classmethod 198 | def from_mcp_tool(cls, session: "mcp.ClientSession", tool: "mcp.types.Tool") -> "Tool": 199 | """ 200 | Build a DSPy tool from an MCP tool and a ClientSession. 201 | 202 | Args: 203 | session: The MCP session to use. 204 | tool: The MCP tool to convert. 205 | 206 | Returns: 207 | A Tool object. 208 | """ 209 | from dspy.utils.mcp import convert_mcp_tool 210 | 211 | return convert_mcp_tool(session, tool) 212 | 213 | @classmethod 214 | def from_langchain(cls, tool: "BaseTool") -> "Tool": 215 | """ 216 | Build a DSPy tool from a LangChain tool. 217 | 218 | Args: 219 | tool: The LangChain tool to convert. 220 | 221 | Returns: 222 | A Tool object. 223 | 224 | Example: 225 | 226 | ```python 227 | import asyncio 228 | import dspy 229 | from langchain.tools import tool as lc_tool 230 | 231 | @lc_tool 232 | def add(x: int, y: int): 233 | "Add two numbers together." 234 | return x + y 235 | 236 | dspy_tool = dspy.Tool.from_langchain(add) 237 | 238 | async def run_tool(): 239 | return await dspy_tool.acall(x=1, y=2) 240 | 241 | print(asyncio.run(run_tool())) 242 | # 3 243 | ``` 244 | """ 245 | from dspy.utils.langchain_tool import convert_langchain_tool 246 | 247 | return convert_langchain_tool(tool) 248 | 249 | def __repr__(self): 250 | return f"Tool(name={self.name}, desc={self.desc}, args={self.args})" 251 | 252 | def __str__(self): 253 | desc = f", whose description is <desc>{self.desc}</desc>.".replace("\n", " ") if self.desc else "." 254 | arg_desc = f"It takes arguments {self.args}." 255 | return f"{self.name}{desc} {arg_desc}" 256 | 257 | 258 | class ToolCalls(Type): 259 | class ToolCall(Type): 260 | name: str 261 | args: dict[str, Any] 262 | 263 | def format(self): 264 | return { 265 | "type": "function", 266 | "function": { 267 | "name": self.name, 268 | "arguments": self.args, 269 | }, 270 | } 271 | 272 | def execute(self, functions: dict[str, Any] | list[Tool] | None = None) -> Any: 273 | """Execute this individual tool call and return its result. 274 | 275 | Args: 276 | functions: Functions to search for the tool. Can be: 277 | - Dict mapping tool names to functions: {"tool_name": function} 278 | - List of Tool objects: [Tool(function), ...] 279 | - None: Will search in caller's locals and globals (automatic lookup) 280 | 281 | Returns: 282 | The result from executing this tool call. 283 | 284 | Raises: 285 | ValueError: If the tool function cannot be found. 286 | Exception: Any exception raised by the tool function. 287 | """ 288 | func = None 289 | 290 | if functions is None: 291 | # Automatic lookup in caller's globals and locals 292 | frame = inspect.currentframe().f_back 293 | try: 294 | caller_globals = frame.f_globals 295 | caller_locals = frame.f_locals 296 | func = caller_locals.get(self.name) or caller_globals.get(self.name) 297 | finally: 298 | del frame 299 | 300 | elif isinstance(functions, dict): 301 | func = functions.get(self.name) 302 | elif isinstance(functions, list): 303 | for tool in functions: 304 | if tool.name == self.name: 305 | func = tool.func 306 | break 307 | 308 | if func is None: 309 | raise ValueError(f"Tool function '{self.name}' not found. Please pass the tool functions to the `execute` method.") 310 | 311 | try: 312 | args = self.args or {} 313 | return func(**args) 314 | except Exception as e: 315 | raise RuntimeError(f"Error executing tool '{self.name}': {e}") from e 316 | 317 | tool_calls: list[ToolCall] 318 | 319 | @classmethod 320 | def from_dict_list(cls, tool_calls_dicts: list[dict[str, Any]]) -> "ToolCalls": 321 | """Convert a list of dictionaries to a ToolCalls instance. 322 | 323 | Args: 324 | dict_list: A list of dictionaries, where each dictionary should have 'name' and 'args' keys. 325 | 326 | Returns: 327 | A ToolCalls instance. 328 | 329 | Example: 330 | 331 | ```python 332 | tool_calls_dict = [ 333 | {"name": "search", "args": {"query": "hello"}}, 334 | {"name": "translate", "args": {"text": "world"}} 335 | ] 336 | tool_calls = ToolCalls.from_dict_list(tool_calls_dict) 337 | ``` 338 | """ 339 | tool_calls = [cls.ToolCall(**item) for item in tool_calls_dicts] 340 | return cls(tool_calls=tool_calls) 341 | 342 | @classmethod 343 | def description(cls) -> str: 344 | return ( 345 | "Tool calls information, including the name of the tools and the arguments to be passed to it. " 346 | "Arguments must be provided in JSON format." 347 | ) 348 | 349 | def format(self) -> list[dict[str, Any]]: 350 | # The tool_call field is compatible with OpenAI's tool calls schema. 351 | return { 352 | "tool_calls": [tool_call.format() for tool_call in self.tool_calls], 353 | } 354 | 355 | @pydantic.model_validator(mode="before") 356 | @classmethod 357 | def validate_input(cls, data: Any): 358 | if isinstance(data, cls): 359 | return data 360 | 361 | # Handle case where data is a list of dicts with "name" and "args" keys 362 | if isinstance(data, list) and all( 363 | isinstance(item, dict) and "name" in item and "args" in item for item in data 364 | ): 365 | return {"tool_calls": [cls.ToolCall(**item) for item in data]} 366 | # Handle case where data is a dict 367 | elif isinstance(data, dict): 368 | if "tool_calls" in data: 369 | # Handle case where data is a dict with "tool_calls" key 370 | tool_calls_data = data["tool_calls"] 371 | if isinstance(tool_calls_data, list): 372 | return { 373 | "tool_calls": [ 374 | cls.ToolCall(**item) if isinstance(item, dict) else item for item in tool_calls_data 375 | ] 376 | } 377 | elif "name" in data and "args" in data: 378 | # Handle case where data is a dict with "name" and "args" keys 379 | return {"tool_calls": [cls.ToolCall(**data)]} 380 | 381 | raise ValueError(f"Received invalid value for `dspy.ToolCalls`: {data}") 382 | 383 | 384 | def _resolve_json_schema_reference(schema: dict) -> dict: 385 | """Recursively resolve json model schema, expanding all references.""" 386 | 387 | # If there are no definitions to resolve, return the main schema 388 | if "$defs" not in schema and "definitions" not in schema: 389 | return schema 390 | 391 | def resolve_refs(obj: Any) -> Any: 392 | if not isinstance(obj, (dict, list)): 393 | return obj 394 | if isinstance(obj, dict): 395 | if "$ref" in obj: 396 | ref_path = obj["$ref"].split("/")[-1] 397 | return resolve_refs(schema["$defs"][ref_path]) 398 | return {k: resolve_refs(v) for k, v in obj.items()} 399 | 400 | # Must be a list 401 | return [resolve_refs(item) for item in obj] 402 | 403 | # Resolve all references in the main schema 404 | resolved_schema = resolve_refs(schema) 405 | # Remove the $defs key as it's no longer needed 406 | resolved_schema.pop("$defs", None) 407 | return resolved_schema 408 | 409 | 410 | def convert_input_schema_to_tool_args( 411 | schema: dict[str, Any], 412 | ) -> tuple[dict[str, Any], dict[str, Type], dict[str, str]]: 413 | """Convert an input json schema to tool arguments compatible with DSPy Tool. 414 | 415 | Args: 416 | schema: An input json schema describing the tool's input parameters 417 | 418 | Returns: 419 | A tuple of (args, arg_types, arg_desc) for DSPy Tool definition. 420 | """ 421 | args, arg_types, arg_desc = {}, {}, {} 422 | properties = schema.get("properties", None) 423 | if properties is None: 424 | return args, arg_types, arg_desc 425 | 426 | required = schema.get("required", []) 427 | 428 | defs = schema.get("$defs", {}) 429 | 430 | for name, prop in properties.items(): 431 | if len(defs) > 0: 432 | prop = _resolve_json_schema_reference({"$defs": defs, **prop}) 433 | args[name] = prop 434 | arg_types[name] = _TYPE_MAPPING.get(prop.get("type"), Any) 435 | arg_desc[name] = prop.get("description", "No description provided.") 436 | if name in required: 437 | arg_desc[name] += " (Required)" 438 | 439 | return args, arg_types, arg_desc 440 | ``` -------------------------------------------------------------------------------- /dspy/teleprompt/copro_optimizer.py: -------------------------------------------------------------------------------- ```python 1 | import logging 2 | from collections import defaultdict 3 | 4 | import dspy 5 | from dspy.evaluate.evaluate import Evaluate 6 | from dspy.signatures import Signature 7 | from dspy.teleprompt.teleprompt import Teleprompter 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | """ 12 | USAGE SUGGESTIONS: 13 | 14 | The following code can be used to compile a optimized signature teleprompter, and evaluate it on an end task: 15 | 16 | teleprompter = COPRO(prompt_model=prompt_model, metric=metric, breadth=BREADTH, depth=DEPTH, init_temperature=INIT_TEMPERATURE) 17 | kwargs = dict(num_threads=NUM_THREADS, display_progress=True, display_table=0) 18 | compiled_prompt_opt = teleprompter.compile(program.deepcopy(), trainset=trainset[:DEV_NUM], eval_kwargs=kwargs) 19 | eval_score = evaluate(compiled_prompt_opt, devset=evalset[:EVAL_NUM], **kwargs) 20 | 21 | Note that this teleprompter takes in the following parameters: 22 | 23 | * prompt_model: The model used for prompt generation. When unspecified, defaults to the model set in settings (ie. dspy.settings.configure(lm=task_model)). 24 | * metric: The task metric used for optimization. 25 | * breadth: The number of new prompts to generate at each iteration. Default=10. 26 | * depth: The number of times we should ask our prompt model to generate new prompts, with the history of the past prompts as input. Default=3. 27 | * init_temperature: The temperature used to generate new prompts. Higher roughly equals more creative. Default=1.4. 28 | * track_stats: Tells the method whether or not to track statistics about the optimization process. 29 | If True, the method will track the following statistics: 30 | * results_best: The min,max,avg,stddev of top 10 scores for each predictor at each depth. 31 | * results_latest: The min,max,avg,stddev of newest prompt scores for each predictor at each depth. 32 | * total_calls: The total number of calls to the task metric. 33 | These statistics will be returned as attributes of the best program. 34 | """ 35 | 36 | 37 | class BasicGenerateInstruction(Signature): 38 | """You are an instruction optimizer for large language models. I will give you a ``signature`` of fields (inputs and outputs) in English. Your task is to propose an instruction that will lead a good language model to perform the task well. Don't be afraid to be creative.""" 39 | 40 | basic_instruction = dspy.InputField(desc="The initial instructions before optimization") 41 | proposed_instruction = dspy.OutputField(desc="The improved instructions for the language model") 42 | proposed_prefix_for_output_field = dspy.OutputField( 43 | desc="The string at the end of the prompt, which will help the model start solving the task", 44 | ) 45 | 46 | 47 | class GenerateInstructionGivenAttempts(dspy.Signature): 48 | """You are an instruction optimizer for large language models. I will give some task instructions I've tried, along with their corresponding validation scores. The instructions are arranged in increasing order based on their scores, where higher scores indicate better quality. 49 | 50 | Your task is to propose a new instruction that will lead a good language model to perform the task even better. Don't be afraid to be creative.""" 51 | 52 | attempted_instructions = dspy.InputField() 53 | proposed_instruction = dspy.OutputField(desc="The improved instructions for the language model") 54 | proposed_prefix_for_output_field = dspy.OutputField( 55 | desc="The string at the end of the prompt, which will help the model start solving the task", 56 | ) 57 | 58 | 59 | class COPRO(Teleprompter): 60 | def __init__( 61 | self, 62 | prompt_model=None, 63 | metric=None, 64 | breadth=10, 65 | depth=3, 66 | init_temperature=1.4, 67 | track_stats=False, 68 | **_kwargs, 69 | ): 70 | if breadth <= 1: 71 | raise ValueError("Breadth must be greater than 1") 72 | self.metric = metric 73 | self.breadth = breadth 74 | self.depth = depth 75 | self.init_temperature = init_temperature 76 | self.prompt_model = prompt_model 77 | self.track_stats = track_stats 78 | 79 | def _check_candidates_equal(self, candidate1, candidate2): 80 | for p1, p2 in zip(candidate1["program"].predictors(), candidate2["program"].predictors(), strict=False): 81 | if self._get_signature(p1).instructions != self._get_signature(p2).instructions: 82 | return False 83 | *_, p1_last_field = self._get_signature(p1).fields.values() 84 | *_, p2_last_field = self._get_signature(p2).fields.values() 85 | if p1_last_field != p2_last_field: 86 | return False 87 | return True 88 | 89 | def _drop_duplicates(self, candidates): 90 | final_candidates = [] 91 | last_batch = [] 92 | last_batch_score = -1 93 | for c in candidates: 94 | repeat = False 95 | if c["score"] == last_batch_score: 96 | for c2 in last_batch: 97 | if self._check_candidates_equal(c, c2): 98 | repeat = True 99 | break 100 | if not repeat: 101 | last_batch.append(c) 102 | else: 103 | last_batch = [c] 104 | last_batch_score = c["score"] 105 | if not repeat: 106 | final_candidates.append(c) 107 | return final_candidates 108 | 109 | def _print_signature(self, predictor): 110 | signature = self._get_signature(predictor) 111 | 112 | logger.debug(f"i: {signature.instructions}") 113 | logger.debug(f"p: {list(signature.fields.values())[-1].json_schema_extra['prefix']}") 114 | 115 | def _get_signature(self, predictor): 116 | assert hasattr(predictor, "signature") 117 | return predictor.signature 118 | 119 | def _set_signature(self, predictor, updated_signature): 120 | assert hasattr(predictor, "signature") 121 | predictor.signature = updated_signature 122 | 123 | def compile(self, student, *, trainset, eval_kwargs): 124 | """ 125 | optimizes `signature` of `student` program - note that it may be zero-shot or already pre-optimized (demos already chosen - `demos != []`) 126 | 127 | parameters: 128 | student: program to optimize and left modified. 129 | trainset: iterable of `Example`s 130 | eval_kwargs: optional, dict 131 | Additional keywords to go into `Evaluate` for the metric. 132 | 133 | Returns optimized version of `student`. 134 | """ 135 | module = student.deepcopy() 136 | evaluate = Evaluate(devset=trainset, metric=self.metric, **eval_kwargs) 137 | total_calls = 0 138 | results_best = { 139 | id(p): {"depth": [], "max": [], "average": [], "min": [], "std": []} for p in module.predictors() 140 | } 141 | results_latest = { 142 | id(p): {"depth": [], "max": [], "average": [], "min": [], "std": []} for p in module.predictors() 143 | } 144 | 145 | if self.track_stats: 146 | import numpy as np 147 | 148 | candidates = {} 149 | evaluated_candidates = defaultdict(dict) 150 | 151 | # Seed the prompt optimizer zero shot with just the instruction, generate BREADTH new prompts 152 | for predictor in module.predictors(): 153 | basic_instruction = None 154 | basic_prefix = None 155 | *_, last_key = self._get_signature(predictor).fields.keys() 156 | basic_instruction = self._get_signature(predictor).instructions 157 | basic_prefix = self._get_signature(predictor).fields[last_key].json_schema_extra["prefix"] 158 | if self.prompt_model: 159 | with dspy.settings.context(lm=self.prompt_model): 160 | instruct = dspy.Predict( 161 | BasicGenerateInstruction, 162 | n=self.breadth - 1, 163 | temperature=self.init_temperature, 164 | )(basic_instruction=basic_instruction) 165 | else: 166 | instruct = dspy.Predict( 167 | BasicGenerateInstruction, 168 | n=self.breadth - 1, 169 | temperature=self.init_temperature, 170 | )(basic_instruction=basic_instruction) 171 | # Add in our initial prompt as a candidate as well 172 | instruct.completions.proposed_instruction.append(basic_instruction) 173 | instruct.completions.proposed_prefix_for_output_field.append(basic_prefix) 174 | candidates[id(predictor)] = instruct.completions 175 | evaluated_candidates[id(predictor)] = {} 176 | 177 | if self.prompt_model: 178 | logger.debug(f"{self.prompt_model.inspect_history(n=1)}") 179 | 180 | latest_candidates = candidates 181 | all_candidates = candidates 182 | 183 | module_clone = module.deepcopy() 184 | 185 | # For each iteration in depth... 186 | for d in range( 187 | self.depth, 188 | ): # TODO: fix this so that we eval the new batch of predictors with the new best following predictors 189 | logger.info(f"Iteration Depth: {d+1}/{self.depth}.") 190 | 191 | latest_scores = [] 192 | 193 | # Go through our module's predictors 194 | for p_i, (p_old, p_new) in enumerate(zip(module.predictors(), module_clone.predictors(), strict=False)): 195 | candidates_ = latest_candidates[id(p_old)] # Use the most recently generated candidates for evaluation 196 | if len(module.predictors()) > 1: 197 | # Unless our program has multiple predictors, in which case we need to reevaluate all prompts with 198 | # the new prompt(s) for the other predictor(s). 199 | candidates_ = all_candidates[ 200 | id(p_old) 201 | ] 202 | 203 | # For each candidate 204 | for c_i, c in enumerate(candidates_): 205 | # Get the candidate instruction and prefix 206 | instruction, prefix = ( 207 | c.proposed_instruction.strip('"').strip(), 208 | c.proposed_prefix_for_output_field.strip('"').strip(), 209 | ) 210 | 211 | # Set this new module with our instruction / prefix 212 | *_, last_key = self._get_signature(p_new).fields.keys() 213 | updated_signature = ( 214 | self._get_signature(p_new) 215 | .with_instructions(instruction) 216 | .with_updated_fields(last_key, prefix=prefix) 217 | ) 218 | self._set_signature(p_new, updated_signature) 219 | 220 | # Score the instruction / prefix 221 | for i, predictor in enumerate(module_clone.predictors()): 222 | logger.debug(f"Predictor {i+1}") 223 | self._print_signature(predictor) 224 | logger.info( 225 | f"At Depth {d+1}/{self.depth}, Evaluating Prompt Candidate #{c_i+1}/{len(candidates_)} for " 226 | f"Predictor {p_i+1} of {len(module.predictors())}.", 227 | ) 228 | score = evaluate(module_clone, devset=trainset, **eval_kwargs).score 229 | if self.prompt_model: 230 | logger.debug(f"prompt_model.inspect_history(n=1) {self.prompt_model.inspect_history(n=1)}") 231 | total_calls += 1 232 | 233 | replace_entry = True 234 | logger.debug(f"(instruction, prefix) {(instruction, prefix)}") 235 | if (instruction, prefix) in evaluated_candidates[id(p_old)]: 236 | if evaluated_candidates[id(p_old)][(instruction, prefix)]["score"] >= score: 237 | replace_entry = False 238 | 239 | if replace_entry: 240 | # Add it to our evaluated candidates list 241 | evaluated_candidates[id(p_old)][(instruction, prefix)] = { 242 | "score": score, 243 | "program": module_clone.deepcopy(), 244 | "instruction": instruction, 245 | "prefix": prefix, 246 | "depth": d, 247 | } 248 | 249 | if len(candidates_) - self.breadth <= c_i: 250 | latest_scores.append(score) 251 | 252 | if self.track_stats: 253 | results_latest[id(p_old)]["depth"].append(d) 254 | results_latest[id(p_old)]["max"].append(max(latest_scores)) 255 | results_latest[id(p_old)]["average"].append(sum(latest_scores) / len(latest_scores)) 256 | results_latest[id(p_old)]["min"].append(min(latest_scores)) 257 | results_latest[id(p_old)]["std"].append(np.std(latest_scores)) 258 | 259 | # Now that we've evaluated the candidates, set this predictor to the best performing version 260 | # to ensure the next round of scores reflect the best possible version 261 | best_candidate = max(evaluated_candidates[id(p_old)].values(), key=lambda candidate: candidate["score"]) 262 | *_, last_key = self._get_signature(p_old).fields.keys() 263 | updated_signature = ( 264 | self._get_signature(p_new) 265 | .with_instructions(best_candidate["instruction"]) 266 | .with_updated_fields(last_key, prefix=best_candidate["prefix"]) 267 | ) 268 | self._set_signature(p_new, updated_signature) 269 | 270 | logger.debug( 271 | f"Updating Predictor {id(p_old)} to:\ni: {best_candidate['instruction']}\n" 272 | f"p: {best_candidate['prefix']}", 273 | ) 274 | logger.debug("Full predictor with update: ") 275 | for i, predictor in enumerate(module_clone.predictors()): 276 | logger.debug(f"Predictor {i}") 277 | self._print_signature(predictor) 278 | 279 | if d == self.depth - 1: 280 | break 281 | 282 | new_candidates = {} 283 | for p_base in module.predictors(): 284 | # Build Few-Shot Example of Optimized Prompts 285 | attempts = [] 286 | shortest_len = self.breadth 287 | shortest_len = min(len(evaluated_candidates[id(p_base)]), shortest_len) 288 | best_predictors = list(evaluated_candidates[id(p_base)].values()) 289 | 290 | # best_predictors = evaluated_candidates[id(p_base)].values()[:] 291 | best_predictors.sort(key=lambda x: x["score"], reverse=True) 292 | 293 | if self.track_stats: 294 | scores = [x["score"] for x in best_predictors][:10] 295 | results_best[id(p_base)]["depth"].append(d) 296 | results_best[id(p_base)]["max"].append(max(scores)) 297 | results_best[id(p_base)]["average"].append(sum(scores) / len(scores)) 298 | results_best[id(p_base)]["min"].append(min(scores)) 299 | results_best[id(p_base)]["std"].append(np.std(scores)) 300 | 301 | for i in range(shortest_len - 1, -1, -1): 302 | # breakpoint() 303 | attempts.append(f'Instruction #{shortest_len-i}: {best_predictors[i]["instruction"]}') 304 | attempts.append(f'Prefix #{shortest_len-i}: {best_predictors[i]["prefix"]}') 305 | attempts.append(f'Resulting Score #{shortest_len-i}: {best_predictors[i]["score"]}') 306 | 307 | # Generate next batch of potential prompts to optimize, with previous attempts as input 308 | if self.prompt_model: 309 | with dspy.settings.context(lm=self.prompt_model): 310 | instr = dspy.Predict( 311 | GenerateInstructionGivenAttempts, 312 | n=self.breadth, 313 | temperature=self.init_temperature, 314 | )(attempted_instructions=attempts) 315 | else: 316 | instr = dspy.Predict( 317 | GenerateInstructionGivenAttempts, 318 | n=self.breadth, 319 | temperature=self.init_temperature, 320 | )(attempted_instructions=attempts) 321 | 322 | # Get candidates for each predictor 323 | new_candidates[id(p_base)] = instr.completions 324 | all_candidates[id(p_base)].proposed_instruction.extend(instr.completions.proposed_instruction) 325 | all_candidates[id(p_base)].proposed_prefix_for_output_field.extend( 326 | instr.completions.proposed_prefix_for_output_field, 327 | ) 328 | 329 | latest_candidates = new_candidates 330 | 331 | candidates = [] 332 | for predictor in module.predictors(): 333 | candidates.extend(list(evaluated_candidates[id(predictor)].values())) 334 | 335 | if self.track_stats: 336 | best_predictors = list(evaluated_candidates[id(predictor)].values()) 337 | best_predictors.sort(key=lambda x: x["score"], reverse=True) 338 | 339 | scores = [x["score"] for x in best_predictors][:10] 340 | results_best[id(predictor)]["depth"].append(d) 341 | results_best[id(predictor)]["max"].append(max(scores)) 342 | results_best[id(predictor)]["average"].append(sum(scores) / len(scores)) 343 | results_best[id(predictor)]["min"].append(min(scores)) 344 | results_best[id(predictor)]["std"].append(np.std(scores)) 345 | 346 | candidates.sort(key=lambda x: x["score"], reverse=True) 347 | 348 | candidates = self._drop_duplicates(candidates) 349 | 350 | best_program = candidates[0]["program"] 351 | best_program.candidate_programs = candidates 352 | best_program.total_calls = total_calls 353 | if self.track_stats: 354 | best_program.results_best = results_best 355 | best_program.results_latest = results_latest 356 | 357 | return best_program 358 | ``` -------------------------------------------------------------------------------- /dspy/propose/grounded_proposer.py: -------------------------------------------------------------------------------- ```python 1 | import random 2 | 3 | import dspy 4 | from dspy.propose.dataset_summary_generator import create_dataset_summary 5 | from dspy.propose.propose_base import Proposer 6 | from dspy.propose.utils import ( 7 | create_example_string, 8 | create_predictor_level_history_string, 9 | get_dspy_source_code, 10 | strip_prefix, 11 | ) 12 | from dspy.teleprompt.utils import get_prompt_model, get_signature 13 | 14 | # Hardcoded variables (TODO: update) 15 | MAX_INSTRUCT_IN_HISTORY = 5 # 10 16 | 17 | TIPS = { 18 | "none": "", 19 | "creative": "Don't be afraid to be creative when creating the new instruction!", 20 | "simple": "Keep the instruction clear and concise.", 21 | "description": "Make sure your instruction is very informative and descriptive.", 22 | "high_stakes": "The instruction should include a high stakes scenario in which the LM must solve the task!", 23 | "persona": 'Include a persona that is relevant to the task in the instruction (ie. "You are a ...")', 24 | } 25 | 26 | ### SIGNATURES USED TO HELP WITH INSTRUCTION GENERATION ### 27 | 28 | class DescribeProgram(dspy.Signature): 29 | ( 30 | """Below is some pseudo-code for a pipeline that solves tasks with calls to language models. Please describe what type of task this program appears to be designed to solve, and how it appears to work.""" 31 | ) 32 | program_code = dspy.InputField( 33 | format=str, 34 | desc="Pseudocode for a language model program designed to solve a particular task.", 35 | prefix="PROGRAM CODE:", 36 | ) 37 | program_example = dspy.InputField( 38 | format=str, 39 | desc="An example of the program in use.", 40 | prefix="EXAMPLE OF PROGRAM IN USE:", 41 | ) 42 | program_description = dspy.OutputField( 43 | desc="Describe what task the program is designed to solve, and how it goes about solving this task.", 44 | prefix="SUMMARY OF PROGRAM ABOVE:", 45 | ) 46 | 47 | 48 | class DescribeModule(dspy.Signature): 49 | ( 50 | """Below is some pseudo-code for a pipeline that solves tasks with calls to language models. Please describe the purpose of one of the specified module in this pipeline.""" 51 | ) 52 | program_code = dspy.InputField( 53 | format=str, 54 | desc="Pseudocode for a language model program designed to solve a particular task.", 55 | prefix="PROGRAM CODE:", 56 | ) 57 | program_example = dspy.InputField( 58 | format=str, 59 | desc="An example of the program in use.", 60 | prefix="EXAMPLE OF PROGRAM IN USE:", 61 | ) 62 | program_description = dspy.InputField( 63 | desc="Summary of the task the program is designed to solve, and how it goes about solving it.", 64 | prefix="SUMMARY OF PROGRAM ABOVE:", 65 | ) 66 | module = dspy.InputField( 67 | desc="The module in the program that we want to describe.", prefix="MODULE:", 68 | ) 69 | module_description = dspy.OutputField( 70 | desc="Description of the module's role in the broader program.", 71 | prefix="MODULE DESCRIPTION:", 72 | ) 73 | 74 | 75 | def generate_instruction_class( 76 | use_dataset_summary=True, 77 | program_aware=True, 78 | use_task_demos=True, 79 | use_instruct_history=True, 80 | use_tip=True, 81 | ): 82 | class GenerateSingleModuleInstruction(dspy.Signature): 83 | ( 84 | """Use the information below to learn about a task that we are trying to solve using calls to an LM, then generate a new instruction that will be used to prompt a Language Model to better solve the task.""" 85 | ) 86 | if use_dataset_summary: 87 | dataset_description = dspy.InputField( 88 | desc="A description of the dataset that we are using.", 89 | prefix="DATASET SUMMARY:", 90 | ) 91 | if program_aware: 92 | program_code = dspy.InputField( 93 | format=str, 94 | desc="Language model program designed to solve a particular task.", 95 | prefix="PROGRAM CODE:", 96 | ) 97 | program_description = dspy.InputField( 98 | desc="Summary of the task the program is designed to solve, and how it goes about solving it.", 99 | prefix="PROGRAM DESCRIPTION:", 100 | ) 101 | module = dspy.InputField( 102 | desc="The module to create an instruction for.", prefix="MODULE:", 103 | ) 104 | module_description = dspy.InputField( 105 | desc="Description of the module to create an instruction for.", prefix="MODULE DESCRIPTION:", 106 | ) 107 | task_demos = dspy.InputField( 108 | format=str, 109 | desc="Example inputs/outputs of our module.", 110 | prefix="TASK DEMO(S):", 111 | ) 112 | if use_instruct_history: 113 | previous_instructions = dspy.InputField( 114 | format=str, 115 | desc="Previous instructions we've attempted, along with their associated scores.", 116 | prefix="PREVIOUS INSTRUCTIONS:", 117 | ) 118 | basic_instruction = dspy.InputField( 119 | format=str, desc="Basic instruction.", prefix="BASIC INSTRUCTION:", 120 | ) 121 | if use_tip: 122 | tip = dspy.InputField( 123 | format=str, 124 | desc="A suggestion for how to go about generating the new instruction.", 125 | prefix="TIP:", 126 | ) 127 | proposed_instruction = dspy.OutputField( 128 | desc="Propose an instruction that will be used to prompt a Language Model to perform this task.", 129 | prefix="PROPOSED INSTRUCTION:", 130 | ) 131 | 132 | return dspy.Predict(GenerateSingleModuleInstruction) 133 | 134 | ### CLASS RESPONSIBLE FOR GENERATING A NEW INSTRUCTION, USING THE HELPER SIGNATURES ABOVE ### 135 | 136 | class GenerateModuleInstruction(dspy.Module): 137 | def __init__( 138 | self, 139 | program_code_string=None, 140 | use_dataset_summary=True, 141 | program_aware=False, 142 | use_task_demos=True, 143 | use_instruct_history=True, 144 | use_tip=True, 145 | verbose=False, 146 | ): 147 | super().__init__() 148 | self.use_dataset_summary = use_dataset_summary 149 | self.program_aware = program_aware 150 | self.use_task_demos = use_task_demos 151 | self.use_instruct_history = use_instruct_history 152 | self.use_tip = use_tip 153 | self.verbose = verbose 154 | 155 | self.program_code_string = program_code_string 156 | self.describe_program = dspy.Predict(DescribeProgram) 157 | self.describe_module = dspy.Predict(DescribeModule) 158 | self.generate_module_instruction = generate_instruction_class( 159 | use_dataset_summary=use_dataset_summary, 160 | program_aware=program_aware, 161 | use_task_demos=use_task_demos, 162 | use_instruct_history=use_instruct_history, 163 | use_tip=use_tip, 164 | ) 165 | 166 | def forward( 167 | self, 168 | demo_candidates, 169 | pred_i, 170 | demo_set_i, 171 | program, 172 | previous_instructions, 173 | data_summary, 174 | num_demos_in_context=3, 175 | tip=None, 176 | ): 177 | def gather_examples_from_sets(candidate_sets, max_examples): 178 | """Helper function to gather up to augmented examples from given sets.""" 179 | count = 0 180 | for candidate_set in candidate_sets: 181 | for example in candidate_set: 182 | if "augmented" in example.keys(): 183 | fields_to_use = get_signature(program.predictors()[pred_i]).fields 184 | yield create_example_string(fields_to_use, example) 185 | count += 1 186 | if count >= max_examples: 187 | return 188 | 189 | # Construct full program demo or single module demo depending on settings 190 | basic_instruction = get_signature(program.predictors()[pred_i]).instructions 191 | task_demos = "" 192 | 193 | if self.use_task_demos: 194 | # Combine current and adjacent sets 195 | adjacent_sets = ( 196 | [demo_candidates[pred_i][demo_set_i]] + 197 | demo_candidates[pred_i][demo_set_i + 1:] + 198 | demo_candidates[pred_i][:demo_set_i] 199 | ) 200 | 201 | # Gather examples up to the required count 202 | example_strings = gather_examples_from_sets(adjacent_sets, num_demos_in_context) 203 | task_demos = "\n\n".join(example_strings) + "\n\n" 204 | 205 | # Default to no demos provided if no examples were gathered, or if we're using the first demo set 206 | if not task_demos.strip() or demo_set_i == 0: 207 | task_demos = "No task demos provided." 208 | 209 | # Summarize the program 210 | program_description = "Not available" 211 | module_code = "Not provided" 212 | module_description = "Not provided" 213 | if self.program_aware: 214 | try: 215 | program_description = strip_prefix( 216 | self.describe_program( 217 | program_code=self.program_code_string, program_example=task_demos, 218 | ).program_description, 219 | ) 220 | if self.verbose: 221 | print(f"PROGRAM DESCRIPTION: {program_description}") 222 | 223 | inputs = [] 224 | outputs = [] 225 | for field_name, field in get_signature(program.predictors()[pred_i]).fields.items(): 226 | # Access the '__dspy_field_type' from the extra metadata 227 | dspy_field_type = field.json_schema_extra.get("__dspy_field_type") 228 | 229 | # Based on the '__dspy_field_type', append to the respective list 230 | if dspy_field_type == "input": 231 | inputs.append(field_name) 232 | else: 233 | outputs.append(field_name) 234 | 235 | module_code = f"{program.predictors()[pred_i].__class__.__name__}({', '.join(inputs)}) -> {', '.join(outputs)}" 236 | 237 | module_description = self.describe_module( 238 | program_code=self.program_code_string, 239 | program_description=program_description, 240 | program_example=task_demos, 241 | module=module_code, 242 | max_depth=10, 243 | ).module_description 244 | except Exception as e: 245 | if self.verbose: 246 | print(f"Error getting program description. Running without program aware proposer. Error: {e}") 247 | self.program_aware = False 248 | 249 | # Generate an instruction for our chosen module 250 | if self.verbose: 251 | print(f"task_demos {task_demos}") 252 | 253 | instruct = self.generate_module_instruction( 254 | dataset_description=data_summary, 255 | program_code=self.program_code_string, 256 | module=module_code, 257 | program_description=program_description, 258 | module_description=module_description, 259 | task_demos=task_demos, 260 | tip=tip, 261 | basic_instruction=basic_instruction, 262 | previous_instructions=previous_instructions, 263 | ) 264 | 265 | proposed_instruction = strip_prefix(instruct.proposed_instruction) 266 | 267 | return dspy.Prediction(proposed_instruction=proposed_instruction) 268 | 269 | ### CLASS USED TO GENERATE THE FULL SET OF INSTRUCTIONS GIVEN THE SPECIFIED CRITERIA ### 270 | 271 | class GroundedProposer(Proposer): 272 | def __init__( 273 | self, 274 | prompt_model, 275 | program, 276 | trainset, 277 | view_data_batch_size=10, 278 | use_dataset_summary=True, 279 | program_aware=True, 280 | use_task_demos=True, 281 | num_demos_in_context = 3, 282 | use_instruct_history=True, 283 | use_tip=True, 284 | set_tip_randomly=True, 285 | set_history_randomly=True, 286 | verbose=False, 287 | rng=None, 288 | init_temperature: float = 1.0, 289 | ): 290 | super().__init__() 291 | self.program_aware = program_aware 292 | self.use_dataset_summary = use_dataset_summary 293 | self.use_task_demos = use_task_demos 294 | self.num_demos_in_context = num_demos_in_context 295 | self.use_instruct_history = use_instruct_history 296 | self.use_tip = use_tip 297 | self.set_tip_randomly=set_tip_randomly 298 | self.set_history_randomly=set_history_randomly 299 | self.verbose = verbose 300 | self.rng = rng or random 301 | 302 | self.prompt_model = get_prompt_model(prompt_model) 303 | self.init_temperature = init_temperature 304 | 305 | self.program_code_string = None 306 | if self.program_aware: 307 | try: 308 | self.program_code_string = get_dspy_source_code(program) 309 | if self.verbose: 310 | print("SOURCE CODE:",self.program_code_string) 311 | except Exception as e: 312 | print(f"Error getting source code: {e}.\n\nRunning without program aware proposer.") 313 | self.program_aware = False 314 | 315 | self.data_summary = None 316 | if self.use_dataset_summary: 317 | try: 318 | self.data_summary = create_dataset_summary( 319 | trainset=trainset, view_data_batch_size=view_data_batch_size, prompt_model=prompt_model, 320 | ) 321 | if self.verbose: 322 | print(f"DATA SUMMARY: {self.data_summary}") 323 | except Exception as e: 324 | print(f"Error getting data summary: {e}.\n\nRunning without data aware proposer.") 325 | self.use_dataset_summary = False 326 | print("") 327 | 328 | def propose_instructions_for_program( 329 | self, 330 | trainset, 331 | program, 332 | demo_candidates, 333 | trial_logs, 334 | N, # noqa: N803 335 | ) -> list[str]: 336 | """This method is responsible for returning the full set of new instructions for our program, given the specified criteria.""" 337 | 338 | proposed_instructions = {} 339 | 340 | if self.set_history_randomly: 341 | # Randomly select whether or not we're using instruction history 342 | use_history = self.rng.random() < 0.5 343 | self.use_instruct_history = use_history 344 | if self.verbose: 345 | print(f"Use history T/F: {self.use_instruct_history}") 346 | 347 | if not demo_candidates: 348 | if self.verbose: 349 | print("No demo candidates provided. Running without task demos.") 350 | self.use_task_demos = False 351 | # When no demo candidates are provided, default to N 352 | num_demos = N 353 | else: 354 | num_demos = max(len(demo_candidates[0]), 1) 355 | 356 | # Create an instruction for each predictor 357 | for pred_i, predictor in enumerate(program.predictors()): 358 | for demo_set_i in range(num_demos)[:min(N, num_demos)]: 359 | if pred_i not in proposed_instructions: 360 | proposed_instructions[pred_i] = [] 361 | selected_tip = None 362 | if self.set_tip_randomly: 363 | if self.verbose: 364 | print("Using a randomly generated configuration for our grounded proposer.") 365 | # Randomly select the tip 366 | selected_tip_key = self.rng.choice(list(TIPS.keys())) 367 | selected_tip = TIPS[selected_tip_key] 368 | self.use_tip = bool( 369 | selected_tip, 370 | ) 371 | if self.verbose: 372 | print(f"Selected tip: {selected_tip_key}") 373 | 374 | proposed_instructions[pred_i].append( 375 | self.propose_instruction_for_predictor( 376 | program=program, 377 | predictor=predictor, 378 | pred_i=pred_i, 379 | demo_candidates=demo_candidates, 380 | demo_set_i=demo_set_i, 381 | trial_logs=trial_logs, 382 | tip=selected_tip, 383 | ), 384 | ) 385 | 386 | return proposed_instructions 387 | 388 | def propose_instruction_for_predictor( 389 | self, 390 | program, 391 | predictor, 392 | pred_i, 393 | demo_candidates, 394 | demo_set_i, 395 | trial_logs, 396 | tip=None, 397 | ) -> str: 398 | """This method is responsible for returning a single instruction for a given predictor, using the specified criteria.""" 399 | 400 | # Create an instruction history string for our predictor 401 | instruction_history = create_predictor_level_history_string( 402 | program, pred_i, trial_logs, MAX_INSTRUCT_IN_HISTORY, 403 | ) 404 | 405 | # Create our instruction generator class (given specific criteria for this round of proposal) 406 | instruction_generator = GenerateModuleInstruction( 407 | program_code_string=self.program_code_string, 408 | use_dataset_summary=self.use_dataset_summary, 409 | program_aware=self.program_aware, 410 | use_task_demos=self.use_task_demos and demo_candidates, 411 | use_instruct_history=self.use_instruct_history and instruction_history, 412 | use_tip=self.use_tip, 413 | verbose=self.verbose 414 | ) 415 | 416 | # Generate a new instruction for our predictor using a unique rollout id to bypass cache 417 | rollout_lm = self.prompt_model.copy( 418 | rollout_id=self.rng.randint(0, 10**9), 419 | temperature=self.init_temperature, 420 | ) 421 | 422 | with dspy.settings.context(lm=rollout_lm): 423 | proposed_instruction = instruction_generator( 424 | demo_candidates=demo_candidates, 425 | pred_i=pred_i, 426 | demo_set_i=demo_set_i, 427 | program=program, 428 | data_summary=self.data_summary, 429 | previous_instructions=instruction_history, 430 | num_demos_in_context = self.num_demos_in_context, 431 | tip=tip, 432 | ).proposed_instruction 433 | 434 | # Log the trace used to generate the new instruction, along with the new instruction itself 435 | if self.verbose: 436 | self.prompt_model.inspect_history(n=1) 437 | print(f"PROPOSED INSTRUCTION: {proposed_instruction}") 438 | 439 | return strip_prefix(proposed_instruction) 440 | ``` -------------------------------------------------------------------------------- /tests/adapters/test_tool.py: -------------------------------------------------------------------------------- ```python 1 | import asyncio 2 | from typing import Any 3 | 4 | import pytest 5 | from pydantic import BaseModel 6 | 7 | import dspy 8 | from dspy.adapters.types.tool import Tool, ToolCalls, convert_input_schema_to_tool_args 9 | 10 | 11 | # Test fixtures 12 | def dummy_function(x: int, y: str = "hello") -> str: 13 | """A dummy function for testing. 14 | 15 | Args: 16 | x: An integer parameter 17 | y: A string parameter 18 | """ 19 | return f"{y} {x}" 20 | 21 | 22 | class DummyModel(BaseModel): 23 | field1: str = "hello" 24 | field2: int 25 | 26 | 27 | def dummy_with_pydantic(model: DummyModel) -> str: 28 | """A dummy function that accepts a Pydantic model.""" 29 | return f"{model.field1} {model.field2}" 30 | 31 | 32 | class Address(BaseModel): 33 | street: str 34 | city: str 35 | zip_code: str 36 | is_primary: bool = False 37 | 38 | 39 | class ContactInfo(BaseModel): 40 | email: str 41 | phone: str | None = None 42 | addresses: list[Address] 43 | 44 | 45 | class UserProfile(BaseModel): 46 | user_id: int 47 | name: str 48 | age: int | None = None 49 | contact: ContactInfo 50 | tags: list[str] = [] 51 | 52 | 53 | class Note(BaseModel): 54 | content: str 55 | author: str 56 | 57 | 58 | def complex_dummy_function(profile: UserProfile, priority: int, notes: list[Note] | None = None) -> dict[str, Any]: 59 | """Process user profile with complex nested structure. 60 | 61 | Args: 62 | profile: User profile containing nested contact and address information 63 | priority: Priority level of the processing 64 | notes: Optional processing notes 65 | """ 66 | primary_address = next( 67 | (addr for addr in profile.contact.addresses if addr.is_primary), profile.contact.addresses[0] 68 | ) 69 | 70 | return { 71 | "user_id": profile.user_id, 72 | "name": profile.name, 73 | "priority": priority, 74 | "primary_address": primary_address.model_dump(), 75 | "notes": notes, 76 | } 77 | 78 | 79 | async def async_dummy_function(x: int, y: str = "hello") -> str: 80 | """An async dummy function for testing. 81 | 82 | Args: 83 | x: An integer parameter 84 | y: A string parameter 85 | """ 86 | await asyncio.sleep(0.1) # Simulate some async work 87 | return f"{y} {x}" 88 | 89 | 90 | async def async_dummy_with_pydantic(model: DummyModel) -> str: 91 | """An async dummy function that accepts a Pydantic model.""" 92 | await asyncio.sleep(0.1) # Simulate some async work 93 | return f"{model.field1} {model.field2}" 94 | 95 | 96 | async def async_complex_dummy_function( 97 | profile: UserProfile, 98 | priority: int, 99 | notes: list[Note] | None = None, 100 | ) -> dict[str, Any]: 101 | """Process user profile with complex nested structure asynchronously. 102 | 103 | Args: 104 | profile: User profile containing nested contact and address information 105 | priority: Priority level of the processing 106 | notes: Optional processing notes 107 | """ 108 | # Simulate some async processing work 109 | await asyncio.sleep(0.1) 110 | 111 | primary_address = next( 112 | (addr for addr in profile.contact.addresses if addr.is_primary), profile.contact.addresses[0] 113 | ) 114 | 115 | # Simulate more async work after finding primary address 116 | await asyncio.sleep(0.1) 117 | 118 | return { 119 | "user_id": profile.user_id, 120 | "name": profile.name, 121 | "priority": priority, 122 | "primary_address": primary_address.model_dump(), 123 | "notes": notes, 124 | } 125 | 126 | 127 | def test_basic_initialization(): 128 | tool = Tool(name="test_tool", desc="A test tool", args={"param1": {"type": "string"}}, func=lambda x: x) 129 | assert tool.name == "test_tool" 130 | assert tool.desc == "A test tool" 131 | assert tool.args == {"param1": {"type": "string"}} 132 | assert callable(tool.func) 133 | 134 | 135 | def test_tool_from_function(): 136 | tool = Tool(dummy_function) 137 | 138 | assert tool.name == "dummy_function" 139 | assert "A dummy function for testing" in tool.desc 140 | assert "x" in tool.args 141 | assert "y" in tool.args 142 | assert tool.args["x"]["type"] == "integer" 143 | assert tool.args["y"]["type"] == "string" 144 | assert tool.args["y"]["default"] == "hello" 145 | 146 | 147 | def test_tool_from_class(): 148 | class Foo: 149 | def __init__(self, user_id: str): 150 | self.user_id = user_id 151 | 152 | def __call__(self, a: int, b: int) -> int: 153 | """Add two numbers.""" 154 | return a + b 155 | 156 | tool = Tool(Foo("123")) 157 | assert tool.name == "Foo" 158 | assert tool.desc == "Add two numbers." 159 | assert tool.args == {"a": {"type": "integer"}, "b": {"type": "integer"}} 160 | 161 | 162 | def test_tool_from_function_with_pydantic(): 163 | tool = Tool(dummy_with_pydantic) 164 | 165 | assert tool.name == "dummy_with_pydantic" 166 | assert "model" in tool.args 167 | assert tool.args["model"]["type"] == "object" 168 | assert "field1" in tool.args["model"]["properties"] 169 | assert "field2" in tool.args["model"]["properties"] 170 | assert tool.args["model"]["properties"]["field1"]["default"] == "hello" 171 | 172 | 173 | def test_tool_from_function_with_pydantic_nesting(): 174 | tool = Tool(complex_dummy_function) 175 | 176 | assert tool.name == "complex_dummy_function" 177 | 178 | assert "profile" in tool.args 179 | assert "priority" in tool.args 180 | assert "notes" in tool.args 181 | assert tool.args["profile"]["type"] == "object" 182 | assert tool.args["profile"]["properties"]["user_id"]["type"] == "integer" 183 | assert tool.args["profile"]["properties"]["name"]["type"] == "string" 184 | assert tool.args["profile"]["properties"]["age"]["anyOf"] == [{"type": "integer"}, {"type": "null"}] 185 | assert tool.args["profile"]["properties"]["contact"]["type"] == "object" 186 | assert tool.args["profile"]["properties"]["contact"]["properties"]["email"]["type"] == "string" 187 | 188 | # Reference should be resolved for nested pydantic models 189 | assert "$defs" not in str(tool.args["notes"]) 190 | assert tool.args["notes"]["anyOf"][0]["type"] == "array" 191 | assert tool.args["notes"]["anyOf"][0]["items"]["type"] == "object" 192 | assert tool.args["notes"]["anyOf"][0]["items"]["properties"]["content"]["type"] == "string" 193 | assert tool.args["notes"]["anyOf"][0]["items"]["properties"]["author"]["type"] == "string" 194 | 195 | 196 | def test_tool_callable(): 197 | tool = Tool(dummy_function) 198 | result = tool(x=42, y="hello") 199 | assert result == "hello 42" 200 | 201 | 202 | def test_tool_with_pydantic_callable(): 203 | tool = Tool(dummy_with_pydantic) 204 | model = DummyModel(field1="test", field2=123) 205 | result = tool(model=model) 206 | assert result == "test 123" 207 | 208 | 209 | def test_invalid_function_call(): 210 | tool = Tool(dummy_function) 211 | with pytest.raises(ValueError): 212 | tool(x="not an integer", y="hello") 213 | 214 | 215 | def test_parameter_desc(): 216 | tool = Tool(dummy_function, arg_desc={"x": "The x parameter"}) 217 | assert tool.args["x"]["description"] == "The x parameter" 218 | 219 | 220 | def test_tool_with_default_args_without_type_hints(): 221 | def foo(x=100): 222 | return x 223 | 224 | tool = Tool(foo) 225 | assert tool.args["x"]["default"] == 100 226 | assert not hasattr(tool.args["x"], "type") 227 | 228 | 229 | def test_tool_call_parses_args(): 230 | tool = Tool(dummy_with_pydantic) 231 | 232 | args = { 233 | "model": { 234 | "field1": "hello", 235 | "field2": 123, 236 | } 237 | } 238 | 239 | result = tool(**args) 240 | assert result == "hello 123" 241 | 242 | 243 | def test_tool_call_parses_nested_list_of_pydantic_model(): 244 | def dummy_function(x: list[list[DummyModel]]): 245 | return x 246 | 247 | tool = Tool(dummy_function) 248 | args = { 249 | "x": [ 250 | [ 251 | { 252 | "field1": "hello", 253 | "field2": 123, 254 | } 255 | ] 256 | ] 257 | } 258 | 259 | result = tool(**args) 260 | assert result == [[DummyModel(field1="hello", field2=123)]] 261 | 262 | 263 | def test_tool_call_kwarg(): 264 | def fn(x: int, **kwargs): 265 | return kwargs 266 | 267 | tool = Tool(fn) 268 | 269 | assert tool(x=1, y=2, z=3) == {"y": 2, "z": 3} 270 | 271 | 272 | def test_tool_str(): 273 | def add(x: int, y: int = 0) -> int: 274 | """Add two integers.""" 275 | return x + y 276 | 277 | tool = Tool(add) 278 | assert ( 279 | str(tool) 280 | == "add, whose description is <desc>Add two integers.</desc>. It takes arguments {'x': {'type': 'integer'}, 'y': {'type': 'integer', 'default': 0}}." 281 | ) 282 | 283 | 284 | @pytest.mark.asyncio 285 | async def test_async_tool_from_function(): 286 | tool = Tool(async_dummy_function) 287 | 288 | assert tool.name == "async_dummy_function" 289 | assert "An async dummy function for testing" in tool.desc 290 | assert "x" in tool.args 291 | assert "y" in tool.args 292 | assert tool.args["x"]["type"] == "integer" 293 | assert tool.args["y"]["type"] == "string" 294 | assert tool.args["y"]["default"] == "hello" 295 | 296 | # Test async call 297 | result = await tool.acall(x=42, y="hello") 298 | assert result == "hello 42" 299 | 300 | 301 | @pytest.mark.asyncio 302 | async def test_async_tool_with_pydantic(): 303 | tool = Tool(async_dummy_with_pydantic) 304 | 305 | assert tool.name == "async_dummy_with_pydantic" 306 | assert "model" in tool.args 307 | assert tool.args["model"]["type"] == "object" 308 | assert "field1" in tool.args["model"]["properties"] 309 | assert "field2" in tool.args["model"]["properties"] 310 | 311 | # Test async call with pydantic model 312 | model = DummyModel(field1="test", field2=123) 313 | result = await tool.acall(model=model) 314 | assert result == "test 123" 315 | 316 | # Test async call with dict 317 | result = await tool.acall(model={"field1": "test", "field2": 123}) 318 | assert result == "test 123" 319 | 320 | 321 | @pytest.mark.asyncio 322 | async def test_async_tool_with_complex_pydantic(): 323 | tool = Tool(async_complex_dummy_function) 324 | 325 | profile = UserProfile( 326 | user_id=1, 327 | name="Test User", 328 | contact=ContactInfo( 329 | email="[email protected]", 330 | addresses=[ 331 | Address(street="123 Main St", city="Test City", zip_code="12345", is_primary=True), 332 | Address(street="456 Side St", city="Test City", zip_code="12345"), 333 | ], 334 | ), 335 | ) 336 | 337 | result = await tool.acall(profile=profile, priority=1, notes=[Note(content="Test note", author="Test author")]) 338 | assert result["user_id"] == 1 339 | assert result["name"] == "Test User" 340 | assert result["priority"] == 1 341 | assert result["notes"] == [Note(content="Test note", author="Test author")] 342 | assert result["primary_address"]["street"] == "123 Main St" 343 | 344 | 345 | @pytest.mark.asyncio 346 | async def test_async_tool_invalid_call(): 347 | tool = Tool(async_dummy_function) 348 | with pytest.raises(ValueError): 349 | await tool.acall(x="not an integer", y="hello") 350 | 351 | 352 | @pytest.mark.asyncio 353 | async def test_async_tool_with_kwargs(): 354 | async def fn(x: int, **kwargs): 355 | return kwargs 356 | 357 | tool = Tool(fn) 358 | 359 | result = await tool.acall(x=1, y=2, z=3) 360 | assert result == {"y": 2, "z": 3} 361 | 362 | 363 | @pytest.mark.asyncio 364 | async def test_async_concurrent_calls(): 365 | """Test that multiple async tools can run concurrently.""" 366 | tool = Tool(async_dummy_function) 367 | 368 | # Create multiple concurrent calls 369 | tasks = [tool.acall(x=i, y=f"hello{i}") for i in range(5)] 370 | 371 | # Run them concurrently and measure time 372 | start_time = asyncio.get_event_loop().time() 373 | results = await asyncio.gather(*tasks) 374 | end_time = asyncio.get_event_loop().time() 375 | 376 | # Verify results, `asyncio.gather` returns results in the order of the tasks 377 | assert results == [f"hello{i} {i}" for i in range(5)] 378 | 379 | # Check that it ran concurrently (should take ~0.1s, not ~0.5s) 380 | # We use 0.3s as threshold to account for some overhead 381 | assert end_time - start_time < 0.3 382 | 383 | 384 | @pytest.mark.filterwarnings("ignore::RuntimeWarning") 385 | def test_async_tool_call_in_sync_mode(): 386 | tool = Tool(async_dummy_function) 387 | with dspy.context(allow_tool_async_sync_conversion=False): 388 | with pytest.raises(ValueError): 389 | result = tool(x=1, y="hello") 390 | 391 | with dspy.context(allow_tool_async_sync_conversion=True): 392 | result = tool(x=1, y="hello") 393 | assert result == "hello 1" 394 | 395 | 396 | TOOL_CALL_TEST_CASES = [ 397 | ([], {"tool_calls": []}), 398 | ( 399 | [{"name": "search", "args": {"query": "hello"}}], 400 | { 401 | "tool_calls": [{"type": "function", "function": {"name": "search", "arguments": {"query": "hello"}}}], 402 | }, 403 | ), 404 | ( 405 | [ 406 | {"name": "search", "args": {"query": "hello"}}, 407 | {"name": "translate", "args": {"text": "world", "lang": "fr"}}, 408 | ], 409 | { 410 | "tool_calls": [ 411 | {"type": "function", "function": {"name": "search", "arguments": {"query": "hello"}}}, 412 | { 413 | "type": "function", 414 | "function": {"name": "translate", "arguments": {"text": "world", "lang": "fr"}}, 415 | }, 416 | ], 417 | }, 418 | ), 419 | ( 420 | [{"name": "get_time", "args": {}}], 421 | { 422 | "tool_calls": [{"type": "function", "function": {"name": "get_time", "arguments": {}}}], 423 | }, 424 | ), 425 | ] 426 | 427 | 428 | @pytest.mark.parametrize("tool_calls_data,expected", TOOL_CALL_TEST_CASES) 429 | def test_tool_calls_format_basic(tool_calls_data, expected): 430 | """Test ToolCalls.format with various basic scenarios.""" 431 | tool_calls_list = [ToolCalls.ToolCall(**data) for data in tool_calls_data] 432 | tool_calls = ToolCalls(tool_calls=tool_calls_list) 433 | result = tool_calls.format() 434 | 435 | assert result == expected 436 | 437 | 438 | def test_tool_calls_format_from_dict_list(): 439 | """Test format works with ToolCalls created from from_dict_list.""" 440 | tool_calls_dicts = [ 441 | {"name": "search", "args": {"query": "hello"}}, 442 | {"name": "translate", "args": {"text": "world", "lang": "fr"}}, 443 | ] 444 | 445 | tool_calls = ToolCalls.from_dict_list(tool_calls_dicts) 446 | result = tool_calls.format() 447 | 448 | assert len(result["tool_calls"]) == 2 449 | assert result["tool_calls"][0]["function"]["name"] == "search" 450 | assert result["tool_calls"][1]["function"]["name"] == "translate" 451 | 452 | 453 | def test_toolcalls_vague_match(): 454 | """ 455 | Test that ToolCalls can parse the data with slightly off format: 456 | 457 | - a single dict with "name" and "args" 458 | - a list of dicts with "name" and "args" 459 | - invalid input (should raise ValueError) 460 | """ 461 | # Single dict with "name" and "args" should parse as one ToolCall 462 | data_single = {"name": "search", "args": {"query": "hello"}} 463 | tc = ToolCalls.model_validate(data_single) 464 | assert isinstance(tc, ToolCalls) 465 | assert len(tc.tool_calls) == 1 466 | assert tc.tool_calls[0].name == "search" 467 | assert tc.tool_calls[0].args == {"query": "hello"} 468 | 469 | # List of dicts with "name" and "args" should parse as multiple ToolCalls 470 | data_list = [ 471 | {"name": "search", "args": {"query": "hello"}}, 472 | {"name": "translate", "args": {"text": "world", "lang": "fr"}}, 473 | ] 474 | tc = ToolCalls.model_validate(data_list) 475 | assert isinstance(tc, ToolCalls) 476 | assert len(tc.tool_calls) == 2 477 | assert tc.tool_calls[0].name == "search" 478 | assert tc.tool_calls[1].name == "translate" 479 | 480 | # Dict with "tool_calls" key containing a list of dicts 481 | data_tool_calls = { 482 | "tool_calls": [ 483 | {"name": "search", "args": {"query": "hello"}}, 484 | {"name": "get_time", "args": {}}, 485 | ] 486 | } 487 | tc = ToolCalls.model_validate(data_tool_calls) 488 | assert isinstance(tc, ToolCalls) 489 | assert len(tc.tool_calls) == 2 490 | assert tc.tool_calls[0].name == "search" 491 | assert tc.tool_calls[1].name == "get_time" 492 | 493 | # Invalid input should raise ValueError 494 | with pytest.raises(ValueError): 495 | ToolCalls.model_validate({"foo": "bar"}) 496 | with pytest.raises(ValueError): 497 | ToolCalls.model_validate([{"foo": "bar"}]) 498 | 499 | 500 | def test_tool_convert_input_schema_to_tool_args_no_input_params(): 501 | args, arg_types, arg_desc = convert_input_schema_to_tool_args(schema={"properties": {}}) 502 | assert args == {} 503 | assert arg_types == {} 504 | assert arg_desc == {} 505 | 506 | 507 | def test_tool_convert_input_schema_to_tool_args_lang_chain(): 508 | # Example from langchain docs: 509 | # https://web.archive.org/web/20250723101359/https://api.python.langchain.com/en/latest/tools/langchain_core.tools.tool.html 510 | args, arg_types, arg_desc = convert_input_schema_to_tool_args( 511 | schema={ 512 | "title": "fooSchema", 513 | "description": "The foo.", 514 | "type": "object", 515 | "properties": { 516 | "bar": { 517 | "title": "Bar", 518 | "description": "The bar.", 519 | "type": "string", 520 | }, 521 | "baz": { 522 | "title": "Baz", 523 | "type": "integer", 524 | }, 525 | }, 526 | "required": [ 527 | "baz", 528 | ], 529 | } 530 | ) 531 | assert args == { 532 | "bar": {"title": "Bar", "description": "The bar.", "type": "string"}, 533 | "baz": {"title": "Baz", "type": "integer"}, 534 | } 535 | assert arg_types == { 536 | "bar": str, 537 | "baz": int, 538 | } 539 | assert arg_desc == { 540 | "bar": "The bar.", 541 | "baz": "No description provided. (Required)", 542 | } 543 | 544 | 545 | 546 | 547 | def test_tool_call_execute(): 548 | def get_weather(city: str) -> str: 549 | return f"The weather in {city} is sunny" 550 | 551 | def add_numbers(a: int, b: int) -> int: 552 | return a + b 553 | 554 | tools = [ 555 | dspy.Tool(get_weather), 556 | dspy.Tool(add_numbers) 557 | ] 558 | 559 | tool_call = dspy.ToolCalls.ToolCall(name="get_weather", args={"city": "Berlin"}) 560 | result = tool_call.execute(functions=tools) 561 | assert result == "The weather in Berlin is sunny" 562 | 563 | # Test individual tool call with function dict 564 | tool_call2 = dspy.ToolCalls.ToolCall(name="add_numbers", args={"a": 7, "b": 13}) 565 | result2 = tool_call2.execute(functions={"add_numbers": add_numbers}) 566 | assert result2 == 20 567 | 568 | # Test individual tool call with no arguments 569 | def get_pi(): 570 | return 3.14159 571 | 572 | tool_call3 = dspy.ToolCalls.ToolCall(name="get_pi", args={}) 573 | result3 = tool_call3.execute(functions={"get_pi": get_pi}) 574 | assert result3 == 3.14159 575 | 576 | # Test error case 577 | tool_call4 = dspy.ToolCalls.ToolCall(name="nonexistent", args={}) 578 | try: 579 | tool_call4.execute(functions=tools) 580 | assert False, "Should have raised ValueError" 581 | except ValueError as e: 582 | assert "not found" in str(e) 583 | 584 | 585 | def test_tool_call_execute_with_local_functions(): 586 | def main(): 587 | def local_add(a: int, b: int) -> int: 588 | return a + b 589 | 590 | def local_multiply(x: int, y: int) -> int: 591 | return x * y 592 | 593 | # Test individual execution with local function 594 | tool_call1 = dspy.ToolCalls.ToolCall(name="local_add", args={"a": 10, "b": 15}) 595 | result1 = tool_call1.execute() # Should find local function automatically 596 | assert result1 == 25 597 | 598 | tool_call2 = dspy.ToolCalls.ToolCall(name="local_multiply", args={"x": 4, "y": 7}) 599 | result2 = tool_call2.execute() # Should find local function automatically 600 | assert result2 == 28 601 | 602 | # Test locals take precedence over globals 603 | try: 604 | globals()["local_add"] = lambda a, b: a + b + 1000 605 | precedence_call = dspy.ToolCalls.ToolCall(name="local_add", args={"a": 1, "b": 2}) 606 | result = precedence_call.execute() 607 | assert result == 3 # Should use local function (1+2=3), not global (1+2+1000=1003) 608 | finally: 609 | globals().pop("local_add", None) 610 | 611 | main() 612 | ``` -------------------------------------------------------------------------------- /docs/docs/tutorials/streaming/index.md: -------------------------------------------------------------------------------- ```markdown 1 | # Streaming 2 | 3 | In this guide, we will walk you through how to enable streaming in your DSPy program. DSPy Streaming 4 | consists of two parts: 5 | 6 | - **Output Token Streaming**: Stream individual tokens as they're generated, rather than waiting for the complete response. 7 | - **Intermediate Status Streaming**: Provide real-time updates about the program's execution state (e.g., "Calling web search...", "Processing results..."). 8 | 9 | ## Output Token Streaming 10 | 11 | DSPy's token streaming feature works with any module in your pipeline, not just the final output. The only requirement is that the streamed field must be of type `str`. To enable token streaming: 12 | 13 | 1. Wrap your program with `dspy.streamify` 14 | 2. Create one or more `dspy.streaming.StreamListener` objects to specify which fields to stream 15 | 16 | Here's a basic example: 17 | 18 | ```python 19 | import os 20 | 21 | import dspy 22 | 23 | os.environ["OPENAI_API_KEY"] = "your_api_key" 24 | 25 | dspy.configure(lm=dspy.LM("openai/gpt-4o-mini")) 26 | 27 | predict = dspy.Predict("question->answer") 28 | 29 | # Enable streaming for the 'answer' field 30 | stream_predict = dspy.streamify( 31 | predict, 32 | stream_listeners=[dspy.streaming.StreamListener(signature_field_name="answer")], 33 | ) 34 | ``` 35 | 36 | To consume the streamed output: 37 | 38 | ```python 39 | import asyncio 40 | 41 | async def read_output_stream(): 42 | output_stream = stream_predict(question="Why did a chicken cross the kitchen?") 43 | 44 | async for chunk in output_stream: 45 | print(chunk) 46 | 47 | asyncio.run(read_output_stream()) 48 | ``` 49 | 50 | This will produce output like: 51 | 52 | ``` 53 | StreamResponse(predict_name='self', signature_field_name='answer', chunk='To') 54 | StreamResponse(predict_name='self', signature_field_name='answer', chunk=' get') 55 | StreamResponse(predict_name='self', signature_field_name='answer', chunk=' to') 56 | StreamResponse(predict_name='self', signature_field_name='answer', chunk=' the') 57 | StreamResponse(predict_name='self', signature_field_name='answer', chunk=' other') 58 | StreamResponse(predict_name='self', signature_field_name='answer', chunk=' side of the frying pan!') 59 | Prediction( 60 | answer='To get to the other side of the frying pan!' 61 | ) 62 | ``` 63 | 64 | Note: Since `dspy.streamify` returns an async generator, you must use it within an async context. If you're using an environment like Jupyter or Google Colab that already has an event loop (async context), you can use the generator directly. 65 | 66 | You may have noticed that the above streaming contains two different entities: `StreamResponse` 67 | and `Prediction.` `StreamResponse` is the wrapper over streaming tokens on the field being listened to, and in 68 | this example it is the `answer` field. `Prediction` is the program's final output. In DSPy, streaming is 69 | implemented in a sidecar fashion: we enable streaming on the LM so that LM outputs a stream of tokens. We send these 70 | tokens to a side channel, which is being continuously read by the user-defined listeners. Listeners keep interpreting 71 | the stream, and decides if the `signature_field_name` it is listening to has started to appear and has finalized. 72 | Once it decides that the field appears, the listener begins outputting tokens to the async generator users can 73 | read. Listeners' internal mechanism changes according to the adapter behind the scene, and because usually 74 | we cannot decide if a field has finalized until seeing the next field, the listener buffers the output tokens 75 | before sending to the final generator, which is why you will usually see the last chunk of type `StreamResponse` 76 | has more than one token. The program's output is also written to the stream, which is the chunk of `Prediction` 77 | as in the sample output above. 78 | 79 | To handle these different types and implement custom logic: 80 | 81 | ```python 82 | import asyncio 83 | 84 | async def read_output_stream(): 85 | output_stream = stream_predict(question="Why did a chicken cross the kitchen?") 86 | 87 | async for chunk in output_stream: 88 | return_value = None 89 | if isinstance(chunk, dspy.streaming.StreamResponse): 90 | print(f"Output token of field {chunk.signature_field_name}: {chunk.chunk}") 91 | elif isinstance(chunk, dspy.Prediction): 92 | return_value = chunk 93 | 94 | 95 | program_output = asyncio.run(read_output_stream()) 96 | print("Final output: ", program_output) 97 | ``` 98 | 99 | ### Understand `StreamResponse` 100 | 101 | `StreamResponse` (`dspy.streaming.StreamResponse`) is the wrapper class of streaming tokens. It comes with 3 102 | fields: 103 | 104 | - `predict_name`: the name of the predict that holds the `signature_field_name`. The name is the 105 | same name of keys as you run `your_program.named_predictors()`. In the code above because `answer` is from 106 | the `predict` itself, so the `predict_name` shows up as `self`, which is the only key as your run 107 | `predict.named_predictors()`. 108 | - `signature_field_name`: the output field that these tokens map to. `predict_name` and `signature_field_name` 109 | together form the unique identifier of the field. We will demonstrate how to handle multiple fields streaming 110 | and duplicated field name later in this guide. 111 | - `chunk`: the value of the stream chunk. 112 | 113 | ### Streaming with Cache 114 | 115 | When a cached result is found, the stream will skip individual tokens and only yield the final `Prediction`. For example: 116 | 117 | ``` 118 | Prediction( 119 | answer='To get to the other side of the dinner plate!' 120 | ) 121 | ``` 122 | 123 | ### Streaming Multiple Fields 124 | 125 | You can monitor multiple fields by creating a `StreamListener` for each one. Here's an example with a multi-module program: 126 | 127 | ```python 128 | import asyncio 129 | 130 | import dspy 131 | 132 | lm = dspy.LM("openai/gpt-4o-mini", cache=False) 133 | dspy.settings.configure(lm=lm) 134 | 135 | 136 | class MyModule(dspy.Module): 137 | def __init__(self): 138 | super().__init__() 139 | 140 | self.predict1 = dspy.Predict("question->answer") 141 | self.predict2 = dspy.Predict("answer->simplified_answer") 142 | 143 | def forward(self, question: str, **kwargs): 144 | answer = self.predict1(question=question) 145 | simplified_answer = self.predict2(answer=answer) 146 | return simplified_answer 147 | 148 | 149 | predict = MyModule() 150 | stream_listeners = [ 151 | dspy.streaming.StreamListener(signature_field_name="answer"), 152 | dspy.streaming.StreamListener(signature_field_name="simplified_answer"), 153 | ] 154 | stream_predict = dspy.streamify( 155 | predict, 156 | stream_listeners=stream_listeners, 157 | ) 158 | 159 | async def read_output_stream(): 160 | output = stream_predict(question="why did a chicken cross the kitchen?") 161 | 162 | return_value = None 163 | async for chunk in output: 164 | if isinstance(chunk, dspy.streaming.StreamResponse): 165 | print(chunk) 166 | elif isinstance(chunk, dspy.Prediction): 167 | return_value = chunk 168 | return return_value 169 | 170 | program_output = asyncio.run(read_output_stream()) 171 | print("Final output: ", program_output) 172 | ``` 173 | 174 | The output will look like: 175 | 176 | ``` 177 | StreamResponse(predict_name='predict1', signature_field_name='answer', chunk='To') 178 | StreamResponse(predict_name='predict1', signature_field_name='answer', chunk=' get') 179 | StreamResponse(predict_name='predict1', signature_field_name='answer', chunk=' to') 180 | StreamResponse(predict_name='predict1', signature_field_name='answer', chunk=' the') 181 | StreamResponse(predict_name='predict1', signature_field_name='answer', chunk=' other side of the recipe!') 182 | StreamResponse(predict_name='predict2', signature_field_name='simplified_answer', chunk='To') 183 | StreamResponse(predict_name='predict2', signature_field_name='simplified_answer', chunk=' reach') 184 | StreamResponse(predict_name='predict2', signature_field_name='simplified_answer', chunk=' the') 185 | StreamResponse(predict_name='predict2', signature_field_name='simplified_answer', chunk=' other side of the recipe!') 186 | Final output: Prediction( 187 | simplified_answer='To reach the other side of the recipe!' 188 | ) 189 | ``` 190 | 191 | ### Streaming the Same Field Multiple Times (as in dspy.ReAct) 192 | 193 | By default, a `StreamListener` automatically closes itself after completing a single streaming session. 194 | This design helps prevent performance issues, since every token is broadcast to all configured stream listeners, 195 | and having too many active listeners can introduce significant overhead. 196 | 197 | However, in scenarios where a DSPy module is used repeatedly in a loop—such as with `dspy.ReAct` — you may want to stream 198 | the same field from each prediction, every time it is used. To enable this behavior, set allow_reuse=True when creating 199 | your `StreamListener`. See the example below: 200 | 201 | ```python 202 | import asyncio 203 | 204 | import dspy 205 | 206 | lm = dspy.LM("openai/gpt-4o-mini", cache=False) 207 | dspy.settings.configure(lm=lm) 208 | 209 | 210 | def fetch_user_info(user_name: str): 211 | """Get user information like name, birthday, etc.""" 212 | return { 213 | "name": user_name, 214 | "birthday": "2009-05-16", 215 | } 216 | 217 | 218 | def get_sports_news(year: int): 219 | """Get sports news for a given year.""" 220 | if year == 2009: 221 | return "Usane Bolt broke the world record in the 100m race." 222 | return None 223 | 224 | 225 | react = dspy.ReAct("question->answer", tools=[fetch_user_info, get_sports_news]) 226 | 227 | stream_listeners = [ 228 | # dspy.ReAct has a built-in output field called "next_thought". 229 | dspy.streaming.StreamListener(signature_field_name="next_thought", allow_reuse=True), 230 | ] 231 | stream_react = dspy.streamify(react, stream_listeners=stream_listeners) 232 | 233 | 234 | async def read_output_stream(): 235 | output = stream_react(question="What sports news happened in the year Adam was born?") 236 | return_value = None 237 | async for chunk in output: 238 | if isinstance(chunk, dspy.streaming.StreamResponse): 239 | print(chunk) 240 | elif isinstance(chunk, dspy.Prediction): 241 | return_value = chunk 242 | return return_value 243 | 244 | 245 | print(asyncio.run(read_output_stream())) 246 | ``` 247 | 248 | In this example, by setting `allow_reuse=True` in the StreamListener, you ensure that streaming for "next_thought" is 249 | available for every iteration, not just the first. When you run this code, you will see the streaming tokens for `next_thought` 250 | output each time the field is produced. 251 | 252 | #### Handling Duplicate Field Names 253 | 254 | When streaming fields with the same name from different modules, specify both the `predict` and `predict_name` in the `StreamListener`: 255 | 256 | ```python 257 | import asyncio 258 | 259 | import dspy 260 | 261 | lm = dspy.LM("openai/gpt-4o-mini", cache=False) 262 | dspy.settings.configure(lm=lm) 263 | 264 | 265 | class MyModule(dspy.Module): 266 | def __init__(self): 267 | super().__init__() 268 | 269 | self.predict1 = dspy.Predict("question->answer") 270 | self.predict2 = dspy.Predict("question, answer->answer, score") 271 | 272 | def forward(self, question: str, **kwargs): 273 | answer = self.predict1(question=question) 274 | simplified_answer = self.predict2(answer=answer) 275 | return simplified_answer 276 | 277 | 278 | predict = MyModule() 279 | stream_listeners = [ 280 | dspy.streaming.StreamListener( 281 | signature_field_name="answer", 282 | predict=predict.predict1, 283 | predict_name="predict1" 284 | ), 285 | dspy.streaming.StreamListener( 286 | signature_field_name="answer", 287 | predict=predict.predict2, 288 | predict_name="predict2" 289 | ), 290 | ] 291 | stream_predict = dspy.streamify( 292 | predict, 293 | stream_listeners=stream_listeners, 294 | ) 295 | 296 | 297 | async def read_output_stream(): 298 | output = stream_predict(question="why did a chicken cross the kitchen?") 299 | 300 | return_value = None 301 | async for chunk in output: 302 | if isinstance(chunk, dspy.streaming.StreamResponse): 303 | print(chunk) 304 | elif isinstance(chunk, dspy.Prediction): 305 | return_value = chunk 306 | return return_value 307 | 308 | 309 | program_output = asyncio.run(read_output_stream()) 310 | print("Final output: ", program_output) 311 | ``` 312 | 313 | The output will be like: 314 | 315 | ``` 316 | StreamResponse(predict_name='predict1', signature_field_name='answer', chunk='To') 317 | StreamResponse(predict_name='predict1', signature_field_name='answer', chunk=' get') 318 | StreamResponse(predict_name='predict1', signature_field_name='answer', chunk=' to') 319 | StreamResponse(predict_name='predict1', signature_field_name='answer', chunk=' the') 320 | StreamResponse(predict_name='predict1', signature_field_name='answer', chunk=' other side of the recipe!') 321 | StreamResponse(predict_name='predict2', signature_field_name='answer', chunk="I'm") 322 | StreamResponse(predict_name='predict2', signature_field_name='answer', chunk=' ready') 323 | StreamResponse(predict_name='predict2', signature_field_name='answer', chunk=' to') 324 | StreamResponse(predict_name='predict2', signature_field_name='answer', chunk=' assist') 325 | StreamResponse(predict_name='predict2', signature_field_name='answer', chunk=' you') 326 | StreamResponse(predict_name='predict2', signature_field_name='answer', chunk='! Please provide a question.') 327 | Final output: Prediction( 328 | answer="I'm ready to assist you! Please provide a question.", 329 | score='N/A' 330 | ) 331 | ``` 332 | 333 | ## Intermediate Status Streaming 334 | 335 | Status streaming keeps users informed about the program's progress, especially useful for long-running operations like tool calls or complex AI pipelines. To implement status streaming: 336 | 337 | 1. Create a custom status message provider by subclassing `dspy.streaming.StatusMessageProvider` 338 | 2. Override the desired hook methods to provide custom status messages 339 | 3. Pass your provider to `dspy.streamify` 340 | 341 | Example: 342 | 343 | ```python 344 | class MyStatusMessageProvider(dspy.streaming.StatusMessageProvider): 345 | def lm_start_status_message(self, instance, inputs): 346 | return f"Calling LM with inputs {inputs}..." 347 | 348 | def lm_end_status_message(self, outputs): 349 | return f"Tool finished with output: {outputs}!" 350 | ``` 351 | 352 | Available hooks: 353 | 354 | - lm_start_status_message: status message at the start of calling dspy.LM. 355 | - lm_end_status_message: status message at the end of calling dspy.LM. 356 | - module_start_status_message: status message at the start of calling a dspy.Module. 357 | - module_end_status_message: status message at the start of calling a dspy.Module. 358 | - tool_start_status_message: status message at the start of calling dspy.Tool. 359 | - tool_end_status_message: status message at the end of calling dspy.Tool. 360 | 361 | Each hook should return a string containing the status message. 362 | 363 | After creating the message provider, just pass it to `dspy.streamify`, and you can enable both 364 | status message streaming and output token streaming. Please see the example below. The intermediate 365 | status message is represented in the class `dspy.streaming.StatusMessage`, so we need to have 366 | another condition check to capture it. 367 | 368 | ```python 369 | import asyncio 370 | 371 | import dspy 372 | 373 | lm = dspy.LM("openai/gpt-4o-mini", cache=False) 374 | dspy.settings.configure(lm=lm) 375 | 376 | 377 | class MyModule(dspy.Module): 378 | def __init__(self): 379 | super().__init__() 380 | 381 | self.tool = dspy.Tool(lambda x: 2 * x, name="double_the_number") 382 | self.predict = dspy.ChainOfThought("num1, num2->sum") 383 | 384 | def forward(self, num, **kwargs): 385 | num2 = self.tool(x=num) 386 | return self.predict(num1=num, num2=num2) 387 | 388 | 389 | class MyStatusMessageProvider(dspy.streaming.StatusMessageProvider): 390 | def tool_start_status_message(self, instance, inputs): 391 | return f"Calling Tool {instance.name} with inputs {inputs}..." 392 | 393 | def tool_end_status_message(self, outputs): 394 | return f"Tool finished with output: {outputs}!" 395 | 396 | 397 | predict = MyModule() 398 | stream_listeners = [ 399 | # dspy.ChainOfThought has a built-in output field called "reasoning". 400 | dspy.streaming.StreamListener(signature_field_name="reasoning"), 401 | ] 402 | stream_predict = dspy.streamify( 403 | predict, 404 | stream_listeners=stream_listeners, 405 | status_message_provider=MyStatusMessageProvider(), 406 | ) 407 | 408 | 409 | async def read_output_stream(): 410 | output = stream_predict(num=3) 411 | 412 | return_value = None 413 | async for chunk in output: 414 | if isinstance(chunk, dspy.streaming.StreamResponse): 415 | print(chunk) 416 | elif isinstance(chunk, dspy.Prediction): 417 | return_value = chunk 418 | elif isinstance(chunk, dspy.streaming.StatusMessage): 419 | print(chunk) 420 | return return_value 421 | 422 | 423 | program_output = asyncio.run(read_output_stream()) 424 | print("Final output: ", program_output) 425 | ``` 426 | 427 | Sample output: 428 | 429 | ``` 430 | StatusMessage(message='Calling tool double_the_number...') 431 | StatusMessage(message='Tool calling finished! Querying the LLM with tool calling results...') 432 | StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk='To') 433 | StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' find') 434 | StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' the') 435 | StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' sum') 436 | StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' of') 437 | StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' the') 438 | StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' two') 439 | StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' numbers') 440 | StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=',') 441 | StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' we') 442 | StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' simply') 443 | StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' add') 444 | StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' them') 445 | StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' together') 446 | StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk='.') 447 | StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' Here') 448 | StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=',') 449 | StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' ') 450 | StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk='3') 451 | StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' plus') 452 | StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' 6 equals 9.') 453 | Final output: Prediction( 454 | reasoning='To find the sum of the two numbers, we simply add them together. Here, 3 plus 6 equals 9.', 455 | sum='9' 456 | ) 457 | ``` 458 | 459 | ## Synchronous Streaming 460 | 461 | By default calling a streamified DSPy program produces an async generator. In order to get back 462 | a sync generator, you can set the flag `async_streaming=False`: 463 | 464 | 465 | ```python 466 | import os 467 | 468 | import dspy 469 | 470 | os.environ["OPENAI_API_KEY"] = "your_api_key" 471 | 472 | dspy.configure(lm=dspy.LM("openai/gpt-4o-mini")) 473 | 474 | predict = dspy.Predict("question->answer") 475 | 476 | # Enable streaming for the 'answer' field 477 | stream_predict = dspy.streamify( 478 | predict, 479 | stream_listeners=[dspy.streaming.StreamListener(signature_field_name="answer")], 480 | async_streaming=False, 481 | ) 482 | 483 | output = stream_predict(question="why did a chicken cross the kitchen?") 484 | 485 | program_output = None 486 | for chunk in output: 487 | if isinstance(chunk, dspy.streaming.StreamResponse): 488 | print(chunk) 489 | elif isinstance(chunk, dspy.Prediction): 490 | program_output = chunk 491 | print(f"Program output: {program_output}") 492 | ``` 493 | ``` -------------------------------------------------------------------------------- /tests/signatures/test_adapter_image.py: -------------------------------------------------------------------------------- ```python 1 | import os 2 | import tempfile 3 | from io import BytesIO 4 | 5 | import pydantic 6 | import pytest 7 | import requests 8 | from PIL import Image as PILImage 9 | 10 | import dspy 11 | from dspy.adapters.types.image import encode_image 12 | from dspy.utils.dummies import DummyLM 13 | 14 | 15 | @pytest.fixture 16 | def sample_pil_image(): 17 | """Fixture to provide a sample image for testing""" 18 | url = "https://images.dog.ceo/breeds/dane-great/n02109047_8912.jpg" 19 | response = requests.get(url) 20 | response.raise_for_status() 21 | return PILImage.open(BytesIO(response.content)) 22 | 23 | 24 | @pytest.fixture 25 | def sample_dspy_image_download(): 26 | url = "https://images.dog.ceo/breeds/dane-great/n02109047_8912.jpg" 27 | return dspy.Image(url, download=True) 28 | 29 | 30 | @pytest.fixture 31 | def sample_url(): 32 | return "https://images.dog.ceo/breeds/dane-great/n02109047_8912.jpg" 33 | 34 | 35 | @pytest.fixture 36 | def sample_dspy_image_no_download(): 37 | return dspy.Image("https://images.dog.ceo/breeds/dane-great/n02109047_8912.jpg") 38 | 39 | 40 | def count_messages_with_image_url_pattern(messages): 41 | pattern = {"type": "image_url", "image_url": {"url": lambda x: isinstance(x, str)}} 42 | 43 | try: 44 | 45 | def check_pattern(obj, pattern): 46 | if isinstance(pattern, dict): 47 | if not isinstance(obj, dict): 48 | return False 49 | return all(k in obj and check_pattern(obj[k], v) for k, v in pattern.items()) 50 | if callable(pattern): 51 | return pattern(obj) 52 | return obj == pattern 53 | 54 | def count_patterns(obj, pattern): 55 | count = 0 56 | if check_pattern(obj, pattern): 57 | count += 1 58 | if isinstance(obj, dict): 59 | count += sum(count_patterns(v, pattern) for v in obj.values()) 60 | if isinstance(obj, (list, tuple)): 61 | count += sum(count_patterns(v, pattern) for v in obj) 62 | return count 63 | 64 | return count_patterns(messages, pattern) 65 | except Exception: 66 | return 0 67 | 68 | 69 | def setup_predictor(signature, expected_output): 70 | """Helper to set up a predictor with DummyLM""" 71 | lm = DummyLM([expected_output]) 72 | dspy.settings.configure(lm=lm) 73 | return dspy.Predict(signature), lm 74 | 75 | 76 | @pytest.mark.parametrize( 77 | "test_case", 78 | [ 79 | { 80 | "name": "probabilistic_classification", 81 | "signature": "image: dspy.Image, class_labels: list[str] -> probabilities: dict[str, float]", 82 | "inputs": {"image": "https://example.com/dog.jpg", "class_labels": ["dog", "cat", "bird"]}, 83 | "key_output": "probabilities", 84 | "expected": {"probabilities": {"dog": 0.8, "cat": 0.1, "bird": 0.1}}, 85 | }, 86 | { 87 | "name": "image_to_code", 88 | "signature": "ui_image: dspy.Image, target_language: str -> generated_code: str", 89 | "inputs": {"ui_image": "https://example.com/button.png", "target_language": "HTML"}, 90 | "key_output": "generated_code", 91 | "expected": {"generated_code": "<button>Click me</button>"}, 92 | }, 93 | { 94 | "name": "bbox_detection", 95 | "signature": "image: dspy.Image -> bboxes: list[Tuple[int, int, int, int]]", 96 | "inputs": {"image": "https://example.com/image.jpg"}, 97 | "key_output": "bboxes", 98 | "expected": {"bboxes": [(10, 20, 30, 40), (50, 60, 70, 80)]}, 99 | }, 100 | { 101 | "name": "multilingual_caption", 102 | "signature": "image: dspy.Image, languages: list[str] -> captions: dict[str, str]", 103 | "inputs": {"image": "https://example.com/dog.jpg", "languages": ["en", "es", "fr"]}, 104 | "key_output": "captions", 105 | "expected": { 106 | "captions": {"en": "A golden retriever", "es": "Un golden retriever", "fr": "Un golden retriever"} 107 | }, 108 | }, 109 | ], 110 | ) 111 | def test_basic_image_operations(test_case): 112 | """Consolidated test for basic image operations""" 113 | predictor, lm = setup_predictor(test_case["signature"], test_case["expected"]) 114 | 115 | # Convert string URLs to dspy.Image objects 116 | inputs = { 117 | k: dspy.Image(v) if isinstance(v, str) and k in ["image", "ui_image"] else v 118 | for k, v in test_case["inputs"].items() 119 | } 120 | 121 | result = predictor(**inputs) 122 | 123 | # Check result based on output field name 124 | output_field = next(f for f in ["probabilities", "generated_code", "bboxes", "captions"] if hasattr(result, f)) 125 | assert getattr(result, output_field) == test_case["expected"][test_case["key_output"]] 126 | assert count_messages_with_image_url_pattern(lm.history[-1]["messages"]) == 1 127 | 128 | 129 | @pytest.mark.parametrize( 130 | "image_input,description", 131 | [ 132 | ("pil_image", "PIL Image"), 133 | ("encoded_pil_image", "encoded PIL image string"), 134 | ("dspy_image_download", "dspy.Image with download=True"), 135 | ("dspy_image_no_download", "dspy.Image without download"), 136 | ], 137 | ) 138 | def test_image_input_formats( 139 | request, sample_pil_image, sample_dspy_image_download, sample_dspy_image_no_download, image_input, description 140 | ): 141 | """Test different input formats for image fields""" 142 | signature = "image: dspy.Image, class_labels: list[str] -> probabilities: dict[str, float]" 143 | expected = {"probabilities": {"dog": 0.8, "cat": 0.1, "bird": 0.1}} 144 | predictor, lm = setup_predictor(signature, expected) 145 | 146 | input_map = { 147 | "pil_image": sample_pil_image, 148 | "encoded_pil_image": encode_image(sample_pil_image), 149 | "dspy_image_download": sample_dspy_image_download, 150 | "dspy_image_no_download": sample_dspy_image_no_download, 151 | } 152 | 153 | actual_input = input_map[image_input] 154 | # TODO(isaacbmiller): Support the cases without direct dspy.Image coercion 155 | if image_input in ["pil_image", "encoded_pil_image"]: 156 | pytest.xfail(f"{description} not fully supported without dspy.Image coercion") 157 | 158 | result = predictor(image=actual_input, class_labels=["dog", "cat", "bird"]) 159 | assert result.probabilities == expected["probabilities"] 160 | assert count_messages_with_image_url_pattern(lm.history[-1]["messages"]) == 1 161 | 162 | 163 | def test_predictor_save_load(sample_url, sample_pil_image): 164 | """Test saving and loading predictors with image fields""" 165 | signature = "image: dspy.Image -> caption: str" 166 | examples = [ 167 | dspy.Example(image=dspy.Image(sample_url), caption="Example 1"), 168 | dspy.Example(image=sample_pil_image, caption="Example 2"), 169 | ] 170 | 171 | predictor, lm = setup_predictor(signature, {"caption": "A golden retriever"}) 172 | optimizer = dspy.teleprompt.LabeledFewShot(k=1) 173 | compiled_predictor = optimizer.compile(student=predictor, trainset=examples, sample=False) 174 | 175 | with tempfile.NamedTemporaryFile(mode="w+", delete=True, suffix=".json") as temp_file: 176 | compiled_predictor.save(temp_file.name) 177 | loaded_predictor = dspy.Predict(signature) 178 | loaded_predictor.load(temp_file.name) 179 | 180 | loaded_predictor(image=dspy.Image("https://example.com/dog.jpg")) 181 | assert count_messages_with_image_url_pattern(lm.history[-1]["messages"]) == 2 182 | assert "<DSPY_IMAGE_START>" not in str(lm.history[-1]["messages"]) 183 | 184 | 185 | def test_save_load_complex_default_types(): 186 | """Test saving and loading predictors with complex default types (lists of images)""" 187 | examples = [ 188 | dspy.Example( 189 | image_list=[ 190 | dspy.Image("https://example.com/dog.jpg"), 191 | dspy.Image("https://example.com/cat.jpg"), 192 | ], 193 | caption="Example 1", 194 | ).with_inputs("image_list"), 195 | ] 196 | 197 | class ComplexTypeSignature(dspy.Signature): 198 | image_list: list[dspy.Image] = dspy.InputField(desc="A list of images") 199 | caption: str = dspy.OutputField(desc="A caption for the image list") 200 | 201 | predictor, lm = setup_predictor(ComplexTypeSignature, {"caption": "A list of images"}) 202 | optimizer = dspy.teleprompt.LabeledFewShot(k=1) 203 | compiled_predictor = optimizer.compile(student=predictor, trainset=examples, sample=False) 204 | 205 | with tempfile.NamedTemporaryFile(mode="w+", delete=True, suffix=".json") as temp_file: 206 | compiled_predictor.save(temp_file.name) 207 | loaded_predictor = dspy.Predict(ComplexTypeSignature) 208 | loaded_predictor.load(temp_file.name) 209 | 210 | result = loaded_predictor(**examples[0].inputs()) 211 | assert result.caption == "A list of images" 212 | assert str(lm.history[-1]["messages"]).count("'url'") == 4 213 | assert "<DSPY_IMAGE_START>" not in str(lm.history[-1]["messages"]) 214 | 215 | 216 | class BasicImageSignature(dspy.Signature): 217 | """Basic signature with a single image input""" 218 | 219 | image: dspy.Image = dspy.InputField() 220 | output: str = dspy.OutputField() 221 | 222 | 223 | class ImageListSignature(dspy.Signature): 224 | """Signature with a list of images input""" 225 | 226 | image_list: list[dspy.Image] = dspy.InputField() 227 | output: str = dspy.OutputField() 228 | 229 | 230 | @pytest.mark.parametrize( 231 | "test_case", 232 | [ 233 | { 234 | "name": "basic_dspy_signature", 235 | "signature_class": BasicImageSignature, 236 | "inputs": {"image": "https://example.com/dog.jpg"}, 237 | "expected": {"output": "A dog photo"}, 238 | "expected_image_urls": 2, 239 | }, 240 | { 241 | "name": "list_dspy_signature", 242 | "signature_class": ImageListSignature, 243 | "inputs": {"image_list": ["https://example.com/dog.jpg", "https://example.com/cat.jpg"]}, 244 | "expected": {"output": "Multiple photos"}, 245 | "expected_image_urls": 4, 246 | }, 247 | ], 248 | ) 249 | def test_save_load_complex_types(test_case): 250 | """Test saving and loading predictors with complex types""" 251 | signature_cls = test_case["signature_class"] 252 | 253 | # Convert string URLs to dspy.Image objects in input 254 | processed_input = {} 255 | for key, value in test_case["inputs"].items(): 256 | if isinstance(value, str) and "http" in value: 257 | processed_input[key] = dspy.Image(value) 258 | elif isinstance(value, list) and value and isinstance(value[0], str): 259 | processed_input[key] = [dspy.Image(url) for url in value] 260 | else: 261 | processed_input[key] = value 262 | 263 | # Create example and predictor 264 | examples = [dspy.Example(**processed_input, **test_case["expected"]).with_inputs(*processed_input.keys())] 265 | 266 | predictor, lm = setup_predictor(signature_cls, test_case["expected"]) 267 | optimizer = dspy.teleprompt.LabeledFewShot(k=1) 268 | compiled_predictor = optimizer.compile(student=predictor, trainset=examples, sample=False) 269 | 270 | # Test save and load 271 | with tempfile.NamedTemporaryFile(mode="w+", delete=True, suffix=".json") as temp_file: 272 | compiled_predictor.save(temp_file.name) 273 | loaded_predictor = dspy.Predict(signature_cls) 274 | loaded_predictor.load(temp_file.name) 275 | 276 | # Run prediction 277 | result = loaded_predictor(**processed_input) 278 | 279 | # Verify output matches expected 280 | for key, value in test_case["expected"].items(): 281 | assert getattr(result, key) == value 282 | 283 | # Verify correct number of image URLs in messages 284 | assert count_messages_with_image_url_pattern(lm.history[-1]["messages"]) == test_case["expected_image_urls"] 285 | assert "<DSPY_IMAGE_START>" not in str(lm.history[-1]["messages"]) 286 | 287 | 288 | def test_save_load_pydantic_model(): 289 | """Test saving and loading predictors with pydantic models""" 290 | 291 | class ImageModel(pydantic.BaseModel): 292 | image: dspy.Image 293 | image_list: list[dspy.Image] | None = None 294 | output: str 295 | 296 | class PydanticSignature(dspy.Signature): 297 | model_input: ImageModel = dspy.InputField() 298 | output: str = dspy.OutputField() 299 | 300 | # Create model instance 301 | model_input = ImageModel( 302 | image=dspy.Image("https://example.com/dog.jpg"), 303 | image_list=[dspy.Image("https://example.com/cat.jpg")], 304 | output="Multiple photos", 305 | ) 306 | 307 | # Create example and predictor 308 | examples = [dspy.Example(model_input=model_input, output="Multiple photos").with_inputs("model_input")] 309 | 310 | predictor, lm = setup_predictor(PydanticSignature, {"output": "Multiple photos"}) 311 | optimizer = dspy.teleprompt.LabeledFewShot(k=1) 312 | compiled_predictor = optimizer.compile(student=predictor, trainset=examples, sample=False) 313 | 314 | # Test save and load 315 | with tempfile.NamedTemporaryFile(mode="w+", delete=True, suffix=".json") as temp_file: 316 | compiled_predictor.save(temp_file.name) 317 | loaded_predictor = dspy.Predict(PydanticSignature) 318 | loaded_predictor.load(temp_file.name) 319 | 320 | # Run prediction 321 | result = loaded_predictor(model_input=model_input) 322 | 323 | # Verify output matches expected 324 | assert result.output == "Multiple photos" 325 | assert count_messages_with_image_url_pattern(lm.history[-1]["messages"]) == 4 326 | assert "<DSPY_IMAGE_START>" not in str(lm.history[-1]["messages"]) 327 | 328 | 329 | def test_optional_image_field(): 330 | """Test that optional image fields are not required""" 331 | 332 | class OptionalImageSignature(dspy.Signature): 333 | image: dspy.Image | None = dspy.InputField() 334 | output: str = dspy.OutputField() 335 | 336 | predictor, lm = setup_predictor(OptionalImageSignature, {"output": "Hello"}) 337 | result = predictor(image=None) 338 | assert result.output == "Hello" 339 | assert count_messages_with_image_url_pattern(lm.history[-1]["messages"]) == 0 340 | 341 | 342 | def test_pdf_url_support(): 343 | """Test support for PDF files from URLs""" 344 | pdf_url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf" 345 | 346 | # Create a dspy.Image object from the PDF URL with download=True 347 | pdf_image = dspy.Image(pdf_url, download=True) 348 | 349 | # The data URI should contain application/pdf in the MIME type 350 | assert "data:application/pdf" in pdf_image.url 351 | assert ";base64," in pdf_image.url 352 | 353 | # Test using it in a predictor 354 | class PDFSignature(dspy.Signature): 355 | document: dspy.Image = dspy.InputField(desc="A PDF document") 356 | summary: str = dspy.OutputField(desc="A summary of the PDF") 357 | 358 | predictor, lm = setup_predictor(PDFSignature, {"summary": "This is a dummy PDF"}) 359 | result = predictor(document=pdf_image) 360 | 361 | assert result.summary == "This is a dummy PDF" 362 | assert count_messages_with_image_url_pattern(lm.history[-1]["messages"]) == 1 363 | 364 | # Ensure the URL was properly expanded in messages 365 | messages_str = str(lm.history[-1]["messages"]) 366 | assert "application/pdf" in messages_str 367 | 368 | 369 | def test_different_mime_types(): 370 | """Test support for different file types and MIME type detection""" 371 | # Test with various file types 372 | file_urls = { 373 | "pdf": "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf", 374 | "image": "https://images.dog.ceo/breeds/dane-great/n02109047_8912.jpg", 375 | } 376 | 377 | expected_mime_types = { 378 | "pdf": "application/pdf", 379 | "image": "image/jpeg", 380 | } 381 | 382 | for file_type, url in file_urls.items(): 383 | # Download and encode 384 | encoded = encode_image(url, download_images=True) 385 | 386 | # Check for correct MIME type in the encoded data - using 'in' instead of startswith 387 | # to account for possible parameters in the MIME type 388 | assert f"data:{expected_mime_types[file_type]}" in encoded 389 | assert ";base64," in encoded 390 | 391 | 392 | def test_mime_type_from_response_headers(): 393 | """Test that MIME types from response headers are correctly used""" 394 | # This URL returns proper Content-Type header 395 | pdf_url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf" 396 | 397 | # Make an actual request to get the content type from headers 398 | response = requests.get(pdf_url) 399 | expected_mime_type = response.headers.get("Content-Type", "") 400 | 401 | # Should be application/pdf or similar 402 | assert "pdf" in expected_mime_type.lower() 403 | 404 | # Encode with download to test MIME type from headers 405 | encoded = encode_image(pdf_url, download_images=True) 406 | 407 | # The encoded data should contain the correct MIME type 408 | assert "application/pdf" in encoded 409 | assert ";base64," in encoded 410 | 411 | 412 | def test_pdf_from_file(): 413 | """Test handling a PDF file from disk""" 414 | # Download a PDF to a temporary file 415 | pdf_url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf" 416 | response = requests.get(pdf_url) 417 | response.raise_for_status() 418 | 419 | with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmp_file: 420 | tmp_file.write(response.content) 421 | tmp_file_path = tmp_file.name 422 | 423 | try: 424 | # Create a dspy.Image from the file 425 | pdf_image = dspy.Image(tmp_file_path) 426 | 427 | # The constructor encodes the file into a data URI we can inspect directly 428 | assert "data:application/pdf" in pdf_image.url 429 | assert ";base64," in pdf_image.url 430 | 431 | # Test the image in a predictor 432 | class FilePDFSignature(dspy.Signature): 433 | document: dspy.Image = dspy.InputField(desc="A PDF document from file") 434 | summary: str = dspy.OutputField(desc="A summary of the PDF") 435 | 436 | predictor, lm = setup_predictor(FilePDFSignature, {"summary": "This is a PDF from file"}) 437 | result = predictor(document=pdf_image) 438 | 439 | assert result.summary == "This is a PDF from file" 440 | assert count_messages_with_image_url_pattern(lm.history[-1]["messages"]) == 1 441 | finally: 442 | # Clean up the temporary file 443 | try: 444 | os.unlink(tmp_file_path) 445 | except Exception: 446 | pass 447 | 448 | 449 | def test_image_repr(): 450 | """Test string representation of Image objects""" 451 | url_image = dspy.Image("https://example.com/dog.jpg") 452 | assert str(url_image) == ( 453 | "<<CUSTOM-TYPE-START-IDENTIFIER>>" 454 | '[{"type": "image_url", "image_url": {"url": "https://example.com/dog.jpg"}}]' 455 | "<<CUSTOM-TYPE-END-IDENTIFIER>>" 456 | ) 457 | assert repr(url_image) == "Image(url='https://example.com/dog.jpg')" 458 | 459 | sample_pil = PILImage.new("RGB", (60, 30), color="red") 460 | pil_image = dspy.Image(sample_pil) 461 | assert str(pil_image).startswith('<<CUSTOM-TYPE-START-IDENTIFIER>>[{"type": "image_url",') 462 | assert str(pil_image).endswith("<<CUSTOM-TYPE-END-IDENTIFIER>>") 463 | assert "base64" in str(pil_image) 464 | 465 | 466 | def test_from_methods_warn(tmp_path): 467 | """Deprecated from_* methods emit warnings""" 468 | tmp_file = tmp_path / "test.png" 469 | tmp_file.write_bytes(b"pngdata") 470 | 471 | with pytest.warns(DeprecationWarning): 472 | dspy.Image.from_url("https://example.com/dog.jpg") 473 | with pytest.warns(DeprecationWarning): 474 | dspy.Image.from_file(str(tmp_file)) 475 | sample_pil = PILImage.new("RGB", (10, 10), color="blue") 476 | with pytest.warns(DeprecationWarning): 477 | dspy.Image.from_PIL(sample_pil) 478 | 479 | 480 | def test_invalid_string_format(): 481 | """Test that invalid string formats raise a ValueError""" 482 | invalid_string = "this_is_not_a_url_or_file" 483 | 484 | # Should raise a ValueError and not pass the string through 485 | with pytest.raises(ValueError, match="Unrecognized") as warning_info: 486 | image = dspy.Image(invalid_string) 487 | 488 | def test_pil_image_with_download_parameter(): 489 | """Test behavior when PIL image is passed with download=True""" 490 | sample_pil = PILImage.new("RGB", (60, 30), color="red") 491 | 492 | # PIL image should be encoded regardless of download parameter 493 | image_no_download = dspy.Image(sample_pil) 494 | image_with_download = dspy.Image(sample_pil, download=True) 495 | 496 | # Both should result in base64 encoded data URIs 497 | assert image_no_download.url.startswith("data:") 498 | assert image_with_download.url.startswith("data:") 499 | assert "base64," in image_no_download.url 500 | assert "base64," in image_with_download.url 501 | 502 | # They should be identical since PIL images are always encoded 503 | assert image_no_download.url == image_with_download.url 504 | ``` -------------------------------------------------------------------------------- /dspy/clients/lm.py: -------------------------------------------------------------------------------- ```python 1 | import logging 2 | import os 3 | import re 4 | import threading 5 | import warnings 6 | from typing import Any, Literal, cast 7 | 8 | import litellm 9 | from anyio.streams.memory import MemoryObjectSendStream 10 | from asyncer import syncify 11 | 12 | import dspy 13 | from dspy.clients.cache import request_cache 14 | from dspy.clients.openai import OpenAIProvider 15 | from dspy.clients.provider import Provider, ReinforceJob, TrainingJob 16 | from dspy.clients.utils_finetune import TrainDataFormat 17 | from dspy.dsp.utils.settings import settings 18 | from dspy.utils.callback import BaseCallback 19 | 20 | from .base_lm import BaseLM 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | class LM(BaseLM): 26 | """ 27 | A language model supporting chat or text completion requests for use with DSPy modules. 28 | """ 29 | 30 | def __init__( 31 | self, 32 | model: str, 33 | model_type: Literal["chat", "text", "responses"] = "chat", 34 | temperature: float | None = None, 35 | max_tokens: int | None = None, 36 | cache: bool = True, 37 | callbacks: list[BaseCallback] | None = None, 38 | num_retries: int = 3, 39 | provider: Provider | None = None, 40 | finetuning_model: str | None = None, 41 | launch_kwargs: dict[str, Any] | None = None, 42 | train_kwargs: dict[str, Any] | None = None, 43 | use_developer_role: bool = False, 44 | **kwargs, 45 | ): 46 | """ 47 | Create a new language model instance for use with DSPy modules and programs. 48 | 49 | Args: 50 | model: The model to use. This should be a string of the form ``"llm_provider/llm_name"`` 51 | supported by LiteLLM. For example, ``"openai/gpt-4o"``. 52 | model_type: The type of the model, either ``"chat"`` or ``"text"``. 53 | temperature: The sampling temperature to use when generating responses. 54 | max_tokens: The maximum number of tokens to generate per response. 55 | cache: Whether to cache the model responses for reuse to improve performance 56 | and reduce costs. 57 | callbacks: A list of callback functions to run before and after each request. 58 | num_retries: The number of times to retry a request if it fails transiently due to 59 | network error, rate limiting, etc. Requests are retried with exponential 60 | backoff. 61 | provider: The provider to use. If not specified, the provider will be inferred from the model. 62 | finetuning_model: The model to finetune. In some providers, the models available for finetuning is different 63 | from the models available for inference. 64 | rollout_id: Optional integer used to differentiate cache entries for otherwise 65 | identical requests. Different values bypass DSPy's caches while still caching 66 | future calls with the same inputs and rollout ID. Note that `rollout_id` 67 | only affects generation when `temperature` is non-zero. This argument is 68 | stripped before sending requests to the provider. 69 | """ 70 | # Remember to update LM.copy() if you modify the constructor! 71 | self.model = model 72 | self.model_type = model_type 73 | self.cache = cache 74 | self.provider = provider or self.infer_provider() 75 | self.callbacks = callbacks or [] 76 | self.history = [] 77 | self.num_retries = num_retries 78 | self.finetuning_model = finetuning_model 79 | self.launch_kwargs = launch_kwargs or {} 80 | self.train_kwargs = train_kwargs or {} 81 | self.use_developer_role = use_developer_role 82 | self._warned_zero_temp_rollout = False 83 | 84 | # Handle model-specific configuration for different model families 85 | model_family = model.split("/")[-1].lower() if "/" in model else model.lower() 86 | 87 | # Recognize OpenAI reasoning models (o1, o3, o4, gpt-5 family) 88 | model_pattern = re.match(r"^(?:o[1345]|gpt-5)(?:-(?:mini|nano))?", model_family) 89 | 90 | if model_pattern: 91 | 92 | if (temperature and temperature != 1.0) or (max_tokens and max_tokens < 16000): 93 | raise ValueError( 94 | "OpenAI's reasoning models require passing temperature=1.0 or None and max_tokens >= 16000 or None to " 95 | "`dspy.LM(...)`, e.g., dspy.LM('openai/gpt-5', temperature=1.0, max_tokens=16000)" 96 | ) 97 | self.kwargs = dict(temperature=temperature, max_completion_tokens=max_tokens, **kwargs) 98 | if self.kwargs.get("rollout_id") is None: 99 | self.kwargs.pop("rollout_id", None) 100 | else: 101 | self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs) 102 | if self.kwargs.get("rollout_id") is None: 103 | self.kwargs.pop("rollout_id", None) 104 | 105 | self._warn_zero_temp_rollout(self.kwargs.get("temperature"), self.kwargs.get("rollout_id")) 106 | 107 | def _warn_zero_temp_rollout(self, temperature: float | None, rollout_id): 108 | if not self._warned_zero_temp_rollout and rollout_id is not None and (temperature is None or temperature == 0): 109 | warnings.warn( 110 | "rollout_id has no effect when temperature=0; set temperature>0 to bypass the cache.", 111 | stacklevel=3, 112 | ) 113 | self._warned_zero_temp_rollout = True 114 | 115 | def _get_cached_completion_fn(self, completion_fn, cache): 116 | ignored_args_for_cache_key = ["api_key", "api_base", "base_url"] 117 | if cache: 118 | completion_fn = request_cache( 119 | cache_arg_name="request", 120 | ignored_args_for_cache_key=ignored_args_for_cache_key, 121 | )(completion_fn) 122 | 123 | litellm_cache_args = {"no-cache": True, "no-store": True} 124 | 125 | return completion_fn, litellm_cache_args 126 | 127 | def forward(self, prompt=None, messages=None, **kwargs): 128 | # Build the request. 129 | kwargs = dict(kwargs) 130 | cache = kwargs.pop("cache", self.cache) 131 | 132 | messages = messages or [{"role": "user", "content": prompt}] 133 | if self.use_developer_role and self.model_type == "responses": 134 | messages = [{**m, "role": "developer"} if m.get("role") == "system" else m for m in messages] 135 | kwargs = {**self.kwargs, **kwargs} 136 | self._warn_zero_temp_rollout(kwargs.get("temperature"), kwargs.get("rollout_id")) 137 | if kwargs.get("rollout_id") is None: 138 | kwargs.pop("rollout_id", None) 139 | 140 | if self.model_type == "chat": 141 | completion = litellm_completion 142 | elif self.model_type == "text": 143 | completion = litellm_text_completion 144 | elif self.model_type == "responses": 145 | completion = litellm_responses_completion 146 | completion, litellm_cache_args = self._get_cached_completion_fn(completion, cache) 147 | 148 | results = completion( 149 | request=dict(model=self.model, messages=messages, **kwargs), 150 | num_retries=self.num_retries, 151 | cache=litellm_cache_args, 152 | ) 153 | 154 | self._check_truncation(results) 155 | 156 | if not getattr(results, "cache_hit", False) and dspy.settings.usage_tracker and hasattr(results, "usage"): 157 | settings.usage_tracker.add_usage(self.model, dict(results.usage)) 158 | return results 159 | 160 | async def aforward(self, prompt=None, messages=None, **kwargs): 161 | # Build the request. 162 | kwargs = dict(kwargs) 163 | cache = kwargs.pop("cache", self.cache) 164 | 165 | messages = messages or [{"role": "user", "content": prompt}] 166 | if self.use_developer_role and self.model_type == "responses": 167 | messages = [{**m, "role": "developer"} if m.get("role") == "system" else m for m in messages] 168 | kwargs = {**self.kwargs, **kwargs} 169 | self._warn_zero_temp_rollout(kwargs.get("temperature"), kwargs.get("rollout_id")) 170 | if kwargs.get("rollout_id") is None: 171 | kwargs.pop("rollout_id", None) 172 | 173 | if self.model_type == "chat": 174 | completion = alitellm_completion 175 | elif self.model_type == "text": 176 | completion = alitellm_text_completion 177 | elif self.model_type == "responses": 178 | completion = alitellm_responses_completion 179 | completion, litellm_cache_args = self._get_cached_completion_fn(completion, cache) 180 | 181 | results = await completion( 182 | request=dict(model=self.model, messages=messages, **kwargs), 183 | num_retries=self.num_retries, 184 | cache=litellm_cache_args, 185 | ) 186 | 187 | self._check_truncation(results) 188 | 189 | if not getattr(results, "cache_hit", False) and dspy.settings.usage_tracker and hasattr(results, "usage"): 190 | settings.usage_tracker.add_usage(self.model, dict(results.usage)) 191 | return results 192 | 193 | def launch(self, launch_kwargs: dict[str, Any] | None = None): 194 | self.provider.launch(self, launch_kwargs) 195 | 196 | def kill(self, launch_kwargs: dict[str, Any] | None = None): 197 | self.provider.kill(self, launch_kwargs) 198 | 199 | def finetune( 200 | self, 201 | train_data: list[dict[str, Any]], 202 | train_data_format: TrainDataFormat | None, 203 | train_kwargs: dict[str, Any] | None = None, 204 | ) -> TrainingJob: 205 | from dspy import settings as settings 206 | 207 | if not self.provider.finetunable: 208 | raise ValueError( 209 | f"Provider {self.provider} does not support fine-tuning, please specify your provider by explicitly " 210 | "setting `provider` when creating the `dspy.LM` instance. For example, " 211 | "`dspy.LM('openai/gpt-4.1-mini-2025-04-14', provider=dspy.OpenAIProvider())`." 212 | ) 213 | 214 | def thread_function_wrapper(): 215 | return self._run_finetune_job(job) 216 | 217 | thread = threading.Thread(target=thread_function_wrapper) 218 | train_kwargs = train_kwargs or self.train_kwargs 219 | model_to_finetune = self.finetuning_model or self.model 220 | job = self.provider.TrainingJob( 221 | thread=thread, 222 | model=model_to_finetune, 223 | train_data=train_data, 224 | train_data_format=train_data_format, 225 | train_kwargs=train_kwargs, 226 | ) 227 | thread.start() 228 | 229 | return job 230 | 231 | def reinforce( 232 | self, train_kwargs 233 | ) -> ReinforceJob: 234 | # TODO(GRPO Team): Should we return an initialized job here? 235 | from dspy import settings as settings 236 | 237 | err = f"Provider {self.provider} does not implement the reinforcement learning interface." 238 | assert self.provider.reinforceable, err 239 | 240 | job = self.provider.ReinforceJob(lm=self, train_kwargs=train_kwargs) 241 | job.initialize() 242 | return job 243 | 244 | def _run_finetune_job(self, job: TrainingJob): 245 | # TODO(enhance): We should listen for keyboard interrupts somewhere. 246 | # Requires TrainingJob.cancel() to be implemented for each provider. 247 | try: 248 | model = self.provider.finetune( 249 | job=job, 250 | model=job.model, 251 | train_data=job.train_data, 252 | train_data_format=job.train_data_format, 253 | train_kwargs=job.train_kwargs, 254 | ) 255 | lm = self.copy(model=model) 256 | job.set_result(lm) 257 | except Exception as err: 258 | logger.error(err) 259 | job.set_result(err) 260 | 261 | def infer_provider(self) -> Provider: 262 | if OpenAIProvider.is_provider_model(self.model): 263 | return OpenAIProvider() 264 | return Provider() 265 | 266 | def dump_state(self): 267 | state_keys = [ 268 | "model", 269 | "model_type", 270 | "cache", 271 | "num_retries", 272 | "finetuning_model", 273 | "launch_kwargs", 274 | "train_kwargs", 275 | ] 276 | return {key: getattr(self, key) for key in state_keys} | self.kwargs 277 | 278 | def _check_truncation(self, results): 279 | if self.model_type != "responses" and any(c.finish_reason == "length" for c in results["choices"]): 280 | logger.warning( 281 | f"LM response was truncated due to exceeding max_tokens={self.kwargs['max_tokens']}. " 282 | "You can inspect the latest LM interactions with `dspy.inspect_history()`. " 283 | "To avoid truncation, consider passing a larger max_tokens when setting up dspy.LM. " 284 | f"You may also consider increasing the temperature (currently {self.kwargs['temperature']}) " 285 | " if the reason for truncation is repetition." 286 | ) 287 | 288 | 289 | def _get_stream_completion_fn( 290 | request: dict[str, Any], 291 | cache_kwargs: dict[str, Any], 292 | sync=True, 293 | ): 294 | stream = dspy.settings.send_stream 295 | caller_predict = dspy.settings.caller_predict 296 | 297 | if stream is None: 298 | return None 299 | 300 | # The stream is already opened, and will be closed by the caller. 301 | stream = cast(MemoryObjectSendStream, stream) 302 | caller_predict_id = id(caller_predict) if caller_predict else None 303 | 304 | if dspy.settings.track_usage: 305 | request["stream_options"] = {"include_usage": True} 306 | 307 | async def stream_completion(request: dict[str, Any], cache_kwargs: dict[str, Any]): 308 | headers = request.pop("headers", None) 309 | response = await litellm.acompletion( 310 | cache=cache_kwargs, 311 | stream=True, 312 | headers=_get_headers(headers), 313 | **request, 314 | ) 315 | chunks = [] 316 | async for chunk in response: 317 | if caller_predict_id: 318 | # Add the predict id to the chunk so that the stream listener can identify which predict produces it. 319 | chunk.predict_id = caller_predict_id 320 | chunks.append(chunk) 321 | await stream.send(chunk) 322 | return litellm.stream_chunk_builder(chunks) 323 | 324 | def sync_stream_completion(): 325 | syncified_stream_completion = syncify(stream_completion) 326 | return syncified_stream_completion(request, cache_kwargs) 327 | 328 | async def async_stream_completion(): 329 | return await stream_completion(request, cache_kwargs) 330 | 331 | if sync: 332 | return sync_stream_completion 333 | else: 334 | return async_stream_completion 335 | 336 | 337 | def litellm_completion(request: dict[str, Any], num_retries: int, cache: dict[str, Any] | None = None): 338 | cache = cache or {"no-cache": True, "no-store": True} 339 | request = dict(request) 340 | request.pop("rollout_id", None) 341 | headers = request.pop("headers", None) 342 | stream_completion = _get_stream_completion_fn(request, cache, sync=True) 343 | if stream_completion is None: 344 | return litellm.completion( 345 | cache=cache, 346 | num_retries=num_retries, 347 | retry_strategy="exponential_backoff_retry", 348 | headers=_get_headers(headers), 349 | **request, 350 | ) 351 | 352 | return stream_completion() 353 | 354 | 355 | def litellm_text_completion(request: dict[str, Any], num_retries: int, cache: dict[str, Any] | None = None): 356 | cache = cache or {"no-cache": True, "no-store": True} 357 | request = dict(request) 358 | request.pop("rollout_id", None) 359 | headers = request.pop("headers", None) 360 | # Extract the provider and model from the model string. 361 | # TODO: Not all the models are in the format of "provider/model" 362 | model = request.pop("model").split("/", 1) 363 | provider, model = model[0] if len(model) > 1 else "openai", model[-1] 364 | 365 | # Use the API key and base from the request, or from the environment. 366 | api_key = request.pop("api_key", None) or os.getenv(f"{provider}_API_KEY") 367 | api_base = request.pop("api_base", None) or os.getenv(f"{provider}_API_BASE") 368 | 369 | # Build the prompt from the messages. 370 | prompt = "\n\n".join([x["content"] for x in request.pop("messages")] + ["BEGIN RESPONSE:"]) 371 | 372 | return litellm.text_completion( 373 | cache=cache, 374 | model=f"text-completion-openai/{model}", 375 | api_key=api_key, 376 | api_base=api_base, 377 | prompt=prompt, 378 | num_retries=num_retries, 379 | retry_strategy="exponential_backoff_retry", 380 | headers=_get_headers(headers), 381 | **request, 382 | ) 383 | 384 | 385 | async def alitellm_completion(request: dict[str, Any], num_retries: int, cache: dict[str, Any] | None = None): 386 | cache = cache or {"no-cache": True, "no-store": True} 387 | request = dict(request) 388 | request.pop("rollout_id", None) 389 | headers = request.pop("headers", None) 390 | stream_completion = _get_stream_completion_fn(request, cache, sync=False) 391 | if stream_completion is None: 392 | return await litellm.acompletion( 393 | cache=cache, 394 | num_retries=num_retries, 395 | retry_strategy="exponential_backoff_retry", 396 | headers=_get_headers(headers), 397 | **request, 398 | ) 399 | 400 | return await stream_completion() 401 | 402 | 403 | async def alitellm_text_completion(request: dict[str, Any], num_retries: int, cache: dict[str, Any] | None = None): 404 | cache = cache or {"no-cache": True, "no-store": True} 405 | request = dict(request) 406 | request.pop("rollout_id", None) 407 | model = request.pop("model").split("/", 1) 408 | headers = request.pop("headers", None) 409 | provider, model = model[0] if len(model) > 1 else "openai", model[-1] 410 | 411 | # Use the API key and base from the request, or from the environment. 412 | api_key = request.pop("api_key", None) or os.getenv(f"{provider}_API_KEY") 413 | api_base = request.pop("api_base", None) or os.getenv(f"{provider}_API_BASE") 414 | 415 | # Build the prompt from the messages. 416 | prompt = "\n\n".join([x["content"] for x in request.pop("messages")] + ["BEGIN RESPONSE:"]) 417 | 418 | return await litellm.atext_completion( 419 | cache=cache, 420 | model=f"text-completion-openai/{model}", 421 | api_key=api_key, 422 | api_base=api_base, 423 | prompt=prompt, 424 | num_retries=num_retries, 425 | retry_strategy="exponential_backoff_retry", 426 | headers=_get_headers(headers), 427 | **request, 428 | ) 429 | 430 | 431 | def litellm_responses_completion(request: dict[str, Any], num_retries: int, cache: dict[str, Any] | None = None): 432 | cache = cache or {"no-cache": True, "no-store": True} 433 | request = dict(request) 434 | request.pop("rollout_id", None) 435 | headers = request.pop("headers", None) 436 | request = _convert_chat_request_to_responses_request(request) 437 | 438 | return litellm.responses( 439 | cache=cache, 440 | num_retries=num_retries, 441 | retry_strategy="exponential_backoff_retry", 442 | headers=_get_headers(headers), 443 | **request, 444 | ) 445 | 446 | 447 | async def alitellm_responses_completion(request: dict[str, Any], num_retries: int, cache: dict[str, Any] | None = None): 448 | cache = cache or {"no-cache": True, "no-store": True} 449 | request = dict(request) 450 | request.pop("rollout_id", None) 451 | headers = request.pop("headers", None) 452 | request = _convert_chat_request_to_responses_request(request) 453 | 454 | return await litellm.aresponses( 455 | cache=cache, 456 | num_retries=num_retries, 457 | retry_strategy="exponential_backoff_retry", 458 | headers=_get_headers(headers), 459 | **request, 460 | ) 461 | 462 | 463 | def _convert_chat_request_to_responses_request(request: dict[str, Any]): 464 | request = dict(request) 465 | if "messages" in request: 466 | content_blocks = [] 467 | for msg in request.pop("messages"): 468 | c = msg.get("content") 469 | if isinstance(c, str): 470 | content_blocks.append({"type": "input_text", "text": c}) 471 | elif isinstance(c, list): 472 | content_blocks.extend(c) 473 | request["input"] = [{"role": msg.get("role", "user"), "content": content_blocks}] 474 | 475 | # Convert `response_format` to `text.format` for Responses API 476 | if "response_format" in request: 477 | response_format = request.pop("response_format") 478 | text = request.pop("text", {}) 479 | request["text"] = {**text, "format": response_format} 480 | 481 | return request 482 | 483 | def _get_headers(headers: dict[str, Any] | None = None): 484 | headers = headers or {} 485 | return { 486 | "User-Agent": f"DSPy/{dspy.__version__}", 487 | **headers, 488 | } 489 | ```