mjkmain commited on
Commit
cfa99fc
·
verified ·
1 Parent(s): ec64b46

Update _modeling_kormo.py

Browse files
Files changed (1) hide show
  1. _modeling_kormo.py +1 -7
_modeling_kormo.py CHANGED
@@ -1,7 +1,6 @@
1
  from typing import Callable, List, Optional, Tuple, Union
2
 
3
  import torch
4
- import torch.utils.checkpoint ### ADD
5
  from torch import nn
6
 
7
  from transformers.activations import ACT2FN
@@ -17,13 +16,10 @@ from transformers.modeling_outputs import (
17
  )
18
  from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
19
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
20
- # from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
21
  from transformers.processing_utils import Unpack
22
- from transformers.utils import auto_docstring, can_return_tuple, logging
23
  from ._configuration_kormo import KORMoConfig
24
 
25
- # from ._flash_attn3_doc import flash_attention_3_doc_forward
26
- # ALL_ATTENTION_FUNCTIONS._global_mapping.update({'flash_attention_3_doc': flash_attention_3_doc_forward})
27
 
28
  logger = logging.get_logger(__name__)
29
 
@@ -421,8 +417,6 @@ class KORMoModel(KORMoPreTrainedModel):
421
  )
422
 
423
 
424
- # class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
425
-
426
  class KORMoForCausalLM(KORMoPreTrainedModel, GenerationMixin):
427
  _tied_weights_keys = ["lm_head.weight"]
428
  _tp_plan = {"lm_head": "colwise_rep"}
 
1
  from typing import Callable, List, Optional, Tuple, Union
2
 
3
  import torch
 
4
  from torch import nn
5
 
6
  from transformers.activations import ACT2FN
 
16
  )
17
  from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
18
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
 
19
  from transformers.processing_utils import Unpack
20
+ from transformers.utils import can_return_tuple, logging
21
  from ._configuration_kormo import KORMoConfig
22
 
 
 
23
 
24
  logger = logging.get_logger(__name__)
25
 
 
417
  )
418
 
419
 
 
 
420
  class KORMoForCausalLM(KORMoPreTrainedModel, GenerationMixin):
421
  _tied_weights_keys = ["lm_head.weight"]
422
  _tp_plan = {"lm_head": "colwise_rep"}