Spaces:
Runtime error
Runtime error
| import torch | |
| import torchaudio | |
| from omegaconf import OmegaConf | |
| from huggingface_hub import snapshot_download | |
| import numpy as np | |
| import json | |
| import os | |
| from safetensors.torch import load_file | |
| # Imports from the jamify library | |
| from jam.model.cfm import CFM | |
| from jam.model.dit import DiT | |
| from jam.model.vae import StableAudioOpenVAE | |
| from jam.dataset import DiffusionWebDataset, enhance_webdataset_config | |
| from muq import MuQMuLan | |
| # Helper functions adapted from jamify/src/jam/infer.py | |
| def get_negative_style_prompt(device, file_path): | |
| vocal_style = np.load(file_path) | |
| vocal_style = torch.from_numpy(vocal_style).to(device) | |
| return vocal_style.half() | |
| def normalize_audio(audio): | |
| audio = audio - audio.mean(-1, keepdim=True) | |
| audio = audio / (audio.abs().max(-1, keepdim=True).values + 1e-8) | |
| return audio | |
| class Jamify: | |
| def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu"): | |
| self.device = torch.device(device) | |
| # --- FIX: Point to the local jamify repository for config and public files --- | |
| #jamify_repo_path = "/Users/cy/Desktop/JAM/jamify" | |
| print("Downloading main model checkpoint...") | |
| model_repo_path = snapshot_download(repo_id="declare-lab/jam-0.5") | |
| self.checkpoint_path = os.path.join(model_repo_path, "jam-0_5.safetensors") | |
| # Use local config and data files | |
| config_path = os.path.join(model_repo_path, "jam_infer.yaml") | |
| self.negative_style_prompt_path = os.path.join(model_repo_path, "vocal.npy") | |
| tokenizer_path = os.path.join(model_repo_path, "en_us_cmudict_ipa_forward.pt") | |
| silence_latent_path = os.path.join(model_repo_path, "silience_latent.pt") | |
| print("Loading configuration...") | |
| self.config = OmegaConf.load(config_path) | |
| self.config.data.train_dataset.silence_latent_path = silence_latent_path | |
| # --- FIX: Override the relative paths in the config with absolute paths --- | |
| self.config.data.train_dataset.tokenizer_path = tokenizer_path | |
| self.config.evaluation.dataset.tokenizer_path = tokenizer_path | |
| self.config.data.train_dataset.phonemizer_checkpoint = tokenizer_path | |
| print("Loading VAE model...") | |
| self.vae = StableAudioOpenVAE().to(self.device).eval() | |
| print("Loading CFM model...") | |
| self.cfm_model = self._load_cfm_model(self.config.model, self.checkpoint_path) | |
| print("Loading MuQ style model...") | |
| self.muq_model = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large").to(self.device).eval() | |
| print("Setting up dataset processor...") | |
| dataset_cfg = OmegaConf.merge(self.config.data.train_dataset, self.config.evaluation.dataset) | |
| enhance_webdataset_config(dataset_cfg) | |
| dataset_cfg.multiple_styles = False | |
| self.dataset_processor = DiffusionWebDataset(**dataset_cfg) | |
| print("Jamify model loaded successfully.") | |
| def _load_cfm_model(self, model_config, checkpoint_path): | |
| dit_config = model_config["dit"].copy() | |
| if "text_num_embeds" not in dit_config: | |
| dit_config["text_num_embeds"] = 256 | |
| model = CFM( | |
| transformer=DiT(**dit_config), | |
| **model_config["cfm"] | |
| ).to(self.device) | |
| state_dict = load_file(checkpoint_path) | |
| model.load_state_dict(state_dict, strict=False) | |
| return model.eval() | |
| def _generate_style_embedding_from_audio(self, audio_path): | |
| waveform, sample_rate = torchaudio.load(audio_path) | |
| if sample_rate != 24000: | |
| resampler = torchaudio.transforms.Resample(sample_rate, 24000) | |
| waveform = resampler(waveform) | |
| if waveform.shape[0] > 1: | |
| waveform = waveform.mean(dim=0, keepdim=True) | |
| waveform = waveform.squeeze(0).to(self.device) | |
| with torch.inference_mode(): | |
| style_embedding = self.muq_model(wavs=waveform.unsqueeze(0)[..., :24000 * 30]) | |
| return style_embedding[0] | |
| def _generate_style_embedding_from_prompt(self, prompt): | |
| with torch.inference_mode(): | |
| style_embedding = self.muq_model(texts=[prompt]).squeeze(0) | |
| return style_embedding | |
| def predict(self, reference_audio_path, lyrics_json_path, style_prompt, duration_sec=30, steps=50): | |
| print("Starting prediction...") | |
| if reference_audio_path: | |
| print(f"Generating style from audio: {reference_audio_path}") | |
| style_embedding = self._generate_style_embedding_from_audio(reference_audio_path) | |
| elif style_prompt: | |
| print(f"Generating style from prompt: '{style_prompt}'") | |
| style_embedding = self._generate_style_embedding_from_prompt(style_prompt) | |
| else: | |
| print("No style provided, using zero embedding.") | |
| style_embedding = torch.zeros(512, device=self.device) | |
| print(f"Loading lyrics from: {lyrics_json_path}") | |
| with open(lyrics_json_path, 'r') as f: | |
| lrc_data = json.load(f) | |
| if 'word' not in lrc_data: | |
| lrc_data = {'word': lrc_data} | |
| frame_rate = 21.5 | |
| num_frames = int(duration_sec * frame_rate) | |
| fake_latent = torch.randn(128, num_frames) | |
| sample_tuple = ("user_song", fake_latent, style_embedding, lrc_data) | |
| print("Processing sample...") | |
| processed_sample = self.dataset_processor.process_sample_safely(sample_tuple) | |
| if processed_sample is None: | |
| raise ValueError("Failed to process the provided lyrics and style.") | |
| batch = self.dataset_processor.custom_collate_fn([processed_sample]) | |
| for key, value in batch.items(): | |
| if isinstance(value, torch.Tensor): | |
| batch[key] = value.to(self.device) | |
| print("Generating audio latent...") | |
| with torch.inference_mode(): | |
| batch_size = 1 | |
| text = batch["lrc"] | |
| style_prompt_tensor = batch["prompt"] | |
| start_time = batch["start_time"] | |
| duration_abs = batch["duration_abs"] | |
| duration_rel = batch["duration_rel"] | |
| cond = torch.zeros(batch_size, self.cfm_model.max_frames, 64).to(self.device) | |
| pred_frames = [(0, self.cfm_model.max_frames)] | |
| negative_style_prompt = get_negative_style_prompt(self.device, self.negative_style_prompt_path) | |
| negative_style_prompt = negative_style_prompt.repeat(batch_size, 1) | |
| sample_kwargs = self.config.evaluation.sample_kwargs | |
| sample_kwargs.steps = steps | |
| latents, _ = self.cfm_model.sample( | |
| cond=cond, text=text, style_prompt=style_prompt_tensor, | |
| duration_abs=duration_abs, duration_rel=duration_rel, | |
| negative_style_prompt=negative_style_prompt, start_time=start_time, | |
| latent_pred_segments=pred_frames, **sample_kwargs) | |
| latent = latents[0][0] | |
| print("Decoding latent to audio...") | |
| latent_for_vae = latent.transpose(0, 1).unsqueeze(0) | |
| pred_audio = self.vae.decode(latent_for_vae).sample.squeeze(0).detach().cpu() | |
| pred_audio = normalize_audio(pred_audio) | |
| sample_rate = 44100 | |
| trim_samples = int(duration_sec * sample_rate) | |
| if pred_audio.shape[1] > trim_samples: | |
| pred_audio = pred_audio[:, :trim_samples] | |
| output_path = "generated_song.mp3" | |
| print(f"Saving audio to {output_path}") | |
| torchaudio.save(output_path, pred_audio, sample_rate, format="mp3") | |
| return output_path | |