diff --git a/app/features/factors.py b/app/features/factors.py index 5a6cd3c..81f3e3f 100644 --- a/app/features/factors.py +++ b/app/features/factors.py @@ -4,7 +4,7 @@ from __future__ import annotations import re from dataclasses import dataclass from datetime import datetime, date, timezone -from typing import Dict, Iterable, List, Optional, Sequence +from typing import Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union from app.core.indicators import momentum, rolling_mean, volatility from app.data.schema import initialize_database @@ -47,15 +47,26 @@ DEFAULT_FACTORS: List[FactorSpec] = [ def compute_factors( trade_date: date, factors: Iterable[FactorSpec] = DEFAULT_FACTORS, - *, + *, ts_codes: Optional[Sequence[str]] = None, skip_existing: bool = False, + batch_size: int = 100, ) -> List[FactorResult]: """Calculate and persist factor values for the requested date. ``ts_codes`` can be supplied to restrict computation to a subset of the universe. When ``skip_existing`` is True, securities that already have an entry for ``trade_date`` will be ignored. + + Args: + trade_date: 交易日日期 + factors: 要计算的因子列表 + ts_codes: 可选,限制计算的证券代码列表 + skip_existing: 是否跳过已存在的因子值 + batch_size: 批处理大小,用于优化性能 + + Returns: + 因子计算结果列表 """ specs = [spec for spec in factors if spec.window >= 0] @@ -85,18 +96,62 @@ def compute_factors( ) return [] + LOGGER.info( + "开始计算因子 universe_size=%s factors=%s trade_date=%s", + len(universe), + [spec.name for spec in specs], + trade_date_str, + extra=LOG_EXTRA, + ) + + # 数据有效性校验初始化 + validation_stats = { + "total": len(universe), + "skipped": 0, + "success": 0, + "data_missing": 0, + "outliers": 0 + } + broker = DataBroker() results: List[FactorResult] = [] rows_to_persist: List[tuple[str, Dict[str, float | None]]] = [] - for ts_code in universe: - values = _compute_security_factors(broker, ts_code, trade_date_str, specs) - if not values: - continue - results.append(FactorResult(ts_code=ts_code, trade_date=trade_date, values=values)) - rows_to_persist.append((ts_code, values)) + + # 分批处理以优化性能 + for i in range(0, len(universe), batch_size): + batch = universe[i:i+batch_size] + batch_results = _compute_batch_factors(broker, batch, trade_date_str, specs, validation_stats) + + for ts_code, values in batch_results: + if values: + results.append(FactorResult(ts_code=ts_code, trade_date=trade_date, values=values)) + rows_to_persist.append((ts_code, values)) + + # 显示进度 + processed = min(i + batch_size, len(universe)) + if processed % (batch_size * 5) == 0 or processed == len(universe): + LOGGER.info( + "因子计算进度: %s/%s (%.1f%%) 成功:%s 跳过:%s 数据缺失:%s 异常值:%s", + processed, len(universe), + (processed / len(universe)) * 100, + validation_stats["success"], + validation_stats["skipped"], + validation_stats["data_missing"], + validation_stats["outliers"], + extra=LOG_EXTRA, + ) if rows_to_persist: _persist_factor_rows(trade_date_str, rows_to_persist, specs) + + LOGGER.info( + "因子计算完成 总数量:%s 成功:%s 失败:%s", + len(universe), + validation_stats["success"], + validation_stats["total"] - validation_stats["success"], + extra=LOG_EXTRA, + ) + return results @@ -184,17 +239,134 @@ def _list_trade_dates( return [row["trade_date"] for row in rows if row["trade_date"]] +def _compute_batch_factors( + broker: DataBroker, + ts_codes: List[str], + trade_date: str, + specs: Sequence[FactorSpec], + validation_stats: Dict[str, int], +) -> List[tuple[str, Dict[str, float | None]]]: + """批量计算多个证券的因子值,提高计算效率""" + batch_results = [] + + for ts_code in ts_codes: + try: + # 先检查数据可用性 + if not _check_data_availability(broker, ts_code, trade_date, specs): + validation_stats["data_missing"] += 1 + continue + + # 计算因子值 + values = _compute_security_factors(broker, ts_code, trade_date, specs) + + if values: + # 检测并处理异常值 + values = _detect_and_handle_outliers(values, ts_code) + batch_results.append((ts_code, values)) + validation_stats["success"] += 1 + else: + validation_stats["skipped"] += 1 + except Exception as e: + LOGGER.error( + "计算因子失败 ts_code=%s err=%s", + ts_code, + str(e), + extra=LOG_EXTRA, + ) + validation_stats["skipped"] += 1 + + return batch_results + + +def _check_data_availability( + broker: DataBroker, + ts_code: str, + trade_date: str, + specs: Sequence[FactorSpec], +) -> bool: + """检查证券数据是否足够计算所有请求的因子""" + # 获取最小需要的数据天数 + min_days = 1 # 至少需要当天的数据 + for spec in specs: + if spec.window > min_days: + min_days = spec.window + + # 检查基本数据是否存在 + basic_data = broker.fetch_latest( + ts_code, + trade_date, + ["daily.close", "daily_basic.turnover_rate"], + ) + + # 检查时间序列数据是否足够 + close_check = broker.fetch_series("daily", "close", ts_code, trade_date, min_days) + + return ( + bool(basic_data.get("daily.close")) and + len(close_check) >= min_days + ) + + +def _detect_and_handle_outliers( + values: Dict[str, float | None], + ts_code: str, +) -> Dict[str, float | None]: + """检测并处理因子值中的异常值""" + result = values.copy() + outliers_found = False + + # 动量因子异常值检测 + for key in [k for k in values if k.startswith("mom_") and values[k] is not None]: + value = values[key] + # 异常值检测规则:动量值绝对值大于3视为异常 + if abs(value) > 3.0: + LOGGER.debug( + "检测到动量因子异常值 ts_code=%s factor=%s value=%.4f", + ts_code, key, value, + extra=LOG_EXTRA, + ) + # 限制到合理范围 + result[key] = min(3.0, max(-3.0, value)) + outliers_found = True + + # 波动率因子异常值检测 + for key in [k for k in values if k.startswith("volat_") and values[k] is not None]: + value = values[key] + # 异常值检测规则:波动率大于100%视为异常 + if value > 1.0: + LOGGER.debug( + "检测到波动率因子异常值 ts_code=%s factor=%s value=%.4f", + ts_code, key, value, + extra=LOG_EXTRA, + ) + # 限制到合理范围 + result[key] = min(1.0, value) + outliers_found = True + + if outliers_found: + LOGGER.debug( + "处理后因子值 ts_code=%s values=%s", + ts_code, {k: f"{v:.4f}" for k, v in result.items() if v is not None}, + extra=LOG_EXTRA, + ) + + return result + + def _compute_security_factors( broker: DataBroker, ts_code: str, trade_date: str, specs: Sequence[FactorSpec], ) -> Dict[str, float | None]: + """计算单个证券的因子值""" + # 确定所需的最大窗口大小 close_windows = [spec.window for spec in specs if _factor_prefix(spec.name) in {"mom", "volat"}] turnover_windows = [spec.window for spec in specs if _factor_prefix(spec.name) == "turn"] max_close_window = max(close_windows) if close_windows else 0 max_turn_window = max(turnover_windows) if turnover_windows else 0 + # 获取所需的时间序列数据 close_series = _fetch_series_values( broker, "daily", @@ -203,6 +375,12 @@ def _compute_security_factors( trade_date, max_close_window, ) + + # 数据有效性检查 + if not close_series: + LOGGER.debug("缺少收盘价数据 ts_code=%s", ts_code, extra=LOG_EXTRA) + return {} + turnover_series = _fetch_series_values( broker, "daily_basic", @@ -212,6 +390,7 @@ def _compute_security_factors( max_turn_window, ) + # 获取最新字段值 latest_fields = broker.fetch_latest( ts_code, trade_date, @@ -224,6 +403,7 @@ def _compute_security_factors( ], ) + # 计算各个因子值 results: Dict[str, float | None] = {} for spec in specs: prefix = _factor_prefix(spec.name) @@ -258,6 +438,11 @@ def _compute_security_factors( ts_code, extra=LOG_EXTRA, ) + + # 确保返回结果不为空 + if not any(v is not None for v in results.values()): + return {} + return results @@ -266,8 +451,14 @@ def _persist_factor_rows( rows: Sequence[tuple[str, Dict[str, float | None]]], specs: Sequence[FactorSpec], ) -> None: + """优化的因子结果持久化函数,支持批量写入""" + if not rows: + return + columns = sorted({spec.name for spec in specs}) timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") + + # SQL语句准备 insert_columns = ["ts_code", "trade_date", "updated_at", *columns] placeholders = ", ".join(["?"] * len(insert_columns)) update_clause = ", ".join( @@ -279,11 +470,55 @@ def _persist_factor_rows( f"ON CONFLICT(ts_code, trade_date) DO UPDATE SET {update_clause}" ) + # 准备批量写入数据 + batch_size = 500 # 批处理大小 + batch_payloads = [] + + for ts_code, values in rows: + # 过滤掉全部为None的行 + if not any(values.get(col) is not None for col in columns): + continue + + payload = [ts_code, trade_date, timestamp] + payload.extend(values.get(column) for column in columns) + batch_payloads.append(payload) + + if not batch_payloads: + LOGGER.debug("无可持久化的有效因子数据", extra=LOG_EXTRA) + return + + # 执行批量写入 + total_inserted = 0 with db_session() as conn: - for ts_code, values in rows: - payload = [ts_code, trade_date, timestamp] - payload.extend(values.get(column) for column in columns) - conn.execute(sql, payload) + # 分批执行以避免SQLite参数限制 + for i in range(0, len(batch_payloads), batch_size): + batch = batch_payloads[i:i+batch_size] + try: + conn.executemany(sql, batch) + batch_count = len(batch) + total_inserted += batch_count + + if batch_count % (batch_size * 5) == 0: + LOGGER.debug( + "因子数据持久化进度: %s/%s", + min(i + batch_size, len(batch_payloads)), + len(batch_payloads), + extra=LOG_EXTRA, + ) + except sqlite3.Error as e: + LOGGER.error( + "因子数据持久化失败 批次=%s-%s err=%s", + i, min(i + batch_size, len(batch_payloads)), + str(e), + extra=LOG_EXTRA, + ) + + LOGGER.info( + "因子数据持久化完成 写入记录数=%s 总记录数=%s", + total_inserted, + len(batch_payloads), + extra=LOG_EXTRA, + ) def _ensure_factor_columns(specs: Sequence[FactorSpec]) -> None: @@ -322,21 +557,39 @@ def _factor_prefix(name: str) -> str: def _valuation_score(value: object, *, scale: float) -> float: + """计算估值指标的标准化分数""" try: numeric = float(value) except (TypeError, ValueError): return 0.0 + + # 有效性检查 if numeric <= 0: return 0.0 + + # 异常值处理:限制估值指标的上限 + max_limit = scale * 10 # 设置十倍scale为上限 + if numeric > max_limit: + numeric = max_limit + + # 计算分数 score = scale / (scale + numeric) return max(0.0, min(1.0, score)) def _volume_ratio_score(value: object) -> float: + """计算量比指标的标准化分数""" try: numeric = float(value) except (TypeError, ValueError): return 0.0 + + # 有效性检查 if numeric < 0: numeric = 0.0 + + # 异常值处理:设置量比上限为20 + if numeric > 20: + numeric = 20 + return max(0.0, min(1.0, numeric / 10.0)) diff --git a/app/ingest/rss.py b/app/ingest/rss.py index a1dcc03..98b10b2 100644 --- a/app/ingest/rss.py +++ b/app/ingest/rss.py @@ -4,7 +4,7 @@ from __future__ import annotations import json import re import sqlite3 -from dataclasses import dataclass, replace +from dataclasses import dataclass, field, replace from datetime import datetime, timedelta, timezone from email.utils import parsedate_to_datetime from typing import Dict, Iterable, List, Optional, Sequence, Tuple @@ -37,33 +37,51 @@ DEFAULT_TIMEOUT = 10.0 MAX_SUMMARY_LENGTH = 1500 POSITIVE_KEYWORDS: Tuple[str, ...] = ( - "利好", - "增长", - "超预期", - "创新高", - "增持", - "回购", - "盈利", - "strong", - "beat", - "upgrade", + # 中文积极关键词 + "利好", "增长", "超预期", "创新高", "增持", "回购", "盈利", + "高增长", "业绩好", "优秀", "强劲", "突破", "新高", "上升", + "上涨", "反弹", "复苏", "景气", "扩张", "加速", "改善", + "提升", "增加", "优化", "利好消息", "超预期", "超出预期", + "盈利超预期", "利润增长", "收入增长", "订单增长", "销量增长", + "高景气", "量价齐升", "拐点", "反转", "政策利好", "政策支持", + # 英文积极关键词 + "strong", "beat", "upgrade", "growth", "positive", "better", + "exceed", "surpass", "outperform", "rally", "bullish", "upbeat", + "improve", "increase", "rise", "gain", "profit", "earnings", + "recovery", "expansion", "boom", "upside", "promising", ) NEGATIVE_KEYWORDS: Tuple[str, ...] = ( - "利空", - "下跌", - "亏损", - "裁员", - "违约", - "处罚", - "暴跌", - "减持", - "downgrade", - "miss", + # 中文消极关键词 + "利空", "下跌", "亏损", "裁员", "违约", "处罚", "暴跌", "减持", + "业绩差", "下滑", "下降", "恶化", "亏损", "不及预期", "低于预期", + "业绩下滑", "利润下降", "收入下降", "订单减少", "销量减少", + "利空消息", "不及预期", "低于预期", "亏损超预期", "利润下滑", + "需求萎缩", "量价齐跌", "拐点向下", "政策利空", "政策收紧", + "监管收紧", "处罚", "调查", "违规", "风险", "警示", "预警", + "降级", "抛售", "减持", "暴跌", "大跌", "下挫", "阴跌", + # 英文消极关键词 + "downgrade", "miss", "weak", "decline", "negative", "worse", + "drop", "fall", "loss", "losses", "slowdown", "contract", + "bearish", "pessimistic", "worsen", "decrease", "reduce", + "slide", "plunge", "crash", "deteriorate", "risk", "warning", + "regulatory", "penalty", "investigation", ) A_SH_CODE_PATTERN = re.compile(r"\b(\d{6})(?:\.(SH|SZ))?\b", re.IGNORECASE) HK_CODE_PATTERN = re.compile(r"\b(\d{4})\.HK\b", re.IGNORECASE) +# 行业关键词映射表 +INDUSTRY_KEYWORDS: Dict[str, List[str]] = { + "半导体": ["半导体", "芯片", "集成电路", "IC", "晶圆", "封装", "设计", "制造", "光刻"], + "新能源": ["新能源", "光伏", "太阳能", "风电", "风电设备", "锂电池", "储能", "氢能"], + "医药": ["医药", "生物制药", "创新药", "医疗器械", "疫苗", "CXO", "CDMO", "CRO"], + "消费": ["消费", "食品", "饮料", "白酒", "啤酒", "乳制品", "零食", "零售", "家电"], + "科技": ["科技", "人工智能", "AI", "云计算", "大数据", "互联网", "软件", "SaaS"], + "金融": ["银行", "保险", "券商", "证券", "金融", "资管", "基金", "投资"], + "地产": ["房地产", "地产", "物业", "建筑", "建材", "家居"], + "汽车": ["汽车", "新能源汽车", "智能汽车", "自动驾驶", "零部件", "锂电"], +} + @dataclass class RssFeedConfig: @@ -87,7 +105,9 @@ class RssItem: published: datetime summary: str source: str - ts_codes: Tuple[str, ...] = () + ts_codes: List[str] = field(default_factory=list) + industries: List[str] = field(default_factory=list) # 新增:相关行业列表 + important_keywords: List[str] = field(default_factory=list) # 新增:重要关键词列表 DEFAULT_RSS_SOURCES: Tuple[RssFeedConfig, ...] = () @@ -260,11 +280,23 @@ def save_news_items(items: Iterable[RssItem]) -> int: text_payload = f"{item.title}\n{item.summary}" sentiment = _estimate_sentiment(text_payload) base_codes = tuple(code for code in item.ts_codes if code) - heat = _estimate_heat(item.published, now, len(base_codes), sentiment) + # 更新调用,添加新增的参数 + heat = _estimate_heat( + item.published, + now, + len(base_codes), + sentiment, + text_length=len(text_payload), + industry_count=len(item.industries) + ) + # 构建包含更多信息的entities对象 entities = json.dumps( { "ts_codes": list(base_codes), "source_url": item.link, + "industries": item.industries, # 添加行业信息 + "important_keywords": item.important_keywords, # 添加重要关键词 + "text_length": len(text_payload), # 添加文本长度信息 }, ensure_ascii=False, ) @@ -723,6 +755,7 @@ def _assign_ts_codes( base_codes: Sequence[str], keywords: Sequence[str], ) -> List[str]: + """为新闻条目分配股票代码,并同时提取行业信息和重要关键词""" matches: set[str] = set() text = f"{item.title} {item.summary}".lower() if keywords: @@ -736,36 +769,205 @@ def _assign_ts_codes( detected = _detect_ts_codes(text) matches.update(detected) + + # 检测相关行业 + item.industries = _detect_industries(text) + + # 提取重要关键词 + item.important_keywords = _extract_important_keywords(text) + return [code for code in matches if code] +def _detect_industries(text: str) -> List[str]: + """根据文本内容检测相关行业""" + detected_industries = [] + text_lower = text.lower() + + for industry, keywords in INDUSTRY_KEYWORDS.items(): + for keyword in keywords: + if keyword.lower() in text_lower: + if industry not in detected_industries: + detected_industries.append(industry) + # 一个行业匹配一个关键词即可 + break + + return detected_industries + +def _extract_important_keywords(text: str) -> List[str]: + """从文本中提取重要关键词,包括情感词和行业词""" + important_keywords = [] + text_lower = text.lower() + + # 提取情感关键词 + for keyword in POSITIVE_KEYWORDS + NEGATIVE_KEYWORDS: + if keyword.lower() in text_lower and keyword not in important_keywords: + important_keywords.append(keyword) + + # 提取行业关键词 + for keywords in INDUSTRY_KEYWORDS.values(): + for keyword in keywords: + if keyword.lower() in text_lower and keyword not in important_keywords: + important_keywords.append(keyword) + + # 限制关键词数量 + return important_keywords[:10] # 最多返回10个关键词 + def _detect_ts_codes(text: str) -> List[str]: + """增强的股票代码检测函数,改进代码识别的准确性""" codes: set[str] = set() + + # 检测A股和港股代码 for match in A_SH_CODE_PATTERN.finditer(text): digits, suffix = match.groups() - if suffix: - codes.add(f"{digits}.{suffix.upper()}") - else: - exchange = "SH" if digits.startswith(tuple("569")) else "SZ" - codes.add(f"{digits}.{exchange}") + # 确保是有效的股票代码(避免误识别其他6位数字) + if _is_valid_stock_code(digits, suffix): + if suffix: + codes.add(f"{digits}.{suffix.upper()}") + else: + # 根据数字范围推断交易所 + exchange = "SH" if digits.startswith(tuple("569")) else "SZ" + codes.add(f"{digits}.{exchange}") + + # 检测港股代码 for match in HK_CODE_PATTERN.finditer(text): digits = match.group(1) + # 补全为4位数字 codes.add(f"{digits.zfill(4)}.HK") + + # 检测可能的股票简称和代码关联 + codes.update(_detect_codes_by_company_name(text)) + return sorted(codes) +def _is_valid_stock_code(digits: str, suffix: Optional[str]) -> bool: + """验证是否为有效的股票代码""" + # 排除明显不是股票代码的数字组合 + if len(digits) != 6: + return False + + # 上海证券交易所股票代码范围:600000-609999 (A股), 688000-688999 (科创板), 500000-599999 (基金) + # 深圳证券交易所股票代码范围:000001-009999 (主板), 300000-309999 (创业板), 002000-002999 (中小板) + # 这里做简单的范围验证,避免误识别 + if suffix and suffix.upper() in ("SH", "SZ"): + return True + + # 没有后缀时,通过数字范围判断 + code_int = int(digits) + return ( + (600000 <= code_int <= 609999) or # 上交所A股 + (688000 <= code_int <= 688999) or # 科创板 + (1 <= code_int <= 9999) or # 深交所主板 (去掉前导零) + (300000 <= code_int <= 309999) or # 创业板 + (2000 <= code_int <= 2999) # 中小板 (去掉前导零) + ) + + +def _detect_codes_by_company_name(text: str) -> List[str]: + """通过公司名称识别可能的股票代码 + 注意:这是一个简化版本,实际应用中可能需要更复杂的映射表 + """ + # 这里仅作为示例,实际应用中应该使用更完善的公司名称-代码映射 + # 这里我们返回空列表,但保留函数结构以便未来扩展 + return [] + + def _estimate_sentiment(text: str) -> float: + """增强的情感分析函数,提高情绪识别准确率""" normalized = text.lower() - score = 0 + score = 0.0 + positive_matches = 0 + negative_matches = 0 + + # 计算关键词匹配次数 for keyword in POSITIVE_KEYWORDS: if keyword.lower() in normalized: - score += 1 + # 情感词权重:根据重要性调整权重 + weight = _get_sentiment_keyword_weight(keyword, positive=True) + score += weight + positive_matches += 1 + for keyword in NEGATIVE_KEYWORDS: if keyword.lower() in normalized: - score -= 1 - if score == 0: - return 0.0 - return max(-1.0, min(1.0, score / 3.0)) + # 情感词权重:根据重要性调整权重 + weight = _get_sentiment_keyword_weight(keyword, positive=False) + score -= weight + negative_matches += 1 + + # 处理无匹配的情况 + if positive_matches == 0 and negative_matches == 0: + # 尝试通过否定词和转折词分析 + return _analyze_neutral_text(normalized) + + # 归一化情感得分 + max_score = max(3.0, positive_matches + negative_matches) # 确保分母不为零且有合理缩放 + normalized_score = score / max_score + + # 限制在[-1.0, 1.0]范围内 + return max(-1.0, min(1.0, normalized_score)) + + +def _get_sentiment_keyword_weight(keyword: str, positive: bool) -> float: + """根据关键词的重要性返回不同的权重""" + # 基础权重 + base_weight = 1.0 + + # 强情感词增加权重 + strong_positive = ["超预期", "超出预期", "盈利超预期", "利好", "upgrade", "beat"] + strong_negative = ["不及预期", "低于预期", "亏损超预期", "利空", "downgrade", "miss"] + + if positive: + if keyword in strong_positive: + return base_weight * 1.5 + else: + if keyword in strong_negative: + return base_weight * 1.5 + + # 弱情感词降低权重 + weak_positive = ["增长", "改善", "增加", "rise", "increase", "improve"] + weak_negative = ["下降", "减少", "恶化", "drop", "decrease", "decline"] + + if positive: + if keyword in weak_positive: + return base_weight * 0.8 + else: + if keyword in weak_negative: + return base_weight * 0.8 + + return base_weight + + +def _analyze_neutral_text(text: str) -> float: + """分析无明显情感词的文本""" + # 检查是否包含否定词和情感词的组合 + negation_words = ["不", "非", "无", "未", "没有", "不是", "不会"] + + # 简单的否定模式识别(实际应用中可能需要更复杂的NLP处理) + for neg_word in negation_words: + neg_pos = text.find(neg_word) + if neg_pos != -1: + # 检查否定词后面是否有积极或消极关键词 + window = text[neg_pos:neg_pos + 30] # 检查否定词后30个字符 + for pos_word in POSITIVE_KEYWORDS: + if pos_word.lower() in window: + return -0.3 # 否定积极词,轻微消极 + for neg_word2 in NEGATIVE_KEYWORDS: + if neg_word2.lower() in window: + return 0.3 # 否定消极词,轻微积极 + + # 检查是否包含中性偏积极或偏消极的表达 + neutral_positive = ["稳定", "平稳", "正常", "符合预期", "stable", "steady", "normal"] + neutral_negative = ["波动", "不确定", "风险", "挑战", "fluctuate", "uncertain", "risk"] + + for word in neutral_positive: + if word.lower() in text: + return 0.1 + for word in neutral_negative: + if word.lower() in text: + return -0.1 + + return 0.0 def _estimate_heat( @@ -773,12 +975,44 @@ def _estimate_heat( now: datetime, code_count: int, sentiment: float, + text_length: int = 0, + source_quality: float = 1.0, + industry_count: int = 0, ) -> float: + """增强的热度评分函数,考虑更多影响热度的因素""" + # 时效性得分(基础权重0.5) delta_hours = max(0.0, (now - published).total_seconds() / 3600.0) - recency = max(0.0, 1.0 - min(delta_hours, 72.0) / 72.0) - coverage_bonus = min(code_count, 3) * 0.05 - sentiment_bonus = min(abs(sentiment) * 0.1, 0.2) - heat = recency + coverage_bonus + sentiment_bonus + # 根据时间衰减曲线调整时效性得分 + if delta_hours < 1: + recency = 1.0 # 1小时内的新闻时效性最高 + elif delta_hours < 6: + recency = 0.8 # 1-6小时 + elif delta_hours < 24: + recency = 0.6 # 6-24小时 + elif delta_hours < 48: + recency = 0.3 # 24-48小时 + else: + recency = 0.1 # 超过48小时 + + # 覆盖度得分(基础权重0.2)- 涉及的股票数量 + coverage_score = min(code_count / 5, 1.0) * 0.2 + + # 情感强度得分(基础权重0.15) + sentiment_score = min(abs(sentiment), 1.0) * 0.15 + + # 内容丰富度得分(基础权重0.1) + content_score = min(text_length / 1000, 1.0) * 0.1 # 基于文本长度评估 + + # 行业覆盖度得分(基础权重0.05) + industry_score = min(industry_count / 3, 1.0) * 0.05 # 涉及多个行业可能更具影响力 + + # 来源质量调整因子(0.5-1.5) + source_adjustment = source_quality + + # 计算综合热度得分 + heat = (recency + coverage_score + sentiment_score + content_score + industry_score) * source_adjustment + + # 限制在[0.0, 1.0]范围内并保留4位小数 return max(0.0, min(1.0, round(heat, 4))) diff --git a/app/ui/streamlit_app.py b/app/ui/streamlit_app.py index a519186..3e735b9 100644 --- a/app/ui/streamlit_app.py +++ b/app/ui/streamlit_app.py @@ -450,6 +450,88 @@ def render_today_plan() -> None: ts_code = st.selectbox("标的", symbols, index=default_ts_idx) # ADD: batch selection for re-evaluation batch_symbols = st.multiselect("批量重评估(可多选)", symbols, default=[]) + + # 一键重评估所有标的按钮 + if st.button("一键重评估所有标的", type="primary", use_container_width=True): + with st.spinner("正在对所有标的进行重评估,请稍候..."): + try: + # 解析交易日 + trade_date_obj = None + try: + trade_date_obj = date.fromisoformat(str(trade_date)) + except Exception: + try: + trade_date_obj = datetime.strptime(str(trade_date), "%Y%m%d").date() + except Exception: + pass + if trade_date_obj is None: + raise ValueError(f"无法解析交易日:{trade_date}") + + progress = st.progress(0.0) + changes_all = [] + success_count = 0 + error_count = 0 + + # 遍历所有标的 + for idx, code in enumerate(symbols, start=1): + try: + # 保存重评估前的状态 + with db_session(read_only=True) as conn: + before_rows = conn.execute( + "SELECT agent, action FROM agent_utils WHERE trade_date = ? AND ts_code = ?", + (trade_date, code), + ).fetchall() + before_map = {row["agent"]: row["action"] for row in before_rows} + + # 执行重评估 + cfg = BtConfig( + id="reeval_ui_all", + name="UI All Re-eval", + start_date=trade_date_obj, + end_date=trade_date_obj, + universe=[code], + params={}, + ) + engine = BacktestEngine(cfg) + state = PortfolioState() + _ = engine.simulate_day(trade_date_obj, state) + + # 检查变化 + with db_session(read_only=True) as conn: + after_rows = conn.execute( + "SELECT agent, action FROM agent_utils WHERE trade_date = ? AND ts_code = ?", + (trade_date, code), + ).fetchall() + for row in after_rows: + agent = row["agent"] + new_action = row["action"] + old_action = before_map.get(agent) + if new_action != old_action: + changes_all.append({"代码": code, "代理": agent, "原动作": old_action, "新动作": new_action}) + success_count += 1 + except Exception as e: + LOGGER.exception(f"重评估 {code} 失败", extra=LOG_EXTRA) + error_count += 1 + + # 更新进度 + progress.progress(idx / len(symbols)) + + # 显示结果 + if error_count > 0: + st.error(f"一键重评估完成:成功 {success_count} 个,失败 {error_count} 个") + else: + st.success(f"一键重评估完成:所有 {success_count} 个标的重评估成功") + + # 显示变更记录 + if changes_all: + st.write("检测到以下动作变更:") + st.dataframe(pd.DataFrame(changes_all), hide_index=True, width='stretch') + + # 刷新页面数据 + st.rerun() + except Exception as exc: + LOGGER.exception("一键重评估失败", extra=LOG_EXTRA) + st.error(f"一键重评估执行过程中发生错误:{exc}") # sync URL params _set_query_params(date=str(trade_date), code=str(ts_code)) @@ -891,10 +973,175 @@ def render_today_plan() -> None: st.error(f"批量重评估失败:{exc}") -def render_backtest() -> None: - LOGGER.info("渲染回测页面", extra=LOG_EXTRA) - st.header("回测与复盘") - st.write("在此运行回测、展示净值曲线与代理贡献。") +def render_log_viewer() -> None: + """渲染日志钻取与历史对比视图页面。""" + LOGGER.info("渲染日志视图页面", extra=LOG_EXTRA) + st.header("日志钻取与历史对比") + st.write("查看系统运行日志,支持时间范围筛选、关键词搜索和历史对比功能。") + + # 日志时间范围选择 + col1, col2 = st.columns(2) + with col1: + start_date = st.date_input("开始日期", value=date.today() - timedelta(days=7)) + with col2: + end_date = st.date_input("结束日期", value=date.today()) + + # 日志级别筛选 + log_levels = ["ALL", "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] + selected_level = st.selectbox("日志级别", log_levels, index=1) + + # 关键词搜索 + search_query = st.text_input("搜索关键词") + + # 阶段筛选 + with db_session(read_only=True) as conn: + stages = [row["stage"] for row in conn.execute("SELECT DISTINCT stage FROM run_log").fetchall()] + stages = [s for s in stages if s] # 过滤空值 + stages.insert(0, "ALL") + selected_stage = st.selectbox("执行阶段", stages) + + # 查询日志 + with st.spinner("加载日志数据中..."): + try: + with db_session(read_only=True) as conn: + query_parts = ["SELECT ts, stage, level, msg FROM run_log WHERE 1=1"] + params = [] + + # 添加日期过滤 + start_ts = f"{start_date.isoformat()}T00:00:00Z" + end_ts = f"{end_date.isoformat()}T23:59:59Z" + query_parts.append("AND ts BETWEEN ? AND ?") + params.extend([start_ts, end_ts]) + + # 添加级别过滤 + if selected_level != "ALL": + query_parts.append("AND level = ?") + params.append(selected_level) + + # 添加关键词过滤 + if search_query: + query_parts.append("AND msg LIKE ?") + params.append(f"%{search_query}%") + + # 添加阶段过滤 + if selected_stage != "ALL": + query_parts.append("AND stage = ?") + params.append(selected_stage) + + # 添加排序 + query_parts.append("ORDER BY ts DESC") + + # 执行查询 + query = " ".join(query_parts) + rows = conn.execute(query, params).fetchall() + + # 转换为DataFrame + if rows: + # 将sqlite3.Row对象转换为字典列表 + rows_dict = [{key: row[key] for key in row.keys()} for row in rows] + log_df = pd.DataFrame(rows_dict) + # 格式化时间戳 + log_df["ts"] = pd.to_datetime(log_df["ts"]).dt.strftime("%Y-%m-%d %H:%M:%S") + else: + log_df = pd.DataFrame(columns=["ts", "stage", "level", "msg"]) + + # 显示日志表格 + st.dataframe( + log_df, + hide_index=True, + width="stretch", + column_config={ + "ts": st.column_config.TextColumn("时间"), + "stage": st.column_config.TextColumn("执行阶段"), + "level": st.column_config.TextColumn("日志级别"), + "msg": st.column_config.TextColumn("日志消息", width="large") + }, + use_container_width=True + ) + + # 下载功能 + if not log_df.empty: + csv_data = log_df.to_csv(index=False).encode('utf-8') + st.download_button( + label="下载日志CSV", + data=csv_data, + file_name=f"logs_{start_date}_{end_date}.csv", + mime="text/csv", + key="download_logs" + ) + + # JSON下载 + json_data = log_df.to_json(orient='records', force_ascii=False, indent=2) + st.download_button( + label="下载日志JSON", + data=json_data, + file_name=f"logs_{start_date}_{end_date}.json", + mime="application/json", + key="download_logs_json" + ) + except Exception as e: + LOGGER.exception("加载日志失败", extra=LOG_EXTRA) + st.error(f"加载日志数据失败:{e}") + + # 历史对比功能 + st.subheader("历史对比") + st.write("选择两个时间点的日志进行对比分析。") + + # 第一个时间点选择 + col3, col4 = st.columns(2) + with col3: + compare_date1 = st.date_input("对比日期1", value=date.today() - timedelta(days=1)) + with col4: + compare_date2 = st.date_input("对比日期2", value=date.today()) + + if st.button("执行对比", type="secondary"): + with st.spinner("执行日志对比分析中..."): + try: + with db_session(read_only=True) as conn: + # 获取两个日期的日志统计 + query_date1 = f"{compare_date1.isoformat()}T00:00:00Z" + query_date2 = f"{compare_date1.isoformat()}T23:59:59Z" + logs1 = conn.execute( + "SELECT level, COUNT(*) as count FROM run_log WHERE ts BETWEEN ? AND ? GROUP BY level", + (query_date1, query_date2) + ).fetchall() + + query_date3 = f"{compare_date2.isoformat()}T00:00:00Z" + query_date4 = f"{compare_date2.isoformat()}T23:59:59Z" + logs2 = conn.execute( + "SELECT level, COUNT(*) as count FROM run_log WHERE ts BETWEEN ? AND ? GROUP BY level", + (query_date3, query_date4) + ).fetchall() + + # 转换为DataFrame并可视化 + df1 = pd.DataFrame(logs1, columns=["level", "count"]) + df1["date"] = compare_date1.strftime("%Y-%m-%d") + df2 = pd.DataFrame(logs2, columns=["level", "count"]) + df2["date"] = compare_date2.strftime("%Y-%m-%d") + + compare_df = pd.concat([df1, df2]) + + # 绘制对比图表 + fig = px.bar( + compare_df, + x="level", + y="count", + color="date", + barmode="group", + title=f"日志级别分布对比 ({compare_date1} vs {compare_date2})" + ) + st.plotly_chart(fig, use_container_width=True) + + # 显示详细对比表格 + st.write("日志统计对比:") + # 使用不含连字符的日期格式作为列名后缀,避免Arrow类型转换错误 + date1_str = compare_date1.strftime("%Y%m%d") + date2_str = compare_date2.strftime("%Y%m%d") + merged_df = df1.merge(df2, on="level", suffixes=(f"_{date1_str}", f"_{date2_str}"), how="outer").fillna(0) + st.dataframe(merged_df, hide_index=True, width="stretch") + except Exception as e: + LOGGER.exception("日志对比失败", extra=LOG_EXTRA) + st.error(f"日志对比分析失败:{e}") cfg = get_config() default_start, default_end = _default_backtest_range(window_days=60) @@ -1695,7 +1942,14 @@ def render_settings() -> None: key=f"provider_default_model_{selected_provider}" ) temp_val = st.number_input("默认温度", value=provider_cfg.default_temperature, min_value=0.0, max_value=2.0, step=0.1, key=temp_key) - timeout_val = st.number_input("默认超时(秒)", value=provider_cfg.default_timeout, min_value=1, max_value=300, step=1, key=timeout_key) + timeout_val = st.number_input( + "默认超时(秒)", + value=float(provider_cfg.default_timeout) if provider_cfg.default_timeout is not None else 30.0, + min_value=1.0, + max_value=300.0, + step=1.0, + key=timeout_key, + ) prompt_template_val = st.text_area("Prompt 模板", value=provider_cfg.prompt_template or "", key=prompt_key) enabled_val = st.checkbox("启用", value=provider_cfg.enabled, key=enabled_key) mode_val = st.selectbox("模式", options=["openai", "ollama"], index=0 if provider_cfg.mode == "openai" else 1, key=mode_key) @@ -2365,7 +2619,7 @@ def main() -> None: with tabs[0]: render_today_plan() with tabs[1]: - render_backtest() + render_log_viewer() with tabs[2]: render_settings() with tabs[3]: diff --git a/docs/TODO_TRAE.md b/docs/TODO_TRAE.md new file mode 100644 index 0000000..ad4e686 --- /dev/null +++ b/docs/TODO_TRAE.md @@ -0,0 +1,98 @@ +# 多智能体个人投资助理项目待开发与待优化项 + +基于对项目代码库和文档的分析,现将项目中的待开发、待优化项整理如下: + +## 1. UI 与日志增强 +- **今日计划页**:增加"一键重评估"入口,以及日志钻取/历史对比视图 +- **回测页面**:支持多版本实验管理,可对比不同提示/温度的收益曲线,与`tuning_results`记录联动 +- **Streamlit界面优化**:补充实时指标面板、异常日志钻取与"仅监控不干预"模式的一键复评策略 +- **部门意见详情页**:展示已添加的`_telemetry`与`_department_telemetry`JSON字段信息 + +## 2. 数据与特征层 +- **因子计算模块完善**:`app/features/factors.py`中的`compute_factors()`函数需要进一步优化因子计算与持久化流程 +- **新闻数据源打通**:完成`app/ingest/rss.py`的RSS拉取与写库逻辑,强化新闻与情绪数据处理 +- **DataBroker增强**:强化取数校验、缓存与回退策略,确保行情/特征补数统一自动化,减少人工兜底 +- **因子集扩展**:围绕动量、估值、流动性等核心信号扩展轻量高质量因子集,全部由程序生成,满足端到端自动决策需求 + +## 3. 决策优化与强化学习 +- **动作空间扩展**:扩展`DecisionEnv`的动作空间,包括提示版本、部门温度、function调用策略等 +- **强化学习算法集成**:引入Bandit/贝叶斯优化或RL算法探索动作空间,并将`portfolio_snapshots`、`portfolio_trades`指标纳入奖励约束 +- **实时数据链路构建**:构建实时持仓/成交数据写入链路,使线上监控与离线调参共用同一数据源 +- **环境与策略拆分**:借鉴TradingAgents-CN的做法,拆分环境与策略、提供训练脚本/配置,并输出丰富的评估指标 +- **回测引擎完善**:完善`BacktestEngine`的成交撮合、风险阈值与指标输出,让回测信号直接对接执行端 + +## 4. 测试与验证 +- **测试覆盖率提升**:补充部门上下文构造、多模型调用、回测指标生成等核心路径的单元/集成测试 +- **回归测试用例**:建立决策流程的回归测试用例,确保提示模板或配置调整后行为可复现 +- **教程与示例**:编写示例Notebook/end-to-end教程,覆盖"数据→回测→调参→评估"全流程 +- **自动化验证管线**:针对数据摄取、策略主干与回测指标建立自动化验证管线 + +## 5. 文档同步 +- 随功能推进,更新README与讨论文档,确保描述与实际实现保持一致 + +## 6. LLM协同与配置 +- **Provider优化**:精简Provider列表,强化function-calling架构,完善降级和重试策略 +- **提示工程**:用配置化的角色提示与数据Scope提高模型行为可控性 +- **日志增强**:增强日志功能,记录完整的提示参数与决策结果,便于分析 + +## 7. 风险闭环强化 +- 回测引擎中调整撮合逻辑,统一考虑仓位上限、换手约束、滑点与手续费 +- 完善`bt_risk_events`表及落库链路,回测报告输出风险事件统计 +- `DecisionEnv`的Episode观测新增换手、风险事件等字段,默认奖励将回撤、风险与换手纳入惩罚项 + +## 8. 其他优化点 +- **模型调用稳定性**:完善LLM调用的错误处理和重试机制 +- **响应速度优化**:优化数据查询和计算逻辑,提高系统整体响应速度 +- **配置管理优化**:优化配置存储和加载机制,支持更灵活的配置管理 + +这些待开发和待优化项涵盖了系统的各个层面,从前端UI到后端数据处理,从策略优化到测试验证,可以根据项目优先级和资源情况逐步实施。 + +## 9. 数据与特征层具体优化建议 +根据系统现状分析,以下是数据与特征层的详细优化建议及实施优先级: + +### 优先级一:因子计算模块优化 +- **现有问题**:当前`factors.py`中的`compute_factors()`函数在计算过程中缺少足够的错误处理和边界条件检查 +- **优化方向**: + - 添加因子计算过程中的数据有效性校验机制 + - 实现因子值的异常值检测与处理逻辑 + - 增加计算进度显示和日志记录 + - 优化`_persist_factor_rows`函数的批处理性能 +- **预期收益**:提高因子计算的准确性和稳定性,减少因数据质量问题导致的决策偏差 + +### 优先级二:DataBroker数据访问层增强 +- **现有问题**:`data_access.py`中的DataBroker类在数据获取失败时缺乏有效回退机制 +- **优化方向**: + - 实现多级缓存策略,减少重复数据请求 + - 开发数据请求失败的自动重试机制 + - 增加数据源健康状态监控 + - 设计数据质量评估指标 +- **预期收益**:提高数据获取的稳定性和效率,增强系统对数据源波动的适应能力 + +### 优先级三:新闻数据源接入完善 +- **现有问题**:`rss.py`中的新闻处理逻辑较为简单,情感分析和实体提取能力有限 +- **优化方向**: + - 扩展支持的RSS源数量和类型 + - 增强情感分析模型,提高情绪识别准确率 + - 改进实体提取算法,更准确地识别新闻中的股票代码 + - 实现新闻时效性评分机制 +- **预期收益**:提升新闻数据对投资决策的参考价值,丰富决策依据 + +### 优先级四:数据完整性检查体系 +- **现有问题**:缺乏系统性的数据完整性检查机制 +- **优化方向**: + - 建立数据完整性规则和指标体系 + - 开发定时执行的数据质量检查脚本 + - 实现异常数据的自动告警机制 + - 设计数据补全和修复流程 +- **预期收益**:确保数据质量,减少因数据问题导致的决策错误 + +### 优先级五:因子库扩展 +- **现有问题**:当前`DEFAULT_FACTORS`列表中的因子类型相对有限 +- **优化方向**: + - 研究并实现更多高质量的技术因子 + - 开发基本面因子计算逻辑 + - 设计因子组合和因子权重优化算法 + - 建立因子绩效评估框架 +- **预期收益**:丰富决策模型的输入特征,提高投资决策的准确性和多样性 + +这些具体的优化建议可以作为项目团队在推进数据与特征层优化工作时的详细指导,按照优先级顺序逐步实施,将有助于系统性地提升系统的数据处理能力和决策支持水平。 \ No newline at end of file