File size: 2,402 Bytes
7f0c9e6
2f0addb
 
0e3cd22
7f0c9e6
2f0addb
 
4ade799
7f0c9e6
 
 
 
 
2f0addb
7f0c9e6
2f0addb
7f0c9e6
 
 
 
 
 
e40ba5b
 
 
2f0addb
e40ba5b
0e3cd22
7f0c9e6
 
2f0addb
7f0c9e6
0e3cd22
2f0addb
e40ba5b
e593b84
7f0c9e6
0e3cd22
 
 
e40ba5b
0e3cd22
e40ba5b
7f0c9e6
2f0addb
7f0c9e6
 
2f0addb
e40ba5b
 
 
 
 
2f0addb
 
e40ba5b
0e3cd22
 
e40ba5b
2f0addb
 
7f0c9e6
e40ba5b
7f0c9e6
e40ba5b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# bp_phi/llm_iface.py
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
import torch, random, numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
from typing import List, Optional

DEBUG = 1

def dbg(*args):
    if DEBUG:
        print("[DEBUG:llm_iface]", *args, flush=True)

class LLM:
    def __init__(self, model_id: str, device: str = "auto", dtype: Optional[str] = None, seed: int = 42):
        self.model_id = model_id
        self.seed = seed

        set_seed(seed)
        token = os.environ.get("HF_TOKEN")

        self.tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True, token=token)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        kwargs = {}
        if torch.cuda.is_available():
             kwargs["torch_dtype"] = torch.bfloat16

        self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device, token=token, **kwargs)
        self.model.eval()

        dbg(f"Loaded model: {model_id}")

    def generate_response(self, system_prompt: str, user_prompt: str, temperature: float = 0.1) -> str:
        set_seed(self.seed)

        messages = [
            {"role": "user", "content": f"{system_prompt}\n\n{user_prompt}"}
        ]

        # Using a simpler user-only template that is robust for Gemma
        prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
        input_token_length = inputs.input_ids.shape[1]

        with torch.no_grad():
            terminators = [
                self.tokenizer.eos_token_id,
                self.tokenizer.convert_tokens_to_ids("<|eot_id|>") if "<|eot_id|>" in self.tokenizer.additional_special_tokens else self.tokenizer.eos_token_id
            ]

            out = self.model.generate(
                **inputs,
                do_sample=(temperature > 0 and temperature < 1.0),
                temperature=max(temperature, 0.01),
                max_new_tokens=200,
                eos_token_id=terminators,
                pad_token_id=self.tokenizer.eos_token_id
            )

        completion = self.tokenizer.decode(out[0, input_token_length:], skip_special_tokens=True)

        dbg("Cleaned Agent Completion:", completion)
        return completion