Spaces:
Sleeping
Sleeping
Commit
·
2fcdf98
1
Parent(s):
6e66a6e
Delete models/tools
Browse files- models/tools/__init__.py +0 -4
- models/tools/analysis_toolkits/__init__.py +0 -0
- models/tools/computations/softmax.py +0 -8
- models/tools/data_structures/__init__.py +0 -0
- models/tools/data_structures/trie.py +0 -152
- models/tools/model_utils/__init__.py +0 -0
- models/tools/model_utils/__pycache__/__init__.cpython-38.pyc +0 -0
- models/tools/model_utils/__pycache__/parameter_freeze.cpython-38.pyc +0 -0
- models/tools/model_utils/calibrate.py +0 -202
- models/tools/model_utils/gpt_response.py +0 -138
- models/tools/model_utils/parameter_freeze.py +0 -126
- models/tools/model_utils/uncertainty.py +0 -137
- models/tools/processing_utils/common.py +0 -38
- models/tools/processing_utils/sampler.py +0 -26
- models/tools/processing_utils/tokenizer/JiebaTokenizer.py +0 -24
- models/tools/processing_utils/tokenizer/__init__.py +0 -4
- models/tools/processing_utils/tokenizer/tokenizer_utils.py +0 -19
- models/tools/runner_utils/__init__.py +0 -0
- models/tools/runner_utils/__pycache__/__init__.cpython-38.pyc +0 -0
- models/tools/runner_utils/__pycache__/log_util.cpython-38.pyc +0 -0
- models/tools/runner_utils/conifg_extensive.py +0 -15
- models/tools/runner_utils/log_util.py +0 -30
- models/tools/runner_utils/retrying.py +0 -288
- models/tools/runner_utils/set_seed.py +0 -21
- models/tools/runner_utils/timecost.py +0 -20
models/tools/__init__.py
DELETED
|
@@ -1,4 +0,0 @@
|
|
| 1 |
-
# -*- coding: utf-8 -*-
|
| 2 |
-
# @Time : 2021/12/2 5:41 p.m.
|
| 3 |
-
# @Author : JianingWang
|
| 4 |
-
# @File : __init__.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/tools/analysis_toolkits/__init__.py
DELETED
|
File without changes
|
models/tools/computations/softmax.py
DELETED
|
@@ -1,8 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
|
| 3 |
-
"""
|
| 4 |
-
Transform the torch logits into probabilities.
|
| 5 |
-
"""
|
| 6 |
-
def softmax(logits):
|
| 7 |
-
probs = torch.softmax(torch.from_numpy(logits).float(), -1).numpy()
|
| 8 |
-
return probs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/tools/data_structures/__init__.py
DELETED
|
File without changes
|
models/tools/data_structures/trie.py
DELETED
|
@@ -1,152 +0,0 @@
|
|
| 1 |
-
# -*- coding: utf-8 -*-
|
| 2 |
-
# @Time : 2022/2/15 7:57 下午
|
| 3 |
-
# @Author : JianingWang
|
| 4 |
-
# @File : trie
|
| 5 |
-
import logging
|
| 6 |
-
from typing import List
|
| 7 |
-
from collections import OrderedDict
|
| 8 |
-
|
| 9 |
-
logger = logging.getLogger(__name__)
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
class Trie:
|
| 13 |
-
def __init__(self):
|
| 14 |
-
self.data = {}
|
| 15 |
-
|
| 16 |
-
def add(self, word: str):
|
| 17 |
-
"""
|
| 18 |
-
Passes over every char (utf-8 char) on word and recursively adds it to the internal `data` trie representation.
|
| 19 |
-
The special key `""` is used to represent termination.
|
| 20 |
-
|
| 21 |
-
This function is idempotent, adding twice the same word will leave the trie unchanged
|
| 22 |
-
|
| 23 |
-
Example:
|
| 24 |
-
|
| 25 |
-
```python
|
| 26 |
-
>>> trie = Trie()
|
| 27 |
-
>>> trie.add("Hello 友達")
|
| 28 |
-
>>> trie.data
|
| 29 |
-
{"H": {"e": {"l": {"l": {"o": {" ": {"友": {"達": {"": 1}}}}}}}}}
|
| 30 |
-
|
| 31 |
-
>>> trie.add("Hello")
|
| 32 |
-
>>> trie.data
|
| 33 |
-
{"H": {"e": {"l": {"l": {"o": {"": 1, " ": {"友": {"達": {"": 1}}}}}}}}}
|
| 34 |
-
```
|
| 35 |
-
"""
|
| 36 |
-
if not word:
|
| 37 |
-
# Prevent empty string
|
| 38 |
-
return
|
| 39 |
-
ref = self.data
|
| 40 |
-
for char in word:
|
| 41 |
-
ref[char] = char in ref and ref[char] or {}
|
| 42 |
-
ref = ref[char]
|
| 43 |
-
ref[""] = 1
|
| 44 |
-
|
| 45 |
-
def find(self, text: str):
|
| 46 |
-
states = OrderedDict()
|
| 47 |
-
offsets = []
|
| 48 |
-
skip = 0
|
| 49 |
-
for current, current_char in enumerate(text):
|
| 50 |
-
if skip and current < skip:
|
| 51 |
-
continue
|
| 52 |
-
to_remove = set()
|
| 53 |
-
reset = False
|
| 54 |
-
for start, trie_pointer in states.items():
|
| 55 |
-
if "" in trie_pointer:
|
| 56 |
-
for lookstart, looktrie_pointer in states.items():
|
| 57 |
-
if lookstart > start:
|
| 58 |
-
break
|
| 59 |
-
elif lookstart < start:
|
| 60 |
-
lookahead_index = current + 1
|
| 61 |
-
end = current + 1
|
| 62 |
-
else:
|
| 63 |
-
lookahead_index = current
|
| 64 |
-
end = current
|
| 65 |
-
next_char = text[lookahead_index] if lookahead_index < len(text) else None
|
| 66 |
-
if "" in looktrie_pointer:
|
| 67 |
-
start = lookstart
|
| 68 |
-
end = lookahead_index
|
| 69 |
-
skip = lookahead_index
|
| 70 |
-
|
| 71 |
-
while next_char in looktrie_pointer:
|
| 72 |
-
looktrie_pointer = looktrie_pointer[next_char]
|
| 73 |
-
lookahead_index += 1
|
| 74 |
-
if "" in looktrie_pointer:
|
| 75 |
-
start = lookstart
|
| 76 |
-
end = lookahead_index
|
| 77 |
-
skip = lookahead_index
|
| 78 |
-
|
| 79 |
-
if lookahead_index == len(text):
|
| 80 |
-
break
|
| 81 |
-
next_char = text[lookahead_index]
|
| 82 |
-
offsets.append([start, end])
|
| 83 |
-
reset = True
|
| 84 |
-
break
|
| 85 |
-
elif current_char in trie_pointer:
|
| 86 |
-
trie_pointer = trie_pointer[current_char]
|
| 87 |
-
states[start] = trie_pointer
|
| 88 |
-
else:
|
| 89 |
-
to_remove.add(start)
|
| 90 |
-
if reset:
|
| 91 |
-
states = {}
|
| 92 |
-
else:
|
| 93 |
-
for start in to_remove:
|
| 94 |
-
del states[start]
|
| 95 |
-
if current >= skip and current_char in self.data:
|
| 96 |
-
states[current] = self.data[current_char]
|
| 97 |
-
for start, trie_pointer in states.items():
|
| 98 |
-
if "" in trie_pointer:
|
| 99 |
-
end = len(text)
|
| 100 |
-
offsets.append([start, end])
|
| 101 |
-
break
|
| 102 |
-
|
| 103 |
-
return offsets
|
| 104 |
-
|
| 105 |
-
def split(self, text: str) -> List[str]:
|
| 106 |
-
"""
|
| 107 |
-
Example:
|
| 108 |
-
|
| 109 |
-
```python
|
| 110 |
-
>>> trie = Trie()
|
| 111 |
-
>>> trie.split("[CLS] This is a extra_id_100")
|
| 112 |
-
["[CLS] This is a extra_id_100"]
|
| 113 |
-
|
| 114 |
-
>>> trie.add("[CLS]")
|
| 115 |
-
>>> trie.add("extra_id_1")
|
| 116 |
-
>>> trie.add("extra_id_100")
|
| 117 |
-
>>> trie.split("[CLS] This is a extra_id_100")
|
| 118 |
-
["[CLS]", " This is a ", "extra_id_100"]
|
| 119 |
-
```
|
| 120 |
-
"""
|
| 121 |
-
word_sets = self.find(text)
|
| 122 |
-
offsets = [0]
|
| 123 |
-
for w in word_sets:
|
| 124 |
-
offsets.extend(w)
|
| 125 |
-
return self.cut_text(text, offsets)
|
| 126 |
-
|
| 127 |
-
def cut_text(self, text, offsets):
|
| 128 |
-
offsets.append(len(text))
|
| 129 |
-
tokens = []
|
| 130 |
-
start = 0
|
| 131 |
-
for end in offsets:
|
| 132 |
-
if start > end:
|
| 133 |
-
logger.error(
|
| 134 |
-
"There was a bug in Trie algorithm in tokenization. Attempting to recover. Please report it anyway."
|
| 135 |
-
)
|
| 136 |
-
continue
|
| 137 |
-
elif start == end:
|
| 138 |
-
continue
|
| 139 |
-
tokens.append(text[start:end])
|
| 140 |
-
start = end
|
| 141 |
-
|
| 142 |
-
return tokens
|
| 143 |
-
|
| 144 |
-
def __reduce__(self):
|
| 145 |
-
return None
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
if __name__ == "__main__":
|
| 149 |
-
trie = Trie()
|
| 150 |
-
for word in ["A", "AB", "BD", "BWA"]:
|
| 151 |
-
trie.add(word)
|
| 152 |
-
print(trie.__reduce__())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/tools/model_utils/__init__.py
DELETED
|
File without changes
|
models/tools/model_utils/__pycache__/__init__.cpython-38.pyc
DELETED
|
Binary file (139 Bytes)
|
|
|
models/tools/model_utils/__pycache__/parameter_freeze.cpython-38.pyc
DELETED
|
Binary file (2.8 kB)
|
|
|
models/tools/model_utils/calibrate.py
DELETED
|
@@ -1,202 +0,0 @@
|
|
| 1 |
-
# -*- coding: utf-8 -*-
|
| 2 |
-
# @Time : 2023/3/20 8:02 p.m.
|
| 3 |
-
# @Author : Jianing Wang
|
| 4 |
-
# @File : calibrate.py
|
| 5 |
-
|
| 6 |
-
import os
|
| 7 |
-
import numpy as np
|
| 8 |
-
import torch
|
| 9 |
-
|
| 10 |
-
"""
|
| 11 |
-
Use LM to classify label words for calibrating CLS
|
| 12 |
-
"""
|
| 13 |
-
class CLSCalibrator:
|
| 14 |
-
pass
|
| 15 |
-
|
| 16 |
-
"""
|
| 17 |
-
Use Causal LM to generate label words for calibrating CLS
|
| 18 |
-
e.g., use gpt2 to generate a label word with in-context prompts, and calibrate for the prediction.
|
| 19 |
-
Paper: http://proceedings.mlr.press/v139/zhao21c.html
|
| 20 |
-
"""
|
| 21 |
-
class CausalCLSCalibrator:
|
| 22 |
-
|
| 23 |
-
def __init__(self, model, tokenizer) -> None:
|
| 24 |
-
self.model = model
|
| 25 |
-
self.tokenizer = tokenizer
|
| 26 |
-
|
| 27 |
-
def calibrate(self, all_label_probs, content_free_examples, label2id, mode="diagonal_W"):
|
| 28 |
-
"""Perform calibration for de-biasing and obtain calibrated probability"""
|
| 29 |
-
p_cf = self.get_content_free_prediction(content_free_examples, label2id)
|
| 30 |
-
|
| 31 |
-
num_classes = all_label_probs.shape[1]
|
| 32 |
-
if p_cf is None:
|
| 33 |
-
# do not calibrate
|
| 34 |
-
W = np.identity(num_classes)
|
| 35 |
-
b = np.zeros([num_classes, 1])
|
| 36 |
-
else:
|
| 37 |
-
# calibrate
|
| 38 |
-
if mode == "diagonal_W":
|
| 39 |
-
W = np.linalg.inv(np.identity(num_classes) * p_cf)
|
| 40 |
-
b = np.zeros([num_classes, 1])
|
| 41 |
-
elif mode == "identity_W":
|
| 42 |
-
W = np.identity(num_classes)
|
| 43 |
-
b = -1 * np.expand_dims(p_cf, axis=-1)
|
| 44 |
-
else:
|
| 45 |
-
assert False
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
all_calibrate_label_probs = list()
|
| 49 |
-
for label_probs in all_label_probs:
|
| 50 |
-
label_probs = label_probs / np.sum(label_probs) # normalize to 1
|
| 51 |
-
calibrate_label_probs = np.matmul(W, np.expand_dims(label_probs, axis=-1)) + b
|
| 52 |
-
all_calibrate_label_probs.append(calibrate_label_probs.squeeze().tolist())
|
| 53 |
-
return np.array(all_calibrate_label_probs)
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
def get_content_free_prediction(self, content_free_examples, label2id: dict):
|
| 57 |
-
"""Query model with content free input, return its prediction probability for each label"""
|
| 58 |
-
|
| 59 |
-
all_p_y = []
|
| 60 |
-
for content_free_example in content_free_examples:
|
| 61 |
-
|
| 62 |
-
content_free_prompt = content_free_example["content_free_prompt"]
|
| 63 |
-
p_y = [0] * len(label2id)
|
| 64 |
-
for answers, i in label2id.items():
|
| 65 |
-
prob = 0
|
| 66 |
-
for a in answers:
|
| 67 |
-
prob += np.exp(self.get_causal_cls_prediction(content_free_prompt + " " + a, 0, echo=True, num_log_probs=1)['choices'][0]['logprobs']['token_logprobs'][-1])
|
| 68 |
-
p_y[i] = prob
|
| 69 |
-
all_p_y.append(p_y)
|
| 70 |
-
|
| 71 |
-
p_y = np.mean(np.array(all_p_y), axis=0)
|
| 72 |
-
p_y = p_y / np.sum(p_y) # normalize
|
| 73 |
-
return p_y
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
def get_causal_cls_prediction(self, prompt, l=10, num_log_probs=None, echo=False):
|
| 77 |
-
''' This function runs GPT-2 locally but places the outputs into an json that looks just like the one
|
| 78 |
-
provided by the OpenAI API. '''
|
| 79 |
-
if isinstance(prompt, str):
|
| 80 |
-
prompt = [prompt] # the code below assumes a list
|
| 81 |
-
input_ids = self.tokenizer.batch_encode_plus(prompt, return_tensors="pt", padding=True)
|
| 82 |
-
|
| 83 |
-
if l + len(input_ids['input_ids'][0]) > 1020:
|
| 84 |
-
m = l + len(input_ids['input_ids'][0]) - 1024
|
| 85 |
-
input_ids['input_ids'] = torch.Tensor([input_ids['input_ids'][0][m:].numpy()]).long()
|
| 86 |
-
input_ids['attention_mask'] = torch.Tensor([input_ids['attention_mask'][0][m:].numpy()]).long()
|
| 87 |
-
|
| 88 |
-
# greedily generate l tokens
|
| 89 |
-
# print("l=", l)
|
| 90 |
-
if l > 0:
|
| 91 |
-
# the generate function can handle left padded inputs automatically in HF
|
| 92 |
-
# total_sequences is now the input + possible generated output
|
| 93 |
-
# print("l + len(input_ids[input_ids][0]=", l + len(input_ids['input_ids'][0]))
|
| 94 |
-
total_sequences = self.model.generate(
|
| 95 |
-
input_ids=input_ids['input_ids'].to(self.model.device),
|
| 96 |
-
attention_mask=input_ids['attention_mask'].to(self.model.device),
|
| 97 |
-
max_length=l + len(input_ids['input_ids'][0]),
|
| 98 |
-
do_sample=False
|
| 99 |
-
)
|
| 100 |
-
else:
|
| 101 |
-
assert echo == True and l == 0
|
| 102 |
-
total_sequences = input_ids['input_ids'].to(self.model.device)
|
| 103 |
-
# print("="*50)
|
| 104 |
-
# print("total_sequences=", total_sequences) [batch, len+l]
|
| 105 |
-
# print("total_sequences.shape=", total_sequences.shape)
|
| 106 |
-
|
| 107 |
-
# they want the probs of the top tokens
|
| 108 |
-
if num_log_probs is not None:
|
| 109 |
-
# we are left padding, so we need to adjust the position IDs
|
| 110 |
-
attention_mask = (total_sequences != 50256).float()
|
| 111 |
-
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 112 |
-
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 113 |
-
# get the logits for the context and the next l tokens
|
| 114 |
-
logits = self.model.forward(input_ids=total_sequences, attention_mask=attention_mask, position_ids=position_ids, return_dict=True).logits.detach().cpu()
|
| 115 |
-
if not echo:
|
| 116 |
-
# get the top tokens and probs for the generated l tokens
|
| 117 |
-
probs = torch.softmax(logits[:,-l-1:], dim=2).cpu()
|
| 118 |
-
else:
|
| 119 |
-
# get the top tokens and probs for the context and the generated l tokens
|
| 120 |
-
probs = torch.softmax(logits, dim=2).cpu()
|
| 121 |
-
top_probs, top_tokens = torch.topk(probs, k=num_log_probs)
|
| 122 |
-
logprobs = torch.log(probs)
|
| 123 |
-
top_log_probs = torch.log(top_probs)
|
| 124 |
-
# print("top_log_probs=", top_log_probs)
|
| 125 |
-
# print("top_log_probs.shape=", top_log_probs.shape) # [1, 2, 100] [batch, 2, api_num_log_prob]
|
| 126 |
-
|
| 127 |
-
# create the return value to resemble OpenAI
|
| 128 |
-
return_json = {}
|
| 129 |
-
choices = []
|
| 130 |
-
# print("="*50)
|
| 131 |
-
for batch_id in range(len(prompt)):
|
| 132 |
-
curr_json = {}
|
| 133 |
-
# text is just the optional context and next l tokens
|
| 134 |
-
if not echo:
|
| 135 |
-
curr_json['text'] = self.tokenizer.decode(total_sequences[batch_id][-l:], skip_special_tokens=True)
|
| 136 |
-
else:
|
| 137 |
-
curr_json['text'] = self.tokenizer.decode(total_sequences[batch_id], skip_special_tokens=True)
|
| 138 |
-
|
| 139 |
-
# fill the return json with the top tokens and probs to match the OpenAI return value.
|
| 140 |
-
if num_log_probs is not None:
|
| 141 |
-
curr_json['logprobs'] = {}
|
| 142 |
-
curr_json['logprobs']['top_logprobs'] = []
|
| 143 |
-
curr_json['logprobs']['token_logprobs'] = []
|
| 144 |
-
curr_json['logprobs']['tokens'] = []
|
| 145 |
-
if not echo:
|
| 146 |
-
# cutoff the -1 here because the probs are shifted one over for LMs
|
| 147 |
-
for current_element_top_log_probs, current_element_top_tokens in zip(top_log_probs[batch_id][:-1], top_tokens[batch_id][:-1]):
|
| 148 |
-
# tokens is a list of the top token at each position
|
| 149 |
-
curr_json['logprobs']['tokens'].append(self.tokenizer.decode([current_element_top_tokens[0]]))
|
| 150 |
-
# token_logprobs is a list of the logprob of the top token at each position
|
| 151 |
-
curr_json['logprobs']['token_logprobs'].append(current_element_top_log_probs[0].item())
|
| 152 |
-
# top_logprobs is a list of dicts for the top K tokens. with each entry being {'token_name': log_prob}
|
| 153 |
-
temp = {}
|
| 154 |
-
for log_prob, token in zip(current_element_top_log_probs, current_element_top_tokens):
|
| 155 |
-
temp[self.tokenizer.decode(token.item())] = log_prob.item()
|
| 156 |
-
curr_json['logprobs']['top_logprobs'].append(temp)
|
| 157 |
-
else:
|
| 158 |
-
# same as not above but small tweaks
|
| 159 |
-
# we add null to the front because for the GPT models, they have null probability for the first token
|
| 160 |
-
# (for some reason they don't have an beginning of sentence token)
|
| 161 |
-
curr_json['logprobs']['top_logprobs'].append('null')
|
| 162 |
-
# cutoff the -1 here because the probs are shifted one over for LMs
|
| 163 |
-
for index, (current_element_top_log_probs, current_element_top_tokens) in enumerate(zip(top_log_probs[batch_id][:-1], top_tokens[batch_id][:-1])):
|
| 164 |
-
# skip padding tokens
|
| 165 |
-
if total_sequences[batch_id][index].item() == 50256:
|
| 166 |
-
continue
|
| 167 |
-
temp = {}
|
| 168 |
-
for log_prob, token in zip(current_element_top_log_probs, current_element_top_tokens):
|
| 169 |
-
temp[self.tokenizer.decode(token.item())] = log_prob.item()
|
| 170 |
-
curr_json['logprobs']['top_logprobs'].append(temp)
|
| 171 |
-
for index in range(len(probs[batch_id])):
|
| 172 |
-
curr_json['logprobs']['tokens'].append(self.tokenizer.decode([total_sequences[batch_id][index]]))
|
| 173 |
-
curr_json['logprobs']['token_logprobs'].append('null')
|
| 174 |
-
for index, log_probs_token_position_j in enumerate(logprobs[batch_id][:-1]):
|
| 175 |
-
# probs are left shifted for LMs
|
| 176 |
-
curr_json['logprobs']['token_logprobs'].append(log_probs_token_position_j[total_sequences[batch_id][index+1]])
|
| 177 |
-
|
| 178 |
-
choices.append(curr_json)
|
| 179 |
-
# print("curr_json=", curr_json)
|
| 180 |
-
'''
|
| 181 |
-
e.g.,
|
| 182 |
-
num_tokens_to_predict=1
|
| 183 |
-
curr_json= {
|
| 184 |
-
'text': ' I', # 当前生成的top词
|
| 185 |
-
'logprobs': {'top_logprobs': [{' I': -3.4267239570617676, '\n': -3.5073862075805664, ...], # top100词及其socre
|
| 186 |
-
'token_logprobs': [-3.4267239570617676], # 当前top词的score
|
| 187 |
-
'tokens': [' I']}
|
| 188 |
-
}
|
| 189 |
-
num_tokens_to_predict=2
|
| 190 |
-
curr_json= {
|
| 191 |
-
'text': '\nThe', # 如果指定生成两个词,则为两个词
|
| 192 |
-
'logprobs': {'top_logprobs': [ # 两个位置对应的预测的score
|
| 193 |
-
{'\n': -3.186706304550171, '\xa0': -3.222092390060425, ' We': -6.781067848205566, ...},
|
| 194 |
-
{'The': -2.5251243114471436, '"': -2.857935667037964, ...],
|
| 195 |
-
'token_logprobs': [-3.186706304550171, -2.5251243114471436], # 生成的词的score
|
| 196 |
-
'tokens': ['\n', 'The']}
|
| 197 |
-
}
|
| 198 |
-
'''
|
| 199 |
-
return_json['choices'] = choices
|
| 200 |
-
# print("="*50)
|
| 201 |
-
# print("return_json=", return_json)
|
| 202 |
-
return return_json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/tools/model_utils/gpt_response.py
DELETED
|
@@ -1,138 +0,0 @@
|
|
| 1 |
-
# -*- coding: utf-8 -*-
|
| 2 |
-
# @Time : 2023/3/23 1:02 p.m.
|
| 3 |
-
# @Author : Jianing Wang
|
| 4 |
-
# @File : gpt_response.py
|
| 5 |
-
|
| 6 |
-
import os
|
| 7 |
-
import sys
|
| 8 |
-
import torch
|
| 9 |
-
import openai
|
| 10 |
-
import time
|
| 11 |
-
|
| 12 |
-
"""
|
| 13 |
-
Call for GPT-style LLM.
|
| 14 |
-
The output format is the same as OpenAI (e.g., GPT-3.5 text-davinci-003)
|
| 15 |
-
"""
|
| 16 |
-
class GPTResponse:
|
| 17 |
-
|
| 18 |
-
def __init__(self, model_type: str, data_path: str) -> None:
|
| 19 |
-
assert model_type in ["gpt2", "gpt3"]
|
| 20 |
-
self.model_type = model_type
|
| 21 |
-
if self.model_type == "gpt3":
|
| 22 |
-
|
| 23 |
-
with open(os.path.join(data_path, 'openai_key.txt'), 'r') as f:
|
| 24 |
-
key = f.readline().strip()
|
| 25 |
-
openai.api_key = key
|
| 26 |
-
|
| 27 |
-
def call_for_gpt3_response(self, prompt, l, model_name, temp=0, num_log_probs=None, echo=False, n=None):
|
| 28 |
-
"""
|
| 29 |
-
call GPT-3 API until result is provided and then return it
|
| 30 |
-
"""
|
| 31 |
-
response = None
|
| 32 |
-
received = False
|
| 33 |
-
while not received:
|
| 34 |
-
try:
|
| 35 |
-
response = openai.Completion.create(engine=model_name, prompt=prompt, max_tokens=l, temperature=temp,
|
| 36 |
-
logprobs=num_log_probs, echo=echo, stop='\n', n=n)
|
| 37 |
-
received = True
|
| 38 |
-
except:
|
| 39 |
-
error = sys.exc_info()[0]
|
| 40 |
-
if error == openai.error.InvalidRequestError: # something is wrong: e.g. prompt too long
|
| 41 |
-
print(f"InvalidRequestError\nPrompt passed in:\n\n{prompt}\n\n")
|
| 42 |
-
assert False
|
| 43 |
-
|
| 44 |
-
print("API error:", error)
|
| 45 |
-
time.sleep(1)
|
| 46 |
-
return response
|
| 47 |
-
|
| 48 |
-
def call_for_gpt2_response(self, gpt2_tokenizer, logits, total_sequences, l=10, num_log_probs=None, echo=False, n=None):
|
| 49 |
-
"""
|
| 50 |
-
Obtain the prediction logits from gpt2 in local, and convert it to the value that can match the response from OpenAI
|
| 51 |
-
"""
|
| 52 |
-
if not echo:
|
| 53 |
-
# get the top tokens and probs for the generated l tokens
|
| 54 |
-
probs = torch.softmax(logits[:,-l-1:], dim=2).cpu()
|
| 55 |
-
else:
|
| 56 |
-
# get the top tokens and probs for the context and the generated l tokens
|
| 57 |
-
probs = torch.softmax(logits, dim=2).cpu()
|
| 58 |
-
# print("probs=", probs)
|
| 59 |
-
top_probs, top_tokens = torch.topk(probs, k=num_log_probs)
|
| 60 |
-
logprobs = torch.log(probs)
|
| 61 |
-
top_log_probs = torch.log(top_probs)
|
| 62 |
-
|
| 63 |
-
# create the return value to resemble OpenAI
|
| 64 |
-
return_json = {}
|
| 65 |
-
choices = []
|
| 66 |
-
# print("="*50)
|
| 67 |
-
for batch_id in range(len(logits)):
|
| 68 |
-
curr_json = {}
|
| 69 |
-
# text is just the optional context and next l tokens
|
| 70 |
-
if not echo:
|
| 71 |
-
curr_json['text'] = gpt2_tokenizer.decode(total_sequences[batch_id][-l:], skip_special_tokens=True)
|
| 72 |
-
else:
|
| 73 |
-
curr_json['text'] = gpt2_tokenizer.decode(total_sequences[batch_id], skip_special_tokens=True)
|
| 74 |
-
|
| 75 |
-
# fill the return json with the top tokens and probs to match the OpenAI return value.
|
| 76 |
-
if num_log_probs is not None:
|
| 77 |
-
curr_json['logprobs'] = {}
|
| 78 |
-
curr_json['logprobs']['top_logprobs'] = []
|
| 79 |
-
curr_json['logprobs']['token_logprobs'] = []
|
| 80 |
-
curr_json['logprobs']['tokens'] = []
|
| 81 |
-
if not echo:
|
| 82 |
-
# cutoff the -1 here because the probs are shifted one over for LMs
|
| 83 |
-
for current_element_top_log_probs, current_element_top_tokens in zip(top_log_probs[batch_id][:-1], top_tokens[batch_id][:-1]):
|
| 84 |
-
# tokens is a list of the top token at each position
|
| 85 |
-
curr_json['logprobs']['tokens'].append(gpt2_tokenizer.decode([current_element_top_tokens[0]]))
|
| 86 |
-
# token_logprobs is a list of the logprob of the top token at each position
|
| 87 |
-
curr_json['logprobs']['token_logprobs'].append(current_element_top_log_probs[0].item())
|
| 88 |
-
# top_logprobs is a list of dicts for the top K tokens. with each entry being {'token_name': log_prob}
|
| 89 |
-
temp = {}
|
| 90 |
-
for log_prob, token in zip(current_element_top_log_probs, current_element_top_tokens):
|
| 91 |
-
temp[gpt2_tokenizer.decode(token.item())] = log_prob.item()
|
| 92 |
-
curr_json['logprobs']['top_logprobs'].append(temp)
|
| 93 |
-
else:
|
| 94 |
-
# same as not above but small tweaks
|
| 95 |
-
# we add null to the front because for the GPT models, they have null probability for the first token
|
| 96 |
-
# (for some reason they don't have an beginning of sentence token)
|
| 97 |
-
curr_json['logprobs']['top_logprobs'].append('null')
|
| 98 |
-
# cutoff the -1 here because the probs are shifted one over for LMs
|
| 99 |
-
for index, (current_element_top_log_probs, current_element_top_tokens) in enumerate(zip(top_log_probs[batch_id][:-1], top_tokens[batch_id][:-1])):
|
| 100 |
-
# skip padding tokens
|
| 101 |
-
if total_sequences[batch_id][index].item() == 50256:
|
| 102 |
-
continue
|
| 103 |
-
temp = {}
|
| 104 |
-
for log_prob, token in zip(current_element_top_log_probs, current_element_top_tokens):
|
| 105 |
-
temp[gpt2_tokenizer.decode(token.item())] = log_prob.item()
|
| 106 |
-
curr_json['logprobs']['top_logprobs'].append(temp)
|
| 107 |
-
for index in range(len(probs[batch_id])):
|
| 108 |
-
curr_json['logprobs']['tokens'].append(gpt2_tokenizer.decode([total_sequences[batch_id][index]]))
|
| 109 |
-
curr_json['logprobs']['token_logprobs'].append('null')
|
| 110 |
-
for index, log_probs_token_position_j in enumerate(logprobs[batch_id][:-1]):
|
| 111 |
-
# probs are left shifted for LMs
|
| 112 |
-
curr_json['logprobs']['token_logprobs'].append(log_probs_token_position_j[total_sequences[batch_id][index+1]])
|
| 113 |
-
|
| 114 |
-
choices.append(curr_json)
|
| 115 |
-
# print("curr_json=", curr_json)
|
| 116 |
-
'''
|
| 117 |
-
e.g.,
|
| 118 |
-
num_tokens_to_predict=1
|
| 119 |
-
curr_json= {
|
| 120 |
-
'text': ' I', # 当前生成的top词
|
| 121 |
-
'logprobs': {'top_logprobs': [{' I': -3.4267239570617676, '\n': -3.5073862075805664, ...], # top100词及其socre
|
| 122 |
-
'token_logprobs': [-3.4267239570617676], # 当前top词的score
|
| 123 |
-
'tokens': [' I']}
|
| 124 |
-
}
|
| 125 |
-
num_tokens_to_predict=2
|
| 126 |
-
curr_json= {
|
| 127 |
-
'text': '\nThe', # 如果指定生成两个词,则为两个词
|
| 128 |
-
'logprobs': {'top_logprobs': [ # 两个位置对应的预测的score
|
| 129 |
-
{'\n': -3.186706304550171, '\xa0': -3.222092390060425, ' We': -6.781067848205566, ...},
|
| 130 |
-
{'The': -2.5251243114471436, '"': -2.857935667037964, ...],
|
| 131 |
-
'token_logprobs': [-3.186706304550171, -2.5251243114471436], # 生成的词的score
|
| 132 |
-
'tokens': ['\n', 'The']}
|
| 133 |
-
}
|
| 134 |
-
'''
|
| 135 |
-
return_json['choices'] = choices
|
| 136 |
-
# print("="*50)
|
| 137 |
-
# print("return_json=", return_json)
|
| 138 |
-
return return_json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/tools/model_utils/parameter_freeze.py
DELETED
|
@@ -1,126 +0,0 @@
|
|
| 1 |
-
# -*- coding: utf-8 -*-
|
| 2 |
-
# @Time : 2023/02/18 02:07 p.m.
|
| 3 |
-
# @Author : JianingWang
|
| 4 |
-
# @File : parameter_freeze.py
|
| 5 |
-
|
| 6 |
-
import torch
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
"""
|
| 10 |
-
This is use for parameter fixing and unfreezing, which can be viewed as parameter-efficient settings.
|
| 11 |
-
"""
|
| 12 |
-
class ParameterFreeze():
|
| 13 |
-
# freeze all parameters
|
| 14 |
-
def freeze_lm(self, model: torch.nn.Module):
|
| 15 |
-
for name, param in model.named_parameters():
|
| 16 |
-
param.requires_grad = False
|
| 17 |
-
return model
|
| 18 |
-
|
| 19 |
-
# freeze all parameters without cls / mlm head
|
| 20 |
-
def freeze_lm_encoder(self, model: torch.nn.Module):
|
| 21 |
-
for name, param in model.named_parameters():
|
| 22 |
-
if "lm_head" in name or ("cls" in name):
|
| 23 |
-
print(name)
|
| 24 |
-
continue
|
| 25 |
-
param.requires_grad = False
|
| 26 |
-
return model
|
| 27 |
-
|
| 28 |
-
# freeze all parameters without bias
|
| 29 |
-
def freeze_lm_finetune_bias(self, model: torch.nn.Module):
|
| 30 |
-
for name, param in model.named_parameters():
|
| 31 |
-
if "bias" in name:
|
| 32 |
-
print(name)
|
| 33 |
-
continue
|
| 34 |
-
param.requires_grad = False
|
| 35 |
-
return model
|
| 36 |
-
|
| 37 |
-
# freeze the component that user defined
|
| 38 |
-
def freeze_lm_component(self, model: torch.nn.Module, component: str):
|
| 39 |
-
if "attention" in component:
|
| 40 |
-
for name, param in model.named_parameters():
|
| 41 |
-
if "attention" in name:
|
| 42 |
-
if "output" in component:
|
| 43 |
-
if "output" in name:
|
| 44 |
-
continue
|
| 45 |
-
else:
|
| 46 |
-
continue
|
| 47 |
-
param.requires_grad = False
|
| 48 |
-
model = self.unfreeze_classification_head(model)
|
| 49 |
-
elif "feedforward" in component:
|
| 50 |
-
for name, param in model.named_parameters():
|
| 51 |
-
if "dense" in name and "attention" not in name:
|
| 52 |
-
if "output" in component:
|
| 53 |
-
if "output" in name:
|
| 54 |
-
continue
|
| 55 |
-
else:
|
| 56 |
-
if "intermediate" in component:
|
| 57 |
-
if "intermediate" in name:
|
| 58 |
-
continue
|
| 59 |
-
param.requires_grad = False
|
| 60 |
-
model = self.unfreeze_classification_head(model)
|
| 61 |
-
elif component == "adapter":
|
| 62 |
-
for name, param in model.named_parameters():
|
| 63 |
-
if "adapter" in name:
|
| 64 |
-
continue
|
| 65 |
-
|
| 66 |
-
param.requires_grad = False
|
| 67 |
-
model = self.unfreeze_classification_head(model)
|
| 68 |
-
elif "embedding" in component:
|
| 69 |
-
for name, param in model.named_parameters():
|
| 70 |
-
if "embedding" in name:
|
| 71 |
-
continue
|
| 72 |
-
|
| 73 |
-
param.requires_grad = False
|
| 74 |
-
model = self.unfreeze_classification_head(model)
|
| 75 |
-
elif "bias" in component:
|
| 76 |
-
for name, param in model.named_parameters():
|
| 77 |
-
if "bias" in name:
|
| 78 |
-
continue
|
| 79 |
-
param.requires_grad = False
|
| 80 |
-
model = self.unfreeze_classification_head(model)
|
| 81 |
-
elif "head" in component:
|
| 82 |
-
for name, param in model.named_parameters():
|
| 83 |
-
param.requires_grad = False
|
| 84 |
-
model = self.unfreeze_classification_head(model)
|
| 85 |
-
|
| 86 |
-
elif "prompt_emb" in component:
|
| 87 |
-
for name, param in model.named_parameters():
|
| 88 |
-
if "prompt_emb" in name:
|
| 89 |
-
continue
|
| 90 |
-
param.requires_grad = False
|
| 91 |
-
return model
|
| 92 |
-
|
| 93 |
-
# unfreeze cls head
|
| 94 |
-
def unfreeze_classification_head(self, model: torch.nn.Module):
|
| 95 |
-
for name, param in model.named_parameters():
|
| 96 |
-
if "lm_head" in name or ("cls" in name) or ("classifier" in name):
|
| 97 |
-
param.requires_grad = True
|
| 98 |
-
return model
|
| 99 |
-
|
| 100 |
-
# freeze k layers
|
| 101 |
-
def freeze_lm_k_layers(self, model: torch.nn.Module, k):
|
| 102 |
-
keep_layers = []
|
| 103 |
-
update_parameters = []
|
| 104 |
-
for i in range(k):
|
| 105 |
-
keep_layers.append("layer."+str(23-i))
|
| 106 |
-
|
| 107 |
-
for name, param in model.named_parameters():
|
| 108 |
-
update = False
|
| 109 |
-
for layer_num in keep_layers:
|
| 110 |
-
if layer_num in name:
|
| 111 |
-
if "dense" in name and "attention" not in name:
|
| 112 |
-
if "output" in name:
|
| 113 |
-
print(name)
|
| 114 |
-
update_parameters.append(name)
|
| 115 |
-
update = True
|
| 116 |
-
|
| 117 |
-
if not update:
|
| 118 |
-
param.requires_grad = False
|
| 119 |
-
model = self.unfreeze_classification_head(model)
|
| 120 |
-
return model
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
def unfreeze_lm(self, model: torch.nn.Module):
|
| 124 |
-
for param in model.parameters():
|
| 125 |
-
param.requires_grad = True
|
| 126 |
-
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/tools/model_utils/uncertainty.py
DELETED
|
@@ -1,137 +0,0 @@
|
|
| 1 |
-
# -*- coding: utf-8 -*-
|
| 2 |
-
# @Time : 2023/04/18 08:11 p.m.
|
| 3 |
-
# @Author : JianingWang
|
| 4 |
-
# @File : uncertainty.py
|
| 5 |
-
|
| 6 |
-
from sklearn.utils import shuffle
|
| 7 |
-
import logging
|
| 8 |
-
import numpy as np
|
| 9 |
-
import os
|
| 10 |
-
import random
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
logger = logging.getLogger(__name__)
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
def get_BALD_acquisition(y_T):
|
| 17 |
-
|
| 18 |
-
expected_entropy = - np.mean(np.sum(y_T * np.log(y_T + 1e-10), axis=-1), axis=0)
|
| 19 |
-
expected_p = np.mean(y_T, axis=0)
|
| 20 |
-
entropy_expected_p = - np.sum(expected_p * np.log(expected_p + 1e-10), axis=-1)
|
| 21 |
-
return (entropy_expected_p - expected_entropy)
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
def sample_by_bald_difficulty(tokenizer, X, y_mean, y_var, y, num_samples, num_classes, y_T):
|
| 25 |
-
|
| 26 |
-
logger.info ("Sampling by difficulty BALD acquisition function")
|
| 27 |
-
BALD_acq = get_BALD_acquisition(y_T)
|
| 28 |
-
p_norm = np.maximum(np.zeros(len(BALD_acq)), BALD_acq)
|
| 29 |
-
p_norm = p_norm / np.sum(p_norm)
|
| 30 |
-
indices = np.random.choice(len(X['input_ids']), num_samples, p=p_norm, replace=False)
|
| 31 |
-
X_s = {"input_ids": X["input_ids"][indices], "token_type_ids": X["token_type_ids"][indices], "attention_mask": X["attention_mask"][indices]}
|
| 32 |
-
y_s = y[indices]
|
| 33 |
-
w_s = y_var[indices][:,0]
|
| 34 |
-
return X_s, y_s, w_s
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
def sample_by_bald_easiness(tokenizer, X, y_mean, y_var, y, num_samples, num_classes, y_T):
|
| 38 |
-
|
| 39 |
-
logger.info ("Sampling by easy BALD acquisition function")
|
| 40 |
-
BALD_acq = get_BALD_acquisition(y_T)
|
| 41 |
-
p_norm = np.maximum(np.zeros(len(BALD_acq)), (1. - BALD_acq)/np.sum(1. - BALD_acq))
|
| 42 |
-
p_norm = p_norm / np.sum(p_norm)
|
| 43 |
-
logger.info (p_norm[:10])
|
| 44 |
-
indices = np.random.choice(len(X['input_ids']), num_samples, p=p_norm, replace=False)
|
| 45 |
-
X_s = {"input_ids": X["input_ids"][indices], "token_type_ids": X["token_type_ids"][indices], "attention_mask": X["attention_mask"][indices]}
|
| 46 |
-
y_s = y[indices]
|
| 47 |
-
w_s = y_var[indices][:,0]
|
| 48 |
-
return X_s, y_s, w_s
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
def sample_by_bald_class_easiness(tokenizer, X, y_mean, y_var, y, num_samples, num_classes, y_T):
|
| 52 |
-
|
| 53 |
-
logger.info ("Sampling by easy BALD acquisition function per class")
|
| 54 |
-
BALD_acq = get_BALD_acquisition(y_T)
|
| 55 |
-
BALD_acq = (1. - BALD_acq)/np.sum(1. - BALD_acq)
|
| 56 |
-
logger.info (BALD_acq)
|
| 57 |
-
samples_per_class = num_samples // num_classes
|
| 58 |
-
X_s_input_ids, X_s_token_type_ids, X_s_attention_mask, X_s_mask_pos, y_s, w_s = [], [], [], [], [], []
|
| 59 |
-
|
| 60 |
-
for label in range(num_classes):
|
| 61 |
-
# X_input_ids, X_token_type_ids, X_attention_mask = np.array(X['input_ids'])[y == label], np.array(X['token_type_ids'])[y == label], np.array(X['attention_mask'])[y == label]
|
| 62 |
-
X_input_ids, X_attention_mask = np.array(X['input_ids'])[y == label], np.array(X['attention_mask'])[y == label]
|
| 63 |
-
if "token_type_ids" in X.features:
|
| 64 |
-
X_token_type_ids = np.array(X['token_type_ids'])[y == label]
|
| 65 |
-
if "mask_pos" in X.features:
|
| 66 |
-
X_mask_pos = np.array(X['mask_pos'])[y == label]
|
| 67 |
-
y_ = y[y==label]
|
| 68 |
-
y_var_ = y_var[y == label]
|
| 69 |
-
# p = y_mean[y == label]
|
| 70 |
-
p_norm = BALD_acq[y==label]
|
| 71 |
-
p_norm = np.maximum(np.zeros(len(p_norm)), p_norm)
|
| 72 |
-
p_norm = p_norm/np.sum(p_norm)
|
| 73 |
-
if len(X_input_ids) < samples_per_class:
|
| 74 |
-
logger.info ("Sampling with replacement.")
|
| 75 |
-
replace = True
|
| 76 |
-
else:
|
| 77 |
-
replace = False
|
| 78 |
-
if len(X_input_ids) == 0: # add by wjn
|
| 79 |
-
continue
|
| 80 |
-
indices = np.random.choice(len(X_input_ids), samples_per_class, p=p_norm, replace=replace)
|
| 81 |
-
X_s_input_ids.extend(X_input_ids[indices])
|
| 82 |
-
# X_s_token_type_ids.extend(X_token_type_ids[indices])
|
| 83 |
-
X_s_attention_mask.extend(X_attention_mask[indices])
|
| 84 |
-
if "token_type_ids" in X.features:
|
| 85 |
-
X_s_token_type_ids.extend(X_token_type_ids[indices])
|
| 86 |
-
if "mask_pos" in X.features:
|
| 87 |
-
X_s_mask_pos.extend(X_mask_pos[indices])
|
| 88 |
-
y_s.extend(y_[indices])
|
| 89 |
-
w_s.extend(y_var_[indices][:,0])
|
| 90 |
-
|
| 91 |
-
# X_s_input_ids, X_s_token_type_ids, X_s_attention_mask, y_s, w_s = shuffle(X_s_input_ids, X_s_token_type_ids, X_s_attention_mask, y_s, w_s)
|
| 92 |
-
if "token_type_ids" in X.features and "mask_pos" not in X.features:
|
| 93 |
-
X_s_input_ids, X_s_token_type_ids, X_s_attention_mask, y_s, w_s = shuffle(X_s_input_ids, X_s_token_type_ids, X_s_attention_mask, y_s, w_s)
|
| 94 |
-
elif "token_type_ids" not in X.features and "mask_pos" in X.features:
|
| 95 |
-
X_s_input_ids, X_s_mask_pos, X_s_attention_mask, y_s, w_s = shuffle(X_s_input_ids, X_s_mask_pos, X_s_attention_mask, y_s, w_s)
|
| 96 |
-
elif "token_type_ids" in X.features and "mask_pos" in X.features:
|
| 97 |
-
X_s_input_ids, X_s_token_type_ids, X_s_mask_pos, X_s_attention_mask, y_s, w_s = shuffle(X_s_input_ids, X_s_token_type_ids, X_s_mask_pos, X_s_attention_mask, y_s, w_s)
|
| 98 |
-
else:
|
| 99 |
-
X_s_input_ids, X_s_attention_mask, y_s, w_s = shuffle(X_s_input_ids, X_s_attention_mask, y_s, w_s)
|
| 100 |
-
|
| 101 |
-
pseudo_labeled_input = {
|
| 102 |
-
'input_ids': np.array(X_s_input_ids),
|
| 103 |
-
'attention_mask': np.array(X_s_attention_mask)
|
| 104 |
-
}
|
| 105 |
-
if "token_type_ids" in X.features:
|
| 106 |
-
pseudo_labeled_input['token_type_ids'] = np.array(X_s_token_type_ids)
|
| 107 |
-
if "mask_pos" in X.features:
|
| 108 |
-
pseudo_labeled_input['mask_pos'] = np.array(X_s_mask_pos)
|
| 109 |
-
return pseudo_labeled_input, np.array(y_s), np.array(w_s)
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
def sample_by_bald_class_difficulty(tokenizer, X, y_mean, y_var, y, num_samples, num_classes, y_T):
|
| 113 |
-
|
| 114 |
-
logger.info ("Sampling by difficulty BALD acquisition function per class")
|
| 115 |
-
BALD_acq = get_BALD_acquisition(y_T)
|
| 116 |
-
samples_per_class = num_samples // num_classes
|
| 117 |
-
X_s_input_ids, X_s_token_type_ids, X_s_attention_mask, y_s, w_s = [], [], [], [], []
|
| 118 |
-
for label in range(num_classes):
|
| 119 |
-
X_input_ids, X_token_type_ids, X_attention_mask = X['input_ids'][y == label], X['token_type_ids'][y == label], X['attention_mask'][y == label]
|
| 120 |
-
y_ = y[y==label]
|
| 121 |
-
y_var_ = y_var[y == label]
|
| 122 |
-
p_norm = BALD_acq[y==label]
|
| 123 |
-
p_norm = np.maximum(np.zeros(len(p_norm)), p_norm)
|
| 124 |
-
p_norm = p_norm/np.sum(p_norm)
|
| 125 |
-
if len(X_input_ids) < samples_per_class:
|
| 126 |
-
replace = True
|
| 127 |
-
logger.info ("Sampling with replacement.")
|
| 128 |
-
else:
|
| 129 |
-
replace = False
|
| 130 |
-
indices = np.random.choice(len(X_input_ids), samples_per_class, p=p_norm, replace=replace)
|
| 131 |
-
X_s_input_ids.extend(X_input_ids[indices])
|
| 132 |
-
X_s_token_type_ids.extend(X_token_type_ids[indices])
|
| 133 |
-
X_s_attention_mask.extend(X_attention_mask[indices])
|
| 134 |
-
y_s.extend(y_[indices])
|
| 135 |
-
w_s.extend(y_var_[indices][:,0])
|
| 136 |
-
X_s_input_ids, X_s_token_type_ids, X_s_attention_mask, y_s, w_s = shuffle(X_s_input_ids, X_s_token_type_ids, X_s_attention_mask, y_s, w_s)
|
| 137 |
-
return {'input_ids': np.array(X_s_input_ids), 'token_type_ids': np.array(X_s_token_type_ids), 'attention_mask': np.array(X_s_attention_mask)}, np.array(y_s), np.array(w_s)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/tools/processing_utils/common.py
DELETED
|
@@ -1,38 +0,0 @@
|
|
| 1 |
-
# -*- coding: utf-8 -*-
|
| 2 |
-
# @Time : 2021/12/2 5:41 p.m.
|
| 3 |
-
# @Author : JianingWang
|
| 4 |
-
# @File : common.py
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
def is_chinese_char(cp):
|
| 8 |
-
"""Checks whether CP is the codepoint of a CJK character."""
|
| 9 |
-
# This defines a "chinese character" as anything in the CJK Unicode block:
|
| 10 |
-
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
| 11 |
-
#
|
| 12 |
-
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
| 13 |
-
# despite its name. The modern Korean Hangul alphabet is a different block,
|
| 14 |
-
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
| 15 |
-
# space-separated words, so they are not treated specially and handled
|
| 16 |
-
# like the all of the other languages.
|
| 17 |
-
if (
|
| 18 |
-
(0x4E00 <= cp <= 0x9FFF)
|
| 19 |
-
or (0x3400 <= cp <= 0x4DBF) #
|
| 20 |
-
or (0x20000 <= cp <= 0x2A6DF) #
|
| 21 |
-
or (0x2A700 <= cp <= 0x2B73F) #
|
| 22 |
-
or (0x2B740 <= cp <= 0x2B81F) #
|
| 23 |
-
or (0x2B820 <= cp <= 0x2CEAF) #
|
| 24 |
-
or (0xF900 <= cp <= 0xFAFF)
|
| 25 |
-
or (0x2F800 <= cp <= 0x2FA1F) #
|
| 26 |
-
): #
|
| 27 |
-
return True
|
| 28 |
-
|
| 29 |
-
return False
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
def is_chinese(word: str):
|
| 33 |
-
# word like "180" or "身高" or "神"
|
| 34 |
-
for char in word:
|
| 35 |
-
char = ord(char)
|
| 36 |
-
if not is_chinese_char(char):
|
| 37 |
-
return 0
|
| 38 |
-
return 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/tools/processing_utils/sampler.py
DELETED
|
@@ -1,26 +0,0 @@
|
|
| 1 |
-
# -*- coding: utf-8 -*-
|
| 2 |
-
# @Time : 2021/12/2 5:41 p.m.
|
| 3 |
-
# @Author : JianingWang
|
| 4 |
-
# @File : sampler.py
|
| 5 |
-
|
| 6 |
-
import numpy as np
|
| 7 |
-
from typing import Optional
|
| 8 |
-
|
| 9 |
-
"""
|
| 10 |
-
random sampling for each label
|
| 11 |
-
"""
|
| 12 |
-
def random_sampling(raw_datasets, num_examples_per_label: Optional[int]=16):
|
| 13 |
-
label_list = raw_datasets["label"] # [0, 1, 0, 0, ...]
|
| 14 |
-
label_dict = dict()
|
| 15 |
-
# denote index of each label
|
| 16 |
-
for ei, label in enumerate(label_list):
|
| 17 |
-
if label not in label_dict.keys():
|
| 18 |
-
label_dict[label] = list()
|
| 19 |
-
label_dict[label].append(ei)
|
| 20 |
-
# random sample k examples of each class
|
| 21 |
-
few_example_ids = list()
|
| 22 |
-
for label, eid_list in label_dict.items():
|
| 23 |
-
idxs = np.random.choice(len(eid_list), size=num_examples_per_label, replace=False)
|
| 24 |
-
selected_eids = [eid_list[i] for i in idxs]
|
| 25 |
-
few_example_ids.extend(selected_eids)
|
| 26 |
-
return few_example_ids
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/tools/processing_utils/tokenizer/JiebaTokenizer.py
DELETED
|
@@ -1,24 +0,0 @@
|
|
| 1 |
-
# -*- coding: utf-8 -*-
|
| 2 |
-
# @Time : 2021/12/8 12:07 a.m.
|
| 3 |
-
# @Author : JianingWang
|
| 4 |
-
# @File : JiebaTokenizer
|
| 5 |
-
|
| 6 |
-
import jieba
|
| 7 |
-
from transformers import BertTokenizer
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
class JiebaTokenizer(BertTokenizer):
|
| 11 |
-
def __init__(
|
| 12 |
-
self, pre_tokenizer=lambda x: jieba.cut(x, HMM=False), *args, **kwargs
|
| 13 |
-
):
|
| 14 |
-
super().__init__(*args, **kwargs)
|
| 15 |
-
self.pre_tokenizer = pre_tokenizer
|
| 16 |
-
|
| 17 |
-
def _tokenize(self, text, *arg, **kwargs):
|
| 18 |
-
split_tokens = []
|
| 19 |
-
for text in self.pre_tokenizer(text):
|
| 20 |
-
if text in self.vocab:
|
| 21 |
-
split_tokens.append(text)
|
| 22 |
-
else:
|
| 23 |
-
split_tokens.extend(super()._tokenize(text))
|
| 24 |
-
return split_tokens
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/tools/processing_utils/tokenizer/__init__.py
DELETED
|
@@ -1,4 +0,0 @@
|
|
| 1 |
-
# -*- coding: utf-8 -*-
|
| 2 |
-
# @Time : 2021/12/8 12:07 上午
|
| 3 |
-
# @Author : JianingWang
|
| 4 |
-
# @File : __init__.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/tools/processing_utils/tokenizer/tokenizer_utils.py
DELETED
|
@@ -1,19 +0,0 @@
|
|
| 1 |
-
from transformers import AutoTokenizer
|
| 2 |
-
|
| 3 |
-
"""
|
| 4 |
-
obtain special tokens
|
| 5 |
-
"""
|
| 6 |
-
def get_special_token_mapping(tokenizer: AutoTokenizer):
|
| 7 |
-
if "t5" in type(tokenizer).__name__.lower():
|
| 8 |
-
special_token_mapping = {
|
| 9 |
-
"cls": 3, "mask": 32099, "sep": tokenizer.eos_token_id,
|
| 10 |
-
"sep+": tokenizer.eos_token_id,
|
| 11 |
-
"pseudo_token": tokenizer.unk_token_id
|
| 12 |
-
}
|
| 13 |
-
else:
|
| 14 |
-
special_token_mapping = {
|
| 15 |
-
"cls": tokenizer.cls_token_id, "mask": tokenizer.mask_token_id, "sep": tokenizer.sep_token_id,
|
| 16 |
-
"sep+": tokenizer.sep_token_id,
|
| 17 |
-
"pseudo_token": tokenizer.unk_token_id
|
| 18 |
-
}
|
| 19 |
-
return special_token_mapping
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/tools/runner_utils/__init__.py
DELETED
|
File without changes
|
models/tools/runner_utils/__pycache__/__init__.cpython-38.pyc
DELETED
|
Binary file (140 Bytes)
|
|
|
models/tools/runner_utils/__pycache__/log_util.cpython-38.pyc
DELETED
|
Binary file (969 Bytes)
|
|
|
models/tools/runner_utils/conifg_extensive.py
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
from transformers import AutoConfig
|
| 2 |
-
from config import ModelArguments
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
# add external config.
|
| 6 |
-
def config_extensive(hf_config: AutoConfig, model_config: ModelArguments):
|
| 7 |
-
hf_config.use_prompt_for_cls = model_config.use_prompt_for_cls
|
| 8 |
-
hf_config.use_freezing = model_config.use_freezing
|
| 9 |
-
hf_config.adapter_choice = model_config.adapter_choice
|
| 10 |
-
hf_config.adapter_dim = model_config.adapter_dim
|
| 11 |
-
hf_config.pre_seq_len = model_config.pre_seq_len
|
| 12 |
-
hf_config.prefix_projection = model_config.prefix_projection
|
| 13 |
-
hf_config.prefix_hidden_size = model_config.prefix_hidden_size
|
| 14 |
-
hf_config.hidden_dropout_prob = model_config.hidden_dropout_prob
|
| 15 |
-
return hf_config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/tools/runner_utils/log_util.py
DELETED
|
@@ -1,30 +0,0 @@
|
|
| 1 |
-
import sys
|
| 2 |
-
import logging
|
| 3 |
-
import datasets
|
| 4 |
-
import transformers
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
def init_logger(log_file, log_level, dist_rank):
|
| 8 |
-
datasets.utils.logging.set_verbosity(log_level)
|
| 9 |
-
transformers.utils.logging.set_verbosity(log_level)
|
| 10 |
-
transformers.utils.logging.enable_default_handler()
|
| 11 |
-
transformers.utils.logging.enable_explicit_format()
|
| 12 |
-
datasets.utils.logging.disable_propagation()
|
| 13 |
-
# transformers.utils.logging.enable_propagation()
|
| 14 |
-
|
| 15 |
-
logger = logging.getLogger("")
|
| 16 |
-
log_format = logging.Formatter(fmt="[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
| 17 |
-
logger.setLevel(log_level)
|
| 18 |
-
console_handler = logging.StreamHandler(sys.stderr)
|
| 19 |
-
console_handler.setFormatter(log_format)
|
| 20 |
-
logger.addHandler(console_handler)
|
| 21 |
-
# transformer_logger = logging.getLogger("transformers")
|
| 22 |
-
# transformer_logger.handlers = []
|
| 23 |
-
# transformer_logger.propagate = True
|
| 24 |
-
|
| 25 |
-
if dist_rank in [-1, 0]:
|
| 26 |
-
file_handler = logging.FileHandler(log_file, mode="a")
|
| 27 |
-
file_handler.setLevel(log_level)
|
| 28 |
-
file_handler.setFormatter(log_format)
|
| 29 |
-
logger.addHandler(file_handler)
|
| 30 |
-
logging.getLogger("transformers").addHandler(file_handler)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/tools/runner_utils/retrying.py
DELETED
|
@@ -1,288 +0,0 @@
|
|
| 1 |
-
# -*- coding: utf-8 -*-
|
| 2 |
-
# @Time : 2021/12/24 4:05 p.m.
|
| 3 |
-
# @Author : JianingWang
|
| 4 |
-
# @File : retrying.py
|
| 5 |
-
|
| 6 |
-
import random
|
| 7 |
-
import six
|
| 8 |
-
import sys
|
| 9 |
-
import time
|
| 10 |
-
import traceback
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
MAX_WAIT = 1073741823
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
def _retry_if_exception_of_type(retryable_types):
|
| 17 |
-
def _retry_if_exception_these_types(exception):
|
| 18 |
-
return isinstance(exception, retryable_types)
|
| 19 |
-
return _retry_if_exception_these_types
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
def retry(*dargs, **dkw):
|
| 23 |
-
"""
|
| 24 |
-
Decorator function that instantiates the Retrying object
|
| 25 |
-
@param *dargs: positional arguments passed to Retrying object
|
| 26 |
-
@param **dkw: keyword arguments passed to the Retrying object
|
| 27 |
-
"""
|
| 28 |
-
# support both @retry and @retry() as valid syntax
|
| 29 |
-
if len(dargs) == 1 and callable(dargs[0]):
|
| 30 |
-
def wrap_simple(f):
|
| 31 |
-
|
| 32 |
-
@six.wraps(f)
|
| 33 |
-
def wrapped_f(*args, **kw):
|
| 34 |
-
return Retrying().call(f, *args, **kw)
|
| 35 |
-
|
| 36 |
-
return wrapped_f
|
| 37 |
-
|
| 38 |
-
return wrap_simple(dargs[0])
|
| 39 |
-
|
| 40 |
-
else:
|
| 41 |
-
def wrap(f):
|
| 42 |
-
|
| 43 |
-
@six.wraps(f)
|
| 44 |
-
def wrapped_f(*args, **kw):
|
| 45 |
-
return Retrying(*dargs, **dkw).call(f, *args, **kw)
|
| 46 |
-
|
| 47 |
-
return wrapped_f
|
| 48 |
-
|
| 49 |
-
return wrap
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
class Retrying(object):
|
| 53 |
-
|
| 54 |
-
def __init__(self,
|
| 55 |
-
stop=None, wait=None,
|
| 56 |
-
stop_max_attempt_number=None,
|
| 57 |
-
stop_max_delay=None,
|
| 58 |
-
wait_fixed=None,
|
| 59 |
-
wait_random_min=None, wait_random_max=None,
|
| 60 |
-
wait_incrementing_start=None, wait_incrementing_increment=None,
|
| 61 |
-
wait_incrementing_max=None,
|
| 62 |
-
wait_exponential_multiplier=None, wait_exponential_max=None,
|
| 63 |
-
retry_on_exception=None,
|
| 64 |
-
retry_on_result=None,
|
| 65 |
-
wrap_exception=False,
|
| 66 |
-
stop_func=None,
|
| 67 |
-
wait_func=None,
|
| 68 |
-
wait_jitter_max=None,
|
| 69 |
-
before_attempts=None,
|
| 70 |
-
after_attempts=None,
|
| 71 |
-
skip_raise=False):
|
| 72 |
-
|
| 73 |
-
self._stop_max_attempt_number = 5 if stop_max_attempt_number is None else stop_max_attempt_number
|
| 74 |
-
self._stop_max_delay = 100 if stop_max_delay is None else stop_max_delay
|
| 75 |
-
self._wait_fixed = 1000 if wait_fixed is None else wait_fixed
|
| 76 |
-
self._wait_random_min = 0 if wait_random_min is None else wait_random_min
|
| 77 |
-
self._wait_random_max = 1000 if wait_random_max is None else wait_random_max
|
| 78 |
-
self._wait_incrementing_start = 0 if wait_incrementing_start is None else wait_incrementing_start
|
| 79 |
-
self._wait_incrementing_increment = 100 if wait_incrementing_increment is None else wait_incrementing_increment
|
| 80 |
-
self._wait_exponential_multiplier = 1 if wait_exponential_multiplier is None else wait_exponential_multiplier
|
| 81 |
-
self._wait_exponential_max = MAX_WAIT if wait_exponential_max is None else wait_exponential_max
|
| 82 |
-
self._wait_incrementing_max = MAX_WAIT if wait_incrementing_max is None else wait_incrementing_max
|
| 83 |
-
self._wait_jitter_max = 0 if wait_jitter_max is None else wait_jitter_max
|
| 84 |
-
self._before_attempts = before_attempts
|
| 85 |
-
self._after_attempts = after_attempts
|
| 86 |
-
self._skip_raise = skip_raise
|
| 87 |
-
|
| 88 |
-
# stop behavior
|
| 89 |
-
stop_funcs = []
|
| 90 |
-
if stop_max_attempt_number is not None:
|
| 91 |
-
stop_funcs.append(self.stop_after_attempt)
|
| 92 |
-
|
| 93 |
-
if stop_max_delay is not None:
|
| 94 |
-
stop_funcs.append(self.stop_after_delay)
|
| 95 |
-
|
| 96 |
-
if stop_func is not None:
|
| 97 |
-
self.stop = stop_func
|
| 98 |
-
|
| 99 |
-
elif stop is None:
|
| 100 |
-
self.stop = lambda attempts, delay: any(f(attempts, delay) for f in stop_funcs)
|
| 101 |
-
|
| 102 |
-
else:
|
| 103 |
-
self.stop = getattr(self, stop)
|
| 104 |
-
|
| 105 |
-
# wait behavior
|
| 106 |
-
wait_funcs = [lambda *args, **kwargs: 0]
|
| 107 |
-
if wait_fixed is not None:
|
| 108 |
-
wait_funcs.append(self.fixed_sleep)
|
| 109 |
-
|
| 110 |
-
if wait_random_min is not None or wait_random_max is not None:
|
| 111 |
-
wait_funcs.append(self.random_sleep)
|
| 112 |
-
|
| 113 |
-
if wait_incrementing_start is not None or wait_incrementing_increment is not None:
|
| 114 |
-
wait_funcs.append(self.incrementing_sleep)
|
| 115 |
-
|
| 116 |
-
if wait_exponential_multiplier is not None or wait_exponential_max is not None:
|
| 117 |
-
wait_funcs.append(self.exponential_sleep)
|
| 118 |
-
|
| 119 |
-
if wait_func is not None:
|
| 120 |
-
self.wait = wait_func
|
| 121 |
-
|
| 122 |
-
elif wait is None:
|
| 123 |
-
self.wait = lambda attempts, delay: max(f(attempts, delay) for f in wait_funcs)
|
| 124 |
-
|
| 125 |
-
else:
|
| 126 |
-
self.wait = getattr(self, wait)
|
| 127 |
-
|
| 128 |
-
# retry on exception filter
|
| 129 |
-
if retry_on_exception is None:
|
| 130 |
-
self._retry_on_exception = self.always_reject
|
| 131 |
-
else:
|
| 132 |
-
# this allows for providing a tuple of exception types that
|
| 133 |
-
# should be allowed to retry on, and avoids having to create
|
| 134 |
-
# a callback that does the same thing
|
| 135 |
-
if isinstance(retry_on_exception, (tuple)):
|
| 136 |
-
retry_on_exception = _retry_if_exception_of_type(
|
| 137 |
-
retry_on_exception)
|
| 138 |
-
self._retry_on_exception = retry_on_exception
|
| 139 |
-
|
| 140 |
-
# retry on result filter
|
| 141 |
-
if retry_on_result is None:
|
| 142 |
-
self._retry_on_result = self.never_reject
|
| 143 |
-
else:
|
| 144 |
-
self._retry_on_result = retry_on_result
|
| 145 |
-
|
| 146 |
-
self._wrap_exception = wrap_exception
|
| 147 |
-
|
| 148 |
-
def stop_after_attempt(self, previous_attempt_number, delay_since_first_attempt_ms):
|
| 149 |
-
"""Stop after the previous attempt >= stop_max_attempt_number."""
|
| 150 |
-
return previous_attempt_number >= self._stop_max_attempt_number
|
| 151 |
-
|
| 152 |
-
def stop_after_delay(self, previous_attempt_number, delay_since_first_attempt_ms):
|
| 153 |
-
"""Stop after the time from the first attempt >= stop_max_delay."""
|
| 154 |
-
return delay_since_first_attempt_ms >= self._stop_max_delay
|
| 155 |
-
|
| 156 |
-
@staticmethod
|
| 157 |
-
def no_sleep(previous_attempt_number, delay_since_first_attempt_ms):
|
| 158 |
-
"""Don"t sleep at all before retrying."""
|
| 159 |
-
return 0
|
| 160 |
-
|
| 161 |
-
def fixed_sleep(self, previous_attempt_number, delay_since_first_attempt_ms):
|
| 162 |
-
"""Sleep a fixed amount of time between each retry."""
|
| 163 |
-
return self._wait_fixed
|
| 164 |
-
|
| 165 |
-
def random_sleep(self, previous_attempt_number, delay_since_first_attempt_ms):
|
| 166 |
-
"""Sleep a random amount of time between wait_random_min and wait_random_max"""
|
| 167 |
-
return random.randint(self._wait_random_min, self._wait_random_max)
|
| 168 |
-
|
| 169 |
-
def incrementing_sleep(self, previous_attempt_number, delay_since_first_attempt_ms):
|
| 170 |
-
"""
|
| 171 |
-
Sleep an incremental amount of time after each attempt, starting at
|
| 172 |
-
wait_incrementing_start and incrementing by wait_incrementing_increment
|
| 173 |
-
"""
|
| 174 |
-
result = self._wait_incrementing_start + (self._wait_incrementing_increment * (previous_attempt_number - 1))
|
| 175 |
-
if result > self._wait_incrementing_max:
|
| 176 |
-
result = self._wait_incrementing_max
|
| 177 |
-
if result < 0:
|
| 178 |
-
result = 0
|
| 179 |
-
return result
|
| 180 |
-
|
| 181 |
-
def exponential_sleep(self, previous_attempt_number, delay_since_first_attempt_ms):
|
| 182 |
-
exp = 2 ** previous_attempt_number
|
| 183 |
-
result = self._wait_exponential_multiplier * exp
|
| 184 |
-
if result > self._wait_exponential_max:
|
| 185 |
-
result = self._wait_exponential_max
|
| 186 |
-
if result < 0:
|
| 187 |
-
result = 0
|
| 188 |
-
return result
|
| 189 |
-
|
| 190 |
-
@staticmethod
|
| 191 |
-
def never_reject(result):
|
| 192 |
-
return False
|
| 193 |
-
|
| 194 |
-
@staticmethod
|
| 195 |
-
def always_reject(result):
|
| 196 |
-
return True
|
| 197 |
-
|
| 198 |
-
def should_reject(self, attempt):
|
| 199 |
-
reject = False
|
| 200 |
-
if attempt.has_exception:
|
| 201 |
-
reject |= self._retry_on_exception(attempt.value[1])
|
| 202 |
-
else:
|
| 203 |
-
reject |= self._retry_on_result(attempt.value)
|
| 204 |
-
|
| 205 |
-
return reject
|
| 206 |
-
|
| 207 |
-
def call(self, fn, *args, **kwargs):
|
| 208 |
-
start_time = int(round(time.time() * 1000))
|
| 209 |
-
attempt_number = 1
|
| 210 |
-
while True:
|
| 211 |
-
if self._before_attempts:
|
| 212 |
-
self._before_attempts(attempt_number)
|
| 213 |
-
|
| 214 |
-
try:
|
| 215 |
-
attempt = Attempt(fn(*args, **kwargs), attempt_number, False)
|
| 216 |
-
except:
|
| 217 |
-
tb = sys.exc_info()
|
| 218 |
-
attempt = Attempt(tb, attempt_number, True)
|
| 219 |
-
|
| 220 |
-
if not self.should_reject(attempt):
|
| 221 |
-
return attempt.get(self._wrap_exception)
|
| 222 |
-
|
| 223 |
-
if self._after_attempts:
|
| 224 |
-
self._after_attempts(attempt_number)
|
| 225 |
-
|
| 226 |
-
delay_since_first_attempt_ms = int(round(time.time() * 1000)) - start_time
|
| 227 |
-
if self.stop(attempt_number, delay_since_first_attempt_ms):
|
| 228 |
-
if not self._wrap_exception and attempt.has_exception:
|
| 229 |
-
# get() on an attempt with an exception should cause it to be raised, but raise just in case
|
| 230 |
-
if not self._skip_raise:
|
| 231 |
-
raise attempt.get()
|
| 232 |
-
else:
|
| 233 |
-
break
|
| 234 |
-
else:
|
| 235 |
-
raise RetryError(attempt)
|
| 236 |
-
else:
|
| 237 |
-
sleep = self.wait(attempt_number, delay_since_first_attempt_ms)
|
| 238 |
-
if self._wait_jitter_max:
|
| 239 |
-
jitter = random.random() * self._wait_jitter_max
|
| 240 |
-
sleep = sleep + max(0, jitter)
|
| 241 |
-
time.sleep(sleep / 1000.0)
|
| 242 |
-
|
| 243 |
-
attempt_number += 1
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
class Attempt(object):
|
| 247 |
-
"""
|
| 248 |
-
An Attempt encapsulates a call to a target function that may end as a
|
| 249 |
-
normal return value from the function or an Exception depending on what
|
| 250 |
-
occurred during the execution.
|
| 251 |
-
"""
|
| 252 |
-
|
| 253 |
-
def __init__(self, value, attempt_number, has_exception):
|
| 254 |
-
self.value = value
|
| 255 |
-
self.attempt_number = attempt_number
|
| 256 |
-
self.has_exception = has_exception
|
| 257 |
-
|
| 258 |
-
def get(self, wrap_exception=False):
|
| 259 |
-
"""
|
| 260 |
-
Return the return value of this Attempt instance or raise an Exception.
|
| 261 |
-
If wrap_exception is true, this Attempt is wrapped inside of a
|
| 262 |
-
RetryError before being raised.
|
| 263 |
-
"""
|
| 264 |
-
if self.has_exception:
|
| 265 |
-
if wrap_exception:
|
| 266 |
-
raise RetryError(self)
|
| 267 |
-
else:
|
| 268 |
-
six.reraise(self.value[0], self.value[1], self.value[2])
|
| 269 |
-
else:
|
| 270 |
-
return self.value
|
| 271 |
-
|
| 272 |
-
def __repr__(self):
|
| 273 |
-
if self.has_exception:
|
| 274 |
-
return "Attempts: {0}, Error:\n{1}".format(self.attempt_number, "".join(traceback.format_tb(self.value[2])))
|
| 275 |
-
else:
|
| 276 |
-
return "Attempts: {0}, Value: {1}".format(self.attempt_number, self.value)
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
class RetryError(Exception):
|
| 280 |
-
"""
|
| 281 |
-
A RetryError encapsulates the last Attempt instance right before giving up.
|
| 282 |
-
"""
|
| 283 |
-
|
| 284 |
-
def __init__(self, last_attempt):
|
| 285 |
-
self.last_attempt = last_attempt
|
| 286 |
-
|
| 287 |
-
def __str__(self):
|
| 288 |
-
return "RetryError[{0}]".format(self.last_attempt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/tools/runner_utils/set_seed.py
DELETED
|
@@ -1,21 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import random
|
| 3 |
-
import numpy as np
|
| 4 |
-
|
| 5 |
-
from transformers.utils import (
|
| 6 |
-
is_tf_available,
|
| 7 |
-
is_torch_available,
|
| 8 |
-
)
|
| 9 |
-
|
| 10 |
-
def set_seed(seed_value: int):
|
| 11 |
-
"""
|
| 12 |
-
Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch` and/or `tf` (if installed).
|
| 13 |
-
|
| 14 |
-
Args:
|
| 15 |
-
seed (`int`): The seed to set.
|
| 16 |
-
"""
|
| 17 |
-
random.seed(seed_value)
|
| 18 |
-
np.random.seed(seed_value)
|
| 19 |
-
if is_torch_available():
|
| 20 |
-
torch.manual_seed(seed_value)
|
| 21 |
-
torch.cuda.manual_seed_all(seed_value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/tools/runner_utils/timecost.py
DELETED
|
@@ -1,20 +0,0 @@
|
|
| 1 |
-
# -*- coding: utf-8 -*-
|
| 2 |
-
# @Time : 2022/3/11 3:06 p.m.
|
| 3 |
-
# @Author : JianingWang
|
| 4 |
-
# @File : time
|
| 5 |
-
|
| 6 |
-
import time
|
| 7 |
-
import logging
|
| 8 |
-
|
| 9 |
-
logger = logging.getLogger(__name__)
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
def timecost(method):
|
| 13 |
-
def timed(*args, **kw):
|
| 14 |
-
ts = time.time()
|
| 15 |
-
result = method(*args, **kw)
|
| 16 |
-
te = time.time()
|
| 17 |
-
logger.info("%r %2.2f ms" % (method.__name__, (te - ts) * 1000))
|
| 18 |
-
return result
|
| 19 |
-
|
| 20 |
-
return timed
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|