diff --git a/app/ui/views/backtest.py b/app/ui/views/backtest.py index 8b27923..1b205b4 100644 --- a/app/ui/views/backtest.py +++ b/app/ui/views/backtest.py @@ -381,7 +381,7 @@ def render_backtest_review() -> None: title="风险事件分布", ) agg_fig.update_layout(height=320, margin=dict(l=10, r=10, t=40, b=20)) - st.plotly_chart(agg_fig, use_container_width=True) + st.plotly_chart(agg_fig, width="stretch") except Exception: # noqa: BLE001 LOGGER.debug("绘制风险事件分布失败", extra=LOG_EXTRA) except Exception: # noqa: BLE001 diff --git a/app/ui/views/factor_calculation.py b/app/ui/views/factor_calculation.py index 0c30482..ab787ff 100644 --- a/app/ui/views/factor_calculation.py +++ b/app/ui/views/factor_calculation.py @@ -274,7 +274,7 @@ def render_factor_calculation() -> None: if df_data: df = pd.DataFrame(df_data) - st.dataframe(df.head(100), use_container_width=True) # 只显示前100条 + st.dataframe(df.head(100), width="stretch") # 只显示前100条 st.info(f"共 {len(df_data)} 条因子记录(显示前100条)") else: st.info("没有找到因子计算结果") diff --git a/app/ui/views/market.py b/app/ui/views/market.py index 0ea1778..de6a08d 100644 --- a/app/ui/views/market.py +++ b/app/ui/views/market.py @@ -194,10 +194,10 @@ def render_market_visualization() -> None: ] ) fig.update_layout(title=f"{ts_code} K线图", xaxis_title="日期", yaxis_title="价格") - st.plotly_chart(fig, use_container_width=True) + st.plotly_chart(fig, width="stretch") fig_vol = px.bar(df, x="trade_date", y="vol", title="成交量") - st.plotly_chart(fig_vol, use_container_width=True) + st.plotly_chart(fig_vol, width="stretch") df_ma = df.copy() df_ma["MA5"] = df_ma["close"].rolling(window=5).mean() @@ -205,6 +205,6 @@ def render_market_visualization() -> None: df_ma["MA60"] = df_ma["close"].rolling(window=60).mean() fig_ma = px.line(df_ma, x="trade_date", y=["close", "MA5", "MA20", "MA60"], title="均线对比") - st.plotly_chart(fig_ma, use_container_width=True) + st.plotly_chart(fig_ma, width="stretch") st.dataframe(df, hide_index=True, width='stretch') diff --git a/app/ui/views/stock_eval.py b/app/ui/views/stock_eval.py index e681bea..537c5d2 100644 --- a/app/ui/views/stock_eval.py +++ b/app/ui/views/stock_eval.py @@ -208,7 +208,7 @@ def render_stock_evaluation() -> None: st.dataframe( result_df, hide_index=True, - use_container_width=True + width="stretch" ) # 绘制IC均值分布 @@ -237,7 +237,7 @@ def render_stock_evaluation() -> None: st.dataframe( score_df, hide_index=True, - use_container_width=True + width="stretch" ) # 添加入池功能 @@ -260,7 +260,12 @@ def _calculate_stock_scores( # 标准化权重 weights = np.array(factor_weights) - weights = weights / np.sum(np.abs(weights)) + abs_sum = np.sum(np.abs(weights)) + if abs_sum > 0: # 避免除以零 + weights = weights / abs_sum + else: + # 如果所有权重都是零,则使用均匀分布 + weights = np.ones_like(weights) / len(weights) # 获取所有股票的因子值 stocks = universe or broker.get_all_stocks(eval_date.strftime("%Y%m%d")) diff --git a/app/utils/data_access.py b/app/utils/data_access.py index f75a3de..18794fe 100644 --- a/app/utils/data_access.py +++ b/app/utils/data_access.py @@ -1397,6 +1397,83 @@ class DataBroker: self._coverage_cache.clear() LOGGER.info("数据覆盖缓存已清除", extra=LOG_EXTRA) + def get_stock_info(self, ts_code: str, trade_date: str = None) -> Optional[Dict[str, Any]]: + """获取股票基本信息。 + + Args: + ts_code: 股票代码 + trade_date: 交易日期,默认为最新日期 + + Returns: + Dict: 股票基本信息,包含名称、行业等 + """ + if not trade_date: + # 如果没有提供交易日期,使用当前日期 + trade_date = datetime.now().strftime("%Y%m%d") + + try: + # 获取股票基本信息 + info = self.fetch_latest( + ts_code=ts_code, + trade_date=trade_date, + fields=["stock_basic.name", "stock_basic.industry"] + ) + + if not info: + return None + + # 添加股票代码 + result = {"ts_code": ts_code} + result.update(info) + + return result + except Exception as exc: + LOGGER.debug( + "获取股票信息失败 ts_code=%s err=%s", + ts_code, + exc, + extra=LOG_EXTRA + ) + return None + + def fetch_latest_factor(self, ts_code: str, factor: str, eval_date: date) -> Optional[float]: + """获取指定股票的最新因子值。 + + Args: + ts_code: 股票代码 + factor: 因子名称 + eval_date: 评估日期 + + Returns: + float: 因子值,如果获取失败则返回None + """ + trade_date = eval_date.strftime("%Y%m%d") + + try: + # 构建因子字段名称 + factor_field = f"factors.{factor}" + + # 获取因子值 + result = self.fetch_latest( + ts_code=ts_code, + trade_date=trade_date, + fields=[factor_field] + ) + + if not result or factor_field not in result: + return None + + return result[factor_field] + except Exception as exc: + LOGGER.debug( + "获取因子值失败 ts_code=%s factor=%s err=%s", + ts_code, + factor, + exc, + extra=LOG_EXTRA + ) + return None + def get_data_coverage(self, start_date: str, end_date: str) -> Dict: """获取指定日期范围内的数据覆盖情况。