Spaces:
Running
on
Zero
Running
on
Zero
fix f16
Browse files
app.py
CHANGED
|
@@ -215,6 +215,11 @@ def run_classifier(image: Image.Image, threshold):
|
|
| 215 |
img = image.convert('RGBA')
|
| 216 |
tensor = transform(img).unsqueeze(0)
|
| 217 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
with torch.no_grad():
|
| 219 |
probits = model(tensor)[0] # type: torch.Tensor
|
| 220 |
values, indices = probits.cpu().topk(250)
|
|
@@ -238,6 +243,13 @@ def cam_inference(img, threshold, alpha, evt: gr.SelectData):
|
|
| 238 |
target_tag_index = tags[evt.value]
|
| 239 |
tensor = transform(img).unsqueeze(0)
|
| 240 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
gradients = {}
|
| 242 |
activations = {}
|
| 243 |
|
|
@@ -339,7 +351,10 @@ def process_images(images, threshold):
|
|
| 339 |
all_results = []
|
| 340 |
with torch.no_grad():
|
| 341 |
for batch, filenames in dataloader:
|
| 342 |
-
|
|
|
|
|
|
|
|
|
|
| 343 |
probabilities = model(batch)
|
| 344 |
for i, prob in enumerate(probabilities):
|
| 345 |
indices = torch.where(prob > threshold)[0]
|
|
|
|
| 215 |
img = image.convert('RGBA')
|
| 216 |
tensor = transform(img).unsqueeze(0)
|
| 217 |
|
| 218 |
+
if torch.cuda.is_available():
|
| 219 |
+
tensor = tensor.to(device, dtype=torch.float16)
|
| 220 |
+
else:
|
| 221 |
+
tensor = tensor.to(device)
|
| 222 |
+
|
| 223 |
with torch.no_grad():
|
| 224 |
probits = model(tensor)[0] # type: torch.Tensor
|
| 225 |
values, indices = probits.cpu().topk(250)
|
|
|
|
| 243 |
target_tag_index = tags[evt.value]
|
| 244 |
tensor = transform(img).unsqueeze(0)
|
| 245 |
|
| 246 |
+
if torch.cuda.is_available():
|
| 247 |
+
tensor = tensor.to(device, dtype=torch.float16)
|
| 248 |
+
else:
|
| 249 |
+
tensor = tensor.to(device)
|
| 250 |
+
|
| 251 |
+
tensor.requires_grad_()
|
| 252 |
+
|
| 253 |
gradients = {}
|
| 254 |
activations = {}
|
| 255 |
|
|
|
|
| 351 |
all_results = []
|
| 352 |
with torch.no_grad():
|
| 353 |
for batch, filenames in dataloader:
|
| 354 |
+
if torch.cuda.is_available():
|
| 355 |
+
batch = batch.to(device, dtype=torch.float16)
|
| 356 |
+
else:
|
| 357 |
+
batch = batch.to(device)
|
| 358 |
probabilities = model(batch)
|
| 359 |
for i, prob in enumerate(probabilities):
|
| 360 |
indices = torch.where(prob > threshold)[0]
|