update
This commit is contained in:
parent
f29bb99b68
commit
be12ba35a6
@ -431,6 +431,38 @@ def _load_from_file(cfg: AppConfig) -> None:
|
||||
if isinstance(weights_payload, dict):
|
||||
cfg.agent_weights.update_from_dict(weights_payload)
|
||||
|
||||
portfolio_payload = payload.get("portfolio")
|
||||
if isinstance(portfolio_payload, dict):
|
||||
limits_payload = portfolio_payload.get("position_limits")
|
||||
if not isinstance(limits_payload, dict):
|
||||
limits_payload = portfolio_payload
|
||||
|
||||
current = cfg.portfolio
|
||||
|
||||
def _float_value(container: Dict[str, object], key: str, fallback: float) -> float:
|
||||
value = container.get(key) if isinstance(container, dict) else None
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
return fallback
|
||||
|
||||
def _int_value(container: Dict[str, object], key: str, fallback: int) -> int:
|
||||
value = container.get(key) if isinstance(container, dict) else None
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return fallback
|
||||
|
||||
updated_portfolio = PortfolioSettings(
|
||||
initial_capital=_float_value(portfolio_payload, "initial_capital", current.initial_capital),
|
||||
currency=str(portfolio_payload.get("currency") or current.currency),
|
||||
max_position=_float_value(limits_payload, "max_position", current.max_position),
|
||||
min_position=_float_value(limits_payload, "min_position", current.min_position),
|
||||
max_total_positions=_int_value(limits_payload, "max_total_positions", current.max_total_positions),
|
||||
max_sector_exposure=_float_value(limits_payload, "max_sector_exposure", current.max_sector_exposure),
|
||||
)
|
||||
cfg.portfolio = updated_portfolio
|
||||
|
||||
legacy_profiles: Dict[str, Dict[str, object]] = {}
|
||||
legacy_routes: Dict[str, Dict[str, object]] = {}
|
||||
|
||||
@ -600,6 +632,16 @@ def save_config(cfg: AppConfig | None = None) -> None:
|
||||
"decision_method": cfg.decision_method,
|
||||
"rss_sources": cfg.rss_sources,
|
||||
"agent_weights": cfg.agent_weights.as_dict(),
|
||||
"portfolio": {
|
||||
"initial_capital": cfg.portfolio.initial_capital,
|
||||
"currency": cfg.portfolio.currency,
|
||||
"position_limits": {
|
||||
"max_position": cfg.portfolio.max_position,
|
||||
"min_position": cfg.portfolio.min_position,
|
||||
"max_total_positions": cfg.portfolio.max_total_positions,
|
||||
"max_sector_exposure": cfg.portfolio.max_sector_exposure,
|
||||
},
|
||||
},
|
||||
"llm": {
|
||||
"strategy": cfg.llm.strategy if cfg.llm.strategy in ALLOWED_LLM_STRATEGIES else "single",
|
||||
"majority_threshold": cfg.llm.majority_threshold,
|
||||
|
||||
@ -61,7 +61,7 @@ def list_investment_pool(
|
||||
query.append(f"AND status IN ({placeholders})")
|
||||
params.extend(list(status))
|
||||
|
||||
query.append("ORDER BY score DESC NULLS LAST, ts_code")
|
||||
query.append("ORDER BY (score IS NULL), score DESC, ts_code")
|
||||
query.append("LIMIT ?")
|
||||
params.append(int(limit))
|
||||
|
||||
|
||||
@ -1,7 +1,15 @@
|
||||
"""Test portfolio configuration and initialization."""
|
||||
import json
|
||||
from dataclasses import replace
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from app.utils.portfolio import get_latest_snapshot
|
||||
import pytest
|
||||
|
||||
from app.utils import config as config_module
|
||||
from app.utils.config import AppConfig, DataPaths, get_config
|
||||
from app.utils.portfolio_init import update_portfolio_config
|
||||
|
||||
from app.utils.portfolio import get_latest_snapshot, list_investment_pool
|
||||
from app.utils.db import db_session
|
||||
|
||||
|
||||
@ -55,3 +63,93 @@ def test_custom_portfolio_config():
|
||||
assert snapshot.cash == 2000000
|
||||
assert snapshot.metadata["initial_capital"] == 2000000
|
||||
assert snapshot.metadata["currency"] == "USD"
|
||||
|
||||
|
||||
def test_update_portfolio_config_persists(tmp_path):
|
||||
cfg = get_config()
|
||||
original_paths = cfg.data_paths
|
||||
original_portfolio = replace(cfg.portfolio)
|
||||
|
||||
temp_root = tmp_path / "data"
|
||||
temp_paths = DataPaths(root=temp_root)
|
||||
cfg.data_paths = temp_paths
|
||||
|
||||
updates = {
|
||||
"initial_capital": 3_000_000,
|
||||
"currency": "USD",
|
||||
"position_limits": {
|
||||
"max_position": 0.15,
|
||||
"min_position": 0.03,
|
||||
"max_total_positions": 12,
|
||||
"max_sector_exposure": 0.4,
|
||||
},
|
||||
}
|
||||
|
||||
try:
|
||||
update_portfolio_config(updates)
|
||||
|
||||
payload = json.loads(temp_paths.config_file.read_text(encoding="utf-8"))
|
||||
assert payload["portfolio"]["initial_capital"] == 3_000_000
|
||||
assert payload["portfolio"]["currency"] == "USD"
|
||||
limits = payload["portfolio"]["position_limits"]
|
||||
assert limits["max_position"] == pytest.approx(0.15)
|
||||
assert limits["min_position"] == pytest.approx(0.03)
|
||||
assert limits["max_total_positions"] == 12
|
||||
assert limits["max_sector_exposure"] == pytest.approx(0.4)
|
||||
|
||||
fresh_cfg = AppConfig()
|
||||
fresh_cfg.data_paths = temp_paths
|
||||
config_module._load_from_file(fresh_cfg)
|
||||
assert fresh_cfg.portfolio.initial_capital == pytest.approx(3_000_000.0)
|
||||
assert fresh_cfg.portfolio.currency == "USD"
|
||||
assert fresh_cfg.portfolio.max_position == pytest.approx(0.15)
|
||||
assert fresh_cfg.portfolio.min_position == pytest.approx(0.03)
|
||||
assert fresh_cfg.portfolio.max_total_positions == 12
|
||||
assert fresh_cfg.portfolio.max_sector_exposure == pytest.approx(0.4)
|
||||
finally:
|
||||
cfg.data_paths = original_paths
|
||||
cfg.portfolio = original_portfolio
|
||||
if temp_paths.config_file.exists():
|
||||
temp_paths.config_file.unlink()
|
||||
|
||||
|
||||
def test_list_investment_pool_orders_without_nulls(tmp_path):
|
||||
cfg = get_config()
|
||||
original_paths = cfg.data_paths
|
||||
|
||||
temp_root = tmp_path / "data"
|
||||
temp_paths = DataPaths(root=temp_root)
|
||||
cfg.data_paths = temp_paths
|
||||
|
||||
try:
|
||||
with db_session() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS investment_pool (
|
||||
trade_date TEXT,
|
||||
ts_code TEXT,
|
||||
score REAL,
|
||||
status TEXT,
|
||||
rationale TEXT,
|
||||
tags TEXT,
|
||||
metadata TEXT,
|
||||
PRIMARY KEY (trade_date, ts_code)
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.executemany(
|
||||
"""
|
||||
INSERT INTO investment_pool (trade_date, ts_code, score, status, rationale, tags, metadata)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
[
|
||||
("2024-01-01", "AAA", 0.8, "buy", "", None, None),
|
||||
("2024-01-01", "BBB", None, "hold", "", None, None),
|
||||
("2024-01-01", "CCC", 0.9, "buy", "", None, None),
|
||||
],
|
||||
)
|
||||
|
||||
rows = list_investment_pool(trade_date="2024-01-01")
|
||||
assert [row.ts_code for row in rows] == ["CCC", "AAA", "BBB"]
|
||||
finally:
|
||||
cfg.data_paths = original_paths
|
||||
|
||||
Loading…
Reference in New Issue
Block a user