diff --git a/app/features/extended_factors.py b/app/features/extended_factors.py index f164ec7..94aa1ee 100644 --- a/app/features/extended_factors.py +++ b/app/features/extended_factors.py @@ -8,10 +8,48 @@ end-to-end automated decision-making requirements. from __future__ import annotations from dataclasses import dataclass -from typing import Dict, List, Sequence, Optional +from typing import Dict, List, Sequence, Optional, Any +import functools import numpy as np +from app.utils.logging import get_logger + +LOGGER = get_logger(__name__) +LOG_EXTRA = {"stage": "extended_factors"} + + +def handle_factor_errors(func: Any) -> Any: + """装饰器:处理因子计算过程中的错误 + + Args: + func: 要装饰的函数 + + Returns: + 装饰后的函数 + """ + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Optional[float]: + try: + return func(*args, **kwargs) + except Exception as e: + # 获取因子名称(如果可能) + factor_name = "unknown" + if len(args) > 2 and isinstance(args[1], str): + factor_name = args[1] + elif "factor_name" in kwargs: + factor_name = kwargs["factor_name"] + + LOGGER.error( + "计算因子出错 name=%s error=%s", + factor_name, + str(e), + exc_info=True, + extra=LOG_EXTRA + ) + return None + return wrapper + from app.core.indicators import momentum, rolling_mean, normalize from app.core.technical import ( rsi, macd, bollinger_bands, obv_momentum, price_volume_trend @@ -101,12 +139,21 @@ class ExtendedFactors: ) all_factors = calculator.compute_all_factors(close_series, volume_series) normalized = calculator.normalize_factors(all_factors) + + 属性: + factor_specs: Dict[str, FactorSpec], 因子名称到因子规格的映射 """ def __init__(self): - """初始化因子计算器""" + """初始化因子计算器,构建因子规格映射""" self.factor_specs = {spec.name: spec for spec in EXTENDED_FACTORS} + LOGGER.info( + "初始化因子计算器,加载因子数量: %d", + len(self.factor_specs), + extra=LOG_EXTRA + ) + @handle_factor_errors def compute_factor(self, factor_name: str, close_series: Sequence[float], @@ -114,118 +161,166 @@ class ExtendedFactors: """计算单个因子值 Args: - factor_name: 因子名称 + factor_name: 因子名称,必须是已注册的因子 close_series: 收盘价序列,从新到旧排序 volume_series: 成交量序列,从新到旧排序 Returns: - 因子值,如果计算失败则返回None + factor_value: Optional[float], 计算得到的因子值,失败时返回None + + Raises: + ValueError: 当因子名称未知或数据不足时抛出 """ - try: - spec = self.factor_specs.get(factor_name) - if spec is None: - print(f"Unknown factor: {factor_name}") - return None - - if len(close_series) < spec.window: - return None - - # 技术分析因子 - if factor_name == "tech_rsi_14": - return rsi(close_series, 14) - - elif factor_name == "tech_macd_signal": - _, signal = macd(close_series) - return signal - - elif factor_name == "tech_bb_position": - upper, lower = bollinger_bands(close_series, 20) - pos = (close_series[0] - lower) / (upper - lower + 1e-8) - return pos - - elif factor_name == "tech_obv_momentum": - return obv_momentum(close_series, volume_series, 20) - - elif factor_name == "tech_pv_trend": - return price_volume_trend(close_series, volume_series, 20) + spec = self.factor_specs.get(factor_name) + if spec is None: + raise ValueError(f"未知的因子名称: {factor_name}") - # 趋势跟踪因子 - elif factor_name == "trend_ma_cross": - ma_5 = rolling_mean(close_series, 5) - ma_20 = rolling_mean(close_series, 20) - return ma_5 - ma_20 + if len(close_series) < spec.window: + raise ValueError( + f"数据长度不足: 需要{spec.window},实际{len(close_series)}" + ) + + # 技术分析因子 + if factor_name == "tech_rsi_14": + return rsi(close_series, 14) - # 波动率预测因子 - elif factor_name == "vol_garch": - return garch_volatility(close_series, 20) - - elif factor_name == "vol_regime": - regime, _ = volatility_regime(close_series, volume_series, 20) - return regime - - # 量价联合因子 - elif factor_name == "volume_price_corr": - return volume_price_correlation(close_series, volume_series, 20) - - # 增强动量因子 - elif factor_name == "momentum_adaptive": - return adaptive_momentum(close_series, volume_series, 20) - - elif factor_name == "momentum_regime": - return momentum_regime(close_series, volume_series, 20) - - elif factor_name == "momentum_quality": - return momentum_quality(close_series, 20) - - # 均线比率因子 - elif factor_name.endswith("_ratio"): - if "price_ma" in factor_name: - window = int(factor_name.split("_")[2]) - ma = rolling_mean(close_series, window) - return close_series[0] / ma if ma > 0 else None - - elif "volume_ma" in factor_name: - window = int(factor_name.split("_")[2]) - ma = rolling_mean(volume_series, window) - return volume_series[0] / ma if ma > 0 else None + elif factor_name == "tech_macd_signal": + _, signal = macd(close_series) + return signal - return None + elif factor_name == "tech_bb_position": + upper, lower = bollinger_bands(close_series, 20) + pos = (close_series[0] - lower) / (upper - lower + 1e-8) + return pos - except Exception as e: - print(f"Error computing factor {factor_name}: {str(e)}") - return None + elif factor_name == "tech_obv_momentum": + return obv_momentum(close_series, volume_series, 20) + + elif factor_name == "tech_pv_trend": + return price_volume_trend(close_series, volume_series, 20) + + # 趋势跟踪因子 + elif factor_name == "trend_ma_cross": + ma_5 = rolling_mean(close_series, 5) + ma_20 = rolling_mean(close_series, 20) + return ma_5 - ma_20 + + # 波动率预测因子 + elif factor_name == "vol_garch": + return garch_volatility(close_series, 20) + + elif factor_name == "vol_regime": + regime, _ = volatility_regime(close_series, volume_series, 20) + return regime + + # 量价联合因子 + elif factor_name == "volume_price_corr": + return volume_price_correlation(close_series, volume_series, 20) + + # 增强动量因子 + elif factor_name == "momentum_adaptive": + return adaptive_momentum(close_series, volume_series, 20) + + elif factor_name == "momentum_regime": + return momentum_regime(close_series, volume_series, 20) + + elif factor_name == "momentum_quality": + return momentum_quality(close_series, 20) + + # 均线比率因子 + elif factor_name.endswith("_ratio"): + if "price_ma" in factor_name: + window = int(factor_name.split("_")[2]) + ma = rolling_mean(close_series, window) + return close_series[0] / ma if ma > 0 else None + + elif "volume_ma" in factor_name: + window = int(factor_name.split("_")[2]) + ma = rolling_mean(volume_series, window) + return volume_series[0] / ma if ma > 0 else None + + raise ValueError(f"因子 {factor_name} 没有对应的计算实现") def compute_all_factors(self, close_series: Sequence[float], volume_series: Sequence[float]) -> Dict[str, float]: - """计算所有扩展因子值 + """计算所有已注册的扩展因子值 Args: close_series: 收盘价序列,从新到旧排序 volume_series: 成交量序列,从新到旧排序 Returns: - 因子名称到因子值的映射字典 + Dict[str, float]: 因子名称到因子值的映射字典, + 只包含成功计算的因子值 + + Note: + 该方法会尝试计算所有已注册的因子,失败的因子将被忽略。 + 如果所有因子计算都失败,将返回空字典。 """ results = {} + success_count = 0 + total_count = len(self.factor_specs) for factor_name in self.factor_specs: value = self.compute_factor(factor_name, close_series, volume_series) if value is not None: results[factor_name] = value + success_count += 1 + LOGGER.info( + "因子计算完成 total=%d success=%d failed=%d", + total_count, + success_count, + total_count - success_count, + extra=LOG_EXTRA + ) + return results - def normalize_factors(self, factors: Dict[str, float]) -> Dict[str, float]: + + def normalize_factors(self, + factors: Dict[str, float], + clip_threshold: float = 3.0) -> Dict[str, float]: """标准化因子值到[-1,1]区间 Args: factors: 原始因子值字典 + clip_threshold: float, 标准化时的截断阈值,默认为3.0 Returns: - 标准化后的因子值字典 + Dict[str, float]: 标准化后的因子值字典, + 只包含成功标准化的因子值 + + Note: + 标准化过程包括: + 1. Z-score标准化 + 2. 使用tanh压缩到[-1,1]区间 + 3. 异常值处理(截断) """ results = {} + success_count = 0 + for name, value in factors.items(): if value is not None: - results[name] = normalize(value) + try: + normalized = normalize(value, clip_threshold) + if not np.isnan(normalized): + results[name] = normalized + success_count += 1 + except Exception as e: + LOGGER.warning( + "因子标准化失败 name=%s error=%s", + name, + str(e), + extra=LOG_EXTRA + ) + + LOGGER.info( + "因子标准化完成 total=%d success=%d failed=%d", + len(factors), + success_count, + len(factors) - success_count, + extra=LOG_EXTRA + ) + return results \ No newline at end of file diff --git a/app/ingest/entity_recognition.py b/app/ingest/entity_recognition.py new file mode 100644 index 0000000..b47fda5 --- /dev/null +++ b/app/ingest/entity_recognition.py @@ -0,0 +1,145 @@ +"""Stock code mapping and entity recognition utilities.""" +from __future__ import annotations + +import re +from typing import Dict, List, Optional, Set, Tuple + +# 股票代码正则表达式 +A_SH_CODE_PATTERN = re.compile(r"\b(\d{6})(\.(?:SH|SZ))?\b", re.IGNORECASE) +HK_CODE_PATTERN = re.compile(r"\b(\d{4})\.HK\b", re.IGNORECASE) + +def normalize_stock_code(code: str, explicit_market: str = None) -> str: + """规范化股票代码格式. + + Args: + code: 原始股票代码 + explicit_market: 显式指定的市场,如 'SH' 或 'SZ' + + Returns: + 标准格式的股票代码,如 '000001.SZ' + """ + if '.' in code: + return code.upper() + + if explicit_market: + return f"{code}.{explicit_market.upper()}" + + # 根据代码规则判断市场 + if code.startswith('6'): + return f"{code}.SH" + elif code.startswith(('0', '3')): + return f"{code}.SZ" + else: + return f"{code}.SH" # 默认使用上交所 + +# 公司名称变体模式 +COMPANY_SUFFIXES = ["股份", "科技", "公司", "集团", "股份有限公司", "有限公司"] + +class CompanyNameMapper: + """Map company names to stock codes with fuzzy matching.""" + + def __init__(self): + self.name_to_code: Dict[str, str] = {} # 完整名称到代码映射 + self.short_names: Dict[str, str] = {} # 简称到代码映射 + self.aliases: Dict[str, str] = {} # 别名到代码映射 + + def add_company(self, ts_code: str, full_name: str, short_name: str, aliases: List[str] = None): + """Add a company to the mapping. + + Args: + ts_code: Stock code in format like '000001.SZ' + full_name: Full registered company name + short_name: Official short name + aliases: List of alternative names + """ + # 存储完整名称映射 + self.name_to_code[full_name] = ts_code + + # 存储简称映射 + self.short_names[short_name] = ts_code + + # 生成和存储名称变体 + name_variants = self._generate_name_variants(full_name) + for variant in name_variants: + if variant not in self.aliases: + self.aliases[variant] = ts_code + + # 存储额外的别名 + if aliases: + for alias in aliases: + if alias not in self.aliases: + self.aliases[alias] = ts_code + + def _generate_name_variants(self, full_name: str) -> Set[str]: + """Generate possible variants of a company name.""" + variants = set() + + # 仅移除整个公司类型后缀 + for suffix in COMPANY_SUFFIXES: + if full_name.endswith(suffix): + variant = full_name[:-len(suffix)].strip() + if len(variant) > 2: # 避免太短的变体 + variants.add(variant) + break + + return variants + + def find_codes(self, text: str) -> List[Tuple[str, str, str]]: + """Find company mentions and corresponding stock codes in text. + + Returns: + List of tuples (matched_text, stock_code, match_type) + where match_type is one of 'code', 'full_name', 'short_name', 'alias' + """ + matches = [] + + # 1. 查找直接的股票代码 + for match in A_SH_CODE_PATTERN.finditer(text): + code = match.group(1) + explicit_market = match.group(2)[1:] if match.group(2) else None + ts_code = normalize_stock_code(code, explicit_market) + matches.append((match.group(), ts_code, 'code')) + + for match in HK_CODE_PATTERN.finditer(text): + ts_code = match.group() + matches.append((match.group(), ts_code, 'code')) + + # 2. 按优先级顺序查找公司名称 + # 完整名称优先级最高 + for name, code in self.name_to_code.items(): + if name in text: + matches.append((name, code, 'full_name')) + + # 其次是简称 + for name, code in self.short_names.items(): + if name in text: + matches.append((name, code, 'short_name')) + + # 最后是别名 + for alias, code in self.aliases.items(): + if alias in text: + matches.append((alias, code, 'alias')) + + return matches + +# 创建全局单例实例 +company_mapper = CompanyNameMapper() + +def initialize_company_mapping(db_connection) -> None: + """从数据库加载公司名称映射. + + Args: + db_connection: SQLite数据库连接 + """ + cursor = db_connection.cursor() + cursor.execute(""" + SELECT ts_code, name, short_name + FROM stock_company + WHERE name IS NOT NULL + """) + + for ts_code, name, short_name in cursor.fetchall(): + if name and short_name: + company_mapper.add_company(ts_code, name, short_name) + + cursor.close() diff --git a/app/ingest/rss.py b/app/ingest/rss.py index 98b10b2..fdf9756 100644 --- a/app/ingest/rss.py +++ b/app/ingest/rss.py @@ -18,6 +18,8 @@ import hashlib import random import time +from app.ingest.entity_recognition import company_mapper, initialize_company_mapping + try: # pragma: no cover - optional dependency at runtime import feedparser # type: ignore[import-not-found] except ImportError: # pragma: no cover - graceful fallback @@ -95,6 +97,15 @@ class RssFeedConfig: max_items: int = 50 +@dataclass +class StockMention: + """A mention of a stock in text.""" + matched_text: str + ts_code: str + match_type: str # 'code', 'full_name', 'short_name', 'alias' + context: str # 相关的上下文片段 + confidence: float # 匹配的置信度 + @dataclass class RssItem: """Structured representation of an RSS entry.""" @@ -106,8 +117,214 @@ class RssItem: summary: str source: str ts_codes: List[str] = field(default_factory=list) - industries: List[str] = field(default_factory=list) # 新增:相关行业列表 - important_keywords: List[str] = field(default_factory=list) # 新增:重要关键词列表 + stock_mentions: List[StockMention] = field(default_factory=list) + industries: List[str] = field(default_factory=list) + important_keywords: List[str] = field(default_factory=list) + + def __post_init__(self): + """Initialize company mapper if not already initialized.""" + # 测试环境下跳过数据库初始化 + if not hasattr(self, '_skip_db_init'): # 仅在非测试环境下初始化 + from app.utils.db import db_session + + # 如果company_mapper还没有数据,初始化它 + if not company_mapper.name_to_code: + with db_session() as conn: + initialize_company_mapping(conn) + + def extract_entities(self) -> None: + """Extract and validate entity mentions from title and summary.""" + # 分别处理标题和摘要 + title_matches = company_mapper.find_codes(self.title) + summary_matches = company_mapper.find_codes(self.summary) + + # 按优先级合并去重后的匹配 + code_best_matches = {} # ts_code -> (matched_text, match_type, is_title, context) + + # 优先级顺序: 代码 > 全称 > 简称 > 别名 + priority = {'code': 0, 'full_name': 1, 'short_name': 2, 'alias': 3} + + for matches, text, is_title in [(title_matches, self.title, True), + (summary_matches, self.summary, False)]: + for matched_text, ts_code, match_type in matches: + # 提取上下文 + context = self._extract_context(text, matched_text) + + # 如果是新代码或优先级更高的匹配 + if (ts_code not in code_best_matches or + priority[match_type] < priority[code_best_matches[ts_code][1]]): + code_best_matches[ts_code] = (matched_text, match_type, is_title, context) + + # 创建股票提及列表 + for ts_code, (matched_text, match_type, is_title, context) in code_best_matches.items(): + confidence = self._calculate_confidence(match_type, matched_text, context, is_title) + + mention = StockMention( + matched_text=matched_text, + ts_code=ts_code, + match_type=match_type, + context=context, + confidence=confidence + ) + self.stock_mentions.append(mention) + + # 更新ts_codes列表,只包含高置信度的匹配 + self.ts_codes = list(set( + mention.ts_code + for mention in self.stock_mentions + if mention.confidence > 0.7 # 只保留高置信度的匹配 + )) + + # 提取行业关键词 + self.extract_industries() + + # 提取重要关键词 + self.extract_important_keywords() + + def _extract_context(self, text: str, matched_text: str) -> str: + """提取匹配文本的上下文,尽量提取完整的句子.""" + # 找到匹配文本的位置 + start_pos = text.find(matched_text) + if start_pos == -1: + return "" + + # 向前找到句子开始(句号、问号、感叹号或换行符之后) + sent_start = start_pos + while sent_start > 0: + if text[sent_start-1] in '。?!\n': + break + sent_start -= 1 + + # 向后找到句子结束 + sent_end = start_pos + len(matched_text) + while sent_end < len(text): + if text[sent_end] in '。?!\n': + sent_end += 1 + break + sent_end += 1 + + # 如果上下文太长,则截取固定长度 + context = text[sent_start:sent_end].strip() + if len(context) > 100: # 最大上下文长度 + start = max(0, start_pos - 30) + end = min(len(text), start_pos + len(matched_text) + 30) + context = text[start:end].strip() + + return context + + def extract_industries(self) -> None: + """从新闻标题和摘要中提取行业关键词.""" + content = f"{self.title} {self.summary}".lower() + found_industries = set() + + # 对每个行业检查其关键词 + for industry, keywords in INDUSTRY_KEYWORDS.items(): + # 如果找到任意关键词,认为属于该行业 + if any(keyword.lower() in content for keyword in keywords): + found_industries.add(industry) + + self.industries = list(found_industries) + + def extract_important_keywords(self) -> None: + """提取重要关键词,包括积极/消极情感词和特定事件.""" + content = f"{self.title} {self.summary}".lower() + found_keywords = set() + + # 1. 检查积极关键词 + for keyword in POSITIVE_KEYWORDS: + if keyword.lower() in content: + found_keywords.add(f"+{keyword}") # 加前缀表示积极 + + # 2. 检查消极关键词 + for keyword in NEGATIVE_KEYWORDS: + if keyword.lower() in content: + found_keywords.add(f"-{keyword}") # 加前缀表示消极 + + # 3. 检查特定事件关键词 + event_keywords = { + # 公司行为 + "收购": "M&A", + "并购": "M&A", + "重组": "重组", + "分拆": "分拆", + "上市": "IPO", + # 财务事件 + "业绩": "业绩", + "亏损": "业绩预警", + "盈利": "业绩预增", + "分红": "分红", + "回购": "回购", + # 监管事件 + "立案": "监管", + "调查": "监管", + "问询": "监管", + "处罚": "处罚", + # 重大项目 + "中标": "中标", + "签约": "签约", + "战略合作": "合作", + } + + for trigger, event in event_keywords.items(): + if trigger in content: + found_keywords.add(f"#{event}") # 加前缀表示事件 + + self.important_keywords = list(found_keywords) + + def _calculate_confidence(self, match_type: str, matched_text: str, context: str, is_title: bool = False) -> float: + """计算实体匹配的置信度. + + 考虑以下因素: + 1. 匹配类型的基础置信度 + 2. 实体在文本中的位置(标题/开头更重要) + 3. 上下文关键词 + 4. 股票相关动词 + 5. 实体的完整性 + """ + # 基础置信度 + base_confidence = { + 'code': 0.9, # 直接的股票代码匹配 + 'full_name': 0.85,# 完整公司名称匹配 + 'short_name': 0.7,# 公司简称匹配 + 'alias': 0.6 # 别名匹配 + }.get(match_type, 0.5) + + confidence = base_confidence + context_lower = context.lower() + + # 1. 位置加权 + if is_title: + confidence += 0.1 + if context.startswith(matched_text): + confidence += 0.05 + + # 2. 实体完整性检查 + if match_type == 'code' and '.' in matched_text: # 完整股票代码(带市场后缀) + confidence += 0.05 + elif match_type == 'full_name' and any(suffix in matched_text for suffix in ["股份有限公司", "有限公司"]): + confidence += 0.05 + + # 3. 上下文关键词 + context_bonus = 0.0 + corporate_terms = ["公司", "集团", "企业", "上市", "控股", "总部"] + if any(term in context_lower for term in corporate_terms): + context_bonus += 0.1 + + # 4. 股票相关动词 + stock_verbs = ["发布", "公告", "披露", "表示", "报告", "投资", "回购", "增持", "减持"] + if any(verb in context_lower for verb in stock_verbs): + context_bonus += 0.05 + + # 5. 财务/业务相关词汇 + business_terms = ["业绩", "营收", "利润", "股价", "市值", "经营", "产品", "服务", "战略"] + if any(term in context_lower for term in business_terms): + context_bonus += 0.05 + + # 限制上下文加成的最大值 + confidence += min(context_bonus, 0.2) + + # 确保置信度在0-1之间 + return min(1.0, max(0.0, confidence)) DEFAULT_RSS_SOURCES: Tuple[RssFeedConfig, ...] = () @@ -255,7 +472,7 @@ def _fetch_feed_items( def deduplicate_items(items: Iterable[RssItem]) -> List[RssItem]: - """Drop duplicate stories by link/id fingerprint.""" + """Drop duplicate stories by link/id fingerprint and process entities.""" seen = set() unique: List[RssItem] = [] @@ -264,7 +481,14 @@ def deduplicate_items(items: Iterable[RssItem]) -> List[RssItem]: if key in seen: continue seen.add(key) - unique.append(item) + + # 提取实体和相关信息 + item.extract_entities() + + # 如果找到了相关股票,则保留这条新闻 + if item.stock_mentions: + unique.append(item) + return unique diff --git a/docs/TODO_UNIFIED.md b/docs/TODO_UNIFIED.md new file mode 100644 index 0000000..4820eb1 --- /dev/null +++ b/docs/TODO_UNIFIED.md @@ -0,0 +1,106 @@ +# LLM量化交易助理系统开发计划 + +> 项目愿景:开发一个可实战的投资助理工具,其业务水平要处在投资的前列。核心是通过多智能体协作提供高质量的投资决策支持。 + +> 开发进度(2025-10-05): +> ✓ 基础因子计算框架 +> ✓ 数据访问与监控 +> ✓ 核心回测系统 +> ✓ LLM基础集成 +> △ RSS新闻处理 +> △ UI与监控系统 + +## 一、核心功能模块优先级排序 + +### 1. 数据与特征层(P0) +#### 1.1 因子计算模块优化 +- [x] 完善 `compute_factors()` 函数实现: + - [x] 添加数据有效性校验机制 + - [x] 实现异常值检测与处理逻辑 + - [ ] 增加计算进度显示和日志记录 + - [ ] 优化因子持久化性能 + - [ ] 支持增量计算模式 + +#### 1.2 DataBroker增强 +- [x] 开发数据请求失败的自动重试机制 +- [x] 增加数据源健康状态监控 +- [ ] 设计数据质量评估指标系统 + +#### 1.3 因子库扩展 +- [x] 扩展动量类因子群 +- [x] 开发估值类因子群 +- [x] 设计流动性因子群 +- [ ] 构建市场情绪因子群 +- [ ] 开发因子组合和权重优化算法 + +#### 1.4 新闻数据源完善 +- [x] 完成RSS数据获取和解析 +- [x] 增强情感分析能力 +- [ ] 改进实体识别准确率 +- [ ] 实现新闻时效性评分 + +### 2. 决策优化(P1) +#### 2.1 决策环境增强 +- [x] 扩展DecisionEnv动作空间: + - [x] 支持提示版本选择 + - [x] 允许调节部门温度 + - [ ] 优化function调用策略 +- [x] 增加环境观测维度: + - [x] 加入换手率指标 + - [x] 纳入风险事件统计 + - [ ] 补充市场情绪指标 + +#### 2.2 回测系统完善 +- [x] 优化成交撮合逻辑: + - [x] 统一仓位限制 + - [x] 考虑换手约束 + - [x] 加入滑点模拟 + - [x] 计算交易成本 +- [x] 完善风险控制: + - [x] 实现止损机制 + - [x] 添加波动率限制 + - [x] 设置集中度控制 + +### 3. LLM协同(P1) +- [x] 精简和优化Provider管理 +- [x] 增强function-calling架构 +- [x] 完善错误处理和重试策略 +- [ ] 优化提示工程: + - [ ] 设计配置化角色提示 + - [ ] 优化数据范围控制 + - [ ] 改进上下文管理 + +### 4. UI与监控(P2) +#### 4.1 功能增强 +- [ ] 实现"一键重评估"功能 +- [ ] 开发多版本实验对比 +- [x] 添加实时指标面板 +- [ ] 设计异常日志钻取功能 + +#### 4.2 监控增强 +- [ ] 开发"仅监控不干预"模式 +- [x] 实现策略实时评估 +- [x] 添加风险预警功能 +- [ ] 设计绩效归因分析 + +### 5. 测试与部署(P2) +- [x] 补充核心路径单元测试 +- [ ] 建立端到端集成测试 +- [ ] 完善日志收集机制 + +## 二、近期开发重点(2025 Q4) + +1. ✓ 完成因子计算模块的优化和重构 +2. ✓ 实现基础因子库的扩展 +3. ✓ 优化DataBroker的数据访问性能 +4. △ 完善RSS新闻数据源的接入 +5. ✓ 开始着手决策环境的增强 + +## 三、开发原则 + +1. 保持简单:每个模块只实现最核心的功能 +2. 重视可靠性:核心功能必须稳定可靠 +3. 易于使用:交互界面简单直观 +4. 容错设计:关键节点预留人工介入的可能 + +> 注:此计划将根据实际使用体验持续优化,始终保持简单实用的原则。上次更新:2025-10-05 diff --git a/tests/test_entity_recognition.py b/tests/test_entity_recognition.py new file mode 100644 index 0000000..2bd937d --- /dev/null +++ b/tests/test_entity_recognition.py @@ -0,0 +1,110 @@ +"""Test improved entity recognition in RSS processing.""" +from datetime import datetime, timezone +import pytest + +from app.ingest.entity_recognition import CompanyNameMapper, company_mapper +from app.ingest.rss import RssItem, StockMention + +def test_company_name_mapper(): + mapper = CompanyNameMapper() + + # 添加测试公司 + mapper.add_company( + ts_code="000001.SZ", + full_name="平安银行股份有限公司", + short_name="平安银行", + aliases=["平安", "PAB"] + ) + + # 测试名称匹配,会找到所有可能的匹配 + matches = mapper.find_codes("平安银行股份有限公司公布2025年业绩") + assert len([m for m in matches if m[1] == "000001.SZ"]) >= 1 + # 确保有一个全称匹配 + assert any(m[2] == "full_name" and m[1] == "000001.SZ" for m in matches) + + # 测试简称匹配 + matches = mapper.find_codes("平安银行发布公告") + # 应该找到简称匹配,可能还有其他匹配 + assert any(m[1] == "000001.SZ" and m[2] == "short_name" for m in matches) + + # 测试别名匹配 + matches = mapper.find_codes("PAB发布新产品") + # 应该至少找到别名匹配 + assert any(m[1] == "000001.SZ" and m[2] == "alias" for m in matches) + + # 测试股票代码直接匹配 + matches = mapper.find_codes("000001.SZ开盘上涨") + assert len(matches) == 1 + assert matches[0][1] == "000001.SZ" + assert matches[0][2] == "code" + +def test_rss_item_entity_extraction(): + # 使用全局company_mapper + company_mapper.name_to_code.clear() # 清除之前的数据 + company_mapper.add_company( + ts_code="000001.SZ", + full_name="平安银行股份有限公司", + short_name="平安银行", + aliases=["平安", "PAB"] + ) + + # 创建测试新闻并跳过数据库初始化 + class TestRssItem(RssItem): + _skip_db_init = True + + item = TestRssItem( + id="test_news", + title="平安银行发布2025年业绩预告", + link="http://example.com", + published=datetime.now(timezone.utc), + summary="平安银行股份有限公司(000001.SZ)今日发布2025年业绩预告", + source="test" + ) + + # 提取实体 + item.extract_entities() + + # 验证结果:由于优先级机制,只会保留最优的匹配 + matched_types = set(m.match_type for m in item.stock_mentions) + assert "code" in matched_types or "full_name" in matched_types # 应该至少找到代码或全称匹配 + + # 验证唯一股票代码 + assert len(item.ts_codes) == 1 # 只有一个唯一的股票代码 + assert item.ts_codes[0] == "000001.SZ" + + # 验证置信度计算 + high_confidence = [m for m in item.stock_mentions if m.confidence > 0.7] + assert len(high_confidence) >= 1 # 至少应该有一个高置信度的匹配 + +def test_rss_item_context_extraction(): + # 使用全局company_mapper + company_mapper.name_to_code.clear() # 清除之前的数据 + company_mapper.add_company( + ts_code="000001.SZ", + full_name="平安银行股份有限公司", + short_name="平安银行" + ) + + # 创建带有上下文的测试新闻并跳过数据库初始化 + class TestRssItem(RssItem): + _skip_db_init = True + + item = TestRssItem( + id="test_news", + title="多家银行业绩报告", + link="http://example.com", + published=datetime.now(timezone.utc), + summary="在银行业整体向好的背景下,平安银行表现突出,营收增长明显", + source="test" + ) + + # 提取实体 + item.extract_entities() + + # 验证上下文提取 + assert len(item.stock_mentions) > 0 + mention = item.stock_mentions[0] + assert len(mention.context) <= 70 # 上下文长度限制(30字符前后) + assert "平安银行" in mention.context + assert "银行业" in mention.context # 应包含前文 + assert "营收增长" in mention.context # 应包含后文 diff --git a/tests/test_rss_item_industry_keywords.py b/tests/test_rss_item_industry_keywords.py new file mode 100644 index 0000000..7f7f2c8 --- /dev/null +++ b/tests/test_rss_item_industry_keywords.py @@ -0,0 +1,83 @@ +"""Test industry and keyword extraction in RSS processing.""" +from datetime import datetime, timezone + +from app.ingest.rss import RssItem + +def test_industry_extraction(): + """Test industry keyword extraction.""" + # 创建测试新闻并跳过数据库初始化 + class TestRssItem(RssItem): + _skip_db_init = True + + item = TestRssItem( + id="test_news", + title="某半导体公司推出新一代芯片", + link="http://example.com", + published=datetime.now(timezone.utc), + summary="该公司在集成电路领域取得重大突破,新产品将用于5G通信", + source="test" + ) + + # 提取行业关键词 + item.extract_industries() + + # 验证结果 + assert "半导体" in item.industries + assert len(item.industries) >= 1 + +def test_important_keyword_extraction(): + """Test important keyword extraction.""" + # 创建测试新闻(包含积极、消极和事件关键词)并跳过数据库初始化 + class TestRssItem(RssItem): + _skip_db_init = True + + item = TestRssItem( + id="test_news", + title="某公司业绩超预期,同时宣布重大收购计划", + link="http://example.com", + published=datetime.now(timezone.utc), + summary="营收增长显著,但部分业务亏损,将通过并购扩张", + source="test" + ) + + # 提取重要关键词 + item.extract_important_keywords() + + # 验证结果:应该包含积极、消极和事件关键词 + keywords = set(item.important_keywords) + + # 检查是否包含至少一个积极关键词(前缀为+) + assert any(k.startswith('+') for k in keywords) + + # 检查是否包含至少一个消极关键词(前缀为-) + assert any(k.startswith('-') for k in keywords) + + # 检查是否包含至少一个事件关键词(前缀为#) + assert any(k.startswith('#') for k in keywords) + +def test_rss_item_full_extraction(): + """Test full entity, industry and keyword extraction.""" + # 创建一个包含多种信息的测试新闻并跳过数据库初始化 + class TestRssItem(RssItem): + _skip_db_init = True + + item = TestRssItem( + id="test_news", + title="半导体行业利好:某公司重大突破", + link="http://example.com", + published=datetime.now(timezone.utc), + summary="集成电路领域取得重大进展,业绩超预期,宣布增持计划", + source="test" + ) + + # 提取所有信息 + item.extract_entities() # 这将自动调用行业和关键词提取 + + # 验证结果的完整性 + assert item.industries # 应该至少识别出半导体行业 + assert item.important_keywords # 应该找到关键词 + + # 验证关键词类型的完整性 + keywords = set(item.important_keywords) + assert any(k.startswith('+') for k in keywords) # 积极关键词 + assert any(k.startswith('#') for k in keywords) # 事件关键词