import os os.environ["HF_HOME"] = "/tmp" # Use /tmp as cache to avoid permission errors from fastapi import FastAPI from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import torch app = FastAPI() model_name = "tacab/mt5-beero_somali" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSeq2SeqLM.from_pretrained(model_name) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) class Query(BaseModel): inputs: str @app.post("/generate") async def generate_text(query: Query): input_ids = tokenizer.encode(query.inputs, return_tensors="pt").to(device) output_ids = model.generate(input_ids, max_length=128) response = tokenizer.decode(output_ids[0], skip_special_tokens=True) return {"generated_text": response}