Spaces:
Running
Running
Update model.py
Browse files
model.py
CHANGED
|
@@ -292,6 +292,10 @@ class GPT(nn.Module):
|
|
| 292 |
logits[logits < v[:, [-1]]] = -float('Inf')
|
| 293 |
probs = F.softmax(logits, dim=-1)
|
| 294 |
idx_next = torch.multinomial(probs, num_samples=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
idx = torch.cat((idx, idx_next), dim=1)
|
| 296 |
|
| 297 |
return idx
|
|
|
|
| 292 |
logits[logits < v[:, [-1]]] = -float('Inf')
|
| 293 |
probs = F.softmax(logits, dim=-1)
|
| 294 |
idx_next = torch.multinomial(probs, num_samples=1)
|
| 295 |
+
|
| 296 |
+
if idx_next.item() == 0: # stop token
|
| 297 |
+
break
|
| 298 |
+
|
| 299 |
idx = torch.cat((idx, idx_next), dim=1)
|
| 300 |
|
| 301 |
return idx
|