Untitled

mail@pastecode.io avatar
unknown
c_cpp
5 months ago
13 kB
1
Indexable
#ifndef GPU_BOOTSTRAP_FFT_CUH
#define GPU_BOOTSTRAP_FFT_CUH

#include "polynomial/functions.cuh"
#include "polynomial/parameters.cuh"
#include "twiddles.cuh"
#include "types/complex/operations.cuh"

#define DUFFS_DEVICE(count, loop_body) \
  do {                                 \
    int _count = (count);              \
    int _n = (_count + 7) / 8;         \
    switch (_count % 8) {              \
    case 0: do { loop_body;            \
    case 7:      loop_body;            \
    case 6:      loop_body;            \
    case 5:      loop_body;            \
    case 4:      loop_body;            \
    case 3:      loop_body;            \
    case 2:      loop_body;            \
    case 1:      loop_body;            \
            } while (--_n > 0);        \
    }                                  \
  } while (false)

#define NSMFFT_DIRECT_LEVEL_1_LOOP_BODY                      \
  do {                                                       \
    i1 = tid;                                                \
    i2 = tid + params::degree / 2;                           \
                                                             \
    u = A[i1];                                               \
    v = A[i2] * (double2){0.707106781186547461715008466854,  \
                          0.707106781186547461715008466854}; \
                                                             \
    A[i1] += v;                                              \
    A[i2] = u - v;                                           \
                                                             \
    tid += params::degree / params::opt;                     \
  } while (false)

#define NSMFFT_DIRECT_LEVEL_1()                                     \
  do {                                                              \
    DUFFS_DEVICE(params::opt / 2, NSMFFT_DIRECT_LEVEL_1_LOOP_BODY); \
    __syncthreads();                                                \
  } while (false)

#define NSMFFT_DIRECT_LEVEL_LOOP_BODY(level)       \
  do {                                             \
    int _shift = 1 << (level);                     \
    twid_id = tid / (params::degree / _shift);     \
    i1 = 2 * (params::degree / _shift) * twid_id + \
        (tid & (params::degree / _shift - 1));     \
    i2 = i1 + params::degree / _shift;             \
                                                   \
    w = negtwiddles[twid_id + _shift / 2];         \
    u = A[i1];                                     \
    v = A[i2] * w;                                 \
                                                   \
    A[i1] += v;                                    \
    A[i2] = u - v;                                 \
                                                   \
    tid += params::degree / params::opt;           \
  } while (false)


#define NSMFFT_DIRECT_LEVEL(level)                                       \
  do {                                                                   \
    tid = threadIdx.x;                                                   \
    DUFFS_DEVICE(params::opt / 2, NSMFFT_DIRECT_LEVEL_LOOP_BODY(level)); \
    __syncthreads();                                                     \
  } while (false)

/*
 * Direct negacyclic FFT:
 *   - before the FFT the N real coefficients are stored into a
 *     N/2 sized complex with the even coefficients in the real part
 *     and the odd coefficients in the imaginary part. This is referred to
 *     as the half-size FFT
 *   - when calling BNSMFFT_direct for the forward negacyclic FFT of PBS,
 *     opt is divided by 2 because the butterfly pattern is always applied
 *     between pairs of coefficients
 *   - instead of twisting each coefficient A_j before the FFT by
 *     multiplying by the w^j roots of unity (aka twiddles, w=exp(-i pi /N)),
 *     the FFT is modified, and for each level k of the FFT the twiddle:
 *     w_j,k = exp(-i pi j/2^k)
 *     is replaced with:
 *     \zeta_j,k = exp(-i pi (2j-1)/2^k)
 */
template <class params> __device__ void NSMFFT_direct(double2 *A) {

  /* We don't make bit reverse here, since twiddles are already reversed
   *  Each thread is always in charge of "opt/2" pairs of coefficients,
   *  which is why we always loop through N/2 by N/opt strides
   *  The pragma unroll instruction tells the compiler to unroll the
   *  full loop, which should increase performance
   */

  size_t tid = threadIdx.x;
  size_t twid_id;
  size_t i1, i2;
  double2 u, v, w;
  // level 1
  // we don't make actual complex multiplication on level1 since we have only
  // one twiddle, it's real and image parts are equal, so we can multiply
  // it with simpler operations
  NSMFFT_DIRECT_LEVEL_1();

  // level 2
  // from this level there are more than one twiddles and none of them has equal
  // real and imag parts, so complete complex multiplication is needed
  // for each level params::degree / 2^level represents number of coefficients
  // inside divided chunk of specific level
  //
  NSMFFT_DIRECT_LEVEL(2);

  // level 3
  NSMFFT_DIRECT_LEVEL(3);

  // level 4
  NSMFFT_DIRECT_LEVEL(4);

  // level 5
  NSMFFT_DIRECT_LEVEL(5);

  // level 6
  NSMFFT_DIRECT_LEVEL(6);

  // level 7
  NSMFFT_DIRECT_LEVEL(7);

  // from level 8, we need to check size of params degree, because we support
  // minimum actual polynomial size = 256,  when compressed size is halfed and
  // minimum supported compressed size is 128, so we always need first 7
  // levels of butterfly operation, since butterfly levels are hardcoded
  // we need to check if polynomial size is big enough to require specific level
  // of butterfly.
  if constexpr (params::degree >= 256) {
    // level 8
    NSMFFT_DIRECT_LEVEL(8);
  }

  if constexpr (params::degree >= 512) {
    // level 9
    NSMFFT_DIRECT_LEVEL(9);
  }

  if constexpr (params::degree >= 1024) {
    // level 10
    NSMFFT_DIRECT_LEVEL(10);
  }

  if constexpr (params::degree >= 2048) {
    // level 11
    NSMFFT_DIRECT_LEVEL(11);
  }

  if constexpr (params::degree >= 4096) {
    // level 12
    NSMFFT_DIRECT_LEVEL(12);
  }

  if constexpr (params::degree >= 8192) {
    // level 13
    NSMFFT_DIRECT_LEVEL(13);
  }
}

#define NSMFFT_INVERSE_LEVEL_LOOP_BODY(level)      \
  do {                                             \
    int _shift = 1 << level;                       \
    twid_id = tid / (params::degree / _shift);     \
    i1 = 2 * (params::degree / _shift) * twid_id + \
        (tid & (params::degree / _shift - 1));     \
    i2 = i1 + params::degree / _shift;             \
                                                   \
    w = negtwiddles[twid_id + _shift / 2];         \
    u = A[i1] - A[i2];                             \
                                                   \
    A[i1] += A[i2];                                \
    A[i2] = u * conjugate(w);                      \
                                                   \
    tid += params::degree / params::opt;           \
  } while (false)

#define NSMFFT_INVERSE_LEVEL(level)                                       \
  do {                                                                    \
    tid = threadIdx.x;                                                    \
    DUFFS_DEVICE(params::opt / 2, NSMFFT_INVERSE_LEVEL_LOOP_BODY(level)); \
    __syncthreads();                                                      \
  } while (false)

/*
 * negacyclic inverse fft
 */
template <class params> __device__ void NSMFFT_inverse(double2 *A) {

  /* We don't make bit reverse here, since twiddles are already reversed
   *  Each thread is always in charge of "opt/2" pairs of coefficients,
   *  which is why we always loop through N/2 by N/opt strides
   *  The pragma unroll instruction tells the compiler to unroll the
   *  full loop, which should increase performance
   */

  size_t tid = threadIdx.x;
  size_t twid_id;
  size_t i1, i2;
  double2 u, w;

  // divide input by compressed polynomial size
  tid = threadIdx.x;
  for (size_t i = 0; i < params::opt; ++i) {
    A[tid] /= params::degree;
    tid += params::degree / params::opt;
  }
  __syncthreads();

  // none of the twiddles have equal real and imag part, so
  // complete complex multiplication has to be done
  // here we have more than one twiddle
  // mapping in backward fft is reversed
  // butterfly operation is started from last level

  if constexpr (params::degree >= 8192) {
    // level 13
    NSMFFT_INVERSE_LEVEL(13);
  }

  if constexpr (params::degree >= 4096) {
    // level 12
    NSMFFT_INVERSE_LEVEL(12);
  }

  if constexpr (params::degree >= 2048) {
    // level 11
    NSMFFT_INVERSE_LEVEL(11);
  }

  if constexpr (params::degree >= 1024) {
    // level 10
    NSMFFT_INVERSE_LEVEL(10);
  }

  if constexpr (params::degree >= 512) {
    // level 9
    NSMFFT_INVERSE_LEVEL(9);
  }

  if constexpr (params::degree >= 256) {
    // level 8
    NSMFFT_INVERSE_LEVEL(8);
  }

  // below level 8, we don't need to check size of params degree, because we
  // support minimum actual polynomial size = 256,  when compressed size is
  // halfed and minimum supported compressed size is 128, so we always need
  // last 7 levels of butterfly operation, since butterfly levels are hardcoded
  // we don't need to check if polynomial size is big enough to require
  // specific level of butterfly.
  // level 7
  NSMFFT_INVERSE_LEVEL(7);

  // level 6
  NSMFFT_INVERSE_LEVEL(6);

  // level 5
  NSMFFT_INVERSE_LEVEL(5);

  // level 4
  NSMFFT_INVERSE_LEVEL(4);

  // level 3
  NSMFFT_INVERSE_LEVEL(3);

  // level 2
  NSMFFT_INVERSE_LEVEL(2);

  // level 1
  NSMFFT_INVERSE_LEVEL(1);
}

#define BATCH_NSMFFT_LOOP_BODY(expr)          \
  do {                                        \
    expr;                                     \
    tid = tid + params::degree / params::opt; \
  } while (false)

#define BATCH_NSMFFT(expr)                                       \
  do {                                                           \
    tid = threadIdx.x;                                           \
    DUFFS_DEVICE(params::opt / 2, BATCH_NSMFFT_LOOP_BODY(expr)); \
  } while (false)

/*
 * global batch fft
 * does fft in half size
 * unrolling half size fft result in half size + 1 elements
 * this function must be called with actual degree
 * function takes as input already compressed input
 */
template <class params, sharedMemDegree SMD>
__global__ void batch_NSMFFT(double2 *d_input, double2 *d_output,
                             double2 *buffer) {
  extern __shared__ double2 sharedMemoryFFT[];
  double2 *fft = (SMD == NOSM) ? &buffer[blockIdx.x * params::degree / 2]
                               : sharedMemoryFFT;
  int tid;

  BATCH_NSMFFT(fft[tid] = d_input[blockIdx.x * (params::degree / 2) + tid]);
  __syncthreads();
  NSMFFT_direct<HalfDegree<params>>(fft);
  __syncthreads();

  BATCH_NSMFFT(d_output[blockIdx.x * (params::degree / 2) + tid] = fft[tid]);
}

/*
 * global batch polynomial multiplication
 * only used for fft tests
 * d_input1 and d_output must not have the same pointer
 * d_input1 can be modified inside the function
 */
template <class params, sharedMemDegree SMD>
__global__ void batch_polynomial_mul(double2 *d_input1, double2 *d_input2,
                                     double2 *d_output, double2 *buffer) {
  extern __shared__ double2 sharedMemoryFFT[];
  double2 *fft = (SMD == NOSM) ? &buffer[blockIdx.x * params::degree / 2]
                               : sharedMemoryFFT;

  // Move first polynomial into shared memory(if possible otherwise it will
  // be moved in device buffer)
  int tid;
  BATCH_NSMFFT(fft[tid] = d_input1[blockIdx.x * (params::degree / 2) + tid]);

  // Perform direct negacyclic fourier transform
  __syncthreads();
  NSMFFT_direct<HalfDegree<params>>(fft);
  __syncthreads();

  // Put the result of direct fft inside input1
  BATCH_NSMFFT(d_input1[blockIdx.x * (params::degree / 2) + tid] = fft[tid]);
  __syncthreads();

  // Move first polynomial into shared memory(if possible otherwise it will
  // be moved in device buffer)
  BATCH_NSMFFT(fft[tid] = d_input2[blockIdx.x * (params::degree / 2) + tid]);

  // Perform direct negacyclic fourier transform on the second polynomial
  __syncthreads();
  NSMFFT_direct<HalfDegree<params>>(fft);
  __syncthreads();

  // calculate pointwise multiplication inside fft buffer
  BATCH_NSMFFT(fft[tid] *= d_input1[blockIdx.x * (params::degree / 2) + tid]);

  // Perform backward negacyclic fourier transform
  __syncthreads();
  NSMFFT_inverse<HalfDegree<params>>(fft);
  __syncthreads();

  // copy results in output buffer
  BATCH_NSMFFT(d_output[blockIdx.x * (params::degree / 2) + tid] = fft[tid]);
}

#endif // GPU_BOOTSTRAP_FFT_CUH
Leave a Comment