weiyuchoumou526 commited on
Commit
c83abdd
·
1 Parent(s): 108d76f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -0
app.py CHANGED
@@ -203,6 +203,7 @@ def get_end_number(track_pause_number_slider, video_state, interactive_state):
203
 
204
  return video_state["painted_images"][track_pause_number_slider],interactive_state, operation_log, operation_log
205
 
 
206
  # use sam to get the mask
207
  def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr.SelectData):
208
  """
@@ -265,6 +266,7 @@ def remove_multi_mask(interactive_state, mask_dropdown):
265
  operation_log = [("",""), ("Remove all masks. Try to add new masks","Normal")]
266
  return interactive_state, gr.update(choices=[],value=[]), operation_log, operation_log
267
 
 
268
  def show_mask(video_state, interactive_state, mask_dropdown):
269
  mask_dropdown.sort()
270
  select_frame = video_state["origin_images"][video_state["select_frame_number"]]
@@ -276,6 +278,7 @@ def show_mask(video_state, interactive_state, mask_dropdown):
276
  operation_log = [("",""), ("Added masks {}. If you want to do the inpainting with current masks, please go to step3, and click the Tracking button first and then Inpainting button.".format(mask_dropdown),"Normal")]
277
  return select_frame, operation_log, operation_log
278
 
 
279
  # tracking vos
280
  def vos_tracking_video(video_state, interactive_state, mask_dropdown):
281
  operation_log = [("",""), ("Tracking finished! Try to click the Inpainting button to get the inpainting result.","Normal")]
@@ -337,6 +340,7 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
337
  #### shanggao code for mask save
338
  return video_output, video_state, interactive_state, operation_log, operation_log
339
 
 
340
  def inpaint_video(video_state, *_):
341
  operation_log = [("", ""), ("Inpainting finished!", "Normal")]
342
 
@@ -397,6 +401,7 @@ def generate_video_from_frames(frames, output_path, fps=30):
397
  torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
398
  return output_path
399
 
 
400
  def restart():
401
  operation_log = [("",""), ("Try to upload your video and click the Get video info button to get started!", "Normal")]
402
  return {
 
203
 
204
  return video_state["painted_images"][track_pause_number_slider],interactive_state, operation_log, operation_log
205
 
206
+ @spaces.GPU
207
  # use sam to get the mask
208
  def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr.SelectData):
209
  """
 
266
  operation_log = [("",""), ("Remove all masks. Try to add new masks","Normal")]
267
  return interactive_state, gr.update(choices=[],value=[]), operation_log, operation_log
268
 
269
+ @spaces.GPU
270
  def show_mask(video_state, interactive_state, mask_dropdown):
271
  mask_dropdown.sort()
272
  select_frame = video_state["origin_images"][video_state["select_frame_number"]]
 
278
  operation_log = [("",""), ("Added masks {}. If you want to do the inpainting with current masks, please go to step3, and click the Tracking button first and then Inpainting button.".format(mask_dropdown),"Normal")]
279
  return select_frame, operation_log, operation_log
280
 
281
+ @spaces.GPU
282
  # tracking vos
283
  def vos_tracking_video(video_state, interactive_state, mask_dropdown):
284
  operation_log = [("",""), ("Tracking finished! Try to click the Inpainting button to get the inpainting result.","Normal")]
 
340
  #### shanggao code for mask save
341
  return video_output, video_state, interactive_state, operation_log, operation_log
342
 
343
+ @spaces.GPU
344
  def inpaint_video(video_state, *_):
345
  operation_log = [("", ""), ("Inpainting finished!", "Normal")]
346
 
 
401
  torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
402
  return output_path
403
 
404
+ @spaces.GPU
405
  def restart():
406
  operation_log = [("",""), ("Try to upload your video and click the Get video info button to get started!", "Normal")]
407
  return {