Untitled
unknown
plain_text
19 days ago
5.9 kB
8
Indexable
import torch from torch.utils import data import numpy as np from os.path import join as pjoin from smplx import SMPLH, SMPLX from tokenization_hmr.utils.rotation_conversions import axis_angle_to_matrix from tokenization_hmr.utils.skeleton import get_smplx_body_parts def get_dataloader(hparams, split, shuffle=True, sequence_length=10): batch_size = hparams.DATA.BATCH_SIZE debug = hparams.EXP.DEBUG data_root = hparams.DATA.DATA_ROOT mask_body_parts = hparams.DATA.MASK_BODY_PARTS rot_type = hparams.ARCH.ROT_TYPE smpl_type = hparams.ARCH.SMPL_TYPE num_workers = 1 if split == 'train': ds_list = hparams.DATA.TRAINLIST.split('_') partition = [1] if len(ds_list) == 1 else hparams.DATA.TRAIN_PART.split('_') assert len(ds_list) == len(partition), "Number of datasets and partition does not match" elif split == 'val': ds_list = hparams.DATA.VALLIST.split('_') partition = [1/len(ds_list)] * len(ds_list) elif split == 'test': ds_list = hparams.DATA.TESTLIST.split('_') partition = [1/len(ds_list)] * len(ds_list) print(f'List of datasets for {split} --> {ds_list} with shuffle = {shuffle}') if len(ds_list) == 1: dataset = VQPoseDataset(ds_list[0], split, data_root, rot_type, smpl_type, mask_body_parts, debug, sequence_length) else: if split == 'train': dataset = MixedTrainDataset(ds_list, partition, split, data_root, rot_type, smpl_type, mask_body_parts, debug, sequence_length) else: dataset = ValDataset(ds_list, split, data_root, rot_type, smpl_type, debug, sequence_length) loader = torch.utils.data.DataLoader(dataset, batch_size, shuffle=shuffle, num_workers=num_workers, drop_last=True) if split == 'train': return cycle(loader) else: return loader class VQPoseDataset(data.Dataset): def __init__(self, dt, split='train', data_root='', rot_type='rotmat', smpl_type='smplx', mask_body_parts=False, debug=False, sequence_length=10): self.sequence_length = sequence_length self.data_root = pjoin(data_root, smpl_type, split) self.joints_num = 21 self.smplx_body_parts = get_smplx_body_parts() self.mask_body_parts = mask_body_parts self.split = split self.smpl_type = smpl_type self.smpl_model = eval(f'{smpl_type.upper()}')(f'./tokenization_hmr/data/body_models/{smpl_type}', num_betas=10, ext='pkl') data = np.load(pjoin(self.data_root, f'{split}_{dt}.npz')) total_samples = data['pose_body'].shape[0] random_idx = None if debug: debug_data_length = 600 random_idx = np.random.choice(total_samples, size=debug_data_length, replace=False) print(f'In debug mode, processing with less data') self.pose_body = data['pose_body'][random_idx] if random_idx is not None else data['pose_body'] self.dataset_name = f'_{dt}' print(f"Processing {dt} for {split} with {self.pose_body.shape[0]} samples...") def __len__(self): return self.pose_body.shape[0] - self.sequence_length + 1 def __getitem__(self, index): item = {} pose_sequence = self.pose_body[index: index + self.sequence_length] pose_sequence = torch.Tensor(pose_sequence).float() body_vertices_sequence = [] body_joints_sequence = [] gt_body_joints_sequence = [] for pose in pose_sequence: pose_body_aa = pose.view(-1).float() body_model = self.smpl_model(body_pose=pose_body_aa.view(-1, pose_body_aa.shape[0])) body_vertices_sequence.append(body_model.vertices[0].detach().float()) body_joints_sequence.append(body_model.joints[0].detach().float()) gt_body_joints_sequence.append(axis_angle_to_matrix(pose.view(-1, 3))) item['pose_body_aa'] = pose_sequence.clone() item['body_vertices'] = torch.stack(body_vertices_sequence) item['body_joints'] = torch.stack(body_joints_sequence) item['gt_pose_body'] = torch.stack(gt_body_joints_sequence) item['dataset_name'] = self.dataset_name return item class MixedTrainDataset(VQPoseDataset): def __init__(self, ds_list, partition, split, data_root, rot_type, smpl_type, mask_body_parts, debug, sequence_length=10): super().__init__(ds_list[0], split, data_root, rot_type, smpl_type, mask_body_parts, debug, sequence_length) self.ds_list = ds_list partition = [float(part) for part in partition] self.partition = np.array(partition).cumsum() self.datasets = [VQPoseDataset(ds, split, data_root, rot_type, smpl_type, mask_body_parts, debug, sequence_length) for ds in ds_list] self.length = max([len(ds) for ds in self.datasets]) def __getitem__(self, index): p = np.random.rand() for i in range(len(self.ds_list)): if p <= self.partition[i]: return self.datasets[i][index % len(self.datasets[i])] def __len__(self): return self.length class ValDataset(VQPoseDataset): def __init__(self, dataset_list, split='val', data_root='', rot_type='rotmat', smpl_type='smplx', debug=False, sequence_length=10): super().__init__(dataset_list[0], split, data_root, rot_type, smpl_type, False, debug, sequence_length) self.dataset_name = '_'.join(dataset_list) def __getitem__(self, index): return super().__getitem__(index) def cycle(iterable): while True: for x in iterable: yield x
Editor is loading...
Leave a Comment