update
This commit is contained in:
parent
a6564cdced
commit
b3f2f5b4fc
@ -74,7 +74,15 @@ class DepartmentDecision:
|
|||||||
class DepartmentAgent:
|
class DepartmentAgent:
|
||||||
"""Wraps LLM ensemble logic for a single analytical department."""
|
"""Wraps LLM ensemble logic for a single analytical department."""
|
||||||
|
|
||||||
ALLOWED_TABLES: ClassVar[List[str]] = ["daily", "daily_basic"]
|
ALLOWED_TABLES: ClassVar[List[str]] = [
|
||||||
|
"daily",
|
||||||
|
"daily_basic",
|
||||||
|
"stk_limit",
|
||||||
|
"suspend",
|
||||||
|
"heat_daily",
|
||||||
|
"news",
|
||||||
|
"index_daily",
|
||||||
|
]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -4,7 +4,6 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import date
|
from datetime import date
|
||||||
from statistics import mean, pstdev
|
|
||||||
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional
|
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional
|
||||||
|
|
||||||
from app.agents.base import AgentAction, AgentContext
|
from app.agents.base import AgentAction, AgentContext
|
||||||
@ -16,50 +15,13 @@ from app.utils.data_access import DataBroker
|
|||||||
from app.utils.config import get_config
|
from app.utils.config import get_config
|
||||||
from app.utils.db import db_session
|
from app.utils.db import db_session
|
||||||
from app.utils.logging import get_logger
|
from app.utils.logging import get_logger
|
||||||
|
from app.core.indicators import momentum, normalize, rolling_mean, volatility
|
||||||
|
|
||||||
|
|
||||||
LOGGER = get_logger(__name__)
|
LOGGER = get_logger(__name__)
|
||||||
LOG_EXTRA = {"stage": "backtest"}
|
LOG_EXTRA = {"stage": "backtest"}
|
||||||
|
|
||||||
|
|
||||||
def _compute_momentum(values: List[float], window: int) -> float:
|
|
||||||
if window <= 0 or len(values) < window:
|
|
||||||
return 0.0
|
|
||||||
latest = values[0]
|
|
||||||
past = values[window - 1]
|
|
||||||
if past is None or past == 0:
|
|
||||||
return 0.0
|
|
||||||
try:
|
|
||||||
return (latest / past) - 1.0
|
|
||||||
except ZeroDivisionError:
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
|
|
||||||
def _compute_volatility(values: List[float], window: int) -> float:
|
|
||||||
if len(values) < 2 or window <= 1:
|
|
||||||
return 0.0
|
|
||||||
limit = min(window, len(values) - 1)
|
|
||||||
returns: List[float] = []
|
|
||||||
for idx in range(limit):
|
|
||||||
current = values[idx]
|
|
||||||
previous = values[idx + 1]
|
|
||||||
if previous is None or previous == 0:
|
|
||||||
continue
|
|
||||||
returns.append((current / previous) - 1.0)
|
|
||||||
if len(returns) < 2:
|
|
||||||
return 0.0
|
|
||||||
return float(pstdev(returns))
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize(value: Any, factor: float) -> float:
|
|
||||||
try:
|
|
||||||
numeric = float(value)
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
return 0.0
|
|
||||||
if factor <= 0:
|
|
||||||
return max(0.0, min(1.0, numeric))
|
|
||||||
return max(0.0, min(1.0, numeric / factor))
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BtConfig:
|
class BtConfig:
|
||||||
@ -143,9 +105,9 @@ class BacktestEngine:
|
|||||||
window=60,
|
window=60,
|
||||||
)
|
)
|
||||||
close_values = [value for _date, value in closes]
|
close_values = [value for _date, value in closes]
|
||||||
mom20 = _compute_momentum(close_values, 20)
|
mom20 = momentum(close_values, 20)
|
||||||
mom60 = _compute_momentum(close_values, 60)
|
mom60 = momentum(close_values, 60)
|
||||||
volat20 = _compute_volatility(close_values, 20)
|
volat20 = volatility(close_values, 20)
|
||||||
|
|
||||||
turnover_series = self.data_broker.fetch_series(
|
turnover_series = self.data_broker.fetch_series(
|
||||||
"daily_basic",
|
"daily_basic",
|
||||||
@ -155,10 +117,31 @@ class BacktestEngine:
|
|||||||
window=20,
|
window=20,
|
||||||
)
|
)
|
||||||
turnover_values = [value for _date, value in turnover_series]
|
turnover_values = [value for _date, value in turnover_series]
|
||||||
turn20 = mean(turnover_values) if turnover_values else 0.0
|
turn20 = rolling_mean(turnover_values, 20)
|
||||||
|
|
||||||
liquidity_score = _normalize(turn20, factor=20.0)
|
liquidity_score = normalize(turn20, factor=20.0)
|
||||||
cost_penalty = _normalize(scope_values.get("daily_basic.volume_ratio", 0.0), factor=50.0)
|
cost_penalty = normalize(
|
||||||
|
scope_values.get("daily_basic.volume_ratio", 0.0),
|
||||||
|
factor=50.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
scope_values.setdefault("factors.mom_20", mom20)
|
||||||
|
scope_values.setdefault("factors.mom_60", mom60)
|
||||||
|
scope_values.setdefault("factors.volat_20", volat20)
|
||||||
|
scope_values.setdefault("factors.turn_20", turn20)
|
||||||
|
scope_values.setdefault("news.sentiment_index", 0.0)
|
||||||
|
scope_values.setdefault("news.heat_score", 0.0)
|
||||||
|
if scope_values.get("macro.industry_heat") is None:
|
||||||
|
scope_values["macro.industry_heat"] = 0.5
|
||||||
|
if scope_values.get("macro.relative_strength") is None:
|
||||||
|
peer_strength = scope_values.get("index.performance_peers")
|
||||||
|
if peer_strength is None:
|
||||||
|
peer_strength = 0.5
|
||||||
|
scope_values["macro.relative_strength"] = peer_strength
|
||||||
|
scope_values.setdefault(
|
||||||
|
"index.performance_peers",
|
||||||
|
scope_values.get("macro.relative_strength", 0.5),
|
||||||
|
)
|
||||||
|
|
||||||
latest_close = scope_values.get("daily.close", 0.0)
|
latest_close = scope_values.get("daily.close", 0.0)
|
||||||
latest_pct = scope_values.get("daily.pct_chg", 0.0)
|
latest_pct = scope_values.get("daily.pct_chg", 0.0)
|
||||||
|
|||||||
1
app/core/__init__.py
Normal file
1
app/core/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
"""Core utilities shared across application layers."""
|
||||||
86
app/core/indicators.py
Normal file
86
app/core/indicators.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
"""Reusable quantitative indicator helpers."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from statistics import pstdev
|
||||||
|
from typing import Iterable, Sequence
|
||||||
|
|
||||||
|
|
||||||
|
def _to_float_list(values: Iterable[object]) -> list[float]:
|
||||||
|
cleaned: list[float] = []
|
||||||
|
for value in values:
|
||||||
|
try:
|
||||||
|
cleaned.append(float(value))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
continue
|
||||||
|
return cleaned
|
||||||
|
|
||||||
|
|
||||||
|
def momentum(series: Sequence[object], window: int) -> float:
|
||||||
|
"""Return simple momentum ratio over ``window`` periods.
|
||||||
|
|
||||||
|
``series`` is expected to be ordered from most recent to oldest. ``0.0`` is
|
||||||
|
returned when insufficient history or the denominator is invalid.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if window <= 0:
|
||||||
|
return 0.0
|
||||||
|
numeric = _to_float_list(series)
|
||||||
|
if len(numeric) < window:
|
||||||
|
return 0.0
|
||||||
|
latest = numeric[0]
|
||||||
|
past = numeric[window - 1]
|
||||||
|
if past == 0.0:
|
||||||
|
return 0.0
|
||||||
|
try:
|
||||||
|
return (latest / past) - 1.0
|
||||||
|
except ZeroDivisionError:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def volatility(series: Sequence[object], window: int) -> float:
|
||||||
|
"""Compute population standard deviation of simple returns."""
|
||||||
|
|
||||||
|
if window <= 1:
|
||||||
|
return 0.0
|
||||||
|
numeric = _to_float_list(series)
|
||||||
|
if len(numeric) < 2:
|
||||||
|
return 0.0
|
||||||
|
limit = min(window, len(numeric) - 1)
|
||||||
|
returns: list[float] = []
|
||||||
|
for idx in range(limit):
|
||||||
|
current = numeric[idx]
|
||||||
|
previous = numeric[idx + 1]
|
||||||
|
if previous == 0.0:
|
||||||
|
continue
|
||||||
|
returns.append((current / previous) - 1.0)
|
||||||
|
if len(returns) < 2:
|
||||||
|
return 0.0
|
||||||
|
return float(pstdev(returns))
|
||||||
|
|
||||||
|
|
||||||
|
def rolling_mean(series: Sequence[object], window: int) -> float:
|
||||||
|
"""Return the arithmetic mean over the latest ``window`` observations."""
|
||||||
|
|
||||||
|
if window <= 0:
|
||||||
|
return 0.0
|
||||||
|
numeric = _to_float_list(series)
|
||||||
|
if not numeric:
|
||||||
|
return 0.0
|
||||||
|
subset = numeric[: min(window, len(numeric))]
|
||||||
|
if not subset:
|
||||||
|
return 0.0
|
||||||
|
return float(sum(subset) / len(subset))
|
||||||
|
|
||||||
|
|
||||||
|
def normalize(value: object, *, factor: float | None = None, clamp: tuple[float, float] = (0.0, 1.0)) -> float:
|
||||||
|
"""Clamp ``value`` into the ``clamp`` interval after optional scaling."""
|
||||||
|
|
||||||
|
if clamp[0] > clamp[1]:
|
||||||
|
raise ValueError("clamp minimum cannot exceed maximum")
|
||||||
|
try:
|
||||||
|
numeric = float(value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return clamp[0]
|
||||||
|
if factor and factor > 0:
|
||||||
|
numeric = numeric / factor
|
||||||
|
return max(clamp[0], min(clamp[1], numeric))
|
||||||
@ -243,6 +243,7 @@ def _default_departments() -> Dict[str, DepartmentSettings]:
|
|||||||
"daily_basic.turnover_rate",
|
"daily_basic.turnover_rate",
|
||||||
"factors.mom_20",
|
"factors.mom_20",
|
||||||
"factors.mom_60",
|
"factors.mom_60",
|
||||||
|
"factors.volat_20",
|
||||||
],
|
],
|
||||||
"prompt": "你主导动量风格研究,关注价格与成交量的加速变化,需在保持纪律的前提下判定短期多空倾向。",
|
"prompt": "你主导动量风格研究,关注价格与成交量的加速变化,需在保持纪律的前提下判定短期多空倾向。",
|
||||||
},
|
},
|
||||||
@ -253,8 +254,9 @@ def _default_departments() -> Dict[str, DepartmentSettings]:
|
|||||||
"data_scope": [
|
"data_scope": [
|
||||||
"daily_basic.pe",
|
"daily_basic.pe",
|
||||||
"daily_basic.pb",
|
"daily_basic.pb",
|
||||||
"daily_basic.roe",
|
"daily_basic.ps",
|
||||||
"fundamental.growth",
|
"daily_basic.dv_ratio",
|
||||||
|
"factors.turn_20",
|
||||||
],
|
],
|
||||||
"prompt": "你负责价值与质量评估,应结合估值分位、盈利持续性及安全边际给出配置建议。",
|
"prompt": "你负责价值与质量评估,应结合估值分位、盈利持续性及安全边际给出配置建议。",
|
||||||
},
|
},
|
||||||
@ -265,7 +267,6 @@ def _default_departments() -> Dict[str, DepartmentSettings]:
|
|||||||
"data_scope": [
|
"data_scope": [
|
||||||
"news.sentiment_index",
|
"news.sentiment_index",
|
||||||
"news.heat_score",
|
"news.heat_score",
|
||||||
"events.latest_headlines",
|
|
||||||
],
|
],
|
||||||
"prompt": "你专注新闻和事件驱动,应评估正负面舆情对标的短线波动的可能影响。",
|
"prompt": "你专注新闻和事件驱动,应评估正负面舆情对标的短线波动的可能影响。",
|
||||||
},
|
},
|
||||||
@ -275,8 +276,11 @@ def _default_departments() -> Dict[str, DepartmentSettings]:
|
|||||||
"description": "衡量成交活跃度与交易成本,控制进出场的实现可能性。",
|
"description": "衡量成交活跃度与交易成本,控制进出场的实现可能性。",
|
||||||
"data_scope": [
|
"data_scope": [
|
||||||
"daily_basic.volume_ratio",
|
"daily_basic.volume_ratio",
|
||||||
|
"daily_basic.turnover_rate",
|
||||||
"daily_basic.turnover_rate_f",
|
"daily_basic.turnover_rate_f",
|
||||||
"market.spread_estimate",
|
"factors.turn_20",
|
||||||
|
"stk_limit.up_limit",
|
||||||
|
"stk_limit.down_limit",
|
||||||
],
|
],
|
||||||
"prompt": "你负责评估该标的的流动性与滑点风险,需要提出可执行的仓位调整建议。",
|
"prompt": "你负责评估该标的的流动性与滑点风险,需要提出可执行的仓位调整建议。",
|
||||||
},
|
},
|
||||||
@ -286,8 +290,8 @@ def _default_departments() -> Dict[str, DepartmentSettings]:
|
|||||||
"description": "追踪宏观与行业景气度,为行业配置和风险偏好提供参考。",
|
"description": "追踪宏观与行业景气度,为行业配置和风险偏好提供参考。",
|
||||||
"data_scope": [
|
"data_scope": [
|
||||||
"macro.industry_heat",
|
"macro.industry_heat",
|
||||||
"macro.liquidity_cycle",
|
|
||||||
"index.performance_peers",
|
"index.performance_peers",
|
||||||
|
"macro.relative_strength",
|
||||||
],
|
],
|
||||||
"prompt": "你负责宏观与行业研判,应结合宏观周期、行业景气与相对强弱给出方向性意见。",
|
"prompt": "你负责宏观与行业研判,应结合宏观周期、行业景气与相对强弱给出方向性意见。",
|
||||||
},
|
},
|
||||||
@ -296,9 +300,10 @@ def _default_departments() -> Dict[str, DepartmentSettings]:
|
|||||||
"title": "风险控制部门",
|
"title": "风险控制部门",
|
||||||
"description": "监控极端风险、合规与交易限制,必要时行使否决。",
|
"description": "监控极端风险、合规与交易限制,必要时行使否决。",
|
||||||
"data_scope": [
|
"data_scope": [
|
||||||
"market.limit_flags",
|
"daily.pct_chg",
|
||||||
"portfolio.position",
|
"suspend.suspend_type",
|
||||||
"risk.alerts",
|
"stk_limit.up_limit",
|
||||||
|
"stk_limit.down_limit",
|
||||||
],
|
],
|
||||||
"prompt": "你负责风险控制,应识别停牌、涨跌停、持仓约束等因素,必要时提出减仓或观望建议。",
|
"prompt": "你负责风险控制,应识别停牌、涨跌停、持仓约束等因素,必要时提出减仓或观望建议。",
|
||||||
},
|
},
|
||||||
@ -608,9 +613,8 @@ def save_config(cfg: AppConfig | None = None) -> None:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
path.parent.mkdir(parents=True, exist_ok=True)
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
tmp_path = path.with_suffix(path.suffix + ".tmp") if path.suffix else path.with_name(path.name + ".tmp")
|
with path.open("w", encoding="utf-8") as fh:
|
||||||
tmp_path.write_text(serialized, encoding="utf-8")
|
fh.write(serialized)
|
||||||
tmp_path.replace(path)
|
|
||||||
LOGGER.info("配置已写入:%s", path)
|
LOGGER.info("配置已写入:%s", path)
|
||||||
except OSError:
|
except OSError:
|
||||||
LOGGER.exception("配置写入失败:%s", path)
|
LOGGER.exception("配置写入失败:%s", path)
|
||||||
|
|||||||
@ -2,11 +2,14 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
import sqlite3
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import ClassVar, Dict, Iterable, List, Optional, Sequence, Tuple
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Any, ClassVar, Dict, Iterable, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
from .db import db_session
|
from .db import db_session
|
||||||
from .logging import get_logger
|
from .logging import get_logger
|
||||||
|
from app.core.indicators import momentum, normalize, rolling_mean, volatility
|
||||||
|
|
||||||
LOGGER = get_logger(__name__)
|
LOGGER = get_logger(__name__)
|
||||||
LOG_EXTRA = {"stage": "data_broker"}
|
LOG_EXTRA = {"stage": "data_broker"}
|
||||||
@ -38,6 +41,27 @@ def parse_field_path(path: str) -> Tuple[str, str] | None:
|
|||||||
return _safe_split(path)
|
return _safe_split(path)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_trade_date(value: object) -> Optional[datetime]:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
text = str(value).strip()
|
||||||
|
if not text:
|
||||||
|
return None
|
||||||
|
text = text.replace("-", "")
|
||||||
|
try:
|
||||||
|
return datetime.strptime(text[:8], "%Y%m%d")
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _start_of_day(dt: datetime) -> str:
|
||||||
|
return dt.strftime("%Y-%m-%d 00:00:00")
|
||||||
|
|
||||||
|
|
||||||
|
def _end_of_day(dt: datetime) -> str:
|
||||||
|
return dt.strftime("%Y-%m-%d 23:59:59")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataBroker:
|
class DataBroker:
|
||||||
"""Lightweight data access helper for agent/LLM consumption."""
|
"""Lightweight data access helper for agent/LLM consumption."""
|
||||||
@ -65,33 +89,45 @@ class DataBroker:
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
MAX_WINDOW: ClassVar[int] = 120
|
MAX_WINDOW: ClassVar[int] = 120
|
||||||
|
BENCHMARK_INDEX: ClassVar[str] = "000300.SH"
|
||||||
|
|
||||||
def fetch_latest(
|
def fetch_latest(
|
||||||
self,
|
self,
|
||||||
ts_code: str,
|
ts_code: str,
|
||||||
trade_date: str,
|
trade_date: str,
|
||||||
fields: Iterable[str],
|
fields: Iterable[str],
|
||||||
) -> Dict[str, float]:
|
) -> Dict[str, Any]:
|
||||||
"""Fetch the latest value (<= trade_date) for each requested field."""
|
"""Fetch the latest value (<= trade_date) for each requested field."""
|
||||||
|
|
||||||
grouped: Dict[str, List[str]] = {}
|
grouped: Dict[str, List[str]] = {}
|
||||||
field_map: Dict[Tuple[str, str], List[str]] = {}
|
field_map: Dict[Tuple[str, str], List[str]] = {}
|
||||||
|
derived_cache: Dict[str, Any] = {}
|
||||||
|
results: Dict[str, Any] = {}
|
||||||
for item in fields:
|
for item in fields:
|
||||||
if not item:
|
if not item:
|
||||||
continue
|
continue
|
||||||
resolved = self.resolve_field(str(item))
|
field_name = str(item)
|
||||||
|
resolved = self.resolve_field(field_name)
|
||||||
if not resolved:
|
if not resolved:
|
||||||
|
derived = self._resolve_derived_field(
|
||||||
|
ts_code,
|
||||||
|
trade_date,
|
||||||
|
field_name,
|
||||||
|
derived_cache,
|
||||||
|
)
|
||||||
|
if derived is not None:
|
||||||
|
results[field_name] = derived
|
||||||
continue
|
continue
|
||||||
table, column = resolved
|
table, column = resolved
|
||||||
grouped.setdefault(table, [])
|
grouped.setdefault(table, [])
|
||||||
if column not in grouped[table]:
|
if column not in grouped[table]:
|
||||||
grouped[table].append(column)
|
grouped[table].append(column)
|
||||||
field_map.setdefault((table, column), []).append(str(item))
|
field_map.setdefault((table, column), []).append(field_name)
|
||||||
|
|
||||||
if not grouped:
|
if not grouped:
|
||||||
return {}
|
return results
|
||||||
|
|
||||||
results: Dict[str, float] = {}
|
try:
|
||||||
with db_session(read_only=True) as conn:
|
with db_session(read_only=True) as conn:
|
||||||
for table, columns in grouped.items():
|
for table, columns in grouped.items():
|
||||||
joined_cols = ", ".join(columns)
|
joined_cols = ", ".join(columns)
|
||||||
@ -118,7 +154,12 @@ class DataBroker:
|
|||||||
if value is None:
|
if value is None:
|
||||||
continue
|
continue
|
||||||
for original in field_map.get((table, column), [f"{table}.{column}"]):
|
for original in field_map.get((table, column), [f"{table}.{column}"]):
|
||||||
|
try:
|
||||||
results[original] = float(value)
|
results[original] = float(value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
results[original] = value
|
||||||
|
except sqlite3.OperationalError as exc:
|
||||||
|
LOGGER.debug("数据库只读连接失败:%s", exc, extra=LOG_EXTRA)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def fetch_series(
|
def fetch_series(
|
||||||
@ -149,6 +190,7 @@ class DataBroker:
|
|||||||
"WHERE ts_code = ? AND trade_date <= ? "
|
"WHERE ts_code = ? AND trade_date <= ? "
|
||||||
"ORDER BY trade_date DESC LIMIT ?"
|
"ORDER BY trade_date DESC LIMIT ?"
|
||||||
)
|
)
|
||||||
|
try:
|
||||||
with db_session(read_only=True) as conn:
|
with db_session(read_only=True) as conn:
|
||||||
try:
|
try:
|
||||||
rows = conn.execute(query, (ts_code, end_date, window)).fetchall()
|
rows = conn.execute(query, (ts_code, end_date, window)).fetchall()
|
||||||
@ -161,6 +203,15 @@ class DataBroker:
|
|||||||
extra=LOG_EXTRA,
|
extra=LOG_EXTRA,
|
||||||
)
|
)
|
||||||
return []
|
return []
|
||||||
|
except sqlite3.OperationalError as exc:
|
||||||
|
LOGGER.debug(
|
||||||
|
"时间序列连接失败 table=%s column=%s err=%s",
|
||||||
|
table,
|
||||||
|
column,
|
||||||
|
exc,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
return []
|
||||||
series: List[Tuple[str, float]] = []
|
series: List[Tuple[str, float]] = []
|
||||||
for row in rows:
|
for row in rows:
|
||||||
value = row[resolved]
|
value = row[resolved]
|
||||||
@ -185,6 +236,7 @@ class DataBroker:
|
|||||||
f"SELECT 1 FROM {table} WHERE ts_code = ? AND {where_clause} LIMIT 1"
|
f"SELECT 1 FROM {table} WHERE ts_code = ? AND {where_clause} LIMIT 1"
|
||||||
)
|
)
|
||||||
bind_params = (ts_code, *params)
|
bind_params = (ts_code, *params)
|
||||||
|
try:
|
||||||
with db_session(read_only=True) as conn:
|
with db_session(read_only=True) as conn:
|
||||||
try:
|
try:
|
||||||
row = conn.execute(query, bind_params).fetchone()
|
row = conn.execute(query, bind_params).fetchone()
|
||||||
@ -197,6 +249,14 @@ class DataBroker:
|
|||||||
extra=LOG_EXTRA,
|
extra=LOG_EXTRA,
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
except sqlite3.OperationalError as exc:
|
||||||
|
LOGGER.debug(
|
||||||
|
"flag 查询连接失败 table=%s err=%s",
|
||||||
|
table,
|
||||||
|
exc,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
return False
|
||||||
return row is not None
|
return row is not None
|
||||||
|
|
||||||
def fetch_table_rows(
|
def fetch_table_rows(
|
||||||
@ -231,6 +291,7 @@ class DataBroker:
|
|||||||
params = (ts_code, window)
|
params = (ts_code, window)
|
||||||
|
|
||||||
results: List[Dict[str, object]] = []
|
results: List[Dict[str, object]] = []
|
||||||
|
try:
|
||||||
with db_session(read_only=True) as conn:
|
with db_session(read_only=True) as conn:
|
||||||
try:
|
try:
|
||||||
rows = conn.execute(query, params).fetchall()
|
rows = conn.execute(query, params).fetchall()
|
||||||
@ -242,12 +303,276 @@ class DataBroker:
|
|||||||
extra=LOG_EXTRA,
|
extra=LOG_EXTRA,
|
||||||
)
|
)
|
||||||
return []
|
return []
|
||||||
|
except sqlite3.OperationalError as exc:
|
||||||
|
LOGGER.debug(
|
||||||
|
"表连接失败 table=%s err=%s",
|
||||||
|
table,
|
||||||
|
exc,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
for row in rows:
|
for row in rows:
|
||||||
record = {col: row[col] for col in columns}
|
record = {col: row[col] for col in columns}
|
||||||
results.append(record)
|
results.append(record)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
def _resolve_derived_field(
|
||||||
|
self,
|
||||||
|
ts_code: str,
|
||||||
|
trade_date: str,
|
||||||
|
field: str,
|
||||||
|
cache: Dict[str, Any],
|
||||||
|
) -> Optional[Any]:
|
||||||
|
if field in cache:
|
||||||
|
return cache[field]
|
||||||
|
|
||||||
|
value: Optional[Any] = None
|
||||||
|
if field == "factors.mom_20":
|
||||||
|
value = self._derived_price_momentum(ts_code, trade_date, 20)
|
||||||
|
elif field == "factors.mom_60":
|
||||||
|
value = self._derived_price_momentum(ts_code, trade_date, 60)
|
||||||
|
elif field == "factors.volat_20":
|
||||||
|
value = self._derived_price_volatility(ts_code, trade_date, 20)
|
||||||
|
elif field == "factors.turn_20":
|
||||||
|
value = self._derived_turnover_mean(ts_code, trade_date, 20)
|
||||||
|
elif field == "news.sentiment_index":
|
||||||
|
rows = cache.get("__news_rows__")
|
||||||
|
if rows is None:
|
||||||
|
rows = self._fetch_recent_news(ts_code, trade_date)
|
||||||
|
cache["__news_rows__"] = rows
|
||||||
|
value = self._news_sentiment_from_rows(rows)
|
||||||
|
elif field == "news.heat_score":
|
||||||
|
rows = cache.get("__news_rows__")
|
||||||
|
if rows is None:
|
||||||
|
rows = self._fetch_recent_news(ts_code, trade_date)
|
||||||
|
cache["__news_rows__"] = rows
|
||||||
|
value = self._news_heat_from_rows(rows)
|
||||||
|
elif field == "macro.industry_heat":
|
||||||
|
value = self._derived_industry_heat(ts_code, trade_date)
|
||||||
|
elif field in {"macro.relative_strength", "index.performance_peers"}:
|
||||||
|
value = self._derived_relative_strength(ts_code, trade_date, cache)
|
||||||
|
|
||||||
|
cache[field] = value
|
||||||
|
return value
|
||||||
|
|
||||||
|
def _derived_price_momentum(
|
||||||
|
self,
|
||||||
|
ts_code: str,
|
||||||
|
trade_date: str,
|
||||||
|
window: int,
|
||||||
|
) -> Optional[float]:
|
||||||
|
series = self.fetch_series("daily", "close", ts_code, trade_date, window)
|
||||||
|
values = [value for _dt, value in series]
|
||||||
|
if not values:
|
||||||
|
return None
|
||||||
|
return momentum(values, window)
|
||||||
|
|
||||||
|
def _derived_price_volatility(
|
||||||
|
self,
|
||||||
|
ts_code: str,
|
||||||
|
trade_date: str,
|
||||||
|
window: int,
|
||||||
|
) -> Optional[float]:
|
||||||
|
series = self.fetch_series("daily", "close", ts_code, trade_date, window)
|
||||||
|
values = [value for _dt, value in series]
|
||||||
|
if len(values) < 2:
|
||||||
|
return None
|
||||||
|
return volatility(values, window)
|
||||||
|
|
||||||
|
def _derived_turnover_mean(
|
||||||
|
self,
|
||||||
|
ts_code: str,
|
||||||
|
trade_date: str,
|
||||||
|
window: int,
|
||||||
|
) -> Optional[float]:
|
||||||
|
series = self.fetch_series(
|
||||||
|
"daily_basic",
|
||||||
|
"turnover_rate",
|
||||||
|
ts_code,
|
||||||
|
trade_date,
|
||||||
|
window,
|
||||||
|
)
|
||||||
|
values = [value for _dt, value in series]
|
||||||
|
if not values:
|
||||||
|
return None
|
||||||
|
return rolling_mean(values, window)
|
||||||
|
|
||||||
|
def _fetch_recent_news(
|
||||||
|
self,
|
||||||
|
ts_code: str,
|
||||||
|
trade_date: str,
|
||||||
|
days: int = 3,
|
||||||
|
limit: int = 120,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
baseline = _parse_trade_date(trade_date)
|
||||||
|
if baseline is None:
|
||||||
|
return []
|
||||||
|
start = _start_of_day(baseline - timedelta(days=days))
|
||||||
|
end = _end_of_day(baseline)
|
||||||
|
query = (
|
||||||
|
"SELECT sentiment, heat FROM news "
|
||||||
|
"WHERE ts_code = ? AND pub_time BETWEEN ? AND ? "
|
||||||
|
"ORDER BY pub_time DESC LIMIT ?"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
with db_session(read_only=True) as conn:
|
||||||
|
rows = conn.execute(query, (ts_code, start, end, limit)).fetchall()
|
||||||
|
except sqlite3.OperationalError as exc:
|
||||||
|
LOGGER.debug(
|
||||||
|
"新闻查询连接失败 ts_code=%s err=%s",
|
||||||
|
ts_code,
|
||||||
|
exc,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
LOGGER.debug(
|
||||||
|
"新闻查询失败 ts_code=%s err=%s",
|
||||||
|
ts_code,
|
||||||
|
exc,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
return [dict(row) for row in rows]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _news_sentiment_from_rows(rows: List[Dict[str, Any]]) -> Optional[float]:
|
||||||
|
sentiments: List[float] = []
|
||||||
|
for row in rows:
|
||||||
|
value = row.get("sentiment")
|
||||||
|
if value is None:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
sentiments.append(float(value))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
continue
|
||||||
|
if not sentiments:
|
||||||
|
return None
|
||||||
|
avg = sum(sentiments) / len(sentiments)
|
||||||
|
return max(-1.0, min(1.0, avg))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _news_heat_from_rows(rows: List[Dict[str, Any]]) -> Optional[float]:
|
||||||
|
if not rows:
|
||||||
|
return None
|
||||||
|
total_heat = 0.0
|
||||||
|
for row in rows:
|
||||||
|
value = row.get("heat")
|
||||||
|
if value is None:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
total_heat += max(float(value), 0.0)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
continue
|
||||||
|
if total_heat > 0:
|
||||||
|
return normalize(total_heat, factor=100.0)
|
||||||
|
return normalize(len(rows), factor=20.0)
|
||||||
|
|
||||||
|
def _derived_industry_heat(self, ts_code: str, trade_date: str) -> Optional[float]:
|
||||||
|
industry = self._lookup_industry(ts_code)
|
||||||
|
if not industry:
|
||||||
|
return None
|
||||||
|
query = (
|
||||||
|
"SELECT heat FROM heat_daily "
|
||||||
|
"WHERE scope = ? AND key = ? AND trade_date <= ? "
|
||||||
|
"ORDER BY trade_date DESC LIMIT 1"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
with db_session(read_only=True) as conn:
|
||||||
|
row = conn.execute(query, ("industry", industry, trade_date)).fetchone()
|
||||||
|
except sqlite3.OperationalError as exc:
|
||||||
|
LOGGER.debug(
|
||||||
|
"行业热度查询失败 ts_code=%s err=%s",
|
||||||
|
ts_code,
|
||||||
|
exc,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
LOGGER.debug(
|
||||||
|
"行业热度读取异常 ts_code=%s err=%s",
|
||||||
|
ts_code,
|
||||||
|
exc,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
if not row:
|
||||||
|
return None
|
||||||
|
heat_value = row["heat"]
|
||||||
|
if heat_value is None:
|
||||||
|
return None
|
||||||
|
return normalize(heat_value, factor=100.0)
|
||||||
|
|
||||||
|
def _lookup_industry(self, ts_code: str) -> Optional[str]:
|
||||||
|
cache = getattr(self, "_industry_cache", None)
|
||||||
|
if cache is None:
|
||||||
|
cache = {}
|
||||||
|
self._industry_cache = cache
|
||||||
|
if ts_code in cache:
|
||||||
|
return cache[ts_code]
|
||||||
|
query = "SELECT industry FROM stock_basic WHERE ts_code = ?"
|
||||||
|
try:
|
||||||
|
with db_session(read_only=True) as conn:
|
||||||
|
row = conn.execute(query, (ts_code,)).fetchone()
|
||||||
|
except sqlite3.OperationalError as exc:
|
||||||
|
LOGGER.debug(
|
||||||
|
"行业查询连接失败 ts_code=%s err=%s",
|
||||||
|
ts_code,
|
||||||
|
exc,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
cache[ts_code] = None
|
||||||
|
return None
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
LOGGER.debug(
|
||||||
|
"行业查询失败 ts_code=%s err=%s",
|
||||||
|
ts_code,
|
||||||
|
exc,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
cache[ts_code] = None
|
||||||
|
return None
|
||||||
|
industry = None
|
||||||
|
if row:
|
||||||
|
industry = row["industry"]
|
||||||
|
cache[ts_code] = industry
|
||||||
|
return industry
|
||||||
|
|
||||||
|
def _derived_relative_strength(
|
||||||
|
self,
|
||||||
|
ts_code: str,
|
||||||
|
trade_date: str,
|
||||||
|
cache: Dict[str, Any],
|
||||||
|
) -> Optional[float]:
|
||||||
|
window = 20
|
||||||
|
series = self.fetch_series("daily", "close", ts_code, trade_date, max(window, 30))
|
||||||
|
values = [value for _dt, value in series]
|
||||||
|
if not values:
|
||||||
|
return None
|
||||||
|
stock_momentum = momentum(values, window)
|
||||||
|
bench_key = f"__benchmark_mom_{window}"
|
||||||
|
benchmark = cache.get(bench_key)
|
||||||
|
if benchmark is None:
|
||||||
|
benchmark = self._index_momentum(trade_date, window)
|
||||||
|
cache[bench_key] = benchmark
|
||||||
|
diff = stock_momentum if benchmark is None else stock_momentum - benchmark
|
||||||
|
diff = max(-0.2, min(0.2, diff))
|
||||||
|
return (diff + 0.2) / 0.4
|
||||||
|
|
||||||
|
def _index_momentum(self, trade_date: str, window: int) -> Optional[float]:
|
||||||
|
series = self.fetch_series(
|
||||||
|
"index_daily",
|
||||||
|
"close",
|
||||||
|
self.BENCHMARK_INDEX,
|
||||||
|
trade_date,
|
||||||
|
window,
|
||||||
|
)
|
||||||
|
values = [value for _dt, value in series]
|
||||||
|
if not values:
|
||||||
|
return None
|
||||||
|
return momentum(values, window)
|
||||||
|
|
||||||
def resolve_field(self, field: str) -> Optional[Tuple[str, str]]:
|
def resolve_field(self, field: str) -> Optional[Tuple[str, str]]:
|
||||||
normalized = _safe_split(field)
|
normalized = _safe_split(field)
|
||||||
if not normalized:
|
if not normalized:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user