import torch from safetensors import safe_open from safetensors.torch import save_file import os import json from collections import OrderedDict import glob # Import the glob library to find files def convert_multimodal_to_text_only(input_dir, output_dir): """ Converts a sharded multimodal Mistral model to a text-only model. This script can handle models with or without a 'model.safetensors.index.json' file. """ try: if not os.path.exists(output_dir): os.makedirs(output_dir) print(f"Created output directory: {output_dir}") # --- Define the prefixes to handle --- vision_prefixes_to_remove = ["vision_tower.", "multi_modal_projector."] lm_prefix_to_rename = "language_model." # --- Determine the list of shard files to process --- index_file_path = os.path.join(input_dir, "model.safetensors.index.json") shard_filenames = [] if os.path.exists(index_file_path): print("Found 'model.safetensors.index.json'. Processing based on index.") with open(index_file_path, 'r') as f: index_data = json.load(f) weight_map = index_data.get("weight_map", {}) # Get a unique, ordered list of filenames from the weight map shard_filenames = sorted(list(set(weight_map.values()))) else: print("No index file found. Searching for '*.safetensors' files directly.") # Use glob to find all files ending with .safetensors search_pattern = os.path.join(input_dir, '*.safetensors') shard_paths = sorted(glob.glob(search_pattern)) if not shard_paths: print(f"Error: No '.safetensors' files found in {input_dir}") return # Extract just the filenames from the full paths shard_filenames = [os.path.basename(p) for p in shard_paths] print(f"Found {len(shard_filenames)} model shards to process.") # --- Process each shard --- new_weight_map = OrderedDict() total_original_size = 0 total_new_size = 0 for shard_filename in shard_filenames: input_shard_path = os.path.join(input_dir, shard_filename) output_shard_path = os.path.join(output_dir, shard_filename) print(f"\nProcessing shard: {shard_filename}") text_only_tensors = OrderedDict() has_text_tensors = False with safe_open(input_shard_path, framework="pt", device="cpu") as f: for key in f.keys(): is_vision_tensor = any(key.startswith(p) for p in vision_prefixes_to_remove) if is_vision_tensor: continue new_key = key if key.startswith(lm_prefix_to_rename): new_key = key[len(lm_prefix_to_rename):] tensor = f.get_tensor(key) text_only_tensors[new_key] = tensor new_weight_map[new_key] = shard_filename has_text_tensors = True if has_text_tensors: print(f"Saving {len(text_only_tensors)} text-only tensors to: {shard_filename}") save_file(text_only_tensors, output_shard_path) original_size = os.path.getsize(input_shard_path) new_size = os.path.getsize(output_shard_path) total_original_size += original_size total_new_size += new_size print(f" - Original shard size: {original_size / (1024**2):.2f} MB") print(f" - New shard size: {new_size / (1024**2):.2f} MB") else: print(f"Shard {shard_filename} contained only vision tensors and will be skipped.") # --- Create the new index file for the text-only model --- # It's good practice to create one, even if the original didn't have it. new_index_data = { "metadata": { "total_size": total_new_size }, "weight_map": new_weight_map } new_index_path = os.path.join(output_dir, "model.safetensors.index.json") with open(new_index_path, 'w') as f: json.dump(new_index_data, f, indent=2) print(f"\nSuccessfully created new index file at: {new_index_path}") print("\n--- Conversion Summary ---") print(f"Total original model size: {total_original_size / (1024**3):.2f} GB") print(f"Total new text-only model size: {total_new_size / (1024**3):.2f} GB") print("Conversion complete!") except Exception as e: print(f"An error occurred: {e}") if __name__ == "__main__": # --- Configuration --- input_model_directory = r"A:\LLM\.cache\huggingface\hub\test" output_model_directory = r"A:\LLM\.cache\huggingface\hub\test\fix" # --- Run the script --- convert_multimodal_to_text_only(input_model_directory, output_model_directory)