3. Train Masking Model and Count Decoder

This step requires a GPU. Run on a GPU node (e.g., via your scheduler) or a machine with CUDA.

3.1. Train Masking Model (GPU required)

This step requires a GPU. Run on a GPU node (e.g., via your scheduler) or a machine with CUDA.

# GPU required
OUTPUT_DIR = "T_perturb/res/masking"  # relative to repo root, here set the output directory
SRC_DATASET = "T_perturb/tokenized_data/LPS_all_tps_2k/dataset_2000_hvg_src/normal.dataset" # path/to/src_dataset.dataset from tokenization step
TGT_DATASET_FOLDER = "T_perturb/tokenized_data/LPS_all_tps_2k/dataset_2000_hvg_tgt" # path/to/tgt_datasets_folder from tokenization step
SRC_ADATA = "T_perturb/tokenized_data/LPS_all_tps_2k/h5ad_pairing_2000_hvg_src/normal.h5ad" # path/to/src_adata.h5ad" from tokenization step
TGT_ADATA_FOLDER = "T_perturb/tokenized_data/LPS_all_tps_2k/h5ad_pairing_2000_hvg_tgt" # path/to/tgt_adata_folder from tokenization step
MAPPING_DICT_PATH = "T_perturb/tokenized_data/LPS_all_tps_2k/token_id_to_genename_2000_hvg.pkl" # path/to/mapping_dict.pkl from tokenization step

ENCODER_PATH = "Perturbgen/pretraining_cohort/20250709_1223_cellgen_train_masking_lr_5e-05_wd_1e-06_batch_64_ptime_pos_sin_m_pow_tp_1-2-3_s_42-epoch=00.ckpt" # path to pretrained encoder checkpoint provided with Perturbgen

BATCH_SIZE = 64 # Model training batch size
EPOCHS = 20 # number of training epochs
CELLGEN_LR = 1e-4 # learning rate
CELLGEN_WD = 1e-4 # weight decay
N_WORKERS = 4 # number of data loading workers
NUM_LAYERS = 6 # number of transformer layers
D_FF = 32 # feedforward dimension
PRED_TPS = ["1", "2", "3"] # time points to train on and predict (for LPS: "1"=90m, "2"=6h, "3"=10h)

VAR_LIST = ["cell_type_harmonized", "time_after_LPS"] # list of obs retained in adata.vars after preprocessing

SEED = 0 # random seed for reproducibility
CONTEXT_MODE = "True" # whether to use context tokens
POS_ENCODING_MODE = "time_pos_sin" # positional encoding mode
MASK_SCHEDULER = "pow" # masking scheduler type
NUM_NODE = 1 # number of nodes if using distributed training
D_MODEL = 768 # model dimension

CKPT_MASKING_PATH = "path/to/optional_resume.ckpt"  # optional path to checkpoint to resume training from
USE_WEIGHTED_SAMPLER = "False" # whether to use weighted sampler during training
cmd = [
    "python",
    "-m",
    "perturbgen",
    "train-mask",
    "--train_mode", "masking",
    "--split", "False",
    "--encoder", "scmaskgit",
    "--splitting_mode", "stratified",
    "--split_obs", "cell_type_harmonized",
    "--output_dir", OUTPUT_DIR,
    "--src_dataset", SRC_DATASET,
    "--tgt_dataset_folder", TGT_DATASET_FOLDER,
    "--src_adata", SRC_ADATA,
    "--tgt_adata_folder", TGT_ADATA_FOLDER,
    "--mapping_dict_path", MAPPING_DICT_PATH,
    "--batch_size", str(BATCH_SIZE),
    "--epochs", str(EPOCHS),
    "--cellgen_lr", str(CELLGEN_LR),
    "--cellgen_wd", str(CELLGEN_WD),
    "--n_workers", str(N_WORKERS),
    "--num_layers", str(NUM_LAYERS),
    "--d_ff", str(D_FF),
    "--pred_tps", *PRED_TPS,
    "--var_list", *VAR_LIST,
    "--encoder_path", ENCODER_PATH,
    "--seed", str(SEED),
    "--context_mode", CONTEXT_MODE,
    "--pos_encoding_mode", POS_ENCODING_MODE,
    "--mask_scheduler", MASK_SCHEDULER,
    "--num_node", str(NUM_NODE),
    "--d_model", str(D_MODEL),
    "--use_weighted_sampler", USE_WEIGHTED_SAMPLER,
]

if CKPT_MASKING_PATH:
    cmd += ["--ckpt_masking_path", CKPT_MASKING_PATH]

print(" ".join(cmd))
python -m perturbgen train-mask --train_mode masking --split False --encoder scmaskgit --splitting_mode stratified --split_obs cell_type_harmonized --output_dir T_perturb/res/masking --src_dataset T_perturb/tokenized_data/LPS_all_tps_2k/dataset_2000_hvg_src/normal.dataset --tgt_dataset_folder T_perturb/tokenized_data/LPS_all_tps_2k/dataset_2000_hvg_tgt --src_adata T_perturb/tokenized_data/LPS_all_tps_2k/h5ad_pairing_2000_hvg_src/normal.h5ad --tgt_adata_folder T_perturb/tokenized_data/LPS_all_tps_2k/h5ad_pairing_2000_hvg_tgt --mapping_dict_path T_perturb/tokenized_data/LPS_all_tps_2k/token_id_to_genename_2000_hvg.pkl --batch_size 64 --max_len 797 --epochs 20 --tgt_vocab_size 2002 --cellgen_lr 0.0001 --cellgen_wd 0.0001 --n_workers 4 --num_layers 6 --d_ff 32 --pred_tps 1 2 3 --var_list cell_type_harmonized time_after_LPS --cond_list cell_type_harmonized --encoder_path Perturbgen/pretraining_cohort/20250709_1223_cellgen_train_masking_lr_5e-05_wd_1e-06_batch_64_ptime_pos_sin_m_pow_tp_1-2-3_s_42-epoch=00.ckpt --seed 0 --context_mode True --pos_encoding_mode time_pos_sin --mask_scheduler pow --num_node 1 --d_model 768 --use_weighted_sampler False --ckpt_masking_path path/to/optional_resume.ckpt
import subprocess

subprocess.run(cmd, check=True)

3.2. Train count decoder (GPU required)

This step trains the count decoder using the masking checkpoint.

# GPU required
COUNT_OUTPUT_DIR = "T_perturb/res/count"  # relative to repo root, here set the output directory
CKPT_MASKING_PATH = "path/to/masking_checkpoint.ckpt" # should be selected based on best masking model from previous step

SRC_DATASET = "T_perturb/tokenized_data/LPS_all_tps_2k/dataset_2000_hvg_src/normal.dataset" # path/to/src_dataset.dataset from tokenization step
TGT_DATASET_FOLDER = "T_perturb/tokenized_data/LPS_all_tps_2k/dataset_2000_hvg_tgt" # path/to/tgt_datasets_folder from tokenization step
SRC_ADATA = "T_perturb/tokenized_data/LPS_all_tps_2k/h5ad_pairing_2000_hvg_src/normal.h5ad" # path/to/src_adata.h5ad" from tokenization step
TGT_ADATA_FOLDER = "T_perturb/tokenized_data/LPS_all_tps_2k/h5ad_pairing_2000_hvg_tgt"  # path/to/tgt_adata_folder from tokenization step
MAPPING_DICT_PATH = "T_perturb/tokenized_data/LPS_all_tps_2k/token_id_to_genename_2000_hvg.pkl" # path/to/mapping_dict.pkl from tokenization step

ENCODER_PATH = "Perturbgen/pretraining_cohort/20250709_1223_cellgen_train_masking_lr_5e-05_wd_1e-06_batch_64_ptime_pos_sin_m_pow_tp_1-2-3_s_42-epoch=00.ckpt" # path to pretrained encoder checkpoint provided with Perturbgen

BATCH_SIZE = 16 # Model training batch size
EPOCHS = 16 # number of training epochs
COUNT_LR = 0.001 # learning rate for count model
CELLGEN_LR = 0.0001 # learning rate for masking part, not useful here
CELLGEN_WD = 0.0001 # weight decay for masking part, not useful here
COUNT_WD = 0.001 # weight decay for count model
MLM_PROB = 0.30 # masking probability
N_WORKERS = 32 # number of data loading workers
NUM_LAYERS = 6 # number of transformer layers
D_FF = 32 # feedforward dimension
LOSS_MODE = "zinb" # loss mode for count model, could be mse, nb, or zinb
PRED_TPS = ["1", "2", "3"] # time points to train on and predict (for LPS: "1"=90m, "2"=6h, "3"=10h)

VAR_LIST = ["cell_type_harmonized", "time_after_LPS"] # list of obs retained in adata.vars after preprocessing

COUNT_DROPOUT = 0.1 # dropout for count model
USE_POSITIONAL_ENCODING = "False" # whether to use positional encoding in count model
LAYER_NORM = "True" # whether to use layer normalization in count model
CONTEXT_MODE = "True" # whether to use context tokens
POS_ENCODING_MODE = "time_pos_sin" # positional encoding mode
MASK_SCHEDULER = "pow" # masking scheduler type
NUM_NODE = 1 # number of nodes if using distributed training
D_MODEL = 768 # model dimension
cmd = [
    "python",
    "-m",
    "perturbgen",
    "train-decoder",
    "--train_mode", "count",
    "--split", "False",
    "--splitting_mode", "stratified",
    "--output_dir", COUNT_OUTPUT_DIR,
    "--ckpt_masking_path", CKPT_MASKING_PATH,
    "--src_dataset", SRC_DATASET,
    "--tgt_dataset_folder", TGT_DATASET_FOLDER,
    "--src_adata", SRC_ADATA,
    "--tgt_adata_folder", TGT_ADATA_FOLDER,
    "--mapping_dict_path", MAPPING_DICT_PATH,
    "--batch_size", str(BATCH_SIZE),
    "--epochs", str(EPOCHS),
    "--count_lr", str(COUNT_LR),
    "--cellgen_lr", str(CELLGEN_LR),
    "--cellgen_wd", str(CELLGEN_WD),
    "--count_wd", str(COUNT_WD),
    "--mlm_prob", str(MLM_PROB),
    "--n_workers", str(N_WORKERS),
    "--num_layers", str(NUM_LAYERS),
    "--d_ff", str(D_FF),
    "--loss_mode", LOSS_MODE,
    "--pred_tps", *PRED_TPS,
    "--var_list", *VAR_LIST,
    "--encoder", "scmaskgit",
    "--count_dropout", str(COUNT_DROPOUT),
    "--use_positional_encoding", USE_POSITIONAL_ENCODING,
    "--layer_norm", LAYER_NORM,
    "--context_mode", CONTEXT_MODE,
    "--encoder_path", ENCODER_PATH,
    "--pos_encoding_mode", POS_ENCODING_MODE,
    "--mask_scheduler", MASK_SCHEDULER,
    "--num_node", str(NUM_NODE),
    "--d_model", str(D_MODEL),
    "--ckpt_every_n_epochs", "5",
]
print(" ".join(cmd))
python -m perturbgen train-decoder --train_mode count --split False --splitting_mode stratified --output_dir T_perturb/res/count --ckpt_masking_path path/to/masking_checkpoint.ckpt --src_dataset T_perturb/tokenized_data/LPS_all_tps_2k/dataset_2000_hvg_src/normal.dataset --tgt_dataset_folder T_perturb/tokenized_data/LPS_all_tps_2k/dataset_2000_hvg_tgt --src_adata T_perturb/tokenized_data/LPS_all_tps_2k/h5ad_pairing_2000_hvg_src/normal.h5ad --tgt_adata_folder T_perturb/tokenized_data/LPS_all_tps_2k/h5ad_pairing_2000_hvg_tgt --mapping_dict_path T_perturb/tokenized_data/LPS_all_tps_2k/token_id_to_genename_2000_hvg.pkl --batch_size 16 --max_len 797 --epochs 16 --tgt_vocab_size 2002 --count_lr 0.001 --cellgen_lr 0.0001 --cellgen_wd 0.0001 --count_wd 0.001 --mlm_prob 0.3 --n_workers 32 --num_layers 6 --d_ff 32 --loss_mode zinb --pred_tps 1 2 3 --var_list cell_type_harmonized time_after_LPS --cond_list cell_type_harmonized --encoder scmaskgit --add_cell_time False --d_condc 64 --d_condt 768 --count_dropout 0.1 --use_positional_encoding False --layer_norm True --context_mode True --encoder_path Perturbgen/pretraining_cohort/20250709_1223_cellgen_train_masking_lr_5e-05_wd_1e-06_batch_64_ptime_pos_sin_m_pow_tp_1-2-3_s_42-epoch=00.ckpt --pos_encoding_mode time_pos_sin --mask_scheduler cosine --num_node 1 --d_model 768
import subprocess

subprocess.run(cmd, check=True)