File size: 2,482 Bytes
6103ef6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29e3684
 
 
 
 
b66d06d
 
6103ef6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# app.py (safe, use /tmp for cache)
import os
import logging
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import tempfile

# --- Put caches in a writable temp dir to avoid permission errors ---
TMP_CACHE = os.environ.get("HF_CACHE_DIR", os.path.join(tempfile.gettempdir(), "hf_cache"))
try:
    os.makedirs(TMP_CACHE, exist_ok=True)
except Exception as e:
    # if even this fails, fall back to tempfile.gettempdir()
    TMP_CACHE = tempfile.gettempdir()

# export environment vars before importing transformers
os.environ["TRANSFORMERS_CACHE"] = TMP_CACHE
os.environ["HF_HOME"] = TMP_CACHE
os.environ["HF_DATASETS_CACHE"] = TMP_CACHE
os.environ["HF_METRICS_CACHE"] = TMP_CACHE

app = FastAPI(title="DirectEd LoRA API (safe startup)")

@app.get("/health")
def health():
    return {"ok": True}

@app.get("/")
def root():
    return {"Status": "AI backend is running"}

class Request(BaseModel):
    prompt: str
    max_new_tokens: int = 150
    temperature: float = 0.7

pipe = None

@app.on_event("startup")
def load_model():
    global pipe
    try:
        # heavy imports done during startup
        from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
        from peft import PeftModel

        BASE_MODEL = "unsloth/llama-3-8b-Instruct-bnb-4bit"
        ADAPTER_REPO = "rayymaxx/DirectEd-AI-LoRA"  # <-- replace with your adapter repo

        tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
        base_model = AutoModelForCausalLM.from_pretrained(
            BASE_MODEL,
            device_map="auto",
            low_cpu_mem_usage=True,
            torch_dtype="auto",
        )

        model = PeftModel.from_pretrained(base_model, ADAPTER_REPO)
        model.eval()

        pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto")
        logging.info("Model and adapter loaded successfully.")
    except Exception as e:
        logging.exception("Failed to load model at startup: %s", e)
        pipe = None

@app.post("/generate")
def generate(req: Request):
    if pipe is None:
        raise HTTPException(status_code=503, detail="Model not loaded. Check logs.")
    try:
        out = pipe(req.prompt, max_new_tokens=req.max_new_tokens, temperature=req.temperature, do_sample=True)
        return {"response": out[0]["generated_text"]}
    except Exception as e:
        logging.exception("Generation failed: %s", e)
        raise HTTPException(status_code=500, detail=str(e))