model_tools / textonly_ripper_v2.py
Naphula's picture
Upload 2 files
9895dce verified
import torch
from safetensors import safe_open
from safetensors.torch import save_file
import os
import json
from collections import OrderedDict
import glob # Import the glob library to find files
def convert_multimodal_to_text_only(input_dir, output_dir):
"""
Converts a sharded multimodal Mistral model to a text-only model.
This script can handle models with or without a 'model.safetensors.index.json' file.
"""
try:
if not os.path.exists(output_dir):
os.makedirs(output_dir)
print(f"Created output directory: {output_dir}")
# --- Define the prefixes to handle ---
vision_prefixes_to_remove = ["vision_tower.", "multi_modal_projector."]
lm_prefix_to_rename = "language_model."
# --- Determine the list of shard files to process ---
index_file_path = os.path.join(input_dir, "model.safetensors.index.json")
shard_filenames = []
if os.path.exists(index_file_path):
print("Found 'model.safetensors.index.json'. Processing based on index.")
with open(index_file_path, 'r') as f:
index_data = json.load(f)
weight_map = index_data.get("weight_map", {})
# Get a unique, ordered list of filenames from the weight map
shard_filenames = sorted(list(set(weight_map.values())))
else:
print("No index file found. Searching for '*.safetensors' files directly.")
# Use glob to find all files ending with .safetensors
search_pattern = os.path.join(input_dir, '*.safetensors')
shard_paths = sorted(glob.glob(search_pattern))
if not shard_paths:
print(f"Error: No '.safetensors' files found in {input_dir}")
return
# Extract just the filenames from the full paths
shard_filenames = [os.path.basename(p) for p in shard_paths]
print(f"Found {len(shard_filenames)} model shards to process.")
# --- Process each shard ---
new_weight_map = OrderedDict()
total_original_size = 0
total_new_size = 0
for shard_filename in shard_filenames:
input_shard_path = os.path.join(input_dir, shard_filename)
output_shard_path = os.path.join(output_dir, shard_filename)
print(f"\nProcessing shard: {shard_filename}")
text_only_tensors = OrderedDict()
has_text_tensors = False
with safe_open(input_shard_path, framework="pt", device="cpu") as f:
for key in f.keys():
is_vision_tensor = any(key.startswith(p) for p in vision_prefixes_to_remove)
if is_vision_tensor:
continue
new_key = key
if key.startswith(lm_prefix_to_rename):
new_key = key[len(lm_prefix_to_rename):]
tensor = f.get_tensor(key)
text_only_tensors[new_key] = tensor
new_weight_map[new_key] = shard_filename
has_text_tensors = True
if has_text_tensors:
print(f"Saving {len(text_only_tensors)} text-only tensors to: {shard_filename}")
save_file(text_only_tensors, output_shard_path)
original_size = os.path.getsize(input_shard_path)
new_size = os.path.getsize(output_shard_path)
total_original_size += original_size
total_new_size += new_size
print(f" - Original shard size: {original_size / (1024**2):.2f} MB")
print(f" - New shard size: {new_size / (1024**2):.2f} MB")
else:
print(f"Shard {shard_filename} contained only vision tensors and will be skipped.")
# --- Create the new index file for the text-only model ---
# It's good practice to create one, even if the original didn't have it.
new_index_data = {
"metadata": {
"total_size": total_new_size
},
"weight_map": new_weight_map
}
new_index_path = os.path.join(output_dir, "model.safetensors.index.json")
with open(new_index_path, 'w') as f:
json.dump(new_index_data, f, indent=2)
print(f"\nSuccessfully created new index file at: {new_index_path}")
print("\n--- Conversion Summary ---")
print(f"Total original model size: {total_original_size / (1024**3):.2f} GB")
print(f"Total new text-only model size: {total_new_size / (1024**3):.2f} GB")
print("Conversion complete!")
except Exception as e:
print(f"An error occurred: {e}")
if __name__ == "__main__":
# --- Configuration ---
input_model_directory = r"A:\LLM\.cache\huggingface\hub\test"
output_model_directory = r"A:\LLM\.cache\huggingface\hub\test\fix"
# --- Run the script ---
convert_multimodal_to_text_only(input_model_directory, output_model_directory)