Update app.py
Browse files
app.py
CHANGED
|
@@ -26,22 +26,11 @@ def log(msg):
|
|
| 26 |
def get_model_size_in_gb(model_name):
|
| 27 |
"""估算模型大小(以GB为单位)"""
|
| 28 |
try:
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
# 手动计算参数量
|
| 34 |
-
if hasattr(config, 'num_hidden_layers') and hasattr(config, 'hidden_size'):
|
| 35 |
-
# 简单估算,可能不够准确
|
| 36 |
-
num_params = config.num_hidden_layers * config.hidden_size * config.hidden_size * 4
|
| 37 |
|
| 38 |
-
if num_params:
|
| 39 |
-
# 每个参数占用2字节(float16)
|
| 40 |
-
size_in_gb = (num_params * 2) / (1024 ** 3)
|
| 41 |
-
return size_in_gb
|
| 42 |
-
else:
|
| 43 |
-
# 如果无法计算,返回一个保守的估计
|
| 44 |
-
return 1 # bypass memory check
|
| 45 |
except Exception as e:
|
| 46 |
log(f"无法估算模型大小: {str(e)}")
|
| 47 |
return 1 # bypass memory check
|
|
@@ -98,6 +87,11 @@ def setup_environment(model_name, hf_token):
|
|
| 98 |
def create_hf_repo(repo_name, hf_token, private=True):
|
| 99 |
"""创建HuggingFace仓库"""
|
| 100 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
repo_url = create_repo(repo_name, private=private, token=hf_token)
|
| 102 |
log(f"创建仓库成功: {repo_url}")
|
| 103 |
return repo_url
|
|
|
|
| 26 |
def get_model_size_in_gb(model_name):
|
| 27 |
"""估算模型大小(以GB为单位)"""
|
| 28 |
try:
|
| 29 |
+
# get model size from huggingface
|
| 30 |
+
api = HfApi()
|
| 31 |
+
model_info = api.model_info(model_name)
|
| 32 |
+
return model_info.safetensors.total / (1024 ** 3)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
except Exception as e:
|
| 35 |
log(f"无法估算模型大小: {str(e)}")
|
| 36 |
return 1 # bypass memory check
|
|
|
|
| 87 |
def create_hf_repo(repo_name, hf_token, private=True):
|
| 88 |
"""创建HuggingFace仓库"""
|
| 89 |
try:
|
| 90 |
+
# check if repo already exists
|
| 91 |
+
api = HfApi()
|
| 92 |
+
if api.repo_exists(repo_name):
|
| 93 |
+
log(f"仓库已存在: {repo_name}")
|
| 94 |
+
return ValueError(f"仓库已存在: {repo_name}, 请使用其他名称或删除已存在的仓库")
|
| 95 |
repo_url = create_repo(repo_name, private=private, token=hf_token)
|
| 96 |
log(f"创建仓库成功: {repo_url}")
|
| 97 |
return repo_url
|