Untitled
unknown
plain_text
a year ago
2.3 kB
14
Indexable
std::pair<int, float> DoSampling( std::vector<std::pair<float, int>>& logits_ids, std::mt19937* generator) { std::vector<float> probs; probs.reserve(logits_ids.size()); for (const auto& [logit, _] : logits_ids) { probs.push_back(logit); } // Probabilities are normalized by `discrete_distribution`. std::discrete_distribution<> dist(probs.begin(), probs.end()); int sample_idx = dist(*generator); return std::pair<int, float>(logits_ids[sample_idx].second, logits_ids[sample_idx].first); } void SelectTopK(std::vector<std::pair<float, int>>& logits_ids, int k) { if (k > logits_ids.size()) { LOG(FATAL) << "Top k value must be smaller than the number of logits."; } std::partial_sort( logits_ids.begin(), logits_ids.begin() + k, logits_ids.end(), [](const std::pair<float, int>& a, const std::pair<float, int>& b) { // reverse order. return a.first > b.first; }); logits_ids.resize(k); } void ScaledSoftmax( std::vector<std::pair<float, int>>& logits_ids, bool normalize) { float scale = 1 / kTemperature; double sum = 0.0; float max_logit = logits_ids[0].first; for (int i = 0; i < logits_ids.size(); ++i) { const float logit = logits_ids[i].first; const float p = expf(scale * (logit - max_logit)); sum += p; logits_ids[i].first = p; } if (normalize) { for (int i = 0; i < logits_ids.size(); ++i) { logits_ids[i].first /= sum; } } } template <typename LogitsType> std::pair<int, LogitsType> TopK( const std::vector<float> logits, std::mt19937* generator) { const size_t batch_size = 1; const size_t vocab_size = logits.size(); std::vector<std::pair<int, LogitsType>> outputs; outputs.reserve(batch_size); for (int batch = 0; batch < batch_size; ++batch) { std::vector<std::pair<float, int>> logits_ids; logits_ids.reserve(vocab_size); for (int v = 0; v < vocab_size; ++v) { float logit = logits[batch * vocab_size + v]; logits_ids.push_back(std::make_pair(logit, v)); } SelectTopK(logits_ids, /*k=*/kTopK); // No need to normalize logits here, sampler takes care of that. ScaledSoftmax(logits_ids, /*normalize=*/false); outputs.push_back(DoSampling(logits_ids, generator)); } return outputs[0]; }
Editor is loading...
Leave a Comment