""" 全市场股票数据服务 获取和管理所有A股股票的基础数据、行业分类、K线数据等 """ import pandas as pd import logging from datetime import datetime, date, timedelta from typing import List, Dict, Optional, Tuple from app import pro from app.dao import StockDAO from app.database import DatabaseManager logger = logging.getLogger(__name__) class MarketDataService: def __init__(self): self.stock_dao = StockDAO() self.db_manager = DatabaseManager() self.logger = logging.getLogger(__name__) def get_all_stock_list(self, force_refresh: bool = False) -> List[Dict]: """获取所有A股股票列表""" try: # 如果不是强制刷新,先从数据库获取 if not force_refresh: stocks = self._get_stock_list_from_db() if stocks: self.logger.info(f"从数据库获取到 {len(stocks)} 只股票") return stocks # 从tushare获取最新的股票列表 self.logger.info("从tushare获取股票列表...") return self._fetch_stock_list_from_api() except Exception as e: self.logger.error(f"获取股票列表失败: {e}") return [] def _get_stock_list_from_db(self) -> List[Dict]: """从数据库获取股票列表""" try: with self.db_manager.get_connection() as conn: cursor = conn.cursor(dictionary=True) query = """ SELECT s.*, i.industry_name, GROUP_CONCAT(DISTINCT sec.sector_name) as sector_names FROM stocks s LEFT JOIN industries i ON s.industry_code = i.industry_code LEFT JOIN stock_sector_relations ssr ON s.stock_code = ssr.stock_code LEFT JOIN sectors sec ON ssr.sector_code = sec.sector_code WHERE s.is_active = TRUE GROUP BY s.stock_code ORDER BY s.stock_code """ cursor.execute(query) stocks = cursor.fetchall() cursor.close() return stocks except Exception as e: self.logger.error(f"从数据库获取股票列表失败: {e}") return [] def _fetch_stock_list_from_api(self) -> List[Dict]: """从tushare API获取股票列表""" try: all_stocks = [] # 获取A股列表 stock_basic = pro.stock_basic( exchange='', list_status='L', # L代表上市 fields='ts_code,symbol,name,area,industry,market,list_date' ) if stock_basic.empty: self.logger.warning("未获取到股票数据") return [] self.logger.info(f"获取到 {len(stock_basic)} 只股票基础信息") # 处理每只股票 for _, row in stock_basic.iterrows(): try: stock_info = { 'stock_code': row['symbol'], # 股票代码 (6位) 'stock_name': row['name'], 'market': row['market'], # 市场主板/创业板等 'industry_code': self._map_industry_code(row['industry']), 'area': row.get('area', ''), 'list_date': pd.to_datetime(row['list_date']).date() if pd.notna(row['list_date']) else None, 'market_type': self._get_market_type(row['symbol'], row['market']), 'is_active': True } # 保存到数据库 self._save_stock_to_db(stock_info) all_stocks.append(stock_info) except Exception as e: self.logger.error(f"处理股票 {row.get('symbol', 'unknown')} 失败: {e}") continue self.logger.info(f"成功保存 {len(all_stocks)} 只股票到数据库") return all_stocks except Exception as e: self.logger.error(f"从API获取股票列表失败: {e}") return [] def _save_stock_to_db(self, stock_info: Dict) -> bool: """保存股票信息到数据库""" try: with self.db_manager.get_connection() as conn: cursor = conn.cursor() # 使用INSERT ... ON DUPLICATE KEY UPDATE query = """ INSERT INTO stocks ( stock_code, stock_name, market, industry_code, area, list_date, market_type, is_active, created_at ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, NOW()) ON DUPLICATE KEY UPDATE stock_name = VALUES(stock_name), market = VALUES(market), industry_code = VALUES(industry_code), area = VALUES(area), list_date = VALUES(list_date), market_type = VALUES(market_type), is_active = VALUES(is_active), updated_at = NOW() """ cursor.execute(query, ( stock_info['stock_code'], stock_info['stock_name'], stock_info['market'], stock_info['industry_code'], stock_info['area'], stock_info['list_date'], stock_info['market_type'], stock_info['is_active'] )) conn.commit() cursor.close() return True except Exception as e: self.logger.error(f"保存股票信息失败: {stock_info['stock_code']}, 错误: {e}") return False def _map_industry_code(self, industry_name: str) -> Optional[str]: """将行业名称映射到行业代码""" if pd.isna(industry_name) or not industry_name: return None industry_mapping = { '计算机': 'I09', '通信': 'I09', '软件和信息技术服务业': 'I09', '医药生物': 'Q17', '生物医药': 'Q17', '医疗器械': 'Q17', '电子': 'C03', '机械设备': 'C03', '化工': 'C03', '汽车': 'C03', '房地产': 'K11', '银行': 'J10', '非银金融': 'J10', '食品饮料': 'C03', '农林牧渔': 'A01', '采掘': 'B02', '钢铁': 'C03', '有色金属': 'C03', '建筑材料': 'C03', '建筑装饰': 'E05', '电气设备': 'C03', '国防军工': 'M13', '交通运输': 'G07', '公用事业': 'D04', '传媒': 'R18', '休闲服务': 'R18', '家用电器': 'C03', '纺织服装': 'C03', '轻工制造': 'C03', '商业贸易': 'F06', '综合': 'S19' } # 精确匹配 if industry_name in industry_mapping: return industry_mapping[industry_name] # 模糊匹配 for key, code in industry_mapping.items(): if key in industry_name or industry_name in key: return code return 'C03' # 默认制造业 def _get_market_type(self, stock_code: str, market: str) -> str: """获取市场类型""" if stock_code.startswith('688'): return '科创板' elif stock_code.startswith('300'): return '创业板' elif stock_code.startswith('600') or stock_code.startswith('601') or stock_code.startswith('603') or stock_code.startswith('605'): return '主板' elif stock_code.startswith('000') or stock_code.startswith('001') or stock_code.startswith('002') or stock_code.startswith('003'): return '主板' elif stock_code.startswith('8') or stock_code.startswith('43'): return '新三板' else: return market or '其他' def update_stock_sectors(self, stock_codes: List[str] = None) -> int: """更新股票概念板块信息""" try: if stock_codes is None: # 获取所有股票 stock_codes = [stock['stock_code'] for stock in self._get_stock_list_from_db()] updated_count = 0 total_count = len(stock_codes) for stock_code in stock_codes: try: # 这里可以调用概念板块API获取股票所属概念 # 由于tushare概念接口限制,这里先做一些基础映射 self._update_stock_concepts(stock_code) updated_count += 1 if updated_count % 100 == 0: self.logger.info(f"已更新 {updated_count}/{total_count} 只股票的概念信息") except Exception as e: self.logger.error(f"更新股票 {stock_code} 概念信息失败: {e}") continue self.logger.info(f"完成更新 {updated_count} 只股票的概念信息") return updated_count except Exception as e: self.logger.error(f"批量更新股票概念信息失败: {e}") return 0 def _update_stock_concepts(self, stock_code: str): """更新单个股票的概念信息""" try: # 基于股票代码做一些基础的概念分类 concepts = [] # 根据股票代码前缀推断概念 if stock_code.startswith('688'): concepts.append('BK0500') # 半导体 elif stock_code.startswith('300'): concepts.append('BK0896') # 国产软件 concepts.append('BK0735') # 新基建 # 这里可以扩展更多的概念匹配逻辑 # 也可以调用第三方API获取更准确的概念分类 if concepts: self._save_stock_concepts(stock_code, concepts) except Exception as e: self.logger.error(f"更新股票 {stock_code} 概念失败: {e}") def _save_stock_concepts(self, stock_code: str, concept_codes: List[str]): """保存股票概念关联关系""" try: with self.db_manager.get_connection() as conn: cursor = conn.cursor() # 先删除现有的概念关联 cursor.execute("DELETE FROM stock_sector_relations WHERE stock_code = %s", (stock_code,)) # 添加新的概念关联 for concept_code in concept_codes: query = """ INSERT IGNORE INTO stock_sector_relations (stock_code, sector_code) VALUES (%s, %s) """ cursor.execute(query, (stock_code, concept_code)) conn.commit() cursor.close() except Exception as e: self.logger.error(f"保存股票概念关联失败: {stock_code}, 错误: {e}") def get_stock_by_industry(self, industry_code: str = None, limit: int = 100) -> List[Dict]: """根据行业获取股票列表""" try: with self.db_manager.get_connection() as conn: cursor = conn.cursor(dictionary=True) if industry_code: query = """ SELECT s.*, i.industry_name FROM stocks s LEFT JOIN industries i ON s.industry_code = i.industry_code WHERE s.industry_code = %s AND s.is_active = TRUE ORDER BY s.stock_code LIMIT %s """ cursor.execute(query, (industry_code, limit)) else: query = """ SELECT s.*, i.industry_name FROM stocks s LEFT JOIN industries i ON s.industry_code = i.industry_code WHERE s.is_active = TRUE ORDER BY s.stock_code LIMIT %s """ cursor.execute(query, (limit,)) stocks = cursor.fetchall() cursor.close() return stocks except Exception as e: self.logger.error(f"根据行业获取股票失败: {e}") return [] def get_stock_by_sector(self, sector_code: str, limit: int = 100) -> List[Dict]: """根据概念板块获取股票列表""" try: with self.db_manager.get_connection() as conn: cursor = conn.cursor(dictionary=True) query = """ SELECT s.*, sec.sector_name FROM stocks s JOIN stock_sector_relations ssr ON s.stock_code = ssr.stock_code JOIN sectors sec ON ssr.sector_code = sec.sector_code WHERE ssr.sector_code = %s AND s.is_active = TRUE ORDER BY s.stock_code LIMIT %s """ cursor.execute(query, (sector_code, limit)) stocks = cursor.fetchall() cursor.close() return stocks except Exception as e: self.logger.error(f"根据概念获取股票失败: {e}") return [] def get_industry_list(self) -> List[Dict]: """获取所有行业列表""" try: with self.db_manager.get_connection() as conn: cursor = conn.cursor(dictionary=True) query = """ SELECT i.industry_code, i.industry_name, i.level, COUNT(s.stock_code) as stock_count FROM industries i LEFT JOIN stocks s ON i.industry_code = s.industry_code AND s.is_active = TRUE GROUP BY i.industry_code, i.industry_name, i.level ORDER BY i.industry_code """ cursor.execute(query) industries = cursor.fetchall() cursor.close() return industries except Exception as e: self.logger.error(f"获取行业列表失败: {e}") return [] def get_sector_list(self) -> List[Dict]: """获取所有概念板块列表""" try: with self.db_manager.get_connection() as conn: cursor = conn.cursor(dictionary=True) query = """ SELECT s.sector_code, s.sector_name, s.description, COUNT(ssr.stock_code) as stock_count FROM sectors s LEFT JOIN stock_sector_relations ssr ON s.sector_code = ssr.sector_code LEFT JOIN stocks st ON ssr.stock_code = st.stock_code AND st.is_active = TRUE GROUP BY s.sector_code, s.sector_name, s.description ORDER BY s.sector_code """ cursor.execute(query) sectors = cursor.fetchall() cursor.close() return sectors except Exception as e: self.logger.error(f"获取概念板块列表失败: {e}") return []