Spaces:
Runtime error
Runtime error
move to(device) outside of collate
Browse files
utils.py
CHANGED
|
@@ -226,7 +226,7 @@ def collate_fn(examples, tokenizer=None, padding=None, device=None):
|
|
| 226 |
batch[k].append(v)
|
| 227 |
|
| 228 |
return {
|
| 229 |
-
k: torch.tensor(v, dtype=torch.long
|
| 230 |
}
|
| 231 |
|
| 232 |
@torch.inference_mode()
|
|
@@ -332,8 +332,8 @@ def batch_embed(
|
|
| 332 |
drop_last=False,
|
| 333 |
collate_fn=partial(collate_fn, device=device)
|
| 334 |
):
|
| 335 |
-
ids = batch["input_ids"]
|
| 336 |
-
mask = batch["attention_mask"]
|
| 337 |
|
| 338 |
t_ids = torch.zeros_like(ids)
|
| 339 |
|
|
|
|
| 226 |
batch[k].append(v)
|
| 227 |
|
| 228 |
return {
|
| 229 |
+
k: torch.tensor(v, dtype=torch.long) if k in {"attention_mask", "input_ids"} else v for k, v in batch.items()
|
| 230 |
}
|
| 231 |
|
| 232 |
@torch.inference_mode()
|
|
|
|
| 332 |
drop_last=False,
|
| 333 |
collate_fn=partial(collate_fn, device=device)
|
| 334 |
):
|
| 335 |
+
ids = batch["input_ids"].to(device)
|
| 336 |
+
mask = batch["attention_mask"].to(device)
|
| 337 |
|
| 338 |
t_ids = torch.zeros_like(ids)
|
| 339 |
|