This is page 11 of 14. Use http://codebase.md/stanfordnlp/dspy?page={x} to view the full context. # Directory Structure ``` ├── .github │ ├── .internal_dspyai │ │ ├── internals │ │ │ ├── build-and-release.md │ │ │ └── release-checklist.md │ │ └── pyproject.toml │ ├── .tmp │ │ └── .generated-actions │ │ └── run-pypi-publish-in-docker-container │ │ └── action.yml │ ├── ISSUE_TEMPLATE │ │ ├── bug_report.yml │ │ └── feature_request.yml │ ├── PULL_REQUEST_TEMPLATE │ │ └── pull_request_template.md │ ├── workflow_scripts │ │ └── install_testpypi_pkg.sh │ └── workflows │ ├── build_and_release.yml │ ├── build_utils │ │ └── test_version.py │ ├── docs-push.yml │ ├── precommits_check.yml │ └── run_tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── docs │ ├── .gitignore │ ├── docs │ │ ├── api │ │ │ ├── adapters │ │ │ │ ├── Adapter.md │ │ │ │ ├── ChatAdapter.md │ │ │ │ ├── JSONAdapter.md │ │ │ │ └── TwoStepAdapter.md │ │ │ ├── evaluation │ │ │ │ ├── answer_exact_match.md │ │ │ │ ├── answer_passage_match.md │ │ │ │ ├── CompleteAndGrounded.md │ │ │ │ ├── Evaluate.md │ │ │ │ ├── EvaluationResult.md │ │ │ │ └── SemanticF1.md │ │ │ ├── experimental │ │ │ │ ├── Citations.md │ │ │ │ └── Document.md │ │ │ ├── index.md │ │ │ ├── models │ │ │ │ ├── Embedder.md │ │ │ │ └── LM.md │ │ │ ├── modules │ │ │ │ ├── BestOfN.md │ │ │ │ ├── ChainOfThought.md │ │ │ │ ├── CodeAct.md │ │ │ │ ├── Module.md │ │ │ │ ├── MultiChainComparison.md │ │ │ │ ├── Parallel.md │ │ │ │ ├── Predict.md │ │ │ │ ├── ProgramOfThought.md │ │ │ │ ├── ReAct.md │ │ │ │ └── Refine.md │ │ │ ├── optimizers │ │ │ │ ├── BetterTogether.md │ │ │ │ ├── BootstrapFewShot.md │ │ │ │ ├── BootstrapFewShotWithRandomSearch.md │ │ │ │ ├── BootstrapFinetune.md │ │ │ │ ├── BootstrapRS.md │ │ │ │ ├── COPRO.md │ │ │ │ ├── Ensemble.md │ │ │ │ ├── GEPA │ │ │ │ │ ├── GEPA_Advanced.md │ │ │ │ │ └── overview.md │ │ │ │ ├── InferRules.md │ │ │ │ ├── KNN.md │ │ │ │ ├── KNNFewShot.md │ │ │ │ ├── LabeledFewShot.md │ │ │ │ ├── MIPROv2.md │ │ │ │ └── SIMBA.md │ │ │ ├── primitives │ │ │ │ ├── Audio.md │ │ │ │ ├── Code.md │ │ │ │ ├── Example.md │ │ │ │ ├── History.md │ │ │ │ ├── Image.md │ │ │ │ ├── Prediction.md │ │ │ │ ├── Tool.md │ │ │ │ └── ToolCalls.md │ │ │ ├── signatures │ │ │ │ ├── InputField.md │ │ │ │ ├── OutputField.md │ │ │ │ └── Signature.md │ │ │ ├── tools │ │ │ │ ├── ColBERTv2.md │ │ │ │ ├── Embeddings.md │ │ │ │ └── PythonInterpreter.md │ │ │ └── utils │ │ │ ├── asyncify.md │ │ │ ├── configure_cache.md │ │ │ ├── disable_litellm_logging.md │ │ │ ├── disable_logging.md │ │ │ ├── enable_litellm_logging.md │ │ │ ├── enable_logging.md │ │ │ ├── inspect_history.md │ │ │ ├── load.md │ │ │ ├── StatusMessage.md │ │ │ ├── StatusMessageProvider.md │ │ │ ├── streamify.md │ │ │ └── StreamListener.md │ │ ├── cheatsheet.md │ │ ├── community │ │ │ ├── community-resources.md │ │ │ ├── how-to-contribute.md │ │ │ └── use-cases.md │ │ ├── deep-dive │ │ │ └── data-handling │ │ │ ├── built-in-datasets.md │ │ │ ├── examples.md │ │ │ ├── img │ │ │ │ └── data-loading.png │ │ │ └── loading-custom-data.md │ │ ├── faqs.md │ │ ├── index.md │ │ ├── js │ │ │ └── runllm-widget.js │ │ ├── learn │ │ │ ├── evaluation │ │ │ │ ├── data.md │ │ │ │ ├── metrics.md │ │ │ │ └── overview.md │ │ │ ├── figures │ │ │ │ ├── native_tool_call.png │ │ │ │ └── teleprompter-classes.png │ │ │ ├── index.md │ │ │ ├── optimization │ │ │ │ ├── optimizers.md │ │ │ │ └── overview.md │ │ │ └── programming │ │ │ ├── 7-assertions.md │ │ │ ├── adapters.md │ │ │ ├── language_models.md │ │ │ ├── mcp.md │ │ │ ├── modules.md │ │ │ ├── overview.md │ │ │ ├── signatures.md │ │ │ └── tools.md │ │ ├── production │ │ │ └── index.md │ │ ├── roadmap.md │ │ ├── static │ │ │ ├── .nojekyll │ │ │ └── img │ │ │ ├── dspy_logo.png │ │ │ ├── logo.png │ │ │ ├── mlflow-tracing-rag.png │ │ │ ├── modular.png │ │ │ ├── optimize.png │ │ │ ├── undraw_docusaurus_mountain.svg │ │ │ ├── undraw_docusaurus_react.svg │ │ │ ├── undraw_docusaurus_tree.svg │ │ │ └── universal_compatibility.png │ │ ├── stylesheets │ │ │ └── extra.css │ │ └── tutorials │ │ ├── agents │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-agent.png │ │ ├── ai_text_game │ │ │ └── index.md │ │ ├── async │ │ │ └── index.md │ │ ├── audio │ │ │ └── index.ipynb │ │ ├── build_ai_program │ │ │ └── index.md │ │ ├── cache │ │ │ └── index.md │ │ ├── classification │ │ │ └── index.md │ │ ├── classification_finetuning │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-classification.png │ │ ├── conversation_history │ │ │ └── index.md │ │ ├── core_development │ │ │ └── index.md │ │ ├── custom_module │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-custom-module.png │ │ ├── customer_service_agent │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-customer-service-agent.png │ │ ├── deployment │ │ │ ├── dspy_mlflow_ui.png │ │ │ └── index.md │ │ ├── email_extraction │ │ │ ├── index.md │ │ │ └── mlflow-tracing-email-extraction.png │ │ ├── entity_extraction │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-entity-extraction.png │ │ ├── games │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-agent.png │ │ ├── gepa_ai_program │ │ │ └── index.md │ │ ├── gepa_aime │ │ │ ├── index.ipynb │ │ │ ├── mlflow-tracing-gepa-aime.png │ │ │ └── mlflow-tracking-gepa-aime-optimization.png │ │ ├── gepa_facilitysupportanalyzer │ │ │ ├── index.ipynb │ │ │ ├── mlflow-tracing-gepa-support.png │ │ │ └── mlflow-tracking-gepa-support-optimization.png │ │ ├── gepa_papillon │ │ │ ├── index.ipynb │ │ │ ├── mlflow-tracing-gepa-papilon.png │ │ │ └── mlflow-tracking-gepa-papilon-optimization.png │ │ ├── image_generation_prompting │ │ │ └── index.ipynb │ │ ├── index.md │ │ ├── llms_txt_generation │ │ │ └── index.md │ │ ├── math │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-math.png │ │ ├── mcp │ │ │ └── index.md │ │ ├── mem0_react_agent │ │ │ └── index.md │ │ ├── multihop_search │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-multi-hop.png │ │ ├── observability │ │ │ ├── index.md │ │ │ ├── mlflow_trace_ui_navigation.gif │ │ │ ├── mlflow_trace_ui.png │ │ │ └── mlflow_trace_view.png │ │ ├── optimize_ai_program │ │ │ └── index.md │ │ ├── optimizer_tracking │ │ │ ├── child_run.png │ │ │ ├── experiment.png │ │ │ ├── index.md │ │ │ └── parent_run.png │ │ ├── output_refinement │ │ │ └── best-of-n-and-refine.md │ │ ├── papillon │ │ │ └── index.md │ │ ├── program_of_thought │ │ │ └── index.ipynb │ │ ├── rag │ │ │ ├── index.ipynb │ │ │ └── mlflow-tracing-rag.png │ │ ├── real_world_examples │ │ │ └── index.md │ │ ├── rl_ai_program │ │ │ └── index.md │ │ ├── rl_multihop │ │ │ └── index.ipynb │ │ ├── rl_papillon │ │ │ └── index.ipynb │ │ ├── sample_code_generation │ │ │ └── index.md │ │ ├── saving │ │ │ └── index.md │ │ ├── streaming │ │ │ └── index.md │ │ ├── tool_use │ │ │ └── index.ipynb │ │ └── yahoo_finance_react │ │ └── index.md │ ├── mkdocs.yml │ ├── overrides │ │ ├── home.html │ │ ├── main.html │ │ └── partials │ │ └── tabs.html │ ├── Pipfile │ ├── Pipfile.lock │ ├── README.md │ ├── requirements.txt │ ├── scripts │ │ ├── generate_api_docs.py │ │ └── generate_api_summary.py │ └── vercel.json ├── dspy │ ├── __init__.py │ ├── __metadata__.py │ ├── adapters │ │ ├── __init__.py │ │ ├── baml_adapter.py │ │ ├── base.py │ │ ├── chat_adapter.py │ │ ├── json_adapter.py │ │ ├── two_step_adapter.py │ │ ├── types │ │ │ ├── __init__.py │ │ │ ├── audio.py │ │ │ ├── base_type.py │ │ │ ├── citation.py │ │ │ ├── code.py │ │ │ ├── document.py │ │ │ ├── history.py │ │ │ ├── image.py │ │ │ └── tool.py │ │ ├── utils.py │ │ └── xml_adapter.py │ ├── clients │ │ ├── __init__.py │ │ ├── base_lm.py │ │ ├── cache.py │ │ ├── databricks.py │ │ ├── embedding.py │ │ ├── lm_local_arbor.py │ │ ├── lm_local.py │ │ ├── lm.py │ │ ├── openai.py │ │ ├── provider.py │ │ └── utils_finetune.py │ ├── datasets │ │ ├── __init__.py │ │ ├── alfworld │ │ │ ├── __init__.py │ │ │ ├── alfworld.py │ │ │ └── base_config.yml │ │ ├── colors.py │ │ ├── dataloader.py │ │ ├── dataset.py │ │ ├── gsm8k.py │ │ ├── hotpotqa.py │ │ └── math.py │ ├── dsp │ │ ├── __init__.py │ │ ├── colbertv2.py │ │ └── utils │ │ ├── __init__.py │ │ ├── dpr.py │ │ ├── settings.py │ │ └── utils.py │ ├── evaluate │ │ ├── __init__.py │ │ ├── auto_evaluation.py │ │ ├── evaluate.py │ │ └── metrics.py │ ├── experimental │ │ └── __init__.py │ ├── predict │ │ ├── __init__.py │ │ ├── aggregation.py │ │ ├── avatar │ │ │ ├── __init__.py │ │ │ ├── avatar.py │ │ │ ├── models.py │ │ │ └── signatures.py │ │ ├── best_of_n.py │ │ ├── chain_of_thought.py │ │ ├── code_act.py │ │ ├── knn.py │ │ ├── multi_chain_comparison.py │ │ ├── parallel.py │ │ ├── parameter.py │ │ ├── predict.py │ │ ├── program_of_thought.py │ │ ├── react.py │ │ ├── refine.py │ │ └── retry.py │ ├── primitives │ │ ├── __init__.py │ │ ├── base_module.py │ │ ├── example.py │ │ ├── module.py │ │ ├── prediction.py │ │ ├── python_interpreter.py │ │ └── runner.js │ ├── propose │ │ ├── __init__.py │ │ ├── dataset_summary_generator.py │ │ ├── grounded_proposer.py │ │ ├── propose_base.py │ │ └── utils.py │ ├── retrievers │ │ ├── __init__.py │ │ ├── databricks_rm.py │ │ ├── embeddings.py │ │ ├── retrieve.py │ │ └── weaviate_rm.py │ ├── signatures │ │ ├── __init__.py │ │ ├── field.py │ │ ├── signature.py │ │ └── utils.py │ ├── streaming │ │ ├── __init__.py │ │ ├── messages.py │ │ ├── streamify.py │ │ └── streaming_listener.py │ ├── teleprompt │ │ ├── __init__.py │ │ ├── avatar_optimizer.py │ │ ├── bettertogether.py │ │ ├── bootstrap_finetune.py │ │ ├── bootstrap_trace.py │ │ ├── bootstrap.py │ │ ├── copro_optimizer.py │ │ ├── ensemble.py │ │ ├── gepa │ │ │ ├── __init__.py │ │ │ ├── gepa_utils.py │ │ │ ├── gepa.py │ │ │ └── instruction_proposal.py │ │ ├── grpo.py │ │ ├── infer_rules.py │ │ ├── knn_fewshot.py │ │ ├── mipro_optimizer_v2.py │ │ ├── random_search.py │ │ ├── signature_opt.py │ │ ├── simba_utils.py │ │ ├── simba.py │ │ ├── teleprompt_optuna.py │ │ ├── teleprompt.py │ │ ├── utils.py │ │ └── vanilla.py │ └── utils │ ├── __init__.py │ ├── annotation.py │ ├── asyncify.py │ ├── caching.py │ ├── callback.py │ ├── dummies.py │ ├── exceptions.py │ ├── hasher.py │ ├── inspect_history.py │ ├── langchain_tool.py │ ├── logging_utils.py │ ├── mcp.py │ ├── parallelizer.py │ ├── saving.py │ ├── syncify.py │ ├── unbatchify.py │ └── usage_tracker.py ├── LICENSE ├── pyproject.toml ├── README.md ├── tests │ ├── __init__.py │ ├── adapters │ │ ├── test_adapter_utils.py │ │ ├── test_baml_adapter.py │ │ ├── test_base_type.py │ │ ├── test_chat_adapter.py │ │ ├── test_citation.py │ │ ├── test_code.py │ │ ├── test_document.py │ │ ├── test_json_adapter.py │ │ ├── test_tool.py │ │ ├── test_two_step_adapter.py │ │ └── test_xml_adapter.py │ ├── callback │ │ └── test_callback.py │ ├── clients │ │ ├── test_cache.py │ │ ├── test_databricks.py │ │ ├── test_embedding.py │ │ ├── test_inspect_global_history.py │ │ └── test_lm.py │ ├── conftest.py │ ├── datasets │ │ └── test_dataset.py │ ├── docs │ │ └── test_mkdocs_links.py │ ├── evaluate │ │ ├── test_evaluate.py │ │ └── test_metrics.py │ ├── examples │ │ └── test_baleen.py │ ├── metadata │ │ └── test_metadata.py │ ├── predict │ │ ├── test_aggregation.py │ │ ├── test_best_of_n.py │ │ ├── test_chain_of_thought.py │ │ ├── test_code_act.py │ │ ├── test_knn.py │ │ ├── test_multi_chain_comparison.py │ │ ├── test_parallel.py │ │ ├── test_predict.py │ │ ├── test_program_of_thought.py │ │ ├── test_react.py │ │ ├── test_refine.py │ │ └── test_retry.py │ ├── primitives │ │ ├── resources │ │ │ └── saved_program.json │ │ ├── test_base_module.py │ │ ├── test_example.py │ │ ├── test_module.py │ │ └── test_python_interpreter.py │ ├── propose │ │ └── test_grounded_proposer.py │ ├── README.md │ ├── reliability │ │ ├── __init__.py │ │ ├── complex_types │ │ │ └── generated │ │ │ ├── test_many_types_1 │ │ │ │ ├── inputs │ │ │ │ │ ├── input1.json │ │ │ │ │ └── input2.json │ │ │ │ ├── program.py │ │ │ │ └── schema.json │ │ │ ├── test_nesting_1 │ │ │ │ ├── inputs │ │ │ │ │ ├── input1.json │ │ │ │ │ └── input2.json │ │ │ │ ├── program.py │ │ │ │ └── schema.json │ │ │ └── test_nesting_2 │ │ │ ├── inputs │ │ │ │ └── input1.json │ │ │ ├── program.py │ │ │ └── schema.json │ │ ├── conftest.py │ │ ├── generate │ │ │ ├── __init__.py │ │ │ ├── __main__.py │ │ │ └── utils.py │ │ ├── input_formats │ │ │ └── generated │ │ │ └── test_markdown_1 │ │ │ ├── inputs │ │ │ │ ├── input1.json │ │ │ │ └── input2.json │ │ │ ├── program.py │ │ │ └── schema.json │ │ ├── README.md │ │ ├── reliability_conf.yaml │ │ ├── test_generated.py │ │ ├── test_pydantic_models.py │ │ └── utils.py │ ├── retrievers │ │ └── test_embeddings.py │ ├── signatures │ │ ├── test_adapter_image.py │ │ ├── test_custom_types.py │ │ └── test_signature.py │ ├── streaming │ │ └── test_streaming.py │ ├── teleprompt │ │ ├── gepa_dummy_lm_custom_component_selector_custom_instruction_proposer.json │ │ ├── gepa_dummy_lm.json │ │ ├── test_bootstrap_finetune.py │ │ ├── test_bootstrap_trace.py │ │ ├── test_bootstrap.py │ │ ├── test_copro_optimizer.py │ │ ├── test_ensemble.py │ │ ├── test_finetune.py │ │ ├── test_gepa_instruction_proposer.py │ │ ├── test_gepa.py │ │ ├── test_grpo.py │ │ ├── test_knn_fewshot.py │ │ ├── test_random_search.py │ │ ├── test_teleprompt.py │ │ └── test_utils.py │ ├── test_utils │ │ ├── __init__.py │ │ └── server │ │ ├── __init__.py │ │ ├── litellm_server_config.yaml │ │ └── litellm_server.py │ └── utils │ ├── __init__.py │ ├── resources │ │ └── mcp_server.py │ ├── test_annotation.py │ ├── test_asyncify.py │ ├── test_exceptions.py │ ├── test_langchain_tool.py │ ├── test_mcp.py │ ├── test_parallelizer.py │ ├── test_saving.py │ ├── test_settings.py │ ├── test_syncify.py │ ├── test_unbatchify.py │ └── test_usage_tracker.py └── uv.lock ``` # Files -------------------------------------------------------------------------------- /dspy/signatures/signature.py: -------------------------------------------------------------------------------- ```python """Signature class for DSPy. You typically subclass the Signature class, like this: class MySignature(dspy.Signature): input: str = InputField(desc="...") output: int = OutputField(desc="...") You can call Signature("input1, input2 -> output1, output2") to create a new signature type. You can also include instructions, Signature("input -> output", "This is a test"). But it's generally better to use the make_signature function. If you are not sure if your input is a string representation, (like "input1, input2 -> output1, output2"), or a signature, you can use the ensure_signature function. For compatibility with the legacy dsp format, you can use the signature_to_template function. """ import ast import importlib import inspect import re import sys import types import typing from copy import deepcopy from typing import Any from pydantic import BaseModel, Field, create_model from pydantic.fields import FieldInfo from dspy.signatures.field import InputField, OutputField def _default_instructions(cls) -> str: inputs_ = ", ".join([f"`{field}`" for field in cls.input_fields]) outputs_ = ", ".join([f"`{field}`" for field in cls.output_fields]) return f"Given the fields {inputs_}, produce the fields {outputs_}." class SignatureMeta(type(BaseModel)): def __call__(cls, *args, **kwargs): if cls is Signature: # We don't create an actual Signature instance, instead, we create a new Signature class. custom_types = kwargs.pop("custom_types", None) if custom_types is None and args and isinstance(args[0], str): custom_types = cls._detect_custom_types_from_caller(args[0]) return make_signature(*args, custom_types=custom_types, **kwargs) return super().__call__(*args, **kwargs) @staticmethod def _detect_custom_types_from_caller(signature_str): """Detect custom types from the caller's frame based on the signature string. Note: This method relies on Python's frame introspection which has some limitations: 1. May not work in all Python implementations (e.g., compiled with optimizations) 2. Looks up a limited number of frames in the call stack 3. Cannot find types that are imported but not in the caller's namespace For more reliable custom type resolution, explicitly provide types using the `custom_types` parameter when creating a Signature. """ # Extract potential type names from the signature string, including dotted names # Match both simple types like 'MyType' and dotted names like 'Module.Type' type_pattern = r":\s*([A-Za-z_][A-Za-z0-9_]*(?:\.[A-Za-z_][A-Za-z0-9_]*)*)" type_names = re.findall(type_pattern, signature_str) if not type_names: return None # Get type references from caller frames by walking the stack found_types = {} needed_types = set() dotted_types = {} for type_name in type_names: parts = type_name.split(".") base_name = parts[0] if base_name not in typing.__dict__ and base_name not in __builtins__: if len(parts) > 1: dotted_types[type_name] = base_name needed_types.add(base_name) else: needed_types.add(type_name) if not needed_types: return None frame = None try: frame = sys._getframe(1) # Start one level up (skip this function) max_frames = 100 frame_count = 0 while frame and needed_types and frame_count < max_frames: frame_count += 1 for type_name in list(needed_types): if type_name in frame.f_locals: found_types[type_name] = frame.f_locals[type_name] needed_types.remove(type_name) elif frame.f_globals and type_name in frame.f_globals: found_types[type_name] = frame.f_globals[type_name] needed_types.remove(type_name) # If we found all needed types, stop looking if not needed_types: break frame = frame.f_back if needed_types and frame_count >= max_frames: import logging logging.getLogger("dspy").warning( f"Reached maximum frame search depth ({max_frames}) while looking for types: {needed_types}. " "Consider providing custom_types explicitly to Signature." ) except (AttributeError, ValueError): # Handle environments where frame introspection is not available import logging logging.getLogger("dspy").debug( "Frame introspection failed while trying to resolve custom types. " "Consider providing custom_types explicitly to Signature." ) finally: if frame: del frame return found_types or None def __new__(mcs, signature_name, bases, namespace, **kwargs): # At this point, the orders have been swapped already. field_order = [name for name, value in namespace.items() if isinstance(value, FieldInfo)] # Set `str` as the default type for all fields raw_annotations = namespace.get("__annotations__", {}) for name, field in namespace.items(): if not isinstance(field, FieldInfo): continue # Don't add types to non-field attributes if not name.startswith("__") and name not in raw_annotations: raw_annotations[name] = str # Create ordered annotations dictionary that preserves field order ordered_annotations = {name: raw_annotations[name] for name in field_order if name in raw_annotations} # Add any remaining annotations that weren't in field_order ordered_annotations.update({k: v for k, v in raw_annotations.items() if k not in ordered_annotations}) namespace["__annotations__"] = ordered_annotations # Let Pydantic do its thing cls = super().__new__(mcs, signature_name, bases, namespace, **kwargs) # If we don't have instructions, it might be because we are a derived generic type. # In that case, we should inherit the instructions from the base class. if cls.__doc__ is None: for base in bases: if isinstance(base, SignatureMeta): doc = getattr(base, "__doc__", "") if doc != "": cls.__doc__ = doc # The more likely case is that the user has just not given us a type. # In that case, we should default to the input/output format. if cls.__doc__ is None: cls.__doc__ = _default_instructions(cls) # Ensure all fields are declared with InputField or OutputField cls._validate_fields() # Ensure all fields have a prefix for name, field in cls.model_fields.items(): if "prefix" not in field.json_schema_extra: field.json_schema_extra["prefix"] = infer_prefix(name) + ":" if "desc" not in field.json_schema_extra: field.json_schema_extra["desc"] = f"${{{name}}}" return cls def _validate_fields(cls): for name, field in cls.model_fields.items(): extra = field.json_schema_extra or {} field_type = extra.get("__dspy_field_type") if field_type not in ["input", "output"]: raise TypeError( f"Field `{name}` in `{cls.__name__}` must be declared with InputField or OutputField, but " f"field `{name}` has `field.json_schema_extra={field.json_schema_extra}`", ) @property def instructions(cls) -> str: return inspect.cleandoc(getattr(cls, "__doc__", "")) @instructions.setter def instructions(cls, instructions: str) -> None: cls.__doc__ = instructions @property def input_fields(cls) -> dict[str, FieldInfo]: return cls._get_fields_with_type("input") @property def output_fields(cls) -> dict[str, FieldInfo]: return cls._get_fields_with_type("output") @property def fields(cls) -> dict[str, FieldInfo]: # Make sure to give input fields before output fields return {**cls.input_fields, **cls.output_fields} @property def signature(cls) -> str: """The string representation of the signature.""" input_fields = ", ".join(cls.input_fields.keys()) output_fields = ", ".join(cls.output_fields.keys()) return f"{input_fields} -> {output_fields}" def _get_fields_with_type(cls, field_type) -> dict[str, FieldInfo]: return {k: v for k, v in cls.model_fields.items() if v.json_schema_extra["__dspy_field_type"] == field_type} def __repr__(cls): """Output a representation of the signature. Uses the form: Signature(question, context -> answer question: str = InputField(desc="..."), context: list[str] = InputField(desc="..."), answer: int = OutputField(desc="..."), ). """ field_reprs = [] for name, field in cls.fields.items(): field_reprs.append(f"{name} = Field({field})") field_repr = "\n ".join(field_reprs) return f"{cls.__name__}({cls.signature}\n instructions={cls.instructions!r}\n {field_repr}\n)" class Signature(BaseModel, metaclass=SignatureMeta): "" # Note: Don't put a docstring here, as it will become the default instructions # for any signature that doesn't define it's own instructions. @classmethod def with_instructions(cls, instructions: str) -> type["Signature"]: return Signature(cls.fields, instructions) @classmethod def with_updated_fields(cls, name: str, type_: type | None = None, **kwargs: dict[str, Any]) -> type["Signature"]: """Create a new Signature class with the updated field information. Returns a new Signature class with the field, name, updated with fields[name].json_schema_extra[key] = value. Args: name: The name of the field to update. type_: The new type of the field. kwargs: The new values for the field. Returns: A new Signature class (not an instance) with the updated field information. """ fields_copy = deepcopy(cls.fields) # Update `fields_copy[name].json_schema_extra` with the new kwargs, on conflicts # we use the new value in kwargs. fields_copy[name].json_schema_extra = { **fields_copy[name].json_schema_extra, **kwargs, } if type_ is not None: fields_copy[name].annotation = type_ return Signature(fields_copy, cls.instructions) @classmethod def prepend(cls, name, field, type_=None) -> type["Signature"]: return cls.insert(0, name, field, type_) @classmethod def append(cls, name, field, type_=None) -> type["Signature"]: return cls.insert(-1, name, field, type_) @classmethod def delete(cls, name) -> type["Signature"]: fields = dict(cls.fields) fields.pop(name, None) return Signature(fields, cls.instructions) @classmethod def insert(cls, index: int, name: str, field, type_: type | None = None) -> type["Signature"]: # It's possible to set the type as annotation=type in pydantic.Field(...) # But this may be annoying for users, so we allow them to pass the type if type_ is None: type_ = field.annotation if type_ is None: type_ = str input_fields = list(cls.input_fields.items()) output_fields = list(cls.output_fields.items()) # Choose the list to insert into based on the field type lst = input_fields if field.json_schema_extra["__dspy_field_type"] == "input" else output_fields # We support negative insert indices if index < 0: index += len(lst) + 1 if index < 0 or index > len(lst): raise ValueError( f"Invalid index to insert: {index}, index must be in the range of [{len(lst) - 1}, {len(lst)}] for " f"{field.json_schema_extra['__dspy_field_type']} fields, but received: {index}.", ) lst.insert(index, (name, (type_, field))) new_fields = dict(input_fields + output_fields) return Signature(new_fields, cls.instructions) @classmethod def equals(cls, other) -> bool: """Compare the JSON schema of two Signature classes.""" if not isinstance(other, type) or not issubclass(other, BaseModel): return False if cls.instructions != other.instructions: return False for name in cls.fields.keys() | other.fields.keys(): if name not in other.fields or name not in cls.fields: return False if cls.fields[name].json_schema_extra != other.fields[name].json_schema_extra: return False return True @classmethod def dump_state(cls): state = {"instructions": cls.instructions, "fields": []} for field in cls.fields: state["fields"].append( { "prefix": cls.fields[field].json_schema_extra["prefix"], "description": cls.fields[field].json_schema_extra["desc"], } ) return state @classmethod def load_state(cls, state): signature_copy = Signature(deepcopy(cls.fields), cls.instructions) signature_copy.instructions = state["instructions"] for field, saved_field in zip(signature_copy.fields.values(), state["fields"], strict=False): field.json_schema_extra["prefix"] = saved_field["prefix"] field.json_schema_extra["desc"] = saved_field["description"] return signature_copy def ensure_signature(signature: str | type[Signature], instructions=None) -> type[Signature]: if signature is None: return None if isinstance(signature, str): return Signature(signature, instructions) if instructions is not None: raise ValueError("Don't specify instructions when initializing with a Signature") return signature def make_signature( signature: str | dict[str, tuple[type, FieldInfo]], instructions: str | None = None, signature_name: str = "StringSignature", custom_types: dict[str, type] | None = None, ) -> type[Signature]: """Create a new Signature subclass with the specified fields and instructions. Args: signature: Either a string in the format "input1, input2 -> output1, output2" or a dictionary mapping field names to tuples of (type, FieldInfo). instructions: Optional string containing instructions/prompt for the signature. If not provided, defaults to a basic description of inputs and outputs. signature_name: Optional string to name the generated Signature subclass. Defaults to "StringSignature". custom_types: Optional dictionary mapping type names to their actual type objects. Useful for resolving custom types that aren't built-ins or in the typing module. Returns: A new signature class with the specified fields and instructions. Examples: ``` # Using string format sig1 = make_signature("question, context -> answer") # Using dictionary format sig2 = make_signature({ "question": (str, InputField()), "answer": (str, OutputField()) }) # Using custom types class MyType: pass sig3 = make_signature("input: MyType -> output", custom_types={"MyType": MyType}) ``` """ # Prepare the names dictionary for type resolution names = None if custom_types: names = dict(typing.__dict__) names.update(custom_types) fields = _parse_signature(signature, names) if isinstance(signature, str) else signature # Validate the fields, this is important because we sometimes forget the # slightly unintuitive syntax with tuples of (type, Field) fixed_fields = {} for name, type_field in fields.items(): if not isinstance(name, str): raise ValueError(f"Field names must be strings, but received: {name}.") if isinstance(type_field, FieldInfo): type_ = type_field.annotation field = type_field else: if not isinstance(type_field, tuple): raise ValueError(f"Field values must be tuples, but received: {type_field}.") type_, field = type_field # It might be better to be explicit about the type, but it currently would break # program of thought and teleprompters, so we just silently default to string. if type_ is None: type_ = str if not isinstance(type_, (type, typing._GenericAlias, types.GenericAlias, typing._SpecialForm, types.UnionType)): raise ValueError(f"Field types must be types, but received: {type_} of type {type(type_)}.") if not isinstance(field, FieldInfo): raise ValueError(f"Field values must be Field instances, but received: {field}.") fixed_fields[name] = (type_, field) # Default prompt when no instructions are provided if instructions is None: sig = Signature(signature, "") # Simple way to parse input/output fields instructions = _default_instructions(sig) return create_model( signature_name, __base__=Signature, __doc__=instructions, **fixed_fields, ) def _parse_signature(signature: str, names=None) -> dict[str, tuple[type, Field]]: if signature.count("->") != 1: raise ValueError(f"Invalid signature format: '{signature}', must contain exactly one '->'.") inputs_str, outputs_str = signature.split("->") fields = {} for field_name, field_type in _parse_field_string(inputs_str, names): fields[field_name] = (field_type, InputField()) for field_name, field_type in _parse_field_string(outputs_str, names): fields[field_name] = (field_type, OutputField()) return fields def _parse_field_string(field_string: str, names=None) -> dict[str, str]: """Extract the field name and type from field string in the string-based Signature. It takes a string like "x: int, y: str" and returns a dictionary mapping field names to their types. For example, "x: int, y: str" -> [("x", int), ("y", str)]. This function utitlizes the Python AST to parse the fields and types. """ args = ast.parse(f"def f({field_string}): pass").body[0].args.args field_names = [arg.arg for arg in args] types = [str if arg.annotation is None else _parse_type_node(arg.annotation, names) for arg in args] return zip(field_names, types, strict=False) def _parse_type_node(node, names=None) -> Any: """Recursively parse an AST node representing a type annotation. This function converts Python's Abstract Syntax Tree (AST) nodes into actual Python types. It's used to parse type annotations in signature strings like "x: list[int] -> y: str". Examples: - For "x: int", the AST node represents 'int' and returns the int type - For "x: list[str]", it processes a subscript node to return typing.list[str] - For "x: Optional[int]", it handles the Union type to return Optional[int] - For "x: MyModule.CustomType", it processes attribute access to return the actual type Args: node: An AST node from Python's ast module, representing a type annotation. Common node types include: - ast.Name: Simple types like 'int', 'str' - ast.Attribute: Nested types like 'typing.List' - ast.Subscript: Generic types like 'list[int]' names: Optional dictionary mapping type names to their actual type objects. Defaults to Python's typing module contents plus NoneType. Returns: The actual Python type represented by the AST node. Raises: ValueError: If the AST node represents an unknown or invalid type annotation. """ if names is None: names = dict(typing.__dict__) names["NoneType"] = type(None) def resolve_name(type_name: str): # Check if it's a built-in known type or in the provided names if type_name in names: return names[type_name] # Common built-in types builtin_types = [int, str, float, bool, list, tuple, dict, set, frozenset, complex, bytes, bytearray] # Check if it matches any known built-in type by name for t in builtin_types: if t.__name__ == type_name: return t # Attempt to import a module with this name dynamically # This allows handling of module-based annotations like `dspy.Image`. try: mod = importlib.import_module(type_name) names[type_name] = mod return mod except ImportError: pass # If we don't know the type or module, raise an error raise ValueError(f"Unknown name: {type_name}") if isinstance(node, ast.Module): if len(node.body) != 1: raise ValueError(f"Code is not syntactically valid: {ast.dump(node)}") return _parse_type_node(node.body[0], names) if isinstance(node, ast.Expr): return _parse_type_node(node.value, names) if isinstance(node, ast.Name): return resolve_name(node.id) if isinstance(node, ast.Attribute): base = _parse_type_node(node.value, names) attr_name = node.attr if hasattr(base, attr_name): return getattr(base, attr_name) if isinstance(node.value, ast.Name): full_name = f"{node.value.id}.{attr_name}" if full_name in names: return names[full_name] raise ValueError(f"Unknown attribute: {attr_name} on {base}") if isinstance(node, ast.Subscript): base_type = _parse_type_node(node.value, names) slice_node = node.slice if isinstance(slice_node, ast.Index): # For older Python versions slice_node = slice_node.value if isinstance(slice_node, ast.Tuple): arg_types = tuple(_parse_type_node(elt, names) for elt in slice_node.elts) else: arg_types = (_parse_type_node(slice_node, names),) # Special handling for Union, Optional if base_type is typing.Union: return typing.Union[arg_types] if base_type is typing.Optional: if len(arg_types) != 1: raise ValueError("Optional must have exactly one type argument") return typing.Optional[arg_types[0]] return base_type[arg_types] if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr): # Handle PEP 604: int | None, str | float, etc. left = _parse_type_node(node.left, names) right = _parse_type_node(node.right, names) # Optional[X] is Union[X, NoneType] if right is type(None): return typing.Optional[left] if left is type(None): return typing.Optional[right] return typing.Union[left, right] if isinstance(node, ast.Tuple): return tuple(_parse_type_node(elt, names) for elt in node.elts) if isinstance(node, ast.Constant): return node.value if isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id == "Field": keys = [kw.arg for kw in node.keywords] values = [] for kw in node.keywords: if isinstance(kw.value, ast.Constant): values.append(kw.value.value) else: values.append(_parse_type_node(kw.value, names)) return Field(**dict(zip(keys, values, strict=False))) raise ValueError( f"Failed to parse string-base Signature due to unhandled AST node type in annotation: {ast.dump(node)}. " "Please consider using class-based DSPy Signatures instead." ) def infer_prefix(attribute_name: str) -> str: """Infer a prefix from an attribute name by converting it to a human-readable format. Examples: "camelCaseText" -> "Camel Case Text" "snake_case_text" -> "Snake Case Text" "text2number" -> "Text 2 Number" "HTMLParser" -> "HTML Parser" """ # Step 1: Convert camelCase to snake_case # Example: "camelCase" -> "camel_Case" s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", attribute_name) # Handle consecutive capitals # Example: "camel_Case" -> "camel_case" intermediate_name = re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1) # Step 2: Handle numbers by adding underscores around them # Example: "text2number" -> "text_2_number" with_underscores_around_numbers = re.sub( r"([a-zA-Z])(\d)", # Match letter followed by number r"\1_\2", # Add underscore between them intermediate_name, ) # Example: "2text" -> "2_text" with_underscores_around_numbers = re.sub( r"(\d)([a-zA-Z])", # Match number followed by letter r"\1_\2", # Add underscore between them with_underscores_around_numbers, ) # Step 3: Convert to Title Case while preserving acronyms words = with_underscores_around_numbers.split("_") title_cased_words = [] for word in words: if word.isupper(): # Preserve acronyms like 'HTML', 'API' as-is title_cased_words.append(word) else: # Capitalize first letter: 'text' -> 'Text' title_cased_words.append(word.capitalize()) # Join words with spaces # Example: ["Text", "2", "Number"] -> "Text 2 Number" return " ".join(title_cased_words) ``` -------------------------------------------------------------------------------- /tests/predict/test_predict.py: -------------------------------------------------------------------------------- ```python import asyncio import copy import enum import time import types from datetime import datetime from unittest.mock import patch import orjson import pydantic import pytest from litellm import ModelResponse from pydantic import BaseModel, HttpUrl import dspy from dspy import Predict, Signature from dspy.predict.predict import serialize_object from dspy.utils.dummies import DummyLM def test_initialization_with_string_signature(): signature_string = "input1, input2 -> output" predict = Predict(signature_string) expected_instruction = "Given the fields `input1`, `input2`, produce the fields `output`." assert predict.signature.instructions == expected_instruction assert predict.signature.instructions == Signature(signature_string).instructions def test_reset_method(): predict_instance = Predict("input -> output") predict_instance.lm = "modified" predict_instance.traces = ["trace"] predict_instance.train = ["train"] predict_instance.demos = ["demo"] predict_instance.reset() assert predict_instance.lm is None assert predict_instance.traces == [] assert predict_instance.train == [] assert predict_instance.demos == [] def test_lm_after_dump_and_load_state(): predict_instance = Predict("input -> output") lm = dspy.LM( model="openai/gpt-4o-mini", model_type="chat", temperature=1, max_tokens=100, num_retries=10, ) predict_instance.lm = lm expected_lm_state = { "model": "openai/gpt-4o-mini", "model_type": "chat", "temperature": 1, "max_tokens": 100, "num_retries": 10, "cache": True, "finetuning_model": None, "launch_kwargs": {}, "train_kwargs": {}, } assert lm.dump_state() == expected_lm_state dumped_state = predict_instance.dump_state() new_instance = Predict("input -> output") new_instance.load_state(dumped_state) assert new_instance.lm.dump_state() == expected_lm_state def test_call_method(): predict_instance = Predict("input -> output") lm = DummyLM([{"output": "test output"}]) dspy.settings.configure(lm=lm) result = predict_instance(input="test input") assert result.output == "test output" def test_instructions_after_dump_and_load_state(): predict_instance = Predict(Signature("input -> output", "original instructions")) dumped_state = predict_instance.dump_state() new_instance = Predict(Signature("input -> output", "new instructions")) new_instance.load_state(dumped_state) assert new_instance.signature.instructions == "original instructions" def test_demos_after_dump_and_load_state(): class TranslateToEnglish(dspy.Signature): """Translate content from a language to English.""" content: str = dspy.InputField() language: str = dspy.InputField() translation: str = dspy.OutputField() original_instance = Predict(TranslateToEnglish) original_instance.demos = [ dspy.Example( content="¿Qué tal?", language="SPANISH", translation="Hello there", ).with_inputs("content", "language"), ] dumped_state = original_instance.dump_state() assert len(dumped_state["demos"]) == len(original_instance.demos) assert dumped_state["demos"][0]["content"] == original_instance.demos[0].content saved_state = orjson.dumps(dumped_state).decode() loaded_state = orjson.loads(saved_state) new_instance = Predict(TranslateToEnglish) new_instance.load_state(loaded_state) assert len(new_instance.demos) == len(original_instance.demos) # Demos don't need to keep the same types after saving and loading the state. assert new_instance.demos[0]["content"] == original_instance.demos[0].content def test_typed_demos_after_dump_and_load_state(): class Item(pydantic.BaseModel): name: str quantity: int class InventorySignature(dspy.Signature): """Handle inventory items and their translations.""" items: list[Item] = dspy.InputField() language: str = dspy.InputField() translated_items: list[Item] = dspy.OutputField() total_quantity: int = dspy.OutputField() original_instance = Predict(InventorySignature) original_instance.demos = [ dspy.Example( items=[Item(name="apple", quantity=5), Item(name="banana", quantity=3)], language="SPANISH", translated_items=[Item(name="manzana", quantity=5), Item(name="plátano", quantity=3)], total_quantity=8, ).with_inputs("items", "language"), ] # Test dump_state dumped_state = original_instance.dump_state() assert len(dumped_state["demos"]) == len(original_instance.demos) # Verify the input items were properly serialized assert isinstance(dumped_state["demos"][0]["items"], list) assert len(dumped_state["demos"][0]["items"]) == 2 assert dumped_state["demos"][0]["items"][0] == {"name": "apple", "quantity": 5} # Test serialization/deserialization saved_state = orjson.dumps(dumped_state).decode() loaded_state = orjson.loads(saved_state) # Test load_state new_instance = Predict(InventorySignature) new_instance.load_state(loaded_state) assert len(new_instance.demos) == len(original_instance.demos) # Verify the structure is maintained after loading loaded_demo = new_instance.demos[0] assert isinstance(loaded_demo["items"], list) assert len(loaded_demo["items"]) == 2 assert loaded_demo["items"][0]["name"] == "apple" assert loaded_demo["items"][0]["quantity"] == 5 assert loaded_demo["items"][1]["name"] == "banana" assert loaded_demo["items"][1]["quantity"] == 3 # Verify output items were also properly maintained assert isinstance(loaded_demo["translated_items"], list) assert len(loaded_demo["translated_items"]) == 2 assert loaded_demo["translated_items"][0]["name"] == "manzana" assert loaded_demo["translated_items"][1]["name"] == "plátano" # def test_typed_demos_after_dump_and_load_state(): # class TypedTranslateToEnglish(dspy.Signature): # """Translate content from a language to English.""" # class Input(pydantic.BaseModel): # content: str # language: str # class Output(pydantic.BaseModel): # translation: str # input: Input = dspy.InputField() # output: Output = dspy.OutputField() # original_instance = TypedPredictor(TypedTranslateToEnglish).predictor # original_instance.demos = [ # dspy.Example( # input=TypedTranslateToEnglish.Input( # content="¿Qué tal?", # language="SPANISH", # ), # output=TypedTranslateToEnglish.Output( # translation="Hello there", # ), # ).with_inputs("input"), # ] # dumped_state = original_instance.dump_state() # assert len(dumped_state["demos"]) == len(original_instance.demos) # assert dumped_state["demos"][0]["input"] == original_instance.demos[0].input.model_dump_json() # saved_state = ujson.dumps(dumped_state) # loaded_state = ujson.loads(saved_state) # new_instance = TypedPredictor(TypedTranslateToEnglish).predictor # new_instance.load_state(loaded_state) # assert len(new_instance.demos) == len(original_instance.demos) # # Demos don't need to keep the same types after saving and loading the state. # assert new_instance.demos[0]["input"] == original_instance.demos[0].input.model_dump_json() def test_signature_fields_after_dump_and_load_state(tmp_path): class CustomSignature(dspy.Signature): """I am just an instruction.""" sentence = dspy.InputField(desc="I am an innocent input!") sentiment = dspy.OutputField() file_path = tmp_path / "tmp.json" original_instance = Predict(CustomSignature) original_instance.save(file_path) class CustomSignature2(dspy.Signature): """I am not a pure instruction.""" sentence = dspy.InputField(desc="I am a malicious input!") sentiment = dspy.OutputField(desc="I am a malicious output!", prefix="I am a prefix!") new_instance = Predict(CustomSignature2) assert new_instance.signature.dump_state() != original_instance.signature.dump_state() # After loading, the fields should be the same. new_instance.load(file_path) assert new_instance.signature.dump_state() == original_instance.signature.dump_state() @pytest.mark.parametrize("filename", ["model.json", "model.pkl"]) def test_lm_field_after_dump_and_load_state(tmp_path, filename): file_path = tmp_path / filename lm = dspy.LM( model="openai/gpt-4o-mini", model_type="chat", temperature=1, max_tokens=100, num_retries=10, ) original_predict = dspy.Predict("q->a") original_predict.lm = lm original_predict.save(file_path) assert file_path.exists() loaded_predict = dspy.Predict("q->a") loaded_predict.load(file_path) assert original_predict.dump_state() == loaded_predict.dump_state() def test_forward_method(): program = Predict("question -> answer") dspy.settings.configure(lm=DummyLM([{"answer": "No more responses"}])) result = program(question="What is 1+1?").answer assert result == "No more responses" def test_forward_method2(): program = Predict("question -> answer1, answer2") dspy.settings.configure(lm=DummyLM([{"answer1": "my first answer", "answer2": "my second answer"}])) result = program(question="What is 1+1?") assert result.answer1 == "my first answer" assert result.answer2 == "my second answer" def test_config_management(): predict_instance = Predict("input -> output") predict_instance.update_config(new_key="value") config = predict_instance.get_config() assert "new_key" in config and config["new_key"] == "value" def test_multi_output(): program = Predict("question -> answer", n=2) dspy.settings.configure(lm=DummyLM([{"answer": "my first answer"}, {"answer": "my second answer"}])) results = program(question="What is 1+1?") assert results.completions.answer[0] == "my first answer" assert results.completions.answer[1] == "my second answer" def test_multi_output2(): program = Predict("question -> answer1, answer2", n=2) dspy.settings.configure( lm=DummyLM( [ {"answer1": "my 0 answer", "answer2": "my 2 answer"}, {"answer1": "my 1 answer", "answer2": "my 3 answer"}, ], ) ) results = program(question="What is 1+1?") assert results.completions.answer1[0] == "my 0 answer" assert results.completions.answer1[1] == "my 1 answer" assert results.completions.answer2[0] == "my 2 answer" assert results.completions.answer2[1] == "my 3 answer" def test_datetime_inputs_and_outputs(): # Define a model for datetime inputs and outputs class TimedEvent(pydantic.BaseModel): event_name: str event_time: datetime class TimedSignature(dspy.Signature): events: list[TimedEvent] = dspy.InputField() summary: str = dspy.OutputField() next_event_time: datetime = dspy.OutputField() program = Predict(TimedSignature) lm = DummyLM( [ { "reasoning": "Processed datetime inputs", "summary": "All events are processed", "next_event_time": "2024-11-27T14:00:00", } ] ) dspy.settings.configure(lm=lm) output = program( events=[ TimedEvent(event_name="Event 1", event_time=datetime(2024, 11, 25, 10, 0, 0)), TimedEvent(event_name="Event 2", event_time=datetime(2024, 11, 25, 15, 30, 0)), ] ) assert output.summary == "All events are processed" assert output.next_event_time == datetime(2024, 11, 27, 14, 0, 0) def test_explicitly_valued_enum_inputs_and_outputs(): class Status(enum.Enum): PENDING = "pending" IN_PROGRESS = "in_progress" COMPLETED = "completed" class StatusSignature(dspy.Signature): current_status: Status = dspy.InputField() next_status: Status = dspy.OutputField() program = Predict(StatusSignature) lm = DummyLM( [ { "reasoning": "The current status is 'PENDING', advancing to 'IN_PROGRESS'.", "next_status": "in_progress", } ] ) dspy.settings.configure(lm=lm) output = program(current_status=Status.PENDING) assert output.next_status == Status.IN_PROGRESS def test_enum_inputs_and_outputs_with_shared_names_and_values(): class TicketStatus(enum.Enum): OPEN = "CLOSED" CLOSED = "RESOLVED" RESOLVED = "OPEN" class TicketStatusSignature(dspy.Signature): current_status: TicketStatus = dspy.InputField() next_status: TicketStatus = dspy.OutputField() program = Predict(TicketStatusSignature) # Mock reasoning and output lm = DummyLM( [ { "reasoning": "The ticket is currently 'OPEN', transitioning to 'CLOSED'.", "next_status": "RESOLVED", # Refers to TicketStatus.CLOSED by value } ] ) dspy.settings.configure(lm=lm) output = program(current_status=TicketStatus.OPEN) assert output.next_status == TicketStatus.CLOSED # By value def test_auto_valued_enum_inputs_and_outputs(): Status = enum.Enum("Status", ["PENDING", "IN_PROGRESS", "COMPLETED"]) # noqa: N806 class StatusSignature(dspy.Signature): current_status: Status = dspy.InputField() next_status: Status = dspy.OutputField() program = Predict(StatusSignature) lm = DummyLM( [ { "reasoning": "The current status is 'PENDING', advancing to 'IN_PROGRESS'.", "next_status": "IN_PROGRESS", # Use the auto-assigned value for IN_PROGRESS } ] ) dspy.settings.configure(lm=lm) output = program(current_status=Status.PENDING) assert output.next_status == Status.IN_PROGRESS def test_named_predictors(): class MyModule(dspy.Module): def __init__(self): super().__init__() self.inner = Predict("question -> answer") program = MyModule() assert program.named_predictors() == [("inner", program.inner)] # Check that it also works the second time. program2 = copy.deepcopy(program) assert program2.named_predictors() == [("inner", program2.inner)] def test_output_only(): class OutputOnlySignature(dspy.Signature): output = dspy.OutputField() predictor = Predict(OutputOnlySignature) lm = DummyLM([{"output": "short answer"}]) dspy.settings.configure(lm=lm) assert predictor().output == "short answer" def test_load_state_chaining(): """Test that load_state returns self for chaining.""" original = Predict("question -> answer") original.demos = [{"question": "test", "answer": "response"}] state = original.dump_state() new_instance = Predict("question -> answer").load_state(state) assert new_instance is not None assert new_instance.demos == original.demos @pytest.mark.parametrize("adapter_type", ["chat", "json"]) def test_call_predict_with_chat_history(adapter_type): class SpyLM(dspy.LM): def __init__(self, *args, return_json=False, **kwargs): super().__init__(*args, **kwargs) self.calls = [] self.return_json = return_json def __call__(self, prompt=None, messages=None, **kwargs): self.calls.append({"prompt": prompt, "messages": messages, "kwargs": kwargs}) if self.return_json: return ["{'answer':'100%'}"] return ["[[ ## answer ## ]]\n100%!"] class MySignature(dspy.Signature): question: str = dspy.InputField() history: dspy.History = dspy.InputField() answer: str = dspy.OutputField() program = Predict(MySignature) if adapter_type == "chat": lm = SpyLM("dummy_model") dspy.settings.configure(adapter=dspy.ChatAdapter(), lm=lm) else: lm = SpyLM("dummy_model", return_json=True) dspy.settings.configure(adapter=dspy.JSONAdapter(), lm=lm) program( question="are you sure that's correct?", history=dspy.History(messages=[{"question": "what's the capital of france?", "answer": "paris"}]), ) # Verify the LM was called with correct messages assert len(lm.calls) == 1 messages = lm.calls[0]["messages"] assert len(messages) == 4 assert "what's the capital of france?" in messages[1]["content"] assert "paris" in messages[2]["content"] assert "are you sure that's correct" in messages[3]["content"] def test_lm_usage(): program = Predict("question -> answer") dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini", cache=False), track_usage=True) with patch( "dspy.clients.lm.litellm_completion", return_value=ModelResponse( choices=[{"message": {"content": "[[ ## answer ## ]]\nParis"}}], usage={"total_tokens": 10}, ), ): result = program(question="What is the capital of France?") assert result.answer == "Paris" assert result.get_lm_usage()["openai/gpt-4o-mini"]["total_tokens"] == 10 def test_lm_usage_with_parallel(): program = Predict("question -> answer") def program_wrapper(question): # Sleep to make it possible to cause a race condition time.sleep(0.5) return program(question=question) dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini", cache=False), track_usage=True) with patch( "dspy.clients.lm.litellm_completion", return_value=ModelResponse( choices=[{"message": {"content": "[[ ## answer ## ]]\nParis"}}], usage={"total_tokens": 10}, ), ): parallelizer = dspy.Parallel() input_pairs = [ (program_wrapper, {"question": "What is the capital of France?"}), (program_wrapper, {"question": "What is the capital of France?"}), ] results = parallelizer(input_pairs) assert results[0].answer == "Paris" assert results[1].answer == "Paris" assert results[0].get_lm_usage()["openai/gpt-4o-mini"]["total_tokens"] == 10 assert results[1].get_lm_usage()["openai/gpt-4o-mini"]["total_tokens"] == 10 @pytest.mark.asyncio async def test_lm_usage_with_async(): program = Predict("question -> answer") original_aforward = program.aforward async def patched_aforward(self, **kwargs): await asyncio.sleep(1) return await original_aforward(**kwargs) program.aforward = types.MethodType(patched_aforward, program) with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False), track_usage=True): with patch( "litellm.acompletion", return_value=ModelResponse( choices=[{"message": {"content": "[[ ## answer ## ]]\nParis"}}], usage={"total_tokens": 10}, ), ): coroutines = [ program.acall(question="What is the capital of France?"), program.acall(question="What is the capital of France?"), program.acall(question="What is the capital of France?"), program.acall(question="What is the capital of France?"), ] results = await asyncio.gather(*coroutines) assert results[0].answer == "Paris" assert results[1].answer == "Paris" assert results[0].get_lm_usage()["openai/gpt-4o-mini"]["total_tokens"] == 10 assert results[1].get_lm_usage()["openai/gpt-4o-mini"]["total_tokens"] == 10 assert results[2].get_lm_usage()["openai/gpt-4o-mini"]["total_tokens"] == 10 assert results[3].get_lm_usage()["openai/gpt-4o-mini"]["total_tokens"] == 10 def test_positional_arguments(): program = Predict("question -> answer") with pytest.raises(ValueError) as e: program("What is the capital of France?") assert str(e.value) == ( "Positional arguments are not allowed when calling `dspy.Predict`, must use keyword arguments that match " "your signature input fields: 'question'. For example: `predict(question=input_value, ...)`." ) def test_error_message_on_invalid_lm_setup(): # No LM is loaded. with pytest.raises(ValueError, match="No LM is loaded"): Predict("question -> answer")(question="Why did a chicken cross the kitchen?") # LM is a string. dspy.configure(lm="openai/gpt-4o-mini") with pytest.raises(ValueError) as e: Predict("question -> answer")(question="Why did a chicken cross the kitchen?") assert "LM must be an instance of `dspy.BaseLM`, not a string." in str(e.value) def dummy_lm(): pass # LM is not an instance of dspy.BaseLM. dspy.configure(lm=dummy_lm) with pytest.raises(ValueError) as e: Predict("question -> answer")(question="Why did a chicken cross the kitchen?") assert "LM must be an instance of `dspy.BaseLM`, not <class 'function'>." in str(e.value) @pytest.mark.parametrize("adapter_type", ["chat", "json"]) def test_field_constraints(adapter_type): class SpyLM(dspy.LM): def __init__(self, *args, return_json=False, **kwargs): super().__init__(*args, **kwargs) self.calls = [] self.return_json = return_json def __call__(self, prompt=None, messages=None, **kwargs): self.calls.append({"prompt": prompt, "messages": messages, "kwargs": kwargs}) if self.return_json: return ["{'score':'0.5', 'count':'2'}"] return ["[[ ## score ## ]]\n0.5\n[[ ## count ## ]]\n2"] class ConstrainedSignature(dspy.Signature): """Test signature with constrained fields.""" # Input with length and value constraints text: str = dspy.InputField(min_length=5, max_length=100, desc="Input text") number: int = dspy.InputField(gt=0, lt=10, desc="A number between 0 and 10") # Output with multiple constraints score: float = dspy.OutputField(ge=0.0, le=1.0, desc="Score between 0 and 1") count: int = dspy.OutputField(multiple_of=2, desc="Even number count") program = Predict(ConstrainedSignature) lm = SpyLM("dummy_model") if adapter_type == "chat": lm = SpyLM("dummy_model") dspy.settings.configure(adapter=dspy.ChatAdapter(), lm=lm) else: lm = SpyLM("dummy_model", return_json=True) dspy.settings.configure(adapter=dspy.JSONAdapter(), lm=lm) # Call the predictor to trigger instruction generation program(text="hello world", number=5) # Get the system message containing the instructions system_message = lm.calls[0]["messages"][0]["content"] # Verify constraints are included in the field descriptions assert "minimum length: 5" in system_message assert "maximum length: 100" in system_message assert "greater than: 0" in system_message assert "less than: 10" in system_message assert "greater than or equal to: 0.0" in system_message assert "less than or equal to: 1.0" in system_message assert "a multiple of the given number: 2" in system_message @pytest.mark.asyncio async def test_async_predict(): program = Predict("question -> answer") with dspy.context(lm=DummyLM([{"answer": "Paris"}])): result = await program.acall(question="What is the capital of France?") assert result.answer == "Paris" def test_predicted_outputs_piped_from_predict_to_lm_call(): program = Predict("question -> answer") dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini")) with patch("litellm.completion") as mock_completion: program( question="Why did a chicken cross the kitchen?", prediction={"type": "content", "content": "A chicken crossing the kitchen"}, ) assert mock_completion.call_args[1]["prediction"] == { "type": "content", "content": "A chicken crossing the kitchen", } # If the signature has prediction as an input field, and the prediction is not set as the standard predicted output # format, it should not be passed to the LM. program = Predict("question, prediction -> judgement") with patch("litellm.completion") as mock_completion: program(question="Why did a chicken cross the kitchen?", prediction="To get to the other side!") assert "prediction" not in mock_completion.call_args[1] def test_dump_state_pydantic_non_primitive_types(): class WebsiteInfo(BaseModel): name: str url: HttpUrl description: str | None = None created_at: datetime class TestSignature(dspy.Signature): website_info: WebsiteInfo = dspy.InputField() summary: str = dspy.OutputField() website_info = WebsiteInfo( name="Example", url="https://www.example.com", description="Test website", created_at=datetime(2021, 1, 1, 12, 0, 0), ) serialized = serialize_object(website_info) assert serialized["url"] == "https://www.example.com/" assert serialized["created_at"] == "2021-01-01T12:00:00" json_str = orjson.dumps(serialized).decode() reloaded = orjson.loads(json_str) assert reloaded == serialized predictor = Predict(TestSignature) demo = {"website_info": website_info, "summary": "This is a test website."} predictor.demos = [demo] state = predictor.dump_state() json_str = orjson.dumps(state).decode() reloaded_state = orjson.loads(json_str) demo_data = reloaded_state["demos"][0] assert demo_data["website_info"]["url"] == "https://www.example.com/" assert demo_data["website_info"]["created_at"] == "2021-01-01T12:00:00" def test_trace_size_limit(): program = Predict("question -> answer") dspy.settings.configure(lm=DummyLM([{"answer": "Paris"}]), max_trace_size=3) for _ in range(10): program(question="What is the capital of France?") assert len(dspy.settings.trace) == 3 def test_disable_trace(): program = Predict("question -> answer") dspy.settings.configure(lm=DummyLM([{"answer": "Paris"}]), trace=None) for _ in range(10): program(question="What is the capital of France?") assert dspy.settings.trace is None def test_per_module_history_size_limit(): program = Predict("question -> answer") dspy.settings.configure(lm=DummyLM([{"answer": "Paris"}]), max_history_size=5) for _ in range(10): program(question="What is the capital of France?") assert len(program.history) == 5 def test_per_module_history_disabled(): program = Predict("question -> answer") dspy.settings.configure(lm=DummyLM([{"answer": "Paris"}]), disable_history=True) for _ in range(10): program(question="What is the capital of France?") assert len(program.history) == 0 ``` -------------------------------------------------------------------------------- /tests/teleprompt/gepa_dummy_lm.json: -------------------------------------------------------------------------------- ```json {"lm": [{"prompt": null, "messages": [{"role": "system", "content": "Your input fields are:\n1. `input` (str):\nYour output fields are:\n1. `output` (str):\nAll interactions will be structured in the following way, with the appropriate values filled in.\n\n[[ ## input ## ]]\n{input}\n\n[[ ## output ## ]]\n{output}\n\n[[ ## completed ## ]]\nIn adhering to this structure, your objective is: \n Given the fields `input`, produce the fields `output`."}, {"role": "user", "content": "[[ ## input ## ]]\nWhat does the fox say?\n\nRespond with the corresponding output fields, starting with the field `[[ ## output ## ]]`, and then ending with the marker for `[[ ## completed ## ]]`."}], "kwargs": {}, "response": "[[ ## output ## ]]\nThe fox says a variety of sounds including barks, yips, and howls. However, the phrase \"What does the fox say?\" is famously known from the viral song by Ylvis, where humorous and nonsensical sounds are suggested as the fox's call.\n\n[[ ## completed ## ]]", "outputs": ["[[ ## output ## ]]\nThe fox says a variety of sounds including barks, yips, and howls. However, the phrase \"What does the fox say?\" is famously known from the viral song by Ylvis, where humorous and nonsensical sounds are suggested as the fox's call.\n\n[[ ## completed ## ]]"], "usage": {}, "cost": 0.0001624, "timestamp": "2025-08-13T18:20:44.059820", "uuid": "93b03c3f-7e96-43a0-bac1-2ececc179365", "model": "openai/gpt-4.1-mini", "response_model": "gpt-4.1-mini-2025-04-14", "model_type": "chat"}, {"prompt": null, "messages": [{"role": "system", "content": "Your input fields are:\n1. `input` (str):\nYour output fields are:\n1. `output` (str):\nAll interactions will be structured in the following way, with the appropriate values filled in.\n\n[[ ## input ## ]]\n{input}\n\n[[ ## output ## ]]\n{output}\n\n[[ ## completed ## ]]\nIn adhering to this structure, your objective is: \n Given the fields `input`, produce the fields `output`."}, {"role": "user", "content": "[[ ## input ## ]]\nWhat is the color of the sky?\n\nRespond with the corresponding output fields, starting with the field `[[ ## output ## ]]`, and then ending with the marker for `[[ ## completed ## ]]`."}], "kwargs": {}, "response": "[[ ## output ## ]]\nThe color of the sky is typically blue during the day due to the scattering of sunlight by the atmosphere. However, it can appear different colors at sunrise, sunset, or under various weather conditions.\n[[ ## completed ## ]]", "outputs": ["[[ ## output ## ]]\nThe color of the sky is typically blue during the day due to the scattering of sunlight by the atmosphere. However, it can appear different colors at sunrise, sunset, or under various weather conditions.\n[[ ## completed ## ]]"], "usage": {}, "cost": 0.000136, "timestamp": "2025-08-13T18:20:44.060309", "uuid": "b3367637-304b-46cf-9147-6a976a8db439", "model": "openai/gpt-4.1-mini", "response_model": "gpt-4.1-mini-2025-04-14", "model_type": "chat"}, {"prompt": null, "messages": [{"role": "system", "content": "Your input fields are:\n1. `input` (str):\nYour output fields are:\n1. `output` (str):\nAll interactions will be structured in the following way, with the appropriate values filled in.\n\n[[ ## input ## ]]\n{input}\n\n[[ ## output ## ]]\n{output}\n\n[[ ## completed ## ]]\nIn adhering to this structure, your objective is: \n Given the fields `input`, produce the fields `output`."}, {"role": "user", "content": "[[ ## input ## ]]\nWhat is the color of the sky?\n\nRespond with the corresponding output fields, starting with the field `[[ ## output ## ]]`, and then ending with the marker for `[[ ## completed ## ]]`."}], "kwargs": {}, "response": "[[ ## output ## ]]\nThe color of the sky is typically blue during the day due to the scattering of sunlight by the atmosphere. However, it can appear different colors at sunrise, sunset, or under various weather conditions.\n[[ ## completed ## ]]", "outputs": ["[[ ## output ## ]]\nThe color of the sky is typically blue during the day due to the scattering of sunlight by the atmosphere. However, it can appear different colors at sunrise, sunset, or under various weather conditions.\n[[ ## completed ## ]]"], "usage": {}, "cost": 0.000136, "timestamp": "2025-08-13T18:20:44.074500", "uuid": "34df9ccb-128c-4a66-8ee7-c17de080d2ad", "model": "openai/gpt-4.1-mini", "response_model": "gpt-4.1-mini-2025-04-14", "model_type": "chat"}, {"prompt": null, "messages": [{"role": "system", "content": "Your input fields are:\n1. `input` (str):\nYour output fields are:\n1. `output` (str):\nAll interactions will be structured in the following way, with the appropriate values filled in.\n\n[[ ## input ## ]]\n{input}\n\n[[ ## output ## ]]\n{output}\n\n[[ ## completed ## ]]\nIn adhering to this structure, your objective is: \n Given the fields `input`, produce the fields `output`."}, {"role": "user", "content": "[[ ## input ## ]]\nWhat does the fox say?\n\nRespond with the corresponding output fields, starting with the field `[[ ## output ## ]]`, and then ending with the marker for `[[ ## completed ## ]]`."}], "kwargs": {}, "response": "[[ ## output ## ]]\nThe fox says a variety of sounds including barks, yips, and howls. However, the phrase \"What does the fox say?\" is famously known from the viral song by Ylvis, where humorous and nonsensical sounds are suggested as the fox's call.\n\n[[ ## completed ## ]]", "outputs": ["[[ ## output ## ]]\nThe fox says a variety of sounds including barks, yips, and howls. However, the phrase \"What does the fox say?\" is famously known from the viral song by Ylvis, where humorous and nonsensical sounds are suggested as the fox's call.\n\n[[ ## completed ## ]]"], "usage": {}, "cost": 0.0001624, "timestamp": "2025-08-13T18:20:44.076517", "uuid": "eae79449-93ab-474c-887e-e8ecde69f8e8", "model": "openai/gpt-4.1-mini", "response_model": "gpt-4.1-mini-2025-04-14", "model_type": "chat"}, {"prompt": null, "messages": [{"role": "system", "content": "Your input fields are:\n1. `input` (str):\nYour output fields are:\n1. `output` (str):\nAll interactions will be structured in the following way, with the appropriate values filled in.\n\n[[ ## input ## ]]\n{input}\n\n[[ ## output ## ]]\n{output}\n\n[[ ## completed ## ]]\nIn adhering to this structure, your objective is: \n Given the fields `input`, produce the fields `output`."}, {"role": "user", "content": "[[ ## input ## ]]\nWhat does the fox say?\n\nRespond with the corresponding output fields, starting with the field `[[ ## output ## ]]`, and then ending with the marker for `[[ ## completed ## ]]`."}], "kwargs": {}, "response": "[[ ## output ## ]]\nThe fox says a variety of sounds including barks, yips, and howls. However, the phrase \"What does the fox say?\" is famously known from the viral song by Ylvis, where humorous and nonsensical sounds are suggested as the fox's call.\n\n[[ ## completed ## ]]", "outputs": ["[[ ## output ## ]]\nThe fox says a variety of sounds including barks, yips, and howls. However, the phrase \"What does the fox say?\" is famously known from the viral song by Ylvis, where humorous and nonsensical sounds are suggested as the fox's call.\n\n[[ ## completed ## ]]"], "usage": {}, "cost": 0.0001624, "timestamp": "2025-08-13T18:20:44.078480", "uuid": "6768dd52-0e72-4648-891b-4d60bea205ec", "model": "openai/gpt-4.1-mini", "response_model": "gpt-4.1-mini-2025-04-14", "model_type": "chat"}, {"prompt": null, "messages": [{"role": "system", "content": "Your input fields are:\n1. `input` (str):\nYour output fields are:\n1. `output` (str):\nAll interactions will be structured in the following way, with the appropriate values filled in.\n\n[[ ## input ## ]]\n{input}\n\n[[ ## output ## ]]\n{output}\n\n[[ ## completed ## ]]\nIn adhering to this structure, your objective is: \n Given the field `input` containing a question or phrase, produce the field `output` containing the exact, direct, and contextually appropriate answer or response that the user expects, without additional explanations, commentary, or general knowledge unless explicitly requested.\n \n Key details and guidelines:\n \n 1. The `input` field contains a question or phrase that may be literal, factual, or culturally specific (e.g., references to popular culture or memes).\n \n 2. The `output` must be the precise answer or response that directly addresses the `input` as intended by the user, not a general or encyclopedic explanation.\n \n 3. If the `input` is a well-known phrase or question from popular culture (e.g., \"What does the fox say?\"), the `output` should reflect the expected or canonical answer associated with that phrase, rather than a factual or scientific explanation.\n \n 4. Avoid providing additional background information, scientific explanations, or alternative interpretations unless explicitly requested.\n \n 5. The goal is to produce the answer that the user expects or the \"correct\" answer in the context of the question, including culturally recognized or meme-based answers.\n \n 6. If the `input` is a straightforward factual question (e.g., \"What is the color of the sky?\"), provide the commonly accepted direct answer (e.g., \"Blue\") rather than a detailed scientific explanation.\n \n 7. The output should be concise, clear, and focused solely on answering the question or phrase in the `input`.\n \n Example:\n \n - Input: \"What is the color of the sky?\"\n - Output: \"Blue.\"\n \n - Input: \"What does the fox say?\"\n - Output: \"Ring-ding-ding-ding-dingeringeding!\"\n \n This approach ensures that the assistant provides the expected, contextually appropriate answers rather than general or overly detailed responses that may be considered incorrect by the user."}, {"role": "user", "content": "[[ ## input ## ]]\nWhat is the color of the sky?\n\nRespond with the corresponding output fields, starting with the field `[[ ## output ## ]]`, and then ending with the marker for `[[ ## completed ## ]]`."}], "kwargs": {}, "response": "[[ ## output ## ]]\nBlue.\n\n[[ ## completed ## ]]", "outputs": ["[[ ## output ## ]]\nBlue.\n\n[[ ## completed ## ]]"], "usage": {}, "cost": 0.000234, "timestamp": "2025-08-13T18:20:44.093296", "uuid": "3d8eee4c-f7f0-4f74-a67a-8a43da85da15", "model": "openai/gpt-4.1-mini", "response_model": "gpt-4.1-mini-2025-04-14", "model_type": "chat"}, {"prompt": null, "messages": [{"role": "system", "content": "Your input fields are:\n1. `input` (str):\nYour output fields are:\n1. `output` (str):\nAll interactions will be structured in the following way, with the appropriate values filled in.\n\n[[ ## input ## ]]\n{input}\n\n[[ ## output ## ]]\n{output}\n\n[[ ## completed ## ]]\nIn adhering to this structure, your objective is: \n Given the field `input` containing a question or phrase, produce the field `output` containing the exact, direct, and contextually appropriate answer or response that the user expects, without additional explanations, commentary, or general knowledge unless explicitly requested.\n \n Key details and guidelines:\n \n 1. The `input` field contains a question or phrase that may be literal, factual, or culturally specific (e.g., references to popular culture or memes).\n \n 2. The `output` must be the precise answer or response that directly addresses the `input` as intended by the user, not a general or encyclopedic explanation.\n \n 3. If the `input` is a well-known phrase or question from popular culture (e.g., \"What does the fox say?\"), the `output` should reflect the expected or canonical answer associated with that phrase, rather than a factual or scientific explanation.\n \n 4. Avoid providing additional background information, scientific explanations, or alternative interpretations unless explicitly requested.\n \n 5. The goal is to produce the answer that the user expects or the \"correct\" answer in the context of the question, including culturally recognized or meme-based answers.\n \n 6. If the `input` is a straightforward factual question (e.g., \"What is the color of the sky?\"), provide the commonly accepted direct answer (e.g., \"Blue\") rather than a detailed scientific explanation.\n \n 7. The output should be concise, clear, and focused solely on answering the question or phrase in the `input`.\n \n Example:\n \n - Input: \"What is the color of the sky?\"\n - Output: \"Blue.\"\n \n - Input: \"What does the fox say?\"\n - Output: \"Ring-ding-ding-ding-dingeringeding!\"\n \n This approach ensures that the assistant provides the expected, contextually appropriate answers rather than general or overly detailed responses that may be considered incorrect by the user."}, {"role": "user", "content": "[[ ## input ## ]]\nWhat does the fox say?\n\nRespond with the corresponding output fields, starting with the field `[[ ## output ## ]]`, and then ending with the marker for `[[ ## completed ## ]]`."}], "kwargs": {}, "response": "[[ ## output ## ]]\nRing-ding-ding-ding-dingeringeding!\n\n[[ ## completed ## ]]", "outputs": ["[[ ## output ## ]]\nRing-ding-ding-ding-dingeringeding!\n\n[[ ## completed ## ]]"], "usage": {}, "cost": 0.0002492, "timestamp": "2025-08-13T18:20:44.094009", "uuid": "ef866572-bbe0-4b47-b64a-af377d18d786", "model": "openai/gpt-4.1-mini", "response_model": "gpt-4.1-mini-2025-04-14", "model_type": "chat"}, {"prompt": null, "messages": [{"role": "system", "content": "Your input fields are:\n1. `input` (str):\nYour output fields are:\n1. `output` (str):\nAll interactions will be structured in the following way, with the appropriate values filled in.\n\n[[ ## input ## ]]\n{input}\n\n[[ ## output ## ]]\n{output}\n\n[[ ## completed ## ]]\nIn adhering to this structure, your objective is: \n Given the field `input` containing a question or phrase, produce the field `output` containing the exact, direct, and contextually appropriate answer or response that the user expects, without additional explanations, commentary, or general knowledge unless explicitly requested.\n \n Key details and guidelines:\n \n 1. The `input` field contains a question or phrase that may be literal, factual, or culturally specific (e.g., references to popular culture or memes).\n \n 2. The `output` must be the precise answer or response that directly addresses the `input` as intended by the user, not a general or encyclopedic explanation.\n \n 3. If the `input` is a well-known phrase or question from popular culture (e.g., \"What does the fox say?\"), the `output` should reflect the expected or canonical answer associated with that phrase, rather than a factual or scientific explanation.\n \n 4. Avoid providing additional background information, scientific explanations, or alternative interpretations unless explicitly requested.\n \n 5. The goal is to produce the answer that the user expects or the \"correct\" answer in the context of the question, including culturally recognized or meme-based answers.\n \n 6. If the `input` is a straightforward factual question (e.g., \"What is the color of the sky?\"), provide the commonly accepted direct answer (e.g., \"Blue\") rather than a detailed scientific explanation.\n \n 7. The output should be concise, clear, and focused solely on answering the question or phrase in the `input`.\n \n Example:\n \n - Input: \"What is the color of the sky?\"\n - Output: \"Blue.\"\n \n - Input: \"What does the fox say?\"\n - Output: \"Ring-ding-ding-ding-dingeringeding!\"\n \n This approach ensures that the assistant provides the expected, contextually appropriate answers rather than general or overly detailed responses that may be considered incorrect by the user."}, {"role": "user", "content": "[[ ## input ## ]]\nWhat does the fox say?\n\nRespond with the corresponding output fields, starting with the field `[[ ## output ## ]]`, and then ending with the marker for `[[ ## completed ## ]]`."}], "kwargs": {}, "response": "[[ ## output ## ]]\nRing-ding-ding-ding-dingeringeding!\n\n[[ ## completed ## ]]", "outputs": ["[[ ## output ## ]]\nRing-ding-ding-ding-dingeringeding!\n\n[[ ## completed ## ]]"], "usage": {}, "cost": 0.0002492, "timestamp": "2025-08-13T18:20:44.094555", "uuid": "5b223989-d453-4c94-bf74-7c94d328c94d", "model": "openai/gpt-4.1-mini", "response_model": "gpt-4.1-mini-2025-04-14", "model_type": "chat"}, {"prompt": null, "messages": [{"role": "system", "content": "Your input fields are:\n1. `input` (str):\nYour output fields are:\n1. `output` (str):\nAll interactions will be structured in the following way, with the appropriate values filled in.\n\n[[ ## input ## ]]\n{input}\n\n[[ ## output ## ]]\n{output}\n\n[[ ## completed ## ]]\nIn adhering to this structure, your objective is: \n Given the field `input` containing a question or phrase, produce the field `output` containing the exact, direct, and contextually appropriate answer or response that the user expects, without additional explanations, commentary, or general knowledge unless explicitly requested.\n \n Key details and guidelines:\n \n 1. The `input` field contains a question or phrase that may be literal, factual, or culturally specific (e.g., references to popular culture or memes).\n \n 2. The `output` must be the precise answer or response that directly addresses the `input` as intended by the user, not a general or encyclopedic explanation.\n \n 3. If the `input` is a well-known phrase or question from popular culture (e.g., \"What does the fox say?\"), the `output` should reflect the expected or canonical answer associated with that phrase, rather than a factual or scientific explanation.\n \n 4. Avoid providing additional background information, scientific explanations, or alternative interpretations unless explicitly requested.\n \n 5. The goal is to produce the answer that the user expects or the \"correct\" answer in the context of the question, including culturally recognized or meme-based answers.\n \n 6. If the `input` is a straightforward factual question (e.g., \"What is the color of the sky?\"), provide the commonly accepted direct answer (e.g., \"Blue\") rather than a detailed scientific explanation.\n \n 7. The output should be concise, clear, and focused solely on answering the question or phrase in the `input`.\n \n Example:\n \n - Input: \"What is the color of the sky?\"\n - Output: \"Blue.\"\n \n - Input: \"What does the fox say?\"\n - Output: \"Ring-ding-ding-ding-dingeringeding!\"\n \n This approach ensures that the assistant provides the expected, contextually appropriate answers rather than general or overly detailed responses that may be considered incorrect by the user."}, {"role": "user", "content": "[[ ## input ## ]]\nWhat does the fox say?\n\nRespond with the corresponding output fields, starting with the field `[[ ## output ## ]]`, and then ending with the marker for `[[ ## completed ## ]]`."}], "kwargs": {}, "response": "[[ ## output ## ]]\nRing-ding-ding-ding-dingeringeding!\n\n[[ ## completed ## ]]", "outputs": ["[[ ## output ## ]]\nRing-ding-ding-ding-dingeringeding!\n\n[[ ## completed ## ]]"], "usage": {}, "cost": 0.0002492, "timestamp": "2025-08-13T18:20:44.123406", "uuid": "64ce9bc4-5a84-4eb4-b8fa-5cceddfb8c7c", "model": "openai/gpt-4.1-mini", "response_model": "gpt-4.1-mini-2025-04-14", "model_type": "chat"}, {"prompt": null, "messages": [{"role": "system", "content": "Your input fields are:\n1. `input` (str):\nYour output fields are:\n1. `output` (str):\nAll interactions will be structured in the following way, with the appropriate values filled in.\n\n[[ ## input ## ]]\n{input}\n\n[[ ## output ## ]]\n{output}\n\n[[ ## completed ## ]]\nIn adhering to this structure, your objective is: \n Given the field `input` containing a question or phrase, produce the field `output` containing the exact, direct, and contextually appropriate answer or response that the user expects, without additional explanations, commentary, or general knowledge unless explicitly requested.\n \n Key details and guidelines:\n \n 1. The `input` field contains a question or phrase that may be literal, factual, or culturally specific (e.g., references to popular culture or memes).\n \n 2. The `output` must be the precise answer or response that directly addresses the `input` as intended by the user, not a general or encyclopedic explanation.\n \n 3. If the `input` is a well-known phrase or question from popular culture (e.g., \"What does the fox say?\"), the `output` should reflect the expected or canonical answer associated with that phrase, rather than a factual or scientific explanation.\n \n 4. Avoid providing additional background information, scientific explanations, or alternative interpretations unless explicitly requested.\n \n 5. The goal is to produce the answer that the user expects or the \"correct\" answer in the context of the question, including culturally recognized or meme-based answers.\n \n 6. If the `input` is a straightforward factual question (e.g., \"What is the color of the sky?\"), provide the commonly accepted direct answer (e.g., \"Blue\") rather than a detailed scientific explanation.\n \n 7. The output should be concise, clear, and focused solely on answering the question or phrase in the `input`.\n \n Example:\n \n - Input: \"What is the color of the sky?\"\n - Output: \"Blue.\"\n \n - Input: \"What does the fox say?\"\n - Output: \"Ring-ding-ding-ding-dingeringeding!\"\n \n This approach ensures that the assistant provides the expected, contextually appropriate answers rather than general or overly detailed responses that may be considered incorrect by the user."}, {"role": "user", "content": "[[ ## input ## ]]\nWhat is the color of the sky?\n\nRespond with the corresponding output fields, starting with the field `[[ ## output ## ]]`, and then ending with the marker for `[[ ## completed ## ]]`."}], "kwargs": {}, "response": "[[ ## output ## ]]\nBlue.\n\n[[ ## completed ## ]]", "outputs": ["[[ ## output ## ]]\nBlue.\n\n[[ ## completed ## ]]"], "usage": {}, "cost": 0.000234, "timestamp": "2025-08-13T18:20:44.125206", "uuid": "5b38b61a-e139-4962-b788-570b3cbb5c38", "model": "openai/gpt-4.1-mini", "response_model": "gpt-4.1-mini-2025-04-14", "model_type": "chat"}], "reflection_lm": [{"prompt": "I provided an assistant with the following instructions to perform a task for me:\n```\nGiven the fields `input`, produce the fields `output`.\n```\n\nThe following are examples of different task inputs provided to the assistant along with the assistant's response for each of them, and some feedback on how the assistant's response could be better:\n```\n# Example 1\n## Inputs\n### input\nWhat is the color of the sky?\n\n## Generated Outputs\n### output\nThe color of the sky is typically blue during the day due to the scattering of sunlight by the atmosphere. However, it can appear different colors at sunrise, sunset, or under various weather conditions.\n\n## Feedback\nWrong answer.\n\n\n\n# Example 2\n## Inputs\n### input\nWhat does the fox say?\n\n## Generated Outputs\n### output\nThe fox says a variety of sounds including barks, yips, and howls. However, the phrase \"What does the fox say?\" is famously known from the viral song by Ylvis, where humorous and nonsensical sounds are suggested as the fox's call.\n\n## Feedback\nWrong answer.\n\n\n\n# Example 3\n## Inputs\n### input\nWhat does the fox say?\n\n## Generated Outputs\n### output\nThe fox says a variety of sounds including barks, yips, and howls. However, the phrase \"What does the fox say?\" is famously known from the viral song by Ylvis, where humorous and nonsensical sounds are suggested as the fox's call.\n\n## Feedback\nWrong answer.\n\n\n```\n\nYour task is to write a new instruction for the assistant.\n\nRead the inputs carefully and identify the input format and infer detailed task description about the task I wish to solve with the assistant.\n\nRead all the assistant responses and the corresponding feedback. Identify all niche and domain specific factual information about the task and include it in the instruction, as a lot of it may not be available to the assistant in the future. The assistant may have utilized a generalizable strategy to solve the task, if so, include that in the instruction as well.\n\nProvide the new instructions within ``` blocks.", "messages": null, "kwargs": {}, "response": "``` \nGiven the field `input` containing a question or phrase, produce the field `output` containing the exact, direct, and contextually appropriate answer or response that the user expects, without additional explanations, commentary, or general knowledge unless explicitly requested.\n\nKey details and guidelines:\n\n1. The `input` field contains a question or phrase that may be literal, factual, or culturally specific (e.g., references to popular culture or memes).\n\n2. The `output` must be the precise answer or response that directly addresses the `input` as intended by the user, not a general or encyclopedic explanation.\n\n3. If the `input` is a well-known phrase or question from popular culture (e.g., \"What does the fox say?\"), the `output` should reflect the expected or canonical answer associated with that phrase, rather than a factual or scientific explanation.\n\n4. Avoid providing additional background information, scientific explanations, or alternative interpretations unless explicitly requested.\n\n5. The goal is to produce the answer that the user expects or the \"correct\" answer in the context of the question, including culturally recognized or meme-based answers.\n\n6. If the `input` is a straightforward factual question (e.g., \"What is the color of the sky?\"), provide the commonly accepted direct answer (e.g., \"Blue\") rather than a detailed scientific explanation.\n\n7. The output should be concise, clear, and focused solely on answering the question or phrase in the `input`.\n\nExample:\n\n- Input: \"What is the color of the sky?\"\n- Output: \"Blue.\"\n\n- Input: \"What does the fox say?\"\n- Output: \"Ring-ding-ding-ding-dingeringeding!\"\n\nThis approach ensures that the assistant provides the expected, contextually appropriate answers rather than general or overly detailed responses that may be considered incorrect by the user.\n```", "outputs": ["``` \nGiven the field `input` containing a question or phrase, produce the field `output` containing the exact, direct, and contextually appropriate answer or response that the user expects, without additional explanations, commentary, or general knowledge unless explicitly requested.\n\nKey details and guidelines:\n\n1. The `input` field contains a question or phrase that may be literal, factual, or culturally specific (e.g., references to popular culture or memes).\n\n2. The `output` must be the precise answer or response that directly addresses the `input` as intended by the user, not a general or encyclopedic explanation.\n\n3. If the `input` is a well-known phrase or question from popular culture (e.g., \"What does the fox say?\"), the `output` should reflect the expected or canonical answer associated with that phrase, rather than a factual or scientific explanation.\n\n4. Avoid providing additional background information, scientific explanations, or alternative interpretations unless explicitly requested.\n\n5. The goal is to produce the answer that the user expects or the \"correct\" answer in the context of the question, including culturally recognized or meme-based answers.\n\n6. If the `input` is a straightforward factual question (e.g., \"What is the color of the sky?\"), provide the commonly accepted direct answer (e.g., \"Blue\") rather than a detailed scientific explanation.\n\n7. The output should be concise, clear, and focused solely on answering the question or phrase in the `input`.\n\nExample:\n\n- Input: \"What is the color of the sky?\"\n- Output: \"Blue.\"\n\n- Input: \"What does the fox say?\"\n- Output: \"Ring-ding-ding-ding-dingeringeding!\"\n\nThis approach ensures that the assistant provides the expected, contextually appropriate answers rather than general or overly detailed responses that may be considered incorrect by the user.\n```"], "usage": {}, "cost": 0.000774, "timestamp": "2025-08-13T18:20:44.080463", "uuid": "c71eee51-af0b-4469-a365-343105013d66", "model": "openai/gpt-4.1-mini", "response_model": "gpt-4.1-mini-2025-04-14", "model_type": "chat"}]} ``` -------------------------------------------------------------------------------- /dspy/teleprompt/gepa/gepa.py: -------------------------------------------------------------------------------- ```python import inspect import logging import random from dataclasses import dataclass from typing import Any, Literal, Optional, Protocol, Union from gepa import GEPAResult from gepa.core.adapter import ProposalFn from gepa.proposer.reflective_mutation.base import ReflectionComponentSelector from dspy.clients.lm import LM from dspy.primitives import Example, Module, Prediction from dspy.teleprompt.gepa.gepa_utils import DspyAdapter, DSPyTrace, PredictorFeedbackFn, ScoreWithFeedback from dspy.teleprompt.teleprompt import Teleprompter from dspy.utils.annotation import experimental logger = logging.getLogger(__name__) AUTO_RUN_SETTINGS = { "light": {"n": 6}, "medium": {"n": 12}, "heavy": {"n": 18}, } @experimental(version="3.0.0") class GEPAFeedbackMetric(Protocol): def __call__( gold: Example, pred: Prediction, trace: Optional["DSPyTrace"], pred_name: str | None, pred_trace: Optional["DSPyTrace"], ) -> Union[float, "ScoreWithFeedback"]: """ This function is called with the following arguments: - gold: The gold example. - pred: The predicted output. - trace: Optional. The trace of the program's execution. - pred_name: Optional. The name of the target predictor currently being optimized by GEPA, for which the feedback is being requested. - pred_trace: Optional. The trace of the target predictor's execution GEPA is seeking feedback for. Note the `pred_name` and `pred_trace` arguments. During optimization, GEPA will call the metric to obtain feedback for individual predictors being optimized. GEPA provides the name of the predictor in `pred_name` and the sub-trace (of the trace) corresponding to the predictor in `pred_trace`. If available at the predictor level, the metric should return dspy.Prediction(score: float, feedback: str) corresponding to the predictor. If not available at the predictor level, the metric can also return a text feedback at the program level (using just the gold, pred and trace). If no feedback is returned, GEPA will use a simple text feedback consisting of just the score: f"This trajectory got a score of {score}." """ ... @experimental(version="3.0.0") @dataclass(frozen=True) class DspyGEPAResult: """ Additional data related to the GEPA run. Fields: - candidates: list of proposed candidates (component_name -> component_text) - parents: lineage info; for each candidate i, parents[i] is a list of parent indices or None - val_aggregate_scores: per-candidate aggregate score on the validation set (higher is better) - val_subscores: per-candidate per-instance scores on the validation set (len == num_val_instances) - per_val_instance_best_candidates: for each val instance t, a set of candidate indices achieving the best score on t - discovery_eval_counts: Budget (number of metric calls / rollouts) consumed up to the discovery of each candidate - total_metric_calls: total number of metric calls made across the run - num_full_val_evals: number of full validation evaluations performed - log_dir: where artifacts were written (if any) - seed: RNG seed for reproducibility (if known) - best_idx: candidate index with the highest val_aggregate_scores - best_candidate: the program text mapping for best_idx """ # Data about the proposed candidates candidates: list[Module] parents: list[list[int | None]] val_aggregate_scores: list[float] val_subscores: list[list[float]] per_val_instance_best_candidates: list[set[int]] discovery_eval_counts: list[int] # Optional data best_outputs_valset: list[list[tuple[int, list[Prediction]]]] | None = None # Optimization metadata total_metric_calls: int | None = None num_full_val_evals: int | None = None log_dir: str | None = None seed: int | None = None @property def best_idx(self) -> int: scores = self.val_aggregate_scores return max(range(len(scores)), key=lambda i: scores[i]) @property def best_candidate(self) -> dict[str, str]: return self.candidates[self.best_idx] @property def highest_score_achieved_per_val_task(self) -> list[float]: return [ self.val_subscores[list(self.per_val_instance_best_candidates[val_idx])[0]][val_idx] for val_idx in range(len(self.val_subscores[0])) ] def to_dict(self) -> dict[str, Any]: cands = [ {k: v for k, v in cand.items()} for cand in self.candidates ] return dict( candidates=cands, parents=self.parents, val_aggregate_scores=self.val_aggregate_scores, best_outputs_valset=self.best_outputs_valset, val_subscores=self.val_subscores, per_val_instance_best_candidates=[list(s) for s in self.per_val_instance_best_candidates], discovery_eval_counts=self.discovery_eval_counts, total_metric_calls=self.total_metric_calls, num_full_val_evals=self.num_full_val_evals, log_dir=self.log_dir, seed=self.seed, best_idx=self.best_idx, ) @staticmethod def from_gepa_result(gepa_result: "GEPAResult", adapter: "DspyAdapter") -> "DspyGEPAResult": return DspyGEPAResult( candidates=[adapter.build_program(c) for c in gepa_result.candidates], parents=gepa_result.parents, val_aggregate_scores=gepa_result.val_aggregate_scores, best_outputs_valset=gepa_result.best_outputs_valset, val_subscores=gepa_result.val_subscores, per_val_instance_best_candidates=gepa_result.per_val_instance_best_candidates, discovery_eval_counts=gepa_result.discovery_eval_counts, total_metric_calls=gepa_result.total_metric_calls, num_full_val_evals=gepa_result.num_full_val_evals, log_dir=gepa_result.run_dir, seed=gepa_result.seed, ) @experimental(version="3.0.0") class GEPA(Teleprompter): """ GEPA is an evolutionary optimizer, which uses reflection to evolve text components of complex systems. GEPA is proposed in the paper [GEPA: Reflective Prompt Evolution Can Outperform Reinforcement Learning](https://arxiv.org/abs/2507.19457). The GEPA optimization engine is provided by the `gepa` package, available from [https://github.com/gepa-ai/gepa](https://github.com/gepa-ai/gepa). GEPA captures full traces of the DSPy module's execution, identifies the parts of the trace corresponding to a specific predictor, and reflects on the behaviour of the predictor to propose a new instruction for the predictor. GEPA allows users to provide textual feedback to the optimizer, which is used to guide the evolution of the predictor. The textual feedback can be provided at the granularity of individual predictors, or at the level of the entire system's execution. To provide feedback to the GEPA optimizer, implement a metric as follows: ``` def metric( gold: Example, pred: Prediction, trace: Optional[DSPyTrace] = None, pred_name: Optional[str] = None, pred_trace: Optional[DSPyTrace] = None, ) -> float | ScoreWithFeedback: \""" This function is called with the following arguments: - gold: The gold example. - pred: The predicted output. - trace: Optional. The trace of the program's execution. - pred_name: Optional. The name of the target predictor currently being optimized by GEPA, for which the feedback is being requested. - pred_trace: Optional. The trace of the target predictor's execution GEPA is seeking feedback for. Note the `pred_name` and `pred_trace` arguments. During optimization, GEPA will call the metric to obtain feedback for individual predictors being optimized. GEPA provides the name of the predictor in `pred_name` and the sub-trace (of the trace) corresponding to the predictor in `pred_trace`. If available at the predictor level, the metric should return {'score': float, 'feedback': str} corresponding to the predictor. If not available at the predictor level, the metric can also return a text feedback at the program level (using just the gold, pred and trace). If no feedback is returned, GEPA will use a simple text feedback consisting of just the score: f"This trajectory got a score of {score}." \""" ... ``` GEPA can also be used as a batch inference-time search strategy, by passing `valset=trainset, track_stats=True, track_best_outputs=True`, and using the `detailed_results` attribute of the optimized program (returned by `compile`) to get the Pareto frontier of the batch. `optimized_program.detailed_results.best_outputs_valset` will contain the best outputs for each task in the batch. Example: ``` gepa = GEPA(metric=metric, track_stats=True) batch_of_tasks = [dspy.Example(...) for task in tasks] new_prog = gepa.compile(student, trainset=trainset, valset=batch_of_tasks) pareto_frontier = new_prog.detailed_results.val_aggregate_scores # pareto_frontier is a list of scores, one for each task in the batch. ``` Args: metric: The metric function to use for feedback and evaluation. auto: The auto budget to use for the run. Options: "light", "medium", "heavy". max_full_evals: The maximum number of full evaluations to perform. max_metric_calls: The maximum number of metric calls to perform. reflection_minibatch_size: The number of examples to use for reflection in a single GEPA step. Default is 3. candidate_selection_strategy: The strategy to use for candidate selection. Default is "pareto", which stochastically selects candidates from the Pareto frontier of all validation scores. Options: "pareto", "current_best". reflection_lm: The language model to use for reflection. Required parameter. GEPA benefits from a strong reflection model. Consider using `dspy.LM(model='gpt-5', temperature=1.0, max_tokens=32000)` for optimal performance. skip_perfect_score: Whether to skip examples with perfect scores during reflection. Default is True. instruction_proposer: Optional custom instruction proposer implementing GEPA's ProposalFn protocol. **Default: None (recommended for most users)** - Uses GEPA's proven instruction proposer from the [GEPA library](https://github.com/gepa-ai/gepa), which implements the [`ProposalFn`](https://github.com/gepa-ai/gepa/blob/main/src/gepa/core/adapter.py). This default proposer is highly capable and was validated across diverse experiments reported in the GEPA paper and tutorials. See documentation on custom instruction proposers [here](https://dspy.ai/api/optimizers/GEPA/GEPA_Advanced/#custom-instruction-proposers). **Advanced Feature**: Only needed for specialized scenarios: - **Multi-modal handling**: Processing dspy.Image inputs alongside textual information - **Nuanced control over constraints**: Fine-grained control over instruction length, format, and structural requirements beyond standard feedback mechanisms - **Domain-specific knowledge injection**: Specialized terminology or context that cannot be provided through feedback_func alone - **Provider-specific prompting**: Optimizations for specific LLM providers (OpenAI, Anthropic) with unique formatting preferences - **Coupled component updates**: Coordinated updates of multiple components together rather than independent optimization - **External knowledge integration**: Runtime access to databases, APIs, or knowledge bases The default proposer handles the vast majority of use cases effectively. Use MultiModalInstructionProposer() from dspy.teleprompt.gepa.instruction_proposal for visual content or implement custom ProposalFn for highly specialized requirements. Note: When both instruction_proposer and reflection_lm are set, the instruction_proposer is called in the reflection_lm context. However, reflection_lm is optional when using a custom instruction_proposer. Custom instruction proposers can invoke their own LLMs if needed. component_selector: Custom component selector implementing the ReflectionComponentSelector protocol, or a string specifying a built-in selector strategy. Controls which components (predictors) are selected for optimization at each iteration. Defaults to 'round_robin' strategy which cycles through components one at a time. Available string options: 'round_robin' (cycles through components sequentially), 'all' (selects all components for simultaneous optimization). Custom selectors can implement strategies using LLM-driven selection logic based on optimization state and trajectories. See [gepa component selectors](https://github.com/gepa-ai/gepa/blob/main/src/gepa/strategies/component_selector.py) for available built-in selectors and the ReflectionComponentSelector protocol for implementing custom selectors. add_format_failure_as_feedback: Whether to add format failures as feedback. Default is False. use_merge: Whether to use merge-based optimization. Default is True. max_merge_invocations: The maximum number of merge invocations to perform. Default is 5. num_threads: The number of threads to use for evaluation with `Evaluate`. Optional. failure_score: The score to assign to failed examples. Default is 0.0. perfect_score: The maximum score achievable by the metric. Default is 1.0. Used by GEPA to determine if all examples in a minibatch are perfect. log_dir: The directory to save the logs. GEPA saves elaborate logs, along with all candidate programs, in this directory. Running GEPA with the same `log_dir` will resume the run from the last checkpoint. track_stats: Whether to return detailed results and all proposed programs in the `detailed_results` attribute of the optimized program. Default is False. use_wandb: Whether to use wandb for logging. Default is False. wandb_api_key: The API key to use for wandb. If not provided, wandb will use the API key from the environment variable `WANDB_API_KEY`. wandb_init_kwargs: Additional keyword arguments to pass to `wandb.init`. track_best_outputs: Whether to track the best outputs on the validation set. track_stats must be True if track_best_outputs is True. The optimized program's `detailed_results.best_outputs_valset` will contain the best outputs for each task in the validation set. warn_on_score_mismatch: GEPA (currently) expects the metric to return the same module-level score when called with and without the pred_name. This flag (defaults to True) determines whether a warning is raised if a mismatch in module-level and predictor-level score is detected. seed: The random seed to use for reproducibility. Default is 0. gepa_kwargs: (Optional) provide additional kwargs to be passed to [gepa.optimize](https://github.com/gepa-ai/gepa/blob/main/src/gepa/api.py) method Note: Budget Configuration: Exactly one of `auto`, `max_full_evals`, or `max_metric_calls` must be provided. The `auto` parameter provides preset configurations: "light" for quick experimentation, "medium" for balanced optimization, and "heavy" for thorough optimization. Reflection Configuration: The `reflection_lm` parameter is required and should be a strong language model. GEPA performs best with models like `dspy.LM(model='gpt-5', temperature=1.0, max_tokens=32000)`. The reflection process analyzes failed examples to generate feedback for program improvement. Merge Configuration: GEPA can merge successful program variants using `use_merge=True`. The `max_merge_invocations` parameter controls how many merge attempts are made during optimization. Evaluation Configuration: Use `num_threads` to parallelize evaluation. The `failure_score` and `perfect_score` parameters help GEPA understand your metric's range and optimize accordingly. Logging Configuration: Set `log_dir` to save detailed logs and enable checkpoint resuming. Use `track_stats=True` to access detailed optimization results via the `detailed_results` attribute. Enable `use_wandb=True` for experiment tracking and visualization. Reproducibility: Set `seed` to ensure consistent results across runs with the same configuration. """ def __init__( self, metric: GEPAFeedbackMetric, *, # Budget configuration auto: Literal["light", "medium", "heavy"] | None = None, max_full_evals: int | None = None, max_metric_calls: int | None = None, # Reflection configuration reflection_minibatch_size: int = 3, candidate_selection_strategy: Literal["pareto", "current_best"] = "pareto", reflection_lm: LM | None = None, skip_perfect_score: bool = True, add_format_failure_as_feedback: bool = False, instruction_proposer: "ProposalFn | None" = None, component_selector: "ReflectionComponentSelector | str" = "round_robin", # Merge-based configuration use_merge: bool = True, max_merge_invocations: int | None = 5, # Evaluation configuration num_threads: int | None = None, failure_score: float = 0.0, perfect_score: float = 1.0, # Logging log_dir: str = None, track_stats: bool = False, use_wandb: bool = False, wandb_api_key: str | None = None, wandb_init_kwargs: dict[str, Any] | None = None, track_best_outputs: bool = False, warn_on_score_mismatch: bool = True, use_mlflow: bool = False, # Reproducibility seed: int | None = 0, # GEPA passthrough kwargs gepa_kwargs: dict | None = None ): try: inspect.signature(metric).bind(None, None, None, None, None) except TypeError as e: raise TypeError( "GEPA metric must accept five arguments: (gold, pred, trace, pred_name, pred_trace). " "See https://dspy.ai/api/optimizers/GEPA for details." ) from e self.metric_fn = metric # Budget configuration assert ( (max_metric_calls is not None) + (max_full_evals is not None) + (auto is not None) == 1 ), ( "Exactly one of max_metric_calls, max_full_evals, auto must be set. " f"You set max_metric_calls={max_metric_calls}, " f"max_full_evals={max_full_evals}, " f"auto={auto}." ) self.auto = auto self.max_full_evals = max_full_evals self.max_metric_calls = max_metric_calls # Reflection configuration self.reflection_minibatch_size = reflection_minibatch_size self.candidate_selection_strategy = candidate_selection_strategy assert reflection_lm is not None or instruction_proposer is not None, ( "GEPA requires a reflection language model, or custom instruction proposer to be provided. " "Typically, you can use `dspy.LM(model='gpt-5', temperature=1.0, max_tokens=32000)` to get a good reflection model. " "Reflection LM is used by GEPA to reflect on the behavior of the program and propose new instructions, and will benefit from a strong model. " ) self.reflection_lm = reflection_lm self.skip_perfect_score = skip_perfect_score self.add_format_failure_as_feedback = add_format_failure_as_feedback # Merge-based configuration self.use_merge = use_merge self.max_merge_invocations = max_merge_invocations # Evaluation Configuration self.num_threads = num_threads self.failure_score = failure_score self.perfect_score = perfect_score # Logging configuration self.log_dir = log_dir self.track_stats = track_stats self.use_wandb = use_wandb self.wandb_api_key = wandb_api_key self.wandb_init_kwargs = wandb_init_kwargs self.warn_on_score_mismatch = warn_on_score_mismatch self.use_mlflow = use_mlflow if track_best_outputs: assert track_stats, "track_stats must be True if track_best_outputs is True." self.track_best_outputs = track_best_outputs # Reproducibility self.seed = seed self.custom_instruction_proposer = instruction_proposer self.component_selector = component_selector self.gepa_kwargs = gepa_kwargs or {} def auto_budget(self, num_preds, num_candidates, valset_size: int, minibatch_size: int = 35, full_eval_steps: int = 5) -> int: import numpy as np num_trials = int(max(2 * (num_preds * 2) * np.log2(num_candidates), 1.5 * num_candidates)) if num_trials < 0 or valset_size < 0 or minibatch_size < 0: raise ValueError("num_trials, valset_size, and minibatch_size must be >= 0.") if full_eval_steps < 1: raise ValueError("full_eval_steps must be >= 1.") V = valset_size N = num_trials M = minibatch_size m = full_eval_steps # Initial full evaluation on the default program total = V # Assume upto 5 trials for bootstrapping each candidate total += num_candidates * 5 # N minibatch evaluations total += N * M if N == 0: return total # no periodic/full evals inside the loop # Periodic full evals occur when trial_num % (m+1) == 0, where trial_num runs 2..N+1 periodic_fulls = (N + 1) // (m) + 1 # If 1 <= N < m, the code triggers one final full eval at the end extra_final = 1 if N < m else 0 total += (periodic_fulls + extra_final) * V return total def compile( self, student: Module, *, trainset: list[Example], teacher: Module | None = None, valset: list[Example] | None = None, ) -> Module: """ GEPA uses the trainset to perform reflective updates to the prompt, but uses the valset for tracking Pareto scores. If no valset is provided, GEPA will use the trainset for both. Parameters: - student: The student module to optimize. - trainset: The training set to use for reflective updates. - valset: The validation set to use for tracking Pareto scores. If not provided, GEPA will use the trainset for both. """ from gepa import GEPAResult, optimize from dspy.teleprompt.gepa.gepa_utils import DspyAdapter, LoggerAdapter assert trainset is not None and len(trainset) > 0, "Trainset must be provided and non-empty" assert teacher is None, "Teacher is not supported in DspyGEPA yet." if self.auto is not None: self.max_metric_calls = self.auto_budget( num_preds=len(student.predictors()), num_candidates=AUTO_RUN_SETTINGS[self.auto]["n"], valset_size=len(valset) if valset is not None else len(trainset), ) elif self.max_full_evals is not None: self.max_metric_calls = self.max_full_evals * (len(trainset) + (len(valset) if valset is not None else 0)) else: assert self.max_metric_calls is not None, "Either auto, max_full_evals, or max_metric_calls must be set." logger.info(f"Running GEPA for approx {self.max_metric_calls} metric calls of the program. This amounts to {self.max_metric_calls / len(trainset) if valset is None else self.max_metric_calls / (len(trainset) + len(valset)):.2f} full evals on the {'train' if valset is None else 'train+val'} set.") if valset is None: logger.warning("No valset provided; Using trainset as valset. This is useful as an inference-time scaling strategy where you want GEPA to find the best solutions for the provided tasks in the trainset, as it makes GEPA overfit prompts to the provided trainset. In order to ensure generalization and perform well on unseen tasks, please provide separate trainset and valset. Provide the smallest valset that is just large enough to match the downstream task distribution, while keeping trainset as large as possible.") valset = valset or trainset logger.info(f"Using {len(valset)} examples for tracking Pareto scores. You can consider using a smaller sample of the valset to allow GEPA to explore more diverse solutions within the same budget. GEPA requires you to provide the smallest valset that is just large enough to match your downstream task distribution, while providing as large trainset as possible.") rng = random.Random(self.seed) def feedback_fn_creator(pred_name: str, predictor) -> "PredictorFeedbackFn": def feedback_fn( predictor_output: dict[str, Any], predictor_inputs: dict[str, Any], module_inputs: Example, module_outputs: Prediction, captured_trace: "DSPyTrace", ) -> "ScoreWithFeedback": trace_for_pred = [(predictor, predictor_inputs, predictor_output)] o = self.metric_fn( module_inputs, module_outputs, captured_trace, pred_name, trace_for_pred, ) if hasattr(o, "feedback"): if o["feedback"] is None: o["feedback"] = f"This trajectory got a score of {o['score']}." return o else: return dict(score=o, feedback=f"This trajectory got a score of {o}.") return feedback_fn feedback_map = { k: feedback_fn_creator(k, v) for k, v in student.named_predictors() } # Build the DSPy adapter that encapsulates evaluation, trace capture, feedback extraction, and instruction proposal adapter = DspyAdapter( student_module=student, metric_fn=self.metric_fn, feedback_map=feedback_map, failure_score=self.failure_score, num_threads=self.num_threads, add_format_failure_as_feedback=self.add_format_failure_as_feedback, rng=rng, reflection_lm=self.reflection_lm, custom_instruction_proposer=self.custom_instruction_proposer, warn_on_score_mismatch=self.warn_on_score_mismatch ) # Instantiate GEPA with the simpler adapter-based API base_program = {name: pred.signature.instructions for name, pred in student.named_predictors()} gepa_result: GEPAResult = optimize( seed_candidate=base_program, trainset=trainset, valset=valset, adapter=adapter, # Reflection-based configuration reflection_lm=(lambda x: self.reflection_lm(x)[0]) if self.reflection_lm is not None else None, candidate_selection_strategy=self.candidate_selection_strategy, skip_perfect_score=self.skip_perfect_score, reflection_minibatch_size=self.reflection_minibatch_size, module_selector=self.component_selector, perfect_score=self.perfect_score, # Merge-based configuration use_merge=self.use_merge, max_merge_invocations=self.max_merge_invocations, # Budget max_metric_calls=self.max_metric_calls, # Logging logger=LoggerAdapter(logger), run_dir=self.log_dir, use_wandb=self.use_wandb, wandb_api_key=self.wandb_api_key, wandb_init_kwargs=self.wandb_init_kwargs, use_mlflow=self.use_mlflow, track_best_outputs=self.track_best_outputs, display_progress_bar=True, raise_on_exception=True, # Reproducibility seed=self.seed, **self.gepa_kwargs ) new_prog = adapter.build_program(gepa_result.best_candidate) if self.track_stats: dspy_gepa_result = DspyGEPAResult.from_gepa_result(gepa_result, adapter) new_prog.detailed_results = dspy_gepa_result return new_prog ``` -------------------------------------------------------------------------------- /tests/reliability/generate/utils.py: -------------------------------------------------------------------------------- ```python import importlib.util import json import os import pathlib import random import re import shutil import sys import tempfile from contextlib import contextmanager from dataclasses import dataclass from functools import wraps from typing import Any, Dict, List, Optional, Tuple import pydantic from datamodel_code_generator import InputFileType, generate import dspy from tests.reliability.utils import assert_program_output_correct, judge_dspy_configuration def _retry(retries): """ A decorator to retry a function a specified number of times. Args: retries (int): The number of retries before failing. """ def decorator(func): @wraps(func) def wrapper(*args, **kwargs): attempt = 0 while attempt < retries: try: return func(*args, **kwargs) except Exception as e: attempt += 1 print(f"Retrying {func.__name__} (attempt {attempt} of {retries})." f" Exception: {e}") if attempt >= retries: raise e return wrapper return decorator @_retry(retries=5) def generate_test_program(dst_path: str, additional_instructions: Optional[str] = None) -> dspy.Module: """ Generate a DSPy program for a reliability test case and save it to a destination path. Args: dst_path: The directory path to which to save the generated program. additional_instructions: Additional instructions for generating the program signature. Return: A dspy.Module object representing the generated program. """ def generate_models(schema: dict[str, Any], class_name: str) -> str: with tempfile.TemporaryDirectory() as tmp_dir: tmp_schema_path = os.path.join(tmp_dir, "schema.json") tmp_model_path = os.path.join(tmp_dir, "model.py") with open(tmp_schema_path, "w") as f: json.dump(schema, f) generate( input_=pathlib.Path(tmp_schema_path), input_file_type=InputFileType.JsonSchema, output=pathlib.Path(tmp_model_path), class_name=class_name, # For enums with only one value, use the value as a literal instead of an enum # in order to test literals enum_field_as_literal="one", # Don't use con* field types, which are deprecated in recent pydantic versions field_constraints=True, use_annotated=False, ) # Remove annotation imports from __future__, which break compatibility with Python's # built-in type hints _remove_line_from_file(tmp_model_path, "from __future__ import annotations") # Remove comments inserted by datamodel-code-generator from the generated model file _remove_comments_from_file(tmp_model_path) with open(tmp_model_path, "r") as f: return f.read() def rename_conflicting_fields( input_schema: dict[str, Any], output_schema: dict[str, Any], ) -> dict[str, Any]: input_fields = set(input_schema.get("properties", {})) output_schema["properties"] = { (f"{field}_output" if field in input_fields else field): properties for field, properties in output_schema.get("properties", {}).items() } # Update required fields, if they exist if "required" in output_schema: output_schema["required"] = [ f"{field}_output" if field in input_fields else field for field in output_schema["required"] ] return output_schema # Disable caching and use a nonzero temperature to ensure that new programs are generated # upon retry if there's an error in the generation process (e.g. the program has an # invalid signature) with judge_dspy_configuration(cache=False, temperature=0.5), tempfile.TemporaryDirectory() as tmp_dir: generated_signature = _get_test_program_generation_program()( additional_instructions=additional_instructions or "" ) input_schema = json.loads(generated_signature.program_input_fields) output_schema = json.loads(generated_signature.program_output_fields) # If there are conflicting field names between input and output schemas, rename the output # fields to avoid conflicts output_schema = rename_conflicting_fields(input_schema, output_schema) # Generate input and output models input_models = generate_models(schema=input_schema, class_name="ProgramInputs") output_models = generate_models(schema=output_schema, class_name="ProgramOutputs") # Write program code program_code = ( "### Input models ###\n" + input_models + "\n" + "### Output models ###\n" + output_models + "\n" + "### Program definition ###\n" + _get_test_program_signature_and_module_definition( program_description=generated_signature.program_description ) ) program_path = os.path.join(tmp_dir, "program.py") with open(program_path, "w") as f: f.write(program_code) # Validate the generated program by loading it before copying it to the destination path loaded_program, _ = load_generated_program(program_path) # Write schema _write_pretty_json( data=_clean_schema(_get_json_schema(loaded_program.signature)), path=os.path.join(tmp_dir, "schema.json"), ) # Copy all generated files to the destination path os.makedirs(dst_path, exist_ok=True) shutil.copytree(tmp_dir, dst_path, dirs_exist_ok=True) return loaded_program @_retry(retries=5) def generate_test_inputs( dst_path: str, program_path: str, num_inputs: int, additional_instructions: Optional[str] = None, ): """ Generate test inputs for a reliability test case and save them to a destination path. Args: dst_path: The directory path to which to save the generated test inputs. program_path: The path to the program for which to generate test inputs. num_inputs: The number of test inputs to generate. additional_instructions: Additional instructions for generating the test inputs. """ # Disable caching and use a nonzero temperature to ensure that new inputs are generated # upon retry if there's an error in the generation process (e.g. the input doesn't match the # program signature) with judge_dspy_configuration(cache=False, temperature=0.5), tempfile.TemporaryDirectory() as tmp_dir: program: dspy.Module program_input_schema: pydantic.BaseModel program, program_input_schema = load_generated_program(program_path) signature_json_schema = _get_json_schema(program.signature) inputs, outputs = _split_schema(signature_json_schema) generated_test_inputs = _get_test_inputs_generation_program()( program_description=program.signature.__doc__ or "", program_input_signature=_write_pretty_json({"properties": _clean_schema(inputs)}), program_output_signature=_write_pretty_json({"properties": _clean_schema(outputs)}), additional_instructions=additional_instructions or "", num_inputs=num_inputs, ).test_inputs[:num_inputs] def find_max_input_number(directory): if not os.path.exists(directory): return 0 max_number = 0 pattern = re.compile(r"input(\d+)\.json") for filename in os.listdir(directory): match = pattern.match(filename) if match: number = int(match.group(1)) max_number = max(max_number, number) return max_number base_input_number = find_max_input_number(dst_path) + 1 for idx, test_input in enumerate(generated_test_inputs): output_assertions = _get_assertions_generation_program()( program_description=program.signature.__doc__ or "", program_input=test_input.program_input, program_output_signature=_write_pretty_json({"properties": _clean_schema(outputs)}), ).output_assertions # Verify that the generated input is valid JSON and matches the input signature of the # program before saving it to the destination path _json_input_to_program_input( input_schema=program_input_schema, json_input=test_input.program_input, ) test_input_file_path = os.path.join(tmp_dir, f"input{base_input_number + idx}.json") json_program_input = json.loads(test_input.program_input) _write_pretty_json( data={ "input": json_program_input, "assertions": output_assertions, }, path=test_input_file_path, ) os.makedirs(dst_path, exist_ok=True) shutil.copytree(tmp_dir, dst_path, dirs_exist_ok=True) def load_generated_program(path) -> Tuple[dspy.Module, pydantic.BaseModel]: """ Loads a generated program from the specified file. Args: path: The path to the file containing the generated program. Returns: A tuple containing: 1. a dspy.Module object representing the generated program and 2. a pydantic.BaseModel object representing the program's input schema. """ if os.path.isdir(path): path = os.path.join(path, "program.py") if not os.path.exists(path): raise ValueError(f"DSPy test program file not found: {path}") program_module = _import_program_module_from_path(module_name="program", file_path=path) return program_module.program, program_module.ProgramInputs @dataclass class GeneratedTestCase: """ Represents a DSPy reliability test case that has been generated with the help of a DSPy program generator and program input generator. """ # The name of the test case for identification / debugging with pytest name: str # The local filesystem path to the program that the test case is testing. program_path: str # A JSON representation of the input to the program that the test case is testing. program_input: str # The assertions that the output of the program must satisfy for the test case to pass. output_assertions: list[str] def load_generated_cases(dir_path) -> list[GeneratedTestCase]: """ Recursively loads generated test cases from the specified directory and its subdirectories. Args: dir_path: The path to the directory containing the generated test cases. Returns: A list of GeneratedTestCase objects. """ test_cases = [] # Walk through all directories and subdirectories in dir_path for root, dirs, files in os.walk(dir_path): # Check if the directory contains a program.py and an inputs directory if "program.py" in files and "inputs" in dirs: program_path = os.path.join(root, "program.py") inputs_path = os.path.join(root, "inputs") # Load each JSON test input file in the inputs directory for input_file in os.listdir(inputs_path): if input_file.endswith(".json"): with open(os.path.join(inputs_path, input_file), "r") as f: # Best effort to extract a meaningful enclosing directory name # from the test path that can be used as part of the test case name readable_dir_name = os.path.basename(os.path.dirname(os.path.dirname(root))) test_case_name = ( f"{readable_dir_name}-" f"{os.path.basename(root)}-" f"{os.path.splitext(input_file)[0]}" ) program_input_and_assertions = json.load(f) program_input = program_input_and_assertions["input"] assertions = program_input_and_assertions["assertions"] # Create a GeneratedTestCase object and add it to the list test_cases.append( GeneratedTestCase( name=test_case_name, program_path=program_path, program_input=json.dumps(program_input), output_assertions=assertions, ) ) return test_cases def run_generated_case(generated_case: GeneratedTestCase): """ Runs a generated reliability test case by 1. running the test case program on the test case input using the global DSPy configuration and 2. verifying that the output of the program satisfies the assertions specified in the test case. Args: generated_case: The generated test case to run. """ program, program_input_schema = load_generated_program(generated_case.program_path) program_input = _json_input_to_program_input( input_schema=program_input_schema, json_input=generated_case.program_input, ) program_output = program(**program_input) for assertion in generated_case.output_assertions: assert_program_output_correct( program_input=program_input, program_output=program_output, grading_guidelines=assertion, ) def _get_test_program_signature_and_module_definition(program_description: str) -> str: """ Generate the signature and model definition for a test DSPy program. Args: program_description: A description of the generated program. """ use_cot = random.choice([True, False]) if use_cot: program_var_definition = "program = dspy.ChainOfThought(program_signature)" else: program_var_definition = "program = dspy.Predict(program_signature)" return ''' import dspy class BaseSignature(dspy.Signature): """ {program_description} """ program_signature = BaseSignature for input_field_name, input_field in ProgramInputs.model_fields.items(): program_signature = program_signature.append( name=input_field_name, field=dspy.InputField(description=input_field.description), type_=input_field.annotation, ) for output_field_name, output_field in ProgramOutputs.model_fields.items(): program_signature = program_signature.append( name=output_field_name, field=dspy.OutputField(description=input_field.description), type_=output_field.annotation, ) {program_var_definition} '''.format(program_description=program_description, program_var_definition=program_var_definition) def _get_test_program_generation_program() -> dspy.Module: """ Create a DSPy program for generating other DSPy test programs. Returns: A dspy.Module object representing the program generation program. """ class ProgramGeneration(dspy.Signature): """ Creates an AI program definition, including the AI program's description, input fields, and output fields. The AI program should be designed to solve a real problem for its users and produce correct outputs for a variety of inputs. The input fields and the output fields must be represented in JSON Schema format, including field names, types, and descriptions. The JSON schema definitions themselves MUST be valid JSON without any extra text (no backticks, no explanatory text, etc.). It's very important to be sure that the additional instructions, if specified, are obeyed precisely in absolutely all cases. """ additional_instructions: str = dspy.InputField( description="Additional instructions for what kind of program to generate and how to generate it" ) program_description: str = dspy.OutputField( description="A description of the generated AI program, including its purpose and expected behavior" ) program_input_fields: str = dspy.OutputField( description="The input fields of the generated program in JSON Schema format, including input field names, types, and descriptions." ) program_output_fields: str = dspy.OutputField( description="The output fields of the generated program in JSON Schema format, including input field names, types, and descriptions." ) return dspy.ChainOfThought(ProgramGeneration) def _get_test_inputs_generation_program() -> dspy.Module: """ Create a DSPy program for generating test inputs for a given DSPy test program. Returns: A dspy.Module object representing the test input generation program. """ class _TestInputsGeneration(dspy.Signature): """ Given the description and input / output signature (format) of an AI program that is designed to produce correct outputs for a variety of inputs while adhering to the input / output signature, generate test inputs used to verify that the program indeed produces correct outputs. The AI program uses LLM prompting with carefully crafted prompt templates to generate responses. When generating an input, do not think about how the program will respond. Instead, focus on creating valid and interesting inputs that are likely to test the program's capabilities. It's very important to be sure that the additional instructions, if specified, are obeyed precisely in absolutely all cases. """ program_description: str = dspy.InputField( description="A description of the AI program being tested, including its purpose and expected behavior" ) program_input_signature: str = dspy.InputField( description="The input signature of the program in JSON Schema format, including input field names, types, and descriptions. The outermost fields in the JSON schema definition represent the top-level input fields of the program." ) program_output_signature: str = dspy.InputField( description="The output signature of the program in JSON Schema format, including output field names, types, and descriptions. The outermost fields in the JSON schema definition represent the top-level output fields of the program." ) additional_instructions: str = dspy.InputField(description="Additional instructions for generating test inputs") test_inputs: list[_TestInput] = dspy.OutputField( description="Generated test inputs for the program, used to verify the correctness of the program outputs for a variety of inputs" ) return dspy.ChainOfThought(_TestInputsGeneration) class _TestInput(pydantic.BaseModel): """ Represents a generated test input for a DSPy program. """ program_input: str = pydantic.Field( "Generated input matching the program signature that will be used to test the program, represented as a JSON string." " The schema of the JSON string must match the input signature of the program precisely, including any wrapper objects." " Be very careful to ensure that the input is valid JSON and matches the input signature of the program, with correct" " field nesting." ) def _get_assertions_generation_program() -> dspy.Module: """ Create a DSPy program for generating assertions that verify the correctness of outputs from other DSPy programs. """ class _TestInputsGeneration(dspy.Signature): """ Given 1. the description and input / output signature (format) of an AI program that is designed to produce correct outputs for a variety of inputs while adhering to the input / output signature and 2. an example input to the AI program, generate assertions that can be used to verify the correctness of the program output. Assertions should be expressed in natural language where possible, rather than code. Only include code if necessary to clarify the assertion. Assertions should be objective and verifiable, with minimal subjectivity only where absolutely necessary. There should be a limited number of assertions, ideally about 5, that are sufficient to verify the correctness of the program output. If it's too difficult to generate accurate assertions, leave them blank. """ program_description: str = dspy.InputField( description="A description of the AI program being tested, including its purpose and expected behavior" ) program_input: str = dspy.InputField( description="An example input to the AI program, represented as a JSON string" ) program_output_signature: str = dspy.InputField( description="The output signature of the program in JSON Schema format, including output field names, types, and descriptions. The outermost fields in the JSON schema definition represent the top-level output fields of the program." ) output_assertions: list[str] = dspy.OutputField( description="Assertions used to verify the correctness of the program output after running the program on the specified input" ) return dspy.ChainOfThought(_TestInputsGeneration) def _clean_json_schema_property(prop: dict[str, Any]) -> dict[str, Any]: """ Remove unnecessary keys from a JSON schema property dictionary, as well as all of its child properties. Args: prop: The JSON schema property dictionary to clean. Returns: The cleaned JSON schema property dictionary. """ cleaned_prop = { k: v for k, v in prop.items() if k not in {"desc", "__dspy_field_type", "title", "prefix", "required"} } # Recursively clean nested properties if "properties" in cleaned_prop: cleaned_prop["properties"] = {k: _clean_json_schema_property(v) for k, v in cleaned_prop["properties"].items()} return cleaned_prop def _get_json_schema(signature: dspy.Signature) -> dict[str, Any]: """ Obtain the JSON schema representation of a DSPy signature. Args: signature: The DSPy signature for which to generate a JSON schema. Returns: A JSON schema representation of the signature. """ def expand_refs(schema: dict[str, Any], definitions: dict[str, Any]) -> dict[str, Any]: """ Expand $ref fields in a JSON schema, inlining the referenced schema definitions directly into the $ref field locations. """ if isinstance(schema, dict): if "$ref" in schema: ref_path = schema["$ref"].replace("#/$defs/", "") ref_schema = definitions.get(ref_path, {}) if "__dspy_field_type" in schema: ref_schema["__dspy_field_type"] = schema["__dspy_field_type"] # Recursively expand the reference schema as well return expand_refs(ref_schema, definitions) else: # Recursively expand properties in the schema return {key: expand_refs(value, definitions) for key, value in schema.items()} elif isinstance(schema, list): return [expand_refs(item, definitions) for item in schema] return schema signature_schema_with_refs = signature.schema() definitions = signature_schema_with_refs.pop("$defs", {}) return expand_refs(signature_schema_with_refs, definitions) def _split_schema(schema: dict[str, Any]) -> Tuple[dict[str, Any], dict[str, Any]]: """ Split a JSON schema into input and output components based on DSPy field types. Args: schema: The JSON schema to split. Returns: A tuple containing the input and output components of the schema. """ inputs = {} outputs = {} # Traverse the properties to categorize inputs and outputs for key, prop in schema.get("properties", {}).items(): # Clean the property cleaned_prop = _clean_schema(prop) # Determine if the property is input or output based on __dspy_field_type field_type = prop.get("__dspy_field_type") if field_type == "input": inputs[key] = cleaned_prop elif field_type == "output" or field_type is None: outputs[key] = cleaned_prop # Handle nested properties for complex models if "properties" in prop: nested_inputs, nested_outputs = _split_schema(prop) if nested_inputs and field_type == "input": inputs[key] = {"properties": nested_inputs, **cleaned_prop} elif nested_outputs and (field_type == "output" or field_type is None): outputs[key] = {"properties": nested_outputs, **cleaned_prop} return inputs, outputs def _clean_schema(prop: dict[str, Any]) -> dict[str, Any]: """ Recursively clean a JSON schema property by removing unnecessary keys. Args: prop: The JSON schema property to clean. Returns: A cleaned version of the property. """ keys_to_remove = ["__dspy_field_type", "title"] # Add any other keys to be removed here # Iterate through the dictionary, applying cleaning recursively if value is a nested dict cleaned_prop = { k: (_clean_schema(v) if isinstance(v, dict) else v) # Recurse if value is a dict for k, v in prop.items() if k not in keys_to_remove } return cleaned_prop def _json_input_to_program_input(input_schema: pydantic.BaseModel, json_input: str) -> dict[str, Any]: """ Convert a JSON input string to a DSPy program input dictionary, validating it against the provided program signature. Args: input_schema: A pydantic model representing the program input schema. json_input: The JSON input string to convert to a DSPy program input. Returns: The converted DSPy program input dictionary. """ json_input = json.loads(json_input) program_input: pydantic.BaseModel = input_schema.model_validate(json_input) return {field: getattr(program_input, field) for field in program_input.__fields__} @contextmanager def _temporarily_prepend_to_system_path(path): """ Temporarily prepend a path to the system path for the duration of a context. Args: path: The path to prepend to the system path. """ original_sys_path = sys.path.copy() try: sys.path.insert(0, path) yield finally: sys.path = original_sys_path def _import_program_module_from_path(module_name: str, file_path: str): """ Import a Python module containing a DSPy program from a specified file path. Args: module_name: The name of the module containing the DSPy program to import. file_path: The path to the file containing the module definition. """ program_dir = os.path.dirname(file_path) with _temporarily_prepend_to_system_path(program_dir): spec = importlib.util.spec_from_file_location(module_name, file_path) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) return module def _remove_line_from_file(file_path: str, line_to_remove: str): """ Remove all instances of a specific line from a file. Args: file_path: The path to the file from which to remove all instances of the line. line_to_remove: The line to remove from the file. """ # Read all lines from the file with open(file_path, "r") as file: lines = file.readlines() # Write all lines back except the one to remove with open(file_path, "w") as file: for line in lines: if line.strip() != line_to_remove: file.write(line) def _remove_comments_from_file(file_path: str) -> None: """ Removes all lines with comments from the specified file. Args: file_path: Path to the file where comments should be removed. """ # Read the file contents with open(file_path, "r") as file: lines = file.readlines() # Filter out lines that start with '#' cleaned_lines = [line for line in lines if not line.strip().startswith("#")] # Write the cleaned lines back to the file with open(file_path, "w") as file: file.writelines(cleaned_lines) def _write_pretty_json(data: dict[str, Any], path: Optional[str] = None) -> Optional[str]: """ Format JSON data with indentation, and write it to a file if specified. Args: data: The JSON data to format. path: The optional path to which to write the formatted JSON data. Returns: The formatted JSON data as a string, if no path is specified. """ formatted_json = json.dumps(data, indent=4) if path: with open(path, "w") as f: f.write(formatted_json) return None else: return formatted_json ``` -------------------------------------------------------------------------------- /tests/teleprompt/test_gepa.py: -------------------------------------------------------------------------------- ```python import json import threading from typing import Any from unittest import mock import pytest import dspy import dspy.clients from dspy import Example from dspy.predict import Predict from dspy.teleprompt.gepa import instruction_proposal from dspy.utils.dummies import DummyLM class SimpleModule(dspy.Module): def __init__(self, signature): super().__init__() self.predictor = Predict(signature) def forward(self, **kwargs): return self.predictor(**kwargs) class DictDummyLM(dspy.clients.lm.LM): def __init__(self, history): super().__init__("dummy", "chat", 0.0, 1000, True) self.history = {} for m in history: self.history[hash(repr(m["messages"]))] = m def __call__(self, prompt=None, messages=None, **kwargs): assert hash(repr(messages)) in self.history, f"Message {messages} not found in history" m = self.history[hash(repr(messages))] return m["outputs"] def simple_metric(example, prediction, trace=None, pred_name=None, pred_trace=None): return dspy.Prediction(score=example.output == prediction.output, feedback="Wrong answer.") def bad_metric(example, prediction): return 0.0 def test_gepa_adapter_disables_logging_during_trace_capture(monkeypatch): from dspy.teleprompt import bootstrap_trace as bootstrap_trace_module from dspy.teleprompt.gepa import gepa_utils class DummyModule(dspy.Module): def forward(self, **kwargs): # pragma: no cover - stub forward return dspy.Prediction() # Exercise the adapter evaluate path directly. adapter = gepa_utils.DspyAdapter( student_module=SimpleModule("input -> output"), metric_fn=simple_metric, feedback_map={}, failure_score=0.0, ) captured_kwargs: dict[str, Any] = {} def dummy_bootstrap_trace_data(*args, **kwargs): captured_kwargs.update(kwargs) return [] monkeypatch.setattr(bootstrap_trace_module, "bootstrap_trace_data", dummy_bootstrap_trace_data) monkeypatch.setattr( gepa_utils.DspyAdapter, "build_program", lambda self, candidate: DummyModule(), ) adapter.evaluate(batch=[], candidate={}, capture_traces=True) assert captured_kwargs["callback_metadata"] == {"disable_logging": True} @pytest.fixture def mock_mlflow(): mock_mlflow = mock.MagicMock() with mock.patch.dict("sys.modules", {"mlflow": mock_mlflow}): yield mock_mlflow @pytest.mark.parametrize("use_mlflow", [True, False]) def test_basic_workflow(use_mlflow, mock_mlflow): """Test to ensure the basic compile flow runs without errors.""" student = SimpleModule("input -> output") with open("tests/teleprompt/gepa_dummy_lm.json") as f: data = json.load(f) lm_history = data["lm"] reflection_lm_history = data["reflection_lm"] lm_main = DictDummyLM(lm_history) dspy.settings.configure(lm=lm_main) reflection_lm = DictDummyLM(reflection_lm_history) optimizer = dspy.GEPA( metric=simple_metric, reflection_lm=reflection_lm, max_metric_calls=5, use_mlflow=use_mlflow ) trainset = [ Example(input="What is the color of the sky?", output="blue").with_inputs("input"), Example(input="What does the fox say?", output="Ring-ding-ding-ding-dingeringeding!").with_inputs("input"), ] optimized_program = optimizer.compile(student, trainset=trainset, valset=trainset) assert optimized_program.predictor.signature.instructions == 'Given the field `input` containing a question or phrase, produce the field `output` containing the exact, direct, and contextually appropriate answer or response that the user expects, without additional explanations, commentary, or general knowledge unless explicitly requested.\n\nKey details and guidelines:\n\n1. The `input` field contains a question or phrase that may be literal, factual, or culturally specific (e.g., references to popular culture or memes).\n\n2. The `output` must be the precise answer or response that directly addresses the `input` as intended by the user, not a general or encyclopedic explanation.\n\n3. If the `input` is a well-known phrase or question from popular culture (e.g., "What does the fox say?"), the `output` should reflect the expected or canonical answer associated with that phrase, rather than a factual or scientific explanation.\n\n4. Avoid providing additional background information, scientific explanations, or alternative interpretations unless explicitly requested.\n\n5. The goal is to produce the answer that the user expects or the "correct" answer in the context of the question, including culturally recognized or meme-based answers.\n\n6. If the `input` is a straightforward factual question (e.g., "What is the color of the sky?"), provide the commonly accepted direct answer (e.g., "Blue") rather than a detailed scientific explanation.\n\n7. The output should be concise, clear, and focused solely on answering the question or phrase in the `input`.\n\nExample:\n\n- Input: "What is the color of the sky?"\n- Output: "Blue."\n\n- Input: "What does the fox say?"\n- Output: "Ring-ding-ding-ding-dingeringeding!"\n\nThis approach ensures that the assistant provides the expected, contextually appropriate answers rather than general or overly detailed responses that may be considered incorrect by the user.' if use_mlflow: assert mock_mlflow.start_run.call_count == 1 else: assert mock_mlflow.start_run.call_count == 0 def test_workflow_with_custom_instruction_proposer_and_component_selector(): """Test to ensure the basic compile flow runs without errors when using a custom instruction proposer and component selector.""" class TimeReader(dspy.Module): def __init__(self): super().__init__() self.hour_predictor = dspy.ChainOfThought("clock_photo: dspy.Image -> hour: int") self.minute_predictor = dspy.ChainOfThought("clock_photo: dspy.Image -> minute: int") self.parallel = dspy.Parallel(num_threads=2) def forward(self, clock_photo: dspy.Image): hour_prediction, minute_prediction = self.parallel( [ (self.hour_predictor, dict(clock_photo=clock_photo)), (self.minute_predictor, dict(clock_photo=clock_photo)), ] ) return dspy.Prediction(hour=hour_prediction.hour, minute=minute_prediction.minute) def metric(example, prediction, trace=None, pred_name=None, pred_trace=None): target_hour, target_minute = example.hour, example.minute predicted_hour, predicted_minute = prediction.hour, prediction.minute score = -abs(target_hour * 60 + target_minute - (predicted_hour * 60 + predicted_minute)) return dspy.Prediction( score=score, feedback=f"Target: {target_hour}:{target_minute}, Predicted: {predicted_hour}:{predicted_minute}", ) def all_component_selector(state, trajectories, subsample_scores, candidate_idx, candidate): """Select all components.""" return list(candidate.keys()) student = TimeReader() with open("tests/teleprompt/gepa_dummy_lm_custom_component_selector_custom_instruction_proposer.json") as f: data = json.load(f) lm_history = data["lm"] reflection_lm_history = data["reflection_lm"] lm_main = DictDummyLM(lm_history) reflection_lm = DictDummyLM(reflection_lm_history) dspy.settings.configure(lm=lm_main) optimizer = dspy.GEPA( metric=metric, reflection_lm=reflection_lm, max_metric_calls=5, instruction_proposer=instruction_proposal.MultiModalInstructionProposer(), component_selector=all_component_selector, num_threads=16, ) trainset = [ Example( clock_photo=dspy.Image( "https://upload.wikimedia.org/wikipedia/commons/thumb/c/cf/Pendulum_clock_by_Jacob_Kock%2C_antique_furniture_photography%2C_IMG_0931_edit.jpg/500px-Pendulum_clock_by_Jacob_Kock%2C_antique_furniture_photography%2C_IMG_0931_edit.jpg", download=False, ), hour=8, minute=18, ).with_inputs("clock_photo"), Example( clock_photo=dspy.Image( "https://upload.wikimedia.org/wikipedia/commons/thumb/a/a5/Telechron_clock_2H07-Br_Administrator.JPG/960px-Telechron_clock_2H07-Br_Administrator.JPG", download=False, ), hour=4, minute=16, ).with_inputs("clock_photo"), ] o = optimizer.compile(student, trainset=trainset, valset=trainset) assert o.hour_predictor.predict.signature.instructions == "Task\n- Input: clock_photo (an image of an analog clock)\n- Output: hour (an integer 1\u201312). Output only the hour number with no extra text.\n\nGoal\n- Determine the correct hour by accurately identifying the hour hand and its position relative to the hour marks, taking into account the minute hand\u2019s position (since the hour hand moves continuously between numbers).\n\nStep-by-step procedure\n1) Find the dial and pivot\n- Locate the clock face and the central pivot where all hands originate.\n- Ignore decorative elements that do not originate at the central pivot (e.g., ornaments, shadows, reflections).\n\n2) Determine the 12 o\u2019clock direction\n- Prefer the numeral \u201c12\u201d if visible. Otherwise use the upright orientation of numerals or the topmost marker.\n- If the photo is rotated, mentally rotate so numerals read upright: 12 at top, 3 right, 6 bottom, 9 left.\n\n3) Identify the hands correctly (do not assume a default \u201c10:10\u201d)\n- Second hand: thinnest, often with a counterweight, may span very long; ignore for the hour.\n- Minute hand: longest, usually reaches or nearly reaches the outer minute tick marks.\n- Hour hand: shortest, usually thicker, typically ends well inside the numerals.\n- If ambiguous, classify by tip distance from center: minute \u2265 hour. Use the piece actually anchored at the pivot, not its shadow.\n\n4) Measure positions (angles)\n- Measure each hand\u2019s angle clockwise from 12 o\u2019clock.\n- Minute angle \u03b8m \u2248 position of the minute hand; hour angle \u03b8h \u2248 position of the hour hand.\n\n5) Use minute-hand position to validate the hour-hand location\n- The hour hand advances 0.5\u00b0 per minute (i.e., 1/12 of the distance between hour marks every 5 minutes).\n- Sanity check examples:\n - ~15 minutes past: hour hand \u2248 1/4 of the way from the current hour toward the next.\n - ~30 minutes: \u2248 halfway.\n - ~45 minutes: \u2248 3/4 of the way.\n- If this relationship doesn\u2019t hold, you likely swapped hour and minute hands\u2014re-identify them.\n\n6) Determine the hour\n- Compute the \u201clast passed\u201d hour: H = floor((\u03b8h mod 360) / 30). Map 0 to 12 (i.e., if floor(...) = 0, H = 12).\n- Do not round up to the next hour. The correct hour is the number the hour hand has most recently passed, not the one it is approaching.\n- If the hour hand appears exactly on an hour mark but the minute hand is not at 12, treat it as still between hours and choose the lower (last passed) hour.\n\n7) Edge cases and robustness\n- Stylized or missing numerals: rely on the 12/3/6/9 axes and tick marks rather than numeral shapes.\n- Roman numerals: \u201c4\u201d may be IIII; positions are unchanged.\n- Ignore mirrored effects, reflections, and shadows; only consider hands anchored at the pivot.\n- Overlap times: if hands nearly overlap, use \u03b8m to ensure the hour hand offset matches 0.5\u00b0 per minute.\n- Return 12, not 0, when appropriate (e.g., just after 12:00).\n\nOutput format\n- Provide only: hour as an integer in [1,12], with no additional text.\n\nCommon error prevention (from prior mistakes)\n- Do not confuse the minute hand for the hour hand; verify by length and reach to the outer tick marks.\n- Do not infer times like \u201c10:10\u201d by default; always read from the actual hand angles.\n- Ensure the hour chosen matches the \u201clast passed\u201d number given the minute hand\u2019s position (e.g., at ~:16, the hour hand must be just past the hour, not near 1 when the minute hand is at 3)." assert o.minute_predictor.predict.signature.instructions == "Task: From the image field clock_photo (an analog clock), output the minute value as an integer from 0\u201359 in the field minute. Output only the minute number\u2014no text or other fields.\n\nWhat to analyze\n- Clock face orientation: Identify where \u201c12\u201d is on the dial. Use the numerals (Arabic or Roman, stylized fonts) or the positions of 3, 6, 9, 12 to set the reference. If the photo is tilted, measure angles relative to the clock face, not the image frame.\n- Hands identification (do not confuse them):\n - Minute hand: typically the longest solid hand reaching near the minute ticks/outer ring; thicker than the second hand; often has a pronounced pointer tip.\n - Hour hand: shorter and thicker, typically ends near the numerals.\n - Second hand (if present): the thinnest, often the longest, usually with a counterweight; ignore it for minute reading.\n - If two non-second hands look similar, the one whose tip reaches closer to the minute tick ring is the minute hand.\n- Ticks and numerals: Each numeral-to-numeral segment equals 5 minutes. If minute tick marks exist, use them. If not, divide each numeral interval evenly into five.\n\nHow to compute the minute\n1. Locate the clock center and the minute hand\u2019s tip.\n2. Determine the angle of the minute hand from the 12 o\u2019clock direction, increasing clockwise.\n3. Convert angle to minutes: minute_estimate = (angle_from_12 / 6). Round to the nearest whole minute.\n - Mapping: 12 \u2192 0, 1 \u2192 5, 2 \u2192 10, 3 \u2192 15, 4 \u2192 20, 5 \u2192 25, 6 \u2192 30, 7 \u2192 35, 8 \u2192 40, 9 \u2192 45, 10 \u2192 50, 11 \u2192 55.\n - If the tip is slightly past a numeral (e.g., just past 3), do not snap to the numeral; round to the nearest minute (e.g., 16 instead of 15).\n4. Consistency check with the hour hand (useful to avoid off-by-one and hand mix-ups):\n - The hour hand moves continuously: it advances 0.5 degrees per minute (i.e., 1/12 of the way to the next numeral every 5 minutes).\n - If your minute_estimate is an exact multiple of 5 but the hour hand is clearly between hour markers (not aligned with an hour), re-examine: the minute hand is likely slightly past the numeral; adjust to the nearest minute accordingly.\n - If the minute hand choice is ambiguous, infer the minute from the hour hand\u2019s fraction toward the next hour: minute \u2248 fraction_between_hour_markers \u00d7 60, then choose the hand assignment that matches this.\n5. Edge cases:\n - Overlapping hands: Look at which tip extends farther toward the tick ring to identify the minute hand.\n - Strong perspective or glare: Use the line from center to the visible tip; ignore reflections.\n - No minute ticks: Evenly interpolate between numerals.\n - Subdials or decorative elements (e.g., pendulum windows) are not the minute indicator; use the main dial only.\n\nOutput format\n- Return only the integer minute value (0\u201359) in the minute field.\n- If the angle computes to 60, output 0.\n\nError prevention reminders\n- Do not treat the hour hand as the minute hand.\n- Do not use the second hand to compute minutes.\n- Do not assume the minute hand is exactly on a numeral\u2014check for slight offsets and round to the nearest minute.\n- Ensure the final minute agrees with the hour hand\u2019s position trend (hour hand slightly past an hour implies minutes > 0)." def test_metric_requires_feedback_signature(): reflection_lm = DictDummyLM([]) with pytest.raises(TypeError): dspy.GEPA(metric=bad_metric, reflection_lm=reflection_lm, max_metric_calls=1) def any_metric( gold: dspy.Example, pred: dspy.Prediction, trace: Any = None, pred_name: str | None = None, pred_trace: Any = None, ) -> float: """ For this test, we only care that the program runs, not the score. """ return 0.0 # ← Just returns 0.0, doesn't access any attributes! def test_gepa_compile_with_track_usage_no_tuple_error(caplog): """ GEPA.compile should not log tuple-usage error when track_usage=True and complete without hanging. Before, compile would hang and/or log "'tuple' object has no attribute 'set_lm_usage'" repeatedly. """ student = dspy.Predict("question -> answer") trainset = [dspy.Example(question="What is 2+2?", answer="4").with_inputs("question")] task_lm = DummyLM([{"answer": "mock answer 1"}]) reflection_lm = DummyLM([{"new_instruction": "Something new."}]) compiled_container: dict[str, Any] = {} exc_container: dict[str, BaseException] = {} def run_compile(): try: with dspy.context(lm=task_lm, track_usage=True): optimizer = dspy.GEPA(metric=any_metric, reflection_lm=reflection_lm, max_metric_calls=3) compiled_container["prog"] = optimizer.compile(student, trainset=trainset, valset=trainset) except BaseException as e: exc_container["e"] = e t = threading.Thread(target=run_compile, daemon=True) t.start() t.join(timeout=1.0) # Assert compile did not hang (pre-fix behavior would time out here) assert not t.is_alive(), "GEPA.compile did not complete within timeout (likely pre-fix behavior)." # Assert no tuple-usage error is logged anymore assert "'tuple' object has no attribute 'set_lm_usage'" not in caplog.text # If any exception occurred, fail explicitly if "e" in exc_container: pytest.fail(f"GEPA.compile raised unexpectedly: {exc_container['e']}") # No timeout, no exception -> so the program must exist if "prog" not in compiled_container: pytest.fail("GEPA.compile did return a program (likely pre-fix behavior).") class MultiComponentModule(dspy.Module): """Test module with multiple predictors.""" def __init__(self): super().__init__() self.classifier = Predict("input -> category") self.generator = Predict("category, input -> output") def forward(self, input): category = self.classifier(input=input).category output = self.generator(category=category, input=input).output return dspy.Prediction(category=category, output=output) def component_selection_metric(example, prediction, trace=None, pred_name=None, pred_trace=None): """Simple metric for component selection testing.""" return dspy.Prediction(score=0.3, feedback="Test feedback") def test_component_selector_functionality(): """Test custom component selector function can select single/multiple components.""" # Track calls for verification selector_calls = [] def test_selector(state, trajectories, subsample_scores, candidate_idx, candidate): selector_calls.append({"components": list(candidate.keys()), "candidate_idx": candidate_idx}) # Test both single and multiple selection return ["classifier"] if candidate_idx == 0 else ["classifier", "generator"] student = MultiComponentModule() # Provide enough responses for all possible LM calls during optimization task_lm = DummyLM([{"category": "test_category", "output": "test_output"}] * 20) reflection_lm = DummyLM( [ {"improved_instruction": "Improved classifier instruction"}, {"improved_instruction": "Improved generator instruction"}, ] * 10 ) trainset = [dspy.Example(input="test", output="expected").with_inputs("input")] with dspy.context(lm=task_lm): optimizer = dspy.GEPA( metric=component_selection_metric, reflection_lm=reflection_lm, max_metric_calls=6, # Reduced to minimize output component_selector=test_selector, ) result = optimizer.compile(student, trainset=trainset, valset=trainset) # Verify selector was called with correct parameters assert len(selector_calls) > 0, "Custom selector should be invoked" assert "classifier" in selector_calls[0]["components"], "Should receive all available components" assert "generator" in selector_calls[0]["components"], "Should receive all available components" assert result is not None, "Should return optimized program" def test_component_selector_default_behavior(): """Test default behavior when no custom selector provided.""" student = MultiComponentModule() # Provide enough responses for all possible LM calls task_lm = DummyLM([{"category": "test_category", "output": "test_output"}] * 15) reflection_lm = DummyLM([{"improved_instruction": "Better instruction"}] * 8) trainset = [dspy.Example(input="test", output="expected").with_inputs("input")] with dspy.context(lm=task_lm): # No component_selector - should use round-robin default optimizer = dspy.GEPA( metric=component_selection_metric, reflection_lm=reflection_lm, max_metric_calls=4, # Minimal calls to reduce noise ) result = optimizer.compile(student, trainset=trainset, valset=trainset) assert result is not None, "Should work with default selector" def test_component_selector_string_round_robin(): """Test string-based round_robin selector.""" student = MultiComponentModule() # Provide enough responses for all possible LM calls task_lm = DummyLM([{"category": "test_category", "output": "test_output"}] * 15) reflection_lm = DummyLM([{"improved_instruction": "Better instruction"}] * 8) trainset = [dspy.Example(input="test", output="expected").with_inputs("input")] with dspy.context(lm=task_lm): optimizer = dspy.GEPA( metric=component_selection_metric, reflection_lm=reflection_lm, max_metric_calls=4, component_selector="round_robin", # String-based selector ) result = optimizer.compile(student, trainset=trainset, valset=trainset) assert result is not None, "Should work with 'round_robin' string selector" def test_component_selector_string_all(): """Test string-based 'all' selector and verify it actually updates all components.""" student = MultiComponentModule() # Store original instructions to verify they get updated original_classifier_instruction = student.classifier.signature.instructions original_generator_instruction = student.generator.signature.instructions def optimize(component_selector): # Metric that progressively improves to encourage GEPA to accept new candidates call_count = 0 def improving_metric(example, prediction, trace=None, pred_name=None, pred_trace=None): nonlocal call_count call_count += 1 # Score improves with each call to encourage acceptance of new candidates score = min(0.3 + (call_count * 0.1), 1.0) return dspy.Prediction(score=score, feedback="Improving feedback") # Provide enough responses for all possible LM calls task_lm = DummyLM([{"category": "test_category", "output": "test_output"}] * 20) reflection_lm = DummyLM( [ {"improved_instruction": "Updated classifier instruction"}, {"improved_instruction": "Updated generator instruction"}, ] * 10 ) trainset = [dspy.Example(input="test", output="expected").with_inputs("input")] with dspy.context(lm=task_lm): optimizer = dspy.GEPA( metric=improving_metric, reflection_lm=reflection_lm, max_metric_calls=8, component_selector=component_selector, track_stats=True, # Track intermediate results to verify updates ) return optimizer.compile(student, trainset=trainset, valset=trainset) result_round_robin = optimize(component_selector="round_robin") candidates_round_robin = result_round_robin.detailed_results.candidates assert ( candidates_round_robin[1].classifier.signature.instructions == original_classifier_instruction and candidates_round_robin[1].generator.signature.instructions != original_generator_instruction ) or ( candidates_round_robin[1].classifier.signature.instructions != original_classifier_instruction and candidates_round_robin[1].generator.signature.instructions == original_generator_instruction ), "First candidate should have only one component updated, when using round_robin selector" result_all = optimize(component_selector="all") candidates_all = result_all.detailed_results.candidates assert ( candidates_all[1].classifier.signature.instructions != original_classifier_instruction and candidates_all[1].generator.signature.instructions != original_generator_instruction ), "First candidate should have both components updated, when using all selector" def test_component_selector_custom_random(): """Test custom component selector function that randomly samples components.""" import random # Simple function-based selector def random_component_selector(state, trajectories, subsample_scores, candidate_idx, candidate): """Randomly select half of the available components.""" component_names = list(candidate.keys()) num_to_select = max(1, len(component_names) // 2) # At least 1, half of total return random.sample(component_names, num_to_select) student = MultiComponentModule() # Provide enough responses for all possible LM calls task_lm = DummyLM([{"category": "test_category", "output": "test_output"}] * 15) reflection_lm = DummyLM([{"improved_instruction": "Better instruction"}] * 8) trainset = [dspy.Example(input="test", output="expected").with_inputs("input")] with dspy.context(lm=task_lm): optimizer = dspy.GEPA( metric=component_selection_metric, reflection_lm=reflection_lm, max_metric_calls=4, component_selector=random_component_selector, # Function-based selector ) result = optimizer.compile(student, trainset=trainset, valset=trainset) assert result is not None, "Should work with custom random function selector" def test_alternating_half_component_selector(): """Test alternating half selector that optimizes different halves on even/odd iterations.""" selection_history = [] def alternating_half_selector(state, trajectories, subsample_scores, candidate_idx, candidate): """Optimize half the components on even iterations, half on odd iterations.""" components = list(candidate.keys()) # If there's only one component, always optimize it if len(components) <= 1: selected = components else: mid_point = len(components) // 2 # Use state.i (iteration counter) to alternate between halves if state.i % 2 == 0: # Even iteration: optimize first half selected = components[:mid_point] else: # Odd iteration: optimize second half selected = components[mid_point:] # Track selections for verification selection_history.append({ "iteration": state.i, "selected": selected.copy(), "all_components": components.copy() }) return selected student = MultiComponentModule() # Has "classifier" and "generator" components # Provide enough responses for multiple iterations task_lm = DummyLM([{"category": "test_category", "output": "test_output"}] * 20) reflection_lm = DummyLM([{"improved_instruction": "Better instruction"}] * 10) trainset = [dspy.Example(input="test", output="expected").with_inputs("input")] with dspy.context(lm=task_lm): optimizer = dspy.GEPA( metric=component_selection_metric, reflection_lm=reflection_lm, max_metric_calls=8, # Allow multiple iterations component_selector=alternating_half_selector, ) result = optimizer.compile(student, trainset=trainset, valset=trainset) assert result is not None, "Should work with alternating half selector" assert len(selection_history) >= 2, "Should have made multiple selections" for i, selection in enumerate(selection_history): if selection["iteration"] % 2 == 0: # Even iteration should select first half: ["classifier"] assert "classifier" in selection["selected"], f"Even iteration {selection['iteration']} should include classifier" assert "generator" not in selection["selected"], f"Even iteration {selection['iteration']} should not include generator" else: # Odd iteration should select second half: ["generator"] assert "generator" in selection["selected"], f"Odd iteration {selection['iteration']} should include generator" assert "classifier" not in selection["selected"], f"Odd iteration {selection['iteration']} should not include classifier" ```