Untitled
unknown
plain_text
a year ago
2.3 kB
17
Indexable
std::pair<int, float> DoSampling(
std::vector<std::pair<float, int>>& logits_ids, std::mt19937* generator) {
std::vector<float> probs;
probs.reserve(logits_ids.size());
for (const auto& [logit, _] : logits_ids) {
probs.push_back(logit);
}
// Probabilities are normalized by `discrete_distribution`.
std::discrete_distribution<> dist(probs.begin(), probs.end());
int sample_idx = dist(*generator);
return std::pair<int, float>(logits_ids[sample_idx].second,
logits_ids[sample_idx].first);
}
void SelectTopK(std::vector<std::pair<float, int>>& logits_ids,
int k) {
if (k > logits_ids.size()) {
LOG(FATAL) << "Top k value must be smaller than the number of logits.";
}
std::partial_sort(
logits_ids.begin(), logits_ids.begin() + k, logits_ids.end(),
[](const std::pair<float, int>& a, const std::pair<float, int>& b) {
// reverse order.
return a.first > b.first;
});
logits_ids.resize(k);
}
void ScaledSoftmax(
std::vector<std::pair<float, int>>& logits_ids, bool normalize) {
float scale = 1 / kTemperature;
double sum = 0.0;
float max_logit = logits_ids[0].first;
for (int i = 0; i < logits_ids.size(); ++i) {
const float logit = logits_ids[i].first;
const float p = expf(scale * (logit - max_logit));
sum += p;
logits_ids[i].first = p;
}
if (normalize) {
for (int i = 0; i < logits_ids.size(); ++i) {
logits_ids[i].first /= sum;
}
}
}
template <typename LogitsType>
std::pair<int, LogitsType> TopK(
const std::vector<float> logits, std::mt19937* generator) {
const size_t batch_size = 1;
const size_t vocab_size = logits.size();
std::vector<std::pair<int, LogitsType>> outputs;
outputs.reserve(batch_size);
for (int batch = 0; batch < batch_size; ++batch) {
std::vector<std::pair<float, int>> logits_ids;
logits_ids.reserve(vocab_size);
for (int v = 0; v < vocab_size; ++v) {
float logit = logits[batch * vocab_size + v];
logits_ids.push_back(std::make_pair(logit, v));
}
SelectTopK(logits_ids, /*k=*/kTopK);
// No need to normalize logits here, sampler takes care of that.
ScaledSoftmax(logits_ids, /*normalize=*/false);
outputs.push_back(DoSampling(logits_ids, generator));
}
return outputs[0];
}Editor is loading...
Leave a Comment