Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	created micro hf space
Browse files- .gitignore +4 -0
- Instructions.md +20 -0
- README.md +2 -6
- app.py +166 -0
- configs/micro_llama_1b.yml +14 -0
- models/micro_llama.py +588 -0
- models/micro_moe_llama.py +725 -0
- models/micro_olmo.py +528 -0
- models/modules.py +42 -0
- requirements.txt +3 -0
- router_backend.py +223 -0
    	
        .gitignore
    ADDED
    
    | @@ -0,0 +1,4 @@ | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            .env
         | 
| 2 | 
            +
            *.pyc
         | 
| 3 | 
            +
            .DS_Store
         | 
| 4 | 
            +
            __pycache__/
         | 
    	
        Instructions.md
    ADDED
    
    | @@ -0,0 +1,20 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # MiCRo Expert Routing Visualizer (Gradio)
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            This demo visualizes how modular language models allocate computation across specialized experts—Language, Logic, Social, and World—when processing a given prompt.
         | 
| 4 | 
            +
            Each expert corresponds to a cognitive domain inspired by human brain networks. Enter a prompt to see how tokens are dynamically routed across modules, revealing the model’s internal reasoning structure.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            ## How it works
         | 
| 7 | 
            +
            - Choose a model (dropdown) or type a custom model id.
         | 
| 8 | 
            +
            - Enter a *User prompt*. Optionally add an *Assistant prompt*; if provided, the app concatenates them as:
         | 
| 9 | 
            +
             | 
| 10 | 
            +
              ```
         | 
| 11 | 
            +
              User: <user text>
         | 
| 12 | 
            +
              Assistant: <assistant text>
         | 
| 13 | 
            +
              ```
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            - When the prompt fails, the demo falls back to "mock data", which generates deterministic, pseudo-random percentages from the prompt.  
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            ### Backend contract
         | 
| 18 | 
            +
            `get_expert_routing(model_id: str, prompt: str)` must return 4 values (percentages) for the experts in this fixed order:
         | 
| 19 | 
            +
            `["Language", "Logic", "Social", "World"]`
         | 
| 20 | 
            +
            or a dict with those exact keys.
         | 
    	
        README.md
    CHANGED
    
    | @@ -1,6 +1,5 @@ | |
| 1 | 
            -
            ---
         | 
| 2 | 
             
            title: MiCRo Routing Visualizer
         | 
| 3 | 
            -
            emoji:  | 
| 4 | 
             
            colorFrom: purple
         | 
| 5 | 
             
            colorTo: red
         | 
| 6 | 
             
            sdk: gradio
         | 
| @@ -8,7 +7,4 @@ sdk_version: 5.49.1 | |
| 8 | 
             
            app_file: app.py
         | 
| 9 | 
             
            pinned: false
         | 
| 10 | 
             
            license: mit
         | 
| 11 | 
            -
            short_description: Mixture of Cognitive Reasoners Computation Allocation
         | 
| 12 | 
            -
            ---
         | 
| 13 | 
            -
             | 
| 14 | 
            -
            Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
         | 
|  | |
|  | |
| 1 | 
             
            title: MiCRo Routing Visualizer
         | 
| 2 | 
            +
            emoji: 🧠
         | 
| 3 | 
             
            colorFrom: purple
         | 
| 4 | 
             
            colorTo: red
         | 
| 5 | 
             
            sdk: gradio
         | 
|  | |
| 7 | 
             
            app_file: app.py
         | 
| 8 | 
             
            pinned: false
         | 
| 9 | 
             
            license: mit
         | 
| 10 | 
            +
            short_description: Mixture of Cognitive Reasoners Computation Allocation
         | 
|  | |
|  | |
|  | 
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,166 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # app.py
         | 
| 2 | 
            +
            """
         | 
| 3 | 
            +
            Hugging Face Space: MoE Expert Routing Visualizer (Gradio)
         | 
| 4 | 
            +
            ----------------------------------------------------------
         | 
| 5 | 
            +
            This Space lets a user:
         | 
| 6 | 
            +
            - Choose a model (from a dropdown or a free-text box)
         | 
| 7 | 
            +
            - Enter a user prompt, and optionally an assistant prompt
         | 
| 8 | 
            +
            - Call a backend function that returns 4 routing percentages (Language, Logic, Social, World)
         | 
| 9 | 
            +
            - See a bar plot + table of the percentages
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            🧩 Plug your real routing function in router_backend.py -> get_expert_routing().
         | 
| 12 | 
            +
            By default, a deterministic "mock mode" produces stable pseudo-random percentages from the prompt.
         | 
| 13 | 
            +
            """
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            import hashlib
         | 
| 16 | 
            +
            from typing import Dict, List, Tuple, Union
         | 
| 17 | 
            +
            import gradio as gr
         | 
| 18 | 
            +
            import plotly
         | 
| 19 | 
            +
            import plotly.express as px
         | 
| 20 | 
            +
            import pandas as pd
         | 
| 21 | 
            +
            from router_backend import get_expert_routing 
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            # ---- Expected backend adapter ------------------------------------------------
         | 
| 24 | 
            +
            # Implement your real function in router_backend.py with the following signature:
         | 
| 25 | 
            +
            #   def get_expert_routing(model_id: str, prompt: str) -> Union[List[float], Dict[str, float], Tuple[float, float, float, float]]
         | 
| 26 | 
            +
            # It MUST return 4 values that sum to ~100 (percentages) in the fixed order:
         | 
| 27 | 
            +
            #   ["Language", "Logic", "Social", "World"]
         | 
| 28 | 
            +
            # or a mapping with those keys.
         | 
| 29 | 
            +
            # try:
         | 
| 30 | 
            +
            #     from router_backend import get_expert_routing  # your real backend
         | 
| 31 | 
            +
            #     BACKEND_AVAILABLE = True
         | 
| 32 | 
            +
            # except Exception as e:  # keep error for display if needed
         | 
| 33 | 
            +
            #     BACKEND_AVAILABLE = False
         | 
| 34 | 
            +
            #     _backend_import_error = e
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            EXPERTS = ["Language", "Logic", "Social", "World"]
         | 
| 37 | 
            +
             | 
| 38 | 
            +
            DEFAULT_MODELS = [
         | 
| 39 | 
            +
                "micro-llama-1b", 
         | 
| 40 | 
            +
                "micro-llama-3b",
         | 
| 41 | 
            +
                "micro-llama-1b-dpo",
         | 
| 42 | 
            +
                "micro-moe-llama-1b",
         | 
| 43 | 
            +
                "micro-smollm2-135m",
         | 
| 44 | 
            +
                "micro-smollm2-360m",
         | 
| 45 | 
            +
                "micro-moe-smollm2-135m",
         | 
| 46 | 
            +
                "micro-moe-smollm2-360m",
         | 
| 47 | 
            +
            ]
         | 
| 48 | 
            +
             | 
| 49 | 
            +
            def _mock_routing(model_id: str, prompt: str, seed: int = 0) -> List[float]:
         | 
| 50 | 
            +
                """
         | 
| 51 | 
            +
                Deterministic mock routing percentages based on model_id + prompt + seed.
         | 
| 52 | 
            +
                Returns a list of 4 percentages summing to 100.0
         | 
| 53 | 
            +
                """
         | 
| 54 | 
            +
                h = hashlib.sha256(f"{model_id}||{prompt}||{seed}".encode()).digest()
         | 
| 55 | 
            +
                # split into 4 positive numbers
         | 
| 56 | 
            +
                vals = [int.from_bytes(h[i*8:(i+1)*8], "little") % 10_000 + 1 for i in range(4)]
         | 
| 57 | 
            +
                s = sum(vals)
         | 
| 58 | 
            +
                return [100.0 * v / s for v in vals]
         | 
| 59 | 
            +
             | 
| 60 | 
            +
            def _normalize_output(r: Union[List[float], Tuple[float, float, float, float], Dict[str, float]]) -> List[float]:
         | 
| 61 | 
            +
                """
         | 
| 62 | 
            +
                Normalize different return types into a 4-length list ordered as EXPERTS.
         | 
| 63 | 
            +
                """
         | 
| 64 | 
            +
                if isinstance(r, dict):
         | 
| 65 | 
            +
                    vals = [float(r.get(k, 0.0)) for k in EXPERTS]
         | 
| 66 | 
            +
                else:
         | 
| 67 | 
            +
                    vals = [float(x) for x in list(r)]
         | 
| 68 | 
            +
                    if len(vals) != 4:
         | 
| 69 | 
            +
                        raise ValueError(f"Expected 4 values, got {len(vals)}.")
         | 
| 70 | 
            +
                # renormalize to 100 if needed
         | 
| 71 | 
            +
                s = sum(vals)
         | 
| 72 | 
            +
                if s <= 0:
         | 
| 73 | 
            +
                    raise ValueError("Sum of routing percentages is non-positive.")
         | 
| 74 | 
            +
                vals = [100.0 * v / s for v in vals]
         | 
| 75 | 
            +
                return vals
         | 
| 76 | 
            +
             | 
| 77 | 
            +
            def _compose_prompt(user_prompt: str, assistant_prompt: str) -> str:
         | 
| 78 | 
            +
                user_prompt = (user_prompt or "").strip()
         | 
| 79 | 
            +
                assistant_prompt = (assistant_prompt or "").strip()
         | 
| 80 | 
            +
                if assistant_prompt:
         | 
| 81 | 
            +
                    return [{"role": "user", "content": user_prompt}, {"role": "assistant", "content": assistant_prompt}]
         | 
| 82 | 
            +
                return user_prompt
         | 
| 83 | 
            +
             | 
| 84 | 
            +
            def route_and_plot(model_choice: str, hf_token: str, user_prompt: str, assistant_prompt: str) -> Tuple[pd.DataFrame, "plotly.graph_objs._figure.Figure", str]:
         | 
| 85 | 
            +
                """
         | 
| 86 | 
            +
                Main pipeline:
         | 
| 87 | 
            +
                - Compose prompt (user + optional assistant)
         | 
| 88 | 
            +
                - Call backend (real or mock)
         | 
| 89 | 
            +
                - Return a table + bar plot + status message
         | 
| 90 | 
            +
                """
         | 
| 91 | 
            +
                model_id = model_choice.strip()
         | 
| 92 | 
            +
                if not model_id:
         | 
| 93 | 
            +
                    raise gr.Error("Please select a model or enter a custom model id.")
         | 
| 94 | 
            +
                prompt = _compose_prompt(user_prompt, assistant_prompt)
         | 
| 95 | 
            +
                if not prompt:
         | 
| 96 | 
            +
                    raise gr.Error("Please enter a prompt.")
         | 
| 97 | 
            +
                
         | 
| 98 | 
            +
                seed = 42
         | 
| 99 | 
            +
                use_mock = False
         | 
| 100 | 
            +
                if use_mock:
         | 
| 101 | 
            +
                    msg = "Using mock data."
         | 
| 102 | 
            +
                    vals = _mock_routing(model_id, prompt, seed=seed)
         | 
| 103 | 
            +
                    generation = None
         | 
| 104 | 
            +
                else:
         | 
| 105 | 
            +
                    try:
         | 
| 106 | 
            +
                        raw, generation = get_expert_routing(model_id, hf_token, prompt)  # <-- your real function
         | 
| 107 | 
            +
                        vals = _normalize_output(raw)
         | 
| 108 | 
            +
                        msg = "Routed with real backend."
         | 
| 109 | 
            +
                    except Exception as e:
         | 
| 110 | 
            +
                        # fallback to mock on error, but surface message
         | 
| 111 | 
            +
                        msg = f"Backend error: {e}\nFalling back to mock data."
         | 
| 112 | 
            +
                        vals = _mock_routing(model_id, prompt, seed=seed)
         | 
| 113 | 
            +
                        generation = None
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                df = pd.DataFrame({"Expert": EXPERTS, "Percent": vals})
         | 
| 116 | 
            +
                fig = px.bar(df, x="Expert", y="Percent", title="Token Routing by Expert (%)", text="Percent")
         | 
| 117 | 
            +
                fig.update_traces(texttemplate="%{text:.2f}%", textposition="outside")
         | 
| 118 | 
            +
                fig.update_layout(yaxis_range=[0, max(100, max(vals) * 1.25)], bargap=0.35)
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                status = f"Model: {model_id}<br>{msg}"
         | 
| 121 | 
            +
                if generation is None:
         | 
| 122 | 
            +
                    generation = assistant_prompt
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                return generation, df, fig, status
         | 
| 125 | 
            +
             | 
| 126 | 
            +
            with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
         | 
| 127 | 
            +
                gr.Markdown(
         | 
| 128 | 
            +
                    """
         | 
| 129 | 
            +
                    # 🧠 Mixture of Cognitive Reasoner (MiCRo) Expert Routing Visualizer
         | 
| 130 | 
            +
                    ## Enter a prompt (and optionally an assistant reply), pick a model, and visualize how tokens were routed across experts.
         | 
| 131 | 
            +
                    Paper: [Mixture of Cognitive Reasoners: Modular Reasoning with Brain-Like Specialization](https://arxiv.org/abs/2506.13331)
         | 
| 132 | 
            +
                    ----
         | 
| 133 | 
            +
                    This demo visualizes how modular language models allocate computation across specialized experts—Language, Logic, Social, and World—when processing a given prompt.
         | 
| 134 | 
            +
                    Each expert corresponds to a cognitive domain inspired by human brain networks. Enter a prompt to see how tokens are dynamically routed across modules, revealing the model's internal reasoning structure.
         | 
| 135 | 
            +
                    """.strip()
         | 
| 136 | 
            +
                )
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                with gr.Row():
         | 
| 139 | 
            +
                    model_choice = gr.Dropdown(choices=DEFAULT_MODELS, label="Select a model", value=DEFAULT_MODELS[0])
         | 
| 140 | 
            +
                    hf_token = gr.Textbox(label="Huggingface token for authentication", placeholder="hf token", lines=1)
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                with gr.Row():
         | 
| 143 | 
            +
                    user_prompt = gr.Textbox(lines=6, label="User prompt", placeholder="Type the user message here...")
         | 
| 144 | 
            +
                    assistant_prompt = gr.Textbox(lines=6, label="Assistant prompt (optional)", placeholder="Type the assistant message here (optional)...")
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                # with gr.Row():
         | 
| 147 | 
            +
                #     use_mock = gr.Checkbox(value=True, label="Use mock data (uncheck to call your backend)")
         | 
| 148 | 
            +
                #     seed = gr.Slider(value=0, minimum=0, maximum=10_000, step=1, label="Mock seed")
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                run = gr.Button("Run Routing", variant="primary")
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                generation_output = gr.Textbox(lines=4, label="Generated continuation", placeholder="Generated text will appear here...", interactive=False)
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                with gr.Row():
         | 
| 155 | 
            +
                    table = gr.Dataframe(label="Routing Percentages", interactive=False)
         | 
| 156 | 
            +
                plot = gr.Plot(label="Bar Plot")
         | 
| 157 | 
            +
                status = gr.Markdown("")
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                run.click(
         | 
| 160 | 
            +
                    route_and_plot,
         | 
| 161 | 
            +
                    inputs=[model_choice, hf_token, user_prompt, assistant_prompt],
         | 
| 162 | 
            +
                    outputs=[generation_output, table, plot, status],
         | 
| 163 | 
            +
                )
         | 
| 164 | 
            +
             | 
| 165 | 
            +
            if __name__ == "__main__":
         | 
| 166 | 
            +
                demo.launch()
         | 
    	
        configs/micro_llama_1b.yml
    ADDED
    
    | @@ -0,0 +1,14 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            run-title: micro-llama-1b
         | 
| 2 | 
            +
            model: micro-llama-1b
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            base-model: meta-llama/Llama-3.2-1B
         | 
| 5 | 
            +
            tokenizer: meta-llama/Llama-3.2-1B-Instruct
         | 
| 6 | 
            +
            num-experts: 4
         | 
| 7 | 
            +
            top-k-experts: 1
         | 
| 8 | 
            +
            jitter-noise: 0
         | 
| 9 | 
            +
            use-router: True
         | 
| 10 | 
            +
            mask-input: True
         | 
| 11 | 
            +
            max-length: 8192
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            trainable:
         | 
| 14 | 
            +
              - model
         | 
    	
        models/micro_llama.py
    ADDED
    
    | @@ -0,0 +1,588 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import Optional, Tuple, Union, List, Callable
         | 
| 2 | 
            +
            import logging 
         | 
| 3 | 
            +
            import yaml
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            import torch.nn as nn
         | 
| 7 | 
            +
            import torch.nn.functional as F
         | 
| 8 | 
            +
            import torch.distributed as dist
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            # from transformers.utils import TransformerKwargs
         | 
| 11 | 
            +
            from transformers import LlamaConfig, AutoConfig, AutoTokenizer, AutoModelForCausalLM
         | 
| 12 | 
            +
            from transformers.modeling_attn_mask_utils import AttentionMaskConverter
         | 
| 13 | 
            +
            from transformers.models.llama.modeling_llama import (
         | 
| 14 | 
            +
                LlamaRotaryEmbedding, 
         | 
| 15 | 
            +
                LlamaRMSNorm, 
         | 
| 16 | 
            +
                LlamaMLP,
         | 
| 17 | 
            +
                LlamaDecoderLayer,
         | 
| 18 | 
            +
                LlamaPreTrainedModel, 
         | 
| 19 | 
            +
                GenerationMixin,
         | 
| 20 | 
            +
                apply_rotary_pos_emb,
         | 
| 21 | 
            +
                eager_attention_forward,
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            )
         | 
| 24 | 
            +
            from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
         | 
| 25 | 
            +
            from transformers.cache_utils import Cache, StaticCache, DynamicCache
         | 
| 26 | 
            +
            from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
         | 
| 27 | 
            +
            from transformers.processing_utils import Unpack
         | 
| 28 | 
            +
            from transformers.utils import is_torchdynamo_compiling
         | 
| 29 | 
            +
            from models.modules import CausalLMOutputWithPast
         | 
| 30 | 
            +
            from transformers.modeling_layers import GradientCheckpointingLayer
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            def _prepare_4d_causal_attention_mask_with_cache_position(
         | 
| 35 | 
            +
                attention_mask: torch.Tensor,
         | 
| 36 | 
            +
                sequence_length: int,
         | 
| 37 | 
            +
                target_length: int,
         | 
| 38 | 
            +
                dtype: torch.dtype,
         | 
| 39 | 
            +
                device: torch.device,
         | 
| 40 | 
            +
                min_dtype: float,
         | 
| 41 | 
            +
                cache_position: torch.Tensor,
         | 
| 42 | 
            +
                batch_size: int,
         | 
| 43 | 
            +
            ):
         | 
| 44 | 
            +
                """
         | 
| 45 | 
            +
                Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
         | 
| 46 | 
            +
                `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                Args:
         | 
| 49 | 
            +
                    attention_mask (`torch.Tensor`):
         | 
| 50 | 
            +
                        A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
         | 
| 51 | 
            +
                    sequence_length (`int`):
         | 
| 52 | 
            +
                        The sequence length being processed.
         | 
| 53 | 
            +
                    target_length (`int`):
         | 
| 54 | 
            +
                        The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
         | 
| 55 | 
            +
                    dtype (`torch.dtype`):
         | 
| 56 | 
            +
                        The dtype to use for the 4D attention mask.
         | 
| 57 | 
            +
                    device (`torch.device`):
         | 
| 58 | 
            +
                        The device to plcae the 4D attention mask on.
         | 
| 59 | 
            +
                    min_dtype (`float`):
         | 
| 60 | 
            +
                        The minimum value representable with the dtype `dtype`.
         | 
| 61 | 
            +
                    cache_position (`torch.Tensor`):
         | 
| 62 | 
            +
                        Indices depicting the position of the input sequence tokens in the sequence.
         | 
| 63 | 
            +
                    batch_size (`torch.Tensor`):
         | 
| 64 | 
            +
                        Batch size.
         | 
| 65 | 
            +
                """
         | 
| 66 | 
            +
                if attention_mask is not None and attention_mask.dim() == 4:
         | 
| 67 | 
            +
                    # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
         | 
| 68 | 
            +
                    causal_mask = attention_mask
         | 
| 69 | 
            +
                else:
         | 
| 70 | 
            +
                    causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
         | 
| 71 | 
            +
                    if sequence_length != 1:
         | 
| 72 | 
            +
                        causal_mask = torch.triu(causal_mask, diagonal=1)
         | 
| 73 | 
            +
                    causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
         | 
| 74 | 
            +
                    causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
         | 
| 75 | 
            +
                    if attention_mask is not None:
         | 
| 76 | 
            +
                        causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit
         | 
| 77 | 
            +
                        mask_length = attention_mask.shape[-1]
         | 
| 78 | 
            +
                        padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
         | 
| 79 | 
            +
                        padding_mask = padding_mask == 0
         | 
| 80 | 
            +
                        causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
         | 
| 81 | 
            +
                            padding_mask, min_dtype
         | 
| 82 | 
            +
                        )
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                return causal_mask
         | 
| 85 | 
            +
             | 
| 86 | 
            +
            class MiCRoLlamaConfig(LlamaConfig):
         | 
| 87 | 
            +
                model_type = "micro_llama"
         | 
| 88 | 
            +
                def __init__(self, *args, **kwargs):
         | 
| 89 | 
            +
                    super().__init__(*args, **kwargs)
         | 
| 90 | 
            +
                    self.num_experts = kwargs.get("num_experts", 4)
         | 
| 91 | 
            +
                    self.use_router = kwargs.get("use_router", True)
         | 
| 92 | 
            +
                    self.num_experts_per_tok = kwargs.get("num_experts_per_tok", 2)
         | 
| 93 | 
            +
                    self.jitter_noise = kwargs.get("jitter_noise", 0.0)
         | 
| 94 | 
            +
                    self.loss_method = kwargs.get("loss_method", "all")
         | 
| 95 | 
            +
                    self.config_path = kwargs.get("config_path", None)
         | 
| 96 | 
            +
                    
         | 
| 97 | 
            +
            class MiCRoLlamaDecoderLayer(nn.Module):
         | 
| 98 | 
            +
                def __init__(self, config: MiCRoLlamaConfig, layer_idx: int):
         | 
| 99 | 
            +
                    super().__init__()
         | 
| 100 | 
            +
                    self.hidden_dim = config.hidden_size
         | 
| 101 | 
            +
                    self.ffn_dim = config.intermediate_size
         | 
| 102 | 
            +
                    self.num_experts = config.num_experts
         | 
| 103 | 
            +
                    self.top_k = config.num_experts_per_tok
         | 
| 104 | 
            +
                    self.use_router = config.use_router
         | 
| 105 | 
            +
                    self.ablate = config.ablate
         | 
| 106 | 
            +
                    self.num_key_value_heads = config.num_key_value_heads
         | 
| 107 | 
            +
                    self.head_dim = self.hidden_dim // config.num_attention_heads
         | 
| 108 | 
            +
                    self.gradient_checkpointing = config.gradient_checkpointing
         | 
| 109 | 
            +
                    if isinstance(self.ablate, str):
         | 
| 110 | 
            +
                        self.ablate = [self.ablate]
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                    self.gate = nn.Sequential(
         | 
| 113 | 
            +
                        nn.Linear(self.hidden_dim, self.hidden_dim, bias=False),
         | 
| 114 | 
            +
                        nn.Linear(self.hidden_dim, self.num_experts, bias=False)
         | 
| 115 | 
            +
                    )
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                    self.num_layers = config.backbone_num_layers
         | 
| 118 | 
            +
                    self.layer_idx = layer_idx
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                    self.experts = nn.ModuleList([LlamaDecoderLayer(config, layer_idx * self.num_experts + expert_idx) for expert_idx in range(self.num_experts)])
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                    self.jitter_noise = config.jitter_noise
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                def forward(
         | 
| 125 | 
            +
                    self,
         | 
| 126 | 
            +
                    hidden_states: torch.Tensor,
         | 
| 127 | 
            +
                    routing_weights: Optional[torch.Tensor] = None,
         | 
| 128 | 
            +
                    attention_mask: Optional[torch.Tensor] = None,
         | 
| 129 | 
            +
                    position_ids: Optional[torch.LongTensor] = None,
         | 
| 130 | 
            +
                    ablate: Optional[List[str]] = None,
         | 
| 131 | 
            +
                    past_key_value: Optional[Cache] = None,
         | 
| 132 | 
            +
                    output_attentions: Optional[bool] = False,
         | 
| 133 | 
            +
                    use_cache: Optional[bool] = False,
         | 
| 134 | 
            +
                    cache_position: Optional[torch.LongTensor] = None,
         | 
| 135 | 
            +
                    position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46
         | 
| 136 | 
            +
                    **kwargs,
         | 
| 137 | 
            +
                ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
         | 
| 138 | 
            +
                
         | 
| 139 | 
            +
                    batch_size, sequence_length, hidden_dim = hidden_states.shape
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                    if ablate is not None:
         | 
| 142 | 
            +
                        self.ablate = ablate
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                    if self.training and self.jitter_noise > 0:
         | 
| 145 | 
            +
                        hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
         | 
| 146 | 
            +
                    
         | 
| 147 | 
            +
                    if self.use_router:
         | 
| 148 | 
            +
                        router_logits = self.gate(hidden_states)
         | 
| 149 | 
            +
                        if "logic" in self.ablate:
         | 
| 150 | 
            +
                            router_logits[..., 0] = -torch.inf
         | 
| 151 | 
            +
                        if "social" in self.ablate:
         | 
| 152 | 
            +
                            router_logits[..., 1] = -torch.inf
         | 
| 153 | 
            +
                        if "world" in self.ablate:
         | 
| 154 | 
            +
                            router_logits[..., 2] = -torch.inf
         | 
| 155 | 
            +
                        if "language" in self.ablate:
         | 
| 156 | 
            +
                            router_logits[..., 3] = -torch.inf
         | 
| 157 | 
            +
                        routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float)
         | 
| 158 | 
            +
                    else:
         | 
| 159 | 
            +
                        if len(routing_weights.shape) == 2:
         | 
| 160 | 
            +
                            routing_weights = routing_weights.unsqueeze(1).tile((1,sequence_length,1)).float()
         | 
| 161 | 
            +
                        else:
         | 
| 162 | 
            +
                            routing_weights = routing_weights.float()
         | 
| 163 | 
            +
                        router_logits = routing_weights
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                    routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
         | 
| 166 | 
            +
                    routing_weights /= (routing_weights.sum(dim=-1, keepdim=True) + 1e-9)
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                    # we cast back to the input dtype
         | 
| 169 | 
            +
                    routing_weights = routing_weights.to(hidden_states.dtype)
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                    # We'll accumulate outputs here
         | 
| 172 | 
            +
                    final_hidden_states = torch.zeros_like(hidden_states)
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                    # Flatten final_hidden_states to [batch_size * seq_len, hidden_dim]
         | 
| 175 | 
            +
                    # so we can do a 2D "index_add_" at the end of each loop.
         | 
| 176 | 
            +
                    final_hidden_states_2d = final_hidden_states.view(-1, hidden_dim)
         | 
| 177 | 
            +
                
         | 
| 178 | 
            +
                    # One hot encode the selected experts to create an expert mask
         | 
| 179 | 
            +
                    # this will be used to easily index which expert is going to be sollicitated
         | 
| 180 | 
            +
                    expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts)
         | 
| 181 | 
            +
                    #^ [batch_size, seq_len, top_k, num_experts]
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                    # Loop over all available experts in the model and perform the computation on each expert
         | 
| 184 | 
            +
                    for expert_idx in range(self.num_experts):
         | 
| 185 | 
            +
                        expert_layer: LlamaDecoderLayer = self.experts[expert_idx]
         | 
| 186 | 
            +
                        batch_indices, seq_indices, top_k_indices = torch.where(expert_mask[..., expert_idx])
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                        if not self.training and sequence_length == 1 and batch_indices.numel() == 0:
         | 
| 189 | 
            +
                            if past_key_value is not None:
         | 
| 190 | 
            +
                                
         | 
| 191 | 
            +
                                hidden_state_ln_norm = expert_layer.input_layernorm(hidden_states)
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                                input_shape = hidden_state_ln_norm.shape[:-1]
         | 
| 194 | 
            +
                                hidden_shape = (*input_shape, -1, self.head_dim)
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                                # query_states = expert_layer.self_attn.q_proj(hidden_state_ln_norm).view(hidden_shape).transpose(1, 2)
         | 
| 197 | 
            +
                                key_states = expert_layer.self_attn.k_proj(hidden_state_ln_norm).view(hidden_shape).transpose(1, 2)
         | 
| 198 | 
            +
                                value_states = expert_layer.self_attn.v_proj(hidden_state_ln_norm).view(hidden_shape).transpose(1, 2)
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                                cos, sin = position_embeddings
         | 
| 201 | 
            +
                                _, key_states = apply_rotary_pos_emb(key_states, key_states, cos, sin)
         | 
| 202 | 
            +
                                # sin and cos are specific to RoPE models; cache_position needed for the static cache
         | 
| 203 | 
            +
                                cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
         | 
| 204 | 
            +
                                past_key_value.update(key_states, value_states, self.layer_idx * self.num_experts + expert_idx, cache_kwargs)
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                            continue
         | 
| 207 | 
            +
                        
         | 
| 208 | 
            +
                        if self.gradient_checkpointing and self.training:
         | 
| 209 | 
            +
                            current_hidden_states = self._gradient_checkpointing_func(
         | 
| 210 | 
            +
                                expert_layer.__call__,
         | 
| 211 | 
            +
                                hidden_states,
         | 
| 212 | 
            +
                                attention_mask,
         | 
| 213 | 
            +
                                position_ids,
         | 
| 214 | 
            +
                                past_key_value,
         | 
| 215 | 
            +
                                output_attentions,
         | 
| 216 | 
            +
                                use_cache,
         | 
| 217 | 
            +
                                cache_position,
         | 
| 218 | 
            +
                                position_embeddings,
         | 
| 219 | 
            +
                            )[0]
         | 
| 220 | 
            +
                        else:
         | 
| 221 | 
            +
                            current_hidden_states = expert_layer(
         | 
| 222 | 
            +
                                hidden_states=hidden_states,
         | 
| 223 | 
            +
                                attention_mask=attention_mask,
         | 
| 224 | 
            +
                                position_ids=position_ids,
         | 
| 225 | 
            +
                                past_key_value=past_key_value,
         | 
| 226 | 
            +
                                output_attentions=output_attentions,
         | 
| 227 | 
            +
                                use_cache=use_cache,
         | 
| 228 | 
            +
                                cache_position=cache_position,
         | 
| 229 | 
            +
                                position_embeddings=position_embeddings,
         | 
| 230 | 
            +
                                **kwargs,
         | 
| 231 | 
            +
                            )[0]
         | 
| 232 | 
            +
                        
         | 
| 233 | 
            +
                        
         | 
| 234 | 
            +
                        flat_idx = batch_indices * sequence_length + seq_indices
         | 
| 235 | 
            +
                        expert_weights = routing_weights[batch_indices, seq_indices, top_k_indices].unsqueeze(-1)
         | 
| 236 | 
            +
                        current_hidden_states = current_hidden_states[batch_indices, seq_indices] * expert_weights
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                        final_hidden_states_2d.index_add_(0, flat_idx, current_hidden_states.to(hidden_states.dtype))
         | 
| 239 | 
            +
                   
         | 
| 240 | 
            +
                    final_hidden_states = final_hidden_states_2d.view(batch_size, sequence_length, hidden_dim)
         | 
| 241 | 
            +
                    return final_hidden_states, router_logits
         | 
| 242 | 
            +
                
         | 
| 243 | 
            +
            class MiCRoLlama(LlamaPreTrainedModel, GenerationMixin):
         | 
| 244 | 
            +
                config_class = MiCRoLlamaConfig
         | 
| 245 | 
            +
                def __init__(self, config: MiCRoLlamaConfig):
         | 
| 246 | 
            +
                    with open(config.config_path, 'r', encoding="utf-8") as file:
         | 
| 247 | 
            +
                        run_config = yaml.load(file.read(), Loader=yaml.FullLoader)
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                    self.config: MiCRoLlamaConfig = config
         | 
| 250 | 
            +
                    self.config.torch_dtype = torch.bfloat16
         | 
| 251 | 
            +
                    self.config.use_bfloat16 = True
         | 
| 252 | 
            +
                    self.config._attn_implementation = "flash_attention_2" # {sdpa, flash_attention_2, eager}
         | 
| 253 | 
            +
                    self.config.backbone_num_layers = self.config.num_hidden_layers
         | 
| 254 | 
            +
                    self.config.num_hidden_layers = self.config.num_hidden_layers * run_config["num-experts"]
         | 
| 255 | 
            +
                    self.config.loss_type = "ForCausalLMLoss"
         | 
| 256 | 
            +
                    
         | 
| 257 | 
            +
                    super(MiCRoLlama, self).__init__(self.config)
         | 
| 258 | 
            +
                    self.build_model(run_config)
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                def build_model(self, run_config):
         | 
| 261 | 
            +
                
         | 
| 262 | 
            +
                    self.gradient_checkpointing = False
         | 
| 263 | 
            +
                    self.config.num_experts = run_config["num-experts"]
         | 
| 264 | 
            +
                    self.config.use_router = run_config["use-router"]
         | 
| 265 | 
            +
                    self.config.num_experts_per_tok = run_config["top-k-experts"]
         | 
| 266 | 
            +
                    print(f">> Number of Experts per Token: {self.config.num_experts_per_tok}")
         | 
| 267 | 
            +
                    self.config.jitter_noise = run_config["jitter-noise"]
         | 
| 268 | 
            +
                    self.config.loss_method = run_config.get("loss", "all")
         | 
| 269 | 
            +
                    self.config.gradient_checkpointing = run_config.get("gradient-checkpointing", False)
         | 
| 270 | 
            +
                    print(f">> Gradient Checkpointing: {self.config.gradient_checkpointing}")
         | 
| 271 | 
            +
             | 
| 272 | 
            +
                    self.run_config = run_config
         | 
| 273 | 
            +
                    self.padding_idx = 2 if "smollm2" in run_config["model"] else 128004
         | 
| 274 | 
            +
                    
         | 
| 275 | 
            +
                    # MiCRoLlama model
         | 
| 276 | 
            +
                    self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.hidden_size, self.padding_idx)
         | 
| 277 | 
            +
                    self.layers = nn.ModuleList([MiCRoLlamaDecoderLayer(self.config, layer_idx) for layer_idx in range(self.config.backbone_num_layers)])
         | 
| 278 | 
            +
                    self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)
         | 
| 279 | 
            +
                    self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
         | 
| 280 | 
            +
                    self.final_norm = LlamaRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                    if "model" not in run_config["trainable"]:
         | 
| 283 | 
            +
                        print(">> Freezing Model Except Routing Gates")
         | 
| 284 | 
            +
                        for param in self.parameters():
         | 
| 285 | 
            +
                            param.requires_grad = False
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                        for layer in self.layers:
         | 
| 288 | 
            +
                            layer: MiCRoLlamaDecoderLayer
         | 
| 289 | 
            +
                            for param in layer.gate.parameters():
         | 
| 290 | 
            +
                                param.requires_grad = True
         | 
| 291 | 
            +
             | 
| 292 | 
            +
                    if "experts-router" not in run_config["trainable"]:
         | 
| 293 | 
            +
                        print(">> Freezing Routing Gates")
         | 
| 294 | 
            +
                        for layer in self.layers:
         | 
| 295 | 
            +
                            layer: MiCRoLlamaDecoderLayer
         | 
| 296 | 
            +
                            for param in layer.gate.parameters():
         | 
| 297 | 
            +
                                param.requires_grad = False
         | 
| 298 | 
            +
             | 
| 299 | 
            +
             | 
| 300 | 
            +
             | 
| 301 | 
            +
                def forward(self,
         | 
| 302 | 
            +
                    input_ids: torch.LongTensor = None,
         | 
| 303 | 
            +
                    attention_mask: Optional[torch.Tensor] = None,
         | 
| 304 | 
            +
                    position_ids: Optional[torch.LongTensor] = None,
         | 
| 305 | 
            +
                    experts_ablate: Optional[List[str]] = None,
         | 
| 306 | 
            +
                    routing_weights: Optional[torch.LongTensor] = None,
         | 
| 307 | 
            +
                    past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
         | 
| 308 | 
            +
                    inputs_embeds: Optional[torch.FloatTensor] = None,
         | 
| 309 | 
            +
                    labels: Optional[torch.LongTensor] = None,
         | 
| 310 | 
            +
                    use_cache: Optional[bool] = None,
         | 
| 311 | 
            +
                    output_attentions: Optional[bool] = None,
         | 
| 312 | 
            +
                    output_hidden_states: Optional[bool] = None,
         | 
| 313 | 
            +
                    return_dict: Optional[bool] = None,
         | 
| 314 | 
            +
                    cache_position: Optional[torch.LongTensor] = None,
         | 
| 315 | 
            +
                    logits_to_keep: Union[int, torch.Tensor] = 0,
         | 
| 316 | 
            +
                    **kwargs: Unpack[FlashAttentionKwargs],
         | 
| 317 | 
            +
                ):
         | 
| 318 | 
            +
             | 
| 319 | 
            +
                    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
         | 
| 320 | 
            +
                    output_hidden_states = (
         | 
| 321 | 
            +
                        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
         | 
| 322 | 
            +
                    )
         | 
| 323 | 
            +
                    use_cache = use_cache if use_cache is not None else self.config.use_cache
         | 
| 324 | 
            +
                    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
         | 
| 325 | 
            +
             | 
| 326 | 
            +
                    if (input_ids is None) ^ (inputs_embeds is not None):
         | 
| 327 | 
            +
                        raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
         | 
| 328 | 
            +
             | 
| 329 | 
            +
                    if self.gradient_checkpointing and self.training and use_cache:
         | 
| 330 | 
            +
                        logger.warning_once(
         | 
| 331 | 
            +
                            "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
         | 
| 332 | 
            +
                        )
         | 
| 333 | 
            +
                        use_cache = False
         | 
| 334 | 
            +
             | 
| 335 | 
            +
                    if inputs_embeds is None:
         | 
| 336 | 
            +
                        inputs_embeds = self.embed_tokens(input_ids)
         | 
| 337 | 
            +
             | 
| 338 | 
            +
                    if use_cache and past_key_values is None:
         | 
| 339 | 
            +
                        past_key_values = DynamicCache()
         | 
| 340 | 
            +
             | 
| 341 | 
            +
                    if cache_position is None:
         | 
| 342 | 
            +
                        past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
         | 
| 343 | 
            +
                        cache_position = torch.arange(
         | 
| 344 | 
            +
                            past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
         | 
| 345 | 
            +
                        )
         | 
| 346 | 
            +
             | 
| 347 | 
            +
                    if position_ids is None:
         | 
| 348 | 
            +
                        position_ids = cache_position.unsqueeze(0)
         | 
| 349 | 
            +
             | 
| 350 | 
            +
                    causal_mask = self._update_causal_mask(
         | 
| 351 | 
            +
                        attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
         | 
| 352 | 
            +
                    )
         | 
| 353 | 
            +
             | 
| 354 | 
            +
                    hidden_states = inputs_embeds
         | 
| 355 | 
            +
             | 
| 356 | 
            +
                    # create position embeddings to be shared across the decoder layers
         | 
| 357 | 
            +
                    position_embeddings = self.rotary_emb(hidden_states, position_ids)
         | 
| 358 | 
            +
             | 
| 359 | 
            +
                    # decoder layers
         | 
| 360 | 
            +
                    all_hidden_states = () if output_hidden_states else None
         | 
| 361 | 
            +
                    all_self_attns = () if output_attentions else None
         | 
| 362 | 
            +
             | 
| 363 | 
            +
                    all_routing_weights = ()
         | 
| 364 | 
            +
             
         | 
| 365 | 
            +
                    for decoder_layer in self.layers:
         | 
| 366 | 
            +
                        if output_hidden_states:
         | 
| 367 | 
            +
                            all_hidden_states += (hidden_states,)
         | 
| 368 | 
            +
             | 
| 369 | 
            +
                        if self.gradient_checkpointing and self.training and False:
         | 
| 370 | 
            +
                            layer_outputs, router_logits = self._gradient_checkpointing_func(
         | 
| 371 | 
            +
                                decoder_layer.__call__,
         | 
| 372 | 
            +
                                hidden_states,
         | 
| 373 | 
            +
                                routing_weights,
         | 
| 374 | 
            +
                                causal_mask,
         | 
| 375 | 
            +
                                position_ids,
         | 
| 376 | 
            +
                                experts_ablate,
         | 
| 377 | 
            +
                                past_key_values,
         | 
| 378 | 
            +
                                output_attentions,
         | 
| 379 | 
            +
                                use_cache,
         | 
| 380 | 
            +
                                cache_position,
         | 
| 381 | 
            +
                                position_embeddings,
         | 
| 382 | 
            +
                            )
         | 
| 383 | 
            +
                        else:
         | 
| 384 | 
            +
                            layer_outputs, router_logits = decoder_layer(
         | 
| 385 | 
            +
                                hidden_states,
         | 
| 386 | 
            +
                                routing_weights=routing_weights,
         | 
| 387 | 
            +
                                attention_mask=causal_mask,
         | 
| 388 | 
            +
                                position_ids=position_ids,
         | 
| 389 | 
            +
                                ablate=experts_ablate,
         | 
| 390 | 
            +
                                past_key_value=past_key_values,
         | 
| 391 | 
            +
                                output_attentions=output_attentions,
         | 
| 392 | 
            +
                                use_cache=use_cache,
         | 
| 393 | 
            +
                                cache_position=cache_position,
         | 
| 394 | 
            +
                                position_embeddings=position_embeddings,
         | 
| 395 | 
            +
                                **kwargs,
         | 
| 396 | 
            +
                            )
         | 
| 397 | 
            +
             | 
| 398 | 
            +
                        hidden_states = layer_outputs
         | 
| 399 | 
            +
             | 
| 400 | 
            +
                        if output_attentions:
         | 
| 401 | 
            +
                            all_self_attns += (layer_outputs[1],)
         | 
| 402 | 
            +
             | 
| 403 | 
            +
                        all_routing_weights += (router_logits,)
         | 
| 404 | 
            +
             | 
| 405 | 
            +
                    hidden_states = self.final_norm(hidden_states)
         | 
| 406 | 
            +
             | 
| 407 | 
            +
                    # add hidden states from the last decoder layer
         | 
| 408 | 
            +
                    if output_hidden_states:
         | 
| 409 | 
            +
                        all_hidden_states += (hidden_states,)
         | 
| 410 | 
            +
                
         | 
| 411 | 
            +
                    # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
         | 
| 412 | 
            +
                    slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
         | 
| 413 | 
            +
                    logits = self.lm_head(hidden_states[:, slice_indices, :])
         | 
| 414 | 
            +
             | 
| 415 | 
            +
                    loss = None
         | 
| 416 | 
            +
                    if labels is not None:
         | 
| 417 | 
            +
                        loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
         | 
| 418 | 
            +
             | 
| 419 | 
            +
                    if not return_dict:
         | 
| 420 | 
            +
                        output = (logits,) + (past_key_values, all_hidden_states, all_self_attns, all_routing_weights) if use_cache else (logits, all_hidden_states, all_self_attns, all_routing_weights) 
         | 
| 421 | 
            +
                        return (loss,) + output if loss is not None else output
         | 
| 422 | 
            +
             | 
| 423 | 
            +
                    return CausalLMOutputWithPast(
         | 
| 424 | 
            +
                        loss=loss,
         | 
| 425 | 
            +
                        logits=logits,
         | 
| 426 | 
            +
                        past_key_values=past_key_values if use_cache else None,
         | 
| 427 | 
            +
                        hidden_states=all_hidden_states,
         | 
| 428 | 
            +
                        attentions=all_self_attns,
         | 
| 429 | 
            +
                        routing_weights=all_routing_weights,
         | 
| 430 | 
            +
                    )
         | 
| 431 | 
            +
             | 
| 432 | 
            +
                def _update_causal_mask(
         | 
| 433 | 
            +
                    self,
         | 
| 434 | 
            +
                    attention_mask: torch.Tensor,
         | 
| 435 | 
            +
                    input_tensor: torch.Tensor,
         | 
| 436 | 
            +
                    cache_position: torch.Tensor,
         | 
| 437 | 
            +
                    past_key_values: Cache,
         | 
| 438 | 
            +
                    output_attentions: bool,
         | 
| 439 | 
            +
                ):
         | 
| 440 | 
            +
                    if self.config._attn_implementation == "flash_attention_2":
         | 
| 441 | 
            +
                        if attention_mask is not None and 0.0 in attention_mask:
         | 
| 442 | 
            +
                            return attention_mask
         | 
| 443 | 
            +
                        return None
         | 
| 444 | 
            +
             | 
| 445 | 
            +
                    # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
         | 
| 446 | 
            +
                    # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
         | 
| 447 | 
            +
                    # to infer the attention mask.
         | 
| 448 | 
            +
                    past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
         | 
| 449 | 
            +
                    using_static_cache = isinstance(past_key_values, StaticCache)
         | 
| 450 | 
            +
             | 
| 451 | 
            +
                    # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
         | 
| 452 | 
            +
                    if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
         | 
| 453 | 
            +
                        if AttentionMaskConverter._ignore_causal_mask_sdpa(
         | 
| 454 | 
            +
                            attention_mask,
         | 
| 455 | 
            +
                            inputs_embeds=input_tensor,
         | 
| 456 | 
            +
                            past_key_values_length=past_seen_tokens,
         | 
| 457 | 
            +
                            is_training=self.training,
         | 
| 458 | 
            +
                        ):
         | 
| 459 | 
            +
                            return None
         | 
| 460 | 
            +
             | 
| 461 | 
            +
                    dtype, device = input_tensor.dtype, input_tensor.device
         | 
| 462 | 
            +
                    min_dtype = torch.finfo(dtype).min
         | 
| 463 | 
            +
                    sequence_length = input_tensor.shape[1]
         | 
| 464 | 
            +
                    if using_static_cache:
         | 
| 465 | 
            +
                        target_length = past_key_values.get_max_length()
         | 
| 466 | 
            +
                    else:
         | 
| 467 | 
            +
                        target_length = (
         | 
| 468 | 
            +
                            attention_mask.shape[-1]
         | 
| 469 | 
            +
                            if isinstance(attention_mask, torch.Tensor)
         | 
| 470 | 
            +
                            else past_seen_tokens + sequence_length + 1
         | 
| 471 | 
            +
                        )
         | 
| 472 | 
            +
             | 
| 473 | 
            +
                    # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
         | 
| 474 | 
            +
                    causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
         | 
| 475 | 
            +
                        attention_mask,
         | 
| 476 | 
            +
                        sequence_length=sequence_length,
         | 
| 477 | 
            +
                        target_length=target_length,
         | 
| 478 | 
            +
                        dtype=dtype,
         | 
| 479 | 
            +
                        device=device,
         | 
| 480 | 
            +
                        min_dtype=min_dtype,
         | 
| 481 | 
            +
                        cache_position=cache_position,
         | 
| 482 | 
            +
                        batch_size=input_tensor.shape[0],
         | 
| 483 | 
            +
                    )
         | 
| 484 | 
            +
             | 
| 485 | 
            +
                    if (
         | 
| 486 | 
            +
                        self.config._attn_implementation == "sdpa"
         | 
| 487 | 
            +
                        and attention_mask is not None
         | 
| 488 | 
            +
                        and attention_mask.device.type == "cuda"
         | 
| 489 | 
            +
                        and not output_attentions
         | 
| 490 | 
            +
                    ):
         | 
| 491 | 
            +
                        # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
         | 
| 492 | 
            +
                        # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
         | 
| 493 | 
            +
                        # Details: https://github.com/pytorch/pytorch/issues/110213
         | 
| 494 | 
            +
                        causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
         | 
| 495 | 
            +
             | 
| 496 | 
            +
                    return causal_mask
         | 
| 497 | 
            +
             | 
| 498 | 
            +
                def load_pretrained(self, model_name):
         | 
| 499 | 
            +
                    base_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
         | 
| 500 | 
            +
                    self.lm_head.load_state_dict(base_model.lm_head.state_dict())
         | 
| 501 | 
            +
                    self.embed_tokens.load_state_dict(base_model.get_input_embeddings().state_dict())
         | 
| 502 | 
            +
                    self.rotary_emb.load_state_dict(base_model.model.rotary_emb.state_dict())
         | 
| 503 | 
            +
                    self.final_norm.load_state_dict(base_model.model.norm.state_dict())
         | 
| 504 | 
            +
                    for layer_idx, layer in enumerate(self.layers):
         | 
| 505 | 
            +
                        base_model_layer = base_model.model.layers[layer_idx].state_dict()
         | 
| 506 | 
            +
                        for expert in layer.experts:
         | 
| 507 | 
            +
                            expert.load_state_dict(base_model_layer)
         | 
| 508 | 
            +
             | 
| 509 | 
            +
                def prepare_inputs_for_generation(
         | 
| 510 | 
            +
                    self,
         | 
| 511 | 
            +
                    input_ids,
         | 
| 512 | 
            +
                    past_key_values=None,
         | 
| 513 | 
            +
                    attention_mask=None,
         | 
| 514 | 
            +
                    inputs_embeds=None,
         | 
| 515 | 
            +
                    cache_position=None,
         | 
| 516 | 
            +
                    position_ids=None,
         | 
| 517 | 
            +
                    experts_ablate=None,
         | 
| 518 | 
            +
                    use_cache=True,
         | 
| 519 | 
            +
                    num_logits_to_keep=None,
         | 
| 520 | 
            +
                    **kwargs,
         | 
| 521 | 
            +
                ):
         | 
| 522 | 
            +
                    
         | 
| 523 | 
            +
                    # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
         | 
| 524 | 
            +
                    # Exception 1: when passing input_embeds, input_ids may be missing entries
         | 
| 525 | 
            +
                    # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
         | 
| 526 | 
            +
                    if past_key_values is not None:
         | 
| 527 | 
            +
                        if inputs_embeds is not None:  # Exception 1
         | 
| 528 | 
            +
                            input_ids = input_ids[:, -cache_position.shape[0] :]
         | 
| 529 | 
            +
                        elif input_ids.shape[1] != cache_position.shape[0]:  # Default case (the "else", a no op, is Exception 2)
         | 
| 530 | 
            +
                            input_ids = input_ids[:, cache_position]
         | 
| 531 | 
            +
             | 
| 532 | 
            +
                    if attention_mask is not None and position_ids is None:
         | 
| 533 | 
            +
                        # create position_ids on the fly for batch generation
         | 
| 534 | 
            +
                        position_ids = attention_mask.long().cumsum(-1) - 1
         | 
| 535 | 
            +
                        position_ids.masked_fill_(attention_mask == 0, 1)
         | 
| 536 | 
            +
                        if past_key_values:
         | 
| 537 | 
            +
                            position_ids = position_ids[:, -input_ids.shape[1] :]
         | 
| 538 | 
            +
             | 
| 539 | 
            +
                            # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s  `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
         | 
| 540 | 
            +
                            position_ids = position_ids.clone(memory_format=torch.contiguous_format)
         | 
| 541 | 
            +
             | 
| 542 | 
            +
                    # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
         | 
| 543 | 
            +
                    if inputs_embeds is not None and cache_position[0] == 0:
         | 
| 544 | 
            +
                        model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
         | 
| 545 | 
            +
                    else:
         | 
| 546 | 
            +
                        # The clone here is for the same reason as for `position_ids`.
         | 
| 547 | 
            +
                        model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
         | 
| 548 | 
            +
             | 
| 549 | 
            +
                    if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
         | 
| 550 | 
            +
                        if model_inputs["inputs_embeds"] is not None:
         | 
| 551 | 
            +
                            batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
         | 
| 552 | 
            +
                            device = model_inputs["inputs_embeds"].device
         | 
| 553 | 
            +
                        else:
         | 
| 554 | 
            +
                            batch_size, sequence_length = model_inputs["input_ids"].shape
         | 
| 555 | 
            +
                            device = model_inputs["input_ids"].device
         | 
| 556 | 
            +
             | 
| 557 | 
            +
                        dtype = self.lm_head.weight.dtype
         | 
| 558 | 
            +
                        min_dtype = torch.finfo(dtype).min
         | 
| 559 | 
            +
             | 
| 560 | 
            +
                        attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
         | 
| 561 | 
            +
                            attention_mask,
         | 
| 562 | 
            +
                            sequence_length=sequence_length,
         | 
| 563 | 
            +
                            target_length=past_key_values.get_max_length(),
         | 
| 564 | 
            +
                            dtype=dtype,
         | 
| 565 | 
            +
                            device=device,
         | 
| 566 | 
            +
                            min_dtype=min_dtype,
         | 
| 567 | 
            +
                            cache_position=cache_position,
         | 
| 568 | 
            +
                            batch_size=batch_size,
         | 
| 569 | 
            +
                        )
         | 
| 570 | 
            +
             | 
| 571 | 
            +
                    if num_logits_to_keep is not None:
         | 
| 572 | 
            +
                        model_inputs["num_logits_to_keep"] = num_logits_to_keep
         | 
| 573 | 
            +
             | 
| 574 | 
            +
                    model_inputs.update(
         | 
| 575 | 
            +
                        {
         | 
| 576 | 
            +
                            "experts_ablate": experts_ablate,
         | 
| 577 | 
            +
                            "position_ids": position_ids,
         | 
| 578 | 
            +
                            "cache_position": cache_position,
         | 
| 579 | 
            +
                            "past_key_values": past_key_values,
         | 
| 580 | 
            +
                            "use_cache": use_cache,
         | 
| 581 | 
            +
                            "attention_mask": attention_mask,
         | 
| 582 | 
            +
                        }
         | 
| 583 | 
            +
                    )
         | 
| 584 | 
            +
                    return model_inputs
         | 
| 585 | 
            +
             | 
| 586 | 
            +
             | 
| 587 | 
            +
            AutoConfig.register("micro_llama", MiCRoLlamaConfig)
         | 
| 588 | 
            +
            AutoModelForCausalLM.register(MiCRoLlamaConfig, MiCRoLlama)
         | 
    	
        models/micro_moe_llama.py
    ADDED
    
    | @@ -0,0 +1,725 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import Optional, Tuple, Union, List, Callable
         | 
| 2 | 
            +
            import logging 
         | 
| 3 | 
            +
            import yaml
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            import torch.nn as nn
         | 
| 7 | 
            +
            import torch.nn.functional as F
         | 
| 8 | 
            +
            import torch.distributed as dist
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from transformers import LlamaConfig, AutoModelForCausalLM, AutoConfig
         | 
| 11 | 
            +
            from transformers.modeling_attn_mask_utils import AttentionMaskConverter
         | 
| 12 | 
            +
            from transformers.models.llama.modeling_llama import (
         | 
| 13 | 
            +
                LlamaRotaryEmbedding, 
         | 
| 14 | 
            +
                LlamaRMSNorm, 
         | 
| 15 | 
            +
                LlamaMLP,
         | 
| 16 | 
            +
                LlamaAttention,
         | 
| 17 | 
            +
                LlamaForCausalLM,
         | 
| 18 | 
            +
                LlamaPreTrainedModel, 
         | 
| 19 | 
            +
                GenerationMixin,
         | 
| 20 | 
            +
                apply_rotary_pos_emb,
         | 
| 21 | 
            +
                eager_attention_forward,
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            )
         | 
| 24 | 
            +
            from transformers.modeling_layers import GradientCheckpointingLayer
         | 
| 25 | 
            +
            from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
         | 
| 26 | 
            +
            from transformers.cache_utils import Cache, StaticCache, DynamicCache
         | 
| 27 | 
            +
            from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
         | 
| 28 | 
            +
            from transformers.processing_utils import Unpack
         | 
| 29 | 
            +
            from transformers.utils import is_torchdynamo_compiling
         | 
| 30 | 
            +
            from transformers.activations import ACT2FN
         | 
| 31 | 
            +
            from models.modules import CausalLMOutputWithPast
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            def keep_alive_zero(model):
         | 
| 36 | 
            +
                z = 0.0
         | 
| 37 | 
            +
                for p in model.parameters():
         | 
| 38 | 
            +
                    if p.requires_grad:
         | 
| 39 | 
            +
                        # one scalar per param to avoid heavy sums
         | 
| 40 | 
            +
                        z = z + (p.view(-1)[0] * 0.0)
         | 
| 41 | 
            +
                return z
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            class MiCRoLlamaMoEConfig(LlamaConfig):
         | 
| 44 | 
            +
                model_type = "micro_llama_moe"
         | 
| 45 | 
            +
                def __init__(self, *args, **kwargs):
         | 
| 46 | 
            +
                    super().__init__(*args, **kwargs)
         | 
| 47 | 
            +
                    self.num_experts = kwargs.get("num_experts", 4)
         | 
| 48 | 
            +
                    self.use_router = kwargs.get("use_router", True)
         | 
| 49 | 
            +
                    self.num_experts_per_tok = kwargs.get("num_experts_per_tok", 2)
         | 
| 50 | 
            +
                    self.jitter_noise = kwargs.get("jitter_noise", 0.0)
         | 
| 51 | 
            +
                    self.loss_method = kwargs.get("loss_method", "all")
         | 
| 52 | 
            +
                    self.config_path = kwargs.get("config_path", None)
         | 
| 53 | 
            +
             | 
| 54 | 
            +
            def _prepare_4d_causal_attention_mask_with_cache_position(
         | 
| 55 | 
            +
                attention_mask: torch.Tensor,
         | 
| 56 | 
            +
                sequence_length: int,
         | 
| 57 | 
            +
                target_length: int,
         | 
| 58 | 
            +
                dtype: torch.dtype,
         | 
| 59 | 
            +
                device: torch.device,
         | 
| 60 | 
            +
                min_dtype: float,
         | 
| 61 | 
            +
                cache_position: torch.Tensor,
         | 
| 62 | 
            +
                batch_size: int,
         | 
| 63 | 
            +
            ):
         | 
| 64 | 
            +
                """
         | 
| 65 | 
            +
                Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
         | 
| 66 | 
            +
                `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                Args:
         | 
| 69 | 
            +
                    attention_mask (`torch.Tensor`):
         | 
| 70 | 
            +
                        A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
         | 
| 71 | 
            +
                    sequence_length (`int`):
         | 
| 72 | 
            +
                        The sequence length being processed.
         | 
| 73 | 
            +
                    target_length (`int`):
         | 
| 74 | 
            +
                        The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
         | 
| 75 | 
            +
                    dtype (`torch.dtype`):
         | 
| 76 | 
            +
                        The dtype to use for the 4D attention mask.
         | 
| 77 | 
            +
                    device (`torch.device`):
         | 
| 78 | 
            +
                        The device to plcae the 4D attention mask on.
         | 
| 79 | 
            +
                    min_dtype (`float`):
         | 
| 80 | 
            +
                        The minimum value representable with the dtype `dtype`.
         | 
| 81 | 
            +
                    cache_position (`torch.Tensor`):
         | 
| 82 | 
            +
                        Indices depicting the position of the input sequence tokens in the sequence.
         | 
| 83 | 
            +
                    batch_size (`torch.Tensor`):
         | 
| 84 | 
            +
                        Batch size.
         | 
| 85 | 
            +
                """
         | 
| 86 | 
            +
                if attention_mask is not None and attention_mask.dim() == 4:
         | 
| 87 | 
            +
                    # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
         | 
| 88 | 
            +
                    causal_mask = attention_mask
         | 
| 89 | 
            +
                else:
         | 
| 90 | 
            +
                    causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
         | 
| 91 | 
            +
                    if sequence_length != 1:
         | 
| 92 | 
            +
                        causal_mask = torch.triu(causal_mask, diagonal=1)
         | 
| 93 | 
            +
                    causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
         | 
| 94 | 
            +
                    causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
         | 
| 95 | 
            +
                    if attention_mask is not None:
         | 
| 96 | 
            +
                        causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit
         | 
| 97 | 
            +
                        mask_length = attention_mask.shape[-1]
         | 
| 98 | 
            +
                        padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
         | 
| 99 | 
            +
                        padding_mask = padding_mask == 0
         | 
| 100 | 
            +
                        causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
         | 
| 101 | 
            +
                            padding_mask, min_dtype
         | 
| 102 | 
            +
                        )
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                return causal_mask
         | 
| 105 | 
            +
             | 
| 106 | 
            +
            class DummyModule(nn.Module):
         | 
| 107 | 
            +
                def __init__(self):
         | 
| 108 | 
            +
                    super().__init__()
         | 
| 109 | 
            +
                def forward(self, x):
         | 
| 110 | 
            +
                    return x
         | 
| 111 | 
            +
                
         | 
| 112 | 
            +
            class LlamaSparseMiCRoMoEBlock(nn.Module):
         | 
| 113 | 
            +
                """
         | 
| 114 | 
            +
                This implementation is
         | 
| 115 | 
            +
                strictly equivalent to standard MoE with full capacity (no
         | 
| 116 | 
            +
                dropped tokens). It's faster since it formulates MoE operations
         | 
| 117 | 
            +
                in terms of block-sparse operations to accommodate imbalanced
         | 
| 118 | 
            +
                assignments of tokens to experts, whereas standard MoE either
         | 
| 119 | 
            +
                (1) drop tokens at the cost of reduced performance or (2) set
         | 
| 120 | 
            +
                capacity factor to number of experts and thus waste computation
         | 
| 121 | 
            +
                and memory on padding.
         | 
| 122 | 
            +
                """
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                def __init__(self, config):
         | 
| 125 | 
            +
                    super().__init__()
         | 
| 126 | 
            +
                    self.hidden_dim = config.hidden_size
         | 
| 127 | 
            +
                    self.ffn_dim = config.intermediate_size
         | 
| 128 | 
            +
                    self.num_experts = config.num_experts
         | 
| 129 | 
            +
                    self.top_k = config.num_experts_per_tok
         | 
| 130 | 
            +
                    self.use_router = config.use_router
         | 
| 131 | 
            +
                    self.ablate = config.ablate
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                    # gating
         | 
| 134 | 
            +
                    self.gate = nn.Sequential(
         | 
| 135 | 
            +
                        nn.Linear(self.hidden_dim, self.hidden_dim, bias=False),
         | 
| 136 | 
            +
                        nn.Linear(self.hidden_dim, self.num_experts, bias=False)
         | 
| 137 | 
            +
                    )
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                    self.experts = nn.ModuleList([LlamaMLP(config) for _ in range(self.num_experts)])
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                    self.dummy = DummyModule()
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                    # Jitter parameters
         | 
| 144 | 
            +
                    self.jitter_noise = config.jitter_noise
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                def forward(self, hidden_states: torch.Tensor, routing_weights: Optional[torch.Tensor] = None) -> torch.Tensor:
         | 
| 147 | 
            +
                    """ """
         | 
| 148 | 
            +
                    batch_size, sequence_length, hidden_dim = hidden_states.shape
         | 
| 149 | 
            +
                    if self.training and self.jitter_noise > 0:
         | 
| 150 | 
            +
                        hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
         | 
| 151 | 
            +
                    hidden_states = hidden_states.view(-1, hidden_dim)
         | 
| 152 | 
            +
                    
         | 
| 153 | 
            +
                    if self.use_router:
         | 
| 154 | 
            +
                        router_logits = self.gate(hidden_states)
         | 
| 155 | 
            +
                        if "logic" in self.ablate:
         | 
| 156 | 
            +
                            router_logits[..., 0] = -torch.inf
         | 
| 157 | 
            +
                        if "social" in self.ablate:
         | 
| 158 | 
            +
                            router_logits[..., 1] = -torch.inf
         | 
| 159 | 
            +
                        if "world" in self.ablate:
         | 
| 160 | 
            +
                            router_logits[..., 2] = -torch.inf
         | 
| 161 | 
            +
                        if "language" in self.ablate:
         | 
| 162 | 
            +
                            router_logits[..., 3] = -torch.inf
         | 
| 163 | 
            +
                        routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float)
         | 
| 164 | 
            +
                    else:
         | 
| 165 | 
            +
                        routing_weights = routing_weights.reshape(-1, 4).float()
         | 
| 166 | 
            +
                        router_logits = routing_weights
         | 
| 167 | 
            +
                    # router_logits: (batch * sequence_length, n_experts)
         | 
| 168 | 
            +
                    
         | 
| 169 | 
            +
                    routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
         | 
| 170 | 
            +
                    routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
         | 
| 171 | 
            +
                    # we cast back to the input dtype
         | 
| 172 | 
            +
                    routing_weights = routing_weights.to(hidden_states.dtype)
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                    final_hidden_states = torch.zeros(
         | 
| 175 | 
            +
                        (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
         | 
| 176 | 
            +
                    )
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                    H_up = self.experts[0].up_proj.out_features
         | 
| 179 | 
            +
                    Y_up = hidden_states.new_zeros((batch_size, sequence_length, self.num_experts, H_up))
         | 
| 180 | 
            +
             | 
| 181 | 
            +
             | 
| 182 | 
            +
                    # One hot encode the selected experts to create an expert mask
         | 
| 183 | 
            +
                    # this will be used to easily index which expert is going to be sollicitated
         | 
| 184 | 
            +
                    expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                    expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
         | 
| 187 | 
            +
                    for expert_idx in expert_hitted:
         | 
| 188 | 
            +
                        expert_layer = self.experts[expert_idx]
         | 
| 189 | 
            +
                        idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
         | 
| 190 | 
            +
                        # Index the correct hidden states and compute the expert hidden state for
         | 
| 191 | 
            +
                        # the current expert. We need to make sure to multiply the output hidden
         | 
| 192 | 
            +
                        # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
         | 
| 193 | 
            +
                        current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                        # --- Hook to capture up-proj output BEFORE nonlinearity ---
         | 
| 196 | 
            +
                        captured_up = []
         | 
| 197 | 
            +
                        def _up_hook(m, inp, out):
         | 
| 198 | 
            +
                            # out shape: [N_e, H_up]
         | 
| 199 | 
            +
                            captured_up.append(out.detach())
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                        h = expert_layer.up_proj.register_forward_hook(_up_hook)
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                        current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
         | 
| 204 | 
            +
                        h.remove()
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                        # Scatter captured up-proj per-token into Y_up[b, t, expert, :]
         | 
| 207 | 
            +
                        if captured_up:
         | 
| 208 | 
            +
                            up = captured_up[-1]  # [N_e, H_up]
         | 
| 209 | 
            +
                            b_idx = top_x // sequence_length
         | 
| 210 | 
            +
                            t_idx = top_x %  sequence_length
         | 
| 211 | 
            +
                            # Y_up[b,t,e,:] = up[n,:]
         | 
| 212 | 
            +
                            Y_up[b_idx, t_idx, expert_idx, :] = up
         | 
| 213 | 
            +
                        
         | 
| 214 | 
            +
                        # However `index_add_` only support torch tensors for indexing so we'll use
         | 
| 215 | 
            +
                        # the `top_x` tensor here.
         | 
| 216 | 
            +
                        final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
         | 
| 217 | 
            +
                    final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                    self.dummy(Y_up)
         | 
| 220 | 
            +
                    return final_hidden_states, router_logits
         | 
| 221 | 
            +
             | 
| 222 | 
            +
            class LlamaMiCRoMoEDecoderLayer(GradientCheckpointingLayer):
         | 
| 223 | 
            +
                def __init__(self, config: MiCRoLlamaMoEConfig, layer_idx: int):
         | 
| 224 | 
            +
                    super().__init__()
         | 
| 225 | 
            +
                    self.hidden_size = config.hidden_size
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                    self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx)
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                    self.block_sparse_moe = LlamaSparseMiCRoMoEBlock(config)
         | 
| 230 | 
            +
                    self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
         | 
| 231 | 
            +
                    self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                def forward(
         | 
| 234 | 
            +
                    self,
         | 
| 235 | 
            +
                    hidden_states: torch.Tensor,
         | 
| 236 | 
            +
                    position_embeddings: tuple[torch.Tensor, torch.Tensor],
         | 
| 237 | 
            +
                    routing_weights: Optional[torch.Tensor] = None,
         | 
| 238 | 
            +
                    attention_mask: Optional[torch.Tensor] = None,
         | 
| 239 | 
            +
                    position_ids: Optional[torch.LongTensor] = None,
         | 
| 240 | 
            +
                    past_key_value: Optional[tuple[torch.Tensor]] = None,
         | 
| 241 | 
            +
                    cache_position: Optional[torch.LongTensor] = None,
         | 
| 242 | 
            +
                    **kwargs: Unpack[FlashAttentionKwargs],
         | 
| 243 | 
            +
                ) -> torch.FloatTensor:
         | 
| 244 | 
            +
                    residual = hidden_states
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                    hidden_states = self.input_layernorm(hidden_states)
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                    # Self Attention
         | 
| 249 | 
            +
                    hidden_states, _ = self.self_attn(
         | 
| 250 | 
            +
                        hidden_states=hidden_states,
         | 
| 251 | 
            +
                        position_embeddings=position_embeddings,
         | 
| 252 | 
            +
                        attention_mask=attention_mask,
         | 
| 253 | 
            +
                        position_ids=position_ids,
         | 
| 254 | 
            +
                        past_key_value=past_key_value,
         | 
| 255 | 
            +
                        cache_position=cache_position,
         | 
| 256 | 
            +
                        **kwargs,
         | 
| 257 | 
            +
                    )
         | 
| 258 | 
            +
                    hidden_states = residual + hidden_states
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                    # Fully Connected
         | 
| 261 | 
            +
                    residual = hidden_states
         | 
| 262 | 
            +
                    hidden_states = self.post_attention_layernorm(hidden_states)
         | 
| 263 | 
            +
                    hidden_states, router_logits = self.block_sparse_moe(hidden_states, routing_weights)
         | 
| 264 | 
            +
                    hidden_states = residual + hidden_states
         | 
| 265 | 
            +
             | 
| 266 | 
            +
                    return hidden_states, router_logits
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                
         | 
| 269 | 
            +
            class MiCRoLlamaMoE(LlamaPreTrainedModel, GenerationMixin):
         | 
| 270 | 
            +
                config_class = MiCRoLlamaMoEConfig
         | 
| 271 | 
            +
                def __init__(self, config):
         | 
| 272 | 
            +
                    with open(config.config_path, 'r', encoding="utf-8") as file:
         | 
| 273 | 
            +
                        run_config = yaml.load(file.read(), Loader=yaml.FullLoader)
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                    self.config: MiCRoLlamaMoEConfig = config
         | 
| 276 | 
            +
                    self.config.torch_dtype = torch.bfloat16
         | 
| 277 | 
            +
                    self.config.use_bfloat16 = True
         | 
| 278 | 
            +
                    self.config._attn_implementation = "flash_attention_2" # {sdpa, flash_attention_2, eager}
         | 
| 279 | 
            +
                    self.config.use_cache = True
         | 
| 280 | 
            +
                    self.config.backbone_num_layers = self.config.num_hidden_layers
         | 
| 281 | 
            +
                    self.config.num_hidden_layers = self.config.num_hidden_layers
         | 
| 282 | 
            +
                    self.config.loss_type = "ForCausalLMLoss"
         | 
| 283 | 
            +
             | 
| 284 | 
            +
                    super(MiCRoLlamaMoE, self).__init__(self.config)
         | 
| 285 | 
            +
                    self.build_model(run_config)
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                def build_model(self, run_config):
         | 
| 288 | 
            +
                
         | 
| 289 | 
            +
                    self.config.num_experts = run_config["num-experts"]
         | 
| 290 | 
            +
                    self.config.use_router = run_config["use-router"]
         | 
| 291 | 
            +
                    self.config.num_experts_per_tok = run_config["top-k-experts"]
         | 
| 292 | 
            +
                    print(f">> Top-K Experts Per Token: {self.config.num_experts_per_tok}")
         | 
| 293 | 
            +
                    self.config.jitter_noise = run_config["jitter-noise"]
         | 
| 294 | 
            +
                    self.config.loss_method = run_config.get("loss", "all")
         | 
| 295 | 
            +
                    self.router_aux_loss_coef = run_config["router-aux-loss-coef"]
         | 
| 296 | 
            +
                    self.use_load_balancing = run_config.get("use-load-balancing", False)
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                    self.config.gradient_checkpointing = run_config.get("gradient-checkpointing", False)
         | 
| 299 | 
            +
                    self.gradient_checkpointing = self.config.gradient_checkpointing
         | 
| 300 | 
            +
             | 
| 301 | 
            +
                    print(f">> Gradient Checkpointing: {self.config.gradient_checkpointing}")
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                    self.run_config = run_config
         | 
| 304 | 
            +
                    self.padding_idx = 2 if "smollm2" in run_config["model"] else 128004
         | 
| 305 | 
            +
                    
         | 
| 306 | 
            +
                    # LlamaMoE model
         | 
| 307 | 
            +
                    self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.hidden_size, self.padding_idx)
         | 
| 308 | 
            +
                    self.layers = nn.ModuleList([LlamaMiCRoMoEDecoderLayer(self.config, layer_idx) for layer_idx in range(self.config.backbone_num_layers)])
         | 
| 309 | 
            +
                    self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)
         | 
| 310 | 
            +
                    self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
         | 
| 311 | 
            +
                    self.final_norm = LlamaRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
         | 
| 312 | 
            +
                    
         | 
| 313 | 
            +
                    if "model" not in run_config["trainable"]:
         | 
| 314 | 
            +
                        print(">> Freezing Model Except Experts + Routing Gates")
         | 
| 315 | 
            +
                        for param in self.parameters():
         | 
| 316 | 
            +
                            param.requires_grad = False
         | 
| 317 | 
            +
             | 
| 318 | 
            +
                        for layer in self.layers:
         | 
| 319 | 
            +
                            layer: LlamaMiCRoMoEDecoderLayer
         | 
| 320 | 
            +
                            for param in layer.block_sparse_moe.parameters():
         | 
| 321 | 
            +
                                param.requires_grad = True
         | 
| 322 | 
            +
             | 
| 323 | 
            +
                    if "experts" not in run_config["trainable"]:
         | 
| 324 | 
            +
                        print(">> Freezing Experts")
         | 
| 325 | 
            +
                        for layer in self.layers:
         | 
| 326 | 
            +
                            layer: LlamaMiCRoMoEDecoderLayer
         | 
| 327 | 
            +
                            for param in layer.block_sparse_moe.experts.parameters():
         | 
| 328 | 
            +
                                param.requires_grad = False
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                    if "experts-router" not in run_config["trainable"]:
         | 
| 331 | 
            +
                        print(">> Freezing Routing Gates")
         | 
| 332 | 
            +
                        for layer in self.layers:
         | 
| 333 | 
            +
                            layer: LlamaMiCRoMoEDecoderLayer
         | 
| 334 | 
            +
                            for param in layer.block_sparse_moe.gate.parameters():
         | 
| 335 | 
            +
                                param.requires_grad = False
         | 
| 336 | 
            +
             | 
| 337 | 
            +
             | 
| 338 | 
            +
                def forward(self,
         | 
| 339 | 
            +
                    input_ids: torch.LongTensor = None,
         | 
| 340 | 
            +
                    attention_mask: Optional[torch.Tensor] = None,
         | 
| 341 | 
            +
                    position_ids: Optional[torch.LongTensor] = None,
         | 
| 342 | 
            +
                    experts_ablate: Optional[List[str]] = None,
         | 
| 343 | 
            +
                    routing_weights: Optional[torch.LongTensor] = None,
         | 
| 344 | 
            +
                    past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
         | 
| 345 | 
            +
                    inputs_embeds: Optional[torch.FloatTensor] = None,
         | 
| 346 | 
            +
                    labels: Optional[torch.LongTensor] = None,
         | 
| 347 | 
            +
                    use_cache: Optional[bool] = None,
         | 
| 348 | 
            +
                    output_attentions: Optional[bool] = None,
         | 
| 349 | 
            +
                    output_hidden_states: Optional[bool] = None,
         | 
| 350 | 
            +
                    return_dict: Optional[bool] = None,
         | 
| 351 | 
            +
                    cache_position: Optional[torch.LongTensor] = None,
         | 
| 352 | 
            +
                    logits_to_keep: Union[int, torch.Tensor] = 0,
         | 
| 353 | 
            +
                    **kwargs: Unpack[FlashAttentionKwargs],
         | 
| 354 | 
            +
                ):
         | 
| 355 | 
            +
             | 
| 356 | 
            +
                    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
         | 
| 357 | 
            +
                    output_hidden_states = (
         | 
| 358 | 
            +
                        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
         | 
| 359 | 
            +
                    )
         | 
| 360 | 
            +
                    use_cache = use_cache if use_cache is not None else self.config.use_cache
         | 
| 361 | 
            +
                    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
         | 
| 362 | 
            +
             | 
| 363 | 
            +
                    if (input_ids is None) ^ (inputs_embeds is not None):
         | 
| 364 | 
            +
                        raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
         | 
| 365 | 
            +
             | 
| 366 | 
            +
                    if self.gradient_checkpointing and self.training and use_cache:
         | 
| 367 | 
            +
                        logger.warning_once(
         | 
| 368 | 
            +
                            "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
         | 
| 369 | 
            +
                        )
         | 
| 370 | 
            +
                        use_cache = False
         | 
| 371 | 
            +
             | 
| 372 | 
            +
                    if inputs_embeds is None:
         | 
| 373 | 
            +
                        inputs_embeds = self.embed_tokens(input_ids)
         | 
| 374 | 
            +
             | 
| 375 | 
            +
                    if use_cache and past_key_values is None:
         | 
| 376 | 
            +
                        past_key_values = DynamicCache()
         | 
| 377 | 
            +
             | 
| 378 | 
            +
                    if cache_position is None:
         | 
| 379 | 
            +
                        past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
         | 
| 380 | 
            +
                        cache_position = torch.arange(
         | 
| 381 | 
            +
                            past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
         | 
| 382 | 
            +
                        )
         | 
| 383 | 
            +
             | 
| 384 | 
            +
                    if position_ids is None:
         | 
| 385 | 
            +
                        position_ids = cache_position.unsqueeze(0)
         | 
| 386 | 
            +
             | 
| 387 | 
            +
                    causal_mask = self._update_causal_mask(
         | 
| 388 | 
            +
                        attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
         | 
| 389 | 
            +
                    )
         | 
| 390 | 
            +
             | 
| 391 | 
            +
                    hidden_states = inputs_embeds
         | 
| 392 | 
            +
             | 
| 393 | 
            +
                    # create position embeddings to be shared across the decoder layers
         | 
| 394 | 
            +
                    position_embeddings = self.rotary_emb(hidden_states, position_ids)
         | 
| 395 | 
            +
             | 
| 396 | 
            +
                    # decoder layers
         | 
| 397 | 
            +
                    all_hidden_states = () if output_hidden_states else None
         | 
| 398 | 
            +
                    all_self_attns = () if output_attentions else None
         | 
| 399 | 
            +
             | 
| 400 | 
            +
                    all_routing_weights = ()
         | 
| 401 | 
            +
             | 
| 402 | 
            +
                    for decoder_layer in self.layers:
         | 
| 403 | 
            +
                        if output_hidden_states:
         | 
| 404 | 
            +
                            all_hidden_states += (hidden_states,)
         | 
| 405 | 
            +
             | 
| 406 | 
            +
                        if self.gradient_checkpointing and self.training:
         | 
| 407 | 
            +
                            layer_outputs, router_logits = self._gradient_checkpointing_func(
         | 
| 408 | 
            +
                                decoder_layer.__call__,
         | 
| 409 | 
            +
                                hidden_states,
         | 
| 410 | 
            +
                                position_embeddings,
         | 
| 411 | 
            +
                                routing_weights,
         | 
| 412 | 
            +
                                causal_mask,
         | 
| 413 | 
            +
                                position_ids,
         | 
| 414 | 
            +
                                past_key_values,
         | 
| 415 | 
            +
                                cache_position,
         | 
| 416 | 
            +
                            )
         | 
| 417 | 
            +
                        else:
         | 
| 418 | 
            +
                            layer_outputs, router_logits = decoder_layer(
         | 
| 419 | 
            +
                                hidden_states,
         | 
| 420 | 
            +
                                position_embeddings=position_embeddings,
         | 
| 421 | 
            +
                                routing_weights=routing_weights,
         | 
| 422 | 
            +
                                attention_mask=causal_mask,
         | 
| 423 | 
            +
                                position_ids=position_ids,
         | 
| 424 | 
            +
                                past_key_value=past_key_values,
         | 
| 425 | 
            +
                                output_attentions=output_attentions,
         | 
| 426 | 
            +
                                use_cache=use_cache,
         | 
| 427 | 
            +
                                cache_position=cache_position,
         | 
| 428 | 
            +
                                **kwargs,
         | 
| 429 | 
            +
                            )
         | 
| 430 | 
            +
             | 
| 431 | 
            +
                        hidden_states = layer_outputs
         | 
| 432 | 
            +
             | 
| 433 | 
            +
                        if output_attentions:
         | 
| 434 | 
            +
                            all_self_attns += (layer_outputs[1],)
         | 
| 435 | 
            +
             | 
| 436 | 
            +
                        all_routing_weights += (router_logits,)
         | 
| 437 | 
            +
             | 
| 438 | 
            +
                    hidden_states = self.final_norm(hidden_states)
         | 
| 439 | 
            +
             | 
| 440 | 
            +
                    # add hidden states from the last decoder layer
         | 
| 441 | 
            +
                    if output_hidden_states:
         | 
| 442 | 
            +
                        all_hidden_states += (hidden_states,)
         | 
| 443 | 
            +
                
         | 
| 444 | 
            +
                    # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
         | 
| 445 | 
            +
                    slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
         | 
| 446 | 
            +
                    logits = self.lm_head(hidden_states[:, slice_indices, :])
         | 
| 447 | 
            +
             | 
| 448 | 
            +
                    loss = None
         | 
| 449 | 
            +
                    if labels is not None:
         | 
| 450 | 
            +
                        loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
         | 
| 451 | 
            +
             | 
| 452 | 
            +
                        loss += keep_alive_zero(self)
         | 
| 453 | 
            +
                        
         | 
| 454 | 
            +
                        aux_loss = None
         | 
| 455 | 
            +
                        if self.use_load_balancing:
         | 
| 456 | 
            +
                            aux_loss = load_balancing_loss_func(
         | 
| 457 | 
            +
                                all_routing_weights,
         | 
| 458 | 
            +
                                self.config.num_experts,
         | 
| 459 | 
            +
                                self.config.num_experts_per_tok,
         | 
| 460 | 
            +
                                attention_mask,
         | 
| 461 | 
            +
                            )
         | 
| 462 | 
            +
                            loss += self.router_aux_loss_coef * aux_loss.to(loss.device)  # make sure to reside in the same device
         | 
| 463 | 
            +
             | 
| 464 | 
            +
                    if not return_dict:
         | 
| 465 | 
            +
                        output = (logits,) + (past_key_values, all_hidden_states, all_self_attns, all_routing_weights) if use_cache else (logits, all_hidden_states, all_self_attns, all_routing_weights) 
         | 
| 466 | 
            +
                        return (loss,) + output if loss is not None else output
         | 
| 467 | 
            +
             | 
| 468 | 
            +
                    return CausalLMOutputWithPast(
         | 
| 469 | 
            +
                        loss=loss,
         | 
| 470 | 
            +
                        logits=logits,
         | 
| 471 | 
            +
                        past_key_values=past_key_values if use_cache else None,
         | 
| 472 | 
            +
                        hidden_states=all_hidden_states,
         | 
| 473 | 
            +
                        attentions=all_self_attns,
         | 
| 474 | 
            +
                        routing_weights=all_routing_weights,
         | 
| 475 | 
            +
                    )
         | 
| 476 | 
            +
             | 
| 477 | 
            +
                def _update_causal_mask(
         | 
| 478 | 
            +
                    self,
         | 
| 479 | 
            +
                    attention_mask: torch.Tensor,
         | 
| 480 | 
            +
                    input_tensor: torch.Tensor,
         | 
| 481 | 
            +
                    cache_position: torch.Tensor,
         | 
| 482 | 
            +
                    past_key_values: Cache,
         | 
| 483 | 
            +
                    output_attentions: bool,
         | 
| 484 | 
            +
                ):
         | 
| 485 | 
            +
                    if self.config._attn_implementation == "flash_attention_2":
         | 
| 486 | 
            +
                        if attention_mask is not None and 0.0 in attention_mask:
         | 
| 487 | 
            +
                            return attention_mask
         | 
| 488 | 
            +
                        return None
         | 
| 489 | 
            +
             | 
| 490 | 
            +
                    # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
         | 
| 491 | 
            +
                    # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
         | 
| 492 | 
            +
                    # to infer the attention mask.
         | 
| 493 | 
            +
                    past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
         | 
| 494 | 
            +
                    using_static_cache = isinstance(past_key_values, StaticCache)
         | 
| 495 | 
            +
             | 
| 496 | 
            +
                    # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
         | 
| 497 | 
            +
                    if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
         | 
| 498 | 
            +
                        if AttentionMaskConverter._ignore_causal_mask_sdpa(
         | 
| 499 | 
            +
                            attention_mask,
         | 
| 500 | 
            +
                            inputs_embeds=input_tensor,
         | 
| 501 | 
            +
                            past_key_values_length=past_seen_tokens,
         | 
| 502 | 
            +
                            is_training=self.training,
         | 
| 503 | 
            +
                        ):
         | 
| 504 | 
            +
                            return None
         | 
| 505 | 
            +
             | 
| 506 | 
            +
                    dtype, device = input_tensor.dtype, input_tensor.device
         | 
| 507 | 
            +
                    min_dtype = torch.finfo(dtype).min
         | 
| 508 | 
            +
                    sequence_length = input_tensor.shape[1]
         | 
| 509 | 
            +
                    if using_static_cache:
         | 
| 510 | 
            +
                        target_length = past_key_values.get_max_length()
         | 
| 511 | 
            +
                    else:
         | 
| 512 | 
            +
                        target_length = (
         | 
| 513 | 
            +
                            attention_mask.shape[-1]
         | 
| 514 | 
            +
                            if isinstance(attention_mask, torch.Tensor)
         | 
| 515 | 
            +
                            else past_seen_tokens + sequence_length + 1
         | 
| 516 | 
            +
                        )
         | 
| 517 | 
            +
             | 
| 518 | 
            +
                    # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
         | 
| 519 | 
            +
                    causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
         | 
| 520 | 
            +
                        attention_mask,
         | 
| 521 | 
            +
                        sequence_length=sequence_length,
         | 
| 522 | 
            +
                        target_length=target_length,
         | 
| 523 | 
            +
                        dtype=dtype,
         | 
| 524 | 
            +
                        device=device,
         | 
| 525 | 
            +
                        min_dtype=min_dtype,
         | 
| 526 | 
            +
                        cache_position=cache_position,
         | 
| 527 | 
            +
                        batch_size=input_tensor.shape[0],
         | 
| 528 | 
            +
                    )
         | 
| 529 | 
            +
             | 
| 530 | 
            +
                    if (
         | 
| 531 | 
            +
                        self.config._attn_implementation == "sdpa"
         | 
| 532 | 
            +
                        and attention_mask is not None
         | 
| 533 | 
            +
                        and attention_mask.device.type == "cuda"
         | 
| 534 | 
            +
                        and not output_attentions
         | 
| 535 | 
            +
                    ):
         | 
| 536 | 
            +
                        # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
         | 
| 537 | 
            +
                        # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
         | 
| 538 | 
            +
                        # Details: https://github.com/pytorch/pytorch/issues/110213
         | 
| 539 | 
            +
                        causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
         | 
| 540 | 
            +
             | 
| 541 | 
            +
                    return causal_mask
         | 
| 542 | 
            +
             | 
| 543 | 
            +
                def load_pretrained(self, model_name):
         | 
| 544 | 
            +
                    base_model: LlamaForCausalLM = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
         | 
| 545 | 
            +
                    self.lm_head.load_state_dict(base_model.lm_head.state_dict())
         | 
| 546 | 
            +
                    self.embed_tokens.load_state_dict(base_model.get_input_embeddings().state_dict())
         | 
| 547 | 
            +
                    self.rotary_emb.load_state_dict(base_model.model.rotary_emb.state_dict())
         | 
| 548 | 
            +
                    self.final_norm.load_state_dict(base_model.model.norm.state_dict())
         | 
| 549 | 
            +
             | 
| 550 | 
            +
                    for layer_idx, layer in enumerate(self.layers):
         | 
| 551 | 
            +
                        
         | 
| 552 | 
            +
                        attn_layer = base_model.model.layers[layer_idx].self_attn.state_dict()
         | 
| 553 | 
            +
                        layer.self_attn.load_state_dict(attn_layer)
         | 
| 554 | 
            +
                        
         | 
| 555 | 
            +
                        input_layernorm = base_model.model.layers[layer_idx].input_layernorm.state_dict()
         | 
| 556 | 
            +
                        layer.input_layernorm.load_state_dict(input_layernorm)
         | 
| 557 | 
            +
             | 
| 558 | 
            +
                        post_attention_layernorm = base_model.model.layers[layer_idx].post_attention_layernorm.state_dict()
         | 
| 559 | 
            +
                        layer.post_attention_layernorm.load_state_dict(post_attention_layernorm)
         | 
| 560 | 
            +
                        
         | 
| 561 | 
            +
                        mlp_model_layer = base_model.model.layers[layer_idx].mlp.state_dict()
         | 
| 562 | 
            +
                        for expert in layer.block_sparse_moe.experts:
         | 
| 563 | 
            +
                            expert.load_state_dict(mlp_model_layer)
         | 
| 564 | 
            +
             | 
| 565 | 
            +
                def prepare_inputs_for_generation(
         | 
| 566 | 
            +
                    self,
         | 
| 567 | 
            +
                    input_ids,
         | 
| 568 | 
            +
                    past_key_values=None,
         | 
| 569 | 
            +
                    attention_mask=None,
         | 
| 570 | 
            +
                    inputs_embeds=None,
         | 
| 571 | 
            +
                    cache_position=None,
         | 
| 572 | 
            +
                    position_ids=None,
         | 
| 573 | 
            +
                    experts_ablate=None,
         | 
| 574 | 
            +
                    use_cache=True,
         | 
| 575 | 
            +
                    num_logits_to_keep=None,
         | 
| 576 | 
            +
                    **kwargs,
         | 
| 577 | 
            +
                ):
         | 
| 578 | 
            +
                    
         | 
| 579 | 
            +
                    # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
         | 
| 580 | 
            +
                    # Exception 1: when passing input_embeds, input_ids may be missing entries
         | 
| 581 | 
            +
                    # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
         | 
| 582 | 
            +
                    if past_key_values is not None:
         | 
| 583 | 
            +
                        if inputs_embeds is not None:  # Exception 1
         | 
| 584 | 
            +
                            input_ids = input_ids[:, -cache_position.shape[0] :]
         | 
| 585 | 
            +
                        elif input_ids.shape[1] != cache_position.shape[0]:  # Default case (the "else", a no op, is Exception 2)
         | 
| 586 | 
            +
                            input_ids = input_ids[:, cache_position]
         | 
| 587 | 
            +
             | 
| 588 | 
            +
                    if attention_mask is not None and position_ids is None:
         | 
| 589 | 
            +
                        # create position_ids on the fly for batch generation
         | 
| 590 | 
            +
                        position_ids = attention_mask.long().cumsum(-1) - 1
         | 
| 591 | 
            +
                        position_ids.masked_fill_(attention_mask == 0, 1)
         | 
| 592 | 
            +
                        if past_key_values:
         | 
| 593 | 
            +
                            position_ids = position_ids[:, -input_ids.shape[1] :]
         | 
| 594 | 
            +
             | 
| 595 | 
            +
                            # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s  `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
         | 
| 596 | 
            +
                            position_ids = position_ids.clone(memory_format=torch.contiguous_format)
         | 
| 597 | 
            +
             | 
| 598 | 
            +
                    # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
         | 
| 599 | 
            +
                    if inputs_embeds is not None and cache_position[0] == 0:
         | 
| 600 | 
            +
                        model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
         | 
| 601 | 
            +
                    else:
         | 
| 602 | 
            +
                        # The clone here is for the same reason as for `position_ids`.
         | 
| 603 | 
            +
                        model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
         | 
| 604 | 
            +
             | 
| 605 | 
            +
                    if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
         | 
| 606 | 
            +
                        if model_inputs["inputs_embeds"] is not None:
         | 
| 607 | 
            +
                            batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
         | 
| 608 | 
            +
                            device = model_inputs["inputs_embeds"].device
         | 
| 609 | 
            +
                        else:
         | 
| 610 | 
            +
                            batch_size, sequence_length = model_inputs["input_ids"].shape
         | 
| 611 | 
            +
                            device = model_inputs["input_ids"].device
         | 
| 612 | 
            +
             | 
| 613 | 
            +
                        dtype = self.lm_head.weight.dtype
         | 
| 614 | 
            +
                        min_dtype = torch.finfo(dtype).min
         | 
| 615 | 
            +
             | 
| 616 | 
            +
                        attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
         | 
| 617 | 
            +
                            attention_mask,
         | 
| 618 | 
            +
                            sequence_length=sequence_length,
         | 
| 619 | 
            +
                            target_length=past_key_values.get_max_length(),
         | 
| 620 | 
            +
                            dtype=dtype,
         | 
| 621 | 
            +
                            device=device,
         | 
| 622 | 
            +
                            min_dtype=min_dtype,
         | 
| 623 | 
            +
                            cache_position=cache_position,
         | 
| 624 | 
            +
                            batch_size=batch_size,
         | 
| 625 | 
            +
                        )
         | 
| 626 | 
            +
             | 
| 627 | 
            +
                    if num_logits_to_keep is not None:
         | 
| 628 | 
            +
                        model_inputs["num_logits_to_keep"] = num_logits_to_keep
         | 
| 629 | 
            +
             | 
| 630 | 
            +
                    model_inputs.update(
         | 
| 631 | 
            +
                        {
         | 
| 632 | 
            +
                            "experts_ablate": experts_ablate,
         | 
| 633 | 
            +
                            "position_ids": position_ids,
         | 
| 634 | 
            +
                            "cache_position": cache_position,
         | 
| 635 | 
            +
                            "past_key_values": past_key_values,
         | 
| 636 | 
            +
                            "use_cache": use_cache,
         | 
| 637 | 
            +
                            "attention_mask": attention_mask,
         | 
| 638 | 
            +
                        }
         | 
| 639 | 
            +
                    )
         | 
| 640 | 
            +
                    return model_inputs
         | 
| 641 | 
            +
                
         | 
| 642 | 
            +
            def load_balancing_loss_func(
         | 
| 643 | 
            +
                gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None],
         | 
| 644 | 
            +
                num_experts: Optional[int] = None,
         | 
| 645 | 
            +
                top_k=2,
         | 
| 646 | 
            +
                attention_mask: Optional[torch.Tensor] = None,
         | 
| 647 | 
            +
            ) -> Union[torch.Tensor, int]:
         | 
| 648 | 
            +
                r"""
         | 
| 649 | 
            +
                Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
         | 
| 650 | 
            +
             | 
| 651 | 
            +
                See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
         | 
| 652 | 
            +
                function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
         | 
| 653 | 
            +
                experts is too unbalanced.
         | 
| 654 | 
            +
             | 
| 655 | 
            +
                Args:
         | 
| 656 | 
            +
                    gate_logits:
         | 
| 657 | 
            +
                        Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
         | 
| 658 | 
            +
                        shape [batch_size X sequence_length, num_experts].
         | 
| 659 | 
            +
                    num_experts:
         | 
| 660 | 
            +
                        Number of experts
         | 
| 661 | 
            +
                    top_k:
         | 
| 662 | 
            +
                        The number of experts to route per-token, can be also interpreted as the `top-k` routing
         | 
| 663 | 
            +
                        parameter.
         | 
| 664 | 
            +
                    attention_mask (`torch.Tensor`, *optional*):
         | 
| 665 | 
            +
                        The attention_mask used in forward function
         | 
| 666 | 
            +
                        shape [batch_size X sequence_length] if not None.
         | 
| 667 | 
            +
             | 
| 668 | 
            +
                Returns:
         | 
| 669 | 
            +
                    The auxiliary loss.
         | 
| 670 | 
            +
                """
         | 
| 671 | 
            +
                if gate_logits is None or not isinstance(gate_logits, tuple):
         | 
| 672 | 
            +
                    return 0
         | 
| 673 | 
            +
             | 
| 674 | 
            +
                if isinstance(gate_logits, tuple):
         | 
| 675 | 
            +
                    compute_device = gate_logits[0].device
         | 
| 676 | 
            +
                    concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
         | 
| 677 | 
            +
             | 
| 678 | 
            +
                routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
         | 
| 679 | 
            +
             | 
| 680 | 
            +
                _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
         | 
| 681 | 
            +
             | 
| 682 | 
            +
                expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
         | 
| 683 | 
            +
             | 
| 684 | 
            +
                if attention_mask is None:
         | 
| 685 | 
            +
                    # Compute the percentage of tokens routed to each experts
         | 
| 686 | 
            +
                    tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
         | 
| 687 | 
            +
             | 
| 688 | 
            +
                    # Compute the average probability of routing to these experts
         | 
| 689 | 
            +
                    router_prob_per_expert = torch.mean(routing_weights, dim=0)
         | 
| 690 | 
            +
                else:
         | 
| 691 | 
            +
                    batch_size, sequence_length = attention_mask.shape
         | 
| 692 | 
            +
                    num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
         | 
| 693 | 
            +
             | 
| 694 | 
            +
                    # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
         | 
| 695 | 
            +
                    expert_attention_mask = (
         | 
| 696 | 
            +
                        attention_mask[None, :, :, None, None]
         | 
| 697 | 
            +
                        .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
         | 
| 698 | 
            +
                        .reshape(-1, top_k, num_experts)
         | 
| 699 | 
            +
                        .to(compute_device)
         | 
| 700 | 
            +
                    )
         | 
| 701 | 
            +
             | 
| 702 | 
            +
                    # Compute the percentage of tokens routed to each experts
         | 
| 703 | 
            +
                    tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
         | 
| 704 | 
            +
                        expert_attention_mask, dim=0
         | 
| 705 | 
            +
                    )
         | 
| 706 | 
            +
             | 
| 707 | 
            +
                    # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
         | 
| 708 | 
            +
                    router_per_expert_attention_mask = (
         | 
| 709 | 
            +
                        attention_mask[None, :, :, None]
         | 
| 710 | 
            +
                        .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
         | 
| 711 | 
            +
                        .reshape(-1, num_experts)
         | 
| 712 | 
            +
                        .to(compute_device)
         | 
| 713 | 
            +
                    )
         | 
| 714 | 
            +
             | 
| 715 | 
            +
                    # Compute the average probability of routing to these experts
         | 
| 716 | 
            +
                    router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
         | 
| 717 | 
            +
                        router_per_expert_attention_mask, dim=0
         | 
| 718 | 
            +
                    )
         | 
| 719 | 
            +
             | 
| 720 | 
            +
                overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
         | 
| 721 | 
            +
                return overall_loss * num_experts
         | 
| 722 | 
            +
             | 
| 723 | 
            +
             | 
| 724 | 
            +
            AutoConfig.register("micro_llama_moe", MiCRoLlamaMoEConfig)
         | 
| 725 | 
            +
            AutoModelForCausalLM.register(MiCRoLlamaMoEConfig, MiCRoLlamaMoE)
         | 
    	
        models/micro_olmo.py
    ADDED
    
    | @@ -0,0 +1,528 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import Callable, Optional, Tuple, Union
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import yaml
         | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import torch.nn as nn
         | 
| 6 | 
            +
            from torch.nn import functional as F
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from transformers import AutoModelForCausalLM
         | 
| 9 | 
            +
            from transformers.activations import ACT2FN
         | 
| 10 | 
            +
            from transformers.cache_utils import Cache, DynamicCache
         | 
| 11 | 
            +
            from transformers.generation import GenerationMixin
         | 
| 12 | 
            +
            from transformers.modeling_attn_mask_utils import AttentionMaskConverter
         | 
| 13 | 
            +
            from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
         | 
| 14 | 
            +
            # from transformers.modeling_layers import GradientCheckpointingLayer
         | 
| 15 | 
            +
            from transformers.modeling_outputs import BaseModelOutputWithPast
         | 
| 16 | 
            +
            from transformers.processing_utils import Unpack
         | 
| 17 | 
            +
            from transformers.utils import (
         | 
| 18 | 
            +
                add_start_docstrings,
         | 
| 19 | 
            +
                add_start_docstrings_to_model_forward,
         | 
| 20 | 
            +
                is_torch_flex_attn_available,
         | 
| 21 | 
            +
                logging,
         | 
| 22 | 
            +
                replace_return_docstrings,
         | 
| 23 | 
            +
            )
         | 
| 24 | 
            +
            from transformers.models.olmo2.configuration_olmo2 import Olmo2Config
         | 
| 25 | 
            +
            from transformers.models.olmo2.modeling_olmo2 import (
         | 
| 26 | 
            +
                Olmo2RMSNorm,
         | 
| 27 | 
            +
                Olmo2Attention,
         | 
| 28 | 
            +
                Olmo2MLP,
         | 
| 29 | 
            +
                Olmo2DecoderLayer,
         | 
| 30 | 
            +
                Olmo2RotaryEmbedding,
         | 
| 31 | 
            +
                Olmo2PreTrainedModel,
         | 
| 32 | 
            +
                rotate_half,
         | 
| 33 | 
            +
                apply_rotary_pos_emb,
         | 
| 34 | 
            +
                repeat_kv,
         | 
| 35 | 
            +
                eager_attention_forward,
         | 
| 36 | 
            +
            )
         | 
| 37 | 
            +
             | 
| 38 | 
            +
             | 
| 39 | 
            +
            if is_torch_flex_attn_available():
         | 
| 40 | 
            +
                from torch.nn.attention.flex_attention import BlockMask
         | 
| 41 | 
            +
             | 
| 42 | 
            +
            from models.modules import CausalLMOutputWithPast
         | 
| 43 | 
            +
             | 
| 44 | 
            +
            logger = logging.get_logger(__name__)
         | 
| 45 | 
            +
             | 
| 46 | 
            +
            class MiCRoOLMo2DecoderLayer(nn.Module):
         | 
| 47 | 
            +
                def __init__(self, config: Olmo2Config, layer_idx: int):
         | 
| 48 | 
            +
                    super().__init__()
         | 
| 49 | 
            +
                    self.hidden_size = config.hidden_size
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                    self.num_experts  = config.num_experts
         | 
| 52 | 
            +
                    self.top_k        = config.num_experts_per_tok
         | 
| 53 | 
            +
                    self.use_router   = config.use_router
         | 
| 54 | 
            +
                    self.ablate       = config.ablate or []
         | 
| 55 | 
            +
                    self.num_layers   = config.backbone_num_layers
         | 
| 56 | 
            +
                    self.layer_idx    = layer_idx
         | 
| 57 | 
            +
                    self.jitter_noise = config.jitter_noise
         | 
| 58 | 
            +
                    self.config = config
         | 
| 59 | 
            +
                    self.head_dim = config.hidden_size // config.num_attention_heads
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                    if isinstance(self.ablate, str):
         | 
| 62 | 
            +
                        self.ablate = [self.ablate]
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    # gating head
         | 
| 65 | 
            +
                    self.gate = nn.Sequential(
         | 
| 66 | 
            +
                        nn.Linear(self.hidden_size, self.hidden_size, bias=False),
         | 
| 67 | 
            +
                        nn.Linear(self.hidden_size, self.num_experts, bias=False),
         | 
| 68 | 
            +
                    )
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                    self.experts = nn.ModuleList([
         | 
| 71 | 
            +
                        Olmo2DecoderLayer(config, layer_idx * self.num_experts + expert_idx)
         | 
| 72 | 
            +
                        for expert_idx in range(self.num_experts)
         | 
| 73 | 
            +
                    ])
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                def forward(
         | 
| 76 | 
            +
                    self,
         | 
| 77 | 
            +
                    hidden_states: torch.Tensor,
         | 
| 78 | 
            +
                    routing_weights: Optional[torch.Tensor] = None,
         | 
| 79 | 
            +
                    attention_mask: Optional[torch.Tensor] = None,
         | 
| 80 | 
            +
                    position_ids: Optional[torch.LongTensor] = None,
         | 
| 81 | 
            +
                    past_key_value: Optional[Cache] = None,
         | 
| 82 | 
            +
                    output_attentions: Optional[bool] = False,
         | 
| 83 | 
            +
                    use_cache: Optional[bool] = False,
         | 
| 84 | 
            +
                    cache_position: Optional[torch.LongTensor] = None,
         | 
| 85 | 
            +
                    position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46
         | 
| 86 | 
            +
                    **kwargs,
         | 
| 87 | 
            +
                ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
         | 
| 88 | 
            +
                
         | 
| 89 | 
            +
                    batch_size, sequence_length, hidden_dim = hidden_states.shape
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                    if self.training and self.jitter_noise > 0:
         | 
| 92 | 
            +
                        hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
         | 
| 93 | 
            +
                    
         | 
| 94 | 
            +
                    if self.use_router:
         | 
| 95 | 
            +
                        router_logits = self.gate(hidden_states)
         | 
| 96 | 
            +
                        if "logic" in self.ablate:
         | 
| 97 | 
            +
                            router_logits[..., 0] = -torch.inf
         | 
| 98 | 
            +
                        if "social" in self.ablate:
         | 
| 99 | 
            +
                            router_logits[..., 1] = -torch.inf
         | 
| 100 | 
            +
                        if "world" in self.ablate:
         | 
| 101 | 
            +
                            router_logits[..., 2] = -torch.inf
         | 
| 102 | 
            +
                        if "language" in self.ablate:
         | 
| 103 | 
            +
                            router_logits[..., 3] = -torch.inf
         | 
| 104 | 
            +
                        routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float)
         | 
| 105 | 
            +
                    else:
         | 
| 106 | 
            +
                        if len(routing_weights.shape) == 2:
         | 
| 107 | 
            +
                            routing_weights = routing_weights.unsqueeze(1).tile((1,sequence_length,1)).float()
         | 
| 108 | 
            +
                        else:
         | 
| 109 | 
            +
                            routing_weights = routing_weights.float()
         | 
| 110 | 
            +
                        router_logits = routing_weights
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                    routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
         | 
| 113 | 
            +
                    routing_weights /= (routing_weights.sum(dim=-1, keepdim=True) + 1e-9)
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    # we cast back to the input dtype
         | 
| 116 | 
            +
                    routing_weights = routing_weights.to(hidden_states.dtype)
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                    # We'll accumulate outputs here
         | 
| 119 | 
            +
                    final_hidden_states = torch.zeros_like(hidden_states)
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                    # Flatten final_hidden_states to [batch_size * seq_len, hidden_dim]
         | 
| 122 | 
            +
                    # so we can do a 2D "index_add_" at the end of each loop.
         | 
| 123 | 
            +
                    final_hidden_states_2d = final_hidden_states.view(-1, hidden_dim)
         | 
| 124 | 
            +
                
         | 
| 125 | 
            +
                    # One hot encode the selected experts to create an expert mask
         | 
| 126 | 
            +
                    # this will be used to easily index which expert is going to be sollicitated
         | 
| 127 | 
            +
                    expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts)
         | 
| 128 | 
            +
                    #^ [batch_size, seq_len, top_k, num_experts]
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                    # Loop over all available experts in the model and perform the computation on each expert
         | 
| 131 | 
            +
                    for expert_idx in range(self.num_experts):
         | 
| 132 | 
            +
                        expert_layer: Olmo2DecoderLayer = self.experts[expert_idx]
         | 
| 133 | 
            +
                        batch_indices, seq_indices, top_k_indices = torch.where(expert_mask[..., expert_idx])
         | 
| 134 | 
            +
                    
         | 
| 135 | 
            +
                        if not self.training and sequence_length == 1 and batch_indices.numel() == 0:
         | 
| 136 | 
            +
                            if past_key_value is not None:
         | 
| 137 | 
            +
                                
         | 
| 138 | 
            +
                                input_shape = hidden_states.shape[:-1]
         | 
| 139 | 
            +
                                hidden_shape = (*input_shape, -1, self.head_dim)
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                                key_states = expert_layer.self_attn.k_proj(hidden_states)
         | 
| 142 | 
            +
                                key_states = expert_layer.self_attn.k_norm(key_states).view(hidden_shape).transpose(1, 2)
         | 
| 143 | 
            +
                                value_states = expert_layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
         | 
| 144 | 
            +
             | 
| 145 | 
            +
             | 
| 146 | 
            +
                                cos, sin = position_embeddings
         | 
| 147 | 
            +
                                _, key_states = apply_rotary_pos_emb(key_states, key_states, cos, sin)
         | 
| 148 | 
            +
                                # sin and cos are specific to RoPE models; cache_position needed for the static cache
         | 
| 149 | 
            +
                                cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
         | 
| 150 | 
            +
                                past_key_value.update(key_states, value_states, self.layer_idx * self.num_experts + expert_idx, cache_kwargs)
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                            continue
         | 
| 153 | 
            +
                    
         | 
| 154 | 
            +
                        current_hidden_states = expert_layer(
         | 
| 155 | 
            +
                            hidden_states=hidden_states,
         | 
| 156 | 
            +
                            attention_mask=attention_mask,
         | 
| 157 | 
            +
                            position_ids=position_ids,
         | 
| 158 | 
            +
                            past_key_value=past_key_value,
         | 
| 159 | 
            +
                            output_attentions=output_attentions,
         | 
| 160 | 
            +
                            use_cache=use_cache,
         | 
| 161 | 
            +
                            cache_position=cache_position,
         | 
| 162 | 
            +
                            position_embeddings=position_embeddings,
         | 
| 163 | 
            +
                            **kwargs,
         | 
| 164 | 
            +
                        )[0]
         | 
| 165 | 
            +
                        
         | 
| 166 | 
            +
                        flat_idx = batch_indices * sequence_length + seq_indices
         | 
| 167 | 
            +
                        expert_weights = routing_weights[batch_indices, seq_indices, top_k_indices].unsqueeze(-1)
         | 
| 168 | 
            +
                        current_hidden_states = current_hidden_states[batch_indices, seq_indices] * expert_weights
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                        final_hidden_states_2d.index_add_(0, flat_idx, current_hidden_states.to(hidden_states.dtype))
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                    final_hidden_states = final_hidden_states_2d.view(batch_size, sequence_length, hidden_dim)
         | 
| 173 | 
            +
                    return final_hidden_states, router_logits
         | 
| 174 | 
            +
                
         | 
| 175 | 
            +
            class MiCRoOLMo(Olmo2PreTrainedModel, GenerationMixin):
         | 
| 176 | 
            +
                """
         | 
| 177 | 
            +
                Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Olmo2DecoderLayer`]
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                Args:
         | 
| 180 | 
            +
                    config: Olmo2Config
         | 
| 181 | 
            +
                """
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                _tied_weights_keys = ["lm_head.weight"]
         | 
| 184 | 
            +
                _tp_plan = {"lm_head": "colwise_rep"}
         | 
| 185 | 
            +
                _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                def __init__(self, config: Olmo2Config):
         | 
| 188 | 
            +
                    with open(config.config_path, 'r', encoding="utf-8") as file:
         | 
| 189 | 
            +
                        run_config = yaml.load(file.read(), Loader=yaml.FullLoader)
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                    self.config: Olmo2Config = config
         | 
| 192 | 
            +
                    self.config.torch_dtype = torch.bfloat16
         | 
| 193 | 
            +
                    self.config.use_bfloat16 = True
         | 
| 194 | 
            +
                    self.config._attn_implementation = "flash_attention_2" # {sdpa, flash_attention_2, eager}
         | 
| 195 | 
            +
                    self.config.use_cache = True
         | 
| 196 | 
            +
                    self.config.backbone_num_layers = self.config.num_hidden_layers
         | 
| 197 | 
            +
                    self.config.num_hidden_layers = self.config.num_hidden_layers * run_config["num-experts"]
         | 
| 198 | 
            +
                    self.config.loss_type = "ForCausalLMLoss"
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                    self.padding_idx = config.pad_token_id
         | 
| 201 | 
            +
                    self.vocab_size = config.vocab_size
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                    self.gradient_checkpointing = False
         | 
| 204 | 
            +
                    super().__init__(config)
         | 
| 205 | 
            +
                    self.padding_idx = config.pad_token_id
         | 
| 206 | 
            +
                    self.vocab_size = config.vocab_size
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                    self.build_model(run_config)
         | 
| 209 | 
            +
                
         | 
| 210 | 
            +
                    # Initialize weights and apply final processing
         | 
| 211 | 
            +
                    self.post_init()
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                def get_input_embeddings(self):
         | 
| 214 | 
            +
                    return self.embed_tokens
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                def set_input_embeddings(self, value):
         | 
| 217 | 
            +
                    self.embed_tokens = value
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                def get_output_embeddings(self):
         | 
| 220 | 
            +
                    return self.lm_head
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                def set_output_embeddings(self, value):
         | 
| 223 | 
            +
                    self.lm_head = value
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                def build_model(self, run_config):
         | 
| 226 | 
            +
                    self.gradient_checkpointing = False
         | 
| 227 | 
            +
                    self.config.num_experts = run_config["num-experts"]
         | 
| 228 | 
            +
                    self.config.use_router = run_config["use-router"]
         | 
| 229 | 
            +
                    self.config.num_experts_per_tok = run_config["top-k-experts"]
         | 
| 230 | 
            +
                    self.config.jitter_noise = run_config["jitter-noise"]
         | 
| 231 | 
            +
                    self.config.loss_method = run_config.get("loss", "all")
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                    self.run_config = run_config        
         | 
| 234 | 
            +
                    # Qwen2 model
         | 
| 235 | 
            +
                    self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.hidden_size, self.padding_idx)
         | 
| 236 | 
            +
                    self.layers = nn.ModuleList([MiCRoOLMo2DecoderLayer(self.config, layer_idx) for layer_idx in range(self.config.backbone_num_layers)])
         | 
| 237 | 
            +
                    self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)
         | 
| 238 | 
            +
                    self.rotary_emb = Olmo2RotaryEmbedding(config=self.config)
         | 
| 239 | 
            +
                    self.norm = Olmo2RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
         | 
| 240 | 
            +
             | 
| 241 | 
            +
                    # Freeze Model
         | 
| 242 | 
            +
                    for param in self.parameters():
         | 
| 243 | 
            +
                        param.requires_grad = False
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                    # Unfreeze Modules
         | 
| 246 | 
            +
                    if "reasoners" in run_config["trainable"]:
         | 
| 247 | 
            +
                        print(">> Unfreezing Reasoning Modules")
         | 
| 248 | 
            +
                        for layer in self.layers:
         | 
| 249 | 
            +
                            layer: MiCRoOLMo2DecoderLayer
         | 
| 250 | 
            +
                            for param in layer.experts.parameters():
         | 
| 251 | 
            +
                                param.requires_grad = True
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                    if "model" in run_config["trainable"]:
         | 
| 254 | 
            +
                        print(">> Unfreezing Model")
         | 
| 255 | 
            +
                        for param in self.layers.parameters():
         | 
| 256 | 
            +
                            param.requires_grad = True
         | 
| 257 | 
            +
             | 
| 258 | 
            +
                        for param in self.lm_head.parameters():
         | 
| 259 | 
            +
                            param.requires_grad = True
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                        for param in self.rotary_emb.parameters():
         | 
| 262 | 
            +
                            param.requires_grad = True
         | 
| 263 | 
            +
             | 
| 264 | 
            +
                        for param in self.norm.parameters():
         | 
| 265 | 
            +
                            param.requires_grad = True
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                        for param in self.embed_tokens.parameters():
         | 
| 268 | 
            +
                            param.requires_grad = True
         | 
| 269 | 
            +
             | 
| 270 | 
            +
                        for layer in self.layers:
         | 
| 271 | 
            +
                            for param in layer.gate.parameters():
         | 
| 272 | 
            +
                                param.requires_grad = False
         | 
| 273 | 
            +
             | 
| 274 | 
            +
             | 
| 275 | 
            +
                    if "experts-router" in run_config["trainable"]:
         | 
| 276 | 
            +
                        print(">> Unfreezing Experts Router")
         | 
| 277 | 
            +
                        for layer in self.layers:
         | 
| 278 | 
            +
                            for param in layer.gate.parameters():
         | 
| 279 | 
            +
                                param.requires_grad = True
         | 
| 280 | 
            +
             | 
| 281 | 
            +
                def forward(
         | 
| 282 | 
            +
                    self,
         | 
| 283 | 
            +
                    input_ids: Optional[torch.LongTensor] = None,
         | 
| 284 | 
            +
                    attention_mask: Optional[torch.Tensor] = None,
         | 
| 285 | 
            +
                    position_ids: Optional[torch.LongTensor] = None,
         | 
| 286 | 
            +
                    routing_weights: Optional[torch.LongTensor] = None,
         | 
| 287 | 
            +
                    past_key_values: Optional[Cache] = None,
         | 
| 288 | 
            +
                    inputs_embeds: Optional[torch.FloatTensor] = None,
         | 
| 289 | 
            +
                    labels: Optional[torch.LongTensor] = None,
         | 
| 290 | 
            +
                    use_cache: Optional[bool] = None,
         | 
| 291 | 
            +
                    output_attentions: Optional[bool] = None,
         | 
| 292 | 
            +
                    output_hidden_states: Optional[bool] = None,
         | 
| 293 | 
            +
                    return_dict: Optional[bool] = None,
         | 
| 294 | 
            +
                    cache_position: Optional[torch.LongTensor] = None,
         | 
| 295 | 
            +
                    logits_to_keep: Union[int, torch.Tensor] = 0,
         | 
| 296 | 
            +
                    **kwargs: Unpack[FlashAttentionKwargs],
         | 
| 297 | 
            +
                ) -> BaseModelOutputWithPast:
         | 
| 298 | 
            +
                    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
         | 
| 299 | 
            +
                    output_hidden_states = (
         | 
| 300 | 
            +
                        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
         | 
| 301 | 
            +
                    )
         | 
| 302 | 
            +
                    use_cache = use_cache if use_cache is not None else self.config.use_cache
         | 
| 303 | 
            +
             | 
| 304 | 
            +
                    if (input_ids is None) ^ (inputs_embeds is not None):
         | 
| 305 | 
            +
                        raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
         | 
| 306 | 
            +
             | 
| 307 | 
            +
                    if self.gradient_checkpointing and self.training and use_cache:
         | 
| 308 | 
            +
                        logger.warning_once(
         | 
| 309 | 
            +
                            "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
         | 
| 310 | 
            +
                        )
         | 
| 311 | 
            +
                        use_cache = False
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                    # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
         | 
| 314 | 
            +
                    if not isinstance(past_key_values, (type(None), Cache)):
         | 
| 315 | 
            +
                        raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
         | 
| 316 | 
            +
             | 
| 317 | 
            +
                    if inputs_embeds is None:
         | 
| 318 | 
            +
                        inputs_embeds = self.embed_tokens(input_ids)
         | 
| 319 | 
            +
             | 
| 320 | 
            +
                    if use_cache and past_key_values is None:
         | 
| 321 | 
            +
                        past_key_values = DynamicCache()
         | 
| 322 | 
            +
             | 
| 323 | 
            +
                    if cache_position is None:
         | 
| 324 | 
            +
                        past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
         | 
| 325 | 
            +
                        cache_position = torch.arange(
         | 
| 326 | 
            +
                            past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
         | 
| 327 | 
            +
                        )
         | 
| 328 | 
            +
             | 
| 329 | 
            +
                    if position_ids is None:
         | 
| 330 | 
            +
                        position_ids = cache_position.unsqueeze(0)
         | 
| 331 | 
            +
             | 
| 332 | 
            +
                    causal_mask = self._update_causal_mask(
         | 
| 333 | 
            +
                        attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
         | 
| 334 | 
            +
                    )
         | 
| 335 | 
            +
             | 
| 336 | 
            +
                    hidden_states = inputs_embeds
         | 
| 337 | 
            +
             | 
| 338 | 
            +
                    # create position embeddings to be shared across the decoder layers
         | 
| 339 | 
            +
                    position_embeddings = self.rotary_emb(hidden_states, position_ids)
         | 
| 340 | 
            +
             | 
| 341 | 
            +
                    # decoder layers
         | 
| 342 | 
            +
                    all_hidden_states = () if output_hidden_states else None
         | 
| 343 | 
            +
                    all_self_attns = () if output_attentions else None
         | 
| 344 | 
            +
                    all_routing_weights = ()
         | 
| 345 | 
            +
             | 
| 346 | 
            +
                    for decoder_layer in self.layers:
         | 
| 347 | 
            +
                        if output_hidden_states:
         | 
| 348 | 
            +
                            all_hidden_states += (hidden_states,)
         | 
| 349 | 
            +
             | 
| 350 | 
            +
                        layer_outputs, router_logits = decoder_layer(
         | 
| 351 | 
            +
                            hidden_states,
         | 
| 352 | 
            +
                            routing_weights=routing_weights,
         | 
| 353 | 
            +
                            attention_mask=causal_mask,
         | 
| 354 | 
            +
                            position_ids=position_ids,
         | 
| 355 | 
            +
                            past_key_value=past_key_values,
         | 
| 356 | 
            +
                            output_attentions=output_attentions,
         | 
| 357 | 
            +
                            use_cache=use_cache,
         | 
| 358 | 
            +
                            cache_position=cache_position,
         | 
| 359 | 
            +
                            position_embeddings=position_embeddings,
         | 
| 360 | 
            +
                            **kwargs,
         | 
| 361 | 
            +
                            # **flash_attn_kwargs,
         | 
| 362 | 
            +
                        )
         | 
| 363 | 
            +
             | 
| 364 | 
            +
                        hidden_states = layer_outputs
         | 
| 365 | 
            +
             | 
| 366 | 
            +
                        # if output_attentions:
         | 
| 367 | 
            +
                        #     all_self_attns += (layer_outputs[1],)
         | 
| 368 | 
            +
                            
         | 
| 369 | 
            +
                        all_routing_weights += (router_logits,)
         | 
| 370 | 
            +
             | 
| 371 | 
            +
                    hidden_states = self.norm(hidden_states)
         | 
| 372 | 
            +
             | 
| 373 | 
            +
                    # add hidden states from the last decoder layer
         | 
| 374 | 
            +
                    if output_hidden_states:
         | 
| 375 | 
            +
                        all_hidden_states += (hidden_states,)
         | 
| 376 | 
            +
                
         | 
| 377 | 
            +
                    slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
         | 
| 378 | 
            +
                    logits = self.lm_head(hidden_states[:, slice_indices, :])
         | 
| 379 | 
            +
             | 
| 380 | 
            +
                    loss = None
         | 
| 381 | 
            +
                    if labels is not None:
         | 
| 382 | 
            +
                        loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
         | 
| 383 | 
            +
             | 
| 384 | 
            +
                    return CausalLMOutputWithPast(
         | 
| 385 | 
            +
                        loss=loss,
         | 
| 386 | 
            +
                        logits=logits,
         | 
| 387 | 
            +
                        past_key_values=past_key_values if use_cache else None,
         | 
| 388 | 
            +
                        hidden_states=all_hidden_states,
         | 
| 389 | 
            +
                        attentions=all_self_attns,
         | 
| 390 | 
            +
                        routing_weights=all_routing_weights,
         | 
| 391 | 
            +
                    )
         | 
| 392 | 
            +
             | 
| 393 | 
            +
                def load_pretrained(self, model_name):
         | 
| 394 | 
            +
                    base_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
         | 
| 395 | 
            +
                    self.lm_head.load_state_dict(base_model.lm_head.state_dict())
         | 
| 396 | 
            +
                    self.embed_tokens.load_state_dict(base_model.get_input_embeddings().state_dict())
         | 
| 397 | 
            +
                    self.rotary_emb.load_state_dict(base_model.model.rotary_emb.state_dict())
         | 
| 398 | 
            +
                    self.norm.load_state_dict(base_model.model.norm.state_dict())
         | 
| 399 | 
            +
                    for layer_idx, layer in enumerate(self.layers):
         | 
| 400 | 
            +
                        base_model_layer = base_model.model.layers[layer_idx].state_dict()
         | 
| 401 | 
            +
                        for expert in layer.experts:
         | 
| 402 | 
            +
                            expert.load_state_dict(base_model_layer)
         | 
| 403 | 
            +
             | 
| 404 | 
            +
                def _update_causal_mask(
         | 
| 405 | 
            +
                    self,
         | 
| 406 | 
            +
                    attention_mask: Union[torch.Tensor, "BlockMask"],
         | 
| 407 | 
            +
                    input_tensor: torch.Tensor,
         | 
| 408 | 
            +
                    cache_position: torch.Tensor,
         | 
| 409 | 
            +
                    past_key_values: Cache,
         | 
| 410 | 
            +
                    output_attentions: bool = False,
         | 
| 411 | 
            +
                ):
         | 
| 412 | 
            +
                    if self.config._attn_implementation == "flash_attention_2":
         | 
| 413 | 
            +
                        if attention_mask is not None and (attention_mask == 0.0).any():
         | 
| 414 | 
            +
                            return attention_mask
         | 
| 415 | 
            +
                        return None
         | 
| 416 | 
            +
                    if self.config._attn_implementation == "flex_attention":
         | 
| 417 | 
            +
                        if isinstance(attention_mask, torch.Tensor):
         | 
| 418 | 
            +
                            attention_mask = make_flex_block_causal_mask(attention_mask)
         | 
| 419 | 
            +
                        return attention_mask
         | 
| 420 | 
            +
             | 
| 421 | 
            +
                    # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
         | 
| 422 | 
            +
                    # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
         | 
| 423 | 
            +
                    # to infer the attention mask.
         | 
| 424 | 
            +
                    past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
         | 
| 425 | 
            +
                    using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
         | 
| 426 | 
            +
             | 
| 427 | 
            +
                    # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
         | 
| 428 | 
            +
                    if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
         | 
| 429 | 
            +
                        if AttentionMaskConverter._ignore_causal_mask_sdpa(
         | 
| 430 | 
            +
                            attention_mask,
         | 
| 431 | 
            +
                            inputs_embeds=input_tensor,
         | 
| 432 | 
            +
                            past_key_values_length=past_seen_tokens,
         | 
| 433 | 
            +
                            is_training=self.training,
         | 
| 434 | 
            +
                        ):
         | 
| 435 | 
            +
                            return None
         | 
| 436 | 
            +
             | 
| 437 | 
            +
                    dtype = input_tensor.dtype
         | 
| 438 | 
            +
                    sequence_length = input_tensor.shape[1]
         | 
| 439 | 
            +
                    if using_compilable_cache:
         | 
| 440 | 
            +
                        target_length = past_key_values.get_max_cache_shape()
         | 
| 441 | 
            +
                    else:
         | 
| 442 | 
            +
                        target_length = (
         | 
| 443 | 
            +
                            attention_mask.shape[-1]
         | 
| 444 | 
            +
                            if isinstance(attention_mask, torch.Tensor)
         | 
| 445 | 
            +
                            else past_seen_tokens + sequence_length + 1
         | 
| 446 | 
            +
                        )
         | 
| 447 | 
            +
             | 
| 448 | 
            +
                    # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
         | 
| 449 | 
            +
                    causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
         | 
| 450 | 
            +
                        attention_mask,
         | 
| 451 | 
            +
                        sequence_length=sequence_length,
         | 
| 452 | 
            +
                        target_length=target_length,
         | 
| 453 | 
            +
                        dtype=dtype,
         | 
| 454 | 
            +
                        cache_position=cache_position,
         | 
| 455 | 
            +
                        batch_size=input_tensor.shape[0],
         | 
| 456 | 
            +
                    )
         | 
| 457 | 
            +
             | 
| 458 | 
            +
                    if (
         | 
| 459 | 
            +
                        self.config._attn_implementation == "sdpa"
         | 
| 460 | 
            +
                        and attention_mask is not None
         | 
| 461 | 
            +
                        and attention_mask.device.type in ["cuda", "xpu", "npu"]
         | 
| 462 | 
            +
                        and not output_attentions
         | 
| 463 | 
            +
                    ):
         | 
| 464 | 
            +
                        # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
         | 
| 465 | 
            +
                        # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
         | 
| 466 | 
            +
                        # Details: https://github.com/pytorch/pytorch/issues/110213
         | 
| 467 | 
            +
                        min_dtype = torch.finfo(dtype).min
         | 
| 468 | 
            +
                        causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
         | 
| 469 | 
            +
             | 
| 470 | 
            +
                    return causal_mask
         | 
| 471 | 
            +
             | 
| 472 | 
            +
                @staticmethod
         | 
| 473 | 
            +
                def _prepare_4d_causal_attention_mask_with_cache_position(
         | 
| 474 | 
            +
                    attention_mask: torch.Tensor,
         | 
| 475 | 
            +
                    sequence_length: int,
         | 
| 476 | 
            +
                    target_length: int,
         | 
| 477 | 
            +
                    dtype: torch.dtype,
         | 
| 478 | 
            +
                    cache_position: torch.Tensor,
         | 
| 479 | 
            +
                    batch_size: int,
         | 
| 480 | 
            +
                    **kwargs,
         | 
| 481 | 
            +
                ):
         | 
| 482 | 
            +
                    """
         | 
| 483 | 
            +
                    Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
         | 
| 484 | 
            +
                    `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
         | 
| 485 | 
            +
             | 
| 486 | 
            +
                    Args:
         | 
| 487 | 
            +
                        attention_mask (`torch.Tensor`):
         | 
| 488 | 
            +
                            A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
         | 
| 489 | 
            +
                            `(batch_size, 1, query_length, key_value_length)`.
         | 
| 490 | 
            +
                        sequence_length (`int`):
         | 
| 491 | 
            +
                            The sequence length being processed.
         | 
| 492 | 
            +
                        target_length (`int`):
         | 
| 493 | 
            +
                            The target length: when generating with static cache, the mask should be as long as the static cache,
         | 
| 494 | 
            +
                            to account for the 0 padding, the part of the cache that is not filled yet.
         | 
| 495 | 
            +
                        dtype (`torch.dtype`):
         | 
| 496 | 
            +
                            The dtype to use for the 4D attention mask.
         | 
| 497 | 
            +
                        cache_position (`torch.Tensor`):
         | 
| 498 | 
            +
                            Indices depicting the position of the input sequence tokens in the sequence.
         | 
| 499 | 
            +
                        batch_size (`torch.Tensor`):
         | 
| 500 | 
            +
                            Batch size.
         | 
| 501 | 
            +
                    """
         | 
| 502 | 
            +
                    if attention_mask is not None and attention_mask.dim() == 4:
         | 
| 503 | 
            +
                        # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
         | 
| 504 | 
            +
                        causal_mask = attention_mask
         | 
| 505 | 
            +
                    else:
         | 
| 506 | 
            +
                        min_dtype = torch.finfo(dtype).min
         | 
| 507 | 
            +
                        causal_mask = torch.full(
         | 
| 508 | 
            +
                            (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
         | 
| 509 | 
            +
                        )
         | 
| 510 | 
            +
                        if sequence_length != 1:
         | 
| 511 | 
            +
                            causal_mask = torch.triu(causal_mask, diagonal=1)
         | 
| 512 | 
            +
                        causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
         | 
| 513 | 
            +
                        causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
         | 
| 514 | 
            +
                        if attention_mask is not None:
         | 
| 515 | 
            +
                            causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit
         | 
| 516 | 
            +
                            mask_length = attention_mask.shape[-1]
         | 
| 517 | 
            +
                            padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
         | 
| 518 | 
            +
                                causal_mask.device
         | 
| 519 | 
            +
                            )
         | 
| 520 | 
            +
                            padding_mask = padding_mask == 0
         | 
| 521 | 
            +
                            causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
         | 
| 522 | 
            +
                                padding_mask, min_dtype
         | 
| 523 | 
            +
                            )
         | 
| 524 | 
            +
             | 
| 525 | 
            +
                    return causal_mask
         | 
| 526 | 
            +
             | 
| 527 | 
            +
             | 
| 528 | 
            +
            __all__ = ["MiCRoOLMo"]
         | 
    	
        models/modules.py
    ADDED
    
    | @@ -0,0 +1,42 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from dataclasses import dataclass
         | 
| 2 | 
            +
            from typing import Optional, Tuple, List, Union
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            from transformers.modeling_outputs import ModelOutput
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            @dataclass
         | 
| 9 | 
            +
            class CausalLMOutputWithPast(ModelOutput):
         | 
| 10 | 
            +
                """
         | 
| 11 | 
            +
                Base class for causal language model (or autoregressive) outputs.
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                Args:
         | 
| 14 | 
            +
                    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
         | 
| 15 | 
            +
                        Language modeling loss (for next-token prediction).
         | 
| 16 | 
            +
                    logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
         | 
| 17 | 
            +
                        Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
         | 
| 18 | 
            +
                    past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
         | 
| 19 | 
            +
                        Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
         | 
| 20 | 
            +
                        `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                        Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
         | 
| 23 | 
            +
                        `past_key_values` input) to speed up sequential decoding.
         | 
| 24 | 
            +
                    hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
         | 
| 25 | 
            +
                        Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
         | 
| 26 | 
            +
                        one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                        Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
         | 
| 29 | 
            +
                    attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
         | 
| 30 | 
            +
                        Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
         | 
| 31 | 
            +
                        sequence_length)`.
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                        Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
         | 
| 34 | 
            +
                        heads.
         | 
| 35 | 
            +
                """
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                loss: Optional[torch.FloatTensor] = None
         | 
| 38 | 
            +
                logits: torch.FloatTensor = None
         | 
| 39 | 
            +
                past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
         | 
| 40 | 
            +
                hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
         | 
| 41 | 
            +
                attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
         | 
| 42 | 
            +
                routing_weights: Optional[Tuple[torch.FloatTensor, ...]] = None
         | 
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            gradio>=4.44.0
         | 
| 2 | 
            +
            plotly>=5.22.0
         | 
| 3 | 
            +
            pandas>=2.2.0
         | 
    	
        router_backend.py
    ADDED
    
    | @@ -0,0 +1,223 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # router_backend.py
         | 
| 2 | 
            +
            """
         | 
| 3 | 
            +
            Plug your real model routing function here.
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            Implement the function:
         | 
| 6 | 
            +
                get_expert_routing(model_id: str, prompt: str) -> list[float] | dict[str, float] | tuple[float, float, float, float]
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            It must return 4 values (percentages) corresponding to the experts:
         | 
| 9 | 
            +
            ["Language", "Logic", "Social", "World"]
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            Example return formats:
         | 
| 12 | 
            +
            - [12.5, 45.0, 22.5, 20.0]
         | 
| 13 | 
            +
            - {"Language": 12.5, "Logic": 45.0, "Social": 22.5, "World": 20.0}
         | 
| 14 | 
            +
            - (12.5, 45.0, 22.5, 20.0)
         | 
| 15 | 
            +
            """
         | 
| 16 | 
            +
            import torch
         | 
| 17 | 
            +
            import numpy as np
         | 
| 18 | 
            +
            import torch.nn.functional as F
         | 
| 19 | 
            +
            from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
         | 
| 20 | 
            +
            from typing import Union, Dict, List, Tuple
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            from models.micro_olmo import MiCRoOLMo
         | 
| 23 | 
            +
            from models.micro_llama import MiCRoLlama
         | 
| 24 | 
            +
            from models.micro_moe_llama import MiCRoLlamaMoE
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            def get_expert_routing(model_id: str, hf_token: str, prompt: Union[str, List[Dict[str, str]]]) -> Union[List[float], Dict[str, float], Tuple[float, float, float, float]]:
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                model, tokenizer = build_model(model_id, hf_token)
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                if isinstance(prompt, str):
         | 
| 33 | 
            +
                    generation, routing_weights = generate_continuation(model, tokenizer, prompt)
         | 
| 34 | 
            +
                elif isinstance(prompt, dict):
         | 
| 35 | 
            +
                    generation = None
         | 
| 36 | 
            +
                    routing_weights = get_routing_weights(model, tokenizer, [prompt])
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                model_routing_percentages = aggregate_routing_weights(routing_weights)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                if generation is not None:
         | 
| 41 | 
            +
                    print(f"Generation:\n{generation}")
         | 
| 42 | 
            +
                
         | 
| 43 | 
            +
                return {
         | 
| 44 | 
            +
                    "Language": float(model_routing_percentages[3]),
         | 
| 45 | 
            +
                    "Logic": float(model_routing_percentages[0]),
         | 
| 46 | 
            +
                    "Social": float(model_routing_percentages[1]),
         | 
| 47 | 
            +
                    "World": float(model_routing_percentages[2]),
         | 
| 48 | 
            +
                }, generation
         | 
| 49 | 
            +
             | 
| 50 | 
            +
            def get_model_path(model_name: str) -> Tuple[str, str, AutoModelForCausalLM]:
         | 
| 51 | 
            +
                return {
         | 
| 52 | 
            +
                    # MiCRo-Llama
         | 
| 53 | 
            +
                    "micro-llama-1b": ("bkhmsi/micro-llama-1b", "meta-llama/Llama-3.2-1B-Instruct", MiCRoLlama),
         | 
| 54 | 
            +
                    "micro-llama-3b": ("bkhmsi/micro-llama-3b", "meta-llama/Llama-3.2-3B-Instruct", MiCRoLlama),
         | 
| 55 | 
            +
                    "micro-llama-1b-dpo": ("bkhmsi/micro-llama-1b-dpo", "meta-llama/Llama-3.2-1B-Instruct", MiCRoLlama),
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                    # MiCRo-MoE-Llama
         | 
| 58 | 
            +
                    "micro-moe-llama-1b": ("bkhmsi/micro-moe-llama-1b", "meta-llama/Llama-3.2-1B-Instruct", MiCRoLlamaMoE),
         | 
| 59 | 
            +
                    
         | 
| 60 | 
            +
                    # MiCRo-OLMo
         | 
| 61 | 
            +
                    "micro-olmo": ("bkhmsi/micro-olmo-1b", "allenai/OLMo-2-0425-1B-Instruct", MiCRoOLMo),
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                    # MiCRo-SmolLM2
         | 
| 64 | 
            +
                    "micro-smollm2-135m": ("bkhmsi/micro-smollm2-135m", "HuggingFaceTB/SmolLM2-135M-Instruct", MiCRoLlama),
         | 
| 65 | 
            +
                    "micro-smollm2-360m": ("bkhmsi/micro-smollm2-360m", "HuggingFaceTB/SmolLM2-360M-Instruct", MiCRoLlama),
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                    # MiCRo-MoE-SmolLM2
         | 
| 68 | 
            +
                    "micro-moe-smollm2-135m": ("bkhmsi/micro-moe-smollm2-135m", "HuggingFaceTB/SmolLM2-135M-Instruct", MiCRoLlamaMoE),
         | 
| 69 | 
            +
                    "micro-moe-smollm2-360m": ("bkhmsi/micro-moe-smollm2-360m", "HuggingFaceTB/SmolLM2-360M-Instruct", MiCRoLlamaMoE),
         | 
| 70 | 
            +
                }.get(model_name, (model_name, model_name, AutoModelForCausalLM))
         | 
| 71 | 
            +
             | 
| 72 | 
            +
            def aggregate_routing_weights(routing_weights):
         | 
| 73 | 
            +
                experts = ["Logic", "Social", "World", "Language"]
         | 
| 74 | 
            +
                expert_token_model = np.zeros((len(experts)), dtype=int)
         | 
| 75 | 
            +
                expert_layer_token = np.zeros((routing_weights.shape[0], len(experts)), dtype=int)
         | 
| 76 | 
            +
                num_layers = routing_weights.shape[0]
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                for layer_idx in range(num_layers):
         | 
| 79 | 
            +
                    for token_idx in range(len(routing_weights[layer_idx])):
         | 
| 80 | 
            +
                        expert_idx = routing_weights[layer_idx][token_idx].argmax()
         | 
| 81 | 
            +
                        if layer_idx >= 2 and layer_idx < num_layers - 2:
         | 
| 82 | 
            +
                            expert_token_model[expert_idx] += 1
         | 
| 83 | 
            +
                        expert_layer_token[layer_idx][expert_idx] += 1
         | 
| 84 | 
            +
                return expert_token_model, expert_layer_token
         | 
| 85 | 
            +
             | 
| 86 | 
            +
            def generate_continuation(model, 
         | 
| 87 | 
            +
                tokenizer, 
         | 
| 88 | 
            +
                prompts, 
         | 
| 89 | 
            +
                max_tokens=1024,
         | 
| 90 | 
            +
                use_cache=True, 
         | 
| 91 | 
            +
                return_routing_weights=True
         | 
| 92 | 
            +
            ):
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                if isinstance(prompts, str):
         | 
| 95 | 
            +
                    prompts = [{"role": "user", "content": prompts}]
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                tokenizer.padding_side = "left"
         | 
| 98 | 
            +
                inputs = tokenizer.apply_chat_template([
         | 
| 99 | 
            +
                    prompt for prompt in prompts
         | 
| 100 | 
            +
                ], return_tensors="pt", padding=True, add_generation_prompt=True).to(DEVICE)
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                attention_mask = torch.ones_like(inputs)
         | 
| 103 | 
            +
                attention_mask[inputs == tokenizer.pad_token_id] = 0
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                outputs = model.generate(
         | 
| 106 | 
            +
                    input_ids=inputs,
         | 
| 107 | 
            +
                    attention_mask=attention_mask, 
         | 
| 108 | 
            +
                    max_new_tokens=max_tokens,
         | 
| 109 | 
            +
                    use_cache=use_cache,
         | 
| 110 | 
            +
                    stop_strings=["</s>","<|eot_id|>", "<|im_start|>user"],
         | 
| 111 | 
            +
                    tokenizer=tokenizer,
         | 
| 112 | 
            +
                    pad_token_id=tokenizer.pad_token_id,
         | 
| 113 | 
            +
                    temperature=0,
         | 
| 114 | 
            +
                    top_p=1.0,
         | 
| 115 | 
            +
                    do_sample=False,
         | 
| 116 | 
            +
                )
         | 
| 117 | 
            +
                
         | 
| 118 | 
            +
                if return_routing_weights:
         | 
| 119 | 
            +
                    attention_mask = torch.ones_like(outputs)
         | 
| 120 | 
            +
                    attention_mask[outputs == tokenizer.pad_token_id] = 0
         | 
| 121 | 
            +
                    model_output = model(input_ids=outputs, attention_mask=attention_mask)
         | 
| 122 | 
            +
                    torch.cuda.empty_cache()
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                    routing_weights = model_output.routing_weights        
         | 
| 125 | 
            +
                    routing_weights = np.concatenate([
         | 
| 126 | 
            +
                        F.softmax(rw, dim=-1)[:, inputs.shape[1]:].detach().float().cpu().numpy() 
         | 
| 127 | 
            +
                        for rw in routing_weights
         | 
| 128 | 
            +
                    ])
         | 
| 129 | 
            +
                    
         | 
| 130 | 
            +
                else:
         | 
| 131 | 
            +
                    routing_weights = None
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                inputs_text = tokenizer.batch_decode(inputs, skip_special_tokens=False)
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                generations = []
         | 
| 136 | 
            +
                for i, output in enumerate(outputs):
         | 
| 137 | 
            +
                    decoded_output = tokenizer.decode(output, skip_special_tokens=False)
         | 
| 138 | 
            +
                    decoded_output = decoded_output.replace(inputs_text[i], "")
         | 
| 139 | 
            +
                    decoded_output = decoded_output.replace(tokenizer.pad_token, "").strip()
         | 
| 140 | 
            +
                    decoded_output = decoded_output.replace("<|end_of_text|>", "").strip()
         | 
| 141 | 
            +
                    decoded_output = decoded_output.replace("<|endoftext|>", "").strip()
         | 
| 142 | 
            +
                    decoded_output = decoded_output.replace("<|eot_id|>", "").strip()
         | 
| 143 | 
            +
                    decoded_output = decoded_output.replace("\n<|im_start|>user", "").strip()
         | 
| 144 | 
            +
                    generations.append(decoded_output)
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                return (generations, routing_weights) if return_routing_weights else generations
         | 
| 147 | 
            +
             | 
| 148 | 
            +
            def get_routing_weights(model, tokenizer, prompts, apply_chat_template=True):
         | 
| 149 | 
            +
                """
         | 
| 150 | 
            +
                Get routing weights for the given prompts using the model.
         | 
| 151 | 
            +
                Args:
         | 
| 152 | 
            +
                    model: The MiCRoLlama or MiCRoOLMo model.
         | 
| 153 | 
            +
                    tokenizer: The tokenizer for the model.
         | 
| 154 | 
            +
                    prompts: A string or list of dictionaries containing the prompts.
         | 
| 155 | 
            +
                Returns:
         | 
| 156 | 
            +
                    routing_weights: A list of routing weights for each layer.
         | 
| 157 | 
            +
                """
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                tokenizer.padding_side = "left"
         | 
| 160 | 
            +
                if apply_chat_template:
         | 
| 161 | 
            +
                    if isinstance(prompts, str):
         | 
| 162 | 
            +
                        prompts = [{"role": "user", "content": prompts}]
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                    inputs = tokenizer.apply_chat_template([
         | 
| 165 | 
            +
                        prompt for prompt in prompts
         | 
| 166 | 
            +
                    ], return_tensors="pt", padding=True).to(DEVICE)
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                    input_without_response = tokenizer.apply_chat_template([
         | 
| 169 | 
            +
                            prompt[:-1] for prompt in prompts
         | 
| 170 | 
            +
                        ], return_tensors="pt", padding=True,
         | 
| 171 | 
            +
                    ).to(DEVICE)
         | 
| 172 | 
            +
                else:
         | 
| 173 | 
            +
                    inputs = tokenizer(prompts[0] + prompts[1], return_tensors="pt", padding=True).input_ids.to(DEVICE)
         | 
| 174 | 
            +
                    input_without_response = tokenizer(prompts[0], return_tensors="pt", padding=True).input_ids.to(DEVICE)
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                attention_mask = torch.ones_like(inputs)
         | 
| 177 | 
            +
                attention_mask[inputs == tokenizer.pad_token_id] = 0
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                model_output = model(input_ids=inputs, attention_mask=attention_mask)
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                routing_weights = model_output.routing_weights   
         | 
| 182 | 
            +
                routing_weights = np.stack([F.softmax(rw, dim=-1).detach().float().cpu().numpy() for rw in routing_weights], axis=0).squeeze()
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                offset = len(input_without_response[0])-1
         | 
| 185 | 
            +
                routing_weights = routing_weights[:, offset:-1]
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                return routing_weights
         | 
| 188 | 
            +
             | 
| 189 | 
            +
            def build_model(model_id: str, hf_token: str, use_cache: bool = True):
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                model_path, base_model, model_class = get_model_path(model_id)
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                model_config = AutoConfig.from_pretrained(base_model, use_auth_token=hf_token)
         | 
| 194 | 
            +
                model_config.config_path = f"configs/{model_id}.yml"
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                model_config.torch_dtype = torch.bfloat16
         | 
| 197 | 
            +
                model_config.use_bfloat16 = True
         | 
| 198 | 
            +
                model_config._attn_implementation = "flash_attention_2"
         | 
| 199 | 
            +
                model_config.use_cache = use_cache
         | 
| 200 | 
            +
                model_config.ablate = []
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                tokenizer = AutoTokenizer.from_pretrained(base_model, use_auth_token=hf_token)
         | 
| 203 | 
            +
                tokenizer.padding_side = "left"
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                if "llama" in model_id:
         | 
| 206 | 
            +
                    tokenizer.pad_token_id = 128004
         | 
| 207 | 
            +
                if "olmo" in model_id:
         | 
| 208 | 
            +
                    tokenizer.pad_token_id = 100277
         | 
| 209 | 
            +
                    tokenizer.add_special_tokens({'additional_special_tokens': ['<|assistant|>']})
         | 
| 210 | 
            +
                elif "smollm2" in model_id:
         | 
| 211 | 
            +
                    tokenizer.pad_token_id = 2
         | 
| 212 | 
            +
                else:
         | 
| 213 | 
            +
                    tokenizer.pad_token_id = 128004
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                if "olmo" in model_id:
         | 
| 216 | 
            +
                    model_config.vocab_size = len(tokenizer)
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                model = model_class.from_pretrained(model_path, config=model_config, low_cpu_mem_usage=True)
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                model.to(DEVICE)
         | 
| 221 | 
            +
                model = model.bfloat16()
         | 
| 222 | 
            +
                model.eval()
         | 
| 223 | 
            +
                return model, tokenizer
         |