Spaces:
Runtime error
Runtime error
feat: handle model parallel
Browse files- src/dalle_mini/data.py +6 -1
- src/dalle_mini/model/configuration.py +19 -18
- src/dalle_mini/model/modeling.py +4 -4
- tools/train/train.py +28 -13
src/dalle_mini/data.py
CHANGED
|
@@ -85,7 +85,12 @@ class Dataset:
|
|
| 85 |
else self.eval_dataset.select(range(self.max_eval_samples))
|
| 86 |
)
|
| 87 |
|
| 88 |
-
def preprocess(self, tokenizer,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
if self.streaming:
|
| 90 |
# we need to shuffle early in streaming mode
|
| 91 |
if hasattr(self, "train_dataset"):
|
|
|
|
| 85 |
else self.eval_dataset.select(range(self.max_eval_samples))
|
| 86 |
)
|
| 87 |
|
| 88 |
+
def preprocess(self, tokenizer, config):
|
| 89 |
+
# get required config variables
|
| 90 |
+
decoder_start_token_id = config.decoder_start_token_id
|
| 91 |
+
normalize_text = config.normalize_text
|
| 92 |
+
max_length = config.max_text_length
|
| 93 |
+
|
| 94 |
if self.streaming:
|
| 95 |
# we need to shuffle early in streaming mode
|
| 96 |
if hasattr(self, "train_dataset"):
|
src/dalle_mini/model/configuration.py
CHANGED
|
@@ -59,6 +59,7 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
|
|
| 59 |
is_encoder_decoder=True,
|
| 60 |
forced_eos_token_id=None,
|
| 61 |
tie_word_embeddings=False, # different modalities and sizes
|
|
|
|
| 62 |
**kwargs,
|
| 63 |
):
|
| 64 |
self.normalize_text = normalize_text
|
|
@@ -87,28 +88,28 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
|
|
| 87 |
scale_embedding # scale factor will be sqrt(d_model) if True
|
| 88 |
)
|
| 89 |
|
| 90 |
-
#
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
kwargs.pop(k, None)
|
| 100 |
|
| 101 |
super().__init__(
|
| 102 |
-
|
| 103 |
-
+ 1, # needed to avoid errors during generation (converted to jnp.array)
|
| 104 |
-
bos_token_id=image_vocab_size + 1, # set to unreachable values
|
| 105 |
-
eos_token_id=image_vocab_size + 1,
|
| 106 |
is_encoder_decoder=is_encoder_decoder,
|
| 107 |
-
decoder_start_token_id=image_vocab_size, # BOS appended to vocab
|
| 108 |
-
forced_eos_token_id=forced_eos_token_id,
|
| 109 |
tie_word_embeddings=tie_word_embeddings,
|
| 110 |
-
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
**kwargs,
|
| 113 |
)
|
| 114 |
|
|
|
|
| 59 |
is_encoder_decoder=True,
|
| 60 |
forced_eos_token_id=None,
|
| 61 |
tie_word_embeddings=False, # different modalities and sizes
|
| 62 |
+
do_sample=True,
|
| 63 |
**kwargs,
|
| 64 |
):
|
| 65 |
self.normalize_text = normalize_text
|
|
|
|
| 88 |
scale_embedding # scale factor will be sqrt(d_model) if True
|
| 89 |
)
|
| 90 |
|
| 91 |
+
# special token id's are appended to vocab if not provided
|
| 92 |
+
decoder_start_token_id = kwargs.pop("decoder_start_token_id", image_vocab_size)
|
| 93 |
+
bos_token_id = kwargs.pop("bos_token_id", image_vocab_size)
|
| 94 |
+
pad_token_id = kwargs.pop("pad_token_id", image_vocab_size)
|
| 95 |
+
eos_token_id = kwargs.pop("eos_token_id", image_vocab_size)
|
| 96 |
+
|
| 97 |
+
# we generate to image_length + 1 (for bos) by default
|
| 98 |
+
min_length = kwargs.pop("min_length", image_length + 1)
|
| 99 |
+
max_length = kwargs.pop("max_length", image_length + 1)
|
|
|
|
| 100 |
|
| 101 |
super().__init__(
|
| 102 |
+
# args required in parent class
|
|
|
|
|
|
|
|
|
|
| 103 |
is_encoder_decoder=is_encoder_decoder,
|
|
|
|
|
|
|
| 104 |
tie_word_embeddings=tie_word_embeddings,
|
| 105 |
+
forced_eos_token_id=forced_eos_token_id,
|
| 106 |
+
decoder_start_token_id=decoder_start_token_id,
|
| 107 |
+
bos_token_id=bos_token_id,
|
| 108 |
+
pad_token_id=pad_token_id,
|
| 109 |
+
eos_token_id=eos_token_id,
|
| 110 |
+
min_length=min_length,
|
| 111 |
+
max_length=max_length,
|
| 112 |
+
do_sample=do_sample,
|
| 113 |
**kwargs,
|
| 114 |
)
|
| 115 |
|
src/dalle_mini/model/modeling.py
CHANGED
|
@@ -54,7 +54,7 @@ logger = logging.get_logger(__name__)
|
|
| 54 |
class FlaxBartAttention(FlaxBartAttention):
|
| 55 |
"""
|
| 56 |
Edits:
|
| 57 |
-
- causal mask is used only in decoder and considers image_length
|
| 58 |
"""
|
| 59 |
|
| 60 |
def setup(self) -> None:
|
|
@@ -81,7 +81,7 @@ class FlaxBartAttention(FlaxBartAttention):
|
|
| 81 |
if self.causal:
|
| 82 |
# used only in decoder
|
| 83 |
self.causal_mask = make_causal_mask(
|
| 84 |
-
jnp.ones((1, self.config.image_length
|
| 85 |
)
|
| 86 |
|
| 87 |
|
|
@@ -240,7 +240,7 @@ class FlaxBartDecoder(FlaxBartDecoder):
|
|
| 240 |
"""
|
| 241 |
Edits:
|
| 242 |
- offset set to 0 (no padding token)
|
| 243 |
-
- use image_length
|
| 244 |
- use custom FlaxBartDecoderLayerCollection
|
| 245 |
- embed_tokens cannot be None (issue at compile time)
|
| 246 |
"""
|
|
@@ -258,7 +258,7 @@ class FlaxBartDecoder(FlaxBartDecoder):
|
|
| 258 |
# and adjust num_embeddings appropriately. Other models don't have this hack
|
| 259 |
self.offset = 0
|
| 260 |
self.embed_positions = nn.Embed(
|
| 261 |
-
self.config.image_length +
|
| 262 |
embed_dim,
|
| 263 |
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
| 264 |
)
|
|
|
|
| 54 |
class FlaxBartAttention(FlaxBartAttention):
|
| 55 |
"""
|
| 56 |
Edits:
|
| 57 |
+
- causal mask is used only in decoder and considers image_length
|
| 58 |
"""
|
| 59 |
|
| 60 |
def setup(self) -> None:
|
|
|
|
| 81 |
if self.causal:
|
| 82 |
# used only in decoder
|
| 83 |
self.causal_mask = make_causal_mask(
|
| 84 |
+
jnp.ones((1, self.config.image_length), dtype="bool"), dtype="bool"
|
| 85 |
)
|
| 86 |
|
| 87 |
|
|
|
|
| 240 |
"""
|
| 241 |
Edits:
|
| 242 |
- offset set to 0 (no padding token)
|
| 243 |
+
- use image_length instead of max_position_embeddings
|
| 244 |
- use custom FlaxBartDecoderLayerCollection
|
| 245 |
- embed_tokens cannot be None (issue at compile time)
|
| 246 |
"""
|
|
|
|
| 258 |
# and adjust num_embeddings appropriately. Other models don't have this hack
|
| 259 |
self.offset = 0
|
| 260 |
self.embed_positions = nn.Embed(
|
| 261 |
+
self.config.image_length + self.offset, # image length for BOS
|
| 262 |
embed_dim,
|
| 263 |
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
| 264 |
)
|
tools/train/train.py
CHANGED
|
@@ -99,7 +99,7 @@ class ModelArguments:
|
|
| 99 |
|
| 100 |
def __post_init__(self):
|
| 101 |
if self.restore_state:
|
| 102 |
-
assert (
|
| 103 |
"/model-" in self.model_name_or_path
|
| 104 |
), "Restoring state only available with W&B artifact reference"
|
| 105 |
self.state_artifact = self.model_name_or_path.replace(
|
|
@@ -222,12 +222,13 @@ class TrainingArguments:
|
|
| 222 |
)
|
| 223 |
|
| 224 |
per_device_train_batch_size: int = field(
|
| 225 |
-
default=8,
|
|
|
|
| 226 |
)
|
| 227 |
per_device_eval_batch_size: Optional[int] = field(
|
| 228 |
default=None,
|
| 229 |
metadata={
|
| 230 |
-
"help": "Batch size per
|
| 231 |
},
|
| 232 |
)
|
| 233 |
|
|
@@ -523,12 +524,7 @@ def main():
|
|
| 523 |
# Preprocessing the datasets.
|
| 524 |
# We need to normalize and tokenize inputs and targets.
|
| 525 |
|
| 526 |
-
dataset.preprocess(
|
| 527 |
-
tokenizer=tokenizer,
|
| 528 |
-
decoder_start_token_id=model.config.decoder_start_token_id,
|
| 529 |
-
normalize_text=model.config.normalize_text,
|
| 530 |
-
max_length=model.config.max_text_length,
|
| 531 |
-
)
|
| 532 |
|
| 533 |
# Initialize our training
|
| 534 |
rng = jax.random.PRNGKey(training_args.seed_model)
|
|
@@ -874,9 +870,17 @@ def main():
|
|
| 874 |
|
| 875 |
# Define eval fn
|
| 876 |
def eval_step(state, batch):
|
| 877 |
-
|
| 878 |
-
|
| 879 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 880 |
return loss
|
| 881 |
|
| 882 |
# Create parallel version of the train and eval step
|
|
@@ -946,7 +950,18 @@ def main():
|
|
| 946 |
leave=False,
|
| 947 |
total=eval_steps,
|
| 948 |
):
|
| 949 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 950 |
batch = freeze(batch)
|
| 951 |
# accumulate losses async
|
| 952 |
eval_loss.append(p_eval_step(state, batch))
|
|
|
|
| 99 |
|
| 100 |
def __post_init__(self):
|
| 101 |
if self.restore_state:
|
| 102 |
+
assert self.model_name_or_path is not None and (
|
| 103 |
"/model-" in self.model_name_or_path
|
| 104 |
), "Restoring state only available with W&B artifact reference"
|
| 105 |
self.state_artifact = self.model_name_or_path.replace(
|
|
|
|
| 222 |
)
|
| 223 |
|
| 224 |
per_device_train_batch_size: int = field(
|
| 225 |
+
default=8,
|
| 226 |
+
metadata={"help": "Batch size per data parallel device for training."},
|
| 227 |
)
|
| 228 |
per_device_eval_batch_size: Optional[int] = field(
|
| 229 |
default=None,
|
| 230 |
metadata={
|
| 231 |
+
"help": "Batch size per data parallel device for evaluation. Same as training batch size if not set."
|
| 232 |
},
|
| 233 |
)
|
| 234 |
|
|
|
|
| 524 |
# Preprocessing the datasets.
|
| 525 |
# We need to normalize and tokenize inputs and targets.
|
| 526 |
|
| 527 |
+
dataset.preprocess(tokenizer=tokenizer, config=model.config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 528 |
|
| 529 |
# Initialize our training
|
| 530 |
rng = jax.random.PRNGKey(training_args.seed_model)
|
|
|
|
| 870 |
|
| 871 |
# Define eval fn
|
| 872 |
def eval_step(state, batch):
|
| 873 |
+
def compute_eval_loss(batch):
|
| 874 |
+
batch, labels = batch.pop("labels")
|
| 875 |
+
logits = state.apply_fn(**batch, params=state.params, train=False)[0]
|
| 876 |
+
return loss_fn(logits, labels)
|
| 877 |
+
|
| 878 |
+
# calculate loss independently per dp_device
|
| 879 |
+
loss = jax.vmap(compute_eval_loss, in_axes=(0,), out_axes=0)(batch)
|
| 880 |
+
# ensure they are sharded over dp devices
|
| 881 |
+
loss = with_sharding_constraint(loss, PartitionSpec("batch"))
|
| 882 |
+
# average across all devices
|
| 883 |
+
loss = jnp.mean(loss)
|
| 884 |
return loss
|
| 885 |
|
| 886 |
# Create parallel version of the train and eval step
|
|
|
|
| 950 |
leave=False,
|
| 951 |
total=eval_steps,
|
| 952 |
):
|
| 953 |
+
# reshape data into (dp_devices, batch_per_dp, ...)
|
| 954 |
+
batch = jax.tree_map(
|
| 955 |
+
lambda x: x.reshape(
|
| 956 |
+
(
|
| 957 |
+
training_args.dp_devices,
|
| 958 |
+
training_args.per_device_eval_batch_size,
|
| 959 |
+
)
|
| 960 |
+
+ x.shape[1:]
|
| 961 |
+
),
|
| 962 |
+
batch,
|
| 963 |
+
)
|
| 964 |
+
# freeze batch to pass safely to jax transforms
|
| 965 |
batch = freeze(batch)
|
| 966 |
# accumulate losses async
|
| 967 |
eval_loss.append(p_eval_step(state, batch))
|