File size: 7,756 Bytes
22115d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import os
import argparse
import json
import glob
from safetensors import safe_open
from gguf import GGUFReader
from gguf.constants import Keys
from typing import List, Dict, Any

def create_safetensors_index(shards_dir: str, output_dir: str) -> None:
    """Creates the model.safetensors.index.json file by scanning shard files."""
    shard_pattern = os.path.join(shards_dir, '*.safetensors')
    shard_files = sorted(glob.glob(shard_pattern))

    if not shard_files:
        print(f"Error: No .safetensors files found in directory: {shards_dir}")
        return

    print(f"Found {len(shard_files)} shard files to index.")

    index_data: Dict[str, Any] = {"metadata": {}, "weight_map": {}}
    total_size = 0

    for shard_file in shard_files:
        shard_basename = os.path.basename(shard_file)
        try:
            with safe_open(shard_file, framework="pt", device="cpu") as f:
                for tensor_name in f.keys():
                    index_data["weight_map"][tensor_name] = shard_basename
            
            shard_size = os.path.getsize(shard_file)
            total_size += shard_size
        except Exception as e:
            print(f"Warning: Could not process shard {shard_basename}. Error: {e}")
            continue

    index_data["metadata"]["total_size"] = total_size
    
    index_filepath = os.path.join(output_dir, "model.safetensors.index.json")
    try:
        with open(index_filepath, 'w', encoding='utf-8') as f:
            json.dump(index_data, f, indent=2)
        print(f"Successfully created safetensors index file: {index_filepath}")
    except Exception as e:
        print(f"Error: Failed to write index file. Error: {e}")

def extract_and_save_gguf_configs(reader: GGUFReader, output_dir: str) -> None:
    """Extracts metadata from GGUF and saves config, tokenizer, and generation files."""
    
    config = {}
    # --- config.json ---
    try:
        arch = reader.get_field(Keys.General.ARCHITECTURE).name.lower()
        model_type_map = {"llama": "llama", "mistral": "mistral", "gemma": "gemma"}
        model_type = model_type_map.get(arch, arch)

        config = {
            "architectures": [arch.capitalize()],
            "model_type": model_type,
            "hidden_size": reader.get_int_value(f"{model_type}.embedding_length"),
            "intermediate_size": reader.get_int_value(f"{model_type}.feed_forward_length"),
            "num_attention_heads": reader.get_int_value(f"{model_type}.attention.head_count"),
            "num_hidden_layers": reader.get_int_value(f"{model_type}.block_count"),
            "num_key_value_heads": reader.get_int_value(f"{model_type}.attention.head_count_kv"),
            "rms_norm_eps": reader.get_float_value(f"{model_type}.attention.layer_norm_rms_epsilon"),
            "vocab_size": len(reader.get_field(Keys.Tokenizer.VOCAB)),
            "rope_theta": reader.get_float_value(f"{model_type}.rope.freq_base"),
            "max_position_embeddings": reader.get_int_value(f"{model_type}.context_length"),
        }
        with open(os.path.join(output_dir, "config.json"), 'w', encoding='utf-8') as f:
            json.dump(config, f, indent=2)
        print("Created config.json")
    except Exception as e:
        print(f"Warning: Could not create config.json. Some values may be missing. Error: {e}")

    # --- tokenizer_config.json ---
    try:
        tokenizer_config = {
            "model_max_length": config.get("max_position_embeddings", 4096),
            "padding_side": "left",
            "tokenizer_class": "LlamaTokenizer",
        }
        # Add chat template if it exists
        try:
            chat_template = reader.get_str_value("tokenizer.chat_template")
            tokenizer_config["chat_template"] = chat_template
        except (KeyError, ValueError):
            pass # Field does not exist
        
        with open(os.path.join(output_dir, "tokenizer_config.json"), 'w', encoding='utf-8') as f:
            json.dump(tokenizer_config, f, indent=2)
        print("Created tokenizer_config.json")
    except Exception as e:
        print(f"Warning: Could not create tokenizer_config.json. Error: {e}")

    # --- tokenizer.json ---
    try:
        vocab = [item.piece for item in reader.get_field(Keys.Tokenizer.VOCAB)]
        merges = reader.get_field(Keys.Tokenizer.MERGES)
        
        tokenizer_data = {
            "version": "1.0",
            "model": {
                "type": "BPE",
                "vocab": {token: i for i, token in enumerate(vocab)},
                "merges": merges,
            },
            "added_tokens": [],
        }
        with open(os.path.join(output_dir, "tokenizer.json"), 'w', encoding='utf-8') as f:
            json.dump(tokenizer_data, f, indent=None, separators=(',', ':'))
        print("Created tokenizer.json")
    except Exception as e:
        print(f"Warning: Could not create tokenizer.json. Error: {e}")

    # --- special_tokens_map.json ---
    try:
        special_map = {}
        # Use a helper to avoid crashing on missing keys
        def add_special_token(key_name, gguf_id_key):
            try:
                token_id = reader.get_int_value(gguf_id_key)
                token_str = vocab[token_id]
                special_map[key_name] = token_str
            except (KeyError, ValueError, IndexError):
                pass
        
        add_special_token("bos_token", "tokenizer.ggml.bos_token_id")
        add_special_token("eos_token", "tokenizer.ggml.eos_token_id")
        add_special_token("unk_token", "tokenizer.ggml.unknown_token_id")
        
        with open(os.path.join(output_dir, "special_tokens_map.json"), 'w', encoding='utf-8') as f:
            json.dump(special_map, f, indent=2)
        print("Created special_tokens_map.json")
    except Exception as e:
        print(f"Warning: Could not create special_tokens_map.json. Error: {e}")

    # --- generation_config.json ---
    try:
        gen_config = {"_from_model_config": True}
        try:
            gen_config["bos_token_id"] = reader.get_int_value("tokenizer.ggml.bos_token_id")
            gen_config["eos_token_id"] = reader.get_int_value("tokenizer.ggml.eos_token_id")
        except (KeyError, ValueError):
            pass
        
        with open(os.path.join(output_dir, "generation_config.json"), 'w', encoding='utf-8') as f:
            json.dump(gen_config, f, indent=2)
        print("Created generation_config.json")
    except Exception as e:
        print(f"Warning: Could not create generation_config.json. Error: {e}")

def main():
    parser = argparse.ArgumentParser(
        description="Generate safetensors index and config files for a sharded model directory."
    )
    parser.add_argument(
        "--gguf-file", 
        required=True, 
        help="Path to the original GGUF file to read metadata from."
    )
    parser.add_argument(
        "--shards-dir", 
        required=True, 
        help="Path to the directory containing the sharded .safetensors files."
    )
    args = parser.parse_args()

    if not os.path.isfile(args.gguf_file):
        print(f"Error: GGUF file not found at {args.gguf_file}")
        return
    if not os.path.isdir(args.shards_dir):
        print(f"Error: Shards directory not found at {args.shards_dir}")
        return

    print(f"Loading GGUF metadata from: {args.gguf_file}")
    reader = GGUFReader(args.gguf_file, 'r')

    # Generate config files from GGUF header and save them to the shards directory
    extract_and_save_gguf_configs(reader, args.shards_dir)

    # Generate the safetensors index from the actual shard files
    create_safetensors_index(args.shards_dir, args.shards_dir)
    
    print("\nMetadata ripping complete.")

if __name__ == "__main__":
    main()