File size: 5,104 Bytes
9895dce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch
from safetensors import safe_open
from safetensors.torch import save_file
import os
import json
from collections import OrderedDict
import glob # Import the glob library to find files

def convert_multimodal_to_text_only(input_dir, output_dir):
    """
    Converts a sharded multimodal Mistral model to a text-only model.
    This script can handle models with or without a 'model.safetensors.index.json' file.
    """
    try:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
            print(f"Created output directory: {output_dir}")

        # --- Define the prefixes to handle ---
        vision_prefixes_to_remove = ["vision_tower.", "multi_modal_projector."]
        lm_prefix_to_rename = "language_model."

        # --- Determine the list of shard files to process ---
        index_file_path = os.path.join(input_dir, "model.safetensors.index.json")
        shard_filenames = []

        if os.path.exists(index_file_path):
            print("Found 'model.safetensors.index.json'. Processing based on index.")
            with open(index_file_path, 'r') as f:
                index_data = json.load(f)
            weight_map = index_data.get("weight_map", {})
            # Get a unique, ordered list of filenames from the weight map
            shard_filenames = sorted(list(set(weight_map.values())))
        else:
            print("No index file found. Searching for '*.safetensors' files directly.")
            # Use glob to find all files ending with .safetensors
            search_pattern = os.path.join(input_dir, '*.safetensors')
            shard_paths = sorted(glob.glob(search_pattern))
            if not shard_paths:
                print(f"Error: No '.safetensors' files found in {input_dir}")
                return
            # Extract just the filenames from the full paths
            shard_filenames = [os.path.basename(p) for p in shard_paths]

        print(f"Found {len(shard_filenames)} model shards to process.")

        # --- Process each shard ---
        new_weight_map = OrderedDict()
        total_original_size = 0
        total_new_size = 0

        for shard_filename in shard_filenames:
            input_shard_path = os.path.join(input_dir, shard_filename)
            output_shard_path = os.path.join(output_dir, shard_filename)
            
            print(f"\nProcessing shard: {shard_filename}")

            text_only_tensors = OrderedDict()
            has_text_tensors = False

            with safe_open(input_shard_path, framework="pt", device="cpu") as f:
                for key in f.keys():
                    is_vision_tensor = any(key.startswith(p) for p in vision_prefixes_to_remove)
                    
                    if is_vision_tensor:
                        continue

                    new_key = key
                    if key.startswith(lm_prefix_to_rename):
                        new_key = key[len(lm_prefix_to_rename):]
                    
                    tensor = f.get_tensor(key)
                    text_only_tensors[new_key] = tensor
                    new_weight_map[new_key] = shard_filename
                    has_text_tensors = True

            if has_text_tensors:
                print(f"Saving {len(text_only_tensors)} text-only tensors to: {shard_filename}")
                save_file(text_only_tensors, output_shard_path)
                
                original_size = os.path.getsize(input_shard_path)
                new_size = os.path.getsize(output_shard_path)
                total_original_size += original_size
                total_new_size += new_size
                print(f"  - Original shard size: {original_size / (1024**2):.2f} MB")
                print(f"  - New shard size: {new_size / (1024**2):.2f} MB")
            else:
                print(f"Shard {shard_filename} contained only vision tensors and will be skipped.")

        # --- Create the new index file for the text-only model ---
        # It's good practice to create one, even if the original didn't have it.
        new_index_data = {
            "metadata": {
                "total_size": total_new_size
            },
            "weight_map": new_weight_map
        }
        new_index_path = os.path.join(output_dir, "model.safetensors.index.json")
        with open(new_index_path, 'w') as f:
            json.dump(new_index_data, f, indent=2)
        
        print(f"\nSuccessfully created new index file at: {new_index_path}")
        print("\n--- Conversion Summary ---")
        print(f"Total original model size: {total_original_size / (1024**3):.2f} GB")
        print(f"Total new text-only model size: {total_new_size / (1024**3):.2f} GB")
        print("Conversion complete!")

    except Exception as e:
        print(f"An error occurred: {e}")

if __name__ == "__main__":
    # --- Configuration ---
    input_model_directory = r"A:\LLM\.cache\huggingface\hub\test"
    output_model_directory = r"A:\LLM\.cache\huggingface\hub\test\fix"

    # --- Run the script ---
    convert_multimodal_to_text_only(input_model_directory, output_model_directory)