#!/usr/bin/env python3 """ 数据库功能测试脚本 """ import sys import os from pathlib import Path # 添加项目根目录到Python路径 project_root = Path(__file__).parent sys.path.insert(0, str(project_root)) from app.dao import StockDAO, WatchlistDAO, AIAnalysisDAO, ConfigDAO from app.services.stock_service_db import StockServiceDB from app.services.ai_analysis_service_db import AIAnalysisServiceDB def test_database_connection(): """测试数据库连接""" print("1. 测试数据库连接...") try: from app.database import DatabaseManager db_manager = DatabaseManager() with db_manager.get_cursor() as cursor: cursor.execute("SELECT 1 as test") result = cursor.fetchone() if result and result['test'] == 1: print(" ✓ 数据库连接正常") return True else: print(" ✗ 数据库连接异常") return False except Exception as e: print(f" ✗ 数据库连接失败: {e}") return False def test_dao_functions(): """测试DAO层功能""" print("\n2. 测试DAO层功能...") try: # 测试各个DAO stock_dao = StockDAO() watchlist_dao = WatchlistDAO() ai_dao = AIAnalysisDAO() config_dao = ConfigDAO() # 测试基础查询 stock_count = stock_dao.get_stock_count() watchlist_count = watchlist_dao.get_watchlist_count() ai_count = ai_dao.get_analysis_count() print(f" ✓ 股票数量: {stock_count}") print(f" ✓ 监控列表: {watchlist_count}") print(f" ✓ AI分析: {ai_count}") # 测试配置读写 config_dao.set_config('test_key', 'test_value', 'string') test_value = config_dao.get_config('test_key') if test_value == 'test_value': print(" ✓ 配置读写正常") else: print(" ✗ 配置读写异常") return False # 清理测试数据 config_dao.delete_config('test_key') return True except Exception as e: print(f" ✗ DAO层测试失败: {e}") return False def test_stock_service(): """测试股票服务""" print("\n3. 测试股票服务...") try: stock_service = StockServiceDB() # 测试监控列表功能 watchlist = stock_service.get_watchlist() print(f" ✓ 获取监控列表: {len(watchlist)} 项") if watchlist: # 测试获取股票信息(使用第一只股票) stock_code = watchlist[0].get('stock_code') or watchlist[0].get('code') if stock_code: print(f" ✓ 测试股票: {stock_code}") # 测试获取股票信息 stock_info = stock_service.get_stock_info(stock_code) if 'error' not in stock_info: print(" ✓ 股票信息获取正常") else: print(f" ✗ 股票信息获取失败: {stock_info.get('error')}") return False # 测试指数信息 index_info = stock_service.get_index_info() if index_info: print(f" ✓ 指数信息获取正常: {len(index_info)} 个指数") else: print(" ✗ 指数信息获取失败") return False return True except Exception as e: print(f" ✗ 股票服务测试失败: {e}") return False def test_ai_service(): """测试AI分析服务""" print("\n4. 测试AI分析服务...") try: ai_service = AIAnalysisServiceDB() stock_service = StockServiceDB() # 获取一只测试股票 watchlist = stock_service.get_watchlist() if not watchlist: print(" ⚠️ 监控列表为空,跳过AI服务测试") return True stock_code = watchlist[0].get('stock_code') or watchlist[0].get('code') # 测试价值分析数据获取 value_data = stock_service.get_value_analysis_data(stock_code) if 'error' not in value_data: print(" ✓ 价值分析数据获取正常") else: print(f" ✗ 价值分析数据获取失败: {value_data.get('error')}") return False # 测试AI分析历史记录 history = ai_service.get_analysis_history(stock_code, 'stock', 7) print(f" ✓ AI分析历史记录: {len(history)} 条") return True except Exception as e: print(f" ✗ AI服务测试失败: {e}") return False def test_api_compatibility(): """测试API兼容性""" print("\n5. 测试API兼容性...") try: from app.services.stock_service_db import StockServiceDB from app.services.ai_analysis_service_db import AIAnalysisServiceDB # 测试服务实例化 stock_service = StockServiceDB() ai_service = AIAnalysisServiceDB() print(" ✓ 数据库服务实例化正常") # 测试方法是否存在 required_methods = [ 'get_stock_info', 'get_watchlist', 'add_watch', 'remove_watch', 'update_target', 'get_index_info' ] for method in required_methods: if hasattr(stock_service, method): print(f" ✓ 方法存在: {method}") else: print(f" ✗ 方法缺失: {method}") return False return True except Exception as e: print(f" ✗ API兼容性测试失败: {e}") return False def main(): """主测试函数""" print("=" * 60) print("股票监控系统数据库功能测试") print("=" * 60) tests = [ test_database_connection, test_dao_functions, test_stock_service, test_ai_service, test_api_compatibility ] passed = 0 total = len(tests) for test in tests: try: if test(): passed += 1 except Exception as e: print(f" 测试异常: {e}") print("\n" + "=" * 60) print(f"测试完成!") print(f"通过: {passed}/{total}") print("=" * 60) if passed == total: print("🎉 所有测试通过!数据库迁移成功!") print("\n系统现在可以正常使用数据库存储。") print("如需回滚到JSON文件存储,请参考 DATABASE_MIGRATION_GUIDE.md") else: print("⚠️ 部分测试未通过,请检查配置和数据库连接。") return passed == total if __name__ == "__main__": try: success = main() sys.exit(0 if success else 1) except KeyboardInterrupt: print("\n测试被用户中断") sys.exit(1) except Exception as e: print(f"\n测试过程中发生错误: {e}") sys.exit(1)