update device
Browse files- Dyna-1/model/model.py +1 -6
Dyna-1/model/model.py
CHANGED
|
@@ -99,8 +99,7 @@ class ESM_model(nn.Module):
|
|
| 99 |
self.method = method
|
| 100 |
self.layer = layer
|
| 101 |
if 'esm3' in self.method:
|
| 102 |
-
|
| 103 |
-
self.model = ESM3.from_pretrained("esm3_sm_open_v1").to(DEVICE,non_blocking=True).to(torch.float32)
|
| 104 |
'''except GatedRepoError as e:
|
| 105 |
print(f"No access to gated repository: {e}")
|
| 106 |
except OSError as e:
|
|
@@ -109,10 +108,6 @@ class ESM_model(nn.Module):
|
|
| 109 |
else:
|
| 110 |
print(f"Other error occurred: {e}")'''
|
| 111 |
|
| 112 |
-
self.n_layers = len(self.model.transformer.blocks)
|
| 113 |
-
self.hidden_size = self.model.transformer.blocks[0].attn.d_model
|
| 114 |
-
elif 'esmc' in self.method:
|
| 115 |
-
self.model = ESMC.from_pretrained("esmc_300m").to(DEVICE,non_blocking=True).to(torch.float32)
|
| 116 |
self.n_layers = len(self.model.transformer.blocks)
|
| 117 |
self.hidden_size = self.model.transformer.blocks[0].attn.d_model
|
| 118 |
elif self.method == 'esm2':
|
|
|
|
| 99 |
self.method = method
|
| 100 |
self.layer = layer
|
| 101 |
if 'esm3' in self.method:
|
| 102 |
+
self.model = ESM3.from_pretrained("esm3_sm_open_v1").to(torch.float32)
|
|
|
|
| 103 |
'''except GatedRepoError as e:
|
| 104 |
print(f"No access to gated repository: {e}")
|
| 105 |
except OSError as e:
|
|
|
|
| 108 |
else:
|
| 109 |
print(f"Other error occurred: {e}")'''
|
| 110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
self.n_layers = len(self.model.transformer.blocks)
|
| 112 |
self.hidden_size = self.model.transformer.blocks[0].attn.d_model
|
| 113 |
elif self.method == 'esm2':
|