Spaces:
Runtime error
Runtime error
File size: 5,036 Bytes
fdf2541 0516e0c fdf2541 0516e0c fdf2541 0516e0c fdf2541 0516e0c fdf2541 0516e0c fdf2541 0516e0c fdf2541 0516e0c fdf2541 0516e0c fdf2541 0516e0c fdf2541 0516e0c fdf2541 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import gradio as gr
import torch
from diffusers import DiffusionPipeline
from transformers import AutoTokenizer, AutoModelForCausalLM
# ------------------------------
# Utility: Device Detection
# ------------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TORCH_DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
# ------------------------------
# Unified Super Model Class
# ------------------------------
class CASS3Beta:
def __init__(self):
self.image_pipes = {}
self.text_models = {}
# Lazy-load image model
def load_image_model(self, model_name):
if model_name in self.image_pipes:
return self.image_pipes[model_name]
if model_name == "Lucy":
pipe = DiffusionPipeline.from_pretrained(
"decart-ai/Lucy-Edit-Dev",
trust_remote_code=True,
torch_dtype=TORCH_DTYPE
).to(DEVICE)
elif model_name == "Wan2.2":
pipe = DiffusionPipeline.from_pretrained(
"Wan-AI/Wan2.2-Animate-14B",
trust_remote_code=True,
torch_dtype=TORCH_DTYPE
).to(DEVICE)
elif model_name == "OpenJourney":
pipe = DiffusionPipeline.from_pretrained(
"prompthero/openjourney",
torch_dtype=TORCH_DTYPE
).to(DEVICE)
elif model_name == "StableXL":
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=TORCH_DTYPE
).to(DEVICE)
elif model_name == "Wan2.1":
pipe = DiffusionPipeline.from_pretrained(
"samuelchristlie/Wan2.1-VACE-1.3B-GGUF",
torch_dtype=TORCH_DTYPE
).to(DEVICE)
else:
raise ValueError(f"Unknown image model: {model_name}")
self.image_pipes[model_name] = pipe
return pipe
# Lazy-load text model
def load_text_model(self, model_name):
if model_name in self.text_models:
return self.text_models[model_name]
if model_name == "Qwen":
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-0.6B",
torch_dtype=TORCH_DTYPE
).to(DEVICE)
elif model_name == "Isaac":
tokenizer = None # Isaac handles tokenization internally
model = AutoModelForCausalLM.from_pretrained(
"PerceptronAI/Isaac-0.1",
trust_remote_code=True,
torch_dtype=TORCH_DTYPE
).to(DEVICE)
else:
raise ValueError(f"Unknown text model: {model_name}")
self.text_models[model_name] = (tokenizer, model)
return tokenizer, model
# Generate a single image
def generate_image(self, prompt, model_name):
pipe = self.load_image_model(model_name)
return pipe(prompt).images[0]
# Generate text
def generate_text(self, prompt, model_name):
tokenizer, model = self.load_text_model(model_name)
if tokenizer:
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
outputs = model.generate(**inputs, max_new_tokens=50)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
else:
# Isaac model
inputs = model.prepare_inputs_for_generation(prompt)
outputs = model.generate(**inputs, max_new_tokens=50)
return model.decode(outputs)
# Generate outputs from all models
def generate_all(self, prompt):
image_names = ["Lucy", "Wan2.2", "OpenJourney", "StableXL", "Wan2.1"]
images = {name: self.generate_image(prompt, name) for name in image_names}
texts = {
"Qwen": self.generate_text(prompt, "Qwen"),
"Isaac": self.generate_text(prompt, "Isaac")
}
return images, texts
# ------------------------------
# Instantiate Super Model
# ------------------------------
cass3 = CASS3Beta()
# ------------------------------
# Gradio Interface
# ------------------------------
def run_cass3(prompt):
images, texts = cass3.generate_all(prompt)
# List of images in fixed order
image_list = [images[name] for name in ["Lucy", "Wan2.2", "OpenJourney", "StableXL", "Wan2.1"]]
# Combine text outputs
text_output = f"Qwen:\n{texts['Qwen']}\n\nIsaac:\n{texts['Isaac']}"
return image_list, text_output
iface = gr.Interface(
fn=run_cass3,
inputs=gr.Textbox(lines=2, placeholder="Enter your prompt here...", label="Prompt"),
outputs=[
gr.Gallery(label="Generated Images", show_label=True, elem_id="image_gallery").style(grid=[3], height="auto"),
gr.Textbox(label="Generated Texts", lines=10)
],
title="CASS3.0beta - Unified Super Model",
description="All five image models + both LLMs in one Hugging Face Space."
)
iface.launch()
|