|
|
|
|
|
import os |
|
|
os.system("pip install --no-deps ./tahoex-0.1.2-py3-none-any.whl") |
|
|
|
|
|
|
|
|
import gc |
|
|
import json |
|
|
import uuid |
|
|
import time |
|
|
import tempfile |
|
|
import torch |
|
|
import gradio as gr |
|
|
import anndata as ad |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import scanpy as sc |
|
|
import pyarrow as pa |
|
|
import pyarrow.parquet as pq |
|
|
import matplotlib.pyplot as plt |
|
|
from scipy import sparse |
|
|
from pathlib import Path |
|
|
from composer import Trainer, Callback |
|
|
from tahoex.model.model import ComposerTX |
|
|
from tahoex.data import CountDataset, DataCollator |
|
|
|
|
|
|
|
|
EMB_KEY = "X_tx1-70m" |
|
|
APP_TITLE = "Tx1-70M Embeddings" |
|
|
APP_DESC = """ |
|
|
Upload an AnnData, compute Tx1-70M embeddings, |
|
|
preview a UMAP, and download the results. **Limits:** |
|
|
Files up to 5GB. If an AnnData contains more |
|
|
than 100K cells, embeddings will be computed **only |
|
|
for the first 100K cells**. |
|
|
""" |
|
|
|
|
|
|
|
|
OUTPUT_DIR = Path("./outputs") |
|
|
OUTPUT_DIR.mkdir(exist_ok=True, parents=True) |
|
|
|
|
|
|
|
|
with open("./symbol-to-ensembl.json", "r") as f: |
|
|
SYMBOL_TO_ENSEMBL = json.load(f) |
|
|
SYMBOL_TO_ENSEMBL_UCASE = {str(k).upper(): v for k, v in SYMBOL_TO_ENSEMBL.items()} |
|
|
|
|
|
|
|
|
PARQUET_INDEX_COL = "index" |
|
|
PARQUET_EMB_COL = "tx1-70m" |
|
|
|
|
|
|
|
|
OBS_NONE_OPTION = "(none)" |
|
|
MAX_CATEGORIES = 50 |
|
|
|
|
|
|
|
|
VAR_PREVIEW_MAX = 5 |
|
|
|
|
|
|
|
|
def _pick_layer(adata, layer_name): |
|
|
X = adata.layers[layer_name] if layer_name else adata.X |
|
|
if sparse.issparse(X): |
|
|
X = X.tocsr() |
|
|
elif not isinstance(X, np.ndarray): |
|
|
X = np.asarray(X) |
|
|
if hasattr(X, "dtype") and X.dtype != np.float32: |
|
|
X = X.astype(np.float32, copy=False) |
|
|
return X |
|
|
|
|
|
|
|
|
def _summarize_columns(df, preview_max=VAR_PREVIEW_MAX): |
|
|
choices = [] |
|
|
for col in df.columns: |
|
|
s = df[col] |
|
|
dtype = str(s.dtype) |
|
|
ex = pd.Series(s.astype(object)).dropna().astype(str).head(preview_max).tolist() |
|
|
preview = ", ".join(ex) if ex else "(no values)" |
|
|
if len(preview) > 47: |
|
|
preview = preview[:47] + "..." |
|
|
elif preview != "(no values)": |
|
|
preview = preview + "..." |
|
|
lbl = f"{col} · {dtype} · {preview}" |
|
|
choices.append((lbl, col)) |
|
|
return choices |
|
|
|
|
|
|
|
|
def _compute_umap_from_emb(emb): |
|
|
ad_umap = ad.AnnData(X=emb) |
|
|
sc.pp.neighbors(ad_umap) |
|
|
sc.tl.umap(ad_umap) |
|
|
coords = np.asarray(ad_umap.obsm["X_umap"]) |
|
|
del ad_umap |
|
|
return coords |
|
|
|
|
|
|
|
|
def _unique_output(name): |
|
|
stem, ext = name.rsplit(".", 1) |
|
|
return OUTPUT_DIR / f"{stem}_{int(time.time())}_{uuid.uuid4().hex[:6]}.{ext}" |
|
|
|
|
|
|
|
|
def _save_outputs(adata, emb, chunk=20000): |
|
|
|
|
|
|
|
|
d_model = int(emb.shape[1]) |
|
|
schema = pa.schema([ |
|
|
pa.field(PARQUET_INDEX_COL, pa.string()), |
|
|
pa.field(PARQUET_EMB_COL, pa.list_(pa.float32(), d_model)), |
|
|
]) |
|
|
|
|
|
|
|
|
parquet_path = _unique_output("embs.parquet") |
|
|
writer = None |
|
|
try: |
|
|
for i in range(0, emb.shape[0], chunk): |
|
|
sl = slice(i, min(i+chunk, emb.shape[0])) |
|
|
idx_arr = pa.array(adata.obs_names[sl].astype(str).tolist(), type=pa.string()) |
|
|
flat = pa.array(emb[sl].reshape(-1), type=pa.float32()) |
|
|
emb_arr = pa.FixedSizeListArray.from_arrays(flat, d_model) |
|
|
batch = pa.record_batch([idx_arr, emb_arr], schema=schema) |
|
|
if writer is None: |
|
|
writer = pq.ParquetWriter(parquet_path, schema, compression="zstd", use_dictionary=True) |
|
|
writer.write_table(pa.Table.from_batches([batch])) |
|
|
finally: |
|
|
if writer is not None: |
|
|
writer.close() |
|
|
|
|
|
|
|
|
out_h5ad = _unique_output("adata_with_embs.h5ad") |
|
|
adata.write(out_h5ad) |
|
|
|
|
|
|
|
|
return parquet_path, out_h5ad |
|
|
|
|
|
|
|
|
def ensure_dropdowns(fileobj): |
|
|
if fileobj is None: |
|
|
return ( |
|
|
gr.Dropdown(choices=["<use .X>"], value="<use .X>"), |
|
|
gr.Dropdown(choices=[], value=None) |
|
|
) |
|
|
try: |
|
|
adata = sc.read_h5ad(fileobj.name, backed="r") |
|
|
adata.var = adata.var.reset_index(drop=False, names="index") |
|
|
layers = list(adata.layers.keys()) |
|
|
var_choices = _summarize_columns(adata.var) |
|
|
del adata |
|
|
gc.collect() |
|
|
default_var = var_choices[0][1] if var_choices else None |
|
|
return ( |
|
|
gr.Dropdown(choices=["<use .X>"] + layers, value="<use .X>"), |
|
|
gr.Dropdown(choices=var_choices, value=default_var) |
|
|
) |
|
|
except Exception: |
|
|
return ( |
|
|
gr.Dropdown(choices=["<use .X>"], value="<use .X>"), |
|
|
gr.Dropdown(choices=[], value=None) |
|
|
) |
|
|
|
|
|
|
|
|
def draw_uncolored(coords, title_suffix=None): |
|
|
fig = plt.figure(figsize=(5.5, 5.0)) |
|
|
ax = fig.add_subplot(111) |
|
|
ax.scatter(coords[:, 0], coords[:, 1], s=3, alpha=0.75) |
|
|
ttl = "Tx1-70M embeddings" |
|
|
if title_suffix: |
|
|
ttl += f" ({title_suffix})" |
|
|
ax.set_title(ttl) |
|
|
ax.set_xlabel("UMAP1") |
|
|
ax.set_ylabel("UMAP2") |
|
|
fig.tight_layout() |
|
|
out_png = _unique_output("umap.png") |
|
|
fig.savefig(out_png, dpi=160) |
|
|
plt.close(fig) |
|
|
return out_png |
|
|
|
|
|
|
|
|
def recolor_umap(obs_col, coords, h5ad_path): |
|
|
|
|
|
|
|
|
if coords is None or h5ad_path is None: |
|
|
raise gr.Error("Run embeddings first to compute UMAP.") |
|
|
coords = np.asarray(coords) |
|
|
if coords.ndim != 2 or coords.shape[1] != 2: |
|
|
raise gr.Error(f"UMAP coordinates look wrong, shape = {coords.shape}. Please recompute.") |
|
|
|
|
|
|
|
|
if obs_col == OBS_NONE_OPTION: |
|
|
out_png = draw_uncolored(coords) |
|
|
return str(out_png.resolve()) |
|
|
|
|
|
|
|
|
try: |
|
|
adata = sc.read_h5ad(h5ad_path, backed="r") |
|
|
series = adata.obs[obs_col] |
|
|
n = series.shape[0] |
|
|
if n != coords.shape[0]: |
|
|
gr.Warning(f"Length mismatch: obs has {n} rows, UMAP has {coords.shape[0]}. Using minimum length.") |
|
|
m = min(n, coords.shape[0]) |
|
|
series = series.iloc[:m] |
|
|
coords = coords[:m] |
|
|
except Exception as e: |
|
|
raise gr.Error(f"Failed to read .obs column '{obs_col}': {e}") |
|
|
|
|
|
|
|
|
s = series.copy() |
|
|
numeric_candidate = pd.to_numeric(s, errors="coerce") |
|
|
n_numeric_valid = int(np.isfinite(numeric_candidate.astype(float)).sum()) |
|
|
n_total = int(len(s)) |
|
|
|
|
|
|
|
|
if n_numeric_valid >= max(5, 0.5 * n_total): |
|
|
|
|
|
|
|
|
vals = pd.to_numeric(s, errors="coerce").astype(float).values |
|
|
mask = np.isfinite(vals) |
|
|
if mask.sum() < max(10, 0.1 * len(vals)): |
|
|
gr.Warning(f"Too few finite numeric values in '{obs_col}'. Showing uncolored UMAP.") |
|
|
return draw_uncolored(f"{obs_col}: insufficient numeric values") |
|
|
if np.nanmax(vals[mask]) == np.nanmin(vals[mask]): |
|
|
gr.Info(f"'{obs_col}' is constant. Showing uncolored UMAP.") |
|
|
return draw_uncolored(f"{obs_col}: constant") |
|
|
|
|
|
|
|
|
fig = plt.figure(figsize=(5.5, 5.0)) |
|
|
ax = fig.add_subplot(111) |
|
|
scatt = ax.scatter(coords[mask, 0], coords[mask, 1], s=3, alpha=0.85, c=vals[mask]) |
|
|
fig.colorbar(scatt, ax=ax, shrink=0.7, label=obs_col) |
|
|
if (~mask).any(): |
|
|
ax.scatter(coords[~mask, 0], coords[~mask, 1], s=3, alpha=0.25) |
|
|
ax.set_title(f"Tx1-70M embeddings colored by {obs_col}") |
|
|
ax.set_xlabel("UMAP1") |
|
|
ax.set_ylabel("UMAP2") |
|
|
fig.tight_layout() |
|
|
out_png = _unique_output("umap.png") |
|
|
fig.savefig(out_png, dpi=160) |
|
|
plt.close(fig) |
|
|
return str(out_png.resolve()) |
|
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
|
cats = s.astype(str).fillna("NA").values |
|
|
uniq = pd.unique(cats) |
|
|
n_cat = len(uniq) |
|
|
if n_cat > MAX_CATEGORIES: |
|
|
gr.Warning(f"'{obs_col}' has too many categories. Showing uncolored UMAP.") |
|
|
out_png = draw_uncolored(coords, f"{obs_col}: {n_cat} categories") |
|
|
return str(out_png.resolve()) |
|
|
if n_cat <= 1: |
|
|
gr.Info(f"'{obs_col}' has a single category. Showing uncolored UMAP.") |
|
|
out_png = draw_uncolored(coords, f"{obs_col}: 1 category") |
|
|
return str(out_png.resolve()) |
|
|
|
|
|
|
|
|
fig = plt.figure(figsize=(5.5, 5.0)) |
|
|
ax = fig.add_subplot(111) |
|
|
for cat in sorted(map(str, uniq)): |
|
|
mask = (cats == cat) |
|
|
ax.scatter(coords[mask, 0], coords[mask, 1], s=3, alpha=0.85, label=cat) |
|
|
ax.legend(markerscale=3, fontsize=8, loc="best", frameon=True, ncol=1) |
|
|
ax.set_title(f"Tx1-70M embeddings colored by {obs_col}") |
|
|
ax.set_xlabel("UMAP1") |
|
|
ax.set_ylabel("UMAP2") |
|
|
fig.tight_layout() |
|
|
out_png = _unique_output("umap.png") |
|
|
fig.savefig(out_png, dpi=160) |
|
|
plt.close(fig) |
|
|
return str(out_png.resolve()) |
|
|
|
|
|
|
|
|
class GradioProgressCallback(Callback): |
|
|
def __init__(self, progress, total_batches, start=0.25, end=0.75): |
|
|
self.progress = progress |
|
|
self.total = max(1, int(total_batches)) |
|
|
self.seen = 0 |
|
|
self.start = start |
|
|
self.end = end |
|
|
|
|
|
def predict_batch_end(self, state, logger): |
|
|
self.seen += 1 |
|
|
frac = self.start + (self.end - self.start) * (self.seen / self.total) |
|
|
self.progress(frac, desc=f"computing Tx1 embeddings ({self.seen} / {self.total} batches)") |
|
|
|
|
|
|
|
|
def _embed(adata_bytes, layer_name, feature_col, use_symbols, progress): |
|
|
|
|
|
|
|
|
progress(0.05, desc="loading AnnData") |
|
|
with tempfile.TemporaryDirectory() as td: |
|
|
|
|
|
|
|
|
fpath = Path(td) / "input.h5ad" |
|
|
with open(fpath, "wb") as f: |
|
|
f.write(adata_bytes) |
|
|
|
|
|
|
|
|
adata_backed = sc.read_h5ad(str(fpath), backed="r") |
|
|
limit = min(adata_backed.n_obs, 100000) |
|
|
if adata_backed.n_obs > 100000: |
|
|
gr.Warning("AnnData has >100K cells. Loading only the first 100K cells.") |
|
|
|
|
|
|
|
|
adata = adata_backed[:limit, :].to_memory() |
|
|
adata.var = adata.var.reset_index(drop=False, names="index") |
|
|
try: |
|
|
adata_backed.file.close() |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
del adata_backed |
|
|
gc.collect() |
|
|
|
|
|
|
|
|
if layer_name and layer_name not in adata.layers: |
|
|
raise gr.Error(f"Layer '{layer_name}' not found. Available: {list(adata.layers.keys())}") |
|
|
|
|
|
|
|
|
if feature_col not in adata.var.columns: |
|
|
raise gr.Error(f"Feature column '{feature_col}' not found. Available: {list(adata.var.columns)}") |
|
|
|
|
|
|
|
|
if use_symbols: |
|
|
|
|
|
|
|
|
col = adata.var[feature_col].astype(str).str.strip() |
|
|
direct = col.map(SYMBOL_TO_ENSEMBL) |
|
|
|
|
|
|
|
|
need_fallback = direct.isna() |
|
|
if need_fallback.any(): |
|
|
upper_mapped = col[need_fallback].str.upper().map(SYMBOL_TO_ENSEMBL_UCASE) |
|
|
direct.loc[need_fallback] = upper_mapped |
|
|
|
|
|
|
|
|
ambiguous_mask = direct.apply(lambda x: isinstance(x, (list, tuple)) and len(x) > 1 if pd.notna(x) else False) |
|
|
n_ambiguous = int(ambiguous_mask.sum()) |
|
|
if n_ambiguous > 0: |
|
|
gr.Warning(f"{n_ambiguous} symbol(s) mapped to multiple Ensembl IDs; selecting first mappings.") |
|
|
direct.loc[ambiguous_mask] = direct.loc[ambiguous_mask].apply(lambda x: x[0]) |
|
|
|
|
|
|
|
|
direct = direct.apply(lambda x: x[0] if isinstance(x, (list, tuple)) and len(x) >= 1 else x) |
|
|
|
|
|
|
|
|
n_mapped = int(direct.notna().sum()) |
|
|
n_total = int(len(direct)) |
|
|
gr.Info(f"Gene symbol conversion: mapped {n_mapped} / {n_total} ({n_mapped / max(1, n_total):.1%}).") |
|
|
if n_mapped == 0: |
|
|
raise gr.Error("Could not map any gene symbols to Ensembl IDs. Please check the column or turn off the symbol checkbox.") |
|
|
|
|
|
|
|
|
adata.var["ensembl_from_symbol"] = direct.astype(str) |
|
|
|
|
|
|
|
|
before = adata.n_vars |
|
|
adata = adata[:, adata.var["ensembl_from_symbol"].notna()].copy() |
|
|
after = adata.n_vars |
|
|
if after < before: |
|
|
gr.Warning(f"Dropped {before - after} genes that did not map.") |
|
|
|
|
|
|
|
|
dup_mask = adata.var.duplicated(subset=["ensembl_from_symbol"], keep="first") |
|
|
n_dups = int(dup_mask.sum()) |
|
|
if n_dups > 0: |
|
|
gr.Warning(f"Found {n_dups} duplicate Ensembl IDs after mapping; keeping the first occurrence.") |
|
|
adata = adata[:, ~dup_mask].copy() |
|
|
|
|
|
|
|
|
feature_col = "ensembl_from_symbol" |
|
|
|
|
|
|
|
|
if not adata.var[feature_col].str.startswith("ENSG").any(): |
|
|
raise gr.Error(f"Feature column '{feature_col}' does not appear to contain human Ensembl gene IDs. If the column contains gene symbols, use the checkbox.") |
|
|
|
|
|
|
|
|
progress(0.15, desc="loading model") |
|
|
model, vocab, model_config, collator_config = ComposerTX.from_hf( |
|
|
"tahoebio/TahoeX1", |
|
|
"70m", |
|
|
return_gene_embeddings=False |
|
|
) |
|
|
|
|
|
|
|
|
progress(0.20, desc="preparing AnnData") |
|
|
gene_id_key = feature_col |
|
|
adata.var["id_in_vocab"] = [vocab[gene] if gene in vocab else -1 for gene in adata.var[gene_id_key]] |
|
|
gene_ids_in_vocab = np.array(adata.var["id_in_vocab"]) |
|
|
num_matches = np.sum(gene_ids_in_vocab >= 0) |
|
|
frac_matches = num_matches / len(gene_ids_in_vocab) |
|
|
gr.Info(f"Matched {num_matches} / {len(gene_ids_in_vocab)} genes to vocabulary.") |
|
|
if frac_matches < 0.5: |
|
|
gr.Warning(f"Only {frac_matches:.1%} of genes matched to vocabulary. Embeddings may be poor.") |
|
|
adata = adata[:, adata.var["id_in_vocab"] >= 0] |
|
|
genes = adata.var[gene_id_key].tolist() |
|
|
gene_ids = np.array([vocab[gene] for gene in genes], dtype=int) |
|
|
|
|
|
|
|
|
progress(0.22, desc="creating data loader") |
|
|
count_matrix = _pick_layer(adata, layer_name) |
|
|
dataset = CountDataset( |
|
|
count_matrix, |
|
|
gene_ids, |
|
|
cls_token_id=vocab["<cls>"], |
|
|
pad_value=collator_config["pad_value"], |
|
|
) |
|
|
collate_fn = DataCollator( |
|
|
vocab=vocab, |
|
|
drug_to_id_path=collator_config.get("drug_to_id_path", None), |
|
|
do_padding=collator_config.get("do_padding", True), |
|
|
unexp_padding=False, |
|
|
pad_token_id=collator_config.pad_token_id, |
|
|
pad_value=collator_config.pad_value, |
|
|
do_mlm=False, |
|
|
do_binning=collator_config.get("do_binning", True), |
|
|
log_transform=collator_config.get("log_transform", False), |
|
|
target_sum=collator_config.get("target_sum"), |
|
|
mlm_probability=collator_config.mlm_probability, |
|
|
mask_value=collator_config.mask_value, |
|
|
max_length=2048, |
|
|
sampling=collator_config.sampling, |
|
|
num_bins=collator_config.get("num_bins", 51), |
|
|
right_binning=collator_config.get("right_binning", False), |
|
|
keep_first_n_tokens=collator_config.get("keep_first_n_tokens", 1), |
|
|
use_chem_token=collator_config.get("use_chem_token", False), |
|
|
) |
|
|
loader = torch.utils.data.DataLoader( |
|
|
dataset, |
|
|
batch_size=128, |
|
|
collate_fn=collate_fn, |
|
|
shuffle=False, |
|
|
drop_last=False, |
|
|
num_workers=0, |
|
|
pin_memory=False |
|
|
) |
|
|
|
|
|
|
|
|
cb = GradioProgressCallback(progress, total_batches=len(loader)) |
|
|
trainer = Trainer( |
|
|
model=model, |
|
|
device="gpu", |
|
|
device_train_microbatch_size="auto", |
|
|
callbacks=[cb] |
|
|
) |
|
|
|
|
|
|
|
|
predictions = trainer.predict(loader, return_outputs=True) |
|
|
|
|
|
|
|
|
progress(0.78, desc="aggregating embeddings") |
|
|
n_cells = len(dataset) |
|
|
d_model = model_config.d_model |
|
|
cell_array = np.empty((n_cells, d_model), dtype=np.float32) |
|
|
write_ptr = 0 |
|
|
for out in predictions: |
|
|
batch_emb = out["cell_emb"].detach().to("cpu").float().numpy() |
|
|
bsz = batch_emb.shape[0] |
|
|
cell_array[write_ptr:write_ptr+bsz] = batch_emb |
|
|
write_ptr += bsz |
|
|
|
|
|
|
|
|
norms = np.linalg.norm(cell_array, axis=1, keepdims=True) |
|
|
np.divide(cell_array, np.clip(norms, 1e-8, None), out=cell_array) |
|
|
|
|
|
|
|
|
adata.obsm[EMB_KEY] = cell_array |
|
|
|
|
|
|
|
|
layers = list(adata.layers.keys()) |
|
|
var_choices = _summarize_columns(adata.var) |
|
|
obs_choices = _summarize_columns(adata.obs) |
|
|
|
|
|
|
|
|
with tempfile.TemporaryDirectory() as td2: |
|
|
outp = Path(td2) / "tmp.h5ad" |
|
|
adata.write(outp, compression="gzip") |
|
|
with open(outp, "rb") as f: |
|
|
adata_persisted = f.read() |
|
|
|
|
|
|
|
|
del adata |
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
|
|
|
|
|
|
return cell_array, layers, var_choices, obs_choices, adata_persisted |
|
|
|
|
|
|
|
|
def run_pipeline(fileobj, layer_choice, var_choice, use_symbols, progress=gr.Progress(track_tqdm=False)): |
|
|
|
|
|
|
|
|
if fileobj is None: |
|
|
raise gr.Error("Please upload an .h5ad file.") |
|
|
if var_choice is None: |
|
|
raise gr.Error("Please select a .var column.") |
|
|
|
|
|
|
|
|
progress(0.02, desc="reading AnnData") |
|
|
with open(fileobj.name, "rb") as f: |
|
|
adata_bytes = f.read() |
|
|
|
|
|
|
|
|
E, layers, var_choices, obs_choices, adata_with_emb_bytes = _embed( |
|
|
adata_bytes=adata_bytes, |
|
|
layer_name=(None if layer_choice in [None, "", "<use .X>"] else layer_choice), |
|
|
feature_col=(None if var_choice in [None, ""] else var_choice), |
|
|
use_symbols=use_symbols, |
|
|
progress=progress |
|
|
) |
|
|
|
|
|
|
|
|
with tempfile.TemporaryDirectory() as td: |
|
|
tmp_in = Path(td) / "with_emb.h5ad" |
|
|
with open(tmp_in, "wb") as f: |
|
|
f.write(adata_with_emb_bytes) |
|
|
adata = sc.read_h5ad(tmp_in, backed=None) |
|
|
|
|
|
|
|
|
progress(0.80, desc="computing UMAP") |
|
|
coords = _compute_umap_from_emb(E) |
|
|
adata.obsm["X_umap"] = coords |
|
|
|
|
|
|
|
|
progress(0.90, desc="plotting UMAP") |
|
|
fig = plt.figure(figsize=(5.5, 5.0)) |
|
|
ax = fig.add_subplot(111) |
|
|
ax.scatter(coords[:, 0], coords[:, 1], s=3, alpha=0.75) |
|
|
ax.set_title("Tx1-70M embeddings") |
|
|
ax.set_xlabel("UMAP1") |
|
|
ax.set_ylabel("UMAP2") |
|
|
fig.tight_layout() |
|
|
umap_png = _unique_output("umap.png") |
|
|
fig.savefig(umap_png, dpi=160) |
|
|
plt.close(fig) |
|
|
|
|
|
|
|
|
update_obs_dd = gr.Dropdown(choices=[OBS_NONE_OPTION] + obs_choices, value=OBS_NONE_OPTION, interactive=True) |
|
|
|
|
|
|
|
|
progress(0.95, desc="saving outputs") |
|
|
parquet_path, h5ad_path = _save_outputs(adata, E) |
|
|
progress(1.00, desc="finished!") |
|
|
return str(umap_png.resolve()), str(parquet_path.resolve()), str(h5ad_path.resolve()), ["<use .X>"] + layers, var_choices, update_obs_dd, coords, str(h5ad_path.resolve()) |
|
|
|
|
|
|
|
|
css = """ |
|
|
div#tahoe-logo { |
|
|
margin-top: 10px; |
|
|
margin-bottom: 10px; |
|
|
} |
|
|
#logo-light {display: none;} |
|
|
@media (prefers-color-scheme: dark) { |
|
|
#logo-dark {display: none;} |
|
|
#logo-light {display: block;} |
|
|
} |
|
|
""" |
|
|
with gr.Blocks(title=APP_TITLE, css=css) as demo: |
|
|
|
|
|
|
|
|
coords_state = gr.State() |
|
|
h5ad_state = gr.State() |
|
|
|
|
|
|
|
|
with gr.Row(elem_id="tahoe-logo", equal_height=True): |
|
|
logo_light = gr.Image( |
|
|
value="tahoe-white-logo.png", |
|
|
height=50, |
|
|
show_label=False, |
|
|
container=False, |
|
|
interactive=False, |
|
|
elem_id="logo-light", |
|
|
show_share_button=False, |
|
|
show_fullscreen_button=False, |
|
|
show_download_button=False, |
|
|
scale=0 |
|
|
) |
|
|
logo_dark = gr.Image( |
|
|
value="tahoe-navy-logo.png", |
|
|
height=50, |
|
|
show_label=False, |
|
|
container=False, |
|
|
interactive=False, |
|
|
elem_id="logo-dark", |
|
|
show_share_button=False, |
|
|
show_fullscreen_button=False, |
|
|
show_download_button=False, |
|
|
scale=0 |
|
|
) |
|
|
gr.Markdown(f"# {APP_TITLE}\n{APP_DESC}") |
|
|
|
|
|
|
|
|
f_in = gr.File(label="Upload .h5ad", file_types=[".h5ad"], type="filepath") |
|
|
|
|
|
|
|
|
with gr.Row(equal_height=True): |
|
|
layer_dd = gr.Dropdown(choices=["<use .X>"], value="<use .X>", label="Layer to use (default: .X)", scale=1) |
|
|
with gr.Column(scale=1): |
|
|
var_dd = gr.Dropdown(choices=[], value=None, label="Name of .var column with Ensembl gene IDs (or gene symbols)") |
|
|
use_symbols_chk = gr.Checkbox(label="Selected .var column contains gene symbols (attempt conversion to Ensembl IDs)", value=False) |
|
|
|
|
|
|
|
|
run_btn = gr.Button("Compute Embeddings + UMAP", variant="primary") |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
umap_img = gr.Image(label="UMAP preview", interactive=False) |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
obs_dd = gr.Dropdown(choices=[], value=None, label="Name of .obs column to color UMAP by", interactive=False) |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
emb_parquet = gr.DownloadButton(label="Download embeddings (.parquet)") |
|
|
adata_with_emb = gr.DownloadButton(label="Download AnnData with embeddings in .obsm (.h5ad)") |
|
|
|
|
|
|
|
|
f_in.change( |
|
|
ensure_dropdowns, |
|
|
inputs=[f_in], |
|
|
outputs=[layer_dd, var_dd], |
|
|
queue=False |
|
|
) |
|
|
|
|
|
|
|
|
evt = run_btn.click( |
|
|
run_pipeline, |
|
|
inputs=[f_in, layer_dd, var_dd, use_symbols_chk], |
|
|
outputs=[umap_img, emb_parquet, adata_with_emb, layer_dd, var_dd, obs_dd, coords_state, h5ad_state], |
|
|
queue=True |
|
|
) |
|
|
|
|
|
|
|
|
evt.then( |
|
|
ensure_dropdowns, |
|
|
inputs=[f_in], |
|
|
outputs=[layer_dd, var_dd], |
|
|
queue=False |
|
|
) |
|
|
|
|
|
|
|
|
obs_dd.change(recolor_umap, inputs=[obs_dd, coords_state, h5ad_state], outputs=[umap_img], queue=False) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(allowed_paths=[str(OUTPUT_DIR.resolve())], max_file_size="5gb") |
|
|
|