""" 测试配置文件 定义测试用的fixture和配置 """ import pytest import asyncio from pathlib import Path from typing import Generator, AsyncGenerator from src.config.settings import Settings from src.storage.database import DatabaseManager from src.data.data_processor import DataProcessor from src.utils.logger import LogManager @pytest.fixture(scope="session") def test_settings() -> Settings: """测试配置""" # 创建测试配置实例 return Settings() @pytest.fixture(scope="session") def test_log_manager(test_settings) -> LogManager: """测试日志管理器""" return LogManager( log_dir="./test_logs", log_level=test_settings.log.log_level ) def create_test_database_manager(): """创建测试数据库管理器""" # 创建一个简单的测试数据库管理器 from sqlalchemy import create_engine from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker class TestDatabaseManager: def __init__(self): self.engine = create_engine("sqlite:///:memory:") self.SessionLocal = sessionmaker(bind=self.engine) # 创建独立的Base类用于测试 self.Base = declarative_base() # 导入模型类并重新定义 self._import_models() def _import_models(self): """导入并重新定义模型类""" from sqlalchemy import Column, Integer, String, Float, Date, DateTime, Text, Boolean, BigInteger, ForeignKey, Index from sqlalchemy.sql import func # 重新定义StockBasic模型 class StockBasic(self.Base): __tablename__ = "stock_basic" id = Column(Integer, primary_key=True, autoincrement=True) code = Column(String(10), nullable=False, unique=True) name = Column(String(50), nullable=False) market = Column(String(10), nullable=False) company_name = Column(String(100)) industry = Column(String(50)) area = Column(String(50)) ipo_date = Column(Date) listing_status = Column(Boolean, default=True) data_source = Column(String(50), default="akshare") # 添加data_source字段 created_at = Column(DateTime, server_default=func.now()) updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now()) __table_args__ = ( Index("idx_code", "code"), Index("idx_market", "market"), Index("idx_industry", "industry"), Index("idx_ipo_date", "ipo_date") ) # 重新定义DailyKline模型 class DailyKline(self.Base): __tablename__ = "daily_kline" id = Column(Integer, primary_key=True, autoincrement=True) stock_code = Column(String(10), ForeignKey("stock_basic.code"), nullable=False) trade_date = Column(Date, nullable=False) open_price = Column(Float, nullable=False) high_price = Column(Float, nullable=False) low_price = Column(Float, nullable=False) close_price = Column(Float, nullable=False) volume = Column(BigInteger) amount = Column(Float) change = Column(Float) pct_change = Column(Float) created_at = Column(DateTime, server_default=func.now()) __table_args__ = ( Index("idx_stock_code_date", "stock_code", "trade_date"), Index("idx_trade_date", "trade_date") ) # 重新定义FinancialReport模型 class FinancialReport(self.Base): __tablename__ = "financial_report" id = Column(Integer, primary_key=True, autoincrement=True) stock_code = Column(String(10), ForeignKey("stock_basic.code"), nullable=False) report_date = Column(Date, nullable=False) report_type = Column(String(20), nullable=False) report_year = Column(Integer, nullable=False) report_quarter = Column(Integer) eps = Column(Float) net_profit = Column(Float) revenue = Column(Float) total_assets = Column(Float) total_liabilities = Column(Float) equity = Column(Float) roe = Column(Float) gross_profit_margin = Column(Float) net_profit_margin = Column(Float) debt_to_asset_ratio = Column(Float) current_ratio = Column(Float) quick_ratio = Column(Float) created_at = Column(DateTime, server_default=func.now()) __table_args__ = ( Index("idx_stock_code_report", "stock_code", "report_date"), Index("idx_report_date", "report_date"), Index("idx_report_type", "report_type") ) self.StockBasic = StockBasic self.DailyKline = DailyKline self.FinancialReport = FinancialReport def create_tables(self): """创建所有数据表""" try: # 创建所有表 self.Base.metadata.create_all(bind=self.engine) return True except Exception as e: print(f"创建数据库表失败: {str(e)}") return False def drop_tables(self): """删除所有数据表""" try: self.Base.metadata.drop_all(bind=self.engine) return True except Exception as e: print(f"删除数据库表失败: {str(e)}") return False def get_session(self): """获取数据库会话""" try: session = self.SessionLocal() return session except Exception as e: print(f"获取数据库会话失败: {str(e)}") raise return TestDatabaseManager() @pytest.fixture(scope="function") def db_manager(): """测试数据库管理器fixture""" db_manager = create_test_database_manager() # 创建测试数据库 db_manager.create_tables() yield db_manager # 清理测试数据库 db_manager.drop_tables() @pytest.fixture(scope="session") def test_data_processor() -> DataProcessor: """测试数据处理器""" return DataProcessor() @pytest.fixture(scope="session") def event_loop() -> Generator[asyncio.AbstractEventLoop, None, None]: """异步事件循环""" loop = asyncio.get_event_loop_policy().new_event_loop() yield loop loop.close() @pytest.fixture def sample_stock_basic_data() -> list: """示例股票基础数据""" return [ { "code": "000001", "name": "平安银行", "market": "主板", "industry": "银行", "area": "广东", "ipo_date": "1991-04-03" }, { "code": "600000", "name": "浦发银行", "market": "主板", "industry": "银行", "area": "上海", "ipo_date": "1999-11-10" } ] @pytest.fixture def sample_kline_data() -> list: """示例K线数据""" return [ { "code": "000001", "date": "2024-01-15", "open": 10.5, "high": 11.2, "low": 10.3, "close": 10.8, "volume": 1000000, "amount": 10800000 }, { "code": "000001", "date": "2024-01-16", "open": 10.8, "high": 11.5, "low": 10.7, "close": 11.2, "volume": 1200000, "amount": 13440000 } ] @pytest.fixture def sample_financial_data() -> list: """示例财务数据""" return [ { "code": "000001", "report_date": "2023-12-31", "eps": 1.5, "net_profit": 1500000000, "revenue": 5000000000, "total_assets": 10000000000 } ] @pytest.fixture def invalid_stock_data() -> list: """无效股票数据""" return [ { "code": "invalid_code", "name": "", "market": "无效市场", "ipo_date": "2099-01-01" # 未来日期 }, { "code": "000001", "name": "测试股票", "open": -10.5, # 负价格 "high": 9.0, # 高价低于低价 "low": 11.0, "close": 10.0 } ]