Spaces:
Build error
Build error
| import os | |
| import numpy as np | |
| from tqdm import tqdm | |
| from math import ceil | |
| def load_vsec_dataset(base_path, corr_file, incorr_file): | |
| # load files | |
| if base_path: | |
| assert os.path.exists(base_path) == True | |
| incorr_data = [] | |
| opfile1 = open(os.path.join(base_path, incorr_file), "r", encoding="utf-8") | |
| for line in opfile1: | |
| if line.strip() != "": | |
| incorr_data.append(line.strip()) | |
| opfile1.close() | |
| corr_data = [] | |
| opfile2 = open(os.path.join(base_path, corr_file), "r", encoding="utf-8") | |
| for line in opfile2: | |
| if line.strip() != "": | |
| corr_data.append(line.strip()) | |
| opfile2.close() | |
| assert len(incorr_data) == len(corr_data) | |
| data = [] | |
| for x, y in zip(corr_data, incorr_data): | |
| data.append((x, y)) | |
| print(f"loaded tuples of (incorr, corr) examples from {base_path}") | |
| return data | |
| def load_dataset(base_path, corr_file, incorr_file, length_file = None): | |
| # load files | |
| if base_path: | |
| assert os.path.exists(base_path) == True | |
| data = [] | |
| opfile2 = open(os.path.join(base_path, corr_file), "r", encoding="utf-8") | |
| for line in tqdm(opfile2): | |
| if line.strip() != "": | |
| data.append([line.strip()]) | |
| data.append([line.strip()]) | |
| opfile2.close() | |
| opfile1 = open(os.path.join(base_path, incorr_file), "r", encoding="utf-8") | |
| for i, line in tqdm(enumerate(opfile1)): | |
| if line.strip() != "": | |
| data[i].append(line.strip()) | |
| opfile1.close() | |
| opfile4 = open(os.path.join(base_path, length_file), "r", encoding="utf-8") | |
| for i, line in tqdm(enumerate(opfile4)): | |
| if line.strip() != "": | |
| data[i].append(int(line)) | |
| opfile4.close() | |
| print(f"loaded tuples of (incorr, corr, length) examples from {base_path}") | |
| return data | |
| def load_epoch_dataset(base_path, corr_file, incorr_file, length_file, epoch: int, num_epoch: int): | |
| # load files | |
| if base_path: | |
| assert os.path.exists(base_path) == True | |
| assert num_epoch >= 1 | |
| assert epoch >= 1 and epoch <= num_epoch | |
| ## Count number of data | |
| opfile = open(os.path.join(base_path, length_file), "r", encoding="utf-8") | |
| count = 0 | |
| for i, line in tqdm(enumerate(opfile)): | |
| count +=1 | |
| opfile.close() | |
| print(f"Number of training datas: {count} examples!") | |
| epochdataset_examples = int(ceil(1 / num_epoch * count)) | |
| start_index = epochdataset_examples * (epoch - 1) | |
| end_index = start_index + epochdataset_examples | |
| data = [] | |
| opfile2 = open(os.path.join(base_path, corr_file), "r", encoding="utf-8") | |
| traverse_count = 0 | |
| for i, line in tqdm(enumerate(opfile2)): | |
| if line.strip() != "": | |
| if traverse_count >= start_index and traverse_count < end_index : | |
| data.append([line.strip()]) | |
| traverse_count += 1 | |
| elif traverse_count >= end_index: | |
| break | |
| else: | |
| traverse_count += 1 | |
| if traverse_count >= start_index and traverse_count < end_index : | |
| data.append([line.strip()]) | |
| traverse_count += 1 | |
| elif traverse_count >= end_index: | |
| break | |
| else: | |
| traverse_count += 1 | |
| opfile2.close() | |
| opfile1 = open(os.path.join(base_path, incorr_file), "r", encoding="utf-8") | |
| traverse_count = 0 | |
| for i, line in tqdm(enumerate(opfile1)): | |
| if line.strip() != "": | |
| if traverse_count >= start_index and traverse_count < end_index : | |
| data[i - start_index].append(line.strip()) | |
| elif traverse_count >= end_index: | |
| break | |
| traverse_count += 1 | |
| opfile1.close() | |
| traverse_count = 0 | |
| opfile4 = open(os.path.join(base_path, length_file), "r", encoding="utf-8") | |
| for i, line in tqdm(enumerate(opfile4)): | |
| if line.strip() != "": | |
| if traverse_count >= start_index and traverse_count < end_index : | |
| data[i - start_index].append(int(line)) | |
| elif traverse_count >= end_index: | |
| break | |
| traverse_count += 1 | |
| opfile4.close() | |
| print(f"loaded tuples of (incorr, corr, length) examples from {base_path}") | |
| return data | |