From adfc8ee1487702a6b32c64bf34bbdbc3ef48e43f Mon Sep 17 00:00:00 2001 From: sam Date: Sun, 5 Oct 2025 15:26:01 +0800 Subject: [PATCH] update --- app/data/schema.py | 52 ++++++++--- app/ingest/job_logger.py | 85 ++++++++++++++++++ app/ingest/tushare.py | 84 +++++++++++------ app/ui/portfolio_config.py | 159 +++++++++++++++++++++++++++++++++ app/ui/streamlit_app.py | 83 ++++++++++++++++- app/utils/config.py | 14 +++ app/utils/portfolio.py | 25 +++++- app/utils/portfolio_init.py | 149 ++++++++++++++++++++++++++++++ tests/test_portfolio_config.py | 57 ++++++++++++ 9 files changed, 662 insertions(+), 46 deletions(-) create mode 100644 app/ingest/job_logger.py create mode 100644 app/ui/portfolio_config.py create mode 100644 app/utils/portfolio_init.py create mode 100644 tests/test_portfolio_config.py diff --git a/app/data/schema.py b/app/data/schema.py index 4fa435e..a23af34 100644 --- a/app/data/schema.py +++ b/app/data/schema.py @@ -9,6 +9,17 @@ from app.utils.db import db_session SCHEMA_STATEMENTS: Iterable[str] = ( + """ + CREATE TABLE IF NOT EXISTS fetch_jobs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + job_type TEXT NOT NULL, + status TEXT NOT NULL DEFAULT 'pending', + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + error_msg TEXT, + metadata TEXT -- JSON object for additional info + ); + """, """ CREATE TABLE IF NOT EXISTS stock_basic ( ts_code TEXT PRIMARY KEY, @@ -515,17 +526,32 @@ def _missing_tables() -> List[str]: return [name for name in REQUIRED_TABLES if name not in existing] -def initialize_database() -> MigrationResult: - """Create tables and indexes required by the application.""" - - missing = _missing_tables() - if not missing: - return MigrationResult(executed=0, skipped=True, missing_tables=[]) - - executed = 0 - with db_session() as conn: - cursor = conn.cursor() +def initialize_database() -> None: + """Initialize the SQLite database with all required tables.""" + with db_session() as session: + cursor = session.cursor() + + # 创建表 for statement in SCHEMA_STATEMENTS: - cursor.executescript(statement) - executed += 1 - return MigrationResult(executed=executed, skipped=False, missing_tables=missing) + try: + cursor.execute(statement) + except Exception as e: # noqa: BLE001 + print(f"初始化数据库时出错: {e}") + raise + + # 添加触发器以自动更新 updated_at 字段 + try: + cursor.execute(""" + CREATE TRIGGER IF NOT EXISTS update_fetch_jobs_timestamp + AFTER UPDATE ON fetch_jobs + BEGIN + UPDATE fetch_jobs + SET updated_at = CURRENT_TIMESTAMP + WHERE id = NEW.id; + END; + """) + except Exception as e: # noqa: BLE001 + print(f"创建触发器时出错: {e}") + raise + + session.commit() diff --git a/app/ingest/job_logger.py b/app/ingest/job_logger.py new file mode 100644 index 0000000..739d94e --- /dev/null +++ b/app/ingest/job_logger.py @@ -0,0 +1,85 @@ +"""任务记录工具类。""" +from __future__ import annotations + +import json +from datetime import datetime +from typing import Any, Dict, Optional + +from app.utils.db import db_session + + +class JobLogger: + """任务记录器。""" + + def __init__(self, job_type: str) -> None: + """初始化任务记录器。 + + Args: + job_type: 任务类型 + """ + self.job_type = job_type + self.job_id: Optional[int] = None + + def __enter__(self) -> "JobLogger": + """开始记录任务。""" + with db_session() as session: + cursor = session.execute( + """ + INSERT INTO fetch_jobs (job_type, status, created_at, updated_at) + VALUES (?, 'running', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) + """, + (self.job_type,) + ) + self.job_id = cursor.lastrowid + session.commit() + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + """结束任务记录。""" + if exc_val: + self.update_status("failed", str(exc_val)) + else: + self.update_status("success") + + def update_status(self, status: str, error_msg: Optional[str] = None) -> None: + """更新任务状态。 + + Args: + status: 新状态 + error_msg: 错误信息(如果有) + """ + if not self.job_id: + return + + with db_session() as session: + session.execute( + """ + UPDATE fetch_jobs + SET status = ?, + error_msg = ?, + updated_at = CURRENT_TIMESTAMP + WHERE id = ? + """, + (status, error_msg, self.job_id) + ) + session.commit() + + def update_metadata(self, metadata: Dict[str, Any]) -> None: + """更新任务元数据。 + + Args: + metadata: 元数据字典 + """ + if not self.job_id: + return + + with db_session() as session: + session.execute( + """ + UPDATE fetch_jobs + SET metadata = ? + WHERE id = ? + """, + (json.dumps(metadata), self.job_id) + ) + session.commit() diff --git a/app/ingest/tushare.py b/app/ingest/tushare.py index 8343f4a..98551a0 100644 --- a/app/ingest/tushare.py +++ b/app/ingest/tushare.py @@ -227,6 +227,9 @@ def _fetch_paginated(endpoint: str, params: Dict[str, object], limit: int | None return merged +from .job_logger import JobLogger + + @dataclass class FetchJob: name: str @@ -1602,35 +1605,58 @@ def collect_data_coverage(start: date, end: date) -> Dict[str, Dict[str, object] def run_ingestion(job: FetchJob, include_limits: bool = True) -> None: - LOGGER.info("启动 TuShare 拉取任务:%s", job.name, extra=LOG_EXTRA) - try: - ensure_data_coverage( - job.start, - job.end, - ts_codes=job.ts_codes, - include_limits=include_limits, - include_extended=True, - force=True, - ) - except Exception as exc: - alerts.add_warning("TuShare", f"拉取任务失败:{job.name}", str(exc)) - raise - else: - alerts.clear_warnings("TuShare") - if job.granularity == "daily": - try: + """运行数据拉取任务。 + + Args: + job: 任务配置 + include_limits: 是否包含涨跌停数据 + """ + with JobLogger("TuShare数据获取") as logger: + LOGGER.info("启动 TuShare 拉取任务:%s", job.name, extra=LOG_EXTRA) + + try: + # 拉取基础数据 + ensure_data_coverage( + job.start, + job.end, + ts_codes=job.ts_codes, + include_limits=include_limits, + include_extended=True, + force=True, + ) + + # 记录任务元数据 + logger.update_metadata({ + "name": job.name, + "start": str(job.start), + "end": str(job.end), + "codes": len(job.ts_codes) if job.ts_codes else 0 + }) + + alerts.clear_warnings("TuShare") + + # 对日线数据计算因子 + if job.granularity == "daily": LOGGER.info("开始计算因子:%s", job.name, extra=LOG_EXTRA) - compute_factor_range( - job.start, - job.end, - ts_codes=job.ts_codes, - skip_existing=False, - ) - except Exception as exc: - alerts.add_warning("Factors", f"因子计算失败:{job.name}", str(exc)) - LOGGER.exception("因子计算失败 job=%s", job.name, extra=LOG_EXTRA) - raise - else: - alerts.clear_warnings("Factors") + try: + compute_factor_range( + job.start, + job.end, + ts_codes=job.ts_codes, + skip_existing=False, + ) + alerts.clear_warnings("Factors") + except Exception as exc: + LOGGER.exception("因子计算失败 job=%s", job.name, extra=LOG_EXTRA) + alerts.add_warning("Factors", f"因子计算失败:{job.name}", str(exc)) + logger.update_status("failed", f"因子计算失败:{exc}") + raise LOGGER.info("因子计算完成:%s", job.name, extra=LOG_EXTRA) + alerts.clear_warnings("Factors") + + except Exception as exc: + LOGGER.exception("数据拉取失败 job=%s", job.name, extra=LOG_EXTRA) + alerts.add_warning("TuShare", f"拉取任务失败:{job.name}", str(exc)) + logger.update_status("failed", f"数据拉取失败:{exc}") + raise LOGGER.info("任务 %s 完成", job.name, extra=LOG_EXTRA) diff --git a/app/ui/portfolio_config.py b/app/ui/portfolio_config.py new file mode 100644 index 0000000..869783f --- /dev/null +++ b/app/ui/portfolio_config.py @@ -0,0 +1,159 @@ +"""Portfolio configuration UI components.""" +from __future__ import annotations + +import streamlit as st +import numpy as np +import pandas as pd + +from app.utils.portfolio_init import get_portfolio_config, update_portfolio_config + + +def render_portfolio_config() -> None: + """渲染投资组合配置界面.""" + st.title("投资组合配置") + + # 获取当前配置 + config = get_portfolio_config() + + # 基本配置部分 + st.header("基本配置") + col1, col2 = st.columns(2) + + with col1: + initial_capital = st.number_input( + "初始投资金额", + min_value=100000, + max_value=100000000, + value=int(config["initial_capital"]), + step=100000, + format="%d" + ) + + with col2: + currency = st.selectbox( + "币种", + options=["CNY", "USD", "HKD"], + index=["CNY", "USD", "HKD"].index(config["currency"]) + ) + + # 仓位限制配置 + st.header("仓位限制") + position_limits = config["position_limits"] + + col1, col2 = st.columns(2) + + with col1: + max_position = st.slider( + "单个持仓上限", + min_value=0.05, + max_value=0.50, + value=float(position_limits["max_position"]), + step=0.01, + format="%.2f", + help="单个股票最大持仓比例" + ) + + min_position = st.slider( + "单个持仓下限", + min_value=0.01, + max_value=0.10, + value=float(position_limits["min_position"]), + step=0.01, + format="%.2f", + help="单个股票最小持仓比例" + ) + + with col2: + max_total_positions = st.slider( + "最大持仓数", + min_value=5, + max_value=50, + value=int(position_limits["max_total_positions"]), + step=1, + help="投资组合中的最大股票数量" + ) + + max_sector_exposure = st.slider( + "行业敞口上限", + min_value=0.20, + max_value=0.50, + value=float(position_limits["max_sector_exposure"]), + step=0.05, + format="%.2f", + help="单个行业的最大持仓比例" + ) + + # 配置预览 + st.header("当前配置概览") + df = pd.DataFrame([ + ["初始资金", f"{initial_capital:,} {currency}"], + ["单个持仓上限", f"{max_position:.1%}"], + ["单个持仓下限", f"{min_position:.1%}"], + ["最大持仓数", max_total_positions], + ["行业敞口上限", f"{max_sector_exposure:.1%}"], + ], columns=["配置项", "当前值"]) + + st.table(df.set_index("配置项")) + + # 保存按钮 + if st.button("保存配置"): + try: + update_portfolio_config({ + "initial_capital": initial_capital, + "currency": currency, + "position_limits": { + "max_position": max_position, + "min_position": min_position, + "max_total_positions": max_total_positions, + "max_sector_exposure": max_sector_exposure + } + }) + st.success("配置已更新!") + except Exception as e: + st.error(f"配置更新失败:{str(e)}") + + # 投资组合限制可视化 + st.header("仓位限制可视化") + + # 生成示例数据 + example_positions = np.random.uniform( + min_position, + max_position, + min(max_total_positions, 10) + ) + example_positions = example_positions / example_positions.sum() + + example_sectors = { + "科技": 0.30, + "金融": 0.25, + "消费": 0.20, + "医药": 0.15, + "其他": 0.10 + } + + col1, col2 = st.columns(2) + + with col1: + st.subheader("示例持仓分布") + positions_df = pd.DataFrame({ + "股票": [f"股票{i+1}" for i in range(len(example_positions))], + "持仓比例": example_positions + }) + positions_df = positions_df.sort_values("持仓比例", ascending=True) + + st.bar_chart( + positions_df.set_index("股票")["持仓比例"], + use_container_width=True + ) + + with col2: + st.subheader("示例行业分布") + sectors_df = pd.DataFrame({ + "行业": list(example_sectors.keys()), + "敞口": list(example_sectors.values()) + }) + + st.bar_chart( + sectors_df.set_index("行业")["敞口"], + use_container_width=True + ) diff --git a/app/ui/streamlit_app.py b/app/ui/streamlit_app.py index 3e3239f..56304db 100644 --- a/app/ui/streamlit_app.py +++ b/app/ui/streamlit_app.py @@ -25,6 +25,7 @@ import streamlit as st from app.agents.base import AgentContext from app.agents.game import Decision from app.backtest.engine import BtConfig, run_backtest +from app.ui.portfolio_config import render_portfolio_config from app.backtest.decision_env import DecisionEnv, ParameterSpec from app.data.schema import initialize_database from app.ingest.checker import run_boot_check @@ -2667,10 +2668,79 @@ def render_tests() -> None: st.write(response) +def render_data_settings() -> None: + """渲染数据源配置界面.""" + st.subheader("Tushare 数据源") + cfg = get_config() + + col1, col2 = st.columns(2) + with col1: + tushare_token = st.text_input( + "Tushare Token", + value=cfg.tushare_token or "", + type="password", + help="从 tushare.pro 获取的 API token" + ) + + with col2: + auto_update = st.checkbox( + "启动时自动更新数据", + value=cfg.auto_update_data, + help="启动应用时自动检查并更新数据" + ) + + update_interval = st.slider( + "数据更新间隔(天)", + min_value=1, + max_value=30, + value=cfg.data_update_interval, + help="自动更新时检查的数据时间范围" + ) + + if st.button("保存数据源配置"): + cfg.tushare_token = tushare_token + cfg.auto_update_data = auto_update + cfg.data_update_interval = update_interval + save_config(cfg) + st.success("数据源配置已更新!") + + st.divider() + st.subheader("数据更新记录") + + with db_session() as session: + df = pd.read_sql_query( + """ + SELECT job_type, status, created_at, updated_at, error_msg + FROM fetch_jobs + ORDER BY created_at DESC + LIMIT 50 + """, + session + ) + + if not df.empty: + df["duration"] = (df["updated_at"] - df["created_at"]).dt.total_seconds().round(2) + df = df.drop(columns=["updated_at"]) + df = df.rename(columns={ + "job_type": "数据类型", + "status": "状态", + "created_at": "开始时间", + "error_msg": "错误信息", + "duration": "耗时(秒)" + }) + st.dataframe(df, width='stretch') + else: + st.info("暂无数据更新记录") + + def main() -> None: LOGGER.info("初始化 Streamlit UI", extra=LOG_EXTRA) st.set_page_config(page_title="多智能体个人投资助理", layout="wide") + # 确保数据库表已创建 + from app.data.schema import initialize_database + initialize_database() + # 检查是否需要自动更新数据 cfg = get_config() if cfg.auto_update_data: @@ -2713,7 +2783,18 @@ def main() -> None: with tabs[1]: render_log_viewer() with tabs[2]: - render_settings() + st.header("系统设置") + settings_tabs = st.tabs(["基本配置", "投资组合", "数据源"]) + + with settings_tabs[0]: + render_settings() + + with settings_tabs[1]: + from app.ui.portfolio_config import render_portfolio_config + render_portfolio_config() + + with settings_tabs[2]: + render_data_settings() with tabs[3]: render_tests() diff --git a/app/utils/config.py b/app/utils/config.py index c4c4749..b82f3b3 100644 --- a/app/utils/config.py +++ b/app/utils/config.py @@ -39,6 +39,18 @@ class DataPaths: self.config_file = config_path +@dataclass +class PortfolioSettings: + """Portfolio configuration settings.""" + + initial_capital: float = 1000000 # 默认100万 + currency: str = "CNY" # 默认人民币 + max_position: float = 0.2 # 单个持仓上限 20% + min_position: float = 0.02 # 单个持仓下限 2% + max_total_positions: int = 20 # 最大持仓数 + max_sector_exposure: float = 0.35 # 行业敞口上限 35% + + @dataclass class AgentWeights: """Default weighting for decision agents.""" @@ -340,9 +352,11 @@ class AppConfig: agent_weights: AgentWeights = field(default_factory=AgentWeights) force_refresh: bool = False auto_update_data: bool = False + data_update_interval: int = 7 # 数据更新间隔(天) llm_providers: Dict[str, LLMProvider] = field(default_factory=_default_llm_providers) llm: LLMConfig = field(default_factory=LLMConfig) departments: Dict[str, DepartmentSettings] = field(default_factory=_default_departments) + portfolio: PortfolioSettings = field(default_factory=PortfolioSettings) def resolve_llm(self, route: Optional[str] = None) -> LLMConfig: return self.llm diff --git a/app/utils/portfolio.py b/app/utils/portfolio.py index 85fd0a5..52c3108 100644 --- a/app/utils/portfolio.py +++ b/app/utils/portfolio.py @@ -7,6 +7,7 @@ from typing import Any, Dict, Iterable, List, Optional from .db import db_session from .logging import get_logger +from .portfolio_init import get_portfolio_config LOGGER = get_logger(__name__) LOG_EXTRA = {"stage": "portfolio"} @@ -169,8 +170,11 @@ class PortfolioSnapshot: def get_latest_snapshot() -> Optional[PortfolioSnapshot]: - """Fetch the most recent portfolio snapshot.""" - + """Fetch the most recent portfolio snapshot. + + Returns: + 最新的投资组合快照,如果没有数据则返回初始快照(仅包含初始资金) + """ sql = """ SELECT trade_date, total_value, cash, invested_value, unrealized_pnl, realized_pnl, net_flow, exposure, notes, metadata @@ -186,7 +190,22 @@ def get_latest_snapshot() -> Optional[PortfolioSnapshot]: return None if not row: - return None + # 如果没有快照,返回初始状态(只有初始资金) + config = get_portfolio_config() + initial_capital = config["initial_capital"] + return PortfolioSnapshot( + trade_date="", # 空日期表示初始状态 + total_value=initial_capital, + cash=initial_capital, + invested_value=0.0, + unrealized_pnl=0.0, + realized_pnl=0.0, + net_flow=0.0, + exposure=0.0, + notes="Initial portfolio state", + metadata={"initial_capital": initial_capital, "currency": config["currency"]}, + ) + return PortfolioSnapshot( trade_date=row["trade_date"], total_value=row["total_value"], diff --git a/app/utils/portfolio_init.py b/app/utils/portfolio_init.py new file mode 100644 index 0000000..72c6b34 --- /dev/null +++ b/app/utils/portfolio_init.py @@ -0,0 +1,149 @@ +"""Initialize portfolio database tables.""" +from __future__ import annotations + +import json +from typing import Any + +from .logging import get_logger +from .config import get_config + + +LOGGER = get_logger(__name__) +LOG_EXTRA = {"stage": "portfolio_init"} + + +def get_portfolio_config() -> dict[str, Any]: + """获取投资组合配置. + + Returns: + 包含以下字段的字典: + - initial_capital: 初始投资金额 + - currency: 货币类型 + - position_limits: 仓位限制 + """ + config = get_config() + settings = config.portfolio if hasattr(config, "portfolio") else None + + if not settings: + from .config import PortfolioSettings + settings = PortfolioSettings() + + return { + "initial_capital": settings.initial_capital, + "currency": settings.currency, + "position_limits": { + "max_position": settings.max_position, + "min_position": settings.min_position, + "max_total_positions": settings.max_total_positions, + "max_sector_exposure": settings.max_sector_exposure + } + } + +def update_portfolio_config(updates: dict[str, Any]) -> None: + """更新投资组合配置. + + Args: + updates: 要更新的配置项字典 + """ + from .config import get_config, save_config, PortfolioSettings + + # 获取当前配置 + config = get_config() + + # 创建新的投资组合设置 + portfolio = PortfolioSettings( + initial_capital=updates["initial_capital"], + currency=updates["currency"], + max_position=updates["position_limits"]["max_position"], + min_position=updates["position_limits"]["min_position"], + max_total_positions=updates["position_limits"]["max_total_positions"], + max_sector_exposure=updates["position_limits"]["max_sector_exposure"] + ) + + # 更新配置 + config.portfolio = portfolio + save_config(config) + + + +SCHEMA_STATEMENTS = [ + # 投资池表 + """ + CREATE TABLE IF NOT EXISTS investment_pool ( + trade_date TEXT, + ts_code TEXT, + score REAL, + status TEXT, + rationale TEXT, + tags TEXT, -- JSON array + metadata TEXT, -- JSON object + PRIMARY KEY (trade_date, ts_code) + ) + """, + + # 数据获取任务表 + """ + CREATE TABLE IF NOT EXISTS fetch_jobs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + job_type TEXT NOT NULL, + status TEXT NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + error_msg TEXT, + metadata TEXT -- JSON object for additional info + ) + ); + """, + + # 持仓表 + """ + CREATE TABLE IF NOT EXISTS portfolio_positions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + ts_code TEXT NOT NULL, + opened_date TEXT NOT NULL, + closed_date TEXT, + quantity REAL NOT NULL, + cost_price REAL NOT NULL, + market_price REAL, + market_value REAL, + realized_pnl REAL DEFAULT 0, + unrealized_pnl REAL DEFAULT 0, + target_weight REAL, + status TEXT NOT NULL DEFAULT 'open', + notes TEXT, + metadata TEXT -- JSON object + ); + """, + + # 投资组合快照表 + """ + CREATE TABLE IF NOT EXISTS portfolio_snapshots ( + trade_date TEXT PRIMARY KEY, + total_value REAL, + cash REAL, + invested_value REAL, + unrealized_pnl REAL, + realized_pnl REAL, + net_flow REAL, + exposure REAL, + notes TEXT, + metadata TEXT -- JSON object + ); + """, +] + + +def initialize_database_schema() -> None: + """Create database tables if they don't exist.""" + from .db import db_session + + with db_session() as conn: + for statement in SCHEMA_STATEMENTS: + try: + conn.execute(statement) + except Exception: # noqa: BLE001 + LOGGER.exception( + "执行 schema 语句失败", + extra={"sql": statement, **LOG_EXTRA} + ) + raise diff --git a/tests/test_portfolio_config.py b/tests/test_portfolio_config.py new file mode 100644 index 0000000..af7b920 --- /dev/null +++ b/tests/test_portfolio_config.py @@ -0,0 +1,57 @@ +"""Test portfolio configuration and initialization.""" +from unittest.mock import patch, MagicMock + +from app.utils.portfolio import get_latest_snapshot +from app.utils.db import db_session + + +def test_default_portfolio_config(): + """Test default portfolio configuration.""" + # Mock db_session as a context manager + mock_session = MagicMock() + mock_session.__enter__ = MagicMock(return_value=mock_session) + mock_session.__exit__ = MagicMock(return_value=None) + + # Mock the database query result + mock_session.execute.return_value.fetchone.return_value = None + + # 使用默认配置 + with patch("app.utils.portfolio.get_portfolio_config") as mock_config, \ + patch("app.utils.portfolio.db_session", return_value=mock_session): + mock_config.return_value = { + "initial_capital": 1000000, + "currency": "CNY" + } + + snapshot = get_latest_snapshot() + assert snapshot is not None + assert snapshot.total_value == 1000000 + assert snapshot.cash == 1000000 + assert snapshot.metadata["initial_capital"] == 1000000 + assert snapshot.metadata["currency"] == "CNY" + + +def test_custom_portfolio_config(): + """Test custom portfolio configuration.""" + # Mock db_session as a context manager + mock_session = MagicMock() + mock_session.__enter__ = MagicMock(return_value=mock_session) + mock_session.__exit__ = MagicMock(return_value=None) + + # Mock the database query result + mock_session.execute.return_value.fetchone.return_value = None + + # 使用自定义配置 + with patch("app.utils.portfolio.get_portfolio_config") as mock_config, \ + patch("app.utils.portfolio.db_session", return_value=mock_session): + mock_config.return_value = { + "initial_capital": 2000000, + "currency": "USD" + } + + snapshot = get_latest_snapshot() + assert snapshot is not None + assert snapshot.total_value == 2000000 + assert snapshot.cash == 2000000 + assert snapshot.metadata["initial_capital"] == 2000000 + assert snapshot.metadata["currency"] == "USD"