File size: 5,204 Bytes
c8fa89c
 
 
 
bca8f87
1cf9e80
 
c8fa89c
 
 
 
 
7dac8c1
 
 
 
1cf9e80
7dac8c1
c8fa89c
 
 
 
 
 
 
 
bca8f87
c8fa89c
 
 
 
 
 
 
 
 
 
 
 
 
a345062
c8fa89c
 
2161eb0
7dac8c1
 
 
a345062
c8fa89c
7dac8c1
2169e97
7dac8c1
 
 
 
 
2161eb0
1cf9e80
2161eb0
2169e97
1cf9e80
2169e97
1cf9e80
2161eb0
1cf9e80
 
 
 
2161eb0
 
 
 
 
7dac8c1
1cf9e80
2169e97
1cf9e80
2169e97
 
 
 
bca8f87
 
 
7dac8c1
 
1cf9e80
7dac8c1
c8fa89c
 
 
 
 
 
 
 
 
 
 
bca8f87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c8fa89c
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
import torch
import random
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed, TextStreamer
from typing import Optional, List
from dataclasses import dataclass, field

from .utils import dbg

os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

@dataclass
class StableLLMConfig:
    hidden_dim: int
    num_layers: int
    layer_list: List[torch.nn.Module] = field(default_factory=list, repr=False)

class LLM:
    def __init__(self, model_id: str, device: str = "auto", seed: int = 42):
        self.model_id = model_id
        self.seed = seed
        self.set_all_seeds(self.seed)

        token = os.environ.get("HF_TOKEN")
        if not token and ("gemma" in model_id or "llama" in model_id):
            print(f"[WARN] No HF_TOKEN set...", flush=True)

        kwargs = {"torch_dtype": torch.bfloat16} if torch.cuda.is_available() else {}

        dbg(f"Loading tokenizer for '{model_id}'...")
        self.tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True, token=token)

        dbg(f"Loading model '{model_id}' with kwargs: {kwargs}")
        self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device, token=token, **kwargs)

        try:
            self.model.set_attn_implementation('eager')
            dbg("Successfully set attention implementation to 'eager'.")
        except Exception as e:
            print(f"[WARN] Could not set 'eager' attention: {e}.", flush=True)

        self.model.eval()
        self.config = self.model.config

        self.stable_config = self._populate_stable_config()

        print(f"[INFO] Model '{model_id}' loaded on device: {self.model.device}", flush=True)

    def _populate_stable_config(self) -> StableLLMConfig:
        hidden_dim = 0
        try:
            hidden_dim = self.model.get_input_embeddings().weight.shape[1]
        except AttributeError:
            hidden_dim = getattr(self.config, 'hidden_size', getattr(self.config, 'd_model', 0))

        num_layers = 0
        layer_list = []
        try:
            if hasattr(self.model, 'model') and hasattr(self.model.model, 'language_model') and hasattr(self.model.model.language_model, 'layers'):
                 layer_list = self.model.model.language_model.layers
            elif hasattr(self.model, 'model') and hasattr(self.model.model, 'layers'):
                 layer_list = self.model.model.layers
            elif hasattr(self.model, 'transformer') and hasattr(self.model.transformer, 'h'):
                 layer_list = self.model.transformer.h

            if layer_list:
                num_layers = len(layer_list)
        except (AttributeError, TypeError):
            pass

        if num_layers == 0:
            num_layers = getattr(self.config, 'num_hidden_layers', getattr(self.config, 'num_layers', 0))

        if hidden_dim <= 0 or num_layers <= 0 or not layer_list:
            dbg("--- CRITICAL: Failed to auto-determine model configuration. ---")
            dbg(f"Detected hidden_dim: {hidden_dim}, num_layers: {num_layers}, found_layer_list: {bool(layer_list)}")
            dbg("--- DUMPING MODEL ARCHITECTURE FOR DEBUGGING: ---")
            dbg(self.model)
            dbg("--- END ARCHITECTURE DUMP ---")

        assert hidden_dim > 0, "Could not determine hidden dimension."
        assert num_layers > 0, "Could not determine number of layers."
        assert layer_list, "Could not find the list of transformer layers."

        dbg(f"Populated stable config: hidden_dim={hidden_dim}, num_layers={num_layers}")
        return StableLLMConfig(hidden_dim=hidden_dim, num_layers=num_layers, layer_list=layer_list)

    def set_all_seeds(self, seed: int):
        os.environ['PYTHONHASHSEED'] = str(seed)
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)
        set_seed(seed)
        torch.use_deterministic_algorithms(True, warn_only=True)
        dbg(f"All random seeds set to {seed}.")

    # --- NEU: Generische Text-Generierungs-Methode ---
    @torch.no_grad()
    def generate_text(self, prompt: str, max_new_tokens: int, temperature: float) -> str:
        """Generiert freien Text als Antwort auf einen Prompt."""
        self.set_all_seeds(self.seed) # Sorge für Reproduzierbarkeit

        messages = [{"role": "user", "content": prompt}]
        inputs = self.tokenizer.apply_chat_template(
            messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
        ).to(self.model.device)

        outputs = self.model.generate(
            inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            do_sample=temperature > 0,
        )

        # Dekodiere nur die neu generierten Tokens
        response_tokens = outputs[0, inputs.shape[-1]:]
        return self.tokenizer.decode(response_tokens, skip_special_tokens=True)

def get_or_load_model(model_id: str, seed: int) -> LLM:
    dbg(f"--- Force-reloading model '{model_id}' for total run isolation ---")
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    return LLM(model_id=model_id, seed=seed)