Spaces:
Running
on
Zero
Running
on
Zero
| # Generic Diffusion Framework (GDF) | |
| # Basic usage | |
| GDF is a simple framework for working with diffusion models. It implements most common diffusion frameworks (DDPM / DDIM | |
| , EDM, Rectified Flows, etc.) and makes it very easy to switch between them or combine different parts of different | |
| frameworks | |
| Using GDF is very straighforward, first of all just define an instance of the GDF class: | |
| ```python | |
| from gdf import GDF | |
| from gdf import CosineSchedule | |
| from gdf import VPScaler, EpsilonTarget, CosineTNoiseCond, P2LossWeight | |
| gdf = GDF( | |
| schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]), | |
| input_scaler=VPScaler(), target=EpsilonTarget(), | |
| noise_cond=CosineTNoiseCond(), | |
| loss_weight=P2LossWeight(), | |
| ) | |
| ``` | |
| You need to define the following components: | |
| * **Train Schedule**: This will return the logSNR schedule that will be used during training, some of the schedulers can be configured. A train schedule will then be called with a batch size and will randomly sample some values from the defined distribution. | |
| * **Sample Schedule**: This is the schedule that will be used later on when sampling. It might be different from the training schedule. | |
| * **Input Scaler**: If you want to use Variance Preserving or LERP (rectified flows) | |
| * **Target**: What the target is during training, usually: epsilon, x0 or v | |
| * **Noise Conditioning**: You could directly pass the logSNR to your model but usually a normalized value is used instead, for example the EDM framework proposes to use `-logSNR/8` | |
| * **Loss Weight**: There are many proposed loss weighting strategies, here you define which one you'll use | |
| All of those classes are actually very simple logSNR centric definitions, for example the VPScaler is defined as just: | |
| ```python | |
| class VPScaler(): | |
| def __call__(self, logSNR): | |
| a_squared = logSNR.sigmoid() | |
| a = a_squared.sqrt() | |
| b = (1-a_squared).sqrt() | |
| return a, b | |
| ``` | |
| So it's very easy to extend this framework with custom schedulers, scalers, targets, loss weights, etc... | |
| ### Training | |
| When you define your training loop you can get all you need by just doing: | |
| ```python | |
| shift, loss_shift = 1, 1 # this can be set to higher values as per what the Simple Diffusion paper sugested for high resolution | |
| for inputs, extra_conditions in dataloader_iterator: | |
| noised, noise, target, logSNR, noise_cond, loss_weight = gdf.diffuse(inputs, shift=shift, loss_shift=loss_shift) | |
| pred = diffusion_model(noised, noise_cond, extra_conditions) | |
| loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3]) | |
| loss_adjusted = (loss * loss_weight).mean() | |
| loss_adjusted.backward() | |
| optimizer.step() | |
| optimizer.zero_grad(set_to_none=True) | |
| ``` | |
| And that's all, you have a diffusion model training, where it's very easy to customize the different elements of the | |
| training from the GDF class. | |
| ### Sampling | |
| The other important part is sampling, when you want to use this framework to sample you can just do the following: | |
| ```python | |
| from gdf import DDPMSampler | |
| shift = 1 | |
| sampling_configs = { | |
| "timesteps": 30, "cfg": 7, "sampler": DDPMSampler(gdf), "shift": shift, | |
| "schedule": CosineSchedule(clamp_range=[0.0001, 0.9999]) | |
| } | |
| *_, (sampled, _, _) = gdf.sample( | |
| diffusion_model, {"cond": extra_conditions}, latents.shape, | |
| unconditional_inputs= {"cond": torch.zeros_like(extra_conditions)}, | |
| device=device, **sampling_configs | |
| ) | |
| ``` | |
| # Available modules | |
| TODO | |