Spaces:
Sleeping
Sleeping
fix: decode
Browse files- modules/ai_model.py +7 -18
modules/ai_model.py
CHANGED
|
@@ -163,23 +163,16 @@ class AIModel:
|
|
| 163 |
return_tensors="pt"
|
| 164 |
).to(self.model.device, dtype=torch.bfloat16)
|
| 165 |
|
| 166 |
-
|
| 167 |
-
if hasattr(inputs, 'input_ids') and inputs.input_ids.shape[-1] > 512:
|
| 168 |
-
log.warning(f"⚠️ 截断过长输入: {inputs.input_ids.shape[-1]} -> 512")
|
| 169 |
-
inputs.input_ids = inputs.input_ids[:, :512]
|
| 170 |
-
if hasattr(inputs, 'attention_mask'):
|
| 171 |
-
inputs.attention_mask = inputs.attention_mask[:, :512]
|
| 172 |
-
|
| 173 |
-
# --- 这是关键的修改 ---
|
| 174 |
with torch.inference_mode():
|
| 175 |
generation_args = {
|
| 176 |
-
"max_new_tokens":
|
| 177 |
"pad_token_id": self.processor.tokenizer.eos_token_id,
|
| 178 |
"use_cache": True
|
| 179 |
}
|
| 180 |
|
| 181 |
# 如果 temperature 接近0,使用贪心解码 (用于分类等确定性任务)
|
| 182 |
-
if temperature < 1e-6:
|
| 183 |
log.info("▶️ 使用贪心解码 (do_sample=False) 以获得确定性输出。")
|
| 184 |
generation_args["do_sample"] = False
|
| 185 |
# 否则,使用采样解码 (用于创造性生成任务)
|
|
@@ -194,15 +187,11 @@ class AIModel:
|
|
| 194 |
**inputs,
|
| 195 |
**generation_args
|
| 196 |
)
|
| 197 |
-
|
| 198 |
-
decoded = self.processor.tokenizer.decode(
|
| 199 |
-
|
| 200 |
-
# 移除prompt部分
|
| 201 |
-
if prompt in decoded:
|
| 202 |
-
decoded = decoded.replace(prompt, "").strip()
|
| 203 |
-
|
| 204 |
return decoded if decoded else "我理解了您的问题,请告诉我更多具体信息。"
|
| 205 |
-
|
| 206 |
except RuntimeError as e:
|
| 207 |
if "shape" in str(e):
|
| 208 |
log.error(f"❌ Tensor形状错误: {e}")
|
|
|
|
| 163 |
return_tensors="pt"
|
| 164 |
).to(self.model.device, dtype=torch.bfloat16)
|
| 165 |
|
| 166 |
+
input_len = inputs.input_ids.shape[-1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
with torch.inference_mode():
|
| 168 |
generation_args = {
|
| 169 |
+
"max_new_tokens": 512,
|
| 170 |
"pad_token_id": self.processor.tokenizer.eos_token_id,
|
| 171 |
"use_cache": True
|
| 172 |
}
|
| 173 |
|
| 174 |
# 如果 temperature 接近0,使用贪心解码 (用于分类等确定性任务)
|
| 175 |
+
if temperature < 1e-6:
|
| 176 |
log.info("▶️ 使用贪心解码 (do_sample=False) 以获得确定性输出。")
|
| 177 |
generation_args["do_sample"] = False
|
| 178 |
# 否则,使用采样解码 (用于创造性生成任务)
|
|
|
|
| 187 |
**inputs,
|
| 188 |
**generation_args
|
| 189 |
)
|
| 190 |
+
generated_tokens = outputs[0][input_len:]
|
| 191 |
+
decoded = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
|
| 192 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
return decoded if decoded else "我理解了您的问题,请告诉我更多具体信息。"
|
| 194 |
+
|
| 195 |
except RuntimeError as e:
|
| 196 |
if "shape" in str(e):
|
| 197 |
log.error(f"❌ Tensor形状错误: {e}")
|