Tutorial 1: Basic usage of AlignDG on simulated data

import necessary packages

[ ]:
import warnings
warnings.filterwarnings('ignore')

import anndata as ad
import scanpy as sc
import pandas as pd
import os
import numpy as np
from scipy import sparse
import torch
import aligndg

# set seed as 42
aligndg.set_global_seed(42)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
sc.set_figure_params(scanpy=True, dpi=80, dpi_save=600, frameon=True, vector_friendly=True, color_map=None, format='pdf', facecolor=None, transparent=True, ipython_format='png2x')
[2]:
# get cuda version (oue device is 12, some new device may be 13)
! nvcc -V
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Mon_Apr__3_17:16:06_PDT_2023
Cuda compilation tools, release 12.1, V12.1.105
Build cuda_12.1.r12.1/compiler.32688072_0

Check jax device on cuda or cpu. If you hava cuda device and result is cpu, please consider: pip install –upgrade “jax[cuda12]”. (cuda version can be 12 or 13)

[3]:
import jax

jax.default_backend(), device
[3]:
('gpu', device(type='cuda', index=0))

Load Data

[10]:
from aligndg.utils import graph_construction

adata = sc.read_h5ad('data/Simulation/simulate1.h5ad')
sc.pl.spatial(adata, color=['celltype'], spot_size=0.5)
adata_sub1 = adata[adata.obs['batch'] == 'Batch1']
adata_sub2 = adata[adata.obs['batch'] == 'Batch2']

adj_list = []
section_ids = ['Batch1', 'Batch2']
adata_list = [adata_sub1, adata_sub2]
for i in range(len(adata_list)):
    adata_list[i].X = sparse.csc_matrix(adata_list[i].X)
    adata_list[i].var_names_make_unique(join="++")
    adata_list[i].obs_names = [x + '_' + section_ids[i] for x in adata_list[i].obs_names]
    graph_construction.cal_spatial_network(adata_list[i], model='KNN', k_cutoff=4)
    adj_list.append(adata_list[i].uns['adj'])
_images/Tutorial_Simulation_8_0.png
Calculating spatial neighbor graph ...
The graph contains 2368 edges, 500 cells
4.736 neighbors per cell on average
Calculating spatial neighbor graph ...
The graph contains 2396 edges, 500 cells
4.792 neighbors per cell on average
[11]:
adata1_harm_df = pd.DataFrame({
    'index': range(adata_sub1.shape[0]),
    'x': adata_sub1.obsm['spatial'][:, 0],
    'y': adata_sub1.obsm['spatial'][:, 1],
    'domain': adata_sub1.obs['celltype'].astype('category'),
})

adata2_harm_df = pd.DataFrame({
    'index': range(adata_sub2.shape[0]),
    'x': adata_sub2.obsm['spatial'][:, 0],
    'y': adata_sub2.obsm['spatial'][:, 1],
    'domain': adata_sub2.obs['celltype'].astype('category'),
})

adata1_harm_df, adata2_harm_df
[11]:
(                index         x         y  domain
 Cell1_Batch1        0  1.037015  0.727646  Group3
 Cell2_Batch1        1  7.301914  9.436571  Group2
 Cell3_Batch1        2  5.493741  5.021805  Group2
 Cell4_Batch1        3  6.041019  9.825238  Group2
 Cell5_Batch1        4  6.947731  4.347541  Group5
 ...               ...       ...       ...     ...
 Cell496_Batch1    495  0.048206  9.434689  Group4
 Cell497_Batch1    496  7.969817  7.419835  Group2
 Cell498_Batch1    497  3.112033  2.076640  Group3
 Cell499_Batch1    498  7.234705  9.352265  Group2
 Cell500_Batch1    499  5.126478  8.129682  Group2

 [500 rows x 4 columns],
                  index         x         y  domain
 Cell501_Batch2       0  7.181161  5.700499  Group2
 Cell502_Batch2       1  7.951904  0.702857  Group1
 Cell503_Batch2       2  5.376372  0.877961  Group1
 Cell504_Batch2       3  0.054656  1.326532  Group3
 Cell505_Batch2       4  6.388294  8.878522  Group2
 ...                ...       ...       ...     ...
 Cell996_Batch2     495  0.426510  4.551456  Group3
 Cell997_Batch2     496  6.570264  2.578510  Group1
 Cell998_Batch2     497  3.063378  5.954876  Group4
 Cell999_Batch2     498  6.138880  3.644416  Group5
 Cell1000_Batch2    499  0.808772  0.958538  Group3

 [500 rows x 4 columns])
[12]:
from scipy.sparse import block_diag

adata_concat = ad.concat(adata_list, label='slice_name', keys=section_ids)
adj_concat = block_diag(adj_list)
adata_concat.uns['edge_list'] = np.nonzero(adj_concat)
adata_concat
[12]:
AnnData object with n_obs × n_vars = 1000 × 500
    obs: 'celltype', 'batch', 'slice_name', 'batch_name'
    uns: 'edge_list'
    obsm: 'spatial'

Perform AlignDG

[16]:
from aligndg.graph import AlignDG

adata_concat, pis = AlignDG.train_AlignDG(adata=adata_concat, beta=0.5, iter_comb=[(0, 1)])
100%|██████████| 500/500 [00:03<00:00, 144.47it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
Update spot triplet at epoch 500
Processing datasets ('Batch1', 'Batch2')
Solver may not converge.
# Epoch 599, loss: 0.632, gene_recon: 0.586, triplet loss: 0.04584432393312454:  20%|█▉        | 98/500 [00:08<00:07, 55.81it/s]
Update spot triplet at epoch 600
Processing datasets ('Batch1', 'Batch2')
# Epoch 608, loss: 0.795, gene_recon: 0.640, triplet loss: 0.15412934124469757:  21%|██        | 106/500 [00:13<01:06,  5.90it/s]
Solver may not converge.
# Epoch 699, loss: 0.734, gene_recon: 0.712, triplet loss: 0.02233678475022316:  40%|████      | 200/500 [00:14<00:04, 70.11it/s]
Update spot triplet at epoch 700
Processing datasets ('Batch1', 'Batch2')
# Epoch 709, loss: 0.708, gene_recon: 0.600, triplet loss: 0.10783737897872925:  42%|████▏     | 209/500 [00:18<00:43,  6.70it/s]
Solver may not converge.
# Epoch 799, loss: 0.518, gene_recon: 0.482, triplet loss: 0.036520037800073624:  60%|█████▉    | 299/500 [00:19<00:03, 60.66it/s]
Update spot triplet at epoch 800
Processing datasets ('Batch1', 'Batch2')
# Epoch 807, loss: 0.606, gene_recon: 0.498, triplet loss: 0.10765905678272247:  61%|██████    | 306/500 [00:23<00:35,  5.51it/s]
Solver may not converge.
# Epoch 899, loss: 0.480, gene_recon: 0.455, triplet loss: 0.024372799322009087:  79%|███████▉  | 395/500 [00:25<00:01, 59.87it/s]
Update spot triplet at epoch 900
Processing datasets ('Batch1', 'Batch2')
# Epoch 908, loss: 0.644, gene_recon: 0.522, triplet loss: 0.12167825549840927:  80%|████████  | 402/500 [00:29<00:20,  4.87it/s]
Solver may not converge.
# Epoch 999, loss: 0.446, gene_recon: 0.426, triplet loss: 0.020836396142840385: 100%|██████████| 500/500 [00:31<00:00, 15.99it/s]
[17]:
# calculate mapping acc

mapping_pi = pis[0]
count = 0
ref_raw_list = []
index = []
for i in range(mapping_pi.shape[0]):
    if np.sum(mapping_pi[i] > 0):
        index.append(i)
        source_domain = adata_sub1.obs['celltype'][i]
        max_align = mapping_pi[i].argsort()[-1]
        target_domain = adata_sub2.obs['celltype'].values[max_align]
        ref_raw_list.append(max_align)
        if source_domain == target_domain:
            count = count + 1

print(count / len(ref_raw_list), count, len(ref_raw_list))
0.962 481 500

Visualization mapping result

[18]:
from aligndg.utils.visualize import Build3D

index_harm = np.array(ref_raw_list)
matching_harm = np.array([index, index_harm])

multi_align = Build3D(adata2_harm_df, adata1_harm_df, matching_harm, meta='domain', scale_coordinate=True,subsample_size=150, exchange_xy=False)
multi_align.draw_3D(size=[8, 10], line_width=0.5, point_size=[2, 2], line_alpha=1, hide_axis=True, show_error=True)
dataset1: 6 cell types; dataset2: 6 cell types;
                    Total :6 celltypes; Overlap: 6 cell types
                    Not overlap :[[]]
Subsampled 150 pairs from 500
_images/Tutorial_Simulation_15_1.png
[ ]: