Revert "add suggestions cleanning"
Browse filesThis reverts commit a94f020eba50b5aec3746c327d42f6b8c6344ac1.
app.py
CHANGED
|
@@ -10,7 +10,6 @@ from termcolor import cprint
|
|
| 10 |
|
| 11 |
# 初始化簡體到繁體轉換器
|
| 12 |
cc = OpenCC('s2t')
|
| 13 |
-
tokenizer = None
|
| 14 |
|
| 15 |
# 可選模型列表
|
| 16 |
MODEL_LIST = [
|
|
@@ -28,82 +27,10 @@ MODEL_LIST = [
|
|
| 28 |
"Epiculous/Violet_Twilight-v0.2",
|
| 29 |
]
|
| 30 |
|
| 31 |
-
def clean_suggestions(suggestions: list[str], max_levels: int) -> list[str]:
|
| 32 |
-
"""
|
| 33 |
-
清洗建议列表:
|
| 34 |
-
1. 对每条建议用 tokenizer.tokenize 得到 token 序列。
|
| 35 |
-
2. 构建前缀树,将所有 token 序列插入。
|
| 36 |
-
3. 遍历前缀树,仅在深度 <= max_levels 且该节点有子节点时,提取对应 token 前缀。
|
| 37 |
-
4. 将这些 token 前缀转换回文本并去重,返回列表。
|
| 38 |
-
"""
|
| 39 |
-
# 定义 Trie 节点结构
|
| 40 |
-
class TrieNode:
|
| 41 |
-
__slots__ = ("children", "count")
|
| 42 |
-
def __init__(self):
|
| 43 |
-
self.children: dict[str, TrieNode] = {}
|
| 44 |
-
self.count: int = 0 # 可以记录有多少序列经过此节点(可选)
|
| 45 |
-
|
| 46 |
-
# 构建前缀树
|
| 47 |
-
root = TrieNode()
|
| 48 |
-
token_seqs: list[list[str]] = []
|
| 49 |
-
|
| 50 |
-
for text in suggestions:
|
| 51 |
-
# tokenizer.tokenize 可能返回子词 token 列表
|
| 52 |
-
try:
|
| 53 |
-
toks = tokenizer.tokenize(text)
|
| 54 |
-
except Exception:
|
| 55 |
-
# 如果 tokenizer 不支持直接 tokenize raw text,可以先用 basic tokenization,如按空白分割
|
| 56 |
-
toks = text.split()
|
| 57 |
-
if not toks:
|
| 58 |
-
continue
|
| 59 |
-
token_seqs.append(toks)
|
| 60 |
-
node = root
|
| 61 |
-
node.count += 1
|
| 62 |
-
for tok in toks:
|
| 63 |
-
if tok not in node.children:
|
| 64 |
-
node.children[tok] = TrieNode()
|
| 65 |
-
node = node.children[tok]
|
| 66 |
-
node.count += 1
|
| 67 |
-
|
| 68 |
-
# 遍历 Trie,收集深度 <= max_levels 且有子节点的前缀序列
|
| 69 |
-
results_prefix_tokens: set[tuple[str, ...]] = set()
|
| 70 |
-
|
| 71 |
-
def dfs(node: TrieNode, path: list[str], depth: int):
|
| 72 |
-
# node: 当前 TrieNode; path: 已走过的 token 列表; depth: len(path)
|
| 73 |
-
if depth > max_levels:
|
| 74 |
-
return
|
| 75 |
-
# 如果当前节点有子节点,且 depth>0 (排除根节点本身),则为一个候选前缀
|
| 76 |
-
if depth > 0 and node.children:
|
| 77 |
-
results_prefix_tokens.add(tuple(path))
|
| 78 |
-
# 继续往下遍历,直到 depth == max_levels
|
| 79 |
-
if depth == max_levels:
|
| 80 |
-
return
|
| 81 |
-
for tok, child in node.children.items():
|
| 82 |
-
path.append(tok)
|
| 83 |
-
dfs(child, path, depth + 1)
|
| 84 |
-
path.pop()
|
| 85 |
-
|
| 86 |
-
dfs(root, [], 0)
|
| 87 |
-
|
| 88 |
-
# 将 token 前缀转换回字符串
|
| 89 |
-
cleaned: set[str] = set()
|
| 90 |
-
for tok_prefix in results_prefix_tokens:
|
| 91 |
-
try:
|
| 92 |
-
# tokenizer.convert_tokens_to_string 在大多数 tokenizer 支持
|
| 93 |
-
text_pref = tokenizer.convert_tokens_to_string(list(tok_prefix)).strip()
|
| 94 |
-
except Exception:
|
| 95 |
-
# fallback: 直接拼接 token(可能需要根据 tokenizer 规范加空格或直接连起来)
|
| 96 |
-
text_pref = "".join(tok_prefix).strip()
|
| 97 |
-
if text_pref:
|
| 98 |
-
cleaned.add(text_pref)
|
| 99 |
-
|
| 100 |
-
# 返回去重之后的列表
|
| 101 |
-
return list(cleaned)
|
| 102 |
|
| 103 |
@lru_cache(maxsize=8)
|
| 104 |
def get_pipeline(model_name):
|
| 105 |
-
|
| 106 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 107 |
mdl = AutoModelForCausalLM.from_pretrained(
|
| 108 |
model_name, weights_only=False, trust_remote_code=True
|
| 109 |
)
|
|
@@ -111,10 +38,10 @@ def get_pipeline(model_name):
|
|
| 111 |
mdl.to("cuda")
|
| 112 |
except Exception as e:
|
| 113 |
print(f'Error: {e}')
|
| 114 |
-
return pipeline("text-generation", model=mdl, tokenizer=
|
| 115 |
|
| 116 |
@spaces.GPU
|
| 117 |
-
def suggest_next(text, model_name, k, m, num_beam_groups, diversity_penalty
|
| 118 |
"""
|
| 119 |
使用 Diverse Beam Search 產生 m 條候選:
|
| 120 |
- num_beams = m
|
|
@@ -131,7 +58,7 @@ def suggest_next(text, model_name, k, m, num_beam_groups, diversity_penalty, max
|
|
| 131 |
"early_stopping": True,
|
| 132 |
}
|
| 133 |
if diversity_penalty and diversity_penalty > 0:
|
| 134 |
-
valid_group =
|
| 135 |
gen_kwargs["num_beam_groups"] = valid_group
|
| 136 |
gen_kwargs["diversity_penalty"] = float(diversity_penalty)
|
| 137 |
|
|
@@ -146,7 +73,6 @@ def suggest_next(text, model_name, k, m, num_beam_groups, diversity_penalty, max
|
|
| 146 |
converted = cc.convert(snippet).strip()
|
| 147 |
suggestions.add(converted)
|
| 148 |
suggestions = list(suggestions)
|
| 149 |
-
suggestions = clean_suggestions(suggestions, max_prefix_levels)
|
| 150 |
|
| 151 |
return update(choices=suggestions, value=None)
|
| 152 |
|
|
@@ -269,10 +195,6 @@ with gr.Blocks(css=custom_css) as demo:
|
|
| 269 |
minimum=0.0, maximum=2.0, step=0.1, value=1.0,
|
| 270 |
label="多樣性懲罰 (diversity_penalty)"
|
| 271 |
)
|
| 272 |
-
prefix_levels_slider = gr.Slider(
|
| 273 |
-
minimum=1, maximum=5, step=1, value=2,
|
| 274 |
-
label="Clean 前綴深度 (max_levels)"
|
| 275 |
-
)
|
| 276 |
|
| 277 |
# 綁定���件
|
| 278 |
predict_button.click(
|
|
@@ -283,14 +205,13 @@ with gr.Blocks(css=custom_css) as demo:
|
|
| 283 |
k_slider,
|
| 284 |
m_slider,
|
| 285 |
group_slider,
|
| 286 |
-
diversity_penalty_slider
|
| 287 |
-
prefix_levels_slider # 新增
|
| 288 |
],
|
| 289 |
outputs=suggestions,
|
| 290 |
)
|
| 291 |
input_text.change(
|
| 292 |
-
fn=lambda txt, mdl, k, m, g, d, auto
|
| 293 |
-
suggest_next(txt, mdl, k, m, g, d
|
| 294 |
if auto else update(choices=[], value=None)
|
| 295 |
),
|
| 296 |
inputs=[
|
|
@@ -300,8 +221,7 @@ with gr.Blocks(css=custom_css) as demo:
|
|
| 300 |
m_slider,
|
| 301 |
group_slider,
|
| 302 |
diversity_penalty_slider,
|
| 303 |
-
auto_predict
|
| 304 |
-
prefix_levels_slider # 新增
|
| 305 |
],
|
| 306 |
outputs=suggestions,
|
| 307 |
)
|
|
|
|
| 10 |
|
| 11 |
# 初始化簡體到繁體轉換器
|
| 12 |
cc = OpenCC('s2t')
|
|
|
|
| 13 |
|
| 14 |
# 可選模型列表
|
| 15 |
MODEL_LIST = [
|
|
|
|
| 27 |
"Epiculous/Violet_Twilight-v0.2",
|
| 28 |
]
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
@lru_cache(maxsize=8)
|
| 32 |
def get_pipeline(model_name):
|
| 33 |
+
tok = AutoTokenizer.from_pretrained(model_name)
|
|
|
|
| 34 |
mdl = AutoModelForCausalLM.from_pretrained(
|
| 35 |
model_name, weights_only=False, trust_remote_code=True
|
| 36 |
)
|
|
|
|
| 38 |
mdl.to("cuda")
|
| 39 |
except Exception as e:
|
| 40 |
print(f'Error: {e}')
|
| 41 |
+
return pipeline("text-generation", model=mdl, tokenizer=tok, device=0)
|
| 42 |
|
| 43 |
@spaces.GPU
|
| 44 |
+
def suggest_next(text, model_name, k, m, num_beam_groups, diversity_penalty):
|
| 45 |
"""
|
| 46 |
使用 Diverse Beam Search 產生 m 條候選:
|
| 47 |
- num_beams = m
|
|
|
|
| 58 |
"early_stopping": True,
|
| 59 |
}
|
| 60 |
if diversity_penalty and diversity_penalty > 0:
|
| 61 |
+
valid_group = gcd(m, num_beam_groups)
|
| 62 |
gen_kwargs["num_beam_groups"] = valid_group
|
| 63 |
gen_kwargs["diversity_penalty"] = float(diversity_penalty)
|
| 64 |
|
|
|
|
| 73 |
converted = cc.convert(snippet).strip()
|
| 74 |
suggestions.add(converted)
|
| 75 |
suggestions = list(suggestions)
|
|
|
|
| 76 |
|
| 77 |
return update(choices=suggestions, value=None)
|
| 78 |
|
|
|
|
| 195 |
minimum=0.0, maximum=2.0, step=0.1, value=1.0,
|
| 196 |
label="多樣性懲罰 (diversity_penalty)"
|
| 197 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
|
| 199 |
# 綁定���件
|
| 200 |
predict_button.click(
|
|
|
|
| 205 |
k_slider,
|
| 206 |
m_slider,
|
| 207 |
group_slider,
|
| 208 |
+
diversity_penalty_slider
|
|
|
|
| 209 |
],
|
| 210 |
outputs=suggestions,
|
| 211 |
)
|
| 212 |
input_text.change(
|
| 213 |
+
fn=lambda txt, mdl, k, m, g, d, auto: (
|
| 214 |
+
suggest_next(txt, mdl, k, m, g, d)
|
| 215 |
if auto else update(choices=[], value=None)
|
| 216 |
),
|
| 217 |
inputs=[
|
|
|
|
| 221 |
m_slider,
|
| 222 |
group_slider,
|
| 223 |
diversity_penalty_slider,
|
| 224 |
+
auto_predict
|
|
|
|
| 225 |
],
|
| 226 |
outputs=suggestions,
|
| 227 |
)
|