Use the flagging threshold to filter out uninteresting tokens
Browse files- completions.py +1 -1
- expand_llm.py +3 -3
completions.py
CHANGED
|
@@ -92,7 +92,7 @@ def check_text(input_text: str, model: PreTrainedModel, tokenizer: Tokenizer, de
|
|
| 92 |
|
| 93 |
contexts = [word.context for _, word in low_prob_words]
|
| 94 |
|
| 95 |
-
expander = LLMBatchExpander(model, tokenizer)
|
| 96 |
|
| 97 |
#%%
|
| 98 |
series = []
|
|
|
|
| 92 |
|
| 93 |
contexts = [word.context for _, word in low_prob_words]
|
| 94 |
|
| 95 |
+
expander = LLMBatchExpander(model, tokenizer, threshold=log_prob_threshold)
|
| 96 |
|
| 97 |
#%%
|
| 98 |
series = []
|
expand_llm.py
CHANGED
|
@@ -6,7 +6,7 @@ import time
|
|
| 6 |
|
| 7 |
type Tokenizer = PreTrainedTokenizer | PreTrainedTokenizerFast
|
| 8 |
|
| 9 |
-
def find_next_tokens(model: PreTrainedModel, inputs: BatchEncoding,
|
| 10 |
input_ids = inputs["input_ids"]
|
| 11 |
attention_mask = inputs["attention_mask"]
|
| 12 |
print("Running inference")
|
|
@@ -21,7 +21,6 @@ def find_next_tokens(model: PreTrainedModel, inputs: BatchEncoding, tokenizer: T
|
|
| 21 |
start_time = time.time()
|
| 22 |
result = []
|
| 23 |
print(f"Resulting tensor: {log_probs.shape}")
|
| 24 |
-
threshold = -10.0
|
| 25 |
for probs in log_probs:
|
| 26 |
# Filter out low probability tokens for efficiency
|
| 27 |
above_threshold = torch.where(probs > threshold)
|
|
@@ -39,10 +38,11 @@ def prepare_inputs(contexts: list[list[int]], tokenizer: Tokenizer, device: torc
|
|
| 39 |
class LLMBatchExpander(BatchExpander):
|
| 40 |
model: PreTrainedModel
|
| 41 |
tokenizer: Tokenizer
|
|
|
|
| 42 |
|
| 43 |
def expand(self, batch: Batch) -> BatchCandidates:
|
| 44 |
inputs = prepare_inputs([s.get_all_tokens() for s in batch.items], self.tokenizer, self.model.device)
|
| 45 |
-
next_tokens = find_next_tokens(self.model, inputs, self.
|
| 46 |
start_time = time.time()
|
| 47 |
results = []
|
| 48 |
print(f"Batch size: {len(batch.items)}, next tokens size: {len(next_tokens)}")
|
|
|
|
| 6 |
|
| 7 |
type Tokenizer = PreTrainedTokenizer | PreTrainedTokenizerFast
|
| 8 |
|
| 9 |
+
def find_next_tokens(model: PreTrainedModel, inputs: BatchEncoding, threshold: float) -> list[list[tuple[int, float]]]:
|
| 10 |
input_ids = inputs["input_ids"]
|
| 11 |
attention_mask = inputs["attention_mask"]
|
| 12 |
print("Running inference")
|
|
|
|
| 21 |
start_time = time.time()
|
| 22 |
result = []
|
| 23 |
print(f"Resulting tensor: {log_probs.shape}")
|
|
|
|
| 24 |
for probs in log_probs:
|
| 25 |
# Filter out low probability tokens for efficiency
|
| 26 |
above_threshold = torch.where(probs > threshold)
|
|
|
|
| 38 |
class LLMBatchExpander(BatchExpander):
|
| 39 |
model: PreTrainedModel
|
| 40 |
tokenizer: Tokenizer
|
| 41 |
+
threshold: float
|
| 42 |
|
| 43 |
def expand(self, batch: Batch) -> BatchCandidates:
|
| 44 |
inputs = prepare_inputs([s.get_all_tokens() for s in batch.items], self.tokenizer, self.model.device)
|
| 45 |
+
next_tokens = find_next_tokens(self.model, inputs, self.threshold)
|
| 46 |
start_time = time.time()
|
| 47 |
results = []
|
| 48 |
print(f"Batch size: {len(batch.items)}, next tokens size: {len(next_tokens)}")
|