Spaces:
Running
Running
| from functools import partial | |
| import os | |
| from carabiner.mpl import add_legend, grid, colorblind_palette | |
| import gradio as gr | |
| import numpy as np | |
| import matplotlib as mpl | |
| import matplotlib.pyplot as plt | |
| from scipy.integrate import solve_ivp | |
| from scipy.stats import multivariate_hypergeom, nbinom | |
| # Set the default color cycle | |
| mpl.rcParams['axes.prop_cycle'] = mpl.cycler( | |
| color=("lightgrey", "dimgrey") + colorblind_palette()[1:], | |
| ) | |
| SEED: int = 42 | |
| MAX_TIME: float = 5. | |
| SOURCES_DIR: str = "sources" | |
| def inject_markdown(filename): | |
| with open(os.path.join(SOURCES_DIR, filename), 'r') as f: | |
| md = f.read() | |
| return gr.Markdown( | |
| md, | |
| latex_delimiters=[ | |
| {"left": "$$", "right": "$$", "display": True}, | |
| {"left": "$", "right": "$", "display": False}, | |
| ], | |
| ) | |
| def lotka_volterra(t, y, w, K): | |
| remaining_capacity = np.sum(y) / K | |
| dy = w * y.flatten() * (1. - remaining_capacity) | |
| return dy | |
| def grow(t, n0, w, K=None): | |
| """Deterministic population size at time t. | |
| n0 : initial cells | |
| w : growth rate | |
| t : time (arbitrary units) | |
| K : carrying capacity (None → pure exponential) | |
| """ | |
| if K is None: | |
| return n0 * np.exp(w[None] * t[:,None]) | |
| # logistic with shared K | |
| else: | |
| ode_solution = solve_ivp( | |
| lotka_volterra, | |
| t_span=sorted(set([0, max(t)])), | |
| t_eval=sorted(t), | |
| y0=n0, | |
| vectorized=True, | |
| args=(w, K), | |
| ) | |
| # print(ode_solution) | |
| return ode_solution.y | |
| def plotter_t(x, growths, scatter=False, **kwargs): | |
| fig, axes = grid(aspect_ratio=1.5) | |
| plotter_f = partial(axes.scatter, s=5.) if scatter else axes.plot | |
| for i, y in enumerate(growths): | |
| plotter_f( | |
| x.flatten(), y.flatten(), | |
| label=f"Mutant {i-1}" if i > 1 else "_none", | |
| ) | |
| axes.set( | |
| xlabel="Time", | |
| yscale="log", | |
| **kwargs, | |
| ) | |
| add_legend(axes) | |
| return fig | |
| def plotter_ref(x, growths, scatter=False, fitlines=None, text=None, **kwargs): | |
| fig, axes = grid(aspect_ratio=1.5 if text is None else 1.7) | |
| plotter_f = partial(axes.scatter, s=5.) if scatter else axes.plot | |
| for i, y in enumerate(growths): | |
| plotter_f( | |
| x.flatten(), y.flatten(), | |
| label=f"Mutant {i-1}" if i > 1 else "_none", | |
| ) | |
| if fitlines is not None: | |
| fit_x, fit_y = fitlines | |
| for i, b in enumerate(fit_y.flatten()): | |
| y = np.exp(np.log(fit_x) @ b[None]) | |
| print(fit_x.shape, y.shape, b) | |
| axes.plot( | |
| fit_x.flatten(), y.flatten(), | |
| label="_none", | |
| ) | |
| if text is not None: | |
| axes.text( | |
| 1.05, .1, | |
| text, | |
| fontsize=10, | |
| transform=axes.transAxes, | |
| ) | |
| axes.set( | |
| xscale="log", | |
| yscale="log", | |
| **kwargs, | |
| ) | |
| add_legend(axes) | |
| return fig | |
| def calculate_growth_curves(inoculum, inoculum_var, carrying_capacity, fitness, n_timepoints=100): | |
| inoculum_var = inoculum + inoculum_var * np.square(inoculum) | |
| p = inoculum / inoculum_var | |
| n = (inoculum ** 2.) / (inoculum_var - inoculum) | |
| w = [1., 0.] + list(fitness) | |
| n0 = nbinom.rvs(n, p, size=len(w), random_state=SEED) | |
| t = np.linspace(0., MAX_TIME, num=int(n_timepoints)) | |
| growths = grow(t, n0, w, inoculum * carrying_capacity) | |
| ref_expansion = growths[0] / n0[0] | |
| return t, ref_expansion, growths | |
| def growth_plotter(inoculum, inoculum_var, carrying_capacity, *fitness): | |
| t, ref_expansion, growths = calculate_growth_curves(inoculum, inoculum_var, carrying_capacity, fitness, n_timepoints=100) | |
| return [ | |
| plotter_t( | |
| t, | |
| growths, | |
| ylabel="Number of cells per strain", | |
| ), | |
| plotter_ref( | |
| ref_expansion, | |
| growths, | |
| xlabel="Fold-expansion of wild-type", | |
| ylabel="Number of cells per strain", | |
| ), | |
| ] | |
| def reads_sampler(population, sample_frac, seq_depth, reps, variance): | |
| samples = [] | |
| for i, timepoint_pop in enumerate(np.split(population.astype(int), population.shape[-1], axis=-1)): | |
| sample_size = np.floor(timepoint_pop.sum() * sample_frac).astype(int) | |
| samples.append( | |
| multivariate_hypergeom.rvs( | |
| m=timepoint_pop.flatten(), | |
| n=sample_size, | |
| size=reps, | |
| random_state=SEED + i, | |
| ).T | |
| ) | |
| samples = np.stack(samples, axis=-2) | |
| read_means = np.floor(seq_depth * samples.shape[0] * samples / samples.sum(axis=0, keepdims=True)) | |
| variance = read_means + variance * np.square(read_means) | |
| p = read_means / variance | |
| n = (read_means ** 2.) / (variance - read_means) | |
| return np.stack([ | |
| nbinom.rvs(n[...,i], p[...,i], random_state=SEED + i) | |
| for i in range(reps) | |
| ], axis=-1) | |
| def fitness_fitter(read_counts, ref_expansion): | |
| read_count_expansion = read_counts / np.mean(read_counts[:,:1], axis=-1, keepdims=True) | |
| read_count_expansion_ref = read_count_expansion[:1] | |
| log_read_count_correction = np.log(read_count_expansion) - np.log(read_count_expansion_ref) | |
| ref_expansion = np.tile( | |
| np.log(ref_expansion)[:,None], | |
| (1, log_read_count_correction.shape[-1]), | |
| ).reshape((-1, 1)) | |
| betas = [] | |
| for i, log_strain_counts_corrected in enumerate(log_read_count_correction): | |
| ols_fit = np.linalg.lstsq(a=ref_expansion, b=log_strain_counts_corrected.flatten()) | |
| betas.append(ols_fit[0]) | |
| return log_read_count_correction, np.asarray(betas) | |
| def fitness_fitter_spike(log_read_count_corrected): | |
| log_spike_count_corrected = log_read_count_corrected[1,...].flatten()[...,None] | |
| betas = [] | |
| for i, log_strain_counts_corrected in enumerate(log_read_count_corrected): | |
| ols_fit = np.linalg.lstsq( | |
| a=log_spike_count_corrected, | |
| b=log_strain_counts_corrected.flatten(), | |
| ) | |
| betas.append(ols_fit[0]) | |
| return log_spike_count_corrected, np.asarray(betas) | |
| def reads_plotter( | |
| sample_frac, seq_reps, seq_depth, read_var, | |
| inoculum, inoculum_var, carrying_capacity, *fitness | |
| ): | |
| t, ref_expansion, growths = calculate_growth_curves(inoculum, inoculum_var, carrying_capacity, fitness, n_timepoints=10) | |
| read_counts = reads_sampler(growths, sample_frac, seq_depth, seq_reps, read_var) | |
| log_read_count_correction, betas = fitness_fitter(read_counts, ref_expansion) | |
| plot_text = "\n".join( | |
| f"Mutant {i-1}: $w_{i-1}/w_{'{wt}'}={1. + b:.2f}$" | |
| for i, b in enumerate(betas.flatten()) if i > 1 | |
| ) | |
| log_spike_count_corrected, spike_betas = fitness_fitter_spike(log_read_count_correction) | |
| plot_text_spike = "\n".join( | |
| f"Mutant {i-1}: $w_{i-1}/w_{'{wt}'}={1. - b:.2f}$" | |
| for i, b in enumerate(spike_betas.flatten()) if i > 1 | |
| ) | |
| read_count_correction = np.exp(log_read_count_correction) | |
| return growth_plotter(inoculum, inoculum_var, carrying_capacity, *fitness) + [ | |
| plotter_t( | |
| np.tile(t[:,None], (1, seq_reps)), | |
| read_counts, | |
| scatter=True, | |
| ylabel="Read counts per strain", | |
| ), | |
| plotter_ref( | |
| np.tile(ref_expansion[:,None], (1, seq_reps)), | |
| read_count_correction, | |
| scatter=True, | |
| fitlines=(ref_expansion[:,None], betas), | |
| text=plot_text, | |
| xlabel="Fold-expansion of wild-type", | |
| ylabel="$\\frac{c_1(t)}{c_{wt}(t)} / \\frac{c_1(0)}{c_{wt}(0)}$", | |
| ), | |
| plotter_ref( | |
| read_count_correction[1:2,...], | |
| read_count_correction, | |
| scatter=True, | |
| fitlines=(read_count_correction[1:2,...].flatten()[...,None], spike_betas), | |
| text=plot_text_spike, | |
| xlabel="$\\frac{c_{spike}(t)}{c_{wt}(t)} / \\frac{c_{spike}(0)}{c_{wt}(0)}$", | |
| ylabel="$\\frac{c_1(t)}{c_{wt}(t)} / \\frac{c_1(0)}{c_{wt}(0)}$", | |
| ), | |
| ] | |
| with gr.Blocks() as demo: | |
| inject_markdown("header.md") | |
| # Growth curves | |
| inject_markdown("growth-curve-intro.md") | |
| mut_fitness_defaults = [.5, 2., .2] | |
| with gr.Row(): | |
| relative_fitness = [ | |
| gr.Slider(0., 3., step=.1, value=w, label=f"Relative fitness, mutant {i + 1}") | |
| for i, w in enumerate(mut_fitness_defaults) | |
| ] | |
| with gr.Row(): | |
| n_mutants = len(mut_fitness_defaults) | |
| inoculum = gr.Slider( | |
| 10, 1_000_000, | |
| step=10, | |
| value=1000, | |
| label="Average inoculum per strain", | |
| ) | |
| inoculum_var = gr.Slider( | |
| .001, 1., | |
| step=.001, | |
| value=.001, | |
| label="Inoculum variance between strains", | |
| ) | |
| carrying_capacity = gr.Slider( | |
| len(mut_fitness_defaults) + 1, 10_000, | |
| step=1, value=10, | |
| label="Total carrying capacity (x inoculum)", | |
| ) | |
| plot_growth = gr.Button("Plot growth curves") | |
| growth_curves_t = gr.Plot(label="Growth vs time", format="png") | |
| inject_markdown("growth-curve-t-independent.md") | |
| growth_curves_ref = gr.Plot(label="Growth vs WT expansion", format="png") | |
| growth_curves = [growth_curves_t, growth_curves_ref] | |
| # Read counts | |
| inject_markdown("read-counts-intro.md") | |
| with gr.Row(): | |
| sample_frac = gr.Slider( | |
| .001, 1., step=.001, | |
| value=.1, | |
| label="Fraction of population per sample", | |
| ) | |
| seq_reps = gr.Slider( | |
| 1, 10, | |
| step=1, | |
| value=3, | |
| label="Technical replicates", | |
| ) | |
| seq_depth = gr.Slider( | |
| 10, 10_000, | |
| step=10, | |
| value=10_000, | |
| label="Average reads per strain per sample", | |
| ) | |
| read_var = gr.Slider( | |
| .001, 1., | |
| step=.001, | |
| value=.001, | |
| label="Sequencing variance", | |
| ) | |
| plot_reads = gr.Button("Plot read counts") | |
| read_curves_t = gr.Plot(label="Read counts vs time", format="png") | |
| inject_markdown("read-counts-expansion.md") | |
| read_curves_ref = gr.Plot(label="Read count diff vs WT expansion", format="png") | |
| inject_markdown("read-counts-spike.md") | |
| read_curves_t2 = gr.Plot(label="Read count diff vs spike count diff", format="png") | |
| read_curves = [ | |
| read_curves_t, | |
| read_curves_ref, | |
| read_curves_t2, | |
| ] | |
| # Events | |
| plot_growth.click( | |
| fn=growth_plotter, | |
| inputs=[inoculum, inoculum_var, carrying_capacity, *relative_fitness], | |
| outputs=growth_curves, | |
| ) | |
| plot_reads.click( | |
| fn=reads_plotter, | |
| inputs=[sample_frac, seq_reps, seq_depth, read_var] + [inoculum, inoculum_var, carrying_capacity, *relative_fitness], | |
| outputs=growth_curves + read_curves, | |
| ) | |
| demo.launch(share=True) | |