Untitled

 avatar
unknown
plain_text
19 days ago
5.9 kB
8
Indexable
import torch
from torch.utils import data
import numpy as np
from os.path import join as pjoin
from smplx import SMPLH, SMPLX
from tokenization_hmr.utils.rotation_conversions import axis_angle_to_matrix
from tokenization_hmr.utils.skeleton import get_smplx_body_parts

def get_dataloader(hparams, split, shuffle=True, sequence_length=10):
    batch_size = hparams.DATA.BATCH_SIZE
    debug = hparams.EXP.DEBUG
    data_root = hparams.DATA.DATA_ROOT
    mask_body_parts = hparams.DATA.MASK_BODY_PARTS
    rot_type = hparams.ARCH.ROT_TYPE
    smpl_type = hparams.ARCH.SMPL_TYPE
    num_workers = 1

    if split == 'train':
        ds_list = hparams.DATA.TRAINLIST.split('_')
        partition = [1] if len(ds_list) == 1 else hparams.DATA.TRAIN_PART.split('_')
        assert len(ds_list) == len(partition), "Number of datasets and partition does not match"
    elif split == 'val':
        ds_list = hparams.DATA.VALLIST.split('_')
        partition = [1/len(ds_list)] * len(ds_list)
    elif split == 'test':
        ds_list = hparams.DATA.TESTLIST.split('_')
        partition = [1/len(ds_list)] * len(ds_list)

    print(f'List of datasets for {split} --> {ds_list} with shuffle = {shuffle}')

    if len(ds_list) == 1:
        dataset = VQPoseDataset(ds_list[0], split, data_root, rot_type, smpl_type, mask_body_parts, debug, sequence_length)
    else:
        if split == 'train':
            dataset = MixedTrainDataset(ds_list, partition, split, data_root, rot_type, smpl_type, mask_body_parts, debug, sequence_length)
        else:
            dataset = ValDataset(ds_list, split, data_root, rot_type, smpl_type, debug, sequence_length)

    loader = torch.utils.data.DataLoader(dataset,
                                         batch_size,
                                         shuffle=shuffle,
                                         num_workers=num_workers,
                                         drop_last=True)
    if split == 'train':
        return cycle(loader)
    else:
        return loader

class VQPoseDataset(data.Dataset):
    def __init__(self, dt, split='train', data_root='', rot_type='rotmat', smpl_type='smplx', mask_body_parts=False, debug=False, sequence_length=10):
        self.sequence_length = sequence_length
        self.data_root = pjoin(data_root, smpl_type, split)
        self.joints_num = 21
        self.smplx_body_parts = get_smplx_body_parts()
        self.mask_body_parts = mask_body_parts
        self.split = split
        self.smpl_type = smpl_type

        self.smpl_model = eval(f'{smpl_type.upper()}')(f'./tokenization_hmr/data/body_models/{smpl_type}', num_betas=10, ext='pkl')
        data = np.load(pjoin(self.data_root, f'{split}_{dt}.npz'))
        total_samples = data['pose_body'].shape[0]

        random_idx = None
        if debug:
            debug_data_length = 600
            random_idx = np.random.choice(total_samples, size=debug_data_length, replace=False)
            print(f'In debug mode, processing with less data')

        self.pose_body = data['pose_body'][random_idx] if random_idx is not None else data['pose_body']
        self.dataset_name = f'_{dt}'
        print(f"Processing {dt} for {split} with {self.pose_body.shape[0]} samples...")

    def __len__(self):
        return self.pose_body.shape[0] - self.sequence_length + 1

    def __getitem__(self, index):
        item = {}
        pose_sequence = self.pose_body[index: index + self.sequence_length]
        pose_sequence = torch.Tensor(pose_sequence).float()
        
        body_vertices_sequence = []
        body_joints_sequence = []
        
        gt_body_joints_sequence = []

        for pose in pose_sequence:
            pose_body_aa = pose.view(-1).float()
            body_model = self.smpl_model(body_pose=pose_body_aa.view(-1, pose_body_aa.shape[0]))
            body_vertices_sequence.append(body_model.vertices[0].detach().float())
            body_joints_sequence.append(body_model.joints[0].detach().float())
            gt_body_joints_sequence.append(axis_angle_to_matrix(pose.view(-1, 3)))
        
        item['pose_body_aa'] = pose_sequence.clone()
        item['body_vertices'] = torch.stack(body_vertices_sequence)
        item['body_joints'] = torch.stack(body_joints_sequence)
        item['gt_pose_body'] = torch.stack(gt_body_joints_sequence)
        item['dataset_name'] = self.dataset_name
        
        return item

class MixedTrainDataset(VQPoseDataset):
    def __init__(self, ds_list, partition, split, data_root, rot_type, smpl_type, mask_body_parts, debug, sequence_length=10):
        super().__init__(ds_list[0], split, data_root, rot_type, smpl_type, mask_body_parts, debug, sequence_length)
        self.ds_list = ds_list
        partition = [float(part) for part in partition]
        self.partition = np.array(partition).cumsum()
        self.datasets = [VQPoseDataset(ds, split, data_root, rot_type, smpl_type, mask_body_parts, debug, sequence_length) for ds in ds_list]
        self.length = max([len(ds) for ds in self.datasets])

    def __getitem__(self, index):
        p = np.random.rand()
        for i in range(len(self.ds_list)):
            if p <= self.partition[i]:
                return self.datasets[i][index % len(self.datasets[i])]

    def __len__(self):
        return self.length

class ValDataset(VQPoseDataset):
    def __init__(self, dataset_list, split='val', data_root='', rot_type='rotmat', smpl_type='smplx', debug=False, sequence_length=10):
        super().__init__(dataset_list[0], split, data_root, rot_type, smpl_type, False, debug, sequence_length)
        self.dataset_name = '_'.join(dataset_list)

    def __getitem__(self, index):
        return super().__getitem__(index)

def cycle(iterable):
    while True:
        for x in iterable:
            yield x
Editor is loading...
Leave a Comment