import torch import torch.nn as nn import pickle class Deck_Attention(nn.Module): def __init__(self, input_size, output_dim, num_heads=8, num_layers=3, output_layers = 2, dropout=0.2): super(Deck_Attention, self).__init__() # Input projection and normalization self.hidden_dim = 1024 self.input_proj = nn.Linear(input_size, self.hidden_dim, bias = False) self.input_norm = nn.LayerNorm(self.hidden_dim, bias = False) self.cls_token = nn.Parameter(torch.randn(1, 1, self.hidden_dim)) self.pos_encoding = nn.Embedding(45, self.hidden_dim) encoder_layer = nn.TransformerEncoderLayer( d_model= self.hidden_dim, nhead = num_heads, dim_feedforward= self.hidden_dim * 4, dropout=dropout, activation='gelu', batch_first=True, norm_first=True, ) self.layers = nn.TransformerEncoder(encoder_layer, num_layers=num_layers, enable_nested_tensor=False, ) self.transformer_norm = nn.LayerNorm(self.hidden_dim, bias = False) # Output projection self.output_proj = nn.ModuleList( [nn.Sequential( nn.Linear(self.hidden_dim, self.hidden_dim, bias = False), nn.GELU(), nn.LayerNorm(self.hidden_dim, bias = False), nn.Dropout(dropout) ) for _ in range(output_layers)]) self.final_layer = nn.Sequential( nn.Linear(self.hidden_dim, self.hidden_dim, bias = False), nn.LayerNorm(self.hidden_dim, bias = False), nn.GELU(), nn.Linear(self.hidden_dim, output_dim, bias = False)) def forward(self, x, lens=None): # Reshape input if needed x = x.view(x.size(0), x.size(-2), x.size(-1)) batch_size = x.size(0) # Create padding mask padding_mask = None if lens is not None: lens = lens.to(x.device) padding_mask = torch.arange(45, device=x.device).expand(batch_size, 45) >= lens.unsqueeze(1) padding_mask = torch.cat((torch.zeros(padding_mask.shape[0], 1, device= padding_mask.device).bool(), padding_mask), dim = 1) # Initial projection and add position embeddings x = self.input_proj(x) pos = torch.arange(45, device=x.device).expand(batch_size, 45) pos = self.pos_encoding(pos) x = x + pos x = torch.cat([self.cls_token.expand(batch_size, -1, -1), x], dim=1) x = self.input_norm(x) x = self.layers(x, src_key_padding_mask=padding_mask) x = self.transformer_norm(x) x = x[:, 0, :] for layer in self.output_proj: x = x+ layer(x) x = self.final_layer(x) return x class Card_Preprocessing(nn.Module): def __init__(self, num_layers, input_size, output_size, nonlinearity = nn.GELU, internal_size = 1024, dropout = 0): super(Card_Preprocessing,self).__init__() self.internal_size = internal_size self.input = nn.Sequential( nn.Linear(input_size,internal_size, bias = False), nonlinearity(), nn.LayerNorm(internal_size, bias = False), nn.Dropout(dropout), ) self.hidden_layers = nn.ModuleList() self.dropout_rate = dropout for i in range(num_layers): self.hidden_layers.append(nn.Sequential( nn.Linear(internal_size,internal_size, bias = False), nonlinearity(), nn.LayerNorm(internal_size, bias = False), nn.Dropout(dropout), )) self.output = nn.Sequential( nn.Linear(internal_size,output_size, bias = False), nonlinearity(), nn.LayerNorm(output_size, bias = False) ) self.gammas = nn.ParameterList([torch.nn.Parameter(torch.ones(1, internal_size), requires_grad = True) for i in range(num_layers)]) def forward(self,x): x = self.input(x) for i,layer in enumerate(self.hidden_layers): gamma = torch.sigmoid(self.gammas[i]) x = gamma * x + (1-gamma) * layer(x) x = self.output(x) return x class CrossAttnBlock(nn.Module): """ One deck→pack cross-attention block, Pre-LayerNorm style. cards : [B, K, d] (queries) deck : [B, D, d] (keys / values) returns updated cards [B, K, d] """ def __init__(self, d_model: int, n_heads: int, dropout: float): super().__init__() self.ln_q = nn.LayerNorm(d_model) self.ln_k = nn.LayerNorm(d_model) self.ln_v = nn.LayerNorm(d_model) self.xattn = nn.MultiheadAttention( d_model, n_heads, dropout=dropout, batch_first=True) self.ln_ff = nn.LayerNorm(d_model) self.ffn = nn.Sequential( nn.Linear(d_model, 4 * d_model), nn.GELU(), nn.Dropout(dropout), nn.Linear(4 * d_model, d_model), nn.Dropout(dropout), ) self.dropout_attn = nn.Dropout(dropout) def forward(self, cards, deck, mask = None): # 1) deck → card cross-attention q = self.ln_q(cards) k = self.ln_k(deck) v = self.ln_v(deck) attn_out, _ = self.xattn(q, k, v, key_padding_mask = mask) # [B, K, d] x = cards + self.dropout_attn(attn_out) # residual # 2) position-wise feed-forward y = self.ffn(self.ln_ff(x)) return x + y class MLP_CrossAttention(nn.Module): def __init__(self, input_size, num_card_layers, card_output_dim, dropout, **kwargs): super(MLP_CrossAttention, self).__init__() self.input_size = input_size self.card_encoder = Card_Preprocessing(num_card_layers, input_size = input_size, internal_size = 1024, output_size = card_output_dim, dropout = dropout) self.attention_layers = nn.ModuleList([ CrossAttnBlock(card_output_dim, n_heads=4, dropout=dropout) for _ in range(10) ]) self.output_layer = nn.Sequential( nn.Linear(card_output_dim, card_output_dim*2), nn.ReLU(), nn.LayerNorm(card_output_dim*2, bias = False), nn.Dropout(dropout), nn.Linear(card_output_dim*2, card_output_dim*4), nn.ReLU(), nn.LayerNorm(card_output_dim*4, bias = False), nn.Dropout(dropout), nn.Linear(card_output_dim*4, card_output_dim), nn.ReLU(), nn.LayerNorm(card_output_dim, bias = False), nn.Linear(card_output_dim, 1), ) if kwargs['path'] is not None: self.load_state_dict(torch.load(f"{kwargs['path']}/network.pt", map_location='cpu')) print(f"Loaded model from {kwargs['path']}/network.pt") def forward(self, deck, cards, get_embeddings = False, no_attention = False): batch_size, deck_size, card_size = deck.shape deck = deck.view(batch_size * deck_size, card_size) deck_encoded = self.card_encoder(deck) deck_encoded = deck_encoded.view(batch_size, deck_size, -1) # identify padded cards mask = (cards.sum(dim=-1) != 0) cards_encoded = self.card_encoder(cards) if not no_attention: # Cross-attention for layer in self.attention_layers: cards_encoded = layer(cards_encoded, deck_encoded) if get_embeddings: for layer in self.output_layer[:-3]: cards_encoded = layer(cards_encoded) return cards_encoded # Output layer logits = self.output_layer(cards_encoded) # Mask out padded cards logits = logits.masked_fill(~mask.unsqueeze(-1), float('-inf')) return logits.squeeze(-1) def get_card_embedding(self, card_embedding): card_embedding = card_embedding.view(1,1, -1) empty_deck = torch.zeros((1, 45, self.input_size)).to(card_embedding.device) return self.card_encoder(card_embedding).squeeze() return self(deck = empty_deck, cards = card_embedding, get_embeddings = True, no_attention = True).squeeze(0) def get_embedding_dict(path, add_nontransformed = False): with open(path, 'rb') as f: embedding_dict = pickle.load(f) if add_nontransformed: embedding_dict_tmp = {} for k,v in embedding_dict.items(): embedding_dict_tmp[k] = v if '//' in k: embedding_dict_tmp[k.split(' // ')[0]] = v embedding_dict = embedding_dict_tmp return embedding_dict_tmp return embedding_dict def get_card_embeddings(card_names, embedding_dict, embedding_size = 1330): embeddings = [] new_embeddings = {} for card in card_names: if card == '': embeddings.append([]) elif card == []: if type(embedding_size) == tuple: channels, height, width = embedding_size new_embedding = torch.zeros(1,channels, height, width) else: new_embedding = torch.zeros(1,embedding_size) embeddings.append(new_embedding) elif isinstance(card, list): if len(card) == 0: embeddings.append(None) continue deck_embedding = [] for c in card: embedding, got_new = get_embedding_of_card(c, embedding_dict) deck_embedding.append(embedding) try: num_cards = len(deck_embedding) deck_embedding = torch.stack(deck_embedding) if type(embedding_size) == tuple: channels, height, width = embedding_size deck_embedding = deck_embedding.view(num_cards,channels, height, width) else: deck_embedding = deck_embedding.view(num_cards,-1) except Exception as e: raise e embeddings.append(deck_embedding) else: embedding, got_new = get_embedding_of_card(card, embedding_dict) embeddings.append(embedding) return embeddings def check_for_basics(card_name, embedding_dict): ints = ['1','2','3','4','5'] basics = ['Mountain','Forest','Swamp','Island','Plains'] for b in basics: if b in card_name: for i in ints: if card_name == f'{b}_{i}': return b return card_name def get_embedding_of_card(card_name, embedding_dict): try: card_name = check_for_basics(card_name, embedding_dict) card_name = card_name.replace('_', ' ') card_name = card_name.replace("Sol'kanar","Sol'Kanar") if card_name not in embedding_dict and card_name.split(' // ')[0] not in embedding_dict and card_name.replace('A-','') not in embedding_dict: # print(f'Requesting new embedding for {card_name}') # attributes, text = get_card_representation(card_name = card_name) # text_embedding = embedd_text([text]).squeeze() # return torch.Tensor(np.concatenate((attributes, text_embedding), axis = 0)), True raise Exception(f'Could not find {card_name}') else: try: return torch.Tensor(embedding_dict[card_name]), False except: try: return torch.Tensor(embedding_dict[card_name.split(' // ')[0]]), False except: try: return torch.Tensor(embedding_dict[card_name.replace('_',' ')]), False except: try: return torch.Tensor(embedding_dict[card_name.replace('A-','')]), False except: print(f'Could not find {card_name}') raise Exception except Exception as e: print(f'Could not find {card_name}') print(e) raise e