Spaces:
Running
on
Zero
Running
on
Zero
Update models/unet.py
Browse files- models/unet.py +8 -0
models/unet.py
CHANGED
|
@@ -759,24 +759,32 @@ class CondUnet1D(nn.Module):
|
|
| 759 |
cond,
|
| 760 |
cond_indices,
|
| 761 |
):
|
|
|
|
| 762 |
temb = self.time_mlp(t)
|
| 763 |
|
| 764 |
h = []
|
| 765 |
for block1, block2, downsample in self.downs:
|
|
|
|
|
|
|
| 766 |
x = block1(x, temb, cond, cond_indices)
|
| 767 |
x = block2(x, temb, cond, cond_indices)
|
| 768 |
h.append(x)
|
| 769 |
x = downsample(x)
|
| 770 |
|
|
|
|
|
|
|
| 771 |
x = self.mid_block1(x, temb, cond, cond_indices)
|
| 772 |
x = self.mid_block2(x, temb, cond, cond_indices)
|
| 773 |
|
| 774 |
for upsample, block1, block2 in self.ups:
|
| 775 |
x = upsample(x)
|
| 776 |
x = torch.cat((x, h.pop()), dim=1)
|
|
|
|
|
|
|
| 777 |
x = block1(x, temb, cond, cond_indices)
|
| 778 |
x = block2(x, temb, cond, cond_indices)
|
| 779 |
|
|
|
|
| 780 |
x = self.final_conv(x)
|
| 781 |
return x
|
| 782 |
|
|
|
|
| 759 |
cond,
|
| 760 |
cond_indices,
|
| 761 |
):
|
| 762 |
+
self.time_mlp = self.time_mlp.cuda()
|
| 763 |
temb = self.time_mlp(t)
|
| 764 |
|
| 765 |
h = []
|
| 766 |
for block1, block2, downsample in self.downs:
|
| 767 |
+
block1 = block1.cuda()
|
| 768 |
+
block2 = block2.cuda()
|
| 769 |
x = block1(x, temb, cond, cond_indices)
|
| 770 |
x = block2(x, temb, cond, cond_indices)
|
| 771 |
h.append(x)
|
| 772 |
x = downsample(x)
|
| 773 |
|
| 774 |
+
self.mid_block1 = self.mid_block1.cuda()
|
| 775 |
+
self.mid_block2 = self.mid_block2.cuda()
|
| 776 |
x = self.mid_block1(x, temb, cond, cond_indices)
|
| 777 |
x = self.mid_block2(x, temb, cond, cond_indices)
|
| 778 |
|
| 779 |
for upsample, block1, block2 in self.ups:
|
| 780 |
x = upsample(x)
|
| 781 |
x = torch.cat((x, h.pop()), dim=1)
|
| 782 |
+
block1 = block1.cuda()
|
| 783 |
+
block2 = block2.cuda()
|
| 784 |
x = block1(x, temb, cond, cond_indices)
|
| 785 |
x = block2(x, temb, cond, cond_indices)
|
| 786 |
|
| 787 |
+
self.final_conv = self.final_conv.cuda()
|
| 788 |
x = self.final_conv(x)
|
| 789 |
return x
|
| 790 |
|