Gijs Wijngaard commited on
Commit
49228ab
·
1 Parent(s): e41bd20
Files changed (1) hide show
  1. app.py +3 -4
app.py CHANGED
@@ -18,7 +18,7 @@ model.disable_talker()
18
  processor = Qwen2_5OmniProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
19
 
20
  @spaces.GPU
21
- def run_omni(audio_path: str, instruction: str, max_tokens: int = 512) -> str:
22
  if not audio_path:
23
  return "Please upload an audio file."
24
 
@@ -53,7 +53,7 @@ def run_omni(audio_path: str, instruction: str, max_tokens: int = 512) -> str:
53
  )
54
  inputs = inputs.to(model.device)
55
 
56
- output_ids = model.generate(**inputs, max_new_tokens=int(max_tokens))
57
  output_ids = output_ids[:, inputs["input_ids"].shape[1]:]
58
  response = processor.batch_decode(
59
  output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
@@ -72,12 +72,11 @@ with gr.Blocks(title="Qwen2.5 Omni (Audio) Demo") as demo:
72
  label="Instruction",
73
  value="Transcribe the audio, then summarize it in one sentence.",
74
  )
75
- max_tokens = gr.Slider(128, 2048, value=512, step=64, label="Max Output Tokens")
76
  submit_btn = gr.Button("Run", variant="primary")
77
  with gr.Column():
78
  output_text = gr.Textbox(label="Response", lines=14)
79
 
80
- submit_btn.click(run_omni, [audio_input, instruction, max_tokens], output_text)
81
 
82
 
83
  if __name__ == "__main__":
 
18
  processor = Qwen2_5OmniProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
19
 
20
  @spaces.GPU
21
+ def run_omni(audio_path: str, instruction: str) -> str:
22
  if not audio_path:
23
  return "Please upload an audio file."
24
 
 
53
  )
54
  inputs = inputs.to(model.device)
55
 
56
+ output_ids = model.generate(**inputs, max_new_tokens=4096)
57
  output_ids = output_ids[:, inputs["input_ids"].shape[1]:]
58
  response = processor.batch_decode(
59
  output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
 
72
  label="Instruction",
73
  value="Transcribe the audio, then summarize it in one sentence.",
74
  )
 
75
  submit_btn = gr.Button("Run", variant="primary")
76
  with gr.Column():
77
  output_text = gr.Textbox(label="Response", lines=14)
78
 
79
+ submit_btn.click(run_omni, [audio_input, instruction], output_text)
80
 
81
 
82
  if __name__ == "__main__":