273 lines
7.0 KiB
Python
273 lines
7.0 KiB
Python
"""
|
|
测试配置文件
|
|
定义测试用的fixture和配置
|
|
"""
|
|
|
|
import pytest
|
|
import asyncio
|
|
from pathlib import Path
|
|
from typing import Generator, AsyncGenerator
|
|
|
|
from src.config.settings import Settings
|
|
from src.storage.database import DatabaseManager
|
|
from src.data.data_processor import DataProcessor
|
|
from src.utils.logger import LogManager
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def test_settings() -> Settings:
|
|
"""测试配置"""
|
|
# 创建测试配置实例
|
|
return Settings()
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def test_log_manager(test_settings) -> LogManager:
|
|
"""测试日志管理器"""
|
|
return LogManager(
|
|
log_dir="./test_logs",
|
|
log_level=test_settings.log.log_level
|
|
)
|
|
|
|
|
|
def create_test_database_manager():
|
|
"""创建测试数据库管理器"""
|
|
# 创建一个简单的测试数据库管理器
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.ext.declarative import declarative_base
|
|
from sqlalchemy.orm import sessionmaker
|
|
|
|
class TestDatabaseManager:
|
|
def __init__(self):
|
|
self.engine = create_engine("sqlite:///:memory:")
|
|
self.SessionLocal = sessionmaker(bind=self.engine)
|
|
# 创建独立的Base类用于测试
|
|
self.Base = declarative_base()
|
|
# 导入模型类并重新定义
|
|
self._import_models()
|
|
|
|
def _import_models(self):
|
|
"""导入并重新定义模型类"""
|
|
from sqlalchemy import Column, Integer, String, Float, Date, DateTime, Text, Boolean, BigInteger, ForeignKey, Index
|
|
from sqlalchemy.sql import func
|
|
|
|
# 重新定义StockBasic模型
|
|
class StockBasic(self.Base):
|
|
__tablename__ = "stock_basic"
|
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
code = Column(String(10), nullable=False, unique=True)
|
|
name = Column(String(50), nullable=False)
|
|
market = Column(String(10), nullable=False)
|
|
company_name = Column(String(100))
|
|
industry = Column(String(50))
|
|
area = Column(String(50))
|
|
ipo_date = Column(Date)
|
|
listing_status = Column(Boolean, default=True)
|
|
data_source = Column(String(50), default="akshare") # 添加data_source字段
|
|
created_at = Column(DateTime, server_default=func.now())
|
|
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now())
|
|
|
|
__table_args__ = (
|
|
Index("idx_code", "code"),
|
|
Index("idx_market", "market"),
|
|
Index("idx_industry", "industry"),
|
|
Index("idx_ipo_date", "ipo_date")
|
|
)
|
|
|
|
# 重新定义DailyKline模型
|
|
class DailyKline(self.Base):
|
|
__tablename__ = "daily_kline"
|
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
stock_code = Column(String(10), ForeignKey("stock_basic.code"), nullable=False)
|
|
trade_date = Column(Date, nullable=False)
|
|
open_price = Column(Float, nullable=False)
|
|
high_price = Column(Float, nullable=False)
|
|
low_price = Column(Float, nullable=False)
|
|
close_price = Column(Float, nullable=False)
|
|
volume = Column(BigInteger)
|
|
amount = Column(Float)
|
|
change = Column(Float)
|
|
pct_change = Column(Float)
|
|
created_at = Column(DateTime, server_default=func.now())
|
|
|
|
__table_args__ = (
|
|
Index("idx_stock_code_date", "stock_code", "trade_date"),
|
|
Index("idx_trade_date", "trade_date")
|
|
)
|
|
|
|
# 重新定义FinancialReport模型
|
|
class FinancialReport(self.Base):
|
|
__tablename__ = "financial_report"
|
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
stock_code = Column(String(10), ForeignKey("stock_basic.code"), nullable=False)
|
|
report_date = Column(Date, nullable=False)
|
|
report_type = Column(String(20), nullable=False)
|
|
report_year = Column(Integer, nullable=False)
|
|
report_quarter = Column(Integer)
|
|
eps = Column(Float)
|
|
net_profit = Column(Float)
|
|
revenue = Column(Float)
|
|
total_assets = Column(Float)
|
|
total_liabilities = Column(Float)
|
|
equity = Column(Float)
|
|
roe = Column(Float)
|
|
gross_profit_margin = Column(Float)
|
|
net_profit_margin = Column(Float)
|
|
debt_to_asset_ratio = Column(Float)
|
|
current_ratio = Column(Float)
|
|
quick_ratio = Column(Float)
|
|
created_at = Column(DateTime, server_default=func.now())
|
|
|
|
__table_args__ = (
|
|
Index("idx_stock_code_report", "stock_code", "report_date"),
|
|
Index("idx_report_date", "report_date"),
|
|
Index("idx_report_type", "report_type")
|
|
)
|
|
|
|
self.StockBasic = StockBasic
|
|
self.DailyKline = DailyKline
|
|
self.FinancialReport = FinancialReport
|
|
|
|
def create_tables(self):
|
|
"""创建所有数据表"""
|
|
try:
|
|
# 创建所有表
|
|
self.Base.metadata.create_all(bind=self.engine)
|
|
return True
|
|
except Exception as e:
|
|
print(f"创建数据库表失败: {str(e)}")
|
|
return False
|
|
|
|
def drop_tables(self):
|
|
"""删除所有数据表"""
|
|
try:
|
|
self.Base.metadata.drop_all(bind=self.engine)
|
|
return True
|
|
except Exception as e:
|
|
print(f"删除数据库表失败: {str(e)}")
|
|
return False
|
|
|
|
def get_session(self):
|
|
"""获取数据库会话"""
|
|
try:
|
|
session = self.SessionLocal()
|
|
return session
|
|
except Exception as e:
|
|
print(f"获取数据库会话失败: {str(e)}")
|
|
raise
|
|
|
|
return TestDatabaseManager()
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def db_manager():
|
|
"""测试数据库管理器fixture"""
|
|
db_manager = create_test_database_manager()
|
|
|
|
# 创建测试数据库
|
|
db_manager.create_tables()
|
|
|
|
yield db_manager
|
|
|
|
# 清理测试数据库
|
|
db_manager.drop_tables()
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def test_data_processor() -> DataProcessor:
|
|
"""测试数据处理器"""
|
|
return DataProcessor()
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def event_loop() -> Generator[asyncio.AbstractEventLoop, None, None]:
|
|
"""异步事件循环"""
|
|
loop = asyncio.get_event_loop_policy().new_event_loop()
|
|
yield loop
|
|
loop.close()
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_stock_basic_data() -> list:
|
|
"""示例股票基础数据"""
|
|
return [
|
|
{
|
|
"code": "000001",
|
|
"name": "平安银行",
|
|
"market": "主板",
|
|
"industry": "银行",
|
|
"area": "广东",
|
|
"ipo_date": "1991-04-03"
|
|
},
|
|
{
|
|
"code": "600000",
|
|
"name": "浦发银行",
|
|
"market": "主板",
|
|
"industry": "银行",
|
|
"area": "上海",
|
|
"ipo_date": "1999-11-10"
|
|
}
|
|
]
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_kline_data() -> list:
|
|
"""示例K线数据"""
|
|
return [
|
|
{
|
|
"code": "000001",
|
|
"date": "2024-01-15",
|
|
"open": 10.5,
|
|
"high": 11.2,
|
|
"low": 10.3,
|
|
"close": 10.8,
|
|
"volume": 1000000,
|
|
"amount": 10800000
|
|
},
|
|
{
|
|
"code": "000001",
|
|
"date": "2024-01-16",
|
|
"open": 10.8,
|
|
"high": 11.5,
|
|
"low": 10.7,
|
|
"close": 11.2,
|
|
"volume": 1200000,
|
|
"amount": 13440000
|
|
}
|
|
]
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_financial_data() -> list:
|
|
"""示例财务数据"""
|
|
return [
|
|
{
|
|
"code": "000001",
|
|
"report_date": "2023-12-31",
|
|
"eps": 1.5,
|
|
"net_profit": 1500000000,
|
|
"revenue": 5000000000,
|
|
"total_assets": 10000000000
|
|
}
|
|
]
|
|
|
|
|
|
@pytest.fixture
|
|
def invalid_stock_data() -> list:
|
|
"""无效股票数据"""
|
|
return [
|
|
{
|
|
"code": "invalid_code",
|
|
"name": "",
|
|
"market": "无效市场",
|
|
"ipo_date": "2099-01-01" # 未来日期
|
|
},
|
|
{
|
|
"code": "000001",
|
|
"name": "测试股票",
|
|
"open": -10.5, # 负价格
|
|
"high": 9.0, # 高价低于低价
|
|
"low": 11.0,
|
|
"close": 10.0
|
|
}
|
|
] |