CASS3.0Beta / app.py
DSDUDEd's picture
Update app.py
0516e0c verified
import gradio as gr
import torch
from diffusers import DiffusionPipeline
from transformers import AutoTokenizer, AutoModelForCausalLM
# ------------------------------
# Utility: Device Detection
# ------------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TORCH_DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
# ------------------------------
# Unified Super Model Class
# ------------------------------
class CASS3Beta:
def __init__(self):
self.image_pipes = {}
self.text_models = {}
# Lazy-load image model
def load_image_model(self, model_name):
if model_name in self.image_pipes:
return self.image_pipes[model_name]
if model_name == "Lucy":
pipe = DiffusionPipeline.from_pretrained(
"decart-ai/Lucy-Edit-Dev",
trust_remote_code=True,
torch_dtype=TORCH_DTYPE
).to(DEVICE)
elif model_name == "Wan2.2":
pipe = DiffusionPipeline.from_pretrained(
"Wan-AI/Wan2.2-Animate-14B",
trust_remote_code=True,
torch_dtype=TORCH_DTYPE
).to(DEVICE)
elif model_name == "OpenJourney":
pipe = DiffusionPipeline.from_pretrained(
"prompthero/openjourney",
torch_dtype=TORCH_DTYPE
).to(DEVICE)
elif model_name == "StableXL":
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=TORCH_DTYPE
).to(DEVICE)
elif model_name == "Wan2.1":
pipe = DiffusionPipeline.from_pretrained(
"samuelchristlie/Wan2.1-VACE-1.3B-GGUF",
torch_dtype=TORCH_DTYPE
).to(DEVICE)
else:
raise ValueError(f"Unknown image model: {model_name}")
self.image_pipes[model_name] = pipe
return pipe
# Lazy-load text model
def load_text_model(self, model_name):
if model_name in self.text_models:
return self.text_models[model_name]
if model_name == "Qwen":
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-0.6B",
torch_dtype=TORCH_DTYPE
).to(DEVICE)
elif model_name == "Isaac":
tokenizer = None # Isaac handles tokenization internally
model = AutoModelForCausalLM.from_pretrained(
"PerceptronAI/Isaac-0.1",
trust_remote_code=True,
torch_dtype=TORCH_DTYPE
).to(DEVICE)
else:
raise ValueError(f"Unknown text model: {model_name}")
self.text_models[model_name] = (tokenizer, model)
return tokenizer, model
# Generate a single image
def generate_image(self, prompt, model_name):
pipe = self.load_image_model(model_name)
return pipe(prompt).images[0]
# Generate text
def generate_text(self, prompt, model_name):
tokenizer, model = self.load_text_model(model_name)
if tokenizer:
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
outputs = model.generate(**inputs, max_new_tokens=50)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
else:
# Isaac model
inputs = model.prepare_inputs_for_generation(prompt)
outputs = model.generate(**inputs, max_new_tokens=50)
return model.decode(outputs)
# Generate outputs from all models
def generate_all(self, prompt):
image_names = ["Lucy", "Wan2.2", "OpenJourney", "StableXL", "Wan2.1"]
images = {name: self.generate_image(prompt, name) for name in image_names}
texts = {
"Qwen": self.generate_text(prompt, "Qwen"),
"Isaac": self.generate_text(prompt, "Isaac")
}
return images, texts
# ------------------------------
# Instantiate Super Model
# ------------------------------
cass3 = CASS3Beta()
# ------------------------------
# Gradio Interface
# ------------------------------
def run_cass3(prompt):
images, texts = cass3.generate_all(prompt)
# List of images in fixed order
image_list = [images[name] for name in ["Lucy", "Wan2.2", "OpenJourney", "StableXL", "Wan2.1"]]
# Combine text outputs
text_output = f"Qwen:\n{texts['Qwen']}\n\nIsaac:\n{texts['Isaac']}"
return image_list, text_output
iface = gr.Interface(
fn=run_cass3,
inputs=gr.Textbox(lines=2, placeholder="Enter your prompt here...", label="Prompt"),
outputs=[
gr.Gallery(label="Generated Images", show_label=True, elem_id="image_gallery").style(grid=[3], height="auto"),
gr.Textbox(label="Generated Texts", lines=10)
],
title="CASS3.0beta - Unified Super Model",
description="All five image models + both LLMs in one Hugging Face Space."
)
iface.launch()