stock-monitor/docs/database/test_database.py
ycg 569c1c8813 重构股票监控系统:数据库架构升级与功能完善
- 重构数据访问层:引入DAO模式,支持MySQL/SQLite双数据库
- 新增数据库架构:完整的股票数据、AI分析、自选股管理表结构
- 升级AI分析服务:集成豆包大模型,支持多维度分析
- 优化API路由:分离市场数据API,提供更清晰的接口设计
- 完善项目文档:添加数据库迁移指南、新功能指南等
- 清理冗余文件:删除旧的缓存文件和无用配置
- 新增调度器:支持定时任务和数据自动更新
- 改进前端模板:简化的股票展示页面

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-12-01 15:44:25 +08:00

239 lines
6.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
数据库功能测试脚本
"""
import sys
import os
from pathlib import Path
# 添加项目根目录到Python路径
project_root = Path(__file__).parent
sys.path.insert(0, str(project_root))
from app.dao import StockDAO, WatchlistDAO, AIAnalysisDAO, ConfigDAO
from app.services.stock_service_db import StockServiceDB
from app.services.ai_analysis_service_db import AIAnalysisServiceDB
def test_database_connection():
"""测试数据库连接"""
print("1. 测试数据库连接...")
try:
from app.database import DatabaseManager
db_manager = DatabaseManager()
with db_manager.get_cursor() as cursor:
cursor.execute("SELECT 1 as test")
result = cursor.fetchone()
if result and result['test'] == 1:
print(" ✓ 数据库连接正常")
return True
else:
print(" ✗ 数据库连接异常")
return False
except Exception as e:
print(f" ✗ 数据库连接失败: {e}")
return False
def test_dao_functions():
"""测试DAO层功能"""
print("\n2. 测试DAO层功能...")
try:
# 测试各个DAO
stock_dao = StockDAO()
watchlist_dao = WatchlistDAO()
ai_dao = AIAnalysisDAO()
config_dao = ConfigDAO()
# 测试基础查询
stock_count = stock_dao.get_stock_count()
watchlist_count = watchlist_dao.get_watchlist_count()
ai_count = ai_dao.get_analysis_count()
print(f" ✓ 股票数量: {stock_count}")
print(f" ✓ 监控列表: {watchlist_count}")
print(f" ✓ AI分析: {ai_count}")
# 测试配置读写
config_dao.set_config('test_key', 'test_value', 'string')
test_value = config_dao.get_config('test_key')
if test_value == 'test_value':
print(" ✓ 配置读写正常")
else:
print(" ✗ 配置读写异常")
return False
# 清理测试数据
config_dao.delete_config('test_key')
return True
except Exception as e:
print(f" ✗ DAO层测试失败: {e}")
return False
def test_stock_service():
"""测试股票服务"""
print("\n3. 测试股票服务...")
try:
stock_service = StockServiceDB()
# 测试监控列表功能
watchlist = stock_service.get_watchlist()
print(f" ✓ 获取监控列表: {len(watchlist)}")
if watchlist:
# 测试获取股票信息(使用第一只股票)
stock_code = watchlist[0].get('stock_code') or watchlist[0].get('code')
if stock_code:
print(f" ✓ 测试股票: {stock_code}")
# 测试获取股票信息
stock_info = stock_service.get_stock_info(stock_code)
if 'error' not in stock_info:
print(" ✓ 股票信息获取正常")
else:
print(f" ✗ 股票信息获取失败: {stock_info.get('error')}")
return False
# 测试指数信息
index_info = stock_service.get_index_info()
if index_info:
print(f" ✓ 指数信息获取正常: {len(index_info)} 个指数")
else:
print(" ✗ 指数信息获取失败")
return False
return True
except Exception as e:
print(f" ✗ 股票服务测试失败: {e}")
return False
def test_ai_service():
"""测试AI分析服务"""
print("\n4. 测试AI分析服务...")
try:
ai_service = AIAnalysisServiceDB()
stock_service = StockServiceDB()
# 获取一只测试股票
watchlist = stock_service.get_watchlist()
if not watchlist:
print(" ⚠️ 监控列表为空跳过AI服务测试")
return True
stock_code = watchlist[0].get('stock_code') or watchlist[0].get('code')
# 测试价值分析数据获取
value_data = stock_service.get_value_analysis_data(stock_code)
if 'error' not in value_data:
print(" ✓ 价值分析数据获取正常")
else:
print(f" ✗ 价值分析数据获取失败: {value_data.get('error')}")
return False
# 测试AI分析历史记录
history = ai_service.get_analysis_history(stock_code, 'stock', 7)
print(f" ✓ AI分析历史记录: {len(history)}")
return True
except Exception as e:
print(f" ✗ AI服务测试失败: {e}")
return False
def test_api_compatibility():
"""测试API兼容性"""
print("\n5. 测试API兼容性...")
try:
from app.services.stock_service_db import StockServiceDB
from app.services.ai_analysis_service_db import AIAnalysisServiceDB
# 测试服务实例化
stock_service = StockServiceDB()
ai_service = AIAnalysisServiceDB()
print(" ✓ 数据库服务实例化正常")
# 测试方法是否存在
required_methods = [
'get_stock_info', 'get_watchlist', 'add_watch', 'remove_watch',
'update_target', 'get_index_info'
]
for method in required_methods:
if hasattr(stock_service, method):
print(f" ✓ 方法存在: {method}")
else:
print(f" ✗ 方法缺失: {method}")
return False
return True
except Exception as e:
print(f" ✗ API兼容性测试失败: {e}")
return False
def main():
"""主测试函数"""
print("=" * 60)
print("股票监控系统数据库功能测试")
print("=" * 60)
tests = [
test_database_connection,
test_dao_functions,
test_stock_service,
test_ai_service,
test_api_compatibility
]
passed = 0
total = len(tests)
for test in tests:
try:
if test():
passed += 1
except Exception as e:
print(f" 测试异常: {e}")
print("\n" + "=" * 60)
print(f"测试完成!")
print(f"通过: {passed}/{total}")
print("=" * 60)
if passed == total:
print("🎉 所有测试通过!数据库迁移成功!")
print("\n系统现在可以正常使用数据库存储。")
print("如需回滚到JSON文件存储请参考 DATABASE_MIGRATION_GUIDE.md")
else:
print("⚠️ 部分测试未通过,请检查配置和数据库连接。")
return passed == total
if __name__ == "__main__":
try:
success = main()
sys.exit(0 if success else 1)
except KeyboardInterrupt:
print("\n测试被用户中断")
sys.exit(1)
except Exception as e:
print(f"\n测试过程中发生错误: {e}")
sys.exit(1)