#
tokens: 43079/50000 3/391 files (page 16/17)
lines: on (toggle) GitHub
raw markdown copy reset
This is page 16 of 17. Use http://codebase.md/stanfordnlp/dspy?lines=true&page={x} to view the full context.

# Directory Structure

```
├── .github
│   ├── .internal_dspyai
│   │   ├── internals
│   │   │   ├── build-and-release.md
│   │   │   └── release-checklist.md
│   │   └── pyproject.toml
│   ├── .tmp
│   │   └── .generated-actions
│   │       └── run-pypi-publish-in-docker-container
│   │           └── action.yml
│   ├── ISSUE_TEMPLATE
│   │   ├── bug_report.yml
│   │   └── feature_request.yml
│   ├── PULL_REQUEST_TEMPLATE
│   │   └── pull_request_template.md
│   ├── workflow_scripts
│   │   └── install_testpypi_pkg.sh
│   └── workflows
│       ├── build_and_release.yml
│       ├── build_utils
│       │   └── test_version.py
│       ├── docs-push.yml
│       ├── precommits_check.yml
│       └── run_tests.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CONTRIBUTING.md
├── docs
│   ├── .gitignore
│   ├── docs
│   │   ├── api
│   │   │   ├── adapters
│   │   │   │   ├── Adapter.md
│   │   │   │   ├── ChatAdapter.md
│   │   │   │   ├── JSONAdapter.md
│   │   │   │   └── TwoStepAdapter.md
│   │   │   ├── evaluation
│   │   │   │   ├── answer_exact_match.md
│   │   │   │   ├── answer_passage_match.md
│   │   │   │   ├── CompleteAndGrounded.md
│   │   │   │   ├── Evaluate.md
│   │   │   │   ├── EvaluationResult.md
│   │   │   │   └── SemanticF1.md
│   │   │   ├── experimental
│   │   │   │   ├── Citations.md
│   │   │   │   └── Document.md
│   │   │   ├── index.md
│   │   │   ├── models
│   │   │   │   ├── Embedder.md
│   │   │   │   └── LM.md
│   │   │   ├── modules
│   │   │   │   ├── BestOfN.md
│   │   │   │   ├── ChainOfThought.md
│   │   │   │   ├── CodeAct.md
│   │   │   │   ├── Module.md
│   │   │   │   ├── MultiChainComparison.md
│   │   │   │   ├── Parallel.md
│   │   │   │   ├── Predict.md
│   │   │   │   ├── ProgramOfThought.md
│   │   │   │   ├── ReAct.md
│   │   │   │   └── Refine.md
│   │   │   ├── optimizers
│   │   │   │   ├── BetterTogether.md
│   │   │   │   ├── BootstrapFewShot.md
│   │   │   │   ├── BootstrapFewShotWithRandomSearch.md
│   │   │   │   ├── BootstrapFinetune.md
│   │   │   │   ├── BootstrapRS.md
│   │   │   │   ├── COPRO.md
│   │   │   │   ├── Ensemble.md
│   │   │   │   ├── GEPA
│   │   │   │   │   ├── GEPA_Advanced.md
│   │   │   │   │   └── overview.md
│   │   │   │   ├── InferRules.md
│   │   │   │   ├── KNN.md
│   │   │   │   ├── KNNFewShot.md
│   │   │   │   ├── LabeledFewShot.md
│   │   │   │   ├── MIPROv2.md
│   │   │   │   └── SIMBA.md
│   │   │   ├── primitives
│   │   │   │   ├── Audio.md
│   │   │   │   ├── Code.md
│   │   │   │   ├── Example.md
│   │   │   │   ├── History.md
│   │   │   │   ├── Image.md
│   │   │   │   ├── Prediction.md
│   │   │   │   ├── Tool.md
│   │   │   │   └── ToolCalls.md
│   │   │   ├── signatures
│   │   │   │   ├── InputField.md
│   │   │   │   ├── OutputField.md
│   │   │   │   └── Signature.md
│   │   │   ├── tools
│   │   │   │   ├── ColBERTv2.md
│   │   │   │   ├── Embeddings.md
│   │   │   │   └── PythonInterpreter.md
│   │   │   └── utils
│   │   │       ├── asyncify.md
│   │   │       ├── configure_cache.md
│   │   │       ├── disable_litellm_logging.md
│   │   │       ├── disable_logging.md
│   │   │       ├── enable_litellm_logging.md
│   │   │       ├── enable_logging.md
│   │   │       ├── inspect_history.md
│   │   │       ├── load.md
│   │   │       ├── StatusMessage.md
│   │   │       ├── StatusMessageProvider.md
│   │   │       ├── streamify.md
│   │   │       └── StreamListener.md
│   │   ├── cheatsheet.md
│   │   ├── community
│   │   │   ├── community-resources.md
│   │   │   ├── how-to-contribute.md
│   │   │   └── use-cases.md
│   │   ├── deep-dive
│   │   │   └── data-handling
│   │   │       ├── built-in-datasets.md
│   │   │       ├── examples.md
│   │   │       ├── img
│   │   │       │   └── data-loading.png
│   │   │       └── loading-custom-data.md
│   │   ├── faqs.md
│   │   ├── index.md
│   │   ├── js
│   │   │   └── runllm-widget.js
│   │   ├── learn
│   │   │   ├── evaluation
│   │   │   │   ├── data.md
│   │   │   │   ├── metrics.md
│   │   │   │   └── overview.md
│   │   │   ├── figures
│   │   │   │   ├── native_tool_call.png
│   │   │   │   └── teleprompter-classes.png
│   │   │   ├── index.md
│   │   │   ├── optimization
│   │   │   │   ├── optimizers.md
│   │   │   │   └── overview.md
│   │   │   └── programming
│   │   │       ├── 7-assertions.md
│   │   │       ├── adapters.md
│   │   │       ├── language_models.md
│   │   │       ├── mcp.md
│   │   │       ├── modules.md
│   │   │       ├── overview.md
│   │   │       ├── signatures.md
│   │   │       └── tools.md
│   │   ├── production
│   │   │   └── index.md
│   │   ├── roadmap.md
│   │   ├── static
│   │   │   ├── .nojekyll
│   │   │   └── img
│   │   │       ├── dspy_logo.png
│   │   │       ├── logo.png
│   │   │       ├── mlflow-tracing-rag.png
│   │   │       ├── modular.png
│   │   │       ├── optimize.png
│   │   │       ├── undraw_docusaurus_mountain.svg
│   │   │       ├── undraw_docusaurus_react.svg
│   │   │       ├── undraw_docusaurus_tree.svg
│   │   │       └── universal_compatibility.png
│   │   ├── stylesheets
│   │   │   └── extra.css
│   │   └── tutorials
│   │       ├── agents
│   │       │   ├── index.ipynb
│   │       │   └── mlflow-tracing-agent.png
│   │       ├── ai_text_game
│   │       │   └── index.md
│   │       ├── async
│   │       │   └── index.md
│   │       ├── audio
│   │       │   └── index.ipynb
│   │       ├── build_ai_program
│   │       │   └── index.md
│   │       ├── cache
│   │       │   └── index.md
│   │       ├── classification
│   │       │   └── index.md
│   │       ├── classification_finetuning
│   │       │   ├── index.ipynb
│   │       │   └── mlflow-tracing-classification.png
│   │       ├── conversation_history
│   │       │   └── index.md
│   │       ├── core_development
│   │       │   └── index.md
│   │       ├── custom_module
│   │       │   ├── index.ipynb
│   │       │   └── mlflow-tracing-custom-module.png
│   │       ├── customer_service_agent
│   │       │   ├── index.ipynb
│   │       │   └── mlflow-tracing-customer-service-agent.png
│   │       ├── deployment
│   │       │   ├── dspy_mlflow_ui.png
│   │       │   └── index.md
│   │       ├── email_extraction
│   │       │   ├── index.md
│   │       │   └── mlflow-tracing-email-extraction.png
│   │       ├── entity_extraction
│   │       │   ├── index.ipynb
│   │       │   └── mlflow-tracing-entity-extraction.png
│   │       ├── games
│   │       │   ├── index.ipynb
│   │       │   └── mlflow-tracing-agent.png
│   │       ├── gepa_ai_program
│   │       │   └── index.md
│   │       ├── gepa_aime
│   │       │   ├── index.ipynb
│   │       │   ├── mlflow-tracing-gepa-aime.png
│   │       │   └── mlflow-tracking-gepa-aime-optimization.png
│   │       ├── gepa_facilitysupportanalyzer
│   │       │   ├── index.ipynb
│   │       │   ├── mlflow-tracing-gepa-support.png
│   │       │   └── mlflow-tracking-gepa-support-optimization.png
│   │       ├── gepa_papillon
│   │       │   ├── index.ipynb
│   │       │   ├── mlflow-tracing-gepa-papilon.png
│   │       │   └── mlflow-tracking-gepa-papilon-optimization.png
│   │       ├── image_generation_prompting
│   │       │   └── index.ipynb
│   │       ├── index.md
│   │       ├── llms_txt_generation
│   │       │   └── index.md
│   │       ├── math
│   │       │   ├── index.ipynb
│   │       │   └── mlflow-tracing-math.png
│   │       ├── mcp
│   │       │   └── index.md
│   │       ├── mem0_react_agent
│   │       │   └── index.md
│   │       ├── multihop_search
│   │       │   ├── index.ipynb
│   │       │   └── mlflow-tracing-multi-hop.png
│   │       ├── observability
│   │       │   ├── index.md
│   │       │   ├── mlflow_trace_ui_navigation.gif
│   │       │   ├── mlflow_trace_ui.png
│   │       │   └── mlflow_trace_view.png
│   │       ├── optimize_ai_program
│   │       │   └── index.md
│   │       ├── optimizer_tracking
│   │       │   ├── child_run.png
│   │       │   ├── experiment.png
│   │       │   ├── index.md
│   │       │   └── parent_run.png
│   │       ├── output_refinement
│   │       │   └── best-of-n-and-refine.md
│   │       ├── papillon
│   │       │   └── index.md
│   │       ├── program_of_thought
│   │       │   └── index.ipynb
│   │       ├── rag
│   │       │   ├── index.ipynb
│   │       │   └── mlflow-tracing-rag.png
│   │       ├── real_world_examples
│   │       │   └── index.md
│   │       ├── rl_ai_program
│   │       │   └── index.md
│   │       ├── rl_multihop
│   │       │   └── index.ipynb
│   │       ├── rl_papillon
│   │       │   └── index.ipynb
│   │       ├── sample_code_generation
│   │       │   └── index.md
│   │       ├── saving
│   │       │   └── index.md
│   │       ├── streaming
│   │       │   └── index.md
│   │       ├── tool_use
│   │       │   └── index.ipynb
│   │       └── yahoo_finance_react
│   │           └── index.md
│   ├── mkdocs.yml
│   ├── overrides
│   │   ├── home.html
│   │   ├── main.html
│   │   └── partials
│   │       └── tabs.html
│   ├── Pipfile
│   ├── Pipfile.lock
│   ├── README.md
│   ├── requirements.txt
│   ├── scripts
│   │   ├── generate_api_docs.py
│   │   └── generate_api_summary.py
│   └── vercel.json
├── dspy
│   ├── __init__.py
│   ├── __metadata__.py
│   ├── adapters
│   │   ├── __init__.py
│   │   ├── baml_adapter.py
│   │   ├── base.py
│   │   ├── chat_adapter.py
│   │   ├── json_adapter.py
│   │   ├── two_step_adapter.py
│   │   ├── types
│   │   │   ├── __init__.py
│   │   │   ├── audio.py
│   │   │   ├── base_type.py
│   │   │   ├── citation.py
│   │   │   ├── code.py
│   │   │   ├── document.py
│   │   │   ├── history.py
│   │   │   ├── image.py
│   │   │   └── tool.py
│   │   ├── utils.py
│   │   └── xml_adapter.py
│   ├── clients
│   │   ├── __init__.py
│   │   ├── base_lm.py
│   │   ├── cache.py
│   │   ├── databricks.py
│   │   ├── embedding.py
│   │   ├── lm_local_arbor.py
│   │   ├── lm_local.py
│   │   ├── lm.py
│   │   ├── openai.py
│   │   ├── provider.py
│   │   └── utils_finetune.py
│   ├── datasets
│   │   ├── __init__.py
│   │   ├── alfworld
│   │   │   ├── __init__.py
│   │   │   ├── alfworld.py
│   │   │   └── base_config.yml
│   │   ├── colors.py
│   │   ├── dataloader.py
│   │   ├── dataset.py
│   │   ├── gsm8k.py
│   │   ├── hotpotqa.py
│   │   └── math.py
│   ├── dsp
│   │   ├── __init__.py
│   │   ├── colbertv2.py
│   │   └── utils
│   │       ├── __init__.py
│   │       ├── dpr.py
│   │       ├── settings.py
│   │       └── utils.py
│   ├── evaluate
│   │   ├── __init__.py
│   │   ├── auto_evaluation.py
│   │   ├── evaluate.py
│   │   └── metrics.py
│   ├── experimental
│   │   └── __init__.py
│   ├── predict
│   │   ├── __init__.py
│   │   ├── aggregation.py
│   │   ├── avatar
│   │   │   ├── __init__.py
│   │   │   ├── avatar.py
│   │   │   ├── models.py
│   │   │   └── signatures.py
│   │   ├── best_of_n.py
│   │   ├── chain_of_thought.py
│   │   ├── code_act.py
│   │   ├── knn.py
│   │   ├── multi_chain_comparison.py
│   │   ├── parallel.py
│   │   ├── parameter.py
│   │   ├── predict.py
│   │   ├── program_of_thought.py
│   │   ├── react.py
│   │   ├── refine.py
│   │   └── retry.py
│   ├── primitives
│   │   ├── __init__.py
│   │   ├── base_module.py
│   │   ├── example.py
│   │   ├── module.py
│   │   ├── prediction.py
│   │   ├── python_interpreter.py
│   │   └── runner.js
│   ├── propose
│   │   ├── __init__.py
│   │   ├── dataset_summary_generator.py
│   │   ├── grounded_proposer.py
│   │   ├── propose_base.py
│   │   └── utils.py
│   ├── retrievers
│   │   ├── __init__.py
│   │   ├── databricks_rm.py
│   │   ├── embeddings.py
│   │   ├── retrieve.py
│   │   └── weaviate_rm.py
│   ├── signatures
│   │   ├── __init__.py
│   │   ├── field.py
│   │   ├── signature.py
│   │   └── utils.py
│   ├── streaming
│   │   ├── __init__.py
│   │   ├── messages.py
│   │   ├── streamify.py
│   │   └── streaming_listener.py
│   ├── teleprompt
│   │   ├── __init__.py
│   │   ├── avatar_optimizer.py
│   │   ├── bettertogether.py
│   │   ├── bootstrap_finetune.py
│   │   ├── bootstrap_trace.py
│   │   ├── bootstrap.py
│   │   ├── copro_optimizer.py
│   │   ├── ensemble.py
│   │   ├── gepa
│   │   │   ├── __init__.py
│   │   │   ├── gepa_utils.py
│   │   │   ├── gepa.py
│   │   │   └── instruction_proposal.py
│   │   ├── grpo.py
│   │   ├── infer_rules.py
│   │   ├── knn_fewshot.py
│   │   ├── mipro_optimizer_v2.py
│   │   ├── random_search.py
│   │   ├── signature_opt.py
│   │   ├── simba_utils.py
│   │   ├── simba.py
│   │   ├── teleprompt_optuna.py
│   │   ├── teleprompt.py
│   │   ├── utils.py
│   │   └── vanilla.py
│   └── utils
│       ├── __init__.py
│       ├── annotation.py
│       ├── asyncify.py
│       ├── caching.py
│       ├── callback.py
│       ├── dummies.py
│       ├── exceptions.py
│       ├── hasher.py
│       ├── inspect_history.py
│       ├── langchain_tool.py
│       ├── logging_utils.py
│       ├── mcp.py
│       ├── parallelizer.py
│       ├── saving.py
│       ├── syncify.py
│       ├── unbatchify.py
│       └── usage_tracker.py
├── LICENSE
├── pyproject.toml
├── README.md
├── tests
│   ├── __init__.py
│   ├── adapters
│   │   ├── test_adapter_utils.py
│   │   ├── test_baml_adapter.py
│   │   ├── test_base_type.py
│   │   ├── test_chat_adapter.py
│   │   ├── test_citation.py
│   │   ├── test_code.py
│   │   ├── test_document.py
│   │   ├── test_json_adapter.py
│   │   ├── test_tool.py
│   │   ├── test_two_step_adapter.py
│   │   └── test_xml_adapter.py
│   ├── callback
│   │   └── test_callback.py
│   ├── clients
│   │   ├── test_cache.py
│   │   ├── test_databricks.py
│   │   ├── test_embedding.py
│   │   ├── test_inspect_global_history.py
│   │   └── test_lm.py
│   ├── conftest.py
│   ├── datasets
│   │   └── test_dataset.py
│   ├── docs
│   │   └── test_mkdocs_links.py
│   ├── evaluate
│   │   ├── test_evaluate.py
│   │   └── test_metrics.py
│   ├── examples
│   │   └── test_baleen.py
│   ├── metadata
│   │   └── test_metadata.py
│   ├── predict
│   │   ├── test_aggregation.py
│   │   ├── test_best_of_n.py
│   │   ├── test_chain_of_thought.py
│   │   ├── test_code_act.py
│   │   ├── test_knn.py
│   │   ├── test_multi_chain_comparison.py
│   │   ├── test_parallel.py
│   │   ├── test_predict.py
│   │   ├── test_program_of_thought.py
│   │   ├── test_react.py
│   │   ├── test_refine.py
│   │   └── test_retry.py
│   ├── primitives
│   │   ├── resources
│   │   │   └── saved_program.json
│   │   ├── test_base_module.py
│   │   ├── test_example.py
│   │   ├── test_module.py
│   │   └── test_python_interpreter.py
│   ├── propose
│   │   └── test_grounded_proposer.py
│   ├── README.md
│   ├── reliability
│   │   ├── __init__.py
│   │   ├── complex_types
│   │   │   └── generated
│   │   │       ├── test_many_types_1
│   │   │       │   ├── inputs
│   │   │       │   │   ├── input1.json
│   │   │       │   │   └── input2.json
│   │   │       │   ├── program.py
│   │   │       │   └── schema.json
│   │   │       ├── test_nesting_1
│   │   │       │   ├── inputs
│   │   │       │   │   ├── input1.json
│   │   │       │   │   └── input2.json
│   │   │       │   ├── program.py
│   │   │       │   └── schema.json
│   │   │       └── test_nesting_2
│   │   │           ├── inputs
│   │   │           │   └── input1.json
│   │   │           ├── program.py
│   │   │           └── schema.json
│   │   ├── conftest.py
│   │   ├── generate
│   │   │   ├── __init__.py
│   │   │   ├── __main__.py
│   │   │   └── utils.py
│   │   ├── input_formats
│   │   │   └── generated
│   │   │       └── test_markdown_1
│   │   │           ├── inputs
│   │   │           │   ├── input1.json
│   │   │           │   └── input2.json
│   │   │           ├── program.py
│   │   │           └── schema.json
│   │   ├── README.md
│   │   ├── reliability_conf.yaml
│   │   ├── test_generated.py
│   │   ├── test_pydantic_models.py
│   │   └── utils.py
│   ├── retrievers
│   │   └── test_embeddings.py
│   ├── signatures
│   │   ├── test_adapter_image.py
│   │   ├── test_custom_types.py
│   │   └── test_signature.py
│   ├── streaming
│   │   └── test_streaming.py
│   ├── teleprompt
│   │   ├── gepa_dummy_lm_custom_component_selector_custom_instruction_proposer.json
│   │   ├── gepa_dummy_lm.json
│   │   ├── test_bootstrap_finetune.py
│   │   ├── test_bootstrap_trace.py
│   │   ├── test_bootstrap.py
│   │   ├── test_copro_optimizer.py
│   │   ├── test_ensemble.py
│   │   ├── test_finetune.py
│   │   ├── test_gepa_instruction_proposer.py
│   │   ├── test_gepa.py
│   │   ├── test_grpo.py
│   │   ├── test_knn_fewshot.py
│   │   ├── test_random_search.py
│   │   ├── test_teleprompt.py
│   │   └── test_utils.py
│   ├── test_utils
│   │   ├── __init__.py
│   │   └── server
│   │       ├── __init__.py
│   │       ├── litellm_server_config.yaml
│   │       └── litellm_server.py
│   └── utils
│       ├── __init__.py
│       ├── resources
│       │   └── mcp_server.py
│       ├── test_annotation.py
│       ├── test_asyncify.py
│       ├── test_exceptions.py
│       ├── test_langchain_tool.py
│       ├── test_mcp.py
│       ├── test_parallelizer.py
│       ├── test_saving.py
│       ├── test_settings.py
│       ├── test_syncify.py
│       ├── test_unbatchify.py
│       └── test_usage_tracker.py
└── uv.lock
```

# Files

--------------------------------------------------------------------------------
/dspy/teleprompt/grpo.py:
--------------------------------------------------------------------------------

```python
  1 | import logging
  2 | import random
  3 | from collections import Counter
  4 | from typing import Any, Callable, Literal
  5 | 
  6 | from dspy.adapters.base import Adapter
  7 | from dspy.adapters.chat_adapter import ChatAdapter
  8 | from dspy.clients.lm import LM
  9 | from dspy.clients.utils_finetune import GRPOGroup, MultiGPUConfig, TrainDataFormat
 10 | from dspy.dsp.utils.settings import settings
 11 | from dspy.evaluate.evaluate import Evaluate
 12 | from dspy.primitives.example import Example
 13 | from dspy.primitives.module import Module
 14 | from dspy.teleprompt.bootstrap_finetune import (
 15 |     FinetuneTeleprompter,
 16 |     all_predictors_have_lms,
 17 |     assert_structural_equivalency,
 18 | )
 19 | from dspy.teleprompt.bootstrap_trace import FailedPrediction, bootstrap_trace_data
 20 | 
 21 | logger = logging.getLogger(__name__)
 22 | 
 23 | 
 24 | class GRPO(FinetuneTeleprompter):
 25 |     def __init__(
 26 |         self,
 27 |         metric: Callable | None = None,
 28 |         multitask: bool = True,
 29 |         train_kwargs: dict[str, Any] | dict[LM, dict[str, Any]] | None = None,
 30 |         adapter: Adapter | dict[LM, Adapter] | None = None,
 31 |         exclude_demos: bool = False,
 32 |         num_threads: int = 6,
 33 |         num_train_steps: int = 100,
 34 |         seed: int = 0,
 35 |         num_dspy_examples_per_grpo_step: int = 1,
 36 |         num_rollouts_per_grpo_step: int = 1,
 37 |         use_train_as_val: bool = False,
 38 |         num_steps_for_val: int = 5,
 39 |         report_train_scores: bool = False,
 40 |         failure_score: float = 0,
 41 |         format_failure_score: float = -1,
 42 |         variably_invoked_predictor_grouping_mode: Literal["truncate"] | Literal["fill"] | Literal["ragged"] = "truncate",
 43 |         variably_invoked_predictor_fill_strategy: Literal["randint"] | Literal["max"] | None = None,
 44 |         gpu_config: MultiGPUConfig = MultiGPUConfig(num_inference_gpus=1, num_training_gpus=1),
 45 |     ):
 46 |         super().__init__(train_kwargs=train_kwargs)
 47 |         self.metric = metric
 48 |         self.multitask = multitask
 49 |         self.adapter: dict[LM, Adapter] = self.convert_to_lm_dict(adapter)
 50 |         self.exclude_demos = exclude_demos
 51 |         self.num_threads = num_threads
 52 |         self.num_train_steps = num_train_steps
 53 |         self.rng = random.Random(seed)
 54 |         self.num_dspy_examples_per_grpo_step = num_dspy_examples_per_grpo_step
 55 |         self.num_rollouts_per_grpo_step = num_rollouts_per_grpo_step
 56 |         self.use_train_as_val = use_train_as_val
 57 |         self.num_steps_for_val = num_steps_for_val
 58 |         self.report_train_scores = report_train_scores
 59 |         self.failure_score = failure_score
 60 |         self.format_failure_score = format_failure_score
 61 |         self.gpu_config = gpu_config
 62 | 
 63 |         assert failure_score > format_failure_score, "failure_score must be greater than format_failure_score since the range [format_failure_score, failure_score] is used to provide dspy formatting rewards"
 64 | 
 65 |         if self.use_train_as_val:
 66 |             assert report_train_scores, "If use_train_as_val is True, report_train_scores must be True."
 67 | 
 68 |         assert exclude_demos, "exclude_demos==False is not supported yet. Please set it to True."
 69 |         assert multitask, "independent GRPO training jobs for each predictor in the student program is not supported yet. Please set multitask=True."
 70 | 
 71 |         # The backend will be called with a batch of (num_dspy_examples_per_grpo_step * num_rollouts_per_grpo_step * num_predictors) per training set if multitask is True
 72 |         # If multitask is False, the backend will be called with a batch of (num_dspy_examples_per_grpo_step * num_rollouts_per_grpo_step) per training job
 73 |         self.variably_invoked_predictor_grouping_mode = variably_invoked_predictor_grouping_mode
 74 |         if variably_invoked_predictor_grouping_mode == "fill":
 75 |             assert variably_invoked_predictor_fill_strategy is not None, "variably_invoked_predictor_fill_strategy must be set when variably_invoked_predictor_grouping_mode is 'fill'"
 76 |             assert variably_invoked_predictor_fill_strategy in ["randint", "max"], "variably_invoked_predictor_fill_strategy must be either 'randint' or 'max'"
 77 |         self.variably_invoked_predictor_fill_strategy = variably_invoked_predictor_fill_strategy
 78 | 
 79 |         self.shuffled_trainset_ids = []
 80 |         self.epoch = -1
 81 |         self.id_freqs = Counter()
 82 | 
 83 |     def validate_trace_data_and_log_issues(
 84 |         self,
 85 |         trace_data: list[list[list[dict[str, Any]]]],
 86 |         subsample_training_dataset: list[Example],
 87 |         num_teachers: int,
 88 |         num_samples_per_input: int,
 89 |         pred_signature_hash_to_ind: dict[int, int],
 90 |     ):
 91 |         # At this point, trace_data: list[example_idx -> list[teacher_idx -> [num_samples_per_input * Dict(example, prediction, trace, example_ind, score)]]]
 92 |         # Shape of trace is: [dspy_module_invocation_idx -> Tuple[Predictor, PredictorInputs, Prediction]]
 93 |         assert len(trace_data) == len(subsample_training_dataset), f"Trace data length {len(trace_data)} does not match the number of examples {len(subsample_training_dataset)}"
 94 |         assert len(trace_data[0]) == num_teachers, f"Trace data length {len(trace_data[0])} does not match the number of teachers {num_teachers}"
 95 |         # TODO(GRPO Team): Ideally, once the dspy format issue is fixed, this change should be reverted back to being a normal assert.
 96 |         if len(trace_data[0][0]) == 0:
 97 |             logger.warning(f"Trace data for example {0} and teacher {0} is empty. This is likely due to all examples in the training set input, resulting in the model generating output not following the dspy response format.")
 98 |         elif len(trace_data[0][0]) != num_samples_per_input:
 99 |             logger.warning(f"Trace data length {len(trace_data[0][0])} does not match the expected number of samples per input {num_samples_per_input}")
100 |             assert "trace" in trace_data[0][0][0], "Trace data does not contain the 'trace' key"
101 |             assert len(trace_data[0][0][0]["trace"]) > 0, "Trace data is empty"
102 |             assert len(trace_data[0][0][0]["trace"][0]) == 3, f"Trace tuple length {len(trace_data[0][0][0]['trace'][0])} does not match the expected length 3"
103 | 
104 |         for example_data in trace_data:
105 |             for teacher_data in example_data:
106 |                 for sample in teacher_data:
107 |                     for t in sample["trace"]:
108 |                         assert hash(t[0].signature) in pred_signature_hash_to_ind
109 | 
110 |     def report_validation_metrics(self, student, trainset, valset, logger, step_idx=-1):
111 |         if step_idx == -1 or step_idx == self.num_train_steps - 1 or (step_idx + 1) % self.num_steps_for_val == 0:
112 |             pass
113 |         else:
114 |             return
115 | 
116 |         if valset is not None:
117 |             # Validation set provided by user
118 |             assert not self.use_train_as_val, "If valset is provided, use_train_as_val must be False."
119 |             assert isinstance(self.num_steps_for_val, int) and self.num_steps_for_val > 0, "num_steps_for_val must be a positive integer."
120 |             if self.report_train_scores:
121 |                 if step_idx == -1:
122 |                     logger.info("Using user provided validation set and reporting train scores for every validation step in addition.")
123 |                 valset_evaluator = Evaluate(
124 |                     devset=valset + trainset,
125 |                     num_threads=self.num_threads,
126 |                     display_progress=True,
127 |                     provide_traceback=False,  # TODO(check with team)
128 |                     max_errors=len(valset)*10,  # TODO(check with team)
129 |                     failure_score=self.failure_score
130 |                 )
131 |                 if step_idx == -1:
132 |                     logger.info("Evaluating the student program on the train+validation set before training loop...")
133 |                 else:
134 |                     logger.info(f"Evaluating the student program on the validation set after training step {step_idx + 1}/{self.num_train_steps}")
135 |                 valset_evaluation = valset_evaluator(student, metric=self.metric)
136 |                 trainset_scores = [r[-1] for r in valset_evaluation.results[len(valset):]]
137 |                 valset_scores = [r[-1] for r in valset_evaluation.results[:len(valset)]]
138 |                 trainset_agg = sum(trainset_scores) / len(trainset_scores)
139 |                 valset_agg = sum(valset_scores) / len(valset_scores)
140 |                 if step_idx == -1:
141 |                     logger.info(f"Student program training set score before training loop: {trainset_agg}")
142 |                     logger.info(f"Student program validation set score before training loop: {valset_agg}")
143 |                 else:
144 |                     logger.info(f"Student program training set score after training step {step_idx + 1}/{self.num_train_steps}: {trainset_agg}")
145 |                     logger.info(f"Student program validation set score after training step {step_idx + 1}/{self.num_train_steps}: {valset_agg}")
146 |             else:
147 |                 if step_idx == -1:
148 |                     logger.info("Using user provided validation set and not reporting train scores.")
149 |                 valset_evaluator = Evaluate(
150 |                     devset=valset,
151 |                     num_threads=self.num_threads,
152 |                     display_progress=True,
153 |                     provide_traceback=False,  # TODO(check with team)
154 |                     max_errors=len(valset)*10,  # TODO(check with team)
155 |                     failure_score=self.failure_score
156 |                 )
157 |                 if step_idx == -1:
158 |                     logger.info("Evaluating the student program on the validation set before training loop...")
159 |                 else:
160 |                     logger.info(f"Evaluating the student program on the validation set after training step {step_idx + 1}/{self.num_train_steps}")
161 |                 valset_evaluation = valset_evaluator(student, metric=self.metric)
162 |                 if step_idx == -1:
163 |                     logger.info(f"Student program validation set score before training loop: {valset_evaluation.score}")
164 |                 else:
165 |                     logger.info(f"Student program validation set score after training step {step_idx + 1}/{self.num_train_steps}: {valset_evaluation.score}")
166 |         else:
167 |             # No validation set provided by user
168 |             if self.report_train_scores:
169 |                 assert self.use_train_as_val, "If report_train_scores is True, use_train_as_val must be True when valset is not provided explicitly."
170 |                 assert isinstance(self.num_steps_for_val, int) and self.num_steps_for_val > 0, "num_steps_for_val must be a positive integer."
171 |                 if step_idx == -1:
172 |                     logger.info("Using trainset as validation set.")
173 |                 valset_evaluator = Evaluate(
174 |                     devset=trainset,
175 |                     num_threads=self.num_threads,
176 |                     display_progress=True,
177 |                     provide_traceback=False,  # TODO(check with team)
178 |                     max_errors=len(trainset)*10,  # TODO(check with team)
179 |                     failure_score=self.failure_score
180 |                 )
181 |                 if step_idx == -1:
182 |                     logger.info("Evaluating the student program on the validation set before training loop...")
183 |                 else:
184 |                     logger.info(f"Evaluating the student program on the validation set after training step {step_idx + 1}/{self.num_train_steps}")
185 |                 valset_evaluation = valset_evaluator(student, metric=self.metric)
186 |                 if step_idx == -1:
187 |                     logger.info(f"Student program training set score before training loop: {valset_evaluation.score}")
188 |                 else:
189 |                     logger.info(f"Student program training set score after training step {step_idx + 1}/{self.num_train_steps}: {valset_evaluation.score}")
190 |             else:
191 |                 # No valset provided, and not using train as val
192 |                 assert not self.use_train_as_val, "If report_train_scores is False, use_train_as_val must be False."
193 |                 if step_idx == -1:
194 |                     logger.info("Not using any validation set and not reporting train scores.")
195 | 
196 |     def update_shuffled_trainset(self, original_trainset):
197 |         self.shuffled_trainset_ids = list(range(len(original_trainset)))
198 |         self.rng.shuffle(self.shuffled_trainset_ids)
199 |         for id in self.shuffled_trainset_ids:
200 |             self.id_freqs[id] += 1
201 | 
202 |         num_to_pad = self.num_dspy_examples_per_grpo_step - (len(original_trainset) % self.num_dspy_examples_per_grpo_step)
203 |         if num_to_pad > 0:
204 |             # Select ids based on least frequent ids
205 |             for _ in range(num_to_pad):
206 |                 selected_id = self.id_freqs.most_common()[::-1][0][0]
207 |                 self.shuffled_trainset_ids.append(selected_id)
208 |                 self.id_freqs[selected_id] += 1
209 | 
210 |     def select_training_sample_and_update_shuffled_trainset(
211 |         self,
212 |         original_trainset: list[Example],
213 |         train_step_idx: int,
214 |     ) -> list[Example]:
215 |         base_idx = train_step_idx * self.num_dspy_examples_per_grpo_step
216 |         if self.epoch == -1:
217 |             curr_epoch = 0
218 |         else:
219 |             curr_epoch = base_idx // len(self.shuffled_trainset_ids)
220 |         if curr_epoch > self.epoch:
221 |             logger.info(f"Updating shuffled trainset for epoch {curr_epoch}...")
222 |             self.epoch = curr_epoch
223 |             self.update_shuffled_trainset(original_trainset)
224 | 
225 |         assert len(self.shuffled_trainset_ids) >= self.num_dspy_examples_per_grpo_step, f"Shuffled trainset length {len(self.shuffled_trainset_ids)} is less than num_dspy_examples_per_grpo_step {self.num_dspy_examples_per_grpo_step}"
226 |         assert len(self.shuffled_trainset_ids) % self.num_dspy_examples_per_grpo_step == 0, f"Shuffled trainset length {len(self.shuffled_trainset_ids)} is not divisible by num_dspy_examples_per_grpo_step {self.num_dspy_examples_per_grpo_step}"
227 | 
228 |         base_idx = base_idx % len(self.shuffled_trainset_ids)
229 |         end_idx = base_idx + self.num_dspy_examples_per_grpo_step
230 |         assert end_idx <= len(self.shuffled_trainset_ids), f"End index {end_idx} is out of bounds for shuffled trainset length {len(self.shuffled_trainset_ids)}"
231 |         selected_ids = self.shuffled_trainset_ids[base_idx:end_idx]
232 |         selected_trainset = [original_trainset[i] for i in selected_ids]
233 |         return selected_trainset
234 | 
235 |     def compile(
236 |         self,
237 |         student: Module,
238 |         trainset: list[Example],
239 |         teacher: Module | list[Module] | None = None,
240 |         valset: list[Example] | None = None,
241 |         **kwargs,
242 |     ) -> Module:
243 |         logger.info("Starting the GRPO compilation process... The LM(s) for the student program will be updated in place at the end of the training.")
244 |         logger.info("Validating the inputs...")
245 | 
246 |         assert len(trainset) > 0, "Training set is empty. Please provide a non-empty training set."
247 | 
248 |         if len(trainset) < self.num_dspy_examples_per_grpo_step:
249 |             logger.warning(
250 |             f"Number of training examples {len(trainset)} is less than the number of examples per GRPO step {self.num_dspy_examples_per_grpo_step}. "
251 |                 "Repeating the training set to fill the GRPO step. This could lead to overfitting and training instability."
252 |             )
253 |             multiplier = (self.num_dspy_examples_per_grpo_step + len(trainset) - 1) // len(trainset)
254 |             if multiplier > 1:
255 |                 logger.warning(
256 |                     f"Repeating the training set {multiplier} times to fill the GRPO step. This could lead to overfitting and training instability."
257 |                 )
258 |                 trainset = trainset * multiplier
259 | 
260 |         # TODO(GRPO Team): Following checks are for unimplemented features.
261 |         # Consider if we want to eventually implement them or remove. We don't
262 |         # yet support:
263 |         # * multitask == False
264 |         # * student program with multiple predictor LMs
265 |         # The main reason for these is that we update the LMs in place. If these
266 |         # LMs are shared between the different predictors of the student
267 |         # program and we have multitask == False, we need to decide which steps
268 |         # will use new LM copies and we need to ensure our decision is
269 |         # consistent with any teacher LMs that share the same LMs.
270 |         # TODO(GRPO Team): We want to make it possible to continue GRPO runs in
271 |         # the future by saving the state of the GRPO run in the event of a
272 |         # process failure.
273 |         if not self.multitask:
274 |             raise ValueError(
275 |                 "Independent GRPO training jobs for each predictor in the student program "
276 |                 "are not supported yet. Please set multitask=True."
277 |             )
278 | 
279 |         student_lms = {id(pred.lm) for pred in student.predictors()}
280 |         assert len(student_lms) == 1, (
281 |             f"Student program has multiple LMs: {student_lms}. "
282 |             "GRPO only supports student programs with a single LM."
283 |             "You can set the LM for a program with `program.set_lm(...)`"
284 |         )
285 | 
286 |         # Our regular input validation starts here
287 |         if self.use_train_as_val:
288 |             assert valset is None, "If use_train_as_val is True, valset must be None."
289 | 
290 |         logger.info("Preparing the student program...")
291 |         all_predictors_have_lms(student)
292 |         pred_signature_hash_to_ind = {hash(pred.signature): ind for ind, pred in enumerate(student.predictors())}
293 |         num_student_predictors = len(student.predictors())
294 | 
295 |         logging.info("Preparing the teacher program(s)... We will ensure that the provided programs have the same program structure as the student program.")
296 |         if (isinstance(teacher, list) and len(teacher) == 0) or teacher is None:
297 |             teacher = student
298 |         teachers = teacher if isinstance(teacher, list) else [teacher]
299 |         for t in teachers:
300 |             assert_structural_equivalency(student, t)
301 |             all_predictors_have_lms(t)
302 | 
303 |         # Ensure that the teachers list contain the student program
304 |         assert student in teachers, f"Student program {student} is not in the list of teachers {teachers}. Please provide the student program as one of the teachers. Alternatively, you can leave the teacher argument as None, and the student program will be used as the teacher program."
305 |         assert self.num_rollouts_per_grpo_step % len(teachers) == 0, (
306 |             f"The GRPO group size (num_rollouts_per_grpo_step) {self.num_rollouts_per_grpo_step} is not divisible by the number of teachers {len(teachers)}. "
307 |             "This is required to ensure that each teacher gets the same number of examples."
308 |             "Please provide a number of examples that is divisible by the number of teachers."
309 |         )
310 |         num_samples_per_input = self.num_rollouts_per_grpo_step // len(teachers)
311 | 
312 |         # We will disable the LM cache for all programs (student and teachers)
313 |         # These will be reverted to their original state at the end of the
314 |         # training
315 |         lm_cache_dict = {}
316 |         disable_lm_cache(program=student, lm_cache_dict=lm_cache_dict)
317 |         for t in teachers:
318 |             disable_lm_cache(program=t, lm_cache_dict=lm_cache_dict)
319 | 
320 |         # Update train_kwargs
321 |         for pred in student.predictors():
322 |             train_kwargs = self.train_kwargs[pred.lm]
323 |             train_kwargs = {} if train_kwargs is None else train_kwargs
324 |             train_kwargs["num_generations"] = self.num_rollouts_per_grpo_step
325 |             self.train_kwargs[pred.lm] = train_kwargs
326 | 
327 |         # We need to have a separate job for each unique LM x the data
328 |         # collection strategy. This properly handles all combinations of
329 |         # multitask and predictor LMs
330 |         logger.info("Preparing the GRPO training job(s)...")
331 |         grpo_training_jobs = {}
332 |         for pred_ind, pred in enumerate(student.predictors()):
333 |             data_key = None if self.multitask else pred_ind
334 |             job_key = (pred.lm, data_key)
335 |             if job_key not in grpo_training_jobs:
336 |                 train_kwargs = self.train_kwargs[pred.lm]
337 |                 job = pred.lm.reinforce(train_kwargs=train_kwargs, gpu_config=self.gpu_config)
338 |                 grpo_training_jobs[job_key] = job
339 | 
340 |         self.report_validation_metrics(
341 |             student=student,
342 |             trainset=trainset,
343 |             valset=valset,
344 |             logger=logger,
345 |             step_idx=-1,
346 |         )
347 | 
348 |         logger.info("Starting the GRPO training loop...")
349 |         for train_step_idx in range(self.num_train_steps):
350 |             logger.info(f"GRPO training step {train_step_idx + 1}/{self.num_train_steps}...")
351 | 
352 |             subsample_training_dataset = self.select_training_sample_and_update_shuffled_trainset(
353 |                 original_trainset=trainset,
354 |                 train_step_idx=train_step_idx,
355 |             )
356 | 
357 |             logger.info("Bootstrapping data...")
358 |             trace_data = [[[] for _ in range(len(teachers))] for _ in range(len(subsample_training_dataset))]
359 |             for tind, teacher in enumerate(teachers):
360 |                 subsample_training_dataset_repeated = [example for _ in range(num_samples_per_input) for example in subsample_training_dataset]
361 |                 round_data = bootstrap_trace_data(
362 |                     program=teacher,
363 |                     dataset=subsample_training_dataset_repeated,
364 |                     metric=self.metric,
365 |                     num_threads=self.num_threads,
366 |                     raise_on_error=False, # TODO(GRPO Team): This should be True, once the dspy format issue is fixed
367 |                     capture_failed_parses=True,
368 |                     failure_score=self.failure_score,
369 |                     format_failure_score=self.format_failure_score,
370 |                 )
371 |                 for data_dict in round_data:
372 |                     example_ind_in_subsample = data_dict["example_ind"] % len(subsample_training_dataset)
373 |                     data_dict["example_ind"] = example_ind_in_subsample
374 |                     trace_data[example_ind_in_subsample][tind].append(data_dict)
375 | 
376 |             # The trace_data for examples with FailedPrediction cases will have the signature at index 0, instead of the predictor
377 |             # We need to replace the signature with the predictor
378 | 
379 |             # At this point, trace_data: list[example_idx -> list[teacher_idx -> [num_samples_per_input * Dict(example, prediction, trace, example_ind, score)]]]
380 |             # Shape of trace is: [dspy_module_invocation_idx -> Tuple[Predictor, PredictorInputs, Prediction]]
381 |             self.validate_trace_data_and_log_issues(
382 |                 trace_data=trace_data,
383 |                 subsample_training_dataset=subsample_training_dataset,
384 |                 num_teachers=len(teachers),
385 |                 num_samples_per_input=num_samples_per_input,
386 |                 pred_signature_hash_to_ind=pred_signature_hash_to_ind,
387 |             )
388 | 
389 |             logger.info("Preparing the training data batch from bootstrapped examples for GRPO...")
390 |             # Now, we need to prepare batches of data to be sent for training
391 |             # Shape of train_batch_per_predictor: list[num_student_predictors -> list[ ]]
392 |             train_batch_per_predictor: list[list[GRPOGroup]] = [[] for _ in range(num_student_predictors)]
393 |             for pred_id in range(num_student_predictors):
394 |                 for example_ind, example_data in enumerate(trace_data):
395 |                     # Each example_data is a list of teacher_idx -> [num_samples_per_input * Dict(example, prediction, trace, example_ind, score)]
396 |                     # We need to flatten this list and create a batch for each predictor
397 | 
398 |                     # TODO(Lakshya, Omar, Noah): Discuss what to do with the same module being invoked multiple times within a single dspy.Example
399 |                     predictor_example_invocations: list[list[tuple]] = []
400 | 
401 |                     for teacher_data in example_data:
402 |                         for sample in teacher_data:
403 |                             # Each sample is a Dict(example, prediction, trace, example_ind, score)
404 |                             # sample['prediction'] is module_level prediction
405 |                             assert sample["example_ind"] == example_ind, f"Example index {sample['example_ind']} does not match the expected index {example_ind}"
406 | 
407 |                             trace_instances_for_current_pred = [(*t, sample["score"]) for t in sample["trace"] if hash(t[0].signature) == hash(student.predictors()[pred_id].signature)]
408 | 
409 |                             predictor_example_invocations.append(trace_instances_for_current_pred)
410 | 
411 |                     if len(predictor_example_invocations) == 0:
412 |                         logger.warning(f"Skipping example {example_ind} for predictor {pred_id} as it has no invocations. This is likely due to all examples in the training set input, resulting in the model generating output not following the dspy response format.")
413 |                         continue
414 |                     elif len(predictor_example_invocations) != self.num_rollouts_per_grpo_step:
415 |                         logger.warning(f"Number of predictor example invocations {len(predictor_example_invocations)} does not match the expected batch size {self.num_rollouts_per_grpo_step}. This is likely due to all examples in the training set input, resulting in the model generating output not following the dspy response format.")
416 | 
417 |                     min_len = min([len(predictor_example_invocations[i]) for i in range(len(predictor_example_invocations))])
418 |                     max_len = max([len(predictor_example_invocations[i]) for i in range(len(predictor_example_invocations))])
419 |                     if min_len == 0:
420 |                         logger.warning(f"Skipping example {example_ind} for predictor {pred_id} as it has no invocations.")
421 |                         continue
422 | 
423 |                     if self.variably_invoked_predictor_grouping_mode == "truncate":
424 |                         predictor_example_invocations = [invocation[:min_len] for invocation in predictor_example_invocations]
425 |                     elif self.variably_invoked_predictor_grouping_mode == "fill":
426 |                         if self.variably_invoked_predictor_fill_strategy == "randint":
427 |                             selector = lambda l: self.rng.choice(l) # noqa: E731, E741
428 |                         else:
429 |                             selector = lambda l: l[-1] # noqa: E731, E741
430 |                         predictor_example_invocations = [
431 |                             invocation + [selector(invocation) for _ in range(max_len - len(invocation))]
432 |                             for invocation in predictor_example_invocations
433 |                         ]
434 |                     else:
435 |                         assert self.variably_invoked_predictor_grouping_mode == "ragged", f"Unknown variably invoked predictor grouping mode {self.variably_invoked_predictor_grouping_mode}"
436 |                     max_len = max([len(predictor_example_invocations[i]) for i in range(len(predictor_example_invocations))])
437 | 
438 |                     example_training_data: list[GRPOGroup] = [[] for _ in range(max_len)]
439 | 
440 |                     for group_idx in range(max_len):
441 |                         for rollout_idx in range(len(predictor_example_invocations)):
442 |                             trace_instance = predictor_example_invocations[rollout_idx][group_idx]
443 |                             score = trace_instance[3]
444 |                             # for module_invocation_idx, trace_instance in enumerate(trace_instances_for_current_pred):
445 |                             # Each trace is a tuple of (Predictor, PredictorInputs, Prediction)
446 |                             trace_pred_id = pred_signature_hash_to_ind.get(hash(trace_instance[0].signature))
447 |                             assert trace_pred_id == pred_id
448 | 
449 |                             predictor = trace_instance[0]
450 |                             pred_lm = predictor.lm
451 |                             adapter = self.adapter[pred_lm] or settings.adapter or ChatAdapter()
452 |                             assert isinstance(adapter, ChatAdapter), f"Adapter {adapter} is not a ChatAdapter. GRPO training is not supported for this adapter."
453 |                             # TODO(Lakshya): Currently we exclude demos from the training data
454 |                             # TODO(GRPO Team): Use build_call_data_from_trace (from bootstrap_finetune) instead of
455 |                             # dealing with the message formatting ourselves.
456 |                             inp_messages = adapter.format(
457 |                                 signature=trace_instance[0].signature,
458 |                                 inputs=trace_instance[1],
459 |                                 demos=[] # TODO: Add support for demos
460 |                             )
461 | 
462 |                             if isinstance(trace_instance[2], FailedPrediction):
463 |                                 score = trace_instance[2].format_reward or self.format_failure_score
464 |                                 example_training_data[group_idx].append({
465 |                                     "messages": inp_messages,
466 |                                     "completion": {
467 |                                         "role": "assistant",
468 |                                         "content": trace_instance[2].completion_text,
469 |                                     },
470 |                                     "reward": float(score),
471 |                                 })
472 |                                 logger.warning(f"Adding a format failure example to the training data for predictor {pred_id} and example {example_ind}.")
473 |                             else:
474 |                                 all_messages = adapter.format_finetune_data(
475 |                                     signature=trace_instance[0].signature,
476 |                                     inputs=trace_instance[1],
477 |                                     outputs=trace_instance[2],
478 |                                     demos=[] # TODO: Add support for demos
479 |                                 )["messages"]
480 | 
481 |                                 assert all_messages[:-1] == inp_messages, f"Input messages {inp_messages} do not match the expected messages {all_messages[:-1]}"
482 | 
483 |                                 example_training_data[group_idx].append({
484 |                                     "messages": inp_messages,
485 |                                     "completion": {
486 |                                         "role": all_messages[-1]["role"],
487 |                                         "content": all_messages[-1]["content"],
488 |                                     },
489 |                                     "reward": float(score),
490 |                                 })
491 | 
492 |                     train_batch_per_predictor[pred_id].extend(example_training_data)
493 | 
494 |             if not any(train_batch_per_predictor):
495 |                 logger.warning("No training data found for this training step. This means that the model did not generate valid formatted responses for any of the examples in the training set. This is a critical error. Please check the model and the training set.")
496 |                 continue
497 | 
498 |             for predictor_train_batch in train_batch_per_predictor:
499 |                 for grpo_train_group in predictor_train_batch:
500 |                     if len(grpo_train_group) != self.num_rollouts_per_grpo_step:
501 |                         logger.warning(f"Number of completions {len(grpo_train_group)} does not match the expected number num_rollouts_per_grpo_step={self.num_rollouts_per_grpo_step}")
502 |                         assert len(grpo_train_group) <= self.num_rollouts_per_grpo_step, f"Number of completions {len(grpo_train_group)} is greater than the expected number num_rollouts_per_grpo_step={self.num_rollouts_per_grpo_step}"
503 |                     if len(set(map(repr, grpo_train_group))) < 2:
504 |                         # TODO(GRPO Team): How can we avoid this warning?
505 |                         logger.warning(f"GRPOGroup has no diversity. This could be due to low temperature, or low number of rollouts, or the cache could be enabled inadvertently. The GRPOGroup is {grpo_train_group}.")
506 | 
507 |             # We now run the GRPO step. Notes:
508 |             # * The job here has a reference to a particular M that's attached
509 |             #   to the student program. We update the .model field of this LM
510 |             #   inside the job, which also updates the LM in the student program
511 |             #   since these point to the same reference (along with any teacher
512 |             #   program that shares the same LM).
513 |             # * TODO(GRPO Team): This is inconsistent with how
514 |             #   BootstrapFinetune works, which creates new LM instances post
515 |             #   training. We should decide whether the LMs should be updated in
516 |             #   place or new LMs should be created, and standardize our approach
517 |             #   for both. If we decide to create new LMs, we should find a way
518 |             #   to update self.adapter and self.train_kwargs accordingly, in
519 |             #   addition to updating any teacher programs that share the same
520 |             #   LM.
521 |             logger.info("Invoking GRPO training step...")
522 |             for (_, data_key), job in grpo_training_jobs.items():
523 |                 train_data: list[GRPOGroup] = sum(train_batch_per_predictor, []) if data_key is None else train_batch_per_predictor[data_key] #noqa: RUF017
524 |                 for group in train_data:
525 |                     if len(group) != self.num_rollouts_per_grpo_step:
526 |                         # TODO(GRPO Team): This is very undesirable. This occurs only because in some of the generations, the model does not follow the correct dspy format.
527 |                         # The ideal solution is to identify the full response string in that predictor's group, and then assign  a high-negative (user-configurable) reward to that group.
528 | 
529 |                         # Pad the group to the expected number of generations by repeating the whole group, might require multiple iterations
530 |                         while len(group) < self.num_rollouts_per_grpo_step:
531 |                             group.extend(group[:min(self.num_rollouts_per_grpo_step - len(group), len(group))])
532 |                     assert len(group) == self.num_rollouts_per_grpo_step, f"Number of completions {len(group)} does not match the expected number self.num_rollouts_per_grpo_step={self.num_rollouts_per_grpo_step}"
533 | 
534 |                 job.step(train_data=train_data, train_data_format=TrainDataFormat.GRPO_CHAT)
535 | 
536 |             logger.info(f"GRPO training step {train_step_idx + 1}/{self.num_train_steps} completed.")
537 | 
538 |             self.report_validation_metrics(
539 |                 student=student,
540 |                 trainset=trainset,
541 |                 valset=valset,
542 |                 logger=logger,
543 |                 step_idx=train_step_idx,
544 |             )
545 | 
546 |         logger.info("Done with the iterations! Retrieving the final model(s)...")
547 |         for _, job in grpo_training_jobs.items():
548 |             job.terminate()
549 | 
550 |         # Revert cache states to their initial values
551 |         recover_lm_cache(program=student, lm_cache_dict=lm_cache_dict)
552 |         for t in teachers:
553 |             recover_lm_cache(program=t, lm_cache_dict=lm_cache_dict)
554 | 
555 |         logger.info("GRPO compiler has finished compiling the student program")
556 |         student._compiled = True
557 |         return student
558 | 
559 | 
560 | def disable_lm_cache(program: Module, lm_cache_dict: dict):
561 |     """Disable the LM cache for all predictors in the program."""
562 |     for pred in program.predictors():
563 |         if not pred.lm:
564 |             raise ValueError(f"Cannot disable cache: predictor {pred} does not have an LM set.")
565 |         if pred.lm not in lm_cache_dict:  # Check to avoid overwriting the cache
566 |             lm_cache_dict[pred.lm] = pred.lm.cache
567 |         pred.lm.cache = False
568 | 
569 | 
570 | def recover_lm_cache(program: Module, lm_cache_dict: dict):
571 |     """Recover the LM caches for all predictors in the program to their original state."""
572 |     for pred in program.predictors():
573 |         if pred.lm in lm_cache_dict:
574 |             pred.lm.cache = lm_cache_dict[pred.lm]
575 |         else:
576 |             # We do not expect this branch to execute at all since all the LMs
577 |             # are modified in place and no new LMs are created during training.
578 |             # However, we do not complain if this happens since this is a
579 |             # relatively minor feature. We default the LM cache to True.
580 |             pred.lm.cache = True
581 | 
```

--------------------------------------------------------------------------------
/tests/adapters/test_json_adapter.py:
--------------------------------------------------------------------------------

```python
  1 | from unittest import mock
  2 | 
  3 | import pydantic
  4 | import pytest
  5 | from litellm.types.llms.openai import ResponseAPIUsage, ResponsesAPIResponse
  6 | from litellm.utils import ChatCompletionMessageToolCall, Choices, Function, Message, ModelResponse
  7 | from openai.types.responses import ResponseOutputMessage
  8 | 
  9 | import dspy
 10 | 
 11 | 
 12 | def test_json_adapter_passes_structured_output_when_supported_by_model():
 13 |     class OutputField3(pydantic.BaseModel):
 14 |         subfield1: int = pydantic.Field(description="Int subfield 1", ge=0, le=10)
 15 |         subfield2: float = pydantic.Field(description="Float subfield 2")
 16 | 
 17 |     class TestSignature(dspy.Signature):
 18 |         input1: str = dspy.InputField()
 19 |         output1: str = dspy.OutputField()  # Description intentionally left blank
 20 |         output2: bool = dspy.OutputField(desc="Boolean output field")
 21 |         output3: OutputField3 = dspy.OutputField(desc="Nested output field")
 22 |         output4_unannotated = dspy.OutputField(desc="Unannotated output field")
 23 | 
 24 |     program = dspy.Predict(TestSignature)
 25 | 
 26 |     # Configure DSPy to use an OpenAI LM that supports structured outputs
 27 |     dspy.configure(lm=dspy.LM(model="openai/gpt-4o"), adapter=dspy.JSONAdapter())
 28 |     with mock.patch("litellm.completion") as mock_completion:
 29 |         program(input1="Test input")
 30 | 
 31 |     def clean_schema_extra(field_name, field_info):
 32 |         attrs = dict(field_info.__repr_args__())
 33 |         if "json_schema_extra" in attrs:
 34 |             attrs["json_schema_extra"] = {
 35 |                 k: v
 36 |                 for k, v in attrs["json_schema_extra"].items()
 37 |                 if k != "__dspy_field_type" and not (k == "desc" and v == f"${{{field_name}}}")
 38 |             }
 39 |         return attrs
 40 | 
 41 |     mock_completion.assert_called_once()
 42 |     _, call_kwargs = mock_completion.call_args
 43 |     response_format = call_kwargs.get("response_format")
 44 |     assert response_format is not None
 45 |     assert issubclass(response_format, pydantic.BaseModel)
 46 |     assert response_format.model_fields.keys() == {"output1", "output2", "output3", "output4_unannotated"}
 47 | 
 48 | 
 49 | def test_json_adapter_not_using_structured_outputs_when_not_supported_by_model():
 50 |     class TestSignature(dspy.Signature):
 51 |         input1: str = dspy.InputField()
 52 |         output1: str = dspy.OutputField()
 53 |         output2: bool = dspy.OutputField()
 54 | 
 55 |     program = dspy.Predict(TestSignature)
 56 | 
 57 |     # Configure DSPy to use a model from a fake provider that doesn't support structured outputs
 58 |     dspy.configure(lm=dspy.LM(model="fakeprovider/fakemodel", cache=False), adapter=dspy.JSONAdapter())
 59 |     with mock.patch("litellm.completion") as mock_completion:
 60 |         mock_completion.return_value = ModelResponse(
 61 |             choices=[Choices(message=Message(content=("{'output1': 'Test output', 'output2': True}")))],
 62 |             model="openai/gpt-4o",
 63 |         )
 64 | 
 65 |         program(input1="Test input")
 66 | 
 67 |     mock_completion.assert_called_once()
 68 |     _, call_kwargs = mock_completion.call_args
 69 |     assert "response_format" not in call_kwargs
 70 | 
 71 | 
 72 | def test_json_adapter_falls_back_when_structured_outputs_fails():
 73 |     class TestSignature(dspy.Signature):
 74 |         input1: str = dspy.InputField()
 75 |         output1: str = dspy.OutputField(desc="String output field")
 76 | 
 77 |     dspy.configure(lm=dspy.LM(model="openai/gpt-4o", cache=False), adapter=dspy.JSONAdapter())
 78 |     program = dspy.Predict(TestSignature)
 79 |     with mock.patch("litellm.completion") as mock_completion:
 80 |         mock_completion.side_effect = [Exception("Bad structured outputs!"), mock_completion.return_value]
 81 |         program(input1="Test input")
 82 |         assert mock_completion.call_count == 2
 83 |         _, first_call_kwargs = mock_completion.call_args_list[0]
 84 |         assert issubclass(first_call_kwargs.get("response_format"), pydantic.BaseModel)
 85 |         _, second_call_kwargs = mock_completion.call_args_list[1]
 86 |         assert second_call_kwargs.get("response_format") == {"type": "json_object"}
 87 | 
 88 | 
 89 | def test_json_adapter_with_structured_outputs_does_not_mutate_original_signature():
 90 |     class OutputField3(pydantic.BaseModel):
 91 |         subfield1: int = pydantic.Field(description="Int subfield 1")
 92 |         subfield2: float = pydantic.Field(description="Float subfield 2")
 93 | 
 94 |     class TestSignature(dspy.Signature):
 95 |         input1: str = dspy.InputField()
 96 |         output1: str = dspy.OutputField()  # Description intentionally left blank
 97 |         output2: bool = dspy.OutputField(desc="Boolean output field")
 98 |         output3: OutputField3 = dspy.OutputField(desc="Nested output field")
 99 |         output4_unannotated = dspy.OutputField(desc="Unannotated output field")
100 | 
101 |     dspy.configure(lm=dspy.LM(model="openai/gpt-4o"), adapter=dspy.JSONAdapter())
102 |     program = dspy.Predict(TestSignature)
103 |     with mock.patch("litellm.completion"):
104 |         program(input1="Test input")
105 | 
106 |     assert program.signature.output_fields == TestSignature.output_fields
107 | 
108 | 
109 | def test_json_adapter_sync_call():
110 |     signature = dspy.make_signature("question->answer")
111 |     adapter = dspy.JSONAdapter()
112 |     lm = dspy.utils.DummyLM([{"answer": "Paris"}], adapter=adapter)
113 |     with dspy.context(adapter=adapter):
114 |         result = adapter(lm, {}, signature, [], {"question": "What is the capital of France?"})
115 |     assert result == [{"answer": "Paris"}]
116 | 
117 | 
118 | @pytest.mark.asyncio
119 | async def test_json_adapter_async_call():
120 |     signature = dspy.make_signature("question->answer")
121 |     adapter = dspy.JSONAdapter()
122 |     lm = dspy.utils.DummyLM([{"answer": "Paris"}], adapter=adapter)
123 |     with dspy.context(adapter=adapter):
124 |         result = await adapter.acall(lm, {}, signature, [], {"question": "What is the capital of France?"})
125 |     assert result == [{"answer": "Paris"}]
126 | 
127 | 
128 | def test_json_adapter_on_pydantic_model():
129 |     from litellm.utils import Choices, Message, ModelResponse
130 | 
131 |     class User(pydantic.BaseModel):
132 |         id: int
133 |         name: str
134 |         email: str
135 | 
136 |     class Answer(pydantic.BaseModel):
137 |         analysis: str
138 |         result: str
139 | 
140 |     class TestSignature(dspy.Signature):
141 |         user: User = dspy.InputField(desc="The user who asks the question")
142 |         question: str = dspy.InputField(desc="Question the user asks")
143 |         answer: Answer = dspy.OutputField(desc="Answer to this question")
144 | 
145 |     program = dspy.Predict(TestSignature)
146 | 
147 |     dspy.configure(lm=dspy.LM(model="openai/gpt-4o", cache=False), adapter=dspy.JSONAdapter())
148 | 
149 |     with mock.patch("litellm.completion") as mock_completion:
150 |         mock_completion.return_value = ModelResponse(
151 |             choices=[
152 |                 Choices(
153 |                     message=Message(
154 |                         content="{'answer': {'analysis': 'Paris is the capital of France', 'result': 'Paris'}}"
155 |                     )
156 |                 )
157 |             ],
158 |             model="openai/gpt-4o",
159 |         )
160 |         result = program(
161 |             user={"id": 5, "name": "name_test", "email": "email_test"}, question="What is the capital of France?"
162 |         )
163 | 
164 |         # Check that litellm.completion was called exactly once
165 |         mock_completion.assert_called_once()
166 | 
167 |         _, call_kwargs = mock_completion.call_args
168 |         # Assert that there are exactly 2 messages (system + user)
169 |         assert len(call_kwargs["messages"]) == 2
170 | 
171 |         assert call_kwargs["messages"][0]["role"] == "system"
172 |         content = call_kwargs["messages"][0]["content"]
173 |         assert content is not None
174 | 
175 |         # Assert that system prompt includes correct input field descriptions
176 |         expected_input_fields = (
177 |             "1. `user` (User): The user who asks the question\n2. `question` (str): Question the user asks\n"
178 |         )
179 |         assert expected_input_fields in content
180 | 
181 |         # Assert that system prompt includes correct output field description
182 |         expected_output_fields = "1. `answer` (Answer): Answer to this question\n"
183 |         assert expected_output_fields in content
184 | 
185 |         # Assert that system prompt includes input formatting structure
186 |         expected_input_structure = "[[ ## user ## ]]\n{user}\n\n[[ ## question ## ]]\n{question}\n\n"
187 |         assert expected_input_structure in content
188 | 
189 |         # Assert that system prompt includes output formatting structure
190 |         expected_output_structure = (
191 |             "Outputs will be a JSON object with the following fields.\n\n{\n  "
192 |             '"answer": "{answer}        # note: the value you produce must adhere to the JSON schema: '
193 |             '{\\"type\\": \\"object\\", \\"properties\\": {\\"analysis\\": {\\"type\\": \\"string\\", \\"title\\": '
194 |             '\\"Analysis\\"}, \\"result\\": {\\"type\\": \\"string\\", \\"title\\": \\"Result\\"}}, \\"required\\": '
195 |             '[\\"analysis\\", \\"result\\"], \\"title\\": \\"Answer\\"}"\n}'
196 |         )
197 |         assert expected_output_structure in content
198 | 
199 |         assert call_kwargs["messages"][1]["role"] == "user"
200 |         user_message_content = call_kwargs["messages"][1]["content"]
201 |         assert user_message_content is not None
202 | 
203 |         # Assert that the user input data is formatted correctly
204 |         expected_input_data = (
205 |             '[[ ## user ## ]]\n{"id": 5, "name": "name_test", "email": "email_test"}\n\n[[ ## question ## ]]\n'
206 |             "What is the capital of France?\n\n"
207 |         )
208 |         assert expected_input_data in user_message_content
209 | 
210 |         # Assert that the adapter output has expected fields and values
211 |         assert result.answer.analysis == "Paris is the capital of France"
212 |         assert result.answer.result == "Paris"
213 | 
214 | 
215 | def test_json_adapter_parse_raise_error_on_mismatch_fields():
216 |     signature = dspy.make_signature("question->answer")
217 |     adapter = dspy.JSONAdapter()
218 |     with mock.patch("litellm.completion") as mock_completion:
219 |         mock_completion.return_value = ModelResponse(
220 |             choices=[
221 |                 Choices(message=Message(content="{'answer1': 'Paris'}")),
222 |             ],
223 |             model="openai/gpt-4o",
224 |         )
225 |         lm = dspy.LM(model="openai/gpt-4o-mini")
226 |         with pytest.raises(dspy.utils.exceptions.AdapterParseError) as e:
227 |             adapter(lm, {}, signature, [], {"question": "What is the capital of France?"})
228 | 
229 |     assert e.value.adapter_name == "JSONAdapter"
230 |     assert e.value.signature == signature
231 |     assert e.value.lm_response == "{'answer1': 'Paris'}"
232 |     assert e.value.parsed_result == {}
233 | 
234 |     assert str(e.value) == (
235 |         "Adapter JSONAdapter failed to parse the LM response. \n\n"
236 |         "LM Response: {'answer1': 'Paris'} \n\n"
237 |         "Expected to find output fields in the LM response: [answer] \n\n"
238 |         "Actual output fields parsed from the LM response: [] \n\n"
239 |     )
240 | 
241 | 
242 | def test_json_adapter_formats_image():
243 |     # Test basic image formatting
244 |     image = dspy.Image(url="https://example.com/image.jpg")
245 | 
246 |     class MySignature(dspy.Signature):
247 |         image: dspy.Image = dspy.InputField()
248 |         text: str = dspy.OutputField()
249 | 
250 |     adapter = dspy.JSONAdapter()
251 |     messages = adapter.format(MySignature, [], {"image": image})
252 | 
253 |     assert len(messages) == 2
254 |     user_message_content = messages[1]["content"]
255 |     assert user_message_content is not None
256 | 
257 |     # The message should have 3 chunks of types: text, image_url, text
258 |     assert len(user_message_content) == 3
259 |     assert user_message_content[0]["type"] == "text"
260 |     assert user_message_content[2]["type"] == "text"
261 | 
262 |     # Assert that the image is formatted correctly
263 |     expected_image_content = {"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}}
264 |     assert expected_image_content in user_message_content
265 | 
266 | 
267 | def test_json_adapter_formats_image_with_few_shot_examples():
268 |     class MySignature(dspy.Signature):
269 |         image: dspy.Image = dspy.InputField()
270 |         text: str = dspy.OutputField()
271 | 
272 |     adapter = dspy.JSONAdapter()
273 | 
274 |     demos = [
275 |         dspy.Example(
276 |             image=dspy.Image(url="https://example.com/image1.jpg"),
277 |             text="This is a test image",
278 |         ),
279 |         dspy.Example(
280 |             image=dspy.Image(url="https://example.com/image2.jpg"),
281 |             text="This is another test image",
282 |         ),
283 |     ]
284 |     messages = adapter.format(MySignature, demos, {"image": dspy.Image(url="https://example.com/image3.jpg")})
285 | 
286 |     # 1 system message, 2 few shot examples (1 user and assistant message for each example), 1 user message
287 |     assert len(messages) == 6
288 | 
289 |     assert {"type": "image_url", "image_url": {"url": "https://example.com/image1.jpg"}} in messages[1]["content"]
290 |     assert {"type": "image_url", "image_url": {"url": "https://example.com/image2.jpg"}} in messages[3]["content"]
291 |     assert {"type": "image_url", "image_url": {"url": "https://example.com/image3.jpg"}} in messages[5]["content"]
292 | 
293 | 
294 | def test_json_adapter_formats_image_with_nested_images():
295 |     class ImageWrapper(pydantic.BaseModel):
296 |         images: list[dspy.Image]
297 |         tag: list[str]
298 | 
299 |     class MySignature(dspy.Signature):
300 |         image: ImageWrapper = dspy.InputField()
301 |         text: str = dspy.OutputField()
302 | 
303 |     image1 = dspy.Image(url="https://example.com/image1.jpg")
304 |     image2 = dspy.Image(url="https://example.com/image2.jpg")
305 |     image3 = dspy.Image(url="https://example.com/image3.jpg")
306 | 
307 |     image_wrapper = ImageWrapper(images=[image1, image2, image3], tag=["test", "example"])
308 | 
309 |     adapter = dspy.JSONAdapter()
310 |     messages = adapter.format(MySignature, [], {"image": image_wrapper})
311 | 
312 |     expected_image1_content = {"type": "image_url", "image_url": {"url": "https://example.com/image1.jpg"}}
313 |     expected_image2_content = {"type": "image_url", "image_url": {"url": "https://example.com/image2.jpg"}}
314 |     expected_image3_content = {"type": "image_url", "image_url": {"url": "https://example.com/image3.jpg"}}
315 | 
316 |     assert expected_image1_content in messages[1]["content"]
317 |     assert expected_image2_content in messages[1]["content"]
318 |     assert expected_image3_content in messages[1]["content"]
319 | 
320 | 
321 | def test_json_adapter_formats_with_nested_documents():
322 |     class DocumentWrapper(pydantic.BaseModel):
323 |         documents: list[dspy.experimental.Document]
324 | 
325 |     class MySignature(dspy.Signature):
326 |         document: DocumentWrapper = dspy.InputField()
327 |         text: str = dspy.OutputField()
328 | 
329 |     doc1 = dspy.experimental.Document(data="Hello, world!")
330 |     doc2 = dspy.experimental.Document(data="Hello, world 2!")
331 | 
332 |     document_wrapper = DocumentWrapper(documents=[doc1, doc2])
333 | 
334 |     adapter = dspy.JSONAdapter()
335 |     messages = adapter.format(MySignature, [], {"document": document_wrapper})
336 | 
337 |     expected_doc1_content = {"type": "document", "source": {"type": "text", "media_type": "text/plain", "data": "Hello, world!"}, "citations": {"enabled": True}}
338 |     expected_doc2_content = {"type": "document", "source": {"type": "text", "media_type": "text/plain", "data": "Hello, world 2!"}, "citations": {"enabled": True}}
339 | 
340 |     assert expected_doc1_content in messages[1]["content"]
341 |     assert expected_doc2_content in messages[1]["content"]
342 | 
343 | 
344 | def test_json_adapter_formats_image_with_few_shot_examples_with_nested_images():
345 |     class ImageWrapper(pydantic.BaseModel):
346 |         images: list[dspy.Image]
347 |         tag: list[str]
348 | 
349 |     class MySignature(dspy.Signature):
350 |         image: ImageWrapper = dspy.InputField()
351 |         text: str = dspy.OutputField()
352 | 
353 |     image1 = dspy.Image(url="https://example.com/image1.jpg")
354 |     image2 = dspy.Image(url="https://example.com/image2.jpg")
355 |     image3 = dspy.Image(url="https://example.com/image3.jpg")
356 | 
357 |     image_wrapper = ImageWrapper(images=[image1, image2, image3], tag=["test", "example"])
358 |     demos = [
359 |         dspy.Example(
360 |             image=image_wrapper,
361 |             text="This is a test image",
362 |         ),
363 |     ]
364 | 
365 |     image_wrapper_2 = ImageWrapper(images=[dspy.Image(url="https://example.com/image4.jpg")], tag=["test", "example"])
366 |     adapter = dspy.JSONAdapter()
367 |     messages = adapter.format(MySignature, demos, {"image": image_wrapper_2})
368 | 
369 |     assert len(messages) == 4
370 | 
371 |     # Image information in the few-shot example's user message
372 |     expected_image1_content = {"type": "image_url", "image_url": {"url": "https://example.com/image1.jpg"}}
373 |     expected_image2_content = {"type": "image_url", "image_url": {"url": "https://example.com/image2.jpg"}}
374 |     expected_image3_content = {"type": "image_url", "image_url": {"url": "https://example.com/image3.jpg"}}
375 |     assert expected_image1_content in messages[1]["content"]
376 |     assert expected_image2_content in messages[1]["content"]
377 |     assert expected_image3_content in messages[1]["content"]
378 | 
379 |     # The query image is formatted in the last user message
380 |     assert {"type": "image_url", "image_url": {"url": "https://example.com/image4.jpg"}} in messages[-1]["content"]
381 | 
382 | 
383 | def test_json_adapter_with_tool():
384 |     class MySignature(dspy.Signature):
385 |         """Answer question with the help of the tools"""
386 | 
387 |         question: str = dspy.InputField()
388 |         tools: list[dspy.Tool] = dspy.InputField()
389 |         answer: str = dspy.OutputField()
390 |         tool_calls: dspy.ToolCalls = dspy.OutputField()
391 | 
392 |     def get_weather(city: str) -> str:
393 |         """Get the weather for a city"""
394 |         return f"The weather in {city} is sunny"
395 | 
396 |     def get_population(country: str, year: int) -> str:
397 |         """Get the population for a country"""
398 |         return f"The population of {country} in {year} is 1000000"
399 | 
400 |     tools = [dspy.Tool(get_weather), dspy.Tool(get_population)]
401 | 
402 |     adapter = dspy.JSONAdapter()
403 |     messages = adapter.format(MySignature, [], {"question": "What is the weather in Tokyo?", "tools": tools})
404 | 
405 |     assert len(messages) == 2
406 | 
407 |     # The output field type description should be included in the system message even if the output field is nested
408 |     assert dspy.ToolCalls.description() in messages[0]["content"]
409 | 
410 |     # The user message should include the question and the tools
411 |     assert "What is the weather in Tokyo?" in messages[1]["content"]
412 |     assert "get_weather" in messages[1]["content"]
413 |     assert "get_population" in messages[1]["content"]
414 | 
415 |     # Tool arguments format should be included in the user message
416 |     assert "{'city': {'type': 'string'}}" in messages[1]["content"]
417 |     assert "{'country': {'type': 'string'}, 'year': {'type': 'integer'}}" in messages[1]["content"]
418 | 
419 |     with mock.patch("litellm.completion") as mock_completion:
420 |         lm = dspy.LM(model="openai/gpt-4o-mini")
421 |         adapter(lm, {}, MySignature, [], {"question": "What is the weather in Tokyo?", "tools": tools})
422 | 
423 |     mock_completion.assert_called_once()
424 |     _, call_kwargs = mock_completion.call_args
425 | 
426 |     # Assert tool calls are included in the `tools` arg
427 |     assert len(call_kwargs["tools"]) > 0
428 |     assert call_kwargs["tools"][0] == {
429 |         "type": "function",
430 |         "function": {
431 |             "name": "get_weather",
432 |             "description": "Get the weather for a city",
433 |             "parameters": {
434 |                 "type": "object",
435 |                 "properties": {
436 |                     "city": {
437 |                         "type": "string",
438 |                     },
439 |                 },
440 |                 "required": ["city"],
441 |             },
442 |         },
443 |     }
444 |     assert call_kwargs["tools"][1] == {
445 |         "type": "function",
446 |         "function": {
447 |             "name": "get_population",
448 |             "description": "Get the population for a country",
449 |             "parameters": {
450 |                 "type": "object",
451 |                 "properties": {
452 |                     "country": {
453 |                         "type": "string",
454 |                     },
455 |                     "year": {
456 |                         "type": "integer",
457 |                     },
458 |                 },
459 |                 "required": ["country", "year"],
460 |             },
461 |         },
462 |     }
463 | 
464 | 
465 | def test_json_adapter_with_code():
466 |     # Test with code as input field
467 |     class CodeAnalysis(dspy.Signature):
468 |         """Analyze the time complexity of the code"""
469 | 
470 |         code: dspy.Code = dspy.InputField()
471 |         result: str = dspy.OutputField()
472 | 
473 |     adapter = dspy.JSONAdapter()
474 |     messages = adapter.format(CodeAnalysis, [], {"code": "print('Hello, world!')"})
475 | 
476 |     assert len(messages) == 2
477 | 
478 |     # The output field type description should be included in the system message even if the output field is nested
479 |     assert dspy.Code.description() in messages[0]["content"]
480 | 
481 |     # The user message should include the question and the tools
482 |     assert "print('Hello, world!')" in messages[1]["content"]
483 | 
484 |     # Test with code as output field
485 |     class CodeGeneration(dspy.Signature):
486 |         """Generate code to answer the question"""
487 | 
488 |         question: str = dspy.InputField()
489 |         code: dspy.Code = dspy.OutputField()
490 | 
491 |     adapter = dspy.JSONAdapter()
492 |     with mock.patch("litellm.completion") as mock_completion:
493 |         mock_completion.return_value = ModelResponse(
494 |             choices=[Choices(message=Message(content="{'code': 'print(\"Hello, world!\")'}"))],
495 |             model="openai/gpt-4o-mini",
496 |         )
497 |         result = adapter(
498 |             dspy.LM(model="openai/gpt-4o-mini", cache=False),
499 |             {},
500 |             CodeGeneration,
501 |             [],
502 |             {"question": "Write a python program to print 'Hello, world!'"},
503 |         )
504 |         assert result[0]["code"].code == 'print("Hello, world!")'
505 | 
506 | 
507 | def test_json_adapter_formats_conversation_history():
508 |     class MySignature(dspy.Signature):
509 |         question: str = dspy.InputField()
510 |         history: dspy.History = dspy.InputField()
511 |         answer: str = dspy.OutputField()
512 | 
513 |     history = dspy.History(
514 |         messages=[
515 |             {"question": "What is the capital of France?", "answer": "Paris"},
516 |             {"question": "What is the capital of Germany?", "answer": "Berlin"},
517 |         ]
518 |     )
519 | 
520 |     adapter = dspy.JSONAdapter()
521 |     messages = adapter.format(MySignature, [], {"question": "What is the capital of France?", "history": history})
522 | 
523 |     assert len(messages) == 6
524 |     assert messages[1]["content"] == "[[ ## question ## ]]\nWhat is the capital of France?"
525 |     assert messages[2]["content"] == '{\n  "answer": "Paris"\n}'
526 |     assert messages[3]["content"] == "[[ ## question ## ]]\nWhat is the capital of Germany?"
527 |     assert messages[4]["content"] == '{\n  "answer": "Berlin"\n}'
528 | 
529 | 
530 | @pytest.mark.asyncio
531 | async def test_json_adapter_on_pydantic_model_async():
532 |     from litellm.utils import Choices, Message, ModelResponse
533 | 
534 |     class User(pydantic.BaseModel):
535 |         id: int
536 |         name: str
537 |         email: str
538 | 
539 |     class Answer(pydantic.BaseModel):
540 |         analysis: str
541 |         result: str
542 | 
543 |     class TestSignature(dspy.Signature):
544 |         user: User = dspy.InputField(desc="The user who asks the question")
545 |         question: str = dspy.InputField(desc="Question the user asks")
546 |         answer: Answer = dspy.OutputField(desc="Answer to this question")
547 | 
548 |     program = dspy.Predict(TestSignature)
549 | 
550 |     with mock.patch("litellm.acompletion") as mock_completion:
551 |         mock_completion.return_value = ModelResponse(
552 |             choices=[
553 |                 Choices(
554 |                     message=Message(
555 |                         content="{'answer': {'analysis': 'Paris is the capital of France', 'result': 'Paris'}}"
556 |                     )
557 |                 )
558 |             ],
559 |             model="openai/gpt-4o",
560 |         )
561 | 
562 |         with dspy.context(lm=dspy.LM(model="openai/gpt-4o", cache=False), adapter=dspy.JSONAdapter()):
563 |             result = await program.acall(
564 |                 user={"id": 5, "name": "name_test", "email": "email_test"}, question="What is the capital of France?"
565 |             )
566 | 
567 |         # Check that litellm.acompletion was called exactly once
568 |         mock_completion.assert_called_once()
569 | 
570 |         _, call_kwargs = mock_completion.call_args
571 |         # Assert that there are exactly 2 messages (system + user)
572 |         assert len(call_kwargs["messages"]) == 2
573 | 
574 |         assert call_kwargs["messages"][0]["role"] == "system"
575 |         content = call_kwargs["messages"][0]["content"]
576 |         assert content is not None
577 | 
578 |         # Assert that system prompt includes correct input field descriptions
579 |         expected_input_fields = (
580 |             "1. `user` (User): The user who asks the question\n2. `question` (str): Question the user asks\n"
581 |         )
582 |         assert expected_input_fields in content
583 | 
584 |         # Assert that system prompt includes correct output field description
585 |         expected_output_fields = "1. `answer` (Answer): Answer to this question\n"
586 |         assert expected_output_fields in content
587 | 
588 |         # Assert that system prompt includes input formatting structure
589 |         expected_input_structure = "[[ ## user ## ]]\n{user}\n\n[[ ## question ## ]]\n{question}\n\n"
590 |         assert expected_input_structure in content
591 | 
592 |         # Assert that system prompt includes output formatting structure
593 |         expected_output_structure = (
594 |             "Outputs will be a JSON object with the following fields.\n\n{\n  "
595 |             '"answer": "{answer}        # note: the value you produce must adhere to the JSON schema: '
596 |             '{\\"type\\": \\"object\\", \\"properties\\": {\\"analysis\\": {\\"type\\": \\"string\\", \\"title\\": '
597 |             '\\"Analysis\\"}, \\"result\\": {\\"type\\": \\"string\\", \\"title\\": \\"Result\\"}}, \\"required\\": '
598 |             '[\\"analysis\\", \\"result\\"], \\"title\\": \\"Answer\\"}"\n}'
599 |         )
600 |         assert expected_output_structure in content
601 | 
602 |         assert call_kwargs["messages"][1]["role"] == "user"
603 |         user_message_content = call_kwargs["messages"][1]["content"]
604 |         assert user_message_content is not None
605 | 
606 |         # Assert that the user input data is formatted correctly
607 |         expected_input_data = (
608 |             '[[ ## user ## ]]\n{"id": 5, "name": "name_test", "email": "email_test"}\n\n[[ ## question ## ]]\n'
609 |             "What is the capital of France?\n\n"
610 |         )
611 |         assert expected_input_data in user_message_content
612 | 
613 |         # Assert that the adapter output has expected fields and values
614 |         assert result.answer.analysis == "Paris is the capital of France"
615 |         assert result.answer.result == "Paris"
616 | 
617 | 
618 | def test_json_adapter_fallback_to_json_mode_on_structured_output_failure():
619 |     class TestSignature(dspy.Signature):
620 |         question: str = dspy.InputField()
621 |         answer: str = dspy.OutputField(desc="String output field")
622 | 
623 |     dspy.configure(lm=dspy.LM(model="openai/gpt-4o-mini", cache=False), adapter=dspy.JSONAdapter())
624 |     program = dspy.Predict(TestSignature)
625 | 
626 |     with mock.patch("litellm.completion") as mock_completion:
627 |         # First call raises error to simulate structured output failure, second call returns a valid response
628 |         mock_completion.side_effect = [
629 |             RuntimeError("Structured output failed!"),
630 |             ModelResponse(choices=[Choices(message=Message(content="{'answer': 'Test output'}"))]),
631 |         ]
632 | 
633 |         result = program(question="Dummy question!")
634 |         # The parse should succeed on the second call
635 |         assert mock_completion.call_count == 2
636 |         assert result.answer == "Test output"
637 | 
638 |         # The first call should have tried structured output
639 |         _, first_call_kwargs = mock_completion.call_args_list[0]
640 |         assert issubclass(first_call_kwargs.get("response_format"), pydantic.BaseModel)
641 | 
642 |         # The second call should have used JSON mode
643 |         _, second_call_kwargs = mock_completion.call_args_list[1]
644 |         assert second_call_kwargs.get("response_format") == {"type": "json_object"}
645 | 
646 | def test_json_adapter_json_mode_no_structured_outputs():
647 |     class TestSignature(dspy.Signature):
648 |         question: str = dspy.InputField()
649 |         answer: str = dspy.OutputField(desc="String output field")
650 | 
651 |     dspy.configure(lm=dspy.LM(model="openai/gpt-4o", cache=False), adapter=dspy.JSONAdapter())
652 |     program = dspy.Predict(TestSignature)
653 | 
654 |     with mock.patch("litellm.completion") as mock_completion, \
655 |         mock.patch("litellm.get_supported_openai_params") as mock_get_supported_openai_params, \
656 |         mock.patch("litellm.supports_response_schema") as mock_supports_response_schema:
657 |         # Call a model that allows json but not structured outputs
658 |         mock_completion.return_value = ModelResponse(choices=[Choices(message=Message(content="{'answer': 'Test output'}"))])
659 |         mock_get_supported_openai_params.return_value = ["response_format"]
660 |         mock_supports_response_schema.return_value = False
661 | 
662 |         result = program(question="Dummy question!")
663 | 
664 |         assert mock_completion.call_count == 1
665 |         assert result.answer == "Test output"
666 | 
667 |         _, call_kwargs = mock_completion.call_args_list[0]
668 |         assert call_kwargs.get("response_format") == {"type": "json_object"}
669 | 
670 | 
671 | @pytest.mark.asyncio
672 | async def test_json_adapter_json_mode_no_structured_outputs_async():
673 |     class TestSignature(dspy.Signature):
674 |         question: str = dspy.InputField()
675 |         answer: str = dspy.OutputField(desc="String output field")
676 | 
677 |     program = dspy.Predict(TestSignature)
678 | 
679 |     with mock.patch("litellm.acompletion") as mock_acompletion, \
680 |         mock.patch("litellm.get_supported_openai_params") as mock_get_supported_openai_params, \
681 |         mock.patch("litellm.supports_response_schema") as mock_supports_response_schema:
682 |         # Call a model that allows json but not structured outputs
683 |         mock_acompletion.return_value = ModelResponse(choices=[Choices(message=Message(content="{'answer': 'Test output'}"))])
684 |         mock_get_supported_openai_params.return_value = ["response_format"]
685 |         mock_supports_response_schema.return_value = False
686 | 
687 |         with dspy.context(lm=dspy.LM(model="openai/gpt-4o", cache=False), adapter=dspy.JSONAdapter()):
688 |             result = await program.acall(question="Dummy question!")
689 | 
690 |         assert mock_acompletion.call_count == 1
691 |         assert result.answer == "Test output"
692 | 
693 |         _, call_kwargs = mock_acompletion.call_args_list[0]
694 |         assert call_kwargs.get("response_format") == {"type": "json_object"}
695 | 
696 | 
697 | @pytest.mark.asyncio
698 | async def test_json_adapter_fallback_to_json_mode_on_structured_output_failure_async():
699 |     class TestSignature(dspy.Signature):
700 |         question: str = dspy.InputField()
701 |         answer: str = dspy.OutputField(desc="String output field")
702 | 
703 |     program = dspy.Predict(TestSignature)
704 | 
705 |     with mock.patch("litellm.acompletion") as mock_acompletion:
706 |         # First call raises error to simulate structured output failure, second call returns a valid response
707 |         mock_acompletion.side_effect = [
708 |             RuntimeError("Structured output failed!"),
709 |             ModelResponse(choices=[Choices(message=Message(content="{'answer': 'Test output'}"))]),
710 |         ]
711 | 
712 |         with dspy.context(lm=dspy.LM(model="openai/gpt-4o-mini", cache=False), adapter=dspy.JSONAdapter()):
713 |             result = await program.acall(question="Dummy question!")
714 |         # The parse should succeed on the second call
715 |         assert mock_acompletion.call_count == 2
716 |         assert result.answer == "Test output"
717 | 
718 |         # The first call should have tried structured output
719 |         _, first_call_kwargs = mock_acompletion.call_args_list[0]
720 |         assert issubclass(first_call_kwargs.get("response_format"), pydantic.BaseModel)
721 | 
722 |         # The second call should have used JSON mode
723 |         _, second_call_kwargs = mock_acompletion.call_args_list[1]
724 |         assert second_call_kwargs.get("response_format") == {"type": "json_object"}
725 | 
726 | 
727 | def test_error_message_on_json_adapter_failure():
728 |     class TestSignature(dspy.Signature):
729 |         question: str = dspy.InputField()
730 |         answer: str = dspy.OutputField(desc="String output field")
731 | 
732 |     program = dspy.Predict(TestSignature)
733 | 
734 |     dspy.configure(lm=dspy.LM(model="openai/gpt-4o-mini", cache=False), adapter=dspy.JSONAdapter())
735 | 
736 |     with mock.patch("litellm.completion") as mock_completion:
737 |         mock_completion.side_effect = RuntimeError("RuntimeError!")
738 | 
739 |         with pytest.raises(RuntimeError) as error:
740 |             program(question="Dummy question!")
741 | 
742 |         assert "RuntimeError!" in str(error.value)
743 | 
744 |         mock_completion.side_effect = ValueError("ValueError!")
745 |         with pytest.raises(ValueError) as error:
746 |             program(question="Dummy question!")
747 | 
748 |         assert "ValueError!" in str(error.value)
749 | 
750 | 
751 | @pytest.mark.asyncio
752 | async def test_error_message_on_json_adapter_failure_async():
753 |     class TestSignature(dspy.Signature):
754 |         question: str = dspy.InputField()
755 |         answer: str = dspy.OutputField(desc="String output field")
756 | 
757 |     program = dspy.Predict(TestSignature)
758 | 
759 |     with mock.patch("litellm.acompletion") as mock_acompletion:
760 |         with dspy.context(lm=dspy.LM(model="openai/gpt-4o-mini", cache=False), adapter=dspy.JSONAdapter()):
761 |             mock_acompletion.side_effect = RuntimeError("RuntimeError!")
762 |             with pytest.raises(RuntimeError) as error:
763 |                 await program.acall(question="Dummy question!")
764 | 
765 |             assert "RuntimeError!" in str(error.value)
766 | 
767 |             mock_acompletion.side_effect = ValueError("ValueError!")
768 |             with pytest.raises(ValueError) as error:
769 |                 await program.acall(question="Dummy question!")
770 | 
771 |             assert "ValueError!" in str(error.value)
772 | 
773 | 
774 | def test_json_adapter_toolcalls_native_function_calling():
775 |     class MySignature(dspy.Signature):
776 |         question: str = dspy.InputField()
777 |         tools: list[dspy.Tool] = dspy.InputField()
778 |         answer: str = dspy.OutputField()
779 |         tool_calls: dspy.ToolCalls = dspy.OutputField()
780 | 
781 |     def get_weather(city: str) -> str:
782 |         return f"The weather in {city} is sunny"
783 | 
784 |     tools = [dspy.Tool(get_weather)]
785 | 
786 |     adapter = dspy.JSONAdapter(use_native_function_calling=True)
787 | 
788 |     # Case 1: Tool calls are present in the response, while content is None.
789 |     with mock.patch("litellm.completion") as mock_completion:
790 |         mock_completion.return_value = ModelResponse(
791 |             choices=[
792 |                 Choices(
793 |                     finish_reason="tool_calls",
794 |                     index=0,
795 |                     message=Message(
796 |                         content=None,
797 |                         role="assistant",
798 |                         tool_calls=[
799 |                             ChatCompletionMessageToolCall(
800 |                                 function=Function(arguments='{"city":"Paris"}', name="get_weather"),
801 |                                 id="call_pQm8ajtSMxgA0nrzK2ivFmxG",
802 |                                 type="function",
803 |                             )
804 |                         ],
805 |                     ),
806 |                 ),
807 |             ],
808 |             model="openai/gpt-4o-mini",
809 |         )
810 |         result = adapter(
811 |             dspy.LM(model="openai/gpt-4o-mini", cache=False),
812 |             {},
813 |             MySignature,
814 |             [],
815 |             {"question": "What is the weather in Paris?", "tools": tools},
816 |         )
817 | 
818 |         assert result[0]["tool_calls"] == dspy.ToolCalls(
819 |             tool_calls=[dspy.ToolCalls.ToolCall(name="get_weather", args={"city": "Paris"})]
820 |         )
821 |         # `answer` is not present, so we set it to None
822 |         assert result[0]["answer"] is None
823 | 
824 |     # Case 2: Tool calls are not present in the response, while content is present.
825 |     with mock.patch("litellm.completion") as mock_completion:
826 |         mock_completion.return_value = ModelResponse(
827 |             choices=[Choices(message=Message(content="{'answer': 'Paris'}"))],
828 |             model="openai/gpt-4o-mini",
829 |         )
830 |         result = adapter(
831 |             dspy.LM(model="openai/gpt-4o-mini", cache=False),
832 |             {},
833 |             MySignature,
834 |             [],
835 |             {"question": "What is the weather in Paris?", "tools": tools},
836 |         )
837 |         assert result[0]["answer"] == "Paris"
838 |         assert result[0]["tool_calls"] is None
839 | 
840 | 
841 | def test_json_adapter_toolcalls_no_native_function_calling():
842 |     class MySignature(dspy.Signature):
843 |         question: str = dspy.InputField()
844 |         tools: list[dspy.Tool] = dspy.InputField()
845 |         answer: str = dspy.OutputField()
846 |         tool_calls: dspy.ToolCalls = dspy.OutputField()
847 | 
848 |     def get_weather(city: str) -> str:
849 |         return f"The weather in {city} is sunny"
850 | 
851 |     tools = [dspy.Tool(get_weather)]
852 | 
853 |     # Patch _get_structured_outputs_response_format to track calls
854 |     with mock.patch("dspy.adapters.json_adapter._get_structured_outputs_response_format") as mock_structured:
855 |         # Patch litellm.completion to return a dummy response
856 |         with mock.patch("litellm.completion") as mock_completion:
857 |             mock_completion.return_value = ModelResponse(
858 |                 choices=[Choices(message=Message(content="{'answer': 'sunny', 'tool_calls': {'tool_calls': []}}"))],
859 |                 model="openai/gpt-4o-mini",
860 |             )
861 |             adapter = dspy.JSONAdapter(use_native_function_calling=False)
862 |             lm = dspy.LM(model="openai/gpt-4o-mini", cache=False)
863 |             adapter(lm, {}, MySignature, [], {"question": "What is the weather in Tokyo?", "tools": tools})
864 | 
865 |         # _get_structured_outputs_response_format is not called because without using native function calling,
866 |         # JSONAdapter falls back to json mode for stable quality.
867 |         mock_structured.assert_not_called()
868 |         mock_completion.assert_called_once()
869 |         _, call_kwargs = mock_completion.call_args
870 |         assert call_kwargs["response_format"] == {"type": "json_object"}
871 | 
872 | 
873 | def test_json_adapter_with_responses_api():
874 |     class TestSignature(dspy.Signature):
875 |         question: str = dspy.InputField()
876 |         answer: str = dspy.OutputField()
877 | 
878 |     api_response = ResponsesAPIResponse(
879 |         id="resp_1",
880 |         created_at=0.0,
881 |         error=None,
882 |         incomplete_details=None,
883 |         instructions=None,
884 |         model="openai/gpt-4o",
885 |         object="response",
886 |         output=[
887 |             ResponseOutputMessage(
888 |                 **{
889 |                     "id": "msg_1",
890 |                     "type": "message",
891 |                     "role": "assistant",
892 |                     "status": "completed",
893 |                     "content": [
894 |                         {"type": "output_text", "text": '{"answer": "Washington, D.C."}', "annotations": []}
895 |                     ],
896 |                 },
897 |             ),
898 |         ],
899 |         metadata={},
900 |         parallel_tool_calls=False,
901 |         temperature=1.0,
902 |         tool_choice="auto",
903 |         tools=[],
904 |         top_p=1.0,
905 |         max_output_tokens=None,
906 |         previous_response_id=None,
907 |         reasoning=None,
908 |         status="completed",
909 |         text=None,
910 |         truncation="disabled",
911 |         usage=ResponseAPIUsage(input_tokens=10, output_tokens=5, total_tokens=15),
912 |         user=None,
913 |     )
914 | 
915 |     lm = dspy.LM(model="openai/gpt-4o", model_type="responses", cache=False)
916 |     dspy.configure(lm=lm, adapter=dspy.JSONAdapter())
917 | 
918 |     program = dspy.Predict(TestSignature)
919 |     with mock.patch("litellm.responses", autospec=True, return_value=api_response) as mock_responses:
920 |         result = program(question="What is the capital of the USA?")
921 | 
922 |     assert result.answer == "Washington, D.C."
923 |     mock_responses.assert_called_once()
924 |     # Verify that response_format was converted to text.format
925 |     call_kwargs = mock_responses.call_args.kwargs
926 |     assert "response_format" not in call_kwargs
927 |     assert "text" in call_kwargs
928 |     assert isinstance(call_kwargs["text"]["format"], type)
929 |     assert issubclass(call_kwargs["text"]["format"], pydantic.BaseModel)
930 | 
```

--------------------------------------------------------------------------------
/tests/streaming/test_streaming.py:
--------------------------------------------------------------------------------

```python
   1 | import asyncio
   2 | import time
   3 | from dataclasses import dataclass
   4 | from unittest import mock
   5 | from unittest.mock import AsyncMock
   6 | 
   7 | import pytest
   8 | from asyncer import syncify
   9 | from litellm.types.utils import Delta, ModelResponseStream, StreamingChoices
  10 | 
  11 | import dspy
  12 | from dspy.adapters.types import Type
  13 | from dspy.experimental import Citations, Document
  14 | from dspy.streaming import StatusMessage, StatusMessageProvider, streaming_response
  15 | 
  16 | 
  17 | @pytest.mark.anyio
  18 | async def test_streamify_yields_expected_response_chunks(litellm_test_server):
  19 |     api_base, _ = litellm_test_server
  20 |     lm = dspy.LM(
  21 |         model="openai/dspy-test-model",
  22 |         api_base=api_base,
  23 |         api_key="fakekey",
  24 |         cache=True,
  25 |     )
  26 |     with dspy.context(lm=lm, adapter=dspy.JSONAdapter()):
  27 | 
  28 |         class TestSignature(dspy.Signature):
  29 |             input_text: str = dspy.InputField()
  30 |             output_text: str = dspy.OutputField()
  31 | 
  32 |         program = dspy.streamify(dspy.Predict(TestSignature))
  33 |         output_stream1 = program(input_text="Test")
  34 |         output_chunks1 = [chunk async for chunk in output_stream1]
  35 |         last_chunk1 = output_chunks1[-1]
  36 |         assert isinstance(last_chunk1, dspy.Prediction)
  37 |         assert last_chunk1.output_text == "Hello!"
  38 | 
  39 |         output_stream2 = program(input_text="Test")
  40 |         output_chunks2 = [chunk async for chunk in output_stream2]
  41 |         # Since the input is cached, only one chunk should be
  42 |         # yielded containing the prediction
  43 |         assert len(output_chunks2) == 1
  44 |         last_chunk2 = output_chunks2[-1]
  45 |         assert isinstance(last_chunk2, dspy.Prediction)
  46 |         assert last_chunk2.output_text == "Hello!"
  47 | 
  48 | 
  49 | @pytest.mark.anyio
  50 | async def test_streaming_response_yields_expected_response_chunks(litellm_test_server):
  51 |     api_base, _ = litellm_test_server
  52 |     lm = dspy.LM(
  53 |         model="openai/dspy-test-model",
  54 |         api_base=api_base,
  55 |         api_key="fakekey",
  56 |         cache=False,
  57 |     )
  58 |     with dspy.context(lm=lm):
  59 | 
  60 |         class TestSignature(dspy.Signature):
  61 |             input_text: str = dspy.InputField()
  62 |             output_text: str = dspy.OutputField()
  63 | 
  64 |         program = dspy.streamify(dspy.Predict(TestSignature))
  65 |         output_stream_from_program = streaming_response(program(input_text="Test"))
  66 |         output_stream_for_server_response = streaming_response(output_stream_from_program)
  67 |         output_chunks = [chunk async for chunk in output_stream_for_server_response]
  68 |         assert all(chunk.startswith("data: ") for chunk in output_chunks)
  69 |         assert 'data: {"prediction":{"output_text":"Hello!"}}\n\n' in output_chunks
  70 |         assert output_chunks[-1] == "data: [DONE]\n\n"
  71 | 
  72 | 
  73 | @pytest.mark.anyio
  74 | async def test_default_status_streaming():
  75 |     class MyProgram(dspy.Module):
  76 |         def __init__(self):
  77 |             self.generate_question = dspy.Tool(lambda x: f"What color is the {x}?", name="generate_question")
  78 |             self.predict = dspy.Predict("question->answer")
  79 | 
  80 |         def __call__(self, x: str):
  81 |             question = self.generate_question(x=x)
  82 |             return self.predict(question=question)
  83 | 
  84 |     lm = dspy.utils.DummyLM([{"answer": "red"}, {"answer": "blue"}])
  85 |     with dspy.context(lm=lm):
  86 |         program = dspy.streamify(MyProgram())
  87 |         output = program("sky")
  88 | 
  89 |         status_messages = []
  90 |         async for value in output:
  91 |             if isinstance(value, StatusMessage):
  92 |                 status_messages.append(value)
  93 | 
  94 |     assert len(status_messages) == 2
  95 |     assert status_messages[0].message == "Calling tool generate_question..."
  96 |     assert status_messages[1].message == "Tool calling finished! Querying the LLM with tool calling results..."
  97 | 
  98 | 
  99 | @pytest.mark.anyio
 100 | async def test_custom_status_streaming():
 101 |     class MyProgram(dspy.Module):
 102 |         def __init__(self):
 103 |             self.generate_question = dspy.Tool(lambda x: f"What color is the {x}?", name="generate_question")
 104 |             self.predict = dspy.Predict("question->answer")
 105 | 
 106 |         def __call__(self, x: str):
 107 |             question = self.generate_question(x=x)
 108 |             return self.predict(question=question)
 109 | 
 110 |     class MyStatusMessageProvider(StatusMessageProvider):
 111 |         def tool_start_status_message(self, instance, inputs):
 112 |             return "Tool starting!"
 113 | 
 114 |         def tool_end_status_message(self, outputs):
 115 |             return "Tool finished!"
 116 | 
 117 |         def module_start_status_message(self, instance, inputs):
 118 |             if isinstance(instance, dspy.Predict):
 119 |                 return "Predict starting!"
 120 | 
 121 |     lm = dspy.utils.DummyLM([{"answer": "red"}, {"answer": "blue"}])
 122 |     with dspy.context(lm=lm):
 123 |         program = dspy.streamify(MyProgram(), status_message_provider=MyStatusMessageProvider())
 124 |         output = program("sky")
 125 | 
 126 |         status_messages = []
 127 |         async for value in output:
 128 |             if isinstance(value, StatusMessage):
 129 |                 status_messages.append(value)
 130 | 
 131 |         assert len(status_messages) == 3
 132 |         assert status_messages[0].message == "Tool starting!"
 133 |         assert status_messages[1].message == "Tool finished!"
 134 |         assert status_messages[2].message == "Predict starting!"
 135 | 
 136 | 
 137 | @pytest.mark.llm_call
 138 | @pytest.mark.anyio
 139 | async def test_stream_listener_chat_adapter(lm_for_test):
 140 |     class MyProgram(dspy.Module):
 141 |         def __init__(self):
 142 |             self.predict1 = dspy.Predict("question->answer")
 143 |             self.predict2 = dspy.Predict("question, answer->judgement")
 144 | 
 145 |         def __call__(self, x: str, **kwargs):
 146 |             answer = self.predict1(question=x, **kwargs)
 147 |             judgement = self.predict2(question=x, answer=answer, **kwargs)
 148 |             return judgement
 149 | 
 150 |     my_program = MyProgram()
 151 |     program = dspy.streamify(
 152 |         my_program,
 153 |         stream_listeners=[
 154 |             dspy.streaming.StreamListener(signature_field_name="answer"),
 155 |             dspy.streaming.StreamListener(signature_field_name="judgement"),
 156 |         ],
 157 |         include_final_prediction_in_output_stream=False,
 158 |     )
 159 |     # Turn off the cache to ensure the stream is produced.
 160 |     with dspy.context(lm=dspy.LM(lm_for_test, cache=False, temperature=0.0)):
 161 |         output = program(x="why did a chicken cross the kitchen?")
 162 |         all_chunks = []
 163 |         async for value in output:
 164 |             if isinstance(value, dspy.streaming.StreamResponse):
 165 |                 all_chunks.append(value)
 166 | 
 167 |     assert all_chunks[0].predict_name == "predict1"
 168 |     assert all_chunks[0].signature_field_name == "answer"
 169 |     # The last chunk can be from either predictor because sometimes small LMs miss the `[[ ## completed ## ]]` marker,
 170 |     # which results in an extra chunk that flushes out the buffer.
 171 |     assert all_chunks[-2].predict_name == "predict2"
 172 |     assert all_chunks[-2].signature_field_name == "judgement"
 173 | 
 174 | 
 175 | @pytest.mark.anyio
 176 | async def test_default_status_streaming_in_async_program():
 177 |     class MyProgram(dspy.Module):
 178 |         def __init__(self):
 179 |             self.generate_question = dspy.Tool(lambda x: f"What color is the {x}?", name="generate_question")
 180 |             self.predict = dspy.Predict("question->answer")
 181 | 
 182 |         async def acall(self, x: str):
 183 |             question = await self.generate_question.acall(x=x)
 184 |             return await self.predict.acall(question=question)
 185 | 
 186 |     lm = dspy.utils.DummyLM([{"answer": "red"}, {"answer": "blue"}])
 187 |     with dspy.context(lm=lm):
 188 |         program = dspy.streamify(MyProgram(), is_async_program=True)
 189 |         output = program("sky")
 190 | 
 191 |         status_messages = []
 192 |         async for value in output:
 193 |             if isinstance(value, StatusMessage):
 194 |                 status_messages.append(value)
 195 | 
 196 |     assert len(status_messages) == 2
 197 |     assert status_messages[0].message == "Calling tool generate_question..."
 198 |     assert status_messages[1].message == "Tool calling finished! Querying the LLM with tool calling results..."
 199 | 
 200 | 
 201 | @pytest.mark.llm_call
 202 | @pytest.mark.anyio
 203 | async def test_stream_listener_json_adapter(lm_for_test):
 204 |     class MyProgram(dspy.Module):
 205 |         def __init__(self):
 206 |             self.predict1 = dspy.Predict("question->answer")
 207 |             self.predict2 = dspy.Predict("question, answer->judgement")
 208 | 
 209 |         def __call__(self, x: str, **kwargs):
 210 |             answer = self.predict1(question=x, **kwargs)
 211 |             judgement = self.predict2(question=x, answer=answer, **kwargs)
 212 |             return judgement
 213 | 
 214 |     my_program = MyProgram()
 215 |     program = dspy.streamify(
 216 |         my_program,
 217 |         stream_listeners=[
 218 |             dspy.streaming.StreamListener(signature_field_name="answer"),
 219 |             dspy.streaming.StreamListener(signature_field_name="judgement"),
 220 |         ],
 221 |         include_final_prediction_in_output_stream=False,
 222 |     )
 223 |     # Turn off the cache to ensure the stream is produced.
 224 |     with dspy.context(lm=dspy.LM(lm_for_test, cache=False, temperature=0.0), adapter=dspy.JSONAdapter()):
 225 |         output = program(x="why did a chicken cross the kitchen?")
 226 |         all_chunks = []
 227 |         async for value in output:
 228 |             if isinstance(value, dspy.streaming.StreamResponse):
 229 |                 all_chunks.append(value)
 230 | 
 231 |     assert all_chunks[0].predict_name == "predict1"
 232 |     assert all_chunks[0].signature_field_name == "answer"
 233 |     assert all_chunks[0].is_last_chunk is False
 234 | 
 235 |     assert all_chunks[-1].predict_name == "predict2"
 236 |     assert all_chunks[-1].signature_field_name == "judgement"
 237 | 
 238 | 
 239 | @pytest.mark.anyio
 240 | async def test_streaming_handles_space_correctly():
 241 |     my_program = dspy.Predict("question->answer")
 242 |     program = dspy.streamify(
 243 |         my_program, stream_listeners=[dspy.streaming.StreamListener(signature_field_name="answer")]
 244 |     )
 245 | 
 246 |     async def gpt_4o_mini_stream(*args, **kwargs):
 247 |         yield ModelResponseStream(
 248 |             model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="[[ ## answer ## ]]\n"))]
 249 |         )
 250 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="How "))])
 251 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="are "))])
 252 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="you "))])
 253 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="doing?"))])
 254 |         yield ModelResponseStream(
 255 |             model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="\n\n[[ ## completed ## ]]"))]
 256 |         )
 257 | 
 258 |     with mock.patch("litellm.acompletion", side_effect=gpt_4o_mini_stream):
 259 |         with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False), adapter=dspy.ChatAdapter()):
 260 |             output = program(question="What is the capital of France?")
 261 |             all_chunks = []
 262 |             async for value in output:
 263 |                 if isinstance(value, dspy.streaming.StreamResponse):
 264 |                     all_chunks.append(value)
 265 | 
 266 |     assert "".join([chunk.chunk for chunk in all_chunks]) == "How are you doing?"
 267 | 
 268 | 
 269 | @pytest.mark.llm_call
 270 | def test_sync_streaming(lm_for_test):
 271 |     class MyProgram(dspy.Module):
 272 |         def __init__(self):
 273 |             self.predict1 = dspy.Predict("question->answer")
 274 |             self.predict2 = dspy.Predict("question, answer->judgement")
 275 | 
 276 |         def __call__(self, x: str, **kwargs):
 277 |             answer = self.predict1(question=x, **kwargs)
 278 |             judgement = self.predict2(question=x, answer=answer, **kwargs)
 279 |             return judgement
 280 | 
 281 |     my_program = MyProgram()
 282 |     program = dspy.streamify(
 283 |         my_program,
 284 |         stream_listeners=[
 285 |             dspy.streaming.StreamListener(signature_field_name="answer"),
 286 |             dspy.streaming.StreamListener(signature_field_name="judgement"),
 287 |         ],
 288 |         include_final_prediction_in_output_stream=False,
 289 |         async_streaming=False,
 290 |     )
 291 |     # Turn off the cache to ensure the stream is produced.
 292 |     with dspy.context(lm=dspy.LM(lm_for_test, cache=False, temperature=0.0)):
 293 |         output = program(x="why did a chicken cross the kitchen?")
 294 |         all_chunks = []
 295 |         for value in output:
 296 |             if isinstance(value, dspy.streaming.StreamResponse):
 297 |                 all_chunks.append(value)
 298 | 
 299 |     assert all_chunks[0].predict_name == "predict1"
 300 |     assert all_chunks[0].signature_field_name == "answer"
 301 |     assert all_chunks[0].is_last_chunk is False
 302 |     # The last chunk can be from either predictor because sometimes small LMs miss the `[[ ## completed ## ]]` marker,
 303 |     # which results in an extra chunk that flushes out the buffer.
 304 |     assert all_chunks[-2].predict_name == "predict2"
 305 |     assert all_chunks[-2].signature_field_name == "judgement"
 306 | 
 307 | 
 308 | def test_sync_status_streaming():
 309 |     class MyProgram(dspy.Module):
 310 |         def __init__(self):
 311 |             self.generate_question = dspy.Tool(lambda x: f"What color is the {x}?", name="generate_question")
 312 |             self.predict = dspy.Predict("question->answer")
 313 | 
 314 |         def __call__(self, x: str):
 315 |             question = self.generate_question(x=x)
 316 |             return self.predict(question=question)
 317 | 
 318 |     lm = dspy.utils.DummyLM([{"answer": "red"}, {"answer": "blue"}])
 319 |     with dspy.context(lm=lm):
 320 |         program = dspy.streamify(MyProgram())
 321 |         output = program("sky")
 322 |         sync_output = dspy.streaming.apply_sync_streaming(output)
 323 |         status_messages = []
 324 |         for value in sync_output:
 325 |             if isinstance(value, StatusMessage):
 326 |                 status_messages.append(value)
 327 | 
 328 |     assert len(status_messages) == 2
 329 |     assert status_messages[0].message == "Calling tool generate_question..."
 330 |     assert status_messages[1].message == "Tool calling finished! Querying the LLM with tool calling results..."
 331 | 
 332 | 
 333 | @pytest.mark.anyio
 334 | async def test_stream_listener_returns_correct_chunk_chat_adapter():
 335 |     class MyProgram(dspy.Module):
 336 |         def __init__(self):
 337 |             super().__init__()
 338 |             self.predict1 = dspy.Predict("question->answer")
 339 |             self.predict2 = dspy.Predict("question, answer->judgement")
 340 | 
 341 |         def forward(self, question, **kwargs):
 342 |             answer = self.predict1(question=question, **kwargs).answer
 343 |             judgement = self.predict2(question=question, answer=answer, **kwargs)
 344 |             return judgement
 345 | 
 346 |     async def gpt_4o_mini_stream_1(*args, **kwargs):
 347 |         # Recorded streaming from openai/gpt-4o-mini
 348 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="[["))])
 349 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" ##"))])
 350 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" answer"))])
 351 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" ##"))])
 352 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" ]]\n\n"))])
 353 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="To"))])
 354 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" get"))])
 355 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" to"))])
 356 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" the"))])
 357 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" other"))])
 358 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" side"))])
 359 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" of"))])
 360 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" the"))])
 361 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" dinner"))])
 362 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" plate"))])
 363 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="!\n\n[[ ##"))])
 364 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" completed"))])
 365 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" ##"))])
 366 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" ]]"))])
 367 | 
 368 |     async def gpt_4o_mini_stream_2():
 369 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="[[ ##"))])
 370 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" judgement"))])
 371 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" ##"))])
 372 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" ]]\n\n"))])
 373 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="The"))])
 374 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" answer"))])
 375 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" is"))])
 376 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" humorous"))])
 377 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" and"))])
 378 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" plays"))])
 379 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" on"))])
 380 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" the"))])
 381 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" classic"))])
 382 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" joke"))])
 383 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" format"))])
 384 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=".\n\n[[ ##"))])
 385 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" completed"))])
 386 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" ##"))])
 387 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" ]]"))])
 388 | 
 389 |     stream_generators = [gpt_4o_mini_stream_1, gpt_4o_mini_stream_2]
 390 | 
 391 |     async def completion_side_effect(*args, **kwargs):
 392 |         return stream_generators.pop(0)()  # return new async generator instance
 393 | 
 394 |     with mock.patch("litellm.acompletion", side_effect=completion_side_effect):
 395 |         program = dspy.streamify(
 396 |             MyProgram(),
 397 |             stream_listeners=[
 398 |                 dspy.streaming.StreamListener(signature_field_name="answer"),
 399 |                 dspy.streaming.StreamListener(signature_field_name="judgement"),
 400 |             ],
 401 |         )
 402 |         with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False)):
 403 |             output = program(question="why did a chicken cross the kitchen?")
 404 |             all_chunks = []
 405 |             async for value in output:
 406 |                 if isinstance(value, dspy.streaming.StreamResponse):
 407 |                     all_chunks.append(value)
 408 | 
 409 |         assert all_chunks[0].predict_name == "predict1"
 410 |         assert all_chunks[0].signature_field_name == "answer"
 411 |         assert all_chunks[0].chunk == "To"
 412 |         assert all_chunks[1].chunk == " get"
 413 |         assert all_chunks[2].chunk == " to"
 414 |         assert all_chunks[3].chunk == " the"
 415 |         assert all_chunks[4].chunk == " other"
 416 |         assert all_chunks[5].chunk == " side"
 417 |         assert all_chunks[6].chunk == " of"
 418 |         assert all_chunks[7].chunk == " the"
 419 |         assert all_chunks[8].chunk == " dinner"
 420 |         assert all_chunks[9].chunk == " plate"
 421 |         assert all_chunks[10].chunk == "!"
 422 |         assert all_chunks[10].is_last_chunk is True
 423 | 
 424 |         assert all_chunks[11].predict_name == "predict2"
 425 |         assert all_chunks[11].signature_field_name == "judgement"
 426 |         assert all_chunks[11].chunk == "The"
 427 |         assert all_chunks[12].chunk == " answer"
 428 |         assert all_chunks[13].chunk == " is"
 429 |         assert all_chunks[14].chunk == " humorous"
 430 |         assert all_chunks[15].chunk == " and"
 431 |         assert all_chunks[16].chunk == " plays"
 432 |         assert all_chunks[17].chunk == " on"
 433 |         assert all_chunks[18].chunk == " the"
 434 |         assert all_chunks[19].chunk == " classic"
 435 |         assert all_chunks[20].chunk == " joke"
 436 |         assert all_chunks[21].chunk == " format"
 437 |         assert all_chunks[22].chunk == "."
 438 |         assert all_chunks[22].is_last_chunk is True
 439 | 
 440 | 
 441 | @pytest.mark.anyio
 442 | async def test_stream_listener_returns_correct_chunk_json_adapter():
 443 |     class MyProgram(dspy.Module):
 444 |         def __init__(self):
 445 |             super().__init__()
 446 |             self.predict1 = dspy.Predict("question->answer")
 447 |             self.predict2 = dspy.Predict("question,answer->judgement")
 448 | 
 449 |         def forward(self, question, **kwargs):
 450 |             answer = self.predict1(question=question, **kwargs).answer
 451 |             judgement = self.predict2(question=question, answer=answer, **kwargs)
 452 |             return judgement
 453 | 
 454 |     async def gpt_4o_mini_stream_1(*args, **kwargs):
 455 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content='{"'))])
 456 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="answer"))])
 457 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content='":'))])
 458 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="To"))])
 459 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" get"))])
 460 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" to"))])
 461 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" the"))])
 462 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" other"))])
 463 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" side"))])
 464 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" of"))])
 465 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" the"))])
 466 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" frying"))])
 467 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" pan"))])
 468 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content='!"'))])
 469 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="}\n"))])
 470 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="None"))])
 471 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="None"))])
 472 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="None"))])
 473 | 
 474 |     async def gpt_4o_mini_stream_2(*args, **kwargs):
 475 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content='{"'))])
 476 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="jud"))])
 477 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="gement"))])
 478 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content='":"'))])
 479 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="The"))])
 480 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" answer"))])
 481 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" is"))])
 482 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" humorous"))])
 483 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" and"))])
 484 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" plays"))])
 485 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" on"))])
 486 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" the"))])
 487 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" very"))])
 488 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" funny"))])
 489 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" and"))])
 490 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" classic"))])
 491 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" joke"))])
 492 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" format"))])
 493 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content='."'))])
 494 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="}"))])
 495 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="None"))])
 496 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="None"))])
 497 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="None"))])
 498 | 
 499 |     with mock.patch(
 500 |         "litellm.acompletion", new_callable=AsyncMock, side_effect=[gpt_4o_mini_stream_1(), gpt_4o_mini_stream_2()]
 501 |     ):
 502 |         program = dspy.streamify(
 503 |             MyProgram(),
 504 |             stream_listeners=[
 505 |                 dspy.streaming.StreamListener(signature_field_name="answer"),
 506 |                 dspy.streaming.StreamListener(signature_field_name="judgement"),
 507 |             ],
 508 |         )
 509 |         with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False), adapter=dspy.JSONAdapter()):
 510 |             output = program(question="why did a chicken cross the kitchen?")
 511 |             all_chunks = []
 512 |             async for value in output:
 513 |                 if isinstance(value, dspy.streaming.StreamResponse):
 514 |                     all_chunks.append(value)
 515 | 
 516 |         assert all_chunks[0].predict_name == "predict1"
 517 |         assert all_chunks[0].signature_field_name == "answer"
 518 |         assert all_chunks[0].chunk == "To"
 519 |         assert all_chunks[1].chunk == " get"
 520 |         assert all_chunks[2].chunk == " to"
 521 |         assert all_chunks[3].chunk == " the"
 522 |         assert all_chunks[4].chunk == " other"
 523 |         assert all_chunks[5].chunk == " side"
 524 |         assert all_chunks[6].chunk == " of"
 525 |         assert all_chunks[7].chunk == " the"
 526 |         assert all_chunks[8].chunk == " frying"
 527 |         assert all_chunks[9].chunk == " pan"
 528 |         assert all_chunks[10].chunk == "!"
 529 |         assert all_chunks[10].is_last_chunk is True
 530 | 
 531 |         assert all_chunks[11].predict_name == "predict2"
 532 |         assert all_chunks[11].signature_field_name == "judgement"
 533 |         assert all_chunks[11].chunk == "The"
 534 |         assert all_chunks[12].chunk == " answer"
 535 |         assert all_chunks[13].chunk == " is"
 536 |         assert all_chunks[14].chunk == " humorous"
 537 |         assert all_chunks[15].chunk == " and"
 538 |         assert all_chunks[16].chunk == " plays"
 539 |         assert all_chunks[17].chunk == " on"
 540 |         assert all_chunks[18].chunk == " the"
 541 |         assert all_chunks[19].chunk == " very"
 542 |         assert all_chunks[20].chunk == " funny"
 543 |         assert all_chunks[21].chunk == " and"
 544 |         assert all_chunks[22].chunk == " classic"
 545 |         assert all_chunks[23].chunk == " joke"
 546 |         assert all_chunks[24].chunk == " format"
 547 |         assert all_chunks[25].chunk == "."
 548 |         assert all_chunks[25].is_last_chunk is True
 549 | 
 550 | 
 551 | @pytest.mark.anyio
 552 | async def test_stream_listener_returns_correct_chunk_chat_adapter_untokenized_stream():
 553 |     class MyProgram(dspy.Module):
 554 |         def __init__(self):
 555 |             super().__init__()
 556 |             self.predict1 = dspy.Predict("question->answer")
 557 |             self.predict2 = dspy.Predict("question,answer->judgement")
 558 | 
 559 |         def forward(self, question, **kwargs):
 560 |             answer = self.predict1(question=question, **kwargs).answer
 561 |             judgement = self.predict2(question=question, answer=answer, **kwargs)
 562 |             return judgement
 563 | 
 564 |     async def gemini_stream_1(*args, **kwargs):
 565 |         yield ModelResponseStream(model="gemini", choices=[StreamingChoices(delta=Delta(content="[[ ##"))])
 566 |         yield ModelResponseStream(model="gemini", choices=[StreamingChoices(delta=Delta(content=" answer ## ]]"))])
 567 |         yield ModelResponseStream(
 568 |             model="gemini", choices=[StreamingChoices(delta=Delta(content="To get to the other side."))]
 569 |         )
 570 |         yield ModelResponseStream(
 571 |             model="gemini", choices=[StreamingChoices(delta=Delta(content="\n\n[[ ## completed ## ]]"))]
 572 |         )
 573 | 
 574 |     async def gemini_stream_2(*args, **kwargs):
 575 |         yield ModelResponseStream(
 576 |             model="gemini", choices=[StreamingChoices(delta=Delta(content="[[ ## judgement ## ]]\n\n"))]
 577 |         )
 578 |         yield ModelResponseStream(
 579 |             model="gemini",
 580 |             choices=[
 581 |                 StreamingChoices(
 582 |                     delta=Delta(
 583 |                         content=(
 584 |                             "The answer provides the standard punchline for this classic joke format, adapted to the "
 585 |                             "specific location mentioned in the question. It is the expected and appropriate response."
 586 |                         )
 587 |                     )
 588 |                 )
 589 |             ],
 590 |         )
 591 |         yield ModelResponseStream(
 592 |             model="gemini",
 593 |             choices=[StreamingChoices(delta=Delta(content="\n\n[[ ## completed ## ]]"))],
 594 |         )
 595 |         yield ModelResponseStream(model="gemini", choices=[StreamingChoices(delta=Delta(content="}\n"))])
 596 | 
 597 |     with mock.patch("litellm.acompletion", new_callable=AsyncMock, side_effect=[gemini_stream_1(), gemini_stream_2()]):
 598 |         program = dspy.streamify(
 599 |             MyProgram(),
 600 |             stream_listeners=[
 601 |                 dspy.streaming.StreamListener(signature_field_name="answer"),
 602 |                 dspy.streaming.StreamListener(signature_field_name="judgement"),
 603 |             ],
 604 |         )
 605 |         with dspy.context(lm=dspy.LM("gemini/gemini-2.5-flash", cache=False), adapter=dspy.ChatAdapter()):
 606 |             output = program(question="why did a chicken cross the kitchen?")
 607 |             all_chunks = []
 608 |             async for value in output:
 609 |                 if isinstance(value, dspy.streaming.StreamResponse):
 610 |                     all_chunks.append(value)
 611 | 
 612 |         assert all_chunks[0].predict_name == "predict1"
 613 |         assert all_chunks[0].signature_field_name == "answer"
 614 |         assert all_chunks[0].chunk == "To get to the other side."
 615 | 
 616 |         assert all_chunks[1].predict_name == "predict2"
 617 |         assert all_chunks[1].signature_field_name == "judgement"
 618 |         assert all_chunks[1].chunk == (
 619 |             "The answer provides the standard punchline for this classic joke format, adapted to the specific location "
 620 |             "mentioned in the question. It is the expected and appropriate response."
 621 |         )
 622 | 
 623 | 
 624 | @pytest.mark.anyio
 625 | async def test_stream_listener_missing_completion_marker_chat_adapter():
 626 |     """Test that streaming works correctly when LLM response omits a final completion marker.
 627 | 
 628 |     This test verifies that:
 629 |     1. All tokens are yielded including those in the buffer
 630 |     2. The last chunk is properly marked with is_last_chunk=True
 631 |     3. No tokens are lost when the completion marker is missing
 632 |     """
 633 | 
 634 |     class MyProgram(dspy.Module):
 635 |         def __init__(self):
 636 |             super().__init__()
 637 |             self.predict = dspy.Predict("question->answer")
 638 | 
 639 |         def forward(self, question, **kwargs):
 640 |             return self.predict(question=question, **kwargs)
 641 | 
 642 |     async def incomplete_stream(*args, **kwargs):
 643 |         """Stream that includes start marker but MISSING completion marker"""
 644 |         # Start marker
 645 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="[[ ##"))])
 646 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" answer"))])
 647 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" ## ]]\n\n"))])
 648 | 
 649 |         # Content tokens - more than 10 to ensure buffering happens
 650 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="This"))])
 651 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" is"))])
 652 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" a"))])
 653 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" test"))])
 654 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" response"))])
 655 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" with"))])
 656 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" many"))])
 657 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" tokens"))])
 658 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" to"))])
 659 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" ensure"))])
 660 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" buffering"))])
 661 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" works"))])
 662 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" correctly"))])
 663 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="."))])
 664 |         # NO COMPLETION MARKER
 665 | 
 666 |     with mock.patch("litellm.acompletion", side_effect=incomplete_stream):
 667 |         program = dspy.streamify(
 668 |             MyProgram(),
 669 |             stream_listeners=[
 670 |                 dspy.streaming.StreamListener(signature_field_name="answer"),
 671 |             ],
 672 |         )
 673 |         with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False), adapter=dspy.ChatAdapter()):
 674 |             output = program(question="Test question")
 675 |             all_chunks = []
 676 |             final_prediction = None
 677 |             async for value in output:
 678 |                 if isinstance(value, dspy.streaming.StreamResponse):
 679 |                     all_chunks.append(value)
 680 |                 elif isinstance(value, dspy.Prediction):
 681 |                     final_prediction = value
 682 | 
 683 |     full_content = "".join([chunk.chunk for chunk in all_chunks])
 684 |     expected_content = "This is a test response with many tokens to ensure buffering works correctly."
 685 |     assert full_content == expected_content
 686 |     assert final_prediction.answer == expected_content
 687 | 
 688 | 
 689 | @pytest.mark.anyio
 690 | async def test_stream_listener_returns_correct_chunk_json_adapter_untokenized_stream():
 691 |     class MyProgram(dspy.Module):
 692 |         def __init__(self):
 693 |             super().__init__()
 694 |             self.predict1 = dspy.Predict("question->answer")
 695 |             self.predict2 = dspy.Predict("question,answer->judgement")
 696 | 
 697 |         def forward(self, question, **kwargs):
 698 |             answer = self.predict1(question=question, **kwargs).answer
 699 |             judgement = self.predict2(question=question, answer=answer, **kwargs)
 700 |             return judgement
 701 | 
 702 |     async def gemini_stream_1(*args, **kwargs):
 703 |         yield ModelResponseStream(model="gemini", choices=[StreamingChoices(delta=Delta(content="{\n"))])
 704 |         yield ModelResponseStream(
 705 |             model="gemini", choices=[StreamingChoices(delta=Delta(content='  "answer": "To get to'))]
 706 |         )
 707 |         yield ModelResponseStream(
 708 |             model="gemini", choices=[StreamingChoices(delta=Delta(content=' the other side... of the cutting board!"'))]
 709 |         )
 710 |         yield ModelResponseStream(model="gemini", choices=[StreamingChoices(delta=Delta(content="}\n"))])
 711 | 
 712 |     async def gemini_stream_2(*args, **kwargs):
 713 |         yield ModelResponseStream(model="gemini", choices=[StreamingChoices(delta=Delta(content="{\n"))])
 714 |         yield ModelResponseStream(
 715 |             model="gemini", choices=[StreamingChoices(delta=Delta(content='  "judgement": "The'))]
 716 |         )
 717 |         yield ModelResponseStream(
 718 |             model="gemini",
 719 |             choices=[
 720 |                 StreamingChoices(
 721 |                     delta=Delta(
 722 |                         content=' answer provides a humorous and relevant punchline to the classic joke setup."'
 723 |                     )
 724 |                 )
 725 |             ],
 726 |         )
 727 |         yield ModelResponseStream(model="gemini", choices=[StreamingChoices(delta=Delta(content="}\n"))])
 728 | 
 729 |     with mock.patch("litellm.acompletion", new_callable=AsyncMock, side_effect=[gemini_stream_1(), gemini_stream_2()]):
 730 |         program = dspy.streamify(
 731 |             MyProgram(),
 732 |             stream_listeners=[
 733 |                 dspy.streaming.StreamListener(signature_field_name="answer"),
 734 |                 dspy.streaming.StreamListener(signature_field_name="judgement"),
 735 |             ],
 736 |         )
 737 |         with dspy.context(lm=dspy.LM("gemini/gemini-2.5-flash", cache=False), adapter=dspy.JSONAdapter()):
 738 |             output = program(question="why did a chicken cross the kitchen?")
 739 |             all_chunks = []
 740 |             async for value in output:
 741 |                 if isinstance(value, dspy.streaming.StreamResponse):
 742 |                     all_chunks.append(value)
 743 | 
 744 |         assert all_chunks[0].predict_name == "predict1"
 745 |         assert all_chunks[0].signature_field_name == "answer"
 746 |         assert all_chunks[0].chunk == "To get to the other side... of the cutting board!"
 747 | 
 748 |         assert all_chunks[1].predict_name == "predict2"
 749 |         assert all_chunks[1].signature_field_name == "judgement"
 750 |         assert all_chunks[1].chunk == "The answer provides a humorous and relevant punchline to the classic joke setup."
 751 | 
 752 | 
 753 | @pytest.mark.anyio
 754 | async def test_status_message_non_blocking():
 755 |     def dummy_tool():
 756 |         time.sleep(1)
 757 |         return "dummy_tool_output"
 758 | 
 759 |     class MyProgram(dspy.Module):
 760 |         def forward(self, question, **kwargs):
 761 |             dspy.Tool(dummy_tool)()
 762 |             return dspy.Prediction(answer="dummy_tool_output")
 763 | 
 764 |     program = dspy.streamify(MyProgram(), status_message_provider=StatusMessageProvider())
 765 | 
 766 |     with mock.patch("litellm.acompletion", new_callable=AsyncMock, side_effect=[dummy_tool]):
 767 |         with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False)):
 768 |             output = program(question="why did a chicken cross the kitchen?")
 769 |             timestamps = []
 770 |             async for value in output:
 771 |                 if isinstance(value, dspy.streaming.StatusMessage):
 772 |                     timestamps.append(time.time())
 773 | 
 774 |     # timestamps[0]: tool start message
 775 |     # timestamps[1]: tool end message
 776 |     # There should be ~1 second delay between the tool start and end messages because we explicitly sleep for 1 second
 777 |     # in the tool.
 778 |     assert timestamps[1] - timestamps[0] >= 1
 779 | 
 780 | 
 781 | @pytest.mark.anyio
 782 | async def test_status_message_non_blocking_async_program():
 783 |     async def dummy_tool():
 784 |         await asyncio.sleep(1)
 785 |         return "dummy_tool_output"
 786 | 
 787 |     class MyProgram(dspy.Module):
 788 |         async def aforward(self, question, **kwargs):
 789 |             await dspy.Tool(dummy_tool).acall()
 790 |             return dspy.Prediction(answer="dummy_tool_output")
 791 | 
 792 |     program = dspy.streamify(MyProgram(), status_message_provider=StatusMessageProvider(), is_async_program=True)
 793 | 
 794 |     with mock.patch("litellm.acompletion", new_callable=AsyncMock, side_effect=[dummy_tool]):
 795 |         with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False)):
 796 |             output = program(question="why did a chicken cross the kitchen?")
 797 |             timestamps = []
 798 |             async for value in output:
 799 |                 if isinstance(value, dspy.streaming.StatusMessage):
 800 |                     timestamps.append(time.time())
 801 | 
 802 |     # timestamps[0]: tool start message
 803 |     # timestamps[1]: tool end message
 804 |     # There should be ~1 second delay between the tool start and end messages because we explicitly sleep for 1 second
 805 |     # in the tool.
 806 |     assert timestamps[1] - timestamps[0] >= 1
 807 | 
 808 | 
 809 | @pytest.mark.anyio
 810 | async def test_stream_listener_allow_reuse():
 811 |     class MyProgram(dspy.Module):
 812 |         def __init__(self):
 813 |             super().__init__()
 814 |             self.predict = dspy.Predict("question->answer")
 815 | 
 816 |         def forward(self, question, **kwargs):
 817 |             self.predict(question=question, **kwargs)
 818 |             return self.predict(question=question, **kwargs)
 819 | 
 820 |     program = dspy.streamify(
 821 |         MyProgram(),
 822 |         stream_listeners=[
 823 |             dspy.streaming.StreamListener(signature_field_name="answer", allow_reuse=True),
 824 |         ],
 825 |     )
 826 | 
 827 |     async def gpt_4o_mini_stream(*args, **kwargs):
 828 |         # Recorded streaming from openai/gpt-4o-mini
 829 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="[["))])
 830 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" ##"))])
 831 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" answer"))])
 832 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" ##"))])
 833 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" ]]\n\n"))])
 834 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="To"))])
 835 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" get"))])
 836 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" to"))])
 837 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" the"))])
 838 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" other"))])
 839 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" side"))])
 840 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="!\n\n[[ ##"))])
 841 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" completed"))])
 842 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" ##"))])
 843 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" ]]"))])
 844 | 
 845 |     stream_generators = [gpt_4o_mini_stream, gpt_4o_mini_stream]
 846 | 
 847 |     async def completion_side_effect(*args, **kwargs):
 848 |         return stream_generators.pop(0)()  # return new async generator instance
 849 | 
 850 |     with mock.patch("litellm.acompletion", side_effect=completion_side_effect):
 851 |         with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False)):
 852 |             output = program(question="why did a chicken cross the kitchen?")
 853 |             all_chunks = []
 854 |             async for value in output:
 855 |                 if isinstance(value, dspy.streaming.StreamResponse):
 856 |                     all_chunks.append(value)
 857 | 
 858 |     concat_message = "".join([chunk.chunk for chunk in all_chunks])
 859 |     # The listener functions twice.
 860 |     assert concat_message == "To get to the other side!To get to the other side!"
 861 | 
 862 | 
 863 | @pytest.mark.anyio
 864 | async def test_stream_listener_returns_correct_chunk_xml_adapter():
 865 |     class MyProgram(dspy.Module):
 866 |         def __init__(self):
 867 |             super().__init__()
 868 |             self.predict1 = dspy.Predict("question->answer")
 869 |             self.predict2 = dspy.Predict("question,answer->judgement")
 870 | 
 871 |         def forward(self, question, **kwargs):
 872 |             answer = self.predict1(question=question, **kwargs).answer
 873 |             judgement = self.predict2(question=question, answer=answer, **kwargs)
 874 |             return judgement
 875 | 
 876 |     async def xml_stream_1(*args, **kwargs):
 877 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="<"))])
 878 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="answer"))])
 879 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=">"))])
 880 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="To"))])
 881 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" get"))])
 882 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" to"))])
 883 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" the"))])
 884 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" other"))])
 885 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" side"))])
 886 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="!"))])
 887 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="<"))])
 888 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="/answer"))])
 889 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=">"))])
 890 | 
 891 |     async def xml_stream_2(*args, **kwargs):
 892 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="<"))])
 893 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="judgement"))])
 894 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=">"))])
 895 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="The"))])
 896 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" answer"))])
 897 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" is"))])
 898 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" humorous"))])
 899 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="."))])
 900 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="<"))])
 901 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="/judgement"))])
 902 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=">"))])
 903 | 
 904 |     stream_generators = [xml_stream_1, xml_stream_2]
 905 | 
 906 |     async def completion_side_effect(*args, **kwargs):
 907 |         return stream_generators.pop(0)()
 908 | 
 909 |     with mock.patch("litellm.acompletion", side_effect=completion_side_effect):
 910 |         program = dspy.streamify(
 911 |             MyProgram(),
 912 |             stream_listeners=[
 913 |                 dspy.streaming.StreamListener(signature_field_name="answer"),
 914 |                 dspy.streaming.StreamListener(signature_field_name="judgement"),
 915 |             ],
 916 |         )
 917 |         with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False), adapter=dspy.XMLAdapter()):
 918 |             output = program(question="why did a chicken cross the kitchen?")
 919 |             all_chunks = []
 920 |             async for value in output:
 921 |                 if isinstance(value, dspy.streaming.StreamResponse):
 922 |                     all_chunks.append(value)
 923 | 
 924 |     # Verify answer chunks
 925 |     answer_chunks = [chunk for chunk in all_chunks if chunk.signature_field_name == "answer"]
 926 |     assert len(answer_chunks) > 0
 927 |     assert answer_chunks[0].predict_name == "predict1"
 928 |     assert "".join([chunk.chunk for chunk in answer_chunks]) == "To get to the other side!"
 929 | 
 930 |     # Verify judgement chunks
 931 |     judgement_chunks = [chunk for chunk in all_chunks if chunk.signature_field_name == "judgement"]
 932 |     assert len(judgement_chunks) > 0
 933 |     assert judgement_chunks[0].predict_name == "predict2"
 934 |     assert "".join([chunk.chunk for chunk in judgement_chunks]) == "The answer is humorous."
 935 | 
 936 | 
 937 | @pytest.mark.anyio
 938 | async def test_streaming_allows_custom_chunk_types():
 939 |     @dataclass
 940 |     class CustomChunk:
 941 |         text: str
 942 | 
 943 |     class MyProgram(dspy.Module):
 944 |         def forward(self, question, **kwargs):
 945 |             async def send_to_stream():
 946 |                 chunk = CustomChunk(text="hello")
 947 |                 await dspy.settings.send_stream.send(chunk)
 948 | 
 949 |             syncified_send_to_stream = syncify(send_to_stream)
 950 |             syncified_send_to_stream()
 951 |             return dspy.Prediction(answer="dummy output")
 952 | 
 953 |     program = dspy.streamify(MyProgram())
 954 | 
 955 |     output = program(question="why did a chicken cross the kitchen?")
 956 |     all_chunks = []
 957 |     async for value in output:
 958 |         all_chunks.append(value)
 959 | 
 960 |     assert isinstance(all_chunks[0], CustomChunk)
 961 |     assert isinstance(all_chunks[1], dspy.Prediction)
 962 | 
 963 | 
 964 | @pytest.mark.anyio
 965 | async def test_streaming_allows_custom_streamable_type():
 966 |     class CustomType(Type):
 967 |         message: str
 968 | 
 969 |         @classmethod
 970 |         def is_streamable(cls) -> bool:
 971 |             return True
 972 | 
 973 |         @classmethod
 974 |         def parse_stream_chunk(cls, chunk):
 975 |             return CustomType(message=chunk.choices[0].delta.content)
 976 | 
 977 |         @classmethod
 978 |         def parse_lm_response(cls, response: dict) -> "CustomType":
 979 |             return CustomType(message=response.split("\n\n")[0])
 980 | 
 981 |     class CustomSignature(dspy.Signature):
 982 |         question: str = dspy.InputField()
 983 |         answer: CustomType = dspy.OutputField()
 984 | 
 985 |     program = dspy.streamify(
 986 |         dspy.Predict(CustomSignature),
 987 |         stream_listeners=[
 988 |             dspy.streaming.StreamListener(signature_field_name="answer"),
 989 |         ],
 990 |     )
 991 | 
 992 |     async def stream(*args, **kwargs):
 993 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="Hello"))])
 994 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="World"))])
 995 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="\n\n"))])
 996 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="[[ ##"))])
 997 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" completed"))])
 998 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" ##"))])
 999 |         yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" ]]"))])
1000 | 
1001 |     with mock.patch("litellm.acompletion", side_effect=stream):
1002 |         with dspy.context(
1003 |             lm=dspy.LM("openai/gpt-4o-mini", cache=False), adapter=dspy.ChatAdapter(native_response_types=[CustomType])
1004 |         ):
1005 |             output = program(question="why did a chicken cross the kitchen?")
1006 |             all_chunks = []
1007 |             async for value in output:
1008 |                 if isinstance(value, dspy.streaming.StreamResponse):
1009 |                     all_chunks.append(value)
1010 |                 elif isinstance(value, dspy.Prediction):
1011 |                     assert isinstance(value.answer, CustomType)
1012 |                     assert value.answer.message == "HelloWorld"
1013 | 
1014 |     assert all(isinstance(chunk.chunk, CustomType) for chunk in all_chunks)
1015 | 
1016 | 
1017 | @pytest.mark.anyio
1018 | async def test_streaming_with_citations():
1019 |     class AnswerWithSources(dspy.Signature):
1020 |         """Answer questions using provided documents with citations."""
1021 | 
1022 |         documents: list[Document] = dspy.InputField()
1023 |         question: str = dspy.InputField()
1024 |         answer: str = dspy.OutputField()
1025 |         citations: Citations = dspy.OutputField()
1026 | 
1027 |     class MyProgram(dspy.Module):
1028 |         def __init__(self):
1029 |             super().__init__()
1030 |             self.predict = dspy.Predict(AnswerWithSources)
1031 | 
1032 |         def forward(self, documents, question, **kwargs):
1033 |             return self.predict(documents=documents, question=question, **kwargs)
1034 | 
1035 |     async def citation_stream(*args, **kwargs):
1036 |         # Stream chunks with citation data in provider_specific_fields
1037 |         # To verify the realistic scenario with more than 10 chunks in the stream, include more than 10 chunks before the citation.
1038 |         yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content="[[ ##"))])
1039 |         yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content=" answer"))])
1040 |         yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content=" ## ]]\n\n"))])
1041 |         yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content="A"))])
1042 |         yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content="c"))])
1043 |         yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content="c"))])
1044 |         yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content="o"))])
1045 |         yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content="r"))])
1046 |         yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content="d"))])
1047 |         yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content="i"))])
1048 |         yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content="n"))])
1049 |         yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content="g"))])
1050 |         yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content=" to "))])
1051 |         yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content="the references,"))])
1052 |         yield ModelResponseStream(
1053 |             model="claude",
1054 |             choices=[
1055 |                 StreamingChoices(
1056 |                     delta=Delta(
1057 |                         content="",
1058 |                         provider_specific_fields={
1059 |                             "citation": {
1060 |                                 "type": "char_location",
1061 |                                 "cited_text": "water boils at 100°C",
1062 |                                 "document_index": 0,
1063 |                                 "document_title": "Physics Facts",
1064 |                                 "start_char_index": 0,
1065 |                                 "end_char_index": 19,
1066 |                             }
1067 |                         },
1068 |                     )
1069 |                 )
1070 |             ],
1071 |         )
1072 |         yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content=" water"))])
1073 |         yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content=" boils"))])
1074 |         yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content=" at"))])
1075 |         yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content=" 100°C"))])
1076 |         yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content=".\n\n[[ ##"))])
1077 |         yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content=" completed"))])
1078 |         yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content=" ## ]]"))])
1079 | 
1080 |     # Mock the final response choice to include provider_specific_fields with citations
1081 |     with mock.patch("litellm.acompletion", return_value=citation_stream()):
1082 |         program = dspy.streamify(
1083 |             MyProgram(),
1084 |             stream_listeners=[
1085 |                 dspy.streaming.StreamListener(signature_field_name="answer"),
1086 |                 dspy.streaming.StreamListener(signature_field_name="citations"),
1087 |             ],
1088 |         )
1089 | 
1090 |         # Create test documents
1091 |         docs = [Document(data="Water boils at 100°C at standard pressure.", title="Physics Facts")]
1092 | 
1093 |         with dspy.context(
1094 |             lm=dspy.LM("anthropic/claude-3-5-sonnet-20241022", cache=False),
1095 |             adapter=dspy.ChatAdapter(native_response_types=[Citations]),
1096 |         ):
1097 |             output = program(documents=docs, question="What temperature does water boil?")
1098 |             citation_chunks = []
1099 |             answer_chunks = []
1100 |             final_prediction = None
1101 |             async for value in output:
1102 |                 if isinstance(value, dspy.streaming.StreamResponse) and value.signature_field_name == "citations":
1103 |                     citation_chunks.append(value)
1104 |                 elif isinstance(value, dspy.streaming.StreamResponse) and value.signature_field_name == "answer":
1105 |                     answer_chunks.append(value.chunk)
1106 |                 elif isinstance(value, dspy.Prediction):
1107 |                     final_prediction = value
1108 | 
1109 |             # Test that we received citation chunks from streaming
1110 |             assert len(citation_chunks) > 0
1111 |             citation_chunk = citation_chunks[0]
1112 |             assert isinstance(citation_chunk.chunk, Citations)
1113 |             assert len(citation_chunk.chunk) == 1
1114 |             assert citation_chunk.chunk[0].cited_text == "water boils at 100°C"
1115 |             assert citation_chunk.chunk[0].document_title == "Physics Facts"
1116 | 
1117 |             # Verify the answer chunks are correct
1118 |             assert "".join(answer_chunks) == "According to the references, water boils at 100°C."
1119 | 
1120 |             # Test that prediction contains the expected fields
1121 |             assert final_prediction is not None
1122 |             assert hasattr(final_prediction, "answer")
1123 |             assert hasattr(final_prediction, "citations")
1124 |             assert final_prediction.answer == "According to the references, water boils at 100°C."
1125 | 
1126 | 
1127 | def test_stream_listener_could_form_end_identifier_chat_adapter():
1128 |     listener = dspy.streaming.StreamListener(signature_field_name="answer")
1129 | 
1130 |     # Should return True for partial bracket sequences
1131 |     assert listener._could_form_end_identifier("some text [", "ChatAdapter") is True
1132 |     assert listener._could_form_end_identifier("some text [[", "ChatAdapter") is True
1133 |     assert listener._could_form_end_identifier("some text [[ ", "ChatAdapter") is True
1134 |     assert listener._could_form_end_identifier("some text [[ #", "ChatAdapter") is True
1135 |     assert listener._could_form_end_identifier("some text [[ ##", "ChatAdapter") is True
1136 | 
1137 |     # Should return True for partial field names after "[[ ##"
1138 |     assert listener._could_form_end_identifier("some text [[ ## com", "ChatAdapter") is True
1139 |     assert listener._could_form_end_identifier("some text [[ ## completed", "ChatAdapter") is True
1140 | 
1141 |     # Should return False for text that clearly cannot form the pattern
1142 |     assert listener._could_form_end_identifier("hello world", "ChatAdapter") is False
1143 |     assert listener._could_form_end_identifier("some text", "ChatAdapter") is False
1144 |     assert listener._could_form_end_identifier("answer: hello", "ChatAdapter") is False
1145 | 
1146 | 
1147 | def test_stream_listener_could_form_end_identifier_json_adapter():
1148 |     listener = dspy.streaming.StreamListener(signature_field_name="output")
1149 | 
1150 |     # Should return True for partial quote/brace sequences
1151 |     assert listener._could_form_end_identifier('some text "', "JSONAdapter") is True
1152 |     assert listener._could_form_end_identifier('some text ",', "JSONAdapter") is True
1153 |     assert listener._could_form_end_identifier('some text " ', "JSONAdapter") is True
1154 |     assert listener._could_form_end_identifier('some text "}', "JSONAdapter") is True
1155 | 
1156 |     # Should return False for text that cannot form the pattern
1157 |     assert listener._could_form_end_identifier("hello world", "JSONAdapter") is False
1158 |     assert listener._could_form_end_identifier("some text", "JSONAdapter") is False
1159 | 
1160 | 
1161 | def test_stream_listener_could_form_end_identifier_xml_adapter():
1162 |     listener = dspy.streaming.StreamListener(signature_field_name="result")
1163 | 
1164 |     # Should return True for partial closing tag
1165 |     assert listener._could_form_end_identifier("some text <", "XMLAdapter") is True
1166 |     assert listener._could_form_end_identifier("some text </", "XMLAdapter") is True
1167 |     assert listener._could_form_end_identifier("some text </result", "XMLAdapter") is True
1168 | 
1169 |     # Should return False for text that cannot form the pattern
1170 |     assert listener._could_form_end_identifier("hello world", "XMLAdapter") is False
1171 |     assert listener._could_form_end_identifier("some text", "XMLAdapter") is False
1172 | 
```
Page 16/17FirstPrevNextLast