This commit is contained in:
sam 2025-09-29 16:01:37 +08:00
parent a6564cdced
commit b3f2f5b4fc
6 changed files with 530 additions and 123 deletions

View File

@ -74,7 +74,15 @@ class DepartmentDecision:
class DepartmentAgent: class DepartmentAgent:
"""Wraps LLM ensemble logic for a single analytical department.""" """Wraps LLM ensemble logic for a single analytical department."""
ALLOWED_TABLES: ClassVar[List[str]] = ["daily", "daily_basic"] ALLOWED_TABLES: ClassVar[List[str]] = [
"daily",
"daily_basic",
"stk_limit",
"suspend",
"heat_daily",
"news",
"index_daily",
]
def __init__( def __init__(
self, self,

View File

@ -4,7 +4,6 @@ from __future__ import annotations
import json import json
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import date from datetime import date
from statistics import mean, pstdev
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional
from app.agents.base import AgentAction, AgentContext from app.agents.base import AgentAction, AgentContext
@ -16,50 +15,13 @@ from app.utils.data_access import DataBroker
from app.utils.config import get_config from app.utils.config import get_config
from app.utils.db import db_session from app.utils.db import db_session
from app.utils.logging import get_logger from app.utils.logging import get_logger
from app.core.indicators import momentum, normalize, rolling_mean, volatility
LOGGER = get_logger(__name__) LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "backtest"} LOG_EXTRA = {"stage": "backtest"}
def _compute_momentum(values: List[float], window: int) -> float:
if window <= 0 or len(values) < window:
return 0.0
latest = values[0]
past = values[window - 1]
if past is None or past == 0:
return 0.0
try:
return (latest / past) - 1.0
except ZeroDivisionError:
return 0.0
def _compute_volatility(values: List[float], window: int) -> float:
if len(values) < 2 or window <= 1:
return 0.0
limit = min(window, len(values) - 1)
returns: List[float] = []
for idx in range(limit):
current = values[idx]
previous = values[idx + 1]
if previous is None or previous == 0:
continue
returns.append((current / previous) - 1.0)
if len(returns) < 2:
return 0.0
return float(pstdev(returns))
def _normalize(value: Any, factor: float) -> float:
try:
numeric = float(value)
except (TypeError, ValueError):
return 0.0
if factor <= 0:
return max(0.0, min(1.0, numeric))
return max(0.0, min(1.0, numeric / factor))
@dataclass @dataclass
class BtConfig: class BtConfig:
@ -143,9 +105,9 @@ class BacktestEngine:
window=60, window=60,
) )
close_values = [value for _date, value in closes] close_values = [value for _date, value in closes]
mom20 = _compute_momentum(close_values, 20) mom20 = momentum(close_values, 20)
mom60 = _compute_momentum(close_values, 60) mom60 = momentum(close_values, 60)
volat20 = _compute_volatility(close_values, 20) volat20 = volatility(close_values, 20)
turnover_series = self.data_broker.fetch_series( turnover_series = self.data_broker.fetch_series(
"daily_basic", "daily_basic",
@ -155,10 +117,31 @@ class BacktestEngine:
window=20, window=20,
) )
turnover_values = [value for _date, value in turnover_series] turnover_values = [value for _date, value in turnover_series]
turn20 = mean(turnover_values) if turnover_values else 0.0 turn20 = rolling_mean(turnover_values, 20)
liquidity_score = _normalize(turn20, factor=20.0) liquidity_score = normalize(turn20, factor=20.0)
cost_penalty = _normalize(scope_values.get("daily_basic.volume_ratio", 0.0), factor=50.0) cost_penalty = normalize(
scope_values.get("daily_basic.volume_ratio", 0.0),
factor=50.0,
)
scope_values.setdefault("factors.mom_20", mom20)
scope_values.setdefault("factors.mom_60", mom60)
scope_values.setdefault("factors.volat_20", volat20)
scope_values.setdefault("factors.turn_20", turn20)
scope_values.setdefault("news.sentiment_index", 0.0)
scope_values.setdefault("news.heat_score", 0.0)
if scope_values.get("macro.industry_heat") is None:
scope_values["macro.industry_heat"] = 0.5
if scope_values.get("macro.relative_strength") is None:
peer_strength = scope_values.get("index.performance_peers")
if peer_strength is None:
peer_strength = 0.5
scope_values["macro.relative_strength"] = peer_strength
scope_values.setdefault(
"index.performance_peers",
scope_values.get("macro.relative_strength", 0.5),
)
latest_close = scope_values.get("daily.close", 0.0) latest_close = scope_values.get("daily.close", 0.0)
latest_pct = scope_values.get("daily.pct_chg", 0.0) latest_pct = scope_values.get("daily.pct_chg", 0.0)

1
app/core/__init__.py Normal file
View File

@ -0,0 +1 @@
"""Core utilities shared across application layers."""

86
app/core/indicators.py Normal file
View File

@ -0,0 +1,86 @@
"""Reusable quantitative indicator helpers."""
from __future__ import annotations
from statistics import pstdev
from typing import Iterable, Sequence
def _to_float_list(values: Iterable[object]) -> list[float]:
cleaned: list[float] = []
for value in values:
try:
cleaned.append(float(value))
except (TypeError, ValueError):
continue
return cleaned
def momentum(series: Sequence[object], window: int) -> float:
"""Return simple momentum ratio over ``window`` periods.
``series`` is expected to be ordered from most recent to oldest. ``0.0`` is
returned when insufficient history or the denominator is invalid.
"""
if window <= 0:
return 0.0
numeric = _to_float_list(series)
if len(numeric) < window:
return 0.0
latest = numeric[0]
past = numeric[window - 1]
if past == 0.0:
return 0.0
try:
return (latest / past) - 1.0
except ZeroDivisionError:
return 0.0
def volatility(series: Sequence[object], window: int) -> float:
"""Compute population standard deviation of simple returns."""
if window <= 1:
return 0.0
numeric = _to_float_list(series)
if len(numeric) < 2:
return 0.0
limit = min(window, len(numeric) - 1)
returns: list[float] = []
for idx in range(limit):
current = numeric[idx]
previous = numeric[idx + 1]
if previous == 0.0:
continue
returns.append((current / previous) - 1.0)
if len(returns) < 2:
return 0.0
return float(pstdev(returns))
def rolling_mean(series: Sequence[object], window: int) -> float:
"""Return the arithmetic mean over the latest ``window`` observations."""
if window <= 0:
return 0.0
numeric = _to_float_list(series)
if not numeric:
return 0.0
subset = numeric[: min(window, len(numeric))]
if not subset:
return 0.0
return float(sum(subset) / len(subset))
def normalize(value: object, *, factor: float | None = None, clamp: tuple[float, float] = (0.0, 1.0)) -> float:
"""Clamp ``value`` into the ``clamp`` interval after optional scaling."""
if clamp[0] > clamp[1]:
raise ValueError("clamp minimum cannot exceed maximum")
try:
numeric = float(value)
except (TypeError, ValueError):
return clamp[0]
if factor and factor > 0:
numeric = numeric / factor
return max(clamp[0], min(clamp[1], numeric))

View File

@ -243,6 +243,7 @@ def _default_departments() -> Dict[str, DepartmentSettings]:
"daily_basic.turnover_rate", "daily_basic.turnover_rate",
"factors.mom_20", "factors.mom_20",
"factors.mom_60", "factors.mom_60",
"factors.volat_20",
], ],
"prompt": "你主导动量风格研究,关注价格与成交量的加速变化,需在保持纪律的前提下判定短期多空倾向。", "prompt": "你主导动量风格研究,关注价格与成交量的加速变化,需在保持纪律的前提下判定短期多空倾向。",
}, },
@ -253,8 +254,9 @@ def _default_departments() -> Dict[str, DepartmentSettings]:
"data_scope": [ "data_scope": [
"daily_basic.pe", "daily_basic.pe",
"daily_basic.pb", "daily_basic.pb",
"daily_basic.roe", "daily_basic.ps",
"fundamental.growth", "daily_basic.dv_ratio",
"factors.turn_20",
], ],
"prompt": "你负责价值与质量评估,应结合估值分位、盈利持续性及安全边际给出配置建议。", "prompt": "你负责价值与质量评估,应结合估值分位、盈利持续性及安全边际给出配置建议。",
}, },
@ -265,7 +267,6 @@ def _default_departments() -> Dict[str, DepartmentSettings]:
"data_scope": [ "data_scope": [
"news.sentiment_index", "news.sentiment_index",
"news.heat_score", "news.heat_score",
"events.latest_headlines",
], ],
"prompt": "你专注新闻和事件驱动,应评估正负面舆情对标的短线波动的可能影响。", "prompt": "你专注新闻和事件驱动,应评估正负面舆情对标的短线波动的可能影响。",
}, },
@ -275,8 +276,11 @@ def _default_departments() -> Dict[str, DepartmentSettings]:
"description": "衡量成交活跃度与交易成本,控制进出场的实现可能性。", "description": "衡量成交活跃度与交易成本,控制进出场的实现可能性。",
"data_scope": [ "data_scope": [
"daily_basic.volume_ratio", "daily_basic.volume_ratio",
"daily_basic.turnover_rate",
"daily_basic.turnover_rate_f", "daily_basic.turnover_rate_f",
"market.spread_estimate", "factors.turn_20",
"stk_limit.up_limit",
"stk_limit.down_limit",
], ],
"prompt": "你负责评估该标的的流动性与滑点风险,需要提出可执行的仓位调整建议。", "prompt": "你负责评估该标的的流动性与滑点风险,需要提出可执行的仓位调整建议。",
}, },
@ -286,8 +290,8 @@ def _default_departments() -> Dict[str, DepartmentSettings]:
"description": "追踪宏观与行业景气度,为行业配置和风险偏好提供参考。", "description": "追踪宏观与行业景气度,为行业配置和风险偏好提供参考。",
"data_scope": [ "data_scope": [
"macro.industry_heat", "macro.industry_heat",
"macro.liquidity_cycle",
"index.performance_peers", "index.performance_peers",
"macro.relative_strength",
], ],
"prompt": "你负责宏观与行业研判,应结合宏观周期、行业景气与相对强弱给出方向性意见。", "prompt": "你负责宏观与行业研判,应结合宏观周期、行业景气与相对强弱给出方向性意见。",
}, },
@ -296,9 +300,10 @@ def _default_departments() -> Dict[str, DepartmentSettings]:
"title": "风险控制部门", "title": "风险控制部门",
"description": "监控极端风险、合规与交易限制,必要时行使否决。", "description": "监控极端风险、合规与交易限制,必要时行使否决。",
"data_scope": [ "data_scope": [
"market.limit_flags", "daily.pct_chg",
"portfolio.position", "suspend.suspend_type",
"risk.alerts", "stk_limit.up_limit",
"stk_limit.down_limit",
], ],
"prompt": "你负责风险控制,应识别停牌、涨跌停、持仓约束等因素,必要时提出减仓或观望建议。", "prompt": "你负责风险控制,应识别停牌、涨跌停、持仓约束等因素,必要时提出减仓或观望建议。",
}, },
@ -608,9 +613,8 @@ def save_config(cfg: AppConfig | None = None) -> None:
try: try:
path.parent.mkdir(parents=True, exist_ok=True) path.parent.mkdir(parents=True, exist_ok=True)
tmp_path = path.with_suffix(path.suffix + ".tmp") if path.suffix else path.with_name(path.name + ".tmp") with path.open("w", encoding="utf-8") as fh:
tmp_path.write_text(serialized, encoding="utf-8") fh.write(serialized)
tmp_path.replace(path)
LOGGER.info("配置已写入:%s", path) LOGGER.info("配置已写入:%s", path)
except OSError: except OSError:
LOGGER.exception("配置写入失败:%s", path) LOGGER.exception("配置写入失败:%s", path)

View File

@ -2,11 +2,14 @@
from __future__ import annotations from __future__ import annotations
import re import re
import sqlite3
from dataclasses import dataclass from dataclasses import dataclass
from typing import ClassVar, Dict, Iterable, List, Optional, Sequence, Tuple from datetime import datetime, timedelta
from typing import Any, ClassVar, Dict, Iterable, List, Optional, Sequence, Tuple
from .db import db_session from .db import db_session
from .logging import get_logger from .logging import get_logger
from app.core.indicators import momentum, normalize, rolling_mean, volatility
LOGGER = get_logger(__name__) LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "data_broker"} LOG_EXTRA = {"stage": "data_broker"}
@ -38,6 +41,27 @@ def parse_field_path(path: str) -> Tuple[str, str] | None:
return _safe_split(path) return _safe_split(path)
def _parse_trade_date(value: object) -> Optional[datetime]:
if value is None:
return None
text = str(value).strip()
if not text:
return None
text = text.replace("-", "")
try:
return datetime.strptime(text[:8], "%Y%m%d")
except ValueError:
return None
def _start_of_day(dt: datetime) -> str:
return dt.strftime("%Y-%m-%d 00:00:00")
def _end_of_day(dt: datetime) -> str:
return dt.strftime("%Y-%m-%d 23:59:59")
@dataclass @dataclass
class DataBroker: class DataBroker:
"""Lightweight data access helper for agent/LLM consumption.""" """Lightweight data access helper for agent/LLM consumption."""
@ -65,60 +89,77 @@ class DataBroker:
}, },
} }
MAX_WINDOW: ClassVar[int] = 120 MAX_WINDOW: ClassVar[int] = 120
BENCHMARK_INDEX: ClassVar[str] = "000300.SH"
def fetch_latest( def fetch_latest(
self, self,
ts_code: str, ts_code: str,
trade_date: str, trade_date: str,
fields: Iterable[str], fields: Iterable[str],
) -> Dict[str, float]: ) -> Dict[str, Any]:
"""Fetch the latest value (<= trade_date) for each requested field.""" """Fetch the latest value (<= trade_date) for each requested field."""
grouped: Dict[str, List[str]] = {} grouped: Dict[str, List[str]] = {}
field_map: Dict[Tuple[str, str], List[str]] = {} field_map: Dict[Tuple[str, str], List[str]] = {}
derived_cache: Dict[str, Any] = {}
results: Dict[str, Any] = {}
for item in fields: for item in fields:
if not item: if not item:
continue continue
resolved = self.resolve_field(str(item)) field_name = str(item)
resolved = self.resolve_field(field_name)
if not resolved: if not resolved:
derived = self._resolve_derived_field(
ts_code,
trade_date,
field_name,
derived_cache,
)
if derived is not None:
results[field_name] = derived
continue continue
table, column = resolved table, column = resolved
grouped.setdefault(table, []) grouped.setdefault(table, [])
if column not in grouped[table]: if column not in grouped[table]:
grouped[table].append(column) grouped[table].append(column)
field_map.setdefault((table, column), []).append(str(item)) field_map.setdefault((table, column), []).append(field_name)
if not grouped: if not grouped:
return {} return results
results: Dict[str, float] = {} try:
with db_session(read_only=True) as conn: with db_session(read_only=True) as conn:
for table, columns in grouped.items(): for table, columns in grouped.items():
joined_cols = ", ".join(columns) joined_cols = ", ".join(columns)
query = ( query = (
f"SELECT trade_date, {joined_cols} FROM {table} " f"SELECT trade_date, {joined_cols} FROM {table} "
"WHERE ts_code = ? AND trade_date <= ? " "WHERE ts_code = ? AND trade_date <= ? "
"ORDER BY trade_date DESC LIMIT 1" "ORDER BY trade_date DESC LIMIT 1"
)
try:
row = conn.execute(query, (ts_code, trade_date)).fetchone()
except Exception as exc: # noqa: BLE001
LOGGER.debug(
"查询失败 table=%s fields=%s err=%s",
table,
columns,
exc,
extra=LOG_EXTRA,
) )
continue try:
if not row: row = conn.execute(query, (ts_code, trade_date)).fetchone()
continue except Exception as exc: # noqa: BLE001
for column in columns: LOGGER.debug(
value = row[column] "查询失败 table=%s fields=%s err=%s",
if value is None: table,
columns,
exc,
extra=LOG_EXTRA,
)
continue continue
for original in field_map.get((table, column), [f"{table}.{column}"]): if not row:
results[original] = float(value) continue
for column in columns:
value = row[column]
if value is None:
continue
for original in field_map.get((table, column), [f"{table}.{column}"]):
try:
results[original] = float(value)
except (TypeError, ValueError):
results[original] = value
except sqlite3.OperationalError as exc:
LOGGER.debug("数据库只读连接失败:%s", exc, extra=LOG_EXTRA)
return results return results
def fetch_series( def fetch_series(
@ -149,18 +190,28 @@ class DataBroker:
"WHERE ts_code = ? AND trade_date <= ? " "WHERE ts_code = ? AND trade_date <= ? "
"ORDER BY trade_date DESC LIMIT ?" "ORDER BY trade_date DESC LIMIT ?"
) )
with db_session(read_only=True) as conn: try:
try: with db_session(read_only=True) as conn:
rows = conn.execute(query, (ts_code, end_date, window)).fetchall() try:
except Exception as exc: # noqa: BLE001 rows = conn.execute(query, (ts_code, end_date, window)).fetchall()
LOGGER.debug( except Exception as exc: # noqa: BLE001
"时间序列查询失败 table=%s column=%s err=%s", LOGGER.debug(
table, "时间序列查询失败 table=%s column=%s err=%s",
column, table,
exc, column,
extra=LOG_EXTRA, exc,
) extra=LOG_EXTRA,
return [] )
return []
except sqlite3.OperationalError as exc:
LOGGER.debug(
"时间序列连接失败 table=%s column=%s err=%s",
table,
column,
exc,
extra=LOG_EXTRA,
)
return []
series: List[Tuple[str, float]] = [] series: List[Tuple[str, float]] = []
for row in rows: for row in rows:
value = row[resolved] value = row[resolved]
@ -185,18 +236,27 @@ class DataBroker:
f"SELECT 1 FROM {table} WHERE ts_code = ? AND {where_clause} LIMIT 1" f"SELECT 1 FROM {table} WHERE ts_code = ? AND {where_clause} LIMIT 1"
) )
bind_params = (ts_code, *params) bind_params = (ts_code, *params)
with db_session(read_only=True) as conn: try:
try: with db_session(read_only=True) as conn:
row = conn.execute(query, bind_params).fetchone() try:
except Exception as exc: # noqa: BLE001 row = conn.execute(query, bind_params).fetchone()
LOGGER.debug( except Exception as exc: # noqa: BLE001
"flag 查询失败 table=%s where=%s err=%s", LOGGER.debug(
table, "flag 查询失败 table=%s where=%s err=%s",
where_clause, table,
exc, where_clause,
extra=LOG_EXTRA, exc,
) extra=LOG_EXTRA,
return False )
return False
except sqlite3.OperationalError as exc:
LOGGER.debug(
"flag 查询连接失败 table=%s err=%s",
table,
exc,
extra=LOG_EXTRA,
)
return False
return row is not None return row is not None
def fetch_table_rows( def fetch_table_rows(
@ -231,23 +291,288 @@ class DataBroker:
params = (ts_code, window) params = (ts_code, window)
results: List[Dict[str, object]] = [] results: List[Dict[str, object]] = []
with db_session(read_only=True) as conn: try:
try: with db_session(read_only=True) as conn:
rows = conn.execute(query, params).fetchall() try:
except Exception as exc: # noqa: BLE001 rows = conn.execute(query, params).fetchall()
LOGGER.debug( except Exception as exc: # noqa: BLE001
"表查询失败 table=%s err=%s", LOGGER.debug(
table, "表查询失败 table=%s err=%s",
exc, table,
extra=LOG_EXTRA, exc,
) extra=LOG_EXTRA,
return [] )
return []
except sqlite3.OperationalError as exc:
LOGGER.debug(
"表连接失败 table=%s err=%s",
table,
exc,
extra=LOG_EXTRA,
)
return []
for row in rows: for row in rows:
record = {col: row[col] for col in columns} record = {col: row[col] for col in columns}
results.append(record) results.append(record)
return results return results
def _resolve_derived_field(
self,
ts_code: str,
trade_date: str,
field: str,
cache: Dict[str, Any],
) -> Optional[Any]:
if field in cache:
return cache[field]
value: Optional[Any] = None
if field == "factors.mom_20":
value = self._derived_price_momentum(ts_code, trade_date, 20)
elif field == "factors.mom_60":
value = self._derived_price_momentum(ts_code, trade_date, 60)
elif field == "factors.volat_20":
value = self._derived_price_volatility(ts_code, trade_date, 20)
elif field == "factors.turn_20":
value = self._derived_turnover_mean(ts_code, trade_date, 20)
elif field == "news.sentiment_index":
rows = cache.get("__news_rows__")
if rows is None:
rows = self._fetch_recent_news(ts_code, trade_date)
cache["__news_rows__"] = rows
value = self._news_sentiment_from_rows(rows)
elif field == "news.heat_score":
rows = cache.get("__news_rows__")
if rows is None:
rows = self._fetch_recent_news(ts_code, trade_date)
cache["__news_rows__"] = rows
value = self._news_heat_from_rows(rows)
elif field == "macro.industry_heat":
value = self._derived_industry_heat(ts_code, trade_date)
elif field in {"macro.relative_strength", "index.performance_peers"}:
value = self._derived_relative_strength(ts_code, trade_date, cache)
cache[field] = value
return value
def _derived_price_momentum(
self,
ts_code: str,
trade_date: str,
window: int,
) -> Optional[float]:
series = self.fetch_series("daily", "close", ts_code, trade_date, window)
values = [value for _dt, value in series]
if not values:
return None
return momentum(values, window)
def _derived_price_volatility(
self,
ts_code: str,
trade_date: str,
window: int,
) -> Optional[float]:
series = self.fetch_series("daily", "close", ts_code, trade_date, window)
values = [value for _dt, value in series]
if len(values) < 2:
return None
return volatility(values, window)
def _derived_turnover_mean(
self,
ts_code: str,
trade_date: str,
window: int,
) -> Optional[float]:
series = self.fetch_series(
"daily_basic",
"turnover_rate",
ts_code,
trade_date,
window,
)
values = [value for _dt, value in series]
if not values:
return None
return rolling_mean(values, window)
def _fetch_recent_news(
self,
ts_code: str,
trade_date: str,
days: int = 3,
limit: int = 120,
) -> List[Dict[str, Any]]:
baseline = _parse_trade_date(trade_date)
if baseline is None:
return []
start = _start_of_day(baseline - timedelta(days=days))
end = _end_of_day(baseline)
query = (
"SELECT sentiment, heat FROM news "
"WHERE ts_code = ? AND pub_time BETWEEN ? AND ? "
"ORDER BY pub_time DESC LIMIT ?"
)
try:
with db_session(read_only=True) as conn:
rows = conn.execute(query, (ts_code, start, end, limit)).fetchall()
except sqlite3.OperationalError as exc:
LOGGER.debug(
"新闻查询连接失败 ts_code=%s err=%s",
ts_code,
exc,
extra=LOG_EXTRA,
)
return []
except Exception as exc: # noqa: BLE001
LOGGER.debug(
"新闻查询失败 ts_code=%s err=%s",
ts_code,
exc,
extra=LOG_EXTRA,
)
return []
return [dict(row) for row in rows]
@staticmethod
def _news_sentiment_from_rows(rows: List[Dict[str, Any]]) -> Optional[float]:
sentiments: List[float] = []
for row in rows:
value = row.get("sentiment")
if value is None:
continue
try:
sentiments.append(float(value))
except (TypeError, ValueError):
continue
if not sentiments:
return None
avg = sum(sentiments) / len(sentiments)
return max(-1.0, min(1.0, avg))
@staticmethod
def _news_heat_from_rows(rows: List[Dict[str, Any]]) -> Optional[float]:
if not rows:
return None
total_heat = 0.0
for row in rows:
value = row.get("heat")
if value is None:
continue
try:
total_heat += max(float(value), 0.0)
except (TypeError, ValueError):
continue
if total_heat > 0:
return normalize(total_heat, factor=100.0)
return normalize(len(rows), factor=20.0)
def _derived_industry_heat(self, ts_code: str, trade_date: str) -> Optional[float]:
industry = self._lookup_industry(ts_code)
if not industry:
return None
query = (
"SELECT heat FROM heat_daily "
"WHERE scope = ? AND key = ? AND trade_date <= ? "
"ORDER BY trade_date DESC LIMIT 1"
)
try:
with db_session(read_only=True) as conn:
row = conn.execute(query, ("industry", industry, trade_date)).fetchone()
except sqlite3.OperationalError as exc:
LOGGER.debug(
"行业热度查询失败 ts_code=%s err=%s",
ts_code,
exc,
extra=LOG_EXTRA,
)
return None
except Exception as exc: # noqa: BLE001
LOGGER.debug(
"行业热度读取异常 ts_code=%s err=%s",
ts_code,
exc,
extra=LOG_EXTRA,
)
return None
if not row:
return None
heat_value = row["heat"]
if heat_value is None:
return None
return normalize(heat_value, factor=100.0)
def _lookup_industry(self, ts_code: str) -> Optional[str]:
cache = getattr(self, "_industry_cache", None)
if cache is None:
cache = {}
self._industry_cache = cache
if ts_code in cache:
return cache[ts_code]
query = "SELECT industry FROM stock_basic WHERE ts_code = ?"
try:
with db_session(read_only=True) as conn:
row = conn.execute(query, (ts_code,)).fetchone()
except sqlite3.OperationalError as exc:
LOGGER.debug(
"行业查询连接失败 ts_code=%s err=%s",
ts_code,
exc,
extra=LOG_EXTRA,
)
cache[ts_code] = None
return None
except Exception as exc: # noqa: BLE001
LOGGER.debug(
"行业查询失败 ts_code=%s err=%s",
ts_code,
exc,
extra=LOG_EXTRA,
)
cache[ts_code] = None
return None
industry = None
if row:
industry = row["industry"]
cache[ts_code] = industry
return industry
def _derived_relative_strength(
self,
ts_code: str,
trade_date: str,
cache: Dict[str, Any],
) -> Optional[float]:
window = 20
series = self.fetch_series("daily", "close", ts_code, trade_date, max(window, 30))
values = [value for _dt, value in series]
if not values:
return None
stock_momentum = momentum(values, window)
bench_key = f"__benchmark_mom_{window}"
benchmark = cache.get(bench_key)
if benchmark is None:
benchmark = self._index_momentum(trade_date, window)
cache[bench_key] = benchmark
diff = stock_momentum if benchmark is None else stock_momentum - benchmark
diff = max(-0.2, min(0.2, diff))
return (diff + 0.2) / 0.4
def _index_momentum(self, trade_date: str, window: int) -> Optional[float]:
series = self.fetch_series(
"index_daily",
"close",
self.BENCHMARK_INDEX,
trade_date,
window,
)
values = [value for _dt, value in series]
if not values:
return None
return momentum(values, window)
def resolve_field(self, field: str) -> Optional[Tuple[str, str]]: def resolve_field(self, field: str) -> Optional[Tuple[str, str]]:
normalized = _safe_split(field) normalized = _safe_split(field)
if not normalized: if not normalized: