From 6810712232f3f3538adc1ce273475b11713a7a2a Mon Sep 17 00:00:00 2001 From: sam Date: Sat, 11 Oct 2025 19:16:42 +0800 Subject: [PATCH] add stock name and industry fields to investment pool --- app/data/schema.py | 2 ++ app/ui/views/stock_eval.py | 14 ++++++++++++-- app/utils/portfolio.py | 2 +- app/utils/portfolio_init.py | 2 ++ tests/test_portfolio_config.py | 12 +++++++----- 5 files changed, 24 insertions(+), 8 deletions(-) diff --git a/app/data/schema.py b/app/data/schema.py index 964ef24..3ee2953 100644 --- a/app/data/schema.py +++ b/app/data/schema.py @@ -474,6 +474,8 @@ SCHEMA_STATEMENTS: Iterable[str] = ( rationale TEXT, tags TEXT, metadata TEXT, + name TEXT, + industry TEXT, created_at TEXT DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now')), PRIMARY KEY (trade_date, ts_code) ); diff --git a/app/ui/views/stock_eval.py b/app/ui/views/stock_eval.py index 49cb9d3..3c6651c 100644 --- a/app/ui/views/stock_eval.py +++ b/app/ui/views/stock_eval.py @@ -475,6 +475,7 @@ def _add_to_stock_pool( ) -> None: """将股票评分结果写入投资池。""" + broker = DataBroker() trade_date = eval_date.strftime("%Y%m%d") payload: List[tuple] = [] ranked_df = score_df.reset_index(drop=True) @@ -489,6 +490,11 @@ def _add_to_stock_pool( }, ensure_ascii=False, ) + # 获取股票基本信息 + stock_info = broker.get_stock_info(row["股票代码"], trade_date) + stock_name = stock_info.get("name", "") if stock_info else "" + stock_industry = stock_info.get("industry", "") if stock_info else "" + payload.append( ( trade_date, @@ -498,6 +504,8 @@ def _add_to_stock_pool( "factor_evaluation_top20", tags, metadata, + stock_name, + stock_industry, ) ) @@ -513,8 +521,10 @@ def _add_to_stock_pool( status, rationale, tags, - metadata - ) VALUES (?, ?, ?, ?, ?, ?, ?) + metadata, + name, + industry + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) """, payload, ) diff --git a/app/utils/portfolio.py b/app/utils/portfolio.py index 1ef79fa..3748b25 100644 --- a/app/utils/portfolio.py +++ b/app/utils/portfolio.py @@ -45,7 +45,7 @@ def list_investment_pool( """Return investment candidates for the given trade date (latest if None).""" query = [ - "SELECT trade_date, ts_code, score, status, rationale, tags, metadata", + "SELECT trade_date, ts_code, score, status, rationale, tags, metadata, name, industry", "FROM investment_pool", ] params: List[Any] = [] diff --git a/app/utils/portfolio_init.py b/app/utils/portfolio_init.py index 6995e01..e89df15 100644 --- a/app/utils/portfolio_init.py +++ b/app/utils/portfolio_init.py @@ -76,6 +76,8 @@ SCHEMA_STATEMENTS = [ rationale TEXT, tags TEXT, -- JSON array metadata TEXT, -- JSON object + name TEXT, + industry TEXT, PRIMARY KEY (trade_date, ts_code) ); """, diff --git a/tests/test_portfolio_config.py b/tests/test_portfolio_config.py index fc4aa39..65d0b97 100644 --- a/tests/test_portfolio_config.py +++ b/tests/test_portfolio_config.py @@ -133,19 +133,21 @@ def test_list_investment_pool_orders_without_nulls(tmp_path): rationale TEXT, tags TEXT, metadata TEXT, + name TEXT, + industry TEXT, PRIMARY KEY (trade_date, ts_code) ) """ ) conn.executemany( """ - INSERT INTO investment_pool (trade_date, ts_code, score, status, rationale, tags, metadata) - VALUES (?, ?, ?, ?, ?, ?, ?) + INSERT INTO investment_pool (trade_date, ts_code, score, status, rationale, tags, metadata, name, industry) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) """, [ - ("2024-01-01", "AAA", 0.8, "buy", "", None, None), - ("2024-01-01", "BBB", None, "hold", "", None, None), - ("2024-01-01", "CCC", 0.9, "buy", "", None, None), + ("2024-01-01", "AAA", 0.8, "buy", "", None, None, "Company A", "Technology"), + ("2024-01-01", "BBB", None, "hold", "", None, None, "Company B", "Finance"), + ("2024-01-01", "CCC", 0.9, "buy", "", None, None, "Company C", "Healthcare"), ], )