init
This commit is contained in:
commit
6eac6c5f69
29
.gitignore
vendored
Normal file
29
.gitignore
vendored
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
# Python cache and build artifacts
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*.so
|
||||||
|
.build/
|
||||||
|
.eggs/
|
||||||
|
*.egg-info/
|
||||||
|
|
||||||
|
# Virtual environments
|
||||||
|
.venv/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env/
|
||||||
|
|
||||||
|
# Jupyter checkpoints
|
||||||
|
.ipynb_checkpoints/
|
||||||
|
|
||||||
|
# Logs and data
|
||||||
|
app/data/*.db*
|
||||||
|
app/data/backups/
|
||||||
|
app/data/logs/
|
||||||
|
*.log
|
||||||
|
|
||||||
|
# Streamlit temporary files
|
||||||
|
.streamlit/
|
||||||
|
|
||||||
|
# System files
|
||||||
|
.DS_Store
|
||||||
|
Thumbs.db
|
||||||
27
README.md
Normal file
27
README.md
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
# 多智能体投资助理骨架
|
||||||
|
|
||||||
|
本仓库提供一个基于多智能体博弈的 A 股日线投资助理代码框架,满足单机可运行、SQLite 存储和 Streamlit UI 的需求。核心模块划分如下:
|
||||||
|
|
||||||
|
- `app/data`:数据库初始化与 Schema 定义。
|
||||||
|
- `app/utils`:配置、数据库连接、日志和交易日历工具。
|
||||||
|
- `app/ingest`:TuShare 与 RSS 数据拉取骨架。
|
||||||
|
- `app/features`:指标与信号计算接口。
|
||||||
|
- `app/agents`:多智能体博弈实现,包括动量、价值、新闻、流动性、宏观与风险代理。
|
||||||
|
- `app/backtest`:日线回测引擎与指标计算的占位实现。
|
||||||
|
- `app/llm`:人类可读卡片与摘要生成入口(仅构建提示,不直接交易)。
|
||||||
|
- `app/ui`:Streamlit 三页界面骨架。
|
||||||
|
|
||||||
|
## 快速开始
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m app.main # 初始化数据库
|
||||||
|
streamlit run app/ui/streamlit_app.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## 下一步
|
||||||
|
|
||||||
|
1. 在 `app/ingest` 中补充 TuShare 和 RSS 数据抓取逻辑。
|
||||||
|
2. 完善 `app/features` 和 `app/backtest` 以实现实际的信号计算与事件驱动回测。
|
||||||
|
3. 将代理效用写入 SQLite 的 `agent_utils` 和 `alloc_log` 表,驱动 UI 展示。
|
||||||
|
4. 使用轻量情感分析与热度计算,填充 `news` 和 `heat_daily`。
|
||||||
|
5. 接入本地小模型或 API 完成 LLM 文本解释,并在 UI 中展示。
|
||||||
0
app/__init__.py
Normal file
0
app/__init__.py
Normal file
0
app/agents/__init__.py
Normal file
0
app/agents/__init__.py
Normal file
44
app/agents/base.py
Normal file
44
app/agents/base.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
"""Agent abstractions for the multi-agent decision engine."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Dict, Mapping
|
||||||
|
|
||||||
|
|
||||||
|
class AgentAction(str, Enum):
|
||||||
|
SELL = "SELL"
|
||||||
|
HOLD = "HOLD"
|
||||||
|
BUY_S = "BUY_S"
|
||||||
|
BUY_M = "BUY_M"
|
||||||
|
BUY_L = "BUY_L"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AgentContext:
|
||||||
|
ts_code: str
|
||||||
|
trade_date: str
|
||||||
|
features: Mapping[str, float]
|
||||||
|
|
||||||
|
|
||||||
|
class Agent:
|
||||||
|
"""Base class for all decision agents."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
|
||||||
|
def __init__(self, name: str) -> None:
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
def score(self, context: AgentContext, action: AgentAction) -> float:
|
||||||
|
"""Return a normalized utility value in [0,1] for the proposed action."""
|
||||||
|
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def feasible(self, context: AgentContext, action: AgentAction) -> bool:
|
||||||
|
"""Optional hook for agents with veto power (defaults to True)."""
|
||||||
|
|
||||||
|
_ = context, action
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
UtilityMatrix = Dict[AgentAction, Dict[str, float]]
|
||||||
127
app/agents/game.py
Normal file
127
app/agents/game.py
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
"""Multi-agent decision game implementation."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from math import exp, log
|
||||||
|
from typing import Dict, Iterable, List, Mapping, Tuple
|
||||||
|
|
||||||
|
from .base import Agent, AgentAction, AgentContext, UtilityMatrix
|
||||||
|
from .registry import weight_map
|
||||||
|
|
||||||
|
|
||||||
|
ACTIONS: Tuple[AgentAction, ...] = (
|
||||||
|
AgentAction.SELL,
|
||||||
|
AgentAction.HOLD,
|
||||||
|
AgentAction.BUY_S,
|
||||||
|
AgentAction.BUY_M,
|
||||||
|
AgentAction.BUY_L,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _clamp(value: float) -> float:
|
||||||
|
return max(0.0, min(1.0, value))
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Decision:
|
||||||
|
action: AgentAction
|
||||||
|
confidence: float
|
||||||
|
target_weight: float
|
||||||
|
feasible_actions: List[AgentAction]
|
||||||
|
utilities: UtilityMatrix
|
||||||
|
|
||||||
|
|
||||||
|
def compute_utilities(agents: Iterable[Agent], context: AgentContext) -> UtilityMatrix:
|
||||||
|
utilities: UtilityMatrix = {}
|
||||||
|
for action in ACTIONS:
|
||||||
|
utilities[action] = {}
|
||||||
|
for agent in agents:
|
||||||
|
score = _clamp(agent.score(context, action))
|
||||||
|
utilities[action][agent.name] = score
|
||||||
|
return utilities
|
||||||
|
|
||||||
|
|
||||||
|
def feasible_actions(agents: Iterable[Agent], context: AgentContext) -> List[AgentAction]:
|
||||||
|
feas: List[AgentAction] = []
|
||||||
|
for action in ACTIONS:
|
||||||
|
if all(agent.feasible(context, action) for agent in agents):
|
||||||
|
feas.append(action)
|
||||||
|
return feas
|
||||||
|
|
||||||
|
|
||||||
|
def nash_bargain(utilities: UtilityMatrix, weights: Mapping[str, float], disagreement: Mapping[str, float]) -> Tuple[AgentAction, float]:
|
||||||
|
best_action = AgentAction.HOLD
|
||||||
|
best_score = float("-inf")
|
||||||
|
for action, agent_scores in utilities.items():
|
||||||
|
if action not in utilities:
|
||||||
|
continue
|
||||||
|
log_product = 0.0
|
||||||
|
valid = True
|
||||||
|
for agent_name, score in agent_scores.items():
|
||||||
|
w = weights.get(agent_name, 0.0)
|
||||||
|
if w == 0:
|
||||||
|
continue
|
||||||
|
gap = score - disagreement.get(agent_name, 0.0)
|
||||||
|
if gap <= 0:
|
||||||
|
valid = False
|
||||||
|
break
|
||||||
|
log_product += w * log(gap)
|
||||||
|
if not valid:
|
||||||
|
continue
|
||||||
|
if log_product > best_score:
|
||||||
|
best_score = log_product
|
||||||
|
best_action = action
|
||||||
|
if best_score == float("-inf"):
|
||||||
|
return AgentAction.HOLD, 0.0
|
||||||
|
confidence = _aggregate_confidence(utilities[best_action], weights)
|
||||||
|
return best_action, confidence
|
||||||
|
|
||||||
|
|
||||||
|
def vote(utilities: UtilityMatrix, weights: Mapping[str, float]) -> Tuple[AgentAction, float]:
|
||||||
|
scores: Dict[AgentAction, float] = {}
|
||||||
|
for action, agent_scores in utilities.items():
|
||||||
|
scores[action] = sum(weights.get(agent, 0.0) * score for agent, score in agent_scores.items())
|
||||||
|
best_action = max(scores, key=scores.get)
|
||||||
|
confidence = _aggregate_confidence(utilities[best_action], weights)
|
||||||
|
return best_action, confidence
|
||||||
|
|
||||||
|
|
||||||
|
def _aggregate_confidence(agent_scores: Mapping[str, float], weights: Mapping[str, float]) -> float:
|
||||||
|
total = sum(weights.values())
|
||||||
|
if total <= 0:
|
||||||
|
return 0.0
|
||||||
|
weighted = sum(weights.get(agent, 0.0) * score for agent, score in agent_scores.items())
|
||||||
|
return weighted / total
|
||||||
|
|
||||||
|
|
||||||
|
def target_weight_for_action(action: AgentAction) -> float:
|
||||||
|
mapping = {
|
||||||
|
AgentAction.SELL: -1.0,
|
||||||
|
AgentAction.HOLD: 0.0,
|
||||||
|
AgentAction.BUY_S: 0.01,
|
||||||
|
AgentAction.BUY_M: 0.02,
|
||||||
|
AgentAction.BUY_L: 0.03,
|
||||||
|
}
|
||||||
|
return mapping[action]
|
||||||
|
|
||||||
|
|
||||||
|
def decide(context: AgentContext, agents: Iterable[Agent], weights: Mapping[str, float], method: str = "nash") -> Decision:
|
||||||
|
agent_list = list(agents)
|
||||||
|
norm_weights = weight_map(dict(weights))
|
||||||
|
utilities = compute_utilities(agent_list, context)
|
||||||
|
feas_actions = feasible_actions(agent_list, context)
|
||||||
|
if not feas_actions:
|
||||||
|
return Decision(AgentAction.HOLD, 0.0, 0.0, [], utilities)
|
||||||
|
|
||||||
|
filtered_utilities = {action: utilities[action] for action in feas_actions}
|
||||||
|
hold_scores = utilities.get(AgentAction.HOLD, {})
|
||||||
|
|
||||||
|
if method == "vote":
|
||||||
|
action, confidence = vote(filtered_utilities, norm_weights)
|
||||||
|
else:
|
||||||
|
action, confidence = nash_bargain(filtered_utilities, norm_weights, hold_scores)
|
||||||
|
if action not in feas_actions:
|
||||||
|
action, confidence = vote(filtered_utilities, norm_weights)
|
||||||
|
|
||||||
|
weight = target_weight_for_action(action)
|
||||||
|
return Decision(action, confidence, weight, feas_actions, utilities)
|
||||||
20
app/agents/liquidity.py
Normal file
20
app/agents/liquidity.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
"""Liquidity and transaction cost agent."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from .base import Agent, AgentAction, AgentContext
|
||||||
|
|
||||||
|
|
||||||
|
class LiquidityAgent(Agent):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__(name="A_liq")
|
||||||
|
|
||||||
|
def score(self, context: AgentContext, action: AgentAction) -> float:
|
||||||
|
liq = context.features.get("liquidity_score", 0.5)
|
||||||
|
cost = context.features.get("cost_penalty", 0.0)
|
||||||
|
penalty = cost
|
||||||
|
if action is AgentAction.SELL:
|
||||||
|
return min(1.0, liq + penalty)
|
||||||
|
if action is AgentAction.HOLD:
|
||||||
|
return 0.4 + 0.2 * liq
|
||||||
|
scale = {AgentAction.BUY_S: 0.5, AgentAction.BUY_M: 0.75, AgentAction.BUY_L: 1.0}.get(action, 0.0)
|
||||||
|
return max(0.0, liq * scale - penalty)
|
||||||
23
app/agents/macro.py
Normal file
23
app/agents/macro.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
"""Macro and industry regime agent."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from .base import Agent, AgentAction, AgentContext
|
||||||
|
|
||||||
|
|
||||||
|
class MacroAgent(Agent):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__(name="A_macro")
|
||||||
|
|
||||||
|
def score(self, context: AgentContext, action: AgentAction) -> float:
|
||||||
|
industry_heat = context.features.get("industry_heat", 0.5)
|
||||||
|
relative_strength = context.features.get("industry_relative_mom", 0.0)
|
||||||
|
raw = min(1.0, max(0.0, industry_heat * 0.6 + relative_strength * 0.4))
|
||||||
|
if action is AgentAction.SELL:
|
||||||
|
return 1 - raw
|
||||||
|
if action is AgentAction.HOLD:
|
||||||
|
return 0.5
|
||||||
|
if action is AgentAction.BUY_S:
|
||||||
|
return raw * 0.6
|
||||||
|
if action is AgentAction.BUY_M:
|
||||||
|
return raw * 0.8
|
||||||
|
return raw
|
||||||
29
app/agents/momentum.py
Normal file
29
app/agents/momentum.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
"""Momentum oriented agent."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from math import tanh
|
||||||
|
|
||||||
|
from .base import Agent, AgentAction, AgentContext
|
||||||
|
|
||||||
|
|
||||||
|
def _sigmoid(x: float) -> float:
|
||||||
|
return 0.5 * (tanh(x) + 1)
|
||||||
|
|
||||||
|
|
||||||
|
class MomentumAgent(Agent):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__(name="A_mom")
|
||||||
|
|
||||||
|
def score(self, context: AgentContext, action: AgentAction) -> float:
|
||||||
|
mom20 = context.features.get("mom_20", 0.0)
|
||||||
|
mom60 = context.features.get("mom_60", 0.0)
|
||||||
|
strength = _sigmoid(0.5 * mom20 + 0.5 * mom60)
|
||||||
|
if action is AgentAction.SELL:
|
||||||
|
return 1 - strength
|
||||||
|
if action is AgentAction.HOLD:
|
||||||
|
return 0.5
|
||||||
|
if action is AgentAction.BUY_S:
|
||||||
|
return strength * 0.6
|
||||||
|
if action is AgentAction.BUY_M:
|
||||||
|
return strength * 0.8
|
||||||
|
return strength
|
||||||
26
app/agents/news.py
Normal file
26
app/agents/news.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
"""News and sentiment aware agent."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from .base import Agent, AgentAction, AgentContext
|
||||||
|
|
||||||
|
|
||||||
|
class NewsAgent(Agent):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__(name="A_news")
|
||||||
|
|
||||||
|
def score(self, context: AgentContext, action: AgentAction) -> float:
|
||||||
|
heat = context.features.get("news_heat", 0.0)
|
||||||
|
sentiment = context.features.get("news_sentiment", 0.0)
|
||||||
|
positive = max(0.0, sentiment)
|
||||||
|
negative = max(0.0, -sentiment)
|
||||||
|
buy_score = min(1.0, heat * positive)
|
||||||
|
sell_score = min(1.0, heat * negative)
|
||||||
|
if action is AgentAction.SELL:
|
||||||
|
return sell_score
|
||||||
|
if action is AgentAction.HOLD:
|
||||||
|
return 0.3 + 0.4 * (1 - heat)
|
||||||
|
if action is AgentAction.BUY_S:
|
||||||
|
return 0.5 * buy_score
|
||||||
|
if action is AgentAction.BUY_M:
|
||||||
|
return 0.75 * buy_score
|
||||||
|
return buy_score
|
||||||
30
app/agents/registry.py
Normal file
30
app/agents/registry.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
"""Factory helpers to construct the agent ensemble."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from .base import Agent
|
||||||
|
from .liquidity import LiquidityAgent
|
||||||
|
from .macro import MacroAgent
|
||||||
|
from .momentum import MomentumAgent
|
||||||
|
from .news import NewsAgent
|
||||||
|
from .risk import RiskAgent
|
||||||
|
from .value import ValueAgent
|
||||||
|
|
||||||
|
|
||||||
|
def default_agents() -> List[Agent]:
|
||||||
|
return [
|
||||||
|
MomentumAgent(),
|
||||||
|
ValueAgent(),
|
||||||
|
NewsAgent(),
|
||||||
|
LiquidityAgent(),
|
||||||
|
MacroAgent(),
|
||||||
|
RiskAgent(),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def weight_map(raw: Dict[str, float]) -> Dict[str, float]:
|
||||||
|
total = sum(raw.values())
|
||||||
|
if total == 0:
|
||||||
|
return raw
|
||||||
|
return {name: weight / total for name, weight in raw.items()}
|
||||||
29
app/agents/risk.py
Normal file
29
app/agents/risk.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
"""Risk agent acts as leader with veto rights."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from .base import Agent, AgentAction, AgentContext
|
||||||
|
|
||||||
|
|
||||||
|
class RiskAgent(Agent):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__(name="A_risk")
|
||||||
|
|
||||||
|
def score(self, context: AgentContext, action: AgentAction) -> float:
|
||||||
|
# Base risk agent is neutral unless penalties are triggered.
|
||||||
|
penalty = context.features.get("risk_penalty", 0.0)
|
||||||
|
if action is AgentAction.SELL:
|
||||||
|
return min(1.0, 0.6 + penalty)
|
||||||
|
if action is AgentAction.HOLD:
|
||||||
|
return 0.5
|
||||||
|
return max(0.0, 1.0 - penalty)
|
||||||
|
|
||||||
|
def feasible(self, context: AgentContext, action: AgentAction) -> bool:
|
||||||
|
if action is AgentAction.SELL:
|
||||||
|
return True
|
||||||
|
if context.features.get("is_suspended", False):
|
||||||
|
return False
|
||||||
|
if context.features.get("limit_up", False) and action not in (AgentAction.SELL, AgentAction.HOLD):
|
||||||
|
return False
|
||||||
|
if context.features.get("position_limit", False) and action in (AgentAction.BUY_M, AgentAction.BUY_L):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
26
app/agents/value.py
Normal file
26
app/agents/value.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
"""Value and quality filtering agent."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from .base import Agent, AgentAction, AgentContext
|
||||||
|
|
||||||
|
|
||||||
|
class ValueAgent(Agent):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__(name="A_val")
|
||||||
|
|
||||||
|
def score(self, context: AgentContext, action: AgentAction) -> float:
|
||||||
|
pe = context.features.get("pe_percentile", 0.5)
|
||||||
|
pb = context.features.get("pb_percentile", 0.5)
|
||||||
|
roe = context.features.get("roe_percentile", 0.5)
|
||||||
|
# Lower valuation percentiles and higher quality percentiles add value.
|
||||||
|
raw = max(0.0, (1 - pe) * 0.4 + (1 - pb) * 0.3 + roe * 0.3)
|
||||||
|
raw = min(raw, 1.0)
|
||||||
|
if action is AgentAction.SELL:
|
||||||
|
return 1 - raw
|
||||||
|
if action is AgentAction.HOLD:
|
||||||
|
return 0.5
|
||||||
|
if action is AgentAction.BUY_S:
|
||||||
|
return raw * 0.7
|
||||||
|
if action is AgentAction.BUY_M:
|
||||||
|
return raw * 0.85
|
||||||
|
return raw
|
||||||
0
app/backtest/__init__.py
Normal file
0
app/backtest/__init__.py
Normal file
90
app/backtest/engine.py
Normal file
90
app/backtest/engine.py
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
"""Backtest engine skeleton for daily bar simulation."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import date
|
||||||
|
from typing import Dict, Iterable, List, Mapping
|
||||||
|
|
||||||
|
from app.agents.base import AgentContext
|
||||||
|
from app.agents.game import Decision, decide
|
||||||
|
from app.agents.registry import default_agents
|
||||||
|
from app.utils.db import db_session
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BtConfig:
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
start_date: date
|
||||||
|
end_date: date
|
||||||
|
universe: List[str]
|
||||||
|
params: Dict[str, float]
|
||||||
|
method: str = "nash"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PortfolioState:
|
||||||
|
cash: float = 1_000_000.0
|
||||||
|
holdings: Dict[str, float] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BacktestResult:
|
||||||
|
nav_series: List[Dict[str, float]] = field(default_factory=list)
|
||||||
|
trades: List[Dict[str, str]] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class BacktestEngine:
|
||||||
|
"""Runs the multi-agent game inside a daily event-driven loop."""
|
||||||
|
|
||||||
|
def __init__(self, cfg: BtConfig) -> None:
|
||||||
|
self.cfg = cfg
|
||||||
|
self.agents = default_agents()
|
||||||
|
self.weights = {agent.name: 1.0 for agent in self.agents}
|
||||||
|
|
||||||
|
def load_market_data(self, trade_date: date) -> Mapping[str, Dict[str, float]]:
|
||||||
|
"""Load per-stock feature vectors. Replace with real data access."""
|
||||||
|
|
||||||
|
_ = trade_date
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def simulate_day(self, trade_date: date, state: PortfolioState) -> List[Decision]:
|
||||||
|
feature_map = self.load_market_data(trade_date)
|
||||||
|
decisions: List[Decision] = []
|
||||||
|
for ts_code, features in feature_map.items():
|
||||||
|
context = AgentContext(ts_code=ts_code, trade_date=trade_date.isoformat(), features=features)
|
||||||
|
decision = decide(context, self.agents, self.weights, method=self.cfg.method)
|
||||||
|
decisions.append(decision)
|
||||||
|
self.record_agent_state(context, decision)
|
||||||
|
# TODO: translate decisions into fills, holdings, and NAV updates.
|
||||||
|
_ = state
|
||||||
|
return decisions
|
||||||
|
|
||||||
|
def record_agent_state(self, context: AgentContext, decision: Decision) -> None:
|
||||||
|
payload = {
|
||||||
|
"trade_date": context.trade_date,
|
||||||
|
"ts_code": context.ts_code,
|
||||||
|
"action": decision.action.value,
|
||||||
|
"confidence": decision.confidence,
|
||||||
|
}
|
||||||
|
_ = payload
|
||||||
|
# Implementation should persist into agent_utils and bt_trades.
|
||||||
|
|
||||||
|
def run(self) -> BacktestResult:
|
||||||
|
state = PortfolioState()
|
||||||
|
result = BacktestResult()
|
||||||
|
current = self.cfg.start_date
|
||||||
|
while current <= self.cfg.end_date:
|
||||||
|
decisions = self.simulate_day(current, state)
|
||||||
|
_ = decisions
|
||||||
|
current = date.fromordinal(current.toordinal() + 1)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def run_backtest(cfg: BtConfig) -> BacktestResult:
|
||||||
|
engine = BacktestEngine(cfg)
|
||||||
|
result = engine.run()
|
||||||
|
with db_session() as conn:
|
||||||
|
_ = conn
|
||||||
|
# Implementation should persist bt_nav, bt_trades, and bt_report rows.
|
||||||
|
return result
|
||||||
19
app/backtest/metrics.py
Normal file
19
app/backtest/metrics.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
"""Performance metric utilities."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Iterable, List
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Metric:
|
||||||
|
name: str
|
||||||
|
value: float
|
||||||
|
description: str
|
||||||
|
|
||||||
|
|
||||||
|
def compute_nav_metrics(nav_series: Iterable[float]) -> List[Metric]:
|
||||||
|
"""Compute core statistics such as CAGR, Sharpe, and max drawdown."""
|
||||||
|
|
||||||
|
_ = nav_series
|
||||||
|
return []
|
||||||
0
app/data/__init__.py
Normal file
0
app/data/__init__.py
Normal file
146
app/data/schema.py
Normal file
146
app/data/schema.py
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
"""Database schema management for the investment assistant."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
from app.utils.config import get_config
|
||||||
|
from app.utils.db import db_session
|
||||||
|
|
||||||
|
|
||||||
|
SCHEMA_STATEMENTS: Iterable[str] = (
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS news (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
ts_code TEXT,
|
||||||
|
pub_time TEXT,
|
||||||
|
source TEXT,
|
||||||
|
title TEXT,
|
||||||
|
summary TEXT,
|
||||||
|
url TEXT,
|
||||||
|
entities TEXT,
|
||||||
|
sentiment REAL,
|
||||||
|
heat REAL
|
||||||
|
);
|
||||||
|
""",
|
||||||
|
"""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_news_time ON news(pub_time DESC);
|
||||||
|
""",
|
||||||
|
"""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_news_code ON news(ts_code, pub_time DESC);
|
||||||
|
""",
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS heat_daily (
|
||||||
|
scope TEXT,
|
||||||
|
key TEXT,
|
||||||
|
trade_date TEXT,
|
||||||
|
heat REAL,
|
||||||
|
top_topics TEXT,
|
||||||
|
PRIMARY KEY (scope, key, trade_date)
|
||||||
|
);
|
||||||
|
""",
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS bt_config (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
name TEXT,
|
||||||
|
start_date TEXT,
|
||||||
|
end_date TEXT,
|
||||||
|
universe TEXT,
|
||||||
|
params TEXT
|
||||||
|
);
|
||||||
|
""",
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS bt_trades (
|
||||||
|
cfg_id TEXT,
|
||||||
|
ts_code TEXT,
|
||||||
|
trade_date TEXT,
|
||||||
|
side TEXT,
|
||||||
|
price REAL,
|
||||||
|
qty REAL,
|
||||||
|
reason TEXT,
|
||||||
|
PRIMARY KEY (cfg_id, ts_code, trade_date, side)
|
||||||
|
);
|
||||||
|
""",
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS bt_nav (
|
||||||
|
cfg_id TEXT,
|
||||||
|
trade_date TEXT,
|
||||||
|
nav REAL,
|
||||||
|
ret REAL,
|
||||||
|
pos_count INTEGER,
|
||||||
|
turnover REAL,
|
||||||
|
dd REAL,
|
||||||
|
info TEXT,
|
||||||
|
PRIMARY KEY (cfg_id, trade_date)
|
||||||
|
);
|
||||||
|
""",
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS bt_report (
|
||||||
|
cfg_id TEXT PRIMARY KEY,
|
||||||
|
summary TEXT
|
||||||
|
);
|
||||||
|
""",
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS run_log (
|
||||||
|
ts TEXT PRIMARY KEY,
|
||||||
|
stage TEXT,
|
||||||
|
level TEXT,
|
||||||
|
msg TEXT
|
||||||
|
);
|
||||||
|
""",
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS agent_utils (
|
||||||
|
trade_date TEXT,
|
||||||
|
ts_code TEXT,
|
||||||
|
agent TEXT,
|
||||||
|
action TEXT,
|
||||||
|
utils TEXT,
|
||||||
|
feasible TEXT,
|
||||||
|
weight REAL,
|
||||||
|
PRIMARY KEY (trade_date, ts_code, agent)
|
||||||
|
);
|
||||||
|
""",
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS alloc_log (
|
||||||
|
trade_date TEXT,
|
||||||
|
ts_code TEXT,
|
||||||
|
target_weight REAL,
|
||||||
|
clipped_weight REAL,
|
||||||
|
reason TEXT,
|
||||||
|
PRIMARY KEY (trade_date, ts_code)
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MigrationResult:
|
||||||
|
executed: int
|
||||||
|
skipped: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
def _schema_exists() -> bool:
|
||||||
|
try:
|
||||||
|
with db_session(read_only=True) as conn:
|
||||||
|
cursor = conn.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table' AND name='news'"
|
||||||
|
)
|
||||||
|
return cursor.fetchone() is not None
|
||||||
|
except sqlite3.OperationalError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_database() -> MigrationResult:
|
||||||
|
"""Create tables and indexes required by the application."""
|
||||||
|
|
||||||
|
if _schema_exists():
|
||||||
|
return MigrationResult(executed=0, skipped=True)
|
||||||
|
|
||||||
|
executed = 0
|
||||||
|
with db_session() as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
for statement in SCHEMA_STATEMENTS:
|
||||||
|
cursor.executescript(statement)
|
||||||
|
executed += 1
|
||||||
|
return MigrationResult(executed=executed)
|
||||||
0
app/features/__init__.py
Normal file
0
app/features/__init__.py
Normal file
39
app/features/factors.py
Normal file
39
app/features/factors.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
"""Feature engineering for signals and indicator computation."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import date
|
||||||
|
from typing import Iterable, List
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FactorSpec:
|
||||||
|
name: str
|
||||||
|
window: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FactorResult:
|
||||||
|
ts_code: str
|
||||||
|
trade_date: date
|
||||||
|
values: dict
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_FACTORS: List[FactorSpec] = [
|
||||||
|
FactorSpec("mom_20", 20),
|
||||||
|
FactorSpec("mom_60", 60),
|
||||||
|
FactorSpec("volat_20", 20),
|
||||||
|
FactorSpec("turn_20", 20),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def compute_factors(trade_date: date, factors: Iterable[FactorSpec] = DEFAULT_FACTORS) -> List[FactorResult]:
|
||||||
|
"""Calculate factor values for the requested date.
|
||||||
|
|
||||||
|
This function should join historical price data, apply rolling windows, and
|
||||||
|
persist results into an factors table. The implementation is left as future
|
||||||
|
work.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_ = trade_date, factors
|
||||||
|
raise NotImplementedError
|
||||||
0
app/ingest/__init__.py
Normal file
0
app/ingest/__init__.py
Normal file
42
app/ingest/rss.py
Normal file
42
app/ingest/rss.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
"""RSS ingestion for news and heat scores."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Iterable, List
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RssItem:
|
||||||
|
id: str
|
||||||
|
title: str
|
||||||
|
link: str
|
||||||
|
published: datetime
|
||||||
|
summary: str
|
||||||
|
source: str
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_rss_feed(url: str) -> List[RssItem]:
|
||||||
|
"""Download and parse an RSS feed into structured items."""
|
||||||
|
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def deduplicate_items(items: Iterable[RssItem]) -> List[RssItem]:
|
||||||
|
"""Drop duplicate stories by link/id fingerprint."""
|
||||||
|
|
||||||
|
seen = set()
|
||||||
|
unique: List[RssItem] = []
|
||||||
|
for item in items:
|
||||||
|
key = item.id or item.link
|
||||||
|
if key in seen:
|
||||||
|
continue
|
||||||
|
seen.add(key)
|
||||||
|
unique.append(item)
|
||||||
|
return unique
|
||||||
|
|
||||||
|
|
||||||
|
def save_news_items(items: Iterable[RssItem]) -> None:
|
||||||
|
"""Persist RSS items into the `news` table."""
|
||||||
|
|
||||||
|
raise NotImplementedError
|
||||||
240
app/ingest/tushare.py
Normal file
240
app/ingest/tushare.py
Normal file
@ -0,0 +1,240 @@
|
|||||||
|
"""TuShare 数据拉取管线实现。"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import date
|
||||||
|
from typing import Dict, Iterable, List, Optional, Sequence
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
try:
|
||||||
|
import tushare as ts
|
||||||
|
except ImportError as exc: # pragma: no cover - dependency error surfaced at runtime
|
||||||
|
ts = None # type: ignore[assignment]
|
||||||
|
|
||||||
|
from app.utils.config import get_config
|
||||||
|
from app.utils.db import db_session
|
||||||
|
|
||||||
|
LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FetchJob:
|
||||||
|
name: str
|
||||||
|
start: date
|
||||||
|
end: date
|
||||||
|
granularity: str = "daily"
|
||||||
|
ts_codes: Optional[Sequence[str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
_TABLE_SCHEMAS: Dict[str, str] = {
|
||||||
|
"daily": """
|
||||||
|
CREATE TABLE IF NOT EXISTS daily (
|
||||||
|
ts_code TEXT,
|
||||||
|
trade_date TEXT,
|
||||||
|
open REAL,
|
||||||
|
high REAL,
|
||||||
|
low REAL,
|
||||||
|
close REAL,
|
||||||
|
pre_close REAL,
|
||||||
|
change REAL,
|
||||||
|
pct_chg REAL,
|
||||||
|
vol REAL,
|
||||||
|
amount REAL,
|
||||||
|
PRIMARY KEY (ts_code, trade_date)
|
||||||
|
);
|
||||||
|
""",
|
||||||
|
"suspend": """
|
||||||
|
CREATE TABLE IF NOT EXISTS suspend (
|
||||||
|
ts_code TEXT,
|
||||||
|
suspend_date TEXT,
|
||||||
|
resume_date TEXT,
|
||||||
|
suspend_type TEXT,
|
||||||
|
ann_date TEXT,
|
||||||
|
suspend_timing TEXT,
|
||||||
|
resume_timing TEXT,
|
||||||
|
reason TEXT,
|
||||||
|
PRIMARY KEY (ts_code, suspend_date)
|
||||||
|
);
|
||||||
|
""",
|
||||||
|
"trade_calendar": """
|
||||||
|
CREATE TABLE IF NOT EXISTS trade_calendar (
|
||||||
|
exchange TEXT,
|
||||||
|
cal_date TEXT PRIMARY KEY,
|
||||||
|
is_open INTEGER,
|
||||||
|
pretrade_date TEXT
|
||||||
|
);
|
||||||
|
""",
|
||||||
|
"stk_limit": """
|
||||||
|
CREATE TABLE IF NOT EXISTS stk_limit (
|
||||||
|
ts_code TEXT,
|
||||||
|
trade_date TEXT,
|
||||||
|
up_limit REAL,
|
||||||
|
down_limit REAL,
|
||||||
|
PRIMARY KEY (ts_code, trade_date)
|
||||||
|
);
|
||||||
|
""",
|
||||||
|
}
|
||||||
|
|
||||||
|
_TABLE_COLUMNS: Dict[str, List[str]] = {
|
||||||
|
"daily": [
|
||||||
|
"ts_code",
|
||||||
|
"trade_date",
|
||||||
|
"open",
|
||||||
|
"high",
|
||||||
|
"low",
|
||||||
|
"close",
|
||||||
|
"pre_close",
|
||||||
|
"change",
|
||||||
|
"pct_chg",
|
||||||
|
"vol",
|
||||||
|
"amount",
|
||||||
|
],
|
||||||
|
"suspend": [
|
||||||
|
"ts_code",
|
||||||
|
"suspend_date",
|
||||||
|
"resume_date",
|
||||||
|
"suspend_type",
|
||||||
|
"ann_date",
|
||||||
|
"suspend_timing",
|
||||||
|
"resume_timing",
|
||||||
|
"reason",
|
||||||
|
],
|
||||||
|
"trade_calendar": [
|
||||||
|
"exchange",
|
||||||
|
"cal_date",
|
||||||
|
"is_open",
|
||||||
|
"pretrade_date",
|
||||||
|
],
|
||||||
|
"stk_limit": [
|
||||||
|
"ts_code",
|
||||||
|
"trade_date",
|
||||||
|
"up_limit",
|
||||||
|
"down_limit",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_client():
|
||||||
|
if ts is None:
|
||||||
|
raise RuntimeError("未安装 tushare,请先在环境中安装 tushare 包")
|
||||||
|
token = get_config().tushare_token or os.getenv("TUSHARE_TOKEN")
|
||||||
|
if not token:
|
||||||
|
raise RuntimeError("未配置 TuShare Token,请在配置文件或环境变量 TUSHARE_TOKEN 中设置")
|
||||||
|
if not hasattr(_ensure_client, "_client") or _ensure_client._client is None: # type: ignore[attr-defined]
|
||||||
|
ts.set_token(token)
|
||||||
|
_ensure_client._client = ts.pro_api(token) # type: ignore[attr-defined]
|
||||||
|
LOGGER.info("完成 TuShare 客户端初始化")
|
||||||
|
return _ensure_client._client # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
|
def _format_date(value: date) -> str:
|
||||||
|
return value.strftime("%Y%m%d")
|
||||||
|
|
||||||
|
|
||||||
|
def _df_to_records(df: pd.DataFrame, allowed_cols: List[str]) -> List[Dict]:
|
||||||
|
if df is None or df.empty:
|
||||||
|
return []
|
||||||
|
# 对缺失列进行补全,防止写库时缺少绑定参数
|
||||||
|
reindexed = df.reindex(columns=allowed_cols)
|
||||||
|
return reindexed.where(pd.notnull(reindexed), None).to_dict("records")
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_daily_bars(job: FetchJob) -> Iterable[Dict]:
|
||||||
|
"""拉取日线行情。"""
|
||||||
|
|
||||||
|
client = _ensure_client()
|
||||||
|
start_date = _format_date(job.start)
|
||||||
|
end_date = _format_date(job.end)
|
||||||
|
frames: List[pd.DataFrame] = []
|
||||||
|
|
||||||
|
if job.granularity != "daily":
|
||||||
|
raise ValueError(f"暂不支持的粒度:{job.granularity}")
|
||||||
|
|
||||||
|
if job.ts_codes:
|
||||||
|
for code in job.ts_codes:
|
||||||
|
LOGGER.info("拉取 %s 的日线行情(%s-%s)", code, start_date, end_date)
|
||||||
|
frames.append(client.daily(ts_code=code, start_date=start_date, end_date=end_date))
|
||||||
|
else:
|
||||||
|
LOGGER.info("按全市场拉取日线行情(%s-%s)", start_date, end_date)
|
||||||
|
frames.append(client.daily(start_date=start_date, end_date=end_date))
|
||||||
|
|
||||||
|
if not frames:
|
||||||
|
return []
|
||||||
|
df = pd.concat(frames, ignore_index=True) if len(frames) > 1 else frames[0]
|
||||||
|
return _df_to_records(df, _TABLE_COLUMNS["daily"])
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_suspensions(start: date, end: date, ts_code: Optional[str] = None) -> Iterable[Dict]:
|
||||||
|
client = _ensure_client()
|
||||||
|
start_date = _format_date(start)
|
||||||
|
end_date = _format_date(end)
|
||||||
|
LOGGER.info("拉取停复牌信息(%s-%s)", start_date, end_date)
|
||||||
|
df = client.suspend_d(ts_code=ts_code, start_date=start_date, end_date=end_date)
|
||||||
|
return _df_to_records(df, _TABLE_COLUMNS["suspend"])
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_trade_calendar(start: date, end: date, exchange: str = "SSE") -> Iterable[Dict]:
|
||||||
|
client = _ensure_client()
|
||||||
|
start_date = _format_date(start)
|
||||||
|
end_date = _format_date(end)
|
||||||
|
LOGGER.info("拉取交易日历(交易所:%s,区间:%s-%s)", exchange, start_date, end_date)
|
||||||
|
df = client.trade_cal(exchange=exchange, start_date=start_date, end_date=end_date)
|
||||||
|
return _df_to_records(df, _TABLE_COLUMNS["trade_calendar"])
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_stk_limit(start: date, end: date, ts_code: Optional[str] = None) -> Iterable[Dict]:
|
||||||
|
client = _ensure_client()
|
||||||
|
start_date = _format_date(start)
|
||||||
|
end_date = _format_date(end)
|
||||||
|
LOGGER.info("拉取涨跌停价格(%s-%s)", start_date, end_date)
|
||||||
|
df = client.stk_limit(ts_code=ts_code, start_date=start_date, end_date=end_date)
|
||||||
|
return _df_to_records(df, _TABLE_COLUMNS["stk_limit"])
|
||||||
|
|
||||||
|
|
||||||
|
def save_records(table: str, rows: Iterable[Dict]) -> None:
|
||||||
|
"""将拉取的数据写入 SQLite。"""
|
||||||
|
|
||||||
|
items = list(rows)
|
||||||
|
if not items:
|
||||||
|
LOGGER.info("表 %s 没有新增记录,跳过写入", table)
|
||||||
|
return
|
||||||
|
|
||||||
|
schema = _TABLE_SCHEMAS.get(table)
|
||||||
|
columns = _TABLE_COLUMNS.get(table)
|
||||||
|
if not schema or not columns:
|
||||||
|
raise ValueError(f"不支持写入的表:{table}")
|
||||||
|
|
||||||
|
placeholders = ",".join([f":{col}" for col in columns])
|
||||||
|
col_clause = ",".join(columns)
|
||||||
|
|
||||||
|
LOGGER.info("表 %s 写入 %d 条记录", table, len(items))
|
||||||
|
with db_session() as conn:
|
||||||
|
conn.executescript(schema)
|
||||||
|
conn.executemany(
|
||||||
|
f"INSERT OR REPLACE INTO {table} ({col_clause}) VALUES ({placeholders})",
|
||||||
|
items,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run_ingestion(job: FetchJob, include_limits: bool = True) -> None:
|
||||||
|
"""按任务配置拉取 TuShare 数据。"""
|
||||||
|
|
||||||
|
LOGGER.info("启动 TuShare 拉取任务:%s", job.name)
|
||||||
|
|
||||||
|
daily_rows = fetch_daily_bars(job)
|
||||||
|
save_records("daily", daily_rows)
|
||||||
|
|
||||||
|
suspend_rows = fetch_suspensions(job.start, job.end)
|
||||||
|
save_records("suspend", suspend_rows)
|
||||||
|
|
||||||
|
calendar_rows = fetch_trade_calendar(job.start, job.end)
|
||||||
|
save_records("trade_calendar", calendar_rows)
|
||||||
|
|
||||||
|
if include_limits:
|
||||||
|
limit_rows = fetch_stk_limit(job.start, job.end)
|
||||||
|
save_records("stk_limit", limit_rows)
|
||||||
|
|
||||||
|
LOGGER.info("任务 %s 完成", job.name)
|
||||||
0
app/llm/__init__.py
Normal file
0
app/llm/__init__.py
Normal file
18
app/llm/explain.py
Normal file
18
app/llm/explain.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
"""LLM assisted explanations and summaries."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from .prompts import plan_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def make_human_card(ts_code: str, trade_date: str, context: Dict) -> Dict:
|
||||||
|
"""Compose payload for UI cards and LLM requests."""
|
||||||
|
|
||||||
|
prompt = plan_prompt(context)
|
||||||
|
return {
|
||||||
|
"ts_code": ts_code,
|
||||||
|
"trade_date": trade_date,
|
||||||
|
"prompt": prompt,
|
||||||
|
"context": context,
|
||||||
|
}
|
||||||
11
app/llm/prompts.py
Normal file
11
app/llm/prompts.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
"""Prompt templates for natural language outputs."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
|
||||||
|
def plan_prompt(data: Dict) -> str:
|
||||||
|
"""Build a concise instruction prompt for the LLM."""
|
||||||
|
|
||||||
|
_ = data
|
||||||
|
return "你是一个投资助理,请根据提供的数据给出三条要点和两条风险提示。"
|
||||||
35
app/main.py
Normal file
35
app/main.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
"""Command line entry points for routine tasks."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import date
|
||||||
|
|
||||||
|
from app.backtest.engine import BtConfig, run_backtest
|
||||||
|
from app.data.schema import initialize_database
|
||||||
|
|
||||||
|
|
||||||
|
def init_db() -> None:
|
||||||
|
result = initialize_database()
|
||||||
|
if result.skipped:
|
||||||
|
print("Database already initialized; skipping schema creation")
|
||||||
|
else:
|
||||||
|
print(f"Initialized database with {result.executed} statements")
|
||||||
|
|
||||||
|
|
||||||
|
def run_sample_backtest() -> None:
|
||||||
|
cfg = BtConfig(
|
||||||
|
id="demo",
|
||||||
|
name="Demo Strategy",
|
||||||
|
start_date=date(2020, 1, 1),
|
||||||
|
end_date=date(2020, 3, 31),
|
||||||
|
universe=["000001.SZ"],
|
||||||
|
params={
|
||||||
|
"target": 0.035,
|
||||||
|
"stop": -0.015,
|
||||||
|
"hold_days": 10,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
run_backtest(cfg)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
init_db()
|
||||||
0
app/ui/__init__.py
Normal file
0
app/ui/__init__.py
Normal file
85
app/ui/streamlit_app.py
Normal file
85
app/ui/streamlit_app.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
"""Streamlit UI scaffold for the investment assistant."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from datetime import date
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
ROOT = Path(__file__).resolve().parents[2]
|
||||||
|
if str(ROOT) not in sys.path:
|
||||||
|
sys.path.insert(0, str(ROOT))
|
||||||
|
|
||||||
|
import streamlit as st
|
||||||
|
|
||||||
|
from app.data.schema import initialize_database
|
||||||
|
from app.ingest.tushare import FetchJob, run_ingestion
|
||||||
|
from app.llm.explain import make_human_card
|
||||||
|
|
||||||
|
|
||||||
|
def render_today_plan() -> None:
|
||||||
|
st.header("今日计划")
|
||||||
|
st.write("待接入候选池筛选与多智能体决策结果。")
|
||||||
|
sample = make_human_card("000001.SZ", "2025-01-01", {"decisions": []})
|
||||||
|
st.json(sample)
|
||||||
|
|
||||||
|
|
||||||
|
def render_backtest() -> None:
|
||||||
|
st.header("回测与复盘")
|
||||||
|
st.write("在此运行回测、展示净值曲线与代理贡献。")
|
||||||
|
st.button("开始回测")
|
||||||
|
|
||||||
|
|
||||||
|
def render_settings() -> None:
|
||||||
|
st.header("数据与设置")
|
||||||
|
st.text_input("TuShare Token")
|
||||||
|
st.write("新闻源开关与数据库备份将在此配置。")
|
||||||
|
|
||||||
|
|
||||||
|
def render_tests() -> None:
|
||||||
|
st.header("自检测试")
|
||||||
|
st.write("用于快速检查数据库与数据拉取是否正常工作。")
|
||||||
|
|
||||||
|
if st.button("测试数据库初始化"):
|
||||||
|
with st.spinner("正在检查数据库..."):
|
||||||
|
result = initialize_database()
|
||||||
|
if result.skipped:
|
||||||
|
st.success("数据库已存在,检查通过。")
|
||||||
|
else:
|
||||||
|
st.success(f"数据库初始化完成,共执行 {result.executed} 条语句。")
|
||||||
|
|
||||||
|
st.divider()
|
||||||
|
|
||||||
|
if st.button("测试 TuShare 拉取(示例 2024-01-01 至 2024-01-03)"):
|
||||||
|
with st.spinner("正在调用 TuShare 接口..."):
|
||||||
|
try:
|
||||||
|
run_ingestion(
|
||||||
|
FetchJob(
|
||||||
|
name="streamlit_self_test",
|
||||||
|
start=date(2024, 1, 1),
|
||||||
|
end=date(2024, 1, 3),
|
||||||
|
ts_codes=("000001.SZ",),
|
||||||
|
),
|
||||||
|
include_limits=False,
|
||||||
|
)
|
||||||
|
st.success("TuShare 示例拉取完成,数据已写入数据库。")
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
st.error(f"拉取失败:{exc}")
|
||||||
|
|
||||||
|
st.info("注意:TuShare 拉取依赖网络与 Token,若环境未配置将出现错误提示。")
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
st.set_page_config(page_title="多智能体投资助理", layout="wide")
|
||||||
|
tabs = st.tabs(["今日计划", "回测与复盘", "数据与设置", "自检测试"])
|
||||||
|
with tabs[0]:
|
||||||
|
render_today_plan()
|
||||||
|
with tabs[1]:
|
||||||
|
render_backtest()
|
||||||
|
with tabs[2]:
|
||||||
|
render_settings()
|
||||||
|
with tabs[3]:
|
||||||
|
render_tests()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
0
app/utils/__init__.py
Normal file
0
app/utils/__init__.py
Normal file
34
app/utils/calendar.py
Normal file
34
app/utils/calendar.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
"""Trading calendar utilities.
|
||||||
|
|
||||||
|
These helpers abstract exchange calendars and trading day lookups. Real
|
||||||
|
implementations should integrate with TuShare or cached calendars.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import date, timedelta
|
||||||
|
from typing import Iterable, List
|
||||||
|
|
||||||
|
|
||||||
|
def is_trading_day(day: date, holidays: Iterable[date] | None = None) -> bool:
|
||||||
|
if day.weekday() >= 5:
|
||||||
|
return False
|
||||||
|
if holidays and day in set(holidays):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def previous_trading_day(day: date, holidays: Iterable[date] | None = None) -> date:
|
||||||
|
current = day - timedelta(days=1)
|
||||||
|
while not is_trading_day(current, holidays):
|
||||||
|
current -= timedelta(days=1)
|
||||||
|
return current
|
||||||
|
|
||||||
|
|
||||||
|
def trading_days_between(start: date, end: date, holidays: Iterable[date] | None = None) -> List[date]:
|
||||||
|
current = start
|
||||||
|
days: List[date] = []
|
||||||
|
while current <= end:
|
||||||
|
if is_trading_day(current, holidays):
|
||||||
|
days.append(current)
|
||||||
|
current += timedelta(days=1)
|
||||||
|
return days
|
||||||
65
app/utils/config.py
Normal file
65
app/utils/config.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
"""Application configuration models and helpers."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
|
||||||
|
def _default_root() -> Path:
|
||||||
|
return Path(__file__).resolve().parents[2] / "app" / "data"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DataPaths:
|
||||||
|
"""Holds filesystem locations for persistent artifacts."""
|
||||||
|
|
||||||
|
root: Path = field(default_factory=_default_root)
|
||||||
|
database: Path = field(init=False)
|
||||||
|
backups: Path = field(init=False)
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
self.root.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.database = self.root / "llm_quant.db"
|
||||||
|
self.backups = self.root / "backups"
|
||||||
|
self.backups.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AgentWeights:
|
||||||
|
"""Default weighting for decision agents."""
|
||||||
|
|
||||||
|
momentum: float = 0.30
|
||||||
|
value: float = 0.20
|
||||||
|
news: float = 0.20
|
||||||
|
liquidity: float = 0.15
|
||||||
|
macro: float = 0.15
|
||||||
|
|
||||||
|
def as_dict(self) -> Dict[str, float]:
|
||||||
|
return {
|
||||||
|
"A_mom": self.momentum,
|
||||||
|
"A_val": self.value,
|
||||||
|
"A_news": self.news,
|
||||||
|
"A_liq": self.liquidity,
|
||||||
|
"A_macro": self.macro,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AppConfig:
|
||||||
|
"""User configurable settings persisted in a simple structure."""
|
||||||
|
|
||||||
|
tushare_token: Optional[str] = None
|
||||||
|
rss_sources: Dict[str, bool] = field(default_factory=dict)
|
||||||
|
decision_method: str = "nash"
|
||||||
|
data_paths: DataPaths = field(default_factory=DataPaths)
|
||||||
|
agent_weights: AgentWeights = field(default_factory=AgentWeights)
|
||||||
|
|
||||||
|
|
||||||
|
CONFIG = AppConfig()
|
||||||
|
|
||||||
|
|
||||||
|
def get_config() -> AppConfig:
|
||||||
|
"""Return a mutable global configuration instance."""
|
||||||
|
|
||||||
|
return CONFIG
|
||||||
47
app/utils/db.py
Normal file
47
app/utils/db.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
"""SQLite helpers for application data access."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Generator, Iterable, Tuple
|
||||||
|
|
||||||
|
from .config import get_config
|
||||||
|
|
||||||
|
|
||||||
|
def get_connection(read_only: bool = False) -> sqlite3.Connection:
|
||||||
|
"""Create a SQLite connection against the configured database file."""
|
||||||
|
|
||||||
|
db_path: Path = get_config().data_paths.database
|
||||||
|
uri = f"file:{db_path}?mode={'ro' if read_only else 'rw'}"
|
||||||
|
if not db_path.exists() and not read_only:
|
||||||
|
# Ensure directory exists before first write.
|
||||||
|
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
else:
|
||||||
|
conn = sqlite3.connect(uri, uri=True)
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
return conn
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def db_session(read_only: bool = False) -> Generator[sqlite3.Connection, None, None]:
|
||||||
|
"""Context manager yielding a connection with automatic commit/rollback."""
|
||||||
|
|
||||||
|
conn = get_connection(read_only=read_only)
|
||||||
|
try:
|
||||||
|
yield conn
|
||||||
|
if not read_only:
|
||||||
|
conn.commit()
|
||||||
|
except Exception:
|
||||||
|
conn.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
def execute_many(sql: str, params: Iterable[Tuple]) -> None:
|
||||||
|
"""Bulk execute a parameterized statement inside a write session."""
|
||||||
|
|
||||||
|
with db_session() as conn:
|
||||||
|
conn.executemany(sql, params)
|
||||||
28
app/utils/logging.py
Normal file
28
app/utils/logging.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
"""Centralized logging configuration."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from .config import get_config
|
||||||
|
|
||||||
|
|
||||||
|
def configure_logging(level: int = logging.INFO) -> None:
|
||||||
|
"""Setup root logger with file and console handlers."""
|
||||||
|
|
||||||
|
cfg = get_config()
|
||||||
|
log_dir = cfg.data_paths.root / "logs"
|
||||||
|
log_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
logfile = log_dir / "app.log"
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=level,
|
||||||
|
format="%(asctime)s %(levelname)s %(name)s - %(message)s",
|
||||||
|
handlers=[
|
||||||
|
logging.FileHandler(logfile, encoding="utf-8"),
|
||||||
|
logging.StreamHandler(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
configure_logging()
|
||||||
Loading…
Reference in New Issue
Block a user