Timo commited on
Commit
ebe0afd
·
1 Parent(s): 83f964b

weird fixes

Browse files
Files changed (1) hide show
  1. src/draft_model.py +2 -2
src/draft_model.py CHANGED
@@ -20,7 +20,7 @@ ENCODING_FILE = "card_encodings.pt"
20
 
21
 
22
  class DraftModel:
23
- def __init__(self, model_path: str):
24
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
 
26
  cfg_path = hf_hub_download(
@@ -35,7 +35,7 @@ class DraftModel:
35
  cfg.pop("name", None)
36
 
37
  self.net = MLP_CrossAttention(**cfg).to(self.device)
38
- self.net.load_state_dict(torch.load(Path(model_path) / "network.pt", map_location=self.device))
39
  self.net.eval()
40
 
41
  # ---- embeddings – one-time load ------------------------------------
 
20
 
21
 
22
  class DraftModel:
23
+ def __init__(self):
24
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
 
26
  cfg_path = hf_hub_download(
 
35
  cfg.pop("name", None)
36
 
37
  self.net = MLP_CrossAttention(**cfg).to(self.device)
38
+ self.net.load_state_dict(weight_path, map_location=self.device))
39
  self.net.eval()
40
 
41
  # ---- embeddings – one-time load ------------------------------------