demo_C2S_Scale / app.py
napoles3d's picture
Update app.py
b5f99ea verified
import os
import gc
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
DEFAULT_MODEL_SMALL = "vandijklab/C2S-Scale-Gemma-2-2B"
DEFAULT_MODEL_LARGE = "vandijklab/C2S-Scale-Gemma-2-27B"
MODEL_CACHE = {"id": None, "tokenizer": None, "model": None}
def vram_gb():
if torch.cuda.is_available():
props = torch.cuda.get_device_properties(0)
return props.total_memory / (1024**3)
return 0.0
def build_prompt(gene_list, species="Homo sapiens"):
if isinstance(gene_list, str):
raw = [g.strip() for g in gene_list.replace("\n", ",").split(",") if g.strip()]
genes = ", ".join(raw)
else:
genes = ", ".join(gene_list)
return (
f"The following is a list of gene names ordered by descending expression level "
f"in a {species} cell. Your task is to give the cell type which this cell belongs "
f"to based on its gene expression.\n"
f"Cell sentence: {genes}.\n"
f"The cell type corresponding to these genes is:"
)
def unload():
MODEL_CACHE["id"] = None
MODEL_CACHE["tokenizer"] = None
MODEL_CACHE["model"] = None
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def load_model(model_id, quantization):
"""
Carga perezosa del modelo. Para 27B se recomienda A100 80GB.
quantization: 'none' o '8bit' (requiere bitsandbytes si hay GPU).
"""
if MODEL_CACHE["id"] == model_id and MODEL_CACHE["model"] is not None:
return MODEL_CACHE["tokenizer"], MODEL_CACHE["model"]
unload()
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
device_map = "auto" if torch.cuda.is_available() else {"": "cpu"}
kwargs = dict(torch_dtype=dtype, device_map=device_map, low_cpu_mem_usage=True)
if quantization == "8bit" and torch.cuda.is_available():
try:
import bitsandbytes as bnb # noqa: F401
kwargs.update(dict(load_in_8bit=True))
except Exception:
# Si no está disponible, caemos a sin cuantizar
pass
tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
mdl = AutoModelForCausalLM.from_pretrained(model_id, **kwargs).eval()
MODEL_CACHE["id"] = model_id
MODEL_CACHE["tokenizer"] = tok
MODEL_CACHE["model"] = mdl
return tok, mdl
def infer(model_id, species, species_custom, genes_text, prompt_manual,
max_new_tokens, temperature, top_p, top_k, repetition_penalty, quantization):
# especie efectiva
species_eff = species_custom.strip() if (species == "Custom…" and species_custom.strip()) else species
# chequeo sencillo de VRAM con guía para 27B
mem = vram_gb()
warn = ""
if "27B" in model_id:
if mem < 60 and quantization != "8bit":
warn = (
f"⚠️ Detectada VRAM ~{mem:.1f}GB. Para 27B se recomienda A100 80GB "
f"o intentar 8-bit (en T4 puede no ser suficiente)."
)
tok, mdl = load_model(model_id, quantization)
# prompt: usa el manual si está provisto; si no, lo construimos
if prompt_manual and str(prompt_manual).strip():
prompt = str(prompt_manual).strip()
else:
prompt = build_prompt(genes_text, species=species_eff)
inputs = tok(prompt, return_tensors="pt")
if torch.cuda.is_available():
inputs = {k: v.to(mdl.device) for k, v in inputs.items()}
streamer = TextIteratorStreamer(tok, skip_special_tokens=True)
gen_kwargs = dict(
**inputs,
max_new_tokens=int(max_new_tokens),
do_sample=True,
temperature=float(temperature),
top_p=float(top_p),
top_k=int(top_k),
repetition_penalty=float(repetition_penalty),
eos_token_id=tok.eos_token_id,
streamer=streamer,
)
# streaming
import threading
output_text = ""
def _gen():
with torch.no_grad():
mdl.generate(**gen_kwargs)
thread = threading.Thread(target=_gen)
thread.start()
for new_text in streamer:
output_text += new_text
yield (warn, output_text)
thread.join()
with gr.Blocks(title="C2S-Scale (Gemma-2) — Single-cell Biology") as demo:
gr.Markdown(
"""
# C2S-Scale (Gemma-2) for single-cell biology
Infiere **tipo celular** a partir de una *cell sentence* (genes ordenados por expresión).
**Modelos**:
- `vandijklab/C2S-Scale-Gemma-2-2B` (ligero; CPU o GPU)
- `vandijklab/C2S-Scale-Gemma-2-27B` (pesado; ideal A100 80GB)
**Nota:** El campo *Prompt efectivo* es editable. Si lo dejas vacío, el app generará uno automáticamente.
"""
)
with gr.Row():
model_id = gr.Dropdown(
choices=[DEFAULT_MODEL_SMALL, DEFAULT_MODEL_LARGE],
value=DEFAULT_MODEL_SMALL,
label="Modelo"
)
quantization = gr.Radio(["none", "8bit"], value="none", label="Cuantización (GPU opcional)")
species = gr.Dropdown(["Homo sapiens", "Mus musculus", "Danio rerio", "Custom…"], value="Homo sapiens", label="Especie")
species_custom = gr.Textbox(value="", label="Especie (si elegiste Custom…)", visible=False)
def _toggle_species(choice):
return gr.update(visible=(choice == "Custom…"))
species.change(_toggle_species, species, species_custom)
example_genes = "MALAT1, RPLP0, RPL13A, ACTB, RPS27A, PTPRC, CD3D, CD3E, CCR7, IL7R, LTB, TRAC, CD27, CD4, CCR6, CXCR5"
genes_text = gr.Textbox(value=example_genes, lines=6, label="Cell sentence (lista de genes ordenados por expresión ↓)")
with gr.Accordion("Parámetros de generación", open=False):
max_new_tokens = gr.Slider(8, 256, value=64, step=1, label="max_new_tokens")
temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="temperature")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.01, label="top_p")
top_k = gr.Slider(1, 200, value=50, step=1, label="top_k")
repetition_penalty = gr.Slider(0.8, 1.5, value=1.05, step=0.01, label="repetition_penalty")
# PROMPT EFECTIVO (editable por el usuario)
prompt_box = gr.Textbox(label="Prompt efectivo (opcional; déjalo vacío para autogenerar)", lines=8, interactive=True)
warn_box = gr.Markdown("")
output_box = gr.Textbox(label="Salida del modelo (stream)")
run_btn = gr.Button("🚀 Inferir tipo celular")
run_btn.click(
fn=infer,
inputs=[model_id, species, species_custom, genes_text, prompt_box,
max_new_tokens, temperature, top_p, top_k, repetition_penalty, quantization],
outputs=[warn_box, output_box]
)
if __name__ == "__main__":
demo.launch()