Spaces:
Runtime error
Runtime error
File size: 3,253 Bytes
37197da 2499bf4 37197da ad298b5 5251795 2499bf4 55da0b7 8f89a48 5251795 37197da 8f89a48 2984b8e b98c632 37197da 2499bf4 5251795 ad298b5 5251795 2499bf4 2984b8e 5251795 2984b8e 5251795 4707822 37197da 2984b8e 066dad9 ad298b5 066dad9 8f89a48 2499bf4 55da0b7 2499bf4 55da0b7 2499bf4 37197da 8f89a48 55da0b7 5251795 2984b8e 2499bf4 5251795 55da0b7 2499bf4 55da0b7 5251795 2984b8e 5251795 2499bf4 55da0b7 37197da ad298b5 37197da 8f89a48 8b416f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import gradio as gr
import torch
import torch.multiprocessing as mp
from transformers import AutoModelForCausalLM, AutoTokenizer
import spaces
import logging
# Set multiprocessing to 'spawn' for ZeroGPU compatibility
try:
mp.set_start_method('spawn', force=True)
except RuntimeError:
pass
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global variables
model = None
tokenizer = None
MODEL_NAME = "ubiodee/plutus_llm"
# Load tokenizer at startup
try:
logger.info("Loading tokenizer at startup for %s...", MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
use_fast=True,
trust_remote_code=True,
)
logger.info("Primary tokenizer loaded successfully.")
except Exception as e:
logger.error(f"Tokenizer loading failed: {str(e)}")
raise
# Set pad token
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
logger.info("Set pad_token_id to eos_token_id: %s", tokenizer.eos_token_id)
# Load model at startup
try:
logger.info("Loading model %s with torch.float16...", MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16,
trust_remote_code=True,
)
model.eval()
logger.info("Model loaded successfully.")
except Exception as e:
logger.error(f"Model loading failed: {str(e)}")
raise
# Response function
@spaces.GPU(duration=120)
def generate_response(prompt, progress=gr.Progress()):
global model
progress(0.1, desc="Moving model to GPU...")
try:
if torch.cuda.is_available():
model = model.to("cuda")
logger.info("Model moved to GPU.")
else:
logger.warning("GPU not available; using CPU.")
progress(0.3, desc="Tokenizing input...")
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
progress(0.6, desc="Generating response...")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=200,
temperature=0.7,
top_p=0.9,
do_sample=True,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
if response.startswith(prompt):
response = response[len(prompt):].strip()
progress(1.0, desc="Done!")
return response
except Exception as e:
logger.error(f"Inference failed: {str(e)}")
return f"Error during generation: {str(e)}"
finally:
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info("GPU memory cleared.")
# Gradio UI
demo = gr.Interface(
fn=generate_response,
inputs=gr.Textbox(label="Enter your prompt", lines=4, placeholder="Ask about Plutus smart contracts..."),
outputs=gr.Textbox(label="Model Response"),
title="Cardano Plutus AI Assistant",
description="Write Plutus smart contracts on Cardano blockchain."
)
# Launch
demo.launch() |