File size: 3,968 Bytes
c8fa89c
 
 
 
 
 
7dac8c1
c8fa89c
 
 
 
 
 
7dac8c1
 
 
 
 
 
 
 
 
c8fa89c
 
a345062
 
c8fa89c
 
 
 
 
 
 
 
a345062
c8fa89c
 
 
 
 
 
 
 
 
 
 
 
 
a345062
c8fa89c
 
7dac8c1
 
 
 
 
a345062
c8fa89c
7dac8c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c8fa89c
a345062
c8fa89c
 
 
 
 
 
 
 
 
 
 
a345062
c8fa89c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
import torch
import random
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
from typing import Optional
from dataclasses import dataclass

from .utils import dbg

# Ensure deterministic CuBLAS operations for reproducibility on GPU
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

@dataclass
class StableLLMConfig:
    """
    Eine stabile, interne Abstraktionsschicht für Modell-Konfigurationen.
    Macht unseren Code unabhängig von den sich ändernden Attributnamen in `transformers`.
    """
    hidden_dim: int
    num_layers: int

class LLM:
    """
    Eine robuste, bereinigte Schnittstelle zum Laden und Interagieren mit einem Sprachmodell.
    Garantiert Isolation und Reproduzierbarkeit.
    """
    def __init__(self, model_id: str, device: str = "auto", seed: int = 42):
        self.model_id = model_id
        self.seed = seed
        self.set_all_seeds(self.seed)

        token = os.environ.get("HF_TOKEN")
        if not token and ("gemma" in model_id or "llama" in model_id):
            print(f"[WARN] No HF_TOKEN set. If '{model_id}' is gated, loading will fail.", flush=True)

        kwargs = {"torch_dtype": torch.bfloat16} if torch.cuda.is_available() else {}

        dbg(f"Loading tokenizer for '{model_id}'...")
        self.tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True, token=token)

        dbg(f"Loading model '{model_id}' with kwargs: {kwargs}")
        self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device, token=token, **kwargs)

        try:
            self.model.set_attn_implementation('eager')
            dbg("Successfully set attention implementation to 'eager'.")
        except Exception as e:
            print(f"[WARN] Could not set 'eager' attention: {e}.", flush=True)

        self.model.eval()
        self.config = self.model.config # Behalte den Zugriff auf die Originalkonfiguration

        # --- NEU: Befülle die stabile Konfigurations-Abstraktion ---
        self.stable_config = self._populate_stable_config()

        print(f"[INFO] Model '{model_id}' loaded on device: {self.model.device}", flush=True)

    def _populate_stable_config(self) -> StableLLMConfig:
        """
        Liest die volatile `transformers`-Konfiguration aus und befüllt unsere stabile Datenklasse.
        """
        # Robuste Methode für hidden_dim
        try:
            hidden_dim = self.model.get_input_embeddings().weight.shape[1]
        except AttributeError:
            hidden_dim = getattr(self.config, 'hidden_size', getattr(self.config, 'd_model', 0))

        # Robuste Methode für num_layers
        num_layers = getattr(self.config, 'num_hidden_layers', getattr(self.config, 'num_layers', 0))

        # Assertions zur Sicherstellung der wissenschaftlichen Validität
        assert hidden_dim > 0, "Could not determine hidden dimension from model config."
        assert num_layers > 0, "Could not determine number of layers from model config."

        dbg(f"Populated stable config: hidden_dim={hidden_dim}, num_layers={num_layers}")
        return StableLLMConfig(hidden_dim=hidden_dim, num_layers=num_layers)

    def set_all_seeds(self, seed: int):
        """Setzt alle relevanten Seeds für maximale Reproduzierbarkeit."""
        os.environ['PYTHONHASHSEED'] = str(seed)
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)
        set_seed(seed)
        torch.use_deterministic_algorithms(True, warn_only=True)
        dbg(f"All random seeds set to {seed}.")

def get_or_load_model(model_id: str, seed: int) -> LLM:
    """Lädt bei jedem Aufruf eine frische, isolierte Instanz des Modells."""
    dbg(f"--- Force-reloading model '{model_id}' for total run isolation ---")
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    return LLM(model_id=model_id, seed=seed)