Spaces:
Running
on
Zero
Running
on
Zero
Fix dynamic shapes
Browse files- optimization.py +3 -1
optimization.py
CHANGED
|
@@ -7,6 +7,7 @@ from typing import ParamSpec
|
|
| 7 |
|
| 8 |
import spaces
|
| 9 |
import torch
|
|
|
|
| 10 |
|
| 11 |
from pipeline_utils import capture_component_call
|
| 12 |
from zerogpu import aoti_compile
|
|
@@ -34,7 +35,8 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
|
|
| 34 |
pipeline(*args, **kwargs)
|
| 35 |
|
| 36 |
hidden_dim = torch.export.Dim('hidden', min=4096, max=8212)
|
| 37 |
-
dynamic_shapes =
|
|
|
|
| 38 |
'hidden_states': {1: hidden_dim},
|
| 39 |
'img_ids': {0: hidden_dim},
|
| 40 |
}
|
|
|
|
| 7 |
|
| 8 |
import spaces
|
| 9 |
import torch
|
| 10 |
+
from torch.utils._pytree import tree_map_only
|
| 11 |
|
| 12 |
from pipeline_utils import capture_component_call
|
| 13 |
from zerogpu import aoti_compile
|
|
|
|
| 35 |
pipeline(*args, **kwargs)
|
| 36 |
|
| 37 |
hidden_dim = torch.export.Dim('hidden', min=4096, max=8212)
|
| 38 |
+
dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
|
| 39 |
+
dynamic_shapes |= {
|
| 40 |
'hidden_states': {1: hidden_dim},
|
| 41 |
'img_ids': {0: hidden_dim},
|
| 42 |
}
|