Untitled
unknown
python
3 years ago
1.0 kB
7
Indexable
class SensorDataset(data.Dataset):
def __init__(self, data, add_data, labels, dataaug=False):
self.data = data
self.add_data = add_data
self.labels = labels
self.dataaug = dataaug
self.padding_size = padding(self.data[1].shape[-1])
self.classes = np.unique(labels)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
sample = self.data[idx].squeeze()
add_data_sample = []
if self.add_data is not None:
add_data_sample = self.add_data[idx]
if self.dataaug:
if np.random.rand() > 0.5:
sample = DA_Rotation(sample)
sample = np.transpose(sample, (1, 0))
sample = sample.astype(np.float32)
#sample = np.pad(sample, ((0, 0), (self.padding_size, self.padding_size)), mode='constant')
target = self.labels[idx]
return sample, add_data_sample, targetEditor is loading...