3. Train Masking Model and Count Decoder
This step requires a GPU. Run on a GPU node (e.g., via your scheduler) or a machine with CUDA.
3.1. Train Masking Model (GPU required)
This step requires a GPU. Run on a GPU node (e.g., via your scheduler) or a machine with CUDA.
# GPU required
OUTPUT_DIR = "T_perturb/res/masking" # relative to repo root, here set the output directory
SRC_DATASET = "T_perturb/tokenized_data/LPS_all_tps_2k/dataset_2000_hvg_src/normal.dataset" # path/to/src_dataset.dataset from tokenization step
TGT_DATASET_FOLDER = "T_perturb/tokenized_data/LPS_all_tps_2k/dataset_2000_hvg_tgt" # path/to/tgt_datasets_folder from tokenization step
SRC_ADATA = "T_perturb/tokenized_data/LPS_all_tps_2k/h5ad_pairing_2000_hvg_src/normal.h5ad" # path/to/src_adata.h5ad" from tokenization step
TGT_ADATA_FOLDER = "T_perturb/tokenized_data/LPS_all_tps_2k/h5ad_pairing_2000_hvg_tgt" # path/to/tgt_adata_folder from tokenization step
MAPPING_DICT_PATH = "T_perturb/tokenized_data/LPS_all_tps_2k/token_id_to_genename_2000_hvg.pkl" # path/to/mapping_dict.pkl from tokenization step
ENCODER_PATH = "Perturbgen/pretraining_cohort/20250709_1223_cellgen_train_masking_lr_5e-05_wd_1e-06_batch_64_ptime_pos_sin_m_pow_tp_1-2-3_s_42-epoch=00.ckpt" # path to pretrained encoder checkpoint provided with Perturbgen
BATCH_SIZE = 64 # Model training batch size
EPOCHS = 20 # number of training epochs
CELLGEN_LR = 1e-4 # learning rate
CELLGEN_WD = 1e-4 # weight decay
N_WORKERS = 4 # number of data loading workers
NUM_LAYERS = 6 # number of transformer layers
D_FF = 32 # feedforward dimension
PRED_TPS = ["1", "2", "3"] # time points to train on and predict (for LPS: "1"=90m, "2"=6h, "3"=10h)
VAR_LIST = ["cell_type_harmonized", "time_after_LPS"] # list of obs retained in adata.vars after preprocessing
SEED = 0 # random seed for reproducibility
CONTEXT_MODE = "True" # whether to use context tokens
POS_ENCODING_MODE = "time_pos_sin" # positional encoding mode
MASK_SCHEDULER = "pow" # masking scheduler type
NUM_NODE = 1 # number of nodes if using distributed training
D_MODEL = 768 # model dimension
CKPT_MASKING_PATH = "path/to/optional_resume.ckpt" # optional path to checkpoint to resume training from
USE_WEIGHTED_SAMPLER = "False" # whether to use weighted sampler during training
cmd = [
"python",
"-m",
"perturbgen",
"train-mask",
"--train_mode", "masking",
"--split", "False",
"--encoder", "scmaskgit",
"--splitting_mode", "stratified",
"--split_obs", "cell_type_harmonized",
"--output_dir", OUTPUT_DIR,
"--src_dataset", SRC_DATASET,
"--tgt_dataset_folder", TGT_DATASET_FOLDER,
"--src_adata", SRC_ADATA,
"--tgt_adata_folder", TGT_ADATA_FOLDER,
"--mapping_dict_path", MAPPING_DICT_PATH,
"--batch_size", str(BATCH_SIZE),
"--epochs", str(EPOCHS),
"--cellgen_lr", str(CELLGEN_LR),
"--cellgen_wd", str(CELLGEN_WD),
"--n_workers", str(N_WORKERS),
"--num_layers", str(NUM_LAYERS),
"--d_ff", str(D_FF),
"--pred_tps", *PRED_TPS,
"--var_list", *VAR_LIST,
"--encoder_path", ENCODER_PATH,
"--seed", str(SEED),
"--context_mode", CONTEXT_MODE,
"--pos_encoding_mode", POS_ENCODING_MODE,
"--mask_scheduler", MASK_SCHEDULER,
"--num_node", str(NUM_NODE),
"--d_model", str(D_MODEL),
"--use_weighted_sampler", USE_WEIGHTED_SAMPLER,
]
if CKPT_MASKING_PATH:
cmd += ["--ckpt_masking_path", CKPT_MASKING_PATH]
print(" ".join(cmd))
python -m perturbgen train-mask --train_mode masking --split False --encoder scmaskgit --splitting_mode stratified --split_obs cell_type_harmonized --output_dir T_perturb/res/masking --src_dataset T_perturb/tokenized_data/LPS_all_tps_2k/dataset_2000_hvg_src/normal.dataset --tgt_dataset_folder T_perturb/tokenized_data/LPS_all_tps_2k/dataset_2000_hvg_tgt --src_adata T_perturb/tokenized_data/LPS_all_tps_2k/h5ad_pairing_2000_hvg_src/normal.h5ad --tgt_adata_folder T_perturb/tokenized_data/LPS_all_tps_2k/h5ad_pairing_2000_hvg_tgt --mapping_dict_path T_perturb/tokenized_data/LPS_all_tps_2k/token_id_to_genename_2000_hvg.pkl --batch_size 64 --max_len 797 --epochs 20 --tgt_vocab_size 2002 --cellgen_lr 0.0001 --cellgen_wd 0.0001 --n_workers 4 --num_layers 6 --d_ff 32 --pred_tps 1 2 3 --var_list cell_type_harmonized time_after_LPS --cond_list cell_type_harmonized --encoder_path Perturbgen/pretraining_cohort/20250709_1223_cellgen_train_masking_lr_5e-05_wd_1e-06_batch_64_ptime_pos_sin_m_pow_tp_1-2-3_s_42-epoch=00.ckpt --seed 0 --context_mode True --pos_encoding_mode time_pos_sin --mask_scheduler pow --num_node 1 --d_model 768 --use_weighted_sampler False --ckpt_masking_path path/to/optional_resume.ckpt
import subprocess
subprocess.run(cmd, check=True)
3.2. Train count decoder (GPU required)
This step trains the count decoder using the masking checkpoint.
# GPU required
COUNT_OUTPUT_DIR = "T_perturb/res/count" # relative to repo root, here set the output directory
CKPT_MASKING_PATH = "path/to/masking_checkpoint.ckpt" # should be selected based on best masking model from previous step
SRC_DATASET = "T_perturb/tokenized_data/LPS_all_tps_2k/dataset_2000_hvg_src/normal.dataset" # path/to/src_dataset.dataset from tokenization step
TGT_DATASET_FOLDER = "T_perturb/tokenized_data/LPS_all_tps_2k/dataset_2000_hvg_tgt" # path/to/tgt_datasets_folder from tokenization step
SRC_ADATA = "T_perturb/tokenized_data/LPS_all_tps_2k/h5ad_pairing_2000_hvg_src/normal.h5ad" # path/to/src_adata.h5ad" from tokenization step
TGT_ADATA_FOLDER = "T_perturb/tokenized_data/LPS_all_tps_2k/h5ad_pairing_2000_hvg_tgt" # path/to/tgt_adata_folder from tokenization step
MAPPING_DICT_PATH = "T_perturb/tokenized_data/LPS_all_tps_2k/token_id_to_genename_2000_hvg.pkl" # path/to/mapping_dict.pkl from tokenization step
ENCODER_PATH = "Perturbgen/pretraining_cohort/20250709_1223_cellgen_train_masking_lr_5e-05_wd_1e-06_batch_64_ptime_pos_sin_m_pow_tp_1-2-3_s_42-epoch=00.ckpt" # path to pretrained encoder checkpoint provided with Perturbgen
BATCH_SIZE = 16 # Model training batch size
EPOCHS = 16 # number of training epochs
COUNT_LR = 0.001 # learning rate for count model
CELLGEN_LR = 0.0001 # learning rate for masking part, not useful here
CELLGEN_WD = 0.0001 # weight decay for masking part, not useful here
COUNT_WD = 0.001 # weight decay for count model
MLM_PROB = 0.30 # masking probability
N_WORKERS = 32 # number of data loading workers
NUM_LAYERS = 6 # number of transformer layers
D_FF = 32 # feedforward dimension
LOSS_MODE = "zinb" # loss mode for count model, could be mse, nb, or zinb
PRED_TPS = ["1", "2", "3"] # time points to train on and predict (for LPS: "1"=90m, "2"=6h, "3"=10h)
VAR_LIST = ["cell_type_harmonized", "time_after_LPS"] # list of obs retained in adata.vars after preprocessing
COUNT_DROPOUT = 0.1 # dropout for count model
USE_POSITIONAL_ENCODING = "False" # whether to use positional encoding in count model
LAYER_NORM = "True" # whether to use layer normalization in count model
CONTEXT_MODE = "True" # whether to use context tokens
POS_ENCODING_MODE = "time_pos_sin" # positional encoding mode
MASK_SCHEDULER = "pow" # masking scheduler type
NUM_NODE = 1 # number of nodes if using distributed training
D_MODEL = 768 # model dimension
cmd = [
"python",
"-m",
"perturbgen",
"train-decoder",
"--train_mode", "count",
"--split", "False",
"--splitting_mode", "stratified",
"--output_dir", COUNT_OUTPUT_DIR,
"--ckpt_masking_path", CKPT_MASKING_PATH,
"--src_dataset", SRC_DATASET,
"--tgt_dataset_folder", TGT_DATASET_FOLDER,
"--src_adata", SRC_ADATA,
"--tgt_adata_folder", TGT_ADATA_FOLDER,
"--mapping_dict_path", MAPPING_DICT_PATH,
"--batch_size", str(BATCH_SIZE),
"--epochs", str(EPOCHS),
"--count_lr", str(COUNT_LR),
"--cellgen_lr", str(CELLGEN_LR),
"--cellgen_wd", str(CELLGEN_WD),
"--count_wd", str(COUNT_WD),
"--mlm_prob", str(MLM_PROB),
"--n_workers", str(N_WORKERS),
"--num_layers", str(NUM_LAYERS),
"--d_ff", str(D_FF),
"--loss_mode", LOSS_MODE,
"--pred_tps", *PRED_TPS,
"--var_list", *VAR_LIST,
"--encoder", "scmaskgit",
"--count_dropout", str(COUNT_DROPOUT),
"--use_positional_encoding", USE_POSITIONAL_ENCODING,
"--layer_norm", LAYER_NORM,
"--context_mode", CONTEXT_MODE,
"--encoder_path", ENCODER_PATH,
"--pos_encoding_mode", POS_ENCODING_MODE,
"--mask_scheduler", MASK_SCHEDULER,
"--num_node", str(NUM_NODE),
"--d_model", str(D_MODEL),
"--ckpt_every_n_epochs", "5",
]
print(" ".join(cmd))
python -m perturbgen train-decoder --train_mode count --split False --splitting_mode stratified --output_dir T_perturb/res/count --ckpt_masking_path path/to/masking_checkpoint.ckpt --src_dataset T_perturb/tokenized_data/LPS_all_tps_2k/dataset_2000_hvg_src/normal.dataset --tgt_dataset_folder T_perturb/tokenized_data/LPS_all_tps_2k/dataset_2000_hvg_tgt --src_adata T_perturb/tokenized_data/LPS_all_tps_2k/h5ad_pairing_2000_hvg_src/normal.h5ad --tgt_adata_folder T_perturb/tokenized_data/LPS_all_tps_2k/h5ad_pairing_2000_hvg_tgt --mapping_dict_path T_perturb/tokenized_data/LPS_all_tps_2k/token_id_to_genename_2000_hvg.pkl --batch_size 16 --max_len 797 --epochs 16 --tgt_vocab_size 2002 --count_lr 0.001 --cellgen_lr 0.0001 --cellgen_wd 0.0001 --count_wd 0.001 --mlm_prob 0.3 --n_workers 32 --num_layers 6 --d_ff 32 --loss_mode zinb --pred_tps 1 2 3 --var_list cell_type_harmonized time_after_LPS --cond_list cell_type_harmonized --encoder scmaskgit --add_cell_time False --d_condc 64 --d_condt 768 --count_dropout 0.1 --use_positional_encoding False --layer_norm True --context_mode True --encoder_path Perturbgen/pretraining_cohort/20250709_1223_cellgen_train_masking_lr_5e-05_wd_1e-06_batch_64_ptime_pos_sin_m_pow_tp_1-2-3_s_42-epoch=00.ckpt --pos_encoding_mode time_pos_sin --mask_scheduler cosine --num_node 1 --d_model 768
import subprocess
subprocess.run(cmd, check=True)