|
|
import os |
|
|
import torch |
|
|
import random |
|
|
import numpy as np |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed |
|
|
from typing import Optional |
|
|
|
|
|
from .utils import dbg |
|
|
|
|
|
|
|
|
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" |
|
|
|
|
|
class LLM: |
|
|
""" |
|
|
Eine robuste, bereinigte Schnittstelle zum Laden und Interagieren mit einem Sprachmodell. |
|
|
Garantiert Isolation und Reproduzierbarkeit. |
|
|
""" |
|
|
def __init__(self, model_id: str, device: str = "auto", seed: int = 42): |
|
|
self.model_id = model_id |
|
|
self.seed = seed |
|
|
self.set_all_seeds(self.seed) |
|
|
|
|
|
token = os.environ.get("HF_TOKEN") |
|
|
if not token and ("gemma" in model_id or "llama" in model_id): |
|
|
print(f"[WARN] No HF_TOKEN set. If '{model_id}' is gated, loading will fail.", flush=True) |
|
|
|
|
|
kwargs = {"torch_dtype": torch.bfloat16} if torch.cuda.is_available() else {} |
|
|
|
|
|
dbg(f"Loading tokenizer for '{model_id}'...") |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True, token=token) |
|
|
|
|
|
dbg(f"Loading model '{model_id}' with kwargs: {kwargs}") |
|
|
self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device, token=token, **kwargs) |
|
|
|
|
|
try: |
|
|
self.model.set_attn_implementation('eager') |
|
|
dbg("Successfully set attention implementation to 'eager'.") |
|
|
except Exception as e: |
|
|
print(f"[WARN] Could not set 'eager' attention: {e}.", flush=True) |
|
|
|
|
|
self.model.eval() |
|
|
self.config = self.model.config |
|
|
print(f"[INFO] Model '{model_id}' loaded on device: {self.model.device}", flush=True) |
|
|
|
|
|
def set_all_seeds(self, seed: int): |
|
|
"""Setzt alle relevanten Seeds für maximale Reproduzierbarkeit.""" |
|
|
os.environ['PYTHONHASHSEED'] = str(seed) |
|
|
random.seed(seed) |
|
|
np.random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
set_seed(seed) |
|
|
torch.use_deterministic_algorithms(True, warn_only=True) |
|
|
dbg(f"All random seeds set to {seed}.") |
|
|
|
|
|
def get_or_load_model(model_id: str, seed: int) -> LLM: |
|
|
"""Lädt bei jedem Aufruf eine frische, isolierte Instanz des Modells.""" |
|
|
dbg(f"--- Force-reloading model '{model_id}' for total run isolation ---") |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
return LLM(model_id=model_id, seed=seed) |
|
|
|