update
This commit is contained in:
parent
adfc8ee148
commit
a619a24440
120
app/core/sentiment.py
Normal file
120
app/core/sentiment.py
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
"""Market sentiment indicators."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Dict, List, Optional, Sequence
|
||||||
|
import numpy as np
|
||||||
|
from scipy import stats
|
||||||
|
|
||||||
|
def news_sentiment_momentum(
|
||||||
|
sentiment_series: Sequence[float],
|
||||||
|
window: int = 20
|
||||||
|
) -> Optional[float]:
|
||||||
|
"""计算新闻情感动量指标
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sentiment_series: 新闻情感得分序列,从新到旧排序
|
||||||
|
window: 计算窗口
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
情感动量得分 (-1到1),或 None(数据不足时)
|
||||||
|
"""
|
||||||
|
if len(sentiment_series) < window:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 计算情感趋势
|
||||||
|
sentiment_series = np.array(sentiment_series[:window])
|
||||||
|
slope, _, r_value, _, _ = stats.linregress(
|
||||||
|
np.arange(len(sentiment_series)),
|
||||||
|
sentiment_series
|
||||||
|
)
|
||||||
|
|
||||||
|
# 结合斜率和拟合度
|
||||||
|
trend = np.tanh(slope * 10) # 归一化斜率
|
||||||
|
quality = abs(r_value) # 趋势显著性
|
||||||
|
|
||||||
|
return float(trend * quality)
|
||||||
|
|
||||||
|
def news_impact_score(
|
||||||
|
sentiment: float,
|
||||||
|
heat: float,
|
||||||
|
entity_count: int
|
||||||
|
) -> float:
|
||||||
|
"""计算新闻影响力得分
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sentiment: 情感得分 (-1到1)
|
||||||
|
heat: 热度得分 (0到1)
|
||||||
|
entity_count: 涉及实体数量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
影响力得分 (0到1)
|
||||||
|
"""
|
||||||
|
# 新闻影响力 = 情感强度 * 热度 * 实体覆盖度
|
||||||
|
sentiment_strength = abs(sentiment)
|
||||||
|
entity_coverage = min(entity_count / 5, 1.0) # 标准化实体数量
|
||||||
|
|
||||||
|
return sentiment_strength * heat * (0.7 + 0.3 * entity_coverage)
|
||||||
|
|
||||||
|
def market_sentiment_index(
|
||||||
|
sentiment_scores: Sequence[float],
|
||||||
|
heat_scores: Sequence[float],
|
||||||
|
volume_ratios: Sequence[float],
|
||||||
|
window: int = 20
|
||||||
|
) -> Optional[float]:
|
||||||
|
"""计算综合市场情绪指数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sentiment_scores: 个股情感得分序列
|
||||||
|
heat_scores: 个股热度得分序列
|
||||||
|
volume_ratios: 个股成交量比序列
|
||||||
|
window: 计算窗口
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
市场情绪指数 (-1到1),或 None(数据不足时)
|
||||||
|
"""
|
||||||
|
if len(sentiment_scores) < window or \
|
||||||
|
len(heat_scores) < window or \
|
||||||
|
len(volume_ratios) < window:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 截取窗口数据
|
||||||
|
sentiment_scores = np.array(sentiment_scores[:window])
|
||||||
|
heat_scores = np.array(heat_scores[:window])
|
||||||
|
volume_ratios = np.array(volume_ratios[:window])
|
||||||
|
|
||||||
|
# 计算带量化权重的情感得分
|
||||||
|
volume_weights = volume_ratios / np.mean(volume_ratios)
|
||||||
|
weighted_sentiment = sentiment_scores * volume_weights
|
||||||
|
|
||||||
|
# 计算热度加权平均
|
||||||
|
heat_weights = heat_scores / np.sum(heat_scores)
|
||||||
|
market_mood = np.sum(weighted_sentiment * heat_weights)
|
||||||
|
|
||||||
|
return float(np.tanh(market_mood)) # 压缩到[-1,1]区间
|
||||||
|
|
||||||
|
def industry_sentiment_divergence(
|
||||||
|
industry_sentiment: float,
|
||||||
|
peer_sentiments: Sequence[float]
|
||||||
|
) -> Optional[float]:
|
||||||
|
"""计算行业情绪背离度
|
||||||
|
|
||||||
|
Args:
|
||||||
|
industry_sentiment: 行业整体情感得分
|
||||||
|
peer_sentiments: 成分股情感得分序列
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
情绪背离度 (-1到1),或 None(数据不足时)
|
||||||
|
"""
|
||||||
|
if not peer_sentiments:
|
||||||
|
return None
|
||||||
|
|
||||||
|
peer_sentiments = np.array(peer_sentiments)
|
||||||
|
peer_mean = np.mean(peer_sentiments)
|
||||||
|
peer_std = np.std(peer_sentiments)
|
||||||
|
|
||||||
|
if peer_std == 0:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# 计算Z分数衡量背离程度
|
||||||
|
z_score = (industry_sentiment - peer_mean) / peer_std
|
||||||
|
return float(np.tanh(z_score)) # 压缩到[-1,1]区间
|
||||||
@ -14,6 +14,7 @@ from app.utils.db import db_session
|
|||||||
from app.utils.logging import get_logger
|
from app.utils.logging import get_logger
|
||||||
# 导入扩展因子模块
|
# 导入扩展因子模块
|
||||||
from app.features.extended_factors import ExtendedFactors
|
from app.features.extended_factors import ExtendedFactors
|
||||||
|
from app.features.sentiment_factors import SentimentFactors
|
||||||
# 导入因子验证功能
|
# 导入因子验证功能
|
||||||
from app.features.validation import check_data_sufficiency, detect_outliers
|
from app.features.validation import check_data_sufficiency, detect_outliers
|
||||||
|
|
||||||
@ -77,6 +78,11 @@ DEFAULT_FACTORS: List[FactorSpec] = [
|
|||||||
# 市场状态因子
|
# 市场状态因子
|
||||||
FactorSpec("market_regime", 0), # 市场状态因子
|
FactorSpec("market_regime", 0), # 市场状态因子
|
||||||
FactorSpec("trend_strength", 0), # 趋势强度因子
|
FactorSpec("trend_strength", 0), # 趋势强度因子
|
||||||
|
# 情绪因子
|
||||||
|
FactorSpec("sent_momentum", 20), # 新闻情感动量
|
||||||
|
FactorSpec("sent_impact", 0), # 新闻影响力
|
||||||
|
FactorSpec("sent_market", 20), # 市场情绪指数
|
||||||
|
FactorSpec("sent_divergence", 0), # 行业情绪背离度
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -419,7 +425,10 @@ def _compute_security_factors(
|
|||||||
trade_date: str,
|
trade_date: str,
|
||||||
specs: Sequence[FactorSpec],
|
specs: Sequence[FactorSpec],
|
||||||
) -> Dict[str, float | None]:
|
) -> Dict[str, float | None]:
|
||||||
"""计算单个证券的因子值"""
|
"""计算单个证券的因子值
|
||||||
|
|
||||||
|
包括基础因子、扩展因子和情绪因子的计算。
|
||||||
|
"""
|
||||||
# 确定所需的最大窗口大小
|
# 确定所需的最大窗口大小
|
||||||
close_windows = [spec.window for spec in specs if _factor_prefix(spec.name) in {"mom", "volat"}]
|
close_windows = [spec.window for spec in specs if _factor_prefix(spec.name) in {"mom", "volat"}]
|
||||||
turnover_windows = [spec.window for spec in specs if _factor_prefix(spec.name) == "turn"]
|
turnover_windows = [spec.window for spec in specs if _factor_prefix(spec.name) == "turn"]
|
||||||
|
|||||||
247
app/features/sentiment_factors.py
Normal file
247
app/features/sentiment_factors.py
Normal file
@ -0,0 +1,247 @@
|
|||||||
|
"""Extended sentiment factor implementations."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Dict, Optional, Sequence
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from app.core.sentiment import (
|
||||||
|
news_sentiment_momentum,
|
||||||
|
news_impact_score,
|
||||||
|
market_sentiment_index,
|
||||||
|
industry_sentiment_divergence
|
||||||
|
)
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from app.utils.data_access import DataBroker
|
||||||
|
from app.utils.db import db_session
|
||||||
|
from app.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOGGER = get_logger(__name__)
|
||||||
|
LOG_EXTRA = {"stage": "sentiment_factors"}
|
||||||
|
|
||||||
|
|
||||||
|
class SentimentFactors:
|
||||||
|
"""情绪因子计算实现类。
|
||||||
|
|
||||||
|
实现了一组基于新闻、市场和行业情绪的因子:
|
||||||
|
1. 新闻情感动量 (sent_momentum)
|
||||||
|
2. 新闻影响力 (sent_impact)
|
||||||
|
3. 市场情绪指数 (sent_market)
|
||||||
|
4. 行业情绪背离度 (sent_divergence)
|
||||||
|
|
||||||
|
使用示例:
|
||||||
|
calculator = SentimentFactors()
|
||||||
|
broker = DataBroker()
|
||||||
|
|
||||||
|
factors = calculator.compute_stock_factors(
|
||||||
|
broker,
|
||||||
|
"000001.SZ",
|
||||||
|
"20251001"
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""初始化情绪因子计算器"""
|
||||||
|
self.factor_specs = {
|
||||||
|
"sent_momentum": 20, # 情感动量窗口
|
||||||
|
"sent_impact": 0, # 新闻影响力
|
||||||
|
"sent_market": 20, # 市场情绪窗口
|
||||||
|
"sent_divergence": 0, # 情绪背离度
|
||||||
|
}
|
||||||
|
|
||||||
|
def compute_stock_factors(
|
||||||
|
self,
|
||||||
|
broker: DataBroker,
|
||||||
|
ts_code: str,
|
||||||
|
trade_date: str,
|
||||||
|
) -> Dict[str, Optional[float]]:
|
||||||
|
"""计算单个股票的情绪因子
|
||||||
|
|
||||||
|
Args:
|
||||||
|
broker: 数据访问器
|
||||||
|
ts_code: 股票代码
|
||||||
|
trade_date: 交易日期
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
因子名称到因子值的映射字典
|
||||||
|
"""
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 获取历史新闻数据
|
||||||
|
news_data = broker.get_news_data(
|
||||||
|
ts_code,
|
||||||
|
trade_date,
|
||||||
|
limit=30 # 保留足够历史以计算动量
|
||||||
|
)
|
||||||
|
|
||||||
|
if not news_data:
|
||||||
|
LOGGER.debug(
|
||||||
|
"无新闻数据 code=%s date=%s",
|
||||||
|
ts_code,
|
||||||
|
trade_date,
|
||||||
|
extra=LOG_EXTRA
|
||||||
|
)
|
||||||
|
return {name: None for name in self.factor_specs}
|
||||||
|
|
||||||
|
# 提取序列数据
|
||||||
|
sentiment_series = [row["sentiment"] for row in news_data]
|
||||||
|
heat_series = [row["heat"] for row in news_data]
|
||||||
|
entity_counts = [
|
||||||
|
len(row["entities"].split(",")) if row["entities"] else 0
|
||||||
|
for row in news_data
|
||||||
|
]
|
||||||
|
|
||||||
|
# 1. 计算新闻情感动量
|
||||||
|
results["sent_momentum"] = news_sentiment_momentum(
|
||||||
|
sentiment_series,
|
||||||
|
window=self.factor_specs["sent_momentum"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. 计算新闻影响力
|
||||||
|
# 使用最新一条新闻的数据
|
||||||
|
results["sent_impact"] = news_impact_score(
|
||||||
|
sentiment=sentiment_series[0],
|
||||||
|
heat=heat_series[0],
|
||||||
|
entity_count=entity_counts[0]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. 计算市场情绪指数
|
||||||
|
# 获取成交量数据
|
||||||
|
volume_data = broker.get_stock_data(
|
||||||
|
ts_code,
|
||||||
|
trade_date,
|
||||||
|
fields=["daily_basic.volume_ratio"],
|
||||||
|
limit=self.factor_specs["sent_market"]
|
||||||
|
)
|
||||||
|
if volume_data:
|
||||||
|
volume_ratios = [
|
||||||
|
row.get("daily_basic.volume_ratio", 1.0)
|
||||||
|
for row in volume_data
|
||||||
|
]
|
||||||
|
results["sent_market"] = market_sentiment_index(
|
||||||
|
sentiment_series,
|
||||||
|
heat_series,
|
||||||
|
volume_ratios,
|
||||||
|
window=self.factor_specs["sent_market"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
results["sent_market"] = None
|
||||||
|
|
||||||
|
# 4. 计算行业情绪背离度
|
||||||
|
industry = broker._lookup_industry(ts_code)
|
||||||
|
if industry:
|
||||||
|
industry_sent = broker._derived_industry_sentiment(
|
||||||
|
industry,
|
||||||
|
trade_date
|
||||||
|
)
|
||||||
|
if industry_sent is not None:
|
||||||
|
# 获取同行业股票的情感得分
|
||||||
|
peer_sents = []
|
||||||
|
for peer in broker.get_industry_stocks(industry):
|
||||||
|
if peer != ts_code:
|
||||||
|
peer_data = broker.get_news_data(
|
||||||
|
peer,
|
||||||
|
trade_date,
|
||||||
|
limit=1
|
||||||
|
)
|
||||||
|
if peer_data:
|
||||||
|
peer_sents.append(peer_data[0]["sentiment"])
|
||||||
|
|
||||||
|
results["sent_divergence"] = industry_sentiment_divergence(
|
||||||
|
industry_sent,
|
||||||
|
peer_sents
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
results["sent_divergence"] = None
|
||||||
|
else:
|
||||||
|
results["sent_divergence"] = None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(
|
||||||
|
"计算情绪因子出错 code=%s date=%s error=%s",
|
||||||
|
ts_code,
|
||||||
|
trade_date,
|
||||||
|
str(e),
|
||||||
|
exc_info=True,
|
||||||
|
extra=LOG_EXTRA
|
||||||
|
)
|
||||||
|
return {name: None for name in self.factor_specs}
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def compute_batch(
|
||||||
|
self,
|
||||||
|
broker: DataBroker,
|
||||||
|
ts_codes: list[str],
|
||||||
|
trade_date: str,
|
||||||
|
batch_size: int = 100
|
||||||
|
) -> None:
|
||||||
|
"""批量计算多个股票的情绪因子并保存
|
||||||
|
|
||||||
|
Args:
|
||||||
|
broker: 数据访问器
|
||||||
|
ts_codes: 股票代码列表
|
||||||
|
trade_date: 交易日期
|
||||||
|
batch_size: 批处理大小
|
||||||
|
"""
|
||||||
|
# 准备SQL语句
|
||||||
|
columns = list(self.factor_specs.keys())
|
||||||
|
insert_columns = ["ts_code", "trade_date", "updated_at"] + columns
|
||||||
|
|
||||||
|
placeholders = ",".join("?" * len(insert_columns))
|
||||||
|
update_clause = ", ".join(
|
||||||
|
f"{column}=excluded.{column}"
|
||||||
|
for column in ["updated_at"] + columns
|
||||||
|
)
|
||||||
|
|
||||||
|
sql = (
|
||||||
|
f"INSERT INTO factors ({','.join(insert_columns)}) "
|
||||||
|
f"VALUES ({placeholders}) "
|
||||||
|
f"ON CONFLICT(ts_code, trade_date) DO UPDATE SET {update_clause}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取当前时间戳
|
||||||
|
timestamp = datetime.now(timezone.utc).isoformat()
|
||||||
|
|
||||||
|
# 分批处理
|
||||||
|
total_processed = 0
|
||||||
|
rows_to_persist = []
|
||||||
|
|
||||||
|
for ts_code in ts_codes:
|
||||||
|
# 计算因子
|
||||||
|
values = self.compute_stock_factors(broker, ts_code, trade_date)
|
||||||
|
|
||||||
|
# 准备数据
|
||||||
|
if any(v is not None for v in values.values()):
|
||||||
|
payload = [ts_code, trade_date, timestamp]
|
||||||
|
payload.extend(values.get(col) for col in columns)
|
||||||
|
rows_to_persist.append(payload)
|
||||||
|
|
||||||
|
total_processed += 1
|
||||||
|
if total_processed % batch_size == 0:
|
||||||
|
LOGGER.info(
|
||||||
|
"情绪因子计算进度: %d/%d (%.1f%%)",
|
||||||
|
total_processed,
|
||||||
|
len(ts_codes),
|
||||||
|
(total_processed / len(ts_codes)) * 100,
|
||||||
|
extra=LOG_EXTRA
|
||||||
|
)
|
||||||
|
|
||||||
|
# 执行批量写入
|
||||||
|
if rows_to_persist:
|
||||||
|
with db_session() as conn:
|
||||||
|
try:
|
||||||
|
conn.executemany(sql, rows_to_persist)
|
||||||
|
LOGGER.info(
|
||||||
|
"情绪因子持久化完成 total=%d",
|
||||||
|
len(rows_to_persist),
|
||||||
|
extra=LOG_EXTRA
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(
|
||||||
|
"情绪因子持久化失败 error=%s",
|
||||||
|
str(e),
|
||||||
|
exc_info=True,
|
||||||
|
extra=LOG_EXTRA
|
||||||
|
)
|
||||||
@ -4,10 +4,13 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from typing import Dict, Iterable, List, Optional
|
from typing import Any, Dict, Iterable, List, Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
from .context import ContextManager, Message
|
||||||
|
from .templates import TemplateRegistry
|
||||||
|
|
||||||
from app.utils.config import (
|
from app.utils.config import (
|
||||||
DEFAULT_LLM_BASE_URLS,
|
DEFAULT_LLM_BASE_URLS,
|
||||||
DEFAULT_LLM_MODELS,
|
DEFAULT_LLM_MODELS,
|
||||||
@ -247,11 +250,58 @@ def _normalize_response(text: str) -> str:
|
|||||||
return " ".join(text.strip().split())
|
return " ".join(text.strip().split())
|
||||||
|
|
||||||
|
|
||||||
def run_llm(prompt: str, *, system: Optional[str] = None) -> str:
|
def run_llm(
|
||||||
"""Execute the globally configured LLM strategy with the given prompt."""
|
prompt: str,
|
||||||
|
*,
|
||||||
|
system: Optional[str] = None,
|
||||||
|
context_id: Optional[str] = None,
|
||||||
|
template_id: Optional[str] = None,
|
||||||
|
template_vars: Optional[Dict[str, Any]] = None
|
||||||
|
) -> str:
|
||||||
|
"""Execute the globally configured LLM strategy with the given prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: Raw prompt string or template variable if template_id is provided
|
||||||
|
system: Optional system message
|
||||||
|
context_id: Optional context ID for conversation tracking
|
||||||
|
template_id: Optional template ID to use
|
||||||
|
template_vars: Variables to use with the template
|
||||||
|
"""
|
||||||
|
# Get config and prepare context
|
||||||
|
cfg = get_config()
|
||||||
|
if context_id:
|
||||||
|
context = ContextManager.get_context(context_id)
|
||||||
|
if not context:
|
||||||
|
context = ContextManager.create_context(context_id)
|
||||||
|
else:
|
||||||
|
context = None
|
||||||
|
|
||||||
settings = get_config().llm
|
# Apply template if specified
|
||||||
return run_llm_with_config(settings, prompt, system=system)
|
if template_id:
|
||||||
|
template = TemplateRegistry.get(template_id)
|
||||||
|
if not template:
|
||||||
|
raise ValueError(f"Template {template_id} not found")
|
||||||
|
vars_dict = template_vars or {}
|
||||||
|
if isinstance(prompt, str):
|
||||||
|
vars_dict["prompt"] = prompt
|
||||||
|
elif isinstance(prompt, dict):
|
||||||
|
vars_dict.update(prompt)
|
||||||
|
prompt = template.format(vars_dict)
|
||||||
|
|
||||||
|
# Add to context if tracking
|
||||||
|
if context:
|
||||||
|
if system:
|
||||||
|
context.add_message(Message(role="system", content=system))
|
||||||
|
context.add_message(Message(role="user", content=prompt))
|
||||||
|
|
||||||
|
# Execute LLM call
|
||||||
|
response = run_llm_with_config(cfg.llm, prompt, system=system)
|
||||||
|
|
||||||
|
# Update context with response
|
||||||
|
if context:
|
||||||
|
context.add_message(Message(role="assistant", content=response))
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
def _run_majority_vote(config: LLMConfig, prompt: str, system: Optional[str]) -> str:
|
def _run_majority_vote(config: LLMConfig, prompt: str, system: Optional[str]) -> str:
|
||||||
|
|||||||
176
app/llm/context.py
Normal file
176
app/llm/context.py
Normal file
@ -0,0 +1,176 @@
|
|||||||
|
"""LLM context management and access control."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Dict, List, Optional, Set
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DataAccessConfig:
|
||||||
|
"""Configuration for data access control."""
|
||||||
|
|
||||||
|
allowed_tables: Set[str]
|
||||||
|
max_history_days: int
|
||||||
|
max_batch_size: int
|
||||||
|
|
||||||
|
def validate_request(
|
||||||
|
self, table: str, start_date: str, end_date: Optional[str] = None
|
||||||
|
) -> List[str]:
|
||||||
|
"""Validate a data access request."""
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
if table not in self.allowed_tables:
|
||||||
|
errors.append(f"Table {table} not allowed")
|
||||||
|
|
||||||
|
try:
|
||||||
|
start_ts = time.strptime(start_date, "%Y%m%d")
|
||||||
|
if end_date:
|
||||||
|
end_ts = time.strptime(end_date, "%Y%m%d")
|
||||||
|
days = (time.mktime(end_ts) - time.mktime(start_ts)) / (24 * 3600)
|
||||||
|
if days > self.max_history_days:
|
||||||
|
errors.append(
|
||||||
|
f"Date range exceeds max {self.max_history_days} days"
|
||||||
|
)
|
||||||
|
if days < 0:
|
||||||
|
errors.append("End date before start date")
|
||||||
|
except ValueError:
|
||||||
|
errors.append("Invalid date format (expected YYYYMMDD)")
|
||||||
|
|
||||||
|
return errors
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ContextConfig:
|
||||||
|
"""Configuration for context management."""
|
||||||
|
|
||||||
|
max_total_tokens: int = 4000
|
||||||
|
max_messages: int = 10
|
||||||
|
include_system: bool = True
|
||||||
|
include_functions: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Message:
|
||||||
|
"""A message in the conversation context."""
|
||||||
|
|
||||||
|
role: str # system, user, assistant, function
|
||||||
|
content: str
|
||||||
|
name: Optional[str] = None # For function calls/results
|
||||||
|
function_call: Optional[Dict[str, Any]] = None
|
||||||
|
timestamp: float = field(default_factory=time.time)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dict format for API calls."""
|
||||||
|
msg = {"role": self.role, "content": self.content}
|
||||||
|
if self.name:
|
||||||
|
msg["name"] = self.name
|
||||||
|
if self.function_call:
|
||||||
|
msg["function_call"] = self.function_call
|
||||||
|
return msg
|
||||||
|
|
||||||
|
@property
|
||||||
|
def estimated_tokens(self) -> int:
|
||||||
|
"""Rough estimate of tokens in message."""
|
||||||
|
# Very rough estimate: 1 token ≈ 4 chars
|
||||||
|
base = len(self.content) // 4
|
||||||
|
if self.function_call:
|
||||||
|
base += len(json.dumps(self.function_call)) // 4
|
||||||
|
return base
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Context:
|
||||||
|
"""Manages conversation context with token tracking."""
|
||||||
|
|
||||||
|
messages: List[Message] = field(default_factory=list)
|
||||||
|
config: ContextConfig = field(default_factory=ContextConfig)
|
||||||
|
_token_count: int = 0
|
||||||
|
|
||||||
|
def add_message(self, message: Message) -> None:
|
||||||
|
"""Add a message to context, maintaining token limit."""
|
||||||
|
# Update token count
|
||||||
|
new_tokens = message.estimated_tokens
|
||||||
|
while (
|
||||||
|
self._token_count + new_tokens > self.config.max_total_tokens
|
||||||
|
and self.messages
|
||||||
|
):
|
||||||
|
# Remove oldest non-system message if needed
|
||||||
|
for i, msg in enumerate(self.messages):
|
||||||
|
if msg.role != "system" or len(self.messages) <= 1:
|
||||||
|
removed = self.messages.pop(i)
|
||||||
|
self._token_count -= removed.estimated_tokens
|
||||||
|
break
|
||||||
|
|
||||||
|
# Add new message
|
||||||
|
self.messages.append(message)
|
||||||
|
self._token_count += new_tokens
|
||||||
|
|
||||||
|
# Trim to max messages if needed
|
||||||
|
while len(self.messages) > self.config.max_messages:
|
||||||
|
for i, msg in enumerate(self.messages):
|
||||||
|
if msg.role != "system" or len(self.messages) <= 1:
|
||||||
|
removed = self.messages.pop(i)
|
||||||
|
self._token_count -= removed.estimated_tokens
|
||||||
|
break
|
||||||
|
|
||||||
|
def get_messages(
|
||||||
|
self, include_system: bool = None, include_functions: bool = None
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Get messages for API call."""
|
||||||
|
if include_system is None:
|
||||||
|
include_system = self.config.include_system
|
||||||
|
if include_functions is None:
|
||||||
|
include_functions = self.config.include_functions
|
||||||
|
|
||||||
|
return [
|
||||||
|
msg.to_dict()
|
||||||
|
for msg in self.messages
|
||||||
|
if (include_system or msg.role != "system")
|
||||||
|
and (include_functions or msg.role != "function")
|
||||||
|
]
|
||||||
|
|
||||||
|
def clear(self, keep_system: bool = True) -> None:
|
||||||
|
"""Clear context, optionally keeping system messages."""
|
||||||
|
if keep_system:
|
||||||
|
system_msgs = [m for m in self.messages if m.role == "system"]
|
||||||
|
self.messages = system_msgs
|
||||||
|
self._token_count = sum(m.estimated_tokens for m in system_msgs)
|
||||||
|
else:
|
||||||
|
self.messages.clear()
|
||||||
|
self._token_count = 0
|
||||||
|
|
||||||
|
|
||||||
|
class ContextManager:
|
||||||
|
"""Global manager for conversation contexts."""
|
||||||
|
|
||||||
|
_contexts: Dict[str, Context] = {}
|
||||||
|
_configs: Dict[str, ContextConfig] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_context(
|
||||||
|
cls, context_id: str, config: Optional[ContextConfig] = None
|
||||||
|
) -> Context:
|
||||||
|
"""Create a new context."""
|
||||||
|
if context_id in cls._contexts:
|
||||||
|
raise ValueError(f"Context {context_id} already exists")
|
||||||
|
context = Context(config=config or ContextConfig())
|
||||||
|
cls._contexts[context_id] = context
|
||||||
|
return context
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_context(cls, context_id: str) -> Optional[Context]:
|
||||||
|
"""Get existing context."""
|
||||||
|
return cls._contexts.get(context_id)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def remove_context(cls, context_id: str) -> None:
|
||||||
|
"""Remove a context."""
|
||||||
|
if context_id in cls._contexts:
|
||||||
|
del cls._contexts[context_id]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def clear_all(cls) -> None:
|
||||||
|
"""Clear all contexts."""
|
||||||
|
cls._contexts.clear()
|
||||||
@ -3,6 +3,8 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import Dict, TYPE_CHECKING
|
from typing import Dict, TYPE_CHECKING
|
||||||
|
|
||||||
|
from .templates import TemplateRegistry
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma: no cover
|
if TYPE_CHECKING: # pragma: no cover
|
||||||
from app.utils.config import DepartmentSettings
|
from app.utils.config import DepartmentSettings
|
||||||
from app.agents.departments import DepartmentContext
|
from app.agents.departments import DepartmentContext
|
||||||
@ -21,7 +23,8 @@ def department_prompt(
|
|||||||
supplements: str = "",
|
supplements: str = "",
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Compose a structured prompt for department-level LLM ensemble."""
|
"""Compose a structured prompt for department-level LLM ensemble."""
|
||||||
|
|
||||||
|
# Format data for template
|
||||||
feature_lines = "\n".join(
|
feature_lines = "\n".join(
|
||||||
f"- {key}: {value}" for key, value in sorted(context.features.items())
|
f"- {key}: {value}" for key, value in sorted(context.features.items())
|
||||||
)
|
)
|
||||||
@ -31,40 +34,25 @@ def department_prompt(
|
|||||||
scope_lines = "\n".join(f"- {item}" for item in settings.data_scope)
|
scope_lines = "\n".join(f"- {item}" for item in settings.data_scope)
|
||||||
role_description = settings.description.strip()
|
role_description = settings.description.strip()
|
||||||
role_instruction = settings.prompt.strip()
|
role_instruction = settings.prompt.strip()
|
||||||
supplement_block = supplements.strip()
|
|
||||||
|
# Determine template ID based on department settings
|
||||||
instructions = f"""
|
template_id = f"{settings.code.lower()}_dept"
|
||||||
部门名称:{settings.title}
|
if not TemplateRegistry.get(template_id):
|
||||||
股票代码:{context.ts_code}
|
template_id = "department_base"
|
||||||
交易日:{context.trade_date}
|
|
||||||
|
# Prepare template variables
|
||||||
角色说明:{role_description or '未配置,默认沿用部门名称所代表的研究职责。'}
|
template_vars = {
|
||||||
职责指令:{role_instruction or '在保持部门风格的前提下,结合可用数据做出审慎判断。'}
|
"title": settings.title,
|
||||||
|
"ts_code": context.ts_code,
|
||||||
【可用数据范围】
|
"trade_date": context.trade_date,
|
||||||
{scope_lines or '- 使用系统提供的全部上下文,必要时指出仍需的额外数据。'}
|
"description": role_description or "未配置,默认沿用部门名称所代表的研究职责。",
|
||||||
|
"instruction": role_instruction or "在保持部门风格的前提下,结合可用数据做出审慎判断。",
|
||||||
【核心特征】
|
"data_scope": scope_lines or "- 使用系统提供的全部上下文,必要时指出仍需的额外数据。",
|
||||||
{feature_lines or '- (无)'}
|
"features": feature_lines or "- (无)",
|
||||||
|
"market_snapshot": market_lines or "- (无)",
|
||||||
【市场背景】
|
"supplements": supplements.strip() or "- 当前无追加数据"
|
||||||
{market_lines or '- (无)'}
|
}
|
||||||
|
|
||||||
【追加数据】
|
# Get template and format prompt
|
||||||
{supplement_block or '- 当前无追加数据'}
|
template = TemplateRegistry.get(template_id)
|
||||||
|
return template.format(template_vars)
|
||||||
请基于以上数据给出该部门对当前股票的操作建议。输出必须是 JSON,字段如下:
|
|
||||||
{{
|
|
||||||
"action": "BUY|BUY_S|BUY_M|BUY_L|SELL|HOLD",
|
|
||||||
"confidence": 0-1 之间的小数,表示信心,
|
|
||||||
"summary": "一句话概括理由",
|
|
||||||
"signals": ["详细要点", "..."],
|
|
||||||
"risks": ["风险点", "..."]
|
|
||||||
}}
|
|
||||||
|
|
||||||
如需额外数据,请调用工具 `fetch_data`,仅支持请求 `daily` 或 `daily_basic` 表;在参数中填写 `tables` 数组,元素包含 `name`(表名)与可选的 `window`(向前回溯的条数,默认 1)及 `trade_date`(YYYYMMDD,默认本次交易日)。
|
|
||||||
工具返回的数据会在后续消息中提供,请在获取所有必要信息后再给出最终 JSON 答复。
|
|
||||||
|
|
||||||
请严格返回单个 JSON 对象,不要添加额外文本。
|
|
||||||
"""
|
|
||||||
return instructions.strip()
|
|
||||||
|
|||||||
219
app/llm/templates.py
Normal file
219
app/llm/templates.py
Normal file
@ -0,0 +1,219 @@
|
|||||||
|
"""LLM prompt templates management with configuration driven design."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PromptTemplate:
|
||||||
|
"""Configuration driven prompt template."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
template: str
|
||||||
|
variables: List[str]
|
||||||
|
max_length: int = 4000
|
||||||
|
required_context: List[str] = None
|
||||||
|
validation_rules: List[str] = None
|
||||||
|
|
||||||
|
def validate(self) -> List[str]:
|
||||||
|
"""Validate template configuration."""
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
# Check template contains all variables
|
||||||
|
for var in self.variables:
|
||||||
|
if f"{{{var}}}" not in self.template:
|
||||||
|
errors.append(f"Template missing variable: {var}")
|
||||||
|
|
||||||
|
# Check required context fields
|
||||||
|
if self.required_context:
|
||||||
|
for field in self.required_context:
|
||||||
|
if not field:
|
||||||
|
errors.append("Empty required context field")
|
||||||
|
|
||||||
|
# Check validation rules format
|
||||||
|
if self.validation_rules:
|
||||||
|
for rule in self.validation_rules:
|
||||||
|
if not rule:
|
||||||
|
errors.append("Empty validation rule")
|
||||||
|
|
||||||
|
return errors
|
||||||
|
|
||||||
|
def format(self, context: Dict[str, Any]) -> str:
|
||||||
|
"""Format template with provided context."""
|
||||||
|
# Validate required context
|
||||||
|
if self.required_context:
|
||||||
|
missing = [f for f in self.required_context if f not in context]
|
||||||
|
if missing:
|
||||||
|
raise ValueError(f"Missing required context: {', '.join(missing)}")
|
||||||
|
|
||||||
|
# Format template
|
||||||
|
try:
|
||||||
|
result = self.template.format(**context)
|
||||||
|
except KeyError as e:
|
||||||
|
raise ValueError(f"Missing template variable: {e}")
|
||||||
|
|
||||||
|
# Truncate if needed
|
||||||
|
if len(result) > self.max_length:
|
||||||
|
result = result[:self.max_length-3] + "..."
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class TemplateRegistry:
|
||||||
|
"""Global registry for prompt templates."""
|
||||||
|
|
||||||
|
_templates: Dict[str, PromptTemplate] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register(cls, template: PromptTemplate) -> None:
|
||||||
|
"""Register a new template."""
|
||||||
|
errors = template.validate()
|
||||||
|
if errors:
|
||||||
|
raise ValueError(f"Invalid template {template.id}: {'; '.join(errors)}")
|
||||||
|
cls._templates[template.id] = template
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get(cls, template_id: str) -> Optional[PromptTemplate]:
|
||||||
|
"""Get template by ID."""
|
||||||
|
return cls._templates.get(template_id)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list(cls) -> List[PromptTemplate]:
|
||||||
|
"""List all registered templates."""
|
||||||
|
return list(cls._templates.values())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_from_json(cls, json_str: str) -> None:
|
||||||
|
"""Load templates from JSON string."""
|
||||||
|
try:
|
||||||
|
data = json.loads(json_str)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise ValueError(f"Invalid JSON: {e}")
|
||||||
|
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
raise ValueError("JSON root must be an object")
|
||||||
|
|
||||||
|
for template_id, cfg in data.items():
|
||||||
|
template = PromptTemplate(
|
||||||
|
id=template_id,
|
||||||
|
name=cfg.get("name", template_id),
|
||||||
|
description=cfg.get("description", ""),
|
||||||
|
template=cfg.get("template", ""),
|
||||||
|
variables=cfg.get("variables", []),
|
||||||
|
max_length=cfg.get("max_length", 4000),
|
||||||
|
required_context=cfg.get("required_context", []),
|
||||||
|
validation_rules=cfg.get("validation_rules", [])
|
||||||
|
)
|
||||||
|
cls.register(template)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def clear(cls) -> None:
|
||||||
|
"""Clear all registered templates."""
|
||||||
|
cls._templates.clear()
|
||||||
|
|
||||||
|
|
||||||
|
# Register default templates
|
||||||
|
DEFAULT_TEMPLATES = {
|
||||||
|
"department_base": {
|
||||||
|
"name": "部门基础模板",
|
||||||
|
"description": "通用的部门分析提示模板",
|
||||||
|
"template": """
|
||||||
|
部门名称:{title}
|
||||||
|
股票代码:{ts_code}
|
||||||
|
交易日:{trade_date}
|
||||||
|
|
||||||
|
角色说明:{description}
|
||||||
|
职责指令:{instruction}
|
||||||
|
|
||||||
|
【可用数据范围】
|
||||||
|
{data_scope}
|
||||||
|
|
||||||
|
【核心特征】
|
||||||
|
{features}
|
||||||
|
|
||||||
|
【市场背景】
|
||||||
|
{market_snapshot}
|
||||||
|
|
||||||
|
【追加数据】
|
||||||
|
{supplements}
|
||||||
|
|
||||||
|
请基于以上数据给出该部门对当前股票的操作建议。输出必须是 JSON,字段如下:
|
||||||
|
{{
|
||||||
|
"action": "BUY|BUY_S|BUY_M|BUY_L|SELL|HOLD",
|
||||||
|
"confidence": 0-1 之间的小数,表示信心,
|
||||||
|
"summary": "一句话概括理由",
|
||||||
|
"signals": ["详细要点", "..."],
|
||||||
|
"risks": ["风险点", "..."]
|
||||||
|
}}
|
||||||
|
|
||||||
|
如需额外数据,请调用工具 `fetch_data`,仅支持请求 `daily` 或 `daily_basic` 表。
|
||||||
|
请严格返回单个 JSON 对象,不要添加额外文本。
|
||||||
|
""",
|
||||||
|
"variables": [
|
||||||
|
"title", "ts_code", "trade_date", "description", "instruction",
|
||||||
|
"data_scope", "features", "market_snapshot", "supplements"
|
||||||
|
],
|
||||||
|
"required_context": [
|
||||||
|
"ts_code", "trade_date", "features", "market_snapshot"
|
||||||
|
],
|
||||||
|
"validation_rules": [
|
||||||
|
"len(features) > 0",
|
||||||
|
"len(market_snapshot) > 0"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"momentum_dept": {
|
||||||
|
"name": "动量研究部门",
|
||||||
|
"description": "专注于动量因子分析的部门模板",
|
||||||
|
"template": """
|
||||||
|
部门名称:动量研究部门
|
||||||
|
股票代码:{ts_code}
|
||||||
|
交易日:{trade_date}
|
||||||
|
|
||||||
|
角色说明:专注于分析股票价格动量、成交量动量和技术指标动量
|
||||||
|
职责指令:重点关注以下方面:
|
||||||
|
1. 价格趋势强度和持续性
|
||||||
|
2. 成交量配合度
|
||||||
|
3. 技术指标背离
|
||||||
|
|
||||||
|
【可用数据范围】
|
||||||
|
{data_scope}
|
||||||
|
|
||||||
|
【动量特征】
|
||||||
|
{features}
|
||||||
|
|
||||||
|
【市场背景】
|
||||||
|
{market_snapshot}
|
||||||
|
|
||||||
|
【追加数据】
|
||||||
|
{supplements}
|
||||||
|
|
||||||
|
请基于以上数据进行动量分析并给出操作建议。输出必须是 JSON,字段如下:
|
||||||
|
{{
|
||||||
|
"action": "BUY|BUY_S|BUY_M|BUY_L|SELL|HOLD",
|
||||||
|
"confidence": 0-1 之间的小数,表示信心,
|
||||||
|
"summary": "一句话概括动量分析结论",
|
||||||
|
"signals": ["动量信号要点", "..."],
|
||||||
|
"risks": ["动量风险点", "..."]
|
||||||
|
}}
|
||||||
|
""",
|
||||||
|
"variables": [
|
||||||
|
"ts_code", "trade_date", "data_scope",
|
||||||
|
"features", "market_snapshot", "supplements"
|
||||||
|
],
|
||||||
|
"required_context": [
|
||||||
|
"ts_code", "trade_date", "features", "market_snapshot"
|
||||||
|
],
|
||||||
|
"validation_rules": [
|
||||||
|
"len(features) > 0",
|
||||||
|
"'momentum' in ' '.join(features.keys()).lower()"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Register default templates
|
||||||
|
for template_id, cfg in DEFAULT_TEMPLATES.items():
|
||||||
|
TemplateRegistry.register(PromptTemplate(**{"id": template_id, **cfg}))
|
||||||
@ -17,8 +17,8 @@
|
|||||||
- [x] 完善 `compute_factors()` 函数实现:
|
- [x] 完善 `compute_factors()` 函数实现:
|
||||||
- [x] 添加数据有效性校验机制
|
- [x] 添加数据有效性校验机制
|
||||||
- [x] 实现异常值检测与处理逻辑
|
- [x] 实现异常值检测与处理逻辑
|
||||||
- [ ] 增加计算进度显示和日志记录
|
- [x] 增加计算进度显示和日志记录
|
||||||
- [ ] 优化因子持久化性能
|
- [x] 优化因子持久化性能
|
||||||
- [ ] 支持增量计算模式
|
- [ ] 支持增量计算模式
|
||||||
|
|
||||||
#### 1.2 DataBroker增强
|
#### 1.2 DataBroker增强
|
||||||
@ -30,7 +30,7 @@
|
|||||||
- [x] 扩展动量类因子群
|
- [x] 扩展动量类因子群
|
||||||
- [x] 开发估值类因子群
|
- [x] 开发估值类因子群
|
||||||
- [x] 设计流动性因子群
|
- [x] 设计流动性因子群
|
||||||
- [ ] 构建市场情绪因子群
|
- [x] 构建市场情绪因子群
|
||||||
- [ ] 开发因子组合和权重优化算法
|
- [ ] 开发因子组合和权重优化算法
|
||||||
|
|
||||||
#### 1.4 新闻数据源完善
|
#### 1.4 新闻数据源完善
|
||||||
|
|||||||
148
tests/test_llm_context.py
Normal file
148
tests/test_llm_context.py
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
"""Test cases for LLM context management."""
|
||||||
|
import time
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.llm.context import Context, ContextConfig, ContextManager, DataAccessConfig, Message
|
||||||
|
|
||||||
|
|
||||||
|
def test_data_access_config():
|
||||||
|
"""Test data access configuration and validation."""
|
||||||
|
config = DataAccessConfig(
|
||||||
|
allowed_tables={"daily", "daily_basic"},
|
||||||
|
max_history_days=365,
|
||||||
|
max_batch_size=1000
|
||||||
|
)
|
||||||
|
|
||||||
|
# Valid request
|
||||||
|
errors = config.validate_request("daily", "20251001", "20251005")
|
||||||
|
assert not errors
|
||||||
|
|
||||||
|
# Invalid table
|
||||||
|
errors = config.validate_request("invalid", "20251001")
|
||||||
|
assert len(errors) == 1
|
||||||
|
assert "not allowed" in errors[0]
|
||||||
|
|
||||||
|
# Invalid date format
|
||||||
|
errors = config.validate_request("daily", "invalid")
|
||||||
|
assert len(errors) == 1
|
||||||
|
assert "Invalid date format" in errors[0]
|
||||||
|
|
||||||
|
# Date range too long
|
||||||
|
errors = config.validate_request("daily", "20251001", "20261001")
|
||||||
|
assert len(errors) == 1
|
||||||
|
assert "exceeds max" in errors[0]
|
||||||
|
|
||||||
|
# End date before start
|
||||||
|
errors = config.validate_request("daily", "20251005", "20251001")
|
||||||
|
assert len(errors) == 1
|
||||||
|
assert "before start date" in errors[0]
|
||||||
|
|
||||||
|
|
||||||
|
def test_context_config():
|
||||||
|
"""Test context configuration defaults."""
|
||||||
|
config = ContextConfig()
|
||||||
|
assert config.max_total_tokens > 0
|
||||||
|
assert config.max_messages > 0
|
||||||
|
assert config.include_system is True
|
||||||
|
assert config.include_functions is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_message():
|
||||||
|
"""Test message functionality."""
|
||||||
|
# Basic message
|
||||||
|
msg = Message(role="user", content="Hello")
|
||||||
|
assert msg.role == "user"
|
||||||
|
assert msg.content == "Hello"
|
||||||
|
assert msg.name is None
|
||||||
|
assert msg.function_call is None
|
||||||
|
assert msg.timestamp <= time.time()
|
||||||
|
|
||||||
|
# Function message
|
||||||
|
func_msg = Message(
|
||||||
|
role="function",
|
||||||
|
content="Result",
|
||||||
|
name="test_func",
|
||||||
|
function_call={"name": "test_func", "arguments": "{}"}
|
||||||
|
)
|
||||||
|
assert func_msg.name == "test_func"
|
||||||
|
assert func_msg.function_call is not None
|
||||||
|
|
||||||
|
# Dict conversion
|
||||||
|
msg_dict = msg.to_dict()
|
||||||
|
assert msg_dict["role"] == "user"
|
||||||
|
assert msg_dict["content"] == "Hello"
|
||||||
|
assert "name" not in msg_dict
|
||||||
|
assert "function_call" not in msg_dict
|
||||||
|
|
||||||
|
func_dict = func_msg.to_dict()
|
||||||
|
assert func_dict["name"] == "test_func"
|
||||||
|
assert func_dict["function_call"] is not None
|
||||||
|
|
||||||
|
# Token estimation
|
||||||
|
assert msg.estimated_tokens > 0
|
||||||
|
assert func_msg.estimated_tokens > msg.estimated_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def test_context():
|
||||||
|
"""Test context management."""
|
||||||
|
config = ContextConfig(max_total_tokens=100, max_messages=3)
|
||||||
|
context = Context(config=config)
|
||||||
|
|
||||||
|
# Add messages
|
||||||
|
msg1 = Message(role="system", content="System message")
|
||||||
|
msg2 = Message(role="user", content="User message")
|
||||||
|
msg3 = Message(role="assistant", content="Assistant message")
|
||||||
|
msg4 = Message(role="user", content="Another message")
|
||||||
|
|
||||||
|
context.add_message(msg1)
|
||||||
|
context.add_message(msg2)
|
||||||
|
context.add_message(msg3)
|
||||||
|
assert len(context.messages) == 3
|
||||||
|
|
||||||
|
# Test max messages
|
||||||
|
context.add_message(msg4)
|
||||||
|
assert len(context.messages) == 3
|
||||||
|
assert msg4 in context.messages # Newest message kept
|
||||||
|
|
||||||
|
# Get messages
|
||||||
|
all_msgs = context.get_messages()
|
||||||
|
assert len(all_msgs) == 3
|
||||||
|
|
||||||
|
no_system = context.get_messages(include_system=False)
|
||||||
|
assert len(no_system) == 2
|
||||||
|
assert all(m["role"] != "system" for m in no_system)
|
||||||
|
|
||||||
|
# Clear context
|
||||||
|
context.clear(keep_system=True)
|
||||||
|
assert len(context.messages) == 1
|
||||||
|
assert context.messages[0].role == "system"
|
||||||
|
|
||||||
|
context.clear(keep_system=False)
|
||||||
|
assert len(context.messages) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_context_manager():
|
||||||
|
"""Test context manager functionality."""
|
||||||
|
ContextManager.clear_all()
|
||||||
|
|
||||||
|
# Create context
|
||||||
|
context = ContextManager.create_context("test")
|
||||||
|
assert ContextManager.get_context("test") == context
|
||||||
|
|
||||||
|
# Duplicate context
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
ContextManager.create_context("test")
|
||||||
|
|
||||||
|
# Custom config
|
||||||
|
config = ContextConfig(max_total_tokens=200)
|
||||||
|
custom = ContextManager.create_context("custom", config)
|
||||||
|
assert custom.config.max_total_tokens == 200
|
||||||
|
|
||||||
|
# Remove context
|
||||||
|
ContextManager.remove_context("test")
|
||||||
|
assert ContextManager.get_context("test") is None
|
||||||
|
|
||||||
|
# Clear all
|
||||||
|
ContextManager.clear_all()
|
||||||
|
assert ContextManager.get_context("custom") is None
|
||||||
178
tests/test_llm_templates.py
Normal file
178
tests/test_llm_templates.py
Normal file
@ -0,0 +1,178 @@
|
|||||||
|
"""Test cases for LLM template management."""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.llm.templates import PromptTemplate, TemplateRegistry
|
||||||
|
|
||||||
|
|
||||||
|
def test_prompt_template_validation():
|
||||||
|
"""Test template validation logic."""
|
||||||
|
# Valid template
|
||||||
|
template = PromptTemplate(
|
||||||
|
id="test",
|
||||||
|
name="Test Template",
|
||||||
|
description="A test template",
|
||||||
|
template="Hello {name}!",
|
||||||
|
variables=["name"]
|
||||||
|
)
|
||||||
|
assert not template.validate()
|
||||||
|
|
||||||
|
# Missing variable
|
||||||
|
template = PromptTemplate(
|
||||||
|
id="test",
|
||||||
|
name="Test Template",
|
||||||
|
description="A test template",
|
||||||
|
template="Hello {name}!",
|
||||||
|
variables=["name", "missing"]
|
||||||
|
)
|
||||||
|
errors = template.validate()
|
||||||
|
assert len(errors) == 1
|
||||||
|
assert "missing" in errors[0]
|
||||||
|
|
||||||
|
# Empty required context
|
||||||
|
template = PromptTemplate(
|
||||||
|
id="test",
|
||||||
|
name="Test Template",
|
||||||
|
description="A test template",
|
||||||
|
template="Hello {name}!",
|
||||||
|
variables=["name"],
|
||||||
|
required_context=["", "name"]
|
||||||
|
)
|
||||||
|
errors = template.validate()
|
||||||
|
assert len(errors) == 1
|
||||||
|
assert "Empty required context" in errors[0]
|
||||||
|
|
||||||
|
# Empty validation rule
|
||||||
|
template = PromptTemplate(
|
||||||
|
id="test",
|
||||||
|
name="Test Template",
|
||||||
|
description="A test template",
|
||||||
|
template="Hello {name}!",
|
||||||
|
variables=["name"],
|
||||||
|
validation_rules=["len(name) > 0", ""]
|
||||||
|
)
|
||||||
|
errors = template.validate()
|
||||||
|
assert len(errors) == 1
|
||||||
|
assert "Empty validation rule" in errors[0]
|
||||||
|
|
||||||
|
|
||||||
|
def test_prompt_template_format():
|
||||||
|
"""Test template formatting."""
|
||||||
|
template = PromptTemplate(
|
||||||
|
id="test",
|
||||||
|
name="Test Template",
|
||||||
|
description="A test template",
|
||||||
|
template="Hello {name}!",
|
||||||
|
variables=["name"],
|
||||||
|
required_context=["name"],
|
||||||
|
max_length=10
|
||||||
|
)
|
||||||
|
|
||||||
|
# Valid context
|
||||||
|
result = template.format({"name": "World"})
|
||||||
|
assert result == "Hello Wor..."
|
||||||
|
|
||||||
|
# Missing required context
|
||||||
|
with pytest.raises(ValueError) as exc:
|
||||||
|
template.format({})
|
||||||
|
assert "Missing required context" in str(exc.value)
|
||||||
|
|
||||||
|
# Missing variable
|
||||||
|
with pytest.raises(ValueError) as exc:
|
||||||
|
template.format({"wrong": "value"})
|
||||||
|
assert "Missing template variable" in str(exc.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_template_registry():
|
||||||
|
"""Test template registry operations."""
|
||||||
|
TemplateRegistry.clear()
|
||||||
|
|
||||||
|
# Register valid template
|
||||||
|
template = PromptTemplate(
|
||||||
|
id="test",
|
||||||
|
name="Test Template",
|
||||||
|
description="A test template",
|
||||||
|
template="Hello {name}!",
|
||||||
|
variables=["name"]
|
||||||
|
)
|
||||||
|
TemplateRegistry.register(template)
|
||||||
|
assert TemplateRegistry.get("test") == template
|
||||||
|
|
||||||
|
# Register invalid template
|
||||||
|
invalid = PromptTemplate(
|
||||||
|
id="invalid",
|
||||||
|
name="Invalid Template",
|
||||||
|
description="An invalid template",
|
||||||
|
template="Hello {name}!",
|
||||||
|
variables=["wrong"]
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError) as exc:
|
||||||
|
TemplateRegistry.register(invalid)
|
||||||
|
assert "Invalid template" in str(exc.value)
|
||||||
|
|
||||||
|
# List templates
|
||||||
|
templates = TemplateRegistry.list()
|
||||||
|
assert len(templates) == 1
|
||||||
|
assert templates[0].id == "test"
|
||||||
|
|
||||||
|
# Load from JSON
|
||||||
|
json_str = '''
|
||||||
|
{
|
||||||
|
"json_test": {
|
||||||
|
"name": "JSON Test",
|
||||||
|
"description": "Test template from JSON",
|
||||||
|
"template": "Hello {name}!",
|
||||||
|
"variables": ["name"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
TemplateRegistry.load_from_json(json_str)
|
||||||
|
assert TemplateRegistry.get("json_test") is not None
|
||||||
|
|
||||||
|
# Invalid JSON
|
||||||
|
with pytest.raises(ValueError) as exc:
|
||||||
|
TemplateRegistry.load_from_json("invalid json")
|
||||||
|
assert "Invalid JSON" in str(exc.value)
|
||||||
|
|
||||||
|
# Non-object JSON
|
||||||
|
with pytest.raises(ValueError) as exc:
|
||||||
|
TemplateRegistry.load_from_json("[1, 2, 3]")
|
||||||
|
assert "JSON root must be an object" in str(exc.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_templates():
|
||||||
|
"""Test default template registration."""
|
||||||
|
TemplateRegistry.clear()
|
||||||
|
from app.llm.templates import DEFAULT_TEMPLATES
|
||||||
|
|
||||||
|
# Verify default templates are loaded
|
||||||
|
assert len(TemplateRegistry.list()) > 0
|
||||||
|
|
||||||
|
# Check specific templates
|
||||||
|
dept_base = TemplateRegistry.get("department_base")
|
||||||
|
assert dept_base is not None
|
||||||
|
assert "部门基础模板" in dept_base.name
|
||||||
|
|
||||||
|
momentum = TemplateRegistry.get("momentum_dept")
|
||||||
|
assert momentum is not None
|
||||||
|
assert "动量研究部门" in momentum.name
|
||||||
|
|
||||||
|
# Validate template content
|
||||||
|
assert all("{" + var + "}" in dept_base.template for var in dept_base.variables)
|
||||||
|
assert all("{" + var + "}" in momentum.template for var in momentum.variables)
|
||||||
|
|
||||||
|
# Test template usage
|
||||||
|
context = {
|
||||||
|
"title": "测试部门",
|
||||||
|
"ts_code": "000001.SZ",
|
||||||
|
"trade_date": "20251005",
|
||||||
|
"description": "测试描述",
|
||||||
|
"instruction": "测试指令",
|
||||||
|
"data_scope": "daily,daily_basic",
|
||||||
|
"features": "特征1,特征2",
|
||||||
|
"market_snapshot": "市场数据1,市场数据2",
|
||||||
|
"supplements": "补充数据"
|
||||||
|
}
|
||||||
|
result = dept_base.format(context)
|
||||||
|
assert "测试部门" in result
|
||||||
|
assert "000001.SZ" in result
|
||||||
|
assert "20251005" in result
|
||||||
133
tests/test_sentiment_factors.py
Normal file
133
tests/test_sentiment_factors.py
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
"""Tests for sentiment factor computation."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import date, datetime
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.features.sentiment_factors import SentimentFactors
|
||||||
|
from app.utils.data_access import DataBroker
|
||||||
|
|
||||||
|
|
||||||
|
class MockDataBroker:
|
||||||
|
"""Mock DataBroker for testing."""
|
||||||
|
|
||||||
|
def get_news_data(
|
||||||
|
self,
|
||||||
|
ts_code: str,
|
||||||
|
trade_date: str,
|
||||||
|
limit: int = 30
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""模拟新闻数据"""
|
||||||
|
if ts_code == "000001.SZ":
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"sentiment": 0.8,
|
||||||
|
"heat": 0.6,
|
||||||
|
"entities": "公司A,行业B,概念C"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"sentiment": 0.6,
|
||||||
|
"heat": 0.4,
|
||||||
|
"entities": "公司A,概念D"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_stock_data(
|
||||||
|
self,
|
||||||
|
ts_code: str,
|
||||||
|
trade_date: str,
|
||||||
|
fields: List[str],
|
||||||
|
limit: int = 1
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""模拟股票数据"""
|
||||||
|
if ts_code == "000001.SZ":
|
||||||
|
return [
|
||||||
|
{"daily_basic.volume_ratio": 1.2},
|
||||||
|
{"daily_basic.volume_ratio": 1.1}
|
||||||
|
]
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _lookup_industry(self, ts_code: str) -> str:
|
||||||
|
"""模拟行业查询"""
|
||||||
|
if ts_code == "000001.SZ":
|
||||||
|
return "银行"
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def _derived_industry_sentiment(
|
||||||
|
self,
|
||||||
|
industry: str,
|
||||||
|
trade_date: str
|
||||||
|
) -> float:
|
||||||
|
"""模拟行业情绪"""
|
||||||
|
if industry == "银行":
|
||||||
|
return 0.5
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
def get_industry_stocks(self, industry: str) -> List[str]:
|
||||||
|
"""模拟行业成分股"""
|
||||||
|
if industry == "银行":
|
||||||
|
return ["000001.SZ", "600000.SH"]
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def test_compute_stock_factors():
|
||||||
|
"""测试股票情绪因子计算"""
|
||||||
|
calculator = SentimentFactors()
|
||||||
|
broker = MockDataBroker()
|
||||||
|
|
||||||
|
# 测试有数据的情况
|
||||||
|
factors = calculator.compute_stock_factors(
|
||||||
|
broker,
|
||||||
|
"000001.SZ",
|
||||||
|
"20251001"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "sent_momentum" in factors
|
||||||
|
assert "sent_impact" in factors
|
||||||
|
assert "sent_market" in factors
|
||||||
|
assert "sent_divergence" in factors
|
||||||
|
|
||||||
|
assert factors["sent_impact"] > 0
|
||||||
|
|
||||||
|
# 测试无数据的情况
|
||||||
|
factors = calculator.compute_stock_factors(
|
||||||
|
broker,
|
||||||
|
"000002.SZ",
|
||||||
|
"20251001"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert all(v is None for v in factors.values())
|
||||||
|
|
||||||
|
def test_compute_batch(tmp_path):
|
||||||
|
"""测试批量计算功能"""
|
||||||
|
from app.data.schema import initialize_database
|
||||||
|
from app.utils.config import get_config
|
||||||
|
|
||||||
|
# 配置测试数据库
|
||||||
|
config = get_config()
|
||||||
|
config.db_path = tmp_path / "test.db"
|
||||||
|
|
||||||
|
# 初始化数据库
|
||||||
|
initialize_database()
|
||||||
|
|
||||||
|
calculator = SentimentFactors()
|
||||||
|
broker = MockDataBroker()
|
||||||
|
|
||||||
|
# 测试批量计算
|
||||||
|
ts_codes = ["000001.SZ", "000002.SZ", "600000.SH"]
|
||||||
|
calculator.compute_batch(broker, ts_codes, "20251001")
|
||||||
|
|
||||||
|
# 验证数据已保存
|
||||||
|
from app.utils.db import db_session
|
||||||
|
with db_session() as conn:
|
||||||
|
rows = conn.execute(
|
||||||
|
"SELECT * FROM factors WHERE trade_date = ?",
|
||||||
|
("20251001",)
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
# 应该只有一个股票有数据
|
||||||
|
assert len(rows) == 1
|
||||||
|
assert rows[0]["ts_code"] == "000001.SZ"
|
||||||
102
tests/test_sentiment_indicators.py
Normal file
102
tests/test_sentiment_indicators.py
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
"""Tests for market sentiment indicators."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.core.sentiment import (
|
||||||
|
news_sentiment_momentum,
|
||||||
|
news_impact_score,
|
||||||
|
market_sentiment_index,
|
||||||
|
industry_sentiment_divergence
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_news_sentiment_momentum():
|
||||||
|
# 生成测试数据
|
||||||
|
window = 20
|
||||||
|
uptrend = np.linspace(-0.5, 0.5, window) # 上升趋势
|
||||||
|
downtrend = np.linspace(0.5, -0.5, window) # 下降趋势
|
||||||
|
flat = np.zeros(window) # 平稳趋势
|
||||||
|
|
||||||
|
# 测试上升趋势
|
||||||
|
result = news_sentiment_momentum(uptrend)
|
||||||
|
assert result is not None
|
||||||
|
assert result > 0
|
||||||
|
|
||||||
|
# 测试下降趋势
|
||||||
|
result = news_sentiment_momentum(downtrend)
|
||||||
|
assert result is not None
|
||||||
|
assert result < 0
|
||||||
|
|
||||||
|
# 测试平稳趋势
|
||||||
|
result = news_sentiment_momentum(flat)
|
||||||
|
assert result is not None
|
||||||
|
assert abs(result) < 0.1
|
||||||
|
|
||||||
|
# 测试数据不足
|
||||||
|
result = news_sentiment_momentum(uptrend[:10])
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_news_impact_score():
|
||||||
|
# 测试典型场景
|
||||||
|
score = news_impact_score(sentiment=0.8, heat=0.6, entity_count=3)
|
||||||
|
assert 0 <= score <= 1
|
||||||
|
assert score > news_impact_score(sentiment=0.4, heat=0.6, entity_count=3)
|
||||||
|
|
||||||
|
# 测试边界情况
|
||||||
|
assert news_impact_score(sentiment=0, heat=0.5, entity_count=1) == 0
|
||||||
|
assert 0 < news_impact_score(sentiment=1, heat=1, entity_count=10) <= 1
|
||||||
|
|
||||||
|
# 测试实体数量影响
|
||||||
|
low_entity = news_impact_score(sentiment=0.5, heat=0.5, entity_count=1)
|
||||||
|
high_entity = news_impact_score(sentiment=0.5, heat=0.5, entity_count=5)
|
||||||
|
assert high_entity > low_entity
|
||||||
|
|
||||||
|
|
||||||
|
def test_market_sentiment_index():
|
||||||
|
window = 20
|
||||||
|
|
||||||
|
# 生成测试数据
|
||||||
|
sentiment_scores = np.random.uniform(-1, 1, window)
|
||||||
|
heat_scores = np.random.uniform(0, 1, window)
|
||||||
|
volume_ratios = np.random.uniform(0.5, 2, window)
|
||||||
|
|
||||||
|
# 测试正常计算
|
||||||
|
result = market_sentiment_index(
|
||||||
|
sentiment_scores,
|
||||||
|
heat_scores,
|
||||||
|
volume_ratios
|
||||||
|
)
|
||||||
|
assert result is not None
|
||||||
|
assert -1 <= result <= 1
|
||||||
|
|
||||||
|
# 测试数据缺失
|
||||||
|
result = market_sentiment_index(
|
||||||
|
sentiment_scores[:10],
|
||||||
|
heat_scores,
|
||||||
|
volume_ratios
|
||||||
|
)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_industry_sentiment_divergence():
|
||||||
|
# 测试显著背离
|
||||||
|
high_divergence = industry_sentiment_divergence(
|
||||||
|
industry_sentiment=0.8,
|
||||||
|
peer_sentiments=[-0.2, -0.1, 0, 0.1]
|
||||||
|
)
|
||||||
|
assert high_divergence is not None
|
||||||
|
assert high_divergence > 0
|
||||||
|
|
||||||
|
# 测试一致性好
|
||||||
|
low_divergence = industry_sentiment_divergence(
|
||||||
|
industry_sentiment=0.1,
|
||||||
|
peer_sentiments=[0, 0.1, 0.2]
|
||||||
|
)
|
||||||
|
assert low_divergence is not None
|
||||||
|
assert abs(low_divergence) < abs(high_divergence)
|
||||||
|
|
||||||
|
# 测试空数据
|
||||||
|
assert industry_sentiment_divergence(0.5, []) is None
|
||||||
Loading…
Reference in New Issue
Block a user