Spaces:
Sleeping
Sleeping
| import pandas as pd | |
| from datetime import datetime, timedelta, date | |
| import numpy as np | |
| import asyncio | |
| import threading | |
| import time | |
| import yfinance as yf | |
| # 索引变量初始化 | |
| # 以下变量在外部模块中定义并在运行时更新 | |
| index_us_stock_index_INX = None | |
| index_us_stock_index_DJI = None | |
| index_us_stock_index_IXIC = None | |
| index_us_stock_index_NDX = None | |
| def init_stock_index_data(): | |
| """初始化股票指数数据,使用 yfinance""" | |
| global index_us_stock_index_INX, index_us_stock_index_DJI, index_us_stock_index_IXIC, index_us_stock_index_NDX | |
| try: | |
| # 计算日期范围 | |
| end_date = datetime.now() | |
| start_date = end_date - timedelta(weeks=8) | |
| # 定义指数映射 | |
| indices = { | |
| '^GSPC': 'INX', # S&P 500 | |
| '^DJI': 'DJI', # Dow Jones | |
| '^IXIC': 'IXIC', # NASDAQ Composite | |
| '^NDX': 'NDX' # NASDAQ 100 | |
| } | |
| results = {} | |
| for yf_symbol, var_name in indices.items(): | |
| try: | |
| print(f"Fetching {var_name} data using yfinance...") | |
| ticker = yf.Ticker(yf_symbol) | |
| hist_data = ticker.history(start=start_date, end=end_date) | |
| if not hist_data.empty: | |
| # 转换为与原来相同的格式 | |
| formatted_data = pd.DataFrame({ | |
| 'date': hist_data.index.strftime('%Y-%m-%d'), | |
| '开盘': hist_data['Open'].values, | |
| '收盘': hist_data['Close'].values, | |
| '最高': hist_data['High'].values, | |
| '最低': hist_data['Low'].values, | |
| '成交量': hist_data['Volume'].values, | |
| '成交额': (hist_data['Close'] * hist_data['Volume']).values | |
| }) | |
| results[var_name] = formatted_data | |
| print(f"Successfully fetched {var_name}: {len(formatted_data)} records") | |
| else: | |
| print(f"No data for {yf_symbol}") | |
| results[var_name] = pd.DataFrame() | |
| except Exception as e: | |
| print(f"Error fetching {yf_symbol}: {e}") | |
| results[var_name] = pd.DataFrame() | |
| # 设置全局变量 | |
| index_us_stock_index_INX = results.get('INX', pd.DataFrame()) | |
| index_us_stock_index_DJI = results.get('DJI', pd.DataFrame()) | |
| index_us_stock_index_IXIC = results.get('IXIC', pd.DataFrame()) | |
| index_us_stock_index_NDX = results.get('NDX', pd.DataFrame()) | |
| print("Stock indices initialized successfully using yfinance") | |
| except Exception as e: | |
| print(f"Error initializing stock indices: {e}") | |
| # 设置空的DataFrame作为fallback | |
| index_us_stock_index_INX = pd.DataFrame() | |
| index_us_stock_index_DJI = pd.DataFrame() | |
| index_us_stock_index_IXIC = pd.DataFrame() | |
| index_us_stock_index_NDX = pd.DataFrame() | |
| def delayed_init_indices(): | |
| """延迟初始化指数数据""" | |
| time.sleep(5) # 等待5秒后开始初始化 | |
| init_stock_index_data() | |
| # 启动延迟初始化 | |
| init_thread = threading.Thread(target=delayed_init_indices, daemon=True) | |
| init_thread.start() | |
| # 下面是原有的其他函数,保持不变... | |
| # 新的文本时间处理函数 | |
| def parse_time(time_str): | |
| """解析时间字符串并返回规范化的日期格式""" | |
| if not time_str: | |
| return None | |
| today = date.today() | |
| # 处理相对时间表达 | |
| if '昨天' in time_str or '昨日' in time_str: | |
| return (today - timedelta(days=1)).strftime('%Y-%m-%d') | |
| elif '今天' in time_str or '今日' in time_str: | |
| return today.strftime('%Y-%m-%d') | |
| elif '前天' in time_str: | |
| return (today - timedelta(days=2)).strftime('%Y-%m-%d') | |
| elif '上周' in time_str: | |
| return (today - timedelta(weeks=1)).strftime('%Y-%m-%d') | |
| elif '上月' in time_str: | |
| return (today - timedelta(days=30)).strftime('%Y-%m-%d') | |
| # 处理具体日期格式 | |
| try: | |
| # 尝试多种日期格式 | |
| formats = ['%Y-%m-%d', '%Y/%m/%d', '%m/%d/%Y', '%m-%d-%Y', '%d/%m/%Y', '%d-%m-%Y'] | |
| for fmt in formats: | |
| try: | |
| parsed_date = datetime.strptime(time_str, fmt).date() | |
| return parsed_date.strftime('%Y-%m-%d') | |
| except ValueError: | |
| continue | |
| except: | |
| pass | |
| # 如果无法解析,返回今天的日期 | |
| return today.strftime('%Y-%m-%d') | |
| # 原有的其他函数... | |
| def preprocess_news_text(text): | |
| """预处理新闻文本""" | |
| # 移除多余的空白字符 | |
| text = ' '.join(text.split()) | |
| # 转换为小写 | |
| text = text.lower() | |
| return text | |
| def extract_sentiment_score(text): | |
| """提取情感分数的占位符函数""" | |
| # 这里可以集成实际的情感分析模型 | |
| # 目前返回一个基于文本长度的简单分数 | |
| if not text: | |
| return 0.0 | |
| positive_words = ['good', 'great', 'excellent', 'positive', 'growth', 'profit', 'gain', 'rise', 'up'] | |
| negative_words = ['bad', 'poor', 'negative', 'loss', 'decline', 'fall', 'down', 'crash'] | |
| text_lower = text.lower() | |
| positive_count = sum(1 for word in positive_words if word in text_lower) | |
| negative_count = sum(1 for word in negative_words if word in text_lower) | |
| if positive_count > negative_count: | |
| return min(1.0, positive_count * 0.2) | |
| elif negative_count > positive_count: | |
| return max(-1.0, -negative_count * 0.2) | |
| else: | |
| return 0.0 | |
| def calculate_technical_indicators(price_data): | |
| """计算技术指标""" | |
| if price_data.empty: | |
| return {} | |
| close_prices = price_data['close'] | |
| # 简单移动平均线 | |
| sma_5 = close_prices.rolling(window=5).mean().iloc[-1] if len(close_prices) >= 5 else close_prices.iloc[-1] | |
| sma_10 = close_prices.rolling(window=10).mean().iloc[-1] if len(close_prices) >= 10 else close_prices.iloc[-1] | |
| # RSI (相对强弱指数) | |
| def calculate_rsi(prices, window=14): | |
| if len(prices) < window: | |
| return 50.0 # 默认值 | |
| delta = prices.diff() | |
| gain = delta.where(delta > 0, 0) | |
| loss = -delta.where(delta < 0, 0) | |
| avg_gain = gain.rolling(window=window).mean() | |
| avg_loss = loss.rolling(window=window).mean() | |
| rs = avg_gain / avg_loss | |
| rsi = 100 - (100 / (1 + rs)) | |
| return rsi.iloc[-1] | |
| rsi = calculate_rsi(close_prices) | |
| # 价格变化百分比 | |
| price_change = ((close_prices.iloc[-1] - close_prices.iloc[0]) / close_prices.iloc[0] * 100) if len(close_prices) > 1 else 0 | |
| return { | |
| 'sma_5': sma_5, | |
| 'sma_10': sma_10, | |
| 'rsi': rsi, | |
| 'price_change_pct': price_change | |
| } | |
| def normalize_features(features_dict): | |
| """标准化特征值""" | |
| normalized = {} | |
| for key, value in features_dict.items(): | |
| if isinstance(value, (int, float)) and not pd.isna(value): | |
| # 简单的min-max标准化到[-1, 1]范围 | |
| if key == 'rsi': | |
| normalized[key] = (value - 50) / 50 # RSI标准化 | |
| elif key.endswith('_pct'): | |
| normalized[key] = np.tanh(value / 100) # 百分比变化标准化 | |
| else: | |
| normalized[key] = np.tanh(value / 1000) # 其他数值标准化 | |
| else: | |
| normalized[key] = 0.0 | |
| return normalized | |
| # 主要的预处理函数 | |
| def preprocess_for_model(news_text, stock_symbol, news_date): | |
| """为模型预处理数据""" | |
| try: | |
| # 预处理文本 | |
| processed_text = preprocess_news_text(news_text) | |
| # 解析日期 | |
| parsed_date = parse_time(news_date) | |
| # 提取情感分数 | |
| sentiment_score = extract_sentiment_score(processed_text) | |
| # 这里应该调用股票数据获取函数 | |
| # 由于需要避免循环导入,这里只返回基本特征 | |
| return { | |
| 'processed_text': processed_text, | |
| 'sentiment_score': sentiment_score, | |
| 'news_date': parsed_date, | |
| 'stock_symbol': stock_symbol | |
| } | |
| except Exception as e: | |
| print(f"Error in preprocess_for_model: {e}") | |
| return { | |
| 'processed_text': news_text, | |
| 'sentiment_score': 0.0, | |
| 'news_date': date.today().strftime('%Y-%m-%d'), | |
| 'stock_symbol': stock_symbol | |
| } | |
| if __name__ == "__main__": | |
| # 测试函数 | |
| test_text = "Apple Inc. reported strong quarterly earnings, beating expectations." | |
| result = preprocess_for_model(test_text, "AAPL", "2024-02-14") | |
| print(f"Preprocessing result: {result}") | |