Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from diffusers import DiffusionPipeline | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| # ------------------------------ | |
| # Utility: Device Detection | |
| # ------------------------------ | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| TORCH_DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32 | |
| # ------------------------------ | |
| # Unified Super Model Class | |
| # ------------------------------ | |
| class CASS3Beta: | |
| def __init__(self): | |
| self.image_pipes = {} | |
| self.text_models = {} | |
| # Lazy-load image model | |
| def load_image_model(self, model_name): | |
| if model_name in self.image_pipes: | |
| return self.image_pipes[model_name] | |
| if model_name == "Lucy": | |
| pipe = DiffusionPipeline.from_pretrained( | |
| "decart-ai/Lucy-Edit-Dev", | |
| trust_remote_code=True, | |
| torch_dtype=TORCH_DTYPE | |
| ).to(DEVICE) | |
| elif model_name == "Wan2.2": | |
| pipe = DiffusionPipeline.from_pretrained( | |
| "Wan-AI/Wan2.2-Animate-14B", | |
| trust_remote_code=True, | |
| torch_dtype=TORCH_DTYPE | |
| ).to(DEVICE) | |
| elif model_name == "OpenJourney": | |
| pipe = DiffusionPipeline.from_pretrained( | |
| "prompthero/openjourney", | |
| torch_dtype=TORCH_DTYPE | |
| ).to(DEVICE) | |
| elif model_name == "StableXL": | |
| pipe = DiffusionPipeline.from_pretrained( | |
| "stabilityai/stable-diffusion-xl-base-1.0", | |
| torch_dtype=TORCH_DTYPE | |
| ).to(DEVICE) | |
| elif model_name == "Wan2.1": | |
| pipe = DiffusionPipeline.from_pretrained( | |
| "samuelchristlie/Wan2.1-VACE-1.3B-GGUF", | |
| torch_dtype=TORCH_DTYPE | |
| ).to(DEVICE) | |
| else: | |
| raise ValueError(f"Unknown image model: {model_name}") | |
| self.image_pipes[model_name] = pipe | |
| return pipe | |
| # Lazy-load text model | |
| def load_text_model(self, model_name): | |
| if model_name in self.text_models: | |
| return self.text_models[model_name] | |
| if model_name == "Qwen": | |
| tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "Qwen/Qwen3-0.6B", | |
| torch_dtype=TORCH_DTYPE | |
| ).to(DEVICE) | |
| elif model_name == "Isaac": | |
| tokenizer = None # Isaac handles tokenization internally | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "PerceptronAI/Isaac-0.1", | |
| trust_remote_code=True, | |
| torch_dtype=TORCH_DTYPE | |
| ).to(DEVICE) | |
| else: | |
| raise ValueError(f"Unknown text model: {model_name}") | |
| self.text_models[model_name] = (tokenizer, model) | |
| return tokenizer, model | |
| # Generate a single image | |
| def generate_image(self, prompt, model_name): | |
| pipe = self.load_image_model(model_name) | |
| return pipe(prompt).images[0] | |
| # Generate text | |
| def generate_text(self, prompt, model_name): | |
| tokenizer, model = self.load_text_model(model_name) | |
| if tokenizer: | |
| inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) | |
| outputs = model.generate(**inputs, max_new_tokens=50) | |
| return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| else: | |
| # Isaac model | |
| inputs = model.prepare_inputs_for_generation(prompt) | |
| outputs = model.generate(**inputs, max_new_tokens=50) | |
| return model.decode(outputs) | |
| # Generate outputs from all models | |
| def generate_all(self, prompt): | |
| image_names = ["Lucy", "Wan2.2", "OpenJourney", "StableXL", "Wan2.1"] | |
| images = {name: self.generate_image(prompt, name) for name in image_names} | |
| texts = { | |
| "Qwen": self.generate_text(prompt, "Qwen"), | |
| "Isaac": self.generate_text(prompt, "Isaac") | |
| } | |
| return images, texts | |
| # ------------------------------ | |
| # Instantiate Super Model | |
| # ------------------------------ | |
| cass3 = CASS3Beta() | |
| # ------------------------------ | |
| # Gradio Interface | |
| # ------------------------------ | |
| def run_cass3(prompt): | |
| images, texts = cass3.generate_all(prompt) | |
| # List of images in fixed order | |
| image_list = [images[name] for name in ["Lucy", "Wan2.2", "OpenJourney", "StableXL", "Wan2.1"]] | |
| # Combine text outputs | |
| text_output = f"Qwen:\n{texts['Qwen']}\n\nIsaac:\n{texts['Isaac']}" | |
| return image_list, text_output | |
| iface = gr.Interface( | |
| fn=run_cass3, | |
| inputs=gr.Textbox(lines=2, placeholder="Enter your prompt here...", label="Prompt"), | |
| outputs=[ | |
| gr.Gallery(label="Generated Images", show_label=True, elem_id="image_gallery").style(grid=[3], height="auto"), | |
| gr.Textbox(label="Generated Texts", lines=10) | |
| ], | |
| title="CASS3.0beta - Unified Super Model", | |
| description="All five image models + both LLMs in one Hugging Face Space." | |
| ) | |
| iface.launch() | |