add index membership schema and default indices initialization

This commit is contained in:
Your Name 2025-10-08 11:50:59 +08:00
parent 4b7d64f915
commit 8f74525875
4 changed files with 109 additions and 1 deletions

View File

@ -5,6 +5,7 @@ import sqlite3
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Iterable, List from typing import Iterable, List
from app.data.schema_index import initialize_index_membership_tables, add_default_indices
from app.utils.db import db_session from app.utils.db import db_session
@ -565,6 +566,10 @@ def initialize_database() -> MigrationResult:
with db_session() as session: with db_session() as session:
cursor = session.cursor() cursor = session.cursor()
# 初始化指数相关表
initialize_index_membership_tables(session)
add_default_indices()
# 创建表 # 创建表
for statement in SCHEMA_STATEMENTS: for statement in SCHEMA_STATEMENTS:
try: try:

42
app/data/schema_index.py Normal file
View File

@ -0,0 +1,42 @@
"""SQL schema for index membership."""
from app.utils.db import db_session
def initialize_index_membership_tables(conn):
"""Create tables for tracking index membership."""
conn.execute("""
CREATE TABLE IF NOT EXISTS index_weight (
id INTEGER PRIMARY KEY AUTOINCREMENT,
index_code VARCHAR(10) NOT NULL,
trade_date VARCHAR(8) NOT NULL,
ts_code VARCHAR(10) NOT NULL,
weight FLOAT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
INDEX idx_index_weight_lookup (index_code, trade_date)
)
""")
def add_default_indices():
"""Add default index list."""
indices = [
("000300.SH", "沪深300"),
("000905.SH", "中证500"),
("000852.SH", "中证1000")
]
with db_session() as conn:
conn.execute("""
CREATE TABLE IF NOT EXISTS indices (
id INTEGER PRIMARY KEY AUTOINCREMENT,
index_code VARCHAR(10) NOT NULL UNIQUE,
name VARCHAR(50) NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
for code, name in indices:
conn.execute(
"""
INSERT OR IGNORE INTO indices (index_code, name)
VALUES (?, ?)
""",
(code, name)
)

View File

@ -48,6 +48,20 @@ def render_stock_evaluation() -> None:
"市场类因子": [f for f in DEFAULT_FACTORS if f.name.startswith("market_")] "市场类因子": [f for f in DEFAULT_FACTORS if f.name.startswith("market_")]
} }
# 定义默认选中的关键常用因子
DEFAULT_SELECTED_FACTORS = {
"mom_5", # 5日动量
"mom_20", # 20日动量
"mom_60", # 60日动量
"volat_20", # 20日波动率
"turn_5", # 5日换手率
"turn_20", # 20日换手率
"val_pe_score", # PE评分
"val_pb_score", # PB评分
"volume_ratio_score", # 量比评分
"risk_penalty" # 风险惩罚项
}
selected_factors = [] selected_factors = []
for group_name, factors in factor_groups.items(): for group_name, factors in factor_groups.items():
if factors: if factors:
@ -56,7 +70,7 @@ def render_stock_evaluation() -> None:
for i, factor in enumerate(factors): for i, factor in enumerate(factors):
if cols[i % 3].checkbox( if cols[i % 3].checkbox(
factor.name, factor.name,
value=factor.name in selected_factors, value=factor.name in DEFAULT_SELECTED_FACTORS,
help=factor.description if hasattr(factor, 'description') else None help=factor.description if hasattr(factor, 'description') else None
): ):
selected_factors.append(factor.name) selected_factors.append(factor.name)

View File

@ -1100,6 +1100,53 @@ class DataBroker:
LOGGER.exception("强制刷新数据失败: %s", exc, extra=LOG_EXTRA) LOGGER.exception("强制刷新数据失败: %s", exc, extra=LOG_EXTRA)
return False return False
def get_index_stocks(
self,
index_code: str,
trade_date: str,
min_weight: float = 0.0
) -> List[str]:
"""获取指数成分股列表。
Args:
index_code: 指数代码( 000300.SH)
trade_date: 交易日期
min_weight: 最小权重筛选
Returns:
成分股代码列表
"""
try:
with db_session(read_only=True) as conn:
# 获取小于等于给定日期的最新一期成分股
rows = conn.execute(
"""
SELECT DISTINCT ts_code
FROM index_weight
WHERE index_code = ?
AND trade_date = (
SELECT MAX(trade_date)
FROM index_weight
WHERE index_code = ?
AND trade_date <= ?
)
AND weight >= ?
ORDER BY weight DESC
""",
(index_code, index_code, trade_date, min_weight)
).fetchall()
return [row["ts_code"] for row in rows if row and row["ts_code"]]
except Exception as exc:
LOGGER.exception(
"获取指数成分股失败 index=%s date=%s err=%s",
index_code,
trade_date,
exc,
extra=LOG_EXTRA
)
return []
def get_refresh_status(self) -> Dict[str, Dict[str, Any]]: def get_refresh_status(self) -> Dict[str, Dict[str, Any]]:
"""获取当前所有补数任务的状态。 """获取当前所有补数任务的状态。