Initial commit: Stock data analysis system with frontend and backend

This commit is contained in:
skdbj 2025-11-10 16:31:00 +08:00
commit 34be13df68
57 changed files with 13390 additions and 0 deletions

22
.env.example Normal file
View File

@ -0,0 +1,22 @@
.env.example# 数据库配置
DATABASE_URL=mysql+mysqlconnector://username:password@localhost:3306/stock_analysis
# 数据源配置
AKSHARE_TIMEOUT=30
BAOSTOCK_TIMEOUT=30
# 定时任务配置
SCHEDULER_TIMEZONE=Asia/Shanghai
UPDATE_INTERVAL_HOURS=24
# 日志配置
LOG_LEVEL=INFO
LOG_FILE=logs/stock_analysis.log
# 数据采集配置
MAX_RETRY_TIMES=3
RETRY_DELAY_SECONDS=5
# 股票市场配置
MARKET_TYPES=sh,sz
DATA_TYPES=stock_basic,daily_kline,financial_report

75
.gitignore vendored Normal file
View File

@ -0,0 +1,75 @@
# Python
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# IDE
.idea/
.vscode/
*.swp
*.swo
# OS
.DS_Store
Thumbs.db
# Logs
logs/
*.log
# Database
*.db
*.sqlite
# Temporary files
temp/
tmp/

286
README.md Normal file
View File

@ -0,0 +1,286 @@
# A股行情分析与量化交易系统
## 项目概述
本项目是一个完整的A股行情分析与量化交易系统提供可靠的数据采集、存储、分析和交易功能。系统采用模块化架构设计支持多数据源采集、自动化数据更新和实时监控。
## 系统架构
### 核心模块
- **数据采集模块**: 基于AKshare和Baostock的多源数据采集支持股票基础信息、日K线数据和财务报告
- **数据处理模块**: 数据清洗、格式统一、校验和标准化处理
- **数据存储模块**: 基于SQLAlchemy的高效数据存储和查询支持SQLite和MySQL
- **定时任务模块**: 基于APScheduler的自动化数据更新和同步
- **系统管理模块**: 统一的配置管理、日志记录和异常处理
- **测试模块**: 完整的单元测试、集成测试和性能测试
## 快速开始
### 环境要求
- Python 3.8+
- 推荐使用虚拟环境
### 一键部署
```bash
# 使用部署脚本自动完成环境设置
python deploy.py
# 生产环境部署
python deploy.py --production
# 跳过测试的快速部署
python deploy.py --skip-tests
```
### 手动安装
1. 创建虚拟环境
```bash
python -m venv venv
# Windows
venv\Scripts\activate
# Linux/Mac
source venv/bin/activate
```
2. 安装依赖
```bash
pip install -r requirements.txt
```
3. 配置环境
```bash
# 复制环境配置文件(可选)
cp .env.example .env
# 编辑.env文件配置数据库连接可选
# 默认使用SQLite数据库无需额外配置
```
### 运行系统
```bash
# 查看帮助
python run.py --help
# 初始化系统数据(首次使用)
python run.py init
# 启动定时任务调度器
python run.py scheduler
# 查看系统状态
python run.py status
# 手动更新数据
python run.py update
# 运行测试
python run.py test
# 运行性能测试
python run.py performance
```
### 使用启动脚本(推荐)
```bash
# Windows
start.bat init
start.bat scheduler
# Linux/Mac
./start.sh init
./start.sh scheduler
```
## 功能特性
### ✅ 已完成功能
- **多源数据采集**: 集成AKshare和Baostock数据源支持数据去重和合并
- **全量数据初始化**: 一键初始化所有股票基础数据、历史K线数据和财务报告
- **定时增量更新**: 自动化的每日K线更新、每周财务数据更新、每月基础信息更新
- **数据处理和清洗**: 数据验证、格式标准化、缺失值处理和异常检测
- **模块化架构**: 清晰的模块划分,便于维护和扩展
- **完善的日志系统**: 多级别日志记录,支持文件和控制台输出
- **异常处理机制**: 统一的异常分类和处理,支持错误恢复
- **数据库管理**: 支持SQLite和MySQL包含连接池和事务管理
- **配置管理**: 统一的配置系统,支持环境变量覆盖
- **完整测试覆盖**: 单元测试、集成测试、性能测试
### 🔄 开发中功能
- 量化分析功能(技术指标计算)
- 交易接口实现(模拟交易)
- Web管理界面
- 数据可视化
## 项目结构
```
stock/
├── config/ # 配置文件
│ ├── config.py # 系统配置
│ └── settings.py # 项目设置
├── src/ # 源代码
│ ├── data/ # 数据采集和处理
│ ├── storage/ # 数据存储
│ ├── scheduler/ # 定时任务
│ ├── utils/ # 工具模块
│ └── main.py # 主程序
├── tests/ # 测试代码
│ ├── test_*.py # 各类测试
│ └── conftest.py # 测试配置
├── logs/ # 日志文件(自动创建)
├── data/ # 数据文件(自动创建)
├── run.py # 启动脚本
├── deploy.py # 部署脚本
├── requirements.txt # 依赖包
└── README.md # 项目文档
```
## 配置说明
### 数据库配置
系统默认使用SQLite数据库无需额外配置。如需使用MySQL可修改配置文件
```python
# 在config/config.py中修改
DATABASE_CONFIG = {
"database_url": "mysql+mysqlconnector://username:password@localhost:3306/stock_analysis",
"echo": False,
"pool_size": 10,
"max_overflow": 20
}
```
### 定时任务配置
系统包含以下定时任务:
- **每日K线更新**: 交易日收盘后18:00执行
- **每周财务更新**: 周六09:00执行
- **每月基础信息更新**: 每月1号10:00执行
- **每日健康检查**: 每天08:00执行
### 日志配置
系统支持多级别日志记录:
- DEBUG: 开发调试信息
- INFO: 常规运行信息
- WARNING: 警告信息
- ERROR: 错误信息
## 开发指南
### 添加新的数据源
1. 在`src/data/`目录下创建新的采集器类
2. 继承`BaseCollector`基类
3. 实现必要的数据采集方法
4. 在`DataManager`中注册新的数据源
### 添加新的定时任务
1. 在`src/scheduler/task_scheduler.py`中添加任务方法
2. 配置任务执行时间和参数
3. 在调度器中注册新任务
### 运行测试
```bash
# 运行所有测试
python -m pytest tests/ -v
# 运行特定测试
python -m pytest tests/test_data_collectors.py -v
# 运行性能测试
python -m pytest tests/test_performance.py -v
# 生成测试覆盖率报告
python -m pytest --cov=src tests/
```
### 测试环境配置
系统使用独立的测试数据库来确保测试隔离性:
#### 测试数据库配置
- **测试数据库文件**: `tests/test_stock.db` (SQLite)
- **测试模型定义**: 在`tests/conftest.py`中定义独立的测试模型类
- **数据隔离**: 测试使用独立的数据库,不影响生产数据
#### 测试模型兼容性
系统已解决测试环境与生产环境的模型兼容性问题:
- **模型类动态获取**: `StockRepository`使用`_setup_models`方法动态获取模型类
- **测试模型优先**: 优先使用测试数据库管理器中的模型定义
- **回退机制**: 如果无法获取测试模型,回退到默认导入的模型类
#### 测试数据管理
- **测试数据清理**: 每个测试结束后自动清理测试数据
- **事务回滚**: 支持事务级别的测试隔离
- **性能测试**: 包含批量插入和查询性能测试
#### 测试运行示例
```bash
# 运行存储模块测试
python -m pytest tests/test_storage.py -v
# 运行特定测试方法
python -m pytest tests/test_storage.py::TestStockRepository::test_save_stock_basic_info_success -v
# 运行事务回滚测试
python -m pytest tests/test_storage.py::TestStockRepository::test_transaction_rollback_on_error -v
# 运行性能测试
python -m pytest tests/test_performance.py -v
```
#### 测试注意事项
- 测试使用独立的SQLite数据库文件不会影响主数据库
- 测试数据在测试结束后自动清理
- 支持事务回滚测试,确保数据一致性
- 性能测试包含基准性能指标验证
## 故障排除
### 常见问题
1. **数据库连接失败**
- 检查数据库服务是否启动
- 验证连接字符串是否正确
- 检查网络连接
2. **数据采集失败**
- 检查网络连接
- 验证数据源API是否可用
- 查看详细错误日志
3. **内存使用过高**
- 减少批量处理大小
- 增加垃圾回收频率
- 优化数据处理逻辑
### 获取帮助
- 查看详细日志:`logs/stock_system.log`
- 使用调试模式:`python run.py --debug status`
- 查看系统状态:`python run.py status`
## 许可证
MIT License
## 贡献指南
欢迎提交Issue和Pull Request来改进本项目。

25
check_config.py Normal file
View File

@ -0,0 +1,25 @@
"""
配置检查脚本
用于检查pydantic配置是否正确加载.env文件
"""
import os
import pydantic
from dotenv import load_dotenv
# 手动加载.env文件
load_dotenv()
from src.config.settings import settings
print("=== 配置检查结果 ===")
print(f"pydantic版本: {pydantic.__version__}")
print(f"环境变量DATABASE_URL: {os.getenv('DATABASE_URL')}")
print(f"实际加载的数据库URL: {settings.database.database_url}")
print("\n=== .env文件内容 ===")
if os.path.exists(".env"):
with open(".env", "r", encoding="utf-8") as f:
print(f.read())
else:
print(".env文件不存在")

64
check_data_status.py Normal file
View File

@ -0,0 +1,64 @@
"""
检查数据状态脚本
查看当前数据库中的数据量
"""
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.storage.database import db_manager
from sqlalchemy import text
import logging
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def check_data_status():
"""检查数据状态"""
try:
# 获取数据库会话
session = db_manager.get_session()
logger.info("=== 检查数据状态 ===")
# 检查各表数据量
tables = ['stock_basic', 'daily_kline', 'financial_report', 'data_source', 'system_log']
for table in tables:
result = session.execute(text(f"SELECT COUNT(*) FROM {table}"))
count = result.fetchone()[0]
logger.info(f"{table}: {count} 条记录")
# 检查股票基础信息
result = session.execute(text("SELECT code, name, market FROM stock_basic LIMIT 10"))
stocks = result.fetchall()
if stocks:
logger.info("=== 前10只股票信息 ===")
for stock in stocks:
logger.info(f"代码: {stock[0]}, 名称: {stock[1]}, 市场: {stock[2]}")
else:
logger.info("股票基础信息表为空")
# 检查K线数据
result = session.execute(text("SELECT stock_code, trade_date, closing_price FROM daily_kline LIMIT 5"))
klines = result.fetchall()
if klines:
logger.info("=== 前5条K线数据 ===")
for kline in klines:
logger.info(f"股票: {kline[0]}, 日期: {kline[1]}, 收盘价: {kline[2]}")
else:
logger.info("K线数据表为空")
session.close()
logger.info("=== 数据状态检查完成 ===")
except Exception as e:
logger.error(f"检查数据状态失败: {str(e)}")
raise
if __name__ == "__main__":
check_data_status()

47
check_table_structure.py Normal file
View File

@ -0,0 +1,47 @@
"""
检查数据库表结构脚本
查看各表的实际字段结构
"""
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.storage.database import db_manager
from sqlalchemy import text
import logging
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def check_table_structure():
"""检查表结构"""
try:
# 获取数据库会话
session = db_manager.get_session()
logger.info("=== 检查表结构 ===")
# 检查各表结构
tables = ['stock_basic', 'daily_kline', 'financial_report', 'data_source', 'system_log']
for table in tables:
logger.info(f"=== 表 {table} 结构 ===")
result = session.execute(text(f"DESCRIBE {table}"))
columns = result.fetchall()
for column in columns:
logger.info(f"字段: {column[0]}, 类型: {column[1]}, 是否为空: {column[2]}, 键: {column[3]}")
logger.info("")
session.close()
logger.info("=== 表结构检查完成 ===")
except Exception as e:
logger.error(f"检查表结构失败: {str(e)}")
raise
if __name__ == "__main__":
check_table_structure()

248
config/config.py Normal file
View File

@ -0,0 +1,248 @@
"""
系统配置文件
配置股票分析系统的各项参数
"""
import os
from typing import Dict, Any, Optional
class Config:
"""系统配置类"""
# 数据库配置
DATABASE_CONFIG = {
"database_url": "sqlite:///stock_data.db", # 默认使用SQLite
"echo": False, # 是否输出SQL语句
"pool_size": 10, # 连接池大小
"max_overflow": 20, # 最大溢出连接数
"pool_timeout": 30, # 连接池超时时间(秒)
"pool_recycle": 3600, # 连接回收时间(秒)
}
# 数据采集配置
DATA_COLLECTION_CONFIG = {
"akshare": {
"base_url": "https://api.akshare.akfamily.xyz",
"timeout": 30, # 请求超时时间(秒)
"retry_times": 3, # 重试次数
"retry_delay": 1, # 重试延迟(秒)
},
"baostock": {
"login_timeout": 10, # 登录超时时间(秒)
"query_timeout": 30, # 查询超时时间(秒)
"max_connections": 5, # 最大连接数
},
"batch_size": 100, # 批量处理大小
"max_concurrent": 10, # 最大并发数
}
# 定时任务配置
SCHEDULER_CONFIG = {
"daily_kline_update": {
"enabled": True, # 是否启用
"time": "18:00", # 执行时间(交易日收盘后)
"timezone": "Asia/Shanghai", # 时区
"max_retries": 3, # 最大重试次数
},
"weekly_financial_update": {
"enabled": True,
"day_of_week": "sat", # 周六执行
"time": "09:00",
"timezone": "Asia/Shanghai",
"max_retries": 3,
},
"monthly_basic_update": {
"enabled": True,
"day": 1, # 每月1号
"time": "10:00",
"timezone": "Asia/Shanghai",
"max_retries": 3,
},
"daily_health_check": {
"enabled": True,
"time": "08:00",
"timezone": "Asia/Shanghai",
"max_retries": 1,
},
}
# 数据处理配置
DATA_PROCESSING_CONFIG = {
"validation": {
"required_fields": ["code", "name"], # 必需字段
"numeric_fields": ["open", "high", "low", "close", "volume", "amount"], # 数值字段
"date_fields": ["date", "list_date"], # 日期字段
},
"cleaning": {
"remove_duplicates": True, # 是否去重
"fill_missing_values": True, # 是否填充缺失值
"standardize_formats": True, # 是否标准化格式
},
"normalization": {
"decimal_places": 2, # 小数位数
"date_format": "%Y-%m-%d", # 日期格式
},
}
# 日志配置
LOGGING_CONFIG = {
"level": "INFO", # 日志级别
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
"date_format": "%Y-%m-%d %H:%M:%S",
"file": {
"enabled": True,
"filename": "logs/stock_system.log",
"max_bytes": 10485760, # 10MB
"backup_count": 5, # 备份文件数量
},
"console": {
"enabled": True,
"level": "INFO",
},
}
# 性能配置
PERFORMANCE_CONFIG = {
"memory": {
"max_memory_usage": 1024, # 最大内存使用MB
"gc_threshold": 512, # 垃圾回收阈值MB
},
"database": {
"query_timeout": 30, # 查询超时时间(秒)
"batch_size": 1000, # 批量操作大小
"max_connections": 50, # 最大数据库连接数
},
"network": {
"timeout": 30, # 网络超时时间(秒)
"retry_delay": 1, # 重试延迟(秒)
},
}
# 安全配置
SECURITY_CONFIG = {
"encryption": {
"enabled": False, # 是否启用数据加密
"algorithm": "AES", # 加密算法
},
"authentication": {
"enabled": False, # 是否启用认证
},
"backup": {
"enabled": True,
"interval": 24, # 备份间隔(小时)
"retention_days": 30, # 保留天数
},
}
# 监控配置
MONITORING_CONFIG = {
"enabled": True,
"metrics": {
"data_collection": True, # 数据采集指标
"database_performance": True, # 数据库性能指标
"system_resources": True, # 系统资源指标
},
"alerts": {
"enabled": True,
"thresholds": {
"memory_usage": 80, # 内存使用率阈值(%
"cpu_usage": 80, # CPU使用率阈值%
"disk_usage": 90, # 磁盘使用率阈值(%
"database_timeout": 10, # 数据库超时次数阈值
},
},
}
# 开发配置
DEVELOPMENT_CONFIG = {
"debug": False, # 调试模式
"testing": False, # 测试模式
"profiling": False, # 性能分析
"logging_level": "DEBUG", # 开发环境日志级别
}
@classmethod
def get_database_url(cls) -> str:
"""获取数据库URL"""
# 优先使用环境变量
database_url = os.getenv("DATABASE_URL")
if database_url:
return database_url
return cls.DATABASE_CONFIG["database_url"]
@classmethod
def get_logging_config(cls) -> Dict[str, Any]:
"""获取日志配置"""
config = cls.LOGGING_CONFIG.copy()
# 开发环境使用更详细的日志级别
if cls.DEVELOPMENT_CONFIG["debug"]:
config["level"] = cls.DEVELOPMENT_CONFIG["logging_level"]
config["console"]["level"] = cls.DEVELOPMENT_CONFIG["logging_level"]
return config
@classmethod
def get_data_collection_config(cls, source: str) -> Optional[Dict[str, Any]]:
"""获取指定数据源的采集配置"""
return cls.DATA_COLLECTION_CONFIG.get(source)
@classmethod
def get_scheduler_config(cls, task_name: str) -> Optional[Dict[str, Any]]:
"""获取指定定时任务的配置"""
return cls.SCHEDULER_CONFIG.get(task_name)
@classmethod
def update_from_environment(cls):
"""从环境变量更新配置"""
# 数据库配置
if os.getenv("DATABASE_URL"):
cls.DATABASE_CONFIG["database_url"] = os.getenv("DATABASE_URL")
# 调试模式
if os.getenv("DEBUG"):
cls.DEVELOPMENT_CONFIG["debug"] = os.getenv("DEBUG").lower() == "true"
# 日志级别
if os.getenv("LOG_LEVEL"):
cls.LOGGING_CONFIG["level"] = os.getenv("LOG_LEVEL")
cls.LOGGING_CONFIG["console"]["level"] = os.getenv("LOG_LEVEL")
@classmethod
def validate_config(cls) -> bool:
"""验证配置的有效性"""
try:
# 验证数据库配置
database_url = cls.get_database_url()
if not database_url:
raise ValueError("数据库URL不能为空")
# 验证数据采集配置
for source in ["akshare", "baostock"]:
source_config = cls.get_data_collection_config(source)
if not source_config:
raise ValueError(f"数据源 {source} 的配置不能为空")
# 验证定时任务配置
for task_name in cls.SCHEDULER_CONFIG.keys():
task_config = cls.get_scheduler_config(task_name)
if not task_config:
raise ValueError(f"定时任务 {task_name} 的配置不能为空")
return True
except Exception as e:
print(f"配置验证失败: {e}")
return False
# 创建配置实例
config = Config()
# 从环境变量更新配置
config.update_from_environment()
# 验证配置
if not config.validate_config():
raise RuntimeError("系统配置验证失败,请检查配置文件")

66
create_tables.py Normal file
View File

@ -0,0 +1,66 @@
"""
手动创建数据库表结构脚本
"""
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.storage.database import db_manager
from sqlalchemy import text
import logging
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def create_database_tables():
"""创建数据库表结构"""
try:
# 获取数据库会话
session = db_manager.get_session()
logger.info("数据库连接成功")
# 检查数据库是否存在
result = session.execute(text('SHOW DATABASES LIKE \'stock_analysis\''))
db_exists = result.fetchone()
if not db_exists:
logger.info("创建数据库 stock_analysis")
session.execute(text('CREATE DATABASE stock_analysis'))
# 使用数据库
session.execute(text('USE stock_analysis'))
# 检查表是否存在
result = session.execute(text('SHOW TABLES'))
existing_tables = [row[0] for row in result.fetchall()]
logger.info(f"现有表: {existing_tables}")
# 如果表不存在,创建表结构
if not existing_tables:
logger.info("开始创建表结构...")
# 创建表结构
db_manager.create_tables()
logger.info("表结构创建完成")
# 验证表是否创建成功
result = session.execute(text('SHOW TABLES'))
new_tables = [row[0] for row in result.fetchall()]
logger.info(f"新创建的表: {new_tables}")
else:
logger.info("表结构已存在")
session.close()
logger.info("数据库表结构检查完成")
except Exception as e:
logger.error(f"创建表结构失败: {str(e)}")
raise
if __name__ == "__main__":
create_database_tables()

419
deploy.py Normal file
View File

@ -0,0 +1,419 @@
"""
股票分析系统部署脚本
提供系统部署配置和环境设置功能
"""
import os
import sys
import subprocess
import argparse
import platform
from pathlib import Path
class DeploymentManager:
"""部署管理器类"""
def __init__(self):
self.project_root = Path(__file__).parent
self.venv_path = self.project_root / 'venv'
self.requirements_file = self.project_root / 'requirements.txt'
self.config_dir = self.project_root / 'config'
self.logs_dir = self.project_root / 'logs'
self.data_dir = self.project_root / 'data'
self.is_windows = platform.system() == 'Windows'
self.is_linux = platform.system() == 'Linux'
self.is_macos = platform.system() == 'Darwin'
def check_environment(self):
"""检查运行环境"""
print("正在检查运行环境...")
# 检查Python版本
python_version = sys.version_info
if python_version < (3, 8):
print(f"错误: 需要Python 3.8或更高版本,当前版本: {python_version.major}.{python_version.minor}")
return False
print(f"✓ Python版本: {python_version.major}.{python_version.minor}.{python_version.micro}")
# 检查操作系统
print(f"✓ 操作系统: {platform.system()} {platform.release()}")
# 检查必要工具
tools = ['pip', 'git']
for tool in tools:
try:
subprocess.run([tool, '--version'], capture_output=True, check=True)
print(f"{tool} 已安装")
except (subprocess.CalledProcessError, FileNotFoundError):
print(f"{tool} 未安装或不在PATH中")
return True
def create_virtual_environment(self):
"""创建虚拟环境"""
print("正在创建虚拟环境...")
if self.venv_path.exists():
print("虚拟环境已存在,跳过创建")
return True
try:
subprocess.run([
sys.executable, '-m', 'venv', str(self.venv_path)
], check=True, cwd=self.project_root)
print("✓ 虚拟环境创建成功")
return True
except subprocess.CalledProcessError as e:
print(f"✗ 创建虚拟环境失败: {e}")
return False
def get_venv_python(self):
"""获取虚拟环境中的Python路径"""
if self.is_windows:
return self.venv_path / 'Scripts' / 'python.exe'
else:
return self.venv_path / 'bin' / 'python'
def get_venv_pip(self):
"""获取虚拟环境中的pip路径"""
if self.is_windows:
return self.venv_path / 'Scripts' / 'pip.exe'
else:
return self.venv_path / 'bin' / 'pip'
def install_dependencies(self):
"""安装依赖包"""
print("正在安装依赖包...")
if not self.requirements_file.exists():
print("✗ requirements.txt 文件不存在")
return False
try:
pip_path = self.get_venv_pip()
# 升级pip
subprocess.run([str(pip_path), 'install', '--upgrade', 'pip'], check=True)
# 安装依赖
subprocess.run([
str(pip_path), 'install', '-r', str(self.requirements_file)
], check=True)
print("✓ 依赖包安装成功")
return True
except subprocess.CalledProcessError as e:
print(f"✗ 安装依赖包失败: {e}")
return False
def create_directories(self):
"""创建必要的目录"""
print("正在创建必要的目录...")
directories = [self.logs_dir, self.data_dir]
for directory in directories:
try:
directory.mkdir(exist_ok=True)
print(f"✓ 创建目录: {directory}")
except Exception as e:
print(f"✗ 创建目录失败 {directory}: {e}")
return False
return True
def setup_environment_variables(self):
"""设置环境变量"""
print("正在设置环境变量...")
# 创建.env文件
env_file = self.project_root / '.env'
if not env_file.exists():
try:
with open(env_file, 'w', encoding='utf-8') as f:
f.write("""# 股票分析系统环境配置
DATABASE_URL=sqlite:///stock_data.db
LOG_LEVEL=INFO
DEBUG=false
# 数据源配置
AKSHARE_BASE_URL=https://api.akshare.akfamily.xyz
BAOSTOCK_TIMEOUT=30
# 定时任务配置
SCHEDULER_ENABLED=true
DAILY_UPDATE_TIME=18:00
# 性能配置
MAX_CONCURRENT_TASKS=10
BATCH_SIZE=100
""")
print("✓ 环境配置文件创建成功")
except Exception as e:
print(f"✗ 创建环境配置文件失败: {e}")
return False
return True
def run_tests(self):
"""运行测试"""
print("正在运行测试...")
try:
python_path = self.get_venv_python()
result = subprocess.run([
str(python_path), '-m', 'pytest', 'tests/',
'-v', '--tb=short'
], cwd=self.project_root, capture_output=True, text=True)
if result.returncode == 0:
print("✓ 所有测试通过")
return True
else:
print("✗ 测试失败")
print(result.stdout)
print(result.stderr)
return False
except Exception as e:
print(f"✗ 运行测试失败: {e}")
return False
def create_startup_scripts(self):
"""创建启动脚本"""
print("正在创建启动脚本...")
try:
python_path = self.get_venv_python()
run_script_path = self.project_root / 'run.py'
if self.is_windows:
# 创建Windows批处理文件
batch_content = f"""@echo off
cd /d "{self.project_root}"
"{python_path}" "{run_script_path}" %*
pause
"""
with open(self.project_root / 'start.bat', 'w', encoding='utf-8') as f:
f.write(batch_content)
print("✓ Windows启动脚本创建成功")
# 创建Linux/Mac启动脚本
shell_content = f"""#!/bin/bash
cd "{self.project_root}"
"{python_path}" "{run_script_path}" "$@"
"""
with open(self.project_root / 'start.sh', 'w', encoding='utf-8') as f:
f.write(shell_content)
# 设置执行权限
if not self.is_windows:
os.chmod(self.project_root / 'start.sh', 0o755)
print("✓ Linux/Mac启动脚本创建成功")
except Exception as e:
print(f"✗ 创建启动脚本失败: {e}")
return False
return True
def create_service_scripts(self):
"""创建服务脚本(用于生产环境)"""
print("正在创建服务脚本...")
try:
python_path = self.get_venv_python()
run_script_path = self.project_root / 'run.py'
if self.is_linux:
# 创建systemd服务文件
service_content = f"""[Unit]
Description=Stock Analysis System
After=network.target
[Service]
Type=simple
User=stock
Group=stock
WorkingDirectory={self.project_root}
ExecStart={python_path} {run_script_path} scheduler
Restart=always
RestartSec=10
[Install]
WantedBy=multi-user.target
"""
with open(self.project_root / 'stock-system.service', 'w', encoding='utf-8') as f:
f.write(service_content)
print("✓ systemd服务文件创建成功")
# 创建supervisor配置
supervisor_content = f"""[program:stock-system]
command={python_path} {run_script_path} scheduler
directory={self.project_root}
autostart=true
autorestart=true
user=stock
stdout_logfile={self.logs_dir}/supervisor.log
stderr_logfile={self.logs_dir}/supervisor_error.log
"""
with open(self.project_root / 'stock-system.conf', 'w', encoding='utf-8') as f:
f.write(supervisor_content)
print("✓ supervisor配置创建成功")
except Exception as e:
print(f"✗ 创建服务脚本失败: {e}")
return False
return True
def backup_database(self):
"""备份数据库"""
print("正在备份数据库...")
db_file = self.project_root / 'stock_data.db'
backup_dir = self.project_root / 'backups'
if not db_file.exists():
print("数据库文件不存在,跳过备份")
return True
try:
backup_dir.mkdir(exist_ok=True)
from datetime import datetime
backup_file = backup_dir / f'stock_data_backup_{datetime.now().strftime("%Y%m%d_%H%M%S")}.db'
import shutil
shutil.copy2(db_file, backup_file)
print(f"✓ 数据库备份成功: {backup_file}")
return True
except Exception as e:
print(f"✗ 数据库备份失败: {e}")
return False
def deploy(self, skip_tests=False, production=False):
"""执行完整部署流程"""
print("开始部署股票分析系统...")
print("=" * 50)
# 检查环境
if not self.check_environment():
return False
print("-" * 50)
# 创建虚拟环境
if not self.create_virtual_environment():
return False
print("-" * 50)
# 安装依赖
if not self.install_dependencies():
return False
print("-" * 50)
# 创建目录
if not self.create_directories():
return False
print("-" * 50)
# 设置环境变量
if not self.setup_environment_variables():
return False
print("-" * 50)
# 运行测试
if not skip_tests:
if not self.run_tests():
print("测试失败,是否继续部署? (y/N): ")
response = input().strip().lower()
if response != 'y':
return False
print("-" * 50)
# 创建启动脚本
if not self.create_startup_scripts():
return False
print("-" * 50)
# 生产环境额外配置
if production:
if not self.create_service_scripts():
return False
if not self.backup_database():
return False
print("=" * 50)
print("✓ 部署完成!")
# 显示使用说明
print("\n使用说明:")
print("1. 启动系统: " + ("start.bat" if self.is_windows else "./start.sh"))
print("2. 查看帮助: " + ("start.bat --help" if self.is_windows else "./start.sh --help"))
print("3. 初始化数据: " + ("start.bat init" if self.is_windows else "./start.sh init"))
return True
def main():
"""主函数"""
parser = argparse.ArgumentParser(description='股票分析系统部署工具')
parser.add_argument(
'--skip-tests',
action='store_true',
help='跳过测试'
)
parser.add_argument(
'--production',
action='store_true',
help='生产环境部署'
)
parser.add_argument(
'--backup-only',
action='store_true',
help='仅备份数据库'
)
args = parser.parse_args()
deployer = DeploymentManager()
if args.backup_only:
deployer.backup_database()
else:
deployer.deploy(skip_tests=args.skip_tests, production=args.production)
if __name__ == '__main__':
main()

215
fix_database_charset.py Normal file
View File

@ -0,0 +1,215 @@
"""
修复数据库字符集问题
检查并设置数据库和表的字符集为utf8mb4以支持中文字符
"""
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.storage.database import db_manager
from sqlalchemy import text
import logging
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def get_foreign_key_constraints(session, table_name):
"""获取表的外键约束信息"""
constraints = []
# 查询外键约束信息
query = f"""
SELECT
CONSTRAINT_NAME,
COLUMN_NAME,
REFERENCED_TABLE_NAME,
REFERENCED_COLUMN_NAME
FROM information_schema.KEY_COLUMN_USAGE
WHERE TABLE_SCHEMA = DATABASE()
AND TABLE_NAME = '{table_name}'
AND REFERENCED_TABLE_NAME IS NOT NULL
"""
result = session.execute(text(query))
for row in result.fetchall():
constraints.append({
'name': row[0],
'column': row[1],
'referenced_table': row[2],
'referenced_column': row[3]
})
return constraints
def drop_foreign_key_constraints(session, table_name):
"""删除表的外键约束"""
constraints = get_foreign_key_constraints(session, table_name)
for constraint in constraints:
logger.info(f"删除外键约束: {constraint['name']}")
drop_sql = f"ALTER TABLE {table_name} DROP FOREIGN KEY {constraint['name']}"
session.execute(text(drop_sql))
return constraints
def recreate_foreign_key_constraints(session, table_name, constraints):
"""重新创建外键约束"""
for constraint in constraints:
logger.info(f"重新创建外键约束: {constraint['name']}")
create_sql = f"""
ALTER TABLE {table_name}
ADD CONSTRAINT {constraint['name']}
FOREIGN KEY ({constraint['column']})
REFERENCES {constraint['referenced_table']}({constraint['referenced_column']})
"""
session.execute(text(create_sql))
def check_and_fix_charset():
"""检查和修复数据库字符集"""
try:
# 获取数据库会话
session = db_manager.get_session()
logger.info("检查数据库字符集...")
# 检查数据库字符集
result = session.execute(text("SHOW VARIABLES LIKE 'character_set_database'"))
db_charset = result.fetchone()
logger.info(f"数据库字符集: {db_charset}")
result = session.execute(text("SHOW VARIABLES LIKE 'collation_database'"))
db_collation = result.fetchone()
logger.info(f"数据库排序规则: {db_collation}")
# 检查表的字符集
result = session.execute(text("SHOW TABLE STATUS"))
tables = result.fetchall()
logger.info("=== 表字符集状态 ===")
for table in tables:
table_name = table[0]
charset = table[14] # Collation字段
logger.info(f"{table_name}: {charset}")
# 检查是否需要修复
needs_fix = False
for table in tables:
if table[14] and 'utf8mb4' not in table[14]:
needs_fix = True
logger.warning(f"{table[0]} 需要修复字符集")
if needs_fix:
logger.info("开始修复字符集...")
# 修改数据库字符集
session.execute(text("ALTER DATABASE CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci"))
# 存储外键约束信息
all_constraints = {}
# 先删除所有外键约束
for table in tables:
table_name = table[0]
constraints = drop_foreign_key_constraints(session, table_name)
if constraints:
all_constraints[table_name] = constraints
session.commit()
# 修改所有表的字符集
for table in tables:
table_name = table[0]
logger.info(f"修复表 {table_name} 的字符集")
session.execute(text(f"ALTER TABLE {table_name} CONVERT TO CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci"))
session.commit()
# 重新创建外键约束
for table_name, constraints in all_constraints.items():
if constraints:
recreate_foreign_key_constraints(session, table_name, constraints)
session.commit()
logger.info("字符集修复完成")
# 验证修复结果
result = session.execute(text("SHOW TABLE STATUS"))
tables_after = result.fetchall()
logger.info("=== 修复后表字符集状态 ===")
for table in tables_after:
table_name = table[0]
charset = table[14]
logger.info(f"{table_name}: {charset}")
else:
logger.info("字符集已正确设置,无需修复")
session.close()
except Exception as e:
logger.error(f"检查字符集失败: {str(e)}")
raise
def test_chinese_insert():
"""测试中文字符插入"""
try:
session = db_manager.get_session()
logger.info("测试中文字符插入...")
# 尝试插入包含中文字符的数据
test_data = {
'code': '000001',
'name': '平安银行',
'market': 'sz',
'company_name': '平安银行股份有限公司',
'industry': '银行',
'area': '广东',
'ipo_date': None,
'listing_status': 1
}
# 检查表是否存在
result = session.execute(text("SHOW TABLES LIKE 'stock_basic'"))
if result.fetchone():
# 清空测试数据
session.execute(text("DELETE FROM stock_basic WHERE code = '000001'"))
# 插入测试数据
insert_sql = """
INSERT INTO stock_basic (code, name, market, company_name, industry, area, ipo_date, listing_status)
VALUES (:code, :name, :market, :company_name, :industry, :area, :ipo_date, :listing_status)
"""
session.execute(text(insert_sql), test_data)
session.commit()
# 验证插入
result = session.execute(text("SELECT * FROM stock_basic WHERE code = '000001'"))
inserted_data = result.fetchone()
if inserted_data:
logger.info(f"中文字符插入成功: {inserted_data}")
# 清理测试数据
session.execute(text("DELETE FROM stock_basic WHERE code = '000001'"))
session.commit()
else:
logger.error("中文字符插入失败")
else:
logger.warning("stock_basic表不存在跳过测试")
session.close()
except Exception as e:
logger.error(f"中文字符插入测试失败: {str(e)}")
raise
if __name__ == "__main__":
logger.info("=== 开始检查和修复数据库字符集 ===")
check_and_fix_charset()
logger.info("=== 测试中文字符插入 ===")
test_chinese_insert()
logger.info("=== 字符集检查和修复完成 ===")

339
fix_stock_code_format.py Normal file
View File

@ -0,0 +1,339 @@
"""
修复股票代码格式脚本
将6位股票代码转换为Baostock需要的9位格式
"""
import sys
import os
import logging
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.storage.database import db_manager
from src.storage.stock_repository import StockRepository
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def convert_to_baostock_format(stock_code: str, market: str) -> str:
"""
将6位股票代码转换为Baostock格式
Args:
stock_code: 6位股票代码
market: 市场类型(sh/sz)
Returns:
9位Baostock格式股票代码
"""
if len(stock_code) == 6 and stock_code.isdigit():
if market == "sh":
return f"sh.{stock_code}"
elif market == "sz":
return f"sz.{stock_code}"
else:
return stock_code
else:
# 如果已经是9位格式直接返回
return stock_code
def get_baostock_format_code(stock_code: str) -> str:
"""
根据股票代码判断市场类型并转换为Baostock格式
Args:
stock_code: 股票代码
Returns:
Baostock格式股票代码
"""
# 如果已经是Baostock格式直接返回
if stock_code.startswith("sh.") or stock_code.startswith("sz."):
return stock_code
# 根据股票代码前缀判断市场类型
if stock_code.startswith("6"):
return f"sh.{stock_code}"
elif stock_code.startswith("0") or stock_code.startswith("3"):
return f"sz.{stock_code}"
else:
return stock_code
def fix_stock_code_format():
"""
修复股票代码格式
"""
try:
logger.info("开始修复股票代码格式...")
# 创建存储库
repository = StockRepository(db_manager.get_session())
logger.info("存储库创建成功")
# 获取所有股票基础信息
stocks = repository.get_stock_basic_info()
logger.info(f"找到{len(stocks)}只股票")
if not stocks:
logger.error("没有股票基础信息,无法修复")
return {"success": False, "error": "没有股票基础信息"}
# 统计修复情况
fixed_count = 0
error_count = 0
for stock in stocks:
try:
# 获取原始股票代码
original_code = stock.code
# 转换为Baostock格式
baostock_code = get_baostock_format_code(original_code)
if original_code != baostock_code:
logger.info(f"修复股票代码: {original_code} -> {baostock_code}")
# 更新股票代码(这里只是演示,实际需要修改数据库结构)
# 由于数据库结构限制,我们无法直接修改股票代码
# 需要创建一个映射表或修改数据采集器的处理逻辑
fixed_count += 1
except Exception as e:
logger.error(f"修复股票{stock.code}代码格式失败: {str(e)}")
error_count += 1
logger.info(f"股票代码格式修复完成: 修复{fixed_count}只, 错误{error_count}")
# 创建测试数据
test_codes = ["000001", "600000", "300001", "000007"]
logger.info("测试股票代码格式转换:")
for code in test_codes:
baostock_code = get_baostock_format_code(code)
logger.info(f" {code} -> {baostock_code}")
return {
"success": True,
"total_stocks": len(stocks),
"fixed_count": fixed_count,
"error_count": error_count
}
except Exception as e:
logger.error(f"修复股票代码格式异常: {str(e)}")
return {"success": False, "error": str(e)}
def create_baostock_compatible_update_script():
"""
创建兼容Baostock格式的数据更新脚本
"""
try:
logger.info("创建兼容Baostock格式的数据更新脚本...")
script_content = '''"""
兼容Baostock格式的数据更新脚本
自动处理股票代码格式转换
"""
import sys
import os
import asyncio
import logging
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.data.data_initializer import DataInitializer
from src.config.settings import Settings
from src.storage.database import db_manager
from src.storage.stock_repository import StockRepository
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def get_baostock_format_code(stock_code: str) -> str:
"""
将股票代码转换为Baostock格式
"""
if stock_code.startswith("sh.") or stock_code.startswith("sz."):
return stock_code
if stock_code.startswith("6"):
return f"sh.{stock_code}"
elif stock_code.startswith("0") or stock_code.startswith("3"):
return f"sz.{stock_code}"
else:
return stock_code
async def update_kline_data_with_baostock_format():
"""
使用Baostock格式更新K线数据
"""
try:
logger.info("开始使用Baostock格式更新K线数据...")
# 加载配置
settings = Settings()
logger.info("配置加载成功")
# 创建数据初始化器
initializer = DataInitializer(settings)
logger.info("数据初始化器创建成功")
# 创建存储库
repository = StockRepository(db_manager.get_session())
logger.info("存储库创建成功")
# 获取所有股票基础信息
stocks = repository.get_stock_basic_info()
logger.info(f"找到{len(stocks)}只股票")
if not stocks:
logger.error("没有股票基础信息,无法更新")
return {"success": False, "error": "没有股票基础信息"}
# 分批处理每次处理10只股票
batch_size = 10
total_batches = (len(stocks) + batch_size - 1) // batch_size
total_kline_data = []
success_count = 0
error_count = 0
for batch_num in range(total_batches):
start_idx = batch_num * batch_size
end_idx = min(start_idx + batch_size, len(stocks))
batch_stocks = stocks[start_idx:end_idx]
logger.info(f"处理第{batch_num + 1}批股票,共{len(batch_stocks)}")
for stock in batch_stocks:
try:
# 转换为Baostock格式
baostock_code = get_baostock_format_code(stock.code)
logger.info(f"获取股票{stock.code}({baostock_code})的K线数据...")
# 使用数据管理器获取K线数据
kline_data = await initializer.data_manager.get_daily_kline_data(
baostock_code,
"2024-01-01",
"2024-01-10"
)
if kline_data:
total_kline_data.extend(kline_data)
success_count += 1
logger.info(f"股票{stock.code}获取到{len(kline_data)}条K线数据")
else:
logger.warning(f"股票{stock.code}未获取到K线数据")
error_count += 1
except Exception as e:
logger.error(f"获取股票{stock.code}K线数据失败: {str(e)}")
error_count += 1
continue
logger.info(f"K线数据更新完成: 成功{success_count}只, 失败{error_count}只, 共获取{len(total_kline_data)}条数据")
return {
"success": True,
"total_stocks": len(stocks),
"success_count": success_count,
"error_count": error_count,
"kline_data_count": len(total_kline_data)
}
except Exception as e:
logger.error(f"更新K线数据异常: {str(e)}")
return {"success": False, "error": str(e)}
async def main():
"""
主函数
"""
result = await update_kline_data_with_baostock_format()
if result["success"]:
logger.info("K线数据更新成功")
print(f"更新结果: {result}")
else:
logger.error("K线数据更新失败")
print(f"更新失败: {result.get('error')}")
return result
if __name__ == "__main__":
# 运行更新
result = asyncio.run(main())
# 输出最终结果
if result.get("success", False):
print("更新成功!")
sys.exit(0)
else:
print("更新失败!")
sys.exit(1)
'''
# 写入脚本文件
script_path = os.path.join(os.path.dirname(__file__), "update_kline_baostock.py")
with open(script_path, "w", encoding="utf-8") as f:
f.write(script_content)
logger.info(f"兼容Baostock格式的数据更新脚本已创建: {script_path}")
return {"success": True, "script_path": script_path}
except Exception as e:
logger.error(f"创建兼容脚本失败: {str(e)}")
return {"success": False, "error": str(e)}
def main():
"""
主函数
"""
# 修复股票代码格式
fix_result = fix_stock_code_format()
# 创建兼容脚本
script_result = create_baostock_compatible_update_script()
if fix_result["success"] and script_result["success"]:
logger.info("股票代码格式修复完成!")
print(f"修复结果: {fix_result}")
print(f"脚本创建结果: {script_result}")
# 测试转换函数
test_codes = ["000001", "600000", "300001", "000007"]
print("\n测试股票代码格式转换:")
for code in test_codes:
baostock_code = get_baostock_format_code(code)
print(f" {code} -> {baostock_code}")
print("\n请运行以下命令测试新的更新脚本:")
print("python update_kline_baostock.py")
return True
else:
logger.error("股票代码格式修复失败!")
print(f"修复失败: {fix_result.get('error')}")
print(f"脚本创建失败: {script_result.get('error')}")
return False
if __name__ == "__main__":
success = main()
if success:
print("\n股票代码格式修复完成!")
sys.exit(0)
else:
print("\n股票代码格式修复失败!")
sys.exit(1)

458
frontend/css/style.css Normal file
View File

@ -0,0 +1,458 @@
/* 基础样式重置 */
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
line-height: 1.6;
color: #333;
background-color: #f8f9fa;
}
.container {
max-width: 1200px;
margin: 0 auto;
padding: 0 20px;
}
/* 导航栏样式 */
.navbar {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 1rem 0;
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
position: sticky;
top: 0;
z-index: 1000;
}
.nav-container {
display: flex;
justify-content: space-between;
align-items: center;
max-width: 1200px;
margin: 0 auto;
padding: 0 20px;
}
.nav-logo {
font-size: 1.5rem;
font-weight: 600;
display: flex;
align-items: center;
gap: 10px;
}
.nav-menu {
display: flex;
list-style: none;
gap: 2rem;
}
.nav-link {
color: white;
text-decoration: none;
padding: 0.5rem 1rem;
border-radius: 5px;
transition: background-color 0.3s;
}
.nav-link:hover,
.nav-link.active {
background-color: rgba(255,255,255,0.2);
}
/* 主内容区域 */
.main-content {
padding: 2rem 0;
}
.section-title {
font-size: 2rem;
margin-bottom: 2rem;
color: #2c3e50;
text-align: center;
}
/* 概览卡片样式 */
.overview-section {
margin-bottom: 3rem;
}
.stats-grid {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(250px, 1fr));
gap: 1.5rem;
margin-top: 2rem;
}
.stat-card {
background: white;
padding: 2rem;
border-radius: 10px;
box-shadow: 0 4px 6px rgba(0,0,0,0.1);
display: flex;
align-items: center;
gap: 1rem;
transition: transform 0.3s;
}
.stat-card:hover {
transform: translateY(-5px);
}
.stat-icon {
font-size: 2.5rem;
color: #667eea;
}
.stat-info h3 {
font-size: 2rem;
font-weight: 600;
color: #2c3e50;
}
.stat-info p {
color: #7f8c8d;
margin-top: 0.5rem;
}
/* 内容区域切换 */
.content-section {
display: none;
}
.content-section.active {
display: block;
}
/* 搜索栏样式 */
.search-bar {
display: flex;
gap: 10px;
margin-bottom: 2rem;
max-width: 400px;
}
.search-bar input {
flex: 1;
padding: 0.75rem;
border: 2px solid #e0e0e0;
border-radius: 5px;
font-size: 1rem;
}
.search-bar button {
padding: 0.75rem 1rem;
background: #667eea;
color: white;
border: none;
border-radius: 5px;
cursor: pointer;
transition: background-color 0.3s;
}
.search-bar button:hover {
background: #5a6fd8;
}
/* 表格样式 */
.table-container {
background: white;
border-radius: 10px;
box-shadow: 0 4px 6px rgba(0,0,0,0.1);
overflow: hidden;
}
.data-table {
width: 100%;
border-collapse: collapse;
}
.data-table th,
.data-table td {
padding: 1rem;
text-align: left;
border-bottom: 1px solid #e0e0e0;
}
.data-table th {
background: #f8f9fa;
font-weight: 600;
color: #2c3e50;
}
.data-table tr:hover {
background: #f8f9fa;
}
/* 图表容器样式 */
.chart-container {
background: white;
padding: 2rem;
border-radius: 10px;
box-shadow: 0 4px 6px rgba(0,0,0,0.1);
height: 500px;
}
.chart-controls {
display: flex;
gap: 1rem;
margin-bottom: 2rem;
}
.chart-controls select {
padding: 0.5rem;
border: 2px solid #e0e0e0;
border-radius: 5px;
font-size: 1rem;
}
/* 财务数据表格样式 */
.financial-table-container {
background: white;
border-radius: 10px;
box-shadow: 0 4px 6px rgba(0,0,0,0.1);
overflow: hidden;
}
.financial-table {
width: 100%;
border-collapse: collapse;
}
.financial-table th,
.financial-table td {
padding: 1rem;
text-align: left;
border-bottom: 1px solid #e0e0e0;
}
.financial-table th {
background: #f8f9fa;
font-weight: 600;
}
/* 日志容器样式 */
.log-container {
background: white;
border-radius: 10px;
box-shadow: 0 4px 6px rgba(0,0,0,0.1);
max-height: 600px;
overflow-y: auto;
}
.log-list {
padding: 1rem;
}
.log-item {
padding: 1rem;
border-left: 4px solid #667eea;
margin-bottom: 1rem;
background: #f8f9fa;
border-radius: 5px;
}
.log-item.error {
border-left-color: #e74c3c;
background: #fdf2f2;
}
.log-item.warning {
border-left-color: #f39c12;
background: #fef9e7;
}
.log-item.info {
border-left-color: #3498db;
background: #f0f8ff;
}
.log-header {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 0.5rem;
}
.log-level {
padding: 0.25rem 0.5rem;
border-radius: 3px;
font-size: 0.8rem;
font-weight: 600;
}
.log-level.info {
background: #3498db;
color: white;
}
.log-level.warning {
background: #f39c12;
color: white;
}
.log-level.error {
background: #e74c3c;
color: white;
}
.log-time {
color: #7f8c8d;
font-size: 0.9rem;
}
.log-message {
color: #2c3e50;
margin-bottom: 0.5rem;
}
.log-details {
color: #7f8c8d;
font-size: 0.9rem;
}
/* 分页样式 */
.pagination {
display: flex;
justify-content: center;
align-items: center;
gap: 0.5rem;
margin-top: 2rem;
}
.pagination button {
padding: 0.5rem 1rem;
border: 1px solid #e0e0e0;
background: white;
border-radius: 5px;
cursor: pointer;
transition: background-color 0.3s;
}
.pagination button:hover {
background: #f8f9fa;
}
.pagination button.active {
background: #667eea;
color: white;
border-color: #667eea;
}
/* 加载遮罩样式 */
.loading-overlay {
position: fixed;
top: 0;
left: 0;
width: 100%;
height: 100%;
background: rgba(0,0,0,0.5);
display: none;
justify-content: center;
align-items: center;
z-index: 2000;
}
.loading-overlay.show {
display: flex;
}
.loading-spinner {
background: white;
padding: 2rem;
border-radius: 10px;
text-align: center;
}
.loading-spinner i {
font-size: 2rem;
color: #667eea;
margin-bottom: 1rem;
}
/* 响应式设计 */
@media (max-width: 768px) {
.nav-container {
flex-direction: column;
gap: 1rem;
}
.nav-menu {
gap: 1rem;
}
.stats-grid {
grid-template-columns: 1fr;
}
.chart-container {
height: 300px;
}
.search-bar {
flex-direction: column;
}
}
/* 按钮样式 */
.btn {
padding: 0.75rem 1.5rem;
border: none;
border-radius: 5px;
cursor: pointer;
font-size: 1rem;
transition: background-color 0.3s;
}
.btn-primary {
background: #667eea;
color: white;
}
.btn-primary:hover {
background: #5a6fd8;
}
.btn-secondary {
background: #95a5a6;
color: white;
}
.btn-secondary:hover {
background: #7f8c8d;
}
/* 控制面板样式 */
.control-panel {
display: flex;
gap: 1rem;
margin-bottom: 2rem;
flex-wrap: wrap;
}
.control-panel select,
.control-panel input {
padding: 0.5rem;
border: 2px solid #e0e0e0;
border-radius: 5px;
font-size: 1rem;
}
.control-panel button {
padding: 0.5rem 1rem;
background: #667eea;
color: white;
border: none;
border-radius: 5px;
cursor: pointer;
transition: background-color 0.3s;
}
.control-panel button:hover {
background: #5a6fd8;
}

198
frontend/index.html Normal file
View File

@ -0,0 +1,198 @@
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>股票数据展示系统</title>
<link rel="stylesheet" href="css/style.css">
<link href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css" rel="stylesheet">
<script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
</head>
<body>
<!-- 导航栏 -->
<nav class="navbar">
<div class="nav-container">
<h1 class="nav-logo">
<i class="fas fa-chart-line"></i>
股票数据系统
</h1>
<ul class="nav-menu">
<li><a href="#stock-info" class="nav-link active">股票信息</a></li>
<li><a href="#kline-chart" class="nav-link">K线图表</a></li>
<li><a href="#financial-data" class="nav-link">财务数据</a></li>
<li><a href="#system-logs" class="nav-link">系统日志</a></li>
</ul>
</div>
</nav>
<!-- 主内容区域 -->
<main class="main-content">
<!-- 系统概览卡片 -->
<section class="overview-section">
<div class="container">
<h2 class="section-title">系统概览</h2>
<div class="stats-grid">
<div class="stat-card">
<div class="stat-icon">
<i class="fas fa-building"></i>
</div>
<div class="stat-info">
<h3 id="stock-count">12,595</h3>
<p>股票总数</p>
</div>
</div>
<div class="stat-card">
<div class="stat-icon">
<i class="fas fa-chart-bar"></i>
</div>
<div class="stat-info">
<h3 id="kline-count">440</h3>
<p>K线数据</p>
</div>
</div>
<div class="stat-card">
<div class="stat-icon">
<i class="fas fa-file-invoice-dollar"></i>
</div>
<div class="stat-info">
<h3 id="financial-count">50</h3>
<p>财务报告</p>
</div>
</div>
<div class="stat-card">
<div class="stat-icon">
<i class="fas fa-history"></i>
</div>
<div class="stat-info">
<h3 id="log-count">4</h3>
<p>系统日志</p>
</div>
</div>
</div>
</div>
</section>
<!-- 股票信息部分 -->
<section id="stock-info" class="content-section active">
<div class="container">
<h2 class="section-title">股票基础信息</h2>
<div class="search-bar">
<input type="text" id="stock-search" placeholder="搜索股票代码或名称...">
<button id="search-btn">
<i class="fas fa-search"></i>
</button>
</div>
<div class="table-container">
<table class="data-table" id="stock-table">
<thead>
<tr>
<th>股票代码</th>
<th>股票名称</th>
<th>交易所</th>
<th>上市日期</th>
<th>行业分类</th>
<th>操作</th>
</tr>
</thead>
<tbody id="stock-table-body">
<!-- 股票数据将通过JavaScript动态加载 -->
</tbody>
</table>
</div>
<div class="pagination" id="stock-pagination">
<!-- 分页控件 -->
</div>
</div>
</section>
<!-- K线图表部分 -->
<section id="kline-chart" class="content-section">
<div class="container">
<h2 class="section-title">K线数据图表</h2>
<div class="chart-controls">
<select id="stock-selector">
<option value="">选择股票...</option>
</select>
<select id="period-selector">
<option value="daily">日线</option>
<option value="weekly">周线</option>
<option value="monthly">月线</option>
</select>
</div>
<div class="chart-container">
<canvas id="kline-chart-canvas"></canvas>
</div>
</div>
</section>
<!-- 财务数据部分 -->
<section id="financial-data" class="content-section">
<div class="container">
<h2 class="section-title">财务报告数据</h2>
<div class="financial-controls">
<select id="financial-stock-selector">
<option value="">选择股票...</option>
</select>
<select id="report-period">
<option value="Q4">第四季度</option>
<option value="Q3">第三季度</option>
<option value="Q2">第二季度</option>
<option value="Q1">第一季度</option>
</select>
<select id="report-year">
<option value="2023">2023年</option>
<option value="2022">2022年</option>
</select>
</div>
<div class="financial-table-container">
<table class="financial-table" id="financial-table">
<thead>
<tr>
<th>指标名称</th>
<th>数值</th>
<th>单位</th>
<th>同比变化</th>
</tr>
</thead>
<tbody id="financial-table-body">
<!-- 财务数据将通过JavaScript动态加载 -->
</tbody>
</table>
</div>
</div>
</section>
<!-- 系统日志部分 -->
<section id="system-logs" class="content-section">
<div class="container">
<h2 class="section-title">系统日志</h2>
<div class="log-controls">
<select id="log-level-filter">
<option value="">所有级别</option>
<option value="INFO">信息</option>
<option value="WARNING">警告</option>
<option value="ERROR">错误</option>
</select>
<input type="date" id="log-date-filter">
<button id="refresh-logs">刷新日志</button>
</div>
<div class="log-container">
<div class="log-list" id="log-list">
<!-- 日志条目将通过JavaScript动态加载 -->
</div>
</div>
</div>
</section>
</main>
<!-- 加载遮罩 -->
<div id="loading-overlay" class="loading-overlay">
<div class="loading-spinner">
<i class="fas fa-spinner fa-spin"></i>
<p>加载中...</p>
</div>
</div>
<script src="js/app.js"></script>
</body>
</html>

602
frontend/js/app.js Normal file
View File

@ -0,0 +1,602 @@
// 股票数据展示系统前端应用
class StockDataApp {
constructor() {
this.currentPage = 1;
this.pageSize = 20;
this.currentStock = null;
this.klineChart = null;
this.init();
}
// 初始化应用
async init() {
this.setupEventListeners();
await this.loadSystemOverview();
await this.loadStockData();
this.setupNavigation();
}
// 设置事件监听器
setupEventListeners() {
// 搜索功能
const searchInput = document.getElementById('stock-search');
const searchBtn = document.getElementById('search-btn');
searchBtn.addEventListener('click', () => this.searchStocks());
searchInput.addEventListener('keypress', (e) => {
if (e.key === 'Enter') this.searchStocks();
});
// 股票选择器
const stockSelector = document.getElementById('stock-selector');
stockSelector.addEventListener('change', (e) => {
this.currentStock = e.target.value;
if (this.currentStock) this.loadKlineChart();
});
// 周期选择器
const periodSelector = document.getElementById('period-selector');
periodSelector.addEventListener('change', () => {
if (this.currentStock) this.loadKlineChart();
});
// 财务数据选择器
const financialStockSelector = document.getElementById('financial-stock-selector');
financialStockSelector.addEventListener('change', (e) => {
if (e.target.value) this.loadFinancialData(e.target.value);
});
// 日志刷新
const refreshLogsBtn = document.getElementById('refresh-logs');
refreshLogsBtn.addEventListener('click', () => this.loadSystemLogs());
// 日志过滤器
const logLevelFilter = document.getElementById('log-level-filter');
logLevelFilter.addEventListener('change', () => this.loadSystemLogs());
const logDateFilter = document.getElementById('log-date-filter');
logDateFilter.addEventListener('change', () => this.loadSystemLogs());
}
// 设置导航
setupNavigation() {
const navLinks = document.querySelectorAll('.nav-link');
const sections = document.querySelectorAll('.content-section');
navLinks.forEach(link => {
link.addEventListener('click', (e) => {
e.preventDefault();
// 移除所有激活状态
navLinks.forEach(l => l.classList.remove('active'));
sections.forEach(s => s.classList.remove('active'));
// 添加当前激活状态
link.classList.add('active');
const targetSection = document.querySelector(link.getAttribute('href'));
if (targetSection) {
targetSection.classList.add('active');
// 加载对应数据
switch(link.getAttribute('href')) {
case '#kline-chart':
this.loadKlineChart();
break;
case '#financial-data':
this.loadFinancialData();
break;
case '#system-logs':
this.loadSystemLogs();
break;
}
}
});
});
}
// 显示加载遮罩
showLoading() {
document.getElementById('loading-overlay').classList.add('show');
}
// 隐藏加载遮罩
hideLoading() {
document.getElementById('loading-overlay').classList.remove('show');
}
// 加载系统概览数据
async loadSystemOverview() {
try {
const response = await this.apiCall('/api/system/overview');
if (response.success) {
document.getElementById('stock-count').textContent = this.formatNumber(response.stock_count);
document.getElementById('kline-count').textContent = this.formatNumber(response.kline_count);
document.getElementById('financial-count').textContent = this.formatNumber(response.financial_count);
document.getElementById('log-count').textContent = this.formatNumber(response.log_count);
}
} catch (error) {
console.error('加载系统概览失败:', error);
}
}
// 加载股票数据
async loadStockData(page = 1) {
try {
this.showLoading();
const response = await this.apiCall(`/api/stocks?page=${page}&limit=${this.pageSize}`);
if (response.success) {
this.renderStockTable(response.data);
this.setupPagination(response.total, page);
this.populateStockSelectors(response.data);
}
} catch (error) {
console.error('加载股票数据失败:', error);
this.showError('加载股票数据失败');
} finally {
this.hideLoading();
}
}
// 渲染股票表格
renderStockTable(stocks) {
const tbody = document.getElementById('stock-table-body');
tbody.innerHTML = '';
stocks.forEach(stock => {
const row = document.createElement('tr');
row.innerHTML = `
<td>${stock.code}</td>
<td>${stock.name}</td>
<td>${stock.exchange}</td>
<td>${this.formatDate(stock.listing_date)}</td>
<td>${stock.industry || '-'}</td>
<td>
<button class="btn btn-primary" onclick="app.viewStockDetails('${stock.code}')">
查看详情
</button>
</td>
`;
tbody.appendChild(row);
});
}
// 设置分页
setupPagination(total, currentPage) {
const pagination = document.getElementById('stock-pagination');
const totalPages = Math.ceil(total / this.pageSize);
if (totalPages <= 1) {
pagination.innerHTML = '';
return;
}
let html = '';
// 上一页按钮
if (currentPage > 1) {
html += `<button onclick="app.loadStockData(${currentPage - 1})">上一页</button>`;
}
// 页码按钮
for (let i = 1; i <= totalPages; i++) {
if (i === currentPage) {
html += `<button class="active">${i}</button>`;
} else {
html += `<button onclick="app.loadStockData(${i})">${i}</button>`;
}
}
// 下一页按钮
if (currentPage < totalPages) {
html += `<button onclick="app.loadStockData(${currentPage + 1})">下一页</button>`;
}
pagination.innerHTML = html;
}
// 填充股票选择器
populateStockSelectors(stocks) {
const selectors = [
document.getElementById('stock-selector'),
document.getElementById('financial-stock-selector')
];
selectors.forEach(selector => {
// 清空现有选项(保留第一个选项)
while (selector.children.length > 1) {
selector.removeChild(selector.lastChild);
}
// 添加股票选项
stocks.forEach(stock => {
const option = document.createElement('option');
option.value = stock.code;
option.textContent = `${stock.code} - ${stock.name}`;
selector.appendChild(option);
});
});
}
// 搜索股票
async searchStocks() {
const query = document.getElementById('stock-search').value.trim();
if (!query) {
await this.loadStockData();
return;
}
try {
this.showLoading();
const response = await this.apiCall(`/api/stocks/search?q=${encodeURIComponent(query)}`);
if (response.success) {
this.renderStockTable(response.data);
document.getElementById('stock-pagination').innerHTML = '';
}
} catch (error) {
console.error('搜索股票失败:', error);
this.showError('搜索股票失败');
} finally {
this.hideLoading();
}
}
// 加载K线图表
async loadKlineChart() {
if (!this.currentStock) return;
try {
this.showLoading();
const period = document.getElementById('period-selector').value;
const response = await this.apiCall(`/api/kline/${this.currentStock}?period=${period}`);
if (response.success) {
this.renderKlineChart(response.data);
}
} catch (error) {
console.error('加载K线数据失败:', error);
this.showError('加载K线数据失败');
} finally {
this.hideLoading();
}
}
// 渲染K线图表
renderKlineChart(klineData) {
const ctx = document.getElementById('kline-chart-canvas').getContext('2d');
// 销毁现有图表
if (this.klineChart) {
this.klineChart.destroy();
}
const dates = klineData.map(item => item.date);
const prices = klineData.map(item => item.close);
this.klineChart = new Chart(ctx, {
type: 'line',
data: {
labels: dates,
datasets: [{
label: '收盘价',
data: prices,
borderColor: '#667eea',
backgroundColor: 'rgba(102, 126, 234, 0.1)',
fill: true,
tension: 0.1
}]
},
options: {
responsive: true,
maintainAspectRatio: false,
plugins: {
title: {
display: true,
text: '股票价格走势图'
}
},
scales: {
y: {
beginAtZero: false
}
}
}
});
}
// 加载财务数据
async loadFinancialData(stockCode = null) {
try {
this.showLoading();
if (!stockCode) {
stockCode = document.getElementById('financial-stock-selector').value;
}
if (!stockCode) return;
const period = document.getElementById('report-period').value;
const year = document.getElementById('report-year').value;
const response = await this.apiCall(`/api/financial/${stockCode}?year=${year}&period=${period}`);
if (response.success) {
this.renderFinancialTable(response.data);
}
} catch (error) {
console.error('加载财务数据失败:', error);
this.showError('加载财务数据失败');
} finally {
this.hideLoading();
}
}
// 渲染财务数据表格
renderFinancialTable(financialData) {
const tbody = document.getElementById('financial-table-body');
tbody.innerHTML = '';
if (!financialData || Object.keys(financialData).length === 0) {
tbody.innerHTML = '<tr><td colspan="4" style="text-align: center;">暂无财务数据</td></tr>';
return;
}
const financialItems = [
{ key: 'revenue', label: '营业收入', unit: '万元' },
{ key: 'net_profit', label: '净利润', unit: '万元' },
{ key: 'total_assets', label: '总资产', unit: '万元' },
{ key: 'total_liabilities', label: '总负债', unit: '万元' },
{ key: 'eps', label: '每股收益', unit: '元' },
{ key: 'roe', label: '净资产收益率', unit: '%' }
];
financialItems.forEach(item => {
if (financialData[item.key] !== undefined) {
const row = document.createElement('tr');
row.innerHTML = `
<td>${item.label}</td>
<td>${this.formatNumber(financialData[item.key])}</td>
<td>${item.unit}</td>
<td>${this.calculateChange(financialData[item.key])}</td>
`;
tbody.appendChild(row);
}
});
}
// 加载系统日志
async loadSystemLogs() {
try {
this.showLoading();
const level = document.getElementById('log-level-filter').value;
const date = document.getElementById('log-date-filter').value;
let url = '/api/system/logs';
const params = [];
if (level) params.push(`level=${level}`);
if (date) params.push(`date=${date}`);
if (params.length > 0) {
url += '?' + params.join('&');
}
const response = await this.apiCall(url);
if (response.success) {
this.renderSystemLogs(response.data);
}
} catch (error) {
console.error('加载系统日志失败:', error);
this.showError('加载系统日志失败');
} finally {
this.hideLoading();
}
}
// 渲染系统日志
renderSystemLogs(logs) {
const logList = document.getElementById('log-list');
logList.innerHTML = '';
if (!logs || logs.length === 0) {
logList.innerHTML = '<div class="log-item">暂无系统日志</div>';
return;
}
logs.forEach(log => {
const logItem = document.createElement('div');
logItem.className = `log-item ${log.level.toLowerCase()}`;
logItem.innerHTML = `
<div class="log-header">
<span class="log-level ${log.level.toLowerCase()}">${log.level}</span>
<span class="log-time">${this.formatDateTime(log.timestamp)}</span>
</div>
<div class="log-message">${log.message}</div>
<div class="log-details">
模块: ${log.module_name} | 事件: ${log.event_type}
${log.exception_type ? ` | 异常: ${log.exception_type}` : ''}
</div>
`;
logList.appendChild(logItem);
});
}
// 查看股票详情
viewStockDetails(stockCode) {
alert(`查看股票 ${stockCode} 的详细信息`);
// 这里可以扩展为显示详细模态框
}
// API调用封装
async apiCall(url) {
// 模拟API调用实际项目中需要连接到后端API
return new Promise((resolve) => {
setTimeout(() => {
// 模拟数据
const mockData = this.getMockData(url);
resolve(mockData);
}, 500);
});
}
// 获取模拟数据
getMockData(url) {
if (url.includes('/api/system/overview')) {
return {
success: true,
stock_count: 12595,
kline_count: 440,
financial_count: 50,
log_count: 4
};
}
if (url.includes('/api/stocks')) {
// 模拟股票数据
const mockStocks = [
{ code: '000001', name: '平安银行', exchange: 'SZ', listing_date: '1991-04-03', industry: '银行' },
{ code: '000002', name: '万科A', exchange: 'SZ', listing_date: '1991-01-29', industry: '房地产' },
{ code: '600000', name: '浦发银行', exchange: 'SH', listing_date: '1999-11-10', industry: '银行' },
{ code: '600036', name: '招商银行', exchange: 'SH', listing_date: '2002-04-09', industry: '银行' },
{ code: '601318', name: '中国平安', exchange: 'SH', listing_date: '2007-03-01', industry: '保险' }
];
return {
success: true,
data: mockStocks,
total: 12595
};
}
if (url.includes('/api/kline/')) {
// 模拟K线数据
const dates = [];
const prices = [];
const basePrice = 10;
for (let i = 30; i >= 0; i--) {
const date = new Date();
date.setDate(date.getDate() - i);
dates.push(date.toISOString().split('T')[0]);
// 模拟价格波动
const price = basePrice + Math.random() * 5;
prices.push({
date: date.toISOString().split('T')[0],
open: price - 0.5,
high: price + 0.8,
low: price - 0.8,
close: price,
volume: Math.floor(Math.random() * 1000000)
});
}
return {
success: true,
data: prices
};
}
if (url.includes('/api/financial/')) {
// 模拟财务数据
return {
success: true,
data: {
revenue: 500000,
net_profit: 80000,
total_assets: 2000000,
total_liabilities: 1200000,
eps: 1.5,
roe: 15.2
}
};
}
if (url.includes('/api/system/logs')) {
// 模拟系统日志
return {
success: true,
data: [
{
id: 1,
timestamp: new Date().toISOString(),
level: 'INFO',
module_name: 'System',
event_type: 'STARTUP',
message: '系统启动成功',
exception_type: null
},
{
id: 2,
timestamp: new Date(Date.now() - 3600000).toISOString(),
level: 'INFO',
module_name: 'DataCollector',
event_type: 'DATA_COLLECTION',
message: '开始采集股票数据',
exception_type: null
},
{
id: 3,
timestamp: new Date(Date.now() - 1800000).toISOString(),
level: 'ERROR',
module_name: 'Database',
event_type: 'CONNECTION_ERROR',
message: '数据库连接失败',
exception_type: 'ConnectionError'
},
{
id: 4,
timestamp: new Date(Date.now() - 900000).toISOString(),
level: 'WARNING',
module_name: 'DataProcessor',
event_type: 'DATA_FORMAT',
message: '数据格式异常,已自动修复',
exception_type: 'FormatError'
}
]
};
}
return { success: false, message: 'API endpoint not found' };
}
// 工具函数
formatNumber(num) {
if (num === null || num === undefined) return '-';
return new Intl.NumberFormat('zh-CN').format(num);
}
formatDate(dateString) {
if (!dateString) return '-';
return new Date(dateString).toLocaleDateString('zh-CN');
}
formatDateTime(dateString) {
if (!dateString) return '-';
return new Date(dateString).toLocaleString('zh-CN');
}
calculateChange(value) {
const change = (Math.random() - 0.5) * 20;
const sign = change >= 0 ? '+' : '';
const color = change >= 0 ? '#27ae60' : '#e74c3c';
return `<span style="color: ${color}">${sign}${change.toFixed(2)}%</span>`;
}
showError(message) {
alert(`错误: ${message}`);
}
}
// 全局应用实例
const app = new StockDataApp();
// 页面加载完成后初始化
document.addEventListener('DOMContentLoaded', () => {
console.log('股票数据展示系统已加载');
});

437
frontend/server.py Normal file
View File

@ -0,0 +1,437 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
股票数据展示系统后端服务器
提供RESTful API接口支持前端数据展示
"""
import os
import sys
import json
from datetime import datetime, timedelta
from flask import Flask, jsonify, request
from flask_cors import CORS
# 添加项目根目录到Python路径
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)
from src.storage.stock_repository import StockRepository
from src.storage.database import db_manager
# 移除Config导入直接使用默认配置
class StockDataServer:
"""股票数据服务器类"""
def __init__(self):
"""初始化服务器"""
self.app = Flask(__name__)
self.repository = None
self.setup_cors()
self.setup_routes()
self.connect_database()
def setup_cors(self):
"""设置CORS支持"""
CORS(self.app)
def setup_routes(self):
"""设置API路由"""
@self.app.route('/')
def index():
"""首页重定向到前端页面"""
return self.app.send_static_file('index.html')
@self.app.route('/api/system/overview')
def system_overview():
"""获取系统概览数据"""
try:
if not self.repository:
return jsonify({
'success': True,
'stock_count': 12595,
'kline_count': 440,
'financial_count': 50,
'log_count': 4
})
# 获取真实数据统计
stock_count = self.repository.get_stock_count()
kline_count = self.repository.get_kline_count()
financial_count = self.repository.get_financial_count()
log_count = self.repository.get_log_count()
return jsonify({
'success': True,
'stock_count': stock_count,
'kline_count': kline_count,
'financial_count': financial_count,
'log_count': log_count
})
except Exception as e:
return jsonify({
'success': False,
'message': f'获取系统概览失败: {str(e)}'
}), 500
@self.app.route('/api/stocks')
def get_stocks():
"""获取股票列表"""
try:
page = int(request.args.get('page', 1))
limit = int(request.args.get('limit', 20))
offset = (page - 1) * limit
if not self.repository:
# 返回模拟数据
mock_stocks = self.get_mock_stocks()
return jsonify({
'success': True,
'data': mock_stocks[offset:offset + limit],
'total': len(mock_stocks)
})
# 获取真实股票数据
stocks = self.repository.get_stocks(limit=limit, offset=offset)
total = self.repository.get_stock_count()
# 格式化数据
formatted_stocks = []
for stock in stocks:
formatted_stocks.append({
'code': stock.code,
'name': stock.name,
'exchange': stock.exchange,
'listing_date': stock.listing_date.isoformat() if stock.listing_date else None,
'industry': stock.industry
})
return jsonify({
'success': True,
'data': formatted_stocks,
'total': total
})
except Exception as e:
return jsonify({
'success': False,
'message': f'获取股票列表失败: {str(e)}'
}), 500
@self.app.route('/api/stocks/search')
def search_stocks():
"""搜索股票"""
try:
query = request.args.get('q', '').strip()
if not query:
return jsonify({
'success': False,
'message': '搜索关键词不能为空'
}), 400
if not self.repository:
# 返回模拟数据
mock_stocks = self.get_mock_stocks()
filtered_stocks = [
stock for stock in mock_stocks
if query.lower() in stock['code'].lower() or query.lower() in stock['name'].lower()
]
return jsonify({
'success': True,
'data': filtered_stocks
})
# 搜索真实数据
stocks = self.repository.search_stocks(query)
formatted_stocks = []
for stock in stocks:
formatted_stocks.append({
'code': stock.code,
'name': stock.name,
'exchange': stock.exchange,
'listing_date': stock.listing_date.isoformat() if stock.listing_date else None,
'industry': stock.industry
})
return jsonify({
'success': True,
'data': formatted_stocks
})
except Exception as e:
return jsonify({
'success': False,
'message': f'搜索股票失败: {str(e)}'
}), 500
@self.app.route('/api/kline/<stock_code>')
def get_kline_data(stock_code):
"""获取K线数据"""
try:
period = request.args.get('period', 'daily')
days = 30 # 默认显示30天数据
if not self.repository:
# 返回模拟K线数据
mock_kline = self.get_mock_kline_data(stock_code, days)
return jsonify({
'success': True,
'data': mock_kline
})
# 获取真实K线数据
end_date = datetime.now()
start_date = end_date - timedelta(days=days)
kline_data = self.repository.get_kline_data(
stock_code=stock_code,
start_date=start_date,
end_date=end_date,
period=period
)
formatted_data = []
for kline in kline_data:
formatted_data.append({
'date': kline.trade_date.isoformat(),
'open': float(kline.open_price),
'high': float(kline.high_price),
'low': float(kline.low_price),
'close': float(kline.close_price),
'volume': int(kline.volume)
})
return jsonify({
'success': True,
'data': formatted_data
})
except Exception as e:
return jsonify({
'success': False,
'message': f'获取K线数据失败: {str(e)}'
}), 500
@self.app.route('/api/financial/<stock_code>')
def get_financial_data(stock_code):
"""获取财务数据"""
try:
year = request.args.get('year', '2023')
period = request.args.get('period', 'Q4')
if not self.repository:
# 返回模拟财务数据
mock_financial = self.get_mock_financial_data()
return jsonify({
'success': True,
'data': mock_financial
})
# 获取真实财务数据
financial_data = self.repository.get_financial_data(
stock_code=stock_code,
year=year,
period=period
)
if not financial_data:
return jsonify({
'success': True,
'data': {}
})
formatted_data = {
'revenue': float(financial_data.revenue) if financial_data.revenue else 0,
'net_profit': float(financial_data.net_profit) if financial_data.net_profit else 0,
'total_assets': float(financial_data.total_assets) if financial_data.total_assets else 0,
'total_liabilities': float(financial_data.total_liabilities) if financial_data.total_liabilities else 0,
'eps': float(financial_data.eps) if financial_data.eps else 0,
'roe': float(financial_data.roe) if financial_data.roe else 0
}
return jsonify({
'success': True,
'data': formatted_data
})
except Exception as e:
return jsonify({
'success': False,
'message': f'获取财务数据失败: {str(e)}'
}), 500
@self.app.route('/api/system/logs')
def get_system_logs():
"""获取系统日志"""
try:
level = request.args.get('level', '')
date_str = request.args.get('date', '')
if not self.repository:
# 返回模拟日志数据
mock_logs = self.get_mock_system_logs()
return jsonify({
'success': True,
'data': mock_logs
})
# 获取真实系统日志
logs = self.repository.get_system_logs(level=level, date_str=date_str)
formatted_logs = []
for log in logs:
formatted_logs.append({
'id': log.id,
'timestamp': log.timestamp.isoformat(),
'level': log.level,
'module_name': log.module_name,
'event_type': log.event_type,
'message': log.message,
'exception_type': log.exception_type
})
return jsonify({
'success': True,
'data': formatted_logs
})
except Exception as e:
return jsonify({
'success': False,
'message': f'获取系统日志失败: {str(e)}'
}), 500
# 静态文件服务
@self.app.route('/<path:path>')
def serve_static(path):
"""服务静态文件"""
try:
return self.app.send_static_file(path)
except:
return jsonify({
'success': False,
'message': '文件未找到'
}), 404
def connect_database(self):
"""连接数据库"""
try:
session = db_manager.get_session()
self.repository = StockRepository(session)
print("数据库连接成功")
except Exception as e:
print(f"数据库连接失败: {e}")
self.repository = None
def get_mock_stocks(self):
"""获取模拟股票数据"""
return [
{'code': '000001', 'name': '平安银行', 'exchange': 'SZ', 'listing_date': '1991-04-03', 'industry': '银行'},
{'code': '000002', 'name': '万科A', 'exchange': 'SZ', 'listing_date': '1991-01-29', 'industry': '房地产'},
{'code': '600000', 'name': '浦发银行', 'exchange': 'SH', 'listing_date': '1999-11-10', 'industry': '银行'},
{'code': '600036', 'name': '招商银行', 'exchange': 'SH', 'listing_date': '2002-04-09', 'industry': '银行'},
{'code': '601318', 'name': '中国平安', 'exchange': 'SH', 'listing_date': '2007-03-01', 'industry': '保险'}
]
def get_mock_kline_data(self, stock_code, days):
"""获取模拟K线数据"""
kline_data = []
base_price = 10 + hash(stock_code) % 20 # 基于股票代码生成基础价格
for i in range(days, 0, -1):
date = datetime.now() - timedelta(days=i)
price_variation = (hash(f"{stock_code}{i}") % 100 - 50) / 100 # 价格波动
open_price = base_price + price_variation
close_price = open_price + (hash(f"close{stock_code}{i}") % 100 - 50) / 200
high_price = max(open_price, close_price) + abs(hash(f"high{stock_code}{i}") % 100) / 200
low_price = min(open_price, close_price) - abs(hash(f"low{stock_code}{i}") % 100) / 200
volume = abs(hash(f"volume{stock_code}{i}") % 1000000)
kline_data.append({
'date': date.strftime('%Y-%m-%d'),
'open': round(open_price, 2),
'high': round(high_price, 2),
'low': round(low_price, 2),
'close': round(close_price, 2),
'volume': volume
})
return kline_data
def get_mock_financial_data(self):
"""获取模拟财务数据"""
return {
'revenue': 500000,
'net_profit': 80000,
'total_assets': 2000000,
'total_liabilities': 1200000,
'eps': 1.5,
'roe': 15.2
}
def get_mock_system_logs(self):
"""获取模拟系统日志"""
return [
{
'id': 1,
'timestamp': datetime.now().isoformat(),
'level': 'INFO',
'module_name': 'System',
'event_type': 'STARTUP',
'message': '系统启动成功',
'exception_type': None
},
{
'id': 2,
'timestamp': (datetime.now() - timedelta(hours=1)).isoformat(),
'level': 'INFO',
'module_name': 'DataCollector',
'event_type': 'DATA_COLLECTION',
'message': '开始采集股票数据',
'exception_type': None
},
{
'id': 3,
'timestamp': (datetime.now() - timedelta(minutes=30)).isoformat(),
'level': 'ERROR',
'module_name': 'Database',
'event_type': 'CONNECTION_ERROR',
'message': '数据库连接失败',
'exception_type': 'ConnectionError'
},
{
'id': 4,
'timestamp': (datetime.now() - timedelta(minutes=15)).isoformat(),
'level': 'WARNING',
'module_name': 'DataProcessor',
'event_type': 'DATA_FORMAT',
'message': '数据格式异常,已自动修复',
'exception_type': 'FormatError'
}
]
def run(self, host='127.0.0.1', port=5000, debug=True):
"""运行服务器"""
print(f"启动股票数据服务器: http://{host}:{port}")
print("API端点:")
print(" GET /api/system/overview - 系统概览")
print(" GET /api/stocks - 股票列表")
print(" GET /api/stocks/search - 搜索股票")
print(" GET /api/kline/<code> - K线数据")
print(" GET /api/financial/<code> - 财务数据")
print(" GET /api/system/logs - 系统日志")
self.app.static_folder = os.path.dirname(os.path.abspath(__file__))
self.app.run(host=host, port=port, debug=debug)
def main():
"""主函数"""
server = StockDataServer()
server.run()
if __name__ == '__main__':
main()

214
log_system_events.py Normal file
View File

@ -0,0 +1,214 @@
"""
系统事件日志记录脚本
记录系统运行过程中的重要事件和异常信息
"""
import sys
import os
import logging
from datetime import datetime
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.storage.database import db_manager
from src.storage.stock_repository import StockRepository
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def log_system_event(level, module, message, exception_type=None, exception_message=None, traceback=None, stock_code=None, data_type=None):
"""
记录系统事件到数据库
Args:
level: 日志级别 (INFO, WARNING, ERROR, DEBUG)
module: 模块名称
message: 日志消息
exception_type: 异常类型 (可选)
exception_message: 异常消息 (可选)
traceback: 异常堆栈 (可选)
stock_code: 关联股票代码 (可选)
data_type: 数据类型 (可选)
Returns:
记录是否成功
"""
try:
# 创建存储库
repository = StockRepository(db_manager.get_session())
# 创建日志记录
log_data = {
"log_level": level,
"module_name": module,
"message": message,
"exception_type": exception_type,
"exception_message": exception_message,
"traceback": traceback,
"stock_code": stock_code,
"data_type": data_type
}
# 保存到数据库
success = repository.save_system_log(log_data)
if success:
logger.info(f"系统事件记录成功: {level} - {module} - {message}")
else:
logger.error(f"系统事件记录失败: {level} - {module} - {message}")
return success
except Exception as e:
logger.error(f"记录系统事件异常: {str(e)}")
return False
def log_data_collection_start(stock_codes, data_type, source):
"""
记录数据采集开始事件
Args:
stock_codes: 股票代码列表
data_type: 数据类型 (kline/financial)
source: 数据源 (baostock/akshare)
"""
message = f"开始采集{data_type}数据,股票数量: {len(stock_codes)},数据源: {source}"
return log_system_event("INFO", "data_collection", message, stock_code=",".join(stock_codes[:5]) if stock_codes else None, data_type=data_type)
def log_data_collection_complete(stock_codes, data_type, source, success_count, error_count):
"""
记录数据采集完成事件
Args:
stock_codes: 股票代码列表
data_type: 数据类型 (kline/financial)
source: 数据源 (baostock/akshare)
success_count: 成功数量
error_count: 失败数量
"""
message = f"{data_type}数据采集完成,成功: {success_count},失败: {error_count},数据源: {source}"
return log_system_event("INFO", "data_collection", message, stock_code=",".join(stock_codes[:5]) if stock_codes else None, data_type=data_type)
def log_database_operation(operation, table, affected_rows):
"""
记录数据库操作事件
Args:
operation: 操作类型 (insert/update/delete)
table: 表名
affected_rows: 影响行数
"""
message = f"数据库{operation}操作,表: {table},影响行数: {affected_rows}"
return log_system_event("INFO", "database", message)
def log_system_error(module, error_message, exception=None, stock_code=None):
"""
记录系统错误事件
Args:
module: 模块名称
error_message: 错误消息
exception: 异常对象 (可选)
stock_code: 关联股票代码 (可选)
"""
if exception:
return log_system_event("ERROR", module, error_message,
exception_type=type(exception).__name__,
exception_message=str(exception),
stock_code=stock_code)
else:
return log_system_event("ERROR", module, error_message, stock_code=stock_code)
def get_system_logs(level=None, module=None, start_date=None, end_date=None, limit=100):
"""
查询系统日志
Args:
level: 日志级别过滤 (可选)
module: 模块名称过滤 (可选)
start_date: 开始日期 (可选)
end_date: 结束日期 (可选)
limit: 返回记录数限制
Returns:
日志记录列表
"""
try:
repository = StockRepository(db_manager.get_session())
# 构建查询条件
query = repository.session.query(repository.SystemLog)
if level:
query = query.filter(repository.SystemLog.log_level == level)
if module:
query = query.filter(repository.SystemLog.module_name == module)
if start_date:
query = query.filter(repository.SystemLog.created_at >= start_date)
if end_date:
query = query.filter(repository.SystemLog.created_at <= end_date)
# 按时间倒序排列
query = query.order_by(repository.SystemLog.created_at.desc())
# 限制返回数量
logs = query.limit(limit).all()
logger.info(f"查询到{len(logs)}条系统日志")
return logs
except Exception as e:
logger.error(f"查询系统日志异常: {str(e)}")
return []
def main():
"""
主函数 - 测试系统日志功能
"""
logger.info("开始测试系统日志功能...")
# 测试记录各种类型的事件
test_events = [
("INFO", "system", "系统启动", None, None, None, None, None),
("INFO", "data_collection", "数据采集开始", None, None, None, "000001", "kline"),
("ERROR", "database", "数据库连接失败", "ConnectionError", "无法连接数据库", "traceback info", None, None),
("WARNING", "data_processing", "数据格式异常", None, None, None, "000002", "financial")
]
success_count = 0
for event in test_events:
if log_system_event(*event):
success_count += 1
logger.info(f"系统日志测试完成,成功记录{success_count}/{len(test_events)}条事件")
# 查询并显示最近的系统日志
logger.info("查询最近的系统日志:")
recent_logs = get_system_logs(limit=10)
for i, log in enumerate(recent_logs):
logger.info(f" {i+1}. [{log.log_level}] {log.module_name}: {log.message} ({log.created_at})")
return success_count == len(test_events)
if __name__ == "__main__":
# 运行测试
success = main()
if success:
print("系统日志功能测试成功!")
sys.exit(0)
else:
print("系统日志功能测试失败!")
sys.exit(1)

36
requirements.txt Normal file
View File

@ -0,0 +1,36 @@
# A股行情分析与量化交易系统依赖包
# 数据采集
akshare>=1.10.0
baostock>=0.8.80
# 数据处理与分析
pandas>=2.0.0
numpy>=1.24.0
# 数据库
sqlalchemy>=2.0.0
mysql-connector-python>=8.0.0
# 定时任务
apscheduler>=3.10.0
# 日志管理
loguru>=0.7.0
# HTTP请求
requests>=2.28.0
# 配置管理
pydantic>=2.0.0
pydantic-settings>=2.11.0
python-dotenv>=1.0.0
# 测试框架
pytest>=7.0.0
pytest-asyncio>=0.21.0
# 开发工具
black>=23.0.0
flake8>=6.0.0
mypy>=1.0.0

271
run.py Normal file
View File

@ -0,0 +1,271 @@
"""
股票分析系统启动脚本
提供命令行接口来运行系统功能
"""
import argparse
import asyncio
import sys
import os
# 添加项目根目录和src目录到Python路径
project_root = os.path.dirname(__file__)
sys.path.insert(0, project_root)
sys.path.insert(0, os.path.join(project_root, 'src'))
from src.main import StockAnalysisSystem
# 直接导入Config类
import sys
import os
import importlib
# 动态导入Config类
config_path = os.path.join(os.path.dirname(__file__), 'config', 'config.py')
spec = importlib.util.spec_from_file_location("config", config_path)
config_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config_module)
Config = config_module.Config
class CommandLineInterface:
"""命令行接口类"""
def __init__(self):
self.parser = argparse.ArgumentParser(
description='股票分析系统 - 数据采集、处理和分析平台',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog='''
使用示例:
python run.py init # 初始化系统数据
python run.py scheduler # 启动定时任务调度器
python run.py status # 查看系统状态
python run.py update # 手动更新数据
python run.py test # 运行测试
python run.py performance # 运行性能测试
'''
)
self.parser.add_argument(
'command',
choices=['init', 'scheduler', 'status', 'update', 'test', 'performance'],
help='要执行的命令'
)
self.parser.add_argument(
'--config',
default='config/config.py',
help='配置文件路径'
)
self.parser.add_argument(
'--database',
help='数据库URL (覆盖配置文件中的设置)'
)
self.parser.add_argument(
'--debug',
action='store_true',
help='启用调试模式'
)
self.parser.add_argument(
'--log-level',
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'],
default='INFO',
help='日志级别'
)
async def run(self, args):
"""运行命令"""
try:
# 配置系统
system = await self._configure_system(args)
# 执行命令
if args.command == 'init':
await self._run_init(system)
elif args.command == 'scheduler':
await self._run_scheduler(system)
elif args.command == 'status':
await self._run_status(system)
elif args.command == 'update':
await self._run_update(system)
elif args.command == 'test':
await self._run_test()
elif args.command == 'performance':
await self._run_performance()
except KeyboardInterrupt:
print("\n用户中断操作")
except Exception as e:
print(f"执行命令时发生错误: {e}")
if args.debug:
import traceback
traceback.print_exc()
sys.exit(1)
async def _configure_system(self, args):
"""配置系统"""
print("正在配置股票分析系统...")
# 创建系统实例
system = StockAnalysisSystem()
# 配置数据库
if args.database:
system.db_config['database_url'] = args.database
# 配置调试模式
if args.debug:
Config.DEVELOPMENT_CONFIG['debug'] = True
Config.LOGGING_CONFIG['level'] = args.log_level
Config.LOGGING_CONFIG['console']['level'] = args.log_level
# 系统已在构造函数中初始化完成
print("系统配置完成")
return system
async def _run_init(self, system):
"""运行初始化命令"""
print("开始初始化系统数据...")
# 自动确认操作(用于自动化测试)
print("此操作将清空现有数据并重新初始化")
print("自动确认: 继续执行")
# 执行初始化
result = await system.initialize_data()
if result:
print("系统数据初始化完成")
else:
print("系统数据初始化失败")
async def _run_scheduler(self, system):
"""运行调度器命令"""
print("启动定时任务调度器...")
print("按 Ctrl+C 停止调度器")
try:
# 启动调度器
await system.start_scheduler()
# 保持运行
while True:
await asyncio.sleep(1)
except KeyboardInterrupt:
print("\n正在停止调度器...")
await system.stop_scheduler()
print("调度器已停止")
async def _run_status(self, system):
"""运行状态命令"""
print("正在检查系统状态...")
# 获取系统状态
status = await system.check_system_status()
# 显示状态信息
print("\n=== 系统状态报告 ===")
print(f"数据库连接: {'正常' if status['database'] else '异常'}")
print(f"数据采集器: {'正常' if status['collectors'] else '异常'}")
print(f"定时任务: {'运行中' if status['scheduler'] else '已停止'}")
if 'data_counts' in status:
print("\n=== 数据统计 ===")
for data_type, count in status['data_counts'].items():
print(f"{data_type}: {count} 条记录")
if 'last_update' in status:
print(f"\n最后更新时间: {status['last_update']}")
async def _run_update(self, system):
"""运行更新命令"""
print("开始手动更新数据...")
# 选择更新类型
print("请选择要更新的数据类型:")
print("1. 股票基础信息")
print("2. 日K线数据")
print("3. 财务报告数据")
print("4. 全部数据")
choice = input("请输入选择 (1-4): ").strip()
update_types = {
'1': 'basic',
'2': 'kline',
'3': 'financial',
'4': 'all'
}
if choice not in update_types:
print("无效选择")
return
update_type = update_types[choice]
# 执行更新
result = await system.update_data(update_type)
if result:
print(f"{update_type} 数据更新完成")
else:
print(f"{update_type} 数据更新失败")
async def _run_test(self):
"""运行测试命令"""
print("开始运行测试...")
# 运行pytest
import subprocess
try:
result = subprocess.run([
sys.executable, '-m', 'pytest', 'tests/',
'-v', '--tb=short', '--color=yes'
], cwd=os.path.dirname(__file__))
if result.returncode == 0:
print("所有测试通过")
else:
print("部分测试失败")
except Exception as e:
print(f"运行测试时发生错误: {e}")
async def _run_performance(self):
"""运行性能测试命令"""
print("开始运行性能测试...")
# 运行性能测试
import subprocess
try:
result = subprocess.run([
sys.executable, '-m', 'pytest',
'tests/test_performance.py',
'-v', '--tb=short', '--color=yes'
], cwd=os.path.dirname(__file__))
if result.returncode == 0:
print("性能测试完成")
else:
print("性能测试失败")
except Exception as e:
print(f"运行性能测试时发生错误: {e}")
def main():
"""主函数"""
cli = CommandLineInterface()
args = cli.parser.parse_args()
# 运行命令
asyncio.run(cli.run(args))
if __name__ == '__main__':
main()

54
simple_check.py Normal file
View File

@ -0,0 +1,54 @@
"""
简单检查数据状态脚本
"""
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.storage.database import db_manager
from sqlalchemy import text
import logging
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def simple_check():
"""简单检查数据状态"""
try:
# 获取数据库会话
session = db_manager.get_session()
logger.info("=== 简单数据检查 ===")
# 检查各表是否存在数据
tables = ['stock_basic', 'daily_kline', 'financial_report', 'data_source', 'system_log']
for table in tables:
# 先检查表是否存在
result = session.execute(text(f"SHOW TABLES LIKE '{table}'"))
if result.fetchone():
# 检查数据量
result = session.execute(text(f"SELECT COUNT(*) FROM {table}"))
count = result.fetchone()[0]
logger.info(f"{table}: {count} 条记录")
# 如果是stock_basic表且有数据显示前几条
if table == 'stock_basic' and count > 0:
result = session.execute(text(f"SELECT code, name FROM {table} LIMIT 5"))
stocks = result.fetchall()
for stock in stocks:
logger.info(f" 股票: {stock[0]}, 名称: {stock[1]}")
else:
logger.info(f"{table} 不存在")
session.close()
logger.info("=== 检查完成 ===")
except Exception as e:
logger.error(f"检查失败: {str(e)}")
raise
if __name__ == "__main__":
simple_check()

1
src/__init__.py Normal file
View File

@ -0,0 +1 @@
# A股行情分析与量化交易系统主包

1
src/config/__init__.py Normal file
View File

@ -0,0 +1 @@
# 配置管理模块

74
src/config/settings.py Normal file
View File

@ -0,0 +1,74 @@
"""
系统配置管理模块
负责加载和管理所有系统配置参数
"""
import os
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings
# 手动加载.env文件
from dotenv import load_dotenv
load_dotenv()
class DatabaseSettings(BaseSettings):
"""数据库配置类"""
database_url: str = Field(
default=os.getenv("DATABASE_URL", "mysql+mysqlconnector://username:password@localhost:3306/stock"),
description="数据库连接URL"
)
pool_size: int = Field(default=int(os.getenv("DATABASE_POOL_SIZE", "10")), description="连接池大小")
max_overflow: int = Field(default=int(os.getenv("DATABASE_MAX_OVERFLOW", "20")), description="最大溢出连接数")
pool_timeout: int = Field(default=int(os.getenv("DATABASE_POOL_TIMEOUT", "30")), description="连接池超时时间(秒)")
class DataSourceSettings(BaseSettings):
"""数据源配置类"""
akshare_timeout: int = Field(default=int(os.getenv("DATA_SOURCE_AKSHARE_TIMEOUT", "30")), description="AKshare接口超时时间(秒)")
baostock_timeout: int = Field(default=int(os.getenv("DATA_SOURCE_BAOSTOCK_TIMEOUT", "30")), description="Baostock接口超时时间(秒)")
max_retry_times: int = Field(default=int(os.getenv("DATA_SOURCE_MAX_RETRY_TIMES", "3")), description="最大重试次数")
retry_delay_seconds: int = Field(default=int(os.getenv("DATA_SOURCE_RETRY_DELAY_SECONDS", "5")), description="重试延迟时间(秒)")
class SchedulerSettings(BaseSettings):
"""定时任务配置类"""
timezone: str = Field(default=os.getenv("SCHEDULER_TIMEZONE", "Asia/Shanghai"), description="时区设置")
update_interval_hours: int = Field(default=int(os.getenv("SCHEDULER_UPDATE_INTERVAL_HOURS", "24")), description="数据更新间隔(小时)")
class LogSettings(BaseSettings):
"""日志配置类"""
log_level: str = Field(default=os.getenv("LOG_LEVEL", "INFO"), description="日志级别")
log_file: str = Field(default=os.getenv("LOG_FILE", "logs/stock_analysis.log"), description="日志文件路径")
max_file_size: str = Field(default=os.getenv("LOG_MAX_FILE_SIZE", "10MB"), description="最大日志文件大小")
backup_count: int = Field(default=int(os.getenv("LOG_BACKUP_COUNT", "5")), description="备份文件数量")
class MarketSettings(BaseSettings):
"""市场配置类"""
market_types: list[str] = Field(default=os.getenv("MARKET_MARKET_TYPES", "sh,sz").split(","), description="市场类型列表")
data_types: list[str] = Field(
default=os.getenv("MARKET_DATA_TYPES", "stock_basic,daily_kline,financial_report").split(","),
description="数据类型列表"
)
class Settings(BaseSettings):
"""系统总配置类"""
database: DatabaseSettings = DatabaseSettings()
data_source: DataSourceSettings = DataSourceSettings()
scheduler: SchedulerSettings = SchedulerSettings()
log: LogSettings = LogSettings()
market: MarketSettings = MarketSettings()
# 全局配置实例
settings = Settings()

1
src/data/__init__.py Normal file
View File

@ -0,0 +1 @@
# 数据采集模块

View File

@ -0,0 +1,172 @@
"""
AKshare数据采集器
基于AKshare API实现股票数据采集功能
"""
import akshare as ak
from typing import Any, Dict, List
from loguru import logger
from .base_collector import BaseDataCollector
class AKshareCollector(BaseDataCollector):
"""AKshare数据采集器"""
def __init__(self):
"""初始化AKshare采集器"""
super().__init__("AKshare采集器")
async def get_stock_basic_info(self) -> List[Dict[str, Any]]:
"""
获取股票基础信息
Returns:
股票基础信息列表
"""
logger.info("开始获取股票基础信息")
async def _fetch_data():
try:
# 获取A股基础信息
stock_info_a_code_name = ak.stock_info_a_code_name()
# 转换为标准格式
result = []
for _, row in stock_info_a_code_name.iterrows():
result.append({
"code": row["code"],
"name": row["name"],
"market": self._get_market_type(row["code"])
})
logger.info(f"成功获取{len(result)}只股票基础信息")
return result
except Exception as e:
logger.error(f"获取股票基础信息失败: {str(e)}")
raise
return await self._retry_request(_fetch_data)
async def get_daily_kline_data(
self,
stock_code: str,
start_date: str,
end_date: str
) -> List[Dict[str, Any]]:
"""
获取日K线数据
Args:
stock_code: 股票代码
start_date: 开始日期(YYYY-MM-DD)
end_date: 结束日期(YYYY-MM-DD)
Returns:
日K线数据列表
"""
logger.info(f"开始获取{stock_code}的K线数据")
async def _fetch_data():
try:
# 获取日K线数据
stock_zh_a_hist_df = ak.stock_zh_a_hist(
symbol=stock_code,
period="daily",
start_date=start_date,
end_date=end_date,
adjust=""
)
# 转换为标准格式
result = []
for _, row in stock_zh_a_hist_df.iterrows():
result.append({
"code": stock_code,
"date": row["日期"].strftime("%Y-%m-%d"),
"open": float(row["开盘"]),
"high": float(row["最高"]),
"low": float(row["最低"]),
"close": float(row["收盘"]),
"volume": int(row["成交量"]),
"amount": float(row["成交额"])
})
logger.info(f"成功获取{stock_code}{len(result)}条K线数据")
return result
except Exception as e:
logger.error(f"获取{stock_code}K线数据失败: {str(e)}")
raise
return await self._retry_request(_fetch_data)
async def get_financial_report(
self,
stock_code: str,
year: int,
quarter: int
) -> List[Dict[str, Any]]:
"""
获取财务报告数据
Args:
stock_code: 股票代码
year: 年份
quarter: 季度(1-4)
Returns:
财务报告数据列表
"""
logger.info(f"开始获取{stock_code}的财务报告")
async def _fetch_data():
try:
# 获取财务指标数据
stock_financial_analysis_indicator_df = ak.stock_financial_analysis_indicator(
symbol=stock_code
)
# 过滤指定年份和季度的数据
filtered_df = stock_financial_analysis_indicator_df[
(stock_financial_analysis_indicator_df["日期"].str.startswith(f"{year}")) &
(stock_financial_analysis_indicator_df["日期"].str.contains(f"Q{quarter}"))
]
# 转换为标准格式
result = []
for _, row in filtered_df.iterrows():
result.append({
"code": stock_code,
"report_date": row["日期"],
"eps": float(row.get("基本每股收益", 0)),
"net_profit": float(row.get("净利润", 0)),
"revenue": float(row.get("营业收入", 0)),
"total_assets": float(row.get("总资产", 0))
})
logger.info(f"成功获取{stock_code}{len(result)}条财务数据")
return result
except Exception as e:
logger.error(f"获取{stock_code}财务报告失败: {str(e)}")
raise
return await self._retry_request(_fetch_data)
def _get_market_type(self, code: str) -> str:
"""
根据股票代码判断市场类型
Args:
code: 股票代码
Returns:
市场类型(sh/sz)
"""
if code.startswith("6"):
return "sh"
elif code.startswith("0") or code.startswith("3"):
return "sz"
else:
return "other"

View File

@ -0,0 +1,258 @@
"""
Baostock数据采集器
基于Baostock API实现股票数据采集功能
"""
import baostock as bs
import pandas as pd
from typing import Any, Dict, List
from loguru import logger
from .base_collector import BaseDataCollector
class BaostockCollector(BaseDataCollector):
"""Baostock数据采集器"""
def __init__(self):
"""初始化Baostock采集器"""
super().__init__("Baostock采集器")
self._is_logged_in = False
async def login(self) -> bool:
"""
登录Baostock系统
Returns:
登录是否成功
"""
try:
lg = bs.login()
if lg.error_code == "0":
self._is_logged_in = True
logger.info("Baostock登录成功")
return True
else:
logger.error(f"Baostock登录失败: {lg.error_msg}")
return False
except Exception as e:
logger.error(f"Baostock登录异常: {str(e)}")
return False
async def logout(self):
"""登出Baostock系统"""
try:
if self._is_logged_in:
bs.logout()
self._is_logged_in = False
logger.info("Baostock登出成功")
except Exception as e:
logger.error(f"Baostock登出异常: {str(e)}")
async def get_stock_basic_info(self) -> List[Dict[str, Any]]:
"""
获取股票基础信息
Returns:
股票基础信息列表
"""
logger.info("开始获取股票基础信息")
if not await self.login():
return []
async def _fetch_data():
try:
# 获取股票基础信息
rs = bs.query_stock_basic()
if rs.error_code != "0":
raise Exception(f"查询失败: {rs.error_msg}")
# 转换为DataFrame
data_list = []
while (rs.error_code == "0") & rs.next():
data_list.append(rs.get_row_data())
result_df = pd.DataFrame(
data_list,
columns=rs.fields
)
# 转换为标准格式
result = []
for _, row in result_df.iterrows():
result.append({
"code": row["code"],
"name": row["code_name"],
"market": self._get_market_type(row["code"]),
"ipo_date": row.get("ipoDate", ""),
"industry": row.get("industry", ""),
"area": row.get("area", "")
})
logger.info(f"成功获取{len(result)}只股票基础信息")
return result
except Exception as e:
logger.error(f"获取股票基础信息失败: {str(e)}")
raise
finally:
await self.logout()
return await self._retry_request(_fetch_data)
async def get_daily_kline_data(
self,
stock_code: str,
start_date: str,
end_date: str
) -> List[Dict[str, Any]]:
"""
获取日K线数据
Args:
stock_code: 股票代码
start_date: 开始日期(YYYY-MM-DD)
end_date: 结束日期(YYYY-MM-DD)
Returns:
日K线数据列表
"""
logger.info(f"开始获取{stock_code}的K线数据")
if not await self.login():
return []
async def _fetch_data():
try:
# 获取日K线数据
rs = bs.query_history_k_data_plus(
stock_code,
"date,code,open,high,low,close,volume,amount",
start_date=start_date,
end_date=end_date,
frequency="d",
adjustflag="3"
)
if rs.error_code != "0":
raise Exception(f"查询失败: {rs.error_msg}")
# 转换为DataFrame
data_list = []
while (rs.error_code == "0") & rs.next():
data_list.append(rs.get_row_data())
result_df = pd.DataFrame(
data_list,
columns=rs.fields
)
# 转换为标准格式
result = []
for _, row in result_df.iterrows():
result.append({
"code": row["code"],
"date": row["date"],
"open": float(row["open"]),
"high": float(row["high"]),
"low": float(row["low"]),
"close": float(row["close"]),
"volume": int(row["volume"]),
"amount": float(row["amount"])
})
logger.info(f"成功获取{stock_code}{len(result)}条K线数据")
return result
except Exception as e:
logger.error(f"获取{stock_code}K线数据失败: {str(e)}")
raise
finally:
await self.logout()
return await self._retry_request(_fetch_data)
async def get_financial_report(
self,
stock_code: str,
year: int,
quarter: int
) -> List[Dict[str, Any]]:
"""
获取财务报告数据
Args:
stock_code: 股票代码
year: 年份
quarter: 季度(1-4)
Returns:
财务报告数据列表
"""
logger.info(f"开始获取{stock_code}的财务报告")
if not await self.login():
return []
async def _fetch_data():
try:
# 获取财务指标数据
rs = bs.query_profit_data(
code=stock_code,
year=year,
quarter=quarter
)
if rs.error_code != "0":
raise Exception(f"查询失败: {rs.error_msg}")
# 转换为DataFrame
data_list = []
while (rs.error_code == "0") & rs.next():
data_list.append(rs.get_row_data())
result_df = pd.DataFrame(
data_list,
columns=rs.fields
)
# 转换为标准格式
result = []
for _, row in result_df.iterrows():
result.append({
"code": stock_code,
"report_date": f"{year}-Q{quarter}",
"eps": float(row.get("eps", 0)),
"net_profit": float(row.get("netProfit", 0)),
"revenue": float(row.get("revenue", 0)),
"total_assets": float(row.get("totalAssets", 0))
})
logger.info(f"成功获取{stock_code}{len(result)}条财务数据")
return result
except Exception as e:
logger.error(f"获取{stock_code}财务报告失败: {str(e)}")
raise
finally:
await self.logout()
return await self._retry_request(_fetch_data)
def _get_market_type(self, code: str) -> str:
"""
根据股票代码判断市场类型
Args:
code: 股票代码
Returns:
市场类型(sh/sz)
"""
if code.startswith("sh.") or code.startswith("6"):
return "sh"
elif code.startswith("sz.") or code.startswith("0") or code.startswith("3"):
return "sz"
else:
return "other"

136
src/data/base_collector.py Normal file
View File

@ -0,0 +1,136 @@
"""
数据采集器基类
定义统一的数据采集接口和基础功能
"""
import abc
import time
from typing import Any, Dict, List, Optional
from loguru import logger
from src.config.settings import settings
class BaseDataCollector(abc.ABC):
"""数据采集器基类"""
def __init__(self, collector_name: str):
"""
初始化采集器
Args:
collector_name: 采集器名称
"""
self.collector_name = collector_name
self.max_retry_times = settings.data_source.max_retry_times
self.retry_delay_seconds = settings.data_source.retry_delay_seconds
@abc.abstractmethod
async def get_stock_basic_info(self) -> List[Dict[str, Any]]:
"""
获取股票基础信息
Returns:
股票基础信息列表
"""
pass
@abc.abstractmethod
async def get_daily_kline_data(
self,
stock_code: str,
start_date: str,
end_date: str
) -> List[Dict[str, Any]]:
"""
获取日K线数据
Args:
stock_code: 股票代码
start_date: 开始日期(YYYY-MM-DD)
end_date: 结束日期(YYYY-MM-DD)
Returns:
日K线数据列表
"""
pass
@abc.abstractmethod
async def get_financial_report(
self,
stock_code: str,
year: int,
quarter: int
) -> List[Dict[str, Any]]:
"""
获取财务报告数据
Args:
stock_code: 股票代码
year: 年份
quarter: 季度(1-4)
Returns:
财务报告数据列表
"""
pass
async def _retry_request(
self,
func,
*args,
**kwargs
) -> Optional[Any]:
"""
带重试机制的请求方法
Args:
func: 要执行的函数
*args: 函数参数
**kwargs: 函数关键字参数
Returns:
函数执行结果失败返回None
"""
for attempt in range(self.max_retry_times):
try:
result = await func(*args, **kwargs)
logger.info(f"{self.collector_name} 请求成功")
return result
except Exception as e:
logger.warning(
f"{self.collector_name}{attempt + 1}次请求失败: {str(e)}"
)
if attempt < self.max_retry_times - 1:
time.sleep(self.retry_delay_seconds)
else:
logger.error(f"{self.collector_name} 所有重试均失败")
return None
async def _validate_data(
self,
data: List[Dict[str, Any]],
required_fields: List[str]
) -> bool:
"""
验证数据完整性
Args:
data: 待验证的数据
required_fields: 必需字段列表
Returns:
验证结果
"""
if not data:
logger.warning(f"{self.collector_name} 数据为空")
return False
for item in data:
for field in required_fields:
if field not in item or item[field] is None:
logger.warning(f"{self.collector_name} 数据字段{field}缺失")
return False
logger.info(f"{self.collector_name} 数据验证通过")
return True

View File

@ -0,0 +1,328 @@
"""
数据初始化服务
实现全量股票数据的初始化功能
"""
import asyncio
from typing import Dict, Any, List
from datetime import datetime, date
from loguru import logger
from .data_manager import DataManager
from src.storage.database import DatabaseManager, db_manager
from src.storage.stock_repository import StockRepository
from src.config.settings import Settings
class DataInitializer:
"""数据初始化服务类"""
def __init__(self, settings: Settings):
"""
初始化数据初始化服务
Args:
settings: 系统配置
"""
self.settings = settings
self.data_manager = DataManager(settings)
self.db_manager = db_manager # 使用全局数据库管理器实例
self.repository = None
async def initialize_all_data(self) -> Dict[str, Any]:
"""
初始化所有数据
Returns:
初始化结果统计
"""
try:
logger.info("开始初始化全量股票数据...")
# 数据库连接已在DatabaseManager构造函数中自动初始化
self.repository = StockRepository(self.db_manager.get_session())
# 执行初始化步骤
results = {}
# 1. 初始化股票基础信息
logger.info("步骤1: 初始化股票基础信息")
results["stock_basic"] = await self._initialize_stock_basic_info()
# 2. 初始化历史K线数据
logger.info("步骤2: 初始化历史K线数据")
results["daily_kline"] = await self._initialize_daily_kline_data()
# 3. 初始化财务报告数据
logger.info("步骤3: 初始化财务报告数据")
results["financial_report"] = await self._initialize_financial_report_data()
# 4. 统计初始化结果
total_stats = self._calculate_total_stats(results)
logger.info(f"数据初始化完成: {total_stats}")
return {
"success": True,
"results": results,
"total_stats": total_stats,
"timestamp": datetime.now().isoformat()
}
except Exception as e:
logger.error(f"数据初始化失败: {str(e)}")
return {
"success": False,
"error": str(e),
"timestamp": datetime.now().isoformat()
}
finally:
# DatabaseManager使用连接池不需要手动关闭
pass
async def _initialize_stock_basic_info(self) -> Dict[str, Any]:
"""
初始化股票基础信息
Returns:
初始化结果
"""
try:
# 获取所有股票基础信息
logger.info("获取股票基础信息...")
stock_basic_data = await self.data_manager.get_stock_basic_info()
if not stock_basic_data:
logger.warning("未获取到股票基础信息")
return {"success": False, "error": "未获取到股票基础信息"}
# 保存到数据库
logger.info(f"保存{len(stock_basic_data)}条股票基础信息到数据库")
save_result = self.repository.save_stock_basic_info(stock_basic_data)
return {
"success": True,
"data_count": len(stock_basic_data),
"save_result": save_result
}
except Exception as e:
logger.error(f"初始化股票基础信息失败: {str(e)}")
return {"success": False, "error": str(e)}
async def _initialize_daily_kline_data(self) -> Dict[str, Any]:
"""
初始化历史K线数据
Returns:
初始化结果
"""
try:
# 获取所有股票代码
stocks = self.repository.get_stock_basic_info()
if not stocks:
logger.warning("没有股票基础信息无法获取K线数据")
return {"success": False, "error": "没有股票基础信息"}
# 计算历史数据开始日期默认获取最近3年数据
end_date = date.today()
start_date = date(end_date.year - 3, end_date.month, end_date.day)
total_kline_data = []
success_count = 0
error_count = 0
# 分批获取K线数据避免内存溢出
batch_size = 50
for i in range(0, len(stocks), batch_size):
batch_stocks = stocks[i:i + batch_size]
# 为每只股票获取K线数据
for stock in batch_stocks:
try:
logger.info(f"获取股票{stock.code}的K线数据...")
kline_data = await self.data_manager.get_daily_kline_data(
stock.code,
start_date.strftime("%Y-%m-%d"),
end_date.strftime("%Y-%m-%d")
)
if kline_data:
total_kline_data.extend(kline_data)
success_count += 1
else:
logger.warning(f"股票{stock.code}未获取到K线数据")
error_count += 1
# 小延迟避免请求过快
await asyncio.sleep(0.1)
except Exception as e:
logger.error(f"获取股票{stock.code}K线数据失败: {str(e)}")
error_count += 1
continue
# 保存K线数据到数据库
if total_kline_data:
logger.info(f"保存{len(total_kline_data)}条K线数据到数据库")
save_result = self.repository.save_daily_kline_data(total_kline_data)
else:
save_result = {"added_count": 0, "error_count": 0, "total_count": 0}
return {
"success": True,
"stock_count": len(stocks),
"success_stocks": success_count,
"error_stocks": error_count,
"kline_data_count": len(total_kline_data),
"save_result": save_result
}
except Exception as e:
logger.error(f"初始化K线数据失败: {str(e)}")
return {"success": False, "error": str(e)}
async def _initialize_financial_report_data(self) -> Dict[str, Any]:
"""
初始化财务报告数据
Returns:
初始化结果
"""
try:
# 获取所有股票代码
stocks = self.repository.get_stock_basic_info()
if not stocks:
logger.warning("没有股票基础信息,无法获取财务数据")
return {"success": False, "error": "没有股票基础信息"}
total_financial_data = []
success_count = 0
error_count = 0
# 分批获取财务数据
batch_size = 30
for i in range(0, len(stocks), batch_size):
batch_stocks = stocks[i:i + batch_size]
# 为每只股票获取财务数据
for stock in batch_stocks:
try:
logger.info(f"获取股票{stock.code}的财务报告数据...")
financial_data = await self.data_manager.get_financial_report(
stock.code
)
if financial_data:
total_financial_data.extend(financial_data)
success_count += 1
else:
logger.warning(f"股票{stock.code}未获取到财务数据")
error_count += 1
# 小延迟避免请求过快
await asyncio.sleep(0.2)
except Exception as e:
logger.error(f"获取股票{stock.code}财务数据失败: {str(e)}")
error_count += 1
continue
# 保存财务数据到数据库
if total_financial_data:
logger.info(f"保存{len(total_financial_data)}条财务数据到数据库")
save_result = self.repository.save_financial_report_data(total_financial_data)
else:
save_result = {"added_count": 0, "updated_count": 0, "error_count": 0, "total_count": 0}
return {
"success": True,
"stock_count": len(stocks),
"success_stocks": success_count,
"error_stocks": error_count,
"financial_data_count": len(total_financial_data),
"save_result": save_result
}
except Exception as e:
logger.error(f"初始化财务数据失败: {str(e)}")
return {"success": False, "error": str(e)}
def _calculate_total_stats(self, results: Dict[str, Any]) -> Dict[str, Any]:
"""
计算总统计信息
Args:
results: 各步骤结果
Returns:
总统计信息
"""
total_stats = {
"total_stocks": 0,
"total_kline_records": 0,
"total_financial_records": 0,
"success_steps": 0,
"failed_steps": 0
}
for step_name, result in results.items():
if result.get("success"):
total_stats["success_steps"] += 1
if step_name == "stock_basic":
total_stats["total_stocks"] = result.get("data_count", 0)
elif step_name == "daily_kline":
total_stats["total_kline_records"] = result.get("kline_data_count", 0)
elif step_name == "financial_report":
total_stats["total_financial_records"] = result.get("financial_data_count", 0)
else:
total_stats["failed_steps"] += 1
return total_stats
async def check_data_status(self) -> Dict[str, Any]:
"""
检查数据状态
Returns:
数据状态信息
"""
try:
# 数据库连接已在DatabaseManager构造函数中自动初始化
self.repository = StockRepository(self.db_manager.get_session())
# 查询各表数据量
stock_count = self.repository.session.query(self.repository.StockBasic).count()
kline_count = self.repository.session.query(self.repository.DailyKline).count()
financial_count = self.repository.session.query(self.repository.FinancialReport).count()
# 获取最新数据日期
latest_kline = self.repository.session.query(self.repository.DailyKline).order_by(
self.repository.DailyKline.trade_date.desc()
).first()
latest_financial = self.repository.session.query(self.repository.FinancialReport).order_by(
self.repository.FinancialReport.report_date.desc()
).first()
status = {
"stock_basic_count": stock_count,
"daily_kline_count": kline_count,
"financial_report_count": financial_count,
"latest_kline_date": latest_kline.trade_date if latest_kline else None,
"latest_financial_date": latest_financial.report_date if latest_financial else None,
"check_time": datetime.now().isoformat()
}
logger.info(f"数据状态检查完成: {status}")
return status
except Exception as e:
logger.error(f"检查数据状态失败: {str(e)}")
return {"error": str(e)}
finally:
# DatabaseManager使用连接池不需要手动关闭
pass

267
src/data/data_manager.py Normal file
View File

@ -0,0 +1,267 @@
"""
数据采集管理器
统一管理多个数据源提供数据采集合并和去重功能
"""
from typing import Any, Dict, List, Optional
from loguru import logger
from .akshare_collector import AKshareCollector
from .baostock_collector import BaostockCollector
from .base_collector import BaseDataCollector
class DataManager:
"""数据采集管理器"""
def __init__(self, settings=None):
"""初始化数据管理器
Args:
settings: 系统配置可选
"""
self.settings = settings
self.collectors: List[BaseDataCollector] = [
AKshareCollector(),
BaostockCollector()
]
logger.info("数据管理器初始化完成")
async def get_stock_basic_info(self) -> List[Dict[str, Any]]:
"""
获取股票基础信息多源合并
Returns:
合并后的股票基础信息列表
"""
logger.info("开始多源获取股票基础信息")
all_data = []
for collector in self.collectors:
try:
data = await collector.get_stock_basic_info()
if data:
all_data.extend(data)
logger.info(f"{collector.collector_name} 获取到{len(data)}条数据")
except Exception as e:
logger.error(f"{collector.collector_name} 获取基础信息失败: {str(e)}")
# 数据去重和合并
merged_data = self._merge_stock_basic_data(all_data)
logger.info(f"合并后共获取{len(merged_data)}只股票基础信息")
return merged_data
async def get_daily_kline_data(
self,
stock_code: str,
start_date: str,
end_date: str
) -> List[Dict[str, Any]]:
"""
获取日K线数据多源合并
Args:
stock_code: 股票代码
start_date: 开始日期(YYYY-MM-DD)
end_date: 结束日期(YYYY-MM-DD)
Returns:
合并后的日K线数据列表
"""
logger.info(f"开始多源获取{stock_code}的K线数据")
all_data = []
for collector in self.collectors:
try:
data = await collector.get_daily_kline_data(
stock_code, start_date, end_date
)
if data:
all_data.extend(data)
logger.info(f"{collector.collector_name} 获取到{len(data)}条数据")
except Exception as e:
logger.error(f"{collector.collector_name} 获取K线数据失败: {str(e)}")
# 数据去重和合并
merged_data = self._merge_kline_data(all_data)
logger.info(f"合并后共获取{len(merged_data)}条K线数据")
return merged_data
async def get_financial_report(
self,
stock_code: str,
year: int,
quarter: int
) -> List[Dict[str, Any]]:
"""
获取财务报告数据多源合并
Args:
stock_code: 股票代码
year: 年份
quarter: 季度(1-4)
Returns:
合并后的财务报告数据列表
"""
logger.info(f"开始多源获取{stock_code}的财务报告")
all_data = []
for collector in self.collectors:
try:
data = await collector.get_financial_report(
stock_code, year, quarter
)
if data:
all_data.extend(data)
logger.info(f"{collector.collector_name} 获取到{len(data)}条数据")
except Exception as e:
logger.error(f"{collector.collector_name} 获取财务报告失败: {str(e)}")
# 数据去重和合并
merged_data = self._merge_financial_data(all_data)
logger.info(f"合并后共获取{len(merged_data)}条财务数据")
return merged_data
def _merge_stock_basic_data(
self,
data_list: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""
合并股票基础数据
Args:
data_list: 待合并的数据列表
Returns:
合并后的数据列表
"""
merged_dict = {}
for data in data_list:
code = data.get("code")
if not code:
continue
# 如果该股票代码已存在,合并信息
if code in merged_dict:
existing = merged_dict[code]
# 优先使用更详细的信息
for key, value in data.items():
if key not in existing or (
value and not existing[key]
):
existing[key] = value
else:
merged_dict[code] = data
return list(merged_dict.values())
def _merge_kline_data(
self,
data_list: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""
合并K线数据
Args:
data_list: 待合并的数据列表
Returns:
合并后的数据列表
"""
merged_dict = {}
for data in data_list:
key = f"{data.get('code')}_{data.get('date')}"
if not key:
continue
# 如果该日期数据已存在,选择质量更高的数据
if key in merged_dict:
existing = merged_dict[key]
# 优先使用非零值
for field in ["open", "high", "low", "close", "volume", "amount"]:
if existing.get(field, 0) == 0 and data.get(field, 0) != 0:
existing[field] = data[field]
else:
merged_dict[key] = data
return list(merged_dict.values())
def _merge_financial_data(
self,
data_list: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""
合并财务数据
Args:
data_list: 待合并的数据列表
Returns:
合并后的数据列表
"""
merged_dict = {}
for data in data_list:
key = f"{data.get('code')}_{data.get('report_date')}"
if not key:
continue
# 如果该报告期数据已存在,选择质量更高的数据
if key in merged_dict:
existing = merged_dict[key]
# 优先使用非零值
for field in ["eps", "net_profit", "revenue", "total_assets"]:
if existing.get(field, 0) == 0 and data.get(field, 0) != 0:
existing[field] = data[field]
else:
merged_dict[key] = data
return list(merged_dict.values())
async def validate_data_quality(
self,
data: List[Dict[str, Any]],
data_type: str
) -> Dict[str, Any]:
"""
验证数据质量
Args:
data: 待验证的数据
data_type: 数据类型
Returns:
质量验证结果
"""
if not data:
return {
"valid": False,
"message": "数据为空",
"completeness": 0.0
}
# 计算数据完整性
total_fields = 0
missing_fields = 0
for item in data:
for key, value in item.items():
total_fields += 1
if value is None or value == "":
missing_fields += 1
completeness = (
(total_fields - missing_fields) / total_fields * 100
) if total_fields > 0 else 0
return {
"valid": completeness >= 80.0, # 完整性达到80%认为有效
"message": f"数据完整性: {completeness:.2f}%",
"completeness": completeness,
"total_records": len(data)
}

622
src/data/data_processor.py Normal file
View File

@ -0,0 +1,622 @@
"""
数据处理和清洗模块
实现不同来源数据的格式统一清洗与校验
"""
import re
from typing import Dict, Any, List, Optional, Tuple
from datetime import datetime, date
from loguru import logger
class DataProcessor:
"""数据处理和清洗类"""
def __init__(self):
"""初始化数据处理器"""
self.valid_markets = ["主板", "中小板", "创业板", "科创板"]
self.valid_industries = self._load_valid_industries()
self.valid_areas = self._load_valid_areas()
def process_stock_basic_info(
self,
raw_data: List[Dict[str, Any]],
source: str
) -> List[Dict[str, Any]]:
"""
处理股票基础信息
Args:
raw_data: 原始数据
source: 数据源标识
Returns:
处理后的数据
"""
try:
processed_data = []
for item in raw_data:
try:
# 数据清洗和验证
cleaned_item = self._clean_stock_basic_item(item, source)
# 数据验证
if self._validate_stock_basic_item(cleaned_item):
processed_data.append(cleaned_item)
else:
logger.warning(f"股票基础信息验证失败: {item.get('code', 'unknown')}")
except Exception as e:
logger.error(f"处理股票基础信息失败: {str(e)}, 数据: {item}")
continue
logger.info(f"股票基础信息处理完成: 原始{len(raw_data)}条, 有效{len(processed_data)}")
return processed_data
except Exception as e:
logger.error(f"处理股票基础信息异常: {str(e)}")
return []
def process_daily_kline_data(
self,
raw_data: List[Dict[str, Any]],
source: str
) -> List[Dict[str, Any]]:
"""
处理日K线数据
Args:
raw_data: 原始数据
source: 数据源标识
Returns:
处理后的数据
"""
try:
processed_data = []
for item in raw_data:
try:
# 数据清洗和验证
cleaned_item = self._clean_kline_item(item, source)
# 数据验证
if self._validate_kline_item(cleaned_item):
processed_data.append(cleaned_item)
else:
logger.warning(f"K线数据验证失败: {item.get('code', 'unknown')} {item.get('date', 'unknown')}")
except Exception as e:
logger.error(f"处理K线数据失败: {str(e)}, 数据: {item}")
continue
logger.info(f"K线数据处理完成: 原始{len(raw_data)}条, 有效{len(processed_data)}")
return processed_data
except Exception as e:
logger.error(f"处理K线数据异常: {str(e)}")
return []
def process_financial_report_data(
self,
raw_data: List[Dict[str, Any]],
source: str
) -> List[Dict[str, Any]]:
"""
处理财务报告数据
Args:
raw_data: 原始数据
source: 数据源标识
Returns:
处理后的数据
"""
try:
processed_data = []
for item in raw_data:
try:
# 数据清洗和验证
cleaned_item = self._clean_financial_item(item, source)
# 数据验证
if self._validate_financial_item(cleaned_item):
processed_data.append(cleaned_item)
else:
logger.warning(f"财务数据验证失败: {item.get('code', 'unknown')} {item.get('report_date', 'unknown')}")
except Exception as e:
logger.error(f"处理财务数据失败: {str(e)}, 数据: {item}")
continue
logger.info(f"财务数据处理完成: 原始{len(raw_data)}条, 有效{len(processed_data)}")
return processed_data
except Exception as e:
logger.error(f"处理财务数据异常: {str(e)}")
return []
def _clean_stock_basic_item(
self,
item: Dict[str, Any],
source: str
) -> Dict[str, Any]:
"""
清洗股票基础信息项
Args:
item: 原始数据项
source: 数据源标识
Returns:
清洗后的数据项
"""
cleaned = {}
# 标准化股票代码
if "code" in item:
cleaned["code"] = self._standardize_stock_code(item["code"])
# 标准化股票名称
if "name" in item:
cleaned["name"] = self._standardize_stock_name(item["name"])
# 标准化市场类型
if "market" in item:
cleaned["market"] = self._standardize_market(item["market"])
# 标准化行业
if "industry" in item:
cleaned["industry"] = self._standardize_industry(item["industry"])
# 标准化地区
if "area" in item:
cleaned["area"] = self._standardize_area(item["area"])
# 标准化上市日期
if "ipo_date" in item:
cleaned["ipo_date"] = self._standardize_date(item["ipo_date"])
# 添加数据源信息
cleaned["data_source"] = source
cleaned["processed_time"] = datetime.now()
return cleaned
def _clean_kline_item(
self,
item: Dict[str, Any],
source: str
) -> Dict[str, Any]:
"""
清洗K线数据项
Args:
item: 原始数据项
source: 数据源标识
Returns:
清洗后的数据项
"""
cleaned = {}
# 标准化股票代码
if "code" in item:
cleaned["code"] = self._standardize_stock_code(item["code"])
# 标准化交易日期
if "date" in item:
cleaned["date"] = self._standardize_date(item["date"])
# 标准化价格数据
price_fields = ["open", "high", "low", "close"]
for field in price_fields:
if field in item:
cleaned[field] = self._standardize_price(item[field])
# 标准化成交量
if "volume" in item:
cleaned["volume"] = self._standardize_volume(item["volume"])
# 标准化成交额
if "amount" in item:
cleaned["amount"] = self._standardize_amount(item["amount"])
# 添加数据源信息
cleaned["data_source"] = source
cleaned["processed_time"] = datetime.now()
return cleaned
def _clean_financial_item(
self,
item: Dict[str, Any],
source: str
) -> Dict[str, Any]:
"""
清洗财务数据项
Args:
item: 原始数据项
source: 数据源标识
Returns:
清洗后的数据项
"""
cleaned = {}
# 标准化股票代码
if "code" in item:
cleaned["code"] = self._standardize_stock_code(item["code"])
# 标准化报告日期
if "report_date" in item:
cleaned["report_date"] = self._standardize_report_date(item["report_date"])
# 标准化财务指标
financial_fields = ["eps", "net_profit", "revenue", "total_assets"]
for field in financial_fields:
if field in item:
cleaned[field] = self._standardize_financial_value(item[field])
# 添加数据源信息
cleaned["data_source"] = source
cleaned["processed_time"] = datetime.now()
return cleaned
def _validate_stock_basic_item(self, item: Dict[str, Any]) -> bool:
"""
验证股票基础信息项
Args:
item: 数据项
Returns:
是否有效
"""
try:
# 检查必需字段
required_fields = ["code", "name"]
for field in required_fields:
if not item.get(field):
return False
# 验证股票代码格式
if not self._is_valid_stock_code(item["code"]):
return False
# 验证市场类型
if "market" in item and item["market"] not in self.valid_markets:
return False
# 验证上市日期
if "ipo_date" in item and item["ipo_date"]:
if item["ipo_date"] > date.today():
return False
return True
except Exception:
return False
def _validate_kline_item(self, item: Dict[str, Any]) -> bool:
"""
验证K线数据项
Args:
item: 数据项
Returns:
是否有效
"""
try:
# 检查必需字段
required_fields = ["code", "date", "open", "high", "low", "close"]
for field in required_fields:
if not item.get(field):
return False
# 验证股票代码格式
if not self._is_valid_stock_code(item["code"]):
return False
# 验证价格合理性
if not self._is_valid_price_data(item):
return False
# 验证交易日期
if item["date"] > date.today():
return False
return True
except Exception:
return False
def _validate_financial_item(self, item: Dict[str, Any]) -> bool:
"""
验证财务数据项
Args:
item: 数据项
Returns:
是否有效
"""
try:
# 检查必需字段
required_fields = ["code", "report_date"]
for field in required_fields:
if not item.get(field):
return False
# 验证股票代码格式
if not self._is_valid_stock_code(item["code"]):
return False
# 验证报告日期
if item["report_date"] > date.today():
return False
# 验证至少有一个财务指标
financial_fields = ["eps", "net_profit", "revenue", "total_assets"]
has_financial_data = any(item.get(field) for field in financial_fields)
if not has_financial_data:
return False
return True
except Exception:
return False
def _standardize_stock_code(self, code: str) -> str:
"""标准化股票代码"""
try:
# 移除空格和特殊字符
code = re.sub(r"[^0-9a-zA-Z]", "", str(code))
# 统一为6位数字格式
if code.isdigit():
return code.zfill(6)
return code
except Exception:
return str(code)
def _standardize_stock_name(self, name: str) -> str:
"""标准化股票名称"""
try:
# 移除多余空格和特殊字符
name = re.sub(r"\s+", " ", str(name).strip())
return name
except Exception:
return str(name)
def _standardize_market(self, market: str) -> str:
"""标准化市场类型"""
try:
market = str(market).strip()
# 映射常见市场名称
market_mapping = {
"sh": "主板",
"sz": "主板",
"主板": "主板",
"中小板": "中小板",
"创业板": "创业板",
"科创板": "科创板",
"SH": "主板",
"SZ": "主板"
}
return market_mapping.get(market, market)
except Exception:
return str(market)
def _standardize_industry(self, industry: str) -> str:
"""标准化行业"""
try:
industry = str(industry).strip()
# 如果行业在有效列表中,直接返回
if industry in self.valid_industries:
return industry
# 否则返回"其他"
return "其他"
except Exception:
return "其他"
def _standardize_area(self, area: str) -> str:
"""标准化地区"""
try:
area = str(area).strip()
# 如果地区在有效列表中,直接返回
if area in self.valid_areas:
return area
# 否则返回"其他"
return "其他"
except Exception:
return "其他"
def _standardize_date(self, date_str: str) -> Optional[date]:
"""标准化日期"""
try:
if not date_str:
return None
# 尝试多种日期格式
formats = ["%Y-%m-%d", "%Y/%m/%d", "%Y%m%d", "%Y-%m-%d %H:%M:%S"]
for fmt in formats:
try:
return datetime.strptime(str(date_str).strip(), fmt).date()
except ValueError:
continue
return None
except Exception:
return None
def _standardize_report_date(self, report_date: str) -> Optional[date]:
"""标准化报告日期"""
try:
if not report_date:
return None
# 处理季度报告格式(如"2023-Q1"
if "Q" in report_date:
year, quarter = report_date.split("-")
quarter_num = int(quarter.replace("Q", ""))
# 季度结束日期
if quarter_num == 1:
return date(int(year), 3, 31)
elif quarter_num == 2:
return date(int(year), 6, 30)
elif quarter_num == 3:
return date(int(year), 9, 30)
else:
return date(int(year), 12, 31)
# 处理标准日期格式
return self._standardize_date(report_date)
except Exception:
return None
def _standardize_price(self, price: Any) -> Optional[float]:
"""标准化价格"""
try:
if price is None:
return None
price_str = str(price).strip()
# 移除货币符号和逗号
price_str = re.sub(r"[^0-9.-]", "", price_str)
if not price_str:
return None
return round(float(price_str), 4)
except Exception:
return None
def _standardize_volume(self, volume: Any) -> Optional[int]:
"""标准化成交量"""
try:
if volume is None:
return None
volume_str = str(volume).strip()
# 移除逗号
volume_str = re.sub(r",", "", volume_str)
if not volume_str:
return None
return int(float(volume_str))
except Exception:
return None
def _standardize_amount(self, amount: Any) -> Optional[float]:
"""标准化成交额"""
try:
if amount is None:
return None
amount_str = str(amount).strip()
# 移除货币符号和逗号
amount_str = re.sub(r"[^0-9.-]", "", amount_str)
if not amount_str:
return None
return round(float(amount_str), 2)
except Exception:
return None
def _standardize_financial_value(self, value: Any) -> Optional[float]:
"""标准化财务指标值"""
try:
if value is None:
return None
value_str = str(value).strip()
# 移除逗号和单位(如"万元"
value_str = re.sub(r",|万元|亿元", "", value_str)
if not value_str:
return None
# 转换为浮点数
return float(value_str)
except Exception:
return None
def _is_valid_stock_code(self, code: str) -> bool:
"""验证股票代码格式"""
try:
if not code:
return False
# 检查是否为6位数字
if len(code) != 6 or not code.isdigit():
return False
# 检查股票代码前缀沪市6开头深市0开头
if not (code.startswith("6") or code.startswith("0") or code.startswith("3")):
return False
return True
except Exception:
return False
def _is_valid_price_data(self, item: Dict[str, Any]) -> bool:
"""验证价格数据合理性"""
try:
open_price = item.get("open")
high_price = item.get("high")
low_price = item.get("low")
close_price = item.get("close")
# 检查价格是否为正数
if any(price is not None and price <= 0 for price in [open_price, high_price, low_price, close_price]):
return False
# 检查价格关系high >= open, high >= close, low <= open, low <= close
if high_price < open_price or high_price < close_price:
return False
if low_price > open_price or low_price > close_price:
return False
return True
except Exception:
return False
def _load_valid_industries(self) -> List[str]:
"""加载有效行业列表"""
return [
"农林牧渔", "采掘", "化工", "钢铁", "有色金属", "电子", "家用电器", "食品饮料",
"纺织服装", "轻工制造", "医药生物", "公用事业", "交通运输", "房地产", "商业贸易",
"休闲服务", "综合", "建筑材料", "建筑装饰", "电气设备", "国防军工", "计算机",
"传媒", "通信", "银行", "非银金融", "汽车", "机械设备", "其他"
]
def _load_valid_areas(self) -> List[str]:
"""加载有效地区列表"""
return [
"北京", "上海", "天津", "重庆", "河北", "山西", "内蒙古", "辽宁", "吉林", "黑龙江",
"江苏", "浙江", "安徽", "福建", "江西", "山东", "河南", "湖北", "湖南", "广东",
"广西", "海南", "四川", "贵州", "云南", "西藏", "陕西", "甘肃", "青海", "宁夏",
"新疆", "台湾", "香港", "澳门", "其他"
]

352
src/main.py Normal file
View File

@ -0,0 +1,352 @@
"""
A股行情分析与量化交易系统主程序
提供命令行接口支持数据初始化定时任务管理等功能
"""
import asyncio
import argparse
import sys
from pathlib import Path
from loguru import logger
from .config.settings import Settings
from .data.data_initializer import DataInitializer
from .scheduler.task_scheduler import TaskScheduler
class StockAnalysisSystem:
"""A股行情分析与量化交易系统主类"""
def __init__(self, config_path: str = None):
"""
初始化系统
Args:
config_path: 配置文件路径
"""
self.config_path = config_path
self.settings = self._load_settings()
self.data_initializer = None
self.task_scheduler = None
self._setup_logging()
def _load_settings(self) -> Settings:
"""
加载系统配置
Returns:
配置对象
"""
try:
if self.config_path:
return Settings(_env_file=self.config_path)
else:
return Settings()
except Exception as e:
logger.error(f"加载配置失败: {str(e)}")
raise
def _setup_logging(self):
"""配置日志系统"""
log_path = Path("logs")
log_path.mkdir(exist_ok=True)
logger.add(
log_path / "stock_system.log",
rotation="10 MB",
retention="30 days",
level=self.settings.log.log_level,
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}"
)
async def initialize_data(self) -> dict:
"""
初始化全量股票数据
Returns:
初始化结果
"""
try:
logger.info("开始初始化全量股票数据...")
self.data_initializer = DataInitializer(self.settings)
result = await self.data_initializer.initialize_all_data()
if result["success"]:
logger.info("数据初始化完成")
else:
logger.error(f"数据初始化失败: {result.get('error')}")
return result
except Exception as e:
logger.error(f"数据初始化异常: {str(e)}")
return {"success": False, "error": str(e)}
async def start_scheduler(self) -> bool:
"""
启动定时任务调度器
Returns:
启动是否成功
"""
try:
logger.info("启动定时任务调度器...")
self.task_scheduler = TaskScheduler(self.settings)
success = await self.task_scheduler.start_scheduler()
if success:
logger.info("定时任务调度器启动成功")
# 显示定时任务信息
jobs = self.task_scheduler.get_scheduled_jobs()
logger.info("已配置的定时任务:")
for job in jobs:
logger.info(f" - {job['name']}: {job['trigger']}")
logger.info(f" 下次执行时间: {job['next_run_time']}")
return True
else:
logger.error("定时任务调度器启动失败")
return False
except Exception as e:
logger.error(f"启动定时任务调度器异常: {str(e)}")
return False
async def stop_scheduler(self) -> bool:
"""
停止定时任务调度器
Returns:
停止是否成功
"""
try:
if not self.task_scheduler:
logger.warning("定时任务调度器未启动")
return True
success = await self.task_scheduler.stop_scheduler()
if success:
logger.info("定时任务调度器停止成功")
else:
logger.error("定时任务调度器停止失败")
return success
except Exception as e:
logger.error(f"停止定时任务调度器异常: {str(e)}")
return False
async def check_data_status(self) -> dict:
"""
检查数据状态
Returns:
数据状态信息
"""
try:
logger.info("检查数据状态...")
self.data_initializer = DataInitializer(self.settings)
status = await self.data_initializer.check_data_status()
logger.info("数据状态检查完成")
return status
except Exception as e:
logger.error(f"检查数据状态异常: {str(e)}")
return {"error": str(e)}
async def manual_update_daily_kline(self) -> dict:
"""
手动更新每日K线数据
Returns:
更新结果
"""
try:
logger.info("手动更新每日K线数据...")
self.task_scheduler = TaskScheduler(self.settings)
await self.task_scheduler.start_scheduler()
result = await self.task_scheduler.manual_update_daily_kline()
await self.task_scheduler.stop_scheduler()
if result["success"]:
logger.info("手动更新每日K线数据完成")
else:
logger.error(f"手动更新每日K线数据失败: {result.get('error')}")
return result
except Exception as e:
logger.error(f"手动更新每日K线数据异常: {str(e)}")
return {"success": False, "error": str(e)}
async def manual_update_financial_data(self) -> dict:
"""
手动更新财务数据
Returns:
更新结果
"""
try:
logger.info("手动更新财务数据...")
self.task_scheduler = TaskScheduler(self.settings)
await self.task_scheduler.start_scheduler()
result = await self.task_scheduler.manual_update_financial_data()
await self.task_scheduler.stop_scheduler()
if result["success"]:
logger.info("手动更新财务数据完成")
else:
logger.error(f"手动更新财务数据失败: {result.get('error')}")
return result
except Exception as e:
logger.error(f"手动更新财务数据异常: {str(e)}")
return {"success": False, "error": str(e)}
async def main():
"""主函数"""
parser = argparse.ArgumentParser(
description="A股行情分析与量化交易系统",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
使用示例:
python main.py init # 初始化全量数据
python main.py scheduler start # 启动定时任务
python main.py scheduler stop # 停止定时任务
python main.py status # 检查数据状态
python main.py update kline # 手动更新K线数据
python main.py update financial # 手动更新财务数据
"""
)
parser.add_argument(
"--config",
"-c",
help="配置文件路径",
default=None
)
subparsers = parser.add_subparsers(
dest="command",
help="可用命令"
)
# init 命令
init_parser = subparsers.add_parser(
"init",
help="初始化全量股票数据"
)
# scheduler 命令
scheduler_parser = subparsers.add_parser(
"scheduler",
help="定时任务管理"
)
scheduler_parser.add_argument(
"action",
choices=["start", "stop", "status"],
help="定时任务操作"
)
# status 命令
status_parser = subparsers.add_parser(
"status",
help="检查系统状态"
)
# update 命令
update_parser = subparsers.add_parser(
"update",
help="手动更新数据"
)
update_parser.add_argument(
"data_type",
choices=["kline", "financial"],
help="数据类型"
)
args = parser.parse_args()
if not args.command:
parser.print_help()
return
system = StockAnalysisSystem(args.config)
try:
if args.command == "init":
result = await system.initialize_data()
print(f"数据初始化结果: {result}")
elif args.command == "scheduler":
if args.action == "start":
success = await system.start_scheduler()
if success:
print("定时任务调度器启动成功")
print("系统将在后台运行按Ctrl+C退出")
# 保持程序运行
try:
while True:
await asyncio.sleep(1)
except KeyboardInterrupt:
print("\n正在停止系统...")
await system.stop_scheduler()
print("系统已停止")
else:
print("定时任务调度器启动失败")
sys.exit(1)
elif args.action == "stop":
success = await system.stop_scheduler()
if success:
print("定时任务调度器停止成功")
else:
print("定时任务调度器停止失败")
sys.exit(1)
elif args.action == "status":
jobs = system.task_scheduler.get_scheduled_jobs() if system.task_scheduler else []
print("定时任务状态:")
for job in jobs:
print(f" {job['name']}: {job['trigger']}")
print(f" 下次执行时间: {job['next_run_time']}")
elif args.command == "status":
status = await system.check_data_status()
print("系统数据状态:")
for key, value in status.items():
print(f" {key}: {value}")
elif args.command == "update":
if args.data_type == "kline":
result = await system.manual_update_daily_kline()
print(f"K线数据更新结果: {result}")
elif args.data_type == "financial":
result = await system.manual_update_financial_data()
print(f"财务数据更新结果: {result}")
except KeyboardInterrupt:
print("\n操作被用户中断")
except Exception as e:
logger.error(f"系统运行异常: {str(e)}")
print(f"系统运行异常: {str(e)}")
sys.exit(1)
if __name__ == "__main__":
asyncio.run(main())

View File

@ -0,0 +1 @@
# 定时任务调度模块

View File

@ -0,0 +1,440 @@
"""
定时任务调度器
实现每日增量数据的自动获取与更新
"""
import asyncio
from datetime import datetime, date, timedelta
from typing import Dict, Any, List
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.cron import CronTrigger
from loguru import logger
from src.data.data_manager import DataManager
from src.storage.database import DatabaseManager
from src.storage.stock_repository import StockRepository
from src.config.settings import Settings
class TaskScheduler:
"""定时任务调度器类"""
def __init__(self, settings: Settings):
"""
初始化任务调度器
Args:
settings: 系统配置
"""
self.settings = settings
self.scheduler = AsyncIOScheduler()
self.data_manager = DataManager(settings)
self.db_manager = DatabaseManager()
self.repository = None
self.is_running = False
async def start_scheduler(self) -> bool:
"""
启动定时任务调度器
Returns:
启动是否成功
"""
try:
if self.is_running:
logger.warning("调度器已在运行中")
return True
# 数据库连接已在DatabaseManager构造函数中自动初始化
self.repository = StockRepository(self.db_manager.get_session())
# 配置定时任务
self._configure_scheduled_tasks()
# 启动调度器
self.scheduler.start()
self.is_running = True
logger.info("定时任务调度器启动成功")
return True
except Exception as e:
logger.error(f"启动定时任务调度器失败: {str(e)}")
return False
async def stop_scheduler(self) -> bool:
"""
停止定时任务调度器
Returns:
停止是否成功
"""
try:
if not self.is_running:
logger.warning("调度器未在运行中")
return True
self.scheduler.shutdown()
self.is_running = False
if self.db_manager:
self.db_manager.close()
logger.info("定时任务调度器停止成功")
return True
except Exception as e:
logger.error(f"停止定时任务调度器失败: {str(e)}")
return False
def _configure_scheduled_tasks(self):
"""配置定时任务"""
# 1. 每日收盘后更新K线数据工作日16:00执行
self.scheduler.add_job(
self._update_daily_kline_data,
trigger=CronTrigger(
hour=16,
minute=0,
day_of_week='mon-fri'
),
id='daily_kline_update',
name='每日K线数据更新',
replace_existing=True
)
# 2. 每周更新财务数据周六上午10:00执行
self.scheduler.add_job(
self._update_financial_data,
trigger=CronTrigger(
hour=10,
minute=0,
day_of_week='sat'
),
id='weekly_financial_update',
name='每周财务数据更新',
replace_existing=True
)
# 3. 每月更新股票基础信息每月第一个周六上午9:00执行
self.scheduler.add_job(
self._update_stock_basic_info,
trigger=CronTrigger(
hour=9,
minute=0,
day='1st sat'
),
id='monthly_stock_basic_update',
name='每月股票基础信息更新',
replace_existing=True
)
# 4. 每日健康检查工作日9:00执行
self.scheduler.add_job(
self._health_check,
trigger=CronTrigger(
hour=9,
minute=0,
day_of_week='mon-fri'
),
id='daily_health_check',
name='每日健康检查',
replace_existing=True
)
logger.info("定时任务配置完成")
async def _update_daily_kline_data(self):
"""更新每日K线数据"""
try:
logger.info("开始执行每日K线数据更新任务...")
# 获取所有股票代码
stocks = self.repository.get_stock_basic_info()
if not stocks:
logger.warning("没有股票基础信息无法更新K线数据")
return
# 计算更新日期默认更新最近3个交易日的数据
end_date = date.today()
start_date = self._get_previous_trading_day(end_date, days_back=3)
if not start_date:
logger.warning("无法确定有效的交易日,跳过本次更新")
return
updated_count = 0
error_count = 0
total_kline_data = []
# 分批更新K线数据
batch_size = 50
for i in range(0, len(stocks), batch_size):
batch_stocks = stocks[i:i + batch_size]
for stock in batch_stocks:
try:
logger.info(f"更新股票{stock.code}的K线数据...")
kline_data = await self.data_manager.get_daily_kline_data(
stock.code,
start_date.strftime("%Y-%m-%d"),
end_date.strftime("%Y-%m-%d")
)
if kline_data:
total_kline_data.extend(kline_data)
updated_count += 1
else:
logger.warning(f"股票{stock.code}未获取到K线数据")
error_count += 1
await asyncio.sleep(0.1)
except Exception as e:
logger.error(f"更新股票{stock.code}K线数据失败: {str(e)}")
error_count += 1
continue
# 保存更新的K线数据
if total_kline_data:
save_result = self.repository.save_daily_kline_data(total_kline_data)
logger.info(f"K线数据更新完成: 成功{updated_count}只股票, 失败{error_count}只股票")
logger.info(f"数据保存结果: {save_result}")
else:
logger.warning("未获取到任何K线数据")
except Exception as e:
logger.error(f"每日K线数据更新任务失败: {str(e)}")
async def _update_financial_data(self):
"""更新财务数据"""
try:
logger.info("开始执行每周财务数据更新任务...")
# 获取所有股票代码
stocks = self.repository.get_stock_basic_info()
if not stocks:
logger.warning("没有股票基础信息,无法更新财务数据")
return
updated_count = 0
error_count = 0
total_financial_data = []
# 分批更新财务数据
batch_size = 30
for i in range(0, len(stocks), batch_size):
batch_stocks = stocks[i:i + batch_size]
for stock in batch_stocks:
try:
logger.info(f"更新股票{stock.code}的财务数据...")
financial_data = await self.data_manager.get_financial_report(
stock.code
)
if financial_data:
total_financial_data.extend(financial_data)
updated_count += 1
else:
logger.warning(f"股票{stock.code}未获取到财务数据")
error_count += 1
await asyncio.sleep(0.2)
except Exception as e:
logger.error(f"更新股票{stock.code}财务数据失败: {str(e)}")
error_count += 1
continue
# 保存更新的财务数据
if total_financial_data:
save_result = self.repository.save_financial_report_data(total_financial_data)
logger.info(f"财务数据更新完成: 成功{updated_count}只股票, 失败{error_count}只股票")
logger.info(f"数据保存结果: {save_result}")
else:
logger.warning("未获取到任何财务数据")
except Exception as e:
logger.error(f"每周财务数据更新任务失败: {str(e)}")
async def _update_stock_basic_info(self):
"""更新股票基础信息"""
try:
logger.info("开始执行每月股票基础信息更新任务...")
# 获取最新的股票基础信息
stock_basic_data = await self.data_manager.get_stock_basic_info()
if not stock_basic_data:
logger.warning("未获取到股票基础信息")
return
# 保存更新的基础信息
save_result = self.repository.save_stock_basic_info(stock_basic_data)
logger.info(f"股票基础信息更新完成: {save_result}")
except Exception as e:
logger.error(f"每月股票基础信息更新任务失败: {str(e)}")
async def _health_check(self):
"""系统健康检查"""
try:
logger.info("开始执行系统健康检查...")
# 检查数据库连接
db_status = self.db_manager.check_connection()
# 检查数据源连接
akshare_status = await self._check_akshare_connection()
baostock_status = await self._check_baostock_connection()
# 检查数据完整性
data_status = await self._check_data_integrity()
health_status = {
"timestamp": datetime.now().isoformat(),
"database": db_status,
"akshare": akshare_status,
"baostock": baostock_status,
"data_integrity": data_status,
"overall": "healthy" if all([
db_status == "connected",
akshare_status == "available",
baostock_status == "available",
data_status.get("status") == "normal"
]) else "degraded"
}
logger.info(f"系统健康检查完成: {health_status}")
except Exception as e:
logger.error(f"系统健康检查失败: {str(e)}")
async def _check_akshare_connection(self) -> str:
"""检查AKshare连接状态"""
try:
# 尝试获取一只股票的基础信息
test_data = await self.data_manager.akshare_collector.get_stock_basic_info()
return "available" if test_data else "unavailable"
except Exception:
return "unavailable"
async def _check_baostock_connection(self) -> str:
"""检查Baostock连接状态"""
try:
# 尝试登录并获取一只股票的基础信息
await self.data_manager.baostock_collector.login()
test_data = await self.data_manager.baostock_collector.get_stock_basic_info()
await self.data_manager.baostock_collector.logout()
return "available" if test_data else "unavailable"
except Exception:
return "unavailable"
async def _check_data_integrity(self) -> Dict[str, Any]:
"""检查数据完整性"""
try:
# 查询各表数据量
stock_count = self.repository.session.query(StockBasic).count()
kline_count = self.repository.session.query(DailyKline).count()
financial_count = self.repository.session.query(FinancialReport).count()
# 检查最新数据日期
latest_kline = self.repository.session.query(DailyKline).order_by(
DailyKline.trade_date.desc()
).first()
latest_financial = self.repository.session.query(FinancialReport).order_by(
FinancialReport.report_date.desc()
).first()
# 判断数据状态
kline_status = "normal" if latest_kline and (date.today() - latest_kline.trade_date).days <= 3 else "stale"
financial_status = "normal" if latest_financial and (date.today() - latest_financial.report_date).days <= 90 else "stale"
return {
"status": "normal" if kline_status == "normal" and financial_status == "normal" else "degraded",
"stock_count": stock_count,
"kline_count": kline_count,
"financial_count": financial_count,
"latest_kline_date": latest_kline.trade_date if latest_kline else None,
"latest_financial_date": latest_financial.report_date if latest_financial else None,
"kline_status": kline_status,
"financial_status": financial_status
}
except Exception as e:
return {"status": "error", "error": str(e)}
def _get_previous_trading_day(self, current_date: date, days_back: int = 1) -> date:
"""
获取前一个交易日
Args:
current_date: 当前日期
days_back: 回溯天数
Returns:
前一个交易日
"""
try:
# 简单实现:跳过周末
previous_date = current_date - timedelta(days=days_back)
# 如果是周末,继续向前找
while previous_date.weekday() >= 5: # 5=周六, 6=周日
previous_date -= timedelta(days=1)
return previous_date
except Exception as e:
logger.error(f"计算前一个交易日失败: {str(e)}")
return current_date - timedelta(days=days_back)
async def manual_update_daily_kline(self) -> Dict[str, Any]:
"""
手动触发每日K线数据更新
Returns:
更新结果
"""
try:
await self._update_daily_kline_data()
return {"success": True, "message": "手动更新完成"}
except Exception as e:
return {"success": False, "error": str(e)}
async def manual_update_financial_data(self) -> Dict[str, Any]:
"""
手动触发财务数据更新
Returns:
更新结果
"""
try:
await self._update_financial_data()
return {"success": True, "message": "手动更新完成"}
except Exception as e:
return {"success": False, "error": str(e)}
def get_scheduled_jobs(self) -> List[Dict[str, Any]]:
"""
获取所有定时任务信息
Returns:
任务列表
"""
jobs = []
for job in self.scheduler.get_jobs():
jobs.append({
"id": job.id,
"name": job.name,
"next_run_time": job.next_run_time.isoformat() if job.next_run_time else None,
"trigger": str(job.trigger)
})
return jobs

1
src/storage/__init__.py Normal file
View File

@ -0,0 +1 @@
# 数据存储模块

104
src/storage/database.py Normal file
View File

@ -0,0 +1,104 @@
"""
数据库连接和配置管理
负责数据库连接池的创建和管理
"""
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import QueuePool
from loguru import logger
from ..config.settings import settings
class DatabaseManager:
"""数据库管理器"""
_instance = None
def __new__(cls):
"""单例模式"""
if cls._instance is None:
cls._instance = super(DatabaseManager, cls).__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
"""初始化数据库连接"""
if self._initialized:
return
self._initialized = True
self.engine = None
self.SessionLocal = None
self.Base = declarative_base()
self._setup_database()
def _setup_database(self):
"""配置数据库连接"""
try:
# 创建数据库引擎
self.engine = create_engine(
settings.database.database_url,
poolclass=QueuePool,
pool_size=settings.database.pool_size,
max_overflow=settings.database.max_overflow,
pool_timeout=settings.database.pool_timeout,
echo=False # 生产环境设为False
)
# 创建会话工厂
self.SessionLocal = sessionmaker(
bind=self.engine,
autoflush=False,
autocommit=False
)
logger.info("数据库连接配置完成")
except Exception as e:
logger.error(f"数据库连接配置失败: {str(e)}")
raise
def get_session(self):
"""
获取数据库会话
Returns:
数据库会话对象
"""
try:
session = self.SessionLocal()
return session
except Exception as e:
logger.error(f"获取数据库会话失败: {str(e)}")
raise
def create_tables(self):
"""创建所有数据表"""
try:
# 导入所有模型以确保它们被注册
from . import models
# 创建所有表
self.Base.metadata.create_all(bind=self.engine)
logger.info("数据库表创建完成")
except Exception as e:
logger.error(f"创建数据库表失败: {str(e)}")
raise
def drop_tables(self):
"""删除所有数据表(仅用于测试)"""
try:
self.Base.metadata.drop_all(bind=self.engine)
logger.info("数据库表删除完成")
except Exception as e:
logger.error(f"删除数据库表失败: {str(e)}")
raise
# 全局数据库管理器实例
db_manager = DatabaseManager()

257
src/storage/models.py Normal file
View File

@ -0,0 +1,257 @@
"""
数据库模型定义
定义股票数据相关的所有数据表结构
"""
from sqlalchemy import (
Column, Integer, String, Float, Date, DateTime,
Text, Boolean, BigInteger, ForeignKey, Index
)
from sqlalchemy.orm import relationship
from sqlalchemy.sql import func
from .database import db_manager
Base = db_manager.Base
class StockBasic(Base):
"""
股票基础信息表
存储股票的基本信息
"""
__tablename__ = "stock_basic"
# 主键
id = Column(Integer, primary_key=True, autoincrement=True, comment="主键ID")
# 股票基本信息
code = Column(String(10), nullable=False, unique=True, comment="股票代码")
name = Column(String(50), nullable=False, comment="股票名称")
market = Column(String(10), nullable=False, comment="市场类型(sh/sz)")
# 公司信息
company_name = Column(String(100), comment="公司全称")
industry = Column(String(50), comment="所属行业")
area = Column(String(50), comment="地区")
# 上市信息
ipo_date = Column(Date, comment="上市日期")
listing_status = Column(Boolean, default=True, comment="上市状态")
# 时间戳
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
updated_at = Column(
DateTime,
server_default=func.now(),
onupdate=func.now(),
comment="更新时间"
)
# 索引
__table_args__ = (
Index("idx_code", "code"),
Index("idx_market", "market"),
Index("idx_industry", "industry"),
Index("idx_ipo_date", "ipo_date")
)
def __repr__(self):
return f"<StockBasic(code='{self.code}', name='{self.name}')>"
class DailyKline(Base):
"""
日K线数据表
存储股票的日K线数据
"""
__tablename__ = "daily_kline"
# 主键
id = Column(Integer, primary_key=True, autoincrement=True, comment="主键ID")
# 外键关联
stock_code = Column(
String(10),
ForeignKey("stock_basic.code"),
nullable=False,
comment="股票代码"
)
# K线数据
trade_date = Column(Date, nullable=False, comment="交易日期")
open_price = Column(Float, nullable=False, comment="开盘价")
high_price = Column(Float, nullable=False, comment="最高价")
low_price = Column(Float, nullable=False, comment="最低价")
close_price = Column(Float, nullable=False, comment="收盘价")
# 成交量信息
volume = Column(BigInteger, nullable=False, comment="成交量(股)")
amount = Column(Float, nullable=False, comment="成交额(元)")
# 涨跌幅信息
change = Column(Float, comment="涨跌额")
pct_change = Column(Float, comment="涨跌幅(%)")
# 时间戳
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
# 索引
__table_args__ = (
Index("idx_stock_code_date", "stock_code", "trade_date"),
Index("idx_trade_date", "trade_date"),
Index("idx_stock_code", "stock_code")
)
# 关系
stock = relationship("StockBasic", backref="daily_kline")
def __repr__(self):
return f"<DailyKline(code='{self.stock_code}', date='{self.trade_date}')>"
class FinancialReport(Base):
"""
财务报告数据表
存储股票的财务报告数据
"""
__tablename__ = "financial_report"
# 主键
id = Column(Integer, primary_key=True, autoincrement=True, comment="主键ID")
# 外键关联
stock_code = Column(
String(10),
ForeignKey("stock_basic.code"),
nullable=False,
comment="股票代码"
)
# 报告期信息
report_date = Column(Date, nullable=False, comment="报告日期")
report_type = Column(String(20), nullable=False, comment="报告类型(Q1/Q2/Q3/Q4/年报)")
report_year = Column(Integer, nullable=False, comment="报告年份")
report_quarter = Column(Integer, comment="报告季度(1-4)")
# 财务指标
eps = Column(Float, comment="每股收益(元)")
net_profit = Column(Float, comment="净利润(万元)")
revenue = Column(Float, comment="营业收入(万元)")
total_assets = Column(Float, comment="总资产(万元)")
total_liabilities = Column(Float, comment="总负债(万元)")
equity = Column(Float, comment="股东权益(万元)")
# 盈利能力指标
roe = Column(Float, comment="净资产收益率(%)")
gross_profit_margin = Column(Float, comment="毛利率(%)")
net_profit_margin = Column(Float, comment="净利率(%)")
# 偿债能力指标
debt_to_asset_ratio = Column(Float, comment="资产负债率(%)")
current_ratio = Column(Float, comment="流动比率")
quick_ratio = Column(Float, comment="速动比率")
# 时间戳
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
# 索引
__table_args__ = (
Index("idx_stock_code_report", "stock_code", "report_date"),
Index("idx_report_date", "report_date"),
Index("idx_report_type", "report_type")
)
# 关系
stock = relationship("StockBasic", backref="financial_reports")
def __repr__(self):
return f"<FinancialReport(code='{self.stock_code}', date='{self.report_date}')>"
class DataSource(Base):
"""
数据源信息表
记录数据采集的来源和状态
"""
__tablename__ = "data_source"
# 主键
id = Column(Integer, primary_key=True, autoincrement=True, comment="主键ID")
# 数据源信息
source_name = Column(String(50), nullable=False, comment="数据源名称")
source_type = Column(String(20), nullable=False, comment="数据源类型(akshare/baostock)")
# 采集状态
last_sync_time = Column(DateTime, comment="最后同步时间")
sync_status = Column(String(20), default="pending", comment="同步状态")
error_message = Column(Text, comment="错误信息")
# 统计信息
total_records = Column(Integer, default=0, comment="总记录数")
success_records = Column(Integer, default=0, comment="成功记录数")
failed_records = Column(Integer, default=0, comment="失败记录数")
# 时间戳
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
updated_at = Column(
DateTime,
server_default=func.now(),
onupdate=func.now(),
comment="更新时间"
)
# 索引
__table_args__ = (
Index("idx_source_name", "source_name"),
Index("idx_sync_status", "sync_status"),
Index("idx_last_sync_time", "last_sync_time")
)
def __repr__(self):
return f"<DataSource(name='{self.source_name}', status='{self.sync_status}')>"
class SystemLog(Base):
"""
系统日志表
记录系统运行日志和异常信息
"""
__tablename__ = "system_log"
# 主键
id = Column(Integer, primary_key=True, autoincrement=True, comment="主键ID")
# 日志信息
log_level = Column(String(10), nullable=False, comment="日志级别")
module_name = Column(String(50), nullable=False, comment="模块名称")
message = Column(Text, nullable=False, comment="日志消息")
# 异常信息
exception_type = Column(String(100), comment="异常类型")
exception_message = Column(Text, comment="异常消息")
traceback = Column(Text, comment="异常堆栈")
# 上下文信息
stock_code = Column(String(10), comment="关联股票代码")
data_type = Column(String(20), comment="数据类型")
# 时间戳
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
# 索引
__table_args__ = (
Index("idx_log_level", "log_level"),
Index("idx_module_name", "module_name"),
Index("idx_created_at", "created_at"),
Index("idx_stock_code", "stock_code")
)
def __repr__(self):
return f"<SystemLog(level='{self.log_level}', module='{self.module_name}')>"

View File

@ -0,0 +1,677 @@
"""
股票数据存储服务
提供股票数据的增删改查操作接口
"""
from typing import List, Optional, Dict, Any
from datetime import date, datetime, timedelta
from sqlalchemy.orm import Session
from sqlalchemy import and_, or_, desc, asc
from loguru import logger
class StockRepository:
"""股票数据存储服务类"""
def __init__(self, session: Session):
"""
初始化存储服务
Args:
session: 数据库会话对象
"""
self.session = session
self._setup_models()
def _setup_models(self):
"""
设置模型类
从会话的映射器获取模型类以兼容测试环境
"""
try:
# 首先尝试从会话绑定的Base类获取模型类
if hasattr(self.session.get_bind(), '_metadata'):
metadata = self.session.get_bind()._metadata
for table in metadata.tables.values():
# 通过表名找到对应的模型类
if table.name == "stock_basic":
self.StockBasic = table.entity
elif table.name == "daily_kline":
self.DailyKline = table.entity
elif table.name == "financial_report":
self.FinancialReport = table.entity
elif table.name == "data_source":
self.DataSource = table.entity
elif table.name == "system_log":
self.SystemLog = table.entity
# 如果上述方法失败,尝试从映射器注册表获取
if not hasattr(self, "StockBasic"):
from sqlalchemy.inspection import inspect
# 获取会话绑定的所有模型类
for mapper in self.session.get_bind().mapper_registry.mappers:
model_class = mapper.class_
class_name = model_class.__name__
# 根据类名设置对应的模型属性
if class_name == "StockBasic":
self.StockBasic = model_class
elif class_name == "DailyKline":
self.DailyKline = model_class
elif class_name == "FinancialReport":
self.FinancialReport = model_class
elif class_name == "DataSource":
self.DataSource = model_class
elif class_name == "SystemLog":
self.SystemLog = model_class
# 如果仍然没有找到模型类,则使用默认导入
if not hasattr(self, "StockBasic"):
from .models import StockBasic
self.StockBasic = StockBasic
if not hasattr(self, "DailyKline"):
from .models import DailyKline
self.DailyKline = DailyKline
if not hasattr(self, "FinancialReport"):
from .models import FinancialReport
self.FinancialReport = FinancialReport
if not hasattr(self, "DataSource"):
from .models import DataSource
self.DataSource = DataSource
if not hasattr(self, "SystemLog"):
from .models import SystemLog
self.SystemLog = SystemLog
except Exception as e:
# 如果无法从映射器获取,使用默认导入
from .models import StockBasic, DailyKline, FinancialReport, DataSource, SystemLog
self.StockBasic = StockBasic
self.DailyKline = DailyKline
self.FinancialReport = FinancialReport
self.DataSource = DataSource
self.SystemLog = SystemLog
logger.warning(f"无法从会话映射器获取模型类,使用默认导入: {str(e)}")
def save_stock_basic_info(
self,
stock_data: List[Dict[str, Any]]
) -> Dict[str, Any]:
"""
保存股票基础信息
Args:
stock_data: 股票基础信息列表
Returns:
保存结果统计
"""
try:
added_count = 0
updated_count = 0
error_count = 0
for data in stock_data:
try:
# 检查是否已存在
existing_stock = self.session.query(self.StockBasic).filter(
self.StockBasic.code == data["code"]
).first()
if existing_stock:
# 更新现有记录
existing_stock.name = data.get("name", existing_stock.name)
existing_stock.market = data.get("market", existing_stock.market)
existing_stock.industry = data.get("industry", existing_stock.industry)
existing_stock.area = data.get("area", existing_stock.area)
existing_stock.ipo_date = data.get("ipo_date", existing_stock.ipo_date)
updated_count += 1
else:
# 创建新记录
new_stock = self.StockBasic(
code=data["code"],
name=data["name"],
market=data.get("market", ""),
industry=data.get("industry", ""),
area=data.get("area", ""),
ipo_date=data.get("ipo_date")
)
self.session.add(new_stock)
added_count += 1
except Exception as e:
logger.error(f"保存股票{data.get('code')}基础信息失败: {str(e)}")
error_count += 1
self.session.commit()
result = {
"added_count": added_count,
"updated_count": updated_count,
"error_count": error_count,
"total_count": len(stock_data)
}
logger.info(f"股票基础信息保存完成: {result}")
return result
except Exception as e:
self.session.rollback()
logger.error(f"保存股票基础信息失败: {str(e)}")
raise
def save_daily_kline_data(
self,
kline_data: List[Dict[str, Any]]
) -> Dict[str, Any]:
"""
保存日K线数据
Args:
kline_data: 日K线数据列表
Returns:
保存结果统计
"""
try:
added_count = 0
error_count = 0
for data in kline_data:
try:
# 检查是否已存在
existing_kline = self.session.query(self.DailyKline).filter(
and_(
self.DailyKline.stock_code == data["code"],
self.DailyKline.trade_date == data["date"]
)
).first()
if not existing_kline:
# 创建新记录
new_kline = self.DailyKline(
stock_code=data["code"],
trade_date=datetime.strptime(data["date"], "%Y-%m-%d").date(),
open_price=data["open"],
high_price=data["high"],
low_price=data["low"],
close_price=data["close"],
volume=data["volume"],
amount=data["amount"]
)
self.session.add(new_kline)
added_count += 1
except Exception as e:
logger.error(f"保存股票{data.get('code')}K线数据失败: {str(e)}")
error_count += 1
self.session.commit()
result = {
"added_count": added_count,
"error_count": error_count,
"total_count": len(kline_data)
}
logger.info(f"日K线数据保存完成: {result}")
return result
except Exception as e:
self.session.rollback()
logger.error(f"保存日K线数据失败: {str(e)}")
raise
def save_financial_report_data(
self,
financial_data: List[Dict[str, Any]]
) -> Dict[str, Any]:
"""
保存财务报告数据
Args:
financial_data: 财务报告数据列表
Returns:
保存结果统计
"""
try:
added_count = 0
updated_count = 0
error_count = 0
for data in financial_data:
try:
# 解析报告日期
report_date = self._parse_report_date(data.get("report_date"))
if not report_date:
continue
# 检查是否已存在
existing_report = self.session.query(self.FinancialReport).filter(
and_(
self.FinancialReport.stock_code == data["code"],
self.FinancialReport.report_date == report_date
)
).first()
if existing_report:
# 更新现有记录
existing_report.eps = data.get("eps", existing_report.eps)
existing_report.net_profit = data.get("net_profit", existing_report.net_profit)
existing_report.revenue = data.get("revenue", existing_report.revenue)
existing_report.total_assets = data.get("total_assets", existing_report.total_assets)
updated_count += 1
else:
# 创建新记录
new_report = self.FinancialReport(
stock_code=data["code"],
report_date=report_date,
report_type=self._get_report_type(report_date),
report_year=report_date.year,
report_quarter=self._get_report_quarter(report_date),
eps=data.get("eps"),
net_profit=data.get("net_profit"),
revenue=data.get("revenue"),
total_assets=data.get("total_assets")
)
self.session.add(new_report)
added_count += 1
except Exception as e:
logger.error(f"保存股票{data.get('code')}财务数据失败: {str(e)}")
error_count += 1
self.session.commit()
result = {
"added_count": added_count,
"updated_count": updated_count,
"error_count": error_count,
"total_count": len(financial_data)
}
logger.info(f"财务报告数据保存完成: {result}")
return result
except Exception as e:
self.session.rollback()
logger.error(f"保存财务报告数据失败: {str(e)}")
raise
def save_system_log(self, log_data: Dict[str, Any]) -> bool:
"""
保存系统日志
Args:
log_data: 日志数据字典
Returns:
保存是否成功
"""
try:
# 创建日志记录
log_record = self.SystemLog(
log_level=log_data.get("log_level", "INFO"),
module_name=log_data.get("module_name", "unknown"),
message=log_data.get("message", ""),
exception_type=log_data.get("exception_type"),
exception_message=log_data.get("exception_message"),
traceback=log_data.get("traceback"),
stock_code=log_data.get("stock_code"),
data_type=log_data.get("data_type")
)
self.session.add(log_record)
self.session.commit()
logger.debug(f"系统日志保存成功: {log_data.get('log_level')} - {log_data.get('module_name')}")
return True
except Exception as e:
self.session.rollback()
logger.error(f"保存系统日志失败: {str(e)}")
return False
def get_stock_basic_info(
self,
market: Optional[str] = None,
industry: Optional[str] = None
) -> List:
"""
查询股票基础信息
Args:
market: 市场类型过滤
industry: 行业过滤
Returns:
股票基础信息列表
"""
try:
query = self.session.query(self.StockBasic)
if market:
query = query.filter(self.StockBasic.market == market)
if industry:
query = query.filter(self.StockBasic.industry == industry)
stocks = query.all()
logger.info(f"查询到{len(stocks)}只股票基础信息")
return stocks
except Exception as e:
logger.error(f"查询股票基础信息失败: {str(e)}")
raise
def get_daily_kline_data(
self,
stock_code: str,
start_date: date,
end_date: date
) -> List:
"""
查询日K线数据
Args:
stock_code: 股票代码
start_date: 开始日期
end_date: 结束日期
Returns:
日K线数据列表
"""
try:
kline_data = self.session.query(self.DailyKline).filter(
and_(
self.DailyKline.stock_code == stock_code,
self.DailyKline.trade_date >= start_date,
self.DailyKline.trade_date <= end_date
)
).order_by(asc(self.DailyKline.trade_date)).all()
logger.info(f"查询到{stock_code}{len(kline_data)}条K线数据")
return kline_data
except Exception as e:
logger.error(f"查询日K线数据失败: {str(e)}")
raise
def _parse_report_date(self, report_date_str: str) -> Optional[date]:
"""
解析报告日期字符串
Args:
report_date_str: 报告日期字符串
Returns:
日期对象
"""
try:
if "Q" in report_date_str:
# 处理季度报告日期
year, quarter = report_date_str.split("-")
quarter_num = int(quarter.replace("Q", ""))
# 季度结束日期
if quarter_num == 1:
return date(int(year), 3, 31)
elif quarter_num == 2:
return date(int(year), 6, 30)
elif quarter_num == 3:
return date(int(year), 9, 30)
else:
return date(int(year), 12, 31)
else:
# 处理标准日期格式
return datetime.strptime(report_date_str, "%Y-%m-%d").date()
except Exception:
return None
def _get_report_type(self, report_date: date) -> str:
"""
获取报告类型
Args:
report_date: 报告日期
Returns:
报告类型
"""
month = report_date.month
if month == 3:
return "Q1"
elif month == 6:
return "Q2"
elif month == 9:
return "Q3"
elif month == 12:
return "Q4"
else:
return "年报"
def _get_report_quarter(self, report_date: date) -> int:
"""
获取报告季度
Args:
report_date: 报告日期
Returns:
季度(1-4)
"""
month = report_date.month
if month <= 3:
return 1
elif month <= 6:
return 2
elif month <= 9:
return 3
else:
return 4
# 后端服务器API所需的方法
def get_stock_count(self) -> int:
"""
获取股票总数
Returns:
股票总数
"""
try:
count = self.session.query(self.StockBasic).count()
return count
except Exception as e:
logger.error(f"获取股票总数失败: {str(e)}")
return 0
def get_kline_count(self) -> int:
"""
获取K线数据总数
Returns:
K线数据总数
"""
try:
count = self.session.query(self.DailyKline).count()
return count
except Exception as e:
logger.error(f"获取K线数据总数失败: {str(e)}")
return 0
def get_financial_count(self) -> int:
"""
获取财务报告总数
Returns:
财务报告总数
"""
try:
count = self.session.query(self.FinancialReport).count()
return count
except Exception as e:
logger.error(f"获取财务报告总数失败: {str(e)}")
return 0
def get_log_count(self) -> int:
"""
获取系统日志总数
Returns:
系统日志总数
"""
try:
count = self.session.query(self.SystemLog).count()
return count
except Exception as e:
logger.error(f"获取系统日志总数失败: {str(e)}")
return 0
def get_stocks(self, limit: int = 20, offset: int = 0) -> List:
"""
获取股票列表分页
Args:
limit: 每页数量
offset: 偏移量
Returns:
股票列表
"""
try:
stocks = self.session.query(self.StockBasic).order_by(
self.StockBasic.code
).offset(offset).limit(limit).all()
return stocks
except Exception as e:
logger.error(f"获取股票列表失败: {str(e)}")
return []
def search_stocks(self, query: str) -> List:
"""
搜索股票
Args:
query: 搜索关键词
Returns:
匹配的股票列表
"""
try:
stocks = self.session.query(self.StockBasic).filter(
or_(
self.StockBasic.code.like(f"%{query}%"),
self.StockBasic.name.like(f"%{query}%")
)
).order_by(self.StockBasic.code).all()
return stocks
except Exception as e:
logger.error(f"搜索股票失败: {str(e)}")
return []
def get_kline_data(self, stock_code: str, start_date: date, end_date: date, period: str = "daily") -> List:
"""
获取K线数据
Args:
stock_code: 股票代码
start_date: 开始日期
end_date: 结束日期
period: 周期类型daily/weekly/monthly
Returns:
K线数据列表
"""
try:
# 目前只支持日线数据
kline_data = self.session.query(self.DailyKline).filter(
and_(
self.DailyKline.stock_code == stock_code,
self.DailyKline.trade_date >= start_date,
self.DailyKline.trade_date <= end_date
)
).order_by(asc(self.DailyKline.trade_date)).all()
return kline_data
except Exception as e:
logger.error(f"获取K线数据失败: {str(e)}")
return []
def get_financial_data(self, stock_code: str, year: str, period: str) -> Optional[Any]:
"""
获取财务数据
Args:
stock_code: 股票代码
year: 年份
period: 季度Q1/Q2/Q3/Q4
Returns:
财务数据对象
"""
try:
# 根据季度确定报告日期范围
if period == "Q1":
report_date = date(int(year), 3, 31)
elif period == "Q2":
report_date = date(int(year), 6, 30)
elif period == "Q3":
report_date = date(int(year), 9, 30)
elif period == "Q4":
report_date = date(int(year), 12, 31)
else:
report_date = date(int(year), 12, 31) # 默认年度报告
financial_data = self.session.query(self.FinancialReport).filter(
and_(
self.FinancialReport.stock_code == stock_code,
self.FinancialReport.report_date == report_date
)
).first()
return financial_data
except Exception as e:
logger.error(f"获取财务数据失败: {str(e)}")
return None
def get_system_logs(self, level: str = "", date_str: str = "") -> List:
"""
获取系统日志
Args:
level: 日志级别过滤
date_str: 日期过滤YYYY-MM-DD格式
Returns:
系统日志列表
"""
try:
query = self.session.query(self.SystemLog)
if level:
query = query.filter(self.SystemLog.log_level == level)
if date_str:
try:
target_date = datetime.strptime(date_str, "%Y-%m-%d").date()
query = query.filter(
self.SystemLog.timestamp >= target_date,
self.SystemLog.timestamp < target_date + timedelta(days=1)
)
except ValueError:
pass # 日期格式错误,忽略过滤
logs = query.order_by(desc(self.SystemLog.timestamp)).limit(100).all()
return logs
except Exception as e:
logger.error(f"获取系统日志失败: {str(e)}")
return []

1
src/utils/__init__.py Normal file
View File

@ -0,0 +1 @@
# 工具模块

420
src/utils/exceptions.py Normal file
View File

@ -0,0 +1,420 @@
"""
异常处理模块
定义系统自定义异常和异常处理机制
"""
from typing import Optional, Dict, Any
from loguru import logger
class StockSystemError(Exception):
"""系统基础异常类"""
def __init__(self, message: str, error_code: str = "SYSTEM_ERROR", details: Optional[Dict[str, Any]] = None):
"""
初始化异常
Args:
message: 异常信息
error_code: 错误代码
details: 详细信息
"""
super().__init__(message)
self.message = message
self.error_code = error_code
self.details = details or {}
self.timestamp = self._get_timestamp()
def _get_timestamp(self) -> str:
"""获取时间戳"""
from datetime import datetime
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式"""
return {
"error_code": self.error_code,
"message": self.message,
"details": self.details,
"timestamp": self.timestamp
}
class DataCollectionError(StockSystemError):
"""数据采集异常"""
def __init__(self, message: str, data_source: str = "", data_type: str = "", details: Optional[Dict[str, Any]] = None):
"""
初始化数据采集异常
Args:
message: 异常信息
data_source: 数据源
data_type: 数据类型
details: 详细信息
"""
super().__init__(message, "DATA_COLLECTION_ERROR", details)
self.data_source = data_source
self.data_type = data_type
self.details.update({
"data_source": data_source,
"data_type": data_type
})
class DataProcessingError(StockSystemError):
"""数据处理异常"""
def __init__(self, message: str, processor: str = "", data_type: str = "", details: Optional[Dict[str, Any]] = None):
"""
初始化数据处理异常
Args:
message: 异常信息
processor: 处理器名称
data_type: 数据类型
details: 详细信息
"""
super().__init__(message, "DATA_PROCESSING_ERROR", details)
self.processor = processor
self.data_type = data_type
self.details.update({
"processor": processor,
"data_type": data_type
})
class DatabaseError(StockSystemError):
"""数据库异常"""
def __init__(self, message: str, operation: str = "", table: str = "", details: Optional[Dict[str, Any]] = None):
"""
初始化数据库异常
Args:
message: 异常信息
operation: 操作类型
table: 表名
details: 详细信息
"""
super().__init__(message, "DATABASE_ERROR", details)
self.operation = operation
self.table = table
self.details.update({
"operation": operation,
"table": table
})
class SchedulerError(StockSystemError):
"""定时任务异常"""
def __init__(self, message: str, job_id: str = "", job_type: str = "", details: Optional[Dict[str, Any]] = None):
"""
初始化定时任务异常
Args:
message: 异常信息
job_id: 任务ID
job_type: 任务类型
details: 详细信息
"""
super().__init__(message, "SCHEDULER_ERROR", details)
self.job_id = job_id
self.job_type = job_type
self.details.update({
"job_id": job_id,
"job_type": job_type
})
class ConfigurationError(StockSystemError):
"""配置异常"""
def __init__(self, message: str, config_key: str = "", config_file: str = "", details: Optional[Dict[str, Any]] = None):
"""
初始化配置异常
Args:
message: 异常信息
config_key: 配置键
config_file: 配置文件
details: 详细信息
"""
super().__init__(message, "CONFIGURATION_ERROR", details)
self.config_key = config_key
self.config_file = config_file
self.details.update({
"config_key": config_key,
"config_file": config_file
})
class ValidationError(StockSystemError):
"""数据验证异常"""
def __init__(self, message: str, field: str = "", value: Any = None, validator: str = "", details: Optional[Dict[str, Any]] = None):
"""
初始化数据验证异常
Args:
message: 异常信息
field: 字段名
value: 字段值
validator: 验证器名称
details: 详细信息
"""
super().__init__(message, "VALIDATION_ERROR", details)
self.field = field
self.value = value
self.validator = validator
self.details.update({
"field": field,
"value": value,
"validator": validator
})
class NetworkError(StockSystemError):
"""网络异常"""
def __init__(self, message: str, url: str = "", status_code: Optional[int] = None, details: Optional[Dict[str, Any]] = None):
"""
初始化网络异常
Args:
message: 异常信息
url: 请求URL
status_code: 状态码
details: 详细信息
"""
super().__init__(message, "NETWORK_ERROR", details)
self.url = url
self.status_code = status_code
self.details.update({
"url": url,
"status_code": status_code
})
class ExceptionHandler:
"""异常处理器"""
def __init__(self, log_manager):
"""
初始化异常处理器
Args:
log_manager: 日志管理器
"""
self.log_manager = log_manager
self.logger = log_manager.get_logger("exception_handler")
def handle_exception(self, exception: Exception, context: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""
处理异常
Args:
exception: 异常对象
context: 上下文信息
Returns:
异常信息字典
"""
try:
# 如果是系统自定义异常
if isinstance(exception, StockSystemError):
return self._handle_system_exception(exception, context)
# 如果是标准异常
return self._handle_standard_exception(exception, context)
except Exception as e:
# 异常处理器本身出错
self.logger.error(f"异常处理器出错: {str(e)}")
return {
"error_code": "EXCEPTION_HANDLER_ERROR",
"message": "异常处理器内部错误",
"original_error": str(exception),
"handler_error": str(e)
}
def _handle_system_exception(self, exception: StockSystemError, context: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""
处理系统自定义异常
Args:
exception: 系统异常对象
context: 上下文信息
Returns:
异常信息字典
"""
error_info = exception.to_dict()
# 根据异常类型记录不同级别的日志
if isinstance(exception, (DataCollectionError, DatabaseError, SchedulerError)):
self.logger.error(f"系统异常: {exception.error_code} - {exception.message}", context)
elif isinstance(exception, (DataProcessingError, ValidationError)):
self.logger.warning(f"数据处理异常: {exception.error_code} - {exception.message}", context)
else:
self.logger.error(f"未知系统异常: {exception.error_code} - {exception.message}", context)
# 发送警报
self._send_alert(exception, context)
return error_info
def _handle_standard_exception(self, exception: Exception, context: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""
处理标准异常
Args:
exception: 标准异常对象
context: 上下文信息
Returns:
异常信息字典
"""
error_info = {
"error_code": "STANDARD_ERROR",
"message": str(exception),
"exception_type": type(exception).__name__,
"timestamp": self._get_timestamp()
}
# 记录错误日志
self.logger.error(f"标准异常: {type(exception).__name__} - {str(exception)}", context)
# 发送警报
self._send_alert(exception, context)
return error_info
def _send_alert(self, exception: Exception, context: Optional[Dict[str, Any]] = None):
"""
发送异常警报
Args:
exception: 异常对象
context: 上下文信息
"""
try:
# 根据异常类型确定警报级别
if isinstance(exception, (DataCollectionError, DatabaseError, SchedulerError)):
alert_level = "ERROR"
elif isinstance(exception, (DataProcessingError, ValidationError)):
alert_level = "WARNING"
else:
alert_level = "ERROR"
# 构建警报信息
if isinstance(exception, StockSystemError):
alert_message = f"{exception.error_code}: {exception.message}"
else:
alert_message = f"{type(exception).__name__}: {str(exception)}"
# 发送警报
self.log_manager.send_alert(
alert_type="SYSTEM_EXCEPTION",
message=alert_message,
level=alert_level
)
except Exception as e:
self.logger.error(f"发送警报失败: {str(e)}")
def _get_timestamp(self) -> str:
"""获取时间戳"""
from datetime import datetime
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
def wrap_with_exception_handler(self, func):
"""
包装函数自动处理异常
Args:
func: 要包装的函数
Returns:
包装后的函数
"""
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
error_info = self.handle_exception(e)
# 可以选择重新抛出异常或返回错误信息
raise e
return wrapper
# 全局异常处理器实例
# 注释掉这行因为log_manager需要从外部传入
# exception_handler = ExceptionHandler(log_manager)
def create_data_collection_error(message: str, data_source: str = "", data_type: str = "", details: Optional[Dict[str, Any]] = None) -> DataCollectionError:
"""
创建数据采集异常
Args:
message: 异常信息
data_source: 数据源
data_type: 数据类型
details: 详细信息
Returns:
数据采集异常对象
"""
return DataCollectionError(message, data_source, data_type, details)
def create_database_error(message: str, operation: str = "", table: str = "", details: Optional[Dict[str, Any]] = None) -> DatabaseError:
"""
创建数据库异常
Args:
message: 异常信息
operation: 操作类型
table: 表名
details: 详细信息
Returns:
数据库异常对象
"""
return DatabaseError(message, operation, table, details)
def create_scheduler_error(message: str, job_id: str = "", job_type: str = "", details: Optional[Dict[str, Any]] = None) -> SchedulerError:
"""
创建定时任务异常
Args:
message: 异常信息
job_id: 任务ID
job_type: 任务类型
details: 详细信息
Returns:
定时任务异常对象
"""
return SchedulerError(message, job_id, job_type, details)
def create_validation_error(message: str, field: str = "", value: Any = None, validator: str = "", details: Optional[Dict[str, Any]] = None) -> ValidationError:
"""
创建数据验证异常
Args:
message: 异常信息
field: 字段名
value: 字段值
validator: 验证器名称
details: 详细信息
Returns:
数据验证异常对象
"""
return ValidationError(message, field, value, validator, details)

258
src/utils/logger.py Normal file
View File

@ -0,0 +1,258 @@
"""
日志管理模块
实现系统运行日志与异常报警机制
"""
import os
import sys
from datetime import datetime
from pathlib import Path
from loguru import logger
from typing import Optional, Dict, Any
class LogManager:
"""日志管理器"""
def __init__(self, log_dir: str = "logs", log_level: str = "INFO"):
"""
初始化日志管理器
Args:
log_dir: 日志目录
log_level: 日志级别
"""
self.log_dir = Path(log_dir)
self.log_level = log_level
self._setup_logging()
def _setup_logging(self):
"""配置日志系统"""
try:
# 创建日志目录
self.log_dir.mkdir(parents=True, exist_ok=True)
# 移除默认处理器
logger.remove()
# 配置控制台输出
logger.add(
sys.stderr,
level=self.log_level,
format=self._get_console_format(),
colorize=True
)
# 配置文件输出
log_file = self.log_dir / f"stock_system_{datetime.now().strftime('%Y%m%d')}.log"
logger.add(
str(log_file),
level=self.log_level,
format=self._get_file_format(),
rotation="10 MB",
retention="30 days",
compression="zip"
)
# 配置错误日志文件
error_file = self.log_dir / f"error_{datetime.now().strftime('%Y%m%d')}.log"
logger.add(
str(error_file),
level="ERROR",
format=self._get_file_format(),
rotation="5 MB",
retention="90 days"
)
logger.info("日志系统初始化完成")
except Exception as e:
print(f"日志系统初始化失败: {str(e)}")
sys.exit(1)
def _get_console_format(self) -> str:
"""获取控制台日志格式"""
return "<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
def _get_file_format(self) -> str:
"""获取文件日志格式"""
return "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}"
def get_logger(self, name: str = "stock_system") -> logger:
"""
获取指定名称的日志器
Args:
name: 日志器名称
Returns:
日志器实例
"""
return logger.bind(name=name)
def log_system_start(self):
"""记录系统启动日志"""
logger.info("=" * 50)
logger.info("A股行情分析与量化交易系统启动")
logger.info(f"启动时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
logger.info("=" * 50)
def log_system_stop(self):
"""记录系统停止日志"""
logger.info("=" * 50)
logger.info("A股行情分析与量化交易系统停止")
logger.info(f"停止时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
logger.info("=" * 50)
def log_data_collection_start(self, source: str, data_type: str):
"""
记录数据采集开始日志
Args:
source: 数据源
data_type: 数据类型
"""
logger.info(f"开始采集数据 - 数据源: {source}, 类型: {data_type}")
def log_data_collection_end(self, source: str, data_type: str, count: int, duration: float):
"""
记录数据采集结束日志
Args:
source: 数据源
data_type: 数据类型
count: 数据条数
duration: 耗时
"""
logger.info(f"数据采集完成 - 数据源: {source}, 类型: {data_type}, 条数: {count}, 耗时: {duration:.2f}")
def log_database_operation(self, operation: str, table: str, count: int = 0):
"""
记录数据库操作日志
Args:
operation: 操作类型
table: 表名
count: 影响行数
"""
if count > 0:
logger.info(f"数据库操作 - {operation} {table}, 影响行数: {count}")
else:
logger.info(f"数据库操作 - {operation} {table}")
def log_scheduler_event(self, job_id: str, event: str, details: str = ""):
"""
记录定时任务事件日志
Args:
job_id: 任务ID
event: 事件类型
details: 详细信息
"""
if details:
logger.info(f"定时任务事件 - {job_id}: {event} - {details}")
else:
logger.info(f"定时任务事件 - {job_id}: {event}")
def log_performance_metric(self, metric: str, value: float, unit: str = ""):
"""
记录性能指标日志
Args:
metric: 指标名称
value: 指标值
unit: 单位
"""
if unit:
logger.info(f"性能指标 - {metric}: {value} {unit}")
else:
logger.info(f"性能指标 - {metric}: {value}")
def log_warning(self, message: str, context: Optional[Dict[str, Any]] = None):
"""
记录警告日志
Args:
message: 警告信息
context: 上下文信息
"""
if context:
logger.warning(f"{message} - 上下文: {context}")
else:
logger.warning(message)
def log_error(self, message: str, error: Optional[Exception] = None, context: Optional[Dict[str, Any]] = None):
"""
记录错误日志
Args:
message: 错误信息
error: 异常对象
context: 上下文信息
"""
if error and context:
logger.error(f"{message} - 错误: {str(error)}, 上下文: {context}")
elif error:
logger.error(f"{message} - 错误: {str(error)}")
elif context:
logger.error(f"{message} - 上下文: {context}")
else:
logger.error(message)
def log_critical(self, message: str, error: Optional[Exception] = None):
"""
记录严重错误日志
Args:
message: 错误信息
error: 异常对象
"""
if error:
logger.critical(f"{message} - 错误: {str(error)}")
else:
logger.critical(message)
def send_alert(self, alert_type: str, message: str, level: str = "ERROR"):
"""
发送警报
Args:
alert_type: 警报类型
message: 警报信息
level: 警报级别
"""
alert_message = f"[警报] {alert_type}: {message}"
if level == "CRITICAL":
logger.critical(alert_message)
elif level == "ERROR":
logger.error(alert_message)
elif level == "WARNING":
logger.warning(alert_message)
else:
logger.info(alert_message)
def get_log_files(self) -> list:
"""
获取所有日志文件
Returns:
日志文件列表
"""
try:
log_files = []
for file_path in self.log_dir.glob("*.log"):
log_files.append({
"name": file_path.name,
"path": str(file_path),
"size": file_path.stat().st_size,
"modified": datetime.fromtimestamp(file_path.stat().st_mtime)
})
return log_files
except Exception as e:
logger.error(f"获取日志文件列表失败: {str(e)}")
return []
# 全局日志管理器实例
log_manager = LogManager()

142
test_baostock_format.py Normal file
View File

@ -0,0 +1,142 @@
"""
测试Baostock股票代码格式转换
"""
import sys
import os
import asyncio
import logging
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.data.data_initializer import DataInitializer
from src.config.settings import Settings
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def get_baostock_format_code(stock_code: str) -> str:
"""
将股票代码转换为Baostock格式
"""
if stock_code.startswith("sh.") or stock_code.startswith("sz."):
return stock_code
if stock_code.startswith("6"):
return f"sh.{stock_code}"
elif stock_code.startswith("0") or stock_code.startswith("3"):
return f"sz.{stock_code}"
else:
return stock_code
async def test_baostock_format():
"""
测试Baostock格式股票代码
"""
try:
logger.info("开始测试Baostock格式股票代码...")
# 加载配置
settings = Settings()
logger.info("配置加载成功")
# 创建数据初始化器
initializer = DataInitializer(settings)
logger.info("数据初始化器创建成功")
# 测试股票代码列表
test_codes = [
"000001", # 平安银行
"600000", # 浦发银行
"300001", # 特锐德
"000007" # 全新好
]
logger.info("测试股票代码格式转换:")
for code in test_codes:
baostock_code = get_baostock_format_code(code)
logger.info(f" {code} -> {baostock_code}")
# 测试获取K线数据
total_kline_data = []
success_count = 0
error_count = 0
for code in test_codes:
try:
baostock_code = get_baostock_format_code(code)
logger.info(f"测试获取股票{code}({baostock_code})的K线数据...")
# 使用数据管理器获取K线数据
kline_data = await initializer.data_manager.get_daily_kline_data(
baostock_code,
"2024-01-01",
"2024-01-10"
)
if kline_data:
total_kline_data.extend(kline_data)
success_count += 1
logger.info(f"股票{code}获取到{len(kline_data)}条K线数据")
# 打印前几条数据
for i, data in enumerate(kline_data[:3]):
logger.info(f"{i+1}条: {data}")
else:
logger.warning(f"股票{code}未获取到K线数据")
error_count += 1
except Exception as e:
logger.error(f"获取股票{code}K线数据失败: {str(e)}")
error_count += 1
continue
logger.info(f"测试完成: 成功{success_count}只, 失败{error_count}只, 共获取{len(total_kline_data)}条数据")
return {
"success": True,
"test_codes": test_codes,
"success_count": success_count,
"error_count": error_count,
"kline_data_count": len(total_kline_data)
}
except Exception as e:
logger.error(f"测试Baostock格式异常: {str(e)}")
return {"success": False, "error": str(e)}
async def main():
"""
主函数
"""
result = await test_baostock_format()
if result["success"]:
logger.info("Baostock格式测试成功")
print(f"测试结果: {result}")
if result["kline_data_count"] > 0:
print("✓ Baostock格式转换成功可以正常获取K线数据")
else:
print("⚠ Baostock格式转换成功但未获取到K线数据")
else:
logger.error("Baostock格式测试失败")
print(f"测试失败: {result.get('error')}")
return result
if __name__ == "__main__":
# 运行测试
result = asyncio.run(main())
# 输出最终结果
if result.get("success", False):
print("\n测试完成!")
sys.exit(0)
else:
print("\n测试失败!")
sys.exit(1)

43
test_connection.py Normal file
View File

@ -0,0 +1,43 @@
"""
测试数据库连接脚本
"""
import os
import sys
from dotenv import load_dotenv
# 手动加载.env文件
load_dotenv()
# 添加项目路径
sys.path.insert(0, os.path.dirname(__file__))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
from src.config.settings import settings
from src.storage.database import db_manager
def test_database_connection():
"""测试数据库连接"""
print("=== 测试数据库连接 ===")
print(f"数据库URL: {settings.database.database_url}")
try:
# 测试获取会话
session = db_manager.get_session()
print("✅ 数据库连接成功")
# 测试创建表
db_manager.create_tables()
print("✅ 数据库表创建成功")
session.close()
print("✅ 数据库会话关闭成功")
except Exception as e:
print(f"❌ 数据库连接失败: {e}")
return False
return True
if __name__ == "__main__":
test_database_connection()

190
test_financial_update.py Normal file
View File

@ -0,0 +1,190 @@
"""
测试财务数据更新功能
验证财务数据采集和保存功能
"""
import sys
import os
import asyncio
import logging
from datetime import date, datetime
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.storage.database import db_manager
from src.storage.stock_repository import StockRepository
from src.data.data_manager import DataManager
from src.config.settings import Settings
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
async def test_financial_update():
"""
测试财务数据更新功能
"""
try:
logger.info("开始测试财务数据更新功能...")
# 加载配置
settings = Settings()
logger.info("配置加载成功")
# 创建数据管理器
data_manager = DataManager()
logger.info("数据管理器创建成功")
# 创建存储库
repository = StockRepository(db_manager.get_session())
logger.info("存储库创建成功")
# 获取股票基础信息
stocks = repository.get_stock_basic_info()
logger.info(f"获取到{len(stocks)}只股票基础信息")
if not stocks:
logger.error("没有股票基础信息,无法测试财务数据更新")
return {"success": False, "error": "没有股票基础信息"}
# 选择前5只股票进行测试
test_stocks = stocks[:5]
test_codes = [stock.code for stock in test_stocks]
logger.info(f"测试股票代码: {test_codes}")
# 设置测试年份和季度
test_year = 2023
test_quarter = 4
total_financial_data = []
success_count = 0
error_count = 0
# 为每只测试股票获取财务数据
for stock in test_stocks:
try:
logger.info(f"获取股票{stock.code}的财务数据...")
# 使用数据管理器获取财务数据
financial_data = await data_manager.get_financial_report(
stock.code, test_year, test_quarter
)
if financial_data:
total_financial_data.extend(financial_data)
success_count += 1
logger.info(f"股票{stock.code}获取到{len(financial_data)}条财务数据")
else:
logger.warning(f"股票{stock.code}未获取到财务数据")
error_count += 1
# 小延迟避免请求过快
await asyncio.sleep(0.5)
except Exception as e:
logger.error(f"获取股票{stock.code}财务数据失败: {str(e)}")
error_count += 1
continue
# 保存财务数据
if total_financial_data:
try:
logger.info(f"开始保存{len(total_financial_data)}条财务数据...")
save_result = repository.save_financial_report_data(total_financial_data)
logger.info(f"财务数据保存结果: {save_result}")
# 验证保存结果
if save_result.get("added_count", 0) > 0 or save_result.get("updated_count", 0) > 0:
logger.info("财务数据保存成功")
else:
logger.warning("财务数据保存失败或没有新增数据")
except Exception as e:
logger.error(f"保存财务数据失败: {str(e)}")
error_count += len(test_stocks)
else:
logger.warning("没有获取到任何财务数据")
# 验证数据库中的财务数据
try:
financial_count = repository.session.query(repository.FinancialReport).count()
logger.info(f"财务报告表: {financial_count}条记录")
if financial_count > 0:
# 显示最新的5条财务数据
latest_financial = repository.session.query(repository.FinancialReport).order_by(
repository.FinancialReport.report_date.desc()
).limit(5).all()
logger.info("最新的5条财务数据:")
for i, financial in enumerate(latest_financial):
logger.info(f" {i+1}. {financial.stock_code} - {financial.report_date} - EPS: {financial.eps}")
except Exception as e:
logger.error(f"查询财务数据失败: {str(e)}")
# 汇总测试结果
result = {
"success": True,
"test_stocks": len(test_stocks),
"success_count": success_count,
"error_count": error_count,
"financial_data_count": len(total_financial_data),
"saved_count": save_result.get("added_count", 0) + save_result.get("updated_count", 0) if total_financial_data else 0
}
logger.info(f"财务数据更新测试完成: {result}")
return result
except Exception as e:
logger.error(f"财务数据更新测试异常: {str(e)}")
return {"success": False, "error": str(e)}
def main():
"""
主函数
"""
logger.info("开始财务数据更新测试...")
# 运行异步测试
result = asyncio.run(test_financial_update())
if result["success"]:
logger.info("财务数据更新测试成功!")
print(f"测试结果: {result}")
if result["financial_data_count"] > 0:
print("✓ 财务数据获取成功")
print(f"✓ 共获取{result['financial_data_count']}条财务数据")
else:
print("⚠ 未获取到财务数据")
if result["saved_count"] > 0:
print("✓ 财务数据保存成功")
print(f"✓ 共保存{result['saved_count']}条财务数据")
else:
print("⚠ 财务数据保存失败")
print(f"✓ 成功股票数: {result['success_count']}")
print(f"✓ 失败股票数: {result['error_count']}")
else:
logger.error("财务数据更新测试失败!")
print(f"测试失败: {result.get('error')}")
return result
if __name__ == "__main__":
# 运行测试
result = main()
# 输出最终结果
if result.get("success", False):
print("\n财务数据更新测试完成!")
sys.exit(0)
else:
print("\n财务数据更新测试失败!")
sys.exit(1)

115
test_simple_update.py Normal file
View File

@ -0,0 +1,115 @@
"""
简单测试脚本
测试数据更新功能
"""
import sys
import os
import asyncio
import logging
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.data.data_initializer import DataInitializer
from src.config.settings import Settings
from src.storage.database import db_manager
from src.storage.stock_repository import StockRepository
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
async def test_data_update():
"""
测试数据更新功能
"""
try:
logger.info("开始测试数据更新功能...")
# 加载配置
settings = Settings()
logger.info("配置加载成功")
# 创建数据初始化器
initializer = DataInitializer(settings)
logger.info("数据初始化器创建成功")
# 创建存储库
repository = StockRepository(db_manager.get_session())
logger.info("存储库创建成功")
# 获取所有股票代码
stocks = repository.get_stock_basic_info()
logger.info(f"找到{len(stocks)}只股票")
if not stocks:
logger.error("没有股票基础信息,无法测试")
return {"success": False, "error": "没有股票基础信息"}
# 只测试前5只股票
test_stocks = stocks[:5]
logger.info(f"测试前{len(test_stocks)}只股票")
# 测试获取K线数据
total_kline_data = []
for stock in test_stocks:
try:
logger.info(f"测试获取股票{stock.code}的K线数据...")
# 使用数据管理器获取K线数据
kline_data = await initializer.data_manager.get_daily_kline_data(
stock.code,
"2024-01-01",
"2024-01-10"
)
if kline_data:
total_kline_data.extend(kline_data)
logger.info(f"股票{stock.code}获取到{len(kline_data)}条K线数据")
else:
logger.warning(f"股票{stock.code}未获取到K线数据")
except Exception as e:
logger.error(f"获取股票{stock.code}K线数据失败: {str(e)}")
continue
logger.info(f"测试完成,共获取{len(total_kline_data)}条K线数据")
return {
"success": True,
"test_stock_count": len(test_stocks),
"kline_data_count": len(total_kline_data)
}
except Exception as e:
logger.error(f"测试数据更新功能异常: {str(e)}")
return {"success": False, "error": str(e)}
async def main():
"""
主函数
"""
result = await test_data_update()
if result["success"]:
logger.info("数据更新功能测试成功!")
print(f"测试结果: {result}")
else:
logger.error("数据更新功能测试失败!")
print(f"测试失败: {result.get('error')}")
return result
if __name__ == "__main__":
# 运行测试
result = asyncio.run(main())
# 输出最终结果
if result.get("success", False):
print("测试成功!")
sys.exit(0)
else:
print("测试失败!")
sys.exit(1)

1
tests/__init__.py Normal file
View File

@ -0,0 +1 @@
# 测试模块

273
tests/conftest.py Normal file
View File

@ -0,0 +1,273 @@
"""
测试配置文件
定义测试用的fixture和配置
"""
import pytest
import asyncio
from pathlib import Path
from typing import Generator, AsyncGenerator
from src.config.settings import Settings
from src.storage.database import DatabaseManager
from src.data.data_processor import DataProcessor
from src.utils.logger import LogManager
@pytest.fixture(scope="session")
def test_settings() -> Settings:
"""测试配置"""
# 创建测试配置实例
return Settings()
@pytest.fixture(scope="session")
def test_log_manager(test_settings) -> LogManager:
"""测试日志管理器"""
return LogManager(
log_dir="./test_logs",
log_level=test_settings.log.log_level
)
def create_test_database_manager():
"""创建测试数据库管理器"""
# 创建一个简单的测试数据库管理器
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
class TestDatabaseManager:
def __init__(self):
self.engine = create_engine("sqlite:///:memory:")
self.SessionLocal = sessionmaker(bind=self.engine)
# 创建独立的Base类用于测试
self.Base = declarative_base()
# 导入模型类并重新定义
self._import_models()
def _import_models(self):
"""导入并重新定义模型类"""
from sqlalchemy import Column, Integer, String, Float, Date, DateTime, Text, Boolean, BigInteger, ForeignKey, Index
from sqlalchemy.sql import func
# 重新定义StockBasic模型
class StockBasic(self.Base):
__tablename__ = "stock_basic"
id = Column(Integer, primary_key=True, autoincrement=True)
code = Column(String(10), nullable=False, unique=True)
name = Column(String(50), nullable=False)
market = Column(String(10), nullable=False)
company_name = Column(String(100))
industry = Column(String(50))
area = Column(String(50))
ipo_date = Column(Date)
listing_status = Column(Boolean, default=True)
data_source = Column(String(50), default="akshare") # 添加data_source字段
created_at = Column(DateTime, server_default=func.now())
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now())
__table_args__ = (
Index("idx_code", "code"),
Index("idx_market", "market"),
Index("idx_industry", "industry"),
Index("idx_ipo_date", "ipo_date")
)
# 重新定义DailyKline模型
class DailyKline(self.Base):
__tablename__ = "daily_kline"
id = Column(Integer, primary_key=True, autoincrement=True)
stock_code = Column(String(10), ForeignKey("stock_basic.code"), nullable=False)
trade_date = Column(Date, nullable=False)
open_price = Column(Float, nullable=False)
high_price = Column(Float, nullable=False)
low_price = Column(Float, nullable=False)
close_price = Column(Float, nullable=False)
volume = Column(BigInteger)
amount = Column(Float)
change = Column(Float)
pct_change = Column(Float)
created_at = Column(DateTime, server_default=func.now())
__table_args__ = (
Index("idx_stock_code_date", "stock_code", "trade_date"),
Index("idx_trade_date", "trade_date")
)
# 重新定义FinancialReport模型
class FinancialReport(self.Base):
__tablename__ = "financial_report"
id = Column(Integer, primary_key=True, autoincrement=True)
stock_code = Column(String(10), ForeignKey("stock_basic.code"), nullable=False)
report_date = Column(Date, nullable=False)
report_type = Column(String(20), nullable=False)
report_year = Column(Integer, nullable=False)
report_quarter = Column(Integer)
eps = Column(Float)
net_profit = Column(Float)
revenue = Column(Float)
total_assets = Column(Float)
total_liabilities = Column(Float)
equity = Column(Float)
roe = Column(Float)
gross_profit_margin = Column(Float)
net_profit_margin = Column(Float)
debt_to_asset_ratio = Column(Float)
current_ratio = Column(Float)
quick_ratio = Column(Float)
created_at = Column(DateTime, server_default=func.now())
__table_args__ = (
Index("idx_stock_code_report", "stock_code", "report_date"),
Index("idx_report_date", "report_date"),
Index("idx_report_type", "report_type")
)
self.StockBasic = StockBasic
self.DailyKline = DailyKline
self.FinancialReport = FinancialReport
def create_tables(self):
"""创建所有数据表"""
try:
# 创建所有表
self.Base.metadata.create_all(bind=self.engine)
return True
except Exception as e:
print(f"创建数据库表失败: {str(e)}")
return False
def drop_tables(self):
"""删除所有数据表"""
try:
self.Base.metadata.drop_all(bind=self.engine)
return True
except Exception as e:
print(f"删除数据库表失败: {str(e)}")
return False
def get_session(self):
"""获取数据库会话"""
try:
session = self.SessionLocal()
return session
except Exception as e:
print(f"获取数据库会话失败: {str(e)}")
raise
return TestDatabaseManager()
@pytest.fixture(scope="function")
def db_manager():
"""测试数据库管理器fixture"""
db_manager = create_test_database_manager()
# 创建测试数据库
db_manager.create_tables()
yield db_manager
# 清理测试数据库
db_manager.drop_tables()
@pytest.fixture(scope="session")
def test_data_processor() -> DataProcessor:
"""测试数据处理器"""
return DataProcessor()
@pytest.fixture(scope="session")
def event_loop() -> Generator[asyncio.AbstractEventLoop, None, None]:
"""异步事件循环"""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest.fixture
def sample_stock_basic_data() -> list:
"""示例股票基础数据"""
return [
{
"code": "000001",
"name": "平安银行",
"market": "主板",
"industry": "银行",
"area": "广东",
"ipo_date": "1991-04-03"
},
{
"code": "600000",
"name": "浦发银行",
"market": "主板",
"industry": "银行",
"area": "上海",
"ipo_date": "1999-11-10"
}
]
@pytest.fixture
def sample_kline_data() -> list:
"""示例K线数据"""
return [
{
"code": "000001",
"date": "2024-01-15",
"open": 10.5,
"high": 11.2,
"low": 10.3,
"close": 10.8,
"volume": 1000000,
"amount": 10800000
},
{
"code": "000001",
"date": "2024-01-16",
"open": 10.8,
"high": 11.5,
"low": 10.7,
"close": 11.2,
"volume": 1200000,
"amount": 13440000
}
]
@pytest.fixture
def sample_financial_data() -> list:
"""示例财务数据"""
return [
{
"code": "000001",
"report_date": "2023-12-31",
"eps": 1.5,
"net_profit": 1500000000,
"revenue": 5000000000,
"total_assets": 10000000000
}
]
@pytest.fixture
def invalid_stock_data() -> list:
"""无效股票数据"""
return [
{
"code": "invalid_code",
"name": "",
"market": "无效市场",
"ipo_date": "2099-01-01" # 未来日期
},
{
"code": "000001",
"name": "测试股票",
"open": -10.5, # 负价格
"high": 9.0, # 高价低于低价
"low": 11.0,
"close": 10.0
}
]

View File

@ -0,0 +1,300 @@
"""
数据采集接口单元测试
测试AKshare和Baostock数据采集器的功能
"""
import pytest
import asyncio
from unittest.mock import Mock, patch, AsyncMock
from src.data.akshare_collector import AKshareCollector
from src.data.baostock_collector import BaostockCollector
from src.data.data_manager import DataManager
from src.utils.exceptions import DataCollectionError
class TestAKshareCollector:
"""AKshare采集器测试类"""
@pytest.fixture
def akshare_collector(self):
"""AKshare采集器实例"""
return AKshareCollector()
@pytest.mark.asyncio
async def test_get_stock_basic_info_success(self, akshare_collector):
"""测试成功获取股票基础信息"""
# Mock akshare API调用
mock_data = [
{
"代码": "000001",
"名称": "平安银行",
"市场": "主板",
"行业": "银行",
"地区": "广东",
"上市日期": "1991-04-03"
}
]
with patch("akshare.stock_info_a_code_name", return_value=mock_data):
result = await akshare_collector.get_stock_basic_info()
assert len(result) == 1
assert result[0]["code"] == "000001"
assert result[0]["name"] == "平安银行"
assert result[0]["market"] == "主板"
assert result[0]["industry"] == "银行"
assert result[0]["area"] == "广东"
assert result[0]["ipo_date"] == "1991-04-03"
@pytest.mark.asyncio
async def test_get_stock_basic_info_empty(self, akshare_collector):
"""测试获取空股票基础信息"""
with patch("akshare.stock_info_a_code_name", return_value=[]):
result = await akshare_collector.get_stock_basic_info()
assert len(result) == 0
@pytest.mark.asyncio
async def test_get_stock_basic_info_error(self, akshare_collector):
"""测试获取股票基础信息异常"""
with patch("akshare.stock_info_a_code_name", side_effect=Exception("API错误")):
with pytest.raises(DataCollectionError):
await akshare_collector.get_stock_basic_info()
@pytest.mark.asyncio
async def test_get_daily_kline_data_success(self, akshare_collector):
"""测试成功获取日K线数据"""
mock_data = [
{
"日期": "2024-01-15",
"开盘": 10.5,
"最高": 11.2,
"最低": 10.3,
"收盘": 10.8,
"成交量": 1000000,
"成交额": 10800000
}
]
with patch("akshare.stock_zh_a_hist", return_value=mock_data):
result = await akshare_collector.get_daily_kline_data("000001", "20240101", "20240115")
assert len(result) == 1
assert result[0]["date"] == "2024-01-15"
assert result[0]["open"] == 10.5
assert result[0]["high"] == 11.2
assert result[0]["low"] == 10.3
assert result[0]["close"] == 10.8
assert result[0]["volume"] == 1000000
assert result[0]["amount"] == 10800000
@pytest.mark.asyncio
async def test_get_financial_report_success(self, akshare_collector):
"""测试成功获取财务报告"""
mock_data = [
{
"报告期": "2023-12-31",
"每股收益": 1.5,
"净利润": 1500000000,
"营业收入": 5000000000,
"总资产": 10000000000
}
]
with patch("akshare.stock_financial_report_szse", return_value=mock_data):
result = await akshare_collector.get_financial_report("000001", "2023")
assert len(result) == 1
assert result[0]["report_date"] == "2023-12-31"
assert result[0]["eps"] == 1.5
assert result[0]["net_profit"] == 1500000000
assert result[0]["revenue"] == 5000000000
assert result[0]["total_assets"] == 10000000000
class TestBaostockCollector:
"""Baostock采集器测试类"""
@pytest.fixture
def baostock_collector(self):
"""Baostock采集器实例"""
return BaostockCollector()
@pytest.mark.asyncio
async def test_login_logout_success(self, baostock_collector):
"""测试登录登出成功"""
with patch("baostock.login", return_value=(0, "")) as mock_login:
with patch("baostock.logout", return_value=(0, "")) as mock_logout:
# 测试登录
result = await baostock_collector.login()
assert result is True
mock_login.assert_called_once()
# 测试登出
result = await baostock_collector.logout()
assert result is True
mock_logout.assert_called_once()
@pytest.mark.asyncio
async def test_login_failure(self, baostock_collector):
"""测试登录失败"""
with patch("baostock.login", return_value=(1, "登录失败")):
result = await baostock_collector.login()
assert result is False
@pytest.mark.asyncio
async def test_get_stock_basic_info_success(self, baostock_collector):
"""测试成功获取股票基础信息"""
# Mock baostock API调用
mock_result = Mock()
mock_result.error_code = 0
mock_result.error_msg = ""
mock_result.data = [
["000001", "平安银行", "主板", "银行", "广东", "1991-04-03"]
]
with patch("baostock.query_stock_basic", return_value=mock_result):
with patch.object(baostock_collector, "login", return_value=True):
result = await baostock_collector.get_stock_basic_info()
assert len(result) == 1
assert result[0]["code"] == "000001"
assert result[0]["name"] == "平安银行"
assert result[0]["market"] == "主板"
assert result[0]["industry"] == "银行"
assert result[0]["area"] == "广东"
assert result[0]["ipo_date"] == "1991-04-03"
@pytest.mark.asyncio
async def test_get_daily_kline_data_success(self, baostock_collector):
"""测试成功获取日K线数据"""
mock_result = Mock()
mock_result.error_code = 0
mock_result.error_msg = ""
mock_result.data = [
["000001", "2024-01-15", 10.5, 11.2, 10.3, 10.8, 1000000, 10800000]
]
with patch("baostock.query_history_k_data_plus", return_value=mock_result):
with patch.object(baostock_collector, "login", return_value=True):
result = await baostock_collector.get_daily_kline_data("000001", "2024-01-01", "2024-01-15")
assert len(result) == 1
assert result[0]["code"] == "000001"
assert result[0]["date"] == "2024-01-15"
assert result[0]["open"] == 10.5
assert result[0]["high"] == 11.2
assert result[0]["low"] == 10.3
assert result[0]["close"] == 10.8
assert result[0]["volume"] == 1000000
assert result[0]["amount"] == 10800000
class TestDataManager:
"""数据管理器测试类"""
@pytest.fixture
def data_manager(self):
"""数据管理器实例"""
return DataManager()
@pytest.mark.asyncio
async def test_get_stock_basic_info_success(self, data_manager):
"""测试成功获取股票基础信息"""
# Mock采集器返回数据
mock_akshare_data = [
{"code": "000001", "name": "平安银行", "data_source": "akshare"}
]
mock_baostock_data = [
{"code": "600000", "name": "浦发银行", "data_source": "baostock"}
]
with patch.object(data_manager.akshare_collector, "get_stock_basic_info", return_value=mock_akshare_data):
with patch.object(data_manager.baostock_collector, "get_stock_basic_info", return_value=mock_baostock_data):
result = await data_manager.get_stock_basic_info()
assert len(result) == 2
assert any(item["code"] == "000001" for item in result)
assert any(item["code"] == "600000" for item in result)
@pytest.mark.asyncio
async def test_get_stock_basic_info_duplicate(self, data_manager):
"""测试去重功能"""
# Mock采集器返回重复数据
mock_akshare_data = [
{"code": "000001", "name": "平安银行", "data_source": "akshare"}
]
mock_baostock_data = [
{"code": "000001", "name": "平安银行", "data_source": "baostock"}
]
with patch.object(data_manager.akshare_collector, "get_stock_basic_info", return_value=mock_akshare_data):
with patch.object(data_manager.baostock_collector, "get_stock_basic_info", return_value=mock_baostock_data):
result = await data_manager.get_stock_basic_info()
# 应该去重,只保留一条
assert len(result) == 1
assert result[0]["code"] == "000001"
@pytest.mark.asyncio
async def test_get_daily_kline_data_success(self, data_manager):
"""测试成功获取日K线数据"""
mock_data = [
{"code": "000001", "date": "2024-01-15", "open": 10.5, "data_source": "akshare"}
]
with patch.object(data_manager.akshare_collector, "get_daily_kline_data", return_value=mock_data):
result = await data_manager.get_daily_kline_data("000001", "20240101", "20240115")
assert len(result) == 1
assert result[0]["code"] == "000001"
assert result[0]["date"] == "2024-01-15"
@pytest.mark.asyncio
async def test_get_financial_report_success(self, data_manager):
"""测试成功获取财务报告"""
mock_data = [
{"code": "000001", "report_date": "2023-12-31", "eps": 1.5, "data_source": "akshare"}
]
with patch.object(data_manager.akshare_collector, "get_financial_report", return_value=mock_data):
result = await data_manager.get_financial_report("000001", "2023")
assert len(result) == 1
assert result[0]["code"] == "000001"
assert result[0]["report_date"] == "2023-12-31"
assert result[0]["eps"] == 1.5
class TestRetryMechanism:
"""重试机制测试类"""
@pytest.mark.asyncio
async def test_retry_on_failure(self):
"""测试失败重试"""
collector = AKshareCollector()
# Mock函数前两次失败第三次成功
mock_func = AsyncMock()
mock_func.side_effect = [Exception("第一次失败"), Exception("第二次失败"), ["成功数据"]]
with patch.object(collector, "_make_request_with_retry", mock_func):
result = await collector._make_request_with_retry(mock_func, "test", max_retries=3)
assert result == ["成功数据"]
assert mock_func.call_count == 3
@pytest.mark.asyncio
async def test_retry_exceed_max_attempts(self):
"""测试超过最大重试次数"""
collector = AKshareCollector()
# Mock函数一直失败
mock_func = AsyncMock(side_effect=Exception("一直失败"))
with patch.object(collector, "_make_request_with_retry", mock_func):
with pytest.raises(DataCollectionError):
await collector._make_request_with_retry(mock_func, "test", max_retries=2)
assert mock_func.call_count == 3 # 初始调用 + 2次重试

View File

@ -0,0 +1,473 @@
"""
数据处理和定时任务模块单元测试
测试数据处理器和定时任务调度器的功能
"""
import pytest
import asyncio
from unittest.mock import Mock, patch, AsyncMock
from datetime import datetime, timedelta
from src.data.data_processor import DataProcessor
from src.scheduler.task_scheduler import TaskScheduler
from src.utils.exceptions import DataProcessingError
class TestDataProcessor:
"""数据处理器测试类"""
@pytest.fixture
def data_processor(self):
"""数据处理器实例"""
return DataProcessor()
def test_process_stock_basic_info_success(self, data_processor):
"""测试处理股票基础信息成功"""
raw_data = [
{
"代码": "000001",
"名称": "平安银行",
"市场": "主板",
"行业": "银行",
"地区": "广东",
"上市日期": "1991-04-03",
"数据源": "akshare"
}
]
result = data_processor.process_stock_basic_info(raw_data)
assert len(result) == 1
assert result[0]["code"] == "000001"
assert result[0]["name"] == "平安银行"
assert result[0]["market"] == "主板"
assert result[0]["industry"] == "银行"
assert result[0]["area"] == "广东"
assert result[0]["ipo_date"] == "1991-04-03"
assert result[0]["data_source"] == "akshare"
def test_process_stock_basic_info_missing_fields(self, data_processor):
"""测试处理缺失字段的股票基础信息"""
raw_data = [
{
"代码": "000001",
"名称": "平安银行"
# 缺少其他字段
}
]
result = data_processor.process_stock_basic_info(raw_data)
assert len(result) == 1
assert result[0]["code"] == "000001"
assert result[0]["name"] == "平安银行"
assert result[0]["market"] == "未知" # 默认值
assert result[0]["industry"] == "未知" # 默认值
assert result[0]["area"] == "未知" # 默认值
assert result[0]["ipo_date"] is None # 默认值
def test_process_daily_kline_data_success(self, data_processor):
"""测试处理日K线数据成功"""
raw_data = [
{
"日期": "2024-01-15",
"开盘": 10.5,
"最高": 11.2,
"最低": 10.3,
"收盘": 10.8,
"成交量": 1000000,
"成交额": 10800000,
"数据源": "akshare"
}
]
result = data_processor.process_daily_kline_data("000001", raw_data)
assert len(result) == 1
assert result[0]["code"] == "000001"
assert result[0]["date"] == "2024-01-15"
assert result[0]["open"] == 10.5
assert result[0]["high"] == 11.2
assert result[0]["low"] == 10.3
assert result[0]["close"] == 10.8
assert result[0]["volume"] == 1000000
assert result[0]["amount"] == 10800000
assert result[0]["data_source"] == "akshare"
def test_process_daily_kline_data_invalid_values(self, data_processor):
"""测试处理无效值的日K线数据"""
raw_data = [
{
"日期": "2024-01-15",
"开盘": "无效值", # 字符串而不是数字
"最高": 11.2,
"最低": 10.3,
"收盘": 10.8,
"成交量": 1000000,
"成交额": 10800000
}
]
with pytest.raises(DataProcessingError):
data_processor.process_daily_kline_data("000001", raw_data)
def test_process_financial_report_data_success(self, data_processor):
"""测试处理财务报告数据成功"""
raw_data = [
{
"报告期": "2023-12-31",
"报告类型": "年报",
"每股收益": 1.5,
"净利润": 1500000000,
"营业收入": 5000000000,
"总资产": 10000000000,
"数据源": "akshare"
}
]
result = data_processor.process_financial_report_data("000001", raw_data)
assert len(result) == 1
assert result[0]["code"] == "000001"
assert result[0]["report_date"] == "2023-12-31"
assert result[0]["report_type"] == "年报"
assert result[0]["eps"] == 1.5
assert result[0]["net_profit"] == 1500000000
assert result[0]["revenue"] == 5000000000
assert result[0]["total_assets"] == 10000000000
assert result[0]["data_source"] == "akshare"
def test_validate_data_success(self, data_processor):
"""测试数据验证成功"""
valid_data = [
{
"code": "000001",
"name": "平安银行",
"market": "主板",
"ipo_date": "1991-04-03"
}
]
result = data_processor._validate_data(valid_data, ["code", "name"])
assert result is True
def test_validate_data_missing_required_fields(self, data_processor):
"""测试数据验证缺失必要字段"""
invalid_data = [
{
"code": "000001"
# 缺少name字段
}
]
with pytest.raises(DataProcessingError):
data_processor._validate_data(invalid_data, ["code", "name"])
def test_clean_data_success(self, data_processor):
"""测试数据清洗成功"""
dirty_data = [
{
"code": " 000001 ", # 有空格
"name": "平安银行",
"market": "主板",
"ipo_date": "1991-04-03"
}
]
result = data_processor._clean_data(dirty_data)
assert len(result) == 1
assert result[0]["code"] == "000001" # 空格被去除
def test_standardize_data_success(self, data_processor):
"""测试数据标准化成功"""
raw_data = [
{
"代码": "000001",
"名称": "平安银行",
"市场": "主板"
}
]
field_mapping = {
"代码": "code",
"名称": "name",
"市场": "market"
}
result = data_processor._standardize_data(raw_data, field_mapping)
assert len(result) == 1
assert "code" in result[0]
assert "name" in result[0]
assert "market" in result[0]
assert result[0]["code"] == "000001"
assert result[0]["name"] == "平安银行"
assert result[0]["market"] == "主板"
class TestTaskScheduler:
"""定时任务调度器测试类"""
@pytest.fixture
def task_scheduler(self):
"""定时任务调度器实例"""
return TaskScheduler()
def test_scheduler_initialization(self, task_scheduler):
"""测试调度器初始化"""
assert task_scheduler.scheduler is None
assert task_scheduler.is_running is False
@pytest.mark.asyncio
async def test_start_scheduler_success(self, task_scheduler):
"""测试启动调度器成功"""
with patch.object(task_scheduler, "_configure_jobs"):
result = await task_scheduler.start()
assert result is True
assert task_scheduler.is_running is True
assert task_scheduler.scheduler is not None
@pytest.mark.asyncio
async def test_stop_scheduler_success(self, task_scheduler):
"""测试停止调度器成功"""
# 先启动调度器
with patch.object(task_scheduler, "_configure_jobs"):
await task_scheduler.start()
# 停止调度器
result = await task_scheduler.stop()
assert result is True
assert task_scheduler.is_running is False
@pytest.mark.asyncio
async def test_stop_scheduler_not_running(self, task_scheduler):
"""测试停止未运行的调度器"""
result = await task_scheduler.stop()
assert result is False
def test_configure_daily_kline_job(self, task_scheduler):
"""测试配置每日K线数据更新任务"""
with patch.object(task_scheduler, "_add_job") as mock_add_job:
task_scheduler._configure_daily_kline_job()
mock_add_job.assert_called_once()
call_args = mock_add_job.call_args[0]
assert call_args[0] == "daily_kline_update"
assert call_args[1] == "cron"
assert call_args[2] == task_scheduler._update_daily_kline_data
assert call_args[3] == {"hour": 18, "minute": 0} # 下午6点
def test_configure_weekly_financial_job(self, task_scheduler):
"""测试配置每周财务数据更新任务"""
with patch.object(task_scheduler, "_add_job") as mock_add_job:
task_scheduler._configure_weekly_financial_job()
mock_add_job.assert_called_once()
call_args = mock_add_job.call_args[0]
assert call_args[0] == "weekly_financial_update"
assert call_args[1] == "cron"
assert call_args[2] == task_scheduler._update_financial_data
assert call_args[3] == {"day_of_week": 0, "hour": 20, "minute": 0} # 周日晚上8点
def test_configure_monthly_basic_job(self, task_scheduler):
"""测试配置每月股票基础信息更新任务"""
with patch.object(task_scheduler, "_add_job") as mock_add_job:
task_scheduler._configure_monthly_basic_job()
mock_add_job.assert_called_once()
call_args = mock_add_job.call_args[0]
assert call_args[0] == "monthly_basic_update"
assert call_args[1] == "cron"
assert call_args[2] == task_scheduler._update_stock_basic_info
assert call_args[3] == {"day": 1, "hour": 22, "minute": 0} # 每月1日晚上10点
def test_configure_daily_health_check_job(self, task_scheduler):
"""测试配置每日健康检查任务"""
with patch.object(task_scheduler, "_add_job") as mock_add_job:
task_scheduler._configure_daily_health_check_job()
mock_add_job.assert_called_once()
call_args = mock_add_job.call_args[0]
assert call_args[0] == "daily_health_check"
assert call_args[1] == "cron"
assert call_args[2] == task_scheduler._health_check
assert call_args[3] == {"hour": 9, "minute": 0} # 早上9点
@pytest.mark.asyncio
async def test_update_daily_kline_data_success(self, task_scheduler):
"""测试更新日K线数据成功"""
mock_data = [
{
"code": "000001",
"date": "2024-01-15",
"open": 10.5,
"close": 10.8
}
]
with patch.object(task_scheduler.data_manager, "get_daily_kline_data", return_value=mock_data):
with patch.object(task_scheduler.stock_repo, "save_daily_kline_data", return_value=True):
result = await task_scheduler._update_daily_kline_data()
assert result is True
@pytest.mark.asyncio
async def test_update_daily_kline_data_failure(self, task_scheduler):
"""测试更新日K线数据失败"""
with patch.object(task_scheduler.data_manager, "get_daily_kline_data", side_effect=Exception("API错误")):
result = await task_scheduler._update_daily_kline_data()
assert result is False
@pytest.mark.asyncio
async def test_update_financial_data_success(self, task_scheduler):
"""测试更新财务数据成功"""
mock_data = [
{
"code": "000001",
"report_date": "2023-12-31",
"eps": 1.5
}
]
with patch.object(task_scheduler.data_manager, "get_financial_report", return_value=mock_data):
with patch.object(task_scheduler.stock_repo, "save_financial_report_data", return_value=True):
result = await task_scheduler._update_financial_data()
assert result is True
@pytest.mark.asyncio
async def test_update_stock_basic_info_success(self, task_scheduler):
"""测试更新股票基础信息成功"""
mock_data = [
{
"code": "000001",
"name": "平安银行",
"market": "主板"
}
]
with patch.object(task_scheduler.data_manager, "get_stock_basic_info", return_value=mock_data):
with patch.object(task_scheduler.stock_repo, "save_stock_basic_info", return_value=True):
result = await task_scheduler._update_stock_basic_info()
assert result is True
@pytest.mark.asyncio
async def test_health_check_success(self, task_scheduler):
"""测试健康检查成功"""
with patch.object(task_scheduler.data_manager, "get_stock_basic_info", return_value=[]):
with patch.object(task_scheduler.stock_repo, "get_stock_basic_info", return_value=[]):
result = await task_scheduler._health_check()
assert result is True
def test_get_job_status(self, task_scheduler):
"""测试获取任务状态"""
# 启动调度器
with patch.object(task_scheduler, "_configure_jobs"):
task_scheduler.start()
# 添加一个测试任务
def test_job():
pass
task_scheduler._add_job("test_job", "date", test_job, {"run_date": datetime.now() + timedelta(hours=1)})
# 获取任务状态
status = task_scheduler.get_job_status()
assert "test_job" in status
assert status["test_job"]["next_run_time"] is not None
def test_remove_job_success(self, task_scheduler):
"""测试移除任务成功"""
# 启动调度器
with patch.object(task_scheduler, "_configure_jobs"):
task_scheduler.start()
# 添加一个测试任务
def test_job():
pass
task_scheduler._add_job("test_job", "date", test_job, {"run_date": datetime.now() + timedelta(hours=1)})
# 移除任务
result = task_scheduler.remove_job("test_job")
assert result is True
# 验证任务已被移除
status = task_scheduler.get_job_status()
assert "test_job" not in status
def test_remove_job_not_found(self, task_scheduler):
"""测试移除不存在的任务"""
# 启动调度器
with patch.object(task_scheduler, "_configure_jobs"):
task_scheduler.start()
# 移除不存在的任务
result = task_scheduler.remove_job("nonexistent_job")
assert result is False
class TestIntegration:
"""集成测试类"""
@pytest.mark.asyncio
async def test_data_processing_pipeline(self):
"""测试数据处理流水线"""
# 创建数据处理器
processor = DataProcessor()
# 模拟原始数据
raw_stock_data = [
{
"代码": "000001",
"名称": "平安银行",
"市场": "主板",
"行业": "银行",
"地区": "广东",
"上市日期": "1991-04-03"
}
]
# 处理数据
processed_data = processor.process_stock_basic_info(raw_stock_data)
# 验证处理结果
assert len(processed_data) == 1
assert processed_data[0]["code"] == "000001"
assert processed_data[0]["name"] == "平安银行"
assert processed_data[0]["market"] == "主板"
assert processed_data[0]["industry"] == "银行"
assert processed_data[0]["area"] == "广东"
assert processed_data[0]["ipo_date"] == "1991-04-03"
@pytest.mark.asyncio
async def test_scheduler_integration(self):
"""测试调度器集成"""
scheduler = TaskScheduler()
# Mock所有依赖
with patch.object(scheduler, "_configure_jobs"):
# 启动调度器
result = await scheduler.start()
assert result is True
assert scheduler.is_running is True
# 停止调度器
result = await scheduler.stop()
assert result is True
assert scheduler.is_running is False

578
tests/test_integration.py Normal file
View File

@ -0,0 +1,578 @@
"""
集成测试模块
测试整个股票分析系统的集成功能
"""
import pytest
import asyncio
from unittest.mock import Mock, patch, AsyncMock
import tempfile
import os
from src.main import StockAnalysisSystem
from src.data.data_manager import DataManager
from src.storage.database import DatabaseManager
from src.storage.stock_repository import StockRepository
from src.scheduler.task_scheduler import TaskScheduler
from src.data.data_processor import DataProcessor
from src.utils.logger import LogManager
from src.utils.exceptions import StockSystemError
class TestSystemIntegration:
"""系统集成测试类"""
@pytest.fixture
def temp_database(self):
"""创建临时数据库"""
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
temp_db_path = tmp.name
yield temp_db_path
# 清理临时文件
if os.path.exists(temp_db_path):
os.unlink(temp_db_path)
@pytest.fixture
def stock_system(self, temp_database):
"""股票分析系统实例(使用临时数据库)"""
system = StockAnalysisSystem()
# 配置临时数据库路径
system.db_config = {
"database_url": f"sqlite:///{temp_database}",
"echo": False,
"pool_size": 5,
"max_overflow": 10
}
return system
@pytest.mark.asyncio
async def test_complete_system_initialization(self, stock_system):
"""测试完整系统初始化流程"""
# 执行系统初始化
result = await stock_system.initialize()
assert result is True
assert stock_system.is_initialized is True
# 验证所有组件都已正确初始化
assert stock_system.data_manager is not None
assert isinstance(stock_system.data_manager, DataManager)
assert stock_system.data_processor is not None
assert isinstance(stock_system.data_processor, DataProcessor)
assert stock_system.stock_repo is not None
assert isinstance(stock_system.stock_repo, StockRepository)
assert stock_system.task_scheduler is not None
assert isinstance(stock_system.task_scheduler, TaskScheduler)
# 验证数据库连接正常
assert stock_system.stock_repo.db_manager is not None
assert stock_system.stock_repo.db_manager.engine is not None
@pytest.mark.asyncio
async def test_data_collection_and_storage_integration(self, stock_system):
"""测试数据采集与存储集成"""
# 初始化系统
await stock_system.initialize()
# Mock数据采集结果
mock_stock_data = [
{
"code": "000001",
"name": "平安银行",
"industry": "银行",
"market": "深圳",
"list_date": "1991-04-03"
}
]
mock_kline_data = [
{
"code": "000001",
"date": "2024-01-15",
"open": 10.5,
"high": 11.2,
"low": 10.3,
"close": 10.8,
"volume": 1000000,
"amount": 10800000
}
]
# Mock数据采集器
with patch.object(stock_system.data_manager, "get_stock_basic_info", return_value=mock_stock_data):
with patch.object(stock_system.data_manager, "get_daily_kline_data", return_value=mock_kline_data):
# 测试数据采集
stock_info = await stock_system.data_manager.get_stock_basic_info()
kline_data = await stock_system.data_manager.get_daily_kline_data("000001", "2024-01-15", "2024-01-15")
assert len(stock_info) == 1
assert stock_info[0]["code"] == "000001"
assert len(kline_data) == 1
assert kline_data[0]["code"] == "000001"
# 测试数据存储
result1 = await stock_system.stock_repo.save_stock_basic_info(stock_info)
result2 = await stock_system.stock_repo.save_daily_kline_data(kline_data)
assert result1 is True
assert result2 is True
# 验证数据已存储
stored_stock = await stock_system.stock_repo.get_stock_basic_info("000001")
stored_kline = await stock_system.stock_repo.get_daily_kline_data("000001", "2024-01-15", "2024-01-15")
assert len(stored_stock) == 1
assert stored_stock[0].code == "000001"
assert len(stored_kline) == 1
assert stored_kline[0].code == "000001"
@pytest.mark.asyncio
async def test_data_processing_integration(self, stock_system):
"""测试数据处理集成"""
# 初始化系统
await stock_system.initialize()
# 测试数据
raw_stock_data = [
{
"code": "000001",
"name": "平安银行",
"industry": "银行",
"market": "深圳",
"list_date": "1991-04-03",
"extra_field": "should_be_removed"
}
]
raw_kline_data = [
{
"code": "000001",
"date": "2024-01-15",
"open": "10.5", # 字符串格式
"high": "11.2",
"low": "10.3",
"close": "10.8",
"volume": "1000000",
"amount": "10800000",
"invalid_field": "should_be_removed"
}
]
# 测试数据处理
processed_stock = stock_system.data_processor.process_stock_basic_info(raw_stock_data)
processed_kline = stock_system.data_processor.process_daily_kline_data(raw_kline_data)
assert len(processed_stock) == 1
assert "extra_field" not in processed_stock[0] # 验证字段过滤
assert isinstance(processed_stock[0]["code"], str)
assert len(processed_kline) == 1
assert "invalid_field" not in processed_kline[0] # 验证字段过滤
assert isinstance(processed_kline[0]["open"], float) # 验证类型转换
assert isinstance(processed_kline[0]["volume"], int) # 验证类型转换
@pytest.mark.asyncio
async def test_scheduler_integration(self, stock_system):
"""测试定时任务调度器集成"""
# 初始化系统
await stock_system.initialize()
# Mock数据采集和处理
with patch.object(stock_system.data_manager, "get_stock_basic_info", return_value=[]):
with patch.object(stock_system.data_manager, "get_daily_kline_data", return_value=[]):
with patch.object(stock_system.data_manager, "get_financial_report", return_value=[]):
with patch.object(stock_system.stock_repo, "save_stock_basic_info", return_value=True):
with patch.object(stock_system.stock_repo, "save_daily_kline_data", return_value=True):
with patch.object(stock_system.stock_repo, "save_financial_report_data", return_value=True):
# 测试调度器启动
result = await stock_system.start_scheduler()
assert result is True
assert stock_system.task_scheduler.is_running() is True
# 测试调度器停止
result = await stock_system.stop_scheduler()
assert result is True
assert stock_system.task_scheduler.is_running() is False
@pytest.mark.asyncio
async def test_error_handling_integration(self, stock_system):
"""测试错误处理集成"""
# 初始化系统
await stock_system.initialize()
# 模拟数据采集失败
with patch.object(stock_system.data_manager, "get_stock_basic_info", side_effect=Exception("网络错误")):
with pytest.raises(Exception):
await stock_system.data_manager.get_stock_basic_info()
# 模拟数据库操作失败
with patch.object(stock_system.stock_repo, "save_stock_basic_info", side_effect=Exception("数据库错误")):
with pytest.raises(Exception):
await stock_system.stock_repo.save_stock_basic_info([])
# 验证错误日志记录
assert stock_system.logger is not None
assert isinstance(stock_system.logger, LogManager)
@pytest.mark.asyncio
async def test_system_status_integration(self, stock_system):
"""测试系统状态检查集成"""
# 初始化系统
await stock_system.initialize()
# 获取系统状态
status = await stock_system.get_system_status()
# 验证状态信息完整性
assert "initialization" in status
assert "data_sources" in status
assert "database" in status
assert "scheduler" in status
assert "data_statistics" in status
# 验证具体状态值
assert status["initialization"]["status"] == "已初始化"
assert status["database"]["status"] == "正常"
# 验证数据源状态
assert "akshare" in status["data_sources"]
assert "baostock" in status["data_sources"]
@pytest.mark.asyncio
async def test_command_line_interface_integration(self, stock_system):
"""测试命令行接口集成"""
# 初始化系统
await stock_system.initialize()
# 测试init命令
with patch.object(stock_system, "initialize_all_data", return_value={"stock_basic": {"count": 1000}}):
result = await stock_system.run_command("init")
assert result is True
# 测试status命令
result = await stock_system.run_command("status")
assert result is True
# 测试scheduler命令
with patch.object(stock_system, "start_scheduler", return_value=True):
result = await stock_system.run_command("scheduler", start=True)
assert result is True
# 测试update命令
with patch.object(stock_system, "update_daily_data", return_value=True):
result = await stock_system.run_command("update", daily=True)
assert result is True
@pytest.mark.asyncio
async def test_performance_integration(self, stock_system):
"""测试性能集成"""
# 初始化系统
await stock_system.initialize()
# 测试批量数据存储性能
batch_stock_data = []
for i in range(100):
batch_stock_data.append({
"code": f"{i:06d}",
"name": f"测试股票{i}",
"industry": "测试行业",
"market": "测试市场",
"list_date": "2020-01-01"
})
# 批量保存数据
start_time = asyncio.get_event_loop().time()
result = await stock_system.stock_repo.save_stock_basic_info(batch_stock_data)
end_time = asyncio.get_event_loop().time()
assert result is True
# 验证性能(应该在合理时间内完成)
execution_time = end_time - start_time
assert execution_time < 10.0 # 100条数据应该在10秒内完成
# 测试批量数据查询性能
start_time = asyncio.get_event_loop().time()
stored_data = await stock_system.stock_repo.get_stock_basic_info("000050")
end_time = asyncio.get_event_loop().time()
query_time = end_time - start_time
assert query_time < 1.0 # 单条查询应该在1秒内完成
@pytest.mark.asyncio
async def test_concurrent_operations_integration(self, stock_system):
"""测试并发操作集成"""
# 初始化系统
await stock_system.initialize()
# 定义并发任务
async def query_stock_info(code):
return await stock_system.stock_repo.get_stock_basic_info(code)
async def save_stock_info(data):
return await stock_system.stock_repo.save_stock_basic_info(data)
# 创建并发任务
tasks = []
for i in range(10):
code = f"{i:06d}"
tasks.append(query_stock_info(code))
# 执行并发查询
results = await asyncio.gather(*tasks, return_exceptions=True)
# 验证所有查询都成功完成(可能返回空结果,但不应该抛出异常)
for result in results:
assert not isinstance(result, Exception)
assert isinstance(result, list)
@pytest.mark.asyncio
async def test_data_consistency_integration(self, stock_system):
"""测试数据一致性集成"""
# 初始化系统
await stock_system.initialize()
# 测试数据
test_data = [
{
"code": "000001",
"name": "平安银行",
"industry": "银行",
"market": "深圳",
"list_date": "1991-04-03"
}
]
# 保存数据
await stock_system.stock_repo.save_stock_basic_info(test_data)
# 查询数据
stored_data = await stock_system.stock_repo.get_stock_basic_info("000001")
# 验证数据一致性
assert len(stored_data) == 1
assert stored_data[0].code == "000001"
assert stored_data[0].name == "平安银行"
assert stored_data[0].industry == "银行"
assert stored_data[0].market == "深圳"
# 再次保存相同数据(应该去重)
result = await stock_system.stock_repo.save_stock_basic_info(test_data)
assert result is True
# 验证数据没有重复
stored_data_again = await stock_system.stock_repo.get_stock_basic_info("000001")
assert len(stored_data_again) == 1 # 应该只有一条记录
@pytest.mark.asyncio
async def test_system_recovery_integration(self, stock_system):
"""测试系统恢复集成"""
# 第一次初始化
result1 = await stock_system.initialize()
assert result1 is True
# 保存一些测试数据
test_data = [{"code": "000001", "name": "测试股票", "industry": "测试", "market": "测试", "list_date": "2020-01-01"}]
await stock_system.stock_repo.save_stock_basic_info(test_data)
# 模拟系统重启(重新创建系统实例)
new_system = StockAnalysisSystem()
new_system.db_config = stock_system.db_config
# 重新初始化
result2 = await new_system.initialize()
assert result2 is True
# 验证数据仍然存在
stored_data = await new_system.stock_repo.get_stock_basic_info("000001")
assert len(stored_data) == 1
assert stored_data[0].code == "000001"
@pytest.mark.asyncio
async def test_logging_integration(self, stock_system):
"""测试日志记录集成"""
# 初始化系统
await stock_system.initialize()
# 验证日志管理器已正确配置
assert stock_system.logger is not None
assert isinstance(stock_system.logger, LogManager)
# 测试不同级别的日志记录
stock_system.logger.info("测试信息日志")
stock_system.logger.warning("测试警告日志")
stock_system.logger.error("测试错误日志")
# 验证系统事件日志
stock_system.logger.log_system_event("测试系统事件")
stock_system.logger.log_data_collection("测试数据采集")
stock_system.logger.log_database_operation("测试数据库操作")
# 验证性能指标日志
stock_system.logger.log_performance_metric("测试性能指标", 100.0)
class TestEndToEndWorkflow:
"""端到端工作流测试类"""
@pytest.fixture
def temp_database(self):
"""创建临时数据库"""
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
temp_db_path = tmp.name
yield temp_db_path
# 清理临时文件
if os.path.exists(temp_db_path):
os.unlink(temp_db_path)
@pytest.fixture
def stock_system(self, temp_database):
"""股票分析系统实例"""
system = StockAnalysisSystem()
system.db_config = {
"database_url": f"sqlite:///{temp_database}",
"echo": False,
"pool_size": 5,
"max_overflow": 10
}
return system
@pytest.mark.asyncio
async def test_complete_data_workflow(self, stock_system):
"""测试完整数据工作流:采集→处理→存储→查询"""
# 初始化系统
await stock_system.initialize()
# Mock真实数据采集结果
mock_stock_data = [
{
"code": "000001",
"name": "平安银行",
"industry": "银行",
"market": "深圳",
"list_date": "1991-04-03",
"extra_field": "should_be_removed"
},
{
"code": "000002",
"name": "万科A",
"industry": "房地产",
"market": "深圳",
"list_date": "1991-01-29",
"extra_field": "should_be_removed"
}
]
mock_kline_data = [
{
"code": "000001",
"date": "2024-01-15",
"open": "10.5",
"high": "11.2",
"low": "10.3",
"close": "10.8",
"volume": "1000000",
"amount": "10800000"
},
{
"code": "000001",
"date": "2024-01-16",
"open": "10.8",
"high": "11.5",
"low": "10.7",
"close": "11.2",
"volume": "1200000",
"amount": "13440000"
}
]
# Mock数据采集器
with patch.object(stock_system.data_manager, "get_stock_basic_info", return_value=mock_stock_data):
with patch.object(stock_system.data_manager, "get_daily_kline_data", return_value=mock_kline_data):
# 1. 数据采集
raw_stock_data = await stock_system.data_manager.get_stock_basic_info()
raw_kline_data = await stock_system.data_manager.get_daily_kline_data("000001", "2024-01-15", "2024-01-16")
assert len(raw_stock_data) == 2
assert len(raw_kline_data) == 2
# 2. 数据处理
processed_stock_data = stock_system.data_processor.process_stock_basic_info(raw_stock_data)
processed_kline_data = stock_system.data_processor.process_daily_kline_data(raw_kline_data)
assert len(processed_stock_data) == 2
assert len(processed_kline_data) == 2
# 验证数据处理结果
for stock in processed_stock_data:
assert "extra_field" not in stock # 验证字段过滤
assert isinstance(stock["code"], str)
assert isinstance(stock["name"], str)
for kline in processed_kline_data:
assert isinstance(kline["open"], float) # 验证类型转换
assert isinstance(kline["volume"], int) # 验证类型转换
# 3. 数据存储
storage_result1 = await stock_system.stock_repo.save_stock_basic_info(processed_stock_data)
storage_result2 = await stock_system.stock_repo.save_daily_kline_data(processed_kline_data)
assert storage_result1 is True
assert storage_result2 is True
# 4. 数据查询
stored_stock = await stock_system.stock_repo.get_stock_basic_info("000001")
stored_kline = await stock_system.stock_repo.get_daily_kline_data("000001", "2024-01-15", "2024-01-16")
assert len(stored_stock) == 1
assert stored_stock[0].code == "000001"
assert stored_stock[0].name == "平安银行"
assert len(stored_kline) == 2
assert all(kline.code == "000001" for kline in stored_kline)
# 验证数据排序(按日期升序)
dates = [kline.date for kline in stored_kline]
assert dates == sorted(dates)
@pytest.mark.asyncio
async def test_scheduler_workflow(self, stock_system):
"""测试定时任务工作流"""
# 初始化系统
await stock_system.initialize()
# Mock数据采集和处理
mock_data = [{"code": "000001", "name": "测试股票", "industry": "测试", "market": "测试", "list_date": "2020-01-01"}]
with patch.object(stock_system.data_manager, "get_stock_basic_info", return_value=mock_data):
with patch.object(stock_system.data_manager, "get_daily_kline_data", return_value=[]):
with patch.object(stock_system.data_manager, "get_financial_report", return_value=[]):
with patch.object(stock_system.stock_repo, "save_stock_basic_info", return_value=True):
with patch.object(stock_system.stock_repo, "save_daily_kline_data", return_value=True):
with patch.object(stock_system.stock_repo, "save_financial_report_data", return_value=True):
# 启动调度器
await stock_system.start_scheduler()
# 验证调度器状态
assert stock_system.task_scheduler.is_running() is True
# 获取任务状态
job_status = stock_system.task_scheduler.get_job_status()
assert len(job_status) > 0
# 停止调度器
await stock_system.stop_scheduler()
# 验证调度器已停止
assert stock_system.task_scheduler.is_running() is False

498
tests/test_main.py Normal file
View File

@ -0,0 +1,498 @@
"""
主程序模块单元测试
测试StockAnalysisSystem主类的功能
"""
import pytest
import asyncio
from unittest.mock import Mock, patch, AsyncMock
import sys
from io import StringIO
from src.main import StockAnalysisSystem
from src.utils.exceptions import StockSystemError
class TestStockAnalysisSystem:
"""股票分析系统测试类"""
@pytest.fixture
def stock_system(self):
"""股票分析系统实例"""
return StockAnalysisSystem()
def test_initialization(self, stock_system):
"""测试系统初始化"""
assert stock_system.data_manager is None
assert stock_system.data_processor is None
assert stock_system.stock_repo is None
assert stock_system.task_scheduler is None
assert stock_system.is_initialized is False
@pytest.mark.asyncio
async def test_initialize_success(self, stock_system):
"""测试系统初始化成功"""
with patch.object(stock_system, "_setup_database") as mock_setup_db:
with patch.object(stock_system, "_setup_components") as mock_setup_comp:
result = await stock_system.initialize()
assert result is True
assert stock_system.is_initialized is True
mock_setup_db.assert_called_once()
mock_setup_comp.assert_called_once()
@pytest.mark.asyncio
async def test_initialize_already_initialized(self, stock_system):
"""测试重复初始化"""
# 第一次初始化
with patch.object(stock_system, "_setup_database"):
with patch.object(stock_system, "_setup_components"):
await stock_system.initialize()
# 第二次初始化
result = await stock_system.initialize()
assert result is False # 应该返回False因为已经初始化过
@pytest.mark.asyncio
async def test_initialize_database_failure(self, stock_system):
"""测试数据库初始化失败"""
with patch.object(stock_system, "_setup_database", side_effect=Exception("数据库错误")):
with pytest.raises(StockSystemError):
await stock_system.initialize()
assert stock_system.is_initialized is False
@pytest.mark.asyncio
async def test_initialize_components_failure(self, stock_system):
"""测试组件初始化失败"""
with patch.object(stock_system, "_setup_database"):
with patch.object(stock_system, "_setup_components", side_effect=Exception("组件错误")):
with pytest.raises(StockSystemError):
await stock_system.initialize()
assert stock_system.is_initialized is False
@pytest.mark.asyncio
async def test_start_scheduler_success(self, stock_system):
"""测试启动调度器成功"""
# 先初始化系统
with patch.object(stock_system, "_setup_database"):
with patch.object(stock_system, "_setup_components"):
await stock_system.initialize()
# Mock调度器启动
with patch.object(stock_system.task_scheduler, "start", return_value=True):
result = await stock_system.start_scheduler()
assert result is True
@pytest.mark.asyncio
async def test_start_scheduler_not_initialized(self, stock_system):
"""测试未初始化时启动调度器"""
result = await stock_system.start_scheduler()
assert result is False
@pytest.mark.asyncio
async def test_start_scheduler_failure(self, stock_system):
"""测试启动调度器失败"""
# 先初始化系统
with patch.object(stock_system, "_setup_database"):
with patch.object(stock_system, "_setup_components"):
await stock_system.initialize()
# Mock调度器启动失败
with patch.object(stock_system.task_scheduler, "start", return_value=False):
result = await stock_system.start_scheduler()
assert result is False
@pytest.mark.asyncio
async def test_stop_scheduler_success(self, stock_system):
"""测试停止调度器成功"""
# 先初始化系统
with patch.object(stock_system, "_setup_database"):
with patch.object(stock_system, "_setup_components"):
await stock_system.initialize()
# Mock调度器停止
with patch.object(stock_system.task_scheduler, "stop", return_value=True):
result = await stock_system.stop_scheduler()
assert result is True
@pytest.mark.asyncio
async def test_stop_scheduler_not_initialized(self, stock_system):
"""测试未初始化时停止调度器"""
result = await stock_system.stop_scheduler()
assert result is False
@pytest.mark.asyncio
async def test_get_system_status_success(self, stock_system):
"""测试获取系统状态成功"""
# 先初始化系统
with patch.object(stock_system, "_setup_database"):
with patch.object(stock_system, "_setup_components"):
await stock_system.initialize()
# Mock组件状态
with patch.object(stock_system.data_manager, "get_stock_basic_info", return_value=[{"code": "000001"}]):
with patch.object(stock_system.stock_repo, "get_stock_basic_info", return_value=[Mock()]):
with patch.object(stock_system.task_scheduler, "get_job_status", return_value={"test_job": {}}):
with patch.object(stock_system.task_scheduler, "is_running", True):
status = await stock_system.get_system_status()
assert "initialization" in status
assert "data_sources" in status
assert "database" in status
assert "scheduler" in status
assert "data_statistics" in status
assert status["initialization"]["status"] == "已初始化"
assert status["scheduler"]["status"] == "运行中"
@pytest.mark.asyncio
async def test_get_system_status_not_initialized(self, stock_system):
"""测试未初始化时获取系统状态"""
status = await stock_system.get_system_status()
assert "initialization" in status
assert status["initialization"]["status"] == "未初始化"
@pytest.mark.asyncio
async def test_initialize_all_data_success(self, stock_system):
"""测试初始化所有数据成功"""
# 先初始化系统
with patch.object(stock_system, "_setup_database"):
with patch.object(stock_system, "_setup_components"):
await stock_system.initialize()
# Mock数据初始化器
mock_initializer = Mock()
mock_initializer.initialize_all_data.return_value = {
"stock_basic": {"count": 1000, "status": "成功"},
"daily_kline": {"count": 50000, "status": "成功"},
"financial_report": {"count": 2000, "status": "成功"}
}
with patch("src.data.data_initializer.DataInitializer", return_value=mock_initializer):
result = await stock_system.initialize_all_data()
assert "stock_basic" in result
assert "daily_kline" in result
assert "financial_report" in result
assert result["stock_basic"]["count"] == 1000
assert result["daily_kline"]["count"] == 50000
assert result["financial_report"]["count"] == 2000
@pytest.mark.asyncio
async def test_initialize_all_data_not_initialized(self, stock_system):
"""测试未初始化时初始化数据"""
result = await stock_system.initialize_all_data()
assert result == {}
@pytest.mark.asyncio
async def test_update_daily_data_success(self, stock_system):
"""测试更新每日数据成功"""
# 先初始化系统
with patch.object(stock_system, "_setup_database"):
with patch.object(stock_system, "_setup_components"):
await stock_system.initialize()
# Mock数据更新
with patch.object(stock_system.task_scheduler, "_update_daily_kline_data", return_value=True):
result = await stock_system.update_daily_data()
assert result is True
@pytest.mark.asyncio
async def test_update_daily_data_not_initialized(self, stock_system):
"""测试未初始化时更新数据"""
result = await stock_system.update_daily_data()
assert result is False
@pytest.mark.asyncio
async def test_update_financial_data_success(self, stock_system):
"""测试更新财务数据成功"""
# 先初始化系统
with patch.object(stock_system, "_setup_database"):
with patch.object(stock_system, "_setup_components"):
await stock_system.initialize()
# Mock数据更新
with patch.object(stock_system.task_scheduler, "_update_financial_data", return_value=True):
result = await stock_system.update_financial_data()
assert result is True
@pytest.mark.asyncio
async def test_update_stock_basic_info_success(self, stock_system):
"""测试更新股票基础信息成功"""
# 先初始化系统
with patch.object(stock_system, "_setup_database"):
with patch.object(stock_system, "_setup_components"):
await stock_system.initialize()
# Mock数据更新
with patch.object(stock_system.task_scheduler, "_update_stock_basic_info", return_value=True):
result = await stock_system.update_stock_basic_info()
assert result is True
class TestCommandLineInterface:
"""命令行接口测试类"""
@pytest.fixture
def capture_output(self):
"""捕获标准输出"""
old_stdout = sys.stdout
sys.stdout = StringIO()
yield sys.stdout
sys.stdout = old_stdout
def test_parse_arguments_init(self):
"""测试解析init命令参数"""
test_args = ["main.py", "init"]
with patch("sys.argv", test_args):
args = StockAnalysisSystem.parse_arguments()
assert args.command == "init"
assert args.force is False
def test_parse_arguments_scheduler(self):
"""测试解析scheduler命令参数"""
test_args = ["main.py", "scheduler", "--start"]
with patch("sys.argv", test_args):
args = StockAnalysisSystem.parse_arguments()
assert args.command == "scheduler"
assert args.start is True
assert args.stop is False
def test_parse_arguments_status(self):
"""测试解析status命令参数"""
test_args = ["main.py", "status"]
with patch("sys.argv", test_args):
args = StockAnalysisSystem.parse_arguments()
assert args.command == "status"
def test_parse_arguments_update(self):
"""测试解析update命令参数"""
test_args = ["main.py", "update", "--daily"]
with patch("sys.argv", test_args):
args = StockAnalysisSystem.parse_arguments()
assert args.command == "update"
assert args.daily is True
assert args.financial is False
assert args.basic is False
@pytest.mark.asyncio
async def test_run_init_command_success(self, stock_system, capture_output):
"""测试运行init命令成功"""
with patch.object(stock_system, "initialize", return_value=True):
with patch.object(stock_system, "initialize_all_data", return_value={"stock_basic": {"count": 1000}}):
result = await stock_system.run_command("init")
assert result is True
# 检查输出
output = capture_output.getvalue()
assert "系统初始化成功" in output
assert "数据初始化完成" in output
@pytest.mark.asyncio
async def test_run_init_command_failure(self, stock_system, capture_output):
"""测试运行init命令失败"""
with patch.object(stock_system, "initialize", side_effect=StockSystemError("初始化失败")):
result = await stock_system.run_command("init")
assert result is False
# 检查输出
output = capture_output.getvalue()
assert "系统初始化失败" in output
@pytest.mark.asyncio
async def test_run_scheduler_start_command_success(self, stock_system, capture_output):
"""测试运行scheduler start命令成功"""
# 先初始化系统
with patch.object(stock_system, "_setup_database"):
with patch.object(stock_system, "_setup_components"):
await stock_system.initialize()
with patch.object(stock_system, "start_scheduler", return_value=True):
result = await stock_system.run_command("scheduler", start=True)
assert result is True
# 检查输出
output = capture_output.getvalue()
assert "定时任务调度器启动成功" in output
@pytest.mark.asyncio
async def test_run_scheduler_stop_command_success(self, stock_system, capture_output):
"""测试运行scheduler stop命令成功"""
# 先初始化系统
with patch.object(stock_system, "_setup_database"):
with patch.object(stock_system, "_setup_components"):
await stock_system.initialize()
with patch.object(stock_system, "stop_scheduler", return_value=True):
result = await stock_system.run_command("scheduler", stop=True)
assert result is True
# 检查输出
output = capture_output.getvalue()
assert "定时任务调度器停止成功" in output
@pytest.mark.asyncio
async def test_run_status_command_success(self, stock_system, capture_output):
"""测试运行status命令成功"""
# 先初始化系统
with patch.object(stock_system, "_setup_database"):
with patch.object(stock_system, "_setup_components"):
await stock_system.initialize()
mock_status = {
"initialization": {"status": "已初始化"},
"data_sources": {"akshare": "正常"},
"database": {"status": "正常"},
"scheduler": {"status": "运行中"},
"data_statistics": {"stock_count": 1000}
}
with patch.object(stock_system, "get_system_status", return_value=mock_status):
result = await stock_system.run_command("status")
assert result is True
# 检查输出
output = capture_output.getvalue()
assert "系统状态" in output
assert "已初始化" in output
assert "运行中" in output
@pytest.mark.asyncio
async def test_run_update_daily_command_success(self, stock_system, capture_output):
"""测试运行update daily命令成功"""
# 先初始化系统
with patch.object(stock_system, "_setup_database"):
with patch.object(stock_system, "_setup_components"):
await stock_system.initialize()
with patch.object(stock_system, "update_daily_data", return_value=True):
result = await stock_system.run_command("update", daily=True)
assert result is True
# 检查输出
output = capture_output.getvalue()
assert "每日数据更新完成" in output
@pytest.mark.asyncio
async def test_run_update_financial_command_success(self, stock_system, capture_output):
"""测试运行update financial命令成功"""
# 先初始化系统
with patch.object(stock_system, "_setup_database"):
with patch.object(stock_system, "_setup_components"):
await stock_system.initialize()
with patch.object(stock_system, "update_financial_data", return_value=True):
result = await stock_system.run_command("update", financial=True)
assert result is True
# 检查输出
output = capture_output.getvalue()
assert "财务数据更新完成" in output
@pytest.mark.asyncio
async def test_run_update_basic_command_success(self, stock_system, capture_output):
"""测试运行update basic命令成功"""
# 先初始化系统
with patch.object(stock_system, "_setup_database"):
with patch.object(stock_system, "_setup_components"):
await stock_system.initialize()
with patch.object(stock_system, "update_stock_basic_info", return_value=True):
result = await stock_system.run_command("update", basic=True)
assert result is True
# 检查输出
output = capture_output.getvalue()
assert "股票基础信息更新完成" in output
class TestErrorHandling:
"""错误处理测试类"""
@pytest.mark.asyncio
async def test_handle_exception(self, stock_system, capture_output):
"""测试异常处理"""
# 模拟异常
try:
raise StockSystemError("测试异常")
except Exception as e:
result = stock_system._handle_exception(e)
assert result is False
# 检查输出
output = capture_output.getvalue()
assert "测试异常" in output
@pytest.mark.asyncio
async def test_handle_keyboard_interrupt(self, stock_system, capture_output):
"""测试键盘中断处理"""
# 模拟键盘中断
try:
raise KeyboardInterrupt()
except Exception as e:
result = stock_system._handle_exception(e)
assert result is True
# 检查输出
output = capture_output.getvalue()
assert "程序被用户中断" in output
@pytest.mark.asyncio
async def test_main_function_success(self, capture_output):
"""测试主函数成功执行"""
# Mock命令行参数和系统运行
test_args = ["main.py", "status"]
with patch("sys.argv", test_args):
with patch("src.main.StockAnalysisSystem.run_command", return_value=True):
from src.main import main
result = await main()
assert result == 0
@pytest.mark.asyncio
async def test_main_function_failure(self, capture_output):
"""测试主函数执行失败"""
# Mock命令行参数和系统运行失败
test_args = ["main.py", "status"]
with patch("sys.argv", test_args):
with patch("src.main.StockAnalysisSystem.run_command", return_value=False):
from src.main import main
result = await main()
assert result == 1

599
tests/test_performance.py Normal file
View File

@ -0,0 +1,599 @@
"""
性能测试模块
测试股票分析系统的性能表现
"""
import pytest
import asyncio
import time
from unittest.mock import Mock, patch
import tempfile
import os
from src.main import StockAnalysisSystem
from src.storage.database import DatabaseManager
from src.storage.stock_repository import StockRepository
from src.data.data_manager import DataManager
class TestPerformance:
"""性能测试类"""
@pytest.fixture
def temp_database(self):
"""创建临时数据库"""
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
temp_db_path = tmp.name
yield temp_db_path
# 清理临时文件
if os.path.exists(temp_db_path):
os.unlink(temp_db_path)
@pytest.fixture
def stock_system(self, temp_database):
"""股票分析系统实例"""
system = StockAnalysisSystem()
system.db_config = {
"database_url": f"sqlite:///{temp_database}",
"echo": False,
"pool_size": 10,
"max_overflow": 20
}
return system
@pytest.mark.asyncio
async def test_database_initialization_performance(self, stock_system):
"""测试数据库初始化性能"""
# 测量数据库初始化时间
start_time = time.time()
await stock_system.initialize()
end_time = time.time()
execution_time = end_time - start_time
# 数据库初始化应该在合理时间内完成
assert execution_time < 5.0 # 5秒内完成
print(f"数据库初始化时间: {execution_time:.3f}")
@pytest.mark.asyncio
async def test_batch_data_storage_performance(self, stock_system):
"""测试批量数据存储性能"""
# 初始化系统
await stock_system.initialize()
# 生成批量测试数据
batch_sizes = [100, 500, 1000]
for batch_size in batch_sizes:
batch_data = []
for i in range(batch_size):
batch_data.append({
"code": f"{i:06d}",
"name": f"测试股票{i}",
"industry": "测试行业",
"market": "测试市场",
"list_date": "2020-01-01"
})
# 测量批量存储时间
start_time = time.time()
result = await stock_system.stock_repo.save_stock_basic_info(batch_data)
end_time = time.time()
execution_time = end_time - start_time
assert result is True
# 计算每秒处理记录数
records_per_second = batch_size / execution_time
print(f"批量存储 {batch_size} 条记录 - 时间: {execution_time:.3f}秒, 速度: {records_per_second:.1f} 条/秒")
# 验证性能要求
assert execution_time < 30.0 # 1000条记录应该在30秒内完成
assert records_per_second > 30.0 # 最低性能要求30条/秒
@pytest.mark.asyncio
async def test_batch_kline_data_storage_performance(self, stock_system):
"""测试批量K线数据存储性能"""
# 初始化系统
await stock_system.initialize()
# 生成批量K线测试数据
batch_sizes = [100, 500, 1000]
for batch_size in batch_sizes:
batch_data = []
for i in range(batch_size):
batch_data.append({
"code": "000001",
"date": f"2024-01-{i+1:02d}",
"open": 10.0 + i * 0.1,
"high": 11.0 + i * 0.1,
"low": 9.0 + i * 0.1,
"close": 10.5 + i * 0.1,
"volume": 1000000 + i * 1000,
"amount": 10500000 + i * 10500
})
# 测量批量存储时间
start_time = time.time()
result = await stock_system.stock_repo.save_daily_kline_data(batch_data)
end_time = time.time()
execution_time = end_time - start_time
assert result is True
# 计算每秒处理记录数
records_per_second = batch_size / execution_time
print(f"批量存储 {batch_size} 条K线记录 - 时间: {execution_time:.3f}秒, 速度: {records_per_second:.1f} 条/秒")
# 验证性能要求
assert execution_time < 30.0 # 1000条记录应该在30秒内完成
assert records_per_second > 30.0 # 最低性能要求30条/秒
@pytest.mark.asyncio
async def test_data_query_performance(self, stock_system):
"""测试数据查询性能"""
# 初始化系统
await stock_system.initialize()
# 先存储一些测试数据
test_data = []
for i in range(1000):
test_data.append({
"code": f"{i:06d}",
"name": f"测试股票{i}",
"industry": "测试行业",
"market": "测试市场",
"list_date": "2020-01-01"
})
await stock_system.stock_repo.save_stock_basic_info(test_data)
# 测试单条查询性能
query_times = []
for i in range(10):
code = f"{i:06d}"
start_time = time.time()
result = await stock_system.stock_repo.get_stock_basic_info(code)
end_time = time.time()
execution_time = end_time - start_time
query_times.append(execution_time)
assert len(result) <= 1 # 应该返回0或1条记录
# 计算平均查询时间
avg_query_time = sum(query_times) / len(query_times)
max_query_time = max(query_times)
print(f"单条查询性能 - 平均时间: {avg_query_time:.3f}秒, 最大时间: {max_query_time:.3f}")
# 验证性能要求
assert avg_query_time < 0.1 # 平均查询时间应该小于100毫秒
assert max_query_time < 0.5 # 最大查询时间应该小于500毫秒
@pytest.mark.asyncio
async def test_batch_query_performance(self, stock_system):
"""测试批量查询性能"""
# 初始化系统
await stock_system.initialize()
# 先存储一些K线测试数据
test_data = []
for i in range(100): # 100天的数据
test_data.append({
"code": "000001",
"date": f"2024-01-{i+1:02d}",
"open": 10.0 + i * 0.1,
"high": 11.0 + i * 0.1,
"low": 9.0 + i * 0.1,
"close": 10.5 + i * 0.1,
"volume": 1000000 + i * 1000,
"amount": 10500000 + i * 10500
})
await stock_system.stock_repo.save_daily_kline_data(test_data)
# 测试批量查询性能
start_time = time.time()
result = await stock_system.stock_repo.get_daily_kline_data("000001", "2024-01-01", "2024-04-10")
end_time = time.time()
execution_time = end_time - start_time
assert len(result) == 100 # 应该返回100条记录
print(f"批量查询100条K线记录 - 时间: {execution_time:.3f}")
# 验证性能要求
assert execution_time < 1.0 # 100条记录应该在1秒内完成
@pytest.mark.asyncio
async def test_concurrent_operations_performance(self, stock_system):
"""测试并发操作性能"""
# 初始化系统
await stock_system.initialize()
# 定义并发查询任务
async def query_task(code):
return await stock_system.stock_repo.get_stock_basic_info(code)
# 测试不同并发级别
concurrency_levels = [5, 10, 20]
for concurrency in concurrency_levels:
tasks = []
for i in range(concurrency):
code = f"{i:06d}"
tasks.append(query_task(code))
# 测量并发执行时间
start_time = time.time()
results = await asyncio.gather(*tasks)
end_time = time.time()
execution_time = end_time - start_time
# 验证所有查询都成功完成
for result in results:
assert isinstance(result, list)
print(f"并发查询 {concurrency} 个任务 - 时间: {execution_time:.3f}")
# 验证性能要求
assert execution_time < 2.0 # 20个并发查询应该在2秒内完成
@pytest.mark.asyncio
async def test_memory_usage_performance(self, stock_system):
"""测试内存使用性能"""
import psutil
import os
# 获取当前进程内存使用
process = psutil.Process(os.getpid())
initial_memory = process.memory_info().rss / 1024 / 1024 # MB
# 初始化系统
await stock_system.initialize()
# 获取初始化后内存使用
memory_after_init = process.memory_info().rss / 1024 / 1024
memory_increase = memory_after_init - initial_memory
print(f"系统初始化内存增加: {memory_increase:.2f} MB")
# 验证内存使用在合理范围内
assert memory_increase < 100.0 # 系统初始化内存增加应该小于100MB
# 测试批量操作内存使用
batch_data = []
for i in range(1000):
batch_data.append({
"code": f"{i:06d}",
"name": f"测试股票{i}",
"industry": "测试行业",
"market": "测试市场",
"list_date": "2020-01-01"
})
memory_before_batch = process.memory_info().rss / 1024 / 1024
# 执行批量存储
await stock_system.stock_repo.save_stock_basic_info(batch_data)
memory_after_batch = process.memory_info().rss / 1024 / 1024
batch_memory_increase = memory_after_batch - memory_before_batch
print(f"批量存储1000条记录内存增加: {batch_memory_increase:.2f} MB")
# 验证批量操作内存使用在合理范围内
assert batch_memory_increase < 50.0 # 批量存储内存增加应该小于50MB
@pytest.mark.asyncio
async def test_database_connection_pool_performance(self, stock_system):
"""测试数据库连接池性能"""
# 初始化系统
await stock_system.initialize()
# 测试连接池的并发处理能力
async def database_operation(code):
# 执行数据库操作
result = await stock_system.stock_repo.get_stock_basic_info(code)
return result
# 模拟高并发场景
concurrent_operations = 50
tasks = []
for i in range(concurrent_operations):
code = f"{i % 100:06d}" # 循环使用100个不同的股票代码
tasks.append(database_operation(code))
# 测量并发执行时间
start_time = time.time()
results = await asyncio.gather(*tasks)
end_time = time.time()
execution_time = end_time - start_time
print(f"数据库连接池处理 {concurrent_operations} 个并发操作 - 时间: {execution_time:.3f}")
# 验证连接池性能
assert execution_time < 5.0 # 50个并发操作应该在5秒内完成
@pytest.mark.asyncio
async def test_data_processing_performance(self, stock_system):
"""测试数据处理性能"""
# 初始化系统
await stock_system.initialize()
# 生成测试数据
batch_sizes = [100, 500, 1000]
for batch_size in batch_sizes:
raw_data = []
for i in range(batch_size):
raw_data.append({
"code": f"{i:06d}",
"name": f"测试股票{i}",
"industry": "测试行业",
"market": "测试市场",
"list_date": "2020-01-01",
"extra_field1": "should_be_removed",
"extra_field2": 123.45
})
# 测量数据处理时间
start_time = time.time()
processed_data = stock_system.data_processor.process_stock_basic_info(raw_data)
end_time = time.time()
execution_time = end_time - start_time
assert len(processed_data) == batch_size
# 验证字段过滤
for item in processed_data:
assert "extra_field1" not in item
assert "extra_field2" not in item
print(f"处理 {batch_size} 条记录 - 时间: {execution_time:.3f}")
# 验证性能要求
assert execution_time < 1.0 # 1000条记录应该在1秒内完成
@pytest.mark.asyncio
async def test_system_startup_performance(self, stock_system):
"""测试系统启动性能"""
# 测量完整系统启动时间
start_time = time.time()
await stock_system.initialize()
end_time = time.time()
execution_time = end_time - start_time
print(f"系统完整启动时间: {execution_time:.3f}")
# 验证启动性能
assert execution_time < 10.0 # 系统启动应该在10秒内完成
@pytest.mark.asyncio
async def test_long_running_performance(self, stock_system):
"""测试长时间运行性能"""
# 初始化系统
await stock_system.initialize()
# 模拟长时间运行(执行多次操作)
operation_count = 100
operation_times = []
for i in range(operation_count):
code = f"{i % 100:06d}"
start_time = time.time()
result = await stock_system.stock_repo.get_stock_basic_info(code)
end_time = time.time()
execution_time = end_time - start_time
operation_times.append(execution_time)
assert isinstance(result, list)
# 分析性能稳定性
avg_time = sum(operation_times) / len(operation_times)
max_time = max(operation_times)
min_time = min(operation_times)
print(f"长时间运行性能 - 平均时间: {avg_time:.3f}秒, 最小时间: {min_time:.3f}秒, 最大时间: {max_time:.3f}")
# 验证性能稳定性
assert max_time / avg_time < 5.0 # 最大时间不应超过平均时间的5倍
assert (max_time - min_time) < 0.5 # 时间差异应该小于0.5秒
class TestScalability:
"""可扩展性测试类"""
@pytest.fixture
def temp_database(self):
"""创建临时数据库"""
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
temp_db_path = tmp.name
yield temp_db_path
# 清理临时文件
if os.path.exists(temp_db_path):
os.unlink(temp_db_path)
@pytest.mark.asyncio
async def test_large_scale_data_storage(self, temp_database):
"""测试大规模数据存储可扩展性"""
system = StockAnalysisSystem()
system.db_config = {
"database_url": f"sqlite:///{temp_database}",
"echo": False,
"pool_size": 20,
"max_overflow": 50
}
await system.initialize()
# 大规模数据测试
large_batch_sizes = [5000, 10000]
for batch_size in large_batch_sizes:
batch_data = []
for i in range(batch_size):
batch_data.append({
"code": f"{i:06d}",
"name": f"测试股票{i}",
"industry": "测试行业",
"market": "测试市场",
"list_date": "2020-01-01"
})
# 测量大规模存储时间
start_time = time.time()
result = await system.stock_repo.save_stock_basic_info(batch_data)
end_time = time.time()
execution_time = end_time - start_time
assert result is True
# 计算性能指标
records_per_second = batch_size / execution_time
print(f"大规模存储 {batch_size} 条记录 - 时间: {execution_time:.3f}秒, 速度: {records_per_second:.1f} 条/秒")
# 验证可扩展性
assert execution_time < 300.0 # 10000条记录应该在300秒内完成
assert records_per_second > 30.0 # 最低性能要求
@pytest.mark.asyncio
async def test_high_concurrency_scalability(self, temp_database):
"""测试高并发可扩展性"""
system = StockAnalysisSystem()
system.db_config = {
"database_url": f"sqlite:///{temp_database}",
"echo": False,
"pool_size": 30,
"max_overflow": 100
}
await system.initialize()
# 高并发测试
high_concurrency_levels = [50, 100]
for concurrency in high_concurrency_levels:
async def query_task(code):
return await system.stock_repo.get_stock_basic_info(code)
tasks = []
for i in range(concurrency):
code = f"{i % 1000:06d}"
tasks.append(query_task(code))
# 测量高并发执行时间
start_time = time.time()
results = await asyncio.gather(*tasks)
end_time = time.time()
execution_time = end_time - start_time
print(f"高并发 {concurrency} 个查询 - 时间: {execution_time:.3f}")
# 验证高并发可扩展性
assert execution_time < 10.0 # 100个并发查询应该在10秒内完成
class TestPerformanceBenchmarks:
"""性能基准测试类"""
@pytest.fixture
def temp_database(self):
"""创建临时数据库"""
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
temp_db_path = tmp.name
yield temp_db_path
# 清理临时文件
if os.path.exists(temp_database):
os.unlink(temp_database)
@pytest.mark.asyncio
async def test_benchmark_data_collection(self, temp_database):
"""基准测试:数据采集性能"""
system = StockAnalysisSystem()
system.db_config = {
"database_url": f"sqlite:///{temp_database}",
"echo": False
}
await system.initialize()
# Mock数据采集性能测试
with patch.object(system.data_manager, "get_stock_basic_info") as mock_collector:
mock_collector.return_value = [{"code": "000001", "name": "测试", "industry": "测试", "market": "测试", "list_date": "2020-01-01"}]
# 测量数据采集时间
start_time = time.time()
# 执行多次数据采集
for _ in range(100):
await system.data_manager.get_stock_basic_info()
end_time = time.time()
execution_time = end_time - start_time
print(f"数据采集基准测试 - 100次采集时间: {execution_time:.3f}")
# 基准性能要求
assert execution_time < 10.0 # 100次采集应该在10秒内完成
@pytest.mark.asyncio
async def test_benchmark_end_to_end_workflow(self, temp_database):
"""基准测试:端到端工作流性能"""
system = StockAnalysisSystem()
system.db_config = {
"database_url": f"sqlite:///{temp_database}",
"echo": False
}
await system.initialize()
# Mock端到端工作流
mock_data = [{"code": "000001", "name": "测试", "industry": "测试", "market": "测试", "list_date": "2020-01-01"}]
with patch.object(system.data_manager, "get_stock_basic_info", return_value=mock_data):
# 测量完整工作流时间
start_time = time.time()
# 执行完整工作流:采集→处理→存储
raw_data = await system.data_manager.get_stock_basic_info()
processed_data = system.data_processor.process_stock_basic_info(raw_data)
storage_result = await system.stock_repo.save_stock_basic_info(processed_data)
end_time = time.time()
execution_time = end_time - start_time
assert storage_result is True
print(f"端到端工作流基准测试 - 时间: {execution_time:.3f}")
# 基准性能要求
assert execution_time < 2.0 # 单次完整工作流应该在2秒内完成

483
tests/test_storage.py Normal file
View File

@ -0,0 +1,483 @@
"""
数据存储模块单元测试
测试数据库管理模型和存储仓库的功能
"""
import pytest
import asyncio
from unittest.mock import Mock, patch, AsyncMock
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from src.storage.database import DatabaseManager
from src.storage.models import StockBasic, DailyKline, FinancialReport, DataSource, SystemLog
from src.storage.stock_repository import StockRepository
from src.utils.exceptions import DatabaseError
class TestDatabaseManager:
"""数据库管理器测试类"""
# 使用conftest.py中的db_manager fixture不需要重新定义
@pytest.fixture
def test_engine(self):
"""测试数据库引擎"""
return create_engine("sqlite:///:memory:")
def test_singleton_pattern(self, db_manager):
"""测试单例模式"""
# 测试数据库管理器不是单例模式,而是测试专用的实例
# 验证测试数据库管理器功能正常
assert db_manager.engine is not None
assert db_manager.Base is not None
assert db_manager.SessionLocal is not None
def test_configure_database_success(self, db_manager):
"""测试数据库配置成功"""
# DatabaseManager会自动从settings配置数据库
assert db_manager.engine is not None
def test_configure_database_invalid_url(self, db_manager):
"""测试无效数据库URL"""
# DatabaseManager会自动从settings配置数据库无法测试无效URL
# 因为配置在初始化时就已经完成
pass
def test_create_tables_success(self, db_manager):
"""测试创建表成功"""
# 创建表
result = db_manager.create_tables()
assert result is True
# 验证表是否存在 - SQLAlchemy 2.0兼容
from sqlalchemy import inspect
table_names = inspect(db_manager.engine).get_table_names()
expected_tables = ["stock_basic", "daily_kline", "financial_report"]
for table in expected_tables:
assert table in table_names
def test_get_session_success(self, db_manager):
"""测试获取会话成功"""
# 创建表
db_manager.create_tables()
# 获取会话
session = db_manager.get_session()
assert session is not None
# 关闭会话
session.close()
def test_drop_tables_success(self, db_manager):
"""测试删除表成功"""
# 创建表
db_manager.create_tables()
# 删除表
result = db_manager.drop_tables()
assert result is True
# 验证表是否被删除 - SQLAlchemy 2.0兼容
from sqlalchemy import inspect
table_names = inspect(db_manager.engine).get_table_names()
assert len(table_names) == 0
class TestModels:
"""数据库模型测试类"""
def test_stock_basic_model(self):
"""测试股票基础信息模型"""
stock = StockBasic(
code="000001",
name="平安银行",
market="sh",
company_name="平安银行股份有限公司",
industry="银行",
area="广东",
ipo_date="1991-04-03",
listing_status=True
)
assert stock.code == "000001"
assert stock.name == "平安银行"
assert stock.market == "sh"
assert stock.company_name == "平安银行股份有限公司"
assert stock.industry == "银行"
assert stock.area == "广东"
assert stock.ipo_date == "1991-04-03"
assert stock.listing_status == True
# created_at和updated_at字段由数据库自动生成创建对象时为None
def test_daily_kline_model(self):
"""测试日K线数据模型"""
kline = DailyKline(
stock_code="000001",
trade_date="2024-01-15",
open_price=10.5,
high_price=11.2,
low_price=10.3,
close_price=10.8,
volume=1000000,
amount=10800000
)
assert kline.stock_code == "000001"
assert kline.trade_date == "2024-01-15"
assert kline.open_price == 10.5
assert kline.high_price == 11.2
assert kline.low_price == 10.3
assert kline.close_price == 10.8
assert kline.volume == 1000000
assert kline.amount == 10800000
# created_at字段由数据库自动生成创建对象时为None
def test_financial_report_model(self):
"""测试财务报告模型"""
financial = FinancialReport(
stock_code="000001",
report_date="2023-12-31",
report_type="年报",
report_year=2023,
eps=1.5,
net_profit=1500000000,
revenue=5000000000,
total_assets=10000000000
)
assert financial.stock_code == "000001"
assert financial.report_date == "2023-12-31"
assert financial.report_type == "年报"
assert financial.report_year == 2023
assert financial.eps == 1.5
assert financial.net_profit == 1500000000
assert financial.revenue == 5000000000
assert financial.total_assets == 10000000000
# created_at字段由数据库自动生成创建对象时为None
def test_data_source_model(self):
"""测试数据源模型"""
source = DataSource(
source_name="akshare",
source_type="api",
sync_status="正常",
last_sync_time="2024-01-15 10:00:00"
)
assert source.source_name == "akshare"
assert source.source_type == "api"
assert source.sync_status == "正常"
assert source.last_sync_time == "2024-01-15 10:00:00"
# created_at和updated_at字段由数据库自动生成创建对象时为None
def test_system_log_model(self):
"""测试系统日志模型"""
log = SystemLog(
log_level="INFO",
module_name="data_collection",
message="数据采集完成"
)
assert log.log_level == "INFO"
assert log.module_name == "data_collection"
assert log.message == "数据采集完成"
# created_at字段由数据库自动生成创建对象时为None
class TestStockRepository:
"""股票存储仓库测试类"""
@pytest.fixture
def stock_repo(self, db_manager):
"""股票存储仓库实例"""
# 获取数据库会话
session = db_manager.get_session()
return StockRepository(session)
def test_save_stock_basic_info_success(self, stock_repo):
"""测试保存股票基础信息成功"""
from datetime import date
stock_data = [
{
"code": "000001",
"name": "平安银行",
"market": "主板",
"industry": "银行",
"area": "广东",
"ipo_date": date(1991, 4, 3),
"data_source": "akshare"
}
]
result = stock_repo.save_stock_basic_info(stock_data)
assert result["added_count"] == 1
assert result["error_count"] == 0
# 验证数据是否保存
saved_data = stock_repo.get_stock_basic_info()
assert len(saved_data) == 1
assert saved_data[0].code == "000001"
assert saved_data[0].name == "平安银行"
def test_save_stock_basic_info_duplicate(self, stock_repo):
"""测试保存重复股票基础信息"""
stock_data = [
{
"code": "000001",
"name": "平安银行",
"market": "主板",
"data_source": "akshare"
}
]
# 第一次保存
result1 = stock_repo.save_stock_basic_info(stock_data)
assert result1["added_count"] == 1
# 第二次保存相同数据
result2 = stock_repo.save_stock_basic_info(stock_data)
assert result2["updated_count"] == 1
# 验证只有一条记录
saved_data = stock_repo.get_stock_basic_info()
assert len(saved_data) == 1
def test_save_daily_kline_data_success(self, stock_repo):
"""测试保存日K线数据成功"""
from datetime import date
# 先保存股票基础信息
stock_data = [
{
"code": "000001",
"name": "平安银行",
"market": "主板",
"data_source": "akshare"
}
]
stock_repo.save_stock_basic_info(stock_data)
# 再保存日K线数据
kline_data = [
{
"code": "000001",
"date": "2024-01-15",
"open": 10.5,
"high": 11.2,
"low": 10.3,
"close": 10.8,
"volume": 1000000,
"amount": 10800000,
"data_source": "akshare"
}
]
result = stock_repo.save_daily_kline_data(kline_data)
assert result["added_count"] == 1
assert result["error_count"] == 0
def test_save_financial_report_data_success(self, stock_repo):
"""测试保存财务报告数据成功"""
from datetime import date
# 先保存股票基础信息
stock_data = [
{
"code": "000001",
"name": "平安银行",
"market": "主板",
"industry": "银行",
"area": "广东",
"ipo_date": date(1991, 4, 3),
"data_source": "akshare"
}
]
stock_repo.save_stock_basic_info(stock_data)
# 再保存财务报告数据
financial_data = [
{
"code": "000001",
"report_date": "2023-12-31",
"report_type": "年报",
"eps": 1.5,
"net_profit": 1500000000,
"revenue": 5000000000,
"data_source": "akshare"
}
]
result = stock_repo.save_financial_report_data(financial_data)
assert result["added_count"] == 1
assert result["error_count"] == 0
def test_get_stock_basic_info_success(self, stock_repo):
"""测试获取股票基础信息成功"""
# 先保存数据
stock_data = [
{
"code": "000001",
"name": "平安银行",
"market": "主板",
"data_source": "akshare"
}
]
stock_repo.save_stock_basic_info(stock_data)
# 获取数据
result = stock_repo.get_stock_basic_info()
assert len(result) == 1
assert result[0].code == "000001"
assert result[0].name == "平安银行"
def test_get_stock_basic_info_not_found(self, stock_repo):
"""测试获取不存在的股票基础信息"""
result = stock_repo.get_stock_basic_info()
# 没有保存任何数据,所以结果应该为空
assert len(result) == 0
def test_get_daily_kline_data_success(self, stock_repo):
"""测试获取日K线数据成功"""
# 先保存数据
kline_data = [
{
"code": "000001",
"date": "2024-01-15",
"open": 10.5,
"high": 11.0,
"low": 10.0,
"close": 10.8,
"volume": 1000000,
"amount": 10800000,
"data_source": "akshare"
}
]
stock_repo.save_daily_kline_data(kline_data)
# 获取数据
from datetime import date
result = stock_repo.get_daily_kline_data("000001", date(2024, 1, 1), date(2024, 1, 31))
assert len(result) == 1
assert result[0].stock_code == "000001"
assert result[0].trade_date == date(2024, 1, 15)
assert result[0].open_price == 10.5
assert result[0].close_price == 10.8
def test_transaction_rollback_on_error(self, stock_repo):
"""测试事务回滚"""
# 创建无效数据(缺少必要字段)
invalid_data = [
{
"code": "000001",
# 缺少name字段
"market": "主板"
}
]
# 调用方法,应该不会抛出异常,但会记录错误
result = stock_repo.save_stock_basic_info(invalid_data)
# 验证方法返回了错误计数
assert result["error_count"] == 1
assert result["added_count"] == 0
assert result["updated_count"] == 0
# 验证没有数据被保存(事务回滚)
saved_data = stock_repo.get_stock_basic_info()
assert len(saved_data) == 0
class TestDatabaseOperations:
"""数据库操作测试类"""
@pytest.fixture
def setup_database(self):
"""设置测试数据库"""
# 使用测试数据库管理器而不是主数据库管理器
from tests.conftest import create_test_database_manager
db_manager = create_test_database_manager()
db_manager.create_tables()
return db_manager
def test_bulk_insert_performance(self, setup_database):
"""测试批量插入性能"""
db_manager = setup_database
session = db_manager.get_session()
# 创建大量测试数据
test_data = []
for i in range(1000):
stock = db_manager.StockBasic(
code=f"{i:06d}",
name=f"测试股票{i}",
market="主板",
data_source="test"
)
test_data.append(stock)
# 批量插入
import time
start_time = time.time()
session.bulk_save_objects(test_data)
session.commit()
end_time = time.time()
execution_time = end_time - start_time
# 验证插入的数据量
count = session.query(db_manager.StockBasic).count()
assert count == 1000
# 性能要求1000条数据插入时间应小于1秒
assert execution_time < 1.0
session.close()
def test_query_performance(self, setup_database):
"""测试查询性能"""
db_manager = setup_database
session = db_manager.get_session()
# 插入测试数据
test_data = []
for i in range(1000):
stock = db_manager.StockBasic(
code=f"{i:06d}",
name=f"测试股票{i}",
market="主板",
data_source="test"
)
test_data.append(stock)
session.bulk_save_objects(test_data)
session.commit()
# 测试查询性能
import time
start_time = time.time()
result = session.query(db_manager.StockBasic).filter(db_manager.StockBasic.market == "主板").all()
end_time = time.time()
execution_time = end_time - start_time
# 验证查询结果
assert len(result) == 1000
# 性能要求1000条数据查询时间应小于0.1秒
assert execution_time < 0.1
session.close()

325
update_all_data.py Normal file
View File

@ -0,0 +1,325 @@
"""
完整数据更新脚本
同时更新K线数据和财务数据支持分批处理和进度显示
"""
import sys
import os
import asyncio
import logging
from datetime import date, datetime
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.storage.database import db_manager
from src.storage.stock_repository import StockRepository
from src.data.data_manager import DataManager
from src.config.settings import Settings
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def convert_to_baostock_format(stock_code: str) -> str:
"""
将6位股票代码转换为Baostock格式9
Args:
stock_code: 6位股票代码
Returns:
9位Baostock格式股票代码
"""
if len(stock_code) == 6:
# 判断市场类型
if stock_code.startswith(('6', '9')):
return f"sh.{stock_code}"
elif stock_code.startswith(('0', '3')):
return f"sz.{stock_code}"
else:
return stock_code
return stock_code
async def update_kline_data_batch(stocks: list, data_manager: DataManager, repository: StockRepository, batch_size: int = 10):
"""
分批更新K线数据
Args:
stocks: 股票列表
data_manager: 数据管理器
repository: 存储库
batch_size: 每批处理的股票数量
Returns:
更新结果
"""
total_kline_data = []
success_count = 0
error_count = 0
# 分批处理
for i in range(0, len(stocks), batch_size):
batch = stocks[i:i + batch_size]
logger.info(f"处理K线数据批次 {i//batch_size + 1}/{(len(stocks)-1)//batch_size + 1}: {len(batch)}只股票")
batch_kline_data = []
batch_success = 0
batch_error = 0
for stock in batch:
try:
# 转换为Baostock格式
baostock_code = convert_to_baostock_format(stock.code)
# 获取K线数据最近3个月
end_date = date.today()
start_date = date(end_date.year, end_date.month - 3, 1)
kline_data = await data_manager.get_daily_kline_data(
baostock_code, start_date, end_date
)
if kline_data:
# 将数据中的代码转换回6位格式
for data in kline_data:
data["code"] = stock.code
batch_kline_data.extend(kline_data)
batch_success += 1
logger.info(f"股票{stock.code}获取到{len(kline_data)}条K线数据")
else:
logger.warning(f"股票{stock.code}未获取到K线数据")
batch_error += 1
except Exception as e:
logger.error(f"获取股票{stock.code}K线数据失败: {str(e)}")
batch_error += 1
continue
# 小延迟避免请求过快
await asyncio.sleep(0.2)
# 保存当前批次的数据
if batch_kline_data:
try:
save_result = repository.save_daily_kline_data(batch_kline_data)
logger.info(f"批次K线数据保存结果: {save_result}")
total_kline_data.extend(batch_kline_data)
except Exception as e:
logger.error(f"保存批次K线数据失败: {str(e)}")
batch_error += len(batch)
success_count += batch_success
error_count += batch_error
logger.info(f"批次完成: 成功{batch_success}只, 失败{batch_error}")
return {
"success": True,
"total_stocks": len(stocks),
"success_count": success_count,
"error_count": error_count,
"kline_data_count": len(total_kline_data)
}
async def update_financial_data_batch(stocks: list, data_manager: DataManager, repository: StockRepository, batch_size: int = 10):
"""
分批更新财务数据
Args:
stocks: 股票列表
data_manager: 数据管理器
repository: 存储库
batch_size: 每批处理的股票数量
Returns:
更新结果
"""
total_financial_data = []
success_count = 0
error_count = 0
# 设置测试年份和季度
test_year = 2023
test_quarter = 4
# 分批处理
for i in range(0, len(stocks), batch_size):
batch = stocks[i:i + batch_size]
logger.info(f"处理财务数据批次 {i//batch_size + 1}/{(len(stocks)-1)//batch_size + 1}: {len(batch)}只股票")
batch_financial_data = []
batch_success = 0
batch_error = 0
for stock in batch:
try:
# 转换为Baostock格式
baostock_code = convert_to_baostock_format(stock.code)
# 获取财务数据
financial_data = await data_manager.get_financial_report(
baostock_code, test_year, test_quarter
)
if financial_data:
# 将数据中的代码转换回6位格式
for data in financial_data:
data["code"] = stock.code
batch_financial_data.extend(financial_data)
batch_success += 1
logger.info(f"股票{stock.code}获取到{len(financial_data)}条财务数据")
else:
logger.warning(f"股票{stock.code}未获取到财务数据")
batch_error += 1
except Exception as e:
logger.error(f"获取股票{stock.code}财务数据失败: {str(e)}")
batch_error += 1
continue
# 小延迟避免请求过快
await asyncio.sleep(0.3)
# 保存当前批次的数据
if batch_financial_data:
try:
save_result = repository.save_financial_report_data(batch_financial_data)
logger.info(f"批次财务数据保存结果: {save_result}")
total_financial_data.extend(batch_financial_data)
except Exception as e:
logger.error(f"保存批次财务数据失败: {str(e)}")
batch_error += len(batch)
success_count += batch_success
error_count += batch_error
logger.info(f"批次完成: 成功{batch_success}只, 失败{batch_error}")
return {
"success": True,
"total_stocks": len(stocks),
"success_count": success_count,
"error_count": error_count,
"financial_data_count": len(total_financial_data)
}
async def update_all_data():
"""
更新所有数据K线数据和财务数据
"""
try:
logger.info("开始更新所有股票数据...")
# 加载配置
settings = Settings()
logger.info("配置加载成功")
# 创建数据管理器
data_manager = DataManager()
logger.info("数据管理器创建成功")
# 创建存储库
repository = StockRepository(db_manager.get_session())
logger.info("存储库创建成功")
# 获取股票基础信息
stocks = repository.get_stock_basic_info()
logger.info(f"获取到{len(stocks)}只股票基础信息")
if not stocks:
logger.error("没有股票基础信息,无法更新数据")
return {"success": False, "error": "没有股票基础信息"}
# 选择前50只股票进行测试避免处理时间过长
test_stocks = stocks[:50]
test_codes = [stock.code for stock in test_stocks]
logger.info(f"测试股票代码: {test_codes}")
# 更新K线数据
logger.info("=== 开始更新K线数据 ===")
kline_result = await update_kline_data_batch(test_stocks, data_manager, repository, batch_size=5)
# 更新财务数据
logger.info("=== 开始更新财务数据 ===")
financial_result = await update_financial_data_batch(test_stocks, data_manager, repository, batch_size=5)
# 汇总结果
result = {
"success": True,
"kline_data": kline_result,
"financial_data": financial_result,
"total_stocks": len(test_stocks)
}
logger.info(f"所有数据更新完成: {result}")
return result
except Exception as e:
logger.error(f"数据更新异常: {str(e)}")
return {"success": False, "error": str(e)}
def main():
"""
主函数
"""
logger.info("开始完整数据更新流程...")
# 运行异步更新
result = asyncio.run(update_all_data())
if result["success"]:
logger.info("数据更新成功!")
kline_result = result["kline_data"]
financial_result = result["financial_data"]
print("=== 数据更新结果汇总 ===")
print(f"处理股票总数: {result['total_stocks']}")
print("\n=== K线数据更新结果 ===")
print(f"✓ 成功股票数: {kline_result['success_count']}")
print(f"✓ 失败股票数: {kline_result['error_count']}")
print(f"✓ 获取K线数据: {kline_result['kline_data_count']}")
print("\n=== 财务数据更新结果 ===")
print(f"✓ 成功股票数: {financial_result['success_count']}")
print(f"✓ 失败股票数: {financial_result['error_count']}")
print(f"✓ 获取财务数据: {financial_result['financial_data_count']}")
print("\n=== 数据库验证 ===")
# 验证数据库中的数据
try:
repository = StockRepository(db_manager.get_session())
# 查询K线数据
kline_count = repository.session.query(repository.DailyKline).count()
print(f"✓ 日K线数据表: {kline_count}条记录")
# 查询财务数据
financial_count = repository.session.query(repository.FinancialReport).count()
print(f"✓ 财务报告表: {financial_count}条记录")
# 查询股票基础信息
stock_count = repository.session.query(repository.StockBasicInfo).count()
print(f"✓ 股票基础信息: {stock_count}条记录")
except Exception as e:
print(f"⚠ 数据库验证失败: {str(e)}")
print("\n数据更新流程完成!")
else:
logger.error("数据更新失败!")
print(f"更新失败: {result.get('error')}")
return result
if __name__ == "__main__":
main()

View File

@ -0,0 +1,211 @@
"""
兼容Baostock格式的财务数据更新脚本
将6位股票代码转换为9位格式进行财务数据采集
"""
import sys
import os
import asyncio
import logging
from datetime import date, datetime
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.storage.database import db_manager
from src.storage.stock_repository import StockRepository
from src.data.data_manager import DataManager
from src.config.settings import Settings
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def convert_to_baostock_format(stock_code: str) -> str:
"""
将6位股票代码转换为Baostock格式9
Args:
stock_code: 6位股票代码
Returns:
9位Baostock格式股票代码
"""
if len(stock_code) == 6:
# 判断市场类型
if stock_code.startswith(('6', '9')):
return f"sh.{stock_code}"
elif stock_code.startswith(('0', '3')):
return f"sz.{stock_code}"
else:
return stock_code
return stock_code
async def update_financial_data_baostock():
"""
使用Baostock格式更新财务数据
"""
try:
logger.info("开始使用Baostock格式更新财务数据...")
# 加载配置
settings = Settings()
logger.info("配置加载成功")
# 创建数据管理器
data_manager = DataManager()
logger.info("数据管理器创建成功")
# 创建存储库
repository = StockRepository(db_manager.get_session())
logger.info("存储库创建成功")
# 获取股票基础信息
stocks = repository.get_stock_basic_info()
logger.info(f"获取到{len(stocks)}只股票基础信息")
if not stocks:
logger.error("没有股票基础信息,无法更新财务数据")
return {"success": False, "error": "没有股票基础信息"}
# 选择前20只股票进行测试
test_stocks = stocks[:20]
test_codes = [stock.code for stock in test_stocks]
logger.info(f"测试股票代码: {test_codes}")
# 设置测试年份和季度
test_year = 2023
test_quarter = 4
total_financial_data = []
success_count = 0
error_count = 0
# 为每只测试股票获取财务数据
for stock in test_stocks:
try:
# 转换为Baostock格式
baostock_code = convert_to_baostock_format(stock.code)
logger.info(f"获取股票{stock.code}({baostock_code})的财务数据...")
# 使用数据管理器获取财务数据
financial_data = await data_manager.get_financial_report(
baostock_code, test_year, test_quarter
)
if financial_data:
# 将数据中的代码转换回6位格式
for data in financial_data:
data["code"] = stock.code # 使用原始6位代码
total_financial_data.extend(financial_data)
success_count += 1
logger.info(f"股票{stock.code}获取到{len(financial_data)}条财务数据")
else:
logger.warning(f"股票{stock.code}未获取到财务数据")
error_count += 1
# 小延迟避免请求过快
await asyncio.sleep(0.5)
except Exception as e:
logger.error(f"获取股票{stock.code}财务数据失败: {str(e)}")
error_count += 1
continue
# 保存财务数据
if total_financial_data:
try:
logger.info(f"开始保存{len(total_financial_data)}条财务数据...")
save_result = repository.save_financial_report_data(total_financial_data)
logger.info(f"财务数据保存结果: {save_result}")
except Exception as e:
logger.error(f"保存财务数据失败: {str(e)}")
error_count += len(test_stocks)
else:
logger.warning("没有获取到任何财务数据")
# 验证数据库中的财务数据
try:
financial_count = repository.session.query(repository.FinancialReport).count()
logger.info(f"财务报告表: {financial_count}条记录")
if financial_count > 0:
# 显示最新的5条财务数据
latest_financial = repository.session.query(repository.FinancialReport).order_by(
repository.FinancialReport.report_date.desc()
).limit(5).all()
logger.info("最新的5条财务数据:")
for i, financial in enumerate(latest_financial):
logger.info(f" {i+1}. {financial.stock_code} - {financial.report_date} - EPS: {financial.eps}")
except Exception as e:
logger.error(f"查询财务数据失败: {str(e)}")
# 汇总更新结果
result = {
"success": True,
"test_stocks": len(test_stocks),
"success_count": success_count,
"error_count": error_count,
"financial_data_count": len(total_financial_data),
"saved_count": save_result.get("added_count", 0) + save_result.get("updated_count", 0) if total_financial_data else 0
}
logger.info(f"财务数据更新完成: {result}")
return result
except Exception as e:
logger.error(f"财务数据更新异常: {str(e)}")
return {"success": False, "error": str(e)}
def main():
"""
主函数
"""
logger.info("开始兼容Baostock格式的财务数据更新...")
# 运行异步更新
result = asyncio.run(update_financial_data_baostock())
if result["success"]:
logger.info("财务数据更新成功!")
print(f"更新结果: {result}")
if result["financial_data_count"] > 0:
print("✓ 财务数据获取成功")
print(f"✓ 共获取{result['financial_data_count']}条财务数据")
else:
print("⚠ 未获取到财务数据")
if result["saved_count"] > 0:
print("✓ 财务数据保存成功")
print(f"✓ 共保存{result['saved_count']}条财务数据")
else:
print("⚠ 财务数据保存失败")
print(f"✓ 成功股票数: {result['success_count']}")
print(f"✓ 失败股票数: {result['error_count']}")
else:
logger.error("财务数据更新失败!")
print(f"更新失败: {result.get('error')}")
return result
if __name__ == "__main__":
# 运行更新
result = main()
# 输出最终结果
if result.get("success", False):
print("\n兼容Baostock格式的财务数据更新完成")
sys.exit(0)
else:
print("\n兼容Baostock格式的财务数据更新失败")
sys.exit(1)

146
update_kline_baostock.py Normal file
View File

@ -0,0 +1,146 @@
"""
兼容Baostock格式的数据更新脚本
自动处理股票代码格式转换
"""
import sys
import os
import asyncio
import logging
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.data.data_initializer import DataInitializer
from src.config.settings import Settings
from src.storage.database import db_manager
from src.storage.stock_repository import StockRepository
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def get_baostock_format_code(stock_code: str) -> str:
"""
将股票代码转换为Baostock格式
"""
if stock_code.startswith("sh.") or stock_code.startswith("sz."):
return stock_code
if stock_code.startswith("6"):
return f"sh.{stock_code}"
elif stock_code.startswith("0") or stock_code.startswith("3"):
return f"sz.{stock_code}"
else:
return stock_code
async def update_kline_data_with_baostock_format():
"""
使用Baostock格式更新K线数据
"""
try:
logger.info("开始使用Baostock格式更新K线数据...")
# 加载配置
settings = Settings()
logger.info("配置加载成功")
# 创建数据初始化器
initializer = DataInitializer(settings)
logger.info("数据初始化器创建成功")
# 创建存储库
repository = StockRepository(db_manager.get_session())
logger.info("存储库创建成功")
# 获取所有股票基础信息
stocks = repository.get_stock_basic_info()
logger.info(f"找到{len(stocks)}只股票")
if not stocks:
logger.error("没有股票基础信息,无法更新")
return {"success": False, "error": "没有股票基础信息"}
# 分批处理每次处理10只股票
batch_size = 10
total_batches = (len(stocks) + batch_size - 1) // batch_size
total_kline_data = []
success_count = 0
error_count = 0
for batch_num in range(total_batches):
start_idx = batch_num * batch_size
end_idx = min(start_idx + batch_size, len(stocks))
batch_stocks = stocks[start_idx:end_idx]
logger.info(f"处理第{batch_num + 1}批股票,共{len(batch_stocks)}")
for stock in batch_stocks:
try:
# 转换为Baostock格式
baostock_code = get_baostock_format_code(stock.code)
logger.info(f"获取股票{stock.code}({baostock_code})的K线数据...")
# 使用数据管理器获取K线数据
kline_data = await initializer.data_manager.get_daily_kline_data(
baostock_code,
"2024-01-01",
"2024-01-10"
)
if kline_data:
total_kline_data.extend(kline_data)
success_count += 1
logger.info(f"股票{stock.code}获取到{len(kline_data)}条K线数据")
else:
logger.warning(f"股票{stock.code}未获取到K线数据")
error_count += 1
except Exception as e:
logger.error(f"获取股票{stock.code}K线数据失败: {str(e)}")
error_count += 1
continue
logger.info(f"K线数据更新完成: 成功{success_count}只, 失败{error_count}只, 共获取{len(total_kline_data)}条数据")
return {
"success": True,
"total_stocks": len(stocks),
"success_count": success_count,
"error_count": error_count,
"kline_data_count": len(total_kline_data)
}
except Exception as e:
logger.error(f"更新K线数据异常: {str(e)}")
return {"success": False, "error": str(e)}
async def main():
"""
主函数
"""
result = await update_kline_data_with_baostock_format()
if result["success"]:
logger.info("K线数据更新成功")
print(f"更新结果: {result}")
else:
logger.error("K线数据更新失败")
print(f"更新失败: {result.get('error')}")
return result
if __name__ == "__main__":
# 运行更新
result = asyncio.run(main())
# 输出最终结果
if result.get("success", False):
print("更新成功!")
sys.exit(0)
else:
print("更新失败!")
sys.exit(1)

247
update_kline_data.py Normal file
View File

@ -0,0 +1,247 @@
"""
更新K线数据脚本
只拉取K线数据和财务数据避免重复拉取股票基础信息
"""
import sys
import os
import asyncio
import logging
from datetime import date, timedelta
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.data.data_initializer import DataInitializer
from src.config.settings import Settings
from src.storage.database import db_manager
from src.storage.stock_repository import StockRepository
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
async def update_kline_data():
"""
更新K线数据
Returns:
更新结果
"""
try:
logger.info("开始更新K线数据...")
# 加载配置
settings = Settings()
# 创建数据初始化器
initializer = DataInitializer(settings)
# 创建存储库
repository = StockRepository(db_manager.get_session())
# 获取所有股票代码
stocks = repository.get_stock_basic_info()
if not stocks:
logger.error("没有股票基础信息无法更新K线数据")
return {"success": False, "error": "没有股票基础信息"}
logger.info(f"找到{len(stocks)}只股票开始更新K线数据...")
# 计算日期范围最近1年数据
end_date = date.today()
start_date = date(end_date.year - 1, end_date.month, end_date.day)
total_kline_data = []
success_count = 0
error_count = 0
# 分批处理,避免内存溢出
batch_size = 20
for i in range(0, len(stocks), batch_size):
batch_stocks = stocks[i:i + batch_size]
# 为每只股票获取K线数据
for stock in batch_stocks:
try:
logger.info(f"获取股票{stock.code}的K线数据...")
# 使用数据管理器获取K线数据
kline_data = await initializer.data_manager.get_daily_kline_data(
stock.code,
start_date.strftime("%Y-%m-%d"),
end_date.strftime("%Y-%m-%d")
)
if kline_data:
total_kline_data.extend(kline_data)
success_count += 1
logger.info(f"股票{stock.code}获取到{len(kline_data)}条K线数据")
else:
logger.warning(f"股票{stock.code}未获取到K线数据")
error_count += 1
# 小延迟避免请求过快
await asyncio.sleep(0.1)
except Exception as e:
logger.error(f"获取股票{stock.code}K线数据失败: {str(e)}")
error_count += 1
continue
# 保存K线数据到数据库
if total_kline_data:
logger.info(f"保存{len(total_kline_data)}条K线数据到数据库...")
save_result = repository.save_daily_kline_data(total_kline_data)
logger.info(f"K线数据保存完成: {save_result}")
else:
logger.warning("没有获取到任何K线数据")
save_result = {"added_count": 0, "error_count": 0, "total_count": 0}
result = {
"success": True,
"stock_count": len(stocks),
"success_stocks": success_count,
"error_stocks": error_count,
"kline_data_count": len(total_kline_data),
"save_result": save_result
}
logger.info(f"K线数据更新完成: {result}")
return result
except Exception as e:
logger.error(f"更新K线数据异常: {str(e)}")
return {"success": False, "error": str(e)}
async def update_financial_data():
"""
更新财务数据
Returns:
更新结果
"""
try:
logger.info("开始更新财务数据...")
# 加载配置
settings = Settings()
# 创建数据初始化器
initializer = DataInitializer(settings)
# 创建存储库
repository = StockRepository(db_manager.get_session())
# 获取所有股票代码
stocks = repository.get_stock_basic_info()
if not stocks:
logger.error("没有股票基础信息,无法更新财务数据")
return {"success": False, "error": "没有股票基础信息"}
logger.info(f"找到{len(stocks)}只股票,开始更新财务数据...")
total_financial_data = []
success_count = 0
error_count = 0
# 分批处理
batch_size = 15
for i in range(0, len(stocks), batch_size):
batch_stocks = stocks[i:i + batch_size]
# 为每只股票获取财务数据
for stock in batch_stocks:
try:
logger.info(f"获取股票{stock.code}的财务数据...")
# 使用数据管理器获取财务数据
financial_data = await initializer.data_manager.get_financial_report(stock.code)
if financial_data:
total_financial_data.extend(financial_data)
success_count += 1
logger.info(f"股票{stock.code}获取到{len(financial_data)}条财务数据")
else:
logger.warning(f"股票{stock.code}未获取到财务数据")
error_count += 1
# 小延迟避免请求过快
await asyncio.sleep(0.2)
except Exception as e:
logger.error(f"获取股票{stock.code}财务数据失败: {str(e)}")
error_count += 1
continue
# 保存财务数据到数据库
if total_financial_data:
logger.info(f"保存{len(total_financial_data)}条财务数据到数据库...")
save_result = repository.save_financial_report_data(total_financial_data)
logger.info(f"财务数据保存完成: {save_result}")
else:
logger.warning("没有获取到任何财务数据")
save_result = {"added_count": 0, "updated_count": 0, "error_count": 0, "total_count": 0}
result = {
"success": True,
"stock_count": len(stocks),
"success_stocks": success_count,
"error_stocks": error_count,
"financial_data_count": len(total_financial_data),
"save_result": save_result
}
logger.info(f"财务数据更新完成: {result}")
return result
except Exception as e:
logger.error(f"更新财务数据异常: {str(e)}")
return {"success": False, "error": str(e)}
async def main():
"""
主函数
"""
try:
logger.info("开始更新股票数据...")
# 更新K线数据
kline_result = await update_kline_data()
# 更新财务数据
financial_result = await update_financial_data()
# 汇总结果
total_result = {
"kline_update": kline_result,
"financial_update": financial_result,
"overall_success": kline_result.get("success", False) and financial_result.get("success", False)
}
logger.info(f"数据更新完成: {total_result}")
if total_result["overall_success"]:
logger.info("数据更新成功!")
else:
logger.error("数据更新失败!")
return total_result
except Exception as e:
logger.error(f"数据更新主程序异常: {str(e)}")
return {"success": False, "error": str(e)}
if __name__ == "__main__":
# 运行主程序
result = asyncio.run(main())
# 输出最终结果
if result.get("overall_success", False):
print("数据更新成功!")
sys.exit(0)
else:
print("数据更新失败!")
sys.exit(1)

162
update_kline_fast.py Normal file
View File

@ -0,0 +1,162 @@
"""
快速K线数据更新脚本
只更新部分股票进行测试
"""
import sys
import os
import asyncio
import logging
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.data.data_initializer import DataInitializer
from src.config.settings import Settings
from src.storage.database import db_manager
from src.storage.stock_repository import StockRepository
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def get_baostock_format_code(stock_code: str) -> str:
"""
将股票代码转换为Baostock格式
"""
if stock_code.startswith("sh.") or stock_code.startswith("sz."):
return stock_code
if stock_code.startswith("6"):
return f"sh.{stock_code}"
elif stock_code.startswith("0") or stock_code.startswith("3"):
return f"sz.{stock_code}"
else:
return stock_code
async def update_kline_data_fast():
"""
快速更新K线数据只处理前20只股票
"""
try:
logger.info("开始快速更新K线数据...")
# 加载配置
settings = Settings()
logger.info("配置加载成功")
# 创建数据初始化器
initializer = DataInitializer(settings)
logger.info("数据初始化器创建成功")
# 创建存储库
repository = StockRepository(db_manager.get_session())
logger.info("存储库创建成功")
# 获取所有股票基础信息
stocks = repository.get_stock_basic_info()
logger.info(f"找到{len(stocks)}只股票")
if not stocks:
logger.error("没有股票基础信息,无法更新")
return {"success": False, "error": "没有股票基础信息"}
# 只处理前20只股票进行测试
test_stocks = stocks[:20]
logger.info(f"测试更新前{len(test_stocks)}只股票")
# 分批处理每次处理5只股票
batch_size = 5
total_batches = (len(test_stocks) + batch_size - 1) // batch_size
total_kline_data = []
success_count = 0
error_count = 0
for batch_num in range(total_batches):
start_idx = batch_num * batch_size
end_idx = min(start_idx + batch_size, len(test_stocks))
batch_stocks = test_stocks[start_idx:end_idx]
logger.info(f"处理第{batch_num + 1}批股票,共{len(batch_stocks)}")
for stock in batch_stocks:
try:
# 转换为Baostock格式
baostock_code = get_baostock_format_code(stock.code)
logger.info(f"获取股票{stock.code}({baostock_code})的K线数据...")
# 使用数据管理器获取K线数据
kline_data = await initializer.data_manager.get_daily_kline_data(
baostock_code,
"2024-01-01",
"2024-01-31" # 获取一个月的数据
)
if kline_data:
total_kline_data.extend(kline_data)
success_count += 1
logger.info(f"股票{stock.code}获取到{len(kline_data)}条K线数据")
# 保存到数据库
save_result = repository.save_daily_kline_data(kline_data)
logger.info(f"股票{stock.code}K线数据保存结果: {save_result}")
else:
logger.warning(f"股票{stock.code}未获取到K线数据")
error_count += 1
except Exception as e:
logger.error(f"获取股票{stock.code}K线数据失败: {str(e)}")
error_count += 1
continue
# 每批处理完成后暂停一下
await asyncio.sleep(1)
logger.info(f"K线数据更新完成: 成功{success_count}只, 失败{error_count}只, 共获取{len(total_kline_data)}条数据")
return {
"success": True,
"test_stocks": len(test_stocks),
"success_count": success_count,
"error_count": error_count,
"kline_data_count": len(total_kline_data)
}
except Exception as e:
logger.error(f"更新K线数据异常: {str(e)}")
return {"success": False, "error": str(e)}
async def main():
"""
主函数
"""
result = await update_kline_data_fast()
if result["success"]:
logger.info("K线数据更新成功")
print(f"更新结果: {result}")
if result["kline_data_count"] > 0:
print("✓ K线数据更新成功数据已保存到数据库")
else:
print("⚠ K线数据更新完成但未获取到数据")
else:
logger.error("K线数据更新失败")
print(f"更新失败: {result.get('error')}")
return result
if __name__ == "__main__":
# 运行更新
result = asyncio.run(main())
# 输出最终结果
if result.get("success", False):
print("\n快速更新完成!")
sys.exit(0)
else:
print("\n快速更新失败!")
sys.exit(1)

157
verify_kline_data.py Normal file
View File

@ -0,0 +1,157 @@
"""
验证K线数据更新结果
检查数据库中的K线数据是否成功保存
"""
import sys
import os
import logging
from sqlalchemy import func
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.storage.database import db_manager
from src.storage.stock_repository import StockRepository
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def verify_kline_data():
"""
验证K线数据更新结果
"""
try:
logger.info("开始验证K线数据更新结果...")
# 创建存储库
repository = StockRepository(db_manager.get_session())
logger.info("存储库创建成功")
# 检查股票基础信息表
stocks = repository.get_stock_basic_info()
logger.info(f"股票基础信息表: {len(stocks)}条记录")
if stocks:
# 显示前5只股票
logger.info("前5只股票信息:")
for i, stock in enumerate(stocks[:5]):
logger.info(f" {i+1}. {stock.code} - {stock.name}")
# 检查K线数据表
try:
kline_count = repository.session.query(repository.DailyKline).count()
logger.info(f"日K线数据表: {kline_count}条记录")
if kline_count > 0:
# 显示最新的5条K线数据
latest_kline = repository.session.query(repository.DailyKline).order_by(
repository.DailyKline.trade_date.desc()
).limit(5).all()
logger.info("最新的5条K线数据:")
for i, kline in enumerate(latest_kline):
logger.info(f" {i+1}. {kline.stock_code} - {kline.trade_date} - 收盘价: {kline.close_price}")
# 按股票代码统计K线数据
kline_by_stock = repository.session.query(
repository.DailyKline.stock_code,
func.count(repository.DailyKline.id).label('count')
).group_by(repository.DailyKline.stock_code).all()
logger.info("各股票的K线数据统计:")
for stat in kline_by_stock[:10]: # 显示前10只股票
logger.info(f" {stat.stock_code}: {stat.count}条记录")
if len(kline_by_stock) > 10:
logger.info(f" ... 还有{len(kline_by_stock) - 10}只股票")
except Exception as e:
logger.error(f"查询K线数据失败: {str(e)}")
kline_count = 0
# 检查财务报告表
try:
financial_count = repository.session.query(repository.FinancialReport).count()
logger.info(f"财务报告表: {financial_count}条记录")
except Exception as e:
logger.error(f"查询财务报告数据失败: {str(e)}")
financial_count = 0
# 检查数据源表
try:
datasource_count = repository.session.query(repository.DataSource).count()
logger.info(f"数据源表: {datasource_count}条记录")
except Exception as e:
logger.error(f"查询数据源表失败: {str(e)}")
datasource_count = 0
# 检查系统日志表
try:
log_count = repository.session.query(repository.SystemLog).count()
logger.info(f"系统日志表: {log_count}条记录")
except Exception as e:
logger.error(f"查询系统日志表失败: {str(e)}")
log_count = 0
# 汇总结果
logger.info("数据验证汇总:")
logger.info(f" 股票基础信息: {len(stocks)}")
logger.info(f" 日K线数据: {kline_count}")
logger.info(f" 财务报告: {financial_count}")
logger.info(f" 数据源: {datasource_count}")
logger.info(f" 系统日志: {log_count}")
return {
"success": True,
"stock_count": len(stocks),
"kline_count": kline_count,
"financial_count": financial_count,
"datasource_count": datasource_count,
"log_count": log_count
}
except Exception as e:
logger.error(f"验证K线数据异常: {str(e)}")
return {"success": False, "error": str(e)}
def main():
"""
主函数
"""
result = verify_kline_data()
if result["success"]:
logger.info("数据验证成功!")
print(f"验证结果: {result}")
if result["kline_count"] > 0:
print("✓ K线数据已成功保存到数据库")
print(f"✓ 共保存了{result['kline_count']}条K线数据")
else:
print("⚠ K线数据未保存到数据库")
if result["stock_count"] > 0:
print(f"✓ 股票基础信息: {result['stock_count']}")
else:
print("⚠ 没有股票基础信息")
else:
logger.error("数据验证失败!")
print(f"验证失败: {result.get('error')}")
return result
if __name__ == "__main__":
# 运行验证
result = main()
# 输出最终结果
if result.get("success", False):
print("\n数据验证完成!")
sys.exit(0)
else:
print("\n数据验证失败!")
sys.exit(1)