MTG_Drafting_AI / src /draft_model.py
Timo
added model
fa1e508
raw
history blame
2.27 kB
from pathlib import Path
import json
import torch
from typing import List, Dict
from huggingface_hub import hf_hub_download
from src.models.winrate_model import Winrate_Model
from src.training import train_mlp
from src.utils import utils
MODEL_REPO = "TimoBertram/MTG_Model_FIN/"
CFG_FILE = "config.json"
MODEL_FILE = "network.pt"
DATA_REPO = "TimoBertram/MTG_Drafting_Dataset/"
CARD_FILE = "cards_eoe.json"
ENCODING_FILE = "card_encodings.pt"
class DraftModel:
def __init__(self, model_path: str):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cfg_path = hf_hub_download(
repo_id=MODEL_REPO, filename=CFG_FILE, repo_type="model"
)
weight_path = hf_hub_download(
repo_id=MODEL_REPO, filename=MODEL_FILE, repo_type="model"
)
# ---- load network ---------------------------------------------------
cfg = open(cfg_path, "r")
cfg.pop("name", None)
self.net = train_mlp.MLP_CrossAttention(**cfg).to(self.device)
self.net.load_state_dict(torch.load(Path(model_path) / "network.pt", map_location=self.device))
self.net.eval()
# ---- embeddings – one-time load ------------------------------------
self.embed_dict = utils.get_embedding_dict(
hf_hub_download(repo_id=DATA_REPO, filename=ENCODING_FILE, repo_type="dataset"),
add_nontransformed=True
)
# --------------------------------------------------------------------- #
# Public API expected by streamlit_app.py #
# --------------------------------------------------------------------- #
@torch.no_grad()
def predict(self, pack: List[Dict], picks: List[Dict], deck: List[Dict]) -> Dict:
names = [c["name"] for c in pack]
def embed(name): # helper
return utils.get_card_embeddings((name,), embedding_dict=self.embed_dict)[0]
card_t = torch.stack([embed(n) for n in names]).unsqueeze(0).to(self.device)
deck_t = torch.zeros((1, 45, self.emb_size), device=self.device)
return torch.softmax(self.net(card_t, deck_t), dim=1).squeeze(0).cpu().numpy().tolist()