update
This commit is contained in:
parent
adfc8ee148
commit
a619a24440
120
app/core/sentiment.py
Normal file
120
app/core/sentiment.py
Normal file
@ -0,0 +1,120 @@
|
||||
"""Market sentiment indicators."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Optional, Sequence
|
||||
import numpy as np
|
||||
from scipy import stats
|
||||
|
||||
def news_sentiment_momentum(
|
||||
sentiment_series: Sequence[float],
|
||||
window: int = 20
|
||||
) -> Optional[float]:
|
||||
"""计算新闻情感动量指标
|
||||
|
||||
Args:
|
||||
sentiment_series: 新闻情感得分序列,从新到旧排序
|
||||
window: 计算窗口
|
||||
|
||||
Returns:
|
||||
情感动量得分 (-1到1),或 None(数据不足时)
|
||||
"""
|
||||
if len(sentiment_series) < window:
|
||||
return None
|
||||
|
||||
# 计算情感趋势
|
||||
sentiment_series = np.array(sentiment_series[:window])
|
||||
slope, _, r_value, _, _ = stats.linregress(
|
||||
np.arange(len(sentiment_series)),
|
||||
sentiment_series
|
||||
)
|
||||
|
||||
# 结合斜率和拟合度
|
||||
trend = np.tanh(slope * 10) # 归一化斜率
|
||||
quality = abs(r_value) # 趋势显著性
|
||||
|
||||
return float(trend * quality)
|
||||
|
||||
def news_impact_score(
|
||||
sentiment: float,
|
||||
heat: float,
|
||||
entity_count: int
|
||||
) -> float:
|
||||
"""计算新闻影响力得分
|
||||
|
||||
Args:
|
||||
sentiment: 情感得分 (-1到1)
|
||||
heat: 热度得分 (0到1)
|
||||
entity_count: 涉及实体数量
|
||||
|
||||
Returns:
|
||||
影响力得分 (0到1)
|
||||
"""
|
||||
# 新闻影响力 = 情感强度 * 热度 * 实体覆盖度
|
||||
sentiment_strength = abs(sentiment)
|
||||
entity_coverage = min(entity_count / 5, 1.0) # 标准化实体数量
|
||||
|
||||
return sentiment_strength * heat * (0.7 + 0.3 * entity_coverage)
|
||||
|
||||
def market_sentiment_index(
|
||||
sentiment_scores: Sequence[float],
|
||||
heat_scores: Sequence[float],
|
||||
volume_ratios: Sequence[float],
|
||||
window: int = 20
|
||||
) -> Optional[float]:
|
||||
"""计算综合市场情绪指数
|
||||
|
||||
Args:
|
||||
sentiment_scores: 个股情感得分序列
|
||||
heat_scores: 个股热度得分序列
|
||||
volume_ratios: 个股成交量比序列
|
||||
window: 计算窗口
|
||||
|
||||
Returns:
|
||||
市场情绪指数 (-1到1),或 None(数据不足时)
|
||||
"""
|
||||
if len(sentiment_scores) < window or \
|
||||
len(heat_scores) < window or \
|
||||
len(volume_ratios) < window:
|
||||
return None
|
||||
|
||||
# 截取窗口数据
|
||||
sentiment_scores = np.array(sentiment_scores[:window])
|
||||
heat_scores = np.array(heat_scores[:window])
|
||||
volume_ratios = np.array(volume_ratios[:window])
|
||||
|
||||
# 计算带量化权重的情感得分
|
||||
volume_weights = volume_ratios / np.mean(volume_ratios)
|
||||
weighted_sentiment = sentiment_scores * volume_weights
|
||||
|
||||
# 计算热度加权平均
|
||||
heat_weights = heat_scores / np.sum(heat_scores)
|
||||
market_mood = np.sum(weighted_sentiment * heat_weights)
|
||||
|
||||
return float(np.tanh(market_mood)) # 压缩到[-1,1]区间
|
||||
|
||||
def industry_sentiment_divergence(
|
||||
industry_sentiment: float,
|
||||
peer_sentiments: Sequence[float]
|
||||
) -> Optional[float]:
|
||||
"""计算行业情绪背离度
|
||||
|
||||
Args:
|
||||
industry_sentiment: 行业整体情感得分
|
||||
peer_sentiments: 成分股情感得分序列
|
||||
|
||||
Returns:
|
||||
情绪背离度 (-1到1),或 None(数据不足时)
|
||||
"""
|
||||
if not peer_sentiments:
|
||||
return None
|
||||
|
||||
peer_sentiments = np.array(peer_sentiments)
|
||||
peer_mean = np.mean(peer_sentiments)
|
||||
peer_std = np.std(peer_sentiments)
|
||||
|
||||
if peer_std == 0:
|
||||
return 0.0
|
||||
|
||||
# 计算Z分数衡量背离程度
|
||||
z_score = (industry_sentiment - peer_mean) / peer_std
|
||||
return float(np.tanh(z_score)) # 压缩到[-1,1]区间
|
||||
@ -14,6 +14,7 @@ from app.utils.db import db_session
|
||||
from app.utils.logging import get_logger
|
||||
# 导入扩展因子模块
|
||||
from app.features.extended_factors import ExtendedFactors
|
||||
from app.features.sentiment_factors import SentimentFactors
|
||||
# 导入因子验证功能
|
||||
from app.features.validation import check_data_sufficiency, detect_outliers
|
||||
|
||||
@ -77,6 +78,11 @@ DEFAULT_FACTORS: List[FactorSpec] = [
|
||||
# 市场状态因子
|
||||
FactorSpec("market_regime", 0), # 市场状态因子
|
||||
FactorSpec("trend_strength", 0), # 趋势强度因子
|
||||
# 情绪因子
|
||||
FactorSpec("sent_momentum", 20), # 新闻情感动量
|
||||
FactorSpec("sent_impact", 0), # 新闻影响力
|
||||
FactorSpec("sent_market", 20), # 市场情绪指数
|
||||
FactorSpec("sent_divergence", 0), # 行业情绪背离度
|
||||
]
|
||||
|
||||
|
||||
@ -419,7 +425,10 @@ def _compute_security_factors(
|
||||
trade_date: str,
|
||||
specs: Sequence[FactorSpec],
|
||||
) -> Dict[str, float | None]:
|
||||
"""计算单个证券的因子值"""
|
||||
"""计算单个证券的因子值
|
||||
|
||||
包括基础因子、扩展因子和情绪因子的计算。
|
||||
"""
|
||||
# 确定所需的最大窗口大小
|
||||
close_windows = [spec.window for spec in specs if _factor_prefix(spec.name) in {"mom", "volat"}]
|
||||
turnover_windows = [spec.window for spec in specs if _factor_prefix(spec.name) == "turn"]
|
||||
|
||||
247
app/features/sentiment_factors.py
Normal file
247
app/features/sentiment_factors.py
Normal file
@ -0,0 +1,247 @@
|
||||
"""Extended sentiment factor implementations."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Optional, Sequence
|
||||
import numpy as np
|
||||
|
||||
from app.core.sentiment import (
|
||||
news_sentiment_momentum,
|
||||
news_impact_score,
|
||||
market_sentiment_index,
|
||||
industry_sentiment_divergence
|
||||
)
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from app.utils.data_access import DataBroker
|
||||
from app.utils.db import db_session
|
||||
from app.utils.logging import get_logger
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
LOG_EXTRA = {"stage": "sentiment_factors"}
|
||||
|
||||
|
||||
class SentimentFactors:
|
||||
"""情绪因子计算实现类。
|
||||
|
||||
实现了一组基于新闻、市场和行业情绪的因子:
|
||||
1. 新闻情感动量 (sent_momentum)
|
||||
2. 新闻影响力 (sent_impact)
|
||||
3. 市场情绪指数 (sent_market)
|
||||
4. 行业情绪背离度 (sent_divergence)
|
||||
|
||||
使用示例:
|
||||
calculator = SentimentFactors()
|
||||
broker = DataBroker()
|
||||
|
||||
factors = calculator.compute_stock_factors(
|
||||
broker,
|
||||
"000001.SZ",
|
||||
"20251001"
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化情绪因子计算器"""
|
||||
self.factor_specs = {
|
||||
"sent_momentum": 20, # 情感动量窗口
|
||||
"sent_impact": 0, # 新闻影响力
|
||||
"sent_market": 20, # 市场情绪窗口
|
||||
"sent_divergence": 0, # 情绪背离度
|
||||
}
|
||||
|
||||
def compute_stock_factors(
|
||||
self,
|
||||
broker: DataBroker,
|
||||
ts_code: str,
|
||||
trade_date: str,
|
||||
) -> Dict[str, Optional[float]]:
|
||||
"""计算单个股票的情绪因子
|
||||
|
||||
Args:
|
||||
broker: 数据访问器
|
||||
ts_code: 股票代码
|
||||
trade_date: 交易日期
|
||||
|
||||
Returns:
|
||||
因子名称到因子值的映射字典
|
||||
"""
|
||||
results = {}
|
||||
|
||||
try:
|
||||
# 获取历史新闻数据
|
||||
news_data = broker.get_news_data(
|
||||
ts_code,
|
||||
trade_date,
|
||||
limit=30 # 保留足够历史以计算动量
|
||||
)
|
||||
|
||||
if not news_data:
|
||||
LOGGER.debug(
|
||||
"无新闻数据 code=%s date=%s",
|
||||
ts_code,
|
||||
trade_date,
|
||||
extra=LOG_EXTRA
|
||||
)
|
||||
return {name: None for name in self.factor_specs}
|
||||
|
||||
# 提取序列数据
|
||||
sentiment_series = [row["sentiment"] for row in news_data]
|
||||
heat_series = [row["heat"] for row in news_data]
|
||||
entity_counts = [
|
||||
len(row["entities"].split(",")) if row["entities"] else 0
|
||||
for row in news_data
|
||||
]
|
||||
|
||||
# 1. 计算新闻情感动量
|
||||
results["sent_momentum"] = news_sentiment_momentum(
|
||||
sentiment_series,
|
||||
window=self.factor_specs["sent_momentum"]
|
||||
)
|
||||
|
||||
# 2. 计算新闻影响力
|
||||
# 使用最新一条新闻的数据
|
||||
results["sent_impact"] = news_impact_score(
|
||||
sentiment=sentiment_series[0],
|
||||
heat=heat_series[0],
|
||||
entity_count=entity_counts[0]
|
||||
)
|
||||
|
||||
# 3. 计算市场情绪指数
|
||||
# 获取成交量数据
|
||||
volume_data = broker.get_stock_data(
|
||||
ts_code,
|
||||
trade_date,
|
||||
fields=["daily_basic.volume_ratio"],
|
||||
limit=self.factor_specs["sent_market"]
|
||||
)
|
||||
if volume_data:
|
||||
volume_ratios = [
|
||||
row.get("daily_basic.volume_ratio", 1.0)
|
||||
for row in volume_data
|
||||
]
|
||||
results["sent_market"] = market_sentiment_index(
|
||||
sentiment_series,
|
||||
heat_series,
|
||||
volume_ratios,
|
||||
window=self.factor_specs["sent_market"]
|
||||
)
|
||||
else:
|
||||
results["sent_market"] = None
|
||||
|
||||
# 4. 计算行业情绪背离度
|
||||
industry = broker._lookup_industry(ts_code)
|
||||
if industry:
|
||||
industry_sent = broker._derived_industry_sentiment(
|
||||
industry,
|
||||
trade_date
|
||||
)
|
||||
if industry_sent is not None:
|
||||
# 获取同行业股票的情感得分
|
||||
peer_sents = []
|
||||
for peer in broker.get_industry_stocks(industry):
|
||||
if peer != ts_code:
|
||||
peer_data = broker.get_news_data(
|
||||
peer,
|
||||
trade_date,
|
||||
limit=1
|
||||
)
|
||||
if peer_data:
|
||||
peer_sents.append(peer_data[0]["sentiment"])
|
||||
|
||||
results["sent_divergence"] = industry_sentiment_divergence(
|
||||
industry_sent,
|
||||
peer_sents
|
||||
)
|
||||
else:
|
||||
results["sent_divergence"] = None
|
||||
else:
|
||||
results["sent_divergence"] = None
|
||||
|
||||
except Exception as e:
|
||||
LOGGER.error(
|
||||
"计算情绪因子出错 code=%s date=%s error=%s",
|
||||
ts_code,
|
||||
trade_date,
|
||||
str(e),
|
||||
exc_info=True,
|
||||
extra=LOG_EXTRA
|
||||
)
|
||||
return {name: None for name in self.factor_specs}
|
||||
|
||||
return results
|
||||
|
||||
def compute_batch(
|
||||
self,
|
||||
broker: DataBroker,
|
||||
ts_codes: list[str],
|
||||
trade_date: str,
|
||||
batch_size: int = 100
|
||||
) -> None:
|
||||
"""批量计算多个股票的情绪因子并保存
|
||||
|
||||
Args:
|
||||
broker: 数据访问器
|
||||
ts_codes: 股票代码列表
|
||||
trade_date: 交易日期
|
||||
batch_size: 批处理大小
|
||||
"""
|
||||
# 准备SQL语句
|
||||
columns = list(self.factor_specs.keys())
|
||||
insert_columns = ["ts_code", "trade_date", "updated_at"] + columns
|
||||
|
||||
placeholders = ",".join("?" * len(insert_columns))
|
||||
update_clause = ", ".join(
|
||||
f"{column}=excluded.{column}"
|
||||
for column in ["updated_at"] + columns
|
||||
)
|
||||
|
||||
sql = (
|
||||
f"INSERT INTO factors ({','.join(insert_columns)}) "
|
||||
f"VALUES ({placeholders}) "
|
||||
f"ON CONFLICT(ts_code, trade_date) DO UPDATE SET {update_clause}"
|
||||
)
|
||||
|
||||
# 获取当前时间戳
|
||||
timestamp = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
# 分批处理
|
||||
total_processed = 0
|
||||
rows_to_persist = []
|
||||
|
||||
for ts_code in ts_codes:
|
||||
# 计算因子
|
||||
values = self.compute_stock_factors(broker, ts_code, trade_date)
|
||||
|
||||
# 准备数据
|
||||
if any(v is not None for v in values.values()):
|
||||
payload = [ts_code, trade_date, timestamp]
|
||||
payload.extend(values.get(col) for col in columns)
|
||||
rows_to_persist.append(payload)
|
||||
|
||||
total_processed += 1
|
||||
if total_processed % batch_size == 0:
|
||||
LOGGER.info(
|
||||
"情绪因子计算进度: %d/%d (%.1f%%)",
|
||||
total_processed,
|
||||
len(ts_codes),
|
||||
(total_processed / len(ts_codes)) * 100,
|
||||
extra=LOG_EXTRA
|
||||
)
|
||||
|
||||
# 执行批量写入
|
||||
if rows_to_persist:
|
||||
with db_session() as conn:
|
||||
try:
|
||||
conn.executemany(sql, rows_to_persist)
|
||||
LOGGER.info(
|
||||
"情绪因子持久化完成 total=%d",
|
||||
len(rows_to_persist),
|
||||
extra=LOG_EXTRA
|
||||
)
|
||||
except Exception as e:
|
||||
LOGGER.error(
|
||||
"情绪因子持久化失败 error=%s",
|
||||
str(e),
|
||||
exc_info=True,
|
||||
extra=LOG_EXTRA
|
||||
)
|
||||
@ -4,10 +4,13 @@ from __future__ import annotations
|
||||
import json
|
||||
from collections import Counter
|
||||
from dataclasses import asdict
|
||||
from typing import Dict, Iterable, List, Optional
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from .context import ContextManager, Message
|
||||
from .templates import TemplateRegistry
|
||||
|
||||
from app.utils.config import (
|
||||
DEFAULT_LLM_BASE_URLS,
|
||||
DEFAULT_LLM_MODELS,
|
||||
@ -247,11 +250,58 @@ def _normalize_response(text: str) -> str:
|
||||
return " ".join(text.strip().split())
|
||||
|
||||
|
||||
def run_llm(prompt: str, *, system: Optional[str] = None) -> str:
|
||||
"""Execute the globally configured LLM strategy with the given prompt."""
|
||||
def run_llm(
|
||||
prompt: str,
|
||||
*,
|
||||
system: Optional[str] = None,
|
||||
context_id: Optional[str] = None,
|
||||
template_id: Optional[str] = None,
|
||||
template_vars: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""Execute the globally configured LLM strategy with the given prompt.
|
||||
|
||||
Args:
|
||||
prompt: Raw prompt string or template variable if template_id is provided
|
||||
system: Optional system message
|
||||
context_id: Optional context ID for conversation tracking
|
||||
template_id: Optional template ID to use
|
||||
template_vars: Variables to use with the template
|
||||
"""
|
||||
# Get config and prepare context
|
||||
cfg = get_config()
|
||||
if context_id:
|
||||
context = ContextManager.get_context(context_id)
|
||||
if not context:
|
||||
context = ContextManager.create_context(context_id)
|
||||
else:
|
||||
context = None
|
||||
|
||||
settings = get_config().llm
|
||||
return run_llm_with_config(settings, prompt, system=system)
|
||||
# Apply template if specified
|
||||
if template_id:
|
||||
template = TemplateRegistry.get(template_id)
|
||||
if not template:
|
||||
raise ValueError(f"Template {template_id} not found")
|
||||
vars_dict = template_vars or {}
|
||||
if isinstance(prompt, str):
|
||||
vars_dict["prompt"] = prompt
|
||||
elif isinstance(prompt, dict):
|
||||
vars_dict.update(prompt)
|
||||
prompt = template.format(vars_dict)
|
||||
|
||||
# Add to context if tracking
|
||||
if context:
|
||||
if system:
|
||||
context.add_message(Message(role="system", content=system))
|
||||
context.add_message(Message(role="user", content=prompt))
|
||||
|
||||
# Execute LLM call
|
||||
response = run_llm_with_config(cfg.llm, prompt, system=system)
|
||||
|
||||
# Update context with response
|
||||
if context:
|
||||
context.add_message(Message(role="assistant", content=response))
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def _run_majority_vote(config: LLMConfig, prompt: str, system: Optional[str]) -> str:
|
||||
|
||||
176
app/llm/context.py
Normal file
176
app/llm/context.py
Normal file
@ -0,0 +1,176 @@
|
||||
"""LLM context management and access control."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataAccessConfig:
|
||||
"""Configuration for data access control."""
|
||||
|
||||
allowed_tables: Set[str]
|
||||
max_history_days: int
|
||||
max_batch_size: int
|
||||
|
||||
def validate_request(
|
||||
self, table: str, start_date: str, end_date: Optional[str] = None
|
||||
) -> List[str]:
|
||||
"""Validate a data access request."""
|
||||
errors = []
|
||||
|
||||
if table not in self.allowed_tables:
|
||||
errors.append(f"Table {table} not allowed")
|
||||
|
||||
try:
|
||||
start_ts = time.strptime(start_date, "%Y%m%d")
|
||||
if end_date:
|
||||
end_ts = time.strptime(end_date, "%Y%m%d")
|
||||
days = (time.mktime(end_ts) - time.mktime(start_ts)) / (24 * 3600)
|
||||
if days > self.max_history_days:
|
||||
errors.append(
|
||||
f"Date range exceeds max {self.max_history_days} days"
|
||||
)
|
||||
if days < 0:
|
||||
errors.append("End date before start date")
|
||||
except ValueError:
|
||||
errors.append("Invalid date format (expected YYYYMMDD)")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextConfig:
|
||||
"""Configuration for context management."""
|
||||
|
||||
max_total_tokens: int = 4000
|
||||
max_messages: int = 10
|
||||
include_system: bool = True
|
||||
include_functions: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
"""A message in the conversation context."""
|
||||
|
||||
role: str # system, user, assistant, function
|
||||
content: str
|
||||
name: Optional[str] = None # For function calls/results
|
||||
function_call: Optional[Dict[str, Any]] = None
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dict format for API calls."""
|
||||
msg = {"role": self.role, "content": self.content}
|
||||
if self.name:
|
||||
msg["name"] = self.name
|
||||
if self.function_call:
|
||||
msg["function_call"] = self.function_call
|
||||
return msg
|
||||
|
||||
@property
|
||||
def estimated_tokens(self) -> int:
|
||||
"""Rough estimate of tokens in message."""
|
||||
# Very rough estimate: 1 token ≈ 4 chars
|
||||
base = len(self.content) // 4
|
||||
if self.function_call:
|
||||
base += len(json.dumps(self.function_call)) // 4
|
||||
return base
|
||||
|
||||
|
||||
@dataclass
|
||||
class Context:
|
||||
"""Manages conversation context with token tracking."""
|
||||
|
||||
messages: List[Message] = field(default_factory=list)
|
||||
config: ContextConfig = field(default_factory=ContextConfig)
|
||||
_token_count: int = 0
|
||||
|
||||
def add_message(self, message: Message) -> None:
|
||||
"""Add a message to context, maintaining token limit."""
|
||||
# Update token count
|
||||
new_tokens = message.estimated_tokens
|
||||
while (
|
||||
self._token_count + new_tokens > self.config.max_total_tokens
|
||||
and self.messages
|
||||
):
|
||||
# Remove oldest non-system message if needed
|
||||
for i, msg in enumerate(self.messages):
|
||||
if msg.role != "system" or len(self.messages) <= 1:
|
||||
removed = self.messages.pop(i)
|
||||
self._token_count -= removed.estimated_tokens
|
||||
break
|
||||
|
||||
# Add new message
|
||||
self.messages.append(message)
|
||||
self._token_count += new_tokens
|
||||
|
||||
# Trim to max messages if needed
|
||||
while len(self.messages) > self.config.max_messages:
|
||||
for i, msg in enumerate(self.messages):
|
||||
if msg.role != "system" or len(self.messages) <= 1:
|
||||
removed = self.messages.pop(i)
|
||||
self._token_count -= removed.estimated_tokens
|
||||
break
|
||||
|
||||
def get_messages(
|
||||
self, include_system: bool = None, include_functions: bool = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get messages for API call."""
|
||||
if include_system is None:
|
||||
include_system = self.config.include_system
|
||||
if include_functions is None:
|
||||
include_functions = self.config.include_functions
|
||||
|
||||
return [
|
||||
msg.to_dict()
|
||||
for msg in self.messages
|
||||
if (include_system or msg.role != "system")
|
||||
and (include_functions or msg.role != "function")
|
||||
]
|
||||
|
||||
def clear(self, keep_system: bool = True) -> None:
|
||||
"""Clear context, optionally keeping system messages."""
|
||||
if keep_system:
|
||||
system_msgs = [m for m in self.messages if m.role == "system"]
|
||||
self.messages = system_msgs
|
||||
self._token_count = sum(m.estimated_tokens for m in system_msgs)
|
||||
else:
|
||||
self.messages.clear()
|
||||
self._token_count = 0
|
||||
|
||||
|
||||
class ContextManager:
|
||||
"""Global manager for conversation contexts."""
|
||||
|
||||
_contexts: Dict[str, Context] = {}
|
||||
_configs: Dict[str, ContextConfig] = {}
|
||||
|
||||
@classmethod
|
||||
def create_context(
|
||||
cls, context_id: str, config: Optional[ContextConfig] = None
|
||||
) -> Context:
|
||||
"""Create a new context."""
|
||||
if context_id in cls._contexts:
|
||||
raise ValueError(f"Context {context_id} already exists")
|
||||
context = Context(config=config or ContextConfig())
|
||||
cls._contexts[context_id] = context
|
||||
return context
|
||||
|
||||
@classmethod
|
||||
def get_context(cls, context_id: str) -> Optional[Context]:
|
||||
"""Get existing context."""
|
||||
return cls._contexts.get(context_id)
|
||||
|
||||
@classmethod
|
||||
def remove_context(cls, context_id: str) -> None:
|
||||
"""Remove a context."""
|
||||
if context_id in cls._contexts:
|
||||
del cls._contexts[context_id]
|
||||
|
||||
@classmethod
|
||||
def clear_all(cls) -> None:
|
||||
"""Clear all contexts."""
|
||||
cls._contexts.clear()
|
||||
@ -3,6 +3,8 @@ from __future__ import annotations
|
||||
|
||||
from typing import Dict, TYPE_CHECKING
|
||||
|
||||
from .templates import TemplateRegistry
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from app.utils.config import DepartmentSettings
|
||||
from app.agents.departments import DepartmentContext
|
||||
@ -21,7 +23,8 @@ def department_prompt(
|
||||
supplements: str = "",
|
||||
) -> str:
|
||||
"""Compose a structured prompt for department-level LLM ensemble."""
|
||||
|
||||
|
||||
# Format data for template
|
||||
feature_lines = "\n".join(
|
||||
f"- {key}: {value}" for key, value in sorted(context.features.items())
|
||||
)
|
||||
@ -31,40 +34,25 @@ def department_prompt(
|
||||
scope_lines = "\n".join(f"- {item}" for item in settings.data_scope)
|
||||
role_description = settings.description.strip()
|
||||
role_instruction = settings.prompt.strip()
|
||||
supplement_block = supplements.strip()
|
||||
|
||||
instructions = f"""
|
||||
部门名称:{settings.title}
|
||||
股票代码:{context.ts_code}
|
||||
交易日:{context.trade_date}
|
||||
|
||||
角色说明:{role_description or '未配置,默认沿用部门名称所代表的研究职责。'}
|
||||
职责指令:{role_instruction or '在保持部门风格的前提下,结合可用数据做出审慎判断。'}
|
||||
|
||||
【可用数据范围】
|
||||
{scope_lines or '- 使用系统提供的全部上下文,必要时指出仍需的额外数据。'}
|
||||
|
||||
【核心特征】
|
||||
{feature_lines or '- (无)'}
|
||||
|
||||
【市场背景】
|
||||
{market_lines or '- (无)'}
|
||||
|
||||
【追加数据】
|
||||
{supplement_block or '- 当前无追加数据'}
|
||||
|
||||
请基于以上数据给出该部门对当前股票的操作建议。输出必须是 JSON,字段如下:
|
||||
{{
|
||||
"action": "BUY|BUY_S|BUY_M|BUY_L|SELL|HOLD",
|
||||
"confidence": 0-1 之间的小数,表示信心,
|
||||
"summary": "一句话概括理由",
|
||||
"signals": ["详细要点", "..."],
|
||||
"risks": ["风险点", "..."]
|
||||
}}
|
||||
|
||||
如需额外数据,请调用工具 `fetch_data`,仅支持请求 `daily` 或 `daily_basic` 表;在参数中填写 `tables` 数组,元素包含 `name`(表名)与可选的 `window`(向前回溯的条数,默认 1)及 `trade_date`(YYYYMMDD,默认本次交易日)。
|
||||
工具返回的数据会在后续消息中提供,请在获取所有必要信息后再给出最终 JSON 答复。
|
||||
|
||||
请严格返回单个 JSON 对象,不要添加额外文本。
|
||||
"""
|
||||
return instructions.strip()
|
||||
|
||||
# Determine template ID based on department settings
|
||||
template_id = f"{settings.code.lower()}_dept"
|
||||
if not TemplateRegistry.get(template_id):
|
||||
template_id = "department_base"
|
||||
|
||||
# Prepare template variables
|
||||
template_vars = {
|
||||
"title": settings.title,
|
||||
"ts_code": context.ts_code,
|
||||
"trade_date": context.trade_date,
|
||||
"description": role_description or "未配置,默认沿用部门名称所代表的研究职责。",
|
||||
"instruction": role_instruction or "在保持部门风格的前提下,结合可用数据做出审慎判断。",
|
||||
"data_scope": scope_lines or "- 使用系统提供的全部上下文,必要时指出仍需的额外数据。",
|
||||
"features": feature_lines or "- (无)",
|
||||
"market_snapshot": market_lines or "- (无)",
|
||||
"supplements": supplements.strip() or "- 当前无追加数据"
|
||||
}
|
||||
|
||||
# Get template and format prompt
|
||||
template = TemplateRegistry.get(template_id)
|
||||
return template.format(template_vars)
|
||||
|
||||
219
app/llm/templates.py
Normal file
219
app/llm/templates.py
Normal file
@ -0,0 +1,219 @@
|
||||
"""LLM prompt templates management with configuration driven design."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptTemplate:
|
||||
"""Configuration driven prompt template."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
template: str
|
||||
variables: List[str]
|
||||
max_length: int = 4000
|
||||
required_context: List[str] = None
|
||||
validation_rules: List[str] = None
|
||||
|
||||
def validate(self) -> List[str]:
|
||||
"""Validate template configuration."""
|
||||
errors = []
|
||||
|
||||
# Check template contains all variables
|
||||
for var in self.variables:
|
||||
if f"{{{var}}}" not in self.template:
|
||||
errors.append(f"Template missing variable: {var}")
|
||||
|
||||
# Check required context fields
|
||||
if self.required_context:
|
||||
for field in self.required_context:
|
||||
if not field:
|
||||
errors.append("Empty required context field")
|
||||
|
||||
# Check validation rules format
|
||||
if self.validation_rules:
|
||||
for rule in self.validation_rules:
|
||||
if not rule:
|
||||
errors.append("Empty validation rule")
|
||||
|
||||
return errors
|
||||
|
||||
def format(self, context: Dict[str, Any]) -> str:
|
||||
"""Format template with provided context."""
|
||||
# Validate required context
|
||||
if self.required_context:
|
||||
missing = [f for f in self.required_context if f not in context]
|
||||
if missing:
|
||||
raise ValueError(f"Missing required context: {', '.join(missing)}")
|
||||
|
||||
# Format template
|
||||
try:
|
||||
result = self.template.format(**context)
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Missing template variable: {e}")
|
||||
|
||||
# Truncate if needed
|
||||
if len(result) > self.max_length:
|
||||
result = result[:self.max_length-3] + "..."
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class TemplateRegistry:
|
||||
"""Global registry for prompt templates."""
|
||||
|
||||
_templates: Dict[str, PromptTemplate] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, template: PromptTemplate) -> None:
|
||||
"""Register a new template."""
|
||||
errors = template.validate()
|
||||
if errors:
|
||||
raise ValueError(f"Invalid template {template.id}: {'; '.join(errors)}")
|
||||
cls._templates[template.id] = template
|
||||
|
||||
@classmethod
|
||||
def get(cls, template_id: str) -> Optional[PromptTemplate]:
|
||||
"""Get template by ID."""
|
||||
return cls._templates.get(template_id)
|
||||
|
||||
@classmethod
|
||||
def list(cls) -> List[PromptTemplate]:
|
||||
"""List all registered templates."""
|
||||
return list(cls._templates.values())
|
||||
|
||||
@classmethod
|
||||
def load_from_json(cls, json_str: str) -> None:
|
||||
"""Load templates from JSON string."""
|
||||
try:
|
||||
data = json.loads(json_str)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON: {e}")
|
||||
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError("JSON root must be an object")
|
||||
|
||||
for template_id, cfg in data.items():
|
||||
template = PromptTemplate(
|
||||
id=template_id,
|
||||
name=cfg.get("name", template_id),
|
||||
description=cfg.get("description", ""),
|
||||
template=cfg.get("template", ""),
|
||||
variables=cfg.get("variables", []),
|
||||
max_length=cfg.get("max_length", 4000),
|
||||
required_context=cfg.get("required_context", []),
|
||||
validation_rules=cfg.get("validation_rules", [])
|
||||
)
|
||||
cls.register(template)
|
||||
|
||||
@classmethod
|
||||
def clear(cls) -> None:
|
||||
"""Clear all registered templates."""
|
||||
cls._templates.clear()
|
||||
|
||||
|
||||
# Register default templates
|
||||
DEFAULT_TEMPLATES = {
|
||||
"department_base": {
|
||||
"name": "部门基础模板",
|
||||
"description": "通用的部门分析提示模板",
|
||||
"template": """
|
||||
部门名称:{title}
|
||||
股票代码:{ts_code}
|
||||
交易日:{trade_date}
|
||||
|
||||
角色说明:{description}
|
||||
职责指令:{instruction}
|
||||
|
||||
【可用数据范围】
|
||||
{data_scope}
|
||||
|
||||
【核心特征】
|
||||
{features}
|
||||
|
||||
【市场背景】
|
||||
{market_snapshot}
|
||||
|
||||
【追加数据】
|
||||
{supplements}
|
||||
|
||||
请基于以上数据给出该部门对当前股票的操作建议。输出必须是 JSON,字段如下:
|
||||
{{
|
||||
"action": "BUY|BUY_S|BUY_M|BUY_L|SELL|HOLD",
|
||||
"confidence": 0-1 之间的小数,表示信心,
|
||||
"summary": "一句话概括理由",
|
||||
"signals": ["详细要点", "..."],
|
||||
"risks": ["风险点", "..."]
|
||||
}}
|
||||
|
||||
如需额外数据,请调用工具 `fetch_data`,仅支持请求 `daily` 或 `daily_basic` 表。
|
||||
请严格返回单个 JSON 对象,不要添加额外文本。
|
||||
""",
|
||||
"variables": [
|
||||
"title", "ts_code", "trade_date", "description", "instruction",
|
||||
"data_scope", "features", "market_snapshot", "supplements"
|
||||
],
|
||||
"required_context": [
|
||||
"ts_code", "trade_date", "features", "market_snapshot"
|
||||
],
|
||||
"validation_rules": [
|
||||
"len(features) > 0",
|
||||
"len(market_snapshot) > 0"
|
||||
]
|
||||
},
|
||||
"momentum_dept": {
|
||||
"name": "动量研究部门",
|
||||
"description": "专注于动量因子分析的部门模板",
|
||||
"template": """
|
||||
部门名称:动量研究部门
|
||||
股票代码:{ts_code}
|
||||
交易日:{trade_date}
|
||||
|
||||
角色说明:专注于分析股票价格动量、成交量动量和技术指标动量
|
||||
职责指令:重点关注以下方面:
|
||||
1. 价格趋势强度和持续性
|
||||
2. 成交量配合度
|
||||
3. 技术指标背离
|
||||
|
||||
【可用数据范围】
|
||||
{data_scope}
|
||||
|
||||
【动量特征】
|
||||
{features}
|
||||
|
||||
【市场背景】
|
||||
{market_snapshot}
|
||||
|
||||
【追加数据】
|
||||
{supplements}
|
||||
|
||||
请基于以上数据进行动量分析并给出操作建议。输出必须是 JSON,字段如下:
|
||||
{{
|
||||
"action": "BUY|BUY_S|BUY_M|BUY_L|SELL|HOLD",
|
||||
"confidence": 0-1 之间的小数,表示信心,
|
||||
"summary": "一句话概括动量分析结论",
|
||||
"signals": ["动量信号要点", "..."],
|
||||
"risks": ["动量风险点", "..."]
|
||||
}}
|
||||
""",
|
||||
"variables": [
|
||||
"ts_code", "trade_date", "data_scope",
|
||||
"features", "market_snapshot", "supplements"
|
||||
],
|
||||
"required_context": [
|
||||
"ts_code", "trade_date", "features", "market_snapshot"
|
||||
],
|
||||
"validation_rules": [
|
||||
"len(features) > 0",
|
||||
"'momentum' in ' '.join(features.keys()).lower()"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
# Register default templates
|
||||
for template_id, cfg in DEFAULT_TEMPLATES.items():
|
||||
TemplateRegistry.register(PromptTemplate(**{"id": template_id, **cfg}))
|
||||
@ -17,8 +17,8 @@
|
||||
- [x] 完善 `compute_factors()` 函数实现:
|
||||
- [x] 添加数据有效性校验机制
|
||||
- [x] 实现异常值检测与处理逻辑
|
||||
- [ ] 增加计算进度显示和日志记录
|
||||
- [ ] 优化因子持久化性能
|
||||
- [x] 增加计算进度显示和日志记录
|
||||
- [x] 优化因子持久化性能
|
||||
- [ ] 支持增量计算模式
|
||||
|
||||
#### 1.2 DataBroker增强
|
||||
@ -30,7 +30,7 @@
|
||||
- [x] 扩展动量类因子群
|
||||
- [x] 开发估值类因子群
|
||||
- [x] 设计流动性因子群
|
||||
- [ ] 构建市场情绪因子群
|
||||
- [x] 构建市场情绪因子群
|
||||
- [ ] 开发因子组合和权重优化算法
|
||||
|
||||
#### 1.4 新闻数据源完善
|
||||
|
||||
148
tests/test_llm_context.py
Normal file
148
tests/test_llm_context.py
Normal file
@ -0,0 +1,148 @@
|
||||
"""Test cases for LLM context management."""
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from app.llm.context import Context, ContextConfig, ContextManager, DataAccessConfig, Message
|
||||
|
||||
|
||||
def test_data_access_config():
|
||||
"""Test data access configuration and validation."""
|
||||
config = DataAccessConfig(
|
||||
allowed_tables={"daily", "daily_basic"},
|
||||
max_history_days=365,
|
||||
max_batch_size=1000
|
||||
)
|
||||
|
||||
# Valid request
|
||||
errors = config.validate_request("daily", "20251001", "20251005")
|
||||
assert not errors
|
||||
|
||||
# Invalid table
|
||||
errors = config.validate_request("invalid", "20251001")
|
||||
assert len(errors) == 1
|
||||
assert "not allowed" in errors[0]
|
||||
|
||||
# Invalid date format
|
||||
errors = config.validate_request("daily", "invalid")
|
||||
assert len(errors) == 1
|
||||
assert "Invalid date format" in errors[0]
|
||||
|
||||
# Date range too long
|
||||
errors = config.validate_request("daily", "20251001", "20261001")
|
||||
assert len(errors) == 1
|
||||
assert "exceeds max" in errors[0]
|
||||
|
||||
# End date before start
|
||||
errors = config.validate_request("daily", "20251005", "20251001")
|
||||
assert len(errors) == 1
|
||||
assert "before start date" in errors[0]
|
||||
|
||||
|
||||
def test_context_config():
|
||||
"""Test context configuration defaults."""
|
||||
config = ContextConfig()
|
||||
assert config.max_total_tokens > 0
|
||||
assert config.max_messages > 0
|
||||
assert config.include_system is True
|
||||
assert config.include_functions is True
|
||||
|
||||
|
||||
def test_message():
|
||||
"""Test message functionality."""
|
||||
# Basic message
|
||||
msg = Message(role="user", content="Hello")
|
||||
assert msg.role == "user"
|
||||
assert msg.content == "Hello"
|
||||
assert msg.name is None
|
||||
assert msg.function_call is None
|
||||
assert msg.timestamp <= time.time()
|
||||
|
||||
# Function message
|
||||
func_msg = Message(
|
||||
role="function",
|
||||
content="Result",
|
||||
name="test_func",
|
||||
function_call={"name": "test_func", "arguments": "{}"}
|
||||
)
|
||||
assert func_msg.name == "test_func"
|
||||
assert func_msg.function_call is not None
|
||||
|
||||
# Dict conversion
|
||||
msg_dict = msg.to_dict()
|
||||
assert msg_dict["role"] == "user"
|
||||
assert msg_dict["content"] == "Hello"
|
||||
assert "name" not in msg_dict
|
||||
assert "function_call" not in msg_dict
|
||||
|
||||
func_dict = func_msg.to_dict()
|
||||
assert func_dict["name"] == "test_func"
|
||||
assert func_dict["function_call"] is not None
|
||||
|
||||
# Token estimation
|
||||
assert msg.estimated_tokens > 0
|
||||
assert func_msg.estimated_tokens > msg.estimated_tokens
|
||||
|
||||
|
||||
def test_context():
|
||||
"""Test context management."""
|
||||
config = ContextConfig(max_total_tokens=100, max_messages=3)
|
||||
context = Context(config=config)
|
||||
|
||||
# Add messages
|
||||
msg1 = Message(role="system", content="System message")
|
||||
msg2 = Message(role="user", content="User message")
|
||||
msg3 = Message(role="assistant", content="Assistant message")
|
||||
msg4 = Message(role="user", content="Another message")
|
||||
|
||||
context.add_message(msg1)
|
||||
context.add_message(msg2)
|
||||
context.add_message(msg3)
|
||||
assert len(context.messages) == 3
|
||||
|
||||
# Test max messages
|
||||
context.add_message(msg4)
|
||||
assert len(context.messages) == 3
|
||||
assert msg4 in context.messages # Newest message kept
|
||||
|
||||
# Get messages
|
||||
all_msgs = context.get_messages()
|
||||
assert len(all_msgs) == 3
|
||||
|
||||
no_system = context.get_messages(include_system=False)
|
||||
assert len(no_system) == 2
|
||||
assert all(m["role"] != "system" for m in no_system)
|
||||
|
||||
# Clear context
|
||||
context.clear(keep_system=True)
|
||||
assert len(context.messages) == 1
|
||||
assert context.messages[0].role == "system"
|
||||
|
||||
context.clear(keep_system=False)
|
||||
assert len(context.messages) == 0
|
||||
|
||||
|
||||
def test_context_manager():
|
||||
"""Test context manager functionality."""
|
||||
ContextManager.clear_all()
|
||||
|
||||
# Create context
|
||||
context = ContextManager.create_context("test")
|
||||
assert ContextManager.get_context("test") == context
|
||||
|
||||
# Duplicate context
|
||||
with pytest.raises(ValueError):
|
||||
ContextManager.create_context("test")
|
||||
|
||||
# Custom config
|
||||
config = ContextConfig(max_total_tokens=200)
|
||||
custom = ContextManager.create_context("custom", config)
|
||||
assert custom.config.max_total_tokens == 200
|
||||
|
||||
# Remove context
|
||||
ContextManager.remove_context("test")
|
||||
assert ContextManager.get_context("test") is None
|
||||
|
||||
# Clear all
|
||||
ContextManager.clear_all()
|
||||
assert ContextManager.get_context("custom") is None
|
||||
178
tests/test_llm_templates.py
Normal file
178
tests/test_llm_templates.py
Normal file
@ -0,0 +1,178 @@
|
||||
"""Test cases for LLM template management."""
|
||||
import pytest
|
||||
|
||||
from app.llm.templates import PromptTemplate, TemplateRegistry
|
||||
|
||||
|
||||
def test_prompt_template_validation():
|
||||
"""Test template validation logic."""
|
||||
# Valid template
|
||||
template = PromptTemplate(
|
||||
id="test",
|
||||
name="Test Template",
|
||||
description="A test template",
|
||||
template="Hello {name}!",
|
||||
variables=["name"]
|
||||
)
|
||||
assert not template.validate()
|
||||
|
||||
# Missing variable
|
||||
template = PromptTemplate(
|
||||
id="test",
|
||||
name="Test Template",
|
||||
description="A test template",
|
||||
template="Hello {name}!",
|
||||
variables=["name", "missing"]
|
||||
)
|
||||
errors = template.validate()
|
||||
assert len(errors) == 1
|
||||
assert "missing" in errors[0]
|
||||
|
||||
# Empty required context
|
||||
template = PromptTemplate(
|
||||
id="test",
|
||||
name="Test Template",
|
||||
description="A test template",
|
||||
template="Hello {name}!",
|
||||
variables=["name"],
|
||||
required_context=["", "name"]
|
||||
)
|
||||
errors = template.validate()
|
||||
assert len(errors) == 1
|
||||
assert "Empty required context" in errors[0]
|
||||
|
||||
# Empty validation rule
|
||||
template = PromptTemplate(
|
||||
id="test",
|
||||
name="Test Template",
|
||||
description="A test template",
|
||||
template="Hello {name}!",
|
||||
variables=["name"],
|
||||
validation_rules=["len(name) > 0", ""]
|
||||
)
|
||||
errors = template.validate()
|
||||
assert len(errors) == 1
|
||||
assert "Empty validation rule" in errors[0]
|
||||
|
||||
|
||||
def test_prompt_template_format():
|
||||
"""Test template formatting."""
|
||||
template = PromptTemplate(
|
||||
id="test",
|
||||
name="Test Template",
|
||||
description="A test template",
|
||||
template="Hello {name}!",
|
||||
variables=["name"],
|
||||
required_context=["name"],
|
||||
max_length=10
|
||||
)
|
||||
|
||||
# Valid context
|
||||
result = template.format({"name": "World"})
|
||||
assert result == "Hello Wor..."
|
||||
|
||||
# Missing required context
|
||||
with pytest.raises(ValueError) as exc:
|
||||
template.format({})
|
||||
assert "Missing required context" in str(exc.value)
|
||||
|
||||
# Missing variable
|
||||
with pytest.raises(ValueError) as exc:
|
||||
template.format({"wrong": "value"})
|
||||
assert "Missing template variable" in str(exc.value)
|
||||
|
||||
|
||||
def test_template_registry():
|
||||
"""Test template registry operations."""
|
||||
TemplateRegistry.clear()
|
||||
|
||||
# Register valid template
|
||||
template = PromptTemplate(
|
||||
id="test",
|
||||
name="Test Template",
|
||||
description="A test template",
|
||||
template="Hello {name}!",
|
||||
variables=["name"]
|
||||
)
|
||||
TemplateRegistry.register(template)
|
||||
assert TemplateRegistry.get("test") == template
|
||||
|
||||
# Register invalid template
|
||||
invalid = PromptTemplate(
|
||||
id="invalid",
|
||||
name="Invalid Template",
|
||||
description="An invalid template",
|
||||
template="Hello {name}!",
|
||||
variables=["wrong"]
|
||||
)
|
||||
with pytest.raises(ValueError) as exc:
|
||||
TemplateRegistry.register(invalid)
|
||||
assert "Invalid template" in str(exc.value)
|
||||
|
||||
# List templates
|
||||
templates = TemplateRegistry.list()
|
||||
assert len(templates) == 1
|
||||
assert templates[0].id == "test"
|
||||
|
||||
# Load from JSON
|
||||
json_str = '''
|
||||
{
|
||||
"json_test": {
|
||||
"name": "JSON Test",
|
||||
"description": "Test template from JSON",
|
||||
"template": "Hello {name}!",
|
||||
"variables": ["name"]
|
||||
}
|
||||
}
|
||||
'''
|
||||
TemplateRegistry.load_from_json(json_str)
|
||||
assert TemplateRegistry.get("json_test") is not None
|
||||
|
||||
# Invalid JSON
|
||||
with pytest.raises(ValueError) as exc:
|
||||
TemplateRegistry.load_from_json("invalid json")
|
||||
assert "Invalid JSON" in str(exc.value)
|
||||
|
||||
# Non-object JSON
|
||||
with pytest.raises(ValueError) as exc:
|
||||
TemplateRegistry.load_from_json("[1, 2, 3]")
|
||||
assert "JSON root must be an object" in str(exc.value)
|
||||
|
||||
|
||||
def test_default_templates():
|
||||
"""Test default template registration."""
|
||||
TemplateRegistry.clear()
|
||||
from app.llm.templates import DEFAULT_TEMPLATES
|
||||
|
||||
# Verify default templates are loaded
|
||||
assert len(TemplateRegistry.list()) > 0
|
||||
|
||||
# Check specific templates
|
||||
dept_base = TemplateRegistry.get("department_base")
|
||||
assert dept_base is not None
|
||||
assert "部门基础模板" in dept_base.name
|
||||
|
||||
momentum = TemplateRegistry.get("momentum_dept")
|
||||
assert momentum is not None
|
||||
assert "动量研究部门" in momentum.name
|
||||
|
||||
# Validate template content
|
||||
assert all("{" + var + "}" in dept_base.template for var in dept_base.variables)
|
||||
assert all("{" + var + "}" in momentum.template for var in momentum.variables)
|
||||
|
||||
# Test template usage
|
||||
context = {
|
||||
"title": "测试部门",
|
||||
"ts_code": "000001.SZ",
|
||||
"trade_date": "20251005",
|
||||
"description": "测试描述",
|
||||
"instruction": "测试指令",
|
||||
"data_scope": "daily,daily_basic",
|
||||
"features": "特征1,特征2",
|
||||
"market_snapshot": "市场数据1,市场数据2",
|
||||
"supplements": "补充数据"
|
||||
}
|
||||
result = dept_base.format(context)
|
||||
assert "测试部门" in result
|
||||
assert "000001.SZ" in result
|
||||
assert "20251005" in result
|
||||
133
tests/test_sentiment_factors.py
Normal file
133
tests/test_sentiment_factors.py
Normal file
@ -0,0 +1,133 @@
|
||||
"""Tests for sentiment factor computation."""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date, datetime
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytest
|
||||
|
||||
from app.features.sentiment_factors import SentimentFactors
|
||||
from app.utils.data_access import DataBroker
|
||||
|
||||
|
||||
class MockDataBroker:
|
||||
"""Mock DataBroker for testing."""
|
||||
|
||||
def get_news_data(
|
||||
self,
|
||||
ts_code: str,
|
||||
trade_date: str,
|
||||
limit: int = 30
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""模拟新闻数据"""
|
||||
if ts_code == "000001.SZ":
|
||||
return [
|
||||
{
|
||||
"sentiment": 0.8,
|
||||
"heat": 0.6,
|
||||
"entities": "公司A,行业B,概念C"
|
||||
},
|
||||
{
|
||||
"sentiment": 0.6,
|
||||
"heat": 0.4,
|
||||
"entities": "公司A,概念D"
|
||||
}
|
||||
]
|
||||
return []
|
||||
|
||||
def get_stock_data(
|
||||
self,
|
||||
ts_code: str,
|
||||
trade_date: str,
|
||||
fields: List[str],
|
||||
limit: int = 1
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""模拟股票数据"""
|
||||
if ts_code == "000001.SZ":
|
||||
return [
|
||||
{"daily_basic.volume_ratio": 1.2},
|
||||
{"daily_basic.volume_ratio": 1.1}
|
||||
]
|
||||
return []
|
||||
|
||||
def _lookup_industry(self, ts_code: str) -> str:
|
||||
"""模拟行业查询"""
|
||||
if ts_code == "000001.SZ":
|
||||
return "银行"
|
||||
return ""
|
||||
|
||||
def _derived_industry_sentiment(
|
||||
self,
|
||||
industry: str,
|
||||
trade_date: str
|
||||
) -> float:
|
||||
"""模拟行业情绪"""
|
||||
if industry == "银行":
|
||||
return 0.5
|
||||
return 0.0
|
||||
|
||||
def get_industry_stocks(self, industry: str) -> List[str]:
|
||||
"""模拟行业成分股"""
|
||||
if industry == "银行":
|
||||
return ["000001.SZ", "600000.SH"]
|
||||
return []
|
||||
|
||||
|
||||
def test_compute_stock_factors():
|
||||
"""测试股票情绪因子计算"""
|
||||
calculator = SentimentFactors()
|
||||
broker = MockDataBroker()
|
||||
|
||||
# 测试有数据的情况
|
||||
factors = calculator.compute_stock_factors(
|
||||
broker,
|
||||
"000001.SZ",
|
||||
"20251001"
|
||||
)
|
||||
|
||||
assert "sent_momentum" in factors
|
||||
assert "sent_impact" in factors
|
||||
assert "sent_market" in factors
|
||||
assert "sent_divergence" in factors
|
||||
|
||||
assert factors["sent_impact"] > 0
|
||||
|
||||
# 测试无数据的情况
|
||||
factors = calculator.compute_stock_factors(
|
||||
broker,
|
||||
"000002.SZ",
|
||||
"20251001"
|
||||
)
|
||||
|
||||
assert all(v is None for v in factors.values())
|
||||
|
||||
def test_compute_batch(tmp_path):
|
||||
"""测试批量计算功能"""
|
||||
from app.data.schema import initialize_database
|
||||
from app.utils.config import get_config
|
||||
|
||||
# 配置测试数据库
|
||||
config = get_config()
|
||||
config.db_path = tmp_path / "test.db"
|
||||
|
||||
# 初始化数据库
|
||||
initialize_database()
|
||||
|
||||
calculator = SentimentFactors()
|
||||
broker = MockDataBroker()
|
||||
|
||||
# 测试批量计算
|
||||
ts_codes = ["000001.SZ", "000002.SZ", "600000.SH"]
|
||||
calculator.compute_batch(broker, ts_codes, "20251001")
|
||||
|
||||
# 验证数据已保存
|
||||
from app.utils.db import db_session
|
||||
with db_session() as conn:
|
||||
rows = conn.execute(
|
||||
"SELECT * FROM factors WHERE trade_date = ?",
|
||||
("20251001",)
|
||||
).fetchall()
|
||||
|
||||
# 应该只有一个股票有数据
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["ts_code"] == "000001.SZ"
|
||||
102
tests/test_sentiment_indicators.py
Normal file
102
tests/test_sentiment_indicators.py
Normal file
@ -0,0 +1,102 @@
|
||||
"""Tests for market sentiment indicators."""
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from app.core.sentiment import (
|
||||
news_sentiment_momentum,
|
||||
news_impact_score,
|
||||
market_sentiment_index,
|
||||
industry_sentiment_divergence
|
||||
)
|
||||
|
||||
|
||||
def test_news_sentiment_momentum():
|
||||
# 生成测试数据
|
||||
window = 20
|
||||
uptrend = np.linspace(-0.5, 0.5, window) # 上升趋势
|
||||
downtrend = np.linspace(0.5, -0.5, window) # 下降趋势
|
||||
flat = np.zeros(window) # 平稳趋势
|
||||
|
||||
# 测试上升趋势
|
||||
result = news_sentiment_momentum(uptrend)
|
||||
assert result is not None
|
||||
assert result > 0
|
||||
|
||||
# 测试下降趋势
|
||||
result = news_sentiment_momentum(downtrend)
|
||||
assert result is not None
|
||||
assert result < 0
|
||||
|
||||
# 测试平稳趋势
|
||||
result = news_sentiment_momentum(flat)
|
||||
assert result is not None
|
||||
assert abs(result) < 0.1
|
||||
|
||||
# 测试数据不足
|
||||
result = news_sentiment_momentum(uptrend[:10])
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_news_impact_score():
|
||||
# 测试典型场景
|
||||
score = news_impact_score(sentiment=0.8, heat=0.6, entity_count=3)
|
||||
assert 0 <= score <= 1
|
||||
assert score > news_impact_score(sentiment=0.4, heat=0.6, entity_count=3)
|
||||
|
||||
# 测试边界情况
|
||||
assert news_impact_score(sentiment=0, heat=0.5, entity_count=1) == 0
|
||||
assert 0 < news_impact_score(sentiment=1, heat=1, entity_count=10) <= 1
|
||||
|
||||
# 测试实体数量影响
|
||||
low_entity = news_impact_score(sentiment=0.5, heat=0.5, entity_count=1)
|
||||
high_entity = news_impact_score(sentiment=0.5, heat=0.5, entity_count=5)
|
||||
assert high_entity > low_entity
|
||||
|
||||
|
||||
def test_market_sentiment_index():
|
||||
window = 20
|
||||
|
||||
# 生成测试数据
|
||||
sentiment_scores = np.random.uniform(-1, 1, window)
|
||||
heat_scores = np.random.uniform(0, 1, window)
|
||||
volume_ratios = np.random.uniform(0.5, 2, window)
|
||||
|
||||
# 测试正常计算
|
||||
result = market_sentiment_index(
|
||||
sentiment_scores,
|
||||
heat_scores,
|
||||
volume_ratios
|
||||
)
|
||||
assert result is not None
|
||||
assert -1 <= result <= 1
|
||||
|
||||
# 测试数据缺失
|
||||
result = market_sentiment_index(
|
||||
sentiment_scores[:10],
|
||||
heat_scores,
|
||||
volume_ratios
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_industry_sentiment_divergence():
|
||||
# 测试显著背离
|
||||
high_divergence = industry_sentiment_divergence(
|
||||
industry_sentiment=0.8,
|
||||
peer_sentiments=[-0.2, -0.1, 0, 0.1]
|
||||
)
|
||||
assert high_divergence is not None
|
||||
assert high_divergence > 0
|
||||
|
||||
# 测试一致性好
|
||||
low_divergence = industry_sentiment_divergence(
|
||||
industry_sentiment=0.1,
|
||||
peer_sentiments=[0, 0.1, 0.2]
|
||||
)
|
||||
assert low_divergence is not None
|
||||
assert abs(low_divergence) < abs(high_divergence)
|
||||
|
||||
# 测试空数据
|
||||
assert industry_sentiment_divergence(0.5, []) is None
|
||||
Loading…
Reference in New Issue
Block a user