156 lines
5.5 KiB
Python
156 lines
5.5 KiB
Python
"""Test portfolio configuration and initialization."""
|
|
import json
|
|
from dataclasses import replace
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
import pytest
|
|
|
|
from app.utils import config as config_module
|
|
from app.utils.config import AppConfig, DataPaths, get_config
|
|
from app.utils.portfolio_init import update_portfolio_config
|
|
|
|
from app.utils.portfolio import get_latest_snapshot, list_investment_pool
|
|
from app.utils.db import db_session
|
|
|
|
|
|
def test_default_portfolio_config():
|
|
"""Test default portfolio configuration."""
|
|
# Mock db_session as a context manager
|
|
mock_session = MagicMock()
|
|
mock_session.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_session.__exit__ = MagicMock(return_value=None)
|
|
|
|
# Mock the database query result
|
|
mock_session.execute.return_value.fetchone.return_value = None
|
|
|
|
# 使用默认配置
|
|
with patch("app.utils.portfolio.get_portfolio_config") as mock_config, \
|
|
patch("app.utils.portfolio.db_session", return_value=mock_session):
|
|
mock_config.return_value = {
|
|
"initial_capital": 1000000,
|
|
"currency": "CNY"
|
|
}
|
|
|
|
snapshot = get_latest_snapshot()
|
|
assert snapshot is not None
|
|
assert snapshot.total_value == 1000000
|
|
assert snapshot.cash == 1000000
|
|
assert snapshot.metadata["initial_capital"] == 1000000
|
|
assert snapshot.metadata["currency"] == "CNY"
|
|
|
|
|
|
def test_custom_portfolio_config():
|
|
"""Test custom portfolio configuration."""
|
|
# Mock db_session as a context manager
|
|
mock_session = MagicMock()
|
|
mock_session.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_session.__exit__ = MagicMock(return_value=None)
|
|
|
|
# Mock the database query result
|
|
mock_session.execute.return_value.fetchone.return_value = None
|
|
|
|
# 使用自定义配置
|
|
with patch("app.utils.portfolio.get_portfolio_config") as mock_config, \
|
|
patch("app.utils.portfolio.db_session", return_value=mock_session):
|
|
mock_config.return_value = {
|
|
"initial_capital": 2000000,
|
|
"currency": "USD"
|
|
}
|
|
|
|
snapshot = get_latest_snapshot()
|
|
assert snapshot is not None
|
|
assert snapshot.total_value == 2000000
|
|
assert snapshot.cash == 2000000
|
|
assert snapshot.metadata["initial_capital"] == 2000000
|
|
assert snapshot.metadata["currency"] == "USD"
|
|
|
|
|
|
def test_update_portfolio_config_persists(tmp_path):
|
|
cfg = get_config()
|
|
original_paths = cfg.data_paths
|
|
original_portfolio = replace(cfg.portfolio)
|
|
|
|
temp_root = tmp_path / "data"
|
|
temp_paths = DataPaths(root=temp_root)
|
|
cfg.data_paths = temp_paths
|
|
|
|
updates = {
|
|
"initial_capital": 3_000_000,
|
|
"currency": "USD",
|
|
"position_limits": {
|
|
"max_position": 0.15,
|
|
"min_position": 0.03,
|
|
"max_total_positions": 12,
|
|
"max_sector_exposure": 0.4,
|
|
},
|
|
}
|
|
|
|
try:
|
|
update_portfolio_config(updates)
|
|
|
|
payload = json.loads(temp_paths.config_file.read_text(encoding="utf-8"))
|
|
assert payload["portfolio"]["initial_capital"] == 3_000_000
|
|
assert payload["portfolio"]["currency"] == "USD"
|
|
limits = payload["portfolio"]["position_limits"]
|
|
assert limits["max_position"] == pytest.approx(0.15)
|
|
assert limits["min_position"] == pytest.approx(0.03)
|
|
assert limits["max_total_positions"] == 12
|
|
assert limits["max_sector_exposure"] == pytest.approx(0.4)
|
|
|
|
fresh_cfg = AppConfig()
|
|
fresh_cfg.data_paths = temp_paths
|
|
config_module._load_from_file(fresh_cfg)
|
|
assert fresh_cfg.portfolio.initial_capital == pytest.approx(3_000_000.0)
|
|
assert fresh_cfg.portfolio.currency == "USD"
|
|
assert fresh_cfg.portfolio.max_position == pytest.approx(0.15)
|
|
assert fresh_cfg.portfolio.min_position == pytest.approx(0.03)
|
|
assert fresh_cfg.portfolio.max_total_positions == 12
|
|
assert fresh_cfg.portfolio.max_sector_exposure == pytest.approx(0.4)
|
|
finally:
|
|
cfg.data_paths = original_paths
|
|
cfg.portfolio = original_portfolio
|
|
if temp_paths.config_file.exists():
|
|
temp_paths.config_file.unlink()
|
|
|
|
|
|
def test_list_investment_pool_orders_without_nulls(tmp_path):
|
|
cfg = get_config()
|
|
original_paths = cfg.data_paths
|
|
|
|
temp_root = tmp_path / "data"
|
|
temp_paths = DataPaths(root=temp_root)
|
|
cfg.data_paths = temp_paths
|
|
|
|
try:
|
|
with db_session() as conn:
|
|
conn.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS investment_pool (
|
|
trade_date TEXT,
|
|
ts_code TEXT,
|
|
score REAL,
|
|
status TEXT,
|
|
rationale TEXT,
|
|
tags TEXT,
|
|
metadata TEXT,
|
|
PRIMARY KEY (trade_date, ts_code)
|
|
)
|
|
"""
|
|
)
|
|
conn.executemany(
|
|
"""
|
|
INSERT INTO investment_pool (trade_date, ts_code, score, status, rationale, tags, metadata)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
""",
|
|
[
|
|
("2024-01-01", "AAA", 0.8, "buy", "", None, None),
|
|
("2024-01-01", "BBB", None, "hold", "", None, None),
|
|
("2024-01-01", "CCC", 0.9, "buy", "", None, None),
|
|
],
|
|
)
|
|
|
|
rows = list_investment_pool(trade_date="2024-01-01")
|
|
assert [row.ts_code for row in rows] == ["CCC", "AAA", "BBB"]
|
|
finally:
|
|
cfg.data_paths = original_paths
|