Spaces:
Runtime error
Runtime error
fix: position embedding for generate method
Browse files
src/dalle_mini/model/modeling.py
CHANGED
|
@@ -371,7 +371,8 @@ class FlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationMod
|
|
| 371 |
def setup(self):
|
| 372 |
self.model = FlaxBartModule(config=self.config, dtype=self.dtype)
|
| 373 |
self.lm_head = nn.Dense(
|
| 374 |
-
self.config.image_vocab_size
|
|
|
|
| 375 |
use_bias=False,
|
| 376 |
dtype=self.dtype,
|
| 377 |
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
|
@@ -437,6 +438,8 @@ class DalleBart(
|
|
| 437 |
- uses custom FlaxBartPreTrainedModel
|
| 438 |
- uses custom FlaxBartForConditionalGenerationModule
|
| 439 |
- no bias in decode method
|
|
|
|
|
|
|
| 440 |
"""
|
| 441 |
|
| 442 |
module_class = FlaxBartForConditionalGenerationModule
|
|
@@ -572,3 +575,38 @@ class DalleBart(
|
|
| 572 |
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
|
| 573 |
|
| 574 |
return outputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 371 |
def setup(self):
|
| 372 |
self.model = FlaxBartModule(config=self.config, dtype=self.dtype)
|
| 373 |
self.lm_head = nn.Dense(
|
| 374 |
+
self.config.image_vocab_size
|
| 375 |
+
+ 1, # image vocab size + 1 for BOS to have same size as decoder inputs (for sharding)
|
| 376 |
use_bias=False,
|
| 377 |
dtype=self.dtype,
|
| 378 |
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
|
|
|
| 438 |
- uses custom FlaxBartPreTrainedModel
|
| 439 |
- uses custom FlaxBartForConditionalGenerationModule
|
| 440 |
- no bias in decode method
|
| 441 |
+
- custom prepare_inputs_for_generation using "max_length - 1" to avoid issues
|
| 442 |
+
related to position embedding during model.generate()
|
| 443 |
"""
|
| 444 |
|
| 445 |
module_class = FlaxBartForConditionalGenerationModule
|
|
|
|
| 575 |
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
|
| 576 |
|
| 577 |
return outputs
|
| 578 |
+
|
| 579 |
+
def prepare_inputs_for_generation(
|
| 580 |
+
self,
|
| 581 |
+
decoder_input_ids,
|
| 582 |
+
max_length,
|
| 583 |
+
attention_mask: Optional[jnp.DeviceArray] = None,
|
| 584 |
+
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
|
| 585 |
+
encoder_outputs=None,
|
| 586 |
+
**kwargs,
|
| 587 |
+
):
|
| 588 |
+
# initializing the cache
|
| 589 |
+
batch_size, seq_length = decoder_input_ids.shape
|
| 590 |
+
|
| 591 |
+
past_key_values = self.init_cache(batch_size, max_length - 1, encoder_outputs)
|
| 592 |
+
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
|
| 593 |
+
# But since the decoder uses a causal mask, those positions are masked anyways.
|
| 594 |
+
# Thus we can create a single static attention_mask here, which is more efficient for compilation
|
| 595 |
+
extended_attention_mask = jnp.ones((batch_size, max_length - 1), dtype="i4")
|
| 596 |
+
if decoder_attention_mask is not None:
|
| 597 |
+
position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
|
| 598 |
+
extended_attention_mask = lax.dynamic_update_slice(
|
| 599 |
+
extended_attention_mask, decoder_attention_mask, (0, 0)
|
| 600 |
+
)
|
| 601 |
+
else:
|
| 602 |
+
position_ids = jnp.broadcast_to(
|
| 603 |
+
jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)
|
| 604 |
+
)
|
| 605 |
+
|
| 606 |
+
return {
|
| 607 |
+
"past_key_values": past_key_values,
|
| 608 |
+
"encoder_outputs": encoder_outputs,
|
| 609 |
+
"encoder_attention_mask": attention_mask,
|
| 610 |
+
"decoder_attention_mask": extended_attention_mask,
|
| 611 |
+
"decoder_position_ids": position_ids,
|
| 612 |
+
}
|