4. Gene Embedding Extraction
This notebook extracts gene embeddings using the validation pipeline (GPU required).
4.1. Parameters (GPU required)
Update the paths and settings for your dataset and checkpoint.
# GPU required
OUTPUT_DIR = "T_perturb/res/masking" # relative to repo root, here set the output directory
CKPT_MASKING_PATH = "path/to/masking_checkpoint.ckpt" # path to load masking model checkpoint from masking training of notebook 03
SRC_DATASET = "T_perturb/tokenized_data/LPS_all_tps_2k/dataset_2000_hvg_src/normal.dataset" # path/to/src_dataset.dataset from tokenization step
TGT_DATASET_FOLDER = "T_perturb/tokenized_data/LPS_all_tps_2k/dataset_2000_hvg_tgt" # path/to/tgt_datasets_folder from tokenization step
SRC_ADATA = "T_perturb/tokenized_data/LPS_all_tps_2k/h5ad_pairing_2000_hvg_src/normal.h5ad" # path/to/src_adata.h5ad" from tokenization step
TGT_ADATA_FOLDER = "T_perturb/tokenized_data/LPS_all_tps_2k/h5ad_pairing_2000_hvg_tgt" # path/to/tgt_adata_folder from tokenization step
MAPPING_DICT_PATH = "T_perturb/tokenized_data/LPS_all_tps_2k/token_id_to_genename_2000_hvg.pkl" # path/to/mapping_dict.pkl from tokenization step
TOKENID_TO_ROWID = "T_perturb/tokenized_data/LPS_all_tps_2k/tokenid_to_rowid_2000_hvg.pkl" # path/to/tokenid_to_rowid.pkl from tokenization step
BATCH_SIZE = 64 # model batch size
CELLGEN_LR = 1e-4 # learning rate
CELLGEN_WD = 1e-4 # weight decay
COUNT_LR = 0.001 # learning rate for count head
COUNT_WD = 0.001 # weight decay for count head
D_FF = 32 # feedforward dimension
NUM_LAYERS = 6 # number of transformer layers
N_WORKERS = 32 # number of workers for data loading
PRED_TPS = ["1", "2", "3"] # predicted time points, (for LPS: "1"=90m, "2"=6h, "3"=10h)
VAR_LIST = ["cell_pairing_index", "time_after_LPS", "cell_type_harmonized"]
ENCODER = "scmaskgit" # encoder type
ENCODER_PATH = "Perturbgen/pretraining_cohort/20250709_1223_cellgen_train_masking_lr_5e-05_wd_1e-06_batch_64_ptime_pos_sin_m_pow_tp_1-2-3_s_42-epoch=00.ckpt" # path to pretrained encoder checkpoint provided with Perturbgen
CONTEXT_MODE = "True" # whether to use context tokens
MASK_SCHEDULER = "pow" # masking scheduler type
RETURN_EMBED = "True" # whether to save cell embeddings
RETURN_ATTN = "False" # whether to return attention weights
GENERATE = "False" # whether to perform generation
RETURN_GENE_EMBS = "True" # whether to return gene embeddings
GENE_EMBS_CONDITION = "time_after_LPS" # condition for gene embeddings
POS_ENCODING_MODE = "time_pos_sin" # positional encoding mode
D_MODEL = 768 # model dimension
4.2. Build the command
This uses the CLI entrypoint extract-embedding, which calls perturbgen.val.
cmd = [
"python",
"-m",
"perturbgen",
"extract-embedding",
"--test_mode", "masking",
"--split", "False",
"--splitting_mode", "stratified",
"--return_embed", RETURN_EMBED,
"--return_attn", RETURN_ATTN,
"--generate", GENERATE,
"--ckpt_masking_path", CKPT_MASKING_PATH,
"--output_dir", OUTPUT_DIR,
"--src_dataset", SRC_DATASET,
"--tgt_dataset_folder", TGT_DATASET_FOLDER,
"--src_adata", SRC_ADATA,
"--tgt_adata_folder", TGT_ADATA_FOLDER,
"--mapping_dict_path", MAPPING_DICT_PATH,
"--batch_size", str(BATCH_SIZE),
"--cellgen_lr", str(CELLGEN_LR),
"--cellgen_wd", str(CELLGEN_WD),
"--count_lr", str(COUNT_LR),
"--count_wd", str(COUNT_WD),
"--d_ff", str(D_FF),
"--num_layers", str(NUM_LAYERS),
"--n_workers", str(N_WORKERS),
"--pred_tps", *PRED_TPS,
"--var_list", *VAR_LIST,
"--tokenid_to_rowid", TOKENID_TO_ROWID,
"--encoder", ENCODER,
"--encoder_path", ENCODER_PATH,
"--context_mode", CONTEXT_MODE,
"--mask_scheduler", MASK_SCHEDULER,
"--return_gene_embs", RETURN_GENE_EMBS,
"--gene_embs_condition", GENE_EMBS_CONDITION,
"--pos_encoding_mode", POS_ENCODING_MODE,
"--d_model", str(D_MODEL),
]
print(" ".join(cmd))
python -m perturbgen generate --test_mode masking --split False --splitting_mode stratified --return_embed True --return_attn False --generate False --ckpt_masking_path path/to/masking_checkpoint.ckpt --output_dir T_perturb/res/masking --src_dataset T_perturb/tokenized_data/LPS_all_tps_2k/dataset_2000_hvg_src/normal.dataset --tgt_dataset_folder T_perturb/tokenized_data/LPS_all_tps_2k/dataset_2000_hvg_tgt --src_adata T_perturb/tokenized_data/LPS_all_tps_2k/h5ad_pairing_2000_hvg_src/normal.h5ad --tgt_adata_folder T_perturb/tokenized_data/LPS_all_tps_2k/h5ad_pairing_2000_hvg_tgt --mapping_dict_path T_perturb/tokenized_data/LPS_all_tps_2k/token_id_to_genename_2000_hvg.pkl --batch_size 64 --max_len 797 --tgt_vocab_size 2002 --cellgen_lr 0.0001 --cellgen_wd 0.0001 --count_lr 0.001 --count_wd 0.001 --d_ff 32 --num_layers 6 --n_workers 32 --pred_tps 1 2 3 --var_list cell_pairing_index time_after_LPS cell_type_harmonized --tokenid_to_rowid T_perturb/tokenized_data/LPS_all_tps_2k/tokenid_to_rowid_2000_hvg.pkl --encoder scmaskgit --encoder_path Perturbgen/pretraining_cohort/20250709_1223_cellgen_train_masking_lr_5e-05_wd_1e-06_batch_64_ptime_pos_sin_m_pow_tp_1-2-3_s_42-epoch=00.ckpt --context_mode True --mask_scheduler pow --return_gene_embs True --gene_embs_condition time_after_LPS --pos_encoding_mode time_pos_sin --d_model 768
import subprocess
print(" ".join(cmd))