This commit is contained in:
sam 2025-10-05 15:26:01 +08:00
parent 16a5fae732
commit adfc8ee148
9 changed files with 662 additions and 46 deletions

View File

@ -9,6 +9,17 @@ from app.utils.db import db_session
SCHEMA_STATEMENTS: Iterable[str] = ( SCHEMA_STATEMENTS: Iterable[str] = (
"""
CREATE TABLE IF NOT EXISTS fetch_jobs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
job_type TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'pending',
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
error_msg TEXT,
metadata TEXT -- JSON object for additional info
);
""",
""" """
CREATE TABLE IF NOT EXISTS stock_basic ( CREATE TABLE IF NOT EXISTS stock_basic (
ts_code TEXT PRIMARY KEY, ts_code TEXT PRIMARY KEY,
@ -515,17 +526,32 @@ def _missing_tables() -> List[str]:
return [name for name in REQUIRED_TABLES if name not in existing] return [name for name in REQUIRED_TABLES if name not in existing]
def initialize_database() -> MigrationResult: def initialize_database() -> None:
"""Create tables and indexes required by the application.""" """Initialize the SQLite database with all required tables."""
with db_session() as session:
cursor = session.cursor()
missing = _missing_tables() # 创建表
if not missing:
return MigrationResult(executed=0, skipped=True, missing_tables=[])
executed = 0
with db_session() as conn:
cursor = conn.cursor()
for statement in SCHEMA_STATEMENTS: for statement in SCHEMA_STATEMENTS:
cursor.executescript(statement) try:
executed += 1 cursor.execute(statement)
return MigrationResult(executed=executed, skipped=False, missing_tables=missing) except Exception as e: # noqa: BLE001
print(f"初始化数据库时出错: {e}")
raise
# 添加触发器以自动更新 updated_at 字段
try:
cursor.execute("""
CREATE TRIGGER IF NOT EXISTS update_fetch_jobs_timestamp
AFTER UPDATE ON fetch_jobs
BEGIN
UPDATE fetch_jobs
SET updated_at = CURRENT_TIMESTAMP
WHERE id = NEW.id;
END;
""")
except Exception as e: # noqa: BLE001
print(f"创建触发器时出错: {e}")
raise
session.commit()

85
app/ingest/job_logger.py Normal file
View File

@ -0,0 +1,85 @@
"""任务记录工具类。"""
from __future__ import annotations
import json
from datetime import datetime
from typing import Any, Dict, Optional
from app.utils.db import db_session
class JobLogger:
"""任务记录器。"""
def __init__(self, job_type: str) -> None:
"""初始化任务记录器。
Args:
job_type: 任务类型
"""
self.job_type = job_type
self.job_id: Optional[int] = None
def __enter__(self) -> "JobLogger":
"""开始记录任务。"""
with db_session() as session:
cursor = session.execute(
"""
INSERT INTO fetch_jobs (job_type, status, created_at, updated_at)
VALUES (?, 'running', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
""",
(self.job_type,)
)
self.job_id = cursor.lastrowid
session.commit()
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
"""结束任务记录。"""
if exc_val:
self.update_status("failed", str(exc_val))
else:
self.update_status("success")
def update_status(self, status: str, error_msg: Optional[str] = None) -> None:
"""更新任务状态。
Args:
status: 新状态
error_msg: 错误信息如果有
"""
if not self.job_id:
return
with db_session() as session:
session.execute(
"""
UPDATE fetch_jobs
SET status = ?,
error_msg = ?,
updated_at = CURRENT_TIMESTAMP
WHERE id = ?
""",
(status, error_msg, self.job_id)
)
session.commit()
def update_metadata(self, metadata: Dict[str, Any]) -> None:
"""更新任务元数据。
Args:
metadata: 元数据字典
"""
if not self.job_id:
return
with db_session() as session:
session.execute(
"""
UPDATE fetch_jobs
SET metadata = ?
WHERE id = ?
""",
(json.dumps(metadata), self.job_id)
)
session.commit()

View File

@ -227,6 +227,9 @@ def _fetch_paginated(endpoint: str, params: Dict[str, object], limit: int | None
return merged return merged
from .job_logger import JobLogger
@dataclass @dataclass
class FetchJob: class FetchJob:
name: str name: str
@ -1602,8 +1605,17 @@ def collect_data_coverage(start: date, end: date) -> Dict[str, Dict[str, object]
def run_ingestion(job: FetchJob, include_limits: bool = True) -> None: def run_ingestion(job: FetchJob, include_limits: bool = True) -> None:
"""运行数据拉取任务。
Args:
job: 任务配置
include_limits: 是否包含涨跌停数据
"""
with JobLogger("TuShare数据获取") as logger:
LOGGER.info("启动 TuShare 拉取任务:%s", job.name, extra=LOG_EXTRA) LOGGER.info("启动 TuShare 拉取任务:%s", job.name, extra=LOG_EXTRA)
try: try:
# 拉取基础数据
ensure_data_coverage( ensure_data_coverage(
job.start, job.start,
job.end, job.end,
@ -1612,25 +1624,39 @@ def run_ingestion(job: FetchJob, include_limits: bool = True) -> None:
include_extended=True, include_extended=True,
force=True, force=True,
) )
except Exception as exc:
alerts.add_warning("TuShare", f"拉取任务失败:{job.name}", str(exc)) # 记录任务元数据
raise logger.update_metadata({
else: "name": job.name,
"start": str(job.start),
"end": str(job.end),
"codes": len(job.ts_codes) if job.ts_codes else 0
})
alerts.clear_warnings("TuShare") alerts.clear_warnings("TuShare")
# 对日线数据计算因子
if job.granularity == "daily": if job.granularity == "daily":
try:
LOGGER.info("开始计算因子:%s", job.name, extra=LOG_EXTRA) LOGGER.info("开始计算因子:%s", job.name, extra=LOG_EXTRA)
try:
compute_factor_range( compute_factor_range(
job.start, job.start,
job.end, job.end,
ts_codes=job.ts_codes, ts_codes=job.ts_codes,
skip_existing=False, skip_existing=False,
) )
except Exception as exc:
alerts.add_warning("Factors", f"因子计算失败:{job.name}", str(exc))
LOGGER.exception("因子计算失败 job=%s", job.name, extra=LOG_EXTRA)
raise
else:
alerts.clear_warnings("Factors") alerts.clear_warnings("Factors")
except Exception as exc:
LOGGER.exception("因子计算失败 job=%s", job.name, extra=LOG_EXTRA)
alerts.add_warning("Factors", f"因子计算失败:{job.name}", str(exc))
logger.update_status("failed", f"因子计算失败:{exc}")
raise
LOGGER.info("因子计算完成:%s", job.name, extra=LOG_EXTRA) LOGGER.info("因子计算完成:%s", job.name, extra=LOG_EXTRA)
alerts.clear_warnings("Factors")
except Exception as exc:
LOGGER.exception("数据拉取失败 job=%s", job.name, extra=LOG_EXTRA)
alerts.add_warning("TuShare", f"拉取任务失败:{job.name}", str(exc))
logger.update_status("failed", f"数据拉取失败:{exc}")
raise
LOGGER.info("任务 %s 完成", job.name, extra=LOG_EXTRA) LOGGER.info("任务 %s 完成", job.name, extra=LOG_EXTRA)

159
app/ui/portfolio_config.py Normal file
View File

@ -0,0 +1,159 @@
"""Portfolio configuration UI components."""
from __future__ import annotations
import streamlit as st
import numpy as np
import pandas as pd
from app.utils.portfolio_init import get_portfolio_config, update_portfolio_config
def render_portfolio_config() -> None:
"""渲染投资组合配置界面."""
st.title("投资组合配置")
# 获取当前配置
config = get_portfolio_config()
# 基本配置部分
st.header("基本配置")
col1, col2 = st.columns(2)
with col1:
initial_capital = st.number_input(
"初始投资金额",
min_value=100000,
max_value=100000000,
value=int(config["initial_capital"]),
step=100000,
format="%d"
)
with col2:
currency = st.selectbox(
"币种",
options=["CNY", "USD", "HKD"],
index=["CNY", "USD", "HKD"].index(config["currency"])
)
# 仓位限制配置
st.header("仓位限制")
position_limits = config["position_limits"]
col1, col2 = st.columns(2)
with col1:
max_position = st.slider(
"单个持仓上限",
min_value=0.05,
max_value=0.50,
value=float(position_limits["max_position"]),
step=0.01,
format="%.2f",
help="单个股票最大持仓比例"
)
min_position = st.slider(
"单个持仓下限",
min_value=0.01,
max_value=0.10,
value=float(position_limits["min_position"]),
step=0.01,
format="%.2f",
help="单个股票最小持仓比例"
)
with col2:
max_total_positions = st.slider(
"最大持仓数",
min_value=5,
max_value=50,
value=int(position_limits["max_total_positions"]),
step=1,
help="投资组合中的最大股票数量"
)
max_sector_exposure = st.slider(
"行业敞口上限",
min_value=0.20,
max_value=0.50,
value=float(position_limits["max_sector_exposure"]),
step=0.05,
format="%.2f",
help="单个行业的最大持仓比例"
)
# 配置预览
st.header("当前配置概览")
df = pd.DataFrame([
["初始资金", f"{initial_capital:,} {currency}"],
["单个持仓上限", f"{max_position:.1%}"],
["单个持仓下限", f"{min_position:.1%}"],
["最大持仓数", max_total_positions],
["行业敞口上限", f"{max_sector_exposure:.1%}"],
], columns=["配置项", "当前值"])
st.table(df.set_index("配置项"))
# 保存按钮
if st.button("保存配置"):
try:
update_portfolio_config({
"initial_capital": initial_capital,
"currency": currency,
"position_limits": {
"max_position": max_position,
"min_position": min_position,
"max_total_positions": max_total_positions,
"max_sector_exposure": max_sector_exposure
}
})
st.success("配置已更新!")
except Exception as e:
st.error(f"配置更新失败:{str(e)}")
# 投资组合限制可视化
st.header("仓位限制可视化")
# 生成示例数据
example_positions = np.random.uniform(
min_position,
max_position,
min(max_total_positions, 10)
)
example_positions = example_positions / example_positions.sum()
example_sectors = {
"科技": 0.30,
"金融": 0.25,
"消费": 0.20,
"医药": 0.15,
"其他": 0.10
}
col1, col2 = st.columns(2)
with col1:
st.subheader("示例持仓分布")
positions_df = pd.DataFrame({
"股票": [f"股票{i+1}" for i in range(len(example_positions))],
"持仓比例": example_positions
})
positions_df = positions_df.sort_values("持仓比例", ascending=True)
st.bar_chart(
positions_df.set_index("股票")["持仓比例"],
use_container_width=True
)
with col2:
st.subheader("示例行业分布")
sectors_df = pd.DataFrame({
"行业": list(example_sectors.keys()),
"敞口": list(example_sectors.values())
})
st.bar_chart(
sectors_df.set_index("行业")["敞口"],
use_container_width=True
)

View File

@ -25,6 +25,7 @@ import streamlit as st
from app.agents.base import AgentContext from app.agents.base import AgentContext
from app.agents.game import Decision from app.agents.game import Decision
from app.backtest.engine import BtConfig, run_backtest from app.backtest.engine import BtConfig, run_backtest
from app.ui.portfolio_config import render_portfolio_config
from app.backtest.decision_env import DecisionEnv, ParameterSpec from app.backtest.decision_env import DecisionEnv, ParameterSpec
from app.data.schema import initialize_database from app.data.schema import initialize_database
from app.ingest.checker import run_boot_check from app.ingest.checker import run_boot_check
@ -2667,10 +2668,79 @@ def render_tests() -> None:
st.write(response) st.write(response)
def render_data_settings() -> None:
"""渲染数据源配置界面."""
st.subheader("Tushare 数据源")
cfg = get_config()
col1, col2 = st.columns(2)
with col1:
tushare_token = st.text_input(
"Tushare Token",
value=cfg.tushare_token or "",
type="password",
help="从 tushare.pro 获取的 API token"
)
with col2:
auto_update = st.checkbox(
"启动时自动更新数据",
value=cfg.auto_update_data,
help="启动应用时自动检查并更新数据"
)
update_interval = st.slider(
"数据更新间隔(天)",
min_value=1,
max_value=30,
value=cfg.data_update_interval,
help="自动更新时检查的数据时间范围"
)
if st.button("保存数据源配置"):
cfg.tushare_token = tushare_token
cfg.auto_update_data = auto_update
cfg.data_update_interval = update_interval
save_config(cfg)
st.success("数据源配置已更新!")
st.divider()
st.subheader("数据更新记录")
with db_session() as session:
df = pd.read_sql_query(
"""
SELECT job_type, status, created_at, updated_at, error_msg
FROM fetch_jobs
ORDER BY created_at DESC
LIMIT 50
""",
session
)
if not df.empty:
df["duration"] = (df["updated_at"] - df["created_at"]).dt.total_seconds().round(2)
df = df.drop(columns=["updated_at"])
df = df.rename(columns={
"job_type": "数据类型",
"status": "状态",
"created_at": "开始时间",
"error_msg": "错误信息",
"duration": "耗时(秒)"
})
st.dataframe(df, width='stretch')
else:
st.info("暂无数据更新记录")
def main() -> None: def main() -> None:
LOGGER.info("初始化 Streamlit UI", extra=LOG_EXTRA) LOGGER.info("初始化 Streamlit UI", extra=LOG_EXTRA)
st.set_page_config(page_title="多智能体个人投资助理", layout="wide") st.set_page_config(page_title="多智能体个人投资助理", layout="wide")
# 确保数据库表已创建
from app.data.schema import initialize_database
initialize_database()
# 检查是否需要自动更新数据 # 检查是否需要自动更新数据
cfg = get_config() cfg = get_config()
if cfg.auto_update_data: if cfg.auto_update_data:
@ -2713,7 +2783,18 @@ def main() -> None:
with tabs[1]: with tabs[1]:
render_log_viewer() render_log_viewer()
with tabs[2]: with tabs[2]:
st.header("系统设置")
settings_tabs = st.tabs(["基本配置", "投资组合", "数据源"])
with settings_tabs[0]:
render_settings() render_settings()
with settings_tabs[1]:
from app.ui.portfolio_config import render_portfolio_config
render_portfolio_config()
with settings_tabs[2]:
render_data_settings()
with tabs[3]: with tabs[3]:
render_tests() render_tests()

View File

@ -39,6 +39,18 @@ class DataPaths:
self.config_file = config_path self.config_file = config_path
@dataclass
class PortfolioSettings:
"""Portfolio configuration settings."""
initial_capital: float = 1000000 # 默认100万
currency: str = "CNY" # 默认人民币
max_position: float = 0.2 # 单个持仓上限 20%
min_position: float = 0.02 # 单个持仓下限 2%
max_total_positions: int = 20 # 最大持仓数
max_sector_exposure: float = 0.35 # 行业敞口上限 35%
@dataclass @dataclass
class AgentWeights: class AgentWeights:
"""Default weighting for decision agents.""" """Default weighting for decision agents."""
@ -340,9 +352,11 @@ class AppConfig:
agent_weights: AgentWeights = field(default_factory=AgentWeights) agent_weights: AgentWeights = field(default_factory=AgentWeights)
force_refresh: bool = False force_refresh: bool = False
auto_update_data: bool = False auto_update_data: bool = False
data_update_interval: int = 7 # 数据更新间隔(天)
llm_providers: Dict[str, LLMProvider] = field(default_factory=_default_llm_providers) llm_providers: Dict[str, LLMProvider] = field(default_factory=_default_llm_providers)
llm: LLMConfig = field(default_factory=LLMConfig) llm: LLMConfig = field(default_factory=LLMConfig)
departments: Dict[str, DepartmentSettings] = field(default_factory=_default_departments) departments: Dict[str, DepartmentSettings] = field(default_factory=_default_departments)
portfolio: PortfolioSettings = field(default_factory=PortfolioSettings)
def resolve_llm(self, route: Optional[str] = None) -> LLMConfig: def resolve_llm(self, route: Optional[str] = None) -> LLMConfig:
return self.llm return self.llm

View File

@ -7,6 +7,7 @@ from typing import Any, Dict, Iterable, List, Optional
from .db import db_session from .db import db_session
from .logging import get_logger from .logging import get_logger
from .portfolio_init import get_portfolio_config
LOGGER = get_logger(__name__) LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "portfolio"} LOG_EXTRA = {"stage": "portfolio"}
@ -169,8 +170,11 @@ class PortfolioSnapshot:
def get_latest_snapshot() -> Optional[PortfolioSnapshot]: def get_latest_snapshot() -> Optional[PortfolioSnapshot]:
"""Fetch the most recent portfolio snapshot.""" """Fetch the most recent portfolio snapshot.
Returns:
最新的投资组合快照如果没有数据则返回初始快照仅包含初始资金
"""
sql = """ sql = """
SELECT trade_date, total_value, cash, invested_value, unrealized_pnl, SELECT trade_date, total_value, cash, invested_value, unrealized_pnl,
realized_pnl, net_flow, exposure, notes, metadata realized_pnl, net_flow, exposure, notes, metadata
@ -186,7 +190,22 @@ def get_latest_snapshot() -> Optional[PortfolioSnapshot]:
return None return None
if not row: if not row:
return None # 如果没有快照,返回初始状态(只有初始资金)
config = get_portfolio_config()
initial_capital = config["initial_capital"]
return PortfolioSnapshot(
trade_date="", # 空日期表示初始状态
total_value=initial_capital,
cash=initial_capital,
invested_value=0.0,
unrealized_pnl=0.0,
realized_pnl=0.0,
net_flow=0.0,
exposure=0.0,
notes="Initial portfolio state",
metadata={"initial_capital": initial_capital, "currency": config["currency"]},
)
return PortfolioSnapshot( return PortfolioSnapshot(
trade_date=row["trade_date"], trade_date=row["trade_date"],
total_value=row["total_value"], total_value=row["total_value"],

149
app/utils/portfolio_init.py Normal file
View File

@ -0,0 +1,149 @@
"""Initialize portfolio database tables."""
from __future__ import annotations
import json
from typing import Any
from .logging import get_logger
from .config import get_config
LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "portfolio_init"}
def get_portfolio_config() -> dict[str, Any]:
"""获取投资组合配置.
Returns:
包含以下字段的字典:
- initial_capital: 初始投资金额
- currency: 货币类型
- position_limits: 仓位限制
"""
config = get_config()
settings = config.portfolio if hasattr(config, "portfolio") else None
if not settings:
from .config import PortfolioSettings
settings = PortfolioSettings()
return {
"initial_capital": settings.initial_capital,
"currency": settings.currency,
"position_limits": {
"max_position": settings.max_position,
"min_position": settings.min_position,
"max_total_positions": settings.max_total_positions,
"max_sector_exposure": settings.max_sector_exposure
}
}
def update_portfolio_config(updates: dict[str, Any]) -> None:
"""更新投资组合配置.
Args:
updates: 要更新的配置项字典
"""
from .config import get_config, save_config, PortfolioSettings
# 获取当前配置
config = get_config()
# 创建新的投资组合设置
portfolio = PortfolioSettings(
initial_capital=updates["initial_capital"],
currency=updates["currency"],
max_position=updates["position_limits"]["max_position"],
min_position=updates["position_limits"]["min_position"],
max_total_positions=updates["position_limits"]["max_total_positions"],
max_sector_exposure=updates["position_limits"]["max_sector_exposure"]
)
# 更新配置
config.portfolio = portfolio
save_config(config)
SCHEMA_STATEMENTS = [
# 投资池表
"""
CREATE TABLE IF NOT EXISTS investment_pool (
trade_date TEXT,
ts_code TEXT,
score REAL,
status TEXT,
rationale TEXT,
tags TEXT, -- JSON array
metadata TEXT, -- JSON object
PRIMARY KEY (trade_date, ts_code)
)
""",
# 数据获取任务表
"""
CREATE TABLE IF NOT EXISTS fetch_jobs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
job_type TEXT NOT NULL,
status TEXT NOT NULL,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
error_msg TEXT,
metadata TEXT -- JSON object for additional info
)
);
""",
# 持仓表
"""
CREATE TABLE IF NOT EXISTS portfolio_positions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
ts_code TEXT NOT NULL,
opened_date TEXT NOT NULL,
closed_date TEXT,
quantity REAL NOT NULL,
cost_price REAL NOT NULL,
market_price REAL,
market_value REAL,
realized_pnl REAL DEFAULT 0,
unrealized_pnl REAL DEFAULT 0,
target_weight REAL,
status TEXT NOT NULL DEFAULT 'open',
notes TEXT,
metadata TEXT -- JSON object
);
""",
# 投资组合快照表
"""
CREATE TABLE IF NOT EXISTS portfolio_snapshots (
trade_date TEXT PRIMARY KEY,
total_value REAL,
cash REAL,
invested_value REAL,
unrealized_pnl REAL,
realized_pnl REAL,
net_flow REAL,
exposure REAL,
notes TEXT,
metadata TEXT -- JSON object
);
""",
]
def initialize_database_schema() -> None:
"""Create database tables if they don't exist."""
from .db import db_session
with db_session() as conn:
for statement in SCHEMA_STATEMENTS:
try:
conn.execute(statement)
except Exception: # noqa: BLE001
LOGGER.exception(
"执行 schema 语句失败",
extra={"sql": statement, **LOG_EXTRA}
)
raise

View File

@ -0,0 +1,57 @@
"""Test portfolio configuration and initialization."""
from unittest.mock import patch, MagicMock
from app.utils.portfolio import get_latest_snapshot
from app.utils.db import db_session
def test_default_portfolio_config():
"""Test default portfolio configuration."""
# Mock db_session as a context manager
mock_session = MagicMock()
mock_session.__enter__ = MagicMock(return_value=mock_session)
mock_session.__exit__ = MagicMock(return_value=None)
# Mock the database query result
mock_session.execute.return_value.fetchone.return_value = None
# 使用默认配置
with patch("app.utils.portfolio.get_portfolio_config") as mock_config, \
patch("app.utils.portfolio.db_session", return_value=mock_session):
mock_config.return_value = {
"initial_capital": 1000000,
"currency": "CNY"
}
snapshot = get_latest_snapshot()
assert snapshot is not None
assert snapshot.total_value == 1000000
assert snapshot.cash == 1000000
assert snapshot.metadata["initial_capital"] == 1000000
assert snapshot.metadata["currency"] == "CNY"
def test_custom_portfolio_config():
"""Test custom portfolio configuration."""
# Mock db_session as a context manager
mock_session = MagicMock()
mock_session.__enter__ = MagicMock(return_value=mock_session)
mock_session.__exit__ = MagicMock(return_value=None)
# Mock the database query result
mock_session.execute.return_value.fetchone.return_value = None
# 使用自定义配置
with patch("app.utils.portfolio.get_portfolio_config") as mock_config, \
patch("app.utils.portfolio.db_session", return_value=mock_session):
mock_config.return_value = {
"initial_capital": 2000000,
"currency": "USD"
}
snapshot = get_latest_snapshot()
assert snapshot is not None
assert snapshot.total_value == 2000000
assert snapshot.cash == 2000000
assert snapshot.metadata["initial_capital"] == 2000000
assert snapshot.metadata["currency"] == "USD"