- 重构数据访问层:引入DAO模式,支持MySQL/SQLite双数据库 - 新增数据库架构:完整的股票数据、AI分析、自选股管理表结构 - 升级AI分析服务:集成豆包大模型,支持多维度分析 - 优化API路由:分离市场数据API,提供更清晰的接口设计 - 完善项目文档:添加数据库迁移指南、新功能指南等 - 清理冗余文件:删除旧的缓存文件和无用配置 - 新增调度器:支持定时任务和数据自动更新 - 改进前端模板:简化的股票展示页面 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
400 lines
15 KiB
Python
400 lines
15 KiB
Python
"""
|
||
全市场股票数据服务
|
||
获取和管理所有A股股票的基础数据、行业分类、K线数据等
|
||
"""
|
||
import pandas as pd
|
||
import logging
|
||
from datetime import datetime, date, timedelta
|
||
from typing import List, Dict, Optional, Tuple
|
||
from app import pro
|
||
from app.dao import StockDAO
|
||
from app.database import DatabaseManager
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class MarketDataService:
|
||
def __init__(self):
|
||
self.stock_dao = StockDAO()
|
||
self.db_manager = DatabaseManager()
|
||
self.logger = logging.getLogger(__name__)
|
||
|
||
def get_all_stock_list(self, force_refresh: bool = False) -> List[Dict]:
|
||
"""获取所有A股股票列表"""
|
||
try:
|
||
# 如果不是强制刷新,先从数据库获取
|
||
if not force_refresh:
|
||
stocks = self._get_stock_list_from_db()
|
||
if stocks:
|
||
self.logger.info(f"从数据库获取到 {len(stocks)} 只股票")
|
||
return stocks
|
||
|
||
# 从tushare获取最新的股票列表
|
||
self.logger.info("从tushare获取股票列表...")
|
||
return self._fetch_stock_list_from_api()
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"获取股票列表失败: {e}")
|
||
return []
|
||
|
||
def _get_stock_list_from_db(self) -> List[Dict]:
|
||
"""从数据库获取股票列表"""
|
||
try:
|
||
with self.db_manager.get_connection() as conn:
|
||
cursor = conn.cursor(dictionary=True)
|
||
query = """
|
||
SELECT s.*, i.industry_name,
|
||
GROUP_CONCAT(DISTINCT sec.sector_name) as sector_names
|
||
FROM stocks s
|
||
LEFT JOIN industries i ON s.industry_code = i.industry_code
|
||
LEFT JOIN stock_sector_relations ssr ON s.stock_code = ssr.stock_code
|
||
LEFT JOIN sectors sec ON ssr.sector_code = sec.sector_code
|
||
WHERE s.is_active = TRUE
|
||
GROUP BY s.stock_code
|
||
ORDER BY s.stock_code
|
||
"""
|
||
cursor.execute(query)
|
||
stocks = cursor.fetchall()
|
||
cursor.close()
|
||
return stocks
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"从数据库获取股票列表失败: {e}")
|
||
return []
|
||
|
||
def _fetch_stock_list_from_api(self) -> List[Dict]:
|
||
"""从tushare API获取股票列表"""
|
||
try:
|
||
all_stocks = []
|
||
|
||
# 获取A股列表
|
||
stock_basic = pro.stock_basic(
|
||
exchange='',
|
||
list_status='L', # L代表上市
|
||
fields='ts_code,symbol,name,area,industry,market,list_date'
|
||
)
|
||
|
||
if stock_basic.empty:
|
||
self.logger.warning("未获取到股票数据")
|
||
return []
|
||
|
||
self.logger.info(f"获取到 {len(stock_basic)} 只股票基础信息")
|
||
|
||
# 处理每只股票
|
||
for _, row in stock_basic.iterrows():
|
||
try:
|
||
stock_info = {
|
||
'stock_code': row['symbol'], # 股票代码 (6位)
|
||
'stock_name': row['name'],
|
||
'market': row['market'], # 市场主板/创业板等
|
||
'industry_code': self._map_industry_code(row['industry']),
|
||
'area': row.get('area', ''),
|
||
'list_date': pd.to_datetime(row['list_date']).date() if pd.notna(row['list_date']) else None,
|
||
'market_type': self._get_market_type(row['symbol'], row['market']),
|
||
'is_active': True
|
||
}
|
||
|
||
# 保存到数据库
|
||
self._save_stock_to_db(stock_info)
|
||
all_stocks.append(stock_info)
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"处理股票 {row.get('symbol', 'unknown')} 失败: {e}")
|
||
continue
|
||
|
||
self.logger.info(f"成功保存 {len(all_stocks)} 只股票到数据库")
|
||
return all_stocks
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"从API获取股票列表失败: {e}")
|
||
return []
|
||
|
||
def _save_stock_to_db(self, stock_info: Dict) -> bool:
|
||
"""保存股票信息到数据库"""
|
||
try:
|
||
with self.db_manager.get_connection() as conn:
|
||
cursor = conn.cursor()
|
||
|
||
# 使用INSERT ... ON DUPLICATE KEY UPDATE
|
||
query = """
|
||
INSERT INTO stocks (
|
||
stock_code, stock_name, market, industry_code, area,
|
||
list_date, market_type, is_active, created_at
|
||
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, NOW())
|
||
ON DUPLICATE KEY UPDATE
|
||
stock_name = VALUES(stock_name),
|
||
market = VALUES(market),
|
||
industry_code = VALUES(industry_code),
|
||
area = VALUES(area),
|
||
list_date = VALUES(list_date),
|
||
market_type = VALUES(market_type),
|
||
is_active = VALUES(is_active),
|
||
updated_at = NOW()
|
||
"""
|
||
|
||
cursor.execute(query, (
|
||
stock_info['stock_code'],
|
||
stock_info['stock_name'],
|
||
stock_info['market'],
|
||
stock_info['industry_code'],
|
||
stock_info['area'],
|
||
stock_info['list_date'],
|
||
stock_info['market_type'],
|
||
stock_info['is_active']
|
||
))
|
||
|
||
conn.commit()
|
||
cursor.close()
|
||
return True
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"保存股票信息失败: {stock_info['stock_code']}, 错误: {e}")
|
||
return False
|
||
|
||
def _map_industry_code(self, industry_name: str) -> Optional[str]:
|
||
"""将行业名称映射到行业代码"""
|
||
if pd.isna(industry_name) or not industry_name:
|
||
return None
|
||
|
||
industry_mapping = {
|
||
'计算机': 'I09',
|
||
'通信': 'I09',
|
||
'软件和信息技术服务业': 'I09',
|
||
'医药生物': 'Q17',
|
||
'生物医药': 'Q17',
|
||
'医疗器械': 'Q17',
|
||
'电子': 'C03',
|
||
'机械设备': 'C03',
|
||
'化工': 'C03',
|
||
'汽车': 'C03',
|
||
'房地产': 'K11',
|
||
'银行': 'J10',
|
||
'非银金融': 'J10',
|
||
'食品饮料': 'C03',
|
||
'农林牧渔': 'A01',
|
||
'采掘': 'B02',
|
||
'钢铁': 'C03',
|
||
'有色金属': 'C03',
|
||
'建筑材料': 'C03',
|
||
'建筑装饰': 'E05',
|
||
'电气设备': 'C03',
|
||
'国防军工': 'M13',
|
||
'交通运输': 'G07',
|
||
'公用事业': 'D04',
|
||
'传媒': 'R18',
|
||
'休闲服务': 'R18',
|
||
'家用电器': 'C03',
|
||
'纺织服装': 'C03',
|
||
'轻工制造': 'C03',
|
||
'商业贸易': 'F06',
|
||
'综合': 'S19'
|
||
}
|
||
|
||
# 精确匹配
|
||
if industry_name in industry_mapping:
|
||
return industry_mapping[industry_name]
|
||
|
||
# 模糊匹配
|
||
for key, code in industry_mapping.items():
|
||
if key in industry_name or industry_name in key:
|
||
return code
|
||
|
||
return 'C03' # 默认制造业
|
||
|
||
def _get_market_type(self, stock_code: str, market: str) -> str:
|
||
"""获取市场类型"""
|
||
if stock_code.startswith('688'):
|
||
return '科创板'
|
||
elif stock_code.startswith('300'):
|
||
return '创业板'
|
||
elif stock_code.startswith('600') or stock_code.startswith('601') or stock_code.startswith('603') or stock_code.startswith('605'):
|
||
return '主板'
|
||
elif stock_code.startswith('000') or stock_code.startswith('001') or stock_code.startswith('002') or stock_code.startswith('003'):
|
||
return '主板'
|
||
elif stock_code.startswith('8') or stock_code.startswith('43'):
|
||
return '新三板'
|
||
else:
|
||
return market or '其他'
|
||
|
||
def update_stock_sectors(self, stock_codes: List[str] = None) -> int:
|
||
"""更新股票概念板块信息"""
|
||
try:
|
||
if stock_codes is None:
|
||
# 获取所有股票
|
||
stock_codes = [stock['stock_code'] for stock in self._get_stock_list_from_db()]
|
||
|
||
updated_count = 0
|
||
total_count = len(stock_codes)
|
||
|
||
for stock_code in stock_codes:
|
||
try:
|
||
# 这里可以调用概念板块API获取股票所属概念
|
||
# 由于tushare概念接口限制,这里先做一些基础映射
|
||
self._update_stock_concepts(stock_code)
|
||
updated_count += 1
|
||
|
||
if updated_count % 100 == 0:
|
||
self.logger.info(f"已更新 {updated_count}/{total_count} 只股票的概念信息")
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"更新股票 {stock_code} 概念信息失败: {e}")
|
||
continue
|
||
|
||
self.logger.info(f"完成更新 {updated_count} 只股票的概念信息")
|
||
return updated_count
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"批量更新股票概念信息失败: {e}")
|
||
return 0
|
||
|
||
def _update_stock_concepts(self, stock_code: str):
|
||
"""更新单个股票的概念信息"""
|
||
try:
|
||
# 基于股票代码做一些基础的概念分类
|
||
concepts = []
|
||
|
||
# 根据股票代码前缀推断概念
|
||
if stock_code.startswith('688'):
|
||
concepts.append('BK0500') # 半导体
|
||
elif stock_code.startswith('300'):
|
||
concepts.append('BK0896') # 国产软件
|
||
concepts.append('BK0735') # 新基建
|
||
|
||
# 这里可以扩展更多的概念匹配逻辑
|
||
# 也可以调用第三方API获取更准确的概念分类
|
||
|
||
if concepts:
|
||
self._save_stock_concepts(stock_code, concepts)
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"更新股票 {stock_code} 概念失败: {e}")
|
||
|
||
def _save_stock_concepts(self, stock_code: str, concept_codes: List[str]):
|
||
"""保存股票概念关联关系"""
|
||
try:
|
||
with self.db_manager.get_connection() as conn:
|
||
cursor = conn.cursor()
|
||
|
||
# 先删除现有的概念关联
|
||
cursor.execute("DELETE FROM stock_sector_relations WHERE stock_code = %s", (stock_code,))
|
||
|
||
# 添加新的概念关联
|
||
for concept_code in concept_codes:
|
||
query = """
|
||
INSERT IGNORE INTO stock_sector_relations (stock_code, sector_code)
|
||
VALUES (%s, %s)
|
||
"""
|
||
cursor.execute(query, (stock_code, concept_code))
|
||
|
||
conn.commit()
|
||
cursor.close()
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"保存股票概念关联失败: {stock_code}, 错误: {e}")
|
||
|
||
def get_stock_by_industry(self, industry_code: str = None, limit: int = 100) -> List[Dict]:
|
||
"""根据行业获取股票列表"""
|
||
try:
|
||
with self.db_manager.get_connection() as conn:
|
||
cursor = conn.cursor(dictionary=True)
|
||
|
||
if industry_code:
|
||
query = """
|
||
SELECT s.*, i.industry_name
|
||
FROM stocks s
|
||
LEFT JOIN industries i ON s.industry_code = i.industry_code
|
||
WHERE s.industry_code = %s AND s.is_active = TRUE
|
||
ORDER BY s.stock_code
|
||
LIMIT %s
|
||
"""
|
||
cursor.execute(query, (industry_code, limit))
|
||
else:
|
||
query = """
|
||
SELECT s.*, i.industry_name
|
||
FROM stocks s
|
||
LEFT JOIN industries i ON s.industry_code = i.industry_code
|
||
WHERE s.is_active = TRUE
|
||
ORDER BY s.stock_code
|
||
LIMIT %s
|
||
"""
|
||
cursor.execute(query, (limit,))
|
||
|
||
stocks = cursor.fetchall()
|
||
cursor.close()
|
||
return stocks
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"根据行业获取股票失败: {e}")
|
||
return []
|
||
|
||
def get_stock_by_sector(self, sector_code: str, limit: int = 100) -> List[Dict]:
|
||
"""根据概念板块获取股票列表"""
|
||
try:
|
||
with self.db_manager.get_connection() as conn:
|
||
cursor = conn.cursor(dictionary=True)
|
||
|
||
query = """
|
||
SELECT s.*, sec.sector_name
|
||
FROM stocks s
|
||
JOIN stock_sector_relations ssr ON s.stock_code = ssr.stock_code
|
||
JOIN sectors sec ON ssr.sector_code = sec.sector_code
|
||
WHERE ssr.sector_code = %s AND s.is_active = TRUE
|
||
ORDER BY s.stock_code
|
||
LIMIT %s
|
||
"""
|
||
cursor.execute(query, (sector_code, limit))
|
||
|
||
stocks = cursor.fetchall()
|
||
cursor.close()
|
||
return stocks
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"根据概念获取股票失败: {e}")
|
||
return []
|
||
|
||
def get_industry_list(self) -> List[Dict]:
|
||
"""获取所有行业列表"""
|
||
try:
|
||
with self.db_manager.get_connection() as conn:
|
||
cursor = conn.cursor(dictionary=True)
|
||
|
||
query = """
|
||
SELECT i.industry_code, i.industry_name, i.level,
|
||
COUNT(s.stock_code) as stock_count
|
||
FROM industries i
|
||
LEFT JOIN stocks s ON i.industry_code = s.industry_code AND s.is_active = TRUE
|
||
GROUP BY i.industry_code, i.industry_name, i.level
|
||
ORDER BY i.industry_code
|
||
"""
|
||
cursor.execute(query)
|
||
industries = cursor.fetchall()
|
||
cursor.close()
|
||
return industries
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"获取行业列表失败: {e}")
|
||
return []
|
||
|
||
def get_sector_list(self) -> List[Dict]:
|
||
"""获取所有概念板块列表"""
|
||
try:
|
||
with self.db_manager.get_connection() as conn:
|
||
cursor = conn.cursor(dictionary=True)
|
||
|
||
query = """
|
||
SELECT s.sector_code, s.sector_name, s.description,
|
||
COUNT(ssr.stock_code) as stock_count
|
||
FROM sectors s
|
||
LEFT JOIN stock_sector_relations ssr ON s.sector_code = ssr.sector_code
|
||
LEFT JOIN stocks st ON ssr.stock_code = st.stock_code AND st.is_active = TRUE
|
||
GROUP BY s.sector_code, s.sector_name, s.description
|
||
ORDER BY s.sector_code
|
||
"""
|
||
cursor.execute(query)
|
||
sectors = cursor.fetchall()
|
||
cursor.close()
|
||
return sectors
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"获取概念板块列表失败: {e}")
|
||
return [] |