Untitled
unknown
plain_text
5 months ago
1.7 kB
3
Indexable
Batch<DType, LType> operator*() const { //The xt::range(min,max) is the half interval [min, max) // Calculate the start and end indices for the batch size_t startIdx = this->k_index * this->k_load->batch_size; size_t data_len = this->k_load->ptr_dataset->len(); size_t endIdx = std::min(startIdx + this->k_load->batch_size, data_len); //Check if the k_index is at the last batch and drop_last = false if (!k_load->drop_last && this->k_index == ((data_len / k_load->batch_size)-1)) { endIdx = data_len; } //Static casst k_load to tensor dataser TensorDataset<DType, LType>* k_load_tense = static_cast<TensorDataset<DType, LType>*>(this->k_load->ptr_dataset); // Extract batch data and labels using xtensor view size_t batch_size = endIdx - startIdx; xt::xarray<DType> batch_data = xt::empty<DType>(this->k_load->ptr_dataset->get_data_shape()); xt::xarray<LType> batch_label = xt::empty<LType>(this->k_load->ptr_dataset->get_label_shape()); batch_data[0] = batch_size; batch_label[0] = batch_size; for(int i=startIdx; i<endIdx; i++) { int shuffle_idx = this->k_load->k_indices(i); DataLabel<DType, LType> item = k_load_tense->getitem(shuffle_idx); xt::view(batch_data, i - startIdx, xt::all()) = item.getData(); xt::xarray<LType> label_get = item.getLabel(); if (k_load_tense->get_label_shape().size() != 0) { xt::view(batch_label, i - startIdx, xt::all()) = label_get; } } // Creating the result batch Batch<DType, LType> batch_getty = Batch<DType, LType>(batch_data, batch_label); return batch_getty; }
Editor is loading...
Leave a Comment