Tutorial 2: Application of AlignDG on real spatial transcriptomics data (DLPFC)
[ ]:
import warnings
warnings.filterwarnings('ignore')
import anndata as ad
import scanpy as sc
import pandas as pd
import os
import numpy as np
from scipy import sparse
import torch
import aligndg
# set seed as 42
aligndg.set_global_seed(42)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
sc.set_figure_params(scanpy=True, dpi=80, dpi_save=600, frameon=True, vector_friendly=True, color_map=None, format='pdf', facecolor=None, transparent=True, ipython_format='png2x')
[2]:
# get cuda version (oue device is 12, some new device may be 13)
! nvcc -V
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Mon_Apr__3_17:16:06_PDT_2023
Cuda compilation tools, release 12.1, V12.1.105
Build cuda_12.1.r12.1/compiler.32688072_0
Check jax device on cuda or cpu. If you hava cuda device and result is cpu, please consider: pip install –upgrade “jax[cuda12]”. (cuda version can be 12 or 13)
[14]:
import jax
jax.default_backend(), device
[14]:
('gpu', device(type='cuda', index=0))
[4]:
from aligndg.utils import graph_construction
section_ids = ['151673', '151674']
adata_list = []
adj_list = []
for section_id in section_ids:
input_dir ='/root/data/DLPFC'
adata = sc.read_h5ad(os.path.join(input_dir, section_id + '_preprocessed.h5'))
adata.var_names_make_unique()
adata.obs_names = [x + '_' + section_id for x in adata.obs_names]
sc.pp.highly_variable_genes(adata, flavor='seurat_v3', n_top_genes=10000)
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
adata = adata[:, adata.var['highly_variable']]
graph_construction.cal_spatial_network(adata, model='KNN', k_cutoff=6)
adj_list.append(adata.uns['adj'])
adata_list.append(adata)
Calculating spatial neighbor graph ...
The graph contains 23691 edges, 3611 cells
6.5607864857380225 neighbors per cell on average
Calculating spatial neighbor graph ...
The graph contains 23721 edges, 3635 cells
6.525722145804677 neighbors per cell on average
[5]:
from scipy.sparse import block_diag
adata_concat = ad.concat(adata_list, label='slice_name', keys=section_ids)
adj_concat = block_diag(adj_list)
adata_concat.uns['edge_list'] = np.nonzero(adj_concat)
[6]:
from aligndg.graph import AlignDG
iter_comb = [(0, 1)]
adata_concat, pis = AlignDG.train_AlignDG(adata=adata_concat, lamb=2, beta=0.4, iter_comb=iter_comb)
100%|██████████| 500/500 [00:07<00:00, 71.16it/s]
0%| | 0/500 [00:00<?, ?it/s]
Update spot triplet at epoch 500
Processing datasets ('151673', '151674')
Solver may not converge.
# Epoch 599, loss: 0.183, gene_recon: 0.182, triplet loss: 0.000783044786658138: 19%|█▉ | 97/500 [00:16<00:08, 47.63it/s]
Update spot triplet at epoch 600
Processing datasets ('151673', '151674')
Solver may not converge.
# Epoch 699, loss: 0.182, gene_recon: 0.181, triplet loss: 0.0004607039736583829: 40%|███▉ | 199/500 [00:25<00:05, 53.44it/s]
Update spot triplet at epoch 700
Processing datasets ('151673', '151674')
Solver may not converge.
# Epoch 799, loss: 0.182, gene_recon: 0.181, triplet loss: 0.0004656599776353687: 59%|█████▉ | 295/500 [00:35<00:04, 48.65it/s]
Update spot triplet at epoch 800
Processing datasets ('151673', '151674')
Solver may not converge.
# Epoch 899, loss: 0.182, gene_recon: 0.181, triplet loss: 0.00047129325685091317: 79%|███████▉ | 397/500 [00:45<00:02, 50.02it/s]
Update spot triplet at epoch 900
Processing datasets ('151673', '151674')
Solver may not converge.
# Epoch 999, loss: 0.182, gene_recon: 0.181, triplet loss: 0.000311184034217149: 100%|██████████| 500/500 [00:56<00:00, 8.92it/s]
[7]:
adata1_harm_df = pd.DataFrame({
'index': range(adata_list[0].shape[0]),
'x': adata_list[0].obsm['spatial'][:, 0],
'y': adata_list[0].obsm['spatial'][:, 1],
'domain': adata_list[0].obs['layer_guess_reordered'].astype('category'),
})
adata2_harm_df = pd.DataFrame({
'index': range(adata_list[1].shape[0]),
'x': adata_list[1].obsm['spatial'][:, 0],
'y': adata_list[1].obsm['spatial'][:, 1],
'domain': adata_list[1].obs['layer_guess_reordered'].astype('category'),
})
adata1_harm_df, adata2_harm_df
[7]:
( index x y domain
AAACAAGTATCTCCCA-1.8_151673 0 440.639079 381.098123 Layer3
AAACAATCTACTAGCA-1.3_151673 1 259.630972 126.327637 Layer1
AAACACCAATAACTGC-1.8_151673 2 183.078314 427.767792 WM
AAACAGAGCGACTCCT-1.7_151673 3 417.236738 186.813688 Layer3
AAACAGCTTTCAGAAG-1.7_151673 4 152.700275 341.269139 Layer5
... ... ... ... ...
TTGTTTCACATCCAGG-1.8_151673 3606 254.410450 422.862301 WM
TTGTTTCATTAGTCTA-1.8_151673 3607 217.146722 433.393354 WM
TTGTTTCCATACAACT-1.8_151673 3608 208.415849 352.430255 Layer6
TTGTTTGTATTACACG-1.4_151673 3609 250.720081 503.735391 WM
TTGTTTGTGTAAATTC-1.8_151673 3610 284.293439 148.109816 Layer2
[3611 rows x 4 columns],
index x y domain
AAACAAGTATCTCCCA-1.9_151674 0 438.028818 384.023416 Layer3
AAACAATCTACTAGCA-1.4_151674 1 257.110720 129.297934 Layer1
AAACACCAATAACTGC-1.9_151674 2 180.513058 430.648080 WM
AAACAGAGCGACTCCT-1.8_151674 3 414.671481 189.783985 Layer3
AAACAGCTTTCAGAAG-1.8_151674 4 150.180023 344.149427 Layer5
... ... ... ... ...
TTGTTTCACATCCAGG-1.9_151674 3630 251.800189 425.787593 WM
TTGTTTCATTAGTCTA-1.9_151674 3631 214.536461 436.273642 WM
TTGTTTCCATACAACT-1.9_151674 3632 205.850592 355.355548 Layer6
TTGTTTGTATTACACG-1.5_151674 3633 248.109820 506.615679 WM
TTGTTTGTGTAAATTC-1.9_151674 3634 281.773187 151.035109 Layer1
[3635 rows x 4 columns])
[8]:
count = 0
ref_raw_list = []
pi = pis[0]
for i in range(pi.shape[0]):
source_domain = adata_list[0].obs['layer_guess_reordered'][i]
max_align = pi[i].argsort()[-1]
target_domain = adata_list[1].obs['layer_guess_reordered'].values[max_align]
ref_raw_list.append(max_align)
if source_domain == target_domain:
count = count + 1
print(count / len(ref_raw_list))
0.8263638881196345
[9]:
from aligndg.utils.visualize import Build3D
index_harm = np.array(ref_raw_list)
matching_harm = np.array([range(index_harm.shape[0]), index_harm])
multi_align = Build3D(adata2_harm_df, adata1_harm_df, matching_harm, meta='domain', scale_coordinate=True, subsample_size=300, exchange_xy=False)
multi_align.draw_3D(size=[8, 8], line_width=0.5, point_size=[2, 2], hide_axis=True, show_error=True, line_alpha=1)
dataset1: 7 cell types; dataset2: 7 cell types;
Total :7 celltypes; Overlap: 7 cell types
Not overlap :[[]]
Subsampled 300 pairs from 3611
[10]:
def get_ratio_unbalanced(alignment, labels):
matched_idx_list = []
ad1_match_label = []
ad2_match_label = [2] * alignment.shape[1]
# 确定共有的cell types
unique_labels_slice1 = set(labels[:alignment.shape[0]])
unique_labels_slice2 = set(labels[alignment.shape[0]:])
common_labels = unique_labels_slice1.intersection(unique_labels_slice2)
# 过滤出只属于共有cell types的索引
indices_slice1 = [i for i in range(alignment.shape[0]) if labels[i] in common_labels]
for i in indices_slice1:
elem = alignment[i]
max_idx = elem.argmax()
# matched_idx_list.append(max_idx)
matched_idx = max_idx + alignment.shape[0]
if labels[i] == labels[matched_idx]:
ad1_match_label.append(1)
ad2_match_label[max_idx] = 1
matched_idx_list.append(max_idx)
else:
ad1_match_label.append(0)
ad2_match_label[max_idx] = 0
return len(indices_slice1) / len(set(matched_idx_list))
[11]:
adata_cat = adata_list[0].concatenate(adata_list[1])
get_ratio_unbalanced(np.array(pis[0]), labels=adata_cat.obs['layer_guess_reordered'])
[11]:
1.2764227642276422
[12]:
sc.pp.neighbors(adata_concat, use_rep='latent')
sc.tl.umap(adata_concat)
sc.pl.umap(adata_concat, color=['layer_guess_reordered', 'slice_name'], size=20)
[ ]: