This is page 12 of 17. Use http://codebase.md/stanfordnlp/dspy?lines=true&page={x} to view the full context. # Directory Structure ``` ├── .github │ ├── .internal_dspyai │ │ ├── internals │ │ │ ├── build-and-release.md │ │ │ └── release-checklist.md │ │ └── pyproject.toml │ ├── .tmp │ │ └── .generated-actions │ │ └── run-pypi-publish-in-docker-container │ │ └── action.yml │ ├── ISSUE_TEMPLATE │ │ ├── bug_report.yml │ │ └── feature_request.yml │ ├── PULL_REQUEST_TEMPLATE │ │ └── pull_request_template.md │ ├── workflow_scripts │ │ └── install_testpypi_pkg.sh │ └── workflows │ ├── build_and_release.yml │ ├── build_utils │ │ └── test_version.py │ ├── docs-push.yml │ ├── precommits_check.yml │ └── run_tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── docs │ ├── .gitignore │ ├── docs │ │ ├── api │ │ │ ├── adapters │ │ │ │ ├── Adapter.md │ │ │ │ ├── ChatAdapter.md │ │ │ │ ├── JSONAdapter.md │ │ │ │ └── TwoStepAdapter.md │ │ │ ├── evaluation │ │ │ │ ├── answer_exact_match.md │ │ │ │ ├── answer_passage_match.md │ │ │ │ ├── CompleteAndGrounded.md │ │ │ │ ├── Evaluate.md │ │ │ │ ├── EvaluationResult.md │ │ │ │ └── SemanticF1.md │ │ │ ├── experimental │ │ │ │ ├── Citations.md │ │ │ │ └── Document.md │ │ │ ├── index.md │ │ │ ├── models │ │ │ │ ├── Embedder.md │ │ │ │ └── LM.md │ │ │ ├── modules │ │ │ │ ├── BestOfN.md │ │ │ │ ├── ChainOfThought.md │ │ │ │ ├── CodeAct.md │ │ │ │ ├── Module.md │ │ │ │ ├── MultiChainComparison.md │ │ │ │ ├── Parallel.md │ │ │ │ ├── Predict.md │ │ │ │ ├── ProgramOfThought.md │ │ │ │ ├── ReAct.md │ │ │ │ └── Refine.md │ │ │ ├── optimizers │ │ │ │ ├── BetterTogether.md │ │ │ │ ├── BootstrapFewShot.md │ │ │ │ ├── BootstrapFewShotWithRandomSearch.md │ │ │ │ ├── BootstrapFinetune.md │ │ │ │ ├── BootstrapRS.md │ │ │ │ ├── COPRO.md │ │ │ │ ├── Ensemble.md │ │ │ │ ├── GEPA │ │ │ │ │ ├── GEPA_Advanced.md │ │ │ │ │ └── overview.md │ │ │ │ ├── InferRules.md │ │ │ │ ├── KNN.md │ │ │ │ ├── KNNFewShot.md │ │ │ │ ├── LabeledFewShot.md │ │ │ │ ├── MIPROv2.md │ │ │ │ └── SIMBA.md │ │ │ ├── primitives │ │ │ │ ├── Audio.md │ │ │ │ ├── Code.md │ │ │ │ ├── Example.md │ │ │ │ ├── History.md │ │ │ │ ├── Image.md │ │ │ │ ├── Prediction.md │ │ │ │ ├── Tool.md │ │ │ │ └── ToolCalls.md │ │ │ ├── signatures │ │ │ │ ├── InputField.md │ │ │ │ ├── OutputField.md │ │ │ │ └── Signature.md │ │ │ ├── tools │ │ │ │ ├── ColBERTv2.md │ │ │ │ ├── Embeddings.md │ │ │ │ └── PythonInterpreter.md │ │ │ └── utils │ │ │ ├── asyncify.md │ │ │ ├── configure_cache.md │ │ │ ├── disable_litellm_logging.md │ │ │ ├── disable_logging.md │ │ │ ├── enable_litellm_logging.md │ │ │ ├── enable_logging.md │ │ │ ├── inspect_history.md │ │ │ ├── load.md │ │ │ ├── StatusMessage.md │ │ │ ├── StatusMessageProvider.md │ │ │ ├── streamify.md │ │ │ └── StreamListener.md │ │ ├── cheatsheet.md │ │ ├── community │ │ │ ├── community-resources.md │ │ │ ├── how-to-contribute.md │ │ │ └── use-cases.md │ │ ├── deep-dive │ │ │ └── data-handling │ │ │ ├── built-in-datasets.md │ │ │ ├── examples.md │ │ │ ├── img │ │ │ │ └── data-loading.png │ │ │ └── loading-custom-data.md │ │ ├── faqs.md │ │ ├── index.md │ │ ├── js │ │ │ └── runllm-widget.js │ │ ├── learn │ │ │ ├── evaluation │ │ │ │ ├── data.md │ │ │ │ ├── metrics.md │ │ │ │ └── overview.md │ │ │ ├── figures │ │ │ │ ├── native_tool_call.png │ │ │ │ └── teleprompter-classes.png │ │ │ ├── index.md │ │ │ ├── optimization │ │ │ │ ├── optimizers.md │ │ │ │ └── overview.md │ │ │ └── programming │ │ │ ├── 7-assertions.md │ │ │ ├── adapters.md │ │ │ ├── language_models.md │ │ │ ├── mcp.md │ │ │ ├── modules.md │ │ │ ├── overview.md │ │ │ ├── signatures.md │ │ │ └── tools.md │ │ ├── production │ │ │ └── index.md │ │ ├── roadmap.md │ │ ├── static │ │ │ ├── .nojekyll │ │ │ └── img │ │ │ ├── dspy_logo.png │ │ │ ├── logo.png │ │ │ ├── mlflow-tracing-rag.png │ │ │ ├── modular.png │ │ │ ├── optimize.png │ │ │ ├── undraw_docusaurus_mountain.svg │ │ │ ├── undraw_docusaurus_react.svg │ │ │ ├── undraw_docusaurus_tree.svg │ │ │ └── universal_compatibility.png │ │ ├── stylesheets │ │ │ └── extra.css │ │ └── tutorials │ │ ├── agents │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-agent.png │ │ ├── ai_text_game │ │ │ └── index.md │ │ ├── async │ │ │ └── index.md │ │ ├── audio │ │ │ └── index.ipynb │ │ ├── build_ai_program │ │ │ └── index.md │ │ ├── cache │ │ │ └── index.md │ │ ├── classification │ │ │ └── index.md │ │ ├── classification_finetuning │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-classification.png │ │ ├── conversation_history │ │ │ └── index.md │ │ ├── core_development │ │ │ └── index.md │ │ ├── custom_module │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-custom-module.png │ │ ├── customer_service_agent │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-customer-service-agent.png │ │ ├── deployment │ │ │ ├── dspy_mlflow_ui.png │ │ │ └── index.md │ │ ├── email_extraction │ │ │ ├── index.md │ │ │ └── mlflow-tracing-email-extraction.png │ │ ├── entity_extraction │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-entity-extraction.png │ │ ├── games │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-agent.png │ │ ├── gepa_ai_program │ │ │ └── index.md │ │ ├── gepa_aime │ │ │ ├── index.ipynb │ │ │ ├── mlflow-tracing-gepa-aime.png │ │ │ └── mlflow-tracking-gepa-aime-optimization.png │ │ ├── gepa_facilitysupportanalyzer │ │ │ ├── index.ipynb │ │ │ ├── mlflow-tracing-gepa-support.png │ │ │ └── mlflow-tracking-gepa-support-optimization.png │ │ ├── gepa_papillon │ │ │ ├── index.ipynb │ │ │ ├── mlflow-tracing-gepa-papilon.png │ │ │ └── mlflow-tracking-gepa-papilon-optimization.png │ │ ├── image_generation_prompting │ │ │ └── index.ipynb │ │ ├── index.md │ │ ├── llms_txt_generation │ │ │ └── index.md │ │ ├── math │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-math.png │ │ ├── mcp │ │ │ └── index.md │ │ ├── mem0_react_agent │ │ │ └── index.md │ │ ├── multihop_search │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-multi-hop.png │ │ ├── observability │ │ │ ├── index.md │ │ │ ├── mlflow_trace_ui_navigation.gif │ │ │ ├── mlflow_trace_ui.png │ │ │ └── mlflow_trace_view.png │ │ ├── optimize_ai_program │ │ │ └── index.md │ │ ├── optimizer_tracking │ │ │ ├── child_run.png │ │ │ ├── experiment.png │ │ │ ├── index.md │ │ │ └── parent_run.png │ │ ├── output_refinement │ │ │ └── best-of-n-and-refine.md │ │ ├── papillon │ │ │ └── index.md │ │ ├── program_of_thought │ │ │ └── index.ipynb │ │ ├── rag │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-rag.png │ │ ├── real_world_examples │ │ │ └── index.md │ │ ├── rl_ai_program │ │ │ └── index.md │ │ ├── rl_multihop │ │ │ └── index.ipynb │ │ ├── rl_papillon │ │ │ └── index.ipynb │ │ ├── sample_code_generation │ │ │ └── index.md │ │ ├── saving │ │ │ └── index.md │ │ ├── streaming │ │ │ └── index.md │ │ ├── tool_use │ │ │ └── index.ipynb │ │ └── yahoo_finance_react │ │ └── index.md │ ├── mkdocs.yml │ ├── overrides │ │ ├── home.html │ │ ├── main.html │ │ └── partials │ │ └── tabs.html │ ├── Pipfile │ ├── Pipfile.lock │ ├── README.md │ ├── requirements.txt │ ├── scripts │ │ ├── generate_api_docs.py │ │ └── generate_api_summary.py │ └── vercel.json ├── dspy │ ├── __init__.py │ ├── __metadata__.py │ ├── adapters │ │ ├── __init__.py │ │ ├── baml_adapter.py │ │ ├── base.py │ │ ├── chat_adapter.py │ │ ├── json_adapter.py │ │ ├── two_step_adapter.py │ │ ├── types │ │ │ ├── __init__.py │ │ │ ├── audio.py │ │ │ ├── base_type.py │ │ │ ├── citation.py │ │ │ ├── code.py │ │ │ ├── document.py │ │ │ ├── history.py │ │ │ ├── image.py │ │ │ └── tool.py │ │ ├── utils.py │ │ └── xml_adapter.py │ ├── clients │ │ ├── __init__.py │ │ ├── base_lm.py │ │ ├── cache.py │ │ ├── databricks.py │ │ ├── embedding.py │ │ ├── lm_local_arbor.py │ │ ├── lm_local.py │ │ ├── lm.py │ │ ├── openai.py │ │ ├── provider.py │ │ └── utils_finetune.py │ ├── datasets │ │ ├── __init__.py │ │ ├── alfworld │ │ │ ├── __init__.py │ │ │ ├── alfworld.py │ │ │ └── base_config.yml │ │ ├── colors.py │ │ ├── dataloader.py │ │ ├── dataset.py │ │ ├── gsm8k.py │ │ ├── hotpotqa.py │ │ └── math.py │ ├── dsp │ │ ├── __init__.py │ │ ├── colbertv2.py │ │ └── utils │ │ ├── __init__.py │ │ ├── dpr.py │ │ ├── settings.py │ │ └── utils.py │ ├── evaluate │ │ ├── __init__.py │ │ ├── auto_evaluation.py │ │ ├── evaluate.py │ │ └── metrics.py │ ├── experimental │ │ └── __init__.py │ ├── predict │ │ ├── __init__.py │ │ ├── aggregation.py │ │ ├── avatar │ │ │ ├── __init__.py │ │ │ ├── avatar.py │ │ │ ├── models.py │ │ │ └── signatures.py │ │ ├── best_of_n.py │ │ ├── chain_of_thought.py │ │ ├── code_act.py │ │ ├── knn.py │ │ ├── multi_chain_comparison.py │ │ ├── parallel.py │ │ ├── parameter.py │ │ ├── predict.py │ │ ├── program_of_thought.py │ │ ├── react.py │ │ ├── refine.py │ │ └── retry.py │ ├── primitives │ │ ├── __init__.py │ │ ├── base_module.py │ │ ├── example.py │ │ ├── module.py │ │ ├── prediction.py │ │ ├── python_interpreter.py │ │ └── runner.js │ ├── propose │ │ ├── __init__.py │ │ ├── dataset_summary_generator.py │ │ ├── grounded_proposer.py │ │ ├── propose_base.py │ │ └── utils.py │ ├── retrievers │ │ ├── __init__.py │ │ ├── databricks_rm.py │ │ ├── embeddings.py │ │ ├── retrieve.py │ │ └── weaviate_rm.py │ ├── signatures │ │ ├── __init__.py │ │ ├── field.py │ │ ├── signature.py │ │ └── utils.py │ ├── streaming │ │ ├── __init__.py │ │ ├── messages.py │ │ ├── streamify.py │ │ └── streaming_listener.py │ ├── teleprompt │ │ ├── __init__.py │ │ ├── avatar_optimizer.py │ │ ├── bettertogether.py │ │ ├── bootstrap_finetune.py │ │ ├── bootstrap_trace.py │ │ ├── bootstrap.py │ │ ├── copro_optimizer.py │ │ ├── ensemble.py │ │ ├── gepa │ │ │ ├── __init__.py │ │ │ ├── gepa_utils.py │ │ │ ├── gepa.py │ │ │ └── instruction_proposal.py │ │ ├── grpo.py │ │ ├── infer_rules.py │ │ ├── knn_fewshot.py │ │ ├── mipro_optimizer_v2.py │ │ ├── random_search.py │ │ ├── signature_opt.py │ │ ├── simba_utils.py │ │ ├── simba.py │ │ ├── teleprompt_optuna.py │ │ ├── teleprompt.py │ │ ├── utils.py │ │ └── vanilla.py │ └── utils │ ├── __init__.py │ ├── annotation.py │ ├── asyncify.py │ ├── caching.py │ ├── callback.py │ ├── dummies.py │ ├── exceptions.py │ ├── hasher.py │ ├── inspect_history.py │ ├── langchain_tool.py │ ├── logging_utils.py │ ├── mcp.py │ ├── parallelizer.py │ ├── saving.py │ ├── syncify.py │ ├── unbatchify.py │ └── usage_tracker.py ├── LICENSE ├── pyproject.toml ├── README.md ├── tests │ ├── __init__.py │ ├── adapters │ │ ├── test_adapter_utils.py │ │ ├── test_baml_adapter.py │ │ ├── test_base_type.py │ │ ├── test_chat_adapter.py │ │ ├── test_citation.py │ │ ├── test_code.py │ │ ├── test_document.py │ │ ├── test_json_adapter.py │ │ ├── test_tool.py │ │ ├── test_two_step_adapter.py │ │ └── test_xml_adapter.py │ ├── callback │ │ └── test_callback.py │ ├── clients │ │ ├── test_cache.py │ │ ├── test_databricks.py │ │ ├── test_embedding.py │ │ ├── test_inspect_global_history.py │ │ └── test_lm.py │ ├── conftest.py │ ├── datasets │ │ └── test_dataset.py │ ├── docs │ │ └── test_mkdocs_links.py │ ├── evaluate │ │ ├── test_evaluate.py │ │ └── test_metrics.py │ ├── examples │ │ └── test_baleen.py │ ├── metadata │ │ └── test_metadata.py │ ├── predict │ │ ├── test_aggregation.py │ │ ├── test_best_of_n.py │ │ ├── test_chain_of_thought.py │ │ ├── test_code_act.py │ │ ├── test_knn.py │ │ ├── test_multi_chain_comparison.py │ │ ├── test_parallel.py │ │ ├── test_predict.py │ │ ├── test_program_of_thought.py │ │ ├── test_react.py │ │ ├── test_refine.py │ │ └── test_retry.py │ ├── primitives │ │ ├── resources │ │ │ └── saved_program.json │ │ ├── test_base_module.py │ │ ├── test_example.py │ │ ├── test_module.py │ │ └── test_python_interpreter.py │ ├── propose │ │ └── test_grounded_proposer.py │ ├── README.md │ ├── reliability │ │ ├── __init__.py │ │ ├── complex_types │ │ │ └── generated │ │ │ ├── test_many_types_1 │ │ │ │ ├── inputs │ │ │ │ │ ├── input1.json │ │ │ │ │ └── input2.json │ │ │ │ ├── program.py │ │ │ │ └── schema.json │ │ │ ├── test_nesting_1 │ │ │ │ ├── inputs │ │ │ │ │ ├── input1.json │ │ │ │ │ └── input2.json │ │ │ │ ├── program.py │ │ │ │ └── schema.json │ │ │ └── test_nesting_2 │ │ │ ├── inputs │ │ │ │ └── input1.json │ │ │ ├── program.py │ │ │ └── schema.json │ │ ├── conftest.py │ │ ├── generate │ │ │ ├── __init__.py │ │ │ ├── __main__.py │ │ │ └── utils.py │ │ ├── input_formats │ │ │ └── generated │ │ │ └── test_markdown_1 │ │ │ ├── inputs │ │ │ │ ├── input1.json │ │ │ │ └── input2.json │ │ │ ├── program.py │ │ │ └── schema.json │ │ ├── README.md │ │ ├── reliability_conf.yaml │ │ ├── test_generated.py │ │ ├── test_pydantic_models.py │ │ └── utils.py │ ├── retrievers │ │ └── test_embeddings.py │ ├── signatures │ │ ├── test_adapter_image.py │ │ ├── test_custom_types.py │ │ └── test_signature.py │ ├── streaming │ │ └── test_streaming.py │ ├── teleprompt │ │ ├── gepa_dummy_lm_custom_component_selector_custom_instruction_proposer.json │ │ ├── gepa_dummy_lm.json │ │ ├── test_bootstrap_finetune.py │ │ ├── test_bootstrap_trace.py │ │ ├── test_bootstrap.py │ │ ├── test_copro_optimizer.py │ │ ├── test_ensemble.py │ │ ├── test_finetune.py │ │ ├── test_gepa_instruction_proposer.py │ │ ├── test_gepa.py │ │ ├── test_grpo.py │ │ ├── test_knn_fewshot.py │ │ ├── test_random_search.py │ │ ├── test_teleprompt.py │ │ └── test_utils.py │ ├── test_utils │ │ ├── __init__.py │ │ └── server │ │ ├── __init__.py │ │ ├── litellm_server_config.yaml │ │ └── litellm_server.py │ └── utils │ ├── __init__.py │ ├── resources │ │ └── mcp_server.py │ ├── test_annotation.py │ ├── test_asyncify.py │ ├── test_exceptions.py │ ├── test_langchain_tool.py │ ├── test_mcp.py │ ├── test_parallelizer.py │ ├── test_saving.py │ ├── test_settings.py │ ├── test_syncify.py │ ├── test_unbatchify.py │ └── test_usage_tracker.py └── uv.lock ``` # Files -------------------------------------------------------------------------------- /dspy/clients/lm.py: -------------------------------------------------------------------------------- ```python 1 | import logging 2 | import os 3 | import re 4 | import threading 5 | import warnings 6 | from typing import Any, Literal, cast 7 | 8 | import litellm 9 | from anyio.streams.memory import MemoryObjectSendStream 10 | from asyncer import syncify 11 | 12 | import dspy 13 | from dspy.clients.cache import request_cache 14 | from dspy.clients.openai import OpenAIProvider 15 | from dspy.clients.provider import Provider, ReinforceJob, TrainingJob 16 | from dspy.clients.utils_finetune import MultiGPUConfig, TrainDataFormat 17 | from dspy.dsp.utils.settings import settings 18 | from dspy.utils.callback import BaseCallback 19 | 20 | from .base_lm import BaseLM 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | class LM(BaseLM): 26 | """ 27 | A language model supporting chat or text completion requests for use with DSPy modules. 28 | """ 29 | 30 | def __init__( 31 | self, 32 | model: str, 33 | model_type: Literal["chat", "text", "responses"] = "chat", 34 | temperature: float | None = None, 35 | max_tokens: int | None = None, 36 | cache: bool = True, 37 | callbacks: list[BaseCallback] | None = None, 38 | num_retries: int = 3, 39 | provider: Provider | None = None, 40 | finetuning_model: str | None = None, 41 | launch_kwargs: dict[str, Any] | None = None, 42 | train_kwargs: dict[str, Any] | None = None, 43 | use_developer_role: bool = False, 44 | **kwargs, 45 | ): 46 | """ 47 | Create a new language model instance for use with DSPy modules and programs. 48 | 49 | Args: 50 | model: The model to use. This should be a string of the form ``"llm_provider/llm_name"`` 51 | supported by LiteLLM. For example, ``"openai/gpt-4o"``. 52 | model_type: The type of the model, either ``"chat"`` or ``"text"``. 53 | temperature: The sampling temperature to use when generating responses. 54 | max_tokens: The maximum number of tokens to generate per response. 55 | cache: Whether to cache the model responses for reuse to improve performance 56 | and reduce costs. 57 | callbacks: A list of callback functions to run before and after each request. 58 | num_retries: The number of times to retry a request if it fails transiently due to 59 | network error, rate limiting, etc. Requests are retried with exponential 60 | backoff. 61 | provider: The provider to use. If not specified, the provider will be inferred from the model. 62 | finetuning_model: The model to finetune. In some providers, the models available for finetuning is different 63 | from the models available for inference. 64 | rollout_id: Optional integer used to differentiate cache entries for otherwise 65 | identical requests. Different values bypass DSPy's caches while still caching 66 | future calls with the same inputs and rollout ID. Note that `rollout_id` 67 | only affects generation when `temperature` is non-zero. This argument is 68 | stripped before sending requests to the provider. 69 | """ 70 | # Remember to update LM.copy() if you modify the constructor! 71 | self.model = model 72 | self.model_type = model_type 73 | self.cache = cache 74 | self.provider = provider or self.infer_provider() 75 | self.callbacks = callbacks or [] 76 | self.history = [] 77 | self.num_retries = num_retries 78 | self.finetuning_model = finetuning_model 79 | self.launch_kwargs = launch_kwargs or {} 80 | self.train_kwargs = train_kwargs or {} 81 | self.use_developer_role = use_developer_role 82 | self._warned_zero_temp_rollout = False 83 | 84 | # Handle model-specific configuration for different model families 85 | model_family = model.split("/")[-1].lower() if "/" in model else model.lower() 86 | 87 | # Recognize OpenAI reasoning models (o1, o3, o4, gpt-5 family) 88 | model_pattern = re.match(r"^(?:o[1345]|gpt-5)(?:-(?:mini|nano))?", model_family) 89 | 90 | if model_pattern: 91 | 92 | if (temperature and temperature != 1.0) or (max_tokens and max_tokens < 16000): 93 | raise ValueError( 94 | "OpenAI's reasoning models require passing temperature=1.0 or None and max_tokens >= 16000 or None to " 95 | "`dspy.LM(...)`, e.g., dspy.LM('openai/gpt-5', temperature=1.0, max_tokens=16000)" 96 | ) 97 | self.kwargs = dict(temperature=temperature, max_completion_tokens=max_tokens, **kwargs) 98 | if self.kwargs.get("rollout_id") is None: 99 | self.kwargs.pop("rollout_id", None) 100 | else: 101 | self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs) 102 | if self.kwargs.get("rollout_id") is None: 103 | self.kwargs.pop("rollout_id", None) 104 | 105 | self._warn_zero_temp_rollout(self.kwargs.get("temperature"), self.kwargs.get("rollout_id")) 106 | 107 | def _warn_zero_temp_rollout(self, temperature: float | None, rollout_id): 108 | if not self._warned_zero_temp_rollout and rollout_id is not None and (temperature is None or temperature == 0): 109 | warnings.warn( 110 | "rollout_id has no effect when temperature=0; set temperature>0 to bypass the cache.", 111 | stacklevel=3, 112 | ) 113 | self._warned_zero_temp_rollout = True 114 | 115 | def _get_cached_completion_fn(self, completion_fn, cache): 116 | ignored_args_for_cache_key = ["api_key", "api_base", "base_url"] 117 | if cache: 118 | completion_fn = request_cache( 119 | cache_arg_name="request", 120 | ignored_args_for_cache_key=ignored_args_for_cache_key, 121 | )(completion_fn) 122 | 123 | litellm_cache_args = {"no-cache": True, "no-store": True} 124 | 125 | return completion_fn, litellm_cache_args 126 | 127 | def forward(self, prompt=None, messages=None, **kwargs): 128 | # Build the request. 129 | kwargs = dict(kwargs) 130 | cache = kwargs.pop("cache", self.cache) 131 | 132 | messages = messages or [{"role": "user", "content": prompt}] 133 | if self.use_developer_role and self.model_type == "responses": 134 | messages = [{**m, "role": "developer"} if m.get("role") == "system" else m for m in messages] 135 | kwargs = {**self.kwargs, **kwargs} 136 | self._warn_zero_temp_rollout(kwargs.get("temperature"), kwargs.get("rollout_id")) 137 | if kwargs.get("rollout_id") is None: 138 | kwargs.pop("rollout_id", None) 139 | 140 | if self.model_type == "chat": 141 | completion = litellm_completion 142 | elif self.model_type == "text": 143 | completion = litellm_text_completion 144 | elif self.model_type == "responses": 145 | completion = litellm_responses_completion 146 | completion, litellm_cache_args = self._get_cached_completion_fn(completion, cache) 147 | 148 | results = completion( 149 | request=dict(model=self.model, messages=messages, **kwargs), 150 | num_retries=self.num_retries, 151 | cache=litellm_cache_args, 152 | ) 153 | 154 | self._check_truncation(results) 155 | 156 | if not getattr(results, "cache_hit", False) and dspy.settings.usage_tracker and hasattr(results, "usage"): 157 | settings.usage_tracker.add_usage(self.model, dict(results.usage)) 158 | return results 159 | 160 | async def aforward(self, prompt=None, messages=None, **kwargs): 161 | # Build the request. 162 | kwargs = dict(kwargs) 163 | cache = kwargs.pop("cache", self.cache) 164 | 165 | messages = messages or [{"role": "user", "content": prompt}] 166 | if self.use_developer_role and self.model_type == "responses": 167 | messages = [{**m, "role": "developer"} if m.get("role") == "system" else m for m in messages] 168 | kwargs = {**self.kwargs, **kwargs} 169 | self._warn_zero_temp_rollout(kwargs.get("temperature"), kwargs.get("rollout_id")) 170 | if kwargs.get("rollout_id") is None: 171 | kwargs.pop("rollout_id", None) 172 | 173 | if self.model_type == "chat": 174 | completion = alitellm_completion 175 | elif self.model_type == "text": 176 | completion = alitellm_text_completion 177 | elif self.model_type == "responses": 178 | completion = alitellm_responses_completion 179 | completion, litellm_cache_args = self._get_cached_completion_fn(completion, cache) 180 | 181 | results = await completion( 182 | request=dict(model=self.model, messages=messages, **kwargs), 183 | num_retries=self.num_retries, 184 | cache=litellm_cache_args, 185 | ) 186 | 187 | self._check_truncation(results) 188 | 189 | if not getattr(results, "cache_hit", False) and dspy.settings.usage_tracker and hasattr(results, "usage"): 190 | settings.usage_tracker.add_usage(self.model, dict(results.usage)) 191 | return results 192 | 193 | def launch(self, launch_kwargs: dict[str, Any] | None = None): 194 | self.provider.launch(self, launch_kwargs) 195 | 196 | def kill(self, launch_kwargs: dict[str, Any] | None = None): 197 | self.provider.kill(self, launch_kwargs) 198 | 199 | def finetune( 200 | self, 201 | train_data: list[dict[str, Any]], 202 | train_data_format: TrainDataFormat | None, 203 | train_kwargs: dict[str, Any] | None = None, 204 | ) -> TrainingJob: 205 | from dspy import settings as settings 206 | 207 | if not self.provider.finetunable: 208 | raise ValueError( 209 | f"Provider {self.provider} does not support fine-tuning, please specify your provider by explicitly " 210 | "setting `provider` when creating the `dspy.LM` instance. For example, " 211 | "`dspy.LM('openai/gpt-4.1-mini-2025-04-14', provider=dspy.OpenAIProvider())`." 212 | ) 213 | 214 | def thread_function_wrapper(): 215 | return self._run_finetune_job(job) 216 | 217 | thread = threading.Thread(target=thread_function_wrapper) 218 | train_kwargs = train_kwargs or self.train_kwargs 219 | model_to_finetune = self.finetuning_model or self.model 220 | job = self.provider.TrainingJob( 221 | thread=thread, 222 | model=model_to_finetune, 223 | train_data=train_data, 224 | train_data_format=train_data_format, 225 | train_kwargs=train_kwargs, 226 | ) 227 | thread.start() 228 | 229 | return job 230 | 231 | def reinforce( 232 | self, train_kwargs, gpu_config: MultiGPUConfig = MultiGPUConfig(num_inference_gpus=1, num_training_gpus=1) 233 | ) -> ReinforceJob: 234 | # TODO(GRPO Team): Should we return an initialized job here? 235 | from dspy import settings as settings 236 | 237 | err = f"Provider {self.provider} does not implement the reinforcement learning interface." 238 | assert self.provider.reinforceable, err 239 | 240 | job = self.provider.ReinforceJob(lm=self, train_kwargs=train_kwargs, gpu_config=gpu_config) 241 | job.initialize() 242 | return job 243 | 244 | def _run_finetune_job(self, job: TrainingJob): 245 | # TODO(enhance): We should listen for keyboard interrupts somewhere. 246 | # Requires TrainingJob.cancel() to be implemented for each provider. 247 | try: 248 | model = self.provider.finetune( 249 | job=job, 250 | model=job.model, 251 | train_data=job.train_data, 252 | train_data_format=job.train_data_format, 253 | train_kwargs=job.train_kwargs, 254 | ) 255 | lm = self.copy(model=model) 256 | job.set_result(lm) 257 | except Exception as err: 258 | logger.error(err) 259 | job.set_result(err) 260 | 261 | def infer_provider(self) -> Provider: 262 | if OpenAIProvider.is_provider_model(self.model): 263 | return OpenAIProvider() 264 | return Provider() 265 | 266 | def dump_state(self): 267 | state_keys = [ 268 | "model", 269 | "model_type", 270 | "cache", 271 | "num_retries", 272 | "finetuning_model", 273 | "launch_kwargs", 274 | "train_kwargs", 275 | ] 276 | return {key: getattr(self, key) for key in state_keys} | self.kwargs 277 | 278 | def _check_truncation(self, results): 279 | if self.model_type != "responses" and any(c.finish_reason == "length" for c in results["choices"]): 280 | logger.warning( 281 | f"LM response was truncated due to exceeding max_tokens={self.kwargs['max_tokens']}. " 282 | "You can inspect the latest LM interactions with `dspy.inspect_history()`. " 283 | "To avoid truncation, consider passing a larger max_tokens when setting up dspy.LM. " 284 | f"You may also consider increasing the temperature (currently {self.kwargs['temperature']}) " 285 | " if the reason for truncation is repetition." 286 | ) 287 | 288 | 289 | def _get_stream_completion_fn( 290 | request: dict[str, Any], 291 | cache_kwargs: dict[str, Any], 292 | sync=True, 293 | ): 294 | stream = dspy.settings.send_stream 295 | caller_predict = dspy.settings.caller_predict 296 | 297 | if stream is None: 298 | return None 299 | 300 | # The stream is already opened, and will be closed by the caller. 301 | stream = cast(MemoryObjectSendStream, stream) 302 | caller_predict_id = id(caller_predict) if caller_predict else None 303 | 304 | if dspy.settings.track_usage: 305 | request["stream_options"] = {"include_usage": True} 306 | 307 | async def stream_completion(request: dict[str, Any], cache_kwargs: dict[str, Any]): 308 | headers = request.pop("headers", None) 309 | response = await litellm.acompletion( 310 | cache=cache_kwargs, 311 | stream=True, 312 | headers=_get_headers(headers), 313 | **request, 314 | ) 315 | chunks = [] 316 | async for chunk in response: 317 | if caller_predict_id: 318 | # Add the predict id to the chunk so that the stream listener can identify which predict produces it. 319 | chunk.predict_id = caller_predict_id 320 | chunks.append(chunk) 321 | await stream.send(chunk) 322 | return litellm.stream_chunk_builder(chunks) 323 | 324 | def sync_stream_completion(): 325 | syncified_stream_completion = syncify(stream_completion) 326 | return syncified_stream_completion(request, cache_kwargs) 327 | 328 | async def async_stream_completion(): 329 | return await stream_completion(request, cache_kwargs) 330 | 331 | if sync: 332 | return sync_stream_completion 333 | else: 334 | return async_stream_completion 335 | 336 | 337 | def litellm_completion(request: dict[str, Any], num_retries: int, cache: dict[str, Any] | None = None): 338 | cache = cache or {"no-cache": True, "no-store": True} 339 | request = dict(request) 340 | request.pop("rollout_id", None) 341 | headers = request.pop("headers", None) 342 | stream_completion = _get_stream_completion_fn(request, cache, sync=True) 343 | if stream_completion is None: 344 | return litellm.completion( 345 | cache=cache, 346 | num_retries=num_retries, 347 | retry_strategy="exponential_backoff_retry", 348 | headers=_get_headers(headers), 349 | **request, 350 | ) 351 | 352 | return stream_completion() 353 | 354 | 355 | def litellm_text_completion(request: dict[str, Any], num_retries: int, cache: dict[str, Any] | None = None): 356 | cache = cache or {"no-cache": True, "no-store": True} 357 | request = dict(request) 358 | request.pop("rollout_id", None) 359 | headers = request.pop("headers", None) 360 | # Extract the provider and model from the model string. 361 | # TODO: Not all the models are in the format of "provider/model" 362 | model = request.pop("model").split("/", 1) 363 | provider, model = model[0] if len(model) > 1 else "openai", model[-1] 364 | 365 | # Use the API key and base from the request, or from the environment. 366 | api_key = request.pop("api_key", None) or os.getenv(f"{provider}_API_KEY") 367 | api_base = request.pop("api_base", None) or os.getenv(f"{provider}_API_BASE") 368 | 369 | # Build the prompt from the messages. 370 | prompt = "\n\n".join([x["content"] for x in request.pop("messages")] + ["BEGIN RESPONSE:"]) 371 | 372 | return litellm.text_completion( 373 | cache=cache, 374 | model=f"text-completion-openai/{model}", 375 | api_key=api_key, 376 | api_base=api_base, 377 | prompt=prompt, 378 | num_retries=num_retries, 379 | retry_strategy="exponential_backoff_retry", 380 | headers=_get_headers(headers), 381 | **request, 382 | ) 383 | 384 | 385 | async def alitellm_completion(request: dict[str, Any], num_retries: int, cache: dict[str, Any] | None = None): 386 | cache = cache or {"no-cache": True, "no-store": True} 387 | request = dict(request) 388 | request.pop("rollout_id", None) 389 | headers = request.pop("headers", None) 390 | stream_completion = _get_stream_completion_fn(request, cache, sync=False) 391 | if stream_completion is None: 392 | return await litellm.acompletion( 393 | cache=cache, 394 | num_retries=num_retries, 395 | retry_strategy="exponential_backoff_retry", 396 | headers=_get_headers(headers), 397 | **request, 398 | ) 399 | 400 | return await stream_completion() 401 | 402 | 403 | async def alitellm_text_completion(request: dict[str, Any], num_retries: int, cache: dict[str, Any] | None = None): 404 | cache = cache or {"no-cache": True, "no-store": True} 405 | request = dict(request) 406 | request.pop("rollout_id", None) 407 | model = request.pop("model").split("/", 1) 408 | headers = request.pop("headers", None) 409 | provider, model = model[0] if len(model) > 1 else "openai", model[-1] 410 | 411 | # Use the API key and base from the request, or from the environment. 412 | api_key = request.pop("api_key", None) or os.getenv(f"{provider}_API_KEY") 413 | api_base = request.pop("api_base", None) or os.getenv(f"{provider}_API_BASE") 414 | 415 | # Build the prompt from the messages. 416 | prompt = "\n\n".join([x["content"] for x in request.pop("messages")] + ["BEGIN RESPONSE:"]) 417 | 418 | return await litellm.atext_completion( 419 | cache=cache, 420 | model=f"text-completion-openai/{model}", 421 | api_key=api_key, 422 | api_base=api_base, 423 | prompt=prompt, 424 | num_retries=num_retries, 425 | retry_strategy="exponential_backoff_retry", 426 | headers=_get_headers(headers), 427 | **request, 428 | ) 429 | 430 | 431 | def litellm_responses_completion(request: dict[str, Any], num_retries: int, cache: dict[str, Any] | None = None): 432 | cache = cache or {"no-cache": True, "no-store": True} 433 | request = dict(request) 434 | request.pop("rollout_id", None) 435 | headers = request.pop("headers", None) 436 | request = _convert_chat_request_to_responses_request(request) 437 | 438 | return litellm.responses( 439 | cache=cache, 440 | num_retries=num_retries, 441 | retry_strategy="exponential_backoff_retry", 442 | headers=_get_headers(headers), 443 | **request, 444 | ) 445 | 446 | 447 | async def alitellm_responses_completion(request: dict[str, Any], num_retries: int, cache: dict[str, Any] | None = None): 448 | cache = cache or {"no-cache": True, "no-store": True} 449 | request = dict(request) 450 | request.pop("rollout_id", None) 451 | headers = request.pop("headers", None) 452 | request = _convert_chat_request_to_responses_request(request) 453 | 454 | return await litellm.aresponses( 455 | cache=cache, 456 | num_retries=num_retries, 457 | retry_strategy="exponential_backoff_retry", 458 | headers=_get_headers(headers), 459 | **request, 460 | ) 461 | 462 | 463 | def _convert_chat_request_to_responses_request(request: dict[str, Any]): 464 | request = dict(request) 465 | if "messages" in request: 466 | content_blocks = [] 467 | for msg in request.pop("messages"): 468 | c = msg.get("content") 469 | if isinstance(c, str): 470 | content_blocks.append({"type": "input_text", "text": c}) 471 | elif isinstance(c, list): 472 | content_blocks.extend(c) 473 | request["input"] = [{"role": msg.get("role", "user"), "content": content_blocks}] 474 | 475 | # Convert `response_format` to `text.format` for Responses API 476 | if "response_format" in request: 477 | response_format = request.pop("response_format") 478 | text = request.pop("text", {}) 479 | request["text"] = {**text, "format": response_format} 480 | 481 | return request 482 | 483 | def _get_headers(headers: dict[str, Any] | None = None): 484 | headers = headers or {} 485 | return { 486 | "User-Agent": f"DSPy/{dspy.__version__}", 487 | **headers, 488 | } 489 | ``` -------------------------------------------------------------------------------- /tests/clients/test_lm.py: -------------------------------------------------------------------------------- ```python 1 | import json 2 | import time 3 | import warnings 4 | from unittest import mock 5 | from unittest.mock import patch 6 | 7 | import litellm 8 | import pydantic 9 | import pytest 10 | from litellm.types.llms.openai import ResponseAPIUsage, ResponsesAPIResponse 11 | from litellm.utils import Choices, Message, ModelResponse 12 | from openai import RateLimitError 13 | from openai.types.responses import ResponseOutputMessage, ResponseReasoningItem 14 | from openai.types.responses.response_reasoning_item import Summary 15 | 16 | import dspy 17 | from dspy.utils.dummies import DummyLM 18 | from dspy.utils.usage_tracker import track_usage 19 | 20 | 21 | def make_response(output_blocks): 22 | return ResponsesAPIResponse( 23 | id="resp_1", 24 | created_at=0.0, 25 | error=None, 26 | incomplete_details=None, 27 | instructions=None, 28 | model="openai/dspy-test-model", 29 | object="response", 30 | output=output_blocks, 31 | metadata = {}, 32 | parallel_tool_calls=False, 33 | temperature=1.0, 34 | tool_choice="auto", 35 | tools=[], 36 | top_p=1.0, 37 | max_output_tokens=None, 38 | previous_response_id=None, 39 | reasoning=None, 40 | status="completed", 41 | text=None, 42 | truncation="disabled", 43 | usage=ResponseAPIUsage(input_tokens=1, output_tokens=1, total_tokens=2), 44 | user=None, 45 | ) 46 | 47 | 48 | def test_chat_lms_can_be_queried(litellm_test_server): 49 | api_base, _ = litellm_test_server 50 | expected_response = ["Hi!"] 51 | 52 | openai_lm = dspy.LM( 53 | model="openai/dspy-test-model", 54 | api_base=api_base, 55 | api_key="fakekey", 56 | model_type="chat", 57 | ) 58 | assert openai_lm("openai query") == expected_response 59 | 60 | azure_openai_lm = dspy.LM( 61 | model="azure/dspy-test-model", 62 | api_base=api_base, 63 | api_key="fakekey", 64 | model_type="chat", 65 | ) 66 | assert azure_openai_lm("azure openai query") == expected_response 67 | 68 | 69 | def test_dspy_cache(litellm_test_server, tmp_path): 70 | api_base, _ = litellm_test_server 71 | 72 | original_cache = dspy.cache 73 | dspy.clients.configure_cache( 74 | enable_disk_cache=True, 75 | enable_memory_cache=True, 76 | disk_cache_dir=tmp_path / ".disk_cache", 77 | ) 78 | cache = dspy.cache 79 | 80 | lm = dspy.LM( 81 | model="openai/dspy-test-model", 82 | api_base=api_base, 83 | api_key="fakekey", 84 | model_type="text", 85 | ) 86 | with track_usage() as usage_tracker: 87 | lm("Query") 88 | 89 | assert len(cache.memory_cache) == 1 90 | cache_key = next(iter(cache.memory_cache.keys())) 91 | assert cache_key in cache.disk_cache 92 | assert len(usage_tracker.usage_data) == 1 93 | 94 | with track_usage() as usage_tracker: 95 | lm("Query") 96 | 97 | assert len(usage_tracker.usage_data) == 0 98 | 99 | dspy.cache = original_cache 100 | 101 | 102 | def test_disabled_cache_skips_cache_key(monkeypatch): 103 | original_cache = dspy.cache 104 | dspy.configure_cache(enable_disk_cache=False, enable_memory_cache=False) 105 | cache = dspy.cache 106 | 107 | try: 108 | with mock.patch.object(cache, "cache_key", wraps=cache.cache_key) as cache_key_spy, \ 109 | mock.patch.object(cache, "get", wraps=cache.get) as cache_get_spy, \ 110 | mock.patch.object(cache, "put", wraps=cache.put) as cache_put_spy: 111 | 112 | def fake_completion(*, cache, num_retries, retry_strategy, **request): 113 | return ModelResponse( 114 | choices=[Choices(message=Message(role="assistant", content="Hi!"))], 115 | usage={"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}, 116 | model="dummy", 117 | ) 118 | 119 | monkeypatch.setattr(litellm, "completion", fake_completion) 120 | 121 | dummy_lm = DummyLM([{"answer": "ignored"}]) 122 | # TODO(isaacbmiller): Change from dummy_lm.forward to just dummy_lm.__call__ #8864 123 | dummy_lm.forward(messages=[{"role": "user", "content": "Hello"}]) 124 | 125 | cache_key_spy.assert_not_called() 126 | cache_get_spy.assert_called_once() 127 | cache_put_spy.assert_called_once() 128 | finally: 129 | dspy.cache = original_cache 130 | 131 | 132 | def test_rollout_id_bypasses_cache(monkeypatch, tmp_path): 133 | calls: list[dict] = [] 134 | 135 | def fake_completion(*, cache, num_retries, retry_strategy, **request): 136 | calls.append(request) 137 | return ModelResponse( 138 | choices=[Choices(message=Message(role="assistant", content="Hi!"))], 139 | usage={"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}, 140 | model="openai/dspy-test-model", 141 | ) 142 | 143 | monkeypatch.setattr(litellm, "completion", fake_completion) 144 | 145 | original_cache = dspy.cache 146 | dspy.clients.configure_cache( 147 | enable_disk_cache=True, 148 | enable_memory_cache=True, 149 | disk_cache_dir=tmp_path / ".disk_cache", 150 | ) 151 | 152 | lm = dspy.LM(model="openai/dspy-test-model", model_type="chat") 153 | 154 | with track_usage() as usage_tracker: 155 | lm(messages=[{"role": "user", "content": "Query"}], rollout_id=1) 156 | assert len(usage_tracker.usage_data) == 1 157 | 158 | with track_usage() as usage_tracker: 159 | lm(messages=[{"role": "user", "content": "Query"}], rollout_id=1) 160 | assert len(usage_tracker.usage_data) == 0 161 | 162 | with track_usage() as usage_tracker: 163 | lm(messages=[{"role": "user", "content": "Query"}], rollout_id=2) 164 | assert len(usage_tracker.usage_data) == 1 165 | 166 | with track_usage() as usage_tracker: 167 | lm(messages=[{"role": "user", "content": "NoRID"}]) 168 | assert len(usage_tracker.usage_data) == 1 169 | 170 | with track_usage() as usage_tracker: 171 | lm(messages=[{"role": "user", "content": "NoRID"}], rollout_id=None) 172 | assert len(usage_tracker.usage_data) == 0 173 | 174 | assert len(dspy.cache.memory_cache) == 3 175 | assert all("rollout_id" not in r for r in calls) 176 | dspy.cache = original_cache 177 | 178 | 179 | def test_zero_temperature_rollout_warns_once(monkeypatch): 180 | def fake_completion(*, cache, num_retries, retry_strategy, **request): 181 | return ModelResponse( 182 | choices=[Choices(message=Message(role="assistant", content="Hi!"))], 183 | usage={"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}, 184 | model="openai/dspy-test-model", 185 | ) 186 | 187 | monkeypatch.setattr(litellm, "completion", fake_completion) 188 | 189 | lm = dspy.LM(model="openai/dspy-test-model", model_type="chat") 190 | with pytest.warns(UserWarning, match="rollout_id has no effect"): 191 | lm("Query", rollout_id=1) 192 | with warnings.catch_warnings(record=True) as record: 193 | warnings.simplefilter("always") 194 | lm("Query", rollout_id=2) 195 | assert len(record) == 0 196 | 197 | 198 | def test_text_lms_can_be_queried(litellm_test_server): 199 | api_base, _ = litellm_test_server 200 | expected_response = ["Hi!"] 201 | 202 | openai_lm = dspy.LM( 203 | model="openai/dspy-test-model", 204 | api_base=api_base, 205 | api_key="fakekey", 206 | model_type="text", 207 | ) 208 | assert openai_lm("openai query") == expected_response 209 | 210 | azure_openai_lm = dspy.LM( 211 | model="azure/dspy-test-model", 212 | api_base=api_base, 213 | api_key="fakekey", 214 | model_type="text", 215 | ) 216 | assert azure_openai_lm("azure openai query") == expected_response 217 | 218 | 219 | def test_lm_calls_support_callables(litellm_test_server): 220 | api_base, _ = litellm_test_server 221 | 222 | with mock.patch("litellm.completion", autospec=True, wraps=litellm.completion) as spy_completion: 223 | 224 | def azure_ad_token_provider(*args, **kwargs): 225 | return None 226 | 227 | lm_with_callable = dspy.LM( 228 | model="openai/dspy-test-model", 229 | api_base=api_base, 230 | api_key="fakekey", 231 | azure_ad_token_provider=azure_ad_token_provider, 232 | cache=False, 233 | ) 234 | 235 | lm_with_callable("Query") 236 | 237 | spy_completion.assert_called_once() 238 | call_args = spy_completion.call_args.kwargs 239 | assert call_args["model"] == "openai/dspy-test-model" 240 | assert call_args["api_base"] == api_base 241 | assert call_args["api_key"] == "fakekey" 242 | assert call_args["azure_ad_token_provider"] is azure_ad_token_provider 243 | 244 | 245 | def test_lm_calls_support_pydantic_models(litellm_test_server): 246 | api_base, _ = litellm_test_server 247 | 248 | class ResponseFormat(pydantic.BaseModel): 249 | response: str 250 | 251 | lm = dspy.LM( 252 | model="openai/dspy-test-model", 253 | api_base=api_base, 254 | api_key="fakekey", 255 | response_format=ResponseFormat, 256 | ) 257 | lm("Query") 258 | 259 | 260 | def test_retry_number_set_correctly(): 261 | lm = dspy.LM("openai/gpt-4o-mini", num_retries=3) 262 | with mock.patch("litellm.completion") as mock_completion: 263 | lm("query") 264 | 265 | assert mock_completion.call_args.kwargs["num_retries"] == 3 266 | 267 | 268 | def test_retry_made_on_system_errors(): 269 | retry_tracking = [0] # Using a list to track retries 270 | 271 | def mock_create(*args, **kwargs): 272 | retry_tracking[0] += 1 273 | # These fields are called during the error handling 274 | mock_response = mock.Mock() 275 | mock_response.headers = {} 276 | mock_response.status_code = 429 277 | raise RateLimitError(response=mock_response, message="message", body="error") 278 | 279 | lm = dspy.LM(model="openai/gpt-4o-mini", max_tokens=250, num_retries=3) 280 | with mock.patch.object(litellm.OpenAIChatCompletion, "completion", side_effect=mock_create): 281 | with pytest.raises(RateLimitError): 282 | lm("question") 283 | 284 | assert retry_tracking[0] == 4 285 | 286 | 287 | def test_reasoning_model_token_parameter(): 288 | test_cases = [ 289 | ("openai/o1", True), 290 | ("openai/o1-mini", True), 291 | ("openai/o1-2023-01-01", True), 292 | ("openai/o3", True), 293 | ("openai/o3-mini-2023-01-01", True), 294 | ("openai/gpt-5", True), 295 | ("openai/gpt-5-mini", True), 296 | ("openai/gpt-5-nano", True), 297 | ("openai/gpt-4", False), 298 | ("anthropic/claude-2", False), 299 | ] 300 | 301 | for model_name, is_reasoning_model in test_cases: 302 | lm = dspy.LM( 303 | model=model_name, 304 | temperature=1.0 if is_reasoning_model else 0.7, 305 | max_tokens=16_000 if is_reasoning_model else 1000, 306 | ) 307 | if is_reasoning_model: 308 | assert "max_completion_tokens" in lm.kwargs 309 | assert "max_tokens" not in lm.kwargs 310 | assert lm.kwargs["max_completion_tokens"] == 16_000 311 | else: 312 | assert "max_completion_tokens" not in lm.kwargs 313 | assert "max_tokens" in lm.kwargs 314 | assert lm.kwargs["max_tokens"] == 1000 315 | 316 | @pytest.mark.parametrize("model_name", ["openai/o1", "openai/gpt-5-nano"]) 317 | def test_reasoning_model_requirements(model_name): 318 | # Should raise assertion error if temperature or max_tokens requirements not met 319 | with pytest.raises( 320 | ValueError, 321 | match="reasoning models require passing temperature=1.0 or None and max_tokens >= 16000 or None", 322 | ): 323 | dspy.LM( 324 | model=model_name, 325 | temperature=0.7, # Should be 1.0 326 | max_tokens=1000, # Should be >= 16_000 327 | ) 328 | 329 | # Should pass with correct parameters 330 | lm = dspy.LM( 331 | model=model_name, 332 | temperature=1.0, 333 | max_tokens=16_000, 334 | ) 335 | assert lm.kwargs["max_completion_tokens"] == 16_000 336 | 337 | # Should pass with no parameters 338 | lm = dspy.LM( 339 | model=model_name, 340 | ) 341 | assert lm.kwargs["temperature"] == None 342 | assert lm.kwargs["max_completion_tokens"] == None 343 | 344 | 345 | def test_dump_state(): 346 | lm = dspy.LM( 347 | model="openai/gpt-4o-mini", 348 | model_type="chat", 349 | temperature=1, 350 | max_tokens=100, 351 | num_retries=10, 352 | launch_kwargs={"temperature": 1}, 353 | train_kwargs={"temperature": 5}, 354 | ) 355 | 356 | assert lm.dump_state() == { 357 | "model": "openai/gpt-4o-mini", 358 | "model_type": "chat", 359 | "temperature": 1, 360 | "max_tokens": 100, 361 | "num_retries": 10, 362 | "cache": True, 363 | "finetuning_model": None, 364 | "launch_kwargs": {"temperature": 1}, 365 | "train_kwargs": {"temperature": 5}, 366 | } 367 | 368 | 369 | def test_exponential_backoff_retry(): 370 | time_counter = [] 371 | 372 | def mock_create(*args, **kwargs): 373 | time_counter.append(time.time()) 374 | # These fields are called during the error handling 375 | mock_response = mock.Mock() 376 | mock_response.headers = {} 377 | mock_response.status_code = 429 378 | raise RateLimitError(response=mock_response, message="message", body="error") 379 | 380 | lm = dspy.LM(model="openai/gpt-3.5-turbo", max_tokens=250, num_retries=3) 381 | with mock.patch.object(litellm.OpenAIChatCompletion, "completion", side_effect=mock_create): 382 | with pytest.raises(RateLimitError): 383 | lm("question") 384 | 385 | # The first retry happens immediately regardless of the configuration 386 | for i in range(1, len(time_counter) - 1): 387 | assert time_counter[i + 1] - time_counter[i] >= 2 ** (i - 1) 388 | 389 | 390 | def test_logprobs_included_when_requested(): 391 | lm = dspy.LM(model="dspy-test-model", logprobs=True, cache=False) 392 | with mock.patch("litellm.completion") as mock_completion: 393 | mock_completion.return_value = ModelResponse( 394 | choices=[ 395 | Choices( 396 | message=Message(content="test answer"), 397 | logprobs={ 398 | "content": [ 399 | {"token": "test", "logprob": 0.1, "top_logprobs": [{"token": "test", "logprob": 0.1}]}, 400 | {"token": "answer", "logprob": 0.2, "top_logprobs": [{"token": "answer", "logprob": 0.2}]}, 401 | ] 402 | }, 403 | ) 404 | ], 405 | model="dspy-test-model", 406 | ) 407 | result = lm("question") 408 | assert result[0]["text"] == "test answer" 409 | assert result[0]["logprobs"].model_dump() == { 410 | "content": [ 411 | { 412 | "token": "test", 413 | "bytes": None, 414 | "logprob": 0.1, 415 | "top_logprobs": [{"token": "test", "bytes": None, "logprob": 0.1}], 416 | }, 417 | { 418 | "token": "answer", 419 | "bytes": None, 420 | "logprob": 0.2, 421 | "top_logprobs": [{"token": "answer", "bytes": None, "logprob": 0.2}], 422 | }, 423 | ] 424 | } 425 | assert mock_completion.call_args.kwargs["logprobs"] 426 | 427 | 428 | @pytest.mark.asyncio 429 | async def test_async_lm_call(): 430 | from litellm.utils import Choices, Message, ModelResponse 431 | 432 | mock_response = ModelResponse(choices=[Choices(message=Message(content="answer"))], model="openai/gpt-4o-mini") 433 | 434 | with patch("litellm.acompletion") as mock_acompletion: 435 | mock_acompletion.return_value = mock_response 436 | 437 | lm = dspy.LM(model="openai/gpt-4o-mini", cache=False) 438 | result = await lm.acall("question") 439 | 440 | assert result == ["answer"] 441 | mock_acompletion.assert_called_once() 442 | 443 | 444 | @pytest.mark.asyncio 445 | async def test_async_lm_call_with_cache(tmp_path): 446 | """Test the async LM call with caching.""" 447 | original_cache = dspy.cache 448 | dspy.clients.configure_cache( 449 | enable_disk_cache=True, 450 | enable_memory_cache=True, 451 | disk_cache_dir=tmp_path / ".disk_cache", 452 | ) 453 | cache = dspy.cache 454 | 455 | lm = dspy.LM(model="openai/gpt-4o-mini") 456 | 457 | with mock.patch("dspy.clients.lm.alitellm_completion") as mock_alitellm_completion: 458 | mock_alitellm_completion.return_value = ModelResponse( 459 | choices=[Choices(message=Message(content="answer"))], model="openai/gpt-4o-mini" 460 | ) 461 | mock_alitellm_completion.__qualname__ = "alitellm_completion" 462 | await lm.acall("Query") 463 | 464 | assert len(cache.memory_cache) == 1 465 | cache_key = next(iter(cache.memory_cache.keys())) 466 | assert cache_key in cache.disk_cache 467 | assert mock_alitellm_completion.call_count == 1 468 | 469 | await lm.acall("Query") 470 | # Second call should hit the cache, so no new call to LiteLLM is made. 471 | assert mock_alitellm_completion.call_count == 1 472 | 473 | # A new query should result in a new LiteLLM call and a new cache entry. 474 | await lm.acall("New query") 475 | 476 | assert len(cache.memory_cache) == 2 477 | assert mock_alitellm_completion.call_count == 2 478 | 479 | dspy.cache = original_cache 480 | 481 | 482 | def test_lm_history_size_limit(): 483 | lm = dspy.LM(model="openai/gpt-4o-mini") 484 | with dspy.context(max_history_size=5): 485 | with mock.patch("litellm.completion") as mock_completion: 486 | mock_completion.return_value = ModelResponse( 487 | choices=[Choices(message=Message(content="test answer"))], 488 | model="openai/gpt-4o-mini", 489 | ) 490 | 491 | for _ in range(10): 492 | lm("query") 493 | 494 | assert len(lm.history) == 5 495 | 496 | 497 | def test_disable_history(): 498 | lm = dspy.LM(model="openai/gpt-4o-mini") 499 | with dspy.context(disable_history=True): 500 | with mock.patch("litellm.completion") as mock_completion: 501 | mock_completion.return_value = ModelResponse( 502 | choices=[Choices(message=Message(content="test answer"))], 503 | model="openai/gpt-4o-mini", 504 | ) 505 | for _ in range(10): 506 | lm("query") 507 | 508 | assert len(lm.history) == 0 509 | 510 | with dspy.context(disable_history=False): 511 | with mock.patch("litellm.completion") as mock_completion: 512 | mock_completion.return_value = ModelResponse( 513 | choices=[Choices(message=Message(content="test answer"))], 514 | model="openai/gpt-4o-mini", 515 | ) 516 | 517 | def test_responses_api(): 518 | api_response = make_response( 519 | output_blocks=[ 520 | ResponseOutputMessage( 521 | **{ 522 | "id": "msg_1", 523 | "type": "message", 524 | "role": "assistant", 525 | "status": "completed", 526 | "content": [ 527 | {"type": "output_text", "text": "This is a test answer from responses API.", "annotations": []} 528 | ], 529 | }, 530 | ), 531 | ResponseReasoningItem( 532 | **{ 533 | "id": "reasoning_1", 534 | "type": "reasoning", 535 | "summary": [Summary(**{"type": "summary_text", "text": "This is a dummy reasoning."})], 536 | }, 537 | ), 538 | ] 539 | ) 540 | 541 | with mock.patch("litellm.responses", autospec=True, return_value=api_response) as dspy_responses: 542 | lm = dspy.LM( 543 | model="openai/gpt-5-mini", 544 | model_type="responses", 545 | cache=False, 546 | temperature=1.0, 547 | max_tokens=16000, 548 | ) 549 | lm_result = lm("openai query") 550 | 551 | assert lm_result == [ 552 | { 553 | "text": "This is a test answer from responses API.", 554 | "reasoning_content": "This is a dummy reasoning.", 555 | } 556 | ] 557 | 558 | dspy_responses.assert_called_once() 559 | assert dspy_responses.call_args.kwargs["model"] == "openai/gpt-5-mini" 560 | 561 | 562 | def test_lm_replaces_system_with_developer_role(): 563 | with mock.patch( 564 | "dspy.clients.lm.litellm_responses_completion", return_value={"choices": []} 565 | ) as mock_completion: 566 | lm = dspy.LM( 567 | "openai/gpt-4o-mini", 568 | cache=False, 569 | model_type="responses", 570 | use_developer_role=True, 571 | ) 572 | lm.forward(messages=[{"role": "system", "content": "hi"}]) 573 | assert ( 574 | mock_completion.call_args.kwargs["request"]["messages"][0]["role"] 575 | == "developer" 576 | ) 577 | 578 | 579 | def test_responses_api_tool_calls(litellm_test_server): 580 | api_base, _ = litellm_test_server 581 | expected_tool_call = { 582 | "type": "function_call", 583 | "name": "get_weather", 584 | "arguments": json.dumps({"city": "Paris"}), 585 | "call_id": "call_1", 586 | "status": "completed", 587 | "id": "call_1", 588 | } 589 | expected_response = [{"tool_calls": [expected_tool_call]}] 590 | 591 | api_response = make_response( 592 | output_blocks=[expected_tool_call], 593 | ) 594 | 595 | with mock.patch("litellm.responses", autospec=True, return_value=api_response) as dspy_responses: 596 | lm = dspy.LM( 597 | model="openai/dspy-test-model", 598 | api_base=api_base, 599 | api_key="fakekey", 600 | model_type="responses", 601 | cache=False, 602 | ) 603 | assert lm("openai query") == expected_response 604 | 605 | dspy_responses.assert_called_once() 606 | assert dspy_responses.call_args.kwargs["model"] == "openai/dspy-test-model" 607 | ``` -------------------------------------------------------------------------------- /docs/docs/api/optimizers/GEPA/GEPA_Advanced.md: -------------------------------------------------------------------------------- ```markdown 1 | # dspy.GEPA - Advanced Features 2 | 3 | ## Custom Instruction Proposers 4 | 5 | ### What is instruction_proposer? 6 | 7 | The `instruction_proposer` is the component responsible for invoking the `reflection_lm` and proposing new prompts during GEPA optimization. When GEPA identifies underperforming components in your DSPy program, the instruction proposer analyzes execution traces, feedback, and failures to generate improved instructions tailored to the observed issues. 8 | 9 | ### Default Implementation 10 | 11 | By default, GEPA uses the built-in instruction proposer from the [GEPA library](https://github.com/gepa-ai/gepa), which implements the [`ProposalFn`](https://github.com/gepa-ai/gepa/blob/main/src/gepa/core/adapter.py). The [default proposer](https://github.com/gepa-ai/gepa/blob/main/src/gepa/proposer/reflective_mutation/reflective_mutation.py#L53-L75) uses this prompt template: 12 | 13 | ```` 14 | I provided an assistant with the following instructions to perform a task for me: 15 | ``` 16 | <curr_instructions> 17 | ``` 18 | 19 | The following are examples of different task inputs provided to the assistant along with the assistant's response for each of them, and some feedback on how the assistant's response could be better: 20 | ``` 21 | <inputs_outputs_feedback> 22 | ``` 23 | 24 | Your task is to write a new instruction for the assistant. 25 | 26 | Read the inputs carefully and identify the input format and infer detailed task description about the task I wish to solve with the assistant. 27 | 28 | Read all the assistant responses and the corresponding feedback. Identify all niche and domain specific factual information about the task and include it in the instruction, as a lot of it may not be available to the assistant in the future. The assistant may have utilized a generalizable strategy to solve the task, if so, include that in the instruction as well. 29 | 30 | Provide the new instructions within ``` blocks. 31 | ```` 32 | 33 | This template is automatically filled with: 34 | 35 | - `<curr_instructions>`: The current instruction being optimized 36 | - `<inputs_outputs_feedback>`: Structured markdown containing predictor inputs, generated outputs, and evaluation feedback 37 | 38 | Example of default behavior: 39 | 40 | ```python 41 | # Default instruction proposer is used automatically 42 | gepa = dspy.GEPA( 43 | metric=my_metric, 44 | reflection_lm=dspy.LM(model="gpt-5", temperature=1.0, max_tokens=32000, api_key=api_key), 45 | auto="medium" 46 | ) 47 | optimized_program = gepa.compile(student, trainset=examples) 48 | ``` 49 | 50 | ### When to Use Custom instruction_proposer 51 | 52 | **Note:** Custom instruction proposers are an advanced feature. Most users should start with the default proposer, which works well for most text-based optimization tasks. 53 | 54 | Consider implementing a custom instruction proposer when you need: 55 | 56 | - **Multi-modal handling**: Process images (dspy.Image) alongside textual information in your inputs 57 | - **Nuanced control on limits and length constraints**: Have more fine-grained control over instruction length, format, and structural requirements 58 | - **Domain-specific information**: Inject specialized knowledge, terminology, or context that the default proposer lacks and cannot be provided via feedback_func. This is an advanced feature, and most users should not need to use this. 59 | - **Provider-specific prompting guides**: Optimize instructions for specific LLM providers (OpenAI, Anthropic, etc.) with their unique formatting preferences 60 | - **Coupled component updates**: Handle situations where 2 or more components need to be updated together in a coordinated manner, rather than optimizing each component independently (refer to component_selector parameter, in [Custom Component Selection](#custom-component-selection) section, for related functionality) 61 | - **External knowledge integration**: Connect to databases, APIs, or knowledge bases during instruction generation 62 | 63 | ### Available Options 64 | 65 | **Built-in Options:** 66 | 67 | - **Default Proposer**: The standard GEPA instruction proposer (used when `instruction_proposer=None`). The default instruction proposer IS an instruction proposer as well! It is the most general one, that was used for the diverse experiments reported in the GEPA paper and tutorials. 68 | - **MultiModalInstructionProposer**: Handles `dspy.Image` inputs and structured multimodal content. 69 | 70 | ```python 71 | from dspy.teleprompt.gepa.instruction_proposal import MultiModalInstructionProposer 72 | 73 | # For tasks involving images or multimodal inputs 74 | gepa = dspy.GEPA( 75 | metric=my_metric, 76 | reflection_lm=dspy.LM(model="gpt-5", temperature=1.0, max_tokens=32000, api_key=api_key), 77 | instruction_proposer=MultiModalInstructionProposer(), 78 | auto="medium" 79 | ) 80 | ``` 81 | 82 | We invite community contributions of new instruction proposers for specialized domains as the [GEPA library](https://github.com/gepa-ai/gepa) continues to grow. 83 | 84 | ### How to Implement Custom Instruction Proposers 85 | 86 | Custom instruction proposers must implement the `ProposalFn` protocol by defining a callable class or function. GEPA will call your proposer during optimization: 87 | 88 | ```python 89 | from dspy.teleprompt.gepa.gepa_utils import ReflectiveExample 90 | 91 | class CustomInstructionProposer: 92 | def __call__( 93 | self, 94 | candidate: dict[str, str], # Candidate component name -> instruction mapping to be updated in this round 95 | reflective_dataset: dict[str, list[ReflectiveExample]], # Component -> examples with structure: {"Inputs": ..., "Generated Outputs": ..., "Feedback": ...} 96 | components_to_update: list[str] # Which components to improve 97 | ) -> dict[str, str]: # Return new instruction mapping only for components being updated 98 | # Your custom instruction generation logic here 99 | return updated_instructions 100 | 101 | # Or as a function: 102 | def custom_instruction_proposer(candidate, reflective_dataset, components_to_update): 103 | # Your custom instruction generation logic here 104 | return updated_instructions 105 | ``` 106 | 107 | **Reflective Dataset Structure:** 108 | 109 | - `dict[str, list[ReflectiveExample]]` - Maps component names to lists of examples 110 | - `ReflectiveExample` TypedDict contains: 111 | - `Inputs: dict[str, Any]` - Predictor inputs (may include dspy.Image objects) 112 | - `Generated_Outputs: dict[str, Any] | str` - Success: output fields dict, Failure: error message 113 | - `Feedback: str` - Always a string from metric function or auto-generated by GEPA 114 | 115 | #### Basic Example: Word Limit Proposer 116 | 117 | ```python 118 | import dspy 119 | from gepa.core.adapter import ProposalFn 120 | from dspy.teleprompt.gepa.gepa_utils import ReflectiveExample 121 | 122 | class GenerateWordLimitedInstruction(dspy.Signature): 123 | """Given a current instruction and feedback examples, generate an improved instruction with word limit constraints.""" 124 | 125 | current_instruction = dspy.InputField(desc="The current instruction that needs improvement") 126 | feedback_summary = dspy.InputField(desc="Feedback from examples that might include both positive and negative cases") 127 | max_words = dspy.InputField(desc="Maximum number of words allowed in the new instruction") 128 | 129 | improved_instruction = dspy.OutputField(desc="A new instruction that fixes the issues while staying under the max_words limit") 130 | 131 | class WordLimitProposer(ProposalFn): 132 | def __init__(self, max_words: int = 1000): 133 | self.max_words = max_words 134 | self.instruction_improver = dspy.ChainOfThought(GenerateWordLimitedInstruction) 135 | 136 | def __call__(self, candidate: dict[str, str], reflective_dataset: dict[str, list[ReflectiveExample]], components_to_update: list[str]) -> dict[str, str]: 137 | updated_components = {} 138 | 139 | for component_name in components_to_update: 140 | if component_name not in candidate or component_name not in reflective_dataset: 141 | continue 142 | 143 | current_instruction = candidate[component_name] 144 | component_examples = reflective_dataset[component_name] 145 | 146 | # Create feedback summary 147 | feedback_text = "\n".join([ 148 | f"Example {i+1}: {ex.get('Feedback', 'No feedback')}" 149 | for i, ex in enumerate(component_examples) # Limit examples to prevent context overflow 150 | ]) 151 | 152 | # Use the module to improve the instruction 153 | result = self.instruction_improver( 154 | current_instruction=current_instruction, 155 | feedback_summary=feedback_text, 156 | max_words=self.max_words 157 | ) 158 | 159 | updated_components[component_name] = result.improved_instruction 160 | 161 | return updated_components 162 | 163 | # Usage 164 | gepa = dspy.GEPA( 165 | metric=my_metric, 166 | reflection_lm=dspy.LM(model="gpt-5", temperature=1.0, max_tokens=32000, api_key=api_key), 167 | instruction_proposer=WordLimitProposer(max_words=700), 168 | auto="medium" 169 | ) 170 | ``` 171 | 172 | #### Advanced Example: RAG-Enhanced Instruction Proposer 173 | 174 | ```python 175 | import dspy 176 | from gepa.core.adapter import ProposalFn 177 | from dspy.teleprompt.gepa.gepa_utils import ReflectiveExample 178 | 179 | class GenerateDocumentationQuery(dspy.Signature): 180 | """Analyze examples with feedback to identify common issue patterns and generate targeted database queries for retrieving relevant documentation. 181 | 182 | Your goal is to search a document database for guidelines that address the problematic patterns found in the examples. Look for recurring issues, error types, or failure modes in the feedback, then craft specific search queries that will find documentation to help resolve these patterns.""" 183 | 184 | current_instruction = dspy.InputField(desc="The current instruction that needs improvement") 185 | examples_with_feedback = dspy.InputField(desc="Examples with their feedback showing what issues occurred and any recurring patterns") 186 | 187 | failure_patterns: str = dspy.OutputField(desc="Summarize the common failure patterns identified in the examples") 188 | 189 | retrieval_queries: list[str] = dspy.OutputField(desc="Specific search queries to find relevant documentation in the database that addresses the common issue patterns identified in the problematic examples") 190 | 191 | class GenerateRAGEnhancedInstruction(dspy.Signature): 192 | """Generate improved instructions using retrieved documentation and examples analysis.""" 193 | 194 | current_instruction = dspy.InputField(desc="The current instruction that needs improvement") 195 | relevant_documentation = dspy.InputField(desc="Retrieved guidelines and best practices from specialized documentation") 196 | examples_with_feedback = dspy.InputField(desc="Examples showing what issues occurred with the current instruction") 197 | 198 | improved_instruction: str = dspy.OutputField(desc="Enhanced instruction that incorporates retrieved guidelines and addresses the issues shown in the examples") 199 | 200 | class RAGInstructionImprover(dspy.Module): 201 | """Module that uses RAG to improve instructions with specialized documentation.""" 202 | 203 | def __init__(self, retrieval_model): 204 | super().__init__() 205 | self.retrieve = retrieval_model # Could be dspy.Retrieve or custom retriever 206 | self.query_generator = dspy.ChainOfThought(GenerateDocumentationQuery) 207 | self.generate_answer = dspy.ChainOfThought(GenerateRAGEnhancedInstruction) 208 | 209 | def forward(self, current_instruction: str, component_examples: list): 210 | """Improve instruction using retrieved documentation.""" 211 | 212 | # Let LM analyze examples and generate targeted retrieval queries 213 | query_result = self.query_generator( 214 | current_instruction=current_instruction, 215 | examples_with_feedback=component_examples 216 | ) 217 | 218 | results = self.retrieve.query( 219 | query_texts=query_result.retrieval_queries, 220 | n_results=3 221 | ) 222 | 223 | relevant_docs_parts = [] 224 | for i, (query, query_docs) in enumerate(zip(query_result.retrieval_queries, results['documents'])): 225 | if query_docs: 226 | docs_formatted = "\n".join([f" - {doc}" for doc in query_docs]) 227 | relevant_docs_parts.append( 228 | f"**Search Query #{i+1}**: {query}\n" 229 | f"**Retrieved Guidelines**:\n{docs_formatted}" 230 | ) 231 | 232 | relevant_docs = "\n\n" + "="*60 + "\n\n".join(relevant_docs_parts) + "\n" + "="*60 233 | 234 | # Generate improved instruction with retrieved context 235 | result = self.generate_answer( 236 | current_instruction=current_instruction, 237 | relevant_documentation=relevant_docs, 238 | examples_with_feedback=component_examples 239 | ) 240 | 241 | return result 242 | 243 | class DocumentationEnhancedProposer(ProposalFn): 244 | """Instruction proposer that accesses specialized documentation via RAG.""" 245 | 246 | def __init__(self, documentation_retriever): 247 | """ 248 | Args: 249 | documentation_retriever: A retrieval model that can search your specialized docs 250 | Could be dspy.Retrieve, ChromadbRM, or custom retriever 251 | """ 252 | self.instruction_improver = RAGInstructionImprover(documentation_retriever) 253 | 254 | def __call__(self, candidate: dict[str, str], reflective_dataset: dict[str, list[ReflectiveExample]], components_to_update: list[str]) -> dict[str, str]: 255 | updated_components = {} 256 | 257 | for component_name in components_to_update: 258 | if component_name not in candidate or component_name not in reflective_dataset: 259 | continue 260 | 261 | current_instruction = candidate[component_name] 262 | component_examples = reflective_dataset[component_name] 263 | 264 | result = self.instruction_improver( 265 | current_instruction=current_instruction, 266 | component_examples=component_examples 267 | ) 268 | 269 | updated_components[component_name] = result.improved_instruction 270 | 271 | return updated_components 272 | 273 | import chromadb 274 | 275 | client = chromadb.Client() 276 | collection = client.get_collection("instruction_guidelines") 277 | 278 | gepa = dspy.GEPA( 279 | metric=task_specific_metric, 280 | reflection_lm=dspy.LM(model="gpt-5", temperature=1.0, max_tokens=32000, api_key=api_key), 281 | instruction_proposer=DocumentationEnhancedProposer(collection), 282 | auto="medium" 283 | ) 284 | ``` 285 | 286 | #### Integration Patterns 287 | 288 | **Using Custom Proposer with External LM:** 289 | 290 | ```python 291 | class ExternalLMProposer(ProposalFn): 292 | def __init__(self): 293 | # Manage your own LM instance 294 | self.external_lm = dspy.LM('gemini/gemini-2.5-pro') 295 | 296 | def __call__(self, candidate, reflective_dataset, components_to_update): 297 | updated_components = {} 298 | 299 | with dspy.context(lm=self.external_lm): 300 | # Your custom logic here using self.external_lm 301 | for component_name in components_to_update: 302 | # ... implementation 303 | pass 304 | 305 | return updated_components 306 | 307 | gepa = dspy.GEPA( 308 | metric=my_metric, 309 | reflection_lm=None, # Optional when using custom proposer 310 | instruction_proposer=ExternalLMProposer(), 311 | auto="medium" 312 | ) 313 | ``` 314 | 315 | **Best Practices:** 316 | 317 | - **Use the full power of DSPy**: Leverage DSPy components like `dspy.Module`, `dspy.Signature`, and `dspy.Predict` to create your instruction proposer rather than direct LM calls. Consider `dspy.Refine` for constraint satisfaction, `dspy.ChainOfThought` for complex reasoning tasks, and compose multiple modules for sophisticated instruction improvement workflows 318 | - **Enable holistic feedback analysis**: While dspy.GEPA's `GEPAFeedbackMetric` processes one (gold, prediction) pair at a time, instruction proposers receive all examples for a component in batch, enabling cross-example pattern detection and systematic issue identification. 319 | - **Mind data serialization**: Serializing everything to strings might not be ideal - handle complex input types (like `dspy.Image`) by maintaining their structure for better LM processing 320 | - **Test thoroughly**: Test your custom proposer with representative failure cases 321 | 322 | ## Custom Component Selection 323 | 324 | ### What is component_selector? 325 | 326 | The `component_selector` parameter controls which components (predictors) in your DSPy program are selected for optimization at each GEPA iteration. Instead of the default round-robin approach that updates one component at a time, you can implement custom selection strategies that choose single or multiple components based on optimization state, performance trajectories, and other contextual information. 327 | 328 | ### Default Behavior 329 | 330 | By default, GEPA uses a **round-robin strategy** (`RoundRobinReflectionComponentSelector`) that cycles through components sequentially, optimizing one component per iteration: 331 | 332 | ```python 333 | # Default round-robin component selection 334 | gepa = dspy.GEPA( 335 | metric=my_metric, 336 | reflection_lm=dspy.LM(model="gpt-5", temperature=1.0, max_tokens=32000, api_key=api_key), 337 | # component_selector="round_robin" # This is the default 338 | auto="medium" 339 | ) 340 | ``` 341 | 342 | ### Built-in Selection Strategies 343 | 344 | **String-based selectors:** 345 | 346 | - `"round_robin"` (default): Cycles through components one at a time 347 | - `"all"`: Selects all components for simultaneous optimization 348 | 349 | ```python 350 | # Optimize all components simultaneously 351 | gepa = dspy.GEPA( 352 | metric=my_metric, 353 | reflection_lm=reflection_lm, 354 | component_selector="all", # Update all components together 355 | auto="medium" 356 | ) 357 | 358 | # Explicit round-robin selection 359 | gepa = dspy.GEPA( 360 | metric=my_metric, 361 | reflection_lm=reflection_lm, 362 | component_selector="round_robin", # One component per iteration 363 | auto="medium" 364 | ) 365 | ``` 366 | 367 | ### When to Use Custom Component Selection 368 | 369 | Consider implementing custom component selection when you need: 370 | 371 | - **Dependency-aware optimization**: Update related components together (e.g., a classifier and its input formatter) 372 | - **LLM-driven selection**: Let an LLM analyze trajectories and decide which components need attention 373 | - **Resource-conscious optimization**: Balance optimization thoroughness with computational budget 374 | 375 | ### Custom Component Selector Protocol 376 | 377 | Custom component selectors must implement the [`ReflectionComponentSelector`](https://github.com/gepa-ai/gepa/blob/main/src/gepa/proposer/reflective_mutation/base.py) protocol by defining a callable class or function. GEPA will call your selector during optimization: 378 | 379 | ```python 380 | from dspy.teleprompt.gepa.gepa_utils import GEPAState, Trajectory 381 | 382 | class CustomComponentSelector: 383 | def __call__( 384 | self, 385 | state: GEPAState, # Complete optimization state with history 386 | trajectories: list[Trajectory], # Execution traces from the current minibatch 387 | subsample_scores: list[float], # Scores for each example in the current minibatch 388 | candidate_idx: int, # Index of the current program candidate being optimized 389 | candidate: dict[str, str], # Component name -> instruction mapping 390 | ) -> list[str]: # Return list of component names to optimize 391 | # Your custom component selection logic here 392 | return selected_components 393 | 394 | # Or as a function: 395 | def custom_component_selector(state, trajectories, subsample_scores, candidate_idx, candidate): 396 | # Your custom component selection logic here 397 | return selected_components 398 | ``` 399 | 400 | ### Custom Implementation Example 401 | 402 | Here's a simple function that alternates between optimizing different halves of your components: 403 | 404 | ```python 405 | def alternating_half_selector(state, trajectories, subsample_scores, candidate_idx, candidate): 406 | """Optimize half the components on even iterations, half on odd iterations.""" 407 | components = list(candidate.keys()) 408 | 409 | # If there's only one component, always optimize it 410 | if len(components) <= 1: 411 | return components 412 | 413 | mid_point = len(components) // 2 414 | 415 | # Use state.i (iteration counter) to alternate between halves 416 | if state.i % 2 == 0: 417 | # Even iteration: optimize first half 418 | return components[:mid_point] 419 | else: 420 | # Odd iteration: optimize second half 421 | return components[mid_point:] 422 | 423 | # Usage 424 | gepa = dspy.GEPA( 425 | metric=my_metric, 426 | reflection_lm=reflection_lm, 427 | component_selector=alternating_half_selector, 428 | auto="medium" 429 | ) 430 | ``` 431 | 432 | ### Integration with Custom Instruction Proposers 433 | 434 | Component selectors work seamlessly with custom instruction proposers. The selector determines which components to update, then the instruction proposer generates new instructions for those components: 435 | 436 | ```python 437 | # Combined custom selector + custom proposer 438 | gepa = dspy.GEPA( 439 | metric=my_metric, 440 | reflection_lm=reflection_lm, 441 | component_selector=alternating_half_selector, 442 | instruction_proposer=WordLimitProposer(max_words=500), 443 | auto="medium" 444 | ) 445 | ``` 446 | ``` -------------------------------------------------------------------------------- /dspy/retrievers/databricks_rm.py: -------------------------------------------------------------------------------- ```python 1 | import json 2 | import os 3 | from dataclasses import dataclass 4 | from importlib.util import find_spec 5 | from typing import Any 6 | 7 | import requests 8 | 9 | import dspy 10 | from dspy.primitives.prediction import Prediction 11 | 12 | _databricks_sdk_installed = find_spec("databricks.sdk") is not None 13 | 14 | 15 | @dataclass 16 | class Document: 17 | page_content: str 18 | metadata: dict[str, Any] 19 | type: str 20 | 21 | def to_dict(self) -> dict[str, Any]: 22 | return { 23 | "page_content": self.page_content, 24 | "metadata": self.metadata, 25 | "type": self.type, 26 | } 27 | 28 | 29 | class DatabricksRM(dspy.Retrieve): 30 | """ 31 | A retriever module that uses a Databricks Mosaic AI Vector Search Index to return the top-k 32 | embeddings for a given query. 33 | 34 | Examples: 35 | Below is a code snippet that shows how to set up a Databricks Vector Search Index 36 | and configure a DatabricksRM DSPy retriever module to query the index. 37 | 38 | (example adapted from "Databricks: How to create and query a Vector Search Index: 39 | https://docs.databricks.com/en/generative-ai/create-query-vector-search.html#create-a-vector-search-index) 40 | 41 | ```python 42 | from databricks.vector_search.client import VectorSearchClient 43 | 44 | # Create a Databricks Vector Search Endpoint 45 | client = VectorSearchClient() 46 | client.create_endpoint( 47 | name="your_vector_search_endpoint_name", 48 | endpoint_type="STANDARD" 49 | ) 50 | 51 | # Create a Databricks Direct Access Vector Search Index 52 | index = client.create_direct_access_index( 53 | endpoint_name="your_vector_search_endpoint_name", 54 | index_name="your_index_name", 55 | primary_key="id", 56 | embedding_dimension=1024, 57 | embedding_vector_column="text_vector", 58 | schema={ 59 | "id": "int", 60 | "field2": "str", 61 | "field3": "float", 62 | "text_vector": "array<float>" 63 | } 64 | ) 65 | 66 | # Create a DatabricksRM retriever module to query the Databricks Direct Access Vector 67 | # Search Index 68 | retriever = DatabricksRM( 69 | databricks_index_name = "your_index_name", 70 | docs_id_column_name="id", 71 | text_column_name="field2", 72 | k=3 73 | ) 74 | ``` 75 | 76 | Below is a code snippet that shows how to query the Databricks Direct Access Vector 77 | Search Index using the DatabricksRM retriever module: 78 | 79 | ```python 80 | retrieved_results = DatabricksRM(query="Example query text")) 81 | ``` 82 | """ 83 | 84 | def __init__( 85 | self, 86 | databricks_index_name: str, 87 | databricks_endpoint: str | None = None, 88 | databricks_token: str | None = None, 89 | databricks_client_id: str | None = None, 90 | databricks_client_secret: str | None = None, 91 | columns: list[str] | None = None, 92 | filters_json: str | None = None, 93 | k: int = 3, 94 | docs_id_column_name: str = "id", 95 | docs_uri_column_name: str | None = None, 96 | text_column_name: str = "text", 97 | use_with_databricks_agent_framework: bool = False, 98 | ): 99 | """ 100 | Args: 101 | databricks_index_name (str): The name of the Databricks Vector Search Index to query. 102 | databricks_endpoint (Optional[str]): The URL of the Databricks Workspace containing 103 | the Vector Search Index. Defaults to the value of the ``DATABRICKS_HOST`` 104 | environment variable. If unspecified, the Databricks SDK is used to identify the 105 | endpoint based on the current environment. 106 | databricks_token (Optional[str]): The Databricks Workspace authentication token to use 107 | when querying the Vector Search Index. Defaults to the value of the 108 | ``DATABRICKS_TOKEN`` environment variable. If unspecified, the Databricks SDK is 109 | used to identify the token based on the current environment. 110 | databricks_client_id (str): Databricks service principal id. If not specified, 111 | the token is resolved from the current environment (DATABRICKS_CLIENT_ID). 112 | databricks_client_secret (str): Databricks service principal secret. If not specified, 113 | the endpoint is resolved from the current environment (DATABRICKS_CLIENT_SECRET). 114 | columns (Optional[list[str]]): Extra column names to include in response, 115 | in addition to the document id and text columns specified by 116 | ``docs_id_column_name`` and ``text_column_name``. 117 | filters_json (Optional[str]): A JSON string specifying additional query filters. 118 | Example filters: ``{"id <": 5}`` selects records that have an ``id`` column value 119 | less than 5, and ``{"id >=": 5, "id <": 10}`` selects records that have an ``id`` 120 | column value greater than or equal to 5 and less than 10. 121 | k (int): The number of documents to retrieve. 122 | docs_id_column_name (str): The name of the column in the Databricks Vector Search Index 123 | containing document IDs. 124 | docs_uri_column_name (Optional[str]): The name of the column in the Databricks Vector Search Index 125 | containing document URI. 126 | text_column_name (str): The name of the column in the Databricks Vector Search Index 127 | containing document text to retrieve. 128 | use_with_databricks_agent_framework (bool): Whether to use the `DatabricksRM` in a way that is 129 | compatible with the Databricks Mosaic Agent Framework. 130 | """ 131 | super().__init__(k=k) 132 | self.databricks_token = databricks_token if databricks_token is not None else os.environ.get("DATABRICKS_TOKEN") 133 | self.databricks_endpoint = ( 134 | databricks_endpoint if databricks_endpoint is not None else os.environ.get("DATABRICKS_HOST") 135 | ) 136 | self.databricks_client_id = ( 137 | databricks_client_id if databricks_client_id is not None else os.environ.get("DATABRICKS_CLIENT_ID") 138 | ) 139 | self.databricks_client_secret = ( 140 | databricks_client_secret 141 | if databricks_client_secret is not None 142 | else os.environ.get("DATABRICKS_CLIENT_SECRET") 143 | ) 144 | if not _databricks_sdk_installed and (self.databricks_token, self.databricks_endpoint).count(None) > 0: 145 | raise ValueError( 146 | "To retrieve documents with Databricks Vector Search, you must install the" 147 | " databricks-sdk Python library, supply the databricks_token and" 148 | " databricks_endpoint parameters, or set the DATABRICKS_TOKEN and DATABRICKS_HOST" 149 | " environment variables. You may also supply a service principal the databricks_client_id and" 150 | " databricks_client_secret parameters, or set the DATABRICKS_CLIENT_ID and DATABRICKS_CLIENT_SECRET" 151 | ) 152 | self.databricks_index_name = databricks_index_name 153 | self.columns = list({docs_id_column_name, text_column_name, *(columns or [])}) 154 | self.filters_json = filters_json 155 | self.k = k 156 | self.docs_id_column_name = docs_id_column_name 157 | self.docs_uri_column_name = docs_uri_column_name 158 | self.text_column_name = text_column_name 159 | self.use_with_databricks_agent_framework = use_with_databricks_agent_framework 160 | if self.use_with_databricks_agent_framework: 161 | try: 162 | import mlflow 163 | 164 | mlflow.models.set_retriever_schema( 165 | primary_key="doc_id", 166 | text_column="page_content", 167 | doc_uri="doc_uri", 168 | ) 169 | except ImportError: 170 | raise ValueError( 171 | "To use the `DatabricksRM` retriever module with the Databricks Mosaic Agent Framework, " 172 | "you must install the mlflow Python library. Please install mlflow via `pip install mlflow`." 173 | ) 174 | 175 | def _extract_doc_ids(self, item: dict[str, Any]) -> str: 176 | """Extracts the document id from a search result 177 | 178 | Args: 179 | item: dict[str, Any]: a record from the search results. 180 | Returns: 181 | str: document id. 182 | """ 183 | if self.docs_id_column_name == "metadata": 184 | docs_dict = json.loads(item["metadata"]) 185 | return docs_dict["document_id"] 186 | return item[self.docs_id_column_name] 187 | 188 | def _get_extra_columns(self, item: dict[str, Any]) -> dict[str, Any]: 189 | """Extracts search result column values, excluding the "text" and not "id" columns 190 | 191 | Args: 192 | item: dict[str, Any]: a record from the search results. 193 | Returns: 194 | dict[str, Any]: Search result column values, excluding the "text", "id" and "uri" columns. 195 | """ 196 | extra_columns = { 197 | k: v 198 | for k, v in item.items() 199 | if k not in [self.docs_id_column_name, self.text_column_name, self.docs_uri_column_name] 200 | } 201 | if self.docs_id_column_name == "metadata": 202 | extra_columns = { 203 | **extra_columns, 204 | **{"metadata": {k: v for k, v in json.loads(item["metadata"]).items() if k != "document_id"}}, 205 | } 206 | return extra_columns 207 | 208 | def forward( 209 | self, 210 | query: str | list[float], 211 | query_type: str = "ANN", 212 | filters_json: str | None = None, 213 | ) -> dspy.Prediction | list[dict[str, Any]]: 214 | """ 215 | Retrieve documents from a Databricks Mosaic AI Vector Search Index that are relevant to the 216 | specified query. 217 | 218 | Args: 219 | query (Union[str, list[float]]): The query text or numeric query vector for which to 220 | retrieve relevant documents. 221 | query_type (str): The type of search query to perform against the Databricks Vector 222 | Search Index. Must be either 'ANN' (approximate nearest neighbor) or 'HYBRID' 223 | (hybrid search). 224 | filters_json (Optional[str]): A JSON string specifying additional query filters. 225 | Example filters: ``{"id <": 5}`` selects records that have an ``id`` column value 226 | less than 5, and ``{"id >=": 5, "id <": 10}`` selects records that have an ``id`` 227 | column value greater than or equal to 5 and less than 10. If specified, this 228 | parameter overrides the `filters_json` parameter passed to the constructor. 229 | 230 | Returns: 231 | A list of dictionaries when ``use_with_databricks_agent_framework`` is ``True``, 232 | or a ``dspy.Prediction`` object when ``use_with_databricks_agent_framework`` is 233 | ``False``. 234 | """ 235 | if query_type in ["vector", "text"]: 236 | # Older versions of DSPy used a `query_type` argument to disambiguate between text 237 | # and vector queries, rather than checking the type of the `query` argument. This 238 | # differs from the Databricks Vector Search definition of `query_type`, which 239 | # specifies the search algorithm to use (e.g. "ANN" or "HYBRID"). To maintain 240 | # backwards compatibility with older versions of DSPy, we map the old `query_type` 241 | # values to the Databricks Vector Search default query type of "ANN". 242 | query_type = "ANN" 243 | 244 | if isinstance(query, str): 245 | query_text = query 246 | query_vector = None 247 | elif isinstance(query, list): 248 | query_vector = query 249 | query_text = None 250 | else: 251 | raise ValueError("Query must be a string or a list of floats.") 252 | 253 | if _databricks_sdk_installed: 254 | results = self._query_via_databricks_sdk( 255 | index_name=self.databricks_index_name, 256 | k=self.k, 257 | columns=self.columns, 258 | query_type=query_type, 259 | query_text=query_text, 260 | query_vector=query_vector, 261 | databricks_token=self.databricks_token, 262 | databricks_endpoint=self.databricks_endpoint, 263 | databricks_client_id=self.databricks_client_id, 264 | databricks_client_secret=self.databricks_client_secret, 265 | filters_json=filters_json or self.filters_json, 266 | ) 267 | else: 268 | results = self._query_via_requests( 269 | index_name=self.databricks_index_name, 270 | k=self.k, 271 | columns=self.columns, 272 | databricks_token=self.databricks_token, 273 | databricks_endpoint=self.databricks_endpoint, 274 | query_type=query_type, 275 | query_text=query_text, 276 | query_vector=query_vector, 277 | filters_json=filters_json or self.filters_json, 278 | ) 279 | 280 | # Checking if defined columns are present in the index columns 281 | col_names = [column["name"] for column in results["manifest"]["columns"]] 282 | 283 | if self.docs_id_column_name not in col_names: 284 | raise Exception( 285 | f"docs_id_column_name: '{self.docs_id_column_name}' is not in the index columns: \n {col_names}" 286 | ) 287 | 288 | if self.text_column_name not in col_names: 289 | raise Exception(f"text_column_name: '{self.text_column_name}' is not in the index columns: \n {col_names}") 290 | 291 | # Extracting the results 292 | items = [] 293 | if "data_array" in results["result"]: 294 | for _, data_row in enumerate(results["result"]["data_array"]): 295 | item = {} 296 | for col_name, val in zip(col_names, data_row, strict=False): 297 | item[col_name] = val 298 | items += [item] 299 | 300 | # Sorting results by score in descending order 301 | sorted_docs = sorted(items, key=lambda x: x["score"], reverse=True)[: self.k] 302 | 303 | if self.use_with_databricks_agent_framework: 304 | return [ 305 | Document( 306 | page_content=doc[self.text_column_name], 307 | metadata={ 308 | "doc_id": self._extract_doc_ids(doc), 309 | "doc_uri": doc[self.docs_uri_column_name] if self.docs_uri_column_name else None, 310 | } 311 | | self._get_extra_columns(doc), 312 | type="Document", 313 | ).to_dict() 314 | for doc in sorted_docs 315 | ] 316 | else: 317 | # Returning the prediction 318 | return Prediction( 319 | docs=[doc[self.text_column_name] for doc in sorted_docs], 320 | doc_ids=[self._extract_doc_ids(doc) for doc in sorted_docs], 321 | doc_uris=[doc[self.docs_uri_column_name] for doc in sorted_docs] if self.docs_uri_column_name else None, 322 | extra_columns=[self._get_extra_columns(item) for item in sorted_docs], 323 | ) 324 | 325 | @staticmethod 326 | def _query_via_databricks_sdk( 327 | index_name: str, 328 | k: int, 329 | columns: list[str], 330 | query_type: str, 331 | query_text: str | None, 332 | query_vector: list[float] | None, 333 | databricks_token: str | None, 334 | databricks_endpoint: str | None, 335 | databricks_client_id: str | None, 336 | databricks_client_secret: str | None, 337 | filters_json: str | None, 338 | ) -> dict[str, Any]: 339 | """ 340 | Query a Databricks Vector Search Index via the Databricks SDK. 341 | Assumes that the databricks-sdk Python library is installed. 342 | 343 | Args: 344 | index_name (str): Name of the Databricks vector search index to query 345 | k (int): Number of relevant documents to retrieve. 346 | columns (list[str]): Column names to include in response. 347 | query_text (Optional[str]): Text query for which to find relevant documents. Exactly 348 | one of query_text or query_vector must be specified. 349 | query_vector (Optional[list[float]]): Numeric query vector for which to find relevant 350 | documents. Exactly one of query_text or query_vector must be specified. 351 | filters_json (Optional[str]): JSON string representing additional query filters. 352 | databricks_token (str): Databricks authentication token. If not specified, 353 | the token is resolved from the current environment. 354 | databricks_endpoint (str): Databricks index endpoint url. If not specified, 355 | the endpoint is resolved from the current environment. 356 | databricks_client_id (str): Databricks service principal id. If not specified, 357 | the token is resolved from the current environment (DATABRICKS_CLIENT_ID). 358 | databricks_client_secret (str): Databricks service principal secret. If not specified, 359 | the endpoint is resolved from the current environment (DATABRICKS_CLIENT_SECRET). 360 | Returns: 361 | Returns: 362 | dict[str, Any]: Parsed JSON response from the Databricks Vector Search Index query. 363 | """ 364 | 365 | from databricks.sdk import WorkspaceClient 366 | 367 | if (query_text, query_vector).count(None) != 1: 368 | raise ValueError("Exactly one of query_text or query_vector must be specified.") 369 | 370 | if databricks_client_secret and databricks_client_id: 371 | # Use client ID and secret for authentication if they are provided 372 | databricks_client = WorkspaceClient( 373 | client_id=databricks_client_id, 374 | client_secret=databricks_client_secret, 375 | ) 376 | print("Creating Databricks workspace client using service principal authentication.") 377 | 378 | else: 379 | # Fallback for token-based authentication 380 | databricks_client = WorkspaceClient( 381 | host=databricks_endpoint, 382 | token=databricks_token, 383 | ) 384 | print("Creating Databricks workspace client using token authentication.") 385 | 386 | return databricks_client.vector_search_indexes.query_index( 387 | index_name=index_name, 388 | query_type=query_type, 389 | query_text=query_text, 390 | query_vector=query_vector, 391 | columns=columns, 392 | filters_json=filters_json, 393 | num_results=k, 394 | ).as_dict() 395 | 396 | @staticmethod 397 | def _query_via_requests( 398 | index_name: str, 399 | k: int, 400 | columns: list[str], 401 | databricks_token: str, 402 | databricks_endpoint: str, 403 | query_type: str, 404 | query_text: str | None, 405 | query_vector: list[float] | None, 406 | filters_json: str | None, 407 | ) -> dict[str, Any]: 408 | """ 409 | Query a Databricks Vector Search Index via the Python requests library. 410 | 411 | Args: 412 | index_name (str): Name of the Databricks vector search index to query 413 | k (int): Number of relevant documents to retrieve. 414 | columns (list[str]): Column names to include in response. 415 | databricks_token (str): Databricks authentication token. 416 | databricks_endpoint (str): Databricks index endpoint url. 417 | query_text (Optional[str]): Text query for which to find relevant documents. Exactly 418 | one of query_text or query_vector must be specified. 419 | query_vector (Optional[list[float]]): Numeric query vector for which to find relevant 420 | documents. Exactly one of query_text or query_vector must be specified. 421 | filters_json (Optional[str]): JSON string representing additional query filters. 422 | 423 | Returns: 424 | dict[str, Any]: Parsed JSON response from the Databricks Vector Search Index query. 425 | """ 426 | if (query_text, query_vector).count(None) != 1: 427 | raise ValueError("Exactly one of query_text or query_vector must be specified.") 428 | 429 | headers = { 430 | "Authorization": f"Bearer {databricks_token}", 431 | "Content-Type": "application/json", 432 | } 433 | payload = { 434 | "columns": columns, 435 | "num_results": k, 436 | "query_type": query_type, 437 | } 438 | if filters_json is not None: 439 | payload["filters_json"] = filters_json 440 | if query_text is not None: 441 | payload["query_text"] = query_text 442 | elif query_vector is not None: 443 | payload["query_vector"] = query_vector 444 | response = requests.post( 445 | f"{databricks_endpoint}/api/2.0/vector-search/indexes/{index_name}/query", 446 | json=payload, 447 | headers=headers, 448 | ) 449 | results = response.json() 450 | if "error_code" in results: 451 | raise Exception(f"ERROR: {results['error_code']} -- {results['message']}") 452 | return results 453 | ``` -------------------------------------------------------------------------------- /tests/primitives/test_base_module.py: -------------------------------------------------------------------------------- ```python 1 | import asyncio 2 | import logging 3 | import os 4 | import threading 5 | from unittest.mock import patch 6 | 7 | import pytest 8 | from litellm import Choices, Message, ModelResponse 9 | from litellm.types.utils import Usage 10 | 11 | import dspy 12 | from dspy.primitives.prediction import Prediction 13 | from dspy.utils.dummies import DummyLM 14 | 15 | 16 | def test_deepcopy_basic(): 17 | signature = dspy.Signature("q -> a") 18 | cot = dspy.ChainOfThought(signature) 19 | cot_copy = cot.deepcopy() 20 | assert len(cot.parameters()) == len(cot_copy.parameters()) 21 | # Parameters should be different objects with the same values. 22 | assert id(cot.parameters()[0]) != id(cot_copy.parameters()[0]) 23 | assert cot.parameters()[0].__dict__ == cot_copy.parameters()[0].__dict__ 24 | 25 | 26 | def test_deepcopy_with_uncopyable_modules(): 27 | class CustomClass(dspy.Module): 28 | def __init__(self): 29 | self.lock = threading.Lock() # Non-copyable object. 30 | self.cot = dspy.ChainOfThought(dspy.Signature("q -> a")) 31 | 32 | model = CustomClass() 33 | model_copy = model.deepcopy() 34 | assert len(model.parameters()) == len(model_copy.parameters()) 35 | # The lock should be refer to the same object (shallow copy). 36 | assert id(model.lock) == id(model_copy.lock) 37 | # Parameters should be different objects with the same values. 38 | assert id(model.parameters()[0]) != id(model_copy.parameters()[0]) 39 | assert model.parameters()[0].__dict__ == model_copy.parameters()[0].__dict__ 40 | 41 | 42 | def test_deepcopy_with_nested_modules(): 43 | class CustomClass1(dspy.Module): 44 | def __init__(self): 45 | self.lock = threading.Lock() # Non-copyable object. 46 | self.cot = dspy.ChainOfThought(dspy.Signature("q -> a")) 47 | 48 | class CustomClass2(dspy.Module): 49 | def __init__(self): 50 | self.submodel = CustomClass1() 51 | 52 | model = CustomClass2() 53 | model_copy = model.deepcopy() 54 | assert len(model.parameters()) == len(model_copy.parameters()) 55 | # The lock should be refer to the same object (shallow copy). 56 | assert id(model.submodel.lock) == id(model_copy.submodel.lock) 57 | # Parameters should be different objects with the same values. 58 | assert id(model.parameters()[0]) != id(model_copy.parameters()[0]) 59 | assert model.parameters()[0].__dict__ == model_copy.parameters()[0].__dict__ 60 | 61 | 62 | def test_save_and_load_with_json(tmp_path): 63 | model = dspy.ChainOfThought(dspy.Signature("q -> a")) 64 | model.predict.signature = model.predict.signature.with_instructions("You are a helpful assistant.") 65 | model.predict.demos = [ 66 | dspy.Example(q="What is the capital of France?", a="Paris", reasoning="n/a").with_inputs("q"), 67 | # Nested example 68 | dspy.Example( 69 | q=[ 70 | dspy.Example(q="What is the capital of France?"), 71 | dspy.Example(q="What is actually the capital of France?"), 72 | ], 73 | a="Paris", 74 | reasoning="n/a", 75 | ).with_inputs("q"), 76 | ] 77 | save_path = tmp_path / "model.json" 78 | model.save(save_path) 79 | new_model = dspy.ChainOfThought(dspy.Signature("q -> a")) 80 | new_model.load(save_path) 81 | 82 | assert str(new_model.predict.signature) == str(model.predict.signature) 83 | assert new_model.predict.demos[0] == model.predict.demos[0].toDict() 84 | assert new_model.predict.demos[1] == model.predict.demos[1].toDict() 85 | 86 | 87 | @pytest.mark.extra 88 | def test_save_and_load_with_pkl(tmp_path): 89 | import datetime 90 | 91 | # `datetime.date` is not json serializable, so we need to save with pickle. 92 | class MySignature(dspy.Signature): 93 | """Just a custom signature.""" 94 | 95 | current_date: datetime.date = dspy.InputField() 96 | target_date: datetime.date = dspy.InputField() 97 | date_diff: int = dspy.OutputField(desc="The difference in days between the current_date and the target_date") 98 | 99 | trainset = [ 100 | {"current_date": datetime.date(2024, 1, 1), "target_date": datetime.date(2024, 1, 2), "date_diff": 1}, 101 | {"current_date": datetime.date(2024, 1, 1), "target_date": datetime.date(2024, 1, 3), "date_diff": 2}, 102 | {"current_date": datetime.date(2024, 1, 1), "target_date": datetime.date(2024, 1, 4), "date_diff": 3}, 103 | {"current_date": datetime.date(2024, 1, 1), "target_date": datetime.date(2024, 1, 5), "date_diff": 4}, 104 | {"current_date": datetime.date(2024, 1, 1), "target_date": datetime.date(2024, 1, 6), "date_diff": 5}, 105 | ] 106 | trainset = [dspy.Example(**example).with_inputs("current_date", "target_date") for example in trainset] 107 | 108 | dspy.settings.configure( 109 | lm=DummyLM([{"date_diff": "1", "reasoning": "n/a"}, {"date_diff": "2", "reasoning": "n/a"}] * 10) 110 | ) 111 | 112 | cot = dspy.ChainOfThought(MySignature) 113 | cot(current_date=datetime.date(2024, 1, 1), target_date=datetime.date(2024, 1, 2)) 114 | 115 | def dummy_metric(example, pred, trace=None): 116 | return True 117 | 118 | optimizer = dspy.BootstrapFewShot(max_bootstrapped_demos=4, max_labeled_demos=4, max_rounds=5, metric=dummy_metric) 119 | compiled_cot = optimizer.compile(cot, trainset=trainset) 120 | compiled_cot.predict.signature = compiled_cot.predict.signature.with_instructions("You are a helpful assistant.") 121 | 122 | save_path = tmp_path / "program.pkl" 123 | compiled_cot.save(save_path) 124 | 125 | new_cot = dspy.ChainOfThought(MySignature) 126 | new_cot.load(save_path) 127 | 128 | assert str(new_cot.predict.signature) == str(compiled_cot.predict.signature) 129 | assert new_cot.predict.demos == compiled_cot.predict.demos 130 | 131 | 132 | def test_save_with_extra_modules(tmp_path): 133 | import sys 134 | 135 | # Create a temporary Python file with our custom module 136 | custom_module_path = tmp_path / "custom_module.py" 137 | with open(custom_module_path, "w") as f: 138 | f.write(""" 139 | import dspy 140 | 141 | class MyModule(dspy.Module): 142 | def __init__(self): 143 | self.cot = dspy.ChainOfThought(dspy.Signature("q -> a")) 144 | 145 | def forward(self, q): 146 | return self.cot(q=q) 147 | """) 148 | 149 | # Add the tmp_path to Python path so we can import the module 150 | sys.path.insert(0, str(tmp_path)) 151 | try: 152 | import custom_module 153 | 154 | cot = custom_module.MyModule() 155 | 156 | cot.save(tmp_path, save_program=True) 157 | # Remove the custom module from sys.modules to simulate it not being available 158 | sys.modules.pop("custom_module", None) 159 | # Also remove it from sys.path 160 | sys.path.remove(str(tmp_path)) 161 | del custom_module 162 | 163 | # Test the loading fails without using `modules_to_serialize` 164 | with pytest.raises(ModuleNotFoundError): 165 | dspy.load(tmp_path) 166 | 167 | sys.path.insert(0, str(tmp_path)) 168 | import custom_module 169 | 170 | cot.save( 171 | tmp_path, 172 | modules_to_serialize=[custom_module], 173 | save_program=True, 174 | ) 175 | 176 | # Remove the custom module from sys.modules to simulate it not being available 177 | sys.modules.pop("custom_module", None) 178 | # Also remove it from sys.path 179 | sys.path.remove(str(tmp_path)) 180 | del custom_module 181 | 182 | loaded_module = dspy.load(tmp_path) 183 | assert loaded_module.cot.predict.signature == cot.cot.predict.signature 184 | 185 | finally: 186 | # Only need to clean up sys.path 187 | if str(tmp_path) in sys.path: 188 | sys.path.remove(str(tmp_path)) 189 | 190 | 191 | def test_load_with_version_mismatch(tmp_path): 192 | from dspy.primitives.base_module import logger 193 | 194 | # Mock versions during save 195 | save_versions = {"python": "3.9", "dspy": "2.4.0", "cloudpickle": "2.0"} 196 | 197 | # Mock versions during load 198 | load_versions = {"python": "3.10", "dspy": "2.5.0", "cloudpickle": "2.1"} 199 | 200 | predict = dspy.Predict("question->answer") 201 | 202 | # Create a custom handler to capture log messages 203 | class ListHandler(logging.Handler): 204 | def __init__(self): 205 | super().__init__() 206 | self.messages = [] 207 | 208 | def emit(self, record): 209 | self.messages.append(record.getMessage()) 210 | 211 | # Add handler and set level 212 | handler = ListHandler() 213 | original_level = logger.level 214 | logger.addHandler(handler) 215 | logger.setLevel(logging.WARNING) 216 | 217 | try: 218 | save_path = tmp_path / "program.pkl" 219 | # Mock version during save 220 | with patch("dspy.primitives.base_module.get_dependency_versions", return_value=save_versions): 221 | predict.save(save_path) 222 | 223 | # Mock version during load 224 | with patch("dspy.primitives.base_module.get_dependency_versions", return_value=load_versions): 225 | loaded_predict = dspy.Predict("question->answer") 226 | loaded_predict.load(save_path) 227 | 228 | # Assert warnings were logged, and one warning for each mismatched dependency. 229 | assert len(handler.messages) == 3 230 | 231 | for msg in handler.messages: 232 | assert "There is a mismatch of" in msg 233 | 234 | # Verify the model still loads correctly despite version mismatches 235 | assert isinstance(loaded_predict, dspy.Predict) 236 | assert str(predict.signature) == str(loaded_predict.signature) 237 | 238 | finally: 239 | # Clean up: restore original level and remove handler 240 | logger.setLevel(original_level) 241 | logger.removeHandler(handler) 242 | 243 | 244 | @pytest.mark.llm_call 245 | def test_single_module_call_with_usage_tracker(lm_for_test): 246 | dspy.settings.configure(lm=dspy.LM(lm_for_test, cache=False), track_usage=True) 247 | 248 | predict = dspy.ChainOfThought("question -> answer") 249 | output = predict(question="What is the capital of France?") 250 | 251 | lm_usage = output.get_lm_usage() 252 | assert len(lm_usage) == 1 253 | assert lm_usage[lm_for_test]["prompt_tokens"] > 0 254 | assert lm_usage[lm_for_test]["completion_tokens"] > 0 255 | assert lm_usage[lm_for_test]["total_tokens"] > 0 256 | 257 | # Test no usage being tracked when cache is enabled 258 | dspy.settings.configure(lm=dspy.LM(lm_for_test, cache=True), track_usage=True) 259 | for _ in range(2): 260 | output = predict(question="What is the capital of France?") 261 | 262 | assert len(output.get_lm_usage()) == 0 263 | 264 | 265 | @pytest.mark.llm_call 266 | def test_multi_module_call_with_usage_tracker(lm_for_test): 267 | dspy.settings.configure(lm=dspy.LM(lm_for_test, cache=False), track_usage=True) 268 | 269 | class MyProgram(dspy.Module): 270 | def __init__(self): 271 | self.predict1 = dspy.ChainOfThought("question -> answer") 272 | self.predict2 = dspy.ChainOfThought("question, answer -> score") 273 | 274 | def __call__(self, question: str) -> Prediction: 275 | answer = self.predict1(question=question) 276 | score = self.predict2(question=question, answer=answer) 277 | return score 278 | 279 | program = MyProgram() 280 | output = program(question="What is the capital of France?") 281 | 282 | lm_usage = output.get_lm_usage() 283 | assert len(lm_usage) == 1 284 | assert lm_usage[lm_for_test]["prompt_tokens"] > 0 285 | assert lm_usage[lm_for_test]["prompt_tokens"] > 0 286 | assert lm_usage[lm_for_test]["completion_tokens"] > 0 287 | assert lm_usage[lm_for_test]["total_tokens"] > 0 288 | 289 | 290 | # TODO: prepare second model for testing this unit test in ci 291 | @pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="Skip the test if OPENAI_API_KEY is not set.") 292 | def test_usage_tracker_in_parallel(): 293 | class MyProgram(dspy.Module): 294 | def __init__(self, lm): 295 | self.lm = lm 296 | self.predict1 = dspy.ChainOfThought("question -> answer") 297 | self.predict2 = dspy.ChainOfThought("question, answer -> score") 298 | 299 | def __call__(self, question: str) -> Prediction: 300 | with dspy.settings.context(lm=self.lm): 301 | answer = self.predict1(question=question) 302 | score = self.predict2(question=question, answer=answer) 303 | return score 304 | 305 | dspy.settings.configure(track_usage=True) 306 | program1 = MyProgram(lm=dspy.LM("openai/gpt-4o-mini", cache=False)) 307 | program2 = MyProgram(lm=dspy.LM("openai/gpt-3.5-turbo", cache=False)) 308 | 309 | parallelizer = dspy.Parallel() 310 | 311 | results = parallelizer( 312 | [ 313 | (program1, {"question": "What is the meaning of life?"}), 314 | (program2, {"question": "why did a chicken cross the kitchen?"}), 315 | ] 316 | ) 317 | 318 | assert results[0].get_lm_usage() is not None 319 | assert results[1].get_lm_usage() is not None 320 | 321 | assert results[0].get_lm_usage().keys() == set(["openai/gpt-4o-mini"]) 322 | assert results[1].get_lm_usage().keys() == set(["openai/gpt-3.5-turbo"]) 323 | 324 | 325 | @pytest.mark.asyncio 326 | async def test_usage_tracker_async_parallel(): 327 | program = dspy.Predict("question -> answer") 328 | 329 | with patch("litellm.acompletion") as mock_completion: 330 | mock_completion.return_value = ModelResponse( 331 | choices=[Choices(message=Message(content="{'answer': 'Paris'}"))], 332 | usage=Usage( 333 | **{ 334 | "prompt_tokens": 1117, 335 | "completion_tokens": 46, 336 | "total_tokens": 1163, 337 | "prompt_tokens_details": {"cached_tokens": 0, "audio_tokens": 0}, 338 | "completion_tokens_details": { 339 | "reasoning_tokens": 0, 340 | "audio_tokens": 0, 341 | "accepted_prediction_tokens": 0, 342 | "rejected_prediction_tokens": 0, 343 | }, 344 | }, 345 | ), 346 | model="openai/gpt-4o-mini", 347 | ) 348 | 349 | coroutines = [ 350 | program.acall(question="What is the capital of France?"), 351 | program.acall(question="What is the capital of France?"), 352 | program.acall(question="What is the capital of France?"), 353 | program.acall(question="What is the capital of France?"), 354 | ] 355 | with dspy.settings.context( 356 | lm=dspy.LM("openai/gpt-4o-mini", cache=False), track_usage=True, adapter=dspy.JSONAdapter() 357 | ): 358 | results = await asyncio.gather(*coroutines) 359 | 360 | assert results[0].get_lm_usage() is not None 361 | assert results[1].get_lm_usage() is not None 362 | 363 | lm_usage0 = results[0].get_lm_usage()["openai/gpt-4o-mini"] 364 | lm_usage1 = results[1].get_lm_usage()["openai/gpt-4o-mini"] 365 | assert lm_usage0["prompt_tokens"] == 1117 366 | assert lm_usage1["prompt_tokens"] == 1117 367 | assert lm_usage0["completion_tokens"] == 46 368 | assert lm_usage1["completion_tokens"] == 46 369 | assert lm_usage0["total_tokens"] == 1163 370 | assert lm_usage1["total_tokens"] == 1163 371 | 372 | 373 | def test_usage_tracker_no_side_effect(): 374 | class MyProgram(dspy.Module): 375 | def __init__(self): 376 | self.predict = dspy.Predict("question -> answer") 377 | 378 | def forward(self, question: str, **kwargs) -> str: 379 | return self.predict(question=question).answer 380 | 381 | program = MyProgram() 382 | with dspy.context(lm=DummyLM([{"answer": "Paris"}]), track_usage=True): 383 | result = program(question="What is the capital of France?") 384 | assert result == "Paris" 385 | 386 | 387 | def test_module_history(): 388 | class MyProgram(dspy.Module): 389 | def __init__(self, **kwargs): 390 | super().__init__(**kwargs) 391 | self.cot = dspy.ChainOfThought("question -> answer") 392 | 393 | def forward(self, question: str, **kwargs) -> Prediction: 394 | return self.cot(question=question) 395 | 396 | with patch("litellm.completion") as mock_completion: 397 | mock_completion.return_value = ModelResponse( 398 | choices=[ 399 | Choices(message=Message(content="{'reasoning': 'Paris is the capital of France', 'answer': 'Paris'}")) 400 | ], 401 | model="openai/gpt-4o-mini", 402 | ) 403 | dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini", cache=False), adapter=dspy.JSONAdapter()) 404 | program = MyProgram() 405 | program(question="What is the capital of France?") 406 | 407 | # Second call only call the submodule. 408 | program.cot(question="What is the capital of France?") 409 | 410 | # The LM history entity exists in all the ancestor callers. 411 | assert len(program.history) == 1 412 | assert len(program.cot.history) == 2 413 | assert len(program.cot.predict.history) == 2 414 | 415 | # The same history entity is shared across all the ancestor callers to reduce memory usage. 416 | assert id(program.history[0]) == id(program.cot.history[0]) 417 | 418 | assert program.history[0]["outputs"] == ["{'reasoning': 'Paris is the capital of France', 'answer': 'Paris'}"] 419 | 420 | dspy.settings.configure(disable_history=True) 421 | 422 | program(question="What is the capital of France?") 423 | # No history is recorded when history is disabled. 424 | assert len(program.history) == 1 425 | assert len(program.cot.history) == 2 426 | assert len(program.cot.predict.history) == 2 427 | 428 | dspy.settings.configure(disable_history=False) 429 | 430 | program(question="What is the capital of France?") 431 | # History is recorded again when history is enabled. 432 | assert len(program.history) == 2 433 | assert len(program.cot.history) == 3 434 | assert len(program.cot.predict.history) == 3 435 | 436 | 437 | def test_module_history_with_concurrency(): 438 | class MyProgram(dspy.Module): 439 | def __init__(self): 440 | super().__init__() 441 | self.cot = dspy.ChainOfThought("question -> answer") 442 | 443 | def forward(self, question: str, **kwargs) -> Prediction: 444 | return self.cot(question=question) 445 | 446 | with patch("litellm.completion") as mock_completion: 447 | mock_completion.return_value = ModelResponse( 448 | choices=[Choices(message=Message(content="{'reasoning': 'N/A', 'answer': 'Holy crab!'}"))], 449 | model="openai/gpt-4o-mini", 450 | ) 451 | dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini", cache=False), adapter=dspy.JSONAdapter()) 452 | program = MyProgram() 453 | 454 | parallelizer = dspy.Parallel() 455 | 456 | parallelizer( 457 | [ 458 | (program, {"question": "What is the meaning of life?"}), 459 | (program, {"question": "why did a chicken cross the kitchen?"}), 460 | ] 461 | ) 462 | assert len(program.history) == 2 463 | assert len(program.cot.history) == 2 464 | assert len(program.cot.predict.history) == 2 465 | 466 | 467 | @pytest.mark.asyncio 468 | async def test_module_history_async(): 469 | class MyProgram(dspy.Module): 470 | def __init__(self, **kwargs): 471 | super().__init__(**kwargs) 472 | self.cot = dspy.ChainOfThought("question -> answer") 473 | 474 | async def aforward(self, question: str, **kwargs) -> Prediction: 475 | return await self.cot.acall(question=question) 476 | 477 | with patch("litellm.acompletion") as mock_completion: 478 | mock_completion.return_value = ModelResponse( 479 | choices=[ 480 | Choices(message=Message(content="{'reasoning': 'Paris is the capital of France', 'answer': 'Paris'}")) 481 | ], 482 | model="openai/gpt-4o-mini", 483 | ) 484 | program = MyProgram() 485 | with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False), adapter=dspy.JSONAdapter()): 486 | await program.acall(question="What is the capital of France?") 487 | 488 | # Second call only call the submodule. 489 | await program.cot.acall(question="What is the capital of France?") 490 | 491 | # The LM history entity exists in all the ancestor callers. 492 | assert len(program.history) == 1 493 | assert len(program.cot.history) == 2 494 | assert len(program.cot.predict.history) == 2 495 | 496 | # The same history entity is shared across all the ancestor callers to reduce memory usage. 497 | assert id(program.history[0]) == id(program.cot.history[0]) 498 | 499 | assert program.history[0]["outputs"] == ["{'reasoning': 'Paris is the capital of France', 'answer': 'Paris'}"] 500 | 501 | with dspy.context( 502 | disable_history=True, lm=dspy.LM("openai/gpt-4o-mini", cache=False), adapter=dspy.JSONAdapter() 503 | ): 504 | await program.acall(question="What is the capital of France?") 505 | 506 | # No history is recorded when history is disabled. 507 | assert len(program.history) == 1 508 | assert len(program.cot.history) == 2 509 | assert len(program.cot.predict.history) == 2 510 | 511 | with dspy.context( 512 | disable_history=False, lm=dspy.LM("openai/gpt-4o-mini", cache=False), adapter=dspy.JSONAdapter() 513 | ): 514 | await program.acall(question="What is the capital of France?") 515 | # History is recorded again when history is enabled. 516 | assert len(program.history) == 2 517 | assert len(program.cot.history) == 3 518 | assert len(program.cot.predict.history) == 3 519 | 520 | 521 | def test_forward_direct_call_warning(capsys): 522 | class TestModule(dspy.Module): 523 | def forward(self, x): 524 | return x 525 | 526 | module = TestModule() 527 | module.forward("test") 528 | captured = capsys.readouterr() 529 | assert "directly is discouraged" in captured.err 530 | 531 | 532 | def test_forward_through_call_no_warning(capsys): 533 | class TestModule(dspy.Module): 534 | def forward(self, x): 535 | return x 536 | 537 | module = TestModule() 538 | module(x="test") 539 | captured = capsys.readouterr() 540 | assert "directly is discouraged" not in captured.err 541 | ``` -------------------------------------------------------------------------------- /dspy/clients/lm_local_arbor.py: -------------------------------------------------------------------------------- ```python 1 | import time 2 | from datetime import datetime 3 | from typing import TYPE_CHECKING, Any, TypedDict 4 | from urllib.parse import urljoin 5 | 6 | import openai 7 | import requests 8 | 9 | import dspy 10 | from dspy.clients.provider import Provider, ReinforceJob, TrainingJob 11 | from dspy.clients.utils_finetune import GRPOGroup, MultiGPUConfig, TrainDataFormat, TrainingStatus, save_data 12 | 13 | if TYPE_CHECKING: 14 | from dspy.clients.lm import LM 15 | 16 | 17 | class GRPOTrainKwargs(TypedDict): 18 | num_generations: int 19 | 20 | 21 | class ArborTrainingJob(TrainingJob): 22 | def __init__(self, *args, **kwargs): 23 | super().__init__(*args, **kwargs) 24 | self.provider_file_id = None 25 | self.provider_job_id = None 26 | 27 | def cancel(self): 28 | if ArborProvider.does_job_exist(self.provider_job_id): 29 | status = self.status() 30 | if ArborProvider.is_terminal_training_status(status): 31 | err_msg = "Jobs that are complete cannot be canceled." 32 | err_msg += f" Job with ID {self.provider_job_id} is done." 33 | raise Exception(err_msg) 34 | openai.fine_tuning.jobs.cancel(self.provider_job_id) 35 | self.provider_job_id = None 36 | 37 | if self.provider_file_id is not None: 38 | if ArborProvider.does_file_exist(self.provider_file_id): 39 | openai.files.delete(self.provider_file_id) 40 | self.provider_file_id = None 41 | 42 | super().cancel() 43 | 44 | def status(self) -> TrainingStatus: 45 | status = ArborProvider.get_training_status(self.provider_job_id) 46 | return status 47 | 48 | 49 | class ArborReinforceJob(ReinforceJob): 50 | DEFAULT_TRAIN_KWARGS = { # noqa: RUF012 51 | "temperature": 0.9, 52 | "beta": 0.04, 53 | "num_iterations": 1, 54 | "per_device_train_batch_size": 8, 55 | "learning_rate": 1e-6, 56 | "gradient_accumulation_steps": 1, 57 | # This is false by default in TRL, but I think it makes sense to be true for us 58 | "gradient_checkpointing": True, 59 | "lr_scheduler_type": "constant_with_warmup", 60 | "max_prompt_length": None, 61 | "max_completion_length": None, 62 | "gradient_checkpointing_kwargs": None, 63 | "bf16": False, 64 | "scale_rewards": True, 65 | "max_grad_norm": 1.0, 66 | "report_to": "none", 67 | "log_completions": True, 68 | "logging_steps": 100, 69 | # By default, none is the model's max context length 70 | "max_context_length": None, 71 | "lora": False, 72 | } 73 | 74 | def __init__(self, lm: "LM", train_kwargs: GRPOTrainKwargs, gpu_config: MultiGPUConfig = MultiGPUConfig(num_inference_gpus=1, num_training_gpus=1)): 75 | # The teleprompter must ensure that this is set 76 | if "num_generations" not in train_kwargs: 77 | raise ValueError("num_generations must be set in the training kwargs") 78 | 79 | self.lm = lm 80 | self.train_kwargs = train_kwargs 81 | self.provider_job_id = None 82 | self.checkpoints = {} 83 | self.last_checkpoint = None 84 | self.gpu_config = gpu_config 85 | 86 | def initialize(self): 87 | # TODO(GRPO Team): Set provider job ID 88 | num_generations = self.train_kwargs.get("num_generations") 89 | temperature = self.train_kwargs.get("temperature", self.DEFAULT_TRAIN_KWARGS["temperature"]) 90 | beta = self.train_kwargs.get("beta", self.DEFAULT_TRAIN_KWARGS["beta"]) 91 | num_iterations = self.train_kwargs.get("num_iterations", self.DEFAULT_TRAIN_KWARGS["num_iterations"]) 92 | per_device_train_batch_size = self.train_kwargs.get( 93 | "per_device_train_batch_size", self.DEFAULT_TRAIN_KWARGS["per_device_train_batch_size"] 94 | ) 95 | learning_rate = self.train_kwargs.get("learning_rate", self.DEFAULT_TRAIN_KWARGS["learning_rate"]) 96 | gradient_accumulation_steps = self.train_kwargs.get( 97 | "gradient_accumulation_steps", self.DEFAULT_TRAIN_KWARGS["gradient_accumulation_steps"] 98 | ) 99 | gradient_checkpointing = self.train_kwargs.get( 100 | "gradient_checkpointing", self.DEFAULT_TRAIN_KWARGS["gradient_checkpointing"] 101 | ) 102 | lr_scheduler_type = self.train_kwargs.get("lr_scheduler_type", self.DEFAULT_TRAIN_KWARGS["lr_scheduler_type"]) 103 | max_prompt_length = self.train_kwargs.get("max_prompt_length", self.DEFAULT_TRAIN_KWARGS["max_prompt_length"]) 104 | max_completion_length = self.train_kwargs.get( 105 | "max_completion_length", self.DEFAULT_TRAIN_KWARGS["max_completion_length"] 106 | ) 107 | bf16 = self.train_kwargs.get("bf16", self.DEFAULT_TRAIN_KWARGS["bf16"]) 108 | scale_rewards = self.train_kwargs.get("scale_rewards", self.DEFAULT_TRAIN_KWARGS["scale_rewards"]) 109 | gradient_checkpointing_kwargs = self.train_kwargs.get( 110 | "gradient_checkpointing_kwargs", self.DEFAULT_TRAIN_KWARGS["gradient_checkpointing_kwargs"] 111 | ) 112 | max_grad_norm = self.train_kwargs.get("max_grad_norm", self.DEFAULT_TRAIN_KWARGS["max_grad_norm"]) 113 | report_to = self.train_kwargs.get("report_to", self.DEFAULT_TRAIN_KWARGS["report_to"]) 114 | log_completions = self.train_kwargs.get("log_completions", self.DEFAULT_TRAIN_KWARGS["log_completions"]) 115 | logging_steps = self.train_kwargs.get("logging_steps", self.DEFAULT_TRAIN_KWARGS["logging_steps"]) 116 | max_context_length = self.train_kwargs.get( 117 | "max_context_length", self.DEFAULT_TRAIN_KWARGS["max_context_length"] 118 | ) 119 | lora = self.train_kwargs.get("lora", self.DEFAULT_TRAIN_KWARGS["lora"]) 120 | api_base = self.lm.kwargs["api_base"] 121 | 122 | finetune_model = ArborProvider._remove_provider_prefix(self.lm.model) 123 | # Only multi-GPU is supported for now 124 | gpu_config_type = "multi" 125 | data = { 126 | "model": finetune_model, 127 | "num_generations": num_generations, 128 | "temperature": temperature, 129 | "beta": beta, 130 | "num_iterations": num_iterations, 131 | "per_device_train_batch_size": per_device_train_batch_size, 132 | "learning_rate": learning_rate, 133 | "gradient_accumulation_steps": gradient_accumulation_steps, 134 | "gradient_checkpointing": gradient_checkpointing, 135 | "lr_scheduler_type": lr_scheduler_type, 136 | "max_prompt_length": max_prompt_length, 137 | "max_completion_length": max_completion_length, 138 | "bf16": bf16, 139 | "scale_rewards": scale_rewards, 140 | "gradient_checkpointing_kwargs": gradient_checkpointing_kwargs, 141 | "max_grad_norm": max_grad_norm, 142 | "report_to": report_to, 143 | "log_completions": log_completions, 144 | "logging_steps": logging_steps, 145 | "max_context_length": max_context_length, 146 | "lora": lora, 147 | "gpu_config": { 148 | "type": gpu_config_type, 149 | gpu_config_type: self.gpu_config, 150 | }, 151 | } 152 | url = urljoin(api_base, "fine_tuning/grpo/initialize") 153 | headers = {"Content-Type": "application/json"} 154 | response = requests.post(url=url, headers=headers, json=data) 155 | assert response.status_code == 200, f"Failed to initialize GRPO: {response}" 156 | response = response.json() 157 | self.lm.model = ArborProvider._add_provider_prefix(response["current_model"]) 158 | self.provider_job_id = response.get("job_id") 159 | 160 | def _run_grpo_step_one_group( 161 | self, train_group: GRPOGroup, train_data_format: TrainDataFormat | str | None = None 162 | ): 163 | # TODO: Check that the data follows the intended format 164 | api_base = self.lm.kwargs["api_base"] 165 | # api_key = self.lm.kwargs["api_key"] 166 | 167 | finetune_model = ArborProvider._remove_provider_prefix(self.lm.model) 168 | data = {"job_id": self.provider_job_id, "model": finetune_model, "batch": train_group} 169 | url = urljoin(api_base, f"fine_tuning/grpo/{self.provider_job_id}/step") 170 | headers = {"Content-Type": "application/json"} 171 | response = requests.post(url, headers=headers, json=data) 172 | assert response.status_code == 200, f"Failed to run a GRPO step: {response.text}" 173 | response = response.json() 174 | assert "current_model" in response, f"Response does not contain the next model ID to be used: {response}" 175 | current_model = response["current_model"] 176 | self.lm.model = ArborProvider._add_provider_prefix(current_model) 177 | 178 | def step(self, train_data: list[GRPOGroup], train_data_format: TrainDataFormat | str | None): 179 | # Note: TrainDataFormat specifies the format for the inner most dict. 180 | # Because we run GRPO at the group level, train_data will be a list of 181 | # groups, where each group is a list of GRPOChatData. Our teleprompters 182 | # ensure that we pass the right data format. 183 | # We can consider making this distinction clearer, e.g., by having two 184 | # different step methods or changing our smallets data format to be the 185 | # GRPO group. 186 | # TODO: Support step on the server side 187 | assert ( 188 | train_data_format == TrainDataFormat.GRPO_CHAT 189 | ), f"GRPO only supports the GRPO_CHAT data format. Got {train_data_format} instead." 190 | for group in train_data: 191 | self._run_grpo_step_one_group(group, train_data_format) 192 | 193 | def save_checkpoint(self, checkpoint_name: str, score: float | None = None): 194 | api_base = self.lm.kwargs["api_base"] 195 | url = urljoin(api_base, f"fine_tuning/grpo/{self.provider_job_id}/checkpoint") 196 | headers = {"Content-Type": "application/json"} 197 | body = {"job_id": self.provider_job_id, "checkpoint_name": checkpoint_name} 198 | response = requests.post(url, headers=headers, json=body) 199 | assert response.status_code == 200, f"Failed to save checkpoint: {response.text}" 200 | response = response.json() 201 | 202 | last_checkpoint = response["last_checkpoint"] 203 | checkpoints = response["checkpoints"] 204 | checkpoint_model_path = checkpoints[last_checkpoint] 205 | self.checkpoints[last_checkpoint] = { 206 | "model_path": checkpoint_model_path, 207 | "score": score, 208 | } 209 | self.last_checkpoint = last_checkpoint 210 | 211 | def terminate(self): 212 | api_base = self.lm.kwargs["api_base"] 213 | 214 | url = urljoin(api_base, f"fine_tuning/grpo/{self.provider_job_id}/terminate") 215 | headers = {"Content-Type": "application/json"} 216 | body = {"job_id": self.provider_job_id} 217 | response = requests.post(url, headers=headers, json=body) 218 | assert response.status_code == 200, f"Failed to terminate GRPO: {response.text}" 219 | 220 | response = response.json() 221 | current_model = response["current_model"] 222 | self.lm.model = ArborProvider._add_provider_prefix(current_model) 223 | 224 | def cancel(self): 225 | if self.provider_job_id: 226 | api_base = self.lm.kwargs["api_base"] 227 | url = urljoin(api_base, f"fine_tuning/grpo/{self.provider_job_id}/cancel") 228 | headers = {"Content-Type": "application/json"} 229 | response = requests.post(url, headers=headers) 230 | if response.status_code == 200: 231 | self.provider_job_id = None 232 | else: 233 | raise Exception(f"Failed to cancel GRPO job: {response.text}") 234 | 235 | def status(self) -> TrainingStatus: 236 | status = ArborProvider.get_training_status(self.provider_job_id) 237 | return status 238 | 239 | 240 | class ArborProvider(Provider): 241 | def __init__(self): 242 | super().__init__() 243 | self.finetunable = True 244 | self.reinforceable = True 245 | self.TrainingJob = ArborTrainingJob 246 | self.ReinforceJob = ArborReinforceJob 247 | 248 | @staticmethod 249 | def launch(lm: "LM", launch_kwargs: dict[str, Any] | None = None): 250 | model = ArborProvider._remove_provider_prefix(lm.model) 251 | 252 | api_base = lm.kwargs["api_base"] 253 | 254 | launch_kwargs = launch_kwargs or lm.launch_kwargs 255 | 256 | # Make request to launch endpoint 257 | response = requests.post(urljoin(api_base, "chat/launch"), json={"model": model, "launch_kwargs": launch_kwargs}) 258 | 259 | if response.status_code != 200: 260 | raise Exception(f"Failed to launch model. Status code: {response.status_code}, Response: {response.text}") 261 | 262 | print(f"Inference server for model {model} launched successfully") 263 | 264 | @staticmethod 265 | def kill(lm: "LM", launch_kwargs: dict[str, Any] | None = None): 266 | api_base = lm.kwargs["api_base"] 267 | 268 | response = requests.post( 269 | urljoin(api_base, "chat/kill"), 270 | ) 271 | 272 | if response.status_code != 200: 273 | raise Exception(f"Failed to kill model. Status code: {response.status_code}, Response: {response.text}") 274 | 275 | print("Inference killed successfully") 276 | 277 | @staticmethod 278 | def _remove_provider_prefix(model: str) -> str: 279 | if model.startswith("openai/"): 280 | model = model[7:] 281 | if model.startswith("arbor:"): 282 | model = model[6:] 283 | return model 284 | 285 | @staticmethod 286 | def _add_provider_prefix(model: str) -> str: 287 | if not model.startswith("openai/arbor:"): 288 | model = "openai/arbor:" + model 289 | return model 290 | 291 | @staticmethod 292 | def _get_arbor_base_api(): 293 | # TODO: We will delete this method once we start passing the LM object 294 | # to finetune. 295 | import dspy.settings as settings 296 | 297 | if not hasattr(settings, "arbor_api_base"): 298 | raise ValueError( 299 | "Arbor API base not set. Please set the `dspy.settings.arbor_api_base` to the URL for the Arbor server (e.g. 'http://localhost:8000/v1/')." 300 | ) 301 | return dspy.settings.arbor_api_base 302 | 303 | @staticmethod 304 | def finetune( 305 | job: ArborTrainingJob, 306 | model: str, 307 | train_data: list[dict[str, Any]], 308 | train_data_format: TrainDataFormat | None, 309 | train_kwargs: dict[str, Any] | None = None, 310 | ) -> str: 311 | # TODO: We want to re-factor finetune so that it takes in an LM. 312 | # Until then, we use the following to get the api information. The 313 | # following is a dummy call to ensure that dspy.settings.arbor_base_api 314 | # is set. 315 | ArborProvider._get_arbor_base_api() 316 | 317 | model = ArborProvider._remove_provider_prefix(model) 318 | 319 | print("[Arbor Provider] Validating the data format") 320 | ArborProvider.validate_data_format(train_data_format) 321 | 322 | print("[Arbor Provider] Saving the data to a file") 323 | data_path = save_data(train_data) 324 | print(f"[Arbor Provider] Data saved to {data_path}") 325 | 326 | print("[Arbor Provider] Uploading the data to the provider") 327 | provider_file_id = ArborProvider.upload_data(data_path) 328 | job.provider_file_id = provider_file_id 329 | 330 | print("[Arbor Provider] Starting remote training") 331 | provider_job_id = ArborProvider._start_remote_training( 332 | train_file_id=job.provider_file_id, 333 | model=model, 334 | train_kwargs=train_kwargs, 335 | ) 336 | job.provider_job_id = provider_job_id 337 | print(f"[Arbor Provider] Job started with the Arbor Job ID {provider_job_id}") 338 | 339 | print("[Arbor Provider] Waiting for training to complete") 340 | ArborProvider.wait_for_job(job, train_kwargs) 341 | 342 | print("[Arbor Provider] Attempting to retrieve the trained model") 343 | model = ArborProvider.get_trained_model(job) 344 | print(f"[Arbor Provider] Model retrieved: {model}") 345 | 346 | return ArborProvider._add_provider_prefix(model) 347 | 348 | @staticmethod 349 | def does_job_exist(job_id: str, training_kwargs: dict[str, Any]) -> bool: 350 | try: 351 | original_base_url = openai.base_url 352 | openai.base_url = ArborProvider._get_arbor_base_api() 353 | openai.fine_tuning.jobs.retrieve(job_id) 354 | openai.base_url = original_base_url 355 | return True 356 | except Exception: 357 | return False 358 | 359 | @staticmethod 360 | def does_file_exist(file_id: str, training_kwargs: dict[str, Any]) -> bool: 361 | try: 362 | original_base_url = openai.base_url 363 | openai.base_url = ArborProvider._get_arbor_base_api() 364 | openai.files.retrieve(file_id) 365 | openai.base_url = original_base_url 366 | return True 367 | except Exception: 368 | return False 369 | 370 | @staticmethod 371 | def is_terminal_training_status(status: TrainingStatus) -> bool: 372 | return status in [ 373 | TrainingStatus.succeeded, 374 | TrainingStatus.failed, 375 | TrainingStatus.cancelled, 376 | ] 377 | 378 | @staticmethod 379 | def get_training_status(job_id: str, training_kwargs: dict[str, Any]) -> TrainingStatus: 380 | provider_status_to_training_status = { 381 | "validating_files": TrainingStatus.pending, 382 | "queued": TrainingStatus.pending, 383 | "running": TrainingStatus.running, 384 | "succeeded": TrainingStatus.succeeded, 385 | "failed": TrainingStatus.failed, 386 | "cancelled": TrainingStatus.cancelled, 387 | "pending": TrainingStatus.pending, 388 | "pending_pause": TrainingStatus.pending, 389 | "pending_resume": TrainingStatus.pending, 390 | "paused": TrainingStatus.pending, 391 | "pending_cancel": TrainingStatus.pending, 392 | } 393 | 394 | if job_id is None: 395 | print("There is no active job.") 396 | return TrainingStatus.not_started 397 | 398 | err_msg = f"Job with ID {job_id} does not exist." 399 | assert ArborProvider.does_job_exist(job_id, training_kwargs), err_msg 400 | 401 | original_base_url = openai.base_url 402 | openai.base_url = ArborProvider._get_arbor_base_api() 403 | provider_job = openai.fine_tuning.jobs.retrieve(job_id) 404 | openai.base_url = original_base_url 405 | 406 | provider_status = provider_job.status 407 | status = provider_status_to_training_status[provider_status] 408 | 409 | return status 410 | 411 | @staticmethod 412 | def validate_data_format(data_format: TrainDataFormat): 413 | supported_data_formats = [ 414 | TrainDataFormat.CHAT, 415 | TrainDataFormat.COMPLETION, 416 | TrainDataFormat.GRPO_CHAT, 417 | ] 418 | 419 | if data_format not in supported_data_formats: 420 | err_msg = f"Arbor does not support the data format {data_format}." 421 | raise ValueError(err_msg) 422 | 423 | @staticmethod 424 | def upload_data(data_path: str, training_kwargs: dict[str, Any]) -> str: 425 | original_base_url = openai.base_url 426 | openai.base_url = ArborProvider._get_arbor_base_api() 427 | provider_file = openai.files.create( 428 | file=open(data_path, "rb"), 429 | purpose="fine-tune", 430 | ) 431 | openai.base_url = original_base_url 432 | 433 | return provider_file.id 434 | 435 | @staticmethod 436 | def _start_remote_training(train_file_id: str, model: str, train_kwargs: dict[str, Any]) -> str: 437 | train_kwargs = train_kwargs or {} 438 | original_base_url = openai.base_url 439 | openai.base_url = ArborProvider._get_arbor_base_api() 440 | provider_job = openai.fine_tuning.jobs.create( 441 | model=model, 442 | training_file=train_file_id, 443 | hyperparameters=train_kwargs, 444 | ) 445 | openai.base_url = original_base_url 446 | return provider_job.id 447 | 448 | @staticmethod 449 | def wait_for_job( 450 | job: TrainingJob, 451 | training_kwargs: dict[str, Any], 452 | poll_frequency: int = 20, 453 | ): 454 | done = False 455 | cur_event_id = None 456 | reported_estimated_time = False 457 | while not done: 458 | # Report estimated time if not already reported 459 | if not reported_estimated_time: 460 | original_base_url = openai.base_url 461 | openai.base_url = ArborProvider._get_arbor_base_api() 462 | remote_job = openai.fine_tuning.jobs.retrieve(job.provider_job_id) 463 | openai.base_url = original_base_url 464 | 465 | timestamp = remote_job.estimated_finish 466 | if timestamp: 467 | estimated_finish_dt = datetime.fromtimestamp(timestamp) 468 | delta_dt = estimated_finish_dt - datetime.now() 469 | print(f"[Arbor Provider] The Arbor estimated time remaining is: {delta_dt}") 470 | reported_estimated_time = True 471 | 472 | # Get new events 473 | original_base_url = openai.base_url 474 | openai.base_url = ArborProvider._get_arbor_base_api() 475 | page = openai.fine_tuning.jobs.list_events(fine_tuning_job_id=job.provider_job_id, limit=1) 476 | openai.base_url = original_base_url 477 | 478 | new_event = page.data[0] if page.data else None 479 | if new_event and new_event.id != cur_event_id: 480 | dt = datetime.fromtimestamp(new_event.created_at) 481 | print(f"[Arbor Provider] {dt} {new_event.message}") 482 | cur_event_id = new_event.id 483 | 484 | # Sleep and update the flag 485 | time.sleep(poll_frequency) 486 | done = ArborProvider.is_terminal_training_status(job.status()) 487 | 488 | @staticmethod 489 | def get_trained_model(job, training_kwargs: dict[str, Any]): 490 | status = job.status() 491 | if status != TrainingStatus.succeeded: 492 | err_msg = f"Job status is {status}." 493 | err_msg += f" Must be {TrainingStatus.succeeded} to retrieve model." 494 | raise Exception(err_msg) 495 | 496 | original_base_url = openai.base_url 497 | openai.base_url = ArborProvider._get_arbor_base_api() 498 | provider_job = openai.fine_tuning.jobs.retrieve(job.provider_job_id) 499 | openai.base_url = original_base_url 500 | 501 | finetuned_model = provider_job.fine_tuned_model 502 | return finetuned_model 503 | ``` -------------------------------------------------------------------------------- /dspy/adapters/base.py: -------------------------------------------------------------------------------- ```python 1 | import logging 2 | from typing import TYPE_CHECKING, Any, get_origin 3 | 4 | import json_repair 5 | import litellm 6 | 7 | from dspy.adapters.types import History, Type 8 | from dspy.adapters.types.base_type import split_message_content_for_custom_types 9 | from dspy.adapters.types.tool import Tool, ToolCalls 10 | from dspy.experimental import Citations 11 | from dspy.signatures.signature import Signature 12 | from dspy.utils.callback import BaseCallback, with_callbacks 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | if TYPE_CHECKING: 17 | from dspy.clients.lm import LM 18 | 19 | _DEFAULT_NATIVE_RESPONSE_TYPES = [Citations] 20 | 21 | 22 | class Adapter: 23 | """Base Adapter class. 24 | 25 | The Adapter serves as the interface layer between DSPy module/signature and Language Models (LMs). It handles the 26 | complete transformation pipeline from DSPy inputs to LM calls and back to structured outputs. 27 | 28 | Key responsibilities: 29 | - Transform user inputs and signatures into properly formatted LM prompts, which also instructs the LM to format 30 | the response in a specific format. 31 | - Parse LM outputs into dictionaries matching the signature's output fields. 32 | - Enable/disable native LM features (function calling, citations, etc.) based on configuration. 33 | - Handle conversation history, few-shot examples, and custom type processing. 34 | 35 | The adapter pattern allows DSPy to work with different LM interfaces while maintaining a consistent programming 36 | model for users. 37 | """ 38 | 39 | def __init__( 40 | self, 41 | callbacks: list[BaseCallback] | None = None, 42 | use_native_function_calling: bool = False, 43 | native_response_types: list[type[Type]] | None = None, 44 | ): 45 | """ 46 | Args: 47 | callbacks: List of callback functions to execute during `format()` and `parse()` methods. Callbacks can be 48 | used for logging, monitoring, or custom processing. Defaults to None (empty list). 49 | use_native_function_calling: Whether to enable native function calling capabilities when the LM supports it. 50 | If True, the adapter will automatically configure function calling when input fields contain `dspy.Tool` 51 | or `list[dspy.Tool]` types. Defaults to False. 52 | native_response_types: List of output field types that should be handled by native LM features rather than 53 | adapter parsing. For example, `dspy.Citations` can be populated directly by citation APIs 54 | (e.g., Anthropic's citation feature). Defaults to `[Citations]`. 55 | """ 56 | self.callbacks = callbacks or [] 57 | self.use_native_function_calling = use_native_function_calling 58 | self.native_response_types = native_response_types or _DEFAULT_NATIVE_RESPONSE_TYPES 59 | 60 | def __init_subclass__(cls, **kwargs) -> None: 61 | super().__init_subclass__(**kwargs) 62 | 63 | # Decorate format() and parse() method with with_callbacks 64 | cls.format = with_callbacks(cls.format) 65 | cls.parse = with_callbacks(cls.parse) 66 | 67 | def _call_preprocess( 68 | self, 69 | lm: "LM", 70 | lm_kwargs: dict[str, Any], 71 | signature: type[Signature], 72 | inputs: dict[str, Any], 73 | ) -> type[Signature]: 74 | if self.use_native_function_calling: 75 | tool_call_input_field_name = self._get_tool_call_input_field_name(signature) 76 | tool_call_output_field_name = self._get_tool_call_output_field_name(signature) 77 | 78 | if tool_call_output_field_name and tool_call_input_field_name is None: 79 | raise ValueError( 80 | f"You provided an output field {tool_call_output_field_name} to receive the tool calls information, " 81 | "but did not provide any tools as the input. Please provide a list of tools as the input by adding an " 82 | "input field with type `list[dspy.Tool]`." 83 | ) 84 | 85 | if tool_call_output_field_name and litellm.supports_function_calling(model=lm.model): 86 | tools = inputs[tool_call_input_field_name] 87 | tools = tools if isinstance(tools, list) else [tools] 88 | 89 | litellm_tools = [] 90 | for tool in tools: 91 | litellm_tools.append(tool.format_as_litellm_function_call()) 92 | 93 | lm_kwargs["tools"] = litellm_tools 94 | 95 | signature_for_native_function_calling = signature.delete(tool_call_output_field_name) 96 | signature_for_native_function_calling = signature_for_native_function_calling.delete( 97 | tool_call_input_field_name 98 | ) 99 | 100 | return signature_for_native_function_calling 101 | 102 | # Handle custom types that use native response 103 | for name, field in signature.output_fields.items(): 104 | if ( 105 | isinstance(field.annotation, type) 106 | and issubclass(field.annotation, Type) 107 | and field.annotation in self.native_response_types 108 | ): 109 | signature = signature.delete(name) 110 | 111 | return signature 112 | 113 | def _call_postprocess( 114 | self, 115 | processed_signature: type[Signature], 116 | original_signature: type[Signature], 117 | outputs: list[dict[str, Any]], 118 | lm: "LM", 119 | ) -> list[dict[str, Any]]: 120 | values = [] 121 | 122 | tool_call_output_field_name = self._get_tool_call_output_field_name(original_signature) 123 | 124 | for output in outputs: 125 | output_logprobs = None 126 | tool_calls = None 127 | text = output 128 | 129 | if isinstance(output, dict): 130 | text = output["text"] 131 | output_logprobs = output.get("logprobs") 132 | tool_calls = output.get("tool_calls") 133 | 134 | if text: 135 | value = self.parse(processed_signature, text) 136 | for field_name in original_signature.output_fields.keys(): 137 | if field_name not in value: 138 | # We need to set the field not present in the processed signature to None for consistency. 139 | value[field_name] = None 140 | else: 141 | value = {} 142 | for field_name in original_signature.output_fields.keys(): 143 | value[field_name] = None 144 | 145 | if tool_calls and tool_call_output_field_name: 146 | tool_calls = [ 147 | { 148 | "name": v["function"]["name"], 149 | "args": json_repair.loads(v["function"]["arguments"]), 150 | } 151 | for v in tool_calls 152 | ] 153 | value[tool_call_output_field_name] = ToolCalls.from_dict_list(tool_calls) 154 | 155 | # Parse custom types that does not rely on the adapter parsing 156 | for name, field in original_signature.output_fields.items(): 157 | if ( 158 | isinstance(field.annotation, type) 159 | and issubclass(field.annotation, Type) 160 | and field.annotation in self.native_response_types 161 | ): 162 | value[name] = field.annotation.parse_lm_response(output) 163 | 164 | if output_logprobs: 165 | value["logprobs"] = output_logprobs 166 | 167 | values.append(value) 168 | 169 | return values 170 | 171 | def __call__( 172 | self, 173 | lm: "LM", 174 | lm_kwargs: dict[str, Any], 175 | signature: type[Signature], 176 | demos: list[dict[str, Any]], 177 | inputs: dict[str, Any], 178 | ) -> list[dict[str, Any]]: 179 | """ 180 | Execute the adapter pipeline: format inputs, call LM, and parse outputs. 181 | 182 | Args: 183 | lm: The Language Model instance to use for generation. Must be an instance of `dspy.BaseLM`. 184 | lm_kwargs: Additional keyword arguments to pass to the LM call (e.g., temperature, max_tokens). These are 185 | passed directly to the LM. 186 | signature: The DSPy signature associated with this LM call. 187 | demos: List of few-shot examples to include in the prompt. Each dictionary should contain keys matching the 188 | signature's input and output field names. Examples are formatted as user/assistant message pairs. 189 | inputs: The current input values for this call. Keys must match the signature's input field names. 190 | 191 | Returns: 192 | List of dictionaries representing parsed LM responses. Each dictionary contains keys matching the 193 | signature's output field names. For multiple generations (n > 1), returns multiple dictionaries. 194 | """ 195 | processed_signature = self._call_preprocess(lm, lm_kwargs, signature, inputs) 196 | inputs = self.format(processed_signature, demos, inputs) 197 | 198 | outputs = lm(messages=inputs, **lm_kwargs) 199 | return self._call_postprocess(processed_signature, signature, outputs, lm) 200 | 201 | async def acall( 202 | self, 203 | lm: "LM", 204 | lm_kwargs: dict[str, Any], 205 | signature: type[Signature], 206 | demos: list[dict[str, Any]], 207 | inputs: dict[str, Any], 208 | ) -> list[dict[str, Any]]: 209 | processed_signature = self._call_preprocess(lm, lm_kwargs, signature, inputs) 210 | inputs = self.format(processed_signature, demos, inputs) 211 | 212 | outputs = await lm.acall(messages=inputs, **lm_kwargs) 213 | return self._call_postprocess(processed_signature, signature, outputs, lm) 214 | 215 | def format( 216 | self, 217 | signature: type[Signature], 218 | demos: list[dict[str, Any]], 219 | inputs: dict[str, Any], 220 | ) -> list[dict[str, Any]]: 221 | """Format the input messages for the LM call. 222 | 223 | This method converts the DSPy structured input along with few-shot examples and conversation history into 224 | multiturn messages as expected by the LM. For custom adapters, this method can be overridden to customize 225 | the formatting of the input messages. 226 | 227 | In general we recommend the messages to have the following structure: 228 | ``` 229 | [ 230 | {"role": "system", "content": system_message}, 231 | # Begin few-shot examples 232 | {"role": "user", "content": few_shot_example_1_input}, 233 | {"role": "assistant", "content": few_shot_example_1_output}, 234 | {"role": "user", "content": few_shot_example_2_input}, 235 | {"role": "assistant", "content": few_shot_example_2_output}, 236 | ... 237 | # End few-shot examples 238 | # Begin conversation history 239 | {"role": "user", "content": conversation_history_1_input}, 240 | {"role": "assistant", "content": conversation_history_1_output}, 241 | {"role": "user", "content": conversation_history_2_input}, 242 | {"role": "assistant", "content": conversation_history_2_output}, 243 | ... 244 | # End conversation history 245 | {"role": "user", "content": current_input}, 246 | ] 247 | 248 | And system message should contain the field description, field structure, and task description. 249 | ``` 250 | 251 | 252 | Args: 253 | signature: The DSPy signature for which to format the input messages. 254 | demos: A list of few-shot examples. 255 | inputs: The input arguments to the DSPy module. 256 | 257 | Returns: 258 | A list of multiturn messages as expected by the LM. 259 | """ 260 | inputs_copy = dict(inputs) 261 | 262 | # If the signature and inputs have conversation history, we need to format the conversation history and 263 | # remove the history field from the signature. 264 | history_field_name = self._get_history_field_name(signature) 265 | if history_field_name: 266 | # In order to format the conversation history, we need to remove the history field from the signature. 267 | signature_without_history = signature.delete(history_field_name) 268 | conversation_history = self.format_conversation_history( 269 | signature_without_history, 270 | history_field_name, 271 | inputs_copy, 272 | ) 273 | 274 | messages = [] 275 | system_message = ( 276 | f"{self.format_field_description(signature)}\n" 277 | f"{self.format_field_structure(signature)}\n" 278 | f"{self.format_task_description(signature)}" 279 | ) 280 | messages.append({"role": "system", "content": system_message}) 281 | messages.extend(self.format_demos(signature, demos)) 282 | if history_field_name: 283 | # Conversation history and current input 284 | content = self.format_user_message_content(signature_without_history, inputs_copy, main_request=True) 285 | messages.extend(conversation_history) 286 | messages.append({"role": "user", "content": content}) 287 | else: 288 | # Only current input 289 | content = self.format_user_message_content(signature, inputs_copy, main_request=True) 290 | messages.append({"role": "user", "content": content}) 291 | 292 | messages = split_message_content_for_custom_types(messages) 293 | return messages 294 | 295 | def format_field_description(self, signature: type[Signature]) -> str: 296 | """Format the field description for the system message. 297 | 298 | This method formats the field description for the system message. It should return a string that contains 299 | the field description for the input fields and the output fields. 300 | 301 | Args: 302 | signature: The DSPy signature for which to format the field description. 303 | 304 | Returns: 305 | A string that contains the field description for the input fields and the output fields. 306 | """ 307 | raise NotImplementedError 308 | 309 | def format_field_structure(self, signature: type[Signature]) -> str: 310 | """Format the field structure for the system message. 311 | 312 | This method formats the field structure for the system message. It should return a string that dictates the 313 | format the input fields should be provided to the LM, and the format the output fields will be in the response. 314 | Refer to the ChatAdapter and JsonAdapter for an example. 315 | 316 | Args: 317 | signature: The DSPy signature for which to format the field structure. 318 | """ 319 | raise NotImplementedError 320 | 321 | def format_task_description(self, signature: type[Signature]) -> str: 322 | """Format the task description for the system message. 323 | 324 | This method formats the task description for the system message. In most cases this is just a thin wrapper 325 | over `signature.instructions`. 326 | 327 | Args: 328 | signature: The DSPy signature of the DSpy module. 329 | 330 | Returns: 331 | A string that describes the task. 332 | """ 333 | raise NotImplementedError 334 | 335 | def format_user_message_content( 336 | self, 337 | signature: type[Signature], 338 | inputs: dict[str, Any], 339 | prefix: str = "", 340 | suffix: str = "", 341 | main_request: bool = False, 342 | ) -> str: 343 | """Format the user message content. 344 | 345 | This method formats the user message content, which can be used in formatting few-shot examples, conversation 346 | history, and the current input. 347 | 348 | Args: 349 | signature: The DSPy signature for which to format the user message content. 350 | inputs: The input arguments to the DSPy module. 351 | prefix: A prefix to the user message content. 352 | suffix: A suffix to the user message content. 353 | 354 | Returns: 355 | A string that contains the user message content. 356 | """ 357 | raise NotImplementedError 358 | 359 | def format_assistant_message_content( 360 | self, 361 | signature: type[Signature], 362 | outputs: dict[str, Any], 363 | missing_field_message: str | None = None, 364 | ) -> str: 365 | """Format the assistant message content. 366 | 367 | This method formats the assistant message content, which can be used in formatting few-shot examples, 368 | conversation history. 369 | 370 | Args: 371 | signature: The DSPy signature for which to format the assistant message content. 372 | outputs: The output fields to be formatted. 373 | missing_field_message: A message to be used when a field is missing. 374 | 375 | Returns: 376 | A string that contains the assistant message content. 377 | """ 378 | raise NotImplementedError 379 | 380 | def format_demos(self, signature: type[Signature], demos: list[dict[str, Any]]) -> list[dict[str, Any]]: 381 | """Format the few-shot examples. 382 | 383 | This method formats the few-shot examples as multiturn messages. 384 | 385 | Args: 386 | signature: The DSPy signature for which to format the few-shot examples. 387 | demos: A list of few-shot examples, each element is a dictionary with keys of the input and output fields of 388 | the signature. 389 | 390 | Returns: 391 | A list of multiturn messages. 392 | """ 393 | complete_demos = [] 394 | incomplete_demos = [] 395 | 396 | for demo in demos: 397 | # Check if all fields are present and not None 398 | is_complete = all(k in demo and demo[k] is not None for k in signature.fields) 399 | 400 | # Check if demo has at least one input and one output field 401 | has_input = any(k in demo for k in signature.input_fields) 402 | has_output = any(k in demo for k in signature.output_fields) 403 | 404 | if is_complete: 405 | complete_demos.append(demo) 406 | elif has_input and has_output: 407 | # We only keep incomplete demos that have at least one input and one output field 408 | incomplete_demos.append(demo) 409 | 410 | messages = [] 411 | 412 | incomplete_demo_prefix = "This is an example of the task, though some input or output fields are not supplied." 413 | for demo in incomplete_demos: 414 | messages.append( 415 | { 416 | "role": "user", 417 | "content": self.format_user_message_content(signature, demo, prefix=incomplete_demo_prefix), 418 | } 419 | ) 420 | messages.append( 421 | { 422 | "role": "assistant", 423 | "content": self.format_assistant_message_content( 424 | signature, demo, missing_field_message="Not supplied for this particular example. " 425 | ), 426 | } 427 | ) 428 | 429 | for demo in complete_demos: 430 | messages.append({"role": "user", "content": self.format_user_message_content(signature, demo)}) 431 | messages.append( 432 | { 433 | "role": "assistant", 434 | "content": self.format_assistant_message_content( 435 | signature, demo, missing_field_message="Not supplied for this conversation history message. " 436 | ), 437 | } 438 | ) 439 | 440 | return messages 441 | 442 | def _get_history_field_name(self, signature: type[Signature]) -> bool: 443 | for name, field in signature.input_fields.items(): 444 | if field.annotation == History: 445 | return name 446 | return None 447 | 448 | def _get_tool_call_input_field_name(self, signature: type[Signature]) -> bool: 449 | for name, field in signature.input_fields.items(): 450 | # Look for annotation `list[dspy.Tool]` or `dspy.Tool` 451 | origin = get_origin(field.annotation) 452 | if origin is list and field.annotation.__args__[0] == Tool: 453 | return name 454 | if field.annotation == Tool: 455 | return name 456 | return None 457 | 458 | def _get_tool_call_output_field_name(self, signature: type[Signature]) -> bool: 459 | for name, field in signature.output_fields.items(): 460 | if field.annotation == ToolCalls: 461 | return name 462 | return None 463 | 464 | def format_conversation_history( 465 | self, 466 | signature: type[Signature], 467 | history_field_name: str, 468 | inputs: dict[str, Any], 469 | ) -> list[dict[str, Any]]: 470 | """Format the conversation history. 471 | 472 | This method formats the conversation history and the current input as multiturn messages. 473 | 474 | Args: 475 | signature: The DSPy signature for which to format the conversation history. 476 | history_field_name: The name of the history field in the signature. 477 | inputs: The input arguments to the DSPy module. 478 | 479 | Returns: 480 | A list of multiturn messages. 481 | """ 482 | conversation_history = inputs[history_field_name].messages if history_field_name in inputs else None 483 | 484 | if conversation_history is None: 485 | return [] 486 | 487 | messages = [] 488 | for message in conversation_history: 489 | messages.append( 490 | { 491 | "role": "user", 492 | "content": self.format_user_message_content(signature, message), 493 | } 494 | ) 495 | messages.append( 496 | { 497 | "role": "assistant", 498 | "content": self.format_assistant_message_content(signature, message), 499 | } 500 | ) 501 | 502 | # Remove the history field from the inputs 503 | del inputs[history_field_name] 504 | 505 | return messages 506 | 507 | def parse(self, signature: type[Signature], completion: str) -> dict[str, Any]: 508 | """Parse the LM output into a dictionary of the output fields. 509 | 510 | This method parses the LM output into a dictionary of the output fields. 511 | 512 | Args: 513 | signature: The DSPy signature for which to parse the LM output. 514 | completion: The LM output to be parsed. 515 | 516 | Returns: 517 | A dictionary of the output fields. 518 | """ 519 | raise NotImplementedError 520 | ```