#include <abd/math/make_pd.h>
#include <cstdio>
namespace abd {
namespace {
template <class T>
INLINE_HOST_DEVICE int maxind(int k, T* S, int n)
{
int m = k + 1;
for (int i = k + 2; i < n; ++i)
if (math::abs(S[k * n + i]) > math::abs(S[k * n + m]))
m = i;
return m;
}
template <class T>
INLINE_HOST_DEVICE void update(int k, T t, T& y, int& state, T* e, int* changed)
{
y = e[k];
e[k] = y + t;
if (changed[k] && y == e[k]) {
changed[k] = false;
--state;
} else if (!changed[k] && y != e[k]) {
changed[k] = true;
++state;
}
}
template <class T>
INLINE_HOST_DEVICE void rotate(int k, int l, int i, int j, T s, T c, T* S, int n)
{
T Skl = S[k * n + l], Sij = S[i * n + j];
S[k * n + l] = c * Skl - s * Sij;
S[i * n + j] = s * Skl + c * Sij;
}
template <class T, int n>
INLINE_HOST_DEVICE T make_pd_gpu(T* S, T* E)
{
T max_S = 0;
for (int i = 0; i < n; ++i)
for (int j = 0; j < n; ++j)
max_S = math::max(max_S, math::abs(S[i * n + j]));
if (max_S < 1e-12)
return 0;
T e[n];
int k, l, m, state;
T s, c, t, p, y, d, r;
int ind[n], changed[n];
for (int i = 0; i < n; ++i)
for (int j = 0; j < n; ++j)
E[i * n + j] = (i == j ? 1.0 : 0.0);
state = n;
for (k = 0; k < n; ++k) {
ind[k] = maxind(k, S, n);
e[k] = S[k * n + k];
changed[k] = true;
}
while (state != 0) {
m = 0;
for (k = 1; k < n - 1; ++k) {
if (math::abs(S[k * n + ind[k]]) > math::abs(S[m * n + ind[m]]))
m = k;
}
k = m;
l = ind[m];
p = S[k * n + l];
if (math::abs(p) < max_S * 1e-9)
break;
y = (e[l] - e[k]) / 2.;
d = math::abs(y) + math::sqrt(p * p + y * y);
r = math::sqrt(p * p + d * d);
c = d / r;
s = p / r;
t = p * p / d;
if (y < 0) {
s = -s;
t = -t;
}
S[k * n + l] = 0.;
update(k, -t, y, state, e, changed);
update(l, t, y, state, e, changed);
for (int i = 0; i <= k - 1; ++i)
rotate(i, k, i, l, s, c, S, n);
for (int i = k + 1; i <= l - 1; ++i)
rotate(k, i, i, l, s, c, S, n);
for (int i = l + 1; i < n; ++i)
rotate(k, i, l, i, s, c, S, n);
for (int i = 0; i < n; ++i) {
T Eik = E[i * n + k], Eil = E[i * n + l];
E[i * n + k] = c * Eik - s * Eil;
E[i * n + l] = s * Eik + c * Eil;
}
ind[k] = maxind(k, S, n);
ind[l] = maxind(l, S, n);
}
for (int i = 0; i < n; ++i)
e[i] = math::max(e[i], 0.);
for (int i = 0; i < n; ++i)
for (int j = 0; j < n; ++j)
S[i * n + j] = 0.;
for (int i = 0; i < n; ++i)
for (int j = 0; j < n; ++j)
for (int k = 0; k < n; ++k)
S[i * n + k] += E[i * n + j] * e[j] * E[k * n + j];
T max_e = e[0];
for (int i = 0; i < n; ++i)
max_e = math::max(max_e, e[i]);
return max_e;
}
}
template <class T, int dim>
T make_pd(T* sym_A, T* local)
{
return make_pd_gpu<T, dim>(sym_A, local);
}
template double make_pd<double, 3>(double* sym_A, double* local);
template double make_pd<double, 9>(double* sym_A, double* local);
template double make_pd<double, 12>(double* sym_A, double* local);
template double make_pd<double, 24>(double* sym_A, double* local);
}