File size: 3,673 Bytes
37197da
 
 
5251795
 
 
 
 
 
37197da
2984b8e
 
 
37197da
2984b8e
37197da
2984b8e
5251795
2984b8e
5251795
 
2984b8e
 
5251795
2984b8e
5251795
2984b8e
 
 
 
 
5251795
2984b8e
37197da
2984b8e
066dad9
 
5251795
066dad9
2984b8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37197da
2984b8e
 
5251795
2984b8e
 
 
 
 
5251795
 
 
2984b8e
5251795
 
 
 
 
 
 
 
 
 
 
 
2984b8e
5251795
 
 
 
 
 
 
 
37197da
 
 
 
 
 
 
 
 
 
5251795
2984b8e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from spaces import GPU
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Global variables for model and tokenizer (lazy loading)
model = None
tokenizer = None
MODEL_NAME = "ubiodee/Test_Plutus"
FALLBACK_TOKENIZER = "gpt2"

# Load tokenizer at startup (lightweight, no model yet)
try:
    logger.info("Loading tokenizer at startup with legacy versions...")
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_NAME,
        use_fast=False,
        trust_remote_code=True,
    )
    logger.info("Primary tokenizer loaded successfully.")
except Exception as e:
    logger.warning(f"Primary tokenizer failed: {str(e)}. Using fallback.")
    tokenizer = AutoTokenizer.from_pretrained(
        FALLBACK_TOKENIZER,
        use_fast=False,
        trust_remote_code=True,
    )
    logger.info("Fallback tokenizer loaded.")

# Set pad token
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id
    logger.info("Set pad_token_id to eos_token_id.")

def load_model():
    """Load model inside GPU context to enable quantization."""
    global model
    if model is None:
        try:
            logger.info("Loading model with CPU fallback (full precision)...")
            model = AutoModelForCausalLM.from_pretrained(
                MODEL_NAME,
                torch_dtype=torch.float16,  # Use fp16 for memory efficiency without bitsandbytes
                low_cpu_mem_usage=True,
                trust_remote_code=True,
            )
            model.eval()
            if torch.cuda.is_available():
                model.to("cuda")
                logger.info("Model loaded and moved to GPU.")
            else:
                logger.warning("GPU not available; using CPU.")
        except Exception as e:
            logger.error(f"Model loading failed: {str(e)}")
            raise
    return model

# Response function: Load model on first call, then reuse
@spaces.GPU(duration=300)  # Allow up to 5min for loading + inference
def generate_response(prompt, progress=gr.Progress()):
    global model
    progress(0.1, desc="Loading model if needed...")
    model = load_model()  # Ensures model is loaded in GPU context
    
    progress(0.3, desc="Tokenizing input...")
    try:
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        
        progress(0.6, desc="Generating response...")
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=200,
                temperature=0.7,
                top_p=0.9,
                do_sample=True,
                eos_token_id=tokenizer.eos_token_id,
                pad_token_id=tokenizer.pad_token_id,
            )
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Remove prompt from output
        if response.startswith(prompt):
            response = response[len(prompt):].strip()
        
        progress(1.0, desc="Done!")
        return response
    except Exception as e:
        logger.error(f"Inference failed: {str(e)}")
        return f"Error during generation: {str(e)}"

# Gradio UI
demo = gr.Interface(
    fn=generate_response,
    inputs=gr.Textbox(label="Enter your prompt", lines=4, placeholder="Ask about Plutus..."),
    outputs=gr.Textbox(label="Model Response"),
    title="Cardano Plutus AI Assistant",
    description="Write Plutus smart contracts on Cardano blockchain."
)

# Launch with queueing
demo.queue(max_size=5).launch(enable_queue=True, max_threads=1)