Spaces:
Running
Running
| import gradio as gr | |
| import random | |
| import json | |
| import re | |
| import os | |
| import numpy as np | |
| from collections import Counter | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| import functools | |
| from concurrent.futures import ThreadPoolExecutor | |
| import threading | |
| import nltk | |
| from nltk.corpus import wordnet | |
| from nltk.stem import WordNetLemmatizer | |
| # Add at the beginning of your script, after imports | |
| import os | |
| import nltk | |
| # Get the current directory | |
| current_dir = os.getcwd() | |
| print(f"Current directory: {current_dir}") | |
| # Point NLTK to the data directories in your current directory | |
| nltk_data_path = os.path.join(current_dir, "nltk_data") | |
| print(f"Setting NLTK data path to: {nltk_data_path}") | |
| # Add the path to NLTK's search paths | |
| nltk.data.path.insert(0, nltk_data_path) # Insert at position 0 to search here first | |
| # Print all paths for debugging | |
| print(f"NLTK will search in: {nltk.data.path}") | |
| # Try to load the taggers from your local directory | |
| try: | |
| # Try to directly load the tagger model | |
| from nltk.tag.perceptron import PerceptronTagger | |
| tagger = PerceptronTagger() | |
| print("Successfully loaded PerceptronTagger") | |
| except Exception as e: | |
| print(f"Error loading tagger: {e}") | |
| # nltk.download('averaged_perceptron_tagger_eng') | |
| # Add the header constant at the top of your file | |
| WEBSITE = ("""<div class="embed_hidden" style="text-align: center;"> | |
| <h1>SINC: Spatial Composition of 3D Human Motions for Simultaneous Action Generation</h1> | |
| <h2 style="margin: 1em 0; font-size: 2em;"> | |
| <span style="font-weight: normal; font-style: italic;">ICCV 2023</span> | |
| </h2> | |
| <h3> | |
| <a href="https://atnikos.github.io/" target="_blank" rel="noopener noreferrer">Nikos Athanasiou</a><sup>*1</sup>, | |
| <a href="https://mathis.petrovich.fr/" target="_blank" rel="noopener noreferrer">Mathis Petrovich</a><sup>*1,2</sup>, | |
| <br> | |
| <a href="https://ps.is.mpg.de/person/black" target="_blank" rel="noopener noreferrer">Michael J. Black</a><sup>1</sup>, | |
| <a href="https://gulvarol.github.io/" target="_blank" rel="noopener noreferrer">Gül Varol</a><sup>2</sup> | |
| </h3> | |
| <h3> | |
| <sup>1</sup>MPI for Intelligent Systems, Tübingen, Germany<br> | |
| <sup>2</sup>LIGM, École des Ponts, Univ Gustave Eiffel, CNRS, France | |
| </h3> | |
| </div> | |
| <div style="display:flex; gap: 0.3rem; justify-content: center; align-items: center;" align="center"> | |
| <a href='https://arxiv.org/abs/2304.10417'><img src='https://img.shields.io/badge/Arxiv-2304.10417-A42C25?style=flat&logo=arXiv&logoColor=A42C25'></a> | |
| <a href='https://sinc.is.tue.mpg.de'><img src='https://img.shields.io/badge/Project-Page-%23df5b46?style=flat&logo=Google%20chrome&logoColor=%23df5b46'></a> | |
| </div> | |
| <h2 align="center"> | |
| Download | |
| <a href="https://drive.google.com/drive/folders/1ks9wvNN_arrgBcd0GxN5nRLf5ASPkUgc?usp=sharing" target="_blank" rel="noopener noreferrer"> SINC synthetic data</a>, | |
| if you want to train your models with spatial composition from AMASS. | |
| <br> | |
| The data you are exploring in this demo are | |
| the data created using the | |
| code <a href='https://github.com/atnikos/sinc/blob/main/create_synthetic_babel.py' target="_blank" rel="noopener noreferrer">to compose motions from AMASS in our repo.</a><sup>**</sup> | |
| </h2> | |
| """) | |
| # Action examples | |
| ACTION_EXAMPLES = [ | |
| "walk forward on balance beam", "walk counterclockwise", "sit on chair", "kick a ball", "jump up", | |
| "hold on to rail with right hand", "pick up an object", 'wave with the right hand', 'throw a ball', 'bow' | |
| ] | |
| ACTION_EXAMPLES_SIMULTANEOUS = [ | |
| "walk forward on balance beam while holding rail with right hand", | |
| "walk counterclockwise while waving with left hand", | |
| "sit on chair and wave with left hand", | |
| "pick up an object while bowing", | |
| "walk forward on balance beam while waving left hand" | |
| ] | |
| # Global cache for expensive operations | |
| SIMILARITY_CACHE = {} | |
| SEARCH_RESULTS_CACHE = {} | |
| GPT_SIMILARITY_CACHE = {} | |
| GPT_SEARCH_RESULTS_CACHE = {} | |
| SYNONYM_CACHE = {} | |
| MAX_WORKERS = 4 # For ThreadPoolExecutor | |
| # Cache for TF-IDF | |
| TFIDF_VECTORIZER = None | |
| TFIDF_MATRIX = None | |
| MOTION_TEXTS = [] | |
| MOTION_KEYS = [] | |
| GPT_TEXTS = [] | |
| GPT_KEYS = [] | |
| # Initialize lemmatizer | |
| lemmatizer = WordNetLemmatizer() | |
| # Movement action word mappings - manually defined synonyms for common motion words | |
| ACTION_SYNONYMS = { | |
| 'walk': ['move', 'stroll', 'pace', 'stride', 'wander', 'stalk', 'amble', 'saunter', 'tread', 'step'], | |
| 'run': ['sprint', 'jog', 'dash', 'race', 'bolt', 'scamper', 'rush', 'hurry'], | |
| 'jump': ['leap', 'hop', 'spring', 'bounce', 'vault', 'bound', 'skip'], | |
| 'turn': ['rotate', 'spin', 'twist', 'revolve', 'pivot', 'swivel', 'whirl'], | |
| 'wave': ['signal', 'gesture', 'flap', 'flutter', 'waggle', 'shake', 'brandish'], | |
| 'sit': ['perch', 'recline', 'rest', 'squat'], | |
| 'stand': ['rise', 'upright', 'erect', 'vertical'], | |
| 'throw': ['toss', 'hurl', 'fling', 'chuck', 'lob', 'pitch', 'cast'], | |
| 'grab': ['grasp', 'clutch', 'seize', 'grip', 'hold', 'take', 'catch'], | |
| 'pick': ['lift', 'raise', 'hoist', 'elevate'], | |
| 'kick': ['boot', 'punt', 'strike'], | |
| 'bow': ['bend', 'stoop', 'incline', 'nod'], | |
| 'dance': ['twirl', 'sway', 'shimmy', 'boogie', 'groove', 'swing'], | |
| 'balance': ['steady', 'stabilize', 'poise', 'equilibrium'], | |
| 'forward': ['ahead', 'onward', 'frontward', 'forth'], | |
| 'backward': ['back', 'rearward', 'reverse', 'retreat'], | |
| 'clockwise': ['right', 'rightward', 'rightways'], | |
| 'counterclockwise': ['left', 'leftward', 'leftways', 'anticlockwise'], | |
| 'hold': ['grip', 'grasp', 'clutch', 'clasp', 'clench', 'possess'] | |
| } | |
| # Build reverse mapping for faster lookups | |
| REVERSE_SYNONYMS = {} | |
| for word, synonyms in ACTION_SYNONYMS.items(): | |
| REVERSE_SYNONYMS[word] = word # A word is its own synonym | |
| for synonym in synonyms: | |
| REVERSE_SYNONYMS[synonym] = word | |
| def get_wordnet_pos(word): | |
| """Map POS tag to first character used by WordNet lemmatizer | |
| with fallback for errors""" | |
| try: | |
| tag = nltk.tag.pos_tag([word])[0][1][0].upper() | |
| tag_dict = {"J": wordnet.ADJ, | |
| "N": wordnet.NOUN, | |
| "V": wordnet.VERB, | |
| "R": wordnet.ADV} | |
| return tag_dict.get(tag, wordnet.NOUN) | |
| except Exception as e: | |
| print(f"POS tagging error for word '{word}': {e}") | |
| # Default to NOUN as fallback | |
| return wordnet.NOUN | |
| def get_synonyms(word): | |
| """Get all synonyms for a word using WordNet and our custom action mappings""" | |
| if word in SYNONYM_CACHE: | |
| return SYNONYM_CACHE[word] | |
| synonyms = set() | |
| # Add the word itself | |
| synonyms.add(word) | |
| # Check our custom action mappings first (faster and more domain-specific) | |
| if word in REVERSE_SYNONYMS: | |
| canonical_word = REVERSE_SYNONYMS[word] | |
| synonyms.add(canonical_word) | |
| synonyms.update(ACTION_SYNONYMS.get(canonical_word, [])) | |
| # Then check WordNet (more general but can be noisy) | |
| try: | |
| word_lemma = lemmatizer.lemmatize(word, get_wordnet_pos(word)) | |
| for syn in wordnet.synsets(word_lemma): | |
| for lemma in syn.lemmas(): | |
| synonyms.add(lemma.name().lower().replace('_', ' ')) | |
| except Exception as e: | |
| print(f"Error getting WordNet synonyms for '{word}': {e}") | |
| SYNONYM_CACHE[word] = synonyms | |
| return synonyms | |
| def expand_query_with_synonyms(query): | |
| """Expand a query with synonyms for each term""" | |
| try: | |
| words = nltk.word_tokenize(query.lower()) | |
| except Exception as e: | |
| print(f"Tokenization error: {e}") | |
| # Fallback to simple split if tokenization fails | |
| words = query.lower().split() | |
| expanded_terms = [] | |
| for word in words: | |
| if len(word) > 2: # Only expand words with length > 2 to avoid stop words | |
| synonyms = get_synonyms(word) | |
| expanded_terms.extend(synonyms) | |
| else: | |
| expanded_terms.append(word) | |
| # Join back into a space-separated string | |
| return ' '.join(expanded_terms) | |
| def create_example_buttons(textbox, loftexts): | |
| """Creates clickable buttons for example actions""" | |
| return gr.Examples( | |
| examples=loftexts, | |
| inputs=textbox, | |
| label="Example Actions" | |
| ) | |
| # Load motion data | |
| def load_json_dict(file_path): | |
| with open(file_path, "r") as f: | |
| return json.load(f) | |
| # Load data at startup | |
| print("Loading motion data...") | |
| motion_dict = load_json_dict("for_website_v4.json") | |
| motion_dict = { | |
| key: value for key, value in motion_dict.items() | |
| if "guide forward walk" not in value['source_annot'].lower() | |
| and "guide forward walk" not in value['target_annot'].lower() | |
| } | |
| print("Loading GPT labels...") | |
| GPT_LABELS_LIST = load_json_dict('gpt3-labels-list.json') | |
| GPT_LABELS_LIST = {k: v[2] for k, v in GPT_LABELS_LIST.items()} | |
| # TF-IDF based similarity implementation with synonym expansion | |
| def initialize_tfidf(): | |
| """Initialize TF-IDF vectorizer and precompute matrices""" | |
| global TFIDF_VECTORIZER, TFIDF_MATRIX, MOTION_TEXTS, MOTION_KEYS | |
| print("Initializing TF-IDF vectorizer...") | |
| # Extract text descriptions from the motion dictionary for TF-IDF | |
| MOTION_TEXTS = [] | |
| MOTION_KEYS = [] | |
| for key, motion in motion_dict.items(): | |
| # Combine source and target annotations | |
| text = f"{motion['source_annot']} {motion['target_annot']}".lower() | |
| MOTION_TEXTS.append(text) | |
| MOTION_KEYS.append(key) | |
| # Initialize the TF-IDF vectorizer | |
| TFIDF_VECTORIZER = TfidfVectorizer( | |
| lowercase=True, | |
| stop_words='english', | |
| ngram_range=(1, 2), # Include bigrams for better matching | |
| max_features=20000, # Increased to accommodate synonym expansions | |
| min_df=1 # Lower threshold to catch less frequent terms | |
| ) | |
| # Fit and transform to get TF-IDF vectors | |
| TFIDF_MATRIX = TFIDF_VECTORIZER.fit_transform(MOTION_TEXTS) | |
| print(f"TF-IDF matrix created with shape {TFIDF_MATRIX.shape}") | |
| # Also create GPT labels matrix | |
| initialize_gpt_tfidf() | |
| def initialize_gpt_tfidf(): | |
| """Initialize TF-IDF for GPT labels""" | |
| global GPT_TEXTS, GPT_KEYS | |
| print("Initializing TF-IDF for GPT labels...") | |
| GPT_TEXTS = [] | |
| GPT_KEYS = [] | |
| for key, text in GPT_LABELS_LIST.items(): | |
| GPT_TEXTS.append(text.lower()) | |
| GPT_KEYS.append(key) | |
| def compute_tfidf_similarity(query, top_k=10): | |
| """Compute similarity using TF-IDF vectors with synonym expansion""" | |
| global TFIDF_VECTORIZER, TFIDF_MATRIX, MOTION_TEXTS, MOTION_KEYS | |
| # Original query for cache key | |
| original_query = query.lower().strip() | |
| # Check cache first | |
| cache_key = f"tfidf_{original_query}_{top_k}" | |
| if cache_key in SIMILARITY_CACHE: | |
| return SIMILARITY_CACHE[cache_key] | |
| try: | |
| # Expand query with synonyms | |
| expanded_query = expand_query_with_synonyms(original_query) | |
| # Transform query to TF-IDF space | |
| query_vector = TFIDF_VECTORIZER.transform([expanded_query]) | |
| # Compute cosine similarity between query and all texts | |
| # Using matrix multiplication for sparse matrices | |
| similarities = (query_vector @ TFIDF_MATRIX.T).toarray().flatten() | |
| # Get indices of top_k highest similarity scores | |
| top_indices = np.argsort(similarities)[-top_k:][::-1] | |
| # Get the corresponding entries and scores | |
| top_entries = [motion_dict[MOTION_KEYS[idx]] for idx in top_indices] | |
| top_scores = [similarities[idx] for idx in top_indices] | |
| result = (top_entries, top_scores) | |
| except Exception as e: | |
| print(f"Error in TF-IDF similarity computation: {e}") | |
| # Fallback to random motions if TF-IDF fails | |
| result = (get_random_motions(top_k), ['NA']*top_k) | |
| SIMILARITY_CACHE[cache_key] = result | |
| return result | |
| def compute_gpt_tfidf_similarity(query): | |
| """Compute similarity between query and GPT labels using TF-IDF with synonym expansion""" | |
| global TFIDF_VECTORIZER, GPT_TEXTS, GPT_KEYS | |
| # Original query for cache key | |
| original_query = query.lower().strip() | |
| # Check cache first | |
| cache_key = f"gpt_tfidf_{original_query}" | |
| if cache_key in GPT_SIMILARITY_CACHE: | |
| return GPT_SIMILARITY_CACHE[cache_key] | |
| try: | |
| # Expand query with synonyms | |
| expanded_query = expand_query_with_synonyms(original_query) | |
| # Transform query and all GPT texts to TF-IDF space | |
| query_vector = TFIDF_VECTORIZER.transform([expanded_query]) | |
| gpt_vectors = TFIDF_VECTORIZER.transform(GPT_TEXTS) | |
| # Compute cosine similarity between query and all GPT texts | |
| similarities = (query_vector @ gpt_vectors.T).toarray().flatten() | |
| # Get the index of highest similarity score | |
| best_idx = np.argmax(similarities) | |
| best_key = GPT_KEYS[best_idx] | |
| best_text = GPT_LABELS_LIST[best_key] | |
| best_sim = similarities[best_idx] | |
| result = (best_key, best_text, best_sim) | |
| except Exception as e: | |
| print(f"Error in GPT TF-IDF similarity computation: {e}") | |
| # Fallback to first GPT label if computation fails | |
| if GPT_KEYS: | |
| result = (GPT_KEYS[0], GPT_LABELS_LIST[GPT_KEYS[0]], 0.5) | |
| else: | |
| result = (None, None, 0) | |
| GPT_SIMILARITY_CACHE[cache_key] = result | |
| return result | |
| # Precompile regex pattern | |
| WORD_PATTERN = re.compile(r'\b\w+\b') | |
| # Cache the word lists to avoid repeated tokenization | |
| SOURCE_WORDS_CACHE = {} | |
| TARGET_WORDS_CACHE = {} | |
| def get_words(text): | |
| """Tokenize text and cache the results""" | |
| if text in SOURCE_WORDS_CACHE: | |
| return SOURCE_WORDS_CACHE[text] | |
| words = set(WORD_PATTERN.findall(text.lower())) | |
| SOURCE_WORDS_CACHE[text] = words | |
| return words | |
| def exact_string_search(action1, action2): | |
| """Search for exact string matches first""" | |
| exact_results = [] | |
| action1_lower = action1.lower().strip() | |
| action2_lower = action2.lower().strip() | |
| for k, v in motion_dict.items(): | |
| source_lower = v["source_annot"].lower() | |
| target_lower = v["target_annot"].lower() | |
| # Check for exact matches in either annotation | |
| cond1 = action1_lower in source_lower or action1_lower in target_lower | |
| cond2 = action2_lower in source_lower or action2_lower in target_lower | |
| if cond1 and cond2: | |
| exact_results.append(v) | |
| return exact_results | |
| def search_motions_two_actions(action1, action2): | |
| """Enhanced substring search with synonym expansion""" | |
| # Create a cache key for this query | |
| cache_key = f"{action1.lower().strip()}_{action2.lower().strip()}" | |
| # Check if we already have results for this query | |
| if cache_key in SEARCH_RESULTS_CACHE: | |
| return SEARCH_RESULTS_CACHE[cache_key] | |
| try: | |
| # Convert actions into lists of words | |
| action1_words = set(action1.lower().strip().split()) | |
| action2_words = set(action2.lower().strip().split()) | |
| # Expand with synonyms | |
| expanded_action1_words = set() | |
| for word in action1_words: | |
| if len(word) > 2: # Only consider words longer than 2 chars | |
| expanded_action1_words.update(get_synonyms(word)) | |
| else: | |
| expanded_action1_words.add(word) | |
| expanded_action2_words = set() | |
| for word in action2_words: | |
| if len(word) > 2: # Only consider words longer than 2 chars | |
| expanded_action2_words.update(get_synonyms(word)) | |
| else: | |
| expanded_action2_words.add(word) | |
| results = [] | |
| for k, v in motion_dict.items(): | |
| # Get or compute tokenized words from cache | |
| if v["source_annot"] not in SOURCE_WORDS_CACHE: | |
| SOURCE_WORDS_CACHE[v["source_annot"]] = set(WORD_PATTERN.findall(v["source_annot"].lower())) | |
| if v["target_annot"] not in TARGET_WORDS_CACHE: | |
| TARGET_WORDS_CACHE[v["target_annot"]] = set(WORD_PATTERN.findall(v["target_annot"].lower())) | |
| source_words = SOURCE_WORDS_CACHE[v["source_annot"]] | |
| target_words = TARGET_WORDS_CACHE[v["target_annot"]] | |
| # For each word in action1, check if any of its synonyms match | |
| cond1 = False | |
| if action1_words: # Only check if action1 has words | |
| matches = 0 | |
| for word in action1_words: | |
| word_matches = False | |
| if len(word) <= 2: # For short words, just check exact match | |
| if word in source_words or word in target_words: | |
| word_matches = True | |
| else: # For longer words, check all synonyms | |
| for syn in get_synonyms(word): | |
| if syn in source_words or syn in target_words: | |
| word_matches = True | |
| break | |
| if word_matches: | |
| matches += 1 | |
| # Consider a match if at least 70% of words (or their synonyms) are found | |
| cond1 = (matches / len(action1_words)) >= 0.7 if action1_words else True | |
| else: | |
| cond1 = True | |
| # For each word in action2, check if any of its synonyms match | |
| cond2 = False | |
| if action2_words: # Only check if action2 has words | |
| matches = 0 | |
| for word in action2_words: | |
| word_matches = False | |
| if len(word) <= 2: # For short words, just check exact match | |
| if word in source_words or word in target_words: | |
| word_matches = True | |
| else: # For longer words, check all synonyms | |
| for syn in get_synonyms(word): | |
| if syn in source_words or syn in target_words: | |
| word_matches = True | |
| break | |
| if word_matches: | |
| matches += 1 | |
| # Consider a match if at least 70% of words (or their synonyms) are found | |
| cond2 = (matches / len(action2_words)) >= 0.7 if action2_words else True | |
| else: | |
| cond2 = True | |
| if cond1 and cond2: | |
| results.append(v) | |
| except Exception as e: | |
| print(f"Error in substring search: {e}") | |
| results = [] | |
| # Cache the results | |
| SEARCH_RESULTS_CACHE[cache_key] = results | |
| return results | |
| def search_motions_semantic(action1, action2, top_k=10): | |
| """Semantic search using TF-IDF similarity with synonym expansion""" | |
| query_text = (action1.strip() + " " + action2.strip()).strip().lower() | |
| if not query_text: | |
| return [], [] | |
| # Check cache first | |
| cache_key = f"{query_text}_{top_k}" | |
| if cache_key in SEARCH_RESULTS_CACHE: | |
| return SEARCH_RESULTS_CACHE[cache_key] | |
| # Use TF-IDF similarity | |
| return compute_tfidf_similarity(query_text, top_k) | |
| def get_random_motions(n_motions): | |
| all_vals = list(motion_dict.values()) | |
| return random.sample(all_vals, min(n_motions, len(all_vals))) | |
| def search_gpt_semantic(action, top_k=1): | |
| """Search GPT labels using TF-IDF similarity with synonym expansion""" | |
| query_text = action.strip().lower() | |
| if not query_text: | |
| return None, None, None | |
| # Check cache first | |
| if query_text in GPT_SEARCH_RESULTS_CACHE: | |
| return GPT_SEARCH_RESULTS_CACHE[query_text] | |
| # Use TF-IDF similarity for GPT labels | |
| result = compute_gpt_tfidf_similarity(query_text) | |
| GPT_SEARCH_RESULTS_CACHE[query_text] = result | |
| return result | |
| def search_motions_combined(action1, action2, n_motions): | |
| """Improved combined search approach that prioritizes exact matches""" | |
| # Create a cache key for this query | |
| cache_key = f"{action1.lower().strip()}_{action2.lower().strip()}_{n_motions}" | |
| # Check if we already have results for this query | |
| if cache_key in SEARCH_RESULTS_CACHE: | |
| return SEARCH_RESULTS_CACHE[cache_key] | |
| # 1. First try exact string matches | |
| exact_results = exact_string_search(action1, action2) | |
| if len(exact_results) >= n_motions: | |
| # If we have enough exact matches, return them | |
| result = (random.sample(exact_results, n_motions), ['EXACT']*n_motions) | |
| SEARCH_RESULTS_CACHE[cache_key] = result | |
| return result | |
| # 2. If not enough exact matches, try the enhanced substring search with synonyms | |
| string_results = search_motions_two_actions(action1, action2) | |
| # Filter out any results that are already in exact_results | |
| string_results = [r for r in string_results if r not in exact_results] | |
| # Combine exact_results with string_results | |
| combined_results = list(exact_results) | |
| combined_scores = ['EXACT'] * len(exact_results) | |
| if len(combined_results) + len(string_results) >= n_motions: | |
| # If we have enough combined results, use them | |
| needed = n_motions - len(combined_results) | |
| if needed > 0: | |
| combined_results.extend(random.sample(string_results, needed)) | |
| combined_scores.extend(['SUBSTR'] * needed) | |
| result = (combined_results[:n_motions], combined_scores[:n_motions]) | |
| else: | |
| # 3. If still not enough, add all substring matches and then use semantic search | |
| combined_results.extend(string_results) | |
| combined_scores.extend(['SUBSTR'] * len(string_results)) | |
| # Use semantic search for the remaining needed motions | |
| needed = n_motions - len(combined_results) | |
| if needed > 0: | |
| sem_list, sem_score_list = search_motions_semantic(action1, action2, top_k=2*needed) | |
| # Filter out duplicates | |
| used_combo = {m["motion_combo"] for m in combined_results} | |
| for item, score in zip(sem_list, sem_score_list): | |
| if item["motion_combo"] not in used_combo: | |
| combined_results.append(item) | |
| combined_scores.append(score) | |
| used_combo.add(item["motion_combo"]) | |
| if len(combined_results) == n_motions: | |
| break | |
| # Still short? Fill with random | |
| if len(combined_results) < n_motions: | |
| needed2 = n_motions - len(combined_results) | |
| rnd = get_random_motions(needed2) | |
| for r in rnd: | |
| if r["motion_combo"] not in used_combo: | |
| combined_results.append(r) | |
| combined_scores.append('RANDOM') | |
| used_combo.add(r["motion_combo"]) | |
| if len(combined_results) == n_motions: | |
| break | |
| result = (combined_results[:n_motions], combined_scores[:n_motions]) | |
| # Cache the results | |
| SEARCH_RESULTS_CACHE[cache_key] = result | |
| return result | |
| def safe_video_update(motion_data, semantic_score, visible=True): | |
| """Optimized video update with match type display""" | |
| # Prepare the annotation text based on the match type | |
| if semantic_score == 'EXACT': | |
| match_info = "Exact Match" | |
| elif semantic_score == 'SUBSTR': | |
| match_info = "Substring Match" | |
| elif semantic_score == 'RANDOM': | |
| match_info = "Random Result" | |
| else: | |
| # For semantic matches, round to 2 decimal places | |
| ssim = str(round(semantic_score, 2)) if semantic_score != 'NA' else '' | |
| match_info = f"Semantic Match (sim: {ssim})" | |
| actual_annot = f"{motion_data['annotation']} | {match_info}" | |
| return [ | |
| gr.update(value=url, visible=visible) | |
| for url in (motion_data["motion_combo"], | |
| motion_data["motion_a"], | |
| motion_data["motion_b"]) | |
| ] + [gr.update(value=actual_annot, visible=visible)] | |
| def update_videos(motions, n_visible, semantic_scores): | |
| """Update video components with motion data, with parallel video processing""" | |
| updates = [] | |
| if not motions: | |
| updates.append(gr.update(value='incompatible combination', visible=True)) | |
| remaining = 7 | |
| for _ in range(remaining): | |
| updates.extend([ | |
| gr.update(value=None, visible=False), | |
| gr.update(value=None, visible=False), | |
| gr.update(value=None, visible=False), | |
| gr.update(value=None, visible=False) | |
| ]) | |
| else: | |
| try: | |
| # Prepare all updates in parallel using ThreadPoolExecutor | |
| with ThreadPoolExecutor(max_workers=min(8, n_visible)) as executor: | |
| # Submit all video update tasks | |
| future_updates = [ | |
| executor.submit(safe_video_update, motion, semantic_scores[jj], True) | |
| for jj, motion in enumerate(motions[:n_visible]) | |
| ] | |
| # Collect all updates as they complete | |
| for future in future_updates: | |
| updates.extend(future.result()) | |
| remaining = 8 - len(motions[:n_visible]) | |
| for _ in range(remaining): | |
| updates.extend([ | |
| gr.update(value=None, visible=False), | |
| gr.update(value=None, visible=False), | |
| gr.update(value=None, visible=False), | |
| gr.update(value=None, visible=False) | |
| ]) | |
| except Exception as e: | |
| print(f"Error updating videos: {e}") | |
| # Fallback if parallel processing fails | |
| updates = [] | |
| for i in range(8): | |
| if i < len(motions[:n_visible]): | |
| motion = motions[i] | |
| score = semantic_scores[i] | |
| # Handle different score types | |
| if score == 'EXACT': | |
| match_info = "Exact Match" | |
| elif score == 'SUBSTR': | |
| match_info = "Substring Match" | |
| elif score == 'RANDOM': | |
| match_info = "Random Result" | |
| else: | |
| # For semantic matches, round to 2 decimal places | |
| ssim = str(round(score, 2)) if score != 'NA' else '' | |
| match_info = f"Semantic Match (sim: {ssim})" | |
| actual_annot = f"{motion['annotation']} | {match_info}" | |
| updates.extend([ | |
| gr.update(value=motion["motion_combo"], visible=True), | |
| gr.update(value=motion["motion_a"], visible=True), | |
| gr.update(value=motion["motion_b"], visible=True), | |
| gr.update(value=actual_annot, visible=True) | |
| ]) | |
| else: | |
| updates.extend([ | |
| gr.update(value=None, visible=False), | |
| gr.update(value=None, visible=False), | |
| gr.update(value=None, visible=False), | |
| gr.update(value=None, visible=False) | |
| ]) | |
| return updates | |
| def parse_gpt_labels(text): | |
| """Parse GPT labels from text""" | |
| if text.startswith("Answer: "): | |
| text = text[len("Answer: "):] # Remove the "Answer: " prefix | |
| return text.split("\n") # Split by newline | |
| def failure_update(message, n_motions=None): | |
| """Create UI updates for failure cases""" | |
| updates = [] | |
| # For the first motion: hide videos and display the message in the text box | |
| updates.append(gr.update(value=None, visible=False)) # video_combo for motion 1 | |
| updates.append(gr.update(value=None, visible=False)) # video_a for motion 1 | |
| updates.append(gr.update(value=None, visible=False)) # video_b for motion 1 | |
| updates.append(gr.update(value=message, visible=True)) # annotation text for motion 1 | |
| # For the remaining 7 motions, hide all components | |
| for _ in range(7): | |
| updates.extend([ | |
| gr.update(value=None, visible=False), | |
| gr.update(value=None, visible=False), | |
| gr.update(value=None, visible=False), | |
| gr.update(value=None, visible=False) | |
| ]) | |
| return updates | |
| def handle_interaction(action1, action2, n_motions): | |
| """Handle user interaction with caching for faster responses""" | |
| # Create a cache key for the entire interaction | |
| cache_key = f"interaction_{action1.strip().lower()}_{action2.strip().lower()}_{n_motions}" | |
| # Check if we have cached results for this interaction | |
| if cache_key in SEARCH_RESULTS_CACHE: | |
| return SEARCH_RESULTS_CACHE[cache_key] | |
| try: | |
| if not action1.strip() and not action2.strip(): | |
| # Both empty => random | |
| motions = get_random_motions(n_motions) | |
| result = update_videos(motions, n_motions, ['NA'] * len(motions)) | |
| else: | |
| # Process GPT labels in parallel | |
| with ThreadPoolExecutor(max_workers=2) as executor: | |
| # Submit tasks for processing both actions in parallel | |
| if action1 in GPT_LABELS_LIST: | |
| future_act1 = executor.submit(lambda: parse_gpt_labels(GPT_LABELS_LIST[action1])) | |
| else: | |
| future_act1 = executor.submit(search_gpt_semantic, action1, 1) | |
| if action2 in GPT_LABELS_LIST: | |
| future_act2 = executor.submit(lambda: parse_gpt_labels(GPT_LABELS_LIST[action2])) | |
| else: | |
| future_act2 = executor.submit(search_gpt_semantic, action2, 1) | |
| # Get results | |
| try: | |
| if action1 in GPT_LABELS_LIST: | |
| gpt_act1 = future_act1.result() | |
| else: | |
| best_key, best_text, best_sim = future_act1.result() | |
| if not best_text: | |
| result = failure_update("Action 1 not recognized.") | |
| SEARCH_RESULTS_CACHE[cache_key] = result | |
| return result | |
| gpt_act1 = parse_gpt_labels(best_text) | |
| if action2 in GPT_LABELS_LIST: | |
| gpt_act2 = future_act2.result() | |
| else: | |
| best_key, best_text, best_sim = future_act2.result() | |
| if not best_text: | |
| result = failure_update("Action 2 not recognized.") | |
| SEARCH_RESULTS_CACHE[cache_key] = result | |
| return result | |
| gpt_act2 = parse_gpt_labels(best_text) | |
| except Exception as e: | |
| print(f"Error processing GPT labels: {e}") | |
| result = failure_update("Error processing actions. Please try again.") | |
| SEARCH_RESULTS_CACHE[cache_key] = result | |
| return result | |
| # Check for conflicts | |
| if bool(set(gpt_act1) & set(gpt_act2)): | |
| failure_message = "Incompatible action pair. Please select actions that are not conflicting." | |
| result = failure_update(failure_message) | |
| else: | |
| motions, sem_mot_scores = search_motions_combined(action1, action2, n_motions) | |
| result = update_videos(motions, n_motions, sem_mot_scores) | |
| except Exception as e: | |
| print(f"Error in handle_interaction: {e}") | |
| result = failure_update("An error occurred. Please try again.") | |
| # Cache the result | |
| SEARCH_RESULTS_CACHE[cache_key] = result | |
| return result | |
| # Custom CSS | |
| CUSTOM_CSS = """ | |
| button.compact-button { | |
| width: auto !important; /* Let the button shrink to fit text */ | |
| min-width: unset !important; /* Remove any forced min-width */ | |
| padding: 4px 8px !important; | |
| font-size: 20px !important; | |
| line-height: 1 !important; | |
| } | |
| """ | |
| # Build the Gradio UI | |
| with gr.Blocks(css=CUSTOM_CSS) as demo: | |
| gr.HTML(WEBSITE) | |
| with gr.Tabs(): | |
| with gr.Tab("SINC-Synth exploration"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| with gr.Column(): | |
| action1_textbox = gr.Textbox( | |
| label="Action 1", | |
| placeholder="Select an action or type the first action, e.g. 'walk'", | |
| ) | |
| create_example_buttons(action1_textbox, ACTION_EXAMPLES[:5]) | |
| with gr.Column(): | |
| action2_textbox = gr.Textbox( | |
| label="Action 2", | |
| placeholder="Select an action or type the second action, e.g. 'wave'" | |
| ) | |
| create_example_buttons(action2_textbox, ACTION_EXAMPLES[5:]) | |
| with gr.Column(): | |
| n_motions_radio = gr.Radio( | |
| choices=[2, 4, 6, 8], | |
| label="Number of motions to be shown from the SINC-Synthetic data", | |
| value=2, | |
| show_label=True, | |
| container=True, | |
| ) | |
| with gr.Row(): | |
| search_button = gr.Button("Search", | |
| elem_classes=["compact-button"]) | |
| random_button = gr.Button("Random", | |
| elem_classes=["compact-button"]) | |
| # up to 8 motions | |
| motion_components = [] | |
| videos_per_row = 2 | |
| max_motions = 8 | |
| num_rows = (max_motions + videos_per_row - 1) // videos_per_row # Ceiling division | |
| for i in range(num_rows): | |
| with gr.Row(): | |
| for j in range(videos_per_row): | |
| motion_index = i * videos_per_row + j | |
| if motion_index >= max_motions: | |
| break | |
| with gr.Column(): | |
| video_combo = gr.Video( | |
| label=f"Motion {motion_index + 1}", | |
| visible=False, | |
| width=480, | |
| height=384, | |
| loop=True | |
| ) | |
| with gr.Row(): | |
| video_a = gr.Video( | |
| label="Motion A", | |
| visible=False, | |
| width=320, | |
| height=256, | |
| loop=True | |
| ) | |
| video_b = gr.Video( | |
| label="Motion B", | |
| visible=False, | |
| width=320, | |
| height=256, | |
| loop=True | |
| ) | |
| text = gr.Textbox( | |
| visible=False, | |
| interactive=False | |
| ) | |
| motion_components.extend([video_combo, video_a, video_b, text]) | |
| search_button.click( | |
| fn=handle_interaction, | |
| inputs=[action1_textbox, action2_textbox, n_motions_radio], | |
| outputs=motion_components | |
| ) | |
| random_button.click( | |
| fn=lambda n: handle_interaction("", "", n), | |
| inputs=[n_motions_radio], | |
| outputs=motion_components | |
| ) | |
| gr.HTML((""" | |
| <div style='text-align: center; margin-top: 20px; font-size: 16px;'> | |
| <p><sup>**</sup>Our data in the official paper are using on the fly compositions, | |
| which means than are not computed and filtered offline. This is a minimally | |
| processed version of ~124k motions ranging between 3-7 seconds.</p> | |
| <p>Made with ❤️ by Nikos Athanasiou</p> | |
| </div> | |
| """) | |
| ) | |
| with gr.Tab("Simultaneous Motion Generation with SINC model"): | |
| gr.HTML("<h2>Motion Generation from Text [TBD. Currenly under construction.]</h2>") | |
| with gr.Row(): | |
| text_input_gen = gr.Textbox( | |
| label="Motion Description", | |
| placeholder="Describe the motion, e.g. 'A person walking forward while waving'" | |
| ) | |
| create_example_buttons(text_input_gen, ACTION_EXAMPLES_SIMULTANEOUS) | |
| generate_button = gr.Button("Generate Motion", | |
| elem_classes=["compact-button"]) | |
| with gr.Row(): | |
| output_video = gr.Video( | |
| label="Generated Motion", | |
| visible=True, | |
| width=320, | |
| height=180 | |
| ) | |
| def generate_motion(text): | |
| # Placeholder function - replace with actual model inference | |
| # Return None instead of a string path to avoid schema conversion issues | |
| return None | |
| generate_button.click( | |
| fn=generate_motion, | |
| inputs=[text_input_gen], | |
| outputs=[output_video] | |
| ) | |
| # Initialize TF-IDF at startup | |
| initialize_tfidf() | |
| # Precompute synonyms for common action words | |
| print("Precomputing synonyms for common action words...") | |
| for action in ACTION_SYNONYMS: | |
| get_synonyms(action) | |
| # Video prefetching | |
| def prefetch_videos(): | |
| """Prefetch some common videos to warm up the cache""" | |
| print("Prefetching common videos...") | |
| try: | |
| # Get a small set of common videos to prefetch | |
| random_motions = get_random_motions(4) | |
| common_actions = [("walk", "wave"), ("sit", "bow"), ("jump", "throw")] | |
| with ThreadPoolExecutor(max_workers=8) as executor: | |
| futures = [] | |
| # Add random motions to prefetch list | |
| for motion in random_motions: | |
| futures.append(executor.submit( | |
| lambda m: (m["motion_combo"], m["motion_a"], m["motion_b"]), | |
| motion | |
| )) | |
| # Add common action combinations | |
| for act1, act2 in common_actions: | |
| motions, _ = search_motions_combined(act1, act2, 2) | |
| if motions: | |
| for motion in motions: | |
| futures.append(executor.submit( | |
| lambda m: (m["motion_combo"], m["motion_a"], m["motion_b"]), | |
| motion | |
| )) | |
| # Wait for all prefetch operations to complete | |
| for future in futures: | |
| future.result() | |
| print("Video prefetching complete") | |
| except Exception as e: | |
| print(f"Error in video prefetching: {e}") | |
| # Start prefetching in a separate thread to not block startup | |
| threading.Thread(target=prefetch_videos).start() | |
| # Print ready message | |
| print("Demo ready! Optimized code running with exact matching prioritized over synonym-enhanced TF-IDF similarity.") | |
| # Launch the demo | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=False) | |