Untitled

 avatar
unknown
python
5 months ago
2.3 kB
2
Indexable
import os
import sys
import typing


sys.path.append(os.path.join(os.path.dirname(__file__), "../noma"))
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))

import unittest
from unittest.mock import Mock

from utils import *

import numpy as np
import jaxtyping
from itertools import product

from noma.cpp_wrapper import ApsmNomaDetectorWrapper, test_func

class LazyTest(unittest.TestCase):

    def test_cuda_lazy(self):

        test_func()

        parameters = {
            "filenames": [
                "test_data_Gaussian_weight_0.5_antennas_2_users_3_4QAM_unittest.bin"
            ],
            "impl_id": [("original", "original"), ("shmem", "shmem"), ("split", "oldfast"), ("gramcache", "balanced")],
        }

        keys, values = zip(*parameters.items())
        pars = [dict(zip(keys, p)) for p in product(*values)]

        for par in pars:
            with self.subTest(par):

                full_filename = os.path.join(os.path.dirname(__file__), "../../cpp/data/tests", par["filenames"])
                data = data_parse(full_filename)

                

                X_train_cpp = data.rxSigTraining.reshape(data.rxSigTraining.shape[0], -1)
                X_test_cpp = data.rxSigData.reshape(data.rxSigData.shape[0], -1)
                y_train = data.txSigTraining

                config = {
                    "otype": "APSM",
                    "llsInitialization": False,
                    "trainingShuffle": False,
                    "gaussianKernelWeight": data.gaussian_weight,
                    "gaussianKernelVariance": data.gaussian_variance / (data.num_antennas * data.num_antennas),
                    "windowSize": data.window_size,
                    "startWithFullWindow": False,
                    "eB": data.eB,
                    "trainVersion": par["impl_id"][0],
                    "detectVersion": par["impl_id"][1],
                }
                apsm_instance = ApsmNomaDetectorWrapper.build(config, data.num_antennas)
                apsm_instance.train(X_train_cpp, y_train)
                basis_cpp, gaussian_weights_cpp, linear_coeffs_cpp = apsm_instance.train_state()
                result_cpp, detect_time = apsm_instance.detect(X_test_cpp)

                import torch




if __name__ == "__main__":
    unittest.main()
Editor is loading...
Leave a Comment