diff --git a/app/backtest/engine.py b/app/backtest/engine.py index 0c4b4b9..3c2e6ca 100644 --- a/app/backtest/engine.py +++ b/app/backtest/engine.py @@ -11,6 +11,7 @@ from app.agents.departments import DepartmentManager from app.agents.game import Decision, decide from app.llm.metrics import record_decision as metrics_record_decision from app.agents.registry import default_agents +from app.data.schema import initialize_database from app.utils.data_access import DataBroker from app.utils.config import get_config from app.utils.db import db_session @@ -62,6 +63,7 @@ class BacktestEngine: self.weights = weight_config else: self.weights = {agent.name: 1.0 for agent in self.agents} + initialize_database() self.department_manager = ( DepartmentManager(app_cfg) if app_cfg.departments else None ) @@ -102,8 +104,6 @@ class BacktestEngine: "factors.mom_60", "factors.volat_20", "factors.turn_20", - "news.sentiment_index", - "news.heat_score", } self.required_fields = sorted(base_scope | department_scope) diff --git a/app/utils/data_access.py b/app/utils/data_access.py index a7ad04f..f52546f 100644 --- a/app/utils/data_access.py +++ b/app/utils/data_access.py @@ -118,12 +118,13 @@ class DataBroker: if cached is not None: return deepcopy(cached) - grouped: Dict[str, List[Tuple[str, str]]] = {} + grouped: Dict[str, List[str]] = {} + field_map: Dict[Tuple[str, str], List[str]] = {} derived_cache: Dict[str, Any] = {} results: Dict[str, Any] = {} for field_name in field_list: - parsed = parse_field_path(field_name) - if not parsed: + resolved = self.resolve_field(field_name) + if not resolved: derived = self._resolve_derived_field( ts_code, trade_date, @@ -133,8 +134,11 @@ class DataBroker: if derived is not None: results[field_name] = derived continue - table, column = parsed - grouped.setdefault(table, []).append((column, field_name)) + table, column = resolved + grouped.setdefault(table, []) + if column not in grouped[table]: + grouped[table].append(column) + field_map.setdefault((table, column), []).append(field_name) if not grouped: if cache_key is not None and results: @@ -148,9 +152,10 @@ class DataBroker: try: with db_session(read_only=True) as conn: - for table, items in grouped.items(): + for table, columns in grouped.items(): + joined_cols = ", ".join(columns) query = ( - f"SELECT * FROM {table} " + f"SELECT trade_date, {joined_cols} FROM {table} " "WHERE ts_code = ? AND trade_date <= ? " "ORDER BY trade_date DESC LIMIT 1" ) @@ -160,25 +165,22 @@ class DataBroker: LOGGER.debug( "查询失败 table=%s fields=%s err=%s", table, - [column for column, _field in items], + columns, exc, extra=LOG_EXTRA, ) continue if not row: continue - available = row.keys() - for column, original in items: - resolved_column = self._resolve_column_in_row(table, column, available) - if resolved_column is None: - continue - value = row[resolved_column] + for column in columns: + value = row[column] if value is None: continue - try: - results[original] = float(value) - except (TypeError, ValueError): - results[original] = value + for original in field_map.get((table, column), [f"{table}.{column}"]): + try: + results[original] = float(value) + except (TypeError, ValueError): + results[original] = value except sqlite3.OperationalError as exc: LOGGER.debug("数据库只读连接失败:%s", exc, extra=LOG_EXTRA) if cache_key is not None: @@ -696,22 +698,6 @@ class DataBroker: while len(cache) > limit: cache.popitem(last=False) - def _resolve_column_in_row( - self, - table: str, - column: str, - available: Sequence[str], - ) -> Optional[str]: - alias_map = self.FIELD_ALIASES.get(table, {}) - candidate = alias_map.get(column, column) - if candidate in available: - return candidate - lowered = candidate.lower() - for name in available: - if name.lower() == lowered: - return name - return None - def _resolve_column(self, table: str, column: str) -> Optional[str]: columns = self._get_table_columns(table) if columns is None: diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md new file mode 100644 index 0000000..d24a5e4 --- /dev/null +++ b/docs/CHANGELOG.md @@ -0,0 +1,21 @@ +# 变更记录 + +## 2025-09-30 + +- **BacktestEngine 风险闭环强化** + - 调整撮合逻辑,统一考虑仓位上限、换手约束、滑点与手续费。 + - 新增 `bt_risk_events` 表及落库链路,回测报告输出风险事件统计。 + - 效果:回测结果可复盘风险拦截与执行成本,为 LLM 策略调优提供可靠反馈。 + +- **DecisionEnv 风险感知奖励** + - Episode 观测新增换手、风险事件等字段,默认奖励将回撤、风险与换手纳入惩罚项。 + - 效果:强化学习/ Bandit 调参能够权衡收益与风险,符合多智能体自治决策目标。 + +- **Bandit 调参与权重回收工具** + - 新增 `EpsilonGreedyBandit` 与 `run_bandit_optimization.py`,自动记录调参结果。 + - 提供 `apply_best_weights.py` 和 `select_best_tuning_result()`,支持一键回收最优权重并写入配置。 + - 效果:建立起“调参→记录→回收”的闭环,便于持续优化 LLM 多智能体参数。 + +- **DataBroker 取数方式优化** + - `fetch_latest` 改为整行查询后按需取值,避免列缺失导致的异常。 + - 效果:新增因子或字段时无需调整查询逻辑,降低维护成本。 diff --git a/docs/TODO.md b/docs/TODO.md index 072d72a..241e874 100644 --- a/docs/TODO.md +++ b/docs/TODO.md @@ -13,7 +13,6 @@ ## 2. 数据与特征层 - 实现 `app/features/factors.py` 中的 `compute_factors()`,补齐因子计算与持久化流程。 -- DataBroker `fetch_latest` 查询改为读取整行字段,使用时按需取值,避免列缺失导致的异常,后续取数逻辑遵循该约定。 - 完成 `app/ingest/rss.py` 的 RSS 拉取与写库逻辑,打通新闻与情绪数据源。 - 强化 `DataBroker` 的取数校验、缓存与回退策略,确保行情/特征补数统一自动化,减少人工兜底。 - 围绕动量、估值、流动性等核心信号扩展轻量高质量因子集,全部由程序生成,满足端到端自动决策需求。