neuralworm's picture
initial commit
c8fa89c
raw
history blame
3.25 kB
import os
import torch
import random
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
from typing import Optional
from .utils import dbg
# Ensure deterministic CuBLAS operations for reproducibility on GPU
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
class LLM:
"""
Eine robuste Schnittstelle zum Laden und Interagieren mit einem Sprachmodell.
Diese Klasse garantiert die Isolation und Reproduzierbarkeit für jeden Ladevorgang.
"""
def __init__(self, model_id: str, device: str = "auto", seed: int = 42):
self.model_id = model_id
self.seed = seed
# Set all seeds for this instance to ensure deterministic behavior
self.set_all_seeds(self.seed)
token = os.environ.get("HF_TOKEN")
if not token and ("gemma" in model_id or "llama" in model_id):
print(f"[WARN] No HF_TOKEN environment variable set. If '{model_id}' is a gated model, this will fail.", flush=True)
# Use bfloat16 on CUDA for performance and memory efficiency if available
kwargs = {"torch_dtype": torch.bfloat16} if torch.cuda.is_available() else {}
dbg(f"Loading tokenizer for '{model_id}'...")
self.tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True, token=token)
dbg(f"Loading model '{model_id}' with kwargs: {kwargs}")
self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device, token=token, **kwargs)
# Set attention implementation to 'eager' to ensure hooks work reliably.
# This is critical for mechanistic interpretability.
try:
self.model.set_attn_implementation('eager')
dbg("Successfully set attention implementation to 'eager'.")
except Exception as e:
print(f"[WARN] Could not set attention implementation to 'eager': {e}. Hook-based diagnostics might fail.", flush=True)
self.model.eval()
self.config = self.model.config
print(f"[INFO] Model '{model_id}' loaded successfully on device: {self.model.device}", flush=True)
def set_all_seeds(self, seed: int):
"""
Sets all relevant random seeds for Python, NumPy, and PyTorch to ensure
reproducibility of stochastic processes like sampling.
"""
os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
set_seed(seed)
# Enforce deterministic algorithms in PyTorch
torch.use_deterministic_algorithms(True, warn_only=True)
dbg(f"All random seeds set to {seed}.")
def get_or_load_model(model_id: str, seed: int) -> LLM:
"""
Lädt JEDES MAL eine frische Instanz des Modells.
Dies verhindert jegliches Caching oder Zustandslecks zwischen Experimenten
und garantiert maximale wissenschaftliche Isolation für jeden Durchlauf.
"""
dbg(f"--- Force-reloading model '{model_id}' for total run isolation ---")
if torch.cuda.is_available():
torch.cuda.empty_cache()
dbg("Cleared CUDA cache before reloading.")
return LLM(model_id=model_id, seed=seed)