Snapshot
Browse files- app.py +3 -3
- text_processing.py +13 -11
app.py
CHANGED
|
@@ -21,7 +21,7 @@ def tokenize(input_text: str, tokenizer: Tokenizer, device: torch.device) -> tup
|
|
| 21 |
attention_mask = cast(torch.Tensor, inputs["attention_mask"])
|
| 22 |
return input_ids, attention_mask
|
| 23 |
|
| 24 |
-
def calculate_log_probabilities(model: PreTrainedModel, tokenizer: Tokenizer, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> list[tuple[
|
| 25 |
with torch.no_grad():
|
| 26 |
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
|
| 27 |
# B x T x V
|
|
@@ -31,8 +31,8 @@ def calculate_log_probabilities(model: PreTrainedModel, tokenizer: Tokenizer, in
|
|
| 31 |
# T - 1
|
| 32 |
token_log_probs: torch.Tensor = log_probs[0, range(log_probs.shape[1]), input_ids[0][1:]]
|
| 33 |
# T - 1
|
| 34 |
-
tokens:
|
| 35 |
-
return list(zip(tokens, token_log_probs.tolist()))
|
| 36 |
|
| 37 |
|
| 38 |
def generate_replacements(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prefix_tokens: list[int], device: torch.device, num_samples: int = 5) -> list[str]:
|
|
|
|
| 21 |
attention_mask = cast(torch.Tensor, inputs["attention_mask"])
|
| 22 |
return input_ids, attention_mask
|
| 23 |
|
| 24 |
+
def calculate_log_probabilities(model: PreTrainedModel, tokenizer: Tokenizer, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> list[tuple[int, float]]:
|
| 25 |
with torch.no_grad():
|
| 26 |
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
|
| 27 |
# B x T x V
|
|
|
|
| 31 |
# T - 1
|
| 32 |
token_log_probs: torch.Tensor = log_probs[0, range(log_probs.shape[1]), input_ids[0][1:]]
|
| 33 |
# T - 1
|
| 34 |
+
tokens: torch.Tensor = input_ids[0][1:]
|
| 35 |
+
return list(zip(tokens.tolist(), token_log_probs.tolist()))
|
| 36 |
|
| 37 |
|
| 38 |
def generate_replacements(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prefix_tokens: list[int], device: torch.device, num_samples: int = 5) -> list[str]:
|
text_processing.py
CHANGED
|
@@ -1,30 +1,32 @@
|
|
| 1 |
from dataclasses import dataclass
|
|
|
|
| 2 |
|
| 3 |
@dataclass
|
| 4 |
class Word:
|
| 5 |
-
tokens: list[
|
| 6 |
text: str
|
| 7 |
logprob: float
|
| 8 |
first_token_index: int
|
| 9 |
|
| 10 |
-
def split_into_words(token_probs: list[tuple[
|
| 11 |
-
words = []
|
| 12 |
-
current_word = []
|
| 13 |
-
current_log_probs = []
|
| 14 |
-
current_word_first_token_index = 0
|
| 15 |
|
| 16 |
-
for i, (
|
|
|
|
| 17 |
if not token.startswith(chr(9601)) and token.isalpha():
|
| 18 |
-
current_word.append(
|
| 19 |
current_log_probs.append(logprob)
|
| 20 |
else:
|
| 21 |
if current_word:
|
| 22 |
-
words.append(Word(current_word,
|
| 23 |
-
current_word = [
|
| 24 |
current_log_probs = [logprob]
|
| 25 |
current_word_first_token_index = i
|
| 26 |
|
| 27 |
if current_word:
|
| 28 |
-
words.append(Word(current_word,
|
| 29 |
|
| 30 |
return words
|
|
|
|
| 1 |
from dataclasses import dataclass
|
| 2 |
+
from tokenizers import Tokenizer
|
| 3 |
|
| 4 |
@dataclass
|
| 5 |
class Word:
|
| 6 |
+
tokens: list[int]
|
| 7 |
text: str
|
| 8 |
logprob: float
|
| 9 |
first_token_index: int
|
| 10 |
|
| 11 |
+
def split_into_words(token_probs: list[tuple[int, float]], tokenizer: Tokenizer) -> list[Word]:
|
| 12 |
+
words: list[Word] = []
|
| 13 |
+
current_word: list[int] = []
|
| 14 |
+
current_log_probs: list[float] = []
|
| 15 |
+
current_word_first_token_index: int = 0
|
| 16 |
|
| 17 |
+
for i, (token_id, logprob) in enumerate(token_probs):
|
| 18 |
+
token: str = tokenizer.decode([token_id])
|
| 19 |
if not token.startswith(chr(9601)) and token.isalpha():
|
| 20 |
+
current_word.append(token_id)
|
| 21 |
current_log_probs.append(logprob)
|
| 22 |
else:
|
| 23 |
if current_word:
|
| 24 |
+
words.append(Word(current_word, tokenizer.decode(current_word), sum(current_log_probs), current_word_first_token_index))
|
| 25 |
+
current_word = [token_id]
|
| 26 |
current_log_probs = [logprob]
|
| 27 |
current_word_first_token_index = i
|
| 28 |
|
| 29 |
if current_word:
|
| 30 |
+
words.append(Word(current_word, tokenizer.decode(current_word), sum(current_log_probs), current_word_first_token_index))
|
| 31 |
|
| 32 |
return words
|