Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, HTTPException | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| import logging | |
| from pydantic import BaseModel | |
| import os | |
| import tarfile | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Debug environment variables | |
| logger.info("Environment variables: %s", {k: "****" if "TOKEN" in k or k == "granite" else v for k, v in os.environ.items()}) | |
| app = FastAPI() | |
| model_tarball = "/app/granite-8b-finetuned-ascii.tar.gz" | |
| model_path = "/app/granite-8b-finetuned-ascii" | |
| # Extract tarball if model directory doesn't exist | |
| if not os.path.exists(model_path): | |
| logger.info(f"Extracting model tarball: {model_tarball}") | |
| try: | |
| with tarfile.open(model_tarball, "r:gz") as tar: | |
| tar.extractall(path="/app") | |
| logger.info("Model tarball extracted successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to extract model tarball: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Model tarball extraction failed: {str(e)}") | |
| try: | |
| logger.info("Loading tokenizer and model") | |
| tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | |
| tokenizer.padding_side = 'right' | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| trust_remote_code=True | |
| ) | |
| logger.info("Model and tokenizer loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load model or tokenizer: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Model initialization failed: {str(e)}") | |
| class EditRequest(BaseModel): | |
| text: str | |
| def greet_json(): | |
| return {"status": "Model is ready", "model": model_path} | |
| async def generate(request: EditRequest): | |
| try: | |
| prompt = f"Edit this AsciiDoc sentence: {request.text}" | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| outputs = model.generate(**inputs, max_length=200) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| logger.info(f"Generated response for prompt: {prompt}") | |
| return {"response": response} | |
| except Exception as e: | |
| logger.error(f"Generation failed: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}") | |