Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import argparse | |
| import numpy as np | |
| from tqdm import tqdm | |
| from util import utils | |
| from util import inference | |
| import torch | |
| device = torch.device(r'cuda' if torch.cuda.is_available() else r'cpu') | |
| def find_selection( | |
| model, | |
| tok, | |
| ds | |
| ): | |
| # find case ids | |
| case_ids = np.array([r['case_id'] for r in ds.data]) | |
| # find original prompts and subjects of each data sample | |
| prompts = [sample['requested_rewrite']['prompt'] for sample in ds.data] | |
| subjects = [sample['requested_rewrite']['subject'] for sample in ds.data] | |
| # perform inference to first token | |
| om_output_tokens = inference.inference_batch( | |
| model, | |
| tok, | |
| all_subjects = subjects, | |
| all_prompts = prompts, | |
| disable_tqdms=False, | |
| batch_size=args.batch_size, | |
| ) | |
| # decode outputs | |
| outputs_decoded = np.array([tok.decode(t).strip() for t in om_output_tokens]) | |
| # find all true targets | |
| target_trues = np.array([ | |
| sample['requested_rewrite']['target_true']['str'] for sample in ds.data]) | |
| # find matching mask, case_ids | |
| matching = [target_trues[i].startswith(outputs_decoded[i]) for i in range(len(outputs_decoded))] | |
| matching_case_ids = case_ids[matching] | |
| # count unique subjects | |
| num_unique_matching = len(np.unique(target_trues[matching])) | |
| num_unique = len(np.unique(target_trues)) | |
| print(f'Number of unique matching: {num_unique_matching}/{num_unique}') | |
| return matching_case_ids.tolist() | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| '--model', default="gpt-j-6b", type=str, help='model to edit') | |
| parser.add_argument( | |
| '--dataset', default="mcf", type=str, choices=['mcf', 'zsre'], help='dataset for evaluation') | |
| parser.add_argument( | |
| '--batch_size', type=int, default=64, help='batch size for extraction') | |
| parser.add_argument('--cache_path', type=str, default='./cache/', help='dataset directory') | |
| args = parser.parse_args() | |
| # ensure results path exists | |
| args.cache_path = os.path.join(args.cache_path, 'selection/') | |
| utils.assure_path_exists(args.cache_path) | |
| # find output path | |
| output_file = os.path.join(args.cache_path, f'{args.dataset}_{args.model}_subject_selection.json') | |
| if os.path.exists(output_file): | |
| print(f'Selection already exists: {output_file}') | |
| exit() | |
| # load model and tokenizer | |
| model, tok = utils.load_model_tok(model_name=args.model) | |
| # load dataset | |
| ds, _, _ = utils.load_dataset(tok, ds_name=args.dataset) | |
| # find selection | |
| selected_case_ids = find_selection(model, tok, ds) | |
| # save json file of selected case ids | |
| utils.savejson(output_file, {'case_ids': selected_case_ids}) | |