""" 数据存储模块单元测试 测试数据库管理、模型和存储仓库的功能 """ import pytest import asyncio from unittest.mock import Mock, patch, AsyncMock from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from src.storage.database import DatabaseManager from src.storage.models import StockBasic, DailyKline, FinancialReport, DataSource, SystemLog from src.storage.stock_repository import StockRepository from src.utils.exceptions import DatabaseError class TestDatabaseManager: """数据库管理器测试类""" # 使用conftest.py中的db_manager fixture,不需要重新定义 @pytest.fixture def test_engine(self): """测试数据库引擎""" return create_engine("sqlite:///:memory:") def test_singleton_pattern(self, db_manager): """测试单例模式""" # 测试数据库管理器不是单例模式,而是测试专用的实例 # 验证测试数据库管理器功能正常 assert db_manager.engine is not None assert db_manager.Base is not None assert db_manager.SessionLocal is not None def test_configure_database_success(self, db_manager): """测试数据库配置成功""" # DatabaseManager会自动从settings配置数据库 assert db_manager.engine is not None def test_configure_database_invalid_url(self, db_manager): """测试无效数据库URL""" # DatabaseManager会自动从settings配置数据库,无法测试无效URL # 因为配置在初始化时就已经完成 pass def test_create_tables_success(self, db_manager): """测试创建表成功""" # 创建表 result = db_manager.create_tables() assert result is True # 验证表是否存在 - SQLAlchemy 2.0兼容 from sqlalchemy import inspect table_names = inspect(db_manager.engine).get_table_names() expected_tables = ["stock_basic", "daily_kline", "financial_report"] for table in expected_tables: assert table in table_names def test_get_session_success(self, db_manager): """测试获取会话成功""" # 创建表 db_manager.create_tables() # 获取会话 session = db_manager.get_session() assert session is not None # 关闭会话 session.close() def test_drop_tables_success(self, db_manager): """测试删除表成功""" # 创建表 db_manager.create_tables() # 删除表 result = db_manager.drop_tables() assert result is True # 验证表是否被删除 - SQLAlchemy 2.0兼容 from sqlalchemy import inspect table_names = inspect(db_manager.engine).get_table_names() assert len(table_names) == 0 class TestModels: """数据库模型测试类""" def test_stock_basic_model(self): """测试股票基础信息模型""" stock = StockBasic( code="000001", name="平安银行", market="sh", company_name="平安银行股份有限公司", industry="银行", area="广东", ipo_date="1991-04-03", listing_status=True ) assert stock.code == "000001" assert stock.name == "平安银行" assert stock.market == "sh" assert stock.company_name == "平安银行股份有限公司" assert stock.industry == "银行" assert stock.area == "广东" assert stock.ipo_date == "1991-04-03" assert stock.listing_status == True # created_at和updated_at字段由数据库自动生成,创建对象时为None def test_daily_kline_model(self): """测试日K线数据模型""" kline = DailyKline( stock_code="000001", trade_date="2024-01-15", open_price=10.5, high_price=11.2, low_price=10.3, close_price=10.8, volume=1000000, amount=10800000 ) assert kline.stock_code == "000001" assert kline.trade_date == "2024-01-15" assert kline.open_price == 10.5 assert kline.high_price == 11.2 assert kline.low_price == 10.3 assert kline.close_price == 10.8 assert kline.volume == 1000000 assert kline.amount == 10800000 # created_at字段由数据库自动生成,创建对象时为None def test_financial_report_model(self): """测试财务报告模型""" financial = FinancialReport( stock_code="000001", report_date="2023-12-31", report_type="年报", report_year=2023, eps=1.5, net_profit=1500000000, revenue=5000000000, total_assets=10000000000 ) assert financial.stock_code == "000001" assert financial.report_date == "2023-12-31" assert financial.report_type == "年报" assert financial.report_year == 2023 assert financial.eps == 1.5 assert financial.net_profit == 1500000000 assert financial.revenue == 5000000000 assert financial.total_assets == 10000000000 # created_at字段由数据库自动生成,创建对象时为None def test_data_source_model(self): """测试数据源模型""" source = DataSource( source_name="akshare", source_type="api", sync_status="正常", last_sync_time="2024-01-15 10:00:00" ) assert source.source_name == "akshare" assert source.source_type == "api" assert source.sync_status == "正常" assert source.last_sync_time == "2024-01-15 10:00:00" # created_at和updated_at字段由数据库自动生成,创建对象时为None def test_system_log_model(self): """测试系统日志模型""" log = SystemLog( log_level="INFO", module_name="data_collection", message="数据采集完成" ) assert log.log_level == "INFO" assert log.module_name == "data_collection" assert log.message == "数据采集完成" # created_at字段由数据库自动生成,创建对象时为None class TestStockRepository: """股票存储仓库测试类""" @pytest.fixture def stock_repo(self, db_manager): """股票存储仓库实例""" # 获取数据库会话 session = db_manager.get_session() return StockRepository(session) def test_save_stock_basic_info_success(self, stock_repo): """测试保存股票基础信息成功""" from datetime import date stock_data = [ { "code": "000001", "name": "平安银行", "market": "主板", "industry": "银行", "area": "广东", "ipo_date": date(1991, 4, 3), "data_source": "akshare" } ] result = stock_repo.save_stock_basic_info(stock_data) assert result["added_count"] == 1 assert result["error_count"] == 0 # 验证数据是否保存 saved_data = stock_repo.get_stock_basic_info() assert len(saved_data) == 1 assert saved_data[0].code == "000001" assert saved_data[0].name == "平安银行" def test_save_stock_basic_info_duplicate(self, stock_repo): """测试保存重复股票基础信息""" stock_data = [ { "code": "000001", "name": "平安银行", "market": "主板", "data_source": "akshare" } ] # 第一次保存 result1 = stock_repo.save_stock_basic_info(stock_data) assert result1["added_count"] == 1 # 第二次保存相同数据 result2 = stock_repo.save_stock_basic_info(stock_data) assert result2["updated_count"] == 1 # 验证只有一条记录 saved_data = stock_repo.get_stock_basic_info() assert len(saved_data) == 1 def test_save_daily_kline_data_success(self, stock_repo): """测试保存日K线数据成功""" from datetime import date # 先保存股票基础信息 stock_data = [ { "code": "000001", "name": "平安银行", "market": "主板", "data_source": "akshare" } ] stock_repo.save_stock_basic_info(stock_data) # 再保存日K线数据 kline_data = [ { "code": "000001", "date": "2024-01-15", "open": 10.5, "high": 11.2, "low": 10.3, "close": 10.8, "volume": 1000000, "amount": 10800000, "data_source": "akshare" } ] result = stock_repo.save_daily_kline_data(kline_data) assert result["added_count"] == 1 assert result["error_count"] == 0 def test_save_financial_report_data_success(self, stock_repo): """测试保存财务报告数据成功""" from datetime import date # 先保存股票基础信息 stock_data = [ { "code": "000001", "name": "平安银行", "market": "主板", "industry": "银行", "area": "广东", "ipo_date": date(1991, 4, 3), "data_source": "akshare" } ] stock_repo.save_stock_basic_info(stock_data) # 再保存财务报告数据 financial_data = [ { "code": "000001", "report_date": "2023-12-31", "report_type": "年报", "eps": 1.5, "net_profit": 1500000000, "revenue": 5000000000, "data_source": "akshare" } ] result = stock_repo.save_financial_report_data(financial_data) assert result["added_count"] == 1 assert result["error_count"] == 0 def test_get_stock_basic_info_success(self, stock_repo): """测试获取股票基础信息成功""" # 先保存数据 stock_data = [ { "code": "000001", "name": "平安银行", "market": "主板", "data_source": "akshare" } ] stock_repo.save_stock_basic_info(stock_data) # 获取数据 result = stock_repo.get_stock_basic_info() assert len(result) == 1 assert result[0].code == "000001" assert result[0].name == "平安银行" def test_get_stock_basic_info_not_found(self, stock_repo): """测试获取不存在的股票基础信息""" result = stock_repo.get_stock_basic_info() # 没有保存任何数据,所以结果应该为空 assert len(result) == 0 def test_get_daily_kline_data_success(self, stock_repo): """测试获取日K线数据成功""" # 先保存数据 kline_data = [ { "code": "000001", "date": "2024-01-15", "open": 10.5, "high": 11.0, "low": 10.0, "close": 10.8, "volume": 1000000, "amount": 10800000, "data_source": "akshare" } ] stock_repo.save_daily_kline_data(kline_data) # 获取数据 from datetime import date result = stock_repo.get_daily_kline_data("000001", date(2024, 1, 1), date(2024, 1, 31)) assert len(result) == 1 assert result[0].stock_code == "000001" assert result[0].trade_date == date(2024, 1, 15) assert result[0].open_price == 10.5 assert result[0].close_price == 10.8 def test_transaction_rollback_on_error(self, stock_repo): """测试事务回滚""" # 创建无效数据(缺少必要字段) invalid_data = [ { "code": "000001", # 缺少name字段 "market": "主板" } ] # 调用方法,应该不会抛出异常,但会记录错误 result = stock_repo.save_stock_basic_info(invalid_data) # 验证方法返回了错误计数 assert result["error_count"] == 1 assert result["added_count"] == 0 assert result["updated_count"] == 0 # 验证没有数据被保存(事务回滚) saved_data = stock_repo.get_stock_basic_info() assert len(saved_data) == 0 class TestDatabaseOperations: """数据库操作测试类""" @pytest.fixture def setup_database(self): """设置测试数据库""" # 使用测试数据库管理器而不是主数据库管理器 from tests.conftest import create_test_database_manager db_manager = create_test_database_manager() db_manager.create_tables() return db_manager def test_bulk_insert_performance(self, setup_database): """测试批量插入性能""" db_manager = setup_database session = db_manager.get_session() # 创建大量测试数据 test_data = [] for i in range(1000): stock = db_manager.StockBasic( code=f"{i:06d}", name=f"测试股票{i}", market="主板", data_source="test" ) test_data.append(stock) # 批量插入 import time start_time = time.time() session.bulk_save_objects(test_data) session.commit() end_time = time.time() execution_time = end_time - start_time # 验证插入的数据量 count = session.query(db_manager.StockBasic).count() assert count == 1000 # 性能要求:1000条数据插入时间应小于1秒 assert execution_time < 1.0 session.close() def test_query_performance(self, setup_database): """测试查询性能""" db_manager = setup_database session = db_manager.get_session() # 插入测试数据 test_data = [] for i in range(1000): stock = db_manager.StockBasic( code=f"{i:06d}", name=f"测试股票{i}", market="主板", data_source="test" ) test_data.append(stock) session.bulk_save_objects(test_data) session.commit() # 测试查询性能 import time start_time = time.time() result = session.query(db_manager.StockBasic).filter(db_manager.StockBasic.market == "主板").all() end_time = time.time() execution_time = end_time - start_time # 验证查询结果 assert len(result) == 1000 # 性能要求:1000条数据查询时间应小于0.1秒 assert execution_time < 0.1 session.close()