""" 股票数据访问对象 """ from typing import Dict, List, Optional, Tuple import json from datetime import datetime, date from .base_dao import BaseDAO class StockDAO(BaseDAO): """股票数据访问对象""" def get_stock_by_code(self, stock_code: str) -> Optional[Dict]: """根据股票代码获取股票信息""" query = "SELECT * FROM stocks WHERE stock_code = %s" return self._execute_single_query(query, (stock_code,)) def add_or_update_stock(self, stock_code: str, stock_name: str, market: str) -> int: """添加或更新股票信息""" existing = self.get_stock_by_code(stock_code) if existing: # 更新现有股票 query = """ UPDATE stocks SET stock_name = %s, market = %s, updated_at = CURRENT_TIMESTAMP WHERE stock_code = %s """ self._execute_update(query, (stock_name, market, stock_code)) return existing['id'] else: # 添加新股票 query = """ INSERT INTO stocks (stock_code, stock_name, market) VALUES (%s, %s, %s) """ return self._execute_insert(query, (stock_code, stock_name, market)) def get_stock_data(self, stock_code: str, data_date: str = None) -> Optional[Dict]: """获取股票数据""" if data_date is None: data_date = self.get_today_date() query = """ SELECT sd.*, s.stock_name FROM stock_data sd JOIN stocks s ON sd.stock_code = s.stock_code WHERE sd.stock_code = %s AND sd.data_date = %s """ return self._execute_single_query(query, (stock_code, data_date)) def save_stock_data(self, stock_code: str, stock_info: Dict, data_date: str = None) -> bool: """保存股票数据""" if data_date is None: data_date = self.get_today_date() try: # 确保股票信息存在 self.add_or_update_stock( stock_code, stock_info.get('name', ''), 'SH' if stock_code.startswith('6') else 'SZ' ) # 检查是否已存在当日数据 existing = self.get_stock_data(stock_code, data_date) if existing: # 更新现有数据 query = """ UPDATE stock_data SET price = %s, change_percent = %s, market_value = %s, pe_ratio = %s, pb_ratio = %s, ps_ratio = %s, dividend_yield = %s, roe = %s, gross_profit_margin = %s, net_profit_margin = %s, debt_to_assets = %s, revenue_yoy = %s, net_profit_yoy = %s, bps = %s, ocfps = %s, from_cache = %s, updated_at = CURRENT_TIMESTAMP WHERE stock_code = %s AND data_date = %s """ self._execute_update(query, ( self.parse_float(stock_info.get('price')), self.parse_float(stock_info.get('change_percent')), self.parse_float(stock_info.get('market_value')), self.parse_float(stock_info.get('pe_ratio')), self.parse_float(stock_info.get('pb_ratio')), self.parse_float(stock_info.get('ps_ratio')), self.parse_float(stock_info.get('dividend_yield')), self.parse_float(stock_info.get('roe')), self.parse_float(stock_info.get('gross_profit_margin')), self.parse_float(stock_info.get('net_profit_margin')), self.parse_float(stock_info.get('debt_to_assets')), self.parse_float(stock_info.get('revenue_yoy')), self.parse_float(stock_info.get('net_profit_yoy')), self.parse_float(stock_info.get('bps')), self.parse_float(stock_info.get('ocfps')), bool(stock_info.get('from_cache', False)), stock_code, data_date )) else: # 插入新数据 query = """ INSERT INTO stock_data ( stock_code, data_date, price, change_percent, market_value, pe_ratio, pb_ratio, ps_ratio, dividend_yield, roe, gross_profit_margin, net_profit_margin, debt_to_assets, revenue_yoy, net_profit_yoy, bps, ocfps, from_cache ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) """ self._execute_insert(query, ( stock_code, data_date, self.parse_float(stock_info.get('price')), self.parse_float(stock_info.get('change_percent')), self.parse_float(stock_info.get('market_value')), self.parse_float(stock_info.get('pe_ratio')), self.parse_float(stock_info.get('pb_ratio')), self.parse_float(stock_info.get('ps_ratio')), self.parse_float(stock_info.get('dividend_yield')), self.parse_float(stock_info.get('roe')), self.parse_float(stock_info.get('gross_profit_margin')), self.parse_float(stock_info.get('net_profit_margin')), self.parse_float(stock_info.get('debt_to_assets')), self.parse_float(stock_info.get('revenue_yoy')), self.parse_float(stock_info.get('net_profit_yoy')), self.parse_float(stock_info.get('bps')), self.parse_float(stock_info.get('ocfps')), bool(stock_info.get('from_cache', False)) )) return True except Exception as e: self.logger.error(f"保存股票数据失败: {stock_code}, 错误: {e}") self.log_data_update('stock_data', stock_code, 'failed', str(e)) return False def get_latest_stock_data(self, stock_code: str) -> Optional[Dict]: """获取最新的股票数据""" query = """ SELECT sd.*, s.stock_name FROM stock_data sd JOIN stocks s ON sd.stock_code = s.stock_code WHERE sd.stock_code = %s ORDER BY sd.data_date DESC LIMIT 1 """ return self._execute_single_query(query, (stock_code,)) def get_multiple_stocks_data(self, stock_codes: List[str], data_date: str = None) -> List[Dict]: """批量获取股票数据""" if not stock_codes: return [] if data_date is None: data_date = self.get_today_date() placeholders = ','.join(['%s'] * len(stock_codes)) query = f""" SELECT sd.*, s.stock_name FROM stock_data sd JOIN stocks s ON sd.stock_code = s.stock_code WHERE sd.stock_code IN ({placeholders}) AND sd.data_date = %s """ return self._execute_query(query, tuple(stock_codes + [data_date])) def get_stock_data_history(self, stock_code: str, days: int = 30) -> List[Dict]: """获取股票历史数据""" query = """ SELECT sd.*, s.stock_name FROM stock_data sd JOIN stocks s ON sd.stock_code = s.stock_code WHERE sd.stock_code = %s AND sd.data_date >= DATE_SUB(CURDATE(), INTERVAL %s DAY) ORDER BY sd.data_date DESC """ return self._execute_query(query, (stock_code, days)) def delete_stock_data(self, stock_code: str, before_date: str = None) -> int: """删除股票数据""" if before_date: query = "DELETE FROM stock_data WHERE stock_code = %s AND data_date < %s" return self._execute_update(query, (stock_code, before_date)) else: query = "DELETE FROM stock_data WHERE stock_code = %s" return self._execute_update(query, (stock_code,)) def get_stock_count(self) -> int: """获取股票总数""" query = "SELECT COUNT(*) as count FROM stocks" result = self._execute_single_query(query) return result['count'] if result else 0 def get_data_date_range(self) -> Optional[Dict]: """获取数据的日期范围""" query = "SELECT MIN(data_date) as min_date, MAX(data_date) as max_date FROM stock_data" return self._execute_single_query(query)