|
|
""" |
|
|
Wrapper para ejecutar Llama 2 (u otros modelos de HF) LOCALMENTE en el Space. |
|
|
Usa transformers + pipeline para inferencia en CPU/GPU. |
|
|
Compatible con la clase Agent (método generate_simple). |
|
|
""" |
|
|
|
|
|
import os |
|
|
import time |
|
|
import torch |
|
|
from functools import lru_cache |
|
|
from typing import List, Dict, Optional |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
|
|
|
|
try: |
|
|
from transformers import BitsAndBytesConfig |
|
|
BITSANDBYTES_AVAILABLE = True |
|
|
except ImportError: |
|
|
BITSANDBYTES_AVAILABLE = False |
|
|
print("⚠️ bitsandbytes no disponible, no se puede usar quantización 8-bit") |
|
|
|
|
|
|
|
|
class LocalHFModel: |
|
|
""" |
|
|
Modelo de HuggingFace cargado localmente en memoria. |
|
|
|
|
|
Ventajas: |
|
|
- ⚡ Más rápido (sin latencia de red) |
|
|
- 🔒 Sin rate limits |
|
|
- 💾 Control total sobre parámetros |
|
|
|
|
|
Desventajas: |
|
|
- 🧠 Usa RAM del Space (~7-14GB según modelo) |
|
|
- ⏳ Carga inicial lenta (30-60s) |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_id: str = "meta-llama/Llama-2-7b-chat-hf", |
|
|
max_new_tokens: int = 256, |
|
|
temperature: float = 0.0, |
|
|
device: str = "auto", |
|
|
load_in_8bit: bool = True, |
|
|
): |
|
|
""" |
|
|
Inicializa modelo local. |
|
|
|
|
|
Args: |
|
|
model_id: ID del modelo en HuggingFace Hub |
|
|
max_new_tokens: Tokens máximos a generar |
|
|
temperature: 0.0 = determinístico, >0 = creativo |
|
|
device: "auto", "cpu", "cuda" |
|
|
load_in_8bit: True = ~7GB RAM, False = ~14GB RAM |
|
|
""" |
|
|
self.model_id = model_id |
|
|
self.max_new_tokens = max_new_tokens |
|
|
self.temperature = temperature |
|
|
self.device = device |
|
|
self.load_in_8bit = load_in_8bit |
|
|
|
|
|
|
|
|
self.hf_token = os.getenv("HF_TOKEN") |
|
|
if not self.hf_token: |
|
|
raise ValueError( |
|
|
"❌ HF_TOKEN no configurado.\n" |
|
|
"Necesario para descargar modelos de HuggingFace.\n" |
|
|
"Configúralo en Settings → Repository secrets" |
|
|
) |
|
|
|
|
|
print(f"🦙 Cargando {model_id} localmente...") |
|
|
print(f" 📍 Device: {device}") |
|
|
print(f" 💾 8-bit quantization: {load_in_8bit}") |
|
|
print(f" 🎯 Max tokens: {max_new_tokens}") |
|
|
print(f" 🌡️ Temperature: {temperature}") |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
quantization_config = None |
|
|
if load_in_8bit and BITSANDBYTES_AVAILABLE: |
|
|
try: |
|
|
quantization_config = BitsAndBytesConfig( |
|
|
load_in_8bit=True, |
|
|
llm_int8_threshold=6.0, |
|
|
) |
|
|
print(" ✅ Quantización 8-bit configurada") |
|
|
except Exception as e: |
|
|
print(f" ⚠️ Error en 8-bit, usando float16: {e}") |
|
|
load_in_8bit = False |
|
|
elif load_in_8bit and not BITSANDBYTES_AVAILABLE: |
|
|
print(" ⚠️ bitsandbytes no instalado, usando float16") |
|
|
load_in_8bit = False |
|
|
|
|
|
|
|
|
print(" 📦 Cargando tokenizer...") |
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
|
model_id, |
|
|
token=self.hf_token, |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
|
|
|
if self.tokenizer.pad_token is None: |
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
|
|
|
|
|
print(" 🧠 Cargando modelo (30-60s)...") |
|
|
try: |
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
model_id, |
|
|
token=self.hf_token, |
|
|
torch_dtype=torch.float16 if not load_in_8bit else torch.float32, |
|
|
device_map=device, |
|
|
quantization_config=quantization_config, |
|
|
low_cpu_mem_usage=True, |
|
|
trust_remote_code=True |
|
|
) |
|
|
except Exception as e: |
|
|
print(f" ❌ Error cargando modelo: {e}") |
|
|
print(" ℹ️ Verifica que HF_TOKEN tenga acceso al modelo") |
|
|
raise |
|
|
|
|
|
|
|
|
self.pipe = pipeline( |
|
|
"text-generation", |
|
|
model=self.model, |
|
|
tokenizer=self.tokenizer, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature if temperature > 0 else 0.01, |
|
|
do_sample=temperature > 0, |
|
|
top_p=0.95 if temperature > 0 else 1.0, |
|
|
repetition_penalty=1.15, |
|
|
return_full_text=False, |
|
|
) |
|
|
|
|
|
load_time = time.time() - start_time |
|
|
print(f" ✅ Modelo cargado en {load_time:.1f}s") |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
mem_allocated = torch.cuda.memory_allocated() / 1024**3 |
|
|
print(f" 📊 GPU Memory: {mem_allocated:.2f} GB") |
|
|
else: |
|
|
print(" 📊 Running on CPU") |
|
|
|
|
|
def _format_llama_prompt(self, messages: List[Dict[str, str]]) -> str: |
|
|
""" |
|
|
Formatea mensajes al formato correcto según el modelo. |
|
|
|
|
|
Soporta: |
|
|
- Llama 2: <s>[INST] <<SYS>>... |
|
|
- Zephyr: <|system|>...<|user|>...<|assistant|> |
|
|
""" |
|
|
system_msg = "" |
|
|
user_msg = "" |
|
|
|
|
|
for msg in messages: |
|
|
role = msg.get("role", "user") |
|
|
content = msg.get("content", "") |
|
|
|
|
|
if role == "system": |
|
|
system_msg = content |
|
|
elif role == "user": |
|
|
user_msg = content |
|
|
|
|
|
|
|
|
if "zephyr" in self.model_id.lower(): |
|
|
|
|
|
if system_msg: |
|
|
prompt = f"<|system|>\n{system_msg}</s>\n<|user|>\n{user_msg}</s>\n<|assistant|>\n" |
|
|
else: |
|
|
prompt = f"<|user|>\n{user_msg}</s>\n<|assistant|>\n" |
|
|
else: |
|
|
|
|
|
if system_msg: |
|
|
prompt = f"<s>[INST] <<SYS>>\n{system_msg}\n<</SYS>>\n\n{user_msg} [/INST]" |
|
|
else: |
|
|
prompt = f"<s>[INST] {user_msg} [/INST]" |
|
|
|
|
|
return prompt |
|
|
|
|
|
def __call__( |
|
|
self, |
|
|
messages: List[Dict[str, str]], |
|
|
max_new_tokens: Optional[int] = None, |
|
|
temperature: Optional[float] = None, |
|
|
**kwargs |
|
|
) -> str: |
|
|
""" |
|
|
Genera respuesta. |
|
|
|
|
|
Args: |
|
|
messages: [{"role": "user", "content": "..."}] |
|
|
max_new_tokens: Override de tokens |
|
|
temperature: Override de temperatura |
|
|
|
|
|
Returns: |
|
|
Texto generado |
|
|
""" |
|
|
try: |
|
|
|
|
|
prompt = self._format_llama_prompt(messages) |
|
|
|
|
|
|
|
|
gen_kwargs = {} |
|
|
if max_new_tokens: |
|
|
gen_kwargs["max_new_tokens"] = max_new_tokens |
|
|
if temperature is not None: |
|
|
gen_kwargs["temperature"] = temperature if temperature > 0 else 0.01 |
|
|
gen_kwargs["do_sample"] = temperature > 0 |
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
result = self.pipe(prompt, **gen_kwargs) |
|
|
gen_time = time.time() - start_time |
|
|
|
|
|
|
|
|
if isinstance(result, list) and len(result) > 0: |
|
|
generated_text = result[0].get("generated_text", "") |
|
|
else: |
|
|
generated_text = str(result) |
|
|
|
|
|
print(f" ⚡ Generado en {gen_time:.2f}s ({len(generated_text)} chars)") |
|
|
|
|
|
return generated_text.strip() |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"ERROR: {str(e)}" |
|
|
print(f" ❌ {error_msg}") |
|
|
return error_msg |
|
|
|
|
|
def generate_simple( |
|
|
self, |
|
|
prompt: str, |
|
|
system: Optional[str] = None, |
|
|
**kwargs |
|
|
) -> str: |
|
|
""" |
|
|
Interfaz simplificada compatible con Agent. |
|
|
|
|
|
Args: |
|
|
prompt: Texto del usuario |
|
|
system: Prompt de sistema (opcional) |
|
|
|
|
|
Returns: |
|
|
Respuesta generada |
|
|
""" |
|
|
messages = [] |
|
|
if system: |
|
|
messages.append({"role": "system", "content": system}) |
|
|
messages.append({"role": "user", "content": prompt}) |
|
|
|
|
|
return self(messages, **kwargs) |
|
|
|
|
|
|
|
|
@lru_cache(maxsize=1) |
|
|
def get_local_model( |
|
|
model_id: str = "meta-llama/Llama-2-7b-chat-hf", |
|
|
load_in_8bit: bool = True, |
|
|
max_new_tokens: int = 256, |
|
|
temperature: float = 0.0, |
|
|
) -> LocalHFModel: |
|
|
""" |
|
|
Factory con cache para modelo local. |
|
|
|
|
|
IMPORTANTE: maxsize=1 porque cada modelo usa ~7-14GB RAM. |
|
|
""" |
|
|
return LocalHFModel( |
|
|
model_id=model_id, |
|
|
load_in_8bit=load_in_8bit, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def get_model(model_id: str = "meta-llama/Llama-2-7b-chat-hf", **kwargs) -> LocalHFModel: |
|
|
""" |
|
|
Factory principal para obtener modelo local. |
|
|
|
|
|
Args: |
|
|
model_id: Modelo de HuggingFace |
|
|
**kwargs: Parámetros adicionales (load_in_8bit, max_new_tokens, etc.) |
|
|
|
|
|
Returns: |
|
|
LocalHFModel listo para usar |
|
|
""" |
|
|
|
|
|
load_in_8bit = kwargs.pop("load_in_8bit", True) |
|
|
max_new_tokens = kwargs.pop("max_new_tokens", 256) |
|
|
temperature = kwargs.pop("temperature", 0.0) |
|
|
|
|
|
return get_local_model( |
|
|
model_id=model_id, |
|
|
load_in_8bit=load_in_8bit, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
print("=== Test de Modelo Local ===") |
|
|
|
|
|
model = get_model(load_in_8bit=True) |
|
|
|
|
|
response = model.generate_simple( |
|
|
"What is 2+2?", |
|
|
system="You are a helpful math assistant." |
|
|
) |
|
|
|
|
|
print(f"\n📝 Respuesta: {response}") |
|
|
|