Eliot0110 commited on
Commit
8d69a10
·
1 Parent(s): 96512ae

fix: decode

Browse files
Files changed (1) hide show
  1. modules/ai_model.py +7 -18
modules/ai_model.py CHANGED
@@ -163,23 +163,16 @@ class AIModel:
163
  return_tensors="pt"
164
  ).to(self.model.device, dtype=torch.bfloat16)
165
 
166
- # 截断过长的 token
167
- if hasattr(inputs, 'input_ids') and inputs.input_ids.shape[-1] > 512:
168
- log.warning(f"⚠️ 截断过长输入: {inputs.input_ids.shape[-1]} -> 512")
169
- inputs.input_ids = inputs.input_ids[:, :512]
170
- if hasattr(inputs, 'attention_mask'):
171
- inputs.attention_mask = inputs.attention_mask[:, :512]
172
-
173
- # --- 这是关键的修改 ---
174
  with torch.inference_mode():
175
  generation_args = {
176
- "max_new_tokens": 256,
177
  "pad_token_id": self.processor.tokenizer.eos_token_id,
178
  "use_cache": True
179
  }
180
 
181
  # 如果 temperature 接近0,使用贪心解码 (用于分类等确定性任务)
182
- if temperature < 1e-6: # 使用一个很小的数来比较浮点数
183
  log.info("▶️ 使用贪心解码 (do_sample=False) 以获得确定性输出。")
184
  generation_args["do_sample"] = False
185
  # 否则,使用采样解码 (用于创造性生成任务)
@@ -194,15 +187,11 @@ class AIModel:
194
  **inputs,
195
  **generation_args
196
  )
197
-
198
- decoded = self.processor.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
199
-
200
- # 移除prompt部分
201
- if prompt in decoded:
202
- decoded = decoded.replace(prompt, "").strip()
203
-
204
  return decoded if decoded else "我理解了您的问题,请告诉我更多具体信息。"
205
-
206
  except RuntimeError as e:
207
  if "shape" in str(e):
208
  log.error(f"❌ Tensor形状错误: {e}")
 
163
  return_tensors="pt"
164
  ).to(self.model.device, dtype=torch.bfloat16)
165
 
166
+ input_len = inputs.input_ids.shape[-1]
 
 
 
 
 
 
 
167
  with torch.inference_mode():
168
  generation_args = {
169
+ "max_new_tokens": 512,
170
  "pad_token_id": self.processor.tokenizer.eos_token_id,
171
  "use_cache": True
172
  }
173
 
174
  # 如果 temperature 接近0,使用贪心解码 (用于分类等确定性任务)
175
+ if temperature < 1e-6:
176
  log.info("▶️ 使用贪心解码 (do_sample=False) 以获得确定性输出。")
177
  generation_args["do_sample"] = False
178
  # 否则,使用采样解码 (用于创造性生成任务)
 
187
  **inputs,
188
  **generation_args
189
  )
190
+ generated_tokens = outputs[0][input_len:]
191
+ decoded = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
192
+
 
 
 
 
193
  return decoded if decoded else "我理解了您的问题,请告诉我更多具体信息。"
194
+
195
  except RuntimeError as e:
196
  if "shape" in str(e):
197
  log.error(f"❌ Tensor形状错误: {e}")