llm_qualia_2 / bp_phi /llm_iface.py
neuralworm's picture
fix for gemma
0e3cd22
raw
history blame
2.4 kB
# bp_phi/llm_iface.py
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
import torch, random, numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
from typing import List, Optional
DEBUG = 1
def dbg(*args):
if DEBUG:
print("[DEBUG:llm_iface]", *args, flush=True)
class LLM:
def __init__(self, model_id: str, device: str = "auto", dtype: Optional[str] = None, seed: int = 42):
self.model_id = model_id
self.seed = seed
set_seed(seed)
token = os.environ.get("HF_TOKEN")
self.tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True, token=token)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
kwargs = {}
if torch.cuda.is_available():
kwargs["torch_dtype"] = torch.bfloat16
self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device, token=token, **kwargs)
self.model.eval()
dbg(f"Loaded model: {model_id}")
def generate_response(self, system_prompt: str, user_prompt: str, temperature: float = 0.1) -> str:
set_seed(self.seed)
messages = [
{"role": "user", "content": f"{system_prompt}\n\n{user_prompt}"}
]
# Using a simpler user-only template that is robust for Gemma
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
input_token_length = inputs.input_ids.shape[1]
with torch.no_grad():
terminators = [
self.tokenizer.eos_token_id,
self.tokenizer.convert_tokens_to_ids("<|eot_id|>") if "<|eot_id|>" in self.tokenizer.additional_special_tokens else self.tokenizer.eos_token_id
]
out = self.model.generate(
**inputs,
do_sample=(temperature > 0 and temperature < 1.0),
temperature=max(temperature, 0.01),
max_new_tokens=200,
eos_token_id=terminators,
pad_token_id=self.tokenizer.eos_token_id
)
completion = self.tokenizer.decode(out[0, input_token_length:], skip_special_tokens=True)
dbg("Cleaned Agent Completion:", completion)
return completion