#
tokens: 46674/50000 8/391 files (page 11/17)
lines: on (toggle) GitHub
raw markdown copy reset
This is page 11 of 17. Use http://codebase.md/stanfordnlp/dspy?lines=true&page={x} to view the full context.

# Directory Structure

```
├── .github
│   ├── .internal_dspyai
│   │   ├── internals
│   │   │   ├── build-and-release.md
│   │   │   └── release-checklist.md
│   │   └── pyproject.toml
│   ├── .tmp
│   │   └── .generated-actions
│   │       └── run-pypi-publish-in-docker-container
│   │           └── action.yml
│   ├── ISSUE_TEMPLATE
│   │   ├── bug_report.yml
│   │   └── feature_request.yml
│   ├── PULL_REQUEST_TEMPLATE
│   │   └── pull_request_template.md
│   ├── workflow_scripts
│   │   └── install_testpypi_pkg.sh
│   └── workflows
│       ├── build_and_release.yml
│       ├── build_utils
│       │   └── test_version.py
│       ├── docs-push.yml
│       ├── precommits_check.yml
│       └── run_tests.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CONTRIBUTING.md
├── docs
│   ├── .gitignore
│   ├── docs
│   │   ├── api
│   │   │   ├── adapters
│   │   │   │   ├── Adapter.md
│   │   │   │   ├── ChatAdapter.md
│   │   │   │   ├── JSONAdapter.md
│   │   │   │   └── TwoStepAdapter.md
│   │   │   ├── evaluation
│   │   │   │   ├── answer_exact_match.md
│   │   │   │   ├── answer_passage_match.md
│   │   │   │   ├── CompleteAndGrounded.md
│   │   │   │   ├── Evaluate.md
│   │   │   │   ├── EvaluationResult.md
│   │   │   │   └── SemanticF1.md
│   │   │   ├── experimental
│   │   │   │   ├── Citations.md
│   │   │   │   └── Document.md
│   │   │   ├── index.md
│   │   │   ├── models
│   │   │   │   ├── Embedder.md
│   │   │   │   └── LM.md
│   │   │   ├── modules
│   │   │   │   ├── BestOfN.md
│   │   │   │   ├── ChainOfThought.md
│   │   │   │   ├── CodeAct.md
│   │   │   │   ├── Module.md
│   │   │   │   ├── MultiChainComparison.md
│   │   │   │   ├── Parallel.md
│   │   │   │   ├── Predict.md
│   │   │   │   ├── ProgramOfThought.md
│   │   │   │   ├── ReAct.md
│   │   │   │   └── Refine.md
│   │   │   ├── optimizers
│   │   │   │   ├── BetterTogether.md
│   │   │   │   ├── BootstrapFewShot.md
│   │   │   │   ├── BootstrapFewShotWithRandomSearch.md
│   │   │   │   ├── BootstrapFinetune.md
│   │   │   │   ├── BootstrapRS.md
│   │   │   │   ├── COPRO.md
│   │   │   │   ├── Ensemble.md
│   │   │   │   ├── GEPA
│   │   │   │   │   ├── GEPA_Advanced.md
│   │   │   │   │   └── overview.md
│   │   │   │   ├── InferRules.md
│   │   │   │   ├── KNN.md
│   │   │   │   ├── KNNFewShot.md
│   │   │   │   ├── LabeledFewShot.md
│   │   │   │   ├── MIPROv2.md
│   │   │   │   └── SIMBA.md
│   │   │   ├── primitives
│   │   │   │   ├── Audio.md
│   │   │   │   ├── Code.md
│   │   │   │   ├── Example.md
│   │   │   │   ├── History.md
│   │   │   │   ├── Image.md
│   │   │   │   ├── Prediction.md
│   │   │   │   ├── Tool.md
│   │   │   │   └── ToolCalls.md
│   │   │   ├── signatures
│   │   │   │   ├── InputField.md
│   │   │   │   ├── OutputField.md
│   │   │   │   └── Signature.md
│   │   │   ├── tools
│   │   │   │   ├── ColBERTv2.md
│   │   │   │   ├── Embeddings.md
│   │   │   │   └── PythonInterpreter.md
│   │   │   └── utils
│   │   │       ├── asyncify.md
│   │   │       ├── configure_cache.md
│   │   │       ├── disable_litellm_logging.md
│   │   │       ├── disable_logging.md
│   │   │       ├── enable_litellm_logging.md
│   │   │       ├── enable_logging.md
│   │   │       ├── inspect_history.md
│   │   │       ├── load.md
│   │   │       ├── StatusMessage.md
│   │   │       ├── StatusMessageProvider.md
│   │   │       ├── streamify.md
│   │   │       └── StreamListener.md
│   │   ├── cheatsheet.md
│   │   ├── community
│   │   │   ├── community-resources.md
│   │   │   ├── how-to-contribute.md
│   │   │   └── use-cases.md
│   │   ├── deep-dive
│   │   │   └── data-handling
│   │   │       ├── built-in-datasets.md
│   │   │       ├── examples.md
│   │   │       ├── img
│   │   │       │   └── data-loading.png
│   │   │       └── loading-custom-data.md
│   │   ├── faqs.md
│   │   ├── index.md
│   │   ├── js
│   │   │   └── runllm-widget.js
│   │   ├── learn
│   │   │   ├── evaluation
│   │   │   │   ├── data.md
│   │   │   │   ├── metrics.md
│   │   │   │   └── overview.md
│   │   │   ├── figures
│   │   │   │   ├── native_tool_call.png
│   │   │   │   └── teleprompter-classes.png
│   │   │   ├── index.md
│   │   │   ├── optimization
│   │   │   │   ├── optimizers.md
│   │   │   │   └── overview.md
│   │   │   └── programming
│   │   │       ├── 7-assertions.md
│   │   │       ├── adapters.md
│   │   │       ├── language_models.md
│   │   │       ├── mcp.md
│   │   │       ├── modules.md
│   │   │       ├── overview.md
│   │   │       ├── signatures.md
│   │   │       └── tools.md
│   │   ├── production
│   │   │   └── index.md
│   │   ├── roadmap.md
│   │   ├── static
│   │   │   ├── .nojekyll
│   │   │   └── img
│   │   │       ├── dspy_logo.png
│   │   │       ├── logo.png
│   │   │       ├── mlflow-tracing-rag.png
│   │   │       ├── modular.png
│   │   │       ├── optimize.png
│   │   │       ├── undraw_docusaurus_mountain.svg
│   │   │       ├── undraw_docusaurus_react.svg
│   │   │       ├── undraw_docusaurus_tree.svg
│   │   │       └── universal_compatibility.png
│   │   ├── stylesheets
│   │   │   └── extra.css
│   │   └── tutorials
│   │       ├── agents
│   │       │   ├── index.ipynb
│   │       │   └── mlflow-tracing-agent.png
│   │       ├── ai_text_game
│   │       │   └── index.md
│   │       ├── async
│   │       │   └── index.md
│   │       ├── audio
│   │       │   └── index.ipynb
│   │       ├── build_ai_program
│   │       │   └── index.md
│   │       ├── cache
│   │       │   └── index.md
│   │       ├── classification
│   │       │   └── index.md
│   │       ├── classification_finetuning
│   │       │   ├── index.ipynb
│   │       │   └── mlflow-tracing-classification.png
│   │       ├── conversation_history
│   │       │   └── index.md
│   │       ├── core_development
│   │       │   └── index.md
│   │       ├── custom_module
│   │       │   ├── index.ipynb
│   │       │   └── mlflow-tracing-custom-module.png
│   │       ├── customer_service_agent
│   │       │   ├── index.ipynb
│   │       │   └── mlflow-tracing-customer-service-agent.png
│   │       ├── deployment
│   │       │   ├── dspy_mlflow_ui.png
│   │       │   └── index.md
│   │       ├── email_extraction
│   │       │   ├── index.md
│   │       │   └── mlflow-tracing-email-extraction.png
│   │       ├── entity_extraction
│   │       │   ├── index.ipynb
│   │       │   └── mlflow-tracing-entity-extraction.png
│   │       ├── games
│   │       │   ├── index.ipynb
│   │       │   └── mlflow-tracing-agent.png
│   │       ├── gepa_ai_program
│   │       │   └── index.md
│   │       ├── gepa_aime
│   │       │   ├── index.ipynb
│   │       │   ├── mlflow-tracing-gepa-aime.png
│   │       │   └── mlflow-tracking-gepa-aime-optimization.png
│   │       ├── gepa_facilitysupportanalyzer
│   │       │   ├── index.ipynb
│   │       │   ├── mlflow-tracing-gepa-support.png
│   │       │   └── mlflow-tracking-gepa-support-optimization.png
│   │       ├── gepa_papillon
│   │       │   ├── index.ipynb
│   │       │   ├── mlflow-tracing-gepa-papilon.png
│   │       │   └── mlflow-tracking-gepa-papilon-optimization.png
│   │       ├── image_generation_prompting
│   │       │   └── index.ipynb
│   │       ├── index.md
│   │       ├── llms_txt_generation
│   │       │   └── index.md
│   │       ├── math
│   │       │   ├── index.ipynb
│   │       │   └── mlflow-tracing-math.png
│   │       ├── mcp
│   │       │   └── index.md
│   │       ├── mem0_react_agent
│   │       │   └── index.md
│   │       ├── multihop_search
│   │       │   ├── index.ipynb
│   │       │   └── mlflow-tracing-multi-hop.png
│   │       ├── observability
│   │       │   ├── index.md
│   │       │   ├── mlflow_trace_ui_navigation.gif
│   │       │   ├── mlflow_trace_ui.png
│   │       │   └── mlflow_trace_view.png
│   │       ├── optimize_ai_program
│   │       │   └── index.md
│   │       ├── optimizer_tracking
│   │       │   ├── child_run.png
│   │       │   ├── experiment.png
│   │       │   ├── index.md
│   │       │   └── parent_run.png
│   │       ├── output_refinement
│   │       │   └── best-of-n-and-refine.md
│   │       ├── papillon
│   │       │   └── index.md
│   │       ├── program_of_thought
│   │       │   └── index.ipynb
│   │       ├── rag
│   │       │   ├── index.ipynb
│   │       │   └── mlflow-tracing-rag.png
│   │       ├── real_world_examples
│   │       │   └── index.md
│   │       ├── rl_ai_program
│   │       │   └── index.md
│   │       ├── rl_multihop
│   │       │   └── index.ipynb
│   │       ├── rl_papillon
│   │       │   └── index.ipynb
│   │       ├── sample_code_generation
│   │       │   └── index.md
│   │       ├── saving
│   │       │   └── index.md
│   │       ├── streaming
│   │       │   └── index.md
│   │       ├── tool_use
│   │       │   └── index.ipynb
│   │       └── yahoo_finance_react
│   │           └── index.md
│   ├── mkdocs.yml
│   ├── overrides
│   │   ├── home.html
│   │   ├── main.html
│   │   └── partials
│   │       └── tabs.html
│   ├── Pipfile
│   ├── Pipfile.lock
│   ├── README.md
│   ├── requirements.txt
│   ├── scripts
│   │   ├── generate_api_docs.py
│   │   └── generate_api_summary.py
│   └── vercel.json
├── dspy
│   ├── __init__.py
│   ├── __metadata__.py
│   ├── adapters
│   │   ├── __init__.py
│   │   ├── baml_adapter.py
│   │   ├── base.py
│   │   ├── chat_adapter.py
│   │   ├── json_adapter.py
│   │   ├── two_step_adapter.py
│   │   ├── types
│   │   │   ├── __init__.py
│   │   │   ├── audio.py
│   │   │   ├── base_type.py
│   │   │   ├── citation.py
│   │   │   ├── code.py
│   │   │   ├── document.py
│   │   │   ├── history.py
│   │   │   ├── image.py
│   │   │   └── tool.py
│   │   ├── utils.py
│   │   └── xml_adapter.py
│   ├── clients
│   │   ├── __init__.py
│   │   ├── base_lm.py
│   │   ├── cache.py
│   │   ├── databricks.py
│   │   ├── embedding.py
│   │   ├── lm_local_arbor.py
│   │   ├── lm_local.py
│   │   ├── lm.py
│   │   ├── openai.py
│   │   ├── provider.py
│   │   └── utils_finetune.py
│   ├── datasets
│   │   ├── __init__.py
│   │   ├── alfworld
│   │   │   ├── __init__.py
│   │   │   ├── alfworld.py
│   │   │   └── base_config.yml
│   │   ├── colors.py
│   │   ├── dataloader.py
│   │   ├── dataset.py
│   │   ├── gsm8k.py
│   │   ├── hotpotqa.py
│   │   └── math.py
│   ├── dsp
│   │   ├── __init__.py
│   │   ├── colbertv2.py
│   │   └── utils
│   │       ├── __init__.py
│   │       ├── dpr.py
│   │       ├── settings.py
│   │       └── utils.py
│   ├── evaluate
│   │   ├── __init__.py
│   │   ├── auto_evaluation.py
│   │   ├── evaluate.py
│   │   └── metrics.py
│   ├── experimental
│   │   └── __init__.py
│   ├── predict
│   │   ├── __init__.py
│   │   ├── aggregation.py
│   │   ├── avatar
│   │   │   ├── __init__.py
│   │   │   ├── avatar.py
│   │   │   ├── models.py
│   │   │   └── signatures.py
│   │   ├── best_of_n.py
│   │   ├── chain_of_thought.py
│   │   ├── code_act.py
│   │   ├── knn.py
│   │   ├── multi_chain_comparison.py
│   │   ├── parallel.py
│   │   ├── parameter.py
│   │   ├── predict.py
│   │   ├── program_of_thought.py
│   │   ├── react.py
│   │   ├── refine.py
│   │   └── retry.py
│   ├── primitives
│   │   ├── __init__.py
│   │   ├── base_module.py
│   │   ├── example.py
│   │   ├── module.py
│   │   ├── prediction.py
│   │   ├── python_interpreter.py
│   │   └── runner.js
│   ├── propose
│   │   ├── __init__.py
│   │   ├── dataset_summary_generator.py
│   │   ├── grounded_proposer.py
│   │   ├── propose_base.py
│   │   └── utils.py
│   ├── retrievers
│   │   ├── __init__.py
│   │   ├── databricks_rm.py
│   │   ├── embeddings.py
│   │   ├── retrieve.py
│   │   └── weaviate_rm.py
│   ├── signatures
│   │   ├── __init__.py
│   │   ├── field.py
│   │   ├── signature.py
│   │   └── utils.py
│   ├── streaming
│   │   ├── __init__.py
│   │   ├── messages.py
│   │   ├── streamify.py
│   │   └── streaming_listener.py
│   ├── teleprompt
│   │   ├── __init__.py
│   │   ├── avatar_optimizer.py
│   │   ├── bettertogether.py
│   │   ├── bootstrap_finetune.py
│   │   ├── bootstrap_trace.py
│   │   ├── bootstrap.py
│   │   ├── copro_optimizer.py
│   │   ├── ensemble.py
│   │   ├── gepa
│   │   │   ├── __init__.py
│   │   │   ├── gepa_utils.py
│   │   │   ├── gepa.py
│   │   │   └── instruction_proposal.py
│   │   ├── grpo.py
│   │   ├── infer_rules.py
│   │   ├── knn_fewshot.py
│   │   ├── mipro_optimizer_v2.py
│   │   ├── random_search.py
│   │   ├── signature_opt.py
│   │   ├── simba_utils.py
│   │   ├── simba.py
│   │   ├── teleprompt_optuna.py
│   │   ├── teleprompt.py
│   │   ├── utils.py
│   │   └── vanilla.py
│   └── utils
│       ├── __init__.py
│       ├── annotation.py
│       ├── asyncify.py
│       ├── caching.py
│       ├── callback.py
│       ├── dummies.py
│       ├── exceptions.py
│       ├── hasher.py
│       ├── inspect_history.py
│       ├── langchain_tool.py
│       ├── logging_utils.py
│       ├── mcp.py
│       ├── parallelizer.py
│       ├── saving.py
│       ├── syncify.py
│       ├── unbatchify.py
│       └── usage_tracker.py
├── LICENSE
├── pyproject.toml
├── README.md
├── tests
│   ├── __init__.py
│   ├── adapters
│   │   ├── test_adapter_utils.py
│   │   ├── test_baml_adapter.py
│   │   ├── test_base_type.py
│   │   ├── test_chat_adapter.py
│   │   ├── test_citation.py
│   │   ├── test_code.py
│   │   ├── test_document.py
│   │   ├── test_json_adapter.py
│   │   ├── test_tool.py
│   │   ├── test_two_step_adapter.py
│   │   └── test_xml_adapter.py
│   ├── callback
│   │   └── test_callback.py
│   ├── clients
│   │   ├── test_cache.py
│   │   ├── test_databricks.py
│   │   ├── test_embedding.py
│   │   ├── test_inspect_global_history.py
│   │   └── test_lm.py
│   ├── conftest.py
│   ├── datasets
│   │   └── test_dataset.py
│   ├── docs
│   │   └── test_mkdocs_links.py
│   ├── evaluate
│   │   ├── test_evaluate.py
│   │   └── test_metrics.py
│   ├── examples
│   │   └── test_baleen.py
│   ├── metadata
│   │   └── test_metadata.py
│   ├── predict
│   │   ├── test_aggregation.py
│   │   ├── test_best_of_n.py
│   │   ├── test_chain_of_thought.py
│   │   ├── test_code_act.py
│   │   ├── test_knn.py
│   │   ├── test_multi_chain_comparison.py
│   │   ├── test_parallel.py
│   │   ├── test_predict.py
│   │   ├── test_program_of_thought.py
│   │   ├── test_react.py
│   │   ├── test_refine.py
│   │   └── test_retry.py
│   ├── primitives
│   │   ├── resources
│   │   │   └── saved_program.json
│   │   ├── test_base_module.py
│   │   ├── test_example.py
│   │   ├── test_module.py
│   │   └── test_python_interpreter.py
│   ├── propose
│   │   └── test_grounded_proposer.py
│   ├── README.md
│   ├── reliability
│   │   ├── __init__.py
│   │   ├── complex_types
│   │   │   └── generated
│   │   │       ├── test_many_types_1
│   │   │       │   ├── inputs
│   │   │       │   │   ├── input1.json
│   │   │       │   │   └── input2.json
│   │   │       │   ├── program.py
│   │   │       │   └── schema.json
│   │   │       ├── test_nesting_1
│   │   │       │   ├── inputs
│   │   │       │   │   ├── input1.json
│   │   │       │   │   └── input2.json
│   │   │       │   ├── program.py
│   │   │       │   └── schema.json
│   │   │       └── test_nesting_2
│   │   │           ├── inputs
│   │   │           │   └── input1.json
│   │   │           ├── program.py
│   │   │           └── schema.json
│   │   ├── conftest.py
│   │   ├── generate
│   │   │   ├── __init__.py
│   │   │   ├── __main__.py
│   │   │   └── utils.py
│   │   ├── input_formats
│   │   │   └── generated
│   │   │       └── test_markdown_1
│   │   │           ├── inputs
│   │   │           │   ├── input1.json
│   │   │           │   └── input2.json
│   │   │           ├── program.py
│   │   │           └── schema.json
│   │   ├── README.md
│   │   ├── reliability_conf.yaml
│   │   ├── test_generated.py
│   │   ├── test_pydantic_models.py
│   │   └── utils.py
│   ├── retrievers
│   │   └── test_embeddings.py
│   ├── signatures
│   │   ├── test_adapter_image.py
│   │   ├── test_custom_types.py
│   │   └── test_signature.py
│   ├── streaming
│   │   └── test_streaming.py
│   ├── teleprompt
│   │   ├── gepa_dummy_lm_custom_component_selector_custom_instruction_proposer.json
│   │   ├── gepa_dummy_lm.json
│   │   ├── test_bootstrap_finetune.py
│   │   ├── test_bootstrap_trace.py
│   │   ├── test_bootstrap.py
│   │   ├── test_copro_optimizer.py
│   │   ├── test_ensemble.py
│   │   ├── test_finetune.py
│   │   ├── test_gepa_instruction_proposer.py
│   │   ├── test_gepa.py
│   │   ├── test_grpo.py
│   │   ├── test_knn_fewshot.py
│   │   ├── test_random_search.py
│   │   ├── test_teleprompt.py
│   │   └── test_utils.py
│   ├── test_utils
│   │   ├── __init__.py
│   │   └── server
│   │       ├── __init__.py
│   │       ├── litellm_server_config.yaml
│   │       └── litellm_server.py
│   └── utils
│       ├── __init__.py
│       ├── resources
│       │   └── mcp_server.py
│       ├── test_annotation.py
│       ├── test_asyncify.py
│       ├── test_exceptions.py
│       ├── test_langchain_tool.py
│       ├── test_mcp.py
│       ├── test_parallelizer.py
│       ├── test_saving.py
│       ├── test_settings.py
│       ├── test_syncify.py
│       ├── test_unbatchify.py
│       └── test_usage_tracker.py
└── uv.lock
```

# Files

--------------------------------------------------------------------------------
/dspy/teleprompt/simba.py:
--------------------------------------------------------------------------------

```python
  1 | from __future__ import annotations
  2 | 
  3 | import logging
  4 | import random
  5 | from typing import Any, Callable
  6 | 
  7 | import numpy as np
  8 | 
  9 | import dspy
 10 | from dspy.teleprompt.simba_utils import append_a_demo, append_a_rule, prepare_models_for_resampling, wrap_program
 11 | from dspy.teleprompt.teleprompt import Teleprompter
 12 | 
 13 | logger = logging.getLogger(__name__)
 14 | 
 15 | 
 16 | class SIMBA(Teleprompter):
 17 |     """
 18 |     SIMBA (Stochastic Introspective Mini-Batch Ascent) optimizer for DSPy.
 19 |     
 20 |     SIMBA is a DSPy optimizer that uses the LLM to analyze its own performance and 
 21 |     generate improvement rules. It samples mini-batches, identifies challenging examples 
 22 |     with high output variability, then either creates self-reflective rules or adds 
 23 |     successful examples as demonstrations.
 24 |     
 25 |     For more details, see: https://dspy.ai/api/optimizers/SIMBA/
 26 |     """
 27 | 
 28 |     def __init__(
 29 |         self,
 30 |         *,
 31 |         metric: Callable[[dspy.Example, dict[str, Any]], float],
 32 |         bsize: int = 32,
 33 |         num_candidates: int = 6,
 34 |         max_steps: int = 8,
 35 |         max_demos: int = 4,
 36 |         prompt_model: dspy.LM | None = None,
 37 |         teacher_settings: dict | None = None,
 38 |         demo_input_field_maxlen: int = 100_000,
 39 |         num_threads: int | None = None,
 40 |         temperature_for_sampling: float = 0.2,
 41 |         temperature_for_candidates: float = 0.2,
 42 |     ) -> None:
 43 |         """
 44 |         Initializes SIMBA.
 45 | 
 46 |         Args:
 47 |             metric: A function that takes an Example and a prediction_dict
 48 |                 as input and returns a float.
 49 |             bsize: Mini-batch size. Defaults to 32.
 50 |             num_candidates: Number of new candidate programs to produce
 51 |                 per iteration. Defaults to 6.
 52 |             max_steps: Number of optimization steps to run. Defaults to 8.
 53 |             max_demos: Maximum number of demos a predictor can hold
 54 |                 before dropping some. Defaults to 4.
 55 |             prompt_model: The model to use to evolve the program. When `prompt_model is None`, the globally configured
 56 |                 lm is used.
 57 |             teacher_settings: Settings for the teacher model. Defaults to None.
 58 |             demo_input_field_maxlen: Maximum number of characters to keep
 59 |                 in an input field when building a new demo. Defaults to 100,000.
 60 |             num_threads: Number of threads for parallel execution.
 61 |                 Defaults to None.
 62 |             temperature_for_sampling: Temperature used for picking
 63 |                 programs during the trajectory-sampling step. Defaults to 0.2.
 64 |             temperature_for_candidates: Temperature used for picking
 65 |                 the source program for building new candidates. Defaults to 0.2.
 66 |         """
 67 |         self.metric = metric
 68 |         self.bsize = bsize
 69 |         self.num_candidates = num_candidates
 70 |         self.max_steps = max_steps
 71 |         self.max_demos = max_demos
 72 |         self.prompt_model = prompt_model or dspy.settings.lm
 73 |         self.teacher_settings = teacher_settings
 74 |         self.demo_input_field_maxlen = demo_input_field_maxlen
 75 |         self.num_threads = num_threads
 76 | 
 77 |         self.temperature_for_sampling = temperature_for_sampling
 78 |         self.temperature_for_candidates = temperature_for_candidates
 79 | 
 80 |         if self.max_demos > 0:
 81 |             self.strategies = [append_a_demo(demo_input_field_maxlen), append_a_rule]
 82 |         else:
 83 |             self.strategies = [append_a_rule]
 84 | 
 85 |     def compile(
 86 |         self,
 87 |         student: dspy.Module,
 88 |         *,
 89 |         trainset: list[dspy.Example],
 90 |         seed: int = 0
 91 |     ) -> dspy.Module:
 92 |         """
 93 |         Compile and optimize the student module using SIMBA.
 94 |         
 95 |         Args:
 96 |             student: The module to optimize
 97 |             trainset: Training examples for optimization
 98 |             seed: Random seed for reproducibility
 99 |             
100 |         Returns:
101 |             The optimized module with candidate_programs and trial_logs attached
102 |         """
103 |         # Basic checks
104 |         assert len(trainset) >= self.bsize, f"Trainset too small: {len(trainset)} < {self.bsize}"
105 | 
106 |         # Initialize RNG
107 |         rng = random.Random(seed)
108 |         rng_np = np.random.default_rng(seed)
109 | 
110 |         programs = []
111 |         program_scores = {}
112 |         next_program_idx = 0
113 | 
114 |         # Helper functions
115 |         def calc_average_score(prog_idx: int) -> float:
116 |             scores = program_scores.get(prog_idx, [])
117 |             if not scores:
118 |                 return 0.0
119 |             return sum(scores) / len(scores)
120 | 
121 |         def top_k_plus_baseline(k: int) -> list[int]:
122 |             # Sort all programs by descending average score
123 |             scored_programs = sorted(programs, key=lambda p: calc_average_score(p.simba_idx), reverse=True)
124 |             top_k = [p.simba_idx for p in scored_programs[:k]]
125 |             # Ensure baseline=0 is in there:
126 |             if 0 not in top_k and len(top_k) > 0:
127 |                 top_k[-1] = 0
128 |             return list(dict.fromkeys(top_k))
129 | 
130 |         def softmax_sample(rng_obj: random.Random, program_idxs: list[int], temperature: float) -> int:
131 |             if not program_idxs:
132 |                 raise ValueError("No programs available for softmax sampling.")
133 | 
134 |             # Unnormalized weights
135 |             scores = [calc_average_score(idx) for idx in program_idxs]
136 |             exps = [np.exp(s / temperature) for s in scores]
137 |             sum_exps = sum(exps)
138 |             if sum_exps <= 0:
139 |                 # Fallback: uniform if all exps are zero
140 |                 return rng_obj.choice(program_idxs)
141 | 
142 |             # Weighted random choice
143 |             probs = [val / sum_exps for val in exps]
144 |             return rng_obj.choices(program_idxs, weights=probs, k=1)[0]
145 | 
146 |         def register_new_program(prog: dspy.Module, score_list: list[float]) -> None:
147 |             nonlocal next_program_idx
148 |             next_program_idx += 1
149 |             new_idx = next_program_idx
150 |             prog.simba_idx = new_idx
151 |             programs.append(prog)
152 |             program_scores[new_idx] = score_list
153 | 
154 |         # Initialize the baseline program: index=0
155 |         student = student.deepcopy()
156 |         student.simba_idx = 0
157 |         programs.append(student)
158 |         program_scores[0] = []
159 | 
160 |         winning_programs = [student]
161 | 
162 |         # Data shuffling
163 |         data_indices = list(range(len(trainset)))
164 |         rng.shuffle(data_indices)
165 |         instance_idx = 0
166 | 
167 |         # Parallel runner
168 |         run_parallel = dspy.Parallel(access_examples=False, num_threads=self.num_threads)
169 | 
170 |         trial_logs = {}
171 |         for batch_idx in range(self.max_steps):
172 |             trial_logs[batch_idx] = {}
173 | 
174 |             logger.info(f"Starting batch {batch_idx+1} of {self.max_steps}.")
175 | 
176 |             # STEP 1: Get next batch
177 |             if instance_idx + self.bsize > len(trainset):
178 |                 rng.shuffle(data_indices)
179 |                 instance_idx = 0
180 | 
181 |             batch_indices = data_indices[instance_idx : instance_idx + self.bsize]
182 |             batch = [trainset[i] for i in batch_indices]
183 |             instance_idx += self.bsize
184 | 
185 |             # We'll generate (program, model) pairs for the trajectory sampling.
186 |             # Prepare distinct LMs (with different temperatures, etc.) from the baseline=programs[0].
187 |             models = prepare_models_for_resampling(programs[0], self.num_candidates, self.teacher_settings)
188 |             top_programs = top_k_plus_baseline(self.num_candidates)
189 | 
190 |             exec_pairs = []
191 |             predictor2name = {}
192 | 
193 |             # For each model, for each example, pick a program from the pool via softmax
194 |             for model in models:
195 |                 for example in batch:
196 |                     chosen_prog_idx = softmax_sample(rng, top_programs, self.temperature_for_sampling)
197 |                     candidate_system = programs[chosen_prog_idx].deepcopy()
198 |                     candidate_system.set_lm(model)
199 | 
200 |                     for name, predictor in candidate_system.named_predictors():
201 |                         predictor2name[id(predictor)] = name
202 | 
203 |                     # Use the special wrap that includes the 'example' in the output
204 |                     wrapped_candidate_system = wrap_program(candidate_system, self.metric)
205 |                     exec_pairs.append((wrapped_candidate_system, example))
206 | 
207 |             # STEP 2: Execute
208 |             logger.info(f"Sampling program trajectories on {self.bsize} examples x {self.num_candidates} samples.")
209 |             outputs = run_parallel(exec_pairs)
210 |             assert len(outputs) == len(exec_pairs) == self.bsize * self.num_candidates
211 | 
212 |             # STEP 3: Sort the training buckets by (max-to-min gap, max score, and max-to-avg gap).
213 |             buckets = []
214 |             largest_max_to_avg_gap = float("-inf")
215 |             batch_10th_percentile_score = np.percentile([float(o["score"]) for o in outputs], 10)
216 |             batch_90th_percentile_score = np.percentile([float(o["score"]) for o in outputs], 90)
217 | 
218 |             # We'll chunk `outputs` by example index, each chunk has length = num_candidates
219 |             for idx, _ in enumerate(batch):
220 |                 # gather all results for this example
221 |                 bucket = [outputs[i] for i in range(idx, len(outputs), self.bsize)]
222 |                 bucket.sort(key=lambda x: x["score"], reverse=True)
223 | 
224 |                 max_score = float(bucket[0]["score"])
225 |                 min_score = float(bucket[-1]["score"])
226 |                 avg_score = sum(x["score"] for x in bucket) / len(bucket)
227 |                 max_to_min_gap = max_score - min_score
228 |                 max_to_avg_gap = max_score - avg_score
229 |                 if max_to_avg_gap > largest_max_to_avg_gap:
230 |                     largest_max_to_avg_gap = max_to_avg_gap
231 | 
232 |                 buckets.append((bucket, (max_to_min_gap, max_score, max_to_avg_gap)))
233 | 
234 |             # sort the buckets
235 |             buckets.sort(key=lambda x: x[1], reverse=True)
236 | 
237 |             # Baseline for the batch is just the average of all runs
238 |             all_scores_in_this_batch = [o["score"] for o in outputs]
239 |             baseline_score = sum(all_scores_in_this_batch) / len(all_scores_in_this_batch)
240 |             logger.info(f"Batch {batch_idx+1}: Baseline mini-batch score: {baseline_score}\n")
241 | 
242 |             # STEP 4: Build new candidate programs by applying a strategy to some top buckets.
243 |             system_candidates = []
244 |             for bucket_idx, (bucket, bucket_stats) in enumerate(buckets):
245 |                 max_to_min_gap, max_score, max_to_avg_gap = bucket_stats
246 |                 logger.info(
247 |                     f"Batch {batch_idx+1}: Processing bucket #{bucket_idx+1}, with max score {max_score}, "
248 |                     f"max-to-min gap {max_to_min_gap}, and max-to-avg gap {max_to_avg_gap}."
249 |                 )
250 | 
251 |                 # pick source program
252 |                 src_prog_idx = softmax_sample(
253 |                     rng, top_k_plus_baseline(self.num_candidates), self.temperature_for_candidates
254 |                 )
255 |                 system_candidate = programs[src_prog_idx].deepcopy()
256 | 
257 |                 # Drop some demos from each predictor
258 |                 name2predictor = {}
259 |                 num_demos_list = []
260 | 
261 |                 max_demos_tmp = self.max_demos if self.max_demos > 0 else 3
262 | 
263 |                 for name, predictor in system_candidate.named_predictors():
264 |                     name2predictor[name] = predictor
265 |                     num_demos_list.append(len(predictor.demos))
266 | 
267 |                 num_demos = max(num_demos_list) if num_demos_list else 0
268 |                 num_demos_to_drop = max(rng_np.poisson(num_demos / max_demos_tmp), int(num_demos >= max_demos_tmp))
269 |                 num_demos_to_drop = min(num_demos_to_drop, num_demos)
270 |                 demos_to_drop = [rng.randrange(num_demos) for _ in range(num_demos_to_drop)]
271 | 
272 |                 for _, predictor in name2predictor.items():
273 |                     predictor.demos = [demo for idxd, demo in enumerate(predictor.demos) if idxd not in demos_to_drop]
274 | 
275 |                 # Pick a strategy
276 |                 strategy = rng.choice(self.strategies)
277 |                 logger.info(
278 |                     f"Batch {batch_idx+1}: Invoking strategy: {strategy.__name__}"
279 |                     + (f", having dropped {num_demos_to_drop} demos per predictor" if num_demos_to_drop else "")
280 |                 )
281 | 
282 |                 try:
283 |                     strategy(
284 |                         bucket,
285 |                         system_candidate,
286 |                         predictor2name=predictor2name,
287 |                         name2predictor=name2predictor,
288 |                         batch_10p_score=batch_10th_percentile_score,
289 |                         batch_90p_score=batch_90th_percentile_score,
290 |                         prompt_model=self.prompt_model,
291 |                     )
292 |                 except Exception as e:
293 |                     logger.error(f"Strategy failed with error: {e}")
294 |                     continue
295 | 
296 |                 system_candidates.append(system_candidate)
297 |                 logger.info("\n")
298 | 
299 |                 if len(system_candidates) >= self.num_candidates + 1:
300 |                     break
301 | 
302 |             # STEP 5: Evaluate these new system_candidates on the same mini-batch
303 |             logger.info(f"Batch {batch_idx+1}: Evaluating {len(system_candidates)} programs on {self.bsize} examples.")
304 | 
305 |             exec_pairs = [(wrap_program(sys, self.metric), ex) for sys in system_candidates for ex in batch]
306 |             outputs = run_parallel(exec_pairs)
307 |             assert len(outputs) == len(exec_pairs) == len(system_candidates) * self.bsize
308 | 
309 |             # STEP 6: Compute average mini-batch scores for each new candidate
310 |             candidate_scores = []
311 |             for idx_cand, _ in enumerate(system_candidates):
312 |                 start = idx_cand * self.bsize
313 |                 end = (idx_cand + 1) * self.bsize
314 |                 sys_scores = [outputs[i]["score"] for i in range(start, end)]
315 |                 avg_sys_score = sum(sys_scores) / len(sys_scores)
316 |                 candidate_scores.append(avg_sys_score)
317 | 
318 |             logger.info(
319 |                 f"Scores after {batch_idx+1} batches: {candidate_scores}, "
320 |                 f"Best: {max(candidate_scores) if candidate_scores else 'N/A'}\n"
321 |             )
322 | 
323 |             # STEP 7: Select the best among these new ones for "winning" record
324 |             if candidate_scores:
325 |                 best_idx_among_candidates = candidate_scores.index(max(candidate_scores))
326 |                 best_program = system_candidates[best_idx_among_candidates]
327 |                 winning_programs.append(best_program.deepcopy())
328 | 
329 |             # STEP 8: Register all new candidate systems in our global pool
330 |             for idx_cand, cand_sys in enumerate(system_candidates):
331 |                 start = idx_cand * self.bsize
332 |                 end = (idx_cand + 1) * self.bsize
333 |                 sys_scores = [outputs[i]["score"] for i in range(start, end)]
334 |                 register_new_program(cand_sys, sys_scores)
335 | 
336 |         M = len(winning_programs) - 1  # noqa: N806
337 |         N = self.num_candidates + 1  # noqa: N806
338 |         if M < 1:
339 |             program_idxs = [0] * N
340 |         else:
341 |             program_idxs = [round(i * M / (N - 1)) for i in range(N)]
342 | 
343 |         program_idxs = list(dict.fromkeys(program_idxs))
344 | 
345 |         candidate_programs = [winning_programs[i].deepcopy() for i in program_idxs]
346 |         logger.info(f"VALIDATION: Evaluating {len(candidate_programs)} programs on the full trainset.")
347 |         exec_pairs = [(wrap_program(sys, self.metric), ex) for sys in candidate_programs for ex in trainset]
348 |         outputs = run_parallel(exec_pairs)
349 | 
350 |         scores = []
351 |         for idx_prog, _ in enumerate(candidate_programs):
352 |             start = idx_prog * len(trainset)
353 |             end = (idx_prog + 1) * len(trainset)
354 |             sys_scores = [outputs[i]["score"] for i in range(start, end)]
355 |             avg_score = sum(sys_scores) / len(sys_scores) if sys_scores else 0.0
356 |             scores.append(avg_score)
357 |             if idx_prog != 0:
358 |                 trial_logs[idx_prog - 1]["train_score"] = avg_score
359 | 
360 |         # Build sorted list of {"score", "program"} dicts
361 |         assert len(scores) == len(candidate_programs)
362 |         candidate_data = [{"score": s, "program": p} for s, p in zip(scores, candidate_programs, strict=False)]
363 |         candidate_data.sort(key=lambda x: x["score"], reverse=True)
364 | 
365 |         best_idx = scores.index(max(scores)) if scores else 0
366 |         best_program = candidate_programs[best_idx].deepcopy()
367 |         logger.info(
368 |             f"Final trainset scores: {scores}, Best: {max(scores) if scores else 'N/A'} "
369 |             f"(at index {best_idx if scores else 'N/A'})\n\n\n"
370 |         )
371 | 
372 |         # Attach sorted, scored candidates & logs
373 |         best_program.candidate_programs = candidate_data
374 |         best_program.trial_logs = trial_logs
375 | 
376 |         return best_program
377 | 
```

--------------------------------------------------------------------------------
/dspy/adapters/types/tool.py:
--------------------------------------------------------------------------------

```python
  1 | import asyncio
  2 | import inspect
  3 | from typing import TYPE_CHECKING, Any, Callable, get_origin, get_type_hints
  4 | 
  5 | import pydantic
  6 | from jsonschema import ValidationError, validate
  7 | from pydantic import BaseModel, TypeAdapter, create_model
  8 | 
  9 | from dspy.adapters.types.base_type import Type
 10 | from dspy.dsp.utils.settings import settings
 11 | from dspy.utils.callback import with_callbacks
 12 | 
 13 | if TYPE_CHECKING:
 14 |     import mcp
 15 |     from langchain.tools import BaseTool
 16 | 
 17 | _TYPE_MAPPING = {"string": str, "integer": int, "number": float, "boolean": bool, "array": list, "object": dict}
 18 | 
 19 | 
 20 | class Tool(Type):
 21 |     """Tool class.
 22 | 
 23 |     This class is used to simplify the creation of tools for tool calling (function calling) in LLMs. Only supports
 24 |     functions for now.
 25 |     """
 26 | 
 27 |     func: Callable
 28 |     name: str | None = None
 29 |     desc: str | None = None
 30 |     args: dict[str, Any] | None = None
 31 |     arg_types: dict[str, Any] | None = None
 32 |     arg_desc: dict[str, str] | None = None
 33 |     has_kwargs: bool = False
 34 | 
 35 |     def __init__(
 36 |         self,
 37 |         func: Callable,
 38 |         name: str | None = None,
 39 |         desc: str | None = None,
 40 |         args: dict[str, Any] | None = None,
 41 |         arg_types: dict[str, Any] | None = None,
 42 |         arg_desc: dict[str, str] | None = None,
 43 |     ):
 44 |         """Initialize the Tool class.
 45 | 
 46 |         Users can choose to specify the `name`, `desc`, `args`, and `arg_types`, or let the `dspy.Tool`
 47 |         automatically infer the values from the function. For values that are specified by the user, automatic inference
 48 |         will not be performed on them.
 49 | 
 50 |         Args:
 51 |             func (Callable): The actual function that is being wrapped by the tool.
 52 |             name (Optional[str], optional): The name of the tool. Defaults to None.
 53 |             desc (Optional[str], optional): The description of the tool. Defaults to None.
 54 |             args (Optional[dict[str, Any]], optional): The args and their schema of the tool, represented as a
 55 |                 dictionary from arg name to arg's json schema. Defaults to None.
 56 |             arg_types (Optional[dict[str, Any]], optional): The argument types of the tool, represented as a dictionary
 57 |                 from arg name to the type of the argument. Defaults to None.
 58 |             arg_desc (Optional[dict[str, str]], optional): Descriptions for each arg, represented as a
 59 |                 dictionary from arg name to description string. Defaults to None.
 60 | 
 61 |         Example:
 62 | 
 63 |         ```python
 64 |         def foo(x: int, y: str = "hello"):
 65 |             return str(x) + y
 66 | 
 67 |         tool = Tool(foo)
 68 |         print(tool.args)
 69 |         # Expected output: {'x': {'type': 'integer'}, 'y': {'type': 'string', 'default': 'hello'}}
 70 |         ```
 71 |         """
 72 |         super().__init__(func=func, name=name, desc=desc, args=args, arg_types=arg_types, arg_desc=arg_desc)
 73 |         self._parse_function(func, arg_desc)
 74 | 
 75 |     def _parse_function(self, func: Callable, arg_desc: dict[str, str] | None = None):
 76 |         """Helper method that parses a function to extract the name, description, and args.
 77 | 
 78 |         This is a helper function that automatically infers the name, description, and args of the tool from the
 79 |         provided function. In order to make the inference work, the function must have valid type hints.
 80 |         """
 81 |         annotations_func = func if inspect.isfunction(func) or inspect.ismethod(func) else func.__call__
 82 |         name = getattr(func, "__name__", type(func).__name__)
 83 |         desc = getattr(func, "__doc__", None) or getattr(annotations_func, "__doc__", "")
 84 |         args = {}
 85 |         arg_types = {}
 86 | 
 87 |         # Use inspect.signature to get all arg names
 88 |         sig = inspect.signature(annotations_func)
 89 |         # Get available type hints
 90 |         available_hints = get_type_hints(annotations_func)
 91 |         # Build a dictionary of arg name -> type (defaulting to Any when missing)
 92 |         hints = {param_name: available_hints.get(param_name, Any) for param_name in sig.parameters.keys()}
 93 |         default_values = {param_name: sig.parameters[param_name].default for param_name in sig.parameters.keys()}
 94 | 
 95 |         # Process each argument's type to generate its JSON schema.
 96 |         for k, v in hints.items():
 97 |             arg_types[k] = v
 98 |             if k == "return":
 99 |                 continue
100 |             # Check if the type (or its origin) is a subclass of Pydantic's BaseModel
101 |             origin = get_origin(v) or v
102 |             if isinstance(origin, type) and issubclass(origin, BaseModel):
103 |                 # Get json schema, and replace $ref with the actual schema
104 |                 v_json_schema = _resolve_json_schema_reference(v.model_json_schema())
105 |                 args[k] = v_json_schema
106 |             else:
107 |                 args[k] = _resolve_json_schema_reference(TypeAdapter(v).json_schema())
108 |             if default_values[k] is not inspect.Parameter.empty:
109 |                 args[k]["default"] = default_values[k]
110 |             if arg_desc and k in arg_desc:
111 |                 args[k]["description"] = arg_desc[k]
112 | 
113 |         self.name = self.name or name
114 |         self.desc = self.desc or desc
115 |         self.args = self.args if self.args is not None else args
116 |         self.arg_types = self.arg_types if self.arg_types is not None else arg_types
117 |         self.has_kwargs = any(param.kind == param.VAR_KEYWORD for param in sig.parameters.values())
118 | 
119 |     def _validate_and_parse_args(self, **kwargs):
120 |         # Validate the args value comply to the json schema.
121 |         for k, v in kwargs.items():
122 |             if k not in self.args:
123 |                 if self.has_kwargs:
124 |                     continue
125 |                 else:
126 |                     raise ValueError(f"Arg {k} is not in the tool's args.")
127 |             try:
128 |                 instance = v.model_dump() if hasattr(v, "model_dump") else v
129 |                 type_str = self.args[k].get("type")
130 |                 if type_str is not None and type_str != "Any":
131 |                     validate(instance=instance, schema=self.args[k])
132 |             except ValidationError as e:
133 |                 raise ValueError(f"Arg {k} is invalid: {e.message}")
134 | 
135 |         # Parse the args to the correct type.
136 |         parsed_kwargs = {}
137 |         for k, v in kwargs.items():
138 |             if k in self.arg_types and self.arg_types[k] != Any:
139 |                 # Create a pydantic model wrapper with a dummy field `value` to parse the arg to the correct type.
140 |                 # This is specifically useful for handling nested Pydantic models like `list[list[MyPydanticModel]]`
141 |                 pydantic_wrapper = create_model("Wrapper", value=(self.arg_types[k], ...))
142 |                 parsed = pydantic_wrapper.model_validate({"value": v})
143 |                 parsed_kwargs[k] = parsed.value
144 |             else:
145 |                 parsed_kwargs[k] = v
146 |         return parsed_kwargs
147 | 
148 |     def format(self):
149 |         return str(self)
150 | 
151 |     def format_as_litellm_function_call(self):
152 |         return {
153 |             "type": "function",
154 |             "function": {
155 |                 "name": self.name,
156 |                 "description": self.desc,
157 |                 "parameters": {
158 |                     "type": "object",
159 |                     "properties": self.args,
160 |                     "required": list(self.args.keys()),
161 |                 },
162 |             },
163 |         }
164 | 
165 |     def _run_async_in_sync(self, coroutine):
166 |         try:
167 |             loop = asyncio.get_running_loop()
168 |         except RuntimeError:
169 |             return asyncio.run(coroutine)
170 | 
171 |         return loop.run_until_complete(coroutine)
172 | 
173 |     @with_callbacks
174 |     def __call__(self, **kwargs):
175 |         parsed_kwargs = self._validate_and_parse_args(**kwargs)
176 |         result = self.func(**parsed_kwargs)
177 |         if asyncio.iscoroutine(result):
178 |             if settings.allow_tool_async_sync_conversion:
179 |                 return self._run_async_in_sync(result)
180 |             else:
181 |                 raise ValueError(
182 |                     "You are calling `__call__` on an async tool, please use `acall` instead or set "
183 |                     "`allow_async=True` to run the async tool in sync mode."
184 |                 )
185 |         return result
186 | 
187 |     @with_callbacks
188 |     async def acall(self, **kwargs):
189 |         parsed_kwargs = self._validate_and_parse_args(**kwargs)
190 |         result = self.func(**parsed_kwargs)
191 |         if asyncio.iscoroutine(result):
192 |             return await result
193 |         else:
194 |             # We should allow calling a sync tool in the async path.
195 |             return result
196 | 
197 |     @classmethod
198 |     def from_mcp_tool(cls, session: "mcp.ClientSession", tool: "mcp.types.Tool") -> "Tool":
199 |         """
200 |         Build a DSPy tool from an MCP tool and a ClientSession.
201 | 
202 |         Args:
203 |             session: The MCP session to use.
204 |             tool: The MCP tool to convert.
205 | 
206 |         Returns:
207 |             A Tool object.
208 |         """
209 |         from dspy.utils.mcp import convert_mcp_tool
210 | 
211 |         return convert_mcp_tool(session, tool)
212 | 
213 |     @classmethod
214 |     def from_langchain(cls, tool: "BaseTool") -> "Tool":
215 |         """
216 |         Build a DSPy tool from a LangChain tool.
217 | 
218 |         Args:
219 |             tool: The LangChain tool to convert.
220 | 
221 |         Returns:
222 |             A Tool object.
223 | 
224 |         Example:
225 | 
226 |         ```python
227 |         import asyncio
228 |         import dspy
229 |         from langchain.tools import tool as lc_tool
230 | 
231 |         @lc_tool
232 |         def add(x: int, y: int):
233 |             "Add two numbers together."
234 |             return x + y
235 | 
236 |         dspy_tool = dspy.Tool.from_langchain(add)
237 | 
238 |         async def run_tool():
239 |             return await dspy_tool.acall(x=1, y=2)
240 | 
241 |         print(asyncio.run(run_tool()))
242 |         # 3
243 |         ```
244 |         """
245 |         from dspy.utils.langchain_tool import convert_langchain_tool
246 | 
247 |         return convert_langchain_tool(tool)
248 | 
249 |     def __repr__(self):
250 |         return f"Tool(name={self.name}, desc={self.desc}, args={self.args})"
251 | 
252 |     def __str__(self):
253 |         desc = f", whose description is <desc>{self.desc}</desc>.".replace("\n", "  ") if self.desc else "."
254 |         arg_desc = f"It takes arguments {self.args}."
255 |         return f"{self.name}{desc} {arg_desc}"
256 | 
257 | 
258 | class ToolCalls(Type):
259 |     class ToolCall(Type):
260 |         name: str
261 |         args: dict[str, Any]
262 | 
263 |         def format(self):
264 |             return {
265 |                 "type": "function",
266 |                 "function": {
267 |                     "name": self.name,
268 |                     "arguments": self.args,
269 |                 },
270 |             }
271 | 
272 |         def execute(self, functions: dict[str, Any] | list[Tool] | None = None) -> Any:
273 |             """Execute this individual tool call and return its result.
274 | 
275 |             Args:
276 |                 functions: Functions to search for the tool. Can be:
277 |                           - Dict mapping tool names to functions: {"tool_name": function}
278 |                           - List of Tool objects: [Tool(function), ...]
279 |                           - None: Will search in caller's locals and globals (automatic lookup)
280 | 
281 |             Returns:
282 |                 The result from executing this tool call.
283 | 
284 |             Raises:
285 |                 ValueError: If the tool function cannot be found.
286 |                 Exception: Any exception raised by the tool function.
287 |             """
288 |             func = None
289 | 
290 |             if functions is None:
291 |                 # Automatic lookup in caller's globals and locals
292 |                 frame = inspect.currentframe().f_back
293 |                 try:
294 |                     caller_globals = frame.f_globals
295 |                     caller_locals = frame.f_locals
296 |                     func = caller_locals.get(self.name) or caller_globals.get(self.name)
297 |                 finally:
298 |                     del frame
299 | 
300 |             elif isinstance(functions, dict):
301 |                 func = functions.get(self.name)
302 |             elif isinstance(functions, list):
303 |                 for tool in functions:
304 |                     if tool.name == self.name:
305 |                         func = tool.func
306 |                         break
307 | 
308 |             if func is None:
309 |                 raise ValueError(f"Tool function '{self.name}' not found. Please pass the tool functions to the `execute` method.")
310 | 
311 |             try:
312 |                 args = self.args or {}
313 |                 return func(**args)
314 |             except Exception as e:
315 |                 raise RuntimeError(f"Error executing tool '{self.name}': {e}") from e
316 | 
317 |     tool_calls: list[ToolCall]
318 | 
319 |     @classmethod
320 |     def from_dict_list(cls, tool_calls_dicts: list[dict[str, Any]]) -> "ToolCalls":
321 |         """Convert a list of dictionaries to a ToolCalls instance.
322 | 
323 |         Args:
324 |             dict_list: A list of dictionaries, where each dictionary should have 'name' and 'args' keys.
325 | 
326 |         Returns:
327 |             A ToolCalls instance.
328 | 
329 |         Example:
330 | 
331 |             ```python
332 |             tool_calls_dict = [
333 |                 {"name": "search", "args": {"query": "hello"}},
334 |                 {"name": "translate", "args": {"text": "world"}}
335 |             ]
336 |             tool_calls = ToolCalls.from_dict_list(tool_calls_dict)
337 |             ```
338 |         """
339 |         tool_calls = [cls.ToolCall(**item) for item in tool_calls_dicts]
340 |         return cls(tool_calls=tool_calls)
341 | 
342 |     @classmethod
343 |     def description(cls) -> str:
344 |         return (
345 |             "Tool calls information, including the name of the tools and the arguments to be passed to it. "
346 |             "Arguments must be provided in JSON format."
347 |         )
348 | 
349 |     def format(self) -> list[dict[str, Any]]:
350 |         # The tool_call field is compatible with OpenAI's tool calls schema.
351 |         return {
352 |             "tool_calls": [tool_call.format() for tool_call in self.tool_calls],
353 |         }
354 | 
355 |     @pydantic.model_validator(mode="before")
356 |     @classmethod
357 |     def validate_input(cls, data: Any):
358 |         if isinstance(data, cls):
359 |             return data
360 | 
361 |         # Handle case where data is a list of dicts with "name" and "args" keys
362 |         if isinstance(data, list) and all(
363 |             isinstance(item, dict) and "name" in item and "args" in item for item in data
364 |         ):
365 |             return {"tool_calls": [cls.ToolCall(**item) for item in data]}
366 |         # Handle case where data is a dict
367 |         elif isinstance(data, dict):
368 |             if "tool_calls" in data:
369 |                 # Handle case where data is a dict with "tool_calls" key
370 |                 tool_calls_data = data["tool_calls"]
371 |                 if isinstance(tool_calls_data, list):
372 |                     return {
373 |                         "tool_calls": [
374 |                             cls.ToolCall(**item) if isinstance(item, dict) else item for item in tool_calls_data
375 |                         ]
376 |                     }
377 |             elif "name" in data and "args" in data:
378 |                 # Handle case where data is a dict with "name" and "args" keys
379 |                 return {"tool_calls": [cls.ToolCall(**data)]}
380 | 
381 |         raise ValueError(f"Received invalid value for `dspy.ToolCalls`: {data}")
382 | 
383 | 
384 | def _resolve_json_schema_reference(schema: dict) -> dict:
385 |     """Recursively resolve json model schema, expanding all references."""
386 | 
387 |     # If there are no definitions to resolve, return the main schema
388 |     if "$defs" not in schema and "definitions" not in schema:
389 |         return schema
390 | 
391 |     def resolve_refs(obj: Any) -> Any:
392 |         if not isinstance(obj, (dict, list)):
393 |             return obj
394 |         if isinstance(obj, dict):
395 |             if "$ref" in obj:
396 |                 ref_path = obj["$ref"].split("/")[-1]
397 |                 return resolve_refs(schema["$defs"][ref_path])
398 |             return {k: resolve_refs(v) for k, v in obj.items()}
399 | 
400 |         # Must be a list
401 |         return [resolve_refs(item) for item in obj]
402 | 
403 |     # Resolve all references in the main schema
404 |     resolved_schema = resolve_refs(schema)
405 |     # Remove the $defs key as it's no longer needed
406 |     resolved_schema.pop("$defs", None)
407 |     return resolved_schema
408 | 
409 | 
410 | def convert_input_schema_to_tool_args(
411 |     schema: dict[str, Any],
412 | ) -> tuple[dict[str, Any], dict[str, Type], dict[str, str]]:
413 |     """Convert an input json schema to tool arguments compatible with DSPy Tool.
414 | 
415 |     Args:
416 |         schema: An input json schema describing the tool's input parameters
417 | 
418 |     Returns:
419 |         A tuple of (args, arg_types, arg_desc) for DSPy Tool definition.
420 |     """
421 |     args, arg_types, arg_desc = {}, {}, {}
422 |     properties = schema.get("properties", None)
423 |     if properties is None:
424 |         return args, arg_types, arg_desc
425 | 
426 |     required = schema.get("required", [])
427 | 
428 |     defs = schema.get("$defs", {})
429 | 
430 |     for name, prop in properties.items():
431 |         if len(defs) > 0:
432 |             prop = _resolve_json_schema_reference({"$defs": defs, **prop})
433 |         args[name] = prop
434 |         arg_types[name] = _TYPE_MAPPING.get(prop.get("type"), Any)
435 |         arg_desc[name] = prop.get("description", "No description provided.")
436 |         if name in required:
437 |             arg_desc[name] += " (Required)"
438 | 
439 |     return args, arg_types, arg_desc
440 | 
```

--------------------------------------------------------------------------------
/dspy/teleprompt/copro_optimizer.py:
--------------------------------------------------------------------------------

```python
  1 | import logging
  2 | from collections import defaultdict
  3 | 
  4 | import dspy
  5 | from dspy.evaluate.evaluate import Evaluate
  6 | from dspy.signatures import Signature
  7 | from dspy.teleprompt.teleprompt import Teleprompter
  8 | 
  9 | logger = logging.getLogger(__name__)
 10 | 
 11 | """
 12 | USAGE SUGGESTIONS:
 13 | 
 14 | The following code can be used to compile a optimized signature teleprompter, and evaluate it on an end task:
 15 | 
 16 | teleprompter = COPRO(prompt_model=prompt_model, metric=metric, breadth=BREADTH, depth=DEPTH, init_temperature=INIT_TEMPERATURE)
 17 | kwargs = dict(num_threads=NUM_THREADS, display_progress=True, display_table=0)
 18 | compiled_prompt_opt = teleprompter.compile(program.deepcopy(), trainset=trainset[:DEV_NUM], eval_kwargs=kwargs)
 19 | eval_score = evaluate(compiled_prompt_opt, devset=evalset[:EVAL_NUM], **kwargs)
 20 | 
 21 | Note that this teleprompter takes in the following parameters:
 22 | 
 23 | * prompt_model: The model used for prompt generation. When unspecified, defaults to the model set in settings (ie. dspy.settings.configure(lm=task_model)).
 24 | * metric: The task metric used for optimization.
 25 | * breadth: The number of new prompts to generate at each iteration. Default=10.
 26 | * depth: The number of times we should ask our prompt model to generate new prompts, with the history of the past prompts as input. Default=3.
 27 | * init_temperature: The temperature used to generate new prompts. Higher roughly equals more creative. Default=1.4.
 28 | * track_stats: Tells the method whether or not to track statistics about the optimization process.
 29 |                 If True, the method will track the following statistics:
 30 |                     * results_best: The min,max,avg,stddev of top 10 scores for each predictor at each depth.
 31 |                     * results_latest: The min,max,avg,stddev of newest prompt scores for each predictor at each depth.
 32 |                     * total_calls: The total number of calls to the task metric.
 33 |                 These statistics will be returned as attributes of the best program.
 34 | """
 35 | 
 36 | 
 37 | class BasicGenerateInstruction(Signature):
 38 |     """You are an instruction optimizer for large language models. I will give you a ``signature`` of fields (inputs and outputs) in English. Your task is to propose an instruction that will lead a good language model to perform the task well. Don't be afraid to be creative."""
 39 | 
 40 |     basic_instruction = dspy.InputField(desc="The initial instructions before optimization")
 41 |     proposed_instruction = dspy.OutputField(desc="The improved instructions for the language model")
 42 |     proposed_prefix_for_output_field = dspy.OutputField(
 43 |         desc="The string at the end of the prompt, which will help the model start solving the task",
 44 |     )
 45 | 
 46 | 
 47 | class GenerateInstructionGivenAttempts(dspy.Signature):
 48 |     """You are an instruction optimizer for large language models. I will give some task instructions I've tried, along with their corresponding validation scores. The instructions are arranged in increasing order based on their scores, where higher scores indicate better quality.
 49 | 
 50 |     Your task is to propose a new instruction that will lead a good language model to perform the task even better. Don't be afraid to be creative."""
 51 | 
 52 |     attempted_instructions = dspy.InputField()
 53 |     proposed_instruction = dspy.OutputField(desc="The improved instructions for the language model")
 54 |     proposed_prefix_for_output_field = dspy.OutputField(
 55 |         desc="The string at the end of the prompt, which will help the model start solving the task",
 56 |     )
 57 | 
 58 | 
 59 | class COPRO(Teleprompter):
 60 |     def __init__(
 61 |         self,
 62 |         prompt_model=None,
 63 |         metric=None,
 64 |         breadth=10,
 65 |         depth=3,
 66 |         init_temperature=1.4,
 67 |         track_stats=False,
 68 |         **_kwargs,
 69 |     ):
 70 |         if breadth <= 1:
 71 |             raise ValueError("Breadth must be greater than 1")
 72 |         self.metric = metric
 73 |         self.breadth = breadth
 74 |         self.depth = depth
 75 |         self.init_temperature = init_temperature
 76 |         self.prompt_model = prompt_model
 77 |         self.track_stats = track_stats
 78 | 
 79 |     def _check_candidates_equal(self, candidate1, candidate2):
 80 |         for p1, p2 in zip(candidate1["program"].predictors(), candidate2["program"].predictors(), strict=False):
 81 |             if self._get_signature(p1).instructions != self._get_signature(p2).instructions:
 82 |                 return False
 83 |             *_, p1_last_field = self._get_signature(p1).fields.values()
 84 |             *_, p2_last_field = self._get_signature(p2).fields.values()
 85 |             if p1_last_field != p2_last_field:
 86 |                 return False
 87 |         return True
 88 | 
 89 |     def _drop_duplicates(self, candidates):
 90 |         final_candidates = []
 91 |         last_batch = []
 92 |         last_batch_score = -1
 93 |         for c in candidates:
 94 |             repeat = False
 95 |             if c["score"] == last_batch_score:
 96 |                 for c2 in last_batch:
 97 |                     if self._check_candidates_equal(c, c2):
 98 |                         repeat = True
 99 |                         break
100 |                 if not repeat:
101 |                     last_batch.append(c)
102 |             else:
103 |                 last_batch = [c]
104 |                 last_batch_score = c["score"]
105 |             if not repeat:
106 |                 final_candidates.append(c)
107 |         return final_candidates
108 | 
109 |     def _print_signature(self, predictor):
110 |         signature = self._get_signature(predictor)
111 | 
112 |         logger.debug(f"i: {signature.instructions}")
113 |         logger.debug(f"p: {list(signature.fields.values())[-1].json_schema_extra['prefix']}")
114 | 
115 |     def _get_signature(self, predictor):
116 |         assert hasattr(predictor, "signature")
117 |         return predictor.signature
118 | 
119 |     def _set_signature(self, predictor, updated_signature):
120 |         assert hasattr(predictor, "signature")
121 |         predictor.signature = updated_signature
122 | 
123 |     def compile(self, student, *, trainset, eval_kwargs):
124 |         """
125 |         optimizes `signature` of `student` program - note that it may be zero-shot or already pre-optimized (demos already chosen - `demos != []`)
126 | 
127 |         parameters:
128 |         student: program to optimize and left modified.
129 |         trainset: iterable of `Example`s
130 |         eval_kwargs: optional, dict
131 |            Additional keywords to go into `Evaluate` for the metric.
132 | 
133 |         Returns optimized version of `student`.
134 |         """
135 |         module = student.deepcopy()
136 |         evaluate = Evaluate(devset=trainset, metric=self.metric, **eval_kwargs)
137 |         total_calls = 0
138 |         results_best = {
139 |             id(p): {"depth": [], "max": [], "average": [], "min": [], "std": []} for p in module.predictors()
140 |         }
141 |         results_latest = {
142 |             id(p): {"depth": [], "max": [], "average": [], "min": [], "std": []} for p in module.predictors()
143 |         }
144 | 
145 |         if self.track_stats:
146 |             import numpy as np
147 | 
148 |         candidates = {}
149 |         evaluated_candidates = defaultdict(dict)
150 | 
151 |         # Seed the prompt optimizer zero shot with just the instruction, generate BREADTH new prompts
152 |         for predictor in module.predictors():
153 |             basic_instruction = None
154 |             basic_prefix = None
155 |             *_, last_key = self._get_signature(predictor).fields.keys()
156 |             basic_instruction = self._get_signature(predictor).instructions
157 |             basic_prefix = self._get_signature(predictor).fields[last_key].json_schema_extra["prefix"]
158 |             if self.prompt_model:
159 |                 with dspy.settings.context(lm=self.prompt_model):
160 |                     instruct = dspy.Predict(
161 |                         BasicGenerateInstruction,
162 |                         n=self.breadth - 1,
163 |                         temperature=self.init_temperature,
164 |                     )(basic_instruction=basic_instruction)
165 |             else:
166 |                 instruct = dspy.Predict(
167 |                     BasicGenerateInstruction,
168 |                     n=self.breadth - 1,
169 |                     temperature=self.init_temperature,
170 |                 )(basic_instruction=basic_instruction)
171 |             # Add in our initial prompt as a candidate as well
172 |             instruct.completions.proposed_instruction.append(basic_instruction)
173 |             instruct.completions.proposed_prefix_for_output_field.append(basic_prefix)
174 |             candidates[id(predictor)] = instruct.completions
175 |             evaluated_candidates[id(predictor)] = {}
176 | 
177 |         if self.prompt_model:
178 |             logger.debug(f"{self.prompt_model.inspect_history(n=1)}")
179 | 
180 |         latest_candidates = candidates
181 |         all_candidates = candidates
182 | 
183 |         module_clone = module.deepcopy()
184 | 
185 |         # For each iteration in depth...
186 |         for d in range(
187 |             self.depth,
188 |         ):  # TODO: fix this so that we eval the new batch of predictors with the new best following predictors
189 |             logger.info(f"Iteration Depth: {d+1}/{self.depth}.")
190 | 
191 |             latest_scores = []
192 | 
193 |             # Go through our module's predictors
194 |             for p_i, (p_old, p_new) in enumerate(zip(module.predictors(), module_clone.predictors(), strict=False)):
195 |                 candidates_ = latest_candidates[id(p_old)]  # Use the most recently generated candidates for evaluation
196 |                 if len(module.predictors()) > 1:
197 |                     # Unless our program has multiple predictors, in which case we need to reevaluate all prompts with
198 |                     # the new prompt(s) for the other predictor(s).
199 |                     candidates_ = all_candidates[
200 |                         id(p_old)
201 |                     ]
202 | 
203 |                 # For each candidate
204 |                 for c_i, c in enumerate(candidates_):
205 |                     # Get the candidate instruction and prefix
206 |                     instruction, prefix = (
207 |                         c.proposed_instruction.strip('"').strip(),
208 |                         c.proposed_prefix_for_output_field.strip('"').strip(),
209 |                     )
210 | 
211 |                     # Set this new module with our instruction / prefix
212 |                     *_, last_key = self._get_signature(p_new).fields.keys()
213 |                     updated_signature = (
214 |                         self._get_signature(p_new)
215 |                         .with_instructions(instruction)
216 |                         .with_updated_fields(last_key, prefix=prefix)
217 |                     )
218 |                     self._set_signature(p_new, updated_signature)
219 | 
220 |                     # Score the instruction / prefix
221 |                     for i, predictor in enumerate(module_clone.predictors()):
222 |                         logger.debug(f"Predictor {i+1}")
223 |                         self._print_signature(predictor)
224 |                     logger.info(
225 |                         f"At Depth {d+1}/{self.depth}, Evaluating Prompt Candidate #{c_i+1}/{len(candidates_)} for "
226 |                         f"Predictor {p_i+1} of {len(module.predictors())}.",
227 |                     )
228 |                     score = evaluate(module_clone, devset=trainset, **eval_kwargs).score
229 |                     if self.prompt_model:
230 |                         logger.debug(f"prompt_model.inspect_history(n=1) {self.prompt_model.inspect_history(n=1)}")
231 |                     total_calls += 1
232 | 
233 |                     replace_entry = True
234 |                     logger.debug(f"(instruction, prefix) {(instruction, prefix)}")
235 |                     if (instruction, prefix) in evaluated_candidates[id(p_old)]:
236 |                         if evaluated_candidates[id(p_old)][(instruction, prefix)]["score"] >= score:
237 |                             replace_entry = False
238 | 
239 |                     if replace_entry:
240 |                         # Add it to our evaluated candidates list
241 |                         evaluated_candidates[id(p_old)][(instruction, prefix)] = {
242 |                             "score": score,
243 |                             "program": module_clone.deepcopy(),
244 |                             "instruction": instruction,
245 |                             "prefix": prefix,
246 |                             "depth": d,
247 |                         }
248 | 
249 |                     if len(candidates_) - self.breadth <= c_i:
250 |                         latest_scores.append(score)
251 | 
252 |                 if self.track_stats:
253 |                     results_latest[id(p_old)]["depth"].append(d)
254 |                     results_latest[id(p_old)]["max"].append(max(latest_scores))
255 |                     results_latest[id(p_old)]["average"].append(sum(latest_scores) / len(latest_scores))
256 |                     results_latest[id(p_old)]["min"].append(min(latest_scores))
257 |                     results_latest[id(p_old)]["std"].append(np.std(latest_scores))
258 | 
259 |                 # Now that we've evaluated the candidates, set this predictor to the best performing version
260 |                 # to ensure the next round of scores reflect the best possible version
261 |                 best_candidate = max(evaluated_candidates[id(p_old)].values(), key=lambda candidate: candidate["score"])
262 |                 *_, last_key = self._get_signature(p_old).fields.keys()
263 |                 updated_signature = (
264 |                     self._get_signature(p_new)
265 |                     .with_instructions(best_candidate["instruction"])
266 |                     .with_updated_fields(last_key, prefix=best_candidate["prefix"])
267 |                 )
268 |                 self._set_signature(p_new, updated_signature)
269 | 
270 |                 logger.debug(
271 |                     f"Updating Predictor {id(p_old)} to:\ni: {best_candidate['instruction']}\n"
272 |                     f"p: {best_candidate['prefix']}",
273 |                 )
274 |                 logger.debug("Full predictor with update: ")
275 |                 for i, predictor in enumerate(module_clone.predictors()):
276 |                     logger.debug(f"Predictor {i}")
277 |                     self._print_signature(predictor)
278 | 
279 |             if d == self.depth - 1:
280 |                 break
281 | 
282 |             new_candidates = {}
283 |             for p_base in module.predictors():
284 |                 # Build Few-Shot Example of Optimized Prompts
285 |                 attempts = []
286 |                 shortest_len = self.breadth
287 |                 shortest_len = min(len(evaluated_candidates[id(p_base)]), shortest_len)
288 |                 best_predictors = list(evaluated_candidates[id(p_base)].values())
289 | 
290 |                 # best_predictors = evaluated_candidates[id(p_base)].values()[:]
291 |                 best_predictors.sort(key=lambda x: x["score"], reverse=True)
292 | 
293 |                 if self.track_stats:
294 |                     scores = [x["score"] for x in best_predictors][:10]
295 |                     results_best[id(p_base)]["depth"].append(d)
296 |                     results_best[id(p_base)]["max"].append(max(scores))
297 |                     results_best[id(p_base)]["average"].append(sum(scores) / len(scores))
298 |                     results_best[id(p_base)]["min"].append(min(scores))
299 |                     results_best[id(p_base)]["std"].append(np.std(scores))
300 | 
301 |                 for i in range(shortest_len - 1, -1, -1):
302 |                     # breakpoint()
303 |                     attempts.append(f'Instruction #{shortest_len-i}: {best_predictors[i]["instruction"]}')
304 |                     attempts.append(f'Prefix #{shortest_len-i}: {best_predictors[i]["prefix"]}')
305 |                     attempts.append(f'Resulting Score #{shortest_len-i}: {best_predictors[i]["score"]}')
306 | 
307 |                 # Generate next batch of potential prompts to optimize, with previous attempts as input
308 |                 if self.prompt_model:
309 |                     with dspy.settings.context(lm=self.prompt_model):
310 |                         instr = dspy.Predict(
311 |                             GenerateInstructionGivenAttempts,
312 |                             n=self.breadth,
313 |                             temperature=self.init_temperature,
314 |                         )(attempted_instructions=attempts)
315 |                 else:
316 |                     instr = dspy.Predict(
317 |                         GenerateInstructionGivenAttempts,
318 |                         n=self.breadth,
319 |                         temperature=self.init_temperature,
320 |                     )(attempted_instructions=attempts)
321 | 
322 |                 # Get candidates for each predictor
323 |                 new_candidates[id(p_base)] = instr.completions
324 |                 all_candidates[id(p_base)].proposed_instruction.extend(instr.completions.proposed_instruction)
325 |                 all_candidates[id(p_base)].proposed_prefix_for_output_field.extend(
326 |                     instr.completions.proposed_prefix_for_output_field,
327 |                 )
328 | 
329 |             latest_candidates = new_candidates
330 | 
331 |         candidates = []
332 |         for predictor in module.predictors():
333 |             candidates.extend(list(evaluated_candidates[id(predictor)].values()))
334 | 
335 |             if self.track_stats:
336 |                 best_predictors = list(evaluated_candidates[id(predictor)].values())
337 |                 best_predictors.sort(key=lambda x: x["score"], reverse=True)
338 | 
339 |                 scores = [x["score"] for x in best_predictors][:10]
340 |                 results_best[id(predictor)]["depth"].append(d)
341 |                 results_best[id(predictor)]["max"].append(max(scores))
342 |                 results_best[id(predictor)]["average"].append(sum(scores) / len(scores))
343 |                 results_best[id(predictor)]["min"].append(min(scores))
344 |                 results_best[id(predictor)]["std"].append(np.std(scores))
345 | 
346 |         candidates.sort(key=lambda x: x["score"], reverse=True)
347 | 
348 |         candidates = self._drop_duplicates(candidates)
349 | 
350 |         best_program = candidates[0]["program"]
351 |         best_program.candidate_programs = candidates
352 |         best_program.total_calls = total_calls
353 |         if self.track_stats:
354 |             best_program.results_best = results_best
355 |             best_program.results_latest = results_latest
356 | 
357 |         return best_program
358 | 
```

--------------------------------------------------------------------------------
/dspy/propose/grounded_proposer.py:
--------------------------------------------------------------------------------

```python
  1 | import random
  2 | 
  3 | import dspy
  4 | from dspy.propose.dataset_summary_generator import create_dataset_summary
  5 | from dspy.propose.propose_base import Proposer
  6 | from dspy.propose.utils import (
  7 |     create_example_string,
  8 |     create_predictor_level_history_string,
  9 |     get_dspy_source_code,
 10 |     strip_prefix,
 11 | )
 12 | from dspy.teleprompt.utils import get_prompt_model, get_signature
 13 | 
 14 | # Hardcoded variables (TODO: update)
 15 | MAX_INSTRUCT_IN_HISTORY = 5  # 10
 16 | 
 17 | TIPS = {
 18 |         "none": "",
 19 |         "creative": "Don't be afraid to be creative when creating the new instruction!",
 20 |         "simple": "Keep the instruction clear and concise.",
 21 |         "description": "Make sure your instruction is very informative and descriptive.",
 22 |         "high_stakes": "The instruction should include a high stakes scenario in which the LM must solve the task!",
 23 |         "persona": 'Include a persona that is relevant to the task in the instruction (ie. "You are a ...")',
 24 |     }
 25 | 
 26 | ### SIGNATURES USED TO HELP WITH INSTRUCTION GENERATION ###
 27 | 
 28 | class DescribeProgram(dspy.Signature):
 29 |     (
 30 |         """Below is some pseudo-code for a pipeline that solves tasks with calls to language models. Please describe what type of task this program appears to be designed to solve, and how it appears to work."""
 31 |     )
 32 |     program_code = dspy.InputField(
 33 |         format=str,
 34 |         desc="Pseudocode for a language model program designed to solve a particular task.",
 35 |         prefix="PROGRAM CODE:",
 36 |     )
 37 |     program_example = dspy.InputField(
 38 |         format=str,
 39 |         desc="An example of the program in use.",
 40 |         prefix="EXAMPLE OF PROGRAM IN USE:",
 41 |     )
 42 |     program_description = dspy.OutputField(
 43 |         desc="Describe what task the program is designed to solve, and how it goes about solving this task.",
 44 |         prefix="SUMMARY OF PROGRAM ABOVE:",
 45 |     )
 46 | 
 47 | 
 48 | class DescribeModule(dspy.Signature):
 49 |     (
 50 |         """Below is some pseudo-code for a pipeline that solves tasks with calls to language models. Please describe the purpose of one of the specified module in this pipeline."""
 51 |     )
 52 |     program_code = dspy.InputField(
 53 |         format=str,
 54 |         desc="Pseudocode for a language model program designed to solve a particular task.",
 55 |         prefix="PROGRAM CODE:",
 56 |     )
 57 |     program_example = dspy.InputField(
 58 |         format=str,
 59 |         desc="An example of the program in use.",
 60 |         prefix="EXAMPLE OF PROGRAM IN USE:",
 61 |     )
 62 |     program_description = dspy.InputField(
 63 |         desc="Summary of the task the program is designed to solve, and how it goes about solving it.",
 64 |         prefix="SUMMARY OF PROGRAM ABOVE:",
 65 |     )
 66 |     module = dspy.InputField(
 67 |         desc="The module in the program that we want to describe.", prefix="MODULE:",
 68 |     )
 69 |     module_description = dspy.OutputField(
 70 |         desc="Description of the module's role in the broader program.",
 71 |         prefix="MODULE DESCRIPTION:",
 72 |     )
 73 | 
 74 | 
 75 | def generate_instruction_class(
 76 |     use_dataset_summary=True,
 77 |     program_aware=True,
 78 |     use_task_demos=True,
 79 |     use_instruct_history=True,
 80 |     use_tip=True,
 81 | ):
 82 |     class GenerateSingleModuleInstruction(dspy.Signature):
 83 |         (
 84 |             """Use the information below to learn about a task that we are trying to solve using calls to an LM, then generate a new instruction that will be used to prompt a Language Model to better solve the task."""
 85 |         )
 86 |         if use_dataset_summary:
 87 |             dataset_description = dspy.InputField(
 88 |                 desc="A description of the dataset that we are using.",
 89 |                 prefix="DATASET SUMMARY:",
 90 |             )
 91 |         if program_aware:
 92 |             program_code = dspy.InputField(
 93 |                 format=str,
 94 |                 desc="Language model program designed to solve a particular task.",
 95 |                 prefix="PROGRAM CODE:",
 96 |             )
 97 |             program_description = dspy.InputField(
 98 |                 desc="Summary of the task the program is designed to solve, and how it goes about solving it.",
 99 |                 prefix="PROGRAM DESCRIPTION:",
100 |             )
101 |             module = dspy.InputField(
102 |                 desc="The module to create an instruction for.", prefix="MODULE:",
103 |             )
104 |             module_description = dspy.InputField(
105 |                 desc="Description of the module to create an instruction for.", prefix="MODULE DESCRIPTION:",
106 |             )
107 |         task_demos = dspy.InputField(
108 |             format=str,
109 |             desc="Example inputs/outputs of our module.",
110 |             prefix="TASK DEMO(S):",
111 |         )
112 |         if use_instruct_history:
113 |             previous_instructions = dspy.InputField(
114 |                 format=str,
115 |                 desc="Previous instructions we've attempted, along with their associated scores.",
116 |                 prefix="PREVIOUS INSTRUCTIONS:",
117 |             )
118 |         basic_instruction = dspy.InputField(
119 |             format=str, desc="Basic instruction.", prefix="BASIC INSTRUCTION:",
120 |         )
121 |         if use_tip:
122 |             tip = dspy.InputField(
123 |                 format=str,
124 |                 desc="A suggestion for how to go about generating the new instruction.",
125 |                 prefix="TIP:",
126 |             )
127 |         proposed_instruction = dspy.OutputField(
128 |             desc="Propose an instruction that will be used to prompt a Language Model to perform this task.",
129 |             prefix="PROPOSED INSTRUCTION:",
130 |         )
131 | 
132 |     return dspy.Predict(GenerateSingleModuleInstruction)
133 | 
134 | ### CLASS RESPONSIBLE FOR GENERATING A NEW INSTRUCTION, USING THE HELPER SIGNATURES ABOVE ###
135 | 
136 | class GenerateModuleInstruction(dspy.Module):
137 |     def __init__(
138 |         self,
139 |         program_code_string=None,
140 |         use_dataset_summary=True,
141 |         program_aware=False,
142 |         use_task_demos=True,
143 |         use_instruct_history=True,
144 |         use_tip=True,
145 |         verbose=False,
146 |     ):
147 |         super().__init__()
148 |         self.use_dataset_summary = use_dataset_summary
149 |         self.program_aware = program_aware
150 |         self.use_task_demos = use_task_demos
151 |         self.use_instruct_history = use_instruct_history
152 |         self.use_tip = use_tip
153 |         self.verbose = verbose
154 | 
155 |         self.program_code_string = program_code_string
156 |         self.describe_program = dspy.Predict(DescribeProgram)
157 |         self.describe_module = dspy.Predict(DescribeModule)
158 |         self.generate_module_instruction = generate_instruction_class(
159 |             use_dataset_summary=use_dataset_summary,
160 |             program_aware=program_aware,
161 |             use_task_demos=use_task_demos,
162 |             use_instruct_history=use_instruct_history,
163 |             use_tip=use_tip,
164 |         )
165 | 
166 |     def forward(
167 |         self,
168 |         demo_candidates,
169 |         pred_i,
170 |         demo_set_i,
171 |         program,
172 |         previous_instructions,
173 |         data_summary,
174 |         num_demos_in_context=3,
175 |         tip=None,
176 |     ):
177 |         def gather_examples_from_sets(candidate_sets, max_examples):
178 |             """Helper function to gather up to augmented examples from given sets."""
179 |             count = 0
180 |             for candidate_set in candidate_sets:
181 |                 for example in candidate_set:
182 |                     if "augmented" in example.keys():
183 |                         fields_to_use = get_signature(program.predictors()[pred_i]).fields
184 |                         yield create_example_string(fields_to_use, example)
185 |                         count += 1
186 |                         if count >= max_examples:
187 |                             return
188 | 
189 |         # Construct full program demo or single module demo depending on settings
190 |         basic_instruction = get_signature(program.predictors()[pred_i]).instructions
191 |         task_demos = ""
192 | 
193 |         if self.use_task_demos:
194 |             # Combine current and adjacent sets
195 |             adjacent_sets = (
196 |                 [demo_candidates[pred_i][demo_set_i]] +
197 |                 demo_candidates[pred_i][demo_set_i + 1:] +
198 |                 demo_candidates[pred_i][:demo_set_i]
199 |             )
200 | 
201 |             # Gather examples up to the required count
202 |             example_strings = gather_examples_from_sets(adjacent_sets, num_demos_in_context)
203 |             task_demos = "\n\n".join(example_strings) + "\n\n"
204 | 
205 |         # Default to no demos provided if no examples were gathered, or if we're using the first demo set
206 |         if not task_demos.strip() or demo_set_i == 0:
207 |             task_demos = "No task demos provided."
208 | 
209 |         # Summarize the program
210 |         program_description = "Not available"
211 |         module_code = "Not provided"
212 |         module_description = "Not provided"
213 |         if self.program_aware:
214 |             try:
215 |                 program_description = strip_prefix(
216 |                     self.describe_program(
217 |                         program_code=self.program_code_string, program_example=task_demos,
218 |                     ).program_description,
219 |                 )
220 |                 if self.verbose:
221 |                     print(f"PROGRAM DESCRIPTION: {program_description}")
222 | 
223 |                 inputs = []
224 |                 outputs = []
225 |                 for field_name, field in get_signature(program.predictors()[pred_i]).fields.items():
226 |                     # Access the '__dspy_field_type' from the extra metadata
227 |                     dspy_field_type = field.json_schema_extra.get("__dspy_field_type")
228 | 
229 |                     # Based on the '__dspy_field_type', append to the respective list
230 |                     if dspy_field_type == "input":
231 |                         inputs.append(field_name)
232 |                     else:
233 |                         outputs.append(field_name)
234 | 
235 |                 module_code = f"{program.predictors()[pred_i].__class__.__name__}({', '.join(inputs)}) -> {', '.join(outputs)}"
236 | 
237 |                 module_description = self.describe_module(
238 |                     program_code=self.program_code_string,
239 |                     program_description=program_description,
240 |                     program_example=task_demos,
241 |                     module=module_code,
242 |                     max_depth=10,
243 |                 ).module_description
244 |             except Exception as e:
245 |                 if self.verbose:
246 |                     print(f"Error getting program description. Running without program aware proposer. Error: {e}")
247 |                 self.program_aware = False
248 | 
249 |         # Generate an instruction for our chosen module
250 |         if self.verbose:
251 |             print(f"task_demos {task_demos}")
252 | 
253 |         instruct = self.generate_module_instruction(
254 |             dataset_description=data_summary,
255 |             program_code=self.program_code_string,
256 |             module=module_code,
257 |             program_description=program_description,
258 |             module_description=module_description,
259 |             task_demos=task_demos,
260 |             tip=tip,
261 |             basic_instruction=basic_instruction,
262 |             previous_instructions=previous_instructions,
263 |         )
264 | 
265 |         proposed_instruction = strip_prefix(instruct.proposed_instruction)
266 | 
267 |         return dspy.Prediction(proposed_instruction=proposed_instruction)
268 | 
269 | ### CLASS USED TO GENERATE THE FULL SET OF INSTRUCTIONS GIVEN THE SPECIFIED CRITERIA ###
270 | 
271 | class GroundedProposer(Proposer):
272 |     def __init__(
273 |         self,
274 |         prompt_model,
275 |         program,
276 |         trainset,
277 |         view_data_batch_size=10,
278 |         use_dataset_summary=True,
279 |         program_aware=True,
280 |         use_task_demos=True,
281 |         num_demos_in_context = 3,
282 |         use_instruct_history=True,
283 |         use_tip=True,
284 |         set_tip_randomly=True,
285 |         set_history_randomly=True,
286 |         verbose=False,
287 |         rng=None,
288 |         init_temperature: float = 1.0,
289 |     ):
290 |         super().__init__()
291 |         self.program_aware = program_aware
292 |         self.use_dataset_summary = use_dataset_summary
293 |         self.use_task_demos = use_task_demos
294 |         self.num_demos_in_context = num_demos_in_context
295 |         self.use_instruct_history = use_instruct_history
296 |         self.use_tip = use_tip
297 |         self.set_tip_randomly=set_tip_randomly
298 |         self.set_history_randomly=set_history_randomly
299 |         self.verbose = verbose
300 |         self.rng = rng or random
301 | 
302 |         self.prompt_model = get_prompt_model(prompt_model)
303 |         self.init_temperature = init_temperature
304 | 
305 |         self.program_code_string = None
306 |         if self.program_aware:
307 |             try:
308 |                 self.program_code_string = get_dspy_source_code(program)
309 |                 if self.verbose:
310 |                     print("SOURCE CODE:",self.program_code_string)
311 |             except Exception as e:
312 |                 print(f"Error getting source code: {e}.\n\nRunning without program aware proposer.")
313 |                 self.program_aware = False
314 | 
315 |         self.data_summary  = None
316 |         if self.use_dataset_summary:
317 |             try:
318 |                 self.data_summary = create_dataset_summary(
319 |                     trainset=trainset, view_data_batch_size=view_data_batch_size, prompt_model=prompt_model,
320 |                 )
321 |                 if self.verbose:
322 |                     print(f"DATA SUMMARY: {self.data_summary}")
323 |             except Exception as e:
324 |                 print(f"Error getting data summary: {e}.\n\nRunning without data aware proposer.")
325 |                 self.use_dataset_summary = False
326 |                 print("")
327 | 
328 |     def propose_instructions_for_program(
329 |         self,
330 |         trainset,
331 |         program,
332 |         demo_candidates,
333 |         trial_logs,
334 |         N, # noqa: N803
335 |     ) -> list[str]:
336 |         """This method is responsible for returning the full set of new instructions for our program, given the specified criteria."""
337 | 
338 |         proposed_instructions = {}
339 | 
340 |         if self.set_history_randomly:
341 |             # Randomly select whether or not we're using instruction history
342 |             use_history = self.rng.random() < 0.5
343 |             self.use_instruct_history = use_history
344 |             if self.verbose:
345 |                 print(f"Use history T/F: {self.use_instruct_history}")
346 | 
347 |         if not demo_candidates:
348 |             if self.verbose:
349 |                 print("No demo candidates provided. Running without task demos.")
350 |             self.use_task_demos = False
351 |             # When no demo candidates are provided, default to N
352 |             num_demos = N
353 |         else:
354 |             num_demos = max(len(demo_candidates[0]), 1)
355 | 
356 |         # Create an instruction for each predictor
357 |         for pred_i, predictor in enumerate(program.predictors()):
358 |             for demo_set_i in range(num_demos)[:min(N, num_demos)]:
359 |                 if pred_i not in proposed_instructions:
360 |                     proposed_instructions[pred_i] = []
361 |                 selected_tip = None
362 |                 if self.set_tip_randomly:
363 |                     if self.verbose:
364 |                         print("Using a randomly generated configuration for our grounded proposer.")
365 |                     # Randomly select the tip
366 |                     selected_tip_key = self.rng.choice(list(TIPS.keys()))
367 |                     selected_tip = TIPS[selected_tip_key]
368 |                     self.use_tip = bool(
369 |                         selected_tip,
370 |                     )
371 |                     if self.verbose:
372 |                         print(f"Selected tip: {selected_tip_key}")
373 | 
374 |                 proposed_instructions[pred_i].append(
375 |                     self.propose_instruction_for_predictor(
376 |                         program=program,
377 |                         predictor=predictor,
378 |                         pred_i=pred_i,
379 |                         demo_candidates=demo_candidates,
380 |                         demo_set_i=demo_set_i,
381 |                         trial_logs=trial_logs,
382 |                         tip=selected_tip,
383 |                     ),
384 |                 )
385 | 
386 |         return proposed_instructions
387 | 
388 |     def propose_instruction_for_predictor(
389 |         self,
390 |         program,
391 |         predictor,
392 |         pred_i,
393 |         demo_candidates,
394 |         demo_set_i,
395 |         trial_logs,
396 |         tip=None,
397 |     ) -> str:
398 |         """This method is responsible for returning a single instruction for a given predictor, using the specified criteria."""
399 | 
400 |         # Create an instruction history string for our predictor
401 |         instruction_history = create_predictor_level_history_string(
402 |             program, pred_i, trial_logs, MAX_INSTRUCT_IN_HISTORY,
403 |         )
404 | 
405 |         # Create our instruction generator class (given specific criteria for this round of proposal)
406 |         instruction_generator = GenerateModuleInstruction(
407 |             program_code_string=self.program_code_string,
408 |             use_dataset_summary=self.use_dataset_summary,
409 |             program_aware=self.program_aware,
410 |             use_task_demos=self.use_task_demos and demo_candidates,
411 |             use_instruct_history=self.use_instruct_history and instruction_history,
412 |             use_tip=self.use_tip,
413 |             verbose=self.verbose
414 |         )
415 | 
416 |         # Generate a new instruction for our predictor using a unique rollout id to bypass cache
417 |         rollout_lm = self.prompt_model.copy(
418 |             rollout_id=self.rng.randint(0, 10**9),
419 |             temperature=self.init_temperature,
420 |         )
421 | 
422 |         with dspy.settings.context(lm=rollout_lm):
423 |             proposed_instruction = instruction_generator(
424 |                 demo_candidates=demo_candidates,
425 |                 pred_i=pred_i,
426 |                 demo_set_i=demo_set_i,
427 |                 program=program,
428 |                 data_summary=self.data_summary,
429 |                 previous_instructions=instruction_history,
430 |                 num_demos_in_context = self.num_demos_in_context,
431 |                 tip=tip,
432 |             ).proposed_instruction
433 | 
434 |         # Log the trace used to generate the new instruction, along with the new instruction itself
435 |         if self.verbose:
436 |             self.prompt_model.inspect_history(n=1)
437 |             print(f"PROPOSED INSTRUCTION: {proposed_instruction}")
438 | 
439 |         return strip_prefix(proposed_instruction)
440 | 
```

--------------------------------------------------------------------------------
/tests/adapters/test_tool.py:
--------------------------------------------------------------------------------

```python
  1 | import asyncio
  2 | from typing import Any
  3 | 
  4 | import pytest
  5 | from pydantic import BaseModel
  6 | 
  7 | import dspy
  8 | from dspy.adapters.types.tool import Tool, ToolCalls, convert_input_schema_to_tool_args
  9 | 
 10 | 
 11 | # Test fixtures
 12 | def dummy_function(x: int, y: str = "hello") -> str:
 13 |     """A dummy function for testing.
 14 | 
 15 |     Args:
 16 |         x: An integer parameter
 17 |         y: A string parameter
 18 |     """
 19 |     return f"{y} {x}"
 20 | 
 21 | 
 22 | class DummyModel(BaseModel):
 23 |     field1: str = "hello"
 24 |     field2: int
 25 | 
 26 | 
 27 | def dummy_with_pydantic(model: DummyModel) -> str:
 28 |     """A dummy function that accepts a Pydantic model."""
 29 |     return f"{model.field1} {model.field2}"
 30 | 
 31 | 
 32 | class Address(BaseModel):
 33 |     street: str
 34 |     city: str
 35 |     zip_code: str
 36 |     is_primary: bool = False
 37 | 
 38 | 
 39 | class ContactInfo(BaseModel):
 40 |     email: str
 41 |     phone: str | None = None
 42 |     addresses: list[Address]
 43 | 
 44 | 
 45 | class UserProfile(BaseModel):
 46 |     user_id: int
 47 |     name: str
 48 |     age: int | None = None
 49 |     contact: ContactInfo
 50 |     tags: list[str] = []
 51 | 
 52 | 
 53 | class Note(BaseModel):
 54 |     content: str
 55 |     author: str
 56 | 
 57 | 
 58 | def complex_dummy_function(profile: UserProfile, priority: int, notes: list[Note] | None = None) -> dict[str, Any]:
 59 |     """Process user profile with complex nested structure.
 60 | 
 61 |     Args:
 62 |         profile: User profile containing nested contact and address information
 63 |         priority: Priority level of the processing
 64 |         notes: Optional processing notes
 65 |     """
 66 |     primary_address = next(
 67 |         (addr for addr in profile.contact.addresses if addr.is_primary), profile.contact.addresses[0]
 68 |     )
 69 | 
 70 |     return {
 71 |         "user_id": profile.user_id,
 72 |         "name": profile.name,
 73 |         "priority": priority,
 74 |         "primary_address": primary_address.model_dump(),
 75 |         "notes": notes,
 76 |     }
 77 | 
 78 | 
 79 | async def async_dummy_function(x: int, y: str = "hello") -> str:
 80 |     """An async dummy function for testing.
 81 | 
 82 |     Args:
 83 |         x: An integer parameter
 84 |         y: A string parameter
 85 |     """
 86 |     await asyncio.sleep(0.1)  # Simulate some async work
 87 |     return f"{y} {x}"
 88 | 
 89 | 
 90 | async def async_dummy_with_pydantic(model: DummyModel) -> str:
 91 |     """An async dummy function that accepts a Pydantic model."""
 92 |     await asyncio.sleep(0.1)  # Simulate some async work
 93 |     return f"{model.field1} {model.field2}"
 94 | 
 95 | 
 96 | async def async_complex_dummy_function(
 97 |     profile: UserProfile,
 98 |     priority: int,
 99 |     notes: list[Note] | None = None,
100 | ) -> dict[str, Any]:
101 |     """Process user profile with complex nested structure asynchronously.
102 | 
103 |     Args:
104 |         profile: User profile containing nested contact and address information
105 |         priority: Priority level of the processing
106 |         notes: Optional processing notes
107 |     """
108 |     # Simulate some async processing work
109 |     await asyncio.sleep(0.1)
110 | 
111 |     primary_address = next(
112 |         (addr for addr in profile.contact.addresses if addr.is_primary), profile.contact.addresses[0]
113 |     )
114 | 
115 |     # Simulate more async work after finding primary address
116 |     await asyncio.sleep(0.1)
117 | 
118 |     return {
119 |         "user_id": profile.user_id,
120 |         "name": profile.name,
121 |         "priority": priority,
122 |         "primary_address": primary_address.model_dump(),
123 |         "notes": notes,
124 |     }
125 | 
126 | 
127 | def test_basic_initialization():
128 |     tool = Tool(name="test_tool", desc="A test tool", args={"param1": {"type": "string"}}, func=lambda x: x)
129 |     assert tool.name == "test_tool"
130 |     assert tool.desc == "A test tool"
131 |     assert tool.args == {"param1": {"type": "string"}}
132 |     assert callable(tool.func)
133 | 
134 | 
135 | def test_tool_from_function():
136 |     tool = Tool(dummy_function)
137 | 
138 |     assert tool.name == "dummy_function"
139 |     assert "A dummy function for testing" in tool.desc
140 |     assert "x" in tool.args
141 |     assert "y" in tool.args
142 |     assert tool.args["x"]["type"] == "integer"
143 |     assert tool.args["y"]["type"] == "string"
144 |     assert tool.args["y"]["default"] == "hello"
145 | 
146 | 
147 | def test_tool_from_class():
148 |     class Foo:
149 |         def __init__(self, user_id: str):
150 |             self.user_id = user_id
151 | 
152 |         def __call__(self, a: int, b: int) -> int:
153 |             """Add two numbers."""
154 |             return a + b
155 | 
156 |     tool = Tool(Foo("123"))
157 |     assert tool.name == "Foo"
158 |     assert tool.desc == "Add two numbers."
159 |     assert tool.args == {"a": {"type": "integer"}, "b": {"type": "integer"}}
160 | 
161 | 
162 | def test_tool_from_function_with_pydantic():
163 |     tool = Tool(dummy_with_pydantic)
164 | 
165 |     assert tool.name == "dummy_with_pydantic"
166 |     assert "model" in tool.args
167 |     assert tool.args["model"]["type"] == "object"
168 |     assert "field1" in tool.args["model"]["properties"]
169 |     assert "field2" in tool.args["model"]["properties"]
170 |     assert tool.args["model"]["properties"]["field1"]["default"] == "hello"
171 | 
172 | 
173 | def test_tool_from_function_with_pydantic_nesting():
174 |     tool = Tool(complex_dummy_function)
175 | 
176 |     assert tool.name == "complex_dummy_function"
177 | 
178 |     assert "profile" in tool.args
179 |     assert "priority" in tool.args
180 |     assert "notes" in tool.args
181 |     assert tool.args["profile"]["type"] == "object"
182 |     assert tool.args["profile"]["properties"]["user_id"]["type"] == "integer"
183 |     assert tool.args["profile"]["properties"]["name"]["type"] == "string"
184 |     assert tool.args["profile"]["properties"]["age"]["anyOf"] == [{"type": "integer"}, {"type": "null"}]
185 |     assert tool.args["profile"]["properties"]["contact"]["type"] == "object"
186 |     assert tool.args["profile"]["properties"]["contact"]["properties"]["email"]["type"] == "string"
187 | 
188 |     # Reference should be resolved for nested pydantic models
189 |     assert "$defs" not in str(tool.args["notes"])
190 |     assert tool.args["notes"]["anyOf"][0]["type"] == "array"
191 |     assert tool.args["notes"]["anyOf"][0]["items"]["type"] == "object"
192 |     assert tool.args["notes"]["anyOf"][0]["items"]["properties"]["content"]["type"] == "string"
193 |     assert tool.args["notes"]["anyOf"][0]["items"]["properties"]["author"]["type"] == "string"
194 | 
195 | 
196 | def test_tool_callable():
197 |     tool = Tool(dummy_function)
198 |     result = tool(x=42, y="hello")
199 |     assert result == "hello 42"
200 | 
201 | 
202 | def test_tool_with_pydantic_callable():
203 |     tool = Tool(dummy_with_pydantic)
204 |     model = DummyModel(field1="test", field2=123)
205 |     result = tool(model=model)
206 |     assert result == "test 123"
207 | 
208 | 
209 | def test_invalid_function_call():
210 |     tool = Tool(dummy_function)
211 |     with pytest.raises(ValueError):
212 |         tool(x="not an integer", y="hello")
213 | 
214 | 
215 | def test_parameter_desc():
216 |     tool = Tool(dummy_function, arg_desc={"x": "The x parameter"})
217 |     assert tool.args["x"]["description"] == "The x parameter"
218 | 
219 | 
220 | def test_tool_with_default_args_without_type_hints():
221 |     def foo(x=100):
222 |         return x
223 | 
224 |     tool = Tool(foo)
225 |     assert tool.args["x"]["default"] == 100
226 |     assert not hasattr(tool.args["x"], "type")
227 | 
228 | 
229 | def test_tool_call_parses_args():
230 |     tool = Tool(dummy_with_pydantic)
231 | 
232 |     args = {
233 |         "model": {
234 |             "field1": "hello",
235 |             "field2": 123,
236 |         }
237 |     }
238 | 
239 |     result = tool(**args)
240 |     assert result == "hello 123"
241 | 
242 | 
243 | def test_tool_call_parses_nested_list_of_pydantic_model():
244 |     def dummy_function(x: list[list[DummyModel]]):
245 |         return x
246 | 
247 |     tool = Tool(dummy_function)
248 |     args = {
249 |         "x": [
250 |             [
251 |                 {
252 |                     "field1": "hello",
253 |                     "field2": 123,
254 |                 }
255 |             ]
256 |         ]
257 |     }
258 | 
259 |     result = tool(**args)
260 |     assert result == [[DummyModel(field1="hello", field2=123)]]
261 | 
262 | 
263 | def test_tool_call_kwarg():
264 |     def fn(x: int, **kwargs):
265 |         return kwargs
266 | 
267 |     tool = Tool(fn)
268 | 
269 |     assert tool(x=1, y=2, z=3) == {"y": 2, "z": 3}
270 | 
271 | 
272 | def test_tool_str():
273 |     def add(x: int, y: int = 0) -> int:
274 |         """Add two integers."""
275 |         return x + y
276 | 
277 |     tool = Tool(add)
278 |     assert (
279 |         str(tool)
280 |         == "add, whose description is <desc>Add two integers.</desc>. It takes arguments {'x': {'type': 'integer'}, 'y': {'type': 'integer', 'default': 0}}."
281 |     )
282 | 
283 | 
284 | @pytest.mark.asyncio
285 | async def test_async_tool_from_function():
286 |     tool = Tool(async_dummy_function)
287 | 
288 |     assert tool.name == "async_dummy_function"
289 |     assert "An async dummy function for testing" in tool.desc
290 |     assert "x" in tool.args
291 |     assert "y" in tool.args
292 |     assert tool.args["x"]["type"] == "integer"
293 |     assert tool.args["y"]["type"] == "string"
294 |     assert tool.args["y"]["default"] == "hello"
295 | 
296 |     # Test async call
297 |     result = await tool.acall(x=42, y="hello")
298 |     assert result == "hello 42"
299 | 
300 | 
301 | @pytest.mark.asyncio
302 | async def test_async_tool_with_pydantic():
303 |     tool = Tool(async_dummy_with_pydantic)
304 | 
305 |     assert tool.name == "async_dummy_with_pydantic"
306 |     assert "model" in tool.args
307 |     assert tool.args["model"]["type"] == "object"
308 |     assert "field1" in tool.args["model"]["properties"]
309 |     assert "field2" in tool.args["model"]["properties"]
310 | 
311 |     # Test async call with pydantic model
312 |     model = DummyModel(field1="test", field2=123)
313 |     result = await tool.acall(model=model)
314 |     assert result == "test 123"
315 | 
316 |     # Test async call with dict
317 |     result = await tool.acall(model={"field1": "test", "field2": 123})
318 |     assert result == "test 123"
319 | 
320 | 
321 | @pytest.mark.asyncio
322 | async def test_async_tool_with_complex_pydantic():
323 |     tool = Tool(async_complex_dummy_function)
324 | 
325 |     profile = UserProfile(
326 |         user_id=1,
327 |         name="Test User",
328 |         contact=ContactInfo(
329 |             email="[email protected]",
330 |             addresses=[
331 |                 Address(street="123 Main St", city="Test City", zip_code="12345", is_primary=True),
332 |                 Address(street="456 Side St", city="Test City", zip_code="12345"),
333 |             ],
334 |         ),
335 |     )
336 | 
337 |     result = await tool.acall(profile=profile, priority=1, notes=[Note(content="Test note", author="Test author")])
338 |     assert result["user_id"] == 1
339 |     assert result["name"] == "Test User"
340 |     assert result["priority"] == 1
341 |     assert result["notes"] == [Note(content="Test note", author="Test author")]
342 |     assert result["primary_address"]["street"] == "123 Main St"
343 | 
344 | 
345 | @pytest.mark.asyncio
346 | async def test_async_tool_invalid_call():
347 |     tool = Tool(async_dummy_function)
348 |     with pytest.raises(ValueError):
349 |         await tool.acall(x="not an integer", y="hello")
350 | 
351 | 
352 | @pytest.mark.asyncio
353 | async def test_async_tool_with_kwargs():
354 |     async def fn(x: int, **kwargs):
355 |         return kwargs
356 | 
357 |     tool = Tool(fn)
358 | 
359 |     result = await tool.acall(x=1, y=2, z=3)
360 |     assert result == {"y": 2, "z": 3}
361 | 
362 | 
363 | @pytest.mark.asyncio
364 | async def test_async_concurrent_calls():
365 |     """Test that multiple async tools can run concurrently."""
366 |     tool = Tool(async_dummy_function)
367 | 
368 |     # Create multiple concurrent calls
369 |     tasks = [tool.acall(x=i, y=f"hello{i}") for i in range(5)]
370 | 
371 |     # Run them concurrently and measure time
372 |     start_time = asyncio.get_event_loop().time()
373 |     results = await asyncio.gather(*tasks)
374 |     end_time = asyncio.get_event_loop().time()
375 | 
376 |     # Verify results, `asyncio.gather` returns results in the order of the tasks
377 |     assert results == [f"hello{i} {i}" for i in range(5)]
378 | 
379 |     # Check that it ran concurrently (should take ~0.1s, not ~0.5s)
380 |     # We use 0.3s as threshold to account for some overhead
381 |     assert end_time - start_time < 0.3
382 | 
383 | 
384 | @pytest.mark.filterwarnings("ignore::RuntimeWarning")
385 | def test_async_tool_call_in_sync_mode():
386 |     tool = Tool(async_dummy_function)
387 |     with dspy.context(allow_tool_async_sync_conversion=False):
388 |         with pytest.raises(ValueError):
389 |             result = tool(x=1, y="hello")
390 | 
391 |     with dspy.context(allow_tool_async_sync_conversion=True):
392 |         result = tool(x=1, y="hello")
393 |         assert result == "hello 1"
394 | 
395 | 
396 | TOOL_CALL_TEST_CASES = [
397 |     ([], {"tool_calls": []}),
398 |     (
399 |         [{"name": "search", "args": {"query": "hello"}}],
400 |         {
401 |             "tool_calls": [{"type": "function", "function": {"name": "search", "arguments": {"query": "hello"}}}],
402 |         },
403 |     ),
404 |     (
405 |         [
406 |             {"name": "search", "args": {"query": "hello"}},
407 |             {"name": "translate", "args": {"text": "world", "lang": "fr"}},
408 |         ],
409 |         {
410 |             "tool_calls": [
411 |                 {"type": "function", "function": {"name": "search", "arguments": {"query": "hello"}}},
412 |                 {
413 |                     "type": "function",
414 |                     "function": {"name": "translate", "arguments": {"text": "world", "lang": "fr"}},
415 |                 },
416 |             ],
417 |         },
418 |     ),
419 |     (
420 |         [{"name": "get_time", "args": {}}],
421 |         {
422 |             "tool_calls": [{"type": "function", "function": {"name": "get_time", "arguments": {}}}],
423 |         },
424 |     ),
425 | ]
426 | 
427 | 
428 | @pytest.mark.parametrize("tool_calls_data,expected", TOOL_CALL_TEST_CASES)
429 | def test_tool_calls_format_basic(tool_calls_data, expected):
430 |     """Test ToolCalls.format with various basic scenarios."""
431 |     tool_calls_list = [ToolCalls.ToolCall(**data) for data in tool_calls_data]
432 |     tool_calls = ToolCalls(tool_calls=tool_calls_list)
433 |     result = tool_calls.format()
434 | 
435 |     assert result == expected
436 | 
437 | 
438 | def test_tool_calls_format_from_dict_list():
439 |     """Test format works with ToolCalls created from from_dict_list."""
440 |     tool_calls_dicts = [
441 |         {"name": "search", "args": {"query": "hello"}},
442 |         {"name": "translate", "args": {"text": "world", "lang": "fr"}},
443 |     ]
444 | 
445 |     tool_calls = ToolCalls.from_dict_list(tool_calls_dicts)
446 |     result = tool_calls.format()
447 | 
448 |     assert len(result["tool_calls"]) == 2
449 |     assert result["tool_calls"][0]["function"]["name"] == "search"
450 |     assert result["tool_calls"][1]["function"]["name"] == "translate"
451 | 
452 | 
453 | def test_toolcalls_vague_match():
454 |     """
455 |     Test that ToolCalls can parse the data with slightly off format:
456 | 
457 |     - a single dict with "name" and "args"
458 |     - a list of dicts with "name" and "args"
459 |     - invalid input (should raise ValueError)
460 |     """
461 |     # Single dict with "name" and "args" should parse as one ToolCall
462 |     data_single = {"name": "search", "args": {"query": "hello"}}
463 |     tc = ToolCalls.model_validate(data_single)
464 |     assert isinstance(tc, ToolCalls)
465 |     assert len(tc.tool_calls) == 1
466 |     assert tc.tool_calls[0].name == "search"
467 |     assert tc.tool_calls[0].args == {"query": "hello"}
468 | 
469 |     # List of dicts with "name" and "args" should parse as multiple ToolCalls
470 |     data_list = [
471 |         {"name": "search", "args": {"query": "hello"}},
472 |         {"name": "translate", "args": {"text": "world", "lang": "fr"}},
473 |     ]
474 |     tc = ToolCalls.model_validate(data_list)
475 |     assert isinstance(tc, ToolCalls)
476 |     assert len(tc.tool_calls) == 2
477 |     assert tc.tool_calls[0].name == "search"
478 |     assert tc.tool_calls[1].name == "translate"
479 | 
480 |     # Dict with "tool_calls" key containing a list of dicts
481 |     data_tool_calls = {
482 |         "tool_calls": [
483 |             {"name": "search", "args": {"query": "hello"}},
484 |             {"name": "get_time", "args": {}},
485 |         ]
486 |     }
487 |     tc = ToolCalls.model_validate(data_tool_calls)
488 |     assert isinstance(tc, ToolCalls)
489 |     assert len(tc.tool_calls) == 2
490 |     assert tc.tool_calls[0].name == "search"
491 |     assert tc.tool_calls[1].name == "get_time"
492 | 
493 |     # Invalid input should raise ValueError
494 |     with pytest.raises(ValueError):
495 |         ToolCalls.model_validate({"foo": "bar"})
496 |     with pytest.raises(ValueError):
497 |         ToolCalls.model_validate([{"foo": "bar"}])
498 | 
499 | 
500 | def test_tool_convert_input_schema_to_tool_args_no_input_params():
501 |     args, arg_types, arg_desc = convert_input_schema_to_tool_args(schema={"properties": {}})
502 |     assert args == {}
503 |     assert arg_types == {}
504 |     assert arg_desc == {}
505 | 
506 | 
507 | def test_tool_convert_input_schema_to_tool_args_lang_chain():
508 |     # Example from langchain docs:
509 |     # https://web.archive.org/web/20250723101359/https://api.python.langchain.com/en/latest/tools/langchain_core.tools.tool.html
510 |     args, arg_types, arg_desc = convert_input_schema_to_tool_args(
511 |         schema={
512 |             "title": "fooSchema",
513 |             "description": "The foo.",
514 |             "type": "object",
515 |             "properties": {
516 |                 "bar": {
517 |                     "title": "Bar",
518 |                     "description": "The bar.",
519 |                     "type": "string",
520 |                 },
521 |                 "baz": {
522 |                     "title": "Baz",
523 |                     "type": "integer",
524 |                 },
525 |             },
526 |             "required": [
527 |                 "baz",
528 |             ],
529 |         }
530 |     )
531 |     assert args == {
532 |         "bar": {"title": "Bar", "description": "The bar.", "type": "string"},
533 |         "baz": {"title": "Baz", "type": "integer"},
534 |     }
535 |     assert arg_types == {
536 |         "bar": str,
537 |         "baz": int,
538 |     }
539 |     assert arg_desc == {
540 |         "bar": "The bar.",
541 |         "baz": "No description provided. (Required)",
542 |     }
543 | 
544 | 
545 | 
546 | 
547 | def test_tool_call_execute():
548 |     def get_weather(city: str) -> str:
549 |         return f"The weather in {city} is sunny"
550 | 
551 |     def add_numbers(a: int, b: int) -> int:
552 |         return a + b
553 | 
554 |     tools = [
555 |         dspy.Tool(get_weather),
556 |         dspy.Tool(add_numbers)
557 |     ]
558 | 
559 |     tool_call = dspy.ToolCalls.ToolCall(name="get_weather", args={"city": "Berlin"})
560 |     result = tool_call.execute(functions=tools)
561 |     assert result == "The weather in Berlin is sunny"
562 | 
563 |     # Test individual tool call with function dict
564 |     tool_call2 = dspy.ToolCalls.ToolCall(name="add_numbers", args={"a": 7, "b": 13})
565 |     result2 = tool_call2.execute(functions={"add_numbers": add_numbers})
566 |     assert result2 == 20
567 | 
568 |     # Test individual tool call with no arguments
569 |     def get_pi():
570 |         return 3.14159
571 | 
572 |     tool_call3 = dspy.ToolCalls.ToolCall(name="get_pi", args={})
573 |     result3 = tool_call3.execute(functions={"get_pi": get_pi})
574 |     assert result3 == 3.14159
575 | 
576 |     # Test error case
577 |     tool_call4 = dspy.ToolCalls.ToolCall(name="nonexistent", args={})
578 |     try:
579 |         tool_call4.execute(functions=tools)
580 |         assert False, "Should have raised ValueError"
581 |     except ValueError as e:
582 |         assert "not found" in str(e)
583 | 
584 | 
585 | def test_tool_call_execute_with_local_functions():
586 |     def main():
587 |         def local_add(a: int, b: int) -> int:
588 |             return a + b
589 | 
590 |         def local_multiply(x: int, y: int) -> int:
591 |             return x * y
592 | 
593 |         # Test individual execution with local function
594 |         tool_call1 = dspy.ToolCalls.ToolCall(name="local_add", args={"a": 10, "b": 15})
595 |         result1 = tool_call1.execute()  # Should find local function automatically
596 |         assert result1 == 25
597 | 
598 |         tool_call2 = dspy.ToolCalls.ToolCall(name="local_multiply", args={"x": 4, "y": 7})
599 |         result2 = tool_call2.execute()  # Should find local function automatically
600 |         assert result2 == 28
601 | 
602 |         # Test locals take precedence over globals
603 |         try:
604 |             globals()["local_add"] = lambda a, b: a + b + 1000
605 |             precedence_call = dspy.ToolCalls.ToolCall(name="local_add", args={"a": 1, "b": 2})
606 |             result = precedence_call.execute()
607 |             assert result == 3  # Should use local function (1+2=3), not global (1+2+1000=1003)
608 |         finally:
609 |             globals().pop("local_add", None)
610 | 
611 |     main()
612 | 
```

--------------------------------------------------------------------------------
/docs/docs/tutorials/streaming/index.md:
--------------------------------------------------------------------------------

```markdown
  1 | # Streaming
  2 | 
  3 | In this guide, we will walk you through how to enable streaming in your DSPy program. DSPy Streaming
  4 | consists of two parts:
  5 | 
  6 | - **Output Token Streaming**: Stream individual tokens as they're generated, rather than waiting for the complete response.
  7 | - **Intermediate Status Streaming**: Provide real-time updates about the program's execution state (e.g., "Calling web search...", "Processing results...").
  8 | 
  9 | ## Output Token Streaming
 10 | 
 11 | DSPy's token streaming feature works with any module in your pipeline, not just the final output. The only requirement is that the streamed field must be of type `str`. To enable token streaming:
 12 | 
 13 | 1. Wrap your program with `dspy.streamify`
 14 | 2. Create one or more `dspy.streaming.StreamListener` objects to specify which fields to stream
 15 | 
 16 | Here's a basic example:
 17 | 
 18 | ```python
 19 | import os
 20 | 
 21 | import dspy
 22 | 
 23 | os.environ["OPENAI_API_KEY"] = "your_api_key"
 24 | 
 25 | dspy.configure(lm=dspy.LM("openai/gpt-4o-mini"))
 26 | 
 27 | predict = dspy.Predict("question->answer")
 28 | 
 29 | # Enable streaming for the 'answer' field
 30 | stream_predict = dspy.streamify(
 31 |     predict,
 32 |     stream_listeners=[dspy.streaming.StreamListener(signature_field_name="answer")],
 33 | )
 34 | ```
 35 | 
 36 | To consume the streamed output:
 37 | 
 38 | ```python
 39 | import asyncio
 40 | 
 41 | async def read_output_stream():
 42 |     output_stream = stream_predict(question="Why did a chicken cross the kitchen?")
 43 | 
 44 |     async for chunk in output_stream:
 45 |         print(chunk)
 46 | 
 47 | asyncio.run(read_output_stream())
 48 | ```
 49 | 
 50 | This will produce output like:
 51 | 
 52 | ```
 53 | StreamResponse(predict_name='self', signature_field_name='answer', chunk='To')
 54 | StreamResponse(predict_name='self', signature_field_name='answer', chunk=' get')
 55 | StreamResponse(predict_name='self', signature_field_name='answer', chunk=' to')
 56 | StreamResponse(predict_name='self', signature_field_name='answer', chunk=' the')
 57 | StreamResponse(predict_name='self', signature_field_name='answer', chunk=' other')
 58 | StreamResponse(predict_name='self', signature_field_name='answer', chunk=' side of the frying pan!')
 59 | Prediction(
 60 |     answer='To get to the other side of the frying pan!'
 61 | )
 62 | ```
 63 | 
 64 | Note: Since `dspy.streamify` returns an async generator, you must use it within an async context. If you're using an environment like Jupyter or Google Colab that already has an event loop (async context), you can use the generator directly.
 65 | 
 66 | You may have noticed that the above streaming contains two different entities: `StreamResponse`
 67 | and `Prediction.` `StreamResponse` is the wrapper over streaming tokens on the field being listened to, and in
 68 | this example it is the `answer` field. `Prediction` is the program's final output. In DSPy, streaming is
 69 | implemented in a sidecar fashion: we enable streaming on the LM so that LM outputs a stream of tokens. We send these
 70 | tokens to a side channel, which is being continuously read by the user-defined listeners. Listeners keep interpreting
 71 | the stream, and decides if the `signature_field_name` it is listening to has started to appear and has finalized.
 72 | Once it decides that the field appears, the listener begins outputting tokens to the async generator users can
 73 | read. Listeners' internal mechanism changes according to the adapter behind the scene, and because usually
 74 | we cannot decide if a field has finalized until seeing the next field, the listener buffers the output tokens
 75 | before sending to the final generator, which is why you will usually see the last chunk of type `StreamResponse`
 76 | has more than one token. The program's output is also written to the stream, which is the chunk of `Prediction`
 77 | as in the sample output above.
 78 | 
 79 | To handle these different types and implement custom logic:
 80 | 
 81 | ```python
 82 | import asyncio
 83 | 
 84 | async def read_output_stream():
 85 |   output_stream = stream_predict(question="Why did a chicken cross the kitchen?")
 86 | 
 87 |   async for chunk in output_stream:
 88 |     return_value = None
 89 |     if isinstance(chunk, dspy.streaming.StreamResponse):
 90 |       print(f"Output token of field {chunk.signature_field_name}: {chunk.chunk}")
 91 |     elif isinstance(chunk, dspy.Prediction):
 92 |       return_value = chunk
 93 | 
 94 | 
 95 | program_output = asyncio.run(read_output_stream())
 96 | print("Final output: ", program_output)
 97 | ```
 98 | 
 99 | ### Understand `StreamResponse`
100 | 
101 | `StreamResponse` (`dspy.streaming.StreamResponse`) is the wrapper class of streaming tokens. It comes with 3
102 | fields:
103 | 
104 | - `predict_name`: the name of the predict that holds the `signature_field_name`. The name is the
105 |   same name of keys as you run `your_program.named_predictors()`. In the code above because `answer` is from
106 |   the `predict` itself, so the `predict_name` shows up as `self`, which is the only key as your run
107 |   `predict.named_predictors()`.
108 | - `signature_field_name`: the output field that these tokens map to. `predict_name` and `signature_field_name`
109 |   together form the unique identifier of the field. We will demonstrate how to handle multiple fields streaming
110 |   and duplicated field name later in this guide.
111 | - `chunk`: the value of the stream chunk.
112 | 
113 | ### Streaming with Cache
114 | 
115 | When a cached result is found, the stream will skip individual tokens and only yield the final `Prediction`. For example:
116 | 
117 | ```
118 | Prediction(
119 |     answer='To get to the other side of the dinner plate!'
120 | )
121 | ```
122 | 
123 | ### Streaming Multiple Fields
124 | 
125 | You can monitor multiple fields by creating a `StreamListener` for each one. Here's an example with a multi-module program:
126 | 
127 | ```python
128 | import asyncio
129 | 
130 | import dspy
131 | 
132 | lm = dspy.LM("openai/gpt-4o-mini", cache=False)
133 | dspy.settings.configure(lm=lm)
134 | 
135 | 
136 | class MyModule(dspy.Module):
137 |     def __init__(self):
138 |         super().__init__()
139 | 
140 |         self.predict1 = dspy.Predict("question->answer")
141 |         self.predict2 = dspy.Predict("answer->simplified_answer")
142 | 
143 |     def forward(self, question: str, **kwargs):
144 |         answer = self.predict1(question=question)
145 |         simplified_answer = self.predict2(answer=answer)
146 |         return simplified_answer
147 | 
148 | 
149 | predict = MyModule()
150 | stream_listeners = [
151 |     dspy.streaming.StreamListener(signature_field_name="answer"),
152 |     dspy.streaming.StreamListener(signature_field_name="simplified_answer"),
153 | ]
154 | stream_predict = dspy.streamify(
155 |     predict,
156 |     stream_listeners=stream_listeners,
157 | )
158 | 
159 | async def read_output_stream():
160 |     output = stream_predict(question="why did a chicken cross the kitchen?")
161 | 
162 |     return_value = None
163 |     async for chunk in output:
164 |         if isinstance(chunk, dspy.streaming.StreamResponse):
165 |             print(chunk)
166 |         elif isinstance(chunk, dspy.Prediction):
167 |             return_value = chunk
168 |     return return_value
169 | 
170 | program_output = asyncio.run(read_output_stream())
171 | print("Final output: ", program_output)
172 | ```
173 | 
174 | The output will look like:
175 | 
176 | ```
177 | StreamResponse(predict_name='predict1', signature_field_name='answer', chunk='To')
178 | StreamResponse(predict_name='predict1', signature_field_name='answer', chunk=' get')
179 | StreamResponse(predict_name='predict1', signature_field_name='answer', chunk=' to')
180 | StreamResponse(predict_name='predict1', signature_field_name='answer', chunk=' the')
181 | StreamResponse(predict_name='predict1', signature_field_name='answer', chunk=' other side of the recipe!')
182 | StreamResponse(predict_name='predict2', signature_field_name='simplified_answer', chunk='To')
183 | StreamResponse(predict_name='predict2', signature_field_name='simplified_answer', chunk=' reach')
184 | StreamResponse(predict_name='predict2', signature_field_name='simplified_answer', chunk=' the')
185 | StreamResponse(predict_name='predict2', signature_field_name='simplified_answer', chunk=' other side of the recipe!')
186 | Final output:  Prediction(
187 |     simplified_answer='To reach the other side of the recipe!'
188 | )
189 | ```
190 | 
191 | ### Streaming the Same Field Multiple Times (as in dspy.ReAct)
192 | 
193 | By default, a `StreamListener` automatically closes itself after completing a single streaming session.
194 | This design helps prevent performance issues, since every token is broadcast to all configured stream listeners,
195 | and having too many active listeners can introduce significant overhead.
196 | 
197 | However, in scenarios where a DSPy module is used repeatedly in a loop—such as with `dspy.ReAct` — you may want to stream
198 | the same field from each prediction, every time it is used. To enable this behavior, set allow_reuse=True when creating
199 | your `StreamListener`. See the example below:
200 | 
201 | ```python
202 | import asyncio
203 | 
204 | import dspy
205 | 
206 | lm = dspy.LM("openai/gpt-4o-mini", cache=False)
207 | dspy.settings.configure(lm=lm)
208 | 
209 | 
210 | def fetch_user_info(user_name: str):
211 |     """Get user information like name, birthday, etc."""
212 |     return {
213 |         "name": user_name,
214 |         "birthday": "2009-05-16",
215 |     }
216 | 
217 | 
218 | def get_sports_news(year: int):
219 |     """Get sports news for a given year."""
220 |     if year == 2009:
221 |         return "Usane Bolt broke the world record in the 100m race."
222 |     return None
223 | 
224 | 
225 | react = dspy.ReAct("question->answer", tools=[fetch_user_info, get_sports_news])
226 | 
227 | stream_listeners = [
228 |     # dspy.ReAct has a built-in output field called "next_thought".
229 |     dspy.streaming.StreamListener(signature_field_name="next_thought", allow_reuse=True),
230 | ]
231 | stream_react = dspy.streamify(react, stream_listeners=stream_listeners)
232 | 
233 | 
234 | async def read_output_stream():
235 |     output = stream_react(question="What sports news happened in the year Adam was born?")
236 |     return_value = None
237 |     async for chunk in output:
238 |         if isinstance(chunk, dspy.streaming.StreamResponse):
239 |             print(chunk)
240 |         elif isinstance(chunk, dspy.Prediction):
241 |             return_value = chunk
242 |     return return_value
243 | 
244 | 
245 | print(asyncio.run(read_output_stream()))
246 | ```
247 | 
248 | In this example, by setting `allow_reuse=True` in the StreamListener, you ensure that streaming for "next_thought" is
249 | available for every iteration, not just the first. When you run this code, you will see the streaming tokens for `next_thought`
250 | output each time the field is produced.
251 | 
252 | #### Handling Duplicate Field Names
253 | 
254 | When streaming fields with the same name from different modules, specify both the `predict` and `predict_name` in the `StreamListener`:
255 | 
256 | ```python
257 | import asyncio
258 | 
259 | import dspy
260 | 
261 | lm = dspy.LM("openai/gpt-4o-mini", cache=False)
262 | dspy.settings.configure(lm=lm)
263 | 
264 | 
265 | class MyModule(dspy.Module):
266 |     def __init__(self):
267 |         super().__init__()
268 | 
269 |         self.predict1 = dspy.Predict("question->answer")
270 |         self.predict2 = dspy.Predict("question, answer->answer, score")
271 | 
272 |     def forward(self, question: str, **kwargs):
273 |         answer = self.predict1(question=question)
274 |         simplified_answer = self.predict2(answer=answer)
275 |         return simplified_answer
276 | 
277 | 
278 | predict = MyModule()
279 | stream_listeners = [
280 |     dspy.streaming.StreamListener(
281 |         signature_field_name="answer",
282 |         predict=predict.predict1,
283 |         predict_name="predict1"
284 |     ),
285 |     dspy.streaming.StreamListener(
286 |         signature_field_name="answer",
287 |         predict=predict.predict2,
288 |         predict_name="predict2"
289 |     ),
290 | ]
291 | stream_predict = dspy.streamify(
292 |     predict,
293 |     stream_listeners=stream_listeners,
294 | )
295 | 
296 | 
297 | async def read_output_stream():
298 |     output = stream_predict(question="why did a chicken cross the kitchen?")
299 | 
300 |     return_value = None
301 |     async for chunk in output:
302 |         if isinstance(chunk, dspy.streaming.StreamResponse):
303 |             print(chunk)
304 |         elif isinstance(chunk, dspy.Prediction):
305 |             return_value = chunk
306 |     return return_value
307 | 
308 | 
309 | program_output = asyncio.run(read_output_stream())
310 | print("Final output: ", program_output)
311 | ```
312 | 
313 | The output will be like:
314 | 
315 | ```
316 | StreamResponse(predict_name='predict1', signature_field_name='answer', chunk='To')
317 | StreamResponse(predict_name='predict1', signature_field_name='answer', chunk=' get')
318 | StreamResponse(predict_name='predict1', signature_field_name='answer', chunk=' to')
319 | StreamResponse(predict_name='predict1', signature_field_name='answer', chunk=' the')
320 | StreamResponse(predict_name='predict1', signature_field_name='answer', chunk=' other side of the recipe!')
321 | StreamResponse(predict_name='predict2', signature_field_name='answer', chunk="I'm")
322 | StreamResponse(predict_name='predict2', signature_field_name='answer', chunk=' ready')
323 | StreamResponse(predict_name='predict2', signature_field_name='answer', chunk=' to')
324 | StreamResponse(predict_name='predict2', signature_field_name='answer', chunk=' assist')
325 | StreamResponse(predict_name='predict2', signature_field_name='answer', chunk=' you')
326 | StreamResponse(predict_name='predict2', signature_field_name='answer', chunk='! Please provide a question.')
327 | Final output:  Prediction(
328 |     answer="I'm ready to assist you! Please provide a question.",
329 |     score='N/A'
330 | )
331 | ```
332 | 
333 | ## Intermediate Status Streaming
334 | 
335 | Status streaming keeps users informed about the program's progress, especially useful for long-running operations like tool calls or complex AI pipelines. To implement status streaming:
336 | 
337 | 1. Create a custom status message provider by subclassing `dspy.streaming.StatusMessageProvider`
338 | 2. Override the desired hook methods to provide custom status messages
339 | 3. Pass your provider to `dspy.streamify`
340 | 
341 | Example:
342 | 
343 | ```python
344 | class MyStatusMessageProvider(dspy.streaming.StatusMessageProvider):
345 |     def lm_start_status_message(self, instance, inputs):
346 |         return f"Calling LM with inputs {inputs}..."
347 | 
348 |     def lm_end_status_message(self, outputs):
349 |         return f"Tool finished with output: {outputs}!"
350 | ```
351 | 
352 | Available hooks:
353 | 
354 | - lm_start_status_message: status message at the start of calling dspy.LM.
355 | - lm_end_status_message: status message at the end of calling dspy.LM.
356 | - module_start_status_message: status message at the start of calling a dspy.Module.
357 | - module_end_status_message: status message at the start of calling a dspy.Module.
358 | - tool_start_status_message: status message at the start of calling dspy.Tool.
359 | - tool_end_status_message: status message at the end of calling dspy.Tool.
360 | 
361 | Each hook should return a string containing the status message.
362 | 
363 | After creating the message provider, just pass it to `dspy.streamify`, and you can enable both
364 | status message streaming and output token streaming. Please see the example below. The intermediate
365 | status message is represented in the class `dspy.streaming.StatusMessage`, so we need to have
366 | another condition check to capture it.
367 | 
368 | ```python
369 | import asyncio
370 | 
371 | import dspy
372 | 
373 | lm = dspy.LM("openai/gpt-4o-mini", cache=False)
374 | dspy.settings.configure(lm=lm)
375 | 
376 | 
377 | class MyModule(dspy.Module):
378 |     def __init__(self):
379 |         super().__init__()
380 | 
381 |         self.tool = dspy.Tool(lambda x: 2 * x, name="double_the_number")
382 |         self.predict = dspy.ChainOfThought("num1, num2->sum")
383 | 
384 |     def forward(self, num, **kwargs):
385 |         num2 = self.tool(x=num)
386 |         return self.predict(num1=num, num2=num2)
387 | 
388 | 
389 | class MyStatusMessageProvider(dspy.streaming.StatusMessageProvider):
390 |     def tool_start_status_message(self, instance, inputs):
391 |         return f"Calling Tool {instance.name} with inputs {inputs}..."
392 | 
393 |     def tool_end_status_message(self, outputs):
394 |         return f"Tool finished with output: {outputs}!"
395 | 
396 | 
397 | predict = MyModule()
398 | stream_listeners = [
399 |     # dspy.ChainOfThought has a built-in output field called "reasoning".
400 |     dspy.streaming.StreamListener(signature_field_name="reasoning"),
401 | ]
402 | stream_predict = dspy.streamify(
403 |     predict,
404 |     stream_listeners=stream_listeners,
405 |     status_message_provider=MyStatusMessageProvider(),
406 | )
407 | 
408 | 
409 | async def read_output_stream():
410 |     output = stream_predict(num=3)
411 | 
412 |     return_value = None
413 |     async for chunk in output:
414 |         if isinstance(chunk, dspy.streaming.StreamResponse):
415 |             print(chunk)
416 |         elif isinstance(chunk, dspy.Prediction):
417 |             return_value = chunk
418 |         elif isinstance(chunk, dspy.streaming.StatusMessage):
419 |             print(chunk)
420 |     return return_value
421 | 
422 | 
423 | program_output = asyncio.run(read_output_stream())
424 | print("Final output: ", program_output)
425 | ```
426 | 
427 | Sample output:
428 | 
429 | ```
430 | StatusMessage(message='Calling tool double_the_number...')
431 | StatusMessage(message='Tool calling finished! Querying the LLM with tool calling results...')
432 | StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk='To')
433 | StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' find')
434 | StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' the')
435 | StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' sum')
436 | StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' of')
437 | StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' the')
438 | StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' two')
439 | StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' numbers')
440 | StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=',')
441 | StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' we')
442 | StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' simply')
443 | StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' add')
444 | StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' them')
445 | StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' together')
446 | StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk='.')
447 | StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' Here')
448 | StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=',')
449 | StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' ')
450 | StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk='3')
451 | StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' plus')
452 | StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' 6 equals 9.')
453 | Final output:  Prediction(
454 |     reasoning='To find the sum of the two numbers, we simply add them together. Here, 3 plus 6 equals 9.',
455 |     sum='9'
456 | )
457 | ```
458 | 
459 | ## Synchronous Streaming
460 | 
461 | By default calling a streamified DSPy program produces an async generator. In order to get back
462 | a sync generator, you can set the flag `async_streaming=False`:
463 | 
464 | 
465 | ```python
466 | import os
467 | 
468 | import dspy
469 | 
470 | os.environ["OPENAI_API_KEY"] = "your_api_key"
471 | 
472 | dspy.configure(lm=dspy.LM("openai/gpt-4o-mini"))
473 | 
474 | predict = dspy.Predict("question->answer")
475 | 
476 | # Enable streaming for the 'answer' field
477 | stream_predict = dspy.streamify(
478 |     predict,
479 |     stream_listeners=[dspy.streaming.StreamListener(signature_field_name="answer")],
480 |     async_streaming=False,
481 | )
482 | 
483 | output = stream_predict(question="why did a chicken cross the kitchen?")
484 | 
485 | program_output = None
486 | for chunk in output:
487 |     if isinstance(chunk, dspy.streaming.StreamResponse):
488 |         print(chunk)
489 |     elif isinstance(chunk, dspy.Prediction):
490 |         program_output = chunk
491 | print(f"Program output: {program_output}")
492 | ```
493 | 
```

--------------------------------------------------------------------------------
/tests/signatures/test_adapter_image.py:
--------------------------------------------------------------------------------

```python
  1 | import os
  2 | import tempfile
  3 | from io import BytesIO
  4 | 
  5 | import pydantic
  6 | import pytest
  7 | import requests
  8 | from PIL import Image as PILImage
  9 | 
 10 | import dspy
 11 | from dspy.adapters.types.image import encode_image
 12 | from dspy.utils.dummies import DummyLM
 13 | 
 14 | 
 15 | @pytest.fixture
 16 | def sample_pil_image():
 17 |     """Fixture to provide a sample image for testing"""
 18 |     url = "https://images.dog.ceo/breeds/dane-great/n02109047_8912.jpg"
 19 |     response = requests.get(url)
 20 |     response.raise_for_status()
 21 |     return PILImage.open(BytesIO(response.content))
 22 | 
 23 | 
 24 | @pytest.fixture
 25 | def sample_dspy_image_download():
 26 |     url = "https://images.dog.ceo/breeds/dane-great/n02109047_8912.jpg"
 27 |     return dspy.Image(url, download=True)
 28 | 
 29 | 
 30 | @pytest.fixture
 31 | def sample_url():
 32 |     return "https://images.dog.ceo/breeds/dane-great/n02109047_8912.jpg"
 33 | 
 34 | 
 35 | @pytest.fixture
 36 | def sample_dspy_image_no_download():
 37 |     return dspy.Image("https://images.dog.ceo/breeds/dane-great/n02109047_8912.jpg")
 38 | 
 39 | 
 40 | def count_messages_with_image_url_pattern(messages):
 41 |     pattern = {"type": "image_url", "image_url": {"url": lambda x: isinstance(x, str)}}
 42 | 
 43 |     try:
 44 | 
 45 |         def check_pattern(obj, pattern):
 46 |             if isinstance(pattern, dict):
 47 |                 if not isinstance(obj, dict):
 48 |                     return False
 49 |                 return all(k in obj and check_pattern(obj[k], v) for k, v in pattern.items())
 50 |             if callable(pattern):
 51 |                 return pattern(obj)
 52 |             return obj == pattern
 53 | 
 54 |         def count_patterns(obj, pattern):
 55 |             count = 0
 56 |             if check_pattern(obj, pattern):
 57 |                 count += 1
 58 |             if isinstance(obj, dict):
 59 |                 count += sum(count_patterns(v, pattern) for v in obj.values())
 60 |             if isinstance(obj, (list, tuple)):
 61 |                 count += sum(count_patterns(v, pattern) for v in obj)
 62 |             return count
 63 | 
 64 |         return count_patterns(messages, pattern)
 65 |     except Exception:
 66 |         return 0
 67 | 
 68 | 
 69 | def setup_predictor(signature, expected_output):
 70 |     """Helper to set up a predictor with DummyLM"""
 71 |     lm = DummyLM([expected_output])
 72 |     dspy.settings.configure(lm=lm)
 73 |     return dspy.Predict(signature), lm
 74 | 
 75 | 
 76 | @pytest.mark.parametrize(
 77 |     "test_case",
 78 |     [
 79 |         {
 80 |             "name": "probabilistic_classification",
 81 |             "signature": "image: dspy.Image, class_labels: list[str] -> probabilities: dict[str, float]",
 82 |             "inputs": {"image": "https://example.com/dog.jpg", "class_labels": ["dog", "cat", "bird"]},
 83 |             "key_output": "probabilities",
 84 |             "expected": {"probabilities": {"dog": 0.8, "cat": 0.1, "bird": 0.1}},
 85 |         },
 86 |         {
 87 |             "name": "image_to_code",
 88 |             "signature": "ui_image: dspy.Image, target_language: str -> generated_code: str",
 89 |             "inputs": {"ui_image": "https://example.com/button.png", "target_language": "HTML"},
 90 |             "key_output": "generated_code",
 91 |             "expected": {"generated_code": "<button>Click me</button>"},
 92 |         },
 93 |         {
 94 |             "name": "bbox_detection",
 95 |             "signature": "image: dspy.Image -> bboxes: list[Tuple[int, int, int, int]]",
 96 |             "inputs": {"image": "https://example.com/image.jpg"},
 97 |             "key_output": "bboxes",
 98 |             "expected": {"bboxes": [(10, 20, 30, 40), (50, 60, 70, 80)]},
 99 |         },
100 |         {
101 |             "name": "multilingual_caption",
102 |             "signature": "image: dspy.Image, languages: list[str] -> captions: dict[str, str]",
103 |             "inputs": {"image": "https://example.com/dog.jpg", "languages": ["en", "es", "fr"]},
104 |             "key_output": "captions",
105 |             "expected": {
106 |                 "captions": {"en": "A golden retriever", "es": "Un golden retriever", "fr": "Un golden retriever"}
107 |             },
108 |         },
109 |     ],
110 | )
111 | def test_basic_image_operations(test_case):
112 |     """Consolidated test for basic image operations"""
113 |     predictor, lm = setup_predictor(test_case["signature"], test_case["expected"])
114 | 
115 |     # Convert string URLs to dspy.Image objects
116 |     inputs = {
117 |         k: dspy.Image(v) if isinstance(v, str) and k in ["image", "ui_image"] else v
118 |         for k, v in test_case["inputs"].items()
119 |     }
120 | 
121 |     result = predictor(**inputs)
122 | 
123 |     # Check result based on output field name
124 |     output_field = next(f for f in ["probabilities", "generated_code", "bboxes", "captions"] if hasattr(result, f))
125 |     assert getattr(result, output_field) == test_case["expected"][test_case["key_output"]]
126 |     assert count_messages_with_image_url_pattern(lm.history[-1]["messages"]) == 1
127 | 
128 | 
129 | @pytest.mark.parametrize(
130 |     "image_input,description",
131 |     [
132 |         ("pil_image", "PIL Image"),
133 |         ("encoded_pil_image", "encoded PIL image string"),
134 |         ("dspy_image_download", "dspy.Image with download=True"),
135 |         ("dspy_image_no_download", "dspy.Image without download"),
136 |     ],
137 | )
138 | def test_image_input_formats(
139 |     request, sample_pil_image, sample_dspy_image_download, sample_dspy_image_no_download, image_input, description
140 | ):
141 |     """Test different input formats for image fields"""
142 |     signature = "image: dspy.Image, class_labels: list[str] -> probabilities: dict[str, float]"
143 |     expected = {"probabilities": {"dog": 0.8, "cat": 0.1, "bird": 0.1}}
144 |     predictor, lm = setup_predictor(signature, expected)
145 | 
146 |     input_map = {
147 |         "pil_image": sample_pil_image,
148 |         "encoded_pil_image": encode_image(sample_pil_image),
149 |         "dspy_image_download": sample_dspy_image_download,
150 |         "dspy_image_no_download": sample_dspy_image_no_download,
151 |     }
152 | 
153 |     actual_input = input_map[image_input]
154 |     # TODO(isaacbmiller): Support the cases without direct dspy.Image coercion
155 |     if image_input in ["pil_image", "encoded_pil_image"]:
156 |         pytest.xfail(f"{description} not fully supported without dspy.Image coercion")
157 | 
158 |     result = predictor(image=actual_input, class_labels=["dog", "cat", "bird"])
159 |     assert result.probabilities == expected["probabilities"]
160 |     assert count_messages_with_image_url_pattern(lm.history[-1]["messages"]) == 1
161 | 
162 | 
163 | def test_predictor_save_load(sample_url, sample_pil_image):
164 |     """Test saving and loading predictors with image fields"""
165 |     signature = "image: dspy.Image -> caption: str"
166 |     examples = [
167 |         dspy.Example(image=dspy.Image(sample_url), caption="Example 1"),
168 |         dspy.Example(image=sample_pil_image, caption="Example 2"),
169 |     ]
170 | 
171 |     predictor, lm = setup_predictor(signature, {"caption": "A golden retriever"})
172 |     optimizer = dspy.teleprompt.LabeledFewShot(k=1)
173 |     compiled_predictor = optimizer.compile(student=predictor, trainset=examples, sample=False)
174 | 
175 |     with tempfile.NamedTemporaryFile(mode="w+", delete=True, suffix=".json") as temp_file:
176 |         compiled_predictor.save(temp_file.name)
177 |         loaded_predictor = dspy.Predict(signature)
178 |         loaded_predictor.load(temp_file.name)
179 | 
180 |     loaded_predictor(image=dspy.Image("https://example.com/dog.jpg"))
181 |     assert count_messages_with_image_url_pattern(lm.history[-1]["messages"]) == 2
182 |     assert "<DSPY_IMAGE_START>" not in str(lm.history[-1]["messages"])
183 | 
184 | 
185 | def test_save_load_complex_default_types():
186 |     """Test saving and loading predictors with complex default types (lists of images)"""
187 |     examples = [
188 |         dspy.Example(
189 |             image_list=[
190 |                 dspy.Image("https://example.com/dog.jpg"),
191 |                 dspy.Image("https://example.com/cat.jpg"),
192 |             ],
193 |             caption="Example 1",
194 |         ).with_inputs("image_list"),
195 |     ]
196 | 
197 |     class ComplexTypeSignature(dspy.Signature):
198 |         image_list: list[dspy.Image] = dspy.InputField(desc="A list of images")
199 |         caption: str = dspy.OutputField(desc="A caption for the image list")
200 | 
201 |     predictor, lm = setup_predictor(ComplexTypeSignature, {"caption": "A list of images"})
202 |     optimizer = dspy.teleprompt.LabeledFewShot(k=1)
203 |     compiled_predictor = optimizer.compile(student=predictor, trainset=examples, sample=False)
204 | 
205 |     with tempfile.NamedTemporaryFile(mode="w+", delete=True, suffix=".json") as temp_file:
206 |         compiled_predictor.save(temp_file.name)
207 |         loaded_predictor = dspy.Predict(ComplexTypeSignature)
208 |         loaded_predictor.load(temp_file.name)
209 | 
210 |     result = loaded_predictor(**examples[0].inputs())
211 |     assert result.caption == "A list of images"
212 |     assert str(lm.history[-1]["messages"]).count("'url'") == 4
213 |     assert "<DSPY_IMAGE_START>" not in str(lm.history[-1]["messages"])
214 | 
215 | 
216 | class BasicImageSignature(dspy.Signature):
217 |     """Basic signature with a single image input"""
218 | 
219 |     image: dspy.Image = dspy.InputField()
220 |     output: str = dspy.OutputField()
221 | 
222 | 
223 | class ImageListSignature(dspy.Signature):
224 |     """Signature with a list of images input"""
225 | 
226 |     image_list: list[dspy.Image] = dspy.InputField()
227 |     output: str = dspy.OutputField()
228 | 
229 | 
230 | @pytest.mark.parametrize(
231 |     "test_case",
232 |     [
233 |         {
234 |             "name": "basic_dspy_signature",
235 |             "signature_class": BasicImageSignature,
236 |             "inputs": {"image": "https://example.com/dog.jpg"},
237 |             "expected": {"output": "A dog photo"},
238 |             "expected_image_urls": 2,
239 |         },
240 |         {
241 |             "name": "list_dspy_signature",
242 |             "signature_class": ImageListSignature,
243 |             "inputs": {"image_list": ["https://example.com/dog.jpg", "https://example.com/cat.jpg"]},
244 |             "expected": {"output": "Multiple photos"},
245 |             "expected_image_urls": 4,
246 |         },
247 |     ],
248 | )
249 | def test_save_load_complex_types(test_case):
250 |     """Test saving and loading predictors with complex types"""
251 |     signature_cls = test_case["signature_class"]
252 | 
253 |     # Convert string URLs to dspy.Image objects in input
254 |     processed_input = {}
255 |     for key, value in test_case["inputs"].items():
256 |         if isinstance(value, str) and "http" in value:
257 |             processed_input[key] = dspy.Image(value)
258 |         elif isinstance(value, list) and value and isinstance(value[0], str):
259 |             processed_input[key] = [dspy.Image(url) for url in value]
260 |         else:
261 |             processed_input[key] = value
262 | 
263 |     # Create example and predictor
264 |     examples = [dspy.Example(**processed_input, **test_case["expected"]).with_inputs(*processed_input.keys())]
265 | 
266 |     predictor, lm = setup_predictor(signature_cls, test_case["expected"])
267 |     optimizer = dspy.teleprompt.LabeledFewShot(k=1)
268 |     compiled_predictor = optimizer.compile(student=predictor, trainset=examples, sample=False)
269 | 
270 |     # Test save and load
271 |     with tempfile.NamedTemporaryFile(mode="w+", delete=True, suffix=".json") as temp_file:
272 |         compiled_predictor.save(temp_file.name)
273 |         loaded_predictor = dspy.Predict(signature_cls)
274 |         loaded_predictor.load(temp_file.name)
275 | 
276 |     # Run prediction
277 |     result = loaded_predictor(**processed_input)
278 | 
279 |     # Verify output matches expected
280 |     for key, value in test_case["expected"].items():
281 |         assert getattr(result, key) == value
282 | 
283 |     # Verify correct number of image URLs in messages
284 |     assert count_messages_with_image_url_pattern(lm.history[-1]["messages"]) == test_case["expected_image_urls"]
285 |     assert "<DSPY_IMAGE_START>" not in str(lm.history[-1]["messages"])
286 | 
287 | 
288 | def test_save_load_pydantic_model():
289 |     """Test saving and loading predictors with pydantic models"""
290 | 
291 |     class ImageModel(pydantic.BaseModel):
292 |         image: dspy.Image
293 |         image_list: list[dspy.Image] | None = None
294 |         output: str
295 | 
296 |     class PydanticSignature(dspy.Signature):
297 |         model_input: ImageModel = dspy.InputField()
298 |         output: str = dspy.OutputField()
299 | 
300 |     # Create model instance
301 |     model_input = ImageModel(
302 |         image=dspy.Image("https://example.com/dog.jpg"),
303 |         image_list=[dspy.Image("https://example.com/cat.jpg")],
304 |         output="Multiple photos",
305 |     )
306 | 
307 |     # Create example and predictor
308 |     examples = [dspy.Example(model_input=model_input, output="Multiple photos").with_inputs("model_input")]
309 | 
310 |     predictor, lm = setup_predictor(PydanticSignature, {"output": "Multiple photos"})
311 |     optimizer = dspy.teleprompt.LabeledFewShot(k=1)
312 |     compiled_predictor = optimizer.compile(student=predictor, trainset=examples, sample=False)
313 | 
314 |     # Test save and load
315 |     with tempfile.NamedTemporaryFile(mode="w+", delete=True, suffix=".json") as temp_file:
316 |         compiled_predictor.save(temp_file.name)
317 |         loaded_predictor = dspy.Predict(PydanticSignature)
318 |         loaded_predictor.load(temp_file.name)
319 | 
320 |     # Run prediction
321 |     result = loaded_predictor(model_input=model_input)
322 | 
323 |     # Verify output matches expected
324 |     assert result.output == "Multiple photos"
325 |     assert count_messages_with_image_url_pattern(lm.history[-1]["messages"]) == 4
326 |     assert "<DSPY_IMAGE_START>" not in str(lm.history[-1]["messages"])
327 | 
328 | 
329 | def test_optional_image_field():
330 |     """Test that optional image fields are not required"""
331 | 
332 |     class OptionalImageSignature(dspy.Signature):
333 |         image: dspy.Image | None = dspy.InputField()
334 |         output: str = dspy.OutputField()
335 | 
336 |     predictor, lm = setup_predictor(OptionalImageSignature, {"output": "Hello"})
337 |     result = predictor(image=None)
338 |     assert result.output == "Hello"
339 |     assert count_messages_with_image_url_pattern(lm.history[-1]["messages"]) == 0
340 | 
341 | 
342 | def test_pdf_url_support():
343 |     """Test support for PDF files from URLs"""
344 |     pdf_url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf"
345 | 
346 |     # Create a dspy.Image object from the PDF URL with download=True
347 |     pdf_image = dspy.Image(pdf_url, download=True)
348 | 
349 |     # The data URI should contain application/pdf in the MIME type
350 |     assert "data:application/pdf" in pdf_image.url
351 |     assert ";base64," in pdf_image.url
352 | 
353 |     # Test using it in a predictor
354 |     class PDFSignature(dspy.Signature):
355 |         document: dspy.Image = dspy.InputField(desc="A PDF document")
356 |         summary: str = dspy.OutputField(desc="A summary of the PDF")
357 | 
358 |     predictor, lm = setup_predictor(PDFSignature, {"summary": "This is a dummy PDF"})
359 |     result = predictor(document=pdf_image)
360 | 
361 |     assert result.summary == "This is a dummy PDF"
362 |     assert count_messages_with_image_url_pattern(lm.history[-1]["messages"]) == 1
363 | 
364 |     # Ensure the URL was properly expanded in messages
365 |     messages_str = str(lm.history[-1]["messages"])
366 |     assert "application/pdf" in messages_str
367 | 
368 | 
369 | def test_different_mime_types():
370 |     """Test support for different file types and MIME type detection"""
371 |     # Test with various file types
372 |     file_urls = {
373 |         "pdf": "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf",
374 |         "image": "https://images.dog.ceo/breeds/dane-great/n02109047_8912.jpg",
375 |     }
376 | 
377 |     expected_mime_types = {
378 |         "pdf": "application/pdf",
379 |         "image": "image/jpeg",
380 |     }
381 | 
382 |     for file_type, url in file_urls.items():
383 |         # Download and encode
384 |         encoded = encode_image(url, download_images=True)
385 | 
386 |         # Check for correct MIME type in the encoded data - using 'in' instead of startswith
387 |         # to account for possible parameters in the MIME type
388 |         assert f"data:{expected_mime_types[file_type]}" in encoded
389 |         assert ";base64," in encoded
390 | 
391 | 
392 | def test_mime_type_from_response_headers():
393 |     """Test that MIME types from response headers are correctly used"""
394 |     # This URL returns proper Content-Type header
395 |     pdf_url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf"
396 | 
397 |     # Make an actual request to get the content type from headers
398 |     response = requests.get(pdf_url)
399 |     expected_mime_type = response.headers.get("Content-Type", "")
400 | 
401 |     # Should be application/pdf or similar
402 |     assert "pdf" in expected_mime_type.lower()
403 | 
404 |     # Encode with download to test MIME type from headers
405 |     encoded = encode_image(pdf_url, download_images=True)
406 | 
407 |     # The encoded data should contain the correct MIME type
408 |     assert "application/pdf" in encoded
409 |     assert ";base64," in encoded
410 | 
411 | 
412 | def test_pdf_from_file():
413 |     """Test handling a PDF file from disk"""
414 |     # Download a PDF to a temporary file
415 |     pdf_url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf"
416 |     response = requests.get(pdf_url)
417 |     response.raise_for_status()
418 | 
419 |     with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmp_file:
420 |         tmp_file.write(response.content)
421 |         tmp_file_path = tmp_file.name
422 | 
423 |     try:
424 |         # Create a dspy.Image from the file
425 |         pdf_image = dspy.Image(tmp_file_path)
426 | 
427 |         # The constructor encodes the file into a data URI we can inspect directly
428 |         assert "data:application/pdf" in pdf_image.url
429 |         assert ";base64," in pdf_image.url
430 | 
431 |         # Test the image in a predictor
432 |         class FilePDFSignature(dspy.Signature):
433 |             document: dspy.Image = dspy.InputField(desc="A PDF document from file")
434 |             summary: str = dspy.OutputField(desc="A summary of the PDF")
435 | 
436 |         predictor, lm = setup_predictor(FilePDFSignature, {"summary": "This is a PDF from file"})
437 |         result = predictor(document=pdf_image)
438 | 
439 |         assert result.summary == "This is a PDF from file"
440 |         assert count_messages_with_image_url_pattern(lm.history[-1]["messages"]) == 1
441 |     finally:
442 |         # Clean up the temporary file
443 |         try:
444 |             os.unlink(tmp_file_path)
445 |         except Exception:
446 |             pass
447 | 
448 | 
449 | def test_image_repr():
450 |     """Test string representation of Image objects"""
451 |     url_image = dspy.Image("https://example.com/dog.jpg")
452 |     assert str(url_image) == (
453 |         "<<CUSTOM-TYPE-START-IDENTIFIER>>"
454 |         '[{"type": "image_url", "image_url": {"url": "https://example.com/dog.jpg"}}]'
455 |         "<<CUSTOM-TYPE-END-IDENTIFIER>>"
456 |     )
457 |     assert repr(url_image) == "Image(url='https://example.com/dog.jpg')"
458 | 
459 |     sample_pil = PILImage.new("RGB", (60, 30), color="red")
460 |     pil_image = dspy.Image(sample_pil)
461 |     assert str(pil_image).startswith('<<CUSTOM-TYPE-START-IDENTIFIER>>[{"type": "image_url",')
462 |     assert str(pil_image).endswith("<<CUSTOM-TYPE-END-IDENTIFIER>>")
463 |     assert "base64" in str(pil_image)
464 | 
465 | 
466 | def test_from_methods_warn(tmp_path):
467 |     """Deprecated from_* methods emit warnings"""
468 |     tmp_file = tmp_path / "test.png"
469 |     tmp_file.write_bytes(b"pngdata")
470 | 
471 |     with pytest.warns(DeprecationWarning):
472 |         dspy.Image.from_url("https://example.com/dog.jpg")
473 |     with pytest.warns(DeprecationWarning):
474 |         dspy.Image.from_file(str(tmp_file))
475 |     sample_pil = PILImage.new("RGB", (10, 10), color="blue")
476 |     with pytest.warns(DeprecationWarning):
477 |         dspy.Image.from_PIL(sample_pil)
478 | 
479 | 
480 | def test_invalid_string_format():
481 |     """Test that invalid string formats raise a ValueError"""
482 |     invalid_string = "this_is_not_a_url_or_file"
483 | 
484 |     # Should raise a ValueError and not pass the string through
485 |     with pytest.raises(ValueError, match="Unrecognized") as warning_info:
486 |         image = dspy.Image(invalid_string)
487 | 
488 | def test_pil_image_with_download_parameter():
489 |     """Test behavior when PIL image is passed with download=True"""
490 |     sample_pil = PILImage.new("RGB", (60, 30), color="red")
491 | 
492 |     # PIL image should be encoded regardless of download parameter
493 |     image_no_download = dspy.Image(sample_pil)
494 |     image_with_download = dspy.Image(sample_pil, download=True)
495 | 
496 |     # Both should result in base64 encoded data URIs
497 |     assert image_no_download.url.startswith("data:")
498 |     assert image_with_download.url.startswith("data:")
499 |     assert "base64," in image_no_download.url
500 |     assert "base64," in image_with_download.url
501 | 
502 |     # They should be identical since PIL images are always encoded
503 |     assert image_no_download.url == image_with_download.url
504 | 
```

--------------------------------------------------------------------------------
/dspy/clients/lm.py:
--------------------------------------------------------------------------------

```python
  1 | import logging
  2 | import os
  3 | import re
  4 | import threading
  5 | import warnings
  6 | from typing import Any, Literal, cast
  7 | 
  8 | import litellm
  9 | from anyio.streams.memory import MemoryObjectSendStream
 10 | from asyncer import syncify
 11 | 
 12 | import dspy
 13 | from dspy.clients.cache import request_cache
 14 | from dspy.clients.openai import OpenAIProvider
 15 | from dspy.clients.provider import Provider, ReinforceJob, TrainingJob
 16 | from dspy.clients.utils_finetune import TrainDataFormat
 17 | from dspy.dsp.utils.settings import settings
 18 | from dspy.utils.callback import BaseCallback
 19 | 
 20 | from .base_lm import BaseLM
 21 | 
 22 | logger = logging.getLogger(__name__)
 23 | 
 24 | 
 25 | class LM(BaseLM):
 26 |     """
 27 |     A language model supporting chat or text completion requests for use with DSPy modules.
 28 |     """
 29 | 
 30 |     def __init__(
 31 |         self,
 32 |         model: str,
 33 |         model_type: Literal["chat", "text", "responses"] = "chat",
 34 |         temperature: float | None = None,
 35 |         max_tokens: int | None = None,
 36 |         cache: bool = True,
 37 |         callbacks: list[BaseCallback] | None = None,
 38 |         num_retries: int = 3,
 39 |         provider: Provider | None = None,
 40 |         finetuning_model: str | None = None,
 41 |         launch_kwargs: dict[str, Any] | None = None,
 42 |         train_kwargs: dict[str, Any] | None = None,
 43 |         use_developer_role: bool = False,
 44 |         **kwargs,
 45 |     ):
 46 |         """
 47 |         Create a new language model instance for use with DSPy modules and programs.
 48 | 
 49 |         Args:
 50 |             model: The model to use. This should be a string of the form ``"llm_provider/llm_name"``
 51 |                    supported by LiteLLM. For example, ``"openai/gpt-4o"``.
 52 |             model_type: The type of the model, either ``"chat"`` or ``"text"``.
 53 |             temperature: The sampling temperature to use when generating responses.
 54 |             max_tokens: The maximum number of tokens to generate per response.
 55 |             cache: Whether to cache the model responses for reuse to improve performance
 56 |                    and reduce costs.
 57 |             callbacks: A list of callback functions to run before and after each request.
 58 |             num_retries: The number of times to retry a request if it fails transiently due to
 59 |                          network error, rate limiting, etc. Requests are retried with exponential
 60 |                          backoff.
 61 |             provider: The provider to use. If not specified, the provider will be inferred from the model.
 62 |             finetuning_model: The model to finetune. In some providers, the models available for finetuning is different
 63 |                 from the models available for inference.
 64 |             rollout_id: Optional integer used to differentiate cache entries for otherwise
 65 |                 identical requests. Different values bypass DSPy's caches while still caching
 66 |                 future calls with the same inputs and rollout ID. Note that `rollout_id`
 67 |                 only affects generation when `temperature` is non-zero. This argument is
 68 |                 stripped before sending requests to the provider.
 69 |         """
 70 |         # Remember to update LM.copy() if you modify the constructor!
 71 |         self.model = model
 72 |         self.model_type = model_type
 73 |         self.cache = cache
 74 |         self.provider = provider or self.infer_provider()
 75 |         self.callbacks = callbacks or []
 76 |         self.history = []
 77 |         self.num_retries = num_retries
 78 |         self.finetuning_model = finetuning_model
 79 |         self.launch_kwargs = launch_kwargs or {}
 80 |         self.train_kwargs = train_kwargs or {}
 81 |         self.use_developer_role = use_developer_role
 82 |         self._warned_zero_temp_rollout = False
 83 | 
 84 |         # Handle model-specific configuration for different model families
 85 |         model_family = model.split("/")[-1].lower() if "/" in model else model.lower()
 86 | 
 87 |         # Recognize OpenAI reasoning models (o1, o3, o4, gpt-5 family)
 88 |         model_pattern = re.match(r"^(?:o[1345]|gpt-5)(?:-(?:mini|nano))?", model_family)
 89 | 
 90 |         if model_pattern:
 91 | 
 92 |             if (temperature and temperature != 1.0) or (max_tokens and max_tokens < 16000):
 93 |                 raise ValueError(
 94 |                     "OpenAI's reasoning models require passing temperature=1.0 or None and max_tokens >= 16000 or None to "
 95 |                     "`dspy.LM(...)`, e.g., dspy.LM('openai/gpt-5', temperature=1.0, max_tokens=16000)"
 96 |                 )
 97 |             self.kwargs = dict(temperature=temperature, max_completion_tokens=max_tokens, **kwargs)
 98 |             if self.kwargs.get("rollout_id") is None:
 99 |                 self.kwargs.pop("rollout_id", None)
100 |         else:
101 |             self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs)
102 |             if self.kwargs.get("rollout_id") is None:
103 |                 self.kwargs.pop("rollout_id", None)
104 | 
105 |         self._warn_zero_temp_rollout(self.kwargs.get("temperature"), self.kwargs.get("rollout_id"))
106 | 
107 |     def _warn_zero_temp_rollout(self, temperature: float | None, rollout_id):
108 |         if not self._warned_zero_temp_rollout and rollout_id is not None and (temperature is None or temperature == 0):
109 |             warnings.warn(
110 |                 "rollout_id has no effect when temperature=0; set temperature>0 to bypass the cache.",
111 |                 stacklevel=3,
112 |             )
113 |             self._warned_zero_temp_rollout = True
114 | 
115 |     def _get_cached_completion_fn(self, completion_fn, cache):
116 |         ignored_args_for_cache_key = ["api_key", "api_base", "base_url"]
117 |         if cache:
118 |             completion_fn = request_cache(
119 |                 cache_arg_name="request",
120 |                 ignored_args_for_cache_key=ignored_args_for_cache_key,
121 |             )(completion_fn)
122 | 
123 |         litellm_cache_args = {"no-cache": True, "no-store": True}
124 | 
125 |         return completion_fn, litellm_cache_args
126 | 
127 |     def forward(self, prompt=None, messages=None, **kwargs):
128 |         # Build the request.
129 |         kwargs = dict(kwargs)
130 |         cache = kwargs.pop("cache", self.cache)
131 | 
132 |         messages = messages or [{"role": "user", "content": prompt}]
133 |         if self.use_developer_role and self.model_type == "responses":
134 |             messages = [{**m, "role": "developer"} if m.get("role") == "system" else m for m in messages]
135 |         kwargs = {**self.kwargs, **kwargs}
136 |         self._warn_zero_temp_rollout(kwargs.get("temperature"), kwargs.get("rollout_id"))
137 |         if kwargs.get("rollout_id") is None:
138 |             kwargs.pop("rollout_id", None)
139 | 
140 |         if self.model_type == "chat":
141 |             completion = litellm_completion
142 |         elif self.model_type == "text":
143 |             completion = litellm_text_completion
144 |         elif self.model_type == "responses":
145 |             completion = litellm_responses_completion
146 |         completion, litellm_cache_args = self._get_cached_completion_fn(completion, cache)
147 | 
148 |         results = completion(
149 |             request=dict(model=self.model, messages=messages, **kwargs),
150 |             num_retries=self.num_retries,
151 |             cache=litellm_cache_args,
152 |         )
153 | 
154 |         self._check_truncation(results)
155 | 
156 |         if not getattr(results, "cache_hit", False) and dspy.settings.usage_tracker and hasattr(results, "usage"):
157 |             settings.usage_tracker.add_usage(self.model, dict(results.usage))
158 |         return results
159 | 
160 |     async def aforward(self, prompt=None, messages=None, **kwargs):
161 |         # Build the request.
162 |         kwargs = dict(kwargs)
163 |         cache = kwargs.pop("cache", self.cache)
164 | 
165 |         messages = messages or [{"role": "user", "content": prompt}]
166 |         if self.use_developer_role and self.model_type == "responses":
167 |             messages = [{**m, "role": "developer"} if m.get("role") == "system" else m for m in messages]
168 |         kwargs = {**self.kwargs, **kwargs}
169 |         self._warn_zero_temp_rollout(kwargs.get("temperature"), kwargs.get("rollout_id"))
170 |         if kwargs.get("rollout_id") is None:
171 |             kwargs.pop("rollout_id", None)
172 | 
173 |         if self.model_type == "chat":
174 |             completion = alitellm_completion
175 |         elif self.model_type == "text":
176 |             completion = alitellm_text_completion
177 |         elif self.model_type == "responses":
178 |             completion = alitellm_responses_completion
179 |         completion, litellm_cache_args = self._get_cached_completion_fn(completion, cache)
180 | 
181 |         results = await completion(
182 |             request=dict(model=self.model, messages=messages, **kwargs),
183 |             num_retries=self.num_retries,
184 |             cache=litellm_cache_args,
185 |         )
186 | 
187 |         self._check_truncation(results)
188 | 
189 |         if not getattr(results, "cache_hit", False) and dspy.settings.usage_tracker and hasattr(results, "usage"):
190 |             settings.usage_tracker.add_usage(self.model, dict(results.usage))
191 |         return results
192 | 
193 |     def launch(self, launch_kwargs: dict[str, Any] | None = None):
194 |         self.provider.launch(self, launch_kwargs)
195 | 
196 |     def kill(self, launch_kwargs: dict[str, Any] | None = None):
197 |         self.provider.kill(self, launch_kwargs)
198 | 
199 |     def finetune(
200 |         self,
201 |         train_data: list[dict[str, Any]],
202 |         train_data_format: TrainDataFormat | None,
203 |         train_kwargs: dict[str, Any] | None = None,
204 |     ) -> TrainingJob:
205 |         from dspy import settings as settings
206 | 
207 |         if not self.provider.finetunable:
208 |             raise ValueError(
209 |                 f"Provider {self.provider} does not support fine-tuning, please specify your provider by explicitly "
210 |                 "setting `provider` when creating the `dspy.LM` instance. For example, "
211 |                 "`dspy.LM('openai/gpt-4.1-mini-2025-04-14', provider=dspy.OpenAIProvider())`."
212 |             )
213 | 
214 |         def thread_function_wrapper():
215 |             return self._run_finetune_job(job)
216 | 
217 |         thread = threading.Thread(target=thread_function_wrapper)
218 |         train_kwargs = train_kwargs or self.train_kwargs
219 |         model_to_finetune = self.finetuning_model or self.model
220 |         job = self.provider.TrainingJob(
221 |             thread=thread,
222 |             model=model_to_finetune,
223 |             train_data=train_data,
224 |             train_data_format=train_data_format,
225 |             train_kwargs=train_kwargs,
226 |         )
227 |         thread.start()
228 | 
229 |         return job
230 | 
231 |     def reinforce(
232 |         self, train_kwargs
233 |     ) -> ReinforceJob:
234 |         # TODO(GRPO Team): Should we return an initialized job here?
235 |         from dspy import settings as settings
236 | 
237 |         err = f"Provider {self.provider} does not implement the reinforcement learning interface."
238 |         assert self.provider.reinforceable, err
239 | 
240 |         job = self.provider.ReinforceJob(lm=self, train_kwargs=train_kwargs)
241 |         job.initialize()
242 |         return job
243 | 
244 |     def _run_finetune_job(self, job: TrainingJob):
245 |         # TODO(enhance): We should listen for keyboard interrupts somewhere.
246 |         # Requires TrainingJob.cancel() to be implemented for each provider.
247 |         try:
248 |             model = self.provider.finetune(
249 |                 job=job,
250 |                 model=job.model,
251 |                 train_data=job.train_data,
252 |                 train_data_format=job.train_data_format,
253 |                 train_kwargs=job.train_kwargs,
254 |             )
255 |             lm = self.copy(model=model)
256 |             job.set_result(lm)
257 |         except Exception as err:
258 |             logger.error(err)
259 |             job.set_result(err)
260 | 
261 |     def infer_provider(self) -> Provider:
262 |         if OpenAIProvider.is_provider_model(self.model):
263 |             return OpenAIProvider()
264 |         return Provider()
265 | 
266 |     def dump_state(self):
267 |         state_keys = [
268 |             "model",
269 |             "model_type",
270 |             "cache",
271 |             "num_retries",
272 |             "finetuning_model",
273 |             "launch_kwargs",
274 |             "train_kwargs",
275 |         ]
276 |         return {key: getattr(self, key) for key in state_keys} | self.kwargs
277 | 
278 |     def _check_truncation(self, results):
279 |         if self.model_type != "responses" and any(c.finish_reason == "length" for c in results["choices"]):
280 |             logger.warning(
281 |                 f"LM response was truncated due to exceeding max_tokens={self.kwargs['max_tokens']}. "
282 |                 "You can inspect the latest LM interactions with `dspy.inspect_history()`. "
283 |                 "To avoid truncation, consider passing a larger max_tokens when setting up dspy.LM. "
284 |                 f"You may also consider increasing the temperature (currently {self.kwargs['temperature']}) "
285 |                 " if the reason for truncation is repetition."
286 |             )
287 | 
288 | 
289 | def _get_stream_completion_fn(
290 |     request: dict[str, Any],
291 |     cache_kwargs: dict[str, Any],
292 |     sync=True,
293 | ):
294 |     stream = dspy.settings.send_stream
295 |     caller_predict = dspy.settings.caller_predict
296 | 
297 |     if stream is None:
298 |         return None
299 | 
300 |     # The stream is already opened, and will be closed by the caller.
301 |     stream = cast(MemoryObjectSendStream, stream)
302 |     caller_predict_id = id(caller_predict) if caller_predict else None
303 | 
304 |     if dspy.settings.track_usage:
305 |         request["stream_options"] = {"include_usage": True}
306 | 
307 |     async def stream_completion(request: dict[str, Any], cache_kwargs: dict[str, Any]):
308 |         headers = request.pop("headers", None)
309 |         response = await litellm.acompletion(
310 |             cache=cache_kwargs,
311 |             stream=True,
312 |             headers=_get_headers(headers),
313 |             **request,
314 |         )
315 |         chunks = []
316 |         async for chunk in response:
317 |             if caller_predict_id:
318 |                 # Add the predict id to the chunk so that the stream listener can identify which predict produces it.
319 |                 chunk.predict_id = caller_predict_id
320 |             chunks.append(chunk)
321 |             await stream.send(chunk)
322 |         return litellm.stream_chunk_builder(chunks)
323 | 
324 |     def sync_stream_completion():
325 |         syncified_stream_completion = syncify(stream_completion)
326 |         return syncified_stream_completion(request, cache_kwargs)
327 | 
328 |     async def async_stream_completion():
329 |         return await stream_completion(request, cache_kwargs)
330 | 
331 |     if sync:
332 |         return sync_stream_completion
333 |     else:
334 |         return async_stream_completion
335 | 
336 | 
337 | def litellm_completion(request: dict[str, Any], num_retries: int, cache: dict[str, Any] | None = None):
338 |     cache = cache or {"no-cache": True, "no-store": True}
339 |     request = dict(request)
340 |     request.pop("rollout_id", None)
341 |     headers = request.pop("headers", None)
342 |     stream_completion = _get_stream_completion_fn(request, cache, sync=True)
343 |     if stream_completion is None:
344 |         return litellm.completion(
345 |             cache=cache,
346 |             num_retries=num_retries,
347 |             retry_strategy="exponential_backoff_retry",
348 |             headers=_get_headers(headers),
349 |             **request,
350 |         )
351 | 
352 |     return stream_completion()
353 | 
354 | 
355 | def litellm_text_completion(request: dict[str, Any], num_retries: int, cache: dict[str, Any] | None = None):
356 |     cache = cache or {"no-cache": True, "no-store": True}
357 |     request = dict(request)
358 |     request.pop("rollout_id", None)
359 |     headers = request.pop("headers", None)
360 |     # Extract the provider and model from the model string.
361 |     # TODO: Not all the models are in the format of "provider/model"
362 |     model = request.pop("model").split("/", 1)
363 |     provider, model = model[0] if len(model) > 1 else "openai", model[-1]
364 | 
365 |     # Use the API key and base from the request, or from the environment.
366 |     api_key = request.pop("api_key", None) or os.getenv(f"{provider}_API_KEY")
367 |     api_base = request.pop("api_base", None) or os.getenv(f"{provider}_API_BASE")
368 | 
369 |     # Build the prompt from the messages.
370 |     prompt = "\n\n".join([x["content"] for x in request.pop("messages")] + ["BEGIN RESPONSE:"])
371 | 
372 |     return litellm.text_completion(
373 |         cache=cache,
374 |         model=f"text-completion-openai/{model}",
375 |         api_key=api_key,
376 |         api_base=api_base,
377 |         prompt=prompt,
378 |         num_retries=num_retries,
379 |         retry_strategy="exponential_backoff_retry",
380 |         headers=_get_headers(headers),
381 |         **request,
382 |     )
383 | 
384 | 
385 | async def alitellm_completion(request: dict[str, Any], num_retries: int, cache: dict[str, Any] | None = None):
386 |     cache = cache or {"no-cache": True, "no-store": True}
387 |     request = dict(request)
388 |     request.pop("rollout_id", None)
389 |     headers = request.pop("headers", None)
390 |     stream_completion = _get_stream_completion_fn(request, cache, sync=False)
391 |     if stream_completion is None:
392 |         return await litellm.acompletion(
393 |             cache=cache,
394 |             num_retries=num_retries,
395 |             retry_strategy="exponential_backoff_retry",
396 |             headers=_get_headers(headers),
397 |             **request,
398 |         )
399 | 
400 |     return await stream_completion()
401 | 
402 | 
403 | async def alitellm_text_completion(request: dict[str, Any], num_retries: int, cache: dict[str, Any] | None = None):
404 |     cache = cache or {"no-cache": True, "no-store": True}
405 |     request = dict(request)
406 |     request.pop("rollout_id", None)
407 |     model = request.pop("model").split("/", 1)
408 |     headers = request.pop("headers", None)
409 |     provider, model = model[0] if len(model) > 1 else "openai", model[-1]
410 | 
411 |     # Use the API key and base from the request, or from the environment.
412 |     api_key = request.pop("api_key", None) or os.getenv(f"{provider}_API_KEY")
413 |     api_base = request.pop("api_base", None) or os.getenv(f"{provider}_API_BASE")
414 | 
415 |     # Build the prompt from the messages.
416 |     prompt = "\n\n".join([x["content"] for x in request.pop("messages")] + ["BEGIN RESPONSE:"])
417 | 
418 |     return await litellm.atext_completion(
419 |         cache=cache,
420 |         model=f"text-completion-openai/{model}",
421 |         api_key=api_key,
422 |         api_base=api_base,
423 |         prompt=prompt,
424 |         num_retries=num_retries,
425 |         retry_strategy="exponential_backoff_retry",
426 |         headers=_get_headers(headers),
427 |         **request,
428 |     )
429 | 
430 | 
431 | def litellm_responses_completion(request: dict[str, Any], num_retries: int, cache: dict[str, Any] | None = None):
432 |     cache = cache or {"no-cache": True, "no-store": True}
433 |     request = dict(request)
434 |     request.pop("rollout_id", None)
435 |     headers = request.pop("headers", None)
436 |     request = _convert_chat_request_to_responses_request(request)
437 | 
438 |     return litellm.responses(
439 |         cache=cache,
440 |         num_retries=num_retries,
441 |         retry_strategy="exponential_backoff_retry",
442 |         headers=_get_headers(headers),
443 |         **request,
444 |     )
445 | 
446 | 
447 | async def alitellm_responses_completion(request: dict[str, Any], num_retries: int, cache: dict[str, Any] | None = None):
448 |     cache = cache or {"no-cache": True, "no-store": True}
449 |     request = dict(request)
450 |     request.pop("rollout_id", None)
451 |     headers = request.pop("headers", None)
452 |     request = _convert_chat_request_to_responses_request(request)
453 | 
454 |     return await litellm.aresponses(
455 |         cache=cache,
456 |         num_retries=num_retries,
457 |         retry_strategy="exponential_backoff_retry",
458 |         headers=_get_headers(headers),
459 |         **request,
460 |     )
461 | 
462 | 
463 | def _convert_chat_request_to_responses_request(request: dict[str, Any]):
464 |     request = dict(request)
465 |     if "messages" in request:
466 |         content_blocks = []
467 |         for msg in request.pop("messages"):
468 |             c = msg.get("content")
469 |             if isinstance(c, str):
470 |                 content_blocks.append({"type": "input_text", "text": c})
471 |             elif isinstance(c, list):
472 |                 content_blocks.extend(c)
473 |         request["input"] = [{"role": msg.get("role", "user"), "content": content_blocks}]
474 | 
475 |     # Convert `response_format` to `text.format` for Responses API
476 |     if "response_format" in request:
477 |         response_format = request.pop("response_format")
478 |         text = request.pop("text", {})
479 |         request["text"] = {**text, "format": response_format}
480 | 
481 |     return request
482 | 
483 | def _get_headers(headers: dict[str, Any] | None = None):
484 |     headers = headers or {}
485 |     return {
486 |         "User-Agent": f"DSPy/{dspy.__version__}",
487 |         **headers,
488 |     }
489 | 
```
Page 11/17FirstPrevNextLast