Spaces:
Sleeping
Sleeping
| # bp_phi/llm_iface.py | |
| import os | |
| os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" | |
| import torch | |
| import random | |
| import numpy as np | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed | |
| from typing import List, Optional | |
| DEBUG = 1 | |
| def dbg(*args): | |
| if DEBUG: | |
| print("[DEBUG:llm_iface]", *args, flush=True) | |
| class LLM: | |
| def __init__(self, model_id: str, device: str = "auto", dtype: Optional[str] = None, seed: int = 42): | |
| self.model_id = model_id | |
| self.seed = seed | |
| set_seed(seed) | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |
| if dtype is None: | |
| dtype = "bfloat16" # Smart default for memory efficiency on CUDA | |
| dbg(f"CUDA detected. Defaulting to dtype={dtype} for memory efficiency.") | |
| try: | |
| torch.use_deterministic_algorithms(True, warn_only=True) | |
| except Exception as e: | |
| dbg(f"Could not set deterministic algorithms: {e}") | |
| token = os.environ.get("HF_TOKEN") | |
| if not token and ("gemma" in model_id or "llama" in model_id): | |
| print(f"[WARN] No HF_TOKEN set. If the model '{model_id}' is gated, this will fail.") | |
| kwargs = {} | |
| if dtype == "bfloat16": | |
| kwargs["torch_dtype"] = torch.bfloat16 | |
| elif dtype == "float16": | |
| kwargs["torch_dtype"] = torch.float16 | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True, token=token) | |
| self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device, token=token, **kwargs) | |
| self.model.eval() | |
| print(f"[INFO] Model '{model_id}' loaded successfully on device: {self.model.device}") | |
| def generate_json(self, system_prompt: str, user_prompt: str, **kwargs) -> List[str]: | |
| # This function remains for potential future use but is not used by the cogitation test. | |
| # It's kept here for completeness. | |
| # ... (Implementation can be added back if needed) | |
| return [""] | |