Commit
·
bd73f2a
1
Parent(s):
114c79c
Do not assume 8 devices in JAX (#154)
Browse files- Do not assume 8 devices in JAX (e124bbdca2dab1af0cdce19d575f8043eab9341e)
Co-authored-by: Pedro Cuenca <pcuenq@users.noreply.huggingface.co>
README.md
CHANGED
|
@@ -154,7 +154,7 @@ prompt_ids = pipeline.prepare_inputs(prompt)
|
|
| 154 |
|
| 155 |
# shard inputs and rng
|
| 156 |
params = replicate(params)
|
| 157 |
-
prng_seed = jax.random.split(prng_seed,
|
| 158 |
prompt_ids = shard(prompt_ids)
|
| 159 |
|
| 160 |
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
|
|
@@ -187,7 +187,7 @@ prompt_ids = pipeline.prepare_inputs(prompt)
|
|
| 187 |
|
| 188 |
# shard inputs and rng
|
| 189 |
params = replicate(params)
|
| 190 |
-
prng_seed = jax.random.split(prng_seed,
|
| 191 |
prompt_ids = shard(prompt_ids)
|
| 192 |
|
| 193 |
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
|
|
|
|
| 154 |
|
| 155 |
# shard inputs and rng
|
| 156 |
params = replicate(params)
|
| 157 |
+
prng_seed = jax.random.split(prng_seed, num_samples)
|
| 158 |
prompt_ids = shard(prompt_ids)
|
| 159 |
|
| 160 |
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
|
|
|
|
| 187 |
|
| 188 |
# shard inputs and rng
|
| 189 |
params = replicate(params)
|
| 190 |
+
prng_seed = jax.random.split(prng_seed, num_samples)
|
| 191 |
prompt_ids = shard(prompt_ids)
|
| 192 |
|
| 193 |
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
|