Chengyue Wu
		
	commited on
		
		
					Commit 
							
							·
						
						cd3af22
	
0
								Parent(s):
							
							
Initial commit
Browse files- .gitattributes +38 -0
 - README.md +151 -0
 - added_tokens.json +25 -0
 - assets/benchmark_results.png +3 -0
 - assets/throughput.png +3 -0
 - assets/training_recipe.png +3 -0
 - assets/visualization_animation.gif +3 -0
 - chat_template.jinja +54 -0
 - config.json +65 -0
 - configuration.py +98 -0
 - generation_config.json +14 -0
 - merges.txt +0 -0
 - model.safetensors +3 -0
 - modeling.py +681 -0
 - special_tokens_map.json +25 -0
 - tokenizer.json +3 -0
 - tokenizer_config.json +204 -0
 - vocab.json +0 -0
 
    	
        .gitattributes
    ADDED
    
    | 
         @@ -0,0 +1,38 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            *.7z filter=lfs diff=lfs merge=lfs -text
         
     | 
| 2 | 
         
            +
            *.arrow filter=lfs diff=lfs merge=lfs -text
         
     | 
| 3 | 
         
            +
            *.bin filter=lfs diff=lfs merge=lfs -text
         
     | 
| 4 | 
         
            +
            *.bz2 filter=lfs diff=lfs merge=lfs -text
         
     | 
| 5 | 
         
            +
            *.ckpt filter=lfs diff=lfs merge=lfs -text
         
     | 
| 6 | 
         
            +
            *.ftz filter=lfs diff=lfs merge=lfs -text
         
     | 
| 7 | 
         
            +
            *.gz filter=lfs diff=lfs merge=lfs -text
         
     | 
| 8 | 
         
            +
            *.h5 filter=lfs diff=lfs merge=lfs -text
         
     | 
| 9 | 
         
            +
            *.joblib filter=lfs diff=lfs merge=lfs -text
         
     | 
| 10 | 
         
            +
            *.lfs.* filter=lfs diff=lfs merge=lfs -text
         
     | 
| 11 | 
         
            +
            *.mlmodel filter=lfs diff=lfs merge=lfs -text
         
     | 
| 12 | 
         
            +
            *.model filter=lfs diff=lfs merge=lfs -text
         
     | 
| 13 | 
         
            +
            *.msgpack filter=lfs diff=lfs merge=lfs -text
         
     | 
| 14 | 
         
            +
            *.npy filter=lfs diff=lfs merge=lfs -text
         
     | 
| 15 | 
         
            +
            *.npz filter=lfs diff=lfs merge=lfs -text
         
     | 
| 16 | 
         
            +
            *.onnx filter=lfs diff=lfs merge=lfs -text
         
     | 
| 17 | 
         
            +
            *.ot filter=lfs diff=lfs merge=lfs -text
         
     | 
| 18 | 
         
            +
            *.parquet filter=lfs diff=lfs merge=lfs -text
         
     | 
| 19 | 
         
            +
            *.pb filter=lfs diff=lfs merge=lfs -text
         
     | 
| 20 | 
         
            +
            *.pickle filter=lfs diff=lfs merge=lfs -text
         
     | 
| 21 | 
         
            +
            *.pkl filter=lfs diff=lfs merge=lfs -text
         
     | 
| 22 | 
         
            +
            *.pt filter=lfs diff=lfs merge=lfs -text
         
     | 
| 23 | 
         
            +
            *.pth filter=lfs diff=lfs merge=lfs -text
         
     | 
| 24 | 
         
            +
            *.rar filter=lfs diff=lfs merge=lfs -text
         
     | 
| 25 | 
         
            +
            *.safetensors filter=lfs diff=lfs merge=lfs -text
         
     | 
| 26 | 
         
            +
            saved_model/**/* filter=lfs diff=lfs merge=lfs -text
         
     | 
| 27 | 
         
            +
            *.tar.* filter=lfs diff=lfs merge=lfs -text
         
     | 
| 28 | 
         
            +
            *.tar filter=lfs diff=lfs merge=lfs -text
         
     | 
| 29 | 
         
            +
            *.tflite filter=lfs diff=lfs merge=lfs -text
         
     | 
| 30 | 
         
            +
            *.tgz filter=lfs diff=lfs merge=lfs -text
         
     | 
| 31 | 
         
            +
            *.wasm filter=lfs diff=lfs merge=lfs -text
         
     | 
| 32 | 
         
            +
            *.xz filter=lfs diff=lfs merge=lfs -text
         
     | 
| 33 | 
         
            +
            *.zip filter=lfs diff=lfs merge=lfs -text
         
     | 
| 34 | 
         
            +
            *.zst filter=lfs diff=lfs merge=lfs -text
         
     | 
| 35 | 
         
            +
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         
     | 
| 36 | 
         
            +
            tokenizer.json filter=lfs diff=lfs merge=lfs -text
         
     | 
| 37 | 
         
            +
            *.png filter=lfs diff=lfs merge=lfs -text
         
     | 
| 38 | 
         
            +
            *.gif filter=lfs diff=lfs merge=lfs -text
         
     | 
    	
        README.md
    ADDED
    
    | 
         @@ -0,0 +1,151 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            ---
         
     | 
| 2 | 
         
            +
            license: apache-2.0
         
     | 
| 3 | 
         
            +
            language:
         
     | 
| 4 | 
         
            +
            - en
         
     | 
| 5 | 
         
            +
            base_model:
         
     | 
| 6 | 
         
            +
            - Qwen/Qwen2.5-1.5B-Instruct
         
     | 
| 7 | 
         
            +
            ---
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            # Fast-dLLM v2 (1.5B) — Efficient Block-Diffusion LLM
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            ## 📖 Introduction
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            Autoregressive (AR) large language models (LLMs) have achieved remarkable performance across a wide range of natural language tasks, yet their **inherent sequential decoding limits inference efficiency**.
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            We present **Fast-dLLM v2** — a carefully designed **block diffusion language model (dLLM)** that efficiently adapts a pretrained AR model (**Qwen2.5-1.5B-Instruct**) into a diffusion-style decoder for **parallel text generation**.
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            Our approach introduces a novel decoding recipe incorporating a complementary attention mask and block diffusion mechanism, which together enable blockwise bidirectional context modeling while preserving the original AR training objectives and performance. To further enhance inference speed, we design a hierarchical caching mechanism: a block-level cache that stores historical context representations and a sub-block level cache that supports efficient parallel decoding within partially generated blocks.
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            ### ✨ Key Innovations
         
     | 
| 20 | 
         
            +
            - **Block Diffusion Mechanism + Complementary Attention Mask**  
         
     | 
| 21 | 
         
            +
              Enables **blockwise bidirectional context modeling** without sacrificing AR objectives.
         
     | 
| 22 | 
         
            +
            - **Hierarchical Caching**  
         
     | 
| 23 | 
         
            +
              - **Block-level cache**: Stores historical context representations across blocks.
         
     | 
| 24 | 
         
            +
              - **Sub-block cache**: Parallel decoding within partially generated blocks.
         
     | 
| 25 | 
         
            +
            - **Token Shift Mechanism**  
         
     | 
| 26 | 
         
            +
              Retains autoregressive characteristics while supporting bidirectional context within blocks.
         
     | 
| 27 | 
         
            +
            - **Parallel Decoding Pipeline**  
         
     | 
| 28 | 
         
            +
              Achieves up to **2.5× speedup** over standard AR decoding **without compromising quality**.
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            > 🚀 Fast-dLLM v2 uses **only ~1B tokens** for fine-tuning — a **500× reduction** vs. full-attention diffusion LLMs (Dream: 580B tokens) — while **matching or surpassing AR baselines** in accuracy.
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
            
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            ---
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
            ## 🛠 Model Overview
         
     | 
| 38 | 
         
            +
            - **Type**: Block Diffusion Language Model (dLLM)
         
     | 
| 39 | 
         
            +
            - **Base Model**: `Qwen/Qwen2.5-1.5B-Instruct`
         
     | 
| 40 | 
         
            +
            - **Architecture**: Transformer w/ RoPE, SwiGLU, RMSNorm, Attention QKV bias, tied embeddings
         
     | 
| 41 | 
         
            +
            - **Params**: 1.54B (non-embedding: 1.31B)
         
     | 
| 42 | 
         
            +
            - **Layers**: 28
         
     | 
| 43 | 
         
            +
            - **Attention Heads**: 12 (Q), 2 (KV, GQA)
         
     | 
| 44 | 
         
            +
            - **Context Window**: 32,768 tokens (generation length: 8,192)
         
     | 
| 45 | 
         
            +
            - **Key Feature**: Parallel **block-wise decoding** + **hierarchical caching**
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
            ---
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
            ## 📦 Installation
         
     | 
| 50 | 
         
            +
            You will need `transformers`, `torch`, and our **custom generation function**:
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
            ```bash
         
     | 
| 53 | 
         
            +
            pip install transformers torch numpy
         
     | 
| 54 | 
         
            +
            ```
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
            ---
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
            ## 🚀 Quickstart
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
            ```python
         
     | 
| 61 | 
         
            +
            from transformers import AutoModelForCausalLM, AutoTokenizer
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
            model_name = "Efficient-Large-Model/Fast_dLLM_1.5B"
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
            model = AutoModelForCausalLM.from_pretrained(
         
     | 
| 66 | 
         
            +
                model_name,
         
     | 
| 67 | 
         
            +
                torch_dtype="auto",
         
     | 
| 68 | 
         
            +
                device_map="auto",
         
     | 
| 69 | 
         
            +
                trust_remote_code=True
         
     | 
| 70 | 
         
            +
            )
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
            tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
            prompt = "Give me a short introduction to large language model."
         
     | 
| 75 | 
         
            +
            messages = [
         
     | 
| 76 | 
         
            +
                {"role": "system", "content": "You are a helpful assistant."},
         
     | 
| 77 | 
         
            +
                {"role": "user", "content": prompt}
         
     | 
| 78 | 
         
            +
            ]
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
            text = tokenizer.apply_chat_template(
         
     | 
| 81 | 
         
            +
                messages,
         
     | 
| 82 | 
         
            +
                tokenize=False,
         
     | 
| 83 | 
         
            +
                add_generation_prompt=True
         
     | 
| 84 | 
         
            +
            )
         
     | 
| 85 | 
         
            +
            inputs = tokenizer([text], return_tensors="pt").to(model.device)
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
            # Fast-dLLM v2 parallel decoding
         
     | 
| 88 | 
         
            +
            gen_ids = model.generate(
         
     | 
| 89 | 
         
            +
                inputs["input_ids"],
         
     | 
| 90 | 
         
            +
                tokenizer=tokenizer,
         
     | 
| 91 | 
         
            +
                max_new_tokens=512,
         
     | 
| 92 | 
         
            +
                small_block_size=8,
         
     | 
| 93 | 
         
            +
                threshold=0.9,
         
     | 
| 94 | 
         
            +
            )
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
            response = tokenizer.decode(
         
     | 
| 97 | 
         
            +
                gen_ids[0][inputs["input_ids"].shape[1]:], 
         
     | 
| 98 | 
         
            +
                skip_special_tokens=True
         
     | 
| 99 | 
         
            +
            )
         
     | 
| 100 | 
         
            +
            print(response)
         
     | 
| 101 | 
         
            +
            ```
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
            ---
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
            ## 📊 Performance & Benchmarks
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
            ### ▶ Real-time Throughput
         
     | 
| 108 | 
         
            +
            Fast-dLLM v2 offers **up to 2.54× higher throughput** than Qwen2.5-7B-Instruct, **without loss in quality**.
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
            
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
            ---
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
            ### 🏆 Benchmark Results
         
     | 
| 115 | 
         
            +
            We compare Fast-dLLM v2 against AR baselines and previous diffusion LLMs on diverse tasks:  
         
     | 
| 116 | 
         
            +
            HumanEval, MBPP (code), GSM8K, Math (reasoning), IFEval (instruction), MMLU, GPQA (knowledge QA).
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
            - **1B group**: Fast-dLLM v2 (1.5B) achieves **best average score: 45.0**.
         
     | 
| 119 | 
         
            +
            - **7B group**: Fast-dLLM v2 (7B) achieves **best average score: 60.3**, surpassing  LLaDA and Dream models.
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
            
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
            ---
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
            ## 📜 Citation
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
            If you use Fast-dLLM v2 in your research or products, please cite:
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
            ```bibtex
         
     | 
| 130 | 
         
            +
            @misc{wu2025fastdllmv2efficientblockdiffusion,
         
     | 
| 131 | 
         
            +
                  title={Fast-dLLM v2: Efficient Block-Diffusion LLM}, 
         
     | 
| 132 | 
         
            +
                  author={Chengyue Wu and Hao Zhang and Shuchen Xue and Shizhe Diao and Yonggan Fu and Zhijian Liu and Pavlo Molchanov and Ping Luo and Song Han and Enze Xie},
         
     | 
| 133 | 
         
            +
                  year={2025},
         
     | 
| 134 | 
         
            +
                  eprint={2509.26328},
         
     | 
| 135 | 
         
            +
                  archivePrefix={arXiv},
         
     | 
| 136 | 
         
            +
                  primaryClass={cs.CL},
         
     | 
| 137 | 
         
            +
                  url={https://arxiv.org/abs/2509.26328}, 
         
     | 
| 138 | 
         
            +
            }
         
     | 
| 139 | 
         
            +
            ```
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
            ---
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
            ## 📄 License
         
     | 
| 144 | 
         
            +
            Released under **Apache 2.0**, following the base Qwen2.5 license.
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
            ---
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
            ## 🔗 Resources
         
     | 
| 149 | 
         
            +
            - 📄 [Paper](https://arxiv.org/abs/2509.26328)  
         
     | 
| 150 | 
         
            +
            - 💻 [Code](https://github.com/NVlabs/Fast-dLLM)  
         
     | 
| 151 | 
         
            +
            - 🤗 [HuggingFace Model](https://huggingface.co/Efficient-Large-Model/Fast_dLLM_1.5B)
         
     | 
    	
        added_tokens.json
    ADDED
    
    | 
         @@ -0,0 +1,25 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            {
         
     | 
| 2 | 
         
            +
              "</tool_call>": 151658,
         
     | 
| 3 | 
         
            +
              "<tool_call>": 151657,
         
     | 
| 4 | 
         
            +
              "<|box_end|>": 151649,
         
     | 
| 5 | 
         
            +
              "<|box_start|>": 151648,
         
     | 
| 6 | 
         
            +
              "<|endoftext|>": 151643,
         
     | 
| 7 | 
         
            +
              "<|file_sep|>": 151664,
         
     | 
| 8 | 
         
            +
              "<|fim_middle|>": 151660,
         
     | 
| 9 | 
         
            +
              "<|fim_pad|>": 151662,
         
     | 
| 10 | 
         
            +
              "<|fim_prefix|>": 151659,
         
     | 
| 11 | 
         
            +
              "<|fim_suffix|>": 151661,
         
     | 
| 12 | 
         
            +
              "<|im_end|>": 151645,
         
     | 
| 13 | 
         
            +
              "<|im_start|>": 151644,
         
     | 
| 14 | 
         
            +
              "<|image_pad|>": 151655,
         
     | 
| 15 | 
         
            +
              "<|object_ref_end|>": 151647,
         
     | 
| 16 | 
         
            +
              "<|object_ref_start|>": 151646,
         
     | 
| 17 | 
         
            +
              "<|quad_end|>": 151651,
         
     | 
| 18 | 
         
            +
              "<|quad_start|>": 151650,
         
     | 
| 19 | 
         
            +
              "<|repo_name|>": 151663,
         
     | 
| 20 | 
         
            +
              "<|video_pad|>": 151656,
         
     | 
| 21 | 
         
            +
              "<|vision_end|>": 151653,
         
     | 
| 22 | 
         
            +
              "<|vision_pad|>": 151654,
         
     | 
| 23 | 
         
            +
              "<|vision_start|>": 151652,
         
     | 
| 24 | 
         
            +
              "|<MASK>|": 151665
         
     | 
| 25 | 
         
            +
            }
         
     | 
    	
        assets/benchmark_results.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        assets/throughput.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        assets/training_recipe.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        assets/visualization_animation.gif
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        chat_template.jinja
    ADDED
    
    | 
         @@ -0,0 +1,54 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            {%- if tools %}
         
     | 
| 2 | 
         
            +
                {{- '<|im_start|>system\n' }}
         
     | 
| 3 | 
         
            +
                {%- if messages[0]['role'] == 'system' %}
         
     | 
| 4 | 
         
            +
                    {{- messages[0]['content'] }}
         
     | 
| 5 | 
         
            +
                {%- else %}
         
     | 
| 6 | 
         
            +
                    {{- 'You are a helpful assistant.' }}
         
     | 
| 7 | 
         
            +
                {%- endif %}
         
     | 
| 8 | 
         
            +
                {{- "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
         
     | 
| 9 | 
         
            +
                {%- for tool in tools %}
         
     | 
| 10 | 
         
            +
                    {{- "\n" }}
         
     | 
| 11 | 
         
            +
                    {{- tool | tojson }}
         
     | 
| 12 | 
         
            +
                {%- endfor %}
         
     | 
| 13 | 
         
            +
                {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
         
     | 
| 14 | 
         
            +
            {%- else %}
         
     | 
| 15 | 
         
            +
                {%- if messages[0]['role'] == 'system' %}
         
     | 
| 16 | 
         
            +
                    {{- '<|im_start|>system\n' + messages[0]['content'] + '<|im_end|>\n' }}
         
     | 
| 17 | 
         
            +
                {%- else %}
         
     | 
| 18 | 
         
            +
                    {{- '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}
         
     | 
| 19 | 
         
            +
                {%- endif %}
         
     | 
| 20 | 
         
            +
            {%- endif %}
         
     | 
| 21 | 
         
            +
            {%- for message in messages %}
         
     | 
| 22 | 
         
            +
                {%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %}
         
     | 
| 23 | 
         
            +
                    {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
         
     | 
| 24 | 
         
            +
                {%- elif message.role == "assistant" %}
         
     | 
| 25 | 
         
            +
                    {{- '<|im_start|>' + message.role }}
         
     | 
| 26 | 
         
            +
                    {%- if message.content %}
         
     | 
| 27 | 
         
            +
                        {{- '\n' + message.content }}
         
     | 
| 28 | 
         
            +
                    {%- endif %}
         
     | 
| 29 | 
         
            +
                    {%- for tool_call in message.tool_calls %}
         
     | 
| 30 | 
         
            +
                        {%- if tool_call.function is defined %}
         
     | 
| 31 | 
         
            +
                            {%- set tool_call = tool_call.function %}
         
     | 
| 32 | 
         
            +
                        {%- endif %}
         
     | 
| 33 | 
         
            +
                        {{- '\n<tool_call>\n{"name": "' }}
         
     | 
| 34 | 
         
            +
                        {{- tool_call.name }}
         
     | 
| 35 | 
         
            +
                        {{- '", "arguments": ' }}
         
     | 
| 36 | 
         
            +
                        {{- tool_call.arguments | tojson }}
         
     | 
| 37 | 
         
            +
                        {{- '}\n</tool_call>' }}
         
     | 
| 38 | 
         
            +
                    {%- endfor %}
         
     | 
| 39 | 
         
            +
                    {{- '<|im_end|>\n' }}
         
     | 
| 40 | 
         
            +
                {%- elif message.role == "tool" %}
         
     | 
| 41 | 
         
            +
                    {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %}
         
     | 
| 42 | 
         
            +
                        {{- '<|im_start|>user' }}
         
     | 
| 43 | 
         
            +
                    {%- endif %}
         
     | 
| 44 | 
         
            +
                    {{- '\n<tool_response>\n' }}
         
     | 
| 45 | 
         
            +
                    {{- message.content }}
         
     | 
| 46 | 
         
            +
                    {{- '\n</tool_response>' }}
         
     | 
| 47 | 
         
            +
                    {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
         
     | 
| 48 | 
         
            +
                        {{- '<|im_end|>\n' }}
         
     | 
| 49 | 
         
            +
                    {%- endif %}
         
     | 
| 50 | 
         
            +
                {%- endif %}
         
     | 
| 51 | 
         
            +
            {%- endfor %}
         
     | 
| 52 | 
         
            +
            {%- if add_generation_prompt %}
         
     | 
| 53 | 
         
            +
                {{- '<|im_start|>assistant\n' }}
         
     | 
| 54 | 
         
            +
            {%- endif %}
         
     | 
    	
        config.json
    ADDED
    
    | 
         @@ -0,0 +1,65 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            {
         
     | 
| 2 | 
         
            +
              "architectures": [
         
     | 
| 3 | 
         
            +
                "Fast_dLLM_QwenForCausalLM"
         
     | 
| 4 | 
         
            +
              ],
         
     | 
| 5 | 
         
            +
              "attention_dropout": 0.0,
         
     | 
| 6 | 
         
            +
              "auto_map": {
         
     | 
| 7 | 
         
            +
                "AutoConfig": "configuration.Fast_dLLM_QwenConfig",
         
     | 
| 8 | 
         
            +
                "AutoModel": "modeling.Fast_dLLM_QwenModel",
         
     | 
| 9 | 
         
            +
                "AutoModelForCausalLM": "modeling.Fast_dLLM_QwenForCausalLM"
         
     | 
| 10 | 
         
            +
              },
         
     | 
| 11 | 
         
            +
              "bd_size": 32,
         
     | 
| 12 | 
         
            +
              "bos_token_id": 151643,
         
     | 
| 13 | 
         
            +
              "eos_token_id": 151645,
         
     | 
| 14 | 
         
            +
              "hidden_act": "silu",
         
     | 
| 15 | 
         
            +
              "hidden_size": 1536,
         
     | 
| 16 | 
         
            +
              "initializer_range": 0.02,
         
     | 
| 17 | 
         
            +
              "intermediate_size": 8960,
         
     | 
| 18 | 
         
            +
              "layer_types": [
         
     | 
| 19 | 
         
            +
                "full_attention",
         
     | 
| 20 | 
         
            +
                "full_attention",
         
     | 
| 21 | 
         
            +
                "full_attention",
         
     | 
| 22 | 
         
            +
                "full_attention",
         
     | 
| 23 | 
         
            +
                "full_attention",
         
     | 
| 24 | 
         
            +
                "full_attention",
         
     | 
| 25 | 
         
            +
                "full_attention",
         
     | 
| 26 | 
         
            +
                "full_attention",
         
     | 
| 27 | 
         
            +
                "full_attention",
         
     | 
| 28 | 
         
            +
                "full_attention",
         
     | 
| 29 | 
         
            +
                "full_attention",
         
     | 
| 30 | 
         
            +
                "full_attention",
         
     | 
| 31 | 
         
            +
                "full_attention",
         
     | 
| 32 | 
         
            +
                "full_attention",
         
     | 
| 33 | 
         
            +
                "full_attention",
         
     | 
| 34 | 
         
            +
                "full_attention",
         
     | 
| 35 | 
         
            +
                "full_attention",
         
     | 
| 36 | 
         
            +
                "full_attention",
         
     | 
| 37 | 
         
            +
                "full_attention",
         
     | 
| 38 | 
         
            +
                "full_attention",
         
     | 
| 39 | 
         
            +
                "full_attention",
         
     | 
| 40 | 
         
            +
                "full_attention",
         
     | 
| 41 | 
         
            +
                "full_attention",
         
     | 
| 42 | 
         
            +
                "full_attention",
         
     | 
| 43 | 
         
            +
                "full_attention",
         
     | 
| 44 | 
         
            +
                "full_attention",
         
     | 
| 45 | 
         
            +
                "full_attention",
         
     | 
| 46 | 
         
            +
                "full_attention"
         
     | 
| 47 | 
         
            +
              ],
         
     | 
| 48 | 
         
            +
              "max_position_embeddings": 32768,
         
     | 
| 49 | 
         
            +
              "max_window_layers": 21,
         
     | 
| 50 | 
         
            +
              "model_type": "Fast_dLLM_Qwen",
         
     | 
| 51 | 
         
            +
              "num_attention_heads": 12,
         
     | 
| 52 | 
         
            +
              "num_hidden_layers": 28,
         
     | 
| 53 | 
         
            +
              "num_key_value_heads": 2,
         
     | 
| 54 | 
         
            +
              "pad_token_id": 151645,
         
     | 
| 55 | 
         
            +
              "rms_norm_eps": 1e-06,
         
     | 
| 56 | 
         
            +
              "rope_scaling": null,
         
     | 
| 57 | 
         
            +
              "rope_theta": 1000000.0,
         
     | 
| 58 | 
         
            +
              "sliding_window": null,
         
     | 
| 59 | 
         
            +
              "tie_word_embeddings": true,
         
     | 
| 60 | 
         
            +
              "torch_dtype": "bfloat16",
         
     | 
| 61 | 
         
            +
              "transformers_version": "4.53.1",
         
     | 
| 62 | 
         
            +
              "use_cache": true,
         
     | 
| 63 | 
         
            +
              "use_sliding_window": false,
         
     | 
| 64 | 
         
            +
              "vocab_size": 151936
         
     | 
| 65 | 
         
            +
            }
         
     | 
    	
        configuration.py
    ADDED
    
    | 
         @@ -0,0 +1,98 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """Fast_dLLM_Qwen model configuration"""
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            from transformers.configuration_utils import PretrainedConfig, layer_type_validation
         
     | 
| 4 | 
         
            +
            from transformers.modeling_rope_utils import rope_config_validation
         
     | 
| 5 | 
         
            +
            from transformers.utils import logging
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            logger = logging.get_logger(__name__)
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            class Fast_dLLM_QwenConfig(PretrainedConfig):
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
                model_type = "Fast_dLLM_Qwen"
         
     | 
| 14 | 
         
            +
                keys_to_ignore_at_inference = ["past_key_values"]
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
                # Default tensor parallel plan for base model `Fast_dLLM_Qwen`
         
     | 
| 17 | 
         
            +
                base_model_tp_plan = {
         
     | 
| 18 | 
         
            +
                    "layers.*.self_attn.q_proj": "colwise",
         
     | 
| 19 | 
         
            +
                    "layers.*.self_attn.k_proj": "colwise",
         
     | 
| 20 | 
         
            +
                    "layers.*.self_attn.v_proj": "colwise",
         
     | 
| 21 | 
         
            +
                    "layers.*.self_attn.o_proj": "rowwise",
         
     | 
| 22 | 
         
            +
                    "layers.*.mlp.gate_proj": "colwise",
         
     | 
| 23 | 
         
            +
                    "layers.*.mlp.up_proj": "colwise",
         
     | 
| 24 | 
         
            +
                    "layers.*.mlp.down_proj": "rowwise",
         
     | 
| 25 | 
         
            +
                }
         
     | 
| 26 | 
         
            +
                base_model_pp_plan = {
         
     | 
| 27 | 
         
            +
                    "embed_tokens": (["input_ids"], ["inputs_embeds"]),
         
     | 
| 28 | 
         
            +
                    "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
         
     | 
| 29 | 
         
            +
                    "norm": (["hidden_states"], ["hidden_states"]),
         
     | 
| 30 | 
         
            +
                }
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                def __init__(
         
     | 
| 33 | 
         
            +
                    self,
         
     | 
| 34 | 
         
            +
                    vocab_size=151936,
         
     | 
| 35 | 
         
            +
                    hidden_size=4096,
         
     | 
| 36 | 
         
            +
                    intermediate_size=22016,
         
     | 
| 37 | 
         
            +
                    num_hidden_layers=32,
         
     | 
| 38 | 
         
            +
                    num_attention_heads=32,
         
     | 
| 39 | 
         
            +
                    num_key_value_heads=32,
         
     | 
| 40 | 
         
            +
                    hidden_act="silu",
         
     | 
| 41 | 
         
            +
                    max_position_embeddings=32768,
         
     | 
| 42 | 
         
            +
                    initializer_range=0.02,
         
     | 
| 43 | 
         
            +
                    rms_norm_eps=1e-6,
         
     | 
| 44 | 
         
            +
                    use_cache=True,
         
     | 
| 45 | 
         
            +
                    tie_word_embeddings=False,
         
     | 
| 46 | 
         
            +
                    rope_theta=10000.0,
         
     | 
| 47 | 
         
            +
                    rope_scaling=None,
         
     | 
| 48 | 
         
            +
                    use_sliding_window=False,
         
     | 
| 49 | 
         
            +
                    sliding_window=4096,
         
     | 
| 50 | 
         
            +
                    max_window_layers=28,
         
     | 
| 51 | 
         
            +
                    layer_types=None,
         
     | 
| 52 | 
         
            +
                    attention_dropout=0.0,
         
     | 
| 53 | 
         
            +
                    bd_size=32,
         
     | 
| 54 | 
         
            +
                    **kwargs,
         
     | 
| 55 | 
         
            +
                ):
         
     | 
| 56 | 
         
            +
                    self.vocab_size = vocab_size
         
     | 
| 57 | 
         
            +
                    self.max_position_embeddings = max_position_embeddings
         
     | 
| 58 | 
         
            +
                    self.hidden_size = hidden_size
         
     | 
| 59 | 
         
            +
                    self.intermediate_size = intermediate_size
         
     | 
| 60 | 
         
            +
                    self.num_hidden_layers = num_hidden_layers
         
     | 
| 61 | 
         
            +
                    self.num_attention_heads = num_attention_heads
         
     | 
| 62 | 
         
            +
                    self.use_sliding_window = use_sliding_window
         
     | 
| 63 | 
         
            +
                    self.sliding_window = sliding_window if self.use_sliding_window else None
         
     | 
| 64 | 
         
            +
                    self.max_window_layers = max_window_layers
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                    # for backward compatibility
         
     | 
| 67 | 
         
            +
                    if num_key_value_heads is None:
         
     | 
| 68 | 
         
            +
                        num_key_value_heads = num_attention_heads
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                    self.num_key_value_heads = num_key_value_heads
         
     | 
| 71 | 
         
            +
                    self.hidden_act = hidden_act
         
     | 
| 72 | 
         
            +
                    self.initializer_range = initializer_range
         
     | 
| 73 | 
         
            +
                    self.rms_norm_eps = rms_norm_eps
         
     | 
| 74 | 
         
            +
                    self.use_cache = use_cache
         
     | 
| 75 | 
         
            +
                    self.rope_theta = rope_theta
         
     | 
| 76 | 
         
            +
                    self.rope_scaling = rope_scaling
         
     | 
| 77 | 
         
            +
                    self.attention_dropout = attention_dropout
         
     | 
| 78 | 
         
            +
                    self.bd_size = bd_size
         
     | 
| 79 | 
         
            +
                    # Validate the correctness of rotary position embeddings parameters
         
     | 
| 80 | 
         
            +
                    # BC: if there is a 'type' field, move it to 'rope_type'.
         
     | 
| 81 | 
         
            +
                    if self.rope_scaling is not None and "type" in self.rope_scaling:
         
     | 
| 82 | 
         
            +
                        self.rope_scaling["rope_type"] = self.rope_scaling["type"]
         
     | 
| 83 | 
         
            +
                    rope_config_validation(self)
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                    self.layer_types = layer_types
         
     | 
| 86 | 
         
            +
                    if self.layer_types is None:
         
     | 
| 87 | 
         
            +
                        self.layer_types = [
         
     | 
| 88 | 
         
            +
                            "sliding_attention"
         
     | 
| 89 | 
         
            +
                            if self.sliding_window is not None and i >= self.max_window_layers
         
     | 
| 90 | 
         
            +
                            else "full_attention"
         
     | 
| 91 | 
         
            +
                            for i in range(self.num_hidden_layers)
         
     | 
| 92 | 
         
            +
                        ]
         
     | 
| 93 | 
         
            +
                    layer_type_validation(self.layer_types)
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
                    super().__init__(
         
     | 
| 96 | 
         
            +
                        tie_word_embeddings=tie_word_embeddings,
         
     | 
| 97 | 
         
            +
                        **kwargs,
         
     | 
| 98 | 
         
            +
                    )
         
     | 
    	
        generation_config.json
    ADDED
    
    | 
         @@ -0,0 +1,14 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            {
         
     | 
| 2 | 
         
            +
              "bos_token_id": 151643,
         
     | 
| 3 | 
         
            +
              "do_sample": true,
         
     | 
| 4 | 
         
            +
              "eos_token_id": [
         
     | 
| 5 | 
         
            +
                151645,
         
     | 
| 6 | 
         
            +
                151643
         
     | 
| 7 | 
         
            +
              ],
         
     | 
| 8 | 
         
            +
              "pad_token_id": 151643,
         
     | 
| 9 | 
         
            +
              "repetition_penalty": 1.1,
         
     | 
| 10 | 
         
            +
              "temperature": 0.7,
         
     | 
| 11 | 
         
            +
              "top_k": 20,
         
     | 
| 12 | 
         
            +
              "top_p": 0.8,
         
     | 
| 13 | 
         
            +
              "transformers_version": "4.53.1"
         
     | 
| 14 | 
         
            +
            }
         
     | 
    	
        merges.txt
    ADDED
    
    | 
         The diff for this file is too large to render. 
		See raw diff 
     | 
| 
         | 
    	
        model.safetensors
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:8d267bb8b935f2e15148ba1175b67dba70261a696ec931dd0a3b0f27f9f3c434
         
     | 
| 3 | 
         
            +
            size 3087467144
         
     | 
    	
        modeling.py
    ADDED
    
    | 
         @@ -0,0 +1,681 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from typing import Callable, Optional, Union
         
     | 
| 2 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            from torch import nn
         
     | 
| 6 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 7 | 
         
            +
            from functools import partial
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            from transformers.activations import ACT2FN
         
     | 
| 10 | 
         
            +
            from transformers.cache_utils import Cache, DynamicCache
         
     | 
| 11 | 
         
            +
            from transformers.generation import GenerationMixin
         
     | 
| 12 | 
         
            +
            from transformers.integrations import use_kernel_forward_from_hub
         
     | 
| 13 | 
         
            +
            from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
         
     | 
| 14 | 
         
            +
            from transformers.modeling_layers import GradientCheckpointingLayer
         
     | 
| 15 | 
         
            +
            from transformers.modeling_outputs import (
         
     | 
| 16 | 
         
            +
                BaseModelOutputWithPast,
         
     | 
| 17 | 
         
            +
                CausalLMOutputWithPast,
         
     | 
| 18 | 
         
            +
            )
         
     | 
| 19 | 
         
            +
            from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
         
     | 
| 20 | 
         
            +
            from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
         
     | 
| 21 | 
         
            +
            from transformers.processing_utils import Unpack
         
     | 
| 22 | 
         
            +
            from transformers.utils import auto_docstring, can_return_tuple, logging
         
     | 
| 23 | 
         
            +
            from .configuration import Fast_dLLM_QwenConfig
         
     | 
| 24 | 
         
            +
            from torch.nn.attention.flex_attention import flex_attention, create_block_mask
         
     | 
| 25 | 
         
            +
            from einops import rearrange, repeat
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            logger = logging.get_logger(__name__)
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            @dataclass
         
     | 
| 31 | 
         
            +
            class CausalLMOutputWithPastAndBlockCache(CausalLMOutputWithPast):
         
     | 
| 32 | 
         
            +
                block_past_key_values: Optional[Cache] = None
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            @dataclass
         
     | 
| 35 | 
         
            +
            class BaseModelOutputWithPastAndBlockCache(BaseModelOutputWithPast):
         
     | 
| 36 | 
         
            +
                block_past_key_values: Optional[Cache] = None
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
            def eval_block_diff_mask(q_idx, kv_idx, block_size=None):
         
     | 
| 40 | 
         
            +
                # Compute block indices
         
     | 
| 41 | 
         
            +
                block_q = q_idx // block_size
         
     | 
| 42 | 
         
            +
                block_kv = kv_idx // block_size
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                return block_q >= block_kv
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
            class Fast_dLLM_QwenMLP(nn.Module):
         
     | 
| 47 | 
         
            +
                def __init__(self, config):
         
     | 
| 48 | 
         
            +
                    super().__init__()
         
     | 
| 49 | 
         
            +
                    self.config = config
         
     | 
| 50 | 
         
            +
                    self.hidden_size = config.hidden_size
         
     | 
| 51 | 
         
            +
                    self.intermediate_size = config.intermediate_size
         
     | 
| 52 | 
         
            +
                    self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
         
     | 
| 53 | 
         
            +
                    self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
         
     | 
| 54 | 
         
            +
                    self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
         
     | 
| 55 | 
         
            +
                    self.act_fn = ACT2FN[config.hidden_act]
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                def forward(self, x):
         
     | 
| 58 | 
         
            +
                    down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
         
     | 
| 59 | 
         
            +
                    return down_proj
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
            def rotate_half(x):
         
     | 
| 63 | 
         
            +
                """Rotates half the hidden dims of the input."""
         
     | 
| 64 | 
         
            +
                x1 = x[..., : x.shape[-1] // 2]
         
     | 
| 65 | 
         
            +
                x2 = x[..., x.shape[-1] // 2 :]
         
     | 
| 66 | 
         
            +
                return torch.cat((-x2, x1), dim=-1)
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
            def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
         
     | 
| 70 | 
         
            +
                """Applies Rotary Position Embedding to the query and key tensors.
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                Args:
         
     | 
| 73 | 
         
            +
                    q (`torch.Tensor`): The query tensor.
         
     | 
| 74 | 
         
            +
                    k (`torch.Tensor`): The key tensor.
         
     | 
| 75 | 
         
            +
                    cos (`torch.Tensor`): The cosine part of the rotary embedding.
         
     | 
| 76 | 
         
            +
                    sin (`torch.Tensor`): The sine part of the rotary embedding.
         
     | 
| 77 | 
         
            +
                    position_ids (`torch.Tensor`, *optional*):
         
     | 
| 78 | 
         
            +
                        Deprecated and unused.
         
     | 
| 79 | 
         
            +
                    unsqueeze_dim (`int`, *optional*, defaults to 1):
         
     | 
| 80 | 
         
            +
                        The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
         
     | 
| 81 | 
         
            +
                        sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
         
     | 
| 82 | 
         
            +
                        that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
         
     | 
| 83 | 
         
            +
                        k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
         
     | 
| 84 | 
         
            +
                        cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
         
     | 
| 85 | 
         
            +
                        the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
         
     | 
| 86 | 
         
            +
                Returns:
         
     | 
| 87 | 
         
            +
                    `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
         
     | 
| 88 | 
         
            +
                """
         
     | 
| 89 | 
         
            +
                cos = cos.unsqueeze(unsqueeze_dim)
         
     | 
| 90 | 
         
            +
                sin = sin.unsqueeze(unsqueeze_dim)
         
     | 
| 91 | 
         
            +
                q_embed = (q * cos) + (rotate_half(q) * sin)
         
     | 
| 92 | 
         
            +
                k_embed = (k * cos) + (rotate_half(k) * sin)
         
     | 
| 93 | 
         
            +
                return q_embed, k_embed
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
            def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
         
     | 
| 97 | 
         
            +
                """
         
     | 
| 98 | 
         
            +
                This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
         
     | 
| 99 | 
         
            +
                num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
         
     | 
| 100 | 
         
            +
                """
         
     | 
| 101 | 
         
            +
                batch, num_key_value_heads, slen, head_dim = hidden_states.shape
         
     | 
| 102 | 
         
            +
                if n_rep == 1:
         
     | 
| 103 | 
         
            +
                    return hidden_states
         
     | 
| 104 | 
         
            +
                hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
         
     | 
| 105 | 
         
            +
                return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
            class Fast_dLLM_QwenAttention(nn.Module):
         
     | 
| 109 | 
         
            +
                """Multi-headed attention from 'Attention Is All You Need' paper"""
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                def __init__(self, config: Fast_dLLM_QwenConfig, layer_idx: int):
         
     | 
| 112 | 
         
            +
                    super().__init__()
         
     | 
| 113 | 
         
            +
                    self.config = config
         
     | 
| 114 | 
         
            +
                    self.layer_idx = layer_idx
         
     | 
| 115 | 
         
            +
                    self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
         
     | 
| 116 | 
         
            +
                    self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
         
     | 
| 117 | 
         
            +
                    self.scaling = self.head_dim**-0.5
         
     | 
| 118 | 
         
            +
                    self.attention_dropout = config.attention_dropout
         
     | 
| 119 | 
         
            +
                    self.is_causal = True
         
     | 
| 120 | 
         
            +
                    self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True)
         
     | 
| 121 | 
         
            +
                    self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
         
     | 
| 122 | 
         
            +
                    self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
         
     | 
| 123 | 
         
            +
                    self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
         
     | 
| 124 | 
         
            +
                    self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
         
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
                def forward(
         
     | 
| 127 | 
         
            +
                    self,
         
     | 
| 128 | 
         
            +
                    hidden_states: torch.Tensor,
         
     | 
| 129 | 
         
            +
                    position_embeddings: tuple[torch.Tensor, torch.Tensor],
         
     | 
| 130 | 
         
            +
                    attention_mask: Optional[torch.Tensor],
         
     | 
| 131 | 
         
            +
                    past_key_value: Optional[Cache] = None,
         
     | 
| 132 | 
         
            +
                    cache_position: Optional[torch.LongTensor] = None,
         
     | 
| 133 | 
         
            +
                    update_past_key_values: Optional[bool] = False,
         
     | 
| 134 | 
         
            +
                    block_past_key_values: Optional[Cache] = None,
         
     | 
| 135 | 
         
            +
                    replace_position: Optional[int] = None,
         
     | 
| 136 | 
         
            +
                    **kwargs: Unpack[FlashAttentionKwargs],
         
     | 
| 137 | 
         
            +
                ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
         
     | 
| 138 | 
         
            +
                    input_shape = hidden_states.shape[:-1]
         
     | 
| 139 | 
         
            +
                    hidden_shape = (*input_shape, -1, self.head_dim)
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
                    query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
         
     | 
| 142 | 
         
            +
                    key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
         
     | 
| 143 | 
         
            +
                    value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
                    cos, sin = position_embeddings
         
     | 
| 146 | 
         
            +
                    # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
         
     | 
| 147 | 
         
            +
                    if self.training:
         
     | 
| 148 | 
         
            +
                        #split q into two parts
         
     | 
| 149 | 
         
            +
                        q_1 = query_states[:,:,:query_states.shape[2]//2]
         
     | 
| 150 | 
         
            +
                        q_2 = query_states[:,:,query_states.shape[2]//2:]
         
     | 
| 151 | 
         
            +
                        #split k into two parts
         
     | 
| 152 | 
         
            +
                        k_1 = key_states[:,:,:key_states.shape[2]//2]
         
     | 
| 153 | 
         
            +
                        k_2 = key_states[:,:,key_states.shape[2]//2:]
         
     | 
| 154 | 
         
            +
                        q_1, k_1 = apply_rotary_pos_emb(q_1, k_1, cos, sin)
         
     | 
| 155 | 
         
            +
                        q_2, k_2 = apply_rotary_pos_emb(q_2, k_2, cos, sin)
         
     | 
| 156 | 
         
            +
                        query_states = torch.cat((q_1, q_2), dim=-2)
         
     | 
| 157 | 
         
            +
                        key_states = torch.cat((k_1, k_2), dim=-2)
         
     | 
| 158 | 
         
            +
                    else:
         
     | 
| 159 | 
         
            +
                        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
         
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
            +
                    if block_past_key_values is not None:
         
     | 
| 162 | 
         
            +
                        if len(block_past_key_values) <= self.layer_idx:
         
     | 
| 163 | 
         
            +
                            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
         
     | 
| 164 | 
         
            +
                            key_states, value_states = block_past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
         
     | 
| 165 | 
         
            +
                        else:
         
     | 
| 166 | 
         
            +
                            block_cache_key_states = block_past_key_values[self.layer_idx][0]
         
     | 
| 167 | 
         
            +
                            block_cache_value_states = block_past_key_values[self.layer_idx][1]
         
     | 
| 168 | 
         
            +
                            
         
     | 
| 169 | 
         
            +
                            block_cache_key_states[:, :, replace_position:replace_position+key_states.shape[2]] = key_states
         
     | 
| 170 | 
         
            +
                            block_cache_value_states[:, :, replace_position:replace_position+value_states.shape[2]] = value_states
         
     | 
| 171 | 
         
            +
                            key_states = block_cache_key_states
         
     | 
| 172 | 
         
            +
                            value_states = block_cache_value_states
         
     | 
| 173 | 
         
            +
             
     | 
| 174 | 
         
            +
                    if past_key_value is not None:
         
     | 
| 175 | 
         
            +
                        # sin and cos are specific to RoPE models; cache_position needed for the static cache
         
     | 
| 176 | 
         
            +
                        if update_past_key_values:
         
     | 
| 177 | 
         
            +
                            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
         
     | 
| 178 | 
         
            +
                            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
         
     | 
| 179 | 
         
            +
                        elif len(past_key_value) > self.layer_idx:
         
     | 
| 180 | 
         
            +
                            key_states = torch.cat((past_key_value[self.layer_idx][0], key_states), dim=-2)
         
     | 
| 181 | 
         
            +
                            value_states = torch.cat((past_key_value[self.layer_idx][1], value_states), dim=-2)
         
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
                    attention_interface = ALL_ATTENTION_FUNCTIONS["sdpa"]
         
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
                    attn_output, attn_weights = attention_interface(
         
     | 
| 186 | 
         
            +
                        self,
         
     | 
| 187 | 
         
            +
                        query_states,
         
     | 
| 188 | 
         
            +
                        key_states,
         
     | 
| 189 | 
         
            +
                        value_states,
         
     | 
| 190 | 
         
            +
                        attention_mask,
         
     | 
| 191 | 
         
            +
                        is_causal=False,
         
     | 
| 192 | 
         
            +
                        dropout=0.0 if not self.training else self.attention_dropout,
         
     | 
| 193 | 
         
            +
                        scaling=self.scaling,
         
     | 
| 194 | 
         
            +
                        sliding_window=self.sliding_window,  # main diff with Llama
         
     | 
| 195 | 
         
            +
                        **kwargs,
         
     | 
| 196 | 
         
            +
                    )
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
                    attn_output = attn_output.reshape(*input_shape, -1).contiguous()
         
     | 
| 199 | 
         
            +
                    attn_output = self.o_proj(attn_output)
         
     | 
| 200 | 
         
            +
                    return attn_output
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
            @use_kernel_forward_from_hub("RMSNorm")
         
     | 
| 203 | 
         
            +
            class Fast_dLLM_QwenRMSNorm(nn.Module):
         
     | 
| 204 | 
         
            +
                def __init__(self, hidden_size, eps=1e-6):
         
     | 
| 205 | 
         
            +
                    """
         
     | 
| 206 | 
         
            +
                    Fast_dLLM_QwenRMSNorm is equivalent to T5LayerNorm
         
     | 
| 207 | 
         
            +
                    """
         
     | 
| 208 | 
         
            +
                    super().__init__()
         
     | 
| 209 | 
         
            +
                    self.weight = nn.Parameter(torch.ones(hidden_size))
         
     | 
| 210 | 
         
            +
                    self.variance_epsilon = eps
         
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
                def forward(self, hidden_states):
         
     | 
| 213 | 
         
            +
                    input_dtype = hidden_states.dtype
         
     | 
| 214 | 
         
            +
                    hidden_states = hidden_states.to(torch.float32)
         
     | 
| 215 | 
         
            +
                    variance = hidden_states.pow(2).mean(-1, keepdim=True)
         
     | 
| 216 | 
         
            +
                    hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
         
     | 
| 217 | 
         
            +
                    return self.weight * hidden_states.to(input_dtype)
         
     | 
| 218 | 
         
            +
             
     | 
| 219 | 
         
            +
                def extra_repr(self):
         
     | 
| 220 | 
         
            +
                    return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
         
     | 
| 221 | 
         
            +
             
     | 
| 222 | 
         
            +
             
     | 
| 223 | 
         
            +
            class Fast_dLLM_QwenDecoderLayer(GradientCheckpointingLayer):
         
     | 
| 224 | 
         
            +
                def __init__(self, config: Fast_dLLM_QwenConfig, layer_idx: int):
         
     | 
| 225 | 
         
            +
                    super().__init__()
         
     | 
| 226 | 
         
            +
                    self.hidden_size = config.hidden_size
         
     | 
| 227 | 
         
            +
             
     | 
| 228 | 
         
            +
                    self.self_attn = Fast_dLLM_QwenAttention(config=config, layer_idx=layer_idx)
         
     | 
| 229 | 
         
            +
             
     | 
| 230 | 
         
            +
                    self.mlp = Fast_dLLM_QwenMLP(config)
         
     | 
| 231 | 
         
            +
                    self.input_layernorm = Fast_dLLM_QwenRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
         
     | 
| 232 | 
         
            +
                    self.post_attention_layernorm = Fast_dLLM_QwenRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
         
     | 
| 233 | 
         
            +
                    self.attention_type = config.layer_types[layer_idx]
         
     | 
| 234 | 
         
            +
             
     | 
| 235 | 
         
            +
                def forward(
         
     | 
| 236 | 
         
            +
                    self,
         
     | 
| 237 | 
         
            +
                    hidden_states: torch.Tensor,
         
     | 
| 238 | 
         
            +
                    attention_mask: Optional[torch.Tensor] = None,
         
     | 
| 239 | 
         
            +
                    position_ids: Optional[torch.LongTensor] = None,
         
     | 
| 240 | 
         
            +
                    past_key_value: Optional[Cache] = None,
         
     | 
| 241 | 
         
            +
                    use_cache: Optional[bool] = False,
         
     | 
| 242 | 
         
            +
                    cache_position: Optional[torch.LongTensor] = None,
         
     | 
| 243 | 
         
            +
                    position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,  # necessary, but kept here for BC
         
     | 
| 244 | 
         
            +
                    update_past_key_values: Optional[bool] = False,
         
     | 
| 245 | 
         
            +
                    use_block_cache: Optional[bool] = False,
         
     | 
| 246 | 
         
            +
                    block_past_key_values: Optional[Cache] = None,
         
     | 
| 247 | 
         
            +
                    replace_position: Optional[int] = None,
         
     | 
| 248 | 
         
            +
                    **kwargs
         
     | 
| 249 | 
         
            +
                ) -> tuple[torch.Tensor]:
         
     | 
| 250 | 
         
            +
                    residual = hidden_states
         
     | 
| 251 | 
         
            +
                    hidden_states = self.input_layernorm(hidden_states)
         
     | 
| 252 | 
         
            +
                    # Self Attention
         
     | 
| 253 | 
         
            +
                    hidden_states = self.self_attn(
         
     | 
| 254 | 
         
            +
                        hidden_states=hidden_states,
         
     | 
| 255 | 
         
            +
                        attention_mask=attention_mask,
         
     | 
| 256 | 
         
            +
                        position_ids=position_ids,
         
     | 
| 257 | 
         
            +
                        past_key_value=past_key_value,
         
     | 
| 258 | 
         
            +
                        use_cache=use_cache,
         
     | 
| 259 | 
         
            +
                        cache_position=cache_position,
         
     | 
| 260 | 
         
            +
                        position_embeddings=position_embeddings,
         
     | 
| 261 | 
         
            +
                        update_past_key_values=update_past_key_values,
         
     | 
| 262 | 
         
            +
                        use_block_cache=use_block_cache,
         
     | 
| 263 | 
         
            +
                        block_past_key_values=block_past_key_values,
         
     | 
| 264 | 
         
            +
                        replace_position=replace_position,
         
     | 
| 265 | 
         
            +
                        **kwargs,
         
     | 
| 266 | 
         
            +
                    )
         
     | 
| 267 | 
         
            +
                    hidden_states = residual + hidden_states
         
     | 
| 268 | 
         
            +
             
     | 
| 269 | 
         
            +
                    # Fully Connected
         
     | 
| 270 | 
         
            +
                    residual = hidden_states
         
     | 
| 271 | 
         
            +
                    hidden_states = self.post_attention_layernorm(hidden_states)
         
     | 
| 272 | 
         
            +
                    hidden_states = self.mlp(hidden_states)
         
     | 
| 273 | 
         
            +
                    hidden_states = residual + hidden_states
         
     | 
| 274 | 
         
            +
                    return hidden_states
         
     | 
| 275 | 
         
            +
             
     | 
| 276 | 
         
            +
             
     | 
| 277 | 
         
            +
             
     | 
| 278 | 
         
            +
            class Fast_dLLM_QwenPreTrainedModel(PreTrainedModel):
         
     | 
| 279 | 
         
            +
                config_class = Fast_dLLM_QwenConfig
         
     | 
| 280 | 
         
            +
                base_model_prefix = "model"
         
     | 
| 281 | 
         
            +
                supports_gradient_checkpointing = True
         
     | 
| 282 | 
         
            +
                _no_split_modules = ["Fast_dLLM_QwenDecoderLayer"]
         
     | 
| 283 | 
         
            +
                _skip_keys_device_placement = ["past_key_values"]
         
     | 
| 284 | 
         
            +
                _supports_flash_attn_2 = True
         
     | 
| 285 | 
         
            +
                _supports_sdpa = True
         
     | 
| 286 | 
         
            +
                _supports_flex_attn = True
         
     | 
| 287 | 
         
            +
                _supports_cache_class = True
         
     | 
| 288 | 
         
            +
                _supports_quantized_cache = True
         
     | 
| 289 | 
         
            +
                _supports_static_cache = True
         
     | 
| 290 | 
         
            +
                _supports_attention_backend = True
         
     | 
| 291 | 
         
            +
                _can_record_outputs = {
         
     | 
| 292 | 
         
            +
                    "hidden_states": Fast_dLLM_QwenDecoderLayer,
         
     | 
| 293 | 
         
            +
                    "attentions": Fast_dLLM_QwenAttention,
         
     | 
| 294 | 
         
            +
                }
         
     | 
| 295 | 
         
            +
             
     | 
| 296 | 
         
            +
                def _init_weights(self, module):
         
     | 
| 297 | 
         
            +
                    std = self.config.initializer_range
         
     | 
| 298 | 
         
            +
                    if isinstance(module, nn.Linear):
         
     | 
| 299 | 
         
            +
                        module.weight.data.normal_(mean=0.0, std=std)
         
     | 
| 300 | 
         
            +
                        if module.bias is not None:
         
     | 
| 301 | 
         
            +
                            module.bias.data.zero_()
         
     | 
| 302 | 
         
            +
                    elif isinstance(module, nn.Embedding):
         
     | 
| 303 | 
         
            +
                        module.weight.data.normal_(mean=0.0, std=std)
         
     | 
| 304 | 
         
            +
                        if module.padding_idx is not None:
         
     | 
| 305 | 
         
            +
                            module.weight.data[module.padding_idx].zero_()
         
     | 
| 306 | 
         
            +
                    elif isinstance(module, Fast_dLLM_QwenRMSNorm):
         
     | 
| 307 | 
         
            +
                        module.weight.data.fill_(1.0)
         
     | 
| 308 | 
         
            +
             
     | 
| 309 | 
         
            +
             
     | 
| 310 | 
         
            +
            class Fast_dLLM_QwenRotaryEmbedding(nn.Module):
         
     | 
| 311 | 
         
            +
                def __init__(self, config: Fast_dLLM_QwenConfig, device=None):
         
     | 
| 312 | 
         
            +
                    super().__init__()
         
     | 
| 313 | 
         
            +
                    # BC: "rope_type" was originally "type"
         
     | 
| 314 | 
         
            +
                    if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
         
     | 
| 315 | 
         
            +
                        self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
         
     | 
| 316 | 
         
            +
                    else:
         
     | 
| 317 | 
         
            +
                        self.rope_type = "default"
         
     | 
| 318 | 
         
            +
                    self.max_seq_len_cached = config.max_position_embeddings
         
     | 
| 319 | 
         
            +
                    self.original_max_seq_len = config.max_position_embeddings
         
     | 
| 320 | 
         
            +
             
     | 
| 321 | 
         
            +
                    self.config = config
         
     | 
| 322 | 
         
            +
                    self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
         
     | 
| 323 | 
         
            +
             
     | 
| 324 | 
         
            +
                    inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
         
     | 
| 325 | 
         
            +
                    self.register_buffer("inv_freq", inv_freq, persistent=False)
         
     | 
| 326 | 
         
            +
                    self.original_inv_freq = self.inv_freq
         
     | 
| 327 | 
         
            +
             
     | 
| 328 | 
         
            +
                @torch.no_grad()
         
     | 
| 329 | 
         
            +
                @dynamic_rope_update  # power user: used with advanced RoPE types (e.g. dynamic rope)
         
     | 
| 330 | 
         
            +
                def forward(self, x, position_ids):
         
     | 
| 331 | 
         
            +
                    inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
         
     | 
| 332 | 
         
            +
                    position_ids_expanded = position_ids[:, None, :].float()
         
     | 
| 333 | 
         
            +
             
     | 
| 334 | 
         
            +
                    device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
         
     | 
| 335 | 
         
            +
                    with torch.autocast(device_type=device_type, enabled=False):  # Force float32
         
     | 
| 336 | 
         
            +
                        freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
         
     | 
| 337 | 
         
            +
                        emb = torch.cat((freqs, freqs), dim=-1)
         
     | 
| 338 | 
         
            +
                        cos = emb.cos() * self.attention_scaling
         
     | 
| 339 | 
         
            +
                        sin = emb.sin() * self.attention_scaling
         
     | 
| 340 | 
         
            +
             
     | 
| 341 | 
         
            +
                    return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
         
     | 
| 342 | 
         
            +
             
     | 
| 343 | 
         
            +
             
     | 
| 344 | 
         
            +
             
     | 
| 345 | 
         
            +
            class Fast_dLLM_QwenModel(Fast_dLLM_QwenPreTrainedModel):
         
     | 
| 346 | 
         
            +
                def __init__(self, config: Fast_dLLM_QwenConfig):
         
     | 
| 347 | 
         
            +
                    super().__init__(config)
         
     | 
| 348 | 
         
            +
                    self.padding_idx = config.pad_token_id
         
     | 
| 349 | 
         
            +
                    self.vocab_size = config.vocab_size
         
     | 
| 350 | 
         
            +
             
     | 
| 351 | 
         
            +
                    self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
         
     | 
| 352 | 
         
            +
                    self.layers = nn.ModuleList(
         
     | 
| 353 | 
         
            +
                        [Fast_dLLM_QwenDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
         
     | 
| 354 | 
         
            +
                    )
         
     | 
| 355 | 
         
            +
                    self.norm = Fast_dLLM_QwenRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
         
     | 
| 356 | 
         
            +
                    self.rotary_emb = Fast_dLLM_QwenRotaryEmbedding(config=config)
         
     | 
| 357 | 
         
            +
                    self.gradient_checkpointing = True
         
     | 
| 358 | 
         
            +
             
     | 
| 359 | 
         
            +
                    # Initialize weights and apply final processing
         
     | 
| 360 | 
         
            +
                    self.post_init()
         
     | 
| 361 | 
         
            +
             
     | 
| 362 | 
         
            +
                def get_input_embeddings(self):
         
     | 
| 363 | 
         
            +
                    return self.embed_tokens
         
     | 
| 364 | 
         
            +
             
     | 
| 365 | 
         
            +
                def set_input_embeddings(self, value):
         
     | 
| 366 | 
         
            +
                    self.embed_tokens = value
         
     | 
| 367 | 
         
            +
             
     | 
| 368 | 
         
            +
             
     | 
| 369 | 
         
            +
                def eval_mask(self, seqlen, block_size, cache_seq_len):
         
     | 
| 370 | 
         
            +
                    q_indices = torch.arange(seqlen) + cache_seq_len
         
     | 
| 371 | 
         
            +
                    k_indices = torch.arange(seqlen + cache_seq_len)
         
     | 
| 372 | 
         
            +
                    mask = eval_block_diff_mask(
         
     | 
| 373 | 
         
            +
                        q_idx=q_indices[:, None], 
         
     | 
| 374 | 
         
            +
                        kv_idx=k_indices[None, :], 
         
     | 
| 375 | 
         
            +
                        block_size=block_size
         
     | 
| 376 | 
         
            +
                    )
         
     | 
| 377 | 
         
            +
                    return mask
         
     | 
| 378 | 
         
            +
             
     | 
| 379 | 
         
            +
                def forward(
         
     | 
| 380 | 
         
            +
                    self,
         
     | 
| 381 | 
         
            +
                    input_ids: Optional[torch.LongTensor] = None,
         
     | 
| 382 | 
         
            +
                    labels: Optional[torch.LongTensor] = None,
         
     | 
| 383 | 
         
            +
                    attention_mask: Optional[torch.Tensor] = None,
         
     | 
| 384 | 
         
            +
                    position_ids: Optional[torch.LongTensor] = None,
         
     | 
| 385 | 
         
            +
                    past_key_values: Optional[Cache] = None,
         
     | 
| 386 | 
         
            +
                    inputs_embeds: Optional[torch.FloatTensor] = None,
         
     | 
| 387 | 
         
            +
                    use_cache: Optional[bool] = None,
         
     | 
| 388 | 
         
            +
                    cache_position: Optional[torch.LongTensor] = None,
         
     | 
| 389 | 
         
            +
                    update_past_key_values: Optional[bool] = False,
         
     | 
| 390 | 
         
            +
                    block_size: Optional[int] = 32,
         
     | 
| 391 | 
         
            +
                    use_block_cache: Optional[bool] = False,
         
     | 
| 392 | 
         
            +
                    block_past_key_values: Optional[Cache] = None,
         
     | 
| 393 | 
         
            +
                    replace_position: Optional[int] = None,
         
     | 
| 394 | 
         
            +
                    **kwargs
         
     | 
| 395 | 
         
            +
                ) -> BaseModelOutputWithPast:
         
     | 
| 396 | 
         
            +
                    if (input_ids is None) ^ (inputs_embeds is not None):
         
     | 
| 397 | 
         
            +
                        raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
         
     | 
| 398 | 
         
            +
             
     | 
| 399 | 
         
            +
                    if inputs_embeds is None:
         
     | 
| 400 | 
         
            +
                        inputs_embeds = self.embed_tokens(input_ids)
         
     | 
| 401 | 
         
            +
             
     | 
| 402 | 
         
            +
                    if use_cache and past_key_values is None:
         
     | 
| 403 | 
         
            +
                        past_key_values = DynamicCache()
         
     | 
| 404 | 
         
            +
             
     | 
| 405 | 
         
            +
                    if use_block_cache and block_past_key_values is None:
         
     | 
| 406 | 
         
            +
                        block_past_key_values = DynamicCache()
         
     | 
| 407 | 
         
            +
             
     | 
| 408 | 
         
            +
                    if cache_position is None:
         
     | 
| 409 | 
         
            +
                        past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
         
     | 
| 410 | 
         
            +
                        if use_block_cache:
         
     | 
| 411 | 
         
            +
                            block_start_position = past_seen_tokens+replace_position if replace_position is not None else past_seen_tokens
         
     | 
| 412 | 
         
            +
                            cache_position = torch.arange(
         
     | 
| 413 | 
         
            +
                                block_start_position, block_start_position + inputs_embeds.shape[1], device=inputs_embeds.device
         
     | 
| 414 | 
         
            +
                            )
         
     | 
| 415 | 
         
            +
                        else:
         
     | 
| 416 | 
         
            +
                            cache_position = torch.arange(
         
     | 
| 417 | 
         
            +
                                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1] if not self.training else inputs_embeds.shape[1]//2, device=inputs_embeds.device
         
     | 
| 418 | 
         
            +
                            )
         
     | 
| 419 | 
         
            +
             
     | 
| 420 | 
         
            +
                    if position_ids is None:
         
     | 
| 421 | 
         
            +
                        position_ids = cache_position.unsqueeze(0)
         
     | 
| 422 | 
         
            +
                    
         
     | 
| 423 | 
         
            +
                    if use_block_cache and block_past_key_values.get_seq_length() != 0:
         
     | 
| 424 | 
         
            +
                        attention_mask = None
         
     | 
| 425 | 
         
            +
                    else:
         
     | 
| 426 | 
         
            +
                        attention_mask = self.eval_mask(input_ids.shape[1], block_size, past_key_values.get_seq_length() if past_key_values is not None else 0).to(device=inputs_embeds.device)
         
     | 
| 427 | 
         
            +
             
     | 
| 428 | 
         
            +
                    hidden_states = inputs_embeds
         
     | 
| 429 | 
         
            +
             
     | 
| 430 | 
         
            +
                    # create position embeddings to be shared across the decoder layers
         
     | 
| 431 | 
         
            +
                    position_embeddings = self.rotary_emb(hidden_states, position_ids)
         
     | 
| 432 | 
         
            +
             
     | 
| 433 | 
         
            +
                    for decoder_layer in self.layers[: self.config.num_hidden_layers]:
         
     | 
| 434 | 
         
            +
                        hidden_states = decoder_layer(
         
     | 
| 435 | 
         
            +
                            hidden_states,
         
     | 
| 436 | 
         
            +
                            attention_mask=attention_mask,
         
     | 
| 437 | 
         
            +
                            position_ids=position_ids,
         
     | 
| 438 | 
         
            +
                            past_key_value=past_key_values,
         
     | 
| 439 | 
         
            +
                            use_cache=use_cache,
         
     | 
| 440 | 
         
            +
                            cache_position=cache_position,
         
     | 
| 441 | 
         
            +
                            position_embeddings=position_embeddings,
         
     | 
| 442 | 
         
            +
                            update_past_key_values=update_past_key_values,
         
     | 
| 443 | 
         
            +
                            use_block_cache=use_block_cache,
         
     | 
| 444 | 
         
            +
                            block_past_key_values=block_past_key_values,
         
     | 
| 445 | 
         
            +
                            replace_position=replace_position,
         
     | 
| 446 | 
         
            +
                            **kwargs,
         
     | 
| 447 | 
         
            +
                        )
         
     | 
| 448 | 
         
            +
             
     | 
| 449 | 
         
            +
                    hidden_states = self.norm(hidden_states)
         
     | 
| 450 | 
         
            +
                    return BaseModelOutputWithPastAndBlockCache(
         
     | 
| 451 | 
         
            +
                        last_hidden_state=hidden_states,
         
     | 
| 452 | 
         
            +
                        past_key_values=past_key_values if use_cache else None,
         
     | 
| 453 | 
         
            +
                        block_past_key_values=block_past_key_values if use_block_cache else None,
         
     | 
| 454 | 
         
            +
                    )
         
     | 
| 455 | 
         
            +
             
     | 
| 456 | 
         
            +
             
     | 
| 457 | 
         
            +
            class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
         
     | 
| 458 | 
         
            +
                _tied_weights_keys = ["lm_head.weight"]
         
     | 
| 459 | 
         
            +
                _tp_plan = {"lm_head": "colwise_rep"}
         
     | 
| 460 | 
         
            +
                _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
         
     | 
| 461 | 
         
            +
             
     | 
| 462 | 
         
            +
                def __init__(self, config):
         
     | 
| 463 | 
         
            +
                    super().__init__(config)
         
     | 
| 464 | 
         
            +
                    self.model = Fast_dLLM_QwenModel(config)
         
     | 
| 465 | 
         
            +
                    self.vocab_size = config.vocab_size
         
     | 
| 466 | 
         
            +
                    self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
         
     | 
| 467 | 
         
            +
             
     | 
| 468 | 
         
            +
                    # Initialize weights and apply final processing
         
     | 
| 469 | 
         
            +
                    self.post_init()
         
     | 
| 470 | 
         
            +
             
     | 
| 471 | 
         
            +
                def get_input_embeddings(self):
         
     | 
| 472 | 
         
            +
                    return self.model.embed_tokens
         
     | 
| 473 | 
         
            +
             
     | 
| 474 | 
         
            +
                def set_input_embeddings(self, value):
         
     | 
| 475 | 
         
            +
                    self.model.embed_tokens = value
         
     | 
| 476 | 
         
            +
             
     | 
| 477 | 
         
            +
                def get_output_embeddings(self):
         
     | 
| 478 | 
         
            +
                    return self.lm_head
         
     | 
| 479 | 
         
            +
             
     | 
| 480 | 
         
            +
                def set_output_embeddings(self, new_embeddings):
         
     | 
| 481 | 
         
            +
                    self.lm_head = new_embeddings
         
     | 
| 482 | 
         
            +
             
     | 
| 483 | 
         
            +
                def set_decoder(self, decoder):
         
     | 
| 484 | 
         
            +
                    self.model = decoder
         
     | 
| 485 | 
         
            +
             
     | 
| 486 | 
         
            +
                def get_decoder(self):
         
     | 
| 487 | 
         
            +
                    return self.model
         
     | 
| 488 | 
         
            +
             
     | 
| 489 | 
         
            +
                @can_return_tuple
         
     | 
| 490 | 
         
            +
                def forward(
         
     | 
| 491 | 
         
            +
                    self,
         
     | 
| 492 | 
         
            +
                    input_ids: Optional[torch.LongTensor] = None,
         
     | 
| 493 | 
         
            +
                    attention_mask: Optional[torch.Tensor] = None,
         
     | 
| 494 | 
         
            +
                    position_ids: Optional[torch.LongTensor] = None,
         
     | 
| 495 | 
         
            +
                    past_key_values: Optional[Cache] = None,
         
     | 
| 496 | 
         
            +
                    inputs_embeds: Optional[torch.FloatTensor] = None,
         
     | 
| 497 | 
         
            +
                    labels: Optional[torch.LongTensor] = None,
         
     | 
| 498 | 
         
            +
                    use_cache: Optional[bool] = None,
         
     | 
| 499 | 
         
            +
                    cache_position: Optional[torch.LongTensor] = None,
         
     | 
| 500 | 
         
            +
                    logits_to_keep: Union[int, torch.Tensor] = 0,
         
     | 
| 501 | 
         
            +
                    update_past_key_values: Optional[bool] = False,
         
     | 
| 502 | 
         
            +
                    block_size: Optional[int] = 32,
         
     | 
| 503 | 
         
            +
                    use_block_cache: Optional[bool] = False,
         
     | 
| 504 | 
         
            +
                    block_past_key_values: Optional[Cache] = None,
         
     | 
| 505 | 
         
            +
                    replace_position: Optional[int] = None,
         
     | 
| 506 | 
         
            +
                    **kwargs
         
     | 
| 507 | 
         
            +
                ) -> CausalLMOutputWithPastAndBlockCache:
         
     | 
| 508 | 
         
            +
             
     | 
| 509 | 
         
            +
                    outputs: BaseModelOutputWithPastAndBlockCache = self.model(
         
     | 
| 510 | 
         
            +
                        input_ids=input_ids,
         
     | 
| 511 | 
         
            +
                        labels=labels,
         
     | 
| 512 | 
         
            +
                        attention_mask=attention_mask,
         
     | 
| 513 | 
         
            +
                        position_ids=position_ids,
         
     | 
| 514 | 
         
            +
                        past_key_values=past_key_values,
         
     | 
| 515 | 
         
            +
                        inputs_embeds=inputs_embeds,
         
     | 
| 516 | 
         
            +
                        use_cache=use_cache,
         
     | 
| 517 | 
         
            +
                        cache_position=cache_position,
         
     | 
| 518 | 
         
            +
                        update_past_key_values=update_past_key_values,
         
     | 
| 519 | 
         
            +
                        block_size=block_size,
         
     | 
| 520 | 
         
            +
                        use_block_cache=use_block_cache,
         
     | 
| 521 | 
         
            +
                        block_past_key_values=block_past_key_values,
         
     | 
| 522 | 
         
            +
                        replace_position=replace_position,
         
     | 
| 523 | 
         
            +
                        **kwargs,
         
     | 
| 524 | 
         
            +
                    )
         
     | 
| 525 | 
         
            +
             
     | 
| 526 | 
         
            +
                    hidden_states = outputs.last_hidden_state
         
     | 
| 527 | 
         
            +
                    # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
         
     | 
| 528 | 
         
            +
                    slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
         
     | 
| 529 | 
         
            +
                    logits = self.lm_head(hidden_states[:, slice_indices, :])
         
     | 
| 530 | 
         
            +
             
     | 
| 531 | 
         
            +
                    loss = None
         
     | 
| 532 | 
         
            +
                    if labels is not None:
         
     | 
| 533 | 
         
            +
                        loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
         
     | 
| 534 | 
         
            +
             
     | 
| 535 | 
         
            +
                    return CausalLMOutputWithPastAndBlockCache(
         
     | 
| 536 | 
         
            +
                        loss=loss,
         
     | 
| 537 | 
         
            +
                        logits=logits,
         
     | 
| 538 | 
         
            +
                        past_key_values=outputs.past_key_values,
         
     | 
| 539 | 
         
            +
                        hidden_states=outputs.hidden_states,
         
     | 
| 540 | 
         
            +
                        attentions=outputs.attentions,
         
     | 
| 541 | 
         
            +
                        block_past_key_values=outputs.block_past_key_values,
         
     | 
| 542 | 
         
            +
                    )
         
     | 
| 543 | 
         
            +
             
     | 
| 544 | 
         
            +
                @torch.no_grad()
         
     | 
| 545 | 
         
            +
                def generate(
         
     | 
| 546 | 
         
            +
                    self,
         
     | 
| 547 | 
         
            +
                    input_ids,
         
     | 
| 548 | 
         
            +
                    max_new_tokens, 
         
     | 
| 549 | 
         
            +
                    mask_id=151665,
         
     | 
| 550 | 
         
            +
                    threshold=1,
         
     | 
| 551 | 
         
            +
                    small_block_size=8,
         
     | 
| 552 | 
         
            +
                    block_size=32,
         
     | 
| 553 | 
         
            +
                    stop_token=151645,
         
     | 
| 554 | 
         
            +
                    stopping_criteria=None,
         
     | 
| 555 | 
         
            +
                    top_p=0.95,
         
     | 
| 556 | 
         
            +
                    temperature=0,
         
     | 
| 557 | 
         
            +
                    use_block_cache=False,
         
     | 
| 558 | 
         
            +
                    block_cache_refresh_interval=16,
         
     | 
| 559 | 
         
            +
                    **kwargs
         
     | 
| 560 | 
         
            +
                ):
         
     | 
| 561 | 
         
            +
                    num_blocks = max_new_tokens // block_size
         
     | 
| 562 | 
         
            +
                    original_input_length = input_ids.shape[1]
         
     | 
| 563 | 
         
            +
             
     | 
| 564 | 
         
            +
                    if input_ids.shape[1] > block_size:
         
     | 
| 565 | 
         
            +
                        output = self.forward(input_ids=input_ids[:, :(input_ids.shape[1] // block_size * block_size)], use_cache=True, update_past_key_values=True, block_size=block_size)
         
     | 
| 566 | 
         
            +
                        logits, past_key_values = output.logits, output.past_key_values
         
     | 
| 567 | 
         
            +
                        if input_ids.shape[1] % block_size == 0:
         
     | 
| 568 | 
         
            +
                            next_token = logits[:, -1:, :].argmax(dim=-1)
         
     | 
| 569 | 
         
            +
                            input_ids = torch.cat([input_ids, next_token], dim=1)
         
     | 
| 570 | 
         
            +
                    else:
         
     | 
| 571 | 
         
            +
                        past_key_values = None
         
     | 
| 572 | 
         
            +
             
     | 
| 573 | 
         
            +
                    num_small_blocks = block_size // small_block_size
         
     | 
| 574 | 
         
            +
             
     | 
| 575 | 
         
            +
                    for block_idx in range(num_blocks):
         
     | 
| 576 | 
         
            +
                        if stop_token in input_ids[:, original_input_length:]:
         
     | 
| 577 | 
         
            +
                            break
         
     | 
| 578 | 
         
            +
                        prompt_length = input_ids.shape[1]
         
     | 
| 579 | 
         
            +
                        # Initialize x_init with mask_id
         
     | 
| 580 | 
         
            +
                        x_init = mask_id * torch.ones((input_ids.shape[0], block_size-prompt_length%block_size), device=self.device, dtype=torch.long)
         
     | 
| 581 | 
         
            +
                        x_init = torch.cat([input_ids, x_init], dim=1)
         
     | 
| 582 | 
         
            +
             
     | 
| 583 | 
         
            +
                        x_t = x_init.clone()
         
     | 
| 584 | 
         
            +
                        step = 0
         
     | 
| 585 | 
         
            +
                        block_past_key_values = None
         
     | 
| 586 | 
         
            +
                        while True:
         
     | 
| 587 | 
         
            +
                            if stop_token in x_t[:, prompt_length:]:
         
     | 
| 588 | 
         
            +
                                stop_token_idx = (x_t[:, prompt_length:] == stop_token).nonzero()[0][1]
         
     | 
| 589 | 
         
            +
                                if (x_t[:, prompt_length:prompt_length+stop_token_idx] == mask_id).sum() == 0:
         
     | 
| 590 | 
         
            +
                                    break
         
     | 
| 591 | 
         
            +
                            mask_idx = (x_t[:, -block_size:] == mask_id)
         
     | 
| 592 | 
         
            +
                            # Decode a complete block, update cache, and generate the next token
         
     | 
| 593 | 
         
            +
                            if mask_idx.sum() == 0:
         
     | 
| 594 | 
         
            +
                                output = self.forward(input_ids=x_t[:, -block_size:], use_cache=True, past_key_values=past_key_values, update_past_key_values=True, block_size=block_size)
         
     | 
| 595 | 
         
            +
                                logits, past_key_values = output.logits, output.past_key_values
         
     | 
| 596 | 
         
            +
                                next_token = logits[:, -1:, :].argmax(dim=-1)
         
     | 
| 597 | 
         
            +
                                x_t = torch.cat([x_t, next_token], dim=1)
         
     | 
| 598 | 
         
            +
                                break
         
     | 
| 599 | 
         
            +
                            for small_block_idx in range(num_small_blocks):
         
     | 
| 600 | 
         
            +
                                small_block_start_idx = small_block_idx * small_block_size
         
     | 
| 601 | 
         
            +
                                small_block_end_idx = small_block_start_idx + small_block_size
         
     | 
| 602 | 
         
            +
             
     | 
| 603 | 
         
            +
                                start = -block_size + small_block_start_idx
         
     | 
| 604 | 
         
            +
                                end = None if block_size == small_block_end_idx else -block_size + small_block_end_idx
         
     | 
| 605 | 
         
            +
                                while True:
         
     | 
| 606 | 
         
            +
                                    mask_idx = (x_t[:, -block_size:] == mask_id)
         
     | 
| 607 | 
         
            +
                                    if mask_idx[:, start:end].sum() == 0:
         
     | 
| 608 | 
         
            +
                                        break
         
     | 
| 609 | 
         
            +
                                    if stop_token in x_t[:, prompt_length:]:
         
     | 
| 610 | 
         
            +
                                        stop_token_idx = (x_t[:, prompt_length:] == stop_token).nonzero()[0][1]
         
     | 
| 611 | 
         
            +
                                        if (x_t[:, prompt_length:prompt_length+stop_token_idx] == mask_id).sum() == 0:
         
     | 
| 612 | 
         
            +
                                            break
         
     | 
| 613 | 
         
            +
             
     | 
| 614 | 
         
            +
                                    if use_block_cache:
         
     | 
| 615 | 
         
            +
                                        if step % block_cache_refresh_interval == 0 or (x_t[:, -block_size+small_block_start_idx] == mask_id).any():
         
     | 
| 616 | 
         
            +
                                            output = self.forward(input_ids=x_t[:, -block_size:], use_cache=True, past_key_values=past_key_values, update_past_key_values=False, use_block_cache=True)
         
     | 
| 617 | 
         
            +
                                            logits, block_past_key_values = output.logits, output.block_past_key_values
         
     | 
| 618 | 
         
            +
                                            logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
         
     | 
| 619 | 
         
            +
                                            logits = logits[:, start:end]
         
     | 
| 620 | 
         
            +
                                        else:
         
     | 
| 621 | 
         
            +
                                            logits = self.forward(input_ids=x_t[:,start:end], use_cache=True, past_key_values=past_key_values, update_past_key_values=False, use_block_cache=True, block_past_key_values=block_past_key_values, replace_position=small_block_start_idx).logits
         
     | 
| 622 | 
         
            +
                                            logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
         
     | 
| 623 | 
         
            +
                                    else:
         
     | 
| 624 | 
         
            +
                                        logits = self.forward(input_ids=x_t[:, -block_size:], use_cache=True, past_key_values=past_key_values, update_past_key_values=False).logits
         
     | 
| 625 | 
         
            +
                                        logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
         
     | 
| 626 | 
         
            +
                                        logits = logits[:, start:end]
         
     | 
| 627 | 
         
            +
             
     | 
| 628 | 
         
            +
             
     | 
| 629 | 
         
            +
                                    x_1, p_1t = self.sample_with_top_p(logits, top_p=top_p, temperature=temperature)
         
     | 
| 630 | 
         
            +
                                    # Select tokens with probability greater than threshold from p_1t
         
     | 
| 631 | 
         
            +
                                    x1_p = torch.squeeze(torch.gather(p_1t, dim=-1, index=torch.unsqueeze(x_1, -1)), -1)
         
     | 
| 632 | 
         
            +
                                    x1_p = torch.where(mask_idx[:, start:end], x1_p, -torch.inf)
         
     | 
| 633 | 
         
            +
             
     | 
| 634 | 
         
            +
                                    unmask_idx = (x1_p > threshold)
         
     | 
| 635 | 
         
            +
                                    max_prob_idx = x1_p.argmax(dim=-1)
         
     | 
| 636 | 
         
            +
                                    unmask_idx[torch.arange(x_1.shape[0]), max_prob_idx] = True
         
     | 
| 637 | 
         
            +
                                    unmask_idx = unmask_idx & mask_idx[:, start:end]
         
     | 
| 638 | 
         
            +
             
     | 
| 639 | 
         
            +
                                    x_t[:, start:end][unmask_idx] = x_1[unmask_idx]
         
     | 
| 640 | 
         
            +
             
     | 
| 641 | 
         
            +
                                    step += 1
         
     | 
| 642 | 
         
            +
                        input_ids = x_t
         
     | 
| 643 | 
         
            +
                    # Truncate stop_token
         
     | 
| 644 | 
         
            +
                    if stop_token in input_ids[:, original_input_length:]:
         
     | 
| 645 | 
         
            +
                        stop_token_idx = (input_ids[:, original_input_length:] == stop_token).nonzero()[0][1]
         
     | 
| 646 | 
         
            +
                        input_ids = input_ids[:, :stop_token_idx+original_input_length+1]
         
     | 
| 647 | 
         
            +
                    return input_ids
         
     | 
| 648 | 
         
            +
             
     | 
| 649 | 
         
            +
                def sample_with_top_p(self, logits, top_p=0.95, temperature=1.0):
         
     | 
| 650 | 
         
            +
                    # Calculate probabilities
         
     | 
| 651 | 
         
            +
                    if temperature > 0:
         
     | 
| 652 | 
         
            +
                        scaled_logits = logits / temperature
         
     | 
| 653 | 
         
            +
                    else:
         
     | 
| 654 | 
         
            +
                        p_1t = torch.softmax(logits, dim=-1)
         
     | 
| 655 | 
         
            +
                        x_1 = p_1t.argmax(dim=-1)
         
     | 
| 656 | 
         
            +
                        return x_1, p_1t
         
     | 
| 657 | 
         
            +
                                        
         
     | 
| 658 | 
         
            +
                    probs = F.softmax(scaled_logits, dim=-1)
         
     | 
| 659 | 
         
            +
             
     | 
| 660 | 
         
            +
                    sorted_probs, sorted_indices = torch.sort(probs, descending=True)
         
     | 
| 661 | 
         
            +
                    cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
         
     | 
| 662 | 
         
            +
             
     | 
| 663 | 
         
            +
                    sorted_indices_to_remove = cumulative_probs > top_p
         
     | 
| 664 | 
         
            +
                    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
         
     | 
| 665 | 
         
            +
                    sorted_indices_to_remove[..., 0] = 0
         
     | 
| 666 | 
         
            +
             
     | 
| 667 | 
         
            +
                    indices_to_remove = torch.zeros_like(probs, dtype=torch.bool).scatter_(
         
     | 
| 668 | 
         
            +
                        dim=-1, index=sorted_indices, src=sorted_indices_to_remove
         
     | 
| 669 | 
         
            +
                    )
         
     | 
| 670 | 
         
            +
                    
         
     | 
| 671 | 
         
            +
                    probs[indices_to_remove] = 0
         
     | 
| 672 | 
         
            +
             
     | 
| 673 | 
         
            +
                    # Renormalize so that the probabilities of remaining tokens sum to 1
         
     | 
| 674 | 
         
            +
                    # Add a small epsilon value to prevent division by zero
         
     | 
| 675 | 
         
            +
                    probs_sum = torch.sum(probs, dim=-1, keepdim=True)
         
     | 
| 676 | 
         
            +
                    normalized_probs = probs / probs_sum
         
     | 
| 677 | 
         
            +
             
     | 
| 678 | 
         
            +
                    p_1t = normalized_probs
         
     | 
| 679 | 
         
            +
                    x_1 = torch.multinomial(p_1t[0], num_samples=1).unsqueeze(0).squeeze(-1)
         
     | 
| 680 | 
         
            +
             
     | 
| 681 | 
         
            +
                    return x_1, p_1t
         
     | 
    	
        special_tokens_map.json
    ADDED
    
    | 
         @@ -0,0 +1,25 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            {
         
     | 
| 2 | 
         
            +
              "additional_special_tokens": [
         
     | 
| 3 | 
         
            +
                {
         
     | 
| 4 | 
         
            +
                  "content": "|<MASK>|",
         
     | 
| 5 | 
         
            +
                  "lstrip": false,
         
     | 
| 6 | 
         
            +
                  "normalized": false,
         
     | 
| 7 | 
         
            +
                  "rstrip": false,
         
     | 
| 8 | 
         
            +
                  "single_word": false
         
     | 
| 9 | 
         
            +
                }
         
     | 
| 10 | 
         
            +
              ],
         
     | 
| 11 | 
         
            +
              "eos_token": {
         
     | 
| 12 | 
         
            +
                "content": "<|im_end|>",
         
     | 
| 13 | 
         
            +
                "lstrip": false,
         
     | 
| 14 | 
         
            +
                "normalized": false,
         
     | 
| 15 | 
         
            +
                "rstrip": false,
         
     | 
| 16 | 
         
            +
                "single_word": false
         
     | 
| 17 | 
         
            +
              },
         
     | 
| 18 | 
         
            +
              "pad_token": {
         
     | 
| 19 | 
         
            +
                "content": "<|endoftext|>",
         
     | 
| 20 | 
         
            +
                "lstrip": false,
         
     | 
| 21 | 
         
            +
                "normalized": false,
         
     | 
| 22 | 
         
            +
                "rstrip": false,
         
     | 
| 23 | 
         
            +
                "single_word": false
         
     | 
| 24 | 
         
            +
              }
         
     | 
| 25 | 
         
            +
            }
         
     | 
    	
        tokenizer.json
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:cb2105b66192c5a532e2a098dc899df86eca233b4faa48461211e4312c8b3568
         
     | 
| 3 | 
         
            +
            size 11422081
         
     | 
    	
        tokenizer_config.json
    ADDED
    
    | 
         @@ -0,0 +1,204 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            {
         
     | 
| 2 | 
         
            +
              "add_bos_token": false,
         
     | 
| 3 | 
         
            +
              "add_prefix_space": false,
         
     | 
| 4 | 
         
            +
              "added_tokens_decoder": {
         
     | 
| 5 | 
         
            +
                "151643": {
         
     | 
| 6 | 
         
            +
                  "content": "<|endoftext|>",
         
     | 
| 7 | 
         
            +
                  "lstrip": false,
         
     | 
| 8 | 
         
            +
                  "normalized": false,
         
     | 
| 9 | 
         
            +
                  "rstrip": false,
         
     | 
| 10 | 
         
            +
                  "single_word": false,
         
     | 
| 11 | 
         
            +
                  "special": true
         
     | 
| 12 | 
         
            +
                },
         
     | 
| 13 | 
         
            +
                "151644": {
         
     | 
| 14 | 
         
            +
                  "content": "<|im_start|>",
         
     | 
| 15 | 
         
            +
                  "lstrip": false,
         
     | 
| 16 | 
         
            +
                  "normalized": false,
         
     | 
| 17 | 
         
            +
                  "rstrip": false,
         
     | 
| 18 | 
         
            +
                  "single_word": false,
         
     | 
| 19 | 
         
            +
                  "special": true
         
     | 
| 20 | 
         
            +
                },
         
     | 
| 21 | 
         
            +
                "151645": {
         
     | 
| 22 | 
         
            +
                  "content": "<|im_end|>",
         
     | 
| 23 | 
         
            +
                  "lstrip": false,
         
     | 
| 24 | 
         
            +
                  "normalized": false,
         
     | 
| 25 | 
         
            +
                  "rstrip": false,
         
     | 
| 26 | 
         
            +
                  "single_word": false,
         
     | 
| 27 | 
         
            +
                  "special": true
         
     | 
| 28 | 
         
            +
                },
         
     | 
| 29 | 
         
            +
                "151646": {
         
     | 
| 30 | 
         
            +
                  "content": "<|object_ref_start|>",
         
     | 
| 31 | 
         
            +
                  "lstrip": false,
         
     | 
| 32 | 
         
            +
                  "normalized": false,
         
     | 
| 33 | 
         
            +
                  "rstrip": false,
         
     | 
| 34 | 
         
            +
                  "single_word": false,
         
     | 
| 35 | 
         
            +
                  "special": true
         
     | 
| 36 | 
         
            +
                },
         
     | 
| 37 | 
         
            +
                "151647": {
         
     | 
| 38 | 
         
            +
                  "content": "<|object_ref_end|>",
         
     | 
| 39 | 
         
            +
                  "lstrip": false,
         
     | 
| 40 | 
         
            +
                  "normalized": false,
         
     | 
| 41 | 
         
            +
                  "rstrip": false,
         
     | 
| 42 | 
         
            +
                  "single_word": false,
         
     | 
| 43 | 
         
            +
                  "special": true
         
     | 
| 44 | 
         
            +
                },
         
     | 
| 45 | 
         
            +
                "151648": {
         
     | 
| 46 | 
         
            +
                  "content": "<|box_start|>",
         
     | 
| 47 | 
         
            +
                  "lstrip": false,
         
     | 
| 48 | 
         
            +
                  "normalized": false,
         
     | 
| 49 | 
         
            +
                  "rstrip": false,
         
     | 
| 50 | 
         
            +
                  "single_word": false,
         
     | 
| 51 | 
         
            +
                  "special": true
         
     | 
| 52 | 
         
            +
                },
         
     | 
| 53 | 
         
            +
                "151649": {
         
     | 
| 54 | 
         
            +
                  "content": "<|box_end|>",
         
     | 
| 55 | 
         
            +
                  "lstrip": false,
         
     | 
| 56 | 
         
            +
                  "normalized": false,
         
     | 
| 57 | 
         
            +
                  "rstrip": false,
         
     | 
| 58 | 
         
            +
                  "single_word": false,
         
     | 
| 59 | 
         
            +
                  "special": true
         
     | 
| 60 | 
         
            +
                },
         
     | 
| 61 | 
         
            +
                "151650": {
         
     | 
| 62 | 
         
            +
                  "content": "<|quad_start|>",
         
     | 
| 63 | 
         
            +
                  "lstrip": false,
         
     | 
| 64 | 
         
            +
                  "normalized": false,
         
     | 
| 65 | 
         
            +
                  "rstrip": false,
         
     | 
| 66 | 
         
            +
                  "single_word": false,
         
     | 
| 67 | 
         
            +
                  "special": true
         
     | 
| 68 | 
         
            +
                },
         
     | 
| 69 | 
         
            +
                "151651": {
         
     | 
| 70 | 
         
            +
                  "content": "<|quad_end|>",
         
     | 
| 71 | 
         
            +
                  "lstrip": false,
         
     | 
| 72 | 
         
            +
                  "normalized": false,
         
     | 
| 73 | 
         
            +
                  "rstrip": false,
         
     | 
| 74 | 
         
            +
                  "single_word": false,
         
     | 
| 75 | 
         
            +
                  "special": true
         
     | 
| 76 | 
         
            +
                },
         
     | 
| 77 | 
         
            +
                "151652": {
         
     | 
| 78 | 
         
            +
                  "content": "<|vision_start|>",
         
     | 
| 79 | 
         
            +
                  "lstrip": false,
         
     | 
| 80 | 
         
            +
                  "normalized": false,
         
     | 
| 81 | 
         
            +
                  "rstrip": false,
         
     | 
| 82 | 
         
            +
                  "single_word": false,
         
     | 
| 83 | 
         
            +
                  "special": true
         
     | 
| 84 | 
         
            +
                },
         
     | 
| 85 | 
         
            +
                "151653": {
         
     | 
| 86 | 
         
            +
                  "content": "<|vision_end|>",
         
     | 
| 87 | 
         
            +
                  "lstrip": false,
         
     | 
| 88 | 
         
            +
                  "normalized": false,
         
     | 
| 89 | 
         
            +
                  "rstrip": false,
         
     | 
| 90 | 
         
            +
                  "single_word": false,
         
     | 
| 91 | 
         
            +
                  "special": true
         
     | 
| 92 | 
         
            +
                },
         
     | 
| 93 | 
         
            +
                "151654": {
         
     | 
| 94 | 
         
            +
                  "content": "<|vision_pad|>",
         
     | 
| 95 | 
         
            +
                  "lstrip": false,
         
     | 
| 96 | 
         
            +
                  "normalized": false,
         
     | 
| 97 | 
         
            +
                  "rstrip": false,
         
     | 
| 98 | 
         
            +
                  "single_word": false,
         
     | 
| 99 | 
         
            +
                  "special": true
         
     | 
| 100 | 
         
            +
                },
         
     | 
| 101 | 
         
            +
                "151655": {
         
     | 
| 102 | 
         
            +
                  "content": "<|image_pad|>",
         
     | 
| 103 | 
         
            +
                  "lstrip": false,
         
     | 
| 104 | 
         
            +
                  "normalized": false,
         
     | 
| 105 | 
         
            +
                  "rstrip": false,
         
     | 
| 106 | 
         
            +
                  "single_word": false,
         
     | 
| 107 | 
         
            +
                  "special": true
         
     | 
| 108 | 
         
            +
                },
         
     | 
| 109 | 
         
            +
                "151656": {
         
     | 
| 110 | 
         
            +
                  "content": "<|video_pad|>",
         
     | 
| 111 | 
         
            +
                  "lstrip": false,
         
     | 
| 112 | 
         
            +
                  "normalized": false,
         
     | 
| 113 | 
         
            +
                  "rstrip": false,
         
     | 
| 114 | 
         
            +
                  "single_word": false,
         
     | 
| 115 | 
         
            +
                  "special": true
         
     | 
| 116 | 
         
            +
                },
         
     | 
| 117 | 
         
            +
                "151657": {
         
     | 
| 118 | 
         
            +
                  "content": "<tool_call>",
         
     | 
| 119 | 
         
            +
                  "lstrip": false,
         
     | 
| 120 | 
         
            +
                  "normalized": false,
         
     | 
| 121 | 
         
            +
                  "rstrip": false,
         
     | 
| 122 | 
         
            +
                  "single_word": false,
         
     | 
| 123 | 
         
            +
                  "special": false
         
     | 
| 124 | 
         
            +
                },
         
     | 
| 125 | 
         
            +
                "151658": {
         
     | 
| 126 | 
         
            +
                  "content": "</tool_call>",
         
     | 
| 127 | 
         
            +
                  "lstrip": false,
         
     | 
| 128 | 
         
            +
                  "normalized": false,
         
     | 
| 129 | 
         
            +
                  "rstrip": false,
         
     | 
| 130 | 
         
            +
                  "single_word": false,
         
     | 
| 131 | 
         
            +
                  "special": false
         
     | 
| 132 | 
         
            +
                },
         
     | 
| 133 | 
         
            +
                "151659": {
         
     | 
| 134 | 
         
            +
                  "content": "<|fim_prefix|>",
         
     | 
| 135 | 
         
            +
                  "lstrip": false,
         
     | 
| 136 | 
         
            +
                  "normalized": false,
         
     | 
| 137 | 
         
            +
                  "rstrip": false,
         
     | 
| 138 | 
         
            +
                  "single_word": false,
         
     | 
| 139 | 
         
            +
                  "special": false
         
     | 
| 140 | 
         
            +
                },
         
     | 
| 141 | 
         
            +
                "151660": {
         
     | 
| 142 | 
         
            +
                  "content": "<|fim_middle|>",
         
     | 
| 143 | 
         
            +
                  "lstrip": false,
         
     | 
| 144 | 
         
            +
                  "normalized": false,
         
     | 
| 145 | 
         
            +
                  "rstrip": false,
         
     | 
| 146 | 
         
            +
                  "single_word": false,
         
     | 
| 147 | 
         
            +
                  "special": false
         
     | 
| 148 | 
         
            +
                },
         
     | 
| 149 | 
         
            +
                "151661": {
         
     | 
| 150 | 
         
            +
                  "content": "<|fim_suffix|>",
         
     | 
| 151 | 
         
            +
                  "lstrip": false,
         
     | 
| 152 | 
         
            +
                  "normalized": false,
         
     | 
| 153 | 
         
            +
                  "rstrip": false,
         
     | 
| 154 | 
         
            +
                  "single_word": false,
         
     | 
| 155 | 
         
            +
                  "special": false
         
     | 
| 156 | 
         
            +
                },
         
     | 
| 157 | 
         
            +
                "151662": {
         
     | 
| 158 | 
         
            +
                  "content": "<|fim_pad|>",
         
     | 
| 159 | 
         
            +
                  "lstrip": false,
         
     | 
| 160 | 
         
            +
                  "normalized": false,
         
     | 
| 161 | 
         
            +
                  "rstrip": false,
         
     | 
| 162 | 
         
            +
                  "single_word": false,
         
     | 
| 163 | 
         
            +
                  "special": false
         
     | 
| 164 | 
         
            +
                },
         
     | 
| 165 | 
         
            +
                "151663": {
         
     | 
| 166 | 
         
            +
                  "content": "<|repo_name|>",
         
     | 
| 167 | 
         
            +
                  "lstrip": false,
         
     | 
| 168 | 
         
            +
                  "normalized": false,
         
     | 
| 169 | 
         
            +
                  "rstrip": false,
         
     | 
| 170 | 
         
            +
                  "single_word": false,
         
     | 
| 171 | 
         
            +
                  "special": false
         
     | 
| 172 | 
         
            +
                },
         
     | 
| 173 | 
         
            +
                "151664": {
         
     | 
| 174 | 
         
            +
                  "content": "<|file_sep|>",
         
     | 
| 175 | 
         
            +
                  "lstrip": false,
         
     | 
| 176 | 
         
            +
                  "normalized": false,
         
     | 
| 177 | 
         
            +
                  "rstrip": false,
         
     | 
| 178 | 
         
            +
                  "single_word": false,
         
     | 
| 179 | 
         
            +
                  "special": false
         
     | 
| 180 | 
         
            +
                },
         
     | 
| 181 | 
         
            +
                "151665": {
         
     | 
| 182 | 
         
            +
                  "content": "|<MASK>|",
         
     | 
| 183 | 
         
            +
                  "lstrip": false,
         
     | 
| 184 | 
         
            +
                  "normalized": false,
         
     | 
| 185 | 
         
            +
                  "rstrip": false,
         
     | 
| 186 | 
         
            +
                  "single_word": false,
         
     | 
| 187 | 
         
            +
                  "special": true
         
     | 
| 188 | 
         
            +
                }
         
     | 
| 189 | 
         
            +
              },
         
     | 
| 190 | 
         
            +
              "additional_special_tokens": [
         
     | 
| 191 | 
         
            +
                "|<MASK>|"
         
     | 
| 192 | 
         
            +
              ],
         
     | 
| 193 | 
         
            +
              "bos_token": null,
         
     | 
| 194 | 
         
            +
              "clean_up_tokenization_spaces": false,
         
     | 
| 195 | 
         
            +
              "eos_token": "<|im_end|>",
         
     | 
| 196 | 
         
            +
              "errors": "replace",
         
     | 
| 197 | 
         
            +
              "extra_special_tokens": {},
         
     | 
| 198 | 
         
            +
              "model_max_length": 131072,
         
     | 
| 199 | 
         
            +
              "pad_token": "<|endoftext|>",
         
     | 
| 200 | 
         
            +
              "padding_side": "right",
         
     | 
| 201 | 
         
            +
              "split_special_tokens": false,
         
     | 
| 202 | 
         
            +
              "tokenizer_class": "Qwen2Tokenizer",
         
     | 
| 203 | 
         
            +
              "unk_token": null
         
     | 
| 204 | 
         
            +
            }
         
     | 
    	
        vocab.json
    ADDED
    
    | 
         The diff for this file is too large to render. 
		See raw diff 
     | 
| 
         |