Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import argparse | |
| import numpy as np | |
| from tqdm import tqdm | |
| from util import utils | |
| from util import extraction, evaluation | |
| from dsets import wikipedia | |
| def cache_wikipedia( | |
| model_name, | |
| model, | |
| tok, | |
| max_len, | |
| exclude_front = 0, | |
| sample_size = 10000, | |
| take_single = False, | |
| exclude_path = None, | |
| layers = None, | |
| cache_path = None | |
| ): | |
| # load wikipedia dataset | |
| if max_len is not None: | |
| raw_ds, tok_ds = wikipedia.get_ds(tok, maxlen=max_len) | |
| else: | |
| print('Finding max length of dataset...') | |
| try: | |
| raw_ds, tok_ds = wikipedia.get_ds(tok, maxlen=model.config.n_positions) | |
| except: | |
| raw_ds, tok_ds = wikipedia.get_ds(tok, maxlen=4096) | |
| # extract features from each layer | |
| for l in layers: | |
| # try: | |
| print('\n\nExtracting wikipedia token features for model layer:', l) | |
| output_file = os.path.join(cache_path, f'wikipedia_features_{model_name}_layer{l}_w1.pickle') | |
| if os.path.exists(output_file): | |
| print('Output file already exists:', output_file) | |
| continue | |
| if exclude_path is not None: | |
| exclude_file = os.path.join(exclude_path, f'wikipedia_features_{model_name}_layer{l}_w1.pickle') | |
| exclude_indices = utils.loadpickle(exclude_file)['sampled_indices'] | |
| else: | |
| exclude_indices = [] | |
| features, params = extraction.extract_tokdataset_features( | |
| model, | |
| tok_ds, | |
| layer = l, | |
| hparams = hparams, | |
| exclude_front = exclude_front, | |
| sample_size = sample_size, | |
| take_single = take_single, | |
| exclude_indices = exclude_indices, | |
| verbose = True | |
| ) | |
| # save features | |
| params['features'] = features.cpu().numpy() | |
| utils.savepickle(output_file, params) | |
| print('Features saved:', output_file) | |
| # except: | |
| # print('Error extracting wikipedia features for layer:', l) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| '--model', default="gpt-j-6b", type=str, help='model to edit') | |
| parser.add_argument( | |
| '--sample_size', type=int, default=10000, help='number of feacture vectors to extract') | |
| parser.add_argument( | |
| '--max_len', type=int, default=None, help='maximum token length') | |
| parser.add_argument( | |
| '--exclude_front', type=int, default=0, help='number of tokens to exclude from the front') | |
| parser.add_argument( | |
| '--take_single', type=int, default=0, help='single vector from single wikipedia sample text') | |
| parser.add_argument( | |
| '--layer', type=int, default=None, help='single vector from single wikipedia sample text') | |
| parser.add_argument( | |
| '--exclude_path', type=str, default=None, help='output directory') | |
| parser.add_argument( | |
| '--cache_path', type=str, default='./cache/wiki_train/', help='output directory') | |
| args = parser.parse_args() | |
| # loading hyperparameters | |
| hparams_path = f'./hparams/SE/{args.model}.json' | |
| hparams = utils.loadjson(hparams_path) | |
| # ensure save path exists | |
| utils.assure_path_exists(args.cache_path) | |
| # load model | |
| model, tok = utils.load_model_tok(args.model) | |
| if args.layer is not None: | |
| layers = [args.layer] | |
| else: | |
| layers = evaluation.model_layer_indices[args.model] | |
| # main function | |
| cache_wikipedia( | |
| model_name = args.model, | |
| model = model, | |
| tok = tok, | |
| max_len = args.max_len, | |
| layers = layers, | |
| exclude_front = args.exclude_front, | |
| sample_size = args.sample_size, | |
| take_single = bool(args.take_single), | |
| cache_path = args.cache_path, | |
| exclude_path = args.exclude_path, | |
| ) | |