This commit is contained in:
sam 2025-09-26 18:21:25 +08:00
commit 6eac6c5f69
34 changed files with 1309 additions and 0 deletions

29
.gitignore vendored Normal file
View File

@ -0,0 +1,29 @@
# Python cache and build artifacts
__pycache__/
*.py[cod]
*.so
.build/
.eggs/
*.egg-info/
# Virtual environments
.venv/
venv/
ENV/
env/
# Jupyter checkpoints
.ipynb_checkpoints/
# Logs and data
app/data/*.db*
app/data/backups/
app/data/logs/
*.log
# Streamlit temporary files
.streamlit/
# System files
.DS_Store
Thumbs.db

27
README.md Normal file
View File

@ -0,0 +1,27 @@
# 多智能体投资助理骨架
本仓库提供一个基于多智能体博弈的 A 股日线投资助理代码框架满足单机可运行、SQLite 存储和 Streamlit UI 的需求。核心模块划分如下:
- `app/data`:数据库初始化与 Schema 定义。
- `app/utils`:配置、数据库连接、日志和交易日历工具。
- `app/ingest`TuShare 与 RSS 数据拉取骨架。
- `app/features`:指标与信号计算接口。
- `app/agents`:多智能体博弈实现,包括动量、价值、新闻、流动性、宏观与风险代理。
- `app/backtest`:日线回测引擎与指标计算的占位实现。
- `app/llm`:人类可读卡片与摘要生成入口(仅构建提示,不直接交易)。
- `app/ui`Streamlit 三页界面骨架。
## 快速开始
```bash
python -m app.main # 初始化数据库
streamlit run app/ui/streamlit_app.py
```
## 下一步
1. 在 `app/ingest` 中补充 TuShare 和 RSS 数据抓取逻辑。
2. 完善 `app/features``app/backtest` 以实现实际的信号计算与事件驱动回测。
3. 将代理效用写入 SQLite 的 `agent_utils``alloc_log` 表,驱动 UI 展示。
4. 使用轻量情感分析与热度计算,填充 `news``heat_daily`
5. 接入本地小模型或 API 完成 LLM 文本解释,并在 UI 中展示。

0
app/__init__.py Normal file
View File

0
app/agents/__init__.py Normal file
View File

44
app/agents/base.py Normal file
View File

@ -0,0 +1,44 @@
"""Agent abstractions for the multi-agent decision engine."""
from __future__ import annotations
from dataclasses import dataclass
from enum import Enum
from typing import Dict, Mapping
class AgentAction(str, Enum):
SELL = "SELL"
HOLD = "HOLD"
BUY_S = "BUY_S"
BUY_M = "BUY_M"
BUY_L = "BUY_L"
@dataclass
class AgentContext:
ts_code: str
trade_date: str
features: Mapping[str, float]
class Agent:
"""Base class for all decision agents."""
name: str
def __init__(self, name: str) -> None:
self.name = name
def score(self, context: AgentContext, action: AgentAction) -> float:
"""Return a normalized utility value in [0,1] for the proposed action."""
raise NotImplementedError
def feasible(self, context: AgentContext, action: AgentAction) -> bool:
"""Optional hook for agents with veto power (defaults to True)."""
_ = context, action
return True
UtilityMatrix = Dict[AgentAction, Dict[str, float]]

127
app/agents/game.py Normal file
View File

@ -0,0 +1,127 @@
"""Multi-agent decision game implementation."""
from __future__ import annotations
from dataclasses import dataclass
from math import exp, log
from typing import Dict, Iterable, List, Mapping, Tuple
from .base import Agent, AgentAction, AgentContext, UtilityMatrix
from .registry import weight_map
ACTIONS: Tuple[AgentAction, ...] = (
AgentAction.SELL,
AgentAction.HOLD,
AgentAction.BUY_S,
AgentAction.BUY_M,
AgentAction.BUY_L,
)
def _clamp(value: float) -> float:
return max(0.0, min(1.0, value))
@dataclass
class Decision:
action: AgentAction
confidence: float
target_weight: float
feasible_actions: List[AgentAction]
utilities: UtilityMatrix
def compute_utilities(agents: Iterable[Agent], context: AgentContext) -> UtilityMatrix:
utilities: UtilityMatrix = {}
for action in ACTIONS:
utilities[action] = {}
for agent in agents:
score = _clamp(agent.score(context, action))
utilities[action][agent.name] = score
return utilities
def feasible_actions(agents: Iterable[Agent], context: AgentContext) -> List[AgentAction]:
feas: List[AgentAction] = []
for action in ACTIONS:
if all(agent.feasible(context, action) for agent in agents):
feas.append(action)
return feas
def nash_bargain(utilities: UtilityMatrix, weights: Mapping[str, float], disagreement: Mapping[str, float]) -> Tuple[AgentAction, float]:
best_action = AgentAction.HOLD
best_score = float("-inf")
for action, agent_scores in utilities.items():
if action not in utilities:
continue
log_product = 0.0
valid = True
for agent_name, score in agent_scores.items():
w = weights.get(agent_name, 0.0)
if w == 0:
continue
gap = score - disagreement.get(agent_name, 0.0)
if gap <= 0:
valid = False
break
log_product += w * log(gap)
if not valid:
continue
if log_product > best_score:
best_score = log_product
best_action = action
if best_score == float("-inf"):
return AgentAction.HOLD, 0.0
confidence = _aggregate_confidence(utilities[best_action], weights)
return best_action, confidence
def vote(utilities: UtilityMatrix, weights: Mapping[str, float]) -> Tuple[AgentAction, float]:
scores: Dict[AgentAction, float] = {}
for action, agent_scores in utilities.items():
scores[action] = sum(weights.get(agent, 0.0) * score for agent, score in agent_scores.items())
best_action = max(scores, key=scores.get)
confidence = _aggregate_confidence(utilities[best_action], weights)
return best_action, confidence
def _aggregate_confidence(agent_scores: Mapping[str, float], weights: Mapping[str, float]) -> float:
total = sum(weights.values())
if total <= 0:
return 0.0
weighted = sum(weights.get(agent, 0.0) * score for agent, score in agent_scores.items())
return weighted / total
def target_weight_for_action(action: AgentAction) -> float:
mapping = {
AgentAction.SELL: -1.0,
AgentAction.HOLD: 0.0,
AgentAction.BUY_S: 0.01,
AgentAction.BUY_M: 0.02,
AgentAction.BUY_L: 0.03,
}
return mapping[action]
def decide(context: AgentContext, agents: Iterable[Agent], weights: Mapping[str, float], method: str = "nash") -> Decision:
agent_list = list(agents)
norm_weights = weight_map(dict(weights))
utilities = compute_utilities(agent_list, context)
feas_actions = feasible_actions(agent_list, context)
if not feas_actions:
return Decision(AgentAction.HOLD, 0.0, 0.0, [], utilities)
filtered_utilities = {action: utilities[action] for action in feas_actions}
hold_scores = utilities.get(AgentAction.HOLD, {})
if method == "vote":
action, confidence = vote(filtered_utilities, norm_weights)
else:
action, confidence = nash_bargain(filtered_utilities, norm_weights, hold_scores)
if action not in feas_actions:
action, confidence = vote(filtered_utilities, norm_weights)
weight = target_weight_for_action(action)
return Decision(action, confidence, weight, feas_actions, utilities)

20
app/agents/liquidity.py Normal file
View File

@ -0,0 +1,20 @@
"""Liquidity and transaction cost agent."""
from __future__ import annotations
from .base import Agent, AgentAction, AgentContext
class LiquidityAgent(Agent):
def __init__(self) -> None:
super().__init__(name="A_liq")
def score(self, context: AgentContext, action: AgentAction) -> float:
liq = context.features.get("liquidity_score", 0.5)
cost = context.features.get("cost_penalty", 0.0)
penalty = cost
if action is AgentAction.SELL:
return min(1.0, liq + penalty)
if action is AgentAction.HOLD:
return 0.4 + 0.2 * liq
scale = {AgentAction.BUY_S: 0.5, AgentAction.BUY_M: 0.75, AgentAction.BUY_L: 1.0}.get(action, 0.0)
return max(0.0, liq * scale - penalty)

23
app/agents/macro.py Normal file
View File

@ -0,0 +1,23 @@
"""Macro and industry regime agent."""
from __future__ import annotations
from .base import Agent, AgentAction, AgentContext
class MacroAgent(Agent):
def __init__(self) -> None:
super().__init__(name="A_macro")
def score(self, context: AgentContext, action: AgentAction) -> float:
industry_heat = context.features.get("industry_heat", 0.5)
relative_strength = context.features.get("industry_relative_mom", 0.0)
raw = min(1.0, max(0.0, industry_heat * 0.6 + relative_strength * 0.4))
if action is AgentAction.SELL:
return 1 - raw
if action is AgentAction.HOLD:
return 0.5
if action is AgentAction.BUY_S:
return raw * 0.6
if action is AgentAction.BUY_M:
return raw * 0.8
return raw

29
app/agents/momentum.py Normal file
View File

@ -0,0 +1,29 @@
"""Momentum oriented agent."""
from __future__ import annotations
from math import tanh
from .base import Agent, AgentAction, AgentContext
def _sigmoid(x: float) -> float:
return 0.5 * (tanh(x) + 1)
class MomentumAgent(Agent):
def __init__(self) -> None:
super().__init__(name="A_mom")
def score(self, context: AgentContext, action: AgentAction) -> float:
mom20 = context.features.get("mom_20", 0.0)
mom60 = context.features.get("mom_60", 0.0)
strength = _sigmoid(0.5 * mom20 + 0.5 * mom60)
if action is AgentAction.SELL:
return 1 - strength
if action is AgentAction.HOLD:
return 0.5
if action is AgentAction.BUY_S:
return strength * 0.6
if action is AgentAction.BUY_M:
return strength * 0.8
return strength

26
app/agents/news.py Normal file
View File

@ -0,0 +1,26 @@
"""News and sentiment aware agent."""
from __future__ import annotations
from .base import Agent, AgentAction, AgentContext
class NewsAgent(Agent):
def __init__(self) -> None:
super().__init__(name="A_news")
def score(self, context: AgentContext, action: AgentAction) -> float:
heat = context.features.get("news_heat", 0.0)
sentiment = context.features.get("news_sentiment", 0.0)
positive = max(0.0, sentiment)
negative = max(0.0, -sentiment)
buy_score = min(1.0, heat * positive)
sell_score = min(1.0, heat * negative)
if action is AgentAction.SELL:
return sell_score
if action is AgentAction.HOLD:
return 0.3 + 0.4 * (1 - heat)
if action is AgentAction.BUY_S:
return 0.5 * buy_score
if action is AgentAction.BUY_M:
return 0.75 * buy_score
return buy_score

30
app/agents/registry.py Normal file
View File

@ -0,0 +1,30 @@
"""Factory helpers to construct the agent ensemble."""
from __future__ import annotations
from typing import Dict, List
from .base import Agent
from .liquidity import LiquidityAgent
from .macro import MacroAgent
from .momentum import MomentumAgent
from .news import NewsAgent
from .risk import RiskAgent
from .value import ValueAgent
def default_agents() -> List[Agent]:
return [
MomentumAgent(),
ValueAgent(),
NewsAgent(),
LiquidityAgent(),
MacroAgent(),
RiskAgent(),
]
def weight_map(raw: Dict[str, float]) -> Dict[str, float]:
total = sum(raw.values())
if total == 0:
return raw
return {name: weight / total for name, weight in raw.items()}

29
app/agents/risk.py Normal file
View File

@ -0,0 +1,29 @@
"""Risk agent acts as leader with veto rights."""
from __future__ import annotations
from .base import Agent, AgentAction, AgentContext
class RiskAgent(Agent):
def __init__(self) -> None:
super().__init__(name="A_risk")
def score(self, context: AgentContext, action: AgentAction) -> float:
# Base risk agent is neutral unless penalties are triggered.
penalty = context.features.get("risk_penalty", 0.0)
if action is AgentAction.SELL:
return min(1.0, 0.6 + penalty)
if action is AgentAction.HOLD:
return 0.5
return max(0.0, 1.0 - penalty)
def feasible(self, context: AgentContext, action: AgentAction) -> bool:
if action is AgentAction.SELL:
return True
if context.features.get("is_suspended", False):
return False
if context.features.get("limit_up", False) and action not in (AgentAction.SELL, AgentAction.HOLD):
return False
if context.features.get("position_limit", False) and action in (AgentAction.BUY_M, AgentAction.BUY_L):
return False
return True

26
app/agents/value.py Normal file
View File

@ -0,0 +1,26 @@
"""Value and quality filtering agent."""
from __future__ import annotations
from .base import Agent, AgentAction, AgentContext
class ValueAgent(Agent):
def __init__(self) -> None:
super().__init__(name="A_val")
def score(self, context: AgentContext, action: AgentAction) -> float:
pe = context.features.get("pe_percentile", 0.5)
pb = context.features.get("pb_percentile", 0.5)
roe = context.features.get("roe_percentile", 0.5)
# Lower valuation percentiles and higher quality percentiles add value.
raw = max(0.0, (1 - pe) * 0.4 + (1 - pb) * 0.3 + roe * 0.3)
raw = min(raw, 1.0)
if action is AgentAction.SELL:
return 1 - raw
if action is AgentAction.HOLD:
return 0.5
if action is AgentAction.BUY_S:
return raw * 0.7
if action is AgentAction.BUY_M:
return raw * 0.85
return raw

0
app/backtest/__init__.py Normal file
View File

90
app/backtest/engine.py Normal file
View File

@ -0,0 +1,90 @@
"""Backtest engine skeleton for daily bar simulation."""
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import date
from typing import Dict, Iterable, List, Mapping
from app.agents.base import AgentContext
from app.agents.game import Decision, decide
from app.agents.registry import default_agents
from app.utils.db import db_session
@dataclass
class BtConfig:
id: str
name: str
start_date: date
end_date: date
universe: List[str]
params: Dict[str, float]
method: str = "nash"
@dataclass
class PortfolioState:
cash: float = 1_000_000.0
holdings: Dict[str, float] = field(default_factory=dict)
@dataclass
class BacktestResult:
nav_series: List[Dict[str, float]] = field(default_factory=list)
trades: List[Dict[str, str]] = field(default_factory=list)
class BacktestEngine:
"""Runs the multi-agent game inside a daily event-driven loop."""
def __init__(self, cfg: BtConfig) -> None:
self.cfg = cfg
self.agents = default_agents()
self.weights = {agent.name: 1.0 for agent in self.agents}
def load_market_data(self, trade_date: date) -> Mapping[str, Dict[str, float]]:
"""Load per-stock feature vectors. Replace with real data access."""
_ = trade_date
return {}
def simulate_day(self, trade_date: date, state: PortfolioState) -> List[Decision]:
feature_map = self.load_market_data(trade_date)
decisions: List[Decision] = []
for ts_code, features in feature_map.items():
context = AgentContext(ts_code=ts_code, trade_date=trade_date.isoformat(), features=features)
decision = decide(context, self.agents, self.weights, method=self.cfg.method)
decisions.append(decision)
self.record_agent_state(context, decision)
# TODO: translate decisions into fills, holdings, and NAV updates.
_ = state
return decisions
def record_agent_state(self, context: AgentContext, decision: Decision) -> None:
payload = {
"trade_date": context.trade_date,
"ts_code": context.ts_code,
"action": decision.action.value,
"confidence": decision.confidence,
}
_ = payload
# Implementation should persist into agent_utils and bt_trades.
def run(self) -> BacktestResult:
state = PortfolioState()
result = BacktestResult()
current = self.cfg.start_date
while current <= self.cfg.end_date:
decisions = self.simulate_day(current, state)
_ = decisions
current = date.fromordinal(current.toordinal() + 1)
return result
def run_backtest(cfg: BtConfig) -> BacktestResult:
engine = BacktestEngine(cfg)
result = engine.run()
with db_session() as conn:
_ = conn
# Implementation should persist bt_nav, bt_trades, and bt_report rows.
return result

19
app/backtest/metrics.py Normal file
View File

@ -0,0 +1,19 @@
"""Performance metric utilities."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Iterable, List
@dataclass
class Metric:
name: str
value: float
description: str
def compute_nav_metrics(nav_series: Iterable[float]) -> List[Metric]:
"""Compute core statistics such as CAGR, Sharpe, and max drawdown."""
_ = nav_series
return []

0
app/data/__init__.py Normal file
View File

146
app/data/schema.py Normal file
View File

@ -0,0 +1,146 @@
"""Database schema management for the investment assistant."""
from __future__ import annotations
import sqlite3
from dataclasses import dataclass
from typing import Iterable
from app.utils.config import get_config
from app.utils.db import db_session
SCHEMA_STATEMENTS: Iterable[str] = (
"""
CREATE TABLE IF NOT EXISTS news (
id TEXT PRIMARY KEY,
ts_code TEXT,
pub_time TEXT,
source TEXT,
title TEXT,
summary TEXT,
url TEXT,
entities TEXT,
sentiment REAL,
heat REAL
);
""",
"""
CREATE INDEX IF NOT EXISTS idx_news_time ON news(pub_time DESC);
""",
"""
CREATE INDEX IF NOT EXISTS idx_news_code ON news(ts_code, pub_time DESC);
""",
"""
CREATE TABLE IF NOT EXISTS heat_daily (
scope TEXT,
key TEXT,
trade_date TEXT,
heat REAL,
top_topics TEXT,
PRIMARY KEY (scope, key, trade_date)
);
""",
"""
CREATE TABLE IF NOT EXISTS bt_config (
id TEXT PRIMARY KEY,
name TEXT,
start_date TEXT,
end_date TEXT,
universe TEXT,
params TEXT
);
""",
"""
CREATE TABLE IF NOT EXISTS bt_trades (
cfg_id TEXT,
ts_code TEXT,
trade_date TEXT,
side TEXT,
price REAL,
qty REAL,
reason TEXT,
PRIMARY KEY (cfg_id, ts_code, trade_date, side)
);
""",
"""
CREATE TABLE IF NOT EXISTS bt_nav (
cfg_id TEXT,
trade_date TEXT,
nav REAL,
ret REAL,
pos_count INTEGER,
turnover REAL,
dd REAL,
info TEXT,
PRIMARY KEY (cfg_id, trade_date)
);
""",
"""
CREATE TABLE IF NOT EXISTS bt_report (
cfg_id TEXT PRIMARY KEY,
summary TEXT
);
""",
"""
CREATE TABLE IF NOT EXISTS run_log (
ts TEXT PRIMARY KEY,
stage TEXT,
level TEXT,
msg TEXT
);
""",
"""
CREATE TABLE IF NOT EXISTS agent_utils (
trade_date TEXT,
ts_code TEXT,
agent TEXT,
action TEXT,
utils TEXT,
feasible TEXT,
weight REAL,
PRIMARY KEY (trade_date, ts_code, agent)
);
""",
"""
CREATE TABLE IF NOT EXISTS alloc_log (
trade_date TEXT,
ts_code TEXT,
target_weight REAL,
clipped_weight REAL,
reason TEXT,
PRIMARY KEY (trade_date, ts_code)
);
"""
)
@dataclass
class MigrationResult:
executed: int
skipped: bool = False
def _schema_exists() -> bool:
try:
with db_session(read_only=True) as conn:
cursor = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='news'"
)
return cursor.fetchone() is not None
except sqlite3.OperationalError:
return False
def initialize_database() -> MigrationResult:
"""Create tables and indexes required by the application."""
if _schema_exists():
return MigrationResult(executed=0, skipped=True)
executed = 0
with db_session() as conn:
cursor = conn.cursor()
for statement in SCHEMA_STATEMENTS:
cursor.executescript(statement)
executed += 1
return MigrationResult(executed=executed)

0
app/features/__init__.py Normal file
View File

39
app/features/factors.py Normal file
View File

@ -0,0 +1,39 @@
"""Feature engineering for signals and indicator computation."""
from __future__ import annotations
from dataclasses import dataclass
from datetime import date
from typing import Iterable, List
@dataclass
class FactorSpec:
name: str
window: int
@dataclass
class FactorResult:
ts_code: str
trade_date: date
values: dict
DEFAULT_FACTORS: List[FactorSpec] = [
FactorSpec("mom_20", 20),
FactorSpec("mom_60", 60),
FactorSpec("volat_20", 20),
FactorSpec("turn_20", 20),
]
def compute_factors(trade_date: date, factors: Iterable[FactorSpec] = DEFAULT_FACTORS) -> List[FactorResult]:
"""Calculate factor values for the requested date.
This function should join historical price data, apply rolling windows, and
persist results into an factors table. The implementation is left as future
work.
"""
_ = trade_date, factors
raise NotImplementedError

0
app/ingest/__init__.py Normal file
View File

42
app/ingest/rss.py Normal file
View File

@ -0,0 +1,42 @@
"""RSS ingestion for news and heat scores."""
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime
from typing import Iterable, List
@dataclass
class RssItem:
id: str
title: str
link: str
published: datetime
summary: str
source: str
def fetch_rss_feed(url: str) -> List[RssItem]:
"""Download and parse an RSS feed into structured items."""
raise NotImplementedError
def deduplicate_items(items: Iterable[RssItem]) -> List[RssItem]:
"""Drop duplicate stories by link/id fingerprint."""
seen = set()
unique: List[RssItem] = []
for item in items:
key = item.id or item.link
if key in seen:
continue
seen.add(key)
unique.append(item)
return unique
def save_news_items(items: Iterable[RssItem]) -> None:
"""Persist RSS items into the `news` table."""
raise NotImplementedError

240
app/ingest/tushare.py Normal file
View File

@ -0,0 +1,240 @@
"""TuShare 数据拉取管线实现。"""
from __future__ import annotations
import logging
import os
from dataclasses import dataclass
from datetime import date
from typing import Dict, Iterable, List, Optional, Sequence
import pandas as pd
try:
import tushare as ts
except ImportError as exc: # pragma: no cover - dependency error surfaced at runtime
ts = None # type: ignore[assignment]
from app.utils.config import get_config
from app.utils.db import db_session
LOGGER = logging.getLogger(__name__)
@dataclass
class FetchJob:
name: str
start: date
end: date
granularity: str = "daily"
ts_codes: Optional[Sequence[str]] = None
_TABLE_SCHEMAS: Dict[str, str] = {
"daily": """
CREATE TABLE IF NOT EXISTS daily (
ts_code TEXT,
trade_date TEXT,
open REAL,
high REAL,
low REAL,
close REAL,
pre_close REAL,
change REAL,
pct_chg REAL,
vol REAL,
amount REAL,
PRIMARY KEY (ts_code, trade_date)
);
""",
"suspend": """
CREATE TABLE IF NOT EXISTS suspend (
ts_code TEXT,
suspend_date TEXT,
resume_date TEXT,
suspend_type TEXT,
ann_date TEXT,
suspend_timing TEXT,
resume_timing TEXT,
reason TEXT,
PRIMARY KEY (ts_code, suspend_date)
);
""",
"trade_calendar": """
CREATE TABLE IF NOT EXISTS trade_calendar (
exchange TEXT,
cal_date TEXT PRIMARY KEY,
is_open INTEGER,
pretrade_date TEXT
);
""",
"stk_limit": """
CREATE TABLE IF NOT EXISTS stk_limit (
ts_code TEXT,
trade_date TEXT,
up_limit REAL,
down_limit REAL,
PRIMARY KEY (ts_code, trade_date)
);
""",
}
_TABLE_COLUMNS: Dict[str, List[str]] = {
"daily": [
"ts_code",
"trade_date",
"open",
"high",
"low",
"close",
"pre_close",
"change",
"pct_chg",
"vol",
"amount",
],
"suspend": [
"ts_code",
"suspend_date",
"resume_date",
"suspend_type",
"ann_date",
"suspend_timing",
"resume_timing",
"reason",
],
"trade_calendar": [
"exchange",
"cal_date",
"is_open",
"pretrade_date",
],
"stk_limit": [
"ts_code",
"trade_date",
"up_limit",
"down_limit",
],
}
def _ensure_client():
if ts is None:
raise RuntimeError("未安装 tushare请先在环境中安装 tushare 包")
token = get_config().tushare_token or os.getenv("TUSHARE_TOKEN")
if not token:
raise RuntimeError("未配置 TuShare Token请在配置文件或环境变量 TUSHARE_TOKEN 中设置")
if not hasattr(_ensure_client, "_client") or _ensure_client._client is None: # type: ignore[attr-defined]
ts.set_token(token)
_ensure_client._client = ts.pro_api(token) # type: ignore[attr-defined]
LOGGER.info("完成 TuShare 客户端初始化")
return _ensure_client._client # type: ignore[attr-defined]
def _format_date(value: date) -> str:
return value.strftime("%Y%m%d")
def _df_to_records(df: pd.DataFrame, allowed_cols: List[str]) -> List[Dict]:
if df is None or df.empty:
return []
# 对缺失列进行补全,防止写库时缺少绑定参数
reindexed = df.reindex(columns=allowed_cols)
return reindexed.where(pd.notnull(reindexed), None).to_dict("records")
def fetch_daily_bars(job: FetchJob) -> Iterable[Dict]:
"""拉取日线行情。"""
client = _ensure_client()
start_date = _format_date(job.start)
end_date = _format_date(job.end)
frames: List[pd.DataFrame] = []
if job.granularity != "daily":
raise ValueError(f"暂不支持的粒度:{job.granularity}")
if job.ts_codes:
for code in job.ts_codes:
LOGGER.info("拉取 %s 的日线行情(%s-%s", code, start_date, end_date)
frames.append(client.daily(ts_code=code, start_date=start_date, end_date=end_date))
else:
LOGGER.info("按全市场拉取日线行情(%s-%s", start_date, end_date)
frames.append(client.daily(start_date=start_date, end_date=end_date))
if not frames:
return []
df = pd.concat(frames, ignore_index=True) if len(frames) > 1 else frames[0]
return _df_to_records(df, _TABLE_COLUMNS["daily"])
def fetch_suspensions(start: date, end: date, ts_code: Optional[str] = None) -> Iterable[Dict]:
client = _ensure_client()
start_date = _format_date(start)
end_date = _format_date(end)
LOGGER.info("拉取停复牌信息(%s-%s", start_date, end_date)
df = client.suspend_d(ts_code=ts_code, start_date=start_date, end_date=end_date)
return _df_to_records(df, _TABLE_COLUMNS["suspend"])
def fetch_trade_calendar(start: date, end: date, exchange: str = "SSE") -> Iterable[Dict]:
client = _ensure_client()
start_date = _format_date(start)
end_date = _format_date(end)
LOGGER.info("拉取交易日历(交易所:%s,区间:%s-%s", exchange, start_date, end_date)
df = client.trade_cal(exchange=exchange, start_date=start_date, end_date=end_date)
return _df_to_records(df, _TABLE_COLUMNS["trade_calendar"])
def fetch_stk_limit(start: date, end: date, ts_code: Optional[str] = None) -> Iterable[Dict]:
client = _ensure_client()
start_date = _format_date(start)
end_date = _format_date(end)
LOGGER.info("拉取涨跌停价格(%s-%s", start_date, end_date)
df = client.stk_limit(ts_code=ts_code, start_date=start_date, end_date=end_date)
return _df_to_records(df, _TABLE_COLUMNS["stk_limit"])
def save_records(table: str, rows: Iterable[Dict]) -> None:
"""将拉取的数据写入 SQLite。"""
items = list(rows)
if not items:
LOGGER.info("%s 没有新增记录,跳过写入", table)
return
schema = _TABLE_SCHEMAS.get(table)
columns = _TABLE_COLUMNS.get(table)
if not schema or not columns:
raise ValueError(f"不支持写入的表:{table}")
placeholders = ",".join([f":{col}" for col in columns])
col_clause = ",".join(columns)
LOGGER.info("%s 写入 %d 条记录", table, len(items))
with db_session() as conn:
conn.executescript(schema)
conn.executemany(
f"INSERT OR REPLACE INTO {table} ({col_clause}) VALUES ({placeholders})",
items,
)
def run_ingestion(job: FetchJob, include_limits: bool = True) -> None:
"""按任务配置拉取 TuShare 数据。"""
LOGGER.info("启动 TuShare 拉取任务:%s", job.name)
daily_rows = fetch_daily_bars(job)
save_records("daily", daily_rows)
suspend_rows = fetch_suspensions(job.start, job.end)
save_records("suspend", suspend_rows)
calendar_rows = fetch_trade_calendar(job.start, job.end)
save_records("trade_calendar", calendar_rows)
if include_limits:
limit_rows = fetch_stk_limit(job.start, job.end)
save_records("stk_limit", limit_rows)
LOGGER.info("任务 %s 完成", job.name)

0
app/llm/__init__.py Normal file
View File

18
app/llm/explain.py Normal file
View File

@ -0,0 +1,18 @@
"""LLM assisted explanations and summaries."""
from __future__ import annotations
from typing import Dict
from .prompts import plan_prompt
def make_human_card(ts_code: str, trade_date: str, context: Dict) -> Dict:
"""Compose payload for UI cards and LLM requests."""
prompt = plan_prompt(context)
return {
"ts_code": ts_code,
"trade_date": trade_date,
"prompt": prompt,
"context": context,
}

11
app/llm/prompts.py Normal file
View File

@ -0,0 +1,11 @@
"""Prompt templates for natural language outputs."""
from __future__ import annotations
from typing import Dict
def plan_prompt(data: Dict) -> str:
"""Build a concise instruction prompt for the LLM."""
_ = data
return "你是一个投资助理,请根据提供的数据给出三条要点和两条风险提示。"

35
app/main.py Normal file
View File

@ -0,0 +1,35 @@
"""Command line entry points for routine tasks."""
from __future__ import annotations
from datetime import date
from app.backtest.engine import BtConfig, run_backtest
from app.data.schema import initialize_database
def init_db() -> None:
result = initialize_database()
if result.skipped:
print("Database already initialized; skipping schema creation")
else:
print(f"Initialized database with {result.executed} statements")
def run_sample_backtest() -> None:
cfg = BtConfig(
id="demo",
name="Demo Strategy",
start_date=date(2020, 1, 1),
end_date=date(2020, 3, 31),
universe=["000001.SZ"],
params={
"target": 0.035,
"stop": -0.015,
"hold_days": 10,
},
)
run_backtest(cfg)
if __name__ == "__main__":
init_db()

0
app/ui/__init__.py Normal file
View File

85
app/ui/streamlit_app.py Normal file
View File

@ -0,0 +1,85 @@
"""Streamlit UI scaffold for the investment assistant."""
from __future__ import annotations
import sys
from datetime import date
from pathlib import Path
ROOT = Path(__file__).resolve().parents[2]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
import streamlit as st
from app.data.schema import initialize_database
from app.ingest.tushare import FetchJob, run_ingestion
from app.llm.explain import make_human_card
def render_today_plan() -> None:
st.header("今日计划")
st.write("待接入候选池筛选与多智能体决策结果。")
sample = make_human_card("000001.SZ", "2025-01-01", {"decisions": []})
st.json(sample)
def render_backtest() -> None:
st.header("回测与复盘")
st.write("在此运行回测、展示净值曲线与代理贡献。")
st.button("开始回测")
def render_settings() -> None:
st.header("数据与设置")
st.text_input("TuShare Token")
st.write("新闻源开关与数据库备份将在此配置。")
def render_tests() -> None:
st.header("自检测试")
st.write("用于快速检查数据库与数据拉取是否正常工作。")
if st.button("测试数据库初始化"):
with st.spinner("正在检查数据库..."):
result = initialize_database()
if result.skipped:
st.success("数据库已存在,检查通过。")
else:
st.success(f"数据库初始化完成,共执行 {result.executed} 条语句。")
st.divider()
if st.button("测试 TuShare 拉取(示例 2024-01-01 至 2024-01-03"):
with st.spinner("正在调用 TuShare 接口..."):
try:
run_ingestion(
FetchJob(
name="streamlit_self_test",
start=date(2024, 1, 1),
end=date(2024, 1, 3),
ts_codes=("000001.SZ",),
),
include_limits=False,
)
st.success("TuShare 示例拉取完成,数据已写入数据库。")
except Exception as exc: # noqa: BLE001
st.error(f"拉取失败:{exc}")
st.info("注意TuShare 拉取依赖网络与 Token若环境未配置将出现错误提示。")
def main() -> None:
st.set_page_config(page_title="多智能体投资助理", layout="wide")
tabs = st.tabs(["今日计划", "回测与复盘", "数据与设置", "自检测试"])
with tabs[0]:
render_today_plan()
with tabs[1]:
render_backtest()
with tabs[2]:
render_settings()
with tabs[3]:
render_tests()
if __name__ == "__main__":
main()

0
app/utils/__init__.py Normal file
View File

34
app/utils/calendar.py Normal file
View File

@ -0,0 +1,34 @@
"""Trading calendar utilities.
These helpers abstract exchange calendars and trading day lookups. Real
implementations should integrate with TuShare or cached calendars.
"""
from __future__ import annotations
from datetime import date, timedelta
from typing import Iterable, List
def is_trading_day(day: date, holidays: Iterable[date] | None = None) -> bool:
if day.weekday() >= 5:
return False
if holidays and day in set(holidays):
return False
return True
def previous_trading_day(day: date, holidays: Iterable[date] | None = None) -> date:
current = day - timedelta(days=1)
while not is_trading_day(current, holidays):
current -= timedelta(days=1)
return current
def trading_days_between(start: date, end: date, holidays: Iterable[date] | None = None) -> List[date]:
current = start
days: List[date] = []
while current <= end:
if is_trading_day(current, holidays):
days.append(current)
current += timedelta(days=1)
return days

65
app/utils/config.py Normal file
View File

@ -0,0 +1,65 @@
"""Application configuration models and helpers."""
from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, Optional
def _default_root() -> Path:
return Path(__file__).resolve().parents[2] / "app" / "data"
@dataclass
class DataPaths:
"""Holds filesystem locations for persistent artifacts."""
root: Path = field(default_factory=_default_root)
database: Path = field(init=False)
backups: Path = field(init=False)
def __post_init__(self) -> None:
self.root.mkdir(parents=True, exist_ok=True)
self.database = self.root / "llm_quant.db"
self.backups = self.root / "backups"
self.backups.mkdir(parents=True, exist_ok=True)
@dataclass
class AgentWeights:
"""Default weighting for decision agents."""
momentum: float = 0.30
value: float = 0.20
news: float = 0.20
liquidity: float = 0.15
macro: float = 0.15
def as_dict(self) -> Dict[str, float]:
return {
"A_mom": self.momentum,
"A_val": self.value,
"A_news": self.news,
"A_liq": self.liquidity,
"A_macro": self.macro,
}
@dataclass
class AppConfig:
"""User configurable settings persisted in a simple structure."""
tushare_token: Optional[str] = None
rss_sources: Dict[str, bool] = field(default_factory=dict)
decision_method: str = "nash"
data_paths: DataPaths = field(default_factory=DataPaths)
agent_weights: AgentWeights = field(default_factory=AgentWeights)
CONFIG = AppConfig()
def get_config() -> AppConfig:
"""Return a mutable global configuration instance."""
return CONFIG

47
app/utils/db.py Normal file
View File

@ -0,0 +1,47 @@
"""SQLite helpers for application data access."""
from __future__ import annotations
import sqlite3
from contextlib import contextmanager
from pathlib import Path
from typing import Generator, Iterable, Tuple
from .config import get_config
def get_connection(read_only: bool = False) -> sqlite3.Connection:
"""Create a SQLite connection against the configured database file."""
db_path: Path = get_config().data_paths.database
uri = f"file:{db_path}?mode={'ro' if read_only else 'rw'}"
if not db_path.exists() and not read_only:
# Ensure directory exists before first write.
db_path.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(db_path)
else:
conn = sqlite3.connect(uri, uri=True)
conn.row_factory = sqlite3.Row
return conn
@contextmanager
def db_session(read_only: bool = False) -> Generator[sqlite3.Connection, None, None]:
"""Context manager yielding a connection with automatic commit/rollback."""
conn = get_connection(read_only=read_only)
try:
yield conn
if not read_only:
conn.commit()
except Exception:
conn.rollback()
raise
finally:
conn.close()
def execute_many(sql: str, params: Iterable[Tuple]) -> None:
"""Bulk execute a parameterized statement inside a write session."""
with db_session() as conn:
conn.executemany(sql, params)

28
app/utils/logging.py Normal file
View File

@ -0,0 +1,28 @@
"""Centralized logging configuration."""
from __future__ import annotations
import logging
from pathlib import Path
from .config import get_config
def configure_logging(level: int = logging.INFO) -> None:
"""Setup root logger with file and console handlers."""
cfg = get_config()
log_dir = cfg.data_paths.root / "logs"
log_dir.mkdir(parents=True, exist_ok=True)
logfile = log_dir / "app.log"
logging.basicConfig(
level=level,
format="%(asctime)s %(levelname)s %(name)s - %(message)s",
handlers=[
logging.FileHandler(logfile, encoding="utf-8"),
logging.StreamHandler(),
],
)
configure_logging()