Spaces:
Runtime error
Runtime error
Commit
·
007806e
1
Parent(s):
fd16ff8
update space
Browse files
app.py
CHANGED
|
@@ -153,9 +153,14 @@ def main():
|
|
| 153 |
editing_types = ['rainbow', 'santa_hat', 'lego', 'golden', 'wooden', 'cyber']
|
| 154 |
# prepare models
|
| 155 |
for editing_type in editing_types:
|
| 156 |
-
tmp_model =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
with torch.no_grad():
|
| 158 |
-
new_proj = nn.Linear(1024 * 2, 1024, device=
|
| 159 |
new_proj.weight = nn.Parameter(torch.zeros_like(new_proj.weight))
|
| 160 |
new_proj.weight[:, :1024].copy_(tmp_model.wrapped.input_proj.weight) #
|
| 161 |
new_proj.bias = nn.Parameter(torch.zeros_like(new_proj.bias))
|
|
@@ -164,10 +169,13 @@ def main():
|
|
| 164 |
|
| 165 |
ckp = torch.load(hf_hub_download(repo_id='silentchen/Shap_Editor', subfolder='single', filename='{}.pt'.format(editing_type)), map_location='cpu')
|
| 166 |
tmp_model.load_state_dict(ckp['model'])
|
| 167 |
-
|
|
|
|
|
|
|
|
|
|
| 168 |
initial_noise[editing_type] = noise_initial
|
| 169 |
noise_start_t[editing_type] = ckp['t_start']
|
| 170 |
-
models[editing_type] = tmp_model
|
| 171 |
@torch.no_grad()
|
| 172 |
def optimize_all(prompt, instruction,
|
| 173 |
rand_seed):
|
|
@@ -279,12 +287,14 @@ def main():
|
|
| 279 |
os.makedirs(general_save_path, exist_ok=True)
|
| 280 |
for i, latent in enumerate(state['latent']):
|
| 281 |
latent = latent.to(device)
|
| 282 |
-
text_embeddings_clip = model.cached_model_kwargs(1, dict(texts=[instruction]))
|
| 283 |
print("shape of latent: ", latent.clone().unsqueeze(0).shape, "instruction: ", instruction)
|
| 284 |
ref_latent = latent.clone().unsqueeze(0).to(device)
|
| 285 |
t_1 = torch.randint(noise_start_t_e_type, noise_start_t_e_type + 1, (1,), device=device).long()
|
| 286 |
|
| 287 |
noise_input = diffusion.q_sample(ref_latent, t_1, noise=noise_initial)
|
|
|
|
|
|
|
| 288 |
out_1 = diffusion.p_mean_variance(model, noise_input, t_1, clip_denoised=True,
|
| 289 |
model_kwargs=text_embeddings_clip,
|
| 290 |
condition_latents=ref_latent)
|
|
|
|
| 153 |
editing_types = ['rainbow', 'santa_hat', 'lego', 'golden', 'wooden', 'cyber']
|
| 154 |
# prepare models
|
| 155 |
for editing_type in editing_types:
|
| 156 |
+
tmp_model = model_from_config(load_config('text300M'), device=device)
|
| 157 |
+
# print(model_name, kwargs)
|
| 158 |
+
# print(model)
|
| 159 |
+
|
| 160 |
+
# xm = load_model('transmitter', de
|
| 161 |
+
tmp_model = load_model('text300M', device=device)
|
| 162 |
with torch.no_grad():
|
| 163 |
+
new_proj = nn.Linear(1024 * 2, 1024, device=device, dtype=tmp_model.wrapped.input_proj.weight.dtype)
|
| 164 |
new_proj.weight = nn.Parameter(torch.zeros_like(new_proj.weight))
|
| 165 |
new_proj.weight[:, :1024].copy_(tmp_model.wrapped.input_proj.weight) #
|
| 166 |
new_proj.bias = nn.Parameter(torch.zeros_like(new_proj.bias))
|
|
|
|
| 169 |
|
| 170 |
ckp = torch.load(hf_hub_download(repo_id='silentchen/Shap_Editor', subfolder='single', filename='{}.pt'.format(editing_type)), map_location='cpu')
|
| 171 |
tmp_model.load_state_dict(ckp['model'])
|
| 172 |
+
tmp_model.eval()
|
| 173 |
+
# print("loaded latent model")
|
| 174 |
+
tmp_model.to(device)
|
| 175 |
+
noise_initial = ckp['initial_noise']['noise'].to(device)
|
| 176 |
initial_noise[editing_type] = noise_initial
|
| 177 |
noise_start_t[editing_type] = ckp['t_start']
|
| 178 |
+
models[editing_type] = tmp_model.to(device)
|
| 179 |
@torch.no_grad()
|
| 180 |
def optimize_all(prompt, instruction,
|
| 181 |
rand_seed):
|
|
|
|
| 287 |
os.makedirs(general_save_path, exist_ok=True)
|
| 288 |
for i, latent in enumerate(state['latent']):
|
| 289 |
latent = latent.to(device)
|
| 290 |
+
text_embeddings_clip = model.cached_model_kwargs(1, dict(texts=[instruction])).to(device)
|
| 291 |
print("shape of latent: ", latent.clone().unsqueeze(0).shape, "instruction: ", instruction)
|
| 292 |
ref_latent = latent.clone().unsqueeze(0).to(device)
|
| 293 |
t_1 = torch.randint(noise_start_t_e_type, noise_start_t_e_type + 1, (1,), device=device).long()
|
| 294 |
|
| 295 |
noise_input = diffusion.q_sample(ref_latent, t_1, noise=noise_initial)
|
| 296 |
+
print("noise_input:", noise_input.device)
|
| 297 |
+
|
| 298 |
out_1 = diffusion.p_mean_variance(model, noise_input, t_1, clip_denoised=True,
|
| 299 |
model_kwargs=text_embeddings_clip,
|
| 300 |
condition_latents=ref_latent)
|