{ "cells": [ { "cell_type": "markdown", "id": "fc3a30f3", "metadata": {}, "source": [ "# 4. Gene Embedding Extraction\n", "\n", "This notebook extracts gene embeddings using the validation pipeline (GPU required).\n" ] }, { "cell_type": "markdown", "id": "a844ee1c", "metadata": {}, "source": [ "## 4.1. Parameters (GPU required)\n", "\n", "Update the paths and settings for your dataset and checkpoint.\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "ae221db5", "metadata": {}, "outputs": [], "source": [ "# GPU required\n", "OUTPUT_DIR = \"T_perturb/res/masking\" # relative to repo root, here set the output directory\n", "CKPT_MASKING_PATH = \"path/to/masking_checkpoint.ckpt\" # path to load masking model checkpoint from masking training of notebook 03\n", "\n", "SRC_DATASET = \"T_perturb/tokenized_data/LPS_all_tps_2k/dataset_2000_hvg_src/normal.dataset\" # path/to/src_dataset.dataset from tokenization step\n", "TGT_DATASET_FOLDER = \"T_perturb/tokenized_data/LPS_all_tps_2k/dataset_2000_hvg_tgt\" # path/to/tgt_datasets_folder from tokenization step\n", "SRC_ADATA = \"T_perturb/tokenized_data/LPS_all_tps_2k/h5ad_pairing_2000_hvg_src/normal.h5ad\" # path/to/src_adata.h5ad\" from tokenization step\n", "TGT_ADATA_FOLDER = \"T_perturb/tokenized_data/LPS_all_tps_2k/h5ad_pairing_2000_hvg_tgt\" # path/to/tgt_adata_folder from tokenization step\n", "MAPPING_DICT_PATH = \"T_perturb/tokenized_data/LPS_all_tps_2k/token_id_to_genename_2000_hvg.pkl\" # path/to/mapping_dict.pkl from tokenization step\n", "TOKENID_TO_ROWID = \"T_perturb/tokenized_data/LPS_all_tps_2k/tokenid_to_rowid_2000_hvg.pkl\" # path/to/tokenid_to_rowid.pkl from tokenization step\n", "\n", "BATCH_SIZE = 64 # model batch size\n", "CELLGEN_LR = 1e-4 # learning rate \n", "CELLGEN_WD = 1e-4 # weight decay\n", "COUNT_LR = 0.001 # learning rate for count head\n", "COUNT_WD = 0.001 # weight decay for count head\n", "D_FF = 32 # feedforward dimension\n", "NUM_LAYERS = 6 # number of transformer layers\n", "N_WORKERS = 32 # number of workers for data loading\n", "PRED_TPS = [\"1\", \"2\", \"3\"] # predicted time points, (for LPS: \"1\"=90m, \"2\"=6h, \"3\"=10h)\n", "\n", "VAR_LIST = [\"cell_pairing_index\", \"time_after_LPS\", \"cell_type_harmonized\"]\n", "\n", "ENCODER = \"scmaskgit\" # encoder type\n", "ENCODER_PATH = \"Perturbgen/pretraining_cohort/20250709_1223_cellgen_train_masking_lr_5e-05_wd_1e-06_batch_64_ptime_pos_sin_m_pow_tp_1-2-3_s_42-epoch=00.ckpt\" # path to pretrained encoder checkpoint provided with Perturbgen\n", "CONTEXT_MODE = \"True\" # whether to use context tokens\n", "MASK_SCHEDULER = \"pow\" # masking scheduler type\n", "RETURN_EMBED = \"True\" # whether to save cell embeddings\n", "RETURN_ATTN = \"False\" # whether to return attention weights\n", "GENERATE = \"False\" # whether to perform generation\n", "RETURN_GENE_EMBS = \"True\" # whether to return gene embeddings\n", "GENE_EMBS_CONDITION = \"time_after_LPS\" # condition for gene embeddings\n", "POS_ENCODING_MODE = \"time_pos_sin\" # positional encoding mode\n", "D_MODEL = 768 # model dimension\n" ] }, { "cell_type": "markdown", "id": "b6fc7ee4", "metadata": {}, "source": [ "## 4.2. Build the command\n", "\n", "This uses the CLI entrypoint `extract-embedding`, which calls `perturbgen.val`.\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "f3b61192", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "python -m perturbgen generate --test_mode masking --split False --splitting_mode stratified --return_embed True --return_attn False --generate False --ckpt_masking_path path/to/masking_checkpoint.ckpt --output_dir T_perturb/res/masking --src_dataset T_perturb/tokenized_data/LPS_all_tps_2k/dataset_2000_hvg_src/normal.dataset --tgt_dataset_folder T_perturb/tokenized_data/LPS_all_tps_2k/dataset_2000_hvg_tgt --src_adata T_perturb/tokenized_data/LPS_all_tps_2k/h5ad_pairing_2000_hvg_src/normal.h5ad --tgt_adata_folder T_perturb/tokenized_data/LPS_all_tps_2k/h5ad_pairing_2000_hvg_tgt --mapping_dict_path T_perturb/tokenized_data/LPS_all_tps_2k/token_id_to_genename_2000_hvg.pkl --batch_size 64 --max_len 797 --tgt_vocab_size 2002 --cellgen_lr 0.0001 --cellgen_wd 0.0001 --count_lr 0.001 --count_wd 0.001 --d_ff 32 --num_layers 6 --n_workers 32 --pred_tps 1 2 3 --var_list cell_pairing_index time_after_LPS cell_type_harmonized --tokenid_to_rowid T_perturb/tokenized_data/LPS_all_tps_2k/tokenid_to_rowid_2000_hvg.pkl --encoder scmaskgit --encoder_path Perturbgen/pretraining_cohort/20250709_1223_cellgen_train_masking_lr_5e-05_wd_1e-06_batch_64_ptime_pos_sin_m_pow_tp_1-2-3_s_42-epoch=00.ckpt --context_mode True --mask_scheduler pow --return_gene_embs True --gene_embs_condition time_after_LPS --pos_encoding_mode time_pos_sin --d_model 768\n" ] } ], "source": [ "cmd = [\n", " \"python\",\n", " \"-m\",\n", " \"perturbgen\",\n", " \"extract-embedding\",\n", " \"--test_mode\", \"masking\",\n", " \"--split\", \"False\",\n", " \"--splitting_mode\", \"stratified\",\n", " \"--return_embed\", RETURN_EMBED,\n", " \"--return_attn\", RETURN_ATTN,\n", " \"--generate\", GENERATE,\n", " \"--ckpt_masking_path\", CKPT_MASKING_PATH,\n", " \"--output_dir\", OUTPUT_DIR,\n", " \"--src_dataset\", SRC_DATASET,\n", " \"--tgt_dataset_folder\", TGT_DATASET_FOLDER,\n", " \"--src_adata\", SRC_ADATA,\n", " \"--tgt_adata_folder\", TGT_ADATA_FOLDER,\n", " \"--mapping_dict_path\", MAPPING_DICT_PATH,\n", " \"--batch_size\", str(BATCH_SIZE),\n", " \"--cellgen_lr\", str(CELLGEN_LR),\n", " \"--cellgen_wd\", str(CELLGEN_WD),\n", " \"--count_lr\", str(COUNT_LR),\n", " \"--count_wd\", str(COUNT_WD),\n", " \"--d_ff\", str(D_FF),\n", " \"--num_layers\", str(NUM_LAYERS),\n", " \"--n_workers\", str(N_WORKERS),\n", " \"--pred_tps\", *PRED_TPS,\n", " \"--var_list\", *VAR_LIST,\n", " \"--tokenid_to_rowid\", TOKENID_TO_ROWID,\n", " \"--encoder\", ENCODER,\n", " \"--encoder_path\", ENCODER_PATH,\n", " \"--context_mode\", CONTEXT_MODE,\n", " \"--mask_scheduler\", MASK_SCHEDULER,\n", " \"--return_gene_embs\", RETURN_GENE_EMBS,\n", " \"--gene_embs_condition\", GENE_EMBS_CONDITION,\n", " \"--pos_encoding_mode\", POS_ENCODING_MODE,\n", " \"--d_model\", str(D_MODEL),\n", "]\n", "\n", "print(\" \".join(cmd))\n" ] }, { "cell_type": "code", "execution_count": null, "id": "c3fac909", "metadata": {}, "outputs": [], "source": [ "import subprocess\n", "\n", "print(\" \".join(cmd))" ] } ], "metadata": { "kernelspec": { "display_name": "perturbgen-py3.11", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 5 }