Spaces:
Build error
Build error
| import pytest | |
| from utils import * | |
| server = ServerPreset.tinyllama2() | |
| LONG_TEXT = """ | |
| Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. | |
| Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. | |
| Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. | |
| Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. | |
| """.strip() | |
| def create_server(): | |
| global server | |
| server = ServerPreset.tinyllama2() | |
| server.n_ctx = 256 | |
| server.n_slots = 2 | |
| def test_ctx_shift_enabled(): | |
| # the prompt is 301 tokens | |
| # the slot context is 256/2 = 128 tokens | |
| # the prompt is truncated to keep the last 109 tokens | |
| # 64 tokens are generated thanks to shifting the context when it gets full | |
| global server | |
| server.start() | |
| res = server.make_request("POST", "/completion", data={ | |
| "n_predict": 64, | |
| "prompt": LONG_TEXT, | |
| }) | |
| assert res.status_code == 200 | |
| assert res.body["timings"]["prompt_n"] == 109 | |
| assert res.body["timings"]["predicted_n"] == 64 | |
| assert res.body["truncated"] is True | |
| def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, truncated: bool): | |
| global server | |
| server.disable_ctx_shift = True | |
| server.n_predict = -1 | |
| server.start() | |
| res = server.make_request("POST", "/completion", data={ | |
| "n_predict": n_predict, | |
| "prompt": "Hi how are you", | |
| }) | |
| assert res.status_code == 200 | |
| assert res.body["timings"]["predicted_n"] == n_token_output | |
| assert res.body["truncated"] == truncated | |
| def test_ctx_shift_disabled_long_prompt(): | |
| global server | |
| server.disable_ctx_shift = True | |
| server.start() | |
| res = server.make_request("POST", "/completion", data={ | |
| "n_predict": 64, | |
| "prompt": LONG_TEXT, | |
| }) | |
| assert res.status_code != 200 | |
| assert "error" in res.body | |
| assert "exceeds the available context size" in res.body["error"]["message"] | |