This is page 14 of 17. Use http://codebase.md/stanfordnlp/dspy?lines=true&page={x} to view the full context. # Directory Structure ``` ├── .github │ ├── .internal_dspyai │ │ ├── internals │ │ │ ├── build-and-release.md │ │ │ └── release-checklist.md │ │ └── pyproject.toml │ ├── .tmp │ │ └── .generated-actions │ │ └── run-pypi-publish-in-docker-container │ │ └── action.yml │ ├── ISSUE_TEMPLATE │ │ ├── bug_report.yml │ │ └── feature_request.yml │ ├── PULL_REQUEST_TEMPLATE │ │ └── pull_request_template.md │ ├── workflow_scripts │ │ └── install_testpypi_pkg.sh │ └── workflows │ ├── build_and_release.yml │ ├── build_utils │ │ └── test_version.py │ ├── docs-push.yml │ ├── precommits_check.yml │ └── run_tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── docs │ ├── .gitignore │ ├── docs │ │ ├── api │ │ │ ├── adapters │ │ │ │ ├── Adapter.md │ │ │ │ ├── ChatAdapter.md │ │ │ │ ├── JSONAdapter.md │ │ │ │ └── TwoStepAdapter.md │ │ │ ├── evaluation │ │ │ │ ├── answer_exact_match.md │ │ │ │ ├── answer_passage_match.md │ │ │ │ ├── CompleteAndGrounded.md │ │ │ │ ├── Evaluate.md │ │ │ │ ├── EvaluationResult.md │ │ │ │ └── SemanticF1.md │ │ │ ├── experimental │ │ │ │ ├── Citations.md │ │ │ │ └── Document.md │ │ │ ├── index.md │ │ │ ├── models │ │ │ │ ├── Embedder.md │ │ │ │ └── LM.md │ │ │ ├── modules │ │ │ │ ├── BestOfN.md │ │ │ │ ├── ChainOfThought.md │ │ │ │ ├── CodeAct.md │ │ │ │ ├── Module.md │ │ │ │ ├── MultiChainComparison.md │ │ │ │ ├── Parallel.md │ │ │ │ ├── Predict.md │ │ │ │ ├── ProgramOfThought.md │ │ │ │ ├── ReAct.md │ │ │ │ └── Refine.md │ │ │ ├── optimizers │ │ │ │ ├── BetterTogether.md │ │ │ │ ├── BootstrapFewShot.md │ │ │ │ ├── BootstrapFewShotWithRandomSearch.md │ │ │ │ ├── BootstrapFinetune.md │ │ │ │ ├── BootstrapRS.md │ │ │ │ ├── COPRO.md │ │ │ │ ├── Ensemble.md │ │ │ │ ├── GEPA │ │ │ │ │ ├── GEPA_Advanced.md │ │ │ │ │ └── overview.md │ │ │ │ ├── InferRules.md │ │ │ │ ├── KNN.md │ │ │ │ ├── KNNFewShot.md │ │ │ │ ├── LabeledFewShot.md │ │ │ │ ├── MIPROv2.md │ │ │ │ └── SIMBA.md │ │ │ ├── primitives │ │ │ │ ├── Audio.md │ │ │ │ ├── Code.md │ │ │ │ ├── Example.md │ │ │ │ ├── History.md │ │ │ │ ├── Image.md │ │ │ │ ├── Prediction.md │ │ │ │ ├── Tool.md │ │ │ │ └── ToolCalls.md │ │ │ ├── signatures │ │ │ │ ├── InputField.md │ │ │ │ ├── OutputField.md │ │ │ │ └── Signature.md │ │ │ ├── tools │ │ │ │ ├── ColBERTv2.md │ │ │ │ ├── Embeddings.md │ │ │ │ └── PythonInterpreter.md │ │ │ └── utils │ │ │ ├── asyncify.md │ │ │ ├── configure_cache.md │ │ │ ├── disable_litellm_logging.md │ │ │ ├── disable_logging.md │ │ │ ├── enable_litellm_logging.md │ │ │ ├── enable_logging.md │ │ │ ├── inspect_history.md │ │ │ ├── load.md │ │ │ ├── StatusMessage.md │ │ │ ├── StatusMessageProvider.md │ │ │ ├── streamify.md │ │ │ └── StreamListener.md │ │ ├── cheatsheet.md │ │ ├── community │ │ │ ├── community-resources.md │ │ │ ├── how-to-contribute.md │ │ │ └── use-cases.md │ │ ├── deep-dive │ │ │ └── data-handling │ │ │ ├── built-in-datasets.md │ │ │ ├── examples.md │ │ │ ├── img │ │ │ │ └── data-loading.png │ │ │ └── loading-custom-data.md │ │ ├── faqs.md │ │ ├── index.md │ │ ├── js │ │ │ └── runllm-widget.js │ │ ├── learn │ │ │ ├── evaluation │ │ │ │ ├── data.md │ │ │ │ ├── metrics.md │ │ │ │ └── overview.md │ │ │ ├── figures │ │ │ │ ├── native_tool_call.png │ │ │ │ └── teleprompter-classes.png │ │ │ ├── index.md │ │ │ ├── optimization │ │ │ │ ├── optimizers.md │ │ │ │ └── overview.md │ │ │ └── programming │ │ │ ├── 7-assertions.md │ │ │ ├── adapters.md │ │ │ ├── language_models.md │ │ │ ├── mcp.md │ │ │ ├── modules.md │ │ │ ├── overview.md │ │ │ ├── signatures.md │ │ │ └── tools.md │ │ ├── production │ │ │ └── index.md │ │ ├── roadmap.md │ │ ├── static │ │ │ ├── .nojekyll │ │ │ └── img │ │ │ ├── dspy_logo.png │ │ │ ├── logo.png │ │ │ ├── mlflow-tracing-rag.png │ │ │ ├── modular.png │ │ │ ├── optimize.png │ │ │ ├── undraw_docusaurus_mountain.svg │ │ │ ├── undraw_docusaurus_react.svg │ │ │ ├── undraw_docusaurus_tree.svg │ │ │ └── universal_compatibility.png │ │ ├── stylesheets │ │ │ └── extra.css │ │ └── tutorials │ │ ├── agents │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-agent.png │ │ ├── ai_text_game │ │ │ └── index.md │ │ ├── async │ │ │ └── index.md │ │ ├── audio │ │ │ └── index.ipynb │ │ ├── build_ai_program │ │ │ └── index.md │ │ ├── cache │ │ │ └── index.md │ │ ├── classification │ │ │ └── index.md │ │ ├── classification_finetuning │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-classification.png │ │ ├── conversation_history │ │ │ └── index.md │ │ ├── core_development │ │ │ └── index.md │ │ ├── custom_module │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-custom-module.png │ │ ├── customer_service_agent │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-customer-service-agent.png │ │ ├── deployment │ │ │ ├── dspy_mlflow_ui.png │ │ │ └── index.md │ │ ├── email_extraction │ │ │ ├── index.md │ │ │ └── mlflow-tracing-email-extraction.png │ │ ├── entity_extraction │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-entity-extraction.png │ │ ├── games │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-agent.png │ │ ├── gepa_ai_program │ │ │ └── index.md │ │ ├── gepa_aime │ │ │ ├── index.ipynb │ │ │ ├── mlflow-tracing-gepa-aime.png │ │ │ └── mlflow-tracking-gepa-aime-optimization.png │ │ ├── gepa_facilitysupportanalyzer │ │ │ ├── index.ipynb │ │ │ ├── mlflow-tracing-gepa-support.png │ │ │ └── mlflow-tracking-gepa-support-optimization.png │ │ ├── gepa_papillon │ │ │ ├── index.ipynb │ │ │ ├── mlflow-tracing-gepa-papilon.png │ │ │ └── mlflow-tracking-gepa-papilon-optimization.png │ │ ├── image_generation_prompting │ │ │ └── index.ipynb │ │ ├── index.md │ │ ├── llms_txt_generation │ │ │ └── index.md │ │ ├── math │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-math.png │ │ ├── mcp │ │ │ └── index.md │ │ ├── mem0_react_agent │ │ │ └── index.md │ │ ├── multihop_search │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-multi-hop.png │ │ ├── observability │ │ │ ├── index.md │ │ │ ├── mlflow_trace_ui_navigation.gif │ │ │ ├── mlflow_trace_ui.png │ │ │ └── mlflow_trace_view.png │ │ ├── optimize_ai_program │ │ │ └── index.md │ │ ├── optimizer_tracking │ │ │ ├── child_run.png │ │ │ ├── experiment.png │ │ │ ├── index.md │ │ │ └── parent_run.png │ │ ├── output_refinement │ │ │ └── best-of-n-and-refine.md │ │ ├── papillon │ │ │ └── index.md │ │ ├── program_of_thought │ │ │ └── index.ipynb │ │ ├── rag │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-rag.png │ │ ├── real_world_examples │ │ │ └── index.md │ │ ├── rl_ai_program │ │ │ └── index.md │ │ ├── rl_multihop │ │ │ └── index.ipynb │ │ ├── rl_papillon │ │ │ └── index.ipynb │ │ ├── sample_code_generation │ │ │ └── index.md │ │ ├── saving │ │ │ └── index.md │ │ ├── streaming │ │ │ └── index.md │ │ ├── tool_use │ │ │ └── index.ipynb │ │ └── yahoo_finance_react │ │ └── index.md │ ├── mkdocs.yml │ ├── overrides │ │ ├── home.html │ │ ├── main.html │ │ └── partials │ │ └── tabs.html │ ├── Pipfile │ ├── Pipfile.lock │ ├── README.md │ ├── requirements.txt │ ├── scripts │ │ ├── generate_api_docs.py │ │ └── generate_api_summary.py │ └── vercel.json ├── dspy │ ├── __init__.py │ ├── __metadata__.py │ ├── adapters │ │ ├── __init__.py │ │ ├── baml_adapter.py │ │ ├── base.py │ │ ├── chat_adapter.py │ │ ├── json_adapter.py │ │ ├── two_step_adapter.py │ │ ├── types │ │ │ ├── __init__.py │ │ │ ├── audio.py │ │ │ ├── base_type.py │ │ │ ├── citation.py │ │ │ ├── code.py │ │ │ ├── document.py │ │ │ ├── history.py │ │ │ ├── image.py │ │ │ └── tool.py │ │ ├── utils.py │ │ └── xml_adapter.py │ ├── clients │ │ ├── __init__.py │ │ ├── base_lm.py │ │ ├── cache.py │ │ ├── databricks.py │ │ ├── embedding.py │ │ ├── lm_local_arbor.py │ │ ├── lm_local.py │ │ ├── lm.py │ │ ├── openai.py │ │ ├── provider.py │ │ └── utils_finetune.py │ ├── datasets │ │ ├── __init__.py │ │ ├── alfworld │ │ │ ├── __init__.py │ │ │ ├── alfworld.py │ │ │ └── base_config.yml │ │ ├── colors.py │ │ ├── dataloader.py │ │ ├── dataset.py │ │ ├── gsm8k.py │ │ ├── hotpotqa.py │ │ └── math.py │ ├── dsp │ │ ├── __init__.py │ │ ├── colbertv2.py │ │ └── utils │ │ ├── __init__.py │ │ ├── dpr.py │ │ ├── settings.py │ │ └── utils.py │ ├── evaluate │ │ ├── __init__.py │ │ ├── auto_evaluation.py │ │ ├── evaluate.py │ │ └── metrics.py │ ├── experimental │ │ └── __init__.py │ ├── predict │ │ ├── __init__.py │ │ ├── aggregation.py │ │ ├── avatar │ │ │ ├── __init__.py │ │ │ ├── avatar.py │ │ │ ├── models.py │ │ │ └── signatures.py │ │ ├── best_of_n.py │ │ ├── chain_of_thought.py │ │ ├── code_act.py │ │ ├── knn.py │ │ ├── multi_chain_comparison.py │ │ ├── parallel.py │ │ ├── parameter.py │ │ ├── predict.py │ │ ├── program_of_thought.py │ │ ├── react.py │ │ ├── refine.py │ │ └── retry.py │ ├── primitives │ │ ├── __init__.py │ │ ├── base_module.py │ │ ├── example.py │ │ ├── module.py │ │ ├── prediction.py │ │ ├── python_interpreter.py │ │ └── runner.js │ ├── propose │ │ ├── __init__.py │ │ ├── dataset_summary_generator.py │ │ ├── grounded_proposer.py │ │ ├── propose_base.py │ │ └── utils.py │ ├── retrievers │ │ ├── __init__.py │ │ ├── databricks_rm.py │ │ ├── embeddings.py │ │ ├── retrieve.py │ │ └── weaviate_rm.py │ ├── signatures │ │ ├── __init__.py │ │ ├── field.py │ │ ├── signature.py │ │ └── utils.py │ ├── streaming │ │ ├── __init__.py │ │ ├── messages.py │ │ ├── streamify.py │ │ └── streaming_listener.py │ ├── teleprompt │ │ ├── __init__.py │ │ ├── avatar_optimizer.py │ │ ├── bettertogether.py │ │ ├── bootstrap_finetune.py │ │ ├── bootstrap_trace.py │ │ ├── bootstrap.py │ │ ├── copro_optimizer.py │ │ ├── ensemble.py │ │ ├── gepa │ │ │ ├── __init__.py │ │ │ ├── gepa_utils.py │ │ │ ├── gepa.py │ │ │ └── instruction_proposal.py │ │ ├── grpo.py │ │ ├── infer_rules.py │ │ ├── knn_fewshot.py │ │ ├── mipro_optimizer_v2.py │ │ ├── random_search.py │ │ ├── signature_opt.py │ │ ├── simba_utils.py │ │ ├── simba.py │ │ ├── teleprompt_optuna.py │ │ ├── teleprompt.py │ │ ├── utils.py │ │ └── vanilla.py │ └── utils │ ├── __init__.py │ ├── annotation.py │ ├── asyncify.py │ ├── caching.py │ ├── callback.py │ ├── dummies.py │ ├── exceptions.py │ ├── hasher.py │ ├── inspect_history.py │ ├── langchain_tool.py │ ├── logging_utils.py │ ├── mcp.py │ ├── parallelizer.py │ ├── saving.py │ ├── syncify.py │ ├── unbatchify.py │ └── usage_tracker.py ├── LICENSE ├── pyproject.toml ├── README.md ├── tests │ ├── __init__.py │ ├── adapters │ │ ├── test_adapter_utils.py │ │ ├── test_baml_adapter.py │ │ ├── test_base_type.py │ │ ├── test_chat_adapter.py │ │ ├── test_citation.py │ │ ├── test_code.py │ │ ├── test_document.py │ │ ├── test_json_adapter.py │ │ ├── test_tool.py │ │ ├── test_two_step_adapter.py │ │ └── test_xml_adapter.py │ ├── callback │ │ └── test_callback.py │ ├── clients │ │ ├── test_cache.py │ │ ├── test_databricks.py │ │ ├── test_embedding.py │ │ ├── test_inspect_global_history.py │ │ └── test_lm.py │ ├── conftest.py │ ├── datasets │ │ └── test_dataset.py │ ├── docs │ │ └── test_mkdocs_links.py │ ├── evaluate │ │ ├── test_evaluate.py │ │ └── test_metrics.py │ ├── examples │ │ └── test_baleen.py │ ├── metadata │ │ └── test_metadata.py │ ├── predict │ │ ├── test_aggregation.py │ │ ├── test_best_of_n.py │ │ ├── test_chain_of_thought.py │ │ ├── test_code_act.py │ │ ├── test_knn.py │ │ ├── test_multi_chain_comparison.py │ │ ├── test_parallel.py │ │ ├── test_predict.py │ │ ├── test_program_of_thought.py │ │ ├── test_react.py │ │ ├── test_refine.py │ │ └── test_retry.py │ ├── primitives │ │ ├── resources │ │ │ └── saved_program.json │ │ ├── test_base_module.py │ │ ├── test_example.py │ │ ├── test_module.py │ │ └── test_python_interpreter.py │ ├── propose │ │ └── test_grounded_proposer.py │ ├── README.md │ ├── reliability │ │ ├── __init__.py │ │ ├── complex_types │ │ │ └── generated │ │ │ ├── test_many_types_1 │ │ │ │ ├── inputs │ │ │ │ │ ├── input1.json │ │ │ │ │ └── input2.json │ │ │ │ ├── program.py │ │ │ │ └── schema.json │ │ │ ├── test_nesting_1 │ │ │ │ ├── inputs │ │ │ │ │ ├── input1.json │ │ │ │ │ └── input2.json │ │ │ │ ├── program.py │ │ │ │ └── schema.json │ │ │ └── test_nesting_2 │ │ │ ├── inputs │ │ │ │ └── input1.json │ │ │ ├── program.py │ │ │ └── schema.json │ │ ├── conftest.py │ │ ├── generate │ │ │ ├── __init__.py │ │ │ ├── __main__.py │ │ │ └── utils.py │ │ ├── input_formats │ │ │ └── generated │ │ │ └── test_markdown_1 │ │ │ ├── inputs │ │ │ │ ├── input1.json │ │ │ │ └── input2.json │ │ │ ├── program.py │ │ │ └── schema.json │ │ ├── README.md │ │ ├── reliability_conf.yaml │ │ ├── test_generated.py │ │ ├── test_pydantic_models.py │ │ └── utils.py │ ├── retrievers │ │ └── test_embeddings.py │ ├── signatures │ │ ├── test_adapter_image.py │ │ ├── test_custom_types.py │ │ └── test_signature.py │ ├── streaming │ │ └── test_streaming.py │ ├── teleprompt │ │ ├── gepa_dummy_lm_custom_component_selector_custom_instruction_proposer.json │ │ ├── gepa_dummy_lm.json │ │ ├── test_bootstrap_finetune.py │ │ ├── test_bootstrap_trace.py │ │ ├── test_bootstrap.py │ │ ├── test_copro_optimizer.py │ │ ├── test_ensemble.py │ │ ├── test_finetune.py │ │ ├── test_gepa_instruction_proposer.py │ │ ├── test_gepa.py │ │ ├── test_grpo.py │ │ ├── test_knn_fewshot.py │ │ ├── test_random_search.py │ │ ├── test_teleprompt.py │ │ └── test_utils.py │ ├── test_utils │ │ ├── __init__.py │ │ └── server │ │ ├── __init__.py │ │ ├── litellm_server_config.yaml │ │ └── litellm_server.py │ └── utils │ ├── __init__.py │ ├── resources │ │ └── mcp_server.py │ ├── test_annotation.py │ ├── test_asyncify.py │ ├── test_exceptions.py │ ├── test_langchain_tool.py │ ├── test_mcp.py │ ├── test_parallelizer.py │ ├── test_saving.py │ ├── test_settings.py │ ├── test_syncify.py │ ├── test_unbatchify.py │ └── test_usage_tracker.py └── uv.lock ``` # Files -------------------------------------------------------------------------------- /tests/predict/test_predict.py: -------------------------------------------------------------------------------- ```python 1 | import asyncio 2 | import copy 3 | import enum 4 | import time 5 | import types 6 | from datetime import datetime 7 | from unittest.mock import patch 8 | 9 | import orjson 10 | import pydantic 11 | import pytest 12 | from litellm import ModelResponse 13 | from pydantic import BaseModel, HttpUrl 14 | 15 | import dspy 16 | from dspy import Predict, Signature 17 | from dspy.predict.predict import serialize_object 18 | from dspy.utils.dummies import DummyLM 19 | 20 | 21 | def test_initialization_with_string_signature(): 22 | signature_string = "input1, input2 -> output" 23 | predict = Predict(signature_string) 24 | expected_instruction = "Given the fields `input1`, `input2`, produce the fields `output`." 25 | assert predict.signature.instructions == expected_instruction 26 | assert predict.signature.instructions == Signature(signature_string).instructions 27 | 28 | 29 | def test_reset_method(): 30 | predict_instance = Predict("input -> output") 31 | predict_instance.lm = "modified" 32 | predict_instance.traces = ["trace"] 33 | predict_instance.train = ["train"] 34 | predict_instance.demos = ["demo"] 35 | predict_instance.reset() 36 | assert predict_instance.lm is None 37 | assert predict_instance.traces == [] 38 | assert predict_instance.train == [] 39 | assert predict_instance.demos == [] 40 | 41 | 42 | def test_lm_after_dump_and_load_state(): 43 | predict_instance = Predict("input -> output") 44 | lm = dspy.LM( 45 | model="openai/gpt-4o-mini", 46 | model_type="chat", 47 | temperature=1, 48 | max_tokens=100, 49 | num_retries=10, 50 | ) 51 | predict_instance.lm = lm 52 | expected_lm_state = { 53 | "model": "openai/gpt-4o-mini", 54 | "model_type": "chat", 55 | "temperature": 1, 56 | "max_tokens": 100, 57 | "num_retries": 10, 58 | "cache": True, 59 | "finetuning_model": None, 60 | "launch_kwargs": {}, 61 | "train_kwargs": {}, 62 | } 63 | assert lm.dump_state() == expected_lm_state 64 | dumped_state = predict_instance.dump_state() 65 | new_instance = Predict("input -> output") 66 | new_instance.load_state(dumped_state) 67 | assert new_instance.lm.dump_state() == expected_lm_state 68 | 69 | 70 | def test_call_method(): 71 | predict_instance = Predict("input -> output") 72 | lm = DummyLM([{"output": "test output"}]) 73 | dspy.settings.configure(lm=lm) 74 | result = predict_instance(input="test input") 75 | assert result.output == "test output" 76 | 77 | 78 | def test_instructions_after_dump_and_load_state(): 79 | predict_instance = Predict(Signature("input -> output", "original instructions")) 80 | dumped_state = predict_instance.dump_state() 81 | new_instance = Predict(Signature("input -> output", "new instructions")) 82 | new_instance.load_state(dumped_state) 83 | assert new_instance.signature.instructions == "original instructions" 84 | 85 | 86 | def test_demos_after_dump_and_load_state(): 87 | class TranslateToEnglish(dspy.Signature): 88 | """Translate content from a language to English.""" 89 | 90 | content: str = dspy.InputField() 91 | language: str = dspy.InputField() 92 | translation: str = dspy.OutputField() 93 | 94 | original_instance = Predict(TranslateToEnglish) 95 | original_instance.demos = [ 96 | dspy.Example( 97 | content="¿Qué tal?", 98 | language="SPANISH", 99 | translation="Hello there", 100 | ).with_inputs("content", "language"), 101 | ] 102 | 103 | dumped_state = original_instance.dump_state() 104 | assert len(dumped_state["demos"]) == len(original_instance.demos) 105 | assert dumped_state["demos"][0]["content"] == original_instance.demos[0].content 106 | 107 | saved_state = orjson.dumps(dumped_state).decode() 108 | loaded_state = orjson.loads(saved_state) 109 | 110 | new_instance = Predict(TranslateToEnglish) 111 | new_instance.load_state(loaded_state) 112 | assert len(new_instance.demos) == len(original_instance.demos) 113 | # Demos don't need to keep the same types after saving and loading the state. 114 | assert new_instance.demos[0]["content"] == original_instance.demos[0].content 115 | 116 | 117 | def test_typed_demos_after_dump_and_load_state(): 118 | class Item(pydantic.BaseModel): 119 | name: str 120 | quantity: int 121 | 122 | class InventorySignature(dspy.Signature): 123 | """Handle inventory items and their translations.""" 124 | 125 | items: list[Item] = dspy.InputField() 126 | language: str = dspy.InputField() 127 | translated_items: list[Item] = dspy.OutputField() 128 | total_quantity: int = dspy.OutputField() 129 | 130 | original_instance = Predict(InventorySignature) 131 | original_instance.demos = [ 132 | dspy.Example( 133 | items=[Item(name="apple", quantity=5), Item(name="banana", quantity=3)], 134 | language="SPANISH", 135 | translated_items=[Item(name="manzana", quantity=5), Item(name="plátano", quantity=3)], 136 | total_quantity=8, 137 | ).with_inputs("items", "language"), 138 | ] 139 | 140 | # Test dump_state 141 | dumped_state = original_instance.dump_state() 142 | assert len(dumped_state["demos"]) == len(original_instance.demos) 143 | # Verify the input items were properly serialized 144 | assert isinstance(dumped_state["demos"][0]["items"], list) 145 | assert len(dumped_state["demos"][0]["items"]) == 2 146 | assert dumped_state["demos"][0]["items"][0] == {"name": "apple", "quantity": 5} 147 | 148 | # Test serialization/deserialization 149 | saved_state = orjson.dumps(dumped_state).decode() 150 | loaded_state = orjson.loads(saved_state) 151 | 152 | # Test load_state 153 | new_instance = Predict(InventorySignature) 154 | new_instance.load_state(loaded_state) 155 | assert len(new_instance.demos) == len(original_instance.demos) 156 | 157 | # Verify the structure is maintained after loading 158 | loaded_demo = new_instance.demos[0] 159 | assert isinstance(loaded_demo["items"], list) 160 | assert len(loaded_demo["items"]) == 2 161 | assert loaded_demo["items"][0]["name"] == "apple" 162 | assert loaded_demo["items"][0]["quantity"] == 5 163 | assert loaded_demo["items"][1]["name"] == "banana" 164 | assert loaded_demo["items"][1]["quantity"] == 3 165 | 166 | # Verify output items were also properly maintained 167 | assert isinstance(loaded_demo["translated_items"], list) 168 | assert len(loaded_demo["translated_items"]) == 2 169 | assert loaded_demo["translated_items"][0]["name"] == "manzana" 170 | assert loaded_demo["translated_items"][1]["name"] == "plátano" 171 | 172 | 173 | # def test_typed_demos_after_dump_and_load_state(): 174 | # class TypedTranslateToEnglish(dspy.Signature): 175 | # """Translate content from a language to English.""" 176 | 177 | # class Input(pydantic.BaseModel): 178 | # content: str 179 | # language: str 180 | 181 | # class Output(pydantic.BaseModel): 182 | # translation: str 183 | 184 | # input: Input = dspy.InputField() 185 | # output: Output = dspy.OutputField() 186 | 187 | # original_instance = TypedPredictor(TypedTranslateToEnglish).predictor 188 | # original_instance.demos = [ 189 | # dspy.Example( 190 | # input=TypedTranslateToEnglish.Input( 191 | # content="¿Qué tal?", 192 | # language="SPANISH", 193 | # ), 194 | # output=TypedTranslateToEnglish.Output( 195 | # translation="Hello there", 196 | # ), 197 | # ).with_inputs("input"), 198 | # ] 199 | 200 | # dumped_state = original_instance.dump_state() 201 | # assert len(dumped_state["demos"]) == len(original_instance.demos) 202 | # assert dumped_state["demos"][0]["input"] == original_instance.demos[0].input.model_dump_json() 203 | 204 | # saved_state = ujson.dumps(dumped_state) 205 | # loaded_state = ujson.loads(saved_state) 206 | 207 | # new_instance = TypedPredictor(TypedTranslateToEnglish).predictor 208 | # new_instance.load_state(loaded_state) 209 | # assert len(new_instance.demos) == len(original_instance.demos) 210 | # # Demos don't need to keep the same types after saving and loading the state. 211 | # assert new_instance.demos[0]["input"] == original_instance.demos[0].input.model_dump_json() 212 | 213 | 214 | def test_signature_fields_after_dump_and_load_state(tmp_path): 215 | class CustomSignature(dspy.Signature): 216 | """I am just an instruction.""" 217 | 218 | sentence = dspy.InputField(desc="I am an innocent input!") 219 | sentiment = dspy.OutputField() 220 | 221 | file_path = tmp_path / "tmp.json" 222 | original_instance = Predict(CustomSignature) 223 | original_instance.save(file_path) 224 | 225 | class CustomSignature2(dspy.Signature): 226 | """I am not a pure instruction.""" 227 | 228 | sentence = dspy.InputField(desc="I am a malicious input!") 229 | sentiment = dspy.OutputField(desc="I am a malicious output!", prefix="I am a prefix!") 230 | 231 | new_instance = Predict(CustomSignature2) 232 | assert new_instance.signature.dump_state() != original_instance.signature.dump_state() 233 | # After loading, the fields should be the same. 234 | new_instance.load(file_path) 235 | assert new_instance.signature.dump_state() == original_instance.signature.dump_state() 236 | 237 | 238 | @pytest.mark.parametrize("filename", ["model.json", "model.pkl"]) 239 | def test_lm_field_after_dump_and_load_state(tmp_path, filename): 240 | file_path = tmp_path / filename 241 | lm = dspy.LM( 242 | model="openai/gpt-4o-mini", 243 | model_type="chat", 244 | temperature=1, 245 | max_tokens=100, 246 | num_retries=10, 247 | ) 248 | original_predict = dspy.Predict("q->a") 249 | original_predict.lm = lm 250 | 251 | original_predict.save(file_path) 252 | 253 | assert file_path.exists() 254 | 255 | loaded_predict = dspy.Predict("q->a") 256 | loaded_predict.load(file_path) 257 | 258 | assert original_predict.dump_state() == loaded_predict.dump_state() 259 | 260 | 261 | def test_forward_method(): 262 | program = Predict("question -> answer") 263 | dspy.settings.configure(lm=DummyLM([{"answer": "No more responses"}])) 264 | result = program(question="What is 1+1?").answer 265 | assert result == "No more responses" 266 | 267 | 268 | def test_forward_method2(): 269 | program = Predict("question -> answer1, answer2") 270 | dspy.settings.configure(lm=DummyLM([{"answer1": "my first answer", "answer2": "my second answer"}])) 271 | result = program(question="What is 1+1?") 272 | assert result.answer1 == "my first answer" 273 | assert result.answer2 == "my second answer" 274 | 275 | 276 | def test_config_management(): 277 | predict_instance = Predict("input -> output") 278 | predict_instance.update_config(new_key="value") 279 | config = predict_instance.get_config() 280 | assert "new_key" in config and config["new_key"] == "value" 281 | 282 | 283 | def test_multi_output(): 284 | program = Predict("question -> answer", n=2) 285 | dspy.settings.configure(lm=DummyLM([{"answer": "my first answer"}, {"answer": "my second answer"}])) 286 | results = program(question="What is 1+1?") 287 | assert results.completions.answer[0] == "my first answer" 288 | assert results.completions.answer[1] == "my second answer" 289 | 290 | 291 | def test_multi_output2(): 292 | program = Predict("question -> answer1, answer2", n=2) 293 | dspy.settings.configure( 294 | lm=DummyLM( 295 | [ 296 | {"answer1": "my 0 answer", "answer2": "my 2 answer"}, 297 | {"answer1": "my 1 answer", "answer2": "my 3 answer"}, 298 | ], 299 | ) 300 | ) 301 | results = program(question="What is 1+1?") 302 | assert results.completions.answer1[0] == "my 0 answer" 303 | assert results.completions.answer1[1] == "my 1 answer" 304 | assert results.completions.answer2[0] == "my 2 answer" 305 | assert results.completions.answer2[1] == "my 3 answer" 306 | 307 | 308 | def test_datetime_inputs_and_outputs(): 309 | # Define a model for datetime inputs and outputs 310 | class TimedEvent(pydantic.BaseModel): 311 | event_name: str 312 | event_time: datetime 313 | 314 | class TimedSignature(dspy.Signature): 315 | events: list[TimedEvent] = dspy.InputField() 316 | summary: str = dspy.OutputField() 317 | next_event_time: datetime = dspy.OutputField() 318 | 319 | program = Predict(TimedSignature) 320 | 321 | lm = DummyLM( 322 | [ 323 | { 324 | "reasoning": "Processed datetime inputs", 325 | "summary": "All events are processed", 326 | "next_event_time": "2024-11-27T14:00:00", 327 | } 328 | ] 329 | ) 330 | dspy.settings.configure(lm=lm) 331 | 332 | output = program( 333 | events=[ 334 | TimedEvent(event_name="Event 1", event_time=datetime(2024, 11, 25, 10, 0, 0)), 335 | TimedEvent(event_name="Event 2", event_time=datetime(2024, 11, 25, 15, 30, 0)), 336 | ] 337 | ) 338 | assert output.summary == "All events are processed" 339 | assert output.next_event_time == datetime(2024, 11, 27, 14, 0, 0) 340 | 341 | 342 | def test_explicitly_valued_enum_inputs_and_outputs(): 343 | class Status(enum.Enum): 344 | PENDING = "pending" 345 | IN_PROGRESS = "in_progress" 346 | COMPLETED = "completed" 347 | 348 | class StatusSignature(dspy.Signature): 349 | current_status: Status = dspy.InputField() 350 | next_status: Status = dspy.OutputField() 351 | 352 | program = Predict(StatusSignature) 353 | 354 | lm = DummyLM( 355 | [ 356 | { 357 | "reasoning": "The current status is 'PENDING', advancing to 'IN_PROGRESS'.", 358 | "next_status": "in_progress", 359 | } 360 | ] 361 | ) 362 | dspy.settings.configure(lm=lm) 363 | 364 | output = program(current_status=Status.PENDING) 365 | assert output.next_status == Status.IN_PROGRESS 366 | 367 | 368 | def test_enum_inputs_and_outputs_with_shared_names_and_values(): 369 | class TicketStatus(enum.Enum): 370 | OPEN = "CLOSED" 371 | CLOSED = "RESOLVED" 372 | RESOLVED = "OPEN" 373 | 374 | class TicketStatusSignature(dspy.Signature): 375 | current_status: TicketStatus = dspy.InputField() 376 | next_status: TicketStatus = dspy.OutputField() 377 | 378 | program = Predict(TicketStatusSignature) 379 | 380 | # Mock reasoning and output 381 | lm = DummyLM( 382 | [ 383 | { 384 | "reasoning": "The ticket is currently 'OPEN', transitioning to 'CLOSED'.", 385 | "next_status": "RESOLVED", # Refers to TicketStatus.CLOSED by value 386 | } 387 | ] 388 | ) 389 | dspy.settings.configure(lm=lm) 390 | 391 | output = program(current_status=TicketStatus.OPEN) 392 | assert output.next_status == TicketStatus.CLOSED # By value 393 | 394 | 395 | def test_auto_valued_enum_inputs_and_outputs(): 396 | Status = enum.Enum("Status", ["PENDING", "IN_PROGRESS", "COMPLETED"]) # noqa: N806 397 | 398 | class StatusSignature(dspy.Signature): 399 | current_status: Status = dspy.InputField() 400 | next_status: Status = dspy.OutputField() 401 | 402 | program = Predict(StatusSignature) 403 | 404 | lm = DummyLM( 405 | [ 406 | { 407 | "reasoning": "The current status is 'PENDING', advancing to 'IN_PROGRESS'.", 408 | "next_status": "IN_PROGRESS", # Use the auto-assigned value for IN_PROGRESS 409 | } 410 | ] 411 | ) 412 | dspy.settings.configure(lm=lm) 413 | 414 | output = program(current_status=Status.PENDING) 415 | assert output.next_status == Status.IN_PROGRESS 416 | 417 | 418 | def test_named_predictors(): 419 | class MyModule(dspy.Module): 420 | def __init__(self): 421 | super().__init__() 422 | self.inner = Predict("question -> answer") 423 | 424 | program = MyModule() 425 | assert program.named_predictors() == [("inner", program.inner)] 426 | 427 | # Check that it also works the second time. 428 | program2 = copy.deepcopy(program) 429 | assert program2.named_predictors() == [("inner", program2.inner)] 430 | 431 | 432 | def test_output_only(): 433 | class OutputOnlySignature(dspy.Signature): 434 | output = dspy.OutputField() 435 | 436 | predictor = Predict(OutputOnlySignature) 437 | 438 | lm = DummyLM([{"output": "short answer"}]) 439 | dspy.settings.configure(lm=lm) 440 | assert predictor().output == "short answer" 441 | 442 | 443 | def test_load_state_chaining(): 444 | """Test that load_state returns self for chaining.""" 445 | original = Predict("question -> answer") 446 | original.demos = [{"question": "test", "answer": "response"}] 447 | state = original.dump_state() 448 | 449 | new_instance = Predict("question -> answer").load_state(state) 450 | assert new_instance is not None 451 | assert new_instance.demos == original.demos 452 | 453 | 454 | @pytest.mark.parametrize("adapter_type", ["chat", "json"]) 455 | def test_call_predict_with_chat_history(adapter_type): 456 | class SpyLM(dspy.LM): 457 | def __init__(self, *args, return_json=False, **kwargs): 458 | super().__init__(*args, **kwargs) 459 | self.calls = [] 460 | self.return_json = return_json 461 | 462 | def __call__(self, prompt=None, messages=None, **kwargs): 463 | self.calls.append({"prompt": prompt, "messages": messages, "kwargs": kwargs}) 464 | if self.return_json: 465 | return ["{'answer':'100%'}"] 466 | return ["[[ ## answer ## ]]\n100%!"] 467 | 468 | class MySignature(dspy.Signature): 469 | question: str = dspy.InputField() 470 | history: dspy.History = dspy.InputField() 471 | answer: str = dspy.OutputField() 472 | 473 | program = Predict(MySignature) 474 | 475 | if adapter_type == "chat": 476 | lm = SpyLM("dummy_model") 477 | dspy.settings.configure(adapter=dspy.ChatAdapter(), lm=lm) 478 | else: 479 | lm = SpyLM("dummy_model", return_json=True) 480 | dspy.settings.configure(adapter=dspy.JSONAdapter(), lm=lm) 481 | 482 | program( 483 | question="are you sure that's correct?", 484 | history=dspy.History(messages=[{"question": "what's the capital of france?", "answer": "paris"}]), 485 | ) 486 | 487 | # Verify the LM was called with correct messages 488 | assert len(lm.calls) == 1 489 | messages = lm.calls[0]["messages"] 490 | 491 | assert len(messages) == 4 492 | 493 | assert "what's the capital of france?" in messages[1]["content"] 494 | assert "paris" in messages[2]["content"] 495 | assert "are you sure that's correct" in messages[3]["content"] 496 | 497 | 498 | def test_lm_usage(): 499 | program = Predict("question -> answer") 500 | dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini", cache=False), track_usage=True) 501 | with patch( 502 | "dspy.clients.lm.litellm_completion", 503 | return_value=ModelResponse( 504 | choices=[{"message": {"content": "[[ ## answer ## ]]\nParis"}}], 505 | usage={"total_tokens": 10}, 506 | ), 507 | ): 508 | result = program(question="What is the capital of France?") 509 | assert result.answer == "Paris" 510 | assert result.get_lm_usage()["openai/gpt-4o-mini"]["total_tokens"] == 10 511 | 512 | 513 | def test_lm_usage_with_parallel(): 514 | program = Predict("question -> answer") 515 | 516 | def program_wrapper(question): 517 | # Sleep to make it possible to cause a race condition 518 | time.sleep(0.5) 519 | return program(question=question) 520 | 521 | dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini", cache=False), track_usage=True) 522 | with patch( 523 | "dspy.clients.lm.litellm_completion", 524 | return_value=ModelResponse( 525 | choices=[{"message": {"content": "[[ ## answer ## ]]\nParis"}}], 526 | usage={"total_tokens": 10}, 527 | ), 528 | ): 529 | parallelizer = dspy.Parallel() 530 | input_pairs = [ 531 | (program_wrapper, {"question": "What is the capital of France?"}), 532 | (program_wrapper, {"question": "What is the capital of France?"}), 533 | ] 534 | results = parallelizer(input_pairs) 535 | assert results[0].answer == "Paris" 536 | assert results[1].answer == "Paris" 537 | assert results[0].get_lm_usage()["openai/gpt-4o-mini"]["total_tokens"] == 10 538 | assert results[1].get_lm_usage()["openai/gpt-4o-mini"]["total_tokens"] == 10 539 | 540 | 541 | @pytest.mark.asyncio 542 | async def test_lm_usage_with_async(): 543 | program = Predict("question -> answer") 544 | 545 | original_aforward = program.aforward 546 | 547 | async def patched_aforward(self, **kwargs): 548 | await asyncio.sleep(1) 549 | return await original_aforward(**kwargs) 550 | 551 | program.aforward = types.MethodType(patched_aforward, program) 552 | 553 | with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False), track_usage=True): 554 | with patch( 555 | "litellm.acompletion", 556 | return_value=ModelResponse( 557 | choices=[{"message": {"content": "[[ ## answer ## ]]\nParis"}}], 558 | usage={"total_tokens": 10}, 559 | ), 560 | ): 561 | coroutines = [ 562 | program.acall(question="What is the capital of France?"), 563 | program.acall(question="What is the capital of France?"), 564 | program.acall(question="What is the capital of France?"), 565 | program.acall(question="What is the capital of France?"), 566 | ] 567 | results = await asyncio.gather(*coroutines) 568 | assert results[0].answer == "Paris" 569 | assert results[1].answer == "Paris" 570 | assert results[0].get_lm_usage()["openai/gpt-4o-mini"]["total_tokens"] == 10 571 | assert results[1].get_lm_usage()["openai/gpt-4o-mini"]["total_tokens"] == 10 572 | assert results[2].get_lm_usage()["openai/gpt-4o-mini"]["total_tokens"] == 10 573 | assert results[3].get_lm_usage()["openai/gpt-4o-mini"]["total_tokens"] == 10 574 | 575 | 576 | def test_positional_arguments(): 577 | program = Predict("question -> answer") 578 | with pytest.raises(ValueError) as e: 579 | program("What is the capital of France?") 580 | assert str(e.value) == ( 581 | "Positional arguments are not allowed when calling `dspy.Predict`, must use keyword arguments that match " 582 | "your signature input fields: 'question'. For example: `predict(question=input_value, ...)`." 583 | ) 584 | 585 | 586 | def test_error_message_on_invalid_lm_setup(): 587 | # No LM is loaded. 588 | with pytest.raises(ValueError, match="No LM is loaded"): 589 | Predict("question -> answer")(question="Why did a chicken cross the kitchen?") 590 | 591 | # LM is a string. 592 | dspy.configure(lm="openai/gpt-4o-mini") 593 | with pytest.raises(ValueError) as e: 594 | Predict("question -> answer")(question="Why did a chicken cross the kitchen?") 595 | 596 | assert "LM must be an instance of `dspy.BaseLM`, not a string." in str(e.value) 597 | 598 | def dummy_lm(): 599 | pass 600 | 601 | # LM is not an instance of dspy.BaseLM. 602 | dspy.configure(lm=dummy_lm) 603 | with pytest.raises(ValueError) as e: 604 | Predict("question -> answer")(question="Why did a chicken cross the kitchen?") 605 | assert "LM must be an instance of `dspy.BaseLM`, not <class 'function'>." in str(e.value) 606 | 607 | 608 | @pytest.mark.parametrize("adapter_type", ["chat", "json"]) 609 | def test_field_constraints(adapter_type): 610 | class SpyLM(dspy.LM): 611 | def __init__(self, *args, return_json=False, **kwargs): 612 | super().__init__(*args, **kwargs) 613 | self.calls = [] 614 | self.return_json = return_json 615 | 616 | def __call__(self, prompt=None, messages=None, **kwargs): 617 | self.calls.append({"prompt": prompt, "messages": messages, "kwargs": kwargs}) 618 | if self.return_json: 619 | return ["{'score':'0.5', 'count':'2'}"] 620 | return ["[[ ## score ## ]]\n0.5\n[[ ## count ## ]]\n2"] 621 | 622 | class ConstrainedSignature(dspy.Signature): 623 | """Test signature with constrained fields.""" 624 | 625 | # Input with length and value constraints 626 | text: str = dspy.InputField(min_length=5, max_length=100, desc="Input text") 627 | number: int = dspy.InputField(gt=0, lt=10, desc="A number between 0 and 10") 628 | 629 | # Output with multiple constraints 630 | score: float = dspy.OutputField(ge=0.0, le=1.0, desc="Score between 0 and 1") 631 | count: int = dspy.OutputField(multiple_of=2, desc="Even number count") 632 | 633 | program = Predict(ConstrainedSignature) 634 | lm = SpyLM("dummy_model") 635 | if adapter_type == "chat": 636 | lm = SpyLM("dummy_model") 637 | dspy.settings.configure(adapter=dspy.ChatAdapter(), lm=lm) 638 | else: 639 | lm = SpyLM("dummy_model", return_json=True) 640 | dspy.settings.configure(adapter=dspy.JSONAdapter(), lm=lm) 641 | 642 | # Call the predictor to trigger instruction generation 643 | program(text="hello world", number=5) 644 | 645 | # Get the system message containing the instructions 646 | system_message = lm.calls[0]["messages"][0]["content"] 647 | 648 | # Verify constraints are included in the field descriptions 649 | assert "minimum length: 5" in system_message 650 | assert "maximum length: 100" in system_message 651 | assert "greater than: 0" in system_message 652 | assert "less than: 10" in system_message 653 | assert "greater than or equal to: 0.0" in system_message 654 | assert "less than or equal to: 1.0" in system_message 655 | assert "a multiple of the given number: 2" in system_message 656 | 657 | 658 | @pytest.mark.asyncio 659 | async def test_async_predict(): 660 | program = Predict("question -> answer") 661 | with dspy.context(lm=DummyLM([{"answer": "Paris"}])): 662 | result = await program.acall(question="What is the capital of France?") 663 | assert result.answer == "Paris" 664 | 665 | 666 | def test_predicted_outputs_piped_from_predict_to_lm_call(): 667 | program = Predict("question -> answer") 668 | dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini")) 669 | 670 | with patch("litellm.completion") as mock_completion: 671 | program( 672 | question="Why did a chicken cross the kitchen?", 673 | prediction={"type": "content", "content": "A chicken crossing the kitchen"}, 674 | ) 675 | 676 | assert mock_completion.call_args[1]["prediction"] == { 677 | "type": "content", 678 | "content": "A chicken crossing the kitchen", 679 | } 680 | 681 | # If the signature has prediction as an input field, and the prediction is not set as the standard predicted output 682 | # format, it should not be passed to the LM. 683 | program = Predict("question, prediction -> judgement") 684 | with patch("litellm.completion") as mock_completion: 685 | program(question="Why did a chicken cross the kitchen?", prediction="To get to the other side!") 686 | 687 | assert "prediction" not in mock_completion.call_args[1] 688 | 689 | 690 | def test_dump_state_pydantic_non_primitive_types(): 691 | class WebsiteInfo(BaseModel): 692 | name: str 693 | url: HttpUrl 694 | description: str | None = None 695 | created_at: datetime 696 | 697 | class TestSignature(dspy.Signature): 698 | website_info: WebsiteInfo = dspy.InputField() 699 | summary: str = dspy.OutputField() 700 | 701 | website_info = WebsiteInfo( 702 | name="Example", 703 | url="https://www.example.com", 704 | description="Test website", 705 | created_at=datetime(2021, 1, 1, 12, 0, 0), 706 | ) 707 | 708 | serialized = serialize_object(website_info) 709 | 710 | assert serialized["url"] == "https://www.example.com/" 711 | assert serialized["created_at"] == "2021-01-01T12:00:00" 712 | 713 | json_str = orjson.dumps(serialized).decode() 714 | reloaded = orjson.loads(json_str) 715 | assert reloaded == serialized 716 | 717 | predictor = Predict(TestSignature) 718 | demo = {"website_info": website_info, "summary": "This is a test website."} 719 | predictor.demos = [demo] 720 | 721 | state = predictor.dump_state() 722 | json_str = orjson.dumps(state).decode() 723 | reloaded_state = orjson.loads(json_str) 724 | 725 | demo_data = reloaded_state["demos"][0] 726 | assert demo_data["website_info"]["url"] == "https://www.example.com/" 727 | assert demo_data["website_info"]["created_at"] == "2021-01-01T12:00:00" 728 | 729 | 730 | def test_trace_size_limit(): 731 | program = Predict("question -> answer") 732 | dspy.settings.configure(lm=DummyLM([{"answer": "Paris"}]), max_trace_size=3) 733 | 734 | for _ in range(10): 735 | program(question="What is the capital of France?") 736 | 737 | assert len(dspy.settings.trace) == 3 738 | 739 | 740 | def test_disable_trace(): 741 | program = Predict("question -> answer") 742 | dspy.settings.configure(lm=DummyLM([{"answer": "Paris"}]), trace=None) 743 | 744 | for _ in range(10): 745 | program(question="What is the capital of France?") 746 | 747 | assert dspy.settings.trace is None 748 | 749 | 750 | def test_per_module_history_size_limit(): 751 | program = Predict("question -> answer") 752 | dspy.settings.configure(lm=DummyLM([{"answer": "Paris"}]), max_history_size=5) 753 | 754 | for _ in range(10): 755 | program(question="What is the capital of France?") 756 | assert len(program.history) == 5 757 | 758 | 759 | def test_per_module_history_disabled(): 760 | program = Predict("question -> answer") 761 | dspy.settings.configure(lm=DummyLM([{"answer": "Paris"}]), disable_history=True) 762 | 763 | for _ in range(10): 764 | program(question="What is the capital of France?") 765 | assert len(program.history) == 0 766 | ``` -------------------------------------------------------------------------------- /tests/teleprompt/gepa_dummy_lm.json: -------------------------------------------------------------------------------- ```json 1 | {"lm": [{"prompt": null, "messages": [{"role": "system", "content": "Your input fields are:\n1. `input` (str):\nYour output fields are:\n1. `output` (str):\nAll interactions will be structured in the following way, with the appropriate values filled in.\n\n[[ ## input ## ]]\n{input}\n\n[[ ## output ## ]]\n{output}\n\n[[ ## completed ## ]]\nIn adhering to this structure, your objective is: \n Given the fields `input`, produce the fields `output`."}, {"role": "user", "content": "[[ ## input ## ]]\nWhat does the fox say?\n\nRespond with the corresponding output fields, starting with the field `[[ ## output ## ]]`, and then ending with the marker for `[[ ## completed ## ]]`."}], "kwargs": {}, "response": "[[ ## output ## ]]\nThe fox says a variety of sounds including barks, yips, and howls. However, the phrase \"What does the fox say?\" is famously known from the viral song by Ylvis, where humorous and nonsensical sounds are suggested as the fox's call.\n\n[[ ## completed ## ]]", "outputs": ["[[ ## output ## ]]\nThe fox says a variety of sounds including barks, yips, and howls. However, the phrase \"What does the fox say?\" is famously known from the viral song by Ylvis, where humorous and nonsensical sounds are suggested as the fox's call.\n\n[[ ## completed ## ]]"], "usage": {}, "cost": 0.0001624, "timestamp": "2025-08-13T18:20:44.059820", "uuid": "93b03c3f-7e96-43a0-bac1-2ececc179365", "model": "openai/gpt-4.1-mini", "response_model": "gpt-4.1-mini-2025-04-14", "model_type": "chat"}, {"prompt": null, "messages": [{"role": "system", "content": "Your input fields are:\n1. `input` (str):\nYour output fields are:\n1. `output` (str):\nAll interactions will be structured in the following way, with the appropriate values filled in.\n\n[[ ## input ## ]]\n{input}\n\n[[ ## output ## ]]\n{output}\n\n[[ ## completed ## ]]\nIn adhering to this structure, your objective is: \n Given the fields `input`, produce the fields `output`."}, {"role": "user", "content": "[[ ## input ## ]]\nWhat is the color of the sky?\n\nRespond with the corresponding output fields, starting with the field `[[ ## output ## ]]`, and then ending with the marker for `[[ ## completed ## ]]`."}], "kwargs": {}, "response": "[[ ## output ## ]]\nThe color of the sky is typically blue during the day due to the scattering of sunlight by the atmosphere. However, it can appear different colors at sunrise, sunset, or under various weather conditions.\n[[ ## completed ## ]]", "outputs": ["[[ ## output ## ]]\nThe color of the sky is typically blue during the day due to the scattering of sunlight by the atmosphere. However, it can appear different colors at sunrise, sunset, or under various weather conditions.\n[[ ## completed ## ]]"], "usage": {}, "cost": 0.000136, "timestamp": "2025-08-13T18:20:44.060309", "uuid": "b3367637-304b-46cf-9147-6a976a8db439", "model": "openai/gpt-4.1-mini", "response_model": "gpt-4.1-mini-2025-04-14", "model_type": "chat"}, {"prompt": null, "messages": [{"role": "system", "content": "Your input fields are:\n1. `input` (str):\nYour output fields are:\n1. `output` (str):\nAll interactions will be structured in the following way, with the appropriate values filled in.\n\n[[ ## input ## ]]\n{input}\n\n[[ ## output ## ]]\n{output}\n\n[[ ## completed ## ]]\nIn adhering to this structure, your objective is: \n Given the fields `input`, produce the fields `output`."}, {"role": "user", "content": "[[ ## input ## ]]\nWhat is the color of the sky?\n\nRespond with the corresponding output fields, starting with the field `[[ ## output ## ]]`, and then ending with the marker for `[[ ## completed ## ]]`."}], "kwargs": {}, "response": "[[ ## output ## ]]\nThe color of the sky is typically blue during the day due to the scattering of sunlight by the atmosphere. However, it can appear different colors at sunrise, sunset, or under various weather conditions.\n[[ ## completed ## ]]", "outputs": ["[[ ## output ## ]]\nThe color of the sky is typically blue during the day due to the scattering of sunlight by the atmosphere. However, it can appear different colors at sunrise, sunset, or under various weather conditions.\n[[ ## completed ## ]]"], "usage": {}, "cost": 0.000136, "timestamp": "2025-08-13T18:20:44.074500", "uuid": "34df9ccb-128c-4a66-8ee7-c17de080d2ad", "model": "openai/gpt-4.1-mini", "response_model": "gpt-4.1-mini-2025-04-14", "model_type": "chat"}, {"prompt": null, "messages": [{"role": "system", "content": "Your input fields are:\n1. `input` (str):\nYour output fields are:\n1. `output` (str):\nAll interactions will be structured in the following way, with the appropriate values filled in.\n\n[[ ## input ## ]]\n{input}\n\n[[ ## output ## ]]\n{output}\n\n[[ ## completed ## ]]\nIn adhering to this structure, your objective is: \n Given the fields `input`, produce the fields `output`."}, {"role": "user", "content": "[[ ## input ## ]]\nWhat does the fox say?\n\nRespond with the corresponding output fields, starting with the field `[[ ## output ## ]]`, and then ending with the marker for `[[ ## completed ## ]]`."}], "kwargs": {}, "response": "[[ ## output ## ]]\nThe fox says a variety of sounds including barks, yips, and howls. However, the phrase \"What does the fox say?\" is famously known from the viral song by Ylvis, where humorous and nonsensical sounds are suggested as the fox's call.\n\n[[ ## completed ## ]]", "outputs": ["[[ ## output ## ]]\nThe fox says a variety of sounds including barks, yips, and howls. However, the phrase \"What does the fox say?\" is famously known from the viral song by Ylvis, where humorous and nonsensical sounds are suggested as the fox's call.\n\n[[ ## completed ## ]]"], "usage": {}, "cost": 0.0001624, "timestamp": "2025-08-13T18:20:44.076517", "uuid": "eae79449-93ab-474c-887e-e8ecde69f8e8", "model": "openai/gpt-4.1-mini", "response_model": "gpt-4.1-mini-2025-04-14", "model_type": "chat"}, {"prompt": null, "messages": [{"role": "system", "content": "Your input fields are:\n1. `input` (str):\nYour output fields are:\n1. `output` (str):\nAll interactions will be structured in the following way, with the appropriate values filled in.\n\n[[ ## input ## ]]\n{input}\n\n[[ ## output ## ]]\n{output}\n\n[[ ## completed ## ]]\nIn adhering to this structure, your objective is: \n Given the fields `input`, produce the fields `output`."}, {"role": "user", "content": "[[ ## input ## ]]\nWhat does the fox say?\n\nRespond with the corresponding output fields, starting with the field `[[ ## output ## ]]`, and then ending with the marker for `[[ ## completed ## ]]`."}], "kwargs": {}, "response": "[[ ## output ## ]]\nThe fox says a variety of sounds including barks, yips, and howls. However, the phrase \"What does the fox say?\" is famously known from the viral song by Ylvis, where humorous and nonsensical sounds are suggested as the fox's call.\n\n[[ ## completed ## ]]", "outputs": ["[[ ## output ## ]]\nThe fox says a variety of sounds including barks, yips, and howls. However, the phrase \"What does the fox say?\" is famously known from the viral song by Ylvis, where humorous and nonsensical sounds are suggested as the fox's call.\n\n[[ ## completed ## ]]"], "usage": {}, "cost": 0.0001624, "timestamp": "2025-08-13T18:20:44.078480", "uuid": "6768dd52-0e72-4648-891b-4d60bea205ec", "model": "openai/gpt-4.1-mini", "response_model": "gpt-4.1-mini-2025-04-14", "model_type": "chat"}, {"prompt": null, "messages": [{"role": "system", "content": "Your input fields are:\n1. `input` (str):\nYour output fields are:\n1. `output` (str):\nAll interactions will be structured in the following way, with the appropriate values filled in.\n\n[[ ## input ## ]]\n{input}\n\n[[ ## output ## ]]\n{output}\n\n[[ ## completed ## ]]\nIn adhering to this structure, your objective is: \n Given the field `input` containing a question or phrase, produce the field `output` containing the exact, direct, and contextually appropriate answer or response that the user expects, without additional explanations, commentary, or general knowledge unless explicitly requested.\n \n Key details and guidelines:\n \n 1. The `input` field contains a question or phrase that may be literal, factual, or culturally specific (e.g., references to popular culture or memes).\n \n 2. The `output` must be the precise answer or response that directly addresses the `input` as intended by the user, not a general or encyclopedic explanation.\n \n 3. If the `input` is a well-known phrase or question from popular culture (e.g., \"What does the fox say?\"), the `output` should reflect the expected or canonical answer associated with that phrase, rather than a factual or scientific explanation.\n \n 4. Avoid providing additional background information, scientific explanations, or alternative interpretations unless explicitly requested.\n \n 5. The goal is to produce the answer that the user expects or the \"correct\" answer in the context of the question, including culturally recognized or meme-based answers.\n \n 6. If the `input` is a straightforward factual question (e.g., \"What is the color of the sky?\"), provide the commonly accepted direct answer (e.g., \"Blue\") rather than a detailed scientific explanation.\n \n 7. The output should be concise, clear, and focused solely on answering the question or phrase in the `input`.\n \n Example:\n \n - Input: \"What is the color of the sky?\"\n - Output: \"Blue.\"\n \n - Input: \"What does the fox say?\"\n - Output: \"Ring-ding-ding-ding-dingeringeding!\"\n \n This approach ensures that the assistant provides the expected, contextually appropriate answers rather than general or overly detailed responses that may be considered incorrect by the user."}, {"role": "user", "content": "[[ ## input ## ]]\nWhat is the color of the sky?\n\nRespond with the corresponding output fields, starting with the field `[[ ## output ## ]]`, and then ending with the marker for `[[ ## completed ## ]]`."}], "kwargs": {}, "response": "[[ ## output ## ]]\nBlue.\n\n[[ ## completed ## ]]", "outputs": ["[[ ## output ## ]]\nBlue.\n\n[[ ## completed ## ]]"], "usage": {}, "cost": 0.000234, "timestamp": "2025-08-13T18:20:44.093296", "uuid": "3d8eee4c-f7f0-4f74-a67a-8a43da85da15", "model": "openai/gpt-4.1-mini", "response_model": "gpt-4.1-mini-2025-04-14", "model_type": "chat"}, {"prompt": null, "messages": [{"role": "system", "content": "Your input fields are:\n1. `input` (str):\nYour output fields are:\n1. `output` (str):\nAll interactions will be structured in the following way, with the appropriate values filled in.\n\n[[ ## input ## ]]\n{input}\n\n[[ ## output ## ]]\n{output}\n\n[[ ## completed ## ]]\nIn adhering to this structure, your objective is: \n Given the field `input` containing a question or phrase, produce the field `output` containing the exact, direct, and contextually appropriate answer or response that the user expects, without additional explanations, commentary, or general knowledge unless explicitly requested.\n \n Key details and guidelines:\n \n 1. The `input` field contains a question or phrase that may be literal, factual, or culturally specific (e.g., references to popular culture or memes).\n \n 2. The `output` must be the precise answer or response that directly addresses the `input` as intended by the user, not a general or encyclopedic explanation.\n \n 3. If the `input` is a well-known phrase or question from popular culture (e.g., \"What does the fox say?\"), the `output` should reflect the expected or canonical answer associated with that phrase, rather than a factual or scientific explanation.\n \n 4. Avoid providing additional background information, scientific explanations, or alternative interpretations unless explicitly requested.\n \n 5. The goal is to produce the answer that the user expects or the \"correct\" answer in the context of the question, including culturally recognized or meme-based answers.\n \n 6. If the `input` is a straightforward factual question (e.g., \"What is the color of the sky?\"), provide the commonly accepted direct answer (e.g., \"Blue\") rather than a detailed scientific explanation.\n \n 7. The output should be concise, clear, and focused solely on answering the question or phrase in the `input`.\n \n Example:\n \n - Input: \"What is the color of the sky?\"\n - Output: \"Blue.\"\n \n - Input: \"What does the fox say?\"\n - Output: \"Ring-ding-ding-ding-dingeringeding!\"\n \n This approach ensures that the assistant provides the expected, contextually appropriate answers rather than general or overly detailed responses that may be considered incorrect by the user."}, {"role": "user", "content": "[[ ## input ## ]]\nWhat does the fox say?\n\nRespond with the corresponding output fields, starting with the field `[[ ## output ## ]]`, and then ending with the marker for `[[ ## completed ## ]]`."}], "kwargs": {}, "response": "[[ ## output ## ]]\nRing-ding-ding-ding-dingeringeding!\n\n[[ ## completed ## ]]", "outputs": ["[[ ## output ## ]]\nRing-ding-ding-ding-dingeringeding!\n\n[[ ## completed ## ]]"], "usage": {}, "cost": 0.0002492, "timestamp": "2025-08-13T18:20:44.094009", "uuid": "ef866572-bbe0-4b47-b64a-af377d18d786", "model": "openai/gpt-4.1-mini", "response_model": "gpt-4.1-mini-2025-04-14", "model_type": "chat"}, {"prompt": null, "messages": [{"role": "system", "content": "Your input fields are:\n1. `input` (str):\nYour output fields are:\n1. `output` (str):\nAll interactions will be structured in the following way, with the appropriate values filled in.\n\n[[ ## input ## ]]\n{input}\n\n[[ ## output ## ]]\n{output}\n\n[[ ## completed ## ]]\nIn adhering to this structure, your objective is: \n Given the field `input` containing a question or phrase, produce the field `output` containing the exact, direct, and contextually appropriate answer or response that the user expects, without additional explanations, commentary, or general knowledge unless explicitly requested.\n \n Key details and guidelines:\n \n 1. The `input` field contains a question or phrase that may be literal, factual, or culturally specific (e.g., references to popular culture or memes).\n \n 2. The `output` must be the precise answer or response that directly addresses the `input` as intended by the user, not a general or encyclopedic explanation.\n \n 3. If the `input` is a well-known phrase or question from popular culture (e.g., \"What does the fox say?\"), the `output` should reflect the expected or canonical answer associated with that phrase, rather than a factual or scientific explanation.\n \n 4. Avoid providing additional background information, scientific explanations, or alternative interpretations unless explicitly requested.\n \n 5. The goal is to produce the answer that the user expects or the \"correct\" answer in the context of the question, including culturally recognized or meme-based answers.\n \n 6. If the `input` is a straightforward factual question (e.g., \"What is the color of the sky?\"), provide the commonly accepted direct answer (e.g., \"Blue\") rather than a detailed scientific explanation.\n \n 7. The output should be concise, clear, and focused solely on answering the question or phrase in the `input`.\n \n Example:\n \n - Input: \"What is the color of the sky?\"\n - Output: \"Blue.\"\n \n - Input: \"What does the fox say?\"\n - Output: \"Ring-ding-ding-ding-dingeringeding!\"\n \n This approach ensures that the assistant provides the expected, contextually appropriate answers rather than general or overly detailed responses that may be considered incorrect by the user."}, {"role": "user", "content": "[[ ## input ## ]]\nWhat does the fox say?\n\nRespond with the corresponding output fields, starting with the field `[[ ## output ## ]]`, and then ending with the marker for `[[ ## completed ## ]]`."}], "kwargs": {}, "response": "[[ ## output ## ]]\nRing-ding-ding-ding-dingeringeding!\n\n[[ ## completed ## ]]", "outputs": ["[[ ## output ## ]]\nRing-ding-ding-ding-dingeringeding!\n\n[[ ## completed ## ]]"], "usage": {}, "cost": 0.0002492, "timestamp": "2025-08-13T18:20:44.094555", "uuid": "5b223989-d453-4c94-bf74-7c94d328c94d", "model": "openai/gpt-4.1-mini", "response_model": "gpt-4.1-mini-2025-04-14", "model_type": "chat"}, {"prompt": null, "messages": [{"role": "system", "content": "Your input fields are:\n1. `input` (str):\nYour output fields are:\n1. `output` (str):\nAll interactions will be structured in the following way, with the appropriate values filled in.\n\n[[ ## input ## ]]\n{input}\n\n[[ ## output ## ]]\n{output}\n\n[[ ## completed ## ]]\nIn adhering to this structure, your objective is: \n Given the field `input` containing a question or phrase, produce the field `output` containing the exact, direct, and contextually appropriate answer or response that the user expects, without additional explanations, commentary, or general knowledge unless explicitly requested.\n \n Key details and guidelines:\n \n 1. The `input` field contains a question or phrase that may be literal, factual, or culturally specific (e.g., references to popular culture or memes).\n \n 2. The `output` must be the precise answer or response that directly addresses the `input` as intended by the user, not a general or encyclopedic explanation.\n \n 3. If the `input` is a well-known phrase or question from popular culture (e.g., \"What does the fox say?\"), the `output` should reflect the expected or canonical answer associated with that phrase, rather than a factual or scientific explanation.\n \n 4. Avoid providing additional background information, scientific explanations, or alternative interpretations unless explicitly requested.\n \n 5. The goal is to produce the answer that the user expects or the \"correct\" answer in the context of the question, including culturally recognized or meme-based answers.\n \n 6. If the `input` is a straightforward factual question (e.g., \"What is the color of the sky?\"), provide the commonly accepted direct answer (e.g., \"Blue\") rather than a detailed scientific explanation.\n \n 7. The output should be concise, clear, and focused solely on answering the question or phrase in the `input`.\n \n Example:\n \n - Input: \"What is the color of the sky?\"\n - Output: \"Blue.\"\n \n - Input: \"What does the fox say?\"\n - Output: \"Ring-ding-ding-ding-dingeringeding!\"\n \n This approach ensures that the assistant provides the expected, contextually appropriate answers rather than general or overly detailed responses that may be considered incorrect by the user."}, {"role": "user", "content": "[[ ## input ## ]]\nWhat does the fox say?\n\nRespond with the corresponding output fields, starting with the field `[[ ## output ## ]]`, and then ending with the marker for `[[ ## completed ## ]]`."}], "kwargs": {}, "response": "[[ ## output ## ]]\nRing-ding-ding-ding-dingeringeding!\n\n[[ ## completed ## ]]", "outputs": ["[[ ## output ## ]]\nRing-ding-ding-ding-dingeringeding!\n\n[[ ## completed ## ]]"], "usage": {}, "cost": 0.0002492, "timestamp": "2025-08-13T18:20:44.123406", "uuid": "64ce9bc4-5a84-4eb4-b8fa-5cceddfb8c7c", "model": "openai/gpt-4.1-mini", "response_model": "gpt-4.1-mini-2025-04-14", "model_type": "chat"}, {"prompt": null, "messages": [{"role": "system", "content": "Your input fields are:\n1. `input` (str):\nYour output fields are:\n1. `output` (str):\nAll interactions will be structured in the following way, with the appropriate values filled in.\n\n[[ ## input ## ]]\n{input}\n\n[[ ## output ## ]]\n{output}\n\n[[ ## completed ## ]]\nIn adhering to this structure, your objective is: \n Given the field `input` containing a question or phrase, produce the field `output` containing the exact, direct, and contextually appropriate answer or response that the user expects, without additional explanations, commentary, or general knowledge unless explicitly requested.\n \n Key details and guidelines:\n \n 1. The `input` field contains a question or phrase that may be literal, factual, or culturally specific (e.g., references to popular culture or memes).\n \n 2. The `output` must be the precise answer or response that directly addresses the `input` as intended by the user, not a general or encyclopedic explanation.\n \n 3. If the `input` is a well-known phrase or question from popular culture (e.g., \"What does the fox say?\"), the `output` should reflect the expected or canonical answer associated with that phrase, rather than a factual or scientific explanation.\n \n 4. Avoid providing additional background information, scientific explanations, or alternative interpretations unless explicitly requested.\n \n 5. The goal is to produce the answer that the user expects or the \"correct\" answer in the context of the question, including culturally recognized or meme-based answers.\n \n 6. If the `input` is a straightforward factual question (e.g., \"What is the color of the sky?\"), provide the commonly accepted direct answer (e.g., \"Blue\") rather than a detailed scientific explanation.\n \n 7. The output should be concise, clear, and focused solely on answering the question or phrase in the `input`.\n \n Example:\n \n - Input: \"What is the color of the sky?\"\n - Output: \"Blue.\"\n \n - Input: \"What does the fox say?\"\n - Output: \"Ring-ding-ding-ding-dingeringeding!\"\n \n This approach ensures that the assistant provides the expected, contextually appropriate answers rather than general or overly detailed responses that may be considered incorrect by the user."}, {"role": "user", "content": "[[ ## input ## ]]\nWhat is the color of the sky?\n\nRespond with the corresponding output fields, starting with the field `[[ ## output ## ]]`, and then ending with the marker for `[[ ## completed ## ]]`."}], "kwargs": {}, "response": "[[ ## output ## ]]\nBlue.\n\n[[ ## completed ## ]]", "outputs": ["[[ ## output ## ]]\nBlue.\n\n[[ ## completed ## ]]"], "usage": {}, "cost": 0.000234, "timestamp": "2025-08-13T18:20:44.125206", "uuid": "5b38b61a-e139-4962-b788-570b3cbb5c38", "model": "openai/gpt-4.1-mini", "response_model": "gpt-4.1-mini-2025-04-14", "model_type": "chat"}], "reflection_lm": [{"prompt": "I provided an assistant with the following instructions to perform a task for me:\n```\nGiven the fields `input`, produce the fields `output`.\n```\n\nThe following are examples of different task inputs provided to the assistant along with the assistant's response for each of them, and some feedback on how the assistant's response could be better:\n```\n# Example 1\n## Inputs\n### input\nWhat is the color of the sky?\n\n## Generated Outputs\n### output\nThe color of the sky is typically blue during the day due to the scattering of sunlight by the atmosphere. However, it can appear different colors at sunrise, sunset, or under various weather conditions.\n\n## Feedback\nWrong answer.\n\n\n\n# Example 2\n## Inputs\n### input\nWhat does the fox say?\n\n## Generated Outputs\n### output\nThe fox says a variety of sounds including barks, yips, and howls. However, the phrase \"What does the fox say?\" is famously known from the viral song by Ylvis, where humorous and nonsensical sounds are suggested as the fox's call.\n\n## Feedback\nWrong answer.\n\n\n\n# Example 3\n## Inputs\n### input\nWhat does the fox say?\n\n## Generated Outputs\n### output\nThe fox says a variety of sounds including barks, yips, and howls. However, the phrase \"What does the fox say?\" is famously known from the viral song by Ylvis, where humorous and nonsensical sounds are suggested as the fox's call.\n\n## Feedback\nWrong answer.\n\n\n```\n\nYour task is to write a new instruction for the assistant.\n\nRead the inputs carefully and identify the input format and infer detailed task description about the task I wish to solve with the assistant.\n\nRead all the assistant responses and the corresponding feedback. Identify all niche and domain specific factual information about the task and include it in the instruction, as a lot of it may not be available to the assistant in the future. The assistant may have utilized a generalizable strategy to solve the task, if so, include that in the instruction as well.\n\nProvide the new instructions within ``` blocks.", "messages": null, "kwargs": {}, "response": "``` \nGiven the field `input` containing a question or phrase, produce the field `output` containing the exact, direct, and contextually appropriate answer or response that the user expects, without additional explanations, commentary, or general knowledge unless explicitly requested.\n\nKey details and guidelines:\n\n1. The `input` field contains a question or phrase that may be literal, factual, or culturally specific (e.g., references to popular culture or memes).\n\n2. The `output` must be the precise answer or response that directly addresses the `input` as intended by the user, not a general or encyclopedic explanation.\n\n3. If the `input` is a well-known phrase or question from popular culture (e.g., \"What does the fox say?\"), the `output` should reflect the expected or canonical answer associated with that phrase, rather than a factual or scientific explanation.\n\n4. Avoid providing additional background information, scientific explanations, or alternative interpretations unless explicitly requested.\n\n5. The goal is to produce the answer that the user expects or the \"correct\" answer in the context of the question, including culturally recognized or meme-based answers.\n\n6. If the `input` is a straightforward factual question (e.g., \"What is the color of the sky?\"), provide the commonly accepted direct answer (e.g., \"Blue\") rather than a detailed scientific explanation.\n\n7. The output should be concise, clear, and focused solely on answering the question or phrase in the `input`.\n\nExample:\n\n- Input: \"What is the color of the sky?\"\n- Output: \"Blue.\"\n\n- Input: \"What does the fox say?\"\n- Output: \"Ring-ding-ding-ding-dingeringeding!\"\n\nThis approach ensures that the assistant provides the expected, contextually appropriate answers rather than general or overly detailed responses that may be considered incorrect by the user.\n```", "outputs": ["``` \nGiven the field `input` containing a question or phrase, produce the field `output` containing the exact, direct, and contextually appropriate answer or response that the user expects, without additional explanations, commentary, or general knowledge unless explicitly requested.\n\nKey details and guidelines:\n\n1. The `input` field contains a question or phrase that may be literal, factual, or culturally specific (e.g., references to popular culture or memes).\n\n2. The `output` must be the precise answer or response that directly addresses the `input` as intended by the user, not a general or encyclopedic explanation.\n\n3. If the `input` is a well-known phrase or question from popular culture (e.g., \"What does the fox say?\"), the `output` should reflect the expected or canonical answer associated with that phrase, rather than a factual or scientific explanation.\n\n4. Avoid providing additional background information, scientific explanations, or alternative interpretations unless explicitly requested.\n\n5. The goal is to produce the answer that the user expects or the \"correct\" answer in the context of the question, including culturally recognized or meme-based answers.\n\n6. If the `input` is a straightforward factual question (e.g., \"What is the color of the sky?\"), provide the commonly accepted direct answer (e.g., \"Blue\") rather than a detailed scientific explanation.\n\n7. The output should be concise, clear, and focused solely on answering the question or phrase in the `input`.\n\nExample:\n\n- Input: \"What is the color of the sky?\"\n- Output: \"Blue.\"\n\n- Input: \"What does the fox say?\"\n- Output: \"Ring-ding-ding-ding-dingeringeding!\"\n\nThis approach ensures that the assistant provides the expected, contextually appropriate answers rather than general or overly detailed responses that may be considered incorrect by the user.\n```"], "usage": {}, "cost": 0.000774, "timestamp": "2025-08-13T18:20:44.080463", "uuid": "c71eee51-af0b-4469-a365-343105013d66", "model": "openai/gpt-4.1-mini", "response_model": "gpt-4.1-mini-2025-04-14", "model_type": "chat"}]} ``` -------------------------------------------------------------------------------- /dspy/teleprompt/gepa/gepa.py: -------------------------------------------------------------------------------- ```python 1 | import inspect 2 | import logging 3 | import random 4 | from dataclasses import dataclass 5 | from typing import Any, Literal, Optional, Protocol, Union 6 | 7 | from gepa import GEPAResult 8 | from gepa.core.adapter import ProposalFn 9 | from gepa.proposer.reflective_mutation.base import ReflectionComponentSelector 10 | 11 | from dspy.clients.lm import LM 12 | from dspy.primitives import Example, Module, Prediction 13 | from dspy.teleprompt.gepa.gepa_utils import DspyAdapter, DSPyTrace, PredictorFeedbackFn, ScoreWithFeedback 14 | from dspy.teleprompt.teleprompt import Teleprompter 15 | from dspy.utils.annotation import experimental 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | AUTO_RUN_SETTINGS = { 20 | "light": {"n": 6}, 21 | "medium": {"n": 12}, 22 | "heavy": {"n": 18}, 23 | } 24 | 25 | @experimental(version="3.0.0") 26 | class GEPAFeedbackMetric(Protocol): 27 | def __call__( 28 | gold: Example, 29 | pred: Prediction, 30 | trace: Optional["DSPyTrace"], 31 | pred_name: str | None, 32 | pred_trace: Optional["DSPyTrace"], 33 | ) -> Union[float, "ScoreWithFeedback"]: 34 | """ 35 | This function is called with the following arguments: 36 | - gold: The gold example. 37 | - pred: The predicted output. 38 | - trace: Optional. The trace of the program's execution. 39 | - pred_name: Optional. The name of the target predictor currently being optimized by GEPA, for which 40 | the feedback is being requested. 41 | - pred_trace: Optional. The trace of the target predictor's execution GEPA is seeking feedback for. 42 | 43 | Note the `pred_name` and `pred_trace` arguments. During optimization, GEPA will call the metric to obtain 44 | feedback for individual predictors being optimized. GEPA provides the name of the predictor in `pred_name` 45 | and the sub-trace (of the trace) corresponding to the predictor in `pred_trace`. 46 | If available at the predictor level, the metric should return dspy.Prediction(score: float, feedback: str) corresponding 47 | to the predictor. 48 | If not available at the predictor level, the metric can also return a text feedback at the program level 49 | (using just the gold, pred and trace). 50 | If no feedback is returned, GEPA will use a simple text feedback consisting of just the score: 51 | f"This trajectory got a score of {score}." 52 | """ 53 | ... 54 | 55 | @experimental(version="3.0.0") 56 | @dataclass(frozen=True) 57 | class DspyGEPAResult: 58 | """ 59 | Additional data related to the GEPA run. 60 | 61 | Fields: 62 | - candidates: list of proposed candidates (component_name -> component_text) 63 | - parents: lineage info; for each candidate i, parents[i] is a list of parent indices or None 64 | - val_aggregate_scores: per-candidate aggregate score on the validation set (higher is better) 65 | - val_subscores: per-candidate per-instance scores on the validation set (len == num_val_instances) 66 | - per_val_instance_best_candidates: for each val instance t, a set of candidate indices achieving the best score on t 67 | - discovery_eval_counts: Budget (number of metric calls / rollouts) consumed up to the discovery of each candidate 68 | 69 | - total_metric_calls: total number of metric calls made across the run 70 | - num_full_val_evals: number of full validation evaluations performed 71 | - log_dir: where artifacts were written (if any) 72 | - seed: RNG seed for reproducibility (if known) 73 | 74 | - best_idx: candidate index with the highest val_aggregate_scores 75 | - best_candidate: the program text mapping for best_idx 76 | """ 77 | # Data about the proposed candidates 78 | candidates: list[Module] 79 | parents: list[list[int | None]] 80 | val_aggregate_scores: list[float] 81 | val_subscores: list[list[float]] 82 | per_val_instance_best_candidates: list[set[int]] 83 | discovery_eval_counts: list[int] 84 | 85 | # Optional data 86 | best_outputs_valset: list[list[tuple[int, list[Prediction]]]] | None = None 87 | 88 | # Optimization metadata 89 | total_metric_calls: int | None = None 90 | num_full_val_evals: int | None = None 91 | log_dir: str | None = None 92 | seed: int | None = None 93 | 94 | @property 95 | def best_idx(self) -> int: 96 | scores = self.val_aggregate_scores 97 | return max(range(len(scores)), key=lambda i: scores[i]) 98 | 99 | @property 100 | def best_candidate(self) -> dict[str, str]: 101 | return self.candidates[self.best_idx] 102 | 103 | @property 104 | def highest_score_achieved_per_val_task(self) -> list[float]: 105 | return [ 106 | self.val_subscores[list(self.per_val_instance_best_candidates[val_idx])[0]][val_idx] 107 | for val_idx in range(len(self.val_subscores[0])) 108 | ] 109 | 110 | def to_dict(self) -> dict[str, Any]: 111 | cands = [ 112 | {k: v for k, v in cand.items()} 113 | for cand in self.candidates 114 | ] 115 | 116 | return dict( 117 | candidates=cands, 118 | parents=self.parents, 119 | val_aggregate_scores=self.val_aggregate_scores, 120 | best_outputs_valset=self.best_outputs_valset, 121 | val_subscores=self.val_subscores, 122 | per_val_instance_best_candidates=[list(s) for s in self.per_val_instance_best_candidates], 123 | discovery_eval_counts=self.discovery_eval_counts, 124 | total_metric_calls=self.total_metric_calls, 125 | num_full_val_evals=self.num_full_val_evals, 126 | log_dir=self.log_dir, 127 | seed=self.seed, 128 | best_idx=self.best_idx, 129 | ) 130 | 131 | @staticmethod 132 | def from_gepa_result(gepa_result: "GEPAResult", adapter: "DspyAdapter") -> "DspyGEPAResult": 133 | return DspyGEPAResult( 134 | candidates=[adapter.build_program(c) for c in gepa_result.candidates], 135 | parents=gepa_result.parents, 136 | val_aggregate_scores=gepa_result.val_aggregate_scores, 137 | best_outputs_valset=gepa_result.best_outputs_valset, 138 | val_subscores=gepa_result.val_subscores, 139 | per_val_instance_best_candidates=gepa_result.per_val_instance_best_candidates, 140 | discovery_eval_counts=gepa_result.discovery_eval_counts, 141 | total_metric_calls=gepa_result.total_metric_calls, 142 | num_full_val_evals=gepa_result.num_full_val_evals, 143 | log_dir=gepa_result.run_dir, 144 | seed=gepa_result.seed, 145 | ) 146 | 147 | @experimental(version="3.0.0") 148 | class GEPA(Teleprompter): 149 | """ 150 | GEPA is an evolutionary optimizer, which uses reflection to evolve text components 151 | of complex systems. GEPA is proposed in the paper [GEPA: Reflective Prompt Evolution Can Outperform Reinforcement Learning](https://arxiv.org/abs/2507.19457). 152 | The GEPA optimization engine is provided by the `gepa` package, available from [https://github.com/gepa-ai/gepa](https://github.com/gepa-ai/gepa). 153 | 154 | GEPA captures full traces of the DSPy module's execution, identifies the parts of the trace 155 | corresponding to a specific predictor, and reflects on the behaviour of the predictor to 156 | propose a new instruction for the predictor. GEPA allows users to provide textual feedback 157 | to the optimizer, which is used to guide the evolution of the predictor. The textual feedback 158 | can be provided at the granularity of individual predictors, or at the level of the entire system's 159 | execution. 160 | 161 | To provide feedback to the GEPA optimizer, implement a metric as follows: 162 | ``` 163 | def metric( 164 | gold: Example, 165 | pred: Prediction, 166 | trace: Optional[DSPyTrace] = None, 167 | pred_name: Optional[str] = None, 168 | pred_trace: Optional[DSPyTrace] = None, 169 | ) -> float | ScoreWithFeedback: 170 | \""" 171 | This function is called with the following arguments: 172 | - gold: The gold example. 173 | - pred: The predicted output. 174 | - trace: Optional. The trace of the program's execution. 175 | - pred_name: Optional. The name of the target predictor currently being optimized by GEPA, for which 176 | the feedback is being requested. 177 | - pred_trace: Optional. The trace of the target predictor's execution GEPA is seeking feedback for. 178 | 179 | Note the `pred_name` and `pred_trace` arguments. During optimization, GEPA will call the metric to obtain 180 | feedback for individual predictors being optimized. GEPA provides the name of the predictor in `pred_name` 181 | and the sub-trace (of the trace) corresponding to the predictor in `pred_trace`. 182 | If available at the predictor level, the metric should return {'score': float, 'feedback': str} corresponding 183 | to the predictor. 184 | If not available at the predictor level, the metric can also return a text feedback at the program level 185 | (using just the gold, pred and trace). 186 | If no feedback is returned, GEPA will use a simple text feedback consisting of just the score: 187 | f"This trajectory got a score of {score}." 188 | \""" 189 | ... 190 | ``` 191 | 192 | GEPA can also be used as a batch inference-time search strategy, by passing `valset=trainset, track_stats=True, track_best_outputs=True`, and using the 193 | `detailed_results` attribute of the optimized program (returned by `compile`) to get the Pareto frontier of the batch. `optimized_program.detailed_results.best_outputs_valset` will contain the best outputs for each task in the batch. 194 | 195 | Example: 196 | ``` 197 | gepa = GEPA(metric=metric, track_stats=True) 198 | batch_of_tasks = [dspy.Example(...) for task in tasks] 199 | new_prog = gepa.compile(student, trainset=trainset, valset=batch_of_tasks) 200 | pareto_frontier = new_prog.detailed_results.val_aggregate_scores 201 | # pareto_frontier is a list of scores, one for each task in the batch. 202 | ``` 203 | 204 | Args: 205 | metric: The metric function to use for feedback and evaluation. 206 | auto: The auto budget to use for the run. Options: "light", "medium", "heavy". 207 | max_full_evals: The maximum number of full evaluations to perform. 208 | max_metric_calls: The maximum number of metric calls to perform. 209 | reflection_minibatch_size: The number of examples to use for reflection in a single GEPA step. Default is 3. 210 | candidate_selection_strategy: The strategy to use for candidate selection. Default is "pareto", 211 | which stochastically selects candidates from the Pareto frontier of all validation scores. 212 | Options: "pareto", "current_best". 213 | reflection_lm: The language model to use for reflection. Required parameter. GEPA benefits from 214 | a strong reflection model. Consider using `dspy.LM(model='gpt-5', temperature=1.0, max_tokens=32000)` 215 | for optimal performance. 216 | skip_perfect_score: Whether to skip examples with perfect scores during reflection. Default is True. 217 | instruction_proposer: Optional custom instruction proposer implementing GEPA's ProposalFn protocol. 218 | **Default: None (recommended for most users)** - Uses GEPA's proven instruction proposer from 219 | the [GEPA library](https://github.com/gepa-ai/gepa), which implements the 220 | [`ProposalFn`](https://github.com/gepa-ai/gepa/blob/main/src/gepa/core/adapter.py). This default 221 | proposer is highly capable and was validated across diverse experiments reported in the GEPA 222 | paper and tutorials. 223 | 224 | See documentation on custom instruction proposers 225 | [here](https://dspy.ai/api/optimizers/GEPA/GEPA_Advanced/#custom-instruction-proposers). 226 | 227 | **Advanced Feature**: Only needed for specialized scenarios: 228 | - **Multi-modal handling**: Processing dspy.Image inputs alongside textual information 229 | - **Nuanced control over constraints**: Fine-grained control over instruction length, format, 230 | and structural requirements beyond standard feedback mechanisms 231 | - **Domain-specific knowledge injection**: Specialized terminology or context that cannot be 232 | provided through feedback_func alone 233 | - **Provider-specific prompting**: Optimizations for specific LLM providers (OpenAI, Anthropic) 234 | with unique formatting preferences 235 | - **Coupled component updates**: Coordinated updates of multiple components together rather 236 | than independent optimization 237 | - **External knowledge integration**: Runtime access to databases, APIs, or knowledge bases 238 | 239 | The default proposer handles the vast majority of use cases effectively. Use 240 | MultiModalInstructionProposer() from dspy.teleprompt.gepa.instruction_proposal for visual 241 | content or implement custom ProposalFn for highly specialized requirements. 242 | 243 | Note: When both instruction_proposer and reflection_lm are set, the instruction_proposer is called 244 | in the reflection_lm context. However, reflection_lm is optional when using a custom instruction_proposer. 245 | Custom instruction proposers can invoke their own LLMs if needed. 246 | component_selector: Custom component selector implementing the ReflectionComponentSelector protocol, 247 | or a string specifying a built-in selector strategy. Controls which components (predictors) are selected 248 | for optimization at each iteration. Defaults to 'round_robin' strategy which cycles through components 249 | one at a time. Available string options: 'round_robin' (cycles through components sequentially), 250 | 'all' (selects all components for simultaneous optimization). Custom selectors can implement strategies 251 | using LLM-driven selection logic based on optimization state and trajectories. 252 | See [gepa component selectors](https://github.com/gepa-ai/gepa/blob/main/src/gepa/strategies/component_selector.py) 253 | for available built-in selectors and the ReflectionComponentSelector protocol for implementing custom selectors. 254 | add_format_failure_as_feedback: Whether to add format failures as feedback. Default is False. 255 | use_merge: Whether to use merge-based optimization. Default is True. 256 | max_merge_invocations: The maximum number of merge invocations to perform. Default is 5. 257 | num_threads: The number of threads to use for evaluation with `Evaluate`. Optional. 258 | failure_score: The score to assign to failed examples. Default is 0.0. 259 | perfect_score: The maximum score achievable by the metric. Default is 1.0. Used by GEPA 260 | to determine if all examples in a minibatch are perfect. 261 | log_dir: The directory to save the logs. GEPA saves elaborate logs, along with all candidate 262 | programs, in this directory. Running GEPA with the same `log_dir` will resume the run 263 | from the last checkpoint. 264 | track_stats: Whether to return detailed results and all proposed programs in the `detailed_results` 265 | attribute of the optimized program. Default is False. 266 | use_wandb: Whether to use wandb for logging. Default is False. 267 | wandb_api_key: The API key to use for wandb. If not provided, wandb will use the API key 268 | from the environment variable `WANDB_API_KEY`. 269 | wandb_init_kwargs: Additional keyword arguments to pass to `wandb.init`. 270 | track_best_outputs: Whether to track the best outputs on the validation set. track_stats must 271 | be True if track_best_outputs is True. The optimized program's `detailed_results.best_outputs_valset` 272 | will contain the best outputs for each task in the validation set. 273 | warn_on_score_mismatch: GEPA (currently) expects the metric to return the same module-level score when 274 | called with and without the pred_name. This flag (defaults to True) determines whether a warning is 275 | raised if a mismatch in module-level and predictor-level score is detected. 276 | seed: The random seed to use for reproducibility. Default is 0. 277 | gepa_kwargs: (Optional) provide additional kwargs to be passed to [gepa.optimize](https://github.com/gepa-ai/gepa/blob/main/src/gepa/api.py) method 278 | 279 | Note: 280 | Budget Configuration: Exactly one of `auto`, `max_full_evals`, or `max_metric_calls` must be provided. 281 | The `auto` parameter provides preset configurations: "light" for quick experimentation, "medium" for 282 | balanced optimization, and "heavy" for thorough optimization. 283 | 284 | Reflection Configuration: The `reflection_lm` parameter is required and should be a strong language model. 285 | GEPA performs best with models like `dspy.LM(model='gpt-5', temperature=1.0, max_tokens=32000)`. 286 | The reflection process analyzes failed examples to generate feedback for program improvement. 287 | 288 | Merge Configuration: GEPA can merge successful program variants using `use_merge=True`. 289 | The `max_merge_invocations` parameter controls how many merge attempts are made during optimization. 290 | 291 | Evaluation Configuration: Use `num_threads` to parallelize evaluation. The `failure_score` and 292 | `perfect_score` parameters help GEPA understand your metric's range and optimize accordingly. 293 | 294 | Logging Configuration: Set `log_dir` to save detailed logs and enable checkpoint resuming. 295 | Use `track_stats=True` to access detailed optimization results via the `detailed_results` attribute. 296 | Enable `use_wandb=True` for experiment tracking and visualization. 297 | 298 | Reproducibility: Set `seed` to ensure consistent results across runs with the same configuration. 299 | """ 300 | def __init__( 301 | self, 302 | metric: GEPAFeedbackMetric, 303 | *, 304 | # Budget configuration 305 | auto: Literal["light", "medium", "heavy"] | None = None, 306 | max_full_evals: int | None = None, 307 | max_metric_calls: int | None = None, 308 | # Reflection configuration 309 | reflection_minibatch_size: int = 3, 310 | candidate_selection_strategy: Literal["pareto", "current_best"] = "pareto", 311 | reflection_lm: LM | None = None, 312 | skip_perfect_score: bool = True, 313 | add_format_failure_as_feedback: bool = False, 314 | instruction_proposer: "ProposalFn | None" = None, 315 | component_selector: "ReflectionComponentSelector | str" = "round_robin", 316 | # Merge-based configuration 317 | use_merge: bool = True, 318 | max_merge_invocations: int | None = 5, 319 | # Evaluation configuration 320 | num_threads: int | None = None, 321 | failure_score: float = 0.0, 322 | perfect_score: float = 1.0, 323 | # Logging 324 | log_dir: str = None, 325 | track_stats: bool = False, 326 | use_wandb: bool = False, 327 | wandb_api_key: str | None = None, 328 | wandb_init_kwargs: dict[str, Any] | None = None, 329 | track_best_outputs: bool = False, 330 | warn_on_score_mismatch: bool = True, 331 | use_mlflow: bool = False, 332 | # Reproducibility 333 | seed: int | None = 0, 334 | # GEPA passthrough kwargs 335 | gepa_kwargs: dict | None = None 336 | ): 337 | try: 338 | inspect.signature(metric).bind(None, None, None, None, None) 339 | except TypeError as e: 340 | raise TypeError( 341 | "GEPA metric must accept five arguments: (gold, pred, trace, pred_name, pred_trace). " 342 | "See https://dspy.ai/api/optimizers/GEPA for details." 343 | ) from e 344 | 345 | self.metric_fn = metric 346 | 347 | # Budget configuration 348 | assert ( 349 | (max_metric_calls is not None) + 350 | (max_full_evals is not None) + 351 | (auto is not None) 352 | == 1 353 | ), ( 354 | "Exactly one of max_metric_calls, max_full_evals, auto must be set. " 355 | f"You set max_metric_calls={max_metric_calls}, " 356 | f"max_full_evals={max_full_evals}, " 357 | f"auto={auto}." 358 | ) 359 | self.auto = auto 360 | self.max_full_evals = max_full_evals 361 | self.max_metric_calls = max_metric_calls 362 | 363 | # Reflection configuration 364 | self.reflection_minibatch_size = reflection_minibatch_size 365 | self.candidate_selection_strategy = candidate_selection_strategy 366 | 367 | assert reflection_lm is not None or instruction_proposer is not None, ( 368 | "GEPA requires a reflection language model, or custom instruction proposer to be provided. " 369 | "Typically, you can use `dspy.LM(model='gpt-5', temperature=1.0, max_tokens=32000)` to get a good reflection model. " 370 | "Reflection LM is used by GEPA to reflect on the behavior of the program and propose new instructions, and will benefit from a strong model. " 371 | ) 372 | 373 | self.reflection_lm = reflection_lm 374 | self.skip_perfect_score = skip_perfect_score 375 | self.add_format_failure_as_feedback = add_format_failure_as_feedback 376 | 377 | # Merge-based configuration 378 | self.use_merge = use_merge 379 | self.max_merge_invocations = max_merge_invocations 380 | 381 | # Evaluation Configuration 382 | self.num_threads = num_threads 383 | self.failure_score = failure_score 384 | self.perfect_score = perfect_score 385 | 386 | # Logging configuration 387 | self.log_dir = log_dir 388 | self.track_stats = track_stats 389 | self.use_wandb = use_wandb 390 | self.wandb_api_key = wandb_api_key 391 | self.wandb_init_kwargs = wandb_init_kwargs 392 | self.warn_on_score_mismatch = warn_on_score_mismatch 393 | self.use_mlflow = use_mlflow 394 | 395 | if track_best_outputs: 396 | assert track_stats, "track_stats must be True if track_best_outputs is True." 397 | self.track_best_outputs = track_best_outputs 398 | 399 | # Reproducibility 400 | self.seed = seed 401 | 402 | self.custom_instruction_proposer = instruction_proposer 403 | self.component_selector = component_selector 404 | self.gepa_kwargs = gepa_kwargs or {} 405 | 406 | def auto_budget(self, num_preds, num_candidates, valset_size: int, minibatch_size: int = 35, full_eval_steps: int = 5) -> int: 407 | import numpy as np 408 | num_trials = int(max(2 * (num_preds * 2) * np.log2(num_candidates), 1.5 * num_candidates)) 409 | if num_trials < 0 or valset_size < 0 or minibatch_size < 0: 410 | raise ValueError("num_trials, valset_size, and minibatch_size must be >= 0.") 411 | if full_eval_steps < 1: 412 | raise ValueError("full_eval_steps must be >= 1.") 413 | 414 | V = valset_size 415 | N = num_trials 416 | M = minibatch_size 417 | m = full_eval_steps 418 | 419 | # Initial full evaluation on the default program 420 | total = V 421 | 422 | # Assume upto 5 trials for bootstrapping each candidate 423 | total += num_candidates * 5 424 | 425 | # N minibatch evaluations 426 | total += N * M 427 | if N == 0: 428 | return total # no periodic/full evals inside the loop 429 | # Periodic full evals occur when trial_num % (m+1) == 0, where trial_num runs 2..N+1 430 | periodic_fulls = (N + 1) // (m) + 1 431 | # If 1 <= N < m, the code triggers one final full eval at the end 432 | extra_final = 1 if N < m else 0 433 | 434 | total += (periodic_fulls + extra_final) * V 435 | return total 436 | 437 | def compile( 438 | self, 439 | student: Module, 440 | *, 441 | trainset: list[Example], 442 | teacher: Module | None = None, 443 | valset: list[Example] | None = None, 444 | ) -> Module: 445 | """ 446 | GEPA uses the trainset to perform reflective updates to the prompt, but uses the valset for tracking Pareto scores. 447 | If no valset is provided, GEPA will use the trainset for both. 448 | 449 | Parameters: 450 | - student: The student module to optimize. 451 | - trainset: The training set to use for reflective updates. 452 | - valset: The validation set to use for tracking Pareto scores. If not provided, GEPA will use the trainset for both. 453 | """ 454 | from gepa import GEPAResult, optimize 455 | 456 | from dspy.teleprompt.gepa.gepa_utils import DspyAdapter, LoggerAdapter 457 | 458 | assert trainset is not None and len(trainset) > 0, "Trainset must be provided and non-empty" 459 | assert teacher is None, "Teacher is not supported in DspyGEPA yet." 460 | 461 | if self.auto is not None: 462 | self.max_metric_calls = self.auto_budget( 463 | num_preds=len(student.predictors()), 464 | num_candidates=AUTO_RUN_SETTINGS[self.auto]["n"], 465 | valset_size=len(valset) if valset is not None else len(trainset), 466 | ) 467 | elif self.max_full_evals is not None: 468 | self.max_metric_calls = self.max_full_evals * (len(trainset) + (len(valset) if valset is not None else 0)) 469 | else: 470 | assert self.max_metric_calls is not None, "Either auto, max_full_evals, or max_metric_calls must be set." 471 | 472 | logger.info(f"Running GEPA for approx {self.max_metric_calls} metric calls of the program. This amounts to {self.max_metric_calls / len(trainset) if valset is None else self.max_metric_calls / (len(trainset) + len(valset)):.2f} full evals on the {'train' if valset is None else 'train+val'} set.") 473 | 474 | if valset is None: 475 | logger.warning("No valset provided; Using trainset as valset. This is useful as an inference-time scaling strategy where you want GEPA to find the best solutions for the provided tasks in the trainset, as it makes GEPA overfit prompts to the provided trainset. In order to ensure generalization and perform well on unseen tasks, please provide separate trainset and valset. Provide the smallest valset that is just large enough to match the downstream task distribution, while keeping trainset as large as possible.") 476 | valset = valset or trainset 477 | logger.info(f"Using {len(valset)} examples for tracking Pareto scores. You can consider using a smaller sample of the valset to allow GEPA to explore more diverse solutions within the same budget. GEPA requires you to provide the smallest valset that is just large enough to match your downstream task distribution, while providing as large trainset as possible.") 478 | 479 | rng = random.Random(self.seed) 480 | 481 | def feedback_fn_creator(pred_name: str, predictor) -> "PredictorFeedbackFn": 482 | def feedback_fn( 483 | predictor_output: dict[str, Any], 484 | predictor_inputs: dict[str, Any], 485 | module_inputs: Example, 486 | module_outputs: Prediction, 487 | captured_trace: "DSPyTrace", 488 | ) -> "ScoreWithFeedback": 489 | trace_for_pred = [(predictor, predictor_inputs, predictor_output)] 490 | o = self.metric_fn( 491 | module_inputs, 492 | module_outputs, 493 | captured_trace, 494 | pred_name, 495 | trace_for_pred, 496 | ) 497 | if hasattr(o, "feedback"): 498 | if o["feedback"] is None: 499 | o["feedback"] = f"This trajectory got a score of {o['score']}." 500 | return o 501 | else: 502 | return dict(score=o, feedback=f"This trajectory got a score of {o}.") 503 | return feedback_fn 504 | 505 | feedback_map = { 506 | k: feedback_fn_creator(k, v) 507 | for k, v in student.named_predictors() 508 | } 509 | 510 | # Build the DSPy adapter that encapsulates evaluation, trace capture, feedback extraction, and instruction proposal 511 | adapter = DspyAdapter( 512 | student_module=student, 513 | metric_fn=self.metric_fn, 514 | feedback_map=feedback_map, 515 | failure_score=self.failure_score, 516 | num_threads=self.num_threads, 517 | add_format_failure_as_feedback=self.add_format_failure_as_feedback, 518 | rng=rng, 519 | reflection_lm=self.reflection_lm, 520 | custom_instruction_proposer=self.custom_instruction_proposer, 521 | warn_on_score_mismatch=self.warn_on_score_mismatch 522 | ) 523 | 524 | # Instantiate GEPA with the simpler adapter-based API 525 | base_program = {name: pred.signature.instructions for name, pred in student.named_predictors()} 526 | gepa_result: GEPAResult = optimize( 527 | seed_candidate=base_program, 528 | trainset=trainset, 529 | valset=valset, 530 | adapter=adapter, 531 | 532 | # Reflection-based configuration 533 | reflection_lm=(lambda x: self.reflection_lm(x)[0]) if self.reflection_lm is not None else None, 534 | candidate_selection_strategy=self.candidate_selection_strategy, 535 | skip_perfect_score=self.skip_perfect_score, 536 | reflection_minibatch_size=self.reflection_minibatch_size, 537 | module_selector=self.component_selector, 538 | 539 | perfect_score=self.perfect_score, 540 | 541 | # Merge-based configuration 542 | use_merge=self.use_merge, 543 | max_merge_invocations=self.max_merge_invocations, 544 | 545 | # Budget 546 | max_metric_calls=self.max_metric_calls, 547 | 548 | # Logging 549 | logger=LoggerAdapter(logger), 550 | run_dir=self.log_dir, 551 | use_wandb=self.use_wandb, 552 | wandb_api_key=self.wandb_api_key, 553 | wandb_init_kwargs=self.wandb_init_kwargs, 554 | use_mlflow=self.use_mlflow, 555 | track_best_outputs=self.track_best_outputs, 556 | display_progress_bar=True, 557 | raise_on_exception=True, 558 | 559 | # Reproducibility 560 | seed=self.seed, 561 | **self.gepa_kwargs 562 | ) 563 | 564 | new_prog = adapter.build_program(gepa_result.best_candidate) 565 | 566 | if self.track_stats: 567 | dspy_gepa_result = DspyGEPAResult.from_gepa_result(gepa_result, adapter) 568 | new_prog.detailed_results = dspy_gepa_result 569 | 570 | return new_prog 571 | ``` -------------------------------------------------------------------------------- /tests/reliability/generate/utils.py: -------------------------------------------------------------------------------- ```python 1 | import importlib.util 2 | import json 3 | import os 4 | import pathlib 5 | import random 6 | import re 7 | import shutil 8 | import sys 9 | import tempfile 10 | from contextlib import contextmanager 11 | from dataclasses import dataclass 12 | from functools import wraps 13 | from typing import Any, Dict, List, Optional, Tuple 14 | 15 | import pydantic 16 | from datamodel_code_generator import InputFileType, generate 17 | 18 | import dspy 19 | from tests.reliability.utils import assert_program_output_correct, judge_dspy_configuration 20 | 21 | 22 | def _retry(retries): 23 | """ 24 | A decorator to retry a function a specified number of times. 25 | 26 | Args: 27 | retries (int): The number of retries before failing. 28 | """ 29 | 30 | def decorator(func): 31 | @wraps(func) 32 | def wrapper(*args, **kwargs): 33 | attempt = 0 34 | while attempt < retries: 35 | try: 36 | return func(*args, **kwargs) 37 | except Exception as e: 38 | attempt += 1 39 | print(f"Retrying {func.__name__} (attempt {attempt} of {retries})." f" Exception: {e}") 40 | if attempt >= retries: 41 | raise e 42 | 43 | return wrapper 44 | 45 | return decorator 46 | 47 | 48 | @_retry(retries=5) 49 | def generate_test_program(dst_path: str, additional_instructions: Optional[str] = None) -> dspy.Module: 50 | """ 51 | Generate a DSPy program for a reliability test case and save it to a destination path. 52 | 53 | Args: 54 | dst_path: The directory path to which to save the generated program. 55 | additional_instructions: Additional instructions for generating the program signature. 56 | Return: 57 | A dspy.Module object representing the generated program. 58 | """ 59 | 60 | def generate_models(schema: dict[str, Any], class_name: str) -> str: 61 | with tempfile.TemporaryDirectory() as tmp_dir: 62 | tmp_schema_path = os.path.join(tmp_dir, "schema.json") 63 | tmp_model_path = os.path.join(tmp_dir, "model.py") 64 | with open(tmp_schema_path, "w") as f: 65 | json.dump(schema, f) 66 | 67 | generate( 68 | input_=pathlib.Path(tmp_schema_path), 69 | input_file_type=InputFileType.JsonSchema, 70 | output=pathlib.Path(tmp_model_path), 71 | class_name=class_name, 72 | # For enums with only one value, use the value as a literal instead of an enum 73 | # in order to test literals 74 | enum_field_as_literal="one", 75 | # Don't use con* field types, which are deprecated in recent pydantic versions 76 | field_constraints=True, 77 | use_annotated=False, 78 | ) 79 | # Remove annotation imports from __future__, which break compatibility with Python's 80 | # built-in type hints 81 | _remove_line_from_file(tmp_model_path, "from __future__ import annotations") 82 | # Remove comments inserted by datamodel-code-generator from the generated model file 83 | _remove_comments_from_file(tmp_model_path) 84 | with open(tmp_model_path, "r") as f: 85 | return f.read() 86 | 87 | def rename_conflicting_fields( 88 | input_schema: dict[str, Any], 89 | output_schema: dict[str, Any], 90 | ) -> dict[str, Any]: 91 | input_fields = set(input_schema.get("properties", {})) 92 | output_schema["properties"] = { 93 | (f"{field}_output" if field in input_fields else field): properties 94 | for field, properties in output_schema.get("properties", {}).items() 95 | } 96 | # Update required fields, if they exist 97 | if "required" in output_schema: 98 | output_schema["required"] = [ 99 | f"{field}_output" if field in input_fields else field for field in output_schema["required"] 100 | ] 101 | return output_schema 102 | 103 | # Disable caching and use a nonzero temperature to ensure that new programs are generated 104 | # upon retry if there's an error in the generation process (e.g. the program has an 105 | # invalid signature) 106 | with judge_dspy_configuration(cache=False, temperature=0.5), tempfile.TemporaryDirectory() as tmp_dir: 107 | generated_signature = _get_test_program_generation_program()( 108 | additional_instructions=additional_instructions or "" 109 | ) 110 | input_schema = json.loads(generated_signature.program_input_fields) 111 | output_schema = json.loads(generated_signature.program_output_fields) 112 | # If there are conflicting field names between input and output schemas, rename the output 113 | # fields to avoid conflicts 114 | output_schema = rename_conflicting_fields(input_schema, output_schema) 115 | 116 | # Generate input and output models 117 | input_models = generate_models(schema=input_schema, class_name="ProgramInputs") 118 | output_models = generate_models(schema=output_schema, class_name="ProgramOutputs") 119 | 120 | # Write program code 121 | program_code = ( 122 | "### Input models ###\n" 123 | + input_models 124 | + "\n" 125 | + "### Output models ###\n" 126 | + output_models 127 | + "\n" 128 | + "### Program definition ###\n" 129 | + _get_test_program_signature_and_module_definition( 130 | program_description=generated_signature.program_description 131 | ) 132 | ) 133 | program_path = os.path.join(tmp_dir, "program.py") 134 | with open(program_path, "w") as f: 135 | f.write(program_code) 136 | 137 | # Validate the generated program by loading it before copying it to the destination path 138 | loaded_program, _ = load_generated_program(program_path) 139 | 140 | # Write schema 141 | _write_pretty_json( 142 | data=_clean_schema(_get_json_schema(loaded_program.signature)), 143 | path=os.path.join(tmp_dir, "schema.json"), 144 | ) 145 | 146 | # Copy all generated files to the destination path 147 | os.makedirs(dst_path, exist_ok=True) 148 | shutil.copytree(tmp_dir, dst_path, dirs_exist_ok=True) 149 | 150 | return loaded_program 151 | 152 | 153 | @_retry(retries=5) 154 | def generate_test_inputs( 155 | dst_path: str, 156 | program_path: str, 157 | num_inputs: int, 158 | additional_instructions: Optional[str] = None, 159 | ): 160 | """ 161 | Generate test inputs for a reliability test case and save them to a destination path. 162 | 163 | Args: 164 | dst_path: The directory path to which to save the generated test inputs. 165 | program_path: The path to the program for which to generate test inputs. 166 | num_inputs: The number of test inputs to generate. 167 | additional_instructions: Additional instructions for generating the test inputs. 168 | """ 169 | # Disable caching and use a nonzero temperature to ensure that new inputs are generated 170 | # upon retry if there's an error in the generation process (e.g. the input doesn't match the 171 | # program signature) 172 | with judge_dspy_configuration(cache=False, temperature=0.5), tempfile.TemporaryDirectory() as tmp_dir: 173 | program: dspy.Module 174 | program_input_schema: pydantic.BaseModel 175 | program, program_input_schema = load_generated_program(program_path) 176 | signature_json_schema = _get_json_schema(program.signature) 177 | inputs, outputs = _split_schema(signature_json_schema) 178 | generated_test_inputs = _get_test_inputs_generation_program()( 179 | program_description=program.signature.__doc__ or "", 180 | program_input_signature=_write_pretty_json({"properties": _clean_schema(inputs)}), 181 | program_output_signature=_write_pretty_json({"properties": _clean_schema(outputs)}), 182 | additional_instructions=additional_instructions or "", 183 | num_inputs=num_inputs, 184 | ).test_inputs[:num_inputs] 185 | 186 | def find_max_input_number(directory): 187 | if not os.path.exists(directory): 188 | return 0 189 | 190 | max_number = 0 191 | pattern = re.compile(r"input(\d+)\.json") 192 | 193 | for filename in os.listdir(directory): 194 | match = pattern.match(filename) 195 | if match: 196 | number = int(match.group(1)) 197 | max_number = max(max_number, number) 198 | return max_number 199 | 200 | base_input_number = find_max_input_number(dst_path) + 1 201 | for idx, test_input in enumerate(generated_test_inputs): 202 | output_assertions = _get_assertions_generation_program()( 203 | program_description=program.signature.__doc__ or "", 204 | program_input=test_input.program_input, 205 | program_output_signature=_write_pretty_json({"properties": _clean_schema(outputs)}), 206 | ).output_assertions 207 | 208 | # Verify that the generated input is valid JSON and matches the input signature of the 209 | # program before saving it to the destination path 210 | _json_input_to_program_input( 211 | input_schema=program_input_schema, 212 | json_input=test_input.program_input, 213 | ) 214 | 215 | test_input_file_path = os.path.join(tmp_dir, f"input{base_input_number + idx}.json") 216 | json_program_input = json.loads(test_input.program_input) 217 | _write_pretty_json( 218 | data={ 219 | "input": json_program_input, 220 | "assertions": output_assertions, 221 | }, 222 | path=test_input_file_path, 223 | ) 224 | 225 | os.makedirs(dst_path, exist_ok=True) 226 | shutil.copytree(tmp_dir, dst_path, dirs_exist_ok=True) 227 | 228 | 229 | def load_generated_program(path) -> Tuple[dspy.Module, pydantic.BaseModel]: 230 | """ 231 | Loads a generated program from the specified file. 232 | 233 | Args: 234 | path: The path to the file containing the generated program. 235 | Returns: 236 | A tuple containing: 1. a dspy.Module object representing the generated program 237 | and 2. a pydantic.BaseModel object representing the program's input schema. 238 | """ 239 | if os.path.isdir(path): 240 | path = os.path.join(path, "program.py") 241 | if not os.path.exists(path): 242 | raise ValueError(f"DSPy test program file not found: {path}") 243 | 244 | program_module = _import_program_module_from_path(module_name="program", file_path=path) 245 | return program_module.program, program_module.ProgramInputs 246 | 247 | 248 | @dataclass 249 | class GeneratedTestCase: 250 | """ 251 | Represents a DSPy reliability test case that has been generated with the help of a 252 | DSPy program generator and program input generator. 253 | """ 254 | 255 | # The name of the test case for identification / debugging with pytest 256 | name: str 257 | # The local filesystem path to the program that the test case is testing. 258 | program_path: str 259 | # A JSON representation of the input to the program that the test case is testing. 260 | program_input: str 261 | # The assertions that the output of the program must satisfy for the test case to pass. 262 | output_assertions: list[str] 263 | 264 | 265 | def load_generated_cases(dir_path) -> list[GeneratedTestCase]: 266 | """ 267 | Recursively loads generated test cases from the specified directory and its subdirectories. 268 | 269 | Args: 270 | dir_path: The path to the directory containing the generated test cases. 271 | Returns: 272 | A list of GeneratedTestCase objects. 273 | """ 274 | test_cases = [] 275 | 276 | # Walk through all directories and subdirectories in dir_path 277 | for root, dirs, files in os.walk(dir_path): 278 | # Check if the directory contains a program.py and an inputs directory 279 | if "program.py" in files and "inputs" in dirs: 280 | program_path = os.path.join(root, "program.py") 281 | inputs_path = os.path.join(root, "inputs") 282 | 283 | # Load each JSON test input file in the inputs directory 284 | for input_file in os.listdir(inputs_path): 285 | if input_file.endswith(".json"): 286 | with open(os.path.join(inputs_path, input_file), "r") as f: 287 | # Best effort to extract a meaningful enclosing directory name 288 | # from the test path that can be used as part of the test case name 289 | readable_dir_name = os.path.basename(os.path.dirname(os.path.dirname(root))) 290 | test_case_name = ( 291 | f"{readable_dir_name}-" f"{os.path.basename(root)}-" f"{os.path.splitext(input_file)[0]}" 292 | ) 293 | program_input_and_assertions = json.load(f) 294 | program_input = program_input_and_assertions["input"] 295 | assertions = program_input_and_assertions["assertions"] 296 | 297 | # Create a GeneratedTestCase object and add it to the list 298 | test_cases.append( 299 | GeneratedTestCase( 300 | name=test_case_name, 301 | program_path=program_path, 302 | program_input=json.dumps(program_input), 303 | output_assertions=assertions, 304 | ) 305 | ) 306 | 307 | return test_cases 308 | 309 | 310 | def run_generated_case(generated_case: GeneratedTestCase): 311 | """ 312 | Runs a generated reliability test case by 1. running the test case program on the test case 313 | input using the global DSPy configuration and 2. verifying that the output of the program 314 | satisfies the assertions specified in the test case. 315 | 316 | Args: 317 | generated_case: The generated test case to run. 318 | """ 319 | program, program_input_schema = load_generated_program(generated_case.program_path) 320 | program_input = _json_input_to_program_input( 321 | input_schema=program_input_schema, 322 | json_input=generated_case.program_input, 323 | ) 324 | program_output = program(**program_input) 325 | for assertion in generated_case.output_assertions: 326 | assert_program_output_correct( 327 | program_input=program_input, 328 | program_output=program_output, 329 | grading_guidelines=assertion, 330 | ) 331 | 332 | 333 | def _get_test_program_signature_and_module_definition(program_description: str) -> str: 334 | """ 335 | Generate the signature and model definition for a test DSPy program. 336 | 337 | Args: 338 | program_description: A description of the generated program. 339 | """ 340 | use_cot = random.choice([True, False]) 341 | if use_cot: 342 | program_var_definition = "program = dspy.ChainOfThought(program_signature)" 343 | else: 344 | program_var_definition = "program = dspy.Predict(program_signature)" 345 | 346 | return ''' 347 | import dspy 348 | 349 | class BaseSignature(dspy.Signature): 350 | """ 351 | {program_description} 352 | """ 353 | 354 | program_signature = BaseSignature 355 | for input_field_name, input_field in ProgramInputs.model_fields.items(): 356 | program_signature = program_signature.append( 357 | name=input_field_name, 358 | field=dspy.InputField(description=input_field.description), 359 | type_=input_field.annotation, 360 | ) 361 | for output_field_name, output_field in ProgramOutputs.model_fields.items(): 362 | program_signature = program_signature.append( 363 | name=output_field_name, 364 | field=dspy.OutputField(description=input_field.description), 365 | type_=output_field.annotation, 366 | ) 367 | 368 | {program_var_definition} 369 | '''.format(program_description=program_description, program_var_definition=program_var_definition) 370 | 371 | 372 | def _get_test_program_generation_program() -> dspy.Module: 373 | """ 374 | Create a DSPy program for generating other DSPy test programs. 375 | 376 | Returns: 377 | A dspy.Module object representing the program generation program. 378 | """ 379 | 380 | class ProgramGeneration(dspy.Signature): 381 | """ 382 | Creates an AI program definition, including the AI program's description, input fields, and output fields. 383 | The AI program should be designed to solve a real problem for its users and produce correct outputs for a variety of inputs. 384 | 385 | The input fields and the output fields must be represented in JSON Schema format, including field names, types, and descriptions. 386 | The JSON schema definitions themselves MUST be valid JSON without any extra text (no backticks, no explanatory text, etc.). 387 | 388 | It's very important to be sure that the additional instructions, if specified, are obeyed 389 | precisely in absolutely all cases. 390 | """ 391 | 392 | additional_instructions: str = dspy.InputField( 393 | description="Additional instructions for what kind of program to generate and how to generate it" 394 | ) 395 | program_description: str = dspy.OutputField( 396 | description="A description of the generated AI program, including its purpose and expected behavior" 397 | ) 398 | program_input_fields: str = dspy.OutputField( 399 | description="The input fields of the generated program in JSON Schema format, including input field names, types, and descriptions." 400 | ) 401 | program_output_fields: str = dspy.OutputField( 402 | description="The output fields of the generated program in JSON Schema format, including input field names, types, and descriptions." 403 | ) 404 | 405 | return dspy.ChainOfThought(ProgramGeneration) 406 | 407 | 408 | def _get_test_inputs_generation_program() -> dspy.Module: 409 | """ 410 | Create a DSPy program for generating test inputs for a given DSPy test program. 411 | 412 | Returns: 413 | A dspy.Module object representing the test input generation program. 414 | """ 415 | 416 | class _TestInputsGeneration(dspy.Signature): 417 | """ 418 | Given the description and input / output signature (format) of an AI program that is designed to produce correct outputs for a variety 419 | of inputs while adhering to the input / output signature, generate test inputs used to verify that the program 420 | indeed produces correct outputs. The AI program uses LLM prompting with carefully crafted prompt templates to generate 421 | responses. 422 | 423 | When generating an input, do not think about how the program will respond. Instead, focus on creating 424 | valid and interesting inputs that are likely to test the program's capabilities. 425 | 426 | It's very important to be sure that the additional instructions, if specified, are obeyed 427 | precisely in absolutely all cases. 428 | """ 429 | 430 | program_description: str = dspy.InputField( 431 | description="A description of the AI program being tested, including its purpose and expected behavior" 432 | ) 433 | program_input_signature: str = dspy.InputField( 434 | description="The input signature of the program in JSON Schema format, including input field names, types, and descriptions. The outermost fields in the JSON schema definition represent the top-level input fields of the program." 435 | ) 436 | program_output_signature: str = dspy.InputField( 437 | description="The output signature of the program in JSON Schema format, including output field names, types, and descriptions. The outermost fields in the JSON schema definition represent the top-level output fields of the program." 438 | ) 439 | additional_instructions: str = dspy.InputField(description="Additional instructions for generating test inputs") 440 | test_inputs: list[_TestInput] = dspy.OutputField( 441 | description="Generated test inputs for the program, used to verify the correctness of the program outputs for a variety of inputs" 442 | ) 443 | 444 | return dspy.ChainOfThought(_TestInputsGeneration) 445 | 446 | 447 | class _TestInput(pydantic.BaseModel): 448 | """ 449 | Represents a generated test input for a DSPy program. 450 | """ 451 | 452 | program_input: str = pydantic.Field( 453 | "Generated input matching the program signature that will be used to test the program, represented as a JSON string." 454 | " The schema of the JSON string must match the input signature of the program precisely, including any wrapper objects." 455 | " Be very careful to ensure that the input is valid JSON and matches the input signature of the program, with correct" 456 | " field nesting." 457 | ) 458 | 459 | 460 | def _get_assertions_generation_program() -> dspy.Module: 461 | """ 462 | Create a DSPy program for generating assertions that verify the correctness of outputs 463 | from other DSPy programs. 464 | """ 465 | 466 | class _TestInputsGeneration(dspy.Signature): 467 | """ 468 | Given 1. the description and input / output signature (format) of an AI program that is designed to produce correct outputs for a variety 469 | of inputs while adhering to the input / output signature and 2. an example input to the AI program, generate assertions that can be used 470 | to verify the correctness of the program output. 471 | 472 | Assertions should be expressed in natural language where possible, rather than code. Only 473 | include code if necessary to clarify the assertion. Assertions should be objective and verifiable, 474 | with minimal subjectivity only where absolutely necessary. 475 | 476 | There should be a limited number of assertions, ideally about 5, that are sufficient to 477 | verify the correctness of the program output. 478 | 479 | If it's too difficult to generate accurate assertions, leave them blank. 480 | """ 481 | 482 | program_description: str = dspy.InputField( 483 | description="A description of the AI program being tested, including its purpose and expected behavior" 484 | ) 485 | program_input: str = dspy.InputField( 486 | description="An example input to the AI program, represented as a JSON string" 487 | ) 488 | program_output_signature: str = dspy.InputField( 489 | description="The output signature of the program in JSON Schema format, including output field names, types, and descriptions. The outermost fields in the JSON schema definition represent the top-level output fields of the program." 490 | ) 491 | output_assertions: list[str] = dspy.OutputField( 492 | description="Assertions used to verify the correctness of the program output after running the program on the specified input" 493 | ) 494 | 495 | return dspy.ChainOfThought(_TestInputsGeneration) 496 | 497 | 498 | def _clean_json_schema_property(prop: dict[str, Any]) -> dict[str, Any]: 499 | """ 500 | Remove unnecessary keys from a JSON schema property dictionary, as well as 501 | all of its child properties. 502 | 503 | Args: 504 | prop: The JSON schema property dictionary to clean. 505 | Returns: 506 | The cleaned JSON schema property dictionary. 507 | """ 508 | cleaned_prop = { 509 | k: v for k, v in prop.items() if k not in {"desc", "__dspy_field_type", "title", "prefix", "required"} 510 | } 511 | 512 | # Recursively clean nested properties 513 | if "properties" in cleaned_prop: 514 | cleaned_prop["properties"] = {k: _clean_json_schema_property(v) for k, v in cleaned_prop["properties"].items()} 515 | 516 | return cleaned_prop 517 | 518 | 519 | def _get_json_schema(signature: dspy.Signature) -> dict[str, Any]: 520 | """ 521 | Obtain the JSON schema representation of a DSPy signature. 522 | 523 | Args: 524 | signature: The DSPy signature for which to generate a JSON schema. 525 | Returns: 526 | A JSON schema representation of the signature. 527 | """ 528 | 529 | def expand_refs(schema: dict[str, Any], definitions: dict[str, Any]) -> dict[str, Any]: 530 | """ 531 | Expand $ref fields in a JSON schema, inlining the referenced schema definitions 532 | directly into the $ref field locations. 533 | """ 534 | if isinstance(schema, dict): 535 | if "$ref" in schema: 536 | ref_path = schema["$ref"].replace("#/$defs/", "") 537 | ref_schema = definitions.get(ref_path, {}) 538 | if "__dspy_field_type" in schema: 539 | ref_schema["__dspy_field_type"] = schema["__dspy_field_type"] 540 | # Recursively expand the reference schema as well 541 | return expand_refs(ref_schema, definitions) 542 | else: 543 | # Recursively expand properties in the schema 544 | return {key: expand_refs(value, definitions) for key, value in schema.items()} 545 | elif isinstance(schema, list): 546 | return [expand_refs(item, definitions) for item in schema] 547 | return schema 548 | 549 | signature_schema_with_refs = signature.schema() 550 | definitions = signature_schema_with_refs.pop("$defs", {}) 551 | return expand_refs(signature_schema_with_refs, definitions) 552 | 553 | 554 | def _split_schema(schema: dict[str, Any]) -> Tuple[dict[str, Any], dict[str, Any]]: 555 | """ 556 | Split a JSON schema into input and output components based on DSPy field types. 557 | 558 | Args: 559 | schema: The JSON schema to split. 560 | Returns: 561 | A tuple containing the input and output components of the schema. 562 | """ 563 | inputs = {} 564 | outputs = {} 565 | 566 | # Traverse the properties to categorize inputs and outputs 567 | for key, prop in schema.get("properties", {}).items(): 568 | # Clean the property 569 | cleaned_prop = _clean_schema(prop) 570 | 571 | # Determine if the property is input or output based on __dspy_field_type 572 | field_type = prop.get("__dspy_field_type") 573 | if field_type == "input": 574 | inputs[key] = cleaned_prop 575 | elif field_type == "output" or field_type is None: 576 | outputs[key] = cleaned_prop 577 | 578 | # Handle nested properties for complex models 579 | if "properties" in prop: 580 | nested_inputs, nested_outputs = _split_schema(prop) 581 | if nested_inputs and field_type == "input": 582 | inputs[key] = {"properties": nested_inputs, **cleaned_prop} 583 | elif nested_outputs and (field_type == "output" or field_type is None): 584 | outputs[key] = {"properties": nested_outputs, **cleaned_prop} 585 | 586 | return inputs, outputs 587 | 588 | 589 | def _clean_schema(prop: dict[str, Any]) -> dict[str, Any]: 590 | """ 591 | Recursively clean a JSON schema property by removing unnecessary keys. 592 | 593 | Args: 594 | prop: The JSON schema property to clean. 595 | Returns: 596 | A cleaned version of the property. 597 | """ 598 | keys_to_remove = ["__dspy_field_type", "title"] # Add any other keys to be removed here 599 | 600 | # Iterate through the dictionary, applying cleaning recursively if value is a nested dict 601 | cleaned_prop = { 602 | k: (_clean_schema(v) if isinstance(v, dict) else v) # Recurse if value is a dict 603 | for k, v in prop.items() 604 | if k not in keys_to_remove 605 | } 606 | return cleaned_prop 607 | 608 | 609 | def _json_input_to_program_input(input_schema: pydantic.BaseModel, json_input: str) -> dict[str, Any]: 610 | """ 611 | Convert a JSON input string to a DSPy program input dictionary, validating it against the 612 | provided program signature. 613 | 614 | Args: 615 | input_schema: A pydantic model representing the program input schema. 616 | json_input: The JSON input string to convert to a DSPy program input. 617 | Returns: 618 | The converted DSPy program input dictionary. 619 | """ 620 | json_input = json.loads(json_input) 621 | program_input: pydantic.BaseModel = input_schema.model_validate(json_input) 622 | return {field: getattr(program_input, field) for field in program_input.__fields__} 623 | 624 | 625 | @contextmanager 626 | def _temporarily_prepend_to_system_path(path): 627 | """ 628 | Temporarily prepend a path to the system path for the duration of a context. 629 | 630 | Args: 631 | path: The path to prepend to the system path. 632 | """ 633 | original_sys_path = sys.path.copy() 634 | try: 635 | sys.path.insert(0, path) 636 | yield 637 | finally: 638 | sys.path = original_sys_path 639 | 640 | 641 | def _import_program_module_from_path(module_name: str, file_path: str): 642 | """ 643 | Import a Python module containing a DSPy program from a specified file path. 644 | 645 | Args: 646 | module_name: The name of the module containing the DSPy program to import. 647 | file_path: The path to the file containing the module definition. 648 | """ 649 | program_dir = os.path.dirname(file_path) 650 | 651 | with _temporarily_prepend_to_system_path(program_dir): 652 | spec = importlib.util.spec_from_file_location(module_name, file_path) 653 | module = importlib.util.module_from_spec(spec) 654 | spec.loader.exec_module(module) 655 | return module 656 | 657 | 658 | def _remove_line_from_file(file_path: str, line_to_remove: str): 659 | """ 660 | Remove all instances of a specific line from a file. 661 | 662 | Args: 663 | file_path: The path to the file from which to remove all instances of the line. 664 | line_to_remove: The line to remove from the file. 665 | """ 666 | # Read all lines from the file 667 | with open(file_path, "r") as file: 668 | lines = file.readlines() 669 | 670 | # Write all lines back except the one to remove 671 | with open(file_path, "w") as file: 672 | for line in lines: 673 | if line.strip() != line_to_remove: 674 | file.write(line) 675 | 676 | 677 | def _remove_comments_from_file(file_path: str) -> None: 678 | """ 679 | Removes all lines with comments from the specified file. 680 | 681 | Args: 682 | file_path: Path to the file where comments should be removed. 683 | """ 684 | # Read the file contents 685 | with open(file_path, "r") as file: 686 | lines = file.readlines() 687 | 688 | # Filter out lines that start with '#' 689 | cleaned_lines = [line for line in lines if not line.strip().startswith("#")] 690 | 691 | # Write the cleaned lines back to the file 692 | with open(file_path, "w") as file: 693 | file.writelines(cleaned_lines) 694 | 695 | 696 | def _write_pretty_json(data: dict[str, Any], path: Optional[str] = None) -> Optional[str]: 697 | """ 698 | Format JSON data with indentation, and write it to a file if specified. 699 | 700 | Args: 701 | data: The JSON data to format. 702 | path: The optional path to which to write the formatted JSON data. 703 | Returns: 704 | The formatted JSON data as a string, if no path is specified. 705 | """ 706 | formatted_json = json.dumps(data, indent=4) 707 | if path: 708 | with open(path, "w") as f: 709 | f.write(formatted_json) 710 | return None 711 | else: 712 | return formatted_json 713 | ``` -------------------------------------------------------------------------------- /tests/teleprompt/test_gepa.py: -------------------------------------------------------------------------------- ```python 1 | import json 2 | import threading 3 | from typing import Any 4 | from unittest import mock 5 | 6 | import pytest 7 | 8 | import dspy 9 | import dspy.clients 10 | from dspy import Example 11 | from dspy.predict import Predict 12 | from dspy.teleprompt.gepa import instruction_proposal 13 | from dspy.utils.dummies import DummyLM 14 | 15 | 16 | class SimpleModule(dspy.Module): 17 | def __init__(self, signature): 18 | super().__init__() 19 | self.predictor = Predict(signature) 20 | 21 | def forward(self, **kwargs): 22 | return self.predictor(**kwargs) 23 | 24 | 25 | class DictDummyLM(dspy.clients.lm.LM): 26 | def __init__(self, history): 27 | super().__init__("dummy", "chat", 0.0, 1000, True) 28 | self.history = {} 29 | for m in history: 30 | self.history[hash(repr(m["messages"]))] = m 31 | 32 | def __call__(self, prompt=None, messages=None, **kwargs): 33 | assert hash(repr(messages)) in self.history, f"Message {messages} not found in history" 34 | m = self.history[hash(repr(messages))] 35 | return m["outputs"] 36 | 37 | 38 | def simple_metric(example, prediction, trace=None, pred_name=None, pred_trace=None): 39 | return dspy.Prediction(score=example.output == prediction.output, feedback="Wrong answer.") 40 | 41 | 42 | def bad_metric(example, prediction): 43 | return 0.0 44 | 45 | 46 | def test_gepa_adapter_disables_logging_during_trace_capture(monkeypatch): 47 | from dspy.teleprompt import bootstrap_trace as bootstrap_trace_module 48 | from dspy.teleprompt.gepa import gepa_utils 49 | 50 | class DummyModule(dspy.Module): 51 | def forward(self, **kwargs): # pragma: no cover - stub forward 52 | return dspy.Prediction() 53 | 54 | # Exercise the adapter evaluate path directly. 55 | adapter = gepa_utils.DspyAdapter( 56 | student_module=SimpleModule("input -> output"), 57 | metric_fn=simple_metric, 58 | feedback_map={}, 59 | failure_score=0.0, 60 | ) 61 | 62 | captured_kwargs: dict[str, Any] = {} 63 | 64 | def dummy_bootstrap_trace_data(*args, **kwargs): 65 | captured_kwargs.update(kwargs) 66 | return [] 67 | 68 | monkeypatch.setattr(bootstrap_trace_module, "bootstrap_trace_data", dummy_bootstrap_trace_data) 69 | monkeypatch.setattr( 70 | gepa_utils.DspyAdapter, 71 | "build_program", 72 | lambda self, candidate: DummyModule(), 73 | ) 74 | 75 | adapter.evaluate(batch=[], candidate={}, capture_traces=True) 76 | 77 | assert captured_kwargs["callback_metadata"] == {"disable_logging": True} 78 | 79 | 80 | @pytest.fixture 81 | def mock_mlflow(): 82 | mock_mlflow = mock.MagicMock() 83 | with mock.patch.dict("sys.modules", {"mlflow": mock_mlflow}): 84 | yield mock_mlflow 85 | 86 | 87 | @pytest.mark.parametrize("use_mlflow", [True, False]) 88 | def test_basic_workflow(use_mlflow, mock_mlflow): 89 | """Test to ensure the basic compile flow runs without errors.""" 90 | student = SimpleModule("input -> output") 91 | 92 | with open("tests/teleprompt/gepa_dummy_lm.json") as f: 93 | data = json.load(f) 94 | lm_history = data["lm"] 95 | reflection_lm_history = data["reflection_lm"] 96 | 97 | lm_main = DictDummyLM(lm_history) 98 | dspy.settings.configure(lm=lm_main) 99 | reflection_lm = DictDummyLM(reflection_lm_history) 100 | 101 | optimizer = dspy.GEPA( 102 | metric=simple_metric, 103 | reflection_lm=reflection_lm, 104 | max_metric_calls=5, 105 | use_mlflow=use_mlflow 106 | ) 107 | 108 | 109 | trainset = [ 110 | Example(input="What is the color of the sky?", output="blue").with_inputs("input"), 111 | Example(input="What does the fox say?", output="Ring-ding-ding-ding-dingeringeding!").with_inputs("input"), 112 | ] 113 | 114 | optimized_program = optimizer.compile(student, trainset=trainset, valset=trainset) 115 | assert optimized_program.predictor.signature.instructions == 'Given the field `input` containing a question or phrase, produce the field `output` containing the exact, direct, and contextually appropriate answer or response that the user expects, without additional explanations, commentary, or general knowledge unless explicitly requested.\n\nKey details and guidelines:\n\n1. The `input` field contains a question or phrase that may be literal, factual, or culturally specific (e.g., references to popular culture or memes).\n\n2. The `output` must be the precise answer or response that directly addresses the `input` as intended by the user, not a general or encyclopedic explanation.\n\n3. If the `input` is a well-known phrase or question from popular culture (e.g., "What does the fox say?"), the `output` should reflect the expected or canonical answer associated with that phrase, rather than a factual or scientific explanation.\n\n4. Avoid providing additional background information, scientific explanations, or alternative interpretations unless explicitly requested.\n\n5. The goal is to produce the answer that the user expects or the "correct" answer in the context of the question, including culturally recognized or meme-based answers.\n\n6. If the `input` is a straightforward factual question (e.g., "What is the color of the sky?"), provide the commonly accepted direct answer (e.g., "Blue") rather than a detailed scientific explanation.\n\n7. The output should be concise, clear, and focused solely on answering the question or phrase in the `input`.\n\nExample:\n\n- Input: "What is the color of the sky?"\n- Output: "Blue."\n\n- Input: "What does the fox say?"\n- Output: "Ring-ding-ding-ding-dingeringeding!"\n\nThis approach ensures that the assistant provides the expected, contextually appropriate answers rather than general or overly detailed responses that may be considered incorrect by the user.' 116 | if use_mlflow: 117 | assert mock_mlflow.start_run.call_count == 1 118 | else: 119 | assert mock_mlflow.start_run.call_count == 0 120 | 121 | def test_workflow_with_custom_instruction_proposer_and_component_selector(): 122 | """Test to ensure the basic compile flow runs without errors when using a custom instruction proposer and component selector.""" 123 | 124 | class TimeReader(dspy.Module): 125 | def __init__(self): 126 | super().__init__() 127 | self.hour_predictor = dspy.ChainOfThought("clock_photo: dspy.Image -> hour: int") 128 | self.minute_predictor = dspy.ChainOfThought("clock_photo: dspy.Image -> minute: int") 129 | 130 | self.parallel = dspy.Parallel(num_threads=2) 131 | 132 | def forward(self, clock_photo: dspy.Image): 133 | hour_prediction, minute_prediction = self.parallel( 134 | [ 135 | (self.hour_predictor, dict(clock_photo=clock_photo)), 136 | (self.minute_predictor, dict(clock_photo=clock_photo)), 137 | ] 138 | ) 139 | return dspy.Prediction(hour=hour_prediction.hour, minute=minute_prediction.minute) 140 | 141 | def metric(example, prediction, trace=None, pred_name=None, pred_trace=None): 142 | target_hour, target_minute = example.hour, example.minute 143 | predicted_hour, predicted_minute = prediction.hour, prediction.minute 144 | 145 | score = -abs(target_hour * 60 + target_minute - (predicted_hour * 60 + predicted_minute)) 146 | 147 | return dspy.Prediction( 148 | score=score, 149 | feedback=f"Target: {target_hour}:{target_minute}, Predicted: {predicted_hour}:{predicted_minute}", 150 | ) 151 | 152 | def all_component_selector(state, trajectories, subsample_scores, candidate_idx, candidate): 153 | """Select all components.""" 154 | return list(candidate.keys()) 155 | 156 | student = TimeReader() 157 | 158 | with open("tests/teleprompt/gepa_dummy_lm_custom_component_selector_custom_instruction_proposer.json") as f: 159 | data = json.load(f) 160 | 161 | lm_history = data["lm"] 162 | reflection_lm_history = data["reflection_lm"] 163 | 164 | lm_main = DictDummyLM(lm_history) 165 | reflection_lm = DictDummyLM(reflection_lm_history) 166 | 167 | dspy.settings.configure(lm=lm_main) 168 | optimizer = dspy.GEPA( 169 | metric=metric, 170 | reflection_lm=reflection_lm, 171 | max_metric_calls=5, 172 | instruction_proposer=instruction_proposal.MultiModalInstructionProposer(), 173 | component_selector=all_component_selector, 174 | num_threads=16, 175 | ) 176 | trainset = [ 177 | Example( 178 | clock_photo=dspy.Image( 179 | "https://upload.wikimedia.org/wikipedia/commons/thumb/c/cf/Pendulum_clock_by_Jacob_Kock%2C_antique_furniture_photography%2C_IMG_0931_edit.jpg/500px-Pendulum_clock_by_Jacob_Kock%2C_antique_furniture_photography%2C_IMG_0931_edit.jpg", 180 | download=False, 181 | ), 182 | hour=8, 183 | minute=18, 184 | ).with_inputs("clock_photo"), 185 | Example( 186 | clock_photo=dspy.Image( 187 | "https://upload.wikimedia.org/wikipedia/commons/thumb/a/a5/Telechron_clock_2H07-Br_Administrator.JPG/960px-Telechron_clock_2H07-Br_Administrator.JPG", 188 | download=False, 189 | ), 190 | hour=4, 191 | minute=16, 192 | ).with_inputs("clock_photo"), 193 | ] 194 | o = optimizer.compile(student, trainset=trainset, valset=trainset) 195 | 196 | assert o.hour_predictor.predict.signature.instructions == "Task\n- Input: clock_photo (an image of an analog clock)\n- Output: hour (an integer 1\u201312). Output only the hour number with no extra text.\n\nGoal\n- Determine the correct hour by accurately identifying the hour hand and its position relative to the hour marks, taking into account the minute hand\u2019s position (since the hour hand moves continuously between numbers).\n\nStep-by-step procedure\n1) Find the dial and pivot\n- Locate the clock face and the central pivot where all hands originate.\n- Ignore decorative elements that do not originate at the central pivot (e.g., ornaments, shadows, reflections).\n\n2) Determine the 12 o\u2019clock direction\n- Prefer the numeral \u201c12\u201d if visible. Otherwise use the upright orientation of numerals or the topmost marker.\n- If the photo is rotated, mentally rotate so numerals read upright: 12 at top, 3 right, 6 bottom, 9 left.\n\n3) Identify the hands correctly (do not assume a default \u201c10:10\u201d)\n- Second hand: thinnest, often with a counterweight, may span very long; ignore for the hour.\n- Minute hand: longest, usually reaches or nearly reaches the outer minute tick marks.\n- Hour hand: shortest, usually thicker, typically ends well inside the numerals.\n- If ambiguous, classify by tip distance from center: minute \u2265 hour. Use the piece actually anchored at the pivot, not its shadow.\n\n4) Measure positions (angles)\n- Measure each hand\u2019s angle clockwise from 12 o\u2019clock.\n- Minute angle \u03b8m \u2248 position of the minute hand; hour angle \u03b8h \u2248 position of the hour hand.\n\n5) Use minute-hand position to validate the hour-hand location\n- The hour hand advances 0.5\u00b0 per minute (i.e., 1/12 of the distance between hour marks every 5 minutes).\n- Sanity check examples:\n - ~15 minutes past: hour hand \u2248 1/4 of the way from the current hour toward the next.\n - ~30 minutes: \u2248 halfway.\n - ~45 minutes: \u2248 3/4 of the way.\n- If this relationship doesn\u2019t hold, you likely swapped hour and minute hands\u2014re-identify them.\n\n6) Determine the hour\n- Compute the \u201clast passed\u201d hour: H = floor((\u03b8h mod 360) / 30). Map 0 to 12 (i.e., if floor(...) = 0, H = 12).\n- Do not round up to the next hour. The correct hour is the number the hour hand has most recently passed, not the one it is approaching.\n- If the hour hand appears exactly on an hour mark but the minute hand is not at 12, treat it as still between hours and choose the lower (last passed) hour.\n\n7) Edge cases and robustness\n- Stylized or missing numerals: rely on the 12/3/6/9 axes and tick marks rather than numeral shapes.\n- Roman numerals: \u201c4\u201d may be IIII; positions are unchanged.\n- Ignore mirrored effects, reflections, and shadows; only consider hands anchored at the pivot.\n- Overlap times: if hands nearly overlap, use \u03b8m to ensure the hour hand offset matches 0.5\u00b0 per minute.\n- Return 12, not 0, when appropriate (e.g., just after 12:00).\n\nOutput format\n- Provide only: hour as an integer in [1,12], with no additional text.\n\nCommon error prevention (from prior mistakes)\n- Do not confuse the minute hand for the hour hand; verify by length and reach to the outer tick marks.\n- Do not infer times like \u201c10:10\u201d by default; always read from the actual hand angles.\n- Ensure the hour chosen matches the \u201clast passed\u201d number given the minute hand\u2019s position (e.g., at ~:16, the hour hand must be just past the hour, not near 1 when the minute hand is at 3)." 197 | assert o.minute_predictor.predict.signature.instructions == "Task: From the image field clock_photo (an analog clock), output the minute value as an integer from 0\u201359 in the field minute. Output only the minute number\u2014no text or other fields.\n\nWhat to analyze\n- Clock face orientation: Identify where \u201c12\u201d is on the dial. Use the numerals (Arabic or Roman, stylized fonts) or the positions of 3, 6, 9, 12 to set the reference. If the photo is tilted, measure angles relative to the clock face, not the image frame.\n- Hands identification (do not confuse them):\n - Minute hand: typically the longest solid hand reaching near the minute ticks/outer ring; thicker than the second hand; often has a pronounced pointer tip.\n - Hour hand: shorter and thicker, typically ends near the numerals.\n - Second hand (if present): the thinnest, often the longest, usually with a counterweight; ignore it for minute reading.\n - If two non-second hands look similar, the one whose tip reaches closer to the minute tick ring is the minute hand.\n- Ticks and numerals: Each numeral-to-numeral segment equals 5 minutes. If minute tick marks exist, use them. If not, divide each numeral interval evenly into five.\n\nHow to compute the minute\n1. Locate the clock center and the minute hand\u2019s tip.\n2. Determine the angle of the minute hand from the 12 o\u2019clock direction, increasing clockwise.\n3. Convert angle to minutes: minute_estimate = (angle_from_12 / 6). Round to the nearest whole minute.\n - Mapping: 12 \u2192 0, 1 \u2192 5, 2 \u2192 10, 3 \u2192 15, 4 \u2192 20, 5 \u2192 25, 6 \u2192 30, 7 \u2192 35, 8 \u2192 40, 9 \u2192 45, 10 \u2192 50, 11 \u2192 55.\n - If the tip is slightly past a numeral (e.g., just past 3), do not snap to the numeral; round to the nearest minute (e.g., 16 instead of 15).\n4. Consistency check with the hour hand (useful to avoid off-by-one and hand mix-ups):\n - The hour hand moves continuously: it advances 0.5 degrees per minute (i.e., 1/12 of the way to the next numeral every 5 minutes).\n - If your minute_estimate is an exact multiple of 5 but the hour hand is clearly between hour markers (not aligned with an hour), re-examine: the minute hand is likely slightly past the numeral; adjust to the nearest minute accordingly.\n - If the minute hand choice is ambiguous, infer the minute from the hour hand\u2019s fraction toward the next hour: minute \u2248 fraction_between_hour_markers \u00d7 60, then choose the hand assignment that matches this.\n5. Edge cases:\n - Overlapping hands: Look at which tip extends farther toward the tick ring to identify the minute hand.\n - Strong perspective or glare: Use the line from center to the visible tip; ignore reflections.\n - No minute ticks: Evenly interpolate between numerals.\n - Subdials or decorative elements (e.g., pendulum windows) are not the minute indicator; use the main dial only.\n\nOutput format\n- Return only the integer minute value (0\u201359) in the minute field.\n- If the angle computes to 60, output 0.\n\nError prevention reminders\n- Do not treat the hour hand as the minute hand.\n- Do not use the second hand to compute minutes.\n- Do not assume the minute hand is exactly on a numeral\u2014check for slight offsets and round to the nearest minute.\n- Ensure the final minute agrees with the hour hand\u2019s position trend (hour hand slightly past an hour implies minutes > 0)." 198 | 199 | 200 | def test_metric_requires_feedback_signature(): 201 | reflection_lm = DictDummyLM([]) 202 | with pytest.raises(TypeError): 203 | dspy.GEPA(metric=bad_metric, reflection_lm=reflection_lm, max_metric_calls=1) 204 | 205 | 206 | def any_metric( 207 | gold: dspy.Example, 208 | pred: dspy.Prediction, 209 | trace: Any = None, 210 | pred_name: str | None = None, 211 | pred_trace: Any = None, 212 | ) -> float: 213 | """ 214 | For this test, we only care that the program runs, not the score. 215 | """ 216 | return 0.0 # ← Just returns 0.0, doesn't access any attributes! 217 | 218 | 219 | def test_gepa_compile_with_track_usage_no_tuple_error(caplog): 220 | """ 221 | GEPA.compile should not log tuple-usage error when track_usage=True and complete without hanging. 222 | Before, compile would hang and/or log "'tuple' object has no attribute 'set_lm_usage'" repeatedly. 223 | """ 224 | student = dspy.Predict("question -> answer") 225 | trainset = [dspy.Example(question="What is 2+2?", answer="4").with_inputs("question")] 226 | 227 | task_lm = DummyLM([{"answer": "mock answer 1"}]) 228 | reflection_lm = DummyLM([{"new_instruction": "Something new."}]) 229 | 230 | compiled_container: dict[str, Any] = {} 231 | exc_container: dict[str, BaseException] = {} 232 | 233 | def run_compile(): 234 | try: 235 | with dspy.context(lm=task_lm, track_usage=True): 236 | optimizer = dspy.GEPA(metric=any_metric, reflection_lm=reflection_lm, max_metric_calls=3) 237 | compiled_container["prog"] = optimizer.compile(student, trainset=trainset, valset=trainset) 238 | except BaseException as e: 239 | exc_container["e"] = e 240 | 241 | t = threading.Thread(target=run_compile, daemon=True) 242 | t.start() 243 | t.join(timeout=1.0) 244 | 245 | # Assert compile did not hang (pre-fix behavior would time out here) 246 | assert not t.is_alive(), "GEPA.compile did not complete within timeout (likely pre-fix behavior)." 247 | 248 | # Assert no tuple-usage error is logged anymore 249 | assert "'tuple' object has no attribute 'set_lm_usage'" not in caplog.text 250 | 251 | # If any exception occurred, fail explicitly 252 | if "e" in exc_container: 253 | pytest.fail(f"GEPA.compile raised unexpectedly: {exc_container['e']}") 254 | 255 | # No timeout, no exception -> so the program must exist 256 | if "prog" not in compiled_container: 257 | pytest.fail("GEPA.compile did return a program (likely pre-fix behavior).") 258 | 259 | 260 | class MultiComponentModule(dspy.Module): 261 | """Test module with multiple predictors.""" 262 | 263 | def __init__(self): 264 | super().__init__() 265 | self.classifier = Predict("input -> category") 266 | self.generator = Predict("category, input -> output") 267 | 268 | def forward(self, input): 269 | category = self.classifier(input=input).category 270 | output = self.generator(category=category, input=input).output 271 | return dspy.Prediction(category=category, output=output) 272 | 273 | 274 | def component_selection_metric(example, prediction, trace=None, pred_name=None, pred_trace=None): 275 | """Simple metric for component selection testing.""" 276 | return dspy.Prediction(score=0.3, feedback="Test feedback") 277 | 278 | 279 | def test_component_selector_functionality(): 280 | """Test custom component selector function can select single/multiple components.""" 281 | 282 | # Track calls for verification 283 | selector_calls = [] 284 | 285 | def test_selector(state, trajectories, subsample_scores, candidate_idx, candidate): 286 | selector_calls.append({"components": list(candidate.keys()), "candidate_idx": candidate_idx}) 287 | # Test both single and multiple selection 288 | return ["classifier"] if candidate_idx == 0 else ["classifier", "generator"] 289 | 290 | student = MultiComponentModule() 291 | 292 | # Provide enough responses for all possible LM calls during optimization 293 | task_lm = DummyLM([{"category": "test_category", "output": "test_output"}] * 20) 294 | reflection_lm = DummyLM( 295 | [ 296 | {"improved_instruction": "Improved classifier instruction"}, 297 | {"improved_instruction": "Improved generator instruction"}, 298 | ] 299 | * 10 300 | ) 301 | trainset = [dspy.Example(input="test", output="expected").with_inputs("input")] 302 | 303 | with dspy.context(lm=task_lm): 304 | optimizer = dspy.GEPA( 305 | metric=component_selection_metric, 306 | reflection_lm=reflection_lm, 307 | max_metric_calls=6, # Reduced to minimize output 308 | component_selector=test_selector, 309 | ) 310 | result = optimizer.compile(student, trainset=trainset, valset=trainset) 311 | 312 | # Verify selector was called with correct parameters 313 | assert len(selector_calls) > 0, "Custom selector should be invoked" 314 | assert "classifier" in selector_calls[0]["components"], "Should receive all available components" 315 | assert "generator" in selector_calls[0]["components"], "Should receive all available components" 316 | assert result is not None, "Should return optimized program" 317 | 318 | 319 | def test_component_selector_default_behavior(): 320 | """Test default behavior when no custom selector provided.""" 321 | student = MultiComponentModule() 322 | 323 | # Provide enough responses for all possible LM calls 324 | task_lm = DummyLM([{"category": "test_category", "output": "test_output"}] * 15) 325 | reflection_lm = DummyLM([{"improved_instruction": "Better instruction"}] * 8) 326 | trainset = [dspy.Example(input="test", output="expected").with_inputs("input")] 327 | 328 | with dspy.context(lm=task_lm): 329 | # No component_selector - should use round-robin default 330 | optimizer = dspy.GEPA( 331 | metric=component_selection_metric, 332 | reflection_lm=reflection_lm, 333 | max_metric_calls=4, # Minimal calls to reduce noise 334 | ) 335 | result = optimizer.compile(student, trainset=trainset, valset=trainset) 336 | 337 | assert result is not None, "Should work with default selector" 338 | 339 | 340 | def test_component_selector_string_round_robin(): 341 | """Test string-based round_robin selector.""" 342 | student = MultiComponentModule() 343 | 344 | # Provide enough responses for all possible LM calls 345 | task_lm = DummyLM([{"category": "test_category", "output": "test_output"}] * 15) 346 | reflection_lm = DummyLM([{"improved_instruction": "Better instruction"}] * 8) 347 | trainset = [dspy.Example(input="test", output="expected").with_inputs("input")] 348 | 349 | with dspy.context(lm=task_lm): 350 | optimizer = dspy.GEPA( 351 | metric=component_selection_metric, 352 | reflection_lm=reflection_lm, 353 | max_metric_calls=4, 354 | component_selector="round_robin", # String-based selector 355 | ) 356 | result = optimizer.compile(student, trainset=trainset, valset=trainset) 357 | 358 | assert result is not None, "Should work with 'round_robin' string selector" 359 | 360 | 361 | def test_component_selector_string_all(): 362 | """Test string-based 'all' selector and verify it actually updates all components.""" 363 | student = MultiComponentModule() 364 | 365 | # Store original instructions to verify they get updated 366 | original_classifier_instruction = student.classifier.signature.instructions 367 | original_generator_instruction = student.generator.signature.instructions 368 | 369 | def optimize(component_selector): 370 | # Metric that progressively improves to encourage GEPA to accept new candidates 371 | call_count = 0 372 | 373 | def improving_metric(example, prediction, trace=None, pred_name=None, pred_trace=None): 374 | nonlocal call_count 375 | call_count += 1 376 | # Score improves with each call to encourage acceptance of new candidates 377 | score = min(0.3 + (call_count * 0.1), 1.0) 378 | return dspy.Prediction(score=score, feedback="Improving feedback") 379 | 380 | # Provide enough responses for all possible LM calls 381 | task_lm = DummyLM([{"category": "test_category", "output": "test_output"}] * 20) 382 | reflection_lm = DummyLM( 383 | [ 384 | {"improved_instruction": "Updated classifier instruction"}, 385 | {"improved_instruction": "Updated generator instruction"}, 386 | ] 387 | * 10 388 | ) 389 | trainset = [dspy.Example(input="test", output="expected").with_inputs("input")] 390 | 391 | with dspy.context(lm=task_lm): 392 | optimizer = dspy.GEPA( 393 | metric=improving_metric, 394 | reflection_lm=reflection_lm, 395 | max_metric_calls=8, 396 | component_selector=component_selector, 397 | track_stats=True, # Track intermediate results to verify updates 398 | ) 399 | return optimizer.compile(student, trainset=trainset, valset=trainset) 400 | 401 | result_round_robin = optimize(component_selector="round_robin") 402 | 403 | candidates_round_robin = result_round_robin.detailed_results.candidates 404 | 405 | assert ( 406 | candidates_round_robin[1].classifier.signature.instructions == original_classifier_instruction 407 | and candidates_round_robin[1].generator.signature.instructions != original_generator_instruction 408 | ) or ( 409 | candidates_round_robin[1].classifier.signature.instructions != original_classifier_instruction 410 | and candidates_round_robin[1].generator.signature.instructions == original_generator_instruction 411 | ), "First candidate should have only one component updated, when using round_robin selector" 412 | 413 | result_all = optimize(component_selector="all") 414 | 415 | candidates_all = result_all.detailed_results.candidates 416 | 417 | assert ( 418 | candidates_all[1].classifier.signature.instructions != original_classifier_instruction 419 | and candidates_all[1].generator.signature.instructions != original_generator_instruction 420 | ), "First candidate should have both components updated, when using all selector" 421 | 422 | 423 | def test_component_selector_custom_random(): 424 | """Test custom component selector function that randomly samples components.""" 425 | import random 426 | 427 | # Simple function-based selector 428 | def random_component_selector(state, trajectories, subsample_scores, candidate_idx, candidate): 429 | """Randomly select half of the available components.""" 430 | component_names = list(candidate.keys()) 431 | num_to_select = max(1, len(component_names) // 2) # At least 1, half of total 432 | return random.sample(component_names, num_to_select) 433 | 434 | student = MultiComponentModule() 435 | 436 | # Provide enough responses for all possible LM calls 437 | task_lm = DummyLM([{"category": "test_category", "output": "test_output"}] * 15) 438 | reflection_lm = DummyLM([{"improved_instruction": "Better instruction"}] * 8) 439 | trainset = [dspy.Example(input="test", output="expected").with_inputs("input")] 440 | 441 | with dspy.context(lm=task_lm): 442 | optimizer = dspy.GEPA( 443 | metric=component_selection_metric, 444 | reflection_lm=reflection_lm, 445 | max_metric_calls=4, 446 | component_selector=random_component_selector, # Function-based selector 447 | ) 448 | result = optimizer.compile(student, trainset=trainset, valset=trainset) 449 | 450 | assert result is not None, "Should work with custom random function selector" 451 | 452 | 453 | def test_alternating_half_component_selector(): 454 | """Test alternating half selector that optimizes different halves on even/odd iterations.""" 455 | 456 | selection_history = [] 457 | 458 | def alternating_half_selector(state, trajectories, subsample_scores, candidate_idx, candidate): 459 | """Optimize half the components on even iterations, half on odd iterations.""" 460 | components = list(candidate.keys()) 461 | 462 | # If there's only one component, always optimize it 463 | if len(components) <= 1: 464 | selected = components 465 | else: 466 | mid_point = len(components) // 2 467 | 468 | # Use state.i (iteration counter) to alternate between halves 469 | if state.i % 2 == 0: 470 | # Even iteration: optimize first half 471 | selected = components[:mid_point] 472 | else: 473 | # Odd iteration: optimize second half 474 | selected = components[mid_point:] 475 | 476 | # Track selections for verification 477 | selection_history.append({ 478 | "iteration": state.i, 479 | "selected": selected.copy(), 480 | "all_components": components.copy() 481 | }) 482 | 483 | return selected 484 | 485 | student = MultiComponentModule() # Has "classifier" and "generator" components 486 | 487 | # Provide enough responses for multiple iterations 488 | task_lm = DummyLM([{"category": "test_category", "output": "test_output"}] * 20) 489 | reflection_lm = DummyLM([{"improved_instruction": "Better instruction"}] * 10) 490 | trainset = [dspy.Example(input="test", output="expected").with_inputs("input")] 491 | 492 | with dspy.context(lm=task_lm): 493 | optimizer = dspy.GEPA( 494 | metric=component_selection_metric, 495 | reflection_lm=reflection_lm, 496 | max_metric_calls=8, # Allow multiple iterations 497 | component_selector=alternating_half_selector, 498 | ) 499 | result = optimizer.compile(student, trainset=trainset, valset=trainset) 500 | 501 | assert result is not None, "Should work with alternating half selector" 502 | assert len(selection_history) >= 2, "Should have made multiple selections" 503 | 504 | for i, selection in enumerate(selection_history): 505 | if selection["iteration"] % 2 == 0: 506 | # Even iteration should select first half: ["classifier"] 507 | assert "classifier" in selection["selected"], f"Even iteration {selection['iteration']} should include classifier" 508 | assert "generator" not in selection["selected"], f"Even iteration {selection['iteration']} should not include generator" 509 | else: 510 | # Odd iteration should select second half: ["generator"] 511 | assert "generator" in selection["selected"], f"Odd iteration {selection['iteration']} should include generator" 512 | assert "classifier" not in selection["selected"], f"Odd iteration {selection['iteration']} should not include classifier" 513 | ```