Spaces:
Sleeping
Sleeping
| # bp_phi/llm_iface.py | |
| import os | |
| os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" | |
| import torch, random, numpy as np | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed | |
| from typing import List, Optional | |
| DEBUG = 1 | |
| def dbg(*args): | |
| if DEBUG: | |
| print("[DEBUG:llm_iface]", *args, flush=True) | |
| class LLM: | |
| def __init__(self, model_id: str, device: str = "auto", dtype: Optional[str] = None, seed: int = 42): | |
| self.model_id = model_id | |
| self.seed = seed | |
| set_seed(seed) | |
| token = os.environ.get("HF_TOKEN") | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True, token=token) | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| kwargs = {} | |
| if torch.cuda.is_available(): | |
| kwargs["torch_dtype"] = torch.bfloat16 | |
| self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device, token=token, **kwargs) | |
| self.model.eval() | |
| dbg(f"Loaded model: {model_id}") | |
| def generate_response(self, system_prompt: str, user_prompt: str, temperature: float = 0.1) -> str: | |
| set_seed(self.seed) | |
| messages = [ | |
| {"role": "user", "content": f"{system_prompt}\n\n{user_prompt}"} | |
| ] | |
| # Using a simpler user-only template that is robust for Gemma | |
| prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) | |
| input_token_length = inputs.input_ids.shape[1] | |
| with torch.no_grad(): | |
| terminators = [ | |
| self.tokenizer.eos_token_id, | |
| self.tokenizer.convert_tokens_to_ids("<|eot_id|>") if "<|eot_id|>" in self.tokenizer.additional_special_tokens else self.tokenizer.eos_token_id | |
| ] | |
| out = self.model.generate( | |
| **inputs, | |
| do_sample=(temperature > 0 and temperature < 1.0), | |
| temperature=max(temperature, 0.01), | |
| max_new_tokens=200, | |
| eos_token_id=terminators, | |
| pad_token_id=self.tokenizer.eos_token_id | |
| ) | |
| completion = self.tokenizer.decode(out[0, input_token_length:], skip_special_tokens=True) | |
| dbg("Cleaned Agent Completion:", completion) | |
| return completion | |