Cleanup
Browse files- completions.py +0 -17
- main.py +1 -19
completions.py
CHANGED
|
@@ -71,23 +71,6 @@ def calculate_log_probabilities(model: PreTrainedModel, tokenizer: Tokenizer, in
|
|
| 71 |
tokens: torch.Tensor = input_ids[0][1:]
|
| 72 |
return list(zip(tokens.tolist(), token_log_probs.tolist()))
|
| 73 |
|
| 74 |
-
def generate_outputs(model: PreTrainedModel, inputs: BatchEncoding, num_samples: int = 5) -> GenerateOutput | torch.LongTensor:
|
| 75 |
-
input_ids = inputs["input_ids"]
|
| 76 |
-
attention_mask = inputs["attention_mask"]
|
| 77 |
-
with torch.no_grad():
|
| 78 |
-
outputs = model.generate(
|
| 79 |
-
input_ids=input_ids,
|
| 80 |
-
attention_mask=attention_mask,
|
| 81 |
-
max_new_tokens=4,
|
| 82 |
-
num_return_sequences=num_samples,
|
| 83 |
-
temperature=1.0,
|
| 84 |
-
top_k=50,
|
| 85 |
-
top_p=0.95,
|
| 86 |
-
do_sample=True
|
| 87 |
-
# num_beams=num_samples
|
| 88 |
-
)
|
| 89 |
-
return outputs
|
| 90 |
-
|
| 91 |
#%%
|
| 92 |
|
| 93 |
def load_model() -> tuple[PreTrainedModel, Tokenizer, torch.device]:
|
|
|
|
| 71 |
tokens: torch.Tensor = input_ids[0][1:]
|
| 72 |
return list(zip(tokens.tolist(), token_log_probs.tolist()))
|
| 73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
#%%
|
| 75 |
|
| 76 |
def load_model() -> tuple[PreTrainedModel, Tokenizer, torch.device]:
|
main.py
CHANGED
|
@@ -2,29 +2,13 @@ from fastapi import FastAPI
|
|
| 2 |
from fastapi.staticfiles import StaticFiles
|
| 3 |
from functools import lru_cache
|
| 4 |
|
| 5 |
-
from models import
|
| 6 |
from completions import check_text, load_model
|
| 7 |
|
| 8 |
app = FastAPI()
|
| 9 |
|
| 10 |
model, tokenizer, device = load_model()
|
| 11 |
|
| 12 |
-
def check_text_stub(text: str):
|
| 13 |
-
def rep(i):
|
| 14 |
-
if i == 3:
|
| 15 |
-
return -10, [" jumped", " leaps"]
|
| 16 |
-
if i == 5:
|
| 17 |
-
return -10, [" calm"]
|
| 18 |
-
if i == 7:
|
| 19 |
-
return -10, [" dog", " cat", " bird", " fish"]
|
| 20 |
-
return -3, []
|
| 21 |
-
|
| 22 |
-
result = []
|
| 23 |
-
for i, w in enumerate(text.split()):
|
| 24 |
-
logprob, replacements = rep(i)
|
| 25 |
-
result.append(ApiWord(text=f" {w}", logprob=logprob, replacements=replacements))
|
| 26 |
-
return result
|
| 27 |
-
|
| 28 |
@lru_cache(maxsize=100)
|
| 29 |
def cached_check_text(text: str):
|
| 30 |
return check_text(text, model, tokenizer, device)
|
|
@@ -34,5 +18,3 @@ def check(text: str):
|
|
| 34 |
return CheckResponse(text=text, words=cached_check_text(text))
|
| 35 |
|
| 36 |
app.mount("/", StaticFiles(directory="frontend/public", html=True))
|
| 37 |
-
|
| 38 |
-
#%%
|
|
|
|
| 2 |
from fastapi.staticfiles import StaticFiles
|
| 3 |
from functools import lru_cache
|
| 4 |
|
| 5 |
+
from models import CheckResponse
|
| 6 |
from completions import check_text, load_model
|
| 7 |
|
| 8 |
app = FastAPI()
|
| 9 |
|
| 10 |
model, tokenizer, device = load_model()
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
@lru_cache(maxsize=100)
|
| 13 |
def cached_check_text(text: str):
|
| 14 |
return check_text(text, model, tokenizer, device)
|
|
|
|
| 18 |
return CheckResponse(text=text, words=cached_check_text(text))
|
| 19 |
|
| 20 |
app.mount("/", StaticFiles(directory="frontend/public", html=True))
|
|
|
|
|
|