patchbanks commited on
Commit
8bd5b8a
·
verified ·
1 Parent(s): 975042a

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +4 -0
model.py CHANGED
@@ -292,6 +292,10 @@ class GPT(nn.Module):
292
  logits[logits < v[:, [-1]]] = -float('Inf')
293
  probs = F.softmax(logits, dim=-1)
294
  idx_next = torch.multinomial(probs, num_samples=1)
 
 
 
 
295
  idx = torch.cat((idx, idx_next), dim=1)
296
 
297
  return idx
 
292
  logits[logits < v[:, [-1]]] = -float('Inf')
293
  probs = F.softmax(logits, dim=-1)
294
  idx_next = torch.multinomial(probs, num_samples=1)
295
+
296
+ if idx_next.item() == 0: # stop token
297
+ break
298
+
299
  idx = torch.cat((idx, idx_next), dim=1)
300
 
301
  return idx