init
This commit is contained in:
commit
6eac6c5f69
29
.gitignore
vendored
Normal file
29
.gitignore
vendored
Normal file
@ -0,0 +1,29 @@
|
||||
# Python cache and build artifacts
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*.so
|
||||
.build/
|
||||
.eggs/
|
||||
*.egg-info/
|
||||
|
||||
# Virtual environments
|
||||
.venv/
|
||||
venv/
|
||||
ENV/
|
||||
env/
|
||||
|
||||
# Jupyter checkpoints
|
||||
.ipynb_checkpoints/
|
||||
|
||||
# Logs and data
|
||||
app/data/*.db*
|
||||
app/data/backups/
|
||||
app/data/logs/
|
||||
*.log
|
||||
|
||||
# Streamlit temporary files
|
||||
.streamlit/
|
||||
|
||||
# System files
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
27
README.md
Normal file
27
README.md
Normal file
@ -0,0 +1,27 @@
|
||||
# 多智能体投资助理骨架
|
||||
|
||||
本仓库提供一个基于多智能体博弈的 A 股日线投资助理代码框架,满足单机可运行、SQLite 存储和 Streamlit UI 的需求。核心模块划分如下:
|
||||
|
||||
- `app/data`:数据库初始化与 Schema 定义。
|
||||
- `app/utils`:配置、数据库连接、日志和交易日历工具。
|
||||
- `app/ingest`:TuShare 与 RSS 数据拉取骨架。
|
||||
- `app/features`:指标与信号计算接口。
|
||||
- `app/agents`:多智能体博弈实现,包括动量、价值、新闻、流动性、宏观与风险代理。
|
||||
- `app/backtest`:日线回测引擎与指标计算的占位实现。
|
||||
- `app/llm`:人类可读卡片与摘要生成入口(仅构建提示,不直接交易)。
|
||||
- `app/ui`:Streamlit 三页界面骨架。
|
||||
|
||||
## 快速开始
|
||||
|
||||
```bash
|
||||
python -m app.main # 初始化数据库
|
||||
streamlit run app/ui/streamlit_app.py
|
||||
```
|
||||
|
||||
## 下一步
|
||||
|
||||
1. 在 `app/ingest` 中补充 TuShare 和 RSS 数据抓取逻辑。
|
||||
2. 完善 `app/features` 和 `app/backtest` 以实现实际的信号计算与事件驱动回测。
|
||||
3. 将代理效用写入 SQLite 的 `agent_utils` 和 `alloc_log` 表,驱动 UI 展示。
|
||||
4. 使用轻量情感分析与热度计算,填充 `news` 和 `heat_daily`。
|
||||
5. 接入本地小模型或 API 完成 LLM 文本解释,并在 UI 中展示。
|
||||
0
app/__init__.py
Normal file
0
app/__init__.py
Normal file
0
app/agents/__init__.py
Normal file
0
app/agents/__init__.py
Normal file
44
app/agents/base.py
Normal file
44
app/agents/base.py
Normal file
@ -0,0 +1,44 @@
|
||||
"""Agent abstractions for the multi-agent decision engine."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Dict, Mapping
|
||||
|
||||
|
||||
class AgentAction(str, Enum):
|
||||
SELL = "SELL"
|
||||
HOLD = "HOLD"
|
||||
BUY_S = "BUY_S"
|
||||
BUY_M = "BUY_M"
|
||||
BUY_L = "BUY_L"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentContext:
|
||||
ts_code: str
|
||||
trade_date: str
|
||||
features: Mapping[str, float]
|
||||
|
||||
|
||||
class Agent:
|
||||
"""Base class for all decision agents."""
|
||||
|
||||
name: str
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
self.name = name
|
||||
|
||||
def score(self, context: AgentContext, action: AgentAction) -> float:
|
||||
"""Return a normalized utility value in [0,1] for the proposed action."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def feasible(self, context: AgentContext, action: AgentAction) -> bool:
|
||||
"""Optional hook for agents with veto power (defaults to True)."""
|
||||
|
||||
_ = context, action
|
||||
return True
|
||||
|
||||
|
||||
UtilityMatrix = Dict[AgentAction, Dict[str, float]]
|
||||
127
app/agents/game.py
Normal file
127
app/agents/game.py
Normal file
@ -0,0 +1,127 @@
|
||||
"""Multi-agent decision game implementation."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from math import exp, log
|
||||
from typing import Dict, Iterable, List, Mapping, Tuple
|
||||
|
||||
from .base import Agent, AgentAction, AgentContext, UtilityMatrix
|
||||
from .registry import weight_map
|
||||
|
||||
|
||||
ACTIONS: Tuple[AgentAction, ...] = (
|
||||
AgentAction.SELL,
|
||||
AgentAction.HOLD,
|
||||
AgentAction.BUY_S,
|
||||
AgentAction.BUY_M,
|
||||
AgentAction.BUY_L,
|
||||
)
|
||||
|
||||
|
||||
def _clamp(value: float) -> float:
|
||||
return max(0.0, min(1.0, value))
|
||||
|
||||
|
||||
@dataclass
|
||||
class Decision:
|
||||
action: AgentAction
|
||||
confidence: float
|
||||
target_weight: float
|
||||
feasible_actions: List[AgentAction]
|
||||
utilities: UtilityMatrix
|
||||
|
||||
|
||||
def compute_utilities(agents: Iterable[Agent], context: AgentContext) -> UtilityMatrix:
|
||||
utilities: UtilityMatrix = {}
|
||||
for action in ACTIONS:
|
||||
utilities[action] = {}
|
||||
for agent in agents:
|
||||
score = _clamp(agent.score(context, action))
|
||||
utilities[action][agent.name] = score
|
||||
return utilities
|
||||
|
||||
|
||||
def feasible_actions(agents: Iterable[Agent], context: AgentContext) -> List[AgentAction]:
|
||||
feas: List[AgentAction] = []
|
||||
for action in ACTIONS:
|
||||
if all(agent.feasible(context, action) for agent in agents):
|
||||
feas.append(action)
|
||||
return feas
|
||||
|
||||
|
||||
def nash_bargain(utilities: UtilityMatrix, weights: Mapping[str, float], disagreement: Mapping[str, float]) -> Tuple[AgentAction, float]:
|
||||
best_action = AgentAction.HOLD
|
||||
best_score = float("-inf")
|
||||
for action, agent_scores in utilities.items():
|
||||
if action not in utilities:
|
||||
continue
|
||||
log_product = 0.0
|
||||
valid = True
|
||||
for agent_name, score in agent_scores.items():
|
||||
w = weights.get(agent_name, 0.0)
|
||||
if w == 0:
|
||||
continue
|
||||
gap = score - disagreement.get(agent_name, 0.0)
|
||||
if gap <= 0:
|
||||
valid = False
|
||||
break
|
||||
log_product += w * log(gap)
|
||||
if not valid:
|
||||
continue
|
||||
if log_product > best_score:
|
||||
best_score = log_product
|
||||
best_action = action
|
||||
if best_score == float("-inf"):
|
||||
return AgentAction.HOLD, 0.0
|
||||
confidence = _aggregate_confidence(utilities[best_action], weights)
|
||||
return best_action, confidence
|
||||
|
||||
|
||||
def vote(utilities: UtilityMatrix, weights: Mapping[str, float]) -> Tuple[AgentAction, float]:
|
||||
scores: Dict[AgentAction, float] = {}
|
||||
for action, agent_scores in utilities.items():
|
||||
scores[action] = sum(weights.get(agent, 0.0) * score for agent, score in agent_scores.items())
|
||||
best_action = max(scores, key=scores.get)
|
||||
confidence = _aggregate_confidence(utilities[best_action], weights)
|
||||
return best_action, confidence
|
||||
|
||||
|
||||
def _aggregate_confidence(agent_scores: Mapping[str, float], weights: Mapping[str, float]) -> float:
|
||||
total = sum(weights.values())
|
||||
if total <= 0:
|
||||
return 0.0
|
||||
weighted = sum(weights.get(agent, 0.0) * score for agent, score in agent_scores.items())
|
||||
return weighted / total
|
||||
|
||||
|
||||
def target_weight_for_action(action: AgentAction) -> float:
|
||||
mapping = {
|
||||
AgentAction.SELL: -1.0,
|
||||
AgentAction.HOLD: 0.0,
|
||||
AgentAction.BUY_S: 0.01,
|
||||
AgentAction.BUY_M: 0.02,
|
||||
AgentAction.BUY_L: 0.03,
|
||||
}
|
||||
return mapping[action]
|
||||
|
||||
|
||||
def decide(context: AgentContext, agents: Iterable[Agent], weights: Mapping[str, float], method: str = "nash") -> Decision:
|
||||
agent_list = list(agents)
|
||||
norm_weights = weight_map(dict(weights))
|
||||
utilities = compute_utilities(agent_list, context)
|
||||
feas_actions = feasible_actions(agent_list, context)
|
||||
if not feas_actions:
|
||||
return Decision(AgentAction.HOLD, 0.0, 0.0, [], utilities)
|
||||
|
||||
filtered_utilities = {action: utilities[action] for action in feas_actions}
|
||||
hold_scores = utilities.get(AgentAction.HOLD, {})
|
||||
|
||||
if method == "vote":
|
||||
action, confidence = vote(filtered_utilities, norm_weights)
|
||||
else:
|
||||
action, confidence = nash_bargain(filtered_utilities, norm_weights, hold_scores)
|
||||
if action not in feas_actions:
|
||||
action, confidence = vote(filtered_utilities, norm_weights)
|
||||
|
||||
weight = target_weight_for_action(action)
|
||||
return Decision(action, confidence, weight, feas_actions, utilities)
|
||||
20
app/agents/liquidity.py
Normal file
20
app/agents/liquidity.py
Normal file
@ -0,0 +1,20 @@
|
||||
"""Liquidity and transaction cost agent."""
|
||||
from __future__ import annotations
|
||||
|
||||
from .base import Agent, AgentAction, AgentContext
|
||||
|
||||
|
||||
class LiquidityAgent(Agent):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(name="A_liq")
|
||||
|
||||
def score(self, context: AgentContext, action: AgentAction) -> float:
|
||||
liq = context.features.get("liquidity_score", 0.5)
|
||||
cost = context.features.get("cost_penalty", 0.0)
|
||||
penalty = cost
|
||||
if action is AgentAction.SELL:
|
||||
return min(1.0, liq + penalty)
|
||||
if action is AgentAction.HOLD:
|
||||
return 0.4 + 0.2 * liq
|
||||
scale = {AgentAction.BUY_S: 0.5, AgentAction.BUY_M: 0.75, AgentAction.BUY_L: 1.0}.get(action, 0.0)
|
||||
return max(0.0, liq * scale - penalty)
|
||||
23
app/agents/macro.py
Normal file
23
app/agents/macro.py
Normal file
@ -0,0 +1,23 @@
|
||||
"""Macro and industry regime agent."""
|
||||
from __future__ import annotations
|
||||
|
||||
from .base import Agent, AgentAction, AgentContext
|
||||
|
||||
|
||||
class MacroAgent(Agent):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(name="A_macro")
|
||||
|
||||
def score(self, context: AgentContext, action: AgentAction) -> float:
|
||||
industry_heat = context.features.get("industry_heat", 0.5)
|
||||
relative_strength = context.features.get("industry_relative_mom", 0.0)
|
||||
raw = min(1.0, max(0.0, industry_heat * 0.6 + relative_strength * 0.4))
|
||||
if action is AgentAction.SELL:
|
||||
return 1 - raw
|
||||
if action is AgentAction.HOLD:
|
||||
return 0.5
|
||||
if action is AgentAction.BUY_S:
|
||||
return raw * 0.6
|
||||
if action is AgentAction.BUY_M:
|
||||
return raw * 0.8
|
||||
return raw
|
||||
29
app/agents/momentum.py
Normal file
29
app/agents/momentum.py
Normal file
@ -0,0 +1,29 @@
|
||||
"""Momentum oriented agent."""
|
||||
from __future__ import annotations
|
||||
|
||||
from math import tanh
|
||||
|
||||
from .base import Agent, AgentAction, AgentContext
|
||||
|
||||
|
||||
def _sigmoid(x: float) -> float:
|
||||
return 0.5 * (tanh(x) + 1)
|
||||
|
||||
|
||||
class MomentumAgent(Agent):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(name="A_mom")
|
||||
|
||||
def score(self, context: AgentContext, action: AgentAction) -> float:
|
||||
mom20 = context.features.get("mom_20", 0.0)
|
||||
mom60 = context.features.get("mom_60", 0.0)
|
||||
strength = _sigmoid(0.5 * mom20 + 0.5 * mom60)
|
||||
if action is AgentAction.SELL:
|
||||
return 1 - strength
|
||||
if action is AgentAction.HOLD:
|
||||
return 0.5
|
||||
if action is AgentAction.BUY_S:
|
||||
return strength * 0.6
|
||||
if action is AgentAction.BUY_M:
|
||||
return strength * 0.8
|
||||
return strength
|
||||
26
app/agents/news.py
Normal file
26
app/agents/news.py
Normal file
@ -0,0 +1,26 @@
|
||||
"""News and sentiment aware agent."""
|
||||
from __future__ import annotations
|
||||
|
||||
from .base import Agent, AgentAction, AgentContext
|
||||
|
||||
|
||||
class NewsAgent(Agent):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(name="A_news")
|
||||
|
||||
def score(self, context: AgentContext, action: AgentAction) -> float:
|
||||
heat = context.features.get("news_heat", 0.0)
|
||||
sentiment = context.features.get("news_sentiment", 0.0)
|
||||
positive = max(0.0, sentiment)
|
||||
negative = max(0.0, -sentiment)
|
||||
buy_score = min(1.0, heat * positive)
|
||||
sell_score = min(1.0, heat * negative)
|
||||
if action is AgentAction.SELL:
|
||||
return sell_score
|
||||
if action is AgentAction.HOLD:
|
||||
return 0.3 + 0.4 * (1 - heat)
|
||||
if action is AgentAction.BUY_S:
|
||||
return 0.5 * buy_score
|
||||
if action is AgentAction.BUY_M:
|
||||
return 0.75 * buy_score
|
||||
return buy_score
|
||||
30
app/agents/registry.py
Normal file
30
app/agents/registry.py
Normal file
@ -0,0 +1,30 @@
|
||||
"""Factory helpers to construct the agent ensemble."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
from .base import Agent
|
||||
from .liquidity import LiquidityAgent
|
||||
from .macro import MacroAgent
|
||||
from .momentum import MomentumAgent
|
||||
from .news import NewsAgent
|
||||
from .risk import RiskAgent
|
||||
from .value import ValueAgent
|
||||
|
||||
|
||||
def default_agents() -> List[Agent]:
|
||||
return [
|
||||
MomentumAgent(),
|
||||
ValueAgent(),
|
||||
NewsAgent(),
|
||||
LiquidityAgent(),
|
||||
MacroAgent(),
|
||||
RiskAgent(),
|
||||
]
|
||||
|
||||
|
||||
def weight_map(raw: Dict[str, float]) -> Dict[str, float]:
|
||||
total = sum(raw.values())
|
||||
if total == 0:
|
||||
return raw
|
||||
return {name: weight / total for name, weight in raw.items()}
|
||||
29
app/agents/risk.py
Normal file
29
app/agents/risk.py
Normal file
@ -0,0 +1,29 @@
|
||||
"""Risk agent acts as leader with veto rights."""
|
||||
from __future__ import annotations
|
||||
|
||||
from .base import Agent, AgentAction, AgentContext
|
||||
|
||||
|
||||
class RiskAgent(Agent):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(name="A_risk")
|
||||
|
||||
def score(self, context: AgentContext, action: AgentAction) -> float:
|
||||
# Base risk agent is neutral unless penalties are triggered.
|
||||
penalty = context.features.get("risk_penalty", 0.0)
|
||||
if action is AgentAction.SELL:
|
||||
return min(1.0, 0.6 + penalty)
|
||||
if action is AgentAction.HOLD:
|
||||
return 0.5
|
||||
return max(0.0, 1.0 - penalty)
|
||||
|
||||
def feasible(self, context: AgentContext, action: AgentAction) -> bool:
|
||||
if action is AgentAction.SELL:
|
||||
return True
|
||||
if context.features.get("is_suspended", False):
|
||||
return False
|
||||
if context.features.get("limit_up", False) and action not in (AgentAction.SELL, AgentAction.HOLD):
|
||||
return False
|
||||
if context.features.get("position_limit", False) and action in (AgentAction.BUY_M, AgentAction.BUY_L):
|
||||
return False
|
||||
return True
|
||||
26
app/agents/value.py
Normal file
26
app/agents/value.py
Normal file
@ -0,0 +1,26 @@
|
||||
"""Value and quality filtering agent."""
|
||||
from __future__ import annotations
|
||||
|
||||
from .base import Agent, AgentAction, AgentContext
|
||||
|
||||
|
||||
class ValueAgent(Agent):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(name="A_val")
|
||||
|
||||
def score(self, context: AgentContext, action: AgentAction) -> float:
|
||||
pe = context.features.get("pe_percentile", 0.5)
|
||||
pb = context.features.get("pb_percentile", 0.5)
|
||||
roe = context.features.get("roe_percentile", 0.5)
|
||||
# Lower valuation percentiles and higher quality percentiles add value.
|
||||
raw = max(0.0, (1 - pe) * 0.4 + (1 - pb) * 0.3 + roe * 0.3)
|
||||
raw = min(raw, 1.0)
|
||||
if action is AgentAction.SELL:
|
||||
return 1 - raw
|
||||
if action is AgentAction.HOLD:
|
||||
return 0.5
|
||||
if action is AgentAction.BUY_S:
|
||||
return raw * 0.7
|
||||
if action is AgentAction.BUY_M:
|
||||
return raw * 0.85
|
||||
return raw
|
||||
0
app/backtest/__init__.py
Normal file
0
app/backtest/__init__.py
Normal file
90
app/backtest/engine.py
Normal file
90
app/backtest/engine.py
Normal file
@ -0,0 +1,90 @@
|
||||
"""Backtest engine skeleton for daily bar simulation."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import date
|
||||
from typing import Dict, Iterable, List, Mapping
|
||||
|
||||
from app.agents.base import AgentContext
|
||||
from app.agents.game import Decision, decide
|
||||
from app.agents.registry import default_agents
|
||||
from app.utils.db import db_session
|
||||
|
||||
|
||||
@dataclass
|
||||
class BtConfig:
|
||||
id: str
|
||||
name: str
|
||||
start_date: date
|
||||
end_date: date
|
||||
universe: List[str]
|
||||
params: Dict[str, float]
|
||||
method: str = "nash"
|
||||
|
||||
|
||||
@dataclass
|
||||
class PortfolioState:
|
||||
cash: float = 1_000_000.0
|
||||
holdings: Dict[str, float] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BacktestResult:
|
||||
nav_series: List[Dict[str, float]] = field(default_factory=list)
|
||||
trades: List[Dict[str, str]] = field(default_factory=list)
|
||||
|
||||
|
||||
class BacktestEngine:
|
||||
"""Runs the multi-agent game inside a daily event-driven loop."""
|
||||
|
||||
def __init__(self, cfg: BtConfig) -> None:
|
||||
self.cfg = cfg
|
||||
self.agents = default_agents()
|
||||
self.weights = {agent.name: 1.0 for agent in self.agents}
|
||||
|
||||
def load_market_data(self, trade_date: date) -> Mapping[str, Dict[str, float]]:
|
||||
"""Load per-stock feature vectors. Replace with real data access."""
|
||||
|
||||
_ = trade_date
|
||||
return {}
|
||||
|
||||
def simulate_day(self, trade_date: date, state: PortfolioState) -> List[Decision]:
|
||||
feature_map = self.load_market_data(trade_date)
|
||||
decisions: List[Decision] = []
|
||||
for ts_code, features in feature_map.items():
|
||||
context = AgentContext(ts_code=ts_code, trade_date=trade_date.isoformat(), features=features)
|
||||
decision = decide(context, self.agents, self.weights, method=self.cfg.method)
|
||||
decisions.append(decision)
|
||||
self.record_agent_state(context, decision)
|
||||
# TODO: translate decisions into fills, holdings, and NAV updates.
|
||||
_ = state
|
||||
return decisions
|
||||
|
||||
def record_agent_state(self, context: AgentContext, decision: Decision) -> None:
|
||||
payload = {
|
||||
"trade_date": context.trade_date,
|
||||
"ts_code": context.ts_code,
|
||||
"action": decision.action.value,
|
||||
"confidence": decision.confidence,
|
||||
}
|
||||
_ = payload
|
||||
# Implementation should persist into agent_utils and bt_trades.
|
||||
|
||||
def run(self) -> BacktestResult:
|
||||
state = PortfolioState()
|
||||
result = BacktestResult()
|
||||
current = self.cfg.start_date
|
||||
while current <= self.cfg.end_date:
|
||||
decisions = self.simulate_day(current, state)
|
||||
_ = decisions
|
||||
current = date.fromordinal(current.toordinal() + 1)
|
||||
return result
|
||||
|
||||
|
||||
def run_backtest(cfg: BtConfig) -> BacktestResult:
|
||||
engine = BacktestEngine(cfg)
|
||||
result = engine.run()
|
||||
with db_session() as conn:
|
||||
_ = conn
|
||||
# Implementation should persist bt_nav, bt_trades, and bt_report rows.
|
||||
return result
|
||||
19
app/backtest/metrics.py
Normal file
19
app/backtest/metrics.py
Normal file
@ -0,0 +1,19 @@
|
||||
"""Performance metric utilities."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, List
|
||||
|
||||
|
||||
@dataclass
|
||||
class Metric:
|
||||
name: str
|
||||
value: float
|
||||
description: str
|
||||
|
||||
|
||||
def compute_nav_metrics(nav_series: Iterable[float]) -> List[Metric]:
|
||||
"""Compute core statistics such as CAGR, Sharpe, and max drawdown."""
|
||||
|
||||
_ = nav_series
|
||||
return []
|
||||
0
app/data/__init__.py
Normal file
0
app/data/__init__.py
Normal file
146
app/data/schema.py
Normal file
146
app/data/schema.py
Normal file
@ -0,0 +1,146 @@
|
||||
"""Database schema management for the investment assistant."""
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable
|
||||
|
||||
from app.utils.config import get_config
|
||||
from app.utils.db import db_session
|
||||
|
||||
|
||||
SCHEMA_STATEMENTS: Iterable[str] = (
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS news (
|
||||
id TEXT PRIMARY KEY,
|
||||
ts_code TEXT,
|
||||
pub_time TEXT,
|
||||
source TEXT,
|
||||
title TEXT,
|
||||
summary TEXT,
|
||||
url TEXT,
|
||||
entities TEXT,
|
||||
sentiment REAL,
|
||||
heat REAL
|
||||
);
|
||||
""",
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_news_time ON news(pub_time DESC);
|
||||
""",
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_news_code ON news(ts_code, pub_time DESC);
|
||||
""",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS heat_daily (
|
||||
scope TEXT,
|
||||
key TEXT,
|
||||
trade_date TEXT,
|
||||
heat REAL,
|
||||
top_topics TEXT,
|
||||
PRIMARY KEY (scope, key, trade_date)
|
||||
);
|
||||
""",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS bt_config (
|
||||
id TEXT PRIMARY KEY,
|
||||
name TEXT,
|
||||
start_date TEXT,
|
||||
end_date TEXT,
|
||||
universe TEXT,
|
||||
params TEXT
|
||||
);
|
||||
""",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS bt_trades (
|
||||
cfg_id TEXT,
|
||||
ts_code TEXT,
|
||||
trade_date TEXT,
|
||||
side TEXT,
|
||||
price REAL,
|
||||
qty REAL,
|
||||
reason TEXT,
|
||||
PRIMARY KEY (cfg_id, ts_code, trade_date, side)
|
||||
);
|
||||
""",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS bt_nav (
|
||||
cfg_id TEXT,
|
||||
trade_date TEXT,
|
||||
nav REAL,
|
||||
ret REAL,
|
||||
pos_count INTEGER,
|
||||
turnover REAL,
|
||||
dd REAL,
|
||||
info TEXT,
|
||||
PRIMARY KEY (cfg_id, trade_date)
|
||||
);
|
||||
""",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS bt_report (
|
||||
cfg_id TEXT PRIMARY KEY,
|
||||
summary TEXT
|
||||
);
|
||||
""",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS run_log (
|
||||
ts TEXT PRIMARY KEY,
|
||||
stage TEXT,
|
||||
level TEXT,
|
||||
msg TEXT
|
||||
);
|
||||
""",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS agent_utils (
|
||||
trade_date TEXT,
|
||||
ts_code TEXT,
|
||||
agent TEXT,
|
||||
action TEXT,
|
||||
utils TEXT,
|
||||
feasible TEXT,
|
||||
weight REAL,
|
||||
PRIMARY KEY (trade_date, ts_code, agent)
|
||||
);
|
||||
""",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS alloc_log (
|
||||
trade_date TEXT,
|
||||
ts_code TEXT,
|
||||
target_weight REAL,
|
||||
clipped_weight REAL,
|
||||
reason TEXT,
|
||||
PRIMARY KEY (trade_date, ts_code)
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MigrationResult:
|
||||
executed: int
|
||||
skipped: bool = False
|
||||
|
||||
|
||||
def _schema_exists() -> bool:
|
||||
try:
|
||||
with db_session(read_only=True) as conn:
|
||||
cursor = conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='news'"
|
||||
)
|
||||
return cursor.fetchone() is not None
|
||||
except sqlite3.OperationalError:
|
||||
return False
|
||||
|
||||
|
||||
def initialize_database() -> MigrationResult:
|
||||
"""Create tables and indexes required by the application."""
|
||||
|
||||
if _schema_exists():
|
||||
return MigrationResult(executed=0, skipped=True)
|
||||
|
||||
executed = 0
|
||||
with db_session() as conn:
|
||||
cursor = conn.cursor()
|
||||
for statement in SCHEMA_STATEMENTS:
|
||||
cursor.executescript(statement)
|
||||
executed += 1
|
||||
return MigrationResult(executed=executed)
|
||||
0
app/features/__init__.py
Normal file
0
app/features/__init__.py
Normal file
39
app/features/factors.py
Normal file
39
app/features/factors.py
Normal file
@ -0,0 +1,39 @@
|
||||
"""Feature engineering for signals and indicator computation."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import date
|
||||
from typing import Iterable, List
|
||||
|
||||
|
||||
@dataclass
|
||||
class FactorSpec:
|
||||
name: str
|
||||
window: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class FactorResult:
|
||||
ts_code: str
|
||||
trade_date: date
|
||||
values: dict
|
||||
|
||||
|
||||
DEFAULT_FACTORS: List[FactorSpec] = [
|
||||
FactorSpec("mom_20", 20),
|
||||
FactorSpec("mom_60", 60),
|
||||
FactorSpec("volat_20", 20),
|
||||
FactorSpec("turn_20", 20),
|
||||
]
|
||||
|
||||
|
||||
def compute_factors(trade_date: date, factors: Iterable[FactorSpec] = DEFAULT_FACTORS) -> List[FactorResult]:
|
||||
"""Calculate factor values for the requested date.
|
||||
|
||||
This function should join historical price data, apply rolling windows, and
|
||||
persist results into an factors table. The implementation is left as future
|
||||
work.
|
||||
"""
|
||||
|
||||
_ = trade_date, factors
|
||||
raise NotImplementedError
|
||||
0
app/ingest/__init__.py
Normal file
0
app/ingest/__init__.py
Normal file
42
app/ingest/rss.py
Normal file
42
app/ingest/rss.py
Normal file
@ -0,0 +1,42 @@
|
||||
"""RSS ingestion for news and heat scores."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Iterable, List
|
||||
|
||||
|
||||
@dataclass
|
||||
class RssItem:
|
||||
id: str
|
||||
title: str
|
||||
link: str
|
||||
published: datetime
|
||||
summary: str
|
||||
source: str
|
||||
|
||||
|
||||
def fetch_rss_feed(url: str) -> List[RssItem]:
|
||||
"""Download and parse an RSS feed into structured items."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def deduplicate_items(items: Iterable[RssItem]) -> List[RssItem]:
|
||||
"""Drop duplicate stories by link/id fingerprint."""
|
||||
|
||||
seen = set()
|
||||
unique: List[RssItem] = []
|
||||
for item in items:
|
||||
key = item.id or item.link
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
unique.append(item)
|
||||
return unique
|
||||
|
||||
|
||||
def save_news_items(items: Iterable[RssItem]) -> None:
|
||||
"""Persist RSS items into the `news` table."""
|
||||
|
||||
raise NotImplementedError
|
||||
240
app/ingest/tushare.py
Normal file
240
app/ingest/tushare.py
Normal file
@ -0,0 +1,240 @@
|
||||
"""TuShare 数据拉取管线实现。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from datetime import date
|
||||
from typing import Dict, Iterable, List, Optional, Sequence
|
||||
|
||||
import pandas as pd
|
||||
|
||||
try:
|
||||
import tushare as ts
|
||||
except ImportError as exc: # pragma: no cover - dependency error surfaced at runtime
|
||||
ts = None # type: ignore[assignment]
|
||||
|
||||
from app.utils.config import get_config
|
||||
from app.utils.db import db_session
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FetchJob:
|
||||
name: str
|
||||
start: date
|
||||
end: date
|
||||
granularity: str = "daily"
|
||||
ts_codes: Optional[Sequence[str]] = None
|
||||
|
||||
|
||||
_TABLE_SCHEMAS: Dict[str, str] = {
|
||||
"daily": """
|
||||
CREATE TABLE IF NOT EXISTS daily (
|
||||
ts_code TEXT,
|
||||
trade_date TEXT,
|
||||
open REAL,
|
||||
high REAL,
|
||||
low REAL,
|
||||
close REAL,
|
||||
pre_close REAL,
|
||||
change REAL,
|
||||
pct_chg REAL,
|
||||
vol REAL,
|
||||
amount REAL,
|
||||
PRIMARY KEY (ts_code, trade_date)
|
||||
);
|
||||
""",
|
||||
"suspend": """
|
||||
CREATE TABLE IF NOT EXISTS suspend (
|
||||
ts_code TEXT,
|
||||
suspend_date TEXT,
|
||||
resume_date TEXT,
|
||||
suspend_type TEXT,
|
||||
ann_date TEXT,
|
||||
suspend_timing TEXT,
|
||||
resume_timing TEXT,
|
||||
reason TEXT,
|
||||
PRIMARY KEY (ts_code, suspend_date)
|
||||
);
|
||||
""",
|
||||
"trade_calendar": """
|
||||
CREATE TABLE IF NOT EXISTS trade_calendar (
|
||||
exchange TEXT,
|
||||
cal_date TEXT PRIMARY KEY,
|
||||
is_open INTEGER,
|
||||
pretrade_date TEXT
|
||||
);
|
||||
""",
|
||||
"stk_limit": """
|
||||
CREATE TABLE IF NOT EXISTS stk_limit (
|
||||
ts_code TEXT,
|
||||
trade_date TEXT,
|
||||
up_limit REAL,
|
||||
down_limit REAL,
|
||||
PRIMARY KEY (ts_code, trade_date)
|
||||
);
|
||||
""",
|
||||
}
|
||||
|
||||
_TABLE_COLUMNS: Dict[str, List[str]] = {
|
||||
"daily": [
|
||||
"ts_code",
|
||||
"trade_date",
|
||||
"open",
|
||||
"high",
|
||||
"low",
|
||||
"close",
|
||||
"pre_close",
|
||||
"change",
|
||||
"pct_chg",
|
||||
"vol",
|
||||
"amount",
|
||||
],
|
||||
"suspend": [
|
||||
"ts_code",
|
||||
"suspend_date",
|
||||
"resume_date",
|
||||
"suspend_type",
|
||||
"ann_date",
|
||||
"suspend_timing",
|
||||
"resume_timing",
|
||||
"reason",
|
||||
],
|
||||
"trade_calendar": [
|
||||
"exchange",
|
||||
"cal_date",
|
||||
"is_open",
|
||||
"pretrade_date",
|
||||
],
|
||||
"stk_limit": [
|
||||
"ts_code",
|
||||
"trade_date",
|
||||
"up_limit",
|
||||
"down_limit",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _ensure_client():
|
||||
if ts is None:
|
||||
raise RuntimeError("未安装 tushare,请先在环境中安装 tushare 包")
|
||||
token = get_config().tushare_token or os.getenv("TUSHARE_TOKEN")
|
||||
if not token:
|
||||
raise RuntimeError("未配置 TuShare Token,请在配置文件或环境变量 TUSHARE_TOKEN 中设置")
|
||||
if not hasattr(_ensure_client, "_client") or _ensure_client._client is None: # type: ignore[attr-defined]
|
||||
ts.set_token(token)
|
||||
_ensure_client._client = ts.pro_api(token) # type: ignore[attr-defined]
|
||||
LOGGER.info("完成 TuShare 客户端初始化")
|
||||
return _ensure_client._client # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _format_date(value: date) -> str:
|
||||
return value.strftime("%Y%m%d")
|
||||
|
||||
|
||||
def _df_to_records(df: pd.DataFrame, allowed_cols: List[str]) -> List[Dict]:
|
||||
if df is None or df.empty:
|
||||
return []
|
||||
# 对缺失列进行补全,防止写库时缺少绑定参数
|
||||
reindexed = df.reindex(columns=allowed_cols)
|
||||
return reindexed.where(pd.notnull(reindexed), None).to_dict("records")
|
||||
|
||||
|
||||
def fetch_daily_bars(job: FetchJob) -> Iterable[Dict]:
|
||||
"""拉取日线行情。"""
|
||||
|
||||
client = _ensure_client()
|
||||
start_date = _format_date(job.start)
|
||||
end_date = _format_date(job.end)
|
||||
frames: List[pd.DataFrame] = []
|
||||
|
||||
if job.granularity != "daily":
|
||||
raise ValueError(f"暂不支持的粒度:{job.granularity}")
|
||||
|
||||
if job.ts_codes:
|
||||
for code in job.ts_codes:
|
||||
LOGGER.info("拉取 %s 的日线行情(%s-%s)", code, start_date, end_date)
|
||||
frames.append(client.daily(ts_code=code, start_date=start_date, end_date=end_date))
|
||||
else:
|
||||
LOGGER.info("按全市场拉取日线行情(%s-%s)", start_date, end_date)
|
||||
frames.append(client.daily(start_date=start_date, end_date=end_date))
|
||||
|
||||
if not frames:
|
||||
return []
|
||||
df = pd.concat(frames, ignore_index=True) if len(frames) > 1 else frames[0]
|
||||
return _df_to_records(df, _TABLE_COLUMNS["daily"])
|
||||
|
||||
|
||||
def fetch_suspensions(start: date, end: date, ts_code: Optional[str] = None) -> Iterable[Dict]:
|
||||
client = _ensure_client()
|
||||
start_date = _format_date(start)
|
||||
end_date = _format_date(end)
|
||||
LOGGER.info("拉取停复牌信息(%s-%s)", start_date, end_date)
|
||||
df = client.suspend_d(ts_code=ts_code, start_date=start_date, end_date=end_date)
|
||||
return _df_to_records(df, _TABLE_COLUMNS["suspend"])
|
||||
|
||||
|
||||
def fetch_trade_calendar(start: date, end: date, exchange: str = "SSE") -> Iterable[Dict]:
|
||||
client = _ensure_client()
|
||||
start_date = _format_date(start)
|
||||
end_date = _format_date(end)
|
||||
LOGGER.info("拉取交易日历(交易所:%s,区间:%s-%s)", exchange, start_date, end_date)
|
||||
df = client.trade_cal(exchange=exchange, start_date=start_date, end_date=end_date)
|
||||
return _df_to_records(df, _TABLE_COLUMNS["trade_calendar"])
|
||||
|
||||
|
||||
def fetch_stk_limit(start: date, end: date, ts_code: Optional[str] = None) -> Iterable[Dict]:
|
||||
client = _ensure_client()
|
||||
start_date = _format_date(start)
|
||||
end_date = _format_date(end)
|
||||
LOGGER.info("拉取涨跌停价格(%s-%s)", start_date, end_date)
|
||||
df = client.stk_limit(ts_code=ts_code, start_date=start_date, end_date=end_date)
|
||||
return _df_to_records(df, _TABLE_COLUMNS["stk_limit"])
|
||||
|
||||
|
||||
def save_records(table: str, rows: Iterable[Dict]) -> None:
|
||||
"""将拉取的数据写入 SQLite。"""
|
||||
|
||||
items = list(rows)
|
||||
if not items:
|
||||
LOGGER.info("表 %s 没有新增记录,跳过写入", table)
|
||||
return
|
||||
|
||||
schema = _TABLE_SCHEMAS.get(table)
|
||||
columns = _TABLE_COLUMNS.get(table)
|
||||
if not schema or not columns:
|
||||
raise ValueError(f"不支持写入的表:{table}")
|
||||
|
||||
placeholders = ",".join([f":{col}" for col in columns])
|
||||
col_clause = ",".join(columns)
|
||||
|
||||
LOGGER.info("表 %s 写入 %d 条记录", table, len(items))
|
||||
with db_session() as conn:
|
||||
conn.executescript(schema)
|
||||
conn.executemany(
|
||||
f"INSERT OR REPLACE INTO {table} ({col_clause}) VALUES ({placeholders})",
|
||||
items,
|
||||
)
|
||||
|
||||
|
||||
def run_ingestion(job: FetchJob, include_limits: bool = True) -> None:
|
||||
"""按任务配置拉取 TuShare 数据。"""
|
||||
|
||||
LOGGER.info("启动 TuShare 拉取任务:%s", job.name)
|
||||
|
||||
daily_rows = fetch_daily_bars(job)
|
||||
save_records("daily", daily_rows)
|
||||
|
||||
suspend_rows = fetch_suspensions(job.start, job.end)
|
||||
save_records("suspend", suspend_rows)
|
||||
|
||||
calendar_rows = fetch_trade_calendar(job.start, job.end)
|
||||
save_records("trade_calendar", calendar_rows)
|
||||
|
||||
if include_limits:
|
||||
limit_rows = fetch_stk_limit(job.start, job.end)
|
||||
save_records("stk_limit", limit_rows)
|
||||
|
||||
LOGGER.info("任务 %s 完成", job.name)
|
||||
0
app/llm/__init__.py
Normal file
0
app/llm/__init__.py
Normal file
18
app/llm/explain.py
Normal file
18
app/llm/explain.py
Normal file
@ -0,0 +1,18 @@
|
||||
"""LLM assisted explanations and summaries."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from .prompts import plan_prompt
|
||||
|
||||
|
||||
def make_human_card(ts_code: str, trade_date: str, context: Dict) -> Dict:
|
||||
"""Compose payload for UI cards and LLM requests."""
|
||||
|
||||
prompt = plan_prompt(context)
|
||||
return {
|
||||
"ts_code": ts_code,
|
||||
"trade_date": trade_date,
|
||||
"prompt": prompt,
|
||||
"context": context,
|
||||
}
|
||||
11
app/llm/prompts.py
Normal file
11
app/llm/prompts.py
Normal file
@ -0,0 +1,11 @@
|
||||
"""Prompt templates for natural language outputs."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict
|
||||
|
||||
|
||||
def plan_prompt(data: Dict) -> str:
|
||||
"""Build a concise instruction prompt for the LLM."""
|
||||
|
||||
_ = data
|
||||
return "你是一个投资助理,请根据提供的数据给出三条要点和两条风险提示。"
|
||||
35
app/main.py
Normal file
35
app/main.py
Normal file
@ -0,0 +1,35 @@
|
||||
"""Command line entry points for routine tasks."""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
|
||||
from app.backtest.engine import BtConfig, run_backtest
|
||||
from app.data.schema import initialize_database
|
||||
|
||||
|
||||
def init_db() -> None:
|
||||
result = initialize_database()
|
||||
if result.skipped:
|
||||
print("Database already initialized; skipping schema creation")
|
||||
else:
|
||||
print(f"Initialized database with {result.executed} statements")
|
||||
|
||||
|
||||
def run_sample_backtest() -> None:
|
||||
cfg = BtConfig(
|
||||
id="demo",
|
||||
name="Demo Strategy",
|
||||
start_date=date(2020, 1, 1),
|
||||
end_date=date(2020, 3, 31),
|
||||
universe=["000001.SZ"],
|
||||
params={
|
||||
"target": 0.035,
|
||||
"stop": -0.015,
|
||||
"hold_days": 10,
|
||||
},
|
||||
)
|
||||
run_backtest(cfg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_db()
|
||||
0
app/ui/__init__.py
Normal file
0
app/ui/__init__.py
Normal file
85
app/ui/streamlit_app.py
Normal file
85
app/ui/streamlit_app.py
Normal file
@ -0,0 +1,85 @@
|
||||
"""Streamlit UI scaffold for the investment assistant."""
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from datetime import date
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[2]
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from app.data.schema import initialize_database
|
||||
from app.ingest.tushare import FetchJob, run_ingestion
|
||||
from app.llm.explain import make_human_card
|
||||
|
||||
|
||||
def render_today_plan() -> None:
|
||||
st.header("今日计划")
|
||||
st.write("待接入候选池筛选与多智能体决策结果。")
|
||||
sample = make_human_card("000001.SZ", "2025-01-01", {"decisions": []})
|
||||
st.json(sample)
|
||||
|
||||
|
||||
def render_backtest() -> None:
|
||||
st.header("回测与复盘")
|
||||
st.write("在此运行回测、展示净值曲线与代理贡献。")
|
||||
st.button("开始回测")
|
||||
|
||||
|
||||
def render_settings() -> None:
|
||||
st.header("数据与设置")
|
||||
st.text_input("TuShare Token")
|
||||
st.write("新闻源开关与数据库备份将在此配置。")
|
||||
|
||||
|
||||
def render_tests() -> None:
|
||||
st.header("自检测试")
|
||||
st.write("用于快速检查数据库与数据拉取是否正常工作。")
|
||||
|
||||
if st.button("测试数据库初始化"):
|
||||
with st.spinner("正在检查数据库..."):
|
||||
result = initialize_database()
|
||||
if result.skipped:
|
||||
st.success("数据库已存在,检查通过。")
|
||||
else:
|
||||
st.success(f"数据库初始化完成,共执行 {result.executed} 条语句。")
|
||||
|
||||
st.divider()
|
||||
|
||||
if st.button("测试 TuShare 拉取(示例 2024-01-01 至 2024-01-03)"):
|
||||
with st.spinner("正在调用 TuShare 接口..."):
|
||||
try:
|
||||
run_ingestion(
|
||||
FetchJob(
|
||||
name="streamlit_self_test",
|
||||
start=date(2024, 1, 1),
|
||||
end=date(2024, 1, 3),
|
||||
ts_codes=("000001.SZ",),
|
||||
),
|
||||
include_limits=False,
|
||||
)
|
||||
st.success("TuShare 示例拉取完成,数据已写入数据库。")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
st.error(f"拉取失败:{exc}")
|
||||
|
||||
st.info("注意:TuShare 拉取依赖网络与 Token,若环境未配置将出现错误提示。")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
st.set_page_config(page_title="多智能体投资助理", layout="wide")
|
||||
tabs = st.tabs(["今日计划", "回测与复盘", "数据与设置", "自检测试"])
|
||||
with tabs[0]:
|
||||
render_today_plan()
|
||||
with tabs[1]:
|
||||
render_backtest()
|
||||
with tabs[2]:
|
||||
render_settings()
|
||||
with tabs[3]:
|
||||
render_tests()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
0
app/utils/__init__.py
Normal file
0
app/utils/__init__.py
Normal file
34
app/utils/calendar.py
Normal file
34
app/utils/calendar.py
Normal file
@ -0,0 +1,34 @@
|
||||
"""Trading calendar utilities.
|
||||
|
||||
These helpers abstract exchange calendars and trading day lookups. Real
|
||||
implementations should integrate with TuShare or cached calendars.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date, timedelta
|
||||
from typing import Iterable, List
|
||||
|
||||
|
||||
def is_trading_day(day: date, holidays: Iterable[date] | None = None) -> bool:
|
||||
if day.weekday() >= 5:
|
||||
return False
|
||||
if holidays and day in set(holidays):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def previous_trading_day(day: date, holidays: Iterable[date] | None = None) -> date:
|
||||
current = day - timedelta(days=1)
|
||||
while not is_trading_day(current, holidays):
|
||||
current -= timedelta(days=1)
|
||||
return current
|
||||
|
||||
|
||||
def trading_days_between(start: date, end: date, holidays: Iterable[date] | None = None) -> List[date]:
|
||||
current = start
|
||||
days: List[date] = []
|
||||
while current <= end:
|
||||
if is_trading_day(current, holidays):
|
||||
days.append(current)
|
||||
current += timedelta(days=1)
|
||||
return days
|
||||
65
app/utils/config.py
Normal file
65
app/utils/config.py
Normal file
@ -0,0 +1,65 @@
|
||||
"""Application configuration models and helpers."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
def _default_root() -> Path:
|
||||
return Path(__file__).resolve().parents[2] / "app" / "data"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataPaths:
|
||||
"""Holds filesystem locations for persistent artifacts."""
|
||||
|
||||
root: Path = field(default_factory=_default_root)
|
||||
database: Path = field(init=False)
|
||||
backups: Path = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.root.mkdir(parents=True, exist_ok=True)
|
||||
self.database = self.root / "llm_quant.db"
|
||||
self.backups = self.root / "backups"
|
||||
self.backups.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentWeights:
|
||||
"""Default weighting for decision agents."""
|
||||
|
||||
momentum: float = 0.30
|
||||
value: float = 0.20
|
||||
news: float = 0.20
|
||||
liquidity: float = 0.15
|
||||
macro: float = 0.15
|
||||
|
||||
def as_dict(self) -> Dict[str, float]:
|
||||
return {
|
||||
"A_mom": self.momentum,
|
||||
"A_val": self.value,
|
||||
"A_news": self.news,
|
||||
"A_liq": self.liquidity,
|
||||
"A_macro": self.macro,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class AppConfig:
|
||||
"""User configurable settings persisted in a simple structure."""
|
||||
|
||||
tushare_token: Optional[str] = None
|
||||
rss_sources: Dict[str, bool] = field(default_factory=dict)
|
||||
decision_method: str = "nash"
|
||||
data_paths: DataPaths = field(default_factory=DataPaths)
|
||||
agent_weights: AgentWeights = field(default_factory=AgentWeights)
|
||||
|
||||
|
||||
CONFIG = AppConfig()
|
||||
|
||||
|
||||
def get_config() -> AppConfig:
|
||||
"""Return a mutable global configuration instance."""
|
||||
|
||||
return CONFIG
|
||||
47
app/utils/db.py
Normal file
47
app/utils/db.py
Normal file
@ -0,0 +1,47 @@
|
||||
"""SQLite helpers for application data access."""
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Generator, Iterable, Tuple
|
||||
|
||||
from .config import get_config
|
||||
|
||||
|
||||
def get_connection(read_only: bool = False) -> sqlite3.Connection:
|
||||
"""Create a SQLite connection against the configured database file."""
|
||||
|
||||
db_path: Path = get_config().data_paths.database
|
||||
uri = f"file:{db_path}?mode={'ro' if read_only else 'rw'}"
|
||||
if not db_path.exists() and not read_only:
|
||||
# Ensure directory exists before first write.
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(db_path)
|
||||
else:
|
||||
conn = sqlite3.connect(uri, uri=True)
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
|
||||
|
||||
@contextmanager
|
||||
def db_session(read_only: bool = False) -> Generator[sqlite3.Connection, None, None]:
|
||||
"""Context manager yielding a connection with automatic commit/rollback."""
|
||||
|
||||
conn = get_connection(read_only=read_only)
|
||||
try:
|
||||
yield conn
|
||||
if not read_only:
|
||||
conn.commit()
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def execute_many(sql: str, params: Iterable[Tuple]) -> None:
|
||||
"""Bulk execute a parameterized statement inside a write session."""
|
||||
|
||||
with db_session() as conn:
|
||||
conn.executemany(sql, params)
|
||||
28
app/utils/logging.py
Normal file
28
app/utils/logging.py
Normal file
@ -0,0 +1,28 @@
|
||||
"""Centralized logging configuration."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from .config import get_config
|
||||
|
||||
|
||||
def configure_logging(level: int = logging.INFO) -> None:
|
||||
"""Setup root logger with file and console handlers."""
|
||||
|
||||
cfg = get_config()
|
||||
log_dir = cfg.data_paths.root / "logs"
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
logfile = log_dir / "app.log"
|
||||
|
||||
logging.basicConfig(
|
||||
level=level,
|
||||
format="%(asctime)s %(levelname)s %(name)s - %(message)s",
|
||||
handlers=[
|
||||
logging.FileHandler(logfile, encoding="utf-8"),
|
||||
logging.StreamHandler(),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
configure_logging()
|
||||
Loading…
Reference in New Issue
Block a user