From 5228ea1c417edd5bda2cf270085e7daf34f010f8 Mon Sep 17 00:00:00 2001 From: sam Date: Tue, 30 Sep 2025 15:42:27 +0800 Subject: [PATCH] update --- README.md | 4 +- app/ingest/rss.py | 838 ++++++++++++++++++++++++++++++++++++++- app/ingest/tushare.py | 25 +- app/ui/streamlit_app.py | 115 +++++- app/utils/alerts.py | 55 +++ app/utils/config.py | 11 +- requirements.txt | 2 + tests/test_rss_ingest.py | 89 +++++ 8 files changed, 1109 insertions(+), 30 deletions(-) create mode 100644 app/utils/alerts.py create mode 100644 tests/test_rss_ingest.py diff --git a/README.md b/README.md index ca187d6..275d2ff 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ -# 多智能体投资助理骨架 +# 多智能体个人投资助理 ## 项目简介 -本仓库提供一个面向 A 股日线级别的多智能体投资助理原型,覆盖数据采集、特征抽取、策略博弈、回测展示和 LLM 解释链路。代码以模块化骨架形式呈现,方便在单机环境下快速搭建端到端的量化研究和可视化决策流程。 +本仓库提供一个面向 A 股日线级别的多智能体个人投资助理原型,覆盖数据采集、特征抽取、策略博弈、回测展示和 LLM 解释链路。代码以模块化骨架形式呈现,方便在单机环境下快速搭建端到端的量化研究和可视化决策流程。 ## 架构总览 diff --git a/app/ingest/rss.py b/app/ingest/rss.py index 50b97d4..a1dcc03 100644 --- a/app/ingest/rss.py +++ b/app/ingest/rss.py @@ -1,25 +1,237 @@ -"""RSS ingestion for news and heat scores.""" +"""RSS ingestion utilities for news sentiment and heat scoring.""" from __future__ import annotations -from dataclasses import dataclass -from datetime import datetime -from typing import Iterable, List +import json +import re +import sqlite3 +from dataclasses import dataclass, replace +from datetime import datetime, timedelta, timezone +from email.utils import parsedate_to_datetime +from typing import Dict, Iterable, List, Optional, Sequence, Tuple +from urllib.parse import urlparse, urljoin +from xml.etree import ElementTree as ET + +import requests +from requests import RequestException + +import hashlib +import random +import time + +try: # pragma: no cover - optional dependency at runtime + import feedparser # type: ignore[import-not-found] +except ImportError: # pragma: no cover - graceful fallback + feedparser = None # type: ignore[assignment] + +from app.data.schema import initialize_database +from app.utils import alerts +from app.utils.config import get_config +from app.utils.db import db_session +from app.utils.logging import get_logger + + +LOGGER = get_logger(__name__) +LOG_EXTRA = {"stage": "rss_ingest"} + +DEFAULT_TIMEOUT = 10.0 +MAX_SUMMARY_LENGTH = 1500 + +POSITIVE_KEYWORDS: Tuple[str, ...] = ( + "利好", + "增长", + "超预期", + "创新高", + "增持", + "回购", + "盈利", + "strong", + "beat", + "upgrade", +) +NEGATIVE_KEYWORDS: Tuple[str, ...] = ( + "利空", + "下跌", + "亏损", + "裁员", + "违约", + "处罚", + "暴跌", + "减持", + "downgrade", + "miss", +) + +A_SH_CODE_PATTERN = re.compile(r"\b(\d{6})(?:\.(SH|SZ))?\b", re.IGNORECASE) +HK_CODE_PATTERN = re.compile(r"\b(\d{4})\.HK\b", re.IGNORECASE) + + +@dataclass +class RssFeedConfig: + """Configuration describing a single RSS source.""" + + url: str + source: str + ts_codes: Tuple[str, ...] = () + keywords: Tuple[str, ...] = () + hours_back: int = 48 + max_items: int = 50 @dataclass class RssItem: + """Structured representation of an RSS entry.""" + id: str title: str link: str published: datetime summary: str source: str + ts_codes: Tuple[str, ...] = () -def fetch_rss_feed(url: str) -> List[RssItem]: +DEFAULT_RSS_SOURCES: Tuple[RssFeedConfig, ...] = () + + +def fetch_rss_feed( + url: str, + *, + source: Optional[str] = None, + hours_back: int = 48, + max_items: int = 50, + timeout: float = DEFAULT_TIMEOUT, + max_retries: int = 5, + retry_backoff: float = 1.5, + retry_jitter: float = 0.3, +) -> List[RssItem]: """Download and parse an RSS feed into structured items.""" - raise NotImplementedError + return _fetch_feed_items( + url, + source=source, + hours_back=hours_back, + max_items=max_items, + timeout=timeout, + max_retries=max_retries, + retry_backoff=retry_backoff, + retry_jitter=retry_jitter, + allow_html_redirect=True, + ) + + +def _fetch_feed_items( + url: str, + *, + source: Optional[str], + hours_back: int, + max_items: int, + timeout: float, + max_retries: int, + retry_backoff: float, + retry_jitter: float, + allow_html_redirect: bool, +) -> List[RssItem]: + + content = _download_feed( + url, + timeout, + max_retries=max_retries, + retry_backoff=retry_backoff, + retry_jitter=retry_jitter, + ) + if content is None: + return [] + + if allow_html_redirect: + feed_links = _extract_html_feed_links(content, url) + if feed_links: + LOGGER.info( + "RSS 页面包含子订阅 %s 个,自动展开", + len(feed_links), + extra=LOG_EXTRA, + ) + aggregated: List[RssItem] = [] + for feed_url in feed_links: + sub_items = _fetch_feed_items( + feed_url, + source=source, + hours_back=hours_back, + max_items=max_items, + timeout=timeout, + max_retries=max_retries, + retry_backoff=retry_backoff, + retry_jitter=retry_jitter, + allow_html_redirect=False, + ) + aggregated.extend(sub_items) + if max_items > 0 and len(aggregated) >= max_items: + return aggregated[:max_items] + if aggregated: + alerts.clear_warnings(_rss_source_key(url)) + else: + alerts.add_warning( + _rss_source_key(url), + "聚合页未返回内容", + ) + return aggregated + + parsed_entries = _parse_feed_content(content) + total_entries = len(parsed_entries) + LOGGER.info( + "RSS 源获取完成 url=%s raw_entries=%s", + url, + total_entries, + extra=LOG_EXTRA, + ) + if not parsed_entries: + LOGGER.warning( + "RSS 无可解析条目 url=%s snippet=%s", + url, + _safe_snippet(content), + extra=LOG_EXTRA, + ) + return [] + + cutoff = datetime.utcnow() - timedelta(hours=max(1, hours_back)) + source_name = source or _source_from_url(url) + items: List[RssItem] = [] + seen_ids: set[str] = set() + for entry in parsed_entries: + published = entry.get("published") or datetime.utcnow() + if published < cutoff: + continue + title = _clean_text(entry.get("title", "")) + summary = _clean_text(entry.get("summary", "")) + link = entry.get("link", "") + raw_id = entry.get("id") or link + item_id = _normalise_item_id(raw_id, link, title, published) + if item_id in seen_ids: + continue + seen_ids.add(item_id) + items.append( + RssItem( + id=item_id, + title=title, + link=link, + published=published, + summary=_truncate(summary, MAX_SUMMARY_LENGTH), + source=source_name, + ) + ) + if len(items) >= max_items > 0: + break + + LOGGER.info( + "RSS 过滤结果 url=%s within_window=%s unique=%s", + url, + sum(1 for entry in parsed_entries if (entry.get("published") or datetime.utcnow()) >= cutoff), + len(items), + extra=LOG_EXTRA, + ) + if items: + alerts.clear_warnings(_rss_source_key(url)) + + return items def deduplicate_items(items: Iterable[RssItem]) -> List[RssItem]: @@ -36,7 +248,617 @@ def deduplicate_items(items: Iterable[RssItem]) -> List[RssItem]: return unique -def save_news_items(items: Iterable[RssItem]) -> None: +def save_news_items(items: Iterable[RssItem]) -> int: """Persist RSS items into the `news` table.""" - raise NotImplementedError + initialize_database() + now = datetime.utcnow() + rows: List[Tuple[object, ...]] = [] + + processed = 0 + for item in items: + text_payload = f"{item.title}\n{item.summary}" + sentiment = _estimate_sentiment(text_payload) + base_codes = tuple(code for code in item.ts_codes if code) + heat = _estimate_heat(item.published, now, len(base_codes), sentiment) + entities = json.dumps( + { + "ts_codes": list(base_codes), + "source_url": item.link, + }, + ensure_ascii=False, + ) + resolved_codes = base_codes or (None,) + for ts_code in resolved_codes: + row_id = item.id if ts_code is None else f"{item.id}::{ts_code}" + rows.append( + ( + row_id, + ts_code, + item.published.replace(tzinfo=timezone.utc).isoformat(), + item.source, + item.title, + item.summary, + item.link, + entities, + sentiment, + heat, + ) + ) + processed += 1 + + if not rows: + return 0 + + inserted = 0 + try: + with db_session() as conn: + conn.executemany( + """ + INSERT OR IGNORE INTO news + (id, ts_code, pub_time, source, title, summary, url, entities, sentiment, heat) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + rows, + ) + inserted = conn.total_changes + except sqlite3.OperationalError: + LOGGER.exception("写入新闻数据失败,表结构可能未初始化", extra=LOG_EXTRA) + return 0 + except Exception: # pragma: no cover - guard unexpected sqlite errors + LOGGER.exception("写入新闻数据异常", extra=LOG_EXTRA) + return 0 + + LOGGER.info( + "RSS 新闻落库完成 processed=%s inserted=%s", + processed, + inserted, + extra=LOG_EXTRA, + ) + return inserted + + +def ingest_configured_rss( + *, + hours_back: Optional[int] = None, + max_items_per_feed: Optional[int] = None, + max_retries: int = 5, + retry_backoff: float = 2.0, + retry_jitter: float = 0.5, +) -> int: + """Ingest all configured RSS feeds into the news store.""" + + configs = resolve_rss_sources() + if not configs: + LOGGER.info("未配置 RSS 来源,跳过新闻拉取", extra=LOG_EXTRA) + return 0 + + aggregated: List[RssItem] = [] + fetched_count = 0 + for index, cfg in enumerate(configs, start=1): + window = hours_back or cfg.hours_back + limit = max_items_per_feed or cfg.max_items + LOGGER.info( + "开始拉取 RSS:%s (window=%sh, limit=%s)", + cfg.url, + window, + limit, + extra=LOG_EXTRA, + ) + items = fetch_rss_feed( + cfg.url, + source=cfg.source, + hours_back=window, + max_items=limit, + max_retries=max_retries, + retry_backoff=retry_backoff, + retry_jitter=retry_jitter, + ) + if not items: + LOGGER.info("RSS 来源无新内容:%s", cfg.url, extra=LOG_EXTRA) + continue + enriched: List[RssItem] = [] + for item in items: + codes = _assign_ts_codes(item, cfg.ts_codes, cfg.keywords) + enriched.append(replace(item, ts_codes=tuple(codes))) + aggregated.extend(enriched) + fetched_count += len(enriched) + if fetched_count and index < len(configs): + time.sleep(2.0) + + if not aggregated: + LOGGER.info("RSS 来源未产生有效新闻", extra=LOG_EXTRA) + alerts.add_warning("RSS", "未获取到任何 RSS 新闻") + return 0 + + deduped = deduplicate_items(aggregated) + LOGGER.info( + "RSS 聚合完成 total_fetched=%s unique=%s", + fetched_count, + len(deduped), + extra=LOG_EXTRA, + ) + return save_news_items(deduped) + + +def resolve_rss_sources() -> List[RssFeedConfig]: + """Resolve RSS feed configuration from persisted settings.""" + + cfg = get_config() + raw = getattr(cfg, "rss_sources", None) or {} + feeds: Dict[str, RssFeedConfig] = {} + + def _add_feed(url: str, **kwargs: object) -> None: + clean_url = url.strip() + if not clean_url: + return + key = clean_url.lower() + if key in feeds: + return + source_name = kwargs.get("source") or _source_from_url(clean_url) + feeds[key] = RssFeedConfig( + url=clean_url, + source=str(source_name), + ts_codes=tuple(kwargs.get("ts_codes", ()) or ()), + keywords=tuple(kwargs.get("keywords", ()) or ()), + hours_back=int(kwargs.get("hours_back", 48) or 48), + max_items=int(kwargs.get("max_items", 50) or 50), + ) + + if isinstance(raw, dict): + for key, value in raw.items(): + if isinstance(value, dict): + if not value.get("enabled", True): + continue + url = str(value.get("url") or key) + ts_codes = [ + str(code).strip().upper() + for code in value.get("ts_codes", []) + if str(code).strip() + ] + keywords = [ + str(token).strip() + for token in value.get("keywords", []) + if str(token).strip() + ] + _add_feed( + url, + ts_codes=ts_codes, + keywords=keywords, + hours_back=value.get("hours_back", 48), + max_items=value.get("max_items", 50), + source=value.get("source") or value.get("label"), + ) + continue + + if not value: + continue + url = key + ts_codes: List[str] = [] + if "|" in key: + prefix, url = key.split("|", 1) + ts_codes = [ + token.strip().upper() + for token in prefix.replace(",", ":").split(":") + if token.strip() + ] + _add_feed(url, ts_codes=ts_codes) + + if feeds: + return list(feeds.values()) + + return list(DEFAULT_RSS_SOURCES) + + +def _download_feed( + url: str, + timeout: float, + *, + max_retries: int, + retry_backoff: float, + retry_jitter: float, +) -> Optional[bytes]: + headers = { + "User-Agent": "llm-quant/0.1 (+https://github.com/qiang/llm_quant)", + "Accept": "application/rss+xml, application/atom+xml, application/xml;q=0.9, */*;q=0.8", + } + attempt = 0 + delay = max(0.5, retry_backoff) + while attempt <= max_retries: + try: + response = requests.get(url, headers=headers, timeout=timeout) + except RequestException as exc: + attempt += 1 + if attempt > max_retries: + message = f"源请求失败:{url}" + LOGGER.warning("RSS 请求失败:%s err=%s", url, exc, extra=LOG_EXTRA) + alerts.add_warning(_rss_source_key(url), message, str(exc)) + return None + wait = delay + random.uniform(0, retry_jitter) + LOGGER.info( + "RSS 请求异常,%.2f 秒后重试 url=%s attempt=%s/%s", + wait, + url, + attempt, + max_retries, + extra=LOG_EXTRA, + ) + time.sleep(max(wait, 0.1)) + delay *= max(1.1, retry_backoff) + continue + + status = response.status_code + if 200 <= status < 300: + return response.content + + if status in {429, 503}: + attempt += 1 + if attempt > max_retries: + LOGGER.warning( + "RSS 请求失败:%s status=%s 已达到最大重试次数", + url, + status, + extra=LOG_EXTRA, + ) + alerts.add_warning( + _rss_source_key(url), + "源限流", + f"HTTP {status}", + ) + return None + retry_after = response.headers.get("Retry-After") + if retry_after: + try: + wait = float(retry_after) + except ValueError: + wait = delay + else: + wait = delay + wait += random.uniform(0, retry_jitter) + LOGGER.info( + "RSS 命中限流 status=%s,%.2f 秒后重试 url=%s attempt=%s/%s", + status, + wait, + url, + attempt, + max_retries, + extra=LOG_EXTRA, + ) + time.sleep(max(wait, 0.1)) + delay *= max(1.1, retry_backoff) + continue + + LOGGER.warning( + "RSS 请求失败:%s status=%s", + url, + status, + extra=LOG_EXTRA, + ) + alerts.add_warning( + _rss_source_key(url), + "源响应异常", + f"HTTP {status}", + ) + return None + + LOGGER.warning("RSS 请求失败:%s 未获取内容", url, extra=LOG_EXTRA) + alerts.add_warning(_rss_source_key(url), "未获取内容") + return None + + +def _extract_html_feed_links(content: bytes, base_url: str) -> List[str]: + sample = content[:1024].lower() + if b"]+rel=[\"']alternate[\"'][^>]+type=[\"']application/(?:rss|atom)\+xml[\"'][^>]*href=[\"']([^\"']+)[\"']", + re.IGNORECASE, + ) + for match in alternates.finditer(text): + href = match.group(1).strip() + if href: + feed_urls.append(urljoin(base_url, href)) + + if not feed_urls: + anchors = re.compile(r"href=[\"']([^\"']+\.xml)[\"']", re.IGNORECASE) + for match in anchors.finditer(text): + href = match.group(1).strip() + if href: + feed_urls.append(urljoin(base_url, href)) + + unique_urls: List[str] = [] + seen = set() + for href in feed_urls: + if href not in seen and href != base_url: + seen.add(href) + unique_urls.append(href) + return unique_urls + + +def _safe_snippet(content: bytes, limit: int = 160) -> str: + try: + text = content.decode("utf-8") + except UnicodeDecodeError: + try: + text = content.decode("gb18030", errors="ignore") + except UnicodeDecodeError: + text = content.decode("latin-1", errors="ignore") + cleaned = re.sub(r"\s+", " ", text) + if len(cleaned) > limit: + return cleaned[: limit - 3] + "..." + return cleaned + + +def _parse_feed_content(content: bytes) -> List[Dict[str, object]]: + if feedparser is not None: + parsed = feedparser.parse(content) + entries = [] + for entry in getattr(parsed, "entries", []) or []: + entries.append( + { + "id": getattr(entry, "id", None) or getattr(entry, "guid", None), + "title": getattr(entry, "title", ""), + "link": getattr(entry, "link", ""), + "summary": getattr(entry, "summary", "") or getattr(entry, "description", ""), + "published": _parse_datetime( + getattr(entry, "published", None) + or getattr(entry, "updated", None) + or getattr(entry, "issued", None) + ), + } + ) + if entries: + return entries + else: # pragma: no cover - log helpful info when dependency missing + LOGGER.warning( + "feedparser 未安装,使用简易 XML 解析器回退处理 RSS", + extra=LOG_EXTRA, + ) + + return _parse_feed_xml(content) + + +def _parse_feed_xml(content: bytes) -> List[Dict[str, object]]: + try: + xml_text = content.decode("utf-8") + except UnicodeDecodeError: + xml_text = content.decode("utf-8", errors="ignore") + + try: + root = ET.fromstring(xml_text) + except ET.ParseError as exc: # pragma: no cover - depends on remote feed + LOGGER.warning("RSS XML 解析失败 err=%s", exc, extra=LOG_EXTRA) + return _lenient_parse_items(xml_text) + + tag = _local_name(root.tag) + if tag == "rss": + candidates = root.findall(".//item") + elif tag == "feed": + candidates = root.findall(".//{*}entry") + else: # fallback + candidates = root.findall(".//item") or root.findall(".//{*}entry") + + entries: List[Dict[str, object]] = [] + for node in candidates: + entries.append( + { + "id": _child_text(node, {"id", "guid"}), + "title": _child_text(node, {"title"}) or "", + "link": _child_text(node, {"link"}) or "", + "summary": _child_text(node, {"summary", "description"}) or "", + "published": _parse_datetime( + _child_text(node, {"pubDate", "published", "updated"}) + ), + } + ) + if not entries and " List[Dict[str, object]]: + """Fallback parser that tolerates malformed RSS by using regular expressions.""" + + items: List[Dict[str, object]] = [] + pattern = re.compile(r"<(item|entry)[^>]*>(.+?)", re.IGNORECASE | re.DOTALL) + for match in pattern.finditer(xml_text): + block = match.group(0) + title = _extract_tag_text(block, ["title"]) or "" + link = _extract_link(block) + summary = _extract_tag_text(block, ["summary", "description"]) or "" + published_text = _extract_tag_text(block, ["pubDate", "published", "updated"]) + items.append( + { + "id": _extract_tag_text(block, ["id", "guid"]) or link, + "title": title, + "link": link, + "summary": summary, + "published": _parse_datetime(published_text), + } + ) + if items: + LOGGER.info("RSS 采用宽松解析提取 %s 条记录", len(items), extra=LOG_EXTRA) + return items + + +def _extract_tag_text(block: str, names: Sequence[str]) -> Optional[str]: + for name in names: + pattern = re.compile(rf"<{name}[^>]*>(.*?)", re.IGNORECASE | re.DOTALL) + match = pattern.search(block) + if match: + text = re.sub(r"<[^>]+>", " ", match.group(1)) + return _clean_text(text) + return None + + +def _extract_link(block: str) -> str: + href_pattern = re.compile(r"]*href=\"([^\"]+)\"[^>]*>", re.IGNORECASE) + match = href_pattern.search(block) + if match: + return match.group(1).strip() + inline_pattern = re.compile(r"]*>(.*?)", re.IGNORECASE | re.DOTALL) + match = inline_pattern.search(block) + if match: + return match.group(1).strip() + return "" + + +def _assign_ts_codes( + item: RssItem, + base_codes: Sequence[str], + keywords: Sequence[str], +) -> List[str]: + matches: set[str] = set() + text = f"{item.title} {item.summary}".lower() + if keywords: + for keyword in keywords: + token = keyword.lower().strip() + if token and token in text: + matches.update(code.strip().upper() for code in base_codes if code) + break + else: + matches.update(code.strip().upper() for code in base_codes if code) + + detected = _detect_ts_codes(text) + matches.update(detected) + return [code for code in matches if code] + + +def _detect_ts_codes(text: str) -> List[str]: + codes: set[str] = set() + for match in A_SH_CODE_PATTERN.finditer(text): + digits, suffix = match.groups() + if suffix: + codes.add(f"{digits}.{suffix.upper()}") + else: + exchange = "SH" if digits.startswith(tuple("569")) else "SZ" + codes.add(f"{digits}.{exchange}") + for match in HK_CODE_PATTERN.finditer(text): + digits = match.group(1) + codes.add(f"{digits.zfill(4)}.HK") + return sorted(codes) + + +def _estimate_sentiment(text: str) -> float: + normalized = text.lower() + score = 0 + for keyword in POSITIVE_KEYWORDS: + if keyword.lower() in normalized: + score += 1 + for keyword in NEGATIVE_KEYWORDS: + if keyword.lower() in normalized: + score -= 1 + if score == 0: + return 0.0 + return max(-1.0, min(1.0, score / 3.0)) + + +def _estimate_heat( + published: datetime, + now: datetime, + code_count: int, + sentiment: float, +) -> float: + delta_hours = max(0.0, (now - published).total_seconds() / 3600.0) + recency = max(0.0, 1.0 - min(delta_hours, 72.0) / 72.0) + coverage_bonus = min(code_count, 3) * 0.05 + sentiment_bonus = min(abs(sentiment) * 0.1, 0.2) + heat = recency + coverage_bonus + sentiment_bonus + return max(0.0, min(1.0, round(heat, 4))) + + +def _parse_datetime(value: Optional[str]) -> Optional[datetime]: + if not value: + return None + try: + dt = parsedate_to_datetime(value) + if dt.tzinfo is not None: + dt = dt.astimezone(timezone.utc).replace(tzinfo=None) + return dt + except (TypeError, ValueError): + pass + + for fmt in ("%Y-%m-%dT%H:%M:%S", "%Y-%m-%d %H:%M:%S"): + try: + return datetime.strptime(value[:19], fmt) + except ValueError: + continue + return None + + +def _clean_text(value: Optional[str]) -> str: + if not value: + return "" + text = re.sub(r"<[^>]+>", " ", value) + return re.sub(r"\s+", " ", text).strip() + + +def _truncate(value: str, length: int) -> str: + if len(value) <= length: + return value + return value[: length - 3].rstrip() + "..." + + +def _normalise_item_id( + raw_id: Optional[str], link: str, title: str, published: datetime +) -> str: + candidate = (raw_id or link or title).strip() + if candidate: + return candidate + fingerprint = f"{title}|{published.isoformat()}" + return hashlib.blake2s(fingerprint.encode("utf-8"), digest_size=16).hexdigest() + + +def _source_from_url(url: str) -> str: + try: + parsed = urlparse(url) + except ValueError: + return url + host = parsed.netloc or url + return host.lower() + + +def _local_name(tag: str) -> str: + if "}" in tag: + return tag.rsplit("}", 1)[-1] + return tag + + +def _child_text(node: ET.Element, candidates: set[str]) -> Optional[str]: + for child in node: + name = _local_name(child.tag) + if name in candidates and child.text: + return child.text.strip() + if name == "link": + href = child.attrib.get("href") + if href: + return href.strip() + return None + + +__all__ = [ + "RssFeedConfig", + "RssItem", + "fetch_rss_feed", + "deduplicate_items", + "save_news_items", + "ingest_configured_rss", + "resolve_rss_sources", +] +def _rss_source_key(url: str) -> str: + return f"RSS|{url}".strip() diff --git a/app/ingest/tushare.py b/app/ingest/tushare.py index 45b97e8..816d6a2 100644 --- a/app/ingest/tushare.py +++ b/app/ingest/tushare.py @@ -16,6 +16,7 @@ try: except ImportError: # pragma: no cover - 运行时提示 ts = None # type: ignore[assignment] +from app.utils import alerts from app.utils.config import get_config from app.utils.db import db_session from app.data.schema import initialize_database @@ -1601,12 +1602,18 @@ def collect_data_coverage(start: date, end: date) -> Dict[str, Dict[str, object] def run_ingestion(job: FetchJob, include_limits: bool = True) -> None: LOGGER.info("启动 TuShare 拉取任务:%s", job.name, extra=LOG_EXTRA) - ensure_data_coverage( - job.start, - job.end, - ts_codes=job.ts_codes, - include_limits=include_limits, - include_extended=True, - force=True, - ) - LOGGER.info("任务 %s 完成", job.name, extra=LOG_EXTRA) + try: + ensure_data_coverage( + job.start, + job.end, + ts_codes=job.ts_codes, + include_limits=include_limits, + include_extended=True, + force=True, + ) + except Exception as exc: + alerts.add_warning("TuShare", f"拉取任务失败:{job.name}", str(exc)) + raise + else: + alerts.clear_warnings("TuShare") + LOGGER.info("任务 %s 完成", job.name, extra=LOG_EXTRA) diff --git a/app/ui/streamlit_app.py b/app/ui/streamlit_app.py index e6696c0..fba67d3 100644 --- a/app/ui/streamlit_app.py +++ b/app/ui/streamlit_app.py @@ -36,6 +36,7 @@ from app.llm.metrics import ( reset as reset_llm_metrics, snapshot as snapshot_llm_metrics, ) +from app.utils import alerts from app.utils.config import ( ALLOWED_LLM_STRATEGIES, DEFAULT_LLM_BASE_URLS, @@ -86,6 +87,9 @@ def render_global_dashboard() -> None: decisions_container = st.sidebar.container() _DASHBOARD_CONTAINERS = (metrics_container, decisions_container) _DASHBOARD_ELEMENTS = _ensure_dashboard_elements(metrics_container, decisions_container) + if st.sidebar.button("清除数据告警", key="clear_data_alerts"): + alerts.clear_warnings() + _update_dashboard_sidebar() if not _SIDEBAR_LISTENER_ATTACHED: register_llm_metrics_listener(_sidebar_metrics_listener) _SIDEBAR_LISTENER_ATTACHED = True @@ -132,6 +136,23 @@ def _update_dashboard_sidebar( else: model_placeholder.info("暂无模型分布数据。") + warnings_placeholder = elements.get("warnings") + if warnings_placeholder is not None: + warnings_placeholder.empty() + warnings = alerts.get_warnings() + if warnings: + lines = [] + for warning in warnings[-10:]: + detail = warning.get("detail") + appendix = f" {detail}" if detail else "" + lines.append( + f"- **{warning['source']}** {warning['message']}{appendix}" + f"\n{warning['timestamp']}" + ) + warnings_placeholder.markdown("\n".join(lines), unsafe_allow_html=True) + else: + warnings_placeholder.info("暂无数据告警。") + decisions = metrics.get("recent_decisions") or llm_recent_decisions(10) if decisions: lines = [] @@ -163,6 +184,8 @@ def _ensure_dashboard_elements(metrics_container, decisions_container) -> Dict[s distribution_expander = metrics_container.expander("调用分布", expanded=False) provider_distribution = distribution_expander.empty() model_distribution = distribution_expander.empty() + warnings_expander = metrics_container.expander("数据告警", expanded=False) + warnings_placeholder = warnings_expander.empty() decisions_container.subheader("最新决策") decisions_list = decisions_container.empty() @@ -173,6 +196,7 @@ def _ensure_dashboard_elements(metrics_container, decisions_container) -> Dict[s "metrics_completion": metrics_completion, "provider_distribution": provider_distribution, "model_distribution": model_distribution, + "warnings": warnings_placeholder, "decisions_list": decisions_list, } return elements @@ -1649,11 +1673,80 @@ def render_tests() -> None: except Exception as exc: # noqa: BLE001 LOGGER.exception("示例 TuShare 拉取失败", extra=LOG_EXTRA) st.error(f"拉取失败:{exc}") + alerts.add_warning("TuShare", "示例拉取失败", str(exc)) + _update_dashboard_sidebar() st.info("注意:TuShare 拉取依赖网络与 Token,若环境未配置将出现错误提示。") st.divider() - days = int(st.number_input("检查窗口(天数)", min_value=30, max_value=1095, value=365, step=30)) + + st.subheader("RSS 数据测试") + st.write("用于验证 RSS 配置是否能够正常抓取新闻并写入数据库。") + rss_url = st.text_input( + "测试 RSS 地址", + value="https://rsshub.app/cls/depth/1000", + help="留空则使用默认配置的全部 RSS 来源。", + ).strip() + rss_hours = int( + st.number_input( + "回溯窗口(小时)", + min_value=1, + max_value=168, + value=24, + step=6, + help="仅抓取最近指定小时内的新闻。", + ) + ) + rss_limit = int( + st.number_input( + "单源抓取条数", + min_value=1, + max_value=200, + value=50, + step=10, + ) + ) + if st.button("运行 RSS 测试"): + from app.ingest import rss as rss_ingest + + LOGGER.info( + "点击 RSS 测试按钮 rss_url=%s hours=%s limit=%s", + rss_url, + rss_hours, + rss_limit, + extra=LOG_EXTRA, + ) + with st.spinner("正在抓取 RSS 新闻..."): + try: + if rss_url: + items = rss_ingest.fetch_rss_feed( + rss_url, + hours_back=rss_hours, + max_items=rss_limit, + ) + count = rss_ingest.save_news_items(items) + else: + count = rss_ingest.ingest_configured_rss( + hours_back=rss_hours, + max_items_per_feed=rss_limit, + ) + st.success(f"RSS 测试完成,新增 {count} 条新闻记录。") + except Exception as exc: # noqa: BLE001 + LOGGER.exception("RSS 测试失败", extra=LOG_EXTRA) + st.error(f"RSS 测试失败:{exc}") + alerts.add_warning("RSS", "RSS 测试执行失败", str(exc)) + _update_dashboard_sidebar() + + st.divider() + days = int( + st.number_input( + "检查窗口(天数)", + min_value=30, + max_value=10950, + value=365, + step=30, + ) + ) LOGGER.debug("检查窗口天数=%s", days, extra=LOG_EXTRA) cfg = get_config() force_refresh = st.checkbox( @@ -1666,8 +1759,8 @@ def render_tests() -> None: LOGGER.info("更新 force_refresh=%s", force_refresh, extra=LOG_EXTRA) save_config() - if st.button("执行开机检查"): - LOGGER.info("点击执行开机检查按钮", extra=LOG_EXTRA) + if st.button("执行手动数据同步"): + LOGGER.info("点击执行手动数据同步按钮", extra=LOG_EXTRA) progress_bar = st.progress(0.0) status_placeholder = st.empty() log_placeholder = st.empty() @@ -1677,23 +1770,25 @@ def render_tests() -> None: progress_bar.progress(min(max(value, 0.0), 1.0)) status_placeholder.write(message) messages.append(message) - LOGGER.debug("开机检查进度:%s -> %.2f", message, value, extra=LOG_EXTRA) + LOGGER.debug("手动数据同步进度:%s -> %.2f", message, value, extra=LOG_EXTRA) - with st.spinner("正在执行开机检查..."): + with st.spinner("正在执行手动数据同步..."): try: report = run_boot_check( days=days, progress_hook=hook, force_refresh=force_refresh, ) - LOGGER.info("开机检查成功", extra=LOG_EXTRA) - st.success("开机检查完成,以下为数据覆盖摘要。") + LOGGER.info("手动数据同步成功", extra=LOG_EXTRA) + st.success("手动数据同步完成,以下为数据覆盖摘要。") st.json(report.to_dict()) if messages: log_placeholder.markdown("\n".join(f"- {msg}" for msg in messages)) except Exception as exc: # noqa: BLE001 - LOGGER.exception("开机检查失败", extra=LOG_EXTRA) - st.error(f"开机检查失败:{exc}") + LOGGER.exception("手动数据同步失败", extra=LOG_EXTRA) + st.error(f"手动数据同步失败:{exc}") + alerts.add_warning("数据同步", "手动数据同步失败", str(exc)) + _update_dashboard_sidebar() if messages: log_placeholder.markdown("\n".join(f"- {msg}" for msg in messages)) finally: @@ -1844,7 +1939,7 @@ def render_tests() -> None: def main() -> None: LOGGER.info("初始化 Streamlit UI", extra=LOG_EXTRA) - st.set_page_config(page_title="多智能体投资助理", layout="wide") + st.set_page_config(page_title="多智能体个人投资助理", layout="wide") render_global_dashboard() tabs = st.tabs(["今日计划", "回测与复盘", "数据与设置", "自检测试"]) LOGGER.debug("Tabs 初始化完成:%s", ["今日计划", "回测与复盘", "数据与设置", "自检测试"], extra=LOG_EXTRA) diff --git a/app/utils/alerts.py b/app/utils/alerts.py new file mode 100644 index 0000000..a403f7b --- /dev/null +++ b/app/utils/alerts.py @@ -0,0 +1,55 @@ +"""Runtime data warning registry for surfacing ingestion issues in UI.""" +from __future__ import annotations + +from datetime import datetime +from threading import Lock +from typing import Dict, List, Optional + + +_ALERTS: List[Dict[str, str]] = [] +_LOCK = Lock() + + +def add_warning(source: str, message: str, detail: Optional[str] = None) -> None: + """Register or update a warning entry.""" + + source = source.strip() or "unknown" + message = message.strip() or "发生未知异常" + timestamp = datetime.utcnow().isoformat(timespec="seconds") + "Z" + + with _LOCK: + for alert in _ALERTS: + if alert["source"] == source and alert["message"] == message: + alert["timestamp"] = timestamp + if detail: + alert["detail"] = detail + return + entry = { + "source": source, + "message": message, + "timestamp": timestamp, + } + if detail: + entry["detail"] = detail + _ALERTS.append(entry) + if len(_ALERTS) > 50: + del _ALERTS[:-50] + + +def get_warnings() -> List[Dict[str, str]]: + """Return a copy of current warning entries.""" + + with _LOCK: + return list(_ALERTS) + + +def clear_warnings(source: Optional[str] = None) -> None: + """Clear warnings entirely or for a specific source.""" + + with _LOCK: + if source is None: + _ALERTS.clear() + return + source = source.strip() + _ALERTS[:] = [alert for alert in _ALERTS if alert["source"] != source] + diff --git a/app/utils/config.py b/app/utils/config.py index 5834dab..9d6963d 100644 --- a/app/utils/config.py +++ b/app/utils/config.py @@ -334,7 +334,7 @@ class AppConfig: """User configurable settings persisted in a simple structure.""" tushare_token: Optional[str] = None - rss_sources: Dict[str, bool] = field(default_factory=dict) + rss_sources: Dict[str, object] = field(default_factory=dict) decision_method: str = "nash" data_paths: DataPaths = field(default_factory=DataPaths) agent_weights: AgentWeights = field(default_factory=AgentWeights) @@ -402,6 +402,14 @@ def _load_from_file(cfg: AppConfig) -> None: if "decision_method" in payload: cfg.decision_method = str(payload.get("decision_method") or cfg.decision_method) + rss_payload = payload.get("rss_sources") + if isinstance(rss_payload, dict): + resolved_rss: Dict[str, object] = {} + for key, value in rss_payload.items(): + if isinstance(value, (bool, dict)): + resolved_rss[str(key)] = value + cfg.rss_sources = resolved_rss + weights_payload = payload.get("agent_weights") if isinstance(weights_payload, dict): cfg.agent_weights.update_from_dict(weights_payload) @@ -572,6 +580,7 @@ def save_config(cfg: AppConfig | None = None) -> None: "tushare_token": cfg.tushare_token, "force_refresh": cfg.force_refresh, "decision_method": cfg.decision_method, + "rss_sources": cfg.rss_sources, "agent_weights": cfg.agent_weights.as_dict(), "llm": { "strategy": cfg.llm.strategy if cfg.llm.strategy in ALLOWED_LLM_STRATEGIES else "single", diff --git a/requirements.txt b/requirements.txt index 24f881a..f574244 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,5 @@ streamlit>=1.30 tushare>=1.2 requests>=2.31 python-box>=7.0 +pytest>=7.0 +feedparser>=6.0 diff --git a/tests/test_rss_ingest.py b/tests/test_rss_ingest.py new file mode 100644 index 0000000..e7dab9e --- /dev/null +++ b/tests/test_rss_ingest.py @@ -0,0 +1,89 @@ +"""Tests for RSS ingestion utilities.""" +from __future__ import annotations + +from datetime import datetime, timedelta + +import pytest + +from app.ingest import rss +from app.utils import alerts +from app.utils.config import DataPaths, get_config +from app.utils.db import db_session + + +@pytest.fixture() +def isolated_db(tmp_path): + """Temporarily redirect database paths for isolated writes.""" + + cfg = get_config() + original_paths = cfg.data_paths + tmp_root = tmp_path / "data" + tmp_root.mkdir(parents=True, exist_ok=True) + cfg.data_paths = DataPaths(root=tmp_root) + alerts.clear_warnings() + try: + yield + finally: + cfg.data_paths = original_paths + + +def test_fetch_rss_feed_parses_entries(monkeypatch): + sample_feed = ( + """ + + + Example + + 新闻:公司利好公告 + https://example.com/a + + Wed, 01 Jan 2025 08:30:00 GMT + a + + + + """ + ).encode("utf-8") + + monkeypatch.setattr( + rss, + "_download_feed", + lambda url, timeout, max_retries, retry_backoff, retry_jitter: sample_feed, + ) + + items = rss.fetch_rss_feed("https://example.com/rss", hours_back=24) + + assert len(items) == 1 + item = items[0] + assert item.title.startswith("新闻") + assert item.source == "example.com" + + +def test_save_news_items_writes_and_deduplicates(isolated_db): + published = datetime.utcnow() - timedelta(hours=1) + rss_item = rss.RssItem( + id="test-id", + title="利好消息推动股价", + link="https://example.com/news/test", + published=published, + summary="这是一条利好消息。", + source="测试来源", + ts_codes=("000001.SZ",), + ) + + inserted = rss.save_news_items([rss_item]) + assert inserted >= 1 + + with db_session(read_only=True) as conn: + row = conn.execute( + "SELECT ts_code, sentiment, heat FROM news WHERE id = ?", + ("test-id::000001.SZ",), + ).fetchone() + assert row is not None + assert row["ts_code"] == "000001.SZ" + assert row["sentiment"] >= 0 # 利好关键词应给出非负情绪 + assert 0 <= row["heat"] <= 1 + + # 再次保存同一条新闻应被忽略 + duplicate = rss.save_news_items([rss_item]) + assert duplicate == 0