File size: 5,321 Bytes
900a36d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc04327
 
 
f5215d5
900a36d
3cf6bb7
fc04327
f5215d5
ecc1ef6
 
 
 
 
 
f5215d5
ecc1ef6
 
 
 
 
 
 
 
f5215d5
 
 
900a36d
f5215d5
 
 
 
 
ecc1ef6
f5215d5
ecc1ef6
900a36d
 
 
 
 
 
 
 
 
 
ecc1ef6
f5215d5
ecc1ef6
 
 
1eb8b02
94763f9
900a36d
f5215d5
 
900a36d
 
6010b3c
1eb8b02
 
 
 
 
 
 
f5215d5
900a36d
 
 
 
f5215d5
 
1eb8b02
900a36d
f5215d5
94763f9
f5215d5
900a36d
 
f5215d5
 
 
900a36d
 
 
 
 
94763f9
f5215d5
 
a4801e5
ecc1ef6
1eb8b02
ecc1ef6
fc04327
a4801e5
 
f5215d5
fc04327
 
 
a4801e5
fc04327
 
 
 
a4801e5
fc04327
 
 
 
94763f9
fc04327
 
 
 
ecc1ef6
fc04327
900a36d
 
fc04327
 
ecc1ef6
fc04327
 
 
 
 
 
 
 
 
 
3316b4a
fc04327
 
 
 
94763f9
fc04327
 
 
 
 
 
 
 
 
 
 
 
 
 
f5215d5
 
fc04327
 
 
 
 
43d9eff
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
# Install required dependencies
import subprocess
import sys

def install_packages():
    packages = ["sentencepiece", "protobuf", "transformers", "torch", "accelerate"]
    for package in packages:
        try:
            __import__(package)
        except ImportError:
            print(f"Installing {package}...")
            subprocess.check_call([sys.executable, "-m", "pip", "install", package])

install_packages()

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import uvicorn
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import time
from fastapi.middleware.cors import CORSMiddleware

# Initialize FastAPI app
app = FastAPI(
    title="YAH Tech AI API",
    description="AI Assistant API for testing",
    version="1.0.0"
)

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

class YAHBot:
    def __init__(self):
        self.repo_id = "Adedoyinjames/brain-ai"
        self.tokenizer = None
        self.model = None
        self._load_model()
    
    def _load_model(self):
        """Load the model from your Hugging Face repo"""
        try:
            print(f"πŸ”„ Loading AI model from {self.repo_id}...")
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.repo_id,
                trust_remote_code=True  # Required for phi-3
            )
            self.model = AutoModelForCausalLM.from_pretrained(
                self.repo_id,
                trust_remote_code=True,  # Required for phi-3
                torch_dtype=torch.float16,
                device_map="auto"
            )
            print("βœ… AI model loaded successfully from HF repo!")
        except Exception as e:
            print(f"❌ Failed to load AI model from repo: {e}")
            self.model = None
            self.tokenizer = None
    
    def generate_response(self, user_input):
        """Generate response using causal language model"""
        if self.model and self.tokenizer:
            try:
                # Format prompt for phi-3 (causal LM)
                prompt = f"<|user|>\n{user_input}<|end|>\n<|assistant|>\n"
                
                inputs = self.tokenizer(
                    prompt, 
                    return_tensors="pt", 
                    max_length=512, 
                    truncation=True,
                    padding=True
                )
                
                # Move to same device as model
                device = next(self.model.parameters()).device
                inputs = {k: v.to(device) for k, v in inputs.items()}
                
                with torch.no_grad():
                    outputs = self.model.generate(
                        inputs.input_ids,
                        max_new_tokens=150,
                        num_return_sequences=1,
                        temperature=0.7,
                        do_sample=True,
                        pad_token_id=self.tokenizer.eos_token_id,  # Use EOS token for padding
                        eos_token_id=self.tokenizer.eos_token_id,
                    )
                
                response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
                
                # Remove the prompt from the response for cleaner output
                if prompt in response:
                    response = response.replace(prompt, "").strip()
                
                return response
                
            except Exception as e:
                print(f"Model error: {str(e)}")
                return "I apologize, but I'm having trouble processing your question right now."
        
        return "AI model is not available."

# Initialize the bot globally
yah_bot = YAHBot()

# Request/Response models
class ChatRequest(BaseModel):
    message: str

class ChatResponse(BaseModel):
    response: str
    status: str
    timestamp: float

class HealthResponse(BaseModel):
    status: str
    service: str
    timestamp: float

# API Endpoints
@app.get("/")
async def root():
    return {
        "message": "YAH Tech AI API is running",
        "status": "active",
        "model_repo": yah_bot.repo_id,
        "model_type": "causal_lm",
        "endpoints": {
            "chat": "POST /api/chat",
            "health": "GET /api/health"
        }
    }

@app.post("/api/chat", response_model=ChatResponse)
async def chat_endpoint(request: ChatRequest):
    """
    Main chat endpoint - Send a message and get AI response
    """
    try:
        response = yah_bot.generate_response(request.message)
        
        return ChatResponse(
            response=response,
            status="success",
            timestamp=time.time()
        )
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")

@app.get("/api/health", response_model=HealthResponse)
async def health_check():
    return HealthResponse(
        status="healthy",
        service="YAH Tech AI API",
        timestamp=time.time()
    )

# For Hugging Face Spaces
def get_app():
    return app

if __name__ == "__main__":
    uvicorn.run(
        app, 
        host="0.0.0.0", 
        port=7860,
        log_level="info"
    )