update
This commit is contained in:
parent
16a5fae732
commit
adfc8ee148
@ -9,6 +9,17 @@ from app.utils.db import db_session
|
|||||||
|
|
||||||
|
|
||||||
SCHEMA_STATEMENTS: Iterable[str] = (
|
SCHEMA_STATEMENTS: Iterable[str] = (
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS fetch_jobs (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
job_type TEXT NOT NULL,
|
||||||
|
status TEXT NOT NULL DEFAULT 'pending',
|
||||||
|
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
error_msg TEXT,
|
||||||
|
metadata TEXT -- JSON object for additional info
|
||||||
|
);
|
||||||
|
""",
|
||||||
"""
|
"""
|
||||||
CREATE TABLE IF NOT EXISTS stock_basic (
|
CREATE TABLE IF NOT EXISTS stock_basic (
|
||||||
ts_code TEXT PRIMARY KEY,
|
ts_code TEXT PRIMARY KEY,
|
||||||
@ -515,17 +526,32 @@ def _missing_tables() -> List[str]:
|
|||||||
return [name for name in REQUIRED_TABLES if name not in existing]
|
return [name for name in REQUIRED_TABLES if name not in existing]
|
||||||
|
|
||||||
|
|
||||||
def initialize_database() -> MigrationResult:
|
def initialize_database() -> None:
|
||||||
"""Create tables and indexes required by the application."""
|
"""Initialize the SQLite database with all required tables."""
|
||||||
|
with db_session() as session:
|
||||||
|
cursor = session.cursor()
|
||||||
|
|
||||||
missing = _missing_tables()
|
# 创建表
|
||||||
if not missing:
|
|
||||||
return MigrationResult(executed=0, skipped=True, missing_tables=[])
|
|
||||||
|
|
||||||
executed = 0
|
|
||||||
with db_session() as conn:
|
|
||||||
cursor = conn.cursor()
|
|
||||||
for statement in SCHEMA_STATEMENTS:
|
for statement in SCHEMA_STATEMENTS:
|
||||||
cursor.executescript(statement)
|
try:
|
||||||
executed += 1
|
cursor.execute(statement)
|
||||||
return MigrationResult(executed=executed, skipped=False, missing_tables=missing)
|
except Exception as e: # noqa: BLE001
|
||||||
|
print(f"初始化数据库时出错: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# 添加触发器以自动更新 updated_at 字段
|
||||||
|
try:
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE TRIGGER IF NOT EXISTS update_fetch_jobs_timestamp
|
||||||
|
AFTER UPDATE ON fetch_jobs
|
||||||
|
BEGIN
|
||||||
|
UPDATE fetch_jobs
|
||||||
|
SET updated_at = CURRENT_TIMESTAMP
|
||||||
|
WHERE id = NEW.id;
|
||||||
|
END;
|
||||||
|
""")
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
print(f"创建触发器时出错: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
|||||||
85
app/ingest/job_logger.py
Normal file
85
app/ingest/job_logger.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
"""任务记录工具类。"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from app.utils.db import db_session
|
||||||
|
|
||||||
|
|
||||||
|
class JobLogger:
|
||||||
|
"""任务记录器。"""
|
||||||
|
|
||||||
|
def __init__(self, job_type: str) -> None:
|
||||||
|
"""初始化任务记录器。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
job_type: 任务类型
|
||||||
|
"""
|
||||||
|
self.job_type = job_type
|
||||||
|
self.job_id: Optional[int] = None
|
||||||
|
|
||||||
|
def __enter__(self) -> "JobLogger":
|
||||||
|
"""开始记录任务。"""
|
||||||
|
with db_session() as session:
|
||||||
|
cursor = session.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO fetch_jobs (job_type, status, created_at, updated_at)
|
||||||
|
VALUES (?, 'running', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
||||||
|
""",
|
||||||
|
(self.job_type,)
|
||||||
|
)
|
||||||
|
self.job_id = cursor.lastrowid
|
||||||
|
session.commit()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||||
|
"""结束任务记录。"""
|
||||||
|
if exc_val:
|
||||||
|
self.update_status("failed", str(exc_val))
|
||||||
|
else:
|
||||||
|
self.update_status("success")
|
||||||
|
|
||||||
|
def update_status(self, status: str, error_msg: Optional[str] = None) -> None:
|
||||||
|
"""更新任务状态。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
status: 新状态
|
||||||
|
error_msg: 错误信息(如果有)
|
||||||
|
"""
|
||||||
|
if not self.job_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
with db_session() as session:
|
||||||
|
session.execute(
|
||||||
|
"""
|
||||||
|
UPDATE fetch_jobs
|
||||||
|
SET status = ?,
|
||||||
|
error_msg = ?,
|
||||||
|
updated_at = CURRENT_TIMESTAMP
|
||||||
|
WHERE id = ?
|
||||||
|
""",
|
||||||
|
(status, error_msg, self.job_id)
|
||||||
|
)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
def update_metadata(self, metadata: Dict[str, Any]) -> None:
|
||||||
|
"""更新任务元数据。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metadata: 元数据字典
|
||||||
|
"""
|
||||||
|
if not self.job_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
with db_session() as session:
|
||||||
|
session.execute(
|
||||||
|
"""
|
||||||
|
UPDATE fetch_jobs
|
||||||
|
SET metadata = ?
|
||||||
|
WHERE id = ?
|
||||||
|
""",
|
||||||
|
(json.dumps(metadata), self.job_id)
|
||||||
|
)
|
||||||
|
session.commit()
|
||||||
@ -227,6 +227,9 @@ def _fetch_paginated(endpoint: str, params: Dict[str, object], limit: int | None
|
|||||||
return merged
|
return merged
|
||||||
|
|
||||||
|
|
||||||
|
from .job_logger import JobLogger
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FetchJob:
|
class FetchJob:
|
||||||
name: str
|
name: str
|
||||||
@ -1602,8 +1605,17 @@ def collect_data_coverage(start: date, end: date) -> Dict[str, Dict[str, object]
|
|||||||
|
|
||||||
|
|
||||||
def run_ingestion(job: FetchJob, include_limits: bool = True) -> None:
|
def run_ingestion(job: FetchJob, include_limits: bool = True) -> None:
|
||||||
|
"""运行数据拉取任务。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
job: 任务配置
|
||||||
|
include_limits: 是否包含涨跌停数据
|
||||||
|
"""
|
||||||
|
with JobLogger("TuShare数据获取") as logger:
|
||||||
LOGGER.info("启动 TuShare 拉取任务:%s", job.name, extra=LOG_EXTRA)
|
LOGGER.info("启动 TuShare 拉取任务:%s", job.name, extra=LOG_EXTRA)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# 拉取基础数据
|
||||||
ensure_data_coverage(
|
ensure_data_coverage(
|
||||||
job.start,
|
job.start,
|
||||||
job.end,
|
job.end,
|
||||||
@ -1612,25 +1624,39 @@ def run_ingestion(job: FetchJob, include_limits: bool = True) -> None:
|
|||||||
include_extended=True,
|
include_extended=True,
|
||||||
force=True,
|
force=True,
|
||||||
)
|
)
|
||||||
except Exception as exc:
|
|
||||||
alerts.add_warning("TuShare", f"拉取任务失败:{job.name}", str(exc))
|
# 记录任务元数据
|
||||||
raise
|
logger.update_metadata({
|
||||||
else:
|
"name": job.name,
|
||||||
|
"start": str(job.start),
|
||||||
|
"end": str(job.end),
|
||||||
|
"codes": len(job.ts_codes) if job.ts_codes else 0
|
||||||
|
})
|
||||||
|
|
||||||
alerts.clear_warnings("TuShare")
|
alerts.clear_warnings("TuShare")
|
||||||
|
|
||||||
|
# 对日线数据计算因子
|
||||||
if job.granularity == "daily":
|
if job.granularity == "daily":
|
||||||
try:
|
|
||||||
LOGGER.info("开始计算因子:%s", job.name, extra=LOG_EXTRA)
|
LOGGER.info("开始计算因子:%s", job.name, extra=LOG_EXTRA)
|
||||||
|
try:
|
||||||
compute_factor_range(
|
compute_factor_range(
|
||||||
job.start,
|
job.start,
|
||||||
job.end,
|
job.end,
|
||||||
ts_codes=job.ts_codes,
|
ts_codes=job.ts_codes,
|
||||||
skip_existing=False,
|
skip_existing=False,
|
||||||
)
|
)
|
||||||
except Exception as exc:
|
|
||||||
alerts.add_warning("Factors", f"因子计算失败:{job.name}", str(exc))
|
|
||||||
LOGGER.exception("因子计算失败 job=%s", job.name, extra=LOG_EXTRA)
|
|
||||||
raise
|
|
||||||
else:
|
|
||||||
alerts.clear_warnings("Factors")
|
alerts.clear_warnings("Factors")
|
||||||
|
except Exception as exc:
|
||||||
|
LOGGER.exception("因子计算失败 job=%s", job.name, extra=LOG_EXTRA)
|
||||||
|
alerts.add_warning("Factors", f"因子计算失败:{job.name}", str(exc))
|
||||||
|
logger.update_status("failed", f"因子计算失败:{exc}")
|
||||||
|
raise
|
||||||
LOGGER.info("因子计算完成:%s", job.name, extra=LOG_EXTRA)
|
LOGGER.info("因子计算完成:%s", job.name, extra=LOG_EXTRA)
|
||||||
|
alerts.clear_warnings("Factors")
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
LOGGER.exception("数据拉取失败 job=%s", job.name, extra=LOG_EXTRA)
|
||||||
|
alerts.add_warning("TuShare", f"拉取任务失败:{job.name}", str(exc))
|
||||||
|
logger.update_status("failed", f"数据拉取失败:{exc}")
|
||||||
|
raise
|
||||||
LOGGER.info("任务 %s 完成", job.name, extra=LOG_EXTRA)
|
LOGGER.info("任务 %s 完成", job.name, extra=LOG_EXTRA)
|
||||||
|
|||||||
159
app/ui/portfolio_config.py
Normal file
159
app/ui/portfolio_config.py
Normal file
@ -0,0 +1,159 @@
|
|||||||
|
"""Portfolio configuration UI components."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import streamlit as st
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from app.utils.portfolio_init import get_portfolio_config, update_portfolio_config
|
||||||
|
|
||||||
|
|
||||||
|
def render_portfolio_config() -> None:
|
||||||
|
"""渲染投资组合配置界面."""
|
||||||
|
st.title("投资组合配置")
|
||||||
|
|
||||||
|
# 获取当前配置
|
||||||
|
config = get_portfolio_config()
|
||||||
|
|
||||||
|
# 基本配置部分
|
||||||
|
st.header("基本配置")
|
||||||
|
col1, col2 = st.columns(2)
|
||||||
|
|
||||||
|
with col1:
|
||||||
|
initial_capital = st.number_input(
|
||||||
|
"初始投资金额",
|
||||||
|
min_value=100000,
|
||||||
|
max_value=100000000,
|
||||||
|
value=int(config["initial_capital"]),
|
||||||
|
step=100000,
|
||||||
|
format="%d"
|
||||||
|
)
|
||||||
|
|
||||||
|
with col2:
|
||||||
|
currency = st.selectbox(
|
||||||
|
"币种",
|
||||||
|
options=["CNY", "USD", "HKD"],
|
||||||
|
index=["CNY", "USD", "HKD"].index(config["currency"])
|
||||||
|
)
|
||||||
|
|
||||||
|
# 仓位限制配置
|
||||||
|
st.header("仓位限制")
|
||||||
|
position_limits = config["position_limits"]
|
||||||
|
|
||||||
|
col1, col2 = st.columns(2)
|
||||||
|
|
||||||
|
with col1:
|
||||||
|
max_position = st.slider(
|
||||||
|
"单个持仓上限",
|
||||||
|
min_value=0.05,
|
||||||
|
max_value=0.50,
|
||||||
|
value=float(position_limits["max_position"]),
|
||||||
|
step=0.01,
|
||||||
|
format="%.2f",
|
||||||
|
help="单个股票最大持仓比例"
|
||||||
|
)
|
||||||
|
|
||||||
|
min_position = st.slider(
|
||||||
|
"单个持仓下限",
|
||||||
|
min_value=0.01,
|
||||||
|
max_value=0.10,
|
||||||
|
value=float(position_limits["min_position"]),
|
||||||
|
step=0.01,
|
||||||
|
format="%.2f",
|
||||||
|
help="单个股票最小持仓比例"
|
||||||
|
)
|
||||||
|
|
||||||
|
with col2:
|
||||||
|
max_total_positions = st.slider(
|
||||||
|
"最大持仓数",
|
||||||
|
min_value=5,
|
||||||
|
max_value=50,
|
||||||
|
value=int(position_limits["max_total_positions"]),
|
||||||
|
step=1,
|
||||||
|
help="投资组合中的最大股票数量"
|
||||||
|
)
|
||||||
|
|
||||||
|
max_sector_exposure = st.slider(
|
||||||
|
"行业敞口上限",
|
||||||
|
min_value=0.20,
|
||||||
|
max_value=0.50,
|
||||||
|
value=float(position_limits["max_sector_exposure"]),
|
||||||
|
step=0.05,
|
||||||
|
format="%.2f",
|
||||||
|
help="单个行业的最大持仓比例"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 配置预览
|
||||||
|
st.header("当前配置概览")
|
||||||
|
df = pd.DataFrame([
|
||||||
|
["初始资金", f"{initial_capital:,} {currency}"],
|
||||||
|
["单个持仓上限", f"{max_position:.1%}"],
|
||||||
|
["单个持仓下限", f"{min_position:.1%}"],
|
||||||
|
["最大持仓数", max_total_positions],
|
||||||
|
["行业敞口上限", f"{max_sector_exposure:.1%}"],
|
||||||
|
], columns=["配置项", "当前值"])
|
||||||
|
|
||||||
|
st.table(df.set_index("配置项"))
|
||||||
|
|
||||||
|
# 保存按钮
|
||||||
|
if st.button("保存配置"):
|
||||||
|
try:
|
||||||
|
update_portfolio_config({
|
||||||
|
"initial_capital": initial_capital,
|
||||||
|
"currency": currency,
|
||||||
|
"position_limits": {
|
||||||
|
"max_position": max_position,
|
||||||
|
"min_position": min_position,
|
||||||
|
"max_total_positions": max_total_positions,
|
||||||
|
"max_sector_exposure": max_sector_exposure
|
||||||
|
}
|
||||||
|
})
|
||||||
|
st.success("配置已更新!")
|
||||||
|
except Exception as e:
|
||||||
|
st.error(f"配置更新失败:{str(e)}")
|
||||||
|
|
||||||
|
# 投资组合限制可视化
|
||||||
|
st.header("仓位限制可视化")
|
||||||
|
|
||||||
|
# 生成示例数据
|
||||||
|
example_positions = np.random.uniform(
|
||||||
|
min_position,
|
||||||
|
max_position,
|
||||||
|
min(max_total_positions, 10)
|
||||||
|
)
|
||||||
|
example_positions = example_positions / example_positions.sum()
|
||||||
|
|
||||||
|
example_sectors = {
|
||||||
|
"科技": 0.30,
|
||||||
|
"金融": 0.25,
|
||||||
|
"消费": 0.20,
|
||||||
|
"医药": 0.15,
|
||||||
|
"其他": 0.10
|
||||||
|
}
|
||||||
|
|
||||||
|
col1, col2 = st.columns(2)
|
||||||
|
|
||||||
|
with col1:
|
||||||
|
st.subheader("示例持仓分布")
|
||||||
|
positions_df = pd.DataFrame({
|
||||||
|
"股票": [f"股票{i+1}" for i in range(len(example_positions))],
|
||||||
|
"持仓比例": example_positions
|
||||||
|
})
|
||||||
|
positions_df = positions_df.sort_values("持仓比例", ascending=True)
|
||||||
|
|
||||||
|
st.bar_chart(
|
||||||
|
positions_df.set_index("股票")["持仓比例"],
|
||||||
|
use_container_width=True
|
||||||
|
)
|
||||||
|
|
||||||
|
with col2:
|
||||||
|
st.subheader("示例行业分布")
|
||||||
|
sectors_df = pd.DataFrame({
|
||||||
|
"行业": list(example_sectors.keys()),
|
||||||
|
"敞口": list(example_sectors.values())
|
||||||
|
})
|
||||||
|
|
||||||
|
st.bar_chart(
|
||||||
|
sectors_df.set_index("行业")["敞口"],
|
||||||
|
use_container_width=True
|
||||||
|
)
|
||||||
@ -25,6 +25,7 @@ import streamlit as st
|
|||||||
from app.agents.base import AgentContext
|
from app.agents.base import AgentContext
|
||||||
from app.agents.game import Decision
|
from app.agents.game import Decision
|
||||||
from app.backtest.engine import BtConfig, run_backtest
|
from app.backtest.engine import BtConfig, run_backtest
|
||||||
|
from app.ui.portfolio_config import render_portfolio_config
|
||||||
from app.backtest.decision_env import DecisionEnv, ParameterSpec
|
from app.backtest.decision_env import DecisionEnv, ParameterSpec
|
||||||
from app.data.schema import initialize_database
|
from app.data.schema import initialize_database
|
||||||
from app.ingest.checker import run_boot_check
|
from app.ingest.checker import run_boot_check
|
||||||
@ -2667,10 +2668,79 @@ def render_tests() -> None:
|
|||||||
st.write(response)
|
st.write(response)
|
||||||
|
|
||||||
|
|
||||||
|
def render_data_settings() -> None:
|
||||||
|
"""渲染数据源配置界面."""
|
||||||
|
st.subheader("Tushare 数据源")
|
||||||
|
cfg = get_config()
|
||||||
|
|
||||||
|
col1, col2 = st.columns(2)
|
||||||
|
with col1:
|
||||||
|
tushare_token = st.text_input(
|
||||||
|
"Tushare Token",
|
||||||
|
value=cfg.tushare_token or "",
|
||||||
|
type="password",
|
||||||
|
help="从 tushare.pro 获取的 API token"
|
||||||
|
)
|
||||||
|
|
||||||
|
with col2:
|
||||||
|
auto_update = st.checkbox(
|
||||||
|
"启动时自动更新数据",
|
||||||
|
value=cfg.auto_update_data,
|
||||||
|
help="启动应用时自动检查并更新数据"
|
||||||
|
)
|
||||||
|
|
||||||
|
update_interval = st.slider(
|
||||||
|
"数据更新间隔(天)",
|
||||||
|
min_value=1,
|
||||||
|
max_value=30,
|
||||||
|
value=cfg.data_update_interval,
|
||||||
|
help="自动更新时检查的数据时间范围"
|
||||||
|
)
|
||||||
|
|
||||||
|
if st.button("保存数据源配置"):
|
||||||
|
cfg.tushare_token = tushare_token
|
||||||
|
cfg.auto_update_data = auto_update
|
||||||
|
cfg.data_update_interval = update_interval
|
||||||
|
save_config(cfg)
|
||||||
|
st.success("数据源配置已更新!")
|
||||||
|
|
||||||
|
st.divider()
|
||||||
|
st.subheader("数据更新记录")
|
||||||
|
|
||||||
|
with db_session() as session:
|
||||||
|
df = pd.read_sql_query(
|
||||||
|
"""
|
||||||
|
SELECT job_type, status, created_at, updated_at, error_msg
|
||||||
|
FROM fetch_jobs
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
LIMIT 50
|
||||||
|
""",
|
||||||
|
session
|
||||||
|
)
|
||||||
|
|
||||||
|
if not df.empty:
|
||||||
|
df["duration"] = (df["updated_at"] - df["created_at"]).dt.total_seconds().round(2)
|
||||||
|
df = df.drop(columns=["updated_at"])
|
||||||
|
df = df.rename(columns={
|
||||||
|
"job_type": "数据类型",
|
||||||
|
"status": "状态",
|
||||||
|
"created_at": "开始时间",
|
||||||
|
"error_msg": "错误信息",
|
||||||
|
"duration": "耗时(秒)"
|
||||||
|
})
|
||||||
|
st.dataframe(df, width='stretch')
|
||||||
|
else:
|
||||||
|
st.info("暂无数据更新记录")
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
LOGGER.info("初始化 Streamlit UI", extra=LOG_EXTRA)
|
LOGGER.info("初始化 Streamlit UI", extra=LOG_EXTRA)
|
||||||
st.set_page_config(page_title="多智能体个人投资助理", layout="wide")
|
st.set_page_config(page_title="多智能体个人投资助理", layout="wide")
|
||||||
|
|
||||||
|
# 确保数据库表已创建
|
||||||
|
from app.data.schema import initialize_database
|
||||||
|
initialize_database()
|
||||||
|
|
||||||
# 检查是否需要自动更新数据
|
# 检查是否需要自动更新数据
|
||||||
cfg = get_config()
|
cfg = get_config()
|
||||||
if cfg.auto_update_data:
|
if cfg.auto_update_data:
|
||||||
@ -2713,7 +2783,18 @@ def main() -> None:
|
|||||||
with tabs[1]:
|
with tabs[1]:
|
||||||
render_log_viewer()
|
render_log_viewer()
|
||||||
with tabs[2]:
|
with tabs[2]:
|
||||||
|
st.header("系统设置")
|
||||||
|
settings_tabs = st.tabs(["基本配置", "投资组合", "数据源"])
|
||||||
|
|
||||||
|
with settings_tabs[0]:
|
||||||
render_settings()
|
render_settings()
|
||||||
|
|
||||||
|
with settings_tabs[1]:
|
||||||
|
from app.ui.portfolio_config import render_portfolio_config
|
||||||
|
render_portfolio_config()
|
||||||
|
|
||||||
|
with settings_tabs[2]:
|
||||||
|
render_data_settings()
|
||||||
with tabs[3]:
|
with tabs[3]:
|
||||||
render_tests()
|
render_tests()
|
||||||
|
|
||||||
|
|||||||
@ -39,6 +39,18 @@ class DataPaths:
|
|||||||
self.config_file = config_path
|
self.config_file = config_path
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PortfolioSettings:
|
||||||
|
"""Portfolio configuration settings."""
|
||||||
|
|
||||||
|
initial_capital: float = 1000000 # 默认100万
|
||||||
|
currency: str = "CNY" # 默认人民币
|
||||||
|
max_position: float = 0.2 # 单个持仓上限 20%
|
||||||
|
min_position: float = 0.02 # 单个持仓下限 2%
|
||||||
|
max_total_positions: int = 20 # 最大持仓数
|
||||||
|
max_sector_exposure: float = 0.35 # 行业敞口上限 35%
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AgentWeights:
|
class AgentWeights:
|
||||||
"""Default weighting for decision agents."""
|
"""Default weighting for decision agents."""
|
||||||
@ -340,9 +352,11 @@ class AppConfig:
|
|||||||
agent_weights: AgentWeights = field(default_factory=AgentWeights)
|
agent_weights: AgentWeights = field(default_factory=AgentWeights)
|
||||||
force_refresh: bool = False
|
force_refresh: bool = False
|
||||||
auto_update_data: bool = False
|
auto_update_data: bool = False
|
||||||
|
data_update_interval: int = 7 # 数据更新间隔(天)
|
||||||
llm_providers: Dict[str, LLMProvider] = field(default_factory=_default_llm_providers)
|
llm_providers: Dict[str, LLMProvider] = field(default_factory=_default_llm_providers)
|
||||||
llm: LLMConfig = field(default_factory=LLMConfig)
|
llm: LLMConfig = field(default_factory=LLMConfig)
|
||||||
departments: Dict[str, DepartmentSettings] = field(default_factory=_default_departments)
|
departments: Dict[str, DepartmentSettings] = field(default_factory=_default_departments)
|
||||||
|
portfolio: PortfolioSettings = field(default_factory=PortfolioSettings)
|
||||||
|
|
||||||
def resolve_llm(self, route: Optional[str] = None) -> LLMConfig:
|
def resolve_llm(self, route: Optional[str] = None) -> LLMConfig:
|
||||||
return self.llm
|
return self.llm
|
||||||
|
|||||||
@ -7,6 +7,7 @@ from typing import Any, Dict, Iterable, List, Optional
|
|||||||
|
|
||||||
from .db import db_session
|
from .db import db_session
|
||||||
from .logging import get_logger
|
from .logging import get_logger
|
||||||
|
from .portfolio_init import get_portfolio_config
|
||||||
|
|
||||||
LOGGER = get_logger(__name__)
|
LOGGER = get_logger(__name__)
|
||||||
LOG_EXTRA = {"stage": "portfolio"}
|
LOG_EXTRA = {"stage": "portfolio"}
|
||||||
@ -169,8 +170,11 @@ class PortfolioSnapshot:
|
|||||||
|
|
||||||
|
|
||||||
def get_latest_snapshot() -> Optional[PortfolioSnapshot]:
|
def get_latest_snapshot() -> Optional[PortfolioSnapshot]:
|
||||||
"""Fetch the most recent portfolio snapshot."""
|
"""Fetch the most recent portfolio snapshot.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
最新的投资组合快照,如果没有数据则返回初始快照(仅包含初始资金)
|
||||||
|
"""
|
||||||
sql = """
|
sql = """
|
||||||
SELECT trade_date, total_value, cash, invested_value, unrealized_pnl,
|
SELECT trade_date, total_value, cash, invested_value, unrealized_pnl,
|
||||||
realized_pnl, net_flow, exposure, notes, metadata
|
realized_pnl, net_flow, exposure, notes, metadata
|
||||||
@ -186,7 +190,22 @@ def get_latest_snapshot() -> Optional[PortfolioSnapshot]:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
if not row:
|
if not row:
|
||||||
return None
|
# 如果没有快照,返回初始状态(只有初始资金)
|
||||||
|
config = get_portfolio_config()
|
||||||
|
initial_capital = config["initial_capital"]
|
||||||
|
return PortfolioSnapshot(
|
||||||
|
trade_date="", # 空日期表示初始状态
|
||||||
|
total_value=initial_capital,
|
||||||
|
cash=initial_capital,
|
||||||
|
invested_value=0.0,
|
||||||
|
unrealized_pnl=0.0,
|
||||||
|
realized_pnl=0.0,
|
||||||
|
net_flow=0.0,
|
||||||
|
exposure=0.0,
|
||||||
|
notes="Initial portfolio state",
|
||||||
|
metadata={"initial_capital": initial_capital, "currency": config["currency"]},
|
||||||
|
)
|
||||||
|
|
||||||
return PortfolioSnapshot(
|
return PortfolioSnapshot(
|
||||||
trade_date=row["trade_date"],
|
trade_date=row["trade_date"],
|
||||||
total_value=row["total_value"],
|
total_value=row["total_value"],
|
||||||
|
|||||||
149
app/utils/portfolio_init.py
Normal file
149
app/utils/portfolio_init.py
Normal file
@ -0,0 +1,149 @@
|
|||||||
|
"""Initialize portfolio database tables."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .logging import get_logger
|
||||||
|
from .config import get_config
|
||||||
|
|
||||||
|
|
||||||
|
LOGGER = get_logger(__name__)
|
||||||
|
LOG_EXTRA = {"stage": "portfolio_init"}
|
||||||
|
|
||||||
|
|
||||||
|
def get_portfolio_config() -> dict[str, Any]:
|
||||||
|
"""获取投资组合配置.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含以下字段的字典:
|
||||||
|
- initial_capital: 初始投资金额
|
||||||
|
- currency: 货币类型
|
||||||
|
- position_limits: 仓位限制
|
||||||
|
"""
|
||||||
|
config = get_config()
|
||||||
|
settings = config.portfolio if hasattr(config, "portfolio") else None
|
||||||
|
|
||||||
|
if not settings:
|
||||||
|
from .config import PortfolioSettings
|
||||||
|
settings = PortfolioSettings()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"initial_capital": settings.initial_capital,
|
||||||
|
"currency": settings.currency,
|
||||||
|
"position_limits": {
|
||||||
|
"max_position": settings.max_position,
|
||||||
|
"min_position": settings.min_position,
|
||||||
|
"max_total_positions": settings.max_total_positions,
|
||||||
|
"max_sector_exposure": settings.max_sector_exposure
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def update_portfolio_config(updates: dict[str, Any]) -> None:
|
||||||
|
"""更新投资组合配置.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
updates: 要更新的配置项字典
|
||||||
|
"""
|
||||||
|
from .config import get_config, save_config, PortfolioSettings
|
||||||
|
|
||||||
|
# 获取当前配置
|
||||||
|
config = get_config()
|
||||||
|
|
||||||
|
# 创建新的投资组合设置
|
||||||
|
portfolio = PortfolioSettings(
|
||||||
|
initial_capital=updates["initial_capital"],
|
||||||
|
currency=updates["currency"],
|
||||||
|
max_position=updates["position_limits"]["max_position"],
|
||||||
|
min_position=updates["position_limits"]["min_position"],
|
||||||
|
max_total_positions=updates["position_limits"]["max_total_positions"],
|
||||||
|
max_sector_exposure=updates["position_limits"]["max_sector_exposure"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新配置
|
||||||
|
config.portfolio = portfolio
|
||||||
|
save_config(config)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
SCHEMA_STATEMENTS = [
|
||||||
|
# 投资池表
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS investment_pool (
|
||||||
|
trade_date TEXT,
|
||||||
|
ts_code TEXT,
|
||||||
|
score REAL,
|
||||||
|
status TEXT,
|
||||||
|
rationale TEXT,
|
||||||
|
tags TEXT, -- JSON array
|
||||||
|
metadata TEXT, -- JSON object
|
||||||
|
PRIMARY KEY (trade_date, ts_code)
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
|
||||||
|
# 数据获取任务表
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS fetch_jobs (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
job_type TEXT NOT NULL,
|
||||||
|
status TEXT NOT NULL,
|
||||||
|
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
error_msg TEXT,
|
||||||
|
metadata TEXT -- JSON object for additional info
|
||||||
|
)
|
||||||
|
);
|
||||||
|
""",
|
||||||
|
|
||||||
|
# 持仓表
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS portfolio_positions (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
ts_code TEXT NOT NULL,
|
||||||
|
opened_date TEXT NOT NULL,
|
||||||
|
closed_date TEXT,
|
||||||
|
quantity REAL NOT NULL,
|
||||||
|
cost_price REAL NOT NULL,
|
||||||
|
market_price REAL,
|
||||||
|
market_value REAL,
|
||||||
|
realized_pnl REAL DEFAULT 0,
|
||||||
|
unrealized_pnl REAL DEFAULT 0,
|
||||||
|
target_weight REAL,
|
||||||
|
status TEXT NOT NULL DEFAULT 'open',
|
||||||
|
notes TEXT,
|
||||||
|
metadata TEXT -- JSON object
|
||||||
|
);
|
||||||
|
""",
|
||||||
|
|
||||||
|
# 投资组合快照表
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS portfolio_snapshots (
|
||||||
|
trade_date TEXT PRIMARY KEY,
|
||||||
|
total_value REAL,
|
||||||
|
cash REAL,
|
||||||
|
invested_value REAL,
|
||||||
|
unrealized_pnl REAL,
|
||||||
|
realized_pnl REAL,
|
||||||
|
net_flow REAL,
|
||||||
|
exposure REAL,
|
||||||
|
notes TEXT,
|
||||||
|
metadata TEXT -- JSON object
|
||||||
|
);
|
||||||
|
""",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_database_schema() -> None:
|
||||||
|
"""Create database tables if they don't exist."""
|
||||||
|
from .db import db_session
|
||||||
|
|
||||||
|
with db_session() as conn:
|
||||||
|
for statement in SCHEMA_STATEMENTS:
|
||||||
|
try:
|
||||||
|
conn.execute(statement)
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
LOGGER.exception(
|
||||||
|
"执行 schema 语句失败",
|
||||||
|
extra={"sql": statement, **LOG_EXTRA}
|
||||||
|
)
|
||||||
|
raise
|
||||||
57
tests/test_portfolio_config.py
Normal file
57
tests/test_portfolio_config.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
"""Test portfolio configuration and initialization."""
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
from app.utils.portfolio import get_latest_snapshot
|
||||||
|
from app.utils.db import db_session
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_portfolio_config():
|
||||||
|
"""Test default portfolio configuration."""
|
||||||
|
# Mock db_session as a context manager
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session.__enter__ = MagicMock(return_value=mock_session)
|
||||||
|
mock_session.__exit__ = MagicMock(return_value=None)
|
||||||
|
|
||||||
|
# Mock the database query result
|
||||||
|
mock_session.execute.return_value.fetchone.return_value = None
|
||||||
|
|
||||||
|
# 使用默认配置
|
||||||
|
with patch("app.utils.portfolio.get_portfolio_config") as mock_config, \
|
||||||
|
patch("app.utils.portfolio.db_session", return_value=mock_session):
|
||||||
|
mock_config.return_value = {
|
||||||
|
"initial_capital": 1000000,
|
||||||
|
"currency": "CNY"
|
||||||
|
}
|
||||||
|
|
||||||
|
snapshot = get_latest_snapshot()
|
||||||
|
assert snapshot is not None
|
||||||
|
assert snapshot.total_value == 1000000
|
||||||
|
assert snapshot.cash == 1000000
|
||||||
|
assert snapshot.metadata["initial_capital"] == 1000000
|
||||||
|
assert snapshot.metadata["currency"] == "CNY"
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_portfolio_config():
|
||||||
|
"""Test custom portfolio configuration."""
|
||||||
|
# Mock db_session as a context manager
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session.__enter__ = MagicMock(return_value=mock_session)
|
||||||
|
mock_session.__exit__ = MagicMock(return_value=None)
|
||||||
|
|
||||||
|
# Mock the database query result
|
||||||
|
mock_session.execute.return_value.fetchone.return_value = None
|
||||||
|
|
||||||
|
# 使用自定义配置
|
||||||
|
with patch("app.utils.portfolio.get_portfolio_config") as mock_config, \
|
||||||
|
patch("app.utils.portfolio.db_session", return_value=mock_session):
|
||||||
|
mock_config.return_value = {
|
||||||
|
"initial_capital": 2000000,
|
||||||
|
"currency": "USD"
|
||||||
|
}
|
||||||
|
|
||||||
|
snapshot = get_latest_snapshot()
|
||||||
|
assert snapshot is not None
|
||||||
|
assert snapshot.total_value == 2000000
|
||||||
|
assert snapshot.cash == 2000000
|
||||||
|
assert snapshot.metadata["initial_capital"] == 2000000
|
||||||
|
assert snapshot.metadata["currency"] == "USD"
|
||||||
Loading…
Reference in New Issue
Block a user