Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn.functional as F | |
| from diffusers import VQDiffusionScheduler | |
| from .test_schedulers import SchedulerCommonTest | |
| class VQDiffusionSchedulerTest(SchedulerCommonTest): | |
| scheduler_classes = (VQDiffusionScheduler,) | |
| def get_scheduler_config(self, **kwargs): | |
| config = { | |
| "num_vec_classes": 4097, | |
| "num_train_timesteps": 100, | |
| } | |
| config.update(**kwargs) | |
| return config | |
| def dummy_sample(self, num_vec_classes): | |
| batch_size = 4 | |
| height = 8 | |
| width = 8 | |
| sample = torch.randint(0, num_vec_classes, (batch_size, height * width)) | |
| return sample | |
| def dummy_sample_deter(self): | |
| assert False | |
| def dummy_model(self, num_vec_classes): | |
| def model(sample, t, *args): | |
| batch_size, num_latent_pixels = sample.shape | |
| logits = torch.rand((batch_size, num_vec_classes - 1, num_latent_pixels)) | |
| return_value = F.log_softmax(logits.double(), dim=1).float() | |
| return return_value | |
| return model | |
| def test_timesteps(self): | |
| for timesteps in [2, 5, 100, 1000]: | |
| self.check_over_configs(num_train_timesteps=timesteps) | |
| def test_num_vec_classes(self): | |
| for num_vec_classes in [5, 100, 1000, 4000]: | |
| self.check_over_configs(num_vec_classes=num_vec_classes) | |
| def test_time_indices(self): | |
| for t in [0, 50, 99]: | |
| self.check_over_forward(time_step=t) | |
| def test_add_noise_device(self): | |
| pass | |