update
This commit is contained in:
parent
16a5fae732
commit
adfc8ee148
@ -9,6 +9,17 @@ from app.utils.db import db_session
|
||||
|
||||
|
||||
SCHEMA_STATEMENTS: Iterable[str] = (
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS fetch_jobs (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
job_type TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
error_msg TEXT,
|
||||
metadata TEXT -- JSON object for additional info
|
||||
);
|
||||
""",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS stock_basic (
|
||||
ts_code TEXT PRIMARY KEY,
|
||||
@ -515,17 +526,32 @@ def _missing_tables() -> List[str]:
|
||||
return [name for name in REQUIRED_TABLES if name not in existing]
|
||||
|
||||
|
||||
def initialize_database() -> MigrationResult:
|
||||
"""Create tables and indexes required by the application."""
|
||||
def initialize_database() -> None:
|
||||
"""Initialize the SQLite database with all required tables."""
|
||||
with db_session() as session:
|
||||
cursor = session.cursor()
|
||||
|
||||
missing = _missing_tables()
|
||||
if not missing:
|
||||
return MigrationResult(executed=0, skipped=True, missing_tables=[])
|
||||
|
||||
executed = 0
|
||||
with db_session() as conn:
|
||||
cursor = conn.cursor()
|
||||
# 创建表
|
||||
for statement in SCHEMA_STATEMENTS:
|
||||
cursor.executescript(statement)
|
||||
executed += 1
|
||||
return MigrationResult(executed=executed, skipped=False, missing_tables=missing)
|
||||
try:
|
||||
cursor.execute(statement)
|
||||
except Exception as e: # noqa: BLE001
|
||||
print(f"初始化数据库时出错: {e}")
|
||||
raise
|
||||
|
||||
# 添加触发器以自动更新 updated_at 字段
|
||||
try:
|
||||
cursor.execute("""
|
||||
CREATE TRIGGER IF NOT EXISTS update_fetch_jobs_timestamp
|
||||
AFTER UPDATE ON fetch_jobs
|
||||
BEGIN
|
||||
UPDATE fetch_jobs
|
||||
SET updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = NEW.id;
|
||||
END;
|
||||
""")
|
||||
except Exception as e: # noqa: BLE001
|
||||
print(f"创建触发器时出错: {e}")
|
||||
raise
|
||||
|
||||
session.commit()
|
||||
|
||||
85
app/ingest/job_logger.py
Normal file
85
app/ingest/job_logger.py
Normal file
@ -0,0 +1,85 @@
|
||||
"""任务记录工具类。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from app.utils.db import db_session
|
||||
|
||||
|
||||
class JobLogger:
|
||||
"""任务记录器。"""
|
||||
|
||||
def __init__(self, job_type: str) -> None:
|
||||
"""初始化任务记录器。
|
||||
|
||||
Args:
|
||||
job_type: 任务类型
|
||||
"""
|
||||
self.job_type = job_type
|
||||
self.job_id: Optional[int] = None
|
||||
|
||||
def __enter__(self) -> "JobLogger":
|
||||
"""开始记录任务。"""
|
||||
with db_session() as session:
|
||||
cursor = session.execute(
|
||||
"""
|
||||
INSERT INTO fetch_jobs (job_type, status, created_at, updated_at)
|
||||
VALUES (?, 'running', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
||||
""",
|
||||
(self.job_type,)
|
||||
)
|
||||
self.job_id = cursor.lastrowid
|
||||
session.commit()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
"""结束任务记录。"""
|
||||
if exc_val:
|
||||
self.update_status("failed", str(exc_val))
|
||||
else:
|
||||
self.update_status("success")
|
||||
|
||||
def update_status(self, status: str, error_msg: Optional[str] = None) -> None:
|
||||
"""更新任务状态。
|
||||
|
||||
Args:
|
||||
status: 新状态
|
||||
error_msg: 错误信息(如果有)
|
||||
"""
|
||||
if not self.job_id:
|
||||
return
|
||||
|
||||
with db_session() as session:
|
||||
session.execute(
|
||||
"""
|
||||
UPDATE fetch_jobs
|
||||
SET status = ?,
|
||||
error_msg = ?,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = ?
|
||||
""",
|
||||
(status, error_msg, self.job_id)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
def update_metadata(self, metadata: Dict[str, Any]) -> None:
|
||||
"""更新任务元数据。
|
||||
|
||||
Args:
|
||||
metadata: 元数据字典
|
||||
"""
|
||||
if not self.job_id:
|
||||
return
|
||||
|
||||
with db_session() as session:
|
||||
session.execute(
|
||||
"""
|
||||
UPDATE fetch_jobs
|
||||
SET metadata = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(json.dumps(metadata), self.job_id)
|
||||
)
|
||||
session.commit()
|
||||
@ -227,6 +227,9 @@ def _fetch_paginated(endpoint: str, params: Dict[str, object], limit: int | None
|
||||
return merged
|
||||
|
||||
|
||||
from .job_logger import JobLogger
|
||||
|
||||
|
||||
@dataclass
|
||||
class FetchJob:
|
||||
name: str
|
||||
@ -1602,35 +1605,58 @@ def collect_data_coverage(start: date, end: date) -> Dict[str, Dict[str, object]
|
||||
|
||||
|
||||
def run_ingestion(job: FetchJob, include_limits: bool = True) -> None:
|
||||
LOGGER.info("启动 TuShare 拉取任务:%s", job.name, extra=LOG_EXTRA)
|
||||
try:
|
||||
ensure_data_coverage(
|
||||
job.start,
|
||||
job.end,
|
||||
ts_codes=job.ts_codes,
|
||||
include_limits=include_limits,
|
||||
include_extended=True,
|
||||
force=True,
|
||||
)
|
||||
except Exception as exc:
|
||||
alerts.add_warning("TuShare", f"拉取任务失败:{job.name}", str(exc))
|
||||
raise
|
||||
else:
|
||||
alerts.clear_warnings("TuShare")
|
||||
if job.granularity == "daily":
|
||||
try:
|
||||
"""运行数据拉取任务。
|
||||
|
||||
Args:
|
||||
job: 任务配置
|
||||
include_limits: 是否包含涨跌停数据
|
||||
"""
|
||||
with JobLogger("TuShare数据获取") as logger:
|
||||
LOGGER.info("启动 TuShare 拉取任务:%s", job.name, extra=LOG_EXTRA)
|
||||
|
||||
try:
|
||||
# 拉取基础数据
|
||||
ensure_data_coverage(
|
||||
job.start,
|
||||
job.end,
|
||||
ts_codes=job.ts_codes,
|
||||
include_limits=include_limits,
|
||||
include_extended=True,
|
||||
force=True,
|
||||
)
|
||||
|
||||
# 记录任务元数据
|
||||
logger.update_metadata({
|
||||
"name": job.name,
|
||||
"start": str(job.start),
|
||||
"end": str(job.end),
|
||||
"codes": len(job.ts_codes) if job.ts_codes else 0
|
||||
})
|
||||
|
||||
alerts.clear_warnings("TuShare")
|
||||
|
||||
# 对日线数据计算因子
|
||||
if job.granularity == "daily":
|
||||
LOGGER.info("开始计算因子:%s", job.name, extra=LOG_EXTRA)
|
||||
compute_factor_range(
|
||||
job.start,
|
||||
job.end,
|
||||
ts_codes=job.ts_codes,
|
||||
skip_existing=False,
|
||||
)
|
||||
except Exception as exc:
|
||||
alerts.add_warning("Factors", f"因子计算失败:{job.name}", str(exc))
|
||||
LOGGER.exception("因子计算失败 job=%s", job.name, extra=LOG_EXTRA)
|
||||
raise
|
||||
else:
|
||||
alerts.clear_warnings("Factors")
|
||||
try:
|
||||
compute_factor_range(
|
||||
job.start,
|
||||
job.end,
|
||||
ts_codes=job.ts_codes,
|
||||
skip_existing=False,
|
||||
)
|
||||
alerts.clear_warnings("Factors")
|
||||
except Exception as exc:
|
||||
LOGGER.exception("因子计算失败 job=%s", job.name, extra=LOG_EXTRA)
|
||||
alerts.add_warning("Factors", f"因子计算失败:{job.name}", str(exc))
|
||||
logger.update_status("failed", f"因子计算失败:{exc}")
|
||||
raise
|
||||
LOGGER.info("因子计算完成:%s", job.name, extra=LOG_EXTRA)
|
||||
alerts.clear_warnings("Factors")
|
||||
|
||||
except Exception as exc:
|
||||
LOGGER.exception("数据拉取失败 job=%s", job.name, extra=LOG_EXTRA)
|
||||
alerts.add_warning("TuShare", f"拉取任务失败:{job.name}", str(exc))
|
||||
logger.update_status("failed", f"数据拉取失败:{exc}")
|
||||
raise
|
||||
LOGGER.info("任务 %s 完成", job.name, extra=LOG_EXTRA)
|
||||
|
||||
159
app/ui/portfolio_config.py
Normal file
159
app/ui/portfolio_config.py
Normal file
@ -0,0 +1,159 @@
|
||||
"""Portfolio configuration UI components."""
|
||||
from __future__ import annotations
|
||||
|
||||
import streamlit as st
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from app.utils.portfolio_init import get_portfolio_config, update_portfolio_config
|
||||
|
||||
|
||||
def render_portfolio_config() -> None:
|
||||
"""渲染投资组合配置界面."""
|
||||
st.title("投资组合配置")
|
||||
|
||||
# 获取当前配置
|
||||
config = get_portfolio_config()
|
||||
|
||||
# 基本配置部分
|
||||
st.header("基本配置")
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
with col1:
|
||||
initial_capital = st.number_input(
|
||||
"初始投资金额",
|
||||
min_value=100000,
|
||||
max_value=100000000,
|
||||
value=int(config["initial_capital"]),
|
||||
step=100000,
|
||||
format="%d"
|
||||
)
|
||||
|
||||
with col2:
|
||||
currency = st.selectbox(
|
||||
"币种",
|
||||
options=["CNY", "USD", "HKD"],
|
||||
index=["CNY", "USD", "HKD"].index(config["currency"])
|
||||
)
|
||||
|
||||
# 仓位限制配置
|
||||
st.header("仓位限制")
|
||||
position_limits = config["position_limits"]
|
||||
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
with col1:
|
||||
max_position = st.slider(
|
||||
"单个持仓上限",
|
||||
min_value=0.05,
|
||||
max_value=0.50,
|
||||
value=float(position_limits["max_position"]),
|
||||
step=0.01,
|
||||
format="%.2f",
|
||||
help="单个股票最大持仓比例"
|
||||
)
|
||||
|
||||
min_position = st.slider(
|
||||
"单个持仓下限",
|
||||
min_value=0.01,
|
||||
max_value=0.10,
|
||||
value=float(position_limits["min_position"]),
|
||||
step=0.01,
|
||||
format="%.2f",
|
||||
help="单个股票最小持仓比例"
|
||||
)
|
||||
|
||||
with col2:
|
||||
max_total_positions = st.slider(
|
||||
"最大持仓数",
|
||||
min_value=5,
|
||||
max_value=50,
|
||||
value=int(position_limits["max_total_positions"]),
|
||||
step=1,
|
||||
help="投资组合中的最大股票数量"
|
||||
)
|
||||
|
||||
max_sector_exposure = st.slider(
|
||||
"行业敞口上限",
|
||||
min_value=0.20,
|
||||
max_value=0.50,
|
||||
value=float(position_limits["max_sector_exposure"]),
|
||||
step=0.05,
|
||||
format="%.2f",
|
||||
help="单个行业的最大持仓比例"
|
||||
)
|
||||
|
||||
# 配置预览
|
||||
st.header("当前配置概览")
|
||||
df = pd.DataFrame([
|
||||
["初始资金", f"{initial_capital:,} {currency}"],
|
||||
["单个持仓上限", f"{max_position:.1%}"],
|
||||
["单个持仓下限", f"{min_position:.1%}"],
|
||||
["最大持仓数", max_total_positions],
|
||||
["行业敞口上限", f"{max_sector_exposure:.1%}"],
|
||||
], columns=["配置项", "当前值"])
|
||||
|
||||
st.table(df.set_index("配置项"))
|
||||
|
||||
# 保存按钮
|
||||
if st.button("保存配置"):
|
||||
try:
|
||||
update_portfolio_config({
|
||||
"initial_capital": initial_capital,
|
||||
"currency": currency,
|
||||
"position_limits": {
|
||||
"max_position": max_position,
|
||||
"min_position": min_position,
|
||||
"max_total_positions": max_total_positions,
|
||||
"max_sector_exposure": max_sector_exposure
|
||||
}
|
||||
})
|
||||
st.success("配置已更新!")
|
||||
except Exception as e:
|
||||
st.error(f"配置更新失败:{str(e)}")
|
||||
|
||||
# 投资组合限制可视化
|
||||
st.header("仓位限制可视化")
|
||||
|
||||
# 生成示例数据
|
||||
example_positions = np.random.uniform(
|
||||
min_position,
|
||||
max_position,
|
||||
min(max_total_positions, 10)
|
||||
)
|
||||
example_positions = example_positions / example_positions.sum()
|
||||
|
||||
example_sectors = {
|
||||
"科技": 0.30,
|
||||
"金融": 0.25,
|
||||
"消费": 0.20,
|
||||
"医药": 0.15,
|
||||
"其他": 0.10
|
||||
}
|
||||
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
with col1:
|
||||
st.subheader("示例持仓分布")
|
||||
positions_df = pd.DataFrame({
|
||||
"股票": [f"股票{i+1}" for i in range(len(example_positions))],
|
||||
"持仓比例": example_positions
|
||||
})
|
||||
positions_df = positions_df.sort_values("持仓比例", ascending=True)
|
||||
|
||||
st.bar_chart(
|
||||
positions_df.set_index("股票")["持仓比例"],
|
||||
use_container_width=True
|
||||
)
|
||||
|
||||
with col2:
|
||||
st.subheader("示例行业分布")
|
||||
sectors_df = pd.DataFrame({
|
||||
"行业": list(example_sectors.keys()),
|
||||
"敞口": list(example_sectors.values())
|
||||
})
|
||||
|
||||
st.bar_chart(
|
||||
sectors_df.set_index("行业")["敞口"],
|
||||
use_container_width=True
|
||||
)
|
||||
@ -25,6 +25,7 @@ import streamlit as st
|
||||
from app.agents.base import AgentContext
|
||||
from app.agents.game import Decision
|
||||
from app.backtest.engine import BtConfig, run_backtest
|
||||
from app.ui.portfolio_config import render_portfolio_config
|
||||
from app.backtest.decision_env import DecisionEnv, ParameterSpec
|
||||
from app.data.schema import initialize_database
|
||||
from app.ingest.checker import run_boot_check
|
||||
@ -2667,10 +2668,79 @@ def render_tests() -> None:
|
||||
st.write(response)
|
||||
|
||||
|
||||
def render_data_settings() -> None:
|
||||
"""渲染数据源配置界面."""
|
||||
st.subheader("Tushare 数据源")
|
||||
cfg = get_config()
|
||||
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
tushare_token = st.text_input(
|
||||
"Tushare Token",
|
||||
value=cfg.tushare_token or "",
|
||||
type="password",
|
||||
help="从 tushare.pro 获取的 API token"
|
||||
)
|
||||
|
||||
with col2:
|
||||
auto_update = st.checkbox(
|
||||
"启动时自动更新数据",
|
||||
value=cfg.auto_update_data,
|
||||
help="启动应用时自动检查并更新数据"
|
||||
)
|
||||
|
||||
update_interval = st.slider(
|
||||
"数据更新间隔(天)",
|
||||
min_value=1,
|
||||
max_value=30,
|
||||
value=cfg.data_update_interval,
|
||||
help="自动更新时检查的数据时间范围"
|
||||
)
|
||||
|
||||
if st.button("保存数据源配置"):
|
||||
cfg.tushare_token = tushare_token
|
||||
cfg.auto_update_data = auto_update
|
||||
cfg.data_update_interval = update_interval
|
||||
save_config(cfg)
|
||||
st.success("数据源配置已更新!")
|
||||
|
||||
st.divider()
|
||||
st.subheader("数据更新记录")
|
||||
|
||||
with db_session() as session:
|
||||
df = pd.read_sql_query(
|
||||
"""
|
||||
SELECT job_type, status, created_at, updated_at, error_msg
|
||||
FROM fetch_jobs
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 50
|
||||
""",
|
||||
session
|
||||
)
|
||||
|
||||
if not df.empty:
|
||||
df["duration"] = (df["updated_at"] - df["created_at"]).dt.total_seconds().round(2)
|
||||
df = df.drop(columns=["updated_at"])
|
||||
df = df.rename(columns={
|
||||
"job_type": "数据类型",
|
||||
"status": "状态",
|
||||
"created_at": "开始时间",
|
||||
"error_msg": "错误信息",
|
||||
"duration": "耗时(秒)"
|
||||
})
|
||||
st.dataframe(df, width='stretch')
|
||||
else:
|
||||
st.info("暂无数据更新记录")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
LOGGER.info("初始化 Streamlit UI", extra=LOG_EXTRA)
|
||||
st.set_page_config(page_title="多智能体个人投资助理", layout="wide")
|
||||
|
||||
# 确保数据库表已创建
|
||||
from app.data.schema import initialize_database
|
||||
initialize_database()
|
||||
|
||||
# 检查是否需要自动更新数据
|
||||
cfg = get_config()
|
||||
if cfg.auto_update_data:
|
||||
@ -2713,7 +2783,18 @@ def main() -> None:
|
||||
with tabs[1]:
|
||||
render_log_viewer()
|
||||
with tabs[2]:
|
||||
render_settings()
|
||||
st.header("系统设置")
|
||||
settings_tabs = st.tabs(["基本配置", "投资组合", "数据源"])
|
||||
|
||||
with settings_tabs[0]:
|
||||
render_settings()
|
||||
|
||||
with settings_tabs[1]:
|
||||
from app.ui.portfolio_config import render_portfolio_config
|
||||
render_portfolio_config()
|
||||
|
||||
with settings_tabs[2]:
|
||||
render_data_settings()
|
||||
with tabs[3]:
|
||||
render_tests()
|
||||
|
||||
|
||||
@ -39,6 +39,18 @@ class DataPaths:
|
||||
self.config_file = config_path
|
||||
|
||||
|
||||
@dataclass
|
||||
class PortfolioSettings:
|
||||
"""Portfolio configuration settings."""
|
||||
|
||||
initial_capital: float = 1000000 # 默认100万
|
||||
currency: str = "CNY" # 默认人民币
|
||||
max_position: float = 0.2 # 单个持仓上限 20%
|
||||
min_position: float = 0.02 # 单个持仓下限 2%
|
||||
max_total_positions: int = 20 # 最大持仓数
|
||||
max_sector_exposure: float = 0.35 # 行业敞口上限 35%
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentWeights:
|
||||
"""Default weighting for decision agents."""
|
||||
@ -340,9 +352,11 @@ class AppConfig:
|
||||
agent_weights: AgentWeights = field(default_factory=AgentWeights)
|
||||
force_refresh: bool = False
|
||||
auto_update_data: bool = False
|
||||
data_update_interval: int = 7 # 数据更新间隔(天)
|
||||
llm_providers: Dict[str, LLMProvider] = field(default_factory=_default_llm_providers)
|
||||
llm: LLMConfig = field(default_factory=LLMConfig)
|
||||
departments: Dict[str, DepartmentSettings] = field(default_factory=_default_departments)
|
||||
portfolio: PortfolioSettings = field(default_factory=PortfolioSettings)
|
||||
|
||||
def resolve_llm(self, route: Optional[str] = None) -> LLMConfig:
|
||||
return self.llm
|
||||
|
||||
@ -7,6 +7,7 @@ from typing import Any, Dict, Iterable, List, Optional
|
||||
|
||||
from .db import db_session
|
||||
from .logging import get_logger
|
||||
from .portfolio_init import get_portfolio_config
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
LOG_EXTRA = {"stage": "portfolio"}
|
||||
@ -169,8 +170,11 @@ class PortfolioSnapshot:
|
||||
|
||||
|
||||
def get_latest_snapshot() -> Optional[PortfolioSnapshot]:
|
||||
"""Fetch the most recent portfolio snapshot."""
|
||||
"""Fetch the most recent portfolio snapshot.
|
||||
|
||||
Returns:
|
||||
最新的投资组合快照,如果没有数据则返回初始快照(仅包含初始资金)
|
||||
"""
|
||||
sql = """
|
||||
SELECT trade_date, total_value, cash, invested_value, unrealized_pnl,
|
||||
realized_pnl, net_flow, exposure, notes, metadata
|
||||
@ -186,7 +190,22 @@ def get_latest_snapshot() -> Optional[PortfolioSnapshot]:
|
||||
return None
|
||||
|
||||
if not row:
|
||||
return None
|
||||
# 如果没有快照,返回初始状态(只有初始资金)
|
||||
config = get_portfolio_config()
|
||||
initial_capital = config["initial_capital"]
|
||||
return PortfolioSnapshot(
|
||||
trade_date="", # 空日期表示初始状态
|
||||
total_value=initial_capital,
|
||||
cash=initial_capital,
|
||||
invested_value=0.0,
|
||||
unrealized_pnl=0.0,
|
||||
realized_pnl=0.0,
|
||||
net_flow=0.0,
|
||||
exposure=0.0,
|
||||
notes="Initial portfolio state",
|
||||
metadata={"initial_capital": initial_capital, "currency": config["currency"]},
|
||||
)
|
||||
|
||||
return PortfolioSnapshot(
|
||||
trade_date=row["trade_date"],
|
||||
total_value=row["total_value"],
|
||||
|
||||
149
app/utils/portfolio_init.py
Normal file
149
app/utils/portfolio_init.py
Normal file
@ -0,0 +1,149 @@
|
||||
"""Initialize portfolio database tables."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from .logging import get_logger
|
||||
from .config import get_config
|
||||
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
LOG_EXTRA = {"stage": "portfolio_init"}
|
||||
|
||||
|
||||
def get_portfolio_config() -> dict[str, Any]:
|
||||
"""获取投资组合配置.
|
||||
|
||||
Returns:
|
||||
包含以下字段的字典:
|
||||
- initial_capital: 初始投资金额
|
||||
- currency: 货币类型
|
||||
- position_limits: 仓位限制
|
||||
"""
|
||||
config = get_config()
|
||||
settings = config.portfolio if hasattr(config, "portfolio") else None
|
||||
|
||||
if not settings:
|
||||
from .config import PortfolioSettings
|
||||
settings = PortfolioSettings()
|
||||
|
||||
return {
|
||||
"initial_capital": settings.initial_capital,
|
||||
"currency": settings.currency,
|
||||
"position_limits": {
|
||||
"max_position": settings.max_position,
|
||||
"min_position": settings.min_position,
|
||||
"max_total_positions": settings.max_total_positions,
|
||||
"max_sector_exposure": settings.max_sector_exposure
|
||||
}
|
||||
}
|
||||
|
||||
def update_portfolio_config(updates: dict[str, Any]) -> None:
|
||||
"""更新投资组合配置.
|
||||
|
||||
Args:
|
||||
updates: 要更新的配置项字典
|
||||
"""
|
||||
from .config import get_config, save_config, PortfolioSettings
|
||||
|
||||
# 获取当前配置
|
||||
config = get_config()
|
||||
|
||||
# 创建新的投资组合设置
|
||||
portfolio = PortfolioSettings(
|
||||
initial_capital=updates["initial_capital"],
|
||||
currency=updates["currency"],
|
||||
max_position=updates["position_limits"]["max_position"],
|
||||
min_position=updates["position_limits"]["min_position"],
|
||||
max_total_positions=updates["position_limits"]["max_total_positions"],
|
||||
max_sector_exposure=updates["position_limits"]["max_sector_exposure"]
|
||||
)
|
||||
|
||||
# 更新配置
|
||||
config.portfolio = portfolio
|
||||
save_config(config)
|
||||
|
||||
|
||||
|
||||
SCHEMA_STATEMENTS = [
|
||||
# 投资池表
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS investment_pool (
|
||||
trade_date TEXT,
|
||||
ts_code TEXT,
|
||||
score REAL,
|
||||
status TEXT,
|
||||
rationale TEXT,
|
||||
tags TEXT, -- JSON array
|
||||
metadata TEXT, -- JSON object
|
||||
PRIMARY KEY (trade_date, ts_code)
|
||||
)
|
||||
""",
|
||||
|
||||
# 数据获取任务表
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS fetch_jobs (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
job_type TEXT NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
error_msg TEXT,
|
||||
metadata TEXT -- JSON object for additional info
|
||||
)
|
||||
);
|
||||
""",
|
||||
|
||||
# 持仓表
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS portfolio_positions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
ts_code TEXT NOT NULL,
|
||||
opened_date TEXT NOT NULL,
|
||||
closed_date TEXT,
|
||||
quantity REAL NOT NULL,
|
||||
cost_price REAL NOT NULL,
|
||||
market_price REAL,
|
||||
market_value REAL,
|
||||
realized_pnl REAL DEFAULT 0,
|
||||
unrealized_pnl REAL DEFAULT 0,
|
||||
target_weight REAL,
|
||||
status TEXT NOT NULL DEFAULT 'open',
|
||||
notes TEXT,
|
||||
metadata TEXT -- JSON object
|
||||
);
|
||||
""",
|
||||
|
||||
# 投资组合快照表
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS portfolio_snapshots (
|
||||
trade_date TEXT PRIMARY KEY,
|
||||
total_value REAL,
|
||||
cash REAL,
|
||||
invested_value REAL,
|
||||
unrealized_pnl REAL,
|
||||
realized_pnl REAL,
|
||||
net_flow REAL,
|
||||
exposure REAL,
|
||||
notes TEXT,
|
||||
metadata TEXT -- JSON object
|
||||
);
|
||||
""",
|
||||
]
|
||||
|
||||
|
||||
def initialize_database_schema() -> None:
|
||||
"""Create database tables if they don't exist."""
|
||||
from .db import db_session
|
||||
|
||||
with db_session() as conn:
|
||||
for statement in SCHEMA_STATEMENTS:
|
||||
try:
|
||||
conn.execute(statement)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.exception(
|
||||
"执行 schema 语句失败",
|
||||
extra={"sql": statement, **LOG_EXTRA}
|
||||
)
|
||||
raise
|
||||
57
tests/test_portfolio_config.py
Normal file
57
tests/test_portfolio_config.py
Normal file
@ -0,0 +1,57 @@
|
||||
"""Test portfolio configuration and initialization."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from app.utils.portfolio import get_latest_snapshot
|
||||
from app.utils.db import db_session
|
||||
|
||||
|
||||
def test_default_portfolio_config():
|
||||
"""Test default portfolio configuration."""
|
||||
# Mock db_session as a context manager
|
||||
mock_session = MagicMock()
|
||||
mock_session.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_session.__exit__ = MagicMock(return_value=None)
|
||||
|
||||
# Mock the database query result
|
||||
mock_session.execute.return_value.fetchone.return_value = None
|
||||
|
||||
# 使用默认配置
|
||||
with patch("app.utils.portfolio.get_portfolio_config") as mock_config, \
|
||||
patch("app.utils.portfolio.db_session", return_value=mock_session):
|
||||
mock_config.return_value = {
|
||||
"initial_capital": 1000000,
|
||||
"currency": "CNY"
|
||||
}
|
||||
|
||||
snapshot = get_latest_snapshot()
|
||||
assert snapshot is not None
|
||||
assert snapshot.total_value == 1000000
|
||||
assert snapshot.cash == 1000000
|
||||
assert snapshot.metadata["initial_capital"] == 1000000
|
||||
assert snapshot.metadata["currency"] == "CNY"
|
||||
|
||||
|
||||
def test_custom_portfolio_config():
|
||||
"""Test custom portfolio configuration."""
|
||||
# Mock db_session as a context manager
|
||||
mock_session = MagicMock()
|
||||
mock_session.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_session.__exit__ = MagicMock(return_value=None)
|
||||
|
||||
# Mock the database query result
|
||||
mock_session.execute.return_value.fetchone.return_value = None
|
||||
|
||||
# 使用自定义配置
|
||||
with patch("app.utils.portfolio.get_portfolio_config") as mock_config, \
|
||||
patch("app.utils.portfolio.db_session", return_value=mock_session):
|
||||
mock_config.return_value = {
|
||||
"initial_capital": 2000000,
|
||||
"currency": "USD"
|
||||
}
|
||||
|
||||
snapshot = get_latest_snapshot()
|
||||
assert snapshot is not None
|
||||
assert snapshot.total_value == 2000000
|
||||
assert snapshot.cash == 2000000
|
||||
assert snapshot.metadata["initial_capital"] == 2000000
|
||||
assert snapshot.metadata["currency"] == "USD"
|
||||
Loading…
Reference in New Issue
Block a user