update
This commit is contained in:
parent
8083c9ffab
commit
5228ea1c41
@ -1,8 +1,8 @@
|
||||
# 多智能体投资助理骨架
|
||||
# 多智能体个人投资助理
|
||||
|
||||
## 项目简介
|
||||
|
||||
本仓库提供一个面向 A 股日线级别的多智能体投资助理原型,覆盖数据采集、特征抽取、策略博弈、回测展示和 LLM 解释链路。代码以模块化骨架形式呈现,方便在单机环境下快速搭建端到端的量化研究和可视化决策流程。
|
||||
本仓库提供一个面向 A 股日线级别的多智能体个人投资助理原型,覆盖数据采集、特征抽取、策略博弈、回测展示和 LLM 解释链路。代码以模块化骨架形式呈现,方便在单机环境下快速搭建端到端的量化研究和可视化决策流程。
|
||||
|
||||
## 架构总览
|
||||
|
||||
|
||||
@ -1,25 +1,237 @@
|
||||
"""RSS ingestion for news and heat scores."""
|
||||
"""RSS ingestion utilities for news sentiment and heat scoring."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Iterable, List
|
||||
import json
|
||||
import re
|
||||
import sqlite3
|
||||
from dataclasses import dataclass, replace
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from email.utils import parsedate_to_datetime
|
||||
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
|
||||
from urllib.parse import urlparse, urljoin
|
||||
from xml.etree import ElementTree as ET
|
||||
|
||||
import requests
|
||||
from requests import RequestException
|
||||
|
||||
import hashlib
|
||||
import random
|
||||
import time
|
||||
|
||||
try: # pragma: no cover - optional dependency at runtime
|
||||
import feedparser # type: ignore[import-not-found]
|
||||
except ImportError: # pragma: no cover - graceful fallback
|
||||
feedparser = None # type: ignore[assignment]
|
||||
|
||||
from app.data.schema import initialize_database
|
||||
from app.utils import alerts
|
||||
from app.utils.config import get_config
|
||||
from app.utils.db import db_session
|
||||
from app.utils.logging import get_logger
|
||||
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
LOG_EXTRA = {"stage": "rss_ingest"}
|
||||
|
||||
DEFAULT_TIMEOUT = 10.0
|
||||
MAX_SUMMARY_LENGTH = 1500
|
||||
|
||||
POSITIVE_KEYWORDS: Tuple[str, ...] = (
|
||||
"利好",
|
||||
"增长",
|
||||
"超预期",
|
||||
"创新高",
|
||||
"增持",
|
||||
"回购",
|
||||
"盈利",
|
||||
"strong",
|
||||
"beat",
|
||||
"upgrade",
|
||||
)
|
||||
NEGATIVE_KEYWORDS: Tuple[str, ...] = (
|
||||
"利空",
|
||||
"下跌",
|
||||
"亏损",
|
||||
"裁员",
|
||||
"违约",
|
||||
"处罚",
|
||||
"暴跌",
|
||||
"减持",
|
||||
"downgrade",
|
||||
"miss",
|
||||
)
|
||||
|
||||
A_SH_CODE_PATTERN = re.compile(r"\b(\d{6})(?:\.(SH|SZ))?\b", re.IGNORECASE)
|
||||
HK_CODE_PATTERN = re.compile(r"\b(\d{4})\.HK\b", re.IGNORECASE)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RssFeedConfig:
|
||||
"""Configuration describing a single RSS source."""
|
||||
|
||||
url: str
|
||||
source: str
|
||||
ts_codes: Tuple[str, ...] = ()
|
||||
keywords: Tuple[str, ...] = ()
|
||||
hours_back: int = 48
|
||||
max_items: int = 50
|
||||
|
||||
|
||||
@dataclass
|
||||
class RssItem:
|
||||
"""Structured representation of an RSS entry."""
|
||||
|
||||
id: str
|
||||
title: str
|
||||
link: str
|
||||
published: datetime
|
||||
summary: str
|
||||
source: str
|
||||
ts_codes: Tuple[str, ...] = ()
|
||||
|
||||
|
||||
def fetch_rss_feed(url: str) -> List[RssItem]:
|
||||
DEFAULT_RSS_SOURCES: Tuple[RssFeedConfig, ...] = ()
|
||||
|
||||
|
||||
def fetch_rss_feed(
|
||||
url: str,
|
||||
*,
|
||||
source: Optional[str] = None,
|
||||
hours_back: int = 48,
|
||||
max_items: int = 50,
|
||||
timeout: float = DEFAULT_TIMEOUT,
|
||||
max_retries: int = 5,
|
||||
retry_backoff: float = 1.5,
|
||||
retry_jitter: float = 0.3,
|
||||
) -> List[RssItem]:
|
||||
"""Download and parse an RSS feed into structured items."""
|
||||
|
||||
raise NotImplementedError
|
||||
return _fetch_feed_items(
|
||||
url,
|
||||
source=source,
|
||||
hours_back=hours_back,
|
||||
max_items=max_items,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
retry_backoff=retry_backoff,
|
||||
retry_jitter=retry_jitter,
|
||||
allow_html_redirect=True,
|
||||
)
|
||||
|
||||
|
||||
def _fetch_feed_items(
|
||||
url: str,
|
||||
*,
|
||||
source: Optional[str],
|
||||
hours_back: int,
|
||||
max_items: int,
|
||||
timeout: float,
|
||||
max_retries: int,
|
||||
retry_backoff: float,
|
||||
retry_jitter: float,
|
||||
allow_html_redirect: bool,
|
||||
) -> List[RssItem]:
|
||||
|
||||
content = _download_feed(
|
||||
url,
|
||||
timeout,
|
||||
max_retries=max_retries,
|
||||
retry_backoff=retry_backoff,
|
||||
retry_jitter=retry_jitter,
|
||||
)
|
||||
if content is None:
|
||||
return []
|
||||
|
||||
if allow_html_redirect:
|
||||
feed_links = _extract_html_feed_links(content, url)
|
||||
if feed_links:
|
||||
LOGGER.info(
|
||||
"RSS 页面包含子订阅 %s 个,自动展开",
|
||||
len(feed_links),
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
aggregated: List[RssItem] = []
|
||||
for feed_url in feed_links:
|
||||
sub_items = _fetch_feed_items(
|
||||
feed_url,
|
||||
source=source,
|
||||
hours_back=hours_back,
|
||||
max_items=max_items,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
retry_backoff=retry_backoff,
|
||||
retry_jitter=retry_jitter,
|
||||
allow_html_redirect=False,
|
||||
)
|
||||
aggregated.extend(sub_items)
|
||||
if max_items > 0 and len(aggregated) >= max_items:
|
||||
return aggregated[:max_items]
|
||||
if aggregated:
|
||||
alerts.clear_warnings(_rss_source_key(url))
|
||||
else:
|
||||
alerts.add_warning(
|
||||
_rss_source_key(url),
|
||||
"聚合页未返回内容",
|
||||
)
|
||||
return aggregated
|
||||
|
||||
parsed_entries = _parse_feed_content(content)
|
||||
total_entries = len(parsed_entries)
|
||||
LOGGER.info(
|
||||
"RSS 源获取完成 url=%s raw_entries=%s",
|
||||
url,
|
||||
total_entries,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
if not parsed_entries:
|
||||
LOGGER.warning(
|
||||
"RSS 无可解析条目 url=%s snippet=%s",
|
||||
url,
|
||||
_safe_snippet(content),
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
return []
|
||||
|
||||
cutoff = datetime.utcnow() - timedelta(hours=max(1, hours_back))
|
||||
source_name = source or _source_from_url(url)
|
||||
items: List[RssItem] = []
|
||||
seen_ids: set[str] = set()
|
||||
for entry in parsed_entries:
|
||||
published = entry.get("published") or datetime.utcnow()
|
||||
if published < cutoff:
|
||||
continue
|
||||
title = _clean_text(entry.get("title", ""))
|
||||
summary = _clean_text(entry.get("summary", ""))
|
||||
link = entry.get("link", "")
|
||||
raw_id = entry.get("id") or link
|
||||
item_id = _normalise_item_id(raw_id, link, title, published)
|
||||
if item_id in seen_ids:
|
||||
continue
|
||||
seen_ids.add(item_id)
|
||||
items.append(
|
||||
RssItem(
|
||||
id=item_id,
|
||||
title=title,
|
||||
link=link,
|
||||
published=published,
|
||||
summary=_truncate(summary, MAX_SUMMARY_LENGTH),
|
||||
source=source_name,
|
||||
)
|
||||
)
|
||||
if len(items) >= max_items > 0:
|
||||
break
|
||||
|
||||
LOGGER.info(
|
||||
"RSS 过滤结果 url=%s within_window=%s unique=%s",
|
||||
url,
|
||||
sum(1 for entry in parsed_entries if (entry.get("published") or datetime.utcnow()) >= cutoff),
|
||||
len(items),
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
if items:
|
||||
alerts.clear_warnings(_rss_source_key(url))
|
||||
|
||||
return items
|
||||
|
||||
|
||||
def deduplicate_items(items: Iterable[RssItem]) -> List[RssItem]:
|
||||
@ -36,7 +248,617 @@ def deduplicate_items(items: Iterable[RssItem]) -> List[RssItem]:
|
||||
return unique
|
||||
|
||||
|
||||
def save_news_items(items: Iterable[RssItem]) -> None:
|
||||
def save_news_items(items: Iterable[RssItem]) -> int:
|
||||
"""Persist RSS items into the `news` table."""
|
||||
|
||||
raise NotImplementedError
|
||||
initialize_database()
|
||||
now = datetime.utcnow()
|
||||
rows: List[Tuple[object, ...]] = []
|
||||
|
||||
processed = 0
|
||||
for item in items:
|
||||
text_payload = f"{item.title}\n{item.summary}"
|
||||
sentiment = _estimate_sentiment(text_payload)
|
||||
base_codes = tuple(code for code in item.ts_codes if code)
|
||||
heat = _estimate_heat(item.published, now, len(base_codes), sentiment)
|
||||
entities = json.dumps(
|
||||
{
|
||||
"ts_codes": list(base_codes),
|
||||
"source_url": item.link,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
resolved_codes = base_codes or (None,)
|
||||
for ts_code in resolved_codes:
|
||||
row_id = item.id if ts_code is None else f"{item.id}::{ts_code}"
|
||||
rows.append(
|
||||
(
|
||||
row_id,
|
||||
ts_code,
|
||||
item.published.replace(tzinfo=timezone.utc).isoformat(),
|
||||
item.source,
|
||||
item.title,
|
||||
item.summary,
|
||||
item.link,
|
||||
entities,
|
||||
sentiment,
|
||||
heat,
|
||||
)
|
||||
)
|
||||
processed += 1
|
||||
|
||||
if not rows:
|
||||
return 0
|
||||
|
||||
inserted = 0
|
||||
try:
|
||||
with db_session() as conn:
|
||||
conn.executemany(
|
||||
"""
|
||||
INSERT OR IGNORE INTO news
|
||||
(id, ts_code, pub_time, source, title, summary, url, entities, sentiment, heat)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
rows,
|
||||
)
|
||||
inserted = conn.total_changes
|
||||
except sqlite3.OperationalError:
|
||||
LOGGER.exception("写入新闻数据失败,表结构可能未初始化", extra=LOG_EXTRA)
|
||||
return 0
|
||||
except Exception: # pragma: no cover - guard unexpected sqlite errors
|
||||
LOGGER.exception("写入新闻数据异常", extra=LOG_EXTRA)
|
||||
return 0
|
||||
|
||||
LOGGER.info(
|
||||
"RSS 新闻落库完成 processed=%s inserted=%s",
|
||||
processed,
|
||||
inserted,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
return inserted
|
||||
|
||||
|
||||
def ingest_configured_rss(
|
||||
*,
|
||||
hours_back: Optional[int] = None,
|
||||
max_items_per_feed: Optional[int] = None,
|
||||
max_retries: int = 5,
|
||||
retry_backoff: float = 2.0,
|
||||
retry_jitter: float = 0.5,
|
||||
) -> int:
|
||||
"""Ingest all configured RSS feeds into the news store."""
|
||||
|
||||
configs = resolve_rss_sources()
|
||||
if not configs:
|
||||
LOGGER.info("未配置 RSS 来源,跳过新闻拉取", extra=LOG_EXTRA)
|
||||
return 0
|
||||
|
||||
aggregated: List[RssItem] = []
|
||||
fetched_count = 0
|
||||
for index, cfg in enumerate(configs, start=1):
|
||||
window = hours_back or cfg.hours_back
|
||||
limit = max_items_per_feed or cfg.max_items
|
||||
LOGGER.info(
|
||||
"开始拉取 RSS:%s (window=%sh, limit=%s)",
|
||||
cfg.url,
|
||||
window,
|
||||
limit,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
items = fetch_rss_feed(
|
||||
cfg.url,
|
||||
source=cfg.source,
|
||||
hours_back=window,
|
||||
max_items=limit,
|
||||
max_retries=max_retries,
|
||||
retry_backoff=retry_backoff,
|
||||
retry_jitter=retry_jitter,
|
||||
)
|
||||
if not items:
|
||||
LOGGER.info("RSS 来源无新内容:%s", cfg.url, extra=LOG_EXTRA)
|
||||
continue
|
||||
enriched: List[RssItem] = []
|
||||
for item in items:
|
||||
codes = _assign_ts_codes(item, cfg.ts_codes, cfg.keywords)
|
||||
enriched.append(replace(item, ts_codes=tuple(codes)))
|
||||
aggregated.extend(enriched)
|
||||
fetched_count += len(enriched)
|
||||
if fetched_count and index < len(configs):
|
||||
time.sleep(2.0)
|
||||
|
||||
if not aggregated:
|
||||
LOGGER.info("RSS 来源未产生有效新闻", extra=LOG_EXTRA)
|
||||
alerts.add_warning("RSS", "未获取到任何 RSS 新闻")
|
||||
return 0
|
||||
|
||||
deduped = deduplicate_items(aggregated)
|
||||
LOGGER.info(
|
||||
"RSS 聚合完成 total_fetched=%s unique=%s",
|
||||
fetched_count,
|
||||
len(deduped),
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
return save_news_items(deduped)
|
||||
|
||||
|
||||
def resolve_rss_sources() -> List[RssFeedConfig]:
|
||||
"""Resolve RSS feed configuration from persisted settings."""
|
||||
|
||||
cfg = get_config()
|
||||
raw = getattr(cfg, "rss_sources", None) or {}
|
||||
feeds: Dict[str, RssFeedConfig] = {}
|
||||
|
||||
def _add_feed(url: str, **kwargs: object) -> None:
|
||||
clean_url = url.strip()
|
||||
if not clean_url:
|
||||
return
|
||||
key = clean_url.lower()
|
||||
if key in feeds:
|
||||
return
|
||||
source_name = kwargs.get("source") or _source_from_url(clean_url)
|
||||
feeds[key] = RssFeedConfig(
|
||||
url=clean_url,
|
||||
source=str(source_name),
|
||||
ts_codes=tuple(kwargs.get("ts_codes", ()) or ()),
|
||||
keywords=tuple(kwargs.get("keywords", ()) or ()),
|
||||
hours_back=int(kwargs.get("hours_back", 48) or 48),
|
||||
max_items=int(kwargs.get("max_items", 50) or 50),
|
||||
)
|
||||
|
||||
if isinstance(raw, dict):
|
||||
for key, value in raw.items():
|
||||
if isinstance(value, dict):
|
||||
if not value.get("enabled", True):
|
||||
continue
|
||||
url = str(value.get("url") or key)
|
||||
ts_codes = [
|
||||
str(code).strip().upper()
|
||||
for code in value.get("ts_codes", [])
|
||||
if str(code).strip()
|
||||
]
|
||||
keywords = [
|
||||
str(token).strip()
|
||||
for token in value.get("keywords", [])
|
||||
if str(token).strip()
|
||||
]
|
||||
_add_feed(
|
||||
url,
|
||||
ts_codes=ts_codes,
|
||||
keywords=keywords,
|
||||
hours_back=value.get("hours_back", 48),
|
||||
max_items=value.get("max_items", 50),
|
||||
source=value.get("source") or value.get("label"),
|
||||
)
|
||||
continue
|
||||
|
||||
if not value:
|
||||
continue
|
||||
url = key
|
||||
ts_codes: List[str] = []
|
||||
if "|" in key:
|
||||
prefix, url = key.split("|", 1)
|
||||
ts_codes = [
|
||||
token.strip().upper()
|
||||
for token in prefix.replace(",", ":").split(":")
|
||||
if token.strip()
|
||||
]
|
||||
_add_feed(url, ts_codes=ts_codes)
|
||||
|
||||
if feeds:
|
||||
return list(feeds.values())
|
||||
|
||||
return list(DEFAULT_RSS_SOURCES)
|
||||
|
||||
|
||||
def _download_feed(
|
||||
url: str,
|
||||
timeout: float,
|
||||
*,
|
||||
max_retries: int,
|
||||
retry_backoff: float,
|
||||
retry_jitter: float,
|
||||
) -> Optional[bytes]:
|
||||
headers = {
|
||||
"User-Agent": "llm-quant/0.1 (+https://github.com/qiang/llm_quant)",
|
||||
"Accept": "application/rss+xml, application/atom+xml, application/xml;q=0.9, */*;q=0.8",
|
||||
}
|
||||
attempt = 0
|
||||
delay = max(0.5, retry_backoff)
|
||||
while attempt <= max_retries:
|
||||
try:
|
||||
response = requests.get(url, headers=headers, timeout=timeout)
|
||||
except RequestException as exc:
|
||||
attempt += 1
|
||||
if attempt > max_retries:
|
||||
message = f"源请求失败:{url}"
|
||||
LOGGER.warning("RSS 请求失败:%s err=%s", url, exc, extra=LOG_EXTRA)
|
||||
alerts.add_warning(_rss_source_key(url), message, str(exc))
|
||||
return None
|
||||
wait = delay + random.uniform(0, retry_jitter)
|
||||
LOGGER.info(
|
||||
"RSS 请求异常,%.2f 秒后重试 url=%s attempt=%s/%s",
|
||||
wait,
|
||||
url,
|
||||
attempt,
|
||||
max_retries,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
time.sleep(max(wait, 0.1))
|
||||
delay *= max(1.1, retry_backoff)
|
||||
continue
|
||||
|
||||
status = response.status_code
|
||||
if 200 <= status < 300:
|
||||
return response.content
|
||||
|
||||
if status in {429, 503}:
|
||||
attempt += 1
|
||||
if attempt > max_retries:
|
||||
LOGGER.warning(
|
||||
"RSS 请求失败:%s status=%s 已达到最大重试次数",
|
||||
url,
|
||||
status,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
alerts.add_warning(
|
||||
_rss_source_key(url),
|
||||
"源限流",
|
||||
f"HTTP {status}",
|
||||
)
|
||||
return None
|
||||
retry_after = response.headers.get("Retry-After")
|
||||
if retry_after:
|
||||
try:
|
||||
wait = float(retry_after)
|
||||
except ValueError:
|
||||
wait = delay
|
||||
else:
|
||||
wait = delay
|
||||
wait += random.uniform(0, retry_jitter)
|
||||
LOGGER.info(
|
||||
"RSS 命中限流 status=%s,%.2f 秒后重试 url=%s attempt=%s/%s",
|
||||
status,
|
||||
wait,
|
||||
url,
|
||||
attempt,
|
||||
max_retries,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
time.sleep(max(wait, 0.1))
|
||||
delay *= max(1.1, retry_backoff)
|
||||
continue
|
||||
|
||||
LOGGER.warning(
|
||||
"RSS 请求失败:%s status=%s",
|
||||
url,
|
||||
status,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
alerts.add_warning(
|
||||
_rss_source_key(url),
|
||||
"源响应异常",
|
||||
f"HTTP {status}",
|
||||
)
|
||||
return None
|
||||
|
||||
LOGGER.warning("RSS 请求失败:%s 未获取内容", url, extra=LOG_EXTRA)
|
||||
alerts.add_warning(_rss_source_key(url), "未获取内容")
|
||||
return None
|
||||
|
||||
|
||||
def _extract_html_feed_links(content: bytes, base_url: str) -> List[str]:
|
||||
sample = content[:1024].lower()
|
||||
if b"<rss" in sample or b"<feed" in sample:
|
||||
return []
|
||||
|
||||
for encoding in ("utf-8", "gb18030", "gb2312"):
|
||||
try:
|
||||
text = content.decode(encoding)
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
text = content.decode(encoding, errors="ignore")
|
||||
break
|
||||
else:
|
||||
text = content.decode("utf-8", errors="ignore")
|
||||
|
||||
if "<link" not in text and ".xml" not in text:
|
||||
return []
|
||||
|
||||
feed_urls: List[str] = []
|
||||
alternates = re.compile(
|
||||
r"<link[^>]+rel=[\"']alternate[\"'][^>]+type=[\"']application/(?:rss|atom)\+xml[\"'][^>]*href=[\"']([^\"']+)[\"']",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
for match in alternates.finditer(text):
|
||||
href = match.group(1).strip()
|
||||
if href:
|
||||
feed_urls.append(urljoin(base_url, href))
|
||||
|
||||
if not feed_urls:
|
||||
anchors = re.compile(r"href=[\"']([^\"']+\.xml)[\"']", re.IGNORECASE)
|
||||
for match in anchors.finditer(text):
|
||||
href = match.group(1).strip()
|
||||
if href:
|
||||
feed_urls.append(urljoin(base_url, href))
|
||||
|
||||
unique_urls: List[str] = []
|
||||
seen = set()
|
||||
for href in feed_urls:
|
||||
if href not in seen and href != base_url:
|
||||
seen.add(href)
|
||||
unique_urls.append(href)
|
||||
return unique_urls
|
||||
|
||||
|
||||
def _safe_snippet(content: bytes, limit: int = 160) -> str:
|
||||
try:
|
||||
text = content.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
try:
|
||||
text = content.decode("gb18030", errors="ignore")
|
||||
except UnicodeDecodeError:
|
||||
text = content.decode("latin-1", errors="ignore")
|
||||
cleaned = re.sub(r"\s+", " ", text)
|
||||
if len(cleaned) > limit:
|
||||
return cleaned[: limit - 3] + "..."
|
||||
return cleaned
|
||||
|
||||
|
||||
def _parse_feed_content(content: bytes) -> List[Dict[str, object]]:
|
||||
if feedparser is not None:
|
||||
parsed = feedparser.parse(content)
|
||||
entries = []
|
||||
for entry in getattr(parsed, "entries", []) or []:
|
||||
entries.append(
|
||||
{
|
||||
"id": getattr(entry, "id", None) or getattr(entry, "guid", None),
|
||||
"title": getattr(entry, "title", ""),
|
||||
"link": getattr(entry, "link", ""),
|
||||
"summary": getattr(entry, "summary", "") or getattr(entry, "description", ""),
|
||||
"published": _parse_datetime(
|
||||
getattr(entry, "published", None)
|
||||
or getattr(entry, "updated", None)
|
||||
or getattr(entry, "issued", None)
|
||||
),
|
||||
}
|
||||
)
|
||||
if entries:
|
||||
return entries
|
||||
else: # pragma: no cover - log helpful info when dependency missing
|
||||
LOGGER.warning(
|
||||
"feedparser 未安装,使用简易 XML 解析器回退处理 RSS",
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
|
||||
return _parse_feed_xml(content)
|
||||
|
||||
|
||||
def _parse_feed_xml(content: bytes) -> List[Dict[str, object]]:
|
||||
try:
|
||||
xml_text = content.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
xml_text = content.decode("utf-8", errors="ignore")
|
||||
|
||||
try:
|
||||
root = ET.fromstring(xml_text)
|
||||
except ET.ParseError as exc: # pragma: no cover - depends on remote feed
|
||||
LOGGER.warning("RSS XML 解析失败 err=%s", exc, extra=LOG_EXTRA)
|
||||
return _lenient_parse_items(xml_text)
|
||||
|
||||
tag = _local_name(root.tag)
|
||||
if tag == "rss":
|
||||
candidates = root.findall(".//item")
|
||||
elif tag == "feed":
|
||||
candidates = root.findall(".//{*}entry")
|
||||
else: # fallback
|
||||
candidates = root.findall(".//item") or root.findall(".//{*}entry")
|
||||
|
||||
entries: List[Dict[str, object]] = []
|
||||
for node in candidates:
|
||||
entries.append(
|
||||
{
|
||||
"id": _child_text(node, {"id", "guid"}),
|
||||
"title": _child_text(node, {"title"}) or "",
|
||||
"link": _child_text(node, {"link"}) or "",
|
||||
"summary": _child_text(node, {"summary", "description"}) or "",
|
||||
"published": _parse_datetime(
|
||||
_child_text(node, {"pubDate", "published", "updated"})
|
||||
),
|
||||
}
|
||||
)
|
||||
if not entries and "<item" in xml_text.lower():
|
||||
return _lenient_parse_items(xml_text)
|
||||
return entries
|
||||
|
||||
|
||||
def _lenient_parse_items(xml_text: str) -> List[Dict[str, object]]:
|
||||
"""Fallback parser that tolerates malformed RSS by using regular expressions."""
|
||||
|
||||
items: List[Dict[str, object]] = []
|
||||
pattern = re.compile(r"<(item|entry)[^>]*>(.+?)</\\1>", re.IGNORECASE | re.DOTALL)
|
||||
for match in pattern.finditer(xml_text):
|
||||
block = match.group(0)
|
||||
title = _extract_tag_text(block, ["title"]) or ""
|
||||
link = _extract_link(block)
|
||||
summary = _extract_tag_text(block, ["summary", "description"]) or ""
|
||||
published_text = _extract_tag_text(block, ["pubDate", "published", "updated"])
|
||||
items.append(
|
||||
{
|
||||
"id": _extract_tag_text(block, ["id", "guid"]) or link,
|
||||
"title": title,
|
||||
"link": link,
|
||||
"summary": summary,
|
||||
"published": _parse_datetime(published_text),
|
||||
}
|
||||
)
|
||||
if items:
|
||||
LOGGER.info("RSS 采用宽松解析提取 %s 条记录", len(items), extra=LOG_EXTRA)
|
||||
return items
|
||||
|
||||
|
||||
def _extract_tag_text(block: str, names: Sequence[str]) -> Optional[str]:
|
||||
for name in names:
|
||||
pattern = re.compile(rf"<{name}[^>]*>(.*?)</{name}>", re.IGNORECASE | re.DOTALL)
|
||||
match = pattern.search(block)
|
||||
if match:
|
||||
text = re.sub(r"<[^>]+>", " ", match.group(1))
|
||||
return _clean_text(text)
|
||||
return None
|
||||
|
||||
|
||||
def _extract_link(block: str) -> str:
|
||||
href_pattern = re.compile(r"<link[^>]*href=\"([^\"]+)\"[^>]*>", re.IGNORECASE)
|
||||
match = href_pattern.search(block)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
inline_pattern = re.compile(r"<link[^>]*>(.*?)</link>", re.IGNORECASE | re.DOTALL)
|
||||
match = inline_pattern.search(block)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
return ""
|
||||
|
||||
|
||||
def _assign_ts_codes(
|
||||
item: RssItem,
|
||||
base_codes: Sequence[str],
|
||||
keywords: Sequence[str],
|
||||
) -> List[str]:
|
||||
matches: set[str] = set()
|
||||
text = f"{item.title} {item.summary}".lower()
|
||||
if keywords:
|
||||
for keyword in keywords:
|
||||
token = keyword.lower().strip()
|
||||
if token and token in text:
|
||||
matches.update(code.strip().upper() for code in base_codes if code)
|
||||
break
|
||||
else:
|
||||
matches.update(code.strip().upper() for code in base_codes if code)
|
||||
|
||||
detected = _detect_ts_codes(text)
|
||||
matches.update(detected)
|
||||
return [code for code in matches if code]
|
||||
|
||||
|
||||
def _detect_ts_codes(text: str) -> List[str]:
|
||||
codes: set[str] = set()
|
||||
for match in A_SH_CODE_PATTERN.finditer(text):
|
||||
digits, suffix = match.groups()
|
||||
if suffix:
|
||||
codes.add(f"{digits}.{suffix.upper()}")
|
||||
else:
|
||||
exchange = "SH" if digits.startswith(tuple("569")) else "SZ"
|
||||
codes.add(f"{digits}.{exchange}")
|
||||
for match in HK_CODE_PATTERN.finditer(text):
|
||||
digits = match.group(1)
|
||||
codes.add(f"{digits.zfill(4)}.HK")
|
||||
return sorted(codes)
|
||||
|
||||
|
||||
def _estimate_sentiment(text: str) -> float:
|
||||
normalized = text.lower()
|
||||
score = 0
|
||||
for keyword in POSITIVE_KEYWORDS:
|
||||
if keyword.lower() in normalized:
|
||||
score += 1
|
||||
for keyword in NEGATIVE_KEYWORDS:
|
||||
if keyword.lower() in normalized:
|
||||
score -= 1
|
||||
if score == 0:
|
||||
return 0.0
|
||||
return max(-1.0, min(1.0, score / 3.0))
|
||||
|
||||
|
||||
def _estimate_heat(
|
||||
published: datetime,
|
||||
now: datetime,
|
||||
code_count: int,
|
||||
sentiment: float,
|
||||
) -> float:
|
||||
delta_hours = max(0.0, (now - published).total_seconds() / 3600.0)
|
||||
recency = max(0.0, 1.0 - min(delta_hours, 72.0) / 72.0)
|
||||
coverage_bonus = min(code_count, 3) * 0.05
|
||||
sentiment_bonus = min(abs(sentiment) * 0.1, 0.2)
|
||||
heat = recency + coverage_bonus + sentiment_bonus
|
||||
return max(0.0, min(1.0, round(heat, 4)))
|
||||
|
||||
|
||||
def _parse_datetime(value: Optional[str]) -> Optional[datetime]:
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
dt = parsedate_to_datetime(value)
|
||||
if dt.tzinfo is not None:
|
||||
dt = dt.astimezone(timezone.utc).replace(tzinfo=None)
|
||||
return dt
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
|
||||
for fmt in ("%Y-%m-%dT%H:%M:%S", "%Y-%m-%d %H:%M:%S"):
|
||||
try:
|
||||
return datetime.strptime(value[:19], fmt)
|
||||
except ValueError:
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
def _clean_text(value: Optional[str]) -> str:
|
||||
if not value:
|
||||
return ""
|
||||
text = re.sub(r"<[^>]+>", " ", value)
|
||||
return re.sub(r"\s+", " ", text).strip()
|
||||
|
||||
|
||||
def _truncate(value: str, length: int) -> str:
|
||||
if len(value) <= length:
|
||||
return value
|
||||
return value[: length - 3].rstrip() + "..."
|
||||
|
||||
|
||||
def _normalise_item_id(
|
||||
raw_id: Optional[str], link: str, title: str, published: datetime
|
||||
) -> str:
|
||||
candidate = (raw_id or link or title).strip()
|
||||
if candidate:
|
||||
return candidate
|
||||
fingerprint = f"{title}|{published.isoformat()}"
|
||||
return hashlib.blake2s(fingerprint.encode("utf-8"), digest_size=16).hexdigest()
|
||||
|
||||
|
||||
def _source_from_url(url: str) -> str:
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
except ValueError:
|
||||
return url
|
||||
host = parsed.netloc or url
|
||||
return host.lower()
|
||||
|
||||
|
||||
def _local_name(tag: str) -> str:
|
||||
if "}" in tag:
|
||||
return tag.rsplit("}", 1)[-1]
|
||||
return tag
|
||||
|
||||
|
||||
def _child_text(node: ET.Element, candidates: set[str]) -> Optional[str]:
|
||||
for child in node:
|
||||
name = _local_name(child.tag)
|
||||
if name in candidates and child.text:
|
||||
return child.text.strip()
|
||||
if name == "link":
|
||||
href = child.attrib.get("href")
|
||||
if href:
|
||||
return href.strip()
|
||||
return None
|
||||
|
||||
|
||||
__all__ = [
|
||||
"RssFeedConfig",
|
||||
"RssItem",
|
||||
"fetch_rss_feed",
|
||||
"deduplicate_items",
|
||||
"save_news_items",
|
||||
"ingest_configured_rss",
|
||||
"resolve_rss_sources",
|
||||
]
|
||||
def _rss_source_key(url: str) -> str:
|
||||
return f"RSS|{url}".strip()
|
||||
|
||||
@ -16,6 +16,7 @@ try:
|
||||
except ImportError: # pragma: no cover - 运行时提示
|
||||
ts = None # type: ignore[assignment]
|
||||
|
||||
from app.utils import alerts
|
||||
from app.utils.config import get_config
|
||||
from app.utils.db import db_session
|
||||
from app.data.schema import initialize_database
|
||||
@ -1601,12 +1602,18 @@ def collect_data_coverage(start: date, end: date) -> Dict[str, Dict[str, object]
|
||||
|
||||
def run_ingestion(job: FetchJob, include_limits: bool = True) -> None:
|
||||
LOGGER.info("启动 TuShare 拉取任务:%s", job.name, extra=LOG_EXTRA)
|
||||
ensure_data_coverage(
|
||||
job.start,
|
||||
job.end,
|
||||
ts_codes=job.ts_codes,
|
||||
include_limits=include_limits,
|
||||
include_extended=True,
|
||||
force=True,
|
||||
)
|
||||
LOGGER.info("任务 %s 完成", job.name, extra=LOG_EXTRA)
|
||||
try:
|
||||
ensure_data_coverage(
|
||||
job.start,
|
||||
job.end,
|
||||
ts_codes=job.ts_codes,
|
||||
include_limits=include_limits,
|
||||
include_extended=True,
|
||||
force=True,
|
||||
)
|
||||
except Exception as exc:
|
||||
alerts.add_warning("TuShare", f"拉取任务失败:{job.name}", str(exc))
|
||||
raise
|
||||
else:
|
||||
alerts.clear_warnings("TuShare")
|
||||
LOGGER.info("任务 %s 完成", job.name, extra=LOG_EXTRA)
|
||||
|
||||
@ -36,6 +36,7 @@ from app.llm.metrics import (
|
||||
reset as reset_llm_metrics,
|
||||
snapshot as snapshot_llm_metrics,
|
||||
)
|
||||
from app.utils import alerts
|
||||
from app.utils.config import (
|
||||
ALLOWED_LLM_STRATEGIES,
|
||||
DEFAULT_LLM_BASE_URLS,
|
||||
@ -86,6 +87,9 @@ def render_global_dashboard() -> None:
|
||||
decisions_container = st.sidebar.container()
|
||||
_DASHBOARD_CONTAINERS = (metrics_container, decisions_container)
|
||||
_DASHBOARD_ELEMENTS = _ensure_dashboard_elements(metrics_container, decisions_container)
|
||||
if st.sidebar.button("清除数据告警", key="clear_data_alerts"):
|
||||
alerts.clear_warnings()
|
||||
_update_dashboard_sidebar()
|
||||
if not _SIDEBAR_LISTENER_ATTACHED:
|
||||
register_llm_metrics_listener(_sidebar_metrics_listener)
|
||||
_SIDEBAR_LISTENER_ATTACHED = True
|
||||
@ -132,6 +136,23 @@ def _update_dashboard_sidebar(
|
||||
else:
|
||||
model_placeholder.info("暂无模型分布数据。")
|
||||
|
||||
warnings_placeholder = elements.get("warnings")
|
||||
if warnings_placeholder is not None:
|
||||
warnings_placeholder.empty()
|
||||
warnings = alerts.get_warnings()
|
||||
if warnings:
|
||||
lines = []
|
||||
for warning in warnings[-10:]:
|
||||
detail = warning.get("detail")
|
||||
appendix = f" {detail}" if detail else ""
|
||||
lines.append(
|
||||
f"- **{warning['source']}** {warning['message']}{appendix}"
|
||||
f"\n<small>{warning['timestamp']}</small>"
|
||||
)
|
||||
warnings_placeholder.markdown("\n".join(lines), unsafe_allow_html=True)
|
||||
else:
|
||||
warnings_placeholder.info("暂无数据告警。")
|
||||
|
||||
decisions = metrics.get("recent_decisions") or llm_recent_decisions(10)
|
||||
if decisions:
|
||||
lines = []
|
||||
@ -163,6 +184,8 @@ def _ensure_dashboard_elements(metrics_container, decisions_container) -> Dict[s
|
||||
distribution_expander = metrics_container.expander("调用分布", expanded=False)
|
||||
provider_distribution = distribution_expander.empty()
|
||||
model_distribution = distribution_expander.empty()
|
||||
warnings_expander = metrics_container.expander("数据告警", expanded=False)
|
||||
warnings_placeholder = warnings_expander.empty()
|
||||
|
||||
decisions_container.subheader("最新决策")
|
||||
decisions_list = decisions_container.empty()
|
||||
@ -173,6 +196,7 @@ def _ensure_dashboard_elements(metrics_container, decisions_container) -> Dict[s
|
||||
"metrics_completion": metrics_completion,
|
||||
"provider_distribution": provider_distribution,
|
||||
"model_distribution": model_distribution,
|
||||
"warnings": warnings_placeholder,
|
||||
"decisions_list": decisions_list,
|
||||
}
|
||||
return elements
|
||||
@ -1649,11 +1673,80 @@ def render_tests() -> None:
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.exception("示例 TuShare 拉取失败", extra=LOG_EXTRA)
|
||||
st.error(f"拉取失败:{exc}")
|
||||
alerts.add_warning("TuShare", "示例拉取失败", str(exc))
|
||||
_update_dashboard_sidebar()
|
||||
|
||||
st.info("注意:TuShare 拉取依赖网络与 Token,若环境未配置将出现错误提示。")
|
||||
|
||||
st.divider()
|
||||
days = int(st.number_input("检查窗口(天数)", min_value=30, max_value=1095, value=365, step=30))
|
||||
|
||||
st.subheader("RSS 数据测试")
|
||||
st.write("用于验证 RSS 配置是否能够正常抓取新闻并写入数据库。")
|
||||
rss_url = st.text_input(
|
||||
"测试 RSS 地址",
|
||||
value="https://rsshub.app/cls/depth/1000",
|
||||
help="留空则使用默认配置的全部 RSS 来源。",
|
||||
).strip()
|
||||
rss_hours = int(
|
||||
st.number_input(
|
||||
"回溯窗口(小时)",
|
||||
min_value=1,
|
||||
max_value=168,
|
||||
value=24,
|
||||
step=6,
|
||||
help="仅抓取最近指定小时内的新闻。",
|
||||
)
|
||||
)
|
||||
rss_limit = int(
|
||||
st.number_input(
|
||||
"单源抓取条数",
|
||||
min_value=1,
|
||||
max_value=200,
|
||||
value=50,
|
||||
step=10,
|
||||
)
|
||||
)
|
||||
if st.button("运行 RSS 测试"):
|
||||
from app.ingest import rss as rss_ingest
|
||||
|
||||
LOGGER.info(
|
||||
"点击 RSS 测试按钮 rss_url=%s hours=%s limit=%s",
|
||||
rss_url,
|
||||
rss_hours,
|
||||
rss_limit,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
with st.spinner("正在抓取 RSS 新闻..."):
|
||||
try:
|
||||
if rss_url:
|
||||
items = rss_ingest.fetch_rss_feed(
|
||||
rss_url,
|
||||
hours_back=rss_hours,
|
||||
max_items=rss_limit,
|
||||
)
|
||||
count = rss_ingest.save_news_items(items)
|
||||
else:
|
||||
count = rss_ingest.ingest_configured_rss(
|
||||
hours_back=rss_hours,
|
||||
max_items_per_feed=rss_limit,
|
||||
)
|
||||
st.success(f"RSS 测试完成,新增 {count} 条新闻记录。")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.exception("RSS 测试失败", extra=LOG_EXTRA)
|
||||
st.error(f"RSS 测试失败:{exc}")
|
||||
alerts.add_warning("RSS", "RSS 测试执行失败", str(exc))
|
||||
_update_dashboard_sidebar()
|
||||
|
||||
st.divider()
|
||||
days = int(
|
||||
st.number_input(
|
||||
"检查窗口(天数)",
|
||||
min_value=30,
|
||||
max_value=10950,
|
||||
value=365,
|
||||
step=30,
|
||||
)
|
||||
)
|
||||
LOGGER.debug("检查窗口天数=%s", days, extra=LOG_EXTRA)
|
||||
cfg = get_config()
|
||||
force_refresh = st.checkbox(
|
||||
@ -1666,8 +1759,8 @@ def render_tests() -> None:
|
||||
LOGGER.info("更新 force_refresh=%s", force_refresh, extra=LOG_EXTRA)
|
||||
save_config()
|
||||
|
||||
if st.button("执行开机检查"):
|
||||
LOGGER.info("点击执行开机检查按钮", extra=LOG_EXTRA)
|
||||
if st.button("执行手动数据同步"):
|
||||
LOGGER.info("点击执行手动数据同步按钮", extra=LOG_EXTRA)
|
||||
progress_bar = st.progress(0.0)
|
||||
status_placeholder = st.empty()
|
||||
log_placeholder = st.empty()
|
||||
@ -1677,23 +1770,25 @@ def render_tests() -> None:
|
||||
progress_bar.progress(min(max(value, 0.0), 1.0))
|
||||
status_placeholder.write(message)
|
||||
messages.append(message)
|
||||
LOGGER.debug("开机检查进度:%s -> %.2f", message, value, extra=LOG_EXTRA)
|
||||
LOGGER.debug("手动数据同步进度:%s -> %.2f", message, value, extra=LOG_EXTRA)
|
||||
|
||||
with st.spinner("正在执行开机检查..."):
|
||||
with st.spinner("正在执行手动数据同步..."):
|
||||
try:
|
||||
report = run_boot_check(
|
||||
days=days,
|
||||
progress_hook=hook,
|
||||
force_refresh=force_refresh,
|
||||
)
|
||||
LOGGER.info("开机检查成功", extra=LOG_EXTRA)
|
||||
st.success("开机检查完成,以下为数据覆盖摘要。")
|
||||
LOGGER.info("手动数据同步成功", extra=LOG_EXTRA)
|
||||
st.success("手动数据同步完成,以下为数据覆盖摘要。")
|
||||
st.json(report.to_dict())
|
||||
if messages:
|
||||
log_placeholder.markdown("\n".join(f"- {msg}" for msg in messages))
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.exception("开机检查失败", extra=LOG_EXTRA)
|
||||
st.error(f"开机检查失败:{exc}")
|
||||
LOGGER.exception("手动数据同步失败", extra=LOG_EXTRA)
|
||||
st.error(f"手动数据同步失败:{exc}")
|
||||
alerts.add_warning("数据同步", "手动数据同步失败", str(exc))
|
||||
_update_dashboard_sidebar()
|
||||
if messages:
|
||||
log_placeholder.markdown("\n".join(f"- {msg}" for msg in messages))
|
||||
finally:
|
||||
@ -1844,7 +1939,7 @@ def render_tests() -> None:
|
||||
|
||||
def main() -> None:
|
||||
LOGGER.info("初始化 Streamlit UI", extra=LOG_EXTRA)
|
||||
st.set_page_config(page_title="多智能体投资助理", layout="wide")
|
||||
st.set_page_config(page_title="多智能体个人投资助理", layout="wide")
|
||||
render_global_dashboard()
|
||||
tabs = st.tabs(["今日计划", "回测与复盘", "数据与设置", "自检测试"])
|
||||
LOGGER.debug("Tabs 初始化完成:%s", ["今日计划", "回测与复盘", "数据与设置", "自检测试"], extra=LOG_EXTRA)
|
||||
|
||||
55
app/utils/alerts.py
Normal file
55
app/utils/alerts.py
Normal file
@ -0,0 +1,55 @@
|
||||
"""Runtime data warning registry for surfacing ingestion issues in UI."""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from threading import Lock
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
|
||||
_ALERTS: List[Dict[str, str]] = []
|
||||
_LOCK = Lock()
|
||||
|
||||
|
||||
def add_warning(source: str, message: str, detail: Optional[str] = None) -> None:
|
||||
"""Register or update a warning entry."""
|
||||
|
||||
source = source.strip() or "unknown"
|
||||
message = message.strip() or "发生未知异常"
|
||||
timestamp = datetime.utcnow().isoformat(timespec="seconds") + "Z"
|
||||
|
||||
with _LOCK:
|
||||
for alert in _ALERTS:
|
||||
if alert["source"] == source and alert["message"] == message:
|
||||
alert["timestamp"] = timestamp
|
||||
if detail:
|
||||
alert["detail"] = detail
|
||||
return
|
||||
entry = {
|
||||
"source": source,
|
||||
"message": message,
|
||||
"timestamp": timestamp,
|
||||
}
|
||||
if detail:
|
||||
entry["detail"] = detail
|
||||
_ALERTS.append(entry)
|
||||
if len(_ALERTS) > 50:
|
||||
del _ALERTS[:-50]
|
||||
|
||||
|
||||
def get_warnings() -> List[Dict[str, str]]:
|
||||
"""Return a copy of current warning entries."""
|
||||
|
||||
with _LOCK:
|
||||
return list(_ALERTS)
|
||||
|
||||
|
||||
def clear_warnings(source: Optional[str] = None) -> None:
|
||||
"""Clear warnings entirely or for a specific source."""
|
||||
|
||||
with _LOCK:
|
||||
if source is None:
|
||||
_ALERTS.clear()
|
||||
return
|
||||
source = source.strip()
|
||||
_ALERTS[:] = [alert for alert in _ALERTS if alert["source"] != source]
|
||||
|
||||
@ -334,7 +334,7 @@ class AppConfig:
|
||||
"""User configurable settings persisted in a simple structure."""
|
||||
|
||||
tushare_token: Optional[str] = None
|
||||
rss_sources: Dict[str, bool] = field(default_factory=dict)
|
||||
rss_sources: Dict[str, object] = field(default_factory=dict)
|
||||
decision_method: str = "nash"
|
||||
data_paths: DataPaths = field(default_factory=DataPaths)
|
||||
agent_weights: AgentWeights = field(default_factory=AgentWeights)
|
||||
@ -402,6 +402,14 @@ def _load_from_file(cfg: AppConfig) -> None:
|
||||
if "decision_method" in payload:
|
||||
cfg.decision_method = str(payload.get("decision_method") or cfg.decision_method)
|
||||
|
||||
rss_payload = payload.get("rss_sources")
|
||||
if isinstance(rss_payload, dict):
|
||||
resolved_rss: Dict[str, object] = {}
|
||||
for key, value in rss_payload.items():
|
||||
if isinstance(value, (bool, dict)):
|
||||
resolved_rss[str(key)] = value
|
||||
cfg.rss_sources = resolved_rss
|
||||
|
||||
weights_payload = payload.get("agent_weights")
|
||||
if isinstance(weights_payload, dict):
|
||||
cfg.agent_weights.update_from_dict(weights_payload)
|
||||
@ -572,6 +580,7 @@ def save_config(cfg: AppConfig | None = None) -> None:
|
||||
"tushare_token": cfg.tushare_token,
|
||||
"force_refresh": cfg.force_refresh,
|
||||
"decision_method": cfg.decision_method,
|
||||
"rss_sources": cfg.rss_sources,
|
||||
"agent_weights": cfg.agent_weights.as_dict(),
|
||||
"llm": {
|
||||
"strategy": cfg.llm.strategy if cfg.llm.strategy in ALLOWED_LLM_STRATEGIES else "single",
|
||||
|
||||
@ -4,3 +4,5 @@ streamlit>=1.30
|
||||
tushare>=1.2
|
||||
requests>=2.31
|
||||
python-box>=7.0
|
||||
pytest>=7.0
|
||||
feedparser>=6.0
|
||||
|
||||
89
tests/test_rss_ingest.py
Normal file
89
tests/test_rss_ingest.py
Normal file
@ -0,0 +1,89 @@
|
||||
"""Tests for RSS ingestion utilities."""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from app.ingest import rss
|
||||
from app.utils import alerts
|
||||
from app.utils.config import DataPaths, get_config
|
||||
from app.utils.db import db_session
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def isolated_db(tmp_path):
|
||||
"""Temporarily redirect database paths for isolated writes."""
|
||||
|
||||
cfg = get_config()
|
||||
original_paths = cfg.data_paths
|
||||
tmp_root = tmp_path / "data"
|
||||
tmp_root.mkdir(parents=True, exist_ok=True)
|
||||
cfg.data_paths = DataPaths(root=tmp_root)
|
||||
alerts.clear_warnings()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
cfg.data_paths = original_paths
|
||||
|
||||
|
||||
def test_fetch_rss_feed_parses_entries(monkeypatch):
|
||||
sample_feed = (
|
||||
"""
|
||||
<rss version="2.0">
|
||||
<channel>
|
||||
<title>Example</title>
|
||||
<item>
|
||||
<title>新闻:公司利好公告</title>
|
||||
<link>https://example.com/a</link>
|
||||
<description><![CDATA[内容包含 000001.SZ ]]></description>
|
||||
<pubDate>Wed, 01 Jan 2025 08:30:00 GMT</pubDate>
|
||||
<guid>a</guid>
|
||||
</item>
|
||||
</channel>
|
||||
</rss>
|
||||
"""
|
||||
).encode("utf-8")
|
||||
|
||||
monkeypatch.setattr(
|
||||
rss,
|
||||
"_download_feed",
|
||||
lambda url, timeout, max_retries, retry_backoff, retry_jitter: sample_feed,
|
||||
)
|
||||
|
||||
items = rss.fetch_rss_feed("https://example.com/rss", hours_back=24)
|
||||
|
||||
assert len(items) == 1
|
||||
item = items[0]
|
||||
assert item.title.startswith("新闻")
|
||||
assert item.source == "example.com"
|
||||
|
||||
|
||||
def test_save_news_items_writes_and_deduplicates(isolated_db):
|
||||
published = datetime.utcnow() - timedelta(hours=1)
|
||||
rss_item = rss.RssItem(
|
||||
id="test-id",
|
||||
title="利好消息推动股价",
|
||||
link="https://example.com/news/test",
|
||||
published=published,
|
||||
summary="这是一条利好消息。",
|
||||
source="测试来源",
|
||||
ts_codes=("000001.SZ",),
|
||||
)
|
||||
|
||||
inserted = rss.save_news_items([rss_item])
|
||||
assert inserted >= 1
|
||||
|
||||
with db_session(read_only=True) as conn:
|
||||
row = conn.execute(
|
||||
"SELECT ts_code, sentiment, heat FROM news WHERE id = ?",
|
||||
("test-id::000001.SZ",),
|
||||
).fetchone()
|
||||
assert row is not None
|
||||
assert row["ts_code"] == "000001.SZ"
|
||||
assert row["sentiment"] >= 0 # 利好关键词应给出非负情绪
|
||||
assert 0 <= row["heat"] <= 1
|
||||
|
||||
# 再次保存同一条新闻应被忽略
|
||||
duplicate = rss.save_news_items([rss_item])
|
||||
assert duplicate == 0
|
||||
Loading…
Reference in New Issue
Block a user