Untitled

 avatar
unknown
plain_text
a year ago
2.3 kB
14
Indexable
std::pair<int, float> DoSampling(
    std::vector<std::pair<float, int>>& logits_ids, std::mt19937* generator) {
  std::vector<float> probs;
  probs.reserve(logits_ids.size());
  for (const auto& [logit, _] : logits_ids) {
    probs.push_back(logit);
  }
  // Probabilities are normalized by `discrete_distribution`.
  std::discrete_distribution<> dist(probs.begin(), probs.end());
  int sample_idx = dist(*generator);
  return std::pair<int, float>(logits_ids[sample_idx].second,
                               logits_ids[sample_idx].first);
}

void SelectTopK(std::vector<std::pair<float, int>>& logits_ids,
                                 int k) {
  if (k > logits_ids.size()) {
    LOG(FATAL) << "Top k value must be smaller than the number of logits.";
  }
  std::partial_sort(
      logits_ids.begin(), logits_ids.begin() + k, logits_ids.end(),
      [](const std::pair<float, int>& a, const std::pair<float, int>& b) {
        // reverse order.
        return a.first > b.first;
      });
  logits_ids.resize(k);
}

void ScaledSoftmax(
    std::vector<std::pair<float, int>>& logits_ids, bool normalize) {
  float scale = 1 / kTemperature;
  double sum = 0.0;
  float max_logit = logits_ids[0].first;
  for (int i = 0; i < logits_ids.size(); ++i) {
    const float logit = logits_ids[i].first;
    const float p = expf(scale * (logit - max_logit));
    sum += p;
    logits_ids[i].first = p;
  }
  if (normalize) {
    for (int i = 0; i < logits_ids.size(); ++i) {
      logits_ids[i].first /= sum;
    }
  }
}

template <typename LogitsType>
std::pair<int, LogitsType> TopK(
    const std::vector<float> logits, std::mt19937* generator) {
  const size_t batch_size = 1;
  const size_t vocab_size = logits.size();

  std::vector<std::pair<int, LogitsType>> outputs;
  outputs.reserve(batch_size);
  for (int batch = 0; batch < batch_size; ++batch) {
    std::vector<std::pair<float, int>> logits_ids;
    logits_ids.reserve(vocab_size);
    for (int v = 0; v < vocab_size; ++v) {
      float logit = logits[batch * vocab_size + v];
      logits_ids.push_back(std::make_pair(logit, v));
    }
    SelectTopK(logits_ids, /*k=*/kTopK);
    // No need to normalize logits here, sampler takes care of that.
    ScaledSoftmax(logits_ids, /*normalize=*/false);
    outputs.push_back(DoSampling(logits_ids, generator));
  }
  return outputs[0];
}
Editor is loading...
Leave a Comment