Spaces:
Runtime error
Runtime error
RohitGandikota
commited on
Commit
Β·
68e2466
1
Parent(s):
4d3c7dc
fixing custom slider training
Browse files
app.py
CHANGED
|
@@ -245,7 +245,7 @@ class Demo:
|
|
| 245 |
save_name = f"{randn}_{target_concept.replace(',','').replace(' ','').replace('.','')[:10]}_{positive_prompt.replace(',','').replace(' ','').replace('.','')[:10]}"
|
| 246 |
save_name += f'_alpha-{1}'
|
| 247 |
save_name += f'_noxattn'
|
| 248 |
-
save_name += f'_rank_{rank}.pt'
|
| 249 |
|
| 250 |
# if torch.cuda.get_device_properties(0).total_memory * 1e-9 < 40:
|
| 251 |
# return [gr.update(interactive=True, value='Train'), gr.update(value='GPU Memory is not enough for training... Please upgrade to GPU atleast 40GB or clone the repo to your local machine.'), None, gr.update()]
|
|
@@ -257,7 +257,7 @@ class Demo:
|
|
| 257 |
attributes = 'white, black, asian, hispanic, indian, male, female'
|
| 258 |
|
| 259 |
self.training = True
|
| 260 |
-
train_xl(target=target_concept, positive=positive_prompt, negative=negative_prompt, lr=lr_input, iterations=iterations_input, config_file='trainscripts/textsliders/data/config-xl.yaml', rank=rank, device=self.device, attributes=attributes, save_name=save_name)
|
| 261 |
self.training = False
|
| 262 |
|
| 263 |
torch.cuda.empty_cache()
|
|
@@ -293,7 +293,7 @@ class Demo:
|
|
| 293 |
rank = 4
|
| 294 |
alpha = 1
|
| 295 |
if 'rank' in model_path:
|
| 296 |
-
rank = int(model_path.split('_')[-1].replace('.pt',''))
|
| 297 |
if 'alpha1' in model_path:
|
| 298 |
alpha = 1.0
|
| 299 |
network = LoRANetwork(
|
|
|
|
| 245 |
save_name = f"{randn}_{target_concept.replace(',','').replace(' ','').replace('.','')[:10]}_{positive_prompt.replace(',','').replace(' ','').replace('.','')[:10]}"
|
| 246 |
save_name += f'_alpha-{1}'
|
| 247 |
save_name += f'_noxattn'
|
| 248 |
+
save_name += f'_rank_{int(rank)}.pt'
|
| 249 |
|
| 250 |
# if torch.cuda.get_device_properties(0).total_memory * 1e-9 < 40:
|
| 251 |
# return [gr.update(interactive=True, value='Train'), gr.update(value='GPU Memory is not enough for training... Please upgrade to GPU atleast 40GB or clone the repo to your local machine.'), None, gr.update()]
|
|
|
|
| 257 |
attributes = 'white, black, asian, hispanic, indian, male, female'
|
| 258 |
|
| 259 |
self.training = True
|
| 260 |
+
train_xl(target=target_concept, positive=positive_prompt, negative=negative_prompt, lr=lr_input, iterations=iterations_input, config_file='trainscripts/textsliders/data/config-xl.yaml', rank=int(rank), device=self.device, attributes=attributes, save_name=save_name)
|
| 261 |
self.training = False
|
| 262 |
|
| 263 |
torch.cuda.empty_cache()
|
|
|
|
| 293 |
rank = 4
|
| 294 |
alpha = 1
|
| 295 |
if 'rank' in model_path:
|
| 296 |
+
rank = int(float(model_path.split('_')[-1].replace('.pt','')))
|
| 297 |
if 'alpha1' in model_path:
|
| 298 |
alpha = 1.0
|
| 299 |
network = LoRANetwork(
|