Spaces:
Running
Running
Add a quick and dirty show-next-token-logits page
Browse files- custom_llm.py +53 -0
custom_llm.py
CHANGED
|
@@ -203,6 +203,59 @@ def continue_messages(request: ContinueMessagesRequest):
|
|
| 203 |
}
|
| 204 |
|
| 205 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
|
| 207 |
if __name__ == "__main__":
|
| 208 |
uvicorn.run(app, host="localhost", port=PORT)
|
|
|
|
| 203 |
}
|
| 204 |
|
| 205 |
|
| 206 |
+
@app.post('/api/logprobs')
|
| 207 |
+
def logprobs(request: ContinueMessagesRequest):
|
| 208 |
+
|
| 209 |
+
messages = [{"role": m.role, "content": m.content} for m in request.messages]
|
| 210 |
+
if len(messages) == 0:
|
| 211 |
+
raise HTTPException(status_code=400, detail="At least one message must be provided.")
|
| 212 |
+
n_branch_tokens = request.n_branch_tokens
|
| 213 |
+
n_future_tokens = request.n_future_tokens
|
| 214 |
+
|
| 215 |
+
model = ml_models['llm']['model']
|
| 216 |
+
tokenizer = ml_models['llm']['tokenizer']
|
| 217 |
+
|
| 218 |
+
device = model.device
|
| 219 |
+
|
| 220 |
+
tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, return_tensors="pt", continue_final_message=True).to(model.device)
|
| 221 |
+
|
| 222 |
+
# Compute all logits
|
| 223 |
+
with torch.no_grad():
|
| 224 |
+
logits = model(tokenized_chat).logits
|
| 225 |
+
|
| 226 |
+
k = request.n_branch_tokens
|
| 227 |
+
|
| 228 |
+
# Return a list of tokens:
|
| 229 |
+
# {
|
| 230 |
+
# "token": "the",
|
| 231 |
+
# "logprobs": [{"the": -0.1, "a": -0.2, ...}]
|
| 232 |
+
# }
|
| 233 |
+
# logprobs are the top-k logprobs for each token, plus the chosen token in case it is not in the top-k
|
| 234 |
+
# The very first token will have no logprobs, since it is the beginning of the document
|
| 235 |
+
# The very last token will have "token" set to None, and "logprobs" will be the logprobs for the next token
|
| 236 |
+
|
| 237 |
+
all_logprobs = []
|
| 238 |
+
for idx in range(len(tokenized_chat[0]) + 1):
|
| 239 |
+
if idx == len(tokenized_chat[0]):
|
| 240 |
+
actual_token_id = None
|
| 241 |
+
token = None
|
| 242 |
+
else:
|
| 243 |
+
actual_token_id = tokenized_chat[0, idx].item()
|
| 244 |
+
token = tokenizer.decode(actual_token_id)
|
| 245 |
+
|
| 246 |
+
if idx == 0:
|
| 247 |
+
token_logprobs = []
|
| 248 |
+
else:
|
| 249 |
+
logprobs = logits[0, idx - 1].log_softmax(dim=-1)
|
| 250 |
+
token_ids_to_return = logprobs.topk(k).indices.cpu().numpy().tolist()
|
| 251 |
+
if actual_token_id is not None and actual_token_id not in token_ids_to_return:
|
| 252 |
+
token_ids_to_return.append(actual_token_id)
|
| 253 |
+
token_logprobs = {tokenizer.decode(i): logprobs[i].item() for i in token_ids_to_return}
|
| 254 |
+
all_logprobs.append(dict(token=token, logprobs=token_logprobs))
|
| 255 |
+
|
| 256 |
+
return {
|
| 257 |
+
'logprobs': all_logprobs
|
| 258 |
+
}
|
| 259 |
|
| 260 |
if __name__ == "__main__":
|
| 261 |
uvicorn.run(app, host="localhost", port=PORT)
|