#
tokens: 48203/50000 10/391 files (page 9/14)
lines: off (toggle) GitHub
raw markdown copy
This is page 9 of 14. Use http://codebase.md/stanfordnlp/dspy?page={x} to view the full context.

# Directory Structure

```
├── .github
│   ├── .internal_dspyai
│   │   ├── internals
│   │   │   ├── build-and-release.md
│   │   │   └── release-checklist.md
│   │   └── pyproject.toml
│   ├── .tmp
│   │   └── .generated-actions
│   │       └── run-pypi-publish-in-docker-container
│   │           └── action.yml
│   ├── ISSUE_TEMPLATE
│   │   ├── bug_report.yml
│   │   └── feature_request.yml
│   ├── PULL_REQUEST_TEMPLATE
│   │   └── pull_request_template.md
│   ├── workflow_scripts
│   │   └── install_testpypi_pkg.sh
│   └── workflows
│       ├── build_and_release.yml
│       ├── build_utils
│       │   └── test_version.py
│       ├── docs-push.yml
│       ├── precommits_check.yml
│       └── run_tests.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CONTRIBUTING.md
├── docs
│   ├── .gitignore
│   ├── docs
│   │   ├── api
│   │   │   ├── adapters
│   │   │   │   ├── Adapter.md
│   │   │   │   ├── ChatAdapter.md
│   │   │   │   ├── JSONAdapter.md
│   │   │   │   └── TwoStepAdapter.md
│   │   │   ├── evaluation
│   │   │   │   ├── answer_exact_match.md
│   │   │   │   ├── answer_passage_match.md
│   │   │   │   ├── CompleteAndGrounded.md
│   │   │   │   ├── Evaluate.md
│   │   │   │   ├── EvaluationResult.md
│   │   │   │   └── SemanticF1.md
│   │   │   ├── experimental
│   │   │   │   ├── Citations.md
│   │   │   │   └── Document.md
│   │   │   ├── index.md
│   │   │   ├── models
│   │   │   │   ├── Embedder.md
│   │   │   │   └── LM.md
│   │   │   ├── modules
│   │   │   │   ├── BestOfN.md
│   │   │   │   ├── ChainOfThought.md
│   │   │   │   ├── CodeAct.md
│   │   │   │   ├── Module.md
│   │   │   │   ├── MultiChainComparison.md
│   │   │   │   ├── Parallel.md
│   │   │   │   ├── Predict.md
│   │   │   │   ├── ProgramOfThought.md
│   │   │   │   ├── ReAct.md
│   │   │   │   └── Refine.md
│   │   │   ├── optimizers
│   │   │   │   ├── BetterTogether.md
│   │   │   │   ├── BootstrapFewShot.md
│   │   │   │   ├── BootstrapFewShotWithRandomSearch.md
│   │   │   │   ├── BootstrapFinetune.md
│   │   │   │   ├── BootstrapRS.md
│   │   │   │   ├── COPRO.md
│   │   │   │   ├── Ensemble.md
│   │   │   │   ├── GEPA
│   │   │   │   │   ├── GEPA_Advanced.md
│   │   │   │   │   └── overview.md
│   │   │   │   ├── InferRules.md
│   │   │   │   ├── KNN.md
│   │   │   │   ├── KNNFewShot.md
│   │   │   │   ├── LabeledFewShot.md
│   │   │   │   ├── MIPROv2.md
│   │   │   │   └── SIMBA.md
│   │   │   ├── primitives
│   │   │   │   ├── Audio.md
│   │   │   │   ├── Code.md
│   │   │   │   ├── Example.md
│   │   │   │   ├── History.md
│   │   │   │   ├── Image.md
│   │   │   │   ├── Prediction.md
│   │   │   │   ├── Tool.md
│   │   │   │   └── ToolCalls.md
│   │   │   ├── signatures
│   │   │   │   ├── InputField.md
│   │   │   │   ├── OutputField.md
│   │   │   │   └── Signature.md
│   │   │   ├── tools
│   │   │   │   ├── ColBERTv2.md
│   │   │   │   ├── Embeddings.md
│   │   │   │   └── PythonInterpreter.md
│   │   │   └── utils
│   │   │       ├── asyncify.md
│   │   │       ├── configure_cache.md
│   │   │       ├── disable_litellm_logging.md
│   │   │       ├── disable_logging.md
│   │   │       ├── enable_litellm_logging.md
│   │   │       ├── enable_logging.md
│   │   │       ├── inspect_history.md
│   │   │       ├── load.md
│   │   │       ├── StatusMessage.md
│   │   │       ├── StatusMessageProvider.md
│   │   │       ├── streamify.md
│   │   │       └── StreamListener.md
│   │   ├── cheatsheet.md
│   │   ├── community
│   │   │   ├── community-resources.md
│   │   │   ├── how-to-contribute.md
│   │   │   └── use-cases.md
│   │   ├── deep-dive
│   │   │   └── data-handling
│   │   │       ├── built-in-datasets.md
│   │   │       ├── examples.md
│   │   │       ├── img
│   │   │       │   └── data-loading.png
│   │   │       └── loading-custom-data.md
│   │   ├── faqs.md
│   │   ├── index.md
│   │   ├── js
│   │   │   └── runllm-widget.js
│   │   ├── learn
│   │   │   ├── evaluation
│   │   │   │   ├── data.md
│   │   │   │   ├── metrics.md
│   │   │   │   └── overview.md
│   │   │   ├── figures
│   │   │   │   ├── native_tool_call.png
│   │   │   │   └── teleprompter-classes.png
│   │   │   ├── index.md
│   │   │   ├── optimization
│   │   │   │   ├── optimizers.md
│   │   │   │   └── overview.md
│   │   │   └── programming
│   │   │       ├── 7-assertions.md
│   │   │       ├── adapters.md
│   │   │       ├── language_models.md
│   │   │       ├── mcp.md
│   │   │       ├── modules.md
│   │   │       ├── overview.md
│   │   │       ├── signatures.md
│   │   │       └── tools.md
│   │   ├── production
│   │   │   └── index.md
│   │   ├── roadmap.md
│   │   ├── static
│   │   │   ├── .nojekyll
│   │   │   └── img
│   │   │       ├── dspy_logo.png
│   │   │       ├── logo.png
│   │   │       ├── mlflow-tracing-rag.png
│   │   │       ├── modular.png
│   │   │       ├── optimize.png
│   │   │       ├── undraw_docusaurus_mountain.svg
│   │   │       ├── undraw_docusaurus_react.svg
│   │   │       ├── undraw_docusaurus_tree.svg
│   │   │       └── universal_compatibility.png
│   │   ├── stylesheets
│   │   │   └── extra.css
│   │   └── tutorials
│   │       ├── agents
│   │       │   ├── index.ipynb
│   │       │   └── mlflow-tracing-agent.png
│   │       ├── ai_text_game
│   │       │   └── index.md
│   │       ├── async
│   │       │   └── index.md
│   │       ├── audio
│   │       │   └── index.ipynb
│   │       ├── build_ai_program
│   │       │   └── index.md
│   │       ├── cache
│   │       │   └── index.md
│   │       ├── classification
│   │       │   └── index.md
│   │       ├── classification_finetuning
│   │       │   ├── index.ipynb
│   │       │   └── mlflow-tracing-classification.png
│   │       ├── conversation_history
│   │       │   └── index.md
│   │       ├── core_development
│   │       │   └── index.md
│   │       ├── custom_module
│   │       │   ├── index.ipynb
│   │       │   └── mlflow-tracing-custom-module.png
│   │       ├── customer_service_agent
│   │       │   ├── index.ipynb
│   │       │   └── mlflow-tracing-customer-service-agent.png
│   │       ├── deployment
│   │       │   ├── dspy_mlflow_ui.png
│   │       │   └── index.md
│   │       ├── email_extraction
│   │       │   ├── index.md
│   │       │   └── mlflow-tracing-email-extraction.png
│   │       ├── entity_extraction
│   │       │   ├── index.ipynb
│   │       │   └── mlflow-tracing-entity-extraction.png
│   │       ├── games
│   │       │   ├── index.ipynb
│   │       │   └── mlflow-tracing-agent.png
│   │       ├── gepa_ai_program
│   │       │   └── index.md
│   │       ├── gepa_aime
│   │       │   ├── index.ipynb
│   │       │   ├── mlflow-tracing-gepa-aime.png
│   │       │   └── mlflow-tracking-gepa-aime-optimization.png
│   │       ├── gepa_facilitysupportanalyzer
│   │       │   ├── index.ipynb
│   │       │   ├── mlflow-tracing-gepa-support.png
│   │       │   └── mlflow-tracking-gepa-support-optimization.png
│   │       ├── gepa_papillon
│   │       │   ├── index.ipynb
│   │       │   ├── mlflow-tracing-gepa-papilon.png
│   │       │   └── mlflow-tracking-gepa-papilon-optimization.png
│   │       ├── image_generation_prompting
│   │       │   └── index.ipynb
│   │       ├── index.md
│   │       ├── llms_txt_generation
│   │       │   └── index.md
│   │       ├── math
│   │       │   ├── index.ipynb
│   │       │   └── mlflow-tracing-math.png
│   │       ├── mcp
│   │       │   └── index.md
│   │       ├── mem0_react_agent
│   │       │   └── index.md
│   │       ├── multihop_search
│   │       │   ├── index.ipynb
│   │       │   └── mlflow-tracing-multi-hop.png
│   │       ├── observability
│   │       │   ├── index.md
│   │       │   ├── mlflow_trace_ui_navigation.gif
│   │       │   ├── mlflow_trace_ui.png
│   │       │   └── mlflow_trace_view.png
│   │       ├── optimize_ai_program
│   │       │   └── index.md
│   │       ├── optimizer_tracking
│   │       │   ├── child_run.png
│   │       │   ├── experiment.png
│   │       │   ├── index.md
│   │       │   └── parent_run.png
│   │       ├── output_refinement
│   │       │   └── best-of-n-and-refine.md
│   │       ├── papillon
│   │       │   └── index.md
│   │       ├── program_of_thought
│   │       │   └── index.ipynb
│   │       ├── rag
│   │       │   ├── index.ipynb
│   │       │   └── mlflow-tracing-rag.png
│   │       ├── real_world_examples
│   │       │   └── index.md
│   │       ├── rl_ai_program
│   │       │   └── index.md
│   │       ├── rl_multihop
│   │       │   └── index.ipynb
│   │       ├── rl_papillon
│   │       │   └── index.ipynb
│   │       ├── sample_code_generation
│   │       │   └── index.md
│   │       ├── saving
│   │       │   └── index.md
│   │       ├── streaming
│   │       │   └── index.md
│   │       ├── tool_use
│   │       │   └── index.ipynb
│   │       └── yahoo_finance_react
│   │           └── index.md
│   ├── mkdocs.yml
│   ├── overrides
│   │   ├── home.html
│   │   ├── main.html
│   │   └── partials
│   │       └── tabs.html
│   ├── Pipfile
│   ├── Pipfile.lock
│   ├── README.md
│   ├── requirements.txt
│   ├── scripts
│   │   ├── generate_api_docs.py
│   │   └── generate_api_summary.py
│   └── vercel.json
├── dspy
│   ├── __init__.py
│   ├── __metadata__.py
│   ├── adapters
│   │   ├── __init__.py
│   │   ├── baml_adapter.py
│   │   ├── base.py
│   │   ├── chat_adapter.py
│   │   ├── json_adapter.py
│   │   ├── two_step_adapter.py
│   │   ├── types
│   │   │   ├── __init__.py
│   │   │   ├── audio.py
│   │   │   ├── base_type.py
│   │   │   ├── citation.py
│   │   │   ├── code.py
│   │   │   ├── document.py
│   │   │   ├── history.py
│   │   │   ├── image.py
│   │   │   └── tool.py
│   │   ├── utils.py
│   │   └── xml_adapter.py
│   ├── clients
│   │   ├── __init__.py
│   │   ├── base_lm.py
│   │   ├── cache.py
│   │   ├── databricks.py
│   │   ├── embedding.py
│   │   ├── lm_local_arbor.py
│   │   ├── lm_local.py
│   │   ├── lm.py
│   │   ├── openai.py
│   │   ├── provider.py
│   │   └── utils_finetune.py
│   ├── datasets
│   │   ├── __init__.py
│   │   ├── alfworld
│   │   │   ├── __init__.py
│   │   │   ├── alfworld.py
│   │   │   └── base_config.yml
│   │   ├── colors.py
│   │   ├── dataloader.py
│   │   ├── dataset.py
│   │   ├── gsm8k.py
│   │   ├── hotpotqa.py
│   │   └── math.py
│   ├── dsp
│   │   ├── __init__.py
│   │   ├── colbertv2.py
│   │   └── utils
│   │       ├── __init__.py
│   │       ├── dpr.py
│   │       ├── settings.py
│   │       └── utils.py
│   ├── evaluate
│   │   ├── __init__.py
│   │   ├── auto_evaluation.py
│   │   ├── evaluate.py
│   │   └── metrics.py
│   ├── experimental
│   │   └── __init__.py
│   ├── predict
│   │   ├── __init__.py
│   │   ├── aggregation.py
│   │   ├── avatar
│   │   │   ├── __init__.py
│   │   │   ├── avatar.py
│   │   │   ├── models.py
│   │   │   └── signatures.py
│   │   ├── best_of_n.py
│   │   ├── chain_of_thought.py
│   │   ├── code_act.py
│   │   ├── knn.py
│   │   ├── multi_chain_comparison.py
│   │   ├── parallel.py
│   │   ├── parameter.py
│   │   ├── predict.py
│   │   ├── program_of_thought.py
│   │   ├── react.py
│   │   ├── refine.py
│   │   └── retry.py
│   ├── primitives
│   │   ├── __init__.py
│   │   ├── base_module.py
│   │   ├── example.py
│   │   ├── module.py
│   │   ├── prediction.py
│   │   ├── python_interpreter.py
│   │   └── runner.js
│   ├── propose
│   │   ├── __init__.py
│   │   ├── dataset_summary_generator.py
│   │   ├── grounded_proposer.py
│   │   ├── propose_base.py
│   │   └── utils.py
│   ├── retrievers
│   │   ├── __init__.py
│   │   ├── databricks_rm.py
│   │   ├── embeddings.py
│   │   ├── retrieve.py
│   │   └── weaviate_rm.py
│   ├── signatures
│   │   ├── __init__.py
│   │   ├── field.py
│   │   ├── signature.py
│   │   └── utils.py
│   ├── streaming
│   │   ├── __init__.py
│   │   ├── messages.py
│   │   ├── streamify.py
│   │   └── streaming_listener.py
│   ├── teleprompt
│   │   ├── __init__.py
│   │   ├── avatar_optimizer.py
│   │   ├── bettertogether.py
│   │   ├── bootstrap_finetune.py
│   │   ├── bootstrap_trace.py
│   │   ├── bootstrap.py
│   │   ├── copro_optimizer.py
│   │   ├── ensemble.py
│   │   ├── gepa
│   │   │   ├── __init__.py
│   │   │   ├── gepa_utils.py
│   │   │   ├── gepa.py
│   │   │   └── instruction_proposal.py
│   │   ├── grpo.py
│   │   ├── infer_rules.py
│   │   ├── knn_fewshot.py
│   │   ├── mipro_optimizer_v2.py
│   │   ├── random_search.py
│   │   ├── signature_opt.py
│   │   ├── simba_utils.py
│   │   ├── simba.py
│   │   ├── teleprompt_optuna.py
│   │   ├── teleprompt.py
│   │   ├── utils.py
│   │   └── vanilla.py
│   └── utils
│       ├── __init__.py
│       ├── annotation.py
│       ├── asyncify.py
│       ├── caching.py
│       ├── callback.py
│       ├── dummies.py
│       ├── exceptions.py
│       ├── hasher.py
│       ├── inspect_history.py
│       ├── langchain_tool.py
│       ├── logging_utils.py
│       ├── mcp.py
│       ├── parallelizer.py
│       ├── saving.py
│       ├── syncify.py
│       ├── unbatchify.py
│       └── usage_tracker.py
├── LICENSE
├── pyproject.toml
├── README.md
├── tests
│   ├── __init__.py
│   ├── adapters
│   │   ├── test_adapter_utils.py
│   │   ├── test_baml_adapter.py
│   │   ├── test_base_type.py
│   │   ├── test_chat_adapter.py
│   │   ├── test_citation.py
│   │   ├── test_code.py
│   │   ├── test_document.py
│   │   ├── test_json_adapter.py
│   │   ├── test_tool.py
│   │   ├── test_two_step_adapter.py
│   │   └── test_xml_adapter.py
│   ├── callback
│   │   └── test_callback.py
│   ├── clients
│   │   ├── test_cache.py
│   │   ├── test_databricks.py
│   │   ├── test_embedding.py
│   │   ├── test_inspect_global_history.py
│   │   └── test_lm.py
│   ├── conftest.py
│   ├── datasets
│   │   └── test_dataset.py
│   ├── docs
│   │   └── test_mkdocs_links.py
│   ├── evaluate
│   │   ├── test_evaluate.py
│   │   └── test_metrics.py
│   ├── examples
│   │   └── test_baleen.py
│   ├── metadata
│   │   └── test_metadata.py
│   ├── predict
│   │   ├── test_aggregation.py
│   │   ├── test_best_of_n.py
│   │   ├── test_chain_of_thought.py
│   │   ├── test_code_act.py
│   │   ├── test_knn.py
│   │   ├── test_multi_chain_comparison.py
│   │   ├── test_parallel.py
│   │   ├── test_predict.py
│   │   ├── test_program_of_thought.py
│   │   ├── test_react.py
│   │   ├── test_refine.py
│   │   └── test_retry.py
│   ├── primitives
│   │   ├── resources
│   │   │   └── saved_program.json
│   │   ├── test_base_module.py
│   │   ├── test_example.py
│   │   ├── test_module.py
│   │   └── test_python_interpreter.py
│   ├── propose
│   │   └── test_grounded_proposer.py
│   ├── README.md
│   ├── reliability
│   │   ├── __init__.py
│   │   ├── complex_types
│   │   │   └── generated
│   │   │       ├── test_many_types_1
│   │   │       │   ├── inputs
│   │   │       │   │   ├── input1.json
│   │   │       │   │   └── input2.json
│   │   │       │   ├── program.py
│   │   │       │   └── schema.json
│   │   │       ├── test_nesting_1
│   │   │       │   ├── inputs
│   │   │       │   │   ├── input1.json
│   │   │       │   │   └── input2.json
│   │   │       │   ├── program.py
│   │   │       │   └── schema.json
│   │   │       └── test_nesting_2
│   │   │           ├── inputs
│   │   │           │   └── input1.json
│   │   │           ├── program.py
│   │   │           └── schema.json
│   │   ├── conftest.py
│   │   ├── generate
│   │   │   ├── __init__.py
│   │   │   ├── __main__.py
│   │   │   └── utils.py
│   │   ├── input_formats
│   │   │   └── generated
│   │   │       └── test_markdown_1
│   │   │           ├── inputs
│   │   │           │   ├── input1.json
│   │   │           │   └── input2.json
│   │   │           ├── program.py
│   │   │           └── schema.json
│   │   ├── README.md
│   │   ├── reliability_conf.yaml
│   │   ├── test_generated.py
│   │   ├── test_pydantic_models.py
│   │   └── utils.py
│   ├── retrievers
│   │   └── test_embeddings.py
│   ├── signatures
│   │   ├── test_adapter_image.py
│   │   ├── test_custom_types.py
│   │   └── test_signature.py
│   ├── streaming
│   │   └── test_streaming.py
│   ├── teleprompt
│   │   ├── gepa_dummy_lm_custom_component_selector_custom_instruction_proposer.json
│   │   ├── gepa_dummy_lm.json
│   │   ├── test_bootstrap_finetune.py
│   │   ├── test_bootstrap_trace.py
│   │   ├── test_bootstrap.py
│   │   ├── test_copro_optimizer.py
│   │   ├── test_ensemble.py
│   │   ├── test_finetune.py
│   │   ├── test_gepa_instruction_proposer.py
│   │   ├── test_gepa.py
│   │   ├── test_grpo.py
│   │   ├── test_knn_fewshot.py
│   │   ├── test_random_search.py
│   │   ├── test_teleprompt.py
│   │   └── test_utils.py
│   ├── test_utils
│   │   ├── __init__.py
│   │   └── server
│   │       ├── __init__.py
│   │       ├── litellm_server_config.yaml
│   │       └── litellm_server.py
│   └── utils
│       ├── __init__.py
│       ├── resources
│       │   └── mcp_server.py
│       ├── test_annotation.py
│       ├── test_asyncify.py
│       ├── test_exceptions.py
│       ├── test_langchain_tool.py
│       ├── test_mcp.py
│       ├── test_parallelizer.py
│       ├── test_saving.py
│       ├── test_settings.py
│       ├── test_syncify.py
│       ├── test_unbatchify.py
│       └── test_usage_tracker.py
└── uv.lock
```

# Files

--------------------------------------------------------------------------------
/dspy/teleprompt/copro_optimizer.py:
--------------------------------------------------------------------------------

```python
import logging
from collections import defaultdict

import dspy
from dspy.evaluate.evaluate import Evaluate
from dspy.signatures import Signature
from dspy.teleprompt.teleprompt import Teleprompter

logger = logging.getLogger(__name__)

"""
USAGE SUGGESTIONS:

The following code can be used to compile a optimized signature teleprompter, and evaluate it on an end task:

teleprompter = COPRO(prompt_model=prompt_model, metric=metric, breadth=BREADTH, depth=DEPTH, init_temperature=INIT_TEMPERATURE)
kwargs = dict(num_threads=NUM_THREADS, display_progress=True, display_table=0)
compiled_prompt_opt = teleprompter.compile(program.deepcopy(), trainset=trainset[:DEV_NUM], eval_kwargs=kwargs)
eval_score = evaluate(compiled_prompt_opt, devset=evalset[:EVAL_NUM], **kwargs)

Note that this teleprompter takes in the following parameters:

* prompt_model: The model used for prompt generation. When unspecified, defaults to the model set in settings (ie. dspy.settings.configure(lm=task_model)).
* metric: The task metric used for optimization.
* breadth: The number of new prompts to generate at each iteration. Default=10.
* depth: The number of times we should ask our prompt model to generate new prompts, with the history of the past prompts as input. Default=3.
* init_temperature: The temperature used to generate new prompts. Higher roughly equals more creative. Default=1.4.
* track_stats: Tells the method whether or not to track statistics about the optimization process.
                If True, the method will track the following statistics:
                    * results_best: The min,max,avg,stddev of top 10 scores for each predictor at each depth.
                    * results_latest: The min,max,avg,stddev of newest prompt scores for each predictor at each depth.
                    * total_calls: The total number of calls to the task metric.
                These statistics will be returned as attributes of the best program.
"""


class BasicGenerateInstruction(Signature):
    """You are an instruction optimizer for large language models. I will give you a ``signature`` of fields (inputs and outputs) in English. Your task is to propose an instruction that will lead a good language model to perform the task well. Don't be afraid to be creative."""

    basic_instruction = dspy.InputField(desc="The initial instructions before optimization")
    proposed_instruction = dspy.OutputField(desc="The improved instructions for the language model")
    proposed_prefix_for_output_field = dspy.OutputField(
        desc="The string at the end of the prompt, which will help the model start solving the task",
    )


class GenerateInstructionGivenAttempts(dspy.Signature):
    """You are an instruction optimizer for large language models. I will give some task instructions I've tried, along with their corresponding validation scores. The instructions are arranged in increasing order based on their scores, where higher scores indicate better quality.

    Your task is to propose a new instruction that will lead a good language model to perform the task even better. Don't be afraid to be creative."""

    attempted_instructions = dspy.InputField()
    proposed_instruction = dspy.OutputField(desc="The improved instructions for the language model")
    proposed_prefix_for_output_field = dspy.OutputField(
        desc="The string at the end of the prompt, which will help the model start solving the task",
    )


class COPRO(Teleprompter):
    def __init__(
        self,
        prompt_model=None,
        metric=None,
        breadth=10,
        depth=3,
        init_temperature=1.4,
        track_stats=False,
        **_kwargs,
    ):
        if breadth <= 1:
            raise ValueError("Breadth must be greater than 1")
        self.metric = metric
        self.breadth = breadth
        self.depth = depth
        self.init_temperature = init_temperature
        self.prompt_model = prompt_model
        self.track_stats = track_stats

    def _check_candidates_equal(self, candidate1, candidate2):
        for p1, p2 in zip(candidate1["program"].predictors(), candidate2["program"].predictors(), strict=False):
            if self._get_signature(p1).instructions != self._get_signature(p2).instructions:
                return False
            *_, p1_last_field = self._get_signature(p1).fields.values()
            *_, p2_last_field = self._get_signature(p2).fields.values()
            if p1_last_field != p2_last_field:
                return False
        return True

    def _drop_duplicates(self, candidates):
        final_candidates = []
        last_batch = []
        last_batch_score = -1
        for c in candidates:
            repeat = False
            if c["score"] == last_batch_score:
                for c2 in last_batch:
                    if self._check_candidates_equal(c, c2):
                        repeat = True
                        break
                if not repeat:
                    last_batch.append(c)
            else:
                last_batch = [c]
                last_batch_score = c["score"]
            if not repeat:
                final_candidates.append(c)
        return final_candidates

    def _print_signature(self, predictor):
        signature = self._get_signature(predictor)

        logger.debug(f"i: {signature.instructions}")
        logger.debug(f"p: {list(signature.fields.values())[-1].json_schema_extra['prefix']}")

    def _get_signature(self, predictor):
        assert hasattr(predictor, "signature")
        return predictor.signature

    def _set_signature(self, predictor, updated_signature):
        assert hasattr(predictor, "signature")
        predictor.signature = updated_signature

    def compile(self, student, *, trainset, eval_kwargs):
        """
        optimizes `signature` of `student` program - note that it may be zero-shot or already pre-optimized (demos already chosen - `demos != []`)

        parameters:
        student: program to optimize and left modified.
        trainset: iterable of `Example`s
        eval_kwargs: optional, dict
           Additional keywords to go into `Evaluate` for the metric.

        Returns optimized version of `student`.
        """
        module = student.deepcopy()
        evaluate = Evaluate(devset=trainset, metric=self.metric, **eval_kwargs)
        total_calls = 0
        results_best = {
            id(p): {"depth": [], "max": [], "average": [], "min": [], "std": []} for p in module.predictors()
        }
        results_latest = {
            id(p): {"depth": [], "max": [], "average": [], "min": [], "std": []} for p in module.predictors()
        }

        if self.track_stats:
            import numpy as np

        candidates = {}
        evaluated_candidates = defaultdict(dict)

        # Seed the prompt optimizer zero shot with just the instruction, generate BREADTH new prompts
        for predictor in module.predictors():
            basic_instruction = None
            basic_prefix = None
            *_, last_key = self._get_signature(predictor).fields.keys()
            basic_instruction = self._get_signature(predictor).instructions
            basic_prefix = self._get_signature(predictor).fields[last_key].json_schema_extra["prefix"]
            if self.prompt_model:
                with dspy.settings.context(lm=self.prompt_model):
                    instruct = dspy.Predict(
                        BasicGenerateInstruction,
                        n=self.breadth - 1,
                        temperature=self.init_temperature,
                    )(basic_instruction=basic_instruction)
            else:
                instruct = dspy.Predict(
                    BasicGenerateInstruction,
                    n=self.breadth - 1,
                    temperature=self.init_temperature,
                )(basic_instruction=basic_instruction)
            # Add in our initial prompt as a candidate as well
            instruct.completions.proposed_instruction.append(basic_instruction)
            instruct.completions.proposed_prefix_for_output_field.append(basic_prefix)
            candidates[id(predictor)] = instruct.completions
            evaluated_candidates[id(predictor)] = {}

        if self.prompt_model:
            logger.debug(f"{self.prompt_model.inspect_history(n=1)}")

        latest_candidates = candidates
        all_candidates = candidates

        module_clone = module.deepcopy()

        # For each iteration in depth...
        for d in range(
            self.depth,
        ):  # TODO: fix this so that we eval the new batch of predictors with the new best following predictors
            logger.info(f"Iteration Depth: {d+1}/{self.depth}.")

            latest_scores = []

            # Go through our module's predictors
            for p_i, (p_old, p_new) in enumerate(zip(module.predictors(), module_clone.predictors(), strict=False)):
                candidates_ = latest_candidates[id(p_old)]  # Use the most recently generated candidates for evaluation
                if len(module.predictors()) > 1:
                    # Unless our program has multiple predictors, in which case we need to reevaluate all prompts with
                    # the new prompt(s) for the other predictor(s).
                    candidates_ = all_candidates[
                        id(p_old)
                    ]

                # For each candidate
                for c_i, c in enumerate(candidates_):
                    # Get the candidate instruction and prefix
                    instruction, prefix = (
                        c.proposed_instruction.strip('"').strip(),
                        c.proposed_prefix_for_output_field.strip('"').strip(),
                    )

                    # Set this new module with our instruction / prefix
                    *_, last_key = self._get_signature(p_new).fields.keys()
                    updated_signature = (
                        self._get_signature(p_new)
                        .with_instructions(instruction)
                        .with_updated_fields(last_key, prefix=prefix)
                    )
                    self._set_signature(p_new, updated_signature)

                    # Score the instruction / prefix
                    for i, predictor in enumerate(module_clone.predictors()):
                        logger.debug(f"Predictor {i+1}")
                        self._print_signature(predictor)
                    logger.info(
                        f"At Depth {d+1}/{self.depth}, Evaluating Prompt Candidate #{c_i+1}/{len(candidates_)} for "
                        f"Predictor {p_i+1} of {len(module.predictors())}.",
                    )
                    score = evaluate(module_clone, devset=trainset, **eval_kwargs).score
                    if self.prompt_model:
                        logger.debug(f"prompt_model.inspect_history(n=1) {self.prompt_model.inspect_history(n=1)}")
                    total_calls += 1

                    replace_entry = True
                    logger.debug(f"(instruction, prefix) {(instruction, prefix)}")
                    if (instruction, prefix) in evaluated_candidates[id(p_old)]:
                        if evaluated_candidates[id(p_old)][(instruction, prefix)]["score"] >= score:
                            replace_entry = False

                    if replace_entry:
                        # Add it to our evaluated candidates list
                        evaluated_candidates[id(p_old)][(instruction, prefix)] = {
                            "score": score,
                            "program": module_clone.deepcopy(),
                            "instruction": instruction,
                            "prefix": prefix,
                            "depth": d,
                        }

                    if len(candidates_) - self.breadth <= c_i:
                        latest_scores.append(score)

                if self.track_stats:
                    results_latest[id(p_old)]["depth"].append(d)
                    results_latest[id(p_old)]["max"].append(max(latest_scores))
                    results_latest[id(p_old)]["average"].append(sum(latest_scores) / len(latest_scores))
                    results_latest[id(p_old)]["min"].append(min(latest_scores))
                    results_latest[id(p_old)]["std"].append(np.std(latest_scores))

                # Now that we've evaluated the candidates, set this predictor to the best performing version
                # to ensure the next round of scores reflect the best possible version
                best_candidate = max(evaluated_candidates[id(p_old)].values(), key=lambda candidate: candidate["score"])
                *_, last_key = self._get_signature(p_old).fields.keys()
                updated_signature = (
                    self._get_signature(p_new)
                    .with_instructions(best_candidate["instruction"])
                    .with_updated_fields(last_key, prefix=best_candidate["prefix"])
                )
                self._set_signature(p_new, updated_signature)

                logger.debug(
                    f"Updating Predictor {id(p_old)} to:\ni: {best_candidate['instruction']}\n"
                    f"p: {best_candidate['prefix']}",
                )
                logger.debug("Full predictor with update: ")
                for i, predictor in enumerate(module_clone.predictors()):
                    logger.debug(f"Predictor {i}")
                    self._print_signature(predictor)

            if d == self.depth - 1:
                break

            new_candidates = {}
            for p_base in module.predictors():
                # Build Few-Shot Example of Optimized Prompts
                attempts = []
                shortest_len = self.breadth
                shortest_len = min(len(evaluated_candidates[id(p_base)]), shortest_len)
                best_predictors = list(evaluated_candidates[id(p_base)].values())

                # best_predictors = evaluated_candidates[id(p_base)].values()[:]
                best_predictors.sort(key=lambda x: x["score"], reverse=True)

                if self.track_stats:
                    scores = [x["score"] for x in best_predictors][:10]
                    results_best[id(p_base)]["depth"].append(d)
                    results_best[id(p_base)]["max"].append(max(scores))
                    results_best[id(p_base)]["average"].append(sum(scores) / len(scores))
                    results_best[id(p_base)]["min"].append(min(scores))
                    results_best[id(p_base)]["std"].append(np.std(scores))

                for i in range(shortest_len - 1, -1, -1):
                    # breakpoint()
                    attempts.append(f'Instruction #{shortest_len-i}: {best_predictors[i]["instruction"]}')
                    attempts.append(f'Prefix #{shortest_len-i}: {best_predictors[i]["prefix"]}')
                    attempts.append(f'Resulting Score #{shortest_len-i}: {best_predictors[i]["score"]}')

                # Generate next batch of potential prompts to optimize, with previous attempts as input
                if self.prompt_model:
                    with dspy.settings.context(lm=self.prompt_model):
                        instr = dspy.Predict(
                            GenerateInstructionGivenAttempts,
                            n=self.breadth,
                            temperature=self.init_temperature,
                        )(attempted_instructions=attempts)
                else:
                    instr = dspy.Predict(
                        GenerateInstructionGivenAttempts,
                        n=self.breadth,
                        temperature=self.init_temperature,
                    )(attempted_instructions=attempts)

                # Get candidates for each predictor
                new_candidates[id(p_base)] = instr.completions
                all_candidates[id(p_base)].proposed_instruction.extend(instr.completions.proposed_instruction)
                all_candidates[id(p_base)].proposed_prefix_for_output_field.extend(
                    instr.completions.proposed_prefix_for_output_field,
                )

            latest_candidates = new_candidates

        candidates = []
        for predictor in module.predictors():
            candidates.extend(list(evaluated_candidates[id(predictor)].values()))

            if self.track_stats:
                best_predictors = list(evaluated_candidates[id(predictor)].values())
                best_predictors.sort(key=lambda x: x["score"], reverse=True)

                scores = [x["score"] for x in best_predictors][:10]
                results_best[id(predictor)]["depth"].append(d)
                results_best[id(predictor)]["max"].append(max(scores))
                results_best[id(predictor)]["average"].append(sum(scores) / len(scores))
                results_best[id(predictor)]["min"].append(min(scores))
                results_best[id(predictor)]["std"].append(np.std(scores))

        candidates.sort(key=lambda x: x["score"], reverse=True)

        candidates = self._drop_duplicates(candidates)

        best_program = candidates[0]["program"]
        best_program.candidate_programs = candidates
        best_program.total_calls = total_calls
        if self.track_stats:
            best_program.results_best = results_best
            best_program.results_latest = results_latest

        return best_program

```

--------------------------------------------------------------------------------
/dspy/propose/grounded_proposer.py:
--------------------------------------------------------------------------------

```python
import random

import dspy
from dspy.propose.dataset_summary_generator import create_dataset_summary
from dspy.propose.propose_base import Proposer
from dspy.propose.utils import (
    create_example_string,
    create_predictor_level_history_string,
    get_dspy_source_code,
    strip_prefix,
)
from dspy.teleprompt.utils import get_prompt_model, get_signature

# Hardcoded variables (TODO: update)
MAX_INSTRUCT_IN_HISTORY = 5  # 10

TIPS = {
        "none": "",
        "creative": "Don't be afraid to be creative when creating the new instruction!",
        "simple": "Keep the instruction clear and concise.",
        "description": "Make sure your instruction is very informative and descriptive.",
        "high_stakes": "The instruction should include a high stakes scenario in which the LM must solve the task!",
        "persona": 'Include a persona that is relevant to the task in the instruction (ie. "You are a ...")',
    }

### SIGNATURES USED TO HELP WITH INSTRUCTION GENERATION ###

class DescribeProgram(dspy.Signature):
    (
        """Below is some pseudo-code for a pipeline that solves tasks with calls to language models. Please describe what type of task this program appears to be designed to solve, and how it appears to work."""
    )
    program_code = dspy.InputField(
        format=str,
        desc="Pseudocode for a language model program designed to solve a particular task.",
        prefix="PROGRAM CODE:",
    )
    program_example = dspy.InputField(
        format=str,
        desc="An example of the program in use.",
        prefix="EXAMPLE OF PROGRAM IN USE:",
    )
    program_description = dspy.OutputField(
        desc="Describe what task the program is designed to solve, and how it goes about solving this task.",
        prefix="SUMMARY OF PROGRAM ABOVE:",
    )


class DescribeModule(dspy.Signature):
    (
        """Below is some pseudo-code for a pipeline that solves tasks with calls to language models. Please describe the purpose of one of the specified module in this pipeline."""
    )
    program_code = dspy.InputField(
        format=str,
        desc="Pseudocode for a language model program designed to solve a particular task.",
        prefix="PROGRAM CODE:",
    )
    program_example = dspy.InputField(
        format=str,
        desc="An example of the program in use.",
        prefix="EXAMPLE OF PROGRAM IN USE:",
    )
    program_description = dspy.InputField(
        desc="Summary of the task the program is designed to solve, and how it goes about solving it.",
        prefix="SUMMARY OF PROGRAM ABOVE:",
    )
    module = dspy.InputField(
        desc="The module in the program that we want to describe.", prefix="MODULE:",
    )
    module_description = dspy.OutputField(
        desc="Description of the module's role in the broader program.",
        prefix="MODULE DESCRIPTION:",
    )


def generate_instruction_class(
    use_dataset_summary=True,
    program_aware=True,
    use_task_demos=True,
    use_instruct_history=True,
    use_tip=True,
):
    class GenerateSingleModuleInstruction(dspy.Signature):
        (
            """Use the information below to learn about a task that we are trying to solve using calls to an LM, then generate a new instruction that will be used to prompt a Language Model to better solve the task."""
        )
        if use_dataset_summary:
            dataset_description = dspy.InputField(
                desc="A description of the dataset that we are using.",
                prefix="DATASET SUMMARY:",
            )
        if program_aware:
            program_code = dspy.InputField(
                format=str,
                desc="Language model program designed to solve a particular task.",
                prefix="PROGRAM CODE:",
            )
            program_description = dspy.InputField(
                desc="Summary of the task the program is designed to solve, and how it goes about solving it.",
                prefix="PROGRAM DESCRIPTION:",
            )
            module = dspy.InputField(
                desc="The module to create an instruction for.", prefix="MODULE:",
            )
            module_description = dspy.InputField(
                desc="Description of the module to create an instruction for.", prefix="MODULE DESCRIPTION:",
            )
        task_demos = dspy.InputField(
            format=str,
            desc="Example inputs/outputs of our module.",
            prefix="TASK DEMO(S):",
        )
        if use_instruct_history:
            previous_instructions = dspy.InputField(
                format=str,
                desc="Previous instructions we've attempted, along with their associated scores.",
                prefix="PREVIOUS INSTRUCTIONS:",
            )
        basic_instruction = dspy.InputField(
            format=str, desc="Basic instruction.", prefix="BASIC INSTRUCTION:",
        )
        if use_tip:
            tip = dspy.InputField(
                format=str,
                desc="A suggestion for how to go about generating the new instruction.",
                prefix="TIP:",
            )
        proposed_instruction = dspy.OutputField(
            desc="Propose an instruction that will be used to prompt a Language Model to perform this task.",
            prefix="PROPOSED INSTRUCTION:",
        )

    return dspy.Predict(GenerateSingleModuleInstruction)

### CLASS RESPONSIBLE FOR GENERATING A NEW INSTRUCTION, USING THE HELPER SIGNATURES ABOVE ###

class GenerateModuleInstruction(dspy.Module):
    def __init__(
        self,
        program_code_string=None,
        use_dataset_summary=True,
        program_aware=False,
        use_task_demos=True,
        use_instruct_history=True,
        use_tip=True,
        verbose=False,
    ):
        super().__init__()
        self.use_dataset_summary = use_dataset_summary
        self.program_aware = program_aware
        self.use_task_demos = use_task_demos
        self.use_instruct_history = use_instruct_history
        self.use_tip = use_tip
        self.verbose = verbose

        self.program_code_string = program_code_string
        self.describe_program = dspy.Predict(DescribeProgram)
        self.describe_module = dspy.Predict(DescribeModule)
        self.generate_module_instruction = generate_instruction_class(
            use_dataset_summary=use_dataset_summary,
            program_aware=program_aware,
            use_task_demos=use_task_demos,
            use_instruct_history=use_instruct_history,
            use_tip=use_tip,
        )

    def forward(
        self,
        demo_candidates,
        pred_i,
        demo_set_i,
        program,
        previous_instructions,
        data_summary,
        num_demos_in_context=3,
        tip=None,
    ):
        def gather_examples_from_sets(candidate_sets, max_examples):
            """Helper function to gather up to augmented examples from given sets."""
            count = 0
            for candidate_set in candidate_sets:
                for example in candidate_set:
                    if "augmented" in example.keys():
                        fields_to_use = get_signature(program.predictors()[pred_i]).fields
                        yield create_example_string(fields_to_use, example)
                        count += 1
                        if count >= max_examples:
                            return

        # Construct full program demo or single module demo depending on settings
        basic_instruction = get_signature(program.predictors()[pred_i]).instructions
        task_demos = ""

        if self.use_task_demos:
            # Combine current and adjacent sets
            adjacent_sets = (
                [demo_candidates[pred_i][demo_set_i]] +
                demo_candidates[pred_i][demo_set_i + 1:] +
                demo_candidates[pred_i][:demo_set_i]
            )

            # Gather examples up to the required count
            example_strings = gather_examples_from_sets(adjacent_sets, num_demos_in_context)
            task_demos = "\n\n".join(example_strings) + "\n\n"

        # Default to no demos provided if no examples were gathered, or if we're using the first demo set
        if not task_demos.strip() or demo_set_i == 0:
            task_demos = "No task demos provided."

        # Summarize the program
        program_description = "Not available"
        module_code = "Not provided"
        module_description = "Not provided"
        if self.program_aware:
            try:
                program_description = strip_prefix(
                    self.describe_program(
                        program_code=self.program_code_string, program_example=task_demos,
                    ).program_description,
                )
                if self.verbose:
                    print(f"PROGRAM DESCRIPTION: {program_description}")

                inputs = []
                outputs = []
                for field_name, field in get_signature(program.predictors()[pred_i]).fields.items():
                    # Access the '__dspy_field_type' from the extra metadata
                    dspy_field_type = field.json_schema_extra.get("__dspy_field_type")

                    # Based on the '__dspy_field_type', append to the respective list
                    if dspy_field_type == "input":
                        inputs.append(field_name)
                    else:
                        outputs.append(field_name)

                module_code = f"{program.predictors()[pred_i].__class__.__name__}({', '.join(inputs)}) -> {', '.join(outputs)}"

                module_description = self.describe_module(
                    program_code=self.program_code_string,
                    program_description=program_description,
                    program_example=task_demos,
                    module=module_code,
                    max_depth=10,
                ).module_description
            except Exception as e:
                if self.verbose:
                    print(f"Error getting program description. Running without program aware proposer. Error: {e}")
                self.program_aware = False

        # Generate an instruction for our chosen module
        if self.verbose:
            print(f"task_demos {task_demos}")

        instruct = self.generate_module_instruction(
            dataset_description=data_summary,
            program_code=self.program_code_string,
            module=module_code,
            program_description=program_description,
            module_description=module_description,
            task_demos=task_demos,
            tip=tip,
            basic_instruction=basic_instruction,
            previous_instructions=previous_instructions,
        )

        proposed_instruction = strip_prefix(instruct.proposed_instruction)

        return dspy.Prediction(proposed_instruction=proposed_instruction)

### CLASS USED TO GENERATE THE FULL SET OF INSTRUCTIONS GIVEN THE SPECIFIED CRITERIA ###

class GroundedProposer(Proposer):
    def __init__(
        self,
        prompt_model,
        program,
        trainset,
        view_data_batch_size=10,
        use_dataset_summary=True,
        program_aware=True,
        use_task_demos=True,
        num_demos_in_context = 3,
        use_instruct_history=True,
        use_tip=True,
        set_tip_randomly=True,
        set_history_randomly=True,
        verbose=False,
        rng=None,
        init_temperature: float = 1.0,
    ):
        super().__init__()
        self.program_aware = program_aware
        self.use_dataset_summary = use_dataset_summary
        self.use_task_demos = use_task_demos
        self.num_demos_in_context = num_demos_in_context
        self.use_instruct_history = use_instruct_history
        self.use_tip = use_tip
        self.set_tip_randomly=set_tip_randomly
        self.set_history_randomly=set_history_randomly
        self.verbose = verbose
        self.rng = rng or random

        self.prompt_model = get_prompt_model(prompt_model)
        self.init_temperature = init_temperature

        self.program_code_string = None
        if self.program_aware:
            try:
                self.program_code_string = get_dspy_source_code(program)
                if self.verbose:
                    print("SOURCE CODE:",self.program_code_string)
            except Exception as e:
                print(f"Error getting source code: {e}.\n\nRunning without program aware proposer.")
                self.program_aware = False

        self.data_summary  = None
        if self.use_dataset_summary:
            try:
                self.data_summary = create_dataset_summary(
                    trainset=trainset, view_data_batch_size=view_data_batch_size, prompt_model=prompt_model,
                )
                if self.verbose:
                    print(f"DATA SUMMARY: {self.data_summary}")
            except Exception as e:
                print(f"Error getting data summary: {e}.\n\nRunning without data aware proposer.")
                self.use_dataset_summary = False
                print("")

    def propose_instructions_for_program(
        self,
        trainset,
        program,
        demo_candidates,
        trial_logs,
        N, # noqa: N803
    ) -> list[str]:
        """This method is responsible for returning the full set of new instructions for our program, given the specified criteria."""

        proposed_instructions = {}

        if self.set_history_randomly:
            # Randomly select whether or not we're using instruction history
            use_history = self.rng.random() < 0.5
            self.use_instruct_history = use_history
            if self.verbose:
                print(f"Use history T/F: {self.use_instruct_history}")

        if not demo_candidates:
            if self.verbose:
                print("No demo candidates provided. Running without task demos.")
            self.use_task_demos = False
            # When no demo candidates are provided, default to N
            num_demos = N
        else:
            num_demos = max(len(demo_candidates[0]), 1)

        # Create an instruction for each predictor
        for pred_i, predictor in enumerate(program.predictors()):
            for demo_set_i in range(num_demos)[:min(N, num_demos)]:
                if pred_i not in proposed_instructions:
                    proposed_instructions[pred_i] = []
                selected_tip = None
                if self.set_tip_randomly:
                    if self.verbose:
                        print("Using a randomly generated configuration for our grounded proposer.")
                    # Randomly select the tip
                    selected_tip_key = self.rng.choice(list(TIPS.keys()))
                    selected_tip = TIPS[selected_tip_key]
                    self.use_tip = bool(
                        selected_tip,
                    )
                    if self.verbose:
                        print(f"Selected tip: {selected_tip_key}")

                proposed_instructions[pred_i].append(
                    self.propose_instruction_for_predictor(
                        program=program,
                        predictor=predictor,
                        pred_i=pred_i,
                        demo_candidates=demo_candidates,
                        demo_set_i=demo_set_i,
                        trial_logs=trial_logs,
                        tip=selected_tip,
                    ),
                )

        return proposed_instructions

    def propose_instruction_for_predictor(
        self,
        program,
        predictor,
        pred_i,
        demo_candidates,
        demo_set_i,
        trial_logs,
        tip=None,
    ) -> str:
        """This method is responsible for returning a single instruction for a given predictor, using the specified criteria."""

        # Create an instruction history string for our predictor
        instruction_history = create_predictor_level_history_string(
            program, pred_i, trial_logs, MAX_INSTRUCT_IN_HISTORY,
        )

        # Create our instruction generator class (given specific criteria for this round of proposal)
        instruction_generator = GenerateModuleInstruction(
            program_code_string=self.program_code_string,
            use_dataset_summary=self.use_dataset_summary,
            program_aware=self.program_aware,
            use_task_demos=self.use_task_demos and demo_candidates,
            use_instruct_history=self.use_instruct_history and instruction_history,
            use_tip=self.use_tip,
            verbose=self.verbose
        )

        # Generate a new instruction for our predictor using a unique rollout id to bypass cache
        rollout_lm = self.prompt_model.copy(
            rollout_id=self.rng.randint(0, 10**9),
            temperature=self.init_temperature,
        )

        with dspy.settings.context(lm=rollout_lm):
            proposed_instruction = instruction_generator(
                demo_candidates=demo_candidates,
                pred_i=pred_i,
                demo_set_i=demo_set_i,
                program=program,
                data_summary=self.data_summary,
                previous_instructions=instruction_history,
                num_demos_in_context = self.num_demos_in_context,
                tip=tip,
            ).proposed_instruction

        # Log the trace used to generate the new instruction, along with the new instruction itself
        if self.verbose:
            self.prompt_model.inspect_history(n=1)
            print(f"PROPOSED INSTRUCTION: {proposed_instruction}")

        return strip_prefix(proposed_instruction)

```

--------------------------------------------------------------------------------
/tests/adapters/test_tool.py:
--------------------------------------------------------------------------------

```python
import asyncio
from typing import Any

import pytest
from pydantic import BaseModel

import dspy
from dspy.adapters.types.tool import Tool, ToolCalls, convert_input_schema_to_tool_args


# Test fixtures
def dummy_function(x: int, y: str = "hello") -> str:
    """A dummy function for testing.

    Args:
        x: An integer parameter
        y: A string parameter
    """
    return f"{y} {x}"


class DummyModel(BaseModel):
    field1: str = "hello"
    field2: int


def dummy_with_pydantic(model: DummyModel) -> str:
    """A dummy function that accepts a Pydantic model."""
    return f"{model.field1} {model.field2}"


class Address(BaseModel):
    street: str
    city: str
    zip_code: str
    is_primary: bool = False


class ContactInfo(BaseModel):
    email: str
    phone: str | None = None
    addresses: list[Address]


class UserProfile(BaseModel):
    user_id: int
    name: str
    age: int | None = None
    contact: ContactInfo
    tags: list[str] = []


class Note(BaseModel):
    content: str
    author: str


def complex_dummy_function(profile: UserProfile, priority: int, notes: list[Note] | None = None) -> dict[str, Any]:
    """Process user profile with complex nested structure.

    Args:
        profile: User profile containing nested contact and address information
        priority: Priority level of the processing
        notes: Optional processing notes
    """
    primary_address = next(
        (addr for addr in profile.contact.addresses if addr.is_primary), profile.contact.addresses[0]
    )

    return {
        "user_id": profile.user_id,
        "name": profile.name,
        "priority": priority,
        "primary_address": primary_address.model_dump(),
        "notes": notes,
    }


async def async_dummy_function(x: int, y: str = "hello") -> str:
    """An async dummy function for testing.

    Args:
        x: An integer parameter
        y: A string parameter
    """
    await asyncio.sleep(0.1)  # Simulate some async work
    return f"{y} {x}"


async def async_dummy_with_pydantic(model: DummyModel) -> str:
    """An async dummy function that accepts a Pydantic model."""
    await asyncio.sleep(0.1)  # Simulate some async work
    return f"{model.field1} {model.field2}"


async def async_complex_dummy_function(
    profile: UserProfile,
    priority: int,
    notes: list[Note] | None = None,
) -> dict[str, Any]:
    """Process user profile with complex nested structure asynchronously.

    Args:
        profile: User profile containing nested contact and address information
        priority: Priority level of the processing
        notes: Optional processing notes
    """
    # Simulate some async processing work
    await asyncio.sleep(0.1)

    primary_address = next(
        (addr for addr in profile.contact.addresses if addr.is_primary), profile.contact.addresses[0]
    )

    # Simulate more async work after finding primary address
    await asyncio.sleep(0.1)

    return {
        "user_id": profile.user_id,
        "name": profile.name,
        "priority": priority,
        "primary_address": primary_address.model_dump(),
        "notes": notes,
    }


def test_basic_initialization():
    tool = Tool(name="test_tool", desc="A test tool", args={"param1": {"type": "string"}}, func=lambda x: x)
    assert tool.name == "test_tool"
    assert tool.desc == "A test tool"
    assert tool.args == {"param1": {"type": "string"}}
    assert callable(tool.func)


def test_tool_from_function():
    tool = Tool(dummy_function)

    assert tool.name == "dummy_function"
    assert "A dummy function for testing" in tool.desc
    assert "x" in tool.args
    assert "y" in tool.args
    assert tool.args["x"]["type"] == "integer"
    assert tool.args["y"]["type"] == "string"
    assert tool.args["y"]["default"] == "hello"


def test_tool_from_class():
    class Foo:
        def __init__(self, user_id: str):
            self.user_id = user_id

        def __call__(self, a: int, b: int) -> int:
            """Add two numbers."""
            return a + b

    tool = Tool(Foo("123"))
    assert tool.name == "Foo"
    assert tool.desc == "Add two numbers."
    assert tool.args == {"a": {"type": "integer"}, "b": {"type": "integer"}}


def test_tool_from_function_with_pydantic():
    tool = Tool(dummy_with_pydantic)

    assert tool.name == "dummy_with_pydantic"
    assert "model" in tool.args
    assert tool.args["model"]["type"] == "object"
    assert "field1" in tool.args["model"]["properties"]
    assert "field2" in tool.args["model"]["properties"]
    assert tool.args["model"]["properties"]["field1"]["default"] == "hello"


def test_tool_from_function_with_pydantic_nesting():
    tool = Tool(complex_dummy_function)

    assert tool.name == "complex_dummy_function"

    assert "profile" in tool.args
    assert "priority" in tool.args
    assert "notes" in tool.args
    assert tool.args["profile"]["type"] == "object"
    assert tool.args["profile"]["properties"]["user_id"]["type"] == "integer"
    assert tool.args["profile"]["properties"]["name"]["type"] == "string"
    assert tool.args["profile"]["properties"]["age"]["anyOf"] == [{"type": "integer"}, {"type": "null"}]
    assert tool.args["profile"]["properties"]["contact"]["type"] == "object"
    assert tool.args["profile"]["properties"]["contact"]["properties"]["email"]["type"] == "string"

    # Reference should be resolved for nested pydantic models
    assert "$defs" not in str(tool.args["notes"])
    assert tool.args["notes"]["anyOf"][0]["type"] == "array"
    assert tool.args["notes"]["anyOf"][0]["items"]["type"] == "object"
    assert tool.args["notes"]["anyOf"][0]["items"]["properties"]["content"]["type"] == "string"
    assert tool.args["notes"]["anyOf"][0]["items"]["properties"]["author"]["type"] == "string"


def test_tool_callable():
    tool = Tool(dummy_function)
    result = tool(x=42, y="hello")
    assert result == "hello 42"


def test_tool_with_pydantic_callable():
    tool = Tool(dummy_with_pydantic)
    model = DummyModel(field1="test", field2=123)
    result = tool(model=model)
    assert result == "test 123"


def test_invalid_function_call():
    tool = Tool(dummy_function)
    with pytest.raises(ValueError):
        tool(x="not an integer", y="hello")


def test_parameter_desc():
    tool = Tool(dummy_function, arg_desc={"x": "The x parameter"})
    assert tool.args["x"]["description"] == "The x parameter"


def test_tool_with_default_args_without_type_hints():
    def foo(x=100):
        return x

    tool = Tool(foo)
    assert tool.args["x"]["default"] == 100
    assert not hasattr(tool.args["x"], "type")


def test_tool_call_parses_args():
    tool = Tool(dummy_with_pydantic)

    args = {
        "model": {
            "field1": "hello",
            "field2": 123,
        }
    }

    result = tool(**args)
    assert result == "hello 123"


def test_tool_call_parses_nested_list_of_pydantic_model():
    def dummy_function(x: list[list[DummyModel]]):
        return x

    tool = Tool(dummy_function)
    args = {
        "x": [
            [
                {
                    "field1": "hello",
                    "field2": 123,
                }
            ]
        ]
    }

    result = tool(**args)
    assert result == [[DummyModel(field1="hello", field2=123)]]


def test_tool_call_kwarg():
    def fn(x: int, **kwargs):
        return kwargs

    tool = Tool(fn)

    assert tool(x=1, y=2, z=3) == {"y": 2, "z": 3}


def test_tool_str():
    def add(x: int, y: int = 0) -> int:
        """Add two integers."""
        return x + y

    tool = Tool(add)
    assert (
        str(tool)
        == "add, whose description is <desc>Add two integers.</desc>. It takes arguments {'x': {'type': 'integer'}, 'y': {'type': 'integer', 'default': 0}}."
    )


@pytest.mark.asyncio
async def test_async_tool_from_function():
    tool = Tool(async_dummy_function)

    assert tool.name == "async_dummy_function"
    assert "An async dummy function for testing" in tool.desc
    assert "x" in tool.args
    assert "y" in tool.args
    assert tool.args["x"]["type"] == "integer"
    assert tool.args["y"]["type"] == "string"
    assert tool.args["y"]["default"] == "hello"

    # Test async call
    result = await tool.acall(x=42, y="hello")
    assert result == "hello 42"


@pytest.mark.asyncio
async def test_async_tool_with_pydantic():
    tool = Tool(async_dummy_with_pydantic)

    assert tool.name == "async_dummy_with_pydantic"
    assert "model" in tool.args
    assert tool.args["model"]["type"] == "object"
    assert "field1" in tool.args["model"]["properties"]
    assert "field2" in tool.args["model"]["properties"]

    # Test async call with pydantic model
    model = DummyModel(field1="test", field2=123)
    result = await tool.acall(model=model)
    assert result == "test 123"

    # Test async call with dict
    result = await tool.acall(model={"field1": "test", "field2": 123})
    assert result == "test 123"


@pytest.mark.asyncio
async def test_async_tool_with_complex_pydantic():
    tool = Tool(async_complex_dummy_function)

    profile = UserProfile(
        user_id=1,
        name="Test User",
        contact=ContactInfo(
            email="[email protected]",
            addresses=[
                Address(street="123 Main St", city="Test City", zip_code="12345", is_primary=True),
                Address(street="456 Side St", city="Test City", zip_code="12345"),
            ],
        ),
    )

    result = await tool.acall(profile=profile, priority=1, notes=[Note(content="Test note", author="Test author")])
    assert result["user_id"] == 1
    assert result["name"] == "Test User"
    assert result["priority"] == 1
    assert result["notes"] == [Note(content="Test note", author="Test author")]
    assert result["primary_address"]["street"] == "123 Main St"


@pytest.mark.asyncio
async def test_async_tool_invalid_call():
    tool = Tool(async_dummy_function)
    with pytest.raises(ValueError):
        await tool.acall(x="not an integer", y="hello")


@pytest.mark.asyncio
async def test_async_tool_with_kwargs():
    async def fn(x: int, **kwargs):
        return kwargs

    tool = Tool(fn)

    result = await tool.acall(x=1, y=2, z=3)
    assert result == {"y": 2, "z": 3}


@pytest.mark.asyncio
async def test_async_concurrent_calls():
    """Test that multiple async tools can run concurrently."""
    tool = Tool(async_dummy_function)

    # Create multiple concurrent calls
    tasks = [tool.acall(x=i, y=f"hello{i}") for i in range(5)]

    # Run them concurrently and measure time
    start_time = asyncio.get_event_loop().time()
    results = await asyncio.gather(*tasks)
    end_time = asyncio.get_event_loop().time()

    # Verify results, `asyncio.gather` returns results in the order of the tasks
    assert results == [f"hello{i} {i}" for i in range(5)]

    # Check that it ran concurrently (should take ~0.1s, not ~0.5s)
    # We use 0.3s as threshold to account for some overhead
    assert end_time - start_time < 0.3


@pytest.mark.filterwarnings("ignore::RuntimeWarning")
def test_async_tool_call_in_sync_mode():
    tool = Tool(async_dummy_function)
    with dspy.context(allow_tool_async_sync_conversion=False):
        with pytest.raises(ValueError):
            result = tool(x=1, y="hello")

    with dspy.context(allow_tool_async_sync_conversion=True):
        result = tool(x=1, y="hello")
        assert result == "hello 1"


TOOL_CALL_TEST_CASES = [
    ([], {"tool_calls": []}),
    (
        [{"name": "search", "args": {"query": "hello"}}],
        {
            "tool_calls": [{"type": "function", "function": {"name": "search", "arguments": {"query": "hello"}}}],
        },
    ),
    (
        [
            {"name": "search", "args": {"query": "hello"}},
            {"name": "translate", "args": {"text": "world", "lang": "fr"}},
        ],
        {
            "tool_calls": [
                {"type": "function", "function": {"name": "search", "arguments": {"query": "hello"}}},
                {
                    "type": "function",
                    "function": {"name": "translate", "arguments": {"text": "world", "lang": "fr"}},
                },
            ],
        },
    ),
    (
        [{"name": "get_time", "args": {}}],
        {
            "tool_calls": [{"type": "function", "function": {"name": "get_time", "arguments": {}}}],
        },
    ),
]


@pytest.mark.parametrize("tool_calls_data,expected", TOOL_CALL_TEST_CASES)
def test_tool_calls_format_basic(tool_calls_data, expected):
    """Test ToolCalls.format with various basic scenarios."""
    tool_calls_list = [ToolCalls.ToolCall(**data) for data in tool_calls_data]
    tool_calls = ToolCalls(tool_calls=tool_calls_list)
    result = tool_calls.format()

    assert result == expected


def test_tool_calls_format_from_dict_list():
    """Test format works with ToolCalls created from from_dict_list."""
    tool_calls_dicts = [
        {"name": "search", "args": {"query": "hello"}},
        {"name": "translate", "args": {"text": "world", "lang": "fr"}},
    ]

    tool_calls = ToolCalls.from_dict_list(tool_calls_dicts)
    result = tool_calls.format()

    assert len(result["tool_calls"]) == 2
    assert result["tool_calls"][0]["function"]["name"] == "search"
    assert result["tool_calls"][1]["function"]["name"] == "translate"


def test_toolcalls_vague_match():
    """
    Test that ToolCalls can parse the data with slightly off format:

    - a single dict with "name" and "args"
    - a list of dicts with "name" and "args"
    - invalid input (should raise ValueError)
    """
    # Single dict with "name" and "args" should parse as one ToolCall
    data_single = {"name": "search", "args": {"query": "hello"}}
    tc = ToolCalls.model_validate(data_single)
    assert isinstance(tc, ToolCalls)
    assert len(tc.tool_calls) == 1
    assert tc.tool_calls[0].name == "search"
    assert tc.tool_calls[0].args == {"query": "hello"}

    # List of dicts with "name" and "args" should parse as multiple ToolCalls
    data_list = [
        {"name": "search", "args": {"query": "hello"}},
        {"name": "translate", "args": {"text": "world", "lang": "fr"}},
    ]
    tc = ToolCalls.model_validate(data_list)
    assert isinstance(tc, ToolCalls)
    assert len(tc.tool_calls) == 2
    assert tc.tool_calls[0].name == "search"
    assert tc.tool_calls[1].name == "translate"

    # Dict with "tool_calls" key containing a list of dicts
    data_tool_calls = {
        "tool_calls": [
            {"name": "search", "args": {"query": "hello"}},
            {"name": "get_time", "args": {}},
        ]
    }
    tc = ToolCalls.model_validate(data_tool_calls)
    assert isinstance(tc, ToolCalls)
    assert len(tc.tool_calls) == 2
    assert tc.tool_calls[0].name == "search"
    assert tc.tool_calls[1].name == "get_time"

    # Invalid input should raise ValueError
    with pytest.raises(ValueError):
        ToolCalls.model_validate({"foo": "bar"})
    with pytest.raises(ValueError):
        ToolCalls.model_validate([{"foo": "bar"}])


def test_tool_convert_input_schema_to_tool_args_no_input_params():
    args, arg_types, arg_desc = convert_input_schema_to_tool_args(schema={"properties": {}})
    assert args == {}
    assert arg_types == {}
    assert arg_desc == {}


def test_tool_convert_input_schema_to_tool_args_lang_chain():
    # Example from langchain docs:
    # https://web.archive.org/web/20250723101359/https://api.python.langchain.com/en/latest/tools/langchain_core.tools.tool.html
    args, arg_types, arg_desc = convert_input_schema_to_tool_args(
        schema={
            "title": "fooSchema",
            "description": "The foo.",
            "type": "object",
            "properties": {
                "bar": {
                    "title": "Bar",
                    "description": "The bar.",
                    "type": "string",
                },
                "baz": {
                    "title": "Baz",
                    "type": "integer",
                },
            },
            "required": [
                "baz",
            ],
        }
    )
    assert args == {
        "bar": {"title": "Bar", "description": "The bar.", "type": "string"},
        "baz": {"title": "Baz", "type": "integer"},
    }
    assert arg_types == {
        "bar": str,
        "baz": int,
    }
    assert arg_desc == {
        "bar": "The bar.",
        "baz": "No description provided. (Required)",
    }




def test_tool_call_execute():
    def get_weather(city: str) -> str:
        return f"The weather in {city} is sunny"

    def add_numbers(a: int, b: int) -> int:
        return a + b

    tools = [
        dspy.Tool(get_weather),
        dspy.Tool(add_numbers)
    ]

    tool_call = dspy.ToolCalls.ToolCall(name="get_weather", args={"city": "Berlin"})
    result = tool_call.execute(functions=tools)
    assert result == "The weather in Berlin is sunny"

    # Test individual tool call with function dict
    tool_call2 = dspy.ToolCalls.ToolCall(name="add_numbers", args={"a": 7, "b": 13})
    result2 = tool_call2.execute(functions={"add_numbers": add_numbers})
    assert result2 == 20

    # Test individual tool call with no arguments
    def get_pi():
        return 3.14159

    tool_call3 = dspy.ToolCalls.ToolCall(name="get_pi", args={})
    result3 = tool_call3.execute(functions={"get_pi": get_pi})
    assert result3 == 3.14159

    # Test error case
    tool_call4 = dspy.ToolCalls.ToolCall(name="nonexistent", args={})
    try:
        tool_call4.execute(functions=tools)
        assert False, "Should have raised ValueError"
    except ValueError as e:
        assert "not found" in str(e)


def test_tool_call_execute_with_local_functions():
    def main():
        def local_add(a: int, b: int) -> int:
            return a + b

        def local_multiply(x: int, y: int) -> int:
            return x * y

        # Test individual execution with local function
        tool_call1 = dspy.ToolCalls.ToolCall(name="local_add", args={"a": 10, "b": 15})
        result1 = tool_call1.execute()  # Should find local function automatically
        assert result1 == 25

        tool_call2 = dspy.ToolCalls.ToolCall(name="local_multiply", args={"x": 4, "y": 7})
        result2 = tool_call2.execute()  # Should find local function automatically
        assert result2 == 28

        # Test locals take precedence over globals
        try:
            globals()["local_add"] = lambda a, b: a + b + 1000
            precedence_call = dspy.ToolCalls.ToolCall(name="local_add", args={"a": 1, "b": 2})
            result = precedence_call.execute()
            assert result == 3  # Should use local function (1+2=3), not global (1+2+1000=1003)
        finally:
            globals().pop("local_add", None)

    main()

```

--------------------------------------------------------------------------------
/docs/docs/tutorials/streaming/index.md:
--------------------------------------------------------------------------------

```markdown
# Streaming

In this guide, we will walk you through how to enable streaming in your DSPy program. DSPy Streaming
consists of two parts:

- **Output Token Streaming**: Stream individual tokens as they're generated, rather than waiting for the complete response.
- **Intermediate Status Streaming**: Provide real-time updates about the program's execution state (e.g., "Calling web search...", "Processing results...").

## Output Token Streaming

DSPy's token streaming feature works with any module in your pipeline, not just the final output. The only requirement is that the streamed field must be of type `str`. To enable token streaming:

1. Wrap your program with `dspy.streamify`
2. Create one or more `dspy.streaming.StreamListener` objects to specify which fields to stream

Here's a basic example:

```python
import os

import dspy

os.environ["OPENAI_API_KEY"] = "your_api_key"

dspy.configure(lm=dspy.LM("openai/gpt-4o-mini"))

predict = dspy.Predict("question->answer")

# Enable streaming for the 'answer' field
stream_predict = dspy.streamify(
    predict,
    stream_listeners=[dspy.streaming.StreamListener(signature_field_name="answer")],
)
```

To consume the streamed output:

```python
import asyncio

async def read_output_stream():
    output_stream = stream_predict(question="Why did a chicken cross the kitchen?")

    async for chunk in output_stream:
        print(chunk)

asyncio.run(read_output_stream())
```

This will produce output like:

```
StreamResponse(predict_name='self', signature_field_name='answer', chunk='To')
StreamResponse(predict_name='self', signature_field_name='answer', chunk=' get')
StreamResponse(predict_name='self', signature_field_name='answer', chunk=' to')
StreamResponse(predict_name='self', signature_field_name='answer', chunk=' the')
StreamResponse(predict_name='self', signature_field_name='answer', chunk=' other')
StreamResponse(predict_name='self', signature_field_name='answer', chunk=' side of the frying pan!')
Prediction(
    answer='To get to the other side of the frying pan!'
)
```

Note: Since `dspy.streamify` returns an async generator, you must use it within an async context. If you're using an environment like Jupyter or Google Colab that already has an event loop (async context), you can use the generator directly.

You may have noticed that the above streaming contains two different entities: `StreamResponse`
and `Prediction.` `StreamResponse` is the wrapper over streaming tokens on the field being listened to, and in
this example it is the `answer` field. `Prediction` is the program's final output. In DSPy, streaming is
implemented in a sidecar fashion: we enable streaming on the LM so that LM outputs a stream of tokens. We send these
tokens to a side channel, which is being continuously read by the user-defined listeners. Listeners keep interpreting
the stream, and decides if the `signature_field_name` it is listening to has started to appear and has finalized.
Once it decides that the field appears, the listener begins outputting tokens to the async generator users can
read. Listeners' internal mechanism changes according to the adapter behind the scene, and because usually
we cannot decide if a field has finalized until seeing the next field, the listener buffers the output tokens
before sending to the final generator, which is why you will usually see the last chunk of type `StreamResponse`
has more than one token. The program's output is also written to the stream, which is the chunk of `Prediction`
as in the sample output above.

To handle these different types and implement custom logic:

```python
import asyncio

async def read_output_stream():
  output_stream = stream_predict(question="Why did a chicken cross the kitchen?")

  async for chunk in output_stream:
    return_value = None
    if isinstance(chunk, dspy.streaming.StreamResponse):
      print(f"Output token of field {chunk.signature_field_name}: {chunk.chunk}")
    elif isinstance(chunk, dspy.Prediction):
      return_value = chunk


program_output = asyncio.run(read_output_stream())
print("Final output: ", program_output)
```

### Understand `StreamResponse`

`StreamResponse` (`dspy.streaming.StreamResponse`) is the wrapper class of streaming tokens. It comes with 3
fields:

- `predict_name`: the name of the predict that holds the `signature_field_name`. The name is the
  same name of keys as you run `your_program.named_predictors()`. In the code above because `answer` is from
  the `predict` itself, so the `predict_name` shows up as `self`, which is the only key as your run
  `predict.named_predictors()`.
- `signature_field_name`: the output field that these tokens map to. `predict_name` and `signature_field_name`
  together form the unique identifier of the field. We will demonstrate how to handle multiple fields streaming
  and duplicated field name later in this guide.
- `chunk`: the value of the stream chunk.

### Streaming with Cache

When a cached result is found, the stream will skip individual tokens and only yield the final `Prediction`. For example:

```
Prediction(
    answer='To get to the other side of the dinner plate!'
)
```

### Streaming Multiple Fields

You can monitor multiple fields by creating a `StreamListener` for each one. Here's an example with a multi-module program:

```python
import asyncio

import dspy

lm = dspy.LM("openai/gpt-4o-mini", cache=False)
dspy.settings.configure(lm=lm)


class MyModule(dspy.Module):
    def __init__(self):
        super().__init__()

        self.predict1 = dspy.Predict("question->answer")
        self.predict2 = dspy.Predict("answer->simplified_answer")

    def forward(self, question: str, **kwargs):
        answer = self.predict1(question=question)
        simplified_answer = self.predict2(answer=answer)
        return simplified_answer


predict = MyModule()
stream_listeners = [
    dspy.streaming.StreamListener(signature_field_name="answer"),
    dspy.streaming.StreamListener(signature_field_name="simplified_answer"),
]
stream_predict = dspy.streamify(
    predict,
    stream_listeners=stream_listeners,
)

async def read_output_stream():
    output = stream_predict(question="why did a chicken cross the kitchen?")

    return_value = None
    async for chunk in output:
        if isinstance(chunk, dspy.streaming.StreamResponse):
            print(chunk)
        elif isinstance(chunk, dspy.Prediction):
            return_value = chunk
    return return_value

program_output = asyncio.run(read_output_stream())
print("Final output: ", program_output)
```

The output will look like:

```
StreamResponse(predict_name='predict1', signature_field_name='answer', chunk='To')
StreamResponse(predict_name='predict1', signature_field_name='answer', chunk=' get')
StreamResponse(predict_name='predict1', signature_field_name='answer', chunk=' to')
StreamResponse(predict_name='predict1', signature_field_name='answer', chunk=' the')
StreamResponse(predict_name='predict1', signature_field_name='answer', chunk=' other side of the recipe!')
StreamResponse(predict_name='predict2', signature_field_name='simplified_answer', chunk='To')
StreamResponse(predict_name='predict2', signature_field_name='simplified_answer', chunk=' reach')
StreamResponse(predict_name='predict2', signature_field_name='simplified_answer', chunk=' the')
StreamResponse(predict_name='predict2', signature_field_name='simplified_answer', chunk=' other side of the recipe!')
Final output:  Prediction(
    simplified_answer='To reach the other side of the recipe!'
)
```

### Streaming the Same Field Multiple Times (as in dspy.ReAct)

By default, a `StreamListener` automatically closes itself after completing a single streaming session.
This design helps prevent performance issues, since every token is broadcast to all configured stream listeners,
and having too many active listeners can introduce significant overhead.

However, in scenarios where a DSPy module is used repeatedly in a loop—such as with `dspy.ReAct` — you may want to stream
the same field from each prediction, every time it is used. To enable this behavior, set allow_reuse=True when creating
your `StreamListener`. See the example below:

```python
import asyncio

import dspy

lm = dspy.LM("openai/gpt-4o-mini", cache=False)
dspy.settings.configure(lm=lm)


def fetch_user_info(user_name: str):
    """Get user information like name, birthday, etc."""
    return {
        "name": user_name,
        "birthday": "2009-05-16",
    }


def get_sports_news(year: int):
    """Get sports news for a given year."""
    if year == 2009:
        return "Usane Bolt broke the world record in the 100m race."
    return None


react = dspy.ReAct("question->answer", tools=[fetch_user_info, get_sports_news])

stream_listeners = [
    # dspy.ReAct has a built-in output field called "next_thought".
    dspy.streaming.StreamListener(signature_field_name="next_thought", allow_reuse=True),
]
stream_react = dspy.streamify(react, stream_listeners=stream_listeners)


async def read_output_stream():
    output = stream_react(question="What sports news happened in the year Adam was born?")
    return_value = None
    async for chunk in output:
        if isinstance(chunk, dspy.streaming.StreamResponse):
            print(chunk)
        elif isinstance(chunk, dspy.Prediction):
            return_value = chunk
    return return_value


print(asyncio.run(read_output_stream()))
```

In this example, by setting `allow_reuse=True` in the StreamListener, you ensure that streaming for "next_thought" is
available for every iteration, not just the first. When you run this code, you will see the streaming tokens for `next_thought`
output each time the field is produced.

#### Handling Duplicate Field Names

When streaming fields with the same name from different modules, specify both the `predict` and `predict_name` in the `StreamListener`:

```python
import asyncio

import dspy

lm = dspy.LM("openai/gpt-4o-mini", cache=False)
dspy.settings.configure(lm=lm)


class MyModule(dspy.Module):
    def __init__(self):
        super().__init__()

        self.predict1 = dspy.Predict("question->answer")
        self.predict2 = dspy.Predict("question, answer->answer, score")

    def forward(self, question: str, **kwargs):
        answer = self.predict1(question=question)
        simplified_answer = self.predict2(answer=answer)
        return simplified_answer


predict = MyModule()
stream_listeners = [
    dspy.streaming.StreamListener(
        signature_field_name="answer",
        predict=predict.predict1,
        predict_name="predict1"
    ),
    dspy.streaming.StreamListener(
        signature_field_name="answer",
        predict=predict.predict2,
        predict_name="predict2"
    ),
]
stream_predict = dspy.streamify(
    predict,
    stream_listeners=stream_listeners,
)


async def read_output_stream():
    output = stream_predict(question="why did a chicken cross the kitchen?")

    return_value = None
    async for chunk in output:
        if isinstance(chunk, dspy.streaming.StreamResponse):
            print(chunk)
        elif isinstance(chunk, dspy.Prediction):
            return_value = chunk
    return return_value


program_output = asyncio.run(read_output_stream())
print("Final output: ", program_output)
```

The output will be like:

```
StreamResponse(predict_name='predict1', signature_field_name='answer', chunk='To')
StreamResponse(predict_name='predict1', signature_field_name='answer', chunk=' get')
StreamResponse(predict_name='predict1', signature_field_name='answer', chunk=' to')
StreamResponse(predict_name='predict1', signature_field_name='answer', chunk=' the')
StreamResponse(predict_name='predict1', signature_field_name='answer', chunk=' other side of the recipe!')
StreamResponse(predict_name='predict2', signature_field_name='answer', chunk="I'm")
StreamResponse(predict_name='predict2', signature_field_name='answer', chunk=' ready')
StreamResponse(predict_name='predict2', signature_field_name='answer', chunk=' to')
StreamResponse(predict_name='predict2', signature_field_name='answer', chunk=' assist')
StreamResponse(predict_name='predict2', signature_field_name='answer', chunk=' you')
StreamResponse(predict_name='predict2', signature_field_name='answer', chunk='! Please provide a question.')
Final output:  Prediction(
    answer="I'm ready to assist you! Please provide a question.",
    score='N/A'
)
```

## Intermediate Status Streaming

Status streaming keeps users informed about the program's progress, especially useful for long-running operations like tool calls or complex AI pipelines. To implement status streaming:

1. Create a custom status message provider by subclassing `dspy.streaming.StatusMessageProvider`
2. Override the desired hook methods to provide custom status messages
3. Pass your provider to `dspy.streamify`

Example:

```python
class MyStatusMessageProvider(dspy.streaming.StatusMessageProvider):
    def lm_start_status_message(self, instance, inputs):
        return f"Calling LM with inputs {inputs}..."

    def lm_end_status_message(self, outputs):
        return f"Tool finished with output: {outputs}!"
```

Available hooks:

- lm_start_status_message: status message at the start of calling dspy.LM.
- lm_end_status_message: status message at the end of calling dspy.LM.
- module_start_status_message: status message at the start of calling a dspy.Module.
- module_end_status_message: status message at the start of calling a dspy.Module.
- tool_start_status_message: status message at the start of calling dspy.Tool.
- tool_end_status_message: status message at the end of calling dspy.Tool.

Each hook should return a string containing the status message.

After creating the message provider, just pass it to `dspy.streamify`, and you can enable both
status message streaming and output token streaming. Please see the example below. The intermediate
status message is represented in the class `dspy.streaming.StatusMessage`, so we need to have
another condition check to capture it.

```python
import asyncio

import dspy

lm = dspy.LM("openai/gpt-4o-mini", cache=False)
dspy.settings.configure(lm=lm)


class MyModule(dspy.Module):
    def __init__(self):
        super().__init__()

        self.tool = dspy.Tool(lambda x: 2 * x, name="double_the_number")
        self.predict = dspy.ChainOfThought("num1, num2->sum")

    def forward(self, num, **kwargs):
        num2 = self.tool(x=num)
        return self.predict(num1=num, num2=num2)


class MyStatusMessageProvider(dspy.streaming.StatusMessageProvider):
    def tool_start_status_message(self, instance, inputs):
        return f"Calling Tool {instance.name} with inputs {inputs}..."

    def tool_end_status_message(self, outputs):
        return f"Tool finished with output: {outputs}!"


predict = MyModule()
stream_listeners = [
    # dspy.ChainOfThought has a built-in output field called "reasoning".
    dspy.streaming.StreamListener(signature_field_name="reasoning"),
]
stream_predict = dspy.streamify(
    predict,
    stream_listeners=stream_listeners,
    status_message_provider=MyStatusMessageProvider(),
)


async def read_output_stream():
    output = stream_predict(num=3)

    return_value = None
    async for chunk in output:
        if isinstance(chunk, dspy.streaming.StreamResponse):
            print(chunk)
        elif isinstance(chunk, dspy.Prediction):
            return_value = chunk
        elif isinstance(chunk, dspy.streaming.StatusMessage):
            print(chunk)
    return return_value


program_output = asyncio.run(read_output_stream())
print("Final output: ", program_output)
```

Sample output:

```
StatusMessage(message='Calling tool double_the_number...')
StatusMessage(message='Tool calling finished! Querying the LLM with tool calling results...')
StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk='To')
StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' find')
StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' the')
StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' sum')
StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' of')
StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' the')
StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' two')
StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' numbers')
StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=',')
StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' we')
StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' simply')
StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' add')
StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' them')
StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' together')
StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk='.')
StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' Here')
StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=',')
StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' ')
StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk='3')
StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' plus')
StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' 6 equals 9.')
Final output:  Prediction(
    reasoning='To find the sum of the two numbers, we simply add them together. Here, 3 plus 6 equals 9.',
    sum='9'
)
```

## Synchronous Streaming

By default calling a streamified DSPy program produces an async generator. In order to get back
a sync generator, you can set the flag `async_streaming=False`:


```python
import os

import dspy

os.environ["OPENAI_API_KEY"] = "your_api_key"

dspy.configure(lm=dspy.LM("openai/gpt-4o-mini"))

predict = dspy.Predict("question->answer")

# Enable streaming for the 'answer' field
stream_predict = dspy.streamify(
    predict,
    stream_listeners=[dspy.streaming.StreamListener(signature_field_name="answer")],
    async_streaming=False,
)

output = stream_predict(question="why did a chicken cross the kitchen?")

program_output = None
for chunk in output:
    if isinstance(chunk, dspy.streaming.StreamResponse):
        print(chunk)
    elif isinstance(chunk, dspy.Prediction):
        program_output = chunk
print(f"Program output: {program_output}")
```

```

--------------------------------------------------------------------------------
/tests/signatures/test_adapter_image.py:
--------------------------------------------------------------------------------

```python
import os
import tempfile
from io import BytesIO

import pydantic
import pytest
import requests
from PIL import Image as PILImage

import dspy
from dspy.adapters.types.image import encode_image
from dspy.utils.dummies import DummyLM


@pytest.fixture
def sample_pil_image():
    """Fixture to provide a sample image for testing"""
    url = "https://images.dog.ceo/breeds/dane-great/n02109047_8912.jpg"
    response = requests.get(url)
    response.raise_for_status()
    return PILImage.open(BytesIO(response.content))


@pytest.fixture
def sample_dspy_image_download():
    url = "https://images.dog.ceo/breeds/dane-great/n02109047_8912.jpg"
    return dspy.Image(url, download=True)


@pytest.fixture
def sample_url():
    return "https://images.dog.ceo/breeds/dane-great/n02109047_8912.jpg"


@pytest.fixture
def sample_dspy_image_no_download():
    return dspy.Image("https://images.dog.ceo/breeds/dane-great/n02109047_8912.jpg")


def count_messages_with_image_url_pattern(messages):
    pattern = {"type": "image_url", "image_url": {"url": lambda x: isinstance(x, str)}}

    try:

        def check_pattern(obj, pattern):
            if isinstance(pattern, dict):
                if not isinstance(obj, dict):
                    return False
                return all(k in obj and check_pattern(obj[k], v) for k, v in pattern.items())
            if callable(pattern):
                return pattern(obj)
            return obj == pattern

        def count_patterns(obj, pattern):
            count = 0
            if check_pattern(obj, pattern):
                count += 1
            if isinstance(obj, dict):
                count += sum(count_patterns(v, pattern) for v in obj.values())
            if isinstance(obj, (list, tuple)):
                count += sum(count_patterns(v, pattern) for v in obj)
            return count

        return count_patterns(messages, pattern)
    except Exception:
        return 0


def setup_predictor(signature, expected_output):
    """Helper to set up a predictor with DummyLM"""
    lm = DummyLM([expected_output])
    dspy.settings.configure(lm=lm)
    return dspy.Predict(signature), lm


@pytest.mark.parametrize(
    "test_case",
    [
        {
            "name": "probabilistic_classification",
            "signature": "image: dspy.Image, class_labels: list[str] -> probabilities: dict[str, float]",
            "inputs": {"image": "https://example.com/dog.jpg", "class_labels": ["dog", "cat", "bird"]},
            "key_output": "probabilities",
            "expected": {"probabilities": {"dog": 0.8, "cat": 0.1, "bird": 0.1}},
        },
        {
            "name": "image_to_code",
            "signature": "ui_image: dspy.Image, target_language: str -> generated_code: str",
            "inputs": {"ui_image": "https://example.com/button.png", "target_language": "HTML"},
            "key_output": "generated_code",
            "expected": {"generated_code": "<button>Click me</button>"},
        },
        {
            "name": "bbox_detection",
            "signature": "image: dspy.Image -> bboxes: list[Tuple[int, int, int, int]]",
            "inputs": {"image": "https://example.com/image.jpg"},
            "key_output": "bboxes",
            "expected": {"bboxes": [(10, 20, 30, 40), (50, 60, 70, 80)]},
        },
        {
            "name": "multilingual_caption",
            "signature": "image: dspy.Image, languages: list[str] -> captions: dict[str, str]",
            "inputs": {"image": "https://example.com/dog.jpg", "languages": ["en", "es", "fr"]},
            "key_output": "captions",
            "expected": {
                "captions": {"en": "A golden retriever", "es": "Un golden retriever", "fr": "Un golden retriever"}
            },
        },
    ],
)
def test_basic_image_operations(test_case):
    """Consolidated test for basic image operations"""
    predictor, lm = setup_predictor(test_case["signature"], test_case["expected"])

    # Convert string URLs to dspy.Image objects
    inputs = {
        k: dspy.Image(v) if isinstance(v, str) and k in ["image", "ui_image"] else v
        for k, v in test_case["inputs"].items()
    }

    result = predictor(**inputs)

    # Check result based on output field name
    output_field = next(f for f in ["probabilities", "generated_code", "bboxes", "captions"] if hasattr(result, f))
    assert getattr(result, output_field) == test_case["expected"][test_case["key_output"]]
    assert count_messages_with_image_url_pattern(lm.history[-1]["messages"]) == 1


@pytest.mark.parametrize(
    "image_input,description",
    [
        ("pil_image", "PIL Image"),
        ("encoded_pil_image", "encoded PIL image string"),
        ("dspy_image_download", "dspy.Image with download=True"),
        ("dspy_image_no_download", "dspy.Image without download"),
    ],
)
def test_image_input_formats(
    request, sample_pil_image, sample_dspy_image_download, sample_dspy_image_no_download, image_input, description
):
    """Test different input formats for image fields"""
    signature = "image: dspy.Image, class_labels: list[str] -> probabilities: dict[str, float]"
    expected = {"probabilities": {"dog": 0.8, "cat": 0.1, "bird": 0.1}}
    predictor, lm = setup_predictor(signature, expected)

    input_map = {
        "pil_image": sample_pil_image,
        "encoded_pil_image": encode_image(sample_pil_image),
        "dspy_image_download": sample_dspy_image_download,
        "dspy_image_no_download": sample_dspy_image_no_download,
    }

    actual_input = input_map[image_input]
    # TODO(isaacbmiller): Support the cases without direct dspy.Image coercion
    if image_input in ["pil_image", "encoded_pil_image"]:
        pytest.xfail(f"{description} not fully supported without dspy.Image coercion")

    result = predictor(image=actual_input, class_labels=["dog", "cat", "bird"])
    assert result.probabilities == expected["probabilities"]
    assert count_messages_with_image_url_pattern(lm.history[-1]["messages"]) == 1


def test_predictor_save_load(sample_url, sample_pil_image):
    """Test saving and loading predictors with image fields"""
    signature = "image: dspy.Image -> caption: str"
    examples = [
        dspy.Example(image=dspy.Image(sample_url), caption="Example 1"),
        dspy.Example(image=sample_pil_image, caption="Example 2"),
    ]

    predictor, lm = setup_predictor(signature, {"caption": "A golden retriever"})
    optimizer = dspy.teleprompt.LabeledFewShot(k=1)
    compiled_predictor = optimizer.compile(student=predictor, trainset=examples, sample=False)

    with tempfile.NamedTemporaryFile(mode="w+", delete=True, suffix=".json") as temp_file:
        compiled_predictor.save(temp_file.name)
        loaded_predictor = dspy.Predict(signature)
        loaded_predictor.load(temp_file.name)

    loaded_predictor(image=dspy.Image("https://example.com/dog.jpg"))
    assert count_messages_with_image_url_pattern(lm.history[-1]["messages"]) == 2
    assert "<DSPY_IMAGE_START>" not in str(lm.history[-1]["messages"])


def test_save_load_complex_default_types():
    """Test saving and loading predictors with complex default types (lists of images)"""
    examples = [
        dspy.Example(
            image_list=[
                dspy.Image("https://example.com/dog.jpg"),
                dspy.Image("https://example.com/cat.jpg"),
            ],
            caption="Example 1",
        ).with_inputs("image_list"),
    ]

    class ComplexTypeSignature(dspy.Signature):
        image_list: list[dspy.Image] = dspy.InputField(desc="A list of images")
        caption: str = dspy.OutputField(desc="A caption for the image list")

    predictor, lm = setup_predictor(ComplexTypeSignature, {"caption": "A list of images"})
    optimizer = dspy.teleprompt.LabeledFewShot(k=1)
    compiled_predictor = optimizer.compile(student=predictor, trainset=examples, sample=False)

    with tempfile.NamedTemporaryFile(mode="w+", delete=True, suffix=".json") as temp_file:
        compiled_predictor.save(temp_file.name)
        loaded_predictor = dspy.Predict(ComplexTypeSignature)
        loaded_predictor.load(temp_file.name)

    result = loaded_predictor(**examples[0].inputs())
    assert result.caption == "A list of images"
    assert str(lm.history[-1]["messages"]).count("'url'") == 4
    assert "<DSPY_IMAGE_START>" not in str(lm.history[-1]["messages"])


class BasicImageSignature(dspy.Signature):
    """Basic signature with a single image input"""

    image: dspy.Image = dspy.InputField()
    output: str = dspy.OutputField()


class ImageListSignature(dspy.Signature):
    """Signature with a list of images input"""

    image_list: list[dspy.Image] = dspy.InputField()
    output: str = dspy.OutputField()


@pytest.mark.parametrize(
    "test_case",
    [
        {
            "name": "basic_dspy_signature",
            "signature_class": BasicImageSignature,
            "inputs": {"image": "https://example.com/dog.jpg"},
            "expected": {"output": "A dog photo"},
            "expected_image_urls": 2,
        },
        {
            "name": "list_dspy_signature",
            "signature_class": ImageListSignature,
            "inputs": {"image_list": ["https://example.com/dog.jpg", "https://example.com/cat.jpg"]},
            "expected": {"output": "Multiple photos"},
            "expected_image_urls": 4,
        },
    ],
)
def test_save_load_complex_types(test_case):
    """Test saving and loading predictors with complex types"""
    signature_cls = test_case["signature_class"]

    # Convert string URLs to dspy.Image objects in input
    processed_input = {}
    for key, value in test_case["inputs"].items():
        if isinstance(value, str) and "http" in value:
            processed_input[key] = dspy.Image(value)
        elif isinstance(value, list) and value and isinstance(value[0], str):
            processed_input[key] = [dspy.Image(url) for url in value]
        else:
            processed_input[key] = value

    # Create example and predictor
    examples = [dspy.Example(**processed_input, **test_case["expected"]).with_inputs(*processed_input.keys())]

    predictor, lm = setup_predictor(signature_cls, test_case["expected"])
    optimizer = dspy.teleprompt.LabeledFewShot(k=1)
    compiled_predictor = optimizer.compile(student=predictor, trainset=examples, sample=False)

    # Test save and load
    with tempfile.NamedTemporaryFile(mode="w+", delete=True, suffix=".json") as temp_file:
        compiled_predictor.save(temp_file.name)
        loaded_predictor = dspy.Predict(signature_cls)
        loaded_predictor.load(temp_file.name)

    # Run prediction
    result = loaded_predictor(**processed_input)

    # Verify output matches expected
    for key, value in test_case["expected"].items():
        assert getattr(result, key) == value

    # Verify correct number of image URLs in messages
    assert count_messages_with_image_url_pattern(lm.history[-1]["messages"]) == test_case["expected_image_urls"]
    assert "<DSPY_IMAGE_START>" not in str(lm.history[-1]["messages"])


def test_save_load_pydantic_model():
    """Test saving and loading predictors with pydantic models"""

    class ImageModel(pydantic.BaseModel):
        image: dspy.Image
        image_list: list[dspy.Image] | None = None
        output: str

    class PydanticSignature(dspy.Signature):
        model_input: ImageModel = dspy.InputField()
        output: str = dspy.OutputField()

    # Create model instance
    model_input = ImageModel(
        image=dspy.Image("https://example.com/dog.jpg"),
        image_list=[dspy.Image("https://example.com/cat.jpg")],
        output="Multiple photos",
    )

    # Create example and predictor
    examples = [dspy.Example(model_input=model_input, output="Multiple photos").with_inputs("model_input")]

    predictor, lm = setup_predictor(PydanticSignature, {"output": "Multiple photos"})
    optimizer = dspy.teleprompt.LabeledFewShot(k=1)
    compiled_predictor = optimizer.compile(student=predictor, trainset=examples, sample=False)

    # Test save and load
    with tempfile.NamedTemporaryFile(mode="w+", delete=True, suffix=".json") as temp_file:
        compiled_predictor.save(temp_file.name)
        loaded_predictor = dspy.Predict(PydanticSignature)
        loaded_predictor.load(temp_file.name)

    # Run prediction
    result = loaded_predictor(model_input=model_input)

    # Verify output matches expected
    assert result.output == "Multiple photos"
    assert count_messages_with_image_url_pattern(lm.history[-1]["messages"]) == 4
    assert "<DSPY_IMAGE_START>" not in str(lm.history[-1]["messages"])


def test_optional_image_field():
    """Test that optional image fields are not required"""

    class OptionalImageSignature(dspy.Signature):
        image: dspy.Image | None = dspy.InputField()
        output: str = dspy.OutputField()

    predictor, lm = setup_predictor(OptionalImageSignature, {"output": "Hello"})
    result = predictor(image=None)
    assert result.output == "Hello"
    assert count_messages_with_image_url_pattern(lm.history[-1]["messages"]) == 0


def test_pdf_url_support():
    """Test support for PDF files from URLs"""
    pdf_url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf"

    # Create a dspy.Image object from the PDF URL with download=True
    pdf_image = dspy.Image(pdf_url, download=True)

    # The data URI should contain application/pdf in the MIME type
    assert "data:application/pdf" in pdf_image.url
    assert ";base64," in pdf_image.url

    # Test using it in a predictor
    class PDFSignature(dspy.Signature):
        document: dspy.Image = dspy.InputField(desc="A PDF document")
        summary: str = dspy.OutputField(desc="A summary of the PDF")

    predictor, lm = setup_predictor(PDFSignature, {"summary": "This is a dummy PDF"})
    result = predictor(document=pdf_image)

    assert result.summary == "This is a dummy PDF"
    assert count_messages_with_image_url_pattern(lm.history[-1]["messages"]) == 1

    # Ensure the URL was properly expanded in messages
    messages_str = str(lm.history[-1]["messages"])
    assert "application/pdf" in messages_str


def test_different_mime_types():
    """Test support for different file types and MIME type detection"""
    # Test with various file types
    file_urls = {
        "pdf": "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf",
        "image": "https://images.dog.ceo/breeds/dane-great/n02109047_8912.jpg",
    }

    expected_mime_types = {
        "pdf": "application/pdf",
        "image": "image/jpeg",
    }

    for file_type, url in file_urls.items():
        # Download and encode
        encoded = encode_image(url, download_images=True)

        # Check for correct MIME type in the encoded data - using 'in' instead of startswith
        # to account for possible parameters in the MIME type
        assert f"data:{expected_mime_types[file_type]}" in encoded
        assert ";base64," in encoded


def test_mime_type_from_response_headers():
    """Test that MIME types from response headers are correctly used"""
    # This URL returns proper Content-Type header
    pdf_url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf"

    # Make an actual request to get the content type from headers
    response = requests.get(pdf_url)
    expected_mime_type = response.headers.get("Content-Type", "")

    # Should be application/pdf or similar
    assert "pdf" in expected_mime_type.lower()

    # Encode with download to test MIME type from headers
    encoded = encode_image(pdf_url, download_images=True)

    # The encoded data should contain the correct MIME type
    assert "application/pdf" in encoded
    assert ";base64," in encoded


def test_pdf_from_file():
    """Test handling a PDF file from disk"""
    # Download a PDF to a temporary file
    pdf_url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf"
    response = requests.get(pdf_url)
    response.raise_for_status()

    with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmp_file:
        tmp_file.write(response.content)
        tmp_file_path = tmp_file.name

    try:
        # Create a dspy.Image from the file
        pdf_image = dspy.Image(tmp_file_path)

        # The constructor encodes the file into a data URI we can inspect directly
        assert "data:application/pdf" in pdf_image.url
        assert ";base64," in pdf_image.url

        # Test the image in a predictor
        class FilePDFSignature(dspy.Signature):
            document: dspy.Image = dspy.InputField(desc="A PDF document from file")
            summary: str = dspy.OutputField(desc="A summary of the PDF")

        predictor, lm = setup_predictor(FilePDFSignature, {"summary": "This is a PDF from file"})
        result = predictor(document=pdf_image)

        assert result.summary == "This is a PDF from file"
        assert count_messages_with_image_url_pattern(lm.history[-1]["messages"]) == 1
    finally:
        # Clean up the temporary file
        try:
            os.unlink(tmp_file_path)
        except Exception:
            pass


def test_image_repr():
    """Test string representation of Image objects"""
    url_image = dspy.Image("https://example.com/dog.jpg")
    assert str(url_image) == (
        "<<CUSTOM-TYPE-START-IDENTIFIER>>"
        '[{"type": "image_url", "image_url": {"url": "https://example.com/dog.jpg"}}]'
        "<<CUSTOM-TYPE-END-IDENTIFIER>>"
    )
    assert repr(url_image) == "Image(url='https://example.com/dog.jpg')"

    sample_pil = PILImage.new("RGB", (60, 30), color="red")
    pil_image = dspy.Image(sample_pil)
    assert str(pil_image).startswith('<<CUSTOM-TYPE-START-IDENTIFIER>>[{"type": "image_url",')
    assert str(pil_image).endswith("<<CUSTOM-TYPE-END-IDENTIFIER>>")
    assert "base64" in str(pil_image)


def test_from_methods_warn(tmp_path):
    """Deprecated from_* methods emit warnings"""
    tmp_file = tmp_path / "test.png"
    tmp_file.write_bytes(b"pngdata")

    with pytest.warns(DeprecationWarning):
        dspy.Image.from_url("https://example.com/dog.jpg")
    with pytest.warns(DeprecationWarning):
        dspy.Image.from_file(str(tmp_file))
    sample_pil = PILImage.new("RGB", (10, 10), color="blue")
    with pytest.warns(DeprecationWarning):
        dspy.Image.from_PIL(sample_pil)


def test_invalid_string_format():
    """Test that invalid string formats raise a ValueError"""
    invalid_string = "this_is_not_a_url_or_file"

    # Should raise a ValueError and not pass the string through
    with pytest.raises(ValueError, match="Unrecognized") as warning_info:
        image = dspy.Image(invalid_string)

def test_pil_image_with_download_parameter():
    """Test behavior when PIL image is passed with download=True"""
    sample_pil = PILImage.new("RGB", (60, 30), color="red")

    # PIL image should be encoded regardless of download parameter
    image_no_download = dspy.Image(sample_pil)
    image_with_download = dspy.Image(sample_pil, download=True)

    # Both should result in base64 encoded data URIs
    assert image_no_download.url.startswith("data:")
    assert image_with_download.url.startswith("data:")
    assert "base64," in image_no_download.url
    assert "base64," in image_with_download.url

    # They should be identical since PIL images are always encoded
    assert image_no_download.url == image_with_download.url

```

--------------------------------------------------------------------------------
/dspy/clients/lm.py:
--------------------------------------------------------------------------------

```python
import logging
import os
import re
import threading
import warnings
from typing import Any, Literal, cast

import litellm
from anyio.streams.memory import MemoryObjectSendStream
from asyncer import syncify

import dspy
from dspy.clients.cache import request_cache
from dspy.clients.openai import OpenAIProvider
from dspy.clients.provider import Provider, ReinforceJob, TrainingJob
from dspy.clients.utils_finetune import TrainDataFormat
from dspy.dsp.utils.settings import settings
from dspy.utils.callback import BaseCallback

from .base_lm import BaseLM

logger = logging.getLogger(__name__)


class LM(BaseLM):
    """
    A language model supporting chat or text completion requests for use with DSPy modules.
    """

    def __init__(
        self,
        model: str,
        model_type: Literal["chat", "text", "responses"] = "chat",
        temperature: float | None = None,
        max_tokens: int | None = None,
        cache: bool = True,
        callbacks: list[BaseCallback] | None = None,
        num_retries: int = 3,
        provider: Provider | None = None,
        finetuning_model: str | None = None,
        launch_kwargs: dict[str, Any] | None = None,
        train_kwargs: dict[str, Any] | None = None,
        use_developer_role: bool = False,
        **kwargs,
    ):
        """
        Create a new language model instance for use with DSPy modules and programs.

        Args:
            model: The model to use. This should be a string of the form ``"llm_provider/llm_name"``
                   supported by LiteLLM. For example, ``"openai/gpt-4o"``.
            model_type: The type of the model, either ``"chat"`` or ``"text"``.
            temperature: The sampling temperature to use when generating responses.
            max_tokens: The maximum number of tokens to generate per response.
            cache: Whether to cache the model responses for reuse to improve performance
                   and reduce costs.
            callbacks: A list of callback functions to run before and after each request.
            num_retries: The number of times to retry a request if it fails transiently due to
                         network error, rate limiting, etc. Requests are retried with exponential
                         backoff.
            provider: The provider to use. If not specified, the provider will be inferred from the model.
            finetuning_model: The model to finetune. In some providers, the models available for finetuning is different
                from the models available for inference.
            rollout_id: Optional integer used to differentiate cache entries for otherwise
                identical requests. Different values bypass DSPy's caches while still caching
                future calls with the same inputs and rollout ID. Note that `rollout_id`
                only affects generation when `temperature` is non-zero. This argument is
                stripped before sending requests to the provider.
        """
        # Remember to update LM.copy() if you modify the constructor!
        self.model = model
        self.model_type = model_type
        self.cache = cache
        self.provider = provider or self.infer_provider()
        self.callbacks = callbacks or []
        self.history = []
        self.num_retries = num_retries
        self.finetuning_model = finetuning_model
        self.launch_kwargs = launch_kwargs or {}
        self.train_kwargs = train_kwargs or {}
        self.use_developer_role = use_developer_role
        self._warned_zero_temp_rollout = False

        # Handle model-specific configuration for different model families
        model_family = model.split("/")[-1].lower() if "/" in model else model.lower()

        # Recognize OpenAI reasoning models (o1, o3, o4, gpt-5 family)
        model_pattern = re.match(r"^(?:o[1345]|gpt-5)(?:-(?:mini|nano))?", model_family)

        if model_pattern:

            if (temperature and temperature != 1.0) or (max_tokens and max_tokens < 16000):
                raise ValueError(
                    "OpenAI's reasoning models require passing temperature=1.0 or None and max_tokens >= 16000 or None to "
                    "`dspy.LM(...)`, e.g., dspy.LM('openai/gpt-5', temperature=1.0, max_tokens=16000)"
                )
            self.kwargs = dict(temperature=temperature, max_completion_tokens=max_tokens, **kwargs)
            if self.kwargs.get("rollout_id") is None:
                self.kwargs.pop("rollout_id", None)
        else:
            self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs)
            if self.kwargs.get("rollout_id") is None:
                self.kwargs.pop("rollout_id", None)

        self._warn_zero_temp_rollout(self.kwargs.get("temperature"), self.kwargs.get("rollout_id"))

    def _warn_zero_temp_rollout(self, temperature: float | None, rollout_id):
        if not self._warned_zero_temp_rollout and rollout_id is not None and (temperature is None or temperature == 0):
            warnings.warn(
                "rollout_id has no effect when temperature=0; set temperature>0 to bypass the cache.",
                stacklevel=3,
            )
            self._warned_zero_temp_rollout = True

    def _get_cached_completion_fn(self, completion_fn, cache):
        ignored_args_for_cache_key = ["api_key", "api_base", "base_url"]
        if cache:
            completion_fn = request_cache(
                cache_arg_name="request",
                ignored_args_for_cache_key=ignored_args_for_cache_key,
            )(completion_fn)

        litellm_cache_args = {"no-cache": True, "no-store": True}

        return completion_fn, litellm_cache_args

    def forward(self, prompt=None, messages=None, **kwargs):
        # Build the request.
        kwargs = dict(kwargs)
        cache = kwargs.pop("cache", self.cache)

        messages = messages or [{"role": "user", "content": prompt}]
        if self.use_developer_role and self.model_type == "responses":
            messages = [{**m, "role": "developer"} if m.get("role") == "system" else m for m in messages]
        kwargs = {**self.kwargs, **kwargs}
        self._warn_zero_temp_rollout(kwargs.get("temperature"), kwargs.get("rollout_id"))
        if kwargs.get("rollout_id") is None:
            kwargs.pop("rollout_id", None)

        if self.model_type == "chat":
            completion = litellm_completion
        elif self.model_type == "text":
            completion = litellm_text_completion
        elif self.model_type == "responses":
            completion = litellm_responses_completion
        completion, litellm_cache_args = self._get_cached_completion_fn(completion, cache)

        results = completion(
            request=dict(model=self.model, messages=messages, **kwargs),
            num_retries=self.num_retries,
            cache=litellm_cache_args,
        )

        self._check_truncation(results)

        if not getattr(results, "cache_hit", False) and dspy.settings.usage_tracker and hasattr(results, "usage"):
            settings.usage_tracker.add_usage(self.model, dict(results.usage))
        return results

    async def aforward(self, prompt=None, messages=None, **kwargs):
        # Build the request.
        kwargs = dict(kwargs)
        cache = kwargs.pop("cache", self.cache)

        messages = messages or [{"role": "user", "content": prompt}]
        if self.use_developer_role and self.model_type == "responses":
            messages = [{**m, "role": "developer"} if m.get("role") == "system" else m for m in messages]
        kwargs = {**self.kwargs, **kwargs}
        self._warn_zero_temp_rollout(kwargs.get("temperature"), kwargs.get("rollout_id"))
        if kwargs.get("rollout_id") is None:
            kwargs.pop("rollout_id", None)

        if self.model_type == "chat":
            completion = alitellm_completion
        elif self.model_type == "text":
            completion = alitellm_text_completion
        elif self.model_type == "responses":
            completion = alitellm_responses_completion
        completion, litellm_cache_args = self._get_cached_completion_fn(completion, cache)

        results = await completion(
            request=dict(model=self.model, messages=messages, **kwargs),
            num_retries=self.num_retries,
            cache=litellm_cache_args,
        )

        self._check_truncation(results)

        if not getattr(results, "cache_hit", False) and dspy.settings.usage_tracker and hasattr(results, "usage"):
            settings.usage_tracker.add_usage(self.model, dict(results.usage))
        return results

    def launch(self, launch_kwargs: dict[str, Any] | None = None):
        self.provider.launch(self, launch_kwargs)

    def kill(self, launch_kwargs: dict[str, Any] | None = None):
        self.provider.kill(self, launch_kwargs)

    def finetune(
        self,
        train_data: list[dict[str, Any]],
        train_data_format: TrainDataFormat | None,
        train_kwargs: dict[str, Any] | None = None,
    ) -> TrainingJob:
        from dspy import settings as settings

        if not self.provider.finetunable:
            raise ValueError(
                f"Provider {self.provider} does not support fine-tuning, please specify your provider by explicitly "
                "setting `provider` when creating the `dspy.LM` instance. For example, "
                "`dspy.LM('openai/gpt-4.1-mini-2025-04-14', provider=dspy.OpenAIProvider())`."
            )

        def thread_function_wrapper():
            return self._run_finetune_job(job)

        thread = threading.Thread(target=thread_function_wrapper)
        train_kwargs = train_kwargs or self.train_kwargs
        model_to_finetune = self.finetuning_model or self.model
        job = self.provider.TrainingJob(
            thread=thread,
            model=model_to_finetune,
            train_data=train_data,
            train_data_format=train_data_format,
            train_kwargs=train_kwargs,
        )
        thread.start()

        return job

    def reinforce(
        self, train_kwargs
    ) -> ReinforceJob:
        # TODO(GRPO Team): Should we return an initialized job here?
        from dspy import settings as settings

        err = f"Provider {self.provider} does not implement the reinforcement learning interface."
        assert self.provider.reinforceable, err

        job = self.provider.ReinforceJob(lm=self, train_kwargs=train_kwargs)
        job.initialize()
        return job

    def _run_finetune_job(self, job: TrainingJob):
        # TODO(enhance): We should listen for keyboard interrupts somewhere.
        # Requires TrainingJob.cancel() to be implemented for each provider.
        try:
            model = self.provider.finetune(
                job=job,
                model=job.model,
                train_data=job.train_data,
                train_data_format=job.train_data_format,
                train_kwargs=job.train_kwargs,
            )
            lm = self.copy(model=model)
            job.set_result(lm)
        except Exception as err:
            logger.error(err)
            job.set_result(err)

    def infer_provider(self) -> Provider:
        if OpenAIProvider.is_provider_model(self.model):
            return OpenAIProvider()
        return Provider()

    def dump_state(self):
        state_keys = [
            "model",
            "model_type",
            "cache",
            "num_retries",
            "finetuning_model",
            "launch_kwargs",
            "train_kwargs",
        ]
        return {key: getattr(self, key) for key in state_keys} | self.kwargs

    def _check_truncation(self, results):
        if self.model_type != "responses" and any(c.finish_reason == "length" for c in results["choices"]):
            logger.warning(
                f"LM response was truncated due to exceeding max_tokens={self.kwargs['max_tokens']}. "
                "You can inspect the latest LM interactions with `dspy.inspect_history()`. "
                "To avoid truncation, consider passing a larger max_tokens when setting up dspy.LM. "
                f"You may also consider increasing the temperature (currently {self.kwargs['temperature']}) "
                " if the reason for truncation is repetition."
            )


def _get_stream_completion_fn(
    request: dict[str, Any],
    cache_kwargs: dict[str, Any],
    sync=True,
):
    stream = dspy.settings.send_stream
    caller_predict = dspy.settings.caller_predict

    if stream is None:
        return None

    # The stream is already opened, and will be closed by the caller.
    stream = cast(MemoryObjectSendStream, stream)
    caller_predict_id = id(caller_predict) if caller_predict else None

    if dspy.settings.track_usage:
        request["stream_options"] = {"include_usage": True}

    async def stream_completion(request: dict[str, Any], cache_kwargs: dict[str, Any]):
        headers = request.pop("headers", None)
        response = await litellm.acompletion(
            cache=cache_kwargs,
            stream=True,
            headers=_get_headers(headers),
            **request,
        )
        chunks = []
        async for chunk in response:
            if caller_predict_id:
                # Add the predict id to the chunk so that the stream listener can identify which predict produces it.
                chunk.predict_id = caller_predict_id
            chunks.append(chunk)
            await stream.send(chunk)
        return litellm.stream_chunk_builder(chunks)

    def sync_stream_completion():
        syncified_stream_completion = syncify(stream_completion)
        return syncified_stream_completion(request, cache_kwargs)

    async def async_stream_completion():
        return await stream_completion(request, cache_kwargs)

    if sync:
        return sync_stream_completion
    else:
        return async_stream_completion


def litellm_completion(request: dict[str, Any], num_retries: int, cache: dict[str, Any] | None = None):
    cache = cache or {"no-cache": True, "no-store": True}
    request = dict(request)
    request.pop("rollout_id", None)
    headers = request.pop("headers", None)
    stream_completion = _get_stream_completion_fn(request, cache, sync=True)
    if stream_completion is None:
        return litellm.completion(
            cache=cache,
            num_retries=num_retries,
            retry_strategy="exponential_backoff_retry",
            headers=_get_headers(headers),
            **request,
        )

    return stream_completion()


def litellm_text_completion(request: dict[str, Any], num_retries: int, cache: dict[str, Any] | None = None):
    cache = cache or {"no-cache": True, "no-store": True}
    request = dict(request)
    request.pop("rollout_id", None)
    headers = request.pop("headers", None)
    # Extract the provider and model from the model string.
    # TODO: Not all the models are in the format of "provider/model"
    model = request.pop("model").split("/", 1)
    provider, model = model[0] if len(model) > 1 else "openai", model[-1]

    # Use the API key and base from the request, or from the environment.
    api_key = request.pop("api_key", None) or os.getenv(f"{provider}_API_KEY")
    api_base = request.pop("api_base", None) or os.getenv(f"{provider}_API_BASE")

    # Build the prompt from the messages.
    prompt = "\n\n".join([x["content"] for x in request.pop("messages")] + ["BEGIN RESPONSE:"])

    return litellm.text_completion(
        cache=cache,
        model=f"text-completion-openai/{model}",
        api_key=api_key,
        api_base=api_base,
        prompt=prompt,
        num_retries=num_retries,
        retry_strategy="exponential_backoff_retry",
        headers=_get_headers(headers),
        **request,
    )


async def alitellm_completion(request: dict[str, Any], num_retries: int, cache: dict[str, Any] | None = None):
    cache = cache or {"no-cache": True, "no-store": True}
    request = dict(request)
    request.pop("rollout_id", None)
    headers = request.pop("headers", None)
    stream_completion = _get_stream_completion_fn(request, cache, sync=False)
    if stream_completion is None:
        return await litellm.acompletion(
            cache=cache,
            num_retries=num_retries,
            retry_strategy="exponential_backoff_retry",
            headers=_get_headers(headers),
            **request,
        )

    return await stream_completion()


async def alitellm_text_completion(request: dict[str, Any], num_retries: int, cache: dict[str, Any] | None = None):
    cache = cache or {"no-cache": True, "no-store": True}
    request = dict(request)
    request.pop("rollout_id", None)
    model = request.pop("model").split("/", 1)
    headers = request.pop("headers", None)
    provider, model = model[0] if len(model) > 1 else "openai", model[-1]

    # Use the API key and base from the request, or from the environment.
    api_key = request.pop("api_key", None) or os.getenv(f"{provider}_API_KEY")
    api_base = request.pop("api_base", None) or os.getenv(f"{provider}_API_BASE")

    # Build the prompt from the messages.
    prompt = "\n\n".join([x["content"] for x in request.pop("messages")] + ["BEGIN RESPONSE:"])

    return await litellm.atext_completion(
        cache=cache,
        model=f"text-completion-openai/{model}",
        api_key=api_key,
        api_base=api_base,
        prompt=prompt,
        num_retries=num_retries,
        retry_strategy="exponential_backoff_retry",
        headers=_get_headers(headers),
        **request,
    )


def litellm_responses_completion(request: dict[str, Any], num_retries: int, cache: dict[str, Any] | None = None):
    cache = cache or {"no-cache": True, "no-store": True}
    request = dict(request)
    request.pop("rollout_id", None)
    headers = request.pop("headers", None)
    request = _convert_chat_request_to_responses_request(request)

    return litellm.responses(
        cache=cache,
        num_retries=num_retries,
        retry_strategy="exponential_backoff_retry",
        headers=_get_headers(headers),
        **request,
    )


async def alitellm_responses_completion(request: dict[str, Any], num_retries: int, cache: dict[str, Any] | None = None):
    cache = cache or {"no-cache": True, "no-store": True}
    request = dict(request)
    request.pop("rollout_id", None)
    headers = request.pop("headers", None)
    request = _convert_chat_request_to_responses_request(request)

    return await litellm.aresponses(
        cache=cache,
        num_retries=num_retries,
        retry_strategy="exponential_backoff_retry",
        headers=_get_headers(headers),
        **request,
    )


def _convert_chat_request_to_responses_request(request: dict[str, Any]):
    request = dict(request)
    if "messages" in request:
        content_blocks = []
        for msg in request.pop("messages"):
            c = msg.get("content")
            if isinstance(c, str):
                content_blocks.append({"type": "input_text", "text": c})
            elif isinstance(c, list):
                content_blocks.extend(c)
        request["input"] = [{"role": msg.get("role", "user"), "content": content_blocks}]

    # Convert `response_format` to `text.format` for Responses API
    if "response_format" in request:
        response_format = request.pop("response_format")
        text = request.pop("text", {})
        request["text"] = {**text, "format": response_format}

    return request

def _get_headers(headers: dict[str, Any] | None = None):
    headers = headers or {}
    return {
        "User-Agent": f"DSPy/{dspy.__version__}",
        **headers,
    }

```

--------------------------------------------------------------------------------
/tests/adapters/test_baml_adapter.py:
--------------------------------------------------------------------------------

```python
from typing import Literal
from unittest import mock

import pydantic
import pytest
from litellm import Choices, Message
from litellm.files.main import ModelResponse

import dspy
from dspy.adapters.baml_adapter import COMMENT_SYMBOL, BAMLAdapter


# Test fixtures - Pydantic models for testing
class PatientAddress(pydantic.BaseModel):
    street: str
    city: str
    country: Literal["US", "CA"]


class PatientDetails(pydantic.BaseModel):
    name: str = pydantic.Field(description="Full name of the patient")
    age: int
    address: PatientAddress | None = None


class ComplexNestedModel(pydantic.BaseModel):
    id: int = pydantic.Field(description="Unique identifier")
    details: PatientDetails
    tags: list[str] = pydantic.Field(default_factory=list)
    metadata: dict[str, str] = pydantic.Field(default_factory=dict)


class ModelWithLists(pydantic.BaseModel):
    items: list[PatientAddress] = pydantic.Field(description="List of patient addresses")
    scores: list[float]


class ImageWrapper(pydantic.BaseModel):
    images: list[dspy.Image]
    tag: list[str]


class CircularModel(pydantic.BaseModel):
    name: str
    field: "CircularModel"


def test_baml_adapter_basic_schema_generation():
    """Test that BAMLAdapter generates simplified schemas for Pydantic models."""

    class TestSignature(dspy.Signature):
        question: str = dspy.InputField()
        patient: PatientDetails = dspy.OutputField()

    adapter = BAMLAdapter()
    schema = adapter.format_field_structure(TestSignature)

    # Should contain simplified schema with comments
    assert f"{COMMENT_SYMBOL} Full name of the patient" in schema
    assert "name: string," in schema
    assert "age: int," in schema
    assert "address:" in schema
    assert "street: string," in schema
    assert 'country: "US" or "CA",' in schema


def test_baml_adapter_handles_optional_fields():
    """Test optional field rendering with 'or null' syntax."""

    class TestSignature(dspy.Signature):
        input: str = dspy.InputField()
        patient: PatientDetails = dspy.OutputField()

    adapter = BAMLAdapter()
    schema = adapter.format_field_structure(TestSignature)

    # Optional address field should show 'or null'
    assert "address:" in schema
    assert "or null" in schema


def test_baml_adapter_handles_primitive_types():
    """Test rendering of basic primitive types."""

    class SimpleModel(pydantic.BaseModel):
        text: str
        number: int
        decimal: float
        flag: bool

    class TestSignature(dspy.Signature):
        input: str = dspy.InputField()
        output: SimpleModel = dspy.OutputField()

    adapter = BAMLAdapter()
    schema = adapter.format_field_structure(TestSignature)

    assert "text: string," in schema
    assert "number: int," in schema
    assert "decimal: float," in schema
    assert "flag: boolean," in schema


def test_baml_adapter_handles_lists_with_bracket_notation():
    """Test that lists of Pydantic models use proper bracket notation."""

    class TestSignature(dspy.Signature):
        input: str = dspy.InputField()
        addresses: ModelWithLists = dspy.OutputField()

    adapter = BAMLAdapter()
    schema = adapter.format_field_structure(TestSignature)

    # Should use bracket notation for lists and include comments
    assert "items: [" in schema
    assert f"{COMMENT_SYMBOL} List of patient addresses" in schema
    assert "street: string," in schema
    assert "city: string," in schema
    assert "]," in schema
    assert "scores: float[]," in schema


def test_baml_adapter_handles_complex_nested_models():
    """Test deeply nested Pydantic model schema generation."""

    class TestSignature(dspy.Signature):
        input: str = dspy.InputField()
        complex: ComplexNestedModel = dspy.OutputField()

    adapter = BAMLAdapter()
    schema = adapter.format_field_structure(TestSignature)

    # Should include nested structure with comments
    assert f"{COMMENT_SYMBOL} Unique identifier" in schema
    assert "details:" in schema
    assert f"{COMMENT_SYMBOL} Full name of the patient" in schema
    assert "tags: string[]," in schema
    assert "metadata: dict[string, string]," in schema


def test_baml_adapter_raise_error_on_circular_references():
    """Test that circular references are handled gracefully."""

    class TestSignature(dspy.Signature):
        input: str = dspy.InputField()
        circular: CircularModel = dspy.OutputField()

    adapter = BAMLAdapter()
    with pytest.raises(ValueError) as error:
        adapter.format_field_structure(TestSignature)

    assert "BAMLAdapter cannot handle recursive pydantic models" in str(error.value)


def test_baml_adapter_formats_pydantic_inputs_as_clean_json():
    """Test that Pydantic input instances are formatted as clean JSON."""

    class TestSignature(dspy.Signature):
        patient: PatientDetails = dspy.InputField()
        question: str = dspy.InputField()
        answer: str = dspy.OutputField()

    adapter = BAMLAdapter()
    patient = PatientDetails(
        name="John Doe", age=45, address=PatientAddress(street="123 Main St", city="Anytown", country="US")
    )

    messages = adapter.format(TestSignature, [], {"patient": patient, "question": "What is the diagnosis?"})

    # Should have clean, indented JSON for Pydantic input
    user_message = messages[-1]["content"]
    assert '"name": "John Doe"' in user_message
    assert '"age": 45' in user_message
    assert '"street": "123 Main St"' in user_message
    assert '"country": "US"' in user_message


def test_baml_adapter_handles_mixed_input_types():
    """Test formatting of mixed Pydantic and primitive inputs."""

    class TestSignature(dspy.Signature):
        patient: PatientDetails = dspy.InputField()
        priority: int = dspy.InputField()
        notes: str = dspy.InputField()
        result: str = dspy.OutputField()

    adapter = BAMLAdapter()
    patient = PatientDetails(name="Jane Doe", age=30)

    messages = adapter.format(TestSignature, [], {"patient": patient, "priority": 1, "notes": "Urgent case"})

    user_message = messages[-1]["content"]
    # Pydantic should be JSON formatted
    assert '"name": "Jane Doe"' in user_message
    # Primitives should be formatted normally
    assert "priority ## ]]\n1" in user_message
    assert "notes ## ]]\nUrgent case" in user_message


def test_baml_adapter_handles_schema_generation_errors_gracefully():
    """Test graceful handling of schema generation errors."""

    class ProblematicModel(pydantic.BaseModel):
        # This might cause issues in schema generation
        field: object

    class TestSignature(dspy.Signature):
        input: str = dspy.InputField()
        output: ProblematicModel = dspy.OutputField()

    adapter = BAMLAdapter()

    # Should not raise an exception
    try:
        schema = adapter.format_field_structure(TestSignature)
        # If no exception, schema should at least contain some basic structure
        assert "schema" in schema.lower()
    except Exception:
        # If exception occurs, test passes as we're testing graceful handling
        pass


def test_baml_adapter_raises_on_missing_fields():
    """Test that missing required fields raise appropriate errors."""

    class TestSignature(dspy.Signature):
        input: str = dspy.InputField()
        patient: PatientDetails = dspy.OutputField()
        summary: str = dspy.OutputField()

    adapter = BAMLAdapter()

    # Missing 'summary' field
    completion = '{"patient": {"name": "John", "age": 30}}'

    with pytest.raises(dspy.utils.exceptions.AdapterParseError) as e:
        adapter.parse(TestSignature, completion)

    assert e.value.adapter_name == "JSONAdapter"  # BAMLAdapter inherits from JSONAdapter
    assert "summary" in str(e.value)


def test_baml_adapter_handles_type_casting_errors():
    """Test graceful handling of type casting errors."""

    class TestSignature(dspy.Signature):
        input: str = dspy.InputField()
        patient: PatientDetails = dspy.OutputField()

    adapter = BAMLAdapter()

    # Invalid age type
    completion = '{"patient": {"name": "John", "age": "not_a_number"}}'

    # Should raise ValidationError from Pydantic (which is the expected behavior)
    with pytest.raises((dspy.utils.exceptions.AdapterParseError, pydantic.ValidationError)):
        adapter.parse(TestSignature, completion)


def test_baml_adapter_with_images():
    """Test BAMLAdapter integration with dspy.Image objects."""

    class TestSignature(dspy.Signature):
        image_data: ImageWrapper = dspy.InputField()
        description: str = dspy.OutputField()

    adapter = BAMLAdapter()

    image_wrapper = ImageWrapper(
        images=[dspy.Image(url="https://example.com/image1.jpg"), dspy.Image(url="https://example.com/image2.jpg")],
        tag=["test", "medical"],
    )

    messages = adapter.format(TestSignature, [], {"image_data": image_wrapper})

    # Should contain image URLs in the message content
    user_message = messages[-1]["content"]
    image_contents = [
        content for content in user_message if isinstance(content, dict) and content.get("type") == "image_url"
    ]

    assert len(image_contents) == 2
    assert {"type": "image_url", "image_url": {"url": "https://example.com/image1.jpg"}} in user_message
    assert {"type": "image_url", "image_url": {"url": "https://example.com/image2.jpg"}} in user_message


def test_baml_adapter_with_tools():
    """Test BAMLAdapter integration with dspy.Tool objects."""

    class TestSignature(dspy.Signature):
        question: str = dspy.InputField()
        tools: list[dspy.Tool] = dspy.InputField()
        answer: str = dspy.OutputField()

    def get_patient_info(patient_id: int) -> str:
        """Get patient information by ID"""
        return f"Patient info for ID {patient_id}"

    def schedule_appointment(patient_name: str, date: str) -> str:
        """Schedule an appointment for a patient"""
        return f"Scheduled appointment for {patient_name} on {date}"

    tools = [dspy.Tool(get_patient_info), dspy.Tool(schedule_appointment)]

    adapter = BAMLAdapter()
    messages = adapter.format(TestSignature, [], {"question": "Schedule an appointment for John", "tools": tools})

    user_message = messages[-1]["content"]
    assert "get_patient_info" in user_message
    assert "schedule_appointment" in user_message
    assert "Get patient information by ID" in user_message
    assert "Schedule an appointment for a patient" in user_message


def test_baml_adapter_with_code():
    """Test BAMLAdapter integration with dspy.Code objects."""

    # Test with code as input field
    class CodeAnalysisSignature(dspy.Signature):
        code: dspy.Code = dspy.InputField()
        analysis: str = dspy.OutputField()

    adapter = BAMLAdapter()
    messages = adapter.format(CodeAnalysisSignature, [], {"code": "def hello():\n    print('Hello, world!')"})

    user_message = messages[-1]["content"]
    assert "def hello():" in user_message
    assert "print('Hello, world!')" in user_message

    # Test with code as output field
    class CodeGenSignature(dspy.Signature):
        task: str = dspy.InputField()
        code: dspy.Code = dspy.OutputField()

    with mock.patch("litellm.completion") as mock_completion:
        mock_completion.return_value = ModelResponse(
            choices=[Choices(message=Message(content='{"code": "print(\\"Generated code\\")"}'))],
            model="openai/gpt-4o-mini",
        )

        result = adapter(
            dspy.LM(model="openai/gpt-4o-mini", cache=False),
            {},
            CodeGenSignature,
            [],
            {"task": "Write a hello world program"},
        )

        assert result[0]["code"].code == 'print("Generated code")'


def test_baml_adapter_with_conversation_history():
    """Test BAMLAdapter integration with dspy.History objects."""

    class TestSignature(dspy.Signature):
        history: dspy.History = dspy.InputField()
        question: str = dspy.InputField()
        answer: str = dspy.OutputField()

    history = dspy.History(
        messages=[
            {"question": "What is the patient's age?", "answer": "45 years old"},
            {"question": "Any allergies?", "answer": "Penicillin allergy"},
        ]
    )

    adapter = BAMLAdapter()
    messages = adapter.format(TestSignature, [], {"history": history, "question": "What medications should we avoid?"})

    # Should format history as separate messages
    assert len(messages) == 6  # system + 2 history pairs + user
    assert "What is the patient's age?" in messages[1]["content"]
    assert '"answer": "45 years old"' in messages[2]["content"]
    assert "Any allergies?" in messages[3]["content"]
    assert '"answer": "Penicillin allergy"' in messages[4]["content"]


# Comparison tests with JSONAdapter
def test_baml_vs_json_adapter_token_efficiency():
    """Test that BAMLAdapter generates more token-efficient schemas."""

    class TestSignature(dspy.Signature):
        input: str = dspy.InputField()
        complex: ComplexNestedModel = dspy.OutputField()

    baml_adapter = BAMLAdapter()
    json_adapter = dspy.JSONAdapter()

    baml_schema = baml_adapter.format_field_structure(TestSignature)
    json_schema = json_adapter.format_field_structure(TestSignature)

    # Simple character count as proxy for token efficiency
    # BAMLAdapter should always produce shorter schemas
    assert len(baml_schema) < len(json_schema)


def test_baml_vs_json_adapter_functional_compatibility():
    """Test that both adapters parse identical outputs to the same results."""

    class TestSignature(dspy.Signature):
        question: str = dspy.InputField()
        patient: PatientDetails = dspy.OutputField()

    baml_adapter = BAMLAdapter()
    json_adapter = dspy.JSONAdapter()

    completion = """{"patient": {
        "name": "Alice Brown",
        "age": 35,
        "address": {"street": "789 Pine St", "city": "Boston", "country": "US"}
    }}"""

    baml_result = baml_adapter.parse(TestSignature, completion)
    json_result = json_adapter.parse(TestSignature, completion)

    # Results should be functionally equivalent
    assert baml_result["patient"].name == json_result["patient"].name
    assert baml_result["patient"].age == json_result["patient"].age
    assert baml_result["patient"].address.street == json_result["patient"].address.street


@pytest.mark.asyncio
async def test_baml_adapter_async_functionality():
    """Test BAMLAdapter async operations."""

    class TestSignature(dspy.Signature):
        question: str = dspy.InputField()
        patient: PatientDetails = dspy.OutputField()

    with mock.patch("litellm.acompletion") as mock_acompletion:
        mock_acompletion.return_value = ModelResponse(
            choices=[Choices(message=Message(content='{"patient": {"name": "John Doe", "age": 28}}'))],
            model="openai/gpt-4o",
        )

        adapter = BAMLAdapter()
        result = await adapter.acall(
            dspy.LM(model="openai/gpt-4o", cache=False), {}, TestSignature, [], {"question": "Extract patient info"}
        )

        assert result[0]["patient"].name == "John Doe"
        assert result[0]["patient"].age == 28


def test_baml_adapter_with_field_aliases():
    """Test BAMLAdapter with Pydantic field aliases."""

    class ModelWithAliases(pydantic.BaseModel):
        full_name: str = pydantic.Field(alias="name")
        patient_age: int = pydantic.Field(alias="age")

    class TestSignature(dspy.Signature):
        input: str = dspy.InputField()
        data: ModelWithAliases = dspy.OutputField()

    adapter = BAMLAdapter()

    # Schema should show aliases in the output structure
    schema = adapter.format_field_structure(TestSignature)
    assert "name:" in schema  # Should use alias, not field name
    assert "age:" in schema  # Should use alias, not field name


def test_baml_adapter_field_alias_without_description():
    """Test BAMLAdapter with field alias present but description absent."""

    class ModelWithAliasNoDescription(pydantic.BaseModel):
        internal_field: str = pydantic.Field(alias="public_name")
        regular_field: int
        field_with_description: str = pydantic.Field(description="This field has a description", alias="desc_field")

    class TestSignature(dspy.Signature):
        input: str = dspy.InputField()
        data: ModelWithAliasNoDescription = dspy.OutputField()

    adapter = BAMLAdapter()
    schema = adapter.format_field_structure(TestSignature)

    # Should show alias as comment when description is absent
    assert f"{COMMENT_SYMBOL} alias: public_name" in schema
    # Should show description comment when present
    assert f"{COMMENT_SYMBOL} This field has a description" in schema
    # Regular field (without alias) should appear in schema but without alias comment
    assert "regular_field: int," in schema
    # Check that regular_field section doesn't have an alias comment
    regular_field_section = schema.split("regular_field: int,")[0].split("\n")[-1]
    assert f"{COMMENT_SYMBOL} alias:" not in regular_field_section


def test_baml_adapter_multiple_pydantic_input_fields():
    """Test that multiple InputField() with Pydantic models are rendered correctly."""

    class UserProfile(pydantic.BaseModel):
        name: str = pydantic.Field(description="User's full name")
        email: str
        age: int

    class SystemConfig(pydantic.BaseModel):
        timeout: int = pydantic.Field(description="Timeout in seconds")
        debug: bool
        endpoints: list[str]

    class TestSignature(dspy.Signature):
        input_1: UserProfile = dspy.InputField()
        input_2: SystemConfig = dspy.InputField()
        result: str = dspy.OutputField()

    adapter = BAMLAdapter()

    # Test schema generation includes headers for ALL input fields
    schema = adapter.format_field_structure(TestSignature)
    assert "[[ ## input_1 ## ]]" in schema  # Should include first input field header
    assert "[[ ## input_2 ## ]]" in schema  # Should include second input field header
    assert "[[ ## result ## ]]" in schema  # Should include output field header
    assert "[[ ## completed ## ]]" in schema  # Should include completed section
    assert "All interactions will be structured in the following way" in schema
    assert "{input_1}" in schema
    assert "{input_2}" in schema
    assert "Output field `result` should be of type: string" in schema

    # Test field descriptions are in the correct method
    field_desc = adapter.format_field_description(TestSignature)
    assert "Your input fields are:" in field_desc
    assert "Your output fields are:" in field_desc

    # Test message formatting with actual Pydantic instances
    user_profile = UserProfile(name="John Doe", email="[email protected]", age=30)
    system_config = SystemConfig(timeout=300, debug=True, endpoints=["api1", "api2"])

    messages = adapter.format(TestSignature, [], {"input_1": user_profile, "input_2": system_config})

    user_message = messages[-1]["content"]

    # Verify both inputs are rendered with the correct bracket notation
    assert "[[ ## input_1 ## ]]" in user_message
    assert "[[ ## input_2 ## ]]" in user_message

    # Verify JSON content for both inputs
    assert '"name": "John Doe"' in user_message
    assert '"email": "[email protected]"' in user_message
    assert '"age": 30' in user_message
    assert '"timeout": 300' in user_message
    assert '"debug": true' in user_message
    # Endpoints array is formatted with indentation, so check for individual elements
    assert '"api1"' in user_message
    assert '"api2"' in user_message
    assert '"endpoints":' in user_message

```

--------------------------------------------------------------------------------
/tests/clients/test_lm.py:
--------------------------------------------------------------------------------

```python
import json
import time
import warnings
from unittest import mock
from unittest.mock import patch

import litellm
import pydantic
import pytest
from litellm.types.llms.openai import ResponseAPIUsage, ResponsesAPIResponse
from litellm.utils import Choices, Message, ModelResponse
from openai import RateLimitError
from openai.types.responses import ResponseOutputMessage, ResponseReasoningItem
from openai.types.responses.response_reasoning_item import Summary

import dspy
from dspy.utils.dummies import DummyLM
from dspy.utils.usage_tracker import track_usage


def make_response(output_blocks):
    return ResponsesAPIResponse(
        id="resp_1",
        created_at=0.0,
        error=None,
        incomplete_details=None,
        instructions=None,
        model="openai/dspy-test-model",
        object="response",
        output=output_blocks,
        metadata = {},
        parallel_tool_calls=False,
        temperature=1.0,
        tool_choice="auto",
        tools=[],
        top_p=1.0,
        max_output_tokens=None,
        previous_response_id=None,
        reasoning=None,
        status="completed",
        text=None,
        truncation="disabled",
        usage=ResponseAPIUsage(input_tokens=1, output_tokens=1, total_tokens=2),
        user=None,
    )


def test_chat_lms_can_be_queried(litellm_test_server):
    api_base, _ = litellm_test_server
    expected_response = ["Hi!"]

    openai_lm = dspy.LM(
        model="openai/dspy-test-model",
        api_base=api_base,
        api_key="fakekey",
        model_type="chat",
    )
    assert openai_lm("openai query") == expected_response

    azure_openai_lm = dspy.LM(
        model="azure/dspy-test-model",
        api_base=api_base,
        api_key="fakekey",
        model_type="chat",
    )
    assert azure_openai_lm("azure openai query") == expected_response


def test_dspy_cache(litellm_test_server, tmp_path):
    api_base, _ = litellm_test_server

    original_cache = dspy.cache
    dspy.clients.configure_cache(
        enable_disk_cache=True,
        enable_memory_cache=True,
        disk_cache_dir=tmp_path / ".disk_cache",
    )
    cache = dspy.cache

    lm = dspy.LM(
        model="openai/dspy-test-model",
        api_base=api_base,
        api_key="fakekey",
        model_type="text",
    )
    with track_usage() as usage_tracker:
        lm("Query")

    assert len(cache.memory_cache) == 1
    cache_key = next(iter(cache.memory_cache.keys()))
    assert cache_key in cache.disk_cache
    assert len(usage_tracker.usage_data) == 1

    with track_usage() as usage_tracker:
        lm("Query")

    assert len(usage_tracker.usage_data) == 0

    dspy.cache = original_cache


def test_disabled_cache_skips_cache_key(monkeypatch):
    original_cache = dspy.cache
    dspy.configure_cache(enable_disk_cache=False, enable_memory_cache=False)
    cache = dspy.cache

    try:
        with mock.patch.object(cache, "cache_key", wraps=cache.cache_key) as cache_key_spy, \
             mock.patch.object(cache, "get", wraps=cache.get) as cache_get_spy, \
             mock.patch.object(cache, "put", wraps=cache.put) as cache_put_spy:

            def fake_completion(*, cache, num_retries, retry_strategy, **request):
                return ModelResponse(
                    choices=[Choices(message=Message(role="assistant", content="Hi!"))],
                    usage={"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
                    model="dummy",
                )

            monkeypatch.setattr(litellm, "completion", fake_completion)

            dummy_lm = DummyLM([{"answer": "ignored"}])
            # TODO(isaacbmiller): Change from dummy_lm.forward to just dummy_lm.__call__ #8864
            dummy_lm.forward(messages=[{"role": "user", "content": "Hello"}])

            cache_key_spy.assert_not_called()
            cache_get_spy.assert_called_once()
            cache_put_spy.assert_called_once()
    finally:
        dspy.cache = original_cache


def test_rollout_id_bypasses_cache(monkeypatch, tmp_path):
    calls: list[dict] = []

    def fake_completion(*, cache, num_retries, retry_strategy, **request):
        calls.append(request)
        return ModelResponse(
            choices=[Choices(message=Message(role="assistant", content="Hi!"))],
            usage={"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
            model="openai/dspy-test-model",
        )

    monkeypatch.setattr(litellm, "completion", fake_completion)

    original_cache = dspy.cache
    dspy.clients.configure_cache(
        enable_disk_cache=True,
        enable_memory_cache=True,
        disk_cache_dir=tmp_path / ".disk_cache",
    )

    lm = dspy.LM(model="openai/dspy-test-model", model_type="chat")

    with track_usage() as usage_tracker:
        lm(messages=[{"role": "user", "content": "Query"}], rollout_id=1)
    assert len(usage_tracker.usage_data) == 1

    with track_usage() as usage_tracker:
        lm(messages=[{"role": "user", "content": "Query"}], rollout_id=1)
    assert len(usage_tracker.usage_data) == 0

    with track_usage() as usage_tracker:
        lm(messages=[{"role": "user", "content": "Query"}], rollout_id=2)
    assert len(usage_tracker.usage_data) == 1

    with track_usage() as usage_tracker:
        lm(messages=[{"role": "user", "content": "NoRID"}])
    assert len(usage_tracker.usage_data) == 1

    with track_usage() as usage_tracker:
        lm(messages=[{"role": "user", "content": "NoRID"}], rollout_id=None)
    assert len(usage_tracker.usage_data) == 0

    assert len(dspy.cache.memory_cache) == 3
    assert all("rollout_id" not in r for r in calls)
    dspy.cache = original_cache


def test_zero_temperature_rollout_warns_once(monkeypatch):
    def fake_completion(*, cache, num_retries, retry_strategy, **request):
        return ModelResponse(
            choices=[Choices(message=Message(role="assistant", content="Hi!"))],
            usage={"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
            model="openai/dspy-test-model",
        )

    monkeypatch.setattr(litellm, "completion", fake_completion)

    lm = dspy.LM(model="openai/dspy-test-model", model_type="chat")
    with pytest.warns(UserWarning, match="rollout_id has no effect"):
        lm("Query", rollout_id=1)
    with warnings.catch_warnings(record=True) as record:
        warnings.simplefilter("always")
        lm("Query", rollout_id=2)
        assert len(record) == 0


def test_text_lms_can_be_queried(litellm_test_server):
    api_base, _ = litellm_test_server
    expected_response = ["Hi!"]

    openai_lm = dspy.LM(
        model="openai/dspy-test-model",
        api_base=api_base,
        api_key="fakekey",
        model_type="text",
    )
    assert openai_lm("openai query") == expected_response

    azure_openai_lm = dspy.LM(
        model="azure/dspy-test-model",
        api_base=api_base,
        api_key="fakekey",
        model_type="text",
    )
    assert azure_openai_lm("azure openai query") == expected_response


def test_lm_calls_support_callables(litellm_test_server):
    api_base, _ = litellm_test_server

    with mock.patch("litellm.completion", autospec=True, wraps=litellm.completion) as spy_completion:

        def azure_ad_token_provider(*args, **kwargs):
            return None

        lm_with_callable = dspy.LM(
            model="openai/dspy-test-model",
            api_base=api_base,
            api_key="fakekey",
            azure_ad_token_provider=azure_ad_token_provider,
            cache=False,
        )

        lm_with_callable("Query")

        spy_completion.assert_called_once()
        call_args = spy_completion.call_args.kwargs
        assert call_args["model"] == "openai/dspy-test-model"
        assert call_args["api_base"] == api_base
        assert call_args["api_key"] == "fakekey"
        assert call_args["azure_ad_token_provider"] is azure_ad_token_provider


def test_lm_calls_support_pydantic_models(litellm_test_server):
    api_base, _ = litellm_test_server

    class ResponseFormat(pydantic.BaseModel):
        response: str

    lm = dspy.LM(
        model="openai/dspy-test-model",
        api_base=api_base,
        api_key="fakekey",
        response_format=ResponseFormat,
    )
    lm("Query")


def test_retry_number_set_correctly():
    lm = dspy.LM("openai/gpt-4o-mini", num_retries=3)
    with mock.patch("litellm.completion") as mock_completion:
        lm("query")

    assert mock_completion.call_args.kwargs["num_retries"] == 3


def test_retry_made_on_system_errors():
    retry_tracking = [0]  # Using a list to track retries

    def mock_create(*args, **kwargs):
        retry_tracking[0] += 1
        # These fields are called during the error handling
        mock_response = mock.Mock()
        mock_response.headers = {}
        mock_response.status_code = 429
        raise RateLimitError(response=mock_response, message="message", body="error")

    lm = dspy.LM(model="openai/gpt-4o-mini", max_tokens=250, num_retries=3)
    with mock.patch.object(litellm.OpenAIChatCompletion, "completion", side_effect=mock_create):
        with pytest.raises(RateLimitError):
            lm("question")

    assert retry_tracking[0] == 4


def test_reasoning_model_token_parameter():
    test_cases = [
        ("openai/o1", True),
        ("openai/o1-mini", True),
        ("openai/o1-2023-01-01", True),
        ("openai/o3", True),
        ("openai/o3-mini-2023-01-01", True),
        ("openai/gpt-5", True),
        ("openai/gpt-5-mini", True),
        ("openai/gpt-5-nano", True),
        ("openai/gpt-4", False),
        ("anthropic/claude-2", False),
    ]

    for model_name, is_reasoning_model in test_cases:
        lm = dspy.LM(
            model=model_name,
            temperature=1.0 if is_reasoning_model else 0.7,
            max_tokens=16_000 if is_reasoning_model else 1000,
        )
        if is_reasoning_model:
            assert "max_completion_tokens" in lm.kwargs
            assert "max_tokens" not in lm.kwargs
            assert lm.kwargs["max_completion_tokens"] == 16_000
        else:
            assert "max_completion_tokens" not in lm.kwargs
            assert "max_tokens" in lm.kwargs
            assert lm.kwargs["max_tokens"] == 1000

@pytest.mark.parametrize("model_name", ["openai/o1", "openai/gpt-5-nano"])
def test_reasoning_model_requirements(model_name):
    # Should raise assertion error if temperature or max_tokens requirements not met
    with pytest.raises(
        ValueError,
        match="reasoning models require passing temperature=1.0 or None and max_tokens >= 16000 or None",
    ):
        dspy.LM(
            model=model_name,
            temperature=0.7,  # Should be 1.0
            max_tokens=1000,  # Should be >= 16_000
        )

    # Should pass with correct parameters
    lm = dspy.LM(
        model=model_name,
        temperature=1.0,
        max_tokens=16_000,
    )
    assert lm.kwargs["max_completion_tokens"] == 16_000

    # Should pass with no parameters
    lm = dspy.LM(
        model=model_name,
    )
    assert lm.kwargs["temperature"] == None
    assert lm.kwargs["max_completion_tokens"] == None


def test_dump_state():
    lm = dspy.LM(
        model="openai/gpt-4o-mini",
        model_type="chat",
        temperature=1,
        max_tokens=100,
        num_retries=10,
        launch_kwargs={"temperature": 1},
        train_kwargs={"temperature": 5},
    )

    assert lm.dump_state() == {
        "model": "openai/gpt-4o-mini",
        "model_type": "chat",
        "temperature": 1,
        "max_tokens": 100,
        "num_retries": 10,
        "cache": True,
        "finetuning_model": None,
        "launch_kwargs": {"temperature": 1},
        "train_kwargs": {"temperature": 5},
    }


def test_exponential_backoff_retry():
    time_counter = []

    def mock_create(*args, **kwargs):
        time_counter.append(time.time())
        # These fields are called during the error handling
        mock_response = mock.Mock()
        mock_response.headers = {}
        mock_response.status_code = 429
        raise RateLimitError(response=mock_response, message="message", body="error")

    lm = dspy.LM(model="openai/gpt-3.5-turbo", max_tokens=250, num_retries=3)
    with mock.patch.object(litellm.OpenAIChatCompletion, "completion", side_effect=mock_create):
        with pytest.raises(RateLimitError):
            lm("question")

    # The first retry happens immediately regardless of the configuration
    for i in range(1, len(time_counter) - 1):
        assert time_counter[i + 1] - time_counter[i] >= 2 ** (i - 1)


def test_logprobs_included_when_requested():
    lm = dspy.LM(model="dspy-test-model", logprobs=True, cache=False)
    with mock.patch("litellm.completion") as mock_completion:
        mock_completion.return_value = ModelResponse(
            choices=[
                Choices(
                    message=Message(content="test answer"),
                    logprobs={
                        "content": [
                            {"token": "test", "logprob": 0.1, "top_logprobs": [{"token": "test", "logprob": 0.1}]},
                            {"token": "answer", "logprob": 0.2, "top_logprobs": [{"token": "answer", "logprob": 0.2}]},
                        ]
                    },
                )
            ],
            model="dspy-test-model",
        )
        result = lm("question")
        assert result[0]["text"] == "test answer"
        assert result[0]["logprobs"].model_dump() == {
            "content": [
                {
                    "token": "test",
                    "bytes": None,
                    "logprob": 0.1,
                    "top_logprobs": [{"token": "test", "bytes": None, "logprob": 0.1}],
                },
                {
                    "token": "answer",
                    "bytes": None,
                    "logprob": 0.2,
                    "top_logprobs": [{"token": "answer", "bytes": None, "logprob": 0.2}],
                },
            ]
        }
        assert mock_completion.call_args.kwargs["logprobs"]


@pytest.mark.asyncio
async def test_async_lm_call():
    from litellm.utils import Choices, Message, ModelResponse

    mock_response = ModelResponse(choices=[Choices(message=Message(content="answer"))], model="openai/gpt-4o-mini")

    with patch("litellm.acompletion") as mock_acompletion:
        mock_acompletion.return_value = mock_response

        lm = dspy.LM(model="openai/gpt-4o-mini", cache=False)
        result = await lm.acall("question")

        assert result == ["answer"]
        mock_acompletion.assert_called_once()


@pytest.mark.asyncio
async def test_async_lm_call_with_cache(tmp_path):
    """Test the async LM call with caching."""
    original_cache = dspy.cache
    dspy.clients.configure_cache(
        enable_disk_cache=True,
        enable_memory_cache=True,
        disk_cache_dir=tmp_path / ".disk_cache",
    )
    cache = dspy.cache

    lm = dspy.LM(model="openai/gpt-4o-mini")

    with mock.patch("dspy.clients.lm.alitellm_completion") as mock_alitellm_completion:
        mock_alitellm_completion.return_value = ModelResponse(
            choices=[Choices(message=Message(content="answer"))], model="openai/gpt-4o-mini"
        )
        mock_alitellm_completion.__qualname__ = "alitellm_completion"
        await lm.acall("Query")

        assert len(cache.memory_cache) == 1
        cache_key = next(iter(cache.memory_cache.keys()))
        assert cache_key in cache.disk_cache
        assert mock_alitellm_completion.call_count == 1

        await lm.acall("Query")
        # Second call should hit the cache, so no new call to LiteLLM is made.
        assert mock_alitellm_completion.call_count == 1

        # A new query should result in a new LiteLLM call and a new cache entry.
        await lm.acall("New query")

        assert len(cache.memory_cache) == 2
        assert mock_alitellm_completion.call_count == 2

    dspy.cache = original_cache


def test_lm_history_size_limit():
    lm = dspy.LM(model="openai/gpt-4o-mini")
    with dspy.context(max_history_size=5):
        with mock.patch("litellm.completion") as mock_completion:
            mock_completion.return_value = ModelResponse(
                choices=[Choices(message=Message(content="test answer"))],
                model="openai/gpt-4o-mini",
            )

            for _ in range(10):
                lm("query")

    assert len(lm.history) == 5


def test_disable_history():
    lm = dspy.LM(model="openai/gpt-4o-mini")
    with dspy.context(disable_history=True):
        with mock.patch("litellm.completion") as mock_completion:
            mock_completion.return_value = ModelResponse(
                choices=[Choices(message=Message(content="test answer"))],
                model="openai/gpt-4o-mini",
            )
            for _ in range(10):
                lm("query")

    assert len(lm.history) == 0

    with dspy.context(disable_history=False):
        with mock.patch("litellm.completion") as mock_completion:
            mock_completion.return_value = ModelResponse(
                choices=[Choices(message=Message(content="test answer"))],
                model="openai/gpt-4o-mini",
            )

def test_responses_api():
    api_response = make_response(
        output_blocks=[
            ResponseOutputMessage(
                **{
                    "id": "msg_1",
                    "type": "message",
                    "role": "assistant",
                    "status": "completed",
                    "content": [
                        {"type": "output_text", "text": "This is a test answer from responses API.", "annotations": []}
                    ],
                },
            ),
            ResponseReasoningItem(
                **{
                    "id": "reasoning_1",
                    "type": "reasoning",
                    "summary": [Summary(**{"type": "summary_text", "text": "This is a dummy reasoning."})],
                },
            ),
        ]
    )

    with mock.patch("litellm.responses", autospec=True, return_value=api_response) as dspy_responses:
        lm = dspy.LM(
            model="openai/gpt-5-mini",
            model_type="responses",
            cache=False,
            temperature=1.0,
            max_tokens=16000,
        )
        lm_result = lm("openai query")

        assert lm_result == [
            {
                "text": "This is a test answer from responses API.",
                "reasoning_content": "This is a dummy reasoning.",
            }
        ]

        dspy_responses.assert_called_once()
        assert dspy_responses.call_args.kwargs["model"] == "openai/gpt-5-mini"


def test_lm_replaces_system_with_developer_role():
    with mock.patch(
        "dspy.clients.lm.litellm_responses_completion", return_value={"choices": []}
    ) as mock_completion:
        lm = dspy.LM(
            "openai/gpt-4o-mini",
            cache=False,
            model_type="responses",
            use_developer_role=True,
        )
        lm.forward(messages=[{"role": "system", "content": "hi"}])
        assert (
            mock_completion.call_args.kwargs["request"]["messages"][0]["role"]
            == "developer"
        )


def test_responses_api_tool_calls(litellm_test_server):
    api_base, _ = litellm_test_server
    expected_tool_call = {
        "type": "function_call",
        "name": "get_weather",
        "arguments": json.dumps({"city": "Paris"}),
        "call_id": "call_1",
        "status": "completed",
        "id": "call_1",
    }
    expected_response = [{"tool_calls": [expected_tool_call]}]

    api_response = make_response(
        output_blocks=[expected_tool_call],
    )

    with mock.patch("litellm.responses", autospec=True, return_value=api_response) as dspy_responses:
        lm = dspy.LM(
            model="openai/dspy-test-model",
            api_base=api_base,
            api_key="fakekey",
            model_type="responses",
            cache=False,
        )
        assert lm("openai query") == expected_response

        dspy_responses.assert_called_once()
        assert dspy_responses.call_args.kwargs["model"] == "openai/dspy-test-model"

```

--------------------------------------------------------------------------------
/docs/docs/api/optimizers/GEPA/GEPA_Advanced.md:
--------------------------------------------------------------------------------

```markdown
# dspy.GEPA - Advanced Features

## Custom Instruction Proposers

### What is instruction_proposer?

The `instruction_proposer` is the component responsible for invoking the `reflection_lm` and proposing new prompts during GEPA optimization. When GEPA identifies underperforming components in your DSPy program, the instruction proposer analyzes execution traces, feedback, and failures to generate improved instructions tailored to the observed issues.

### Default Implementation

By default, GEPA uses the built-in instruction proposer from the [GEPA library](https://github.com/gepa-ai/gepa), which implements the [`ProposalFn`](https://github.com/gepa-ai/gepa/blob/main/src/gepa/core/adapter.py). The [default proposer](https://github.com/gepa-ai/gepa/blob/main/src/gepa/proposer/reflective_mutation/reflective_mutation.py#L53-L75) uses this prompt template:

````
I provided an assistant with the following instructions to perform a task for me:
```
<curr_instructions>
```

The following are examples of different task inputs provided to the assistant along with the assistant's response for each of them, and some feedback on how the assistant's response could be better:
```
<inputs_outputs_feedback>
```

Your task is to write a new instruction for the assistant.

Read the inputs carefully and identify the input format and infer detailed task description about the task I wish to solve with the assistant.

Read all the assistant responses and the corresponding feedback. Identify all niche and domain specific factual information about the task and include it in the instruction, as a lot of it may not be available to the assistant in the future. The assistant may have utilized a generalizable strategy to solve the task, if so, include that in the instruction as well.

Provide the new instructions within ``` blocks.
````

This template is automatically filled with:

- `<curr_instructions>`: The current instruction being optimized
- `<inputs_outputs_feedback>`: Structured markdown containing predictor inputs, generated outputs, and evaluation feedback

Example of default behavior:

```python
# Default instruction proposer is used automatically
gepa = dspy.GEPA(
    metric=my_metric,
    reflection_lm=dspy.LM(model="gpt-5", temperature=1.0, max_tokens=32000, api_key=api_key),
    auto="medium"
)
optimized_program = gepa.compile(student, trainset=examples)
```

### When to Use Custom instruction_proposer

**Note:** Custom instruction proposers are an advanced feature. Most users should start with the default proposer, which works well for most text-based optimization tasks.

Consider implementing a custom instruction proposer when you need:

- **Multi-modal handling**: Process images (dspy.Image) alongside textual information in your inputs
- **Nuanced control on limits and length constraints**: Have more fine-grained control over instruction length, format, and structural requirements
- **Domain-specific information**: Inject specialized knowledge, terminology, or context that the default proposer lacks and cannot be provided via feedback_func. This is an advanced feature, and most users should not need to use this.
- **Provider-specific prompting guides**: Optimize instructions for specific LLM providers (OpenAI, Anthropic, etc.) with their unique formatting preferences
- **Coupled component updates**: Handle situations where 2 or more components need to be updated together in a coordinated manner, rather than optimizing each component independently (refer to component_selector parameter, in [Custom Component Selection](#custom-component-selection) section, for related functionality)
- **External knowledge integration**: Connect to databases, APIs, or knowledge bases during instruction generation

### Available Options

**Built-in Options:**

- **Default Proposer**: The standard GEPA instruction proposer (used when `instruction_proposer=None`). The default instruction proposer IS an instruction proposer as well! It is the most general one, that was used for the diverse experiments reported in the GEPA paper and tutorials.
- **MultiModalInstructionProposer**: Handles `dspy.Image` inputs and structured multimodal content.

```python
from dspy.teleprompt.gepa.instruction_proposal import MultiModalInstructionProposer

# For tasks involving images or multimodal inputs
gepa = dspy.GEPA(
    metric=my_metric,
    reflection_lm=dspy.LM(model="gpt-5", temperature=1.0, max_tokens=32000, api_key=api_key),
    instruction_proposer=MultiModalInstructionProposer(),
    auto="medium"
)
```

We invite community contributions of new instruction proposers for specialized domains as the [GEPA library](https://github.com/gepa-ai/gepa) continues to grow.

### How to Implement Custom Instruction Proposers

Custom instruction proposers must implement the `ProposalFn` protocol by defining a callable class or function. GEPA will call your proposer during optimization:

```python
from dspy.teleprompt.gepa.gepa_utils import ReflectiveExample

class CustomInstructionProposer:
    def __call__(
        self,
        candidate: dict[str, str],                          # Candidate component name -> instruction mapping to be updated in this round
        reflective_dataset: dict[str, list[ReflectiveExample]],  # Component -> examples with structure: {"Inputs": ..., "Generated Outputs": ..., "Feedback": ...}
        components_to_update: list[str]                     # Which components to improve
    ) -> dict[str, str]:                                    # Return new instruction mapping only for components being updated
        # Your custom instruction generation logic here
        return updated_instructions

# Or as a function:
def custom_instruction_proposer(candidate, reflective_dataset, components_to_update):
    # Your custom instruction generation logic here
    return updated_instructions
```

**Reflective Dataset Structure:**

- `dict[str, list[ReflectiveExample]]` - Maps component names to lists of examples
- `ReflectiveExample` TypedDict contains:
  - `Inputs: dict[str, Any]` - Predictor inputs (may include dspy.Image objects)
  - `Generated_Outputs: dict[str, Any] | str` - Success: output fields dict, Failure: error message
  - `Feedback: str` - Always a string from metric function or auto-generated by GEPA

#### Basic Example: Word Limit Proposer

```python
import dspy
from gepa.core.adapter import ProposalFn
from dspy.teleprompt.gepa.gepa_utils import ReflectiveExample

class GenerateWordLimitedInstruction(dspy.Signature):
    """Given a current instruction and feedback examples, generate an improved instruction with word limit constraints."""

    current_instruction = dspy.InputField(desc="The current instruction that needs improvement")
    feedback_summary = dspy.InputField(desc="Feedback from examples that might include both positive and negative cases")
    max_words = dspy.InputField(desc="Maximum number of words allowed in the new instruction")

    improved_instruction = dspy.OutputField(desc="A new instruction that fixes the issues while staying under the max_words limit")

class WordLimitProposer(ProposalFn):
    def __init__(self, max_words: int = 1000):
        self.max_words = max_words
        self.instruction_improver = dspy.ChainOfThought(GenerateWordLimitedInstruction)

    def __call__(self, candidate: dict[str, str], reflective_dataset: dict[str, list[ReflectiveExample]], components_to_update: list[str]) -> dict[str, str]:
        updated_components = {}

        for component_name in components_to_update:
            if component_name not in candidate or component_name not in reflective_dataset:
                continue

            current_instruction = candidate[component_name]
            component_examples = reflective_dataset[component_name]

            # Create feedback summary
            feedback_text = "\n".join([
                f"Example {i+1}: {ex.get('Feedback', 'No feedback')}"
                for i, ex in enumerate(component_examples)  # Limit examples to prevent context overflow
            ])

            # Use the module to improve the instruction
            result = self.instruction_improver(
                current_instruction=current_instruction,
                feedback_summary=feedback_text,
                max_words=self.max_words
            )

            updated_components[component_name] = result.improved_instruction

        return updated_components

# Usage
gepa = dspy.GEPA(
    metric=my_metric,
    reflection_lm=dspy.LM(model="gpt-5", temperature=1.0, max_tokens=32000, api_key=api_key),
    instruction_proposer=WordLimitProposer(max_words=700),
    auto="medium"
)
```

#### Advanced Example: RAG-Enhanced Instruction Proposer

```python
import dspy
from gepa.core.adapter import ProposalFn
from dspy.teleprompt.gepa.gepa_utils import ReflectiveExample

class GenerateDocumentationQuery(dspy.Signature):
    """Analyze examples with feedback to identify common issue patterns and generate targeted database queries for retrieving relevant documentation.

    Your goal is to search a document database for guidelines that address the problematic patterns found in the examples. Look for recurring issues, error types, or failure modes in the feedback, then craft specific search queries that will find documentation to help resolve these patterns."""

    current_instruction = dspy.InputField(desc="The current instruction that needs improvement")
    examples_with_feedback = dspy.InputField(desc="Examples with their feedback showing what issues occurred and any recurring patterns")

    failure_patterns: str = dspy.OutputField(desc="Summarize the common failure patterns identified in the examples")

    retrieval_queries: list[str] = dspy.OutputField(desc="Specific search queries to find relevant documentation in the database that addresses the common issue patterns identified in the problematic examples")

class GenerateRAGEnhancedInstruction(dspy.Signature):
    """Generate improved instructions using retrieved documentation and examples analysis."""

    current_instruction = dspy.InputField(desc="The current instruction that needs improvement")
    relevant_documentation = dspy.InputField(desc="Retrieved guidelines and best practices from specialized documentation")
    examples_with_feedback = dspy.InputField(desc="Examples showing what issues occurred with the current instruction")

    improved_instruction: str = dspy.OutputField(desc="Enhanced instruction that incorporates retrieved guidelines and addresses the issues shown in the examples")

class RAGInstructionImprover(dspy.Module):
    """Module that uses RAG to improve instructions with specialized documentation."""

    def __init__(self, retrieval_model):
        super().__init__()
        self.retrieve = retrieval_model  # Could be dspy.Retrieve or custom retriever
        self.query_generator = dspy.ChainOfThought(GenerateDocumentationQuery)
        self.generate_answer = dspy.ChainOfThought(GenerateRAGEnhancedInstruction)

    def forward(self, current_instruction: str, component_examples: list):
        """Improve instruction using retrieved documentation."""

        # Let LM analyze examples and generate targeted retrieval queries
        query_result = self.query_generator(
            current_instruction=current_instruction,
            examples_with_feedback=component_examples
        )

        results = self.retrieve.query(
            query_texts=query_result.retrieval_queries,
            n_results=3
        )

        relevant_docs_parts = []
        for i, (query, query_docs) in enumerate(zip(query_result.retrieval_queries, results['documents'])):
            if query_docs:
                docs_formatted = "\n".join([f"  - {doc}" for doc in query_docs])
                relevant_docs_parts.append(
                    f"**Search Query #{i+1}**: {query}\n"
                    f"**Retrieved Guidelines**:\n{docs_formatted}"
                )

        relevant_docs = "\n\n" + "="*60 + "\n\n".join(relevant_docs_parts) + "\n" + "="*60

        # Generate improved instruction with retrieved context
        result = self.generate_answer(
            current_instruction=current_instruction,
            relevant_documentation=relevant_docs,
            examples_with_feedback=component_examples
        )

        return result

class DocumentationEnhancedProposer(ProposalFn):
    """Instruction proposer that accesses specialized documentation via RAG."""

    def __init__(self, documentation_retriever):
        """
        Args:
            documentation_retriever: A retrieval model that can search your specialized docs
                                   Could be dspy.Retrieve, ChromadbRM, or custom retriever
        """
        self.instruction_improver = RAGInstructionImprover(documentation_retriever)

    def __call__(self, candidate: dict[str, str], reflective_dataset: dict[str, list[ReflectiveExample]], components_to_update: list[str]) -> dict[str, str]:
        updated_components = {}

        for component_name in components_to_update:
            if component_name not in candidate or component_name not in reflective_dataset:
                continue

            current_instruction = candidate[component_name]
            component_examples = reflective_dataset[component_name]

            result = self.instruction_improver(
                current_instruction=current_instruction,
                component_examples=component_examples
            )

            updated_components[component_name] = result.improved_instruction

        return updated_components

import chromadb

client = chromadb.Client()
collection = client.get_collection("instruction_guidelines")

gepa = dspy.GEPA(
    metric=task_specific_metric,
    reflection_lm=dspy.LM(model="gpt-5", temperature=1.0, max_tokens=32000, api_key=api_key),
    instruction_proposer=DocumentationEnhancedProposer(collection),
    auto="medium"
)
```

#### Integration Patterns

**Using Custom Proposer with External LM:**

```python
class ExternalLMProposer(ProposalFn):
    def __init__(self):
        # Manage your own LM instance
        self.external_lm = dspy.LM('gemini/gemini-2.5-pro')

    def __call__(self, candidate, reflective_dataset, components_to_update):
        updated_components = {}

        with dspy.context(lm=self.external_lm):
            # Your custom logic here using self.external_lm
            for component_name in components_to_update:
                # ... implementation
                pass

        return updated_components

gepa = dspy.GEPA(
    metric=my_metric,
    reflection_lm=None,  # Optional when using custom proposer
    instruction_proposer=ExternalLMProposer(),
    auto="medium"
)
```

**Best Practices:**

- **Use the full power of DSPy**: Leverage DSPy components like `dspy.Module`, `dspy.Signature`, and `dspy.Predict` to create your instruction proposer rather than direct LM calls. Consider `dspy.Refine` for constraint satisfaction, `dspy.ChainOfThought` for complex reasoning tasks, and compose multiple modules for sophisticated instruction improvement workflows
- **Enable holistic feedback analysis**: While dspy.GEPA's `GEPAFeedbackMetric` processes one (gold, prediction) pair at a time, instruction proposers receive all examples for a component in batch, enabling cross-example pattern detection and systematic issue identification.
- **Mind data serialization**: Serializing everything to strings might not be ideal - handle complex input types (like `dspy.Image`) by maintaining their structure for better LM processing
- **Test thoroughly**: Test your custom proposer with representative failure cases

## Custom Component Selection

### What is component_selector?

The `component_selector` parameter controls which components (predictors) in your DSPy program are selected for optimization at each GEPA iteration. Instead of the default round-robin approach that updates one component at a time, you can implement custom selection strategies that choose single or multiple components based on optimization state, performance trajectories, and other contextual information.

### Default Behavior

By default, GEPA uses a **round-robin strategy** (`RoundRobinReflectionComponentSelector`) that cycles through components sequentially, optimizing one component per iteration:

```python
# Default round-robin component selection
gepa = dspy.GEPA(
    metric=my_metric,
    reflection_lm=dspy.LM(model="gpt-5", temperature=1.0, max_tokens=32000, api_key=api_key),
    # component_selector="round_robin"  # This is the default
    auto="medium"
)
```

### Built-in Selection Strategies

**String-based selectors:**

- `"round_robin"` (default): Cycles through components one at a time
- `"all"`: Selects all components for simultaneous optimization

```python
# Optimize all components simultaneously
gepa = dspy.GEPA(
    metric=my_metric,
    reflection_lm=reflection_lm,
    component_selector="all",  # Update all components together
    auto="medium"
)

# Explicit round-robin selection
gepa = dspy.GEPA(
    metric=my_metric,
    reflection_lm=reflection_lm,
    component_selector="round_robin",  # One component per iteration
    auto="medium"
)
```

### When to Use Custom Component Selection

Consider implementing custom component selection when you need:

- **Dependency-aware optimization**: Update related components together (e.g., a classifier and its input formatter)
- **LLM-driven selection**: Let an LLM analyze trajectories and decide which components need attention
- **Resource-conscious optimization**: Balance optimization thoroughness with computational budget

### Custom Component Selector Protocol

Custom component selectors must implement the [`ReflectionComponentSelector`](https://github.com/gepa-ai/gepa/blob/main/src/gepa/proposer/reflective_mutation/base.py) protocol by defining a callable class or function. GEPA will call your selector during optimization:

```python
from dspy.teleprompt.gepa.gepa_utils import GEPAState, Trajectory

class CustomComponentSelector:
    def __call__(
        self,
        state: GEPAState,                    # Complete optimization state with history
        trajectories: list[Trajectory],      # Execution traces from the current minibatch
        subsample_scores: list[float],       # Scores for each example in the current minibatch
        candidate_idx: int,                  # Index of the current program candidate being optimized
        candidate: dict[str, str],           # Component name -> instruction mapping
    ) -> list[str]:                          # Return list of component names to optimize
        # Your custom component selection logic here
        return selected_components

# Or as a function:
def custom_component_selector(state, trajectories, subsample_scores, candidate_idx, candidate):
    # Your custom component selection logic here
    return selected_components
```

### Custom Implementation Example

Here's a simple function that alternates between optimizing different halves of your components:

```python
def alternating_half_selector(state, trajectories, subsample_scores, candidate_idx, candidate):
    """Optimize half the components on even iterations, half on odd iterations."""
    components = list(candidate.keys())

    # If there's only one component, always optimize it
    if len(components) <= 1:
        return components

    mid_point = len(components) // 2

    # Use state.i (iteration counter) to alternate between halves
    if state.i % 2 == 0:
        # Even iteration: optimize first half
        return components[:mid_point]
    else:
        # Odd iteration: optimize second half
        return components[mid_point:]

# Usage
gepa = dspy.GEPA(
    metric=my_metric,
    reflection_lm=reflection_lm,
    component_selector=alternating_half_selector,
    auto="medium"
)
```

### Integration with Custom Instruction Proposers

Component selectors work seamlessly with custom instruction proposers. The selector determines which components to update, then the instruction proposer generates new instructions for those components:

```python
# Combined custom selector + custom proposer
gepa = dspy.GEPA(
    metric=my_metric,
    reflection_lm=reflection_lm,
    component_selector=alternating_half_selector,
    instruction_proposer=WordLimitProposer(max_words=500),
    auto="medium"
)
```

```

--------------------------------------------------------------------------------
/dspy/retrievers/databricks_rm.py:
--------------------------------------------------------------------------------

```python
import json
import os
from dataclasses import dataclass
from importlib.util import find_spec
from typing import Any

import requests

import dspy
from dspy.primitives.prediction import Prediction

_databricks_sdk_installed = find_spec("databricks.sdk") is not None


@dataclass
class Document:
    page_content: str
    metadata: dict[str, Any]
    type: str

    def to_dict(self) -> dict[str, Any]:
        return {
            "page_content": self.page_content,
            "metadata": self.metadata,
            "type": self.type,
        }


class DatabricksRM(dspy.Retrieve):
    """
    A retriever module that uses a Databricks Mosaic AI Vector Search Index to return the top-k
    embeddings for a given query.

    Examples:
        Below is a code snippet that shows how to set up a Databricks Vector Search Index
        and configure a DatabricksRM DSPy retriever module to query the index.

        (example adapted from "Databricks: How to create and query a Vector Search Index:
        https://docs.databricks.com/en/generative-ai/create-query-vector-search.html#create-a-vector-search-index)

        ```python
        from databricks.vector_search.client import VectorSearchClient

        # Create a Databricks Vector Search Endpoint
        client = VectorSearchClient()
        client.create_endpoint(
            name="your_vector_search_endpoint_name",
            endpoint_type="STANDARD"
        )

        # Create a Databricks Direct Access Vector Search Index
        index = client.create_direct_access_index(
            endpoint_name="your_vector_search_endpoint_name",
            index_name="your_index_name",
            primary_key="id",
            embedding_dimension=1024,
            embedding_vector_column="text_vector",
            schema={
              "id": "int",
              "field2": "str",
              "field3": "float",
              "text_vector": "array<float>"
            }
        )

        # Create a DatabricksRM retriever module to query the Databricks Direct Access Vector
        # Search Index
        retriever = DatabricksRM(
            databricks_index_name = "your_index_name",
            docs_id_column_name="id",
            text_column_name="field2",
            k=3
        )
        ```

        Below is a code snippet that shows how to query the Databricks Direct Access Vector
        Search Index using the DatabricksRM retriever module:

        ```python
        retrieved_results = DatabricksRM(query="Example query text"))
        ```
    """

    def __init__(
        self,
        databricks_index_name: str,
        databricks_endpoint: str | None = None,
        databricks_token: str | None = None,
        databricks_client_id: str | None = None,
        databricks_client_secret: str | None = None,
        columns: list[str] | None = None,
        filters_json: str | None = None,
        k: int = 3,
        docs_id_column_name: str = "id",
        docs_uri_column_name: str | None = None,
        text_column_name: str = "text",
        use_with_databricks_agent_framework: bool = False,
    ):
        """
        Args:
            databricks_index_name (str): The name of the Databricks Vector Search Index to query.
            databricks_endpoint (Optional[str]): The URL of the Databricks Workspace containing
                the Vector Search Index. Defaults to the value of the ``DATABRICKS_HOST``
                environment variable. If unspecified, the Databricks SDK is used to identify the
                endpoint based on the current environment.
            databricks_token (Optional[str]): The Databricks Workspace authentication token to use
                when querying the Vector Search Index. Defaults to the value of the
                ``DATABRICKS_TOKEN`` environment variable. If unspecified, the Databricks SDK is
                used to identify the token based on the current environment.
            databricks_client_id (str): Databricks service principal id. If not specified,
                the token is resolved from the current environment (DATABRICKS_CLIENT_ID).
            databricks_client_secret (str): Databricks service principal secret. If not specified,
                the endpoint is resolved from the current environment (DATABRICKS_CLIENT_SECRET).
            columns (Optional[list[str]]): Extra column names to include in response,
                in addition to the document id and text columns specified by
                ``docs_id_column_name`` and ``text_column_name``.
            filters_json (Optional[str]): A JSON string specifying additional query filters.
                Example filters: ``{"id <": 5}`` selects records that have an ``id`` column value
                less than 5, and ``{"id >=": 5, "id <": 10}`` selects records that have an ``id``
                column value greater than or equal to 5 and less than 10.
            k (int): The number of documents to retrieve.
            docs_id_column_name (str): The name of the column in the Databricks Vector Search Index
                containing document IDs.
            docs_uri_column_name (Optional[str]): The name of the column in the Databricks Vector Search Index
                containing document URI.
            text_column_name (str): The name of the column in the Databricks Vector Search Index
                containing document text to retrieve.
            use_with_databricks_agent_framework (bool): Whether to use the `DatabricksRM` in a way that is
                compatible with the Databricks Mosaic Agent Framework.
        """
        super().__init__(k=k)
        self.databricks_token = databricks_token if databricks_token is not None else os.environ.get("DATABRICKS_TOKEN")
        self.databricks_endpoint = (
            databricks_endpoint if databricks_endpoint is not None else os.environ.get("DATABRICKS_HOST")
        )
        self.databricks_client_id = (
            databricks_client_id if databricks_client_id is not None else os.environ.get("DATABRICKS_CLIENT_ID")
        )
        self.databricks_client_secret = (
            databricks_client_secret
            if databricks_client_secret is not None
            else os.environ.get("DATABRICKS_CLIENT_SECRET")
        )
        if not _databricks_sdk_installed and (self.databricks_token, self.databricks_endpoint).count(None) > 0:
            raise ValueError(
                "To retrieve documents with Databricks Vector Search, you must install the"
                " databricks-sdk Python library, supply the databricks_token and"
                " databricks_endpoint parameters, or set the DATABRICKS_TOKEN and DATABRICKS_HOST"
                " environment variables. You may also supply a service principal the databricks_client_id and"
                " databricks_client_secret parameters, or set the DATABRICKS_CLIENT_ID and DATABRICKS_CLIENT_SECRET"
            )
        self.databricks_index_name = databricks_index_name
        self.columns = list({docs_id_column_name, text_column_name, *(columns or [])})
        self.filters_json = filters_json
        self.k = k
        self.docs_id_column_name = docs_id_column_name
        self.docs_uri_column_name = docs_uri_column_name
        self.text_column_name = text_column_name
        self.use_with_databricks_agent_framework = use_with_databricks_agent_framework
        if self.use_with_databricks_agent_framework:
            try:
                import mlflow

                mlflow.models.set_retriever_schema(
                    primary_key="doc_id",
                    text_column="page_content",
                    doc_uri="doc_uri",
                )
            except ImportError:
                raise ValueError(
                    "To use the `DatabricksRM` retriever module with the Databricks Mosaic Agent Framework, "
                    "you must install the mlflow Python library. Please install mlflow via `pip install mlflow`."
                )

    def _extract_doc_ids(self, item: dict[str, Any]) -> str:
        """Extracts the document id from a search result

        Args:
            item: dict[str, Any]: a record from the search results.
        Returns:
            str: document id.
        """
        if self.docs_id_column_name == "metadata":
            docs_dict = json.loads(item["metadata"])
            return docs_dict["document_id"]
        return item[self.docs_id_column_name]

    def _get_extra_columns(self, item: dict[str, Any]) -> dict[str, Any]:
        """Extracts search result column values, excluding the "text" and not "id" columns

        Args:
            item: dict[str, Any]: a record from the search results.
        Returns:
            dict[str, Any]: Search result column values, excluding the "text", "id" and "uri" columns.
        """
        extra_columns = {
            k: v
            for k, v in item.items()
            if k not in [self.docs_id_column_name, self.text_column_name, self.docs_uri_column_name]
        }
        if self.docs_id_column_name == "metadata":
            extra_columns = {
                **extra_columns,
                **{"metadata": {k: v for k, v in json.loads(item["metadata"]).items() if k != "document_id"}},
            }
        return extra_columns

    def forward(
        self,
        query: str | list[float],
        query_type: str = "ANN",
        filters_json: str | None = None,
    ) -> dspy.Prediction | list[dict[str, Any]]:
        """
        Retrieve documents from a Databricks Mosaic AI Vector Search Index that are relevant to the
        specified query.

        Args:
            query (Union[str, list[float]]): The query text or numeric query vector for which to
                retrieve relevant documents.
            query_type (str): The type of search query to perform against the Databricks Vector
                Search Index. Must be either 'ANN' (approximate nearest neighbor) or 'HYBRID'
                (hybrid search).
            filters_json (Optional[str]): A JSON string specifying additional query filters.
                Example filters: ``{"id <": 5}`` selects records that have an ``id`` column value
                less than 5, and ``{"id >=": 5, "id <": 10}`` selects records that have an ``id``
                column value greater than or equal to 5 and less than 10. If specified, this
                parameter overrides the `filters_json` parameter passed to the constructor.

        Returns:
            A list of dictionaries when ``use_with_databricks_agent_framework`` is ``True``,
            or a ``dspy.Prediction`` object when ``use_with_databricks_agent_framework`` is
            ``False``.
        """
        if query_type in ["vector", "text"]:
            # Older versions of DSPy used a `query_type` argument to disambiguate between text
            # and vector queries, rather than checking the type of the `query` argument. This
            # differs from the Databricks Vector Search definition of `query_type`, which
            # specifies the search algorithm to use (e.g. "ANN" or "HYBRID"). To maintain
            # backwards compatibility with older versions of DSPy, we map the old `query_type`
            # values to the Databricks Vector Search default query type of "ANN".
            query_type = "ANN"

        if isinstance(query, str):
            query_text = query
            query_vector = None
        elif isinstance(query, list):
            query_vector = query
            query_text = None
        else:
            raise ValueError("Query must be a string or a list of floats.")

        if _databricks_sdk_installed:
            results = self._query_via_databricks_sdk(
                index_name=self.databricks_index_name,
                k=self.k,
                columns=self.columns,
                query_type=query_type,
                query_text=query_text,
                query_vector=query_vector,
                databricks_token=self.databricks_token,
                databricks_endpoint=self.databricks_endpoint,
                databricks_client_id=self.databricks_client_id,
                databricks_client_secret=self.databricks_client_secret,
                filters_json=filters_json or self.filters_json,
            )
        else:
            results = self._query_via_requests(
                index_name=self.databricks_index_name,
                k=self.k,
                columns=self.columns,
                databricks_token=self.databricks_token,
                databricks_endpoint=self.databricks_endpoint,
                query_type=query_type,
                query_text=query_text,
                query_vector=query_vector,
                filters_json=filters_json or self.filters_json,
            )

        # Checking if defined columns are present in the index columns
        col_names = [column["name"] for column in results["manifest"]["columns"]]

        if self.docs_id_column_name not in col_names:
            raise Exception(
                f"docs_id_column_name: '{self.docs_id_column_name}' is not in the index columns: \n {col_names}"
            )

        if self.text_column_name not in col_names:
            raise Exception(f"text_column_name: '{self.text_column_name}' is not in the index columns: \n {col_names}")

        # Extracting the results
        items = []
        if "data_array" in results["result"]:
            for _, data_row in enumerate(results["result"]["data_array"]):
                item = {}
                for col_name, val in zip(col_names, data_row, strict=False):
                    item[col_name] = val
                items += [item]

        # Sorting results by score in descending order
        sorted_docs = sorted(items, key=lambda x: x["score"], reverse=True)[: self.k]

        if self.use_with_databricks_agent_framework:
            return [
                Document(
                    page_content=doc[self.text_column_name],
                    metadata={
                        "doc_id": self._extract_doc_ids(doc),
                        "doc_uri": doc[self.docs_uri_column_name] if self.docs_uri_column_name else None,
                    }
                    | self._get_extra_columns(doc),
                    type="Document",
                ).to_dict()
                for doc in sorted_docs
            ]
        else:
            # Returning the prediction
            return Prediction(
                docs=[doc[self.text_column_name] for doc in sorted_docs],
                doc_ids=[self._extract_doc_ids(doc) for doc in sorted_docs],
                doc_uris=[doc[self.docs_uri_column_name] for doc in sorted_docs] if self.docs_uri_column_name else None,
                extra_columns=[self._get_extra_columns(item) for item in sorted_docs],
            )

    @staticmethod
    def _query_via_databricks_sdk(
        index_name: str,
        k: int,
        columns: list[str],
        query_type: str,
        query_text: str | None,
        query_vector: list[float] | None,
        databricks_token: str | None,
        databricks_endpoint: str | None,
        databricks_client_id: str | None,
        databricks_client_secret: str | None,
        filters_json: str | None,
    ) -> dict[str, Any]:
        """
        Query a Databricks Vector Search Index via the Databricks SDK.
        Assumes that the databricks-sdk Python library is installed.

        Args:
            index_name (str): Name of the Databricks vector search index to query
            k (int): Number of relevant documents to retrieve.
            columns (list[str]): Column names to include in response.
            query_text (Optional[str]): Text query for which to find relevant documents. Exactly
                one of query_text or query_vector must be specified.
            query_vector (Optional[list[float]]): Numeric query vector for which to find relevant
                documents. Exactly one of query_text or query_vector must be specified.
            filters_json (Optional[str]): JSON string representing additional query filters.
            databricks_token (str): Databricks authentication token. If not specified,
                the token is resolved from the current environment.
            databricks_endpoint (str): Databricks index endpoint url. If not specified,
                the endpoint is resolved from the current environment.
            databricks_client_id (str): Databricks service principal id. If not specified,
                the token is resolved from the current environment (DATABRICKS_CLIENT_ID).
            databricks_client_secret (str): Databricks service principal secret. If not specified,
                the endpoint is resolved from the current environment (DATABRICKS_CLIENT_SECRET).
        Returns:
        Returns:
            dict[str, Any]: Parsed JSON response from the Databricks Vector Search Index query.
        """

        from databricks.sdk import WorkspaceClient

        if (query_text, query_vector).count(None) != 1:
            raise ValueError("Exactly one of query_text or query_vector must be specified.")

        if databricks_client_secret and databricks_client_id:
            # Use client ID and secret for authentication if they are provided
            databricks_client = WorkspaceClient(
                client_id=databricks_client_id,
                client_secret=databricks_client_secret,
            )
            print("Creating Databricks workspace client using service principal authentication.")

        else:
            # Fallback for token-based authentication
            databricks_client = WorkspaceClient(
                host=databricks_endpoint,
                token=databricks_token,
            )
            print("Creating Databricks workspace client using token authentication.")

        return databricks_client.vector_search_indexes.query_index(
            index_name=index_name,
            query_type=query_type,
            query_text=query_text,
            query_vector=query_vector,
            columns=columns,
            filters_json=filters_json,
            num_results=k,
        ).as_dict()

    @staticmethod
    def _query_via_requests(
        index_name: str,
        k: int,
        columns: list[str],
        databricks_token: str,
        databricks_endpoint: str,
        query_type: str,
        query_text: str | None,
        query_vector: list[float] | None,
        filters_json: str | None,
    ) -> dict[str, Any]:
        """
        Query a Databricks Vector Search Index via the Python requests library.

        Args:
            index_name (str): Name of the Databricks vector search index to query
            k (int): Number of relevant documents to retrieve.
            columns (list[str]): Column names to include in response.
            databricks_token (str): Databricks authentication token.
            databricks_endpoint (str): Databricks index endpoint url.
            query_text (Optional[str]): Text query for which to find relevant documents. Exactly
                one of query_text or query_vector must be specified.
            query_vector (Optional[list[float]]): Numeric query vector for which to find relevant
                documents. Exactly one of query_text or query_vector must be specified.
            filters_json (Optional[str]): JSON string representing additional query filters.

        Returns:
            dict[str, Any]: Parsed JSON response from the Databricks Vector Search Index query.
        """
        if (query_text, query_vector).count(None) != 1:
            raise ValueError("Exactly one of query_text or query_vector must be specified.")

        headers = {
            "Authorization": f"Bearer {databricks_token}",
            "Content-Type": "application/json",
        }
        payload = {
            "columns": columns,
            "num_results": k,
            "query_type": query_type,
        }
        if filters_json is not None:
            payload["filters_json"] = filters_json
        if query_text is not None:
            payload["query_text"] = query_text
        elif query_vector is not None:
            payload["query_vector"] = query_vector
        response = requests.post(
            f"{databricks_endpoint}/api/2.0/vector-search/indexes/{index_name}/query",
            json=payload,
            headers=headers,
        )
        results = response.json()
        if "error_code" in results:
            raise Exception(f"ERROR: {results['error_code']} -- {results['message']}")
        return results

```
Page 9/14FirstPrevNextLast