stock/tests/conftest.py

273 lines
7.0 KiB
Python

"""
测试配置文件
定义测试用的fixture和配置
"""
import pytest
import asyncio
from pathlib import Path
from typing import Generator, AsyncGenerator
from src.config.settings import Settings
from src.storage.database import DatabaseManager
from src.data.data_processor import DataProcessor
from src.utils.logger import LogManager
@pytest.fixture(scope="session")
def test_settings() -> Settings:
"""测试配置"""
# 创建测试配置实例
return Settings()
@pytest.fixture(scope="session")
def test_log_manager(test_settings) -> LogManager:
"""测试日志管理器"""
return LogManager(
log_dir="./test_logs",
log_level=test_settings.log.log_level
)
def create_test_database_manager():
"""创建测试数据库管理器"""
# 创建一个简单的测试数据库管理器
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
class TestDatabaseManager:
def __init__(self):
self.engine = create_engine("sqlite:///:memory:")
self.SessionLocal = sessionmaker(bind=self.engine)
# 创建独立的Base类用于测试
self.Base = declarative_base()
# 导入模型类并重新定义
self._import_models()
def _import_models(self):
"""导入并重新定义模型类"""
from sqlalchemy import Column, Integer, String, Float, Date, DateTime, Text, Boolean, BigInteger, ForeignKey, Index
from sqlalchemy.sql import func
# 重新定义StockBasic模型
class StockBasic(self.Base):
__tablename__ = "stock_basic"
id = Column(Integer, primary_key=True, autoincrement=True)
code = Column(String(10), nullable=False, unique=True)
name = Column(String(50), nullable=False)
market = Column(String(10), nullable=False)
company_name = Column(String(100))
industry = Column(String(50))
area = Column(String(50))
ipo_date = Column(Date)
listing_status = Column(Boolean, default=True)
data_source = Column(String(50), default="akshare") # 添加data_source字段
created_at = Column(DateTime, server_default=func.now())
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now())
__table_args__ = (
Index("idx_code", "code"),
Index("idx_market", "market"),
Index("idx_industry", "industry"),
Index("idx_ipo_date", "ipo_date")
)
# 重新定义DailyKline模型
class DailyKline(self.Base):
__tablename__ = "daily_kline"
id = Column(Integer, primary_key=True, autoincrement=True)
stock_code = Column(String(10), ForeignKey("stock_basic.code"), nullable=False)
trade_date = Column(Date, nullable=False)
open_price = Column(Float, nullable=False)
high_price = Column(Float, nullable=False)
low_price = Column(Float, nullable=False)
close_price = Column(Float, nullable=False)
volume = Column(BigInteger)
amount = Column(Float)
change = Column(Float)
pct_change = Column(Float)
created_at = Column(DateTime, server_default=func.now())
__table_args__ = (
Index("idx_stock_code_date", "stock_code", "trade_date"),
Index("idx_trade_date", "trade_date")
)
# 重新定义FinancialReport模型
class FinancialReport(self.Base):
__tablename__ = "financial_report"
id = Column(Integer, primary_key=True, autoincrement=True)
stock_code = Column(String(10), ForeignKey("stock_basic.code"), nullable=False)
report_date = Column(Date, nullable=False)
report_type = Column(String(20), nullable=False)
report_year = Column(Integer, nullable=False)
report_quarter = Column(Integer)
eps = Column(Float)
net_profit = Column(Float)
revenue = Column(Float)
total_assets = Column(Float)
total_liabilities = Column(Float)
equity = Column(Float)
roe = Column(Float)
gross_profit_margin = Column(Float)
net_profit_margin = Column(Float)
debt_to_asset_ratio = Column(Float)
current_ratio = Column(Float)
quick_ratio = Column(Float)
created_at = Column(DateTime, server_default=func.now())
__table_args__ = (
Index("idx_stock_code_report", "stock_code", "report_date"),
Index("idx_report_date", "report_date"),
Index("idx_report_type", "report_type")
)
self.StockBasic = StockBasic
self.DailyKline = DailyKline
self.FinancialReport = FinancialReport
def create_tables(self):
"""创建所有数据表"""
try:
# 创建所有表
self.Base.metadata.create_all(bind=self.engine)
return True
except Exception as e:
print(f"创建数据库表失败: {str(e)}")
return False
def drop_tables(self):
"""删除所有数据表"""
try:
self.Base.metadata.drop_all(bind=self.engine)
return True
except Exception as e:
print(f"删除数据库表失败: {str(e)}")
return False
def get_session(self):
"""获取数据库会话"""
try:
session = self.SessionLocal()
return session
except Exception as e:
print(f"获取数据库会话失败: {str(e)}")
raise
return TestDatabaseManager()
@pytest.fixture(scope="function")
def db_manager():
"""测试数据库管理器fixture"""
db_manager = create_test_database_manager()
# 创建测试数据库
db_manager.create_tables()
yield db_manager
# 清理测试数据库
db_manager.drop_tables()
@pytest.fixture(scope="session")
def test_data_processor() -> DataProcessor:
"""测试数据处理器"""
return DataProcessor()
@pytest.fixture(scope="session")
def event_loop() -> Generator[asyncio.AbstractEventLoop, None, None]:
"""异步事件循环"""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest.fixture
def sample_stock_basic_data() -> list:
"""示例股票基础数据"""
return [
{
"code": "000001",
"name": "平安银行",
"market": "主板",
"industry": "银行",
"area": "广东",
"ipo_date": "1991-04-03"
},
{
"code": "600000",
"name": "浦发银行",
"market": "主板",
"industry": "银行",
"area": "上海",
"ipo_date": "1999-11-10"
}
]
@pytest.fixture
def sample_kline_data() -> list:
"""示例K线数据"""
return [
{
"code": "000001",
"date": "2024-01-15",
"open": 10.5,
"high": 11.2,
"low": 10.3,
"close": 10.8,
"volume": 1000000,
"amount": 10800000
},
{
"code": "000001",
"date": "2024-01-16",
"open": 10.8,
"high": 11.5,
"low": 10.7,
"close": 11.2,
"volume": 1200000,
"amount": 13440000
}
]
@pytest.fixture
def sample_financial_data() -> list:
"""示例财务数据"""
return [
{
"code": "000001",
"report_date": "2023-12-31",
"eps": 1.5,
"net_profit": 1500000000,
"revenue": 5000000000,
"total_assets": 10000000000
}
]
@pytest.fixture
def invalid_stock_data() -> list:
"""无效股票数据"""
return [
{
"code": "invalid_code",
"name": "",
"market": "无效市场",
"ipo_date": "2099-01-01" # 未来日期
},
{
"code": "000001",
"name": "测试股票",
"open": -10.5, # 负价格
"high": 9.0, # 高价低于低价
"low": 11.0,
"close": 10.0
}
]