atnikos commited on
Commit
8fb911c
·
1 Parent(s): e752355

fixes for speed

Browse files
Files changed (1) hide show
  1. app.py +96 -34
app.py CHANGED
@@ -369,6 +369,26 @@ def get_words(text):
369
  SOURCE_WORDS_CACHE[text] = words
370
  return words
371
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
  def search_motions_two_actions(action1, action2):
373
  """Enhanced substring search with synonym expansion"""
374
  # Create a cache key for this query
@@ -498,7 +518,7 @@ def search_gpt_semantic(action, top_k=1):
498
  return result
499
 
500
  def search_motions_combined(action1, action2, n_motions):
501
- """Optimized combined search approach with synonym expansion"""
502
  # Create a cache key for this query
503
  cache_key = f"{action1.lower().strip()}_{action2.lower().strip()}_{n_motions}"
504
 
@@ -506,49 +526,67 @@ def search_motions_combined(action1, action2, n_motions):
506
  if cache_key in SEARCH_RESULTS_CACHE:
507
  return SEARCH_RESULTS_CACHE[cache_key]
508
 
509
- # Perform the search
 
 
 
 
 
 
 
 
 
510
  string_results = search_motions_two_actions(action1, action2)
511
 
512
- if len(string_results) == 0:
513
- # Fallback to purely semantic
514
- semantic_res, sem_scores = search_motions_semantic(action1, action2, top_k=2*n_motions)
515
- if not semantic_res:
516
- result = (get_random_motions(n_motions), ['NA']*n_motions)
517
- else:
518
- result = (semantic_res[:n_motions], sem_scores[:n_motions])
 
 
 
 
 
 
 
 
519
  else:
520
- if len(string_results) >= n_motions:
521
- result = (random.sample(string_results, n_motions), ['NA']*n_motions)
522
- else:
523
- needed = n_motions - len(string_results)
524
- final_list = list(string_results)
525
- scores_ret = ['NA']*len(final_list)
 
 
526
 
527
- # Fill from semantic
528
- sem_list, sem_score_list = search_motions_semantic(action1, action2, top_k=2*n_motions)
529
- used_combo = {m["motion_combo"] for m in final_list}
530
 
531
  for item, score in zip(sem_list, sem_score_list):
532
  if item["motion_combo"] not in used_combo:
533
- final_list.append(item)
534
- scores_ret.append(score)
535
  used_combo.add(item["motion_combo"])
536
- if len(final_list) == n_motions:
537
  break
538
 
539
  # Still short? Fill with random
540
- if len(final_list) < n_motions:
541
- needed2 = n_motions - len(final_list)
542
  rnd = get_random_motions(needed2)
543
  for r in rnd:
544
  if r["motion_combo"] not in used_combo:
545
- final_list.append(r)
546
- scores_ret.append('NA')
547
  used_combo.add(r["motion_combo"])
548
- if len(final_list) == n_motions:
549
  break
550
-
551
- result = (final_list[:n_motions], scores_ret[:n_motions])
552
 
553
  # Cache the results
554
  SEARCH_RESULTS_CACHE[cache_key] = result
@@ -556,9 +594,21 @@ def search_motions_combined(action1, action2, n_motions):
556
  return result
557
 
558
  def safe_video_update(motion_data, semantic_score, visible=True):
559
- """Optimized video update without unnecessary network checks"""
560
- ssim = str(round(semantic_score, 2)) if semantic_score != 'NA' else ''
561
- actual_annot = f"{motion_data['annotation']} | text sim. : {ssim}"
 
 
 
 
 
 
 
 
 
 
 
 
562
 
563
  return [
564
  gr.update(value=url, visible=visible)
@@ -610,8 +660,20 @@ def update_videos(motions, n_visible, semantic_scores):
610
  if i < len(motions[:n_visible]):
611
  motion = motions[i]
612
  score = semantic_scores[i]
613
- ssim = str(round(score, 2)) if score != 'NA' else ''
614
- actual_annot = f"{motion['annotation']} | text sim. : {ssim}"
 
 
 
 
 
 
 
 
 
 
 
 
615
  updates.extend([
616
  gr.update(value=motion["motion_combo"], visible=True),
617
  gr.update(value=motion["motion_a"], visible=True),
@@ -909,7 +971,7 @@ def prefetch_videos():
909
  threading.Thread(target=prefetch_videos).start()
910
 
911
  # Print ready message
912
- print("Demo ready! Optimized code running with synonym-enhanced TF-IDF similarity.")
913
 
914
  # Launch the demo
915
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
369
  SOURCE_WORDS_CACHE[text] = words
370
  return words
371
 
372
+ def exact_string_search(action1, action2):
373
+ """Search for exact string matches first"""
374
+ exact_results = []
375
+
376
+ action1_lower = action1.lower().strip()
377
+ action2_lower = action2.lower().strip()
378
+
379
+ for k, v in motion_dict.items():
380
+ source_lower = v["source_annot"].lower()
381
+ target_lower = v["target_annot"].lower()
382
+
383
+ # Check for exact matches in either annotation
384
+ cond1 = action1_lower in source_lower or action1_lower in target_lower
385
+ cond2 = action2_lower in source_lower or action2_lower in target_lower
386
+
387
+ if cond1 and cond2:
388
+ exact_results.append(v)
389
+
390
+ return exact_results
391
+
392
  def search_motions_two_actions(action1, action2):
393
  """Enhanced substring search with synonym expansion"""
394
  # Create a cache key for this query
 
518
  return result
519
 
520
  def search_motions_combined(action1, action2, n_motions):
521
+ """Improved combined search approach that prioritizes exact matches"""
522
  # Create a cache key for this query
523
  cache_key = f"{action1.lower().strip()}_{action2.lower().strip()}_{n_motions}"
524
 
 
526
  if cache_key in SEARCH_RESULTS_CACHE:
527
  return SEARCH_RESULTS_CACHE[cache_key]
528
 
529
+ # 1. First try exact string matches
530
+ exact_results = exact_string_search(action1, action2)
531
+
532
+ if len(exact_results) >= n_motions:
533
+ # If we have enough exact matches, return them
534
+ result = (random.sample(exact_results, n_motions), ['EXACT']*n_motions)
535
+ SEARCH_RESULTS_CACHE[cache_key] = result
536
+ return result
537
+
538
+ # 2. If not enough exact matches, try the enhanced substring search with synonyms
539
  string_results = search_motions_two_actions(action1, action2)
540
 
541
+ # Filter out any results that are already in exact_results
542
+ string_results = [r for r in string_results if r not in exact_results]
543
+
544
+ # Combine exact_results with string_results
545
+ combined_results = list(exact_results)
546
+ combined_scores = ['EXACT'] * len(exact_results)
547
+
548
+ if len(combined_results) + len(string_results) >= n_motions:
549
+ # If we have enough combined results, use them
550
+ needed = n_motions - len(combined_results)
551
+ if needed > 0:
552
+ combined_results.extend(random.sample(string_results, needed))
553
+ combined_scores.extend(['SUBSTR'] * needed)
554
+
555
+ result = (combined_results[:n_motions], combined_scores[:n_motions])
556
  else:
557
+ # 3. If still not enough, add all substring matches and then use semantic search
558
+ combined_results.extend(string_results)
559
+ combined_scores.extend(['SUBSTR'] * len(string_results))
560
+
561
+ # Use semantic search for the remaining needed motions
562
+ needed = n_motions - len(combined_results)
563
+ if needed > 0:
564
+ sem_list, sem_score_list = search_motions_semantic(action1, action2, top_k=2*needed)
565
 
566
+ # Filter out duplicates
567
+ used_combo = {m["motion_combo"] for m in combined_results}
 
568
 
569
  for item, score in zip(sem_list, sem_score_list):
570
  if item["motion_combo"] not in used_combo:
571
+ combined_results.append(item)
572
+ combined_scores.append(score)
573
  used_combo.add(item["motion_combo"])
574
+ if len(combined_results) == n_motions:
575
  break
576
 
577
  # Still short? Fill with random
578
+ if len(combined_results) < n_motions:
579
+ needed2 = n_motions - len(combined_results)
580
  rnd = get_random_motions(needed2)
581
  for r in rnd:
582
  if r["motion_combo"] not in used_combo:
583
+ combined_results.append(r)
584
+ combined_scores.append('RANDOM')
585
  used_combo.add(r["motion_combo"])
586
+ if len(combined_results) == n_motions:
587
  break
588
+
589
+ result = (combined_results[:n_motions], combined_scores[:n_motions])
590
 
591
  # Cache the results
592
  SEARCH_RESULTS_CACHE[cache_key] = result
 
594
  return result
595
 
596
  def safe_video_update(motion_data, semantic_score, visible=True):
597
+ """Optimized video update with match type display"""
598
+
599
+ # Prepare the annotation text based on the match type
600
+ if semantic_score == 'EXACT':
601
+ match_info = "Exact Match"
602
+ elif semantic_score == 'SUBSTR':
603
+ match_info = "Substring Match"
604
+ elif semantic_score == 'RANDOM':
605
+ match_info = "Random Result"
606
+ else:
607
+ # For semantic matches, round to 2 decimal places
608
+ ssim = str(round(semantic_score, 2)) if semantic_score != 'NA' else ''
609
+ match_info = f"Semantic Match (sim: {ssim})"
610
+
611
+ actual_annot = f"{motion_data['annotation']} | {match_info}"
612
 
613
  return [
614
  gr.update(value=url, visible=visible)
 
660
  if i < len(motions[:n_visible]):
661
  motion = motions[i]
662
  score = semantic_scores[i]
663
+
664
+ # Handle different score types
665
+ if score == 'EXACT':
666
+ match_info = "Exact Match"
667
+ elif score == 'SUBSTR':
668
+ match_info = "Substring Match"
669
+ elif score == 'RANDOM':
670
+ match_info = "Random Result"
671
+ else:
672
+ # For semantic matches, round to 2 decimal places
673
+ ssim = str(round(score, 2)) if score != 'NA' else ''
674
+ match_info = f"Semantic Match (sim: {ssim})"
675
+
676
+ actual_annot = f"{motion['annotation']} | {match_info}"
677
  updates.extend([
678
  gr.update(value=motion["motion_combo"], visible=True),
679
  gr.update(value=motion["motion_a"], visible=True),
 
971
  threading.Thread(target=prefetch_videos).start()
972
 
973
  # Print ready message
974
+ print("Demo ready! Optimized code running with exact matching prioritized over synonym-enhanced TF-IDF similarity.")
975
 
976
  # Launch the demo
977
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)