add index membership schema and default indices initialization
This commit is contained in:
parent
4b7d64f915
commit
8f74525875
@ -5,6 +5,7 @@ import sqlite3
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Iterable, List
|
||||
|
||||
from app.data.schema_index import initialize_index_membership_tables, add_default_indices
|
||||
from app.utils.db import db_session
|
||||
|
||||
|
||||
@ -565,6 +566,10 @@ def initialize_database() -> MigrationResult:
|
||||
with db_session() as session:
|
||||
cursor = session.cursor()
|
||||
|
||||
# 初始化指数相关表
|
||||
initialize_index_membership_tables(session)
|
||||
add_default_indices()
|
||||
|
||||
# 创建表
|
||||
for statement in SCHEMA_STATEMENTS:
|
||||
try:
|
||||
|
||||
42
app/data/schema_index.py
Normal file
42
app/data/schema_index.py
Normal file
@ -0,0 +1,42 @@
|
||||
"""SQL schema for index membership."""
|
||||
from app.utils.db import db_session
|
||||
|
||||
def initialize_index_membership_tables(conn):
|
||||
"""Create tables for tracking index membership."""
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS index_weight (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
index_code VARCHAR(10) NOT NULL,
|
||||
trade_date VARCHAR(8) NOT NULL,
|
||||
ts_code VARCHAR(10) NOT NULL,
|
||||
weight FLOAT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
INDEX idx_index_weight_lookup (index_code, trade_date)
|
||||
)
|
||||
""")
|
||||
|
||||
def add_default_indices():
|
||||
"""Add default index list."""
|
||||
indices = [
|
||||
("000300.SH", "沪深300"),
|
||||
("000905.SH", "中证500"),
|
||||
("000852.SH", "中证1000")
|
||||
]
|
||||
with db_session() as conn:
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS indices (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
index_code VARCHAR(10) NOT NULL UNIQUE,
|
||||
name VARCHAR(50) NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
|
||||
for code, name in indices:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR IGNORE INTO indices (index_code, name)
|
||||
VALUES (?, ?)
|
||||
""",
|
||||
(code, name)
|
||||
)
|
||||
@ -48,6 +48,20 @@ def render_stock_evaluation() -> None:
|
||||
"市场类因子": [f for f in DEFAULT_FACTORS if f.name.startswith("market_")]
|
||||
}
|
||||
|
||||
# 定义默认选中的关键常用因子
|
||||
DEFAULT_SELECTED_FACTORS = {
|
||||
"mom_5", # 5日动量
|
||||
"mom_20", # 20日动量
|
||||
"mom_60", # 60日动量
|
||||
"volat_20", # 20日波动率
|
||||
"turn_5", # 5日换手率
|
||||
"turn_20", # 20日换手率
|
||||
"val_pe_score", # PE评分
|
||||
"val_pb_score", # PB评分
|
||||
"volume_ratio_score", # 量比评分
|
||||
"risk_penalty" # 风险惩罚项
|
||||
}
|
||||
|
||||
selected_factors = []
|
||||
for group_name, factors in factor_groups.items():
|
||||
if factors:
|
||||
@ -56,7 +70,7 @@ def render_stock_evaluation() -> None:
|
||||
for i, factor in enumerate(factors):
|
||||
if cols[i % 3].checkbox(
|
||||
factor.name,
|
||||
value=factor.name in selected_factors,
|
||||
value=factor.name in DEFAULT_SELECTED_FACTORS,
|
||||
help=factor.description if hasattr(factor, 'description') else None
|
||||
):
|
||||
selected_factors.append(factor.name)
|
||||
|
||||
@ -1100,6 +1100,53 @@ class DataBroker:
|
||||
LOGGER.exception("强制刷新数据失败: %s", exc, extra=LOG_EXTRA)
|
||||
return False
|
||||
|
||||
def get_index_stocks(
|
||||
self,
|
||||
index_code: str,
|
||||
trade_date: str,
|
||||
min_weight: float = 0.0
|
||||
) -> List[str]:
|
||||
"""获取指数成分股列表。
|
||||
|
||||
Args:
|
||||
index_code: 指数代码(如 000300.SH)
|
||||
trade_date: 交易日期
|
||||
min_weight: 最小权重筛选
|
||||
|
||||
Returns:
|
||||
成分股代码列表
|
||||
"""
|
||||
try:
|
||||
with db_session(read_only=True) as conn:
|
||||
# 获取小于等于给定日期的最新一期成分股
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT DISTINCT ts_code
|
||||
FROM index_weight
|
||||
WHERE index_code = ?
|
||||
AND trade_date = (
|
||||
SELECT MAX(trade_date)
|
||||
FROM index_weight
|
||||
WHERE index_code = ?
|
||||
AND trade_date <= ?
|
||||
)
|
||||
AND weight >= ?
|
||||
ORDER BY weight DESC
|
||||
""",
|
||||
(index_code, index_code, trade_date, min_weight)
|
||||
).fetchall()
|
||||
|
||||
return [row["ts_code"] for row in rows if row and row["ts_code"]]
|
||||
except Exception as exc:
|
||||
LOGGER.exception(
|
||||
"获取指数成分股失败 index=%s date=%s err=%s",
|
||||
index_code,
|
||||
trade_date,
|
||||
exc,
|
||||
extra=LOG_EXTRA
|
||||
)
|
||||
return []
|
||||
|
||||
def get_refresh_status(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""获取当前所有补数任务的状态。
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user