hf-gaia-agents-course-MN / model_llama_local.py
Mahynlo
Switch to Zephyr 7B (no gating required)
c3c998f
"""
Wrapper para ejecutar Llama 2 (u otros modelos de HF) LOCALMENTE en el Space.
Usa transformers + pipeline para inferencia en CPU/GPU.
Compatible con la clase Agent (método generate_simple).
"""
import os
import time
import torch
from functools import lru_cache
from typing import List, Dict, Optional
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
try:
from transformers import BitsAndBytesConfig
BITSANDBYTES_AVAILABLE = True
except ImportError:
BITSANDBYTES_AVAILABLE = False
print("⚠️ bitsandbytes no disponible, no se puede usar quantización 8-bit")
class LocalHFModel:
"""
Modelo de HuggingFace cargado localmente en memoria.
Ventajas:
- ⚡ Más rápido (sin latencia de red)
- 🔒 Sin rate limits
- 💾 Control total sobre parámetros
Desventajas:
- 🧠 Usa RAM del Space (~7-14GB según modelo)
- ⏳ Carga inicial lenta (30-60s)
"""
def __init__(
self,
model_id: str = "meta-llama/Llama-2-7b-chat-hf",
max_new_tokens: int = 256,
temperature: float = 0.0,
device: str = "auto",
load_in_8bit: bool = True,
):
"""
Inicializa modelo local.
Args:
model_id: ID del modelo en HuggingFace Hub
max_new_tokens: Tokens máximos a generar
temperature: 0.0 = determinístico, >0 = creativo
device: "auto", "cpu", "cuda"
load_in_8bit: True = ~7GB RAM, False = ~14GB RAM
"""
self.model_id = model_id
self.max_new_tokens = max_new_tokens
self.temperature = temperature
self.device = device
self.load_in_8bit = load_in_8bit
# HF Token (necesario para modelos gated como Llama)
self.hf_token = os.getenv("HF_TOKEN")
if not self.hf_token:
raise ValueError(
"❌ HF_TOKEN no configurado.\n"
"Necesario para descargar modelos de HuggingFace.\n"
"Configúralo en Settings → Repository secrets"
)
print(f"🦙 Cargando {model_id} localmente...")
print(f" 📍 Device: {device}")
print(f" 💾 8-bit quantization: {load_in_8bit}")
print(f" 🎯 Max tokens: {max_new_tokens}")
print(f" 🌡️ Temperature: {temperature}")
start_time = time.time()
# Configurar quantización
quantization_config = None
if load_in_8bit and BITSANDBYTES_AVAILABLE:
try:
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0,
)
print(" ✅ Quantización 8-bit configurada")
except Exception as e:
print(f" ⚠️ Error en 8-bit, usando float16: {e}")
load_in_8bit = False
elif load_in_8bit and not BITSANDBYTES_AVAILABLE:
print(" ⚠️ bitsandbytes no instalado, usando float16")
load_in_8bit = False
# Cargar tokenizer
print(" 📦 Cargando tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(
model_id,
token=self.hf_token,
trust_remote_code=True
)
# Configurar pad_token si no existe
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Cargar modelo
print(" 🧠 Cargando modelo (30-60s)...")
try:
self.model = AutoModelForCausalLM.from_pretrained(
model_id,
token=self.hf_token,
torch_dtype=torch.float16 if not load_in_8bit else torch.float32,
device_map=device,
quantization_config=quantization_config,
low_cpu_mem_usage=True, # Importante para 16GB RAM
trust_remote_code=True
)
except Exception as e:
print(f" ❌ Error cargando modelo: {e}")
print(" ℹ️ Verifica que HF_TOKEN tenga acceso al modelo")
raise
# Crear pipeline
self.pipe = pipeline(
"text-generation",
model=self.model,
tokenizer=self.tokenizer,
max_new_tokens=max_new_tokens,
temperature=temperature if temperature > 0 else 0.01, # 0.0 causa problemas
do_sample=temperature > 0,
top_p=0.95 if temperature > 0 else 1.0,
repetition_penalty=1.15,
return_full_text=False, # Solo nueva generación
)
load_time = time.time() - start_time
print(f" ✅ Modelo cargado en {load_time:.1f}s")
# Info de memoria
if torch.cuda.is_available():
mem_allocated = torch.cuda.memory_allocated() / 1024**3
print(f" 📊 GPU Memory: {mem_allocated:.2f} GB")
else:
print(" 📊 Running on CPU")
def _format_llama_prompt(self, messages: List[Dict[str, str]]) -> str:
"""
Formatea mensajes al formato correcto según el modelo.
Soporta:
- Llama 2: <s>[INST] <<SYS>>...
- Zephyr: <|system|>...<|user|>...<|assistant|>
"""
system_msg = ""
user_msg = ""
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
if role == "system":
system_msg = content
elif role == "user":
user_msg = content
# Detectar formato según model_id
if "zephyr" in self.model_id.lower():
# Formato Zephyr
if system_msg:
prompt = f"<|system|>\n{system_msg}</s>\n<|user|>\n{user_msg}</s>\n<|assistant|>\n"
else:
prompt = f"<|user|>\n{user_msg}</s>\n<|assistant|>\n"
else:
# Formato Llama 2 (default)
if system_msg:
prompt = f"<s>[INST] <<SYS>>\n{system_msg}\n<</SYS>>\n\n{user_msg} [/INST]"
else:
prompt = f"<s>[INST] {user_msg} [/INST]"
return prompt
def __call__(
self,
messages: List[Dict[str, str]],
max_new_tokens: Optional[int] = None,
temperature: Optional[float] = None,
**kwargs
) -> str:
"""
Genera respuesta.
Args:
messages: [{"role": "user", "content": "..."}]
max_new_tokens: Override de tokens
temperature: Override de temperatura
Returns:
Texto generado
"""
try:
# Formatear prompt
prompt = self._format_llama_prompt(messages)
# Override parámetros
gen_kwargs = {}
if max_new_tokens:
gen_kwargs["max_new_tokens"] = max_new_tokens
if temperature is not None:
gen_kwargs["temperature"] = temperature if temperature > 0 else 0.01
gen_kwargs["do_sample"] = temperature > 0
# Generar
start_time = time.time()
result = self.pipe(prompt, **gen_kwargs)
gen_time = time.time() - start_time
# Extraer texto
if isinstance(result, list) and len(result) > 0:
generated_text = result[0].get("generated_text", "")
else:
generated_text = str(result)
print(f" ⚡ Generado en {gen_time:.2f}s ({len(generated_text)} chars)")
return generated_text.strip()
except Exception as e:
error_msg = f"ERROR: {str(e)}"
print(f" ❌ {error_msg}")
return error_msg
def generate_simple(
self,
prompt: str,
system: Optional[str] = None,
**kwargs
) -> str:
"""
Interfaz simplificada compatible con Agent.
Args:
prompt: Texto del usuario
system: Prompt de sistema (opcional)
Returns:
Respuesta generada
"""
messages = []
if system:
messages.append({"role": "system", "content": system})
messages.append({"role": "user", "content": prompt})
return self(messages, **kwargs)
@lru_cache(maxsize=1) # Solo 1 modelo en cache (usa mucha RAM)
def get_local_model(
model_id: str = "meta-llama/Llama-2-7b-chat-hf",
load_in_8bit: bool = True,
max_new_tokens: int = 256,
temperature: float = 0.0,
) -> LocalHFModel:
"""
Factory con cache para modelo local.
IMPORTANTE: maxsize=1 porque cada modelo usa ~7-14GB RAM.
"""
return LocalHFModel(
model_id=model_id,
load_in_8bit=load_in_8bit,
max_new_tokens=max_new_tokens,
temperature=temperature
)
# Alias para compatibilidad con app.py
def get_model(model_id: str = "meta-llama/Llama-2-7b-chat-hf", **kwargs) -> LocalHFModel:
"""
Factory principal para obtener modelo local.
Args:
model_id: Modelo de HuggingFace
**kwargs: Parámetros adicionales (load_in_8bit, max_new_tokens, etc.)
Returns:
LocalHFModel listo para usar
"""
# Obtener parámetros con defaults
load_in_8bit = kwargs.pop("load_in_8bit", True)
max_new_tokens = kwargs.pop("max_new_tokens", 256)
temperature = kwargs.pop("temperature", 0.0)
return get_local_model(
model_id=model_id,
load_in_8bit=load_in_8bit,
max_new_tokens=max_new_tokens,
temperature=temperature
)
if __name__ == "__main__":
# Test
print("=== Test de Modelo Local ===")
model = get_model(load_in_8bit=True)
response = model.generate_simple(
"What is 2+2?",
system="You are a helpful math assistant."
)
print(f"\n📝 Respuesta: {response}")