Untitled
unknown
python
3 years ago
1.0 kB
6
Indexable
class SensorDataset(data.Dataset): def __init__(self, data, add_data, labels, dataaug=False): self.data = data self.add_data = add_data self.labels = labels self.dataaug = dataaug self.padding_size = padding(self.data[1].shape[-1]) self.classes = np.unique(labels) def __len__(self): return len(self.data) def __getitem__(self, idx): if torch.is_tensor(idx): idx = idx.tolist() sample = self.data[idx].squeeze() add_data_sample = [] if self.add_data is not None: add_data_sample = self.add_data[idx] if self.dataaug: if np.random.rand() > 0.5: sample = DA_Rotation(sample) sample = np.transpose(sample, (1, 0)) sample = sample.astype(np.float32) #sample = np.pad(sample, ((0, 0), (self.padding_size, self.padding_size)), mode='constant') target = self.labels[idx] return sample, add_data_sample, target
Editor is loading...