Spaces:
Runtime error
Runtime error
| import os | |
| import time | |
| import uuid | |
| from typing import List, Optional, Dict, Any | |
| import torch | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import RedirectResponse | |
| from pydantic import BaseModel, Field | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| MODEL_ID = os.getenv("MODEL_ID", "LiquidAI/LFM2-1.2B") | |
| DEFAULT_MAX_TOKENS = int(os.getenv("MAX_TOKENS", "256")) | |
| app = FastAPI(title="OpenAI-compatible API for LiquidAI/LFM2-1.2B") | |
| tokenizer = None | |
| model = None | |
| def get_dtype() -> torch.dtype: | |
| if torch.cuda.is_available(): | |
| # Prefer bfloat16 if supported; else float16 | |
| if torch.cuda.is_bf16_supported(): | |
| return torch.bfloat16 | |
| return torch.float16 | |
| # CPU | |
| return torch.float32 | |
| def load_model(): | |
| global tokenizer, model | |
| dtype = get_dtype() | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=dtype, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| # Ensure eos/bos tokens exist | |
| if tokenizer.eos_token is None: | |
| tokenizer.eos_token = tokenizer.sep_token or tokenizer.pad_token or "</s>" | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| class ChatMessage(BaseModel): | |
| role: str | |
| content: str | |
| class ChatCompletionRequest(BaseModel): | |
| model: Optional[str] = Field(default=MODEL_ID) | |
| messages: List[ChatMessage] | |
| temperature: Optional[float] = 0.7 | |
| top_p: Optional[float] = 0.95 | |
| max_tokens: Optional[int] = None | |
| stop: Optional[List[str] | str] = None | |
| n: Optional[int] = 1 | |
| class CompletionRequest(BaseModel): | |
| model: Optional[str] = Field(default=MODEL_ID) | |
| prompt: str | List[str] | |
| temperature: Optional[float] = 0.7 | |
| top_p: Optional[float] = 0.95 | |
| max_tokens: Optional[int] = None | |
| stop: Optional[List[str] | str] = None | |
| n: Optional[int] = 1 | |
| class Usage(BaseModel): | |
| prompt_tokens: int | |
| completion_tokens: int | |
| total_tokens: int | |
| # Simple chat prompt formatter | |
| def build_chat_prompt(messages: List[ChatMessage]) -> str: | |
| system_prefix = "You are a helpful assistant." | |
| system_msgs = [m.content for m in messages if m.role == "system"] | |
| if system_msgs: | |
| system_prefix = system_msgs[-1] | |
| conv: List[str] = [f"System: {system_prefix}"] | |
| for m in messages: | |
| if m.role == "system": | |
| continue | |
| role = "User" if m.role == "user" else ("Assistant" if m.role == "assistant" else m.role.capitalize()) | |
| conv.append(f"{role}: {m.content}") | |
| conv.append("Assistant:") | |
| return "\n".join(conv) | |
| def apply_stop_sequences(text: str, stop: Optional[List[str] | str]) -> str: | |
| if stop is None: | |
| return text | |
| stops = stop if isinstance(stop, list) else [stop] | |
| cut = len(text) | |
| for s in stops: | |
| if not s: | |
| continue | |
| idx = text.find(s) | |
| if idx != -1: | |
| cut = min(cut, idx) | |
| return text[:cut] | |
| def generate_once(prompt: str, temperature: float, top_p: float, max_new_tokens: int) -> Dict[str, Any]: | |
| assert tokenizer is not None and model is not None, "Model not loaded" | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| gen_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True if temperature and temperature > 0 else False, | |
| temperature=max(0.0, float(temperature or 0.0)), | |
| top_p=max(0.0, float(top_p or 1.0)), | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| out = tokenizer.decode(gen_ids[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True) | |
| return { | |
| "text": out, | |
| "prompt_tokens": inputs["input_ids"].numel(), | |
| "completion_tokens": gen_ids[0].shape[0] - inputs["input_ids"].shape[-1], | |
| } | |
| def root(): | |
| return RedirectResponse(url="/docs") | |
| def health(): | |
| return {"status": "ok", "model": MODEL_ID} | |
| def chat_completions(req: ChatCompletionRequest): | |
| if req.n and req.n > 1: | |
| raise HTTPException(status_code=400, detail="Only n=1 is supported in this simple server.") | |
| max_new = req.max_tokens or DEFAULT_MAX_TOKENS | |
| prompt = build_chat_prompt(req.messages) | |
| g = generate_once(prompt, req.temperature or 0.7, req.top_p or 0.95, max_new) | |
| text = apply_stop_sequences(g["text"], req.stop) | |
| created = int(time.time()) | |
| comp_id = f"chatcmpl-{uuid.uuid4().hex[:24]}" | |
| usage = Usage( | |
| prompt_tokens=g["prompt_tokens"], | |
| completion_tokens=g["completion_tokens"], | |
| total_tokens=g["prompt_tokens"] + g["completion_tokens"], | |
| ) | |
| return { | |
| "id": comp_id, | |
| "object": "chat.completion", | |
| "created": created, | |
| "model": req.model or MODEL_ID, | |
| "choices": [ | |
| { | |
| "index": 0, | |
| "message": {"role": "assistant", "content": text}, | |
| "finish_reason": "stop", | |
| } | |
| ], | |
| "usage": usage.dict(), | |
| } | |
| def completions(req: CompletionRequest): | |
| if req.n and req.n > 1: | |
| raise HTTPException(status_code=400, detail="Only n=1 is supported in this simple server.") | |
| prompts = req.prompt if isinstance(req.prompt, list) else [req.prompt] | |
| if len(prompts) != 1: | |
| raise HTTPException(status_code=400, detail="Only a single prompt is supported in this simple server.") | |
| max_new = req.max_tokens or DEFAULT_MAX_TOKENS | |
| g = generate_once(prompts[0], req.temperature or 0.7, req.top_p or 0.95, max_new) | |
| text = apply_stop_sequences(g["text"], req.stop) | |
| created = int(time.time()) | |
| comp_id = f"cmpl-{uuid.uuid4().hex[:24]}" | |
| usage = Usage( | |
| prompt_tokens=g["prompt_tokens"], | |
| completion_tokens=g["completion_tokens"], | |
| total_tokens=g["prompt_tokens"] + g["completion_tokens"], | |
| ) | |
| return { | |
| "id": comp_id, | |
| "object": "text_completion", | |
| "created": created, | |
| "model": req.model or MODEL_ID, | |
| "choices": [ | |
| { | |
| "index": 0, | |
| "text": text, | |
| "finish_reason": "stop", | |
| "logprobs": None, | |
| } | |
| ], | |
| "usage": usage.dict(), | |
| } | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.getenv("PORT", "7860")) | |
| uvicorn.run("app:app", host="0.0.0.0", port=port, reload=False) | |