codeee

mail@pastecode.io avatar
unknown
python
a year ago
16 kB
5
Indexable
Never
import json 
from os.path import join 
import glob
from sklearn import preprocessing

import pandas as pd
import numpy as np
import os
from sklearn.model_selection import StratifiedShuffleSplit
from tqdm import tqdm
os.environ['CUDA_VISIBLE_DEVICES']='0'

from torch.utils.data import Dataset
import matplotlib.pyplot as plt
from torchvision import transforms
import torch 
from torch import nn
from PIL import Image
from collections.abc import Iterable
import logging
from tqdm import tqdm
import torch.nn.functional as F
import glob
import albumentations as A

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim

import pytorch_lightning as pl
from pytorch_lightning import LightningDataModule, LightningModule
import random
from sklearn.model_selection import StratifiedKFold, GroupKFold, KFold, StratifiedGroupKFold

import sys 
sys.path.append("../lib/pytorch-image-models")
import timm


RANDOM_SEED = 0
FOLD = 0
N_FOLDS = 2
SIZE = 384
BATCH_SIZE = 32
NUM_WORKER = 4
EPOCHS = 10
INIT_LR = 1e-4
MIN_LR = 1e-6
WEIGHT_DECAY = 1e-4


def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
seed_everything(42)


def get_class_name_list(json_folder_dir):
    label_list = []
    file_list = glob.glob(join(json_folder_dir, '*.json'))
    for file_path in tqdm(file_list):
        with open (file_path, "r") as file: 
            json_file = json.loads(file.read())
        label_list.append(json_file['chart-type'])
    return list(set(label_list))


def create_label_encoder(label_name_list):
    encoder = preprocessing.LabelEncoder()
    encoder.fit(label_name_list)
    return encoder


def get_label_num(Label_Name):
        return int(encoder_label.transform([Label_Name])[0])


def create_csv_data(json_folder_dir):
    label_list = []
    img_list = []
    file_list = glob.glob(join(json_folder_dir, '*.json'))
    
    for file_path in tqdm(file_list):
        with open (file_path, "r") as file: 
            json_file = json.loads(file.read())
        label_list.append(json_file['chart-type'])

        tmp = file_path.replace('annotations', 'images')
        img_dir = tmp.replace('.json', '.jpg')
        img_list.append(img_dir)

    csv_file = pd.DataFrame(list(zip(img_list, label_list)),
                            columns =['Image_Path', 'Label_Name'])        
    return csv_file



#############################
#         Dataloader        #
#############################
class Graph_Dataset(Dataset):
    def __init__(self, img_path, label_list, transform, train_aug, train_bool=True):
        self.img_path = img_path
        self.label_list = label_list
        self.transform = transform
        self.train_aug = train_aug
        self.train_bool = train_bool
    def __len__(self,):
        return len(self.img_path)
    def __getitem__(self, index):
        img = self.img_path[index]
        img = Image.open(img).convert('RGB')
        img = np.array(img)
        label = int(self.label_list[index])

        if self.train_bool:
            img = self.train_aug(image=img)["image"]
        img = self.transform(img)
        return {'image':img, 'labels':label }

transform_train = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], 
                        std=[0.26862954, 0.26130258, 0.27577711])
    ])

train_aug = A.Compose([
        A.Resize(SIZE, SIZE),
        A.HorizontalFlip(p=0.5),
        A.ImageCompression(quality_lower=99, quality_upper=100),
        A.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=10, border_mode=0, p=0.7),
        A.Cutout(max_h_size=int(224 * 0.4), max_w_size=int(224 * 0.4), num_holes=1, p=0.5),
    ])

transform_val = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((SIZE,SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], 
                        std=[0.26862954, 0.26130258, 0.27577711])
    ])



#############################
#     Lightning Loader      #
#############################
class Graph_Dataloader(LightningDataModule):
    def __init__(self, imgs_train, label_train, imgs_val, label_val):
        super().__init__()
        self.train_transform = train_aug
        self.imgs_train = imgs_train
        self.label_train = label_train
        self.imgs_val = imgs_val
        self.label_val = label_val

    def prepare_data(self):
        pass 

    def setup(self, *_, **__) -> None:
        self.train_dataset = Graph_Dataset(self.imgs_train, self.label_train, transform_train, train_aug, train_bool=True)
        logging.info(f"training dataset: {len(self.train_dataset)}")

        self.val_dataset = Graph_Dataset(self.imgs_val, self.label_val, transform_val, train_aug, train_bool=False)
        logging.info(f"val dataset: {len(self.val_dataset)}")


    def train_dataloader(self) -> torch.utils.data.DataLoader:
        return torch.utils.data.DataLoader(self.train_dataset,
                                            batch_size=BATCH_SIZE ,
                                            num_workers=NUM_WORKER,
                                            pin_memory=True,
                                            drop_last=False, 
                                            shuffle=True,)

    def val_dataloader(self) -> torch.utils.data.DataLoader:
        return torch.utils.data.DataLoader(self.val_dataset,
                                            batch_size=BATCH_SIZE ,
                                            num_workers=NUM_WORKER,
                                            pin_memory=True,
                                            drop_last=False, 
                                            shuffle=True,)

#############################
#          Model            #
#############################
class ConvNext_Baseline(nn.Module):
    def __init__(self, arch='convnext_base_384_in22ft1k', num_classes=1, pretrained=False):
        super().__init__()
        self.arch = arch 
        self.num_classes = num_classes
        self.pretrained = pretrained 
        self.encoder = timm.create_model(self.arch, pretrained=self.pretrained, num_classes=self.num_classes, drop_path_rate=0.2)
        self.num_feature = self.encoder.head.fc.in_features
        self.fc = nn.Linear(self.num_feature, self.num_classes,bias=True)
        self.encoder.head.fc = nn.Identity()

    def forward(self, x):
        feature = self.encoder.forward(x)
        output = self.fc(feature)
        return {'feature': feature,'score': output}
    


#############################
#       Training Loop       #
#############################
from sklearn.metrics import f1_score

def get_score(y_truth, y_pred):
    f1score = f1_score(y_truth, y_pred, average='macro')
    return {"F1-Score":f1score}

class BaseTraining(LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.loss = torch.nn.CrossEntropyLoss()
    
    ## Cross Entropy
    def compute_loss(self, y_pred, y):
        return self.loss(y_pred, y)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        image, label = batch['image'], batch['labels']
        y_pred = self.forward(image)['score']
        loss = self.compute_loss(y_pred, label)

        self.log("loss", loss, prog_bar=True)
        return loss


    def validation_step(self, batch, batch_idx):
        image, label = batch['image'], batch['labels']
        y_pred = self.forward(image)['score']
        loss = self.compute_loss(y_pred, label)

        return {"preds":torch.argmax(y_pred.softmax(1), dim=1).detach().to('cpu').numpy(), 
                "truth":label.clone().detach().tolist(),
                "loss": loss.to('cpu').item()}


    def validation_epoch_end(self, outputs):
        predict_list = [] 
        target_list = []
        loss_list = [] 
        
        for out in outputs:
            predict_list.append(out['preds'])
            target_list.extend(out['truth'])
            loss_list.append(out['loss'])

        predict_list = np.concatenate(predict_list)
        target_list  = np.array(target_list)
        score = get_score(list(target_list), list(predict_list))
        loss = np.array(loss_list).mean()
        
        self.log("F1", score['F1-Score'],sync_dist=True,on_epoch=True,prog_bar=True)
        self.log('Loss_Val',loss,sync_dist=True,on_epoch=True,prog_bar=True)
        
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=INIT_LR)    
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, self.trainer.max_epochs, MIN_LR) 
        return [optimizer], [scheduler]

if __name__ == '__main__':
    
    # label_name = get_class_name_list("./data/benetech-making-graphs-accessible/train/annotations/*.json")
    label_name = ['scatter', 'dot', 'line', 'vertical_bar', 'horizontal_bar']
    encoder_label = create_label_encoder(label_name)

    data = create_csv_data('./data/benetech-making-graphs-accessible/train/annotations')
    data['class_id'] = data['Label_Name'].apply(get_label_num)


    ## Fold Splitting
    Fold = StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=RANDOM_SEED)
    for n, (train_index, val_index) in enumerate(Fold.split(data, data["class_id"])):
        data.loc[val_index, 'fold'] = int(n)
    data['fold'] = data['fold'].astype(int)

    train_csv = data[data["fold"] != FOLD].reset_index(drop=True)
    val_csv   = data[data["fold"] == FOLD].reset_index(drop=True)


    ## Data Setup
    DataModule = Graph_Dataloader(list(train_csv['Image_Path'].values),list(train_csv['class_id'].values),
                                  list(val_csv['Image_Path'].values),list(val_csv['class_id'].values))
    DataModule.setup() 
    

    ## Setup Model 
    Backbone = ConvNext_Baseline(arch='convnext_base_384_in22ft1k',
                                num_classes=len(label_name), 
                                pretrained=True)
    model= BaseTraining(model=Backbone)


    ## Logger
    name_logger = f'ConvNext_{SIZE}x{SIZE}_Fold{FOLD}' 
    logger = pl.loggers.CSVLogger(save_dir='./results/', name=name_logger)
    print("---->",name_logger)


    ## Checkpoint 
    ckpt = pl.callbacks.ModelCheckpoint(monitor='F1',
                                            save_top_k=2,
                                            save_last=True,     
                                            save_weights_only=True,
                                            filename='{epoch:02d}-{F1:.3f}-{Loss_Val:.3f}',
                                            # verbose=False,
                                            mode='max',)  
    
    ## Trainier 
    trainer = pl.Trainer(
            accelerator='gpu',
            move_metrics_to_cpu=False,
            gpus=[0],
            callbacks=[ckpt],
            logger=logger,
            max_epochs=20,
            precision=32, # 16bit autocast 
            accumulate_grad_batches=2,
            check_val_every_n_epoch=1,
            # progress_bar_refresh_rate=1,                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  
            # weights_summary='top',
            # resume_from_checkpoint='/home/fruit/Desktop/CVPR-InsightFace/1-Baseline/output/ConvNext_T/ConvNext_BCE_224x224_Fold4/version_0/checkpoints/last.ckpt'
        )
    trainer.fit(model=model, datamodule=DataModule)