Untitled

 avatar
unknown
plain_text
14 days ago
9.9 kB
9
Indexable
#include "../../include/pkg/election.hpp"
#include "../../include-shared/logger.hpp"
#include <crypto++/osrng.h>
#include <crypto++/nbtheory.h>

using namespace CryptoPP;

/*
Syntax to use logger:
  CUSTOM_LOG(lg, debug) << "your message"
See logger.hpp for more modes besides 'debug'
*/
namespace {
src::severity_logger<logging::trivial::severity_level> lg;
}

/**
 * Generate Vote and ZKP.
 */
std::pair<Vote_Ciphertext, VoteZKP_Struct>
ElectionClient::GenerateVote(CryptoPP::Integer vote, CryptoPP::Integer pk) {
  //TODO: implement me!


    if (vote != 0 && vote != 1) {
        throw std::runtime_error("GenerateVote error: only votes of 0 or 1 are supported!");
    }
  AutoSeededRandomPool rng;

  // Pick random r in [1..q-1]
  SecByteBlock rb(DL_Q.ByteCount());
  rng.GenerateBlock(rb, rb.size());
  Integer r = Integer(rb, rb.size()) % DL_Q;
  while (r == Integer::Zero()) { 
    rng.GenerateBlock(rb, rb.size());
    r = Integer(rb, rb.size()) % DL_Q;
  };

  //    a = g^r mod p
  //    b = (pk^r)*(g^vote) mod p
  Vote_Ciphertext ciphertext;
  ciphertext.a = a_exp_b_mod_c(DL_G, r, DL_P);   // g^r mod p
  Integer pk_to_r = a_exp_b_mod_c(pk, r, DL_P);  // pk^r mod p

  // If vote == 0, then g^vote = g^0 = 1
  // If vote == 1, then g^vote = g
  Integer g_to_v;
  if (vote == Integer::One()) {
    g_to_v = a_exp_b_mod_c(DL_G, Integer::One(), DL_P);  // equivalent to g^1 mod p
  } else {
    g_to_v = Integer::One();  // simply 1
  }

  ciphertext.b = (pk_to_r * g_to_v) % DL_P;

  VoteZKP_Struct zkp;

  // Here i write some helper functions for the mod functions just to make life easier
  auto modMul = [&](const Integer &x, const Integer &y){
    return (x * y) % DL_P;
  };
  auto modExp = [&](const Integer &base, const Integer &exp){
    return a_exp_b_mod_c(base, exp, DL_P);
  };
  auto modInv = [&](const Integer &x){
    return EuclideanMultiplicativeInverse(x, DL_P);
  };
  auto random_in_mod_q = [&](void) {
    Integer x;
    do {
      SecByteBlock tmp(DL_Q.ByteCount());
      rng.GenerateBlock(tmp, tmp.size());
      x = Integer(tmp, tmp.size()) % DL_Q;
    } while (x == Integer::Zero());
    return x;
  };

  auto bDivG = [&](const Integer &b_val){
    // b * g^{-1} mod p
    return modMul(b_val, modInv(DL_G));
  };

  const Integer &a = ciphertext.a;
  const Integer &b = ciphertext.b;

  Integer a0_, b0_, a1_, b1_;
  Integer c0, c1;
  Integer r0, r1;

  if (vote == Integer::Zero()) {
    // simulate proof for encryption of 1
    c1 = random_in_mod_q();
    Integer r1_doubleprime = random_in_mod_q();

    // a1 = g^(r1'') / a^c1
    {
      Integer g_r1dp = modExp(DL_G, r1_doubleprime);
      Integer a_c1   = modExp(a, c1);
      a1_ = modMul(g_r1dp, modInv(a_c1));
    }
    // b1 = pk^(r1'') / [(b/g)^c1]
    {
      Integer pk_r1dp = modExp(pk, r1_doubleprime);
      Integer b_over_g = bDivG(b);    // pk^r * g^-1
      Integer b_over_g_c1 = modExp(b_over_g, c1);
      b1_ = modMul(pk_r1dp, modInv(b_over_g_c1));
    }

    //Real proof for 0 encr
  
    Integer r0_prime = random_in_mod_q();
    a0_ = modExp(DL_G, r0_prime);  // g^(r0')
    b0_ = modExp(pk, r0_prime);    // pk^(r0')


    Integer bigC = hash_vote_zkp(pk, a, b, a0_, b0_, a1_, b1_) % DL_Q;

    c0 = (bigC >= c1) ? bigC - c1 : (bigC + DL_Q - c1);
    c0 = c0 % DL_Q;

    r0 = (r0_prime + (c0 * r)) % DL_Q;
    r1 = r1_doubleprime;
  }
  else {
    // vote is 1 then 
    // sim for vote being 0
    c0 = random_in_mod_q();
    Integer r0_doubleprime = random_in_mod_q();

    // a0 = g^r0'' / a^c0
    {
      Integer g_r0dp = modExp(DL_G, r0_doubleprime);
      Integer a_c0   = modExp(a, c0);
      a0_ = modMul(g_r0dp, modInv(a_c0));
    }
    // b0 = pk^r0'' / [b^c0]
    {
      Integer pk_r0dp = modExp(pk, r0_doubleprime);
      Integer b_c0    = modExp(b, c0);
      b0_ = modMul(pk_r0dp, modInv(b_c0));
    }

    
    // real encr 1 proof
    
    Integer r1_prime = random_in_mod_q();
    a1_ = modExp(DL_G, r1_prime);  // g^(r1')
    b1_ = modExp(pk, r1_prime);    // pk^(r1')

    Integer bigC = hash_vote_zkp(pk, a, b, a0_, b0_, a1_, b1_) % DL_Q;

    c1 = (bigC >= c0) ? bigC - c0 : (bigC + DL_Q - c0);
    c1 = c1 % DL_Q;

    r1 = (r1_prime + (c1 * r)) % DL_Q;
    r0 = r0_doubleprime;
  }

  zkp.a0 = a0_;
  zkp.a1 = a1_;
  zkp.b0 = b0_;
  zkp.b1 = b1_;
  zkp.c0 = c0;
  zkp.c1 = c1;
  zkp.r0 = r0;
  zkp.r1 = r1;

  return std::make_pair(ciphertext, zkp);
}

/**
 * Verify vote zkp.
 */
bool ElectionClient::VerifyVoteZKP(
    std::pair<Vote_Ciphertext, VoteZKP_Struct> vote, CryptoPP::Integer pk) {
  initLogger();
  // TODO: implement me!
    const Vote_Ciphertext &ciphertext = vote.first;
    const VoteZKP_Struct &zkp = vote.second;

    CryptoPP::Integer a = ciphertext.a; // g^r
    CryptoPP::Integer b = ciphertext.b; 
    CryptoPP::Integer a0 = zkp.a0, a1 = zkp.a1,
                      b0 = zkp.b0, b1 = zkp.b1,
                      c0 = zkp.c0, c1 = zkp.c1,
                      r0 = zkp.r0, r1 = zkp.r1;

    CryptoPP::Integer bigC = hash_vote_zkp(pk, a, b, a0, b0, a1, b1) % DL_Q;
    if(((c0 + c1) % DL_Q) != bigC) {
      return false;
    }


    // g^r0 == a0 * a^c0
    CryptoPP::Integer lhs = ModularExponentiation(DL_G, r0, DL_P);
    CryptoPP::Integer rhs = ( a0 * ModularExponentiation(a, c0, DL_P) ) % DL_P;
    if(lhs != rhs) return false;

    // pk^r0 == b0 * b^c0
    lhs = ModularExponentiation(pk, r0, DL_P);
    rhs = ( b0 * ModularExponentiation(b, c0, DL_P) ) % DL_P;
    if(lhs != rhs) return false;

    // g^r1 == a1 * a^c1
    lhs = ModularExponentiation(DL_G, r1, DL_P);
    rhs = ( a1 * ModularExponentiation(a, c1, DL_P) ) % DL_P;
    if(lhs != rhs) return false;

    CryptoPP::Integer bInvG = (b * EuclideanMultiplicativeInverse(DL_G, DL_P)) % DL_P;
    lhs = ModularExponentiation(pk, r1, DL_P);
    rhs = ( b1 * ModularExponentiation(bInvG, c1, DL_P) ) % DL_P;
    if(lhs != rhs) return false;

    return true; 




}

/**
 * Generate partial decryption and zkp.
 */
std::pair<PartialDecryption_Struct, DecryptionZKP_Struct>
ElectionClient::PartialDecrypt(Vote_Ciphertext combined_vote,
                               CryptoPP::Integer pk, CryptoPP::Integer sk) {
  initLogger();
  // TODO: implement me!

    PartialDecryption_Struct dec;
  dec.aggregate_ciphertext = combined_vote;

  CryptoPP::Integer p = DL_P; 
  CryptoPP::Integer a = combined_vote.a;
  dec.d = a_exp_b_mod_c(a, sk, p);

  // 2) Build the ZKP
  DecryptionZKP_Struct zkp;
  //   pick random r in [1..q-1]
  CryptoPP::AutoSeededRandomPool rng;
  CryptoPP::Integer r;
  while (true) {
    CryptoPP::SecByteBlock tmp(DL_Q.ByteCount());
    rng.GenerateBlock(tmp, tmp.size());
    CryptoPP::Integer candidate(tmp, tmp.size());
    // reduce mod Q so  0 < r < Q
    candidate = candidate % (DL_Q);
    if (candidate > CryptoPP::Integer::Zero()) {
      r = candidate;
      break;
    }
  }

  //   zkp.u = a^r mod p
  zkp.u = a_exp_b_mod_c(a, r, p);
  //   zkp.v = g^r mod p
  zkp.v = a_exp_b_mod_c(DL_G, r, p);

  CryptoPP::Integer c = hash_dec_zkp(pk, combined_vote.a, combined_vote.b,
                                     zkp.u, zkp.v);
  c = c % DL_Q;

  //   zkp.s = r + c * sk  (mod Q)
  zkp.s = (r + c * sk) % DL_Q;

  return std::make_pair(dec, zkp);





}

/**
 * Verify partial decryption zkp.
 */
bool ElectionClient::VerifyPartialDecryptZKP(
    ArbiterToWorld_PartialDecryption_Message a2w_dec_s, CryptoPP::Integer pki) {
  initLogger();
  // TODO: implement me!

  PartialDecryption_Struct dec_s = a2w_dec_s.dec;
  DecryptionZKP_Struct zkp = a2w_dec_s.zkp;
  Vote_Ciphertext agg = dec_s.aggregate_ciphertext;

  CryptoPP::Integer a = agg.a;
  CryptoPP::Integer b = agg.b; 
  CryptoPP::Integer d = dec_s.d;
  CryptoPP::Integer p = DL_P;

  //Compute challenge c
  CryptoPP::Integer c = hash_dec_zkp(pki, a, b, zkp.u, zkp.v);
  c = c % DL_Q;

  //  Check a^s  ==  u * d^c (mod p)
  CryptoPP::Integer lhs = a_exp_b_mod_c(a, zkp.s, p);
  CryptoPP::Integer rhs = (zkp.u * a_exp_b_mod_c(d, c, p)) % p;
  if (lhs != rhs) {
    return false;
  }

  // g^s  ==  v * pki^c (mod p)
  lhs = a_exp_b_mod_c(DL_G, zkp.s, p);
  rhs = (zkp.v * a_exp_b_mod_c(pki, c, p)) % p;
  if (lhs != rhs) {
    return false;
  }

  return true;

  
}

/**
 * Combine votes into one using homomorphic encryption.
 */
Vote_Ciphertext ElectionClient::CombineVotes(std::vector<VoteRow> all_votes) {
  initLogger();
  // TODO: implement me!

  Vote_Ciphertext combined;
  combined.a = CryptoPP::Integer::One();
  combined.b = CryptoPP::Integer::One();

  CryptoPP::Integer p = DL_P;
  for (auto &row : all_votes) {
    // row.vote is a Vote_Ciphertext
    combined.a = (combined.a * row.vote.a) % p;
    combined.b = (combined.b * row.vote.b) % p;
  }
  return combined;
}

/**
 * Combines partial decryptions and returns final vote count.
 */
CryptoPP::Integer ElectionClient::CombineResults(
    Vote_Ciphertext combined_vote,
    std::vector<PartialDecryptionRow> all_partial_decryptions) {
  initLogger();
  // TODO: implement me!
  // Combine all partial dec by multiplying them
  CryptoPP::Integer bigD = CryptoPP::Integer::One();
    for(const auto &row : all_partial_decryptions) {
      bigD = (bigD * row.dec.d) % DL_P;
    }


    CryptoPP::Integer numerator = combined_vote.b;
    CryptoPP::Integer denominator = bigD; // = a^sk mod p
    CryptoPP::Integer invDen = EuclideanMultiplicativeInverse(denominator, DL_P);
    CryptoPP::Integer raw = (numerator * invDen) % DL_P; // = g^(sum_of_ones)


    int maxVotes = 10000; // might need to change this depending on autograder tests idk
    for(int i=0; i<= maxVotes; i++){
      CryptoPP::Integer testVal =
        CryptoPP::ModularExponentiation(DL_G, CryptoPP::Integer(i), DL_P);
      if(testVal == raw){
        // found the exponent i, and i is the number of 1-votes
        return CryptoPP::Integer(i);
      }
    }
    // In case something goes wrong...
    return CryptoPP::Integer(-1); 
  
}
Editor is loading...
Leave a Comment