Spaces:
Running
on
Zero
Running
on
Zero
| # Project EmbodiedGen | |
| # | |
| # Copyright (c) 2025 Horizon Robotics. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or | |
| # implied. See the License for the specific language governing | |
| # permissions and limitations under the License. | |
| # Text-to-Image generation models from Hugging Face community. | |
| import os | |
| from abc import ABC, abstractmethod | |
| import torch | |
| from diffusers import ( | |
| ChromaPipeline, | |
| Cosmos2TextToImagePipeline, | |
| DPMSolverMultistepScheduler, | |
| FluxPipeline, | |
| KolorsPipeline, | |
| StableDiffusion3Pipeline, | |
| ) | |
| from diffusers.quantizers import PipelineQuantizationConfig | |
| from huggingface_hub import snapshot_download | |
| from PIL import Image | |
| from transformers import AutoModelForCausalLM, SiglipProcessor | |
| __all__ = [ | |
| "build_hf_image_pipeline", | |
| ] | |
| class BasePipelineLoader(ABC): | |
| def __init__(self, device="cuda"): | |
| self.device = device | |
| def load(self): | |
| pass | |
| class BasePipelineRunner(ABC): | |
| def __init__(self, pipe): | |
| self.pipe = pipe | |
| def run(self, prompt: str, **kwargs) -> Image.Image: | |
| pass | |
| # ===== SD3.5-medium ===== | |
| class SD35Loader(BasePipelineLoader): | |
| def load(self): | |
| pipe = StableDiffusion3Pipeline.from_pretrained( | |
| "stabilityai/stable-diffusion-3.5-medium", | |
| torch_dtype=torch.float16, | |
| ) | |
| pipe = pipe.to(self.device) | |
| pipe.enable_model_cpu_offload() | |
| pipe.enable_xformers_memory_efficient_attention() | |
| pipe.enable_attention_slicing() | |
| return pipe | |
| class SD35Runner(BasePipelineRunner): | |
| def run(self, prompt: str, **kwargs) -> Image.Image: | |
| return self.pipe(prompt=prompt, **kwargs).images | |
| # ===== Cosmos2 ===== | |
| class CosmosLoader(BasePipelineLoader): | |
| def __init__( | |
| self, | |
| model_id="nvidia/Cosmos-Predict2-2B-Text2Image", | |
| local_dir="weights/cosmos2", | |
| device="cuda", | |
| ): | |
| super().__init__(device) | |
| self.model_id = model_id | |
| self.local_dir = local_dir | |
| def _patch(self): | |
| def patch_model(cls): | |
| orig = cls.from_pretrained | |
| def new(*args, **kwargs): | |
| kwargs.setdefault("attn_implementation", "flash_attention_2") | |
| kwargs.setdefault("torch_dtype", torch.bfloat16) | |
| return orig(*args, **kwargs) | |
| cls.from_pretrained = new | |
| def patch_processor(cls): | |
| orig = cls.from_pretrained | |
| def new(*args, **kwargs): | |
| kwargs.setdefault("use_fast", True) | |
| return orig(*args, **kwargs) | |
| cls.from_pretrained = new | |
| patch_model(AutoModelForCausalLM) | |
| patch_processor(SiglipProcessor) | |
| def load(self): | |
| self._patch() | |
| snapshot_download( | |
| repo_id=self.model_id, | |
| local_dir=self.local_dir, | |
| local_dir_use_symlinks=False, | |
| resume_download=True, | |
| ) | |
| config = PipelineQuantizationConfig( | |
| quant_backend="bitsandbytes_4bit", | |
| quant_kwargs={ | |
| "load_in_4bit": True, | |
| "bnb_4bit_quant_type": "nf4", | |
| "bnb_4bit_compute_dtype": torch.bfloat16, | |
| "bnb_4bit_use_double_quant": True, | |
| }, | |
| components_to_quantize=["text_encoder", "transformer", "unet"], | |
| ) | |
| pipe = Cosmos2TextToImagePipeline.from_pretrained( | |
| self.model_id, | |
| torch_dtype=torch.bfloat16, | |
| quantization_config=config, | |
| use_safetensors=True, | |
| safety_checker=None, | |
| requires_safety_checker=False, | |
| ).to(self.device) | |
| return pipe | |
| class CosmosRunner(BasePipelineRunner): | |
| def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image: | |
| return self.pipe( | |
| prompt=prompt, negative_prompt=negative_prompt, **kwargs | |
| ).images | |
| # ===== Kolors ===== | |
| class KolorsLoader(BasePipelineLoader): | |
| def load(self): | |
| pipe = KolorsPipeline.from_pretrained( | |
| "Kwai-Kolors/Kolors-diffusers", | |
| torch_dtype=torch.float16, | |
| variant="fp16", | |
| ).to(self.device) | |
| pipe.enable_model_cpu_offload() | |
| pipe.enable_xformers_memory_efficient_attention() | |
| pipe.scheduler = DPMSolverMultistepScheduler.from_config( | |
| pipe.scheduler.config, use_karras_sigmas=True | |
| ) | |
| return pipe | |
| class KolorsRunner(BasePipelineRunner): | |
| def run(self, prompt: str, **kwargs) -> Image.Image: | |
| return self.pipe(prompt=prompt, **kwargs).images | |
| # ===== Flux ===== | |
| class FluxLoader(BasePipelineLoader): | |
| def load(self): | |
| os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' | |
| pipe = FluxPipeline.from_pretrained( | |
| "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16 | |
| ) | |
| pipe.enable_model_cpu_offload() | |
| pipe.enable_xformers_memory_efficient_attention() | |
| pipe.enable_attention_slicing() | |
| return pipe.to(self.device) | |
| class FluxRunner(BasePipelineRunner): | |
| def run(self, prompt: str, **kwargs) -> Image.Image: | |
| return self.pipe(prompt=prompt, **kwargs).images | |
| # ===== Chroma ===== | |
| class ChromaLoader(BasePipelineLoader): | |
| def load(self): | |
| return ChromaPipeline.from_pretrained( | |
| "lodestones/Chroma", torch_dtype=torch.bfloat16 | |
| ).to(self.device) | |
| class ChromaRunner(BasePipelineRunner): | |
| def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image: | |
| return self.pipe( | |
| prompt=prompt, negative_prompt=negative_prompt, **kwargs | |
| ).images | |
| PIPELINE_REGISTRY = { | |
| "sd35": (SD35Loader, SD35Runner), | |
| "cosmos": (CosmosLoader, CosmosRunner), | |
| "kolors": (KolorsLoader, KolorsRunner), | |
| "flux": (FluxLoader, FluxRunner), | |
| "chroma": (ChromaLoader, ChromaRunner), | |
| } | |
| def build_hf_image_pipeline(name: str, device="cuda") -> BasePipelineRunner: | |
| if name not in PIPELINE_REGISTRY: | |
| raise ValueError(f"Unsupported model: {name}") | |
| loader_cls, runner_cls = PIPELINE_REGISTRY[name] | |
| pipe = loader_cls(device=device).load() | |
| return runner_cls(pipe) | |
| if __name__ == "__main__": | |
| model_name = "sd35" | |
| runner = build_hf_image_pipeline(model_name) | |
| # NOTE: Just for pipeline testing, generation quality at low resolution is poor. | |
| images = runner.run( | |
| prompt="A robot holding a sign that says 'Hello'", | |
| height=512, | |
| width=512, | |
| num_inference_steps=10, | |
| guidance_scale=6, | |
| num_images_per_prompt=1, | |
| ) | |
| for i, img in enumerate(images): | |
| img.save(f"image_{model_name}_{i}.jpg") | |