import os import torch import random import numpy as np from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed from typing import Optional from .utils import dbg # Ensure deterministic CuBLAS operations for reproducibility on GPU os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" class LLM: """ Eine robuste, bereinigte Schnittstelle zum Laden und Interagieren mit einem Sprachmodell. Garantiert Isolation und Reproduzierbarkeit. """ def __init__(self, model_id: str, device: str = "auto", seed: int = 42): self.model_id = model_id self.seed = seed self.set_all_seeds(self.seed) token = os.environ.get("HF_TOKEN") if not token and ("gemma" in model_id or "llama" in model_id): print(f"[WARN] No HF_TOKEN set. If '{model_id}' is gated, loading will fail.", flush=True) kwargs = {"torch_dtype": torch.bfloat16} if torch.cuda.is_available() else {} dbg(f"Loading tokenizer for '{model_id}'...") self.tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True, token=token) dbg(f"Loading model '{model_id}' with kwargs: {kwargs}") self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device, token=token, **kwargs) try: self.model.set_attn_implementation('eager') dbg("Successfully set attention implementation to 'eager'.") except Exception as e: print(f"[WARN] Could not set 'eager' attention: {e}.", flush=True) self.model.eval() self.config = self.model.config print(f"[INFO] Model '{model_id}' loaded on device: {self.model.device}", flush=True) def set_all_seeds(self, seed: int): """Setzt alle relevanten Seeds für maximale Reproduzierbarkeit.""" os.environ['PYTHONHASHSEED'] = str(seed) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) set_seed(seed) torch.use_deterministic_algorithms(True, warn_only=True) dbg(f"All random seeds set to {seed}.") def get_or_load_model(model_id: str, seed: int) -> LLM: """Lädt bei jedem Aufruf eine frische, isolierte Instanz des Modells.""" dbg(f"--- Force-reloading model '{model_id}' for total run isolation ---") if torch.cuda.is_available(): torch.cuda.empty_cache() return LLM(model_id=model_id, seed=seed)