patchbanks commited on
Commit
cea3d05
·
verified ·
1 Parent(s): 83171b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -23,7 +23,7 @@ ckpt_load = 'model.pt'
23
 
24
  start = "000000000000\n"
25
  num_samples = 1
26
- max_new_tokens = 384
27
 
28
  seed = random.randint(1, 100000)
29
  torch.manual_seed(seed)
@@ -81,7 +81,7 @@ def clear_midi(dir):
81
 
82
  clear_midi(temp_dir)
83
 
84
-
85
  def generate_midi(temperature, top_k):
86
  start_ids = encode(start)
87
  x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
@@ -122,7 +122,7 @@ def generate_midi(temperature, top_k):
122
  for sequence in midi_events:
123
  filtered_sequence = []
124
  for event in sequence:
125
- if event['start'] < 768 and event['end'] <= 768:
126
  filtered_sequence.append(event)
127
  if filtered_sequence:
128
  round_bars.append(filtered_sequence)
 
23
 
24
  start = "000000000000\n"
25
  num_samples = 1
26
+ max_new_tokens = 1152
27
 
28
  seed = random.randint(1, 100000)
29
  torch.manual_seed(seed)
 
81
 
82
  clear_midi(temp_dir)
83
 
84
+ @spaces.GPU(duration=15)
85
  def generate_midi(temperature, top_k):
86
  start_ids = encode(start)
87
  x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
 
122
  for sequence in midi_events:
123
  filtered_sequence = []
124
  for event in sequence:
125
+ if event['start'] < 1536 and event['end'] <= 1536:
126
  filtered_sequence.append(event)
127
  if filtered_sequence:
128
  round_bars.append(filtered_sequence)