Spaces:
Running
Running
| import torch | |
| from safetensors import safe_open | |
| from safetensors.torch import save_file | |
| import os | |
| import json | |
| from collections import OrderedDict | |
| import glob # Import the glob library to find files | |
| def convert_multimodal_to_text_only(input_dir, output_dir): | |
| """ | |
| Converts a sharded multimodal Mistral model to a text-only model. | |
| This script can handle models with or without a 'model.safetensors.index.json' file. | |
| """ | |
| try: | |
| if not os.path.exists(output_dir): | |
| os.makedirs(output_dir) | |
| print(f"Created output directory: {output_dir}") | |
| # --- Define the prefixes to handle --- | |
| vision_prefixes_to_remove = ["vision_tower.", "multi_modal_projector."] | |
| lm_prefix_to_rename = "language_model." | |
| # --- Determine the list of shard files to process --- | |
| index_file_path = os.path.join(input_dir, "model.safetensors.index.json") | |
| shard_filenames = [] | |
| if os.path.exists(index_file_path): | |
| print("Found 'model.safetensors.index.json'. Processing based on index.") | |
| with open(index_file_path, 'r') as f: | |
| index_data = json.load(f) | |
| weight_map = index_data.get("weight_map", {}) | |
| # Get a unique, ordered list of filenames from the weight map | |
| shard_filenames = sorted(list(set(weight_map.values()))) | |
| else: | |
| print("No index file found. Searching for '*.safetensors' files directly.") | |
| # Use glob to find all files ending with .safetensors | |
| search_pattern = os.path.join(input_dir, '*.safetensors') | |
| shard_paths = sorted(glob.glob(search_pattern)) | |
| if not shard_paths: | |
| print(f"Error: No '.safetensors' files found in {input_dir}") | |
| return | |
| # Extract just the filenames from the full paths | |
| shard_filenames = [os.path.basename(p) for p in shard_paths] | |
| print(f"Found {len(shard_filenames)} model shards to process.") | |
| # --- Process each shard --- | |
| new_weight_map = OrderedDict() | |
| total_original_size = 0 | |
| total_new_size = 0 | |
| for shard_filename in shard_filenames: | |
| input_shard_path = os.path.join(input_dir, shard_filename) | |
| output_shard_path = os.path.join(output_dir, shard_filename) | |
| print(f"\nProcessing shard: {shard_filename}") | |
| text_only_tensors = OrderedDict() | |
| has_text_tensors = False | |
| with safe_open(input_shard_path, framework="pt", device="cpu") as f: | |
| for key in f.keys(): | |
| is_vision_tensor = any(key.startswith(p) for p in vision_prefixes_to_remove) | |
| if is_vision_tensor: | |
| continue | |
| new_key = key | |
| if key.startswith(lm_prefix_to_rename): | |
| new_key = key[len(lm_prefix_to_rename):] | |
| tensor = f.get_tensor(key) | |
| text_only_tensors[new_key] = tensor | |
| new_weight_map[new_key] = shard_filename | |
| has_text_tensors = True | |
| if has_text_tensors: | |
| print(f"Saving {len(text_only_tensors)} text-only tensors to: {shard_filename}") | |
| save_file(text_only_tensors, output_shard_path) | |
| original_size = os.path.getsize(input_shard_path) | |
| new_size = os.path.getsize(output_shard_path) | |
| total_original_size += original_size | |
| total_new_size += new_size | |
| print(f" - Original shard size: {original_size / (1024**2):.2f} MB") | |
| print(f" - New shard size: {new_size / (1024**2):.2f} MB") | |
| else: | |
| print(f"Shard {shard_filename} contained only vision tensors and will be skipped.") | |
| # --- Create the new index file for the text-only model --- | |
| # It's good practice to create one, even if the original didn't have it. | |
| new_index_data = { | |
| "metadata": { | |
| "total_size": total_new_size | |
| }, | |
| "weight_map": new_weight_map | |
| } | |
| new_index_path = os.path.join(output_dir, "model.safetensors.index.json") | |
| with open(new_index_path, 'w') as f: | |
| json.dump(new_index_data, f, indent=2) | |
| print(f"\nSuccessfully created new index file at: {new_index_path}") | |
| print("\n--- Conversion Summary ---") | |
| print(f"Total original model size: {total_original_size / (1024**3):.2f} GB") | |
| print(f"Total new text-only model size: {total_new_size / (1024**3):.2f} GB") | |
| print("Conversion complete!") | |
| except Exception as e: | |
| print(f"An error occurred: {e}") | |
| if __name__ == "__main__": | |
| # --- Configuration --- | |
| input_model_directory = r"A:\LLM\.cache\huggingface\hub\test" | |
| output_model_directory = r"A:\LLM\.cache\huggingface\hub\test\fix" | |
| # --- Run the script --- | |
| convert_multimodal_to_text_only(input_model_directory, output_model_directory) |