Update models/unet.py
Browse files- models/unet.py +4 -4
models/unet.py
CHANGED
|
@@ -857,10 +857,10 @@ class MotionCLR(nn.Module):
|
|
| 857 |
edit_config=edit_config,
|
| 858 |
)
|
| 859 |
|
| 860 |
-
self.embed_text = self.embed_text.
|
| 861 |
-
self.textTransEncoder = self.textTransEncoder.
|
| 862 |
-
self.text_ln = self.text_ln.
|
| 863 |
-
self.unet = self.unet.
|
| 864 |
|
| 865 |
def encode_text(self, raw_text, device):
|
| 866 |
with torch.no_grad():
|
|
|
|
| 857 |
edit_config=edit_config,
|
| 858 |
)
|
| 859 |
|
| 860 |
+
self.embed_text = self.embed_text.cuda()
|
| 861 |
+
self.textTransEncoder = self.textTransEncoder.cuda()
|
| 862 |
+
self.text_ln = self.text_ln.cuda()
|
| 863 |
+
self.unet = self.unet.cuda()
|
| 864 |
|
| 865 |
def encode_text(self, raw_text, device):
|
| 866 |
with torch.no_grad():
|