Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn.functional as F | |
| def top_k_filtering(logits, top_k: int = 1): | |
| """ | |
| Filter a distribution of logits using top-k and/or top-p (nucleus) filtering. | |
| The input logits tensor is modified in-place. | |
| Args: | |
| logits: A tensor of logits to be filtered. Expected shape is [..., vocab_size]. | |
| top_k: If > 0, only keep the top k tokens with highest probability. | |
| top_p: If < 1.0, only keep tokens whose cumulative probability is below this threshold. | |
| Returns: | |
| A tensor of logits where values outside the top-k/top-p threshold are set to -β. | |
| """ | |
| if top_k > 0: | |
| idx_to_remove = logits < logits.topk(top_k, largest=True, sorted=False, dim=-1)[ | |
| 0 | |
| ].amin(dim=-1, keepdim=True) | |
| logits.masked_fill_(idx_to_remove, -torch.inf) | |
| return logits | |
| def process_logits( | |
| logits, | |
| top_k: int = 1, | |
| ): | |
| """ | |
| Process logits by optionally applying top-k filtering. | |
| The final probabilities are returned after applying softmax on the filtered logits. | |
| Args: | |
| logits: A tensor of logits to process. Expected shape is [..., vocab_size]. | |
| top_k: If > 0, only keep the top k tokens with highest probability. | |
| Returns: | |
| A tensor of probabilities after filtering, with the same shape as the input logits. | |
| """ | |
| logits = top_k_filtering(logits, top_k=top_k) | |
| probs = F.softmax(logits, dim=-1) | |
| return probs | |