update
This commit is contained in:
parent
69a3cc69c6
commit
16a5fae732
@ -8,10 +8,48 @@ end-to-end automated decision-making requirements.
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Sequence, Optional
|
||||
from typing import Dict, List, Sequence, Optional, Any
|
||||
import functools
|
||||
|
||||
import numpy as np
|
||||
|
||||
from app.utils.logging import get_logger
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
LOG_EXTRA = {"stage": "extended_factors"}
|
||||
|
||||
|
||||
def handle_factor_errors(func: Any) -> Any:
|
||||
"""装饰器:处理因子计算过程中的错误
|
||||
|
||||
Args:
|
||||
func: 要装饰的函数
|
||||
|
||||
Returns:
|
||||
装饰后的函数
|
||||
"""
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Optional[float]:
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
# 获取因子名称(如果可能)
|
||||
factor_name = "unknown"
|
||||
if len(args) > 2 and isinstance(args[1], str):
|
||||
factor_name = args[1]
|
||||
elif "factor_name" in kwargs:
|
||||
factor_name = kwargs["factor_name"]
|
||||
|
||||
LOGGER.error(
|
||||
"计算因子出错 name=%s error=%s",
|
||||
factor_name,
|
||||
str(e),
|
||||
exc_info=True,
|
||||
extra=LOG_EXTRA
|
||||
)
|
||||
return None
|
||||
return wrapper
|
||||
|
||||
from app.core.indicators import momentum, rolling_mean, normalize
|
||||
from app.core.technical import (
|
||||
rsi, macd, bollinger_bands, obv_momentum, price_volume_trend
|
||||
@ -101,12 +139,21 @@ class ExtendedFactors:
|
||||
)
|
||||
all_factors = calculator.compute_all_factors(close_series, volume_series)
|
||||
normalized = calculator.normalize_factors(all_factors)
|
||||
|
||||
属性:
|
||||
factor_specs: Dict[str, FactorSpec], 因子名称到因子规格的映射
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化因子计算器"""
|
||||
"""初始化因子计算器,构建因子规格映射"""
|
||||
self.factor_specs = {spec.name: spec for spec in EXTENDED_FACTORS}
|
||||
LOGGER.info(
|
||||
"初始化因子计算器,加载因子数量: %d",
|
||||
len(self.factor_specs),
|
||||
extra=LOG_EXTRA
|
||||
)
|
||||
|
||||
@handle_factor_errors
|
||||
def compute_factor(self,
|
||||
factor_name: str,
|
||||
close_series: Sequence[float],
|
||||
@ -114,21 +161,24 @@ class ExtendedFactors:
|
||||
"""计算单个因子值
|
||||
|
||||
Args:
|
||||
factor_name: 因子名称
|
||||
factor_name: 因子名称,必须是已注册的因子
|
||||
close_series: 收盘价序列,从新到旧排序
|
||||
volume_series: 成交量序列,从新到旧排序
|
||||
|
||||
Returns:
|
||||
因子值,如果计算失败则返回None
|
||||
factor_value: Optional[float], 计算得到的因子值,失败时返回None
|
||||
|
||||
Raises:
|
||||
ValueError: 当因子名称未知或数据不足时抛出
|
||||
"""
|
||||
try:
|
||||
spec = self.factor_specs.get(factor_name)
|
||||
if spec is None:
|
||||
print(f"Unknown factor: {factor_name}")
|
||||
return None
|
||||
raise ValueError(f"未知的因子名称: {factor_name}")
|
||||
|
||||
if len(close_series) < spec.window:
|
||||
return None
|
||||
raise ValueError(
|
||||
f"数据长度不足: 需要{spec.window},实际{len(close_series)}"
|
||||
)
|
||||
|
||||
# 技术分析因子
|
||||
if factor_name == "tech_rsi_14":
|
||||
@ -189,43 +239,88 @@ class ExtendedFactors:
|
||||
ma = rolling_mean(volume_series, window)
|
||||
return volume_series[0] / ma if ma > 0 else None
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error computing factor {factor_name}: {str(e)}")
|
||||
return None
|
||||
raise ValueError(f"因子 {factor_name} 没有对应的计算实现")
|
||||
|
||||
def compute_all_factors(self,
|
||||
close_series: Sequence[float],
|
||||
volume_series: Sequence[float]) -> Dict[str, float]:
|
||||
"""计算所有扩展因子值
|
||||
"""计算所有已注册的扩展因子值
|
||||
|
||||
Args:
|
||||
close_series: 收盘价序列,从新到旧排序
|
||||
volume_series: 成交量序列,从新到旧排序
|
||||
|
||||
Returns:
|
||||
因子名称到因子值的映射字典
|
||||
Dict[str, float]: 因子名称到因子值的映射字典,
|
||||
只包含成功计算的因子值
|
||||
|
||||
Note:
|
||||
该方法会尝试计算所有已注册的因子,失败的因子将被忽略。
|
||||
如果所有因子计算都失败,将返回空字典。
|
||||
"""
|
||||
results = {}
|
||||
success_count = 0
|
||||
total_count = len(self.factor_specs)
|
||||
|
||||
for factor_name in self.factor_specs:
|
||||
value = self.compute_factor(factor_name, close_series, volume_series)
|
||||
if value is not None:
|
||||
results[factor_name] = value
|
||||
success_count += 1
|
||||
|
||||
LOGGER.info(
|
||||
"因子计算完成 total=%d success=%d failed=%d",
|
||||
total_count,
|
||||
success_count,
|
||||
total_count - success_count,
|
||||
extra=LOG_EXTRA
|
||||
)
|
||||
|
||||
return results
|
||||
def normalize_factors(self, factors: Dict[str, float]) -> Dict[str, float]:
|
||||
|
||||
def normalize_factors(self,
|
||||
factors: Dict[str, float],
|
||||
clip_threshold: float = 3.0) -> Dict[str, float]:
|
||||
"""标准化因子值到[-1,1]区间
|
||||
|
||||
Args:
|
||||
factors: 原始因子值字典
|
||||
clip_threshold: float, 标准化时的截断阈值,默认为3.0
|
||||
|
||||
Returns:
|
||||
标准化后的因子值字典
|
||||
Dict[str, float]: 标准化后的因子值字典,
|
||||
只包含成功标准化的因子值
|
||||
|
||||
Note:
|
||||
标准化过程包括:
|
||||
1. Z-score标准化
|
||||
2. 使用tanh压缩到[-1,1]区间
|
||||
3. 异常值处理(截断)
|
||||
"""
|
||||
results = {}
|
||||
success_count = 0
|
||||
|
||||
for name, value in factors.items():
|
||||
if value is not None:
|
||||
results[name] = normalize(value)
|
||||
try:
|
||||
normalized = normalize(value, clip_threshold)
|
||||
if not np.isnan(normalized):
|
||||
results[name] = normalized
|
||||
success_count += 1
|
||||
except Exception as e:
|
||||
LOGGER.warning(
|
||||
"因子标准化失败 name=%s error=%s",
|
||||
name,
|
||||
str(e),
|
||||
extra=LOG_EXTRA
|
||||
)
|
||||
|
||||
LOGGER.info(
|
||||
"因子标准化完成 total=%d success=%d failed=%d",
|
||||
len(factors),
|
||||
success_count,
|
||||
len(factors) - success_count,
|
||||
extra=LOG_EXTRA
|
||||
)
|
||||
|
||||
return results
|
||||
145
app/ingest/entity_recognition.py
Normal file
145
app/ingest/entity_recognition.py
Normal file
@ -0,0 +1,145 @@
|
||||
"""Stock code mapping and entity recognition utilities."""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
# 股票代码正则表达式
|
||||
A_SH_CODE_PATTERN = re.compile(r"\b(\d{6})(\.(?:SH|SZ))?\b", re.IGNORECASE)
|
||||
HK_CODE_PATTERN = re.compile(r"\b(\d{4})\.HK\b", re.IGNORECASE)
|
||||
|
||||
def normalize_stock_code(code: str, explicit_market: str = None) -> str:
|
||||
"""规范化股票代码格式.
|
||||
|
||||
Args:
|
||||
code: 原始股票代码
|
||||
explicit_market: 显式指定的市场,如 'SH' 或 'SZ'
|
||||
|
||||
Returns:
|
||||
标准格式的股票代码,如 '000001.SZ'
|
||||
"""
|
||||
if '.' in code:
|
||||
return code.upper()
|
||||
|
||||
if explicit_market:
|
||||
return f"{code}.{explicit_market.upper()}"
|
||||
|
||||
# 根据代码规则判断市场
|
||||
if code.startswith('6'):
|
||||
return f"{code}.SH"
|
||||
elif code.startswith(('0', '3')):
|
||||
return f"{code}.SZ"
|
||||
else:
|
||||
return f"{code}.SH" # 默认使用上交所
|
||||
|
||||
# 公司名称变体模式
|
||||
COMPANY_SUFFIXES = ["股份", "科技", "公司", "集团", "股份有限公司", "有限公司"]
|
||||
|
||||
class CompanyNameMapper:
|
||||
"""Map company names to stock codes with fuzzy matching."""
|
||||
|
||||
def __init__(self):
|
||||
self.name_to_code: Dict[str, str] = {} # 完整名称到代码映射
|
||||
self.short_names: Dict[str, str] = {} # 简称到代码映射
|
||||
self.aliases: Dict[str, str] = {} # 别名到代码映射
|
||||
|
||||
def add_company(self, ts_code: str, full_name: str, short_name: str, aliases: List[str] = None):
|
||||
"""Add a company to the mapping.
|
||||
|
||||
Args:
|
||||
ts_code: Stock code in format like '000001.SZ'
|
||||
full_name: Full registered company name
|
||||
short_name: Official short name
|
||||
aliases: List of alternative names
|
||||
"""
|
||||
# 存储完整名称映射
|
||||
self.name_to_code[full_name] = ts_code
|
||||
|
||||
# 存储简称映射
|
||||
self.short_names[short_name] = ts_code
|
||||
|
||||
# 生成和存储名称变体
|
||||
name_variants = self._generate_name_variants(full_name)
|
||||
for variant in name_variants:
|
||||
if variant not in self.aliases:
|
||||
self.aliases[variant] = ts_code
|
||||
|
||||
# 存储额外的别名
|
||||
if aliases:
|
||||
for alias in aliases:
|
||||
if alias not in self.aliases:
|
||||
self.aliases[alias] = ts_code
|
||||
|
||||
def _generate_name_variants(self, full_name: str) -> Set[str]:
|
||||
"""Generate possible variants of a company name."""
|
||||
variants = set()
|
||||
|
||||
# 仅移除整个公司类型后缀
|
||||
for suffix in COMPANY_SUFFIXES:
|
||||
if full_name.endswith(suffix):
|
||||
variant = full_name[:-len(suffix)].strip()
|
||||
if len(variant) > 2: # 避免太短的变体
|
||||
variants.add(variant)
|
||||
break
|
||||
|
||||
return variants
|
||||
|
||||
def find_codes(self, text: str) -> List[Tuple[str, str, str]]:
|
||||
"""Find company mentions and corresponding stock codes in text.
|
||||
|
||||
Returns:
|
||||
List of tuples (matched_text, stock_code, match_type)
|
||||
where match_type is one of 'code', 'full_name', 'short_name', 'alias'
|
||||
"""
|
||||
matches = []
|
||||
|
||||
# 1. 查找直接的股票代码
|
||||
for match in A_SH_CODE_PATTERN.finditer(text):
|
||||
code = match.group(1)
|
||||
explicit_market = match.group(2)[1:] if match.group(2) else None
|
||||
ts_code = normalize_stock_code(code, explicit_market)
|
||||
matches.append((match.group(), ts_code, 'code'))
|
||||
|
||||
for match in HK_CODE_PATTERN.finditer(text):
|
||||
ts_code = match.group()
|
||||
matches.append((match.group(), ts_code, 'code'))
|
||||
|
||||
# 2. 按优先级顺序查找公司名称
|
||||
# 完整名称优先级最高
|
||||
for name, code in self.name_to_code.items():
|
||||
if name in text:
|
||||
matches.append((name, code, 'full_name'))
|
||||
|
||||
# 其次是简称
|
||||
for name, code in self.short_names.items():
|
||||
if name in text:
|
||||
matches.append((name, code, 'short_name'))
|
||||
|
||||
# 最后是别名
|
||||
for alias, code in self.aliases.items():
|
||||
if alias in text:
|
||||
matches.append((alias, code, 'alias'))
|
||||
|
||||
return matches
|
||||
|
||||
# 创建全局单例实例
|
||||
company_mapper = CompanyNameMapper()
|
||||
|
||||
def initialize_company_mapping(db_connection) -> None:
|
||||
"""从数据库加载公司名称映射.
|
||||
|
||||
Args:
|
||||
db_connection: SQLite数据库连接
|
||||
"""
|
||||
cursor = db_connection.cursor()
|
||||
cursor.execute("""
|
||||
SELECT ts_code, name, short_name
|
||||
FROM stock_company
|
||||
WHERE name IS NOT NULL
|
||||
""")
|
||||
|
||||
for ts_code, name, short_name in cursor.fetchall():
|
||||
if name and short_name:
|
||||
company_mapper.add_company(ts_code, name, short_name)
|
||||
|
||||
cursor.close()
|
||||
@ -18,6 +18,8 @@ import hashlib
|
||||
import random
|
||||
import time
|
||||
|
||||
from app.ingest.entity_recognition import company_mapper, initialize_company_mapping
|
||||
|
||||
try: # pragma: no cover - optional dependency at runtime
|
||||
import feedparser # type: ignore[import-not-found]
|
||||
except ImportError: # pragma: no cover - graceful fallback
|
||||
@ -95,6 +97,15 @@ class RssFeedConfig:
|
||||
max_items: int = 50
|
||||
|
||||
|
||||
@dataclass
|
||||
class StockMention:
|
||||
"""A mention of a stock in text."""
|
||||
matched_text: str
|
||||
ts_code: str
|
||||
match_type: str # 'code', 'full_name', 'short_name', 'alias'
|
||||
context: str # 相关的上下文片段
|
||||
confidence: float # 匹配的置信度
|
||||
|
||||
@dataclass
|
||||
class RssItem:
|
||||
"""Structured representation of an RSS entry."""
|
||||
@ -106,8 +117,214 @@ class RssItem:
|
||||
summary: str
|
||||
source: str
|
||||
ts_codes: List[str] = field(default_factory=list)
|
||||
industries: List[str] = field(default_factory=list) # 新增:相关行业列表
|
||||
important_keywords: List[str] = field(default_factory=list) # 新增:重要关键词列表
|
||||
stock_mentions: List[StockMention] = field(default_factory=list)
|
||||
industries: List[str] = field(default_factory=list)
|
||||
important_keywords: List[str] = field(default_factory=list)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Initialize company mapper if not already initialized."""
|
||||
# 测试环境下跳过数据库初始化
|
||||
if not hasattr(self, '_skip_db_init'): # 仅在非测试环境下初始化
|
||||
from app.utils.db import db_session
|
||||
|
||||
# 如果company_mapper还没有数据,初始化它
|
||||
if not company_mapper.name_to_code:
|
||||
with db_session() as conn:
|
||||
initialize_company_mapping(conn)
|
||||
|
||||
def extract_entities(self) -> None:
|
||||
"""Extract and validate entity mentions from title and summary."""
|
||||
# 分别处理标题和摘要
|
||||
title_matches = company_mapper.find_codes(self.title)
|
||||
summary_matches = company_mapper.find_codes(self.summary)
|
||||
|
||||
# 按优先级合并去重后的匹配
|
||||
code_best_matches = {} # ts_code -> (matched_text, match_type, is_title, context)
|
||||
|
||||
# 优先级顺序: 代码 > 全称 > 简称 > 别名
|
||||
priority = {'code': 0, 'full_name': 1, 'short_name': 2, 'alias': 3}
|
||||
|
||||
for matches, text, is_title in [(title_matches, self.title, True),
|
||||
(summary_matches, self.summary, False)]:
|
||||
for matched_text, ts_code, match_type in matches:
|
||||
# 提取上下文
|
||||
context = self._extract_context(text, matched_text)
|
||||
|
||||
# 如果是新代码或优先级更高的匹配
|
||||
if (ts_code not in code_best_matches or
|
||||
priority[match_type] < priority[code_best_matches[ts_code][1]]):
|
||||
code_best_matches[ts_code] = (matched_text, match_type, is_title, context)
|
||||
|
||||
# 创建股票提及列表
|
||||
for ts_code, (matched_text, match_type, is_title, context) in code_best_matches.items():
|
||||
confidence = self._calculate_confidence(match_type, matched_text, context, is_title)
|
||||
|
||||
mention = StockMention(
|
||||
matched_text=matched_text,
|
||||
ts_code=ts_code,
|
||||
match_type=match_type,
|
||||
context=context,
|
||||
confidence=confidence
|
||||
)
|
||||
self.stock_mentions.append(mention)
|
||||
|
||||
# 更新ts_codes列表,只包含高置信度的匹配
|
||||
self.ts_codes = list(set(
|
||||
mention.ts_code
|
||||
for mention in self.stock_mentions
|
||||
if mention.confidence > 0.7 # 只保留高置信度的匹配
|
||||
))
|
||||
|
||||
# 提取行业关键词
|
||||
self.extract_industries()
|
||||
|
||||
# 提取重要关键词
|
||||
self.extract_important_keywords()
|
||||
|
||||
def _extract_context(self, text: str, matched_text: str) -> str:
|
||||
"""提取匹配文本的上下文,尽量提取完整的句子."""
|
||||
# 找到匹配文本的位置
|
||||
start_pos = text.find(matched_text)
|
||||
if start_pos == -1:
|
||||
return ""
|
||||
|
||||
# 向前找到句子开始(句号、问号、感叹号或换行符之后)
|
||||
sent_start = start_pos
|
||||
while sent_start > 0:
|
||||
if text[sent_start-1] in '。?!\n':
|
||||
break
|
||||
sent_start -= 1
|
||||
|
||||
# 向后找到句子结束
|
||||
sent_end = start_pos + len(matched_text)
|
||||
while sent_end < len(text):
|
||||
if text[sent_end] in '。?!\n':
|
||||
sent_end += 1
|
||||
break
|
||||
sent_end += 1
|
||||
|
||||
# 如果上下文太长,则截取固定长度
|
||||
context = text[sent_start:sent_end].strip()
|
||||
if len(context) > 100: # 最大上下文长度
|
||||
start = max(0, start_pos - 30)
|
||||
end = min(len(text), start_pos + len(matched_text) + 30)
|
||||
context = text[start:end].strip()
|
||||
|
||||
return context
|
||||
|
||||
def extract_industries(self) -> None:
|
||||
"""从新闻标题和摘要中提取行业关键词."""
|
||||
content = f"{self.title} {self.summary}".lower()
|
||||
found_industries = set()
|
||||
|
||||
# 对每个行业检查其关键词
|
||||
for industry, keywords in INDUSTRY_KEYWORDS.items():
|
||||
# 如果找到任意关键词,认为属于该行业
|
||||
if any(keyword.lower() in content for keyword in keywords):
|
||||
found_industries.add(industry)
|
||||
|
||||
self.industries = list(found_industries)
|
||||
|
||||
def extract_important_keywords(self) -> None:
|
||||
"""提取重要关键词,包括积极/消极情感词和特定事件."""
|
||||
content = f"{self.title} {self.summary}".lower()
|
||||
found_keywords = set()
|
||||
|
||||
# 1. 检查积极关键词
|
||||
for keyword in POSITIVE_KEYWORDS:
|
||||
if keyword.lower() in content:
|
||||
found_keywords.add(f"+{keyword}") # 加前缀表示积极
|
||||
|
||||
# 2. 检查消极关键词
|
||||
for keyword in NEGATIVE_KEYWORDS:
|
||||
if keyword.lower() in content:
|
||||
found_keywords.add(f"-{keyword}") # 加前缀表示消极
|
||||
|
||||
# 3. 检查特定事件关键词
|
||||
event_keywords = {
|
||||
# 公司行为
|
||||
"收购": "M&A",
|
||||
"并购": "M&A",
|
||||
"重组": "重组",
|
||||
"分拆": "分拆",
|
||||
"上市": "IPO",
|
||||
# 财务事件
|
||||
"业绩": "业绩",
|
||||
"亏损": "业绩预警",
|
||||
"盈利": "业绩预增",
|
||||
"分红": "分红",
|
||||
"回购": "回购",
|
||||
# 监管事件
|
||||
"立案": "监管",
|
||||
"调查": "监管",
|
||||
"问询": "监管",
|
||||
"处罚": "处罚",
|
||||
# 重大项目
|
||||
"中标": "中标",
|
||||
"签约": "签约",
|
||||
"战略合作": "合作",
|
||||
}
|
||||
|
||||
for trigger, event in event_keywords.items():
|
||||
if trigger in content:
|
||||
found_keywords.add(f"#{event}") # 加前缀表示事件
|
||||
|
||||
self.important_keywords = list(found_keywords)
|
||||
|
||||
def _calculate_confidence(self, match_type: str, matched_text: str, context: str, is_title: bool = False) -> float:
|
||||
"""计算实体匹配的置信度.
|
||||
|
||||
考虑以下因素:
|
||||
1. 匹配类型的基础置信度
|
||||
2. 实体在文本中的位置(标题/开头更重要)
|
||||
3. 上下文关键词
|
||||
4. 股票相关动词
|
||||
5. 实体的完整性
|
||||
"""
|
||||
# 基础置信度
|
||||
base_confidence = {
|
||||
'code': 0.9, # 直接的股票代码匹配
|
||||
'full_name': 0.85,# 完整公司名称匹配
|
||||
'short_name': 0.7,# 公司简称匹配
|
||||
'alias': 0.6 # 别名匹配
|
||||
}.get(match_type, 0.5)
|
||||
|
||||
confidence = base_confidence
|
||||
context_lower = context.lower()
|
||||
|
||||
# 1. 位置加权
|
||||
if is_title:
|
||||
confidence += 0.1
|
||||
if context.startswith(matched_text):
|
||||
confidence += 0.05
|
||||
|
||||
# 2. 实体完整性检查
|
||||
if match_type == 'code' and '.' in matched_text: # 完整股票代码(带市场后缀)
|
||||
confidence += 0.05
|
||||
elif match_type == 'full_name' and any(suffix in matched_text for suffix in ["股份有限公司", "有限公司"]):
|
||||
confidence += 0.05
|
||||
|
||||
# 3. 上下文关键词
|
||||
context_bonus = 0.0
|
||||
corporate_terms = ["公司", "集团", "企业", "上市", "控股", "总部"]
|
||||
if any(term in context_lower for term in corporate_terms):
|
||||
context_bonus += 0.1
|
||||
|
||||
# 4. 股票相关动词
|
||||
stock_verbs = ["发布", "公告", "披露", "表示", "报告", "投资", "回购", "增持", "减持"]
|
||||
if any(verb in context_lower for verb in stock_verbs):
|
||||
context_bonus += 0.05
|
||||
|
||||
# 5. 财务/业务相关词汇
|
||||
business_terms = ["业绩", "营收", "利润", "股价", "市值", "经营", "产品", "服务", "战略"]
|
||||
if any(term in context_lower for term in business_terms):
|
||||
context_bonus += 0.05
|
||||
|
||||
# 限制上下文加成的最大值
|
||||
confidence += min(context_bonus, 0.2)
|
||||
|
||||
# 确保置信度在0-1之间
|
||||
return min(1.0, max(0.0, confidence))
|
||||
|
||||
|
||||
DEFAULT_RSS_SOURCES: Tuple[RssFeedConfig, ...] = ()
|
||||
@ -255,7 +472,7 @@ def _fetch_feed_items(
|
||||
|
||||
|
||||
def deduplicate_items(items: Iterable[RssItem]) -> List[RssItem]:
|
||||
"""Drop duplicate stories by link/id fingerprint."""
|
||||
"""Drop duplicate stories by link/id fingerprint and process entities."""
|
||||
|
||||
seen = set()
|
||||
unique: List[RssItem] = []
|
||||
@ -264,7 +481,14 @@ def deduplicate_items(items: Iterable[RssItem]) -> List[RssItem]:
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
|
||||
# 提取实体和相关信息
|
||||
item.extract_entities()
|
||||
|
||||
# 如果找到了相关股票,则保留这条新闻
|
||||
if item.stock_mentions:
|
||||
unique.append(item)
|
||||
|
||||
return unique
|
||||
|
||||
|
||||
|
||||
106
docs/TODO_UNIFIED.md
Normal file
106
docs/TODO_UNIFIED.md
Normal file
@ -0,0 +1,106 @@
|
||||
# LLM量化交易助理系统开发计划
|
||||
|
||||
> 项目愿景:开发一个可实战的投资助理工具,其业务水平要处在投资的前列。核心是通过多智能体协作提供高质量的投资决策支持。
|
||||
|
||||
> 开发进度(2025-10-05):
|
||||
> ✓ 基础因子计算框架
|
||||
> ✓ 数据访问与监控
|
||||
> ✓ 核心回测系统
|
||||
> ✓ LLM基础集成
|
||||
> △ RSS新闻处理
|
||||
> △ UI与监控系统
|
||||
|
||||
## 一、核心功能模块优先级排序
|
||||
|
||||
### 1. 数据与特征层(P0)
|
||||
#### 1.1 因子计算模块优化
|
||||
- [x] 完善 `compute_factors()` 函数实现:
|
||||
- [x] 添加数据有效性校验机制
|
||||
- [x] 实现异常值检测与处理逻辑
|
||||
- [ ] 增加计算进度显示和日志记录
|
||||
- [ ] 优化因子持久化性能
|
||||
- [ ] 支持增量计算模式
|
||||
|
||||
#### 1.2 DataBroker增强
|
||||
- [x] 开发数据请求失败的自动重试机制
|
||||
- [x] 增加数据源健康状态监控
|
||||
- [ ] 设计数据质量评估指标系统
|
||||
|
||||
#### 1.3 因子库扩展
|
||||
- [x] 扩展动量类因子群
|
||||
- [x] 开发估值类因子群
|
||||
- [x] 设计流动性因子群
|
||||
- [ ] 构建市场情绪因子群
|
||||
- [ ] 开发因子组合和权重优化算法
|
||||
|
||||
#### 1.4 新闻数据源完善
|
||||
- [x] 完成RSS数据获取和解析
|
||||
- [x] 增强情感分析能力
|
||||
- [ ] 改进实体识别准确率
|
||||
- [ ] 实现新闻时效性评分
|
||||
|
||||
### 2. 决策优化(P1)
|
||||
#### 2.1 决策环境增强
|
||||
- [x] 扩展DecisionEnv动作空间:
|
||||
- [x] 支持提示版本选择
|
||||
- [x] 允许调节部门温度
|
||||
- [ ] 优化function调用策略
|
||||
- [x] 增加环境观测维度:
|
||||
- [x] 加入换手率指标
|
||||
- [x] 纳入风险事件统计
|
||||
- [ ] 补充市场情绪指标
|
||||
|
||||
#### 2.2 回测系统完善
|
||||
- [x] 优化成交撮合逻辑:
|
||||
- [x] 统一仓位限制
|
||||
- [x] 考虑换手约束
|
||||
- [x] 加入滑点模拟
|
||||
- [x] 计算交易成本
|
||||
- [x] 完善风险控制:
|
||||
- [x] 实现止损机制
|
||||
- [x] 添加波动率限制
|
||||
- [x] 设置集中度控制
|
||||
|
||||
### 3. LLM协同(P1)
|
||||
- [x] 精简和优化Provider管理
|
||||
- [x] 增强function-calling架构
|
||||
- [x] 完善错误处理和重试策略
|
||||
- [ ] 优化提示工程:
|
||||
- [ ] 设计配置化角色提示
|
||||
- [ ] 优化数据范围控制
|
||||
- [ ] 改进上下文管理
|
||||
|
||||
### 4. UI与监控(P2)
|
||||
#### 4.1 功能增强
|
||||
- [ ] 实现"一键重评估"功能
|
||||
- [ ] 开发多版本实验对比
|
||||
- [x] 添加实时指标面板
|
||||
- [ ] 设计异常日志钻取功能
|
||||
|
||||
#### 4.2 监控增强
|
||||
- [ ] 开发"仅监控不干预"模式
|
||||
- [x] 实现策略实时评估
|
||||
- [x] 添加风险预警功能
|
||||
- [ ] 设计绩效归因分析
|
||||
|
||||
### 5. 测试与部署(P2)
|
||||
- [x] 补充核心路径单元测试
|
||||
- [ ] 建立端到端集成测试
|
||||
- [ ] 完善日志收集机制
|
||||
|
||||
## 二、近期开发重点(2025 Q4)
|
||||
|
||||
1. ✓ 完成因子计算模块的优化和重构
|
||||
2. ✓ 实现基础因子库的扩展
|
||||
3. ✓ 优化DataBroker的数据访问性能
|
||||
4. △ 完善RSS新闻数据源的接入
|
||||
5. ✓ 开始着手决策环境的增强
|
||||
|
||||
## 三、开发原则
|
||||
|
||||
1. 保持简单:每个模块只实现最核心的功能
|
||||
2. 重视可靠性:核心功能必须稳定可靠
|
||||
3. 易于使用:交互界面简单直观
|
||||
4. 容错设计:关键节点预留人工介入的可能
|
||||
|
||||
> 注:此计划将根据实际使用体验持续优化,始终保持简单实用的原则。上次更新:2025-10-05
|
||||
110
tests/test_entity_recognition.py
Normal file
110
tests/test_entity_recognition.py
Normal file
@ -0,0 +1,110 @@
|
||||
"""Test improved entity recognition in RSS processing."""
|
||||
from datetime import datetime, timezone
|
||||
import pytest
|
||||
|
||||
from app.ingest.entity_recognition import CompanyNameMapper, company_mapper
|
||||
from app.ingest.rss import RssItem, StockMention
|
||||
|
||||
def test_company_name_mapper():
|
||||
mapper = CompanyNameMapper()
|
||||
|
||||
# 添加测试公司
|
||||
mapper.add_company(
|
||||
ts_code="000001.SZ",
|
||||
full_name="平安银行股份有限公司",
|
||||
short_name="平安银行",
|
||||
aliases=["平安", "PAB"]
|
||||
)
|
||||
|
||||
# 测试名称匹配,会找到所有可能的匹配
|
||||
matches = mapper.find_codes("平安银行股份有限公司公布2025年业绩")
|
||||
assert len([m for m in matches if m[1] == "000001.SZ"]) >= 1
|
||||
# 确保有一个全称匹配
|
||||
assert any(m[2] == "full_name" and m[1] == "000001.SZ" for m in matches)
|
||||
|
||||
# 测试简称匹配
|
||||
matches = mapper.find_codes("平安银行发布公告")
|
||||
# 应该找到简称匹配,可能还有其他匹配
|
||||
assert any(m[1] == "000001.SZ" and m[2] == "short_name" for m in matches)
|
||||
|
||||
# 测试别名匹配
|
||||
matches = mapper.find_codes("PAB发布新产品")
|
||||
# 应该至少找到别名匹配
|
||||
assert any(m[1] == "000001.SZ" and m[2] == "alias" for m in matches)
|
||||
|
||||
# 测试股票代码直接匹配
|
||||
matches = mapper.find_codes("000001.SZ开盘上涨")
|
||||
assert len(matches) == 1
|
||||
assert matches[0][1] == "000001.SZ"
|
||||
assert matches[0][2] == "code"
|
||||
|
||||
def test_rss_item_entity_extraction():
|
||||
# 使用全局company_mapper
|
||||
company_mapper.name_to_code.clear() # 清除之前的数据
|
||||
company_mapper.add_company(
|
||||
ts_code="000001.SZ",
|
||||
full_name="平安银行股份有限公司",
|
||||
short_name="平安银行",
|
||||
aliases=["平安", "PAB"]
|
||||
)
|
||||
|
||||
# 创建测试新闻并跳过数据库初始化
|
||||
class TestRssItem(RssItem):
|
||||
_skip_db_init = True
|
||||
|
||||
item = TestRssItem(
|
||||
id="test_news",
|
||||
title="平安银行发布2025年业绩预告",
|
||||
link="http://example.com",
|
||||
published=datetime.now(timezone.utc),
|
||||
summary="平安银行股份有限公司(000001.SZ)今日发布2025年业绩预告",
|
||||
source="test"
|
||||
)
|
||||
|
||||
# 提取实体
|
||||
item.extract_entities()
|
||||
|
||||
# 验证结果:由于优先级机制,只会保留最优的匹配
|
||||
matched_types = set(m.match_type for m in item.stock_mentions)
|
||||
assert "code" in matched_types or "full_name" in matched_types # 应该至少找到代码或全称匹配
|
||||
|
||||
# 验证唯一股票代码
|
||||
assert len(item.ts_codes) == 1 # 只有一个唯一的股票代码
|
||||
assert item.ts_codes[0] == "000001.SZ"
|
||||
|
||||
# 验证置信度计算
|
||||
high_confidence = [m for m in item.stock_mentions if m.confidence > 0.7]
|
||||
assert len(high_confidence) >= 1 # 至少应该有一个高置信度的匹配
|
||||
|
||||
def test_rss_item_context_extraction():
|
||||
# 使用全局company_mapper
|
||||
company_mapper.name_to_code.clear() # 清除之前的数据
|
||||
company_mapper.add_company(
|
||||
ts_code="000001.SZ",
|
||||
full_name="平安银行股份有限公司",
|
||||
short_name="平安银行"
|
||||
)
|
||||
|
||||
# 创建带有上下文的测试新闻并跳过数据库初始化
|
||||
class TestRssItem(RssItem):
|
||||
_skip_db_init = True
|
||||
|
||||
item = TestRssItem(
|
||||
id="test_news",
|
||||
title="多家银行业绩报告",
|
||||
link="http://example.com",
|
||||
published=datetime.now(timezone.utc),
|
||||
summary="在银行业整体向好的背景下,平安银行表现突出,营收增长明显",
|
||||
source="test"
|
||||
)
|
||||
|
||||
# 提取实体
|
||||
item.extract_entities()
|
||||
|
||||
# 验证上下文提取
|
||||
assert len(item.stock_mentions) > 0
|
||||
mention = item.stock_mentions[0]
|
||||
assert len(mention.context) <= 70 # 上下文长度限制(30字符前后)
|
||||
assert "平安银行" in mention.context
|
||||
assert "银行业" in mention.context # 应包含前文
|
||||
assert "营收增长" in mention.context # 应包含后文
|
||||
83
tests/test_rss_item_industry_keywords.py
Normal file
83
tests/test_rss_item_industry_keywords.py
Normal file
@ -0,0 +1,83 @@
|
||||
"""Test industry and keyword extraction in RSS processing."""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from app.ingest.rss import RssItem
|
||||
|
||||
def test_industry_extraction():
|
||||
"""Test industry keyword extraction."""
|
||||
# 创建测试新闻并跳过数据库初始化
|
||||
class TestRssItem(RssItem):
|
||||
_skip_db_init = True
|
||||
|
||||
item = TestRssItem(
|
||||
id="test_news",
|
||||
title="某半导体公司推出新一代芯片",
|
||||
link="http://example.com",
|
||||
published=datetime.now(timezone.utc),
|
||||
summary="该公司在集成电路领域取得重大突破,新产品将用于5G通信",
|
||||
source="test"
|
||||
)
|
||||
|
||||
# 提取行业关键词
|
||||
item.extract_industries()
|
||||
|
||||
# 验证结果
|
||||
assert "半导体" in item.industries
|
||||
assert len(item.industries) >= 1
|
||||
|
||||
def test_important_keyword_extraction():
|
||||
"""Test important keyword extraction."""
|
||||
# 创建测试新闻(包含积极、消极和事件关键词)并跳过数据库初始化
|
||||
class TestRssItem(RssItem):
|
||||
_skip_db_init = True
|
||||
|
||||
item = TestRssItem(
|
||||
id="test_news",
|
||||
title="某公司业绩超预期,同时宣布重大收购计划",
|
||||
link="http://example.com",
|
||||
published=datetime.now(timezone.utc),
|
||||
summary="营收增长显著,但部分业务亏损,将通过并购扩张",
|
||||
source="test"
|
||||
)
|
||||
|
||||
# 提取重要关键词
|
||||
item.extract_important_keywords()
|
||||
|
||||
# 验证结果:应该包含积极、消极和事件关键词
|
||||
keywords = set(item.important_keywords)
|
||||
|
||||
# 检查是否包含至少一个积极关键词(前缀为+)
|
||||
assert any(k.startswith('+') for k in keywords)
|
||||
|
||||
# 检查是否包含至少一个消极关键词(前缀为-)
|
||||
assert any(k.startswith('-') for k in keywords)
|
||||
|
||||
# 检查是否包含至少一个事件关键词(前缀为#)
|
||||
assert any(k.startswith('#') for k in keywords)
|
||||
|
||||
def test_rss_item_full_extraction():
|
||||
"""Test full entity, industry and keyword extraction."""
|
||||
# 创建一个包含多种信息的测试新闻并跳过数据库初始化
|
||||
class TestRssItem(RssItem):
|
||||
_skip_db_init = True
|
||||
|
||||
item = TestRssItem(
|
||||
id="test_news",
|
||||
title="半导体行业利好:某公司重大突破",
|
||||
link="http://example.com",
|
||||
published=datetime.now(timezone.utc),
|
||||
summary="集成电路领域取得重大进展,业绩超预期,宣布增持计划",
|
||||
source="test"
|
||||
)
|
||||
|
||||
# 提取所有信息
|
||||
item.extract_entities() # 这将自动调用行业和关键词提取
|
||||
|
||||
# 验证结果的完整性
|
||||
assert item.industries # 应该至少识别出半导体行业
|
||||
assert item.important_keywords # 应该找到关键词
|
||||
|
||||
# 验证关键词类型的完整性
|
||||
keywords = set(item.important_keywords)
|
||||
assert any(k.startswith('+') for k in keywords) # 积极关键词
|
||||
assert any(k.startswith('#') for k in keywords) # 事件关键词
|
||||
Loading…
Reference in New Issue
Block a user