Untitled

 avatar
unknown
python
5 months ago
5.3 kB
2
Indexable
import os
import sys
import typing

sys.path.append(os.path.join(os.path.dirname(__file__), "../noma"))
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))

import unittest
from unittest.mock import Mock

from utils import *
import numpy as np
import jaxtyping
from itertools import product

from apsm import LinearDictionary, GaussianDictionary, SumKernelDictionary, NomaApsm, RkhsDictionary

from noma.cpp_wrapper import ApsmNomaDetectorWrapper, test_func
import torch

DTYPE = torch.float32


class PybindTest(unittest.TestCase):

    def test_equivalence_between_cpp_and_python(self):
        parameters = {
            "filenames": [
                "test_data_Gaussian_weight_0.5_antennas_2_users_3_4QAM_unittest.bin",
                "test_data_Gaussian_weight_0.5_antennas_16_users_6_4QAM_unittest.bin",
                "test_data_Gaussian_weight_1_antennas_2_users_3_4QAM_unittest.bin",
                "test_data_Gaussian_weight_1_antennas_16_users_6_4QAM_unittest.bin",
                "test_data_Gaussian_weight_1e-10_antennas_2_users_3_4QAM_unittest.bin",
                "test_data_Gaussian_weight_1e-10_antennas_16_users_6_4QAM_unittest.bin",
                "test_data_Gaussian_weight_0.5_antennas_16_users_6_16QAM_unittest.bin",
                "test_data_Gaussian_weight_1_antennas_16_users_6_16QAM_unittest.bin",
                "test_data_Gaussian_weight_1e-15_antennas_16_users_6_16QAM_unittest.bin",
                "test_data_Gaussian_weight_0.5_antennas_16_users_6_64QAM_unittest.bin",
                "test_data_Gaussian_weight_1_antennas_16_users_6_64QAM_unittest.bin",
                "test_data_Gaussian_weight_1e-15_antennas_16_users_6_64QAM_unittest.bin",
            ],
            "impl_id": [("original", "original"), ("shmem", "shmem"), ("split", "oldfast"), ("gramcache", "balanced")],
        }

        keys, values = zip(*parameters.items())
        pars = [dict(zip(keys, p)) for p in product(*values)]

        for par in pars:
            with self.subTest(par):
                full_filename = os.path.join(os.path.dirname(__file__), "../../cpp/data/tests", par["filenames"])
                data = data_parse(full_filename)
                X_train = process_complex_data_daniyal(data.rxSigTraining)
                X_test = process_complex_data_daniyal(data.rxSigData)
                y_train = data.txSigTraining
                apsm = NomaApsm(
                    window_size=data.window_size * 2,
                    hyperslab_width=data.eB,
                    gaussian_weight=data.gaussian_weight,
                    gaussian_variance=data.gaussian_variance,
                    dtype=DTYPE,
                )
                apsm.train_numpy(X_train, y_train.reshape(-1))
                result = apsm.detect_numpy(X_test)

                # Using pybind
                X_train_cpp = data.rxSigTraining.reshape(data.rxSigTraining.shape[0], -1)
                X_test_cpp = data.rxSigData.reshape(data.rxSigData.shape[0], -1)
                config = {
                    "otype": "APSM",
                    "llsInitialization": False,
                    "trainingShuffle": False,
                    "gaussianKernelWeight": data.gaussian_weight,
                    "gaussianKernelVariance": data.gaussian_variance / (data.num_antennas * data.num_antennas),
                    "windowSize": data.window_size,
                    "startWithFullWindow": False,
                    "eB": data.eB,
                    "trainVersion": par["impl_id"][0],
                    "detectVersion": par["impl_id"][1],
                }
                apsm_instance = ApsmNomaDetectorWrapper.build(config, data.num_antennas)
                apsm_instance.train(X_train_cpp, y_train)
                basis_cpp, gaussian_weights_cpp, linear_coeffs_cpp = apsm_instance.train_state()
                result_cpp, detect_time = apsm_instance.detect(X_test_cpp)

                linear_dict, gaussian_dict = None, None
                if data.linear_weight == 1:
                    linear_dict = apsm.rkhs_dictionary
                elif data.gaussian_weight == 1:
                    gaussian_dict = apsm.rkhs_dictionary
                else:
                    sum_dict = apsm.rkhs_dictionary
                    self.assertIsInstance(sum_dict, SumKernelDictionary)
                    linear_dict = sum_dict.get_dictionaries_of_type(LinearDictionary)[0]
                    gaussian_dict = sum_dict.get_dictionaries_of_type(GaussianDictionary)[0]

                if linear_dict:
                    np.testing.assert_almost_equal(
                        linear_coeffs_cpp.reshape(-1), linear_dict.weights.numpy(), decimal=5
                    )
                if gaussian_dict:
                    gaussian_basis_python = gaussian_dict.dictionary_elements.numpy()
                    np.testing.assert_equal(basis_cpp, gaussian_basis_python[:, : data.num_antennas])
                    np.testing.assert_almost_equal(
                        gaussian_weights_cpp.reshape(-1), gaussian_dict.weights.numpy(), decimal=5
                    )

                np.testing.assert_almost_equal(result, result_cpp.reshape(-1), decimal=5)


if __name__ == "__main__":
    unittest.main()

# import pytest,sys
# if __name__ == '__main__':
#     pytest.main([sys.argv[0], '-vvv',])
#     exit(0)
Editor is loading...
Leave a Comment