Untitled

 avatar
unknown
python
19 days ago
5.7 kB
4
Indexable
import torch
import scanpy as sc
import json
import numpy as np
import anndata as ad
import pandas as pd

from typing import List, Dict, Any

from lightning import seed_everything

from cell_cell_interactions.modeling.model import MulticellularTransformerPLModel

import torch as t

import os.path as osp

from .base import PredictorBaseClass


class MCTPredictor(PredictorBaseClass):

    def __init__(self, checkpoint_path: str, resource_dir: str | None = None):
        seed_everything(32)
        super().__init__(
            model=MulticellularTransformerPLModel, checkpoint_path=checkpoint_path
        )

        self._rsc_dir = resource_dir
        if self._rsc_dir is None:
            self._rsc_dir = "rsc/mct"

        gene2idx_file_path = osp.join(self._rsc_dir, "ensemble2idx.json")
        with open(gene2idx_file_path) as f:
            self.gene2idx = json.load(f)

        self.idx2gene = {v:k for k,v in self.gene2idx.items()}

        pretrained_label_input_order_path = osp.join(
            self._rsc_dir, "cell_type_input_order.json"
        )
        with open(pretrained_label_input_order_path) as f:
            self.pretrained_label_input_order = json.load(f)


    @classmethod
    def pp(cls, adata : ad.AnnData):
        sc.pp.normalize_total(adata, target_sum=10000)
        sc.pp.log1p(adata)

    def adata_to_model_input(
        self,
        adata: ad.AnnData,
        cell_type_col: str,
        sample_col: str,
        tissue_col: str | None,
    ) -> Dict[str, Dict[str, Any]]:

        input_dict = dict()

        if tissue_col is not None:
            tissue_labels = adata.obs[tissue_col]
            uni_tissue_labels = np.unique(tissue_labels)
            tissue2idx = {"[PAD]": 0, "[MASK]": 1}
            tissue2idx.update({v: idx + 2 for idx, v in enumerate(uni_tissue_labels)})

        sample_labels = adata.obs[sample_col].values

        uni_sample_labels, uni_sample_index = np.unique(
            sample_labels, return_index=True
        )
        sample2tissue = {
            x: y for x, y in zip(uni_sample_labels, tissue_labels[uni_sample_index])
        }

        for sample in uni_sample_labels:

            is_sample = sample_labels == sample

            X_df = adata[is_sample].to_df()
            X_df["label"] = adata[is_sample].obs[cell_type_col]

            X_agg = X_df.groupby("label").agg("sum").T

            non_observed_labels = [
                x
                for x in self.pretrained_label_input_order.keys()
                if x not in X_agg.columns
            ]

            X_inp = pd.DataFrame(
                np.zeros((X_agg.shape[0], len(non_observed_labels))),
                index=X_agg.index,
                columns=non_observed_labels,
            )
            X_inp = pd.concat((X_agg, X_inp), axis=1)
            X_inp = X_inp.loc[:, self.pretrained_label_input_order.keys()]
            X_inp = t.tensor(X_inp.values.astype(np.float32))

            gene_ids = t.tensor(
                [[self.gene2idx[x] for x in X_agg.index]], dtype=t.int64
            )

            is_masked_gene_expression = t.zeros_like(gene_ids)
            padding_mask = t.zeros_like(is_masked_gene_expression, dtype=t.bool)

            tissue_of_sample = sample2tissue[sample]
            tissue_index = tissue2idx[tissue_of_sample]

            tissue = t.tensor([tissue_index])

            input_dict[sample] = dict(
                expression_matrix=X_inp,
                gene_ids=gene_ids,
                is_masked_gene_expression=is_masked_gene_expression,
                padding_mask=padding_mask,
                tissue=tissue,
            )

        return input_dict


    def predict(self, input_dict: Dict[str, Dict[str, Any]], only_return_gex: bool = True, use_gpu: bool = True):

        self.model_to_gpu()

        output_dict = dict()

        for key, model_input in input_dict.items():
            with torch.no_grad():

                for model_key,model_val in model_input.items():
                    if isinstance(model_val, t.Tensor):
                        model_input[model_key] = model_val.to(self.model.device)

                output = self.model.model(
                    **model_input,
                )

                if self.model.device.type != 'cpu':
                    for model_key,model_val in model_input.items():
                        if isinstance(model_val, t.Tensor):
                            model_input[model_key] = model_val.cpu()

            if only_return_gex:
                setattr(output, 'gex_reconstruction', output.gex_reconstruction.cpu())

                for field in output.__dataclass_fields__:
                    if field != 'gex_reconstruction':
                        setattr(output, field, None)
            else:
                for field in output.__dataclass_fields__:
                    value = getattr(output, field)
                    if isinstance(value, torch.Tensor):
                        setattr(output, field, value.cpu())

            output_dict[key] = output
            t.cuda.empty_cache()

        self.model_to_cpu()
        t.cuda.empty_cache()

        return output_dict

    def output_to_dataframe(self, output_dict, input_dict):



        dfs = dict()

        for key, output in output_dict.items():

            gene_idx = input_dict[key]['gene_ids'].numpy()
            gene_names = [self.idx2gene[x] for x in gene_idx[0]]

            X = output.gex_reconstruction[0]
            X = pd.DataFrame(X.numpy(),
                             index = gene_names,
                             columns = list(self.pretrained_label_input_order),
                             )

            dfs[key] = X
        return dfs




Leave a Comment