Untitled

 avatar
unknown
python
3 years ago
1.0 kB
6
Indexable
class SensorDataset(data.Dataset):
    def __init__(self, data, add_data, labels, dataaug=False):
        self.data = data
        self.add_data = add_data
        self.labels = labels
        self.dataaug = dataaug
        self.padding_size = padding(self.data[1].shape[-1])
        self.classes = np.unique(labels)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        sample = self.data[idx].squeeze()
        add_data_sample = []
        if self.add_data is not None:
            add_data_sample = self.add_data[idx]

        if self.dataaug:
            if np.random.rand() > 0.5:
                sample = DA_Rotation(sample)

        sample = np.transpose(sample, (1, 0))

        sample = sample.astype(np.float32)
        #sample = np.pad(sample, ((0, 0), (self.padding_size, self.padding_size)), mode='constant')
        target = self.labels[idx]

        return sample, add_data_sample, target
Editor is loading...