hamxaameer commited on
Commit
2711df5
·
verified ·
1 Parent(s): b7bd99f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -12
app.py CHANGED
@@ -147,6 +147,44 @@ Then upload 'best_model_cpu.pkl' to this Space and rename it to 'best_model.pkl'
147
  loaded_model.eval()
148
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
149
  loaded_model = loaded_model.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  except Exception as e:
151
  return (f"❌ Error preparing model for inference: {str(e)}\n\n"
152
  "This can happen if the saved object is not a proper torch.nn.Module or if tensors couldn't be mapped to the current device.")
@@ -264,19 +302,54 @@ def generate_code_from_pseudo(pseudo_code, max_length, temperature, top_k, top_p
264
  # Generate (ensure type safety for parameters)
265
  with torch.no_grad():
266
  try:
267
- outputs = loaded_model.generate(
268
- **inputs,
269
- max_length=int(max_length),
270
- temperature=float(temperature),
271
- top_k=int(top_k),
272
- top_p=float(top_p),
273
- do_sample=True,
274
- num_return_sequences=int(num_sequences),
275
- pad_token_id=loaded_tokenizer.pad_token_id,
276
- eos_token_id=loaded_tokenizer.eos_token_id,
277
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
  except Exception as generation_error:
279
- return f"❌ Generation failed: {str(generation_error)}", "", "", ""
280
 
281
  generation_time = time.time() - start_time
282
 
 
147
  loaded_model.eval()
148
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
149
  loaded_model = loaded_model.to(device)
150
+
151
+ # Fix generation config compatibility issues
152
+ if hasattr(loaded_model, 'generation_config'):
153
+ gen_config = loaded_model.generation_config
154
+
155
+ # Remove problematic attributes that don't exist in current transformers version
156
+ problematic_attrs = [
157
+ 'forced_decoder_ids', 'forced_bos_token_id', 'forced_eos_token_id',
158
+ 'suppress_tokens', 'begin_suppress_tokens', 'decoder_start_token_id'
159
+ ]
160
+
161
+ for attr in problematic_attrs:
162
+ if hasattr(gen_config, attr):
163
+ try:
164
+ delattr(gen_config, attr)
165
+ except:
166
+ pass
167
+
168
+ # Ensure required attributes exist with safe defaults
169
+ if not hasattr(gen_config, 'pad_token_id') or gen_config.pad_token_id is None:
170
+ gen_config.pad_token_id = loaded_tokenizer.eos_token_id if loaded_tokenizer else 50256
171
+
172
+ if not hasattr(gen_config, 'eos_token_id') or gen_config.eos_token_id is None:
173
+ gen_config.eos_token_id = loaded_tokenizer.eos_token_id if loaded_tokenizer else 50256
174
+
175
+ if not hasattr(gen_config, 'bos_token_id'):
176
+ gen_config.bos_token_id = loaded_tokenizer.bos_token_id if loaded_tokenizer else 50256
177
+
178
+ else:
179
+ # Create a basic generation config if missing
180
+ from transformers import GenerationConfig
181
+ loaded_model.generation_config = GenerationConfig(
182
+ pad_token_id=loaded_tokenizer.eos_token_id if loaded_tokenizer else 50256,
183
+ eos_token_id=loaded_tokenizer.eos_token_id if loaded_tokenizer else 50256,
184
+ do_sample=True,
185
+ max_length=512
186
+ )
187
+
188
  except Exception as e:
189
  return (f"❌ Error preparing model for inference: {str(e)}\n\n"
190
  "This can happen if the saved object is not a proper torch.nn.Module or if tensors couldn't be mapped to the current device.")
 
302
  # Generate (ensure type safety for parameters)
303
  with torch.no_grad():
304
  try:
305
+ # Create generation kwargs with compatibility handling
306
+ generation_kwargs = {
307
+ 'max_length': int(max_length),
308
+ 'temperature': float(temperature),
309
+ 'top_k': int(top_k),
310
+ 'top_p': float(top_p),
311
+ 'do_sample': True,
312
+ 'num_return_sequences': int(num_sequences),
313
+ 'pad_token_id': loaded_tokenizer.pad_token_id,
314
+ 'eos_token_id': loaded_tokenizer.eos_token_id,
315
+ }
316
+
317
+ # Remove any None values that might cause issues
318
+ generation_kwargs = {k: v for k, v in generation_kwargs.items() if v is not None}
319
+
320
+ # Add input_ids explicitly
321
+ generation_kwargs.update(inputs)
322
+
323
+ # Try generation with comprehensive error handling
324
+ try:
325
+ outputs = loaded_model.generate(**generation_kwargs)
326
+ except Exception as gen_error:
327
+ # First fallback: try without problematic parameters
328
+ if 'forced_decoder_ids' in str(gen_error) or 'GenerationConfig' in str(gen_error):
329
+ # Reset generation config to minimal safe version
330
+ if hasattr(loaded_model, 'generation_config'):
331
+ from transformers import GenerationConfig
332
+ loaded_model.generation_config = GenerationConfig(
333
+ pad_token_id=loaded_tokenizer.pad_token_id,
334
+ eos_token_id=loaded_tokenizer.eos_token_id,
335
+ do_sample=True
336
+ )
337
+
338
+ # Try again with minimal parameters
339
+ minimal_kwargs = {
340
+ 'max_length': int(max_length),
341
+ 'do_sample': True,
342
+ 'temperature': float(temperature),
343
+ 'pad_token_id': loaded_tokenizer.pad_token_id,
344
+ 'eos_token_id': loaded_tokenizer.eos_token_id,
345
+ }
346
+ minimal_kwargs.update(inputs)
347
+ outputs = loaded_model.generate(**minimal_kwargs)
348
+ else:
349
+ raise gen_error
350
+
351
  except Exception as generation_error:
352
+ return f"❌ Generation failed: {str(generation_error)}\n\nTry using default parameters or check model compatibility.", "", "", ""
353
 
354
  generation_time = time.time() - start_time
355