Untitled
unknown
c_cpp
8 months ago
5.4 kB
3
Indexable
#include<iostream> #include<vector> #include<algorithm> #include<map> using namespace std; const long long MOD = 998244353; map<long long, long long> Larger; map<long long, long long> Smaller; class node{ public: long long index; long long val; node() : index(0), val(0) {} node(long long idx, long long v) : index(idx), val(v) {} }; long long mergeAndCount(vector<long long>& indices, long long left, long long mid, long long right) { vector<long long> leftArr(indices.begin() + left, indices.begin() + mid + 1); vector<long long> rightArr(indices.begin() + mid + 1, indices.begin() + right + 1); long long i = 0, j = 0, k = left; long long invCount = 0; while (i < leftArr.size() && j < rightArr.size()) { if (leftArr[i] <= rightArr[j]) { indices[k++] = leftArr[i++]; } else { indices[k++] = rightArr[j++]; invCount += leftArr.size() - i; // Counting inversions } } while (i < leftArr.size()) { indices[k++] = leftArr[i++]; } while (j < rightArr.size()) { indices[k++] = rightArr[j++]; } return invCount; } long long mergeSortAndCount(vector<long long>& indices, long long left, long long right) { long long invCount = 0; if (left < right) { long long mid = left + (right - left) / 2; invCount += mergeSortAndCount(indices, left, mid); invCount += mergeSortAndCount(indices, mid + 1, right); invCount += mergeAndCount(indices, left, mid, right); } return invCount; } long long countInversions(const vector<node>& v) { vector<long long> indices(v.size()); for (long long i = 0; i < v.size(); i++) { indices[i] = v[i].index; } return mergeSortAndCount(indices, 0, indices.size() - 1); } void buildCycleInversions(const vector<node>& v){ vector<long long> counts; // To store the count of each unique number long long currentValue = v[0].val; // Start with the first value long long currentCount = 0; // Counter for current value long long sum = 0; for (long long i=0;i<v.size();i++) { if (v[i].val == currentValue) { // If it's the same as the current number, increment the count currentCount++; } else { // If the number changes, store the count and reset for the new number counts.push_back(currentCount); Smaller[currentValue] = sum; sum += currentCount; currentValue = v[i].val; currentCount = 1; // Reset for the new number } } counts.push_back(currentCount); Smaller[currentValue] = sum; sum += currentCount; counts = vector<long long> (0); currentValue = v[v.size()-1].val; currentCount = 0; long long s = 0; for (long long i=v.size()-1;i>=0;i--) { if (v[i].val == currentValue) { // If it's the same as the current number, increment the count currentCount++; } else { // If the number changes, store the count and reset for the new number counts.push_back(currentCount); Larger[currentValue] = s; s += currentCount; currentValue = v[i].val; currentCount = 1; } } counts.push_back(currentCount); Larger[currentValue] = s; s += currentCount; } long long buildInv(const vector<node>& v_copy, vector<long long>& Inv, long long init, long long& N){ long long sum = init; Inv[0] = init; for(long long i=1;i<N;i++){ Inv[i] = Inv[i-1] + Larger[v_copy[i-1].val] - Smaller[v_copy[i-1].val]; sum += Inv[i]; } cerr<<"cycleInversions: "; for(long long i=0;i<N;i++){ cerr<<Inv[i]<<" "; } cerr<<endl; return sum; } int main(){ long long N, K; cin >> N >> K; vector<node> v(N); for(long long i=0;i<N;i++){ long long temp; cin >> temp; node temp_Node(i, temp); v[i] = temp_Node; } vector<node> v_copy = v; sort(v.begin(), v.end(), [](const node &a, const node &b) { return a.val < b.val; }); long long Inversions = countInversions(v); buildCycleInversions(v); vector<long long> Inv(N, 0); long long cycleInversion = buildInv(v_copy, Inv, Inversions, N); long long a, b; a = K / N; b = K % N; long long indeSum = 0; indeSum += ((a%998244353) * (cycleInversion%998244353))%998244353; indeSum %= 998244353; for(long long i=0;i<b;i++) indeSum += (Inv[i] % 998244353); indeSum %= 998244353; cerr << "indeSum:" << indeSum << endl; long long groupInversion = 0; // cout<<"Larger: "; // for(long long i=0;i<v_copy.size();i++){ // cout<<Larger[v_copy[i].val]<<" "; // } // cout<<endl; // cout<<"Smaller: "; for(long long i=0;i<v_copy.size();i++){ // cout<<Smaller[v_copy[i].val]<<" "; groupInversion += Smaller[v_copy[i].val]; } // cout<<endl; groupInversion%=998244353; long long groupSum = (long long)((((K%998244353)*((K-1)%998244353))%998244353)/2) * (groupInversion%998244353); groupSum %= 998244353; // cerr << "groupInversions: " << groupInversion << endl; cerr << "calculated inversion: " << groupSum << endl; cout << (indeSum + groupSum) % 998244353 << endl; }
Editor is loading...
Leave a Comment