Better names for Expander etc
Browse files- README.md +1 -1
- completions.py +1 -2
- expand.py +14 -14
- expand_llm.py +4 -4
- expand_test.py +28 -28
- run.py +2 -2
README.md
CHANGED
|
@@ -173,7 +173,7 @@ In my case, I stop when the budget is exhausted, and I also stop if the expansio
|
|
| 173 |
|
| 174 |
Given the batch and the stopping criterion, we can call the expander:
|
| 175 |
```python
|
| 176 |
-
expander =
|
| 177 |
expanded = expand(batch, expander, stopping_criterion)
|
| 178 |
```
|
| 179 |
|
|
|
|
| 173 |
|
| 174 |
Given the batch and the stopping criterion, we can call the expander:
|
| 175 |
```python
|
| 176 |
+
expander = LLMBatchExpander(model, tokenizer)
|
| 177 |
expanded = expand(batch, expander, stopping_criterion)
|
| 178 |
```
|
| 179 |
|
completions.py
CHANGED
|
@@ -92,8 +92,7 @@ def check_text(input_text: str, model: PreTrainedModel, tokenizer: Tokenizer, de
|
|
| 92 |
|
| 93 |
contexts = [word.context for _, word in low_prob_words]
|
| 94 |
|
| 95 |
-
|
| 96 |
-
expander = ExpanderOneBatchLLM(model, tokenizer)
|
| 97 |
|
| 98 |
#%%
|
| 99 |
series = []
|
|
|
|
| 92 |
|
| 93 |
contexts = [word.context for _, word in low_prob_words]
|
| 94 |
|
| 95 |
+
expander = LLMBatchExpander(model, tokenizer)
|
|
|
|
| 96 |
|
| 97 |
#%%
|
| 98 |
series = []
|
expand.py
CHANGED
|
@@ -25,28 +25,28 @@ class Batch:
|
|
| 25 |
items: list[Series]
|
| 26 |
|
| 27 |
@dataclass
|
| 28 |
-
class
|
| 29 |
series: Series
|
| 30 |
expansions: list[Expansion]
|
| 31 |
|
| 32 |
@dataclass
|
| 33 |
-
class
|
| 34 |
-
items: list[
|
| 35 |
|
| 36 |
# A fundamental operation that we can implement both using an LLM and using a list of hardcoded sequences, for testing
|
| 37 |
-
class
|
| 38 |
-
def expand(self, batch: Batch) ->
|
| 39 |
|
| 40 |
@dataclass
|
| 41 |
-
class
|
| 42 |
series: Series
|
| 43 |
expansions: list[list[Expansion]]
|
| 44 |
|
| 45 |
@dataclass
|
| 46 |
-
class
|
| 47 |
-
items: list[
|
| 48 |
|
| 49 |
-
def compute_new_series(result:
|
| 50 |
new_series_batch = []
|
| 51 |
for expansion in result.expansions:
|
| 52 |
if not stopping_criterion(result.series, expansion):
|
|
@@ -60,7 +60,7 @@ def compute_new_series(result: ExpansionOneResult, stopping_criterion: Callable[
|
|
| 60 |
completed_series = [result.series] if len(new_series_batch) == 0 else []
|
| 61 |
return new_series_batch, completed_series
|
| 62 |
|
| 63 |
-
def compute_expansions(original_series: list[Series], expanded_series: list[Series]) ->
|
| 64 |
# check that ids in original_series are unique
|
| 65 |
assert len(original_series) == len({s.id for s in original_series})
|
| 66 |
# group original series by id
|
|
@@ -73,15 +73,15 @@ def compute_expansions(original_series: list[Series], expanded_series: list[Seri
|
|
| 73 |
results = []
|
| 74 |
for id, s in original_series_by_id.items():
|
| 75 |
expansions = expanded_series_by_id[id]
|
| 76 |
-
expansion_result =
|
| 77 |
results.append(expansion_result)
|
| 78 |
-
return
|
| 79 |
|
| 80 |
def default_completion_criterion(series: Series, expansion: Expansion) -> bool:
|
| 81 |
return series.get_remaining_budget() + expansion.cost < 0
|
| 82 |
|
| 83 |
-
# A compound operation that we can implement generically, relying on
|
| 84 |
-
def expand(batch: Batch, expander:
|
| 85 |
completed_series: list[Series] = []
|
| 86 |
current_batch = batch
|
| 87 |
while len(current_batch.items) > 0:
|
|
|
|
| 25 |
items: list[Series]
|
| 26 |
|
| 27 |
@dataclass
|
| 28 |
+
class TokenCandidates:
|
| 29 |
series: Series
|
| 30 |
expansions: list[Expansion]
|
| 31 |
|
| 32 |
@dataclass
|
| 33 |
+
class BatchCandidates:
|
| 34 |
+
items: list[TokenCandidates]
|
| 35 |
|
| 36 |
# A fundamental operation that we can implement both using an LLM and using a list of hardcoded sequences, for testing
|
| 37 |
+
class BatchExpander(Protocol):
|
| 38 |
+
def expand(self, batch: Batch) -> BatchCandidates: ...
|
| 39 |
|
| 40 |
@dataclass
|
| 41 |
+
class CompletedSequence:
|
| 42 |
series: Series
|
| 43 |
expansions: list[list[Expansion]]
|
| 44 |
|
| 45 |
@dataclass
|
| 46 |
+
class CompletedBatch:
|
| 47 |
+
items: list[CompletedSequence]
|
| 48 |
|
| 49 |
+
def compute_new_series(result: TokenCandidates, stopping_criterion: Callable[[Series, Expansion], bool]) -> tuple[list[Series], list[Series]]:
|
| 50 |
new_series_batch = []
|
| 51 |
for expansion in result.expansions:
|
| 52 |
if not stopping_criterion(result.series, expansion):
|
|
|
|
| 60 |
completed_series = [result.series] if len(new_series_batch) == 0 else []
|
| 61 |
return new_series_batch, completed_series
|
| 62 |
|
| 63 |
+
def compute_expansions(original_series: list[Series], expanded_series: list[Series]) -> CompletedBatch:
|
| 64 |
# check that ids in original_series are unique
|
| 65 |
assert len(original_series) == len({s.id for s in original_series})
|
| 66 |
# group original series by id
|
|
|
|
| 73 |
results = []
|
| 74 |
for id, s in original_series_by_id.items():
|
| 75 |
expansions = expanded_series_by_id[id]
|
| 76 |
+
expansion_result = CompletedSequence(series=s, expansions=expansions)
|
| 77 |
results.append(expansion_result)
|
| 78 |
+
return CompletedBatch(items=results)
|
| 79 |
|
| 80 |
def default_completion_criterion(series: Series, expansion: Expansion) -> bool:
|
| 81 |
return series.get_remaining_budget() + expansion.cost < 0
|
| 82 |
|
| 83 |
+
# A compound operation that we can implement generically, relying on a BatchExpander
|
| 84 |
+
def expand(batch: Batch, expander: BatchExpander, completion_criterion: Callable[[Series, Expansion], bool] = default_completion_criterion) -> CompletedBatch:
|
| 85 |
completed_series: list[Series] = []
|
| 86 |
current_batch = batch
|
| 87 |
while len(current_batch.items) > 0:
|
expand_llm.py
CHANGED
|
@@ -22,18 +22,18 @@ def prepare_inputs(contexts: list[list[int]], tokenizer: Tokenizer, device: torc
|
|
| 22 |
return tokenizer(texts, return_tensors="pt", padding=True).to(device)
|
| 23 |
|
| 24 |
@dataclass
|
| 25 |
-
class
|
| 26 |
model: PreTrainedModel
|
| 27 |
tokenizer: Tokenizer
|
| 28 |
|
| 29 |
-
def expand(self, batch: Batch) ->
|
| 30 |
inputs = prepare_inputs([s.get_all_tokens() for s in batch.items], self.tokenizer, self.model.device)
|
| 31 |
next_tokens = find_next_tokens(self.model, inputs, self.tokenizer)
|
| 32 |
results = []
|
| 33 |
for s, next_tokens in zip(batch.items, next_tokens):
|
| 34 |
expansions = [Expansion(token=token, cost=cost) for token, cost in next_tokens]
|
| 35 |
-
results.append(
|
| 36 |
-
return
|
| 37 |
|
| 38 |
def create_stopping_criterion_llm(tokenizer: Tokenizer) -> Callable[[Series, Expansion], bool]:
|
| 39 |
def stopping_criterion(series: Series, expansion: Expansion) -> bool:
|
|
|
|
| 22 |
return tokenizer(texts, return_tensors="pt", padding=True).to(device)
|
| 23 |
|
| 24 |
@dataclass
|
| 25 |
+
class LLMBatchExpander(BatchExpander):
|
| 26 |
model: PreTrainedModel
|
| 27 |
tokenizer: Tokenizer
|
| 28 |
|
| 29 |
+
def expand(self, batch: Batch) -> BatchCandidates:
|
| 30 |
inputs = prepare_inputs([s.get_all_tokens() for s in batch.items], self.tokenizer, self.model.device)
|
| 31 |
next_tokens = find_next_tokens(self.model, inputs, self.tokenizer)
|
| 32 |
results = []
|
| 33 |
for s, next_tokens in zip(batch.items, next_tokens):
|
| 34 |
expansions = [Expansion(token=token, cost=cost) for token, cost in next_tokens]
|
| 35 |
+
results.append(TokenCandidates(series=s, expansions=expansions))
|
| 36 |
+
return BatchCandidates(items=results)
|
| 37 |
|
| 38 |
def create_stopping_criterion_llm(tokenizer: Tokenizer) -> Callable[[Series, Expansion], bool]:
|
| 39 |
def stopping_criterion(series: Series, expansion: Expansion) -> bool:
|
expand_test.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
from dataclasses import dataclass
|
| 2 |
-
from expand import Series,
|
| 3 |
|
| 4 |
possible_sequences = [
|
| 5 |
[1, 21, 31, 41],
|
|
@@ -16,21 +16,21 @@ def expand_series(series: Series) -> list[Expansion]:
|
|
| 16 |
candidates = [Expansion(token=l, cost=-1.0) for l in dict.fromkeys(items)]
|
| 17 |
return candidates
|
| 18 |
|
| 19 |
-
class
|
| 20 |
-
def expand(self, batch: Batch) ->
|
| 21 |
result = []
|
| 22 |
for s in batch.items:
|
| 23 |
expansions = expand_series(s)
|
| 24 |
-
result.append(
|
| 25 |
-
return
|
| 26 |
|
| 27 |
-
expander =
|
| 28 |
|
| 29 |
def test_expander_zero_budget():
|
| 30 |
s = Series(id=0, tokens=[1], budget=0.0)
|
| 31 |
expanded = expander.expand(Batch(items=[s]))
|
| 32 |
-
expected =
|
| 33 |
-
items=[
|
| 34 |
Expansion(token=21, cost=-1.0),
|
| 35 |
Expansion(token=22, cost=-1.0),
|
| 36 |
])]
|
|
@@ -40,8 +40,8 @@ def test_expander_zero_budget():
|
|
| 40 |
def test_expander_budget_one():
|
| 41 |
s = Series(id=0, tokens=[1], budget=1.0)
|
| 42 |
expanded = expander.expand(Batch(items=[s]))
|
| 43 |
-
expected =
|
| 44 |
-
items=[
|
| 45 |
Expansion(token=21, cost=-1.0),
|
| 46 |
Expansion(token=22, cost=-1.0),
|
| 47 |
])]
|
|
@@ -51,8 +51,8 @@ def test_expander_budget_one():
|
|
| 51 |
def test_expander_budget_two():
|
| 52 |
s = Series(id=0, tokens=[1], budget=2.0)
|
| 53 |
expanded = expander.expand(Batch(items=[s]))
|
| 54 |
-
expected =
|
| 55 |
-
items=[
|
| 56 |
Expansion(token=21, cost=-1.0),
|
| 57 |
Expansion(token=22, cost=-1.0),
|
| 58 |
])]
|
|
@@ -62,16 +62,16 @@ def test_expander_budget_two():
|
|
| 62 |
def test_expander_budget_one_no_expansion():
|
| 63 |
s = Series(id=0, tokens=[1, 20], budget=1.0)
|
| 64 |
expanded = expander.expand(Batch(items=[s]))
|
| 65 |
-
expected =
|
| 66 |
-
items=[
|
| 67 |
)
|
| 68 |
assert expected == expanded
|
| 69 |
|
| 70 |
def test_expander_budget_one_two_tokens():
|
| 71 |
s = Series(id=0, tokens=[1, 22], budget=1.0)
|
| 72 |
expanded = expander.expand(Batch(items=[s]))
|
| 73 |
-
expected =
|
| 74 |
-
items=[
|
| 75 |
Expansion(token=33, cost=-1.0),
|
| 76 |
Expansion(token=34, cost=-1.0),
|
| 77 |
])]
|
|
@@ -82,13 +82,13 @@ def test_expander_budget_one_two_tokens_two_series():
|
|
| 82 |
s1 = Series(id=0, tokens=[1, 21, 31], budget=1.0)
|
| 83 |
s2 = Series(id=1, tokens=[1, 22], budget=1.0)
|
| 84 |
expanded = expander.expand(Batch(items=[s1, s2]))
|
| 85 |
-
expected =
|
| 86 |
items=[
|
| 87 |
-
|
| 88 |
Expansion(token=41, cost=-1.0),
|
| 89 |
Expansion(token=42, cost=-1.0),
|
| 90 |
]),
|
| 91 |
-
|
| 92 |
Expansion(token=33, cost=-1.0),
|
| 93 |
Expansion(token=34, cost=-1.0),
|
| 94 |
])
|
|
@@ -102,15 +102,15 @@ def test_expand_01():
|
|
| 102 |
Series(id=1, tokens=[1, 22], budget=1.0),
|
| 103 |
])
|
| 104 |
expanded = expand(batch, expander)
|
| 105 |
-
assert expanded ==
|
| 106 |
-
|
| 107 |
series=Series(id=0, tokens=[1, 21], budget=1.0),
|
| 108 |
expansions=[
|
| 109 |
[Expansion(token=31, cost=-1.0)],
|
| 110 |
[Expansion(token=32, cost=-1.0)],
|
| 111 |
]
|
| 112 |
),
|
| 113 |
-
|
| 114 |
series=Series(id=1, tokens=[1, 22], budget=1.0),
|
| 115 |
expansions=[
|
| 116 |
[Expansion(token=33, cost=-1.0)],
|
|
@@ -125,8 +125,8 @@ def test_expand_02():
|
|
| 125 |
Series(id=1, tokens=[1, 22], budget=1.0),
|
| 126 |
])
|
| 127 |
expanded = expand(batch, expander)
|
| 128 |
-
assert expanded ==
|
| 129 |
-
|
| 130 |
series=Series(id=0, tokens=[1, 21], budget=2.0),
|
| 131 |
expansions=[
|
| 132 |
[Expansion(token=31, cost=-1.0), Expansion(token=41, cost=-1.0)],
|
|
@@ -134,7 +134,7 @@ def test_expand_02():
|
|
| 134 |
[Expansion(token=32, cost=-1.0), Expansion(token=41, cost=-1.0)],
|
| 135 |
]
|
| 136 |
),
|
| 137 |
-
|
| 138 |
series=Series(id=1, tokens=[1, 22], budget=1.0),
|
| 139 |
expansions=[
|
| 140 |
[Expansion(token=33, cost=-1.0)],
|
|
@@ -149,8 +149,8 @@ def test_expand_03():
|
|
| 149 |
Series(id=1, tokens=[1, 22], budget=0.0),
|
| 150 |
])
|
| 151 |
expanded = expand(batch, expander)
|
| 152 |
-
assert expanded ==
|
| 153 |
-
|
| 154 |
series=Series(id=0, tokens=[1, 21], budget=3.0),
|
| 155 |
expansions=[
|
| 156 |
[Expansion(token=31, cost=-1.0), Expansion(token=41, cost=-1.0)],
|
|
@@ -158,7 +158,7 @@ def test_expand_03():
|
|
| 158 |
[Expansion(token=32, cost=-1.0), Expansion(token=41, cost=-1.0), Expansion(token=51, cost=-1.0)],
|
| 159 |
]
|
| 160 |
),
|
| 161 |
-
|
| 162 |
series=Series(id=1, tokens=[1, 22], budget=0.0),
|
| 163 |
expansions=[],
|
| 164 |
),
|
|
|
|
| 1 |
from dataclasses import dataclass
|
| 2 |
+
from expand import Series, BatchExpander, Expansion, Batch, TokenCandidates, BatchCandidates, CompletedSequence, CompletedBatch, expand
|
| 3 |
|
| 4 |
possible_sequences = [
|
| 5 |
[1, 21, 31, 41],
|
|
|
|
| 16 |
candidates = [Expansion(token=l, cost=-1.0) for l in dict.fromkeys(items)]
|
| 17 |
return candidates
|
| 18 |
|
| 19 |
+
class PredefinedSequenceExpander(BatchExpander):
|
| 20 |
+
def expand(self, batch: Batch) -> BatchCandidates:
|
| 21 |
result = []
|
| 22 |
for s in batch.items:
|
| 23 |
expansions = expand_series(s)
|
| 24 |
+
result.append(TokenCandidates(series=s, expansions=expansions))
|
| 25 |
+
return BatchCandidates(items=result)
|
| 26 |
|
| 27 |
+
expander = PredefinedSequenceExpander()
|
| 28 |
|
| 29 |
def test_expander_zero_budget():
|
| 30 |
s = Series(id=0, tokens=[1], budget=0.0)
|
| 31 |
expanded = expander.expand(Batch(items=[s]))
|
| 32 |
+
expected = BatchCandidates(
|
| 33 |
+
items=[TokenCandidates(series=s, expansions=[
|
| 34 |
Expansion(token=21, cost=-1.0),
|
| 35 |
Expansion(token=22, cost=-1.0),
|
| 36 |
])]
|
|
|
|
| 40 |
def test_expander_budget_one():
|
| 41 |
s = Series(id=0, tokens=[1], budget=1.0)
|
| 42 |
expanded = expander.expand(Batch(items=[s]))
|
| 43 |
+
expected = BatchCandidates(
|
| 44 |
+
items=[TokenCandidates(series=s, expansions=[
|
| 45 |
Expansion(token=21, cost=-1.0),
|
| 46 |
Expansion(token=22, cost=-1.0),
|
| 47 |
])]
|
|
|
|
| 51 |
def test_expander_budget_two():
|
| 52 |
s = Series(id=0, tokens=[1], budget=2.0)
|
| 53 |
expanded = expander.expand(Batch(items=[s]))
|
| 54 |
+
expected = BatchCandidates(
|
| 55 |
+
items=[TokenCandidates(series=s, expansions=[
|
| 56 |
Expansion(token=21, cost=-1.0),
|
| 57 |
Expansion(token=22, cost=-1.0),
|
| 58 |
])]
|
|
|
|
| 62 |
def test_expander_budget_one_no_expansion():
|
| 63 |
s = Series(id=0, tokens=[1, 20], budget=1.0)
|
| 64 |
expanded = expander.expand(Batch(items=[s]))
|
| 65 |
+
expected = BatchCandidates(
|
| 66 |
+
items=[TokenCandidates(series=s, expansions=[])]
|
| 67 |
)
|
| 68 |
assert expected == expanded
|
| 69 |
|
| 70 |
def test_expander_budget_one_two_tokens():
|
| 71 |
s = Series(id=0, tokens=[1, 22], budget=1.0)
|
| 72 |
expanded = expander.expand(Batch(items=[s]))
|
| 73 |
+
expected = BatchCandidates(
|
| 74 |
+
items=[TokenCandidates(series=s, expansions=[
|
| 75 |
Expansion(token=33, cost=-1.0),
|
| 76 |
Expansion(token=34, cost=-1.0),
|
| 77 |
])]
|
|
|
|
| 82 |
s1 = Series(id=0, tokens=[1, 21, 31], budget=1.0)
|
| 83 |
s2 = Series(id=1, tokens=[1, 22], budget=1.0)
|
| 84 |
expanded = expander.expand(Batch(items=[s1, s2]))
|
| 85 |
+
expected = BatchCandidates(
|
| 86 |
items=[
|
| 87 |
+
TokenCandidates(series=s1, expansions=[
|
| 88 |
Expansion(token=41, cost=-1.0),
|
| 89 |
Expansion(token=42, cost=-1.0),
|
| 90 |
]),
|
| 91 |
+
TokenCandidates(series=s2, expansions=[
|
| 92 |
Expansion(token=33, cost=-1.0),
|
| 93 |
Expansion(token=34, cost=-1.0),
|
| 94 |
])
|
|
|
|
| 102 |
Series(id=1, tokens=[1, 22], budget=1.0),
|
| 103 |
])
|
| 104 |
expanded = expand(batch, expander)
|
| 105 |
+
assert expanded == CompletedBatch(items=[
|
| 106 |
+
CompletedSequence(
|
| 107 |
series=Series(id=0, tokens=[1, 21], budget=1.0),
|
| 108 |
expansions=[
|
| 109 |
[Expansion(token=31, cost=-1.0)],
|
| 110 |
[Expansion(token=32, cost=-1.0)],
|
| 111 |
]
|
| 112 |
),
|
| 113 |
+
CompletedSequence(
|
| 114 |
series=Series(id=1, tokens=[1, 22], budget=1.0),
|
| 115 |
expansions=[
|
| 116 |
[Expansion(token=33, cost=-1.0)],
|
|
|
|
| 125 |
Series(id=1, tokens=[1, 22], budget=1.0),
|
| 126 |
])
|
| 127 |
expanded = expand(batch, expander)
|
| 128 |
+
assert expanded == CompletedBatch(items=[
|
| 129 |
+
CompletedSequence(
|
| 130 |
series=Series(id=0, tokens=[1, 21], budget=2.0),
|
| 131 |
expansions=[
|
| 132 |
[Expansion(token=31, cost=-1.0), Expansion(token=41, cost=-1.0)],
|
|
|
|
| 134 |
[Expansion(token=32, cost=-1.0), Expansion(token=41, cost=-1.0)],
|
| 135 |
]
|
| 136 |
),
|
| 137 |
+
CompletedSequence(
|
| 138 |
series=Series(id=1, tokens=[1, 22], budget=1.0),
|
| 139 |
expansions=[
|
| 140 |
[Expansion(token=33, cost=-1.0)],
|
|
|
|
| 149 |
Series(id=1, tokens=[1, 22], budget=0.0),
|
| 150 |
])
|
| 151 |
expanded = expand(batch, expander)
|
| 152 |
+
assert expanded == CompletedBatch(items=[
|
| 153 |
+
CompletedSequence(
|
| 154 |
series=Series(id=0, tokens=[1, 21], budget=3.0),
|
| 155 |
expansions=[
|
| 156 |
[Expansion(token=31, cost=-1.0), Expansion(token=41, cost=-1.0)],
|
|
|
|
| 158 |
[Expansion(token=32, cost=-1.0), Expansion(token=41, cost=-1.0), Expansion(token=51, cost=-1.0)],
|
| 159 |
]
|
| 160 |
),
|
| 161 |
+
CompletedSequence(
|
| 162 |
series=Series(id=1, tokens=[1, 22], budget=0.0),
|
| 163 |
expansions=[],
|
| 164 |
),
|
run.py
CHANGED
|
@@ -24,7 +24,7 @@ low_prob_words = [(i, word) for i, word in enumerate(words) if word.logprob < lo
|
|
| 24 |
contexts = [word.context for _, word in low_prob_words]
|
| 25 |
|
| 26 |
#%%
|
| 27 |
-
expander =
|
| 28 |
|
| 29 |
#%%
|
| 30 |
series = []
|
|
@@ -41,7 +41,7 @@ stopping_criterion = create_stopping_criterion_llm(tokenizer)
|
|
| 41 |
expanded = expand(batch, expander, stopping_criterion)
|
| 42 |
|
| 43 |
# %%
|
| 44 |
-
def print_expansions(expansions:
|
| 45 |
for result in expansions.items:
|
| 46 |
for expansion in result.expansions:
|
| 47 |
# convert tokens to string
|
|
|
|
| 24 |
contexts = [word.context for _, word in low_prob_words]
|
| 25 |
|
| 26 |
#%%
|
| 27 |
+
expander = LLMBatchExpander(model, tokenizer)
|
| 28 |
|
| 29 |
#%%
|
| 30 |
series = []
|
|
|
|
| 41 |
expanded = expand(batch, expander, stopping_criterion)
|
| 42 |
|
| 43 |
# %%
|
| 44 |
+
def print_expansions(expansions: CompletedBatch):
|
| 45 |
for result in expansions.items:
|
| 46 |
for expansion in result.expansions:
|
| 47 |
# convert tokens to string
|