339 lines
8.9 KiB
Python
339 lines
8.9 KiB
Python
"""
|
||
修复股票代码格式脚本
|
||
将6位股票代码转换为Baostock需要的9位格式
|
||
"""
|
||
|
||
import sys
|
||
import os
|
||
import logging
|
||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
from src.storage.database import db_manager
|
||
from src.storage.stock_repository import StockRepository
|
||
|
||
# 配置日志
|
||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def convert_to_baostock_format(stock_code: str, market: str) -> str:
|
||
"""
|
||
将6位股票代码转换为Baostock格式
|
||
|
||
Args:
|
||
stock_code: 6位股票代码
|
||
market: 市场类型(sh/sz)
|
||
|
||
Returns:
|
||
9位Baostock格式股票代码
|
||
"""
|
||
if len(stock_code) == 6 and stock_code.isdigit():
|
||
if market == "sh":
|
||
return f"sh.{stock_code}"
|
||
elif market == "sz":
|
||
return f"sz.{stock_code}"
|
||
else:
|
||
return stock_code
|
||
else:
|
||
# 如果已经是9位格式,直接返回
|
||
return stock_code
|
||
|
||
|
||
def get_baostock_format_code(stock_code: str) -> str:
|
||
"""
|
||
根据股票代码判断市场类型并转换为Baostock格式
|
||
|
||
Args:
|
||
stock_code: 股票代码
|
||
|
||
Returns:
|
||
Baostock格式股票代码
|
||
"""
|
||
# 如果已经是Baostock格式,直接返回
|
||
if stock_code.startswith("sh.") or stock_code.startswith("sz."):
|
||
return stock_code
|
||
|
||
# 根据股票代码前缀判断市场类型
|
||
if stock_code.startswith("6"):
|
||
return f"sh.{stock_code}"
|
||
elif stock_code.startswith("0") or stock_code.startswith("3"):
|
||
return f"sz.{stock_code}"
|
||
else:
|
||
return stock_code
|
||
|
||
|
||
def fix_stock_code_format():
|
||
"""
|
||
修复股票代码格式
|
||
"""
|
||
try:
|
||
logger.info("开始修复股票代码格式...")
|
||
|
||
# 创建存储库
|
||
repository = StockRepository(db_manager.get_session())
|
||
logger.info("存储库创建成功")
|
||
|
||
# 获取所有股票基础信息
|
||
stocks = repository.get_stock_basic_info()
|
||
logger.info(f"找到{len(stocks)}只股票")
|
||
|
||
if not stocks:
|
||
logger.error("没有股票基础信息,无法修复")
|
||
return {"success": False, "error": "没有股票基础信息"}
|
||
|
||
# 统计修复情况
|
||
fixed_count = 0
|
||
error_count = 0
|
||
|
||
for stock in stocks:
|
||
try:
|
||
# 获取原始股票代码
|
||
original_code = stock.code
|
||
|
||
# 转换为Baostock格式
|
||
baostock_code = get_baostock_format_code(original_code)
|
||
|
||
if original_code != baostock_code:
|
||
logger.info(f"修复股票代码: {original_code} -> {baostock_code}")
|
||
|
||
# 更新股票代码(这里只是演示,实际需要修改数据库结构)
|
||
# 由于数据库结构限制,我们无法直接修改股票代码
|
||
# 需要创建一个映射表或修改数据采集器的处理逻辑
|
||
fixed_count += 1
|
||
|
||
except Exception as e:
|
||
logger.error(f"修复股票{stock.code}代码格式失败: {str(e)}")
|
||
error_count += 1
|
||
|
||
logger.info(f"股票代码格式修复完成: 修复{fixed_count}只, 错误{error_count}只")
|
||
|
||
# 创建测试数据
|
||
test_codes = ["000001", "600000", "300001", "000007"]
|
||
logger.info("测试股票代码格式转换:")
|
||
for code in test_codes:
|
||
baostock_code = get_baostock_format_code(code)
|
||
logger.info(f" {code} -> {baostock_code}")
|
||
|
||
return {
|
||
"success": True,
|
||
"total_stocks": len(stocks),
|
||
"fixed_count": fixed_count,
|
||
"error_count": error_count
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"修复股票代码格式异常: {str(e)}")
|
||
return {"success": False, "error": str(e)}
|
||
|
||
|
||
def create_baostock_compatible_update_script():
|
||
"""
|
||
创建兼容Baostock格式的数据更新脚本
|
||
"""
|
||
try:
|
||
logger.info("创建兼容Baostock格式的数据更新脚本...")
|
||
|
||
script_content = '''"""
|
||
兼容Baostock格式的数据更新脚本
|
||
自动处理股票代码格式转换
|
||
"""
|
||
|
||
import sys
|
||
import os
|
||
import asyncio
|
||
import logging
|
||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
from src.data.data_initializer import DataInitializer
|
||
from src.config.settings import Settings
|
||
from src.storage.database import db_manager
|
||
from src.storage.stock_repository import StockRepository
|
||
|
||
# 配置日志
|
||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def get_baostock_format_code(stock_code: str) -> str:
|
||
"""
|
||
将股票代码转换为Baostock格式
|
||
"""
|
||
if stock_code.startswith("sh.") or stock_code.startswith("sz."):
|
||
return stock_code
|
||
|
||
if stock_code.startswith("6"):
|
||
return f"sh.{stock_code}"
|
||
elif stock_code.startswith("0") or stock_code.startswith("3"):
|
||
return f"sz.{stock_code}"
|
||
else:
|
||
return stock_code
|
||
|
||
|
||
async def update_kline_data_with_baostock_format():
|
||
"""
|
||
使用Baostock格式更新K线数据
|
||
"""
|
||
try:
|
||
logger.info("开始使用Baostock格式更新K线数据...")
|
||
|
||
# 加载配置
|
||
settings = Settings()
|
||
logger.info("配置加载成功")
|
||
|
||
# 创建数据初始化器
|
||
initializer = DataInitializer(settings)
|
||
logger.info("数据初始化器创建成功")
|
||
|
||
# 创建存储库
|
||
repository = StockRepository(db_manager.get_session())
|
||
logger.info("存储库创建成功")
|
||
|
||
# 获取所有股票基础信息
|
||
stocks = repository.get_stock_basic_info()
|
||
logger.info(f"找到{len(stocks)}只股票")
|
||
|
||
if not stocks:
|
||
logger.error("没有股票基础信息,无法更新")
|
||
return {"success": False, "error": "没有股票基础信息"}
|
||
|
||
# 分批处理,每次处理10只股票
|
||
batch_size = 10
|
||
total_batches = (len(stocks) + batch_size - 1) // batch_size
|
||
|
||
total_kline_data = []
|
||
success_count = 0
|
||
error_count = 0
|
||
|
||
for batch_num in range(total_batches):
|
||
start_idx = batch_num * batch_size
|
||
end_idx = min(start_idx + batch_size, len(stocks))
|
||
batch_stocks = stocks[start_idx:end_idx]
|
||
|
||
logger.info(f"处理第{batch_num + 1}批股票,共{len(batch_stocks)}只")
|
||
|
||
for stock in batch_stocks:
|
||
try:
|
||
# 转换为Baostock格式
|
||
baostock_code = get_baostock_format_code(stock.code)
|
||
logger.info(f"获取股票{stock.code}({baostock_code})的K线数据...")
|
||
|
||
# 使用数据管理器获取K线数据
|
||
kline_data = await initializer.data_manager.get_daily_kline_data(
|
||
baostock_code,
|
||
"2024-01-01",
|
||
"2024-01-10"
|
||
)
|
||
|
||
if kline_data:
|
||
total_kline_data.extend(kline_data)
|
||
success_count += 1
|
||
logger.info(f"股票{stock.code}获取到{len(kline_data)}条K线数据")
|
||
else:
|
||
logger.warning(f"股票{stock.code}未获取到K线数据")
|
||
error_count += 1
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取股票{stock.code}K线数据失败: {str(e)}")
|
||
error_count += 1
|
||
continue
|
||
|
||
logger.info(f"K线数据更新完成: 成功{success_count}只, 失败{error_count}只, 共获取{len(total_kline_data)}条数据")
|
||
|
||
return {
|
||
"success": True,
|
||
"total_stocks": len(stocks),
|
||
"success_count": success_count,
|
||
"error_count": error_count,
|
||
"kline_data_count": len(total_kline_data)
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"更新K线数据异常: {str(e)}")
|
||
return {"success": False, "error": str(e)}
|
||
|
||
|
||
async def main():
|
||
"""
|
||
主函数
|
||
"""
|
||
result = await update_kline_data_with_baostock_format()
|
||
|
||
if result["success"]:
|
||
logger.info("K线数据更新成功!")
|
||
print(f"更新结果: {result}")
|
||
else:
|
||
logger.error("K线数据更新失败!")
|
||
print(f"更新失败: {result.get('error')}")
|
||
|
||
return result
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 运行更新
|
||
result = asyncio.run(main())
|
||
|
||
# 输出最终结果
|
||
if result.get("success", False):
|
||
print("更新成功!")
|
||
sys.exit(0)
|
||
else:
|
||
print("更新失败!")
|
||
sys.exit(1)
|
||
'''
|
||
|
||
# 写入脚本文件
|
||
script_path = os.path.join(os.path.dirname(__file__), "update_kline_baostock.py")
|
||
with open(script_path, "w", encoding="utf-8") as f:
|
||
f.write(script_content)
|
||
|
||
logger.info(f"兼容Baostock格式的数据更新脚本已创建: {script_path}")
|
||
|
||
return {"success": True, "script_path": script_path}
|
||
|
||
except Exception as e:
|
||
logger.error(f"创建兼容脚本失败: {str(e)}")
|
||
return {"success": False, "error": str(e)}
|
||
|
||
|
||
def main():
|
||
"""
|
||
主函数
|
||
"""
|
||
# 修复股票代码格式
|
||
fix_result = fix_stock_code_format()
|
||
|
||
# 创建兼容脚本
|
||
script_result = create_baostock_compatible_update_script()
|
||
|
||
if fix_result["success"] and script_result["success"]:
|
||
logger.info("股票代码格式修复完成!")
|
||
print(f"修复结果: {fix_result}")
|
||
print(f"脚本创建结果: {script_result}")
|
||
|
||
# 测试转换函数
|
||
test_codes = ["000001", "600000", "300001", "000007"]
|
||
print("\n测试股票代码格式转换:")
|
||
for code in test_codes:
|
||
baostock_code = get_baostock_format_code(code)
|
||
print(f" {code} -> {baostock_code}")
|
||
|
||
print("\n请运行以下命令测试新的更新脚本:")
|
||
print("python update_kline_baostock.py")
|
||
|
||
return True
|
||
else:
|
||
logger.error("股票代码格式修复失败!")
|
||
print(f"修复失败: {fix_result.get('error')}")
|
||
print(f"脚本创建失败: {script_result.get('error')}")
|
||
return False
|
||
|
||
|
||
if __name__ == "__main__":
|
||
success = main()
|
||
|
||
if success:
|
||
print("\n股票代码格式修复完成!")
|
||
sys.exit(0)
|
||
else:
|
||
print("\n股票代码格式修复失败!")
|
||
sys.exit(1) |