This commit is contained in:
sam 2025-10-05 14:48:41 +08:00
parent 69a3cc69c6
commit 16a5fae732
6 changed files with 845 additions and 82 deletions

View File

@ -8,10 +8,48 @@ end-to-end automated decision-making requirements.
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Sequence, Optional from typing import Dict, List, Sequence, Optional, Any
import functools
import numpy as np import numpy as np
from app.utils.logging import get_logger
LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "extended_factors"}
def handle_factor_errors(func: Any) -> Any:
"""装饰器:处理因子计算过程中的错误
Args:
func: 要装饰的函数
Returns:
装饰后的函数
"""
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Optional[float]:
try:
return func(*args, **kwargs)
except Exception as e:
# 获取因子名称(如果可能)
factor_name = "unknown"
if len(args) > 2 and isinstance(args[1], str):
factor_name = args[1]
elif "factor_name" in kwargs:
factor_name = kwargs["factor_name"]
LOGGER.error(
"计算因子出错 name=%s error=%s",
factor_name,
str(e),
exc_info=True,
extra=LOG_EXTRA
)
return None
return wrapper
from app.core.indicators import momentum, rolling_mean, normalize from app.core.indicators import momentum, rolling_mean, normalize
from app.core.technical import ( from app.core.technical import (
rsi, macd, bollinger_bands, obv_momentum, price_volume_trend rsi, macd, bollinger_bands, obv_momentum, price_volume_trend
@ -101,12 +139,21 @@ class ExtendedFactors:
) )
all_factors = calculator.compute_all_factors(close_series, volume_series) all_factors = calculator.compute_all_factors(close_series, volume_series)
normalized = calculator.normalize_factors(all_factors) normalized = calculator.normalize_factors(all_factors)
属性:
factor_specs: Dict[str, FactorSpec], 因子名称到因子规格的映射
""" """
def __init__(self): def __init__(self):
"""初始化因子计算器""" """初始化因子计算器,构建因子规格映射"""
self.factor_specs = {spec.name: spec for spec in EXTENDED_FACTORS} self.factor_specs = {spec.name: spec for spec in EXTENDED_FACTORS}
LOGGER.info(
"初始化因子计算器,加载因子数量: %d",
len(self.factor_specs),
extra=LOG_EXTRA
)
@handle_factor_errors
def compute_factor(self, def compute_factor(self,
factor_name: str, factor_name: str,
close_series: Sequence[float], close_series: Sequence[float],
@ -114,21 +161,24 @@ class ExtendedFactors:
"""计算单个因子值 """计算单个因子值
Args: Args:
factor_name: 因子名称 factor_name: 因子名称必须是已注册的因子
close_series: 收盘价序列从新到旧排序 close_series: 收盘价序列从新到旧排序
volume_series: 成交量序列从新到旧排序 volume_series: 成交量序列从新到旧排序
Returns: Returns:
因子值如果计算失败则返回None factor_value: Optional[float], 计算得到的因子值失败时返回None
Raises:
ValueError: 当因子名称未知或数据不足时抛出
""" """
try:
spec = self.factor_specs.get(factor_name) spec = self.factor_specs.get(factor_name)
if spec is None: if spec is None:
print(f"Unknown factor: {factor_name}") raise ValueError(f"未知的因子名称: {factor_name}")
return None
if len(close_series) < spec.window: if len(close_series) < spec.window:
return None raise ValueError(
f"数据长度不足: 需要{spec.window},实际{len(close_series)}"
)
# 技术分析因子 # 技术分析因子
if factor_name == "tech_rsi_14": if factor_name == "tech_rsi_14":
@ -189,43 +239,88 @@ class ExtendedFactors:
ma = rolling_mean(volume_series, window) ma = rolling_mean(volume_series, window)
return volume_series[0] / ma if ma > 0 else None return volume_series[0] / ma if ma > 0 else None
return None raise ValueError(f"因子 {factor_name} 没有对应的计算实现")
except Exception as e:
print(f"Error computing factor {factor_name}: {str(e)}")
return None
def compute_all_factors(self, def compute_all_factors(self,
close_series: Sequence[float], close_series: Sequence[float],
volume_series: Sequence[float]) -> Dict[str, float]: volume_series: Sequence[float]) -> Dict[str, float]:
"""计算所有扩展因子值 """计算所有已注册的扩展因子值
Args: Args:
close_series: 收盘价序列从新到旧排序 close_series: 收盘价序列从新到旧排序
volume_series: 成交量序列从新到旧排序 volume_series: 成交量序列从新到旧排序
Returns: Returns:
因子名称到因子值的映射字典 Dict[str, float]: 因子名称到因子值的映射字典
只包含成功计算的因子值
Note:
该方法会尝试计算所有已注册的因子失败的因子将被忽略
如果所有因子计算都失败将返回空字典
""" """
results = {} results = {}
success_count = 0
total_count = len(self.factor_specs)
for factor_name in self.factor_specs: for factor_name in self.factor_specs:
value = self.compute_factor(factor_name, close_series, volume_series) value = self.compute_factor(factor_name, close_series, volume_series)
if value is not None: if value is not None:
results[factor_name] = value results[factor_name] = value
success_count += 1
LOGGER.info(
"因子计算完成 total=%d success=%d failed=%d",
total_count,
success_count,
total_count - success_count,
extra=LOG_EXTRA
)
return results return results
def normalize_factors(self, factors: Dict[str, float]) -> Dict[str, float]:
def normalize_factors(self,
factors: Dict[str, float],
clip_threshold: float = 3.0) -> Dict[str, float]:
"""标准化因子值到[-1,1]区间 """标准化因子值到[-1,1]区间
Args: Args:
factors: 原始因子值字典 factors: 原始因子值字典
clip_threshold: float, 标准化时的截断阈值默认为3.0
Returns: Returns:
标准化后的因子值字典 Dict[str, float]: 标准化后的因子值字典
只包含成功标准化的因子值
Note:
标准化过程包括:
1. Z-score标准化
2. 使用tanh压缩到[-1,1]区间
3. 异常值处理截断
""" """
results = {} results = {}
success_count = 0
for name, value in factors.items(): for name, value in factors.items():
if value is not None: if value is not None:
results[name] = normalize(value) try:
normalized = normalize(value, clip_threshold)
if not np.isnan(normalized):
results[name] = normalized
success_count += 1
except Exception as e:
LOGGER.warning(
"因子标准化失败 name=%s error=%s",
name,
str(e),
extra=LOG_EXTRA
)
LOGGER.info(
"因子标准化完成 total=%d success=%d failed=%d",
len(factors),
success_count,
len(factors) - success_count,
extra=LOG_EXTRA
)
return results return results

View File

@ -0,0 +1,145 @@
"""Stock code mapping and entity recognition utilities."""
from __future__ import annotations
import re
from typing import Dict, List, Optional, Set, Tuple
# 股票代码正则表达式
A_SH_CODE_PATTERN = re.compile(r"\b(\d{6})(\.(?:SH|SZ))?\b", re.IGNORECASE)
HK_CODE_PATTERN = re.compile(r"\b(\d{4})\.HK\b", re.IGNORECASE)
def normalize_stock_code(code: str, explicit_market: str = None) -> str:
"""规范化股票代码格式.
Args:
code: 原始股票代码
explicit_market: 显式指定的市场 'SH' 'SZ'
Returns:
标准格式的股票代码 '000001.SZ'
"""
if '.' in code:
return code.upper()
if explicit_market:
return f"{code}.{explicit_market.upper()}"
# 根据代码规则判断市场
if code.startswith('6'):
return f"{code}.SH"
elif code.startswith(('0', '3')):
return f"{code}.SZ"
else:
return f"{code}.SH" # 默认使用上交所
# 公司名称变体模式
COMPANY_SUFFIXES = ["股份", "科技", "公司", "集团", "股份有限公司", "有限公司"]
class CompanyNameMapper:
"""Map company names to stock codes with fuzzy matching."""
def __init__(self):
self.name_to_code: Dict[str, str] = {} # 完整名称到代码映射
self.short_names: Dict[str, str] = {} # 简称到代码映射
self.aliases: Dict[str, str] = {} # 别名到代码映射
def add_company(self, ts_code: str, full_name: str, short_name: str, aliases: List[str] = None):
"""Add a company to the mapping.
Args:
ts_code: Stock code in format like '000001.SZ'
full_name: Full registered company name
short_name: Official short name
aliases: List of alternative names
"""
# 存储完整名称映射
self.name_to_code[full_name] = ts_code
# 存储简称映射
self.short_names[short_name] = ts_code
# 生成和存储名称变体
name_variants = self._generate_name_variants(full_name)
for variant in name_variants:
if variant not in self.aliases:
self.aliases[variant] = ts_code
# 存储额外的别名
if aliases:
for alias in aliases:
if alias not in self.aliases:
self.aliases[alias] = ts_code
def _generate_name_variants(self, full_name: str) -> Set[str]:
"""Generate possible variants of a company name."""
variants = set()
# 仅移除整个公司类型后缀
for suffix in COMPANY_SUFFIXES:
if full_name.endswith(suffix):
variant = full_name[:-len(suffix)].strip()
if len(variant) > 2: # 避免太短的变体
variants.add(variant)
break
return variants
def find_codes(self, text: str) -> List[Tuple[str, str, str]]:
"""Find company mentions and corresponding stock codes in text.
Returns:
List of tuples (matched_text, stock_code, match_type)
where match_type is one of 'code', 'full_name', 'short_name', 'alias'
"""
matches = []
# 1. 查找直接的股票代码
for match in A_SH_CODE_PATTERN.finditer(text):
code = match.group(1)
explicit_market = match.group(2)[1:] if match.group(2) else None
ts_code = normalize_stock_code(code, explicit_market)
matches.append((match.group(), ts_code, 'code'))
for match in HK_CODE_PATTERN.finditer(text):
ts_code = match.group()
matches.append((match.group(), ts_code, 'code'))
# 2. 按优先级顺序查找公司名称
# 完整名称优先级最高
for name, code in self.name_to_code.items():
if name in text:
matches.append((name, code, 'full_name'))
# 其次是简称
for name, code in self.short_names.items():
if name in text:
matches.append((name, code, 'short_name'))
# 最后是别名
for alias, code in self.aliases.items():
if alias in text:
matches.append((alias, code, 'alias'))
return matches
# 创建全局单例实例
company_mapper = CompanyNameMapper()
def initialize_company_mapping(db_connection) -> None:
"""从数据库加载公司名称映射.
Args:
db_connection: SQLite数据库连接
"""
cursor = db_connection.cursor()
cursor.execute("""
SELECT ts_code, name, short_name
FROM stock_company
WHERE name IS NOT NULL
""")
for ts_code, name, short_name in cursor.fetchall():
if name and short_name:
company_mapper.add_company(ts_code, name, short_name)
cursor.close()

View File

@ -18,6 +18,8 @@ import hashlib
import random import random
import time import time
from app.ingest.entity_recognition import company_mapper, initialize_company_mapping
try: # pragma: no cover - optional dependency at runtime try: # pragma: no cover - optional dependency at runtime
import feedparser # type: ignore[import-not-found] import feedparser # type: ignore[import-not-found]
except ImportError: # pragma: no cover - graceful fallback except ImportError: # pragma: no cover - graceful fallback
@ -95,6 +97,15 @@ class RssFeedConfig:
max_items: int = 50 max_items: int = 50
@dataclass
class StockMention:
"""A mention of a stock in text."""
matched_text: str
ts_code: str
match_type: str # 'code', 'full_name', 'short_name', 'alias'
context: str # 相关的上下文片段
confidence: float # 匹配的置信度
@dataclass @dataclass
class RssItem: class RssItem:
"""Structured representation of an RSS entry.""" """Structured representation of an RSS entry."""
@ -106,8 +117,214 @@ class RssItem:
summary: str summary: str
source: str source: str
ts_codes: List[str] = field(default_factory=list) ts_codes: List[str] = field(default_factory=list)
industries: List[str] = field(default_factory=list) # 新增:相关行业列表 stock_mentions: List[StockMention] = field(default_factory=list)
important_keywords: List[str] = field(default_factory=list) # 新增:重要关键词列表 industries: List[str] = field(default_factory=list)
important_keywords: List[str] = field(default_factory=list)
def __post_init__(self):
"""Initialize company mapper if not already initialized."""
# 测试环境下跳过数据库初始化
if not hasattr(self, '_skip_db_init'): # 仅在非测试环境下初始化
from app.utils.db import db_session
# 如果company_mapper还没有数据初始化它
if not company_mapper.name_to_code:
with db_session() as conn:
initialize_company_mapping(conn)
def extract_entities(self) -> None:
"""Extract and validate entity mentions from title and summary."""
# 分别处理标题和摘要
title_matches = company_mapper.find_codes(self.title)
summary_matches = company_mapper.find_codes(self.summary)
# 按优先级合并去重后的匹配
code_best_matches = {} # ts_code -> (matched_text, match_type, is_title, context)
# 优先级顺序: 代码 > 全称 > 简称 > 别名
priority = {'code': 0, 'full_name': 1, 'short_name': 2, 'alias': 3}
for matches, text, is_title in [(title_matches, self.title, True),
(summary_matches, self.summary, False)]:
for matched_text, ts_code, match_type in matches:
# 提取上下文
context = self._extract_context(text, matched_text)
# 如果是新代码或优先级更高的匹配
if (ts_code not in code_best_matches or
priority[match_type] < priority[code_best_matches[ts_code][1]]):
code_best_matches[ts_code] = (matched_text, match_type, is_title, context)
# 创建股票提及列表
for ts_code, (matched_text, match_type, is_title, context) in code_best_matches.items():
confidence = self._calculate_confidence(match_type, matched_text, context, is_title)
mention = StockMention(
matched_text=matched_text,
ts_code=ts_code,
match_type=match_type,
context=context,
confidence=confidence
)
self.stock_mentions.append(mention)
# 更新ts_codes列表只包含高置信度的匹配
self.ts_codes = list(set(
mention.ts_code
for mention in self.stock_mentions
if mention.confidence > 0.7 # 只保留高置信度的匹配
))
# 提取行业关键词
self.extract_industries()
# 提取重要关键词
self.extract_important_keywords()
def _extract_context(self, text: str, matched_text: str) -> str:
"""提取匹配文本的上下文,尽量提取完整的句子."""
# 找到匹配文本的位置
start_pos = text.find(matched_text)
if start_pos == -1:
return ""
# 向前找到句子开始(句号、问号、感叹号或换行符之后)
sent_start = start_pos
while sent_start > 0:
if text[sent_start-1] in '。?!\n':
break
sent_start -= 1
# 向后找到句子结束
sent_end = start_pos + len(matched_text)
while sent_end < len(text):
if text[sent_end] in '。?!\n':
sent_end += 1
break
sent_end += 1
# 如果上下文太长,则截取固定长度
context = text[sent_start:sent_end].strip()
if len(context) > 100: # 最大上下文长度
start = max(0, start_pos - 30)
end = min(len(text), start_pos + len(matched_text) + 30)
context = text[start:end].strip()
return context
def extract_industries(self) -> None:
"""从新闻标题和摘要中提取行业关键词."""
content = f"{self.title} {self.summary}".lower()
found_industries = set()
# 对每个行业检查其关键词
for industry, keywords in INDUSTRY_KEYWORDS.items():
# 如果找到任意关键词,认为属于该行业
if any(keyword.lower() in content for keyword in keywords):
found_industries.add(industry)
self.industries = list(found_industries)
def extract_important_keywords(self) -> None:
"""提取重要关键词,包括积极/消极情感词和特定事件."""
content = f"{self.title} {self.summary}".lower()
found_keywords = set()
# 1. 检查积极关键词
for keyword in POSITIVE_KEYWORDS:
if keyword.lower() in content:
found_keywords.add(f"+{keyword}") # 加前缀表示积极
# 2. 检查消极关键词
for keyword in NEGATIVE_KEYWORDS:
if keyword.lower() in content:
found_keywords.add(f"-{keyword}") # 加前缀表示消极
# 3. 检查特定事件关键词
event_keywords = {
# 公司行为
"收购": "M&A",
"并购": "M&A",
"重组": "重组",
"分拆": "分拆",
"上市": "IPO",
# 财务事件
"业绩": "业绩",
"亏损": "业绩预警",
"盈利": "业绩预增",
"分红": "分红",
"回购": "回购",
# 监管事件
"立案": "监管",
"调查": "监管",
"问询": "监管",
"处罚": "处罚",
# 重大项目
"中标": "中标",
"签约": "签约",
"战略合作": "合作",
}
for trigger, event in event_keywords.items():
if trigger in content:
found_keywords.add(f"#{event}") # 加前缀表示事件
self.important_keywords = list(found_keywords)
def _calculate_confidence(self, match_type: str, matched_text: str, context: str, is_title: bool = False) -> float:
"""计算实体匹配的置信度.
考虑以下因素
1. 匹配类型的基础置信度
2. 实体在文本中的位置标题/开头更重要
3. 上下文关键词
4. 股票相关动词
5. 实体的完整性
"""
# 基础置信度
base_confidence = {
'code': 0.9, # 直接的股票代码匹配
'full_name': 0.85,# 完整公司名称匹配
'short_name': 0.7,# 公司简称匹配
'alias': 0.6 # 别名匹配
}.get(match_type, 0.5)
confidence = base_confidence
context_lower = context.lower()
# 1. 位置加权
if is_title:
confidence += 0.1
if context.startswith(matched_text):
confidence += 0.05
# 2. 实体完整性检查
if match_type == 'code' and '.' in matched_text: # 完整股票代码(带市场后缀)
confidence += 0.05
elif match_type == 'full_name' and any(suffix in matched_text for suffix in ["股份有限公司", "有限公司"]):
confidence += 0.05
# 3. 上下文关键词
context_bonus = 0.0
corporate_terms = ["公司", "集团", "企业", "上市", "控股", "总部"]
if any(term in context_lower for term in corporate_terms):
context_bonus += 0.1
# 4. 股票相关动词
stock_verbs = ["发布", "公告", "披露", "表示", "报告", "投资", "回购", "增持", "减持"]
if any(verb in context_lower for verb in stock_verbs):
context_bonus += 0.05
# 5. 财务/业务相关词汇
business_terms = ["业绩", "营收", "利润", "股价", "市值", "经营", "产品", "服务", "战略"]
if any(term in context_lower for term in business_terms):
context_bonus += 0.05
# 限制上下文加成的最大值
confidence += min(context_bonus, 0.2)
# 确保置信度在0-1之间
return min(1.0, max(0.0, confidence))
DEFAULT_RSS_SOURCES: Tuple[RssFeedConfig, ...] = () DEFAULT_RSS_SOURCES: Tuple[RssFeedConfig, ...] = ()
@ -255,7 +472,7 @@ def _fetch_feed_items(
def deduplicate_items(items: Iterable[RssItem]) -> List[RssItem]: def deduplicate_items(items: Iterable[RssItem]) -> List[RssItem]:
"""Drop duplicate stories by link/id fingerprint.""" """Drop duplicate stories by link/id fingerprint and process entities."""
seen = set() seen = set()
unique: List[RssItem] = [] unique: List[RssItem] = []
@ -264,7 +481,14 @@ def deduplicate_items(items: Iterable[RssItem]) -> List[RssItem]:
if key in seen: if key in seen:
continue continue
seen.add(key) seen.add(key)
# 提取实体和相关信息
item.extract_entities()
# 如果找到了相关股票,则保留这条新闻
if item.stock_mentions:
unique.append(item) unique.append(item)
return unique return unique

106
docs/TODO_UNIFIED.md Normal file
View File

@ -0,0 +1,106 @@
# LLM量化交易助理系统开发计划
> 项目愿景:开发一个可实战的投资助理工具,其业务水平要处在投资的前列。核心是通过多智能体协作提供高质量的投资决策支持。
> 开发进度2025-10-05
> ✓ 基础因子计算框架
> ✓ 数据访问与监控
> ✓ 核心回测系统
> ✓ LLM基础集成
> △ RSS新闻处理
> △ UI与监控系统
## 一、核心功能模块优先级排序
### 1. 数据与特征层P0
#### 1.1 因子计算模块优化
- [x] 完善 `compute_factors()` 函数实现:
- [x] 添加数据有效性校验机制
- [x] 实现异常值检测与处理逻辑
- [ ] 增加计算进度显示和日志记录
- [ ] 优化因子持久化性能
- [ ] 支持增量计算模式
#### 1.2 DataBroker增强
- [x] 开发数据请求失败的自动重试机制
- [x] 增加数据源健康状态监控
- [ ] 设计数据质量评估指标系统
#### 1.3 因子库扩展
- [x] 扩展动量类因子群
- [x] 开发估值类因子群
- [x] 设计流动性因子群
- [ ] 构建市场情绪因子群
- [ ] 开发因子组合和权重优化算法
#### 1.4 新闻数据源完善
- [x] 完成RSS数据获取和解析
- [x] 增强情感分析能力
- [ ] 改进实体识别准确率
- [ ] 实现新闻时效性评分
### 2. 决策优化P1
#### 2.1 决策环境增强
- [x] 扩展DecisionEnv动作空间
- [x] 支持提示版本选择
- [x] 允许调节部门温度
- [ ] 优化function调用策略
- [x] 增加环境观测维度:
- [x] 加入换手率指标
- [x] 纳入风险事件统计
- [ ] 补充市场情绪指标
#### 2.2 回测系统完善
- [x] 优化成交撮合逻辑:
- [x] 统一仓位限制
- [x] 考虑换手约束
- [x] 加入滑点模拟
- [x] 计算交易成本
- [x] 完善风险控制:
- [x] 实现止损机制
- [x] 添加波动率限制
- [x] 设置集中度控制
### 3. LLM协同P1
- [x] 精简和优化Provider管理
- [x] 增强function-calling架构
- [x] 完善错误处理和重试策略
- [ ] 优化提示工程:
- [ ] 设计配置化角色提示
- [ ] 优化数据范围控制
- [ ] 改进上下文管理
### 4. UI与监控P2
#### 4.1 功能增强
- [ ] 实现"一键重评估"功能
- [ ] 开发多版本实验对比
- [x] 添加实时指标面板
- [ ] 设计异常日志钻取功能
#### 4.2 监控增强
- [ ] 开发"仅监控不干预"模式
- [x] 实现策略实时评估
- [x] 添加风险预警功能
- [ ] 设计绩效归因分析
### 5. 测试与部署P2
- [x] 补充核心路径单元测试
- [ ] 建立端到端集成测试
- [ ] 完善日志收集机制
## 二、近期开发重点2025 Q4
1. ✓ 完成因子计算模块的优化和重构
2. ✓ 实现基础因子库的扩展
3. ✓ 优化DataBroker的数据访问性能
4. △ 完善RSS新闻数据源的接入
5. ✓ 开始着手决策环境的增强
## 三、开发原则
1. 保持简单:每个模块只实现最核心的功能
2. 重视可靠性:核心功能必须稳定可靠
3. 易于使用:交互界面简单直观
4. 容错设计:关键节点预留人工介入的可能
> 注此计划将根据实际使用体验持续优化始终保持简单实用的原则。上次更新2025-10-05

View File

@ -0,0 +1,110 @@
"""Test improved entity recognition in RSS processing."""
from datetime import datetime, timezone
import pytest
from app.ingest.entity_recognition import CompanyNameMapper, company_mapper
from app.ingest.rss import RssItem, StockMention
def test_company_name_mapper():
mapper = CompanyNameMapper()
# 添加测试公司
mapper.add_company(
ts_code="000001.SZ",
full_name="平安银行股份有限公司",
short_name="平安银行",
aliases=["平安", "PAB"]
)
# 测试名称匹配,会找到所有可能的匹配
matches = mapper.find_codes("平安银行股份有限公司公布2025年业绩")
assert len([m for m in matches if m[1] == "000001.SZ"]) >= 1
# 确保有一个全称匹配
assert any(m[2] == "full_name" and m[1] == "000001.SZ" for m in matches)
# 测试简称匹配
matches = mapper.find_codes("平安银行发布公告")
# 应该找到简称匹配,可能还有其他匹配
assert any(m[1] == "000001.SZ" and m[2] == "short_name" for m in matches)
# 测试别名匹配
matches = mapper.find_codes("PAB发布新产品")
# 应该至少找到别名匹配
assert any(m[1] == "000001.SZ" and m[2] == "alias" for m in matches)
# 测试股票代码直接匹配
matches = mapper.find_codes("000001.SZ开盘上涨")
assert len(matches) == 1
assert matches[0][1] == "000001.SZ"
assert matches[0][2] == "code"
def test_rss_item_entity_extraction():
# 使用全局company_mapper
company_mapper.name_to_code.clear() # 清除之前的数据
company_mapper.add_company(
ts_code="000001.SZ",
full_name="平安银行股份有限公司",
short_name="平安银行",
aliases=["平安", "PAB"]
)
# 创建测试新闻并跳过数据库初始化
class TestRssItem(RssItem):
_skip_db_init = True
item = TestRssItem(
id="test_news",
title="平安银行发布2025年业绩预告",
link="http://example.com",
published=datetime.now(timezone.utc),
summary="平安银行股份有限公司000001.SZ今日发布2025年业绩预告",
source="test"
)
# 提取实体
item.extract_entities()
# 验证结果:由于优先级机制,只会保留最优的匹配
matched_types = set(m.match_type for m in item.stock_mentions)
assert "code" in matched_types or "full_name" in matched_types # 应该至少找到代码或全称匹配
# 验证唯一股票代码
assert len(item.ts_codes) == 1 # 只有一个唯一的股票代码
assert item.ts_codes[0] == "000001.SZ"
# 验证置信度计算
high_confidence = [m for m in item.stock_mentions if m.confidence > 0.7]
assert len(high_confidence) >= 1 # 至少应该有一个高置信度的匹配
def test_rss_item_context_extraction():
# 使用全局company_mapper
company_mapper.name_to_code.clear() # 清除之前的数据
company_mapper.add_company(
ts_code="000001.SZ",
full_name="平安银行股份有限公司",
short_name="平安银行"
)
# 创建带有上下文的测试新闻并跳过数据库初始化
class TestRssItem(RssItem):
_skip_db_init = True
item = TestRssItem(
id="test_news",
title="多家银行业绩报告",
link="http://example.com",
published=datetime.now(timezone.utc),
summary="在银行业整体向好的背景下,平安银行表现突出,营收增长明显",
source="test"
)
# 提取实体
item.extract_entities()
# 验证上下文提取
assert len(item.stock_mentions) > 0
mention = item.stock_mentions[0]
assert len(mention.context) <= 70 # 上下文长度限制30字符前后
assert "平安银行" in mention.context
assert "银行业" in mention.context # 应包含前文
assert "营收增长" in mention.context # 应包含后文

View File

@ -0,0 +1,83 @@
"""Test industry and keyword extraction in RSS processing."""
from datetime import datetime, timezone
from app.ingest.rss import RssItem
def test_industry_extraction():
"""Test industry keyword extraction."""
# 创建测试新闻并跳过数据库初始化
class TestRssItem(RssItem):
_skip_db_init = True
item = TestRssItem(
id="test_news",
title="某半导体公司推出新一代芯片",
link="http://example.com",
published=datetime.now(timezone.utc),
summary="该公司在集成电路领域取得重大突破新产品将用于5G通信",
source="test"
)
# 提取行业关键词
item.extract_industries()
# 验证结果
assert "半导体" in item.industries
assert len(item.industries) >= 1
def test_important_keyword_extraction():
"""Test important keyword extraction."""
# 创建测试新闻(包含积极、消极和事件关键词)并跳过数据库初始化
class TestRssItem(RssItem):
_skip_db_init = True
item = TestRssItem(
id="test_news",
title="某公司业绩超预期,同时宣布重大收购计划",
link="http://example.com",
published=datetime.now(timezone.utc),
summary="营收增长显著,但部分业务亏损,将通过并购扩张",
source="test"
)
# 提取重要关键词
item.extract_important_keywords()
# 验证结果:应该包含积极、消极和事件关键词
keywords = set(item.important_keywords)
# 检查是否包含至少一个积极关键词(前缀为+
assert any(k.startswith('+') for k in keywords)
# 检查是否包含至少一个消极关键词(前缀为-
assert any(k.startswith('-') for k in keywords)
# 检查是否包含至少一个事件关键词(前缀为#
assert any(k.startswith('#') for k in keywords)
def test_rss_item_full_extraction():
"""Test full entity, industry and keyword extraction."""
# 创建一个包含多种信息的测试新闻并跳过数据库初始化
class TestRssItem(RssItem):
_skip_db_init = True
item = TestRssItem(
id="test_news",
title="半导体行业利好:某公司重大突破",
link="http://example.com",
published=datetime.now(timezone.utc),
summary="集成电路领域取得重大进展,业绩超预期,宣布增持计划",
source="test"
)
# 提取所有信息
item.extract_entities() # 这将自动调用行业和关键词提取
# 验证结果的完整性
assert item.industries # 应该至少识别出半导体行业
assert item.important_keywords # 应该找到关键词
# 验证关键词类型的完整性
keywords = set(item.important_keywords)
assert any(k.startswith('+') for k in keywords) # 积极关键词
assert any(k.startswith('#') for k in keywords) # 事件关键词