Umair Khan commited on
Commit
0811027
·
1 Parent(s): 237ece6

update UI and reformat parquet output

Browse files
Files changed (2) hide show
  1. app.py +38 -15
  2. requirements.txt +2 -1
app.py CHANGED
@@ -14,6 +14,8 @@ import anndata as ad
14
  import pandas as pd
15
  import numpy as np
16
  import scanpy as sc
 
 
17
  from pathlib import Path
18
  from composer import Trainer, Callback
19
  from tahoex.model.model import ComposerTX
@@ -24,10 +26,11 @@ EMB_KEY = "X_tx1-70m"
24
  APP_TITLE = "Tx1-70M Embeddings"
25
  APP_DESC = """
26
  Upload an AnnData, compute Tx1-70M embeddings,
27
- preview a UMAP, and download the results. Files are
28
- limited to 5GB / 50K cells. If a file is less than 5GB but
29
- contains more than 50K cells, embeddings will be
30
- computed only for the first 50K cells.
 
31
  """
32
 
33
  # set up directories
@@ -39,10 +42,14 @@ with open("./symbol-to-ensembl.json", "r") as f:
39
  SYMBOL_TO_ENSEMBL = json.load(f)
40
  SYMBOL_TO_ENSEMBL_UCASE = {str(k).upper(): v for k, v in SYMBOL_TO_ENSEMBL.items()}
41
 
 
 
 
 
42
  # helper to read AnnData header
43
  def read_anndata_header(fileobj):
44
  adata = sc.read_h5ad(fileobj.name, backed="r")
45
- layers = ["<use .X>"] + list(adata.layers.keys())
46
  var_cols = list(adata.var.columns)
47
  obs_cols = list(adata.obs.columns)
48
  del adata
@@ -72,11 +79,27 @@ def _unique_output(name):
72
 
73
  # helper to save outputs
74
  def _save_outputs(adata, emb):
75
- emb_df = pd.DataFrame(emb, index=adata.obs_names)
 
 
 
 
 
 
 
 
 
 
 
 
76
  parquet_path = _unique_output("embs.parquet")
77
- emb_df.to_parquet(parquet_path)
 
 
78
  out_h5ad = _unique_output("adata_with_embs.h5ad")
79
  adata.write(out_h5ad)
 
 
80
  return parquet_path, out_h5ad
81
 
82
  # refresh dropdowns given a file object
@@ -103,7 +126,7 @@ def ensure_dropdowns(fileobj):
103
 
104
  # custom callback to report progress to Gradio
105
  class GradioProgressCallback(Callback):
106
- def __init__(self, progress, total_batches, start=0.1, end=0.6):
107
  self.progress = progress
108
  self.total = max(1, int(total_batches))
109
  self.seen = 0
@@ -119,6 +142,7 @@ class GradioProgressCallback(Callback):
119
  def _embed(adata_bytes, layer_name, feature_col, use_symbols, progress):
120
 
121
  # retrieve AnnData from bytes
 
122
  with tempfile.TemporaryDirectory() as td:
123
 
124
  # persist to a temporary file
@@ -206,7 +230,7 @@ def _embed(adata_bytes, layer_name, feature_col, use_symbols, progress):
206
  raise gr.Error(f"Feature column '{feature_col}' does not appear to contain Ensembl gene IDs. If the column contains gene symbols, use the checkbox.")
207
 
208
  # load model
209
- print("loading model")
210
  model, vocab, _, collator_config = ComposerTX.from_hf(
211
  "tahoebio/TahoeX1",
212
  "70m",
@@ -214,7 +238,7 @@ def _embed(adata_bytes, layer_name, feature_col, use_symbols, progress):
214
  )
215
 
216
  # prepare AnnData
217
- print("preparing AnnData")
218
  gene_id_key = feature_col
219
  adata.var["id_in_vocab"] = [vocab[gene] if gene in vocab else -1 for gene in adata.var[gene_id_key]]
220
  gene_ids_in_vocab = np.array(adata.var["id_in_vocab"])
@@ -228,7 +252,7 @@ def _embed(adata_bytes, layer_name, feature_col, use_symbols, progress):
228
  gene_ids = np.array([vocab[gene] for gene in genes], dtype=int)
229
 
230
  # create data loader
231
- print("creating data loader")
232
  count_matrix = _pick_layer(adata, layer_name)
233
  dataset = CountDataset(
234
  count_matrix,
@@ -279,7 +303,7 @@ def _embed(adata_bytes, layer_name, feature_col, use_symbols, progress):
279
  predictions = trainer.predict(loader, return_outputs=True)
280
 
281
  # aggregate embeddings
282
- print("aggregating embeddings")
283
  cell_embs = []
284
  for out in predictions:
285
  cell_embs.append(out["cell_emb"].cpu())
@@ -330,7 +354,6 @@ def run_pipeline(fileobj, layer_choice, var_choice, obs_choice, use_symbols, pro
330
  adata_bytes = f.read()
331
 
332
  # compute embeddings on GPU
333
- progress(0.10, desc="computing Tx1 embeddings")
334
  E, layers, var_cols, obs_cols, adata_with_emb_bytes = _embed(
335
  adata_bytes=adata_bytes,
336
  layer_name=(None if layer_choice in [None, "", "<use .X>"] else layer_choice),
@@ -347,13 +370,13 @@ def run_pipeline(fileobj, layer_choice, var_choice, obs_choice, use_symbols, pro
347
  adata = sc.read_h5ad(tmp_in, backed=None)
348
 
349
  # compute UMAP
350
- progress(0.60, desc="computing UMAP")
351
  color_series = adata.obs[obs_choice] if (obs_choice and obs_choice in adata.obs) else None
352
  coords = _compute_umap_from_emb(E)
353
  adata.obsm["X_umap"] = coords
354
 
355
  # plot UMAP
356
- progress(0.80, desc="plotting UMAP")
357
  import matplotlib.pyplot as plt
358
  fig = plt.figure(figsize=(5.5, 5.0))
359
  ax = fig.add_subplot(111)
 
14
  import pandas as pd
15
  import numpy as np
16
  import scanpy as sc
17
+ import pyarrow as pa
18
+ import pyarrow.parquet as pq
19
  from pathlib import Path
20
  from composer import Trainer, Callback
21
  from tahoex.model.model import ComposerTX
 
26
  APP_TITLE = "Tx1-70M Embeddings"
27
  APP_DESC = """
28
  Upload an AnnData, compute Tx1-70M embeddings,
29
+ preview a UMAP, and download the results.
30
+
31
+ **Limits:** Files up to 5GB. If an AnnData contains more
32
+ than 50K cells, embeddings will be computed **only
33
+ for the first 50K cells**.
34
  """
35
 
36
  # set up directories
 
42
  SYMBOL_TO_ENSEMBL = json.load(f)
43
  SYMBOL_TO_ENSEMBL_UCASE = {str(k).upper(): v for k, v in SYMBOL_TO_ENSEMBL.items()}
44
 
45
+ # set up parquet outputs
46
+ PARQUET_INDEX_COL = "index"
47
+ PARQUET_EMB_COL = "tx1-70m"
48
+
49
  # helper to read AnnData header
50
  def read_anndata_header(fileobj):
51
  adata = sc.read_h5ad(fileobj.name, backed="r")
52
+ layers = list(adata.layers.keys())
53
  var_cols = list(adata.var.columns)
54
  obs_cols = list(adata.obs.columns)
55
  del adata
 
79
 
80
  # helper to save outputs
81
  def _save_outputs(adata, emb):
82
+
83
+ # save parquet
84
+ d_model = int(emb.shape[1])
85
+ index_arr = pa.array(adata.obs_names.astype(str).tolist(), type=pa.string())
86
+ emb_arr = pa.array(emb.tolist(), type=pa.list_(pa.float32(), d_model))
87
+ table = pa.Table.from_arrays(
88
+ [index_arr, emb_arr],
89
+ names=[PARQUET_INDEX_COL, PARQUET_EMB_COL],
90
+ schema=pa.schema([
91
+ pa.field(PARQUET_INDEX_COL, pa.string()),
92
+ pa.field(PARQUET_EMB_COL, pa.list_(pa.float32(), d_model)),
93
+ ]),
94
+ )
95
  parquet_path = _unique_output("embs.parquet")
96
+ pq.write_table(table, parquet_path, compression="zstd", use_dictionary=True)
97
+
98
+ # save AnnData
99
  out_h5ad = _unique_output("adata_with_embs.h5ad")
100
  adata.write(out_h5ad)
101
+
102
+ # return paths
103
  return parquet_path, out_h5ad
104
 
105
  # refresh dropdowns given a file object
 
126
 
127
  # custom callback to report progress to Gradio
128
  class GradioProgressCallback(Callback):
129
+ def __init__(self, progress, total_batches, start=0.35, end=0.75):
130
  self.progress = progress
131
  self.total = max(1, int(total_batches))
132
  self.seen = 0
 
142
  def _embed(adata_bytes, layer_name, feature_col, use_symbols, progress):
143
 
144
  # retrieve AnnData from bytes
145
+ progress(0.12, desc="loading AnnData")
146
  with tempfile.TemporaryDirectory() as td:
147
 
148
  # persist to a temporary file
 
230
  raise gr.Error(f"Feature column '{feature_col}' does not appear to contain Ensembl gene IDs. If the column contains gene symbols, use the checkbox.")
231
 
232
  # load model
233
+ progress(0.22, desc="loading model")
234
  model, vocab, _, collator_config = ComposerTX.from_hf(
235
  "tahoebio/TahoeX1",
236
  "70m",
 
238
  )
239
 
240
  # prepare AnnData
241
+ progress(0.30, desc="preparing AnnData")
242
  gene_id_key = feature_col
243
  adata.var["id_in_vocab"] = [vocab[gene] if gene in vocab else -1 for gene in adata.var[gene_id_key]]
244
  gene_ids_in_vocab = np.array(adata.var["id_in_vocab"])
 
252
  gene_ids = np.array([vocab[gene] for gene in genes], dtype=int)
253
 
254
  # create data loader
255
+ progress(0.35, desc="creating data loader")
256
  count_matrix = _pick_layer(adata, layer_name)
257
  dataset = CountDataset(
258
  count_matrix,
 
303
  predictions = trainer.predict(loader, return_outputs=True)
304
 
305
  # aggregate embeddings
306
+ progress(0.78, desc="aggregating embeddings")
307
  cell_embs = []
308
  for out in predictions:
309
  cell_embs.append(out["cell_emb"].cpu())
 
354
  adata_bytes = f.read()
355
 
356
  # compute embeddings on GPU
 
357
  E, layers, var_cols, obs_cols, adata_with_emb_bytes = _embed(
358
  adata_bytes=adata_bytes,
359
  layer_name=(None if layer_choice in [None, "", "<use .X>"] else layer_choice),
 
370
  adata = sc.read_h5ad(tmp_in, backed=None)
371
 
372
  # compute UMAP
373
+ progress(0.85, desc="computing UMAP")
374
  color_series = adata.obs[obs_choice] if (obs_choice and obs_choice in adata.obs) else None
375
  coords = _compute_umap_from_emb(E)
376
  adata.obsm["X_umap"] = coords
377
 
378
  # plot UMAP
379
+ progress(0.90, desc="plotting UMAP")
380
  import matplotlib.pyplot as plt
381
  fig = plt.figure(figsize=(5.5, 5.0))
382
  ax = fig.add_subplot(111)
requirements.txt CHANGED
@@ -17,4 +17,5 @@ scanpy
17
  pynndescent
18
  umap-learn
19
  anndata
20
- h5py
 
 
17
  pynndescent
18
  umap-learn
19
  anndata
20
+ h5py
21
+ pyarrow