| from dataclasses import dataclass | |
| from expand import Series, BatchExpander, Expansion, Batch, TokenCandidates, BatchCandidates, CompletedSequence, CompletedBatch, expand | |
| possible_sequences = [ | |
| [1, 21, 31, 41], | |
| [1, 21, 31, 42], | |
| [1, 21, 32, 41, 51], | |
| [1, 22, 33, 41], | |
| [1, 22, 34, 41], | |
| ] | |
| def expand_series(series: Series) -> list[Expansion]: | |
| all_tokens = series.get_all_tokens() | |
| l = len(all_tokens) | |
| items = [s[l] for s in possible_sequences if s[:l] == all_tokens and len(s) > l] | |
| candidates = [Expansion(token=l, cost=-1.0) for l in dict.fromkeys(items)] | |
| return candidates | |
| class PredefinedSequenceExpander(BatchExpander): | |
| def expand(self, batch: Batch) -> BatchCandidates: | |
| result = [] | |
| for s in batch.items: | |
| expansions = expand_series(s) | |
| result.append(TokenCandidates(series=s, expansions=expansions)) | |
| return BatchCandidates(items=result) | |
| expander = PredefinedSequenceExpander() | |
| def test_expander_zero_budget(): | |
| s = Series(id=0, tokens=[1], budget=0.0) | |
| expanded = expander.expand(Batch(items=[s])) | |
| expected = BatchCandidates( | |
| items=[TokenCandidates(series=s, expansions=[ | |
| Expansion(token=21, cost=-1.0), | |
| Expansion(token=22, cost=-1.0), | |
| ])] | |
| ) | |
| assert expected == expanded | |
| def test_expander_budget_one(): | |
| s = Series(id=0, tokens=[1], budget=1.0) | |
| expanded = expander.expand(Batch(items=[s])) | |
| expected = BatchCandidates( | |
| items=[TokenCandidates(series=s, expansions=[ | |
| Expansion(token=21, cost=-1.0), | |
| Expansion(token=22, cost=-1.0), | |
| ])] | |
| ) | |
| assert expected == expanded | |
| def test_expander_budget_two(): | |
| s = Series(id=0, tokens=[1], budget=2.0) | |
| expanded = expander.expand(Batch(items=[s])) | |
| expected = BatchCandidates( | |
| items=[TokenCandidates(series=s, expansions=[ | |
| Expansion(token=21, cost=-1.0), | |
| Expansion(token=22, cost=-1.0), | |
| ])] | |
| ) | |
| assert expected == expanded | |
| def test_expander_budget_one_no_expansion(): | |
| s = Series(id=0, tokens=[1, 20], budget=1.0) | |
| expanded = expander.expand(Batch(items=[s])) | |
| expected = BatchCandidates( | |
| items=[TokenCandidates(series=s, expansions=[])] | |
| ) | |
| assert expected == expanded | |
| def test_expander_budget_one_two_tokens(): | |
| s = Series(id=0, tokens=[1, 22], budget=1.0) | |
| expanded = expander.expand(Batch(items=[s])) | |
| expected = BatchCandidates( | |
| items=[TokenCandidates(series=s, expansions=[ | |
| Expansion(token=33, cost=-1.0), | |
| Expansion(token=34, cost=-1.0), | |
| ])] | |
| ) | |
| assert expected == expanded | |
| def test_expander_budget_one_two_tokens_two_series(): | |
| s1 = Series(id=0, tokens=[1, 21, 31], budget=1.0) | |
| s2 = Series(id=1, tokens=[1, 22], budget=1.0) | |
| expanded = expander.expand(Batch(items=[s1, s2])) | |
| expected = BatchCandidates( | |
| items=[ | |
| TokenCandidates(series=s1, expansions=[ | |
| Expansion(token=41, cost=-1.0), | |
| Expansion(token=42, cost=-1.0), | |
| ]), | |
| TokenCandidates(series=s2, expansions=[ | |
| Expansion(token=33, cost=-1.0), | |
| Expansion(token=34, cost=-1.0), | |
| ]) | |
| ] | |
| ) | |
| assert expected == expanded | |
| def test_expand_01(): | |
| batch = Batch(items=[ | |
| Series(id=0, tokens=[1, 21], budget=1.0), | |
| Series(id=1, tokens=[1, 22], budget=1.0), | |
| ]) | |
| expanded = expand(batch, expander) | |
| assert expanded == CompletedBatch(items=[ | |
| CompletedSequence( | |
| series=Series(id=0, tokens=[1, 21], budget=1.0), | |
| expansions=[ | |
| [Expansion(token=31, cost=-1.0)], | |
| [Expansion(token=32, cost=-1.0)], | |
| ] | |
| ), | |
| CompletedSequence( | |
| series=Series(id=1, tokens=[1, 22], budget=1.0), | |
| expansions=[ | |
| [Expansion(token=33, cost=-1.0)], | |
| [Expansion(token=34, cost=-1.0)], | |
| ] | |
| ), | |
| ]) | |
| def test_expand_02(): | |
| batch = Batch(items=[ | |
| Series(id=0, tokens=[1, 21], budget=2.0), | |
| Series(id=1, tokens=[1, 22], budget=1.0), | |
| ]) | |
| expanded = expand(batch, expander) | |
| assert expanded == CompletedBatch(items=[ | |
| CompletedSequence( | |
| series=Series(id=0, tokens=[1, 21], budget=2.0), | |
| expansions=[ | |
| [Expansion(token=31, cost=-1.0), Expansion(token=41, cost=-1.0)], | |
| [Expansion(token=31, cost=-1.0), Expansion(token=42, cost=-1.0)], | |
| [Expansion(token=32, cost=-1.0), Expansion(token=41, cost=-1.0)], | |
| ] | |
| ), | |
| CompletedSequence( | |
| series=Series(id=1, tokens=[1, 22], budget=1.0), | |
| expansions=[ | |
| [Expansion(token=33, cost=-1.0)], | |
| [Expansion(token=34, cost=-1.0)], | |
| ] | |
| ), | |
| ]) | |
| def test_expand_03(): | |
| batch = Batch(items=[ | |
| Series(id=0, tokens=[1, 21], budget=3.0), | |
| Series(id=1, tokens=[1, 22], budget=0.0), | |
| ]) | |
| expanded = expand(batch, expander) | |
| assert expanded == CompletedBatch(items=[ | |
| CompletedSequence( | |
| series=Series(id=0, tokens=[1, 21], budget=3.0), | |
| expansions=[ | |
| [Expansion(token=31, cost=-1.0), Expansion(token=41, cost=-1.0)], | |
| [Expansion(token=31, cost=-1.0), Expansion(token=42, cost=-1.0)], | |
| [Expansion(token=32, cost=-1.0), Expansion(token=41, cost=-1.0), Expansion(token=51, cost=-1.0)], | |
| ] | |
| ), | |
| CompletedSequence( | |
| series=Series(id=1, tokens=[1, 22], budget=0.0), | |
| expansions=[], | |
| ), | |
| ]) | |