Spaces:
Sleeping
Sleeping
File size: 5,321 Bytes
900a36d fc04327 f5215d5 900a36d 3cf6bb7 fc04327 f5215d5 ecc1ef6 f5215d5 ecc1ef6 f5215d5 900a36d f5215d5 ecc1ef6 f5215d5 ecc1ef6 900a36d ecc1ef6 f5215d5 ecc1ef6 1eb8b02 94763f9 900a36d f5215d5 900a36d 6010b3c 1eb8b02 f5215d5 900a36d f5215d5 1eb8b02 900a36d f5215d5 94763f9 f5215d5 900a36d f5215d5 900a36d 94763f9 f5215d5 a4801e5 ecc1ef6 1eb8b02 ecc1ef6 fc04327 a4801e5 f5215d5 fc04327 a4801e5 fc04327 a4801e5 fc04327 94763f9 fc04327 ecc1ef6 fc04327 900a36d fc04327 ecc1ef6 fc04327 3316b4a fc04327 94763f9 fc04327 f5215d5 fc04327 43d9eff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
# Install required dependencies
import subprocess
import sys
def install_packages():
packages = ["sentencepiece", "protobuf", "transformers", "torch", "accelerate"]
for package in packages:
try:
__import__(package)
except ImportError:
print(f"Installing {package}...")
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
install_packages()
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import uvicorn
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import time
from fastapi.middleware.cors import CORSMiddleware
# Initialize FastAPI app
app = FastAPI(
title="YAH Tech AI API",
description="AI Assistant API for testing",
version="1.0.0"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class YAHBot:
def __init__(self):
self.repo_id = "Adedoyinjames/brain-ai"
self.tokenizer = None
self.model = None
self._load_model()
def _load_model(self):
"""Load the model from your Hugging Face repo"""
try:
print(f"π Loading AI model from {self.repo_id}...")
self.tokenizer = AutoTokenizer.from_pretrained(
self.repo_id,
trust_remote_code=True # Required for phi-3
)
self.model = AutoModelForCausalLM.from_pretrained(
self.repo_id,
trust_remote_code=True, # Required for phi-3
torch_dtype=torch.float16,
device_map="auto"
)
print("β
AI model loaded successfully from HF repo!")
except Exception as e:
print(f"β Failed to load AI model from repo: {e}")
self.model = None
self.tokenizer = None
def generate_response(self, user_input):
"""Generate response using causal language model"""
if self.model and self.tokenizer:
try:
# Format prompt for phi-3 (causal LM)
prompt = f"<|user|>\n{user_input}<|end|>\n<|assistant|>\n"
inputs = self.tokenizer(
prompt,
return_tensors="pt",
max_length=512,
truncation=True,
padding=True
)
# Move to same device as model
device = next(self.model.parameters()).device
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model.generate(
inputs.input_ids,
max_new_tokens=150,
num_return_sequences=1,
temperature=0.7,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id, # Use EOS token for padding
eos_token_id=self.tokenizer.eos_token_id,
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Remove the prompt from the response for cleaner output
if prompt in response:
response = response.replace(prompt, "").strip()
return response
except Exception as e:
print(f"Model error: {str(e)}")
return "I apologize, but I'm having trouble processing your question right now."
return "AI model is not available."
# Initialize the bot globally
yah_bot = YAHBot()
# Request/Response models
class ChatRequest(BaseModel):
message: str
class ChatResponse(BaseModel):
response: str
status: str
timestamp: float
class HealthResponse(BaseModel):
status: str
service: str
timestamp: float
# API Endpoints
@app.get("/")
async def root():
return {
"message": "YAH Tech AI API is running",
"status": "active",
"model_repo": yah_bot.repo_id,
"model_type": "causal_lm",
"endpoints": {
"chat": "POST /api/chat",
"health": "GET /api/health"
}
}
@app.post("/api/chat", response_model=ChatResponse)
async def chat_endpoint(request: ChatRequest):
"""
Main chat endpoint - Send a message and get AI response
"""
try:
response = yah_bot.generate_response(request.message)
return ChatResponse(
response=response,
status="success",
timestamp=time.time()
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")
@app.get("/api/health", response_model=HealthResponse)
async def health_check():
return HealthResponse(
status="healthy",
service="YAH Tech AI API",
timestamp=time.time()
)
# For Hugging Face Spaces
def get_app():
return app
if __name__ == "__main__":
uvicorn.run(
app,
host="0.0.0.0",
port=7860,
log_level="info"
) |