This is page 9 of 14. Use http://codebase.md/stanfordnlp/dspy?page={x} to view the full context. # Directory Structure ``` ├── .github │ ├── .internal_dspyai │ │ ├── internals │ │ │ ├── build-and-release.md │ │ │ └── release-checklist.md │ │ └── pyproject.toml │ ├── .tmp │ │ └── .generated-actions │ │ └── run-pypi-publish-in-docker-container │ │ └── action.yml │ ├── ISSUE_TEMPLATE │ │ ├── bug_report.yml │ │ └── feature_request.yml │ ├── PULL_REQUEST_TEMPLATE │ │ └── pull_request_template.md │ ├── workflow_scripts │ │ └── install_testpypi_pkg.sh │ └── workflows │ ├── build_and_release.yml │ ├── build_utils │ │ └── test_version.py │ ├── docs-push.yml │ ├── precommits_check.yml │ └── run_tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── docs │ ├── .gitignore │ ├── docs │ │ ├── api │ │ │ ├── adapters │ │ │ │ ├── Adapter.md │ │ │ │ ├── ChatAdapter.md │ │ │ │ ├── JSONAdapter.md │ │ │ │ └── TwoStepAdapter.md │ │ │ ├── evaluation │ │ │ │ ├── answer_exact_match.md │ │ │ │ ├── answer_passage_match.md │ │ │ │ ├── CompleteAndGrounded.md │ │ │ │ ├── Evaluate.md │ │ │ │ ├── EvaluationResult.md │ │ │ │ └── SemanticF1.md │ │ │ ├── experimental │ │ │ │ ├── Citations.md │ │ │ │ └── Document.md │ │ │ ├── index.md │ │ │ ├── models │ │ │ │ ├── Embedder.md │ │ │ │ └── LM.md │ │ │ ├── modules │ │ │ │ ├── BestOfN.md │ │ │ │ ├── ChainOfThought.md │ │ │ │ ├── CodeAct.md │ │ │ │ ├── Module.md │ │ │ │ ├── MultiChainComparison.md │ │ │ │ ├── Parallel.md │ │ │ │ ├── Predict.md │ │ │ │ ├── ProgramOfThought.md │ │ │ │ ├── ReAct.md │ │ │ │ └── Refine.md │ │ │ ├── optimizers │ │ │ │ ├── BetterTogether.md │ │ │ │ ├── BootstrapFewShot.md │ │ │ │ ├── BootstrapFewShotWithRandomSearch.md │ │ │ │ ├── BootstrapFinetune.md │ │ │ │ ├── BootstrapRS.md │ │ │ │ ├── COPRO.md │ │ │ │ ├── Ensemble.md │ │ │ │ ├── GEPA │ │ │ │ │ ├── GEPA_Advanced.md │ │ │ │ │ └── overview.md │ │ │ │ ├── InferRules.md │ │ │ │ ├── KNN.md │ │ │ │ ├── KNNFewShot.md │ │ │ │ ├── LabeledFewShot.md │ │ │ │ ├── MIPROv2.md │ │ │ │ └── SIMBA.md │ │ │ ├── primitives │ │ │ │ ├── Audio.md │ │ │ │ ├── Code.md │ │ │ │ ├── Example.md │ │ │ │ ├── History.md │ │ │ │ ├── Image.md │ │ │ │ ├── Prediction.md │ │ │ │ ├── Tool.md │ │ │ │ └── ToolCalls.md │ │ │ ├── signatures │ │ │ │ ├── InputField.md │ │ │ │ ├── OutputField.md │ │ │ │ └── Signature.md │ │ │ ├── tools │ │ │ │ ├── ColBERTv2.md │ │ │ │ ├── Embeddings.md │ │ │ │ └── PythonInterpreter.md │ │ │ └── utils │ │ │ ├── asyncify.md │ │ │ ├── configure_cache.md │ │ │ ├── disable_litellm_logging.md │ │ │ ├── disable_logging.md │ │ │ ├── enable_litellm_logging.md │ │ │ ├── enable_logging.md │ │ │ ├── inspect_history.md │ │ │ ├── load.md │ │ │ ├── StatusMessage.md │ │ │ ├── StatusMessageProvider.md │ │ │ ├── streamify.md │ │ │ └── StreamListener.md │ │ ├── cheatsheet.md │ │ ├── community │ │ │ ├── community-resources.md │ │ │ ├── how-to-contribute.md │ │ │ └── use-cases.md │ │ ├── deep-dive │ │ │ └── data-handling │ │ │ ├── built-in-datasets.md │ │ │ ├── examples.md │ │ │ ├── img │ │ │ │ └── data-loading.png │ │ │ └── loading-custom-data.md │ │ ├── faqs.md │ │ ├── index.md │ │ ├── js │ │ │ └── runllm-widget.js │ │ ├── learn │ │ │ ├── evaluation │ │ │ │ ├── data.md │ │ │ │ ├── metrics.md │ │ │ │ └── overview.md │ │ │ ├── figures │ │ │ │ ├── native_tool_call.png │ │ │ │ └── teleprompter-classes.png │ │ │ ├── index.md │ │ │ ├── optimization │ │ │ │ ├── optimizers.md │ │ │ │ └── overview.md │ │ │ └── programming │ │ │ ├── 7-assertions.md │ │ │ ├── adapters.md │ │ │ ├── language_models.md │ │ │ ├── mcp.md │ │ │ ├── modules.md │ │ │ ├── overview.md │ │ │ ├── signatures.md │ │ │ └── tools.md │ │ ├── production │ │ │ └── index.md │ │ ├── roadmap.md │ │ ├── static │ │ │ ├── .nojekyll │ │ │ └── img │ │ │ ├── dspy_logo.png │ │ │ ├── logo.png │ │ │ ├── mlflow-tracing-rag.png │ │ │ ├── modular.png │ │ │ ├── optimize.png │ │ │ ├── undraw_docusaurus_mountain.svg │ │ │ ├── undraw_docusaurus_react.svg │ │ │ ├── undraw_docusaurus_tree.svg │ │ │ └── universal_compatibility.png │ │ ├── stylesheets │ │ │ └── extra.css │ │ └── tutorials │ │ ├── agents │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-agent.png │ │ ├── ai_text_game │ │ │ └── index.md │ │ ├── async │ │ │ └── index.md │ │ ├── audio │ │ │ └── index.ipynb │ │ ├── build_ai_program │ │ │ └── index.md │ │ ├── cache │ │ │ └── index.md │ │ ├── classification │ │ │ └── index.md │ │ ├── classification_finetuning │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-classification.png │ │ ├── conversation_history │ │ │ └── index.md │ │ ├── core_development │ │ │ └── index.md │ │ ├── custom_module │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-custom-module.png │ │ ├── customer_service_agent │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-customer-service-agent.png │ │ ├── deployment │ │ │ ├── dspy_mlflow_ui.png │ │ │ └── index.md │ │ ├── email_extraction │ │ │ ├── index.md │ │ │ └── mlflow-tracing-email-extraction.png │ │ ├── entity_extraction │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-entity-extraction.png │ │ ├── games │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-agent.png │ │ ├── gepa_ai_program │ │ │ └── index.md │ │ ├── gepa_aime │ │ │ ├── index.ipynb │ │ │ ├── mlflow-tracing-gepa-aime.png │ │ │ └── mlflow-tracking-gepa-aime-optimization.png │ │ ├── gepa_facilitysupportanalyzer │ │ │ ├── index.ipynb │ │ │ ├── mlflow-tracing-gepa-support.png │ │ │ └── mlflow-tracking-gepa-support-optimization.png │ │ ├── gepa_papillon │ │ │ ├── index.ipynb │ │ │ ├── mlflow-tracing-gepa-papilon.png │ │ │ └── mlflow-tracking-gepa-papilon-optimization.png │ │ ├── image_generation_prompting │ │ │ └── index.ipynb │ │ ├── index.md │ │ ├── llms_txt_generation │ │ │ └── index.md │ │ ├── math │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-math.png │ │ ├── mcp │ │ │ └── index.md │ │ ├── mem0_react_agent │ │ │ └── index.md │ │ ├── multihop_search │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-multi-hop.png │ │ ├── observability │ │ │ ├── index.md │ │ │ ├── mlflow_trace_ui_navigation.gif │ │ │ ├── mlflow_trace_ui.png │ │ │ └── mlflow_trace_view.png │ │ ├── optimize_ai_program │ │ │ └── index.md │ │ ├── optimizer_tracking │ │ │ ├── child_run.png │ │ │ ├── experiment.png │ │ │ ├── index.md │ │ │ └── parent_run.png │ │ ├── output_refinement │ │ │ └── best-of-n-and-refine.md │ │ ├── papillon │ │ │ └── index.md │ │ ├── program_of_thought │ │ │ └── index.ipynb │ │ ├── rag │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-rag.png │ │ ├── real_world_examples │ │ │ └── index.md │ │ ├── rl_ai_program │ │ │ └── index.md │ │ ├── rl_multihop │ │ │ └── index.ipynb │ │ ├── rl_papillon │ │ │ └── index.ipynb │ │ ├── sample_code_generation │ │ │ └── index.md │ │ ├── saving │ │ │ └── index.md │ │ ├── streaming │ │ │ └── index.md │ │ ├── tool_use │ │ │ └── index.ipynb │ │ └── yahoo_finance_react │ │ └── index.md │ ├── mkdocs.yml │ ├── overrides │ │ ├── home.html │ │ ├── main.html │ │ └── partials │ │ └── tabs.html │ ├── Pipfile │ ├── Pipfile.lock │ ├── README.md │ ├── requirements.txt │ ├── scripts │ │ ├── generate_api_docs.py │ │ └── generate_api_summary.py │ └── vercel.json ├── dspy │ ├── __init__.py │ ├── __metadata__.py │ ├── adapters │ │ ├── __init__.py │ │ ├── baml_adapter.py │ │ ├── base.py │ │ ├── chat_adapter.py │ │ ├── json_adapter.py │ │ ├── two_step_adapter.py │ │ ├── types │ │ │ ├── __init__.py │ │ │ ├── audio.py │ │ │ ├── base_type.py │ │ │ ├── citation.py │ │ │ ├── code.py │ │ │ ├── document.py │ │ │ ├── history.py │ │ │ ├── image.py │ │ │ └── tool.py │ │ ├── utils.py │ │ └── xml_adapter.py │ ├── clients │ │ ├── __init__.py │ │ ├── base_lm.py │ │ ├── cache.py │ │ ├── databricks.py │ │ ├── embedding.py │ │ ├── lm_local_arbor.py │ │ ├── lm_local.py │ │ ├── lm.py │ │ ├── openai.py │ │ ├── provider.py │ │ └── utils_finetune.py │ ├── datasets │ │ ├── __init__.py │ │ ├── alfworld │ │ │ ├── __init__.py │ │ │ ├── alfworld.py │ │ │ └── base_config.yml │ │ ├── colors.py │ │ ├── dataloader.py │ │ ├── dataset.py │ │ ├── gsm8k.py │ │ ├── hotpotqa.py │ │ └── math.py │ ├── dsp │ │ ├── __init__.py │ │ ├── colbertv2.py │ │ └── utils │ │ ├── __init__.py │ │ ├── dpr.py │ │ ├── settings.py │ │ └── utils.py │ ├── evaluate │ │ ├── __init__.py │ │ ├── auto_evaluation.py │ │ ├── evaluate.py │ │ └── metrics.py │ ├── experimental │ │ └── __init__.py │ ├── predict │ │ ├── __init__.py │ │ ├── aggregation.py │ │ ├── avatar │ │ │ ├── __init__.py │ │ │ ├── avatar.py │ │ │ ├── models.py │ │ │ └── signatures.py │ │ ├── best_of_n.py │ │ ├── chain_of_thought.py │ │ ├── code_act.py │ │ ├── knn.py │ │ ├── multi_chain_comparison.py │ │ ├── parallel.py │ │ ├── parameter.py │ │ ├── predict.py │ │ ├── program_of_thought.py │ │ ├── react.py │ │ ├── refine.py │ │ └── retry.py │ ├── primitives │ │ ├── __init__.py │ │ ├── base_module.py │ │ ├── example.py │ │ ├── module.py │ │ ├── prediction.py │ │ ├── python_interpreter.py │ │ └── runner.js │ ├── propose │ │ ├── __init__.py │ │ ├── dataset_summary_generator.py │ │ ├── grounded_proposer.py │ │ ├── propose_base.py │ │ └── utils.py │ ├── retrievers │ │ ├── __init__.py │ │ ├── databricks_rm.py │ │ ├── embeddings.py │ │ ├── retrieve.py │ │ └── weaviate_rm.py │ ├── signatures │ │ ├── __init__.py │ │ ├── field.py │ │ ├── signature.py │ │ └── utils.py │ ├── streaming │ │ ├── __init__.py │ │ ├── messages.py │ │ ├── streamify.py │ │ └── streaming_listener.py │ ├── teleprompt │ │ ├── __init__.py │ │ ├── avatar_optimizer.py │ │ ├── bettertogether.py │ │ ├── bootstrap_finetune.py │ │ ├── bootstrap_trace.py │ │ ├── bootstrap.py │ │ ├── copro_optimizer.py │ │ ├── ensemble.py │ │ ├── gepa │ │ │ ├── __init__.py │ │ │ ├── gepa_utils.py │ │ │ ├── gepa.py │ │ │ └── instruction_proposal.py │ │ ├── grpo.py │ │ ├── infer_rules.py │ │ ├── knn_fewshot.py │ │ ├── mipro_optimizer_v2.py │ │ ├── random_search.py │ │ ├── signature_opt.py │ │ ├── simba_utils.py │ │ ├── simba.py │ │ ├── teleprompt_optuna.py │ │ ├── teleprompt.py │ │ ├── utils.py │ │ └── vanilla.py │ └── utils │ ├── __init__.py │ ├── annotation.py │ ├── asyncify.py │ ├── caching.py │ ├── callback.py │ ├── dummies.py │ ├── exceptions.py │ ├── hasher.py │ ├── inspect_history.py │ ├── langchain_tool.py │ ├── logging_utils.py │ ├── mcp.py │ ├── parallelizer.py │ ├── saving.py │ ├── syncify.py │ ├── unbatchify.py │ └── usage_tracker.py ├── LICENSE ├── pyproject.toml ├── README.md ├── tests │ ├── __init__.py │ ├── adapters │ │ ├── test_adapter_utils.py │ │ ├── test_baml_adapter.py │ │ ├── test_base_type.py │ │ ├── test_chat_adapter.py │ │ ├── test_citation.py │ │ ├── test_code.py │ │ ├── test_document.py │ │ ├── test_json_adapter.py │ │ ├── test_tool.py │ │ ├── test_two_step_adapter.py │ │ └── test_xml_adapter.py │ ├── callback │ │ └── test_callback.py │ ├── clients │ │ ├── test_cache.py │ │ ├── test_databricks.py │ │ ├── test_embedding.py │ │ ├── test_inspect_global_history.py │ │ └── test_lm.py │ ├── conftest.py │ ├── datasets │ │ └── test_dataset.py │ ├── docs │ │ └── test_mkdocs_links.py │ ├── evaluate │ │ ├── test_evaluate.py │ │ └── test_metrics.py │ ├── examples │ │ └── test_baleen.py │ ├── metadata │ │ └── test_metadata.py │ ├── predict │ │ ├── test_aggregation.py │ │ ├── test_best_of_n.py │ │ ├── test_chain_of_thought.py │ │ ├── test_code_act.py │ │ ├── test_knn.py │ │ ├── test_multi_chain_comparison.py │ │ ├── test_parallel.py │ │ ├── test_predict.py │ │ ├── test_program_of_thought.py │ │ ├── test_react.py │ │ ├── test_refine.py │ │ └── test_retry.py │ ├── primitives │ │ ├── resources │ │ │ └── saved_program.json │ │ ├── test_base_module.py │ │ ├── test_example.py │ │ ├── test_module.py │ │ └── test_python_interpreter.py │ ├── propose │ │ └── test_grounded_proposer.py │ ├── README.md │ ├── reliability │ │ ├── __init__.py │ │ ├── complex_types │ │ │ └── generated │ │ │ ├── test_many_types_1 │ │ │ │ ├── inputs │ │ │ │ │ ├── input1.json │ │ │ │ │ └── input2.json │ │ │ │ ├── program.py │ │ │ │ └── schema.json │ │ │ ├── test_nesting_1 │ │ │ │ ├── inputs │ │ │ │ │ ├── input1.json │ │ │ │ │ └── input2.json │ │ │ │ ├── program.py │ │ │ │ └── schema.json │ │ │ └── test_nesting_2 │ │ │ ├── inputs │ │ │ │ └── input1.json │ │ │ ├── program.py │ │ │ └── schema.json │ │ ├── conftest.py │ │ ├── generate │ │ │ ├── __init__.py │ │ │ ├── __main__.py │ │ │ └── utils.py │ │ ├── input_formats │ │ │ └── generated │ │ │ └── test_markdown_1 │ │ │ ├── inputs │ │ │ │ ├── input1.json │ │ │ │ └── input2.json │ │ │ ├── program.py │ │ │ └── schema.json │ │ ├── README.md │ │ ├── reliability_conf.yaml │ │ ├── test_generated.py │ │ ├── test_pydantic_models.py │ │ └── utils.py │ ├── retrievers │ │ └── test_embeddings.py │ ├── signatures │ │ ├── test_adapter_image.py │ │ ├── test_custom_types.py │ │ └── test_signature.py │ ├── streaming │ │ └── test_streaming.py │ ├── teleprompt │ │ ├── gepa_dummy_lm_custom_component_selector_custom_instruction_proposer.json │ │ ├── gepa_dummy_lm.json │ │ ├── test_bootstrap_finetune.py │ │ ├── test_bootstrap_trace.py │ │ ├── test_bootstrap.py │ │ ├── test_copro_optimizer.py │ │ ├── test_ensemble.py │ │ ├── test_finetune.py │ │ ├── test_gepa_instruction_proposer.py │ │ ├── test_gepa.py │ │ ├── test_grpo.py │ │ ├── test_knn_fewshot.py │ │ ├── test_random_search.py │ │ ├── test_teleprompt.py │ │ └── test_utils.py │ ├── test_utils │ │ ├── __init__.py │ │ └── server │ │ ├── __init__.py │ │ ├── litellm_server_config.yaml │ │ └── litellm_server.py │ └── utils │ ├── __init__.py │ ├── resources │ │ └── mcp_server.py │ ├── test_annotation.py │ ├── test_asyncify.py │ ├── test_exceptions.py │ ├── test_langchain_tool.py │ ├── test_mcp.py │ ├── test_parallelizer.py │ ├── test_saving.py │ ├── test_settings.py │ ├── test_syncify.py │ ├── test_unbatchify.py │ └── test_usage_tracker.py └── uv.lock ``` # Files -------------------------------------------------------------------------------- /dspy/teleprompt/copro_optimizer.py: -------------------------------------------------------------------------------- ```python import logging from collections import defaultdict import dspy from dspy.evaluate.evaluate import Evaluate from dspy.signatures import Signature from dspy.teleprompt.teleprompt import Teleprompter logger = logging.getLogger(__name__) """ USAGE SUGGESTIONS: The following code can be used to compile a optimized signature teleprompter, and evaluate it on an end task: teleprompter = COPRO(prompt_model=prompt_model, metric=metric, breadth=BREADTH, depth=DEPTH, init_temperature=INIT_TEMPERATURE) kwargs = dict(num_threads=NUM_THREADS, display_progress=True, display_table=0) compiled_prompt_opt = teleprompter.compile(program.deepcopy(), trainset=trainset[:DEV_NUM], eval_kwargs=kwargs) eval_score = evaluate(compiled_prompt_opt, devset=evalset[:EVAL_NUM], **kwargs) Note that this teleprompter takes in the following parameters: * prompt_model: The model used for prompt generation. When unspecified, defaults to the model set in settings (ie. dspy.settings.configure(lm=task_model)). * metric: The task metric used for optimization. * breadth: The number of new prompts to generate at each iteration. Default=10. * depth: The number of times we should ask our prompt model to generate new prompts, with the history of the past prompts as input. Default=3. * init_temperature: The temperature used to generate new prompts. Higher roughly equals more creative. Default=1.4. * track_stats: Tells the method whether or not to track statistics about the optimization process. If True, the method will track the following statistics: * results_best: The min,max,avg,stddev of top 10 scores for each predictor at each depth. * results_latest: The min,max,avg,stddev of newest prompt scores for each predictor at each depth. * total_calls: The total number of calls to the task metric. These statistics will be returned as attributes of the best program. """ class BasicGenerateInstruction(Signature): """You are an instruction optimizer for large language models. I will give you a ``signature`` of fields (inputs and outputs) in English. Your task is to propose an instruction that will lead a good language model to perform the task well. Don't be afraid to be creative.""" basic_instruction = dspy.InputField(desc="The initial instructions before optimization") proposed_instruction = dspy.OutputField(desc="The improved instructions for the language model") proposed_prefix_for_output_field = dspy.OutputField( desc="The string at the end of the prompt, which will help the model start solving the task", ) class GenerateInstructionGivenAttempts(dspy.Signature): """You are an instruction optimizer for large language models. I will give some task instructions I've tried, along with their corresponding validation scores. The instructions are arranged in increasing order based on their scores, where higher scores indicate better quality. Your task is to propose a new instruction that will lead a good language model to perform the task even better. Don't be afraid to be creative.""" attempted_instructions = dspy.InputField() proposed_instruction = dspy.OutputField(desc="The improved instructions for the language model") proposed_prefix_for_output_field = dspy.OutputField( desc="The string at the end of the prompt, which will help the model start solving the task", ) class COPRO(Teleprompter): def __init__( self, prompt_model=None, metric=None, breadth=10, depth=3, init_temperature=1.4, track_stats=False, **_kwargs, ): if breadth <= 1: raise ValueError("Breadth must be greater than 1") self.metric = metric self.breadth = breadth self.depth = depth self.init_temperature = init_temperature self.prompt_model = prompt_model self.track_stats = track_stats def _check_candidates_equal(self, candidate1, candidate2): for p1, p2 in zip(candidate1["program"].predictors(), candidate2["program"].predictors(), strict=False): if self._get_signature(p1).instructions != self._get_signature(p2).instructions: return False *_, p1_last_field = self._get_signature(p1).fields.values() *_, p2_last_field = self._get_signature(p2).fields.values() if p1_last_field != p2_last_field: return False return True def _drop_duplicates(self, candidates): final_candidates = [] last_batch = [] last_batch_score = -1 for c in candidates: repeat = False if c["score"] == last_batch_score: for c2 in last_batch: if self._check_candidates_equal(c, c2): repeat = True break if not repeat: last_batch.append(c) else: last_batch = [c] last_batch_score = c["score"] if not repeat: final_candidates.append(c) return final_candidates def _print_signature(self, predictor): signature = self._get_signature(predictor) logger.debug(f"i: {signature.instructions}") logger.debug(f"p: {list(signature.fields.values())[-1].json_schema_extra['prefix']}") def _get_signature(self, predictor): assert hasattr(predictor, "signature") return predictor.signature def _set_signature(self, predictor, updated_signature): assert hasattr(predictor, "signature") predictor.signature = updated_signature def compile(self, student, *, trainset, eval_kwargs): """ optimizes `signature` of `student` program - note that it may be zero-shot or already pre-optimized (demos already chosen - `demos != []`) parameters: student: program to optimize and left modified. trainset: iterable of `Example`s eval_kwargs: optional, dict Additional keywords to go into `Evaluate` for the metric. Returns optimized version of `student`. """ module = student.deepcopy() evaluate = Evaluate(devset=trainset, metric=self.metric, **eval_kwargs) total_calls = 0 results_best = { id(p): {"depth": [], "max": [], "average": [], "min": [], "std": []} for p in module.predictors() } results_latest = { id(p): {"depth": [], "max": [], "average": [], "min": [], "std": []} for p in module.predictors() } if self.track_stats: import numpy as np candidates = {} evaluated_candidates = defaultdict(dict) # Seed the prompt optimizer zero shot with just the instruction, generate BREADTH new prompts for predictor in module.predictors(): basic_instruction = None basic_prefix = None *_, last_key = self._get_signature(predictor).fields.keys() basic_instruction = self._get_signature(predictor).instructions basic_prefix = self._get_signature(predictor).fields[last_key].json_schema_extra["prefix"] if self.prompt_model: with dspy.settings.context(lm=self.prompt_model): instruct = dspy.Predict( BasicGenerateInstruction, n=self.breadth - 1, temperature=self.init_temperature, )(basic_instruction=basic_instruction) else: instruct = dspy.Predict( BasicGenerateInstruction, n=self.breadth - 1, temperature=self.init_temperature, )(basic_instruction=basic_instruction) # Add in our initial prompt as a candidate as well instruct.completions.proposed_instruction.append(basic_instruction) instruct.completions.proposed_prefix_for_output_field.append(basic_prefix) candidates[id(predictor)] = instruct.completions evaluated_candidates[id(predictor)] = {} if self.prompt_model: logger.debug(f"{self.prompt_model.inspect_history(n=1)}") latest_candidates = candidates all_candidates = candidates module_clone = module.deepcopy() # For each iteration in depth... for d in range( self.depth, ): # TODO: fix this so that we eval the new batch of predictors with the new best following predictors logger.info(f"Iteration Depth: {d+1}/{self.depth}.") latest_scores = [] # Go through our module's predictors for p_i, (p_old, p_new) in enumerate(zip(module.predictors(), module_clone.predictors(), strict=False)): candidates_ = latest_candidates[id(p_old)] # Use the most recently generated candidates for evaluation if len(module.predictors()) > 1: # Unless our program has multiple predictors, in which case we need to reevaluate all prompts with # the new prompt(s) for the other predictor(s). candidates_ = all_candidates[ id(p_old) ] # For each candidate for c_i, c in enumerate(candidates_): # Get the candidate instruction and prefix instruction, prefix = ( c.proposed_instruction.strip('"').strip(), c.proposed_prefix_for_output_field.strip('"').strip(), ) # Set this new module with our instruction / prefix *_, last_key = self._get_signature(p_new).fields.keys() updated_signature = ( self._get_signature(p_new) .with_instructions(instruction) .with_updated_fields(last_key, prefix=prefix) ) self._set_signature(p_new, updated_signature) # Score the instruction / prefix for i, predictor in enumerate(module_clone.predictors()): logger.debug(f"Predictor {i+1}") self._print_signature(predictor) logger.info( f"At Depth {d+1}/{self.depth}, Evaluating Prompt Candidate #{c_i+1}/{len(candidates_)} for " f"Predictor {p_i+1} of {len(module.predictors())}.", ) score = evaluate(module_clone, devset=trainset, **eval_kwargs).score if self.prompt_model: logger.debug(f"prompt_model.inspect_history(n=1) {self.prompt_model.inspect_history(n=1)}") total_calls += 1 replace_entry = True logger.debug(f"(instruction, prefix) {(instruction, prefix)}") if (instruction, prefix) in evaluated_candidates[id(p_old)]: if evaluated_candidates[id(p_old)][(instruction, prefix)]["score"] >= score: replace_entry = False if replace_entry: # Add it to our evaluated candidates list evaluated_candidates[id(p_old)][(instruction, prefix)] = { "score": score, "program": module_clone.deepcopy(), "instruction": instruction, "prefix": prefix, "depth": d, } if len(candidates_) - self.breadth <= c_i: latest_scores.append(score) if self.track_stats: results_latest[id(p_old)]["depth"].append(d) results_latest[id(p_old)]["max"].append(max(latest_scores)) results_latest[id(p_old)]["average"].append(sum(latest_scores) / len(latest_scores)) results_latest[id(p_old)]["min"].append(min(latest_scores)) results_latest[id(p_old)]["std"].append(np.std(latest_scores)) # Now that we've evaluated the candidates, set this predictor to the best performing version # to ensure the next round of scores reflect the best possible version best_candidate = max(evaluated_candidates[id(p_old)].values(), key=lambda candidate: candidate["score"]) *_, last_key = self._get_signature(p_old).fields.keys() updated_signature = ( self._get_signature(p_new) .with_instructions(best_candidate["instruction"]) .with_updated_fields(last_key, prefix=best_candidate["prefix"]) ) self._set_signature(p_new, updated_signature) logger.debug( f"Updating Predictor {id(p_old)} to:\ni: {best_candidate['instruction']}\n" f"p: {best_candidate['prefix']}", ) logger.debug("Full predictor with update: ") for i, predictor in enumerate(module_clone.predictors()): logger.debug(f"Predictor {i}") self._print_signature(predictor) if d == self.depth - 1: break new_candidates = {} for p_base in module.predictors(): # Build Few-Shot Example of Optimized Prompts attempts = [] shortest_len = self.breadth shortest_len = min(len(evaluated_candidates[id(p_base)]), shortest_len) best_predictors = list(evaluated_candidates[id(p_base)].values()) # best_predictors = evaluated_candidates[id(p_base)].values()[:] best_predictors.sort(key=lambda x: x["score"], reverse=True) if self.track_stats: scores = [x["score"] for x in best_predictors][:10] results_best[id(p_base)]["depth"].append(d) results_best[id(p_base)]["max"].append(max(scores)) results_best[id(p_base)]["average"].append(sum(scores) / len(scores)) results_best[id(p_base)]["min"].append(min(scores)) results_best[id(p_base)]["std"].append(np.std(scores)) for i in range(shortest_len - 1, -1, -1): # breakpoint() attempts.append(f'Instruction #{shortest_len-i}: {best_predictors[i]["instruction"]}') attempts.append(f'Prefix #{shortest_len-i}: {best_predictors[i]["prefix"]}') attempts.append(f'Resulting Score #{shortest_len-i}: {best_predictors[i]["score"]}') # Generate next batch of potential prompts to optimize, with previous attempts as input if self.prompt_model: with dspy.settings.context(lm=self.prompt_model): instr = dspy.Predict( GenerateInstructionGivenAttempts, n=self.breadth, temperature=self.init_temperature, )(attempted_instructions=attempts) else: instr = dspy.Predict( GenerateInstructionGivenAttempts, n=self.breadth, temperature=self.init_temperature, )(attempted_instructions=attempts) # Get candidates for each predictor new_candidates[id(p_base)] = instr.completions all_candidates[id(p_base)].proposed_instruction.extend(instr.completions.proposed_instruction) all_candidates[id(p_base)].proposed_prefix_for_output_field.extend( instr.completions.proposed_prefix_for_output_field, ) latest_candidates = new_candidates candidates = [] for predictor in module.predictors(): candidates.extend(list(evaluated_candidates[id(predictor)].values())) if self.track_stats: best_predictors = list(evaluated_candidates[id(predictor)].values()) best_predictors.sort(key=lambda x: x["score"], reverse=True) scores = [x["score"] for x in best_predictors][:10] results_best[id(predictor)]["depth"].append(d) results_best[id(predictor)]["max"].append(max(scores)) results_best[id(predictor)]["average"].append(sum(scores) / len(scores)) results_best[id(predictor)]["min"].append(min(scores)) results_best[id(predictor)]["std"].append(np.std(scores)) candidates.sort(key=lambda x: x["score"], reverse=True) candidates = self._drop_duplicates(candidates) best_program = candidates[0]["program"] best_program.candidate_programs = candidates best_program.total_calls = total_calls if self.track_stats: best_program.results_best = results_best best_program.results_latest = results_latest return best_program ``` -------------------------------------------------------------------------------- /dspy/propose/grounded_proposer.py: -------------------------------------------------------------------------------- ```python import random import dspy from dspy.propose.dataset_summary_generator import create_dataset_summary from dspy.propose.propose_base import Proposer from dspy.propose.utils import ( create_example_string, create_predictor_level_history_string, get_dspy_source_code, strip_prefix, ) from dspy.teleprompt.utils import get_prompt_model, get_signature # Hardcoded variables (TODO: update) MAX_INSTRUCT_IN_HISTORY = 5 # 10 TIPS = { "none": "", "creative": "Don't be afraid to be creative when creating the new instruction!", "simple": "Keep the instruction clear and concise.", "description": "Make sure your instruction is very informative and descriptive.", "high_stakes": "The instruction should include a high stakes scenario in which the LM must solve the task!", "persona": 'Include a persona that is relevant to the task in the instruction (ie. "You are a ...")', } ### SIGNATURES USED TO HELP WITH INSTRUCTION GENERATION ### class DescribeProgram(dspy.Signature): ( """Below is some pseudo-code for a pipeline that solves tasks with calls to language models. Please describe what type of task this program appears to be designed to solve, and how it appears to work.""" ) program_code = dspy.InputField( format=str, desc="Pseudocode for a language model program designed to solve a particular task.", prefix="PROGRAM CODE:", ) program_example = dspy.InputField( format=str, desc="An example of the program in use.", prefix="EXAMPLE OF PROGRAM IN USE:", ) program_description = dspy.OutputField( desc="Describe what task the program is designed to solve, and how it goes about solving this task.", prefix="SUMMARY OF PROGRAM ABOVE:", ) class DescribeModule(dspy.Signature): ( """Below is some pseudo-code for a pipeline that solves tasks with calls to language models. Please describe the purpose of one of the specified module in this pipeline.""" ) program_code = dspy.InputField( format=str, desc="Pseudocode for a language model program designed to solve a particular task.", prefix="PROGRAM CODE:", ) program_example = dspy.InputField( format=str, desc="An example of the program in use.", prefix="EXAMPLE OF PROGRAM IN USE:", ) program_description = dspy.InputField( desc="Summary of the task the program is designed to solve, and how it goes about solving it.", prefix="SUMMARY OF PROGRAM ABOVE:", ) module = dspy.InputField( desc="The module in the program that we want to describe.", prefix="MODULE:", ) module_description = dspy.OutputField( desc="Description of the module's role in the broader program.", prefix="MODULE DESCRIPTION:", ) def generate_instruction_class( use_dataset_summary=True, program_aware=True, use_task_demos=True, use_instruct_history=True, use_tip=True, ): class GenerateSingleModuleInstruction(dspy.Signature): ( """Use the information below to learn about a task that we are trying to solve using calls to an LM, then generate a new instruction that will be used to prompt a Language Model to better solve the task.""" ) if use_dataset_summary: dataset_description = dspy.InputField( desc="A description of the dataset that we are using.", prefix="DATASET SUMMARY:", ) if program_aware: program_code = dspy.InputField( format=str, desc="Language model program designed to solve a particular task.", prefix="PROGRAM CODE:", ) program_description = dspy.InputField( desc="Summary of the task the program is designed to solve, and how it goes about solving it.", prefix="PROGRAM DESCRIPTION:", ) module = dspy.InputField( desc="The module to create an instruction for.", prefix="MODULE:", ) module_description = dspy.InputField( desc="Description of the module to create an instruction for.", prefix="MODULE DESCRIPTION:", ) task_demos = dspy.InputField( format=str, desc="Example inputs/outputs of our module.", prefix="TASK DEMO(S):", ) if use_instruct_history: previous_instructions = dspy.InputField( format=str, desc="Previous instructions we've attempted, along with their associated scores.", prefix="PREVIOUS INSTRUCTIONS:", ) basic_instruction = dspy.InputField( format=str, desc="Basic instruction.", prefix="BASIC INSTRUCTION:", ) if use_tip: tip = dspy.InputField( format=str, desc="A suggestion for how to go about generating the new instruction.", prefix="TIP:", ) proposed_instruction = dspy.OutputField( desc="Propose an instruction that will be used to prompt a Language Model to perform this task.", prefix="PROPOSED INSTRUCTION:", ) return dspy.Predict(GenerateSingleModuleInstruction) ### CLASS RESPONSIBLE FOR GENERATING A NEW INSTRUCTION, USING THE HELPER SIGNATURES ABOVE ### class GenerateModuleInstruction(dspy.Module): def __init__( self, program_code_string=None, use_dataset_summary=True, program_aware=False, use_task_demos=True, use_instruct_history=True, use_tip=True, verbose=False, ): super().__init__() self.use_dataset_summary = use_dataset_summary self.program_aware = program_aware self.use_task_demos = use_task_demos self.use_instruct_history = use_instruct_history self.use_tip = use_tip self.verbose = verbose self.program_code_string = program_code_string self.describe_program = dspy.Predict(DescribeProgram) self.describe_module = dspy.Predict(DescribeModule) self.generate_module_instruction = generate_instruction_class( use_dataset_summary=use_dataset_summary, program_aware=program_aware, use_task_demos=use_task_demos, use_instruct_history=use_instruct_history, use_tip=use_tip, ) def forward( self, demo_candidates, pred_i, demo_set_i, program, previous_instructions, data_summary, num_demos_in_context=3, tip=None, ): def gather_examples_from_sets(candidate_sets, max_examples): """Helper function to gather up to augmented examples from given sets.""" count = 0 for candidate_set in candidate_sets: for example in candidate_set: if "augmented" in example.keys(): fields_to_use = get_signature(program.predictors()[pred_i]).fields yield create_example_string(fields_to_use, example) count += 1 if count >= max_examples: return # Construct full program demo or single module demo depending on settings basic_instruction = get_signature(program.predictors()[pred_i]).instructions task_demos = "" if self.use_task_demos: # Combine current and adjacent sets adjacent_sets = ( [demo_candidates[pred_i][demo_set_i]] + demo_candidates[pred_i][demo_set_i + 1:] + demo_candidates[pred_i][:demo_set_i] ) # Gather examples up to the required count example_strings = gather_examples_from_sets(adjacent_sets, num_demos_in_context) task_demos = "\n\n".join(example_strings) + "\n\n" # Default to no demos provided if no examples were gathered, or if we're using the first demo set if not task_demos.strip() or demo_set_i == 0: task_demos = "No task demos provided." # Summarize the program program_description = "Not available" module_code = "Not provided" module_description = "Not provided" if self.program_aware: try: program_description = strip_prefix( self.describe_program( program_code=self.program_code_string, program_example=task_demos, ).program_description, ) if self.verbose: print(f"PROGRAM DESCRIPTION: {program_description}") inputs = [] outputs = [] for field_name, field in get_signature(program.predictors()[pred_i]).fields.items(): # Access the '__dspy_field_type' from the extra metadata dspy_field_type = field.json_schema_extra.get("__dspy_field_type") # Based on the '__dspy_field_type', append to the respective list if dspy_field_type == "input": inputs.append(field_name) else: outputs.append(field_name) module_code = f"{program.predictors()[pred_i].__class__.__name__}({', '.join(inputs)}) -> {', '.join(outputs)}" module_description = self.describe_module( program_code=self.program_code_string, program_description=program_description, program_example=task_demos, module=module_code, max_depth=10, ).module_description except Exception as e: if self.verbose: print(f"Error getting program description. Running without program aware proposer. Error: {e}") self.program_aware = False # Generate an instruction for our chosen module if self.verbose: print(f"task_demos {task_demos}") instruct = self.generate_module_instruction( dataset_description=data_summary, program_code=self.program_code_string, module=module_code, program_description=program_description, module_description=module_description, task_demos=task_demos, tip=tip, basic_instruction=basic_instruction, previous_instructions=previous_instructions, ) proposed_instruction = strip_prefix(instruct.proposed_instruction) return dspy.Prediction(proposed_instruction=proposed_instruction) ### CLASS USED TO GENERATE THE FULL SET OF INSTRUCTIONS GIVEN THE SPECIFIED CRITERIA ### class GroundedProposer(Proposer): def __init__( self, prompt_model, program, trainset, view_data_batch_size=10, use_dataset_summary=True, program_aware=True, use_task_demos=True, num_demos_in_context = 3, use_instruct_history=True, use_tip=True, set_tip_randomly=True, set_history_randomly=True, verbose=False, rng=None, init_temperature: float = 1.0, ): super().__init__() self.program_aware = program_aware self.use_dataset_summary = use_dataset_summary self.use_task_demos = use_task_demos self.num_demos_in_context = num_demos_in_context self.use_instruct_history = use_instruct_history self.use_tip = use_tip self.set_tip_randomly=set_tip_randomly self.set_history_randomly=set_history_randomly self.verbose = verbose self.rng = rng or random self.prompt_model = get_prompt_model(prompt_model) self.init_temperature = init_temperature self.program_code_string = None if self.program_aware: try: self.program_code_string = get_dspy_source_code(program) if self.verbose: print("SOURCE CODE:",self.program_code_string) except Exception as e: print(f"Error getting source code: {e}.\n\nRunning without program aware proposer.") self.program_aware = False self.data_summary = None if self.use_dataset_summary: try: self.data_summary = create_dataset_summary( trainset=trainset, view_data_batch_size=view_data_batch_size, prompt_model=prompt_model, ) if self.verbose: print(f"DATA SUMMARY: {self.data_summary}") except Exception as e: print(f"Error getting data summary: {e}.\n\nRunning without data aware proposer.") self.use_dataset_summary = False print("") def propose_instructions_for_program( self, trainset, program, demo_candidates, trial_logs, N, # noqa: N803 ) -> list[str]: """This method is responsible for returning the full set of new instructions for our program, given the specified criteria.""" proposed_instructions = {} if self.set_history_randomly: # Randomly select whether or not we're using instruction history use_history = self.rng.random() < 0.5 self.use_instruct_history = use_history if self.verbose: print(f"Use history T/F: {self.use_instruct_history}") if not demo_candidates: if self.verbose: print("No demo candidates provided. Running without task demos.") self.use_task_demos = False # When no demo candidates are provided, default to N num_demos = N else: num_demos = max(len(demo_candidates[0]), 1) # Create an instruction for each predictor for pred_i, predictor in enumerate(program.predictors()): for demo_set_i in range(num_demos)[:min(N, num_demos)]: if pred_i not in proposed_instructions: proposed_instructions[pred_i] = [] selected_tip = None if self.set_tip_randomly: if self.verbose: print("Using a randomly generated configuration for our grounded proposer.") # Randomly select the tip selected_tip_key = self.rng.choice(list(TIPS.keys())) selected_tip = TIPS[selected_tip_key] self.use_tip = bool( selected_tip, ) if self.verbose: print(f"Selected tip: {selected_tip_key}") proposed_instructions[pred_i].append( self.propose_instruction_for_predictor( program=program, predictor=predictor, pred_i=pred_i, demo_candidates=demo_candidates, demo_set_i=demo_set_i, trial_logs=trial_logs, tip=selected_tip, ), ) return proposed_instructions def propose_instruction_for_predictor( self, program, predictor, pred_i, demo_candidates, demo_set_i, trial_logs, tip=None, ) -> str: """This method is responsible for returning a single instruction for a given predictor, using the specified criteria.""" # Create an instruction history string for our predictor instruction_history = create_predictor_level_history_string( program, pred_i, trial_logs, MAX_INSTRUCT_IN_HISTORY, ) # Create our instruction generator class (given specific criteria for this round of proposal) instruction_generator = GenerateModuleInstruction( program_code_string=self.program_code_string, use_dataset_summary=self.use_dataset_summary, program_aware=self.program_aware, use_task_demos=self.use_task_demos and demo_candidates, use_instruct_history=self.use_instruct_history and instruction_history, use_tip=self.use_tip, verbose=self.verbose ) # Generate a new instruction for our predictor using a unique rollout id to bypass cache rollout_lm = self.prompt_model.copy( rollout_id=self.rng.randint(0, 10**9), temperature=self.init_temperature, ) with dspy.settings.context(lm=rollout_lm): proposed_instruction = instruction_generator( demo_candidates=demo_candidates, pred_i=pred_i, demo_set_i=demo_set_i, program=program, data_summary=self.data_summary, previous_instructions=instruction_history, num_demos_in_context = self.num_demos_in_context, tip=tip, ).proposed_instruction # Log the trace used to generate the new instruction, along with the new instruction itself if self.verbose: self.prompt_model.inspect_history(n=1) print(f"PROPOSED INSTRUCTION: {proposed_instruction}") return strip_prefix(proposed_instruction) ``` -------------------------------------------------------------------------------- /tests/adapters/test_tool.py: -------------------------------------------------------------------------------- ```python import asyncio from typing import Any import pytest from pydantic import BaseModel import dspy from dspy.adapters.types.tool import Tool, ToolCalls, convert_input_schema_to_tool_args # Test fixtures def dummy_function(x: int, y: str = "hello") -> str: """A dummy function for testing. Args: x: An integer parameter y: A string parameter """ return f"{y} {x}" class DummyModel(BaseModel): field1: str = "hello" field2: int def dummy_with_pydantic(model: DummyModel) -> str: """A dummy function that accepts a Pydantic model.""" return f"{model.field1} {model.field2}" class Address(BaseModel): street: str city: str zip_code: str is_primary: bool = False class ContactInfo(BaseModel): email: str phone: str | None = None addresses: list[Address] class UserProfile(BaseModel): user_id: int name: str age: int | None = None contact: ContactInfo tags: list[str] = [] class Note(BaseModel): content: str author: str def complex_dummy_function(profile: UserProfile, priority: int, notes: list[Note] | None = None) -> dict[str, Any]: """Process user profile with complex nested structure. Args: profile: User profile containing nested contact and address information priority: Priority level of the processing notes: Optional processing notes """ primary_address = next( (addr for addr in profile.contact.addresses if addr.is_primary), profile.contact.addresses[0] ) return { "user_id": profile.user_id, "name": profile.name, "priority": priority, "primary_address": primary_address.model_dump(), "notes": notes, } async def async_dummy_function(x: int, y: str = "hello") -> str: """An async dummy function for testing. Args: x: An integer parameter y: A string parameter """ await asyncio.sleep(0.1) # Simulate some async work return f"{y} {x}" async def async_dummy_with_pydantic(model: DummyModel) -> str: """An async dummy function that accepts a Pydantic model.""" await asyncio.sleep(0.1) # Simulate some async work return f"{model.field1} {model.field2}" async def async_complex_dummy_function( profile: UserProfile, priority: int, notes: list[Note] | None = None, ) -> dict[str, Any]: """Process user profile with complex nested structure asynchronously. Args: profile: User profile containing nested contact and address information priority: Priority level of the processing notes: Optional processing notes """ # Simulate some async processing work await asyncio.sleep(0.1) primary_address = next( (addr for addr in profile.contact.addresses if addr.is_primary), profile.contact.addresses[0] ) # Simulate more async work after finding primary address await asyncio.sleep(0.1) return { "user_id": profile.user_id, "name": profile.name, "priority": priority, "primary_address": primary_address.model_dump(), "notes": notes, } def test_basic_initialization(): tool = Tool(name="test_tool", desc="A test tool", args={"param1": {"type": "string"}}, func=lambda x: x) assert tool.name == "test_tool" assert tool.desc == "A test tool" assert tool.args == {"param1": {"type": "string"}} assert callable(tool.func) def test_tool_from_function(): tool = Tool(dummy_function) assert tool.name == "dummy_function" assert "A dummy function for testing" in tool.desc assert "x" in tool.args assert "y" in tool.args assert tool.args["x"]["type"] == "integer" assert tool.args["y"]["type"] == "string" assert tool.args["y"]["default"] == "hello" def test_tool_from_class(): class Foo: def __init__(self, user_id: str): self.user_id = user_id def __call__(self, a: int, b: int) -> int: """Add two numbers.""" return a + b tool = Tool(Foo("123")) assert tool.name == "Foo" assert tool.desc == "Add two numbers." assert tool.args == {"a": {"type": "integer"}, "b": {"type": "integer"}} def test_tool_from_function_with_pydantic(): tool = Tool(dummy_with_pydantic) assert tool.name == "dummy_with_pydantic" assert "model" in tool.args assert tool.args["model"]["type"] == "object" assert "field1" in tool.args["model"]["properties"] assert "field2" in tool.args["model"]["properties"] assert tool.args["model"]["properties"]["field1"]["default"] == "hello" def test_tool_from_function_with_pydantic_nesting(): tool = Tool(complex_dummy_function) assert tool.name == "complex_dummy_function" assert "profile" in tool.args assert "priority" in tool.args assert "notes" in tool.args assert tool.args["profile"]["type"] == "object" assert tool.args["profile"]["properties"]["user_id"]["type"] == "integer" assert tool.args["profile"]["properties"]["name"]["type"] == "string" assert tool.args["profile"]["properties"]["age"]["anyOf"] == [{"type": "integer"}, {"type": "null"}] assert tool.args["profile"]["properties"]["contact"]["type"] == "object" assert tool.args["profile"]["properties"]["contact"]["properties"]["email"]["type"] == "string" # Reference should be resolved for nested pydantic models assert "$defs" not in str(tool.args["notes"]) assert tool.args["notes"]["anyOf"][0]["type"] == "array" assert tool.args["notes"]["anyOf"][0]["items"]["type"] == "object" assert tool.args["notes"]["anyOf"][0]["items"]["properties"]["content"]["type"] == "string" assert tool.args["notes"]["anyOf"][0]["items"]["properties"]["author"]["type"] == "string" def test_tool_callable(): tool = Tool(dummy_function) result = tool(x=42, y="hello") assert result == "hello 42" def test_tool_with_pydantic_callable(): tool = Tool(dummy_with_pydantic) model = DummyModel(field1="test", field2=123) result = tool(model=model) assert result == "test 123" def test_invalid_function_call(): tool = Tool(dummy_function) with pytest.raises(ValueError): tool(x="not an integer", y="hello") def test_parameter_desc(): tool = Tool(dummy_function, arg_desc={"x": "The x parameter"}) assert tool.args["x"]["description"] == "The x parameter" def test_tool_with_default_args_without_type_hints(): def foo(x=100): return x tool = Tool(foo) assert tool.args["x"]["default"] == 100 assert not hasattr(tool.args["x"], "type") def test_tool_call_parses_args(): tool = Tool(dummy_with_pydantic) args = { "model": { "field1": "hello", "field2": 123, } } result = tool(**args) assert result == "hello 123" def test_tool_call_parses_nested_list_of_pydantic_model(): def dummy_function(x: list[list[DummyModel]]): return x tool = Tool(dummy_function) args = { "x": [ [ { "field1": "hello", "field2": 123, } ] ] } result = tool(**args) assert result == [[DummyModel(field1="hello", field2=123)]] def test_tool_call_kwarg(): def fn(x: int, **kwargs): return kwargs tool = Tool(fn) assert tool(x=1, y=2, z=3) == {"y": 2, "z": 3} def test_tool_str(): def add(x: int, y: int = 0) -> int: """Add two integers.""" return x + y tool = Tool(add) assert ( str(tool) == "add, whose description is <desc>Add two integers.</desc>. It takes arguments {'x': {'type': 'integer'}, 'y': {'type': 'integer', 'default': 0}}." ) @pytest.mark.asyncio async def test_async_tool_from_function(): tool = Tool(async_dummy_function) assert tool.name == "async_dummy_function" assert "An async dummy function for testing" in tool.desc assert "x" in tool.args assert "y" in tool.args assert tool.args["x"]["type"] == "integer" assert tool.args["y"]["type"] == "string" assert tool.args["y"]["default"] == "hello" # Test async call result = await tool.acall(x=42, y="hello") assert result == "hello 42" @pytest.mark.asyncio async def test_async_tool_with_pydantic(): tool = Tool(async_dummy_with_pydantic) assert tool.name == "async_dummy_with_pydantic" assert "model" in tool.args assert tool.args["model"]["type"] == "object" assert "field1" in tool.args["model"]["properties"] assert "field2" in tool.args["model"]["properties"] # Test async call with pydantic model model = DummyModel(field1="test", field2=123) result = await tool.acall(model=model) assert result == "test 123" # Test async call with dict result = await tool.acall(model={"field1": "test", "field2": 123}) assert result == "test 123" @pytest.mark.asyncio async def test_async_tool_with_complex_pydantic(): tool = Tool(async_complex_dummy_function) profile = UserProfile( user_id=1, name="Test User", contact=ContactInfo( email="[email protected]", addresses=[ Address(street="123 Main St", city="Test City", zip_code="12345", is_primary=True), Address(street="456 Side St", city="Test City", zip_code="12345"), ], ), ) result = await tool.acall(profile=profile, priority=1, notes=[Note(content="Test note", author="Test author")]) assert result["user_id"] == 1 assert result["name"] == "Test User" assert result["priority"] == 1 assert result["notes"] == [Note(content="Test note", author="Test author")] assert result["primary_address"]["street"] == "123 Main St" @pytest.mark.asyncio async def test_async_tool_invalid_call(): tool = Tool(async_dummy_function) with pytest.raises(ValueError): await tool.acall(x="not an integer", y="hello") @pytest.mark.asyncio async def test_async_tool_with_kwargs(): async def fn(x: int, **kwargs): return kwargs tool = Tool(fn) result = await tool.acall(x=1, y=2, z=3) assert result == {"y": 2, "z": 3} @pytest.mark.asyncio async def test_async_concurrent_calls(): """Test that multiple async tools can run concurrently.""" tool = Tool(async_dummy_function) # Create multiple concurrent calls tasks = [tool.acall(x=i, y=f"hello{i}") for i in range(5)] # Run them concurrently and measure time start_time = asyncio.get_event_loop().time() results = await asyncio.gather(*tasks) end_time = asyncio.get_event_loop().time() # Verify results, `asyncio.gather` returns results in the order of the tasks assert results == [f"hello{i} {i}" for i in range(5)] # Check that it ran concurrently (should take ~0.1s, not ~0.5s) # We use 0.3s as threshold to account for some overhead assert end_time - start_time < 0.3 @pytest.mark.filterwarnings("ignore::RuntimeWarning") def test_async_tool_call_in_sync_mode(): tool = Tool(async_dummy_function) with dspy.context(allow_tool_async_sync_conversion=False): with pytest.raises(ValueError): result = tool(x=1, y="hello") with dspy.context(allow_tool_async_sync_conversion=True): result = tool(x=1, y="hello") assert result == "hello 1" TOOL_CALL_TEST_CASES = [ ([], {"tool_calls": []}), ( [{"name": "search", "args": {"query": "hello"}}], { "tool_calls": [{"type": "function", "function": {"name": "search", "arguments": {"query": "hello"}}}], }, ), ( [ {"name": "search", "args": {"query": "hello"}}, {"name": "translate", "args": {"text": "world", "lang": "fr"}}, ], { "tool_calls": [ {"type": "function", "function": {"name": "search", "arguments": {"query": "hello"}}}, { "type": "function", "function": {"name": "translate", "arguments": {"text": "world", "lang": "fr"}}, }, ], }, ), ( [{"name": "get_time", "args": {}}], { "tool_calls": [{"type": "function", "function": {"name": "get_time", "arguments": {}}}], }, ), ] @pytest.mark.parametrize("tool_calls_data,expected", TOOL_CALL_TEST_CASES) def test_tool_calls_format_basic(tool_calls_data, expected): """Test ToolCalls.format with various basic scenarios.""" tool_calls_list = [ToolCalls.ToolCall(**data) for data in tool_calls_data] tool_calls = ToolCalls(tool_calls=tool_calls_list) result = tool_calls.format() assert result == expected def test_tool_calls_format_from_dict_list(): """Test format works with ToolCalls created from from_dict_list.""" tool_calls_dicts = [ {"name": "search", "args": {"query": "hello"}}, {"name": "translate", "args": {"text": "world", "lang": "fr"}}, ] tool_calls = ToolCalls.from_dict_list(tool_calls_dicts) result = tool_calls.format() assert len(result["tool_calls"]) == 2 assert result["tool_calls"][0]["function"]["name"] == "search" assert result["tool_calls"][1]["function"]["name"] == "translate" def test_toolcalls_vague_match(): """ Test that ToolCalls can parse the data with slightly off format: - a single dict with "name" and "args" - a list of dicts with "name" and "args" - invalid input (should raise ValueError) """ # Single dict with "name" and "args" should parse as one ToolCall data_single = {"name": "search", "args": {"query": "hello"}} tc = ToolCalls.model_validate(data_single) assert isinstance(tc, ToolCalls) assert len(tc.tool_calls) == 1 assert tc.tool_calls[0].name == "search" assert tc.tool_calls[0].args == {"query": "hello"} # List of dicts with "name" and "args" should parse as multiple ToolCalls data_list = [ {"name": "search", "args": {"query": "hello"}}, {"name": "translate", "args": {"text": "world", "lang": "fr"}}, ] tc = ToolCalls.model_validate(data_list) assert isinstance(tc, ToolCalls) assert len(tc.tool_calls) == 2 assert tc.tool_calls[0].name == "search" assert tc.tool_calls[1].name == "translate" # Dict with "tool_calls" key containing a list of dicts data_tool_calls = { "tool_calls": [ {"name": "search", "args": {"query": "hello"}}, {"name": "get_time", "args": {}}, ] } tc = ToolCalls.model_validate(data_tool_calls) assert isinstance(tc, ToolCalls) assert len(tc.tool_calls) == 2 assert tc.tool_calls[0].name == "search" assert tc.tool_calls[1].name == "get_time" # Invalid input should raise ValueError with pytest.raises(ValueError): ToolCalls.model_validate({"foo": "bar"}) with pytest.raises(ValueError): ToolCalls.model_validate([{"foo": "bar"}]) def test_tool_convert_input_schema_to_tool_args_no_input_params(): args, arg_types, arg_desc = convert_input_schema_to_tool_args(schema={"properties": {}}) assert args == {} assert arg_types == {} assert arg_desc == {} def test_tool_convert_input_schema_to_tool_args_lang_chain(): # Example from langchain docs: # https://web.archive.org/web/20250723101359/https://api.python.langchain.com/en/latest/tools/langchain_core.tools.tool.html args, arg_types, arg_desc = convert_input_schema_to_tool_args( schema={ "title": "fooSchema", "description": "The foo.", "type": "object", "properties": { "bar": { "title": "Bar", "description": "The bar.", "type": "string", }, "baz": { "title": "Baz", "type": "integer", }, }, "required": [ "baz", ], } ) assert args == { "bar": {"title": "Bar", "description": "The bar.", "type": "string"}, "baz": {"title": "Baz", "type": "integer"}, } assert arg_types == { "bar": str, "baz": int, } assert arg_desc == { "bar": "The bar.", "baz": "No description provided. (Required)", } def test_tool_call_execute(): def get_weather(city: str) -> str: return f"The weather in {city} is sunny" def add_numbers(a: int, b: int) -> int: return a + b tools = [ dspy.Tool(get_weather), dspy.Tool(add_numbers) ] tool_call = dspy.ToolCalls.ToolCall(name="get_weather", args={"city": "Berlin"}) result = tool_call.execute(functions=tools) assert result == "The weather in Berlin is sunny" # Test individual tool call with function dict tool_call2 = dspy.ToolCalls.ToolCall(name="add_numbers", args={"a": 7, "b": 13}) result2 = tool_call2.execute(functions={"add_numbers": add_numbers}) assert result2 == 20 # Test individual tool call with no arguments def get_pi(): return 3.14159 tool_call3 = dspy.ToolCalls.ToolCall(name="get_pi", args={}) result3 = tool_call3.execute(functions={"get_pi": get_pi}) assert result3 == 3.14159 # Test error case tool_call4 = dspy.ToolCalls.ToolCall(name="nonexistent", args={}) try: tool_call4.execute(functions=tools) assert False, "Should have raised ValueError" except ValueError as e: assert "not found" in str(e) def test_tool_call_execute_with_local_functions(): def main(): def local_add(a: int, b: int) -> int: return a + b def local_multiply(x: int, y: int) -> int: return x * y # Test individual execution with local function tool_call1 = dspy.ToolCalls.ToolCall(name="local_add", args={"a": 10, "b": 15}) result1 = tool_call1.execute() # Should find local function automatically assert result1 == 25 tool_call2 = dspy.ToolCalls.ToolCall(name="local_multiply", args={"x": 4, "y": 7}) result2 = tool_call2.execute() # Should find local function automatically assert result2 == 28 # Test locals take precedence over globals try: globals()["local_add"] = lambda a, b: a + b + 1000 precedence_call = dspy.ToolCalls.ToolCall(name="local_add", args={"a": 1, "b": 2}) result = precedence_call.execute() assert result == 3 # Should use local function (1+2=3), not global (1+2+1000=1003) finally: globals().pop("local_add", None) main() ``` -------------------------------------------------------------------------------- /docs/docs/tutorials/streaming/index.md: -------------------------------------------------------------------------------- ```markdown # Streaming In this guide, we will walk you through how to enable streaming in your DSPy program. DSPy Streaming consists of two parts: - **Output Token Streaming**: Stream individual tokens as they're generated, rather than waiting for the complete response. - **Intermediate Status Streaming**: Provide real-time updates about the program's execution state (e.g., "Calling web search...", "Processing results..."). ## Output Token Streaming DSPy's token streaming feature works with any module in your pipeline, not just the final output. The only requirement is that the streamed field must be of type `str`. To enable token streaming: 1. Wrap your program with `dspy.streamify` 2. Create one or more `dspy.streaming.StreamListener` objects to specify which fields to stream Here's a basic example: ```python import os import dspy os.environ["OPENAI_API_KEY"] = "your_api_key" dspy.configure(lm=dspy.LM("openai/gpt-4o-mini")) predict = dspy.Predict("question->answer") # Enable streaming for the 'answer' field stream_predict = dspy.streamify( predict, stream_listeners=[dspy.streaming.StreamListener(signature_field_name="answer")], ) ``` To consume the streamed output: ```python import asyncio async def read_output_stream(): output_stream = stream_predict(question="Why did a chicken cross the kitchen?") async for chunk in output_stream: print(chunk) asyncio.run(read_output_stream()) ``` This will produce output like: ``` StreamResponse(predict_name='self', signature_field_name='answer', chunk='To') StreamResponse(predict_name='self', signature_field_name='answer', chunk=' get') StreamResponse(predict_name='self', signature_field_name='answer', chunk=' to') StreamResponse(predict_name='self', signature_field_name='answer', chunk=' the') StreamResponse(predict_name='self', signature_field_name='answer', chunk=' other') StreamResponse(predict_name='self', signature_field_name='answer', chunk=' side of the frying pan!') Prediction( answer='To get to the other side of the frying pan!' ) ``` Note: Since `dspy.streamify` returns an async generator, you must use it within an async context. If you're using an environment like Jupyter or Google Colab that already has an event loop (async context), you can use the generator directly. You may have noticed that the above streaming contains two different entities: `StreamResponse` and `Prediction.` `StreamResponse` is the wrapper over streaming tokens on the field being listened to, and in this example it is the `answer` field. `Prediction` is the program's final output. In DSPy, streaming is implemented in a sidecar fashion: we enable streaming on the LM so that LM outputs a stream of tokens. We send these tokens to a side channel, which is being continuously read by the user-defined listeners. Listeners keep interpreting the stream, and decides if the `signature_field_name` it is listening to has started to appear and has finalized. Once it decides that the field appears, the listener begins outputting tokens to the async generator users can read. Listeners' internal mechanism changes according to the adapter behind the scene, and because usually we cannot decide if a field has finalized until seeing the next field, the listener buffers the output tokens before sending to the final generator, which is why you will usually see the last chunk of type `StreamResponse` has more than one token. The program's output is also written to the stream, which is the chunk of `Prediction` as in the sample output above. To handle these different types and implement custom logic: ```python import asyncio async def read_output_stream(): output_stream = stream_predict(question="Why did a chicken cross the kitchen?") async for chunk in output_stream: return_value = None if isinstance(chunk, dspy.streaming.StreamResponse): print(f"Output token of field {chunk.signature_field_name}: {chunk.chunk}") elif isinstance(chunk, dspy.Prediction): return_value = chunk program_output = asyncio.run(read_output_stream()) print("Final output: ", program_output) ``` ### Understand `StreamResponse` `StreamResponse` (`dspy.streaming.StreamResponse`) is the wrapper class of streaming tokens. It comes with 3 fields: - `predict_name`: the name of the predict that holds the `signature_field_name`. The name is the same name of keys as you run `your_program.named_predictors()`. In the code above because `answer` is from the `predict` itself, so the `predict_name` shows up as `self`, which is the only key as your run `predict.named_predictors()`. - `signature_field_name`: the output field that these tokens map to. `predict_name` and `signature_field_name` together form the unique identifier of the field. We will demonstrate how to handle multiple fields streaming and duplicated field name later in this guide. - `chunk`: the value of the stream chunk. ### Streaming with Cache When a cached result is found, the stream will skip individual tokens and only yield the final `Prediction`. For example: ``` Prediction( answer='To get to the other side of the dinner plate!' ) ``` ### Streaming Multiple Fields You can monitor multiple fields by creating a `StreamListener` for each one. Here's an example with a multi-module program: ```python import asyncio import dspy lm = dspy.LM("openai/gpt-4o-mini", cache=False) dspy.settings.configure(lm=lm) class MyModule(dspy.Module): def __init__(self): super().__init__() self.predict1 = dspy.Predict("question->answer") self.predict2 = dspy.Predict("answer->simplified_answer") def forward(self, question: str, **kwargs): answer = self.predict1(question=question) simplified_answer = self.predict2(answer=answer) return simplified_answer predict = MyModule() stream_listeners = [ dspy.streaming.StreamListener(signature_field_name="answer"), dspy.streaming.StreamListener(signature_field_name="simplified_answer"), ] stream_predict = dspy.streamify( predict, stream_listeners=stream_listeners, ) async def read_output_stream(): output = stream_predict(question="why did a chicken cross the kitchen?") return_value = None async for chunk in output: if isinstance(chunk, dspy.streaming.StreamResponse): print(chunk) elif isinstance(chunk, dspy.Prediction): return_value = chunk return return_value program_output = asyncio.run(read_output_stream()) print("Final output: ", program_output) ``` The output will look like: ``` StreamResponse(predict_name='predict1', signature_field_name='answer', chunk='To') StreamResponse(predict_name='predict1', signature_field_name='answer', chunk=' get') StreamResponse(predict_name='predict1', signature_field_name='answer', chunk=' to') StreamResponse(predict_name='predict1', signature_field_name='answer', chunk=' the') StreamResponse(predict_name='predict1', signature_field_name='answer', chunk=' other side of the recipe!') StreamResponse(predict_name='predict2', signature_field_name='simplified_answer', chunk='To') StreamResponse(predict_name='predict2', signature_field_name='simplified_answer', chunk=' reach') StreamResponse(predict_name='predict2', signature_field_name='simplified_answer', chunk=' the') StreamResponse(predict_name='predict2', signature_field_name='simplified_answer', chunk=' other side of the recipe!') Final output: Prediction( simplified_answer='To reach the other side of the recipe!' ) ``` ### Streaming the Same Field Multiple Times (as in dspy.ReAct) By default, a `StreamListener` automatically closes itself after completing a single streaming session. This design helps prevent performance issues, since every token is broadcast to all configured stream listeners, and having too many active listeners can introduce significant overhead. However, in scenarios where a DSPy module is used repeatedly in a loop—such as with `dspy.ReAct` — you may want to stream the same field from each prediction, every time it is used. To enable this behavior, set allow_reuse=True when creating your `StreamListener`. See the example below: ```python import asyncio import dspy lm = dspy.LM("openai/gpt-4o-mini", cache=False) dspy.settings.configure(lm=lm) def fetch_user_info(user_name: str): """Get user information like name, birthday, etc.""" return { "name": user_name, "birthday": "2009-05-16", } def get_sports_news(year: int): """Get sports news for a given year.""" if year == 2009: return "Usane Bolt broke the world record in the 100m race." return None react = dspy.ReAct("question->answer", tools=[fetch_user_info, get_sports_news]) stream_listeners = [ # dspy.ReAct has a built-in output field called "next_thought". dspy.streaming.StreamListener(signature_field_name="next_thought", allow_reuse=True), ] stream_react = dspy.streamify(react, stream_listeners=stream_listeners) async def read_output_stream(): output = stream_react(question="What sports news happened in the year Adam was born?") return_value = None async for chunk in output: if isinstance(chunk, dspy.streaming.StreamResponse): print(chunk) elif isinstance(chunk, dspy.Prediction): return_value = chunk return return_value print(asyncio.run(read_output_stream())) ``` In this example, by setting `allow_reuse=True` in the StreamListener, you ensure that streaming for "next_thought" is available for every iteration, not just the first. When you run this code, you will see the streaming tokens for `next_thought` output each time the field is produced. #### Handling Duplicate Field Names When streaming fields with the same name from different modules, specify both the `predict` and `predict_name` in the `StreamListener`: ```python import asyncio import dspy lm = dspy.LM("openai/gpt-4o-mini", cache=False) dspy.settings.configure(lm=lm) class MyModule(dspy.Module): def __init__(self): super().__init__() self.predict1 = dspy.Predict("question->answer") self.predict2 = dspy.Predict("question, answer->answer, score") def forward(self, question: str, **kwargs): answer = self.predict1(question=question) simplified_answer = self.predict2(answer=answer) return simplified_answer predict = MyModule() stream_listeners = [ dspy.streaming.StreamListener( signature_field_name="answer", predict=predict.predict1, predict_name="predict1" ), dspy.streaming.StreamListener( signature_field_name="answer", predict=predict.predict2, predict_name="predict2" ), ] stream_predict = dspy.streamify( predict, stream_listeners=stream_listeners, ) async def read_output_stream(): output = stream_predict(question="why did a chicken cross the kitchen?") return_value = None async for chunk in output: if isinstance(chunk, dspy.streaming.StreamResponse): print(chunk) elif isinstance(chunk, dspy.Prediction): return_value = chunk return return_value program_output = asyncio.run(read_output_stream()) print("Final output: ", program_output) ``` The output will be like: ``` StreamResponse(predict_name='predict1', signature_field_name='answer', chunk='To') StreamResponse(predict_name='predict1', signature_field_name='answer', chunk=' get') StreamResponse(predict_name='predict1', signature_field_name='answer', chunk=' to') StreamResponse(predict_name='predict1', signature_field_name='answer', chunk=' the') StreamResponse(predict_name='predict1', signature_field_name='answer', chunk=' other side of the recipe!') StreamResponse(predict_name='predict2', signature_field_name='answer', chunk="I'm") StreamResponse(predict_name='predict2', signature_field_name='answer', chunk=' ready') StreamResponse(predict_name='predict2', signature_field_name='answer', chunk=' to') StreamResponse(predict_name='predict2', signature_field_name='answer', chunk=' assist') StreamResponse(predict_name='predict2', signature_field_name='answer', chunk=' you') StreamResponse(predict_name='predict2', signature_field_name='answer', chunk='! Please provide a question.') Final output: Prediction( answer="I'm ready to assist you! Please provide a question.", score='N/A' ) ``` ## Intermediate Status Streaming Status streaming keeps users informed about the program's progress, especially useful for long-running operations like tool calls or complex AI pipelines. To implement status streaming: 1. Create a custom status message provider by subclassing `dspy.streaming.StatusMessageProvider` 2. Override the desired hook methods to provide custom status messages 3. Pass your provider to `dspy.streamify` Example: ```python class MyStatusMessageProvider(dspy.streaming.StatusMessageProvider): def lm_start_status_message(self, instance, inputs): return f"Calling LM with inputs {inputs}..." def lm_end_status_message(self, outputs): return f"Tool finished with output: {outputs}!" ``` Available hooks: - lm_start_status_message: status message at the start of calling dspy.LM. - lm_end_status_message: status message at the end of calling dspy.LM. - module_start_status_message: status message at the start of calling a dspy.Module. - module_end_status_message: status message at the start of calling a dspy.Module. - tool_start_status_message: status message at the start of calling dspy.Tool. - tool_end_status_message: status message at the end of calling dspy.Tool. Each hook should return a string containing the status message. After creating the message provider, just pass it to `dspy.streamify`, and you can enable both status message streaming and output token streaming. Please see the example below. The intermediate status message is represented in the class `dspy.streaming.StatusMessage`, so we need to have another condition check to capture it. ```python import asyncio import dspy lm = dspy.LM("openai/gpt-4o-mini", cache=False) dspy.settings.configure(lm=lm) class MyModule(dspy.Module): def __init__(self): super().__init__() self.tool = dspy.Tool(lambda x: 2 * x, name="double_the_number") self.predict = dspy.ChainOfThought("num1, num2->sum") def forward(self, num, **kwargs): num2 = self.tool(x=num) return self.predict(num1=num, num2=num2) class MyStatusMessageProvider(dspy.streaming.StatusMessageProvider): def tool_start_status_message(self, instance, inputs): return f"Calling Tool {instance.name} with inputs {inputs}..." def tool_end_status_message(self, outputs): return f"Tool finished with output: {outputs}!" predict = MyModule() stream_listeners = [ # dspy.ChainOfThought has a built-in output field called "reasoning". dspy.streaming.StreamListener(signature_field_name="reasoning"), ] stream_predict = dspy.streamify( predict, stream_listeners=stream_listeners, status_message_provider=MyStatusMessageProvider(), ) async def read_output_stream(): output = stream_predict(num=3) return_value = None async for chunk in output: if isinstance(chunk, dspy.streaming.StreamResponse): print(chunk) elif isinstance(chunk, dspy.Prediction): return_value = chunk elif isinstance(chunk, dspy.streaming.StatusMessage): print(chunk) return return_value program_output = asyncio.run(read_output_stream()) print("Final output: ", program_output) ``` Sample output: ``` StatusMessage(message='Calling tool double_the_number...') StatusMessage(message='Tool calling finished! Querying the LLM with tool calling results...') StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk='To') StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' find') StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' the') StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' sum') StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' of') StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' the') StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' two') StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' numbers') StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=',') StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' we') StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' simply') StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' add') StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' them') StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' together') StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk='.') StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' Here') StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=',') StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' ') StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk='3') StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' plus') StreamResponse(predict_name='predict.predict', signature_field_name='reasoning', chunk=' 6 equals 9.') Final output: Prediction( reasoning='To find the sum of the two numbers, we simply add them together. Here, 3 plus 6 equals 9.', sum='9' ) ``` ## Synchronous Streaming By default calling a streamified DSPy program produces an async generator. In order to get back a sync generator, you can set the flag `async_streaming=False`: ```python import os import dspy os.environ["OPENAI_API_KEY"] = "your_api_key" dspy.configure(lm=dspy.LM("openai/gpt-4o-mini")) predict = dspy.Predict("question->answer") # Enable streaming for the 'answer' field stream_predict = dspy.streamify( predict, stream_listeners=[dspy.streaming.StreamListener(signature_field_name="answer")], async_streaming=False, ) output = stream_predict(question="why did a chicken cross the kitchen?") program_output = None for chunk in output: if isinstance(chunk, dspy.streaming.StreamResponse): print(chunk) elif isinstance(chunk, dspy.Prediction): program_output = chunk print(f"Program output: {program_output}") ``` ``` -------------------------------------------------------------------------------- /tests/signatures/test_adapter_image.py: -------------------------------------------------------------------------------- ```python import os import tempfile from io import BytesIO import pydantic import pytest import requests from PIL import Image as PILImage import dspy from dspy.adapters.types.image import encode_image from dspy.utils.dummies import DummyLM @pytest.fixture def sample_pil_image(): """Fixture to provide a sample image for testing""" url = "https://images.dog.ceo/breeds/dane-great/n02109047_8912.jpg" response = requests.get(url) response.raise_for_status() return PILImage.open(BytesIO(response.content)) @pytest.fixture def sample_dspy_image_download(): url = "https://images.dog.ceo/breeds/dane-great/n02109047_8912.jpg" return dspy.Image(url, download=True) @pytest.fixture def sample_url(): return "https://images.dog.ceo/breeds/dane-great/n02109047_8912.jpg" @pytest.fixture def sample_dspy_image_no_download(): return dspy.Image("https://images.dog.ceo/breeds/dane-great/n02109047_8912.jpg") def count_messages_with_image_url_pattern(messages): pattern = {"type": "image_url", "image_url": {"url": lambda x: isinstance(x, str)}} try: def check_pattern(obj, pattern): if isinstance(pattern, dict): if not isinstance(obj, dict): return False return all(k in obj and check_pattern(obj[k], v) for k, v in pattern.items()) if callable(pattern): return pattern(obj) return obj == pattern def count_patterns(obj, pattern): count = 0 if check_pattern(obj, pattern): count += 1 if isinstance(obj, dict): count += sum(count_patterns(v, pattern) for v in obj.values()) if isinstance(obj, (list, tuple)): count += sum(count_patterns(v, pattern) for v in obj) return count return count_patterns(messages, pattern) except Exception: return 0 def setup_predictor(signature, expected_output): """Helper to set up a predictor with DummyLM""" lm = DummyLM([expected_output]) dspy.settings.configure(lm=lm) return dspy.Predict(signature), lm @pytest.mark.parametrize( "test_case", [ { "name": "probabilistic_classification", "signature": "image: dspy.Image, class_labels: list[str] -> probabilities: dict[str, float]", "inputs": {"image": "https://example.com/dog.jpg", "class_labels": ["dog", "cat", "bird"]}, "key_output": "probabilities", "expected": {"probabilities": {"dog": 0.8, "cat": 0.1, "bird": 0.1}}, }, { "name": "image_to_code", "signature": "ui_image: dspy.Image, target_language: str -> generated_code: str", "inputs": {"ui_image": "https://example.com/button.png", "target_language": "HTML"}, "key_output": "generated_code", "expected": {"generated_code": "<button>Click me</button>"}, }, { "name": "bbox_detection", "signature": "image: dspy.Image -> bboxes: list[Tuple[int, int, int, int]]", "inputs": {"image": "https://example.com/image.jpg"}, "key_output": "bboxes", "expected": {"bboxes": [(10, 20, 30, 40), (50, 60, 70, 80)]}, }, { "name": "multilingual_caption", "signature": "image: dspy.Image, languages: list[str] -> captions: dict[str, str]", "inputs": {"image": "https://example.com/dog.jpg", "languages": ["en", "es", "fr"]}, "key_output": "captions", "expected": { "captions": {"en": "A golden retriever", "es": "Un golden retriever", "fr": "Un golden retriever"} }, }, ], ) def test_basic_image_operations(test_case): """Consolidated test for basic image operations""" predictor, lm = setup_predictor(test_case["signature"], test_case["expected"]) # Convert string URLs to dspy.Image objects inputs = { k: dspy.Image(v) if isinstance(v, str) and k in ["image", "ui_image"] else v for k, v in test_case["inputs"].items() } result = predictor(**inputs) # Check result based on output field name output_field = next(f for f in ["probabilities", "generated_code", "bboxes", "captions"] if hasattr(result, f)) assert getattr(result, output_field) == test_case["expected"][test_case["key_output"]] assert count_messages_with_image_url_pattern(lm.history[-1]["messages"]) == 1 @pytest.mark.parametrize( "image_input,description", [ ("pil_image", "PIL Image"), ("encoded_pil_image", "encoded PIL image string"), ("dspy_image_download", "dspy.Image with download=True"), ("dspy_image_no_download", "dspy.Image without download"), ], ) def test_image_input_formats( request, sample_pil_image, sample_dspy_image_download, sample_dspy_image_no_download, image_input, description ): """Test different input formats for image fields""" signature = "image: dspy.Image, class_labels: list[str] -> probabilities: dict[str, float]" expected = {"probabilities": {"dog": 0.8, "cat": 0.1, "bird": 0.1}} predictor, lm = setup_predictor(signature, expected) input_map = { "pil_image": sample_pil_image, "encoded_pil_image": encode_image(sample_pil_image), "dspy_image_download": sample_dspy_image_download, "dspy_image_no_download": sample_dspy_image_no_download, } actual_input = input_map[image_input] # TODO(isaacbmiller): Support the cases without direct dspy.Image coercion if image_input in ["pil_image", "encoded_pil_image"]: pytest.xfail(f"{description} not fully supported without dspy.Image coercion") result = predictor(image=actual_input, class_labels=["dog", "cat", "bird"]) assert result.probabilities == expected["probabilities"] assert count_messages_with_image_url_pattern(lm.history[-1]["messages"]) == 1 def test_predictor_save_load(sample_url, sample_pil_image): """Test saving and loading predictors with image fields""" signature = "image: dspy.Image -> caption: str" examples = [ dspy.Example(image=dspy.Image(sample_url), caption="Example 1"), dspy.Example(image=sample_pil_image, caption="Example 2"), ] predictor, lm = setup_predictor(signature, {"caption": "A golden retriever"}) optimizer = dspy.teleprompt.LabeledFewShot(k=1) compiled_predictor = optimizer.compile(student=predictor, trainset=examples, sample=False) with tempfile.NamedTemporaryFile(mode="w+", delete=True, suffix=".json") as temp_file: compiled_predictor.save(temp_file.name) loaded_predictor = dspy.Predict(signature) loaded_predictor.load(temp_file.name) loaded_predictor(image=dspy.Image("https://example.com/dog.jpg")) assert count_messages_with_image_url_pattern(lm.history[-1]["messages"]) == 2 assert "<DSPY_IMAGE_START>" not in str(lm.history[-1]["messages"]) def test_save_load_complex_default_types(): """Test saving and loading predictors with complex default types (lists of images)""" examples = [ dspy.Example( image_list=[ dspy.Image("https://example.com/dog.jpg"), dspy.Image("https://example.com/cat.jpg"), ], caption="Example 1", ).with_inputs("image_list"), ] class ComplexTypeSignature(dspy.Signature): image_list: list[dspy.Image] = dspy.InputField(desc="A list of images") caption: str = dspy.OutputField(desc="A caption for the image list") predictor, lm = setup_predictor(ComplexTypeSignature, {"caption": "A list of images"}) optimizer = dspy.teleprompt.LabeledFewShot(k=1) compiled_predictor = optimizer.compile(student=predictor, trainset=examples, sample=False) with tempfile.NamedTemporaryFile(mode="w+", delete=True, suffix=".json") as temp_file: compiled_predictor.save(temp_file.name) loaded_predictor = dspy.Predict(ComplexTypeSignature) loaded_predictor.load(temp_file.name) result = loaded_predictor(**examples[0].inputs()) assert result.caption == "A list of images" assert str(lm.history[-1]["messages"]).count("'url'") == 4 assert "<DSPY_IMAGE_START>" not in str(lm.history[-1]["messages"]) class BasicImageSignature(dspy.Signature): """Basic signature with a single image input""" image: dspy.Image = dspy.InputField() output: str = dspy.OutputField() class ImageListSignature(dspy.Signature): """Signature with a list of images input""" image_list: list[dspy.Image] = dspy.InputField() output: str = dspy.OutputField() @pytest.mark.parametrize( "test_case", [ { "name": "basic_dspy_signature", "signature_class": BasicImageSignature, "inputs": {"image": "https://example.com/dog.jpg"}, "expected": {"output": "A dog photo"}, "expected_image_urls": 2, }, { "name": "list_dspy_signature", "signature_class": ImageListSignature, "inputs": {"image_list": ["https://example.com/dog.jpg", "https://example.com/cat.jpg"]}, "expected": {"output": "Multiple photos"}, "expected_image_urls": 4, }, ], ) def test_save_load_complex_types(test_case): """Test saving and loading predictors with complex types""" signature_cls = test_case["signature_class"] # Convert string URLs to dspy.Image objects in input processed_input = {} for key, value in test_case["inputs"].items(): if isinstance(value, str) and "http" in value: processed_input[key] = dspy.Image(value) elif isinstance(value, list) and value and isinstance(value[0], str): processed_input[key] = [dspy.Image(url) for url in value] else: processed_input[key] = value # Create example and predictor examples = [dspy.Example(**processed_input, **test_case["expected"]).with_inputs(*processed_input.keys())] predictor, lm = setup_predictor(signature_cls, test_case["expected"]) optimizer = dspy.teleprompt.LabeledFewShot(k=1) compiled_predictor = optimizer.compile(student=predictor, trainset=examples, sample=False) # Test save and load with tempfile.NamedTemporaryFile(mode="w+", delete=True, suffix=".json") as temp_file: compiled_predictor.save(temp_file.name) loaded_predictor = dspy.Predict(signature_cls) loaded_predictor.load(temp_file.name) # Run prediction result = loaded_predictor(**processed_input) # Verify output matches expected for key, value in test_case["expected"].items(): assert getattr(result, key) == value # Verify correct number of image URLs in messages assert count_messages_with_image_url_pattern(lm.history[-1]["messages"]) == test_case["expected_image_urls"] assert "<DSPY_IMAGE_START>" not in str(lm.history[-1]["messages"]) def test_save_load_pydantic_model(): """Test saving and loading predictors with pydantic models""" class ImageModel(pydantic.BaseModel): image: dspy.Image image_list: list[dspy.Image] | None = None output: str class PydanticSignature(dspy.Signature): model_input: ImageModel = dspy.InputField() output: str = dspy.OutputField() # Create model instance model_input = ImageModel( image=dspy.Image("https://example.com/dog.jpg"), image_list=[dspy.Image("https://example.com/cat.jpg")], output="Multiple photos", ) # Create example and predictor examples = [dspy.Example(model_input=model_input, output="Multiple photos").with_inputs("model_input")] predictor, lm = setup_predictor(PydanticSignature, {"output": "Multiple photos"}) optimizer = dspy.teleprompt.LabeledFewShot(k=1) compiled_predictor = optimizer.compile(student=predictor, trainset=examples, sample=False) # Test save and load with tempfile.NamedTemporaryFile(mode="w+", delete=True, suffix=".json") as temp_file: compiled_predictor.save(temp_file.name) loaded_predictor = dspy.Predict(PydanticSignature) loaded_predictor.load(temp_file.name) # Run prediction result = loaded_predictor(model_input=model_input) # Verify output matches expected assert result.output == "Multiple photos" assert count_messages_with_image_url_pattern(lm.history[-1]["messages"]) == 4 assert "<DSPY_IMAGE_START>" not in str(lm.history[-1]["messages"]) def test_optional_image_field(): """Test that optional image fields are not required""" class OptionalImageSignature(dspy.Signature): image: dspy.Image | None = dspy.InputField() output: str = dspy.OutputField() predictor, lm = setup_predictor(OptionalImageSignature, {"output": "Hello"}) result = predictor(image=None) assert result.output == "Hello" assert count_messages_with_image_url_pattern(lm.history[-1]["messages"]) == 0 def test_pdf_url_support(): """Test support for PDF files from URLs""" pdf_url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf" # Create a dspy.Image object from the PDF URL with download=True pdf_image = dspy.Image(pdf_url, download=True) # The data URI should contain application/pdf in the MIME type assert "data:application/pdf" in pdf_image.url assert ";base64," in pdf_image.url # Test using it in a predictor class PDFSignature(dspy.Signature): document: dspy.Image = dspy.InputField(desc="A PDF document") summary: str = dspy.OutputField(desc="A summary of the PDF") predictor, lm = setup_predictor(PDFSignature, {"summary": "This is a dummy PDF"}) result = predictor(document=pdf_image) assert result.summary == "This is a dummy PDF" assert count_messages_with_image_url_pattern(lm.history[-1]["messages"]) == 1 # Ensure the URL was properly expanded in messages messages_str = str(lm.history[-1]["messages"]) assert "application/pdf" in messages_str def test_different_mime_types(): """Test support for different file types and MIME type detection""" # Test with various file types file_urls = { "pdf": "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf", "image": "https://images.dog.ceo/breeds/dane-great/n02109047_8912.jpg", } expected_mime_types = { "pdf": "application/pdf", "image": "image/jpeg", } for file_type, url in file_urls.items(): # Download and encode encoded = encode_image(url, download_images=True) # Check for correct MIME type in the encoded data - using 'in' instead of startswith # to account for possible parameters in the MIME type assert f"data:{expected_mime_types[file_type]}" in encoded assert ";base64," in encoded def test_mime_type_from_response_headers(): """Test that MIME types from response headers are correctly used""" # This URL returns proper Content-Type header pdf_url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf" # Make an actual request to get the content type from headers response = requests.get(pdf_url) expected_mime_type = response.headers.get("Content-Type", "") # Should be application/pdf or similar assert "pdf" in expected_mime_type.lower() # Encode with download to test MIME type from headers encoded = encode_image(pdf_url, download_images=True) # The encoded data should contain the correct MIME type assert "application/pdf" in encoded assert ";base64," in encoded def test_pdf_from_file(): """Test handling a PDF file from disk""" # Download a PDF to a temporary file pdf_url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf" response = requests.get(pdf_url) response.raise_for_status() with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmp_file: tmp_file.write(response.content) tmp_file_path = tmp_file.name try: # Create a dspy.Image from the file pdf_image = dspy.Image(tmp_file_path) # The constructor encodes the file into a data URI we can inspect directly assert "data:application/pdf" in pdf_image.url assert ";base64," in pdf_image.url # Test the image in a predictor class FilePDFSignature(dspy.Signature): document: dspy.Image = dspy.InputField(desc="A PDF document from file") summary: str = dspy.OutputField(desc="A summary of the PDF") predictor, lm = setup_predictor(FilePDFSignature, {"summary": "This is a PDF from file"}) result = predictor(document=pdf_image) assert result.summary == "This is a PDF from file" assert count_messages_with_image_url_pattern(lm.history[-1]["messages"]) == 1 finally: # Clean up the temporary file try: os.unlink(tmp_file_path) except Exception: pass def test_image_repr(): """Test string representation of Image objects""" url_image = dspy.Image("https://example.com/dog.jpg") assert str(url_image) == ( "<<CUSTOM-TYPE-START-IDENTIFIER>>" '[{"type": "image_url", "image_url": {"url": "https://example.com/dog.jpg"}}]' "<<CUSTOM-TYPE-END-IDENTIFIER>>" ) assert repr(url_image) == "Image(url='https://example.com/dog.jpg')" sample_pil = PILImage.new("RGB", (60, 30), color="red") pil_image = dspy.Image(sample_pil) assert str(pil_image).startswith('<<CUSTOM-TYPE-START-IDENTIFIER>>[{"type": "image_url",') assert str(pil_image).endswith("<<CUSTOM-TYPE-END-IDENTIFIER>>") assert "base64" in str(pil_image) def test_from_methods_warn(tmp_path): """Deprecated from_* methods emit warnings""" tmp_file = tmp_path / "test.png" tmp_file.write_bytes(b"pngdata") with pytest.warns(DeprecationWarning): dspy.Image.from_url("https://example.com/dog.jpg") with pytest.warns(DeprecationWarning): dspy.Image.from_file(str(tmp_file)) sample_pil = PILImage.new("RGB", (10, 10), color="blue") with pytest.warns(DeprecationWarning): dspy.Image.from_PIL(sample_pil) def test_invalid_string_format(): """Test that invalid string formats raise a ValueError""" invalid_string = "this_is_not_a_url_or_file" # Should raise a ValueError and not pass the string through with pytest.raises(ValueError, match="Unrecognized") as warning_info: image = dspy.Image(invalid_string) def test_pil_image_with_download_parameter(): """Test behavior when PIL image is passed with download=True""" sample_pil = PILImage.new("RGB", (60, 30), color="red") # PIL image should be encoded regardless of download parameter image_no_download = dspy.Image(sample_pil) image_with_download = dspy.Image(sample_pil, download=True) # Both should result in base64 encoded data URIs assert image_no_download.url.startswith("data:") assert image_with_download.url.startswith("data:") assert "base64," in image_no_download.url assert "base64," in image_with_download.url # They should be identical since PIL images are always encoded assert image_no_download.url == image_with_download.url ``` -------------------------------------------------------------------------------- /dspy/clients/lm.py: -------------------------------------------------------------------------------- ```python import logging import os import re import threading import warnings from typing import Any, Literal, cast import litellm from anyio.streams.memory import MemoryObjectSendStream from asyncer import syncify import dspy from dspy.clients.cache import request_cache from dspy.clients.openai import OpenAIProvider from dspy.clients.provider import Provider, ReinforceJob, TrainingJob from dspy.clients.utils_finetune import TrainDataFormat from dspy.dsp.utils.settings import settings from dspy.utils.callback import BaseCallback from .base_lm import BaseLM logger = logging.getLogger(__name__) class LM(BaseLM): """ A language model supporting chat or text completion requests for use with DSPy modules. """ def __init__( self, model: str, model_type: Literal["chat", "text", "responses"] = "chat", temperature: float | None = None, max_tokens: int | None = None, cache: bool = True, callbacks: list[BaseCallback] | None = None, num_retries: int = 3, provider: Provider | None = None, finetuning_model: str | None = None, launch_kwargs: dict[str, Any] | None = None, train_kwargs: dict[str, Any] | None = None, use_developer_role: bool = False, **kwargs, ): """ Create a new language model instance for use with DSPy modules and programs. Args: model: The model to use. This should be a string of the form ``"llm_provider/llm_name"`` supported by LiteLLM. For example, ``"openai/gpt-4o"``. model_type: The type of the model, either ``"chat"`` or ``"text"``. temperature: The sampling temperature to use when generating responses. max_tokens: The maximum number of tokens to generate per response. cache: Whether to cache the model responses for reuse to improve performance and reduce costs. callbacks: A list of callback functions to run before and after each request. num_retries: The number of times to retry a request if it fails transiently due to network error, rate limiting, etc. Requests are retried with exponential backoff. provider: The provider to use. If not specified, the provider will be inferred from the model. finetuning_model: The model to finetune. In some providers, the models available for finetuning is different from the models available for inference. rollout_id: Optional integer used to differentiate cache entries for otherwise identical requests. Different values bypass DSPy's caches while still caching future calls with the same inputs and rollout ID. Note that `rollout_id` only affects generation when `temperature` is non-zero. This argument is stripped before sending requests to the provider. """ # Remember to update LM.copy() if you modify the constructor! self.model = model self.model_type = model_type self.cache = cache self.provider = provider or self.infer_provider() self.callbacks = callbacks or [] self.history = [] self.num_retries = num_retries self.finetuning_model = finetuning_model self.launch_kwargs = launch_kwargs or {} self.train_kwargs = train_kwargs or {} self.use_developer_role = use_developer_role self._warned_zero_temp_rollout = False # Handle model-specific configuration for different model families model_family = model.split("/")[-1].lower() if "/" in model else model.lower() # Recognize OpenAI reasoning models (o1, o3, o4, gpt-5 family) model_pattern = re.match(r"^(?:o[1345]|gpt-5)(?:-(?:mini|nano))?", model_family) if model_pattern: if (temperature and temperature != 1.0) or (max_tokens and max_tokens < 16000): raise ValueError( "OpenAI's reasoning models require passing temperature=1.0 or None and max_tokens >= 16000 or None to " "`dspy.LM(...)`, e.g., dspy.LM('openai/gpt-5', temperature=1.0, max_tokens=16000)" ) self.kwargs = dict(temperature=temperature, max_completion_tokens=max_tokens, **kwargs) if self.kwargs.get("rollout_id") is None: self.kwargs.pop("rollout_id", None) else: self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs) if self.kwargs.get("rollout_id") is None: self.kwargs.pop("rollout_id", None) self._warn_zero_temp_rollout(self.kwargs.get("temperature"), self.kwargs.get("rollout_id")) def _warn_zero_temp_rollout(self, temperature: float | None, rollout_id): if not self._warned_zero_temp_rollout and rollout_id is not None and (temperature is None or temperature == 0): warnings.warn( "rollout_id has no effect when temperature=0; set temperature>0 to bypass the cache.", stacklevel=3, ) self._warned_zero_temp_rollout = True def _get_cached_completion_fn(self, completion_fn, cache): ignored_args_for_cache_key = ["api_key", "api_base", "base_url"] if cache: completion_fn = request_cache( cache_arg_name="request", ignored_args_for_cache_key=ignored_args_for_cache_key, )(completion_fn) litellm_cache_args = {"no-cache": True, "no-store": True} return completion_fn, litellm_cache_args def forward(self, prompt=None, messages=None, **kwargs): # Build the request. kwargs = dict(kwargs) cache = kwargs.pop("cache", self.cache) messages = messages or [{"role": "user", "content": prompt}] if self.use_developer_role and self.model_type == "responses": messages = [{**m, "role": "developer"} if m.get("role") == "system" else m for m in messages] kwargs = {**self.kwargs, **kwargs} self._warn_zero_temp_rollout(kwargs.get("temperature"), kwargs.get("rollout_id")) if kwargs.get("rollout_id") is None: kwargs.pop("rollout_id", None) if self.model_type == "chat": completion = litellm_completion elif self.model_type == "text": completion = litellm_text_completion elif self.model_type == "responses": completion = litellm_responses_completion completion, litellm_cache_args = self._get_cached_completion_fn(completion, cache) results = completion( request=dict(model=self.model, messages=messages, **kwargs), num_retries=self.num_retries, cache=litellm_cache_args, ) self._check_truncation(results) if not getattr(results, "cache_hit", False) and dspy.settings.usage_tracker and hasattr(results, "usage"): settings.usage_tracker.add_usage(self.model, dict(results.usage)) return results async def aforward(self, prompt=None, messages=None, **kwargs): # Build the request. kwargs = dict(kwargs) cache = kwargs.pop("cache", self.cache) messages = messages or [{"role": "user", "content": prompt}] if self.use_developer_role and self.model_type == "responses": messages = [{**m, "role": "developer"} if m.get("role") == "system" else m for m in messages] kwargs = {**self.kwargs, **kwargs} self._warn_zero_temp_rollout(kwargs.get("temperature"), kwargs.get("rollout_id")) if kwargs.get("rollout_id") is None: kwargs.pop("rollout_id", None) if self.model_type == "chat": completion = alitellm_completion elif self.model_type == "text": completion = alitellm_text_completion elif self.model_type == "responses": completion = alitellm_responses_completion completion, litellm_cache_args = self._get_cached_completion_fn(completion, cache) results = await completion( request=dict(model=self.model, messages=messages, **kwargs), num_retries=self.num_retries, cache=litellm_cache_args, ) self._check_truncation(results) if not getattr(results, "cache_hit", False) and dspy.settings.usage_tracker and hasattr(results, "usage"): settings.usage_tracker.add_usage(self.model, dict(results.usage)) return results def launch(self, launch_kwargs: dict[str, Any] | None = None): self.provider.launch(self, launch_kwargs) def kill(self, launch_kwargs: dict[str, Any] | None = None): self.provider.kill(self, launch_kwargs) def finetune( self, train_data: list[dict[str, Any]], train_data_format: TrainDataFormat | None, train_kwargs: dict[str, Any] | None = None, ) -> TrainingJob: from dspy import settings as settings if not self.provider.finetunable: raise ValueError( f"Provider {self.provider} does not support fine-tuning, please specify your provider by explicitly " "setting `provider` when creating the `dspy.LM` instance. For example, " "`dspy.LM('openai/gpt-4.1-mini-2025-04-14', provider=dspy.OpenAIProvider())`." ) def thread_function_wrapper(): return self._run_finetune_job(job) thread = threading.Thread(target=thread_function_wrapper) train_kwargs = train_kwargs or self.train_kwargs model_to_finetune = self.finetuning_model or self.model job = self.provider.TrainingJob( thread=thread, model=model_to_finetune, train_data=train_data, train_data_format=train_data_format, train_kwargs=train_kwargs, ) thread.start() return job def reinforce( self, train_kwargs ) -> ReinforceJob: # TODO(GRPO Team): Should we return an initialized job here? from dspy import settings as settings err = f"Provider {self.provider} does not implement the reinforcement learning interface." assert self.provider.reinforceable, err job = self.provider.ReinforceJob(lm=self, train_kwargs=train_kwargs) job.initialize() return job def _run_finetune_job(self, job: TrainingJob): # TODO(enhance): We should listen for keyboard interrupts somewhere. # Requires TrainingJob.cancel() to be implemented for each provider. try: model = self.provider.finetune( job=job, model=job.model, train_data=job.train_data, train_data_format=job.train_data_format, train_kwargs=job.train_kwargs, ) lm = self.copy(model=model) job.set_result(lm) except Exception as err: logger.error(err) job.set_result(err) def infer_provider(self) -> Provider: if OpenAIProvider.is_provider_model(self.model): return OpenAIProvider() return Provider() def dump_state(self): state_keys = [ "model", "model_type", "cache", "num_retries", "finetuning_model", "launch_kwargs", "train_kwargs", ] return {key: getattr(self, key) for key in state_keys} | self.kwargs def _check_truncation(self, results): if self.model_type != "responses" and any(c.finish_reason == "length" for c in results["choices"]): logger.warning( f"LM response was truncated due to exceeding max_tokens={self.kwargs['max_tokens']}. " "You can inspect the latest LM interactions with `dspy.inspect_history()`. " "To avoid truncation, consider passing a larger max_tokens when setting up dspy.LM. " f"You may also consider increasing the temperature (currently {self.kwargs['temperature']}) " " if the reason for truncation is repetition." ) def _get_stream_completion_fn( request: dict[str, Any], cache_kwargs: dict[str, Any], sync=True, ): stream = dspy.settings.send_stream caller_predict = dspy.settings.caller_predict if stream is None: return None # The stream is already opened, and will be closed by the caller. stream = cast(MemoryObjectSendStream, stream) caller_predict_id = id(caller_predict) if caller_predict else None if dspy.settings.track_usage: request["stream_options"] = {"include_usage": True} async def stream_completion(request: dict[str, Any], cache_kwargs: dict[str, Any]): headers = request.pop("headers", None) response = await litellm.acompletion( cache=cache_kwargs, stream=True, headers=_get_headers(headers), **request, ) chunks = [] async for chunk in response: if caller_predict_id: # Add the predict id to the chunk so that the stream listener can identify which predict produces it. chunk.predict_id = caller_predict_id chunks.append(chunk) await stream.send(chunk) return litellm.stream_chunk_builder(chunks) def sync_stream_completion(): syncified_stream_completion = syncify(stream_completion) return syncified_stream_completion(request, cache_kwargs) async def async_stream_completion(): return await stream_completion(request, cache_kwargs) if sync: return sync_stream_completion else: return async_stream_completion def litellm_completion(request: dict[str, Any], num_retries: int, cache: dict[str, Any] | None = None): cache = cache or {"no-cache": True, "no-store": True} request = dict(request) request.pop("rollout_id", None) headers = request.pop("headers", None) stream_completion = _get_stream_completion_fn(request, cache, sync=True) if stream_completion is None: return litellm.completion( cache=cache, num_retries=num_retries, retry_strategy="exponential_backoff_retry", headers=_get_headers(headers), **request, ) return stream_completion() def litellm_text_completion(request: dict[str, Any], num_retries: int, cache: dict[str, Any] | None = None): cache = cache or {"no-cache": True, "no-store": True} request = dict(request) request.pop("rollout_id", None) headers = request.pop("headers", None) # Extract the provider and model from the model string. # TODO: Not all the models are in the format of "provider/model" model = request.pop("model").split("/", 1) provider, model = model[0] if len(model) > 1 else "openai", model[-1] # Use the API key and base from the request, or from the environment. api_key = request.pop("api_key", None) or os.getenv(f"{provider}_API_KEY") api_base = request.pop("api_base", None) or os.getenv(f"{provider}_API_BASE") # Build the prompt from the messages. prompt = "\n\n".join([x["content"] for x in request.pop("messages")] + ["BEGIN RESPONSE:"]) return litellm.text_completion( cache=cache, model=f"text-completion-openai/{model}", api_key=api_key, api_base=api_base, prompt=prompt, num_retries=num_retries, retry_strategy="exponential_backoff_retry", headers=_get_headers(headers), **request, ) async def alitellm_completion(request: dict[str, Any], num_retries: int, cache: dict[str, Any] | None = None): cache = cache or {"no-cache": True, "no-store": True} request = dict(request) request.pop("rollout_id", None) headers = request.pop("headers", None) stream_completion = _get_stream_completion_fn(request, cache, sync=False) if stream_completion is None: return await litellm.acompletion( cache=cache, num_retries=num_retries, retry_strategy="exponential_backoff_retry", headers=_get_headers(headers), **request, ) return await stream_completion() async def alitellm_text_completion(request: dict[str, Any], num_retries: int, cache: dict[str, Any] | None = None): cache = cache or {"no-cache": True, "no-store": True} request = dict(request) request.pop("rollout_id", None) model = request.pop("model").split("/", 1) headers = request.pop("headers", None) provider, model = model[0] if len(model) > 1 else "openai", model[-1] # Use the API key and base from the request, or from the environment. api_key = request.pop("api_key", None) or os.getenv(f"{provider}_API_KEY") api_base = request.pop("api_base", None) or os.getenv(f"{provider}_API_BASE") # Build the prompt from the messages. prompt = "\n\n".join([x["content"] for x in request.pop("messages")] + ["BEGIN RESPONSE:"]) return await litellm.atext_completion( cache=cache, model=f"text-completion-openai/{model}", api_key=api_key, api_base=api_base, prompt=prompt, num_retries=num_retries, retry_strategy="exponential_backoff_retry", headers=_get_headers(headers), **request, ) def litellm_responses_completion(request: dict[str, Any], num_retries: int, cache: dict[str, Any] | None = None): cache = cache or {"no-cache": True, "no-store": True} request = dict(request) request.pop("rollout_id", None) headers = request.pop("headers", None) request = _convert_chat_request_to_responses_request(request) return litellm.responses( cache=cache, num_retries=num_retries, retry_strategy="exponential_backoff_retry", headers=_get_headers(headers), **request, ) async def alitellm_responses_completion(request: dict[str, Any], num_retries: int, cache: dict[str, Any] | None = None): cache = cache or {"no-cache": True, "no-store": True} request = dict(request) request.pop("rollout_id", None) headers = request.pop("headers", None) request = _convert_chat_request_to_responses_request(request) return await litellm.aresponses( cache=cache, num_retries=num_retries, retry_strategy="exponential_backoff_retry", headers=_get_headers(headers), **request, ) def _convert_chat_request_to_responses_request(request: dict[str, Any]): request = dict(request) if "messages" in request: content_blocks = [] for msg in request.pop("messages"): c = msg.get("content") if isinstance(c, str): content_blocks.append({"type": "input_text", "text": c}) elif isinstance(c, list): content_blocks.extend(c) request["input"] = [{"role": msg.get("role", "user"), "content": content_blocks}] # Convert `response_format` to `text.format` for Responses API if "response_format" in request: response_format = request.pop("response_format") text = request.pop("text", {}) request["text"] = {**text, "format": response_format} return request def _get_headers(headers: dict[str, Any] | None = None): headers = headers or {} return { "User-Agent": f"DSPy/{dspy.__version__}", **headers, } ``` -------------------------------------------------------------------------------- /tests/adapters/test_baml_adapter.py: -------------------------------------------------------------------------------- ```python from typing import Literal from unittest import mock import pydantic import pytest from litellm import Choices, Message from litellm.files.main import ModelResponse import dspy from dspy.adapters.baml_adapter import COMMENT_SYMBOL, BAMLAdapter # Test fixtures - Pydantic models for testing class PatientAddress(pydantic.BaseModel): street: str city: str country: Literal["US", "CA"] class PatientDetails(pydantic.BaseModel): name: str = pydantic.Field(description="Full name of the patient") age: int address: PatientAddress | None = None class ComplexNestedModel(pydantic.BaseModel): id: int = pydantic.Field(description="Unique identifier") details: PatientDetails tags: list[str] = pydantic.Field(default_factory=list) metadata: dict[str, str] = pydantic.Field(default_factory=dict) class ModelWithLists(pydantic.BaseModel): items: list[PatientAddress] = pydantic.Field(description="List of patient addresses") scores: list[float] class ImageWrapper(pydantic.BaseModel): images: list[dspy.Image] tag: list[str] class CircularModel(pydantic.BaseModel): name: str field: "CircularModel" def test_baml_adapter_basic_schema_generation(): """Test that BAMLAdapter generates simplified schemas for Pydantic models.""" class TestSignature(dspy.Signature): question: str = dspy.InputField() patient: PatientDetails = dspy.OutputField() adapter = BAMLAdapter() schema = adapter.format_field_structure(TestSignature) # Should contain simplified schema with comments assert f"{COMMENT_SYMBOL} Full name of the patient" in schema assert "name: string," in schema assert "age: int," in schema assert "address:" in schema assert "street: string," in schema assert 'country: "US" or "CA",' in schema def test_baml_adapter_handles_optional_fields(): """Test optional field rendering with 'or null' syntax.""" class TestSignature(dspy.Signature): input: str = dspy.InputField() patient: PatientDetails = dspy.OutputField() adapter = BAMLAdapter() schema = adapter.format_field_structure(TestSignature) # Optional address field should show 'or null' assert "address:" in schema assert "or null" in schema def test_baml_adapter_handles_primitive_types(): """Test rendering of basic primitive types.""" class SimpleModel(pydantic.BaseModel): text: str number: int decimal: float flag: bool class TestSignature(dspy.Signature): input: str = dspy.InputField() output: SimpleModel = dspy.OutputField() adapter = BAMLAdapter() schema = adapter.format_field_structure(TestSignature) assert "text: string," in schema assert "number: int," in schema assert "decimal: float," in schema assert "flag: boolean," in schema def test_baml_adapter_handles_lists_with_bracket_notation(): """Test that lists of Pydantic models use proper bracket notation.""" class TestSignature(dspy.Signature): input: str = dspy.InputField() addresses: ModelWithLists = dspy.OutputField() adapter = BAMLAdapter() schema = adapter.format_field_structure(TestSignature) # Should use bracket notation for lists and include comments assert "items: [" in schema assert f"{COMMENT_SYMBOL} List of patient addresses" in schema assert "street: string," in schema assert "city: string," in schema assert "]," in schema assert "scores: float[]," in schema def test_baml_adapter_handles_complex_nested_models(): """Test deeply nested Pydantic model schema generation.""" class TestSignature(dspy.Signature): input: str = dspy.InputField() complex: ComplexNestedModel = dspy.OutputField() adapter = BAMLAdapter() schema = adapter.format_field_structure(TestSignature) # Should include nested structure with comments assert f"{COMMENT_SYMBOL} Unique identifier" in schema assert "details:" in schema assert f"{COMMENT_SYMBOL} Full name of the patient" in schema assert "tags: string[]," in schema assert "metadata: dict[string, string]," in schema def test_baml_adapter_raise_error_on_circular_references(): """Test that circular references are handled gracefully.""" class TestSignature(dspy.Signature): input: str = dspy.InputField() circular: CircularModel = dspy.OutputField() adapter = BAMLAdapter() with pytest.raises(ValueError) as error: adapter.format_field_structure(TestSignature) assert "BAMLAdapter cannot handle recursive pydantic models" in str(error.value) def test_baml_adapter_formats_pydantic_inputs_as_clean_json(): """Test that Pydantic input instances are formatted as clean JSON.""" class TestSignature(dspy.Signature): patient: PatientDetails = dspy.InputField() question: str = dspy.InputField() answer: str = dspy.OutputField() adapter = BAMLAdapter() patient = PatientDetails( name="John Doe", age=45, address=PatientAddress(street="123 Main St", city="Anytown", country="US") ) messages = adapter.format(TestSignature, [], {"patient": patient, "question": "What is the diagnosis?"}) # Should have clean, indented JSON for Pydantic input user_message = messages[-1]["content"] assert '"name": "John Doe"' in user_message assert '"age": 45' in user_message assert '"street": "123 Main St"' in user_message assert '"country": "US"' in user_message def test_baml_adapter_handles_mixed_input_types(): """Test formatting of mixed Pydantic and primitive inputs.""" class TestSignature(dspy.Signature): patient: PatientDetails = dspy.InputField() priority: int = dspy.InputField() notes: str = dspy.InputField() result: str = dspy.OutputField() adapter = BAMLAdapter() patient = PatientDetails(name="Jane Doe", age=30) messages = adapter.format(TestSignature, [], {"patient": patient, "priority": 1, "notes": "Urgent case"}) user_message = messages[-1]["content"] # Pydantic should be JSON formatted assert '"name": "Jane Doe"' in user_message # Primitives should be formatted normally assert "priority ## ]]\n1" in user_message assert "notes ## ]]\nUrgent case" in user_message def test_baml_adapter_handles_schema_generation_errors_gracefully(): """Test graceful handling of schema generation errors.""" class ProblematicModel(pydantic.BaseModel): # This might cause issues in schema generation field: object class TestSignature(dspy.Signature): input: str = dspy.InputField() output: ProblematicModel = dspy.OutputField() adapter = BAMLAdapter() # Should not raise an exception try: schema = adapter.format_field_structure(TestSignature) # If no exception, schema should at least contain some basic structure assert "schema" in schema.lower() except Exception: # If exception occurs, test passes as we're testing graceful handling pass def test_baml_adapter_raises_on_missing_fields(): """Test that missing required fields raise appropriate errors.""" class TestSignature(dspy.Signature): input: str = dspy.InputField() patient: PatientDetails = dspy.OutputField() summary: str = dspy.OutputField() adapter = BAMLAdapter() # Missing 'summary' field completion = '{"patient": {"name": "John", "age": 30}}' with pytest.raises(dspy.utils.exceptions.AdapterParseError) as e: adapter.parse(TestSignature, completion) assert e.value.adapter_name == "JSONAdapter" # BAMLAdapter inherits from JSONAdapter assert "summary" in str(e.value) def test_baml_adapter_handles_type_casting_errors(): """Test graceful handling of type casting errors.""" class TestSignature(dspy.Signature): input: str = dspy.InputField() patient: PatientDetails = dspy.OutputField() adapter = BAMLAdapter() # Invalid age type completion = '{"patient": {"name": "John", "age": "not_a_number"}}' # Should raise ValidationError from Pydantic (which is the expected behavior) with pytest.raises((dspy.utils.exceptions.AdapterParseError, pydantic.ValidationError)): adapter.parse(TestSignature, completion) def test_baml_adapter_with_images(): """Test BAMLAdapter integration with dspy.Image objects.""" class TestSignature(dspy.Signature): image_data: ImageWrapper = dspy.InputField() description: str = dspy.OutputField() adapter = BAMLAdapter() image_wrapper = ImageWrapper( images=[dspy.Image(url="https://example.com/image1.jpg"), dspy.Image(url="https://example.com/image2.jpg")], tag=["test", "medical"], ) messages = adapter.format(TestSignature, [], {"image_data": image_wrapper}) # Should contain image URLs in the message content user_message = messages[-1]["content"] image_contents = [ content for content in user_message if isinstance(content, dict) and content.get("type") == "image_url" ] assert len(image_contents) == 2 assert {"type": "image_url", "image_url": {"url": "https://example.com/image1.jpg"}} in user_message assert {"type": "image_url", "image_url": {"url": "https://example.com/image2.jpg"}} in user_message def test_baml_adapter_with_tools(): """Test BAMLAdapter integration with dspy.Tool objects.""" class TestSignature(dspy.Signature): question: str = dspy.InputField() tools: list[dspy.Tool] = dspy.InputField() answer: str = dspy.OutputField() def get_patient_info(patient_id: int) -> str: """Get patient information by ID""" return f"Patient info for ID {patient_id}" def schedule_appointment(patient_name: str, date: str) -> str: """Schedule an appointment for a patient""" return f"Scheduled appointment for {patient_name} on {date}" tools = [dspy.Tool(get_patient_info), dspy.Tool(schedule_appointment)] adapter = BAMLAdapter() messages = adapter.format(TestSignature, [], {"question": "Schedule an appointment for John", "tools": tools}) user_message = messages[-1]["content"] assert "get_patient_info" in user_message assert "schedule_appointment" in user_message assert "Get patient information by ID" in user_message assert "Schedule an appointment for a patient" in user_message def test_baml_adapter_with_code(): """Test BAMLAdapter integration with dspy.Code objects.""" # Test with code as input field class CodeAnalysisSignature(dspy.Signature): code: dspy.Code = dspy.InputField() analysis: str = dspy.OutputField() adapter = BAMLAdapter() messages = adapter.format(CodeAnalysisSignature, [], {"code": "def hello():\n print('Hello, world!')"}) user_message = messages[-1]["content"] assert "def hello():" in user_message assert "print('Hello, world!')" in user_message # Test with code as output field class CodeGenSignature(dspy.Signature): task: str = dspy.InputField() code: dspy.Code = dspy.OutputField() with mock.patch("litellm.completion") as mock_completion: mock_completion.return_value = ModelResponse( choices=[Choices(message=Message(content='{"code": "print(\\"Generated code\\")"}'))], model="openai/gpt-4o-mini", ) result = adapter( dspy.LM(model="openai/gpt-4o-mini", cache=False), {}, CodeGenSignature, [], {"task": "Write a hello world program"}, ) assert result[0]["code"].code == 'print("Generated code")' def test_baml_adapter_with_conversation_history(): """Test BAMLAdapter integration with dspy.History objects.""" class TestSignature(dspy.Signature): history: dspy.History = dspy.InputField() question: str = dspy.InputField() answer: str = dspy.OutputField() history = dspy.History( messages=[ {"question": "What is the patient's age?", "answer": "45 years old"}, {"question": "Any allergies?", "answer": "Penicillin allergy"}, ] ) adapter = BAMLAdapter() messages = adapter.format(TestSignature, [], {"history": history, "question": "What medications should we avoid?"}) # Should format history as separate messages assert len(messages) == 6 # system + 2 history pairs + user assert "What is the patient's age?" in messages[1]["content"] assert '"answer": "45 years old"' in messages[2]["content"] assert "Any allergies?" in messages[3]["content"] assert '"answer": "Penicillin allergy"' in messages[4]["content"] # Comparison tests with JSONAdapter def test_baml_vs_json_adapter_token_efficiency(): """Test that BAMLAdapter generates more token-efficient schemas.""" class TestSignature(dspy.Signature): input: str = dspy.InputField() complex: ComplexNestedModel = dspy.OutputField() baml_adapter = BAMLAdapter() json_adapter = dspy.JSONAdapter() baml_schema = baml_adapter.format_field_structure(TestSignature) json_schema = json_adapter.format_field_structure(TestSignature) # Simple character count as proxy for token efficiency # BAMLAdapter should always produce shorter schemas assert len(baml_schema) < len(json_schema) def test_baml_vs_json_adapter_functional_compatibility(): """Test that both adapters parse identical outputs to the same results.""" class TestSignature(dspy.Signature): question: str = dspy.InputField() patient: PatientDetails = dspy.OutputField() baml_adapter = BAMLAdapter() json_adapter = dspy.JSONAdapter() completion = """{"patient": { "name": "Alice Brown", "age": 35, "address": {"street": "789 Pine St", "city": "Boston", "country": "US"} }}""" baml_result = baml_adapter.parse(TestSignature, completion) json_result = json_adapter.parse(TestSignature, completion) # Results should be functionally equivalent assert baml_result["patient"].name == json_result["patient"].name assert baml_result["patient"].age == json_result["patient"].age assert baml_result["patient"].address.street == json_result["patient"].address.street @pytest.mark.asyncio async def test_baml_adapter_async_functionality(): """Test BAMLAdapter async operations.""" class TestSignature(dspy.Signature): question: str = dspy.InputField() patient: PatientDetails = dspy.OutputField() with mock.patch("litellm.acompletion") as mock_acompletion: mock_acompletion.return_value = ModelResponse( choices=[Choices(message=Message(content='{"patient": {"name": "John Doe", "age": 28}}'))], model="openai/gpt-4o", ) adapter = BAMLAdapter() result = await adapter.acall( dspy.LM(model="openai/gpt-4o", cache=False), {}, TestSignature, [], {"question": "Extract patient info"} ) assert result[0]["patient"].name == "John Doe" assert result[0]["patient"].age == 28 def test_baml_adapter_with_field_aliases(): """Test BAMLAdapter with Pydantic field aliases.""" class ModelWithAliases(pydantic.BaseModel): full_name: str = pydantic.Field(alias="name") patient_age: int = pydantic.Field(alias="age") class TestSignature(dspy.Signature): input: str = dspy.InputField() data: ModelWithAliases = dspy.OutputField() adapter = BAMLAdapter() # Schema should show aliases in the output structure schema = adapter.format_field_structure(TestSignature) assert "name:" in schema # Should use alias, not field name assert "age:" in schema # Should use alias, not field name def test_baml_adapter_field_alias_without_description(): """Test BAMLAdapter with field alias present but description absent.""" class ModelWithAliasNoDescription(pydantic.BaseModel): internal_field: str = pydantic.Field(alias="public_name") regular_field: int field_with_description: str = pydantic.Field(description="This field has a description", alias="desc_field") class TestSignature(dspy.Signature): input: str = dspy.InputField() data: ModelWithAliasNoDescription = dspy.OutputField() adapter = BAMLAdapter() schema = adapter.format_field_structure(TestSignature) # Should show alias as comment when description is absent assert f"{COMMENT_SYMBOL} alias: public_name" in schema # Should show description comment when present assert f"{COMMENT_SYMBOL} This field has a description" in schema # Regular field (without alias) should appear in schema but without alias comment assert "regular_field: int," in schema # Check that regular_field section doesn't have an alias comment regular_field_section = schema.split("regular_field: int,")[0].split("\n")[-1] assert f"{COMMENT_SYMBOL} alias:" not in regular_field_section def test_baml_adapter_multiple_pydantic_input_fields(): """Test that multiple InputField() with Pydantic models are rendered correctly.""" class UserProfile(pydantic.BaseModel): name: str = pydantic.Field(description="User's full name") email: str age: int class SystemConfig(pydantic.BaseModel): timeout: int = pydantic.Field(description="Timeout in seconds") debug: bool endpoints: list[str] class TestSignature(dspy.Signature): input_1: UserProfile = dspy.InputField() input_2: SystemConfig = dspy.InputField() result: str = dspy.OutputField() adapter = BAMLAdapter() # Test schema generation includes headers for ALL input fields schema = adapter.format_field_structure(TestSignature) assert "[[ ## input_1 ## ]]" in schema # Should include first input field header assert "[[ ## input_2 ## ]]" in schema # Should include second input field header assert "[[ ## result ## ]]" in schema # Should include output field header assert "[[ ## completed ## ]]" in schema # Should include completed section assert "All interactions will be structured in the following way" in schema assert "{input_1}" in schema assert "{input_2}" in schema assert "Output field `result` should be of type: string" in schema # Test field descriptions are in the correct method field_desc = adapter.format_field_description(TestSignature) assert "Your input fields are:" in field_desc assert "Your output fields are:" in field_desc # Test message formatting with actual Pydantic instances user_profile = UserProfile(name="John Doe", email="[email protected]", age=30) system_config = SystemConfig(timeout=300, debug=True, endpoints=["api1", "api2"]) messages = adapter.format(TestSignature, [], {"input_1": user_profile, "input_2": system_config}) user_message = messages[-1]["content"] # Verify both inputs are rendered with the correct bracket notation assert "[[ ## input_1 ## ]]" in user_message assert "[[ ## input_2 ## ]]" in user_message # Verify JSON content for both inputs assert '"name": "John Doe"' in user_message assert '"email": "[email protected]"' in user_message assert '"age": 30' in user_message assert '"timeout": 300' in user_message assert '"debug": true' in user_message # Endpoints array is formatted with indentation, so check for individual elements assert '"api1"' in user_message assert '"api2"' in user_message assert '"endpoints":' in user_message ``` -------------------------------------------------------------------------------- /tests/clients/test_lm.py: -------------------------------------------------------------------------------- ```python import json import time import warnings from unittest import mock from unittest.mock import patch import litellm import pydantic import pytest from litellm.types.llms.openai import ResponseAPIUsage, ResponsesAPIResponse from litellm.utils import Choices, Message, ModelResponse from openai import RateLimitError from openai.types.responses import ResponseOutputMessage, ResponseReasoningItem from openai.types.responses.response_reasoning_item import Summary import dspy from dspy.utils.dummies import DummyLM from dspy.utils.usage_tracker import track_usage def make_response(output_blocks): return ResponsesAPIResponse( id="resp_1", created_at=0.0, error=None, incomplete_details=None, instructions=None, model="openai/dspy-test-model", object="response", output=output_blocks, metadata = {}, parallel_tool_calls=False, temperature=1.0, tool_choice="auto", tools=[], top_p=1.0, max_output_tokens=None, previous_response_id=None, reasoning=None, status="completed", text=None, truncation="disabled", usage=ResponseAPIUsage(input_tokens=1, output_tokens=1, total_tokens=2), user=None, ) def test_chat_lms_can_be_queried(litellm_test_server): api_base, _ = litellm_test_server expected_response = ["Hi!"] openai_lm = dspy.LM( model="openai/dspy-test-model", api_base=api_base, api_key="fakekey", model_type="chat", ) assert openai_lm("openai query") == expected_response azure_openai_lm = dspy.LM( model="azure/dspy-test-model", api_base=api_base, api_key="fakekey", model_type="chat", ) assert azure_openai_lm("azure openai query") == expected_response def test_dspy_cache(litellm_test_server, tmp_path): api_base, _ = litellm_test_server original_cache = dspy.cache dspy.clients.configure_cache( enable_disk_cache=True, enable_memory_cache=True, disk_cache_dir=tmp_path / ".disk_cache", ) cache = dspy.cache lm = dspy.LM( model="openai/dspy-test-model", api_base=api_base, api_key="fakekey", model_type="text", ) with track_usage() as usage_tracker: lm("Query") assert len(cache.memory_cache) == 1 cache_key = next(iter(cache.memory_cache.keys())) assert cache_key in cache.disk_cache assert len(usage_tracker.usage_data) == 1 with track_usage() as usage_tracker: lm("Query") assert len(usage_tracker.usage_data) == 0 dspy.cache = original_cache def test_disabled_cache_skips_cache_key(monkeypatch): original_cache = dspy.cache dspy.configure_cache(enable_disk_cache=False, enable_memory_cache=False) cache = dspy.cache try: with mock.patch.object(cache, "cache_key", wraps=cache.cache_key) as cache_key_spy, \ mock.patch.object(cache, "get", wraps=cache.get) as cache_get_spy, \ mock.patch.object(cache, "put", wraps=cache.put) as cache_put_spy: def fake_completion(*, cache, num_retries, retry_strategy, **request): return ModelResponse( choices=[Choices(message=Message(role="assistant", content="Hi!"))], usage={"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}, model="dummy", ) monkeypatch.setattr(litellm, "completion", fake_completion) dummy_lm = DummyLM([{"answer": "ignored"}]) # TODO(isaacbmiller): Change from dummy_lm.forward to just dummy_lm.__call__ #8864 dummy_lm.forward(messages=[{"role": "user", "content": "Hello"}]) cache_key_spy.assert_not_called() cache_get_spy.assert_called_once() cache_put_spy.assert_called_once() finally: dspy.cache = original_cache def test_rollout_id_bypasses_cache(monkeypatch, tmp_path): calls: list[dict] = [] def fake_completion(*, cache, num_retries, retry_strategy, **request): calls.append(request) return ModelResponse( choices=[Choices(message=Message(role="assistant", content="Hi!"))], usage={"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}, model="openai/dspy-test-model", ) monkeypatch.setattr(litellm, "completion", fake_completion) original_cache = dspy.cache dspy.clients.configure_cache( enable_disk_cache=True, enable_memory_cache=True, disk_cache_dir=tmp_path / ".disk_cache", ) lm = dspy.LM(model="openai/dspy-test-model", model_type="chat") with track_usage() as usage_tracker: lm(messages=[{"role": "user", "content": "Query"}], rollout_id=1) assert len(usage_tracker.usage_data) == 1 with track_usage() as usage_tracker: lm(messages=[{"role": "user", "content": "Query"}], rollout_id=1) assert len(usage_tracker.usage_data) == 0 with track_usage() as usage_tracker: lm(messages=[{"role": "user", "content": "Query"}], rollout_id=2) assert len(usage_tracker.usage_data) == 1 with track_usage() as usage_tracker: lm(messages=[{"role": "user", "content": "NoRID"}]) assert len(usage_tracker.usage_data) == 1 with track_usage() as usage_tracker: lm(messages=[{"role": "user", "content": "NoRID"}], rollout_id=None) assert len(usage_tracker.usage_data) == 0 assert len(dspy.cache.memory_cache) == 3 assert all("rollout_id" not in r for r in calls) dspy.cache = original_cache def test_zero_temperature_rollout_warns_once(monkeypatch): def fake_completion(*, cache, num_retries, retry_strategy, **request): return ModelResponse( choices=[Choices(message=Message(role="assistant", content="Hi!"))], usage={"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}, model="openai/dspy-test-model", ) monkeypatch.setattr(litellm, "completion", fake_completion) lm = dspy.LM(model="openai/dspy-test-model", model_type="chat") with pytest.warns(UserWarning, match="rollout_id has no effect"): lm("Query", rollout_id=1) with warnings.catch_warnings(record=True) as record: warnings.simplefilter("always") lm("Query", rollout_id=2) assert len(record) == 0 def test_text_lms_can_be_queried(litellm_test_server): api_base, _ = litellm_test_server expected_response = ["Hi!"] openai_lm = dspy.LM( model="openai/dspy-test-model", api_base=api_base, api_key="fakekey", model_type="text", ) assert openai_lm("openai query") == expected_response azure_openai_lm = dspy.LM( model="azure/dspy-test-model", api_base=api_base, api_key="fakekey", model_type="text", ) assert azure_openai_lm("azure openai query") == expected_response def test_lm_calls_support_callables(litellm_test_server): api_base, _ = litellm_test_server with mock.patch("litellm.completion", autospec=True, wraps=litellm.completion) as spy_completion: def azure_ad_token_provider(*args, **kwargs): return None lm_with_callable = dspy.LM( model="openai/dspy-test-model", api_base=api_base, api_key="fakekey", azure_ad_token_provider=azure_ad_token_provider, cache=False, ) lm_with_callable("Query") spy_completion.assert_called_once() call_args = spy_completion.call_args.kwargs assert call_args["model"] == "openai/dspy-test-model" assert call_args["api_base"] == api_base assert call_args["api_key"] == "fakekey" assert call_args["azure_ad_token_provider"] is azure_ad_token_provider def test_lm_calls_support_pydantic_models(litellm_test_server): api_base, _ = litellm_test_server class ResponseFormat(pydantic.BaseModel): response: str lm = dspy.LM( model="openai/dspy-test-model", api_base=api_base, api_key="fakekey", response_format=ResponseFormat, ) lm("Query") def test_retry_number_set_correctly(): lm = dspy.LM("openai/gpt-4o-mini", num_retries=3) with mock.patch("litellm.completion") as mock_completion: lm("query") assert mock_completion.call_args.kwargs["num_retries"] == 3 def test_retry_made_on_system_errors(): retry_tracking = [0] # Using a list to track retries def mock_create(*args, **kwargs): retry_tracking[0] += 1 # These fields are called during the error handling mock_response = mock.Mock() mock_response.headers = {} mock_response.status_code = 429 raise RateLimitError(response=mock_response, message="message", body="error") lm = dspy.LM(model="openai/gpt-4o-mini", max_tokens=250, num_retries=3) with mock.patch.object(litellm.OpenAIChatCompletion, "completion", side_effect=mock_create): with pytest.raises(RateLimitError): lm("question") assert retry_tracking[0] == 4 def test_reasoning_model_token_parameter(): test_cases = [ ("openai/o1", True), ("openai/o1-mini", True), ("openai/o1-2023-01-01", True), ("openai/o3", True), ("openai/o3-mini-2023-01-01", True), ("openai/gpt-5", True), ("openai/gpt-5-mini", True), ("openai/gpt-5-nano", True), ("openai/gpt-4", False), ("anthropic/claude-2", False), ] for model_name, is_reasoning_model in test_cases: lm = dspy.LM( model=model_name, temperature=1.0 if is_reasoning_model else 0.7, max_tokens=16_000 if is_reasoning_model else 1000, ) if is_reasoning_model: assert "max_completion_tokens" in lm.kwargs assert "max_tokens" not in lm.kwargs assert lm.kwargs["max_completion_tokens"] == 16_000 else: assert "max_completion_tokens" not in lm.kwargs assert "max_tokens" in lm.kwargs assert lm.kwargs["max_tokens"] == 1000 @pytest.mark.parametrize("model_name", ["openai/o1", "openai/gpt-5-nano"]) def test_reasoning_model_requirements(model_name): # Should raise assertion error if temperature or max_tokens requirements not met with pytest.raises( ValueError, match="reasoning models require passing temperature=1.0 or None and max_tokens >= 16000 or None", ): dspy.LM( model=model_name, temperature=0.7, # Should be 1.0 max_tokens=1000, # Should be >= 16_000 ) # Should pass with correct parameters lm = dspy.LM( model=model_name, temperature=1.0, max_tokens=16_000, ) assert lm.kwargs["max_completion_tokens"] == 16_000 # Should pass with no parameters lm = dspy.LM( model=model_name, ) assert lm.kwargs["temperature"] == None assert lm.kwargs["max_completion_tokens"] == None def test_dump_state(): lm = dspy.LM( model="openai/gpt-4o-mini", model_type="chat", temperature=1, max_tokens=100, num_retries=10, launch_kwargs={"temperature": 1}, train_kwargs={"temperature": 5}, ) assert lm.dump_state() == { "model": "openai/gpt-4o-mini", "model_type": "chat", "temperature": 1, "max_tokens": 100, "num_retries": 10, "cache": True, "finetuning_model": None, "launch_kwargs": {"temperature": 1}, "train_kwargs": {"temperature": 5}, } def test_exponential_backoff_retry(): time_counter = [] def mock_create(*args, **kwargs): time_counter.append(time.time()) # These fields are called during the error handling mock_response = mock.Mock() mock_response.headers = {} mock_response.status_code = 429 raise RateLimitError(response=mock_response, message="message", body="error") lm = dspy.LM(model="openai/gpt-3.5-turbo", max_tokens=250, num_retries=3) with mock.patch.object(litellm.OpenAIChatCompletion, "completion", side_effect=mock_create): with pytest.raises(RateLimitError): lm("question") # The first retry happens immediately regardless of the configuration for i in range(1, len(time_counter) - 1): assert time_counter[i + 1] - time_counter[i] >= 2 ** (i - 1) def test_logprobs_included_when_requested(): lm = dspy.LM(model="dspy-test-model", logprobs=True, cache=False) with mock.patch("litellm.completion") as mock_completion: mock_completion.return_value = ModelResponse( choices=[ Choices( message=Message(content="test answer"), logprobs={ "content": [ {"token": "test", "logprob": 0.1, "top_logprobs": [{"token": "test", "logprob": 0.1}]}, {"token": "answer", "logprob": 0.2, "top_logprobs": [{"token": "answer", "logprob": 0.2}]}, ] }, ) ], model="dspy-test-model", ) result = lm("question") assert result[0]["text"] == "test answer" assert result[0]["logprobs"].model_dump() == { "content": [ { "token": "test", "bytes": None, "logprob": 0.1, "top_logprobs": [{"token": "test", "bytes": None, "logprob": 0.1}], }, { "token": "answer", "bytes": None, "logprob": 0.2, "top_logprobs": [{"token": "answer", "bytes": None, "logprob": 0.2}], }, ] } assert mock_completion.call_args.kwargs["logprobs"] @pytest.mark.asyncio async def test_async_lm_call(): from litellm.utils import Choices, Message, ModelResponse mock_response = ModelResponse(choices=[Choices(message=Message(content="answer"))], model="openai/gpt-4o-mini") with patch("litellm.acompletion") as mock_acompletion: mock_acompletion.return_value = mock_response lm = dspy.LM(model="openai/gpt-4o-mini", cache=False) result = await lm.acall("question") assert result == ["answer"] mock_acompletion.assert_called_once() @pytest.mark.asyncio async def test_async_lm_call_with_cache(tmp_path): """Test the async LM call with caching.""" original_cache = dspy.cache dspy.clients.configure_cache( enable_disk_cache=True, enable_memory_cache=True, disk_cache_dir=tmp_path / ".disk_cache", ) cache = dspy.cache lm = dspy.LM(model="openai/gpt-4o-mini") with mock.patch("dspy.clients.lm.alitellm_completion") as mock_alitellm_completion: mock_alitellm_completion.return_value = ModelResponse( choices=[Choices(message=Message(content="answer"))], model="openai/gpt-4o-mini" ) mock_alitellm_completion.__qualname__ = "alitellm_completion" await lm.acall("Query") assert len(cache.memory_cache) == 1 cache_key = next(iter(cache.memory_cache.keys())) assert cache_key in cache.disk_cache assert mock_alitellm_completion.call_count == 1 await lm.acall("Query") # Second call should hit the cache, so no new call to LiteLLM is made. assert mock_alitellm_completion.call_count == 1 # A new query should result in a new LiteLLM call and a new cache entry. await lm.acall("New query") assert len(cache.memory_cache) == 2 assert mock_alitellm_completion.call_count == 2 dspy.cache = original_cache def test_lm_history_size_limit(): lm = dspy.LM(model="openai/gpt-4o-mini") with dspy.context(max_history_size=5): with mock.patch("litellm.completion") as mock_completion: mock_completion.return_value = ModelResponse( choices=[Choices(message=Message(content="test answer"))], model="openai/gpt-4o-mini", ) for _ in range(10): lm("query") assert len(lm.history) == 5 def test_disable_history(): lm = dspy.LM(model="openai/gpt-4o-mini") with dspy.context(disable_history=True): with mock.patch("litellm.completion") as mock_completion: mock_completion.return_value = ModelResponse( choices=[Choices(message=Message(content="test answer"))], model="openai/gpt-4o-mini", ) for _ in range(10): lm("query") assert len(lm.history) == 0 with dspy.context(disable_history=False): with mock.patch("litellm.completion") as mock_completion: mock_completion.return_value = ModelResponse( choices=[Choices(message=Message(content="test answer"))], model="openai/gpt-4o-mini", ) def test_responses_api(): api_response = make_response( output_blocks=[ ResponseOutputMessage( **{ "id": "msg_1", "type": "message", "role": "assistant", "status": "completed", "content": [ {"type": "output_text", "text": "This is a test answer from responses API.", "annotations": []} ], }, ), ResponseReasoningItem( **{ "id": "reasoning_1", "type": "reasoning", "summary": [Summary(**{"type": "summary_text", "text": "This is a dummy reasoning."})], }, ), ] ) with mock.patch("litellm.responses", autospec=True, return_value=api_response) as dspy_responses: lm = dspy.LM( model="openai/gpt-5-mini", model_type="responses", cache=False, temperature=1.0, max_tokens=16000, ) lm_result = lm("openai query") assert lm_result == [ { "text": "This is a test answer from responses API.", "reasoning_content": "This is a dummy reasoning.", } ] dspy_responses.assert_called_once() assert dspy_responses.call_args.kwargs["model"] == "openai/gpt-5-mini" def test_lm_replaces_system_with_developer_role(): with mock.patch( "dspy.clients.lm.litellm_responses_completion", return_value={"choices": []} ) as mock_completion: lm = dspy.LM( "openai/gpt-4o-mini", cache=False, model_type="responses", use_developer_role=True, ) lm.forward(messages=[{"role": "system", "content": "hi"}]) assert ( mock_completion.call_args.kwargs["request"]["messages"][0]["role"] == "developer" ) def test_responses_api_tool_calls(litellm_test_server): api_base, _ = litellm_test_server expected_tool_call = { "type": "function_call", "name": "get_weather", "arguments": json.dumps({"city": "Paris"}), "call_id": "call_1", "status": "completed", "id": "call_1", } expected_response = [{"tool_calls": [expected_tool_call]}] api_response = make_response( output_blocks=[expected_tool_call], ) with mock.patch("litellm.responses", autospec=True, return_value=api_response) as dspy_responses: lm = dspy.LM( model="openai/dspy-test-model", api_base=api_base, api_key="fakekey", model_type="responses", cache=False, ) assert lm("openai query") == expected_response dspy_responses.assert_called_once() assert dspy_responses.call_args.kwargs["model"] == "openai/dspy-test-model" ``` -------------------------------------------------------------------------------- /docs/docs/api/optimizers/GEPA/GEPA_Advanced.md: -------------------------------------------------------------------------------- ```markdown # dspy.GEPA - Advanced Features ## Custom Instruction Proposers ### What is instruction_proposer? The `instruction_proposer` is the component responsible for invoking the `reflection_lm` and proposing new prompts during GEPA optimization. When GEPA identifies underperforming components in your DSPy program, the instruction proposer analyzes execution traces, feedback, and failures to generate improved instructions tailored to the observed issues. ### Default Implementation By default, GEPA uses the built-in instruction proposer from the [GEPA library](https://github.com/gepa-ai/gepa), which implements the [`ProposalFn`](https://github.com/gepa-ai/gepa/blob/main/src/gepa/core/adapter.py). The [default proposer](https://github.com/gepa-ai/gepa/blob/main/src/gepa/proposer/reflective_mutation/reflective_mutation.py#L53-L75) uses this prompt template: ```` I provided an assistant with the following instructions to perform a task for me: ``` <curr_instructions> ``` The following are examples of different task inputs provided to the assistant along with the assistant's response for each of them, and some feedback on how the assistant's response could be better: ``` <inputs_outputs_feedback> ``` Your task is to write a new instruction for the assistant. Read the inputs carefully and identify the input format and infer detailed task description about the task I wish to solve with the assistant. Read all the assistant responses and the corresponding feedback. Identify all niche and domain specific factual information about the task and include it in the instruction, as a lot of it may not be available to the assistant in the future. The assistant may have utilized a generalizable strategy to solve the task, if so, include that in the instruction as well. Provide the new instructions within ``` blocks. ```` This template is automatically filled with: - `<curr_instructions>`: The current instruction being optimized - `<inputs_outputs_feedback>`: Structured markdown containing predictor inputs, generated outputs, and evaluation feedback Example of default behavior: ```python # Default instruction proposer is used automatically gepa = dspy.GEPA( metric=my_metric, reflection_lm=dspy.LM(model="gpt-5", temperature=1.0, max_tokens=32000, api_key=api_key), auto="medium" ) optimized_program = gepa.compile(student, trainset=examples) ``` ### When to Use Custom instruction_proposer **Note:** Custom instruction proposers are an advanced feature. Most users should start with the default proposer, which works well for most text-based optimization tasks. Consider implementing a custom instruction proposer when you need: - **Multi-modal handling**: Process images (dspy.Image) alongside textual information in your inputs - **Nuanced control on limits and length constraints**: Have more fine-grained control over instruction length, format, and structural requirements - **Domain-specific information**: Inject specialized knowledge, terminology, or context that the default proposer lacks and cannot be provided via feedback_func. This is an advanced feature, and most users should not need to use this. - **Provider-specific prompting guides**: Optimize instructions for specific LLM providers (OpenAI, Anthropic, etc.) with their unique formatting preferences - **Coupled component updates**: Handle situations where 2 or more components need to be updated together in a coordinated manner, rather than optimizing each component independently (refer to component_selector parameter, in [Custom Component Selection](#custom-component-selection) section, for related functionality) - **External knowledge integration**: Connect to databases, APIs, or knowledge bases during instruction generation ### Available Options **Built-in Options:** - **Default Proposer**: The standard GEPA instruction proposer (used when `instruction_proposer=None`). The default instruction proposer IS an instruction proposer as well! It is the most general one, that was used for the diverse experiments reported in the GEPA paper and tutorials. - **MultiModalInstructionProposer**: Handles `dspy.Image` inputs and structured multimodal content. ```python from dspy.teleprompt.gepa.instruction_proposal import MultiModalInstructionProposer # For tasks involving images or multimodal inputs gepa = dspy.GEPA( metric=my_metric, reflection_lm=dspy.LM(model="gpt-5", temperature=1.0, max_tokens=32000, api_key=api_key), instruction_proposer=MultiModalInstructionProposer(), auto="medium" ) ``` We invite community contributions of new instruction proposers for specialized domains as the [GEPA library](https://github.com/gepa-ai/gepa) continues to grow. ### How to Implement Custom Instruction Proposers Custom instruction proposers must implement the `ProposalFn` protocol by defining a callable class or function. GEPA will call your proposer during optimization: ```python from dspy.teleprompt.gepa.gepa_utils import ReflectiveExample class CustomInstructionProposer: def __call__( self, candidate: dict[str, str], # Candidate component name -> instruction mapping to be updated in this round reflective_dataset: dict[str, list[ReflectiveExample]], # Component -> examples with structure: {"Inputs": ..., "Generated Outputs": ..., "Feedback": ...} components_to_update: list[str] # Which components to improve ) -> dict[str, str]: # Return new instruction mapping only for components being updated # Your custom instruction generation logic here return updated_instructions # Or as a function: def custom_instruction_proposer(candidate, reflective_dataset, components_to_update): # Your custom instruction generation logic here return updated_instructions ``` **Reflective Dataset Structure:** - `dict[str, list[ReflectiveExample]]` - Maps component names to lists of examples - `ReflectiveExample` TypedDict contains: - `Inputs: dict[str, Any]` - Predictor inputs (may include dspy.Image objects) - `Generated_Outputs: dict[str, Any] | str` - Success: output fields dict, Failure: error message - `Feedback: str` - Always a string from metric function or auto-generated by GEPA #### Basic Example: Word Limit Proposer ```python import dspy from gepa.core.adapter import ProposalFn from dspy.teleprompt.gepa.gepa_utils import ReflectiveExample class GenerateWordLimitedInstruction(dspy.Signature): """Given a current instruction and feedback examples, generate an improved instruction with word limit constraints.""" current_instruction = dspy.InputField(desc="The current instruction that needs improvement") feedback_summary = dspy.InputField(desc="Feedback from examples that might include both positive and negative cases") max_words = dspy.InputField(desc="Maximum number of words allowed in the new instruction") improved_instruction = dspy.OutputField(desc="A new instruction that fixes the issues while staying under the max_words limit") class WordLimitProposer(ProposalFn): def __init__(self, max_words: int = 1000): self.max_words = max_words self.instruction_improver = dspy.ChainOfThought(GenerateWordLimitedInstruction) def __call__(self, candidate: dict[str, str], reflective_dataset: dict[str, list[ReflectiveExample]], components_to_update: list[str]) -> dict[str, str]: updated_components = {} for component_name in components_to_update: if component_name not in candidate or component_name not in reflective_dataset: continue current_instruction = candidate[component_name] component_examples = reflective_dataset[component_name] # Create feedback summary feedback_text = "\n".join([ f"Example {i+1}: {ex.get('Feedback', 'No feedback')}" for i, ex in enumerate(component_examples) # Limit examples to prevent context overflow ]) # Use the module to improve the instruction result = self.instruction_improver( current_instruction=current_instruction, feedback_summary=feedback_text, max_words=self.max_words ) updated_components[component_name] = result.improved_instruction return updated_components # Usage gepa = dspy.GEPA( metric=my_metric, reflection_lm=dspy.LM(model="gpt-5", temperature=1.0, max_tokens=32000, api_key=api_key), instruction_proposer=WordLimitProposer(max_words=700), auto="medium" ) ``` #### Advanced Example: RAG-Enhanced Instruction Proposer ```python import dspy from gepa.core.adapter import ProposalFn from dspy.teleprompt.gepa.gepa_utils import ReflectiveExample class GenerateDocumentationQuery(dspy.Signature): """Analyze examples with feedback to identify common issue patterns and generate targeted database queries for retrieving relevant documentation. Your goal is to search a document database for guidelines that address the problematic patterns found in the examples. Look for recurring issues, error types, or failure modes in the feedback, then craft specific search queries that will find documentation to help resolve these patterns.""" current_instruction = dspy.InputField(desc="The current instruction that needs improvement") examples_with_feedback = dspy.InputField(desc="Examples with their feedback showing what issues occurred and any recurring patterns") failure_patterns: str = dspy.OutputField(desc="Summarize the common failure patterns identified in the examples") retrieval_queries: list[str] = dspy.OutputField(desc="Specific search queries to find relevant documentation in the database that addresses the common issue patterns identified in the problematic examples") class GenerateRAGEnhancedInstruction(dspy.Signature): """Generate improved instructions using retrieved documentation and examples analysis.""" current_instruction = dspy.InputField(desc="The current instruction that needs improvement") relevant_documentation = dspy.InputField(desc="Retrieved guidelines and best practices from specialized documentation") examples_with_feedback = dspy.InputField(desc="Examples showing what issues occurred with the current instruction") improved_instruction: str = dspy.OutputField(desc="Enhanced instruction that incorporates retrieved guidelines and addresses the issues shown in the examples") class RAGInstructionImprover(dspy.Module): """Module that uses RAG to improve instructions with specialized documentation.""" def __init__(self, retrieval_model): super().__init__() self.retrieve = retrieval_model # Could be dspy.Retrieve or custom retriever self.query_generator = dspy.ChainOfThought(GenerateDocumentationQuery) self.generate_answer = dspy.ChainOfThought(GenerateRAGEnhancedInstruction) def forward(self, current_instruction: str, component_examples: list): """Improve instruction using retrieved documentation.""" # Let LM analyze examples and generate targeted retrieval queries query_result = self.query_generator( current_instruction=current_instruction, examples_with_feedback=component_examples ) results = self.retrieve.query( query_texts=query_result.retrieval_queries, n_results=3 ) relevant_docs_parts = [] for i, (query, query_docs) in enumerate(zip(query_result.retrieval_queries, results['documents'])): if query_docs: docs_formatted = "\n".join([f" - {doc}" for doc in query_docs]) relevant_docs_parts.append( f"**Search Query #{i+1}**: {query}\n" f"**Retrieved Guidelines**:\n{docs_formatted}" ) relevant_docs = "\n\n" + "="*60 + "\n\n".join(relevant_docs_parts) + "\n" + "="*60 # Generate improved instruction with retrieved context result = self.generate_answer( current_instruction=current_instruction, relevant_documentation=relevant_docs, examples_with_feedback=component_examples ) return result class DocumentationEnhancedProposer(ProposalFn): """Instruction proposer that accesses specialized documentation via RAG.""" def __init__(self, documentation_retriever): """ Args: documentation_retriever: A retrieval model that can search your specialized docs Could be dspy.Retrieve, ChromadbRM, or custom retriever """ self.instruction_improver = RAGInstructionImprover(documentation_retriever) def __call__(self, candidate: dict[str, str], reflective_dataset: dict[str, list[ReflectiveExample]], components_to_update: list[str]) -> dict[str, str]: updated_components = {} for component_name in components_to_update: if component_name not in candidate or component_name not in reflective_dataset: continue current_instruction = candidate[component_name] component_examples = reflective_dataset[component_name] result = self.instruction_improver( current_instruction=current_instruction, component_examples=component_examples ) updated_components[component_name] = result.improved_instruction return updated_components import chromadb client = chromadb.Client() collection = client.get_collection("instruction_guidelines") gepa = dspy.GEPA( metric=task_specific_metric, reflection_lm=dspy.LM(model="gpt-5", temperature=1.0, max_tokens=32000, api_key=api_key), instruction_proposer=DocumentationEnhancedProposer(collection), auto="medium" ) ``` #### Integration Patterns **Using Custom Proposer with External LM:** ```python class ExternalLMProposer(ProposalFn): def __init__(self): # Manage your own LM instance self.external_lm = dspy.LM('gemini/gemini-2.5-pro') def __call__(self, candidate, reflective_dataset, components_to_update): updated_components = {} with dspy.context(lm=self.external_lm): # Your custom logic here using self.external_lm for component_name in components_to_update: # ... implementation pass return updated_components gepa = dspy.GEPA( metric=my_metric, reflection_lm=None, # Optional when using custom proposer instruction_proposer=ExternalLMProposer(), auto="medium" ) ``` **Best Practices:** - **Use the full power of DSPy**: Leverage DSPy components like `dspy.Module`, `dspy.Signature`, and `dspy.Predict` to create your instruction proposer rather than direct LM calls. Consider `dspy.Refine` for constraint satisfaction, `dspy.ChainOfThought` for complex reasoning tasks, and compose multiple modules for sophisticated instruction improvement workflows - **Enable holistic feedback analysis**: While dspy.GEPA's `GEPAFeedbackMetric` processes one (gold, prediction) pair at a time, instruction proposers receive all examples for a component in batch, enabling cross-example pattern detection and systematic issue identification. - **Mind data serialization**: Serializing everything to strings might not be ideal - handle complex input types (like `dspy.Image`) by maintaining their structure for better LM processing - **Test thoroughly**: Test your custom proposer with representative failure cases ## Custom Component Selection ### What is component_selector? The `component_selector` parameter controls which components (predictors) in your DSPy program are selected for optimization at each GEPA iteration. Instead of the default round-robin approach that updates one component at a time, you can implement custom selection strategies that choose single or multiple components based on optimization state, performance trajectories, and other contextual information. ### Default Behavior By default, GEPA uses a **round-robin strategy** (`RoundRobinReflectionComponentSelector`) that cycles through components sequentially, optimizing one component per iteration: ```python # Default round-robin component selection gepa = dspy.GEPA( metric=my_metric, reflection_lm=dspy.LM(model="gpt-5", temperature=1.0, max_tokens=32000, api_key=api_key), # component_selector="round_robin" # This is the default auto="medium" ) ``` ### Built-in Selection Strategies **String-based selectors:** - `"round_robin"` (default): Cycles through components one at a time - `"all"`: Selects all components for simultaneous optimization ```python # Optimize all components simultaneously gepa = dspy.GEPA( metric=my_metric, reflection_lm=reflection_lm, component_selector="all", # Update all components together auto="medium" ) # Explicit round-robin selection gepa = dspy.GEPA( metric=my_metric, reflection_lm=reflection_lm, component_selector="round_robin", # One component per iteration auto="medium" ) ``` ### When to Use Custom Component Selection Consider implementing custom component selection when you need: - **Dependency-aware optimization**: Update related components together (e.g., a classifier and its input formatter) - **LLM-driven selection**: Let an LLM analyze trajectories and decide which components need attention - **Resource-conscious optimization**: Balance optimization thoroughness with computational budget ### Custom Component Selector Protocol Custom component selectors must implement the [`ReflectionComponentSelector`](https://github.com/gepa-ai/gepa/blob/main/src/gepa/proposer/reflective_mutation/base.py) protocol by defining a callable class or function. GEPA will call your selector during optimization: ```python from dspy.teleprompt.gepa.gepa_utils import GEPAState, Trajectory class CustomComponentSelector: def __call__( self, state: GEPAState, # Complete optimization state with history trajectories: list[Trajectory], # Execution traces from the current minibatch subsample_scores: list[float], # Scores for each example in the current minibatch candidate_idx: int, # Index of the current program candidate being optimized candidate: dict[str, str], # Component name -> instruction mapping ) -> list[str]: # Return list of component names to optimize # Your custom component selection logic here return selected_components # Or as a function: def custom_component_selector(state, trajectories, subsample_scores, candidate_idx, candidate): # Your custom component selection logic here return selected_components ``` ### Custom Implementation Example Here's a simple function that alternates between optimizing different halves of your components: ```python def alternating_half_selector(state, trajectories, subsample_scores, candidate_idx, candidate): """Optimize half the components on even iterations, half on odd iterations.""" components = list(candidate.keys()) # If there's only one component, always optimize it if len(components) <= 1: return components mid_point = len(components) // 2 # Use state.i (iteration counter) to alternate between halves if state.i % 2 == 0: # Even iteration: optimize first half return components[:mid_point] else: # Odd iteration: optimize second half return components[mid_point:] # Usage gepa = dspy.GEPA( metric=my_metric, reflection_lm=reflection_lm, component_selector=alternating_half_selector, auto="medium" ) ``` ### Integration with Custom Instruction Proposers Component selectors work seamlessly with custom instruction proposers. The selector determines which components to update, then the instruction proposer generates new instructions for those components: ```python # Combined custom selector + custom proposer gepa = dspy.GEPA( metric=my_metric, reflection_lm=reflection_lm, component_selector=alternating_half_selector, instruction_proposer=WordLimitProposer(max_words=500), auto="medium" ) ``` ``` -------------------------------------------------------------------------------- /dspy/retrievers/databricks_rm.py: -------------------------------------------------------------------------------- ```python import json import os from dataclasses import dataclass from importlib.util import find_spec from typing import Any import requests import dspy from dspy.primitives.prediction import Prediction _databricks_sdk_installed = find_spec("databricks.sdk") is not None @dataclass class Document: page_content: str metadata: dict[str, Any] type: str def to_dict(self) -> dict[str, Any]: return { "page_content": self.page_content, "metadata": self.metadata, "type": self.type, } class DatabricksRM(dspy.Retrieve): """ A retriever module that uses a Databricks Mosaic AI Vector Search Index to return the top-k embeddings for a given query. Examples: Below is a code snippet that shows how to set up a Databricks Vector Search Index and configure a DatabricksRM DSPy retriever module to query the index. (example adapted from "Databricks: How to create and query a Vector Search Index: https://docs.databricks.com/en/generative-ai/create-query-vector-search.html#create-a-vector-search-index) ```python from databricks.vector_search.client import VectorSearchClient # Create a Databricks Vector Search Endpoint client = VectorSearchClient() client.create_endpoint( name="your_vector_search_endpoint_name", endpoint_type="STANDARD" ) # Create a Databricks Direct Access Vector Search Index index = client.create_direct_access_index( endpoint_name="your_vector_search_endpoint_name", index_name="your_index_name", primary_key="id", embedding_dimension=1024, embedding_vector_column="text_vector", schema={ "id": "int", "field2": "str", "field3": "float", "text_vector": "array<float>" } ) # Create a DatabricksRM retriever module to query the Databricks Direct Access Vector # Search Index retriever = DatabricksRM( databricks_index_name = "your_index_name", docs_id_column_name="id", text_column_name="field2", k=3 ) ``` Below is a code snippet that shows how to query the Databricks Direct Access Vector Search Index using the DatabricksRM retriever module: ```python retrieved_results = DatabricksRM(query="Example query text")) ``` """ def __init__( self, databricks_index_name: str, databricks_endpoint: str | None = None, databricks_token: str | None = None, databricks_client_id: str | None = None, databricks_client_secret: str | None = None, columns: list[str] | None = None, filters_json: str | None = None, k: int = 3, docs_id_column_name: str = "id", docs_uri_column_name: str | None = None, text_column_name: str = "text", use_with_databricks_agent_framework: bool = False, ): """ Args: databricks_index_name (str): The name of the Databricks Vector Search Index to query. databricks_endpoint (Optional[str]): The URL of the Databricks Workspace containing the Vector Search Index. Defaults to the value of the ``DATABRICKS_HOST`` environment variable. If unspecified, the Databricks SDK is used to identify the endpoint based on the current environment. databricks_token (Optional[str]): The Databricks Workspace authentication token to use when querying the Vector Search Index. Defaults to the value of the ``DATABRICKS_TOKEN`` environment variable. If unspecified, the Databricks SDK is used to identify the token based on the current environment. databricks_client_id (str): Databricks service principal id. If not specified, the token is resolved from the current environment (DATABRICKS_CLIENT_ID). databricks_client_secret (str): Databricks service principal secret. If not specified, the endpoint is resolved from the current environment (DATABRICKS_CLIENT_SECRET). columns (Optional[list[str]]): Extra column names to include in response, in addition to the document id and text columns specified by ``docs_id_column_name`` and ``text_column_name``. filters_json (Optional[str]): A JSON string specifying additional query filters. Example filters: ``{"id <": 5}`` selects records that have an ``id`` column value less than 5, and ``{"id >=": 5, "id <": 10}`` selects records that have an ``id`` column value greater than or equal to 5 and less than 10. k (int): The number of documents to retrieve. docs_id_column_name (str): The name of the column in the Databricks Vector Search Index containing document IDs. docs_uri_column_name (Optional[str]): The name of the column in the Databricks Vector Search Index containing document URI. text_column_name (str): The name of the column in the Databricks Vector Search Index containing document text to retrieve. use_with_databricks_agent_framework (bool): Whether to use the `DatabricksRM` in a way that is compatible with the Databricks Mosaic Agent Framework. """ super().__init__(k=k) self.databricks_token = databricks_token if databricks_token is not None else os.environ.get("DATABRICKS_TOKEN") self.databricks_endpoint = ( databricks_endpoint if databricks_endpoint is not None else os.environ.get("DATABRICKS_HOST") ) self.databricks_client_id = ( databricks_client_id if databricks_client_id is not None else os.environ.get("DATABRICKS_CLIENT_ID") ) self.databricks_client_secret = ( databricks_client_secret if databricks_client_secret is not None else os.environ.get("DATABRICKS_CLIENT_SECRET") ) if not _databricks_sdk_installed and (self.databricks_token, self.databricks_endpoint).count(None) > 0: raise ValueError( "To retrieve documents with Databricks Vector Search, you must install the" " databricks-sdk Python library, supply the databricks_token and" " databricks_endpoint parameters, or set the DATABRICKS_TOKEN and DATABRICKS_HOST" " environment variables. You may also supply a service principal the databricks_client_id and" " databricks_client_secret parameters, or set the DATABRICKS_CLIENT_ID and DATABRICKS_CLIENT_SECRET" ) self.databricks_index_name = databricks_index_name self.columns = list({docs_id_column_name, text_column_name, *(columns or [])}) self.filters_json = filters_json self.k = k self.docs_id_column_name = docs_id_column_name self.docs_uri_column_name = docs_uri_column_name self.text_column_name = text_column_name self.use_with_databricks_agent_framework = use_with_databricks_agent_framework if self.use_with_databricks_agent_framework: try: import mlflow mlflow.models.set_retriever_schema( primary_key="doc_id", text_column="page_content", doc_uri="doc_uri", ) except ImportError: raise ValueError( "To use the `DatabricksRM` retriever module with the Databricks Mosaic Agent Framework, " "you must install the mlflow Python library. Please install mlflow via `pip install mlflow`." ) def _extract_doc_ids(self, item: dict[str, Any]) -> str: """Extracts the document id from a search result Args: item: dict[str, Any]: a record from the search results. Returns: str: document id. """ if self.docs_id_column_name == "metadata": docs_dict = json.loads(item["metadata"]) return docs_dict["document_id"] return item[self.docs_id_column_name] def _get_extra_columns(self, item: dict[str, Any]) -> dict[str, Any]: """Extracts search result column values, excluding the "text" and not "id" columns Args: item: dict[str, Any]: a record from the search results. Returns: dict[str, Any]: Search result column values, excluding the "text", "id" and "uri" columns. """ extra_columns = { k: v for k, v in item.items() if k not in [self.docs_id_column_name, self.text_column_name, self.docs_uri_column_name] } if self.docs_id_column_name == "metadata": extra_columns = { **extra_columns, **{"metadata": {k: v for k, v in json.loads(item["metadata"]).items() if k != "document_id"}}, } return extra_columns def forward( self, query: str | list[float], query_type: str = "ANN", filters_json: str | None = None, ) -> dspy.Prediction | list[dict[str, Any]]: """ Retrieve documents from a Databricks Mosaic AI Vector Search Index that are relevant to the specified query. Args: query (Union[str, list[float]]): The query text or numeric query vector for which to retrieve relevant documents. query_type (str): The type of search query to perform against the Databricks Vector Search Index. Must be either 'ANN' (approximate nearest neighbor) or 'HYBRID' (hybrid search). filters_json (Optional[str]): A JSON string specifying additional query filters. Example filters: ``{"id <": 5}`` selects records that have an ``id`` column value less than 5, and ``{"id >=": 5, "id <": 10}`` selects records that have an ``id`` column value greater than or equal to 5 and less than 10. If specified, this parameter overrides the `filters_json` parameter passed to the constructor. Returns: A list of dictionaries when ``use_with_databricks_agent_framework`` is ``True``, or a ``dspy.Prediction`` object when ``use_with_databricks_agent_framework`` is ``False``. """ if query_type in ["vector", "text"]: # Older versions of DSPy used a `query_type` argument to disambiguate between text # and vector queries, rather than checking the type of the `query` argument. This # differs from the Databricks Vector Search definition of `query_type`, which # specifies the search algorithm to use (e.g. "ANN" or "HYBRID"). To maintain # backwards compatibility with older versions of DSPy, we map the old `query_type` # values to the Databricks Vector Search default query type of "ANN". query_type = "ANN" if isinstance(query, str): query_text = query query_vector = None elif isinstance(query, list): query_vector = query query_text = None else: raise ValueError("Query must be a string or a list of floats.") if _databricks_sdk_installed: results = self._query_via_databricks_sdk( index_name=self.databricks_index_name, k=self.k, columns=self.columns, query_type=query_type, query_text=query_text, query_vector=query_vector, databricks_token=self.databricks_token, databricks_endpoint=self.databricks_endpoint, databricks_client_id=self.databricks_client_id, databricks_client_secret=self.databricks_client_secret, filters_json=filters_json or self.filters_json, ) else: results = self._query_via_requests( index_name=self.databricks_index_name, k=self.k, columns=self.columns, databricks_token=self.databricks_token, databricks_endpoint=self.databricks_endpoint, query_type=query_type, query_text=query_text, query_vector=query_vector, filters_json=filters_json or self.filters_json, ) # Checking if defined columns are present in the index columns col_names = [column["name"] for column in results["manifest"]["columns"]] if self.docs_id_column_name not in col_names: raise Exception( f"docs_id_column_name: '{self.docs_id_column_name}' is not in the index columns: \n {col_names}" ) if self.text_column_name not in col_names: raise Exception(f"text_column_name: '{self.text_column_name}' is not in the index columns: \n {col_names}") # Extracting the results items = [] if "data_array" in results["result"]: for _, data_row in enumerate(results["result"]["data_array"]): item = {} for col_name, val in zip(col_names, data_row, strict=False): item[col_name] = val items += [item] # Sorting results by score in descending order sorted_docs = sorted(items, key=lambda x: x["score"], reverse=True)[: self.k] if self.use_with_databricks_agent_framework: return [ Document( page_content=doc[self.text_column_name], metadata={ "doc_id": self._extract_doc_ids(doc), "doc_uri": doc[self.docs_uri_column_name] if self.docs_uri_column_name else None, } | self._get_extra_columns(doc), type="Document", ).to_dict() for doc in sorted_docs ] else: # Returning the prediction return Prediction( docs=[doc[self.text_column_name] for doc in sorted_docs], doc_ids=[self._extract_doc_ids(doc) for doc in sorted_docs], doc_uris=[doc[self.docs_uri_column_name] for doc in sorted_docs] if self.docs_uri_column_name else None, extra_columns=[self._get_extra_columns(item) for item in sorted_docs], ) @staticmethod def _query_via_databricks_sdk( index_name: str, k: int, columns: list[str], query_type: str, query_text: str | None, query_vector: list[float] | None, databricks_token: str | None, databricks_endpoint: str | None, databricks_client_id: str | None, databricks_client_secret: str | None, filters_json: str | None, ) -> dict[str, Any]: """ Query a Databricks Vector Search Index via the Databricks SDK. Assumes that the databricks-sdk Python library is installed. Args: index_name (str): Name of the Databricks vector search index to query k (int): Number of relevant documents to retrieve. columns (list[str]): Column names to include in response. query_text (Optional[str]): Text query for which to find relevant documents. Exactly one of query_text or query_vector must be specified. query_vector (Optional[list[float]]): Numeric query vector for which to find relevant documents. Exactly one of query_text or query_vector must be specified. filters_json (Optional[str]): JSON string representing additional query filters. databricks_token (str): Databricks authentication token. If not specified, the token is resolved from the current environment. databricks_endpoint (str): Databricks index endpoint url. If not specified, the endpoint is resolved from the current environment. databricks_client_id (str): Databricks service principal id. If not specified, the token is resolved from the current environment (DATABRICKS_CLIENT_ID). databricks_client_secret (str): Databricks service principal secret. If not specified, the endpoint is resolved from the current environment (DATABRICKS_CLIENT_SECRET). Returns: Returns: dict[str, Any]: Parsed JSON response from the Databricks Vector Search Index query. """ from databricks.sdk import WorkspaceClient if (query_text, query_vector).count(None) != 1: raise ValueError("Exactly one of query_text or query_vector must be specified.") if databricks_client_secret and databricks_client_id: # Use client ID and secret for authentication if they are provided databricks_client = WorkspaceClient( client_id=databricks_client_id, client_secret=databricks_client_secret, ) print("Creating Databricks workspace client using service principal authentication.") else: # Fallback for token-based authentication databricks_client = WorkspaceClient( host=databricks_endpoint, token=databricks_token, ) print("Creating Databricks workspace client using token authentication.") return databricks_client.vector_search_indexes.query_index( index_name=index_name, query_type=query_type, query_text=query_text, query_vector=query_vector, columns=columns, filters_json=filters_json, num_results=k, ).as_dict() @staticmethod def _query_via_requests( index_name: str, k: int, columns: list[str], databricks_token: str, databricks_endpoint: str, query_type: str, query_text: str | None, query_vector: list[float] | None, filters_json: str | None, ) -> dict[str, Any]: """ Query a Databricks Vector Search Index via the Python requests library. Args: index_name (str): Name of the Databricks vector search index to query k (int): Number of relevant documents to retrieve. columns (list[str]): Column names to include in response. databricks_token (str): Databricks authentication token. databricks_endpoint (str): Databricks index endpoint url. query_text (Optional[str]): Text query for which to find relevant documents. Exactly one of query_text or query_vector must be specified. query_vector (Optional[list[float]]): Numeric query vector for which to find relevant documents. Exactly one of query_text or query_vector must be specified. filters_json (Optional[str]): JSON string representing additional query filters. Returns: dict[str, Any]: Parsed JSON response from the Databricks Vector Search Index query. """ if (query_text, query_vector).count(None) != 1: raise ValueError("Exactly one of query_text or query_vector must be specified.") headers = { "Authorization": f"Bearer {databricks_token}", "Content-Type": "application/json", } payload = { "columns": columns, "num_results": k, "query_type": query_type, } if filters_json is not None: payload["filters_json"] = filters_json if query_text is not None: payload["query_text"] = query_text elif query_vector is not None: payload["query_vector"] = query_vector response = requests.post( f"{databricks_endpoint}/api/2.0/vector-search/indexes/{index_name}/query", json=payload, headers=headers, ) results = response.json() if "error_code" in results: raise Exception(f"ERROR: {results['error_code']} -- {results['message']}") return results ```