[refactor]: replace pyplot usage with matplotlib Figure
Browse files
app.py
CHANGED
|
@@ -2,13 +2,13 @@ import random
|
|
| 2 |
from typing import *
|
| 3 |
|
| 4 |
import gradio as gr
|
| 5 |
-
import matplotlib.pyplot as plt
|
| 6 |
import numpy as np
|
| 7 |
import seaborn as sns
|
| 8 |
import sentencepiece as sp
|
| 9 |
import torch
|
| 10 |
|
| 11 |
from huggingface_hub import hf_hub_download
|
|
|
|
| 12 |
from torchtext.datasets import Multi30k
|
| 13 |
|
| 14 |
from models import Seq2Seq
|
|
@@ -31,19 +31,18 @@ normalize = lambda sample: (sample[0].lower().strip(), sample[1].lower().strip()
|
|
| 31 |
test_source, _ = zip(*map(normalize, Multi30k(split="test", language_pair=("de", "en"))))
|
| 32 |
|
| 33 |
|
| 34 |
-
def attention_heatmap(input_tokens: List[str], output_tokens: List[str], weights: np.ndarray) ->
|
| 35 |
-
figure =
|
| 36 |
-
axes =
|
| 37 |
-
axes.
|
| 38 |
-
axes.
|
| 39 |
-
axes.tick_params(axis="
|
| 40 |
axes.xaxis.tick_top()
|
| 41 |
-
plt.close()
|
| 42 |
return figure
|
| 43 |
|
| 44 |
|
| 45 |
@torch.inference_mode()
|
| 46 |
-
def run(input: str) -> Tuple[str,
|
| 47 |
"""Run inference on a single sentence. Returns prediction and attention heatmap."""""
|
| 48 |
input = input.lower().strip().rstrip(".") + "."
|
| 49 |
input_tensor = torch.tensor(source_spm.encode(input), dtype=torch.int64)
|
|
|
|
| 2 |
from typing import *
|
| 3 |
|
| 4 |
import gradio as gr
|
|
|
|
| 5 |
import numpy as np
|
| 6 |
import seaborn as sns
|
| 7 |
import sentencepiece as sp
|
| 8 |
import torch
|
| 9 |
|
| 10 |
from huggingface_hub import hf_hub_download
|
| 11 |
+
from matplotlib.figure import Figure
|
| 12 |
from torchtext.datasets import Multi30k
|
| 13 |
|
| 14 |
from models import Seq2Seq
|
|
|
|
| 31 |
test_source, _ = zip(*map(normalize, Multi30k(split="test", language_pair=("de", "en"))))
|
| 32 |
|
| 33 |
|
| 34 |
+
def attention_heatmap(input_tokens: List[str], output_tokens: List[str], weights: np.ndarray) -> Figure:
|
| 35 |
+
figure = Figure(dpi=800, tight_layout=True)
|
| 36 |
+
axes = figure.add_subplot()
|
| 37 |
+
axes = sns.heatmap(weights, ax=axes, xticklabels=input_tokens, yticklabels=output_tokens, cmap="gray", cbar=False)
|
| 38 |
+
axes.tick_params(axis="x", rotation=90, length=0)
|
| 39 |
+
axes.tick_params(axis="y", rotation=0, length=0)
|
| 40 |
axes.xaxis.tick_top()
|
|
|
|
| 41 |
return figure
|
| 42 |
|
| 43 |
|
| 44 |
@torch.inference_mode()
|
| 45 |
+
def run(input: str) -> Tuple[str, Figure]:
|
| 46 |
"""Run inference on a single sentence. Returns prediction and attention heatmap."""""
|
| 47 |
input = input.lower().strip().rstrip(".") + "."
|
| 48 |
input_tensor = torch.tensor(source_spm.encode(input), dtype=torch.int64)
|