Spaces:
Build error
Build error
Commit
·
2012098
1
Parent(s):
9629d26
Minor
Browse files
models/mask_transformer/transformer.py
CHANGED
|
@@ -179,9 +179,10 @@ class MaskTransformer(nn.Module):
|
|
| 179 |
clip_model, clip_preprocess = clip.load(clip_version, device='cpu',
|
| 180 |
jit=False) # Must set jit=False for training
|
| 181 |
# Cannot run on cpu
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
|
|
|
| 185 |
|
| 186 |
# Freeze CLIP weights
|
| 187 |
clip_model.eval()
|
|
@@ -731,9 +732,10 @@ class ResidualTransformer(nn.Module):
|
|
| 731 |
clip_model, clip_preprocess = clip.load(clip_version, device='cpu',
|
| 732 |
jit=False) # Must set jit=False for training
|
| 733 |
# Cannot run on cpu
|
| 734 |
-
|
| 735 |
-
|
| 736 |
-
|
|
|
|
| 737 |
|
| 738 |
# Freeze CLIP weights
|
| 739 |
clip_model.eval()
|
|
|
|
| 179 |
clip_model, clip_preprocess = clip.load(clip_version, device='cpu',
|
| 180 |
jit=False) # Must set jit=False for training
|
| 181 |
# Cannot run on cpu
|
| 182 |
+
if str(self.evice) != "cpu":
|
| 183 |
+
clip.model.convert_weights(
|
| 184 |
+
clip_model) # Actually this line is unnecessary since clip by default already on float16
|
| 185 |
+
# Date 0707: It's necessary, only unecessary when load directly to gpu. Disable if need to run on cpu
|
| 186 |
|
| 187 |
# Freeze CLIP weights
|
| 188 |
clip_model.eval()
|
|
|
|
| 732 |
clip_model, clip_preprocess = clip.load(clip_version, device='cpu',
|
| 733 |
jit=False) # Must set jit=False for training
|
| 734 |
# Cannot run on cpu
|
| 735 |
+
if str(self.evice) != "cpu":
|
| 736 |
+
clip.model.convert_weights(
|
| 737 |
+
clip_model) # Actually this line is unnecessary since clip by default already on float16
|
| 738 |
+
# Date 0707: It's necessary, only unecessary when load directly to gpu. Disable if need to run on cpu
|
| 739 |
|
| 740 |
# Freeze CLIP weights
|
| 741 |
clip_model.eval()
|
options/base_option.py
CHANGED
|
@@ -12,7 +12,7 @@ class BaseOptions():
|
|
| 12 |
|
| 13 |
self.parser.add_argument('--vq_name', type=str, default="rvq_nq1_dc512_nc512", help='Name of the rvq model.')
|
| 14 |
|
| 15 |
-
self.parser.add_argument("--gpu_id", type=int, default
|
| 16 |
self.parser.add_argument('--dataset_name', type=str, default='t2m', help='Dataset Name, {t2m} for humanml3d, {kit} for kit-ml')
|
| 17 |
self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here.')
|
| 18 |
|
|
|
|
| 12 |
|
| 13 |
self.parser.add_argument('--vq_name', type=str, default="rvq_nq1_dc512_nc512", help='Name of the rvq model.')
|
| 14 |
|
| 15 |
+
self.parser.add_argument("--gpu_id", type=int, default=-1, help='GPU id')
|
| 16 |
self.parser.add_argument('--dataset_name', type=str, default='t2m', help='Dataset Name, {t2m} for humanml3d, {kit} for kit-ml')
|
| 17 |
self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here.')
|
| 18 |
|