Spaces:
Runtime error
Runtime error
fix: model config
Browse files
seq2seq/run_seq2seq_flax.py
CHANGED
|
@@ -282,8 +282,6 @@ class CustomFlaxBartModule(FlaxBartModule):
|
|
| 282 |
# the decoder has a different config
|
| 283 |
decoder_config = BartConfig(self.config.to_dict())
|
| 284 |
decoder_config.max_position_embeddings = OUTPUT_LENGTH
|
| 285 |
-
decoder_config.min_length = OUTPUT_LENGTH
|
| 286 |
-
decoder_config.max_length = OUTPUT_LENGTH
|
| 287 |
decoder_config.vocab_size = OUTPUT_VOCAB_SIZE
|
| 288 |
self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
|
| 289 |
|
|
@@ -440,8 +438,8 @@ def main():
|
|
| 440 |
config.eos_token_id = BOS_TOKEN_ID + 1 # unreachable
|
| 441 |
config.forced_bos_token_id = None # we don't need this token
|
| 442 |
config.forced_eos_token_id = None # we don't need this token
|
| 443 |
-
|
| 444 |
-
|
| 445 |
|
| 446 |
print(f"TPUs: {jax.device_count()}")
|
| 447 |
assert jax.device_count() == 8, "TPUs in use, please check running processes"
|
|
|
|
| 282 |
# the decoder has a different config
|
| 283 |
decoder_config = BartConfig(self.config.to_dict())
|
| 284 |
decoder_config.max_position_embeddings = OUTPUT_LENGTH
|
|
|
|
|
|
|
| 285 |
decoder_config.vocab_size = OUTPUT_VOCAB_SIZE
|
| 286 |
self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
|
| 287 |
|
|
|
|
| 438 |
config.eos_token_id = BOS_TOKEN_ID + 1 # unreachable
|
| 439 |
config.forced_bos_token_id = None # we don't need this token
|
| 440 |
config.forced_eos_token_id = None # we don't need this token
|
| 441 |
+
config.min_length = data_args.max_target_length
|
| 442 |
+
config.max_length = data_args.max_target_length
|
| 443 |
|
| 444 |
print(f"TPUs: {jax.device_count()}")
|
| 445 |
assert jax.device_count() == 8, "TPUs in use, please check running processes"
|