Spaces:
Running
Running
File size: 7,756 Bytes
22115d9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
import os
import argparse
import json
import glob
from safetensors import safe_open
from gguf import GGUFReader
from gguf.constants import Keys
from typing import List, Dict, Any
def create_safetensors_index(shards_dir: str, output_dir: str) -> None:
"""Creates the model.safetensors.index.json file by scanning shard files."""
shard_pattern = os.path.join(shards_dir, '*.safetensors')
shard_files = sorted(glob.glob(shard_pattern))
if not shard_files:
print(f"Error: No .safetensors files found in directory: {shards_dir}")
return
print(f"Found {len(shard_files)} shard files to index.")
index_data: Dict[str, Any] = {"metadata": {}, "weight_map": {}}
total_size = 0
for shard_file in shard_files:
shard_basename = os.path.basename(shard_file)
try:
with safe_open(shard_file, framework="pt", device="cpu") as f:
for tensor_name in f.keys():
index_data["weight_map"][tensor_name] = shard_basename
shard_size = os.path.getsize(shard_file)
total_size += shard_size
except Exception as e:
print(f"Warning: Could not process shard {shard_basename}. Error: {e}")
continue
index_data["metadata"]["total_size"] = total_size
index_filepath = os.path.join(output_dir, "model.safetensors.index.json")
try:
with open(index_filepath, 'w', encoding='utf-8') as f:
json.dump(index_data, f, indent=2)
print(f"Successfully created safetensors index file: {index_filepath}")
except Exception as e:
print(f"Error: Failed to write index file. Error: {e}")
def extract_and_save_gguf_configs(reader: GGUFReader, output_dir: str) -> None:
"""Extracts metadata from GGUF and saves config, tokenizer, and generation files."""
config = {}
# --- config.json ---
try:
arch = reader.get_field(Keys.General.ARCHITECTURE).name.lower()
model_type_map = {"llama": "llama", "mistral": "mistral", "gemma": "gemma"}
model_type = model_type_map.get(arch, arch)
config = {
"architectures": [arch.capitalize()],
"model_type": model_type,
"hidden_size": reader.get_int_value(f"{model_type}.embedding_length"),
"intermediate_size": reader.get_int_value(f"{model_type}.feed_forward_length"),
"num_attention_heads": reader.get_int_value(f"{model_type}.attention.head_count"),
"num_hidden_layers": reader.get_int_value(f"{model_type}.block_count"),
"num_key_value_heads": reader.get_int_value(f"{model_type}.attention.head_count_kv"),
"rms_norm_eps": reader.get_float_value(f"{model_type}.attention.layer_norm_rms_epsilon"),
"vocab_size": len(reader.get_field(Keys.Tokenizer.VOCAB)),
"rope_theta": reader.get_float_value(f"{model_type}.rope.freq_base"),
"max_position_embeddings": reader.get_int_value(f"{model_type}.context_length"),
}
with open(os.path.join(output_dir, "config.json"), 'w', encoding='utf-8') as f:
json.dump(config, f, indent=2)
print("Created config.json")
except Exception as e:
print(f"Warning: Could not create config.json. Some values may be missing. Error: {e}")
# --- tokenizer_config.json ---
try:
tokenizer_config = {
"model_max_length": config.get("max_position_embeddings", 4096),
"padding_side": "left",
"tokenizer_class": "LlamaTokenizer",
}
# Add chat template if it exists
try:
chat_template = reader.get_str_value("tokenizer.chat_template")
tokenizer_config["chat_template"] = chat_template
except (KeyError, ValueError):
pass # Field does not exist
with open(os.path.join(output_dir, "tokenizer_config.json"), 'w', encoding='utf-8') as f:
json.dump(tokenizer_config, f, indent=2)
print("Created tokenizer_config.json")
except Exception as e:
print(f"Warning: Could not create tokenizer_config.json. Error: {e}")
# --- tokenizer.json ---
try:
vocab = [item.piece for item in reader.get_field(Keys.Tokenizer.VOCAB)]
merges = reader.get_field(Keys.Tokenizer.MERGES)
tokenizer_data = {
"version": "1.0",
"model": {
"type": "BPE",
"vocab": {token: i for i, token in enumerate(vocab)},
"merges": merges,
},
"added_tokens": [],
}
with open(os.path.join(output_dir, "tokenizer.json"), 'w', encoding='utf-8') as f:
json.dump(tokenizer_data, f, indent=None, separators=(',', ':'))
print("Created tokenizer.json")
except Exception as e:
print(f"Warning: Could not create tokenizer.json. Error: {e}")
# --- special_tokens_map.json ---
try:
special_map = {}
# Use a helper to avoid crashing on missing keys
def add_special_token(key_name, gguf_id_key):
try:
token_id = reader.get_int_value(gguf_id_key)
token_str = vocab[token_id]
special_map[key_name] = token_str
except (KeyError, ValueError, IndexError):
pass
add_special_token("bos_token", "tokenizer.ggml.bos_token_id")
add_special_token("eos_token", "tokenizer.ggml.eos_token_id")
add_special_token("unk_token", "tokenizer.ggml.unknown_token_id")
with open(os.path.join(output_dir, "special_tokens_map.json"), 'w', encoding='utf-8') as f:
json.dump(special_map, f, indent=2)
print("Created special_tokens_map.json")
except Exception as e:
print(f"Warning: Could not create special_tokens_map.json. Error: {e}")
# --- generation_config.json ---
try:
gen_config = {"_from_model_config": True}
try:
gen_config["bos_token_id"] = reader.get_int_value("tokenizer.ggml.bos_token_id")
gen_config["eos_token_id"] = reader.get_int_value("tokenizer.ggml.eos_token_id")
except (KeyError, ValueError):
pass
with open(os.path.join(output_dir, "generation_config.json"), 'w', encoding='utf-8') as f:
json.dump(gen_config, f, indent=2)
print("Created generation_config.json")
except Exception as e:
print(f"Warning: Could not create generation_config.json. Error: {e}")
def main():
parser = argparse.ArgumentParser(
description="Generate safetensors index and config files for a sharded model directory."
)
parser.add_argument(
"--gguf-file",
required=True,
help="Path to the original GGUF file to read metadata from."
)
parser.add_argument(
"--shards-dir",
required=True,
help="Path to the directory containing the sharded .safetensors files."
)
args = parser.parse_args()
if not os.path.isfile(args.gguf_file):
print(f"Error: GGUF file not found at {args.gguf_file}")
return
if not os.path.isdir(args.shards_dir):
print(f"Error: Shards directory not found at {args.shards_dir}")
return
print(f"Loading GGUF metadata from: {args.gguf_file}")
reader = GGUFReader(args.gguf_file, 'r')
# Generate config files from GGUF header and save them to the shards directory
extract_and_save_gguf_configs(reader, args.shards_dir)
# Generate the safetensors index from the actual shard files
create_safetensors_index(args.shards_dir, args.shards_dir)
print("\nMetadata ripping complete.")
if __name__ == "__main__":
main() |