File size: 5,204 Bytes
c8fa89c bca8f87 1cf9e80 c8fa89c 7dac8c1 1cf9e80 7dac8c1 c8fa89c bca8f87 c8fa89c a345062 c8fa89c 2161eb0 7dac8c1 a345062 c8fa89c 7dac8c1 2169e97 7dac8c1 2161eb0 1cf9e80 2161eb0 2169e97 1cf9e80 2169e97 1cf9e80 2161eb0 1cf9e80 2161eb0 7dac8c1 1cf9e80 2169e97 1cf9e80 2169e97 bca8f87 7dac8c1 1cf9e80 7dac8c1 c8fa89c bca8f87 c8fa89c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import os
import torch
import random
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed, TextStreamer
from typing import Optional, List
from dataclasses import dataclass, field
from .utils import dbg
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
@dataclass
class StableLLMConfig:
hidden_dim: int
num_layers: int
layer_list: List[torch.nn.Module] = field(default_factory=list, repr=False)
class LLM:
def __init__(self, model_id: str, device: str = "auto", seed: int = 42):
self.model_id = model_id
self.seed = seed
self.set_all_seeds(self.seed)
token = os.environ.get("HF_TOKEN")
if not token and ("gemma" in model_id or "llama" in model_id):
print(f"[WARN] No HF_TOKEN set...", flush=True)
kwargs = {"torch_dtype": torch.bfloat16} if torch.cuda.is_available() else {}
dbg(f"Loading tokenizer for '{model_id}'...")
self.tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True, token=token)
dbg(f"Loading model '{model_id}' with kwargs: {kwargs}")
self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device, token=token, **kwargs)
try:
self.model.set_attn_implementation('eager')
dbg("Successfully set attention implementation to 'eager'.")
except Exception as e:
print(f"[WARN] Could not set 'eager' attention: {e}.", flush=True)
self.model.eval()
self.config = self.model.config
self.stable_config = self._populate_stable_config()
print(f"[INFO] Model '{model_id}' loaded on device: {self.model.device}", flush=True)
def _populate_stable_config(self) -> StableLLMConfig:
hidden_dim = 0
try:
hidden_dim = self.model.get_input_embeddings().weight.shape[1]
except AttributeError:
hidden_dim = getattr(self.config, 'hidden_size', getattr(self.config, 'd_model', 0))
num_layers = 0
layer_list = []
try:
if hasattr(self.model, 'model') and hasattr(self.model.model, 'language_model') and hasattr(self.model.model.language_model, 'layers'):
layer_list = self.model.model.language_model.layers
elif hasattr(self.model, 'model') and hasattr(self.model.model, 'layers'):
layer_list = self.model.model.layers
elif hasattr(self.model, 'transformer') and hasattr(self.model.transformer, 'h'):
layer_list = self.model.transformer.h
if layer_list:
num_layers = len(layer_list)
except (AttributeError, TypeError):
pass
if num_layers == 0:
num_layers = getattr(self.config, 'num_hidden_layers', getattr(self.config, 'num_layers', 0))
if hidden_dim <= 0 or num_layers <= 0 or not layer_list:
dbg("--- CRITICAL: Failed to auto-determine model configuration. ---")
dbg(f"Detected hidden_dim: {hidden_dim}, num_layers: {num_layers}, found_layer_list: {bool(layer_list)}")
dbg("--- DUMPING MODEL ARCHITECTURE FOR DEBUGGING: ---")
dbg(self.model)
dbg("--- END ARCHITECTURE DUMP ---")
assert hidden_dim > 0, "Could not determine hidden dimension."
assert num_layers > 0, "Could not determine number of layers."
assert layer_list, "Could not find the list of transformer layers."
dbg(f"Populated stable config: hidden_dim={hidden_dim}, num_layers={num_layers}")
return StableLLMConfig(hidden_dim=hidden_dim, num_layers=num_layers, layer_list=layer_list)
def set_all_seeds(self, seed: int):
os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
set_seed(seed)
torch.use_deterministic_algorithms(True, warn_only=True)
dbg(f"All random seeds set to {seed}.")
# --- NEU: Generische Text-Generierungs-Methode ---
@torch.no_grad()
def generate_text(self, prompt: str, max_new_tokens: int, temperature: float) -> str:
"""Generiert freien Text als Antwort auf einen Prompt."""
self.set_all_seeds(self.seed) # Sorge für Reproduzierbarkeit
messages = [{"role": "user", "content": prompt}]
inputs = self.tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
).to(self.model.device)
outputs = self.model.generate(
inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=temperature > 0,
)
# Dekodiere nur die neu generierten Tokens
response_tokens = outputs[0, inputs.shape[-1]:]
return self.tokenizer.decode(response_tokens, skip_special_tokens=True)
def get_or_load_model(model_id: str, seed: int) -> LLM:
dbg(f"--- Force-reloading model '{model_id}' for total run isolation ---")
if torch.cuda.is_available():
torch.cuda.empty_cache()
return LLM(model_id=model_id, seed=seed)
|