File size: 3,253 Bytes
37197da
 
2499bf4
37197da
ad298b5
5251795
2499bf4
 
55da0b7
 
 
8f89a48
5251795
 
 
 
37197da
8f89a48
2984b8e
 
b98c632
37197da
2499bf4
5251795
ad298b5
5251795
 
2499bf4
2984b8e
5251795
2984b8e
5251795
4707822
 
37197da
2984b8e
066dad9
 
ad298b5
066dad9
8f89a48
2499bf4
55da0b7
2499bf4
 
 
 
 
 
55da0b7
2499bf4
 
 
37197da
8f89a48
55da0b7
5251795
2984b8e
2499bf4
5251795
55da0b7
 
 
 
 
 
2499bf4
55da0b7
5251795
2984b8e
5251795
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2499bf4
 
 
55da0b7
37197da
 
 
 
ad298b5
37197da
 
 
 
 
8f89a48
8b416f9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import gradio as gr
import torch
import torch.multiprocessing as mp
from transformers import AutoModelForCausalLM, AutoTokenizer
import spaces
import logging

# Set multiprocessing to 'spawn' for ZeroGPU compatibility
try:
    mp.set_start_method('spawn', force=True)
except RuntimeError:
    pass

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Global variables
model = None
tokenizer = None
MODEL_NAME = "ubiodee/plutus_llm"

# Load tokenizer at startup
try:
    logger.info("Loading tokenizer at startup for %s...", MODEL_NAME)
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_NAME,
        use_fast=True,
        trust_remote_code=True,
    )
    logger.info("Primary tokenizer loaded successfully.")
except Exception as e:
    logger.error(f"Tokenizer loading failed: {str(e)}")
    raise

# Set pad token
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id
    logger.info("Set pad_token_id to eos_token_id: %s", tokenizer.eos_token_id)

# Load model at startup
try:
    logger.info("Loading model %s with torch.float16...", MODEL_NAME)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.float16,
        trust_remote_code=True,
    )
    model.eval()
    logger.info("Model loaded successfully.")
except Exception as e:
    logger.error(f"Model loading failed: {str(e)}")
    raise

# Response function
@spaces.GPU(duration=120)
def generate_response(prompt, progress=gr.Progress()):
    global model
    progress(0.1, desc="Moving model to GPU...")
    try:
        if torch.cuda.is_available():
            model = model.to("cuda")
            logger.info("Model moved to GPU.")
        else:
            logger.warning("GPU not available; using CPU.")
        
        progress(0.3, desc="Tokenizing input...")
        inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
        
        progress(0.6, desc="Generating response...")
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=200,
                temperature=0.7,
                top_p=0.9,
                do_sample=True,
                eos_token_id=tokenizer.eos_token_id,
                pad_token_id=tokenizer.pad_token_id,
            )
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        if response.startswith(prompt):
            response = response[len(prompt):].strip()
        
        progress(1.0, desc="Done!")
        return response
    except Exception as e:
        logger.error(f"Inference failed: {str(e)}")
        return f"Error during generation: {str(e)}"
    finally:
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            logger.info("GPU memory cleared.")

# Gradio UI
demo = gr.Interface(
    fn=generate_response,
    inputs=gr.Textbox(label="Enter your prompt", lines=4, placeholder="Ask about Plutus smart contracts..."),
    outputs=gr.Textbox(label="Model Response"),
    title="Cardano Plutus AI Assistant",
    description="Write Plutus smart contracts on Cardano blockchain."
)

# Launch
demo.launch()