aknapitsch user
refactoring
eb74057
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Helper functions for HuggingFace integration and model initialization.
"""
import json
import os
def load_hf_token():
"""Load HuggingFace access token from local file"""
token_file_paths = [
"/home/aknapitsch/hf_token.txt",
]
for token_path in token_file_paths:
if os.path.exists(token_path):
try:
with open(token_path, "r") as f:
token = f.read().strip()
print(f"Loaded HuggingFace token from: {token_path}")
return token
except Exception as e:
print(f"Error reading token from {token_path}: {e}")
continue
else:
print(token_path, "token_path doesnt exist")
# Also try environment variable
# see https://huggingface.co/docs/hub/spaces-overview#managing-secrets on options
token = (
os.getenv("HF_TOKEN")
or os.getenv("HUGGING_FACE_HUB_TOKEN")
or os.getenv("HUGGING_FACE_MODEL_TOKEN")
)
if token:
print("Loaded HuggingFace token from environment variable")
return token
print(
"Warning: No HuggingFace token found. Model loading may fail for private repositories."
)
return None
def init_hydra_config(config_path, overrides=None):
"""Initialize Hydra config"""
import hydra
config_dir = os.path.dirname(config_path)
config_name = os.path.basename(config_path).split(".")[0]
relative_path = os.path.relpath(config_dir, os.path.dirname(__file__))
hydra.core.global_hydra.GlobalHydra.instance().clear()
hydra.initialize(version_base=None, config_path=relative_path)
if overrides is not None:
cfg = hydra.compose(config_name=config_name, overrides=overrides)
else:
cfg = hydra.compose(config_name=config_name)
return cfg
def initialize_mapanything_model(high_level_config, device):
"""
Initialize MapAnything model with three-tier fallback approach:
1. Try HuggingFace from_pretrained()
2. Download HF config + use local model factory + load HF weights
3. Pure local configuration fallback
Args:
high_level_config (dict): Configuration dictionary containing model settings
device (torch.device): Device to load the model on
Returns:
torch.nn.Module: Initialized MapAnything model
"""
import torch
from huggingface_hub import hf_hub_download
from mapanything.models import init_model, MapAnything
print("Initializing MapAnything model...")
# Initialize Hydra config and create model from configuration
cfg = init_hydra_config(
high_level_config["path"], overrides=high_level_config["config_overrides"]
)
# Try using from_pretrained first
try:
print("Loading MapAnything model from_pretrained...")
model = MapAnything.from_pretrained(high_level_config["hf_model_name"]).to(
device
)
print("Loading MapAnything model from_pretrained succeeded...")
return model
except Exception as e:
print(f"from_pretrained failed: {e}")
print("Falling back to local configuration approach using hf_hub_download...")
# Create model from local configuration instead of using from_pretrained
# Try to download and use the config from HuggingFace Hub
try:
print("Downloading model configuration from HuggingFace Hub...")
config_path = hf_hub_download(
repo_id=high_level_config["hf_model_name"],
filename=high_level_config["config_name"],
token=load_hf_token(),
)
# Load the config from the downloaded file
with open(config_path, "r") as f:
downloaded_config = json.load(f)
print("Using downloaded configuration for model initialization")
model = init_model(
model_str=downloaded_config.get(
"model_str", high_level_config["model_str"]
),
model_config=downloaded_config.get(
"model_config", cfg.model.model_config
),
torch_hub_force_reload=high_level_config.get(
"torch_hub_force_reload", False
),
)
except Exception as config_e:
print(f"Failed to download/use HuggingFace config: {config_e}")
print("Falling back to local configuration...")
# Fall back to local configuration as before
model = init_model(
model_str=cfg.model.model_str,
model_config=cfg.model.model_config,
torch_hub_force_reload=high_level_config.get(
"torch_hub_force_reload", False
),
)
# Load the pretrained weights from HuggingFace Hub
try:
# First, let's see what files are available in the repository
try:
checkpoint_filename = high_level_config["checkpoint_name"]
# Download the model weights
checkpoint_path = hf_hub_download(
repo_id=high_level_config["hf_model_name"],
filename=checkpoint_filename,
token=load_hf_token(),
)
# Load the weights
print("start loading checkpoint")
if checkpoint_filename.endswith(".safetensors"):
from safetensors.torch import load_file
checkpoint = load_file(checkpoint_path)
else:
checkpoint = torch.load(
checkpoint_path, map_location="cpu", weights_only=False
)
print("start loading state_dict")
if "model" in checkpoint:
model.load_state_dict(checkpoint["model"], strict=False)
elif "state_dict" in checkpoint:
model.load_state_dict(checkpoint["state_dict"], strict=False)
else:
model.load_state_dict(checkpoint, strict=False)
print(
f"Successfully loaded pretrained weights from HuggingFace Hub ({checkpoint_filename})"
)
except Exception as inner_e:
print(f"Error listing repository files or loading weights: {inner_e}")
raise inner_e
except Exception as e:
print(f"Warning: Could not load pretrained weights: {e}")
print("Proceeding with randomly initialized model...")
model = model.to(device)
return model