Untitled
unknown
python
a year ago
5.7 kB
10
Indexable
import torch
import scanpy as sc
import json
import numpy as np
import anndata as ad
import pandas as pd
from typing import List, Dict, Any
from lightning import seed_everything
from cell_cell_interactions.modeling.model import MulticellularTransformerPLModel
import torch as t
import os.path as osp
from .base import PredictorBaseClass
class MCTPredictor(PredictorBaseClass):
def __init__(self, checkpoint_path: str, resource_dir: str | None = None):
seed_everything(32)
super().__init__(
model=MulticellularTransformerPLModel, checkpoint_path=checkpoint_path
)
self._rsc_dir = resource_dir
if self._rsc_dir is None:
self._rsc_dir = "rsc/mct"
gene2idx_file_path = osp.join(self._rsc_dir, "ensemble2idx.json")
with open(gene2idx_file_path) as f:
self.gene2idx = json.load(f)
self.idx2gene = {v:k for k,v in self.gene2idx.items()}
pretrained_label_input_order_path = osp.join(
self._rsc_dir, "cell_type_input_order.json"
)
with open(pretrained_label_input_order_path) as f:
self.pretrained_label_input_order = json.load(f)
@classmethod
def pp(cls, adata : ad.AnnData):
sc.pp.normalize_total(adata, target_sum=10000)
sc.pp.log1p(adata)
def adata_to_model_input(
self,
adata: ad.AnnData,
cell_type_col: str,
sample_col: str,
tissue_col: str | None,
) -> Dict[str, Dict[str, Any]]:
input_dict = dict()
if tissue_col is not None:
tissue_labels = adata.obs[tissue_col]
uni_tissue_labels = np.unique(tissue_labels)
tissue2idx = {"[PAD]": 0, "[MASK]": 1}
tissue2idx.update({v: idx + 2 for idx, v in enumerate(uni_tissue_labels)})
sample_labels = adata.obs[sample_col].values
uni_sample_labels, uni_sample_index = np.unique(
sample_labels, return_index=True
)
sample2tissue = {
x: y for x, y in zip(uni_sample_labels, tissue_labels[uni_sample_index])
}
for sample in uni_sample_labels:
is_sample = sample_labels == sample
X_df = adata[is_sample].to_df()
X_df["label"] = adata[is_sample].obs[cell_type_col]
X_agg = X_df.groupby("label").agg("sum").T
non_observed_labels = [
x
for x in self.pretrained_label_input_order.keys()
if x not in X_agg.columns
]
X_inp = pd.DataFrame(
np.zeros((X_agg.shape[0], len(non_observed_labels))),
index=X_agg.index,
columns=non_observed_labels,
)
X_inp = pd.concat((X_agg, X_inp), axis=1)
X_inp = X_inp.loc[:, self.pretrained_label_input_order.keys()]
X_inp = t.tensor(X_inp.values.astype(np.float32))
gene_ids = t.tensor(
[[self.gene2idx[x] for x in X_agg.index]], dtype=t.int64
)
is_masked_gene_expression = t.zeros_like(gene_ids)
padding_mask = t.zeros_like(is_masked_gene_expression, dtype=t.bool)
tissue_of_sample = sample2tissue[sample]
tissue_index = tissue2idx[tissue_of_sample]
tissue = t.tensor([tissue_index])
input_dict[sample] = dict(
expression_matrix=X_inp,
gene_ids=gene_ids,
is_masked_gene_expression=is_masked_gene_expression,
padding_mask=padding_mask,
tissue=tissue,
)
return input_dict
def predict(self, input_dict: Dict[str, Dict[str, Any]], only_return_gex: bool = True, use_gpu: bool = True):
self.model_to_gpu()
output_dict = dict()
for key, model_input in input_dict.items():
with torch.no_grad():
for model_key,model_val in model_input.items():
if isinstance(model_val, t.Tensor):
model_input[model_key] = model_val.to(self.model.device)
output = self.model.model(
**model_input,
)
if self.model.device.type != 'cpu':
for model_key,model_val in model_input.items():
if isinstance(model_val, t.Tensor):
model_input[model_key] = model_val.cpu()
if only_return_gex:
setattr(output, 'gex_reconstruction', output.gex_reconstruction.cpu())
for field in output.__dataclass_fields__:
if field != 'gex_reconstruction':
setattr(output, field, None)
else:
for field in output.__dataclass_fields__:
value = getattr(output, field)
if isinstance(value, torch.Tensor):
setattr(output, field, value.cpu())
output_dict[key] = output
t.cuda.empty_cache()
self.model_to_cpu()
t.cuda.empty_cache()
return output_dict
def output_to_dataframe(self, output_dict, input_dict):
dfs = dict()
for key, output in output_dict.items():
gene_idx = input_dict[key]['gene_ids'].numpy()
gene_names = [self.idx2gene[x] for x in gene_idx[0]]
X = output.gex_reconstruction[0]
X = pd.DataFrame(X.numpy(),
index = gene_names,
columns = list(self.pretrained_label_input_order),
)
dfs[key] = X
return dfs
Editor is loading...
Leave a Comment