4. Gene Embedding Extraction

This notebook extracts gene embeddings using the validation pipeline (GPU required).

4.1. Parameters (GPU required)

Update the paths and settings for your dataset and checkpoint.

# GPU required
OUTPUT_DIR = "T_perturb/res/masking"  # relative to repo root, here set the output directory
CKPT_MASKING_PATH = "path/to/masking_checkpoint.ckpt" # path to load masking model checkpoint from masking training of notebook 03

SRC_DATASET = "T_perturb/tokenized_data/LPS_all_tps_2k/dataset_2000_hvg_src/normal.dataset" # path/to/src_dataset.dataset from tokenization step
TGT_DATASET_FOLDER = "T_perturb/tokenized_data/LPS_all_tps_2k/dataset_2000_hvg_tgt" # path/to/tgt_datasets_folder from tokenization step
SRC_ADATA = "T_perturb/tokenized_data/LPS_all_tps_2k/h5ad_pairing_2000_hvg_src/normal.h5ad" # path/to/src_adata.h5ad" from tokenization step
TGT_ADATA_FOLDER = "T_perturb/tokenized_data/LPS_all_tps_2k/h5ad_pairing_2000_hvg_tgt" # path/to/tgt_adata_folder from tokenization step
MAPPING_DICT_PATH = "T_perturb/tokenized_data/LPS_all_tps_2k/token_id_to_genename_2000_hvg.pkl" # path/to/mapping_dict.pkl from tokenization step
TOKENID_TO_ROWID = "T_perturb/tokenized_data/LPS_all_tps_2k/tokenid_to_rowid_2000_hvg.pkl" # path/to/tokenid_to_rowid.pkl from tokenization step

BATCH_SIZE = 64 # model batch size
CELLGEN_LR = 1e-4 # learning rate 
CELLGEN_WD = 1e-4 # weight decay
COUNT_LR = 0.001 # learning rate for count head
COUNT_WD = 0.001 # weight decay for count head
D_FF = 32 # feedforward dimension
NUM_LAYERS = 6 # number of transformer layers
N_WORKERS = 32 # number of workers for data loading
PRED_TPS = ["1", "2", "3"] # predicted time points, (for LPS: "1"=90m, "2"=6h, "3"=10h)

VAR_LIST = ["cell_pairing_index", "time_after_LPS", "cell_type_harmonized"]

ENCODER = "scmaskgit" # encoder type
ENCODER_PATH = "Perturbgen/pretraining_cohort/20250709_1223_cellgen_train_masking_lr_5e-05_wd_1e-06_batch_64_ptime_pos_sin_m_pow_tp_1-2-3_s_42-epoch=00.ckpt" # path to pretrained encoder checkpoint provided with Perturbgen
CONTEXT_MODE = "True" # whether to use context tokens
MASK_SCHEDULER = "pow" # masking scheduler type
RETURN_EMBED = "True" # whether to save cell embeddings
RETURN_ATTN = "False" # whether to return attention weights
GENERATE = "False" # whether to perform generation
RETURN_GENE_EMBS = "True" # whether to return gene embeddings
GENE_EMBS_CONDITION = "time_after_LPS" # condition for gene embeddings
POS_ENCODING_MODE = "time_pos_sin" # positional encoding mode
D_MODEL = 768 # model dimension

4.2. Build the command

This uses the CLI entrypoint extract-embedding, which calls perturbgen.val.

cmd = [
    "python",
    "-m",
    "perturbgen",
    "extract-embedding",
    "--test_mode", "masking",
    "--split", "False",
    "--splitting_mode", "stratified",
    "--return_embed", RETURN_EMBED,
    "--return_attn", RETURN_ATTN,
    "--generate", GENERATE,
    "--ckpt_masking_path", CKPT_MASKING_PATH,
    "--output_dir", OUTPUT_DIR,
    "--src_dataset", SRC_DATASET,
    "--tgt_dataset_folder", TGT_DATASET_FOLDER,
    "--src_adata", SRC_ADATA,
    "--tgt_adata_folder", TGT_ADATA_FOLDER,
    "--mapping_dict_path", MAPPING_DICT_PATH,
    "--batch_size", str(BATCH_SIZE),
    "--cellgen_lr", str(CELLGEN_LR),
    "--cellgen_wd", str(CELLGEN_WD),
    "--count_lr", str(COUNT_LR),
    "--count_wd", str(COUNT_WD),
    "--d_ff", str(D_FF),
    "--num_layers", str(NUM_LAYERS),
    "--n_workers", str(N_WORKERS),
    "--pred_tps", *PRED_TPS,
    "--var_list", *VAR_LIST,
    "--tokenid_to_rowid", TOKENID_TO_ROWID,
    "--encoder", ENCODER,
    "--encoder_path", ENCODER_PATH,
    "--context_mode", CONTEXT_MODE,
    "--mask_scheduler", MASK_SCHEDULER,
    "--return_gene_embs", RETURN_GENE_EMBS,
    "--gene_embs_condition", GENE_EMBS_CONDITION,
    "--pos_encoding_mode", POS_ENCODING_MODE,
    "--d_model", str(D_MODEL),
]

print(" ".join(cmd))
python -m perturbgen generate --test_mode masking --split False --splitting_mode stratified --return_embed True --return_attn False --generate False --ckpt_masking_path path/to/masking_checkpoint.ckpt --output_dir T_perturb/res/masking --src_dataset T_perturb/tokenized_data/LPS_all_tps_2k/dataset_2000_hvg_src/normal.dataset --tgt_dataset_folder T_perturb/tokenized_data/LPS_all_tps_2k/dataset_2000_hvg_tgt --src_adata T_perturb/tokenized_data/LPS_all_tps_2k/h5ad_pairing_2000_hvg_src/normal.h5ad --tgt_adata_folder T_perturb/tokenized_data/LPS_all_tps_2k/h5ad_pairing_2000_hvg_tgt --mapping_dict_path T_perturb/tokenized_data/LPS_all_tps_2k/token_id_to_genename_2000_hvg.pkl --batch_size 64 --max_len 797 --tgt_vocab_size 2002 --cellgen_lr 0.0001 --cellgen_wd 0.0001 --count_lr 0.001 --count_wd 0.001 --d_ff 32 --num_layers 6 --n_workers 32 --pred_tps 1 2 3 --var_list cell_pairing_index time_after_LPS cell_type_harmonized --tokenid_to_rowid T_perturb/tokenized_data/LPS_all_tps_2k/tokenid_to_rowid_2000_hvg.pkl --encoder scmaskgit --encoder_path Perturbgen/pretraining_cohort/20250709_1223_cellgen_train_masking_lr_5e-05_wd_1e-06_batch_64_ptime_pos_sin_m_pow_tp_1-2-3_s_42-epoch=00.ckpt --context_mode True --mask_scheduler pow --return_gene_embs True --gene_embs_condition time_after_LPS --pos_encoding_mode time_pos_sin --d_model 768
import subprocess

print(" ".join(cmd))