Spaces:
Sleeping
Sleeping
| import os | |
| os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from typing import List, Optional | |
| class LLM: | |
| def __init__(self, model_id: str, device: str = "auto", dtype: Optional[str] = None): | |
| self.model_id = model_id | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) | |
| kwargs = {} | |
| if dtype == "float16": | |
| kwargs["torch_dtype"] = torch.float16 | |
| elif dtype == "bfloat16": | |
| kwargs["torch_dtype"] = torch.bfloat16 | |
| self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device, **kwargs) | |
| self.model.eval() | |
| self.is_instruction_tuned = hasattr(self.tokenizer, "apply_chat_template") and getattr(self.tokenizer, "chat_template", None) | |
| print(f"[BP-Φ] Loaded model: {model_id}") | |
| print(f"[BP-Φ] Chat-template detected: {bool(self.is_instruction_tuned)}") | |
| def generate_json(self, system_prompt: str, user_prompt: str, | |
| max_new_tokens: int = 256, temperature: float = 0.7, | |
| top_p: float = 0.9, num_return_sequences: int = 1) -> List[str]: | |
| if self.is_instruction_tuned: | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt} | |
| ] | |
| prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| else: | |
| prompt = f"{system_prompt}\n\nUser:\n{user_prompt}\n\nAssistant:\n" | |
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) | |
| with torch.no_grad(): | |
| out = self.model.generate( | |
| **inputs, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_p=top_p, | |
| max_new_tokens=max_new_tokens, | |
| num_return_sequences=num_return_sequences, | |
| pad_token_id=self.tokenizer.eos_token_id | |
| ) | |
| texts = self.tokenizer.batch_decode(out, skip_special_tokens=True) | |
| completions = [] | |
| for t in texts: | |
| for marker in ["<end_of_turn>", "<end_of_text>", "</s>"]: | |
| if marker in t: | |
| t = t.split(marker)[0] | |
| if "Assistant:" in t: | |
| t = t.split("Assistant:")[-1] | |
| completions.append(t.strip()) | |
| return completions | |