initial commit
Browse files- app.py +75 -0
- attention.py +53 -0
- models.py +162 -0
- requirements.txt +8 -0
app.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from typing import *
|
| 3 |
+
|
| 4 |
+
import gradio as gr
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
import numpy as np
|
| 7 |
+
import seaborn as sns
|
| 8 |
+
import sentencepiece as sp
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from huggingface_hub import hf_hub_download
|
| 12 |
+
from torchtext.datasets import Multi30k
|
| 13 |
+
|
| 14 |
+
from models import Seq2Seq
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# Load model
|
| 18 |
+
model_path = hf_hub_download("msarmi9/multi30k", "models/de-en/model.bin")
|
| 19 |
+
model = Seq2Seq(vocab_size=8000, hidden_dim=512, bos_idx=1, eos_idx=2, pad_idx=3, temperature=2)
|
| 20 |
+
model.load_state_dict(torch.load(model_path))
|
| 21 |
+
model.eval()
|
| 22 |
+
|
| 23 |
+
# Load sentencepiece tokenizers
|
| 24 |
+
source_spm_path = hf_hub_download("msarmi9/multi30k", "models/de-en/de8000.model")
|
| 25 |
+
target_spm_path = hf_hub_download("msarmi9/multi30k", "models/de-en/en8000.model")
|
| 26 |
+
source_spm = sp.SentencePieceProcessor(model_file=source_spm_path, add_eos=True)
|
| 27 |
+
target_spm = sp.SentencePieceProcessor(model_file=target_spm_path, add_eos=True)
|
| 28 |
+
|
| 29 |
+
# Load test set for example inputs
|
| 30 |
+
normalize = lambda sample: (sample[0].lower().strip(), sample[1].lower().strip())
|
| 31 |
+
test_source, _ = zip(*map(normalize, Multi30k(split="test", language_pair=("de", "en"))))
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def attention_heatmap(input_tokens: List[str], output_tokens: List[str], weights: np.ndarray) -> plt.Figure:
|
| 35 |
+
figure = plt.figure(dpi=800, tight_layout=True)
|
| 36 |
+
axes = sns.heatmap(weights, cmap="gray", cbar=False)
|
| 37 |
+
axes.set_xticklabels(input_tokens, rotation=90)
|
| 38 |
+
axes.set_yticklabels(output_tokens, rotation=0)
|
| 39 |
+
axes.tick_params(axis="both", length=0)
|
| 40 |
+
axes.xaxis.tick_top()
|
| 41 |
+
plt.close()
|
| 42 |
+
return figure
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@torch.inference_mode()
|
| 46 |
+
def run(input: str) -> Tuple[str, plt.Figure]:
|
| 47 |
+
"""Run inference on a single sentence. Returns prediction and attention heatmap."""""
|
| 48 |
+
input_tensor = torch.tensor(source_spm.encode(input), dtype=torch.int64)
|
| 49 |
+
output, weights = model.decode(input_tensor, max_decode_length=max(len(input_tensor), 80))
|
| 50 |
+
output = target_spm.decode(output.detach().tolist())
|
| 51 |
+
input_tokens = source_spm.encode(input, out_type=str)
|
| 52 |
+
output_tokens = target_spm.encode(output, out_type=str)
|
| 53 |
+
return output, attention_heatmap(input_tokens, output_tokens, weights.detach().numpy())
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
if __name__ == "__main__":
|
| 57 |
+
interface = gr.Interface(
|
| 58 |
+
run,
|
| 59 |
+
inputs=gr.inputs.Textbox(lines=4, label="German"),
|
| 60 |
+
outputs=[
|
| 61 |
+
gr.outputs.Textbox(label="English"),
|
| 62 |
+
gr.outputs.Image(type="plot", label="Attention Heatmap"),
|
| 63 |
+
],
|
| 64 |
+
title = "Multi30k Translation Widget",
|
| 65 |
+
examples=random.sample(test_source, k=30),
|
| 66 |
+
examples_per_page=10,
|
| 67 |
+
allow_flagging="never",
|
| 68 |
+
theme="huggingface",
|
| 69 |
+
live=True,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
interface.launch(
|
| 73 |
+
enable_queue=True,
|
| 74 |
+
cache_examples=True,
|
| 75 |
+
)
|
attention.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
Tensor = torch.Tensor
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Attention(nn.Module):
|
| 11 |
+
"""Container for applying an attention scoring function."""""
|
| 12 |
+
|
| 13 |
+
def __init__(self, score: nn.Module, dropout: nn.Module = None):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.score = score
|
| 16 |
+
self.dropout = dropout
|
| 17 |
+
|
| 18 |
+
def forward(self, decoder_state: Tensor, encoder_state: Tensor, source_mask: Tensor = None) -> Tuple[Tensor, Tensor]:
|
| 19 |
+
"""Return context and attention weights. Accepts a boolean mask indicating padding in the source sequence."""""
|
| 20 |
+
(B, L, D), (B, T, _) = decoder_state.shape, encoder_state.shape
|
| 21 |
+
scores = self.score(decoder_state, encoder_state) # (B, L, T)
|
| 22 |
+
if source_mask is not None: # (B, T)
|
| 23 |
+
scores.masked_fill_(source_mask.view(B, 1, T), -1e4)
|
| 24 |
+
weights = F.softmax(scores, dim=-1) # (B, L, T)
|
| 25 |
+
if self.dropout is not None:
|
| 26 |
+
weights = self.dropout(weights)
|
| 27 |
+
context = weights @ encoder_state # (B, L, _)
|
| 28 |
+
return context, weights # (B, L, _), (B, L, T)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class ConcatScore(nn.Module):
|
| 32 |
+
"""A two layer network as an attention scoring function. Expects bidirectional encoder."""""
|
| 33 |
+
|
| 34 |
+
def __init__(self, d: int):
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.w = nn.Linear(3*d, d)
|
| 37 |
+
self.v = nn.Linear(d, 1, bias=False)
|
| 38 |
+
self.initialize_parameters()
|
| 39 |
+
|
| 40 |
+
def forward(self, decoder_state: Tensor, encoder_state: Tensor) -> Tensor:
|
| 41 |
+
"""Return attention scores."""""
|
| 42 |
+
(B, L, D), (B, T, _) = decoder_state.shape, encoder_state.shape # (B, L, D), (B, T, 2*D)
|
| 43 |
+
decoder_state = decoder_state.repeat_interleave(T, dim=1) # (B, L*T, D)
|
| 44 |
+
encoder_state = encoder_state.repeat(1, L, 1) # (B, L*T, 2*D)
|
| 45 |
+
concatenated = torch.cat((decoder_state, encoder_state), dim=-1) # (B, L*T, 3*D)
|
| 46 |
+
scores = self.v(torch.tanh(self.w(concatenated))) # (B, L*T, 1)
|
| 47 |
+
return scores.view(B, L, T) # (B, L, T)
|
| 48 |
+
|
| 49 |
+
@torch.no_grad()
|
| 50 |
+
def initialize_parameters(self):
|
| 51 |
+
nn.init.xavier_uniform_(self.w.weight)
|
| 52 |
+
nn.init.xavier_uniform_(self.v.weight, gain=nn.init.calculate_gain("tanh"))
|
| 53 |
+
nn.init.zeros_(self.w.bias)
|
models.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
from attention import Attention
|
| 7 |
+
from attention import ConcatScore
|
| 8 |
+
|
| 9 |
+
Tensor = torch.Tensor
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Encoder(nn.Module):
|
| 13 |
+
"""Single layer recurrent bidirectional encoder."""""
|
| 14 |
+
|
| 15 |
+
def __init__(self, vocab_size: int, hidden_dim: int, pad_idx: int):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.embedding = nn.Sequential(
|
| 18 |
+
OrderedDict(
|
| 19 |
+
embedding=nn.Embedding(vocab_size, hidden_dim, padding_idx=pad_idx),
|
| 20 |
+
dropout=nn.Dropout(p=0.33),
|
| 21 |
+
)
|
| 22 |
+
)
|
| 23 |
+
self.gru = nn.GRU(hidden_dim, hidden_dim, batch_first=True, bidirectional=True)
|
| 24 |
+
self.fc = nn.Linear(2*hidden_dim, hidden_dim)
|
| 25 |
+
self.initialize_parameters()
|
| 26 |
+
|
| 27 |
+
def forward(self, input: Tensor) -> Tuple[Tensor, Tensor]:
|
| 28 |
+
"""Encode a sequence of tokens as a sequence of hidden states."""""
|
| 29 |
+
B, T = input.shape
|
| 30 |
+
embedded = self.embedding(input) # (B, T, D)
|
| 31 |
+
output, hidden = self.gru(embedded) # (B, T, 2*D), (2, B, D)
|
| 32 |
+
hidden = torch.cat((hidden[0], hidden[1]), dim=-1) # (B, 2*D)
|
| 33 |
+
hidden = torch.tanh(self.fc(hidden)) # (B, D)
|
| 34 |
+
return output, hidden.unsqueeze(0) # (B, T, 2*D), (1, B, D)
|
| 35 |
+
|
| 36 |
+
@torch.no_grad()
|
| 37 |
+
def initialize_parameters(self):
|
| 38 |
+
"""Initialize linear weights uniformly, recurrent weights orthogonally, and bias to zero."""""
|
| 39 |
+
for name, parameters in self.named_parameters():
|
| 40 |
+
if "embedding" in name:
|
| 41 |
+
nn.init.xavier_uniform_(parameters)
|
| 42 |
+
elif "weight_ih" in name:
|
| 43 |
+
w_ir, w_iz, w_in = torch.chunk(parameters, chunks=3, dim=0)
|
| 44 |
+
nn.init.xavier_uniform_(w_ir)
|
| 45 |
+
nn.init.xavier_uniform_(w_iz)
|
| 46 |
+
nn.init.xavier_uniform_(w_in)
|
| 47 |
+
elif "weight_hh" in name:
|
| 48 |
+
w_hr, w_hz, w_hn = torch.chunk(parameters, chunks=3, dim=0)
|
| 49 |
+
nn.init.orthogonal_(w_hr)
|
| 50 |
+
nn.init.orthogonal_(w_hz)
|
| 51 |
+
nn.init.orthogonal_(w_hn)
|
| 52 |
+
elif "weight" in name:
|
| 53 |
+
nn.init.xavier_uniform_(parameters)
|
| 54 |
+
elif "bias" in name:
|
| 55 |
+
nn.init.zeros_(parameters)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class Decoder(nn.Module):
|
| 59 |
+
"""Single layer recurrent decoder."""""
|
| 60 |
+
|
| 61 |
+
def __init__(self, vocab_size: int, hidden_dim: int, pad_idx: int, temperature: float = 1.0):
|
| 62 |
+
super().__init__()
|
| 63 |
+
self.embedding = nn.Sequential(
|
| 64 |
+
OrderedDict(
|
| 65 |
+
embedding=nn.Embedding(vocab_size, hidden_dim, padding_idx=pad_idx),
|
| 66 |
+
dropout=nn.Dropout(p=0.33),
|
| 67 |
+
)
|
| 68 |
+
)
|
| 69 |
+
self.attention = Attention(ConcatScore(hidden_dim), nn.Dropout(p=0.1))
|
| 70 |
+
self.gru = nn.GRU(3*hidden_dim, hidden_dim, batch_first=True)
|
| 71 |
+
self.fc = nn.Sequential(
|
| 72 |
+
OrderedDict(
|
| 73 |
+
fc1=nn.Linear(4*hidden_dim, hidden_dim),
|
| 74 |
+
layer_norm=nn.LayerNorm(hidden_dim),
|
| 75 |
+
gelu=nn.GELU(),
|
| 76 |
+
fc2=nn.Linear(hidden_dim, vocab_size, bias=False),
|
| 77 |
+
)
|
| 78 |
+
)
|
| 79 |
+
self.fc.fc2.weight = self.embedding.embedding.weight
|
| 80 |
+
self.temperature = temperature
|
| 81 |
+
self.initialize_parameters()
|
| 82 |
+
|
| 83 |
+
def forward(self, input: Tensor, hidden: Tensor, encoder_output: Tensor, source_mask: Tensor = None) -> Tuple[Tensor, Tensor, Tensor]:
|
| 84 |
+
"""Predict the next token given an input token. Returns unnormalized predictions over the vocabulary."""""
|
| 85 |
+
B, = input.shape # L=1
|
| 86 |
+
embedded = self.embedding(input.view(B, 1)) # (B, 1, D)
|
| 87 |
+
context, weights = self.attention(hidden.view(B, 1, -1), encoder_output, source_mask) # (B, 1, 2*D), (B, 1, T)
|
| 88 |
+
output, hidden = self.gru(torch.cat((embedded, context), dim=-1), hidden) # (B, 1, D), (1, B, D)
|
| 89 |
+
predictions = self.fc(torch.cat((embedded, context, output), dim=-1)) / self.temperature # (B, 1, V)
|
| 90 |
+
return predictions.view(B, -1), hidden, weights.view(B, -1) # (B, V), (1, B, D), (B, T)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@torch.no_grad()
|
| 94 |
+
def initialize_parameters(self):
|
| 95 |
+
"""Initialize linear weights uniformly, recurrent weights orthogonally, and bias to zero."""""
|
| 96 |
+
for name, parameters in self.named_parameters():
|
| 97 |
+
if "norm" in name:
|
| 98 |
+
continue
|
| 99 |
+
elif "embedding" in name:
|
| 100 |
+
nn.init.xavier_uniform_(parameters)
|
| 101 |
+
elif "weight_ih" in name:
|
| 102 |
+
w_ir, w_iz, w_in = torch.chunk(parameters, chunks=3, dim=0)
|
| 103 |
+
nn.init.xavier_uniform_(w_ir)
|
| 104 |
+
nn.init.xavier_uniform_(w_iz)
|
| 105 |
+
nn.init.xavier_uniform_(w_in)
|
| 106 |
+
elif "weight_hh" in name:
|
| 107 |
+
w_hr, w_hz, w_hn = torch.chunk(parameters, chunks=3, dim=0)
|
| 108 |
+
nn.init.orthogonal_(w_hr)
|
| 109 |
+
nn.init.orthogonal_(w_hz)
|
| 110 |
+
nn.init.orthogonal_(w_hn)
|
| 111 |
+
elif "weight" in name:
|
| 112 |
+
nn.init.xavier_uniform_(parameters)
|
| 113 |
+
elif "bias" in name:
|
| 114 |
+
nn.init.zeros_(parameters)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class Seq2Seq(nn.Module):
|
| 118 |
+
"""Seq2seq with attention."""""
|
| 119 |
+
|
| 120 |
+
def __init__(self, vocab_size: int, hidden_dim: int, bos_idx: int, eos_idx: int, pad_idx: int, teacher_forcing: float = 0.5, temperature: float = 1.0):
|
| 121 |
+
super().__init__()
|
| 122 |
+
self.encoder = Encoder(vocab_size, hidden_dim, pad_idx)
|
| 123 |
+
self.decoder = Decoder(vocab_size, hidden_dim, pad_idx, temperature=temperature)
|
| 124 |
+
self.bos_idx = bos_idx
|
| 125 |
+
self.eos_idx = eos_idx
|
| 126 |
+
self.pad_idx = pad_idx
|
| 127 |
+
self.teacher_forcing = teacher_forcing
|
| 128 |
+
|
| 129 |
+
def forward(self, source: Tensor, target: Tensor) -> Tensor:
|
| 130 |
+
"""Forward pass at training time. Returns unnormalized predictions over the vocabulary."""""
|
| 131 |
+
(B, T), (B, L) = source.shape, target.shape
|
| 132 |
+
encoder_output, hidden = self.encoder(source) # (B, T, D), (1, B, D)
|
| 133 |
+
decoder_input = torch.full((B,), self.bos_idx, device=source.device) # (B,)
|
| 134 |
+
source_mask = source == self.pad_idx # (B, 1, T)
|
| 135 |
+
|
| 136 |
+
output = []
|
| 137 |
+
for i in range(L):
|
| 138 |
+
predictions, hidden, _ = self.decoder(decoder_input, hidden, encoder_output, source_mask) # (B, V), (1, B, D)
|
| 139 |
+
output.append(predictions)
|
| 140 |
+
if self.training and random.random() < self.teacher_forcing:
|
| 141 |
+
decoder_input = target[:,i] # (B,)
|
| 142 |
+
else:
|
| 143 |
+
decoder_input = predictions.argmax(dim=1) # (B,)
|
| 144 |
+
return torch.stack(output, dim=1) # (B, L, V)
|
| 145 |
+
|
| 146 |
+
@torch.inference_mode()
|
| 147 |
+
def decode(self, source: Tensor, max_decode_length: int) -> Tuple[Tensor, Tensor]:
|
| 148 |
+
"""Decode a single sequence at inference time. Returns output sequence and attention weights."""""
|
| 149 |
+
B, (T,) = 1, source.shape
|
| 150 |
+
encoder_output, hidden = self.encoder(source.view(B, T)) # (B, T, D), (B, 1, D)
|
| 151 |
+
decoder_input = torch.full((B,), self.bos_idx, device=source.device) # (B,)
|
| 152 |
+
|
| 153 |
+
output, attention = [], []
|
| 154 |
+
for i in range(max_decode_length):
|
| 155 |
+
predictions, hidden, weights = self.decoder(decoder_input, hidden, encoder_output) # (B, V), (1, B, D), (B, T)
|
| 156 |
+
output.append(predictions.argmax(dim=-1)) # (B,)
|
| 157 |
+
attention.append(weights) # (B, T)
|
| 158 |
+
if output[i] == self.eos_idx:
|
| 159 |
+
break
|
| 160 |
+
else:
|
| 161 |
+
decoder_input = output[i] # (B,)
|
| 162 |
+
return torch.cat(output, dim=0), torch.cat(attention, dim=0) # (L,), (L, T)
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio
|
| 2 |
+
huggingface_hub
|
| 3 |
+
matplotlib
|
| 4 |
+
numpy
|
| 5 |
+
seaborn
|
| 6 |
+
sentencepiece
|
| 7 |
+
torch
|
| 8 |
+
torchtext
|