Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -147,6 +147,44 @@ Then upload 'best_model_cpu.pkl' to this Space and rename it to 'best_model.pkl'
|
|
| 147 |
loaded_model.eval()
|
| 148 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 149 |
loaded_model = loaded_model.to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
except Exception as e:
|
| 151 |
return (f"❌ Error preparing model for inference: {str(e)}\n\n"
|
| 152 |
"This can happen if the saved object is not a proper torch.nn.Module or if tensors couldn't be mapped to the current device.")
|
|
@@ -264,19 +302,54 @@ def generate_code_from_pseudo(pseudo_code, max_length, temperature, top_k, top_p
|
|
| 264 |
# Generate (ensure type safety for parameters)
|
| 265 |
with torch.no_grad():
|
| 266 |
try:
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
max_length
|
| 270 |
-
temperature
|
| 271 |
-
top_k
|
| 272 |
-
top_p
|
| 273 |
-
do_sample
|
| 274 |
-
num_return_sequences
|
| 275 |
-
pad_token_id
|
| 276 |
-
eos_token_id
|
| 277 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
except Exception as generation_error:
|
| 279 |
-
return f"❌ Generation failed: {str(generation_error)}", "", "", ""
|
| 280 |
|
| 281 |
generation_time = time.time() - start_time
|
| 282 |
|
|
|
|
| 147 |
loaded_model.eval()
|
| 148 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 149 |
loaded_model = loaded_model.to(device)
|
| 150 |
+
|
| 151 |
+
# Fix generation config compatibility issues
|
| 152 |
+
if hasattr(loaded_model, 'generation_config'):
|
| 153 |
+
gen_config = loaded_model.generation_config
|
| 154 |
+
|
| 155 |
+
# Remove problematic attributes that don't exist in current transformers version
|
| 156 |
+
problematic_attrs = [
|
| 157 |
+
'forced_decoder_ids', 'forced_bos_token_id', 'forced_eos_token_id',
|
| 158 |
+
'suppress_tokens', 'begin_suppress_tokens', 'decoder_start_token_id'
|
| 159 |
+
]
|
| 160 |
+
|
| 161 |
+
for attr in problematic_attrs:
|
| 162 |
+
if hasattr(gen_config, attr):
|
| 163 |
+
try:
|
| 164 |
+
delattr(gen_config, attr)
|
| 165 |
+
except:
|
| 166 |
+
pass
|
| 167 |
+
|
| 168 |
+
# Ensure required attributes exist with safe defaults
|
| 169 |
+
if not hasattr(gen_config, 'pad_token_id') or gen_config.pad_token_id is None:
|
| 170 |
+
gen_config.pad_token_id = loaded_tokenizer.eos_token_id if loaded_tokenizer else 50256
|
| 171 |
+
|
| 172 |
+
if not hasattr(gen_config, 'eos_token_id') or gen_config.eos_token_id is None:
|
| 173 |
+
gen_config.eos_token_id = loaded_tokenizer.eos_token_id if loaded_tokenizer else 50256
|
| 174 |
+
|
| 175 |
+
if not hasattr(gen_config, 'bos_token_id'):
|
| 176 |
+
gen_config.bos_token_id = loaded_tokenizer.bos_token_id if loaded_tokenizer else 50256
|
| 177 |
+
|
| 178 |
+
else:
|
| 179 |
+
# Create a basic generation config if missing
|
| 180 |
+
from transformers import GenerationConfig
|
| 181 |
+
loaded_model.generation_config = GenerationConfig(
|
| 182 |
+
pad_token_id=loaded_tokenizer.eos_token_id if loaded_tokenizer else 50256,
|
| 183 |
+
eos_token_id=loaded_tokenizer.eos_token_id if loaded_tokenizer else 50256,
|
| 184 |
+
do_sample=True,
|
| 185 |
+
max_length=512
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
except Exception as e:
|
| 189 |
return (f"❌ Error preparing model for inference: {str(e)}\n\n"
|
| 190 |
"This can happen if the saved object is not a proper torch.nn.Module or if tensors couldn't be mapped to the current device.")
|
|
|
|
| 302 |
# Generate (ensure type safety for parameters)
|
| 303 |
with torch.no_grad():
|
| 304 |
try:
|
| 305 |
+
# Create generation kwargs with compatibility handling
|
| 306 |
+
generation_kwargs = {
|
| 307 |
+
'max_length': int(max_length),
|
| 308 |
+
'temperature': float(temperature),
|
| 309 |
+
'top_k': int(top_k),
|
| 310 |
+
'top_p': float(top_p),
|
| 311 |
+
'do_sample': True,
|
| 312 |
+
'num_return_sequences': int(num_sequences),
|
| 313 |
+
'pad_token_id': loaded_tokenizer.pad_token_id,
|
| 314 |
+
'eos_token_id': loaded_tokenizer.eos_token_id,
|
| 315 |
+
}
|
| 316 |
+
|
| 317 |
+
# Remove any None values that might cause issues
|
| 318 |
+
generation_kwargs = {k: v for k, v in generation_kwargs.items() if v is not None}
|
| 319 |
+
|
| 320 |
+
# Add input_ids explicitly
|
| 321 |
+
generation_kwargs.update(inputs)
|
| 322 |
+
|
| 323 |
+
# Try generation with comprehensive error handling
|
| 324 |
+
try:
|
| 325 |
+
outputs = loaded_model.generate(**generation_kwargs)
|
| 326 |
+
except Exception as gen_error:
|
| 327 |
+
# First fallback: try without problematic parameters
|
| 328 |
+
if 'forced_decoder_ids' in str(gen_error) or 'GenerationConfig' in str(gen_error):
|
| 329 |
+
# Reset generation config to minimal safe version
|
| 330 |
+
if hasattr(loaded_model, 'generation_config'):
|
| 331 |
+
from transformers import GenerationConfig
|
| 332 |
+
loaded_model.generation_config = GenerationConfig(
|
| 333 |
+
pad_token_id=loaded_tokenizer.pad_token_id,
|
| 334 |
+
eos_token_id=loaded_tokenizer.eos_token_id,
|
| 335 |
+
do_sample=True
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
# Try again with minimal parameters
|
| 339 |
+
minimal_kwargs = {
|
| 340 |
+
'max_length': int(max_length),
|
| 341 |
+
'do_sample': True,
|
| 342 |
+
'temperature': float(temperature),
|
| 343 |
+
'pad_token_id': loaded_tokenizer.pad_token_id,
|
| 344 |
+
'eos_token_id': loaded_tokenizer.eos_token_id,
|
| 345 |
+
}
|
| 346 |
+
minimal_kwargs.update(inputs)
|
| 347 |
+
outputs = loaded_model.generate(**minimal_kwargs)
|
| 348 |
+
else:
|
| 349 |
+
raise gen_error
|
| 350 |
+
|
| 351 |
except Exception as generation_error:
|
| 352 |
+
return f"❌ Generation failed: {str(generation_error)}\n\nTry using default parameters or check model compatibility.", "", "", ""
|
| 353 |
|
| 354 |
generation_time = time.time() - start_time
|
| 355 |
|