peihsin0715 commited on
Commit
b1dab38
·
1 Parent(s): d02be95

Fix model loading

Browse files
Files changed (1) hide show
  1. backend/utils/utils.py +44 -10
backend/utils/utils.py CHANGED
@@ -17,34 +17,68 @@ from transformers import (
17
  )
18
 
19
  def load_model_and_tokenizer(model_name: str):
 
 
 
 
 
20
  if torch.cuda.is_available():
21
  device = torch.device("cuda")
22
- elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): # macOS Apple Silicon
23
  device = torch.device("mps")
24
  else:
25
  device = torch.device("cpu")
26
-
27
- gpt2_aliases = {"gpt2", "openai-community/gpt2", "holistic-ai/gpt2-EMGSD"}
28
-
29
  try:
30
  if model_name in gpt2_aliases:
31
- tokenizer = GPT2Tokenizer.from_pretrained(model_name)
 
 
 
32
  if tokenizer.pad_token is None and tokenizer.eos_token is not None:
33
  tokenizer.pad_token = tokenizer.eos_token
34
- model = GPT2LMHeadModel.from_pretrained(model_name)
 
 
 
 
 
 
 
35
  if getattr(model.config, "pad_token_id", None) is None and getattr(model.config, "eos_token_id", None) is not None:
36
  model.config.pad_token_id = model.config.eos_token_id
37
  else:
38
- tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
 
39
  if tokenizer.pad_token is None and tokenizer.eos_token is not None:
40
  tokenizer.pad_token = tokenizer.eos_token
41
- model = AutoModelForCausalLM.from_pretrained(model_name)
 
 
 
 
 
 
 
 
42
  if getattr(model.config, "pad_token_id", None) is None and getattr(model.config, "eos_token_id", None) is not None:
43
  model.config.pad_token_id = model.config.eos_token_id
44
-
45
- model.to(device)
 
 
 
46
  return tokenizer, model, device
 
47
  except Exception as e:
 
 
 
 
48
  raise RuntimeError(f"Failed to load model '{model_name}': {e}")
49
 
50
  def finetune(train_texts, tokenizer, model, num_epochs=20, output_dir='./data'):
 
17
  )
18
 
19
  def load_model_and_tokenizer(model_name: str):
20
+ # 檢查可用記憶體
21
+ import psutil
22
+ available_memory = psutil.virtual_memory().available / 1024**3
23
+ print(f"Available memory: {available_memory:.2f} GB")
24
+
25
  if torch.cuda.is_available():
26
  device = torch.device("cuda")
27
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
28
  device = torch.device("mps")
29
  else:
30
  device = torch.device("cpu")
31
+
32
+ gpt2_aliases = {"gpt2", "openai-community/gpt2"}
33
+
34
  try:
35
  if model_name in gpt2_aliases:
36
+ tokenizer = GPT2Tokenizer.from_pretrained(
37
+ model_name,
38
+ cache_dir="/tmp/hf_cache" # 使用臨時目錄
39
+ )
40
  if tokenizer.pad_token is None and tokenizer.eos_token is not None:
41
  tokenizer.pad_token = tokenizer.eos_token
42
+
43
+ model = GPT2LMHeadModel.from_pretrained(
44
+ model_name,
45
+ torch_dtype=torch.float16 if device.type != "cpu" else torch.float32,
46
+ low_cpu_mem_usage=True, # 關鍵:減少記憶體使用
47
+ cache_dir="/tmp/hf_cache"
48
+ )
49
+
50
  if getattr(model.config, "pad_token_id", None) is None and getattr(model.config, "eos_token_id", None) is not None:
51
  model.config.pad_token_id = model.config.eos_token_id
52
  else:
53
+ tokenizer = AutoTokenizer.from_pretrained(
54
+ model_name,
55
+ cache_dir="/tmp/hf_cache"
56
+ )
57
  if tokenizer.pad_token is None and tokenizer.eos_token is not None:
58
  tokenizer.pad_token = tokenizer.eos_token
59
+
60
+ model = AutoModelForCausalLM.from_pretrained(
61
+ model_name,
62
+ torch_dtype=torch.float16 if device.type != "cpu" else torch.float32,
63
+ low_cpu_mem_usage=True, # 關鍵優化
64
+ device_map="auto" if torch.cuda.is_available() else None,
65
+ cache_dir="/tmp/hf_cache"
66
+ )
67
+
68
  if getattr(model.config, "pad_token_id", None) is None and getattr(model.config, "eos_token_id", None) is not None:
69
  model.config.pad_token_id = model.config.eos_token_id
70
+
71
+ # 只在非 CPU 設備上移動模型
72
+ if device.type != "cpu":
73
+ model.to(device)
74
+
75
  return tokenizer, model, device
76
+
77
  except Exception as e:
78
+ # 提供更詳細的錯誤信息
79
+ import traceback
80
+ print(f"Error loading model {model_name}: {str(e)}")
81
+ print(f"Traceback: {traceback.format_exc()}")
82
  raise RuntimeError(f"Failed to load model '{model_name}': {e}")
83
 
84
  def finetune(train_texts, tokenizer, model, num_epochs=20, output_dir='./data'):