Add files using upload-large-folder tool
Browse files- README.md +2 -2
 - __init__.py +26 -0
 - __pycache__/configuration_minimax_m2.cpython-313.pyc +0 -0
 - config.json +6 -2
 - configuration_minimax_m2.py +147 -0
 - modeling_minimax_m2.py +765 -0
 - test_minimax_m2_hf.py +178 -0
 
    	
        README.md
    CHANGED
    
    | 
         @@ -4,7 +4,7 @@ license: mit 
     | 
|
| 4 | 
         
             
            ---
         
     | 
| 5 | 
         | 
| 6 | 
         
             
            <div align="center">
         
     | 
| 7 | 
         
            -
            Upconverted to BFloat16 by <a href='https://x.com/qubitium'>@Qubitum</a> at ModelCloud
         
     | 
| 8 | 
         | 
| 9 | 
         
             
            <svg width="60%" height="auto" viewBox="0 0 144 48" fill="none" xmlns="http://www.w3.org/2000/svg">
         
     | 
| 10 | 
         
             
            <path d="M26.6782 7.96523C26.6782 7.02436 25.913 6.26087 24.9739 6.26087C24.0348 6.26087 23.2695 7.0261 23.2695 7.96523V36.2139C23.2695 38.4 21.4904 40.1791 19.3043 40.1791C17.1183 40.1791 15.3391 38.4 15.3391 36.2139V18.0904C15.3391 17.1496 14.5739 16.3861 13.6348 16.3861C12.6956 16.3861 11.9304 17.1513 11.9304 18.0904V25.7722C11.9304 27.9583 10.1513 29.7374 7.96518 29.7374C5.7791 29.7374 4 27.9583 4 25.7722V22.9878C4 22.3635 4.50609 21.8574 5.13043 21.8574C5.75478 21.8574 6.26087 22.3635 6.26087 22.9878V25.7722C6.26087 26.713 7.02605 27.4765 7.96518 27.4765C8.90431 27.4765 9.66954 26.7113 9.66954 25.7722V18.0904C9.66954 15.9044 11.4487 14.1252 13.6348 14.1252C15.8209 14.1252 17.6 15.9044 17.6 18.0904V36.2139C17.6 37.1548 18.3652 37.9183 19.3043 37.9183C20.2435 37.9183 21.0087 37.153 21.0087 36.2139V25.1322V7.96523C21.0087 5.77914 22.7878 4 24.9739 4C27.16 4 28.9391 5.77914 28.9391 7.96523V31.3565C28.9391 31.9809 28.433 32.487 27.8087 32.487C27.1843 32.487 26.6782 31.9809 26.6782 31.3565V7.96523ZM47.6539 14.1252C45.4678 14.1252 43.6887 15.9044 43.6887 18.0904V33.2296C43.6887 34.1704 42.9235 34.9339 41.9843 34.9339C41.0452 34.9339 40.28 34.1687 40.28 33.2296V7.96523C40.28 5.77914 38.5008 4 36.3148 4C34.1287 4 32.3496 5.77914 32.3496 7.96523V40.0348C32.3496 40.9756 31.5843 41.7391 30.6452 41.7391C29.7061 41.7391 28.9409 40.9739 28.9409 40.0348V36.0643C28.9409 35.44 28.4348 34.9339 27.8104 34.9339C27.1861 34.9339 26.68 35.44 26.68 36.0643V40.0348C26.68 42.2209 28.4591 44 30.6452 44C32.8313 44 34.6104 42.2209 34.6104 40.0348V7.96523C34.6104 7.02436 35.3756 6.26087 36.3148 6.26087C37.2539 6.26087 38.0191 7.0261 38.0191 7.96523V33.2296C38.0191 35.4156 39.7982 37.1948 41.9843 37.1948C44.1704 37.1948 45.9496 35.4156 45.9496 33.2296V18.0904C45.9496 17.1496 46.7148 16.3861 47.6539 16.3861C48.593 16.3861 49.3582 17.1513 49.3582 18.0904V31.3565C49.3582 31.9809 49.8643 32.487 50.4887 32.487C51.113 32.487 51.6191 31.9809 51.6191 31.3565V18.0904C51.6191 15.9044 49.84 14.1252 47.6539 14.1252Z" fill="url(#paint0_linear_17_483)"/>
         
     | 
| 
         @@ -186,4 +186,4 @@ Please refer to our [Tool Calling Guide](https://huggingface.co/MiniMaxAI/MiniMa 
     | 
|
| 186 | 
         | 
| 187 | 
         
             
            # Contact Us
         
     | 
| 188 | 
         | 
| 189 | 
         
            -
            Contact us at [model@minimax.io](mailto:model@minimax.io).
         
     | 
| 
         | 
|
| 4 | 
         
             
            ---
         
     | 
| 5 | 
         | 
| 6 | 
         
             
            <div align="center">
         
     | 
| 7 | 
         
            +
            Upconverted to BFloat16 by <a href='https://x.com/qubitium'>@Qubitum</a> at ModelCloud.
         
     | 
| 8 | 
         | 
| 9 | 
         
             
            <svg width="60%" height="auto" viewBox="0 0 144 48" fill="none" xmlns="http://www.w3.org/2000/svg">
         
     | 
| 10 | 
         
             
            <path d="M26.6782 7.96523C26.6782 7.02436 25.913 6.26087 24.9739 6.26087C24.0348 6.26087 23.2695 7.0261 23.2695 7.96523V36.2139C23.2695 38.4 21.4904 40.1791 19.3043 40.1791C17.1183 40.1791 15.3391 38.4 15.3391 36.2139V18.0904C15.3391 17.1496 14.5739 16.3861 13.6348 16.3861C12.6956 16.3861 11.9304 17.1513 11.9304 18.0904V25.7722C11.9304 27.9583 10.1513 29.7374 7.96518 29.7374C5.7791 29.7374 4 27.9583 4 25.7722V22.9878C4 22.3635 4.50609 21.8574 5.13043 21.8574C5.75478 21.8574 6.26087 22.3635 6.26087 22.9878V25.7722C6.26087 26.713 7.02605 27.4765 7.96518 27.4765C8.90431 27.4765 9.66954 26.7113 9.66954 25.7722V18.0904C9.66954 15.9044 11.4487 14.1252 13.6348 14.1252C15.8209 14.1252 17.6 15.9044 17.6 18.0904V36.2139C17.6 37.1548 18.3652 37.9183 19.3043 37.9183C20.2435 37.9183 21.0087 37.153 21.0087 36.2139V25.1322V7.96523C21.0087 5.77914 22.7878 4 24.9739 4C27.16 4 28.9391 5.77914 28.9391 7.96523V31.3565C28.9391 31.9809 28.433 32.487 27.8087 32.487C27.1843 32.487 26.6782 31.9809 26.6782 31.3565V7.96523ZM47.6539 14.1252C45.4678 14.1252 43.6887 15.9044 43.6887 18.0904V33.2296C43.6887 34.1704 42.9235 34.9339 41.9843 34.9339C41.0452 34.9339 40.28 34.1687 40.28 33.2296V7.96523C40.28 5.77914 38.5008 4 36.3148 4C34.1287 4 32.3496 5.77914 32.3496 7.96523V40.0348C32.3496 40.9756 31.5843 41.7391 30.6452 41.7391C29.7061 41.7391 28.9409 40.9739 28.9409 40.0348V36.0643C28.9409 35.44 28.4348 34.9339 27.8104 34.9339C27.1861 34.9339 26.68 35.44 26.68 36.0643V40.0348C26.68 42.2209 28.4591 44 30.6452 44C32.8313 44 34.6104 42.2209 34.6104 40.0348V7.96523C34.6104 7.02436 35.3756 6.26087 36.3148 6.26087C37.2539 6.26087 38.0191 7.0261 38.0191 7.96523V33.2296C38.0191 35.4156 39.7982 37.1948 41.9843 37.1948C44.1704 37.1948 45.9496 35.4156 45.9496 33.2296V18.0904C45.9496 17.1496 46.7148 16.3861 47.6539 16.3861C48.593 16.3861 49.3582 17.1513 49.3582 18.0904V31.3565C49.3582 31.9809 49.8643 32.487 50.4887 32.487C51.113 32.487 51.6191 31.9809 51.6191 31.3565V18.0904C51.6191 15.9044 49.84 14.1252 47.6539 14.1252Z" fill="url(#paint0_linear_17_483)"/>
         
     | 
| 
         | 
|
| 186 | 
         | 
| 187 | 
         
             
            # Contact Us
         
     | 
| 188 | 
         | 
| 189 | 
         
            +
            Contact us at [model@minimax.io](mailto:model@minimax.io).
         
     | 
    	
        __init__.py
    ADDED
    
    | 
         @@ -0,0 +1,26 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai
         
     | 
| 2 | 
         
            +
            # SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai
         
     | 
| 3 | 
         
            +
            # SPDX-License-Identifier: Apache-2.0
         
     | 
| 4 | 
         
            +
            # Contact: qubitium@modelcloud.ai, x.com/qubitium
         
     | 
| 5 | 
         
            +
            #
         
     | 
| 6 | 
         
            +
            # """MiniMax M2 Hugging Face remote code support."""
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            from .configuration_minimax_m2 import MiniMaxM2Config
         
     | 
| 9 | 
         
            +
            from .modeling_minimax_m2 import (
         
     | 
| 10 | 
         
            +
                MiniMaxForCausalLM,
         
     | 
| 11 | 
         
            +
                MiniMaxM2ForCausalLM,
         
     | 
| 12 | 
         
            +
                MiniMaxM2Model,
         
     | 
| 13 | 
         
            +
                MiniMaxM2PreTrainedModel,
         
     | 
| 14 | 
         
            +
                MiniMaxModel,
         
     | 
| 15 | 
         
            +
                MiniMaxPreTrainedModel,
         
     | 
| 16 | 
         
            +
            )
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            __all__ = [
         
     | 
| 19 | 
         
            +
                "MiniMaxM2Config",
         
     | 
| 20 | 
         
            +
                "MiniMaxM2PreTrainedModel",
         
     | 
| 21 | 
         
            +
                "MiniMaxM2Model",
         
     | 
| 22 | 
         
            +
                "MiniMaxM2ForCausalLM",
         
     | 
| 23 | 
         
            +
                "MiniMaxPreTrainedModel",
         
     | 
| 24 | 
         
            +
                "MiniMaxModel",
         
     | 
| 25 | 
         
            +
                "MiniMaxForCausalLM",
         
     | 
| 26 | 
         
            +
            ]
         
     | 
    	
        __pycache__/configuration_minimax_m2.cpython-313.pyc
    ADDED
    
    | 
         Binary file (6.02 kB). View file 
     | 
| 
         | 
    	
        config.json
    CHANGED
    
    | 
         @@ -79,7 +79,7 @@ 
     | 
|
| 79 | 
         
             
              "layernorm_mlp_beta": 1.0,
         
     | 
| 80 | 
         
             
              "max_position_embeddings": 196608,
         
     | 
| 81 | 
         
             
              "mlp_intermediate_size": 8192,
         
     | 
| 82 | 
         
            -
              "model_type": " 
     | 
| 83 | 
         
             
              "mtp_transformer_layers": 1,
         
     | 
| 84 | 
         
             
              "num_attention_heads": 48,
         
     | 
| 85 | 
         
             
              "num_experts_per_tok": 8,
         
     | 
| 
         @@ -105,5 +105,9 @@ 
     | 
|
| 105 | 
         
             
              "use_qk_norm": true,
         
     | 
| 106 | 
         
             
              "use_routing_bias": true,
         
     | 
| 107 | 
         
             
              "vocab_size": 200064,
         
     | 
| 108 | 
         
            -
              "torch_dtype": "bfloat16"
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 109 | 
         
             
            }
         
     | 
| 
         | 
|
| 79 | 
         
             
              "layernorm_mlp_beta": 1.0,
         
     | 
| 80 | 
         
             
              "max_position_embeddings": 196608,
         
     | 
| 81 | 
         
             
              "mlp_intermediate_size": 8192,
         
     | 
| 82 | 
         
            +
              "model_type": "minimax",
         
     | 
| 83 | 
         
             
              "mtp_transformer_layers": 1,
         
     | 
| 84 | 
         
             
              "num_attention_heads": 48,
         
     | 
| 85 | 
         
             
              "num_experts_per_tok": 8,
         
     | 
| 
         | 
|
| 105 | 
         
             
              "use_qk_norm": true,
         
     | 
| 106 | 
         
             
              "use_routing_bias": true,
         
     | 
| 107 | 
         
             
              "vocab_size": 200064,
         
     | 
| 108 | 
         
            +
              "torch_dtype": "bfloat16",
         
     | 
| 109 | 
         
            +
              "auto_map": {
         
     | 
| 110 | 
         
            +
                "AutoConfig": "configuration_minimax_m2.MiniMaxM2Config",
         
     | 
| 111 | 
         
            +
                "AutoModelForCausalLM": "modeling_minimax_m2.MiniMaxM2ForCausalLM"
         
     | 
| 112 | 
         
            +
              }
         
     | 
| 113 | 
         
             
            }
         
     | 
    	
        configuration_minimax_m2.py
    ADDED
    
    | 
         @@ -0,0 +1,147 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai
         
     | 
| 2 | 
         
            +
            # SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai
         
     | 
| 3 | 
         
            +
            # SPDX-License-Identifier: Apache-2.0
         
     | 
| 4 | 
         
            +
            # Contact: qubitium@modelcloud.ai, x.com/qubitium
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            """Configuration for the MiniMax M2 architecture."""
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            from __future__ import annotations
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            from typing import List, Optional, Union
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            from transformers.configuration_utils import PretrainedConfig
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            class _QuantizationConfigDict(dict):
         
     | 
| 16 | 
         
            +
                """Ensure quantization config always exposes a `quant_method`."""
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
                def __init__(self, data: Optional[dict] = None):
         
     | 
| 19 | 
         
            +
                    if data is None:
         
     | 
| 20 | 
         
            +
                        data = {}
         
     | 
| 21 | 
         
            +
                    super().__init__(data)
         
     | 
| 22 | 
         
            +
                    self.setdefault("quant_method", "none")
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                def to_dict(self):
         
     | 
| 25 | 
         
            +
                    return dict(self)
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            class MiniMaxM2Config(PretrainedConfig):
         
     | 
| 29 | 
         
            +
                model_type = "minimax"
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                def __init__(
         
     | 
| 32 | 
         
            +
                    self,
         
     | 
| 33 | 
         
            +
                    vocab_size: int = 200_064,
         
     | 
| 34 | 
         
            +
                    hidden_size: int = 3_072,
         
     | 
| 35 | 
         
            +
                    intermediate_size: int = 1_536,
         
     | 
| 36 | 
         
            +
                    mlp_intermediate_size: int = 8_192,
         
     | 
| 37 | 
         
            +
                    num_hidden_layers: int = 62,
         
     | 
| 38 | 
         
            +
                    num_attention_heads: int = 48,
         
     | 
| 39 | 
         
            +
                    num_key_value_heads: int = 8,
         
     | 
| 40 | 
         
            +
                    head_dim: Optional[int] = 128,
         
     | 
| 41 | 
         
            +
                    num_local_experts: int = 256,
         
     | 
| 42 | 
         
            +
                    num_experts_per_tok: int = 8,
         
     | 
| 43 | 
         
            +
                    attn_type_list: Optional[List[int]] = None,
         
     | 
| 44 | 
         
            +
                    attention_dropout: float = 0.0,
         
     | 
| 45 | 
         
            +
                    hidden_act: str = "silu",
         
     | 
| 46 | 
         
            +
                    rms_norm_eps: float = 1e-6,
         
     | 
| 47 | 
         
            +
                    max_position_embeddings: int = 196_608,
         
     | 
| 48 | 
         
            +
                    rope_theta: float = 5_000_000.0,
         
     | 
| 49 | 
         
            +
                    rotary_dim: int = 64,
         
     | 
| 50 | 
         
            +
                    rope_scaling: Optional[dict] = None,
         
     | 
| 51 | 
         
            +
                    use_qk_norm: bool = True,
         
     | 
| 52 | 
         
            +
                    qk_norm_type: str = "per_layer",
         
     | 
| 53 | 
         
            +
                    use_routing_bias: bool = True,
         
     | 
| 54 | 
         
            +
                    scoring_func: str = "sigmoid",
         
     | 
| 55 | 
         
            +
                    router_aux_loss_coef: float = 0.001,
         
     | 
| 56 | 
         
            +
                    router_jitter_noise: float = 0.0,
         
     | 
| 57 | 
         
            +
                    output_router_logits: bool = False,
         
     | 
| 58 | 
         
            +
                    use_grouped_topk: bool = True,
         
     | 
| 59 | 
         
            +
                    num_expert_group: Optional[int] = None,
         
     | 
| 60 | 
         
            +
                    topk_group: Optional[int] = None,
         
     | 
| 61 | 
         
            +
                    routed_scaling_factor: float = 1.0,
         
     | 
| 62 | 
         
            +
                    layernorm_full_attention_beta: float = 1.0,
         
     | 
| 63 | 
         
            +
                    layernorm_linear_attention_beta: float = 1.0,
         
     | 
| 64 | 
         
            +
                    layernorm_mlp_beta: float = 1.0,
         
     | 
| 65 | 
         
            +
                    shared_intermediate_size: int = 0,
         
     | 
| 66 | 
         
            +
                    shared_moe_mode: str = "sigmoid",
         
     | 
| 67 | 
         
            +
                    use_mtp: bool = True,
         
     | 
| 68 | 
         
            +
                    num_mtp_modules: int = 3,
         
     | 
| 69 | 
         
            +
                    mtp_transformer_layers: int = 1,
         
     | 
| 70 | 
         
            +
                    attn_window_size: Optional[Union[int, List[int]]] = None,
         
     | 
| 71 | 
         
            +
                    swa_rope_theta: float = -1.0,
         
     | 
| 72 | 
         
            +
                    sliding_window: Optional[int] = None,
         
     | 
| 73 | 
         
            +
                    initializer_range: float = 0.02,
         
     | 
| 74 | 
         
            +
                    tie_word_embeddings: bool = False,
         
     | 
| 75 | 
         
            +
                    max_model_len: Optional[int] = None,
         
     | 
| 76 | 
         
            +
                    bos_token_id: Optional[int] = None,
         
     | 
| 77 | 
         
            +
                    eos_token_id: Optional[int] = None,
         
     | 
| 78 | 
         
            +
                    pad_token_id: Optional[int] = None,
         
     | 
| 79 | 
         
            +
                    use_cache: bool = True,
         
     | 
| 80 | 
         
            +
                    **kwargs,
         
     | 
| 81 | 
         
            +
                ) -> None:
         
     | 
| 82 | 
         
            +
                    quantization_config = kwargs.pop("quantization_config", None)
         
     | 
| 83 | 
         
            +
                    if quantization_config is None:
         
     | 
| 84 | 
         
            +
                        quantization_config = _QuantizationConfigDict()
         
     | 
| 85 | 
         
            +
                    elif not isinstance(quantization_config, _QuantizationConfigDict):
         
     | 
| 86 | 
         
            +
                        quantization_config = _QuantizationConfigDict(quantization_config)
         
     | 
| 87 | 
         
            +
                    transformers_version = kwargs.pop("transformers_version", None)
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                    super().__init__(
         
     | 
| 90 | 
         
            +
                        bos_token_id=bos_token_id,
         
     | 
| 91 | 
         
            +
                        eos_token_id=eos_token_id,
         
     | 
| 92 | 
         
            +
                        tie_word_embeddings=tie_word_embeddings,
         
     | 
| 93 | 
         
            +
                        pad_token_id=pad_token_id,
         
     | 
| 94 | 
         
            +
                        **kwargs,
         
     | 
| 95 | 
         
            +
                    )
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                    self.vocab_size = vocab_size
         
     | 
| 98 | 
         
            +
                    self.hidden_size = hidden_size
         
     | 
| 99 | 
         
            +
                    self.intermediate_size = intermediate_size
         
     | 
| 100 | 
         
            +
                    self.mlp_intermediate_size = mlp_intermediate_size
         
     | 
| 101 | 
         
            +
                    self.num_hidden_layers = num_hidden_layers
         
     | 
| 102 | 
         
            +
                    self.num_attention_heads = num_attention_heads
         
     | 
| 103 | 
         
            +
                    self.num_key_value_heads = num_key_value_heads
         
     | 
| 104 | 
         
            +
                    self.head_dim = head_dim or hidden_size // num_attention_heads
         
     | 
| 105 | 
         
            +
                    self.num_local_experts = num_local_experts
         
     | 
| 106 | 
         
            +
                    self.num_experts_per_tok = num_experts_per_tok
         
     | 
| 107 | 
         
            +
                    self.attn_type_list = attn_type_list or [1] * num_hidden_layers
         
     | 
| 108 | 
         
            +
                    self.attention_dropout = attention_dropout
         
     | 
| 109 | 
         
            +
                    self.hidden_act = hidden_act
         
     | 
| 110 | 
         
            +
                    self.rms_norm_eps = rms_norm_eps
         
     | 
| 111 | 
         
            +
                    self.max_position_embeddings = max_position_embeddings
         
     | 
| 112 | 
         
            +
                    self.rope_theta = rope_theta
         
     | 
| 113 | 
         
            +
                    self.rotary_dim = rotary_dim
         
     | 
| 114 | 
         
            +
                    self.rope_scaling = rope_scaling
         
     | 
| 115 | 
         
            +
                    self.use_qk_norm = use_qk_norm
         
     | 
| 116 | 
         
            +
                    self.qk_norm_type = qk_norm_type
         
     | 
| 117 | 
         
            +
                    self.use_routing_bias = use_routing_bias
         
     | 
| 118 | 
         
            +
                    self.scoring_func = scoring_func
         
     | 
| 119 | 
         
            +
                    self.router_aux_loss_coef = router_aux_loss_coef
         
     | 
| 120 | 
         
            +
                    self.router_jitter_noise = router_jitter_noise
         
     | 
| 121 | 
         
            +
                    self.output_router_logits = output_router_logits
         
     | 
| 122 | 
         
            +
                    self.use_grouped_topk = use_grouped_topk
         
     | 
| 123 | 
         
            +
                    self.num_expert_group = num_expert_group
         
     | 
| 124 | 
         
            +
                    self.topk_group = topk_group
         
     | 
| 125 | 
         
            +
                    self.routed_scaling_factor = routed_scaling_factor
         
     | 
| 126 | 
         
            +
                    self.layernorm_full_attention_beta = layernorm_full_attention_beta
         
     | 
| 127 | 
         
            +
                    self.layernorm_linear_attention_beta = layernorm_linear_attention_beta
         
     | 
| 128 | 
         
            +
                    self.layernorm_mlp_beta = layernorm_mlp_beta
         
     | 
| 129 | 
         
            +
                    self.shared_intermediate_size = shared_intermediate_size
         
     | 
| 130 | 
         
            +
                    self.shared_moe_mode = shared_moe_mode
         
     | 
| 131 | 
         
            +
                    self.use_mtp = use_mtp
         
     | 
| 132 | 
         
            +
                    self.num_mtp_modules = num_mtp_modules
         
     | 
| 133 | 
         
            +
                    self.mtp_transformer_layers = mtp_transformer_layers
         
     | 
| 134 | 
         
            +
                    self.attn_window_size = attn_window_size
         
     | 
| 135 | 
         
            +
                    self.swa_rope_theta = swa_rope_theta
         
     | 
| 136 | 
         
            +
                    self.sliding_window = sliding_window
         
     | 
| 137 | 
         
            +
                    self.initializer_range = initializer_range
         
     | 
| 138 | 
         
            +
                    self.max_model_len = max_model_len
         
     | 
| 139 | 
         
            +
                    self.use_cache = use_cache
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
                    # Convenient accessor used by rotary embedding helper
         
     | 
| 142 | 
         
            +
                    self.partial_rotary_factor = float(self.rotary_dim) / float(self.head_dim)
         
     | 
| 143 | 
         
            +
                    self.quantization_config = quantization_config
         
     | 
| 144 | 
         
            +
                    self.transformers_version = transformers_version
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
             
     | 
| 147 | 
         
            +
            __all__ = ["MiniMaxM2Config"]
         
     | 
    	
        modeling_minimax_m2.py
    ADDED
    
    | 
         @@ -0,0 +1,765 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai
         
     | 
| 2 | 
         
            +
            # SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai
         
     | 
| 3 | 
         
            +
            # SPDX-License-Identifier: Apache-2.0
         
     | 
| 4 | 
         
            +
            # Contact: qubitium@modelcloud.ai, x.com/qubitium
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            """PyTorch implementation of the MiniMax M2 architecture for Hugging Face Transformers."""
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            from __future__ import annotations
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            import copy
         
     | 
| 11 | 
         
            +
            import time
         
     | 
| 12 | 
         
            +
            from typing import Optional, Tuple, Union
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            import torch
         
     | 
| 15 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 16 | 
         
            +
            from torch import nn
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            from transformers.activations import ACT2FN
         
     | 
| 19 | 
         
            +
            from transformers.cache_utils import Cache, DynamicCache
         
     | 
| 20 | 
         
            +
            from transformers.generation import GenerationMixin
         
     | 
| 21 | 
         
            +
            from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
         
     | 
| 22 | 
         
            +
            from transformers.modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
         
     | 
| 23 | 
         
            +
            from transformers.modeling_utils import PreTrainedModel
         
     | 
| 24 | 
         
            +
            from transformers.utils import logging
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, repeat_kv, rotate_half
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            from .configuration_minimax_m2 import MiniMaxM2Config
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            logger = logging.get_logger(__name__)
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            _CONFIG_FOR_DOC = "MiniMaxM2Config"
         
     | 
| 33 | 
         
            +
            _CHECKPOINT_FOR_DOC = "MiniMaxAI/MiniMax-M2"
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            def load_balancing_loss_func(
         
     | 
| 37 | 
         
            +
                gate_logits: Union[torch.Tensor, Tuple[torch.Tensor, ...]],
         
     | 
| 38 | 
         
            +
                num_experts: int,
         
     | 
| 39 | 
         
            +
                top_k: int,
         
     | 
| 40 | 
         
            +
                attention_mask: Optional[torch.Tensor] = None,
         
     | 
| 41 | 
         
            +
            ) -> torch.Tensor:
         
     | 
| 42 | 
         
            +
                if gate_logits is None:
         
     | 
| 43 | 
         
            +
                    return torch.tensor(0.0)
         
     | 
| 44 | 
         
            +
                if isinstance(gate_logits, torch.Tensor):
         
     | 
| 45 | 
         
            +
                    logits = gate_logits
         
     | 
| 46 | 
         
            +
                else:
         
     | 
| 47 | 
         
            +
                    logits = torch.cat([layer_gate.to(gate_logits[0].device) for layer_gate in gate_logits], dim=0)
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                routing_weights = torch.softmax(logits, dim=-1, dtype=torch.float32)
         
     | 
| 50 | 
         
            +
                _, selected = torch.topk(routing_weights, top_k, dim=-1)
         
     | 
| 51 | 
         
            +
                expert_mask = torch.nn.functional.one_hot(selected, num_experts)
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                if attention_mask is None:
         
     | 
| 54 | 
         
            +
                    tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
         
     | 
| 55 | 
         
            +
                    router_prob_per_expert = torch.mean(routing_weights, dim=0)
         
     | 
| 56 | 
         
            +
                else:
         
     | 
| 57 | 
         
            +
                    batch_size, seq_len = attention_mask.shape
         
     | 
| 58 | 
         
            +
                    num_layers = logits.shape[0] // (batch_size * seq_len)
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                    expanded_mask = (
         
     | 
| 61 | 
         
            +
                        attention_mask[None, :, :, None, None]
         
     | 
| 62 | 
         
            +
                        .expand(num_layers, batch_size, seq_len, top_k, num_experts)
         
     | 
| 63 | 
         
            +
                        .reshape(-1, top_k, num_experts)
         
     | 
| 64 | 
         
            +
                        .to(logits.device)
         
     | 
| 65 | 
         
            +
                    )
         
     | 
| 66 | 
         
            +
                    tokens_per_expert = torch.sum(expert_mask.float() * expanded_mask, dim=0) / torch.sum(expanded_mask, dim=0)
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                    router_mask = (
         
     | 
| 69 | 
         
            +
                        attention_mask[None, :, :, None]
         
     | 
| 70 | 
         
            +
                        .expand(num_layers, batch_size, seq_len, num_experts)
         
     | 
| 71 | 
         
            +
                        .reshape(-1, num_experts)
         
     | 
| 72 | 
         
            +
                        .to(logits.device)
         
     | 
| 73 | 
         
            +
                    )
         
     | 
| 74 | 
         
            +
                    router_prob_per_expert = torch.sum(routing_weights * router_mask, dim=0) / torch.sum(router_mask, dim=0)
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
         
     | 
| 77 | 
         
            +
                return loss * num_experts
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
            def apply_rotary_pos_emb_partial(
         
     | 
| 81 | 
         
            +
                q: torch.Tensor,
         
     | 
| 82 | 
         
            +
                k: torch.Tensor,
         
     | 
| 83 | 
         
            +
                cos: torch.Tensor,
         
     | 
| 84 | 
         
            +
                sin: torch.Tensor,
         
     | 
| 85 | 
         
            +
                rotary_dim: int,
         
     | 
| 86 | 
         
            +
                unsqueeze_dim: int = 2,
         
     | 
| 87 | 
         
            +
            ) -> Tuple[torch.Tensor, torch.Tensor]:
         
     | 
| 88 | 
         
            +
                cos = cos.unsqueeze(unsqueeze_dim)[..., :rotary_dim]
         
     | 
| 89 | 
         
            +
                sin = sin.unsqueeze(unsqueeze_dim)[..., :rotary_dim]
         
     | 
| 90 | 
         
            +
                q_rot = q[..., :rotary_dim]
         
     | 
| 91 | 
         
            +
                k_rot = k[..., :rotary_dim]
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
                q_rot = (q_rot * cos) + (rotate_half(q_rot) * sin)
         
     | 
| 94 | 
         
            +
                k_rot = (k_rot * cos) + (rotate_half(k_rot) * sin)
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                q = torch.cat((q_rot, q[..., rotary_dim:]), dim=-1)
         
     | 
| 97 | 
         
            +
                k = torch.cat((k_rot, k[..., rotary_dim:]), dim=-1)
         
     | 
| 98 | 
         
            +
                return q, k
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
            class MiniMaxM2RMSNorm(nn.Module):
         
     | 
| 102 | 
         
            +
                def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
         
     | 
| 103 | 
         
            +
                    super().__init__()
         
     | 
| 104 | 
         
            +
                    self.weight = nn.Parameter(torch.ones(hidden_size))
         
     | 
| 105 | 
         
            +
                    self.variance_epsilon = eps
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
         
     | 
| 108 | 
         
            +
                    input_dtype = hidden_states.dtype
         
     | 
| 109 | 
         
            +
                    hidden_states = hidden_states.to(torch.float32)
         
     | 
| 110 | 
         
            +
                    variance = hidden_states.pow(2).mean(-1, keepdim=True)
         
     | 
| 111 | 
         
            +
                    hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
         
     | 
| 112 | 
         
            +
                    return (self.weight * hidden_states).to(input_dtype)
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
            class MiniMaxM2MLP(nn.Module):
         
     | 
| 116 | 
         
            +
                def __init__(self, config: MiniMaxM2Config) -> None:
         
     | 
| 117 | 
         
            +
                    super().__init__()
         
     | 
| 118 | 
         
            +
                    self.hidden_size = config.hidden_size
         
     | 
| 119 | 
         
            +
                    self.intermediate_size = config.intermediate_size
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                    self.w1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
         
     | 
| 122 | 
         
            +
                    self.w2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
         
     | 
| 123 | 
         
            +
                    self.w3 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
         
     | 
| 124 | 
         
            +
                    self.act_fn = ACT2FN[config.hidden_act]
         
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
                def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
         
     | 
| 127 | 
         
            +
                    gate = self.act_fn(self.w1(hidden_states))
         
     | 
| 128 | 
         
            +
                    up = self.w3(hidden_states)
         
     | 
| 129 | 
         
            +
                    hidden_states = gate * up
         
     | 
| 130 | 
         
            +
                    hidden_states = self.w2(hidden_states)
         
     | 
| 131 | 
         
            +
                    return hidden_states
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
            class MiniMaxM2SparseMoeBlock(nn.Module):
         
     | 
| 135 | 
         
            +
                def __init__(self, config: MiniMaxM2Config) -> None:
         
     | 
| 136 | 
         
            +
                    super().__init__()
         
     | 
| 137 | 
         
            +
                    self.hidden_dim = config.hidden_size
         
     | 
| 138 | 
         
            +
                    self.experts = nn.ModuleList([MiniMaxM2MLP(config) for _ in range(config.num_local_experts)])
         
     | 
| 139 | 
         
            +
                    self.num_experts = config.num_local_experts
         
     | 
| 140 | 
         
            +
                    self.top_k = config.num_experts_per_tok
         
     | 
| 141 | 
         
            +
                    self.jitter_noise = config.router_jitter_noise
         
     | 
| 142 | 
         
            +
                    self.use_routing_bias = config.use_routing_bias
         
     | 
| 143 | 
         
            +
                    self.scoring_func = getattr(config, "scoring_func", "softmax")
         
     | 
| 144 | 
         
            +
                    self.use_grouped_topk = getattr(config, "use_grouped_topk", False)
         
     | 
| 145 | 
         
            +
                    self.num_expert_group = getattr(config, "num_expert_group", None)
         
     | 
| 146 | 
         
            +
                    self.topk_group = getattr(config, "topk_group", None)
         
     | 
| 147 | 
         
            +
                    self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0)
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
                    if self.use_grouped_topk:
         
     | 
| 150 | 
         
            +
                        if self.num_expert_group is None or self.num_expert_group <= 0:
         
     | 
| 151 | 
         
            +
                            self.num_expert_group = 1
         
     | 
| 152 | 
         
            +
                        if self.topk_group is None or self.topk_group <= 0:
         
     | 
| 153 | 
         
            +
                            self.topk_group = min(self.num_expert_group, self.top_k)
         
     | 
| 154 | 
         
            +
                    else:
         
     | 
| 155 | 
         
            +
                        self.num_expert_group = 1
         
     | 
| 156 | 
         
            +
                        self.topk_group = 1
         
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
                    self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
         
     | 
| 159 | 
         
            +
                    if self.use_routing_bias:
         
     | 
| 160 | 
         
            +
                        self.e_score_correction_bias = nn.Parameter(torch.zeros(self.num_experts, dtype=torch.float32))
         
     | 
| 161 | 
         
            +
                    else:
         
     | 
| 162 | 
         
            +
                        self.register_parameter("e_score_correction_bias", None)
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
                def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
         
     | 
| 165 | 
         
            +
                    batch_size, seq_len, hidden_dim = hidden_states.shape
         
     | 
| 166 | 
         
            +
                    if self.training and self.jitter_noise > 0:
         
     | 
| 167 | 
         
            +
                        noise = torch.empty_like(hidden_states).uniform_(
         
     | 
| 168 | 
         
            +
                            1.0 - self.jitter_noise,
         
     | 
| 169 | 
         
            +
                            1.0 + self.jitter_noise,
         
     | 
| 170 | 
         
            +
                        )
         
     | 
| 171 | 
         
            +
                        hidden_states = hidden_states * noise
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
                    hidden_states = hidden_states.view(-1, hidden_dim)
         
     | 
| 174 | 
         
            +
                    gate_dtype = self.gate.weight.dtype
         
     | 
| 175 | 
         
            +
                    router_logits = self.gate(hidden_states.to(gate_dtype)).to(torch.float32)
         
     | 
| 176 | 
         
            +
                    if self.e_score_correction_bias is not None:
         
     | 
| 177 | 
         
            +
                        # Bias is applied after scoring (see vLLM/SGLang implementations).
         
     | 
| 178 | 
         
            +
                        correction_bias = self.e_score_correction_bias.to(router_logits.device, router_logits.dtype)
         
     | 
| 179 | 
         
            +
                    else:
         
     | 
| 180 | 
         
            +
                        correction_bias = None
         
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
                    if self.scoring_func == "sigmoid":
         
     | 
| 183 | 
         
            +
                        scores = torch.sigmoid(router_logits)
         
     | 
| 184 | 
         
            +
                    elif self.scoring_func == "softmax":
         
     | 
| 185 | 
         
            +
                        scores = torch.softmax(router_logits, dim=-1)
         
     | 
| 186 | 
         
            +
                    else:
         
     | 
| 187 | 
         
            +
                        raise ValueError(f"Unsupported scoring function: {self.scoring_func}")
         
     | 
| 188 | 
         
            +
             
     | 
| 189 | 
         
            +
                    if correction_bias is not None:
         
     | 
| 190 | 
         
            +
                        original_scores = scores
         
     | 
| 191 | 
         
            +
                        scores = scores + correction_bias
         
     | 
| 192 | 
         
            +
                    else:
         
     | 
| 193 | 
         
            +
                        original_scores = scores
         
     | 
| 194 | 
         
            +
                    topk_scores: torch.Tensor
         
     | 
| 195 | 
         
            +
                    if self.use_grouped_topk and self.num_expert_group > 1:
         
     | 
| 196 | 
         
            +
                        experts_per_group = scores.size(-1) // self.num_expert_group
         
     | 
| 197 | 
         
            +
                        scores_grouped = scores.view(scores.size(0), self.num_expert_group, experts_per_group)
         
     | 
| 198 | 
         
            +
                        if correction_bias is not None:
         
     | 
| 199 | 
         
            +
                            topk_in_group = min(2, experts_per_group)
         
     | 
| 200 | 
         
            +
                            if topk_in_group > 0:
         
     | 
| 201 | 
         
            +
                                group_scores = scores_grouped.topk(topk_in_group, dim=-1)[0].sum(dim=-1)
         
     | 
| 202 | 
         
            +
                            else:
         
     | 
| 203 | 
         
            +
                                group_scores = torch.zeros_like(scores_grouped[..., 0])
         
     | 
| 204 | 
         
            +
                        else:
         
     | 
| 205 | 
         
            +
                            group_scores = scores_grouped.max(dim=-1).values
         
     | 
| 206 | 
         
            +
                        group_mask = torch.zeros_like(group_scores)
         
     | 
| 207 | 
         
            +
                        selected_groups = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=True).indices
         
     | 
| 208 | 
         
            +
                        group_mask.scatter_(1, selected_groups, 1.0)
         
     | 
| 209 | 
         
            +
                        mask = group_mask.unsqueeze(-1).expand(-1, -1, experts_per_group).reshape(scores.size())
         
     | 
| 210 | 
         
            +
                        masked_scores = scores.masked_fill(mask == 0, float("-inf"))
         
     | 
| 211 | 
         
            +
                        topk_scores, selected_experts = torch.topk(masked_scores, self.top_k, dim=-1, sorted=True)
         
     | 
| 212 | 
         
            +
                    else:
         
     | 
| 213 | 
         
            +
                        topk_scores, selected_experts = torch.topk(scores, self.top_k, dim=-1, sorted=True)
         
     | 
| 214 | 
         
            +
             
     | 
| 215 | 
         
            +
                    if correction_bias is not None:
         
     | 
| 216 | 
         
            +
                        routing_weights = original_scores.gather(1, selected_experts)
         
     | 
| 217 | 
         
            +
                    else:
         
     | 
| 218 | 
         
            +
                        routing_weights = topk_scores
         
     | 
| 219 | 
         
            +
             
     | 
| 220 | 
         
            +
                    routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True).clamp(min=1e-12)
         
     | 
| 221 | 
         
            +
                    if self.routed_scaling_factor != 1.0:
         
     | 
| 222 | 
         
            +
                        routing_weights = routing_weights * self.routed_scaling_factor
         
     | 
| 223 | 
         
            +
                    routing_weights = routing_weights.to(hidden_states.dtype)
         
     | 
| 224 | 
         
            +
                    selected_experts = selected_experts.to(torch.long)
         
     | 
| 225 | 
         
            +
             
     | 
| 226 | 
         
            +
                    final_hidden_states = torch.zeros_like(hidden_states)
         
     | 
| 227 | 
         
            +
                    expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
         
     | 
| 228 | 
         
            +
                    expert_hit = torch.nonzero(expert_mask.sum(dim=(-1, -2)) > 0, as_tuple=False).flatten()
         
     | 
| 229 | 
         
            +
             
     | 
| 230 | 
         
            +
                    for expert_idx in expert_hit.tolist():
         
     | 
| 231 | 
         
            +
                        expert_layer = self.experts[expert_idx]
         
     | 
| 232 | 
         
            +
                        idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
         
     | 
| 233 | 
         
            +
                        token_states = hidden_states.index_select(0, top_x)
         
     | 
| 234 | 
         
            +
                        expert_output = expert_layer(token_states) * routing_weights[top_x, idx].unsqueeze(-1)
         
     | 
| 235 | 
         
            +
                        final_hidden_states.index_add_(0, top_x, expert_output.to(final_hidden_states.dtype))
         
     | 
| 236 | 
         
            +
             
     | 
| 237 | 
         
            +
                    final_hidden_states = final_hidden_states.view(batch_size, seq_len, hidden_dim)
         
     | 
| 238 | 
         
            +
                    return final_hidden_states, router_logits
         
     | 
| 239 | 
         
            +
             
     | 
| 240 | 
         
            +
             
     | 
| 241 | 
         
            +
            class MiniMaxM2Attention(nn.Module):
         
     | 
| 242 | 
         
            +
                def __init__(self, config: MiniMaxM2Config, layer_idx: int) -> None:
         
     | 
| 243 | 
         
            +
                    super().__init__()
         
     | 
| 244 | 
         
            +
                    self.config = config
         
     | 
| 245 | 
         
            +
                    self.layer_idx = layer_idx
         
     | 
| 246 | 
         
            +
             
     | 
| 247 | 
         
            +
                    self.head_dim = config.head_dim
         
     | 
| 248 | 
         
            +
                    self.num_heads = config.num_attention_heads
         
     | 
| 249 | 
         
            +
                    self.num_key_value_heads = config.num_key_value_heads
         
     | 
| 250 | 
         
            +
                    self.num_key_value_groups = self.num_heads // max(1, self.num_key_value_heads)
         
     | 
| 251 | 
         
            +
                    self.rotary_dim = config.rotary_dim
         
     | 
| 252 | 
         
            +
                    self.scaling = self.head_dim**-0.5
         
     | 
| 253 | 
         
            +
                    self.attention_dropout = config.attention_dropout
         
     | 
| 254 | 
         
            +
                    self.is_causal = True
         
     | 
| 255 | 
         
            +
             
     | 
| 256 | 
         
            +
                    max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
         
     | 
| 257 | 
         
            +
                    max_model_len = getattr(config, "max_model_len", None)
         
     | 
| 258 | 
         
            +
                    if max_model_len is not None:
         
     | 
| 259 | 
         
            +
                        max_position_embeddings = max(max_position_embeddings, max_model_len)
         
     | 
| 260 | 
         
            +
             
     | 
| 261 | 
         
            +
                    attn_window_size = getattr(config, "attn_window_size", None)
         
     | 
| 262 | 
         
            +
                    if isinstance(attn_window_size, list):
         
     | 
| 263 | 
         
            +
                        sliding_window = attn_window_size[layer_idx]
         
     | 
| 264 | 
         
            +
                    else:
         
     | 
| 265 | 
         
            +
                        sliding_window = attn_window_size
         
     | 
| 266 | 
         
            +
                    if sliding_window is not None and sliding_window <= 0:
         
     | 
| 267 | 
         
            +
                        sliding_window = None
         
     | 
| 268 | 
         
            +
                    self.sliding_window = sliding_window
         
     | 
| 269 | 
         
            +
             
     | 
| 270 | 
         
            +
                    swa_rope_theta = getattr(config, "swa_rope_theta", -1.0)
         
     | 
| 271 | 
         
            +
                    rope_theta = config.rope_theta
         
     | 
| 272 | 
         
            +
                    if self.sliding_window is not None and swa_rope_theta > 0:
         
     | 
| 273 | 
         
            +
                        rope_theta = swa_rope_theta
         
     | 
| 274 | 
         
            +
             
     | 
| 275 | 
         
            +
                    self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
         
     | 
| 276 | 
         
            +
                    self.k_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
         
     | 
| 277 | 
         
            +
                    self.v_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
         
     | 
| 278 | 
         
            +
                    self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
         
     | 
| 279 | 
         
            +
             
     | 
| 280 | 
         
            +
                    self.use_qk_norm = config.use_qk_norm
         
     | 
| 281 | 
         
            +
                    if self.use_qk_norm:
         
     | 
| 282 | 
         
            +
                        self.q_norm = MiniMaxM2RMSNorm(self.num_heads * self.head_dim, eps=config.rms_norm_eps)
         
     | 
| 283 | 
         
            +
                        self.k_norm = MiniMaxM2RMSNorm(self.num_key_value_heads * self.head_dim, eps=config.rms_norm_eps)
         
     | 
| 284 | 
         
            +
             
     | 
| 285 | 
         
            +
                    rope_config = copy.deepcopy(config)
         
     | 
| 286 | 
         
            +
                    rope_config.hidden_size = config.hidden_size
         
     | 
| 287 | 
         
            +
                    rope_config.num_attention_heads = config.num_attention_heads
         
     | 
| 288 | 
         
            +
                    rope_config.partial_rotary_factor = float(config.rotary_dim) / float(self.head_dim)
         
     | 
| 289 | 
         
            +
                    rope_config.rope_theta = rope_theta
         
     | 
| 290 | 
         
            +
                    rope_config.max_position_embeddings = max_position_embeddings
         
     | 
| 291 | 
         
            +
                    self.rotary_emb = LlamaRotaryEmbedding(rope_config)
         
     | 
| 292 | 
         
            +
             
     | 
| 293 | 
         
            +
                def forward(
         
     | 
| 294 | 
         
            +
                    self,
         
     | 
| 295 | 
         
            +
                    hidden_states: torch.Tensor,
         
     | 
| 296 | 
         
            +
                    attention_mask: Optional[torch.Tensor] = None,
         
     | 
| 297 | 
         
            +
                    position_ids: Optional[torch.LongTensor] = None,
         
     | 
| 298 | 
         
            +
                    past_key_values: Optional[Cache] = None,
         
     | 
| 299 | 
         
            +
                    use_cache: Optional[bool] = False,
         
     | 
| 300 | 
         
            +
                    cache_position: Optional[torch.LongTensor] = None,
         
     | 
| 301 | 
         
            +
                    position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
         
     | 
| 302 | 
         
            +
                    output_attentions: bool = False,
         
     | 
| 303 | 
         
            +
                ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
         
     | 
| 304 | 
         
            +
                    bsz, q_len, _ = hidden_states.size()
         
     | 
| 305 | 
         
            +
             
     | 
| 306 | 
         
            +
                    query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
         
     | 
| 307 | 
         
            +
                    key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
         
     | 
| 308 | 
         
            +
                    value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
         
     | 
| 309 | 
         
            +
             
     | 
| 310 | 
         
            +
                    if self.use_qk_norm:
         
     | 
| 311 | 
         
            +
                        q_flat = query_states.transpose(1, 2).reshape(bsz * q_len, -1)
         
     | 
| 312 | 
         
            +
                        k_flat = key_states.transpose(1, 2).reshape(bsz * q_len, -1)
         
     | 
| 313 | 
         
            +
                        q_flat = self.q_norm(q_flat)
         
     | 
| 314 | 
         
            +
                        k_flat = self.k_norm(k_flat)
         
     | 
| 315 | 
         
            +
                        query_states = q_flat.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
         
     | 
| 316 | 
         
            +
                        key_states = k_flat.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
         
     | 
| 317 | 
         
            +
             
     | 
| 318 | 
         
            +
                    if position_embeddings is None:
         
     | 
| 319 | 
         
            +
                        cos, sin = self.rotary_emb(value_states, position_ids)
         
     | 
| 320 | 
         
            +
                    else:
         
     | 
| 321 | 
         
            +
                        cos, sin = position_embeddings
         
     | 
| 322 | 
         
            +
             
     | 
| 323 | 
         
            +
                    query_states, key_states = apply_rotary_pos_emb_partial(
         
     | 
| 324 | 
         
            +
                        query_states.transpose(1, 2), key_states.transpose(1, 2), cos, sin, self.rotary_dim
         
     | 
| 325 | 
         
            +
                    )
         
     | 
| 326 | 
         
            +
                    query_states = query_states.transpose(1, 2)
         
     | 
| 327 | 
         
            +
                    key_states = key_states.transpose(1, 2)
         
     | 
| 328 | 
         
            +
             
     | 
| 329 | 
         
            +
                    if past_key_values is not None:
         
     | 
| 330 | 
         
            +
                        cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
         
     | 
| 331 | 
         
            +
                        key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
         
     | 
| 332 | 
         
            +
             
     | 
| 333 | 
         
            +
                    key_states = repeat_kv(key_states, self.num_key_value_groups)
         
     | 
| 334 | 
         
            +
                    value_states = repeat_kv(value_states, self.num_key_value_groups)
         
     | 
| 335 | 
         
            +
             
     | 
| 336 | 
         
            +
                    attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) * self.scaling
         
     | 
| 337 | 
         
            +
                    if attention_mask is not None:
         
     | 
| 338 | 
         
            +
                        attn_weights = attn_weights + attention_mask
         
     | 
| 339 | 
         
            +
             
     | 
| 340 | 
         
            +
                    if self.sliding_window is not None and past_key_values is None:
         
     | 
| 341 | 
         
            +
                        query_positions = torch.arange(q_len, device=hidden_states.device).view(1, 1, q_len, 1)
         
     | 
| 342 | 
         
            +
                        key_positions = torch.arange(key_states.shape[-2], device=hidden_states.device).view(1, 1, 1, -1)
         
     | 
| 343 | 
         
            +
                        window_mask = key_positions < (query_positions - self.sliding_window)
         
     | 
| 344 | 
         
            +
                        if window_mask.any():
         
     | 
| 345 | 
         
            +
                            attn_weights = attn_weights.masked_fill(window_mask, float("-inf"))
         
     | 
| 346 | 
         
            +
             
     | 
| 347 | 
         
            +
                    attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
         
     | 
| 348 | 
         
            +
                    if self.training and self.attention_dropout > 0:
         
     | 
| 349 | 
         
            +
                        attn_weights = F.dropout(attn_weights, p=self.attention_dropout)
         
     | 
| 350 | 
         
            +
             
     | 
| 351 | 
         
            +
                    attn_output = torch.matmul(attn_weights, value_states)
         
     | 
| 352 | 
         
            +
                    attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, -1)
         
     | 
| 353 | 
         
            +
                    attn_output = self.o_proj(attn_output)
         
     | 
| 354 | 
         
            +
             
     | 
| 355 | 
         
            +
                    if not output_attentions:
         
     | 
| 356 | 
         
            +
                        attn_weights = None
         
     | 
| 357 | 
         
            +
                    return attn_output, attn_weights
         
     | 
| 358 | 
         
            +
             
     | 
| 359 | 
         
            +
             
     | 
| 360 | 
         
            +
            class MiniMaxM2LogitsProcessor(nn.Module):
         
     | 
| 361 | 
         
            +
                def __init__(self, config: MiniMaxM2Config) -> None:
         
     | 
| 362 | 
         
            +
                    super().__init__()
         
     | 
| 363 | 
         
            +
                    self.scale = getattr(config, "logits_scale", 1.0)
         
     | 
| 364 | 
         
            +
             
     | 
| 365 | 
         
            +
                def forward(self, lm_head: nn.Module, hidden_states: torch.Tensor) -> torch.Tensor:
         
     | 
| 366 | 
         
            +
                    logits = lm_head(hidden_states)
         
     | 
| 367 | 
         
            +
                    if self.scale != 1.0:
         
     | 
| 368 | 
         
            +
                        logits = logits * self.scale
         
     | 
| 369 | 
         
            +
                    return logits
         
     | 
| 370 | 
         
            +
             
     | 
| 371 | 
         
            +
             
     | 
| 372 | 
         
            +
            class MiniMaxM2DecoderLayer(nn.Module):
         
     | 
| 373 | 
         
            +
                def __init__(self, config: MiniMaxM2Config, layer_idx: int) -> None:
         
     | 
| 374 | 
         
            +
                    super().__init__()
         
     | 
| 375 | 
         
            +
                    self.hidden_size = config.hidden_size
         
     | 
| 376 | 
         
            +
                    self.self_attn = MiniMaxM2Attention(config, layer_idx)
         
     | 
| 377 | 
         
            +
                    self.block_sparse_moe = MiniMaxM2SparseMoeBlock(config)
         
     | 
| 378 | 
         
            +
                    self.input_layernorm = MiniMaxM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
         
     | 
| 379 | 
         
            +
                    self.post_attention_layernorm = MiniMaxM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
         
     | 
| 380 | 
         
            +
             
     | 
| 381 | 
         
            +
                def forward(
         
     | 
| 382 | 
         
            +
                    self,
         
     | 
| 383 | 
         
            +
                    hidden_states: torch.Tensor,
         
     | 
| 384 | 
         
            +
                    attention_mask: Optional[torch.Tensor] = None,
         
     | 
| 385 | 
         
            +
                    position_ids: Optional[torch.LongTensor] = None,
         
     | 
| 386 | 
         
            +
                    past_key_values: Optional[Cache] = None,
         
     | 
| 387 | 
         
            +
                    use_cache: Optional[bool] = False,
         
     | 
| 388 | 
         
            +
                    cache_position: Optional[torch.LongTensor] = None,
         
     | 
| 389 | 
         
            +
                    position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
         
     | 
| 390 | 
         
            +
                    output_attentions: bool = False,
         
     | 
| 391 | 
         
            +
                    residual: Optional[torch.Tensor] = None,
         
     | 
| 392 | 
         
            +
                ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
         
     | 
| 393 | 
         
            +
                    residual_input = hidden_states if residual is None else residual
         
     | 
| 394 | 
         
            +
                    hidden_states = self.input_layernorm(hidden_states)
         
     | 
| 395 | 
         
            +
             
     | 
| 396 | 
         
            +
                    attn_output, attn_weights = self.self_attn(
         
     | 
| 397 | 
         
            +
                        hidden_states=hidden_states,
         
     | 
| 398 | 
         
            +
                        attention_mask=attention_mask,
         
     | 
| 399 | 
         
            +
                        position_ids=position_ids,
         
     | 
| 400 | 
         
            +
                        past_key_values=past_key_values,
         
     | 
| 401 | 
         
            +
                        use_cache=use_cache,
         
     | 
| 402 | 
         
            +
                        cache_position=cache_position,
         
     | 
| 403 | 
         
            +
                        position_embeddings=position_embeddings,
         
     | 
| 404 | 
         
            +
                        output_attentions=output_attentions,
         
     | 
| 405 | 
         
            +
                    )
         
     | 
| 406 | 
         
            +
                    hidden_states = residual_input + attn_output
         
     | 
| 407 | 
         
            +
             
     | 
| 408 | 
         
            +
                    residual_post_attn = hidden_states
         
     | 
| 409 | 
         
            +
                    hidden_states = self.post_attention_layernorm(hidden_states)
         
     | 
| 410 | 
         
            +
                    moe_output, router_logits = self.block_sparse_moe(hidden_states)
         
     | 
| 411 | 
         
            +
                    hidden_states = residual_post_attn + moe_output
         
     | 
| 412 | 
         
            +
             
     | 
| 413 | 
         
            +
                    return hidden_states, hidden_states, router_logits, attn_weights
         
     | 
| 414 | 
         
            +
             
     | 
| 415 | 
         
            +
             
     | 
| 416 | 
         
            +
            class MiniMaxM2PreTrainedModel(PreTrainedModel):
         
     | 
| 417 | 
         
            +
                config_class = MiniMaxM2Config
         
     | 
| 418 | 
         
            +
                base_model_prefix = "model"
         
     | 
| 419 | 
         
            +
                supports_gradient_checkpointing = True
         
     | 
| 420 | 
         
            +
                _no_split_modules = ["MiniMaxM2DecoderLayer"]
         
     | 
| 421 | 
         
            +
                _supports_flash_attn = False
         
     | 
| 422 | 
         
            +
                _supports_sdpa = False
         
     | 
| 423 | 
         
            +
                _supports_attention_backend = False
         
     | 
| 424 | 
         
            +
             
     | 
| 425 | 
         
            +
                def _init_weights(self, module: nn.Module) -> None:
         
     | 
| 426 | 
         
            +
                    if isinstance(module, nn.Linear):
         
     | 
| 427 | 
         
            +
                        module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
         
     | 
| 428 | 
         
            +
                        if module.bias is not None:
         
     | 
| 429 | 
         
            +
                            module.bias.data.zero_()
         
     | 
| 430 | 
         
            +
                    elif isinstance(module, nn.Embedding):
         
     | 
| 431 | 
         
            +
                        module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
         
     | 
| 432 | 
         
            +
                        if module.padding_idx is not None:
         
     | 
| 433 | 
         
            +
                            module.weight.data[module.padding_idx].zero_()
         
     | 
| 434 | 
         
            +
             
     | 
| 435 | 
         
            +
                def _remap_qkv_weights(self, state_dict):
         
     | 
| 436 | 
         
            +
                    num_q = self.config.num_attention_heads * self.config.head_dim
         
     | 
| 437 | 
         
            +
                    num_kv = self.config.num_key_value_heads * self.config.head_dim
         
     | 
| 438 | 
         
            +
             
     | 
| 439 | 
         
            +
                    for layer_idx in range(self.config.num_hidden_layers):
         
     | 
| 440 | 
         
            +
                        prefix = f"model.layers.{layer_idx}.self_attn"
         
     | 
| 441 | 
         
            +
                        weight_key = f"{prefix}.qkv_proj.weight"
         
     | 
| 442 | 
         
            +
                        if weight_key in state_dict:
         
     | 
| 443 | 
         
            +
                            qkv_weight = state_dict.pop(weight_key)
         
     | 
| 444 | 
         
            +
                            q_weight, k_weight, v_weight = qkv_weight.split([num_q, num_kv, num_kv], dim=0)
         
     | 
| 445 | 
         
            +
                            state_dict.setdefault(f"{prefix}.q_proj.weight", q_weight)
         
     | 
| 446 | 
         
            +
                            state_dict.setdefault(f"{prefix}.k_proj.weight", k_weight)
         
     | 
| 447 | 
         
            +
                            state_dict.setdefault(f"{prefix}.v_proj.weight", v_weight)
         
     | 
| 448 | 
         
            +
             
     | 
| 449 | 
         
            +
                def load_state_dict(self, state_dict, strict: bool = True):
         
     | 
| 450 | 
         
            +
                    if not isinstance(state_dict, dict):
         
     | 
| 451 | 
         
            +
                        raise TypeError(f"Expected state_dict to be dict, got {type(state_dict)}")
         
     | 
| 452 | 
         
            +
             
     | 
| 453 | 
         
            +
                    filtered_state_dict = {}
         
     | 
| 454 | 
         
            +
                    drop_suffixes = ("weight_scale_inv", "weight_scale", "input_scale", "scales", "amax")
         
     | 
| 455 | 
         
            +
                    for key, value in state_dict.items():
         
     | 
| 456 | 
         
            +
                        if key.endswith(drop_suffixes) or "fp8" in key:
         
     | 
| 457 | 
         
            +
                            continue
         
     | 
| 458 | 
         
            +
                        filtered_state_dict[key] = value
         
     | 
| 459 | 
         
            +
             
     | 
| 460 | 
         
            +
                    self._remap_qkv_weights(filtered_state_dict)
         
     | 
| 461 | 
         
            +
             
     | 
| 462 | 
         
            +
                    if logger.isEnabledFor(logging.INFO):
         
     | 
| 463 | 
         
            +
                        logger.info(
         
     | 
| 464 | 
         
            +
                            "MiniMaxM2: loading %d tensors (filtered from %d original).",
         
     | 
| 465 | 
         
            +
                            len(filtered_state_dict),
         
     | 
| 466 | 
         
            +
                            len(state_dict),
         
     | 
| 467 | 
         
            +
                        )
         
     | 
| 468 | 
         
            +
             
     | 
| 469 | 
         
            +
                    load_start = time.perf_counter()
         
     | 
| 470 | 
         
            +
                    result = super().load_state_dict(filtered_state_dict, strict=strict)
         
     | 
| 471 | 
         
            +
                    load_elapsed = time.perf_counter() - load_start
         
     | 
| 472 | 
         
            +
                    if logger.isEnabledFor(logging.INFO):
         
     | 
| 473 | 
         
            +
                        logger.info("MiniMaxM2: state_dict load finished in %.2f seconds.", load_elapsed)
         
     | 
| 474 | 
         
            +
             
     | 
| 475 | 
         
            +
                    return result
         
     | 
| 476 | 
         
            +
             
     | 
| 477 | 
         
            +
             
     | 
| 478 | 
         
            +
            class MiniMaxM2Model(MiniMaxM2PreTrainedModel):
         
     | 
| 479 | 
         
            +
                def __init__(self, config: MiniMaxM2Config) -> None:
         
     | 
| 480 | 
         
            +
                    super().__init__(config)
         
     | 
| 481 | 
         
            +
                    self.padding_idx = config.pad_token_id
         
     | 
| 482 | 
         
            +
                    self.vocab_size = config.vocab_size
         
     | 
| 483 | 
         
            +
             
     | 
| 484 | 
         
            +
                    self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
         
     | 
| 485 | 
         
            +
                    self.layers = nn.ModuleList(
         
     | 
| 486 | 
         
            +
                        [MiniMaxM2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
         
     | 
| 487 | 
         
            +
                    )
         
     | 
| 488 | 
         
            +
                    self.norm = MiniMaxM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
         
     | 
| 489 | 
         
            +
                    self.gradient_checkpointing = False
         
     | 
| 490 | 
         
            +
             
     | 
| 491 | 
         
            +
                    self.post_init()
         
     | 
| 492 | 
         
            +
             
     | 
| 493 | 
         
            +
                def get_input_embeddings(self) -> nn.Module:
         
     | 
| 494 | 
         
            +
                    return self.embed_tokens
         
     | 
| 495 | 
         
            +
             
     | 
| 496 | 
         
            +
                def set_input_embeddings(self, value: nn.Module) -> None:
         
     | 
| 497 | 
         
            +
                    self.embed_tokens = value
         
     | 
| 498 | 
         
            +
             
     | 
| 499 | 
         
            +
                def forward(
         
     | 
| 500 | 
         
            +
                    self,
         
     | 
| 501 | 
         
            +
                    input_ids: Optional[torch.LongTensor] = None,
         
     | 
| 502 | 
         
            +
                    attention_mask: Optional[torch.Tensor] = None,
         
     | 
| 503 | 
         
            +
                    position_ids: Optional[torch.LongTensor] = None,
         
     | 
| 504 | 
         
            +
                    past_key_values: Optional[Cache] = None,
         
     | 
| 505 | 
         
            +
                    inputs_embeds: Optional[torch.Tensor] = None,
         
     | 
| 506 | 
         
            +
                    cache_position: Optional[torch.LongTensor] = None,
         
     | 
| 507 | 
         
            +
                    use_cache: Optional[bool] = None,
         
     | 
| 508 | 
         
            +
                    output_attentions: bool = False,
         
     | 
| 509 | 
         
            +
                    output_hidden_states: bool = False,
         
     | 
| 510 | 
         
            +
                    output_router_logits: Optional[bool] = None,
         
     | 
| 511 | 
         
            +
                    return_dict: Optional[bool] = None,
         
     | 
| 512 | 
         
            +
                ) -> Union[MoeModelOutputWithPast, Tuple]:
         
     | 
| 513 | 
         
            +
                    if (input_ids is None) == (inputs_embeds is None):
         
     | 
| 514 | 
         
            +
                        raise ValueError("You must specify exactly one of input_ids or inputs_embeds.")
         
     | 
| 515 | 
         
            +
             
     | 
| 516 | 
         
            +
                    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
         
     | 
| 517 | 
         
            +
                    use_cache = use_cache if use_cache is not None else self.config.use_cache
         
     | 
| 518 | 
         
            +
                    output_router_logits = (
         
     | 
| 519 | 
         
            +
                        output_router_logits if output_router_logits is not None else self.config.output_router_logits
         
     | 
| 520 | 
         
            +
                    )
         
     | 
| 521 | 
         
            +
             
     | 
| 522 | 
         
            +
                    if inputs_embeds is None:
         
     | 
| 523 | 
         
            +
                        inputs_embeds = self.embed_tokens(input_ids)
         
     | 
| 524 | 
         
            +
             
     | 
| 525 | 
         
            +
                    if use_cache and past_key_values is None:
         
     | 
| 526 | 
         
            +
                        past_key_values = DynamicCache(config=self.config)
         
     | 
| 527 | 
         
            +
             
     | 
| 528 | 
         
            +
                    if cache_position is None:
         
     | 
| 529 | 
         
            +
                        past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
         
     | 
| 530 | 
         
            +
                        cache_position = torch.arange(
         
     | 
| 531 | 
         
            +
                            past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
         
     | 
| 532 | 
         
            +
                        )
         
     | 
| 533 | 
         
            +
             
     | 
| 534 | 
         
            +
                    if position_ids is None:
         
     | 
| 535 | 
         
            +
                        position_ids = cache_position.unsqueeze(0)
         
     | 
| 536 | 
         
            +
             
     | 
| 537 | 
         
            +
                    if self.config.sliding_window is not None:
         
     | 
| 538 | 
         
            +
                        causal_mask = create_sliding_window_causal_mask(
         
     | 
| 539 | 
         
            +
                            config=self.config,
         
     | 
| 540 | 
         
            +
                            input_embeds=inputs_embeds,
         
     | 
| 541 | 
         
            +
                            attention_mask=attention_mask,
         
     | 
| 542 | 
         
            +
                            cache_position=cache_position,
         
     | 
| 543 | 
         
            +
                            past_key_values=past_key_values,
         
     | 
| 544 | 
         
            +
                            position_ids=position_ids,
         
     | 
| 545 | 
         
            +
                        )
         
     | 
| 546 | 
         
            +
                    else:
         
     | 
| 547 | 
         
            +
                        causal_mask = create_causal_mask(
         
     | 
| 548 | 
         
            +
                            config=self.config,
         
     | 
| 549 | 
         
            +
                            input_embeds=inputs_embeds,
         
     | 
| 550 | 
         
            +
                            attention_mask=attention_mask,
         
     | 
| 551 | 
         
            +
                            cache_position=cache_position,
         
     | 
| 552 | 
         
            +
                            past_key_values=past_key_values,
         
     | 
| 553 | 
         
            +
                            position_ids=position_ids,
         
     | 
| 554 | 
         
            +
                        )
         
     | 
| 555 | 
         
            +
             
     | 
| 556 | 
         
            +
                    hidden_states = inputs_embeds
         
     | 
| 557 | 
         
            +
             
     | 
| 558 | 
         
            +
                    all_hidden_states = () if output_hidden_states else None
         
     | 
| 559 | 
         
            +
                    all_attentions = () if output_attentions else None
         
     | 
| 560 | 
         
            +
                    all_router_logits = () if output_router_logits else None
         
     | 
| 561 | 
         
            +
             
     | 
| 562 | 
         
            +
                    residual = None
         
     | 
| 563 | 
         
            +
                    for decoder_layer in self.layers:
         
     | 
| 564 | 
         
            +
                        if output_hidden_states:
         
     | 
| 565 | 
         
            +
                            all_hidden_states = all_hidden_states + (hidden_states,)
         
     | 
| 566 | 
         
            +
             
     | 
| 567 | 
         
            +
                        layer_outputs = decoder_layer(
         
     | 
| 568 | 
         
            +
                            hidden_states,
         
     | 
| 569 | 
         
            +
                            attention_mask=causal_mask,
         
     | 
| 570 | 
         
            +
                            position_ids=position_ids,
         
     | 
| 571 | 
         
            +
                            past_key_values=past_key_values,
         
     | 
| 572 | 
         
            +
                            use_cache=use_cache,
         
     | 
| 573 | 
         
            +
                            cache_position=cache_position,
         
     | 
| 574 | 
         
            +
                            position_embeddings=None,
         
     | 
| 575 | 
         
            +
                            output_attentions=output_attentions,
         
     | 
| 576 | 
         
            +
                            residual=residual,
         
     | 
| 577 | 
         
            +
                        )
         
     | 
| 578 | 
         
            +
             
     | 
| 579 | 
         
            +
                        hidden_states, residual, router_logits, attn_weights = layer_outputs
         
     | 
| 580 | 
         
            +
             
     | 
| 581 | 
         
            +
                        if output_router_logits:
         
     | 
| 582 | 
         
            +
                            all_router_logits = all_router_logits + (router_logits,)
         
     | 
| 583 | 
         
            +
                        if output_attentions:
         
     | 
| 584 | 
         
            +
                            all_attentions = all_attentions + (attn_weights,)
         
     | 
| 585 | 
         
            +
             
     | 
| 586 | 
         
            +
                    hidden_states = self.norm(hidden_states)
         
     | 
| 587 | 
         
            +
             
     | 
| 588 | 
         
            +
                    if output_hidden_states:
         
     | 
| 589 | 
         
            +
                        all_hidden_states = all_hidden_states + (hidden_states,)
         
     | 
| 590 | 
         
            +
             
     | 
| 591 | 
         
            +
                    if not return_dict:
         
     | 
| 592 | 
         
            +
                        outputs = (hidden_states, past_key_values)
         
     | 
| 593 | 
         
            +
                        if output_hidden_states:
         
     | 
| 594 | 
         
            +
                            outputs += (all_hidden_states,)
         
     | 
| 595 | 
         
            +
                        if output_attentions:
         
     | 
| 596 | 
         
            +
                            outputs += (all_attentions,)
         
     | 
| 597 | 
         
            +
                        if output_router_logits:
         
     | 
| 598 | 
         
            +
                            outputs += (all_router_logits,)
         
     | 
| 599 | 
         
            +
                        return outputs
         
     | 
| 600 | 
         
            +
             
     | 
| 601 | 
         
            +
                    return MoeModelOutputWithPast(
         
     | 
| 602 | 
         
            +
                        last_hidden_state=hidden_states,
         
     | 
| 603 | 
         
            +
                        past_key_values=past_key_values,
         
     | 
| 604 | 
         
            +
                        hidden_states=all_hidden_states,
         
     | 
| 605 | 
         
            +
                        attentions=all_attentions,
         
     | 
| 606 | 
         
            +
                        router_logits=all_router_logits,
         
     | 
| 607 | 
         
            +
                    )
         
     | 
| 608 | 
         
            +
             
     | 
| 609 | 
         
            +
             
     | 
| 610 | 
         
            +
            class MiniMaxM2ForCausalLM(MiniMaxM2PreTrainedModel, GenerationMixin):
         
     | 
| 611 | 
         
            +
                def __init__(self, config: MiniMaxM2Config) -> None:
         
     | 
| 612 | 
         
            +
                    super().__init__(config)
         
     | 
| 613 | 
         
            +
                    self.model = MiniMaxM2Model(config)
         
     | 
| 614 | 
         
            +
                    self.vocab_size = config.vocab_size
         
     | 
| 615 | 
         
            +
                    self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
         
     | 
| 616 | 
         
            +
                    self.router_aux_loss_coef = config.router_aux_loss_coef
         
     | 
| 617 | 
         
            +
                    self.num_experts = config.num_local_experts
         
     | 
| 618 | 
         
            +
                    self.num_experts_per_tok = config.num_experts_per_tok
         
     | 
| 619 | 
         
            +
                    self.logits_processor = MiniMaxM2LogitsProcessor(config)
         
     | 
| 620 | 
         
            +
             
     | 
| 621 | 
         
            +
                    self.post_init()
         
     | 
| 622 | 
         
            +
             
     | 
| 623 | 
         
            +
                def get_input_embeddings(self) -> nn.Module:
         
     | 
| 624 | 
         
            +
                    return self.model.embed_tokens
         
     | 
| 625 | 
         
            +
             
     | 
| 626 | 
         
            +
                def set_input_embeddings(self, value: nn.Module) -> None:
         
     | 
| 627 | 
         
            +
                    self.model.embed_tokens = value
         
     | 
| 628 | 
         
            +
             
     | 
| 629 | 
         
            +
                def get_output_embeddings(self) -> nn.Module:
         
     | 
| 630 | 
         
            +
                    return self.lm_head
         
     | 
| 631 | 
         
            +
             
     | 
| 632 | 
         
            +
                def set_output_embeddings(self, new_embeddings: nn.Module) -> None:
         
     | 
| 633 | 
         
            +
                    self.lm_head = new_embeddings
         
     | 
| 634 | 
         
            +
             
     | 
| 635 | 
         
            +
                def prepare_inputs_for_generation(
         
     | 
| 636 | 
         
            +
                    self,
         
     | 
| 637 | 
         
            +
                    input_ids: torch.LongTensor,
         
     | 
| 638 | 
         
            +
                    past_key_values: Optional[Cache] = None,
         
     | 
| 639 | 
         
            +
                    attention_mask: Optional[torch.Tensor] = None,
         
     | 
| 640 | 
         
            +
                    inputs_embeds: Optional[torch.Tensor] = None,
         
     | 
| 641 | 
         
            +
                    **kwargs,
         
     | 
| 642 | 
         
            +
                ):
         
     | 
| 643 | 
         
            +
                    if past_key_values is not None:
         
     | 
| 644 | 
         
            +
                        input_ids = input_ids[:, -1:]
         
     | 
| 645 | 
         
            +
                        if attention_mask is not None:
         
     | 
| 646 | 
         
            +
                            attention_mask = attention_mask[:, -past_key_values.get_seq_length() - 1 :]
         
     | 
| 647 | 
         
            +
             
     | 
| 648 | 
         
            +
                    return {
         
     | 
| 649 | 
         
            +
                        "input_ids": input_ids,
         
     | 
| 650 | 
         
            +
                        "attention_mask": attention_mask,
         
     | 
| 651 | 
         
            +
                        "past_key_values": past_key_values,
         
     | 
| 652 | 
         
            +
                        "inputs_embeds": inputs_embeds,
         
     | 
| 653 | 
         
            +
                    }
         
     | 
| 654 | 
         
            +
             
     | 
| 655 | 
         
            +
                def forward(
         
     | 
| 656 | 
         
            +
                    self,
         
     | 
| 657 | 
         
            +
                    input_ids: Optional[torch.LongTensor] = None,
         
     | 
| 658 | 
         
            +
                    attention_mask: Optional[torch.Tensor] = None,
         
     | 
| 659 | 
         
            +
                    position_ids: Optional[torch.LongTensor] = None,
         
     | 
| 660 | 
         
            +
                    past_key_values: Optional[Cache] = None,
         
     | 
| 661 | 
         
            +
                    inputs_embeds: Optional[torch.Tensor] = None,
         
     | 
| 662 | 
         
            +
                    labels: Optional[torch.LongTensor] = None,
         
     | 
| 663 | 
         
            +
                    cache_position: Optional[torch.LongTensor] = None,
         
     | 
| 664 | 
         
            +
                    use_cache: Optional[bool] = None,
         
     | 
| 665 | 
         
            +
                    output_attentions: bool = False,
         
     | 
| 666 | 
         
            +
                    output_hidden_states: bool = False,
         
     | 
| 667 | 
         
            +
                    output_router_logits: Optional[bool] = None,
         
     | 
| 668 | 
         
            +
                    return_dict: Optional[bool] = None,
         
     | 
| 669 | 
         
            +
                    logits_to_keep: Union[int, torch.Tensor] = 0,
         
     | 
| 670 | 
         
            +
                ) -> Union[MoeCausalLMOutputWithPast, Tuple]:
         
     | 
| 671 | 
         
            +
                    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
         
     | 
| 672 | 
         
            +
                    output_router_logits = (
         
     | 
| 673 | 
         
            +
                        output_router_logits if output_router_logits is not None else self.config.output_router_logits
         
     | 
| 674 | 
         
            +
                    )
         
     | 
| 675 | 
         
            +
             
     | 
| 676 | 
         
            +
                    model_outputs = self.model(
         
     | 
| 677 | 
         
            +
                        input_ids=input_ids,
         
     | 
| 678 | 
         
            +
                        attention_mask=attention_mask,
         
     | 
| 679 | 
         
            +
                        position_ids=position_ids,
         
     | 
| 680 | 
         
            +
                        past_key_values=past_key_values,
         
     | 
| 681 | 
         
            +
                        inputs_embeds=inputs_embeds,
         
     | 
| 682 | 
         
            +
                        cache_position=cache_position,
         
     | 
| 683 | 
         
            +
                        use_cache=use_cache,
         
     | 
| 684 | 
         
            +
                        output_attentions=output_attentions,
         
     | 
| 685 | 
         
            +
                        output_hidden_states=output_hidden_states,
         
     | 
| 686 | 
         
            +
                        output_router_logits=output_router_logits,
         
     | 
| 687 | 
         
            +
                        return_dict=True,
         
     | 
| 688 | 
         
            +
                    )
         
     | 
| 689 | 
         
            +
             
     | 
| 690 | 
         
            +
                    hidden_states = model_outputs.last_hidden_state
         
     | 
| 691 | 
         
            +
                    slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) and logits_to_keep > 0 else slice(None)
         
     | 
| 692 | 
         
            +
                    logits = self.logits_processor(self.lm_head, hidden_states[:, slice_indices, :])
         
     | 
| 693 | 
         
            +
             
     | 
| 694 | 
         
            +
                    loss = None
         
     | 
| 695 | 
         
            +
                    if labels is not None:
         
     | 
| 696 | 
         
            +
                        shift_logits = logits[..., :-1, :].contiguous()
         
     | 
| 697 | 
         
            +
                        shift_labels = labels[..., 1:].contiguous()
         
     | 
| 698 | 
         
            +
                        loss_fct = nn.CrossEntropyLoss()
         
     | 
| 699 | 
         
            +
                        loss = loss_fct(shift_logits.view(-1, self.vocab_size), shift_labels.view(-1))
         
     | 
| 700 | 
         
            +
             
     | 
| 701 | 
         
            +
                    aux_loss = None
         
     | 
| 702 | 
         
            +
                    if output_router_logits and model_outputs.router_logits is not None:
         
     | 
| 703 | 
         
            +
                        aux_loss = load_balancing_loss_func(
         
     | 
| 704 | 
         
            +
                            model_outputs.router_logits,
         
     | 
| 705 | 
         
            +
                            num_experts=self.num_experts,
         
     | 
| 706 | 
         
            +
                            top_k=self.num_experts_per_tok,
         
     | 
| 707 | 
         
            +
                            attention_mask=attention_mask,
         
     | 
| 708 | 
         
            +
                        )
         
     | 
| 709 | 
         
            +
                        if loss is not None:
         
     | 
| 710 | 
         
            +
                            loss = loss + self.router_aux_loss_coef * aux_loss.to(loss.device)
         
     | 
| 711 | 
         
            +
             
     | 
| 712 | 
         
            +
                    if not return_dict:
         
     | 
| 713 | 
         
            +
                        output = (logits,) + (model_outputs.past_key_values,)
         
     | 
| 714 | 
         
            +
                        if output_hidden_states:
         
     | 
| 715 | 
         
            +
                            output += (model_outputs.hidden_states,)
         
     | 
| 716 | 
         
            +
                        if output_attentions:
         
     | 
| 717 | 
         
            +
                            output += (model_outputs.attentions,)
         
     | 
| 718 | 
         
            +
                        if output_router_logits:
         
     | 
| 719 | 
         
            +
                            output += (model_outputs.router_logits,)
         
     | 
| 720 | 
         
            +
                        return ((loss,) + output) if loss is not None else output
         
     | 
| 721 | 
         
            +
             
     | 
| 722 | 
         
            +
                    return MoeCausalLMOutputWithPast(
         
     | 
| 723 | 
         
            +
                        loss=loss,
         
     | 
| 724 | 
         
            +
                        aux_loss=aux_loss,
         
     | 
| 725 | 
         
            +
                        logits=logits,
         
     | 
| 726 | 
         
            +
                        past_key_values=model_outputs.past_key_values,
         
     | 
| 727 | 
         
            +
                        hidden_states=model_outputs.hidden_states,
         
     | 
| 728 | 
         
            +
                        attentions=model_outputs.attentions,
         
     | 
| 729 | 
         
            +
                        router_logits=model_outputs.router_logits,
         
     | 
| 730 | 
         
            +
                    )
         
     | 
| 731 | 
         
            +
             
     | 
| 732 | 
         
            +
            # -----------------------------------------------------------------------------
         
     | 
| 733 | 
         
            +
            # Backward compatibility aliases
         
     | 
| 734 | 
         
            +
            # -----------------------------------------------------------------------------
         
     | 
| 735 | 
         
            +
             
     | 
| 736 | 
         
            +
            MiniMaxRMSNorm = MiniMaxM2RMSNorm
         
     | 
| 737 | 
         
            +
            MiniMaxSparseMoeBlock = MiniMaxM2SparseMoeBlock
         
     | 
| 738 | 
         
            +
            MiniMaxAttention = MiniMaxM2Attention
         
     | 
| 739 | 
         
            +
            MiniMaxDecoderLayer = MiniMaxM2DecoderLayer
         
     | 
| 740 | 
         
            +
            MiniMaxMLP = MiniMaxM2MLP
         
     | 
| 741 | 
         
            +
            MiniMaxPreTrainedModel = MiniMaxM2PreTrainedModel
         
     | 
| 742 | 
         
            +
            MiniMaxModel = MiniMaxM2Model
         
     | 
| 743 | 
         
            +
             
     | 
| 744 | 
         
            +
             
     | 
| 745 | 
         
            +
            class MiniMaxForCausalLM(MiniMaxM2ForCausalLM):
         
     | 
| 746 | 
         
            +
                """Alias for compatibility with checkpoints exporting MiniMaxForCausalLM."""
         
     | 
| 747 | 
         
            +
             
     | 
| 748 | 
         
            +
             
     | 
| 749 | 
         
            +
            __all__ = [
         
     | 
| 750 | 
         
            +
                "MiniMaxM2RMSNorm",
         
     | 
| 751 | 
         
            +
                "MiniMaxM2SparseMoeBlock",
         
     | 
| 752 | 
         
            +
                "MiniMaxM2Attention",
         
     | 
| 753 | 
         
            +
                "MiniMaxM2DecoderLayer",
         
     | 
| 754 | 
         
            +
                "MiniMaxM2Model",
         
     | 
| 755 | 
         
            +
                "MiniMaxM2ForCausalLM",
         
     | 
| 756 | 
         
            +
                "MiniMaxM2PreTrainedModel",
         
     | 
| 757 | 
         
            +
                "MiniMaxRMSNorm",
         
     | 
| 758 | 
         
            +
                "MiniMaxSparseMoeBlock",
         
     | 
| 759 | 
         
            +
                "MiniMaxAttention",
         
     | 
| 760 | 
         
            +
                "MiniMaxDecoderLayer",
         
     | 
| 761 | 
         
            +
                "MiniMaxPreTrainedModel",
         
     | 
| 762 | 
         
            +
                "MiniMaxModel",
         
     | 
| 763 | 
         
            +
                "MiniMaxMLP",
         
     | 
| 764 | 
         
            +
                "MiniMaxForCausalLM",
         
     | 
| 765 | 
         
            +
            ]
         
     | 
    	
        test_minimax_m2_hf.py
    ADDED
    
    | 
         @@ -0,0 +1,178 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai
         
     | 
| 2 | 
         
            +
            # SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai
         
     | 
| 3 | 
         
            +
            # SPDX-License-Identifier: Apache-2.0
         
     | 
| 4 | 
         
            +
            # Contact: qubitium@modelcloud.ai, x.com/qubitium
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            """
         
     | 
| 7 | 
         
            +
            MiniMax-M2 Hugging Face checkpoint sanity check with streaming output.
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            Usage:
         
     | 
| 10 | 
         
            +
                python test_minimax_m2_hf.py \
         
     | 
| 11 | 
         
            +
                    --model-path /monster/data/model/MiniMax-M2-bf16 \
         
     | 
| 12 | 
         
            +
                    --question "How many letter A are there in the word Alphabet? Reply with the number only."
         
     | 
| 13 | 
         
            +
            """
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            from __future__ import annotations
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            import argparse
         
     | 
| 18 | 
         
            +
            import threading
         
     | 
| 19 | 
         
            +
            from pathlib import Path
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            import torch.nn as nn
         
     | 
| 22 | 
         
            +
            from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            from gptqmodel.hf_minimax_m2.modeling_minimax_m2 import (
         
     | 
| 25 | 
         
            +
                MiniMaxAttention,
         
     | 
| 26 | 
         
            +
                MiniMaxDecoderLayer,
         
     | 
| 27 | 
         
            +
                MiniMaxForCausalLM,
         
     | 
| 28 | 
         
            +
                MiniMaxMLP,
         
     | 
| 29 | 
         
            +
                MiniMaxM2Attention,
         
     | 
| 30 | 
         
            +
                MiniMaxM2DecoderLayer,
         
     | 
| 31 | 
         
            +
                MiniMaxM2ForCausalLM,
         
     | 
| 32 | 
         
            +
                MiniMaxM2MLP,
         
     | 
| 33 | 
         
            +
                MiniMaxM2RMSNorm,
         
     | 
| 34 | 
         
            +
                MiniMaxM2SparseMoeBlock,
         
     | 
| 35 | 
         
            +
                MiniMaxRMSNorm,
         
     | 
| 36 | 
         
            +
                MiniMaxSparseMoeBlock,
         
     | 
| 37 | 
         
            +
            )
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            def parse_args() -> argparse.Namespace:
         
     | 
| 41 | 
         
            +
                parser = argparse.ArgumentParser(description="MiniMax-M2 HF checkpoint smoke test.")
         
     | 
| 42 | 
         
            +
                parser.add_argument(
         
     | 
| 43 | 
         
            +
                    "--model-path",
         
     | 
| 44 | 
         
            +
                    type=str,
         
     | 
| 45 | 
         
            +
                    default="/monster/data/model/MiniMax-M2-bf16",
         
     | 
| 46 | 
         
            +
                    help="Path to the MiniMax-M2 Hugging Face checkpoint directory.",
         
     | 
| 47 | 
         
            +
                )
         
     | 
| 48 | 
         
            +
                parser.add_argument(
         
     | 
| 49 | 
         
            +
                    "--question",
         
     | 
| 50 | 
         
            +
                    type=str,
         
     | 
| 51 | 
         
            +
                    default="How many letter A are there in the word Alphabet? Reply with the number only.",
         
     | 
| 52 | 
         
            +
                    help="User question to send through the chat template.",
         
     | 
| 53 | 
         
            +
                )
         
     | 
| 54 | 
         
            +
                parser.add_argument(
         
     | 
| 55 | 
         
            +
                    "--max-new-tokens",
         
     | 
| 56 | 
         
            +
                    type=int,
         
     | 
| 57 | 
         
            +
                    default=512,
         
     | 
| 58 | 
         
            +
                    help="Maximum number of new tokens to sample from the model.",
         
     | 
| 59 | 
         
            +
                )
         
     | 
| 60 | 
         
            +
                return parser.parse_args()
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
            def build_prompt(tokenizer: AutoTokenizer, question: str) -> str:
         
     | 
| 64 | 
         
            +
                messages = [
         
     | 
| 65 | 
         
            +
                    {"role": "system", "content": "You are a helpful assistant."},
         
     | 
| 66 | 
         
            +
                    {"role": "user", "content": question},
         
     | 
| 67 | 
         
            +
                ]
         
     | 
| 68 | 
         
            +
                return tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
            def assert_module_types(model: MiniMaxM2ForCausalLM) -> None:
         
     | 
| 72 | 
         
            +
                causal_lm_types = (MiniMaxM2ForCausalLM, MiniMaxForCausalLM)
         
     | 
| 73 | 
         
            +
                decoder_layer_types = (MiniMaxM2DecoderLayer, MiniMaxDecoderLayer)
         
     | 
| 74 | 
         
            +
                attention_types = (MiniMaxM2Attention, MiniMaxAttention)
         
     | 
| 75 | 
         
            +
                moe_block_types = (MiniMaxM2SparseMoeBlock, MiniMaxSparseMoeBlock)
         
     | 
| 76 | 
         
            +
                norm_types = (MiniMaxM2RMSNorm, MiniMaxRMSNorm)
         
     | 
| 77 | 
         
            +
                mlp_types = (MiniMaxM2MLP, MiniMaxMLP)
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
                assert isinstance(
         
     | 
| 80 | 
         
            +
                    model, causal_lm_types
         
     | 
| 81 | 
         
            +
                ), f"Expected MiniMaxM2ForCausalLM/MiniMaxForCausalLM, received {type(model).__name__}"
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
                decoder = getattr(model, "model", None)
         
     | 
| 84 | 
         
            +
                assert decoder is not None, "Model is missing the `model` attribute with decoder layers."
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
                for layer_idx, layer in enumerate(decoder.layers):
         
     | 
| 87 | 
         
            +
                    assert isinstance(
         
     | 
| 88 | 
         
            +
                        layer, decoder_layer_types
         
     | 
| 89 | 
         
            +
                    ), f"Layer {layer_idx}: expected MiniMax(M2)DecoderLayer, got {type(layer).__name__}"
         
     | 
| 90 | 
         
            +
                    assert isinstance(
         
     | 
| 91 | 
         
            +
                        layer.self_attn, attention_types
         
     | 
| 92 | 
         
            +
                    ), f"Layer {layer_idx}: unexpected self_attn type {type(layer.self_attn).__name__}"
         
     | 
| 93 | 
         
            +
                    assert isinstance(
         
     | 
| 94 | 
         
            +
                        layer.block_sparse_moe, moe_block_types
         
     | 
| 95 | 
         
            +
                    ), f"Layer {layer_idx}: unexpected MoE block type {type(layer.block_sparse_moe).__name__}"
         
     | 
| 96 | 
         
            +
                    assert isinstance(
         
     | 
| 97 | 
         
            +
                        layer.input_layernorm, norm_types
         
     | 
| 98 | 
         
            +
                    ), f"Layer {layer_idx}: unexpected input_layernorm type {type(layer.input_layernorm).__name__}"
         
     | 
| 99 | 
         
            +
                    assert isinstance(
         
     | 
| 100 | 
         
            +
                        layer.post_attention_layernorm, norm_types
         
     | 
| 101 | 
         
            +
                    ), f"Layer {layer_idx}: unexpected post_attention_layernorm type {type(layer.post_attention_layernorm).__name__}"
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                    moe_block = layer.block_sparse_moe
         
     | 
| 104 | 
         
            +
                    assert isinstance(
         
     | 
| 105 | 
         
            +
                        moe_block.experts, nn.ModuleList
         
     | 
| 106 | 
         
            +
                    ), f"Layer {layer_idx}: expected experts to be a ModuleList, got {type(moe_block.experts).__name__}"
         
     | 
| 107 | 
         
            +
                    for expert_idx, expert in enumerate(moe_block.experts):
         
     | 
| 108 | 
         
            +
                        assert isinstance(
         
     | 
| 109 | 
         
            +
                            expert, mlp_types
         
     | 
| 110 | 
         
            +
                        ), f"Layer {layer_idx} expert {expert_idx}: expected MiniMax(M2)MLP, got {type(expert).__name__}"
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
            def main() -> None:
         
     | 
| 114 | 
         
            +
                args = parse_args()
         
     | 
| 115 | 
         
            +
                model_path = Path(args.model_path).expanduser().resolve()
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
                print(f"Loading tokenizer from {model_path}...")
         
     | 
| 118 | 
         
            +
                tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
                print(f"Loading model from {model_path}...")
         
     | 
| 121 | 
         
            +
                model = AutoModelForCausalLM.from_pretrained(
         
     | 
| 122 | 
         
            +
                    model_path,
         
     | 
| 123 | 
         
            +
                    dtype="bfloat16",
         
     | 
| 124 | 
         
            +
                    device_map="auto",
         
     | 
| 125 | 
         
            +
                    trust_remote_code=True,
         
     | 
| 126 | 
         
            +
                )
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
                # Uncomment to enforce module type checks.
         
     | 
| 129 | 
         
            +
                # print("Validating module types...")
         
     | 
| 130 | 
         
            +
                # assert_module_types(model)
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
                prompt = build_prompt(tokenizer, args.question)
         
     | 
| 133 | 
         
            +
                inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
                print("Running generation (streaming)...\n")
         
     | 
| 136 | 
         
            +
                streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=False)
         
     | 
| 137 | 
         
            +
                eos_ids = model.generation_config.eos_token_id
         
     | 
| 138 | 
         
            +
                if eos_ids is None:
         
     | 
| 139 | 
         
            +
                    eos_ids = []
         
     | 
| 140 | 
         
            +
                elif isinstance(eos_ids, int):
         
     | 
| 141 | 
         
            +
                    eos_ids = [eos_ids]
         
     | 
| 142 | 
         
            +
                think_end_id = tokenizer.convert_tokens_to_ids("</think>")
         
     | 
| 143 | 
         
            +
                if think_end_id is not None and think_end_id not in eos_ids:
         
     | 
| 144 | 
         
            +
                    eos_ids = eos_ids + [think_end_id]
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
                generation_kwargs = dict(
         
     | 
| 147 | 
         
            +
                    **inputs,
         
     | 
| 148 | 
         
            +
                    max_new_tokens=args.max_new_tokens,
         
     | 
| 149 | 
         
            +
                    streamer=streamer,
         
     | 
| 150 | 
         
            +
                    eos_token_id=eos_ids if eos_ids else None,
         
     | 
| 151 | 
         
            +
                )
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
                generation_thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
         
     | 
| 154 | 
         
            +
                generation_thread.start()
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
                completion = []
         
     | 
| 157 | 
         
            +
                first_chunk = True
         
     | 
| 158 | 
         
            +
                seen_end_reasoning = False
         
     | 
| 159 | 
         
            +
                for text in streamer:
         
     | 
| 160 | 
         
            +
                    if first_chunk:
         
     | 
| 161 | 
         
            +
                        print("<think>", end="", flush=True)
         
     | 
| 162 | 
         
            +
                        completion.append("<think>")
         
     | 
| 163 | 
         
            +
                        first_chunk = False
         
     | 
| 164 | 
         
            +
                    print(text, end="", flush=True)
         
     | 
| 165 | 
         
            +
                    completion.append(text)
         
     | 
| 166 | 
         
            +
                    if "</think>" in text:
         
     | 
| 167 | 
         
            +
                        seen_end_reasoning = True
         
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
                generation_thread.join()
         
     | 
| 170 | 
         
            +
                print("\n\n=== Completed Response ===")
         
     | 
| 171 | 
         
            +
                final_text = "".join(completion).strip()
         
     | 
| 172 | 
         
            +
                print(final_text or "<empty response>")
         
     | 
| 173 | 
         
            +
                if not seen_end_reasoning:
         
     | 
| 174 | 
         
            +
                    print("\n[warning] No </think> token detected in streamed output.", flush=True)
         
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 178 | 
         
            +
                main()
         
     |