stock-monitor/app/services/market_data_service.py
ycg 569c1c8813 重构股票监控系统:数据库架构升级与功能完善
- 重构数据访问层:引入DAO模式,支持MySQL/SQLite双数据库
- 新增数据库架构:完整的股票数据、AI分析、自选股管理表结构
- 升级AI分析服务:集成豆包大模型,支持多维度分析
- 优化API路由:分离市场数据API,提供更清晰的接口设计
- 完善项目文档:添加数据库迁移指南、新功能指南等
- 清理冗余文件:删除旧的缓存文件和无用配置
- 新增调度器:支持定时任务和数据自动更新
- 改进前端模板:简化的股票展示页面

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-12-01 15:44:25 +08:00

400 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
全市场股票数据服务
获取和管理所有A股股票的基础数据、行业分类、K线数据等
"""
import pandas as pd
import logging
from datetime import datetime, date, timedelta
from typing import List, Dict, Optional, Tuple
from app import pro
from app.dao import StockDAO
from app.database import DatabaseManager
logger = logging.getLogger(__name__)
class MarketDataService:
def __init__(self):
self.stock_dao = StockDAO()
self.db_manager = DatabaseManager()
self.logger = logging.getLogger(__name__)
def get_all_stock_list(self, force_refresh: bool = False) -> List[Dict]:
"""获取所有A股股票列表"""
try:
# 如果不是强制刷新,先从数据库获取
if not force_refresh:
stocks = self._get_stock_list_from_db()
if stocks:
self.logger.info(f"从数据库获取到 {len(stocks)} 只股票")
return stocks
# 从tushare获取最新的股票列表
self.logger.info("从tushare获取股票列表...")
return self._fetch_stock_list_from_api()
except Exception as e:
self.logger.error(f"获取股票列表失败: {e}")
return []
def _get_stock_list_from_db(self) -> List[Dict]:
"""从数据库获取股票列表"""
try:
with self.db_manager.get_connection() as conn:
cursor = conn.cursor(dictionary=True)
query = """
SELECT s.*, i.industry_name,
GROUP_CONCAT(DISTINCT sec.sector_name) as sector_names
FROM stocks s
LEFT JOIN industries i ON s.industry_code = i.industry_code
LEFT JOIN stock_sector_relations ssr ON s.stock_code = ssr.stock_code
LEFT JOIN sectors sec ON ssr.sector_code = sec.sector_code
WHERE s.is_active = TRUE
GROUP BY s.stock_code
ORDER BY s.stock_code
"""
cursor.execute(query)
stocks = cursor.fetchall()
cursor.close()
return stocks
except Exception as e:
self.logger.error(f"从数据库获取股票列表失败: {e}")
return []
def _fetch_stock_list_from_api(self) -> List[Dict]:
"""从tushare API获取股票列表"""
try:
all_stocks = []
# 获取A股列表
stock_basic = pro.stock_basic(
exchange='',
list_status='L', # L代表上市
fields='ts_code,symbol,name,area,industry,market,list_date'
)
if stock_basic.empty:
self.logger.warning("未获取到股票数据")
return []
self.logger.info(f"获取到 {len(stock_basic)} 只股票基础信息")
# 处理每只股票
for _, row in stock_basic.iterrows():
try:
stock_info = {
'stock_code': row['symbol'], # 股票代码 (6位)
'stock_name': row['name'],
'market': row['market'], # 市场主板/创业板等
'industry_code': self._map_industry_code(row['industry']),
'area': row.get('area', ''),
'list_date': pd.to_datetime(row['list_date']).date() if pd.notna(row['list_date']) else None,
'market_type': self._get_market_type(row['symbol'], row['market']),
'is_active': True
}
# 保存到数据库
self._save_stock_to_db(stock_info)
all_stocks.append(stock_info)
except Exception as e:
self.logger.error(f"处理股票 {row.get('symbol', 'unknown')} 失败: {e}")
continue
self.logger.info(f"成功保存 {len(all_stocks)} 只股票到数据库")
return all_stocks
except Exception as e:
self.logger.error(f"从API获取股票列表失败: {e}")
return []
def _save_stock_to_db(self, stock_info: Dict) -> bool:
"""保存股票信息到数据库"""
try:
with self.db_manager.get_connection() as conn:
cursor = conn.cursor()
# 使用INSERT ... ON DUPLICATE KEY UPDATE
query = """
INSERT INTO stocks (
stock_code, stock_name, market, industry_code, area,
list_date, market_type, is_active, created_at
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, NOW())
ON DUPLICATE KEY UPDATE
stock_name = VALUES(stock_name),
market = VALUES(market),
industry_code = VALUES(industry_code),
area = VALUES(area),
list_date = VALUES(list_date),
market_type = VALUES(market_type),
is_active = VALUES(is_active),
updated_at = NOW()
"""
cursor.execute(query, (
stock_info['stock_code'],
stock_info['stock_name'],
stock_info['market'],
stock_info['industry_code'],
stock_info['area'],
stock_info['list_date'],
stock_info['market_type'],
stock_info['is_active']
))
conn.commit()
cursor.close()
return True
except Exception as e:
self.logger.error(f"保存股票信息失败: {stock_info['stock_code']}, 错误: {e}")
return False
def _map_industry_code(self, industry_name: str) -> Optional[str]:
"""将行业名称映射到行业代码"""
if pd.isna(industry_name) or not industry_name:
return None
industry_mapping = {
'计算机': 'I09',
'通信': 'I09',
'软件和信息技术服务业': 'I09',
'医药生物': 'Q17',
'生物医药': 'Q17',
'医疗器械': 'Q17',
'电子': 'C03',
'机械设备': 'C03',
'化工': 'C03',
'汽车': 'C03',
'房地产': 'K11',
'银行': 'J10',
'非银金融': 'J10',
'食品饮料': 'C03',
'农林牧渔': 'A01',
'采掘': 'B02',
'钢铁': 'C03',
'有色金属': 'C03',
'建筑材料': 'C03',
'建筑装饰': 'E05',
'电气设备': 'C03',
'国防军工': 'M13',
'交通运输': 'G07',
'公用事业': 'D04',
'传媒': 'R18',
'休闲服务': 'R18',
'家用电器': 'C03',
'纺织服装': 'C03',
'轻工制造': 'C03',
'商业贸易': 'F06',
'综合': 'S19'
}
# 精确匹配
if industry_name in industry_mapping:
return industry_mapping[industry_name]
# 模糊匹配
for key, code in industry_mapping.items():
if key in industry_name or industry_name in key:
return code
return 'C03' # 默认制造业
def _get_market_type(self, stock_code: str, market: str) -> str:
"""获取市场类型"""
if stock_code.startswith('688'):
return '科创板'
elif stock_code.startswith('300'):
return '创业板'
elif stock_code.startswith('600') or stock_code.startswith('601') or stock_code.startswith('603') or stock_code.startswith('605'):
return '主板'
elif stock_code.startswith('000') or stock_code.startswith('001') or stock_code.startswith('002') or stock_code.startswith('003'):
return '主板'
elif stock_code.startswith('8') or stock_code.startswith('43'):
return '新三板'
else:
return market or '其他'
def update_stock_sectors(self, stock_codes: List[str] = None) -> int:
"""更新股票概念板块信息"""
try:
if stock_codes is None:
# 获取所有股票
stock_codes = [stock['stock_code'] for stock in self._get_stock_list_from_db()]
updated_count = 0
total_count = len(stock_codes)
for stock_code in stock_codes:
try:
# 这里可以调用概念板块API获取股票所属概念
# 由于tushare概念接口限制这里先做一些基础映射
self._update_stock_concepts(stock_code)
updated_count += 1
if updated_count % 100 == 0:
self.logger.info(f"已更新 {updated_count}/{total_count} 只股票的概念信息")
except Exception as e:
self.logger.error(f"更新股票 {stock_code} 概念信息失败: {e}")
continue
self.logger.info(f"完成更新 {updated_count} 只股票的概念信息")
return updated_count
except Exception as e:
self.logger.error(f"批量更新股票概念信息失败: {e}")
return 0
def _update_stock_concepts(self, stock_code: str):
"""更新单个股票的概念信息"""
try:
# 基于股票代码做一些基础的概念分类
concepts = []
# 根据股票代码前缀推断概念
if stock_code.startswith('688'):
concepts.append('BK0500') # 半导体
elif stock_code.startswith('300'):
concepts.append('BK0896') # 国产软件
concepts.append('BK0735') # 新基建
# 这里可以扩展更多的概念匹配逻辑
# 也可以调用第三方API获取更准确的概念分类
if concepts:
self._save_stock_concepts(stock_code, concepts)
except Exception as e:
self.logger.error(f"更新股票 {stock_code} 概念失败: {e}")
def _save_stock_concepts(self, stock_code: str, concept_codes: List[str]):
"""保存股票概念关联关系"""
try:
with self.db_manager.get_connection() as conn:
cursor = conn.cursor()
# 先删除现有的概念关联
cursor.execute("DELETE FROM stock_sector_relations WHERE stock_code = %s", (stock_code,))
# 添加新的概念关联
for concept_code in concept_codes:
query = """
INSERT IGNORE INTO stock_sector_relations (stock_code, sector_code)
VALUES (%s, %s)
"""
cursor.execute(query, (stock_code, concept_code))
conn.commit()
cursor.close()
except Exception as e:
self.logger.error(f"保存股票概念关联失败: {stock_code}, 错误: {e}")
def get_stock_by_industry(self, industry_code: str = None, limit: int = 100) -> List[Dict]:
"""根据行业获取股票列表"""
try:
with self.db_manager.get_connection() as conn:
cursor = conn.cursor(dictionary=True)
if industry_code:
query = """
SELECT s.*, i.industry_name
FROM stocks s
LEFT JOIN industries i ON s.industry_code = i.industry_code
WHERE s.industry_code = %s AND s.is_active = TRUE
ORDER BY s.stock_code
LIMIT %s
"""
cursor.execute(query, (industry_code, limit))
else:
query = """
SELECT s.*, i.industry_name
FROM stocks s
LEFT JOIN industries i ON s.industry_code = i.industry_code
WHERE s.is_active = TRUE
ORDER BY s.stock_code
LIMIT %s
"""
cursor.execute(query, (limit,))
stocks = cursor.fetchall()
cursor.close()
return stocks
except Exception as e:
self.logger.error(f"根据行业获取股票失败: {e}")
return []
def get_stock_by_sector(self, sector_code: str, limit: int = 100) -> List[Dict]:
"""根据概念板块获取股票列表"""
try:
with self.db_manager.get_connection() as conn:
cursor = conn.cursor(dictionary=True)
query = """
SELECT s.*, sec.sector_name
FROM stocks s
JOIN stock_sector_relations ssr ON s.stock_code = ssr.stock_code
JOIN sectors sec ON ssr.sector_code = sec.sector_code
WHERE ssr.sector_code = %s AND s.is_active = TRUE
ORDER BY s.stock_code
LIMIT %s
"""
cursor.execute(query, (sector_code, limit))
stocks = cursor.fetchall()
cursor.close()
return stocks
except Exception as e:
self.logger.error(f"根据概念获取股票失败: {e}")
return []
def get_industry_list(self) -> List[Dict]:
"""获取所有行业列表"""
try:
with self.db_manager.get_connection() as conn:
cursor = conn.cursor(dictionary=True)
query = """
SELECT i.industry_code, i.industry_name, i.level,
COUNT(s.stock_code) as stock_count
FROM industries i
LEFT JOIN stocks s ON i.industry_code = s.industry_code AND s.is_active = TRUE
GROUP BY i.industry_code, i.industry_name, i.level
ORDER BY i.industry_code
"""
cursor.execute(query)
industries = cursor.fetchall()
cursor.close()
return industries
except Exception as e:
self.logger.error(f"获取行业列表失败: {e}")
return []
def get_sector_list(self) -> List[Dict]:
"""获取所有概念板块列表"""
try:
with self.db_manager.get_connection() as conn:
cursor = conn.cursor(dictionary=True)
query = """
SELECT s.sector_code, s.sector_name, s.description,
COUNT(ssr.stock_code) as stock_count
FROM sectors s
LEFT JOIN stock_sector_relations ssr ON s.sector_code = ssr.sector_code
LEFT JOIN stocks st ON ssr.stock_code = st.stock_code AND st.is_active = TRUE
GROUP BY s.sector_code, s.sector_name, s.description
ORDER BY s.sector_code
"""
cursor.execute(query)
sectors = cursor.fetchall()
cursor.close()
return sectors
except Exception as e:
self.logger.error(f"获取概念板块列表失败: {e}")
return []