Adityak204 commited on
Commit
34a1570
·
verified ·
1 Parent(s): fce3e3b
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -8,6 +8,7 @@ import torch.nn.functional as F
8
  from src.gpt_base import GPT
9
  import json
10
  from huggingface_hub import hf_hub_download
 
11
 
12
 
13
  # Config class for model parameters
@@ -43,7 +44,7 @@ def decode(l):
43
 
44
  def predict_next_word(text, model, seq_len=50):
45
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
- for _ in range(seq_len):
47
  xb = torch.tensor(encode(text)).unsqueeze(0).to(device)
48
  yb = model(xb)
49
  next_word = yb[0, -1].argmax().item()
 
8
  from src.gpt_base import GPT
9
  import json
10
  from huggingface_hub import hf_hub_download
11
+ from tqdm import tqdm
12
 
13
 
14
  # Config class for model parameters
 
44
 
45
  def predict_next_word(text, model, seq_len=50):
46
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
47
+ for _ in tqdm(range(seq_len)):
48
  xb = torch.tensor(encode(text)).unsqueeze(0).to(device)
49
  yb = model(xb)
50
  next_word = yb[0, -1].argmax().item()