Update custom_generate/generate.py
Browse files
custom_generate/generate.py
CHANGED
|
@@ -282,14 +282,7 @@ def _contrastive_search(
|
|
| 282 |
f"{model.__class__.__name__} does not support caching and therefore **can't** be used "
|
| 283 |
"for contrastive search."
|
| 284 |
)
|
| 285 |
-
|
| 286 |
-
not isinstance(past_key_values[0], (tuple, torch.Tensor))
|
| 287 |
-
or past_key_values[0][0].shape[0] != batch_size
|
| 288 |
-
):
|
| 289 |
-
raise ValueError(
|
| 290 |
-
f"{model.__class__.__name__} does not have a standard cache format and therefore **can't** be "
|
| 291 |
-
"used for contrastive search without further modifications."
|
| 292 |
-
)
|
| 293 |
|
| 294 |
# contrastive_search main logic start:
|
| 295 |
# contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by
|
|
|
|
| 282 |
f"{model.__class__.__name__} does not support caching and therefore **can't** be used "
|
| 283 |
"for contrastive search."
|
| 284 |
)
|
| 285 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
|
| 287 |
# contrastive_search main logic start:
|
| 288 |
# contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by
|