Untitled
unknown
python
3 years ago
17 kB
9
Indexable
import os
from tkinter import W
from typing import Any, List
from omegaconf import DictConfig
import timm
import torch
import torch.nn as nn
from pytorch_lightning import LightningModule
from torchmetrics import Precision
import ivtmetrics
from hydra.utils import get_original_cwd
import numpy as np
from pprint import pprint
import json
import random
import subprocess
# assert timm.__version__ == "0.6.2.dev0", "Unsupport timm version"
class TripletAttentionModule(LightningModule):
def __init__(
self,
temporal_cfg: DictConfig,
optim: DictConfig,
loss_weight: DictConfig,
tool_component: DictConfig,
target_tool_attention: DictConfig,
use_pretrained: bool = True,
emb_dim: int = 256,
backbone_model: str = "",
backbone_trainable: bool = True,
triplet_map: str = "./data/CholecT45/dict/maps.txt",
):
super().__init__()
self.save_hyperparameters(logger=False)
self.train_recog_metric = ivtmetrics.Recognition(num_class=100)
self.valid_recog_metric = ivtmetrics.Recognition(num_class=100)
self.test_recog_metric = ivtmetrics.Recognition(num_class=100)
self.class_num = {
"tool": 6,
"verb": 10,
"target": 15,
"triplet": 100,
}
self.train_tool_map = Precision(
num_classes=self.class_num["tool"], average="macro",
)
self.train_verb_map = Precision(
num_classes=self.class_num["verb"], average="macro",
)
self.train_target_map = Precision(
num_classes=self.class_num["target"], average="macro",
)
self.train_triplet_map = Precision(
num_classes=self.class_num["triplet"], average="macro",
)
self.valid_tool_map = Precision(
num_classes=self.class_num["tool"], average="macro",
)
self.valid_verb_map = Precision(
num_classes=self.class_num["verb"], average="macro",
)
self.valid_target_map = Precision(
num_classes=self.class_num["target"], average="macro",
)
self.valid_triplet_map = Precision(
num_classes=self.class_num["triplet"], average="macro",
)
assert (
"vit" in backbone_model or "swin" in backbone_model
), "Only support using vision transformer based model"
self.feature_extractor = timm.create_model(
backbone_model,
pretrained=use_pretrained,
in_chans=3,
num_classes=0,
)
for p in self.feature_extractor.parameters():
p.requires_grad = False
self.tool_information = nn.Sequential(
nn.Linear(
self.feature_extractor.num_features,
emb_dim,
),
nn.Dropout(p=tool_component.dropout_ratio),
)
self.tool_head = nn.Sequential(
nn.Linear(
emb_dim,
self.class_num["tool"],
),
)
self.attention_pre_fc = nn.Linear(
self.feature_extractor.num_features,
emb_dim,
)
self.target_tool_attention = nn.MultiheadAttention(
embed_dim=emb_dim,
batch_first=True,
**target_tool_attention,
)
self.target_head = nn.Sequential(
nn.Linear(
emb_dim,
self.class_num["target"],
),
)
self.ts = nn.Sequential(
getattr(nn, temporal_cfg.type)(
input_size=self.feature_extractor.num_features,
hidden_size=temporal_cfg.hidden_size,
num_layers=temporal_cfg.num_layers,
bidirectional=temporal_cfg.bidirectional,
batch_first=True,
),
)
self.ts_fc = nn.Linear(
temporal_cfg.hidden_size * self.temporal_direction(),
emb_dim,
)
self.verb_head = nn.Sequential(
nn.Linear(
emb_dim,
self.class_num["verb"],
),
)
self.triplet_head = nn.Sequential(
nn.Linear(
emb_dim,
self.class_num["triplet"],
),
)
self.criterion = torch.nn.BCEWithLogitsLoss()
self.triplet_map = self.contstruct_triplet_map()
self.vit_dim = self.test_dim()
subprocess.run(["mkdir", "valid"])
subprocess.run(["mkdir", "test"])
def test_dim(self):
self.feature_extractor.eval()
x = torch.randn(1, 3, 224, 224)
return self.feature_extractor.forward_features(x).shape[1]
def contstruct_triplet_map(self):
with open(os.path.join(get_original_cwd(), self.hparams.triplet_map), "r") as f:
triplet_map = f.read().split("\n")[1:-2]
ret = list()
for triplet in triplet_map:
ret.append(list(map(int, triplet.split(","))))
return ret
def temporal_direction(self):
if (
self.hparams.temporal_cfg.type is None
or not self.hparams.temporal_cfg.bidirectional
):
return 1
else:
return 2
def frames_feature_extractor(
self,
x: torch.Tensor,
output: torch.Tensor,
):
for i in range(0, x.shape[1]):
output[:, i, :, :] = self.feature_extractor.forward_features(
x[:, i, :, :, :]
)
return output.to(self.device)
def forward(self, x):
output_tensor = torch.zeros(
[x.shape[0], x.shape[1], self.vit_dim, self.feature_extractor.num_features]
)
feature = self.frames_feature_extractor(x, output_tensor)
tool_seq_info = self.tool_information(feature[:, -1, :, :])
tool_info = tool_seq_info.mean(dim=1)
tool_logit = self.tool_head(tool_info)
attn_feature = self.attention_pre_fc(feature[:, -1, :, :])
attn_output, _ = self.target_tool_attention(
attn_feature,
tool_seq_info,
tool_seq_info,
need_weights=False,
)
target_logit = self.target_head(attn_output.mean(dim=1))
ts_feature, _ = self.ts(feature.mean(dim=2))
ts_feature = self.ts_fc(ts_feature)
verb_logit = self.verb_head((ts_feature[:, -1, :] + tool_info) / 2)
triplet_logit = self.triplet_head((ts_feature[:, -1, :] + tool_info) / 2)
print(triplet_logit)
print(triplet_logit.shape)
return tool_logit, target_logit, verb_logit, triplet_logit
def step(self, batch: Any):
"""
batch would a be a dict might contains the following things
*image*: the frame image
*action*: the action [Action type 0, Action type 1, Action type 3, Action type 4]
*tool*: the tool [Tool 0, Tool 1, ..., Tool 6]
*phase*: the phase [phase 0, ..., phase 6]
ex:
image = batch["image"]
self.forward(image)
return
loss: the loss by the loss_fn
preds: the pred by our model (i guess it would be sth like preds = torch.argmax(logits, dim=-1))
y: correspond to the task it should be action or tool
"""
tool_logit, target_logit, verb_logit, triplet_logit = self.forward(
batch["image"]
)
tool_loss = self.criterion(tool_logit, batch["tool"])
target_loss = self.criterion(target_logit, batch["target"])
verb_loss = self.criterion(verb_logit, batch["verb"])
triplet_loss = self.criterion(triplet_logit, batch["triplet"])
return (
self.hparams.loss_weight.tool_weight * tool_loss
+ self.hparams.loss_weight.target_weight * target_loss
+ self.hparams.loss_weight.verb_weight * verb_loss
+ self.hparams.loss_weight.triplet_weight * triplet_loss,
tool_logit,
target_logit,
verb_logit,
triplet_logit,
)
def training_step(self, batch: Any, batch_idx: int):
loss, tool_logit, target_logit, verb_logit, triplet_logit = self.step(batch)
# self.train_recog_metric.update(
# batch["triplet"].cpu().numpy(),
# triplet_logit,
# )
self.train_tool_map(tool_logit, batch["tool"].to(torch.int))
self.train_target_map(target_logit, batch["target"].to(torch.int))
self.train_verb_map(verb_logit, batch["verb"].to(torch.int))
self.train_triplet_map(triplet_logit, batch["triplet"].to(torch.int))
self.log(
"train/loss",
loss,
on_step=True,
on_epoch=True,
prog_bar=False,
)
self.log(
"train/tool_mAP",
self.train_tool_map,
on_step=True,
on_epoch=True,
)
self.log(
"train/verb_mAP",
self.train_verb_map,
on_step=True,
on_epoch=True,
)
self.log(
"train/target_mAP",
self.train_target_map,
on_step=True,
on_epoch=True,
)
self.log(
"train/triplet_mAP",
self.train_triplet_map,
on_step=True,
on_epoch=True,
)
return loss
def training_epoch_end(self, outputs: List[Any]):
# self.train_tool_map.reset()
# self.train_target_map.reset()
# self.train_verb_map.reset()
# self.train_triplet_map.reset()
# ivt_result = self.train_recog_metric.compute_global_AP("ivt")
# pprint(ivt_result["AP"])
# self.log("train/ivt_mAP", ivt_result["mAP"])
# self.log("train/i_mAP", self.train_recog_metric.compute_global_AP("i")["mAP"])
# self.log("train/v_mAP", self.train_recog_metric.compute_global_AP("v")["mAP"])
# self.log("train/t_mAP", self.train_recog_metric.compute_global_AP("t")["mAP"])
pass
def validation_step(self, batch: Any, batch_idx: int):
loss, tool_logit, target_logit, verb_logit, triplet_logit = self.step(batch)
self.valid_tool_map(tool_logit, batch["tool"].to(torch.int))
self.valid_target_map(target_logit, batch["target"].to(torch.int))
self.valid_verb_map(verb_logit, batch["verb"].to(torch.int))
self.valid_triplet_map(triplet_logit, batch["triplet"].to(torch.int))
self.log(
"valid/tool_mAP",
self.valid_tool_map,
on_step=True,
on_epoch=True,
)
self.log(
"valid/verb_mAP",
self.valid_verb_map,
on_step=True,
on_epoch=True,
)
self.log(
"valid/target_mAP",
self.valid_target_map,
on_step=True,
on_epoch=True,
)
self.log(
"valid/triplet_mAP",
self.valid_triplet_map,
on_step=True,
on_epoch=True,
)
tool_logit, target_logit, verb_logit, triplet_logit = (
tool_logit.detach().cpu().numpy(),
target_logit.detach().cpu().numpy(),
verb_logit.detach().cpu().numpy(),
triplet_logit.detach().cpu().numpy(),
)
self.valid_recog_metric.update(
batch["triplet"].cpu().numpy(),
triplet_logit,
)
self.log("valid/loss", loss, on_step=True, on_epoch=True, prog_bar=False)
for i in range(len(batch['frame'])):
subprocess.run(["touch", f'video_{batch["video"][i][3:]}.json'])
with open(f'video_{batch["video"][i][3:]}.json', 'r+') as f:
try:
data = json.load(f)
except:
data = {}
data[int(batch["frame"][i])] = {
"recognition": triplet_logit.tolist()[i],
"detection": [
{
"triplet": int(np.argmax(triplet_logit[i], axis=-1)),
"instrument": [
random.randint(0, 5),
random.randint(0, 256),
random.randint(0, 256),
random.randint(0, 256),
random.randint(0, 256),
]
}
]
}
f.seek(0)
json.dump(data, f, indent=4)
f.truncate()
return loss
def validation_epoch_end(self, outputs: List[Any]):
# self.valid_tool_map.reset()
# self.valid_target_map.reset()
# self.valid_verb_map.reset()
# self.valid_triplet_map.reset()
ivt_result = self.valid_recog_metric.compute_global_AP("ivt")
pprint(ivt_result["AP"])
self.log(
"valid/ivt_mAP",
ivt_result["mAP"],
)
self.log("valid/i_mAP", self.valid_recog_metric.compute_global_AP("i")["mAP"])
self.log("valid/v_mAP", self.valid_recog_metric.compute_global_AP("v")["mAP"])
self.log("valid/t_mAP", self.valid_recog_metric.compute_global_AP("t")["mAP"])
def test_step(self, batch: Any, batch_idx: int):
loss, tool_logit, target_logit, verb_logit, triplet_logit = self.step(batch)
tool_logit, target_logit, verb_logit, triplet_logit = (
tool_logit.detach().cpu().numpy(),
target_logit.detach().cpu().numpy(),
verb_logit.detach().cpu().numpy(),
triplet_logit.detach().cpu().numpy(),
)
post_tool_logit, post_target_logit, post_verb_logit = (
np.zeros([triplet_logit.shape[0], 100]),
np.zeros([triplet_logit.shape[0], 100]),
np.zeros([triplet_logit.shape[0], 100]),
)
for i in range(triplet_logit.shape[0]):
for index, _triplet in enumerate(self.triplet_map):
post_tool_logit[i][index] = tool_logit[i][_triplet[1]]
post_verb_logit[i][index] = verb_logit[i][_triplet[2]]
post_target_logit[i][index] = target_logit[i][_triplet[3]]
self.test_recog_metric.update(
batch["triplet"].cpu().numpy(),
# triplet_logit,
triplet_logit + 0.4 * post_target_logit + 0.2 * post_verb_logit,
)
self.log("test/loss", loss, on_step=True, on_epoch=True, prog_bar=False)
self.log("test/i_mAP", self.test_recog_metric.compute_global_AP("i")["mAP"])
self.log("test/v_mAP", self.test_recog_metric.compute_global_AP("v")["mAP"])
self.log("test/t_mAP", self.test_recog_metric.compute_global_AP("t")["mAP"])
self.valid_tool_map(tool_logit, batch["tool"].to(torch.int))
self.valid_target_map(target_logit, batch["target"].to(torch.int))
self.valid_verb_map(verb_logit, batch["verb"].to(torch.int))
self.valid_triplet_map(triplet_logit, batch["triplet"].to(torch.int))
for i in range(len(batch['frame'])):
subprocess.run(["touch", f'video_{batch["video"][i][3:]}.json'])
with open(f'video_{batch["video"][i][3:]}.json', 'r+') as f:
try:
data = json.load(f)
except:
data = {}
data[int(batch["frame"][i])] = {
"recognition": triplet_logit.tolist()[i],
"detection": [
{
"triplet": int(np.argmax(triplet_logit[i], axis=-1)),
"instrument": [
random.randint(0, 5),
random.randint(0, 256),
random.randint(0, 256),
random.randint(0, 256),
random.randint(0, 256),
]
}
]
}
f.seek(0)
json.dump(data, f, indent=4)
f.truncate()
return loss
def test_epoch_end(self, outputs: List[Any]):
ivt_result = self.test_recog_metric.compute_global_AP("ivt")
pprint(ivt_result["AP"])
self.log(
"test/ivt_mAP",
ivt_result["mAP"],
)
self.log("test/i_mAP", self.test_recog_metric.compute_global_AP("i")["mAP"])
self.log("test/v_mAP", self.test_recog_metric.compute_global_AP("v")["mAP"])
self.log("test/t_mAP", self.test_recog_metric.compute_global_AP("t")["mAP"])
def on_epoch_end(self):
self.train_recog_metric.reset()
self.valid_recog_metric.reset()
self.test_recog_metric.reset()
def configure_optimizers(self):
opt = getattr(torch.optim, self.hparams.optim.optim_name)(
params=self.parameters(),
lr=self.hparams.optim.lr,
weight_decay=self.hparams.optim.weight_decay,
)
lr_scheduler = getattr(
torch.optim.lr_scheduler, self.hparams.optim.scheduler_name
)(
opt,
**self.hparams.optim.scheduler,
)
return [opt], [lr_scheduler]
Editor is loading...