Spaces:
Running
Running
hysts
commited on
Commit
·
6be65e1
1
Parent(s):
0e88f89
Clean up
Browse files
model.py
CHANGED
|
@@ -71,22 +71,21 @@ class Model:
|
|
| 71 |
self.model_names = LIGHTWEIGHT_MODEL_NAMES
|
| 72 |
self.weight_root = LIGHTWEIGHT_WEIGHT_ROOT
|
| 73 |
base_model_url = 'https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors'
|
| 74 |
-
self.
|
| 75 |
-
base_model_path = self.model_dir / base_model_url.split('/')[-1]
|
| 76 |
-
self.load_base_model(base_model_path)
|
| 77 |
else:
|
| 78 |
self.model_names = ORIGINAL_MODEL_NAMES
|
| 79 |
self.weight_root = ORIGINAL_WEIGHT_ROOT
|
| 80 |
self.download_models()
|
| 81 |
|
| 82 |
-
def download_base_model(self,
|
| 83 |
-
model_name =
|
| 84 |
out_path = self.model_dir / model_name
|
| 85 |
-
if out_path.exists():
|
| 86 |
-
|
| 87 |
-
|
| 88 |
|
| 89 |
-
def load_base_model(self,
|
|
|
|
| 90 |
self.model.load_state_dict(load_state_dict(model_path,
|
| 91 |
location=self.device.type),
|
| 92 |
strict=False)
|
|
|
|
| 71 |
self.model_names = LIGHTWEIGHT_MODEL_NAMES
|
| 72 |
self.weight_root = LIGHTWEIGHT_WEIGHT_ROOT
|
| 73 |
base_model_url = 'https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors'
|
| 74 |
+
self.load_base_model(base_model_url)
|
|
|
|
|
|
|
| 75 |
else:
|
| 76 |
self.model_names = ORIGINAL_MODEL_NAMES
|
| 77 |
self.weight_root = ORIGINAL_WEIGHT_ROOT
|
| 78 |
self.download_models()
|
| 79 |
|
| 80 |
+
def download_base_model(self, model_url: str) -> pathlib.Path:
|
| 81 |
+
model_name = model_url.split('/')[-1]
|
| 82 |
out_path = self.model_dir / model_name
|
| 83 |
+
if not out_path.exists():
|
| 84 |
+
subprocess.run(shlex.split(f'wget {model_url} -O {out_path}'))
|
| 85 |
+
return out_path
|
| 86 |
|
| 87 |
+
def load_base_model(self, model_url: str) -> None:
|
| 88 |
+
model_path = self.download_base_model(model_url)
|
| 89 |
self.model.load_state_dict(load_state_dict(model_path,
|
| 90 |
location=self.device.type),
|
| 91 |
strict=False)
|