Untitled

mail@pastecode.io avatar
unknown
python
2 years ago
17 kB
2
Indexable
Never
import os
from tkinter import W
from typing import Any, List
from omegaconf import DictConfig
import timm
import torch
import torch.nn as nn
from pytorch_lightning import LightningModule
from torchmetrics import Precision
import ivtmetrics
from hydra.utils import get_original_cwd
import numpy as np
from pprint import pprint
import json
import random
import subprocess

# assert timm.__version__ == "0.6.2.dev0", "Unsupport timm version"


class TripletAttentionModule(LightningModule):
    def __init__(
        self,
        temporal_cfg: DictConfig,
        optim: DictConfig,
        loss_weight: DictConfig,
        tool_component: DictConfig,
        target_tool_attention: DictConfig,
        use_pretrained: bool = True,
        emb_dim: int = 256,
        backbone_model: str = "",
        backbone_trainable: bool = True,
        triplet_map: str = "./data/CholecT45/dict/maps.txt",
    ):
        super().__init__()

        self.save_hyperparameters(logger=False)

        self.train_recog_metric = ivtmetrics.Recognition(num_class=100)
        self.valid_recog_metric = ivtmetrics.Recognition(num_class=100)
        self.test_recog_metric = ivtmetrics.Recognition(num_class=100)

        self.class_num = {
            "tool": 6,
            "verb": 10,
            "target": 15,
            "triplet": 100,
        }
        self.train_tool_map = Precision(
            num_classes=self.class_num["tool"], average="macro",
        )
        self.train_verb_map = Precision(
            num_classes=self.class_num["verb"], average="macro",
        )
        self.train_target_map = Precision(
            num_classes=self.class_num["target"], average="macro",
        )
        self.train_triplet_map = Precision(
            num_classes=self.class_num["triplet"], average="macro",
        )
        self.valid_tool_map = Precision(
            num_classes=self.class_num["tool"], average="macro",
        )
        self.valid_verb_map = Precision(
            num_classes=self.class_num["verb"], average="macro",
        )
        self.valid_target_map = Precision(
            num_classes=self.class_num["target"], average="macro",
        )
        self.valid_triplet_map = Precision(
            num_classes=self.class_num["triplet"], average="macro",
        )

        assert (
            "vit" in backbone_model or "swin" in backbone_model
        ), "Only support using vision transformer based model"

        self.feature_extractor = timm.create_model(
            backbone_model,
            pretrained=use_pretrained,
            in_chans=3,
            num_classes=0,
        )

        for p in self.feature_extractor.parameters():
            p.requires_grad = False

        self.tool_information = nn.Sequential(
            nn.Linear(
                self.feature_extractor.num_features,
                emb_dim,
            ),
            nn.Dropout(p=tool_component.dropout_ratio),
        )

        self.tool_head = nn.Sequential(
            nn.Linear(
                emb_dim,
                self.class_num["tool"],
            ),
        )

        self.attention_pre_fc = nn.Linear(
            self.feature_extractor.num_features,
            emb_dim,
        )

        self.target_tool_attention = nn.MultiheadAttention(
            embed_dim=emb_dim,
            batch_first=True,
            **target_tool_attention,
        )

        self.target_head = nn.Sequential(
            nn.Linear(
                emb_dim,
                self.class_num["target"],
            ),
        )

        self.ts = nn.Sequential(
            getattr(nn, temporal_cfg.type)(
                input_size=self.feature_extractor.num_features,
                hidden_size=temporal_cfg.hidden_size,
                num_layers=temporal_cfg.num_layers,
                bidirectional=temporal_cfg.bidirectional,
                batch_first=True,
            ),
        )

        self.ts_fc = nn.Linear(
            temporal_cfg.hidden_size * self.temporal_direction(),
            emb_dim,
        )

        self.verb_head = nn.Sequential(
            nn.Linear(
                emb_dim,
                self.class_num["verb"],
            ),
        )

        self.triplet_head = nn.Sequential(
            nn.Linear(
                emb_dim,
                self.class_num["triplet"],
            ),
        )

        self.criterion = torch.nn.BCEWithLogitsLoss()

        self.triplet_map = self.contstruct_triplet_map()

        self.vit_dim = self.test_dim()

        subprocess.run(["mkdir", "valid"])
        subprocess.run(["mkdir", "test"])

    def test_dim(self):
        self.feature_extractor.eval()
        x = torch.randn(1, 3, 224, 224)
        return self.feature_extractor.forward_features(x).shape[1]

    def contstruct_triplet_map(self):
        with open(os.path.join(get_original_cwd(), self.hparams.triplet_map), "r") as f:
            triplet_map = f.read().split("\n")[1:-2]

        ret = list()
        for triplet in triplet_map:
            ret.append(list(map(int, triplet.split(","))))

        return ret

    def temporal_direction(self):
        if (
            self.hparams.temporal_cfg.type is None
            or not self.hparams.temporal_cfg.bidirectional
        ):
            return 1
        else:
            return 2

    def frames_feature_extractor(
        self,
        x: torch.Tensor,
        output: torch.Tensor,
    ):
        for i in range(0, x.shape[1]):
            output[:, i, :, :] = self.feature_extractor.forward_features(
                x[:, i, :, :, :]
            )
        return output.to(self.device)

    def forward(self, x):
        output_tensor = torch.zeros(
            [x.shape[0], x.shape[1], self.vit_dim, self.feature_extractor.num_features]
        )
        feature = self.frames_feature_extractor(x, output_tensor)

        tool_seq_info = self.tool_information(feature[:, -1, :, :])

        tool_info = tool_seq_info.mean(dim=1)
        tool_logit = self.tool_head(tool_info)

        attn_feature = self.attention_pre_fc(feature[:, -1, :, :])

        attn_output, _ = self.target_tool_attention(
            attn_feature,
            tool_seq_info,
            tool_seq_info,
            need_weights=False,
        )

        target_logit = self.target_head(attn_output.mean(dim=1))

        ts_feature, _ = self.ts(feature.mean(dim=2))
        ts_feature = self.ts_fc(ts_feature)
        verb_logit = self.verb_head((ts_feature[:, -1, :] + tool_info) / 2)
        triplet_logit = self.triplet_head((ts_feature[:, -1, :] + tool_info) / 2)
        print(triplet_logit)
        print(triplet_logit.shape)

        return tool_logit, target_logit, verb_logit, triplet_logit

    def step(self, batch: Any):
        """
        batch would a be a dict might contains the following things
        *image*: the frame image
        *action*: the action [Action type 0, Action type 1, Action type 3, Action type 4]
        *tool*: the tool [Tool 0, Tool 1, ..., Tool 6]
        *phase*: the phase [phase 0, ..., phase 6]

        ex:
        image = batch["image"]
        self.forward(image)

        return

        loss: the loss by the loss_fn
        preds: the pred by our model (i guess it would be sth like preds = torch.argmax(logits, dim=-1))
        y: correspond to the task it should be action or tool
        """
        tool_logit, target_logit, verb_logit, triplet_logit = self.forward(
            batch["image"]
        )
        tool_loss = self.criterion(tool_logit, batch["tool"])
        target_loss = self.criterion(target_logit, batch["target"])
        verb_loss = self.criterion(verb_logit, batch["verb"])
        triplet_loss = self.criterion(triplet_logit, batch["triplet"])
        return (
            self.hparams.loss_weight.tool_weight * tool_loss
            + self.hparams.loss_weight.target_weight * target_loss
            + self.hparams.loss_weight.verb_weight * verb_loss
            + self.hparams.loss_weight.triplet_weight * triplet_loss,
            tool_logit,
            target_logit,
            verb_logit,
            triplet_logit,
        )

    def training_step(self, batch: Any, batch_idx: int):
        loss, tool_logit, target_logit, verb_logit, triplet_logit = self.step(batch)

        # self.train_recog_metric.update(
        #     batch["triplet"].cpu().numpy(),
        #     triplet_logit,
        # )
        self.train_tool_map(tool_logit, batch["tool"].to(torch.int))
        self.train_target_map(target_logit, batch["target"].to(torch.int))
        self.train_verb_map(verb_logit, batch["verb"].to(torch.int))
        self.train_triplet_map(triplet_logit, batch["triplet"].to(torch.int))
        self.log(
            "train/loss",
            loss,
            on_step=True,
            on_epoch=True,
            prog_bar=False,
        )
        self.log(
            "train/tool_mAP",
            self.train_tool_map,
            on_step=True,
            on_epoch=True,
        )
        self.log(
            "train/verb_mAP",
            self.train_verb_map,
            on_step=True,
            on_epoch=True,
        )
        self.log(
            "train/target_mAP",
            self.train_target_map,
            on_step=True,
            on_epoch=True,
        )
        self.log(
            "train/triplet_mAP",
            self.train_triplet_map,
            on_step=True,
            on_epoch=True,
        )

        return loss

    def training_epoch_end(self, outputs: List[Any]):
        # self.train_tool_map.reset()
        # self.train_target_map.reset()
        # self.train_verb_map.reset()
        # self.train_triplet_map.reset()
        # ivt_result = self.train_recog_metric.compute_global_AP("ivt")
        # pprint(ivt_result["AP"])
        # self.log("train/ivt_mAP", ivt_result["mAP"])
        # self.log("train/i_mAP", self.train_recog_metric.compute_global_AP("i")["mAP"])
        # self.log("train/v_mAP", self.train_recog_metric.compute_global_AP("v")["mAP"])
        # self.log("train/t_mAP", self.train_recog_metric.compute_global_AP("t")["mAP"])
        pass

    def validation_step(self, batch: Any, batch_idx: int):
        loss, tool_logit, target_logit, verb_logit, triplet_logit = self.step(batch)

        self.valid_tool_map(tool_logit, batch["tool"].to(torch.int))
        self.valid_target_map(target_logit, batch["target"].to(torch.int))
        self.valid_verb_map(verb_logit, batch["verb"].to(torch.int))
        self.valid_triplet_map(triplet_logit, batch["triplet"].to(torch.int))

        self.log(
            "valid/tool_mAP",
            self.valid_tool_map,
            on_step=True,
            on_epoch=True,
        )
        self.log(
            "valid/verb_mAP",
            self.valid_verb_map,
            on_step=True,
            on_epoch=True,
        )
        self.log(
            "valid/target_mAP",
            self.valid_target_map,
            on_step=True,
            on_epoch=True,
        )
        self.log(
            "valid/triplet_mAP",
            self.valid_triplet_map,
            on_step=True,
            on_epoch=True,
        )

        tool_logit, target_logit, verb_logit, triplet_logit = (
            tool_logit.detach().cpu().numpy(),
            target_logit.detach().cpu().numpy(),
            verb_logit.detach().cpu().numpy(),
            triplet_logit.detach().cpu().numpy(),
        )

        self.valid_recog_metric.update(
            batch["triplet"].cpu().numpy(),
            triplet_logit,
        )
        self.log("valid/loss", loss, on_step=True, on_epoch=True, prog_bar=False)


        for i in range(len(batch['frame'])):

            subprocess.run(["touch", f'video_{batch["video"][i][3:]}.json'])
            with open(f'video_{batch["video"][i][3:]}.json', 'r+') as f:
                try:
                    data = json.load(f)
                    
                except:
                    data = {}

                data[int(batch["frame"][i])] = {
                    "recognition": triplet_logit.tolist()[i],
                    "detection": [
                        {
                            "triplet": int(np.argmax(triplet_logit[i], axis=-1)), 
                            "instrument": [
                                random.randint(0, 5), 
                                random.randint(0, 256),
                                random.randint(0, 256),
                                random.randint(0, 256),
                                random.randint(0, 256),
                            ]
                        }
                    ]
                }
                f.seek(0)
                json.dump(data, f, indent=4)
                f.truncate()           

        return loss

    def validation_epoch_end(self, outputs: List[Any]):
        # self.valid_tool_map.reset()
        # self.valid_target_map.reset()
        # self.valid_verb_map.reset()
        # self.valid_triplet_map.reset()
        ivt_result = self.valid_recog_metric.compute_global_AP("ivt")
        pprint(ivt_result["AP"])
        self.log(
            "valid/ivt_mAP",
            ivt_result["mAP"],
        )
        self.log("valid/i_mAP", self.valid_recog_metric.compute_global_AP("i")["mAP"])
        self.log("valid/v_mAP", self.valid_recog_metric.compute_global_AP("v")["mAP"])
        self.log("valid/t_mAP", self.valid_recog_metric.compute_global_AP("t")["mAP"])

    def test_step(self, batch: Any, batch_idx: int):
        loss, tool_logit, target_logit, verb_logit, triplet_logit = self.step(batch)

        tool_logit, target_logit, verb_logit, triplet_logit = (
            tool_logit.detach().cpu().numpy(),
            target_logit.detach().cpu().numpy(),
            verb_logit.detach().cpu().numpy(),
            triplet_logit.detach().cpu().numpy(),
        )

        post_tool_logit, post_target_logit, post_verb_logit = (
            np.zeros([triplet_logit.shape[0], 100]),
            np.zeros([triplet_logit.shape[0], 100]),
            np.zeros([triplet_logit.shape[0], 100]),
        )

        for i in range(triplet_logit.shape[0]):
            for index, _triplet in enumerate(self.triplet_map):
                post_tool_logit[i][index] = tool_logit[i][_triplet[1]]
                post_verb_logit[i][index] = verb_logit[i][_triplet[2]]
                post_target_logit[i][index] = target_logit[i][_triplet[3]]

        self.test_recog_metric.update(
            batch["triplet"].cpu().numpy(),
            # triplet_logit,
            triplet_logit + 0.4 * post_target_logit + 0.2 * post_verb_logit,
        )
        self.log("test/loss", loss, on_step=True, on_epoch=True, prog_bar=False)
        self.log("test/i_mAP", self.test_recog_metric.compute_global_AP("i")["mAP"])
        self.log("test/v_mAP", self.test_recog_metric.compute_global_AP("v")["mAP"])
        self.log("test/t_mAP", self.test_recog_metric.compute_global_AP("t")["mAP"])

        self.valid_tool_map(tool_logit, batch["tool"].to(torch.int))
        self.valid_target_map(target_logit, batch["target"].to(torch.int))
        self.valid_verb_map(verb_logit, batch["verb"].to(torch.int))
        self.valid_triplet_map(triplet_logit, batch["triplet"].to(torch.int))

        for i in range(len(batch['frame'])):

            subprocess.run(["touch", f'video_{batch["video"][i][3:]}.json'])
            with open(f'video_{batch["video"][i][3:]}.json', 'r+') as f:
                try:
                    data = json.load(f)
                    
                except:
                    data = {}

                data[int(batch["frame"][i])] = {
                    "recognition": triplet_logit.tolist()[i],
                    "detection": [
                        {
                            "triplet": int(np.argmax(triplet_logit[i], axis=-1)), 
                            "instrument": [
                                random.randint(0, 5), 
                                random.randint(0, 256),
                                random.randint(0, 256),
                                random.randint(0, 256),
                                random.randint(0, 256),
                            ]
                        }
                    ]
                }
                f.seek(0)
                json.dump(data, f, indent=4)
                f.truncate() 

        return loss

    def test_epoch_end(self, outputs: List[Any]):
        ivt_result = self.test_recog_metric.compute_global_AP("ivt")
        pprint(ivt_result["AP"])
        self.log(
            "test/ivt_mAP",
            ivt_result["mAP"],
        )
        self.log("test/i_mAP", self.test_recog_metric.compute_global_AP("i")["mAP"])
        self.log("test/v_mAP", self.test_recog_metric.compute_global_AP("v")["mAP"])
        self.log("test/t_mAP", self.test_recog_metric.compute_global_AP("t")["mAP"])

    def on_epoch_end(self):
        self.train_recog_metric.reset()
        self.valid_recog_metric.reset()
        self.test_recog_metric.reset()

    def configure_optimizers(self):
        opt = getattr(torch.optim, self.hparams.optim.optim_name)(
            params=self.parameters(),
            lr=self.hparams.optim.lr,
            weight_decay=self.hparams.optim.weight_decay,
        )
        lr_scheduler = getattr(
            torch.optim.lr_scheduler, self.hparams.optim.scheduler_name
        )(
            opt,
            **self.hparams.optim.scheduler,
        )
        return [opt], [lr_scheduler]