111 lines
3.9 KiB
Python
111 lines
3.9 KiB
Python
"""Test improved entity recognition in RSS processing."""
|
||
from datetime import datetime, timezone
|
||
import pytest
|
||
|
||
from app.ingest.entity_recognition import CompanyNameMapper, company_mapper
|
||
from app.ingest.rss import RssItem, StockMention
|
||
|
||
def test_company_name_mapper():
|
||
mapper = CompanyNameMapper()
|
||
|
||
# 添加测试公司
|
||
mapper.add_company(
|
||
ts_code="000001.SZ",
|
||
full_name="平安银行股份有限公司",
|
||
short_name="平安银行",
|
||
aliases=["平安", "PAB"]
|
||
)
|
||
|
||
# 测试名称匹配,会找到所有可能的匹配
|
||
matches = mapper.find_codes("平安银行股份有限公司公布2025年业绩")
|
||
assert len([m for m in matches if m[1] == "000001.SZ"]) >= 1
|
||
# 确保有一个全称匹配
|
||
assert any(m[2] == "full_name" and m[1] == "000001.SZ" for m in matches)
|
||
|
||
# 测试简称匹配
|
||
matches = mapper.find_codes("平安银行发布公告")
|
||
# 应该找到简称匹配,可能还有其他匹配
|
||
assert any(m[1] == "000001.SZ" and m[2] == "short_name" for m in matches)
|
||
|
||
# 测试别名匹配
|
||
matches = mapper.find_codes("PAB发布新产品")
|
||
# 应该至少找到别名匹配
|
||
assert any(m[1] == "000001.SZ" and m[2] == "alias" for m in matches)
|
||
|
||
# 测试股票代码直接匹配
|
||
matches = mapper.find_codes("000001.SZ开盘上涨")
|
||
assert len(matches) == 1
|
||
assert matches[0][1] == "000001.SZ"
|
||
assert matches[0][2] == "code"
|
||
|
||
def test_rss_item_entity_extraction():
|
||
# 使用全局company_mapper
|
||
company_mapper.name_to_code.clear() # 清除之前的数据
|
||
company_mapper.add_company(
|
||
ts_code="000001.SZ",
|
||
full_name="平安银行股份有限公司",
|
||
short_name="平安银行",
|
||
aliases=["平安", "PAB"]
|
||
)
|
||
|
||
# 创建测试新闻并跳过数据库初始化
|
||
class TestRssItem(RssItem):
|
||
_skip_db_init = True
|
||
|
||
item = TestRssItem(
|
||
id="test_news",
|
||
title="平安银行发布2025年业绩预告",
|
||
link="http://example.com",
|
||
published=datetime.now(timezone.utc),
|
||
summary="平安银行股份有限公司(000001.SZ)今日发布2025年业绩预告",
|
||
source="test"
|
||
)
|
||
|
||
# 提取实体
|
||
item.extract_entities()
|
||
|
||
# 验证结果:由于优先级机制,只会保留最优的匹配
|
||
matched_types = set(m.match_type for m in item.stock_mentions)
|
||
assert "code" in matched_types or "full_name" in matched_types # 应该至少找到代码或全称匹配
|
||
|
||
# 验证唯一股票代码
|
||
assert len(item.ts_codes) == 1 # 只有一个唯一的股票代码
|
||
assert item.ts_codes[0] == "000001.SZ"
|
||
|
||
# 验证置信度计算
|
||
high_confidence = [m for m in item.stock_mentions if m.confidence > 0.7]
|
||
assert len(high_confidence) >= 1 # 至少应该有一个高置信度的匹配
|
||
|
||
def test_rss_item_context_extraction():
|
||
# 使用全局company_mapper
|
||
company_mapper.name_to_code.clear() # 清除之前的数据
|
||
company_mapper.add_company(
|
||
ts_code="000001.SZ",
|
||
full_name="平安银行股份有限公司",
|
||
short_name="平安银行"
|
||
)
|
||
|
||
# 创建带有上下文的测试新闻并跳过数据库初始化
|
||
class TestRssItem(RssItem):
|
||
_skip_db_init = True
|
||
|
||
item = TestRssItem(
|
||
id="test_news",
|
||
title="多家银行业绩报告",
|
||
link="http://example.com",
|
||
published=datetime.now(timezone.utc),
|
||
summary="在银行业整体向好的背景下,平安银行表现突出,营收增长明显",
|
||
source="test"
|
||
)
|
||
|
||
# 提取实体
|
||
item.extract_entities()
|
||
|
||
# 验证上下文提取
|
||
assert len(item.stock_mentions) > 0
|
||
mention = item.stock_mentions[0]
|
||
assert len(mention.context) <= 70 # 上下文长度限制(30字符前后)
|
||
assert "平安银行" in mention.context
|
||
assert "银行业" in mention.context # 应包含前文
|
||
assert "营收增长" in mention.context # 应包含后文
|