Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from pathlib import Path | |
| import typing as tp | |
| import torch | |
| import torchmetrics | |
| from transformers import RobertaTokenizer # type: ignore | |
| from ..data.audio_utils import convert_audio | |
| from ..environment import AudioCraftEnvironment | |
| from ..utils.utils import load_clap_state_dict | |
| try: | |
| import laion_clap # type: ignore | |
| except ImportError: | |
| laion_clap = None | |
| class TextConsistencyMetric(torchmetrics.Metric): | |
| """Text consistency metric measuring consistency between audio and text pairs.""" | |
| def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) -> None: | |
| raise NotImplementedError("implement how to update the metric from the audio and text pairs.") | |
| def compute(self): | |
| raise NotImplementedError("implement how to compute the final metric score.") | |
| class CLAPTextConsistencyMetric(TextConsistencyMetric): | |
| """Text consistency metric relying on Contrastive Language-Audio Pretraining (CLAP). | |
| This metric is similar to the MuLan Cycle Consistency from MusicLM (https://arxiv.org/pdf/2301.11325.pdf) | |
| or the CLAP score used in Make-An-Audio (https://arxiv.org/pdf/2301.12661v1.pdf). | |
| As a joint audio-text embedding model, a pretrained CLAP model can be used to quantify the | |
| similarity between audio-text pairs. We compute the CLAP embeddings from the text descriptions as | |
| well as the generated audio based on them, and define the MCC metric as the average cosine similarity | |
| between these embeddings. | |
| Model implementation & pre-trained checkpoints: https://github.com/LAION-AI/CLAP | |
| """ | |
| def __init__(self, model_path: tp.Union[str, Path], model_arch: str = 'HTSAT-tiny', enable_fusion: bool = False): | |
| super().__init__() | |
| if laion_clap is None: | |
| raise ImportError("Please install CLAP to compute text consistency: 'pip install laion_clap'") | |
| self.add_state("cosine_sum", default=torch.tensor(0.), dist_reduce_fx="sum") | |
| self.add_state("weight", default=torch.tensor(0.), dist_reduce_fx="sum") | |
| self._initialize_model(model_path, model_arch, enable_fusion) | |
| def _initialize_model(self, model_path: tp.Union[str, Path], model_arch: str, enable_fusion: bool): | |
| model_path = AudioCraftEnvironment.resolve_reference_path(model_path) | |
| self.tokenize = RobertaTokenizer.from_pretrained('roberta-base') | |
| self.model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=model_arch) | |
| self.model_sample_rate = 48_000 | |
| load_clap_state_dict(self.model, model_path) | |
| self.model.eval() | |
| def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict: | |
| # we use the default params from CLAP module here as well | |
| return self.tokenize(texts, padding="max_length", truncation=True, max_length=77, return_tensors="pt") | |
| def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) -> None: | |
| """Compute cosine similarity between audio and text pairs and accumulate scores over the dataset.""" | |
| assert audio.size(0) == len(text), "Number of audio and text samples should match" | |
| assert torch.all(sample_rates == sample_rates[0].item()), "All items in batch should have the same sample rate" | |
| sample_rate = int(sample_rates[0].item()) | |
| # convert audio batch to 48kHz monophonic audio with no channel dimension: [B, C, T] -> [B, T] | |
| audio = convert_audio(audio, from_rate=sample_rate, to_rate=self.model_sample_rate, to_channels=1).mean(dim=1) | |
| audio_embeddings = self.model.get_audio_embedding_from_data(audio, use_tensor=True) | |
| text_embeddings = self.model.get_text_embedding(text, tokenizer=self._tokenizer, use_tensor=True) | |
| # cosine similarity between the text and the audio embedding | |
| cosine_sim = torch.nn.functional.cosine_similarity(audio_embeddings, text_embeddings, dim=1, eps=1e-8) | |
| self.cosine_sum += cosine_sim.sum(dim=0) | |
| self.weight += torch.tensor(cosine_sim.size(0)) | |
| def compute(self): | |
| """Computes the average cosine similarty across all audio/text pairs.""" | |
| assert self.weight.item() > 0, "Unable to compute with total number of comparisons <= 0" # type: ignore | |
| return (self.cosine_sum / self.weight).item() # type: ignore | |
 
			
