stock-monitor/app/api/market_routes.py
ycg 569c1c8813 重构股票监控系统:数据库架构升级与功能完善
- 重构数据访问层:引入DAO模式,支持MySQL/SQLite双数据库
- 新增数据库架构:完整的股票数据、AI分析、自选股管理表结构
- 升级AI分析服务:集成豆包大模型,支持多维度分析
- 优化API路由:分离市场数据API,提供更清晰的接口设计
- 完善项目文档:添加数据库迁移指南、新功能指南等
- 清理冗余文件:删除旧的缓存文件和无用配置
- 新增调度器:支持定时任务和数据自动更新
- 改进前端模板:简化的股票展示页面

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-12-01 15:44:25 +08:00

355 lines
12 KiB
Python

"""
市场数据和股票浏览API路由
"""
from fastapi import APIRouter, Query
from typing import Optional, List
from app.services.market_data_service import MarketDataService
from app.services.kline_service import KlineService
from app.scheduler import run_manual_task, get_scheduler_status
from datetime import datetime
import logging
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/market")
market_service = MarketDataService()
kline_service = KlineService()
@router.get("/stocks")
async def get_all_stocks(
page: int = Query(1, description="页码"),
size: int = Query(50, description="每页数量"),
industry: Optional[str] = Query(None, description="行业代码"),
sector: Optional[str] = Query(None, description="概念板块代码"),
search: Optional[str] = Query(None, description="搜索关键词")
):
"""获取所有股票列表,支持分页、行业筛选、概念筛选、搜索"""
try:
# 基础查询
stocks = market_service._get_stock_list_from_db()
# 筛选
if industry:
stocks = [s for s in stocks if s.get('industry_code') == industry]
if sector:
# 需要查询股票-板块关联表
from app.database import DatabaseManager
db_manager = DatabaseManager()
with db_manager.get_connection() as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT stock_code FROM stock_sector_relations WHERE sector_code = %s
""", (sector,))
sector_stocks = {row[0] for row in cursor.fetchall()}
cursor.close()
stocks = [s for s in stocks if s['stock_code'] in sector_stocks]
if search:
search_lower = search.lower()
stocks = [
s for s in stocks
if search_lower in s['stock_name'].lower() or search_lower in s['stock_code']
]
# 分页
total = len(stocks)
start = (page - 1) * size
end = start + size
page_stocks = stocks[start:end]
return {
"total": total,
"page": page,
"size": size,
"pages": (total + size - 1) // size,
"data": page_stocks
}
except Exception as e:
logger.error(f"获取股票列表失败: {e}")
return {"error": f"获取股票列表失败: {str(e)}"}
@router.get("/industries")
async def get_industries():
"""获取所有行业分类"""
try:
industries = market_service.get_industry_list()
return {"data": industries}
except Exception as e:
logger.error(f"获取行业列表失败: {e}")
return {"error": f"获取行业列表失败: {str(e)}"}
@router.get("/sectors")
async def get_sectors():
"""获取所有概念板块"""
try:
sectors = market_service.get_sector_list()
return {"data": sectors}
except Exception as e:
logger.error(f"获取概念板块失败: {e}")
return {"error": f"获取概念板块失败: {str(e)}"}
@router.get("/stocks/{stock_code}")
async def get_stock_detail(stock_code: str):
"""获取股票详细信息"""
try:
# 获取股票基础信息
from app.database import DatabaseManager
db_manager = DatabaseManager()
with db_manager.get_connection() as conn:
cursor = conn.cursor(dictionary=True)
query = """
SELECT s.*, i.industry_name,
GROUP_CONCAT(DISTINCT sec.sector_name) as sector_names
FROM stocks s
LEFT JOIN industries i ON s.industry_code = i.industry_code
LEFT JOIN stock_sector_relations ssr ON s.stock_code = ssr.stock_code
LEFT JOIN sectors sec ON ssr.sector_code = sec.sector_code
WHERE s.stock_code = %s
GROUP BY s.stock_code
"""
cursor.execute(query, (stock_code,))
stock = cursor.fetchone()
cursor.close()
if not stock:
return {"error": "股票不存在"}
return {"data": stock}
except Exception as e:
logger.error(f"获取股票详情失败: {stock_code}, 错误: {e}")
return {"error": f"获取股票详情失败: {str(e)}"}
@router.get("/stocks/{stock_code}/kline")
async def get_kline_data(
stock_code: str,
kline_type: str = Query("daily", description="K线类型: daily/weekly/monthly"),
days: int = Query(30, description="获取天数"),
start_date: Optional[str] = Query(None, description="开始日期 YYYYMMDD"),
end_date: Optional[str] = Query(None, description="结束日期 YYYYMMDD")
):
"""获取股票K线数据"""
try:
# 确定时间范围
limit = days
if start_date and end_date:
# 如果指定了日期范围,不限制数量
limit = 1000
kline_data = kline_service.get_kline_data(
stock_code=stock_code,
kline_type=kline_type,
start_date=start_date,
end_date=end_date,
limit=limit
)
# 获取股票基本信息
from app.services.stock_service_db import StockServiceDB
stock_service = StockServiceDB()
stock_info = stock_service.get_stock_info(stock_code)
return {
"stock_info": stock_info,
"kline_type": kline_type,
"data": kline_data
}
except Exception as e:
logger.error(f"获取K线数据失败: {stock_code}, 错误: {e}")
return {"error": f"获取K线数据失败: {str(e)}"}
@router.get("/overview")
async def get_market_overview():
"""获取市场概览数据"""
try:
overview = kline_service.get_market_overview()
return {"data": overview}
except Exception as e:
logger.error(f"获取市场概览失败: {e}")
return {"error": f"获取市场概览失败: {str(e)}"}
@router.get("/hot-stocks")
async def get_hot_stocks(
rank_type: str = Query("volume", description="排行榜类型: volume/amount/change"),
limit: int = Query(20, description="返回数量")
):
"""获取热门股票排行榜"""
try:
from app.database import DatabaseManager
db_manager = DatabaseManager()
with db_manager.get_connection() as conn:
cursor = conn.cursor(dictionary=True)
today = datetime.now().strftime('%Y-%m-%d')
if rank_type == "volume":
query = """
SELECT s.stock_code, s.stock_name, k.close_price, k.volume,
k.change_percent, k.amount, i.industry_name
FROM kline_data k
JOIN stocks s ON k.stock_code = s.stock_code
LEFT JOIN industries i ON s.industry_code = i.industry_code
WHERE k.kline_type = 'daily' AND k.trade_date = %s
ORDER BY k.volume DESC
LIMIT %s
"""
elif rank_type == "amount":
query = """
SELECT s.stock_code, s.stock_name, k.close_price, k.volume,
k.change_percent, k.amount, i.industry_name
FROM kline_data k
JOIN stocks s ON k.stock_code = s.stock_code
LEFT JOIN industries i ON s.industry_code = i.industry_code
WHERE k.kline_type = 'daily' AND k.trade_date = %s
ORDER BY k.amount DESC
LIMIT %s
"""
elif rank_type == "change":
query = """
SELECT s.stock_code, s.stock_name, k.close_price, k.volume,
k.change_percent, k.amount, i.industry_name
FROM kline_data k
JOIN stocks s ON k.stock_code = s.stock_code
LEFT JOIN industries i ON s.industry_code = i.industry_code
WHERE k.kline_type = 'daily' AND k.trade_date = %s AND k.change_percent IS NOT NULL
ORDER BY k.change_percent DESC
LIMIT %s
"""
else:
return {"error": "不支持的排行榜类型"}
cursor.execute(query, (today, limit))
stocks = cursor.fetchall()
cursor.close()
return {"data": stocks, "rank_type": rank_type}
except Exception as e:
logger.error(f"获取热门股票失败: {e}")
return {"error": f"获取热门股票失败: {str(e)}"}
@router.post("/tasks/{task_name}")
async def run_manual_task(task_name: str):
"""手动执行定时任务"""
try:
result = run_manual_task(task_name)
return {"data": result}
except Exception as e:
logger.error(f"手动执行任务失败: {task_name}, 错误: {e}")
return {"error": f"手动执行任务失败: {str(e)}"}
@router.get("/tasks/status")
async def get_task_status(
task_type: Optional[str] = Query(None, description="任务类型"),
days: int = Query(7, description="查询天数")
):
"""获取任务执行状态"""
try:
tasks = get_scheduler_status(task_type, days)
return {"data": tasks}
except Exception as e:
logger.error(f"获取任务状态失败: {e}")
return {"error": f"获取任务状态失败: {str(e)}"}
@router.post("/sync")
async def sync_market_data():
"""同步市场数据"""
try:
# 更新股票列表
stocks = market_service.get_all_stock_list(force_refresh=True)
stock_count = len(stocks)
# 更新概念分类
concept_count = market_service.update_stock_sectors()
# 更新当日K线数据
kline_result = kline_service.batch_update_kline_data(days_back=1)
return {
"message": "市场数据同步完成",
"stocks_updated": stock_count,
"concepts_updated": concept_count,
"kline_updated": kline_result
}
except Exception as e:
logger.error(f"同步市场数据失败: {e}")
return {"error": f"同步市场数据失败: {str(e)}"}
@router.get("/statistics")
async def get_market_statistics(
days: int = Query(30, description="统计天数")
):
"""获取市场统计数据"""
try:
from app.database import DatabaseManager
from datetime import datetime, timedelta
db_manager = DatabaseManager()
with db_manager.get_connection() as conn:
cursor = conn.cursor(dictionary=True)
start_date = (datetime.now() - timedelta(days=days)).date()
# 获取市场统计数据
query = """
SELECT stat_date, market_code, total_stocks, up_stocks, down_stocks,
flat_stocks, total_amount, total_volume
FROM market_statistics
WHERE stat_date >= %s
ORDER BY stat_date DESC, market_code
"""
cursor.execute(query, (start_date,))
stats = cursor.fetchall()
# 获取行业分布统计
cursor.execute("""
SELECT i.industry_name, COUNT(s.stock_code) as stock_count
FROM stocks s
LEFT JOIN industries i ON s.industry_code = i.industry_code
WHERE s.is_active = TRUE AND i.industry_name IS NOT NULL
GROUP BY i.industry_name
ORDER BY stock_count DESC
""")
industry_stats = cursor.fetchall()
# 获取市场规模统计
cursor.execute("""
SELECT market_type, COUNT(*) as stock_count
FROM stocks
WHERE is_active = TRUE
GROUP BY market_type
""")
market_type_stats = cursor.fetchall()
cursor.close()
return {
"statistics": stats,
"industry_distribution": industry_stats,
"market_type_distribution": market_type_stats,
"period_days": days
}
except Exception as e:
logger.error(f"获取市场统计数据失败: {e}")
return {"error": f"获取市场统计数据失败: {str(e)}"}