Spaces:
Sleeping
Sleeping
| # Install required dependencies | |
| import subprocess | |
| import sys | |
| def install_packages(): | |
| packages = ["sentencepiece", "protobuf", "transformers", "torch", "accelerate"] | |
| for package in packages: | |
| try: | |
| __import__(package) | |
| except ImportError: | |
| print(f"Installing {package}...") | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", package]) | |
| install_packages() | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| import uvicorn | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import time | |
| from fastapi.middleware.cors import CORSMiddleware | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="YAH Tech AI API", | |
| description="AI Assistant API for testing", | |
| version="1.0.0" | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| class YAHBot: | |
| def __init__(self): | |
| self.repo_id = "Adedoyinjames/brain-ai" | |
| self.tokenizer = None | |
| self.model = None | |
| self._load_model() | |
| def _load_model(self): | |
| """Load the model from your Hugging Face repo""" | |
| try: | |
| print(f"π Loading AI model from {self.repo_id}...") | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| self.repo_id, | |
| trust_remote_code=True # Required for phi-3 | |
| ) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| self.repo_id, | |
| trust_remote_code=True, # Required for phi-3 | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ) | |
| print("β AI model loaded successfully from HF repo!") | |
| except Exception as e: | |
| print(f"β Failed to load AI model from repo: {e}") | |
| self.model = None | |
| self.tokenizer = None | |
| def generate_response(self, user_input): | |
| """Generate response using causal language model""" | |
| if self.model and self.tokenizer: | |
| try: | |
| # Format prompt for phi-3 (causal LM) | |
| prompt = f"<|user|>\n{user_input}<|end|>\n<|assistant|>\n" | |
| inputs = self.tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| max_length=512, | |
| truncation=True, | |
| padding=True | |
| ) | |
| # Move to same device as model | |
| device = next(self.model.parameters()).device | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| inputs.input_ids, | |
| max_new_tokens=150, | |
| num_return_sequences=1, | |
| temperature=0.7, | |
| do_sample=True, | |
| pad_token_id=self.tokenizer.eos_token_id, # Use EOS token for padding | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| ) | |
| response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Remove the prompt from the response for cleaner output | |
| if prompt in response: | |
| response = response.replace(prompt, "").strip() | |
| return response | |
| except Exception as e: | |
| print(f"Model error: {str(e)}") | |
| return "I apologize, but I'm having trouble processing your question right now." | |
| return "AI model is not available." | |
| # Initialize the bot globally | |
| yah_bot = YAHBot() | |
| # Request/Response models | |
| class ChatRequest(BaseModel): | |
| message: str | |
| class ChatResponse(BaseModel): | |
| response: str | |
| status: str | |
| timestamp: float | |
| class HealthResponse(BaseModel): | |
| status: str | |
| service: str | |
| timestamp: float | |
| # API Endpoints | |
| async def root(): | |
| return { | |
| "message": "YAH Tech AI API is running", | |
| "status": "active", | |
| "model_repo": yah_bot.repo_id, | |
| "model_type": "causal_lm", | |
| "endpoints": { | |
| "chat": "POST /api/chat", | |
| "health": "GET /api/health" | |
| } | |
| } | |
| async def chat_endpoint(request: ChatRequest): | |
| """ | |
| Main chat endpoint - Send a message and get AI response | |
| """ | |
| try: | |
| response = yah_bot.generate_response(request.message) | |
| return ChatResponse( | |
| response=response, | |
| status="success", | |
| timestamp=time.time() | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}") | |
| async def health_check(): | |
| return HealthResponse( | |
| status="healthy", | |
| service="YAH Tech AI API", | |
| timestamp=time.time() | |
| ) | |
| # For Hugging Face Spaces | |
| def get_app(): | |
| return app | |
| if __name__ == "__main__": | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=7860, | |
| log_level="info" | |
| ) |