Untitled

mail@pastecode.io avatar
unknown
plain_text
a year ago
6.9 kB
1
Indexable
Never
import pytest
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
from src.dad_model.spark_model.stage_1_string_indexer.string_indexer import (
    DictStringIndexer,
    DataFrameStringIndexer,
)
from pyspark import SparkConf


@pytest.fixture(scope="session")
def spark():
    spark = (
        SparkSession.builder.master("local[1]")
        .appName("local-tests")
        .config("spark.executor.cores", "1")
        .config("spark.executor.instances", "1")
        .config("spark.sql.shuffle.partitions", "1")
        .config("spark.driver.bindAddress", "127.0.0.1")
        .getOrCreate()
    )
    yield spark
    spark.stop()



@pytest.fixture(scope="module")
def sample_df(spark):
    data = [("apple",), ("banana",), ("cherry",), ("apple",), ("banana",)]
    return spark.createDataFrame(data, ["fruit"])


@pytest.fixture(scope="module")
def test_df(spark):
    data = [("apple",), ("banana",), ("date",), ("elderberry",)]
    return spark.createDataFrame(data, ["fruit"])


# Testing DistributedStringIndexer
def test_distributed_string_indexer(spark, sample_df, test_df):
    indexer = DictStringIndexer()
    indexer.fit(sample_df, "fruit")
    transformed_df = indexer.transform(test_df, "fruit", "fruit_index")

    # Check if the known fruits get indexed
    assert (
        transformed_df.filter(col("fruit") == "apple")
        .select("fruit_index")
        .first()[0]
        == 0
    )
    assert (
        transformed_df.filter(col("fruit") == "banana")
        .select("fruit_index")
        .first()[0]
        == 1
    )

    # Check if unknown fruits are removed from the DataFrame
    assert transformed_df.filter(col("fruit") == "date").count() == 0
    assert transformed_df.filter(col("fruit") == "elderberry").count() == 0


# Testing DataFrameStringIndexer
def test_dataframe_string_indexer(spark, sample_df, test_df):
    indexer = DictStringIndexer()
    indexer.fit(sample_df, "fruit")
    transformed_df = indexer.transform(test_df, "fruit", "fruit_index")

    # Check if the known fruits get indexed
    assert (
        transformed_df.filter(col("fruit") == "apple")
        .select("fruit_index")
        .first()[0]
        == 0
    )
    assert (
        transformed_df.filter(col("fruit") == "banana")
        .select("fruit_index")
        .first()[0]
        == 1
    )

    # Check if unknown fruits are removed from the DataFrame
    assert transformed_df.filter(col("fruit") == "date").count() == 0
    assert transformed_df.filter(col("fruit") == "elderberry").count() == 0


def test_distributed_string_indexer_refit(spark, sample_df, test_df):
    indexer = DictStringIndexer()
    indexer.fit(sample_df, "fruit")
    transformed_df = indexer.transform(test_df, "fruit", "fruit_index")

    # Simulate refitting with another DataFrame
    refit_data = [("date",), ("elderberry",), ("apple",)]
    refit_df = spark.createDataFrame(refit_data, ["fruit"])
    indexer.fit(refit_df, "fruit")

    refitted_transformed_df = indexer.transform(
        test_df, "fruit", "fruit_index"
    )

    # Check if indices for known fruits from the first fitting remain the same
    assert (
        refitted_transformed_df.filter(col("fruit") == "apple")
        .select("fruit_index")
        .first()[0]
        == 0
    )
    assert (
        refitted_transformed_df.filter(col("fruit") == "banana")
        .select("fruit_index")
        .first()[0]
        == 1
    )

    # Check if new strings from the refitting get new indices
    assert (
        refitted_transformed_df.filter(col("fruit") == "date")
        .select("fruit_index")
        .first()[0]
        == 3
    )
    assert (
        refitted_transformed_df.filter(col("fruit") == "elderberry")
        .select("fruit_index")
        .first()[0]
        == 4
    )


def test_dataframe_string_indexer_refit(spark, sample_df, test_df):
    indexer = DataFrameStringIndexer()
    indexer.fit(sample_df, "fruit")
    transformed_df = indexer.transform(test_df, "fruit", "fruit_index")

    # Simulate refitting with another DataFrame
    refit_data = [("date",), ("elderberry",), ("apple",)]
    refit_df = spark.createDataFrame(refit_data, ["fruit"])
    indexer.fit(refit_df, "fruit")

    refitted_transformed_df = indexer.transform(
        test_df, "fruit", "fruit_index"
    )

    # Check if indices for known fruits from the first fitting remain the same
    assert (
        refitted_transformed_df.filter(col("fruit") == "apple")
        .select("fruit_index")
        .first()[0]
        == 0
    )
    assert (
        refitted_transformed_df.filter(col("fruit") == "banana")
        .select("fruit_index")
        .first()[0]
        == 1
    )

    # Check if new strings from the refitting get new indices
    assert (
        refitted_transformed_df.filter(col("fruit") == "date")
        .select("fruit_index")
        .first()[0]
        == 3
    )
    assert (
        refitted_transformed_df.filter(col("fruit") == "elderberry")
        .select("fruit_index")
        .first()[0]
        == 4
    )


def test_distributed_string_indexer_instance_variables(spark, sample_df):
    indexer = DictStringIndexer()
    indexer.fit(sample_df, "fruit")

    # Check instance variables
    assert len(indexer.string_to_index) == 3
    assert len(indexer.index_to_string) == 3
    assert set(indexer.string_to_index.keys()) == {"apple", "banana", "cherry"}
    assert set(indexer.string_to_index.values()) == {0, 1, 2}
    assert set(indexer.index_to_string.keys()) == {0, 1, 2}
    assert set(indexer.index_to_string.values()) == {
        "apple",
        "banana",
        "cherry",
    }


def test_distributed_string_indexer_transform_without_fit(spark, sample_df):
    indexer = DictStringIndexer()

    # Check that calling transform before fit raises an error
    with pytest.raises(
        ValueError, match="The indexer has not been fitted yet. Use fit first."
    ):
        indexer.transform(sample_df, "fruit", "fruit_index")


def test_dataframe_string_indexer_instance_variables(spark, sample_df):
    indexer = DataFrameStringIndexer()
    indexer.fit(sample_df, "fruit")

    # Check instance variables
    mapping_count = indexer.mapping_df.count()
    assert mapping_count == 3
    assert set(
        indexer.mapping_df.select("string").rdd.flatMap(lambda x: x).collect()
    ) == {"apple", "banana", "cherry"}
    assert set(
        indexer.mapping_df.select("index").rdd.flatMap(lambda x: x).collect()
    ) == {0, 1, 2}
    assert indexer._is_fitted == True


def test_dataframe_string_indexer_transform_without_fit(spark, sample_df):
    indexer = DataFrameStringIndexer()

    # Check that calling transform before fit raises an error
    with pytest.raises(
        ValueError, match="The indexer has not been fitted yet. Use fit first."
    ):
        indexer.transform(sample_df, "fruit", "fruit_index")