# Directory Structure
```
├── .gitignore
├── LICENSE
├── mcp_server.py
├── pyproject.toml
├── README.md
├── resource
│ └── img
│ ├── ali.png
│ ├── gzh_code.jpg
│ ├── img_1.png
│ ├── img_2.png
│ └── planet.jpg
├── src
│ ├── __init__.py
│ ├── baostock_data_source.py
│ ├── data_source_interface.py
│ ├── formatting
│ │ ├── __init__.py
│ │ └── markdown_formatter.py
│ ├── services
│ │ ├── __init__.py
│ │ ├── tool_runner.py
│ │ └── validation.py
│ ├── tools
│ │ ├── __init__.py
│ │ ├── analysis.py
│ │ ├── date_utils.py
│ │ ├── financial_reports.py
│ │ ├── helpers.py
│ │ ├── indices.py
│ │ ├── macroeconomic.py
│ │ ├── market_overview.py
│ │ └── stock_market.py
│ ├── use_cases
│ │ ├── __init__.py
│ │ ├── analysis.py
│ │ ├── date_utils.py
│ │ ├── financial_reports.py
│ │ ├── helpers.py
│ │ ├── indices.py
│ │ ├── macroeconomic.py
│ │ ├── market_overview.py
│ │ └── stock_market.py
│ └── utils.py
└── uv.lock
```
# Files
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
```
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 | cover/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 | db.sqlite3
63 | db.sqlite3-journal
64 |
65 | # Flask stuff:
66 | instance/
67 | .webassets-cache
68 |
69 | # Scrapy stuff:
70 | .scrapy
71 |
72 | # Sphinx documentation
73 | docs/_build/
74 | AGENTS.md
75 | # PyBuilder
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | .python-version
87 |
88 | # pipenv
89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
90 | # Pipfile.lock
91 |
92 | # PEP 582; used by PDM, PEP 582 proposal
93 | __pypackages__/
94 |
95 | # Celery stuff
96 | celerybeat-schedule
97 | celerybeat.pid
98 |
99 | # SageMath parsed files
100 | *.sage.py
101 |
102 | # Environments
103 | .env
104 | .venv
105 | env/
106 | venv/
107 | ENV/
108 | env.bak/
109 | venv.bak/
110 |
111 | # Spyder project settings
112 | .spyderproject
113 | .spyproject
114 |
115 | # Rope project settings
116 | .ropeproject
117 |
118 | # mkdocs documentation
119 | /site
120 |
121 | # mypy
122 | .mypy_cache/
123 | .dmypy.json
124 | dmypy.json
125 |
126 | # Pyre type checker
127 | .pyre/
128 |
129 | # pytype static type analyzer
130 | .pytype/
131 |
132 | # Cython debug symbols
133 | cython_debug/
134 |
135 | # VS Code settings
136 | .vscode/
137 |
138 | docs/
139 |
140 | # Local tests (do not commit real-data logs or scripts)
141 | test/
142 |
143 |
144 |
145 |
146 |
147 |
```
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
```python
1 | # This file makes src a Python package
2 |
```
--------------------------------------------------------------------------------
/src/tools/__init__.py:
--------------------------------------------------------------------------------
```python
1 | # Initialization file for tools package
2 |
```
--------------------------------------------------------------------------------
/src/formatting/__init__.py:
--------------------------------------------------------------------------------
```python
1 | # Initialization file for formatting package
2 |
```
--------------------------------------------------------------------------------
/src/use_cases/__init__.py:
--------------------------------------------------------------------------------
```python
1 | """Use case layer to keep tools thin and consistent."""
2 |
```
--------------------------------------------------------------------------------
/src/services/__init__.py:
--------------------------------------------------------------------------------
```python
1 | """Shared services for validation and execution helpers."""
2 |
```
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
```toml
1 | [project]
2 | name = "a-share-mcp"
3 | version = "0.1.0"
4 | description = "Add your description here"
5 | readme = "README.md"
6 | requires-python = ">=3.12"
7 | dependencies = [
8 | "baostock>=0.8.9",
9 | "httpx>=0.28.1",
10 | "mcp[cli]>=1.2.0",
11 | "pandas>=2.2.3",
12 | "annotated-types>=0.7.0",
13 | "anyio>=4.9.0",
14 | "certifi>=2025.4.26",
15 | "click>=8.1.8",
16 | "colorama>=0.4.6",
17 | "h11>=0.16.0",
18 | "httpcore>=1.0.9",
19 | "httpx-sse>=0.4.0",
20 | "idna>=3.10",
21 | "markdown-it-py>=3.0.0",
22 | "mdurl>=0.1.2",
23 | "numpy>=2.2.5",
24 | "pydantic>=2.11.3",
25 | "pydantic-core>=2.33.1",
26 | "pydantic-settings>=2.9.1",
27 | "pygments>=2.19.1",
28 | "python-dateutil>=2.9.0",
29 | "python-dotenv>=1.1.0",
30 | "pytz>=2025.2",
31 | "rich>=14.0.0",
32 | "shellingham>=1.5.4",
33 | "six>=1.17.0",
34 | "sniffio>=1.3.1",
35 | "sse-starlette>=2.3.3",
36 | "starlette>=0.46.2",
37 | "tabulate>=0.9.0",
38 | "typer>=0.15.3",
39 | "typing-extensions>=4.13.2",
40 | "typing-inspection>=0.4.0",
41 | "tzdata>=2025.2",
42 | "uvicorn>=0.34.2",
43 | ]
44 |
45 | [tool.uv]
46 | dev-dependencies = [
47 | "pytest>=8.3.0",
48 | ]
49 |
```
--------------------------------------------------------------------------------
/src/tools/analysis.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Analysis tools for MCP server.
3 | Delegates heavy lifting to use-case layer.
4 | """
5 | import logging
6 |
7 | from mcp.server.fastmcp import FastMCP
8 | from src.data_source_interface import FinancialDataSource
9 | from src.services.tool_runner import run_tool_with_handling
10 | from src.use_cases.analysis import build_stock_analysis_report
11 |
12 | logger = logging.getLogger(__name__)
13 |
14 |
15 | def register_analysis_tools(app: FastMCP, active_data_source: FinancialDataSource):
16 | """Register analysis tools."""
17 |
18 | @app.tool()
19 | def get_stock_analysis(code: str, analysis_type: str = "fundamental") -> str:
20 | """
21 | 提供基于数据的股票分析报告,而非投资建议。
22 |
23 | Args:
24 | code: 股票代码,如'sh.600000'
25 | analysis_type: 'fundamental'|'technical'|'comprehensive'
26 | """
27 | logger.info(f"Tool 'get_stock_analysis' called for {code}, type={analysis_type}")
28 | return run_tool_with_handling(
29 | lambda: build_stock_analysis_report(active_data_source, code=code, analysis_type=analysis_type),
30 | context=f"get_stock_analysis:{code}:{analysis_type}",
31 | )
32 |
```
--------------------------------------------------------------------------------
/src/use_cases/helpers.py:
--------------------------------------------------------------------------------
```python
1 | """Helper use cases for normalization utilities."""
2 | import re
3 |
4 | from src.services.validation import validate_non_empty_str
5 |
6 |
7 | def normalize_stock_code_logic(code: str) -> str:
8 | validate_non_empty_str(code, "code")
9 | raw = code.strip()
10 |
11 | m = re.fullmatch(r"(?i)(sh|sz)[.]?(\d{6})", raw)
12 | if m:
13 | ex, num = m.group(1).lower(), m.group(2)
14 | return f"{ex}.{num}"
15 |
16 | m2 = re.fullmatch(r"(\d{6})[.]?(?i:(sh|sz))", raw)
17 | if m2:
18 | num, ex = m2.group(1), m2.group(2).lower()
19 | return f"{ex}.{num}"
20 |
21 | m3 = re.fullmatch(r"(\d{6})", raw)
22 | if m3:
23 | num = m3.group(1)
24 | ex = "sh" if num.startswith("6") else "sz"
25 | return f"{ex}.{num}"
26 |
27 | raise ValueError("Unsupported code format. Examples: 'sh.600000', '600000', '000001.SZ'.")
28 |
29 |
30 | def normalize_index_code_logic(code: str) -> str:
31 | validate_non_empty_str(code, "code")
32 | raw = code.strip().upper()
33 | if raw in {"000300", "CSI300", "HS300"}:
34 | return "sh.000300"
35 | if raw in {"000016", "SSE50", "SZ50"}:
36 | return "sh.000016"
37 | if raw in {"000905", "ZZ500", "CSI500"}:
38 | return "sh.000905"
39 | raise ValueError("Unsupported index code. Examples: 000300/CSI300/HS300, 000016, 000905.")
40 |
```
--------------------------------------------------------------------------------
/src/services/tool_runner.py:
--------------------------------------------------------------------------------
```python
1 | """Common error handling for MCP tools."""
2 | import logging
3 | from typing import Callable
4 |
5 | from src.data_source_interface import NoDataFoundError, LoginError, DataSourceError
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 |
10 | def run_tool_with_handling(action: Callable[[], str], context: str) -> str:
11 | """
12 | Executes a callable and normalizes exceptions to user-friendly strings.
13 |
14 | Args:
15 | action: Callable returning a string (typically formatted output).
16 | context: Short description for logs.
17 | """
18 | try:
19 | return action()
20 | except NoDataFoundError as e:
21 | logger.warning(f"{context}: No data found: {e}")
22 | return f"Error: {e}"
23 | except LoginError as e:
24 | logger.error(f"{context}: Login error: {e}")
25 | return f"Error: Could not connect to data source. {e}"
26 | except DataSourceError as e:
27 | logger.error(f"{context}: Data source error: {e}")
28 | return f"Error: An error occurred while fetching data. {e}"
29 | except ValueError as e:
30 | logger.warning(f"{context}: Validation error: {e}")
31 | return f"Error: Invalid input parameter. {e}"
32 | except Exception as e: # Catch-all
33 | logger.exception(f"{context}: Unexpected error: {e}")
34 | return f"Error: An unexpected error occurred: {e}"
35 |
```
--------------------------------------------------------------------------------
/src/services/validation.py:
--------------------------------------------------------------------------------
```python
1 | """Validation utilities for tool inputs."""
2 | from typing import Iterable
3 |
4 | VALID_FREQS = ["d", "w", "m", "5", "15", "30", "60"]
5 | VALID_ADJUST_FLAGS = ["1", "2", "3"]
6 | VALID_FORMATS = ["markdown", "json", "csv"]
7 | VALID_YEAR_TYPES = ["report", "operate"]
8 | VALID_RESERVE_YEAR_TYPES = ["0", "1", "2"]
9 |
10 |
11 | def _ensure_in(value: str, allowed: Iterable[str], label: str) -> None:
12 | if value not in allowed:
13 | raise ValueError(f"Invalid {label} '{value}'. Valid options are: {list(allowed)}")
14 |
15 |
16 | def validate_frequency(frequency: str) -> None:
17 | _ensure_in(frequency, VALID_FREQS, "frequency")
18 |
19 |
20 | def validate_adjust_flag(adjust_flag: str) -> None:
21 | _ensure_in(adjust_flag, VALID_ADJUST_FLAGS, "adjust_flag")
22 |
23 |
24 | def validate_output_format(fmt: str) -> None:
25 | _ensure_in(fmt, VALID_FORMATS, "format")
26 |
27 |
28 | def validate_year(year: str) -> None:
29 | if not year.isdigit() or len(year) != 4:
30 | raise ValueError(f"Invalid year '{year}'. Please provide a 4-digit year.")
31 |
32 |
33 | def validate_year_type(year_type: str) -> None:
34 | _ensure_in(year_type, VALID_YEAR_TYPES, "year_type")
35 |
36 |
37 | def validate_quarter(quarter: int) -> None:
38 | if quarter not in (1, 2, 3, 4):
39 | raise ValueError("Invalid quarter. Must be between 1 and 4.")
40 |
41 |
42 | def validate_non_empty_str(value: str, label: str) -> None:
43 | if value is None or not str(value).strip():
44 | raise ValueError(f"'{label}' is required.")
45 |
46 |
47 | def validate_index_key(value: str, mapping: dict) -> str:
48 | key = mapping.get(value.lower()) if isinstance(value, str) else None
49 | if not key:
50 | raise ValueError(f"Invalid index '{value}'. Valid options: {sorted(set(mapping.values()))}")
51 | return key
52 |
53 |
54 | def validate_year_type_reserve(year_type: str) -> None:
55 | _ensure_in(year_type, VALID_RESERVE_YEAR_TYPES, "year_type")
56 |
57 |
58 | def validate_limit(limit: int) -> None:
59 | if limit <= 0:
60 | raise ValueError("limit must be positive.")
61 |
```
--------------------------------------------------------------------------------
/src/use_cases/market_overview.py:
--------------------------------------------------------------------------------
```python
1 | """Use cases for market overview tools."""
2 | from typing import Optional
3 |
4 | from src.data_source_interface import FinancialDataSource
5 | from src.formatting.markdown_formatter import format_table_output
6 | from src.services.validation import validate_output_format, validate_non_empty_str
7 |
8 |
9 | def fetch_trade_dates(data_source: FinancialDataSource, *, start_date: Optional[str], end_date: Optional[str], limit: int, format: str) -> str:
10 | validate_output_format(format)
11 | df = data_source.get_trade_dates(start_date=start_date, end_date=end_date)
12 | meta = {"start_date": start_date or "default", "end_date": end_date or "default"}
13 | return format_table_output(df, format=format, max_rows=limit, meta=meta)
14 |
15 |
16 | def fetch_all_stock(data_source: FinancialDataSource, *, date: Optional[str], limit: int, format: str) -> str:
17 | validate_output_format(format)
18 | df = data_source.get_all_stock(date=date)
19 | meta = {"as_of": date or "default"}
20 | return format_table_output(df, format=format, max_rows=limit, meta=meta)
21 |
22 |
23 | def fetch_search_stocks(data_source: FinancialDataSource, *, keyword: str, date: Optional[str], limit: int, format: str) -> str:
24 | validate_output_format(format)
25 | validate_non_empty_str(keyword, "keyword")
26 | df = data_source.get_all_stock(date=date)
27 | if df is None or df.empty:
28 | return "(No data available to display)"
29 | kw = keyword.strip().lower()
30 | filtered = df[df["code"].str.lower().str.contains(kw, na=False)]
31 | meta = {"keyword": keyword, "as_of": date or "current"}
32 | return format_table_output(filtered, format=format, max_rows=limit, meta=meta)
33 |
34 |
35 | def fetch_suspensions(data_source: FinancialDataSource, *, date: Optional[str], limit: int, format: str) -> str:
36 | validate_output_format(format)
37 | df = data_source.get_all_stock(date=date)
38 | if df is None or df.empty:
39 | return "(No data available to display)"
40 | if "tradeStatus" not in df.columns:
41 | raise ValueError("'tradeStatus' column not present in data source response.")
42 | suspended = df[df["tradeStatus"] == '0']
43 | meta = {"as_of": date or "current", "total_suspended": int(suspended.shape[0])}
44 | return format_table_output(suspended, format=format, max_rows=limit, meta=meta)
45 |
```
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
```python
1 | # Utility functions, including the Baostock login context manager and logging setup
2 | import baostock as bs
3 | import os
4 | import sys
5 | import logging
6 | from contextlib import contextmanager
7 | from .data_source_interface import LoginError
8 |
9 | # --- Logging Setup ---
10 | def setup_logging(level=logging.INFO):
11 | """Configures basic logging for the application."""
12 | logging.basicConfig(
13 | level=level,
14 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
15 | datefmt='%Y-%m-%d %H:%M:%S'
16 | )
17 | # Optionally silence logs from dependencies if they are too verbose
18 | # logging.getLogger("mcp").setLevel(logging.WARNING)
19 |
20 | # Get a logger instance for this module (optional, but good practice)
21 | logger = logging.getLogger(__name__)
22 |
23 | # --- Baostock Context Manager ---
24 | @contextmanager
25 | def baostock_login_context():
26 | """Context manager to handle Baostock login and logout, suppressing stdout messages."""
27 | # Redirect stdout to suppress login/logout messages
28 | original_stdout_fd = sys.stdout.fileno()
29 | saved_stdout_fd = os.dup(original_stdout_fd)
30 | devnull_fd = os.open(os.devnull, os.O_WRONLY)
31 |
32 | os.dup2(devnull_fd, original_stdout_fd)
33 | os.close(devnull_fd)
34 |
35 | logger.debug("Attempting Baostock login...")
36 | lg = bs.login()
37 | logger.debug(f"Login result: code={lg.error_code}, msg={lg.error_msg}")
38 |
39 | # Restore stdout
40 | os.dup2(saved_stdout_fd, original_stdout_fd)
41 | os.close(saved_stdout_fd)
42 |
43 | if lg.error_code != '0':
44 | # Log error before raising
45 | logger.error(f"Baostock login failed: {lg.error_msg}")
46 | raise LoginError(f"Baostock login failed: {lg.error_msg}")
47 |
48 | logger.info("Baostock login successful.")
49 | try:
50 | yield # API calls happen here
51 | finally:
52 | # Redirect stdout again for logout
53 | original_stdout_fd = sys.stdout.fileno()
54 | saved_stdout_fd = os.dup(original_stdout_fd)
55 | devnull_fd = os.open(os.devnull, os.O_WRONLY)
56 |
57 | os.dup2(devnull_fd, original_stdout_fd)
58 | os.close(devnull_fd)
59 |
60 | logger.debug("Attempting Baostock logout...")
61 | bs.logout()
62 | logger.debug("Logout completed.")
63 |
64 | # Restore stdout
65 | os.dup2(saved_stdout_fd, original_stdout_fd)
66 | os.close(saved_stdout_fd)
67 | logger.info("Baostock logout successful.")
68 |
69 | # You can add other utility functions or classes here if needed
70 |
```
--------------------------------------------------------------------------------
/src/use_cases/indices.py:
--------------------------------------------------------------------------------
```python
1 | """Use cases for index and industry related tools."""
2 | from typing import Optional
3 |
4 | from src.data_source_interface import FinancialDataSource
5 | from src.formatting.markdown_formatter import format_table_output
6 | from src.services.validation import validate_output_format, validate_index_key, validate_non_empty_str
7 |
8 | INDEX_MAP = {
9 | "hs300": "hs300",
10 | "沪深300": "hs300",
11 | "zz500": "zz500",
12 | "中证500": "zz500",
13 | "sz50": "sz50",
14 | "上证50": "sz50",
15 | }
16 |
17 |
18 | def fetch_stock_industry(data_source: FinancialDataSource, *, code: Optional[str], date: Optional[str], limit: int, format: str) -> str:
19 | validate_output_format(format)
20 | df = data_source.get_stock_industry(code=code, date=date)
21 | meta = {"code": code or "all", "as_of": date or "latest"}
22 | return format_table_output(df, format=format, max_rows=limit, meta=meta)
23 |
24 |
25 | def fetch_index_constituents(data_source: FinancialDataSource, *, index: str, date: Optional[str], limit: int, format: str) -> str:
26 | validate_output_format(format)
27 | key = validate_index_key(index, INDEX_MAP)
28 | if key == "hs300":
29 | df = data_source.get_hs300_stocks(date=date)
30 | elif key == "sz50":
31 | df = data_source.get_sz50_stocks(date=date)
32 | else:
33 | df = data_source.get_zz500_stocks(date=date)
34 | meta = {"index": key, "as_of": date or "latest"}
35 | return format_table_output(df, format=format, max_rows=limit, meta=meta)
36 |
37 |
38 | def fetch_list_industries(data_source: FinancialDataSource, *, date: Optional[str], format: str) -> str:
39 | validate_output_format(format)
40 | df = data_source.get_stock_industry(code=None, date=date)
41 | if df is None or df.empty:
42 | return "(No data available to display)"
43 | col = "industry" if "industry" in df.columns else df.columns[-1]
44 | out = df[[col]].drop_duplicates().sort_values(by=col)
45 | out = out.rename(columns={col: "industry"})
46 | meta = {"as_of": date or "latest", "count": int(out.shape[0])}
47 | return format_table_output(out, format=format, max_rows=out.shape[0], meta=meta)
48 |
49 |
50 | def fetch_industry_members(data_source: FinancialDataSource, *, industry: str, date: Optional[str], limit: int, format: str) -> str:
51 | validate_output_format(format)
52 | validate_non_empty_str(industry, "industry")
53 | df = data_source.get_stock_industry(code=None, date=date)
54 | if df is None or df.empty:
55 | return "(No data available to display)"
56 | col = "industry" if "industry" in df.columns else df.columns[-1]
57 | filtered = df[df[col] == industry].copy()
58 | meta = {"industry": industry, "as_of": date or "latest"}
59 | return format_table_output(filtered, format=format, max_rows=limit, meta=meta)
60 |
```
--------------------------------------------------------------------------------
/src/use_cases/macroeconomic.py:
--------------------------------------------------------------------------------
```python
1 | """Use cases for macroeconomic data tools."""
2 | from typing import Optional
3 |
4 | from src.data_source_interface import FinancialDataSource
5 | from src.formatting.markdown_formatter import format_table_output
6 | from src.services.validation import validate_output_format, validate_year_type_reserve
7 |
8 |
9 | def fetch_deposit_rate_data(data_source: FinancialDataSource, *, start_date: Optional[str], end_date: Optional[str], limit: int, format: str) -> str:
10 | validate_output_format(format)
11 | df = data_source.get_deposit_rate_data(start_date=start_date, end_date=end_date)
12 | meta = {"dataset": "deposit_rate", "start_date": start_date, "end_date": end_date}
13 | return format_table_output(df, format=format, max_rows=limit, meta=meta)
14 |
15 |
16 | def fetch_loan_rate_data(data_source: FinancialDataSource, *, start_date: Optional[str], end_date: Optional[str], limit: int, format: str) -> str:
17 | validate_output_format(format)
18 | df = data_source.get_loan_rate_data(start_date=start_date, end_date=end_date)
19 | meta = {"dataset": "loan_rate", "start_date": start_date, "end_date": end_date}
20 | return format_table_output(df, format=format, max_rows=limit, meta=meta)
21 |
22 |
23 | def fetch_required_reserve_ratio_data(data_source: FinancialDataSource, *, start_date: Optional[str], end_date: Optional[str], year_type: str, limit: int, format: str) -> str:
24 | validate_output_format(format)
25 | validate_year_type_reserve(year_type)
26 | df = data_source.get_required_reserve_ratio_data(start_date=start_date, end_date=end_date, year_type=year_type)
27 | meta = {"dataset": "required_reserve_ratio", "start_date": start_date, "end_date": end_date, "year_type": year_type}
28 | return format_table_output(df, format=format, max_rows=limit, meta=meta)
29 |
30 |
31 | def fetch_money_supply_data_month(data_source: FinancialDataSource, *, start_date: Optional[str], end_date: Optional[str], limit: int, format: str) -> str:
32 | validate_output_format(format)
33 | df = data_source.get_money_supply_data_month(start_date=start_date, end_date=end_date)
34 | meta = {"dataset": "money_supply_month", "start_date": start_date, "end_date": end_date}
35 | return format_table_output(df, format=format, max_rows=limit, meta=meta)
36 |
37 |
38 | def fetch_money_supply_data_year(data_source: FinancialDataSource, *, start_date: Optional[str], end_date: Optional[str], limit: int, format: str) -> str:
39 | validate_output_format(format)
40 | df = data_source.get_money_supply_data_year(start_date=start_date, end_date=end_date)
41 | meta = {"dataset": "money_supply_year", "start_date": start_date, "end_date": end_date}
42 | return format_table_output(df, format=format, max_rows=limit, meta=meta)
43 |
```
--------------------------------------------------------------------------------
/src/use_cases/stock_market.py:
--------------------------------------------------------------------------------
```python
1 | """Stock market use cases orchestrating data fetch and formatting."""
2 | from typing import List, Optional
3 |
4 | import pandas as pd
5 |
6 | from src.data_source_interface import FinancialDataSource
7 | from src.formatting.markdown_formatter import format_table_output
8 | from src.services.validation import (
9 | validate_adjust_flag,
10 | validate_frequency,
11 | validate_output_format,
12 | validate_year,
13 | validate_year_type,
14 | )
15 |
16 |
17 | def fetch_historical_k_data(
18 | data_source: FinancialDataSource,
19 | *,
20 | code: str,
21 | start_date: str,
22 | end_date: str,
23 | frequency: str = "d",
24 | adjust_flag: str = "3",
25 | fields: Optional[List[str]] = None,
26 | limit: int = 250,
27 | format: str = "markdown",
28 | ) -> str:
29 | validate_frequency(frequency)
30 | validate_adjust_flag(adjust_flag)
31 | validate_output_format(format)
32 |
33 | df = data_source.get_historical_k_data(
34 | code=code,
35 | start_date=start_date,
36 | end_date=end_date,
37 | frequency=frequency,
38 | adjust_flag=adjust_flag,
39 | fields=fields,
40 | )
41 | meta = {
42 | "code": code,
43 | "start_date": start_date,
44 | "end_date": end_date,
45 | "frequency": frequency,
46 | "adjust_flag": adjust_flag,
47 | }
48 | return format_table_output(df, format=format, max_rows=limit, meta=meta)
49 |
50 |
51 | def fetch_stock_basic_info(
52 | data_source: FinancialDataSource,
53 | *,
54 | code: str,
55 | fields: Optional[List[str]] = None,
56 | format: str = "markdown",
57 | ) -> str:
58 | validate_output_format(format)
59 | df = data_source.get_stock_basic_info(code=code, fields=fields)
60 | meta = {"code": code}
61 | return format_table_output(df, format=format, max_rows=df.shape[0] if df is not None else 0, meta=meta)
62 |
63 |
64 | def fetch_dividend_data(
65 | data_source: FinancialDataSource,
66 | *,
67 | code: str,
68 | year: str,
69 | year_type: str = "report",
70 | limit: int = 250,
71 | format: str = "markdown",
72 | ) -> str:
73 | validate_year(year)
74 | validate_year_type(year_type)
75 | validate_output_format(format)
76 |
77 | df = data_source.get_dividend_data(code=code, year=year, year_type=year_type)
78 | meta = {"code": code, "year": year, "year_type": year_type}
79 | return format_table_output(df, format=format, max_rows=limit, meta=meta)
80 |
81 |
82 | def fetch_adjust_factor_data(
83 | data_source: FinancialDataSource,
84 | *,
85 | code: str,
86 | start_date: str,
87 | end_date: str,
88 | limit: int = 250,
89 | format: str = "markdown",
90 | ) -> str:
91 | validate_output_format(format)
92 | df = data_source.get_adjust_factor_data(code=code, start_date=start_date, end_date=end_date)
93 | meta = {"code": code, "start_date": start_date, "end_date": end_date}
94 | return format_table_output(df, format=format, max_rows=limit, meta=meta)
95 |
```
--------------------------------------------------------------------------------
/src/tools/helpers.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Helper tools for code normalization and constants discovery.
3 | Uses shared validation and helper logic.
4 | """
5 | import logging
6 | from typing import Optional
7 |
8 | from mcp.server.fastmcp import FastMCP
9 | from src.services.tool_runner import run_tool_with_handling
10 | from src.use_cases.helpers import normalize_index_code_logic, normalize_stock_code_logic
11 |
12 | logger = logging.getLogger(__name__)
13 |
14 |
15 | def register_helpers_tools(app: FastMCP):
16 | """Register helper/utility tools with the MCP app."""
17 |
18 | @app.tool()
19 | def normalize_stock_code(code: str) -> str:
20 | """Normalize a stock code to Baostock format."""
21 | logger.info("Tool 'normalize_stock_code' called with input=%s", code)
22 | return run_tool_with_handling(
23 | lambda: normalize_stock_code_logic(code),
24 | context="normalize_stock_code",
25 | )
26 |
27 | @app.tool()
28 | def normalize_index_code(code: str) -> str:
29 | """Normalize common index codes to Baostock format."""
30 | logger.info("Tool 'normalize_index_code' called with input=%s", code)
31 | return run_tool_with_handling(
32 | lambda: normalize_index_code_logic(code),
33 | context="normalize_index_code",
34 | )
35 |
36 | @app.tool()
37 | def list_tool_constants(kind: Optional[str] = None) -> str:
38 | """
39 | List valid constants for tool parameters.
40 |
41 | Args:
42 | kind: Optional filter: 'frequency' | 'adjust_flag' | 'year_type' | 'index'. If None, show all.
43 | """
44 | logger.info("Tool 'list_tool_constants' called kind=%s", kind or "all")
45 | freq = [
46 | ("d", "daily"), ("w", "weekly"), ("m", "monthly"),
47 | ("5", "5 minutes"), ("15", "15 minutes"), ("30", "30 minutes"), ("60", "60 minutes"),
48 | ]
49 | adjust = [("1", "forward adjusted"), ("2", "backward adjusted"), ("3", "unadjusted")]
50 | year_type = [("report", "announcement year"), ("operate", "ex-dividend year")]
51 | index = [("hs300", "CSI 300"), ("sz50", "SSE 50"), ("zz500", "CSI 500")]
52 |
53 | sections = []
54 |
55 | def as_md(title: str, rows):
56 | if not rows:
57 | return ""
58 | header = f"### {title}\n\n| value | meaning |\n|---|---|\n"
59 | lines = [f"| {v} | {m} |" for (v, m) in rows]
60 | return header + "\n".join(lines) + "\n"
61 |
62 | k = (kind or "").strip().lower()
63 | if k in ("", "frequency"):
64 | sections.append(as_md("frequency", freq))
65 | if k in ("", "adjust_flag"):
66 | sections.append(as_md("adjust_flag", adjust))
67 | if k in ("", "year_type"):
68 | sections.append(as_md("year_type", year_type))
69 | if k in ("", "index"):
70 | sections.append(as_md("index", index))
71 |
72 | out = "\n".join(s for s in sections if s)
73 | if not out:
74 | return "Error: Invalid kind. Use one of 'frequency', 'adjust_flag', 'year_type', 'index'."
75 | return out
76 |
```
--------------------------------------------------------------------------------
/src/tools/macroeconomic.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Macroeconomic tools for the MCP server.
3 | Delegates to use cases with shared validation and error handling.
4 | """
5 | import logging
6 | from typing import Optional
7 |
8 | from mcp.server.fastmcp import FastMCP
9 | from src.data_source_interface import FinancialDataSource
10 | from src.services.tool_runner import run_tool_with_handling
11 | from src.use_cases.macroeconomic import (
12 | fetch_deposit_rate_data,
13 | fetch_loan_rate_data,
14 | fetch_money_supply_data_month,
15 | fetch_money_supply_data_year,
16 | fetch_required_reserve_ratio_data,
17 | )
18 |
19 | logger = logging.getLogger(__name__)
20 |
21 |
22 | def register_macroeconomic_tools(app: FastMCP, active_data_source: FinancialDataSource):
23 | """Register macroeconomic tools."""
24 |
25 | @app.tool()
26 | def get_deposit_rate_data(start_date: Optional[str] = None, end_date: Optional[str] = None, limit: int = 250, format: str = "markdown") -> str:
27 | """Benchmark deposit rates."""
28 | return run_tool_with_handling(
29 | lambda: fetch_deposit_rate_data(active_data_source, start_date=start_date, end_date=end_date, limit=limit, format=format),
30 | context="get_deposit_rate_data",
31 | )
32 |
33 | @app.tool()
34 | def get_loan_rate_data(start_date: Optional[str] = None, end_date: Optional[str] = None, limit: int = 250, format: str = "markdown") -> str:
35 | """Benchmark loan rates."""
36 | return run_tool_with_handling(
37 | lambda: fetch_loan_rate_data(active_data_source, start_date=start_date, end_date=end_date, limit=limit, format=format),
38 | context="get_loan_rate_data",
39 | )
40 |
41 | @app.tool()
42 | def get_required_reserve_ratio_data(start_date: Optional[str] = None, end_date: Optional[str] = None, year_type: str = '0', limit: int = 250, format: str = "markdown") -> str:
43 | """Required reserve ratio data."""
44 | return run_tool_with_handling(
45 | lambda: fetch_required_reserve_ratio_data(
46 | active_data_source, start_date=start_date, end_date=end_date, year_type=year_type, limit=limit, format=format
47 | ),
48 | context="get_required_reserve_ratio_data",
49 | )
50 |
51 | @app.tool()
52 | def get_money_supply_data_month(start_date: Optional[str] = None, end_date: Optional[str] = None, limit: int = 250, format: str = "markdown") -> str:
53 | """Monthly money supply data."""
54 | return run_tool_with_handling(
55 | lambda: fetch_money_supply_data_month(
56 | active_data_source, start_date=start_date, end_date=end_date, limit=limit, format=format
57 | ),
58 | context="get_money_supply_data_month",
59 | )
60 |
61 | @app.tool()
62 | def get_money_supply_data_year(start_date: Optional[str] = None, end_date: Optional[str] = None, limit: int = 250, format: str = "markdown") -> str:
63 | """Yearly money supply data."""
64 | return run_tool_with_handling(
65 | lambda: fetch_money_supply_data_year(
66 | active_data_source, start_date=start_date, end_date=end_date, limit=limit, format=format
67 | ),
68 | context="get_money_supply_data_year",
69 | )
70 |
```
--------------------------------------------------------------------------------
/src/tools/date_utils.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Date utility tools for the MCP server.
3 | Delegates to use-case layer for consistent behavior.
4 | """
5 | import logging
6 | from typing import Optional
7 |
8 | from mcp.server.fastmcp import FastMCP
9 | from src.data_source_interface import FinancialDataSource
10 | from src.services.tool_runner import run_tool_with_handling
11 | from src.use_cases import date_utils as uc_date
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 |
16 | def register_date_utils_tools(app: FastMCP, active_data_source: FinancialDataSource):
17 | """Register date utility tools."""
18 |
19 | @app.tool()
20 | def get_latest_trading_date() -> str:
21 | """Get the latest trading date up to today."""
22 | logger.info("Tool 'get_latest_trading_date' called")
23 | return run_tool_with_handling(
24 | lambda: uc_date.get_latest_trading_date(active_data_source),
25 | context="get_latest_trading_date",
26 | )
27 |
28 | @app.tool()
29 | def get_market_analysis_timeframe(period: str = "recent") -> str:
30 | """Return a human-friendly timeframe label."""
31 | logger.info(f"Tool 'get_market_analysis_timeframe' called with period={period}")
32 | return run_tool_with_handling(
33 | lambda: uc_date.get_market_analysis_timeframe(period=period),
34 | context="get_market_analysis_timeframe",
35 | )
36 |
37 | @app.tool()
38 | def is_trading_day(date: str) -> str:
39 | """Check if a specific date is a trading day."""
40 | return run_tool_with_handling(
41 | lambda: uc_date.is_trading_day(active_data_source, date=date),
42 | context=f"is_trading_day:{date}",
43 | )
44 |
45 | @app.tool()
46 | def previous_trading_day(date: str) -> str:
47 | """Get the previous trading day before the given date."""
48 | return run_tool_with_handling(
49 | lambda: uc_date.previous_trading_day(active_data_source, date=date),
50 | context=f"previous_trading_day:{date}",
51 | )
52 |
53 | @app.tool()
54 | def next_trading_day(date: str) -> str:
55 | """Get the next trading day after the given date."""
56 | return run_tool_with_handling(
57 | lambda: uc_date.next_trading_day(active_data_source, date=date),
58 | context=f"next_trading_day:{date}",
59 | )
60 |
61 | @app.tool()
62 | def get_last_n_trading_days(days: int = 5) -> str:
63 | """Return the last N trading dates."""
64 | return run_tool_with_handling(
65 | lambda: uc_date.get_last_n_trading_days(active_data_source, days=days),
66 | context=f"get_last_n_trading_days:{days}",
67 | )
68 |
69 | @app.tool()
70 | def get_recent_trading_range(days: int = 5) -> str:
71 | """Return a date range string covering the recent N trading days."""
72 | return run_tool_with_handling(
73 | lambda: uc_date.get_recent_trading_range(active_data_source, days=days),
74 | context=f"get_recent_trading_range:{days}",
75 | )
76 |
77 | @app.tool()
78 | def get_month_end_trading_dates(year: int) -> str:
79 | """Return month-end trading dates for a given year."""
80 | return run_tool_with_handling(
81 | lambda: uc_date.get_month_end_trading_dates(active_data_source, year=year),
82 | context=f"get_month_end_trading_dates:{year}",
83 | )
84 |
```
--------------------------------------------------------------------------------
/src/tools/indices.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Index-related tools for the MCP server.
3 | Delegates to use-case layer for validation and formatting.
4 | """
5 | import logging
6 | from typing import Optional
7 |
8 | from mcp.server.fastmcp import FastMCP
9 | from src.data_source_interface import FinancialDataSource
10 | from src.services.tool_runner import run_tool_with_handling
11 | from src.use_cases.indices import (
12 | fetch_index_constituents,
13 | fetch_industry_members,
14 | fetch_list_industries,
15 | fetch_stock_industry,
16 | )
17 |
18 | logger = logging.getLogger(__name__)
19 |
20 |
21 | def register_index_tools(app: FastMCP, active_data_source: FinancialDataSource):
22 | """Register index related tools with the MCP app."""
23 |
24 | @app.tool()
25 | def get_stock_industry(code: Optional[str] = None, date: Optional[str] = None, limit: int = 250, format: str = "markdown") -> str:
26 | """Get industry classification for a specific stock or all stocks on a date."""
27 | logger.info(f"Tool 'get_stock_industry' called for code={code or 'all'}, date={date or 'latest'}")
28 | return run_tool_with_handling(
29 | lambda: fetch_stock_industry(active_data_source, code=code, date=date, limit=limit, format=format),
30 | context=f"get_stock_industry:{code or 'all'}",
31 | )
32 |
33 | @app.tool()
34 | def get_sz50_stocks(date: Optional[str] = None, limit: int = 250, format: str = "markdown") -> str:
35 | """SZSE 50 constituents."""
36 | return run_tool_with_handling(
37 | lambda: fetch_index_constituents(active_data_source, index="sz50", date=date, limit=limit, format=format),
38 | context="get_sz50_stocks",
39 | )
40 |
41 | @app.tool()
42 | def get_hs300_stocks(date: Optional[str] = None, limit: int = 250, format: str = "markdown") -> str:
43 | """CSI 300 constituents."""
44 | return run_tool_with_handling(
45 | lambda: fetch_index_constituents(active_data_source, index="hs300", date=date, limit=limit, format=format),
46 | context="get_hs300_stocks",
47 | )
48 |
49 | @app.tool()
50 | def get_zz500_stocks(date: Optional[str] = None, limit: int = 250, format: str = "markdown") -> str:
51 | """CSI 500 constituents."""
52 | return run_tool_with_handling(
53 | lambda: fetch_index_constituents(active_data_source, index="zz500", date=date, limit=limit, format=format),
54 | context="get_zz500_stocks",
55 | )
56 |
57 | @app.tool()
58 | def get_index_constituents(index: str, date: Optional[str] = None, limit: int = 250, format: str = "markdown") -> str:
59 | """Generic index constituent fetch (hs300/sz50/zz500)."""
60 | return run_tool_with_handling(
61 | lambda: fetch_index_constituents(active_data_source, index=index, date=date, limit=limit, format=format),
62 | context=f"get_index_constituents:{index}",
63 | )
64 |
65 | @app.tool()
66 | def list_industries(date: Optional[str] = None, format: str = "markdown") -> str:
67 | """List distinct industries for a given date."""
68 | logger.info("Tool 'list_industries' called date=%s", date or "latest")
69 | return run_tool_with_handling(
70 | lambda: fetch_list_industries(active_data_source, date=date, format=format),
71 | context="list_industries",
72 | )
73 |
74 | @app.tool()
75 | def get_industry_members(industry: str, date: Optional[str] = None, limit: int = 250, format: str = "markdown") -> str:
76 | """Get all stocks in a given industry on a date."""
77 | logger.info("Tool 'get_industry_members' called industry=%s, date=%s", industry, date or "latest")
78 | return run_tool_with_handling(
79 | lambda: fetch_industry_members(active_data_source, industry=industry, date=date, limit=limit, format=format),
80 | context=f"get_industry_members:{industry}",
81 | )
82 |
```
--------------------------------------------------------------------------------
/src/formatting/markdown_formatter.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Markdown formatting utilities for A-Share MCP Server.
3 | """
4 | import pandas as pd
5 | import logging
6 | import json
7 |
8 | logger = logging.getLogger(__name__)
9 |
10 | # Configuration: Max rows to display in string outputs to protect context length
11 | MAX_MARKDOWN_ROWS = 250
12 |
13 |
14 | def format_df_to_markdown(df: pd.DataFrame, max_rows: int = None) -> str:
15 | """Formats a Pandas DataFrame to a Markdown string with row truncation.
16 |
17 | Args:
18 | df: The DataFrame to format
19 | max_rows: Maximum rows to include in output. Defaults to MAX_MARKDOWN_ROWS if None.
20 |
21 | Returns:
22 | A markdown formatted string representation of the DataFrame
23 | """
24 | if df is None or df.empty:
25 | logger.warning("Attempted to format an empty DataFrame to Markdown.")
26 | return "(No data available to display)"
27 |
28 | if max_rows is None:
29 | max_rows = MAX_MARKDOWN_ROWS
30 |
31 | original_rows = df.shape[0]
32 | rows_to_show = min(original_rows, max_rows)
33 | df_display = df.head(rows_to_show)
34 |
35 | truncated = original_rows > rows_to_show
36 |
37 | try:
38 | markdown_table = df_display.to_markdown(index=False)
39 | except Exception as e:
40 | logger.error("Error converting DataFrame to Markdown: %s", e, exc_info=True)
41 | return "Error: Could not format data into Markdown table."
42 |
43 | if truncated:
44 | notes = f"rows truncated to {rows_to_show} from {original_rows}"
45 | return f"Note: Data truncated ({notes}).\n\n{markdown_table}"
46 | return markdown_table
47 |
48 |
49 | def format_table_output(
50 | df: pd.DataFrame,
51 | format: str = "markdown",
52 | max_rows: int | None = None,
53 | meta: dict | None = None,
54 | ) -> str:
55 | """Formats a DataFrame into the requested string format with optional meta.
56 |
57 | Args:
58 | df: Data to format.
59 | format: 'markdown' | 'json' | 'csv'. Defaults to 'markdown'.
60 | max_rows: Optional max rows to include (defaults depend on formatters).
61 | meta: Optional metadata dict to include (prepended for markdown, embedded for json).
62 |
63 | Returns:
64 | A string suitable for tool responses.
65 | """
66 | fmt = (format or "markdown").lower()
67 |
68 | # Normalize row cap
69 | if max_rows is None:
70 | max_rows = MAX_MARKDOWN_ROWS if fmt == "markdown" else MAX_MARKDOWN_ROWS
71 |
72 | total_rows = 0 if df is None else int(df.shape[0])
73 | rows_to_show = 0 if df is None else min(total_rows, max_rows)
74 | truncated = total_rows > rows_to_show
75 | df_display = df.head(rows_to_show) if df is not None else pd.DataFrame()
76 |
77 | if fmt == "markdown":
78 | header = ""
79 | if meta:
80 | # Render a compact meta header
81 | lines = ["Meta:"]
82 | for k, v in meta.items():
83 | lines.append(f"- {k}: {v}")
84 | header = "\n".join(lines) + "\n\n"
85 | return header + format_df_to_markdown(df_display, max_rows=max_rows)
86 |
87 | if fmt == "csv":
88 | try:
89 | return df_display.to_csv(index=False)
90 | except Exception as e:
91 | logger.error("Error converting DataFrame to CSV: %s", e, exc_info=True)
92 | return "Error: Could not format data into CSV."
93 |
94 | if fmt == "json":
95 | try:
96 | payload = {
97 | "data": [] if df_display is None else df_display.to_dict(orient="records"),
98 | "meta": {
99 | **(meta or {}),
100 | "total_rows": total_rows,
101 | "returned_rows": rows_to_show,
102 | "truncated": truncated,
103 | "columns": [] if df_display is None else list(df_display.columns),
104 | },
105 | }
106 | return json.dumps(payload, ensure_ascii=False)
107 | except Exception as e:
108 | logger.error("Error converting DataFrame to JSON: %s", e, exc_info=True)
109 | return "Error: Could not format data into JSON."
110 |
111 | # Fallback to markdown if unknown format
112 | logger.warning("Unknown format '%s', falling back to markdown", fmt)
113 | return format_df_to_markdown(df_display, max_rows=max_rows)
114 |
```
--------------------------------------------------------------------------------
/src/tools/market_overview.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Market overview tools for the MCP server.
3 | Includes trading calendar, stock list, and discovery helpers.
4 | """
5 | import logging
6 | from typing import Optional
7 |
8 | from mcp.server.fastmcp import FastMCP
9 | from src.data_source_interface import FinancialDataSource
10 | from src.services.tool_runner import run_tool_with_handling
11 | from src.use_cases.market_overview import (
12 | fetch_all_stock,
13 | fetch_search_stocks,
14 | fetch_suspensions,
15 | fetch_trade_dates,
16 | )
17 |
18 | logger = logging.getLogger(__name__)
19 |
20 |
21 | def register_market_overview_tools(app: FastMCP, active_data_source: FinancialDataSource):
22 | """
23 | Register market overview tools with the MCP app.
24 |
25 | Args:
26 | app: The FastMCP app instance
27 | active_data_source: The active financial data source
28 | """
29 |
30 | @app.tool()
31 | def get_trade_dates(start_date: Optional[str] = None, end_date: Optional[str] = None, limit: int = 250, format: str = "markdown") -> str:
32 | """
33 | Fetch trading dates within a specified range.
34 |
35 | Args:
36 | start_date: Optional. Start date in 'YYYY-MM-DD' format. Defaults to 2015-01-01 if None.
37 | end_date: Optional. End date in 'YYYY-MM-DD' format. Defaults to the current date if None.
38 |
39 | Returns:
40 | Markdown table with 'is_trading_day' (1=trading, 0=non-trading).
41 | """
42 | logger.info(f"Tool 'get_trade_dates' called for range {start_date or 'default'} to {end_date or 'default'}")
43 | return run_tool_with_handling(
44 | lambda: fetch_trade_dates(active_data_source, start_date=start_date, end_date=end_date, limit=limit, format=format),
45 | context="get_trade_dates",
46 | )
47 |
48 | @app.tool()
49 | def get_all_stock(date: Optional[str] = None, limit: int = 250, format: str = "markdown") -> str:
50 | """
51 | Fetch a list of all stocks (A-shares and indices) and their trading status for a date.
52 |
53 | Args:
54 | date: Optional. The date in 'YYYY-MM-DD' format. If None, uses the current date.
55 |
56 | Returns:
57 | Markdown table listing stock codes and trading status (1=trading, 0=suspended).
58 | """
59 | logger.info(f"Tool 'get_all_stock' called for date={date or 'default'}")
60 | return run_tool_with_handling(
61 | lambda: fetch_all_stock(active_data_source, date=date, limit=limit, format=format),
62 | context=f"get_all_stock:{date or 'default'}",
63 | )
64 |
65 | @app.tool()
66 | def search_stocks(keyword: str, date: Optional[str] = None, limit: int = 50, format: str = "markdown") -> str:
67 | """
68 | Search stocks by code substring on a date.
69 |
70 | Args:
71 | keyword: Substring to match in the stock code (e.g., '600', '000001').
72 | date: Optional 'YYYY-MM-DD'. If None, uses current date.
73 | limit: Max rows to return. Defaults to 50.
74 | format: Output format: 'markdown' | 'json' | 'csv'. Defaults to 'markdown'.
75 |
76 | Returns:
77 | Matching stock codes with their trading status.
78 | """
79 | logger.info("Tool 'search_stocks' called keyword=%s, date=%s, limit=%s, format=%s", keyword, date or "default", limit, format)
80 | return run_tool_with_handling(
81 | lambda: fetch_search_stocks(active_data_source, keyword=keyword, date=date, limit=limit, format=format),
82 | context=f"search_stocks:{keyword}",
83 | )
84 |
85 | @app.tool()
86 | def get_suspensions(date: Optional[str] = None, limit: int = 250, format: str = "markdown") -> str:
87 | """
88 | List suspended stocks for a date.
89 |
90 | Args:
91 | date: Optional 'YYYY-MM-DD'. If None, uses current date.
92 | limit: Max rows to return. Defaults to 250.
93 | format: Output format: 'markdown' | 'json' | 'csv'. Defaults to 'markdown'.
94 |
95 | Returns:
96 | Table of stocks where tradeStatus==0.
97 | """
98 | logger.info("Tool 'get_suspensions' called date=%s, limit=%s, format=%s", date or "current", limit, format)
99 | return run_tool_with_handling(
100 | lambda: fetch_suspensions(active_data_source, date=date, limit=limit, format=format),
101 | context=f"get_suspensions:{date or 'current'}",
102 | )
103 |
```
--------------------------------------------------------------------------------
/src/use_cases/financial_reports.py:
--------------------------------------------------------------------------------
```python
1 | """Use cases for financial report related tools."""
2 | from typing import Optional
3 |
4 | from src.data_source_interface import FinancialDataSource
5 | from src.formatting.markdown_formatter import format_table_output
6 | from src.services.validation import (
7 | validate_output_format,
8 | validate_quarter,
9 | validate_year,
10 | )
11 |
12 |
13 | def _format_financial_df(df, *, code: str, year: str | None, quarter: Optional[int], dataset: str, format: str, limit: int) -> str:
14 | meta = {"code": code, "dataset": dataset}
15 | if year:
16 | meta["year"] = year
17 | if quarter is not None:
18 | meta["quarter"] = quarter
19 | return format_table_output(df, format=format, max_rows=limit, meta=meta)
20 |
21 |
22 | def fetch_profit_data(data_source: FinancialDataSource, *, code: str, year: str, quarter: int, limit: int, format: str) -> str:
23 | validate_year(year)
24 | validate_quarter(quarter)
25 | validate_output_format(format)
26 | df = data_source.get_profit_data(code=code, year=year, quarter=quarter)
27 | return _format_financial_df(df, code=code, year=year, quarter=quarter, dataset="Profitability", format=format, limit=limit)
28 |
29 |
30 | def fetch_operation_data(data_source: FinancialDataSource, *, code: str, year: str, quarter: int, limit: int, format: str) -> str:
31 | validate_year(year)
32 | validate_quarter(quarter)
33 | validate_output_format(format)
34 | df = data_source.get_operation_data(code=code, year=year, quarter=quarter)
35 | return _format_financial_df(df, code=code, year=year, quarter=quarter, dataset="Operation Capability", format=format, limit=limit)
36 |
37 |
38 | def fetch_growth_data(data_source: FinancialDataSource, *, code: str, year: str, quarter: int, limit: int, format: str) -> str:
39 | validate_year(year)
40 | validate_quarter(quarter)
41 | validate_output_format(format)
42 | df = data_source.get_growth_data(code=code, year=year, quarter=quarter)
43 | return _format_financial_df(df, code=code, year=year, quarter=quarter, dataset="Growth", format=format, limit=limit)
44 |
45 |
46 | def fetch_balance_data(data_source: FinancialDataSource, *, code: str, year: str, quarter: int, limit: int, format: str) -> str:
47 | validate_year(year)
48 | validate_quarter(quarter)
49 | validate_output_format(format)
50 | df = data_source.get_balance_data(code=code, year=year, quarter=quarter)
51 | return _format_financial_df(df, code=code, year=year, quarter=quarter, dataset="Balance Sheet", format=format, limit=limit)
52 |
53 |
54 | def fetch_cash_flow_data(data_source: FinancialDataSource, *, code: str, year: str, quarter: int, limit: int, format: str) -> str:
55 | validate_year(year)
56 | validate_quarter(quarter)
57 | validate_output_format(format)
58 | df = data_source.get_cash_flow_data(code=code, year=year, quarter=quarter)
59 | return _format_financial_df(df, code=code, year=year, quarter=quarter, dataset="Cash Flow", format=format, limit=limit)
60 |
61 |
62 | def fetch_dupont_data(data_source: FinancialDataSource, *, code: str, year: str, quarter: int, limit: int, format: str) -> str:
63 | validate_year(year)
64 | validate_quarter(quarter)
65 | validate_output_format(format)
66 | df = data_source.get_dupont_data(code=code, year=year, quarter=quarter)
67 | return _format_financial_df(df, code=code, year=year, quarter=quarter, dataset="Dupont", format=format, limit=limit)
68 |
69 |
70 | def fetch_performance_express_report(data_source: FinancialDataSource, *, code: str, start_date: str, end_date: str, limit: int, format: str) -> str:
71 | validate_output_format(format)
72 | df = data_source.get_performance_express_report(code=code, start_date=start_date, end_date=end_date)
73 | meta = {"code": code, "start_date": start_date, "end_date": end_date, "dataset": "Performance Express"}
74 | return format_table_output(df, format=format, max_rows=limit, meta=meta)
75 |
76 |
77 | def fetch_forecast_report(data_source: FinancialDataSource, *, code: str, start_date: str, end_date: str, limit: int, format: str) -> str:
78 | validate_output_format(format)
79 | df = data_source.get_forecast_report(code=code, start_date=start_date, end_date=end_date)
80 | meta = {"code": code, "start_date": start_date, "end_date": end_date, "dataset": "Forecast"}
81 | return format_table_output(df, format=format, max_rows=limit, meta=meta)
82 |
```
--------------------------------------------------------------------------------
/src/tools/financial_reports.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Financial report tools for the MCP server.
3 | Thin wrappers delegating to use cases with shared validation and error handling.
4 | """
5 | import logging
6 | from typing import Optional
7 |
8 | from mcp.server.fastmcp import FastMCP
9 | from src.data_source_interface import FinancialDataSource
10 | from src.services.tool_runner import run_tool_with_handling
11 | from src.use_cases.financial_reports import (
12 | fetch_balance_data,
13 | fetch_cash_flow_data,
14 | fetch_dupont_data,
15 | fetch_forecast_report,
16 | fetch_growth_data,
17 | fetch_operation_data,
18 | fetch_performance_express_report,
19 | fetch_profit_data,
20 | )
21 |
22 | logger = logging.getLogger(__name__)
23 |
24 |
25 | def register_financial_report_tools(app: FastMCP, active_data_source: FinancialDataSource):
26 | """
27 | Register financial report related tools with the MCP app.
28 | """
29 |
30 | @app.tool()
31 | def get_profit_data(code: str, year: str, quarter: int, limit: int = 250, format: str = "markdown") -> str:
32 | """Quarterly profitability data."""
33 | return run_tool_with_handling(
34 | lambda: fetch_profit_data(active_data_source, code=code, year=year, quarter=quarter, limit=limit, format=format),
35 | context=f"get_profit_data:{code}:{year}Q{quarter}",
36 | )
37 |
38 | @app.tool()
39 | def get_operation_data(code: str, year: str, quarter: int, limit: int = 250, format: str = "markdown") -> str:
40 | """Quarterly operation capability data."""
41 | return run_tool_with_handling(
42 | lambda: fetch_operation_data(active_data_source, code=code, year=year, quarter=quarter, limit=limit, format=format),
43 | context=f"get_operation_data:{code}:{year}Q{quarter}",
44 | )
45 |
46 | @app.tool()
47 | def get_growth_data(code: str, year: str, quarter: int, limit: int = 250, format: str = "markdown") -> str:
48 | """Quarterly growth capability data."""
49 | return run_tool_with_handling(
50 | lambda: fetch_growth_data(active_data_source, code=code, year=year, quarter=quarter, limit=limit, format=format),
51 | context=f"get_growth_data:{code}:{year}Q{quarter}",
52 | )
53 |
54 | @app.tool()
55 | def get_balance_data(code: str, year: str, quarter: int, limit: int = 250, format: str = "markdown") -> str:
56 | """Quarterly balance sheet data."""
57 | return run_tool_with_handling(
58 | lambda: fetch_balance_data(active_data_source, code=code, year=year, quarter=quarter, limit=limit, format=format),
59 | context=f"get_balance_data:{code}:{year}Q{quarter}",
60 | )
61 |
62 | @app.tool()
63 | def get_cash_flow_data(code: str, year: str, quarter: int, limit: int = 250, format: str = "markdown") -> str:
64 | """Quarterly cash flow data."""
65 | return run_tool_with_handling(
66 | lambda: fetch_cash_flow_data(active_data_source, code=code, year=year, quarter=quarter, limit=limit, format=format),
67 | context=f"get_cash_flow_data:{code}:{year}Q{quarter}",
68 | )
69 |
70 | @app.tool()
71 | def get_dupont_data(code: str, year: str, quarter: int, limit: int = 250, format: str = "markdown") -> str:
72 | """Quarterly Dupont analysis data."""
73 | return run_tool_with_handling(
74 | lambda: fetch_dupont_data(active_data_source, code=code, year=year, quarter=quarter, limit=limit, format=format),
75 | context=f"get_dupont_data:{code}:{year}Q{quarter}",
76 | )
77 |
78 | @app.tool()
79 | def get_performance_express_report(code: str, start_date: str, end_date: str, limit: int = 250, format: str = "markdown") -> str:
80 | """Performance express report within date range."""
81 | return run_tool_with_handling(
82 | lambda: fetch_performance_express_report(
83 | active_data_source, code=code, start_date=start_date, end_date=end_date, limit=limit, format=format
84 | ),
85 | context=f"get_performance_express_report:{code}:{start_date}-{end_date}",
86 | )
87 |
88 | @app.tool()
89 | def get_forecast_report(code: str, start_date: str, end_date: str, limit: int = 250, format: str = "markdown") -> str:
90 | """Earnings forecast report within date range."""
91 | return run_tool_with_handling(
92 | lambda: fetch_forecast_report(
93 | active_data_source, code=code, start_date=start_date, end_date=end_date, limit=limit, format=format
94 | ),
95 | context=f"get_forecast_report:{code}:{start_date}-{end_date}",
96 | )
97 |
```
--------------------------------------------------------------------------------
/src/use_cases/date_utils.py:
--------------------------------------------------------------------------------
```python
1 | """Use cases for date utility tools."""
2 | import calendar
3 | from datetime import datetime, timedelta
4 | from typing import Optional
5 |
6 | import pandas as pd
7 |
8 | from src.data_source_interface import FinancialDataSource
9 |
10 |
11 | def _fetch_trading_days(data_source: FinancialDataSource, start_date: str, end_date: str) -> pd.DataFrame:
12 | return data_source.get_trade_dates(start_date=start_date, end_date=end_date)
13 |
14 |
15 | def get_latest_trading_date(data_source: FinancialDataSource) -> str:
16 | today = datetime.now().strftime("%Y-%m-%d")
17 | start_date = datetime.now().replace(day=1).strftime("%Y-%m-%d")
18 | end_date = datetime.now().replace(day=28).strftime("%Y-%m-%d")
19 | df = _fetch_trading_days(data_source, start_date=start_date, end_date=end_date)
20 | valid_trading_days = df[df["is_trading_day"] == "1"]["calendar_date"].tolist()
21 | latest_trading_date = None
22 | for dstr in valid_trading_days:
23 | if dstr <= today and (latest_trading_date is None or dstr > latest_trading_date):
24 | latest_trading_date = dstr
25 | return latest_trading_date or today
26 |
27 |
28 | def get_market_analysis_timeframe(period: str = "recent") -> str:
29 | now = datetime.now()
30 | end_date = now
31 | if period == "recent":
32 | if now.day < 15:
33 | if now.month == 1:
34 | start_date = datetime(now.year - 1, 11, 1)
35 | else:
36 | prev_month = now.month - 1
37 | start_month = prev_month if prev_month > 0 else 12
38 | start_year = now.year if prev_month > 0 else now.year - 1
39 | start_date = datetime(start_year, start_month, 1)
40 | else:
41 | start_date = datetime(now.year, now.month, 1)
42 | elif period == "quarter":
43 | quarter = (now.month - 1) // 3 + 1
44 | start_month = (quarter - 1) * 3 + 1
45 | start_date = datetime(now.year, start_month, 1)
46 | elif period == "half_year":
47 | start_month = 1 if now.month <= 6 else 7
48 | start_date = datetime(now.year, start_month, 1)
49 | elif period == "year":
50 | start_date = datetime(now.year, 1, 1)
51 | else:
52 | raise ValueError("Invalid period. Use 'recent', 'quarter', 'half_year', or 'year'.")
53 | return f"{start_date.strftime('%Y-%m-%d')} 至 {end_date.strftime('%Y-%m-%d')}"
54 |
55 |
56 | def is_trading_day(data_source: FinancialDataSource, *, date: str) -> str:
57 | df = _fetch_trading_days(data_source, start_date=date, end_date=date)
58 | if df.empty:
59 | return "未知"
60 | row = df.iloc[0]
61 | return "是" if str(row.get("is_trading_day", "")) == "1" else "否"
62 |
63 |
64 | def previous_trading_day(data_source: FinancialDataSource, *, date: str) -> str:
65 | target = datetime.strptime(date, "%Y-%m-%d")
66 | start = (target - timedelta(days=31)).strftime("%Y-%m-%d")
67 | df = _fetch_trading_days(data_source, start_date=start, end_date=date)
68 | days = df[df["is_trading_day"] == "1"]["calendar_date"].tolist()
69 | prev = max([d for d in days if d < date], default=None)
70 | return prev or date
71 |
72 |
73 | def next_trading_day(data_source: FinancialDataSource, *, date: str) -> str:
74 | target = datetime.strptime(date, "%Y-%m-%d")
75 | end = (target + timedelta(days=31)).strftime("%Y-%m-%d")
76 | df = _fetch_trading_days(data_source, start_date=date, end_date=end)
77 | days = df[df["is_trading_day"] == "1"]["calendar_date"].tolist()
78 | next_day = min([d for d in days if d > date], default=None)
79 | return next_day or date
80 |
81 |
82 | def get_last_n_trading_days(data_source: FinancialDataSource, *, days: int) -> str:
83 | today = datetime.now()
84 | start = (today - timedelta(days=days * 2)).strftime("%Y-%m-%d")
85 | end = today.strftime("%Y-%m-%d")
86 | df = _fetch_trading_days(data_source, start_date=start, end_date=end)
87 | trading_days = df[df["is_trading_day"] == "1"]["calendar_date"].tolist()
88 | return ", ".join(trading_days[-days:]) if trading_days else ""
89 |
90 |
91 | def get_recent_trading_range(data_source: FinancialDataSource, *, days: int) -> str:
92 | today = datetime.now()
93 | start = (today - timedelta(days=days * 2)).strftime("%Y-%m-%d")
94 | end = today.strftime("%Y-%m-%d")
95 | df = _fetch_trading_days(data_source, start_date=start, end_date=end)
96 | trading_days = df[df["is_trading_day"] == "1"]["calendar_date"].tolist()
97 | if not trading_days:
98 | return ""
99 | return f"{trading_days[-days]} 至 {trading_days[-1]}" if len(trading_days) >= days else f"{trading_days[0]} 至 {trading_days[-1]}"
100 |
101 |
102 | def get_month_end_trading_dates(data_source: FinancialDataSource, *, year: int) -> str:
103 | results = []
104 | for month in range(1, 13):
105 | last_day = calendar.monthrange(year, month)[1]
106 | start_date = datetime(year, month, last_day - 7).strftime("%Y-%m-%d")
107 | end_date = datetime(year, month, last_day).strftime("%Y-%m-%d")
108 | df = _fetch_trading_days(data_source, start_date=start_date, end_date=end_date)
109 | trading_days = df[df["is_trading_day"] == "1"]["calendar_date"].tolist()
110 | if trading_days:
111 | results.append(trading_days[-1])
112 | return ", ".join(results)
113 |
```
--------------------------------------------------------------------------------
/src/tools/stock_market.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Stock market tools for the MCP server.
3 | Thin wrappers that delegate to use cases with shared validation and error handling.
4 | """
5 | import logging
6 | from typing import List, Optional
7 |
8 | from mcp.server.fastmcp import FastMCP
9 | from src.data_source_interface import FinancialDataSource
10 | from src.services.tool_runner import run_tool_with_handling
11 | from src.use_cases.stock_market import (
12 | fetch_adjust_factor_data,
13 | fetch_dividend_data,
14 | fetch_historical_k_data,
15 | fetch_stock_basic_info,
16 | )
17 |
18 | logger = logging.getLogger(__name__)
19 |
20 |
21 | def register_stock_market_tools(app: FastMCP, active_data_source: FinancialDataSource):
22 | """
23 | Register stock market data tools with the MCP app.
24 |
25 | Args:
26 | app: The FastMCP app instance
27 | active_data_source: The active financial data source
28 | """
29 |
30 | @app.tool()
31 | def get_historical_k_data(
32 | code: str,
33 | start_date: str,
34 | end_date: str,
35 | frequency: str = "d",
36 | adjust_flag: str = "3",
37 | fields: Optional[List[str]] = None,
38 | limit: int = 250,
39 | format: str = "markdown",
40 | ) -> str:
41 | """
42 | Fetches historical K-line (OHLCV) data for a Chinese A-share stock.
43 |
44 | Args:
45 | code: The stock code in Baostock format (e.g., 'sh.600000', 'sz.000001').
46 | start_date: Start date in 'YYYY-MM-DD' format.
47 | end_date: End date in 'YYYY-MM-DD' format.
48 | frequency: Data frequency. Valid options (from Baostock):
49 | 'd': daily
50 | 'w': weekly
51 | 'm': monthly
52 | '5': 5 minutes
53 | '15': 15 minutes
54 | '30': 30 minutes
55 | '60': 60 minutes
56 | Defaults to 'd'.
57 | adjust_flag: Adjustment flag for price/volume. Valid options (from Baostock):
58 | '1': Forward adjusted (后复权)
59 | '2': Backward adjusted (前复权)
60 | '3': Non-adjusted (不复权)
61 | Defaults to '3'.
62 | fields: Optional list of specific data fields to retrieve (must be valid Baostock fields).
63 | If None or empty, default fields will be used (e.g., date, code, open, high, low, close, volume, amount, pctChg).
64 | limit: Max rows to return. Defaults to 250.
65 | format: Output format: 'markdown' | 'json' | 'csv'. Defaults to 'markdown'.
66 |
67 | Returns:
68 | A Markdown formatted string containing the K-line data table, or an error message.
69 | The table might be truncated if the result set is too large.
70 | """
71 | logger.info(
72 | f"Tool 'get_historical_k_data' called for {code} ({start_date}-{end_date}, freq={frequency}, adj={adjust_flag}, fields={fields})"
73 | )
74 | return run_tool_with_handling(
75 | lambda: fetch_historical_k_data(
76 | active_data_source,
77 | code=code,
78 | start_date=start_date,
79 | end_date=end_date,
80 | frequency=frequency,
81 | adjust_flag=adjust_flag,
82 | fields=fields,
83 | limit=limit,
84 | format=format,
85 | ),
86 | context=f"get_historical_k_data:{code}",
87 | )
88 |
89 | @app.tool()
90 | def get_stock_basic_info(code: str, fields: Optional[List[str]] = None, format: str = "markdown") -> str:
91 | """
92 | Fetches basic information for a given Chinese A-share stock.
93 |
94 | Args:
95 | code: The stock code in Baostock format (e.g., 'sh.600000', 'sz.000001').
96 | fields: Optional list to select specific columns from the available basic info
97 | (e.g., ['code', 'code_name', 'industry', 'listingDate']).
98 | If None or empty, returns all available basic info columns from Baostock.
99 |
100 | Returns:
101 | Basic stock information in the requested format.
102 | """
103 | logger.info(f"Tool 'get_stock_basic_info' called for {code} (fields={fields})")
104 | return run_tool_with_handling(
105 | lambda: fetch_stock_basic_info(
106 | active_data_source, code=code, fields=fields, format=format
107 | ),
108 | context=f"get_stock_basic_info:{code}",
109 | )
110 |
111 | @app.tool()
112 | def get_dividend_data(code: str, year: str, year_type: str = "report", limit: int = 250, format: str = "markdown") -> str:
113 | """
114 | Fetches dividend information for a given stock code and year.
115 |
116 | Args:
117 | code: The stock code in Baostock format (e.g., 'sh.600000', 'sz.000001').
118 | year: The year to query (e.g., '2023').
119 | year_type: Type of year. Valid options (from Baostock):
120 | 'report': Announcement year (预案公告年份)
121 | 'operate': Ex-dividend year (除权除息年份)
122 | Defaults to 'report'.
123 |
124 | Returns:
125 | Dividend records table.
126 | """
127 | logger.info(f"Tool 'get_dividend_data' called for {code}, year={year}, year_type={year_type}")
128 | return run_tool_with_handling(
129 | lambda: fetch_dividend_data(
130 | active_data_source,
131 | code=code,
132 | year=year,
133 | year_type=year_type,
134 | limit=limit,
135 | format=format,
136 | ),
137 | context=f"get_dividend_data:{code}:{year}",
138 | )
139 |
140 | @app.tool()
141 | def get_adjust_factor_data(code: str, start_date: str, end_date: str, limit: int = 250, format: str = "markdown") -> str:
142 | """
143 | Fetches adjustment factor data for a given stock code and date range.
144 | Uses Baostock's "涨跌幅复权算法" factors. Useful for calculating adjusted prices.
145 |
146 | Args:
147 | code: The stock code in Baostock format (e.g., 'sh.600000', 'sz.000001').
148 | start_date: Start date in 'YYYY-MM-DD' format.
149 | end_date: End date in 'YYYY-MM-DD' format.
150 |
151 | Returns:
152 | Adjustment factors table.
153 | """
154 | logger.info(f"Tool 'get_adjust_factor_data' called for {code} ({start_date} to {end_date})")
155 | return run_tool_with_handling(
156 | lambda: fetch_adjust_factor_data(
157 | active_data_source,
158 | code=code,
159 | start_date=start_date,
160 | end_date=end_date,
161 | limit=limit,
162 | format=format,
163 | ),
164 | context=f"get_adjust_factor_data:{code}",
165 | )
166 |
```
--------------------------------------------------------------------------------
/src/data_source_interface.py:
--------------------------------------------------------------------------------
```python
1 | # Defines the abstract interface for financial data sources
2 | from abc import ABC, abstractmethod
3 | import pandas as pd
4 | from typing import Optional, List
5 |
6 | class DataSourceError(Exception):
7 | """Base exception for data source errors."""
8 | pass
9 |
10 |
11 | class LoginError(DataSourceError):
12 | """Exception raised for login failures to the data source."""
13 | pass
14 |
15 |
16 | class NoDataFoundError(DataSourceError):
17 | """Exception raised when no data is found for the given query."""
18 | pass
19 |
20 |
21 | class FinancialDataSource(ABC):
22 | """
23 | Abstract base class defining the interface for financial data sources.
24 | Implementations of this class provide access to specific financial data APIs
25 | (e.g., Baostock, Akshare).
26 | """
27 |
28 | @abstractmethod
29 | def get_historical_k_data(
30 | self,
31 | code: str,
32 | start_date: str,
33 | end_date: str,
34 | frequency: str = "d",
35 | adjust_flag: str = "3",
36 | fields: Optional[List[str]] = None,
37 | ) -> pd.DataFrame:
38 | """
39 | Fetches historical K-line (OHLCV) data for a given stock code.
40 |
41 | Args:
42 | code: The stock code (e.g., 'sh.600000', 'sz.000001').
43 | start_date: Start date in 'YYYY-MM-DD' format.
44 | end_date: End date in 'YYYY-MM-DD' format.
45 | frequency: Data frequency. Common values depend on the underlying
46 | source (e.g., 'd' for daily, 'w' for weekly, 'm' for monthly,
47 | '5', '15', '30', '60' for minutes). Defaults to 'd'.
48 | adjust_flag: Adjustment flag for historical data. Common values
49 | depend on the source (e.g., '1' for forward adjusted,
50 | '2' for backward adjusted, '3' for non-adjusted).
51 | Defaults to '3'.
52 | fields: Optional list of specific fields to retrieve. If None,
53 | retrieves default fields defined by the implementation.
54 |
55 | Returns:
56 | A pandas DataFrame containing the historical K-line data, with
57 | columns corresponding to the requested fields.
58 |
59 | Raises:
60 | LoginError: If login to the data source fails.
61 | NoDataFoundError: If no data is found for the query.
62 | DataSourceError: For other data source related errors.
63 | ValueError: If input parameters are invalid.
64 | """
65 | pass
66 |
67 | @abstractmethod
68 | def get_stock_basic_info(self, code: str) -> pd.DataFrame:
69 | """
70 | Fetches basic information for a given stock code.
71 |
72 | Args:
73 | code: The stock code (e.g., 'sh.600000', 'sz.000001').
74 |
75 | Returns:
76 | A pandas DataFrame containing the basic stock information.
77 | The structure and columns depend on the underlying data source.
78 | Typically contains info like name, industry, listing date, etc.
79 |
80 | Raises:
81 | LoginError: If login to the data source fails.
82 | NoDataFoundError: If no data is found for the query.
83 | DataSourceError: For other data source related errors.
84 | ValueError: If the input code is invalid.
85 | """
86 | pass
87 |
88 | @abstractmethod
89 | def get_trade_dates(self, start_date: Optional[str] = None, end_date: Optional[str] = None) -> pd.DataFrame:
90 | """Fetches trading dates information within a range."""
91 | pass
92 |
93 | @abstractmethod
94 | def get_all_stock(self, date: Optional[str] = None) -> pd.DataFrame:
95 | """Fetches list of all stocks and their trading status on a given date."""
96 | pass
97 |
98 | @abstractmethod
99 | def get_deposit_rate_data(self, start_date: Optional[str] = None, end_date: Optional[str] = None) -> pd.DataFrame:
100 | """Fetches benchmark deposit rates."""
101 | pass
102 |
103 | @abstractmethod
104 | def get_loan_rate_data(self, start_date: Optional[str] = None, end_date: Optional[str] = None) -> pd.DataFrame:
105 | """Fetches benchmark loan rates."""
106 | pass
107 |
108 | @abstractmethod
109 | def get_required_reserve_ratio_data(self, start_date: Optional[str] = None, end_date: Optional[str] = None, year_type: str = '0') -> pd.DataFrame:
110 | """Fetches required reserve ratio data."""
111 | pass
112 |
113 | @abstractmethod
114 | def get_money_supply_data_month(self, start_date: Optional[str] = None, end_date: Optional[str] = None) -> pd.DataFrame:
115 | """Fetches monthly money supply data (M0, M1, M2)."""
116 | pass
117 |
118 | @abstractmethod
119 | def get_money_supply_data_year(self, start_date: Optional[str] = None, end_date: Optional[str] = None) -> pd.DataFrame:
120 | """Fetches yearly money supply data (M0, M1, M2 - year end balance)."""
121 | pass
122 |
123 | @abstractmethod
124 | def get_dividend_data(self, code: str, year: str, year_type: str = "report") -> pd.DataFrame:
125 | """Fetches dividend information for a stock and year."""
126 | pass
127 |
128 | @abstractmethod
129 | def get_adjust_factor_data(self, code: str, start_date: str, end_date: str) -> pd.DataFrame:
130 | """Fetches adjustment factor data used for price adjustments."""
131 | pass
132 |
133 | # Financial report datasets
134 | @abstractmethod
135 | def get_profit_data(self, code: str, year: str, quarter: int) -> pd.DataFrame:
136 | pass
137 |
138 | @abstractmethod
139 | def get_operation_data(self, code: str, year: str, quarter: int) -> pd.DataFrame:
140 | pass
141 |
142 | @abstractmethod
143 | def get_growth_data(self, code: str, year: str, quarter: int) -> pd.DataFrame:
144 | pass
145 |
146 | @abstractmethod
147 | def get_balance_data(self, code: str, year: str, quarter: int) -> pd.DataFrame:
148 | pass
149 |
150 | @abstractmethod
151 | def get_cash_flow_data(self, code: str, year: str, quarter: int) -> pd.DataFrame:
152 | pass
153 |
154 | @abstractmethod
155 | def get_dupont_data(self, code: str, year: str, quarter: int) -> pd.DataFrame:
156 | pass
157 |
158 | @abstractmethod
159 | def get_performance_express_report(self, code: str, start_date: str, end_date: str) -> pd.DataFrame:
160 | pass
161 |
162 | @abstractmethod
163 | def get_forecast_report(self, code: str, start_date: str, end_date: str) -> pd.DataFrame:
164 | pass
165 |
166 | # Index / industry
167 | @abstractmethod
168 | def get_stock_industry(self, code: Optional[str] = None, date: Optional[str] = None) -> pd.DataFrame:
169 | pass
170 |
171 | @abstractmethod
172 | def get_hs300_stocks(self, date: Optional[str] = None) -> pd.DataFrame:
173 | pass
174 |
175 | @abstractmethod
176 | def get_sz50_stocks(self, date: Optional[str] = None) -> pd.DataFrame:
177 | pass
178 |
179 | @abstractmethod
180 | def get_zz500_stocks(self, date: Optional[str] = None) -> pd.DataFrame:
181 | pass
182 |
183 | # Market overview
184 | @abstractmethod
185 | def get_all_stock(self, date: Optional[str] = None) -> pd.DataFrame:
186 | pass
187 | # Note: SHIBOR is not implemented in current Baostock bindings; no abstract method here.
188 |
```
--------------------------------------------------------------------------------
/src/baostock_data_source.py:
--------------------------------------------------------------------------------
```python
1 | # Implementation of the FinancialDataSource interface using Baostock
2 | import baostock as bs
3 | import pandas as pd
4 | from typing import List, Optional
5 | import logging
6 | from .data_source_interface import FinancialDataSource, DataSourceError, NoDataFoundError, LoginError
7 | from .utils import baostock_login_context
8 |
9 | # Get a logger instance for this module
10 | logger = logging.getLogger(__name__)
11 |
12 | DEFAULT_K_FIELDS = [
13 | "date", "code", "open", "high", "low", "close", "preclose",
14 | "volume", "amount", "adjustflag", "turn", "tradestatus",
15 | "pctChg", "peTTM", "pbMRQ", "psTTM", "pcfNcfTTM", "isST"
16 | ]
17 |
18 | DEFAULT_BASIC_FIELDS = [
19 | "code", "tradeStatus", "code_name"
20 | # Add more default fields as needed, e.g., "industry", "listingDate"
21 | ]
22 |
23 | # Helper function to reduce repetition in financial data fetching
24 |
25 |
26 | def _fetch_financial_data(
27 | bs_query_func,
28 | data_type_name: str,
29 | code: str,
30 | year: str,
31 | quarter: int
32 | ) -> pd.DataFrame:
33 | logger.info(
34 | f"Fetching {data_type_name} data for {code}, year={year}, quarter={quarter}")
35 | try:
36 | with baostock_login_context():
37 | # Assuming all these functions take code, year, quarter
38 | rs = bs_query_func(code=code, year=year, quarter=quarter)
39 |
40 | if rs.error_code != '0':
41 | logger.error(
42 | f"Baostock API error ({data_type_name}) for {code}: {rs.error_msg} (code: {rs.error_code})")
43 | if "no record found" in rs.error_msg.lower() or rs.error_code == '10002':
44 | raise NoDataFoundError(
45 | f"No {data_type_name} data found for {code}, {year}Q{quarter}. Baostock msg: {rs.error_msg}")
46 | else:
47 | raise DataSourceError(
48 | f"Baostock API error fetching {data_type_name} data: {rs.error_msg} (code: {rs.error_code})")
49 |
50 | data_list = []
51 | while rs.next():
52 | data_list.append(rs.get_row_data())
53 |
54 | if not data_list:
55 | logger.warning(
56 | f"No {data_type_name} data found for {code}, {year}Q{quarter} (empty result set from Baostock).")
57 | raise NoDataFoundError(
58 | f"No {data_type_name} data found for {code}, {year}Q{quarter} (empty result set).")
59 |
60 | result_df = pd.DataFrame(data_list, columns=rs.fields)
61 | logger.info(
62 | f"Retrieved {len(result_df)} {data_type_name} records for {code}, {year}Q{quarter}.")
63 | return result_df
64 |
65 | except (LoginError, NoDataFoundError, DataSourceError, ValueError) as e:
66 | logger.warning(
67 | f"Caught known error fetching {data_type_name} data for {code}: {type(e).__name__}")
68 | raise e
69 | except Exception as e:
70 | logger.exception(
71 | f"Unexpected error fetching {data_type_name} data for {code}: {e}")
72 | raise DataSourceError(
73 | f"Unexpected error fetching {data_type_name} data for {code}: {e}")
74 |
75 | # Helper function to reduce repetition for index constituent data fetching
76 |
77 |
78 | def _fetch_index_constituent_data(
79 | bs_query_func,
80 | index_name: str,
81 | date: Optional[str] = None
82 | ) -> pd.DataFrame:
83 | logger.info(
84 | f"Fetching {index_name} constituents for date={date or 'latest'}")
85 | try:
86 | with baostock_login_context():
87 | # date is optional, defaults to latest
88 | rs = bs_query_func(date=date)
89 |
90 | if rs.error_code != '0':
91 | logger.error(
92 | f"Baostock API error ({index_name} Constituents) for date {date}: {rs.error_msg} (code: {rs.error_code})")
93 | if "no record found" in rs.error_msg.lower() or rs.error_code == '10002':
94 | raise NoDataFoundError(
95 | f"No {index_name} constituent data found for date {date}. Baostock msg: {rs.error_msg}")
96 | else:
97 | raise DataSourceError(
98 | f"Baostock API error fetching {index_name} constituents: {rs.error_msg} (code: {rs.error_code})")
99 |
100 | data_list = []
101 | while rs.next():
102 | data_list.append(rs.get_row_data())
103 |
104 | if not data_list:
105 | logger.warning(
106 | f"No {index_name} constituent data found for date {date} (empty result set).")
107 | raise NoDataFoundError(
108 | f"No {index_name} constituent data found for date {date} (empty result set).")
109 |
110 | result_df = pd.DataFrame(data_list, columns=rs.fields)
111 | logger.info(
112 | f"Retrieved {len(result_df)} {index_name} constituents for date {date or 'latest'}.")
113 | return result_df
114 |
115 | except (LoginError, NoDataFoundError, DataSourceError, ValueError) as e:
116 | logger.warning(
117 | f"Caught known error fetching {index_name} constituents for date {date}: {type(e).__name__}")
118 | raise e
119 | except Exception as e:
120 | logger.exception(
121 | f"Unexpected error fetching {index_name} constituents for date {date}: {e}")
122 | raise DataSourceError(
123 | f"Unexpected error fetching {index_name} constituents for date {date}: {e}")
124 |
125 | # Helper function to reduce repetition for macroeconomic data fetching
126 |
127 |
128 | def _fetch_macro_data(
129 | bs_query_func,
130 | data_type_name: str,
131 | start_date: Optional[str] = None,
132 | end_date: Optional[str] = None,
133 | **kwargs # For extra params like yearType
134 | ) -> pd.DataFrame:
135 | date_range_log = f"from {start_date or 'default'} to {end_date or 'default'}"
136 | kwargs_log = f", extra_args={kwargs}" if kwargs else ""
137 | logger.info(f"Fetching {data_type_name} data {date_range_log}{kwargs_log}")
138 | try:
139 | with baostock_login_context():
140 | rs = bs_query_func(start_date=start_date,
141 | end_date=end_date, **kwargs)
142 |
143 | if rs.error_code != '0':
144 | logger.error(
145 | f"Baostock API error ({data_type_name}): {rs.error_msg} (code: {rs.error_code})")
146 | if "no record found" in rs.error_msg.lower() or rs.error_code == '10002':
147 | raise NoDataFoundError(
148 | f"No {data_type_name} data found for the specified criteria. Baostock msg: {rs.error_msg}")
149 | else:
150 | raise DataSourceError(
151 | f"Baostock API error fetching {data_type_name} data: {rs.error_msg} (code: {rs.error_code})")
152 |
153 | data_list = []
154 | while rs.next():
155 | data_list.append(rs.get_row_data())
156 |
157 | if not data_list:
158 | logger.warning(
159 | f"No {data_type_name} data found for the specified criteria (empty result set).")
160 | raise NoDataFoundError(
161 | f"No {data_type_name} data found for the specified criteria (empty result set).")
162 |
163 | result_df = pd.DataFrame(data_list, columns=rs.fields)
164 | logger.info(
165 | f"Retrieved {len(result_df)} {data_type_name} records.")
166 | return result_df
167 |
168 | except (LoginError, NoDataFoundError, DataSourceError, ValueError) as e:
169 | logger.warning(
170 | f"Caught known error fetching {data_type_name} data: {type(e).__name__}")
171 | raise e
172 | except Exception as e:
173 | logger.exception(
174 | f"Unexpected error fetching {data_type_name} data: {e}")
175 | raise DataSourceError(
176 | f"Unexpected error fetching {data_type_name} data: {e}")
177 |
178 |
179 | class BaostockDataSource(FinancialDataSource):
180 | """
181 | Concrete implementation of FinancialDataSource using the Baostock library.
182 | """
183 |
184 | def _format_fields(self, fields: Optional[List[str]], default_fields: List[str]) -> str:
185 | """Formats the list of fields into a comma-separated string for Baostock."""
186 | if fields is None or not fields:
187 | logger.debug(
188 | f"No specific fields requested, using defaults: {default_fields}")
189 | return ",".join(default_fields)
190 | # Basic validation: ensure requested fields are strings
191 | if not all(isinstance(f, str) for f in fields):
192 | raise ValueError("All items in the fields list must be strings.")
193 | logger.debug(f"Using requested fields: {fields}")
194 | return ",".join(fields)
195 |
196 | def get_historical_k_data(
197 | self,
198 | code: str,
199 | start_date: str,
200 | end_date: str,
201 | frequency: str = "d",
202 | adjust_flag: str = "3",
203 | fields: Optional[List[str]] = None,
204 | ) -> pd.DataFrame:
205 | """Fetches historical K-line data using Baostock."""
206 | logger.info(
207 | f"Fetching K-data for {code} ({start_date} to {end_date}), freq={frequency}, adjust={adjust_flag}")
208 | try:
209 | formatted_fields = self._format_fields(fields, DEFAULT_K_FIELDS)
210 | logger.debug(
211 | f"Requesting fields from Baostock: {formatted_fields}")
212 |
213 | with baostock_login_context():
214 | rs = bs.query_history_k_data_plus(
215 | code,
216 | formatted_fields,
217 | start_date=start_date,
218 | end_date=end_date,
219 | frequency=frequency,
220 | adjustflag=adjust_flag
221 | )
222 |
223 | if rs.error_code != '0':
224 | logger.error(
225 | f"Baostock API error (K-data) for {code}: {rs.error_msg} (code: {rs.error_code})")
226 | # Check common error codes, e.g., for no data
227 | if "no record found" in rs.error_msg.lower() or rs.error_code == '10002': # Example error code
228 | raise NoDataFoundError(
229 | f"No historical data found for {code} in the specified range. Baostock msg: {rs.error_msg}")
230 | else:
231 | raise DataSourceError(
232 | f"Baostock API error fetching K-data: {rs.error_msg} (code: {rs.error_code})")
233 |
234 | data_list = []
235 | while rs.next():
236 | data_list.append(rs.get_row_data())
237 |
238 | if not data_list:
239 | logger.warning(
240 | f"No historical data found for {code} in range (empty result set from Baostock).")
241 | raise NoDataFoundError(
242 | f"No historical data found for {code} in the specified range (empty result set).")
243 |
244 | # Crucial: Use rs.fields for column names
245 | result_df = pd.DataFrame(data_list, columns=rs.fields)
246 | logger.info(f"Retrieved {len(result_df)} records for {code}.")
247 | return result_df
248 |
249 | except (LoginError, NoDataFoundError, DataSourceError, ValueError) as e:
250 | # Re-raise known errors
251 | logger.warning(
252 | f"Caught known error fetching K-data for {code}: {type(e).__name__}")
253 | raise e
254 | except Exception as e:
255 | # Wrap unexpected errors
256 | # Use logger.exception to include traceback
257 | logger.exception(
258 | f"Unexpected error fetching K-data for {code}: {e}")
259 | raise DataSourceError(
260 | f"Unexpected error fetching K-data for {code}: {e}")
261 |
262 | def get_stock_basic_info(self, code: str, fields: Optional[List[str]] = None) -> pd.DataFrame:
263 | """Fetches basic stock information using Baostock."""
264 | logger.info(f"Fetching basic info for {code}")
265 | try:
266 | # Note: query_stock_basic doesn't seem to have a fields parameter in docs,
267 | # but we keep the signature consistent. It returns a fixed set.
268 | # We will use the `fields` argument post-query to select columns if needed.
269 | logger.debug(
270 | f"Requesting basic info for {code}. Optional fields requested: {fields}")
271 |
272 | with baostock_login_context():
273 | # Example: Fetch basic info; adjust API call if needed based on baostock docs
274 | # rs = bs.query_stock_basic(code=code, code_name=code_name) # If supporting name lookup
275 | rs = bs.query_stock_basic(code=code)
276 |
277 | if rs.error_code != '0':
278 | logger.error(
279 | f"Baostock API error (Basic Info) for {code}: {rs.error_msg} (code: {rs.error_code})")
280 | if "no record found" in rs.error_msg.lower() or rs.error_code == '10002':
281 | raise NoDataFoundError(
282 | f"No basic info found for {code}. Baostock msg: {rs.error_msg}")
283 | else:
284 | raise DataSourceError(
285 | f"Baostock API error fetching basic info: {rs.error_msg} (code: {rs.error_code})")
286 |
287 | data_list = []
288 | while rs.next():
289 | data_list.append(rs.get_row_data())
290 |
291 | if not data_list:
292 | logger.warning(
293 | f"No basic info found for {code} (empty result set from Baostock).")
294 | raise NoDataFoundError(
295 | f"No basic info found for {code} (empty result set).")
296 |
297 | # Crucial: Use rs.fields for column names
298 | result_df = pd.DataFrame(data_list, columns=rs.fields)
299 | logger.info(
300 | f"Retrieved basic info for {code}. Columns: {result_df.columns.tolist()}")
301 |
302 | # Optional: Select subset of columns if `fields` argument was provided
303 | if fields:
304 | available_cols = [
305 | col for col in fields if col in result_df.columns]
306 | if not available_cols:
307 | raise ValueError(
308 | f"None of the requested fields {fields} are available in the basic info result.")
309 | logger.debug(
310 | f"Selecting columns: {available_cols} from basic info for {code}")
311 | result_df = result_df[available_cols]
312 |
313 | return result_df
314 |
315 | except (LoginError, NoDataFoundError, DataSourceError, ValueError) as e:
316 | logger.warning(
317 | f"Caught known error fetching basic info for {code}: {type(e).__name__}")
318 | raise e
319 | except Exception as e:
320 | logger.exception(
321 | f"Unexpected error fetching basic info for {code}: {e}")
322 | raise DataSourceError(
323 | f"Unexpected error fetching basic info for {code}: {e}")
324 |
325 | def get_dividend_data(self, code: str, year: str, year_type: str = "report") -> pd.DataFrame:
326 | """Fetches dividend information using Baostock."""
327 | logger.info(
328 | f"Fetching dividend data for {code}, year={year}, year_type={year_type}")
329 | try:
330 | with baostock_login_context():
331 | rs = bs.query_dividend_data(
332 | code=code, year=year, yearType=year_type)
333 |
334 | if rs.error_code != '0':
335 | logger.error(
336 | f"Baostock API error (Dividend) for {code}: {rs.error_msg} (code: {rs.error_code})")
337 | if "no record found" in rs.error_msg.lower() or rs.error_code == '10002':
338 | raise NoDataFoundError(
339 | f"No dividend data found for {code} and year {year}. Baostock msg: {rs.error_msg}")
340 | else:
341 | raise DataSourceError(
342 | f"Baostock API error fetching dividend data: {rs.error_msg} (code: {rs.error_code})")
343 |
344 | data_list = []
345 | while rs.next():
346 | data_list.append(rs.get_row_data())
347 |
348 | if not data_list:
349 | logger.warning(
350 | f"No dividend data found for {code}, year {year} (empty result set from Baostock).")
351 | raise NoDataFoundError(
352 | f"No dividend data found for {code}, year {year} (empty result set).")
353 |
354 | result_df = pd.DataFrame(data_list, columns=rs.fields)
355 | logger.info(
356 | f"Retrieved {len(result_df)} dividend records for {code}, year {year}.")
357 | return result_df
358 |
359 | except (LoginError, NoDataFoundError, DataSourceError, ValueError) as e:
360 | logger.warning(
361 | f"Caught known error fetching dividend data for {code}: {type(e).__name__}")
362 | raise e
363 | except Exception as e:
364 | logger.exception(
365 | f"Unexpected error fetching dividend data for {code}: {e}")
366 | raise DataSourceError(
367 | f"Unexpected error fetching dividend data for {code}: {e}")
368 |
369 | def get_adjust_factor_data(self, code: str, start_date: str, end_date: str) -> pd.DataFrame:
370 | """Fetches adjustment factor data using Baostock."""
371 | logger.info(
372 | f"Fetching adjustment factor data for {code} ({start_date} to {end_date})")
373 | try:
374 | with baostock_login_context():
375 | rs = bs.query_adjust_factor(
376 | code=code, start_date=start_date, end_date=end_date)
377 |
378 | if rs.error_code != '0':
379 | logger.error(
380 | f"Baostock API error (Adjust Factor) for {code}: {rs.error_msg} (code: {rs.error_code})")
381 | if "no record found" in rs.error_msg.lower() or rs.error_code == '10002':
382 | raise NoDataFoundError(
383 | f"No adjustment factor data found for {code} in the specified range. Baostock msg: {rs.error_msg}")
384 | else:
385 | raise DataSourceError(
386 | f"Baostock API error fetching adjust factor data: {rs.error_msg} (code: {rs.error_code})")
387 |
388 | data_list = []
389 | while rs.next():
390 | data_list.append(rs.get_row_data())
391 |
392 | if not data_list:
393 | logger.warning(
394 | f"No adjustment factor data found for {code} in range (empty result set from Baostock).")
395 | raise NoDataFoundError(
396 | f"No adjustment factor data found for {code} in the specified range (empty result set).")
397 |
398 | result_df = pd.DataFrame(data_list, columns=rs.fields)
399 | logger.info(
400 | f"Retrieved {len(result_df)} adjustment factor records for {code}.")
401 | return result_df
402 |
403 | except (LoginError, NoDataFoundError, DataSourceError, ValueError) as e:
404 | logger.warning(
405 | f"Caught known error fetching adjust factor data for {code}: {type(e).__name__}")
406 | raise e
407 | except Exception as e:
408 | logger.exception(
409 | f"Unexpected error fetching adjust factor data for {code}: {e}")
410 | raise DataSourceError(
411 | f"Unexpected error fetching adjust factor data for {code}: {e}")
412 |
413 | def get_profit_data(self, code: str, year: str, quarter: int) -> pd.DataFrame:
414 | """Fetches quarterly profitability data using Baostock."""
415 | return _fetch_financial_data(bs.query_profit_data, "Profitability", code, year, quarter)
416 |
417 | def get_operation_data(self, code: str, year: str, quarter: int) -> pd.DataFrame:
418 | """Fetches quarterly operation capability data using Baostock."""
419 | return _fetch_financial_data(bs.query_operation_data, "Operation Capability", code, year, quarter)
420 |
421 | def get_growth_data(self, code: str, year: str, quarter: int) -> pd.DataFrame:
422 | """Fetches quarterly growth capability data using Baostock."""
423 | return _fetch_financial_data(bs.query_growth_data, "Growth Capability", code, year, quarter)
424 |
425 | def get_balance_data(self, code: str, year: str, quarter: int) -> pd.DataFrame:
426 | """Fetches quarterly balance sheet data (solvency) using Baostock."""
427 | return _fetch_financial_data(bs.query_balance_data, "Balance Sheet", code, year, quarter)
428 |
429 | def get_cash_flow_data(self, code: str, year: str, quarter: int) -> pd.DataFrame:
430 | """Fetches quarterly cash flow data using Baostock."""
431 | return _fetch_financial_data(bs.query_cash_flow_data, "Cash Flow", code, year, quarter)
432 |
433 | def get_dupont_data(self, code: str, year: str, quarter: int) -> pd.DataFrame:
434 | """Fetches quarterly DuPont analysis data using Baostock."""
435 | return _fetch_financial_data(bs.query_dupont_data, "DuPont Analysis", code, year, quarter)
436 |
437 | def get_performance_express_report(self, code: str, start_date: str, end_date: str) -> pd.DataFrame:
438 | """Fetches performance express reports (业绩快报) using Baostock."""
439 | logger.info(
440 | f"Fetching Performance Express Report for {code} ({start_date} to {end_date})")
441 | try:
442 | with baostock_login_context():
443 | rs = bs.query_performance_express_report(
444 | code=code, start_date=start_date, end_date=end_date)
445 |
446 | if rs.error_code != '0':
447 | logger.error(
448 | f"Baostock API error (Perf Express) for {code}: {rs.error_msg} (code: {rs.error_code})")
449 | if "no record found" in rs.error_msg.lower() or rs.error_code == '10002':
450 | raise NoDataFoundError(
451 | f"No performance express report found for {code} in range {start_date}-{end_date}. Baostock msg: {rs.error_msg}")
452 | else:
453 | raise DataSourceError(
454 | f"Baostock API error fetching performance express report: {rs.error_msg} (code: {rs.error_code})")
455 |
456 | data_list = []
457 | while rs.next():
458 | data_list.append(rs.get_row_data())
459 |
460 | if not data_list:
461 | logger.warning(
462 | f"No performance express report found for {code} in range {start_date}-{end_date} (empty result set).")
463 | raise NoDataFoundError(
464 | f"No performance express report found for {code} in range {start_date}-{end_date} (empty result set).")
465 |
466 | result_df = pd.DataFrame(data_list, columns=rs.fields)
467 | logger.info(
468 | f"Retrieved {len(result_df)} performance express report records for {code}.")
469 | return result_df
470 |
471 | except (LoginError, NoDataFoundError, DataSourceError, ValueError) as e:
472 | logger.warning(
473 | f"Caught known error fetching performance express report for {code}: {type(e).__name__}")
474 | raise e
475 | except Exception as e:
476 | logger.exception(
477 | f"Unexpected error fetching performance express report for {code}: {e}")
478 | raise DataSourceError(
479 | f"Unexpected error fetching performance express report for {code}: {e}")
480 |
481 | def get_forecast_report(self, code: str, start_date: str, end_date: str) -> pd.DataFrame:
482 | """Fetches performance forecast reports (业绩预告) using Baostock."""
483 | logger.info(
484 | f"Fetching Performance Forecast Report for {code} ({start_date} to {end_date})")
485 | try:
486 | with baostock_login_context():
487 | rs = bs.query_forecast_report(
488 | code=code, start_date=start_date, end_date=end_date)
489 | # Note: Baostock docs mention pagination for this, but the Python API doesn't seem to expose it directly.
490 | # We fetch all available pages in the loop below.
491 |
492 | if rs.error_code != '0':
493 | logger.error(
494 | f"Baostock API error (Forecast) for {code}: {rs.error_msg} (code: {rs.error_code})")
495 | if "no record found" in rs.error_msg.lower() or rs.error_code == '10002':
496 | raise NoDataFoundError(
497 | f"No performance forecast report found for {code} in range {start_date}-{end_date}. Baostock msg: {rs.error_msg}")
498 | else:
499 | raise DataSourceError(
500 | f"Baostock API error fetching performance forecast report: {rs.error_msg} (code: {rs.error_code})")
501 |
502 | data_list = []
503 | while rs.next(): # Loop should handle pagination implicitly if rs manages it
504 | data_list.append(rs.get_row_data())
505 |
506 | if not data_list:
507 | logger.warning(
508 | f"No performance forecast report found for {code} in range {start_date}-{end_date} (empty result set).")
509 | raise NoDataFoundError(
510 | f"No performance forecast report found for {code} in range {start_date}-{end_date} (empty result set).")
511 |
512 | result_df = pd.DataFrame(data_list, columns=rs.fields)
513 | logger.info(
514 | f"Retrieved {len(result_df)} performance forecast report records for {code}.")
515 | return result_df
516 |
517 | except (LoginError, NoDataFoundError, DataSourceError, ValueError) as e:
518 | logger.warning(
519 | f"Caught known error fetching performance forecast report for {code}: {type(e).__name__}")
520 | raise e
521 | except Exception as e:
522 | logger.exception(
523 | f"Unexpected error fetching performance forecast report for {code}: {e}")
524 | raise DataSourceError(
525 | f"Unexpected error fetching performance forecast report for {code}: {e}")
526 |
527 | def get_stock_industry(self, code: Optional[str] = None, date: Optional[str] = None) -> pd.DataFrame:
528 | """Fetches industry classification using Baostock."""
529 | log_msg = f"Fetching industry data for code={code or 'all'}, date={date or 'latest'}"
530 | logger.info(log_msg)
531 | try:
532 | with baostock_login_context():
533 | rs = bs.query_stock_industry(code=code, date=date)
534 |
535 | if rs.error_code != '0':
536 | logger.error(
537 | f"Baostock API error (Industry) for {code}, {date}: {rs.error_msg} (code: {rs.error_code})")
538 | if "no record found" in rs.error_msg.lower() or rs.error_code == '10002':
539 | raise NoDataFoundError(
540 | f"No industry data found for {code}, {date}. Baostock msg: {rs.error_msg}")
541 | else:
542 | raise DataSourceError(
543 | f"Baostock API error fetching industry data: {rs.error_msg} (code: {rs.error_code})")
544 |
545 | data_list = []
546 | while rs.next():
547 | data_list.append(rs.get_row_data())
548 |
549 | if not data_list:
550 | logger.warning(
551 | f"No industry data found for {code}, {date} (empty result set).")
552 | raise NoDataFoundError(
553 | f"No industry data found for {code}, {date} (empty result set).")
554 |
555 | result_df = pd.DataFrame(data_list, columns=rs.fields)
556 | logger.info(
557 | f"Retrieved {len(result_df)} industry records for {code or 'all'}, {date or 'latest'}.")
558 | return result_df
559 |
560 | except (LoginError, NoDataFoundError, DataSourceError, ValueError) as e:
561 | logger.warning(
562 | f"Caught known error fetching industry data for {code}, {date}: {type(e).__name__}")
563 | raise e
564 | except Exception as e:
565 | logger.exception(
566 | f"Unexpected error fetching industry data for {code}, {date}: {e}")
567 | raise DataSourceError(
568 | f"Unexpected error fetching industry data for {code}, {date}: {e}")
569 |
570 | def get_sz50_stocks(self, date: Optional[str] = None) -> pd.DataFrame:
571 | """Fetches SZSE 50 index constituents using Baostock."""
572 | return _fetch_index_constituent_data(bs.query_sz50_stocks, "SZSE 50", date)
573 |
574 | def get_hs300_stocks(self, date: Optional[str] = None) -> pd.DataFrame:
575 | """Fetches CSI 300 index constituents using Baostock."""
576 | return _fetch_index_constituent_data(bs.query_hs300_stocks, "CSI 300", date)
577 |
578 | def get_zz500_stocks(self, date: Optional[str] = None) -> pd.DataFrame:
579 | """Fetches CSI 500 index constituents using Baostock."""
580 | return _fetch_index_constituent_data(bs.query_zz500_stocks, "CSI 500", date)
581 |
582 | def get_trade_dates(self, start_date: Optional[str] = None, end_date: Optional[str] = None) -> pd.DataFrame:
583 | """Fetches trading dates using Baostock."""
584 | logger.info(
585 | f"Fetching trade dates from {start_date or 'default'} to {end_date or 'default'}")
586 | try:
587 | with baostock_login_context(): # Login might not be strictly needed for this, but keeping consistent
588 | rs = bs.query_trade_dates(
589 | start_date=start_date, end_date=end_date)
590 |
591 | if rs.error_code != '0':
592 | logger.error(
593 | f"Baostock API error (Trade Dates): {rs.error_msg} (code: {rs.error_code})")
594 | # Unlikely to have 'no record found' for dates, but handle API errors
595 | raise DataSourceError(
596 | f"Baostock API error fetching trade dates: {rs.error_msg} (code: {rs.error_code})")
597 |
598 | data_list = []
599 | while rs.next():
600 | data_list.append(rs.get_row_data())
601 |
602 | if not data_list:
603 | # This case should ideally not happen if the API returns a valid range
604 | logger.warning(
605 | f"No trade dates returned for range {start_date}-{end_date} (empty result set).")
606 | raise NoDataFoundError(
607 | f"No trade dates found for range {start_date}-{end_date} (empty result set).")
608 |
609 | result_df = pd.DataFrame(data_list, columns=rs.fields)
610 | logger.info(f"Retrieved {len(result_df)} trade date records.")
611 | return result_df
612 |
613 | except (LoginError, NoDataFoundError, DataSourceError, ValueError) as e:
614 | logger.warning(
615 | f"Caught known error fetching trade dates: {type(e).__name__}")
616 | raise e
617 | except Exception as e:
618 | logger.exception(f"Unexpected error fetching trade dates: {e}")
619 | raise DataSourceError(
620 | f"Unexpected error fetching trade dates: {e}")
621 |
622 | def get_all_stock(self, date: Optional[str] = None) -> pd.DataFrame:
623 | """Fetches all stock list for a given date using Baostock."""
624 | logger.info(f"Fetching all stock list for date={date or 'default'}")
625 | try:
626 | with baostock_login_context():
627 | rs = bs.query_all_stock(day=date)
628 |
629 | if rs.error_code != '0':
630 | logger.error(
631 | f"Baostock API error (All Stock) for date {date}: {rs.error_msg} (code: {rs.error_code})")
632 | if "no record found" in rs.error_msg.lower() or rs.error_code == '10002': # Check if this applies
633 | raise NoDataFoundError(
634 | f"No stock data found for date {date}. Baostock msg: {rs.error_msg}")
635 | else:
636 | raise DataSourceError(
637 | f"Baostock API error fetching all stock list: {rs.error_msg} (code: {rs.error_code})")
638 |
639 | data_list = []
640 | while rs.next():
641 | data_list.append(rs.get_row_data())
642 |
643 | if not data_list:
644 | logger.warning(
645 | f"No stock list returned for date {date} (empty result set).")
646 | raise NoDataFoundError(
647 | f"No stock list found for date {date} (empty result set).")
648 |
649 | result_df = pd.DataFrame(data_list, columns=rs.fields)
650 | logger.info(
651 | f"Retrieved {len(result_df)} stock records for date {date or 'default'}.")
652 | return result_df
653 |
654 | except (LoginError, NoDataFoundError, DataSourceError, ValueError) as e:
655 | logger.warning(
656 | f"Caught known error fetching all stock list for date {date}: {type(e).__name__}")
657 | raise e
658 | except Exception as e:
659 | logger.exception(
660 | f"Unexpected error fetching all stock list for date {date}: {e}")
661 | raise DataSourceError(
662 | f"Unexpected error fetching all stock list for date {date}: {e}")
663 |
664 | def get_deposit_rate_data(self, start_date: Optional[str] = None, end_date: Optional[str] = None) -> pd.DataFrame:
665 | """Fetches benchmark deposit rates using Baostock."""
666 | return _fetch_macro_data(bs.query_deposit_rate_data, "Deposit Rate", start_date, end_date)
667 |
668 | def get_loan_rate_data(self, start_date: Optional[str] = None, end_date: Optional[str] = None) -> pd.DataFrame:
669 | """Fetches benchmark loan rates using Baostock."""
670 | return _fetch_macro_data(bs.query_loan_rate_data, "Loan Rate", start_date, end_date)
671 |
672 | def get_required_reserve_ratio_data(self, start_date: Optional[str] = None, end_date: Optional[str] = None, year_type: str = '0') -> pd.DataFrame:
673 | """Fetches required reserve ratio data using Baostock."""
674 | # Note the extra yearType parameter handled by kwargs
675 | return _fetch_macro_data(bs.query_required_reserve_ratio_data, "Required Reserve Ratio", start_date, end_date, yearType=year_type)
676 |
677 | def get_money_supply_data_month(self, start_date: Optional[str] = None, end_date: Optional[str] = None) -> pd.DataFrame:
678 | """Fetches monthly money supply data (M0, M1, M2) using Baostock."""
679 | # Baostock expects YYYY-MM format for dates here
680 | return _fetch_macro_data(bs.query_money_supply_data_month, "Monthly Money Supply", start_date, end_date)
681 |
682 | def get_money_supply_data_year(self, start_date: Optional[str] = None, end_date: Optional[str] = None) -> pd.DataFrame:
683 | """Fetches yearly money supply data (M0, M1, M2 - year end balance) using Baostock."""
684 | # Baostock expects YYYY format for dates here
685 | return _fetch_macro_data(bs.query_money_supply_data_year, "Yearly Money Supply", start_date, end_date)
686 |
687 | # Note: SHIBOR is not available in current Baostock API bindings used; not implemented.
688 |
```