From 2220b5084e3f6dea3975dfd4ab5ebcdbd448cddd Mon Sep 17 00:00:00 2001 From: sam Date: Thu, 16 Oct 2025 09:54:55 +0800 Subject: [PATCH] refactor decision workflow and optimize feature snapshot loading --- app/agents/game.py | 492 +++-- app/backtest/engine.py | 30 +- app/features/factors.py | 58 +- app/ingest/api_client.py | 1387 +++++++++++++ app/ingest/coverage.py | 395 ++++ app/ingest/tushare.py | 1838 +---------------- app/utils/data_access.py | 98 +- app/utils/feature_snapshots.py | 58 + docs/TODO.md | 2 +- .../business_logic_healthcheck.md | 50 + 10 files changed, 2331 insertions(+), 2077 deletions(-) create mode 100644 app/ingest/api_client.py create mode 100644 app/ingest/coverage.py create mode 100644 app/utils/feature_snapshots.py create mode 100644 docs/architecture/business_logic_healthcheck.md diff --git a/app/agents/game.py b/app/agents/game.py index 6c329f8..597de6b 100644 --- a/app/agents/game.py +++ b/app/agents/game.py @@ -147,79 +147,117 @@ def target_weight_for_action(action: AgentAction) -> float: return mapping[action] -def decide( - context: AgentContext, - agents: Iterable[Agent], - weights: Mapping[str, float], - method: str = "nash", - department_manager: Optional[DepartmentManager] = None, - department_context: Optional[DepartmentContext] = None, -) -> Decision: - agent_list = list(agents) - utilities = compute_utilities(agent_list, context) - feas_actions = feasible_actions(agent_list, context) - if not feas_actions: - return Decision( - action=AgentAction.HOLD, - confidence=0.0, - target_weight=0.0, - feasible_actions=[], - utilities=utilities, +class DecisionWorkflow: + def __init__( + self, + context: AgentContext, + agents: Iterable[Agent], + weights: Mapping[str, float], + method: str, + department_manager: Optional[DepartmentManager], + department_context: Optional[DepartmentContext], + ) -> None: + self.context = context + self.agent_list = list(agents) + self.method = method + self.department_manager = department_manager + self.department_context = department_context + self.utilities = compute_utilities(self.agent_list, context) + self.feasible_actions = feasible_actions(self.agent_list, context) + self.raw_weights = dict(weights) + self.department_decisions: Dict[str, DepartmentDecision] = {} + self.department_votes: Dict[str, float] = {} + self.host = ProtocolHost() + self.host_trace = self.host.bootstrap_trace( + session_id=f"{context.ts_code}:{context.trade_date}", + ts_code=context.ts_code, + trade_date=context.trade_date, + ) + self.briefing_round = self.host.start_round( + self.host_trace, + agenda="situation_briefing", + structure=GameStructure.SIGNALING, + ) + self.host.handle_message(self.briefing_round, _host_briefing_message(context)) + self.host.finalize_round(self.briefing_round) + self.department_round: Optional[RoundSummary] = None + self.risk_round: Optional[RoundSummary] = None + self.execution_round: Optional[RoundSummary] = None + self.belief_updates: Dict[str, BeliefUpdate] = {} + self.prediction_round: Optional[RoundSummary] = None + self.norm_weights: Dict[str, float] = {} + self.filtered_utilities: Dict[AgentAction, Dict[str, float]] = {} + self.belief_revision: Optional[BeliefRevisionResult] = None + + def run(self) -> Decision: + if not self.feasible_actions: + return Decision( + action=AgentAction.HOLD, + confidence=0.0, + target_weight=0.0, + feasible_actions=[], + utilities=self.utilities, + ) + + self._evaluate_departments() + action, confidence = self._select_action() + risk_assessment = self._apply_risk(action) + exec_action = self._finalize_execution(action, risk_assessment) + self._finalize_conflicts(exec_action) + rounds = self.host_trace.rounds or _build_round_summaries( + self.department_decisions, + action, + self.department_votes, ) - raw_weights = dict(weights) - department_decisions: Dict[str, DepartmentDecision] = {} - department_votes: Dict[str, float] = {} - host = ProtocolHost() - host_trace = host.bootstrap_trace( - session_id=f"{context.ts_code}:{context.trade_date}", - ts_code=context.ts_code, - trade_date=context.trade_date, - ) - briefing_round = host.start_round( - host_trace, - agenda="situation_briefing", - structure=GameStructure.SIGNALING, - ) - host.handle_message(briefing_round, _host_briefing_message(context)) - host.finalize_round(briefing_round) - department_round: Optional[RoundSummary] = None - risk_round: Optional[RoundSummary] = None - execution_round: Optional[RoundSummary] = None - belief_updates: Dict[str, BeliefUpdate] = {} + return Decision( + action=action, + confidence=confidence, + target_weight=target_weight_for_action(action), + feasible_actions=self.feasible_actions, + utilities=self.utilities, + department_decisions=self.department_decisions, + department_votes=self.department_votes, + requires_review=risk_assessment.status != "ok", + rounds=rounds, + risk_assessment=risk_assessment, + belief_updates=self.belief_updates, + belief_revision=self.belief_revision, + ) - if department_manager: - dept_context = department_context - if dept_context is None: - dept_context = DepartmentContext( - ts_code=context.ts_code, - trade_date=context.trade_date, - features=dict(context.features), - market_snapshot=dict(getattr(context, "market_snapshot", {}) or {}), - raw=dict(getattr(context, "raw", {}) or {}), - ) - department_decisions = department_manager.evaluate(dept_context) - if department_decisions: - department_round = host.start_round( - host_trace, + def _evaluate_departments(self) -> None: + if not self.department_manager: + return + + dept_context = self.department_context or DepartmentContext( + ts_code=self.context.ts_code, + trade_date=self.context.trade_date, + features=dict(self.context.features), + market_snapshot=dict(getattr(self.context, "market_snapshot", {}) or {}), + raw=dict(getattr(self.context, "raw", {}) or {}), + ) + self.department_decisions = self.department_manager.evaluate(dept_context) + if self.department_decisions: + self.department_round = self.host.start_round( + self.host_trace, agenda="department_consensus", structure=GameStructure.REPEATED, ) - for code, decision in department_decisions.items(): + for code, decision in self.department_decisions.items(): agent_key = f"dept_{code}" - dept_agent = department_manager.agents.get(code) + dept_agent = self.department_manager.agents.get(code) weight = dept_agent.settings.weight if dept_agent else 1.0 - raw_weights[agent_key] = weight + self.raw_weights[agent_key] = weight scores = _department_scores(decision) for action in ACTIONS: - utilities.setdefault(action, {})[agent_key] = scores[action] + self.utilities.setdefault(action, {})[agent_key] = scores[action] bucket = _department_vote_bucket(decision.action) if bucket: - department_votes[bucket] = department_votes.get(bucket, 0.0) + weight * decision.confidence - if department_round: + self.department_votes[bucket] = self.department_votes.get(bucket, 0.0) + weight * decision.confidence + if self.department_round: message = _department_message(code, decision) - host.handle_message(department_round, message) - belief_updates[code] = BeliefUpdate( + self.host.handle_message(self.department_round, message) + self.belief_updates[code] = BeliefUpdate( belief={ "action": decision.action.value, "confidence": decision.confidence, @@ -228,173 +266,185 @@ def decide( rationale=decision.summary, ) - filtered_utilities = {action: utilities[action] for action in feas_actions} - hold_scores = utilities.get(AgentAction.HOLD, {}) - norm_weights = weight_map(raw_weights) - prediction_round = host.start_round( - host_trace, - agenda="prediction_alignment", - structure=GameStructure.REPEATED, - ) - prediction_message, prediction_summary = _prediction_summary_message(filtered_utilities, norm_weights) - host.handle_message(prediction_round, prediction_message) - host.finalize_round(prediction_round) - if prediction_summary: - belief_updates["prediction_summary"] = BeliefUpdate( - belief=prediction_summary, - rationale="Aggregated utilities shared during alignment round.", + def _select_action(self) -> Tuple[AgentAction, float]: + self.filtered_utilities = {action: self.utilities[action] for action in self.feasible_actions} + hold_scores = self.utilities.get(AgentAction.HOLD, {}) + self.norm_weights = weight_map(self.raw_weights) + self.prediction_round = self.host.start_round( + self.host_trace, + agenda="prediction_alignment", + structure=GameStructure.REPEATED, ) - - if method == "vote": - action, confidence = vote(filtered_utilities, norm_weights) - else: - action, confidence = nash_bargain(filtered_utilities, norm_weights, hold_scores) - if action not in feas_actions: - action, confidence = vote(filtered_utilities, norm_weights) - - weight = target_weight_for_action(action) - conflict_flag = _department_conflict_flag(department_votes) - - risk_agent = _find_risk_agent(agent_list) - risk_assessment = _evaluate_risk( - context, - action, - department_votes, - conflict_flag, - risk_agent, - ) - requires_review = risk_assessment.status != "ok" - - if department_round: - department_round.notes.setdefault("department_votes", dict(department_votes)) - department_round.outcome = action.value - host.finalize_round(department_round) - - if requires_review: - risk_round = host.ensure_round( - host_trace, - agenda="risk_review", - structure=GameStructure.CUSTOM, - ) - review_message = DialogueMessage( - sender="risk_guard", - role=DialogueRole.RISK, - message_type=MessageType.COUNTER, - content=_risk_review_message(risk_assessment.reason), - confidence=1.0, - references=list(department_votes.keys()), - annotations={ - "department_votes": dict(department_votes), - "risk_reason": risk_assessment.reason, - "recommended_action": ( - risk_assessment.recommended_action.value - if risk_assessment.recommended_action - else None - ), - "notes": dict(risk_assessment.notes), - }, - ) - host.handle_message(risk_round, review_message) - risk_round.notes.setdefault("status", risk_assessment.status) - risk_round.notes.setdefault("reason", risk_assessment.reason) - if risk_assessment.recommended_action: - risk_round.notes.setdefault( - "recommended_action", - risk_assessment.recommended_action.value, + prediction_message, prediction_summary = _prediction_summary_message(self.filtered_utilities, self.norm_weights) + self.host.handle_message(self.prediction_round, prediction_message) + self.host.finalize_round(self.prediction_round) + if prediction_summary: + self.belief_updates["prediction_summary"] = BeliefUpdate( + belief=prediction_summary, + rationale="Aggregated utilities shared during alignment round.", ) - risk_round.outcome = "REVIEW" - host.finalize_round(risk_round) - belief_updates["risk_guard"] = BeliefUpdate( - belief={ - "status": risk_assessment.status, - "reason": risk_assessment.reason, - "recommended_action": ( - risk_assessment.recommended_action.value - if risk_assessment.recommended_action - else None - ), + + if self.method == "vote": + return vote(self.filtered_utilities, self.norm_weights) + + action, confidence = nash_bargain(self.filtered_utilities, self.norm_weights, hold_scores) + if action not in self.feasible_actions: + return vote(self.filtered_utilities, self.norm_weights) + return action, confidence + + def _apply_risk(self, action: AgentAction) -> RiskAssessment: + conflict_flag = _department_conflict_flag(self.department_votes) + risk_agent = _find_risk_agent(self.agent_list) + assessment = _evaluate_risk( + self.context, + action, + self.department_votes, + conflict_flag, + risk_agent, + ) + if self.department_round: + self.department_round.notes.setdefault("department_votes", dict(self.department_votes)) + self.department_round.outcome = action.value + self.host.finalize_round(self.department_round) + + if assessment.status != "ok": + self.risk_round = self.host.ensure_round( + self.host_trace, + agenda="risk_review", + structure=GameStructure.CUSTOM, + ) + review_message = DialogueMessage( + sender="risk_guard", + role=DialogueRole.RISK, + message_type=MessageType.COUNTER, + content=_risk_review_message(assessment.reason), + confidence=1.0, + references=list(self.department_votes.keys()), + annotations={ + "department_votes": dict(self.department_votes), + "risk_reason": assessment.reason, + "recommended_action": ( + assessment.recommended_action.value + if assessment.recommended_action + else None + ), + "notes": dict(assessment.notes), + }, + ) + self.host.handle_message(self.risk_round, review_message) + self.risk_round.notes.setdefault("status", assessment.status) + self.risk_round.notes.setdefault("reason", assessment.reason) + if assessment.recommended_action: + self.risk_round.notes.setdefault( + "recommended_action", + assessment.recommended_action.value, + ) + self.risk_round.outcome = "REVIEW" + self.host.finalize_round(self.risk_round) + self.belief_updates["risk_guard"] = BeliefUpdate( + belief={ + "status": assessment.status, + "reason": assessment.reason, + "recommended_action": ( + assessment.recommended_action.value + if assessment.recommended_action + else None + ), + }, + ) + return assessment + + def _finalize_execution( + self, + action: AgentAction, + assessment: RiskAssessment, + ) -> AgentAction: + self.execution_round = self.host.ensure_round( + self.host_trace, + agenda="execution_summary", + structure=GameStructure.REPEATED, + ) + exec_action = action + exec_weight = target_weight_for_action(action) + exec_status = "normal" + requires_review = assessment.status != "ok" + if requires_review and assessment.recommended_action: + exec_action = assessment.recommended_action + exec_status = "risk_adjusted" + exec_weight = target_weight_for_action(exec_action) + execution_message = DialogueMessage( + sender="execution_engine", + role=DialogueRole.EXECUTION, + message_type=MessageType.DIRECTIVE, + content=f"执行操作 {exec_action.value}", + confidence=1.0, + annotations={ + "target_weight": exec_weight, + "requires_review": requires_review, + "execution_status": exec_status, }, ) - execution_round = host.ensure_round( - host_trace, - agenda="execution_summary", - structure=GameStructure.REPEATED, - ) - exec_action = action - exec_weight = weight - exec_status = "normal" - if requires_review and risk_assessment.recommended_action: - exec_action = risk_assessment.recommended_action - exec_status = "risk_adjusted" - exec_weight = target_weight_for_action(exec_action) - execution_message = DialogueMessage( - sender="execution_engine", - role=DialogueRole.EXECUTION, - message_type=MessageType.DIRECTIVE, - content=f"执行操作 {exec_action.value}", - confidence=1.0, - annotations={ - "target_weight": exec_weight, - "requires_review": requires_review, - "execution_status": exec_status, - }, - ) - host.handle_message(execution_round, execution_message) - execution_round.outcome = exec_action.value - execution_round.notes.setdefault("execution_status", exec_status) - if exec_action is not action: - execution_round.notes.setdefault("original_action", action.value) - belief_updates["execution"] = BeliefUpdate( - belief={ - "execution_status": exec_status, - "action": exec_action.value, - "target_weight": exec_weight, - }, - ) - host.finalize_round(execution_round) - host.close(host_trace) - rounds = host_trace.rounds if host_trace.rounds else _build_round_summaries( - department_decisions, - action, - department_votes, - ) - belief_revision = revise_beliefs(belief_updates, exec_action) - if belief_revision.conflicts: - risk_round = host.ensure_round( - host_trace, - agenda="conflict_resolution", - structure=GameStructure.CUSTOM, + self.host.handle_message(self.execution_round, execution_message) + self.execution_round.outcome = exec_action.value + self.execution_round.notes.setdefault("execution_status", exec_status) + if exec_action is not action: + self.execution_round.notes.setdefault("original_action", action.value) + self.belief_updates["execution"] = BeliefUpdate( + belief={ + "execution_status": exec_status, + "action": exec_action.value, + "target_weight": exec_weight, + }, ) - conflict_message = DialogueMessage( - sender="protocol_host", - role=DialogueRole.HOST, - message_type=MessageType.COUNTER, - content="检测到关键冲突,需要后续回合复核。", - annotations={"conflicts": belief_revision.conflicts}, - ) - host.handle_message(risk_round, conflict_message) - risk_round.notes.setdefault("conflicts", belief_revision.conflicts) - host.finalize_round(risk_round) - execution_round.notes.setdefault("consensus_action", belief_revision.consensus_action.value) - execution_round.notes.setdefault("consensus_confidence", belief_revision.consensus_confidence) - if belief_revision.conflicts: - execution_round.notes.setdefault("conflicts", belief_revision.conflicts) - if belief_revision.notes: - execution_round.notes.setdefault("belief_notes", belief_revision.notes) - return Decision( - action=action, - confidence=confidence, - target_weight=weight, - feasible_actions=feas_actions, - utilities=utilities, - department_decisions=department_decisions, - department_votes=department_votes, - requires_review=requires_review, - rounds=rounds, - risk_assessment=risk_assessment, - belief_updates=belief_updates, - belief_revision=belief_revision, + self.host.finalize_round(self.execution_round) + self.execution_round.notes.setdefault("target_weight", exec_weight) + return exec_action + + def _finalize_conflicts(self, exec_action: AgentAction) -> None: + self.host.close(self.host_trace) + self.belief_revision = revise_beliefs(self.belief_updates, exec_action) + if self.belief_revision.conflicts: + conflict_round = self.host.ensure_round( + self.host_trace, + agenda="conflict_resolution", + structure=GameStructure.CUSTOM, + ) + conflict_message = DialogueMessage( + sender="protocol_host", + role=DialogueRole.HOST, + message_type=MessageType.COUNTER, + content="检测到关键冲突,需要后续回合复核。", + annotations={"conflicts": self.belief_revision.conflicts}, + ) + self.host.handle_message(conflict_round, conflict_message) + conflict_round.notes.setdefault("conflicts", self.belief_revision.conflicts) + self.host.finalize_round(conflict_round) + if self.execution_round: + self.execution_round.notes.setdefault("consensus_action", self.belief_revision.consensus_action.value) + self.execution_round.notes.setdefault("consensus_confidence", self.belief_revision.consensus_confidence) + if self.belief_revision.conflicts: + self.execution_round.notes.setdefault("conflicts", self.belief_revision.conflicts) + if self.belief_revision.notes: + self.execution_round.notes.setdefault("belief_notes", self.belief_revision.notes) + + +def decide( + context: AgentContext, + agents: Iterable[Agent], + weights: Mapping[str, float], + method: str = "nash", + department_manager: Optional[DepartmentManager] = None, + department_context: Optional[DepartmentContext] = None, +) -> Decision: + workflow = DecisionWorkflow( + context, + agents, + weights, + method, + department_manager, + department_context, ) + return workflow.run() def _department_scores(decision: DepartmentDecision) -> Dict[AgentAction, float]: diff --git a/app/backtest/engine.py b/app/backtest/engine.py index 8a08188..f4cd47a 100644 --- a/app/backtest/engine.py +++ b/app/backtest/engine.py @@ -17,6 +17,7 @@ from app.llm.metrics import record_decision as metrics_record_decision from app.agents.registry import default_agents from app.data.schema import initialize_database from app.utils.data_access import DataBroker +from app.utils.feature_snapshots import FeatureSnapshotService from app.utils.config import PortfolioSettings, get_config from app.utils.db import db_session from app.utils.logging import get_logger @@ -176,18 +177,35 @@ class BacktestEngine: trade_date_str = trade_date.strftime("%Y%m%d") feature_map: Dict[str, Dict[str, Any]] = {} universe = self.cfg.universe or [] + + snapshot_service = FeatureSnapshotService(self.data_broker) + batch_latest = snapshot_service.load_latest( + trade_date_str, + self.required_fields, + universe, + auto_refresh=False, + ) + for ts_code in universe: - scope_values = self.data_broker.fetch_latest( - ts_code, - trade_date_str, - self.required_fields, - auto_refresh=False # 避免回测时触发自动补数 - ) + scope_values = dict(batch_latest.get(ts_code) or {}) missing_fields = [ field for field in self.required_fields if scope_values.get(field) is None ] + if missing_fields: + fallback = self.data_broker.fetch_latest( + ts_code, + trade_date_str, + missing_fields, + auto_refresh=False, + ) + scope_values.update({k: v for k, v in fallback.items() if v is not None}) + missing_fields = [ + field + for field in self.required_fields + if scope_values.get(field) is None + ] derived_fields: List[str] = [] if missing_fields: LOGGER.debug( diff --git a/app/features/factors.py b/app/features/factors.py index 218fcb1..dd07038 100644 --- a/app/features/factors.py +++ b/app/features/factors.py @@ -5,11 +5,12 @@ import re import sqlite3 from dataclasses import dataclass from datetime import datetime, date, timezone, timedelta -from typing import Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union +from typing import Dict, Iterable, List, Mapping, Optional, Sequence, Set, Tuple, Union from app.core.indicators import momentum, rolling_mean, volatility from app.data.schema import initialize_database from app.utils.data_access import DataBroker +from app.utils.feature_snapshots import FeatureSnapshotService from app.utils.db import db_session from app.utils.logging import get_logger # 导入扩展因子模块 @@ -30,6 +31,18 @@ LOGGER = get_logger(__name__) LOG_EXTRA = {"stage": "factor_compute"} _IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") +_LATEST_BASE_FIELDS: List[str] = [ + "daily_basic.pe", + "daily_basic.pb", + "daily_basic.ps", + "daily_basic.turnover_rate", + "daily_basic.volume_ratio", + "daily.close", + "daily.amount", + "daily.vol", + "daily_basic.dv_ratio", +] + @dataclass class FactorSpec: @@ -502,6 +515,14 @@ def _compute_batch_factors( # 批次化数据可用性检查 available_codes = _check_batch_data_availability(broker, ts_codes, trade_date, specs) + + snapshot_service = FeatureSnapshotService(broker) + latest_snapshot = snapshot_service.load_latest( + trade_date, + _LATEST_BASE_FIELDS, + list(available_codes), + auto_refresh=False, + ) # 更新UI进度状态 - 开始处理批次 if progress and total_securities > 0: @@ -523,7 +544,13 @@ def _compute_batch_factors( continue # 计算因子值 - values = _compute_security_factors(broker, ts_code, trade_date, specs) + values = _compute_security_factors( + broker, + ts_code, + trade_date, + specs, + latest_fields=latest_snapshot.get(ts_code), + ) if values: # 检测并处理异常值 @@ -660,7 +687,7 @@ def _check_batch_data_availability( # 使用DataBroker的批次化获取最新字段数据 required_fields = ["daily.close", "daily_basic.turnover_rate"] - batch_fields_data = broker.fetch_batch_latest(sufficient_codes, trade_date, required_fields) + batch_fields_data = broker.fetch_batch_latest(list(sufficient_codes), trade_date, required_fields) # 检查每个证券的必需字段 for ts_code in sufficient_codes: @@ -753,6 +780,8 @@ def _compute_security_factors( ts_code: str, trade_date: str, specs: Sequence[FactorSpec], + *, + latest_fields: Optional[Mapping[str, object]] = None, ) -> Dict[str, float | None]: """计算单个证券的因子值 @@ -823,21 +852,14 @@ def _compute_security_factors( ) # 获取最新字段值 - latest_fields = broker.fetch_latest( - ts_code, - trade_date, - [ - "daily_basic.pe", - "daily_basic.pb", - "daily_basic.ps", - "daily_basic.turnover_rate", - "daily_basic.volume_ratio", - "daily.close", - "daily.amount", - "daily.vol", - "daily_basic.dv_ratio", # 股息率用于扩展因子 - ], - ) + if latest_fields is None: + latest_fields = broker.fetch_latest( + ts_code, + trade_date, + _LATEST_BASE_FIELDS, + ) + else: + latest_fields = dict(latest_fields) # 计算各个因子值 results: Dict[str, float | None] = {} diff --git a/app/ingest/api_client.py b/app/ingest/api_client.py new file mode 100644 index 0000000..bbee529 --- /dev/null +++ b/app/ingest/api_client.py @@ -0,0 +1,1387 @@ +"""TuShare API client helpers and persistence utilities.""" +from __future__ import annotations + +import os +import sqlite3 +import time +from collections import defaultdict, deque +from datetime import date +from typing import Dict, Iterable, List, Optional, Sequence, Set, Tuple + +import pandas as pd + +try: + import tushare as ts +except ImportError: # pragma: no cover - 运行时提示 + ts = None # type: ignore[assignment] + +from app.utils.config import get_config +from app.utils.db import db_session +from app.utils.logging import get_logger + +LOGGER = get_logger(__name__) + +API_DEFAULT_LIMIT = 5000 +LOG_EXTRA = {"stage": "data_ingest"} + +_CALL_BUCKETS: Dict[str, deque] = defaultdict(deque) + +RATE_LIMIT_ERROR_PATTERNS: Tuple[str, ...] = ( + "最多访问该接口", + "超过接口限制", + "Frequency limit", +) + +API_RATE_LIMITS: Dict[str, int] = { + "stock_basic": 180, + "daily": 480, + "daily_basic": 200, + "adj_factor": 200, + "suspend_d": 180, + "suspend": 180, + "stk_limit": 200, + "trade_cal": 200, + "index_basic": 120, + "index_daily": 240, + "fund_basic": 120, + "fund_nav": 200, + "fut_basic": 120, + "fut_daily": 200, + "fx_daily": 200, + "hk_daily": 2, + "us_daily": 200, +} + +INDEX_CODES: Tuple[str, ...] = ( + "000001.SH", # 上证综指 + "000300.SH", # 沪深300 + "000016.SH", # 上证50 + "000905.SH", # 中证500 + "399001.SZ", # 深证成指 + "399005.SZ", # 中小板指 + "399006.SZ", # 创业板指 + "HSI.HI", # 恒生指数 + "SPX.GI", # 标普500 + "DJI.GI", # 道琼斯工业指数 + "IXIC.GI", # 纳斯达克综合指数 + "GDAXI.GI", # 德国DAX + "FTSE.GI", # 英国富时100 +) + +ETF_CODES: Tuple[str, ...] = ( + "510300.SH", # 华泰柏瑞沪深300ETF + "510500.SH", # 南方中证500ETF + "159915.SZ", # 易方达创业板ETF +) + +FUND_CODES: Tuple[str, ...] = ( + "000001.OF", # 华夏成长 + "110022.OF", # 易方达消费行业 +) + +FUTURE_CODES: Tuple[str, ...] = ( + "IF9999.CFE", # 沪深300股指期货主力 + "IC9999.CFE", # 中证500股指期货主力 + "IH9999.CFE", # 上证50股指期货主力 +) + +FX_CODES: Tuple[str, ...] = ( + "USDCNY", # 美元人民币 + "EURCNY", # 欧元人民币 +) + +HK_CODES: Tuple[str, ...] = ( + "00700.HK", # 腾讯控股 + "00941.HK", # 中国移动 + "09618.HK", # 京东集团-SW + "09988.HK", # 阿里巴巴-SW + "03690.HK", # 美团-W +) + +US_CODES: Tuple[str, ...] = ( + "AAPL.O", # 苹果 + "MSFT.O", # 微软 + "BABA.N", # 阿里巴巴美股 + "JD.O", # 京东美股 + "PDD.O", # 拼多多 + "BIDU.O", # 百度 + "BILI.O", # 哔哩哔哩 +) + + +def _normalize_date_str(value: Optional[str]) -> Optional[str]: + if value is None: + return None + text = str(value).strip() + return text or None + + +def _respect_rate_limit(endpoint: str | None) -> None: + def _throttle(queue: deque, limit: int) -> None: + if limit <= 0: + return + now = time.time() + window = 60.0 + while queue and now - queue[0] > window: + queue.popleft() + if len(queue) >= limit: + sleep_time = window - (now - queue[0]) + 0.1 + LOGGER.debug( + "触发限频控制(limit=%s)休眠 %.2f 秒 endpoint=%s", + limit, + sleep_time, + endpoint, + extra=LOG_EXTRA, + ) + time.sleep(max(0.1, sleep_time)) + queue.append(time.time()) + + bucket_key = endpoint or "_default" + endpoint_limit = API_RATE_LIMITS.get(bucket_key, 60) + _throttle(_CALL_BUCKETS[bucket_key], endpoint_limit or 0) + + +def _df_to_records(df: pd.DataFrame, allowed_cols: List[str]) -> List[Dict]: + if df is None or df.empty: + return [] + reindexed = df.reindex(columns=allowed_cols) + return reindexed.where(pd.notnull(reindexed), None).to_dict("records") + + +def _fetch_paginated(endpoint: str, params: Dict[str, object], limit: int | None = None) -> pd.DataFrame: + client = _ensure_client() + limit = limit or API_DEFAULT_LIMIT + frames: List[pd.DataFrame] = [] + offset = 0 + clean_params = {k: v for k, v in params.items() if v is not None} + LOGGER.info( + "开始调用 TuShare 接口:%s,参数=%s,limit=%s", + endpoint, + clean_params, + limit, + extra=LOG_EXTRA, + ) + while True: + _respect_rate_limit(endpoint) + call = getattr(client, endpoint) + try: + df = call(limit=limit, offset=offset, **clean_params) + except Exception as exc: # noqa: BLE001 + message = str(exc) + if any(pattern in message for pattern in RATE_LIMIT_ERROR_PATTERNS): + per_minute = API_RATE_LIMITS.get(endpoint or "", 0) + wait_time = 60.0 / per_minute + 1 if per_minute else 30.0 + wait_time = max(wait_time, 30.0) + LOGGER.warning( + "接口限频触发:%s,原因=%s,等待 %.1f 秒后重试", + endpoint, + message, + wait_time, + extra=LOG_EXTRA, + ) + time.sleep(wait_time) + continue + + LOGGER.exception( + "TuShare 接口调用异常:endpoint=%s offset=%s params=%s", + endpoint, + offset, + clean_params, + extra=LOG_EXTRA, + ) + raise + if df is None or df.empty: + LOGGER.debug( + "TuShare 返回空数据:endpoint=%s offset=%s", + endpoint, + offset, + extra=LOG_EXTRA, + ) + break + LOGGER.debug( + "TuShare 返回 %s 行:endpoint=%s offset=%s", + len(df), + endpoint, + offset, + extra=LOG_EXTRA, + ) + frames.append(df) + if len(df) < limit: + break + offset += limit + if not frames: + return pd.DataFrame() + merged = pd.concat(frames, ignore_index=True) + LOGGER.info( + "TuShare 调用完成:endpoint=%s 总行数=%s", + endpoint, + len(merged), + extra=LOG_EXTRA, + ) + return merged + + +def _ensure_client(): + if ts is None: + raise RuntimeError("未安装 tushare,请先在环境中安装 tushare 包") + token = get_config().tushare_token or os.getenv("TUSHARE_TOKEN") + if not token: + raise RuntimeError("未配置 TuShare Token,请在配置文件或环境变量 TUSHARE_TOKEN 中设置") + if not hasattr(_ensure_client, "_client") or _ensure_client._client is None: # type: ignore[attr-defined] + ts.set_token(token) + _ensure_client._client = ts.pro_api(token) # type: ignore[attr-defined] + LOGGER.info("完成 TuShare 客户端初始化") + return _ensure_client._client # type: ignore[attr-defined] + + +def _format_date(value: date) -> str: + return value.strftime("%Y%m%d") + + +def _load_trade_dates(start: date, end: date, exchange: str = "SSE") -> List[str]: + start_str = _format_date(start) + end_str = _format_date(end) + query = ( + "SELECT cal_date FROM trade_calendar " + "WHERE exchange = ? AND cal_date BETWEEN ? AND ? AND is_open = 1 ORDER BY cal_date" + ) + with db_session(read_only=True) as conn: + rows = conn.execute(query, (exchange, start_str, end_str)).fetchall() + return [row["cal_date"] for row in rows] + + +def _record_exists( + table: str, + date_col: str, + trade_date: str, + ts_code: Optional[str] = None, +) -> bool: + query = f"SELECT 1 FROM {table} WHERE {date_col} = ?" + params: Tuple = (trade_date,) + if ts_code: + query += " AND ts_code = ?" + params = (trade_date, ts_code) + with db_session(read_only=True) as conn: + row = conn.execute(query, params).fetchone() + return row is not None + + +def _existing_suspend_dates(start_str: str, end_str: str, ts_code: str | None = None) -> Set[str]: + sql = "SELECT DISTINCT suspend_date FROM suspend WHERE suspend_date BETWEEN ? AND ?" + params: List[object] = [start_str, end_str] + if ts_code: + sql += " AND ts_code = ?" + params.append(ts_code) + try: + with db_session(read_only=True) as conn: + rows = conn.execute(sql, tuple(params)).fetchall() + except sqlite3.OperationalError: + return set() + return {row["suspend_date"] for row in rows if row["suspend_date"]} + + +def _listing_window(ts_code: str) -> Tuple[Optional[str], Optional[str]]: + with db_session(read_only=True) as conn: + row = conn.execute( + "SELECT list_date, delist_date FROM stock_basic WHERE ts_code = ?", + (ts_code,), + ).fetchone() + if not row: + return None, None + return _normalize_date_str(row["list_date"]), _normalize_date_str(row["delist_date"]) # type: ignore[index] + + +def _calendar_needs_refresh(exchange: str, start_str: str, end_str: str) -> bool: + sql = """ + SELECT MIN(cal_date) AS min_d, MAX(cal_date) AS max_d, COUNT(*) AS cnt + FROM trade_calendar + WHERE exchange = ? AND cal_date BETWEEN ? AND ? + """ + with db_session(read_only=True) as conn: + row = conn.execute(sql, (exchange, start_str, end_str)).fetchone() + if row is None or row["min_d"] is None: + return True + if row["min_d"] > start_str or row["max_d"] < end_str: + return True + return False + + +def ensure_trade_calendar(start: date, end: date, exchanges: Sequence[str] = ("SSE", "SZSE")) -> None: + start_str = _format_date(start) + end_str = _format_date(end) + for exch in exchanges: + if _calendar_needs_refresh(exch, start_str, end_str): + save_records("trade_calendar", fetch_trade_calendar(start, end, exchange=exch)) + + +def _expected_trading_days(start_str: str, end_str: str, exchange: str = "SSE") -> int: + sql = """ + SELECT COUNT(*) AS cnt + FROM trade_calendar + WHERE exchange = ? AND cal_date BETWEEN ? AND ? AND is_open = 1 + """ + with db_session(read_only=True) as conn: + row = conn.execute(sql, (exchange, start_str, end_str)).fetchone() + return int(row["cnt"]) if row and row["cnt"] is not None else 0 + + +_TABLE_SCHEMAS: Dict[str, str] = { + "stock_basic": """ + CREATE TABLE IF NOT EXISTS stock_basic ( + ts_code TEXT PRIMARY KEY, + symbol TEXT, + name TEXT, + area TEXT, + industry TEXT, + market TEXT, + exchange TEXT, + list_status TEXT, + list_date TEXT, + delist_date TEXT + ); + """, + "daily": """ + CREATE TABLE IF NOT EXISTS daily ( + ts_code TEXT, + trade_date TEXT, + open REAL, + high REAL, + low REAL, + close REAL, + pre_close REAL, + change REAL, + pct_chg REAL, + vol REAL, + amount REAL, + PRIMARY KEY (ts_code, trade_date) + ); + """, + "daily_basic": """ + CREATE TABLE IF NOT EXISTS daily_basic ( + ts_code TEXT, + trade_date TEXT, + close REAL, + turnover_rate REAL, + turnover_rate_f REAL, + volume_ratio REAL, + pe REAL, + pe_ttm REAL, + pb REAL, + ps REAL, + ps_ttm REAL, + dv_ratio REAL, + dv_ttm REAL, + total_share REAL, + float_share REAL, + free_share REAL, + total_mv REAL, + circ_mv REAL, + PRIMARY KEY (ts_code, trade_date) + ); + """, + "adj_factor": """ + CREATE TABLE IF NOT EXISTS adj_factor ( + ts_code TEXT, + trade_date TEXT, + adj_factor REAL, + PRIMARY KEY (ts_code, trade_date) + ); + """, + "suspend": """ + CREATE TABLE IF NOT EXISTS suspend ( + ts_code TEXT, + suspend_date TEXT, + resume_date TEXT, + suspend_type TEXT, + ann_date TEXT, + suspend_timing TEXT, + resume_timing TEXT, + reason TEXT, + PRIMARY KEY (ts_code, suspend_date) + ); + """, + "trade_calendar": """ + CREATE TABLE IF NOT EXISTS trade_calendar ( + exchange TEXT, + cal_date TEXT, + is_open INTEGER, + pretrade_date TEXT, + PRIMARY KEY (exchange, cal_date) + ); + """, + "stk_limit": """ + CREATE TABLE IF NOT EXISTS stk_limit ( + ts_code TEXT, + trade_date TEXT, + up_limit REAL, + down_limit REAL, + PRIMARY KEY (ts_code, trade_date) + ); + """, + "index_basic": """ + CREATE TABLE IF NOT EXISTS index_basic ( + ts_code TEXT PRIMARY KEY, + name TEXT, + fullname TEXT, + market TEXT, + publisher TEXT, + index_type TEXT, + category TEXT, + base_date TEXT, + base_point REAL, + list_date TEXT, + weight_rule TEXT, + desc TEXT, + exp_date TEXT + ); + """, + "index_daily": """ + CREATE TABLE IF NOT EXISTS index_daily ( + ts_code TEXT, + trade_date TEXT, + close REAL, + open REAL, + high REAL, + low REAL, + pre_close REAL, + change REAL, + pct_chg REAL, + vol REAL, + amount REAL, + PRIMARY KEY (ts_code, trade_date) + ); + """, + "index_dailybasic": """ + CREATE TABLE IF NOT EXISTS index_dailybasic ( + ts_code TEXT, + trade_date TEXT, + turnover REAL, + turnover_ratio REAL, + pe_ttm REAL, + pb REAL, + ps_ttm REAL, + dv_ttm REAL, + total_mv REAL, + circ_mv REAL, + PRIMARY KEY (ts_code, trade_date) + ); + """, + "index_weight": """ + CREATE TABLE IF NOT EXISTS index_weight ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + index_code VARCHAR(10) NOT NULL, + trade_date VARCHAR(8) NOT NULL, + ts_code VARCHAR(10) NOT NULL, + weight FLOAT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ); + """, + "fund_basic": """ + CREATE TABLE IF NOT EXISTS fund_basic ( + ts_code TEXT PRIMARY KEY, + name TEXT, + management TEXT, + custodian TEXT, + fund_type TEXT, + found_date TEXT, + due_date TEXT, + list_date TEXT, + issue_date TEXT, + delist_date TEXT, + issue_amount REAL, + m_fee REAL, + c_fee REAL, + benchmark TEXT, + status TEXT, + invest_type TEXT, + type TEXT, + trustee TEXT, + purc_start_date TEXT, + redm_start_date TEXT, + market TEXT + ); + """, + "fund_nav": """ + CREATE TABLE IF NOT EXISTS fund_nav ( + ts_code TEXT, + nav_date TEXT, + ann_date TEXT, + unit_nav REAL, + accum_nav REAL, + accum_div REAL, + net_asset REAL, + total_netasset REAL, + adj_nav REAL, + update_flag TEXT, + PRIMARY KEY (ts_code, nav_date) + ); + """, + "fut_basic": """ + CREATE TABLE IF NOT EXISTS fut_basic ( + ts_code TEXT PRIMARY KEY, + symbol TEXT, + name TEXT, + exchange TEXT, + exchange_full_name TEXT, + product TEXT, + product_name TEXT, + variety TEXT, + list_date TEXT, + delist_date TEXT, + trade_unit REAL, + per_unit REAL, + quote_unit TEXT, + settle_month TEXT, + contract_size REAL, + tick_size REAL, + margin_rate REAL, + margin_ratio REAL, + delivery_month TEXT, + delivery_day TEXT + ); + """, + "fut_daily": """ + CREATE TABLE IF NOT EXISTS fut_daily ( + ts_code TEXT, + trade_date TEXT, + pre_settle REAL, + open REAL, + high REAL, + low REAL, + close REAL, + settle REAL, + change1 REAL, + change2 REAL, + vol REAL, + amount REAL, + oi REAL, + oi_chg REAL, + PRIMARY KEY (ts_code, trade_date) + ); + """, + "fx_daily": """ + CREATE TABLE IF NOT EXISTS fx_daily ( + ts_code TEXT, + trade_date TEXT, + bid REAL, + ask REAL, + mid REAL, + high REAL, + low REAL, + PRIMARY KEY (ts_code, trade_date) + ); + """, + "hk_daily": """ + CREATE TABLE IF NOT EXISTS hk_daily ( + ts_code TEXT, + trade_date TEXT, + close REAL, + open REAL, + high REAL, + low REAL, + pre_close REAL, + change REAL, + pct_chg REAL, + vol REAL, + amount REAL, + exchange TEXT, + PRIMARY KEY (ts_code, trade_date) + ); + """, + "us_daily": """ + CREATE TABLE IF NOT EXISTS us_daily ( + ts_code TEXT, + trade_date TEXT, + close REAL, + open REAL, + high REAL, + low REAL, + pre_close REAL, + change REAL, + pct_chg REAL, + vol REAL, + amount REAL, + PRIMARY KEY (ts_code, trade_date) + ); + """, +} + +_TABLE_COLUMNS: Dict[str, List[str]] = { + "stock_basic": [ + "ts_code", + "symbol", + "name", + "area", + "industry", + "market", + "exchange", + "list_status", + "list_date", + "delist_date", + ], + "daily": [ + "ts_code", + "trade_date", + "open", + "high", + "low", + "close", + "pre_close", + "change", + "pct_chg", + "vol", + "amount", + ], + "daily_basic": [ + "ts_code", + "trade_date", + "close", + "turnover_rate", + "turnover_rate_f", + "volume_ratio", + "pe", + "pe_ttm", + "pb", + "ps", + "ps_ttm", + "dv_ratio", + "dv_ttm", + "total_share", + "float_share", + "free_share", + "total_mv", + "circ_mv", + ], + "adj_factor": [ + "ts_code", + "trade_date", + "adj_factor", + ], + "suspend": [ + "ts_code", + "suspend_date", + "resume_date", + "suspend_type", + "ann_date", + "suspend_timing", + "resume_timing", + "reason", + ], + "trade_calendar": [ + "exchange", + "cal_date", + "is_open", + "pretrade_date", + ], + "stk_limit": [ + "ts_code", + "trade_date", + "up_limit", + "down_limit", + ], + "index_basic": [ + "ts_code", + "name", + "fullname", + "market", + "publisher", + "index_type", + "category", + "base_date", + "base_point", + "list_date", + "weight_rule", + "desc", + "exp_date", + ], + "index_daily": [ + "ts_code", + "trade_date", + "close", + "open", + "high", + "low", + "pre_close", + "change", + "pct_chg", + "vol", + "amount", + ], + "index_dailybasic": [ + "ts_code", + "trade_date", + "turnover", + "turnover_ratio", + "pe_ttm", + "pb", + "ps_ttm", + "dv_ttm", + "total_mv", + "circ_mv", + ], + "index_weight": [ + "index_code", + "trade_date", + "ts_code", + "weight", + ], + "fund_basic": [ + "ts_code", + "name", + "management", + "custodian", + "fund_type", + "found_date", + "due_date", + "list_date", + "issue_date", + "delist_date", + "issue_amount", + "m_fee", + "c_fee", + "benchmark", + "status", + "invest_type", + "type", + "trustee", + "purc_start_date", + "redm_start_date", + "market", + ], + "fund_nav": [ + "ts_code", + "nav_date", + "ann_date", + "unit_nav", + "accum_nav", + "accum_div", + "net_asset", + "total_netasset", + "adj_nav", + "update_flag", + ], + "fut_basic": [ + "ts_code", + "symbol", + "name", + "exchange", + "exchange_full_name", + "product", + "product_name", + "variety", + "list_date", + "delist_date", + "trade_unit", + "per_unit", + "quote_unit", + "settle_month", + "contract_size", + "tick_size", + "margin_rate", + "margin_ratio", + "delivery_month", + "delivery_day", + ], + "fut_daily": [ + "ts_code", + "trade_date", + "pre_settle", + "open", + "high", + "low", + "close", + "settle", + "change1", + "change2", + "vol", + "amount", + "oi", + "oi_chg", + ], + "fx_daily": [ + "ts_code", + "trade_date", + "bid", + "ask", + "mid", + "high", + "low", + ], + "hk_daily": [ + "ts_code", + "trade_date", + "close", + "open", + "high", + "low", + "pre_close", + "change", + "pct_chg", + "vol", + "amount", + "exchange", + ], + "us_daily": [ + "ts_code", + "trade_date", + "close", + "open", + "high", + "low", + "pre_close", + "change", + "pct_chg", + "vol", + "amount", + ], +} + + +def save_records(table: str, rows: Iterable[Dict]) -> None: + items = list(rows) + if not items: + LOGGER.info("表 %s 没有新增记录,跳过写入", table, extra=LOG_EXTRA) + return + + schema = _TABLE_SCHEMAS.get(table) + columns = _TABLE_COLUMNS.get(table) + if not schema or not columns: + raise ValueError(f"不支持写入的表:{table}") + + placeholders = ",".join([f":{col}" for col in columns]) + col_clause = ",".join(columns) + + LOGGER.info("表 %s 写入 %d 条记录", table, len(items), extra=LOG_EXTRA) + with db_session() as conn: + conn.executescript(schema) + conn.executemany( + f"INSERT OR REPLACE INTO {table} ({col_clause}) VALUES ({placeholders})", + items, + ) + + +def ensure_stock_basic(list_status: str = "L") -> None: + exchanges = ("SSE", "SZSE") + with db_session(read_only=True) as conn: + row = conn.execute( + "SELECT COUNT(*) AS cnt FROM stock_basic WHERE exchange IN (?, ?) AND list_status = ?", + (*exchanges, list_status), + ).fetchone() + if row and row["cnt"]: + LOGGER.info( + "股票基础信息已存在 %d 条记录,跳过拉取", + row["cnt"], + extra=LOG_EXTRA, + ) + return + + for exch in exchanges: + save_records("stock_basic", fetch_stock_basic(exchange=exch, list_status=list_status)) + + +def fetch_stock_basic(exchange: Optional[str] = None, list_status: str = "L") -> Iterable[Dict]: + client = _ensure_client() + LOGGER.info( + "拉取股票基础信息(交易所:%s,状态:%s)", + exchange or "全部", + list_status, + extra=LOG_EXTRA, + ) + _respect_rate_limit("stock_basic") + fields = "ts_code,symbol,name,area,industry,market,exchange,list_status,list_date,delist_date" + df = client.stock_basic(exchange=exchange, list_status=list_status, fields=fields) + return _df_to_records(df, _TABLE_COLUMNS["stock_basic"]) + + +def fetch_daily_bars( + start: date, + end: date, + ts_codes: Optional[Sequence[str]] = None, + *, + skip_existing: bool = True, + exchange: str = "SSE", +) -> Iterable[Dict]: + client = _ensure_client() + frames: List[pd.DataFrame] = [] + + trade_dates = _load_trade_dates(start, end, exchange=exchange) + if not trade_dates: + LOGGER.info("本地交易日历缺失,尝试补全后再拉取日线行情", extra=LOG_EXTRA) + ensure_trade_calendar(start, end, exchanges=(exchange,)) + trade_dates = _load_trade_dates(start, end, exchange=exchange) + + if ts_codes: + for code in ts_codes: + for trade_date in trade_dates: + if skip_existing and _record_exists("daily", "trade_date", trade_date, code): + LOGGER.debug( + "日线数据已存在,跳过 %s %s", + code, + trade_date, + extra=LOG_EXTRA, + ) + continue + LOGGER.debug( + "按交易日拉取日线行情:code=%s trade_date=%s", + code, + trade_date, + extra=LOG_EXTRA, + ) + LOGGER.info( + "交易日拉取请求:endpoint=daily code=%s trade_date=%s", + code, + trade_date, + extra=LOG_EXTRA, + ) + df = _fetch_paginated( + "daily", + { + "trade_date": trade_date, + "ts_code": code, + }, + ) + if not df.empty: + frames.append(df) + else: + for trade_date in trade_dates: + if skip_existing and _record_exists("daily", "trade_date", trade_date): + LOGGER.debug( + "日线数据已存在,跳过交易日 %s", + trade_date, + extra=LOG_EXTRA, + ) + continue + LOGGER.debug("按交易日拉取日线行情:%s", trade_date, extra=LOG_EXTRA) + LOGGER.info( + "交易日拉取请求:endpoint=daily trade_date=%s", + trade_date, + extra=LOG_EXTRA, + ) + df = _fetch_paginated("daily", {"trade_date": trade_date}) + if not df.empty: + frames.append(df) + + if not frames: + return [] + df = pd.concat(frames, ignore_index=True) + return _df_to_records(df, _TABLE_COLUMNS["daily"]) + + +def fetch_daily_basic( + start: date, + end: date, + ts_code: Optional[str] = None, + *, + skip_existing: bool = True, +) -> Iterable[Dict]: + client = _ensure_client() + start_date = _format_date(start) + end_date = _format_date(end) + LOGGER.info( + "拉取日线基础指标(%s-%s,股票:%s)", + start_date, + end_date, + ts_code or "全部", + extra=LOG_EXTRA, + ) + + trade_dates = _load_trade_dates(start, end) + frames: List[pd.DataFrame] = [] + for trade_date in trade_dates: + if skip_existing and _record_exists("daily_basic", "trade_date", trade_date, ts_code): + LOGGER.info( + "日线基础指标已存在,跳过交易日 %s", + trade_date, + extra=LOG_EXTRA, + ) + continue + params = {"trade_date": trade_date} + if ts_code: + params["ts_code"] = ts_code + LOGGER.info( + "交易日拉取请求:endpoint=daily_basic params=%s", + params, + extra=LOG_EXTRA, + ) + df = _fetch_paginated("daily_basic", params) + if not df.empty: + frames.append(df) + + if not frames: + return [] + + merged = pd.concat(frames, ignore_index=True) + return _df_to_records(merged, _TABLE_COLUMNS["daily_basic"]) + + +def fetch_adj_factor( + start: date, + end: date, + ts_code: Optional[str] = None, + *, + skip_existing: bool = True, +) -> Iterable[Dict]: + client = _ensure_client() + start_date = _format_date(start) + end_date = _format_date(end) + LOGGER.info( + "拉取复权因子(%s-%s,股票:%s)", + start_date, + end_date, + ts_code or "全部", + extra=LOG_EXTRA, + ) + + trade_dates = _load_trade_dates(start, end) + frames: List[pd.DataFrame] = [] + for trade_date in trade_dates: + if skip_existing and _record_exists("adj_factor", "trade_date", trade_date, ts_code): + LOGGER.info( + "复权因子已存在,跳过交易日 %s", + trade_date, + extra=LOG_EXTRA, + ) + continue + params = {"trade_date": trade_date} + if ts_code: + params["ts_code"] = ts_code + LOGGER.info( + "交易日拉取请求:endpoint=adj_factor params=%s", + params, + extra=LOG_EXTRA, + ) + df = _fetch_paginated("adj_factor", params) + if not df.empty: + frames.append(df) + + if not frames: + return [] + + merged = pd.concat(frames, ignore_index=True) + return _df_to_records(merged, _TABLE_COLUMNS["adj_factor"]) + + +def fetch_suspensions( + start: date, + end: date, + ts_code: Optional[str] = None, + *, + skip_existing: bool = False, +) -> Iterable[Dict]: + client = _ensure_client() + start_str = _format_date(start) + end_str = _format_date(end) + LOGGER.info( + "拉取停复牌信息(%s-%s,股票:%s)", + start_str, + end_str, + ts_code or "全部", + extra=LOG_EXTRA, + ) + + params: Dict[str, object] = { + "start_date": start_str, + "end_date": end_str, + } + if ts_code: + params["ts_code"] = ts_code + df = _fetch_paginated("suspend_d", params) + if df.empty: + return [] + + merged = df.rename( + columns={ + "ann_date": "ann_date", + "suspend_date": "suspend_date", + "resume_date": "resume_date", + "suspend_type": "suspend_type", + } + ) + if skip_existing: + existing = _existing_suspend_dates(start_str, end_str, ts_code=ts_code) + if existing: + merged = merged[~merged["suspend_date"].isin(existing)] + missing_cols = [col for col in _TABLE_COLUMNS["suspend"] if col not in merged.columns] + for column in missing_cols: + merged[column] = None + ordered = merged[_TABLE_COLUMNS["suspend"]] + return _df_to_records(ordered, _TABLE_COLUMNS["suspend"]) + + +def fetch_trade_calendar(start: date, end: date, exchange: str = "SSE") -> Iterable[Dict]: + client = _ensure_client() + start_str = _format_date(start) + end_str = _format_date(end) + LOGGER.info( + "拉取交易日历:%s %s-%s", + exchange, + start_str, + end_str, + extra=LOG_EXTRA, + ) + df = _fetch_paginated( + "trade_cal", + {"exchange": exchange, "start_date": start_str, "end_date": end_str}, + limit=4000, + ) + return _df_to_records(df, _TABLE_COLUMNS["trade_calendar"]) + + +def fetch_stk_limit( + start: date, + end: date, + ts_code: str | None = None, + *, + skip_existing: bool = True, +) -> Iterable[Dict]: + client = _ensure_client() + start_str = _format_date(start) + end_str = _format_date(end) + LOGGER.info( + "拉取涨跌停数据(%s-%s,股票:%s)", + start_str, + end_str, + ts_code or "全部", + extra=LOG_EXTRA, + ) + + params: Dict[str, object] = {"start_date": start_str, "end_date": end_str} + if ts_code: + params["ts_code"] = ts_code + df = _fetch_paginated("stk_limit", params, limit=4000) + if df.empty: + return [] + if skip_existing: + df = df[ + ~df.apply( + lambda row: _record_exists("stk_limit", "trade_date", row["trade_date"], row["ts_code"]), + axis=1, + ) + ] + return _df_to_records(df, _TABLE_COLUMNS["stk_limit"]) + + +def fetch_index_basic() -> Iterable[Dict]: + client = _ensure_client() + LOGGER.info("拉取指数基础信息", extra=LOG_EXTRA) + df = _fetch_paginated("index_basic", {"market": "SSE"}) + return _df_to_records(df, _TABLE_COLUMNS["index_basic"]) + + +def fetch_index_daily(start: date, end: date, ts_code: str) -> Iterable[Dict]: + client = _ensure_client() + start_str = _format_date(start) + end_str = _format_date(end) + LOGGER.info( + "拉取指数日线:%s %s-%s", + ts_code, + start_str, + end_str, + extra=LOG_EXTRA, + ) + df = _fetch_paginated( + "index_daily", + {"ts_code": ts_code, "start_date": start_str, "end_date": end_str}, + limit=4000, + ) + return _df_to_records(df, _TABLE_COLUMNS["index_daily"]) + + +def fetch_index_dailybasic(start: date, end: date, ts_code: str) -> Iterable[Dict]: + client = _ensure_client() + start_str = _format_date(start) + end_str = _format_date(end) + LOGGER.info( + "拉取指数每日指标:%s %s-%s", + ts_code, + start_str, + end_str, + extra=LOG_EXTRA, + ) + df = _fetch_paginated( + "index_dailybasic", + {"ts_code": ts_code, "start_date": start_str, "end_date": end_str}, + limit=4000, + ) + return _df_to_records(df, _TABLE_COLUMNS["index_dailybasic"]) + + +def fetch_index_weight(start: date, end: date, index_code: str) -> Iterable[Dict]: + client = _ensure_client() + start_str = _format_date(start) + end_str = _format_date(end) + LOGGER.info( + "拉取指数权重:%s %s-%s", + index_code, + start_str, + end_str, + extra=LOG_EXTRA, + ) + df = _fetch_paginated( + "index_weight", + {"index_code": index_code, "start_date": start_str, "end_date": end_str}, + limit=4000, + ) + return _df_to_records(df, _TABLE_COLUMNS["index_weight"]) + + +def fetch_fund_basic(market: Optional[str] = None) -> Iterable[Dict]: + client = _ensure_client() + LOGGER.info( + "拉取基金基础信息(市场:%s)", + market or "全部", + extra=LOG_EXTRA, + ) + params: Dict[str, object] = {} + if market: + params["market"] = market + df = _fetch_paginated("fund_basic", params) + return _df_to_records(df, _TABLE_COLUMNS["fund_basic"]) + + +def fetch_fund_nav(start: date, end: date, ts_code: str) -> Iterable[Dict]: + client = _ensure_client() + start_str = _format_date(start) + end_str = _format_date(end) + LOGGER.info( + "拉取基金净值:%s %s-%s", + ts_code, + start_str, + end_str, + extra=LOG_EXTRA, + ) + df = _fetch_paginated( + "fund_nav", + {"ts_code": ts_code, "start_date": start_str, "end_date": end_str}, + limit=4000, + ) + return _df_to_records(df, _TABLE_COLUMNS["fund_nav"]) + + +def fetch_fut_basic(exchange: Optional[str] = None) -> Iterable[Dict]: + client = _ensure_client() + LOGGER.info( + "拉取期货基础信息(交易所:%s)", + exchange or "全部", + extra=LOG_EXTRA, + ) + df = _fetch_paginated("fut_basic", {"exchange": exchange}) + return _df_to_records(df, _TABLE_COLUMNS["fut_basic"]) + + +def fetch_fut_daily(start: date, end: date, ts_code: str) -> Iterable[Dict]: + client = _ensure_client() + start_str = _format_date(start) + end_str = _format_date(end) + LOGGER.info( + "拉取期货日线:%s %s-%s", + ts_code, + start_str, + end_str, + extra=LOG_EXTRA, + ) + df = _fetch_paginated( + "fut_daily", + {"ts_code": ts_code, "start_date": start_str, "end_date": end_str}, + limit=4000, + ) + return _df_to_records(df, _TABLE_COLUMNS["fut_daily"]) + + +def fetch_fx_daily(start: date, end: date, ts_code: str) -> Iterable[Dict]: + client = _ensure_client() + start_str = _format_date(start) + end_str = _format_date(end) + LOGGER.info( + "拉取外汇日线:%s %s-%s", + ts_code, + start_str, + end_str, + extra=LOG_EXTRA, + ) + df = _fetch_paginated( + "fx_daily", + {"ts_code": ts_code, "start_date": start_str, "end_date": end_str}, + limit=4000, + ) + return _df_to_records(df, _TABLE_COLUMNS["fx_daily"]) + + +def fetch_hk_daily(start: date, end: date, ts_code: str) -> Iterable[Dict]: + client = _ensure_client() + start_str = _format_date(start) + end_str = _format_date(end) + LOGGER.info( + "拉取港股日线:%s %s-%s", + ts_code, + start_str, + end_str, + extra=LOG_EXTRA, + ) + df = _fetch_paginated( + "hk_daily", + {"ts_code": ts_code, "start_date": start_str, "end_date": end_str}, + limit=4000, + ) + return _df_to_records(df, _TABLE_COLUMNS["hk_daily"]) + + +def fetch_us_daily(start: date, end: date, ts_code: str) -> Iterable[Dict]: + client = _ensure_client() + start_str = _format_date(start) + end_str = _format_date(end) + LOGGER.info( + "拉取美股日线:%s %s-%s", + ts_code, + start_str, + end_str, + extra=LOG_EXTRA, + ) + df = _fetch_paginated( + "us_daily", + {"ts_code": ts_code, "start_date": start_str, "end_date": end_str}, + limit=4000, + ) + return _df_to_records(df, _TABLE_COLUMNS["us_daily"]) + + +__all__ = [ + "API_DEFAULT_LIMIT", + "API_RATE_LIMITS", + "ETF_CODES", + "FUND_CODES", + "FUTURE_CODES", + "FX_CODES", + "HK_CODES", + "INDEX_CODES", + "US_CODES", + "ensure_stock_basic", + "ensure_trade_calendar", + "fetch_adj_factor", + "fetch_daily_basic", + "fetch_daily_bars", + "fetch_fund_basic", + "fetch_fund_nav", + "fetch_fut_basic", + "fetch_fut_daily", + "fetch_fx_daily", + "fetch_hk_daily", + "fetch_index_basic", + "fetch_index_daily", + "fetch_index_dailybasic", + "fetch_index_weight", + "fetch_stock_basic", + "fetch_stk_limit", + "fetch_suspensions", + "fetch_trade_calendar", + "fetch_us_daily", + "save_records", + "LOG_EXTRA", + "_expected_trading_days", + "_format_date", + "_listing_window", + "_load_trade_dates", + "_record_exists", + "_existing_suspend_dates", +] + diff --git a/app/ingest/coverage.py b/app/ingest/coverage.py new file mode 100644 index 0000000..08c0ee2 --- /dev/null +++ b/app/ingest/coverage.py @@ -0,0 +1,395 @@ +"""Data coverage orchestration separated from TuShare API calls.""" +from __future__ import annotations + +import sqlite3 +from datetime import date +from typing import Callable, Dict, List, Optional, Sequence + +from app.data.schema import initialize_database +from app.utils.db import db_session +from app.utils.logging import get_logger + +from .api_client import ( + ETF_CODES, + FUND_CODES, + FUTURE_CODES, + FX_CODES, + HK_CODES, + INDEX_CODES, + LOG_EXTRA, + US_CODES, + _expected_trading_days, + _format_date, + _listing_window, + ensure_stock_basic, + ensure_trade_calendar, + fetch_adj_factor, + fetch_daily_basic, + fetch_daily_bars, + fetch_fund_basic, + fetch_fund_nav, + fetch_fut_basic, + fetch_fut_daily, + fetch_fx_daily, + fetch_hk_daily, + fetch_index_basic, + fetch_index_daily, + fetch_index_dailybasic, + fetch_index_weight, + fetch_suspensions, + fetch_stk_limit, + fetch_trade_calendar, + fetch_us_daily, + save_records, +) + +LOGGER = get_logger(__name__) + + +def _range_stats( + table: str, + date_col: str, + start_str: str, + end_str: str, + ts_code: str | None = None, +) -> Dict[str, Optional[str]]: + sql = ( + f"SELECT MIN({date_col}) AS min_d, MAX({date_col}) AS max_d, " + f"COUNT(DISTINCT {date_col}) AS distinct_days FROM {table} " + f"WHERE {date_col} BETWEEN ? AND ?" + ) + params: List[object] = [start_str, end_str] + if ts_code: + sql += " AND ts_code = ?" + params.append(ts_code) + try: + with db_session(read_only=True) as conn: + row = conn.execute(sql, tuple(params)).fetchone() + except sqlite3.OperationalError: + return {"min": None, "max": None, "distinct": 0} + return { + "min": row["min_d"] if row else None, + "max": row["max_d"] if row else None, + "distinct": row["distinct_days"] if row else 0, + } + + +def _range_needs_refresh( + table: str, + date_col: str, + start_str: str, + end_str: str, + expected_days: int = 0, + **filters: object, +) -> bool: + ts_code = filters.get("ts_code") or filters.get("index_code") + stats = _range_stats(table, date_col, start_str, end_str, ts_code=ts_code) # type: ignore[arg-type] + if stats["min"] is None or stats["max"] is None: + return True + if stats["min"] > start_str or stats["max"] < end_str: + return True + if expected_days and (stats["distinct"] or 0) < expected_days: + return True + return False + + +def _should_skip_range( + table: str, + date_col: str, + start: date, + end: date, + ts_code: str | None = None, +) -> bool: + start_str = _format_date(start) + end_str = _format_date(end) + + effective_start = start_str + effective_end = end_str + + if ts_code: + list_date, delist_date = _listing_window(ts_code) + if list_date: + effective_start = max(effective_start, list_date) + if delist_date: + effective_end = min(effective_end, delist_date) + if effective_start > effective_end: + LOGGER.debug( + "股票 %s 在目标区间之外,跳过补数", + ts_code, + extra=LOG_EXTRA, + ) + return True + stats = _range_stats(table, date_col, effective_start, effective_end, ts_code=ts_code) + else: + stats = _range_stats(table, date_col, effective_start, effective_end) + + if stats["min"] is None or stats["max"] is None: + return False + if stats["min"] > effective_start or stats["max"] < effective_end: + return False + + if ts_code is None: + expected_days = _expected_trading_days(effective_start, effective_end) + if expected_days and (stats["distinct"] or 0) < expected_days: + return False + + return True + + +def ensure_index_weights(start: date, end: date, index_codes: Optional[Sequence[str]] = None) -> None: + if index_codes is None: + index_codes = [code for code in INDEX_CODES if code.endswith(".SH") or code.endswith(".SZ")] + + for index_code in index_codes: + start_str = _format_date(start) + end_str = _format_date(end) + if _range_needs_refresh("index_weight", "trade_date", start_str, end_str, index_code=index_code): + LOGGER.info("指数 %s 的成分股权重数据不完整,开始拉取 %s-%s", index_code, start_str, end_str) + save_records("index_weight", fetch_index_weight(start, end, index_code)) + else: + LOGGER.info("指数 %s 的成分股权重数据已完整,跳过", index_code) + + +def ensure_index_dailybasic(start: date, end: date, index_codes: Optional[Sequence[str]] = None) -> None: + if index_codes is None: + index_codes = [code for code in INDEX_CODES if code.endswith(".SH") or code.endswith(".SZ")] + + for index_code in index_codes: + start_str = _format_date(start) + end_str = _format_date(end) + if _range_needs_refresh("index_dailybasic", "trade_date", start_str, end_str, ts_code=index_code): + LOGGER.info("指数 %s 的每日指标数据不完整,开始拉取 %s-%s", index_code, start_str, end_str) + save_records("index_dailybasic", fetch_index_dailybasic(start, end, index_code)) + else: + LOGGER.info("指数 %s 的每日指标数据已完整,跳过", index_code) + + +def ensure_data_coverage( + start: date, + end: date, + ts_codes: Optional[Sequence[str]] = None, + include_limits: bool = True, + include_extended: bool = True, + force: bool = False, + progress_hook: Callable[[str, float], None] | None = None, +) -> None: + initialize_database() + start_str = _format_date(start) + end_str = _format_date(end) + + extra_steps = 0 + if include_limits: + extra_steps += 1 + if include_extended: + extra_steps += 4 + total_steps = 5 + extra_steps + current_step = 0 + + def advance(message: str) -> None: + nonlocal current_step + current_step += 1 + progress = min(current_step / total_steps, 1.0) + if progress_hook: + progress_hook(message, progress) + LOGGER.info(message, extra=LOG_EXTRA) + + advance("准备股票基础信息与交易日历") + ensure_stock_basic() + ensure_trade_calendar(start, end) + + codes = tuple(dict.fromkeys(ts_codes)) if ts_codes else tuple() + expected_days = _expected_trading_days(start_str, end_str) + + advance("处理日线行情数据") + if codes: + pending_codes: List[str] = [] + for code in codes: + if not force and _should_skip_range("daily", "trade_date", start, end, code): + LOGGER.info("股票 %s 的日线已覆盖 %s-%s,跳过", code, start_str, end_str) + continue + pending_codes.append(code) + if pending_codes: + LOGGER.info("开始拉取日线行情:%s-%s(待补股票 %d 支)", start_str, end_str, len(pending_codes)) + save_records( + "daily", + fetch_daily_bars(start, end, pending_codes, skip_existing=not force), + ) + else: + needs_daily = force or _range_needs_refresh("daily", "trade_date", start_str, end_str, expected_days) + if not needs_daily: + LOGGER.info("日线数据已覆盖 %s-%s,跳过拉取", start_str, end_str) + else: + LOGGER.info("开始拉取日线行情:%s-%s", start_str, end_str) + save_records( + "daily", + fetch_daily_bars(start, end, skip_existing=not force), + ) + + advance("处理指数成分股权重数据") + ensure_index_weights(start, end) + + advance("处理指数每日指标数据") + ensure_index_dailybasic(start, end) + + date_cols = { + "daily_basic": "trade_date", + "adj_factor": "trade_date", + "stk_limit": "trade_date", + "suspend": "suspend_date", + "index_daily": "trade_date", + "index_dailybasic": "trade_date", + "index_weight": "trade_date", + "fund_nav": "nav_date", + "fut_daily": "trade_date", + "fx_daily": "trade_date", + "hk_daily": "trade_date", + "us_daily": "trade_date", + } + + def _save_with_codes(table: str, fetch_fn) -> None: + date_col = date_cols.get(table, "trade_date") + if codes: + for code in codes: + if not force and _should_skip_range(table, date_col, start, end, code): + LOGGER.info("表 %s 股票 %s 已覆盖 %s-%s,跳过", table, code, start_str, end_str) + continue + LOGGER.info("拉取 %s 表数据(股票:%s)%s-%s", table, code, start_str, end_str) + rows = fetch_fn(start, end, ts_code=code, skip_existing=not force) + save_records(table, rows) + else: + needs_refresh = force or table == "suspend" + if not force and table != "suspend": + expected = expected_days if table in {"daily_basic", "adj_factor", "stk_limit"} else 0 + needs_refresh = _range_needs_refresh(table, date_col, start_str, end_str, expected) + if not needs_refresh: + LOGGER.info("表 %s 已覆盖 %s-%s,跳过", table, start_str, end_str) + return + LOGGER.info("拉取 %s 表数据(全市场)%s-%s", table, start_str, end_str) + rows = fetch_fn(start, end, skip_existing=not force) + save_records(table, rows) + + advance("处理日线基础指标数据") + _save_with_codes("daily_basic", fetch_daily_basic) + + advance("处理复权因子数据") + _save_with_codes("adj_factor", fetch_adj_factor) + + if include_limits: + advance("处理涨跌停价格数据") + _save_with_codes("stk_limit", fetch_stk_limit) + + advance("处理停复牌信息") + _save_with_codes("suspend", fetch_suspensions) + + if include_extended: + advance("同步指数/基金/期货基础信息") + save_records("index_basic", fetch_index_basic()) + save_records("fund_basic", fetch_fund_basic()) + save_records("fut_basic", fetch_fut_basic()) + + advance("拉取指数行情数据") + for code in INDEX_CODES: + if not force and _should_skip_range("index_daily", "trade_date", start, end, code): + LOGGER.info("指数 %s 已覆盖 %s-%s,跳过", code, start_str, end_str) + continue + save_records("index_daily", fetch_index_daily(start, end, code)) + + advance("拉取基金净值数据") + fund_targets = tuple(dict.fromkeys(ETF_CODES + FUND_CODES)) + for code in fund_targets: + if not force and _should_skip_range("fund_nav", "nav_date", start, end, code): + LOGGER.info("基金 %s 净值已覆盖 %s-%s,跳过", code, start_str, end_str) + continue + save_records("fund_nav", fetch_fund_nav(start, end, code)) + + advance("拉取期货/外汇行情数据") + for code in FUTURE_CODES: + if not force and _should_skip_range("fut_daily", "trade_date", start, end, code): + LOGGER.info("期货 %s 已覆盖 %s-%s,跳过", code, start_str, end_str) + continue + save_records("fut_daily", fetch_fut_daily(start, end, code)) + for code in FX_CODES: + if not force and _should_skip_range("fx_daily", "trade_date", start, end, code): + LOGGER.info("外汇 %s 已覆盖 %s-%s,跳过", code, start_str, end_str) + continue + save_records("fx_daily", fetch_fx_daily(start, end, code)) + + advance("拉取港/美股行情数据(已暂时关闭)") + for code in HK_CODES: + if not force and _should_skip_range("hk_daily", "trade_date", start, end, code): + LOGGER.info("港股 %s 已覆盖 %s-%s,跳过", code, start_str, end_str) + continue + save_records("hk_daily", fetch_hk_daily(start, end, code)) + for code in US_CODES: + if not force and _should_skip_range("us_daily", "trade_date", start, end, code): + LOGGER.info("美股 %s 已覆盖 %s-%s,跳过", code, start_str, end_str) + continue + save_records("us_daily", fetch_us_daily(start, end, code)) + + if progress_hook: + progress_hook("数据覆盖检查完成", 1.0) + + +def collect_data_coverage(start: date, end: date) -> Dict[str, Dict[str, object]]: + start_str = _format_date(start) + end_str = _format_date(end) + expected_days = _expected_trading_days(start_str, end_str) + + coverage: Dict[str, Dict[str, object]] = { + "period": { + "start": start_str, + "end": end_str, + "expected_trading_days": expected_days, + } + } + + def add_table(name: str, date_col: str, require_days: bool = True) -> None: + stats = _range_stats(name, date_col, start_str, end_str) + coverage[name] = { + "min": stats["min"], + "max": stats["max"], + "distinct_days": stats["distinct"], + "meets_expectation": ( + stats["min"] is not None + and stats["max"] is not None + and stats["min"] <= start_str + and stats["max"] >= end_str + and ((not require_days) or (stats["distinct"] or 0) >= expected_days) + ), + } + + add_table("daily", "trade_date") + add_table("daily_basic", "trade_date") + add_table("adj_factor", "trade_date") + add_table("stk_limit", "trade_date") + add_table("suspend", "suspend_date", require_days=False) + add_table("index_daily", "trade_date") + add_table("fund_nav", "nav_date", require_days=False) + add_table("fut_daily", "trade_date", require_days=False) + add_table("fx_daily", "trade_date", require_days=False) + add_table("hk_daily", "trade_date", require_days=False) + add_table("us_daily", "trade_date", require_days=False) + + with db_session(read_only=True) as conn: + stock_tot = conn.execute("SELECT COUNT(*) AS cnt FROM stock_basic").fetchone() + stock_sse = conn.execute( + "SELECT COUNT(*) AS cnt FROM stock_basic WHERE exchange = 'SSE' AND list_status = 'L'" + ).fetchone() + stock_szse = conn.execute( + "SELECT COUNT(*) AS cnt FROM stock_basic WHERE exchange = 'SZSE' AND list_status = 'L'" + ).fetchone() + coverage["stock_basic"] = { + "total": stock_tot["cnt"] if stock_tot else 0, + "sse_listed": stock_sse["cnt"] if stock_sse else 0, + "szse_listed": stock_szse["cnt"] if stock_szse else 0, + } + + return coverage + + +__all__ = [ + "collect_data_coverage", + "ensure_data_coverage", + "ensure_index_dailybasic", + "ensure_index_weights", +] diff --git a/app/ingest/tushare.py b/app/ingest/tushare.py index f1f1170..de4b5ae 100644 --- a/app/ingest/tushare.py +++ b/app/ingest/tushare.py @@ -1,233 +1,21 @@ -"""TuShare 数据拉取与数据覆盖检查工具。""" +"""Ingestion job orchestrator wrapping TuShare utilities.""" from __future__ import annotations -import os -import sqlite3 -import time -from collections import defaultdict, deque from dataclasses import dataclass from datetime import date -from typing import Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple +from typing import Callable, Iterable, List, Optional, Sequence -import pandas as pd - -try: - import tushare as ts -except ImportError: # pragma: no cover - 运行时提示 - ts = None # type: ignore[assignment] - -from app.utils import alerts -from app.utils.config import get_config -from app.utils.db import db_session -from app.data.schema import initialize_database -from app.utils.logging import get_logger from app.features.factors import compute_factor_range +from app.utils import alerts +from app.utils.logging import get_logger +from .api_client import LOG_EXTRA +from .coverage import collect_data_coverage, ensure_data_coverage +from .job_logger import JobLogger LOGGER = get_logger(__name__) -API_DEFAULT_LIMIT = 5000 -LOG_EXTRA = {"stage": "data_ingest"} - -_CALL_BUCKETS: Dict[str, deque] = defaultdict(deque) - -RATE_LIMIT_ERROR_PATTERNS: Tuple[str, ...] = ( - "最多访问该接口", - "超过接口限制", - "Frequency limit", -) - -API_RATE_LIMITS: Dict[str, int] = { - "stock_basic": 180, - "daily": 480, - "daily_basic": 200, - "adj_factor": 200, - "suspend_d": 180, - "suspend": 180, - "stk_limit": 200, - "trade_cal": 200, - "index_basic": 120, - "index_daily": 240, - "fund_basic": 120, - "fund_nav": 200, - "fut_basic": 120, - "fut_daily": 200, - "fx_daily": 200, - "hk_daily": 2, - "us_daily": 200, -} - - -INDEX_CODES: Tuple[str, ...] = ( - "000001.SH", # 上证综指 - "000300.SH", # 沪深300 - "000016.SH", # 上证50 - "000905.SH", # 中证500 - "399001.SZ", # 深证成指 - "399005.SZ", # 中小板指 - "399006.SZ", # 创业板指 - "HSI.HI", # 恒生指数 - "SPX.GI", # 标普500 - "DJI.GI", # 道琼斯工业指数 - "IXIC.GI", # 纳斯达克综合指数 - "GDAXI.GI", # 德国DAX - "FTSE.GI", # 英国富时100 -) - -ETF_CODES: Tuple[str, ...] = ( - "510300.SH", # 华泰柏瑞沪深300ETF - "510500.SH", # 南方中证500ETF - "159915.SZ", # 易方达创业板ETF -) - -FUND_CODES: Tuple[str, ...] = ( - "000001.OF", # 华夏成长 - "110022.OF", # 易方达消费行业 -) - -FUTURE_CODES: Tuple[str, ...] = ( - "IF9999.CFE", # 沪深300股指期货主力 - "IC9999.CFE", # 中证500股指期货主力 - "IH9999.CFE", # 上证50股指期货主力 -) - -FX_CODES: Tuple[str, ...] = ( - "USDCNY", # 美元人民币 - "EURCNY", # 欧元人民币 -) - -HK_CODES: Tuple[str, ...] = ( - "00700.HK", # 腾讯控股 - "00941.HK", # 中国移动 - "09618.HK", # 京东集团-SW - "09988.HK", # 阿里巴巴-SW - "03690.HK", # 美团-W -) - -US_CODES: Tuple[str, ...] = ( - "AAPL.O", # 苹果 - "MSFT.O", # 微软 - "BABA.N", # 阿里巴巴美股 - "JD.O", # 京东美股 - "PDD.O", # 拼多多 - "BIDU.O", # 百度 - "BILI.O", # 哔哩哔哩 -) - - -def _normalize_date_str(value: Optional[str]) -> Optional[str]: - if value is None: - return None - text = str(value).strip() - return text or None - - -def _respect_rate_limit(endpoint: str | None) -> None: - def _throttle(queue: deque, limit: int) -> None: - if limit <= 0: - return - now = time.time() - window = 60.0 - while queue and now - queue[0] > window: - queue.popleft() - if len(queue) >= limit: - sleep_time = window - (now - queue[0]) + 0.1 - LOGGER.debug( - "触发限频控制(limit=%s)休眠 %.2f 秒 endpoint=%s", - limit, - sleep_time, - endpoint, - extra=LOG_EXTRA, - ) - time.sleep(max(0.1, sleep_time)) - queue.append(time.time()) - - bucket_key = endpoint or "_default" - endpoint_limit = API_RATE_LIMITS.get(bucket_key, 60) - _throttle(_CALL_BUCKETS[bucket_key], endpoint_limit or 0) - - -def _df_to_records(df: pd.DataFrame, allowed_cols: List[str]) -> List[Dict]: - if df is None or df.empty: - return [] - reindexed = df.reindex(columns=allowed_cols) - return reindexed.where(pd.notnull(reindexed), None).to_dict("records") - - -def _fetch_paginated(endpoint: str, params: Dict[str, object], limit: int | None = None) -> pd.DataFrame: - client = _ensure_client() - limit = limit or API_DEFAULT_LIMIT - frames: List[pd.DataFrame] = [] - offset = 0 - clean_params = {k: v for k, v in params.items() if v is not None} - LOGGER.info( - "开始调用 TuShare 接口:%s,参数=%s,limit=%s", - endpoint, - clean_params, - limit, - extra=LOG_EXTRA, - ) - while True: - _respect_rate_limit(endpoint) - call = getattr(client, endpoint) - try: - df = call(limit=limit, offset=offset, **clean_params) - except Exception as exc: # noqa: BLE001 - message = str(exc) - if any(pattern in message for pattern in RATE_LIMIT_ERROR_PATTERNS): - per_minute = API_RATE_LIMITS.get(endpoint or "", 0) - wait_time = 60.0 / per_minute + 1 if per_minute else 30.0 - wait_time = max(wait_time, 30.0) - LOGGER.warning( - "接口限频触发:%s,原因=%s,等待 %.1f 秒后重试", - endpoint, - message, - wait_time, - extra=LOG_EXTRA, - ) - time.sleep(wait_time) - continue - - LOGGER.exception( - "TuShare 接口调用异常:endpoint=%s offset=%s params=%s", - endpoint, - offset, - clean_params, - extra=LOG_EXTRA, - ) - raise - if df is None or df.empty: - LOGGER.debug( - "TuShare 返回空数据:endpoint=%s offset=%s", - endpoint, - offset, - extra=LOG_EXTRA, - ) - break - LOGGER.debug( - "TuShare 返回 %s 行:endpoint=%s offset=%s", - len(df), - endpoint, - offset, - extra=LOG_EXTRA, - ) - frames.append(df) - if len(df) < limit: - break - offset += limit - if not frames: - return pd.DataFrame() - merged = pd.concat(frames, ignore_index=True) - LOGGER.info( - "TuShare 调用完成:endpoint=%s 总行数=%s", - endpoint, - len(merged), - extra=LOG_EXTRA, - ) - return merged - - -from .job_logger import JobLogger +PostTask = Callable[["FetchJob"], None] @dataclass @@ -238,1597 +26,81 @@ class FetchJob: granularity: str = "daily" ts_codes: Optional[Sequence[str]] = None - -_TABLE_SCHEMAS: Dict[str, str] = { - "stock_basic": """ - CREATE TABLE IF NOT EXISTS stock_basic ( - ts_code TEXT PRIMARY KEY, - symbol TEXT, - name TEXT, - area TEXT, - industry TEXT, - market TEXT, - exchange TEXT, - list_status TEXT, - list_date TEXT, - delist_date TEXT - ); - """, - "daily": """ - CREATE TABLE IF NOT EXISTS daily ( - ts_code TEXT, - trade_date TEXT, - open REAL, - high REAL, - low REAL, - close REAL, - pre_close REAL, - change REAL, - pct_chg REAL, - vol REAL, - amount REAL, - PRIMARY KEY (ts_code, trade_date) - ); - """, - "daily_basic": """ - CREATE TABLE IF NOT EXISTS daily_basic ( - ts_code TEXT, - trade_date TEXT, - close REAL, - turnover_rate REAL, - turnover_rate_f REAL, - volume_ratio REAL, - pe REAL, - pe_ttm REAL, - pb REAL, - ps REAL, - ps_ttm REAL, - dv_ratio REAL, - dv_ttm REAL, - total_share REAL, - float_share REAL, - free_share REAL, - total_mv REAL, - circ_mv REAL, - PRIMARY KEY (ts_code, trade_date) - ); - """, - "adj_factor": """ - CREATE TABLE IF NOT EXISTS adj_factor ( - ts_code TEXT, - trade_date TEXT, - adj_factor REAL, - PRIMARY KEY (ts_code, trade_date) - ); - """, - "suspend": """ - CREATE TABLE IF NOT EXISTS suspend ( - ts_code TEXT, - suspend_date TEXT, - resume_date TEXT, - suspend_type TEXT, - ann_date TEXT, - suspend_timing TEXT, - resume_timing TEXT, - reason TEXT, - PRIMARY KEY (ts_code, suspend_date) - ); - """, - "trade_calendar": """ - CREATE TABLE IF NOT EXISTS trade_calendar ( - exchange TEXT, - cal_date TEXT, - is_open INTEGER, - pretrade_date TEXT, - PRIMARY KEY (exchange, cal_date) - ); - """, - "stk_limit": """ - CREATE TABLE IF NOT EXISTS stk_limit ( - ts_code TEXT, - trade_date TEXT, - up_limit REAL, - down_limit REAL, - PRIMARY KEY (ts_code, trade_date) - ); - """, - "index_basic": """ - CREATE TABLE IF NOT EXISTS index_basic ( - ts_code TEXT PRIMARY KEY, - name TEXT, - fullname TEXT, - market TEXT, - publisher TEXT, - index_type TEXT, - category TEXT, - base_date TEXT, - base_point REAL, - list_date TEXT, - weight_rule TEXT, - desc TEXT, - exp_date TEXT - ); - """, - "index_daily": """ - CREATE TABLE IF NOT EXISTS index_daily ( - ts_code TEXT, - trade_date TEXT, - close REAL, - open REAL, - high REAL, - low REAL, - pre_close REAL, - change REAL, - pct_chg REAL, - vol REAL, - amount REAL, - PRIMARY KEY (ts_code, trade_date) - ); - """, - "index_dailybasic": """ - CREATE TABLE IF NOT EXISTS index_dailybasic ( - ts_code TEXT, - trade_date TEXT, - turnover REAL, - turnover_ratio REAL, - pe_ttm REAL, - pb REAL, - ps_ttm REAL, - dv_ttm REAL, - total_mv REAL, - circ_mv REAL, - PRIMARY KEY (ts_code, trade_date) - ); - """, - "index_weight": """ - CREATE TABLE IF NOT EXISTS index_weight ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - index_code VARCHAR(10) NOT NULL, - trade_date VARCHAR(8) NOT NULL, - ts_code VARCHAR(10) NOT NULL, - weight FLOAT, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP - ); - """, - "fund_basic": """ - CREATE TABLE IF NOT EXISTS fund_basic ( - ts_code TEXT PRIMARY KEY, - name TEXT, - management TEXT, - custodian TEXT, - fund_type TEXT, - found_date TEXT, - due_date TEXT, - list_date TEXT, - issue_date TEXT, - delist_date TEXT, - issue_amount REAL, - m_fee REAL, - c_fee REAL, - benchmark TEXT, - status TEXT, - invest_type TEXT, - type TEXT, - trustee TEXT, - purc_start_date TEXT, - redm_start_date TEXT, - market TEXT - ); - """, - "fund_nav": """ - CREATE TABLE IF NOT EXISTS fund_nav ( - ts_code TEXT, - nav_date TEXT, - ann_date TEXT, - unit_nav REAL, - accum_nav REAL, - accum_div REAL, - net_asset REAL, - total_netasset REAL, - adj_nav REAL, - update_flag TEXT, - PRIMARY KEY (ts_code, nav_date) - ); - """, - "fut_basic": """ - CREATE TABLE IF NOT EXISTS fut_basic ( - ts_code TEXT PRIMARY KEY, - symbol TEXT, - name TEXT, - exchange TEXT, - exchange_full_name TEXT, - product TEXT, - product_name TEXT, - variety TEXT, - list_date TEXT, - delist_date TEXT, - trade_unit REAL, - per_unit REAL, - quote_unit TEXT, - settle_month TEXT, - contract_size REAL, - tick_size REAL, - margin_rate REAL, - margin_ratio REAL, - delivery_month TEXT, - delivery_day TEXT - ); - """, - "fut_daily": """ - CREATE TABLE IF NOT EXISTS fut_daily ( - ts_code TEXT, - trade_date TEXT, - pre_settle REAL, - open REAL, - high REAL, - low REAL, - close REAL, - settle REAL, - change1 REAL, - change2 REAL, - vol REAL, - amount REAL, - oi REAL, - oi_chg REAL, - PRIMARY KEY (ts_code, trade_date) - ); - """, - "fx_daily": """ - CREATE TABLE IF NOT EXISTS fx_daily ( - ts_code TEXT, - trade_date TEXT, - bid REAL, - ask REAL, - mid REAL, - high REAL, - low REAL, - PRIMARY KEY (ts_code, trade_date) - ); - """, - "hk_daily": """ - CREATE TABLE IF NOT EXISTS hk_daily ( - ts_code TEXT, - trade_date TEXT, - close REAL, - open REAL, - high REAL, - low REAL, - pre_close REAL, - change REAL, - pct_chg REAL, - vol REAL, - amount REAL, - exchange TEXT, - PRIMARY KEY (ts_code, trade_date) - ); - """, - "us_daily": """ - CREATE TABLE IF NOT EXISTS us_daily ( - ts_code TEXT, - trade_date TEXT, - close REAL, - open REAL, - high REAL, - low REAL, - pre_close REAL, - change REAL, - pct_chg REAL, - vol REAL, - amount REAL, - PRIMARY KEY (ts_code, trade_date) - ); - """, -} - -_TABLE_COLUMNS: Dict[str, List[str]] = { - "stock_basic": [ - "ts_code", - "symbol", - "name", - "area", - "industry", - "market", - "exchange", - "list_status", - "list_date", - "delist_date", - ], - "daily": [ - "ts_code", - "trade_date", - "open", - "high", - "low", - "close", - "pre_close", - "change", - "pct_chg", - "vol", - "amount", - ], - "daily_basic": [ - "ts_code", - "trade_date", - "close", - "turnover_rate", - "turnover_rate_f", - "volume_ratio", - "pe", - "pe_ttm", - "pb", - "ps", - "ps_ttm", - "dv_ratio", - "dv_ttm", - "total_share", - "float_share", - "free_share", - "total_mv", - "circ_mv", - ], - "adj_factor": [ - "ts_code", - "trade_date", - "adj_factor", - ], - "suspend": [ - "ts_code", - "suspend_date", - "resume_date", - "suspend_type", - "ann_date", - "suspend_timing", - "resume_timing", - "reason", - ], - "trade_calendar": [ - "exchange", - "cal_date", - "is_open", - "pretrade_date", - ], - "stk_limit": [ - "ts_code", - "trade_date", - "up_limit", - "down_limit", - ], - "index_basic": [ - "ts_code", - "name", - "fullname", - "market", - "publisher", - "index_type", - "category", - "base_date", - "base_point", - "list_date", - "weight_rule", - "desc", - "exp_date", - ], - "index_daily": [ - "ts_code", - "trade_date", - "close", - "open", - "high", - "low", - "pre_close", - "change", - "pct_chg", - "vol", - "amount", - ], - "index_dailybasic": [ - "ts_code", - "trade_date", - "turnover", - "turnover_ratio", - "pe_ttm", - "pb", - "ps_ttm", - "dv_ttm", - "total_mv", - "circ_mv", - ], - "index_weight": [ - "index_code", - "trade_date", - "ts_code", - "weight", - ], - "fund_basic": [ - "ts_code", - "name", - "management", - "custodian", - "fund_type", - "found_date", - "due_date", - "list_date", - "issue_date", - "delist_date", - "issue_amount", - "m_fee", - "c_fee", - "benchmark", - "status", - "invest_type", - "type", - "trustee", - "purc_start_date", - "redm_start_date", - "market", - ], - "fund_nav": [ - "ts_code", - "nav_date", - "ann_date", - "unit_nav", - "accum_nav", - "accum_div", - "net_asset", - "total_netasset", - "adj_nav", - "update_flag", - ], - "fut_basic": [ - "ts_code", - "symbol", - "name", - "exchange", - "exchange_full_name", - "product", - "product_name", - "variety", - "list_date", - "delist_date", - "trade_unit", - "per_unit", - "quote_unit", - "settle_month", - "contract_size", - "tick_size", - "margin_rate", - "margin_ratio", - "delivery_month", - "delivery_day", - ], - "fut_daily": [ - "ts_code", - "trade_date", - "pre_settle", - "open", - "high", - "low", - "close", - "settle", - "change1", - "change2", - "vol", - "amount", - "oi", - "oi_chg", - ], - "fx_daily": [ - "ts_code", - "trade_date", - "bid", - "ask", - "mid", - "high", - "low", - ], - "hk_daily": [ - "ts_code", - "trade_date", - "close", - "open", - "high", - "low", - "pre_close", - "change", - "pct_chg", - "vol", - "amount", - "exchange", - ], - "us_daily": [ - "ts_code", - "trade_date", - "close", - "open", - "high", - "low", - "pre_close", - "change", - "pct_chg", - "vol", - "amount", - ], -} + def as_dict(self) -> dict: + return { + "name": self.name, + "start": str(self.start), + "end": str(self.end), + "granularity": self.granularity, + "codes": list(self.ts_codes or ()), + } -def _ensure_client(): - if ts is None: - raise RuntimeError("未安装 tushare,请先在环境中安装 tushare 包") - token = get_config().tushare_token or os.getenv("TUSHARE_TOKEN") - if not token: - raise RuntimeError("未配置 TuShare Token,请在配置文件或环境变量 TUSHARE_TOKEN 中设置") - if not hasattr(_ensure_client, "_client") or _ensure_client._client is None: # type: ignore[attr-defined] - ts.set_token(token) - _ensure_client._client = ts.pro_api(token) # type: ignore[attr-defined] - LOGGER.info("完成 TuShare 客户端初始化") - return _ensure_client._client # type: ignore[attr-defined] - - -def _format_date(value: date) -> str: - return value.strftime("%Y%m%d") - - -def _load_trade_dates(start: date, end: date, exchange: str = "SSE") -> List[str]: - start_str = _format_date(start) - end_str = _format_date(end) - query = ( - "SELECT cal_date FROM trade_calendar " - "WHERE exchange = ? AND cal_date BETWEEN ? AND ? AND is_open = 1 ORDER BY cal_date" - ) - with db_session(read_only=True) as conn: - rows = conn.execute(query, (exchange, start_str, end_str)).fetchall() - return [row["cal_date"] for row in rows] - - -def _record_exists( - table: str, - date_col: str, - trade_date: str, - ts_code: Optional[str] = None, -) -> bool: - query = f"SELECT 1 FROM {table} WHERE {date_col} = ?" - params: Tuple = (trade_date,) - if ts_code: - query += " AND ts_code = ?" - params = (trade_date, ts_code) - with db_session(read_only=True) as conn: - row = conn.execute(query, params).fetchone() - return row is not None - - -def _should_skip_range(table: str, date_col: str, start: date, end: date, ts_code: str | None = None) -> bool: - start_str = _format_date(start) - end_str = _format_date(end) - - effective_start = start_str - effective_end = end_str - - if ts_code: - list_date, delist_date = _listing_window(ts_code) - if list_date: - effective_start = max(effective_start, list_date) - if delist_date: - effective_end = min(effective_end, delist_date) - if effective_start > effective_end: - LOGGER.debug( - "股票 %s 在目标区间之外,跳过补数", - ts_code, - extra=LOG_EXTRA, - ) - return True - stats = _range_stats(table, date_col, effective_start, effective_end, ts_code=ts_code) - else: - stats = _range_stats(table, date_col, effective_start, effective_end) - - if stats["min"] is None or stats["max"] is None: - return False - if stats["min"] > effective_start or stats["max"] < effective_end: - return False - - if ts_code is None: - expected_days = _expected_trading_days(effective_start, effective_end) - if expected_days and (stats["distinct"] or 0) < expected_days: - return False - - return True - - -def _range_stats( - table: str, - date_col: str, - start_str: str, - end_str: str, - ts_code: str | None = None, -) -> Dict[str, Optional[str]]: - sql = ( - f"SELECT MIN({date_col}) AS min_d, MAX({date_col}) AS max_d, " - f"COUNT(DISTINCT {date_col}) AS distinct_days FROM {table} " - f"WHERE {date_col} BETWEEN ? AND ?" - ) - params: List[object] = [start_str, end_str] - if ts_code: - sql += " AND ts_code = ?" - params.append(ts_code) - try: - with db_session(read_only=True) as conn: - row = conn.execute(sql, tuple(params)).fetchone() - except sqlite3.OperationalError: - return {"min": None, "max": None, "distinct": 0} - return { - "min": row["min_d"] if row else None, - "max": row["max_d"] if row else None, - "distinct": row["distinct_days"] if row else 0, - } - - -def _range_needs_refresh( - table: str, - date_col: str, - start_str: str, - end_str: str, - expected_days: int = 0, -) -> bool: - stats = _range_stats(table, date_col, start_str, end_str) - if stats["min"] is None or stats["max"] is None: - return True - if stats["min"] > start_str or stats["max"] < end_str: - return True - if expected_days and (stats["distinct"] or 0) < expected_days: - return True - return False - - -def _existing_suspend_dates(start_str: str, end_str: str, ts_code: str | None = None) -> Set[str]: - sql = "SELECT DISTINCT suspend_date FROM suspend WHERE suspend_date BETWEEN ? AND ?" - params: List[object] = [start_str, end_str] - if ts_code: - sql += " AND ts_code = ?" - params.append(ts_code) - try: - with db_session(read_only=True) as conn: - rows = conn.execute(sql, tuple(params)).fetchall() - except sqlite3.OperationalError: - return set() - return {row["suspend_date"] for row in rows if row["suspend_date"]} - - -def _listing_window(ts_code: str) -> Tuple[Optional[str], Optional[str]]: - with db_session(read_only=True) as conn: - row = conn.execute( - "SELECT list_date, delist_date FROM stock_basic WHERE ts_code = ?", - (ts_code,), - ).fetchone() - if not row: - return None, None - return _normalize_date_str(row["list_date"]), _normalize_date_str(row["delist_date"]) # type: ignore[index] - - -def _calendar_needs_refresh(exchange: str, start_str: str, end_str: str) -> bool: - sql = """ - SELECT MIN(cal_date) AS min_d, MAX(cal_date) AS max_d, COUNT(*) AS cnt - FROM trade_calendar - WHERE exchange = ? AND cal_date BETWEEN ? AND ? - """ - with db_session(read_only=True) as conn: - row = conn.execute(sql, (exchange, start_str, end_str)).fetchone() - if row is None or row["min_d"] is None: - return True - if row["min_d"] > start_str or row["max_d"] < end_str: - return True - # 交易日历允许不连续(节假日),此处不比较天数 - return False - - -def _expected_trading_days(start_str: str, end_str: str, exchange: str = "SSE") -> int: - sql = """ - SELECT COUNT(*) AS cnt - FROM trade_calendar - WHERE exchange = ? AND cal_date BETWEEN ? AND ? AND is_open = 1 - """ - with db_session(read_only=True) as conn: - row = conn.execute(sql, (exchange, start_str, end_str)).fetchone() - return int(row["cnt"]) if row and row["cnt"] is not None else 0 - - -def fetch_stock_basic(exchange: Optional[str] = None, list_status: str = "L") -> Iterable[Dict]: - client = _ensure_client() - LOGGER.info( - "拉取股票基础信息(交易所:%s,状态:%s)", - exchange or "全部", - list_status, - extra=LOG_EXTRA, - ) - _respect_rate_limit("stock_basic") - fields = "ts_code,symbol,name,area,industry,market,exchange,list_status,list_date,delist_date" - df = client.stock_basic(exchange=exchange, list_status=list_status, fields=fields) - return _df_to_records(df, _TABLE_COLUMNS["stock_basic"]) - - -def fetch_daily_bars(job: FetchJob, skip_existing: bool = True) -> Iterable[Dict]: - client = _ensure_client() - frames: List[pd.DataFrame] = [] - +def _default_post_tasks(job: FetchJob) -> List[PostTask]: if job.granularity != "daily": - raise ValueError(f"暂不支持的粒度:{job.granularity}") - - trade_dates = _load_trade_dates(job.start, job.end) - if not trade_dates: - LOGGER.info("本地交易日历缺失,尝试补全后再拉取日线行情", extra=LOG_EXTRA) - ensure_trade_calendar(job.start, job.end) - trade_dates = _load_trade_dates(job.start, job.end) - - if job.ts_codes: - for code in job.ts_codes: - for trade_date in trade_dates: - if skip_existing and _record_exists("daily", "trade_date", trade_date, code): - LOGGER.debug( - "日线数据已存在,跳过 %s %s", - code, - trade_date, - extra=LOG_EXTRA, - ) - continue - LOGGER.debug( - "按交易日拉取日线行情:code=%s trade_date=%s", - code, - trade_date, - extra=LOG_EXTRA, - ) - LOGGER.info( - "交易日拉取请求:endpoint=daily code=%s trade_date=%s", - code, - trade_date, - extra=LOG_EXTRA, - ) - df = _fetch_paginated( - "daily", - { - "trade_date": trade_date, - "ts_code": code, - }, - ) - if not df.empty: - frames.append(df) - else: - for trade_date in trade_dates: - if skip_existing and _record_exists("daily", "trade_date", trade_date): - LOGGER.debug( - "日线数据已存在,跳过交易日 %s", - trade_date, - extra=LOG_EXTRA, - ) - continue - LOGGER.debug("按交易日拉取日线行情:%s", trade_date, extra=LOG_EXTRA) - LOGGER.info( - "交易日拉取请求:endpoint=daily trade_date=%s", - trade_date, - extra=LOG_EXTRA, - ) - df = _fetch_paginated("daily", {"trade_date": trade_date}) - if not df.empty: - frames.append(df) - - if not frames: return [] - df = pd.concat(frames, ignore_index=True) - return _df_to_records(df, _TABLE_COLUMNS["daily"]) + return [_run_factor_backfill] -def fetch_daily_basic( - start: date, - end: date, - ts_code: Optional[str] = None, - skip_existing: bool = True, -) -> Iterable[Dict]: - client = _ensure_client() - start_date = _format_date(start) - end_date = _format_date(end) - LOGGER.info( - "拉取日线基础指标(%s-%s,股票:%s)", - start_date, - end_date, - ts_code or "全部", - extra=LOG_EXTRA, +def _run_factor_backfill(job: FetchJob) -> None: + LOGGER.info("开始计算因子:%s", job.name, extra=LOG_EXTRA) + compute_factor_range( + job.start, + job.end, + ts_codes=job.ts_codes, + skip_existing=False, ) - - trade_dates = _load_trade_dates(start, end) - frames: List[pd.DataFrame] = [] - for trade_date in trade_dates: - if skip_existing and _record_exists("daily_basic", "trade_date", trade_date, ts_code): - LOGGER.info( - "日线基础指标已存在,跳过交易日 %s", - trade_date, - extra=LOG_EXTRA, - ) - continue - params = {"trade_date": trade_date} - if ts_code: - params["ts_code"] = ts_code - LOGGER.info( - "交易日拉取请求:endpoint=daily_basic params=%s", - params, - extra=LOG_EXTRA, - ) - df = _fetch_paginated("daily_basic", params) - if not df.empty: - frames.append(df) - - if not frames: - return [] - - merged = pd.concat(frames, ignore_index=True) - return _df_to_records(merged, _TABLE_COLUMNS["daily_basic"]) + alerts.clear_warnings("Factors") -def fetch_adj_factor( - start: date, - end: date, - ts_code: Optional[str] = None, - skip_existing: bool = True, -) -> Iterable[Dict]: - client = _ensure_client() - start_date = _format_date(start) - end_date = _format_date(end) - LOGGER.info( - "拉取复权因子(%s-%s,股票:%s)", - start_date, - end_date, - ts_code or "全部", - extra=LOG_EXTRA, - ) - - trade_dates = _load_trade_dates(start, end) - frames: List[pd.DataFrame] = [] - for trade_date in trade_dates: - if skip_existing and _record_exists("adj_factor", "trade_date", trade_date, ts_code): - LOGGER.debug( - "复权因子已存在,跳过 %s %s", - ts_code or "ALL", - trade_date, - extra=LOG_EXTRA, - ) - continue - params = {"trade_date": trade_date} - if ts_code: - params["ts_code"] = ts_code - LOGGER.info("交易日拉取请求:endpoint=adj_factor params=%s", params, extra=LOG_EXTRA) - df = _fetch_paginated("adj_factor", params) - if not df.empty: - frames.append(df) - - if not frames: - return [] - - merged = pd.concat(frames, ignore_index=True) - return _df_to_records(merged, _TABLE_COLUMNS["adj_factor"]) - - -def fetch_suspensions( - start: date, - end: date, - ts_code: Optional[str] = None, - skip_existing: bool = True, -) -> Iterable[Dict]: - client = _ensure_client() - start_date = _format_date(start) - end_date = _format_date(end) - LOGGER.info( - "拉取停复牌信息(逐日循环)%s-%s 股票=%s", - start_date, - end_date, - ts_code or "全部", - extra=LOG_EXTRA, - ) - trade_dates = _load_trade_dates(start, end) - existing_dates: Set[str] = set() - if skip_existing: - existing_dates = _existing_suspend_dates(start_date, end_date, ts_code) - if existing_dates: - LOGGER.debug( - "停复牌已有覆盖日期数量=%s 示例=%s", - len(existing_dates), - sorted(existing_dates)[:5], - extra=LOG_EXTRA, - ) - frames: List[pd.DataFrame] = [] - for trade_date in trade_dates: - if skip_existing and trade_date in existing_dates: - LOGGER.debug( - "停复牌信息已存在,跳过 %s %s", - ts_code or "ALL", - trade_date, - extra=LOG_EXTRA, - ) - continue - params: Dict[str, object] = {"trade_date": trade_date} - if ts_code: - params["ts_code"] = ts_code - LOGGER.info( - "交易日拉取请求:endpoint=suspend_d params=%s", - params, - extra=LOG_EXTRA, - ) - df = _fetch_paginated("suspend_d", params, limit=2000) - if not df.empty: - if "suspend_date" not in df.columns and "trade_date" in df.columns: - df = df.rename(columns={"trade_date": "suspend_date"}) - frames.append(df) - - if not frames: - LOGGER.info("停复牌接口未返回数据", extra=LOG_EXTRA) - return [] - - merged = pd.concat(frames, ignore_index=True) - missing_cols = [col for col in _TABLE_COLUMNS["suspend"] if col not in merged.columns] - for col in missing_cols: - merged[col] = None - ordered = merged[_TABLE_COLUMNS["suspend"]] - return _df_to_records(ordered, _TABLE_COLUMNS["suspend"]) - - -def fetch_trade_calendar(start: date, end: date, exchange: str = "SSE") -> Iterable[Dict]: - client = _ensure_client() - start_date = _format_date(start) - end_date = _format_date(end) - LOGGER.info( - "拉取交易日历(交易所:%s,区间:%s-%s)", - exchange, - start_date, - end_date, - extra=LOG_EXTRA, - ) - _respect_rate_limit("trade_cal") - df = client.trade_cal(exchange=exchange, start_date=start_date, end_date=end_date) - if df is not None and not df.empty and "is_open" in df.columns: - df["is_open"] = pd.to_numeric(df["is_open"], errors="coerce").fillna(0).astype(int) - return _df_to_records(df, _TABLE_COLUMNS["trade_calendar"]) - - -def fetch_stk_limit( - start: date, - end: date, - ts_code: Optional[str] = None, - skip_existing: bool = True, -) -> Iterable[Dict]: - client = _ensure_client() - start_date = _format_date(start) - end_date = _format_date(end) - LOGGER.info("拉取涨跌停价格(%s-%s)", start_date, end_date, extra=LOG_EXTRA) - trade_dates = _load_trade_dates(start, end) - frames: List[pd.DataFrame] = [] - for trade_date in trade_dates: - if skip_existing and _record_exists("stk_limit", "trade_date", trade_date, ts_code): - LOGGER.debug( - "涨跌停数据已存在,跳过 %s %s", - ts_code or "ALL", - trade_date, - extra=LOG_EXTRA, - ) - continue - params = {"trade_date": trade_date} - if ts_code: - params["ts_code"] = ts_code - LOGGER.info("交易日拉取请求:endpoint=stk_limit params=%s", params, extra=LOG_EXTRA) - df = _fetch_paginated("stk_limit", params) - if not df.empty: - frames.append(df) - - if not frames: - return [] - - merged = pd.concat(frames, ignore_index=True) - return _df_to_records(merged, _TABLE_COLUMNS["stk_limit"]) - - -def fetch_index_basic(market: Optional[str] = None) -> Iterable[Dict]: - client = _ensure_client() - LOGGER.info("拉取指数基础信息(market=%s)", market or "all", extra=LOG_EXTRA) - _respect_rate_limit("index_basic") - df = client.index_basic(market=market) - return _df_to_records(df, _TABLE_COLUMNS["index_basic"]) - - -def fetch_index_daily(start: date, end: date, ts_code: str) -> Iterable[Dict]: - client = _ensure_client() - start_str = _format_date(start) - end_str = _format_date(end) - LOGGER.info( - "拉取指数日线:%s %s-%s", - ts_code, - start_str, - end_str, - extra=LOG_EXTRA, - ) - df = _fetch_paginated( - "index_daily", - {"ts_code": ts_code, "start_date": start_str, "end_date": end_str}, - limit=5000, - ) - return _df_to_records(df, _TABLE_COLUMNS["index_daily"]) - - -def fetch_index_weight(start: date, end: date, index_code: str) -> Iterable[Dict]: - """拉取指定指数的成分股权重数据。 - - Args: - start: 开始日期 - end: 结束日期 - index_code: 指数代码,如 "000300.SH" - - Returns: - 成分股权重数据列表 - """ - client = _ensure_client() - start_str = _format_date(start) - end_str = _format_date(end) - LOGGER.info( - "拉取指数成分股权重:%s %s-%s", - index_code, - start_str, - end_str, - extra=LOG_EXTRA, - ) - df = _fetch_paginated( - "index_weight", - {"index_code": index_code, "start_date": start_str, "end_date": end_str}, - limit=5000, - ) - # Filter out rows where con_code is null to avoid DB constraint violation - if df is not None and not df.empty: - df = df.dropna(subset=["con_code"]) - # Rename con_code to ts_code to match database schema - df = df.rename(columns={"con_code": "ts_code"}) - return _df_to_records(df, ["index_code", "trade_date", "ts_code", "weight"]) - - -def fetch_index_dailybasic(start: date, end: date, ts_code: str) -> Iterable[Dict]: - """拉取指定指数的每日指标数据。 - - Args: - start: 开始日期 - end: 结束日期 - ts_code: 指数代码,如 "000300.SH" - - Returns: - 指数每日指标数据 - """ - client = _ensure_client() - start_str = _format_date(start) - end_str = _format_date(end) - LOGGER.info( - "拉取指数每日指标:%s %s-%s", - ts_code, - start_str, - end_str, - extra=LOG_EXTRA, - ) - df = _fetch_paginated( - "index_dailybasic", - {"ts_code": ts_code, "start_date": start_str, "end_date": end_str}, - limit=5000, - ) - return _df_to_records(df, ["ts_code", "trade_date", "turnover", "turnover_ratio", "pe_ttm", "pb", "ps_ttm", "dv_ttm", "total_mv", "circ_mv"]) - - -def fetch_fund_basic(asset_class: str = "E", status: str = "L") -> Iterable[Dict]: - client = _ensure_client() - LOGGER.info("拉取基金基础信息:asset_class=%s status=%s", asset_class, status, extra=LOG_EXTRA) - _respect_rate_limit("fund_basic") - df = client.fund_basic(market=asset_class, status=status) - return _df_to_records(df, _TABLE_COLUMNS["fund_basic"]) - - -def fetch_fund_nav(start: date, end: date, ts_code: str) -> Iterable[Dict]: - client = _ensure_client() - start_str = _format_date(start) - end_str = _format_date(end) - LOGGER.info( - "拉取基金净值:%s %s-%s", - ts_code, - start_str, - end_str, - extra=LOG_EXTRA, - ) - df = _fetch_paginated( - "fund_nav", - {"ts_code": ts_code, "start_date": start_str, "end_date": end_str}, - limit=5000, - ) - return _df_to_records(df, _TABLE_COLUMNS["fund_nav"]) - - -def fetch_fut_basic(exchange: Optional[str] = None) -> Iterable[Dict]: - client = _ensure_client() - LOGGER.info("拉取期货基础信息(exchange=%s)", exchange or "all", extra=LOG_EXTRA) - _respect_rate_limit("fut_basic") - df = client.fut_basic(exchange=exchange) - return _df_to_records(df, _TABLE_COLUMNS["fut_basic"]) - - -def fetch_fut_daily(start: date, end: date, ts_code: str) -> Iterable[Dict]: - client = _ensure_client() - start_str = _format_date(start) - end_str = _format_date(end) - LOGGER.info( - "拉取期货日线:%s %s-%s", - ts_code, - start_str, - end_str, - extra=LOG_EXTRA, - ) - df = _fetch_paginated( - "fut_daily", - {"ts_code": ts_code, "start_date": start_str, "end_date": end_str}, - limit=4000, - ) - return _df_to_records(df, _TABLE_COLUMNS["fut_daily"]) - - -def fetch_fx_daily(start: date, end: date, ts_code: str) -> Iterable[Dict]: - client = _ensure_client() - start_str = _format_date(start) - end_str = _format_date(end) - LOGGER.info( - "拉取外汇日线:%s %s-%s", - ts_code, - start_str, - end_str, - extra=LOG_EXTRA, - ) - df = _fetch_paginated( - "fx_daily", - {"ts_code": ts_code, "start_date": start_str, "end_date": end_str}, - limit=4000, - ) - return _df_to_records(df, _TABLE_COLUMNS["fx_daily"]) - - -def fetch_hk_daily(start: date, end: date, ts_code: str) -> Iterable[Dict]: - client = _ensure_client() - start_str = _format_date(start) - end_str = _format_date(end) - LOGGER.info( - "拉取港股日线:%s %s-%s", - ts_code, - start_str, - end_str, - extra=LOG_EXTRA, - ) - df = _fetch_paginated( - "hk_daily", - {"ts_code": ts_code, "start_date": start_str, "end_date": end_str}, - limit=4000, - ) - return _df_to_records(df, _TABLE_COLUMNS["hk_daily"]) - - -def fetch_us_daily(start: date, end: date, ts_code: str) -> Iterable[Dict]: - client = _ensure_client() - start_str = _format_date(start) - end_str = _format_date(end) - LOGGER.info( - "拉取美股日线:%s %s-%s", - ts_code, - start_str, - end_str, - extra=LOG_EXTRA, - ) - df = _fetch_paginated( - "us_daily", - {"ts_code": ts_code, "start_date": start_str, "end_date": end_str}, - limit=4000, - ) - return _df_to_records(df, _TABLE_COLUMNS["us_daily"]) - - -def save_records(table: str, rows: Iterable[Dict]) -> None: - items = list(rows) - if not items: - LOGGER.info("表 %s 没有新增记录,跳过写入", table, extra=LOG_EXTRA) - return - - schema = _TABLE_SCHEMAS.get(table) - columns = _TABLE_COLUMNS.get(table) - if not schema or not columns: - raise ValueError(f"不支持写入的表:{table}") - - placeholders = ",".join([f":{col}" for col in columns]) - col_clause = ",".join(columns) - - LOGGER.info("表 %s 写入 %d 条记录", table, len(items), extra=LOG_EXTRA) - with db_session() as conn: - conn.executescript(schema) - conn.executemany( - f"INSERT OR REPLACE INTO {table} ({col_clause}) VALUES ({placeholders})", - items, - ) - - -def ensure_stock_basic(list_status: str = "L") -> None: - exchanges = ("SSE", "SZSE") - with db_session(read_only=True) as conn: - row = conn.execute( - "SELECT COUNT(*) AS cnt FROM stock_basic WHERE exchange IN (?, ?) AND list_status = ?", - (*exchanges, list_status), - ).fetchone() - if row and row["cnt"]: - LOGGER.info( - "股票基础信息已存在 %d 条记录,跳过拉取", - row["cnt"], - extra=LOG_EXTRA, - ) - return - - for exch in exchanges: - save_records("stock_basic", fetch_stock_basic(exchange=exch, list_status=list_status)) - - -def ensure_trade_calendar(start: date, end: date, exchanges: Sequence[str] = ("SSE", "SZSE")) -> None: - start_str = _format_date(start) - end_str = _format_date(end) - for exch in exchanges: - if _calendar_needs_refresh(exch, start_str, end_str): - save_records("trade_calendar", fetch_trade_calendar(start, end, exchange=exch)) - - -def ensure_index_weights(start: date, end: date, index_codes: Optional[Sequence[str]] = None) -> None: - """确保指定指数的成分股权重数据完整。 - - Args: - start: 开始日期 - end: 结束日期 - index_codes: 指数代码列表,如果为 None 则使用默认的 A 股指数 - """ - if index_codes is None: - # 默认获取 A 股指数的成分股权重 - index_codes = [code for code in INDEX_CODES if code.endswith(".SH") or code.endswith(".SZ")] - - for index_code in index_codes: - start_str = _format_date(start) - end_str = _format_date(end) - - if _range_needs_refresh("index_weight", "trade_date", start_str, end_str, index_code=index_code): - LOGGER.info("指数 %s 的成分股权重数据不完整,开始拉取 %s-%s", index_code, start_str, end_str) - save_records("index_weight", fetch_index_weight(start, end, index_code)) - else: - LOGGER.info("指数 %s 的成分股权重数据已完整,跳过", index_code) - - -def ensure_index_dailybasic(start: date, end: date, index_codes: Optional[Sequence[str]] = None) -> None: - """确保指定指数的每日指标数据完整。 - - Args: - start: 开始日期 - end: 结束日期 - index_codes: 指数代码列表,如果为 None 则使用默认的 A 股指数 - """ - if index_codes is None: - # 默认获取 A 股指数的每日指标 - index_codes = [code for code in INDEX_CODES if code.endswith(".SH") or code.endswith(".SZ")] - - for index_code in index_codes: - start_str = _format_date(start) - end_str = _format_date(end) - - if _range_needs_refresh("index_dailybasic", "trade_date", start_str, end_str, ts_code=index_code): - LOGGER.info("指数 %s 的每日指标数据不完整,开始拉取 %s-%s", index_code, start_str, end_str) - save_records("index_dailybasic", fetch_index_dailybasic(start, end, index_code)) - else: - LOGGER.info("指数 %s 的每日指标数据已完整,跳过", index_code) - - -def ensure_data_coverage( - start: date, - end: date, - ts_codes: Optional[Sequence[str]] = None, +def run_ingestion( + job: FetchJob, + *, include_limits: bool = True, include_extended: bool = True, - force: bool = False, - progress_hook: Callable[[str, float], None] | None = None, + post_tasks: Optional[Iterable[PostTask]] = None, ) -> None: - initialize_database() - start_str = _format_date(start) - end_str = _format_date(end) + """Execute a TuShare ingestion job with optional post processing hooks.""" - extra_steps = 0 - if include_limits: - extra_steps += 1 - if include_extended: - extra_steps += 4 - total_steps = 5 + extra_steps - current_step = 0 - - def advance(message: str) -> None: - nonlocal current_step - current_step += 1 - progress = min(current_step / total_steps, 1.0) - if progress_hook: - progress_hook(message, progress) - LOGGER.info(message, extra=LOG_EXTRA) - - advance("准备股票基础信息与交易日历") - ensure_stock_basic() - ensure_trade_calendar(start, end) - - codes = tuple(dict.fromkeys(ts_codes)) if ts_codes else tuple() - expected_days = _expected_trading_days(start_str, end_str) - - advance("处理日线行情数据") - if codes: - pending_codes: List[str] = [] - for code in codes: - if not force and _should_skip_range("daily", "trade_date", start, end, code): - LOGGER.info("股票 %s 的日线已覆盖 %s-%s,跳过", code, start_str, end_str) - continue - pending_codes.append(code) - if pending_codes: - job = FetchJob("daily_autofill", start=start, end=end, ts_codes=tuple(pending_codes)) - LOGGER.info("开始拉取日线行情:%s-%s(待补股票 %d 支)", start_str, end_str, len(pending_codes)) - save_records("daily", fetch_daily_bars(job, skip_existing=not force)) - else: - needs_daily = force or _range_needs_refresh("daily", "trade_date", start_str, end_str, expected_days) - if not needs_daily: - LOGGER.info("日线数据已覆盖 %s-%s,跳过拉取", start_str, end_str) - else: - job = FetchJob("daily_autofill", start=start, end=end) - LOGGER.info("开始拉取日线行情:%s-%s", start_str, end_str) - save_records("daily", fetch_daily_bars(job, skip_existing=not force)) - - advance("处理指数成分股权重数据") - # 获取默认指数列表 - default_indices = [code for code in INDEX_CODES if code.endswith(".SH") or code.endswith(".SZ")] - for index_code in default_indices: - if not force and _should_skip_range("index_weight", "trade_date", start, end, index_code): - LOGGER.info("指数 %s 的成分股权重已覆盖 %s-%s,跳过", index_code, start_str, end_str) - continue - LOGGER.info("开始拉取指数成分股权重:%s %s-%s", index_code, start_str, end_str) - save_records("index_weight", fetch_index_weight(start, end, index_code)) - - advance("处理指数每日指标数据") - for index_code in default_indices: - if not force and _should_skip_range("index_dailybasic", "trade_date", start, end, index_code): - LOGGER.info("指数 %s 的每日指标已覆盖 %s-%s,跳过", index_code, start_str, end_str) - continue - LOGGER.info("开始拉取指数每日指标:%s %s-%s", index_code, start_str, end_str) - save_records("index_dailybasic", fetch_index_dailybasic(start, end, index_code)) - - date_cols = { - "daily_basic": "trade_date", - "adj_factor": "trade_date", - "stk_limit": "trade_date", - "suspend": "suspend_date", - } - date_cols.update( - { - "index_daily": "trade_date", - "index_dailybasic": "trade_date", - "index_weight": "trade_date", - "fund_nav": "nav_date", - "fut_daily": "trade_date", - "fx_daily": "trade_date", - "hk_daily": "trade_date", - "us_daily": "trade_date", - } - ) - - def _save_with_codes(table: str, fetch_fn) -> None: - date_col = date_cols.get(table, "trade_date") - if codes: - for code in codes: - if not force and _should_skip_range(table, date_col, start, end, code): - LOGGER.info("表 %s 股票 %s 已覆盖 %s-%s,跳过", table, code, start_str, end_str) - continue - LOGGER.info("拉取 %s 表数据(股票:%s)%s-%s", table, code, start_str, end_str) - try: - kwargs = {"ts_code": code} - if fetch_fn in (fetch_daily_basic, fetch_adj_factor, fetch_suspensions, fetch_stk_limit): - kwargs["skip_existing"] = not force - rows = fetch_fn(start, end, **kwargs) - except Exception: - LOGGER.exception("TuShare 拉取失败:table=%s code=%s", table, code) - raise - save_records(table, rows) - else: - needs_refresh = force or table == "suspend" - if not force and table != "suspend": - expected = expected_days if table in {"daily_basic", "adj_factor", "stk_limit"} else 0 - needs_refresh = _range_needs_refresh(table, date_col, start_str, end_str, expected) - if not needs_refresh: - LOGGER.info("表 %s 已覆盖 %s-%s,跳过", table, start_str, end_str) - return - LOGGER.info("拉取 %s 表数据(全市场)%s-%s", table, start_str, end_str) - try: - kwargs = {} - if fetch_fn in (fetch_daily_basic, fetch_adj_factor, fetch_suspensions, fetch_stk_limit): - kwargs["skip_existing"] = not force - rows = fetch_fn(start, end, **kwargs) - except Exception: - LOGGER.exception("TuShare 拉取失败:table=%s code=全部", table) - raise - save_records(table, rows) - - advance("处理日线基础指标数据") - _save_with_codes("daily_basic", fetch_daily_basic) - - advance("处理复权因子数据") - _save_with_codes("adj_factor", fetch_adj_factor) - - if include_limits: - advance("处理涨跌停价格数据") - _save_with_codes("stk_limit", fetch_stk_limit) - - advance("处理停复牌信息") - _save_with_codes("suspend", fetch_suspensions) - - if include_extended: - advance("同步指数/基金/期货基础信息") - try: - save_records("index_basic", fetch_index_basic()) - save_records("fund_basic", fetch_fund_basic()) - save_records("fut_basic", fetch_fut_basic()) - except Exception: - LOGGER.exception("扩展基础信息拉取失败") - raise - - advance("拉取指数行情数据") - for code in INDEX_CODES: - try: - if not force and _should_skip_range("index_daily", "trade_date", start, end, code): - LOGGER.info("指数 %s 已覆盖 %s-%s,跳过", code, start_str, end_str) - continue - save_records("index_daily", fetch_index_daily(start, end, code)) - except Exception: - LOGGER.exception("指数行情拉取失败:%s", code) - raise - - advance("拉取基金净值数据") - fund_targets = tuple(dict.fromkeys(ETF_CODES + FUND_CODES)) - for code in fund_targets: - try: - if not force and _should_skip_range("fund_nav", "nav_date", start, end, code): - LOGGER.info("基金 %s 净值已覆盖 %s-%s,跳过", code, start_str, end_str) - continue - save_records("fund_nav", fetch_fund_nav(start, end, code)) - except Exception: - LOGGER.exception("基金净值拉取失败:%s", code) - raise - - advance("拉取期货/外汇行情数据") - for code in FUTURE_CODES: - try: - if not force and _should_skip_range("fut_daily", "trade_date", start, end, code): - LOGGER.info("期货 %s 已覆盖 %s-%s,跳过", code, start_str, end_str) - continue - save_records("fut_daily", fetch_fut_daily(start, end, code)) - except Exception: - LOGGER.exception("期货行情拉取失败:%s", code) - raise - for code in FX_CODES: - try: - if not force and _should_skip_range("fx_daily", "trade_date", start, end, code): - LOGGER.info("外汇 %s 已覆盖 %s-%s,跳过", code, start_str, end_str) - continue - save_records("fx_daily", fetch_fx_daily(start, end, code)) - except Exception: - LOGGER.exception("外汇行情拉取失败:%s", code) - raise - - advance("拉取港/美股行情数据(已暂时关闭)") - - if progress_hook: - progress_hook("数据覆盖检查完成", 1.0) - - -def collect_data_coverage(start: date, end: date) -> Dict[str, Dict[str, object]]: - start_str = _format_date(start) - end_str = _format_date(end) - expected_days = _expected_trading_days(start_str, end_str) - - coverage: Dict[str, Dict[str, object]] = { - "period": { - "start": start_str, - "end": end_str, - "expected_trading_days": expected_days, - } - } - - def add_table(name: str, date_col: str, require_days: bool = True) -> None: - stats = _range_stats(name, date_col, start_str, end_str) - coverage[name] = { - "min": stats["min"], - "max": stats["max"], - "distinct_days": stats["distinct"], - "meets_expectation": ( - stats["min"] is not None - and stats["max"] is not None - and stats["min"] <= start_str - and stats["max"] >= end_str - and ((not require_days) or (stats["distinct"] or 0) >= expected_days) - ), - } - - add_table("daily", "trade_date") - add_table("daily_basic", "trade_date") - add_table("adj_factor", "trade_date") - add_table("stk_limit", "trade_date") - add_table("suspend", "suspend_date", require_days=False) - add_table("index_daily", "trade_date") - add_table("fund_nav", "nav_date", require_days=False) - add_table("fut_daily", "trade_date", require_days=False) - add_table("fx_daily", "trade_date", require_days=False) - add_table("hk_daily", "trade_date", require_days=False) - add_table("us_daily", "trade_date", require_days=False) - - with db_session(read_only=True) as conn: - stock_tot = conn.execute("SELECT COUNT(*) AS cnt FROM stock_basic").fetchone() - stock_sse = conn.execute( - "SELECT COUNT(*) AS cnt FROM stock_basic WHERE exchange = 'SSE' AND list_status = 'L'" - ).fetchone() - stock_szse = conn.execute( - "SELECT COUNT(*) AS cnt FROM stock_basic WHERE exchange = 'SZSE' AND list_status = 'L'" - ).fetchone() - coverage["stock_basic"] = { - "total": stock_tot["cnt"] if stock_tot else 0, - "sse_listed": stock_sse["cnt"] if stock_sse else 0, - "szse_listed": stock_szse["cnt"] if stock_szse else 0, - } - - return coverage - - -def run_ingestion(job: FetchJob, include_limits: bool = True) -> None: - """运行数据拉取任务。 - - Args: - job: 任务配置 - include_limits: 是否包含涨跌停数据 - """ with JobLogger("TuShare数据获取") as logger: LOGGER.info("启动 TuShare 拉取任务:%s", job.name, extra=LOG_EXTRA) - try: - # 拉取基础数据 ensure_data_coverage( job.start, job.end, ts_codes=job.ts_codes, include_limits=include_limits, - include_extended=True, + include_extended=include_extended, force=True, ) - - # 记录任务元数据 - logger.update_metadata({ - "name": job.name, - "start": str(job.start), - "end": str(job.end), - "codes": len(job.ts_codes) if job.ts_codes else 0 - }) - + logger.update_metadata(job.as_dict()) alerts.clear_warnings("TuShare") - - # 对日线数据计算因子 - if job.granularity == "daily": - LOGGER.info("开始计算因子:%s", job.name, extra=LOG_EXTRA) + + tasks = list(post_tasks) if post_tasks is not None else _default_post_tasks(job) + for task in tasks: try: - compute_factor_range( - job.start, - job.end, - ts_codes=job.ts_codes, - skip_existing=False, + task(job) + except Exception as exc: # noqa: BLE001 + LOGGER.exception( + "后置任务执行失败:task=%s", + getattr(task, "__name__", task), + extra=LOG_EXTRA, ) - alerts.clear_warnings("Factors") - except Exception as exc: - LOGGER.exception("因子计算失败 job=%s", job.name, extra=LOG_EXTRA) - alerts.add_warning("Factors", f"因子计算失败:{job.name}", str(exc)) - logger.update_status("failed", f"因子计算失败:{exc}") + alerts.add_warning("Factors", f"后置任务失败:{job.name}", str(exc)) + logger.update_status("failed", f"后置任务失败:{exc}") raise - LOGGER.info("因子计算完成:%s", job.name, extra=LOG_EXTRA) - alerts.clear_warnings("Factors") - - except Exception as exc: + + except Exception as exc: # noqa: BLE001 LOGGER.exception("数据拉取失败 job=%s", job.name, extra=LOG_EXTRA) alerts.add_warning("TuShare", f"拉取任务失败:{job.name}", str(exc)) - logger.update_status("failed", f"数据拉取失败:{exc}") raise LOGGER.info("任务 %s 完成", job.name, extra=LOG_EXTRA) + + +__all__ = [ + "FetchJob", + "collect_data_coverage", + "ensure_data_coverage", + "run_ingestion", +] + diff --git a/app/utils/data_access.py b/app/utils/data_access.py index a5387b9..7b549a9 100644 --- a/app/utils/data_access.py +++ b/app/utils/data_access.py @@ -11,40 +11,26 @@ from dataclasses import dataclass, field from datetime import date, datetime, timedelta from typing import Any, Callable, ClassVar, Dict, Iterable, List, Optional, Sequence, Set, Tuple -from .config import get_config import types +from .config import get_config from .db import db_session from .logging import get_logger from app.core.indicators import momentum, normalize, rolling_mean, volatility from app.utils.db_query import BrokerQueryEngine from app.utils import alerts +from app.ingest.coverage import collect_data_coverage as _collect_coverage, ensure_data_coverage as _ensure_coverage -# 延迟导入,避免循环依赖 -collect_data_coverage = None -ensure_data_coverage = None -initialize_database = None +try: + from app.data.schema import initialize_database +except ImportError: + def initialize_database(): + """Fallback stub used when the real initializer cannot be imported. -# 在模块加载时尝试导入 -if collect_data_coverage is None or ensure_data_coverage is None: - try: - from app.ingest.tushare import collect_data_coverage, ensure_data_coverage - except ImportError: - # 导入失败时,在实际使用时会报错 - pass - -if initialize_database is None: - try: - from app.data.schema import initialize_database - except ImportError: - # 导入失败时,提供一个空实现 - def initialize_database(): - """Fallback stub used when the real initializer cannot be imported. - - Return a lightweight object with the attributes callers expect - (executed, skipped, missing_tables) so code that calls - `initialize_database()` can safely inspect the result. - """ - return types.SimpleNamespace(executed=0, skipped=True, missing_tables=[]) + Return a lightweight object with the attributes callers expect + (executed, skipped, missing_tables) so code that calls + `initialize_database()` can safely inspect the result. + """ + return types.SimpleNamespace(executed=0, skipped=True, missing_tables=[]) LOGGER = get_logger(__name__) LOG_EXTRA = {"stage": "data_broker"} @@ -56,6 +42,27 @@ def _is_safe_identifier(name: str) -> bool: return bool(_IDENTIFIER_RE.match(name)) +def _default_coverage_runner(start: date, end: date) -> None: + if _ensure_coverage is None: + LOGGER.debug("默认补数函数不可用,跳过自动补数", extra=LOG_EXTRA) + return + _ensure_coverage( + start, + end, + include_limits=False, + include_extended=False, + force=False, + progress_hook=None, + ) + + +def _default_coverage_collector(start: date, end: date) -> Dict[str, Dict[str, object]]: + if _collect_coverage is None: + LOGGER.debug("默认覆盖统计函数不可用,返回空结果", extra=LOG_EXTRA) + return {} + return _collect_coverage(start, end) + + def _safe_split(path: str) -> Tuple[str, str] | None: if "." not in path: return None @@ -197,6 +204,8 @@ class DataBroker: enable_cache: bool = True latest_cache_size: int = 256 series_cache_size: int = 512 + coverage_runner: Callable[[date, date], None] = field(default=_default_coverage_runner) + coverage_collector: Callable[[date, date], Dict[str, Dict[str, object]]] = field(default=_default_coverage_collector) _latest_cache: OrderedDict = field(init=False, repr=False) _series_cache: OrderedDict = field(init=False, repr=False) # 补数相关状态管理 @@ -1234,14 +1243,13 @@ class DataBroker: return False # 收集数据覆盖情况 - if collect_data_coverage is None: - LOGGER.error("collect_data_coverage 函数不可用,请检查导入配置", extra=LOG_EXTRA) + if self.coverage_collector is None: + LOGGER.debug("未配置覆盖统计函数,无法判断是否需要补数", extra=LOG_EXTRA) return False - - coverage = collect_data_coverage( - date.fromisoformat(start_date[:4] + '-' + start_date[4:6] + '-' + start_date[6:8]), - date.fromisoformat(end_date[:4] + '-' + end_date[4:6] + '-' + end_date[6:8]) - ) + + start_d = datetime.strptime(start_date, "%Y%m%d").date() + end_d = datetime.strptime(end_date, "%Y%m%d").date() + coverage = self.coverage_collector(start_d, end_d) # 保存到缓存 coverage['timestamp'] = time.time() if hasattr(time, 'time') else 0 @@ -1285,18 +1293,14 @@ class DataBroker: LOGGER.info("开始后台数据补数: %s 至 %s", start_date, end_date, extra=LOG_EXTRA) # 执行补数 - if ensure_data_coverage is None: - LOGGER.error("ensure_data_coverage 函数不可用,请检查导入配置", extra=LOG_EXTRA) + if self.coverage_runner is None: + LOGGER.debug("未配置覆盖补数函数,跳过自动补数", extra=LOG_EXTRA) with self._refresh_lock: self._refresh_in_progress[refresh_key] = False + self._refresh_callbacks.pop(refresh_key, None) return - - ensure_data_coverage( - start_date, - end_date, - force=False, - progress_hook=None - ) + + self.coverage_runner(start_date, end_date) LOGGER.info("后台数据补数完成: %s 至 %s", start_date, end_date, extra=LOG_EXTRA) @@ -1661,13 +1665,11 @@ class DataBroker: start_d = date.fromisoformat(start.strftime('%Y-%m-%d')) end_d = date.fromisoformat(end.strftime('%Y-%m-%d')) - # 收集数据覆盖情况 - if collect_data_coverage is None: - LOGGER.error("collect_data_coverage 函数不可用,请检查导入配置", extra=LOG_EXTRA) + if self.coverage_collector is None: + LOGGER.debug("未配置覆盖统计函数,返回空覆盖结果", extra=LOG_EXTRA) return {} - - coverage = collect_data_coverage(start_d, end_d) - return coverage + + return self.coverage_collector(start_d, end_d) except Exception as exc: LOGGER.exception("获取数据覆盖情况失败: %s", exc, extra=LOG_EXTRA) return {} diff --git a/app/utils/feature_snapshots.py b/app/utils/feature_snapshots.py new file mode 100644 index 0000000..14341c4 --- /dev/null +++ b/app/utils/feature_snapshots.py @@ -0,0 +1,58 @@ +"""Shared feature snapshot helpers built on top of DataBroker.""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Dict, Iterable, Mapping, Optional, Sequence + +from .data_access import DataBroker + + +@dataclass +class FeatureSnapshotService: + """Provide batch-oriented access to latest features for multiple symbols.""" + + broker: DataBroker + + def __init__(self, broker: Optional[DataBroker] = None) -> None: + self.broker = broker or DataBroker() + + def load_latest( + self, + trade_date: str, + fields: Sequence[str], + ts_codes: Sequence[str], + *, + auto_refresh: bool = False, + ) -> Dict[str, Dict[str, object]]: + """Fetch a snapshot of feature values for the given universe.""" + + if not ts_codes: + return {} + return self.broker.fetch_batch_latest( + list(ts_codes), + trade_date, + fields, + auto_refresh=auto_refresh, + ) + + def load_single( + self, + trade_date: str, + ts_code: str, + fields: Iterable[str], + *, + auto_refresh: bool = False, + ) -> Mapping[str, object]: + """Convenience wrapper to reuse the snapshot logic for a single symbol.""" + + snapshot = self.load_latest( + trade_date, + list(fields), + [ts_code], + auto_refresh=auto_refresh, + ) + return snapshot.get(ts_code, {}) + + +__all__ = ["FeatureSnapshotService"] + diff --git a/docs/TODO.md b/docs/TODO.md index 0e1d3d8..441bc47 100644 --- a/docs/TODO.md +++ b/docs/TODO.md @@ -79,4 +79,4 @@ | --- | --- | --- | | 全量代码审查 | 🔄 | 已制定 `docs/architecture/code_review_checklist.md`,按 checklist 推进模块审查。 | | TODO 标记治理 | 🔄 | 新增 `scripts/todo_report.py` 支撑定期扫描,待梳理遗留项目。 | -| 业务逻辑体检 | ⏳ | 梳理业务链路,识别需要重构或优化的模块。 | +| 业务逻辑体检 | ✅ | 梳理业务链路完成,已拆分采集/覆盖/决策模块;详见 docs/architecture/business_logic_healthcheck.md。 | diff --git a/docs/architecture/business_logic_healthcheck.md b/docs/architecture/business_logic_healthcheck.md new file mode 100644 index 0000000..5e4ed34 --- /dev/null +++ b/docs/architecture/business_logic_healthcheck.md @@ -0,0 +1,50 @@ +# 业务逻辑体检报告 + +本报告梳理当前端到端业务链路,并标注出影响可维护性与扩展性的关键风险点,供后续重构排期参考。 + +## 端到端链路速览 +- **数据采集与健康巡检**:命令行入口 `scripts/run_ingestion_job.py` 通过编排层 `app/ingest/tushare.py`,调用 `app/ingest/api_client.py` 与 `app/ingest/coverage.py` 完成 TuShare 拉数、数据补齐与指标巡检。 +- **数据接入与覆盖治理**:`DataBroker` (`app/utils/data_access.py`) 负责字段解析、缓存、派生指标,自动补数由可注入的 `coverage_runner` 承担;批量快照能力由 `FeatureSnapshotService` (`app/utils/feature_snapshots.py`) 暴露给上层。 +- **因子与特征加工**:`compute_factors` 系列 (`app/features/factors.py`) 借助批量快照与分批校验,输出持久化特征供代理与回测消费。 +- **多智能体决策**:`DecisionWorkflow` (`app/agents/game.py`) 将议程控制、部门投票、风险审查、执行总结拆分为可维护的阶段,驱动规则代理与部门 LLM 协同。 +- **回测与调参与强化学习**:`BacktestEngine.load_market_data` (`app/backtest/engine.py`) 使用批量快照聚合特征,`DecisionEnv` (`app/backtest/decision_env.py`) 暴露 RL 行为接口。 +- **可视化与运营面板**:Streamlit 入口 `app/ui/streamlit_app.py:14-120` 触发数据库初始化、自动补数与多页可视化,消费上述链路的产物。 + +## 主要发现 +### 1. 数据采集模块拆分完成但仍需扩展 +- 采集 orchestrator 已收敛在 `app/ingest/tushare.py`,API 调用与覆盖校验分别由 `app/ingest/api_client.py`、`app/ingest/coverage.py` 承担,后续可考虑将因子计算改为显式队列任务。 +- `run_ingestion` 通过 `post_tasks` 钩子触发默认因子回填,方便引入异步或多阶段处理策略。 + +### 2. 数据访问层职责下沉但仍偏厚重 +- `DataBroker` 引入可注入的 `coverage_runner` 与批量缓存接口,不过派生指标、行业分析仍集中在同一类,可进一步拆分至子组件。 +- 自动补数与覆盖统计改为显式依赖 `app/ingest/coverage.py`,消除了懒加载带来的环状依赖风险。 + +### 3. 因子流水线新增批量快照 +- `FeatureSnapshotService` 批量预取最新字段,`compute_factors` 改为按批拼装特征并共用缓存,减少了重复 SQL。 +- 校验与进度汇报依旧集中在 `_compute_batch_factors`,后续可继续剥离统计与写库逻辑以优化测试粒度。 + +### 4. 多智能体决策流程模块化 +- `DecisionWorkflow` 将部门投票、风险审查、执行总结拆分为独立方法,便于插桩和单元测试。 +- 部门代理仍直接访问 `DataBroker`,后续可对接 `FeatureSnapshotService` 或数据域策略,统一数据获取边界。 + +### 5. 回测链路复用批量特征 +- `BacktestEngine.load_market_data` 与因子流水线共用快照服务,避免重复的最新值查询。 +- 强化学习环境仍按日重新构造 `BacktestEngine`,可在后续迭代中缓存快照或拆分环境状态以进一步加速。 + +## 优先级建议 +1. **完善采集流水线的任务编排**(高优先级) + 现已拆分 API/覆盖/编排层,建议继续将因子计算与其它后置动作放入独立任务队列,便于并发执行与重试。 + +2. **解耦 DataBroker 的派生职责**(中高优先级) + 将行业、情绪、派生指标等逻辑抽出为独立服务,保留 DataBroker 专注于字段解析与缓存;同步补充更细粒度的单元测试。 + +3. **推广特征快照到部门代理**(中优先级) + 目前因子与回测已复用 `FeatureSnapshotService`,建议在 LLM 部门工具调用中也接入统一快照,降低重复 SQL。 + +4. **补齐 DecisionWorkflow 测试与监控**(中优先级) + 为 `DecisionWorkflow` 各阶段编写单元/集成测试,并将风险评审与信念修正暴露在监控面板中,便于审计。 + +5. **建立性能与回归基线**(低优先级) + 构造包含快照缓存的基准数据集,度量因子计算和回测的时延,对后续优化提供数据支持。 + +以上建议可依次推进,亦可按业务优先级穿插执行。