commit 6eac6c5f691a3eba64465019589096b3e51c4a6b Author: sam Date: Fri Sep 26 18:21:25 2025 +0800 init diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..94fc8fe --- /dev/null +++ b/.gitignore @@ -0,0 +1,29 @@ +# Python cache and build artifacts +__pycache__/ +*.py[cod] +*.so +.build/ +.eggs/ +*.egg-info/ + +# Virtual environments +.venv/ +venv/ +ENV/ +env/ + +# Jupyter checkpoints +.ipynb_checkpoints/ + +# Logs and data +app/data/*.db* +app/data/backups/ +app/data/logs/ +*.log + +# Streamlit temporary files +.streamlit/ + +# System files +.DS_Store +Thumbs.db diff --git a/README.md b/README.md new file mode 100644 index 0000000..ab2fbfd --- /dev/null +++ b/README.md @@ -0,0 +1,27 @@ +# 多智能体投资助理骨架 + +本仓库提供一个基于多智能体博弈的 A 股日线投资助理代码框架,满足单机可运行、SQLite 存储和 Streamlit UI 的需求。核心模块划分如下: + +- `app/data`:数据库初始化与 Schema 定义。 +- `app/utils`:配置、数据库连接、日志和交易日历工具。 +- `app/ingest`:TuShare 与 RSS 数据拉取骨架。 +- `app/features`:指标与信号计算接口。 +- `app/agents`:多智能体博弈实现,包括动量、价值、新闻、流动性、宏观与风险代理。 +- `app/backtest`:日线回测引擎与指标计算的占位实现。 +- `app/llm`:人类可读卡片与摘要生成入口(仅构建提示,不直接交易)。 +- `app/ui`:Streamlit 三页界面骨架。 + +## 快速开始 + +```bash +python -m app.main # 初始化数据库 +streamlit run app/ui/streamlit_app.py +``` + +## 下一步 + +1. 在 `app/ingest` 中补充 TuShare 和 RSS 数据抓取逻辑。 +2. 完善 `app/features` 和 `app/backtest` 以实现实际的信号计算与事件驱动回测。 +3. 将代理效用写入 SQLite 的 `agent_utils` 和 `alloc_log` 表,驱动 UI 展示。 +4. 使用轻量情感分析与热度计算,填充 `news` 和 `heat_daily`。 +5. 接入本地小模型或 API 完成 LLM 文本解释,并在 UI 中展示。 diff --git a/app/__init__.py b/app/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/agents/__init__.py b/app/agents/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/agents/base.py b/app/agents/base.py new file mode 100644 index 0000000..d78d6eb --- /dev/null +++ b/app/agents/base.py @@ -0,0 +1,44 @@ +"""Agent abstractions for the multi-agent decision engine.""" +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from typing import Dict, Mapping + + +class AgentAction(str, Enum): + SELL = "SELL" + HOLD = "HOLD" + BUY_S = "BUY_S" + BUY_M = "BUY_M" + BUY_L = "BUY_L" + + +@dataclass +class AgentContext: + ts_code: str + trade_date: str + features: Mapping[str, float] + + +class Agent: + """Base class for all decision agents.""" + + name: str + + def __init__(self, name: str) -> None: + self.name = name + + def score(self, context: AgentContext, action: AgentAction) -> float: + """Return a normalized utility value in [0,1] for the proposed action.""" + + raise NotImplementedError + + def feasible(self, context: AgentContext, action: AgentAction) -> bool: + """Optional hook for agents with veto power (defaults to True).""" + + _ = context, action + return True + + +UtilityMatrix = Dict[AgentAction, Dict[str, float]] diff --git a/app/agents/game.py b/app/agents/game.py new file mode 100644 index 0000000..31b6659 --- /dev/null +++ b/app/agents/game.py @@ -0,0 +1,127 @@ +"""Multi-agent decision game implementation.""" +from __future__ import annotations + +from dataclasses import dataclass +from math import exp, log +from typing import Dict, Iterable, List, Mapping, Tuple + +from .base import Agent, AgentAction, AgentContext, UtilityMatrix +from .registry import weight_map + + +ACTIONS: Tuple[AgentAction, ...] = ( + AgentAction.SELL, + AgentAction.HOLD, + AgentAction.BUY_S, + AgentAction.BUY_M, + AgentAction.BUY_L, +) + + +def _clamp(value: float) -> float: + return max(0.0, min(1.0, value)) + + +@dataclass +class Decision: + action: AgentAction + confidence: float + target_weight: float + feasible_actions: List[AgentAction] + utilities: UtilityMatrix + + +def compute_utilities(agents: Iterable[Agent], context: AgentContext) -> UtilityMatrix: + utilities: UtilityMatrix = {} + for action in ACTIONS: + utilities[action] = {} + for agent in agents: + score = _clamp(agent.score(context, action)) + utilities[action][agent.name] = score + return utilities + + +def feasible_actions(agents: Iterable[Agent], context: AgentContext) -> List[AgentAction]: + feas: List[AgentAction] = [] + for action in ACTIONS: + if all(agent.feasible(context, action) for agent in agents): + feas.append(action) + return feas + + +def nash_bargain(utilities: UtilityMatrix, weights: Mapping[str, float], disagreement: Mapping[str, float]) -> Tuple[AgentAction, float]: + best_action = AgentAction.HOLD + best_score = float("-inf") + for action, agent_scores in utilities.items(): + if action not in utilities: + continue + log_product = 0.0 + valid = True + for agent_name, score in agent_scores.items(): + w = weights.get(agent_name, 0.0) + if w == 0: + continue + gap = score - disagreement.get(agent_name, 0.0) + if gap <= 0: + valid = False + break + log_product += w * log(gap) + if not valid: + continue + if log_product > best_score: + best_score = log_product + best_action = action + if best_score == float("-inf"): + return AgentAction.HOLD, 0.0 + confidence = _aggregate_confidence(utilities[best_action], weights) + return best_action, confidence + + +def vote(utilities: UtilityMatrix, weights: Mapping[str, float]) -> Tuple[AgentAction, float]: + scores: Dict[AgentAction, float] = {} + for action, agent_scores in utilities.items(): + scores[action] = sum(weights.get(agent, 0.0) * score for agent, score in agent_scores.items()) + best_action = max(scores, key=scores.get) + confidence = _aggregate_confidence(utilities[best_action], weights) + return best_action, confidence + + +def _aggregate_confidence(agent_scores: Mapping[str, float], weights: Mapping[str, float]) -> float: + total = sum(weights.values()) + if total <= 0: + return 0.0 + weighted = sum(weights.get(agent, 0.0) * score for agent, score in agent_scores.items()) + return weighted / total + + +def target_weight_for_action(action: AgentAction) -> float: + mapping = { + AgentAction.SELL: -1.0, + AgentAction.HOLD: 0.0, + AgentAction.BUY_S: 0.01, + AgentAction.BUY_M: 0.02, + AgentAction.BUY_L: 0.03, + } + return mapping[action] + + +def decide(context: AgentContext, agents: Iterable[Agent], weights: Mapping[str, float], method: str = "nash") -> Decision: + agent_list = list(agents) + norm_weights = weight_map(dict(weights)) + utilities = compute_utilities(agent_list, context) + feas_actions = feasible_actions(agent_list, context) + if not feas_actions: + return Decision(AgentAction.HOLD, 0.0, 0.0, [], utilities) + + filtered_utilities = {action: utilities[action] for action in feas_actions} + hold_scores = utilities.get(AgentAction.HOLD, {}) + + if method == "vote": + action, confidence = vote(filtered_utilities, norm_weights) + else: + action, confidence = nash_bargain(filtered_utilities, norm_weights, hold_scores) + if action not in feas_actions: + action, confidence = vote(filtered_utilities, norm_weights) + + weight = target_weight_for_action(action) + return Decision(action, confidence, weight, feas_actions, utilities) diff --git a/app/agents/liquidity.py b/app/agents/liquidity.py new file mode 100644 index 0000000..03a66ff --- /dev/null +++ b/app/agents/liquidity.py @@ -0,0 +1,20 @@ +"""Liquidity and transaction cost agent.""" +from __future__ import annotations + +from .base import Agent, AgentAction, AgentContext + + +class LiquidityAgent(Agent): + def __init__(self) -> None: + super().__init__(name="A_liq") + + def score(self, context: AgentContext, action: AgentAction) -> float: + liq = context.features.get("liquidity_score", 0.5) + cost = context.features.get("cost_penalty", 0.0) + penalty = cost + if action is AgentAction.SELL: + return min(1.0, liq + penalty) + if action is AgentAction.HOLD: + return 0.4 + 0.2 * liq + scale = {AgentAction.BUY_S: 0.5, AgentAction.BUY_M: 0.75, AgentAction.BUY_L: 1.0}.get(action, 0.0) + return max(0.0, liq * scale - penalty) diff --git a/app/agents/macro.py b/app/agents/macro.py new file mode 100644 index 0000000..102c3e0 --- /dev/null +++ b/app/agents/macro.py @@ -0,0 +1,23 @@ +"""Macro and industry regime agent.""" +from __future__ import annotations + +from .base import Agent, AgentAction, AgentContext + + +class MacroAgent(Agent): + def __init__(self) -> None: + super().__init__(name="A_macro") + + def score(self, context: AgentContext, action: AgentAction) -> float: + industry_heat = context.features.get("industry_heat", 0.5) + relative_strength = context.features.get("industry_relative_mom", 0.0) + raw = min(1.0, max(0.0, industry_heat * 0.6 + relative_strength * 0.4)) + if action is AgentAction.SELL: + return 1 - raw + if action is AgentAction.HOLD: + return 0.5 + if action is AgentAction.BUY_S: + return raw * 0.6 + if action is AgentAction.BUY_M: + return raw * 0.8 + return raw diff --git a/app/agents/momentum.py b/app/agents/momentum.py new file mode 100644 index 0000000..3a7f17b --- /dev/null +++ b/app/agents/momentum.py @@ -0,0 +1,29 @@ +"""Momentum oriented agent.""" +from __future__ import annotations + +from math import tanh + +from .base import Agent, AgentAction, AgentContext + + +def _sigmoid(x: float) -> float: + return 0.5 * (tanh(x) + 1) + + +class MomentumAgent(Agent): + def __init__(self) -> None: + super().__init__(name="A_mom") + + def score(self, context: AgentContext, action: AgentAction) -> float: + mom20 = context.features.get("mom_20", 0.0) + mom60 = context.features.get("mom_60", 0.0) + strength = _sigmoid(0.5 * mom20 + 0.5 * mom60) + if action is AgentAction.SELL: + return 1 - strength + if action is AgentAction.HOLD: + return 0.5 + if action is AgentAction.BUY_S: + return strength * 0.6 + if action is AgentAction.BUY_M: + return strength * 0.8 + return strength diff --git a/app/agents/news.py b/app/agents/news.py new file mode 100644 index 0000000..e72b14c --- /dev/null +++ b/app/agents/news.py @@ -0,0 +1,26 @@ +"""News and sentiment aware agent.""" +from __future__ import annotations + +from .base import Agent, AgentAction, AgentContext + + +class NewsAgent(Agent): + def __init__(self) -> None: + super().__init__(name="A_news") + + def score(self, context: AgentContext, action: AgentAction) -> float: + heat = context.features.get("news_heat", 0.0) + sentiment = context.features.get("news_sentiment", 0.0) + positive = max(0.0, sentiment) + negative = max(0.0, -sentiment) + buy_score = min(1.0, heat * positive) + sell_score = min(1.0, heat * negative) + if action is AgentAction.SELL: + return sell_score + if action is AgentAction.HOLD: + return 0.3 + 0.4 * (1 - heat) + if action is AgentAction.BUY_S: + return 0.5 * buy_score + if action is AgentAction.BUY_M: + return 0.75 * buy_score + return buy_score diff --git a/app/agents/registry.py b/app/agents/registry.py new file mode 100644 index 0000000..1eafc7b --- /dev/null +++ b/app/agents/registry.py @@ -0,0 +1,30 @@ +"""Factory helpers to construct the agent ensemble.""" +from __future__ import annotations + +from typing import Dict, List + +from .base import Agent +from .liquidity import LiquidityAgent +from .macro import MacroAgent +from .momentum import MomentumAgent +from .news import NewsAgent +from .risk import RiskAgent +from .value import ValueAgent + + +def default_agents() -> List[Agent]: + return [ + MomentumAgent(), + ValueAgent(), + NewsAgent(), + LiquidityAgent(), + MacroAgent(), + RiskAgent(), + ] + + +def weight_map(raw: Dict[str, float]) -> Dict[str, float]: + total = sum(raw.values()) + if total == 0: + return raw + return {name: weight / total for name, weight in raw.items()} diff --git a/app/agents/risk.py b/app/agents/risk.py new file mode 100644 index 0000000..57497c9 --- /dev/null +++ b/app/agents/risk.py @@ -0,0 +1,29 @@ +"""Risk agent acts as leader with veto rights.""" +from __future__ import annotations + +from .base import Agent, AgentAction, AgentContext + + +class RiskAgent(Agent): + def __init__(self) -> None: + super().__init__(name="A_risk") + + def score(self, context: AgentContext, action: AgentAction) -> float: + # Base risk agent is neutral unless penalties are triggered. + penalty = context.features.get("risk_penalty", 0.0) + if action is AgentAction.SELL: + return min(1.0, 0.6 + penalty) + if action is AgentAction.HOLD: + return 0.5 + return max(0.0, 1.0 - penalty) + + def feasible(self, context: AgentContext, action: AgentAction) -> bool: + if action is AgentAction.SELL: + return True + if context.features.get("is_suspended", False): + return False + if context.features.get("limit_up", False) and action not in (AgentAction.SELL, AgentAction.HOLD): + return False + if context.features.get("position_limit", False) and action in (AgentAction.BUY_M, AgentAction.BUY_L): + return False + return True diff --git a/app/agents/value.py b/app/agents/value.py new file mode 100644 index 0000000..ff135e3 --- /dev/null +++ b/app/agents/value.py @@ -0,0 +1,26 @@ +"""Value and quality filtering agent.""" +from __future__ import annotations + +from .base import Agent, AgentAction, AgentContext + + +class ValueAgent(Agent): + def __init__(self) -> None: + super().__init__(name="A_val") + + def score(self, context: AgentContext, action: AgentAction) -> float: + pe = context.features.get("pe_percentile", 0.5) + pb = context.features.get("pb_percentile", 0.5) + roe = context.features.get("roe_percentile", 0.5) + # Lower valuation percentiles and higher quality percentiles add value. + raw = max(0.0, (1 - pe) * 0.4 + (1 - pb) * 0.3 + roe * 0.3) + raw = min(raw, 1.0) + if action is AgentAction.SELL: + return 1 - raw + if action is AgentAction.HOLD: + return 0.5 + if action is AgentAction.BUY_S: + return raw * 0.7 + if action is AgentAction.BUY_M: + return raw * 0.85 + return raw diff --git a/app/backtest/__init__.py b/app/backtest/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/backtest/engine.py b/app/backtest/engine.py new file mode 100644 index 0000000..3876ea2 --- /dev/null +++ b/app/backtest/engine.py @@ -0,0 +1,90 @@ +"""Backtest engine skeleton for daily bar simulation.""" +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import date +from typing import Dict, Iterable, List, Mapping + +from app.agents.base import AgentContext +from app.agents.game import Decision, decide +from app.agents.registry import default_agents +from app.utils.db import db_session + + +@dataclass +class BtConfig: + id: str + name: str + start_date: date + end_date: date + universe: List[str] + params: Dict[str, float] + method: str = "nash" + + +@dataclass +class PortfolioState: + cash: float = 1_000_000.0 + holdings: Dict[str, float] = field(default_factory=dict) + + +@dataclass +class BacktestResult: + nav_series: List[Dict[str, float]] = field(default_factory=list) + trades: List[Dict[str, str]] = field(default_factory=list) + + +class BacktestEngine: + """Runs the multi-agent game inside a daily event-driven loop.""" + + def __init__(self, cfg: BtConfig) -> None: + self.cfg = cfg + self.agents = default_agents() + self.weights = {agent.name: 1.0 for agent in self.agents} + + def load_market_data(self, trade_date: date) -> Mapping[str, Dict[str, float]]: + """Load per-stock feature vectors. Replace with real data access.""" + + _ = trade_date + return {} + + def simulate_day(self, trade_date: date, state: PortfolioState) -> List[Decision]: + feature_map = self.load_market_data(trade_date) + decisions: List[Decision] = [] + for ts_code, features in feature_map.items(): + context = AgentContext(ts_code=ts_code, trade_date=trade_date.isoformat(), features=features) + decision = decide(context, self.agents, self.weights, method=self.cfg.method) + decisions.append(decision) + self.record_agent_state(context, decision) + # TODO: translate decisions into fills, holdings, and NAV updates. + _ = state + return decisions + + def record_agent_state(self, context: AgentContext, decision: Decision) -> None: + payload = { + "trade_date": context.trade_date, + "ts_code": context.ts_code, + "action": decision.action.value, + "confidence": decision.confidence, + } + _ = payload + # Implementation should persist into agent_utils and bt_trades. + + def run(self) -> BacktestResult: + state = PortfolioState() + result = BacktestResult() + current = self.cfg.start_date + while current <= self.cfg.end_date: + decisions = self.simulate_day(current, state) + _ = decisions + current = date.fromordinal(current.toordinal() + 1) + return result + + +def run_backtest(cfg: BtConfig) -> BacktestResult: + engine = BacktestEngine(cfg) + result = engine.run() + with db_session() as conn: + _ = conn + # Implementation should persist bt_nav, bt_trades, and bt_report rows. + return result diff --git a/app/backtest/metrics.py b/app/backtest/metrics.py new file mode 100644 index 0000000..b002637 --- /dev/null +++ b/app/backtest/metrics.py @@ -0,0 +1,19 @@ +"""Performance metric utilities.""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Iterable, List + + +@dataclass +class Metric: + name: str + value: float + description: str + + +def compute_nav_metrics(nav_series: Iterable[float]) -> List[Metric]: + """Compute core statistics such as CAGR, Sharpe, and max drawdown.""" + + _ = nav_series + return [] diff --git a/app/data/__init__.py b/app/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/data/schema.py b/app/data/schema.py new file mode 100644 index 0000000..e50a1ee --- /dev/null +++ b/app/data/schema.py @@ -0,0 +1,146 @@ +"""Database schema management for the investment assistant.""" +from __future__ import annotations + +import sqlite3 +from dataclasses import dataclass +from typing import Iterable + +from app.utils.config import get_config +from app.utils.db import db_session + + +SCHEMA_STATEMENTS: Iterable[str] = ( + """ + CREATE TABLE IF NOT EXISTS news ( + id TEXT PRIMARY KEY, + ts_code TEXT, + pub_time TEXT, + source TEXT, + title TEXT, + summary TEXT, + url TEXT, + entities TEXT, + sentiment REAL, + heat REAL + ); + """, + """ + CREATE INDEX IF NOT EXISTS idx_news_time ON news(pub_time DESC); + """, + """ + CREATE INDEX IF NOT EXISTS idx_news_code ON news(ts_code, pub_time DESC); + """, + """ + CREATE TABLE IF NOT EXISTS heat_daily ( + scope TEXT, + key TEXT, + trade_date TEXT, + heat REAL, + top_topics TEXT, + PRIMARY KEY (scope, key, trade_date) + ); + """, + """ + CREATE TABLE IF NOT EXISTS bt_config ( + id TEXT PRIMARY KEY, + name TEXT, + start_date TEXT, + end_date TEXT, + universe TEXT, + params TEXT + ); + """, + """ + CREATE TABLE IF NOT EXISTS bt_trades ( + cfg_id TEXT, + ts_code TEXT, + trade_date TEXT, + side TEXT, + price REAL, + qty REAL, + reason TEXT, + PRIMARY KEY (cfg_id, ts_code, trade_date, side) + ); + """, + """ + CREATE TABLE IF NOT EXISTS bt_nav ( + cfg_id TEXT, + trade_date TEXT, + nav REAL, + ret REAL, + pos_count INTEGER, + turnover REAL, + dd REAL, + info TEXT, + PRIMARY KEY (cfg_id, trade_date) + ); + """, + """ + CREATE TABLE IF NOT EXISTS bt_report ( + cfg_id TEXT PRIMARY KEY, + summary TEXT + ); + """, + """ + CREATE TABLE IF NOT EXISTS run_log ( + ts TEXT PRIMARY KEY, + stage TEXT, + level TEXT, + msg TEXT + ); + """, + """ + CREATE TABLE IF NOT EXISTS agent_utils ( + trade_date TEXT, + ts_code TEXT, + agent TEXT, + action TEXT, + utils TEXT, + feasible TEXT, + weight REAL, + PRIMARY KEY (trade_date, ts_code, agent) + ); + """, + """ + CREATE TABLE IF NOT EXISTS alloc_log ( + trade_date TEXT, + ts_code TEXT, + target_weight REAL, + clipped_weight REAL, + reason TEXT, + PRIMARY KEY (trade_date, ts_code) + ); + """ +) + + +@dataclass +class MigrationResult: + executed: int + skipped: bool = False + + +def _schema_exists() -> bool: + try: + with db_session(read_only=True) as conn: + cursor = conn.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='news'" + ) + return cursor.fetchone() is not None + except sqlite3.OperationalError: + return False + + +def initialize_database() -> MigrationResult: + """Create tables and indexes required by the application.""" + + if _schema_exists(): + return MigrationResult(executed=0, skipped=True) + + executed = 0 + with db_session() as conn: + cursor = conn.cursor() + for statement in SCHEMA_STATEMENTS: + cursor.executescript(statement) + executed += 1 + return MigrationResult(executed=executed) diff --git a/app/features/__init__.py b/app/features/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/features/factors.py b/app/features/factors.py new file mode 100644 index 0000000..485884f --- /dev/null +++ b/app/features/factors.py @@ -0,0 +1,39 @@ +"""Feature engineering for signals and indicator computation.""" +from __future__ import annotations + +from dataclasses import dataclass +from datetime import date +from typing import Iterable, List + + +@dataclass +class FactorSpec: + name: str + window: int + + +@dataclass +class FactorResult: + ts_code: str + trade_date: date + values: dict + + +DEFAULT_FACTORS: List[FactorSpec] = [ + FactorSpec("mom_20", 20), + FactorSpec("mom_60", 60), + FactorSpec("volat_20", 20), + FactorSpec("turn_20", 20), +] + + +def compute_factors(trade_date: date, factors: Iterable[FactorSpec] = DEFAULT_FACTORS) -> List[FactorResult]: + """Calculate factor values for the requested date. + + This function should join historical price data, apply rolling windows, and + persist results into an factors table. The implementation is left as future + work. + """ + + _ = trade_date, factors + raise NotImplementedError diff --git a/app/ingest/__init__.py b/app/ingest/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/ingest/rss.py b/app/ingest/rss.py new file mode 100644 index 0000000..50b97d4 --- /dev/null +++ b/app/ingest/rss.py @@ -0,0 +1,42 @@ +"""RSS ingestion for news and heat scores.""" +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime +from typing import Iterable, List + + +@dataclass +class RssItem: + id: str + title: str + link: str + published: datetime + summary: str + source: str + + +def fetch_rss_feed(url: str) -> List[RssItem]: + """Download and parse an RSS feed into structured items.""" + + raise NotImplementedError + + +def deduplicate_items(items: Iterable[RssItem]) -> List[RssItem]: + """Drop duplicate stories by link/id fingerprint.""" + + seen = set() + unique: List[RssItem] = [] + for item in items: + key = item.id or item.link + if key in seen: + continue + seen.add(key) + unique.append(item) + return unique + + +def save_news_items(items: Iterable[RssItem]) -> None: + """Persist RSS items into the `news` table.""" + + raise NotImplementedError diff --git a/app/ingest/tushare.py b/app/ingest/tushare.py new file mode 100644 index 0000000..46e3e2d --- /dev/null +++ b/app/ingest/tushare.py @@ -0,0 +1,240 @@ +"""TuShare 数据拉取管线实现。""" +from __future__ import annotations + +import logging +import os +from dataclasses import dataclass +from datetime import date +from typing import Dict, Iterable, List, Optional, Sequence + +import pandas as pd + +try: + import tushare as ts +except ImportError as exc: # pragma: no cover - dependency error surfaced at runtime + ts = None # type: ignore[assignment] + +from app.utils.config import get_config +from app.utils.db import db_session + +LOGGER = logging.getLogger(__name__) + + +@dataclass +class FetchJob: + name: str + start: date + end: date + granularity: str = "daily" + ts_codes: Optional[Sequence[str]] = None + + +_TABLE_SCHEMAS: Dict[str, str] = { + "daily": """ + CREATE TABLE IF NOT EXISTS daily ( + ts_code TEXT, + trade_date TEXT, + open REAL, + high REAL, + low REAL, + close REAL, + pre_close REAL, + change REAL, + pct_chg REAL, + vol REAL, + amount REAL, + PRIMARY KEY (ts_code, trade_date) + ); + """, + "suspend": """ + CREATE TABLE IF NOT EXISTS suspend ( + ts_code TEXT, + suspend_date TEXT, + resume_date TEXT, + suspend_type TEXT, + ann_date TEXT, + suspend_timing TEXT, + resume_timing TEXT, + reason TEXT, + PRIMARY KEY (ts_code, suspend_date) + ); + """, + "trade_calendar": """ + CREATE TABLE IF NOT EXISTS trade_calendar ( + exchange TEXT, + cal_date TEXT PRIMARY KEY, + is_open INTEGER, + pretrade_date TEXT + ); + """, + "stk_limit": """ + CREATE TABLE IF NOT EXISTS stk_limit ( + ts_code TEXT, + trade_date TEXT, + up_limit REAL, + down_limit REAL, + PRIMARY KEY (ts_code, trade_date) + ); + """, +} + +_TABLE_COLUMNS: Dict[str, List[str]] = { + "daily": [ + "ts_code", + "trade_date", + "open", + "high", + "low", + "close", + "pre_close", + "change", + "pct_chg", + "vol", + "amount", + ], + "suspend": [ + "ts_code", + "suspend_date", + "resume_date", + "suspend_type", + "ann_date", + "suspend_timing", + "resume_timing", + "reason", + ], + "trade_calendar": [ + "exchange", + "cal_date", + "is_open", + "pretrade_date", + ], + "stk_limit": [ + "ts_code", + "trade_date", + "up_limit", + "down_limit", + ], +} + + +def _ensure_client(): + if ts is None: + raise RuntimeError("未安装 tushare,请先在环境中安装 tushare 包") + token = get_config().tushare_token or os.getenv("TUSHARE_TOKEN") + if not token: + raise RuntimeError("未配置 TuShare Token,请在配置文件或环境变量 TUSHARE_TOKEN 中设置") + if not hasattr(_ensure_client, "_client") or _ensure_client._client is None: # type: ignore[attr-defined] + ts.set_token(token) + _ensure_client._client = ts.pro_api(token) # type: ignore[attr-defined] + LOGGER.info("完成 TuShare 客户端初始化") + return _ensure_client._client # type: ignore[attr-defined] + + +def _format_date(value: date) -> str: + return value.strftime("%Y%m%d") + + +def _df_to_records(df: pd.DataFrame, allowed_cols: List[str]) -> List[Dict]: + if df is None or df.empty: + return [] + # 对缺失列进行补全,防止写库时缺少绑定参数 + reindexed = df.reindex(columns=allowed_cols) + return reindexed.where(pd.notnull(reindexed), None).to_dict("records") + + +def fetch_daily_bars(job: FetchJob) -> Iterable[Dict]: + """拉取日线行情。""" + + client = _ensure_client() + start_date = _format_date(job.start) + end_date = _format_date(job.end) + frames: List[pd.DataFrame] = [] + + if job.granularity != "daily": + raise ValueError(f"暂不支持的粒度:{job.granularity}") + + if job.ts_codes: + for code in job.ts_codes: + LOGGER.info("拉取 %s 的日线行情(%s-%s)", code, start_date, end_date) + frames.append(client.daily(ts_code=code, start_date=start_date, end_date=end_date)) + else: + LOGGER.info("按全市场拉取日线行情(%s-%s)", start_date, end_date) + frames.append(client.daily(start_date=start_date, end_date=end_date)) + + if not frames: + return [] + df = pd.concat(frames, ignore_index=True) if len(frames) > 1 else frames[0] + return _df_to_records(df, _TABLE_COLUMNS["daily"]) + + +def fetch_suspensions(start: date, end: date, ts_code: Optional[str] = None) -> Iterable[Dict]: + client = _ensure_client() + start_date = _format_date(start) + end_date = _format_date(end) + LOGGER.info("拉取停复牌信息(%s-%s)", start_date, end_date) + df = client.suspend_d(ts_code=ts_code, start_date=start_date, end_date=end_date) + return _df_to_records(df, _TABLE_COLUMNS["suspend"]) + + +def fetch_trade_calendar(start: date, end: date, exchange: str = "SSE") -> Iterable[Dict]: + client = _ensure_client() + start_date = _format_date(start) + end_date = _format_date(end) + LOGGER.info("拉取交易日历(交易所:%s,区间:%s-%s)", exchange, start_date, end_date) + df = client.trade_cal(exchange=exchange, start_date=start_date, end_date=end_date) + return _df_to_records(df, _TABLE_COLUMNS["trade_calendar"]) + + +def fetch_stk_limit(start: date, end: date, ts_code: Optional[str] = None) -> Iterable[Dict]: + client = _ensure_client() + start_date = _format_date(start) + end_date = _format_date(end) + LOGGER.info("拉取涨跌停价格(%s-%s)", start_date, end_date) + df = client.stk_limit(ts_code=ts_code, start_date=start_date, end_date=end_date) + return _df_to_records(df, _TABLE_COLUMNS["stk_limit"]) + + +def save_records(table: str, rows: Iterable[Dict]) -> None: + """将拉取的数据写入 SQLite。""" + + items = list(rows) + if not items: + LOGGER.info("表 %s 没有新增记录,跳过写入", table) + return + + schema = _TABLE_SCHEMAS.get(table) + columns = _TABLE_COLUMNS.get(table) + if not schema or not columns: + raise ValueError(f"不支持写入的表:{table}") + + placeholders = ",".join([f":{col}" for col in columns]) + col_clause = ",".join(columns) + + LOGGER.info("表 %s 写入 %d 条记录", table, len(items)) + with db_session() as conn: + conn.executescript(schema) + conn.executemany( + f"INSERT OR REPLACE INTO {table} ({col_clause}) VALUES ({placeholders})", + items, + ) + + +def run_ingestion(job: FetchJob, include_limits: bool = True) -> None: + """按任务配置拉取 TuShare 数据。""" + + LOGGER.info("启动 TuShare 拉取任务:%s", job.name) + + daily_rows = fetch_daily_bars(job) + save_records("daily", daily_rows) + + suspend_rows = fetch_suspensions(job.start, job.end) + save_records("suspend", suspend_rows) + + calendar_rows = fetch_trade_calendar(job.start, job.end) + save_records("trade_calendar", calendar_rows) + + if include_limits: + limit_rows = fetch_stk_limit(job.start, job.end) + save_records("stk_limit", limit_rows) + + LOGGER.info("任务 %s 完成", job.name) diff --git a/app/llm/__init__.py b/app/llm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/llm/explain.py b/app/llm/explain.py new file mode 100644 index 0000000..5d6a7bb --- /dev/null +++ b/app/llm/explain.py @@ -0,0 +1,18 @@ +"""LLM assisted explanations and summaries.""" +from __future__ import annotations + +from typing import Dict + +from .prompts import plan_prompt + + +def make_human_card(ts_code: str, trade_date: str, context: Dict) -> Dict: + """Compose payload for UI cards and LLM requests.""" + + prompt = plan_prompt(context) + return { + "ts_code": ts_code, + "trade_date": trade_date, + "prompt": prompt, + "context": context, + } diff --git a/app/llm/prompts.py b/app/llm/prompts.py new file mode 100644 index 0000000..96cd99f --- /dev/null +++ b/app/llm/prompts.py @@ -0,0 +1,11 @@ +"""Prompt templates for natural language outputs.""" +from __future__ import annotations + +from typing import Dict + + +def plan_prompt(data: Dict) -> str: + """Build a concise instruction prompt for the LLM.""" + + _ = data + return "你是一个投资助理,请根据提供的数据给出三条要点和两条风险提示。" diff --git a/app/main.py b/app/main.py new file mode 100644 index 0000000..ba8901a --- /dev/null +++ b/app/main.py @@ -0,0 +1,35 @@ +"""Command line entry points for routine tasks.""" +from __future__ import annotations + +from datetime import date + +from app.backtest.engine import BtConfig, run_backtest +from app.data.schema import initialize_database + + +def init_db() -> None: + result = initialize_database() + if result.skipped: + print("Database already initialized; skipping schema creation") + else: + print(f"Initialized database with {result.executed} statements") + + +def run_sample_backtest() -> None: + cfg = BtConfig( + id="demo", + name="Demo Strategy", + start_date=date(2020, 1, 1), + end_date=date(2020, 3, 31), + universe=["000001.SZ"], + params={ + "target": 0.035, + "stop": -0.015, + "hold_days": 10, + }, + ) + run_backtest(cfg) + + +if __name__ == "__main__": + init_db() diff --git a/app/ui/__init__.py b/app/ui/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/ui/streamlit_app.py b/app/ui/streamlit_app.py new file mode 100644 index 0000000..a25dce2 --- /dev/null +++ b/app/ui/streamlit_app.py @@ -0,0 +1,85 @@ +"""Streamlit UI scaffold for the investment assistant.""" +from __future__ import annotations + +import sys +from datetime import date +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[2] +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) + +import streamlit as st + +from app.data.schema import initialize_database +from app.ingest.tushare import FetchJob, run_ingestion +from app.llm.explain import make_human_card + + +def render_today_plan() -> None: + st.header("今日计划") + st.write("待接入候选池筛选与多智能体决策结果。") + sample = make_human_card("000001.SZ", "2025-01-01", {"decisions": []}) + st.json(sample) + + +def render_backtest() -> None: + st.header("回测与复盘") + st.write("在此运行回测、展示净值曲线与代理贡献。") + st.button("开始回测") + + +def render_settings() -> None: + st.header("数据与设置") + st.text_input("TuShare Token") + st.write("新闻源开关与数据库备份将在此配置。") + + +def render_tests() -> None: + st.header("自检测试") + st.write("用于快速检查数据库与数据拉取是否正常工作。") + + if st.button("测试数据库初始化"): + with st.spinner("正在检查数据库..."): + result = initialize_database() + if result.skipped: + st.success("数据库已存在,检查通过。") + else: + st.success(f"数据库初始化完成,共执行 {result.executed} 条语句。") + + st.divider() + + if st.button("测试 TuShare 拉取(示例 2024-01-01 至 2024-01-03)"): + with st.spinner("正在调用 TuShare 接口..."): + try: + run_ingestion( + FetchJob( + name="streamlit_self_test", + start=date(2024, 1, 1), + end=date(2024, 1, 3), + ts_codes=("000001.SZ",), + ), + include_limits=False, + ) + st.success("TuShare 示例拉取完成,数据已写入数据库。") + except Exception as exc: # noqa: BLE001 + st.error(f"拉取失败:{exc}") + + st.info("注意:TuShare 拉取依赖网络与 Token,若环境未配置将出现错误提示。") + + +def main() -> None: + st.set_page_config(page_title="多智能体投资助理", layout="wide") + tabs = st.tabs(["今日计划", "回测与复盘", "数据与设置", "自检测试"]) + with tabs[0]: + render_today_plan() + with tabs[1]: + render_backtest() + with tabs[2]: + render_settings() + with tabs[3]: + render_tests() + + +if __name__ == "__main__": + main() diff --git a/app/utils/__init__.py b/app/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/utils/calendar.py b/app/utils/calendar.py new file mode 100644 index 0000000..60cc3c2 --- /dev/null +++ b/app/utils/calendar.py @@ -0,0 +1,34 @@ +"""Trading calendar utilities. + +These helpers abstract exchange calendars and trading day lookups. Real +implementations should integrate with TuShare or cached calendars. +""" +from __future__ import annotations + +from datetime import date, timedelta +from typing import Iterable, List + + +def is_trading_day(day: date, holidays: Iterable[date] | None = None) -> bool: + if day.weekday() >= 5: + return False + if holidays and day in set(holidays): + return False + return True + + +def previous_trading_day(day: date, holidays: Iterable[date] | None = None) -> date: + current = day - timedelta(days=1) + while not is_trading_day(current, holidays): + current -= timedelta(days=1) + return current + + +def trading_days_between(start: date, end: date, holidays: Iterable[date] | None = None) -> List[date]: + current = start + days: List[date] = [] + while current <= end: + if is_trading_day(current, holidays): + days.append(current) + current += timedelta(days=1) + return days diff --git a/app/utils/config.py b/app/utils/config.py new file mode 100644 index 0000000..fb0a128 --- /dev/null +++ b/app/utils/config.py @@ -0,0 +1,65 @@ +"""Application configuration models and helpers.""" +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Dict, Optional + + +def _default_root() -> Path: + return Path(__file__).resolve().parents[2] / "app" / "data" + + +@dataclass +class DataPaths: + """Holds filesystem locations for persistent artifacts.""" + + root: Path = field(default_factory=_default_root) + database: Path = field(init=False) + backups: Path = field(init=False) + + def __post_init__(self) -> None: + self.root.mkdir(parents=True, exist_ok=True) + self.database = self.root / "llm_quant.db" + self.backups = self.root / "backups" + self.backups.mkdir(parents=True, exist_ok=True) + + +@dataclass +class AgentWeights: + """Default weighting for decision agents.""" + + momentum: float = 0.30 + value: float = 0.20 + news: float = 0.20 + liquidity: float = 0.15 + macro: float = 0.15 + + def as_dict(self) -> Dict[str, float]: + return { + "A_mom": self.momentum, + "A_val": self.value, + "A_news": self.news, + "A_liq": self.liquidity, + "A_macro": self.macro, + } + + +@dataclass +class AppConfig: + """User configurable settings persisted in a simple structure.""" + + tushare_token: Optional[str] = None + rss_sources: Dict[str, bool] = field(default_factory=dict) + decision_method: str = "nash" + data_paths: DataPaths = field(default_factory=DataPaths) + agent_weights: AgentWeights = field(default_factory=AgentWeights) + + +CONFIG = AppConfig() + + +def get_config() -> AppConfig: + """Return a mutable global configuration instance.""" + + return CONFIG diff --git a/app/utils/db.py b/app/utils/db.py new file mode 100644 index 0000000..32ee0d9 --- /dev/null +++ b/app/utils/db.py @@ -0,0 +1,47 @@ +"""SQLite helpers for application data access.""" +from __future__ import annotations + +import sqlite3 +from contextlib import contextmanager +from pathlib import Path +from typing import Generator, Iterable, Tuple + +from .config import get_config + + +def get_connection(read_only: bool = False) -> sqlite3.Connection: + """Create a SQLite connection against the configured database file.""" + + db_path: Path = get_config().data_paths.database + uri = f"file:{db_path}?mode={'ro' if read_only else 'rw'}" + if not db_path.exists() and not read_only: + # Ensure directory exists before first write. + db_path.parent.mkdir(parents=True, exist_ok=True) + conn = sqlite3.connect(db_path) + else: + conn = sqlite3.connect(uri, uri=True) + conn.row_factory = sqlite3.Row + return conn + + +@contextmanager +def db_session(read_only: bool = False) -> Generator[sqlite3.Connection, None, None]: + """Context manager yielding a connection with automatic commit/rollback.""" + + conn = get_connection(read_only=read_only) + try: + yield conn + if not read_only: + conn.commit() + except Exception: + conn.rollback() + raise + finally: + conn.close() + + +def execute_many(sql: str, params: Iterable[Tuple]) -> None: + """Bulk execute a parameterized statement inside a write session.""" + + with db_session() as conn: + conn.executemany(sql, params) diff --git a/app/utils/logging.py b/app/utils/logging.py new file mode 100644 index 0000000..687a850 --- /dev/null +++ b/app/utils/logging.py @@ -0,0 +1,28 @@ +"""Centralized logging configuration.""" +from __future__ import annotations + +import logging +from pathlib import Path + +from .config import get_config + + +def configure_logging(level: int = logging.INFO) -> None: + """Setup root logger with file and console handlers.""" + + cfg = get_config() + log_dir = cfg.data_paths.root / "logs" + log_dir.mkdir(parents=True, exist_ok=True) + logfile = log_dir / "app.log" + + logging.basicConfig( + level=level, + format="%(asctime)s %(levelname)s %(name)s - %(message)s", + handlers=[ + logging.FileHandler(logfile, encoding="utf-8"), + logging.StreamHandler(), + ], + ) + + +configure_logging()