Spaces:
Runtime error
Runtime error
| from utils.logger import get_logger | |
| import numpy as np | |
| from rapidfuzz.distance.Levenshtein import normalized_distance | |
| import multiprocessing | |
| import time | |
| import utils.alignment as alignment | |
| def _get_mned_metric_from_TruePredict(true_text, predict_text): | |
| return normalized_distance(predict_text, true_text) | |
| def get_mned_metric_from_TruePredict(batch_true_text, batch_predict_text): | |
| total_NMED = 0.0 | |
| count = 0 | |
| for true_text, predict_text in zip(batch_true_text, batch_predict_text): | |
| total_NMED += _get_mned_metric_from_TruePredict(true_text, predict_text) | |
| count += 1 | |
| return total_NMED / count | |
| def get_metric_for_tfm(batch_predicts, batch_targets, batch_length): | |
| num_correct, num_wrong = 0, 0 | |
| for predict, target, length in zip(batch_predicts, batch_targets, batch_length): | |
| predict = predict[1:-1] | |
| target = target[1:-1] | |
| predict = np.array(predict[0:length]) | |
| target = np.array(target[0:length]) | |
| num_correct += np.sum(predict == target) | |
| num_wrong += np.sum(predict != target) | |
| return num_correct, num_wrong | |
| def allign_seq2trueseq(seq, true_seq, gap_symbol = "-"): | |
| prev_sep = None | |
| next_sep = None | |
| seq_list = [] | |
| true_list = [] | |
| accumulate_true_word = "" | |
| accumulate_pred_word = "" | |
| assert len(true_seq) == len(seq) | |
| for i in range(len(true_seq)): | |
| if true_seq[i] != " ": | |
| accumulate_true_word += true_seq[i] | |
| accumulate_pred_word += seq[i] | |
| else: | |
| if seq[i] == gap_symbol: | |
| next_sep = gap_symbol | |
| if prev_sep != None and prev_sep == gap_symbol: | |
| accumulate_pred_word = "@@" + accumulate_pred_word | |
| if next_sep != None and next_sep == gap_symbol: | |
| accumulate_pred_word = accumulate_pred_word + "@@" | |
| else: | |
| next_sep = " " | |
| if prev_sep != None and prev_sep == gap_symbol: | |
| accumulate_pred_word = "@@" + accumulate_pred_word | |
| if next_sep != None and next_sep == gap_symbol: | |
| accumulate_pred_word = accumulate_pred_word + "@@" | |
| true_list.append(accumulate_true_word.replace(gap_symbol, "")) | |
| seq_list.append(accumulate_pred_word) | |
| accumulate_pred_word = "" | |
| accumulate_true_word = "" | |
| prev_sep = next_sep | |
| next_sep = None | |
| return seq_list, true_list | |
| def align_2seq2trueseq(wrong_text, pred_text, true_text, gap_symbol = "-"): | |
| assert gap_symbol != None and len(gap_symbol) == 1 | |
| seq1, true_seq = alignment.needle(wrong_text, true_text, gap_symbol) | |
| seq1_list, true_list = allign_seq2trueseq(seq1, true_seq, gap_symbol) | |
| seq2, true_seq = alignment.needle(pred_text, true_text, gap_symbol) | |
| seq2_list, _ = allign_seq2trueseq(seq2, true_seq, gap_symbol) | |
| return list(zip(seq1_list, seq2_list, true_list)) | |
| def _get_metric_from_TrueWrongPredictV3(true_text, wrong_text, predict_text, vocab = None): | |
| gap_symbol = None | |
| if vocab != None: | |
| all_symbols = set(list(vocab.chartoken2idx.keys())[4:]) | |
| symbols = set(list(wrong_text + predict_text + true_text)) | |
| usable_symbols = all_symbols.difference(symbols) | |
| assert len(usable_symbols) > 0 | |
| if "-" not in usable_symbols: | |
| gap_symbol = usable_symbols.pop() | |
| else: | |
| gap_symbol = "-" | |
| gap_symbol = gap_symbol if gap_symbol != None else "-" | |
| alignment = align_2seq2trueseq(wrong_text, predict_text, true_text, gap_symbol) | |
| TP, FP, FN = 0, 0, 0 | |
| for wrong, predict, true in alignment: | |
| if wrong == true: | |
| if predict[:-2] == true: | |
| pass | |
| elif predict != true: | |
| if len(predict.split(" ")) == len(true.split(" ")): | |
| FP += 1 | |
| else: | |
| penalty = len(predict.split(" ")) - len(true.split(" ")) | |
| assert penalty > 0 | |
| FP += penalty | |
| else: | |
| if predict == true: | |
| TP += 1 | |
| else: | |
| if len(predict.split(" ")) == len(true.split(" ")): | |
| FN += 1 | |
| else: | |
| penalty = len(predict.split(" ")) - len(true.split(" ")) | |
| assert penalty > 0 | |
| FN += penalty | |
| return TP, FP, FN | |
| def worker_task(true_text, wrong_text, predict_text, vocab): | |
| _TP, _FP, _FN = _get_metric_from_TrueWrongPredictV3(true_text, wrong_text, predict_text, vocab) | |
| return (_TP, _FP, _FN) | |
| from multiprocessing import Pool | |
| def get_metric_from_TrueWrongPredictV3(batch_true_text, batch_wrong_text, batch_predict_text, vocab, twp_logger): | |
| assert vocab != None | |
| TPs, FPs, FNs = 0, 0, 0 | |
| with Pool(int(multiprocessing.cpu_count() / 3)) as pool: | |
| data = [(true_text, wrong_text, pred_text, vocab) for true_text, wrong_text, pred_text in zip(batch_true_text, batch_wrong_text, batch_predict_text)] | |
| result = pool.starmap_async(worker_task, data) | |
| for i, result in enumerate(result.get()): | |
| TPs += result[0] | |
| FPs += result[1] | |
| FNs += result[2] | |
| if twp_logger: | |
| twp_logger.log(batch_true_text[i], file_only=True) | |
| twp_logger.log(batch_wrong_text[i], file_only=True) | |
| twp_logger.log(batch_predict_text[i], file_only=True) | |
| twp_logger.log(f"{result[0]} - {result[1]} - {result[2]}", file_only=True) | |
| return TPs, FPs, FNs |