File size: 6,721 Bytes
22115d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import os
import argparse
import torch
import numpy as np
from safetensors.torch import save_file
from safetensors import safe_open
from typing import Dict, Tuple, List
from gguf import GGUFReader, dequantize

def load_gguf_and_extract_metadata(gguf_path: str) -> Tuple[GGUFReader, List[Dict]]:
    """Load GGUF file and extract metadata for all tensors."""
    print(f"Loading GGUF file: {gguf_path}")
    reader = GGUFReader(gguf_path, 'r')
    tensors_metadata = []
    for tensor in reader.tensors:
        tensor_metadata = {
            'name': tensor.name,
            'shape': tuple(tensor.shape.tolist()),
            'n_elements': tensor.n_elements,
            'n_bytes': tensor.n_bytes,
            'type': tensor.tensor_type,
        }
        tensors_metadata.append(tensor_metadata)
    return reader, tensors_metadata

def get_dequantized_tensor_size_in_bytes(tensor_info: Dict, use_bf16: bool) -> int:
    """Calculates the size of a tensor after it has been dequantized to FP16 or BF16."""
    bytes_per_element = 2
    return tensor_info['n_elements'] * bytes_per_element

def get_hf_name(gguf_name: str) -> str:
    """Translates a GGUF tensor name to its Hugging Face equivalent for Llama/Mistral models."""
    name_map = {
        "token_embd.weight": "model.embed_tokens.weight",
        "output_norm.weight": "model.norm.weight",
        "output.weight": "lm_head.weight",
    }
    if gguf_name in name_map:
        return name_map[gguf_name]

    if gguf_name.startswith("blk."):
        parts = gguf_name.split('.')
        layer_num = parts[1]
        layer_part = ".".join(parts[2:])
        
        block_map = {
            "attn_norm.weight": "input_layernorm.weight",
            "ffn_norm.weight": "post_attention_layernorm.weight",
            "attn_q.weight": "self_attn.q_proj.weight",
            "attn_k.weight": "self_attn.k_proj.weight",
            "attn_v.weight": "self_attn.v_proj.weight",
            "attn_output.weight": "self_attn.o_proj.weight",
            "ffn_gate.weight": "mlp.gate_proj.weight",
            "ffn_up.weight": "mlp.up_proj.weight",
            "ffn_down.weight": "mlp.down_proj.weight",
        }
        
        if layer_part in block_map:
            return f"model.layers.{layer_num}.{block_map[layer_part]}"

    print(f"Warning: No mapping found for tensor '{gguf_name}'. Using original name.")
    return gguf_name

def convert_gguf_to_safetensors_by_size(gguf_path: str, output_path: str, use_bf16: bool, shard_size_gb: float):
    """Converts a GGUF file to .safetensors, sharding and renaming tensors for HF compatibility."""
    reader, tensors_metadata = load_gguf_and_extract_metadata(gguf_path)
    print(f"Extracted metadata for {len(tensors_metadata)} tensors from GGUF file.")

    shard_size_bytes = int(shard_size_gb * 1024**3)
    print(f"Target shard size set to ~{shard_size_gb} GB ({shard_size_bytes} bytes).")

    output_dir = os.path.dirname(output_path)
    if not output_dir:
        output_dir = "."
    base_name = os.path.basename(output_path).replace('.safetensors', '')
    
    tensors_in_current_chunk: dict[str, torch.Tensor] = {}
    current_chunk_size_bytes = 0
    num_chunks = 0
    
    total_shards = 0
    temp_size = 0
    for tensor_info in tensors_metadata:
        dequantized_size = get_dequantized_tensor_size_in_bytes(tensor_info, use_bf16)
        if temp_size > 0 and (temp_size + dequantized_size) > shard_size_bytes:
            total_shards += 1
            temp_size = 0
        temp_size += dequantized_size
    if temp_size > 0:
        total_shards += 1
    print(f"Model will be split into {total_shards} shards.")

    for i, tensor_info in enumerate(tensors_metadata):
        gguf_tensor_name = tensor_info['name']
        dequantized_size = get_dequantized_tensor_size_in_bytes(tensor_info, use_bf16)

        if current_chunk_size_bytes > 0 and (current_chunk_size_bytes + dequantized_size) > shard_size_bytes:
            num_chunks += 1
            chunk_path = os.path.join(output_dir, f"{base_name}-{num_chunks:05d}-of-{total_shards:05d}.safetensors")
            
            print(f"\nCurrent chunk size ({current_chunk_size_bytes / 1024**3:.2f} GB) exceeds limit.")
            print(f"Saving chunk {num_chunks} with {len(tensors_in_current_chunk)} tensors to {chunk_path}...\n")
            save_file(tensors_in_current_chunk, chunk_path)
            
            tensors_in_current_chunk.clear()
            current_chunk_size_bytes = 0

        tensor_data = reader.get_tensor(i)
        weights_np = dequantize(tensor_data.data, tensor_data.tensor_type).copy()
        target_dtype = torch.bfloat16 if use_bf16 else torch.float16
        
        try:
            weights_tensor = torch.from_numpy(weights_np).to(target_dtype)
        except Exception as e:
            print(f"Warning: Could not convert {gguf_tensor_name} directly. Error: {e}. Using float32 fallback.")
            weights_tensor = torch.from_numpy(weights_np.astype(np.float32)).to(target_dtype)

        # --- CORRECTED RENAMING LOGIC ---
        hf_tensor_name = get_hf_name(gguf_tensor_name)
        
        print(f"Processed tensor ({i+1}/{len(tensors_metadata)}): {gguf_tensor_name} -> {hf_tensor_name} | Size: {dequantized_size/1024**2:.2f} MB")
        
        tensors_in_current_chunk[hf_tensor_name] = weights_tensor
        current_chunk_size_bytes += dequantized_size
        
        del weights_np
        del tensor_data

    if tensors_in_current_chunk:
        num_chunks += 1
        chunk_path = os.path.join(output_dir, f"{base_name}-{num_chunks:05d}-of-{total_shards:05d}.safetensors")
        print(f"\nSaving final chunk {num_chunks} with {len(tensors_in_current_chunk)} tensors to {chunk_path}...\n")
        save_file(tensors_in_current_chunk, chunk_path)

    print("All tensors have been dequantized, renamed, and saved into sharded safetensor files.")

def main():
    parser = argparse.ArgumentParser(
        description="Convert GGUF to HF-compatible sharded safetensors, renaming tensors correctly."
    )
    parser.add_argument("--input", required=True, help="Path to the input GGUF file.")
    parser.add_argument("--output", required=True, help="Base path for the final output sharded .safetensors files.")
    parser.add_argument("--bf16", action="store_true", help="Convert tensors to BF16 format instead of the default FP16.")
    parser.add_argument(
        "--shard-size", 
        type=float, 
        default=5.0, 
        help="Maximum size of each shard in Gigabytes (GB). Default: 5.0"
    )
    args = parser.parse_args()

    convert_gguf_to_safetensors_by_size(args.input, args.output, args.bf16, args.shard_size)

if __name__ == "__main__":
    main()