Source code for destvi_utils._destvi_utils

import base64
import logging
from argparse import ArgumentError
from io import BytesIO

import cmap2d
import gseapy
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from adjustText import adjust_text
from rich import print
from scipy.sparse import issparse
from scipy.stats import ks_2samp
from statsmodels.stats.multitest import multipletests

from . import _utils


[docs]def automatic_proportion_threshold( st_adata, kind_threshold="primary", output_file=None, ct_list=None, key_proportions="proportions", key_spatial="spatial", ): """ Function to compute automatic threshold on cell type proportion values. For further reference check [Lopez22]_. Parameters ---------- st_adata Spatial sequencing dataset with proportions in obsm['proportions'] and spatial location in obsm['spatial'] kind_threshold Which threshold value to use. Supported are 'primary', 'secondary'. 'min_value' uses the minimum of primary and secondary threshold for each cell type. output_file File where html output is stored. None means displaying the results and not storing them. Defaults to None. ct_list Celltypes to use. Defaults to all celltypes. key_proportions Obsm key pointing to cell-type proportions. key_spatial Obsm key pointing to location of cells. Returns ------- ct_thresholds Dictionary containing all threshold values. """ if key_proportions not in st_adata.obsm: raise ValueError( f"Please provide cell type proportions in st_adata.obsm[{key_proportions}] and rerun.", ) if key_spatial not in st_adata.obsm: raise ValueError( f"Please provide cell type locations in st_adata.obsm[{key_spatial}] and rerun." ) if ct_list is None: ct_list = list(st_adata.obsm[key_proportions].columns) ct_thresholds = {} html = "<h2>Automatic thresholding</h2>" for name_ct in ct_list: fig = plt.figure(figsize=(15, 5)) fig.suptitle( name_ct + ": critical points", fontsize="x-large", fontweight="semibold" ) array = st_adata.obsm[key_proportions][name_ct] vmax = np.quantile(array.values, 0.99) # get characteristic values quantiles, stack = _utils._form_stacked_quantiles(array.values) index, z_values = _utils._get_autocorrelations( st_adata, stack, quantiles, key_spatial ) ( z_values, _, derivative, sign_2nd, _, ) = _utils._smooth_get_critical_points(index, z_values, s=0.1) ipoints = index[np.where(sign_2nd[:-1] != sign_2nd[1:])[0]] nom_map = index[np.argmin(derivative)] # add thresholds to dict if kind_threshold == "primary": ct_thresholds[name_ct] = nom_map elif kind_threshold == "secondary": ct_thresholds[name_ct] = ipoints[0] else: raise ArgumentError( 'Kind threshold {} is not defined. Use "secondary" or "primary"'.format( kind_threshold ) ) # PLOT 1 shows proportions in spatial dimensions without thresholding def plot_proportions_xy(ax, threshold): _utils._prettify_axis(ax, spatial=True) plt.scatter( st_adata.obsm[key_spatial][:, 0], st_adata.obsm[key_spatial][:, 1], c=array * (array.values > threshold), s=14, vmax=vmax, cmap="Reds", ) plt.colorbar() plt.title("name_ct, threshold: t={:0.3f}".format(threshold)) plt.tight_layout(rect=[0, 0.03, 1, 0.9]) return ax ax1 = plt.subplot(131) _utils._prettify_axis(ax1) ax1 = plot_proportions_xy(ax1, 0) # plot on top of histogram ax2 = plt.subplot(132) _utils._prettify_axis(ax2) n, _, _ = plt.hist(array.values) plt.vlines( ipoints, ymin=0, ymax=np.max(n), color="red", linestyle="--", label="secondary thresholds", ) # nominal mapping plt.axvline(nom_map, c="red", label="main threshold") plt.xlabel("proportions value") plt.title("Cell type frequency histogram") plt.legend() ax3 = plt.subplot(133) ax3 = plot_proportions_xy(ax3, ct_thresholds[name_ct]) if output_file is not None: tmpfile = BytesIO() plt.savefig(tmpfile, format="png") encoded = base64.b64encode(tmpfile.getvalue()).decode("utf-8") html += "<img src='data:image/png;base64,{}'>".format(encoded) plt.close() else: plt.show() # dump+write to HTML if output_file is not None: logging.warning( "Saving output to {}. Set output_file=None to display results.".format( output_file ) ) with open(output_file, "w") as f: f.write(html) return ct_thresholds
[docs]def explore_gamma_space( st_model, sc_model, st_adata=None, ct_thresholds=None, output_file=None, ct_list=None, key_proportions="proportions", key_spatial="spatial", ): """ Function to compute automatic threshold on cell type proportion values. For further reference check [Lopez22]_. Parameters ---------- st_model Trained destVI model sc_model Trained CondSCVI model st_adata Spatial sequencing dataset with proportions in obsm[key_proportions]. Otherwise uses data in st_model. ct_threshold List with threshold values for cell type proportions output_file File where html output is stored. None means displaying the results and not storing them. Defaults to None. ct_list Celltypes to use. Defaults to all celltypes. key_proportions Obsm key pointing to cell-type proportions. key_spatial Obsm key pointing to location of cells. """ if output_file is not None: html = "<h1>sPCA analysis</h1>" if st_adata is None: st_adata = st_model.adata st_adata.obsm[key_proportions] = st_model.get_proportions() else: if key_proportions not in st_adata.obsm: raise ValueError( f"Please provide cell type proportions in st_adata.obsm[{key_proportions}] and rerun.", ) if key_spatial not in st_adata.obsm: raise ValueError( f"Please provide cell type locations in st_adata.obsm[{key_spatial}] and rerun." ) sc_adata = sc_model.adata if ct_list is None: ct_list = list(st_adata.obsm[key_proportions].columns) if ct_thresholds is None: ct_thresholds = {ct: 0 for ct in ct_list} tri_coords = [[-1, -1], [-1, 1], [1, 0]] tri_colors = [(1, 0, 0), (0, 1, 0), (0, 0, 1)] gamma = st_model.get_gamma(return_numpy=True) for name_ct in ct_list: if ct_thresholds[name_ct] > np.max( st_adata.obsm[key_proportions][name_ct].values ): logging.warning( f"Defined threshold {ct_thresholds[name_ct]} higher than highest proportion value. " + f"Falling back to no threshold for cell-type {name_ct}" ) ct_threshold = 0 else: ct_threshold = ct_thresholds[name_ct] filter_ = st_adata.obsm[key_proportions][name_ct].values > ct_threshold locations = st_adata.obsm[key_spatial][filter_] proportions = st_adata.obsm[key_proportions][name_ct].values[filter_] ct_index = np.where(name_ct == st_model.cell_type_mapping)[0][0] data = gamma[:, :, ct_index][filter_] vec = _utils._get_spatial_components(locations, proportions, data) # project data onto them projection = np.dot(data - np.mean(data, 0), vec) # create the colormap cmap = cmap2d.TernaryColorMap(tri_coords, tri_colors) # apply colormap to spatial data color = np.vstack([cmap(projection[i]) for i in range(projection.shape[0])]) fig = plt.figure(figsize=(15, 5)) fig.suptitle(name_ct, fontsize="x-large", fontweight="semibold") ax1 = plt.subplot(132) _utils._prettify_axis(ax1) plt.scatter(projection[:, 0], projection[:, 1], c=color, marker="X") # variance and explained variance total_var = np.sum(np.diag(np.cov(data.T))) explained_var = 100 * np.diag(np.cov(projection.T)) / total_var plt.xlabel("SpatialPC1 ({:.1f}% explained var)".format(explained_var[0])) plt.ylabel("SpatialPC2 ({:.1f}% explained var)".format(explained_var[1])) plt.title("Projection of the spatial data") ax3 = plt.subplot(131) _utils._prettify_axis(ax3, spatial=True) plt.scatter( st_adata.obsm[key_spatial][:, 0], st_adata.obsm[key_spatial][:, 1], alpha=0.1, s=7, c="blue", ) plt.scatter( st_adata.obsm[key_spatial][filter_, 0], st_adata.obsm[key_spatial][filter_, 1], c=color, s=7, ) plt.title("Spatial transcriptome coloring") # go back to the single-cell data and find gene correlated with the axis sc_adata_slice = sc_adata[ sc_adata.obs[sc_model.registry_["setup_args"]["labels_key"]] == name_ct ].copy() is_sparse = issparse(sc_adata_slice.X) normalized_counts = sc_adata_slice.X.A if is_sparse else sc_adata_slice.X indices_ct = np.where( sc_adata.obs[sc_model.registry_["setup_args"]["labels_key"]] == name_ct )[0] sc_latent = sc_model.get_latent_representation(indices=indices_ct) sc_projection = np.dot(sc_latent - np.mean(sc_latent, 0), vec) # show the colormap for single-cell data color = np.vstack( [cmap(sc_projection[i]) for i in range(sc_projection.shape[0])] ) ax2 = plt.subplot(133) _utils._prettify_axis(ax2) plt.scatter(sc_projection[:, 0], sc_projection[:, 1], c=color) # variance and explained variance total_var = np.sum(np.diag(np.cov(sc_latent.T))) explained_var = 100 * np.diag(np.cov(sc_projection.T)) / total_var plt.xlabel("SpatialPC1 ({:.1f}% explained var)".format(explained_var[0])) plt.ylabel("SpatialPC2 ({:.1f}% explained var)".format(explained_var[1])) plt.title("Projection of the scRNA-seq data") plt.tight_layout(rect=[0, 0.03, 1, 0.9]) # DUMP TO HTML tmpfile = BytesIO() plt.savefig(tmpfile, dpi="figure", format="png") encoded = base64.b64encode(tmpfile.getvalue()).decode("utf-8") if output_file is not None: html += "<img src='data:image/png;base64,{}'>".format(encoded) else: plt.show() # calculate correlations, and for each axis: # (A) display top 50 genes + AND - (B) for each gene set, get GSEA for d in [0, 1]: if output_file is not None: html += f"<h4>Genes associated with SpatialPC{d+1}</h4>" else: print("[bold]Genes associated with SpatialPC{}[/bold]".format(d + 1)) r = _utils._vcorrcoef(normalized_counts.T, sc_projection[:, d]) for mode in ["Positively", "Negatively"]: ranking = np.argsort(r) if mode == "Positively": ranking = ranking[::-1] gl = list(st_adata.var.index[ranking[:50]]) enr = gseapy.enrichr( gene_list=gl, description="pathway", gene_sets="BioPlanet_2019", outdir="test", no_plot=True, ) text_signatures = enr.results.head(10)["Term"].values for i in range(len(text_signatures)): if enr.results.iloc[i]["Adjusted P-value"] < 0.01: text_signatures[i] += "*" if output_file is not None: html += f"<h5> {mode} </h5>" html += "<p>" + ", ".join(gl) + "</p>" html += "<p>" + ", ".join(text_signatures) + "</p>" else: print("\n") print("[italic]{}[/italic]".format(mode)) print( "---------------------------------------------------------------------------------------" ) print(", ".join(gl)) print( "---------------------------------------------------------------------------------------" ) print(", ".join(text_signatures)) print("\n \n \n") plt.close(fig) # write HTML if output_file is not None: logging.warning( "Saving output to {}. Set output_file=None to display results.".format( output_file ) ) with open(output_file, "w") as f: f.write(html)
[docs]def de_genes( st_model, mask, ct, threshold=0.0, st_adata=None, mask2=None, key=None, N_sample=10, pseudocount=0.01, key_proportions="proportions", ): """ Function to compute differential expressed genes from generative model. For further reference check [Lopez22]_. Parameters ---------- st_adata Spatial sequencing dataset with proportions in obsm[key_proportions]. If not provided uses data in st_model. st_model Trained destVI model mask Mask for subsetting the spots to condition 1 in differential expression. mask2 Mask for subsetting the spots to condition 2 in differential expression (reference). If none, inverse of mask. ct Cell type for which differential expression is computed threshold Proportion threshold to subset to spots with this amount of cell type proportion key Key to store values in st_adata.uns[key]. If None returns pandas dataframe with DE results. Defaults to None N_sample N_samples drawn from generative model to simulate expression values. pseudocount Pseudocount added at computation of logFC. Increasing leads to lower logFC of lowly expressed genes. key_proportions Obsm key pointing to cell-type proportions. Returns ------- res If key is None. Pandas dataframe containing results of differential expression. Dataframe columns are "log2FC", "pval", "score". If key is provided. mask, mask2 and de_results are stored in st_adata.uns[key]. Dictionary keys are "mask_active", "mask_rest", "de_results". """ # get statistics if mask2 is None: mask2 = ~mask if st_adata is None: st_adata = st_model.adata st_adata.obsm[key_proportions] = st_model.get_proportions() else: if key_proportions not in st_adata.obsm: raise ValueError( f"Please provide cell type proportions in st_adata.obsm[{key_proportions}] and rerun." ) if st_model.registry_["setup_args"]["layer"]: expression = st_adata.layers[st_model.registry_["setup_args"]["layer"]] else: expression = st_adata.X mask = np.logical_and(mask, st_adata.obsm[key_proportions][ct] > threshold) mask2 = np.logical_and(mask2, st_adata.obsm[key_proportions][ct] > threshold) avg_library_size = np.mean(np.sum(expression, axis=1).flatten()) exp_px_o = st_model.module.px_o.detach().exp().cpu().numpy() imputations = st_model.get_scale_for_ct(ct).values mean = avg_library_size * imputations concentration = torch.tensor(avg_library_size * imputations / exp_px_o) rate = torch.tensor(1.0 / exp_px_o) # slice conditions N_mask = N_unmask = N_sample def simulation(mask_, N_mask_): # generate simulated = ( torch.distributions.Gamma(concentration=concentration[mask_], rate=rate) .sample((N_mask_,)) .cpu() .numpy() ) simulated = np.log(simulated + 1) simulated = simulated.reshape((-1, simulated.shape[-1])) return simulated simulated_case = simulation(mask, N_mask) simulated_control = simulation(mask2, N_unmask) de = np.array( [ ks_2samp( simulated_case[:, gene], simulated_control[:, gene], alternative="two-sided", mode="asymp", ) for gene in range(simulated_control.shape[1]) ] ) lfc = np.log2(pseudocount + mean[mask].mean(0)) - np.log2( pseudocount + mean[mask2].mean(0) ) res = pd.DataFrame( data=np.vstack([lfc, de[:, 0], de[:, 1]]), columns=st_adata.var.index, index=["log2FC", "score", "pval"], ).T # Store results in st_adata if key is not None: st_adata.uns[key] = {} st_adata.uns[key]["de_results"] = res.sort_values(by="score", ascending=False) st_adata.uns[key]["mask_active"] = mask st_adata.uns[key]["mask_rest"] = mask2 return st_adata else: return res
[docs]def plot_de_genes( st_adata, key, output_file=None, interesting_genes=None, key_spatial="spatial" ): """ Function to plot results of differential expressed genes in a Volcano plot. For further reference check [Lopez22]_. Parameters ---------- st_adata Spatial sequencing dataset with precomputed de genes using de_genes function. key Key under which results of DE comparison are stored output_file File where picture is stored. None means displaying the results and not storing them. Defaults to None. interesting_genes Label dots in scatter plots with corresponding gene name. Uses first two genes if None. key_spatial Obsm key pointing to location of cells. """ if key_spatial not in st_adata.obsm: raise ValueError( f"Please provide locations in st_adata.obsm[{key_spatial}] and rerun." ) if key not in st_adata.uns: raise ValueError( "DE results are not stored with given key. Please run de_genes function with given key." ) matching_genes = np.array([i in st_adata.var_names for i in interesting_genes]) if not matching_genes.all(): missing_genes = np.array(interesting_genes)[~matching_genes] raise ValueError( "{} are not in st_adata.var_names. Remove these genes from interesting_genes.".format( missing_genes ) ) locations = st_adata.obsm[key_spatial] res = st_adata.uns[key]["de_results"] mask_active = st_adata.uns[key]["mask_active"] mask_rest = st_adata.uns[key]["mask_rest"] corr_p_vals = multipletests(res["pval"], method="fdr_bh") min_score = np.min(res["score"][corr_p_vals[0]]) plt.figure(figsize=(10, 5)) # plot DE genes ax1 = plt.subplot(122) ax1.text( -0.1, 1.05, "B", transform=ax1.transAxes, fontsize=16, fontweight="bold", va="top", ha="right", ) # plot DE genes mask_de = (res["score"] > min_score) * ( np.abs(res["log2FC"]) > _utils._get_delta(res["log2FC"]) ) plt.scatter(res["log2FC"][mask_de], res["score"][mask_de], s=10, c="r") plt.scatter(res["log2FC"][~mask_de], res["score"][~mask_de], s=10, c="black") plt.xlabel("log2 fold-change") plt.ylabel("score") plt.grid(False) if interesting_genes is not None: texts = [] for gene in interesting_genes: x_coord, y_coord = res.loc[gene, "log2FC"], res.loc[gene, "score"] plt.scatter(x_coord, y_coord, c="r", s=10) texts += [plt.text(x_coord, y_coord, gene, fontsize=12)] adjust_text( texts, res["log2FC"].values, res["score"].values, arrowprops=dict(arrowstyle="-", color="blue"), ) ax2 = plt.subplot(121) ax2.text( -0.1, 1.05, "A", transform=ax2.transAxes, fontsize=16, fontweight="bold", va="top", ha="right", ) ax2.scatter( locations[mask_active][:, 0], locations[mask_active][:, 1], s=5, label="active" ) ax2.scatter( locations[mask_rest][:, 0], locations[mask_rest][:, 1], s=5, label="rest" ) plt.legend() _utils._prettify_axis(ax2, spatial=True) plt.tight_layout() if output_file is not None: logging.warning( "Saving output to {}. Set output_file=None to display results.".format( output_file ) ) plt.savefig(output_file, dpi=300) plt.close() else: plt.show()