llm-quant/tests/test_sentiment_factors.py
2025-10-05 16:28:53 +08:00

134 lines
3.4 KiB
Python

"""Tests for sentiment factor computation."""
from __future__ import annotations
from datetime import date, datetime
from typing import Any, Dict, List
import pytest
from app.features.sentiment_factors import SentimentFactors
from app.utils.data_access import DataBroker
class MockDataBroker:
"""Mock DataBroker for testing."""
def get_news_data(
self,
ts_code: str,
trade_date: str,
limit: int = 30
) -> List[Dict[str, Any]]:
"""模拟新闻数据"""
if ts_code == "000001.SZ":
return [
{
"sentiment": 0.8,
"heat": 0.6,
"entities": "公司A,行业B,概念C"
},
{
"sentiment": 0.6,
"heat": 0.4,
"entities": "公司A,概念D"
}
]
return []
def get_stock_data(
self,
ts_code: str,
trade_date: str,
fields: List[str],
limit: int = 1
) -> List[Dict[str, Any]]:
"""模拟股票数据"""
if ts_code == "000001.SZ":
return [
{"daily_basic.volume_ratio": 1.2},
{"daily_basic.volume_ratio": 1.1}
]
return []
def _lookup_industry(self, ts_code: str) -> str:
"""模拟行业查询"""
if ts_code == "000001.SZ":
return "银行"
return ""
def _derived_industry_sentiment(
self,
industry: str,
trade_date: str
) -> float:
"""模拟行业情绪"""
if industry == "银行":
return 0.5
return 0.0
def get_industry_stocks(self, industry: str) -> List[str]:
"""模拟行业成分股"""
if industry == "银行":
return ["000001.SZ", "600000.SH"]
return []
def test_compute_stock_factors():
"""测试股票情绪因子计算"""
calculator = SentimentFactors()
broker = MockDataBroker()
# 测试有数据的情况
factors = calculator.compute_stock_factors(
broker,
"000001.SZ",
"20251001"
)
assert "sent_momentum" in factors
assert "sent_impact" in factors
assert "sent_market" in factors
assert "sent_divergence" in factors
assert factors["sent_impact"] > 0
# 测试无数据的情况
factors = calculator.compute_stock_factors(
broker,
"000002.SZ",
"20251001"
)
assert all(v is None for v in factors.values())
def test_compute_batch(tmp_path):
"""测试批量计算功能"""
from app.data.schema import initialize_database
from app.utils.config import get_config
# 配置测试数据库
config = get_config()
config.db_path = tmp_path / "test.db"
# 初始化数据库
initialize_database()
calculator = SentimentFactors()
broker = MockDataBroker()
# 测试批量计算
ts_codes = ["000001.SZ", "000002.SZ", "600000.SH"]
calculator.compute_batch(broker, ts_codes, "20251001")
# 验证数据已保存
from app.utils.db import db_session
with db_session() as conn:
rows = conn.execute(
"SELECT * FROM factors WHERE trade_date = ?",
("20251001",)
).fetchall()
# 应该只有一个股票有数据
assert len(rows) == 1
assert rows[0]["ts_code"] == "000001.SZ"