Update modeling_gpt_refact.py
Browse files- modeling_gpt_refact.py +2 -2
modeling_gpt_refact.py
CHANGED
|
@@ -508,9 +508,9 @@ class GPTRefactForCausalLM(GPTRefactPreTrainedModel):
|
|
| 508 |
import transformers
|
| 509 |
from packaging import version
|
| 510 |
|
| 511 |
-
def _set_gradient_checkpointing(module,
|
| 512 |
if isinstance(module, GPTRefactModel):
|
| 513 |
-
module.gradient_checkpointing =
|
| 514 |
|
| 515 |
v = version.parse(transformers.__version__)
|
| 516 |
if v.major <= 4 and v.minor < 35:
|
|
|
|
| 508 |
import transformers
|
| 509 |
from packaging import version
|
| 510 |
|
| 511 |
+
def _set_gradient_checkpointing(module, value=False):
|
| 512 |
if isinstance(module, GPTRefactModel):
|
| 513 |
+
module.gradient_checkpointing = value
|
| 514 |
|
| 515 |
v = version.parse(transformers.__version__)
|
| 516 |
if v.major <= 4 and v.minor < 35:
|