Spaces:
Running
Running
| import os | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from fastapi.middleware.wsgi import WSGIMiddleware | |
| from transformers import pipeline | |
| from RequestModel import PredictRequest | |
| from us_stock import fetch_symbols | |
| app = FastAPI() # 创建 FastAPI 应用 | |
| # 定义请求模型 | |
| class TextRequest(BaseModel): | |
| text: str | |
| # 定义两个 API 路由处理函数 | |
| async def api_aaa_post(request: TextRequest): | |
| result = request.text + 'aaa' | |
| return {"result": result} | |
| # 定义两个 API 路由处理函数 | |
| async def aaa(request: TextRequest): | |
| result = request.text + 'aaa' | |
| return {"result": result} | |
| # 定义两个 API 路由处理函数 | |
| async def api_aaa_get(request: TextRequest): | |
| result = request.text + 'aaa' | |
| return {"result": result} | |
| async def api_bbb(request: TextRequest): | |
| result = request.text + 'bbb' | |
| return {"result": result} | |
| async def initialize_symbols(): | |
| # 在 FastAPI 启动时初始化变量 | |
| await fetch_symbols() | |
| async def predict(request: PredictRequest): | |
| from blkeras import predict | |
| try: | |
| input_text = request.text # FastAPI 会自动解析为 PredictRequest 对象 | |
| affected_stock_codes = request.stock_codes | |
| print("Input text:", input_text[:200] if len(input_text) > 200 else input_text) | |
| print("Affected stock codes:", affected_stock_codes) | |
| return predict(input_text, affected_stock_codes) | |
| except Exception as e: | |
| return {"error": str(e)} | |
| async def root(): | |
| return {"message": "Welcome to the API. Use /api/aaa or /api/bbb for processing."} | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |