This commit is contained in:
sam 2025-09-30 15:42:27 +08:00
parent 8083c9ffab
commit 5228ea1c41
8 changed files with 1109 additions and 30 deletions

View File

@ -1,8 +1,8 @@
# 多智能体投资助理骨架
# 多智能体个人投资助理
## 项目简介
本仓库提供一个面向 A 股日线级别的多智能体投资助理原型,覆盖数据采集、特征抽取、策略博弈、回测展示和 LLM 解释链路。代码以模块化骨架形式呈现,方便在单机环境下快速搭建端到端的量化研究和可视化决策流程。
本仓库提供一个面向 A 股日线级别的多智能体个人投资助理原型,覆盖数据采集、特征抽取、策略博弈、回测展示和 LLM 解释链路。代码以模块化骨架形式呈现,方便在单机环境下快速搭建端到端的量化研究和可视化决策流程。
## 架构总览

View File

@ -1,25 +1,237 @@
"""RSS ingestion for news and heat scores."""
"""RSS ingestion utilities for news sentiment and heat scoring."""
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime
from typing import Iterable, List
import json
import re
import sqlite3
from dataclasses import dataclass, replace
from datetime import datetime, timedelta, timezone
from email.utils import parsedate_to_datetime
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
from urllib.parse import urlparse, urljoin
from xml.etree import ElementTree as ET
import requests
from requests import RequestException
import hashlib
import random
import time
try: # pragma: no cover - optional dependency at runtime
import feedparser # type: ignore[import-not-found]
except ImportError: # pragma: no cover - graceful fallback
feedparser = None # type: ignore[assignment]
from app.data.schema import initialize_database
from app.utils import alerts
from app.utils.config import get_config
from app.utils.db import db_session
from app.utils.logging import get_logger
LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "rss_ingest"}
DEFAULT_TIMEOUT = 10.0
MAX_SUMMARY_LENGTH = 1500
POSITIVE_KEYWORDS: Tuple[str, ...] = (
"利好",
"增长",
"超预期",
"创新高",
"增持",
"回购",
"盈利",
"strong",
"beat",
"upgrade",
)
NEGATIVE_KEYWORDS: Tuple[str, ...] = (
"利空",
"下跌",
"亏损",
"裁员",
"违约",
"处罚",
"暴跌",
"减持",
"downgrade",
"miss",
)
A_SH_CODE_PATTERN = re.compile(r"\b(\d{6})(?:\.(SH|SZ))?\b", re.IGNORECASE)
HK_CODE_PATTERN = re.compile(r"\b(\d{4})\.HK\b", re.IGNORECASE)
@dataclass
class RssFeedConfig:
"""Configuration describing a single RSS source."""
url: str
source: str
ts_codes: Tuple[str, ...] = ()
keywords: Tuple[str, ...] = ()
hours_back: int = 48
max_items: int = 50
@dataclass
class RssItem:
"""Structured representation of an RSS entry."""
id: str
title: str
link: str
published: datetime
summary: str
source: str
ts_codes: Tuple[str, ...] = ()
def fetch_rss_feed(url: str) -> List[RssItem]:
DEFAULT_RSS_SOURCES: Tuple[RssFeedConfig, ...] = ()
def fetch_rss_feed(
url: str,
*,
source: Optional[str] = None,
hours_back: int = 48,
max_items: int = 50,
timeout: float = DEFAULT_TIMEOUT,
max_retries: int = 5,
retry_backoff: float = 1.5,
retry_jitter: float = 0.3,
) -> List[RssItem]:
"""Download and parse an RSS feed into structured items."""
raise NotImplementedError
return _fetch_feed_items(
url,
source=source,
hours_back=hours_back,
max_items=max_items,
timeout=timeout,
max_retries=max_retries,
retry_backoff=retry_backoff,
retry_jitter=retry_jitter,
allow_html_redirect=True,
)
def _fetch_feed_items(
url: str,
*,
source: Optional[str],
hours_back: int,
max_items: int,
timeout: float,
max_retries: int,
retry_backoff: float,
retry_jitter: float,
allow_html_redirect: bool,
) -> List[RssItem]:
content = _download_feed(
url,
timeout,
max_retries=max_retries,
retry_backoff=retry_backoff,
retry_jitter=retry_jitter,
)
if content is None:
return []
if allow_html_redirect:
feed_links = _extract_html_feed_links(content, url)
if feed_links:
LOGGER.info(
"RSS 页面包含子订阅 %s 个,自动展开",
len(feed_links),
extra=LOG_EXTRA,
)
aggregated: List[RssItem] = []
for feed_url in feed_links:
sub_items = _fetch_feed_items(
feed_url,
source=source,
hours_back=hours_back,
max_items=max_items,
timeout=timeout,
max_retries=max_retries,
retry_backoff=retry_backoff,
retry_jitter=retry_jitter,
allow_html_redirect=False,
)
aggregated.extend(sub_items)
if max_items > 0 and len(aggregated) >= max_items:
return aggregated[:max_items]
if aggregated:
alerts.clear_warnings(_rss_source_key(url))
else:
alerts.add_warning(
_rss_source_key(url),
"聚合页未返回内容",
)
return aggregated
parsed_entries = _parse_feed_content(content)
total_entries = len(parsed_entries)
LOGGER.info(
"RSS 源获取完成 url=%s raw_entries=%s",
url,
total_entries,
extra=LOG_EXTRA,
)
if not parsed_entries:
LOGGER.warning(
"RSS 无可解析条目 url=%s snippet=%s",
url,
_safe_snippet(content),
extra=LOG_EXTRA,
)
return []
cutoff = datetime.utcnow() - timedelta(hours=max(1, hours_back))
source_name = source or _source_from_url(url)
items: List[RssItem] = []
seen_ids: set[str] = set()
for entry in parsed_entries:
published = entry.get("published") or datetime.utcnow()
if published < cutoff:
continue
title = _clean_text(entry.get("title", ""))
summary = _clean_text(entry.get("summary", ""))
link = entry.get("link", "")
raw_id = entry.get("id") or link
item_id = _normalise_item_id(raw_id, link, title, published)
if item_id in seen_ids:
continue
seen_ids.add(item_id)
items.append(
RssItem(
id=item_id,
title=title,
link=link,
published=published,
summary=_truncate(summary, MAX_SUMMARY_LENGTH),
source=source_name,
)
)
if len(items) >= max_items > 0:
break
LOGGER.info(
"RSS 过滤结果 url=%s within_window=%s unique=%s",
url,
sum(1 for entry in parsed_entries if (entry.get("published") or datetime.utcnow()) >= cutoff),
len(items),
extra=LOG_EXTRA,
)
if items:
alerts.clear_warnings(_rss_source_key(url))
return items
def deduplicate_items(items: Iterable[RssItem]) -> List[RssItem]:
@ -36,7 +248,617 @@ def deduplicate_items(items: Iterable[RssItem]) -> List[RssItem]:
return unique
def save_news_items(items: Iterable[RssItem]) -> None:
def save_news_items(items: Iterable[RssItem]) -> int:
"""Persist RSS items into the `news` table."""
raise NotImplementedError
initialize_database()
now = datetime.utcnow()
rows: List[Tuple[object, ...]] = []
processed = 0
for item in items:
text_payload = f"{item.title}\n{item.summary}"
sentiment = _estimate_sentiment(text_payload)
base_codes = tuple(code for code in item.ts_codes if code)
heat = _estimate_heat(item.published, now, len(base_codes), sentiment)
entities = json.dumps(
{
"ts_codes": list(base_codes),
"source_url": item.link,
},
ensure_ascii=False,
)
resolved_codes = base_codes or (None,)
for ts_code in resolved_codes:
row_id = item.id if ts_code is None else f"{item.id}::{ts_code}"
rows.append(
(
row_id,
ts_code,
item.published.replace(tzinfo=timezone.utc).isoformat(),
item.source,
item.title,
item.summary,
item.link,
entities,
sentiment,
heat,
)
)
processed += 1
if not rows:
return 0
inserted = 0
try:
with db_session() as conn:
conn.executemany(
"""
INSERT OR IGNORE INTO news
(id, ts_code, pub_time, source, title, summary, url, entities, sentiment, heat)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
rows,
)
inserted = conn.total_changes
except sqlite3.OperationalError:
LOGGER.exception("写入新闻数据失败,表结构可能未初始化", extra=LOG_EXTRA)
return 0
except Exception: # pragma: no cover - guard unexpected sqlite errors
LOGGER.exception("写入新闻数据异常", extra=LOG_EXTRA)
return 0
LOGGER.info(
"RSS 新闻落库完成 processed=%s inserted=%s",
processed,
inserted,
extra=LOG_EXTRA,
)
return inserted
def ingest_configured_rss(
*,
hours_back: Optional[int] = None,
max_items_per_feed: Optional[int] = None,
max_retries: int = 5,
retry_backoff: float = 2.0,
retry_jitter: float = 0.5,
) -> int:
"""Ingest all configured RSS feeds into the news store."""
configs = resolve_rss_sources()
if not configs:
LOGGER.info("未配置 RSS 来源,跳过新闻拉取", extra=LOG_EXTRA)
return 0
aggregated: List[RssItem] = []
fetched_count = 0
for index, cfg in enumerate(configs, start=1):
window = hours_back or cfg.hours_back
limit = max_items_per_feed or cfg.max_items
LOGGER.info(
"开始拉取 RSS%s (window=%sh, limit=%s)",
cfg.url,
window,
limit,
extra=LOG_EXTRA,
)
items = fetch_rss_feed(
cfg.url,
source=cfg.source,
hours_back=window,
max_items=limit,
max_retries=max_retries,
retry_backoff=retry_backoff,
retry_jitter=retry_jitter,
)
if not items:
LOGGER.info("RSS 来源无新内容:%s", cfg.url, extra=LOG_EXTRA)
continue
enriched: List[RssItem] = []
for item in items:
codes = _assign_ts_codes(item, cfg.ts_codes, cfg.keywords)
enriched.append(replace(item, ts_codes=tuple(codes)))
aggregated.extend(enriched)
fetched_count += len(enriched)
if fetched_count and index < len(configs):
time.sleep(2.0)
if not aggregated:
LOGGER.info("RSS 来源未产生有效新闻", extra=LOG_EXTRA)
alerts.add_warning("RSS", "未获取到任何 RSS 新闻")
return 0
deduped = deduplicate_items(aggregated)
LOGGER.info(
"RSS 聚合完成 total_fetched=%s unique=%s",
fetched_count,
len(deduped),
extra=LOG_EXTRA,
)
return save_news_items(deduped)
def resolve_rss_sources() -> List[RssFeedConfig]:
"""Resolve RSS feed configuration from persisted settings."""
cfg = get_config()
raw = getattr(cfg, "rss_sources", None) or {}
feeds: Dict[str, RssFeedConfig] = {}
def _add_feed(url: str, **kwargs: object) -> None:
clean_url = url.strip()
if not clean_url:
return
key = clean_url.lower()
if key in feeds:
return
source_name = kwargs.get("source") or _source_from_url(clean_url)
feeds[key] = RssFeedConfig(
url=clean_url,
source=str(source_name),
ts_codes=tuple(kwargs.get("ts_codes", ()) or ()),
keywords=tuple(kwargs.get("keywords", ()) or ()),
hours_back=int(kwargs.get("hours_back", 48) or 48),
max_items=int(kwargs.get("max_items", 50) or 50),
)
if isinstance(raw, dict):
for key, value in raw.items():
if isinstance(value, dict):
if not value.get("enabled", True):
continue
url = str(value.get("url") or key)
ts_codes = [
str(code).strip().upper()
for code in value.get("ts_codes", [])
if str(code).strip()
]
keywords = [
str(token).strip()
for token in value.get("keywords", [])
if str(token).strip()
]
_add_feed(
url,
ts_codes=ts_codes,
keywords=keywords,
hours_back=value.get("hours_back", 48),
max_items=value.get("max_items", 50),
source=value.get("source") or value.get("label"),
)
continue
if not value:
continue
url = key
ts_codes: List[str] = []
if "|" in key:
prefix, url = key.split("|", 1)
ts_codes = [
token.strip().upper()
for token in prefix.replace(",", ":").split(":")
if token.strip()
]
_add_feed(url, ts_codes=ts_codes)
if feeds:
return list(feeds.values())
return list(DEFAULT_RSS_SOURCES)
def _download_feed(
url: str,
timeout: float,
*,
max_retries: int,
retry_backoff: float,
retry_jitter: float,
) -> Optional[bytes]:
headers = {
"User-Agent": "llm-quant/0.1 (+https://github.com/qiang/llm_quant)",
"Accept": "application/rss+xml, application/atom+xml, application/xml;q=0.9, */*;q=0.8",
}
attempt = 0
delay = max(0.5, retry_backoff)
while attempt <= max_retries:
try:
response = requests.get(url, headers=headers, timeout=timeout)
except RequestException as exc:
attempt += 1
if attempt > max_retries:
message = f"源请求失败:{url}"
LOGGER.warning("RSS 请求失败:%s err=%s", url, exc, extra=LOG_EXTRA)
alerts.add_warning(_rss_source_key(url), message, str(exc))
return None
wait = delay + random.uniform(0, retry_jitter)
LOGGER.info(
"RSS 请求异常,%.2f 秒后重试 url=%s attempt=%s/%s",
wait,
url,
attempt,
max_retries,
extra=LOG_EXTRA,
)
time.sleep(max(wait, 0.1))
delay *= max(1.1, retry_backoff)
continue
status = response.status_code
if 200 <= status < 300:
return response.content
if status in {429, 503}:
attempt += 1
if attempt > max_retries:
LOGGER.warning(
"RSS 请求失败:%s status=%s 已达到最大重试次数",
url,
status,
extra=LOG_EXTRA,
)
alerts.add_warning(
_rss_source_key(url),
"源限流",
f"HTTP {status}",
)
return None
retry_after = response.headers.get("Retry-After")
if retry_after:
try:
wait = float(retry_after)
except ValueError:
wait = delay
else:
wait = delay
wait += random.uniform(0, retry_jitter)
LOGGER.info(
"RSS 命中限流 status=%s%.2f 秒后重试 url=%s attempt=%s/%s",
status,
wait,
url,
attempt,
max_retries,
extra=LOG_EXTRA,
)
time.sleep(max(wait, 0.1))
delay *= max(1.1, retry_backoff)
continue
LOGGER.warning(
"RSS 请求失败:%s status=%s",
url,
status,
extra=LOG_EXTRA,
)
alerts.add_warning(
_rss_source_key(url),
"源响应异常",
f"HTTP {status}",
)
return None
LOGGER.warning("RSS 请求失败:%s 未获取内容", url, extra=LOG_EXTRA)
alerts.add_warning(_rss_source_key(url), "未获取内容")
return None
def _extract_html_feed_links(content: bytes, base_url: str) -> List[str]:
sample = content[:1024].lower()
if b"<rss" in sample or b"<feed" in sample:
return []
for encoding in ("utf-8", "gb18030", "gb2312"):
try:
text = content.decode(encoding)
break
except UnicodeDecodeError:
text = content.decode(encoding, errors="ignore")
break
else:
text = content.decode("utf-8", errors="ignore")
if "<link" not in text and ".xml" not in text:
return []
feed_urls: List[str] = []
alternates = re.compile(
r"<link[^>]+rel=[\"']alternate[\"'][^>]+type=[\"']application/(?:rss|atom)\+xml[\"'][^>]*href=[\"']([^\"']+)[\"']",
re.IGNORECASE,
)
for match in alternates.finditer(text):
href = match.group(1).strip()
if href:
feed_urls.append(urljoin(base_url, href))
if not feed_urls:
anchors = re.compile(r"href=[\"']([^\"']+\.xml)[\"']", re.IGNORECASE)
for match in anchors.finditer(text):
href = match.group(1).strip()
if href:
feed_urls.append(urljoin(base_url, href))
unique_urls: List[str] = []
seen = set()
for href in feed_urls:
if href not in seen and href != base_url:
seen.add(href)
unique_urls.append(href)
return unique_urls
def _safe_snippet(content: bytes, limit: int = 160) -> str:
try:
text = content.decode("utf-8")
except UnicodeDecodeError:
try:
text = content.decode("gb18030", errors="ignore")
except UnicodeDecodeError:
text = content.decode("latin-1", errors="ignore")
cleaned = re.sub(r"\s+", " ", text)
if len(cleaned) > limit:
return cleaned[: limit - 3] + "..."
return cleaned
def _parse_feed_content(content: bytes) -> List[Dict[str, object]]:
if feedparser is not None:
parsed = feedparser.parse(content)
entries = []
for entry in getattr(parsed, "entries", []) or []:
entries.append(
{
"id": getattr(entry, "id", None) or getattr(entry, "guid", None),
"title": getattr(entry, "title", ""),
"link": getattr(entry, "link", ""),
"summary": getattr(entry, "summary", "") or getattr(entry, "description", ""),
"published": _parse_datetime(
getattr(entry, "published", None)
or getattr(entry, "updated", None)
or getattr(entry, "issued", None)
),
}
)
if entries:
return entries
else: # pragma: no cover - log helpful info when dependency missing
LOGGER.warning(
"feedparser 未安装,使用简易 XML 解析器回退处理 RSS",
extra=LOG_EXTRA,
)
return _parse_feed_xml(content)
def _parse_feed_xml(content: bytes) -> List[Dict[str, object]]:
try:
xml_text = content.decode("utf-8")
except UnicodeDecodeError:
xml_text = content.decode("utf-8", errors="ignore")
try:
root = ET.fromstring(xml_text)
except ET.ParseError as exc: # pragma: no cover - depends on remote feed
LOGGER.warning("RSS XML 解析失败 err=%s", exc, extra=LOG_EXTRA)
return _lenient_parse_items(xml_text)
tag = _local_name(root.tag)
if tag == "rss":
candidates = root.findall(".//item")
elif tag == "feed":
candidates = root.findall(".//{*}entry")
else: # fallback
candidates = root.findall(".//item") or root.findall(".//{*}entry")
entries: List[Dict[str, object]] = []
for node in candidates:
entries.append(
{
"id": _child_text(node, {"id", "guid"}),
"title": _child_text(node, {"title"}) or "",
"link": _child_text(node, {"link"}) or "",
"summary": _child_text(node, {"summary", "description"}) or "",
"published": _parse_datetime(
_child_text(node, {"pubDate", "published", "updated"})
),
}
)
if not entries and "<item" in xml_text.lower():
return _lenient_parse_items(xml_text)
return entries
def _lenient_parse_items(xml_text: str) -> List[Dict[str, object]]:
"""Fallback parser that tolerates malformed RSS by using regular expressions."""
items: List[Dict[str, object]] = []
pattern = re.compile(r"<(item|entry)[^>]*>(.+?)</\\1>", re.IGNORECASE | re.DOTALL)
for match in pattern.finditer(xml_text):
block = match.group(0)
title = _extract_tag_text(block, ["title"]) or ""
link = _extract_link(block)
summary = _extract_tag_text(block, ["summary", "description"]) or ""
published_text = _extract_tag_text(block, ["pubDate", "published", "updated"])
items.append(
{
"id": _extract_tag_text(block, ["id", "guid"]) or link,
"title": title,
"link": link,
"summary": summary,
"published": _parse_datetime(published_text),
}
)
if items:
LOGGER.info("RSS 采用宽松解析提取 %s 条记录", len(items), extra=LOG_EXTRA)
return items
def _extract_tag_text(block: str, names: Sequence[str]) -> Optional[str]:
for name in names:
pattern = re.compile(rf"<{name}[^>]*>(.*?)</{name}>", re.IGNORECASE | re.DOTALL)
match = pattern.search(block)
if match:
text = re.sub(r"<[^>]+>", " ", match.group(1))
return _clean_text(text)
return None
def _extract_link(block: str) -> str:
href_pattern = re.compile(r"<link[^>]*href=\"([^\"]+)\"[^>]*>", re.IGNORECASE)
match = href_pattern.search(block)
if match:
return match.group(1).strip()
inline_pattern = re.compile(r"<link[^>]*>(.*?)</link>", re.IGNORECASE | re.DOTALL)
match = inline_pattern.search(block)
if match:
return match.group(1).strip()
return ""
def _assign_ts_codes(
item: RssItem,
base_codes: Sequence[str],
keywords: Sequence[str],
) -> List[str]:
matches: set[str] = set()
text = f"{item.title} {item.summary}".lower()
if keywords:
for keyword in keywords:
token = keyword.lower().strip()
if token and token in text:
matches.update(code.strip().upper() for code in base_codes if code)
break
else:
matches.update(code.strip().upper() for code in base_codes if code)
detected = _detect_ts_codes(text)
matches.update(detected)
return [code for code in matches if code]
def _detect_ts_codes(text: str) -> List[str]:
codes: set[str] = set()
for match in A_SH_CODE_PATTERN.finditer(text):
digits, suffix = match.groups()
if suffix:
codes.add(f"{digits}.{suffix.upper()}")
else:
exchange = "SH" if digits.startswith(tuple("569")) else "SZ"
codes.add(f"{digits}.{exchange}")
for match in HK_CODE_PATTERN.finditer(text):
digits = match.group(1)
codes.add(f"{digits.zfill(4)}.HK")
return sorted(codes)
def _estimate_sentiment(text: str) -> float:
normalized = text.lower()
score = 0
for keyword in POSITIVE_KEYWORDS:
if keyword.lower() in normalized:
score += 1
for keyword in NEGATIVE_KEYWORDS:
if keyword.lower() in normalized:
score -= 1
if score == 0:
return 0.0
return max(-1.0, min(1.0, score / 3.0))
def _estimate_heat(
published: datetime,
now: datetime,
code_count: int,
sentiment: float,
) -> float:
delta_hours = max(0.0, (now - published).total_seconds() / 3600.0)
recency = max(0.0, 1.0 - min(delta_hours, 72.0) / 72.0)
coverage_bonus = min(code_count, 3) * 0.05
sentiment_bonus = min(abs(sentiment) * 0.1, 0.2)
heat = recency + coverage_bonus + sentiment_bonus
return max(0.0, min(1.0, round(heat, 4)))
def _parse_datetime(value: Optional[str]) -> Optional[datetime]:
if not value:
return None
try:
dt = parsedate_to_datetime(value)
if dt.tzinfo is not None:
dt = dt.astimezone(timezone.utc).replace(tzinfo=None)
return dt
except (TypeError, ValueError):
pass
for fmt in ("%Y-%m-%dT%H:%M:%S", "%Y-%m-%d %H:%M:%S"):
try:
return datetime.strptime(value[:19], fmt)
except ValueError:
continue
return None
def _clean_text(value: Optional[str]) -> str:
if not value:
return ""
text = re.sub(r"<[^>]+>", " ", value)
return re.sub(r"\s+", " ", text).strip()
def _truncate(value: str, length: int) -> str:
if len(value) <= length:
return value
return value[: length - 3].rstrip() + "..."
def _normalise_item_id(
raw_id: Optional[str], link: str, title: str, published: datetime
) -> str:
candidate = (raw_id or link or title).strip()
if candidate:
return candidate
fingerprint = f"{title}|{published.isoformat()}"
return hashlib.blake2s(fingerprint.encode("utf-8"), digest_size=16).hexdigest()
def _source_from_url(url: str) -> str:
try:
parsed = urlparse(url)
except ValueError:
return url
host = parsed.netloc or url
return host.lower()
def _local_name(tag: str) -> str:
if "}" in tag:
return tag.rsplit("}", 1)[-1]
return tag
def _child_text(node: ET.Element, candidates: set[str]) -> Optional[str]:
for child in node:
name = _local_name(child.tag)
if name in candidates and child.text:
return child.text.strip()
if name == "link":
href = child.attrib.get("href")
if href:
return href.strip()
return None
__all__ = [
"RssFeedConfig",
"RssItem",
"fetch_rss_feed",
"deduplicate_items",
"save_news_items",
"ingest_configured_rss",
"resolve_rss_sources",
]
def _rss_source_key(url: str) -> str:
return f"RSS|{url}".strip()

View File

@ -16,6 +16,7 @@ try:
except ImportError: # pragma: no cover - 运行时提示
ts = None # type: ignore[assignment]
from app.utils import alerts
from app.utils.config import get_config
from app.utils.db import db_session
from app.data.schema import initialize_database
@ -1601,6 +1602,7 @@ def collect_data_coverage(start: date, end: date) -> Dict[str, Dict[str, object]
def run_ingestion(job: FetchJob, include_limits: bool = True) -> None:
LOGGER.info("启动 TuShare 拉取任务:%s", job.name, extra=LOG_EXTRA)
try:
ensure_data_coverage(
job.start,
job.end,
@ -1609,4 +1611,9 @@ def run_ingestion(job: FetchJob, include_limits: bool = True) -> None:
include_extended=True,
force=True,
)
except Exception as exc:
alerts.add_warning("TuShare", f"拉取任务失败:{job.name}", str(exc))
raise
else:
alerts.clear_warnings("TuShare")
LOGGER.info("任务 %s 完成", job.name, extra=LOG_EXTRA)

View File

@ -36,6 +36,7 @@ from app.llm.metrics import (
reset as reset_llm_metrics,
snapshot as snapshot_llm_metrics,
)
from app.utils import alerts
from app.utils.config import (
ALLOWED_LLM_STRATEGIES,
DEFAULT_LLM_BASE_URLS,
@ -86,6 +87,9 @@ def render_global_dashboard() -> None:
decisions_container = st.sidebar.container()
_DASHBOARD_CONTAINERS = (metrics_container, decisions_container)
_DASHBOARD_ELEMENTS = _ensure_dashboard_elements(metrics_container, decisions_container)
if st.sidebar.button("清除数据告警", key="clear_data_alerts"):
alerts.clear_warnings()
_update_dashboard_sidebar()
if not _SIDEBAR_LISTENER_ATTACHED:
register_llm_metrics_listener(_sidebar_metrics_listener)
_SIDEBAR_LISTENER_ATTACHED = True
@ -132,6 +136,23 @@ def _update_dashboard_sidebar(
else:
model_placeholder.info("暂无模型分布数据。")
warnings_placeholder = elements.get("warnings")
if warnings_placeholder is not None:
warnings_placeholder.empty()
warnings = alerts.get_warnings()
if warnings:
lines = []
for warning in warnings[-10:]:
detail = warning.get("detail")
appendix = f" {detail}" if detail else ""
lines.append(
f"- **{warning['source']}** {warning['message']}{appendix}"
f"\n<small>{warning['timestamp']}</small>"
)
warnings_placeholder.markdown("\n".join(lines), unsafe_allow_html=True)
else:
warnings_placeholder.info("暂无数据告警。")
decisions = metrics.get("recent_decisions") or llm_recent_decisions(10)
if decisions:
lines = []
@ -163,6 +184,8 @@ def _ensure_dashboard_elements(metrics_container, decisions_container) -> Dict[s
distribution_expander = metrics_container.expander("调用分布", expanded=False)
provider_distribution = distribution_expander.empty()
model_distribution = distribution_expander.empty()
warnings_expander = metrics_container.expander("数据告警", expanded=False)
warnings_placeholder = warnings_expander.empty()
decisions_container.subheader("最新决策")
decisions_list = decisions_container.empty()
@ -173,6 +196,7 @@ def _ensure_dashboard_elements(metrics_container, decisions_container) -> Dict[s
"metrics_completion": metrics_completion,
"provider_distribution": provider_distribution,
"model_distribution": model_distribution,
"warnings": warnings_placeholder,
"decisions_list": decisions_list,
}
return elements
@ -1649,11 +1673,80 @@ def render_tests() -> None:
except Exception as exc: # noqa: BLE001
LOGGER.exception("示例 TuShare 拉取失败", extra=LOG_EXTRA)
st.error(f"拉取失败:{exc}")
alerts.add_warning("TuShare", "示例拉取失败", str(exc))
_update_dashboard_sidebar()
st.info("注意TuShare 拉取依赖网络与 Token若环境未配置将出现错误提示。")
st.divider()
days = int(st.number_input("检查窗口(天数)", min_value=30, max_value=1095, value=365, step=30))
st.subheader("RSS 数据测试")
st.write("用于验证 RSS 配置是否能够正常抓取新闻并写入数据库。")
rss_url = st.text_input(
"测试 RSS 地址",
value="https://rsshub.app/cls/depth/1000",
help="留空则使用默认配置的全部 RSS 来源。",
).strip()
rss_hours = int(
st.number_input(
"回溯窗口(小时)",
min_value=1,
max_value=168,
value=24,
step=6,
help="仅抓取最近指定小时内的新闻。",
)
)
rss_limit = int(
st.number_input(
"单源抓取条数",
min_value=1,
max_value=200,
value=50,
step=10,
)
)
if st.button("运行 RSS 测试"):
from app.ingest import rss as rss_ingest
LOGGER.info(
"点击 RSS 测试按钮 rss_url=%s hours=%s limit=%s",
rss_url,
rss_hours,
rss_limit,
extra=LOG_EXTRA,
)
with st.spinner("正在抓取 RSS 新闻..."):
try:
if rss_url:
items = rss_ingest.fetch_rss_feed(
rss_url,
hours_back=rss_hours,
max_items=rss_limit,
)
count = rss_ingest.save_news_items(items)
else:
count = rss_ingest.ingest_configured_rss(
hours_back=rss_hours,
max_items_per_feed=rss_limit,
)
st.success(f"RSS 测试完成,新增 {count} 条新闻记录。")
except Exception as exc: # noqa: BLE001
LOGGER.exception("RSS 测试失败", extra=LOG_EXTRA)
st.error(f"RSS 测试失败:{exc}")
alerts.add_warning("RSS", "RSS 测试执行失败", str(exc))
_update_dashboard_sidebar()
st.divider()
days = int(
st.number_input(
"检查窗口(天数)",
min_value=30,
max_value=10950,
value=365,
step=30,
)
)
LOGGER.debug("检查窗口天数=%s", days, extra=LOG_EXTRA)
cfg = get_config()
force_refresh = st.checkbox(
@ -1666,8 +1759,8 @@ def render_tests() -> None:
LOGGER.info("更新 force_refresh=%s", force_refresh, extra=LOG_EXTRA)
save_config()
if st.button("执行开机检查"):
LOGGER.info("点击执行开机检查按钮", extra=LOG_EXTRA)
if st.button("执行手动数据同步"):
LOGGER.info("点击执行手动数据同步按钮", extra=LOG_EXTRA)
progress_bar = st.progress(0.0)
status_placeholder = st.empty()
log_placeholder = st.empty()
@ -1677,23 +1770,25 @@ def render_tests() -> None:
progress_bar.progress(min(max(value, 0.0), 1.0))
status_placeholder.write(message)
messages.append(message)
LOGGER.debug("开机检查进度:%s -> %.2f", message, value, extra=LOG_EXTRA)
LOGGER.debug("手动数据同步进度:%s -> %.2f", message, value, extra=LOG_EXTRA)
with st.spinner("正在执行开机检查..."):
with st.spinner("正在执行手动数据同步..."):
try:
report = run_boot_check(
days=days,
progress_hook=hook,
force_refresh=force_refresh,
)
LOGGER.info("开机检查成功", extra=LOG_EXTRA)
st.success("开机检查完成,以下为数据覆盖摘要。")
LOGGER.info("手动数据同步成功", extra=LOG_EXTRA)
st.success("手动数据同步完成,以下为数据覆盖摘要。")
st.json(report.to_dict())
if messages:
log_placeholder.markdown("\n".join(f"- {msg}" for msg in messages))
except Exception as exc: # noqa: BLE001
LOGGER.exception("开机检查失败", extra=LOG_EXTRA)
st.error(f"开机检查失败:{exc}")
LOGGER.exception("手动数据同步失败", extra=LOG_EXTRA)
st.error(f"手动数据同步失败:{exc}")
alerts.add_warning("数据同步", "手动数据同步失败", str(exc))
_update_dashboard_sidebar()
if messages:
log_placeholder.markdown("\n".join(f"- {msg}" for msg in messages))
finally:
@ -1844,7 +1939,7 @@ def render_tests() -> None:
def main() -> None:
LOGGER.info("初始化 Streamlit UI", extra=LOG_EXTRA)
st.set_page_config(page_title="多智能体投资助理", layout="wide")
st.set_page_config(page_title="多智能体个人投资助理", layout="wide")
render_global_dashboard()
tabs = st.tabs(["今日计划", "回测与复盘", "数据与设置", "自检测试"])
LOGGER.debug("Tabs 初始化完成:%s", ["今日计划", "回测与复盘", "数据与设置", "自检测试"], extra=LOG_EXTRA)

55
app/utils/alerts.py Normal file
View File

@ -0,0 +1,55 @@
"""Runtime data warning registry for surfacing ingestion issues in UI."""
from __future__ import annotations
from datetime import datetime
from threading import Lock
from typing import Dict, List, Optional
_ALERTS: List[Dict[str, str]] = []
_LOCK = Lock()
def add_warning(source: str, message: str, detail: Optional[str] = None) -> None:
"""Register or update a warning entry."""
source = source.strip() or "unknown"
message = message.strip() or "发生未知异常"
timestamp = datetime.utcnow().isoformat(timespec="seconds") + "Z"
with _LOCK:
for alert in _ALERTS:
if alert["source"] == source and alert["message"] == message:
alert["timestamp"] = timestamp
if detail:
alert["detail"] = detail
return
entry = {
"source": source,
"message": message,
"timestamp": timestamp,
}
if detail:
entry["detail"] = detail
_ALERTS.append(entry)
if len(_ALERTS) > 50:
del _ALERTS[:-50]
def get_warnings() -> List[Dict[str, str]]:
"""Return a copy of current warning entries."""
with _LOCK:
return list(_ALERTS)
def clear_warnings(source: Optional[str] = None) -> None:
"""Clear warnings entirely or for a specific source."""
with _LOCK:
if source is None:
_ALERTS.clear()
return
source = source.strip()
_ALERTS[:] = [alert for alert in _ALERTS if alert["source"] != source]

View File

@ -334,7 +334,7 @@ class AppConfig:
"""User configurable settings persisted in a simple structure."""
tushare_token: Optional[str] = None
rss_sources: Dict[str, bool] = field(default_factory=dict)
rss_sources: Dict[str, object] = field(default_factory=dict)
decision_method: str = "nash"
data_paths: DataPaths = field(default_factory=DataPaths)
agent_weights: AgentWeights = field(default_factory=AgentWeights)
@ -402,6 +402,14 @@ def _load_from_file(cfg: AppConfig) -> None:
if "decision_method" in payload:
cfg.decision_method = str(payload.get("decision_method") or cfg.decision_method)
rss_payload = payload.get("rss_sources")
if isinstance(rss_payload, dict):
resolved_rss: Dict[str, object] = {}
for key, value in rss_payload.items():
if isinstance(value, (bool, dict)):
resolved_rss[str(key)] = value
cfg.rss_sources = resolved_rss
weights_payload = payload.get("agent_weights")
if isinstance(weights_payload, dict):
cfg.agent_weights.update_from_dict(weights_payload)
@ -572,6 +580,7 @@ def save_config(cfg: AppConfig | None = None) -> None:
"tushare_token": cfg.tushare_token,
"force_refresh": cfg.force_refresh,
"decision_method": cfg.decision_method,
"rss_sources": cfg.rss_sources,
"agent_weights": cfg.agent_weights.as_dict(),
"llm": {
"strategy": cfg.llm.strategy if cfg.llm.strategy in ALLOWED_LLM_STRATEGIES else "single",

View File

@ -4,3 +4,5 @@ streamlit>=1.30
tushare>=1.2
requests>=2.31
python-box>=7.0
pytest>=7.0
feedparser>=6.0

89
tests/test_rss_ingest.py Normal file
View File

@ -0,0 +1,89 @@
"""Tests for RSS ingestion utilities."""
from __future__ import annotations
from datetime import datetime, timedelta
import pytest
from app.ingest import rss
from app.utils import alerts
from app.utils.config import DataPaths, get_config
from app.utils.db import db_session
@pytest.fixture()
def isolated_db(tmp_path):
"""Temporarily redirect database paths for isolated writes."""
cfg = get_config()
original_paths = cfg.data_paths
tmp_root = tmp_path / "data"
tmp_root.mkdir(parents=True, exist_ok=True)
cfg.data_paths = DataPaths(root=tmp_root)
alerts.clear_warnings()
try:
yield
finally:
cfg.data_paths = original_paths
def test_fetch_rss_feed_parses_entries(monkeypatch):
sample_feed = (
"""
<rss version="2.0">
<channel>
<title>Example</title>
<item>
<title>新闻公司利好公告</title>
<link>https://example.com/a</link>
<description><![CDATA[内容包含 000001.SZ ]]></description>
<pubDate>Wed, 01 Jan 2025 08:30:00 GMT</pubDate>
<guid>a</guid>
</item>
</channel>
</rss>
"""
).encode("utf-8")
monkeypatch.setattr(
rss,
"_download_feed",
lambda url, timeout, max_retries, retry_backoff, retry_jitter: sample_feed,
)
items = rss.fetch_rss_feed("https://example.com/rss", hours_back=24)
assert len(items) == 1
item = items[0]
assert item.title.startswith("新闻")
assert item.source == "example.com"
def test_save_news_items_writes_and_deduplicates(isolated_db):
published = datetime.utcnow() - timedelta(hours=1)
rss_item = rss.RssItem(
id="test-id",
title="利好消息推动股价",
link="https://example.com/news/test",
published=published,
summary="这是一条利好消息。",
source="测试来源",
ts_codes=("000001.SZ",),
)
inserted = rss.save_news_items([rss_item])
assert inserted >= 1
with db_session(read_only=True) as conn:
row = conn.execute(
"SELECT ts_code, sentiment, heat FROM news WHERE id = ?",
("test-id::000001.SZ",),
).fetchone()
assert row is not None
assert row["ts_code"] == "000001.SZ"
assert row["sentiment"] >= 0 # 利好关键词应给出非负情绪
assert 0 <= row["heat"] <= 1
# 再次保存同一条新闻应被忽略
duplicate = rss.save_news_items([rss_item])
assert duplicate == 0