Spaces:
Running
on
Zero
Running
on
Zero
Pipeline add torch dtype code
Browse files- pipeline_objectclear.py +3 -1
pipeline_objectclear.py
CHANGED
|
@@ -463,7 +463,6 @@ class ObjectClearPipeline(
|
|
| 463 |
)
|
| 464 |
|
| 465 |
|
| 466 |
-
@classmethod
|
| 467 |
@classmethod
|
| 468 |
def from_pretrained_with_custom_modules(
|
| 469 |
cls,
|
|
@@ -500,6 +499,9 @@ class ObjectClearPipeline(
|
|
| 500 |
cache_dir=cache_dir,
|
| 501 |
**kwargs,
|
| 502 |
)
|
|
|
|
|
|
|
|
|
|
| 503 |
|
| 504 |
return pipe
|
| 505 |
|
|
|
|
| 463 |
)
|
| 464 |
|
| 465 |
|
|
|
|
| 466 |
@classmethod
|
| 467 |
def from_pretrained_with_custom_modules(
|
| 468 |
cls,
|
|
|
|
| 499 |
cache_dir=cache_dir,
|
| 500 |
**kwargs,
|
| 501 |
)
|
| 502 |
+
|
| 503 |
+
if torch_dtype is not None:
|
| 504 |
+
pipe.to(dtype=torch_dtype)
|
| 505 |
|
| 506 |
return pipe
|
| 507 |
|