This commit is contained in:
sam 2025-10-05 16:28:53 +08:00
parent adfc8ee148
commit a619a24440
12 changed files with 1417 additions and 47 deletions

120
app/core/sentiment.py Normal file
View File

@ -0,0 +1,120 @@
"""Market sentiment indicators."""
from __future__ import annotations
from typing import Dict, List, Optional, Sequence
import numpy as np
from scipy import stats
def news_sentiment_momentum(
sentiment_series: Sequence[float],
window: int = 20
) -> Optional[float]:
"""计算新闻情感动量指标
Args:
sentiment_series: 新闻情感得分序列从新到旧排序
window: 计算窗口
Returns:
情感动量得分 (-1到1) None数据不足时
"""
if len(sentiment_series) < window:
return None
# 计算情感趋势
sentiment_series = np.array(sentiment_series[:window])
slope, _, r_value, _, _ = stats.linregress(
np.arange(len(sentiment_series)),
sentiment_series
)
# 结合斜率和拟合度
trend = np.tanh(slope * 10) # 归一化斜率
quality = abs(r_value) # 趋势显著性
return float(trend * quality)
def news_impact_score(
sentiment: float,
heat: float,
entity_count: int
) -> float:
"""计算新闻影响力得分
Args:
sentiment: 情感得分 (-1到1)
heat: 热度得分 (0到1)
entity_count: 涉及实体数量
Returns:
影响力得分 (0到1)
"""
# 新闻影响力 = 情感强度 * 热度 * 实体覆盖度
sentiment_strength = abs(sentiment)
entity_coverage = min(entity_count / 5, 1.0) # 标准化实体数量
return sentiment_strength * heat * (0.7 + 0.3 * entity_coverage)
def market_sentiment_index(
sentiment_scores: Sequence[float],
heat_scores: Sequence[float],
volume_ratios: Sequence[float],
window: int = 20
) -> Optional[float]:
"""计算综合市场情绪指数
Args:
sentiment_scores: 个股情感得分序列
heat_scores: 个股热度得分序列
volume_ratios: 个股成交量比序列
window: 计算窗口
Returns:
市场情绪指数 (-1到1) None数据不足时
"""
if len(sentiment_scores) < window or \
len(heat_scores) < window or \
len(volume_ratios) < window:
return None
# 截取窗口数据
sentiment_scores = np.array(sentiment_scores[:window])
heat_scores = np.array(heat_scores[:window])
volume_ratios = np.array(volume_ratios[:window])
# 计算带量化权重的情感得分
volume_weights = volume_ratios / np.mean(volume_ratios)
weighted_sentiment = sentiment_scores * volume_weights
# 计算热度加权平均
heat_weights = heat_scores / np.sum(heat_scores)
market_mood = np.sum(weighted_sentiment * heat_weights)
return float(np.tanh(market_mood)) # 压缩到[-1,1]区间
def industry_sentiment_divergence(
industry_sentiment: float,
peer_sentiments: Sequence[float]
) -> Optional[float]:
"""计算行业情绪背离度
Args:
industry_sentiment: 行业整体情感得分
peer_sentiments: 成分股情感得分序列
Returns:
情绪背离度 (-1到1) None数据不足时
"""
if not peer_sentiments:
return None
peer_sentiments = np.array(peer_sentiments)
peer_mean = np.mean(peer_sentiments)
peer_std = np.std(peer_sentiments)
if peer_std == 0:
return 0.0
# 计算Z分数衡量背离程度
z_score = (industry_sentiment - peer_mean) / peer_std
return float(np.tanh(z_score)) # 压缩到[-1,1]区间

View File

@ -14,6 +14,7 @@ from app.utils.db import db_session
from app.utils.logging import get_logger
# 导入扩展因子模块
from app.features.extended_factors import ExtendedFactors
from app.features.sentiment_factors import SentimentFactors
# 导入因子验证功能
from app.features.validation import check_data_sufficiency, detect_outliers
@ -77,6 +78,11 @@ DEFAULT_FACTORS: List[FactorSpec] = [
# 市场状态因子
FactorSpec("market_regime", 0), # 市场状态因子
FactorSpec("trend_strength", 0), # 趋势强度因子
# 情绪因子
FactorSpec("sent_momentum", 20), # 新闻情感动量
FactorSpec("sent_impact", 0), # 新闻影响力
FactorSpec("sent_market", 20), # 市场情绪指数
FactorSpec("sent_divergence", 0), # 行业情绪背离度
]
@ -419,7 +425,10 @@ def _compute_security_factors(
trade_date: str,
specs: Sequence[FactorSpec],
) -> Dict[str, float | None]:
"""计算单个证券的因子值"""
"""计算单个证券的因子值
包括基础因子扩展因子和情绪因子的计算
"""
# 确定所需的最大窗口大小
close_windows = [spec.window for spec in specs if _factor_prefix(spec.name) in {"mom", "volat"}]
turnover_windows = [spec.window for spec in specs if _factor_prefix(spec.name) == "turn"]

View File

@ -0,0 +1,247 @@
"""Extended sentiment factor implementations."""
from __future__ import annotations
from typing import Dict, Optional, Sequence
import numpy as np
from app.core.sentiment import (
news_sentiment_momentum,
news_impact_score,
market_sentiment_index,
industry_sentiment_divergence
)
from dataclasses import dataclass
from datetime import datetime, timezone
from app.utils.data_access import DataBroker
from app.utils.db import db_session
from app.utils.logging import get_logger
LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "sentiment_factors"}
class SentimentFactors:
"""情绪因子计算实现类。
实现了一组基于新闻市场和行业情绪的因子
1. 新闻情感动量 (sent_momentum)
2. 新闻影响力 (sent_impact)
3. 市场情绪指数 (sent_market)
4. 行业情绪背离度 (sent_divergence)
使用示例:
calculator = SentimentFactors()
broker = DataBroker()
factors = calculator.compute_stock_factors(
broker,
"000001.SZ",
"20251001"
)
"""
def __init__(self):
"""初始化情绪因子计算器"""
self.factor_specs = {
"sent_momentum": 20, # 情感动量窗口
"sent_impact": 0, # 新闻影响力
"sent_market": 20, # 市场情绪窗口
"sent_divergence": 0, # 情绪背离度
}
def compute_stock_factors(
self,
broker: DataBroker,
ts_code: str,
trade_date: str,
) -> Dict[str, Optional[float]]:
"""计算单个股票的情绪因子
Args:
broker: 数据访问器
ts_code: 股票代码
trade_date: 交易日期
Returns:
因子名称到因子值的映射字典
"""
results = {}
try:
# 获取历史新闻数据
news_data = broker.get_news_data(
ts_code,
trade_date,
limit=30 # 保留足够历史以计算动量
)
if not news_data:
LOGGER.debug(
"无新闻数据 code=%s date=%s",
ts_code,
trade_date,
extra=LOG_EXTRA
)
return {name: None for name in self.factor_specs}
# 提取序列数据
sentiment_series = [row["sentiment"] for row in news_data]
heat_series = [row["heat"] for row in news_data]
entity_counts = [
len(row["entities"].split(",")) if row["entities"] else 0
for row in news_data
]
# 1. 计算新闻情感动量
results["sent_momentum"] = news_sentiment_momentum(
sentiment_series,
window=self.factor_specs["sent_momentum"]
)
# 2. 计算新闻影响力
# 使用最新一条新闻的数据
results["sent_impact"] = news_impact_score(
sentiment=sentiment_series[0],
heat=heat_series[0],
entity_count=entity_counts[0]
)
# 3. 计算市场情绪指数
# 获取成交量数据
volume_data = broker.get_stock_data(
ts_code,
trade_date,
fields=["daily_basic.volume_ratio"],
limit=self.factor_specs["sent_market"]
)
if volume_data:
volume_ratios = [
row.get("daily_basic.volume_ratio", 1.0)
for row in volume_data
]
results["sent_market"] = market_sentiment_index(
sentiment_series,
heat_series,
volume_ratios,
window=self.factor_specs["sent_market"]
)
else:
results["sent_market"] = None
# 4. 计算行业情绪背离度
industry = broker._lookup_industry(ts_code)
if industry:
industry_sent = broker._derived_industry_sentiment(
industry,
trade_date
)
if industry_sent is not None:
# 获取同行业股票的情感得分
peer_sents = []
for peer in broker.get_industry_stocks(industry):
if peer != ts_code:
peer_data = broker.get_news_data(
peer,
trade_date,
limit=1
)
if peer_data:
peer_sents.append(peer_data[0]["sentiment"])
results["sent_divergence"] = industry_sentiment_divergence(
industry_sent,
peer_sents
)
else:
results["sent_divergence"] = None
else:
results["sent_divergence"] = None
except Exception as e:
LOGGER.error(
"计算情绪因子出错 code=%s date=%s error=%s",
ts_code,
trade_date,
str(e),
exc_info=True,
extra=LOG_EXTRA
)
return {name: None for name in self.factor_specs}
return results
def compute_batch(
self,
broker: DataBroker,
ts_codes: list[str],
trade_date: str,
batch_size: int = 100
) -> None:
"""批量计算多个股票的情绪因子并保存
Args:
broker: 数据访问器
ts_codes: 股票代码列表
trade_date: 交易日期
batch_size: 批处理大小
"""
# 准备SQL语句
columns = list(self.factor_specs.keys())
insert_columns = ["ts_code", "trade_date", "updated_at"] + columns
placeholders = ",".join("?" * len(insert_columns))
update_clause = ", ".join(
f"{column}=excluded.{column}"
for column in ["updated_at"] + columns
)
sql = (
f"INSERT INTO factors ({','.join(insert_columns)}) "
f"VALUES ({placeholders}) "
f"ON CONFLICT(ts_code, trade_date) DO UPDATE SET {update_clause}"
)
# 获取当前时间戳
timestamp = datetime.now(timezone.utc).isoformat()
# 分批处理
total_processed = 0
rows_to_persist = []
for ts_code in ts_codes:
# 计算因子
values = self.compute_stock_factors(broker, ts_code, trade_date)
# 准备数据
if any(v is not None for v in values.values()):
payload = [ts_code, trade_date, timestamp]
payload.extend(values.get(col) for col in columns)
rows_to_persist.append(payload)
total_processed += 1
if total_processed % batch_size == 0:
LOGGER.info(
"情绪因子计算进度: %d/%d (%.1f%%)",
total_processed,
len(ts_codes),
(total_processed / len(ts_codes)) * 100,
extra=LOG_EXTRA
)
# 执行批量写入
if rows_to_persist:
with db_session() as conn:
try:
conn.executemany(sql, rows_to_persist)
LOGGER.info(
"情绪因子持久化完成 total=%d",
len(rows_to_persist),
extra=LOG_EXTRA
)
except Exception as e:
LOGGER.error(
"情绪因子持久化失败 error=%s",
str(e),
exc_info=True,
extra=LOG_EXTRA
)

View File

@ -4,10 +4,13 @@ from __future__ import annotations
import json
from collections import Counter
from dataclasses import asdict
from typing import Dict, Iterable, List, Optional
from typing import Any, Dict, Iterable, List, Optional
import requests
from .context import ContextManager, Message
from .templates import TemplateRegistry
from app.utils.config import (
DEFAULT_LLM_BASE_URLS,
DEFAULT_LLM_MODELS,
@ -247,11 +250,58 @@ def _normalize_response(text: str) -> str:
return " ".join(text.strip().split())
def run_llm(prompt: str, *, system: Optional[str] = None) -> str:
"""Execute the globally configured LLM strategy with the given prompt."""
def run_llm(
prompt: str,
*,
system: Optional[str] = None,
context_id: Optional[str] = None,
template_id: Optional[str] = None,
template_vars: Optional[Dict[str, Any]] = None
) -> str:
"""Execute the globally configured LLM strategy with the given prompt.
Args:
prompt: Raw prompt string or template variable if template_id is provided
system: Optional system message
context_id: Optional context ID for conversation tracking
template_id: Optional template ID to use
template_vars: Variables to use with the template
"""
# Get config and prepare context
cfg = get_config()
if context_id:
context = ContextManager.get_context(context_id)
if not context:
context = ContextManager.create_context(context_id)
else:
context = None
settings = get_config().llm
return run_llm_with_config(settings, prompt, system=system)
# Apply template if specified
if template_id:
template = TemplateRegistry.get(template_id)
if not template:
raise ValueError(f"Template {template_id} not found")
vars_dict = template_vars or {}
if isinstance(prompt, str):
vars_dict["prompt"] = prompt
elif isinstance(prompt, dict):
vars_dict.update(prompt)
prompt = template.format(vars_dict)
# Add to context if tracking
if context:
if system:
context.add_message(Message(role="system", content=system))
context.add_message(Message(role="user", content=prompt))
# Execute LLM call
response = run_llm_with_config(cfg.llm, prompt, system=system)
# Update context with response
if context:
context.add_message(Message(role="assistant", content=response))
return response
def _run_majority_vote(config: LLMConfig, prompt: str, system: Optional[str]) -> str:

176
app/llm/context.py Normal file
View File

@ -0,0 +1,176 @@
"""LLM context management and access control."""
from __future__ import annotations
import json
import time
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Set
@dataclass
class DataAccessConfig:
"""Configuration for data access control."""
allowed_tables: Set[str]
max_history_days: int
max_batch_size: int
def validate_request(
self, table: str, start_date: str, end_date: Optional[str] = None
) -> List[str]:
"""Validate a data access request."""
errors = []
if table not in self.allowed_tables:
errors.append(f"Table {table} not allowed")
try:
start_ts = time.strptime(start_date, "%Y%m%d")
if end_date:
end_ts = time.strptime(end_date, "%Y%m%d")
days = (time.mktime(end_ts) - time.mktime(start_ts)) / (24 * 3600)
if days > self.max_history_days:
errors.append(
f"Date range exceeds max {self.max_history_days} days"
)
if days < 0:
errors.append("End date before start date")
except ValueError:
errors.append("Invalid date format (expected YYYYMMDD)")
return errors
@dataclass
class ContextConfig:
"""Configuration for context management."""
max_total_tokens: int = 4000
max_messages: int = 10
include_system: bool = True
include_functions: bool = True
@dataclass
class Message:
"""A message in the conversation context."""
role: str # system, user, assistant, function
content: str
name: Optional[str] = None # For function calls/results
function_call: Optional[Dict[str, Any]] = None
timestamp: float = field(default_factory=time.time)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dict format for API calls."""
msg = {"role": self.role, "content": self.content}
if self.name:
msg["name"] = self.name
if self.function_call:
msg["function_call"] = self.function_call
return msg
@property
def estimated_tokens(self) -> int:
"""Rough estimate of tokens in message."""
# Very rough estimate: 1 token ≈ 4 chars
base = len(self.content) // 4
if self.function_call:
base += len(json.dumps(self.function_call)) // 4
return base
@dataclass
class Context:
"""Manages conversation context with token tracking."""
messages: List[Message] = field(default_factory=list)
config: ContextConfig = field(default_factory=ContextConfig)
_token_count: int = 0
def add_message(self, message: Message) -> None:
"""Add a message to context, maintaining token limit."""
# Update token count
new_tokens = message.estimated_tokens
while (
self._token_count + new_tokens > self.config.max_total_tokens
and self.messages
):
# Remove oldest non-system message if needed
for i, msg in enumerate(self.messages):
if msg.role != "system" or len(self.messages) <= 1:
removed = self.messages.pop(i)
self._token_count -= removed.estimated_tokens
break
# Add new message
self.messages.append(message)
self._token_count += new_tokens
# Trim to max messages if needed
while len(self.messages) > self.config.max_messages:
for i, msg in enumerate(self.messages):
if msg.role != "system" or len(self.messages) <= 1:
removed = self.messages.pop(i)
self._token_count -= removed.estimated_tokens
break
def get_messages(
self, include_system: bool = None, include_functions: bool = None
) -> List[Dict[str, Any]]:
"""Get messages for API call."""
if include_system is None:
include_system = self.config.include_system
if include_functions is None:
include_functions = self.config.include_functions
return [
msg.to_dict()
for msg in self.messages
if (include_system or msg.role != "system")
and (include_functions or msg.role != "function")
]
def clear(self, keep_system: bool = True) -> None:
"""Clear context, optionally keeping system messages."""
if keep_system:
system_msgs = [m for m in self.messages if m.role == "system"]
self.messages = system_msgs
self._token_count = sum(m.estimated_tokens for m in system_msgs)
else:
self.messages.clear()
self._token_count = 0
class ContextManager:
"""Global manager for conversation contexts."""
_contexts: Dict[str, Context] = {}
_configs: Dict[str, ContextConfig] = {}
@classmethod
def create_context(
cls, context_id: str, config: Optional[ContextConfig] = None
) -> Context:
"""Create a new context."""
if context_id in cls._contexts:
raise ValueError(f"Context {context_id} already exists")
context = Context(config=config or ContextConfig())
cls._contexts[context_id] = context
return context
@classmethod
def get_context(cls, context_id: str) -> Optional[Context]:
"""Get existing context."""
return cls._contexts.get(context_id)
@classmethod
def remove_context(cls, context_id: str) -> None:
"""Remove a context."""
if context_id in cls._contexts:
del cls._contexts[context_id]
@classmethod
def clear_all(cls) -> None:
"""Clear all contexts."""
cls._contexts.clear()

View File

@ -3,6 +3,8 @@ from __future__ import annotations
from typing import Dict, TYPE_CHECKING
from .templates import TemplateRegistry
if TYPE_CHECKING: # pragma: no cover
from app.utils.config import DepartmentSettings
from app.agents.departments import DepartmentContext
@ -21,7 +23,8 @@ def department_prompt(
supplements: str = "",
) -> str:
"""Compose a structured prompt for department-level LLM ensemble."""
# Format data for template
feature_lines = "\n".join(
f"- {key}: {value}" for key, value in sorted(context.features.items())
)
@ -31,40 +34,25 @@ def department_prompt(
scope_lines = "\n".join(f"- {item}" for item in settings.data_scope)
role_description = settings.description.strip()
role_instruction = settings.prompt.strip()
supplement_block = supplements.strip()
instructions = f"""
部门名称{settings.title}
股票代码{context.ts_code}
交易日{context.trade_date}
角色说明{role_description or '未配置,默认沿用部门名称所代表的研究职责。'}
职责指令{role_instruction or '在保持部门风格的前提下,结合可用数据做出审慎判断。'}
可用数据范围
{scope_lines or '- 使用系统提供的全部上下文,必要时指出仍需的额外数据。'}
核心特征
{feature_lines or '- (无)'}
市场背景
{market_lines or '- (无)'}
追加数据
{supplement_block or '- 当前无追加数据'}
请基于以上数据给出该部门对当前股票的操作建议输出必须是 JSON字段如下
{{
"action": "BUY|BUY_S|BUY_M|BUY_L|SELL|HOLD",
"confidence": 0-1 之间的小数表示信心
"summary": "一句话概括理由",
"signals": ["详细要点", "..."],
"risks": ["风险点", "..."]
}}
如需额外数据请调用工具 `fetch_data`仅支持请求 `daily` `daily_basic` 在参数中填写 `tables` 数组元素包含 `name`表名与可选的 `window`向前回溯的条数默认 1 `trade_date`YYYYMMDD默认本次交易日
工具返回的数据会在后续消息中提供请在获取所有必要信息后再给出最终 JSON 答复
请严格返回单个 JSON 对象不要添加额外文本
"""
return instructions.strip()
# Determine template ID based on department settings
template_id = f"{settings.code.lower()}_dept"
if not TemplateRegistry.get(template_id):
template_id = "department_base"
# Prepare template variables
template_vars = {
"title": settings.title,
"ts_code": context.ts_code,
"trade_date": context.trade_date,
"description": role_description or "未配置,默认沿用部门名称所代表的研究职责。",
"instruction": role_instruction or "在保持部门风格的前提下,结合可用数据做出审慎判断。",
"data_scope": scope_lines or "- 使用系统提供的全部上下文,必要时指出仍需的额外数据。",
"features": feature_lines or "- (无)",
"market_snapshot": market_lines or "- (无)",
"supplements": supplements.strip() or "- 当前无追加数据"
}
# Get template and format prompt
template = TemplateRegistry.get(template_id)
return template.format(template_vars)

219
app/llm/templates.py Normal file
View File

@ -0,0 +1,219 @@
"""LLM prompt templates management with configuration driven design."""
from __future__ import annotations
import json
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
@dataclass
class PromptTemplate:
"""Configuration driven prompt template."""
id: str
name: str
description: str
template: str
variables: List[str]
max_length: int = 4000
required_context: List[str] = None
validation_rules: List[str] = None
def validate(self) -> List[str]:
"""Validate template configuration."""
errors = []
# Check template contains all variables
for var in self.variables:
if f"{{{var}}}" not in self.template:
errors.append(f"Template missing variable: {var}")
# Check required context fields
if self.required_context:
for field in self.required_context:
if not field:
errors.append("Empty required context field")
# Check validation rules format
if self.validation_rules:
for rule in self.validation_rules:
if not rule:
errors.append("Empty validation rule")
return errors
def format(self, context: Dict[str, Any]) -> str:
"""Format template with provided context."""
# Validate required context
if self.required_context:
missing = [f for f in self.required_context if f not in context]
if missing:
raise ValueError(f"Missing required context: {', '.join(missing)}")
# Format template
try:
result = self.template.format(**context)
except KeyError as e:
raise ValueError(f"Missing template variable: {e}")
# Truncate if needed
if len(result) > self.max_length:
result = result[:self.max_length-3] + "..."
return result
class TemplateRegistry:
"""Global registry for prompt templates."""
_templates: Dict[str, PromptTemplate] = {}
@classmethod
def register(cls, template: PromptTemplate) -> None:
"""Register a new template."""
errors = template.validate()
if errors:
raise ValueError(f"Invalid template {template.id}: {'; '.join(errors)}")
cls._templates[template.id] = template
@classmethod
def get(cls, template_id: str) -> Optional[PromptTemplate]:
"""Get template by ID."""
return cls._templates.get(template_id)
@classmethod
def list(cls) -> List[PromptTemplate]:
"""List all registered templates."""
return list(cls._templates.values())
@classmethod
def load_from_json(cls, json_str: str) -> None:
"""Load templates from JSON string."""
try:
data = json.loads(json_str)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON: {e}")
if not isinstance(data, dict):
raise ValueError("JSON root must be an object")
for template_id, cfg in data.items():
template = PromptTemplate(
id=template_id,
name=cfg.get("name", template_id),
description=cfg.get("description", ""),
template=cfg.get("template", ""),
variables=cfg.get("variables", []),
max_length=cfg.get("max_length", 4000),
required_context=cfg.get("required_context", []),
validation_rules=cfg.get("validation_rules", [])
)
cls.register(template)
@classmethod
def clear(cls) -> None:
"""Clear all registered templates."""
cls._templates.clear()
# Register default templates
DEFAULT_TEMPLATES = {
"department_base": {
"name": "部门基础模板",
"description": "通用的部门分析提示模板",
"template": """
部门名称{title}
股票代码{ts_code}
交易日{trade_date}
角色说明{description}
职责指令{instruction}
可用数据范围
{data_scope}
核心特征
{features}
市场背景
{market_snapshot}
追加数据
{supplements}
请基于以上数据给出该部门对当前股票的操作建议输出必须是 JSON字段如下
{{
"action": "BUY|BUY_S|BUY_M|BUY_L|SELL|HOLD",
"confidence": 0-1 之间的小数表示信心,
"summary": "一句话概括理由",
"signals": ["详细要点", "..."],
"risks": ["风险点", "..."]
}}
如需额外数据请调用工具 `fetch_data`仅支持请求 `daily` `daily_basic`
请严格返回单个 JSON 对象不要添加额外文本
""",
"variables": [
"title", "ts_code", "trade_date", "description", "instruction",
"data_scope", "features", "market_snapshot", "supplements"
],
"required_context": [
"ts_code", "trade_date", "features", "market_snapshot"
],
"validation_rules": [
"len(features) > 0",
"len(market_snapshot) > 0"
]
},
"momentum_dept": {
"name": "动量研究部门",
"description": "专注于动量因子分析的部门模板",
"template": """
部门名称动量研究部门
股票代码{ts_code}
交易日{trade_date}
角色说明专注于分析股票价格动量成交量动量和技术指标动量
职责指令重点关注以下方面:
1. 价格趋势强度和持续性
2. 成交量配合度
3. 技术指标背离
可用数据范围
{data_scope}
动量特征
{features}
市场背景
{market_snapshot}
追加数据
{supplements}
请基于以上数据进行动量分析并给出操作建议输出必须是 JSON字段如下
{{
"action": "BUY|BUY_S|BUY_M|BUY_L|SELL|HOLD",
"confidence": 0-1 之间的小数表示信心,
"summary": "一句话概括动量分析结论",
"signals": ["动量信号要点", "..."],
"risks": ["动量风险点", "..."]
}}
""",
"variables": [
"ts_code", "trade_date", "data_scope",
"features", "market_snapshot", "supplements"
],
"required_context": [
"ts_code", "trade_date", "features", "market_snapshot"
],
"validation_rules": [
"len(features) > 0",
"'momentum' in ' '.join(features.keys()).lower()"
]
}
}
# Register default templates
for template_id, cfg in DEFAULT_TEMPLATES.items():
TemplateRegistry.register(PromptTemplate(**{"id": template_id, **cfg}))

View File

@ -17,8 +17,8 @@
- [x] 完善 `compute_factors()` 函数实现:
- [x] 添加数据有效性校验机制
- [x] 实现异常值检测与处理逻辑
- [ ] 增加计算进度显示和日志记录
- [ ] 优化因子持久化性能
- [x] 增加计算进度显示和日志记录
- [x] 优化因子持久化性能
- [ ] 支持增量计算模式
#### 1.2 DataBroker增强
@ -30,7 +30,7 @@
- [x] 扩展动量类因子群
- [x] 开发估值类因子群
- [x] 设计流动性因子群
- [ ] 构建市场情绪因子群
- [x] 构建市场情绪因子群
- [ ] 开发因子组合和权重优化算法
#### 1.4 新闻数据源完善

148
tests/test_llm_context.py Normal file
View File

@ -0,0 +1,148 @@
"""Test cases for LLM context management."""
import time
import pytest
from app.llm.context import Context, ContextConfig, ContextManager, DataAccessConfig, Message
def test_data_access_config():
"""Test data access configuration and validation."""
config = DataAccessConfig(
allowed_tables={"daily", "daily_basic"},
max_history_days=365,
max_batch_size=1000
)
# Valid request
errors = config.validate_request("daily", "20251001", "20251005")
assert not errors
# Invalid table
errors = config.validate_request("invalid", "20251001")
assert len(errors) == 1
assert "not allowed" in errors[0]
# Invalid date format
errors = config.validate_request("daily", "invalid")
assert len(errors) == 1
assert "Invalid date format" in errors[0]
# Date range too long
errors = config.validate_request("daily", "20251001", "20261001")
assert len(errors) == 1
assert "exceeds max" in errors[0]
# End date before start
errors = config.validate_request("daily", "20251005", "20251001")
assert len(errors) == 1
assert "before start date" in errors[0]
def test_context_config():
"""Test context configuration defaults."""
config = ContextConfig()
assert config.max_total_tokens > 0
assert config.max_messages > 0
assert config.include_system is True
assert config.include_functions is True
def test_message():
"""Test message functionality."""
# Basic message
msg = Message(role="user", content="Hello")
assert msg.role == "user"
assert msg.content == "Hello"
assert msg.name is None
assert msg.function_call is None
assert msg.timestamp <= time.time()
# Function message
func_msg = Message(
role="function",
content="Result",
name="test_func",
function_call={"name": "test_func", "arguments": "{}"}
)
assert func_msg.name == "test_func"
assert func_msg.function_call is not None
# Dict conversion
msg_dict = msg.to_dict()
assert msg_dict["role"] == "user"
assert msg_dict["content"] == "Hello"
assert "name" not in msg_dict
assert "function_call" not in msg_dict
func_dict = func_msg.to_dict()
assert func_dict["name"] == "test_func"
assert func_dict["function_call"] is not None
# Token estimation
assert msg.estimated_tokens > 0
assert func_msg.estimated_tokens > msg.estimated_tokens
def test_context():
"""Test context management."""
config = ContextConfig(max_total_tokens=100, max_messages=3)
context = Context(config=config)
# Add messages
msg1 = Message(role="system", content="System message")
msg2 = Message(role="user", content="User message")
msg3 = Message(role="assistant", content="Assistant message")
msg4 = Message(role="user", content="Another message")
context.add_message(msg1)
context.add_message(msg2)
context.add_message(msg3)
assert len(context.messages) == 3
# Test max messages
context.add_message(msg4)
assert len(context.messages) == 3
assert msg4 in context.messages # Newest message kept
# Get messages
all_msgs = context.get_messages()
assert len(all_msgs) == 3
no_system = context.get_messages(include_system=False)
assert len(no_system) == 2
assert all(m["role"] != "system" for m in no_system)
# Clear context
context.clear(keep_system=True)
assert len(context.messages) == 1
assert context.messages[0].role == "system"
context.clear(keep_system=False)
assert len(context.messages) == 0
def test_context_manager():
"""Test context manager functionality."""
ContextManager.clear_all()
# Create context
context = ContextManager.create_context("test")
assert ContextManager.get_context("test") == context
# Duplicate context
with pytest.raises(ValueError):
ContextManager.create_context("test")
# Custom config
config = ContextConfig(max_total_tokens=200)
custom = ContextManager.create_context("custom", config)
assert custom.config.max_total_tokens == 200
# Remove context
ContextManager.remove_context("test")
assert ContextManager.get_context("test") is None
# Clear all
ContextManager.clear_all()
assert ContextManager.get_context("custom") is None

178
tests/test_llm_templates.py Normal file
View File

@ -0,0 +1,178 @@
"""Test cases for LLM template management."""
import pytest
from app.llm.templates import PromptTemplate, TemplateRegistry
def test_prompt_template_validation():
"""Test template validation logic."""
# Valid template
template = PromptTemplate(
id="test",
name="Test Template",
description="A test template",
template="Hello {name}!",
variables=["name"]
)
assert not template.validate()
# Missing variable
template = PromptTemplate(
id="test",
name="Test Template",
description="A test template",
template="Hello {name}!",
variables=["name", "missing"]
)
errors = template.validate()
assert len(errors) == 1
assert "missing" in errors[0]
# Empty required context
template = PromptTemplate(
id="test",
name="Test Template",
description="A test template",
template="Hello {name}!",
variables=["name"],
required_context=["", "name"]
)
errors = template.validate()
assert len(errors) == 1
assert "Empty required context" in errors[0]
# Empty validation rule
template = PromptTemplate(
id="test",
name="Test Template",
description="A test template",
template="Hello {name}!",
variables=["name"],
validation_rules=["len(name) > 0", ""]
)
errors = template.validate()
assert len(errors) == 1
assert "Empty validation rule" in errors[0]
def test_prompt_template_format():
"""Test template formatting."""
template = PromptTemplate(
id="test",
name="Test Template",
description="A test template",
template="Hello {name}!",
variables=["name"],
required_context=["name"],
max_length=10
)
# Valid context
result = template.format({"name": "World"})
assert result == "Hello Wor..."
# Missing required context
with pytest.raises(ValueError) as exc:
template.format({})
assert "Missing required context" in str(exc.value)
# Missing variable
with pytest.raises(ValueError) as exc:
template.format({"wrong": "value"})
assert "Missing template variable" in str(exc.value)
def test_template_registry():
"""Test template registry operations."""
TemplateRegistry.clear()
# Register valid template
template = PromptTemplate(
id="test",
name="Test Template",
description="A test template",
template="Hello {name}!",
variables=["name"]
)
TemplateRegistry.register(template)
assert TemplateRegistry.get("test") == template
# Register invalid template
invalid = PromptTemplate(
id="invalid",
name="Invalid Template",
description="An invalid template",
template="Hello {name}!",
variables=["wrong"]
)
with pytest.raises(ValueError) as exc:
TemplateRegistry.register(invalid)
assert "Invalid template" in str(exc.value)
# List templates
templates = TemplateRegistry.list()
assert len(templates) == 1
assert templates[0].id == "test"
# Load from JSON
json_str = '''
{
"json_test": {
"name": "JSON Test",
"description": "Test template from JSON",
"template": "Hello {name}!",
"variables": ["name"]
}
}
'''
TemplateRegistry.load_from_json(json_str)
assert TemplateRegistry.get("json_test") is not None
# Invalid JSON
with pytest.raises(ValueError) as exc:
TemplateRegistry.load_from_json("invalid json")
assert "Invalid JSON" in str(exc.value)
# Non-object JSON
with pytest.raises(ValueError) as exc:
TemplateRegistry.load_from_json("[1, 2, 3]")
assert "JSON root must be an object" in str(exc.value)
def test_default_templates():
"""Test default template registration."""
TemplateRegistry.clear()
from app.llm.templates import DEFAULT_TEMPLATES
# Verify default templates are loaded
assert len(TemplateRegistry.list()) > 0
# Check specific templates
dept_base = TemplateRegistry.get("department_base")
assert dept_base is not None
assert "部门基础模板" in dept_base.name
momentum = TemplateRegistry.get("momentum_dept")
assert momentum is not None
assert "动量研究部门" in momentum.name
# Validate template content
assert all("{" + var + "}" in dept_base.template for var in dept_base.variables)
assert all("{" + var + "}" in momentum.template for var in momentum.variables)
# Test template usage
context = {
"title": "测试部门",
"ts_code": "000001.SZ",
"trade_date": "20251005",
"description": "测试描述",
"instruction": "测试指令",
"data_scope": "daily,daily_basic",
"features": "特征1,特征2",
"market_snapshot": "市场数据1,市场数据2",
"supplements": "补充数据"
}
result = dept_base.format(context)
assert "测试部门" in result
assert "000001.SZ" in result
assert "20251005" in result

View File

@ -0,0 +1,133 @@
"""Tests for sentiment factor computation."""
from __future__ import annotations
from datetime import date, datetime
from typing import Any, Dict, List
import pytest
from app.features.sentiment_factors import SentimentFactors
from app.utils.data_access import DataBroker
class MockDataBroker:
"""Mock DataBroker for testing."""
def get_news_data(
self,
ts_code: str,
trade_date: str,
limit: int = 30
) -> List[Dict[str, Any]]:
"""模拟新闻数据"""
if ts_code == "000001.SZ":
return [
{
"sentiment": 0.8,
"heat": 0.6,
"entities": "公司A,行业B,概念C"
},
{
"sentiment": 0.6,
"heat": 0.4,
"entities": "公司A,概念D"
}
]
return []
def get_stock_data(
self,
ts_code: str,
trade_date: str,
fields: List[str],
limit: int = 1
) -> List[Dict[str, Any]]:
"""模拟股票数据"""
if ts_code == "000001.SZ":
return [
{"daily_basic.volume_ratio": 1.2},
{"daily_basic.volume_ratio": 1.1}
]
return []
def _lookup_industry(self, ts_code: str) -> str:
"""模拟行业查询"""
if ts_code == "000001.SZ":
return "银行"
return ""
def _derived_industry_sentiment(
self,
industry: str,
trade_date: str
) -> float:
"""模拟行业情绪"""
if industry == "银行":
return 0.5
return 0.0
def get_industry_stocks(self, industry: str) -> List[str]:
"""模拟行业成分股"""
if industry == "银行":
return ["000001.SZ", "600000.SH"]
return []
def test_compute_stock_factors():
"""测试股票情绪因子计算"""
calculator = SentimentFactors()
broker = MockDataBroker()
# 测试有数据的情况
factors = calculator.compute_stock_factors(
broker,
"000001.SZ",
"20251001"
)
assert "sent_momentum" in factors
assert "sent_impact" in factors
assert "sent_market" in factors
assert "sent_divergence" in factors
assert factors["sent_impact"] > 0
# 测试无数据的情况
factors = calculator.compute_stock_factors(
broker,
"000002.SZ",
"20251001"
)
assert all(v is None for v in factors.values())
def test_compute_batch(tmp_path):
"""测试批量计算功能"""
from app.data.schema import initialize_database
from app.utils.config import get_config
# 配置测试数据库
config = get_config()
config.db_path = tmp_path / "test.db"
# 初始化数据库
initialize_database()
calculator = SentimentFactors()
broker = MockDataBroker()
# 测试批量计算
ts_codes = ["000001.SZ", "000002.SZ", "600000.SH"]
calculator.compute_batch(broker, ts_codes, "20251001")
# 验证数据已保存
from app.utils.db import db_session
with db_session() as conn:
rows = conn.execute(
"SELECT * FROM factors WHERE trade_date = ?",
("20251001",)
).fetchall()
# 应该只有一个股票有数据
assert len(rows) == 1
assert rows[0]["ts_code"] == "000001.SZ"

View File

@ -0,0 +1,102 @@
"""Tests for market sentiment indicators."""
from __future__ import annotations
import numpy as np
import pytest
from app.core.sentiment import (
news_sentiment_momentum,
news_impact_score,
market_sentiment_index,
industry_sentiment_divergence
)
def test_news_sentiment_momentum():
# 生成测试数据
window = 20
uptrend = np.linspace(-0.5, 0.5, window) # 上升趋势
downtrend = np.linspace(0.5, -0.5, window) # 下降趋势
flat = np.zeros(window) # 平稳趋势
# 测试上升趋势
result = news_sentiment_momentum(uptrend)
assert result is not None
assert result > 0
# 测试下降趋势
result = news_sentiment_momentum(downtrend)
assert result is not None
assert result < 0
# 测试平稳趋势
result = news_sentiment_momentum(flat)
assert result is not None
assert abs(result) < 0.1
# 测试数据不足
result = news_sentiment_momentum(uptrend[:10])
assert result is None
def test_news_impact_score():
# 测试典型场景
score = news_impact_score(sentiment=0.8, heat=0.6, entity_count=3)
assert 0 <= score <= 1
assert score > news_impact_score(sentiment=0.4, heat=0.6, entity_count=3)
# 测试边界情况
assert news_impact_score(sentiment=0, heat=0.5, entity_count=1) == 0
assert 0 < news_impact_score(sentiment=1, heat=1, entity_count=10) <= 1
# 测试实体数量影响
low_entity = news_impact_score(sentiment=0.5, heat=0.5, entity_count=1)
high_entity = news_impact_score(sentiment=0.5, heat=0.5, entity_count=5)
assert high_entity > low_entity
def test_market_sentiment_index():
window = 20
# 生成测试数据
sentiment_scores = np.random.uniform(-1, 1, window)
heat_scores = np.random.uniform(0, 1, window)
volume_ratios = np.random.uniform(0.5, 2, window)
# 测试正常计算
result = market_sentiment_index(
sentiment_scores,
heat_scores,
volume_ratios
)
assert result is not None
assert -1 <= result <= 1
# 测试数据缺失
result = market_sentiment_index(
sentiment_scores[:10],
heat_scores,
volume_ratios
)
assert result is None
def test_industry_sentiment_divergence():
# 测试显著背离
high_divergence = industry_sentiment_divergence(
industry_sentiment=0.8,
peer_sentiments=[-0.2, -0.1, 0, 0.1]
)
assert high_divergence is not None
assert high_divergence > 0
# 测试一致性好
low_divergence = industry_sentiment_divergence(
industry_sentiment=0.1,
peer_sentiments=[0, 0.1, 0.2]
)
assert low_divergence is not None
assert abs(low_divergence) < abs(high_divergence)
# 测试空数据
assert industry_sentiment_divergence(0.5, []) is None