Spaces:
Runtime error
Runtime error
feat(train): google-cloud-storage is optional
Browse files- tools/train/train.py +12 -1
tools/train/train.py
CHANGED
|
@@ -42,7 +42,6 @@ from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
|
| 42 |
from flax.serialization import from_bytes, to_bytes
|
| 43 |
from flax.training import train_state
|
| 44 |
from flax.training.common_utils import onehot
|
| 45 |
-
from google.cloud import storage
|
| 46 |
from jax.experimental import PartitionSpec, maps
|
| 47 |
from jax.experimental.compilation_cache import compilation_cache as cc
|
| 48 |
from jax.experimental.pjit import pjit, with_sharding_constraint
|
|
@@ -58,6 +57,11 @@ from dalle_mini.model import (
|
|
| 58 |
set_partitions,
|
| 59 |
)
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
cc.initialize_cache("./jax_cache", max_cache_size_bytes=10 * 2**30)
|
| 62 |
|
| 63 |
logger = logging.getLogger(__name__)
|
|
@@ -144,6 +148,9 @@ class ModelArguments:
|
|
| 144 |
if self.restore_state.startswith("gs://"):
|
| 145 |
bucket_path = Path(self.restore_state[5:]) / "opt_state.msgpack"
|
| 146 |
bucket, blob_name = str(bucket_path).split("/", 1)
|
|
|
|
|
|
|
|
|
|
| 147 |
client = storage.Client()
|
| 148 |
bucket = client.bucket(bucket)
|
| 149 |
blob = bucket.blob(blob_name)
|
|
@@ -456,6 +463,10 @@ class TrainingArguments:
|
|
| 456 |
assert (
|
| 457 |
jax.local_device_count() == 8
|
| 458 |
), "TPUs in use, please check running processes"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 459 |
assert self.optim in [
|
| 460 |
"distributed_shampoo",
|
| 461 |
"adam",
|
|
|
|
| 42 |
from flax.serialization import from_bytes, to_bytes
|
| 43 |
from flax.training import train_state
|
| 44 |
from flax.training.common_utils import onehot
|
|
|
|
| 45 |
from jax.experimental import PartitionSpec, maps
|
| 46 |
from jax.experimental.compilation_cache import compilation_cache as cc
|
| 47 |
from jax.experimental.pjit import pjit, with_sharding_constraint
|
|
|
|
| 57 |
set_partitions,
|
| 58 |
)
|
| 59 |
|
| 60 |
+
try:
|
| 61 |
+
from google.cloud import storage
|
| 62 |
+
except:
|
| 63 |
+
storage = None
|
| 64 |
+
|
| 65 |
cc.initialize_cache("./jax_cache", max_cache_size_bytes=10 * 2**30)
|
| 66 |
|
| 67 |
logger = logging.getLogger(__name__)
|
|
|
|
| 148 |
if self.restore_state.startswith("gs://"):
|
| 149 |
bucket_path = Path(self.restore_state[5:]) / "opt_state.msgpack"
|
| 150 |
bucket, blob_name = str(bucket_path).split("/", 1)
|
| 151 |
+
assert (
|
| 152 |
+
storage is not None
|
| 153 |
+
), 'Could not find google.storage. Install with "pip install google-cloud-storage"'
|
| 154 |
client = storage.Client()
|
| 155 |
bucket = client.bucket(bucket)
|
| 156 |
blob = bucket.blob(blob_name)
|
|
|
|
| 463 |
assert (
|
| 464 |
jax.local_device_count() == 8
|
| 465 |
), "TPUs in use, please check running processes"
|
| 466 |
+
if self.output_dir.startswith("gs://"):
|
| 467 |
+
assert (
|
| 468 |
+
storage is not None
|
| 469 |
+
), 'Could not find google.storage. Install with "pip install google-cloud-storage"'
|
| 470 |
assert self.optim in [
|
| 471 |
"distributed_shampoo",
|
| 472 |
"adam",
|