update
This commit is contained in:
parent
69a3cc69c6
commit
16a5fae732
@ -8,10 +8,48 @@ end-to-end automated decision-making requirements.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Sequence, Optional
|
from typing import Dict, List, Sequence, Optional, Any
|
||||||
|
import functools
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from app.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOGGER = get_logger(__name__)
|
||||||
|
LOG_EXTRA = {"stage": "extended_factors"}
|
||||||
|
|
||||||
|
|
||||||
|
def handle_factor_errors(func: Any) -> Any:
|
||||||
|
"""装饰器:处理因子计算过程中的错误
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: 要装饰的函数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
装饰后的函数
|
||||||
|
"""
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args: Any, **kwargs: Any) -> Optional[float]:
|
||||||
|
try:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
# 获取因子名称(如果可能)
|
||||||
|
factor_name = "unknown"
|
||||||
|
if len(args) > 2 and isinstance(args[1], str):
|
||||||
|
factor_name = args[1]
|
||||||
|
elif "factor_name" in kwargs:
|
||||||
|
factor_name = kwargs["factor_name"]
|
||||||
|
|
||||||
|
LOGGER.error(
|
||||||
|
"计算因子出错 name=%s error=%s",
|
||||||
|
factor_name,
|
||||||
|
str(e),
|
||||||
|
exc_info=True,
|
||||||
|
extra=LOG_EXTRA
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
return wrapper
|
||||||
|
|
||||||
from app.core.indicators import momentum, rolling_mean, normalize
|
from app.core.indicators import momentum, rolling_mean, normalize
|
||||||
from app.core.technical import (
|
from app.core.technical import (
|
||||||
rsi, macd, bollinger_bands, obv_momentum, price_volume_trend
|
rsi, macd, bollinger_bands, obv_momentum, price_volume_trend
|
||||||
@ -101,12 +139,21 @@ class ExtendedFactors:
|
|||||||
)
|
)
|
||||||
all_factors = calculator.compute_all_factors(close_series, volume_series)
|
all_factors = calculator.compute_all_factors(close_series, volume_series)
|
||||||
normalized = calculator.normalize_factors(all_factors)
|
normalized = calculator.normalize_factors(all_factors)
|
||||||
|
|
||||||
|
属性:
|
||||||
|
factor_specs: Dict[str, FactorSpec], 因子名称到因子规格的映射
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""初始化因子计算器"""
|
"""初始化因子计算器,构建因子规格映射"""
|
||||||
self.factor_specs = {spec.name: spec for spec in EXTENDED_FACTORS}
|
self.factor_specs = {spec.name: spec for spec in EXTENDED_FACTORS}
|
||||||
|
LOGGER.info(
|
||||||
|
"初始化因子计算器,加载因子数量: %d",
|
||||||
|
len(self.factor_specs),
|
||||||
|
extra=LOG_EXTRA
|
||||||
|
)
|
||||||
|
|
||||||
|
@handle_factor_errors
|
||||||
def compute_factor(self,
|
def compute_factor(self,
|
||||||
factor_name: str,
|
factor_name: str,
|
||||||
close_series: Sequence[float],
|
close_series: Sequence[float],
|
||||||
@ -114,118 +161,166 @@ class ExtendedFactors:
|
|||||||
"""计算单个因子值
|
"""计算单个因子值
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
factor_name: 因子名称
|
factor_name: 因子名称,必须是已注册的因子
|
||||||
close_series: 收盘价序列,从新到旧排序
|
close_series: 收盘价序列,从新到旧排序
|
||||||
volume_series: 成交量序列,从新到旧排序
|
volume_series: 成交量序列,从新到旧排序
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
因子值,如果计算失败则返回None
|
factor_value: Optional[float], 计算得到的因子值,失败时返回None
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 当因子名称未知或数据不足时抛出
|
||||||
"""
|
"""
|
||||||
try:
|
spec = self.factor_specs.get(factor_name)
|
||||||
spec = self.factor_specs.get(factor_name)
|
if spec is None:
|
||||||
if spec is None:
|
raise ValueError(f"未知的因子名称: {factor_name}")
|
||||||
print(f"Unknown factor: {factor_name}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
if len(close_series) < spec.window:
|
if len(close_series) < spec.window:
|
||||||
return None
|
raise ValueError(
|
||||||
|
f"数据长度不足: 需要{spec.window},实际{len(close_series)}"
|
||||||
|
)
|
||||||
|
|
||||||
# 技术分析因子
|
# 技术分析因子
|
||||||
if factor_name == "tech_rsi_14":
|
if factor_name == "tech_rsi_14":
|
||||||
return rsi(close_series, 14)
|
return rsi(close_series, 14)
|
||||||
|
|
||||||
elif factor_name == "tech_macd_signal":
|
elif factor_name == "tech_macd_signal":
|
||||||
_, signal = macd(close_series)
|
_, signal = macd(close_series)
|
||||||
return signal
|
return signal
|
||||||
|
|
||||||
elif factor_name == "tech_bb_position":
|
elif factor_name == "tech_bb_position":
|
||||||
upper, lower = bollinger_bands(close_series, 20)
|
upper, lower = bollinger_bands(close_series, 20)
|
||||||
pos = (close_series[0] - lower) / (upper - lower + 1e-8)
|
pos = (close_series[0] - lower) / (upper - lower + 1e-8)
|
||||||
return pos
|
return pos
|
||||||
|
|
||||||
elif factor_name == "tech_obv_momentum":
|
elif factor_name == "tech_obv_momentum":
|
||||||
return obv_momentum(close_series, volume_series, 20)
|
return obv_momentum(close_series, volume_series, 20)
|
||||||
|
|
||||||
elif factor_name == "tech_pv_trend":
|
elif factor_name == "tech_pv_trend":
|
||||||
return price_volume_trend(close_series, volume_series, 20)
|
return price_volume_trend(close_series, volume_series, 20)
|
||||||
|
|
||||||
# 趋势跟踪因子
|
# 趋势跟踪因子
|
||||||
elif factor_name == "trend_ma_cross":
|
elif factor_name == "trend_ma_cross":
|
||||||
ma_5 = rolling_mean(close_series, 5)
|
ma_5 = rolling_mean(close_series, 5)
|
||||||
ma_20 = rolling_mean(close_series, 20)
|
ma_20 = rolling_mean(close_series, 20)
|
||||||
return ma_5 - ma_20
|
return ma_5 - ma_20
|
||||||
|
|
||||||
# 波动率预测因子
|
# 波动率预测因子
|
||||||
elif factor_name == "vol_garch":
|
elif factor_name == "vol_garch":
|
||||||
return garch_volatility(close_series, 20)
|
return garch_volatility(close_series, 20)
|
||||||
|
|
||||||
elif factor_name == "vol_regime":
|
elif factor_name == "vol_regime":
|
||||||
regime, _ = volatility_regime(close_series, volume_series, 20)
|
regime, _ = volatility_regime(close_series, volume_series, 20)
|
||||||
return regime
|
return regime
|
||||||
|
|
||||||
# 量价联合因子
|
# 量价联合因子
|
||||||
elif factor_name == "volume_price_corr":
|
elif factor_name == "volume_price_corr":
|
||||||
return volume_price_correlation(close_series, volume_series, 20)
|
return volume_price_correlation(close_series, volume_series, 20)
|
||||||
|
|
||||||
# 增强动量因子
|
# 增强动量因子
|
||||||
elif factor_name == "momentum_adaptive":
|
elif factor_name == "momentum_adaptive":
|
||||||
return adaptive_momentum(close_series, volume_series, 20)
|
return adaptive_momentum(close_series, volume_series, 20)
|
||||||
|
|
||||||
elif factor_name == "momentum_regime":
|
elif factor_name == "momentum_regime":
|
||||||
return momentum_regime(close_series, volume_series, 20)
|
return momentum_regime(close_series, volume_series, 20)
|
||||||
|
|
||||||
elif factor_name == "momentum_quality":
|
elif factor_name == "momentum_quality":
|
||||||
return momentum_quality(close_series, 20)
|
return momentum_quality(close_series, 20)
|
||||||
|
|
||||||
# 均线比率因子
|
# 均线比率因子
|
||||||
elif factor_name.endswith("_ratio"):
|
elif factor_name.endswith("_ratio"):
|
||||||
if "price_ma" in factor_name:
|
if "price_ma" in factor_name:
|
||||||
window = int(factor_name.split("_")[2])
|
window = int(factor_name.split("_")[2])
|
||||||
ma = rolling_mean(close_series, window)
|
ma = rolling_mean(close_series, window)
|
||||||
return close_series[0] / ma if ma > 0 else None
|
return close_series[0] / ma if ma > 0 else None
|
||||||
|
|
||||||
elif "volume_ma" in factor_name:
|
elif "volume_ma" in factor_name:
|
||||||
window = int(factor_name.split("_")[2])
|
window = int(factor_name.split("_")[2])
|
||||||
ma = rolling_mean(volume_series, window)
|
ma = rolling_mean(volume_series, window)
|
||||||
return volume_series[0] / ma if ma > 0 else None
|
return volume_series[0] / ma if ma > 0 else None
|
||||||
|
|
||||||
return None
|
raise ValueError(f"因子 {factor_name} 没有对应的计算实现")
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error computing factor {factor_name}: {str(e)}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def compute_all_factors(self,
|
def compute_all_factors(self,
|
||||||
close_series: Sequence[float],
|
close_series: Sequence[float],
|
||||||
volume_series: Sequence[float]) -> Dict[str, float]:
|
volume_series: Sequence[float]) -> Dict[str, float]:
|
||||||
"""计算所有扩展因子值
|
"""计算所有已注册的扩展因子值
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
close_series: 收盘价序列,从新到旧排序
|
close_series: 收盘价序列,从新到旧排序
|
||||||
volume_series: 成交量序列,从新到旧排序
|
volume_series: 成交量序列,从新到旧排序
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
因子名称到因子值的映射字典
|
Dict[str, float]: 因子名称到因子值的映射字典,
|
||||||
|
只包含成功计算的因子值
|
||||||
|
|
||||||
|
Note:
|
||||||
|
该方法会尝试计算所有已注册的因子,失败的因子将被忽略。
|
||||||
|
如果所有因子计算都失败,将返回空字典。
|
||||||
"""
|
"""
|
||||||
results = {}
|
results = {}
|
||||||
|
success_count = 0
|
||||||
|
total_count = len(self.factor_specs)
|
||||||
|
|
||||||
for factor_name in self.factor_specs:
|
for factor_name in self.factor_specs:
|
||||||
value = self.compute_factor(factor_name, close_series, volume_series)
|
value = self.compute_factor(factor_name, close_series, volume_series)
|
||||||
if value is not None:
|
if value is not None:
|
||||||
results[factor_name] = value
|
results[factor_name] = value
|
||||||
|
success_count += 1
|
||||||
|
|
||||||
|
LOGGER.info(
|
||||||
|
"因子计算完成 total=%d success=%d failed=%d",
|
||||||
|
total_count,
|
||||||
|
success_count,
|
||||||
|
total_count - success_count,
|
||||||
|
extra=LOG_EXTRA
|
||||||
|
)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
def normalize_factors(self, factors: Dict[str, float]) -> Dict[str, float]:
|
|
||||||
|
def normalize_factors(self,
|
||||||
|
factors: Dict[str, float],
|
||||||
|
clip_threshold: float = 3.0) -> Dict[str, float]:
|
||||||
"""标准化因子值到[-1,1]区间
|
"""标准化因子值到[-1,1]区间
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
factors: 原始因子值字典
|
factors: 原始因子值字典
|
||||||
|
clip_threshold: float, 标准化时的截断阈值,默认为3.0
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
标准化后的因子值字典
|
Dict[str, float]: 标准化后的因子值字典,
|
||||||
|
只包含成功标准化的因子值
|
||||||
|
|
||||||
|
Note:
|
||||||
|
标准化过程包括:
|
||||||
|
1. Z-score标准化
|
||||||
|
2. 使用tanh压缩到[-1,1]区间
|
||||||
|
3. 异常值处理(截断)
|
||||||
"""
|
"""
|
||||||
results = {}
|
results = {}
|
||||||
|
success_count = 0
|
||||||
|
|
||||||
for name, value in factors.items():
|
for name, value in factors.items():
|
||||||
if value is not None:
|
if value is not None:
|
||||||
results[name] = normalize(value)
|
try:
|
||||||
|
normalized = normalize(value, clip_threshold)
|
||||||
|
if not np.isnan(normalized):
|
||||||
|
results[name] = normalized
|
||||||
|
success_count += 1
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.warning(
|
||||||
|
"因子标准化失败 name=%s error=%s",
|
||||||
|
name,
|
||||||
|
str(e),
|
||||||
|
extra=LOG_EXTRA
|
||||||
|
)
|
||||||
|
|
||||||
|
LOGGER.info(
|
||||||
|
"因子标准化完成 total=%d success=%d failed=%d",
|
||||||
|
len(factors),
|
||||||
|
success_count,
|
||||||
|
len(factors) - success_count,
|
||||||
|
extra=LOG_EXTRA
|
||||||
|
)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
145
app/ingest/entity_recognition.py
Normal file
145
app/ingest/entity_recognition.py
Normal file
@ -0,0 +1,145 @@
|
|||||||
|
"""Stock code mapping and entity recognition utilities."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
|
# 股票代码正则表达式
|
||||||
|
A_SH_CODE_PATTERN = re.compile(r"\b(\d{6})(\.(?:SH|SZ))?\b", re.IGNORECASE)
|
||||||
|
HK_CODE_PATTERN = re.compile(r"\b(\d{4})\.HK\b", re.IGNORECASE)
|
||||||
|
|
||||||
|
def normalize_stock_code(code: str, explicit_market: str = None) -> str:
|
||||||
|
"""规范化股票代码格式.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: 原始股票代码
|
||||||
|
explicit_market: 显式指定的市场,如 'SH' 或 'SZ'
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
标准格式的股票代码,如 '000001.SZ'
|
||||||
|
"""
|
||||||
|
if '.' in code:
|
||||||
|
return code.upper()
|
||||||
|
|
||||||
|
if explicit_market:
|
||||||
|
return f"{code}.{explicit_market.upper()}"
|
||||||
|
|
||||||
|
# 根据代码规则判断市场
|
||||||
|
if code.startswith('6'):
|
||||||
|
return f"{code}.SH"
|
||||||
|
elif code.startswith(('0', '3')):
|
||||||
|
return f"{code}.SZ"
|
||||||
|
else:
|
||||||
|
return f"{code}.SH" # 默认使用上交所
|
||||||
|
|
||||||
|
# 公司名称变体模式
|
||||||
|
COMPANY_SUFFIXES = ["股份", "科技", "公司", "集团", "股份有限公司", "有限公司"]
|
||||||
|
|
||||||
|
class CompanyNameMapper:
|
||||||
|
"""Map company names to stock codes with fuzzy matching."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.name_to_code: Dict[str, str] = {} # 完整名称到代码映射
|
||||||
|
self.short_names: Dict[str, str] = {} # 简称到代码映射
|
||||||
|
self.aliases: Dict[str, str] = {} # 别名到代码映射
|
||||||
|
|
||||||
|
def add_company(self, ts_code: str, full_name: str, short_name: str, aliases: List[str] = None):
|
||||||
|
"""Add a company to the mapping.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ts_code: Stock code in format like '000001.SZ'
|
||||||
|
full_name: Full registered company name
|
||||||
|
short_name: Official short name
|
||||||
|
aliases: List of alternative names
|
||||||
|
"""
|
||||||
|
# 存储完整名称映射
|
||||||
|
self.name_to_code[full_name] = ts_code
|
||||||
|
|
||||||
|
# 存储简称映射
|
||||||
|
self.short_names[short_name] = ts_code
|
||||||
|
|
||||||
|
# 生成和存储名称变体
|
||||||
|
name_variants = self._generate_name_variants(full_name)
|
||||||
|
for variant in name_variants:
|
||||||
|
if variant not in self.aliases:
|
||||||
|
self.aliases[variant] = ts_code
|
||||||
|
|
||||||
|
# 存储额外的别名
|
||||||
|
if aliases:
|
||||||
|
for alias in aliases:
|
||||||
|
if alias not in self.aliases:
|
||||||
|
self.aliases[alias] = ts_code
|
||||||
|
|
||||||
|
def _generate_name_variants(self, full_name: str) -> Set[str]:
|
||||||
|
"""Generate possible variants of a company name."""
|
||||||
|
variants = set()
|
||||||
|
|
||||||
|
# 仅移除整个公司类型后缀
|
||||||
|
for suffix in COMPANY_SUFFIXES:
|
||||||
|
if full_name.endswith(suffix):
|
||||||
|
variant = full_name[:-len(suffix)].strip()
|
||||||
|
if len(variant) > 2: # 避免太短的变体
|
||||||
|
variants.add(variant)
|
||||||
|
break
|
||||||
|
|
||||||
|
return variants
|
||||||
|
|
||||||
|
def find_codes(self, text: str) -> List[Tuple[str, str, str]]:
|
||||||
|
"""Find company mentions and corresponding stock codes in text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of tuples (matched_text, stock_code, match_type)
|
||||||
|
where match_type is one of 'code', 'full_name', 'short_name', 'alias'
|
||||||
|
"""
|
||||||
|
matches = []
|
||||||
|
|
||||||
|
# 1. 查找直接的股票代码
|
||||||
|
for match in A_SH_CODE_PATTERN.finditer(text):
|
||||||
|
code = match.group(1)
|
||||||
|
explicit_market = match.group(2)[1:] if match.group(2) else None
|
||||||
|
ts_code = normalize_stock_code(code, explicit_market)
|
||||||
|
matches.append((match.group(), ts_code, 'code'))
|
||||||
|
|
||||||
|
for match in HK_CODE_PATTERN.finditer(text):
|
||||||
|
ts_code = match.group()
|
||||||
|
matches.append((match.group(), ts_code, 'code'))
|
||||||
|
|
||||||
|
# 2. 按优先级顺序查找公司名称
|
||||||
|
# 完整名称优先级最高
|
||||||
|
for name, code in self.name_to_code.items():
|
||||||
|
if name in text:
|
||||||
|
matches.append((name, code, 'full_name'))
|
||||||
|
|
||||||
|
# 其次是简称
|
||||||
|
for name, code in self.short_names.items():
|
||||||
|
if name in text:
|
||||||
|
matches.append((name, code, 'short_name'))
|
||||||
|
|
||||||
|
# 最后是别名
|
||||||
|
for alias, code in self.aliases.items():
|
||||||
|
if alias in text:
|
||||||
|
matches.append((alias, code, 'alias'))
|
||||||
|
|
||||||
|
return matches
|
||||||
|
|
||||||
|
# 创建全局单例实例
|
||||||
|
company_mapper = CompanyNameMapper()
|
||||||
|
|
||||||
|
def initialize_company_mapping(db_connection) -> None:
|
||||||
|
"""从数据库加载公司名称映射.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_connection: SQLite数据库连接
|
||||||
|
"""
|
||||||
|
cursor = db_connection.cursor()
|
||||||
|
cursor.execute("""
|
||||||
|
SELECT ts_code, name, short_name
|
||||||
|
FROM stock_company
|
||||||
|
WHERE name IS NOT NULL
|
||||||
|
""")
|
||||||
|
|
||||||
|
for ts_code, name, short_name in cursor.fetchall():
|
||||||
|
if name and short_name:
|
||||||
|
company_mapper.add_company(ts_code, name, short_name)
|
||||||
|
|
||||||
|
cursor.close()
|
||||||
@ -18,6 +18,8 @@ import hashlib
|
|||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
from app.ingest.entity_recognition import company_mapper, initialize_company_mapping
|
||||||
|
|
||||||
try: # pragma: no cover - optional dependency at runtime
|
try: # pragma: no cover - optional dependency at runtime
|
||||||
import feedparser # type: ignore[import-not-found]
|
import feedparser # type: ignore[import-not-found]
|
||||||
except ImportError: # pragma: no cover - graceful fallback
|
except ImportError: # pragma: no cover - graceful fallback
|
||||||
@ -95,6 +97,15 @@ class RssFeedConfig:
|
|||||||
max_items: int = 50
|
max_items: int = 50
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StockMention:
|
||||||
|
"""A mention of a stock in text."""
|
||||||
|
matched_text: str
|
||||||
|
ts_code: str
|
||||||
|
match_type: str # 'code', 'full_name', 'short_name', 'alias'
|
||||||
|
context: str # 相关的上下文片段
|
||||||
|
confidence: float # 匹配的置信度
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RssItem:
|
class RssItem:
|
||||||
"""Structured representation of an RSS entry."""
|
"""Structured representation of an RSS entry."""
|
||||||
@ -106,8 +117,214 @@ class RssItem:
|
|||||||
summary: str
|
summary: str
|
||||||
source: str
|
source: str
|
||||||
ts_codes: List[str] = field(default_factory=list)
|
ts_codes: List[str] = field(default_factory=list)
|
||||||
industries: List[str] = field(default_factory=list) # 新增:相关行业列表
|
stock_mentions: List[StockMention] = field(default_factory=list)
|
||||||
important_keywords: List[str] = field(default_factory=list) # 新增:重要关键词列表
|
industries: List[str] = field(default_factory=list)
|
||||||
|
important_keywords: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""Initialize company mapper if not already initialized."""
|
||||||
|
# 测试环境下跳过数据库初始化
|
||||||
|
if not hasattr(self, '_skip_db_init'): # 仅在非测试环境下初始化
|
||||||
|
from app.utils.db import db_session
|
||||||
|
|
||||||
|
# 如果company_mapper还没有数据,初始化它
|
||||||
|
if not company_mapper.name_to_code:
|
||||||
|
with db_session() as conn:
|
||||||
|
initialize_company_mapping(conn)
|
||||||
|
|
||||||
|
def extract_entities(self) -> None:
|
||||||
|
"""Extract and validate entity mentions from title and summary."""
|
||||||
|
# 分别处理标题和摘要
|
||||||
|
title_matches = company_mapper.find_codes(self.title)
|
||||||
|
summary_matches = company_mapper.find_codes(self.summary)
|
||||||
|
|
||||||
|
# 按优先级合并去重后的匹配
|
||||||
|
code_best_matches = {} # ts_code -> (matched_text, match_type, is_title, context)
|
||||||
|
|
||||||
|
# 优先级顺序: 代码 > 全称 > 简称 > 别名
|
||||||
|
priority = {'code': 0, 'full_name': 1, 'short_name': 2, 'alias': 3}
|
||||||
|
|
||||||
|
for matches, text, is_title in [(title_matches, self.title, True),
|
||||||
|
(summary_matches, self.summary, False)]:
|
||||||
|
for matched_text, ts_code, match_type in matches:
|
||||||
|
# 提取上下文
|
||||||
|
context = self._extract_context(text, matched_text)
|
||||||
|
|
||||||
|
# 如果是新代码或优先级更高的匹配
|
||||||
|
if (ts_code not in code_best_matches or
|
||||||
|
priority[match_type] < priority[code_best_matches[ts_code][1]]):
|
||||||
|
code_best_matches[ts_code] = (matched_text, match_type, is_title, context)
|
||||||
|
|
||||||
|
# 创建股票提及列表
|
||||||
|
for ts_code, (matched_text, match_type, is_title, context) in code_best_matches.items():
|
||||||
|
confidence = self._calculate_confidence(match_type, matched_text, context, is_title)
|
||||||
|
|
||||||
|
mention = StockMention(
|
||||||
|
matched_text=matched_text,
|
||||||
|
ts_code=ts_code,
|
||||||
|
match_type=match_type,
|
||||||
|
context=context,
|
||||||
|
confidence=confidence
|
||||||
|
)
|
||||||
|
self.stock_mentions.append(mention)
|
||||||
|
|
||||||
|
# 更新ts_codes列表,只包含高置信度的匹配
|
||||||
|
self.ts_codes = list(set(
|
||||||
|
mention.ts_code
|
||||||
|
for mention in self.stock_mentions
|
||||||
|
if mention.confidence > 0.7 # 只保留高置信度的匹配
|
||||||
|
))
|
||||||
|
|
||||||
|
# 提取行业关键词
|
||||||
|
self.extract_industries()
|
||||||
|
|
||||||
|
# 提取重要关键词
|
||||||
|
self.extract_important_keywords()
|
||||||
|
|
||||||
|
def _extract_context(self, text: str, matched_text: str) -> str:
|
||||||
|
"""提取匹配文本的上下文,尽量提取完整的句子."""
|
||||||
|
# 找到匹配文本的位置
|
||||||
|
start_pos = text.find(matched_text)
|
||||||
|
if start_pos == -1:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# 向前找到句子开始(句号、问号、感叹号或换行符之后)
|
||||||
|
sent_start = start_pos
|
||||||
|
while sent_start > 0:
|
||||||
|
if text[sent_start-1] in '。?!\n':
|
||||||
|
break
|
||||||
|
sent_start -= 1
|
||||||
|
|
||||||
|
# 向后找到句子结束
|
||||||
|
sent_end = start_pos + len(matched_text)
|
||||||
|
while sent_end < len(text):
|
||||||
|
if text[sent_end] in '。?!\n':
|
||||||
|
sent_end += 1
|
||||||
|
break
|
||||||
|
sent_end += 1
|
||||||
|
|
||||||
|
# 如果上下文太长,则截取固定长度
|
||||||
|
context = text[sent_start:sent_end].strip()
|
||||||
|
if len(context) > 100: # 最大上下文长度
|
||||||
|
start = max(0, start_pos - 30)
|
||||||
|
end = min(len(text), start_pos + len(matched_text) + 30)
|
||||||
|
context = text[start:end].strip()
|
||||||
|
|
||||||
|
return context
|
||||||
|
|
||||||
|
def extract_industries(self) -> None:
|
||||||
|
"""从新闻标题和摘要中提取行业关键词."""
|
||||||
|
content = f"{self.title} {self.summary}".lower()
|
||||||
|
found_industries = set()
|
||||||
|
|
||||||
|
# 对每个行业检查其关键词
|
||||||
|
for industry, keywords in INDUSTRY_KEYWORDS.items():
|
||||||
|
# 如果找到任意关键词,认为属于该行业
|
||||||
|
if any(keyword.lower() in content for keyword in keywords):
|
||||||
|
found_industries.add(industry)
|
||||||
|
|
||||||
|
self.industries = list(found_industries)
|
||||||
|
|
||||||
|
def extract_important_keywords(self) -> None:
|
||||||
|
"""提取重要关键词,包括积极/消极情感词和特定事件."""
|
||||||
|
content = f"{self.title} {self.summary}".lower()
|
||||||
|
found_keywords = set()
|
||||||
|
|
||||||
|
# 1. 检查积极关键词
|
||||||
|
for keyword in POSITIVE_KEYWORDS:
|
||||||
|
if keyword.lower() in content:
|
||||||
|
found_keywords.add(f"+{keyword}") # 加前缀表示积极
|
||||||
|
|
||||||
|
# 2. 检查消极关键词
|
||||||
|
for keyword in NEGATIVE_KEYWORDS:
|
||||||
|
if keyword.lower() in content:
|
||||||
|
found_keywords.add(f"-{keyword}") # 加前缀表示消极
|
||||||
|
|
||||||
|
# 3. 检查特定事件关键词
|
||||||
|
event_keywords = {
|
||||||
|
# 公司行为
|
||||||
|
"收购": "M&A",
|
||||||
|
"并购": "M&A",
|
||||||
|
"重组": "重组",
|
||||||
|
"分拆": "分拆",
|
||||||
|
"上市": "IPO",
|
||||||
|
# 财务事件
|
||||||
|
"业绩": "业绩",
|
||||||
|
"亏损": "业绩预警",
|
||||||
|
"盈利": "业绩预增",
|
||||||
|
"分红": "分红",
|
||||||
|
"回购": "回购",
|
||||||
|
# 监管事件
|
||||||
|
"立案": "监管",
|
||||||
|
"调查": "监管",
|
||||||
|
"问询": "监管",
|
||||||
|
"处罚": "处罚",
|
||||||
|
# 重大项目
|
||||||
|
"中标": "中标",
|
||||||
|
"签约": "签约",
|
||||||
|
"战略合作": "合作",
|
||||||
|
}
|
||||||
|
|
||||||
|
for trigger, event in event_keywords.items():
|
||||||
|
if trigger in content:
|
||||||
|
found_keywords.add(f"#{event}") # 加前缀表示事件
|
||||||
|
|
||||||
|
self.important_keywords = list(found_keywords)
|
||||||
|
|
||||||
|
def _calculate_confidence(self, match_type: str, matched_text: str, context: str, is_title: bool = False) -> float:
|
||||||
|
"""计算实体匹配的置信度.
|
||||||
|
|
||||||
|
考虑以下因素:
|
||||||
|
1. 匹配类型的基础置信度
|
||||||
|
2. 实体在文本中的位置(标题/开头更重要)
|
||||||
|
3. 上下文关键词
|
||||||
|
4. 股票相关动词
|
||||||
|
5. 实体的完整性
|
||||||
|
"""
|
||||||
|
# 基础置信度
|
||||||
|
base_confidence = {
|
||||||
|
'code': 0.9, # 直接的股票代码匹配
|
||||||
|
'full_name': 0.85,# 完整公司名称匹配
|
||||||
|
'short_name': 0.7,# 公司简称匹配
|
||||||
|
'alias': 0.6 # 别名匹配
|
||||||
|
}.get(match_type, 0.5)
|
||||||
|
|
||||||
|
confidence = base_confidence
|
||||||
|
context_lower = context.lower()
|
||||||
|
|
||||||
|
# 1. 位置加权
|
||||||
|
if is_title:
|
||||||
|
confidence += 0.1
|
||||||
|
if context.startswith(matched_text):
|
||||||
|
confidence += 0.05
|
||||||
|
|
||||||
|
# 2. 实体完整性检查
|
||||||
|
if match_type == 'code' and '.' in matched_text: # 完整股票代码(带市场后缀)
|
||||||
|
confidence += 0.05
|
||||||
|
elif match_type == 'full_name' and any(suffix in matched_text for suffix in ["股份有限公司", "有限公司"]):
|
||||||
|
confidence += 0.05
|
||||||
|
|
||||||
|
# 3. 上下文关键词
|
||||||
|
context_bonus = 0.0
|
||||||
|
corporate_terms = ["公司", "集团", "企业", "上市", "控股", "总部"]
|
||||||
|
if any(term in context_lower for term in corporate_terms):
|
||||||
|
context_bonus += 0.1
|
||||||
|
|
||||||
|
# 4. 股票相关动词
|
||||||
|
stock_verbs = ["发布", "公告", "披露", "表示", "报告", "投资", "回购", "增持", "减持"]
|
||||||
|
if any(verb in context_lower for verb in stock_verbs):
|
||||||
|
context_bonus += 0.05
|
||||||
|
|
||||||
|
# 5. 财务/业务相关词汇
|
||||||
|
business_terms = ["业绩", "营收", "利润", "股价", "市值", "经营", "产品", "服务", "战略"]
|
||||||
|
if any(term in context_lower for term in business_terms):
|
||||||
|
context_bonus += 0.05
|
||||||
|
|
||||||
|
# 限制上下文加成的最大值
|
||||||
|
confidence += min(context_bonus, 0.2)
|
||||||
|
|
||||||
|
# 确保置信度在0-1之间
|
||||||
|
return min(1.0, max(0.0, confidence))
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_RSS_SOURCES: Tuple[RssFeedConfig, ...] = ()
|
DEFAULT_RSS_SOURCES: Tuple[RssFeedConfig, ...] = ()
|
||||||
@ -255,7 +472,7 @@ def _fetch_feed_items(
|
|||||||
|
|
||||||
|
|
||||||
def deduplicate_items(items: Iterable[RssItem]) -> List[RssItem]:
|
def deduplicate_items(items: Iterable[RssItem]) -> List[RssItem]:
|
||||||
"""Drop duplicate stories by link/id fingerprint."""
|
"""Drop duplicate stories by link/id fingerprint and process entities."""
|
||||||
|
|
||||||
seen = set()
|
seen = set()
|
||||||
unique: List[RssItem] = []
|
unique: List[RssItem] = []
|
||||||
@ -264,7 +481,14 @@ def deduplicate_items(items: Iterable[RssItem]) -> List[RssItem]:
|
|||||||
if key in seen:
|
if key in seen:
|
||||||
continue
|
continue
|
||||||
seen.add(key)
|
seen.add(key)
|
||||||
unique.append(item)
|
|
||||||
|
# 提取实体和相关信息
|
||||||
|
item.extract_entities()
|
||||||
|
|
||||||
|
# 如果找到了相关股票,则保留这条新闻
|
||||||
|
if item.stock_mentions:
|
||||||
|
unique.append(item)
|
||||||
|
|
||||||
return unique
|
return unique
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
106
docs/TODO_UNIFIED.md
Normal file
106
docs/TODO_UNIFIED.md
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
# LLM量化交易助理系统开发计划
|
||||||
|
|
||||||
|
> 项目愿景:开发一个可实战的投资助理工具,其业务水平要处在投资的前列。核心是通过多智能体协作提供高质量的投资决策支持。
|
||||||
|
|
||||||
|
> 开发进度(2025-10-05):
|
||||||
|
> ✓ 基础因子计算框架
|
||||||
|
> ✓ 数据访问与监控
|
||||||
|
> ✓ 核心回测系统
|
||||||
|
> ✓ LLM基础集成
|
||||||
|
> △ RSS新闻处理
|
||||||
|
> △ UI与监控系统
|
||||||
|
|
||||||
|
## 一、核心功能模块优先级排序
|
||||||
|
|
||||||
|
### 1. 数据与特征层(P0)
|
||||||
|
#### 1.1 因子计算模块优化
|
||||||
|
- [x] 完善 `compute_factors()` 函数实现:
|
||||||
|
- [x] 添加数据有效性校验机制
|
||||||
|
- [x] 实现异常值检测与处理逻辑
|
||||||
|
- [ ] 增加计算进度显示和日志记录
|
||||||
|
- [ ] 优化因子持久化性能
|
||||||
|
- [ ] 支持增量计算模式
|
||||||
|
|
||||||
|
#### 1.2 DataBroker增强
|
||||||
|
- [x] 开发数据请求失败的自动重试机制
|
||||||
|
- [x] 增加数据源健康状态监控
|
||||||
|
- [ ] 设计数据质量评估指标系统
|
||||||
|
|
||||||
|
#### 1.3 因子库扩展
|
||||||
|
- [x] 扩展动量类因子群
|
||||||
|
- [x] 开发估值类因子群
|
||||||
|
- [x] 设计流动性因子群
|
||||||
|
- [ ] 构建市场情绪因子群
|
||||||
|
- [ ] 开发因子组合和权重优化算法
|
||||||
|
|
||||||
|
#### 1.4 新闻数据源完善
|
||||||
|
- [x] 完成RSS数据获取和解析
|
||||||
|
- [x] 增强情感分析能力
|
||||||
|
- [ ] 改进实体识别准确率
|
||||||
|
- [ ] 实现新闻时效性评分
|
||||||
|
|
||||||
|
### 2. 决策优化(P1)
|
||||||
|
#### 2.1 决策环境增强
|
||||||
|
- [x] 扩展DecisionEnv动作空间:
|
||||||
|
- [x] 支持提示版本选择
|
||||||
|
- [x] 允许调节部门温度
|
||||||
|
- [ ] 优化function调用策略
|
||||||
|
- [x] 增加环境观测维度:
|
||||||
|
- [x] 加入换手率指标
|
||||||
|
- [x] 纳入风险事件统计
|
||||||
|
- [ ] 补充市场情绪指标
|
||||||
|
|
||||||
|
#### 2.2 回测系统完善
|
||||||
|
- [x] 优化成交撮合逻辑:
|
||||||
|
- [x] 统一仓位限制
|
||||||
|
- [x] 考虑换手约束
|
||||||
|
- [x] 加入滑点模拟
|
||||||
|
- [x] 计算交易成本
|
||||||
|
- [x] 完善风险控制:
|
||||||
|
- [x] 实现止损机制
|
||||||
|
- [x] 添加波动率限制
|
||||||
|
- [x] 设置集中度控制
|
||||||
|
|
||||||
|
### 3. LLM协同(P1)
|
||||||
|
- [x] 精简和优化Provider管理
|
||||||
|
- [x] 增强function-calling架构
|
||||||
|
- [x] 完善错误处理和重试策略
|
||||||
|
- [ ] 优化提示工程:
|
||||||
|
- [ ] 设计配置化角色提示
|
||||||
|
- [ ] 优化数据范围控制
|
||||||
|
- [ ] 改进上下文管理
|
||||||
|
|
||||||
|
### 4. UI与监控(P2)
|
||||||
|
#### 4.1 功能增强
|
||||||
|
- [ ] 实现"一键重评估"功能
|
||||||
|
- [ ] 开发多版本实验对比
|
||||||
|
- [x] 添加实时指标面板
|
||||||
|
- [ ] 设计异常日志钻取功能
|
||||||
|
|
||||||
|
#### 4.2 监控增强
|
||||||
|
- [ ] 开发"仅监控不干预"模式
|
||||||
|
- [x] 实现策略实时评估
|
||||||
|
- [x] 添加风险预警功能
|
||||||
|
- [ ] 设计绩效归因分析
|
||||||
|
|
||||||
|
### 5. 测试与部署(P2)
|
||||||
|
- [x] 补充核心路径单元测试
|
||||||
|
- [ ] 建立端到端集成测试
|
||||||
|
- [ ] 完善日志收集机制
|
||||||
|
|
||||||
|
## 二、近期开发重点(2025 Q4)
|
||||||
|
|
||||||
|
1. ✓ 完成因子计算模块的优化和重构
|
||||||
|
2. ✓ 实现基础因子库的扩展
|
||||||
|
3. ✓ 优化DataBroker的数据访问性能
|
||||||
|
4. △ 完善RSS新闻数据源的接入
|
||||||
|
5. ✓ 开始着手决策环境的增强
|
||||||
|
|
||||||
|
## 三、开发原则
|
||||||
|
|
||||||
|
1. 保持简单:每个模块只实现最核心的功能
|
||||||
|
2. 重视可靠性:核心功能必须稳定可靠
|
||||||
|
3. 易于使用:交互界面简单直观
|
||||||
|
4. 容错设计:关键节点预留人工介入的可能
|
||||||
|
|
||||||
|
> 注:此计划将根据实际使用体验持续优化,始终保持简单实用的原则。上次更新:2025-10-05
|
||||||
110
tests/test_entity_recognition.py
Normal file
110
tests/test_entity_recognition.py
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
"""Test improved entity recognition in RSS processing."""
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.ingest.entity_recognition import CompanyNameMapper, company_mapper
|
||||||
|
from app.ingest.rss import RssItem, StockMention
|
||||||
|
|
||||||
|
def test_company_name_mapper():
|
||||||
|
mapper = CompanyNameMapper()
|
||||||
|
|
||||||
|
# 添加测试公司
|
||||||
|
mapper.add_company(
|
||||||
|
ts_code="000001.SZ",
|
||||||
|
full_name="平安银行股份有限公司",
|
||||||
|
short_name="平安银行",
|
||||||
|
aliases=["平安", "PAB"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 测试名称匹配,会找到所有可能的匹配
|
||||||
|
matches = mapper.find_codes("平安银行股份有限公司公布2025年业绩")
|
||||||
|
assert len([m for m in matches if m[1] == "000001.SZ"]) >= 1
|
||||||
|
# 确保有一个全称匹配
|
||||||
|
assert any(m[2] == "full_name" and m[1] == "000001.SZ" for m in matches)
|
||||||
|
|
||||||
|
# 测试简称匹配
|
||||||
|
matches = mapper.find_codes("平安银行发布公告")
|
||||||
|
# 应该找到简称匹配,可能还有其他匹配
|
||||||
|
assert any(m[1] == "000001.SZ" and m[2] == "short_name" for m in matches)
|
||||||
|
|
||||||
|
# 测试别名匹配
|
||||||
|
matches = mapper.find_codes("PAB发布新产品")
|
||||||
|
# 应该至少找到别名匹配
|
||||||
|
assert any(m[1] == "000001.SZ" and m[2] == "alias" for m in matches)
|
||||||
|
|
||||||
|
# 测试股票代码直接匹配
|
||||||
|
matches = mapper.find_codes("000001.SZ开盘上涨")
|
||||||
|
assert len(matches) == 1
|
||||||
|
assert matches[0][1] == "000001.SZ"
|
||||||
|
assert matches[0][2] == "code"
|
||||||
|
|
||||||
|
def test_rss_item_entity_extraction():
|
||||||
|
# 使用全局company_mapper
|
||||||
|
company_mapper.name_to_code.clear() # 清除之前的数据
|
||||||
|
company_mapper.add_company(
|
||||||
|
ts_code="000001.SZ",
|
||||||
|
full_name="平安银行股份有限公司",
|
||||||
|
short_name="平安银行",
|
||||||
|
aliases=["平安", "PAB"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建测试新闻并跳过数据库初始化
|
||||||
|
class TestRssItem(RssItem):
|
||||||
|
_skip_db_init = True
|
||||||
|
|
||||||
|
item = TestRssItem(
|
||||||
|
id="test_news",
|
||||||
|
title="平安银行发布2025年业绩预告",
|
||||||
|
link="http://example.com",
|
||||||
|
published=datetime.now(timezone.utc),
|
||||||
|
summary="平安银行股份有限公司(000001.SZ)今日发布2025年业绩预告",
|
||||||
|
source="test"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 提取实体
|
||||||
|
item.extract_entities()
|
||||||
|
|
||||||
|
# 验证结果:由于优先级机制,只会保留最优的匹配
|
||||||
|
matched_types = set(m.match_type for m in item.stock_mentions)
|
||||||
|
assert "code" in matched_types or "full_name" in matched_types # 应该至少找到代码或全称匹配
|
||||||
|
|
||||||
|
# 验证唯一股票代码
|
||||||
|
assert len(item.ts_codes) == 1 # 只有一个唯一的股票代码
|
||||||
|
assert item.ts_codes[0] == "000001.SZ"
|
||||||
|
|
||||||
|
# 验证置信度计算
|
||||||
|
high_confidence = [m for m in item.stock_mentions if m.confidence > 0.7]
|
||||||
|
assert len(high_confidence) >= 1 # 至少应该有一个高置信度的匹配
|
||||||
|
|
||||||
|
def test_rss_item_context_extraction():
|
||||||
|
# 使用全局company_mapper
|
||||||
|
company_mapper.name_to_code.clear() # 清除之前的数据
|
||||||
|
company_mapper.add_company(
|
||||||
|
ts_code="000001.SZ",
|
||||||
|
full_name="平安银行股份有限公司",
|
||||||
|
short_name="平安银行"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建带有上下文的测试新闻并跳过数据库初始化
|
||||||
|
class TestRssItem(RssItem):
|
||||||
|
_skip_db_init = True
|
||||||
|
|
||||||
|
item = TestRssItem(
|
||||||
|
id="test_news",
|
||||||
|
title="多家银行业绩报告",
|
||||||
|
link="http://example.com",
|
||||||
|
published=datetime.now(timezone.utc),
|
||||||
|
summary="在银行业整体向好的背景下,平安银行表现突出,营收增长明显",
|
||||||
|
source="test"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 提取实体
|
||||||
|
item.extract_entities()
|
||||||
|
|
||||||
|
# 验证上下文提取
|
||||||
|
assert len(item.stock_mentions) > 0
|
||||||
|
mention = item.stock_mentions[0]
|
||||||
|
assert len(mention.context) <= 70 # 上下文长度限制(30字符前后)
|
||||||
|
assert "平安银行" in mention.context
|
||||||
|
assert "银行业" in mention.context # 应包含前文
|
||||||
|
assert "营收增长" in mention.context # 应包含后文
|
||||||
83
tests/test_rss_item_industry_keywords.py
Normal file
83
tests/test_rss_item_industry_keywords.py
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
"""Test industry and keyword extraction in RSS processing."""
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from app.ingest.rss import RssItem
|
||||||
|
|
||||||
|
def test_industry_extraction():
|
||||||
|
"""Test industry keyword extraction."""
|
||||||
|
# 创建测试新闻并跳过数据库初始化
|
||||||
|
class TestRssItem(RssItem):
|
||||||
|
_skip_db_init = True
|
||||||
|
|
||||||
|
item = TestRssItem(
|
||||||
|
id="test_news",
|
||||||
|
title="某半导体公司推出新一代芯片",
|
||||||
|
link="http://example.com",
|
||||||
|
published=datetime.now(timezone.utc),
|
||||||
|
summary="该公司在集成电路领域取得重大突破,新产品将用于5G通信",
|
||||||
|
source="test"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 提取行业关键词
|
||||||
|
item.extract_industries()
|
||||||
|
|
||||||
|
# 验证结果
|
||||||
|
assert "半导体" in item.industries
|
||||||
|
assert len(item.industries) >= 1
|
||||||
|
|
||||||
|
def test_important_keyword_extraction():
|
||||||
|
"""Test important keyword extraction."""
|
||||||
|
# 创建测试新闻(包含积极、消极和事件关键词)并跳过数据库初始化
|
||||||
|
class TestRssItem(RssItem):
|
||||||
|
_skip_db_init = True
|
||||||
|
|
||||||
|
item = TestRssItem(
|
||||||
|
id="test_news",
|
||||||
|
title="某公司业绩超预期,同时宣布重大收购计划",
|
||||||
|
link="http://example.com",
|
||||||
|
published=datetime.now(timezone.utc),
|
||||||
|
summary="营收增长显著,但部分业务亏损,将通过并购扩张",
|
||||||
|
source="test"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 提取重要关键词
|
||||||
|
item.extract_important_keywords()
|
||||||
|
|
||||||
|
# 验证结果:应该包含积极、消极和事件关键词
|
||||||
|
keywords = set(item.important_keywords)
|
||||||
|
|
||||||
|
# 检查是否包含至少一个积极关键词(前缀为+)
|
||||||
|
assert any(k.startswith('+') for k in keywords)
|
||||||
|
|
||||||
|
# 检查是否包含至少一个消极关键词(前缀为-)
|
||||||
|
assert any(k.startswith('-') for k in keywords)
|
||||||
|
|
||||||
|
# 检查是否包含至少一个事件关键词(前缀为#)
|
||||||
|
assert any(k.startswith('#') for k in keywords)
|
||||||
|
|
||||||
|
def test_rss_item_full_extraction():
|
||||||
|
"""Test full entity, industry and keyword extraction."""
|
||||||
|
# 创建一个包含多种信息的测试新闻并跳过数据库初始化
|
||||||
|
class TestRssItem(RssItem):
|
||||||
|
_skip_db_init = True
|
||||||
|
|
||||||
|
item = TestRssItem(
|
||||||
|
id="test_news",
|
||||||
|
title="半导体行业利好:某公司重大突破",
|
||||||
|
link="http://example.com",
|
||||||
|
published=datetime.now(timezone.utc),
|
||||||
|
summary="集成电路领域取得重大进展,业绩超预期,宣布增持计划",
|
||||||
|
source="test"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 提取所有信息
|
||||||
|
item.extract_entities() # 这将自动调用行业和关键词提取
|
||||||
|
|
||||||
|
# 验证结果的完整性
|
||||||
|
assert item.industries # 应该至少识别出半导体行业
|
||||||
|
assert item.important_keywords # 应该找到关键词
|
||||||
|
|
||||||
|
# 验证关键词类型的完整性
|
||||||
|
keywords = set(item.important_keywords)
|
||||||
|
assert any(k.startswith('+') for k in keywords) # 积极关键词
|
||||||
|
assert any(k.startswith('#') for k in keywords) # 事件关键词
|
||||||
Loading…
Reference in New Issue
Block a user