From be12ba35a6d3efe605b8657d0c84ca9ac6f0a192 Mon Sep 17 00:00:00 2001 From: sam Date: Sun, 5 Oct 2025 18:27:39 +0800 Subject: [PATCH] update --- app/utils/config.py | 42 ++++++++++++++ app/utils/portfolio.py | 2 +- tests/test_portfolio_config.py | 100 ++++++++++++++++++++++++++++++++- 3 files changed, 142 insertions(+), 2 deletions(-) diff --git a/app/utils/config.py b/app/utils/config.py index b82f3b3..06e7c83 100644 --- a/app/utils/config.py +++ b/app/utils/config.py @@ -431,6 +431,38 @@ def _load_from_file(cfg: AppConfig) -> None: if isinstance(weights_payload, dict): cfg.agent_weights.update_from_dict(weights_payload) + portfolio_payload = payload.get("portfolio") + if isinstance(portfolio_payload, dict): + limits_payload = portfolio_payload.get("position_limits") + if not isinstance(limits_payload, dict): + limits_payload = portfolio_payload + + current = cfg.portfolio + + def _float_value(container: Dict[str, object], key: str, fallback: float) -> float: + value = container.get(key) if isinstance(container, dict) else None + try: + return float(value) + except (TypeError, ValueError): + return fallback + + def _int_value(container: Dict[str, object], key: str, fallback: int) -> int: + value = container.get(key) if isinstance(container, dict) else None + try: + return int(value) + except (TypeError, ValueError): + return fallback + + updated_portfolio = PortfolioSettings( + initial_capital=_float_value(portfolio_payload, "initial_capital", current.initial_capital), + currency=str(portfolio_payload.get("currency") or current.currency), + max_position=_float_value(limits_payload, "max_position", current.max_position), + min_position=_float_value(limits_payload, "min_position", current.min_position), + max_total_positions=_int_value(limits_payload, "max_total_positions", current.max_total_positions), + max_sector_exposure=_float_value(limits_payload, "max_sector_exposure", current.max_sector_exposure), + ) + cfg.portfolio = updated_portfolio + legacy_profiles: Dict[str, Dict[str, object]] = {} legacy_routes: Dict[str, Dict[str, object]] = {} @@ -600,6 +632,16 @@ def save_config(cfg: AppConfig | None = None) -> None: "decision_method": cfg.decision_method, "rss_sources": cfg.rss_sources, "agent_weights": cfg.agent_weights.as_dict(), + "portfolio": { + "initial_capital": cfg.portfolio.initial_capital, + "currency": cfg.portfolio.currency, + "position_limits": { + "max_position": cfg.portfolio.max_position, + "min_position": cfg.portfolio.min_position, + "max_total_positions": cfg.portfolio.max_total_positions, + "max_sector_exposure": cfg.portfolio.max_sector_exposure, + }, + }, "llm": { "strategy": cfg.llm.strategy if cfg.llm.strategy in ALLOWED_LLM_STRATEGIES else "single", "majority_threshold": cfg.llm.majority_threshold, diff --git a/app/utils/portfolio.py b/app/utils/portfolio.py index 52c3108..213e4ea 100644 --- a/app/utils/portfolio.py +++ b/app/utils/portfolio.py @@ -61,7 +61,7 @@ def list_investment_pool( query.append(f"AND status IN ({placeholders})") params.extend(list(status)) - query.append("ORDER BY score DESC NULLS LAST, ts_code") + query.append("ORDER BY (score IS NULL), score DESC, ts_code") query.append("LIMIT ?") params.append(int(limit)) diff --git a/tests/test_portfolio_config.py b/tests/test_portfolio_config.py index af7b920..fc4aa39 100644 --- a/tests/test_portfolio_config.py +++ b/tests/test_portfolio_config.py @@ -1,7 +1,15 @@ """Test portfolio configuration and initialization.""" +import json +from dataclasses import replace from unittest.mock import patch, MagicMock -from app.utils.portfolio import get_latest_snapshot +import pytest + +from app.utils import config as config_module +from app.utils.config import AppConfig, DataPaths, get_config +from app.utils.portfolio_init import update_portfolio_config + +from app.utils.portfolio import get_latest_snapshot, list_investment_pool from app.utils.db import db_session @@ -55,3 +63,93 @@ def test_custom_portfolio_config(): assert snapshot.cash == 2000000 assert snapshot.metadata["initial_capital"] == 2000000 assert snapshot.metadata["currency"] == "USD" + + +def test_update_portfolio_config_persists(tmp_path): + cfg = get_config() + original_paths = cfg.data_paths + original_portfolio = replace(cfg.portfolio) + + temp_root = tmp_path / "data" + temp_paths = DataPaths(root=temp_root) + cfg.data_paths = temp_paths + + updates = { + "initial_capital": 3_000_000, + "currency": "USD", + "position_limits": { + "max_position": 0.15, + "min_position": 0.03, + "max_total_positions": 12, + "max_sector_exposure": 0.4, + }, + } + + try: + update_portfolio_config(updates) + + payload = json.loads(temp_paths.config_file.read_text(encoding="utf-8")) + assert payload["portfolio"]["initial_capital"] == 3_000_000 + assert payload["portfolio"]["currency"] == "USD" + limits = payload["portfolio"]["position_limits"] + assert limits["max_position"] == pytest.approx(0.15) + assert limits["min_position"] == pytest.approx(0.03) + assert limits["max_total_positions"] == 12 + assert limits["max_sector_exposure"] == pytest.approx(0.4) + + fresh_cfg = AppConfig() + fresh_cfg.data_paths = temp_paths + config_module._load_from_file(fresh_cfg) + assert fresh_cfg.portfolio.initial_capital == pytest.approx(3_000_000.0) + assert fresh_cfg.portfolio.currency == "USD" + assert fresh_cfg.portfolio.max_position == pytest.approx(0.15) + assert fresh_cfg.portfolio.min_position == pytest.approx(0.03) + assert fresh_cfg.portfolio.max_total_positions == 12 + assert fresh_cfg.portfolio.max_sector_exposure == pytest.approx(0.4) + finally: + cfg.data_paths = original_paths + cfg.portfolio = original_portfolio + if temp_paths.config_file.exists(): + temp_paths.config_file.unlink() + + +def test_list_investment_pool_orders_without_nulls(tmp_path): + cfg = get_config() + original_paths = cfg.data_paths + + temp_root = tmp_path / "data" + temp_paths = DataPaths(root=temp_root) + cfg.data_paths = temp_paths + + try: + with db_session() as conn: + conn.execute( + """ + CREATE TABLE IF NOT EXISTS investment_pool ( + trade_date TEXT, + ts_code TEXT, + score REAL, + status TEXT, + rationale TEXT, + tags TEXT, + metadata TEXT, + PRIMARY KEY (trade_date, ts_code) + ) + """ + ) + conn.executemany( + """ + INSERT INTO investment_pool (trade_date, ts_code, score, status, rationale, tags, metadata) + VALUES (?, ?, ?, ?, ?, ?, ?) + """, + [ + ("2024-01-01", "AAA", 0.8, "buy", "", None, None), + ("2024-01-01", "BBB", None, "hold", "", None, None), + ("2024-01-01", "CCC", 0.9, "buy", "", None, None), + ], + ) + + rows = list_investment_pool(trade_date="2024-01-01") + assert [row.ts_code for row in rows] == ["CCC", "AAA", "BBB"] + finally: + cfg.data_paths = original_paths