Untitled

 avatar
unknown
plain_text
5 months ago
1.7 kB
3
Indexable
Batch<DType, LType> operator*() const {
      //The xt::range(min,max) is the half interval [min, max)
      // Calculate the start and end indices for the batch
      size_t startIdx = this->k_index * this->k_load->batch_size;
      size_t data_len = this->k_load->ptr_dataset->len();
      size_t endIdx = std::min(startIdx + this->k_load->batch_size, data_len);

      //Check if the k_index is at the last batch and drop_last = false
      if (!k_load->drop_last && this->k_index == ((data_len / k_load->batch_size)-1)) {
        endIdx = data_len;
      } 
      
      //Static casst k_load to tensor dataser
      TensorDataset<DType, LType>* k_load_tense = static_cast<TensorDataset<DType, LType>*>(this->k_load->ptr_dataset);
  
      // Extract batch data and labels using xtensor view
      size_t batch_size = endIdx - startIdx;
      xt::xarray<DType> batch_data = xt::empty<DType>(this->k_load->ptr_dataset->get_data_shape());
      xt::xarray<LType> batch_label = xt::empty<LType>(this->k_load->ptr_dataset->get_label_shape());
      batch_data[0] = batch_size;
      batch_label[0] = batch_size;
      for(int i=startIdx; i<endIdx; i++) {
        int shuffle_idx = this->k_load->k_indices(i);
        DataLabel<DType, LType> item = k_load_tense->getitem(shuffle_idx);
        xt::view(batch_data, i - startIdx, xt::all()) = item.getData();
        xt::xarray<LType> label_get = item.getLabel();
        if (k_load_tense->get_label_shape().size() != 0) {
            xt::view(batch_label, i - startIdx, xt::all()) = label_get;
        }
      }
      // Creating the result batch
      Batch<DType, LType> batch_getty = Batch<DType, LType>(batch_data, batch_label);
      return batch_getty;
    }
Editor is loading...
Leave a Comment