325 lines
8.9 KiB
Python
325 lines
8.9 KiB
Python
"""
|
||
完整数据更新脚本
|
||
同时更新K线数据和财务数据,支持分批处理和进度显示
|
||
"""
|
||
|
||
import sys
|
||
import os
|
||
import asyncio
|
||
import logging
|
||
from datetime import date, datetime
|
||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
from src.storage.database import db_manager
|
||
from src.storage.stock_repository import StockRepository
|
||
from src.data.data_manager import DataManager
|
||
from src.config.settings import Settings
|
||
|
||
# 配置日志
|
||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def convert_to_baostock_format(stock_code: str) -> str:
|
||
"""
|
||
将6位股票代码转换为Baostock格式(9位)
|
||
|
||
Args:
|
||
stock_code: 6位股票代码
|
||
|
||
Returns:
|
||
9位Baostock格式股票代码
|
||
"""
|
||
if len(stock_code) == 6:
|
||
# 判断市场类型
|
||
if stock_code.startswith(('6', '9')):
|
||
return f"sh.{stock_code}"
|
||
elif stock_code.startswith(('0', '3')):
|
||
return f"sz.{stock_code}"
|
||
else:
|
||
return stock_code
|
||
return stock_code
|
||
|
||
|
||
async def update_kline_data_batch(stocks: list, data_manager: DataManager, repository: StockRepository, batch_size: int = 10):
|
||
"""
|
||
分批更新K线数据
|
||
|
||
Args:
|
||
stocks: 股票列表
|
||
data_manager: 数据管理器
|
||
repository: 存储库
|
||
batch_size: 每批处理的股票数量
|
||
|
||
Returns:
|
||
更新结果
|
||
"""
|
||
total_kline_data = []
|
||
success_count = 0
|
||
error_count = 0
|
||
|
||
# 分批处理
|
||
for i in range(0, len(stocks), batch_size):
|
||
batch = stocks[i:i + batch_size]
|
||
logger.info(f"处理K线数据批次 {i//batch_size + 1}/{(len(stocks)-1)//batch_size + 1}: {len(batch)}只股票")
|
||
|
||
batch_kline_data = []
|
||
batch_success = 0
|
||
batch_error = 0
|
||
|
||
for stock in batch:
|
||
try:
|
||
# 转换为Baostock格式
|
||
baostock_code = convert_to_baostock_format(stock.code)
|
||
|
||
# 获取K线数据(最近3个月)
|
||
end_date = date.today()
|
||
start_date = date(end_date.year, end_date.month - 3, 1)
|
||
|
||
kline_data = await data_manager.get_daily_kline_data(
|
||
baostock_code, start_date, end_date
|
||
)
|
||
|
||
if kline_data:
|
||
# 将数据中的代码转换回6位格式
|
||
for data in kline_data:
|
||
data["code"] = stock.code
|
||
|
||
batch_kline_data.extend(kline_data)
|
||
batch_success += 1
|
||
logger.info(f"股票{stock.code}获取到{len(kline_data)}条K线数据")
|
||
else:
|
||
logger.warning(f"股票{stock.code}未获取到K线数据")
|
||
batch_error += 1
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取股票{stock.code}K线数据失败: {str(e)}")
|
||
batch_error += 1
|
||
continue
|
||
|
||
# 小延迟避免请求过快
|
||
await asyncio.sleep(0.2)
|
||
|
||
# 保存当前批次的数据
|
||
if batch_kline_data:
|
||
try:
|
||
save_result = repository.save_daily_kline_data(batch_kline_data)
|
||
logger.info(f"批次K线数据保存结果: {save_result}")
|
||
total_kline_data.extend(batch_kline_data)
|
||
except Exception as e:
|
||
logger.error(f"保存批次K线数据失败: {str(e)}")
|
||
batch_error += len(batch)
|
||
|
||
success_count += batch_success
|
||
error_count += batch_error
|
||
|
||
logger.info(f"批次完成: 成功{batch_success}只, 失败{batch_error}只")
|
||
|
||
return {
|
||
"success": True,
|
||
"total_stocks": len(stocks),
|
||
"success_count": success_count,
|
||
"error_count": error_count,
|
||
"kline_data_count": len(total_kline_data)
|
||
}
|
||
|
||
|
||
async def update_financial_data_batch(stocks: list, data_manager: DataManager, repository: StockRepository, batch_size: int = 10):
|
||
"""
|
||
分批更新财务数据
|
||
|
||
Args:
|
||
stocks: 股票列表
|
||
data_manager: 数据管理器
|
||
repository: 存储库
|
||
batch_size: 每批处理的股票数量
|
||
|
||
Returns:
|
||
更新结果
|
||
"""
|
||
total_financial_data = []
|
||
success_count = 0
|
||
error_count = 0
|
||
|
||
# 设置测试年份和季度
|
||
test_year = 2023
|
||
test_quarter = 4
|
||
|
||
# 分批处理
|
||
for i in range(0, len(stocks), batch_size):
|
||
batch = stocks[i:i + batch_size]
|
||
logger.info(f"处理财务数据批次 {i//batch_size + 1}/{(len(stocks)-1)//batch_size + 1}: {len(batch)}只股票")
|
||
|
||
batch_financial_data = []
|
||
batch_success = 0
|
||
batch_error = 0
|
||
|
||
for stock in batch:
|
||
try:
|
||
# 转换为Baostock格式
|
||
baostock_code = convert_to_baostock_format(stock.code)
|
||
|
||
# 获取财务数据
|
||
financial_data = await data_manager.get_financial_report(
|
||
baostock_code, test_year, test_quarter
|
||
)
|
||
|
||
if financial_data:
|
||
# 将数据中的代码转换回6位格式
|
||
for data in financial_data:
|
||
data["code"] = stock.code
|
||
|
||
batch_financial_data.extend(financial_data)
|
||
batch_success += 1
|
||
logger.info(f"股票{stock.code}获取到{len(financial_data)}条财务数据")
|
||
else:
|
||
logger.warning(f"股票{stock.code}未获取到财务数据")
|
||
batch_error += 1
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取股票{stock.code}财务数据失败: {str(e)}")
|
||
batch_error += 1
|
||
continue
|
||
|
||
# 小延迟避免请求过快
|
||
await asyncio.sleep(0.3)
|
||
|
||
# 保存当前批次的数据
|
||
if batch_financial_data:
|
||
try:
|
||
save_result = repository.save_financial_report_data(batch_financial_data)
|
||
logger.info(f"批次财务数据保存结果: {save_result}")
|
||
total_financial_data.extend(batch_financial_data)
|
||
except Exception as e:
|
||
logger.error(f"保存批次财务数据失败: {str(e)}")
|
||
batch_error += len(batch)
|
||
|
||
success_count += batch_success
|
||
error_count += batch_error
|
||
|
||
logger.info(f"批次完成: 成功{batch_success}只, 失败{batch_error}只")
|
||
|
||
return {
|
||
"success": True,
|
||
"total_stocks": len(stocks),
|
||
"success_count": success_count,
|
||
"error_count": error_count,
|
||
"financial_data_count": len(total_financial_data)
|
||
}
|
||
|
||
|
||
async def update_all_data():
|
||
"""
|
||
更新所有数据(K线数据和财务数据)
|
||
"""
|
||
try:
|
||
logger.info("开始更新所有股票数据...")
|
||
|
||
# 加载配置
|
||
settings = Settings()
|
||
logger.info("配置加载成功")
|
||
|
||
# 创建数据管理器
|
||
data_manager = DataManager()
|
||
logger.info("数据管理器创建成功")
|
||
|
||
# 创建存储库
|
||
repository = StockRepository(db_manager.get_session())
|
||
logger.info("存储库创建成功")
|
||
|
||
# 获取股票基础信息
|
||
stocks = repository.get_stock_basic_info()
|
||
logger.info(f"获取到{len(stocks)}只股票基础信息")
|
||
|
||
if not stocks:
|
||
logger.error("没有股票基础信息,无法更新数据")
|
||
return {"success": False, "error": "没有股票基础信息"}
|
||
|
||
# 选择前50只股票进行测试(避免处理时间过长)
|
||
test_stocks = stocks[:50]
|
||
test_codes = [stock.code for stock in test_stocks]
|
||
logger.info(f"测试股票代码: {test_codes}")
|
||
|
||
# 更新K线数据
|
||
logger.info("=== 开始更新K线数据 ===")
|
||
kline_result = await update_kline_data_batch(test_stocks, data_manager, repository, batch_size=5)
|
||
|
||
# 更新财务数据
|
||
logger.info("=== 开始更新财务数据 ===")
|
||
financial_result = await update_financial_data_batch(test_stocks, data_manager, repository, batch_size=5)
|
||
|
||
# 汇总结果
|
||
result = {
|
||
"success": True,
|
||
"kline_data": kline_result,
|
||
"financial_data": financial_result,
|
||
"total_stocks": len(test_stocks)
|
||
}
|
||
|
||
logger.info(f"所有数据更新完成: {result}")
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.error(f"数据更新异常: {str(e)}")
|
||
return {"success": False, "error": str(e)}
|
||
|
||
|
||
def main():
|
||
"""
|
||
主函数
|
||
"""
|
||
logger.info("开始完整数据更新流程...")
|
||
|
||
# 运行异步更新
|
||
result = asyncio.run(update_all_data())
|
||
|
||
if result["success"]:
|
||
logger.info("数据更新成功!")
|
||
|
||
kline_result = result["kline_data"]
|
||
financial_result = result["financial_data"]
|
||
|
||
print("=== 数据更新结果汇总 ===")
|
||
print(f"处理股票总数: {result['total_stocks']}")
|
||
|
||
print("\n=== K线数据更新结果 ===")
|
||
print(f"✓ 成功股票数: {kline_result['success_count']}")
|
||
print(f"✓ 失败股票数: {kline_result['error_count']}")
|
||
print(f"✓ 获取K线数据: {kline_result['kline_data_count']}条")
|
||
|
||
print("\n=== 财务数据更新结果 ===")
|
||
print(f"✓ 成功股票数: {financial_result['success_count']}")
|
||
print(f"✓ 失败股票数: {financial_result['error_count']}")
|
||
print(f"✓ 获取财务数据: {financial_result['financial_data_count']}条")
|
||
|
||
print("\n=== 数据库验证 ===")
|
||
# 验证数据库中的数据
|
||
try:
|
||
repository = StockRepository(db_manager.get_session())
|
||
|
||
# 查询K线数据
|
||
kline_count = repository.session.query(repository.DailyKline).count()
|
||
print(f"✓ 日K线数据表: {kline_count}条记录")
|
||
|
||
# 查询财务数据
|
||
financial_count = repository.session.query(repository.FinancialReport).count()
|
||
print(f"✓ 财务报告表: {financial_count}条记录")
|
||
|
||
# 查询股票基础信息
|
||
stock_count = repository.session.query(repository.StockBasicInfo).count()
|
||
print(f"✓ 股票基础信息: {stock_count}条记录")
|
||
|
||
except Exception as e:
|
||
print(f"⚠ 数据库验证失败: {str(e)}")
|
||
|
||
print("\n数据更新流程完成!")
|
||
|
||
else:
|
||
logger.error("数据更新失败!")
|
||
print(f"更新失败: {result.get('error')}")
|
||
|
||
return result
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |