Spaces:
Build error
Build error
| from diffusers import DiffusionPipeline | |
| import os | |
| import sys | |
| from huggingface_hub import HfApi, hf_hub_download | |
| # from .tools import build_dataset_json_from_list | |
| import torch | |
| class MOSDiffusionPipeline(DiffusionPipeline): | |
| def __init__(self, reload_from_ckpt="./qa_mdt/checkpoint_389999.ckpt", base_folder=None): | |
| """ | |
| Initialize the MOS Diffusion pipeline and download the necessary files/folders. | |
| Args: | |
| config_yaml (str): Path to the YAML configuration file. | |
| list_inference (str): Path to the file containing inference prompts. | |
| reload_from_ckpt (str, optional): Checkpoint path to reload from. | |
| base_folder (str, optional): Base folder to store downloaded files. Defaults to the current working directory. | |
| """ | |
| super().__init__() | |
| self.base_folder = base_folder if base_folder else os.getcwd() | |
| self.repo_id = "jadechoghari/qa-mdt" | |
| self.config_yaml = "./qa_mdt/audioldm_train/config/mos_as_token/qa_mdt.yaml" | |
| self.reload_from_ckpt = reload_from_ckpt | |
| config_yaml_path = os.path.join(self.config_yaml) | |
| self.configs = self.load_yaml(config_yaml_path) | |
| self.configs["reload_from_ckpt"] = self.reload_from_ckpt | |
| self.exp_name = os.path.basename(self.config_yaml.split(".")[0]) | |
| self.exp_group_name = os.path.basename(os.path.dirname(self.config_yaml)) | |
| def download_required_folders(self): | |
| """ | |
| Downloads the necessary folders from the Hugging Face Hub if they are not already available locally. | |
| """ | |
| api = HfApi() | |
| files = api.list_repo_files(repo_id=self.repo_id) | |
| required_folders = ["audioldm_train", "checkpoints", "infer", "log", "taming", "test_prompts"] | |
| files_to_download = [f for f in files if any(f.startswith(folder) for folder in required_folders)] | |
| for file in files_to_download: | |
| local_file_path = os.path.join(self.base_folder, file) | |
| if not os.path.exists(local_file_path): | |
| downloaded_file = hf_hub_download(repo_id=self.repo_id, filename=file) | |
| os.makedirs(os.path.dirname(local_file_path), exist_ok=True) | |
| os.rename(downloaded_file, local_file_path) | |
| sys.path.append(self.base_folder) | |
| def load_yaml(self, yaml_path): | |
| """ | |
| Helper method to load the YAML configuration. | |
| """ | |
| import yaml | |
| with open(yaml_path, "r") as f: | |
| return yaml.safe_load(f) | |
| def __call__(self, prompt: str): | |
| """ | |
| Run the MOS Diffusion Pipeline. This method calls the infer function from infer_mos5.py. | |
| """ | |
| from .infer.infer_mos5 import infer | |
| dataset_key = self.build_dataset_json_from_prompt(prompt) | |
| # we run inference with the prompt - configs - and other settings | |
| infer( | |
| dataset_key=dataset_key, | |
| configs=self.configs, | |
| config_yaml_path=self.config_yaml, | |
| exp_group_name="qa_mdt", | |
| exp_name="mos_as_token" | |
| ) | |
| def build_dataset_json_from_prompt(self, prompt: str): | |
| """ | |
| Build dataset_key dynamically from the provided prompt. | |
| """ | |
| # for simplicity let's just return the prompt as the dataset_key | |
| data = [{"wav": "", "caption": prompt}] # no wav file, just the caption (prompt) | |
| return {"data": data} | |
| # Example of how to use the pipeline | |
| if __name__ == "__main__": | |
| pipe = MOSDiffusionPipeline() | |
| result = pipe("A modern synthesizer creating futuristic soundscapes.") | |
| print(result) | |