Untitled
unknown
plain_text
9 months ago
5.9 kB
11
Indexable
import torch
from torch.utils import data
import numpy as np
from os.path import join as pjoin
from smplx import SMPLH, SMPLX
from tokenization_hmr.utils.rotation_conversions import axis_angle_to_matrix
from tokenization_hmr.utils.skeleton import get_smplx_body_parts
def get_dataloader(hparams, split, shuffle=True, sequence_length=10):
batch_size = hparams.DATA.BATCH_SIZE
debug = hparams.EXP.DEBUG
data_root = hparams.DATA.DATA_ROOT
mask_body_parts = hparams.DATA.MASK_BODY_PARTS
rot_type = hparams.ARCH.ROT_TYPE
smpl_type = hparams.ARCH.SMPL_TYPE
num_workers = 1
if split == 'train':
ds_list = hparams.DATA.TRAINLIST.split('_')
partition = [1] if len(ds_list) == 1 else hparams.DATA.TRAIN_PART.split('_')
assert len(ds_list) == len(partition), "Number of datasets and partition does not match"
elif split == 'val':
ds_list = hparams.DATA.VALLIST.split('_')
partition = [1/len(ds_list)] * len(ds_list)
elif split == 'test':
ds_list = hparams.DATA.TESTLIST.split('_')
partition = [1/len(ds_list)] * len(ds_list)
print(f'List of datasets for {split} --> {ds_list} with shuffle = {shuffle}')
if len(ds_list) == 1:
dataset = VQPoseDataset(ds_list[0], split, data_root, rot_type, smpl_type, mask_body_parts, debug, sequence_length)
else:
if split == 'train':
dataset = MixedTrainDataset(ds_list, partition, split, data_root, rot_type, smpl_type, mask_body_parts, debug, sequence_length)
else:
dataset = ValDataset(ds_list, split, data_root, rot_type, smpl_type, debug, sequence_length)
loader = torch.utils.data.DataLoader(dataset,
batch_size,
shuffle=shuffle,
num_workers=num_workers,
drop_last=True)
if split == 'train':
return cycle(loader)
else:
return loader
class VQPoseDataset(data.Dataset):
def __init__(self, dt, split='train', data_root='', rot_type='rotmat', smpl_type='smplx', mask_body_parts=False, debug=False, sequence_length=10):
self.sequence_length = sequence_length
self.data_root = pjoin(data_root, smpl_type, split)
self.joints_num = 21
self.smplx_body_parts = get_smplx_body_parts()
self.mask_body_parts = mask_body_parts
self.split = split
self.smpl_type = smpl_type
self.smpl_model = eval(f'{smpl_type.upper()}')(f'./tokenization_hmr/data/body_models/{smpl_type}', num_betas=10, ext='pkl')
data = np.load(pjoin(self.data_root, f'{split}_{dt}.npz'))
total_samples = data['pose_body'].shape[0]
random_idx = None
if debug:
debug_data_length = 600
random_idx = np.random.choice(total_samples, size=debug_data_length, replace=False)
print(f'In debug mode, processing with less data')
self.pose_body = data['pose_body'][random_idx] if random_idx is not None else data['pose_body']
self.dataset_name = f'_{dt}'
print(f"Processing {dt} for {split} with {self.pose_body.shape[0]} samples...")
def __len__(self):
return self.pose_body.shape[0] - self.sequence_length + 1
def __getitem__(self, index):
item = {}
pose_sequence = self.pose_body[index: index + self.sequence_length]
pose_sequence = torch.Tensor(pose_sequence).float()
body_vertices_sequence = []
body_joints_sequence = []
gt_body_joints_sequence = []
for pose in pose_sequence:
pose_body_aa = pose.view(-1).float()
body_model = self.smpl_model(body_pose=pose_body_aa.view(-1, pose_body_aa.shape[0]))
body_vertices_sequence.append(body_model.vertices[0].detach().float())
body_joints_sequence.append(body_model.joints[0].detach().float())
gt_body_joints_sequence.append(axis_angle_to_matrix(pose.view(-1, 3)))
item['pose_body_aa'] = pose_sequence.clone()
item['body_vertices'] = torch.stack(body_vertices_sequence)
item['body_joints'] = torch.stack(body_joints_sequence)
item['gt_pose_body'] = torch.stack(gt_body_joints_sequence)
item['dataset_name'] = self.dataset_name
return item
class MixedTrainDataset(VQPoseDataset):
def __init__(self, ds_list, partition, split, data_root, rot_type, smpl_type, mask_body_parts, debug, sequence_length=10):
super().__init__(ds_list[0], split, data_root, rot_type, smpl_type, mask_body_parts, debug, sequence_length)
self.ds_list = ds_list
partition = [float(part) for part in partition]
self.partition = np.array(partition).cumsum()
self.datasets = [VQPoseDataset(ds, split, data_root, rot_type, smpl_type, mask_body_parts, debug, sequence_length) for ds in ds_list]
self.length = max([len(ds) for ds in self.datasets])
def __getitem__(self, index):
p = np.random.rand()
for i in range(len(self.ds_list)):
if p <= self.partition[i]:
return self.datasets[i][index % len(self.datasets[i])]
def __len__(self):
return self.length
class ValDataset(VQPoseDataset):
def __init__(self, dataset_list, split='val', data_root='', rot_type='rotmat', smpl_type='smplx', debug=False, sequence_length=10):
super().__init__(dataset_list[0], split, data_root, rot_type, smpl_type, False, debug, sequence_length)
self.dataset_name = '_'.join(dataset_list)
def __getitem__(self, index):
return super().__getitem__(index)
def cycle(iterable):
while True:
for x in iterable:
yield x
Editor is loading...
Leave a Comment