Spaces:
Runtime error
Runtime error
| # app.py — FastAPI backend with simple rate limiting and history persistence | |
| import os, json, logging, time | |
| from pathlib import Path | |
| from typing import List, Any | |
| import torch | |
| from fastapi import FastAPI, Request, HTTPException | |
| from fastapi.responses import FileResponse, JSONResponse, HTMLResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from pydantic import BaseModel | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
| from peft import PeftModel | |
| # ----- config ----- | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("colabmind") | |
| BASE_MODEL = "deepseek-ai/deepseek-coder-6.7b-base" | |
| ADAPTER_REPO = "Agasthya0/colabmind-coder-6.7b-ml-qlora" | |
| ROOT = Path(__file__).parent | |
| STATIC_DIR = ROOT # index.html at repo root | |
| HISTORY_PATH = ROOT / "history.jsonl" | |
| OFFLOAD_DIR = "/tmp/llm_offload" | |
| os.makedirs(OFFLOAD_DIR, exist_ok=True) | |
| # ----- rate limiter (simple sliding window per IP) ----- | |
| RATE_WINDOW = 60 # seconds | |
| MAX_REQUESTS = 12 # per window per IP (adjust as needed) | |
| _client_requests = {} # ip -> [timestamps...] | |
| def allowed_request(client_ip: str) -> bool: | |
| now = time.time() | |
| arr = _client_requests.setdefault(client_ip, []) | |
| # remove old | |
| while arr and arr[0] <= now - RATE_WINDOW: | |
| arr.pop(0) | |
| if len(arr) >= MAX_REQUESTS: | |
| return False | |
| arr.append(now) | |
| return True | |
| # ----- fastapi app ----- | |
| app = FastAPI(title="ColabMind Coder API") | |
| # mount static | |
| app.mount("/static", StaticFiles(directory=STATIC_DIR / "static"), name="static") | |
| class PredictRequest(BaseModel): | |
| data: List[Any] | |
| # ----- model load ----- | |
| logger.info("Preparing BitsAndBytes config for 4-bit QLoRA inference...") | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_compute_dtype=torch.float16, | |
| llm_int8_enable_fp32_cpu_offload=True, | |
| ) | |
| logger.info("Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True) | |
| logger.info("Loading base model (4-bit). This may take a while...") | |
| model_base = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| offload_folder=OFFLOAD_DIR, | |
| trust_remote_code=True, | |
| ) | |
| logger.info("Loading LoRA adapter from repo...") | |
| model = PeftModel.from_pretrained(model_base, ADAPTER_REPO, device_map="auto") | |
| model.eval() | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| logger.info("Model loaded and ready.") | |
| # ----- helpers ----- | |
| def _extract_args(data: List[Any]): | |
| prompt = data[0] if len(data) > 0 else "" | |
| max_new_tokens = int(data[1]) if len(data) > 1 and data[1] else 512 | |
| temperature = float(data[2]) if len(data) > 2 and data[2] else 0.2 | |
| top_p = float(data[3]) if len(data) > 3 and data[3] else 0.9 | |
| max_new_tokens = max(16, min(2048, max_new_tokens)) | |
| temperature = max(0.0, min(2.0, temperature)) | |
| top_p = max(0.0, min(1.0, top_p)) | |
| return prompt, max_new_tokens, temperature, top_p | |
| def generate_text(prompt:str, max_new_tokens:int, temperature:float, top_p:float) -> str: | |
| if not prompt or not prompt.strip(): | |
| return "⚠️ Empty prompt" | |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048).to(model.device) | |
| with torch.no_grad(): | |
| out = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_p=top_p, | |
| eos_token_id=tokenizer.eos_token_id, | |
| pad_token_id=getattr(tokenizer, "pad_token_id", tokenizer.eos_token_id), | |
| ) | |
| text = tokenizer.decode(out[0], skip_special_tokens=True) | |
| try: | |
| if torch.cuda.is_available(): torch.cuda.empty_cache() | |
| except: pass | |
| return text | |
| def append_history(prompt, code): | |
| entry = {"time": time.strftime("%Y-%m-%d %H:%M:%S"), "prompt": prompt, "code": code} | |
| try: | |
| with open(HISTORY_PATH, "a", encoding="utf-8") as f: | |
| f.write(json.dumps(entry, ensure_ascii=False) + "\n") | |
| except Exception as e: | |
| logger.warning("Failed to append history: %s", e) | |
| # ----- routes ----- | |
| async def homepage(): | |
| p = ROOT / "index.html" | |
| if p.exists(): return FileResponse(str(p)) | |
| return HTMLResponse("<h3>Upload index.html to the Space repo root.</h3>", status_code=404) | |
| async def run_predict(req: PredictRequest, request: Request): | |
| client_ip = request.client.host or "unknown" | |
| if not allowed_request(client_ip): | |
| raise HTTPException(429, detail="Too many requests — slow down.") | |
| try: | |
| prompt, max_new_tokens, temperature, top_p = _extract_args(req.data) | |
| logger.info("Predict request: len(prompt)=%d ip=%s", len(prompt), client_ip) | |
| out = generate_text(prompt, max_new_tokens, temperature, top_p) | |
| # persist history asynchronously (best-effort) | |
| try: append_history(prompt, out) | |
| except: pass | |
| return JSONResponse({"data":[out]}) | |
| except Exception as e: | |
| logger.exception("Predict failure") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def run_explain(req: PredictRequest, request: Request): | |
| client_ip = request.client.host or "unknown" | |
| if not allowed_request(client_ip): | |
| raise HTTPException(429, detail="Too many requests — slow down.") | |
| try: | |
| code_text = req.data[0] if req.data else "" | |
| prompt = f"### Instruction:\nExplain the following Python code in detail, line-by-line, and highlight potential bugs or improvements.\n\n### Code:\n{code_text}\n\n### Explanation:\n" | |
| max_new_tokens = int(req.data[1]) if len(req.data) > 1 and req.data[1] else 512 | |
| temperature = float(req.data[2]) if len(req.data) > 2 and req.data[2] else 0.2 | |
| top_p = float(req.data[3]) if len(req.data) > 3 and req.data[3] else 0.9 | |
| out = generate_text(prompt, max_new_tokens, temperature, top_p) | |
| return JSONResponse({"data":[out]}) | |
| except Exception as e: | |
| logger.exception("Explain failure") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def history_append(payload: dict): | |
| try: | |
| append_history(payload.get("prompt",""), payload.get("code","")) | |
| return JSONResponse({"ok":True}) | |
| except Exception as e: | |
| logger.exception("history append failed") | |
| raise HTTPException(500, detail=str(e)) | |
| async def history_clear(): | |
| try: | |
| if HISTORY_PATH.exists(): HISTORY_PATH.unlink() | |
| return JSONResponse({"ok":True}) | |
| except Exception as e: | |
| logger.exception("history clear failed") | |
| raise HTTPException(500, detail=str(e)) | |
| async def health(): | |
| try: | |
| device = str(model.device) if hasattr(model, "device") else "unknown" | |
| return {"status":"ok","device":device} | |
| except Exception as e: | |
| return {"status":"error","error":str(e)} | |