MTG_Drafting_AI / src /draft_model.py
Timo
Works now
8ef1a8f
from pathlib import Path
import json
import torch
from typing import List, Dict
from collections import defaultdict
from huggingface_hub import hf_hub_download
from helpers import get_embedding_dict, get_card_embeddings, MLP_CrossAttention
MODEL_REPO = "TimoBertram/MTG_Model_FIN"
CFG_FILE = "config.json"
MODEL_FILE = "network.pt"
DATA_REPO = "TimoBertram/MTG_Drafting_Dataset"
CARD_FILE = "cards_eoe.json"
ENCODING_FILE = "card_encodings.pt"
class DraftModel:
def __init__(self):
#self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device = torch.device("cpu") # Force CPU for compatibility
weight_path = hf_hub_download(
repo_id=MODEL_REPO, filename=MODEL_FILE, repo_type="model"
)
cfg_path = hf_hub_download(
repo_id=MODEL_REPO, filename=CFG_FILE, repo_type="model"
)
with open(cfg_path, "r") as f:
cfg = json.load(f)
# ---- load network ---------------------------------------------------
self.net = MLP_CrossAttention(**cfg)
self.net.load_state_dict(torch.load(weight_path, map_location=self.device))
self.net.eval()
self.net.to(self.device)
# ---- embeddings – one-time load ------------------------------------
self.embed_dict = get_embedding_dict(
hf_hub_download(repo_id=DATA_REPO, filename=ENCODING_FILE, repo_type="dataset"),
add_nontransformed=True
)
self.emb_size = next(iter(self.embed_dict.values())).shape[0]
raw_card_file = json.load(open(hf_hub_download(
repo_id=DATA_REPO, filename=CARD_FILE, repo_type="dataset"
)))
self.cards = defaultdict(dict)
for card in raw_card_file:
self.cards[card["set"]][card["name"]] = card
def _embed(self, name): # helper
return get_card_embeddings((name,), embedding_dict=self.embed_dict)[0]
# --------------------------------------------------------------------- #
# Public API expected by streamlit_app.py #
# --------------------------------------------------------------------- #
@torch.no_grad()
def predict(self, pack: List[str], deck: List[str]) -> Dict:
card_t = torch.stack([self._embed(c) for c in pack]).unsqueeze(0).to(self.device)
if deck is None:
deck_t = torch.zeros((1, 45, self.emb_size), device=self.device)
else:
deck_t = torch.stack([self._embed(c) for c in deck]).unsqueeze(0).to(self.device)
vals = self.net(deck = deck_t, cards = card_t)
scores = torch.softmax(vals, dim=1).squeeze(0).cpu().numpy()
logits = vals.squeeze(0).cpu().numpy()
return {
"pick": pack[scores.argmax()],
"logits": logits.tolist(),
"scores": scores.tolist(),
}
@torch.no_grad()
def get_p1p1(self, set_code:str):
keys = list(self.cards[set_code].keys())
cards = torch.stack([self._embed(c) for c in keys]).unsqueeze(0).to(self.device)
vals = self.predict(pack=keys, deck=None)["logits"]
return keys, vals