#
tokens: 45741/50000 7/391 files (page 12/17)
lines: on (toggle) GitHub
raw markdown copy reset
This is page 12 of 17. Use http://codebase.md/stanfordnlp/dspy?lines=true&page={x} to view the full context.

# Directory Structure

```
├── .github
│   ├── .internal_dspyai
│   │   ├── internals
│   │   │   ├── build-and-release.md
│   │   │   └── release-checklist.md
│   │   └── pyproject.toml
│   ├── .tmp
│   │   └── .generated-actions
│   │       └── run-pypi-publish-in-docker-container
│   │           └── action.yml
│   ├── ISSUE_TEMPLATE
│   │   ├── bug_report.yml
│   │   └── feature_request.yml
│   ├── PULL_REQUEST_TEMPLATE
│   │   └── pull_request_template.md
│   ├── workflow_scripts
│   │   └── install_testpypi_pkg.sh
│   └── workflows
│       ├── build_and_release.yml
│       ├── build_utils
│       │   └── test_version.py
│       ├── docs-push.yml
│       ├── precommits_check.yml
│       └── run_tests.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CONTRIBUTING.md
├── docs
│   ├── .gitignore
│   ├── docs
│   │   ├── api
│   │   │   ├── adapters
│   │   │   │   ├── Adapter.md
│   │   │   │   ├── ChatAdapter.md
│   │   │   │   ├── JSONAdapter.md
│   │   │   │   └── TwoStepAdapter.md
│   │   │   ├── evaluation
│   │   │   │   ├── answer_exact_match.md
│   │   │   │   ├── answer_passage_match.md
│   │   │   │   ├── CompleteAndGrounded.md
│   │   │   │   ├── Evaluate.md
│   │   │   │   ├── EvaluationResult.md
│   │   │   │   └── SemanticF1.md
│   │   │   ├── experimental
│   │   │   │   ├── Citations.md
│   │   │   │   └── Document.md
│   │   │   ├── index.md
│   │   │   ├── models
│   │   │   │   ├── Embedder.md
│   │   │   │   └── LM.md
│   │   │   ├── modules
│   │   │   │   ├── BestOfN.md
│   │   │   │   ├── ChainOfThought.md
│   │   │   │   ├── CodeAct.md
│   │   │   │   ├── Module.md
│   │   │   │   ├── MultiChainComparison.md
│   │   │   │   ├── Parallel.md
│   │   │   │   ├── Predict.md
│   │   │   │   ├── ProgramOfThought.md
│   │   │   │   ├── ReAct.md
│   │   │   │   └── Refine.md
│   │   │   ├── optimizers
│   │   │   │   ├── BetterTogether.md
│   │   │   │   ├── BootstrapFewShot.md
│   │   │   │   ├── BootstrapFewShotWithRandomSearch.md
│   │   │   │   ├── BootstrapFinetune.md
│   │   │   │   ├── BootstrapRS.md
│   │   │   │   ├── COPRO.md
│   │   │   │   ├── Ensemble.md
│   │   │   │   ├── GEPA
│   │   │   │   │   ├── GEPA_Advanced.md
│   │   │   │   │   └── overview.md
│   │   │   │   ├── InferRules.md
│   │   │   │   ├── KNN.md
│   │   │   │   ├── KNNFewShot.md
│   │   │   │   ├── LabeledFewShot.md
│   │   │   │   ├── MIPROv2.md
│   │   │   │   └── SIMBA.md
│   │   │   ├── primitives
│   │   │   │   ├── Audio.md
│   │   │   │   ├── Code.md
│   │   │   │   ├── Example.md
│   │   │   │   ├── History.md
│   │   │   │   ├── Image.md
│   │   │   │   ├── Prediction.md
│   │   │   │   ├── Tool.md
│   │   │   │   └── ToolCalls.md
│   │   │   ├── signatures
│   │   │   │   ├── InputField.md
│   │   │   │   ├── OutputField.md
│   │   │   │   └── Signature.md
│   │   │   ├── tools
│   │   │   │   ├── ColBERTv2.md
│   │   │   │   ├── Embeddings.md
│   │   │   │   └── PythonInterpreter.md
│   │   │   └── utils
│   │   │       ├── asyncify.md
│   │   │       ├── configure_cache.md
│   │   │       ├── disable_litellm_logging.md
│   │   │       ├── disable_logging.md
│   │   │       ├── enable_litellm_logging.md
│   │   │       ├── enable_logging.md
│   │   │       ├── inspect_history.md
│   │   │       ├── load.md
│   │   │       ├── StatusMessage.md
│   │   │       ├── StatusMessageProvider.md
│   │   │       ├── streamify.md
│   │   │       └── StreamListener.md
│   │   ├── cheatsheet.md
│   │   ├── community
│   │   │   ├── community-resources.md
│   │   │   ├── how-to-contribute.md
│   │   │   └── use-cases.md
│   │   ├── deep-dive
│   │   │   └── data-handling
│   │   │       ├── built-in-datasets.md
│   │   │       ├── examples.md
│   │   │       ├── img
│   │   │       │   └── data-loading.png
│   │   │       └── loading-custom-data.md
│   │   ├── faqs.md
│   │   ├── index.md
│   │   ├── js
│   │   │   └── runllm-widget.js
│   │   ├── learn
│   │   │   ├── evaluation
│   │   │   │   ├── data.md
│   │   │   │   ├── metrics.md
│   │   │   │   └── overview.md
│   │   │   ├── figures
│   │   │   │   ├── native_tool_call.png
│   │   │   │   └── teleprompter-classes.png
│   │   │   ├── index.md
│   │   │   ├── optimization
│   │   │   │   ├── optimizers.md
│   │   │   │   └── overview.md
│   │   │   └── programming
│   │   │       ├── 7-assertions.md
│   │   │       ├── adapters.md
│   │   │       ├── language_models.md
│   │   │       ├── mcp.md
│   │   │       ├── modules.md
│   │   │       ├── overview.md
│   │   │       ├── signatures.md
│   │   │       └── tools.md
│   │   ├── production
│   │   │   └── index.md
│   │   ├── roadmap.md
│   │   ├── static
│   │   │   ├── .nojekyll
│   │   │   └── img
│   │   │       ├── dspy_logo.png
│   │   │       ├── logo.png
│   │   │       ├── mlflow-tracing-rag.png
│   │   │       ├── modular.png
│   │   │       ├── optimize.png
│   │   │       ├── undraw_docusaurus_mountain.svg
│   │   │       ├── undraw_docusaurus_react.svg
│   │   │       ├── undraw_docusaurus_tree.svg
│   │   │       └── universal_compatibility.png
│   │   ├── stylesheets
│   │   │   └── extra.css
│   │   └── tutorials
│   │       ├── agents
│   │       │   ├── index.ipynb
│   │       │   └── mlflow-tracing-agent.png
│   │       ├── ai_text_game
│   │       │   └── index.md
│   │       ├── async
│   │       │   └── index.md
│   │       ├── audio
│   │       │   └── index.ipynb
│   │       ├── build_ai_program
│   │       │   └── index.md
│   │       ├── cache
│   │       │   └── index.md
│   │       ├── classification
│   │       │   └── index.md
│   │       ├── classification_finetuning
│   │       │   ├── index.ipynb
│   │       │   └── mlflow-tracing-classification.png
│   │       ├── conversation_history
│   │       │   └── index.md
│   │       ├── core_development
│   │       │   └── index.md
│   │       ├── custom_module
│   │       │   ├── index.ipynb
│   │       │   └── mlflow-tracing-custom-module.png
│   │       ├── customer_service_agent
│   │       │   ├── index.ipynb
│   │       │   └── mlflow-tracing-customer-service-agent.png
│   │       ├── deployment
│   │       │   ├── dspy_mlflow_ui.png
│   │       │   └── index.md
│   │       ├── email_extraction
│   │       │   ├── index.md
│   │       │   └── mlflow-tracing-email-extraction.png
│   │       ├── entity_extraction
│   │       │   ├── index.ipynb
│   │       │   └── mlflow-tracing-entity-extraction.png
│   │       ├── games
│   │       │   ├── index.ipynb
│   │       │   └── mlflow-tracing-agent.png
│   │       ├── gepa_ai_program
│   │       │   └── index.md
│   │       ├── gepa_aime
│   │       │   ├── index.ipynb
│   │       │   ├── mlflow-tracing-gepa-aime.png
│   │       │   └── mlflow-tracking-gepa-aime-optimization.png
│   │       ├── gepa_facilitysupportanalyzer
│   │       │   ├── index.ipynb
│   │       │   ├── mlflow-tracing-gepa-support.png
│   │       │   └── mlflow-tracking-gepa-support-optimization.png
│   │       ├── gepa_papillon
│   │       │   ├── index.ipynb
│   │       │   ├── mlflow-tracing-gepa-papilon.png
│   │       │   └── mlflow-tracking-gepa-papilon-optimization.png
│   │       ├── image_generation_prompting
│   │       │   └── index.ipynb
│   │       ├── index.md
│   │       ├── llms_txt_generation
│   │       │   └── index.md
│   │       ├── math
│   │       │   ├── index.ipynb
│   │       │   └── mlflow-tracing-math.png
│   │       ├── mcp
│   │       │   └── index.md
│   │       ├── mem0_react_agent
│   │       │   └── index.md
│   │       ├── multihop_search
│   │       │   ├── index.ipynb
│   │       │   └── mlflow-tracing-multi-hop.png
│   │       ├── observability
│   │       │   ├── index.md
│   │       │   ├── mlflow_trace_ui_navigation.gif
│   │       │   ├── mlflow_trace_ui.png
│   │       │   └── mlflow_trace_view.png
│   │       ├── optimize_ai_program
│   │       │   └── index.md
│   │       ├── optimizer_tracking
│   │       │   ├── child_run.png
│   │       │   ├── experiment.png
│   │       │   ├── index.md
│   │       │   └── parent_run.png
│   │       ├── output_refinement
│   │       │   └── best-of-n-and-refine.md
│   │       ├── papillon
│   │       │   └── index.md
│   │       ├── program_of_thought
│   │       │   └── index.ipynb
│   │       ├── rag
│   │       │   ├── index.ipynb
│   │       │   └── mlflow-tracing-rag.png
│   │       ├── real_world_examples
│   │       │   └── index.md
│   │       ├── rl_ai_program
│   │       │   └── index.md
│   │       ├── rl_multihop
│   │       │   └── index.ipynb
│   │       ├── rl_papillon
│   │       │   └── index.ipynb
│   │       ├── sample_code_generation
│   │       │   └── index.md
│   │       ├── saving
│   │       │   └── index.md
│   │       ├── streaming
│   │       │   └── index.md
│   │       ├── tool_use
│   │       │   └── index.ipynb
│   │       └── yahoo_finance_react
│   │           └── index.md
│   ├── mkdocs.yml
│   ├── overrides
│   │   ├── home.html
│   │   ├── main.html
│   │   └── partials
│   │       └── tabs.html
│   ├── Pipfile
│   ├── Pipfile.lock
│   ├── README.md
│   ├── requirements.txt
│   ├── scripts
│   │   ├── generate_api_docs.py
│   │   └── generate_api_summary.py
│   └── vercel.json
├── dspy
│   ├── __init__.py
│   ├── __metadata__.py
│   ├── adapters
│   │   ├── __init__.py
│   │   ├── baml_adapter.py
│   │   ├── base.py
│   │   ├── chat_adapter.py
│   │   ├── json_adapter.py
│   │   ├── two_step_adapter.py
│   │   ├── types
│   │   │   ├── __init__.py
│   │   │   ├── audio.py
│   │   │   ├── base_type.py
│   │   │   ├── citation.py
│   │   │   ├── code.py
│   │   │   ├── document.py
│   │   │   ├── history.py
│   │   │   ├── image.py
│   │   │   └── tool.py
│   │   ├── utils.py
│   │   └── xml_adapter.py
│   ├── clients
│   │   ├── __init__.py
│   │   ├── base_lm.py
│   │   ├── cache.py
│   │   ├── databricks.py
│   │   ├── embedding.py
│   │   ├── lm_local_arbor.py
│   │   ├── lm_local.py
│   │   ├── lm.py
│   │   ├── openai.py
│   │   ├── provider.py
│   │   └── utils_finetune.py
│   ├── datasets
│   │   ├── __init__.py
│   │   ├── alfworld
│   │   │   ├── __init__.py
│   │   │   ├── alfworld.py
│   │   │   └── base_config.yml
│   │   ├── colors.py
│   │   ├── dataloader.py
│   │   ├── dataset.py
│   │   ├── gsm8k.py
│   │   ├── hotpotqa.py
│   │   └── math.py
│   ├── dsp
│   │   ├── __init__.py
│   │   ├── colbertv2.py
│   │   └── utils
│   │       ├── __init__.py
│   │       ├── dpr.py
│   │       ├── settings.py
│   │       └── utils.py
│   ├── evaluate
│   │   ├── __init__.py
│   │   ├── auto_evaluation.py
│   │   ├── evaluate.py
│   │   └── metrics.py
│   ├── experimental
│   │   └── __init__.py
│   ├── predict
│   │   ├── __init__.py
│   │   ├── aggregation.py
│   │   ├── avatar
│   │   │   ├── __init__.py
│   │   │   ├── avatar.py
│   │   │   ├── models.py
│   │   │   └── signatures.py
│   │   ├── best_of_n.py
│   │   ├── chain_of_thought.py
│   │   ├── code_act.py
│   │   ├── knn.py
│   │   ├── multi_chain_comparison.py
│   │   ├── parallel.py
│   │   ├── parameter.py
│   │   ├── predict.py
│   │   ├── program_of_thought.py
│   │   ├── react.py
│   │   ├── refine.py
│   │   └── retry.py
│   ├── primitives
│   │   ├── __init__.py
│   │   ├── base_module.py
│   │   ├── example.py
│   │   ├── module.py
│   │   ├── prediction.py
│   │   ├── python_interpreter.py
│   │   └── runner.js
│   ├── propose
│   │   ├── __init__.py
│   │   ├── dataset_summary_generator.py
│   │   ├── grounded_proposer.py
│   │   ├── propose_base.py
│   │   └── utils.py
│   ├── retrievers
│   │   ├── __init__.py
│   │   ├── databricks_rm.py
│   │   ├── embeddings.py
│   │   ├── retrieve.py
│   │   └── weaviate_rm.py
│   ├── signatures
│   │   ├── __init__.py
│   │   ├── field.py
│   │   ├── signature.py
│   │   └── utils.py
│   ├── streaming
│   │   ├── __init__.py
│   │   ├── messages.py
│   │   ├── streamify.py
│   │   └── streaming_listener.py
│   ├── teleprompt
│   │   ├── __init__.py
│   │   ├── avatar_optimizer.py
│   │   ├── bettertogether.py
│   │   ├── bootstrap_finetune.py
│   │   ├── bootstrap_trace.py
│   │   ├── bootstrap.py
│   │   ├── copro_optimizer.py
│   │   ├── ensemble.py
│   │   ├── gepa
│   │   │   ├── __init__.py
│   │   │   ├── gepa_utils.py
│   │   │   ├── gepa.py
│   │   │   └── instruction_proposal.py
│   │   ├── grpo.py
│   │   ├── infer_rules.py
│   │   ├── knn_fewshot.py
│   │   ├── mipro_optimizer_v2.py
│   │   ├── random_search.py
│   │   ├── signature_opt.py
│   │   ├── simba_utils.py
│   │   ├── simba.py
│   │   ├── teleprompt_optuna.py
│   │   ├── teleprompt.py
│   │   ├── utils.py
│   │   └── vanilla.py
│   └── utils
│       ├── __init__.py
│       ├── annotation.py
│       ├── asyncify.py
│       ├── caching.py
│       ├── callback.py
│       ├── dummies.py
│       ├── exceptions.py
│       ├── hasher.py
│       ├── inspect_history.py
│       ├── langchain_tool.py
│       ├── logging_utils.py
│       ├── mcp.py
│       ├── parallelizer.py
│       ├── saving.py
│       ├── syncify.py
│       ├── unbatchify.py
│       └── usage_tracker.py
├── LICENSE
├── pyproject.toml
├── README.md
├── tests
│   ├── __init__.py
│   ├── adapters
│   │   ├── test_adapter_utils.py
│   │   ├── test_baml_adapter.py
│   │   ├── test_base_type.py
│   │   ├── test_chat_adapter.py
│   │   ├── test_citation.py
│   │   ├── test_code.py
│   │   ├── test_document.py
│   │   ├── test_json_adapter.py
│   │   ├── test_tool.py
│   │   ├── test_two_step_adapter.py
│   │   └── test_xml_adapter.py
│   ├── callback
│   │   └── test_callback.py
│   ├── clients
│   │   ├── test_cache.py
│   │   ├── test_databricks.py
│   │   ├── test_embedding.py
│   │   ├── test_inspect_global_history.py
│   │   └── test_lm.py
│   ├── conftest.py
│   ├── datasets
│   │   └── test_dataset.py
│   ├── docs
│   │   └── test_mkdocs_links.py
│   ├── evaluate
│   │   ├── test_evaluate.py
│   │   └── test_metrics.py
│   ├── examples
│   │   └── test_baleen.py
│   ├── metadata
│   │   └── test_metadata.py
│   ├── predict
│   │   ├── test_aggregation.py
│   │   ├── test_best_of_n.py
│   │   ├── test_chain_of_thought.py
│   │   ├── test_code_act.py
│   │   ├── test_knn.py
│   │   ├── test_multi_chain_comparison.py
│   │   ├── test_parallel.py
│   │   ├── test_predict.py
│   │   ├── test_program_of_thought.py
│   │   ├── test_react.py
│   │   ├── test_refine.py
│   │   └── test_retry.py
│   ├── primitives
│   │   ├── resources
│   │   │   └── saved_program.json
│   │   ├── test_base_module.py
│   │   ├── test_example.py
│   │   ├── test_module.py
│   │   └── test_python_interpreter.py
│   ├── propose
│   │   └── test_grounded_proposer.py
│   ├── README.md
│   ├── reliability
│   │   ├── __init__.py
│   │   ├── complex_types
│   │   │   └── generated
│   │   │       ├── test_many_types_1
│   │   │       │   ├── inputs
│   │   │       │   │   ├── input1.json
│   │   │       │   │   └── input2.json
│   │   │       │   ├── program.py
│   │   │       │   └── schema.json
│   │   │       ├── test_nesting_1
│   │   │       │   ├── inputs
│   │   │       │   │   ├── input1.json
│   │   │       │   │   └── input2.json
│   │   │       │   ├── program.py
│   │   │       │   └── schema.json
│   │   │       └── test_nesting_2
│   │   │           ├── inputs
│   │   │           │   └── input1.json
│   │   │           ├── program.py
│   │   │           └── schema.json
│   │   ├── conftest.py
│   │   ├── generate
│   │   │   ├── __init__.py
│   │   │   ├── __main__.py
│   │   │   └── utils.py
│   │   ├── input_formats
│   │   │   └── generated
│   │   │       └── test_markdown_1
│   │   │           ├── inputs
│   │   │           │   ├── input1.json
│   │   │           │   └── input2.json
│   │   │           ├── program.py
│   │   │           └── schema.json
│   │   ├── README.md
│   │   ├── reliability_conf.yaml
│   │   ├── test_generated.py
│   │   ├── test_pydantic_models.py
│   │   └── utils.py
│   ├── retrievers
│   │   └── test_embeddings.py
│   ├── signatures
│   │   ├── test_adapter_image.py
│   │   ├── test_custom_types.py
│   │   └── test_signature.py
│   ├── streaming
│   │   └── test_streaming.py
│   ├── teleprompt
│   │   ├── gepa_dummy_lm_custom_component_selector_custom_instruction_proposer.json
│   │   ├── gepa_dummy_lm.json
│   │   ├── test_bootstrap_finetune.py
│   │   ├── test_bootstrap_trace.py
│   │   ├── test_bootstrap.py
│   │   ├── test_copro_optimizer.py
│   │   ├── test_ensemble.py
│   │   ├── test_finetune.py
│   │   ├── test_gepa_instruction_proposer.py
│   │   ├── test_gepa.py
│   │   ├── test_grpo.py
│   │   ├── test_knn_fewshot.py
│   │   ├── test_random_search.py
│   │   ├── test_teleprompt.py
│   │   └── test_utils.py
│   ├── test_utils
│   │   ├── __init__.py
│   │   └── server
│   │       ├── __init__.py
│   │       ├── litellm_server_config.yaml
│   │       └── litellm_server.py
│   └── utils
│       ├── __init__.py
│       ├── resources
│       │   └── mcp_server.py
│       ├── test_annotation.py
│       ├── test_asyncify.py
│       ├── test_exceptions.py
│       ├── test_langchain_tool.py
│       ├── test_mcp.py
│       ├── test_parallelizer.py
│       ├── test_saving.py
│       ├── test_settings.py
│       ├── test_syncify.py
│       ├── test_unbatchify.py
│       └── test_usage_tracker.py
└── uv.lock
```

# Files

--------------------------------------------------------------------------------
/dspy/clients/lm.py:
--------------------------------------------------------------------------------

```python
  1 | import logging
  2 | import os
  3 | import re
  4 | import threading
  5 | import warnings
  6 | from typing import Any, Literal, cast
  7 | 
  8 | import litellm
  9 | from anyio.streams.memory import MemoryObjectSendStream
 10 | from asyncer import syncify
 11 | 
 12 | import dspy
 13 | from dspy.clients.cache import request_cache
 14 | from dspy.clients.openai import OpenAIProvider
 15 | from dspy.clients.provider import Provider, ReinforceJob, TrainingJob
 16 | from dspy.clients.utils_finetune import MultiGPUConfig, TrainDataFormat
 17 | from dspy.dsp.utils.settings import settings
 18 | from dspy.utils.callback import BaseCallback
 19 | 
 20 | from .base_lm import BaseLM
 21 | 
 22 | logger = logging.getLogger(__name__)
 23 | 
 24 | 
 25 | class LM(BaseLM):
 26 |     """
 27 |     A language model supporting chat or text completion requests for use with DSPy modules.
 28 |     """
 29 | 
 30 |     def __init__(
 31 |         self,
 32 |         model: str,
 33 |         model_type: Literal["chat", "text", "responses"] = "chat",
 34 |         temperature: float | None = None,
 35 |         max_tokens: int | None = None,
 36 |         cache: bool = True,
 37 |         callbacks: list[BaseCallback] | None = None,
 38 |         num_retries: int = 3,
 39 |         provider: Provider | None = None,
 40 |         finetuning_model: str | None = None,
 41 |         launch_kwargs: dict[str, Any] | None = None,
 42 |         train_kwargs: dict[str, Any] | None = None,
 43 |         use_developer_role: bool = False,
 44 |         **kwargs,
 45 |     ):
 46 |         """
 47 |         Create a new language model instance for use with DSPy modules and programs.
 48 | 
 49 |         Args:
 50 |             model: The model to use. This should be a string of the form ``"llm_provider/llm_name"``
 51 |                    supported by LiteLLM. For example, ``"openai/gpt-4o"``.
 52 |             model_type: The type of the model, either ``"chat"`` or ``"text"``.
 53 |             temperature: The sampling temperature to use when generating responses.
 54 |             max_tokens: The maximum number of tokens to generate per response.
 55 |             cache: Whether to cache the model responses for reuse to improve performance
 56 |                    and reduce costs.
 57 |             callbacks: A list of callback functions to run before and after each request.
 58 |             num_retries: The number of times to retry a request if it fails transiently due to
 59 |                          network error, rate limiting, etc. Requests are retried with exponential
 60 |                          backoff.
 61 |             provider: The provider to use. If not specified, the provider will be inferred from the model.
 62 |             finetuning_model: The model to finetune. In some providers, the models available for finetuning is different
 63 |                 from the models available for inference.
 64 |             rollout_id: Optional integer used to differentiate cache entries for otherwise
 65 |                 identical requests. Different values bypass DSPy's caches while still caching
 66 |                 future calls with the same inputs and rollout ID. Note that `rollout_id`
 67 |                 only affects generation when `temperature` is non-zero. This argument is
 68 |                 stripped before sending requests to the provider.
 69 |         """
 70 |         # Remember to update LM.copy() if you modify the constructor!
 71 |         self.model = model
 72 |         self.model_type = model_type
 73 |         self.cache = cache
 74 |         self.provider = provider or self.infer_provider()
 75 |         self.callbacks = callbacks or []
 76 |         self.history = []
 77 |         self.num_retries = num_retries
 78 |         self.finetuning_model = finetuning_model
 79 |         self.launch_kwargs = launch_kwargs or {}
 80 |         self.train_kwargs = train_kwargs or {}
 81 |         self.use_developer_role = use_developer_role
 82 |         self._warned_zero_temp_rollout = False
 83 | 
 84 |         # Handle model-specific configuration for different model families
 85 |         model_family = model.split("/")[-1].lower() if "/" in model else model.lower()
 86 | 
 87 |         # Recognize OpenAI reasoning models (o1, o3, o4, gpt-5 family)
 88 |         model_pattern = re.match(r"^(?:o[1345]|gpt-5)(?:-(?:mini|nano))?", model_family)
 89 | 
 90 |         if model_pattern:
 91 | 
 92 |             if (temperature and temperature != 1.0) or (max_tokens and max_tokens < 16000):
 93 |                 raise ValueError(
 94 |                     "OpenAI's reasoning models require passing temperature=1.0 or None and max_tokens >= 16000 or None to "
 95 |                     "`dspy.LM(...)`, e.g., dspy.LM('openai/gpt-5', temperature=1.0, max_tokens=16000)"
 96 |                 )
 97 |             self.kwargs = dict(temperature=temperature, max_completion_tokens=max_tokens, **kwargs)
 98 |             if self.kwargs.get("rollout_id") is None:
 99 |                 self.kwargs.pop("rollout_id", None)
100 |         else:
101 |             self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs)
102 |             if self.kwargs.get("rollout_id") is None:
103 |                 self.kwargs.pop("rollout_id", None)
104 | 
105 |         self._warn_zero_temp_rollout(self.kwargs.get("temperature"), self.kwargs.get("rollout_id"))
106 | 
107 |     def _warn_zero_temp_rollout(self, temperature: float | None, rollout_id):
108 |         if not self._warned_zero_temp_rollout and rollout_id is not None and (temperature is None or temperature == 0):
109 |             warnings.warn(
110 |                 "rollout_id has no effect when temperature=0; set temperature>0 to bypass the cache.",
111 |                 stacklevel=3,
112 |             )
113 |             self._warned_zero_temp_rollout = True
114 | 
115 |     def _get_cached_completion_fn(self, completion_fn, cache):
116 |         ignored_args_for_cache_key = ["api_key", "api_base", "base_url"]
117 |         if cache:
118 |             completion_fn = request_cache(
119 |                 cache_arg_name="request",
120 |                 ignored_args_for_cache_key=ignored_args_for_cache_key,
121 |             )(completion_fn)
122 | 
123 |         litellm_cache_args = {"no-cache": True, "no-store": True}
124 | 
125 |         return completion_fn, litellm_cache_args
126 | 
127 |     def forward(self, prompt=None, messages=None, **kwargs):
128 |         # Build the request.
129 |         kwargs = dict(kwargs)
130 |         cache = kwargs.pop("cache", self.cache)
131 | 
132 |         messages = messages or [{"role": "user", "content": prompt}]
133 |         if self.use_developer_role and self.model_type == "responses":
134 |             messages = [{**m, "role": "developer"} if m.get("role") == "system" else m for m in messages]
135 |         kwargs = {**self.kwargs, **kwargs}
136 |         self._warn_zero_temp_rollout(kwargs.get("temperature"), kwargs.get("rollout_id"))
137 |         if kwargs.get("rollout_id") is None:
138 |             kwargs.pop("rollout_id", None)
139 | 
140 |         if self.model_type == "chat":
141 |             completion = litellm_completion
142 |         elif self.model_type == "text":
143 |             completion = litellm_text_completion
144 |         elif self.model_type == "responses":
145 |             completion = litellm_responses_completion
146 |         completion, litellm_cache_args = self._get_cached_completion_fn(completion, cache)
147 | 
148 |         results = completion(
149 |             request=dict(model=self.model, messages=messages, **kwargs),
150 |             num_retries=self.num_retries,
151 |             cache=litellm_cache_args,
152 |         )
153 | 
154 |         self._check_truncation(results)
155 | 
156 |         if not getattr(results, "cache_hit", False) and dspy.settings.usage_tracker and hasattr(results, "usage"):
157 |             settings.usage_tracker.add_usage(self.model, dict(results.usage))
158 |         return results
159 | 
160 |     async def aforward(self, prompt=None, messages=None, **kwargs):
161 |         # Build the request.
162 |         kwargs = dict(kwargs)
163 |         cache = kwargs.pop("cache", self.cache)
164 | 
165 |         messages = messages or [{"role": "user", "content": prompt}]
166 |         if self.use_developer_role and self.model_type == "responses":
167 |             messages = [{**m, "role": "developer"} if m.get("role") == "system" else m for m in messages]
168 |         kwargs = {**self.kwargs, **kwargs}
169 |         self._warn_zero_temp_rollout(kwargs.get("temperature"), kwargs.get("rollout_id"))
170 |         if kwargs.get("rollout_id") is None:
171 |             kwargs.pop("rollout_id", None)
172 | 
173 |         if self.model_type == "chat":
174 |             completion = alitellm_completion
175 |         elif self.model_type == "text":
176 |             completion = alitellm_text_completion
177 |         elif self.model_type == "responses":
178 |             completion = alitellm_responses_completion
179 |         completion, litellm_cache_args = self._get_cached_completion_fn(completion, cache)
180 | 
181 |         results = await completion(
182 |             request=dict(model=self.model, messages=messages, **kwargs),
183 |             num_retries=self.num_retries,
184 |             cache=litellm_cache_args,
185 |         )
186 | 
187 |         self._check_truncation(results)
188 | 
189 |         if not getattr(results, "cache_hit", False) and dspy.settings.usage_tracker and hasattr(results, "usage"):
190 |             settings.usage_tracker.add_usage(self.model, dict(results.usage))
191 |         return results
192 | 
193 |     def launch(self, launch_kwargs: dict[str, Any] | None = None):
194 |         self.provider.launch(self, launch_kwargs)
195 | 
196 |     def kill(self, launch_kwargs: dict[str, Any] | None = None):
197 |         self.provider.kill(self, launch_kwargs)
198 | 
199 |     def finetune(
200 |         self,
201 |         train_data: list[dict[str, Any]],
202 |         train_data_format: TrainDataFormat | None,
203 |         train_kwargs: dict[str, Any] | None = None,
204 |     ) -> TrainingJob:
205 |         from dspy import settings as settings
206 | 
207 |         if not self.provider.finetunable:
208 |             raise ValueError(
209 |                 f"Provider {self.provider} does not support fine-tuning, please specify your provider by explicitly "
210 |                 "setting `provider` when creating the `dspy.LM` instance. For example, "
211 |                 "`dspy.LM('openai/gpt-4.1-mini-2025-04-14', provider=dspy.OpenAIProvider())`."
212 |             )
213 | 
214 |         def thread_function_wrapper():
215 |             return self._run_finetune_job(job)
216 | 
217 |         thread = threading.Thread(target=thread_function_wrapper)
218 |         train_kwargs = train_kwargs or self.train_kwargs
219 |         model_to_finetune = self.finetuning_model or self.model
220 |         job = self.provider.TrainingJob(
221 |             thread=thread,
222 |             model=model_to_finetune,
223 |             train_data=train_data,
224 |             train_data_format=train_data_format,
225 |             train_kwargs=train_kwargs,
226 |         )
227 |         thread.start()
228 | 
229 |         return job
230 | 
231 |     def reinforce(
232 |         self, train_kwargs, gpu_config: MultiGPUConfig = MultiGPUConfig(num_inference_gpus=1, num_training_gpus=1)
233 |     ) -> ReinforceJob:
234 |         # TODO(GRPO Team): Should we return an initialized job here?
235 |         from dspy import settings as settings
236 | 
237 |         err = f"Provider {self.provider} does not implement the reinforcement learning interface."
238 |         assert self.provider.reinforceable, err
239 | 
240 |         job = self.provider.ReinforceJob(lm=self, train_kwargs=train_kwargs, gpu_config=gpu_config)
241 |         job.initialize()
242 |         return job
243 | 
244 |     def _run_finetune_job(self, job: TrainingJob):
245 |         # TODO(enhance): We should listen for keyboard interrupts somewhere.
246 |         # Requires TrainingJob.cancel() to be implemented for each provider.
247 |         try:
248 |             model = self.provider.finetune(
249 |                 job=job,
250 |                 model=job.model,
251 |                 train_data=job.train_data,
252 |                 train_data_format=job.train_data_format,
253 |                 train_kwargs=job.train_kwargs,
254 |             )
255 |             lm = self.copy(model=model)
256 |             job.set_result(lm)
257 |         except Exception as err:
258 |             logger.error(err)
259 |             job.set_result(err)
260 | 
261 |     def infer_provider(self) -> Provider:
262 |         if OpenAIProvider.is_provider_model(self.model):
263 |             return OpenAIProvider()
264 |         return Provider()
265 | 
266 |     def dump_state(self):
267 |         state_keys = [
268 |             "model",
269 |             "model_type",
270 |             "cache",
271 |             "num_retries",
272 |             "finetuning_model",
273 |             "launch_kwargs",
274 |             "train_kwargs",
275 |         ]
276 |         return {key: getattr(self, key) for key in state_keys} | self.kwargs
277 | 
278 |     def _check_truncation(self, results):
279 |         if self.model_type != "responses" and any(c.finish_reason == "length" for c in results["choices"]):
280 |             logger.warning(
281 |                 f"LM response was truncated due to exceeding max_tokens={self.kwargs['max_tokens']}. "
282 |                 "You can inspect the latest LM interactions with `dspy.inspect_history()`. "
283 |                 "To avoid truncation, consider passing a larger max_tokens when setting up dspy.LM. "
284 |                 f"You may also consider increasing the temperature (currently {self.kwargs['temperature']}) "
285 |                 " if the reason for truncation is repetition."
286 |             )
287 | 
288 | 
289 | def _get_stream_completion_fn(
290 |     request: dict[str, Any],
291 |     cache_kwargs: dict[str, Any],
292 |     sync=True,
293 | ):
294 |     stream = dspy.settings.send_stream
295 |     caller_predict = dspy.settings.caller_predict
296 | 
297 |     if stream is None:
298 |         return None
299 | 
300 |     # The stream is already opened, and will be closed by the caller.
301 |     stream = cast(MemoryObjectSendStream, stream)
302 |     caller_predict_id = id(caller_predict) if caller_predict else None
303 | 
304 |     if dspy.settings.track_usage:
305 |         request["stream_options"] = {"include_usage": True}
306 | 
307 |     async def stream_completion(request: dict[str, Any], cache_kwargs: dict[str, Any]):
308 |         headers = request.pop("headers", None)
309 |         response = await litellm.acompletion(
310 |             cache=cache_kwargs,
311 |             stream=True,
312 |             headers=_get_headers(headers),
313 |             **request,
314 |         )
315 |         chunks = []
316 |         async for chunk in response:
317 |             if caller_predict_id:
318 |                 # Add the predict id to the chunk so that the stream listener can identify which predict produces it.
319 |                 chunk.predict_id = caller_predict_id
320 |             chunks.append(chunk)
321 |             await stream.send(chunk)
322 |         return litellm.stream_chunk_builder(chunks)
323 | 
324 |     def sync_stream_completion():
325 |         syncified_stream_completion = syncify(stream_completion)
326 |         return syncified_stream_completion(request, cache_kwargs)
327 | 
328 |     async def async_stream_completion():
329 |         return await stream_completion(request, cache_kwargs)
330 | 
331 |     if sync:
332 |         return sync_stream_completion
333 |     else:
334 |         return async_stream_completion
335 | 
336 | 
337 | def litellm_completion(request: dict[str, Any], num_retries: int, cache: dict[str, Any] | None = None):
338 |     cache = cache or {"no-cache": True, "no-store": True}
339 |     request = dict(request)
340 |     request.pop("rollout_id", None)
341 |     headers = request.pop("headers", None)
342 |     stream_completion = _get_stream_completion_fn(request, cache, sync=True)
343 |     if stream_completion is None:
344 |         return litellm.completion(
345 |             cache=cache,
346 |             num_retries=num_retries,
347 |             retry_strategy="exponential_backoff_retry",
348 |             headers=_get_headers(headers),
349 |             **request,
350 |         )
351 | 
352 |     return stream_completion()
353 | 
354 | 
355 | def litellm_text_completion(request: dict[str, Any], num_retries: int, cache: dict[str, Any] | None = None):
356 |     cache = cache or {"no-cache": True, "no-store": True}
357 |     request = dict(request)
358 |     request.pop("rollout_id", None)
359 |     headers = request.pop("headers", None)
360 |     # Extract the provider and model from the model string.
361 |     # TODO: Not all the models are in the format of "provider/model"
362 |     model = request.pop("model").split("/", 1)
363 |     provider, model = model[0] if len(model) > 1 else "openai", model[-1]
364 | 
365 |     # Use the API key and base from the request, or from the environment.
366 |     api_key = request.pop("api_key", None) or os.getenv(f"{provider}_API_KEY")
367 |     api_base = request.pop("api_base", None) or os.getenv(f"{provider}_API_BASE")
368 | 
369 |     # Build the prompt from the messages.
370 |     prompt = "\n\n".join([x["content"] for x in request.pop("messages")] + ["BEGIN RESPONSE:"])
371 | 
372 |     return litellm.text_completion(
373 |         cache=cache,
374 |         model=f"text-completion-openai/{model}",
375 |         api_key=api_key,
376 |         api_base=api_base,
377 |         prompt=prompt,
378 |         num_retries=num_retries,
379 |         retry_strategy="exponential_backoff_retry",
380 |         headers=_get_headers(headers),
381 |         **request,
382 |     )
383 | 
384 | 
385 | async def alitellm_completion(request: dict[str, Any], num_retries: int, cache: dict[str, Any] | None = None):
386 |     cache = cache or {"no-cache": True, "no-store": True}
387 |     request = dict(request)
388 |     request.pop("rollout_id", None)
389 |     headers = request.pop("headers", None)
390 |     stream_completion = _get_stream_completion_fn(request, cache, sync=False)
391 |     if stream_completion is None:
392 |         return await litellm.acompletion(
393 |             cache=cache,
394 |             num_retries=num_retries,
395 |             retry_strategy="exponential_backoff_retry",
396 |             headers=_get_headers(headers),
397 |             **request,
398 |         )
399 | 
400 |     return await stream_completion()
401 | 
402 | 
403 | async def alitellm_text_completion(request: dict[str, Any], num_retries: int, cache: dict[str, Any] | None = None):
404 |     cache = cache or {"no-cache": True, "no-store": True}
405 |     request = dict(request)
406 |     request.pop("rollout_id", None)
407 |     model = request.pop("model").split("/", 1)
408 |     headers = request.pop("headers", None)
409 |     provider, model = model[0] if len(model) > 1 else "openai", model[-1]
410 | 
411 |     # Use the API key and base from the request, or from the environment.
412 |     api_key = request.pop("api_key", None) or os.getenv(f"{provider}_API_KEY")
413 |     api_base = request.pop("api_base", None) or os.getenv(f"{provider}_API_BASE")
414 | 
415 |     # Build the prompt from the messages.
416 |     prompt = "\n\n".join([x["content"] for x in request.pop("messages")] + ["BEGIN RESPONSE:"])
417 | 
418 |     return await litellm.atext_completion(
419 |         cache=cache,
420 |         model=f"text-completion-openai/{model}",
421 |         api_key=api_key,
422 |         api_base=api_base,
423 |         prompt=prompt,
424 |         num_retries=num_retries,
425 |         retry_strategy="exponential_backoff_retry",
426 |         headers=_get_headers(headers),
427 |         **request,
428 |     )
429 | 
430 | 
431 | def litellm_responses_completion(request: dict[str, Any], num_retries: int, cache: dict[str, Any] | None = None):
432 |     cache = cache or {"no-cache": True, "no-store": True}
433 |     request = dict(request)
434 |     request.pop("rollout_id", None)
435 |     headers = request.pop("headers", None)
436 |     request = _convert_chat_request_to_responses_request(request)
437 | 
438 |     return litellm.responses(
439 |         cache=cache,
440 |         num_retries=num_retries,
441 |         retry_strategy="exponential_backoff_retry",
442 |         headers=_get_headers(headers),
443 |         **request,
444 |     )
445 | 
446 | 
447 | async def alitellm_responses_completion(request: dict[str, Any], num_retries: int, cache: dict[str, Any] | None = None):
448 |     cache = cache or {"no-cache": True, "no-store": True}
449 |     request = dict(request)
450 |     request.pop("rollout_id", None)
451 |     headers = request.pop("headers", None)
452 |     request = _convert_chat_request_to_responses_request(request)
453 | 
454 |     return await litellm.aresponses(
455 |         cache=cache,
456 |         num_retries=num_retries,
457 |         retry_strategy="exponential_backoff_retry",
458 |         headers=_get_headers(headers),
459 |         **request,
460 |     )
461 | 
462 | 
463 | def _convert_chat_request_to_responses_request(request: dict[str, Any]):
464 |     request = dict(request)
465 |     if "messages" in request:
466 |         content_blocks = []
467 |         for msg in request.pop("messages"):
468 |             c = msg.get("content")
469 |             if isinstance(c, str):
470 |                 content_blocks.append({"type": "input_text", "text": c})
471 |             elif isinstance(c, list):
472 |                 content_blocks.extend(c)
473 |         request["input"] = [{"role": msg.get("role", "user"), "content": content_blocks}]
474 | 
475 |     # Convert `response_format` to `text.format` for Responses API
476 |     if "response_format" in request:
477 |         response_format = request.pop("response_format")
478 |         text = request.pop("text", {})
479 |         request["text"] = {**text, "format": response_format}
480 | 
481 |     return request
482 | 
483 | def _get_headers(headers: dict[str, Any] | None = None):
484 |     headers = headers or {}
485 |     return {
486 |         "User-Agent": f"DSPy/{dspy.__version__}",
487 |         **headers,
488 |     }
489 | 
```

--------------------------------------------------------------------------------
/tests/clients/test_lm.py:
--------------------------------------------------------------------------------

```python
  1 | import json
  2 | import time
  3 | import warnings
  4 | from unittest import mock
  5 | from unittest.mock import patch
  6 | 
  7 | import litellm
  8 | import pydantic
  9 | import pytest
 10 | from litellm.types.llms.openai import ResponseAPIUsage, ResponsesAPIResponse
 11 | from litellm.utils import Choices, Message, ModelResponse
 12 | from openai import RateLimitError
 13 | from openai.types.responses import ResponseOutputMessage, ResponseReasoningItem
 14 | from openai.types.responses.response_reasoning_item import Summary
 15 | 
 16 | import dspy
 17 | from dspy.utils.dummies import DummyLM
 18 | from dspy.utils.usage_tracker import track_usage
 19 | 
 20 | 
 21 | def make_response(output_blocks):
 22 |     return ResponsesAPIResponse(
 23 |         id="resp_1",
 24 |         created_at=0.0,
 25 |         error=None,
 26 |         incomplete_details=None,
 27 |         instructions=None,
 28 |         model="openai/dspy-test-model",
 29 |         object="response",
 30 |         output=output_blocks,
 31 |         metadata = {},
 32 |         parallel_tool_calls=False,
 33 |         temperature=1.0,
 34 |         tool_choice="auto",
 35 |         tools=[],
 36 |         top_p=1.0,
 37 |         max_output_tokens=None,
 38 |         previous_response_id=None,
 39 |         reasoning=None,
 40 |         status="completed",
 41 |         text=None,
 42 |         truncation="disabled",
 43 |         usage=ResponseAPIUsage(input_tokens=1, output_tokens=1, total_tokens=2),
 44 |         user=None,
 45 |     )
 46 | 
 47 | 
 48 | def test_chat_lms_can_be_queried(litellm_test_server):
 49 |     api_base, _ = litellm_test_server
 50 |     expected_response = ["Hi!"]
 51 | 
 52 |     openai_lm = dspy.LM(
 53 |         model="openai/dspy-test-model",
 54 |         api_base=api_base,
 55 |         api_key="fakekey",
 56 |         model_type="chat",
 57 |     )
 58 |     assert openai_lm("openai query") == expected_response
 59 | 
 60 |     azure_openai_lm = dspy.LM(
 61 |         model="azure/dspy-test-model",
 62 |         api_base=api_base,
 63 |         api_key="fakekey",
 64 |         model_type="chat",
 65 |     )
 66 |     assert azure_openai_lm("azure openai query") == expected_response
 67 | 
 68 | 
 69 | def test_dspy_cache(litellm_test_server, tmp_path):
 70 |     api_base, _ = litellm_test_server
 71 | 
 72 |     original_cache = dspy.cache
 73 |     dspy.clients.configure_cache(
 74 |         enable_disk_cache=True,
 75 |         enable_memory_cache=True,
 76 |         disk_cache_dir=tmp_path / ".disk_cache",
 77 |     )
 78 |     cache = dspy.cache
 79 | 
 80 |     lm = dspy.LM(
 81 |         model="openai/dspy-test-model",
 82 |         api_base=api_base,
 83 |         api_key="fakekey",
 84 |         model_type="text",
 85 |     )
 86 |     with track_usage() as usage_tracker:
 87 |         lm("Query")
 88 | 
 89 |     assert len(cache.memory_cache) == 1
 90 |     cache_key = next(iter(cache.memory_cache.keys()))
 91 |     assert cache_key in cache.disk_cache
 92 |     assert len(usage_tracker.usage_data) == 1
 93 | 
 94 |     with track_usage() as usage_tracker:
 95 |         lm("Query")
 96 | 
 97 |     assert len(usage_tracker.usage_data) == 0
 98 | 
 99 |     dspy.cache = original_cache
100 | 
101 | 
102 | def test_disabled_cache_skips_cache_key(monkeypatch):
103 |     original_cache = dspy.cache
104 |     dspy.configure_cache(enable_disk_cache=False, enable_memory_cache=False)
105 |     cache = dspy.cache
106 | 
107 |     try:
108 |         with mock.patch.object(cache, "cache_key", wraps=cache.cache_key) as cache_key_spy, \
109 |              mock.patch.object(cache, "get", wraps=cache.get) as cache_get_spy, \
110 |              mock.patch.object(cache, "put", wraps=cache.put) as cache_put_spy:
111 | 
112 |             def fake_completion(*, cache, num_retries, retry_strategy, **request):
113 |                 return ModelResponse(
114 |                     choices=[Choices(message=Message(role="assistant", content="Hi!"))],
115 |                     usage={"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
116 |                     model="dummy",
117 |                 )
118 | 
119 |             monkeypatch.setattr(litellm, "completion", fake_completion)
120 | 
121 |             dummy_lm = DummyLM([{"answer": "ignored"}])
122 |             # TODO(isaacbmiller): Change from dummy_lm.forward to just dummy_lm.__call__ #8864
123 |             dummy_lm.forward(messages=[{"role": "user", "content": "Hello"}])
124 | 
125 |             cache_key_spy.assert_not_called()
126 |             cache_get_spy.assert_called_once()
127 |             cache_put_spy.assert_called_once()
128 |     finally:
129 |         dspy.cache = original_cache
130 | 
131 | 
132 | def test_rollout_id_bypasses_cache(monkeypatch, tmp_path):
133 |     calls: list[dict] = []
134 | 
135 |     def fake_completion(*, cache, num_retries, retry_strategy, **request):
136 |         calls.append(request)
137 |         return ModelResponse(
138 |             choices=[Choices(message=Message(role="assistant", content="Hi!"))],
139 |             usage={"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
140 |             model="openai/dspy-test-model",
141 |         )
142 | 
143 |     monkeypatch.setattr(litellm, "completion", fake_completion)
144 | 
145 |     original_cache = dspy.cache
146 |     dspy.clients.configure_cache(
147 |         enable_disk_cache=True,
148 |         enable_memory_cache=True,
149 |         disk_cache_dir=tmp_path / ".disk_cache",
150 |     )
151 | 
152 |     lm = dspy.LM(model="openai/dspy-test-model", model_type="chat")
153 | 
154 |     with track_usage() as usage_tracker:
155 |         lm(messages=[{"role": "user", "content": "Query"}], rollout_id=1)
156 |     assert len(usage_tracker.usage_data) == 1
157 | 
158 |     with track_usage() as usage_tracker:
159 |         lm(messages=[{"role": "user", "content": "Query"}], rollout_id=1)
160 |     assert len(usage_tracker.usage_data) == 0
161 | 
162 |     with track_usage() as usage_tracker:
163 |         lm(messages=[{"role": "user", "content": "Query"}], rollout_id=2)
164 |     assert len(usage_tracker.usage_data) == 1
165 | 
166 |     with track_usage() as usage_tracker:
167 |         lm(messages=[{"role": "user", "content": "NoRID"}])
168 |     assert len(usage_tracker.usage_data) == 1
169 | 
170 |     with track_usage() as usage_tracker:
171 |         lm(messages=[{"role": "user", "content": "NoRID"}], rollout_id=None)
172 |     assert len(usage_tracker.usage_data) == 0
173 | 
174 |     assert len(dspy.cache.memory_cache) == 3
175 |     assert all("rollout_id" not in r for r in calls)
176 |     dspy.cache = original_cache
177 | 
178 | 
179 | def test_zero_temperature_rollout_warns_once(monkeypatch):
180 |     def fake_completion(*, cache, num_retries, retry_strategy, **request):
181 |         return ModelResponse(
182 |             choices=[Choices(message=Message(role="assistant", content="Hi!"))],
183 |             usage={"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
184 |             model="openai/dspy-test-model",
185 |         )
186 | 
187 |     monkeypatch.setattr(litellm, "completion", fake_completion)
188 | 
189 |     lm = dspy.LM(model="openai/dspy-test-model", model_type="chat")
190 |     with pytest.warns(UserWarning, match="rollout_id has no effect"):
191 |         lm("Query", rollout_id=1)
192 |     with warnings.catch_warnings(record=True) as record:
193 |         warnings.simplefilter("always")
194 |         lm("Query", rollout_id=2)
195 |         assert len(record) == 0
196 | 
197 | 
198 | def test_text_lms_can_be_queried(litellm_test_server):
199 |     api_base, _ = litellm_test_server
200 |     expected_response = ["Hi!"]
201 | 
202 |     openai_lm = dspy.LM(
203 |         model="openai/dspy-test-model",
204 |         api_base=api_base,
205 |         api_key="fakekey",
206 |         model_type="text",
207 |     )
208 |     assert openai_lm("openai query") == expected_response
209 | 
210 |     azure_openai_lm = dspy.LM(
211 |         model="azure/dspy-test-model",
212 |         api_base=api_base,
213 |         api_key="fakekey",
214 |         model_type="text",
215 |     )
216 |     assert azure_openai_lm("azure openai query") == expected_response
217 | 
218 | 
219 | def test_lm_calls_support_callables(litellm_test_server):
220 |     api_base, _ = litellm_test_server
221 | 
222 |     with mock.patch("litellm.completion", autospec=True, wraps=litellm.completion) as spy_completion:
223 | 
224 |         def azure_ad_token_provider(*args, **kwargs):
225 |             return None
226 | 
227 |         lm_with_callable = dspy.LM(
228 |             model="openai/dspy-test-model",
229 |             api_base=api_base,
230 |             api_key="fakekey",
231 |             azure_ad_token_provider=azure_ad_token_provider,
232 |             cache=False,
233 |         )
234 | 
235 |         lm_with_callable("Query")
236 | 
237 |         spy_completion.assert_called_once()
238 |         call_args = spy_completion.call_args.kwargs
239 |         assert call_args["model"] == "openai/dspy-test-model"
240 |         assert call_args["api_base"] == api_base
241 |         assert call_args["api_key"] == "fakekey"
242 |         assert call_args["azure_ad_token_provider"] is azure_ad_token_provider
243 | 
244 | 
245 | def test_lm_calls_support_pydantic_models(litellm_test_server):
246 |     api_base, _ = litellm_test_server
247 | 
248 |     class ResponseFormat(pydantic.BaseModel):
249 |         response: str
250 | 
251 |     lm = dspy.LM(
252 |         model="openai/dspy-test-model",
253 |         api_base=api_base,
254 |         api_key="fakekey",
255 |         response_format=ResponseFormat,
256 |     )
257 |     lm("Query")
258 | 
259 | 
260 | def test_retry_number_set_correctly():
261 |     lm = dspy.LM("openai/gpt-4o-mini", num_retries=3)
262 |     with mock.patch("litellm.completion") as mock_completion:
263 |         lm("query")
264 | 
265 |     assert mock_completion.call_args.kwargs["num_retries"] == 3
266 | 
267 | 
268 | def test_retry_made_on_system_errors():
269 |     retry_tracking = [0]  # Using a list to track retries
270 | 
271 |     def mock_create(*args, **kwargs):
272 |         retry_tracking[0] += 1
273 |         # These fields are called during the error handling
274 |         mock_response = mock.Mock()
275 |         mock_response.headers = {}
276 |         mock_response.status_code = 429
277 |         raise RateLimitError(response=mock_response, message="message", body="error")
278 | 
279 |     lm = dspy.LM(model="openai/gpt-4o-mini", max_tokens=250, num_retries=3)
280 |     with mock.patch.object(litellm.OpenAIChatCompletion, "completion", side_effect=mock_create):
281 |         with pytest.raises(RateLimitError):
282 |             lm("question")
283 | 
284 |     assert retry_tracking[0] == 4
285 | 
286 | 
287 | def test_reasoning_model_token_parameter():
288 |     test_cases = [
289 |         ("openai/o1", True),
290 |         ("openai/o1-mini", True),
291 |         ("openai/o1-2023-01-01", True),
292 |         ("openai/o3", True),
293 |         ("openai/o3-mini-2023-01-01", True),
294 |         ("openai/gpt-5", True),
295 |         ("openai/gpt-5-mini", True),
296 |         ("openai/gpt-5-nano", True),
297 |         ("openai/gpt-4", False),
298 |         ("anthropic/claude-2", False),
299 |     ]
300 | 
301 |     for model_name, is_reasoning_model in test_cases:
302 |         lm = dspy.LM(
303 |             model=model_name,
304 |             temperature=1.0 if is_reasoning_model else 0.7,
305 |             max_tokens=16_000 if is_reasoning_model else 1000,
306 |         )
307 |         if is_reasoning_model:
308 |             assert "max_completion_tokens" in lm.kwargs
309 |             assert "max_tokens" not in lm.kwargs
310 |             assert lm.kwargs["max_completion_tokens"] == 16_000
311 |         else:
312 |             assert "max_completion_tokens" not in lm.kwargs
313 |             assert "max_tokens" in lm.kwargs
314 |             assert lm.kwargs["max_tokens"] == 1000
315 | 
316 | @pytest.mark.parametrize("model_name", ["openai/o1", "openai/gpt-5-nano"])
317 | def test_reasoning_model_requirements(model_name):
318 |     # Should raise assertion error if temperature or max_tokens requirements not met
319 |     with pytest.raises(
320 |         ValueError,
321 |         match="reasoning models require passing temperature=1.0 or None and max_tokens >= 16000 or None",
322 |     ):
323 |         dspy.LM(
324 |             model=model_name,
325 |             temperature=0.7,  # Should be 1.0
326 |             max_tokens=1000,  # Should be >= 16_000
327 |         )
328 | 
329 |     # Should pass with correct parameters
330 |     lm = dspy.LM(
331 |         model=model_name,
332 |         temperature=1.0,
333 |         max_tokens=16_000,
334 |     )
335 |     assert lm.kwargs["max_completion_tokens"] == 16_000
336 | 
337 |     # Should pass with no parameters
338 |     lm = dspy.LM(
339 |         model=model_name,
340 |     )
341 |     assert lm.kwargs["temperature"] == None
342 |     assert lm.kwargs["max_completion_tokens"] == None
343 | 
344 | 
345 | def test_dump_state():
346 |     lm = dspy.LM(
347 |         model="openai/gpt-4o-mini",
348 |         model_type="chat",
349 |         temperature=1,
350 |         max_tokens=100,
351 |         num_retries=10,
352 |         launch_kwargs={"temperature": 1},
353 |         train_kwargs={"temperature": 5},
354 |     )
355 | 
356 |     assert lm.dump_state() == {
357 |         "model": "openai/gpt-4o-mini",
358 |         "model_type": "chat",
359 |         "temperature": 1,
360 |         "max_tokens": 100,
361 |         "num_retries": 10,
362 |         "cache": True,
363 |         "finetuning_model": None,
364 |         "launch_kwargs": {"temperature": 1},
365 |         "train_kwargs": {"temperature": 5},
366 |     }
367 | 
368 | 
369 | def test_exponential_backoff_retry():
370 |     time_counter = []
371 | 
372 |     def mock_create(*args, **kwargs):
373 |         time_counter.append(time.time())
374 |         # These fields are called during the error handling
375 |         mock_response = mock.Mock()
376 |         mock_response.headers = {}
377 |         mock_response.status_code = 429
378 |         raise RateLimitError(response=mock_response, message="message", body="error")
379 | 
380 |     lm = dspy.LM(model="openai/gpt-3.5-turbo", max_tokens=250, num_retries=3)
381 |     with mock.patch.object(litellm.OpenAIChatCompletion, "completion", side_effect=mock_create):
382 |         with pytest.raises(RateLimitError):
383 |             lm("question")
384 | 
385 |     # The first retry happens immediately regardless of the configuration
386 |     for i in range(1, len(time_counter) - 1):
387 |         assert time_counter[i + 1] - time_counter[i] >= 2 ** (i - 1)
388 | 
389 | 
390 | def test_logprobs_included_when_requested():
391 |     lm = dspy.LM(model="dspy-test-model", logprobs=True, cache=False)
392 |     with mock.patch("litellm.completion") as mock_completion:
393 |         mock_completion.return_value = ModelResponse(
394 |             choices=[
395 |                 Choices(
396 |                     message=Message(content="test answer"),
397 |                     logprobs={
398 |                         "content": [
399 |                             {"token": "test", "logprob": 0.1, "top_logprobs": [{"token": "test", "logprob": 0.1}]},
400 |                             {"token": "answer", "logprob": 0.2, "top_logprobs": [{"token": "answer", "logprob": 0.2}]},
401 |                         ]
402 |                     },
403 |                 )
404 |             ],
405 |             model="dspy-test-model",
406 |         )
407 |         result = lm("question")
408 |         assert result[0]["text"] == "test answer"
409 |         assert result[0]["logprobs"].model_dump() == {
410 |             "content": [
411 |                 {
412 |                     "token": "test",
413 |                     "bytes": None,
414 |                     "logprob": 0.1,
415 |                     "top_logprobs": [{"token": "test", "bytes": None, "logprob": 0.1}],
416 |                 },
417 |                 {
418 |                     "token": "answer",
419 |                     "bytes": None,
420 |                     "logprob": 0.2,
421 |                     "top_logprobs": [{"token": "answer", "bytes": None, "logprob": 0.2}],
422 |                 },
423 |             ]
424 |         }
425 |         assert mock_completion.call_args.kwargs["logprobs"]
426 | 
427 | 
428 | @pytest.mark.asyncio
429 | async def test_async_lm_call():
430 |     from litellm.utils import Choices, Message, ModelResponse
431 | 
432 |     mock_response = ModelResponse(choices=[Choices(message=Message(content="answer"))], model="openai/gpt-4o-mini")
433 | 
434 |     with patch("litellm.acompletion") as mock_acompletion:
435 |         mock_acompletion.return_value = mock_response
436 | 
437 |         lm = dspy.LM(model="openai/gpt-4o-mini", cache=False)
438 |         result = await lm.acall("question")
439 | 
440 |         assert result == ["answer"]
441 |         mock_acompletion.assert_called_once()
442 | 
443 | 
444 | @pytest.mark.asyncio
445 | async def test_async_lm_call_with_cache(tmp_path):
446 |     """Test the async LM call with caching."""
447 |     original_cache = dspy.cache
448 |     dspy.clients.configure_cache(
449 |         enable_disk_cache=True,
450 |         enable_memory_cache=True,
451 |         disk_cache_dir=tmp_path / ".disk_cache",
452 |     )
453 |     cache = dspy.cache
454 | 
455 |     lm = dspy.LM(model="openai/gpt-4o-mini")
456 | 
457 |     with mock.patch("dspy.clients.lm.alitellm_completion") as mock_alitellm_completion:
458 |         mock_alitellm_completion.return_value = ModelResponse(
459 |             choices=[Choices(message=Message(content="answer"))], model="openai/gpt-4o-mini"
460 |         )
461 |         mock_alitellm_completion.__qualname__ = "alitellm_completion"
462 |         await lm.acall("Query")
463 | 
464 |         assert len(cache.memory_cache) == 1
465 |         cache_key = next(iter(cache.memory_cache.keys()))
466 |         assert cache_key in cache.disk_cache
467 |         assert mock_alitellm_completion.call_count == 1
468 | 
469 |         await lm.acall("Query")
470 |         # Second call should hit the cache, so no new call to LiteLLM is made.
471 |         assert mock_alitellm_completion.call_count == 1
472 | 
473 |         # A new query should result in a new LiteLLM call and a new cache entry.
474 |         await lm.acall("New query")
475 | 
476 |         assert len(cache.memory_cache) == 2
477 |         assert mock_alitellm_completion.call_count == 2
478 | 
479 |     dspy.cache = original_cache
480 | 
481 | 
482 | def test_lm_history_size_limit():
483 |     lm = dspy.LM(model="openai/gpt-4o-mini")
484 |     with dspy.context(max_history_size=5):
485 |         with mock.patch("litellm.completion") as mock_completion:
486 |             mock_completion.return_value = ModelResponse(
487 |                 choices=[Choices(message=Message(content="test answer"))],
488 |                 model="openai/gpt-4o-mini",
489 |             )
490 | 
491 |             for _ in range(10):
492 |                 lm("query")
493 | 
494 |     assert len(lm.history) == 5
495 | 
496 | 
497 | def test_disable_history():
498 |     lm = dspy.LM(model="openai/gpt-4o-mini")
499 |     with dspy.context(disable_history=True):
500 |         with mock.patch("litellm.completion") as mock_completion:
501 |             mock_completion.return_value = ModelResponse(
502 |                 choices=[Choices(message=Message(content="test answer"))],
503 |                 model="openai/gpt-4o-mini",
504 |             )
505 |             for _ in range(10):
506 |                 lm("query")
507 | 
508 |     assert len(lm.history) == 0
509 | 
510 |     with dspy.context(disable_history=False):
511 |         with mock.patch("litellm.completion") as mock_completion:
512 |             mock_completion.return_value = ModelResponse(
513 |                 choices=[Choices(message=Message(content="test answer"))],
514 |                 model="openai/gpt-4o-mini",
515 |             )
516 | 
517 | def test_responses_api():
518 |     api_response = make_response(
519 |         output_blocks=[
520 |             ResponseOutputMessage(
521 |                 **{
522 |                     "id": "msg_1",
523 |                     "type": "message",
524 |                     "role": "assistant",
525 |                     "status": "completed",
526 |                     "content": [
527 |                         {"type": "output_text", "text": "This is a test answer from responses API.", "annotations": []}
528 |                     ],
529 |                 },
530 |             ),
531 |             ResponseReasoningItem(
532 |                 **{
533 |                     "id": "reasoning_1",
534 |                     "type": "reasoning",
535 |                     "summary": [Summary(**{"type": "summary_text", "text": "This is a dummy reasoning."})],
536 |                 },
537 |             ),
538 |         ]
539 |     )
540 | 
541 |     with mock.patch("litellm.responses", autospec=True, return_value=api_response) as dspy_responses:
542 |         lm = dspy.LM(
543 |             model="openai/gpt-5-mini",
544 |             model_type="responses",
545 |             cache=False,
546 |             temperature=1.0,
547 |             max_tokens=16000,
548 |         )
549 |         lm_result = lm("openai query")
550 | 
551 |         assert lm_result == [
552 |             {
553 |                 "text": "This is a test answer from responses API.",
554 |                 "reasoning_content": "This is a dummy reasoning.",
555 |             }
556 |         ]
557 | 
558 |         dspy_responses.assert_called_once()
559 |         assert dspy_responses.call_args.kwargs["model"] == "openai/gpt-5-mini"
560 | 
561 | 
562 | def test_lm_replaces_system_with_developer_role():
563 |     with mock.patch(
564 |         "dspy.clients.lm.litellm_responses_completion", return_value={"choices": []}
565 |     ) as mock_completion:
566 |         lm = dspy.LM(
567 |             "openai/gpt-4o-mini",
568 |             cache=False,
569 |             model_type="responses",
570 |             use_developer_role=True,
571 |         )
572 |         lm.forward(messages=[{"role": "system", "content": "hi"}])
573 |         assert (
574 |             mock_completion.call_args.kwargs["request"]["messages"][0]["role"]
575 |             == "developer"
576 |         )
577 | 
578 | 
579 | def test_responses_api_tool_calls(litellm_test_server):
580 |     api_base, _ = litellm_test_server
581 |     expected_tool_call = {
582 |         "type": "function_call",
583 |         "name": "get_weather",
584 |         "arguments": json.dumps({"city": "Paris"}),
585 |         "call_id": "call_1",
586 |         "status": "completed",
587 |         "id": "call_1",
588 |     }
589 |     expected_response = [{"tool_calls": [expected_tool_call]}]
590 | 
591 |     api_response = make_response(
592 |         output_blocks=[expected_tool_call],
593 |     )
594 | 
595 |     with mock.patch("litellm.responses", autospec=True, return_value=api_response) as dspy_responses:
596 |         lm = dspy.LM(
597 |             model="openai/dspy-test-model",
598 |             api_base=api_base,
599 |             api_key="fakekey",
600 |             model_type="responses",
601 |             cache=False,
602 |         )
603 |         assert lm("openai query") == expected_response
604 | 
605 |         dspy_responses.assert_called_once()
606 |         assert dspy_responses.call_args.kwargs["model"] == "openai/dspy-test-model"
607 | 
```

--------------------------------------------------------------------------------
/docs/docs/api/optimizers/GEPA/GEPA_Advanced.md:
--------------------------------------------------------------------------------

```markdown
  1 | # dspy.GEPA - Advanced Features
  2 | 
  3 | ## Custom Instruction Proposers
  4 | 
  5 | ### What is instruction_proposer?
  6 | 
  7 | The `instruction_proposer` is the component responsible for invoking the `reflection_lm` and proposing new prompts during GEPA optimization. When GEPA identifies underperforming components in your DSPy program, the instruction proposer analyzes execution traces, feedback, and failures to generate improved instructions tailored to the observed issues.
  8 | 
  9 | ### Default Implementation
 10 | 
 11 | By default, GEPA uses the built-in instruction proposer from the [GEPA library](https://github.com/gepa-ai/gepa), which implements the [`ProposalFn`](https://github.com/gepa-ai/gepa/blob/main/src/gepa/core/adapter.py). The [default proposer](https://github.com/gepa-ai/gepa/blob/main/src/gepa/proposer/reflective_mutation/reflective_mutation.py#L53-L75) uses this prompt template:
 12 | 
 13 | ````
 14 | I provided an assistant with the following instructions to perform a task for me:
 15 | ```
 16 | <curr_instructions>
 17 | ```
 18 | 
 19 | The following are examples of different task inputs provided to the assistant along with the assistant's response for each of them, and some feedback on how the assistant's response could be better:
 20 | ```
 21 | <inputs_outputs_feedback>
 22 | ```
 23 | 
 24 | Your task is to write a new instruction for the assistant.
 25 | 
 26 | Read the inputs carefully and identify the input format and infer detailed task description about the task I wish to solve with the assistant.
 27 | 
 28 | Read all the assistant responses and the corresponding feedback. Identify all niche and domain specific factual information about the task and include it in the instruction, as a lot of it may not be available to the assistant in the future. The assistant may have utilized a generalizable strategy to solve the task, if so, include that in the instruction as well.
 29 | 
 30 | Provide the new instructions within ``` blocks.
 31 | ````
 32 | 
 33 | This template is automatically filled with:
 34 | 
 35 | - `<curr_instructions>`: The current instruction being optimized
 36 | - `<inputs_outputs_feedback>`: Structured markdown containing predictor inputs, generated outputs, and evaluation feedback
 37 | 
 38 | Example of default behavior:
 39 | 
 40 | ```python
 41 | # Default instruction proposer is used automatically
 42 | gepa = dspy.GEPA(
 43 |     metric=my_metric,
 44 |     reflection_lm=dspy.LM(model="gpt-5", temperature=1.0, max_tokens=32000, api_key=api_key),
 45 |     auto="medium"
 46 | )
 47 | optimized_program = gepa.compile(student, trainset=examples)
 48 | ```
 49 | 
 50 | ### When to Use Custom instruction_proposer
 51 | 
 52 | **Note:** Custom instruction proposers are an advanced feature. Most users should start with the default proposer, which works well for most text-based optimization tasks.
 53 | 
 54 | Consider implementing a custom instruction proposer when you need:
 55 | 
 56 | - **Multi-modal handling**: Process images (dspy.Image) alongside textual information in your inputs
 57 | - **Nuanced control on limits and length constraints**: Have more fine-grained control over instruction length, format, and structural requirements
 58 | - **Domain-specific information**: Inject specialized knowledge, terminology, or context that the default proposer lacks and cannot be provided via feedback_func. This is an advanced feature, and most users should not need to use this.
 59 | - **Provider-specific prompting guides**: Optimize instructions for specific LLM providers (OpenAI, Anthropic, etc.) with their unique formatting preferences
 60 | - **Coupled component updates**: Handle situations where 2 or more components need to be updated together in a coordinated manner, rather than optimizing each component independently (refer to component_selector parameter, in [Custom Component Selection](#custom-component-selection) section, for related functionality)
 61 | - **External knowledge integration**: Connect to databases, APIs, or knowledge bases during instruction generation
 62 | 
 63 | ### Available Options
 64 | 
 65 | **Built-in Options:**
 66 | 
 67 | - **Default Proposer**: The standard GEPA instruction proposer (used when `instruction_proposer=None`). The default instruction proposer IS an instruction proposer as well! It is the most general one, that was used for the diverse experiments reported in the GEPA paper and tutorials.
 68 | - **MultiModalInstructionProposer**: Handles `dspy.Image` inputs and structured multimodal content.
 69 | 
 70 | ```python
 71 | from dspy.teleprompt.gepa.instruction_proposal import MultiModalInstructionProposer
 72 | 
 73 | # For tasks involving images or multimodal inputs
 74 | gepa = dspy.GEPA(
 75 |     metric=my_metric,
 76 |     reflection_lm=dspy.LM(model="gpt-5", temperature=1.0, max_tokens=32000, api_key=api_key),
 77 |     instruction_proposer=MultiModalInstructionProposer(),
 78 |     auto="medium"
 79 | )
 80 | ```
 81 | 
 82 | We invite community contributions of new instruction proposers for specialized domains as the [GEPA library](https://github.com/gepa-ai/gepa) continues to grow.
 83 | 
 84 | ### How to Implement Custom Instruction Proposers
 85 | 
 86 | Custom instruction proposers must implement the `ProposalFn` protocol by defining a callable class or function. GEPA will call your proposer during optimization:
 87 | 
 88 | ```python
 89 | from dspy.teleprompt.gepa.gepa_utils import ReflectiveExample
 90 | 
 91 | class CustomInstructionProposer:
 92 |     def __call__(
 93 |         self,
 94 |         candidate: dict[str, str],                          # Candidate component name -> instruction mapping to be updated in this round
 95 |         reflective_dataset: dict[str, list[ReflectiveExample]],  # Component -> examples with structure: {"Inputs": ..., "Generated Outputs": ..., "Feedback": ...}
 96 |         components_to_update: list[str]                     # Which components to improve
 97 |     ) -> dict[str, str]:                                    # Return new instruction mapping only for components being updated
 98 |         # Your custom instruction generation logic here
 99 |         return updated_instructions
100 | 
101 | # Or as a function:
102 | def custom_instruction_proposer(candidate, reflective_dataset, components_to_update):
103 |     # Your custom instruction generation logic here
104 |     return updated_instructions
105 | ```
106 | 
107 | **Reflective Dataset Structure:**
108 | 
109 | - `dict[str, list[ReflectiveExample]]` - Maps component names to lists of examples
110 | - `ReflectiveExample` TypedDict contains:
111 |   - `Inputs: dict[str, Any]` - Predictor inputs (may include dspy.Image objects)
112 |   - `Generated_Outputs: dict[str, Any] | str` - Success: output fields dict, Failure: error message
113 |   - `Feedback: str` - Always a string from metric function or auto-generated by GEPA
114 | 
115 | #### Basic Example: Word Limit Proposer
116 | 
117 | ```python
118 | import dspy
119 | from gepa.core.adapter import ProposalFn
120 | from dspy.teleprompt.gepa.gepa_utils import ReflectiveExample
121 | 
122 | class GenerateWordLimitedInstruction(dspy.Signature):
123 |     """Given a current instruction and feedback examples, generate an improved instruction with word limit constraints."""
124 | 
125 |     current_instruction = dspy.InputField(desc="The current instruction that needs improvement")
126 |     feedback_summary = dspy.InputField(desc="Feedback from examples that might include both positive and negative cases")
127 |     max_words = dspy.InputField(desc="Maximum number of words allowed in the new instruction")
128 | 
129 |     improved_instruction = dspy.OutputField(desc="A new instruction that fixes the issues while staying under the max_words limit")
130 | 
131 | class WordLimitProposer(ProposalFn):
132 |     def __init__(self, max_words: int = 1000):
133 |         self.max_words = max_words
134 |         self.instruction_improver = dspy.ChainOfThought(GenerateWordLimitedInstruction)
135 | 
136 |     def __call__(self, candidate: dict[str, str], reflective_dataset: dict[str, list[ReflectiveExample]], components_to_update: list[str]) -> dict[str, str]:
137 |         updated_components = {}
138 | 
139 |         for component_name in components_to_update:
140 |             if component_name not in candidate or component_name not in reflective_dataset:
141 |                 continue
142 | 
143 |             current_instruction = candidate[component_name]
144 |             component_examples = reflective_dataset[component_name]
145 | 
146 |             # Create feedback summary
147 |             feedback_text = "\n".join([
148 |                 f"Example {i+1}: {ex.get('Feedback', 'No feedback')}"
149 |                 for i, ex in enumerate(component_examples)  # Limit examples to prevent context overflow
150 |             ])
151 | 
152 |             # Use the module to improve the instruction
153 |             result = self.instruction_improver(
154 |                 current_instruction=current_instruction,
155 |                 feedback_summary=feedback_text,
156 |                 max_words=self.max_words
157 |             )
158 | 
159 |             updated_components[component_name] = result.improved_instruction
160 | 
161 |         return updated_components
162 | 
163 | # Usage
164 | gepa = dspy.GEPA(
165 |     metric=my_metric,
166 |     reflection_lm=dspy.LM(model="gpt-5", temperature=1.0, max_tokens=32000, api_key=api_key),
167 |     instruction_proposer=WordLimitProposer(max_words=700),
168 |     auto="medium"
169 | )
170 | ```
171 | 
172 | #### Advanced Example: RAG-Enhanced Instruction Proposer
173 | 
174 | ```python
175 | import dspy
176 | from gepa.core.adapter import ProposalFn
177 | from dspy.teleprompt.gepa.gepa_utils import ReflectiveExample
178 | 
179 | class GenerateDocumentationQuery(dspy.Signature):
180 |     """Analyze examples with feedback to identify common issue patterns and generate targeted database queries for retrieving relevant documentation.
181 | 
182 |     Your goal is to search a document database for guidelines that address the problematic patterns found in the examples. Look for recurring issues, error types, or failure modes in the feedback, then craft specific search queries that will find documentation to help resolve these patterns."""
183 | 
184 |     current_instruction = dspy.InputField(desc="The current instruction that needs improvement")
185 |     examples_with_feedback = dspy.InputField(desc="Examples with their feedback showing what issues occurred and any recurring patterns")
186 | 
187 |     failure_patterns: str = dspy.OutputField(desc="Summarize the common failure patterns identified in the examples")
188 | 
189 |     retrieval_queries: list[str] = dspy.OutputField(desc="Specific search queries to find relevant documentation in the database that addresses the common issue patterns identified in the problematic examples")
190 | 
191 | class GenerateRAGEnhancedInstruction(dspy.Signature):
192 |     """Generate improved instructions using retrieved documentation and examples analysis."""
193 | 
194 |     current_instruction = dspy.InputField(desc="The current instruction that needs improvement")
195 |     relevant_documentation = dspy.InputField(desc="Retrieved guidelines and best practices from specialized documentation")
196 |     examples_with_feedback = dspy.InputField(desc="Examples showing what issues occurred with the current instruction")
197 | 
198 |     improved_instruction: str = dspy.OutputField(desc="Enhanced instruction that incorporates retrieved guidelines and addresses the issues shown in the examples")
199 | 
200 | class RAGInstructionImprover(dspy.Module):
201 |     """Module that uses RAG to improve instructions with specialized documentation."""
202 | 
203 |     def __init__(self, retrieval_model):
204 |         super().__init__()
205 |         self.retrieve = retrieval_model  # Could be dspy.Retrieve or custom retriever
206 |         self.query_generator = dspy.ChainOfThought(GenerateDocumentationQuery)
207 |         self.generate_answer = dspy.ChainOfThought(GenerateRAGEnhancedInstruction)
208 | 
209 |     def forward(self, current_instruction: str, component_examples: list):
210 |         """Improve instruction using retrieved documentation."""
211 | 
212 |         # Let LM analyze examples and generate targeted retrieval queries
213 |         query_result = self.query_generator(
214 |             current_instruction=current_instruction,
215 |             examples_with_feedback=component_examples
216 |         )
217 | 
218 |         results = self.retrieve.query(
219 |             query_texts=query_result.retrieval_queries,
220 |             n_results=3
221 |         )
222 | 
223 |         relevant_docs_parts = []
224 |         for i, (query, query_docs) in enumerate(zip(query_result.retrieval_queries, results['documents'])):
225 |             if query_docs:
226 |                 docs_formatted = "\n".join([f"  - {doc}" for doc in query_docs])
227 |                 relevant_docs_parts.append(
228 |                     f"**Search Query #{i+1}**: {query}\n"
229 |                     f"**Retrieved Guidelines**:\n{docs_formatted}"
230 |                 )
231 | 
232 |         relevant_docs = "\n\n" + "="*60 + "\n\n".join(relevant_docs_parts) + "\n" + "="*60
233 | 
234 |         # Generate improved instruction with retrieved context
235 |         result = self.generate_answer(
236 |             current_instruction=current_instruction,
237 |             relevant_documentation=relevant_docs,
238 |             examples_with_feedback=component_examples
239 |         )
240 | 
241 |         return result
242 | 
243 | class DocumentationEnhancedProposer(ProposalFn):
244 |     """Instruction proposer that accesses specialized documentation via RAG."""
245 | 
246 |     def __init__(self, documentation_retriever):
247 |         """
248 |         Args:
249 |             documentation_retriever: A retrieval model that can search your specialized docs
250 |                                    Could be dspy.Retrieve, ChromadbRM, or custom retriever
251 |         """
252 |         self.instruction_improver = RAGInstructionImprover(documentation_retriever)
253 | 
254 |     def __call__(self, candidate: dict[str, str], reflective_dataset: dict[str, list[ReflectiveExample]], components_to_update: list[str]) -> dict[str, str]:
255 |         updated_components = {}
256 | 
257 |         for component_name in components_to_update:
258 |             if component_name not in candidate or component_name not in reflective_dataset:
259 |                 continue
260 | 
261 |             current_instruction = candidate[component_name]
262 |             component_examples = reflective_dataset[component_name]
263 | 
264 |             result = self.instruction_improver(
265 |                 current_instruction=current_instruction,
266 |                 component_examples=component_examples
267 |             )
268 | 
269 |             updated_components[component_name] = result.improved_instruction
270 | 
271 |         return updated_components
272 | 
273 | import chromadb
274 | 
275 | client = chromadb.Client()
276 | collection = client.get_collection("instruction_guidelines")
277 | 
278 | gepa = dspy.GEPA(
279 |     metric=task_specific_metric,
280 |     reflection_lm=dspy.LM(model="gpt-5", temperature=1.0, max_tokens=32000, api_key=api_key),
281 |     instruction_proposer=DocumentationEnhancedProposer(collection),
282 |     auto="medium"
283 | )
284 | ```
285 | 
286 | #### Integration Patterns
287 | 
288 | **Using Custom Proposer with External LM:**
289 | 
290 | ```python
291 | class ExternalLMProposer(ProposalFn):
292 |     def __init__(self):
293 |         # Manage your own LM instance
294 |         self.external_lm = dspy.LM('gemini/gemini-2.5-pro')
295 | 
296 |     def __call__(self, candidate, reflective_dataset, components_to_update):
297 |         updated_components = {}
298 | 
299 |         with dspy.context(lm=self.external_lm):
300 |             # Your custom logic here using self.external_lm
301 |             for component_name in components_to_update:
302 |                 # ... implementation
303 |                 pass
304 | 
305 |         return updated_components
306 | 
307 | gepa = dspy.GEPA(
308 |     metric=my_metric,
309 |     reflection_lm=None,  # Optional when using custom proposer
310 |     instruction_proposer=ExternalLMProposer(),
311 |     auto="medium"
312 | )
313 | ```
314 | 
315 | **Best Practices:**
316 | 
317 | - **Use the full power of DSPy**: Leverage DSPy components like `dspy.Module`, `dspy.Signature`, and `dspy.Predict` to create your instruction proposer rather than direct LM calls. Consider `dspy.Refine` for constraint satisfaction, `dspy.ChainOfThought` for complex reasoning tasks, and compose multiple modules for sophisticated instruction improvement workflows
318 | - **Enable holistic feedback analysis**: While dspy.GEPA's `GEPAFeedbackMetric` processes one (gold, prediction) pair at a time, instruction proposers receive all examples for a component in batch, enabling cross-example pattern detection and systematic issue identification.
319 | - **Mind data serialization**: Serializing everything to strings might not be ideal - handle complex input types (like `dspy.Image`) by maintaining their structure for better LM processing
320 | - **Test thoroughly**: Test your custom proposer with representative failure cases
321 | 
322 | ## Custom Component Selection
323 | 
324 | ### What is component_selector?
325 | 
326 | The `component_selector` parameter controls which components (predictors) in your DSPy program are selected for optimization at each GEPA iteration. Instead of the default round-robin approach that updates one component at a time, you can implement custom selection strategies that choose single or multiple components based on optimization state, performance trajectories, and other contextual information.
327 | 
328 | ### Default Behavior
329 | 
330 | By default, GEPA uses a **round-robin strategy** (`RoundRobinReflectionComponentSelector`) that cycles through components sequentially, optimizing one component per iteration:
331 | 
332 | ```python
333 | # Default round-robin component selection
334 | gepa = dspy.GEPA(
335 |     metric=my_metric,
336 |     reflection_lm=dspy.LM(model="gpt-5", temperature=1.0, max_tokens=32000, api_key=api_key),
337 |     # component_selector="round_robin"  # This is the default
338 |     auto="medium"
339 | )
340 | ```
341 | 
342 | ### Built-in Selection Strategies
343 | 
344 | **String-based selectors:**
345 | 
346 | - `"round_robin"` (default): Cycles through components one at a time
347 | - `"all"`: Selects all components for simultaneous optimization
348 | 
349 | ```python
350 | # Optimize all components simultaneously
351 | gepa = dspy.GEPA(
352 |     metric=my_metric,
353 |     reflection_lm=reflection_lm,
354 |     component_selector="all",  # Update all components together
355 |     auto="medium"
356 | )
357 | 
358 | # Explicit round-robin selection
359 | gepa = dspy.GEPA(
360 |     metric=my_metric,
361 |     reflection_lm=reflection_lm,
362 |     component_selector="round_robin",  # One component per iteration
363 |     auto="medium"
364 | )
365 | ```
366 | 
367 | ### When to Use Custom Component Selection
368 | 
369 | Consider implementing custom component selection when you need:
370 | 
371 | - **Dependency-aware optimization**: Update related components together (e.g., a classifier and its input formatter)
372 | - **LLM-driven selection**: Let an LLM analyze trajectories and decide which components need attention
373 | - **Resource-conscious optimization**: Balance optimization thoroughness with computational budget
374 | 
375 | ### Custom Component Selector Protocol
376 | 
377 | Custom component selectors must implement the [`ReflectionComponentSelector`](https://github.com/gepa-ai/gepa/blob/main/src/gepa/proposer/reflective_mutation/base.py) protocol by defining a callable class or function. GEPA will call your selector during optimization:
378 | 
379 | ```python
380 | from dspy.teleprompt.gepa.gepa_utils import GEPAState, Trajectory
381 | 
382 | class CustomComponentSelector:
383 |     def __call__(
384 |         self,
385 |         state: GEPAState,                    # Complete optimization state with history
386 |         trajectories: list[Trajectory],      # Execution traces from the current minibatch
387 |         subsample_scores: list[float],       # Scores for each example in the current minibatch
388 |         candidate_idx: int,                  # Index of the current program candidate being optimized
389 |         candidate: dict[str, str],           # Component name -> instruction mapping
390 |     ) -> list[str]:                          # Return list of component names to optimize
391 |         # Your custom component selection logic here
392 |         return selected_components
393 | 
394 | # Or as a function:
395 | def custom_component_selector(state, trajectories, subsample_scores, candidate_idx, candidate):
396 |     # Your custom component selection logic here
397 |     return selected_components
398 | ```
399 | 
400 | ### Custom Implementation Example
401 | 
402 | Here's a simple function that alternates between optimizing different halves of your components:
403 | 
404 | ```python
405 | def alternating_half_selector(state, trajectories, subsample_scores, candidate_idx, candidate):
406 |     """Optimize half the components on even iterations, half on odd iterations."""
407 |     components = list(candidate.keys())
408 | 
409 |     # If there's only one component, always optimize it
410 |     if len(components) <= 1:
411 |         return components
412 | 
413 |     mid_point = len(components) // 2
414 | 
415 |     # Use state.i (iteration counter) to alternate between halves
416 |     if state.i % 2 == 0:
417 |         # Even iteration: optimize first half
418 |         return components[:mid_point]
419 |     else:
420 |         # Odd iteration: optimize second half
421 |         return components[mid_point:]
422 | 
423 | # Usage
424 | gepa = dspy.GEPA(
425 |     metric=my_metric,
426 |     reflection_lm=reflection_lm,
427 |     component_selector=alternating_half_selector,
428 |     auto="medium"
429 | )
430 | ```
431 | 
432 | ### Integration with Custom Instruction Proposers
433 | 
434 | Component selectors work seamlessly with custom instruction proposers. The selector determines which components to update, then the instruction proposer generates new instructions for those components:
435 | 
436 | ```python
437 | # Combined custom selector + custom proposer
438 | gepa = dspy.GEPA(
439 |     metric=my_metric,
440 |     reflection_lm=reflection_lm,
441 |     component_selector=alternating_half_selector,
442 |     instruction_proposer=WordLimitProposer(max_words=500),
443 |     auto="medium"
444 | )
445 | ```
446 | 
```

--------------------------------------------------------------------------------
/dspy/retrievers/databricks_rm.py:
--------------------------------------------------------------------------------

```python
  1 | import json
  2 | import os
  3 | from dataclasses import dataclass
  4 | from importlib.util import find_spec
  5 | from typing import Any
  6 | 
  7 | import requests
  8 | 
  9 | import dspy
 10 | from dspy.primitives.prediction import Prediction
 11 | 
 12 | _databricks_sdk_installed = find_spec("databricks.sdk") is not None
 13 | 
 14 | 
 15 | @dataclass
 16 | class Document:
 17 |     page_content: str
 18 |     metadata: dict[str, Any]
 19 |     type: str
 20 | 
 21 |     def to_dict(self) -> dict[str, Any]:
 22 |         return {
 23 |             "page_content": self.page_content,
 24 |             "metadata": self.metadata,
 25 |             "type": self.type,
 26 |         }
 27 | 
 28 | 
 29 | class DatabricksRM(dspy.Retrieve):
 30 |     """
 31 |     A retriever module that uses a Databricks Mosaic AI Vector Search Index to return the top-k
 32 |     embeddings for a given query.
 33 | 
 34 |     Examples:
 35 |         Below is a code snippet that shows how to set up a Databricks Vector Search Index
 36 |         and configure a DatabricksRM DSPy retriever module to query the index.
 37 | 
 38 |         (example adapted from "Databricks: How to create and query a Vector Search Index:
 39 |         https://docs.databricks.com/en/generative-ai/create-query-vector-search.html#create-a-vector-search-index)
 40 | 
 41 |         ```python
 42 |         from databricks.vector_search.client import VectorSearchClient
 43 | 
 44 |         # Create a Databricks Vector Search Endpoint
 45 |         client = VectorSearchClient()
 46 |         client.create_endpoint(
 47 |             name="your_vector_search_endpoint_name",
 48 |             endpoint_type="STANDARD"
 49 |         )
 50 | 
 51 |         # Create a Databricks Direct Access Vector Search Index
 52 |         index = client.create_direct_access_index(
 53 |             endpoint_name="your_vector_search_endpoint_name",
 54 |             index_name="your_index_name",
 55 |             primary_key="id",
 56 |             embedding_dimension=1024,
 57 |             embedding_vector_column="text_vector",
 58 |             schema={
 59 |               "id": "int",
 60 |               "field2": "str",
 61 |               "field3": "float",
 62 |               "text_vector": "array<float>"
 63 |             }
 64 |         )
 65 | 
 66 |         # Create a DatabricksRM retriever module to query the Databricks Direct Access Vector
 67 |         # Search Index
 68 |         retriever = DatabricksRM(
 69 |             databricks_index_name = "your_index_name",
 70 |             docs_id_column_name="id",
 71 |             text_column_name="field2",
 72 |             k=3
 73 |         )
 74 |         ```
 75 | 
 76 |         Below is a code snippet that shows how to query the Databricks Direct Access Vector
 77 |         Search Index using the DatabricksRM retriever module:
 78 | 
 79 |         ```python
 80 |         retrieved_results = DatabricksRM(query="Example query text"))
 81 |         ```
 82 |     """
 83 | 
 84 |     def __init__(
 85 |         self,
 86 |         databricks_index_name: str,
 87 |         databricks_endpoint: str | None = None,
 88 |         databricks_token: str | None = None,
 89 |         databricks_client_id: str | None = None,
 90 |         databricks_client_secret: str | None = None,
 91 |         columns: list[str] | None = None,
 92 |         filters_json: str | None = None,
 93 |         k: int = 3,
 94 |         docs_id_column_name: str = "id",
 95 |         docs_uri_column_name: str | None = None,
 96 |         text_column_name: str = "text",
 97 |         use_with_databricks_agent_framework: bool = False,
 98 |     ):
 99 |         """
100 |         Args:
101 |             databricks_index_name (str): The name of the Databricks Vector Search Index to query.
102 |             databricks_endpoint (Optional[str]): The URL of the Databricks Workspace containing
103 |                 the Vector Search Index. Defaults to the value of the ``DATABRICKS_HOST``
104 |                 environment variable. If unspecified, the Databricks SDK is used to identify the
105 |                 endpoint based on the current environment.
106 |             databricks_token (Optional[str]): The Databricks Workspace authentication token to use
107 |                 when querying the Vector Search Index. Defaults to the value of the
108 |                 ``DATABRICKS_TOKEN`` environment variable. If unspecified, the Databricks SDK is
109 |                 used to identify the token based on the current environment.
110 |             databricks_client_id (str): Databricks service principal id. If not specified,
111 |                 the token is resolved from the current environment (DATABRICKS_CLIENT_ID).
112 |             databricks_client_secret (str): Databricks service principal secret. If not specified,
113 |                 the endpoint is resolved from the current environment (DATABRICKS_CLIENT_SECRET).
114 |             columns (Optional[list[str]]): Extra column names to include in response,
115 |                 in addition to the document id and text columns specified by
116 |                 ``docs_id_column_name`` and ``text_column_name``.
117 |             filters_json (Optional[str]): A JSON string specifying additional query filters.
118 |                 Example filters: ``{"id <": 5}`` selects records that have an ``id`` column value
119 |                 less than 5, and ``{"id >=": 5, "id <": 10}`` selects records that have an ``id``
120 |                 column value greater than or equal to 5 and less than 10.
121 |             k (int): The number of documents to retrieve.
122 |             docs_id_column_name (str): The name of the column in the Databricks Vector Search Index
123 |                 containing document IDs.
124 |             docs_uri_column_name (Optional[str]): The name of the column in the Databricks Vector Search Index
125 |                 containing document URI.
126 |             text_column_name (str): The name of the column in the Databricks Vector Search Index
127 |                 containing document text to retrieve.
128 |             use_with_databricks_agent_framework (bool): Whether to use the `DatabricksRM` in a way that is
129 |                 compatible with the Databricks Mosaic Agent Framework.
130 |         """
131 |         super().__init__(k=k)
132 |         self.databricks_token = databricks_token if databricks_token is not None else os.environ.get("DATABRICKS_TOKEN")
133 |         self.databricks_endpoint = (
134 |             databricks_endpoint if databricks_endpoint is not None else os.environ.get("DATABRICKS_HOST")
135 |         )
136 |         self.databricks_client_id = (
137 |             databricks_client_id if databricks_client_id is not None else os.environ.get("DATABRICKS_CLIENT_ID")
138 |         )
139 |         self.databricks_client_secret = (
140 |             databricks_client_secret
141 |             if databricks_client_secret is not None
142 |             else os.environ.get("DATABRICKS_CLIENT_SECRET")
143 |         )
144 |         if not _databricks_sdk_installed and (self.databricks_token, self.databricks_endpoint).count(None) > 0:
145 |             raise ValueError(
146 |                 "To retrieve documents with Databricks Vector Search, you must install the"
147 |                 " databricks-sdk Python library, supply the databricks_token and"
148 |                 " databricks_endpoint parameters, or set the DATABRICKS_TOKEN and DATABRICKS_HOST"
149 |                 " environment variables. You may also supply a service principal the databricks_client_id and"
150 |                 " databricks_client_secret parameters, or set the DATABRICKS_CLIENT_ID and DATABRICKS_CLIENT_SECRET"
151 |             )
152 |         self.databricks_index_name = databricks_index_name
153 |         self.columns = list({docs_id_column_name, text_column_name, *(columns or [])})
154 |         self.filters_json = filters_json
155 |         self.k = k
156 |         self.docs_id_column_name = docs_id_column_name
157 |         self.docs_uri_column_name = docs_uri_column_name
158 |         self.text_column_name = text_column_name
159 |         self.use_with_databricks_agent_framework = use_with_databricks_agent_framework
160 |         if self.use_with_databricks_agent_framework:
161 |             try:
162 |                 import mlflow
163 | 
164 |                 mlflow.models.set_retriever_schema(
165 |                     primary_key="doc_id",
166 |                     text_column="page_content",
167 |                     doc_uri="doc_uri",
168 |                 )
169 |             except ImportError:
170 |                 raise ValueError(
171 |                     "To use the `DatabricksRM` retriever module with the Databricks Mosaic Agent Framework, "
172 |                     "you must install the mlflow Python library. Please install mlflow via `pip install mlflow`."
173 |                 )
174 | 
175 |     def _extract_doc_ids(self, item: dict[str, Any]) -> str:
176 |         """Extracts the document id from a search result
177 | 
178 |         Args:
179 |             item: dict[str, Any]: a record from the search results.
180 |         Returns:
181 |             str: document id.
182 |         """
183 |         if self.docs_id_column_name == "metadata":
184 |             docs_dict = json.loads(item["metadata"])
185 |             return docs_dict["document_id"]
186 |         return item[self.docs_id_column_name]
187 | 
188 |     def _get_extra_columns(self, item: dict[str, Any]) -> dict[str, Any]:
189 |         """Extracts search result column values, excluding the "text" and not "id" columns
190 | 
191 |         Args:
192 |             item: dict[str, Any]: a record from the search results.
193 |         Returns:
194 |             dict[str, Any]: Search result column values, excluding the "text", "id" and "uri" columns.
195 |         """
196 |         extra_columns = {
197 |             k: v
198 |             for k, v in item.items()
199 |             if k not in [self.docs_id_column_name, self.text_column_name, self.docs_uri_column_name]
200 |         }
201 |         if self.docs_id_column_name == "metadata":
202 |             extra_columns = {
203 |                 **extra_columns,
204 |                 **{"metadata": {k: v for k, v in json.loads(item["metadata"]).items() if k != "document_id"}},
205 |             }
206 |         return extra_columns
207 | 
208 |     def forward(
209 |         self,
210 |         query: str | list[float],
211 |         query_type: str = "ANN",
212 |         filters_json: str | None = None,
213 |     ) -> dspy.Prediction | list[dict[str, Any]]:
214 |         """
215 |         Retrieve documents from a Databricks Mosaic AI Vector Search Index that are relevant to the
216 |         specified query.
217 | 
218 |         Args:
219 |             query (Union[str, list[float]]): The query text or numeric query vector for which to
220 |                 retrieve relevant documents.
221 |             query_type (str): The type of search query to perform against the Databricks Vector
222 |                 Search Index. Must be either 'ANN' (approximate nearest neighbor) or 'HYBRID'
223 |                 (hybrid search).
224 |             filters_json (Optional[str]): A JSON string specifying additional query filters.
225 |                 Example filters: ``{"id <": 5}`` selects records that have an ``id`` column value
226 |                 less than 5, and ``{"id >=": 5, "id <": 10}`` selects records that have an ``id``
227 |                 column value greater than or equal to 5 and less than 10. If specified, this
228 |                 parameter overrides the `filters_json` parameter passed to the constructor.
229 | 
230 |         Returns:
231 |             A list of dictionaries when ``use_with_databricks_agent_framework`` is ``True``,
232 |             or a ``dspy.Prediction`` object when ``use_with_databricks_agent_framework`` is
233 |             ``False``.
234 |         """
235 |         if query_type in ["vector", "text"]:
236 |             # Older versions of DSPy used a `query_type` argument to disambiguate between text
237 |             # and vector queries, rather than checking the type of the `query` argument. This
238 |             # differs from the Databricks Vector Search definition of `query_type`, which
239 |             # specifies the search algorithm to use (e.g. "ANN" or "HYBRID"). To maintain
240 |             # backwards compatibility with older versions of DSPy, we map the old `query_type`
241 |             # values to the Databricks Vector Search default query type of "ANN".
242 |             query_type = "ANN"
243 | 
244 |         if isinstance(query, str):
245 |             query_text = query
246 |             query_vector = None
247 |         elif isinstance(query, list):
248 |             query_vector = query
249 |             query_text = None
250 |         else:
251 |             raise ValueError("Query must be a string or a list of floats.")
252 | 
253 |         if _databricks_sdk_installed:
254 |             results = self._query_via_databricks_sdk(
255 |                 index_name=self.databricks_index_name,
256 |                 k=self.k,
257 |                 columns=self.columns,
258 |                 query_type=query_type,
259 |                 query_text=query_text,
260 |                 query_vector=query_vector,
261 |                 databricks_token=self.databricks_token,
262 |                 databricks_endpoint=self.databricks_endpoint,
263 |                 databricks_client_id=self.databricks_client_id,
264 |                 databricks_client_secret=self.databricks_client_secret,
265 |                 filters_json=filters_json or self.filters_json,
266 |             )
267 |         else:
268 |             results = self._query_via_requests(
269 |                 index_name=self.databricks_index_name,
270 |                 k=self.k,
271 |                 columns=self.columns,
272 |                 databricks_token=self.databricks_token,
273 |                 databricks_endpoint=self.databricks_endpoint,
274 |                 query_type=query_type,
275 |                 query_text=query_text,
276 |                 query_vector=query_vector,
277 |                 filters_json=filters_json or self.filters_json,
278 |             )
279 | 
280 |         # Checking if defined columns are present in the index columns
281 |         col_names = [column["name"] for column in results["manifest"]["columns"]]
282 | 
283 |         if self.docs_id_column_name not in col_names:
284 |             raise Exception(
285 |                 f"docs_id_column_name: '{self.docs_id_column_name}' is not in the index columns: \n {col_names}"
286 |             )
287 | 
288 |         if self.text_column_name not in col_names:
289 |             raise Exception(f"text_column_name: '{self.text_column_name}' is not in the index columns: \n {col_names}")
290 | 
291 |         # Extracting the results
292 |         items = []
293 |         if "data_array" in results["result"]:
294 |             for _, data_row in enumerate(results["result"]["data_array"]):
295 |                 item = {}
296 |                 for col_name, val in zip(col_names, data_row, strict=False):
297 |                     item[col_name] = val
298 |                 items += [item]
299 | 
300 |         # Sorting results by score in descending order
301 |         sorted_docs = sorted(items, key=lambda x: x["score"], reverse=True)[: self.k]
302 | 
303 |         if self.use_with_databricks_agent_framework:
304 |             return [
305 |                 Document(
306 |                     page_content=doc[self.text_column_name],
307 |                     metadata={
308 |                         "doc_id": self._extract_doc_ids(doc),
309 |                         "doc_uri": doc[self.docs_uri_column_name] if self.docs_uri_column_name else None,
310 |                     }
311 |                     | self._get_extra_columns(doc),
312 |                     type="Document",
313 |                 ).to_dict()
314 |                 for doc in sorted_docs
315 |             ]
316 |         else:
317 |             # Returning the prediction
318 |             return Prediction(
319 |                 docs=[doc[self.text_column_name] for doc in sorted_docs],
320 |                 doc_ids=[self._extract_doc_ids(doc) for doc in sorted_docs],
321 |                 doc_uris=[doc[self.docs_uri_column_name] for doc in sorted_docs] if self.docs_uri_column_name else None,
322 |                 extra_columns=[self._get_extra_columns(item) for item in sorted_docs],
323 |             )
324 | 
325 |     @staticmethod
326 |     def _query_via_databricks_sdk(
327 |         index_name: str,
328 |         k: int,
329 |         columns: list[str],
330 |         query_type: str,
331 |         query_text: str | None,
332 |         query_vector: list[float] | None,
333 |         databricks_token: str | None,
334 |         databricks_endpoint: str | None,
335 |         databricks_client_id: str | None,
336 |         databricks_client_secret: str | None,
337 |         filters_json: str | None,
338 |     ) -> dict[str, Any]:
339 |         """
340 |         Query a Databricks Vector Search Index via the Databricks SDK.
341 |         Assumes that the databricks-sdk Python library is installed.
342 | 
343 |         Args:
344 |             index_name (str): Name of the Databricks vector search index to query
345 |             k (int): Number of relevant documents to retrieve.
346 |             columns (list[str]): Column names to include in response.
347 |             query_text (Optional[str]): Text query for which to find relevant documents. Exactly
348 |                 one of query_text or query_vector must be specified.
349 |             query_vector (Optional[list[float]]): Numeric query vector for which to find relevant
350 |                 documents. Exactly one of query_text or query_vector must be specified.
351 |             filters_json (Optional[str]): JSON string representing additional query filters.
352 |             databricks_token (str): Databricks authentication token. If not specified,
353 |                 the token is resolved from the current environment.
354 |             databricks_endpoint (str): Databricks index endpoint url. If not specified,
355 |                 the endpoint is resolved from the current environment.
356 |             databricks_client_id (str): Databricks service principal id. If not specified,
357 |                 the token is resolved from the current environment (DATABRICKS_CLIENT_ID).
358 |             databricks_client_secret (str): Databricks service principal secret. If not specified,
359 |                 the endpoint is resolved from the current environment (DATABRICKS_CLIENT_SECRET).
360 |         Returns:
361 |         Returns:
362 |             dict[str, Any]: Parsed JSON response from the Databricks Vector Search Index query.
363 |         """
364 | 
365 |         from databricks.sdk import WorkspaceClient
366 | 
367 |         if (query_text, query_vector).count(None) != 1:
368 |             raise ValueError("Exactly one of query_text or query_vector must be specified.")
369 | 
370 |         if databricks_client_secret and databricks_client_id:
371 |             # Use client ID and secret for authentication if they are provided
372 |             databricks_client = WorkspaceClient(
373 |                 client_id=databricks_client_id,
374 |                 client_secret=databricks_client_secret,
375 |             )
376 |             print("Creating Databricks workspace client using service principal authentication.")
377 | 
378 |         else:
379 |             # Fallback for token-based authentication
380 |             databricks_client = WorkspaceClient(
381 |                 host=databricks_endpoint,
382 |                 token=databricks_token,
383 |             )
384 |             print("Creating Databricks workspace client using token authentication.")
385 | 
386 |         return databricks_client.vector_search_indexes.query_index(
387 |             index_name=index_name,
388 |             query_type=query_type,
389 |             query_text=query_text,
390 |             query_vector=query_vector,
391 |             columns=columns,
392 |             filters_json=filters_json,
393 |             num_results=k,
394 |         ).as_dict()
395 | 
396 |     @staticmethod
397 |     def _query_via_requests(
398 |         index_name: str,
399 |         k: int,
400 |         columns: list[str],
401 |         databricks_token: str,
402 |         databricks_endpoint: str,
403 |         query_type: str,
404 |         query_text: str | None,
405 |         query_vector: list[float] | None,
406 |         filters_json: str | None,
407 |     ) -> dict[str, Any]:
408 |         """
409 |         Query a Databricks Vector Search Index via the Python requests library.
410 | 
411 |         Args:
412 |             index_name (str): Name of the Databricks vector search index to query
413 |             k (int): Number of relevant documents to retrieve.
414 |             columns (list[str]): Column names to include in response.
415 |             databricks_token (str): Databricks authentication token.
416 |             databricks_endpoint (str): Databricks index endpoint url.
417 |             query_text (Optional[str]): Text query for which to find relevant documents. Exactly
418 |                 one of query_text or query_vector must be specified.
419 |             query_vector (Optional[list[float]]): Numeric query vector for which to find relevant
420 |                 documents. Exactly one of query_text or query_vector must be specified.
421 |             filters_json (Optional[str]): JSON string representing additional query filters.
422 | 
423 |         Returns:
424 |             dict[str, Any]: Parsed JSON response from the Databricks Vector Search Index query.
425 |         """
426 |         if (query_text, query_vector).count(None) != 1:
427 |             raise ValueError("Exactly one of query_text or query_vector must be specified.")
428 | 
429 |         headers = {
430 |             "Authorization": f"Bearer {databricks_token}",
431 |             "Content-Type": "application/json",
432 |         }
433 |         payload = {
434 |             "columns": columns,
435 |             "num_results": k,
436 |             "query_type": query_type,
437 |         }
438 |         if filters_json is not None:
439 |             payload["filters_json"] = filters_json
440 |         if query_text is not None:
441 |             payload["query_text"] = query_text
442 |         elif query_vector is not None:
443 |             payload["query_vector"] = query_vector
444 |         response = requests.post(
445 |             f"{databricks_endpoint}/api/2.0/vector-search/indexes/{index_name}/query",
446 |             json=payload,
447 |             headers=headers,
448 |         )
449 |         results = response.json()
450 |         if "error_code" in results:
451 |             raise Exception(f"ERROR: {results['error_code']} -- {results['message']}")
452 |         return results
453 | 
```

--------------------------------------------------------------------------------
/tests/primitives/test_base_module.py:
--------------------------------------------------------------------------------

```python
  1 | import asyncio
  2 | import logging
  3 | import os
  4 | import threading
  5 | from unittest.mock import patch
  6 | 
  7 | import pytest
  8 | from litellm import Choices, Message, ModelResponse
  9 | from litellm.types.utils import Usage
 10 | 
 11 | import dspy
 12 | from dspy.primitives.prediction import Prediction
 13 | from dspy.utils.dummies import DummyLM
 14 | 
 15 | 
 16 | def test_deepcopy_basic():
 17 |     signature = dspy.Signature("q -> a")
 18 |     cot = dspy.ChainOfThought(signature)
 19 |     cot_copy = cot.deepcopy()
 20 |     assert len(cot.parameters()) == len(cot_copy.parameters())
 21 |     # Parameters should be different objects with the same values.
 22 |     assert id(cot.parameters()[0]) != id(cot_copy.parameters()[0])
 23 |     assert cot.parameters()[0].__dict__ == cot_copy.parameters()[0].__dict__
 24 | 
 25 | 
 26 | def test_deepcopy_with_uncopyable_modules():
 27 |     class CustomClass(dspy.Module):
 28 |         def __init__(self):
 29 |             self.lock = threading.Lock()  # Non-copyable object.
 30 |             self.cot = dspy.ChainOfThought(dspy.Signature("q -> a"))
 31 | 
 32 |     model = CustomClass()
 33 |     model_copy = model.deepcopy()
 34 |     assert len(model.parameters()) == len(model_copy.parameters())
 35 |     # The lock should be refer to the same object (shallow copy).
 36 |     assert id(model.lock) == id(model_copy.lock)
 37 |     # Parameters should be different objects with the same values.
 38 |     assert id(model.parameters()[0]) != id(model_copy.parameters()[0])
 39 |     assert model.parameters()[0].__dict__ == model_copy.parameters()[0].__dict__
 40 | 
 41 | 
 42 | def test_deepcopy_with_nested_modules():
 43 |     class CustomClass1(dspy.Module):
 44 |         def __init__(self):
 45 |             self.lock = threading.Lock()  # Non-copyable object.
 46 |             self.cot = dspy.ChainOfThought(dspy.Signature("q -> a"))
 47 | 
 48 |     class CustomClass2(dspy.Module):
 49 |         def __init__(self):
 50 |             self.submodel = CustomClass1()
 51 | 
 52 |     model = CustomClass2()
 53 |     model_copy = model.deepcopy()
 54 |     assert len(model.parameters()) == len(model_copy.parameters())
 55 |     # The lock should be refer to the same object (shallow copy).
 56 |     assert id(model.submodel.lock) == id(model_copy.submodel.lock)
 57 |     # Parameters should be different objects with the same values.
 58 |     assert id(model.parameters()[0]) != id(model_copy.parameters()[0])
 59 |     assert model.parameters()[0].__dict__ == model_copy.parameters()[0].__dict__
 60 | 
 61 | 
 62 | def test_save_and_load_with_json(tmp_path):
 63 |     model = dspy.ChainOfThought(dspy.Signature("q -> a"))
 64 |     model.predict.signature = model.predict.signature.with_instructions("You are a helpful assistant.")
 65 |     model.predict.demos = [
 66 |         dspy.Example(q="What is the capital of France?", a="Paris", reasoning="n/a").with_inputs("q"),
 67 |         # Nested example
 68 |         dspy.Example(
 69 |             q=[
 70 |                 dspy.Example(q="What is the capital of France?"),
 71 |                 dspy.Example(q="What is actually the capital of France?"),
 72 |             ],
 73 |             a="Paris",
 74 |             reasoning="n/a",
 75 |         ).with_inputs("q"),
 76 |     ]
 77 |     save_path = tmp_path / "model.json"
 78 |     model.save(save_path)
 79 |     new_model = dspy.ChainOfThought(dspy.Signature("q -> a"))
 80 |     new_model.load(save_path)
 81 | 
 82 |     assert str(new_model.predict.signature) == str(model.predict.signature)
 83 |     assert new_model.predict.demos[0] == model.predict.demos[0].toDict()
 84 |     assert new_model.predict.demos[1] == model.predict.demos[1].toDict()
 85 | 
 86 | 
 87 | @pytest.mark.extra
 88 | def test_save_and_load_with_pkl(tmp_path):
 89 |     import datetime
 90 | 
 91 |     # `datetime.date` is not json serializable, so we need to save with pickle.
 92 |     class MySignature(dspy.Signature):
 93 |         """Just a custom signature."""
 94 | 
 95 |         current_date: datetime.date = dspy.InputField()
 96 |         target_date: datetime.date = dspy.InputField()
 97 |         date_diff: int = dspy.OutputField(desc="The difference in days between the current_date and the target_date")
 98 | 
 99 |     trainset = [
100 |         {"current_date": datetime.date(2024, 1, 1), "target_date": datetime.date(2024, 1, 2), "date_diff": 1},
101 |         {"current_date": datetime.date(2024, 1, 1), "target_date": datetime.date(2024, 1, 3), "date_diff": 2},
102 |         {"current_date": datetime.date(2024, 1, 1), "target_date": datetime.date(2024, 1, 4), "date_diff": 3},
103 |         {"current_date": datetime.date(2024, 1, 1), "target_date": datetime.date(2024, 1, 5), "date_diff": 4},
104 |         {"current_date": datetime.date(2024, 1, 1), "target_date": datetime.date(2024, 1, 6), "date_diff": 5},
105 |     ]
106 |     trainset = [dspy.Example(**example).with_inputs("current_date", "target_date") for example in trainset]
107 | 
108 |     dspy.settings.configure(
109 |         lm=DummyLM([{"date_diff": "1", "reasoning": "n/a"}, {"date_diff": "2", "reasoning": "n/a"}] * 10)
110 |     )
111 | 
112 |     cot = dspy.ChainOfThought(MySignature)
113 |     cot(current_date=datetime.date(2024, 1, 1), target_date=datetime.date(2024, 1, 2))
114 | 
115 |     def dummy_metric(example, pred, trace=None):
116 |         return True
117 | 
118 |     optimizer = dspy.BootstrapFewShot(max_bootstrapped_demos=4, max_labeled_demos=4, max_rounds=5, metric=dummy_metric)
119 |     compiled_cot = optimizer.compile(cot, trainset=trainset)
120 |     compiled_cot.predict.signature = compiled_cot.predict.signature.with_instructions("You are a helpful assistant.")
121 | 
122 |     save_path = tmp_path / "program.pkl"
123 |     compiled_cot.save(save_path)
124 | 
125 |     new_cot = dspy.ChainOfThought(MySignature)
126 |     new_cot.load(save_path)
127 | 
128 |     assert str(new_cot.predict.signature) == str(compiled_cot.predict.signature)
129 |     assert new_cot.predict.demos == compiled_cot.predict.demos
130 | 
131 | 
132 | def test_save_with_extra_modules(tmp_path):
133 |     import sys
134 | 
135 |     # Create a temporary Python file with our custom module
136 |     custom_module_path = tmp_path / "custom_module.py"
137 |     with open(custom_module_path, "w") as f:
138 |         f.write("""
139 | import dspy
140 | 
141 | class MyModule(dspy.Module):
142 |     def __init__(self):
143 |         self.cot = dspy.ChainOfThought(dspy.Signature("q -> a"))
144 | 
145 |     def forward(self, q):
146 |         return self.cot(q=q)
147 | """)
148 | 
149 |     # Add the tmp_path to Python path so we can import the module
150 |     sys.path.insert(0, str(tmp_path))
151 |     try:
152 |         import custom_module
153 | 
154 |         cot = custom_module.MyModule()
155 | 
156 |         cot.save(tmp_path, save_program=True)
157 |         # Remove the custom module from sys.modules to simulate it not being available
158 |         sys.modules.pop("custom_module", None)
159 |         # Also remove it from sys.path
160 |         sys.path.remove(str(tmp_path))
161 |         del custom_module
162 | 
163 |         # Test the loading fails without using `modules_to_serialize`
164 |         with pytest.raises(ModuleNotFoundError):
165 |             dspy.load(tmp_path)
166 | 
167 |         sys.path.insert(0, str(tmp_path))
168 |         import custom_module
169 | 
170 |         cot.save(
171 |             tmp_path,
172 |             modules_to_serialize=[custom_module],
173 |             save_program=True,
174 |         )
175 | 
176 |         # Remove the custom module from sys.modules to simulate it not being available
177 |         sys.modules.pop("custom_module", None)
178 |         # Also remove it from sys.path
179 |         sys.path.remove(str(tmp_path))
180 |         del custom_module
181 | 
182 |         loaded_module = dspy.load(tmp_path)
183 |         assert loaded_module.cot.predict.signature == cot.cot.predict.signature
184 | 
185 |     finally:
186 |         # Only need to clean up sys.path
187 |         if str(tmp_path) in sys.path:
188 |             sys.path.remove(str(tmp_path))
189 | 
190 | 
191 | def test_load_with_version_mismatch(tmp_path):
192 |     from dspy.primitives.base_module import logger
193 | 
194 |     # Mock versions during save
195 |     save_versions = {"python": "3.9", "dspy": "2.4.0", "cloudpickle": "2.0"}
196 | 
197 |     # Mock versions during load
198 |     load_versions = {"python": "3.10", "dspy": "2.5.0", "cloudpickle": "2.1"}
199 | 
200 |     predict = dspy.Predict("question->answer")
201 | 
202 |     # Create a custom handler to capture log messages
203 |     class ListHandler(logging.Handler):
204 |         def __init__(self):
205 |             super().__init__()
206 |             self.messages = []
207 | 
208 |         def emit(self, record):
209 |             self.messages.append(record.getMessage())
210 | 
211 |     # Add handler and set level
212 |     handler = ListHandler()
213 |     original_level = logger.level
214 |     logger.addHandler(handler)
215 |     logger.setLevel(logging.WARNING)
216 | 
217 |     try:
218 |         save_path = tmp_path / "program.pkl"
219 |         # Mock version during save
220 |         with patch("dspy.primitives.base_module.get_dependency_versions", return_value=save_versions):
221 |             predict.save(save_path)
222 | 
223 |         # Mock version during load
224 |         with patch("dspy.primitives.base_module.get_dependency_versions", return_value=load_versions):
225 |             loaded_predict = dspy.Predict("question->answer")
226 |             loaded_predict.load(save_path)
227 | 
228 |         # Assert warnings were logged, and one warning for each mismatched dependency.
229 |         assert len(handler.messages) == 3
230 | 
231 |         for msg in handler.messages:
232 |             assert "There is a mismatch of" in msg
233 | 
234 |         # Verify the model still loads correctly despite version mismatches
235 |         assert isinstance(loaded_predict, dspy.Predict)
236 |         assert str(predict.signature) == str(loaded_predict.signature)
237 | 
238 |     finally:
239 |         # Clean up: restore original level and remove handler
240 |         logger.setLevel(original_level)
241 |         logger.removeHandler(handler)
242 | 
243 | 
244 | @pytest.mark.llm_call
245 | def test_single_module_call_with_usage_tracker(lm_for_test):
246 |     dspy.settings.configure(lm=dspy.LM(lm_for_test, cache=False), track_usage=True)
247 | 
248 |     predict = dspy.ChainOfThought("question -> answer")
249 |     output = predict(question="What is the capital of France?")
250 | 
251 |     lm_usage = output.get_lm_usage()
252 |     assert len(lm_usage) == 1
253 |     assert lm_usage[lm_for_test]["prompt_tokens"] > 0
254 |     assert lm_usage[lm_for_test]["completion_tokens"] > 0
255 |     assert lm_usage[lm_for_test]["total_tokens"] > 0
256 | 
257 |     # Test no usage being tracked when cache is enabled
258 |     dspy.settings.configure(lm=dspy.LM(lm_for_test, cache=True), track_usage=True)
259 |     for _ in range(2):
260 |         output = predict(question="What is the capital of France?")
261 | 
262 |     assert len(output.get_lm_usage()) == 0
263 | 
264 | 
265 | @pytest.mark.llm_call
266 | def test_multi_module_call_with_usage_tracker(lm_for_test):
267 |     dspy.settings.configure(lm=dspy.LM(lm_for_test, cache=False), track_usage=True)
268 | 
269 |     class MyProgram(dspy.Module):
270 |         def __init__(self):
271 |             self.predict1 = dspy.ChainOfThought("question -> answer")
272 |             self.predict2 = dspy.ChainOfThought("question, answer -> score")
273 | 
274 |         def __call__(self, question: str) -> Prediction:
275 |             answer = self.predict1(question=question)
276 |             score = self.predict2(question=question, answer=answer)
277 |             return score
278 | 
279 |     program = MyProgram()
280 |     output = program(question="What is the capital of France?")
281 | 
282 |     lm_usage = output.get_lm_usage()
283 |     assert len(lm_usage) == 1
284 |     assert lm_usage[lm_for_test]["prompt_tokens"] > 0
285 |     assert lm_usage[lm_for_test]["prompt_tokens"] > 0
286 |     assert lm_usage[lm_for_test]["completion_tokens"] > 0
287 |     assert lm_usage[lm_for_test]["total_tokens"] > 0
288 | 
289 | 
290 | # TODO: prepare second model for testing this unit test in ci
291 | @pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="Skip the test if OPENAI_API_KEY is not set.")
292 | def test_usage_tracker_in_parallel():
293 |     class MyProgram(dspy.Module):
294 |         def __init__(self, lm):
295 |             self.lm = lm
296 |             self.predict1 = dspy.ChainOfThought("question -> answer")
297 |             self.predict2 = dspy.ChainOfThought("question, answer -> score")
298 | 
299 |         def __call__(self, question: str) -> Prediction:
300 |             with dspy.settings.context(lm=self.lm):
301 |                 answer = self.predict1(question=question)
302 |                 score = self.predict2(question=question, answer=answer)
303 |                 return score
304 | 
305 |     dspy.settings.configure(track_usage=True)
306 |     program1 = MyProgram(lm=dspy.LM("openai/gpt-4o-mini", cache=False))
307 |     program2 = MyProgram(lm=dspy.LM("openai/gpt-3.5-turbo", cache=False))
308 | 
309 |     parallelizer = dspy.Parallel()
310 | 
311 |     results = parallelizer(
312 |         [
313 |             (program1, {"question": "What is the meaning of life?"}),
314 |             (program2, {"question": "why did a chicken cross the kitchen?"}),
315 |         ]
316 |     )
317 | 
318 |     assert results[0].get_lm_usage() is not None
319 |     assert results[1].get_lm_usage() is not None
320 | 
321 |     assert results[0].get_lm_usage().keys() == set(["openai/gpt-4o-mini"])
322 |     assert results[1].get_lm_usage().keys() == set(["openai/gpt-3.5-turbo"])
323 | 
324 | 
325 | @pytest.mark.asyncio
326 | async def test_usage_tracker_async_parallel():
327 |     program = dspy.Predict("question -> answer")
328 | 
329 |     with patch("litellm.acompletion") as mock_completion:
330 |         mock_completion.return_value = ModelResponse(
331 |             choices=[Choices(message=Message(content="{'answer': 'Paris'}"))],
332 |             usage=Usage(
333 |                 **{
334 |                     "prompt_tokens": 1117,
335 |                     "completion_tokens": 46,
336 |                     "total_tokens": 1163,
337 |                     "prompt_tokens_details": {"cached_tokens": 0, "audio_tokens": 0},
338 |                     "completion_tokens_details": {
339 |                         "reasoning_tokens": 0,
340 |                         "audio_tokens": 0,
341 |                         "accepted_prediction_tokens": 0,
342 |                         "rejected_prediction_tokens": 0,
343 |                     },
344 |                 },
345 |             ),
346 |             model="openai/gpt-4o-mini",
347 |         )
348 | 
349 |         coroutines = [
350 |             program.acall(question="What is the capital of France?"),
351 |             program.acall(question="What is the capital of France?"),
352 |             program.acall(question="What is the capital of France?"),
353 |             program.acall(question="What is the capital of France?"),
354 |         ]
355 |         with dspy.settings.context(
356 |             lm=dspy.LM("openai/gpt-4o-mini", cache=False), track_usage=True, adapter=dspy.JSONAdapter()
357 |         ):
358 |             results = await asyncio.gather(*coroutines)
359 | 
360 |         assert results[0].get_lm_usage() is not None
361 |         assert results[1].get_lm_usage() is not None
362 | 
363 |         lm_usage0 = results[0].get_lm_usage()["openai/gpt-4o-mini"]
364 |         lm_usage1 = results[1].get_lm_usage()["openai/gpt-4o-mini"]
365 |         assert lm_usage0["prompt_tokens"] == 1117
366 |         assert lm_usage1["prompt_tokens"] == 1117
367 |         assert lm_usage0["completion_tokens"] == 46
368 |         assert lm_usage1["completion_tokens"] == 46
369 |         assert lm_usage0["total_tokens"] == 1163
370 |         assert lm_usage1["total_tokens"] == 1163
371 | 
372 | 
373 | def test_usage_tracker_no_side_effect():
374 |     class MyProgram(dspy.Module):
375 |         def __init__(self):
376 |             self.predict = dspy.Predict("question -> answer")
377 | 
378 |         def forward(self, question: str, **kwargs) -> str:
379 |             return self.predict(question=question).answer
380 | 
381 |     program = MyProgram()
382 |     with dspy.context(lm=DummyLM([{"answer": "Paris"}]), track_usage=True):
383 |         result = program(question="What is the capital of France?")
384 |     assert result == "Paris"
385 | 
386 | 
387 | def test_module_history():
388 |     class MyProgram(dspy.Module):
389 |         def __init__(self, **kwargs):
390 |             super().__init__(**kwargs)
391 |             self.cot = dspy.ChainOfThought("question -> answer")
392 | 
393 |         def forward(self, question: str, **kwargs) -> Prediction:
394 |             return self.cot(question=question)
395 | 
396 |     with patch("litellm.completion") as mock_completion:
397 |         mock_completion.return_value = ModelResponse(
398 |             choices=[
399 |                 Choices(message=Message(content="{'reasoning': 'Paris is the capital of France', 'answer': 'Paris'}"))
400 |             ],
401 |             model="openai/gpt-4o-mini",
402 |         )
403 |         dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini", cache=False), adapter=dspy.JSONAdapter())
404 |         program = MyProgram()
405 |         program(question="What is the capital of France?")
406 | 
407 |         # Second call only call the submodule.
408 |         program.cot(question="What is the capital of France?")
409 | 
410 |         # The LM history entity exists in all the ancestor callers.
411 |         assert len(program.history) == 1
412 |         assert len(program.cot.history) == 2
413 |         assert len(program.cot.predict.history) == 2
414 | 
415 |         # The same history entity is shared across all the ancestor callers to reduce memory usage.
416 |         assert id(program.history[0]) == id(program.cot.history[0])
417 | 
418 |         assert program.history[0]["outputs"] == ["{'reasoning': 'Paris is the capital of France', 'answer': 'Paris'}"]
419 | 
420 |         dspy.settings.configure(disable_history=True)
421 | 
422 |         program(question="What is the capital of France?")
423 |         # No history is recorded when history is disabled.
424 |         assert len(program.history) == 1
425 |         assert len(program.cot.history) == 2
426 |         assert len(program.cot.predict.history) == 2
427 | 
428 |         dspy.settings.configure(disable_history=False)
429 | 
430 |         program(question="What is the capital of France?")
431 |         # History is recorded again when history is enabled.
432 |         assert len(program.history) == 2
433 |         assert len(program.cot.history) == 3
434 |         assert len(program.cot.predict.history) == 3
435 | 
436 | 
437 | def test_module_history_with_concurrency():
438 |     class MyProgram(dspy.Module):
439 |         def __init__(self):
440 |             super().__init__()
441 |             self.cot = dspy.ChainOfThought("question -> answer")
442 | 
443 |         def forward(self, question: str, **kwargs) -> Prediction:
444 |             return self.cot(question=question)
445 | 
446 |     with patch("litellm.completion") as mock_completion:
447 |         mock_completion.return_value = ModelResponse(
448 |             choices=[Choices(message=Message(content="{'reasoning': 'N/A', 'answer': 'Holy crab!'}"))],
449 |             model="openai/gpt-4o-mini",
450 |         )
451 |         dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini", cache=False), adapter=dspy.JSONAdapter())
452 |         program = MyProgram()
453 | 
454 |         parallelizer = dspy.Parallel()
455 | 
456 |         parallelizer(
457 |             [
458 |                 (program, {"question": "What is the meaning of life?"}),
459 |                 (program, {"question": "why did a chicken cross the kitchen?"}),
460 |             ]
461 |         )
462 |         assert len(program.history) == 2
463 |         assert len(program.cot.history) == 2
464 |         assert len(program.cot.predict.history) == 2
465 | 
466 | 
467 | @pytest.mark.asyncio
468 | async def test_module_history_async():
469 |     class MyProgram(dspy.Module):
470 |         def __init__(self, **kwargs):
471 |             super().__init__(**kwargs)
472 |             self.cot = dspy.ChainOfThought("question -> answer")
473 | 
474 |         async def aforward(self, question: str, **kwargs) -> Prediction:
475 |             return await self.cot.acall(question=question)
476 | 
477 |     with patch("litellm.acompletion") as mock_completion:
478 |         mock_completion.return_value = ModelResponse(
479 |             choices=[
480 |                 Choices(message=Message(content="{'reasoning': 'Paris is the capital of France', 'answer': 'Paris'}"))
481 |             ],
482 |             model="openai/gpt-4o-mini",
483 |         )
484 |         program = MyProgram()
485 |         with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False), adapter=dspy.JSONAdapter()):
486 |             await program.acall(question="What is the capital of France?")
487 | 
488 |             # Second call only call the submodule.
489 |             await program.cot.acall(question="What is the capital of France?")
490 | 
491 |         # The LM history entity exists in all the ancestor callers.
492 |         assert len(program.history) == 1
493 |         assert len(program.cot.history) == 2
494 |         assert len(program.cot.predict.history) == 2
495 | 
496 |         # The same history entity is shared across all the ancestor callers to reduce memory usage.
497 |         assert id(program.history[0]) == id(program.cot.history[0])
498 | 
499 |         assert program.history[0]["outputs"] == ["{'reasoning': 'Paris is the capital of France', 'answer': 'Paris'}"]
500 | 
501 |         with dspy.context(
502 |             disable_history=True, lm=dspy.LM("openai/gpt-4o-mini", cache=False), adapter=dspy.JSONAdapter()
503 |         ):
504 |             await program.acall(question="What is the capital of France?")
505 | 
506 |         # No history is recorded when history is disabled.
507 |         assert len(program.history) == 1
508 |         assert len(program.cot.history) == 2
509 |         assert len(program.cot.predict.history) == 2
510 | 
511 |         with dspy.context(
512 |             disable_history=False, lm=dspy.LM("openai/gpt-4o-mini", cache=False), adapter=dspy.JSONAdapter()
513 |         ):
514 |             await program.acall(question="What is the capital of France?")
515 |         # History is recorded again when history is enabled.
516 |         assert len(program.history) == 2
517 |         assert len(program.cot.history) == 3
518 |         assert len(program.cot.predict.history) == 3
519 | 
520 | 
521 | def test_forward_direct_call_warning(capsys):
522 |     class TestModule(dspy.Module):
523 |         def forward(self, x):
524 |             return x
525 | 
526 |     module = TestModule()
527 |     module.forward("test")
528 |     captured = capsys.readouterr()
529 |     assert "directly is discouraged" in captured.err
530 | 
531 | 
532 | def test_forward_through_call_no_warning(capsys):
533 |     class TestModule(dspy.Module):
534 |         def forward(self, x):
535 |             return x
536 | 
537 |     module = TestModule()
538 |     module(x="test")
539 |     captured = capsys.readouterr()
540 |     assert "directly is discouraged" not in captured.err
541 | 
```

--------------------------------------------------------------------------------
/dspy/clients/lm_local_arbor.py:
--------------------------------------------------------------------------------

```python
  1 | import time
  2 | from datetime import datetime
  3 | from typing import TYPE_CHECKING, Any, TypedDict
  4 | from urllib.parse import urljoin
  5 | 
  6 | import openai
  7 | import requests
  8 | 
  9 | import dspy
 10 | from dspy.clients.provider import Provider, ReinforceJob, TrainingJob
 11 | from dspy.clients.utils_finetune import GRPOGroup, MultiGPUConfig, TrainDataFormat, TrainingStatus, save_data
 12 | 
 13 | if TYPE_CHECKING:
 14 |     from dspy.clients.lm import LM
 15 | 
 16 | 
 17 | class GRPOTrainKwargs(TypedDict):
 18 |     num_generations: int
 19 | 
 20 | 
 21 | class ArborTrainingJob(TrainingJob):
 22 |     def __init__(self, *args, **kwargs):
 23 |         super().__init__(*args, **kwargs)
 24 |         self.provider_file_id = None
 25 |         self.provider_job_id = None
 26 | 
 27 |     def cancel(self):
 28 |         if ArborProvider.does_job_exist(self.provider_job_id):
 29 |             status = self.status()
 30 |             if ArborProvider.is_terminal_training_status(status):
 31 |                 err_msg = "Jobs that are complete cannot be canceled."
 32 |                 err_msg += f" Job with ID {self.provider_job_id} is done."
 33 |                 raise Exception(err_msg)
 34 |             openai.fine_tuning.jobs.cancel(self.provider_job_id)
 35 |             self.provider_job_id = None
 36 | 
 37 |         if self.provider_file_id is not None:
 38 |             if ArborProvider.does_file_exist(self.provider_file_id):
 39 |                 openai.files.delete(self.provider_file_id)
 40 |             self.provider_file_id = None
 41 | 
 42 |         super().cancel()
 43 | 
 44 |     def status(self) -> TrainingStatus:
 45 |         status = ArborProvider.get_training_status(self.provider_job_id)
 46 |         return status
 47 | 
 48 | 
 49 | class ArborReinforceJob(ReinforceJob):
 50 |     DEFAULT_TRAIN_KWARGS = {  # noqa: RUF012
 51 |         "temperature": 0.9,
 52 |         "beta": 0.04,
 53 |         "num_iterations": 1,
 54 |         "per_device_train_batch_size": 8,
 55 |         "learning_rate": 1e-6,
 56 |         "gradient_accumulation_steps": 1,
 57 |         # This is false by default in TRL, but I think it makes sense to be true for us
 58 |         "gradient_checkpointing": True,
 59 |         "lr_scheduler_type": "constant_with_warmup",
 60 |         "max_prompt_length": None,
 61 |         "max_completion_length": None,
 62 |         "gradient_checkpointing_kwargs": None,
 63 |         "bf16": False,
 64 |         "scale_rewards": True,
 65 |         "max_grad_norm": 1.0,
 66 |         "report_to": "none",
 67 |         "log_completions": True,
 68 |         "logging_steps": 100,
 69 |         # By default, none is the model's max context length
 70 |         "max_context_length": None,
 71 |         "lora": False,
 72 |     }
 73 | 
 74 |     def __init__(self, lm: "LM", train_kwargs: GRPOTrainKwargs, gpu_config: MultiGPUConfig = MultiGPUConfig(num_inference_gpus=1, num_training_gpus=1)):
 75 |         # The teleprompter must ensure that this is set
 76 |         if "num_generations" not in train_kwargs:
 77 |             raise ValueError("num_generations must be set in the training kwargs")
 78 | 
 79 |         self.lm = lm
 80 |         self.train_kwargs = train_kwargs
 81 |         self.provider_job_id = None
 82 |         self.checkpoints = {}
 83 |         self.last_checkpoint = None
 84 |         self.gpu_config = gpu_config
 85 | 
 86 |     def initialize(self):
 87 |         # TODO(GRPO Team): Set provider job ID
 88 |         num_generations = self.train_kwargs.get("num_generations")
 89 |         temperature = self.train_kwargs.get("temperature", self.DEFAULT_TRAIN_KWARGS["temperature"])
 90 |         beta = self.train_kwargs.get("beta", self.DEFAULT_TRAIN_KWARGS["beta"])
 91 |         num_iterations = self.train_kwargs.get("num_iterations", self.DEFAULT_TRAIN_KWARGS["num_iterations"])
 92 |         per_device_train_batch_size = self.train_kwargs.get(
 93 |             "per_device_train_batch_size", self.DEFAULT_TRAIN_KWARGS["per_device_train_batch_size"]
 94 |         )
 95 |         learning_rate = self.train_kwargs.get("learning_rate", self.DEFAULT_TRAIN_KWARGS["learning_rate"])
 96 |         gradient_accumulation_steps = self.train_kwargs.get(
 97 |             "gradient_accumulation_steps", self.DEFAULT_TRAIN_KWARGS["gradient_accumulation_steps"]
 98 |         )
 99 |         gradient_checkpointing = self.train_kwargs.get(
100 |             "gradient_checkpointing", self.DEFAULT_TRAIN_KWARGS["gradient_checkpointing"]
101 |         )
102 |         lr_scheduler_type = self.train_kwargs.get("lr_scheduler_type", self.DEFAULT_TRAIN_KWARGS["lr_scheduler_type"])
103 |         max_prompt_length = self.train_kwargs.get("max_prompt_length", self.DEFAULT_TRAIN_KWARGS["max_prompt_length"])
104 |         max_completion_length = self.train_kwargs.get(
105 |             "max_completion_length", self.DEFAULT_TRAIN_KWARGS["max_completion_length"]
106 |         )
107 |         bf16 = self.train_kwargs.get("bf16", self.DEFAULT_TRAIN_KWARGS["bf16"])
108 |         scale_rewards = self.train_kwargs.get("scale_rewards", self.DEFAULT_TRAIN_KWARGS["scale_rewards"])
109 |         gradient_checkpointing_kwargs = self.train_kwargs.get(
110 |             "gradient_checkpointing_kwargs", self.DEFAULT_TRAIN_KWARGS["gradient_checkpointing_kwargs"]
111 |         )
112 |         max_grad_norm = self.train_kwargs.get("max_grad_norm", self.DEFAULT_TRAIN_KWARGS["max_grad_norm"])
113 |         report_to = self.train_kwargs.get("report_to", self.DEFAULT_TRAIN_KWARGS["report_to"])
114 |         log_completions = self.train_kwargs.get("log_completions", self.DEFAULT_TRAIN_KWARGS["log_completions"])
115 |         logging_steps = self.train_kwargs.get("logging_steps", self.DEFAULT_TRAIN_KWARGS["logging_steps"])
116 |         max_context_length = self.train_kwargs.get(
117 |             "max_context_length", self.DEFAULT_TRAIN_KWARGS["max_context_length"]
118 |         )
119 |         lora = self.train_kwargs.get("lora", self.DEFAULT_TRAIN_KWARGS["lora"])
120 |         api_base = self.lm.kwargs["api_base"]
121 | 
122 |         finetune_model = ArborProvider._remove_provider_prefix(self.lm.model)
123 |         # Only multi-GPU is supported for now
124 |         gpu_config_type = "multi"
125 |         data = {
126 |             "model": finetune_model,
127 |             "num_generations": num_generations,
128 |             "temperature": temperature,
129 |             "beta": beta,
130 |             "num_iterations": num_iterations,
131 |             "per_device_train_batch_size": per_device_train_batch_size,
132 |             "learning_rate": learning_rate,
133 |             "gradient_accumulation_steps": gradient_accumulation_steps,
134 |             "gradient_checkpointing": gradient_checkpointing,
135 |             "lr_scheduler_type": lr_scheduler_type,
136 |             "max_prompt_length": max_prompt_length,
137 |             "max_completion_length": max_completion_length,
138 |             "bf16": bf16,
139 |             "scale_rewards": scale_rewards,
140 |             "gradient_checkpointing_kwargs": gradient_checkpointing_kwargs,
141 |             "max_grad_norm": max_grad_norm,
142 |             "report_to": report_to,
143 |             "log_completions": log_completions,
144 |             "logging_steps": logging_steps,
145 |             "max_context_length": max_context_length,
146 |             "lora": lora,
147 |             "gpu_config": {
148 |                 "type": gpu_config_type,
149 |                 gpu_config_type: self.gpu_config,
150 |             },
151 |         }
152 |         url = urljoin(api_base, "fine_tuning/grpo/initialize")
153 |         headers = {"Content-Type": "application/json"}
154 |         response = requests.post(url=url, headers=headers, json=data)
155 |         assert response.status_code == 200, f"Failed to initialize GRPO: {response}"
156 |         response = response.json()
157 |         self.lm.model = ArborProvider._add_provider_prefix(response["current_model"])
158 |         self.provider_job_id = response.get("job_id")
159 | 
160 |     def _run_grpo_step_one_group(
161 |         self, train_group: GRPOGroup, train_data_format: TrainDataFormat | str | None = None
162 |     ):
163 |         # TODO: Check that the data follows the intended format
164 |         api_base = self.lm.kwargs["api_base"]
165 |         # api_key = self.lm.kwargs["api_key"]
166 | 
167 |         finetune_model = ArborProvider._remove_provider_prefix(self.lm.model)
168 |         data = {"job_id": self.provider_job_id, "model": finetune_model, "batch": train_group}
169 |         url = urljoin(api_base, f"fine_tuning/grpo/{self.provider_job_id}/step")
170 |         headers = {"Content-Type": "application/json"}
171 |         response = requests.post(url, headers=headers, json=data)
172 |         assert response.status_code == 200, f"Failed to run a GRPO step: {response.text}"
173 |         response = response.json()
174 |         assert "current_model" in response, f"Response does not contain the next model ID to be used: {response}"
175 |         current_model = response["current_model"]
176 |         self.lm.model = ArborProvider._add_provider_prefix(current_model)
177 | 
178 |     def step(self, train_data: list[GRPOGroup], train_data_format: TrainDataFormat | str | None):
179 |         # Note: TrainDataFormat specifies the format for the inner most dict.
180 |         # Because we run GRPO at the group level, train_data will be a list of
181 |         # groups, where each group is a list of GRPOChatData. Our teleprompters
182 |         # ensure that we pass the right data format.
183 |         # We can consider making this distinction clearer, e.g., by having two
184 |         # different step methods or changing our smallets data format to be the
185 |         # GRPO group.
186 |         # TODO: Support step on the server side
187 |         assert (
188 |             train_data_format == TrainDataFormat.GRPO_CHAT
189 |         ), f"GRPO only supports the GRPO_CHAT data format. Got {train_data_format} instead."
190 |         for group in train_data:
191 |             self._run_grpo_step_one_group(group, train_data_format)
192 | 
193 |     def save_checkpoint(self, checkpoint_name: str, score: float | None = None):
194 |         api_base = self.lm.kwargs["api_base"]
195 |         url = urljoin(api_base, f"fine_tuning/grpo/{self.provider_job_id}/checkpoint")
196 |         headers = {"Content-Type": "application/json"}
197 |         body = {"job_id": self.provider_job_id, "checkpoint_name": checkpoint_name}
198 |         response = requests.post(url, headers=headers, json=body)
199 |         assert response.status_code == 200, f"Failed to save checkpoint: {response.text}"
200 |         response = response.json()
201 | 
202 |         last_checkpoint = response["last_checkpoint"]
203 |         checkpoints = response["checkpoints"]
204 |         checkpoint_model_path = checkpoints[last_checkpoint]
205 |         self.checkpoints[last_checkpoint] = {
206 |             "model_path": checkpoint_model_path,
207 |             "score": score,
208 |         }
209 |         self.last_checkpoint = last_checkpoint
210 | 
211 |     def terminate(self):
212 |         api_base = self.lm.kwargs["api_base"]
213 | 
214 |         url = urljoin(api_base, f"fine_tuning/grpo/{self.provider_job_id}/terminate")
215 |         headers = {"Content-Type": "application/json"}
216 |         body = {"job_id": self.provider_job_id}
217 |         response = requests.post(url, headers=headers, json=body)
218 |         assert response.status_code == 200, f"Failed to terminate GRPO: {response.text}"
219 | 
220 |         response = response.json()
221 |         current_model = response["current_model"]
222 |         self.lm.model = ArborProvider._add_provider_prefix(current_model)
223 | 
224 |     def cancel(self):
225 |         if self.provider_job_id:
226 |             api_base = self.lm.kwargs["api_base"]
227 |             url = urljoin(api_base, f"fine_tuning/grpo/{self.provider_job_id}/cancel")
228 |             headers = {"Content-Type": "application/json"}
229 |             response = requests.post(url, headers=headers)
230 |             if response.status_code == 200:
231 |                 self.provider_job_id = None
232 |             else:
233 |                 raise Exception(f"Failed to cancel GRPO job: {response.text}")
234 | 
235 |     def status(self) -> TrainingStatus:
236 |         status = ArborProvider.get_training_status(self.provider_job_id)
237 |         return status
238 | 
239 | 
240 | class ArborProvider(Provider):
241 |     def __init__(self):
242 |         super().__init__()
243 |         self.finetunable = True
244 |         self.reinforceable = True
245 |         self.TrainingJob = ArborTrainingJob
246 |         self.ReinforceJob = ArborReinforceJob
247 | 
248 |     @staticmethod
249 |     def launch(lm: "LM", launch_kwargs: dict[str, Any] | None = None):
250 |         model = ArborProvider._remove_provider_prefix(lm.model)
251 | 
252 |         api_base = lm.kwargs["api_base"]
253 | 
254 |         launch_kwargs = launch_kwargs or lm.launch_kwargs
255 | 
256 |         # Make request to launch endpoint
257 |         response = requests.post(urljoin(api_base, "chat/launch"), json={"model": model, "launch_kwargs": launch_kwargs})
258 | 
259 |         if response.status_code != 200:
260 |             raise Exception(f"Failed to launch model. Status code: {response.status_code}, Response: {response.text}")
261 | 
262 |         print(f"Inference server for model {model} launched successfully")
263 | 
264 |     @staticmethod
265 |     def kill(lm: "LM", launch_kwargs: dict[str, Any] | None = None):
266 |         api_base = lm.kwargs["api_base"]
267 | 
268 |         response = requests.post(
269 |             urljoin(api_base, "chat/kill"),
270 |         )
271 | 
272 |         if response.status_code != 200:
273 |             raise Exception(f"Failed to kill model. Status code: {response.status_code}, Response: {response.text}")
274 | 
275 |         print("Inference killed successfully")
276 | 
277 |     @staticmethod
278 |     def _remove_provider_prefix(model: str) -> str:
279 |         if model.startswith("openai/"):
280 |             model = model[7:]
281 |         if model.startswith("arbor:"):
282 |             model = model[6:]
283 |         return model
284 | 
285 |     @staticmethod
286 |     def _add_provider_prefix(model: str) -> str:
287 |         if not model.startswith("openai/arbor:"):
288 |             model = "openai/arbor:" + model
289 |         return model
290 | 
291 |     @staticmethod
292 |     def _get_arbor_base_api():
293 |         # TODO: We will delete this method once we start passing the LM object
294 |         # to finetune.
295 |         import dspy.settings as settings
296 | 
297 |         if not hasattr(settings, "arbor_api_base"):
298 |             raise ValueError(
299 |                 "Arbor API base not set. Please set the `dspy.settings.arbor_api_base` to the URL for the Arbor server (e.g. 'http://localhost:8000/v1/')."
300 |             )
301 |         return dspy.settings.arbor_api_base
302 | 
303 |     @staticmethod
304 |     def finetune(
305 |         job: ArborTrainingJob,
306 |         model: str,
307 |         train_data: list[dict[str, Any]],
308 |         train_data_format: TrainDataFormat | None,
309 |         train_kwargs: dict[str, Any] | None = None,
310 |     ) -> str:
311 |         # TODO: We want to re-factor finetune so that it takes in an LM.
312 |         # Until then, we use the following to get the api information. The
313 |         # following is a dummy call to ensure that dspy.settings.arbor_base_api
314 |         # is set.
315 |         ArborProvider._get_arbor_base_api()
316 | 
317 |         model = ArborProvider._remove_provider_prefix(model)
318 | 
319 |         print("[Arbor Provider] Validating the data format")
320 |         ArborProvider.validate_data_format(train_data_format)
321 | 
322 |         print("[Arbor Provider] Saving the data to a file")
323 |         data_path = save_data(train_data)
324 |         print(f"[Arbor Provider] Data saved to {data_path}")
325 | 
326 |         print("[Arbor Provider] Uploading the data to the provider")
327 |         provider_file_id = ArborProvider.upload_data(data_path)
328 |         job.provider_file_id = provider_file_id
329 | 
330 |         print("[Arbor Provider] Starting remote training")
331 |         provider_job_id = ArborProvider._start_remote_training(
332 |             train_file_id=job.provider_file_id,
333 |             model=model,
334 |             train_kwargs=train_kwargs,
335 |         )
336 |         job.provider_job_id = provider_job_id
337 |         print(f"[Arbor Provider] Job started with the Arbor Job ID {provider_job_id}")
338 | 
339 |         print("[Arbor Provider] Waiting for training to complete")
340 |         ArborProvider.wait_for_job(job, train_kwargs)
341 | 
342 |         print("[Arbor Provider] Attempting to retrieve the trained model")
343 |         model = ArborProvider.get_trained_model(job)
344 |         print(f"[Arbor Provider] Model retrieved: {model}")
345 | 
346 |         return ArborProvider._add_provider_prefix(model)
347 | 
348 |     @staticmethod
349 |     def does_job_exist(job_id: str, training_kwargs: dict[str, Any]) -> bool:
350 |         try:
351 |             original_base_url = openai.base_url
352 |             openai.base_url = ArborProvider._get_arbor_base_api()
353 |             openai.fine_tuning.jobs.retrieve(job_id)
354 |             openai.base_url = original_base_url
355 |             return True
356 |         except Exception:
357 |             return False
358 | 
359 |     @staticmethod
360 |     def does_file_exist(file_id: str, training_kwargs: dict[str, Any]) -> bool:
361 |         try:
362 |             original_base_url = openai.base_url
363 |             openai.base_url = ArborProvider._get_arbor_base_api()
364 |             openai.files.retrieve(file_id)
365 |             openai.base_url = original_base_url
366 |             return True
367 |         except Exception:
368 |             return False
369 | 
370 |     @staticmethod
371 |     def is_terminal_training_status(status: TrainingStatus) -> bool:
372 |         return status in [
373 |             TrainingStatus.succeeded,
374 |             TrainingStatus.failed,
375 |             TrainingStatus.cancelled,
376 |         ]
377 | 
378 |     @staticmethod
379 |     def get_training_status(job_id: str, training_kwargs: dict[str, Any]) -> TrainingStatus:
380 |         provider_status_to_training_status = {
381 |             "validating_files": TrainingStatus.pending,
382 |             "queued": TrainingStatus.pending,
383 |             "running": TrainingStatus.running,
384 |             "succeeded": TrainingStatus.succeeded,
385 |             "failed": TrainingStatus.failed,
386 |             "cancelled": TrainingStatus.cancelled,
387 |             "pending": TrainingStatus.pending,
388 |             "pending_pause": TrainingStatus.pending,
389 |             "pending_resume": TrainingStatus.pending,
390 |             "paused": TrainingStatus.pending,
391 |             "pending_cancel": TrainingStatus.pending,
392 |         }
393 | 
394 |         if job_id is None:
395 |             print("There is no active job.")
396 |             return TrainingStatus.not_started
397 | 
398 |         err_msg = f"Job with ID {job_id} does not exist."
399 |         assert ArborProvider.does_job_exist(job_id, training_kwargs), err_msg
400 | 
401 |         original_base_url = openai.base_url
402 |         openai.base_url = ArborProvider._get_arbor_base_api()
403 |         provider_job = openai.fine_tuning.jobs.retrieve(job_id)
404 |         openai.base_url = original_base_url
405 | 
406 |         provider_status = provider_job.status
407 |         status = provider_status_to_training_status[provider_status]
408 | 
409 |         return status
410 | 
411 |     @staticmethod
412 |     def validate_data_format(data_format: TrainDataFormat):
413 |         supported_data_formats = [
414 |             TrainDataFormat.CHAT,
415 |             TrainDataFormat.COMPLETION,
416 |             TrainDataFormat.GRPO_CHAT,
417 |         ]
418 | 
419 |         if data_format not in supported_data_formats:
420 |             err_msg = f"Arbor does not support the data format {data_format}."
421 |             raise ValueError(err_msg)
422 | 
423 |     @staticmethod
424 |     def upload_data(data_path: str, training_kwargs: dict[str, Any]) -> str:
425 |         original_base_url = openai.base_url
426 |         openai.base_url = ArborProvider._get_arbor_base_api()
427 |         provider_file = openai.files.create(
428 |             file=open(data_path, "rb"),
429 |             purpose="fine-tune",
430 |         )
431 |         openai.base_url = original_base_url
432 | 
433 |         return provider_file.id
434 | 
435 |     @staticmethod
436 |     def _start_remote_training(train_file_id: str, model: str, train_kwargs: dict[str, Any]) -> str:
437 |         train_kwargs = train_kwargs or {}
438 |         original_base_url = openai.base_url
439 |         openai.base_url = ArborProvider._get_arbor_base_api()
440 |         provider_job = openai.fine_tuning.jobs.create(
441 |             model=model,
442 |             training_file=train_file_id,
443 |             hyperparameters=train_kwargs,
444 |         )
445 |         openai.base_url = original_base_url
446 |         return provider_job.id
447 | 
448 |     @staticmethod
449 |     def wait_for_job(
450 |         job: TrainingJob,
451 |         training_kwargs: dict[str, Any],
452 |         poll_frequency: int = 20,
453 |     ):
454 |         done = False
455 |         cur_event_id = None
456 |         reported_estimated_time = False
457 |         while not done:
458 |             # Report estimated time if not already reported
459 |             if not reported_estimated_time:
460 |                 original_base_url = openai.base_url
461 |                 openai.base_url = ArborProvider._get_arbor_base_api()
462 |                 remote_job = openai.fine_tuning.jobs.retrieve(job.provider_job_id)
463 |                 openai.base_url = original_base_url
464 | 
465 |                 timestamp = remote_job.estimated_finish
466 |                 if timestamp:
467 |                     estimated_finish_dt = datetime.fromtimestamp(timestamp)
468 |                     delta_dt = estimated_finish_dt - datetime.now()
469 |                     print(f"[Arbor Provider] The Arbor estimated time remaining is: {delta_dt}")
470 |                     reported_estimated_time = True
471 | 
472 |             # Get new events
473 |             original_base_url = openai.base_url
474 |             openai.base_url = ArborProvider._get_arbor_base_api()
475 |             page = openai.fine_tuning.jobs.list_events(fine_tuning_job_id=job.provider_job_id, limit=1)
476 |             openai.base_url = original_base_url
477 | 
478 |             new_event = page.data[0] if page.data else None
479 |             if new_event and new_event.id != cur_event_id:
480 |                 dt = datetime.fromtimestamp(new_event.created_at)
481 |                 print(f"[Arbor Provider] {dt} {new_event.message}")
482 |                 cur_event_id = new_event.id
483 | 
484 |             # Sleep and update the flag
485 |             time.sleep(poll_frequency)
486 |             done = ArborProvider.is_terminal_training_status(job.status())
487 | 
488 |     @staticmethod
489 |     def get_trained_model(job, training_kwargs: dict[str, Any]):
490 |         status = job.status()
491 |         if status != TrainingStatus.succeeded:
492 |             err_msg = f"Job status is {status}."
493 |             err_msg += f" Must be {TrainingStatus.succeeded} to retrieve model."
494 |             raise Exception(err_msg)
495 | 
496 |         original_base_url = openai.base_url
497 |         openai.base_url = ArborProvider._get_arbor_base_api()
498 |         provider_job = openai.fine_tuning.jobs.retrieve(job.provider_job_id)
499 |         openai.base_url = original_base_url
500 | 
501 |         finetuned_model = provider_job.fine_tuned_model
502 |         return finetuned_model
503 | 
```

--------------------------------------------------------------------------------
/dspy/adapters/base.py:
--------------------------------------------------------------------------------

```python
  1 | import logging
  2 | from typing import TYPE_CHECKING, Any, get_origin
  3 | 
  4 | import json_repair
  5 | import litellm
  6 | 
  7 | from dspy.adapters.types import History, Type
  8 | from dspy.adapters.types.base_type import split_message_content_for_custom_types
  9 | from dspy.adapters.types.tool import Tool, ToolCalls
 10 | from dspy.experimental import Citations
 11 | from dspy.signatures.signature import Signature
 12 | from dspy.utils.callback import BaseCallback, with_callbacks
 13 | 
 14 | logger = logging.getLogger(__name__)
 15 | 
 16 | if TYPE_CHECKING:
 17 |     from dspy.clients.lm import LM
 18 | 
 19 | _DEFAULT_NATIVE_RESPONSE_TYPES = [Citations]
 20 | 
 21 | 
 22 | class Adapter:
 23 |     """Base Adapter class.
 24 | 
 25 |     The Adapter serves as the interface layer between DSPy module/signature and Language Models (LMs). It handles the
 26 |     complete transformation pipeline from DSPy inputs to LM calls and back to structured outputs.
 27 | 
 28 |     Key responsibilities:
 29 |         - Transform user inputs and signatures into properly formatted LM prompts, which also instructs the LM to format
 30 |             the response in a specific format.
 31 |         - Parse LM outputs into dictionaries matching the signature's output fields.
 32 |         - Enable/disable native LM features (function calling, citations, etc.) based on configuration.
 33 |         - Handle conversation history, few-shot examples, and custom type processing.
 34 | 
 35 |     The adapter pattern allows DSPy to work with different LM interfaces while maintaining a consistent programming
 36 |     model for users.
 37 |     """
 38 | 
 39 |     def __init__(
 40 |         self,
 41 |         callbacks: list[BaseCallback] | None = None,
 42 |         use_native_function_calling: bool = False,
 43 |         native_response_types: list[type[Type]] | None = None,
 44 |     ):
 45 |         """
 46 |         Args:
 47 |             callbacks: List of callback functions to execute during `format()` and `parse()` methods. Callbacks can be
 48 |                 used for logging, monitoring, or custom processing. Defaults to None (empty list).
 49 |             use_native_function_calling: Whether to enable native function calling capabilities when the LM supports it.
 50 |                 If True, the adapter will automatically configure function calling when input fields contain `dspy.Tool`
 51 |                 or `list[dspy.Tool]` types. Defaults to False.
 52 |             native_response_types: List of output field types that should be handled by native LM features rather than
 53 |                 adapter parsing. For example, `dspy.Citations` can be populated directly by citation APIs
 54 |                 (e.g., Anthropic's citation feature). Defaults to `[Citations]`.
 55 |         """
 56 |         self.callbacks = callbacks or []
 57 |         self.use_native_function_calling = use_native_function_calling
 58 |         self.native_response_types = native_response_types or _DEFAULT_NATIVE_RESPONSE_TYPES
 59 | 
 60 |     def __init_subclass__(cls, **kwargs) -> None:
 61 |         super().__init_subclass__(**kwargs)
 62 | 
 63 |         # Decorate format() and parse() method with with_callbacks
 64 |         cls.format = with_callbacks(cls.format)
 65 |         cls.parse = with_callbacks(cls.parse)
 66 | 
 67 |     def _call_preprocess(
 68 |         self,
 69 |         lm: "LM",
 70 |         lm_kwargs: dict[str, Any],
 71 |         signature: type[Signature],
 72 |         inputs: dict[str, Any],
 73 |     ) -> type[Signature]:
 74 |         if self.use_native_function_calling:
 75 |             tool_call_input_field_name = self._get_tool_call_input_field_name(signature)
 76 |             tool_call_output_field_name = self._get_tool_call_output_field_name(signature)
 77 | 
 78 |             if tool_call_output_field_name and tool_call_input_field_name is None:
 79 |                 raise ValueError(
 80 |                     f"You provided an output field {tool_call_output_field_name} to receive the tool calls information, "
 81 |                     "but did not provide any tools as the input. Please provide a list of tools as the input by adding an "
 82 |                     "input field with type `list[dspy.Tool]`."
 83 |                 )
 84 | 
 85 |             if tool_call_output_field_name and litellm.supports_function_calling(model=lm.model):
 86 |                 tools = inputs[tool_call_input_field_name]
 87 |                 tools = tools if isinstance(tools, list) else [tools]
 88 | 
 89 |                 litellm_tools = []
 90 |                 for tool in tools:
 91 |                     litellm_tools.append(tool.format_as_litellm_function_call())
 92 | 
 93 |                 lm_kwargs["tools"] = litellm_tools
 94 | 
 95 |                 signature_for_native_function_calling = signature.delete(tool_call_output_field_name)
 96 |                 signature_for_native_function_calling = signature_for_native_function_calling.delete(
 97 |                     tool_call_input_field_name
 98 |                 )
 99 | 
100 |                 return signature_for_native_function_calling
101 | 
102 |         # Handle custom types that use native response
103 |         for name, field in signature.output_fields.items():
104 |             if (
105 |                 isinstance(field.annotation, type)
106 |                 and issubclass(field.annotation, Type)
107 |                 and field.annotation in self.native_response_types
108 |             ):
109 |                 signature = signature.delete(name)
110 | 
111 |         return signature
112 | 
113 |     def _call_postprocess(
114 |         self,
115 |         processed_signature: type[Signature],
116 |         original_signature: type[Signature],
117 |         outputs: list[dict[str, Any]],
118 |         lm: "LM",
119 |     ) -> list[dict[str, Any]]:
120 |         values = []
121 | 
122 |         tool_call_output_field_name = self._get_tool_call_output_field_name(original_signature)
123 | 
124 |         for output in outputs:
125 |             output_logprobs = None
126 |             tool_calls = None
127 |             text = output
128 | 
129 |             if isinstance(output, dict):
130 |                 text = output["text"]
131 |                 output_logprobs = output.get("logprobs")
132 |                 tool_calls = output.get("tool_calls")
133 | 
134 |             if text:
135 |                 value = self.parse(processed_signature, text)
136 |                 for field_name in original_signature.output_fields.keys():
137 |                     if field_name not in value:
138 |                         # We need to set the field not present in the processed signature to None for consistency.
139 |                         value[field_name] = None
140 |             else:
141 |                 value = {}
142 |                 for field_name in original_signature.output_fields.keys():
143 |                     value[field_name] = None
144 | 
145 |             if tool_calls and tool_call_output_field_name:
146 |                 tool_calls = [
147 |                     {
148 |                         "name": v["function"]["name"],
149 |                         "args": json_repair.loads(v["function"]["arguments"]),
150 |                     }
151 |                     for v in tool_calls
152 |                 ]
153 |                 value[tool_call_output_field_name] = ToolCalls.from_dict_list(tool_calls)
154 | 
155 |             # Parse custom types that does not rely on the adapter parsing
156 |             for name, field in original_signature.output_fields.items():
157 |                 if (
158 |                     isinstance(field.annotation, type)
159 |                     and issubclass(field.annotation, Type)
160 |                     and field.annotation in self.native_response_types
161 |                 ):
162 |                     value[name] = field.annotation.parse_lm_response(output)
163 | 
164 |             if output_logprobs:
165 |                 value["logprobs"] = output_logprobs
166 | 
167 |             values.append(value)
168 | 
169 |         return values
170 | 
171 |     def __call__(
172 |         self,
173 |         lm: "LM",
174 |         lm_kwargs: dict[str, Any],
175 |         signature: type[Signature],
176 |         demos: list[dict[str, Any]],
177 |         inputs: dict[str, Any],
178 |     ) -> list[dict[str, Any]]:
179 |         """
180 |         Execute the adapter pipeline: format inputs, call LM, and parse outputs.
181 | 
182 |         Args:
183 |             lm: The Language Model instance to use for generation. Must be an instance of `dspy.BaseLM`.
184 |             lm_kwargs: Additional keyword arguments to pass to the LM call (e.g., temperature, max_tokens). These are
185 |                 passed directly to the LM.
186 |             signature: The DSPy signature associated with this LM call.
187 |             demos: List of few-shot examples to include in the prompt. Each dictionary should contain keys matching the
188 |                 signature's input and output field names. Examples are formatted as user/assistant message pairs.
189 |             inputs: The current input values for this call. Keys must match the signature's input field names.
190 | 
191 |         Returns:
192 |             List of dictionaries representing parsed LM responses. Each dictionary contains keys matching the
193 |             signature's output field names. For multiple generations (n > 1), returns multiple dictionaries.
194 |         """
195 |         processed_signature = self._call_preprocess(lm, lm_kwargs, signature, inputs)
196 |         inputs = self.format(processed_signature, demos, inputs)
197 | 
198 |         outputs = lm(messages=inputs, **lm_kwargs)
199 |         return self._call_postprocess(processed_signature, signature, outputs, lm)
200 | 
201 |     async def acall(
202 |         self,
203 |         lm: "LM",
204 |         lm_kwargs: dict[str, Any],
205 |         signature: type[Signature],
206 |         demos: list[dict[str, Any]],
207 |         inputs: dict[str, Any],
208 |     ) -> list[dict[str, Any]]:
209 |         processed_signature = self._call_preprocess(lm, lm_kwargs, signature, inputs)
210 |         inputs = self.format(processed_signature, demos, inputs)
211 | 
212 |         outputs = await lm.acall(messages=inputs, **lm_kwargs)
213 |         return self._call_postprocess(processed_signature, signature, outputs, lm)
214 | 
215 |     def format(
216 |         self,
217 |         signature: type[Signature],
218 |         demos: list[dict[str, Any]],
219 |         inputs: dict[str, Any],
220 |     ) -> list[dict[str, Any]]:
221 |         """Format the input messages for the LM call.
222 | 
223 |         This method converts the DSPy structured input along with few-shot examples and conversation history into
224 |         multiturn messages as expected by the LM. For custom adapters, this method can be overridden to customize
225 |         the formatting of the input messages.
226 | 
227 |         In general we recommend the messages to have the following structure:
228 |         ```
229 |         [
230 |             {"role": "system", "content": system_message},
231 |             # Begin few-shot examples
232 |             {"role": "user", "content": few_shot_example_1_input},
233 |             {"role": "assistant", "content": few_shot_example_1_output},
234 |             {"role": "user", "content": few_shot_example_2_input},
235 |             {"role": "assistant", "content": few_shot_example_2_output},
236 |             ...
237 |             # End few-shot examples
238 |             # Begin conversation history
239 |             {"role": "user", "content": conversation_history_1_input},
240 |             {"role": "assistant", "content": conversation_history_1_output},
241 |             {"role": "user", "content": conversation_history_2_input},
242 |             {"role": "assistant", "content": conversation_history_2_output},
243 |             ...
244 |             # End conversation history
245 |             {"role": "user", "content": current_input},
246 |         ]
247 | 
248 |         And system message should contain the field description, field structure, and task description.
249 |         ```
250 | 
251 | 
252 |         Args:
253 |             signature: The DSPy signature for which to format the input messages.
254 |             demos: A list of few-shot examples.
255 |             inputs: The input arguments to the DSPy module.
256 | 
257 |         Returns:
258 |             A list of multiturn messages as expected by the LM.
259 |         """
260 |         inputs_copy = dict(inputs)
261 | 
262 |         # If the signature and inputs have conversation history, we need to format the conversation history and
263 |         # remove the history field from the signature.
264 |         history_field_name = self._get_history_field_name(signature)
265 |         if history_field_name:
266 |             # In order to format the conversation history, we need to remove the history field from the signature.
267 |             signature_without_history = signature.delete(history_field_name)
268 |             conversation_history = self.format_conversation_history(
269 |                 signature_without_history,
270 |                 history_field_name,
271 |                 inputs_copy,
272 |             )
273 | 
274 |         messages = []
275 |         system_message = (
276 |             f"{self.format_field_description(signature)}\n"
277 |             f"{self.format_field_structure(signature)}\n"
278 |             f"{self.format_task_description(signature)}"
279 |         )
280 |         messages.append({"role": "system", "content": system_message})
281 |         messages.extend(self.format_demos(signature, demos))
282 |         if history_field_name:
283 |             # Conversation history and current input
284 |             content = self.format_user_message_content(signature_without_history, inputs_copy, main_request=True)
285 |             messages.extend(conversation_history)
286 |             messages.append({"role": "user", "content": content})
287 |         else:
288 |             # Only current input
289 |             content = self.format_user_message_content(signature, inputs_copy, main_request=True)
290 |             messages.append({"role": "user", "content": content})
291 | 
292 |         messages = split_message_content_for_custom_types(messages)
293 |         return messages
294 | 
295 |     def format_field_description(self, signature: type[Signature]) -> str:
296 |         """Format the field description for the system message.
297 | 
298 |         This method formats the field description for the system message. It should return a string that contains
299 |         the field description for the input fields and the output fields.
300 | 
301 |         Args:
302 |             signature: The DSPy signature for which to format the field description.
303 | 
304 |         Returns:
305 |             A string that contains the field description for the input fields and the output fields.
306 |         """
307 |         raise NotImplementedError
308 | 
309 |     def format_field_structure(self, signature: type[Signature]) -> str:
310 |         """Format the field structure for the system message.
311 | 
312 |         This method formats the field structure for the system message. It should return a string that dictates the
313 |         format the input fields should be provided to the LM, and the format the output fields will be in the response.
314 |         Refer to the ChatAdapter and JsonAdapter for an example.
315 | 
316 |         Args:
317 |             signature: The DSPy signature for which to format the field structure.
318 |         """
319 |         raise NotImplementedError
320 | 
321 |     def format_task_description(self, signature: type[Signature]) -> str:
322 |         """Format the task description for the system message.
323 | 
324 |         This method formats the task description for the system message. In most cases this is just a thin wrapper
325 |         over `signature.instructions`.
326 | 
327 |         Args:
328 |             signature: The DSPy signature of the DSpy module.
329 | 
330 |         Returns:
331 |             A string that describes the task.
332 |         """
333 |         raise NotImplementedError
334 | 
335 |     def format_user_message_content(
336 |         self,
337 |         signature: type[Signature],
338 |         inputs: dict[str, Any],
339 |         prefix: str = "",
340 |         suffix: str = "",
341 |         main_request: bool = False,
342 |     ) -> str:
343 |         """Format the user message content.
344 | 
345 |         This method formats the user message content, which can be used in formatting few-shot examples, conversation
346 |         history, and the current input.
347 | 
348 |         Args:
349 |             signature: The DSPy signature for which to format the user message content.
350 |             inputs: The input arguments to the DSPy module.
351 |             prefix: A prefix to the user message content.
352 |             suffix: A suffix to the user message content.
353 | 
354 |         Returns:
355 |             A string that contains the user message content.
356 |         """
357 |         raise NotImplementedError
358 | 
359 |     def format_assistant_message_content(
360 |         self,
361 |         signature: type[Signature],
362 |         outputs: dict[str, Any],
363 |         missing_field_message: str | None = None,
364 |     ) -> str:
365 |         """Format the assistant message content.
366 | 
367 |         This method formats the assistant message content, which can be used in formatting few-shot examples,
368 |         conversation history.
369 | 
370 |         Args:
371 |             signature: The DSPy signature for which to format the assistant message content.
372 |             outputs: The output fields to be formatted.
373 |             missing_field_message: A message to be used when a field is missing.
374 | 
375 |         Returns:
376 |             A string that contains the assistant message content.
377 |         """
378 |         raise NotImplementedError
379 | 
380 |     def format_demos(self, signature: type[Signature], demos: list[dict[str, Any]]) -> list[dict[str, Any]]:
381 |         """Format the few-shot examples.
382 | 
383 |         This method formats the few-shot examples as multiturn messages.
384 | 
385 |         Args:
386 |             signature: The DSPy signature for which to format the few-shot examples.
387 |             demos: A list of few-shot examples, each element is a dictionary with keys of the input and output fields of
388 |                 the signature.
389 | 
390 |         Returns:
391 |             A list of multiturn messages.
392 |         """
393 |         complete_demos = []
394 |         incomplete_demos = []
395 | 
396 |         for demo in demos:
397 |             # Check if all fields are present and not None
398 |             is_complete = all(k in demo and demo[k] is not None for k in signature.fields)
399 | 
400 |             # Check if demo has at least one input and one output field
401 |             has_input = any(k in demo for k in signature.input_fields)
402 |             has_output = any(k in demo for k in signature.output_fields)
403 | 
404 |             if is_complete:
405 |                 complete_demos.append(demo)
406 |             elif has_input and has_output:
407 |                 # We only keep incomplete demos that have at least one input and one output field
408 |                 incomplete_demos.append(demo)
409 | 
410 |         messages = []
411 | 
412 |         incomplete_demo_prefix = "This is an example of the task, though some input or output fields are not supplied."
413 |         for demo in incomplete_demos:
414 |             messages.append(
415 |                 {
416 |                     "role": "user",
417 |                     "content": self.format_user_message_content(signature, demo, prefix=incomplete_demo_prefix),
418 |                 }
419 |             )
420 |             messages.append(
421 |                 {
422 |                     "role": "assistant",
423 |                     "content": self.format_assistant_message_content(
424 |                         signature, demo, missing_field_message="Not supplied for this particular example. "
425 |                     ),
426 |                 }
427 |             )
428 | 
429 |         for demo in complete_demos:
430 |             messages.append({"role": "user", "content": self.format_user_message_content(signature, demo)})
431 |             messages.append(
432 |                 {
433 |                     "role": "assistant",
434 |                     "content": self.format_assistant_message_content(
435 |                         signature, demo, missing_field_message="Not supplied for this conversation history message. "
436 |                     ),
437 |                 }
438 |             )
439 | 
440 |         return messages
441 | 
442 |     def _get_history_field_name(self, signature: type[Signature]) -> bool:
443 |         for name, field in signature.input_fields.items():
444 |             if field.annotation == History:
445 |                 return name
446 |         return None
447 | 
448 |     def _get_tool_call_input_field_name(self, signature: type[Signature]) -> bool:
449 |         for name, field in signature.input_fields.items():
450 |             # Look for annotation `list[dspy.Tool]` or `dspy.Tool`
451 |             origin = get_origin(field.annotation)
452 |             if origin is list and field.annotation.__args__[0] == Tool:
453 |                 return name
454 |             if field.annotation == Tool:
455 |                 return name
456 |         return None
457 | 
458 |     def _get_tool_call_output_field_name(self, signature: type[Signature]) -> bool:
459 |         for name, field in signature.output_fields.items():
460 |             if field.annotation == ToolCalls:
461 |                 return name
462 |         return None
463 | 
464 |     def format_conversation_history(
465 |         self,
466 |         signature: type[Signature],
467 |         history_field_name: str,
468 |         inputs: dict[str, Any],
469 |     ) -> list[dict[str, Any]]:
470 |         """Format the conversation history.
471 | 
472 |         This method formats the conversation history and the current input as multiturn messages.
473 | 
474 |         Args:
475 |             signature: The DSPy signature for which to format the conversation history.
476 |             history_field_name: The name of the history field in the signature.
477 |             inputs: The input arguments to the DSPy module.
478 | 
479 |         Returns:
480 |             A list of multiturn messages.
481 |         """
482 |         conversation_history = inputs[history_field_name].messages if history_field_name in inputs else None
483 | 
484 |         if conversation_history is None:
485 |             return []
486 | 
487 |         messages = []
488 |         for message in conversation_history:
489 |             messages.append(
490 |                 {
491 |                     "role": "user",
492 |                     "content": self.format_user_message_content(signature, message),
493 |                 }
494 |             )
495 |             messages.append(
496 |                 {
497 |                     "role": "assistant",
498 |                     "content": self.format_assistant_message_content(signature, message),
499 |                 }
500 |             )
501 | 
502 |         # Remove the history field from the inputs
503 |         del inputs[history_field_name]
504 | 
505 |         return messages
506 | 
507 |     def parse(self, signature: type[Signature], completion: str) -> dict[str, Any]:
508 |         """Parse the LM output into a dictionary of the output fields.
509 | 
510 |         This method parses the LM output into a dictionary of the output fields.
511 | 
512 |         Args:
513 |             signature: The DSPy signature for which to parse the LM output.
514 |             completion: The LM output to be parsed.
515 | 
516 |         Returns:
517 |             A dictionary of the output fields.
518 |         """
519 |         raise NotImplementedError
520 | 
```
Page 12/17FirstPrevNextLast