Multip-GPU inference is not working,., OMG... "RuntimeError: but found at least two devices, cpu and cuda:5!"
I'm a researcher in Industrial Engineering LAB.
I met a very serious problem.
I want to use a "InternVL2_5-78B", not using quantization for accuracy.
Unfortunately, I only have a100 GPU 3 core(core number = 5, 6, 7).
So, I split the model into 5, 6, 7 core. 
Look at this image, I think the separation of the model is not a problem.
And Use this code for multi-GPU option
import math
import torch
from transformers import AutoTokenizer, AutoModel
def split_model(model_name):
    device_map = {}
    world_size = 3  # Use only GPU 5, 6, 7 Core
    num_layers = {
        'InternVL2_5-1B': 24, 'InternVL2_5-2B': 24, 'InternVL2_5-4B': 36, 'InternVL2_5-8B': 32,
        'InternVL2_5-26B': 48, 'InternVL2_5-38B': 64, 'InternVL2_5-78B': 80}[model_name]
    
    # distribute layers into 3 GPU(5, 6,7)
    num_layers_per_gpu = math.ceil(num_layers / world_size)
    num_layers_per_gpu = [num_layers_per_gpu] * world_size
    layer_cnt = 0
    for i, num_layer in enumerate(num_layers_per_gpu):
        for j in range(num_layer):
            device_map[f'language_model.model.layers.{layer_cnt}'] = 5 + i  # GPU 5, 6, 7
            layer_cnt += 1
    device_map['vision_model'] = 5
    device_map['mlp1'] = 5
    device_map['language_model.model.tok_embeddings'] = 5
    device_map['language_model.model.embed_tokens'] = 5
    device_map['language_model.output'] = 5
    device_map['language_model.model.norm'] = 5
    device_map['language_model.lm_head'] = 5
    device_map[f'language_model.model.layers.{num_layers - 1}'] = 5  # last layer
    return device_map
# model  path
path = "OpenGVLab/InternVL2_5-78B"
# `device_map` setting
device_map = split_model('InternVL2_5-78B')
# model loading
model = AutoModel.from_pretrained(
    path,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    use_flash_attn=True,
    trust_remote_code=True,
    device_map=device_map  # 수정된 `device_map`을 사용
).eval()
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
But, inference Code is not working...
ERROR Message
RuntimeError                              Traceback (most recent call last)
Input In [8], in <cell line: 4>()
      1 generation_config = dict(max_new_tokens=1024, do_sample=True)
      3 question = 'Can you tell me a story?'
----> 4 response, history = model.chat(tokenizer, None, question, generation_config, history=None, return_history=True)
      5 print(f'User: {question}\nAssistant: {response}')
File ~/.cache/huggingface/modules/transformers_modules/OpenGVLab/InternVL2_5-78B/ea891f50e952a1bdf9dd44df66a932bc5a4f40ec/modeling_internvl_chat.py:290, in InternVLChatModel.chat(self, tokenizer, pixel_values, question, generation_config, history, return_history, num_patches_list, IMG_START_TOKEN, IMG_END_TOKEN, IMG_CONTEXT_TOKEN, verbose)
    288 attention_mask = model_inputs['attention_mask'].to(self.device)
    289 generation_config['eos_token_id'] = eos_token_id
--> 290 generation_output = self.generate(
    291     pixel_values=pixel_values,
    292     input_ids=input_ids,
    293     attention_mask=attention_mask,
    294     **generation_config
    295 )
    296 response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
    297 response = response.split(template.sep.strip())[0].strip()
File /usr/local/lib/python3.8/dist-packages/torch/utils/_contextlib.py:115, in context_decorator..decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)
File ~/.cache/huggingface/modules/transformers_modules/OpenGVLab/InternVL2_5-78B/ea891f50e952a1bdf9dd44df66a932bc5a4f40ec/modeling_internvl_chat.py:339, in InternVLChatModel.generate(self, pixel_values, input_ids, attention_mask, visual_features, generation_config, output_hidden_states, **generate_kwargs)
    336 else:
    337     input_embeds = self.language_model.get_input_embeddings()(input_ids)
--> 339 outputs = self.language_model.generate(
    340     inputs_embeds=input_embeds,
    341     attention_mask=attention_mask,
    342     generation_config=generation_config,
    343     output_hidden_states=output_hidden_states,
    344     use_cache=True,
    345     **generate_kwargs,
    346 )
    348 return outputs
File /usr/local/lib/python3.8/dist-packages/torch/utils/_contextlib.py:115, in context_decorator..decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)
File /usr/local/lib/python3.8/dist-packages/transformers/generation/utils.py:2215, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
   2207     input_ids, model_kwargs = self._expand_inputs_for_generation(
   2208         input_ids=input_ids,
   2209         expand_size=generation_config.num_return_sequences,
   2210         is_encoder_decoder=self.config.is_encoder_decoder,
   2211         **model_kwargs,
   2212     )
   2214     # 12. run sample (it degenerates to greedy search when generation_config.do_sample=False)
-> 2215     result = self._sample(
   2216         input_ids,
   2217         logits_processor=prepared_logits_processor,
   2218         stopping_criteria=prepared_stopping_criteria,
   2219         generation_config=generation_config,
   2220         synced_gpus=synced_gpus,
   2221         streamer=streamer,
   2222         **model_kwargs,
   2223     )
   2225 elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
   2226     # 11. prepare beam search scorer
   2227     beam_scorer = BeamSearchScorer(
   2228         batch_size=batch_size,
   2229         num_beams=generation_config.num_beams,
   (...)
   2234         max_length=generation_config.max_length,
   2235     )
File /usr/local/lib/python3.8/dist-packages/transformers/generation/utils.py:3206, in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, **model_kwargs)
   3203 model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
   3205 # forward pass to get next token
-> 3206 outputs = self(**model_inputs, return_dict=True)
   3208 # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
   3209 model_kwargs = self._update_model_kwargs_for_generation(
   3210     outputs,
   3211     model_kwargs,
   3212     is_encoder_decoder=self.config.is_encoder_decoder,
   3213 )
File /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []
File /usr/local/lib/python3.8/dist-packages/transformers/models/qwen2/modeling_qwen2.py:1164, in Qwen2ForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, num_logits_to_keep, **loss_kwargs)
   1161 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
   1163 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1164 outputs = self.model(
   1165     input_ids=input_ids,
   1166     attention_mask=attention_mask,
   1167     position_ids=position_ids,
   1168     past_key_values=past_key_values,
   1169     inputs_embeds=inputs_embeds,
   1170     use_cache=use_cache,
   1171     output_attentions=output_attentions,
   1172     output_hidden_states=output_hidden_states,
   1173     return_dict=return_dict,
   1174     cache_position=cache_position,
   1175 )
   1177 hidden_states = outputs[0]
   1178 # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
File /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []
File /usr/local/lib/python3.8/dist-packages/transformers/models/qwen2/modeling_qwen2.py:871, in Qwen2Model.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
    868 hidden_states = inputs_embeds
    870 # create position embeddings to be shared across the decoder layers
--> 871 position_embeddings = self.rotary_emb(hidden_states, position_ids)
    873 # decoder layers
    874 all_hidden_states = () if output_hidden_states else None
File /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []
File /usr/local/lib/python3.8/dist-packages/torch/utils/_contextlib.py:115, in context_decorator..decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)
File /usr/local/lib/python3.8/dist-packages/transformers/models/qwen2/modeling_qwen2.py:163, in Qwen2RotaryEmbedding.forward(self, x, position_ids)
    161 device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
    162 with torch.autocast(device_type=device_type, enabled=False):
--> 163     freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
    164     emb = torch.cat((freqs, freqs), dim=-1)
    165     cos = emb.cos()
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:5! (when checking argument for argument mat2 in method wrapper_CUDA_bmm)
😭😭😭How to solve RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:5! (when checking argument for argument mat2 in method wrapper_CUDA_bmm) ERROR???
Please Help me....
I'm solving this problems for 1 week...
Hi, you can try to add this line in the split_model  function: 
device_map['language_model.model.rotary_emb'] = 0

