{ "cells": [ { "cell_type": "markdown", "id": "e190beb1", "metadata": {}, "source": [ "# 3. Train Masking Model and Count Decoder\n", "\n" ] }, { "cell_type": "markdown", "id": "cd1afea7", "metadata": {}, "source": [ "This step requires a GPU. Run on a GPU node (e.g., via your scheduler) or a machine with CUDA." ] }, { "cell_type": "markdown", "id": "b6d04be7", "metadata": {}, "source": [ "## 3.1. Train Masking Model (GPU required)\n" ] }, { "cell_type": "markdown", "id": "58478008", "metadata": {}, "source": [ "This step requires a GPU. Run on a GPU node (e.g., via your scheduler) or a machine with CUDA.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "c25fecd2", "metadata": {}, "outputs": [], "source": [ "# GPU required\n", "OUTPUT_DIR = \"T_perturb/res/masking\" # relative to repo root, here set the output directory\n", "SRC_DATASET = \"T_perturb/tokenized_data/LPS_all_tps_2k/dataset_2000_hvg_src/normal.dataset\" # path/to/src_dataset.dataset from tokenization step\n", "TGT_DATASET_FOLDER = \"T_perturb/tokenized_data/LPS_all_tps_2k/dataset_2000_hvg_tgt\" # path/to/tgt_datasets_folder from tokenization step\n", "SRC_ADATA = \"T_perturb/tokenized_data/LPS_all_tps_2k/h5ad_pairing_2000_hvg_src/normal.h5ad\" # path/to/src_adata.h5ad\" from tokenization step\n", "TGT_ADATA_FOLDER = \"T_perturb/tokenized_data/LPS_all_tps_2k/h5ad_pairing_2000_hvg_tgt\" # path/to/tgt_adata_folder from tokenization step\n", "MAPPING_DICT_PATH = \"T_perturb/tokenized_data/LPS_all_tps_2k/token_id_to_genename_2000_hvg.pkl\" # path/to/mapping_dict.pkl from tokenization step\n", "\n", "ENCODER_PATH = \"Perturbgen/pretraining_cohort/20250709_1223_cellgen_train_masking_lr_5e-05_wd_1e-06_batch_64_ptime_pos_sin_m_pow_tp_1-2-3_s_42-epoch=00.ckpt\" # path to pretrained encoder checkpoint provided with Perturbgen\n", "\n", "BATCH_SIZE = 64 # Model training batch size\n", "EPOCHS = 20 # number of training epochs\n", "CELLGEN_LR = 1e-4 # learning rate\n", "CELLGEN_WD = 1e-4 # weight decay\n", "N_WORKERS = 4 # number of data loading workers\n", "NUM_LAYERS = 6 # number of transformer layers\n", "D_FF = 32 # feedforward dimension\n", "PRED_TPS = [\"1\", \"2\", \"3\"] # time points to train on and predict (for LPS: \"1\"=90m, \"2\"=6h, \"3\"=10h)\n", "\n", "VAR_LIST = [\"cell_type_harmonized\", \"time_after_LPS\"] # list of obs retained in adata.vars after preprocessing\n", "\n", "SEED = 0 # random seed for reproducibility\n", "CONTEXT_MODE = \"True\" # whether to use context tokens\n", "POS_ENCODING_MODE = \"time_pos_sin\" # positional encoding mode\n", "MASK_SCHEDULER = \"pow\" # masking scheduler type\n", "NUM_NODE = 1 # number of nodes if using distributed training\n", "D_MODEL = 768 # model dimension\n", "\n", "CKPT_MASKING_PATH = \"path/to/optional_resume.ckpt\" # optional path to checkpoint to resume training from\n", "USE_WEIGHTED_SAMPLER = \"False\" # whether to use weighted sampler during training\n" ] }, { "cell_type": "code", "execution_count": null, "id": "7bd04ad4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "python -m perturbgen train-mask --train_mode masking --split False --encoder scmaskgit --splitting_mode stratified --split_obs cell_type_harmonized --output_dir T_perturb/res/masking --src_dataset T_perturb/tokenized_data/LPS_all_tps_2k/dataset_2000_hvg_src/normal.dataset --tgt_dataset_folder T_perturb/tokenized_data/LPS_all_tps_2k/dataset_2000_hvg_tgt --src_adata T_perturb/tokenized_data/LPS_all_tps_2k/h5ad_pairing_2000_hvg_src/normal.h5ad --tgt_adata_folder T_perturb/tokenized_data/LPS_all_tps_2k/h5ad_pairing_2000_hvg_tgt --mapping_dict_path T_perturb/tokenized_data/LPS_all_tps_2k/token_id_to_genename_2000_hvg.pkl --batch_size 64 --max_len 797 --epochs 20 --tgt_vocab_size 2002 --cellgen_lr 0.0001 --cellgen_wd 0.0001 --n_workers 4 --num_layers 6 --d_ff 32 --pred_tps 1 2 3 --var_list cell_type_harmonized time_after_LPS --cond_list cell_type_harmonized --encoder_path Perturbgen/pretraining_cohort/20250709_1223_cellgen_train_masking_lr_5e-05_wd_1e-06_batch_64_ptime_pos_sin_m_pow_tp_1-2-3_s_42-epoch=00.ckpt --seed 0 --context_mode True --pos_encoding_mode time_pos_sin --mask_scheduler pow --num_node 1 --d_model 768 --use_weighted_sampler False --ckpt_masking_path path/to/optional_resume.ckpt\n" ] } ], "source": [ "cmd = [\n", " \"python\",\n", " \"-m\",\n", " \"perturbgen\",\n", " \"train-mask\",\n", " \"--train_mode\", \"masking\",\n", " \"--split\", \"False\",\n", " \"--encoder\", \"scmaskgit\",\n", " \"--splitting_mode\", \"stratified\",\n", " \"--split_obs\", \"cell_type_harmonized\",\n", " \"--output_dir\", OUTPUT_DIR,\n", " \"--src_dataset\", SRC_DATASET,\n", " \"--tgt_dataset_folder\", TGT_DATASET_FOLDER,\n", " \"--src_adata\", SRC_ADATA,\n", " \"--tgt_adata_folder\", TGT_ADATA_FOLDER,\n", " \"--mapping_dict_path\", MAPPING_DICT_PATH,\n", " \"--batch_size\", str(BATCH_SIZE),\n", " \"--epochs\", str(EPOCHS),\n", " \"--cellgen_lr\", str(CELLGEN_LR),\n", " \"--cellgen_wd\", str(CELLGEN_WD),\n", " \"--n_workers\", str(N_WORKERS),\n", " \"--num_layers\", str(NUM_LAYERS),\n", " \"--d_ff\", str(D_FF),\n", " \"--pred_tps\", *PRED_TPS,\n", " \"--var_list\", *VAR_LIST,\n", " \"--encoder_path\", ENCODER_PATH,\n", " \"--seed\", str(SEED),\n", " \"--context_mode\", CONTEXT_MODE,\n", " \"--pos_encoding_mode\", POS_ENCODING_MODE,\n", " \"--mask_scheduler\", MASK_SCHEDULER,\n", " \"--num_node\", str(NUM_NODE),\n", " \"--d_model\", str(D_MODEL),\n", " \"--use_weighted_sampler\", USE_WEIGHTED_SAMPLER,\n", "]\n", "\n", "if CKPT_MASKING_PATH:\n", " cmd += [\"--ckpt_masking_path\", CKPT_MASKING_PATH]\n", "\n", "print(\" \".join(cmd))\n" ] }, { "cell_type": "code", "execution_count": null, "id": "37427f38", "metadata": {}, "outputs": [], "source": [ "import subprocess\n", "\n", "subprocess.run(cmd, check=True)\n" ] }, { "cell_type": "markdown", "id": "b726c167", "metadata": {}, "source": [ "## 3.2. Train count decoder (GPU required)\n", "\n", "This step trains the count decoder using the masking checkpoint.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "250a1b0f", "metadata": {}, "outputs": [], "source": [ "# GPU required\n", "COUNT_OUTPUT_DIR = \"T_perturb/res/count\" # relative to repo root, here set the output directory\n", "CKPT_MASKING_PATH = \"path/to/masking_checkpoint.ckpt\" # should be selected based on best masking model from previous step\n", "\n", "SRC_DATASET = \"T_perturb/tokenized_data/LPS_all_tps_2k/dataset_2000_hvg_src/normal.dataset\" # path/to/src_dataset.dataset from tokenization step\n", "TGT_DATASET_FOLDER = \"T_perturb/tokenized_data/LPS_all_tps_2k/dataset_2000_hvg_tgt\" # path/to/tgt_datasets_folder from tokenization step\n", "SRC_ADATA = \"T_perturb/tokenized_data/LPS_all_tps_2k/h5ad_pairing_2000_hvg_src/normal.h5ad\" # path/to/src_adata.h5ad\" from tokenization step\n", "TGT_ADATA_FOLDER = \"T_perturb/tokenized_data/LPS_all_tps_2k/h5ad_pairing_2000_hvg_tgt\" # path/to/tgt_adata_folder from tokenization step\n", "MAPPING_DICT_PATH = \"T_perturb/tokenized_data/LPS_all_tps_2k/token_id_to_genename_2000_hvg.pkl\" # path/to/mapping_dict.pkl from tokenization step\n", "\n", "ENCODER_PATH = \"Perturbgen/pretraining_cohort/20250709_1223_cellgen_train_masking_lr_5e-05_wd_1e-06_batch_64_ptime_pos_sin_m_pow_tp_1-2-3_s_42-epoch=00.ckpt\" # path to pretrained encoder checkpoint provided with Perturbgen\n", "\n", "BATCH_SIZE = 16 # Model training batch size\n", "EPOCHS = 16 # number of training epochs\n", "COUNT_LR = 0.001 # learning rate for count model\n", "CELLGEN_LR = 0.0001 # learning rate for masking part, not useful here\n", "CELLGEN_WD = 0.0001 # weight decay for masking part, not useful here\n", "COUNT_WD = 0.001 # weight decay for count model\n", "MLM_PROB = 0.30 # masking probability\n", "N_WORKERS = 32 # number of data loading workers\n", "NUM_LAYERS = 6 # number of transformer layers\n", "D_FF = 32 # feedforward dimension\n", "LOSS_MODE = \"zinb\" # loss mode for count model, could be mse, nb, or zinb\n", "PRED_TPS = [\"1\", \"2\", \"3\"] # time points to train on and predict (for LPS: \"1\"=90m, \"2\"=6h, \"3\"=10h)\n", "\n", "VAR_LIST = [\"cell_type_harmonized\", \"time_after_LPS\"] # list of obs retained in adata.vars after preprocessing\n", "\n", "COUNT_DROPOUT = 0.1 # dropout for count model\n", "USE_POSITIONAL_ENCODING = \"False\" # whether to use positional encoding in count model\n", "LAYER_NORM = \"True\" # whether to use layer normalization in count model\n", "CONTEXT_MODE = \"True\" # whether to use context tokens\n", "POS_ENCODING_MODE = \"time_pos_sin\" # positional encoding mode\n", "MASK_SCHEDULER = \"pow\" # masking scheduler type\n", "NUM_NODE = 1 # number of nodes if using distributed training\n", "D_MODEL = 768 # model dimension\n" ] }, { "cell_type": "code", "execution_count": null, "id": "a96c0216", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "python -m perturbgen train-decoder --train_mode count --split False --splitting_mode stratified --output_dir T_perturb/res/count --ckpt_masking_path path/to/masking_checkpoint.ckpt --src_dataset T_perturb/tokenized_data/LPS_all_tps_2k/dataset_2000_hvg_src/normal.dataset --tgt_dataset_folder T_perturb/tokenized_data/LPS_all_tps_2k/dataset_2000_hvg_tgt --src_adata T_perturb/tokenized_data/LPS_all_tps_2k/h5ad_pairing_2000_hvg_src/normal.h5ad --tgt_adata_folder T_perturb/tokenized_data/LPS_all_tps_2k/h5ad_pairing_2000_hvg_tgt --mapping_dict_path T_perturb/tokenized_data/LPS_all_tps_2k/token_id_to_genename_2000_hvg.pkl --batch_size 16 --max_len 797 --epochs 16 --tgt_vocab_size 2002 --count_lr 0.001 --cellgen_lr 0.0001 --cellgen_wd 0.0001 --count_wd 0.001 --mlm_prob 0.3 --n_workers 32 --num_layers 6 --d_ff 32 --loss_mode zinb --pred_tps 1 2 3 --var_list cell_type_harmonized time_after_LPS --cond_list cell_type_harmonized --encoder scmaskgit --add_cell_time False --d_condc 64 --d_condt 768 --count_dropout 0.1 --use_positional_encoding False --layer_norm True --context_mode True --encoder_path Perturbgen/pretraining_cohort/20250709_1223_cellgen_train_masking_lr_5e-05_wd_1e-06_batch_64_ptime_pos_sin_m_pow_tp_1-2-3_s_42-epoch=00.ckpt --pos_encoding_mode time_pos_sin --mask_scheduler cosine --num_node 1 --d_model 768\n" ] } ], "source": [ "cmd = [\n", " \"python\",\n", " \"-m\",\n", " \"perturbgen\",\n", " \"train-decoder\",\n", " \"--train_mode\", \"count\",\n", " \"--split\", \"False\",\n", " \"--splitting_mode\", \"stratified\",\n", " \"--output_dir\", COUNT_OUTPUT_DIR,\n", " \"--ckpt_masking_path\", CKPT_MASKING_PATH,\n", " \"--src_dataset\", SRC_DATASET,\n", " \"--tgt_dataset_folder\", TGT_DATASET_FOLDER,\n", " \"--src_adata\", SRC_ADATA,\n", " \"--tgt_adata_folder\", TGT_ADATA_FOLDER,\n", " \"--mapping_dict_path\", MAPPING_DICT_PATH,\n", " \"--batch_size\", str(BATCH_SIZE),\n", " \"--epochs\", str(EPOCHS),\n", " \"--count_lr\", str(COUNT_LR),\n", " \"--cellgen_lr\", str(CELLGEN_LR),\n", " \"--cellgen_wd\", str(CELLGEN_WD),\n", " \"--count_wd\", str(COUNT_WD),\n", " \"--mlm_prob\", str(MLM_PROB),\n", " \"--n_workers\", str(N_WORKERS),\n", " \"--num_layers\", str(NUM_LAYERS),\n", " \"--d_ff\", str(D_FF),\n", " \"--loss_mode\", LOSS_MODE,\n", " \"--pred_tps\", *PRED_TPS,\n", " \"--var_list\", *VAR_LIST,\n", " \"--encoder\", \"scmaskgit\",\n", " \"--count_dropout\", str(COUNT_DROPOUT),\n", " \"--use_positional_encoding\", USE_POSITIONAL_ENCODING,\n", " \"--layer_norm\", LAYER_NORM,\n", " \"--context_mode\", CONTEXT_MODE,\n", " \"--encoder_path\", ENCODER_PATH,\n", " \"--pos_encoding_mode\", POS_ENCODING_MODE,\n", " \"--mask_scheduler\", MASK_SCHEDULER,\n", " \"--num_node\", str(NUM_NODE),\n", " \"--d_model\", str(D_MODEL),\n", " \"--ckpt_every_n_epochs\", \"5\",\n", "]\n", "print(\" \".join(cmd))\n" ] }, { "cell_type": "code", "execution_count": null, "id": "ac2cb591", "metadata": {}, "outputs": [], "source": [ "import subprocess\n", "\n", "subprocess.run(cmd, check=True)" ] } ], "metadata": { "kernelspec": { "display_name": "perturbgen-py3.11", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 5 }