Spaces:
Runtime error
Runtime error
Commit
·
958100a
1
Parent(s):
239a0d3
changed ft.pt to prompt-based
Browse files
app.py
CHANGED
|
@@ -262,7 +262,8 @@ class Demo:
|
|
| 262 |
loss.backward()
|
| 263 |
optimizer.step()
|
| 264 |
|
| 265 |
-
|
|
|
|
| 266 |
|
| 267 |
self.finetuner = finetuner.eval().half()
|
| 268 |
|
|
@@ -272,9 +273,9 @@ class Demo:
|
|
| 272 |
|
| 273 |
self.training = False
|
| 274 |
|
| 275 |
-
model_map['Custom'] =
|
| 276 |
|
| 277 |
-
return [gr.update(interactive=True), gr.update(interactive=True),
|
| 278 |
|
| 279 |
|
| 280 |
def inference(self, prompt, seed, pbar = gr.Progress(track_tqdm=True)):
|
|
|
|
| 262 |
loss.backward()
|
| 263 |
optimizer.step()
|
| 264 |
|
| 265 |
+
ft_path = f"{prompt.lower().replace(' ', '')}.pt"
|
| 266 |
+
torch.save(finetuner.state_dict(), ft_path)
|
| 267 |
|
| 268 |
self.finetuner = finetuner.eval().half()
|
| 269 |
|
|
|
|
| 273 |
|
| 274 |
self.training = False
|
| 275 |
|
| 276 |
+
model_map['Custom'] = ft_path
|
| 277 |
|
| 278 |
+
return [gr.update(interactive=True), gr.update(interactive=True), ft_path, gr.Dropdown.update(choices=list(model_map.keys()), value='Custom')]
|
| 279 |
|
| 280 |
|
| 281 |
def inference(self, prompt, seed, pbar = gr.Progress(track_tqdm=True)):
|