Spaces:
Runtime error
Runtime error
| import argparse | |
| import itertools | |
| import json | |
| import os | |
| import random | |
| import time | |
| from functools import partial | |
| import re | |
| from evaluate_tokenizer import EvaluationTokenizer | |
| import editdistance as ed | |
| import torch | |
| from transformers.pipelines.audio_utils import ffmpeg_read | |
| import requests | |
| from whisper_normalizer.english import EnglishTextNormalizer | |
| from whisper_normalizer.basic import BasicTextNormalizer | |
| from cn_tn import TextNorm | |
| import zhconv | |
| english_normalizer = EnglishTextNormalizer() | |
| chinese_normalizer = TextNorm( | |
| to_banjiao = False, | |
| to_upper = False, | |
| to_lower = False, | |
| remove_fillers = False, | |
| remove_erhua =False, | |
| check_chars = False, | |
| remove_space = False, | |
| cc_mode = '', | |
| ) | |
| basic_normalizer = BasicTextNormalizer() | |
| from tqdm import tqdm | |
| from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration | |
| PUNCS = '!,.?;:' | |
| ds_collections = { | |
| 'librispeech': {'path': 'asr/librispeech_eval.jsonl','language': 'en'}, | |
| 'aishell2': {'path': 'asr/aishell2_eval.jsonl', 'language': 'zh'}, | |
| 'cv15_en': {'path': 'asr/cv15_asr_en_eval.jsonl', 'language': 'en'}, | |
| 'cv15_zh': {'path': 'asr/cv15_asr_zh_eval.jsonl', 'language': 'zh'}, | |
| 'cv15_yue': {'path': 'asr/cv15_asr_yue_eval.jsonl', 'language': 'yue'}, | |
| 'cv15_fr': {'path': 'asr/cv15_asr_fr_eval.jsonl', 'language': 'fr'}, | |
| 'fluers_zh': {'path': 'asr/fleurs_asr_zh_eval.jsonl', 'language': 'zh'}, | |
| } | |
| class AudioDataset(torch.utils.data.Dataset): | |
| def __init__(self, ds): | |
| path = ds['path'] | |
| self.datas = open(path).readlines() | |
| def __len__(self): | |
| return len(self.datas) | |
| def __getitem__(self, idx): | |
| data = json.loads(self.datas[idx].strip()) | |
| audio = data['audio'] | |
| source = data['source'] | |
| prompt = "<|audio_bos|><|AUDIO|><|audio_eos|>"+data['prompt'] | |
| gt = data['gt'] | |
| return { | |
| 'audio': audio, | |
| 'prompt': prompt, | |
| 'source': source, | |
| 'gt': gt | |
| } | |
| def read_audio(audio_path): | |
| if audio_path.startswith("http://") or audio_path.startswith("https://"): | |
| # We need to actually check for a real protocol, otherwise it's impossible to use a local file | |
| # like http_huggingface_co.png | |
| inputs = requests.get(audio_path).content | |
| else: | |
| with open(audio_path, "rb") as f: | |
| inputs = f.read() | |
| return inputs | |
| def collate_fn(inputs, processor): | |
| input_texts = [_['prompt'] for _ in inputs] | |
| source = [_['source'] for _ in inputs] | |
| gt = [_['gt'] for _ in inputs] | |
| audio_path = [_['audio'] for _ in inputs] | |
| input_audios = [ffmpeg_read(read_audio(_['audio']),sampling_rate=processor.feature_extractor.sampling_rate) for _ in inputs] | |
| inputs = processor(text=input_texts, audios=input_audios, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt", padding=True) | |
| return inputs, audio_path, source, gt | |
| class InferenceSampler(torch.utils.data.sampler.Sampler): | |
| def __init__(self, size): | |
| self._size = int(size) | |
| assert size > 0 | |
| self._rank = torch.distributed.get_rank() | |
| self._world_size = torch.distributed.get_world_size() | |
| self._local_indices = self._get_local_indices(size, self._world_size, | |
| self._rank) | |
| def _get_local_indices(total_size, world_size, rank): | |
| shard_size = total_size // world_size | |
| left = total_size % world_size | |
| shard_sizes = [shard_size + int(r < left) for r in range(world_size)] | |
| begin = sum(shard_sizes[:rank]) | |
| end = min(sum(shard_sizes[:rank + 1]), total_size) | |
| return range(begin, end) | |
| def __iter__(self): | |
| yield from self._local_indices | |
| def __len__(self): | |
| return len(self._local_indices) | |
| def remove_sp(text, language): | |
| gt = re.sub(r"<\|.*?\|>", " ", text) | |
| gt = re.sub(rf"\s+", r" ", gt) # 将文本中的连续空格替换为单个空格 | |
| gt = re.sub(f" ?([{PUNCS}])", r"\1", gt) | |
| gt = gt.lstrip(" ") | |
| if language == "zh": | |
| gt = re.sub(rf"\s+", r"", gt) | |
| return gt | |
| def compute_wer(refs, hyps, language): | |
| distance = 0 | |
| ref_length = 0 | |
| tokenizer = EvaluationTokenizer( | |
| tokenizer_type="none", | |
| lowercase=True, | |
| punctuation_removal=True, | |
| character_tokenization=False, | |
| ) | |
| for i in range(len(refs)): | |
| ref = refs[i] | |
| pred = hyps[i] | |
| if language in ["yue"]: | |
| ref = zhconv.convert(ref, 'zh-cn') | |
| pred = zhconv.convert(pred, 'zh-cn') | |
| if language in ["en"]: | |
| ref = english_normalizer(ref) | |
| pred = english_normalizer(pred) | |
| if language in ["zh"]: | |
| ref = chinese_normalizer(ref) | |
| pred = chinese_normalizer(pred) | |
| else: | |
| ref = basic_normalizer(ref) | |
| pred = basic_normalizer(pred) | |
| ref_items = tokenizer.tokenize(ref).split() | |
| pred_items = tokenizer.tokenize(pred).split() | |
| if language in ["zh", "yue"]: | |
| ref_items = [x for x in "".join(ref_items)] | |
| pred_items = [x for x in "".join(pred_items)] | |
| if i==0: | |
| print(f"ref: {ref}") | |
| print(f"pred: {pred}") | |
| print(f"ref_items:\n{ref_items}\n{len(ref_items)}\n{ref_items[0]}") | |
| print(f"pred_items:\n{pred_items}\n{len(ref_items)}\n{ref_items[0]}") | |
| distance += ed.eval(ref_items, pred_items) | |
| ref_length += len(ref_items) | |
| return distance/ref_length | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--checkpoint', type=str, default='Qwen/Qwen2-Audio') | |
| parser.add_argument('--dataset', type=str, default='') | |
| parser.add_argument('--batch-size', type=int, default=1) | |
| parser.add_argument('--num-workers', type=int, default=1) | |
| parser.add_argument('--seed', type=int, default=0) | |
| args = parser.parse_args() | |
| torch.distributed.init_process_group( | |
| backend='nccl', | |
| world_size=int(os.getenv('WORLD_SIZE', '1')), | |
| rank=int(os.getenv('RANK', '0')), | |
| ) | |
| torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0))) | |
| model = Qwen2AudioForConditionalGeneration.from_pretrained( | |
| args.checkpoint, device_map='cuda', torch_dtype='auto', trust_remote_code=True).eval() | |
| processor = AutoProcessor.from_pretrained(args.checkpoint) | |
| processor.tokenizer.padding_side = 'left' | |
| random.seed(args.seed) | |
| dataset = AudioDataset( | |
| ds=ds_collections[args.dataset], | |
| ) | |
| data_loader = torch.utils.data.DataLoader( | |
| dataset=dataset, | |
| sampler=InferenceSampler(len(dataset)), | |
| batch_size=args.batch_size, | |
| num_workers=args.num_workers, | |
| pin_memory=True, | |
| drop_last=False, | |
| collate_fn=partial(collate_fn, processor=processor), | |
| ) | |
| gts = [] | |
| sources = [] | |
| rets = [] | |
| audio_paths = [] | |
| for _, (inputs, audio_path, source, gt) in tqdm(enumerate(data_loader)): | |
| inputs['input_ids'] = inputs['input_ids'].to('cuda') | |
| output_ids = model.generate(**inputs, max_new_tokens=256, min_new_tokens=1, do_sample=False) | |
| output_ids = output_ids[:, inputs.input_ids.size(1):] | |
| output = processor.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) | |
| gts.extend(gt) | |
| rets.extend(output) | |
| sources.extend(source) | |
| audio_paths.extend(audio_path) | |
| torch.distributed.barrier() | |
| world_size = torch.distributed.get_world_size() | |
| merged_gts = [None for _ in range(world_size)] | |
| merged_sources = [None for _ in range(world_size)] | |
| merged_responses = [None for _ in range(world_size)] | |
| merged_audio_paths = [None for _ in range(world_size)] | |
| torch.distributed.all_gather_object(merged_gts, gts) | |
| torch.distributed.all_gather_object(merged_sources, sources) | |
| torch.distributed.all_gather_object(merged_responses, rets) | |
| torch.distributed.all_gather_object(merged_audio_paths, audio_paths) | |
| merged_gts = [_ for _ in itertools.chain.from_iterable(merged_gts)] | |
| merged_sources = [_ for _ in itertools.chain.from_iterable(merged_sources)] | |
| merged_audio_paths = [_ for _ in itertools.chain.from_iterable(merged_audio_paths)] | |
| merged_responses = [ | |
| _ for _ in itertools.chain.from_iterable(merged_responses) | |
| ] | |
| if torch.distributed.get_rank() == 0: | |
| print(f"Evaluating {args.dataset} ...") | |
| results = [] | |
| for gt, response, source, audio_path in zip(merged_gts, merged_responses, merged_sources, merged_audio_paths): | |
| results.append({ | |
| 'gt': gt, | |
| 'response': response, | |
| 'source': source, | |
| 'audio_path': audio_path, | |
| }) | |
| time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime()) | |
| results_file = f'{args.dataset}_{time_prefix}.json' | |
| json.dump(results, open(results_file, 'w')) | |
| results_dict = {} | |
| for item in tqdm(results): | |
| source = item["source"] | |
| results_dict.setdefault(source, []).append(item) | |
| lan = ds_collections[args.dataset]['language'] | |
| for source in results_dict: | |
| refs, hyps = [], [] | |
| results_list = results_dict[source] | |
| for result in results_list: | |
| gt = result["gt"] | |
| response = result["response"] | |
| gt = remove_sp(gt, lan) | |
| response = remove_sp(response, lan) | |
| refs.append(gt) | |
| hyps.append(response) | |
| wer = compute_wer(refs, hyps, lan) | |
| print(f"source: {source} cnt: {len(refs)} wer: {wer:.4f}") | |
| torch.distributed.barrier() | |