Untitled
unknown
python
a year ago
5.3 kB
4
Indexable
import os
import sys
import typing
sys.path.append(os.path.join(os.path.dirname(__file__), "../noma"))
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
import unittest
from unittest.mock import Mock
from utils import *
import numpy as np
import jaxtyping
from itertools import product
from apsm import LinearDictionary, GaussianDictionary, SumKernelDictionary, NomaApsm, RkhsDictionary
from noma.cpp_wrapper import ApsmNomaDetectorWrapper, test_func
import torch
DTYPE = torch.float32
class PybindTest(unittest.TestCase):
def test_equivalence_between_cpp_and_python(self):
parameters = {
"filenames": [
"test_data_Gaussian_weight_0.5_antennas_2_users_3_4QAM_unittest.bin",
"test_data_Gaussian_weight_0.5_antennas_16_users_6_4QAM_unittest.bin",
"test_data_Gaussian_weight_1_antennas_2_users_3_4QAM_unittest.bin",
"test_data_Gaussian_weight_1_antennas_16_users_6_4QAM_unittest.bin",
"test_data_Gaussian_weight_1e-10_antennas_2_users_3_4QAM_unittest.bin",
"test_data_Gaussian_weight_1e-10_antennas_16_users_6_4QAM_unittest.bin",
"test_data_Gaussian_weight_0.5_antennas_16_users_6_16QAM_unittest.bin",
"test_data_Gaussian_weight_1_antennas_16_users_6_16QAM_unittest.bin",
"test_data_Gaussian_weight_1e-15_antennas_16_users_6_16QAM_unittest.bin",
"test_data_Gaussian_weight_0.5_antennas_16_users_6_64QAM_unittest.bin",
"test_data_Gaussian_weight_1_antennas_16_users_6_64QAM_unittest.bin",
"test_data_Gaussian_weight_1e-15_antennas_16_users_6_64QAM_unittest.bin",
],
"impl_id": [("original", "original"), ("shmem", "shmem"), ("split", "oldfast"), ("gramcache", "balanced")],
}
keys, values = zip(*parameters.items())
pars = [dict(zip(keys, p)) for p in product(*values)]
for par in pars:
with self.subTest(par):
full_filename = os.path.join(os.path.dirname(__file__), "../../cpp/data/tests", par["filenames"])
data = data_parse(full_filename)
X_train = process_complex_data_daniyal(data.rxSigTraining)
X_test = process_complex_data_daniyal(data.rxSigData)
y_train = data.txSigTraining
apsm = NomaApsm(
window_size=data.window_size * 2,
hyperslab_width=data.eB,
gaussian_weight=data.gaussian_weight,
gaussian_variance=data.gaussian_variance,
dtype=DTYPE,
)
apsm.train_numpy(X_train, y_train.reshape(-1))
result = apsm.detect_numpy(X_test)
# Using pybind
X_train_cpp = data.rxSigTraining.reshape(data.rxSigTraining.shape[0], -1)
X_test_cpp = data.rxSigData.reshape(data.rxSigData.shape[0], -1)
config = {
"otype": "APSM",
"llsInitialization": False,
"trainingShuffle": False,
"gaussianKernelWeight": data.gaussian_weight,
"gaussianKernelVariance": data.gaussian_variance / (data.num_antennas * data.num_antennas),
"windowSize": data.window_size,
"startWithFullWindow": False,
"eB": data.eB,
"trainVersion": par["impl_id"][0],
"detectVersion": par["impl_id"][1],
}
apsm_instance = ApsmNomaDetectorWrapper.build(config, data.num_antennas)
apsm_instance.train(X_train_cpp, y_train)
basis_cpp, gaussian_weights_cpp, linear_coeffs_cpp = apsm_instance.train_state()
result_cpp, detect_time = apsm_instance.detect(X_test_cpp)
linear_dict, gaussian_dict = None, None
if data.linear_weight == 1:
linear_dict = apsm.rkhs_dictionary
elif data.gaussian_weight == 1:
gaussian_dict = apsm.rkhs_dictionary
else:
sum_dict = apsm.rkhs_dictionary
self.assertIsInstance(sum_dict, SumKernelDictionary)
linear_dict = sum_dict.get_dictionaries_of_type(LinearDictionary)[0]
gaussian_dict = sum_dict.get_dictionaries_of_type(GaussianDictionary)[0]
if linear_dict:
np.testing.assert_almost_equal(
linear_coeffs_cpp.reshape(-1), linear_dict.weights.numpy(), decimal=5
)
if gaussian_dict:
gaussian_basis_python = gaussian_dict.dictionary_elements.numpy()
np.testing.assert_equal(basis_cpp, gaussian_basis_python[:, : data.num_antennas])
np.testing.assert_almost_equal(
gaussian_weights_cpp.reshape(-1), gaussian_dict.weights.numpy(), decimal=5
)
np.testing.assert_almost_equal(result, result_cpp.reshape(-1), decimal=5)
if __name__ == "__main__":
unittest.main()
# import pytest,sys
# if __name__ == '__main__':
# pytest.main([sys.argv[0], '-vvv',])
# exit(0)Editor is loading...
Leave a Comment