Untitled
import torch import scanpy as sc import json import numpy as np import anndata as ad import pandas as pd from typing import List, Dict, Any from lightning import seed_everything from cell_cell_interactions.modeling.model import MulticellularTransformerPLModel import torch as t import os.path as osp from .base import PredictorBaseClass class MCTPredictor(PredictorBaseClass): def __init__(self, checkpoint_path: str, resource_dir: str | None = None): seed_everything(32) super().__init__( model=MulticellularTransformerPLModel, checkpoint_path=checkpoint_path ) self._rsc_dir = resource_dir if self._rsc_dir is None: self._rsc_dir = "rsc/mct" gene2idx_file_path = osp.join(self._rsc_dir, "ensemble2idx.json") with open(gene2idx_file_path) as f: self.gene2idx = json.load(f) self.idx2gene = {v:k for k,v in self.gene2idx.items()} pretrained_label_input_order_path = osp.join( self._rsc_dir, "cell_type_input_order.json" ) with open(pretrained_label_input_order_path) as f: self.pretrained_label_input_order = json.load(f) @classmethod def pp(cls, adata : ad.AnnData): sc.pp.normalize_total(adata, target_sum=10000) sc.pp.log1p(adata) def adata_to_model_input( self, adata: ad.AnnData, cell_type_col: str, sample_col: str, tissue_col: str | None, ) -> Dict[str, Dict[str, Any]]: input_dict = dict() if tissue_col is not None: tissue_labels = adata.obs[tissue_col] uni_tissue_labels = np.unique(tissue_labels) tissue2idx = {"[PAD]": 0, "[MASK]": 1} tissue2idx.update({v: idx + 2 for idx, v in enumerate(uni_tissue_labels)}) sample_labels = adata.obs[sample_col].values uni_sample_labels, uni_sample_index = np.unique( sample_labels, return_index=True ) sample2tissue = { x: y for x, y in zip(uni_sample_labels, tissue_labels[uni_sample_index]) } for sample in uni_sample_labels: is_sample = sample_labels == sample X_df = adata[is_sample].to_df() X_df["label"] = adata[is_sample].obs[cell_type_col] X_agg = X_df.groupby("label").agg("sum").T non_observed_labels = [ x for x in self.pretrained_label_input_order.keys() if x not in X_agg.columns ] X_inp = pd.DataFrame( np.zeros((X_agg.shape[0], len(non_observed_labels))), index=X_agg.index, columns=non_observed_labels, ) X_inp = pd.concat((X_agg, X_inp), axis=1) X_inp = X_inp.loc[:, self.pretrained_label_input_order.keys()] X_inp = t.tensor(X_inp.values.astype(np.float32)) gene_ids = t.tensor( [[self.gene2idx[x] for x in X_agg.index]], dtype=t.int64 ) is_masked_gene_expression = t.zeros_like(gene_ids) padding_mask = t.zeros_like(is_masked_gene_expression, dtype=t.bool) tissue_of_sample = sample2tissue[sample] tissue_index = tissue2idx[tissue_of_sample] tissue = t.tensor([tissue_index]) input_dict[sample] = dict( expression_matrix=X_inp, gene_ids=gene_ids, is_masked_gene_expression=is_masked_gene_expression, padding_mask=padding_mask, tissue=tissue, ) return input_dict def predict(self, input_dict: Dict[str, Dict[str, Any]], only_return_gex: bool = True, use_gpu: bool = True): self.model_to_gpu() output_dict = dict() for key, model_input in input_dict.items(): with torch.no_grad(): for model_key,model_val in model_input.items(): if isinstance(model_val, t.Tensor): model_input[model_key] = model_val.to(self.model.device) output = self.model.model( **model_input, ) if self.model.device.type != 'cpu': for model_key,model_val in model_input.items(): if isinstance(model_val, t.Tensor): model_input[model_key] = model_val.cpu() if only_return_gex: setattr(output, 'gex_reconstruction', output.gex_reconstruction.cpu()) for field in output.__dataclass_fields__: if field != 'gex_reconstruction': setattr(output, field, None) else: for field in output.__dataclass_fields__: value = getattr(output, field) if isinstance(value, torch.Tensor): setattr(output, field, value.cpu()) output_dict[key] = output t.cuda.empty_cache() self.model_to_cpu() t.cuda.empty_cache() return output_dict def output_to_dataframe(self, output_dict, input_dict): dfs = dict() for key, output in output_dict.items(): gene_idx = input_dict[key]['gene_ids'].numpy() gene_names = [self.idx2gene[x] for x in gene_idx[0]] X = output.gex_reconstruction[0] X = pd.DataFrame(X.numpy(), index = gene_names, columns = list(self.pretrained_label_input_order), ) dfs[key] = X return dfs
Leave a Comment