Untitled
unknown
plain_text
a year ago
1.7 kB
7
Indexable
Batch<DType, LType> operator*() const {
//The xt::range(min,max) is the half interval [min, max)
// Calculate the start and end indices for the batch
size_t startIdx = this->k_index * this->k_load->batch_size;
size_t data_len = this->k_load->ptr_dataset->len();
size_t endIdx = std::min(startIdx + this->k_load->batch_size, data_len);
//Check if the k_index is at the last batch and drop_last = false
if (!k_load->drop_last && this->k_index == ((data_len / k_load->batch_size)-1)) {
endIdx = data_len;
}
//Static casst k_load to tensor dataser
TensorDataset<DType, LType>* k_load_tense = static_cast<TensorDataset<DType, LType>*>(this->k_load->ptr_dataset);
// Extract batch data and labels using xtensor view
size_t batch_size = endIdx - startIdx;
xt::xarray<DType> batch_data = xt::empty<DType>(this->k_load->ptr_dataset->get_data_shape());
xt::xarray<LType> batch_label = xt::empty<LType>(this->k_load->ptr_dataset->get_label_shape());
batch_data[0] = batch_size;
batch_label[0] = batch_size;
for(int i=startIdx; i<endIdx; i++) {
int shuffle_idx = this->k_load->k_indices(i);
DataLabel<DType, LType> item = k_load_tense->getitem(shuffle_idx);
xt::view(batch_data, i - startIdx, xt::all()) = item.getData();
xt::xarray<LType> label_get = item.getLabel();
if (k_load_tense->get_label_shape().size() != 0) {
xt::view(batch_label, i - startIdx, xt::all()) = label_get;
}
}
// Creating the result batch
Batch<DType, LType> batch_getty = Batch<DType, LType>(batch_data, batch_label);
return batch_getty;
}Editor is loading...
Leave a Comment