# Directory Structure ``` ├── .gitignore ├── LICENSE ├── mcp_server.py ├── pyproject.toml ├── README.md ├── resource │ └── img │ ├── ali.png │ ├── gzh_code.jpg │ ├── img_1.png │ ├── img_2.png │ └── planet.jpg ├── src │ ├── __init__.py │ ├── baostock_data_source.py │ ├── data_source_interface.py │ ├── formatting │ │ ├── __init__.py │ │ └── markdown_formatter.py │ ├── tools │ │ ├── __init__.py │ │ ├── analysis.py │ │ ├── base.py │ │ ├── date_utils.py │ │ ├── financial_reports.py │ │ ├── helpers.py │ │ ├── indices.py │ │ ├── macroeconomic.py │ │ ├── market_overview.py │ │ └── stock_market.py │ └── utils.py └── uv.lock ``` # Files -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- ``` 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | AGENTS.md 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # Pipfile.lock 91 | 92 | # PEP 582; used by PDM, PEP 582 proposal 93 | __pypackages__/ 94 | 95 | # Celery stuff 96 | celerybeat-schedule 97 | celerybeat.pid 98 | 99 | # SageMath parsed files 100 | *.sage.py 101 | 102 | # Environments 103 | .env 104 | .venv 105 | env/ 106 | venv/ 107 | ENV/ 108 | env.bak/ 109 | venv.bak/ 110 | 111 | # Spyder project settings 112 | .spyderproject 113 | .spyproject 114 | 115 | # Rope project settings 116 | .ropeproject 117 | 118 | # mkdocs documentation 119 | /site 120 | 121 | # mypy 122 | .mypy_cache/ 123 | .dmypy.json 124 | dmypy.json 125 | 126 | # Pyre type checker 127 | .pyre/ 128 | 129 | # pytype static type analyzer 130 | .pytype/ 131 | 132 | # Cython debug symbols 133 | cython_debug/ 134 | 135 | # VS Code settings 136 | .vscode/ 137 | 138 | docs/ 139 | 140 | 141 | 142 | 143 | ``` -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- ```python 1 | # This file makes src a Python package 2 | ``` -------------------------------------------------------------------------------- /src/tools/__init__.py: -------------------------------------------------------------------------------- ```python 1 | # Initialization file for tools package 2 | ``` -------------------------------------------------------------------------------- /src/formatting/__init__.py: -------------------------------------------------------------------------------- ```python 1 | # Initialization file for formatting package 2 | ``` -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- ```toml 1 | [project] 2 | name = "a-share-mcp" 3 | version = "0.1.0" 4 | description = "Add your description here" 5 | readme = "README.md" 6 | requires-python = ">=3.12" 7 | dependencies = [ 8 | "baostock>=0.8.9", 9 | "httpx>=0.28.1", 10 | "mcp[cli]>=1.2.0", 11 | "pandas>=2.2.3", 12 | "annotated-types>=0.7.0", 13 | "anyio>=4.9.0", 14 | "certifi>=2025.4.26", 15 | "click>=8.1.8", 16 | "colorama>=0.4.6", 17 | "h11>=0.16.0", 18 | "httpcore>=1.0.9", 19 | "httpx-sse>=0.4.0", 20 | "idna>=3.10", 21 | "markdown-it-py>=3.0.0", 22 | "mdurl>=0.1.2", 23 | "numpy>=2.2.5", 24 | "pydantic>=2.11.3", 25 | "pydantic-core>=2.33.1", 26 | "pydantic-settings>=2.9.1", 27 | "pygments>=2.19.1", 28 | "python-dateutil>=2.9.0", 29 | "python-dotenv>=1.1.0", 30 | "pytz>=2025.2", 31 | "rich>=14.0.0", 32 | "shellingham>=1.5.4", 33 | "six>=1.17.0", 34 | "sniffio>=1.3.1", 35 | "sse-starlette>=2.3.3", 36 | "starlette>=0.46.2", 37 | "tabulate>=0.9.0", 38 | "typer>=0.15.3", 39 | "typing-extensions>=4.13.2", 40 | "typing-inspection>=0.4.0", 41 | "tzdata>=2025.2", 42 | "uvicorn>=0.34.2", 43 | ] 44 | ``` -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- ```python 1 | # Utility functions, including the Baostock login context manager and logging setup 2 | import baostock as bs 3 | import os 4 | import sys 5 | import logging 6 | from contextlib import contextmanager 7 | from .data_source_interface import LoginError 8 | 9 | # --- Logging Setup --- 10 | def setup_logging(level=logging.INFO): 11 | """Configures basic logging for the application.""" 12 | logging.basicConfig( 13 | level=level, 14 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', 15 | datefmt='%Y-%m-%d %H:%M:%S' 16 | ) 17 | # Optionally silence logs from dependencies if they are too verbose 18 | # logging.getLogger("mcp").setLevel(logging.WARNING) 19 | 20 | # Get a logger instance for this module (optional, but good practice) 21 | logger = logging.getLogger(__name__) 22 | 23 | # --- Baostock Context Manager --- 24 | @contextmanager 25 | def baostock_login_context(): 26 | """Context manager to handle Baostock login and logout, suppressing stdout messages.""" 27 | # Redirect stdout to suppress login/logout messages 28 | original_stdout_fd = sys.stdout.fileno() 29 | saved_stdout_fd = os.dup(original_stdout_fd) 30 | devnull_fd = os.open(os.devnull, os.O_WRONLY) 31 | 32 | os.dup2(devnull_fd, original_stdout_fd) 33 | os.close(devnull_fd) 34 | 35 | logger.debug("Attempting Baostock login...") 36 | lg = bs.login() 37 | logger.debug(f"Login result: code={lg.error_code}, msg={lg.error_msg}") 38 | 39 | # Restore stdout 40 | os.dup2(saved_stdout_fd, original_stdout_fd) 41 | os.close(saved_stdout_fd) 42 | 43 | if lg.error_code != '0': 44 | # Log error before raising 45 | logger.error(f"Baostock login failed: {lg.error_msg}") 46 | raise LoginError(f"Baostock login failed: {lg.error_msg}") 47 | 48 | logger.info("Baostock login successful.") 49 | try: 50 | yield # API calls happen here 51 | finally: 52 | # Redirect stdout again for logout 53 | original_stdout_fd = sys.stdout.fileno() 54 | saved_stdout_fd = os.dup(original_stdout_fd) 55 | devnull_fd = os.open(os.devnull, os.O_WRONLY) 56 | 57 | os.dup2(devnull_fd, original_stdout_fd) 58 | os.close(devnull_fd) 59 | 60 | logger.debug("Attempting Baostock logout...") 61 | bs.logout() 62 | logger.debug("Logout completed.") 63 | 64 | # Restore stdout 65 | os.dup2(saved_stdout_fd, original_stdout_fd) 66 | os.close(saved_stdout_fd) 67 | logger.info("Baostock logout successful.") 68 | 69 | # You can add other utility functions or classes here if needed 70 | ``` -------------------------------------------------------------------------------- /src/formatting/markdown_formatter.py: -------------------------------------------------------------------------------- ```python 1 | """ 2 | Markdown formatting utilities for A-Share MCP Server. 3 | """ 4 | import pandas as pd 5 | import logging 6 | import json 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | # Configuration: Max rows to display in string outputs to protect context length 11 | MAX_MARKDOWN_ROWS = 250 12 | 13 | 14 | def format_df_to_markdown(df: pd.DataFrame, max_rows: int = None) -> str: 15 | """Formats a Pandas DataFrame to a Markdown string with row truncation. 16 | 17 | Args: 18 | df: The DataFrame to format 19 | max_rows: Maximum rows to include in output. Defaults to MAX_MARKDOWN_ROWS if None. 20 | 21 | Returns: 22 | A markdown formatted string representation of the DataFrame 23 | """ 24 | if df is None or df.empty: 25 | logger.warning("Attempted to format an empty DataFrame to Markdown.") 26 | return "(No data available to display)" 27 | 28 | if max_rows is None: 29 | max_rows = MAX_MARKDOWN_ROWS 30 | 31 | original_rows = df.shape[0] 32 | rows_to_show = min(original_rows, max_rows) 33 | df_display = df.head(rows_to_show) 34 | 35 | truncated = original_rows > rows_to_show 36 | 37 | try: 38 | markdown_table = df_display.to_markdown(index=False) 39 | except Exception as e: 40 | logger.error("Error converting DataFrame to Markdown: %s", e, exc_info=True) 41 | return "Error: Could not format data into Markdown table." 42 | 43 | if truncated: 44 | notes = f"rows truncated to {rows_to_show} from {original_rows}" 45 | return f"Note: Data truncated ({notes}).\n\n{markdown_table}" 46 | return markdown_table 47 | 48 | 49 | def format_table_output( 50 | df: pd.DataFrame, 51 | format: str = "markdown", 52 | max_rows: int | None = None, 53 | meta: dict | None = None, 54 | ) -> str: 55 | """Formats a DataFrame into the requested string format with optional meta. 56 | 57 | Args: 58 | df: Data to format. 59 | format: 'markdown' | 'json' | 'csv'. Defaults to 'markdown'. 60 | max_rows: Optional max rows to include (defaults depend on formatters). 61 | meta: Optional metadata dict to include (prepended for markdown, embedded for json). 62 | 63 | Returns: 64 | A string suitable for tool responses. 65 | """ 66 | fmt = (format or "markdown").lower() 67 | 68 | # Normalize row cap 69 | if max_rows is None: 70 | max_rows = MAX_MARKDOWN_ROWS if fmt == "markdown" else MAX_MARKDOWN_ROWS 71 | 72 | total_rows = 0 if df is None else int(df.shape[0]) 73 | rows_to_show = 0 if df is None else min(total_rows, max_rows) 74 | truncated = total_rows > rows_to_show 75 | df_display = df.head(rows_to_show) if df is not None else pd.DataFrame() 76 | 77 | if fmt == "markdown": 78 | header = "" 79 | if meta: 80 | # Render a compact meta header 81 | lines = ["Meta:"] 82 | for k, v in meta.items(): 83 | lines.append(f"- {k}: {v}") 84 | header = "\n".join(lines) + "\n\n" 85 | return header + format_df_to_markdown(df_display, max_rows=max_rows) 86 | 87 | if fmt == "csv": 88 | try: 89 | return df_display.to_csv(index=False) 90 | except Exception as e: 91 | logger.error("Error converting DataFrame to CSV: %s", e, exc_info=True) 92 | return "Error: Could not format data into CSV." 93 | 94 | if fmt == "json": 95 | try: 96 | payload = { 97 | "data": [] if df_display is None else df_display.to_dict(orient="records"), 98 | "meta": { 99 | **(meta or {}), 100 | "total_rows": total_rows, 101 | "returned_rows": rows_to_show, 102 | "truncated": truncated, 103 | "columns": [] if df_display is None else list(df_display.columns), 104 | }, 105 | } 106 | return json.dumps(payload, ensure_ascii=False) 107 | except Exception as e: 108 | logger.error("Error converting DataFrame to JSON: %s", e, exc_info=True) 109 | return "Error: Could not format data into JSON." 110 | 111 | # Fallback to markdown if unknown format 112 | logger.warning("Unknown format '%s', falling back to markdown", fmt) 113 | return format_df_to_markdown(df_display, max_rows=max_rows) 114 | ``` -------------------------------------------------------------------------------- /src/tools/helpers.py: -------------------------------------------------------------------------------- ```python 1 | """ 2 | Helper tools for code normalization and constants discovery. 3 | These are agent-friendly utilities with clear, unambiguous parameters. 4 | """ 5 | import logging 6 | import re 7 | from typing import Optional 8 | 9 | from mcp.server.fastmcp import FastMCP 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def register_helpers_tools(app: FastMCP): 15 | """ 16 | Register helper/utility tools with the MCP app. 17 | """ 18 | 19 | @app.tool() 20 | def normalize_stock_code(code: str) -> str: 21 | """ 22 | Normalize a stock code to Baostock format. 23 | 24 | Rules: 25 | - If 6 digits and starts with '6' -> 'sh.<code>' 26 | - If 6 digits and starts with other -> 'sz.<code>' 27 | - Accept '600000.SH'/'000001.SZ' -> lower and reorder to 'sh.600000'/'sz.000001' 28 | - Accept 'sh600000'/'sz000001' -> insert dot 29 | 30 | Args: 31 | code: Raw stock code (e.g., '600000', '000001.SZ', 'sh600000'). 32 | 33 | Returns: 34 | Normalized code like 'sh.600000' or an error string if invalid. 35 | 36 | Examples: 37 | - normalize_stock_code('600000') -> 'sh.600000' 38 | - normalize_stock_code('000001.SZ') -> 'sz.000001' 39 | """ 40 | logger.info("Tool 'normalize_stock_code' called with input=%s", code) 41 | try: 42 | raw = (code or "").strip() 43 | if not raw: 44 | return "Error: 'code' is required." 45 | 46 | # Patterns 47 | m = re.fullmatch(r"(?i)(sh|sz)[\.]?(\d{6})", raw) 48 | if m: 49 | ex = m.group(1).lower() 50 | num = m.group(2) 51 | return f"{ex}.{num}" 52 | 53 | m2 = re.fullmatch(r"(\d{6})[\.]?(?i)(sh|sz)", raw) 54 | if m2: 55 | num = m2.group(1) 56 | ex = m2.group(2).lower() 57 | return f"{ex}.{num}" 58 | 59 | m3 = re.fullmatch(r"(\d{6})", raw) 60 | if m3: 61 | num = m3.group(1) 62 | ex = "sh" if num.startswith("6") else "sz" 63 | return f"{ex}.{num}" 64 | 65 | return "Error: Unsupported code format. Examples: 'sh.600000', '600000', '000001.SZ'." 66 | except Exception as e: 67 | logger.exception("Exception in normalize_stock_code: %s", e) 68 | return f"Error: {e}" 69 | 70 | @app.tool() 71 | def list_tool_constants(kind: Optional[str] = None) -> str: 72 | """ 73 | List valid constants for tool parameters. 74 | 75 | Args: 76 | kind: Optional filter: 'frequency' | 'adjust_flag' | 'year_type' | 'index'. If None, show all. 77 | 78 | Returns: 79 | Markdown table(s) of constants and meanings. 80 | """ 81 | logger.info("Tool 'list_tool_constants' called kind=%s", kind or "all") 82 | freq = [ 83 | ("d", "daily"), ("w", "weekly"), ("m", "monthly"), 84 | ("5", "5 minutes"), ("15", "15 minutes"), ("30", "30 minutes"), ("60", "60 minutes"), 85 | ] 86 | adjust = [("1", "forward adjusted"), ("2", "backward adjusted"), ("3", "unadjusted")] 87 | year_type = [("report", "announcement year"), ("operate", "ex-dividend year")] 88 | index = [("hs300", "CSI 300"), ("sz50", "SSE 50"), ("zz500", "CSI 500")] 89 | 90 | sections = [] 91 | def as_md(title: str, rows): 92 | if not rows: 93 | return "" 94 | header = f"### {title}\n\n| value | meaning |\n|---|---|\n" 95 | lines = [f"| {v} | {m} |" for (v, m) in rows] 96 | return header + "\n".join(lines) + "\n" 97 | 98 | k = (kind or "").strip().lower() 99 | if k in ("", "frequency"): 100 | sections.append(as_md("frequency", freq)) 101 | if k in ("", "adjust_flag"): 102 | sections.append(as_md("adjust_flag", adjust)) 103 | if k in ("", "year_type"): 104 | sections.append(as_md("year_type", year_type)) 105 | if k in ("", "index"): 106 | sections.append(as_md("index", index)) 107 | 108 | out = "\n".join(s for s in sections if s) 109 | if not out: 110 | return "Error: Invalid kind. Use one of 'frequency', 'adjust_flag', 'year_type', 'index'." 111 | return out 112 | 113 | ``` -------------------------------------------------------------------------------- /src/data_source_interface.py: -------------------------------------------------------------------------------- ```python 1 | # Defines the abstract interface for financial data sources 2 | from abc import ABC, abstractmethod 3 | import pandas as pd 4 | from typing import Optional, List 5 | 6 | class DataSourceError(Exception): 7 | """Base exception for data source errors.""" 8 | pass 9 | 10 | 11 | class LoginError(DataSourceError): 12 | """Exception raised for login failures to the data source.""" 13 | pass 14 | 15 | 16 | class NoDataFoundError(DataSourceError): 17 | """Exception raised when no data is found for the given query.""" 18 | pass 19 | 20 | 21 | class FinancialDataSource(ABC): 22 | """ 23 | Abstract base class defining the interface for financial data sources. 24 | Implementations of this class provide access to specific financial data APIs 25 | (e.g., Baostock, Akshare). 26 | """ 27 | 28 | @abstractmethod 29 | def get_historical_k_data( 30 | self, 31 | code: str, 32 | start_date: str, 33 | end_date: str, 34 | frequency: str = "d", 35 | adjust_flag: str = "3", 36 | fields: Optional[List[str]] = None, 37 | ) -> pd.DataFrame: 38 | """ 39 | Fetches historical K-line (OHLCV) data for a given stock code. 40 | 41 | Args: 42 | code: The stock code (e.g., 'sh.600000', 'sz.000001'). 43 | start_date: Start date in 'YYYY-MM-DD' format. 44 | end_date: End date in 'YYYY-MM-DD' format. 45 | frequency: Data frequency. Common values depend on the underlying 46 | source (e.g., 'd' for daily, 'w' for weekly, 'm' for monthly, 47 | '5', '15', '30', '60' for minutes). Defaults to 'd'. 48 | adjust_flag: Adjustment flag for historical data. Common values 49 | depend on the source (e.g., '1' for forward adjusted, 50 | '2' for backward adjusted, '3' for non-adjusted). 51 | Defaults to '3'. 52 | fields: Optional list of specific fields to retrieve. If None, 53 | retrieves default fields defined by the implementation. 54 | 55 | Returns: 56 | A pandas DataFrame containing the historical K-line data, with 57 | columns corresponding to the requested fields. 58 | 59 | Raises: 60 | LoginError: If login to the data source fails. 61 | NoDataFoundError: If no data is found for the query. 62 | DataSourceError: For other data source related errors. 63 | ValueError: If input parameters are invalid. 64 | """ 65 | pass 66 | 67 | @abstractmethod 68 | def get_stock_basic_info(self, code: str) -> pd.DataFrame: 69 | """ 70 | Fetches basic information for a given stock code. 71 | 72 | Args: 73 | code: The stock code (e.g., 'sh.600000', 'sz.000001'). 74 | 75 | Returns: 76 | A pandas DataFrame containing the basic stock information. 77 | The structure and columns depend on the underlying data source. 78 | Typically contains info like name, industry, listing date, etc. 79 | 80 | Raises: 81 | LoginError: If login to the data source fails. 82 | NoDataFoundError: If no data is found for the query. 83 | DataSourceError: For other data source related errors. 84 | ValueError: If the input code is invalid. 85 | """ 86 | pass 87 | 88 | @abstractmethod 89 | def get_trade_dates(self, start_date: Optional[str] = None, end_date: Optional[str] = None) -> pd.DataFrame: 90 | """Fetches trading dates information within a range.""" 91 | pass 92 | 93 | @abstractmethod 94 | def get_all_stock(self, date: Optional[str] = None) -> pd.DataFrame: 95 | """Fetches list of all stocks and their trading status on a given date.""" 96 | pass 97 | 98 | @abstractmethod 99 | def get_deposit_rate_data(self, start_date: Optional[str] = None, end_date: Optional[str] = None) -> pd.DataFrame: 100 | """Fetches benchmark deposit rates.""" 101 | pass 102 | 103 | @abstractmethod 104 | def get_loan_rate_data(self, start_date: Optional[str] = None, end_date: Optional[str] = None) -> pd.DataFrame: 105 | """Fetches benchmark loan rates.""" 106 | pass 107 | 108 | @abstractmethod 109 | def get_required_reserve_ratio_data(self, start_date: Optional[str] = None, end_date: Optional[str] = None, year_type: str = '0') -> pd.DataFrame: 110 | """Fetches required reserve ratio data.""" 111 | pass 112 | 113 | @abstractmethod 114 | def get_money_supply_data_month(self, start_date: Optional[str] = None, end_date: Optional[str] = None) -> pd.DataFrame: 115 | """Fetches monthly money supply data (M0, M1, M2).""" 116 | pass 117 | 118 | @abstractmethod 119 | def get_money_supply_data_year(self, start_date: Optional[str] = None, end_date: Optional[str] = None) -> pd.DataFrame: 120 | """Fetches yearly money supply data (M0, M1, M2 - year end balance).""" 121 | pass 122 | 123 | # Note: SHIBOR is not implemented in current Baostock bindings; no abstract method here. 124 | ``` -------------------------------------------------------------------------------- /src/tools/macroeconomic.py: -------------------------------------------------------------------------------- ```python 1 | """ 2 | Macroeconomic tools for the MCP server. 3 | Fetch interest rates, money supply data, and more with consistent options. 4 | """ 5 | import logging 6 | from typing import Optional 7 | 8 | from mcp.server.fastmcp import FastMCP 9 | from src.data_source_interface import FinancialDataSource 10 | from src.tools.base import call_macro_data_tool 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | def register_macroeconomic_tools(app: FastMCP, active_data_source: FinancialDataSource): 16 | """ 17 | Register macroeconomic data tools with the MCP app. 18 | 19 | Args: 20 | app: The FastMCP app instance 21 | active_data_source: The active financial data source 22 | """ 23 | 24 | @app.tool() 25 | def get_deposit_rate_data(start_date: Optional[str] = None, end_date: Optional[str] = None, limit: int = 250, format: str = "markdown") -> str: 26 | """ 27 | Fetches benchmark deposit rates (活期, 定期) within a date range. 28 | 29 | Args: 30 | start_date: Optional. Start date in 'YYYY-MM-DD' format. 31 | end_date: Optional. End date in 'YYYY-MM-DD' format. 32 | 33 | Returns: 34 | Markdown table with deposit rate data or an error message. 35 | """ 36 | return call_macro_data_tool( 37 | "get_deposit_rate_data", 38 | active_data_source.get_deposit_rate_data, 39 | "Deposit Rate", 40 | start_date, end_date, 41 | limit=limit, format=format 42 | ) 43 | 44 | @app.tool() 45 | def get_loan_rate_data(start_date: Optional[str] = None, end_date: Optional[str] = None, limit: int = 250, format: str = "markdown") -> str: 46 | """ 47 | Fetches benchmark loan rates (贷款利率) within a date range. 48 | 49 | Args: 50 | start_date: Optional. Start date in 'YYYY-MM-DD' format. 51 | end_date: Optional. End date in 'YYYY-MM-DD' format. 52 | 53 | Returns: 54 | Markdown table with loan rate data or an error message. 55 | """ 56 | return call_macro_data_tool( 57 | "get_loan_rate_data", 58 | active_data_source.get_loan_rate_data, 59 | "Loan Rate", 60 | start_date, end_date, 61 | limit=limit, format=format 62 | ) 63 | 64 | @app.tool() 65 | def get_required_reserve_ratio_data(start_date: Optional[str] = None, end_date: Optional[str] = None, year_type: str = '0', limit: int = 250, format: str = "markdown") -> str: 66 | """ 67 | Fetches required reserve ratio data (存款准备金率) within a date range. 68 | 69 | Args: 70 | start_date: Optional. Start date in 'YYYY-MM-DD' format. 71 | end_date: Optional. End date in 'YYYY-MM-DD' format. 72 | year_type: Optional. Year type for date filtering. '0' for announcement date (公告日期, default), 73 | '1' for effective date (生效日期). 74 | 75 | Returns: 76 | Markdown table with required reserve ratio data or an error message. 77 | """ 78 | # Basic validation for year_type 79 | if year_type not in ['0', '1']: 80 | logger.warning(f"Invalid year_type requested: {year_type}") 81 | return "Error: Invalid year_type '{year_type}'. Valid options are '0' (announcement date) or '1' (effective date)." 82 | 83 | return call_macro_data_tool( 84 | "get_required_reserve_ratio_data", 85 | active_data_source.get_required_reserve_ratio_data, 86 | "Required Reserve Ratio", 87 | start_date, end_date, 88 | limit=limit, format=format, 89 | yearType=year_type # Pass the extra arg correctly named for Baostock 90 | ) 91 | 92 | @app.tool() 93 | def get_money_supply_data_month(start_date: Optional[str] = None, end_date: Optional[str] = None, limit: int = 250, format: str = "markdown") -> str: 94 | """ 95 | Fetches monthly money supply data (M0, M1, M2) within a date range. 96 | 97 | Args: 98 | start_date: Optional. Start date in 'YYYY-MM' format. 99 | end_date: Optional. End date in 'YYYY-MM' format. 100 | 101 | Returns: 102 | Markdown table with monthly money supply data or an error message. 103 | """ 104 | # Add specific validation for YYYY-MM format if desired 105 | return call_macro_data_tool( 106 | "get_money_supply_data_month", 107 | active_data_source.get_money_supply_data_month, 108 | "Monthly Money Supply", 109 | start_date, end_date, 110 | limit=limit, format=format 111 | ) 112 | 113 | @app.tool() 114 | def get_money_supply_data_year(start_date: Optional[str] = None, end_date: Optional[str] = None, limit: int = 250, format: str = "markdown") -> str: 115 | """ 116 | Fetches yearly money supply data (M0, M1, M2 - year end balance) within a date range. 117 | 118 | Args: 119 | start_date: Optional. Start year in 'YYYY' format. 120 | end_date: Optional. End year in 'YYYY' format. 121 | 122 | Returns: 123 | Markdown table with yearly money supply data or an error message. 124 | """ 125 | # Add specific validation for YYYY format if desired 126 | return call_macro_data_tool( 127 | "get_money_supply_data_year", 128 | active_data_source.get_money_supply_data_year, 129 | "Yearly Money Supply", 130 | start_date, end_date, 131 | limit=limit, format=format 132 | ) 133 | 134 | # Note: SHIBOR 查询未在当前 baostock 绑定中提供,对应工具不实现。 135 | ``` -------------------------------------------------------------------------------- /src/tools/base.py: -------------------------------------------------------------------------------- ```python 1 | """ 2 | Base utilities for MCP tools. 3 | Shared helpers for calling data sources with consistent formatting and errors. 4 | """ 5 | import logging 6 | from typing import Callable, Optional 7 | import pandas as pd 8 | 9 | from src.formatting.markdown_formatter import format_df_to_markdown, format_table_output 10 | from src.data_source_interface import NoDataFoundError, LoginError, DataSourceError 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | def call_financial_data_tool( 16 | tool_name: str, 17 | # Pass the bound method like active_data_source.get_profit_data 18 | data_source_method: Callable, 19 | data_type_name: str, 20 | code: str, 21 | year: str, 22 | quarter: int, 23 | *, 24 | limit: int = 250, 25 | format: str = "markdown", 26 | ) -> str: 27 | """ 28 | Helper function to reduce repetition for financial data tools 29 | 30 | Args: 31 | tool_name: Name of the tool for logging 32 | data_source_method: Method to call on the data source 33 | data_type_name: Type of financial data (for logging) 34 | code: Stock code 35 | year: Year to query 36 | quarter: Quarter to query 37 | 38 | Returns: 39 | Markdown formatted string with results or error message 40 | """ 41 | logger.info(f"Tool '{tool_name}' called for {code}, {year}Q{quarter}") 42 | try: 43 | # Basic validation 44 | if not year.isdigit() or len(year) != 4: 45 | logger.warning(f"Invalid year format requested: {year}") 46 | return f"Error: Invalid year '{year}'. Please provide a 4-digit year." 47 | if not 1 <= quarter <= 4: 48 | logger.warning(f"Invalid quarter requested: {quarter}") 49 | return f"Error: Invalid quarter '{quarter}'. Must be between 1 and 4." 50 | 51 | # Call the appropriate method on the already instantiated active_data_source 52 | df = data_source_method(code=code, year=year, quarter=quarter) 53 | logger.info( 54 | f"Successfully retrieved {data_type_name} data for {code}, {year}Q{quarter}.") 55 | meta = {"code": code, "year": year, "quarter": quarter, "dataset": data_type_name} 56 | return format_table_output(df, format=format, max_rows=limit, meta=meta) 57 | 58 | except NoDataFoundError as e: 59 | logger.warning(f"NoDataFoundError for {code}, {year}Q{quarter}: {e}") 60 | return f"Error: {e}" 61 | except LoginError as e: 62 | logger.error(f"LoginError for {code}: {e}") 63 | return f"Error: Could not connect to data source. {e}" 64 | except DataSourceError as e: 65 | logger.error(f"DataSourceError for {code}: {e}") 66 | return f"Error: An error occurred while fetching data. {e}" 67 | except ValueError as e: 68 | logger.warning(f"ValueError processing request for {code}: {e}") 69 | return f"Error: Invalid input parameter. {e}" 70 | except Exception as e: 71 | logger.exception( 72 | f"Unexpected Exception processing {tool_name} for {code}: {e}") 73 | return f"Error: An unexpected error occurred: {e}" 74 | 75 | 76 | def call_macro_data_tool( 77 | tool_name: str, 78 | data_source_method: Callable, 79 | data_type_name: str, 80 | start_date: Optional[str] = None, 81 | end_date: Optional[str] = None, 82 | *, 83 | limit: int = 250, 84 | format: str = "markdown", 85 | **kwargs # For extra params like year_type 86 | ) -> str: 87 | """ 88 | Helper function for macroeconomic data tools 89 | 90 | Args: 91 | tool_name: Name of the tool for logging 92 | data_source_method: Method to call on the data source 93 | data_type_name: Type of data (for logging) 94 | start_date: Optional start date 95 | end_date: Optional end date 96 | **kwargs: Additional keyword arguments to pass to data_source_method 97 | 98 | Returns: 99 | Markdown formatted string with results or error message 100 | """ 101 | date_range_log = f"from {start_date or 'default'} to {end_date or 'default'}" 102 | kwargs_log = f", extra_args={kwargs}" if kwargs else "" 103 | logger.info(f"Tool '{tool_name}' called {date_range_log}{kwargs_log}") 104 | try: 105 | # Call the appropriate method on the active_data_source 106 | df = data_source_method(start_date=start_date, end_date=end_date, **kwargs) 107 | logger.info(f"Successfully retrieved {data_type_name} data.") 108 | meta = {"dataset": data_type_name, "start_date": start_date, "end_date": end_date} | ({"extra": kwargs} if kwargs else {}) 109 | return format_table_output(df, format=format, max_rows=limit, meta=meta) 110 | except NoDataFoundError as e: 111 | logger.warning(f"NoDataFoundError: {e}") 112 | return f"Error: {e}" 113 | except LoginError as e: 114 | logger.error(f"LoginError: {e}") 115 | return f"Error: Could not connect to data source. {e}" 116 | except DataSourceError as e: 117 | logger.error(f"DataSourceError: {e}") 118 | return f"Error: An error occurred while fetching data. {e}" 119 | except ValueError as e: 120 | logger.warning(f"ValueError: {e}") 121 | return f"Error: Invalid input parameter. {e}" 122 | except Exception as e: 123 | logger.exception(f"Unexpected Exception processing {tool_name}: {e}") 124 | return f"Error: An unexpected error occurred: {e}" 125 | 126 | 127 | def call_index_constituent_tool( 128 | tool_name: str, 129 | data_source_method: Callable, 130 | index_name: str, 131 | date: Optional[str] = None, 132 | *, 133 | limit: int = 250, 134 | format: str = "markdown", 135 | ) -> str: 136 | """ 137 | Helper function for index constituent tools 138 | 139 | Args: 140 | tool_name: Name of the tool for logging 141 | data_source_method: Method to call on the data source 142 | index_name: Name of the index (for logging) 143 | date: Optional date to query 144 | 145 | Returns: 146 | Markdown formatted string with results or error message 147 | """ 148 | log_msg = f"Tool '{tool_name}' called for date={date or 'latest'}" 149 | logger.info(log_msg) 150 | try: 151 | # Add date validation if desired 152 | df = data_source_method(date=date) 153 | logger.info( 154 | f"Successfully retrieved {index_name} constituents for {date or 'latest'}.") 155 | meta = {"index": index_name, "as_of": date or "latest"} 156 | return format_table_output(df, format=format, max_rows=limit, meta=meta) 157 | except NoDataFoundError as e: 158 | logger.warning(f"NoDataFoundError: {e}") 159 | return f"Error: {e}" 160 | except LoginError as e: 161 | logger.error(f"LoginError: {e}") 162 | return f"Error: Could not connect to data source. {e}" 163 | except DataSourceError as e: 164 | logger.error(f"DataSourceError: {e}") 165 | return f"Error: An error occurred while fetching data. {e}" 166 | except ValueError as e: 167 | logger.warning(f"ValueError: {e}") 168 | return f"Error: Invalid input parameter. {e}" 169 | except Exception as e: 170 | logger.exception(f"Unexpected Exception processing {tool_name}: {e}") 171 | return f"Error: An unexpected error occurred: {e}" 172 | ``` -------------------------------------------------------------------------------- /src/tools/market_overview.py: -------------------------------------------------------------------------------- ```python 1 | """ 2 | Market overview tools for the MCP server. 3 | Includes trading calendar, stock list, and discovery helpers. 4 | """ 5 | import logging 6 | from typing import Optional 7 | 8 | from mcp.server.fastmcp import FastMCP 9 | from src.data_source_interface import FinancialDataSource, NoDataFoundError, LoginError, DataSourceError 10 | from src.formatting.markdown_formatter import format_df_to_markdown, format_table_output 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | def register_market_overview_tools(app: FastMCP, active_data_source: FinancialDataSource): 16 | """ 17 | Register market overview tools with the MCP app. 18 | 19 | Args: 20 | app: The FastMCP app instance 21 | active_data_source: The active financial data source 22 | """ 23 | 24 | @app.tool() 25 | def get_trade_dates(start_date: Optional[str] = None, end_date: Optional[str] = None, limit: int = 250, format: str = "markdown") -> str: 26 | """ 27 | Fetch trading dates within a specified range. 28 | 29 | Args: 30 | start_date: Optional. Start date in 'YYYY-MM-DD' format. Defaults to 2015-01-01 if None. 31 | end_date: Optional. End date in 'YYYY-MM-DD' format. Defaults to the current date if None. 32 | 33 | Returns: 34 | Markdown table with 'is_trading_day' (1=trading, 0=non-trading). 35 | """ 36 | logger.info( 37 | f"Tool 'get_trade_dates' called for range {start_date or 'default'} to {end_date or 'default'}") 38 | try: 39 | # Add date validation if desired 40 | df = active_data_source.get_trade_dates( 41 | start_date=start_date, end_date=end_date) 42 | logger.info("Successfully retrieved trade dates.") 43 | meta = {"start_date": start_date or "default", "end_date": end_date or "default"} 44 | return format_table_output(df, format=format, max_rows=limit, meta=meta) 45 | 46 | except NoDataFoundError as e: 47 | logger.warning(f"NoDataFoundError: {e}") 48 | return f"Error: {e}" 49 | except LoginError as e: 50 | logger.error(f"LoginError: {e}") 51 | return f"Error: Could not connect to data source. {e}" 52 | except DataSourceError as e: 53 | logger.error(f"DataSourceError: {e}") 54 | return f"Error: An error occurred while fetching data. {e}" 55 | except ValueError as e: 56 | logger.warning(f"ValueError: {e}") 57 | return f"Error: Invalid input parameter. {e}" 58 | except Exception as e: 59 | logger.exception( 60 | f"Unexpected Exception processing get_trade_dates: {e}") 61 | return f"Error: An unexpected error occurred: {e}" 62 | 63 | @app.tool() 64 | def get_all_stock(date: Optional[str] = None, limit: int = 250, format: str = "markdown") -> str: 65 | """ 66 | Fetch a list of all stocks (A-shares and indices) and their trading status for a date. 67 | 68 | Args: 69 | date: Optional. The date in 'YYYY-MM-DD' format. If None, uses the current date. 70 | 71 | Returns: 72 | Markdown table listing stock codes and trading status (1=trading, 0=suspended). 73 | """ 74 | logger.info( 75 | f"Tool 'get_all_stock' called for date={date or 'default'}") 76 | try: 77 | # Add date validation if desired 78 | df = active_data_source.get_all_stock(date=date) 79 | logger.info( 80 | f"Successfully retrieved stock list for {date or 'default'}.") 81 | meta = {"as_of": date or "default"} 82 | return format_table_output(df, format=format, max_rows=limit, meta=meta) 83 | 84 | except NoDataFoundError as e: 85 | logger.warning(f"NoDataFoundError: {e}") 86 | return f"Error: {e}" 87 | except LoginError as e: 88 | logger.error(f"LoginError: {e}") 89 | return f"Error: Could not connect to data source. {e}" 90 | except DataSourceError as e: 91 | logger.error(f"DataSourceError: {e}") 92 | return f"Error: An error occurred while fetching data. {e}" 93 | except ValueError as e: 94 | logger.warning(f"ValueError: {e}") 95 | return f"Error: Invalid input parameter. {e}" 96 | except Exception as e: 97 | logger.exception( 98 | f"Unexpected Exception processing get_all_stock: {e}") 99 | return f"Error: An unexpected error occurred: {e}" 100 | 101 | @app.tool() 102 | def search_stocks(keyword: str, date: Optional[str] = None, limit: int = 50, format: str = "markdown") -> str: 103 | """ 104 | Search stocks by code substring on a date. 105 | 106 | Args: 107 | keyword: Substring to match in the stock code (e.g., '600', '000001'). 108 | date: Optional 'YYYY-MM-DD'. If None, uses current date. 109 | limit: Max rows to return. Defaults to 50. 110 | format: Output format: 'markdown' | 'json' | 'csv'. Defaults to 'markdown'. 111 | 112 | Returns: 113 | Matching stock codes with their trading status. 114 | """ 115 | logger.info("Tool 'search_stocks' called keyword=%s, date=%s, limit=%s, format=%s", keyword, date or "default", limit, format) 116 | try: 117 | if not keyword or not keyword.strip(): 118 | return "Error: 'keyword' is required (substring of code)." 119 | df = active_data_source.get_all_stock(date=date) 120 | if df is None or df.empty: 121 | return "(No data available to display)" 122 | kw = keyword.strip().lower() 123 | # baostock returns 'code' like 'sh.600000' 124 | filtered = df[df["code"].str.lower().str.contains(kw, na=False)] 125 | meta = {"keyword": keyword, "as_of": date or "current"} 126 | return format_table_output(filtered, format=format, max_rows=limit, meta=meta) 127 | except Exception as e: 128 | logger.exception("Exception processing search_stocks: %s", e) 129 | return f"Error: An unexpected error occurred: {e}" 130 | 131 | @app.tool() 132 | def get_suspensions(date: Optional[str] = None, limit: int = 250, format: str = "markdown") -> str: 133 | """ 134 | List suspended stocks for a date. 135 | 136 | Args: 137 | date: Optional 'YYYY-MM-DD'. If None, uses current date. 138 | limit: Max rows to return. Defaults to 250. 139 | format: Output format: 'markdown' | 'json' | 'csv'. Defaults to 'markdown'. 140 | 141 | Returns: 142 | Table of stocks where tradeStatus==0. 143 | """ 144 | logger.info("Tool 'get_suspensions' called date=%s, limit=%s, format=%s", date or "current", limit, format) 145 | try: 146 | df = active_data_source.get_all_stock(date=date) 147 | if df is None or df.empty: 148 | return "(No data available to display)" 149 | # tradeStatus: '1' trading, '0' suspended 150 | if "tradeStatus" not in df.columns: 151 | return "Error: 'tradeStatus' column not present in data source response." 152 | suspended = df[df["tradeStatus"] == '0'] 153 | meta = {"as_of": date or "current", "total_suspended": int(suspended.shape[0])} 154 | return format_table_output(suspended, format=format, max_rows=limit, meta=meta) 155 | except Exception as e: 156 | logger.exception("Exception processing get_suspensions: %s", e) 157 | return f"Error: An unexpected error occurred: {e}" 158 | ``` -------------------------------------------------------------------------------- /src/tools/financial_reports.py: -------------------------------------------------------------------------------- ```python 1 | """ 2 | Financial report tools for the MCP server. 3 | Clear, strongly-typed parameters; consistent output options. 4 | """ 5 | import logging 6 | from typing import List, Optional 7 | 8 | from mcp.server.fastmcp import FastMCP 9 | from src.data_source_interface import FinancialDataSource 10 | from src.tools.base import call_financial_data_tool 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | def register_financial_report_tools(app: FastMCP, active_data_source: FinancialDataSource): 16 | """ 17 | Register financial report related tools with the MCP app. 18 | 19 | Args: 20 | app: The FastMCP app instance 21 | active_data_source: The active financial data source 22 | """ 23 | 24 | @app.tool() 25 | def get_profit_data(code: str, year: str, quarter: int, limit: int = 250, format: str = "markdown") -> str: 26 | """ 27 | Get quarterly profitability data (e.g., ROE, net profit margin) for a stock. 28 | 29 | Args: 30 | code: The stock code (e.g., 'sh.600000'). 31 | year: The 4-digit year (e.g., '2023'). 32 | quarter: The quarter (1, 2, 3, or 4). 33 | 34 | Returns: 35 | Profitability metrics table. 36 | """ 37 | return call_financial_data_tool( 38 | "get_profit_data", 39 | active_data_source.get_profit_data, 40 | "Profitability", 41 | code, year, quarter, 42 | limit=limit, format=format 43 | ) 44 | 45 | @app.tool() 46 | def get_operation_data(code: str, year: str, quarter: int, limit: int = 250, format: str = "markdown") -> str: 47 | """ 48 | Get quarterly operation capability data (e.g., turnover ratios) for a stock. 49 | 50 | Args: 51 | code: The stock code (e.g., 'sh.600000'). 52 | year: The 4-digit year (e.g., '2023'). 53 | quarter: The quarter (1, 2, 3, or 4). 54 | 55 | Returns: 56 | Operation capability metrics table. 57 | """ 58 | return call_financial_data_tool( 59 | "get_operation_data", 60 | active_data_source.get_operation_data, 61 | "Operation Capability", 62 | code, year, quarter, 63 | limit=limit, format=format 64 | ) 65 | 66 | @app.tool() 67 | def get_growth_data(code: str, year: str, quarter: int, limit: int = 250, format: str = "markdown") -> str: 68 | """ 69 | Get quarterly growth capability data (e.g., YOY growth rates) for a stock. 70 | 71 | Args: 72 | code: The stock code (e.g., 'sh.600000'). 73 | year: The 4-digit year (e.g., '2023'). 74 | quarter: The quarter (1, 2, 3, or 4). 75 | 76 | Returns: 77 | Growth capability metrics table. 78 | """ 79 | return call_financial_data_tool( 80 | "get_growth_data", 81 | active_data_source.get_growth_data, 82 | "Growth Capability", 83 | code, year, quarter, 84 | limit=limit, format=format 85 | ) 86 | 87 | @app.tool() 88 | def get_balance_data(code: str, year: str, quarter: int, limit: int = 250, format: str = "markdown") -> str: 89 | """ 90 | Get quarterly balance sheet / solvency data (e.g., current ratio, debt ratio) for a stock. 91 | 92 | Args: 93 | code: The stock code (e.g., 'sh.600000'). 94 | year: The 4-digit year (e.g., '2023'). 95 | quarter: The quarter (1, 2, 3, or 4). 96 | 97 | Returns: 98 | Balance sheet metrics table. 99 | """ 100 | return call_financial_data_tool( 101 | "get_balance_data", 102 | active_data_source.get_balance_data, 103 | "Balance Sheet", 104 | code, year, quarter, 105 | limit=limit, format=format 106 | ) 107 | 108 | @app.tool() 109 | def get_cash_flow_data(code: str, year: str, quarter: int, limit: int = 250, format: str = "markdown") -> str: 110 | """ 111 | Get quarterly cash flow data (e.g., CFO/Operating Revenue ratio) for a stock. 112 | 113 | Args: 114 | code: The stock code (e.g., 'sh.600000'). 115 | year: The 4-digit year (e.g., '2023'). 116 | quarter: The quarter (1, 2, 3, or 4). 117 | 118 | Returns: 119 | Cash flow metrics table. 120 | """ 121 | return call_financial_data_tool( 122 | "get_cash_flow_data", 123 | active_data_source.get_cash_flow_data, 124 | "Cash Flow", 125 | code, year, quarter, 126 | limit=limit, format=format 127 | ) 128 | 129 | @app.tool() 130 | def get_dupont_data(code: str, year: str, quarter: int, limit: int = 250, format: str = "markdown") -> str: 131 | """ 132 | Get quarterly DuPont analysis data (ROE decomposition) for a stock. 133 | 134 | Args: 135 | code: The stock code (e.g., 'sh.600000'). 136 | year: The 4-digit year (e.g., '2023'). 137 | quarter: The quarter (1, 2, 3, or 4). 138 | 139 | Returns: 140 | DuPont analysis metrics table. 141 | """ 142 | return call_financial_data_tool( 143 | "get_dupont_data", 144 | active_data_source.get_dupont_data, 145 | "DuPont Analysis", 146 | code, year, quarter, 147 | limit=limit, format=format 148 | ) 149 | 150 | @app.tool() 151 | def get_performance_express_report(code: str, start_date: str, end_date: str, limit: int = 250, format: str = "markdown") -> str: 152 | """ 153 | Fetches performance express reports (业绩快报) for a stock within a date range. 154 | Note: Companies are not required to publish these except in specific cases. 155 | 156 | Args: 157 | code: The stock code (e.g., 'sh.600000'). 158 | start_date: Start date (for report publication/update) in 'YYYY-MM-DD' format. 159 | end_date: End date (for report publication/update) in 'YYYY-MM-DD' format. 160 | 161 | Returns: 162 | Markdown table with performance express report data or an error message. 163 | """ 164 | logger.info( 165 | f"Tool 'get_performance_express_report' called for {code} ({start_date} to {end_date})") 166 | try: 167 | # Add date validation if desired 168 | df = active_data_source.get_performance_express_report( 169 | code=code, start_date=start_date, end_date=end_date) 170 | logger.info( 171 | f"Successfully retrieved performance express reports for {code}.") 172 | from src.formatting.markdown_formatter import format_table_output 173 | meta = {"code": code, "start_date": start_date, "end_date": end_date, "dataset": "performance_express"} 174 | return format_table_output(df, format=format, max_rows=limit, meta=meta) 175 | 176 | except Exception as e: 177 | logger.exception( 178 | f"Exception processing get_performance_express_report for {code}: {e}") 179 | return f"Error: An unexpected error occurred: {e}" 180 | 181 | @app.tool() 182 | def get_forecast_report(code: str, start_date: str, end_date: str, limit: int = 250, format: str = "markdown") -> str: 183 | """ 184 | Fetches performance forecast reports (业绩预告) for a stock within a date range. 185 | Note: Companies are not required to publish these except in specific cases. 186 | 187 | Args: 188 | code: The stock code (e.g., 'sh.600000'). 189 | start_date: Start date (for report publication/update) in 'YYYY-MM-DD' format. 190 | end_date: End date (for report publication/update) in 'YYYY-MM-DD' format. 191 | 192 | Returns: 193 | Markdown table with performance forecast report data or an error message. 194 | """ 195 | logger.info( 196 | f"Tool 'get_forecast_report' called for {code} ({start_date} to {end_date})") 197 | try: 198 | # Add date validation if desired 199 | df = active_data_source.get_forecast_report( 200 | code=code, start_date=start_date, end_date=end_date) 201 | logger.info( 202 | f"Successfully retrieved performance forecast reports for {code}.") 203 | from src.formatting.markdown_formatter import format_table_output 204 | meta = {"code": code, "start_date": start_date, "end_date": end_date, "dataset": "forecast"} 205 | return format_table_output(df, format=format, max_rows=limit, meta=meta) 206 | 207 | except Exception as e: 208 | logger.exception( 209 | f"Exception processing get_forecast_report for {code}: {e}") 210 | return f"Error: An unexpected error occurred: {e}" 211 | ``` -------------------------------------------------------------------------------- /src/tools/indices.py: -------------------------------------------------------------------------------- ```python 1 | """ 2 | Index-related tools for the MCP server. 3 | Includes index constituents and industry utilities with clear, discoverable parameters. 4 | """ 5 | import logging 6 | from typing import Optional, List 7 | 8 | from mcp.server.fastmcp import FastMCP 9 | from src.data_source_interface import FinancialDataSource 10 | from src.tools.base import call_index_constituent_tool 11 | from src.formatting.markdown_formatter import format_table_output 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | def register_index_tools(app: FastMCP, active_data_source: FinancialDataSource): 17 | """ 18 | Register index related tools with the MCP app. 19 | 20 | Args: 21 | app: The FastMCP app instance 22 | active_data_source: The active financial data source 23 | """ 24 | 25 | @app.tool() 26 | def get_stock_industry(code: Optional[str] = None, date: Optional[str] = None, limit: int = 250, format: str = "markdown") -> str: 27 | """ 28 | Get industry classification for a specific stock or all stocks on a date. 29 | 30 | Args: 31 | code: Optional stock code in Baostock format (e.g., 'sh.600000'). If None, returns all. 32 | date: Optional 'YYYY-MM-DD'. If None, uses the latest available date. 33 | 34 | Returns: 35 | Markdown table with industry data or an error message. 36 | """ 37 | log_msg = f"Tool 'get_stock_industry' called for code={code or 'all'}, date={date or 'latest'}" 38 | logger.info(log_msg) 39 | try: 40 | # Add date validation if desired 41 | df = active_data_source.get_stock_industry(code=code, date=date) 42 | logger.info( 43 | f"Successfully retrieved industry data for {code or 'all'}, {date or 'latest'}.") 44 | meta = {"code": code or "all", "as_of": date or "latest"} 45 | return format_table_output(df, format=format, max_rows=limit, meta=meta) 46 | 47 | except Exception as e: 48 | logger.exception( 49 | f"Exception processing get_stock_industry: {e}") 50 | return f"Error: An unexpected error occurred: {e}" 51 | 52 | @app.tool() 53 | def get_sz50_stocks(date: Optional[str] = None, limit: int = 250, format: str = "markdown") -> str: 54 | """ 55 | Fetches the constituent stocks of the SZSE 50 Index for a given date. 56 | 57 | Args: 58 | date: Optional. The date in 'YYYY-MM-DD' format. If None, uses the latest available date. 59 | 60 | Returns: 61 | Markdown table with SZSE 50 constituent stocks or an error message. 62 | """ 63 | return call_index_constituent_tool( 64 | "get_sz50_stocks", 65 | active_data_source.get_sz50_stocks, 66 | "SZSE 50", 67 | date, 68 | limit=limit, format=format 69 | ) 70 | 71 | @app.tool() 72 | def get_hs300_stocks(date: Optional[str] = None, limit: int = 250, format: str = "markdown") -> str: 73 | """ 74 | Fetch the constituent stocks of the CSI 300 Index for a given date. 75 | 76 | Args: 77 | date: Optional 'YYYY-MM-DD'. If None, uses the latest available date. 78 | 79 | Returns: 80 | Markdown table with CSI 300 constituent stocks or an error message. 81 | """ 82 | return call_index_constituent_tool( 83 | "get_hs300_stocks", 84 | active_data_source.get_hs300_stocks, 85 | "CSI 300", 86 | date, 87 | limit=limit, format=format 88 | ) 89 | 90 | @app.tool() 91 | def get_zz500_stocks(date: Optional[str] = None, limit: int = 250, format: str = "markdown") -> str: 92 | """ 93 | Fetch the constituent stocks of the CSI 500 Index for a given date. 94 | 95 | Args: 96 | date: Optional 'YYYY-MM-DD'. If None, uses the latest available date. 97 | 98 | Returns: 99 | Markdown table with CSI 500 constituent stocks or an error message. 100 | """ 101 | return call_index_constituent_tool( 102 | "get_zz500_stocks", 103 | active_data_source.get_zz500_stocks, 104 | "CSI 500", 105 | date, 106 | limit=limit, format=format 107 | ) 108 | 109 | @app.tool() 110 | def get_index_constituents( 111 | index: str, 112 | date: Optional[str] = None, 113 | limit: int = 250, 114 | format: str = "markdown", 115 | ) -> str: 116 | """ 117 | Get constituents for a major index. 118 | 119 | Args: 120 | index: One of 'hs300' (CSI 300), 'sz50' (SSE 50), 'zz500' (CSI 500). 121 | date: Optional 'YYYY-MM-DD'. If None, uses the latest available date. 122 | limit: Max rows to return (pagination helper). Defaults to 250. 123 | format: Output format: 'markdown' | 'json' | 'csv'. Defaults to 'markdown'. 124 | 125 | Returns: 126 | Table of index constituents in the requested format. Defaults to Markdown. 127 | 128 | Examples: 129 | - get_index_constituents(index='hs300') 130 | - get_index_constituents(index='sz50', date='2024-12-31', format='json', limit=100) 131 | """ 132 | logger.info( 133 | f"Tool 'get_index_constituents' called index={index}, date={date or 'latest'}, limit={limit}, format={format}") 134 | try: 135 | key = (index or "").strip().lower() 136 | if key not in {"hs300", "sz50", "zz500"}: 137 | return "Error: Invalid index. Valid options are 'hs300', 'sz50', 'zz500'." 138 | 139 | if key == "hs300": 140 | df = active_data_source.get_hs300_stocks(date=date) 141 | elif key == "sz50": 142 | df = active_data_source.get_sz50_stocks(date=date) 143 | else: 144 | df = active_data_source.get_zz500_stocks(date=date) 145 | 146 | meta = { 147 | "index": key, 148 | "as_of": date or "latest", 149 | } 150 | return format_table_output(df, format=format, max_rows=limit, meta=meta) 151 | except Exception as e: 152 | logger.exception("Exception processing get_index_constituents: %s", e) 153 | return f"Error: An unexpected error occurred: {e}" 154 | 155 | @app.tool() 156 | def list_industries(date: Optional[str] = None, format: str = "markdown") -> str: 157 | """ 158 | List distinct industries for a given date. 159 | 160 | Args: 161 | date: Optional 'YYYY-MM-DD'. If None, uses the latest available date. 162 | format: Output format: 'markdown' | 'json' | 'csv'. Defaults to 'markdown'. 163 | 164 | Returns: 165 | One-column table of industries. 166 | """ 167 | logger.info("Tool 'list_industries' called date=%s", date or "latest") 168 | try: 169 | df = active_data_source.get_stock_industry(code=None, date=date) 170 | if df is None or df.empty: 171 | return "(No data available to display)" 172 | col = "industry" if "industry" in df.columns else df.columns[-1] 173 | out = df[[col]].drop_duplicates().sort_values(by=col) 174 | out = out.rename(columns={col: "industry"}) 175 | meta = {"as_of": date or "latest", "count": int(out.shape[0])} 176 | return format_table_output(out, format=format, max_rows=out.shape[0], meta=meta) 177 | except Exception as e: 178 | logger.exception("Exception processing list_industries: %s", e) 179 | return f"Error: An unexpected error occurred: {e}" 180 | 181 | @app.tool() 182 | def get_industry_members( 183 | industry: str, 184 | date: Optional[str] = None, 185 | limit: int = 250, 186 | format: str = "markdown", 187 | ) -> str: 188 | """ 189 | Get all stocks that belong to a given industry on a date. 190 | 191 | Args: 192 | industry: Exact industry name to filter by (see list_industries). 193 | date: Optional 'YYYY-MM-DD'. If None, uses the latest available date. 194 | limit: Max rows to return. Defaults to 250. 195 | format: Output format: 'markdown' | 'json' | 'csv'. Defaults to 'markdown'. 196 | 197 | Returns: 198 | Table of stocks in the given industry. 199 | """ 200 | logger.info( 201 | "Tool 'get_industry_members' called industry=%s, date=%s, limit=%s, format=%s", 202 | industry, date or "latest", limit, format, 203 | ) 204 | try: 205 | if not industry or not industry.strip(): 206 | return "Error: 'industry' is required. Call list_industries() to discover available values." 207 | df = active_data_source.get_stock_industry(code=None, date=date) 208 | if df is None or df.empty: 209 | return "(No data available to display)" 210 | col = "industry" if "industry" in df.columns else df.columns[-1] 211 | filtered = df[df[col] == industry].copy() 212 | meta = {"industry": industry, "as_of": date or "latest"} 213 | return format_table_output(filtered, format=format, max_rows=limit, meta=meta) 214 | except Exception as e: 215 | logger.exception("Exception processing get_industry_members: %s", e) 216 | return f"Error: An unexpected error occurred: {e}" 217 | ``` -------------------------------------------------------------------------------- /src/tools/date_utils.py: -------------------------------------------------------------------------------- ```python 1 | """ 2 | Date utility tools for the MCP server. 3 | Convenience helpers around trading days and analysis timeframes. 4 | """ 5 | import logging 6 | from datetime import datetime, timedelta 7 | import calendar 8 | 9 | from mcp.server.fastmcp import FastMCP 10 | from src.data_source_interface import FinancialDataSource 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | def register_date_utils_tools(app: FastMCP, active_data_source: FinancialDataSource): 16 | """ 17 | Register date utility tools with the MCP app. 18 | 19 | Args: 20 | app: The FastMCP app instance 21 | active_data_source: The active financial data source 22 | """ 23 | 24 | @app.tool() 25 | def get_latest_trading_date() -> str: 26 | """ 27 | Get the latest trading date up to today. 28 | 29 | Returns: 30 | The latest trading date in 'YYYY-MM-DD' format. 31 | """ 32 | logger.info("Tool 'get_latest_trading_date' called") 33 | try: 34 | today = datetime.now().strftime("%Y-%m-%d") 35 | # Query within the current month (safe bound) 36 | start_date = (datetime.now().replace(day=1)).strftime("%Y-%m-%d") 37 | end_date = (datetime.now().replace(day=28)).strftime("%Y-%m-%d") 38 | 39 | df = active_data_source.get_trade_dates( 40 | start_date=start_date, end_date=end_date) 41 | 42 | valid_trading_days = df[df['is_trading_day'] == '1']['calendar_date'].tolist() 43 | 44 | latest_trading_date = None 45 | for dstr in valid_trading_days: 46 | if dstr <= today and (latest_trading_date is None or dstr > latest_trading_date): 47 | latest_trading_date = dstr 48 | 49 | if latest_trading_date: 50 | logger.info("Latest trading date found: %s", latest_trading_date) 51 | return latest_trading_date 52 | else: 53 | logger.warning("No trading dates found before today, returning today's date") 54 | return today 55 | 56 | except Exception as e: 57 | logger.exception("Error determining latest trading date: %s", e) 58 | return datetime.now().strftime("%Y-%m-%d") 59 | 60 | @app.tool() 61 | def get_market_analysis_timeframe(period: str = "recent") -> str: 62 | """ 63 | Get a market analysis timeframe label tuned for current calendar context. 64 | 65 | Args: 66 | period: One of 'recent' (default), 'quarter', 'half_year', 'year'. 67 | 68 | Returns: 69 | A human-friendly label plus ISO range, like "2025年1月-3月 (ISO: 2025-01-01 至 2025-03-31)". 70 | """ 71 | logger.info( 72 | f"Tool 'get_market_analysis_timeframe' called with period={period}") 73 | 74 | now = datetime.now() 75 | end_date = now 76 | 77 | if period == "recent": 78 | if now.day < 15: 79 | if now.month == 1: 80 | start_date = datetime(now.year - 1, 11, 1) 81 | middle_date = datetime(now.year - 1, 12, 1) 82 | elif now.month == 2: 83 | start_date = datetime(now.year, 1, 1) 84 | middle_date = start_date 85 | else: 86 | start_date = datetime(now.year, now.month - 2, 1) 87 | middle_date = datetime(now.year, now.month - 1, 1) 88 | else: 89 | if now.month == 1: 90 | start_date = datetime(now.year - 1, 12, 1) 91 | middle_date = start_date 92 | else: 93 | start_date = datetime(now.year, now.month - 1, 1) 94 | middle_date = start_date 95 | 96 | elif period == "quarter": 97 | if now.month <= 3: 98 | start_date = datetime(now.year - 1, now.month + 9, 1) 99 | else: 100 | start_date = datetime(now.year, now.month - 3, 1) 101 | middle_date = start_date 102 | 103 | elif period == "half_year": 104 | if now.month <= 6: 105 | start_date = datetime(now.year - 1, now.month + 6, 1) 106 | else: 107 | start_date = datetime(now.year, now.month - 6, 1) 108 | middle_date = datetime(start_date.year, start_date.month + 3, 1) if start_date.month <= 9 else \ 109 | datetime(start_date.year + 1, start_date.month - 9, 1) 110 | 111 | elif period == "year": 112 | start_date = datetime(now.year - 1, now.month, 1) 113 | middle_date = datetime(start_date.year, start_date.month + 6, 1) if start_date.month <= 6 else \ 114 | datetime(start_date.year + 1, start_date.month - 6, 1) 115 | else: 116 | if now.month == 1: 117 | start_date = datetime(now.year - 1, 12, 1) 118 | else: 119 | start_date = datetime(now.year, now.month - 1, 1) 120 | middle_date = start_date 121 | 122 | def get_month_end_day(year, month): 123 | return calendar.monthrange(year, month)[1] 124 | 125 | end_day = min(get_month_end_day(end_date.year, end_date.month), end_date.day) 126 | end_iso_date = f"{end_date.year}-{end_date.month:02d}-{end_day:02d}" 127 | 128 | start_iso_date = f"{start_date.year}-{start_date.month:02d}-01" 129 | 130 | if start_date.year != end_date.year: 131 | date_range = f"{start_date.year}年{start_date.month}月-{end_date.year}年{end_date.month}月" 132 | elif middle_date.month != start_date.month and middle_date.month != end_date.month: 133 | date_range = f"{start_date.year}年{start_date.month}月-{middle_date.month}月-{end_date.month}月" 134 | elif start_date.month != end_date.month: 135 | date_range = f"{start_date.year}年{start_date.month}月-{end_date.month}月" 136 | else: 137 | date_range = f"{start_date.year}年{start_date.month}月" 138 | 139 | result = f"{date_range} (ISO: {start_iso_date} to {end_iso_date})" 140 | logger.info(f"Generated market analysis timeframe: {result}") 141 | return result 142 | 143 | @app.tool() 144 | def is_trading_day(date: str) -> str: 145 | """ 146 | Check whether a given date is a trading day. 147 | 148 | Args: 149 | date: 'YYYY-MM-DD'. 150 | 151 | Returns: 152 | 'Yes' or 'No'. 153 | 154 | Examples: 155 | - is_trading_day('2025-01-03') 156 | """ 157 | logger.info("Tool 'is_trading_day' called date=%s", date) 158 | try: 159 | df = active_data_source.get_trade_dates(start_date=date, end_date=date) 160 | if df is None or df.empty: 161 | return "No" 162 | flag_col = 'is_trading_day' if 'is_trading_day' in df.columns else df.columns[-1] 163 | val = str(df.iloc[0][flag_col]) 164 | return "Yes" if val == '1' else "No" 165 | except Exception as e: 166 | logger.exception("Exception processing is_trading_day: %s", e) 167 | return f"Error: {e}" 168 | 169 | @app.tool() 170 | def previous_trading_day(date: str) -> str: 171 | """ 172 | Get the previous trading day before a given date. 173 | 174 | Args: 175 | date: 'YYYY-MM-DD'. 176 | 177 | Returns: 178 | The previous trading day in 'YYYY-MM-DD'. If none found nearby, returns input date. 179 | """ 180 | logger.info("Tool 'previous_trading_day' called date=%s", date) 181 | try: 182 | d = datetime.strptime(date, "%Y-%m-%d") 183 | start = (d - timedelta(days=30)).strftime("%Y-%m-%d") 184 | end = date 185 | df = active_data_source.get_trade_dates(start_date=start, end_date=end) 186 | if df is None or df.empty: 187 | return date 188 | flag_col = 'is_trading_day' if 'is_trading_day' in df.columns else df.columns[-1] 189 | day_col = 'calendar_date' if 'calendar_date' in df.columns else df.columns[0] 190 | candidates = df[(df[flag_col] == '1') & (df[day_col] < date)].sort_values(by=day_col) 191 | if candidates.empty: 192 | return date 193 | return str(candidates.iloc[-1][day_col]) 194 | except Exception as e: 195 | logger.exception("Exception processing previous_trading_day: %s", e) 196 | return f"Error: {e}" 197 | 198 | @app.tool() 199 | def next_trading_day(date: str) -> str: 200 | """ 201 | Get the next trading day after a given date. 202 | 203 | Args: 204 | date: 'YYYY-MM-DD'. 205 | 206 | Returns: 207 | The next trading day in 'YYYY-MM-DD'. If none found nearby, returns input date. 208 | """ 209 | logger.info("Tool 'next_trading_day' called date=%s", date) 210 | try: 211 | d = datetime.strptime(date, "%Y-%m-%d") 212 | start = date 213 | end = (d + timedelta(days=30)).strftime("%Y-%m-%d") 214 | df = active_data_source.get_trade_dates(start_date=start, end_date=end) 215 | if df is None or df.empty: 216 | return date 217 | flag_col = 'is_trading_day' if 'is_trading_day' in df.columns else df.columns[-1] 218 | day_col = 'calendar_date' if 'calendar_date' in df.columns else df.columns[0] 219 | candidates = df[(df[flag_col] == '1') & (df[day_col] > date)].sort_values(by=day_col) 220 | if candidates.empty: 221 | return date 222 | return str(candidates.iloc[0][day_col]) 223 | except Exception as e: 224 | logger.exception("Exception processing next_trading_day: %s", e) 225 | return f"Error: {e}" 226 | 227 | ``` -------------------------------------------------------------------------------- /src/tools/stock_market.py: -------------------------------------------------------------------------------- ```python 1 | """ 2 | Stock market tools for the MCP server. 3 | Historical prices, basic info, dividends, and adjust factors with clear options. 4 | """ 5 | import logging 6 | from typing import List, Optional 7 | 8 | from mcp.server.fastmcp import FastMCP 9 | from src.data_source_interface import FinancialDataSource, NoDataFoundError, LoginError, DataSourceError 10 | from src.formatting.markdown_formatter import format_df_to_markdown, format_table_output 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | def register_stock_market_tools(app: FastMCP, active_data_source: FinancialDataSource): 16 | """ 17 | Register stock market data tools with the MCP app. 18 | 19 | Args: 20 | app: The FastMCP app instance 21 | active_data_source: The active financial data source 22 | """ 23 | 24 | @app.tool() 25 | def get_historical_k_data( 26 | code: str, 27 | start_date: str, 28 | end_date: str, 29 | frequency: str = "d", 30 | adjust_flag: str = "3", 31 | fields: Optional[List[str]] = None, 32 | limit: int = 250, 33 | format: str = "markdown", 34 | ) -> str: 35 | """ 36 | Fetches historical K-line (OHLCV) data for a Chinese A-share stock. 37 | 38 | Args: 39 | code: The stock code in Baostock format (e.g., 'sh.600000', 'sz.000001'). 40 | start_date: Start date in 'YYYY-MM-DD' format. 41 | end_date: End date in 'YYYY-MM-DD' format. 42 | frequency: Data frequency. Valid options (from Baostock): 43 | 'd': daily 44 | 'w': weekly 45 | 'm': monthly 46 | '5': 5 minutes 47 | '15': 15 minutes 48 | '30': 30 minutes 49 | '60': 60 minutes 50 | Defaults to 'd'. 51 | adjust_flag: Adjustment flag for price/volume. Valid options (from Baostock): 52 | '1': Forward adjusted (后复权) 53 | '2': Backward adjusted (前复权) 54 | '3': Non-adjusted (不复权) 55 | Defaults to '3'. 56 | fields: Optional list of specific data fields to retrieve (must be valid Baostock fields). 57 | If None or empty, default fields will be used (e.g., date, code, open, high, low, close, volume, amount, pctChg). 58 | limit: Max rows to return. Defaults to 250. 59 | format: Output format: 'markdown' | 'json' | 'csv'. Defaults to 'markdown'. 60 | 61 | Returns: 62 | A Markdown formatted string containing the K-line data table, or an error message. 63 | The table might be truncated if the result set is too large. 64 | """ 65 | logger.info( 66 | f"Tool 'get_historical_k_data' called for {code} ({start_date}-{end_date}, freq={frequency}, adj={adjust_flag}, fields={fields})") 67 | try: 68 | # Validate frequency and adjust_flag if necessary (basic example) 69 | valid_freqs = ['d', 'w', 'm', '5', '15', '30', '60'] 70 | valid_adjusts = ['1', '2', '3'] 71 | if frequency not in valid_freqs: 72 | logger.warning(f"Invalid frequency requested: {frequency}") 73 | return f"Error: Invalid frequency '{frequency}'. Valid options are: {valid_freqs}" 74 | if adjust_flag not in valid_adjusts: 75 | logger.warning(f"Invalid adjust_flag requested: {adjust_flag}") 76 | return f"Error: Invalid adjust_flag '{adjust_flag}'. Valid options are: {valid_adjusts}" 77 | 78 | # Call the injected data source 79 | df = active_data_source.get_historical_k_data( 80 | code=code, 81 | start_date=start_date, 82 | end_date=end_date, 83 | frequency=frequency, 84 | adjust_flag=adjust_flag, 85 | fields=fields, 86 | ) 87 | # Format the result 88 | logger.info( 89 | f"Successfully retrieved K-data for {code}, formatting output.") 90 | meta = {"code": code, "start_date": start_date, "end_date": end_date, "frequency": frequency, "adjust_flag": adjust_flag} 91 | return format_table_output(df, format=format, max_rows=limit, meta=meta) 92 | 93 | except NoDataFoundError as e: 94 | logger.warning(f"NoDataFoundError for {code}: {e}") 95 | return f"Error: {e}" 96 | except LoginError as e: 97 | logger.error(f"LoginError for {code}: {e}") 98 | return f"Error: Could not connect to data source. {e}" 99 | except DataSourceError as e: 100 | logger.error(f"DataSourceError for {code}: {e}") 101 | return f"Error: An error occurred while fetching data. {e}" 102 | except ValueError as e: 103 | logger.warning(f"ValueError processing request for {code}: {e}") 104 | return f"Error: Invalid input parameter. {e}" 105 | except Exception as e: 106 | # Catch-all for unexpected errors 107 | logger.exception( 108 | f"Unexpected Exception processing get_historical_k_data for {code}: {e}") 109 | return f"Error: An unexpected error occurred: {e}" 110 | 111 | @app.tool() 112 | def get_stock_basic_info(code: str, fields: Optional[List[str]] = None, format: str = "markdown") -> str: 113 | """ 114 | Fetches basic information for a given Chinese A-share stock. 115 | 116 | Args: 117 | code: The stock code in Baostock format (e.g., 'sh.600000', 'sz.000001'). 118 | fields: Optional list to select specific columns from the available basic info 119 | (e.g., ['code', 'code_name', 'industry', 'listingDate']). 120 | If None or empty, returns all available basic info columns from Baostock. 121 | 122 | Returns: 123 | Basic stock information in the requested format. 124 | """ 125 | logger.info( 126 | f"Tool 'get_stock_basic_info' called for {code} (fields={fields})") 127 | try: 128 | # Call the injected data source 129 | # Pass fields along; BaostockDataSource implementation handles selection 130 | df = active_data_source.get_stock_basic_info( 131 | code=code, fields=fields) 132 | 133 | # Format the result (basic info usually small) 134 | logger.info( 135 | f"Successfully retrieved basic info for {code}, formatting output.") 136 | meta = {"code": code} 137 | return format_table_output(df, format=format, max_rows=df.shape[0] if df is not None else 0, meta=meta) 138 | 139 | except NoDataFoundError as e: 140 | logger.warning(f"NoDataFoundError for {code}: {e}") 141 | return f"Error: {e}" 142 | except LoginError as e: 143 | logger.error(f"LoginError for {code}: {e}") 144 | return f"Error: Could not connect to data source. {e}" 145 | except DataSourceError as e: 146 | logger.error(f"DataSourceError for {code}: {e}") 147 | return f"Error: An error occurred while fetching data. {e}" 148 | except ValueError as e: 149 | logger.warning(f"ValueError processing request for {code}: {e}") 150 | return f"Error: Invalid input parameter or requested field not available. {e}" 151 | except Exception as e: 152 | logger.exception( 153 | f"Unexpected Exception processing get_stock_basic_info for {code}: {e}") 154 | return f"Error: An unexpected error occurred: {e}" 155 | 156 | @app.tool() 157 | def get_dividend_data(code: str, year: str, year_type: str = "report", limit: int = 250, format: str = "markdown") -> str: 158 | """ 159 | Fetches dividend information for a given stock code and year. 160 | 161 | Args: 162 | code: The stock code in Baostock format (e.g., 'sh.600000', 'sz.000001'). 163 | year: The year to query (e.g., '2023'). 164 | year_type: Type of year. Valid options (from Baostock): 165 | 'report': Announcement year (预案公告年份) 166 | 'operate': Ex-dividend year (除权除息年份) 167 | Defaults to 'report'. 168 | 169 | Returns: 170 | Dividend records table. 171 | """ 172 | logger.info( 173 | f"Tool 'get_dividend_data' called for {code}, year={year}, year_type={year_type}") 174 | try: 175 | # Basic validation 176 | if year_type not in ['report', 'operate']: 177 | logger.warning(f"Invalid year_type requested: {year_type}") 178 | return f"Error: Invalid year_type '{year_type}'. Valid options are: 'report', 'operate'" 179 | if not year.isdigit() or len(year) != 4: 180 | logger.warning(f"Invalid year format requested: {year}") 181 | return f"Error: Invalid year '{year}'. Please provide a 4-digit year." 182 | 183 | df = active_data_source.get_dividend_data( 184 | code=code, year=year, year_type=year_type) 185 | logger.info( 186 | f"Successfully retrieved dividend data for {code}, year {year}.") 187 | meta = {"code": code, "year": year, "year_type": year_type} 188 | return format_table_output(df, format=format, max_rows=limit, meta=meta) 189 | 190 | except NoDataFoundError as e: 191 | logger.warning(f"NoDataFoundError for {code}, year {year}: {e}") 192 | return f"Error: {e}" 193 | except LoginError as e: 194 | logger.error(f"LoginError for {code}: {e}") 195 | return f"Error: Could not connect to data source. {e}" 196 | except DataSourceError as e: 197 | logger.error(f"DataSourceError for {code}: {e}") 198 | return f"Error: An error occurred while fetching data. {e}" 199 | except ValueError as e: 200 | logger.warning(f"ValueError processing request for {code}: {e}") 201 | return f"Error: Invalid input parameter. {e}" 202 | except Exception as e: 203 | logger.exception( 204 | f"Unexpected Exception processing get_dividend_data for {code}: {e}") 205 | return f"Error: An unexpected error occurred: {e}" 206 | 207 | @app.tool() 208 | def get_adjust_factor_data(code: str, start_date: str, end_date: str, limit: int = 250, format: str = "markdown") -> str: 209 | """ 210 | Fetches adjustment factor data for a given stock code and date range. 211 | Uses Baostock's "涨跌幅复权算法" factors. Useful for calculating adjusted prices. 212 | 213 | Args: 214 | code: The stock code in Baostock format (e.g., 'sh.600000', 'sz.000001'). 215 | start_date: Start date in 'YYYY-MM-DD' format. 216 | end_date: End date in 'YYYY-MM-DD' format. 217 | 218 | Returns: 219 | Adjustment factors table. 220 | """ 221 | logger.info( 222 | f"Tool 'get_adjust_factor_data' called for {code} ({start_date} to {end_date})") 223 | try: 224 | # Basic date validation could be added here if desired 225 | df = active_data_source.get_adjust_factor_data( 226 | code=code, start_date=start_date, end_date=end_date) 227 | logger.info( 228 | f"Successfully retrieved adjustment factor data for {code}.") 229 | meta = {"code": code, "start_date": start_date, "end_date": end_date} 230 | return format_table_output(df, format=format, max_rows=limit, meta=meta) 231 | 232 | except NoDataFoundError as e: 233 | logger.warning(f"NoDataFoundError for {code}: {e}") 234 | return f"Error: {e}" 235 | except LoginError as e: 236 | logger.error(f"LoginError for {code}: {e}") 237 | return f"Error: Could not connect to data source. {e}" 238 | except DataSourceError as e: 239 | logger.error(f"DataSourceError for {code}: {e}") 240 | return f"Error: An error occurred while fetching data. {e}" 241 | except ValueError as e: 242 | logger.warning(f"ValueError processing request for {code}: {e}") 243 | return f"Error: Invalid input parameter. {e}" 244 | except Exception as e: 245 | logger.exception( 246 | f"Unexpected Exception processing get_adjust_factor_data for {code}: {e}") 247 | return f"Error: An unexpected error occurred: {e}" 248 | ``` -------------------------------------------------------------------------------- /src/baostock_data_source.py: -------------------------------------------------------------------------------- ```python 1 | # Implementation of the FinancialDataSource interface using Baostock 2 | import baostock as bs 3 | import pandas as pd 4 | from typing import List, Optional 5 | import logging 6 | from .data_source_interface import FinancialDataSource, DataSourceError, NoDataFoundError, LoginError 7 | from .utils import baostock_login_context 8 | 9 | # Get a logger instance for this module 10 | logger = logging.getLogger(__name__) 11 | 12 | DEFAULT_K_FIELDS = [ 13 | "date", "code", "open", "high", "low", "close", "preclose", 14 | "volume", "amount", "adjustflag", "turn", "tradestatus", 15 | "pctChg", "peTTM", "pbMRQ", "psTTM", "pcfNcfTTM", "isST" 16 | ] 17 | 18 | DEFAULT_BASIC_FIELDS = [ 19 | "code", "tradeStatus", "code_name" 20 | # Add more default fields as needed, e.g., "industry", "listingDate" 21 | ] 22 | 23 | # Helper function to reduce repetition in financial data fetching 24 | 25 | 26 | def _fetch_financial_data( 27 | bs_query_func, 28 | data_type_name: str, 29 | code: str, 30 | year: str, 31 | quarter: int 32 | ) -> pd.DataFrame: 33 | logger.info( 34 | f"Fetching {data_type_name} data for {code}, year={year}, quarter={quarter}") 35 | try: 36 | with baostock_login_context(): 37 | # Assuming all these functions take code, year, quarter 38 | rs = bs_query_func(code=code, year=year, quarter=quarter) 39 | 40 | if rs.error_code != '0': 41 | logger.error( 42 | f"Baostock API error ({data_type_name}) for {code}: {rs.error_msg} (code: {rs.error_code})") 43 | if "no record found" in rs.error_msg.lower() or rs.error_code == '10002': 44 | raise NoDataFoundError( 45 | f"No {data_type_name} data found for {code}, {year}Q{quarter}. Baostock msg: {rs.error_msg}") 46 | else: 47 | raise DataSourceError( 48 | f"Baostock API error fetching {data_type_name} data: {rs.error_msg} (code: {rs.error_code})") 49 | 50 | data_list = [] 51 | while rs.next(): 52 | data_list.append(rs.get_row_data()) 53 | 54 | if not data_list: 55 | logger.warning( 56 | f"No {data_type_name} data found for {code}, {year}Q{quarter} (empty result set from Baostock).") 57 | raise NoDataFoundError( 58 | f"No {data_type_name} data found for {code}, {year}Q{quarter} (empty result set).") 59 | 60 | result_df = pd.DataFrame(data_list, columns=rs.fields) 61 | logger.info( 62 | f"Retrieved {len(result_df)} {data_type_name} records for {code}, {year}Q{quarter}.") 63 | return result_df 64 | 65 | except (LoginError, NoDataFoundError, DataSourceError, ValueError) as e: 66 | logger.warning( 67 | f"Caught known error fetching {data_type_name} data for {code}: {type(e).__name__}") 68 | raise e 69 | except Exception as e: 70 | logger.exception( 71 | f"Unexpected error fetching {data_type_name} data for {code}: {e}") 72 | raise DataSourceError( 73 | f"Unexpected error fetching {data_type_name} data for {code}: {e}") 74 | 75 | # Helper function to reduce repetition for index constituent data fetching 76 | 77 | 78 | def _fetch_index_constituent_data( 79 | bs_query_func, 80 | index_name: str, 81 | date: Optional[str] = None 82 | ) -> pd.DataFrame: 83 | logger.info( 84 | f"Fetching {index_name} constituents for date={date or 'latest'}") 85 | try: 86 | with baostock_login_context(): 87 | # date is optional, defaults to latest 88 | rs = bs_query_func(date=date) 89 | 90 | if rs.error_code != '0': 91 | logger.error( 92 | f"Baostock API error ({index_name} Constituents) for date {date}: {rs.error_msg} (code: {rs.error_code})") 93 | if "no record found" in rs.error_msg.lower() or rs.error_code == '10002': 94 | raise NoDataFoundError( 95 | f"No {index_name} constituent data found for date {date}. Baostock msg: {rs.error_msg}") 96 | else: 97 | raise DataSourceError( 98 | f"Baostock API error fetching {index_name} constituents: {rs.error_msg} (code: {rs.error_code})") 99 | 100 | data_list = [] 101 | while rs.next(): 102 | data_list.append(rs.get_row_data()) 103 | 104 | if not data_list: 105 | logger.warning( 106 | f"No {index_name} constituent data found for date {date} (empty result set).") 107 | raise NoDataFoundError( 108 | f"No {index_name} constituent data found for date {date} (empty result set).") 109 | 110 | result_df = pd.DataFrame(data_list, columns=rs.fields) 111 | logger.info( 112 | f"Retrieved {len(result_df)} {index_name} constituents for date {date or 'latest'}.") 113 | return result_df 114 | 115 | except (LoginError, NoDataFoundError, DataSourceError, ValueError) as e: 116 | logger.warning( 117 | f"Caught known error fetching {index_name} constituents for date {date}: {type(e).__name__}") 118 | raise e 119 | except Exception as e: 120 | logger.exception( 121 | f"Unexpected error fetching {index_name} constituents for date {date}: {e}") 122 | raise DataSourceError( 123 | f"Unexpected error fetching {index_name} constituents for date {date}: {e}") 124 | 125 | # Helper function to reduce repetition for macroeconomic data fetching 126 | 127 | 128 | def _fetch_macro_data( 129 | bs_query_func, 130 | data_type_name: str, 131 | start_date: Optional[str] = None, 132 | end_date: Optional[str] = None, 133 | **kwargs # For extra params like yearType 134 | ) -> pd.DataFrame: 135 | date_range_log = f"from {start_date or 'default'} to {end_date or 'default'}" 136 | kwargs_log = f", extra_args={kwargs}" if kwargs else "" 137 | logger.info(f"Fetching {data_type_name} data {date_range_log}{kwargs_log}") 138 | try: 139 | with baostock_login_context(): 140 | rs = bs_query_func(start_date=start_date, 141 | end_date=end_date, **kwargs) 142 | 143 | if rs.error_code != '0': 144 | logger.error( 145 | f"Baostock API error ({data_type_name}): {rs.error_msg} (code: {rs.error_code})") 146 | if "no record found" in rs.error_msg.lower() or rs.error_code == '10002': 147 | raise NoDataFoundError( 148 | f"No {data_type_name} data found for the specified criteria. Baostock msg: {rs.error_msg}") 149 | else: 150 | raise DataSourceError( 151 | f"Baostock API error fetching {data_type_name} data: {rs.error_msg} (code: {rs.error_code})") 152 | 153 | data_list = [] 154 | while rs.next(): 155 | data_list.append(rs.get_row_data()) 156 | 157 | if not data_list: 158 | logger.warning( 159 | f"No {data_type_name} data found for the specified criteria (empty result set).") 160 | raise NoDataFoundError( 161 | f"No {data_type_name} data found for the specified criteria (empty result set).") 162 | 163 | result_df = pd.DataFrame(data_list, columns=rs.fields) 164 | logger.info( 165 | f"Retrieved {len(result_df)} {data_type_name} records.") 166 | return result_df 167 | 168 | except (LoginError, NoDataFoundError, DataSourceError, ValueError) as e: 169 | logger.warning( 170 | f"Caught known error fetching {data_type_name} data: {type(e).__name__}") 171 | raise e 172 | except Exception as e: 173 | logger.exception( 174 | f"Unexpected error fetching {data_type_name} data: {e}") 175 | raise DataSourceError( 176 | f"Unexpected error fetching {data_type_name} data: {e}") 177 | 178 | 179 | class BaostockDataSource(FinancialDataSource): 180 | """ 181 | Concrete implementation of FinancialDataSource using the Baostock library. 182 | """ 183 | 184 | def _format_fields(self, fields: Optional[List[str]], default_fields: List[str]) -> str: 185 | """Formats the list of fields into a comma-separated string for Baostock.""" 186 | if fields is None or not fields: 187 | logger.debug( 188 | f"No specific fields requested, using defaults: {default_fields}") 189 | return ",".join(default_fields) 190 | # Basic validation: ensure requested fields are strings 191 | if not all(isinstance(f, str) for f in fields): 192 | raise ValueError("All items in the fields list must be strings.") 193 | logger.debug(f"Using requested fields: {fields}") 194 | return ",".join(fields) 195 | 196 | def get_historical_k_data( 197 | self, 198 | code: str, 199 | start_date: str, 200 | end_date: str, 201 | frequency: str = "d", 202 | adjust_flag: str = "3", 203 | fields: Optional[List[str]] = None, 204 | ) -> pd.DataFrame: 205 | """Fetches historical K-line data using Baostock.""" 206 | logger.info( 207 | f"Fetching K-data for {code} ({start_date} to {end_date}), freq={frequency}, adjust={adjust_flag}") 208 | try: 209 | formatted_fields = self._format_fields(fields, DEFAULT_K_FIELDS) 210 | logger.debug( 211 | f"Requesting fields from Baostock: {formatted_fields}") 212 | 213 | with baostock_login_context(): 214 | rs = bs.query_history_k_data_plus( 215 | code, 216 | formatted_fields, 217 | start_date=start_date, 218 | end_date=end_date, 219 | frequency=frequency, 220 | adjustflag=adjust_flag 221 | ) 222 | 223 | if rs.error_code != '0': 224 | logger.error( 225 | f"Baostock API error (K-data) for {code}: {rs.error_msg} (code: {rs.error_code})") 226 | # Check common error codes, e.g., for no data 227 | if "no record found" in rs.error_msg.lower() or rs.error_code == '10002': # Example error code 228 | raise NoDataFoundError( 229 | f"No historical data found for {code} in the specified range. Baostock msg: {rs.error_msg}") 230 | else: 231 | raise DataSourceError( 232 | f"Baostock API error fetching K-data: {rs.error_msg} (code: {rs.error_code})") 233 | 234 | data_list = [] 235 | while rs.next(): 236 | data_list.append(rs.get_row_data()) 237 | 238 | if not data_list: 239 | logger.warning( 240 | f"No historical data found for {code} in range (empty result set from Baostock).") 241 | raise NoDataFoundError( 242 | f"No historical data found for {code} in the specified range (empty result set).") 243 | 244 | # Crucial: Use rs.fields for column names 245 | result_df = pd.DataFrame(data_list, columns=rs.fields) 246 | logger.info(f"Retrieved {len(result_df)} records for {code}.") 247 | return result_df 248 | 249 | except (LoginError, NoDataFoundError, DataSourceError, ValueError) as e: 250 | # Re-raise known errors 251 | logger.warning( 252 | f"Caught known error fetching K-data for {code}: {type(e).__name__}") 253 | raise e 254 | except Exception as e: 255 | # Wrap unexpected errors 256 | # Use logger.exception to include traceback 257 | logger.exception( 258 | f"Unexpected error fetching K-data for {code}: {e}") 259 | raise DataSourceError( 260 | f"Unexpected error fetching K-data for {code}: {e}") 261 | 262 | def get_stock_basic_info(self, code: str, fields: Optional[List[str]] = None) -> pd.DataFrame: 263 | """Fetches basic stock information using Baostock.""" 264 | logger.info(f"Fetching basic info for {code}") 265 | try: 266 | # Note: query_stock_basic doesn't seem to have a fields parameter in docs, 267 | # but we keep the signature consistent. It returns a fixed set. 268 | # We will use the `fields` argument post-query to select columns if needed. 269 | logger.debug( 270 | f"Requesting basic info for {code}. Optional fields requested: {fields}") 271 | 272 | with baostock_login_context(): 273 | # Example: Fetch basic info; adjust API call if needed based on baostock docs 274 | # rs = bs.query_stock_basic(code=code, code_name=code_name) # If supporting name lookup 275 | rs = bs.query_stock_basic(code=code) 276 | 277 | if rs.error_code != '0': 278 | logger.error( 279 | f"Baostock API error (Basic Info) for {code}: {rs.error_msg} (code: {rs.error_code})") 280 | if "no record found" in rs.error_msg.lower() or rs.error_code == '10002': 281 | raise NoDataFoundError( 282 | f"No basic info found for {code}. Baostock msg: {rs.error_msg}") 283 | else: 284 | raise DataSourceError( 285 | f"Baostock API error fetching basic info: {rs.error_msg} (code: {rs.error_code})") 286 | 287 | data_list = [] 288 | while rs.next(): 289 | data_list.append(rs.get_row_data()) 290 | 291 | if not data_list: 292 | logger.warning( 293 | f"No basic info found for {code} (empty result set from Baostock).") 294 | raise NoDataFoundError( 295 | f"No basic info found for {code} (empty result set).") 296 | 297 | # Crucial: Use rs.fields for column names 298 | result_df = pd.DataFrame(data_list, columns=rs.fields) 299 | logger.info( 300 | f"Retrieved basic info for {code}. Columns: {result_df.columns.tolist()}") 301 | 302 | # Optional: Select subset of columns if `fields` argument was provided 303 | if fields: 304 | available_cols = [ 305 | col for col in fields if col in result_df.columns] 306 | if not available_cols: 307 | raise ValueError( 308 | f"None of the requested fields {fields} are available in the basic info result.") 309 | logger.debug( 310 | f"Selecting columns: {available_cols} from basic info for {code}") 311 | result_df = result_df[available_cols] 312 | 313 | return result_df 314 | 315 | except (LoginError, NoDataFoundError, DataSourceError, ValueError) as e: 316 | logger.warning( 317 | f"Caught known error fetching basic info for {code}: {type(e).__name__}") 318 | raise e 319 | except Exception as e: 320 | logger.exception( 321 | f"Unexpected error fetching basic info for {code}: {e}") 322 | raise DataSourceError( 323 | f"Unexpected error fetching basic info for {code}: {e}") 324 | 325 | def get_dividend_data(self, code: str, year: str, year_type: str = "report") -> pd.DataFrame: 326 | """Fetches dividend information using Baostock.""" 327 | logger.info( 328 | f"Fetching dividend data for {code}, year={year}, year_type={year_type}") 329 | try: 330 | with baostock_login_context(): 331 | rs = bs.query_dividend_data( 332 | code=code, year=year, yearType=year_type) 333 | 334 | if rs.error_code != '0': 335 | logger.error( 336 | f"Baostock API error (Dividend) for {code}: {rs.error_msg} (code: {rs.error_code})") 337 | if "no record found" in rs.error_msg.lower() or rs.error_code == '10002': 338 | raise NoDataFoundError( 339 | f"No dividend data found for {code} and year {year}. Baostock msg: {rs.error_msg}") 340 | else: 341 | raise DataSourceError( 342 | f"Baostock API error fetching dividend data: {rs.error_msg} (code: {rs.error_code})") 343 | 344 | data_list = [] 345 | while rs.next(): 346 | data_list.append(rs.get_row_data()) 347 | 348 | if not data_list: 349 | logger.warning( 350 | f"No dividend data found for {code}, year {year} (empty result set from Baostock).") 351 | raise NoDataFoundError( 352 | f"No dividend data found for {code}, year {year} (empty result set).") 353 | 354 | result_df = pd.DataFrame(data_list, columns=rs.fields) 355 | logger.info( 356 | f"Retrieved {len(result_df)} dividend records for {code}, year {year}.") 357 | return result_df 358 | 359 | except (LoginError, NoDataFoundError, DataSourceError, ValueError) as e: 360 | logger.warning( 361 | f"Caught known error fetching dividend data for {code}: {type(e).__name__}") 362 | raise e 363 | except Exception as e: 364 | logger.exception( 365 | f"Unexpected error fetching dividend data for {code}: {e}") 366 | raise DataSourceError( 367 | f"Unexpected error fetching dividend data for {code}: {e}") 368 | 369 | def get_adjust_factor_data(self, code: str, start_date: str, end_date: str) -> pd.DataFrame: 370 | """Fetches adjustment factor data using Baostock.""" 371 | logger.info( 372 | f"Fetching adjustment factor data for {code} ({start_date} to {end_date})") 373 | try: 374 | with baostock_login_context(): 375 | rs = bs.query_adjust_factor( 376 | code=code, start_date=start_date, end_date=end_date) 377 | 378 | if rs.error_code != '0': 379 | logger.error( 380 | f"Baostock API error (Adjust Factor) for {code}: {rs.error_msg} (code: {rs.error_code})") 381 | if "no record found" in rs.error_msg.lower() or rs.error_code == '10002': 382 | raise NoDataFoundError( 383 | f"No adjustment factor data found for {code} in the specified range. Baostock msg: {rs.error_msg}") 384 | else: 385 | raise DataSourceError( 386 | f"Baostock API error fetching adjust factor data: {rs.error_msg} (code: {rs.error_code})") 387 | 388 | data_list = [] 389 | while rs.next(): 390 | data_list.append(rs.get_row_data()) 391 | 392 | if not data_list: 393 | logger.warning( 394 | f"No adjustment factor data found for {code} in range (empty result set from Baostock).") 395 | raise NoDataFoundError( 396 | f"No adjustment factor data found for {code} in the specified range (empty result set).") 397 | 398 | result_df = pd.DataFrame(data_list, columns=rs.fields) 399 | logger.info( 400 | f"Retrieved {len(result_df)} adjustment factor records for {code}.") 401 | return result_df 402 | 403 | except (LoginError, NoDataFoundError, DataSourceError, ValueError) as e: 404 | logger.warning( 405 | f"Caught known error fetching adjust factor data for {code}: {type(e).__name__}") 406 | raise e 407 | except Exception as e: 408 | logger.exception( 409 | f"Unexpected error fetching adjust factor data for {code}: {e}") 410 | raise DataSourceError( 411 | f"Unexpected error fetching adjust factor data for {code}: {e}") 412 | 413 | def get_profit_data(self, code: str, year: str, quarter: int) -> pd.DataFrame: 414 | """Fetches quarterly profitability data using Baostock.""" 415 | return _fetch_financial_data(bs.query_profit_data, "Profitability", code, year, quarter) 416 | 417 | def get_operation_data(self, code: str, year: str, quarter: int) -> pd.DataFrame: 418 | """Fetches quarterly operation capability data using Baostock.""" 419 | return _fetch_financial_data(bs.query_operation_data, "Operation Capability", code, year, quarter) 420 | 421 | def get_growth_data(self, code: str, year: str, quarter: int) -> pd.DataFrame: 422 | """Fetches quarterly growth capability data using Baostock.""" 423 | return _fetch_financial_data(bs.query_growth_data, "Growth Capability", code, year, quarter) 424 | 425 | def get_balance_data(self, code: str, year: str, quarter: int) -> pd.DataFrame: 426 | """Fetches quarterly balance sheet data (solvency) using Baostock.""" 427 | return _fetch_financial_data(bs.query_balance_data, "Balance Sheet", code, year, quarter) 428 | 429 | def get_cash_flow_data(self, code: str, year: str, quarter: int) -> pd.DataFrame: 430 | """Fetches quarterly cash flow data using Baostock.""" 431 | return _fetch_financial_data(bs.query_cash_flow_data, "Cash Flow", code, year, quarter) 432 | 433 | def get_dupont_data(self, code: str, year: str, quarter: int) -> pd.DataFrame: 434 | """Fetches quarterly DuPont analysis data using Baostock.""" 435 | return _fetch_financial_data(bs.query_dupont_data, "DuPont Analysis", code, year, quarter) 436 | 437 | def get_performance_express_report(self, code: str, start_date: str, end_date: str) -> pd.DataFrame: 438 | """Fetches performance express reports (业绩快报) using Baostock.""" 439 | logger.info( 440 | f"Fetching Performance Express Report for {code} ({start_date} to {end_date})") 441 | try: 442 | with baostock_login_context(): 443 | rs = bs.query_performance_express_report( 444 | code=code, start_date=start_date, end_date=end_date) 445 | 446 | if rs.error_code != '0': 447 | logger.error( 448 | f"Baostock API error (Perf Express) for {code}: {rs.error_msg} (code: {rs.error_code})") 449 | if "no record found" in rs.error_msg.lower() or rs.error_code == '10002': 450 | raise NoDataFoundError( 451 | f"No performance express report found for {code} in range {start_date}-{end_date}. Baostock msg: {rs.error_msg}") 452 | else: 453 | raise DataSourceError( 454 | f"Baostock API error fetching performance express report: {rs.error_msg} (code: {rs.error_code})") 455 | 456 | data_list = [] 457 | while rs.next(): 458 | data_list.append(rs.get_row_data()) 459 | 460 | if not data_list: 461 | logger.warning( 462 | f"No performance express report found for {code} in range {start_date}-{end_date} (empty result set).") 463 | raise NoDataFoundError( 464 | f"No performance express report found for {code} in range {start_date}-{end_date} (empty result set).") 465 | 466 | result_df = pd.DataFrame(data_list, columns=rs.fields) 467 | logger.info( 468 | f"Retrieved {len(result_df)} performance express report records for {code}.") 469 | return result_df 470 | 471 | except (LoginError, NoDataFoundError, DataSourceError, ValueError) as e: 472 | logger.warning( 473 | f"Caught known error fetching performance express report for {code}: {type(e).__name__}") 474 | raise e 475 | except Exception as e: 476 | logger.exception( 477 | f"Unexpected error fetching performance express report for {code}: {e}") 478 | raise DataSourceError( 479 | f"Unexpected error fetching performance express report for {code}: {e}") 480 | 481 | def get_forecast_report(self, code: str, start_date: str, end_date: str) -> pd.DataFrame: 482 | """Fetches performance forecast reports (业绩预告) using Baostock.""" 483 | logger.info( 484 | f"Fetching Performance Forecast Report for {code} ({start_date} to {end_date})") 485 | try: 486 | with baostock_login_context(): 487 | rs = bs.query_forecast_report( 488 | code=code, start_date=start_date, end_date=end_date) 489 | # Note: Baostock docs mention pagination for this, but the Python API doesn't seem to expose it directly. 490 | # We fetch all available pages in the loop below. 491 | 492 | if rs.error_code != '0': 493 | logger.error( 494 | f"Baostock API error (Forecast) for {code}: {rs.error_msg} (code: {rs.error_code})") 495 | if "no record found" in rs.error_msg.lower() or rs.error_code == '10002': 496 | raise NoDataFoundError( 497 | f"No performance forecast report found for {code} in range {start_date}-{end_date}. Baostock msg: {rs.error_msg}") 498 | else: 499 | raise DataSourceError( 500 | f"Baostock API error fetching performance forecast report: {rs.error_msg} (code: {rs.error_code})") 501 | 502 | data_list = [] 503 | while rs.next(): # Loop should handle pagination implicitly if rs manages it 504 | data_list.append(rs.get_row_data()) 505 | 506 | if not data_list: 507 | logger.warning( 508 | f"No performance forecast report found for {code} in range {start_date}-{end_date} (empty result set).") 509 | raise NoDataFoundError( 510 | f"No performance forecast report found for {code} in range {start_date}-{end_date} (empty result set).") 511 | 512 | result_df = pd.DataFrame(data_list, columns=rs.fields) 513 | logger.info( 514 | f"Retrieved {len(result_df)} performance forecast report records for {code}.") 515 | return result_df 516 | 517 | except (LoginError, NoDataFoundError, DataSourceError, ValueError) as e: 518 | logger.warning( 519 | f"Caught known error fetching performance forecast report for {code}: {type(e).__name__}") 520 | raise e 521 | except Exception as e: 522 | logger.exception( 523 | f"Unexpected error fetching performance forecast report for {code}: {e}") 524 | raise DataSourceError( 525 | f"Unexpected error fetching performance forecast report for {code}: {e}") 526 | 527 | def get_stock_industry(self, code: Optional[str] = None, date: Optional[str] = None) -> pd.DataFrame: 528 | """Fetches industry classification using Baostock.""" 529 | log_msg = f"Fetching industry data for code={code or 'all'}, date={date or 'latest'}" 530 | logger.info(log_msg) 531 | try: 532 | with baostock_login_context(): 533 | rs = bs.query_stock_industry(code=code, date=date) 534 | 535 | if rs.error_code != '0': 536 | logger.error( 537 | f"Baostock API error (Industry) for {code}, {date}: {rs.error_msg} (code: {rs.error_code})") 538 | if "no record found" in rs.error_msg.lower() or rs.error_code == '10002': 539 | raise NoDataFoundError( 540 | f"No industry data found for {code}, {date}. Baostock msg: {rs.error_msg}") 541 | else: 542 | raise DataSourceError( 543 | f"Baostock API error fetching industry data: {rs.error_msg} (code: {rs.error_code})") 544 | 545 | data_list = [] 546 | while rs.next(): 547 | data_list.append(rs.get_row_data()) 548 | 549 | if not data_list: 550 | logger.warning( 551 | f"No industry data found for {code}, {date} (empty result set).") 552 | raise NoDataFoundError( 553 | f"No industry data found for {code}, {date} (empty result set).") 554 | 555 | result_df = pd.DataFrame(data_list, columns=rs.fields) 556 | logger.info( 557 | f"Retrieved {len(result_df)} industry records for {code or 'all'}, {date or 'latest'}.") 558 | return result_df 559 | 560 | except (LoginError, NoDataFoundError, DataSourceError, ValueError) as e: 561 | logger.warning( 562 | f"Caught known error fetching industry data for {code}, {date}: {type(e).__name__}") 563 | raise e 564 | except Exception as e: 565 | logger.exception( 566 | f"Unexpected error fetching industry data for {code}, {date}: {e}") 567 | raise DataSourceError( 568 | f"Unexpected error fetching industry data for {code}, {date}: {e}") 569 | 570 | def get_sz50_stocks(self, date: Optional[str] = None) -> pd.DataFrame: 571 | """Fetches SZSE 50 index constituents using Baostock.""" 572 | return _fetch_index_constituent_data(bs.query_sz50_stocks, "SZSE 50", date) 573 | 574 | def get_hs300_stocks(self, date: Optional[str] = None) -> pd.DataFrame: 575 | """Fetches CSI 300 index constituents using Baostock.""" 576 | return _fetch_index_constituent_data(bs.query_hs300_stocks, "CSI 300", date) 577 | 578 | def get_zz500_stocks(self, date: Optional[str] = None) -> pd.DataFrame: 579 | """Fetches CSI 500 index constituents using Baostock.""" 580 | return _fetch_index_constituent_data(bs.query_zz500_stocks, "CSI 500", date) 581 | 582 | def get_trade_dates(self, start_date: Optional[str] = None, end_date: Optional[str] = None) -> pd.DataFrame: 583 | """Fetches trading dates using Baostock.""" 584 | logger.info( 585 | f"Fetching trade dates from {start_date or 'default'} to {end_date or 'default'}") 586 | try: 587 | with baostock_login_context(): # Login might not be strictly needed for this, but keeping consistent 588 | rs = bs.query_trade_dates( 589 | start_date=start_date, end_date=end_date) 590 | 591 | if rs.error_code != '0': 592 | logger.error( 593 | f"Baostock API error (Trade Dates): {rs.error_msg} (code: {rs.error_code})") 594 | # Unlikely to have 'no record found' for dates, but handle API errors 595 | raise DataSourceError( 596 | f"Baostock API error fetching trade dates: {rs.error_msg} (code: {rs.error_code})") 597 | 598 | data_list = [] 599 | while rs.next(): 600 | data_list.append(rs.get_row_data()) 601 | 602 | if not data_list: 603 | # This case should ideally not happen if the API returns a valid range 604 | logger.warning( 605 | f"No trade dates returned for range {start_date}-{end_date} (empty result set).") 606 | raise NoDataFoundError( 607 | f"No trade dates found for range {start_date}-{end_date} (empty result set).") 608 | 609 | result_df = pd.DataFrame(data_list, columns=rs.fields) 610 | logger.info(f"Retrieved {len(result_df)} trade date records.") 611 | return result_df 612 | 613 | except (LoginError, NoDataFoundError, DataSourceError, ValueError) as e: 614 | logger.warning( 615 | f"Caught known error fetching trade dates: {type(e).__name__}") 616 | raise e 617 | except Exception as e: 618 | logger.exception(f"Unexpected error fetching trade dates: {e}") 619 | raise DataSourceError( 620 | f"Unexpected error fetching trade dates: {e}") 621 | 622 | def get_all_stock(self, date: Optional[str] = None) -> pd.DataFrame: 623 | """Fetches all stock list for a given date using Baostock.""" 624 | logger.info(f"Fetching all stock list for date={date or 'default'}") 625 | try: 626 | with baostock_login_context(): 627 | rs = bs.query_all_stock(day=date) 628 | 629 | if rs.error_code != '0': 630 | logger.error( 631 | f"Baostock API error (All Stock) for date {date}: {rs.error_msg} (code: {rs.error_code})") 632 | if "no record found" in rs.error_msg.lower() or rs.error_code == '10002': # Check if this applies 633 | raise NoDataFoundError( 634 | f"No stock data found for date {date}. Baostock msg: {rs.error_msg}") 635 | else: 636 | raise DataSourceError( 637 | f"Baostock API error fetching all stock list: {rs.error_msg} (code: {rs.error_code})") 638 | 639 | data_list = [] 640 | while rs.next(): 641 | data_list.append(rs.get_row_data()) 642 | 643 | if not data_list: 644 | logger.warning( 645 | f"No stock list returned for date {date} (empty result set).") 646 | raise NoDataFoundError( 647 | f"No stock list found for date {date} (empty result set).") 648 | 649 | result_df = pd.DataFrame(data_list, columns=rs.fields) 650 | logger.info( 651 | f"Retrieved {len(result_df)} stock records for date {date or 'default'}.") 652 | return result_df 653 | 654 | except (LoginError, NoDataFoundError, DataSourceError, ValueError) as e: 655 | logger.warning( 656 | f"Caught known error fetching all stock list for date {date}: {type(e).__name__}") 657 | raise e 658 | except Exception as e: 659 | logger.exception( 660 | f"Unexpected error fetching all stock list for date {date}: {e}") 661 | raise DataSourceError( 662 | f"Unexpected error fetching all stock list for date {date}: {e}") 663 | 664 | def get_deposit_rate_data(self, start_date: Optional[str] = None, end_date: Optional[str] = None) -> pd.DataFrame: 665 | """Fetches benchmark deposit rates using Baostock.""" 666 | return _fetch_macro_data(bs.query_deposit_rate_data, "Deposit Rate", start_date, end_date) 667 | 668 | def get_loan_rate_data(self, start_date: Optional[str] = None, end_date: Optional[str] = None) -> pd.DataFrame: 669 | """Fetches benchmark loan rates using Baostock.""" 670 | return _fetch_macro_data(bs.query_loan_rate_data, "Loan Rate", start_date, end_date) 671 | 672 | def get_required_reserve_ratio_data(self, start_date: Optional[str] = None, end_date: Optional[str] = None, year_type: str = '0') -> pd.DataFrame: 673 | """Fetches required reserve ratio data using Baostock.""" 674 | # Note the extra yearType parameter handled by kwargs 675 | return _fetch_macro_data(bs.query_required_reserve_ratio_data, "Required Reserve Ratio", start_date, end_date, yearType=year_type) 676 | 677 | def get_money_supply_data_month(self, start_date: Optional[str] = None, end_date: Optional[str] = None) -> pd.DataFrame: 678 | """Fetches monthly money supply data (M0, M1, M2) using Baostock.""" 679 | # Baostock expects YYYY-MM format for dates here 680 | return _fetch_macro_data(bs.query_money_supply_data_month, "Monthly Money Supply", start_date, end_date) 681 | 682 | def get_money_supply_data_year(self, start_date: Optional[str] = None, end_date: Optional[str] = None) -> pd.DataFrame: 683 | """Fetches yearly money supply data (M0, M1, M2 - year end balance) using Baostock.""" 684 | # Baostock expects YYYY format for dates here 685 | return _fetch_macro_data(bs.query_money_supply_data_year, "Yearly Money Supply", start_date, end_date) 686 | 687 | # Note: SHIBOR is not available in current Baostock API bindings used; not implemented. 688 | ```