Untitled

 avatar
unknown
c_cpp
8 months ago
5.4 kB
3
Indexable
#include<iostream>
#include<vector>
#include<algorithm>
#include<map>
using namespace std;

const long long MOD = 998244353;

map<long long, long long> Larger;
map<long long, long long> Smaller;

class node{
    public:
        long long index;
        long long val;
        node() : index(0), val(0) {}
        node(long long idx, long long v) : index(idx), val(v) {}
};

long long mergeAndCount(vector<long long>& indices, long long left, long long mid, long long right) {
    vector<long long> leftArr(indices.begin() + left, indices.begin() + mid + 1);
    vector<long long> rightArr(indices.begin() + mid + 1, indices.begin() + right + 1);

    long long i = 0, j = 0, k = left;
    long long invCount = 0;

    while (i < leftArr.size() && j < rightArr.size()) {
        if (leftArr[i] <= rightArr[j]) {
            indices[k++] = leftArr[i++];
        } else {
            indices[k++] = rightArr[j++];
            invCount += leftArr.size() - i;  // Counting inversions
        }
    }

    while (i < leftArr.size()) {
        indices[k++] = leftArr[i++];
    }

    while (j < rightArr.size()) {
        indices[k++] = rightArr[j++];
    }

    return invCount;
}
long long mergeSortAndCount(vector<long long>& indices, long long left, long long right) {
    long long invCount = 0;
    if (left < right) {
        long long mid = left + (right - left) / 2;

        invCount += mergeSortAndCount(indices, left, mid);
        invCount += mergeSortAndCount(indices, mid + 1, right);

        invCount += mergeAndCount(indices, left, mid, right);
    }
    return invCount;
}
long long countInversions(const vector<node>& v) {
    vector<long long> indices(v.size());
    for (long long i = 0; i < v.size(); i++) {
        indices[i] = v[i].index;
    }

    return mergeSortAndCount(indices, 0, indices.size() - 1);
}

void buildCycleInversions(const vector<node>& v){
    vector<long long> counts;  // To store the count of each unique number
    long long currentValue = v[0].val;  // Start with the first value
    long long currentCount = 0;  // Counter for current value
    long long sum = 0;

    for (long long i=0;i<v.size();i++) {
        if (v[i].val == currentValue) {
            // If it's the same as the current number, increment the count
            currentCount++;
        } else {
            // If the number changes, store the count and reset for the new number
            counts.push_back(currentCount);

            Smaller[currentValue] = sum;
            sum += currentCount;

            currentValue = v[i].val;
            currentCount = 1;  // Reset for the new number
        }
    }
    counts.push_back(currentCount);
    Smaller[currentValue] = sum;
    sum += currentCount;

    counts = vector<long long> (0);
    currentValue = v[v.size()-1].val;
    currentCount = 0;
    long long s = 0;

    for (long long i=v.size()-1;i>=0;i--) {
        if (v[i].val == currentValue) {
            // If it's the same as the current number, increment the count
            currentCount++;
        } else {
            // If the number changes, store the count and reset for the new number
            counts.push_back(currentCount);

            Larger[currentValue] = s;
            s += currentCount;

            currentValue = v[i].val;
            currentCount = 1; 
        }
    }
    counts.push_back(currentCount);
    Larger[currentValue] = s;
    s += currentCount;
}
long long buildInv(const vector<node>& v_copy, vector<long long>& Inv, long long init, long long& N){
    long long sum = init;
    Inv[0] = init;
    for(long long i=1;i<N;i++){
        Inv[i] = Inv[i-1] + Larger[v_copy[i-1].val] - Smaller[v_copy[i-1].val];
        sum += Inv[i];
    }
    cerr<<"cycleInversions: ";
    for(long long i=0;i<N;i++){
        cerr<<Inv[i]<<" ";
    }
    cerr<<endl;
    return sum;
}

int main(){
    long long N, K;
    cin >> N >> K;

    vector<node> v(N);
    for(long long i=0;i<N;i++){
        long long temp;
        cin >> temp;
        node temp_Node(i, temp);
        v[i] = temp_Node;
    }
    vector<node> v_copy = v;

    sort(v.begin(), v.end(), [](const node &a, const node &b) {
        return a.val < b.val;
    });

    long long Inversions = countInversions(v);
    buildCycleInversions(v);

    vector<long long> Inv(N, 0);
    long long cycleInversion = buildInv(v_copy, Inv, Inversions, N);

    long long a, b;
    a = K / N;
    b = K % N;

    long long indeSum = 0;
    indeSum += ((a%998244353) * (cycleInversion%998244353))%998244353;
    indeSum %= 998244353;
    for(long long i=0;i<b;i++) indeSum += (Inv[i] % 998244353);
    indeSum %= 998244353;
    cerr << "indeSum:" << indeSum << endl;

    long long groupInversion = 0;
    // cout<<"Larger: ";
    // for(long long i=0;i<v_copy.size();i++){
    //     cout<<Larger[v_copy[i].val]<<" ";
    // }
    // cout<<endl;
    // cout<<"Smaller: ";
    for(long long i=0;i<v_copy.size();i++){
        // cout<<Smaller[v_copy[i].val]<<" ";
        groupInversion += Smaller[v_copy[i].val];
    }
    // cout<<endl;
    groupInversion%=998244353;

    long long groupSum = (long long)((((K%998244353)*((K-1)%998244353))%998244353)/2) * (groupInversion%998244353);
    groupSum %= 998244353;
    // cerr << "groupInversions: " << groupInversion << endl;
    cerr << "calculated inversion: " << groupSum << endl;
    cout << (indeSum + groupSum) % 998244353 << endl;
}
Editor is loading...
Leave a Comment