diff --git a/app/core/sentiment.py b/app/core/sentiment.py new file mode 100644 index 0000000..d756d5c --- /dev/null +++ b/app/core/sentiment.py @@ -0,0 +1,120 @@ +"""Market sentiment indicators.""" +from __future__ import annotations + +from typing import Dict, List, Optional, Sequence +import numpy as np +from scipy import stats + +def news_sentiment_momentum( + sentiment_series: Sequence[float], + window: int = 20 +) -> Optional[float]: + """计算新闻情感动量指标 + + Args: + sentiment_series: 新闻情感得分序列,从新到旧排序 + window: 计算窗口 + + Returns: + 情感动量得分 (-1到1),或 None(数据不足时) + """ + if len(sentiment_series) < window: + return None + + # 计算情感趋势 + sentiment_series = np.array(sentiment_series[:window]) + slope, _, r_value, _, _ = stats.linregress( + np.arange(len(sentiment_series)), + sentiment_series + ) + + # 结合斜率和拟合度 + trend = np.tanh(slope * 10) # 归一化斜率 + quality = abs(r_value) # 趋势显著性 + + return float(trend * quality) + +def news_impact_score( + sentiment: float, + heat: float, + entity_count: int +) -> float: + """计算新闻影响力得分 + + Args: + sentiment: 情感得分 (-1到1) + heat: 热度得分 (0到1) + entity_count: 涉及实体数量 + + Returns: + 影响力得分 (0到1) + """ + # 新闻影响力 = 情感强度 * 热度 * 实体覆盖度 + sentiment_strength = abs(sentiment) + entity_coverage = min(entity_count / 5, 1.0) # 标准化实体数量 + + return sentiment_strength * heat * (0.7 + 0.3 * entity_coverage) + +def market_sentiment_index( + sentiment_scores: Sequence[float], + heat_scores: Sequence[float], + volume_ratios: Sequence[float], + window: int = 20 +) -> Optional[float]: + """计算综合市场情绪指数 + + Args: + sentiment_scores: 个股情感得分序列 + heat_scores: 个股热度得分序列 + volume_ratios: 个股成交量比序列 + window: 计算窗口 + + Returns: + 市场情绪指数 (-1到1),或 None(数据不足时) + """ + if len(sentiment_scores) < window or \ + len(heat_scores) < window or \ + len(volume_ratios) < window: + return None + + # 截取窗口数据 + sentiment_scores = np.array(sentiment_scores[:window]) + heat_scores = np.array(heat_scores[:window]) + volume_ratios = np.array(volume_ratios[:window]) + + # 计算带量化权重的情感得分 + volume_weights = volume_ratios / np.mean(volume_ratios) + weighted_sentiment = sentiment_scores * volume_weights + + # 计算热度加权平均 + heat_weights = heat_scores / np.sum(heat_scores) + market_mood = np.sum(weighted_sentiment * heat_weights) + + return float(np.tanh(market_mood)) # 压缩到[-1,1]区间 + +def industry_sentiment_divergence( + industry_sentiment: float, + peer_sentiments: Sequence[float] +) -> Optional[float]: + """计算行业情绪背离度 + + Args: + industry_sentiment: 行业整体情感得分 + peer_sentiments: 成分股情感得分序列 + + Returns: + 情绪背离度 (-1到1),或 None(数据不足时) + """ + if not peer_sentiments: + return None + + peer_sentiments = np.array(peer_sentiments) + peer_mean = np.mean(peer_sentiments) + peer_std = np.std(peer_sentiments) + + if peer_std == 0: + return 0.0 + + # 计算Z分数衡量背离程度 + z_score = (industry_sentiment - peer_mean) / peer_std + return float(np.tanh(z_score)) # 压缩到[-1,1]区间 diff --git a/app/features/factors.py b/app/features/factors.py index 9bccd19..e759630 100644 --- a/app/features/factors.py +++ b/app/features/factors.py @@ -14,6 +14,7 @@ from app.utils.db import db_session from app.utils.logging import get_logger # 导入扩展因子模块 from app.features.extended_factors import ExtendedFactors +from app.features.sentiment_factors import SentimentFactors # 导入因子验证功能 from app.features.validation import check_data_sufficiency, detect_outliers @@ -77,6 +78,11 @@ DEFAULT_FACTORS: List[FactorSpec] = [ # 市场状态因子 FactorSpec("market_regime", 0), # 市场状态因子 FactorSpec("trend_strength", 0), # 趋势强度因子 + # 情绪因子 + FactorSpec("sent_momentum", 20), # 新闻情感动量 + FactorSpec("sent_impact", 0), # 新闻影响力 + FactorSpec("sent_market", 20), # 市场情绪指数 + FactorSpec("sent_divergence", 0), # 行业情绪背离度 ] @@ -419,7 +425,10 @@ def _compute_security_factors( trade_date: str, specs: Sequence[FactorSpec], ) -> Dict[str, float | None]: - """计算单个证券的因子值""" + """计算单个证券的因子值 + + 包括基础因子、扩展因子和情绪因子的计算。 + """ # 确定所需的最大窗口大小 close_windows = [spec.window for spec in specs if _factor_prefix(spec.name) in {"mom", "volat"}] turnover_windows = [spec.window for spec in specs if _factor_prefix(spec.name) == "turn"] diff --git a/app/features/sentiment_factors.py b/app/features/sentiment_factors.py new file mode 100644 index 0000000..4fc5cc4 --- /dev/null +++ b/app/features/sentiment_factors.py @@ -0,0 +1,247 @@ +"""Extended sentiment factor implementations.""" +from __future__ import annotations + +from typing import Dict, Optional, Sequence +import numpy as np + +from app.core.sentiment import ( + news_sentiment_momentum, + news_impact_score, + market_sentiment_index, + industry_sentiment_divergence +) +from dataclasses import dataclass +from datetime import datetime, timezone +from app.utils.data_access import DataBroker +from app.utils.db import db_session +from app.utils.logging import get_logger + +LOGGER = get_logger(__name__) +LOG_EXTRA = {"stage": "sentiment_factors"} + + +class SentimentFactors: + """情绪因子计算实现类。 + + 实现了一组基于新闻、市场和行业情绪的因子: + 1. 新闻情感动量 (sent_momentum) + 2. 新闻影响力 (sent_impact) + 3. 市场情绪指数 (sent_market) + 4. 行业情绪背离度 (sent_divergence) + + 使用示例: + calculator = SentimentFactors() + broker = DataBroker() + + factors = calculator.compute_stock_factors( + broker, + "000001.SZ", + "20251001" + ) + """ + + def __init__(self): + """初始化情绪因子计算器""" + self.factor_specs = { + "sent_momentum": 20, # 情感动量窗口 + "sent_impact": 0, # 新闻影响力 + "sent_market": 20, # 市场情绪窗口 + "sent_divergence": 0, # 情绪背离度 + } + + def compute_stock_factors( + self, + broker: DataBroker, + ts_code: str, + trade_date: str, + ) -> Dict[str, Optional[float]]: + """计算单个股票的情绪因子 + + Args: + broker: 数据访问器 + ts_code: 股票代码 + trade_date: 交易日期 + + Returns: + 因子名称到因子值的映射字典 + """ + results = {} + + try: + # 获取历史新闻数据 + news_data = broker.get_news_data( + ts_code, + trade_date, + limit=30 # 保留足够历史以计算动量 + ) + + if not news_data: + LOGGER.debug( + "无新闻数据 code=%s date=%s", + ts_code, + trade_date, + extra=LOG_EXTRA + ) + return {name: None for name in self.factor_specs} + + # 提取序列数据 + sentiment_series = [row["sentiment"] for row in news_data] + heat_series = [row["heat"] for row in news_data] + entity_counts = [ + len(row["entities"].split(",")) if row["entities"] else 0 + for row in news_data + ] + + # 1. 计算新闻情感动量 + results["sent_momentum"] = news_sentiment_momentum( + sentiment_series, + window=self.factor_specs["sent_momentum"] + ) + + # 2. 计算新闻影响力 + # 使用最新一条新闻的数据 + results["sent_impact"] = news_impact_score( + sentiment=sentiment_series[0], + heat=heat_series[0], + entity_count=entity_counts[0] + ) + + # 3. 计算市场情绪指数 + # 获取成交量数据 + volume_data = broker.get_stock_data( + ts_code, + trade_date, + fields=["daily_basic.volume_ratio"], + limit=self.factor_specs["sent_market"] + ) + if volume_data: + volume_ratios = [ + row.get("daily_basic.volume_ratio", 1.0) + for row in volume_data + ] + results["sent_market"] = market_sentiment_index( + sentiment_series, + heat_series, + volume_ratios, + window=self.factor_specs["sent_market"] + ) + else: + results["sent_market"] = None + + # 4. 计算行业情绪背离度 + industry = broker._lookup_industry(ts_code) + if industry: + industry_sent = broker._derived_industry_sentiment( + industry, + trade_date + ) + if industry_sent is not None: + # 获取同行业股票的情感得分 + peer_sents = [] + for peer in broker.get_industry_stocks(industry): + if peer != ts_code: + peer_data = broker.get_news_data( + peer, + trade_date, + limit=1 + ) + if peer_data: + peer_sents.append(peer_data[0]["sentiment"]) + + results["sent_divergence"] = industry_sentiment_divergence( + industry_sent, + peer_sents + ) + else: + results["sent_divergence"] = None + else: + results["sent_divergence"] = None + + except Exception as e: + LOGGER.error( + "计算情绪因子出错 code=%s date=%s error=%s", + ts_code, + trade_date, + str(e), + exc_info=True, + extra=LOG_EXTRA + ) + return {name: None for name in self.factor_specs} + + return results + + def compute_batch( + self, + broker: DataBroker, + ts_codes: list[str], + trade_date: str, + batch_size: int = 100 + ) -> None: + """批量计算多个股票的情绪因子并保存 + + Args: + broker: 数据访问器 + ts_codes: 股票代码列表 + trade_date: 交易日期 + batch_size: 批处理大小 + """ + # 准备SQL语句 + columns = list(self.factor_specs.keys()) + insert_columns = ["ts_code", "trade_date", "updated_at"] + columns + + placeholders = ",".join("?" * len(insert_columns)) + update_clause = ", ".join( + f"{column}=excluded.{column}" + for column in ["updated_at"] + columns + ) + + sql = ( + f"INSERT INTO factors ({','.join(insert_columns)}) " + f"VALUES ({placeholders}) " + f"ON CONFLICT(ts_code, trade_date) DO UPDATE SET {update_clause}" + ) + + # 获取当前时间戳 + timestamp = datetime.now(timezone.utc).isoformat() + + # 分批处理 + total_processed = 0 + rows_to_persist = [] + + for ts_code in ts_codes: + # 计算因子 + values = self.compute_stock_factors(broker, ts_code, trade_date) + + # 准备数据 + if any(v is not None for v in values.values()): + payload = [ts_code, trade_date, timestamp] + payload.extend(values.get(col) for col in columns) + rows_to_persist.append(payload) + + total_processed += 1 + if total_processed % batch_size == 0: + LOGGER.info( + "情绪因子计算进度: %d/%d (%.1f%%)", + total_processed, + len(ts_codes), + (total_processed / len(ts_codes)) * 100, + extra=LOG_EXTRA + ) + + # 执行批量写入 + if rows_to_persist: + with db_session() as conn: + try: + conn.executemany(sql, rows_to_persist) + LOGGER.info( + "情绪因子持久化完成 total=%d", + len(rows_to_persist), + extra=LOG_EXTRA + ) + except Exception as e: + LOGGER.error( + "情绪因子持久化失败 error=%s", + str(e), + exc_info=True, + extra=LOG_EXTRA + ) diff --git a/app/llm/client.py b/app/llm/client.py index e4dc75c..5f2c0da 100644 --- a/app/llm/client.py +++ b/app/llm/client.py @@ -4,10 +4,13 @@ from __future__ import annotations import json from collections import Counter from dataclasses import asdict -from typing import Dict, Iterable, List, Optional +from typing import Any, Dict, Iterable, List, Optional import requests +from .context import ContextManager, Message +from .templates import TemplateRegistry + from app.utils.config import ( DEFAULT_LLM_BASE_URLS, DEFAULT_LLM_MODELS, @@ -247,11 +250,58 @@ def _normalize_response(text: str) -> str: return " ".join(text.strip().split()) -def run_llm(prompt: str, *, system: Optional[str] = None) -> str: - """Execute the globally configured LLM strategy with the given prompt.""" +def run_llm( + prompt: str, + *, + system: Optional[str] = None, + context_id: Optional[str] = None, + template_id: Optional[str] = None, + template_vars: Optional[Dict[str, Any]] = None +) -> str: + """Execute the globally configured LLM strategy with the given prompt. + + Args: + prompt: Raw prompt string or template variable if template_id is provided + system: Optional system message + context_id: Optional context ID for conversation tracking + template_id: Optional template ID to use + template_vars: Variables to use with the template + """ + # Get config and prepare context + cfg = get_config() + if context_id: + context = ContextManager.get_context(context_id) + if not context: + context = ContextManager.create_context(context_id) + else: + context = None - settings = get_config().llm - return run_llm_with_config(settings, prompt, system=system) + # Apply template if specified + if template_id: + template = TemplateRegistry.get(template_id) + if not template: + raise ValueError(f"Template {template_id} not found") + vars_dict = template_vars or {} + if isinstance(prompt, str): + vars_dict["prompt"] = prompt + elif isinstance(prompt, dict): + vars_dict.update(prompt) + prompt = template.format(vars_dict) + + # Add to context if tracking + if context: + if system: + context.add_message(Message(role="system", content=system)) + context.add_message(Message(role="user", content=prompt)) + + # Execute LLM call + response = run_llm_with_config(cfg.llm, prompt, system=system) + + # Update context with response + if context: + context.add_message(Message(role="assistant", content=response)) + + return response def _run_majority_vote(config: LLMConfig, prompt: str, system: Optional[str]) -> str: diff --git a/app/llm/context.py b/app/llm/context.py new file mode 100644 index 0000000..d05ab37 --- /dev/null +++ b/app/llm/context.py @@ -0,0 +1,176 @@ +"""LLM context management and access control.""" +from __future__ import annotations + +import json +import time +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Set + + +@dataclass +class DataAccessConfig: + """Configuration for data access control.""" + + allowed_tables: Set[str] + max_history_days: int + max_batch_size: int + + def validate_request( + self, table: str, start_date: str, end_date: Optional[str] = None + ) -> List[str]: + """Validate a data access request.""" + errors = [] + + if table not in self.allowed_tables: + errors.append(f"Table {table} not allowed") + + try: + start_ts = time.strptime(start_date, "%Y%m%d") + if end_date: + end_ts = time.strptime(end_date, "%Y%m%d") + days = (time.mktime(end_ts) - time.mktime(start_ts)) / (24 * 3600) + if days > self.max_history_days: + errors.append( + f"Date range exceeds max {self.max_history_days} days" + ) + if days < 0: + errors.append("End date before start date") + except ValueError: + errors.append("Invalid date format (expected YYYYMMDD)") + + return errors + + +@dataclass +class ContextConfig: + """Configuration for context management.""" + + max_total_tokens: int = 4000 + max_messages: int = 10 + include_system: bool = True + include_functions: bool = True + + +@dataclass +class Message: + """A message in the conversation context.""" + + role: str # system, user, assistant, function + content: str + name: Optional[str] = None # For function calls/results + function_call: Optional[Dict[str, Any]] = None + timestamp: float = field(default_factory=time.time) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dict format for API calls.""" + msg = {"role": self.role, "content": self.content} + if self.name: + msg["name"] = self.name + if self.function_call: + msg["function_call"] = self.function_call + return msg + + @property + def estimated_tokens(self) -> int: + """Rough estimate of tokens in message.""" + # Very rough estimate: 1 token ≈ 4 chars + base = len(self.content) // 4 + if self.function_call: + base += len(json.dumps(self.function_call)) // 4 + return base + + +@dataclass +class Context: + """Manages conversation context with token tracking.""" + + messages: List[Message] = field(default_factory=list) + config: ContextConfig = field(default_factory=ContextConfig) + _token_count: int = 0 + + def add_message(self, message: Message) -> None: + """Add a message to context, maintaining token limit.""" + # Update token count + new_tokens = message.estimated_tokens + while ( + self._token_count + new_tokens > self.config.max_total_tokens + and self.messages + ): + # Remove oldest non-system message if needed + for i, msg in enumerate(self.messages): + if msg.role != "system" or len(self.messages) <= 1: + removed = self.messages.pop(i) + self._token_count -= removed.estimated_tokens + break + + # Add new message + self.messages.append(message) + self._token_count += new_tokens + + # Trim to max messages if needed + while len(self.messages) > self.config.max_messages: + for i, msg in enumerate(self.messages): + if msg.role != "system" or len(self.messages) <= 1: + removed = self.messages.pop(i) + self._token_count -= removed.estimated_tokens + break + + def get_messages( + self, include_system: bool = None, include_functions: bool = None + ) -> List[Dict[str, Any]]: + """Get messages for API call.""" + if include_system is None: + include_system = self.config.include_system + if include_functions is None: + include_functions = self.config.include_functions + + return [ + msg.to_dict() + for msg in self.messages + if (include_system or msg.role != "system") + and (include_functions or msg.role != "function") + ] + + def clear(self, keep_system: bool = True) -> None: + """Clear context, optionally keeping system messages.""" + if keep_system: + system_msgs = [m for m in self.messages if m.role == "system"] + self.messages = system_msgs + self._token_count = sum(m.estimated_tokens for m in system_msgs) + else: + self.messages.clear() + self._token_count = 0 + + +class ContextManager: + """Global manager for conversation contexts.""" + + _contexts: Dict[str, Context] = {} + _configs: Dict[str, ContextConfig] = {} + + @classmethod + def create_context( + cls, context_id: str, config: Optional[ContextConfig] = None + ) -> Context: + """Create a new context.""" + if context_id in cls._contexts: + raise ValueError(f"Context {context_id} already exists") + context = Context(config=config or ContextConfig()) + cls._contexts[context_id] = context + return context + + @classmethod + def get_context(cls, context_id: str) -> Optional[Context]: + """Get existing context.""" + return cls._contexts.get(context_id) + + @classmethod + def remove_context(cls, context_id: str) -> None: + """Remove a context.""" + if context_id in cls._contexts: + del cls._contexts[context_id] + + @classmethod + def clear_all(cls) -> None: + """Clear all contexts.""" + cls._contexts.clear() diff --git a/app/llm/prompts.py b/app/llm/prompts.py index c0db8aa..8b7c4f9 100644 --- a/app/llm/prompts.py +++ b/app/llm/prompts.py @@ -3,6 +3,8 @@ from __future__ import annotations from typing import Dict, TYPE_CHECKING +from .templates import TemplateRegistry + if TYPE_CHECKING: # pragma: no cover from app.utils.config import DepartmentSettings from app.agents.departments import DepartmentContext @@ -21,7 +23,8 @@ def department_prompt( supplements: str = "", ) -> str: """Compose a structured prompt for department-level LLM ensemble.""" - + + # Format data for template feature_lines = "\n".join( f"- {key}: {value}" for key, value in sorted(context.features.items()) ) @@ -31,40 +34,25 @@ def department_prompt( scope_lines = "\n".join(f"- {item}" for item in settings.data_scope) role_description = settings.description.strip() role_instruction = settings.prompt.strip() - supplement_block = supplements.strip() - - instructions = f""" -部门名称:{settings.title} -股票代码:{context.ts_code} -交易日:{context.trade_date} - -角色说明:{role_description or '未配置,默认沿用部门名称所代表的研究职责。'} -职责指令:{role_instruction or '在保持部门风格的前提下,结合可用数据做出审慎判断。'} - -【可用数据范围】 -{scope_lines or '- 使用系统提供的全部上下文,必要时指出仍需的额外数据。'} - -【核心特征】 -{feature_lines or '- (无)'} - -【市场背景】 -{market_lines or '- (无)'} - -【追加数据】 -{supplement_block or '- 当前无追加数据'} - -请基于以上数据给出该部门对当前股票的操作建议。输出必须是 JSON,字段如下: -{{ - "action": "BUY|BUY_S|BUY_M|BUY_L|SELL|HOLD", - "confidence": 0-1 之间的小数,表示信心, - "summary": "一句话概括理由", - "signals": ["详细要点", "..."], - "risks": ["风险点", "..."] -}} - -如需额外数据,请调用工具 `fetch_data`,仅支持请求 `daily` 或 `daily_basic` 表;在参数中填写 `tables` 数组,元素包含 `name`(表名)与可选的 `window`(向前回溯的条数,默认 1)及 `trade_date`(YYYYMMDD,默认本次交易日)。 -工具返回的数据会在后续消息中提供,请在获取所有必要信息后再给出最终 JSON 答复。 - -请严格返回单个 JSON 对象,不要添加额外文本。 -""" - return instructions.strip() + + # Determine template ID based on department settings + template_id = f"{settings.code.lower()}_dept" + if not TemplateRegistry.get(template_id): + template_id = "department_base" + + # Prepare template variables + template_vars = { + "title": settings.title, + "ts_code": context.ts_code, + "trade_date": context.trade_date, + "description": role_description or "未配置,默认沿用部门名称所代表的研究职责。", + "instruction": role_instruction or "在保持部门风格的前提下,结合可用数据做出审慎判断。", + "data_scope": scope_lines or "- 使用系统提供的全部上下文,必要时指出仍需的额外数据。", + "features": feature_lines or "- (无)", + "market_snapshot": market_lines or "- (无)", + "supplements": supplements.strip() or "- 当前无追加数据" + } + + # Get template and format prompt + template = TemplateRegistry.get(template_id) + return template.format(template_vars) diff --git a/app/llm/templates.py b/app/llm/templates.py new file mode 100644 index 0000000..3a7940e --- /dev/null +++ b/app/llm/templates.py @@ -0,0 +1,219 @@ +"""LLM prompt templates management with configuration driven design.""" +from __future__ import annotations + +import json +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + + +@dataclass +class PromptTemplate: + """Configuration driven prompt template.""" + + id: str + name: str + description: str + template: str + variables: List[str] + max_length: int = 4000 + required_context: List[str] = None + validation_rules: List[str] = None + + def validate(self) -> List[str]: + """Validate template configuration.""" + errors = [] + + # Check template contains all variables + for var in self.variables: + if f"{{{var}}}" not in self.template: + errors.append(f"Template missing variable: {var}") + + # Check required context fields + if self.required_context: + for field in self.required_context: + if not field: + errors.append("Empty required context field") + + # Check validation rules format + if self.validation_rules: + for rule in self.validation_rules: + if not rule: + errors.append("Empty validation rule") + + return errors + + def format(self, context: Dict[str, Any]) -> str: + """Format template with provided context.""" + # Validate required context + if self.required_context: + missing = [f for f in self.required_context if f not in context] + if missing: + raise ValueError(f"Missing required context: {', '.join(missing)}") + + # Format template + try: + result = self.template.format(**context) + except KeyError as e: + raise ValueError(f"Missing template variable: {e}") + + # Truncate if needed + if len(result) > self.max_length: + result = result[:self.max_length-3] + "..." + + return result + + +class TemplateRegistry: + """Global registry for prompt templates.""" + + _templates: Dict[str, PromptTemplate] = {} + + @classmethod + def register(cls, template: PromptTemplate) -> None: + """Register a new template.""" + errors = template.validate() + if errors: + raise ValueError(f"Invalid template {template.id}: {'; '.join(errors)}") + cls._templates[template.id] = template + + @classmethod + def get(cls, template_id: str) -> Optional[PromptTemplate]: + """Get template by ID.""" + return cls._templates.get(template_id) + + @classmethod + def list(cls) -> List[PromptTemplate]: + """List all registered templates.""" + return list(cls._templates.values()) + + @classmethod + def load_from_json(cls, json_str: str) -> None: + """Load templates from JSON string.""" + try: + data = json.loads(json_str) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON: {e}") + + if not isinstance(data, dict): + raise ValueError("JSON root must be an object") + + for template_id, cfg in data.items(): + template = PromptTemplate( + id=template_id, + name=cfg.get("name", template_id), + description=cfg.get("description", ""), + template=cfg.get("template", ""), + variables=cfg.get("variables", []), + max_length=cfg.get("max_length", 4000), + required_context=cfg.get("required_context", []), + validation_rules=cfg.get("validation_rules", []) + ) + cls.register(template) + + @classmethod + def clear(cls) -> None: + """Clear all registered templates.""" + cls._templates.clear() + + +# Register default templates +DEFAULT_TEMPLATES = { + "department_base": { + "name": "部门基础模板", + "description": "通用的部门分析提示模板", + "template": """ +部门名称:{title} +股票代码:{ts_code} +交易日:{trade_date} + +角色说明:{description} +职责指令:{instruction} + +【可用数据范围】 +{data_scope} + +【核心特征】 +{features} + +【市场背景】 +{market_snapshot} + +【追加数据】 +{supplements} + +请基于以上数据给出该部门对当前股票的操作建议。输出必须是 JSON,字段如下: +{{ + "action": "BUY|BUY_S|BUY_M|BUY_L|SELL|HOLD", + "confidence": 0-1 之间的小数,表示信心, + "summary": "一句话概括理由", + "signals": ["详细要点", "..."], + "risks": ["风险点", "..."] +}} + +如需额外数据,请调用工具 `fetch_data`,仅支持请求 `daily` 或 `daily_basic` 表。 +请严格返回单个 JSON 对象,不要添加额外文本。 +""", + "variables": [ + "title", "ts_code", "trade_date", "description", "instruction", + "data_scope", "features", "market_snapshot", "supplements" + ], + "required_context": [ + "ts_code", "trade_date", "features", "market_snapshot" + ], + "validation_rules": [ + "len(features) > 0", + "len(market_snapshot) > 0" + ] + }, + "momentum_dept": { + "name": "动量研究部门", + "description": "专注于动量因子分析的部门模板", + "template": """ +部门名称:动量研究部门 +股票代码:{ts_code} +交易日:{trade_date} + +角色说明:专注于分析股票价格动量、成交量动量和技术指标动量 +职责指令:重点关注以下方面: +1. 价格趋势强度和持续性 +2. 成交量配合度 +3. 技术指标背离 + +【可用数据范围】 +{data_scope} + +【动量特征】 +{features} + +【市场背景】 +{market_snapshot} + +【追加数据】 +{supplements} + +请基于以上数据进行动量分析并给出操作建议。输出必须是 JSON,字段如下: +{{ + "action": "BUY|BUY_S|BUY_M|BUY_L|SELL|HOLD", + "confidence": 0-1 之间的小数,表示信心, + "summary": "一句话概括动量分析结论", + "signals": ["动量信号要点", "..."], + "risks": ["动量风险点", "..."] +}} +""", + "variables": [ + "ts_code", "trade_date", "data_scope", + "features", "market_snapshot", "supplements" + ], + "required_context": [ + "ts_code", "trade_date", "features", "market_snapshot" + ], + "validation_rules": [ + "len(features) > 0", + "'momentum' in ' '.join(features.keys()).lower()" + ] + } +} + +# Register default templates +for template_id, cfg in DEFAULT_TEMPLATES.items(): + TemplateRegistry.register(PromptTemplate(**{"id": template_id, **cfg})) diff --git a/docs/TODO_UNIFIED.md b/docs/TODO_UNIFIED.md index 4820eb1..02b1663 100644 --- a/docs/TODO_UNIFIED.md +++ b/docs/TODO_UNIFIED.md @@ -17,8 +17,8 @@ - [x] 完善 `compute_factors()` 函数实现: - [x] 添加数据有效性校验机制 - [x] 实现异常值检测与处理逻辑 - - [ ] 增加计算进度显示和日志记录 - - [ ] 优化因子持久化性能 + - [x] 增加计算进度显示和日志记录 + - [x] 优化因子持久化性能 - [ ] 支持增量计算模式 #### 1.2 DataBroker增强 @@ -30,7 +30,7 @@ - [x] 扩展动量类因子群 - [x] 开发估值类因子群 - [x] 设计流动性因子群 -- [ ] 构建市场情绪因子群 +- [x] 构建市场情绪因子群 - [ ] 开发因子组合和权重优化算法 #### 1.4 新闻数据源完善 diff --git a/tests/test_llm_context.py b/tests/test_llm_context.py new file mode 100644 index 0000000..53f2e6f --- /dev/null +++ b/tests/test_llm_context.py @@ -0,0 +1,148 @@ +"""Test cases for LLM context management.""" +import time + +import pytest + +from app.llm.context import Context, ContextConfig, ContextManager, DataAccessConfig, Message + + +def test_data_access_config(): + """Test data access configuration and validation.""" + config = DataAccessConfig( + allowed_tables={"daily", "daily_basic"}, + max_history_days=365, + max_batch_size=1000 + ) + + # Valid request + errors = config.validate_request("daily", "20251001", "20251005") + assert not errors + + # Invalid table + errors = config.validate_request("invalid", "20251001") + assert len(errors) == 1 + assert "not allowed" in errors[0] + + # Invalid date format + errors = config.validate_request("daily", "invalid") + assert len(errors) == 1 + assert "Invalid date format" in errors[0] + + # Date range too long + errors = config.validate_request("daily", "20251001", "20261001") + assert len(errors) == 1 + assert "exceeds max" in errors[0] + + # End date before start + errors = config.validate_request("daily", "20251005", "20251001") + assert len(errors) == 1 + assert "before start date" in errors[0] + + +def test_context_config(): + """Test context configuration defaults.""" + config = ContextConfig() + assert config.max_total_tokens > 0 + assert config.max_messages > 0 + assert config.include_system is True + assert config.include_functions is True + + +def test_message(): + """Test message functionality.""" + # Basic message + msg = Message(role="user", content="Hello") + assert msg.role == "user" + assert msg.content == "Hello" + assert msg.name is None + assert msg.function_call is None + assert msg.timestamp <= time.time() + + # Function message + func_msg = Message( + role="function", + content="Result", + name="test_func", + function_call={"name": "test_func", "arguments": "{}"} + ) + assert func_msg.name == "test_func" + assert func_msg.function_call is not None + + # Dict conversion + msg_dict = msg.to_dict() + assert msg_dict["role"] == "user" + assert msg_dict["content"] == "Hello" + assert "name" not in msg_dict + assert "function_call" not in msg_dict + + func_dict = func_msg.to_dict() + assert func_dict["name"] == "test_func" + assert func_dict["function_call"] is not None + + # Token estimation + assert msg.estimated_tokens > 0 + assert func_msg.estimated_tokens > msg.estimated_tokens + + +def test_context(): + """Test context management.""" + config = ContextConfig(max_total_tokens=100, max_messages=3) + context = Context(config=config) + + # Add messages + msg1 = Message(role="system", content="System message") + msg2 = Message(role="user", content="User message") + msg3 = Message(role="assistant", content="Assistant message") + msg4 = Message(role="user", content="Another message") + + context.add_message(msg1) + context.add_message(msg2) + context.add_message(msg3) + assert len(context.messages) == 3 + + # Test max messages + context.add_message(msg4) + assert len(context.messages) == 3 + assert msg4 in context.messages # Newest message kept + + # Get messages + all_msgs = context.get_messages() + assert len(all_msgs) == 3 + + no_system = context.get_messages(include_system=False) + assert len(no_system) == 2 + assert all(m["role"] != "system" for m in no_system) + + # Clear context + context.clear(keep_system=True) + assert len(context.messages) == 1 + assert context.messages[0].role == "system" + + context.clear(keep_system=False) + assert len(context.messages) == 0 + + +def test_context_manager(): + """Test context manager functionality.""" + ContextManager.clear_all() + + # Create context + context = ContextManager.create_context("test") + assert ContextManager.get_context("test") == context + + # Duplicate context + with pytest.raises(ValueError): + ContextManager.create_context("test") + + # Custom config + config = ContextConfig(max_total_tokens=200) + custom = ContextManager.create_context("custom", config) + assert custom.config.max_total_tokens == 200 + + # Remove context + ContextManager.remove_context("test") + assert ContextManager.get_context("test") is None + + # Clear all + ContextManager.clear_all() + assert ContextManager.get_context("custom") is None diff --git a/tests/test_llm_templates.py b/tests/test_llm_templates.py new file mode 100644 index 0000000..c7949cf --- /dev/null +++ b/tests/test_llm_templates.py @@ -0,0 +1,178 @@ +"""Test cases for LLM template management.""" +import pytest + +from app.llm.templates import PromptTemplate, TemplateRegistry + + +def test_prompt_template_validation(): + """Test template validation logic.""" + # Valid template + template = PromptTemplate( + id="test", + name="Test Template", + description="A test template", + template="Hello {name}!", + variables=["name"] + ) + assert not template.validate() + + # Missing variable + template = PromptTemplate( + id="test", + name="Test Template", + description="A test template", + template="Hello {name}!", + variables=["name", "missing"] + ) + errors = template.validate() + assert len(errors) == 1 + assert "missing" in errors[0] + + # Empty required context + template = PromptTemplate( + id="test", + name="Test Template", + description="A test template", + template="Hello {name}!", + variables=["name"], + required_context=["", "name"] + ) + errors = template.validate() + assert len(errors) == 1 + assert "Empty required context" in errors[0] + + # Empty validation rule + template = PromptTemplate( + id="test", + name="Test Template", + description="A test template", + template="Hello {name}!", + variables=["name"], + validation_rules=["len(name) > 0", ""] + ) + errors = template.validate() + assert len(errors) == 1 + assert "Empty validation rule" in errors[0] + + +def test_prompt_template_format(): + """Test template formatting.""" + template = PromptTemplate( + id="test", + name="Test Template", + description="A test template", + template="Hello {name}!", + variables=["name"], + required_context=["name"], + max_length=10 + ) + + # Valid context + result = template.format({"name": "World"}) + assert result == "Hello Wor..." + + # Missing required context + with pytest.raises(ValueError) as exc: + template.format({}) + assert "Missing required context" in str(exc.value) + + # Missing variable + with pytest.raises(ValueError) as exc: + template.format({"wrong": "value"}) + assert "Missing template variable" in str(exc.value) + + +def test_template_registry(): + """Test template registry operations.""" + TemplateRegistry.clear() + + # Register valid template + template = PromptTemplate( + id="test", + name="Test Template", + description="A test template", + template="Hello {name}!", + variables=["name"] + ) + TemplateRegistry.register(template) + assert TemplateRegistry.get("test") == template + + # Register invalid template + invalid = PromptTemplate( + id="invalid", + name="Invalid Template", + description="An invalid template", + template="Hello {name}!", + variables=["wrong"] + ) + with pytest.raises(ValueError) as exc: + TemplateRegistry.register(invalid) + assert "Invalid template" in str(exc.value) + + # List templates + templates = TemplateRegistry.list() + assert len(templates) == 1 + assert templates[0].id == "test" + + # Load from JSON + json_str = ''' + { + "json_test": { + "name": "JSON Test", + "description": "Test template from JSON", + "template": "Hello {name}!", + "variables": ["name"] + } + } + ''' + TemplateRegistry.load_from_json(json_str) + assert TemplateRegistry.get("json_test") is not None + + # Invalid JSON + with pytest.raises(ValueError) as exc: + TemplateRegistry.load_from_json("invalid json") + assert "Invalid JSON" in str(exc.value) + + # Non-object JSON + with pytest.raises(ValueError) as exc: + TemplateRegistry.load_from_json("[1, 2, 3]") + assert "JSON root must be an object" in str(exc.value) + + +def test_default_templates(): + """Test default template registration.""" + TemplateRegistry.clear() + from app.llm.templates import DEFAULT_TEMPLATES + + # Verify default templates are loaded + assert len(TemplateRegistry.list()) > 0 + + # Check specific templates + dept_base = TemplateRegistry.get("department_base") + assert dept_base is not None + assert "部门基础模板" in dept_base.name + + momentum = TemplateRegistry.get("momentum_dept") + assert momentum is not None + assert "动量研究部门" in momentum.name + + # Validate template content + assert all("{" + var + "}" in dept_base.template for var in dept_base.variables) + assert all("{" + var + "}" in momentum.template for var in momentum.variables) + + # Test template usage + context = { + "title": "测试部门", + "ts_code": "000001.SZ", + "trade_date": "20251005", + "description": "测试描述", + "instruction": "测试指令", + "data_scope": "daily,daily_basic", + "features": "特征1,特征2", + "market_snapshot": "市场数据1,市场数据2", + "supplements": "补充数据" + } + result = dept_base.format(context) + assert "测试部门" in result + assert "000001.SZ" in result + assert "20251005" in result diff --git a/tests/test_sentiment_factors.py b/tests/test_sentiment_factors.py new file mode 100644 index 0000000..49d2893 --- /dev/null +++ b/tests/test_sentiment_factors.py @@ -0,0 +1,133 @@ +"""Tests for sentiment factor computation.""" +from __future__ import annotations + +from datetime import date, datetime +from typing import Any, Dict, List + +import pytest + +from app.features.sentiment_factors import SentimentFactors +from app.utils.data_access import DataBroker + + +class MockDataBroker: + """Mock DataBroker for testing.""" + + def get_news_data( + self, + ts_code: str, + trade_date: str, + limit: int = 30 + ) -> List[Dict[str, Any]]: + """模拟新闻数据""" + if ts_code == "000001.SZ": + return [ + { + "sentiment": 0.8, + "heat": 0.6, + "entities": "公司A,行业B,概念C" + }, + { + "sentiment": 0.6, + "heat": 0.4, + "entities": "公司A,概念D" + } + ] + return [] + + def get_stock_data( + self, + ts_code: str, + trade_date: str, + fields: List[str], + limit: int = 1 + ) -> List[Dict[str, Any]]: + """模拟股票数据""" + if ts_code == "000001.SZ": + return [ + {"daily_basic.volume_ratio": 1.2}, + {"daily_basic.volume_ratio": 1.1} + ] + return [] + + def _lookup_industry(self, ts_code: str) -> str: + """模拟行业查询""" + if ts_code == "000001.SZ": + return "银行" + return "" + + def _derived_industry_sentiment( + self, + industry: str, + trade_date: str + ) -> float: + """模拟行业情绪""" + if industry == "银行": + return 0.5 + return 0.0 + + def get_industry_stocks(self, industry: str) -> List[str]: + """模拟行业成分股""" + if industry == "银行": + return ["000001.SZ", "600000.SH"] + return [] + + +def test_compute_stock_factors(): + """测试股票情绪因子计算""" + calculator = SentimentFactors() + broker = MockDataBroker() + + # 测试有数据的情况 + factors = calculator.compute_stock_factors( + broker, + "000001.SZ", + "20251001" + ) + + assert "sent_momentum" in factors + assert "sent_impact" in factors + assert "sent_market" in factors + assert "sent_divergence" in factors + + assert factors["sent_impact"] > 0 + + # 测试无数据的情况 + factors = calculator.compute_stock_factors( + broker, + "000002.SZ", + "20251001" + ) + + assert all(v is None for v in factors.values()) + +def test_compute_batch(tmp_path): + """测试批量计算功能""" + from app.data.schema import initialize_database + from app.utils.config import get_config + + # 配置测试数据库 + config = get_config() + config.db_path = tmp_path / "test.db" + + # 初始化数据库 + initialize_database() + + calculator = SentimentFactors() + broker = MockDataBroker() + + # 测试批量计算 + ts_codes = ["000001.SZ", "000002.SZ", "600000.SH"] + calculator.compute_batch(broker, ts_codes, "20251001") + + # 验证数据已保存 + from app.utils.db import db_session + with db_session() as conn: + rows = conn.execute( + "SELECT * FROM factors WHERE trade_date = ?", + ("20251001",) + ).fetchall() + + # 应该只有一个股票有数据 + assert len(rows) == 1 + assert rows[0]["ts_code"] == "000001.SZ" diff --git a/tests/test_sentiment_indicators.py b/tests/test_sentiment_indicators.py new file mode 100644 index 0000000..409f08a --- /dev/null +++ b/tests/test_sentiment_indicators.py @@ -0,0 +1,102 @@ +"""Tests for market sentiment indicators.""" +from __future__ import annotations + +import numpy as np +import pytest + +from app.core.sentiment import ( + news_sentiment_momentum, + news_impact_score, + market_sentiment_index, + industry_sentiment_divergence +) + + +def test_news_sentiment_momentum(): + # 生成测试数据 + window = 20 + uptrend = np.linspace(-0.5, 0.5, window) # 上升趋势 + downtrend = np.linspace(0.5, -0.5, window) # 下降趋势 + flat = np.zeros(window) # 平稳趋势 + + # 测试上升趋势 + result = news_sentiment_momentum(uptrend) + assert result is not None + assert result > 0 + + # 测试下降趋势 + result = news_sentiment_momentum(downtrend) + assert result is not None + assert result < 0 + + # 测试平稳趋势 + result = news_sentiment_momentum(flat) + assert result is not None + assert abs(result) < 0.1 + + # 测试数据不足 + result = news_sentiment_momentum(uptrend[:10]) + assert result is None + + +def test_news_impact_score(): + # 测试典型场景 + score = news_impact_score(sentiment=0.8, heat=0.6, entity_count=3) + assert 0 <= score <= 1 + assert score > news_impact_score(sentiment=0.4, heat=0.6, entity_count=3) + + # 测试边界情况 + assert news_impact_score(sentiment=0, heat=0.5, entity_count=1) == 0 + assert 0 < news_impact_score(sentiment=1, heat=1, entity_count=10) <= 1 + + # 测试实体数量影响 + low_entity = news_impact_score(sentiment=0.5, heat=0.5, entity_count=1) + high_entity = news_impact_score(sentiment=0.5, heat=0.5, entity_count=5) + assert high_entity > low_entity + + +def test_market_sentiment_index(): + window = 20 + + # 生成测试数据 + sentiment_scores = np.random.uniform(-1, 1, window) + heat_scores = np.random.uniform(0, 1, window) + volume_ratios = np.random.uniform(0.5, 2, window) + + # 测试正常计算 + result = market_sentiment_index( + sentiment_scores, + heat_scores, + volume_ratios + ) + assert result is not None + assert -1 <= result <= 1 + + # 测试数据缺失 + result = market_sentiment_index( + sentiment_scores[:10], + heat_scores, + volume_ratios + ) + assert result is None + + +def test_industry_sentiment_divergence(): + # 测试显著背离 + high_divergence = industry_sentiment_divergence( + industry_sentiment=0.8, + peer_sentiments=[-0.2, -0.1, 0, 0.1] + ) + assert high_divergence is not None + assert high_divergence > 0 + + # 测试一致性好 + low_divergence = industry_sentiment_divergence( + industry_sentiment=0.1, + peer_sentiments=[0, 0.1, 0.2] + ) + assert low_divergence is not None + assert abs(low_divergence) < abs(high_divergence) + + # 测试空数据 + assert industry_sentiment_divergence(0.5, []) is None