Spaces:
Runtime error
Runtime error
feat: shard by host is optional
Browse files- dalle_mini/data.py +7 -2
- tools/train/train.py +8 -2
dalle_mini/data.py
CHANGED
|
@@ -27,6 +27,7 @@ class Dataset:
|
|
| 27 |
do_train: bool = False
|
| 28 |
do_eval: bool = True
|
| 29 |
seed_dataset: int = None
|
|
|
|
| 30 |
train_dataset: Dataset = field(init=False)
|
| 31 |
eval_dataset: Dataset = field(init=False)
|
| 32 |
rng_dataset: jnp.ndarray = field(init=False)
|
|
@@ -42,7 +43,11 @@ class Dataset:
|
|
| 42 |
if isinstance(f, str):
|
| 43 |
setattr(self, k, list(braceexpand(f)))
|
| 44 |
# for list of files, split training data shards by host
|
| 45 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
self.train_file = self.train_file[
|
| 47 |
jax.process_index() :: jax.process_count()
|
| 48 |
]
|
|
@@ -185,7 +190,7 @@ class Dataset:
|
|
| 185 |
first_loop = True
|
| 186 |
while self.multi_hosts or first_loop:
|
| 187 |
# in multi-host, we run forever (no epoch) as hosts need to stop
|
| 188 |
-
# at same
|
| 189 |
if not first_loop:
|
| 190 |
# multi-host setting, we reshuffle shards
|
| 191 |
epoch += 1
|
|
|
|
| 27 |
do_train: bool = False
|
| 28 |
do_eval: bool = True
|
| 29 |
seed_dataset: int = None
|
| 30 |
+
shard_by_host: bool = False
|
| 31 |
train_dataset: Dataset = field(init=False)
|
| 32 |
eval_dataset: Dataset = field(init=False)
|
| 33 |
rng_dataset: jnp.ndarray = field(init=False)
|
|
|
|
| 43 |
if isinstance(f, str):
|
| 44 |
setattr(self, k, list(braceexpand(f)))
|
| 45 |
# for list of files, split training data shards by host
|
| 46 |
+
if (
|
| 47 |
+
isinstance(self.train_file, list)
|
| 48 |
+
and self.multi_hosts
|
| 49 |
+
and self.shard_by_host
|
| 50 |
+
):
|
| 51 |
self.train_file = self.train_file[
|
| 52 |
jax.process_index() :: jax.process_count()
|
| 53 |
]
|
|
|
|
| 190 |
first_loop = True
|
| 191 |
while self.multi_hosts or first_loop:
|
| 192 |
# in multi-host, we run forever (no epoch) as hosts need to stop
|
| 193 |
+
# at the same time and we don't know how much data is on each host
|
| 194 |
if not first_loop:
|
| 195 |
# multi-host setting, we reshuffle shards
|
| 196 |
epoch += 1
|
tools/train/train.py
CHANGED
|
@@ -112,16 +112,22 @@ class DataTrainingArguments:
|
|
| 112 |
metadata={"help": "An optional input evaluation data file (glob acceptable)."},
|
| 113 |
)
|
| 114 |
# data loading should not be a bottleneck so we use "streaming" mode by default
|
| 115 |
-
streaming: bool = field(
|
| 116 |
default=True,
|
| 117 |
metadata={"help": "Whether to stream the dataset."},
|
| 118 |
)
|
| 119 |
-
use_auth_token: bool = field(
|
| 120 |
default=False,
|
| 121 |
metadata={
|
| 122 |
"help": "Whether to use the authentication token for private datasets."
|
| 123 |
},
|
| 124 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
max_train_samples: Optional[int] = field(
|
| 126 |
default=None,
|
| 127 |
metadata={
|
|
|
|
| 112 |
metadata={"help": "An optional input evaluation data file (glob acceptable)."},
|
| 113 |
)
|
| 114 |
# data loading should not be a bottleneck so we use "streaming" mode by default
|
| 115 |
+
streaming: Optional[bool] = field(
|
| 116 |
default=True,
|
| 117 |
metadata={"help": "Whether to stream the dataset."},
|
| 118 |
)
|
| 119 |
+
use_auth_token: Optional[bool] = field(
|
| 120 |
default=False,
|
| 121 |
metadata={
|
| 122 |
"help": "Whether to use the authentication token for private datasets."
|
| 123 |
},
|
| 124 |
)
|
| 125 |
+
shard_by_host: Optional[bool] = field(
|
| 126 |
+
default=False,
|
| 127 |
+
metadata={
|
| 128 |
+
"help": "Whether to shard data files by host in multi-host environments."
|
| 129 |
+
},
|
| 130 |
+
)
|
| 131 |
max_train_samples: Optional[int] = field(
|
| 132 |
default=None,
|
| 133 |
metadata={
|