Spaces:
Running
Running
| import os | |
| import gc | |
| import torch | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
| DEFAULT_MODEL_SMALL = "vandijklab/C2S-Scale-Gemma-2-2B" | |
| DEFAULT_MODEL_LARGE = "vandijklab/C2S-Scale-Gemma-2-27B" | |
| MODEL_CACHE = {"id": None, "tokenizer": None, "model": None} | |
| def vram_gb(): | |
| if torch.cuda.is_available(): | |
| props = torch.cuda.get_device_properties(0) | |
| return props.total_memory / (1024**3) | |
| return 0.0 | |
| def build_prompt(gene_list, species="Homo sapiens"): | |
| if isinstance(gene_list, str): | |
| raw = [g.strip() for g in gene_list.replace("\n", ",").split(",") if g.strip()] | |
| genes = ", ".join(raw) | |
| else: | |
| genes = ", ".join(gene_list) | |
| return ( | |
| f"The following is a list of gene names ordered by descending expression level " | |
| f"in a {species} cell. Your task is to give the cell type which this cell belongs " | |
| f"to based on its gene expression.\n" | |
| f"Cell sentence: {genes}.\n" | |
| f"The cell type corresponding to these genes is:" | |
| ) | |
| def unload(): | |
| MODEL_CACHE["id"] = None | |
| MODEL_CACHE["tokenizer"] = None | |
| MODEL_CACHE["model"] = None | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| def load_model(model_id, quantization): | |
| """ | |
| Carga perezosa del modelo. Para 27B se recomienda A100 80GB. | |
| quantization: 'none' o '8bit' (requiere bitsandbytes si hay GPU). | |
| """ | |
| if MODEL_CACHE["id"] == model_id and MODEL_CACHE["model"] is not None: | |
| return MODEL_CACHE["tokenizer"], MODEL_CACHE["model"] | |
| unload() | |
| dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
| device_map = "auto" if torch.cuda.is_available() else {"": "cpu"} | |
| kwargs = dict(torch_dtype=dtype, device_map=device_map, low_cpu_mem_usage=True) | |
| if quantization == "8bit" and torch.cuda.is_available(): | |
| try: | |
| import bitsandbytes as bnb # noqa: F401 | |
| kwargs.update(dict(load_in_8bit=True)) | |
| except Exception: | |
| # Si no está disponible, caemos a sin cuantizar | |
| pass | |
| tok = AutoTokenizer.from_pretrained(model_id, use_fast=True) | |
| mdl = AutoModelForCausalLM.from_pretrained(model_id, **kwargs).eval() | |
| MODEL_CACHE["id"] = model_id | |
| MODEL_CACHE["tokenizer"] = tok | |
| MODEL_CACHE["model"] = mdl | |
| return tok, mdl | |
| def infer(model_id, species, species_custom, genes_text, prompt_manual, | |
| max_new_tokens, temperature, top_p, top_k, repetition_penalty, quantization): | |
| # especie efectiva | |
| species_eff = species_custom.strip() if (species == "Custom…" and species_custom.strip()) else species | |
| # chequeo sencillo de VRAM con guía para 27B | |
| mem = vram_gb() | |
| warn = "" | |
| if "27B" in model_id: | |
| if mem < 60 and quantization != "8bit": | |
| warn = ( | |
| f"⚠️ Detectada VRAM ~{mem:.1f}GB. Para 27B se recomienda A100 80GB " | |
| f"o intentar 8-bit (en T4 puede no ser suficiente)." | |
| ) | |
| tok, mdl = load_model(model_id, quantization) | |
| # prompt: usa el manual si está provisto; si no, lo construimos | |
| if prompt_manual and str(prompt_manual).strip(): | |
| prompt = str(prompt_manual).strip() | |
| else: | |
| prompt = build_prompt(genes_text, species=species_eff) | |
| inputs = tok(prompt, return_tensors="pt") | |
| if torch.cuda.is_available(): | |
| inputs = {k: v.to(mdl.device) for k, v in inputs.items()} | |
| streamer = TextIteratorStreamer(tok, skip_special_tokens=True) | |
| gen_kwargs = dict( | |
| **inputs, | |
| max_new_tokens=int(max_new_tokens), | |
| do_sample=True, | |
| temperature=float(temperature), | |
| top_p=float(top_p), | |
| top_k=int(top_k), | |
| repetition_penalty=float(repetition_penalty), | |
| eos_token_id=tok.eos_token_id, | |
| streamer=streamer, | |
| ) | |
| # streaming | |
| import threading | |
| output_text = "" | |
| def _gen(): | |
| with torch.no_grad(): | |
| mdl.generate(**gen_kwargs) | |
| thread = threading.Thread(target=_gen) | |
| thread.start() | |
| for new_text in streamer: | |
| output_text += new_text | |
| yield (warn, output_text) | |
| thread.join() | |
| with gr.Blocks(title="C2S-Scale (Gemma-2) — Single-cell Biology") as demo: | |
| gr.Markdown( | |
| """ | |
| # C2S-Scale (Gemma-2) for single-cell biology | |
| Infiere **tipo celular** a partir de una *cell sentence* (genes ordenados por expresión). | |
| **Modelos**: | |
| - `vandijklab/C2S-Scale-Gemma-2-2B` (ligero; CPU o GPU) | |
| - `vandijklab/C2S-Scale-Gemma-2-27B` (pesado; ideal A100 80GB) | |
| **Nota:** El campo *Prompt efectivo* es editable. Si lo dejas vacío, el app generará uno automáticamente. | |
| """ | |
| ) | |
| with gr.Row(): | |
| model_id = gr.Dropdown( | |
| choices=[DEFAULT_MODEL_SMALL, DEFAULT_MODEL_LARGE], | |
| value=DEFAULT_MODEL_SMALL, | |
| label="Modelo" | |
| ) | |
| quantization = gr.Radio(["none", "8bit"], value="none", label="Cuantización (GPU opcional)") | |
| species = gr.Dropdown(["Homo sapiens", "Mus musculus", "Danio rerio", "Custom…"], value="Homo sapiens", label="Especie") | |
| species_custom = gr.Textbox(value="", label="Especie (si elegiste Custom…)", visible=False) | |
| def _toggle_species(choice): | |
| return gr.update(visible=(choice == "Custom…")) | |
| species.change(_toggle_species, species, species_custom) | |
| example_genes = "MALAT1, RPLP0, RPL13A, ACTB, RPS27A, PTPRC, CD3D, CD3E, CCR7, IL7R, LTB, TRAC, CD27, CD4, CCR6, CXCR5" | |
| genes_text = gr.Textbox(value=example_genes, lines=6, label="Cell sentence (lista de genes ordenados por expresión ↓)") | |
| with gr.Accordion("Parámetros de generación", open=False): | |
| max_new_tokens = gr.Slider(8, 256, value=64, step=1, label="max_new_tokens") | |
| temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="temperature") | |
| top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.01, label="top_p") | |
| top_k = gr.Slider(1, 200, value=50, step=1, label="top_k") | |
| repetition_penalty = gr.Slider(0.8, 1.5, value=1.05, step=0.01, label="repetition_penalty") | |
| # PROMPT EFECTIVO (editable por el usuario) | |
| prompt_box = gr.Textbox(label="Prompt efectivo (opcional; déjalo vacío para autogenerar)", lines=8, interactive=True) | |
| warn_box = gr.Markdown("") | |
| output_box = gr.Textbox(label="Salida del modelo (stream)") | |
| run_btn = gr.Button("🚀 Inferir tipo celular") | |
| run_btn.click( | |
| fn=infer, | |
| inputs=[model_id, species, species_custom, genes_text, prompt_box, | |
| max_new_tokens, temperature, top_p, top_k, repetition_penalty, quantization], | |
| outputs=[warn_box, output_box] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |