Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import os | |
| import random | |
| import numpy as np | |
| import pandas as pd | |
| import gdown | |
| import base64 | |
| from time import gmtime, strftime | |
| from csv import writer | |
| import json | |
| from datasets import load_dataset | |
| from hfserver import HuggingFaceDatasetSaver, HuggingFaceDatasetJSONSaver | |
| ENVS = ['ShadowHand', 'ShadowHandCatchAbreast', 'ShadowHandOver', 'ShadowHandBlockStack', 'ShadowHandCatchUnderarm', | |
| 'ShadowHandCatchOver2Underarm', 'ShadowHandBottleCap', 'ShadowHandLiftUnderarm', 'ShadowHandTwoCatchUnderarm', | |
| 'ShadowHandDoorOpenInward', 'ShadowHandDoorOpenOutward', 'ShadowHandDoorCloseInward', 'ShadowHandDoorCloseOutward', | |
| 'ShadowHandPushBlock', 'ShadowHandKettle', | |
| 'ShadowHandScissors', 'ShadowHandPen', 'ShadowHandSwingCup', 'ShadowHandGraspAndPlace', 'ShadowHandSwitch'] | |
| # download data from huggingface dataset | |
| # dataset = load_dataset("quantumiracle-git/robotinder-data") | |
| # os.remove('.git/hooks/pre-push') # https://github.com/git-lfs/git-lfs/issues/853 | |
| LOAD_DATA_GOOGLE_DRIVE = True | |
| if LOAD_DATA_GOOGLE_DRIVE: # download data from google drive | |
| # url = 'https://drive.google.com/drive/folders/1JuNQS4R7axTezWj1x4KRAuRt_L26ApxA?usp=sharing' # './processed/' folder in google drive | |
| # url = 'https://drive.google.com/drive/folders/1o8Q9eX-J7F326zv4g2MZWlzR46uVkUF2?usp=sharing' # './processed_zip/' folder in google drive | |
| # url = 'https://drive.google.com/drive/folders/1ZWgpPiZwnWfwlwta8Tu-Jtu2HsS7HAEa?usp=share_link' # './filter_processed_zip/' folder in google drive | |
| # url = 'https://drive.google.com/drive/folders/1ROkuX6rQpyK7vLqF5fL2mggKiMCdKSuY?usp=share_link' # './split_processed_zip/' folder in google drive | |
| # output = './' | |
| # id = url.split('/')[-1] | |
| # os.system(f"gdown --id {id} -O {output} --folder --no-cookies --remaining-ok") | |
| # # VIDEO_PATH = 'processed_zip' | |
| # # VIDEO_PATH = 'filter_processed_zip' | |
| # VIDEO_PATH = 'split_processed_zip' | |
| # import zipfile | |
| # from os import listdir | |
| # from os.path import isfile, join, isdir | |
| # # unzip the zip files to the same location and delete zip files | |
| # path_to_zip_file = VIDEO_PATH | |
| # zip_files = [join(path_to_zip_file, f) for f in listdir(path_to_zip_file)] | |
| # for f in zip_files: | |
| # if f.endswith(".zip"): | |
| # directory_to_extract_to = path_to_zip_file # extracted file itself contains a folder | |
| # print(f'extract data {f} to {directory_to_extract_to}') | |
| # with zipfile.ZipFile(f, 'r') as zip_ref: | |
| # zip_ref.extractall(directory_to_extract_to) | |
| # os.remove(f) | |
| ### multiple urls to handle the retrieve error | |
| import zipfile | |
| from os import listdir | |
| from os.path import isfile, join, isdir | |
| # urls = [ | |
| # 'https://drive.google.com/drive/folders/1BbQe4XtcsalsvwGVLW9jWCkr-ln5pvyf?usp=share_link', # './filter_processed_zip/1' folder in google drive | |
| # 'https://drive.google.com/drive/folders/1saUTUuObPhMJFguc2J_O0K5woCJjYHci?usp=share_link', # './filter_processed_zip/2' folder in google drive | |
| # 'https://drive.google.com/drive/folders/1Kh9_E28-RH8g8EP1V3DhGI7KRs9LB7YJ?usp=share_link', # './filter_processed_zip/3' folder in google drive | |
| # 'https://drive.google.com/drive/folders/1oE75Dz6hxtaSpNhjD22PmQfgQ-PjnEc0?usp=share_link', # './filter_processed_zip/4' folder in google drive | |
| # 'https://drive.google.com/drive/folders/1XSPEKFqNHpXdLho-bnkT6FZZXssW8JkC?usp=share_link', # './filter_processed_zip/5' folder in google drive | |
| # 'https://drive.google.com/drive/folders/1XwjAHqR7kF1uSyZZIydQMoETfdvi0aPD?usp=share_link', | |
| # 'https://drive.google.com/drive/folders/1TceozOWhLsbqP-w-RkforjAVo1M2zsRP?usp=share_link', | |
| # 'https://drive.google.com/drive/folders/1zAP9eDSW5Eh_isACuZJadXcFaJNqEM9u?usp=share_link', | |
| # 'https://drive.google.com/drive/folders/1oK8fyF9A3Pv5JubvrQMjTE9n66vYlyZN?usp=share_link', | |
| # 'https://drive.google.com/drive/folders/1cezGNjlM0ONMM6C0N_PbZVCGsTyVSR0w?usp=share_link', | |
| # ] | |
| urls = [ | |
| 'https://drive.google.com/drive/folders/1SF5jQ7HakO3lFXBon57VP83-AwfnrM3F?usp=share_link', # './split_processed_zip/1' folder in google drive | |
| 'https://drive.google.com/drive/folders/13WuS6ow6sm7ws7A5xzCEhR-2XX_YiIu5?usp=share_link', # './split_processed_zip/2' folder in google drive | |
| 'https://drive.google.com/drive/folders/1GWLffJDOyLkubF2C03UFcB7iFpzy1aDy?usp=share_link', # './split_processed_zip/3' folder in google drive | |
| 'https://drive.google.com/drive/folders/1UKAntA7WliD84AUhRN224PkW4vt9agZW?usp=share_link', # './split_processed_zip/4' folder in google drive | |
| 'https://drive.google.com/drive/folders/11cCQw3qb1vJbviVPfBnOVWVzD_VzHdWs?usp=share_link', # './split_processed_zip/5' folder in google drive | |
| 'https://drive.google.com/drive/folders/1Wvy604wCxEdXAwE7r3sE0L0ieXvM__u8?usp=share_link', | |
| 'https://drive.google.com/drive/folders/1BTv_pMTNGm7m3hD65IgBrX880v-rLIaf?usp=share_link', | |
| 'https://drive.google.com/drive/folders/12x7F11ln2VQkqi8-Mu3kng74eLgifM0N?usp=share_link', | |
| 'https://drive.google.com/drive/folders/1OWkOul2CCrqynqpt44Fu1CBxzNNfOFE2?usp=share_link', | |
| 'https://drive.google.com/drive/folders/1ukwsfrbSEqCBNmRSuAYvYBHijWCQh2OU?usp=share_link', | |
| 'https://drive.google.com/drive/folders/1EO7zumR6sVfsWQWCS6zfNs5WuO2Se6WX?usp=share_link', | |
| 'https://drive.google.com/drive/folders/1aw0iBWvvZiSKng0ejRK8xbNoHLVUFCFu?usp=share_link', | |
| 'https://drive.google.com/drive/folders/1szIcxlVyT5WJtzpqYWYlue0n82A6-xtk?usp=share_link', | |
| ] | |
| output = './' | |
| # VIDEO_PATH = 'processed_zip' | |
| # VIDEO_PATH = 'filter_processed_zip' | |
| VIDEO_PATH = 'split_processed_zip' | |
| for i, url in enumerate(urls): | |
| id = url.split('/')[-1] | |
| os.system(f"gdown --id {id} -O {output} --folder --no-cookies --remaining-ok") | |
| # unzip the zip files to the same location and delete zip files | |
| path_to_zip_file = str(i+1) | |
| zip_files = [join(path_to_zip_file, f) for f in listdir(path_to_zip_file)] | |
| for f in zip_files: | |
| if f.endswith(".zip"): | |
| directory_to_extract_to = VIDEO_PATH # extracted file itself contains a folder | |
| print(f'extract data {f} to {directory_to_extract_to}') | |
| with zipfile.ZipFile(f, 'r') as zip_ref: | |
| zip_ref.extractall(directory_to_extract_to) | |
| os.remove(f) | |
| else: # local data | |
| VIDEO_PATH = 'robotinder-data' | |
| VIDEO_INFO = os.path.join(VIDEO_PATH, 'video_info.json') | |
| def inference(video_path): | |
| # for displaying mp4 with autoplay on Gradio | |
| with open(video_path, "rb") as f: | |
| data = f.read() | |
| b64 = base64.b64encode(data).decode() | |
| html = ( | |
| f""" | |
| <video controls autoplay muted loop> | |
| <source src="data:video/mp4;base64,{b64}" type="video/mp4"> | |
| </video> | |
| """ | |
| ) | |
| return html | |
| def video_identity(video): | |
| return video | |
| def nan(): | |
| return None | |
| FORMAT = ['mp4', 'gif'][0] | |
| def get_huggingface_dataset(): | |
| try: | |
| import huggingface_hub | |
| except (ImportError, ModuleNotFoundError): | |
| raise ImportError( | |
| "Package `huggingface_hub` not found is needed " | |
| "for HuggingFaceDatasetSaver. Try 'pip install huggingface_hub'." | |
| ) | |
| HF_TOKEN = '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' # my HF token | |
| DATASET_NAME = 'crowdsourced-robotinder-demo' | |
| FLAGGING_DIR = 'flag/' | |
| path_to_dataset_repo = huggingface_hub.create_repo( | |
| repo_id=DATASET_NAME, | |
| token=HF_TOKEN, | |
| private=False, | |
| repo_type="dataset", | |
| exist_ok=True, | |
| ) | |
| dataset_dir = os.path.join(DATASET_NAME, FLAGGING_DIR) | |
| repo = huggingface_hub.Repository( | |
| local_dir=dataset_dir, | |
| clone_from=path_to_dataset_repo, | |
| use_auth_token=HF_TOKEN, | |
| ) | |
| repo.git_pull(lfs=True) | |
| log_file = os.path.join(dataset_dir, "flag_data.csv") | |
| return repo, log_file | |
| def update(user_choice, left, right, choose_env, data_folder=VIDEO_PATH, flag_to_huggingface=True): | |
| global last_left_video_path | |
| global last_right_video_path | |
| global last_infer_left_video_path | |
| global last_infer_right_video_path | |
| if flag_to_huggingface: # log | |
| env_name = str(last_left_video_path).split('/')[1] # 'robotinder-data/ENV_NAME/' | |
| current_time = strftime("%Y-%m-%d-%H-%M-%S", gmtime()) | |
| info = [env_name, user_choice, last_left_video_path, last_right_video_path, current_time] | |
| print(info) | |
| repo, log_file = get_huggingface_dataset() | |
| with open(log_file, 'a') as file: # incremental change of the file | |
| writer_object = writer(file) | |
| writer_object.writerow(info) | |
| file.close() | |
| if int(current_time.split('-')[-2]) % 5 == 0: # push only on certain minutes | |
| try: | |
| repo.push_to_hub(commit_message=f"Flagged sample at {current_time}") | |
| except: | |
| repo.git_pull(lfs=True) # sync with remote first | |
| repo.push_to_hub(commit_message=f"Flagged sample at {current_time}") | |
| if choose_env == 'Random' or choose_env == '': # random or no selection | |
| envs = get_env_names() | |
| env_name = envs[random.randint(0, len(envs)-1)] | |
| else: | |
| env_name = choose_env | |
| # choose video | |
| left, right = randomly_select_videos(env_name) | |
| last_left_video_path = left | |
| last_right_video_path = right | |
| last_infer_left_video_path = inference(left) | |
| last_infer_right_video_path = inference(right) | |
| return last_infer_left_video_path, last_infer_right_video_path, env_name | |
| def replay(left, right): | |
| return left, right | |
| def parse_envs(folder=VIDEO_PATH, filter=True, MAX_ITER=20000, DEFAULT_ITER=20000): | |
| """ | |
| return a dict of env_name: video_paths | |
| """ | |
| files = {} | |
| if filter: | |
| df = pd.read_csv('Bidexhands_Video.csv') | |
| # print(df) | |
| for env_name in os.listdir(folder): | |
| env_path = os.path.join(folder, env_name) | |
| if os.path.isdir(env_path): | |
| videos = os.listdir(env_path) | |
| video_files = [] | |
| for video in videos: # video name rule: EnvName_Alg_Seed_Timestamp_Checkpoint_video-episode-EpisodeID | |
| if video.endswith(f'.{FORMAT}'): | |
| if filter: | |
| if len(video.split('_')) < 6: | |
| print(f'{video} is wrongly named.') | |
| seed = video.split('_')[2] | |
| checkpoint = video.split('_')[4] | |
| try: | |
| succeed_iteration = df.loc[(df['seed'] == int(seed)) & (df['env_name'] == str(env_name))]['succeed_iteration'].iloc[0] | |
| except: | |
| print(f'Env {env_name} with seed {seed} not found in Bidexhands_Video.csv') | |
| if 'unsolved' in succeed_iteration: | |
| continue | |
| elif pd.isnull(succeed_iteration): | |
| min_iter = DEFAULT_ITER | |
| max_iter = MAX_ITER | |
| elif '-' in succeed_iteration: | |
| [min_iter, max_iter] = succeed_iteration.split('-') | |
| else: | |
| min_iter = succeed_iteration | |
| max_iter = MAX_ITER | |
| # check if the checkpoint is in the valid range | |
| valid_checkpoints = np.arange(int(min_iter), int(max_iter)+1000, 1000) | |
| if int(checkpoint) not in valid_checkpoints: | |
| continue | |
| video_path = os.path.join(folder, env_name, video) | |
| video_files.append(video_path) | |
| # print(video_path) | |
| files[env_name] = video_files | |
| with open(VIDEO_INFO, 'w') as fp: | |
| json.dump(files, fp) | |
| return files | |
| def get_env_names(): | |
| with open(VIDEO_INFO, 'r') as fp: | |
| files = json.load(fp) | |
| return list(files.keys()) | |
| def randomly_select_videos(env_name): | |
| # load the parsed video info | |
| with open(VIDEO_INFO, 'r') as fp: | |
| files = json.load(fp) | |
| env_files = files[env_name] | |
| # randomly choose two videos | |
| selected_video_ids = np.random.choice(len(env_files), 2, replace=False) | |
| left_video_path = env_files[selected_video_ids[0]] | |
| right_video_path = env_files[selected_video_ids[1]] | |
| return left_video_path, right_video_path | |
| def build_interface(iter=3, data_folder=VIDEO_PATH): | |
| import sys | |
| import csv | |
| csv.field_size_limit(sys.maxsize) | |
| HF_TOKEN = os.getenv('HF_TOKEN') | |
| print(HF_TOKEN) | |
| HF_TOKEN = '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' # my HF token | |
| # hf_writer = gr.HuggingFaceDatasetSaver(HF_TOKEN, "crowdsourced-robotinder-demo") # HuggingFace logger instead of local one: https://github.com/gradio-app/gradio/blob/master/gradio/flagging.py | |
| hf_writer = HuggingFaceDatasetSaver(HF_TOKEN, "crowdsourced-robotinder-demo") | |
| # callback = gr.CSVLogger() | |
| callback = hf_writer | |
| # parse the video folder | |
| files = parse_envs() | |
| # build gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Here is <span style=color:cyan>RoboTinder</span>!") | |
| gr.Markdown("### Select the best robot behaviour in your choice!") | |
| # some initial values | |
| env_name = list(files.keys())[random.randint(0, len(files)-1)] # random pick an env | |
| with gr.Row(): | |
| str_env_name = gr.Markdown(f"{env_name}") | |
| # choose video | |
| left_video_path, right_video_path = randomly_select_videos(env_name) | |
| with gr.Row(): | |
| if FORMAT == 'mp4': | |
| # left = gr.PlayableVideo(left_video_path, label="left_video") | |
| # right = gr.PlayableVideo(right_video_path, label="right_video") | |
| infer_left_video_path = inference(left_video_path) | |
| infer_right_video_path = inference(right_video_path) | |
| right = gr.HTML(infer_right_video_path, label="right_video") | |
| left = gr.HTML(infer_left_video_path, label="left_video") | |
| else: | |
| left = gr.Image(left_video_path, shape=(1024, 768), label="left_video") | |
| # right = gr.Image(right_video_path).style(height=768, width=1024) | |
| right = gr.Image(right_video_path, label="right_video") | |
| global last_left_video_path | |
| last_left_video_path = left_video_path | |
| global last_right_video_path | |
| last_right_video_path = right_video_path | |
| global last_infer_left_video_path | |
| last_infer_left_video_path = infer_left_video_path | |
| global last_infer_right_video_path | |
| last_infer_right_video_path = infer_right_video_path | |
| # btn1 = gr.Button("Replay") | |
| user_choice = gr.Radio(["Left", "Right", "Not Sure"], label="Which one is your favorite?") | |
| choose_env = gr.Radio(["Random"]+ENVS, label="Choose the next task:") | |
| btn2 = gr.Button("Next") | |
| # This needs to be called at some point prior to the first call to callback.flag() | |
| callback.setup([user_choice, left, right], "flagged_data_points") | |
| # btn1.click(fn=replay, inputs=[left, right], outputs=[left, right]) | |
| btn2.click(fn=update, inputs=[user_choice, left, right, choose_env], outputs=[left, right, str_env_name]) | |
| # We can choose which components to flag -- in this case, we'll flag all of them | |
| # btn2.click(lambda *args: callback.flag(args), [user_choice, left, right], None, preprocess=False) # not using the gradio flagging anymore | |
| return demo | |
| if __name__ == "__main__": | |
| last_left_video_path = None | |
| last_right_video_path = None | |
| demo = build_interface() | |
| # demo.launch(share=True) | |
| demo.launch(share=False) | |