Untitled
unknown
plain_text
2 years ago
6.9 kB
6
Indexable
import pytest
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
from src.dad_model.spark_model.stage_1_string_indexer.string_indexer import (
DictStringIndexer,
DataFrameStringIndexer,
)
from pyspark import SparkConf
@pytest.fixture(scope="session")
def spark():
spark = (
SparkSession.builder.master("local[1]")
.appName("local-tests")
.config("spark.executor.cores", "1")
.config("spark.executor.instances", "1")
.config("spark.sql.shuffle.partitions", "1")
.config("spark.driver.bindAddress", "127.0.0.1")
.getOrCreate()
)
yield spark
spark.stop()
@pytest.fixture(scope="module")
def sample_df(spark):
data = [("apple",), ("banana",), ("cherry",), ("apple",), ("banana",)]
return spark.createDataFrame(data, ["fruit"])
@pytest.fixture(scope="module")
def test_df(spark):
data = [("apple",), ("banana",), ("date",), ("elderberry",)]
return spark.createDataFrame(data, ["fruit"])
# Testing DistributedStringIndexer
def test_distributed_string_indexer(spark, sample_df, test_df):
indexer = DictStringIndexer()
indexer.fit(sample_df, "fruit")
transformed_df = indexer.transform(test_df, "fruit", "fruit_index")
# Check if the known fruits get indexed
assert (
transformed_df.filter(col("fruit") == "apple")
.select("fruit_index")
.first()[0]
== 0
)
assert (
transformed_df.filter(col("fruit") == "banana")
.select("fruit_index")
.first()[0]
== 1
)
# Check if unknown fruits are removed from the DataFrame
assert transformed_df.filter(col("fruit") == "date").count() == 0
assert transformed_df.filter(col("fruit") == "elderberry").count() == 0
# Testing DataFrameStringIndexer
def test_dataframe_string_indexer(spark, sample_df, test_df):
indexer = DictStringIndexer()
indexer.fit(sample_df, "fruit")
transformed_df = indexer.transform(test_df, "fruit", "fruit_index")
# Check if the known fruits get indexed
assert (
transformed_df.filter(col("fruit") == "apple")
.select("fruit_index")
.first()[0]
== 0
)
assert (
transformed_df.filter(col("fruit") == "banana")
.select("fruit_index")
.first()[0]
== 1
)
# Check if unknown fruits are removed from the DataFrame
assert transformed_df.filter(col("fruit") == "date").count() == 0
assert transformed_df.filter(col("fruit") == "elderberry").count() == 0
def test_distributed_string_indexer_refit(spark, sample_df, test_df):
indexer = DictStringIndexer()
indexer.fit(sample_df, "fruit")
transformed_df = indexer.transform(test_df, "fruit", "fruit_index")
# Simulate refitting with another DataFrame
refit_data = [("date",), ("elderberry",), ("apple",)]
refit_df = spark.createDataFrame(refit_data, ["fruit"])
indexer.fit(refit_df, "fruit")
refitted_transformed_df = indexer.transform(
test_df, "fruit", "fruit_index"
)
# Check if indices for known fruits from the first fitting remain the same
assert (
refitted_transformed_df.filter(col("fruit") == "apple")
.select("fruit_index")
.first()[0]
== 0
)
assert (
refitted_transformed_df.filter(col("fruit") == "banana")
.select("fruit_index")
.first()[0]
== 1
)
# Check if new strings from the refitting get new indices
assert (
refitted_transformed_df.filter(col("fruit") == "date")
.select("fruit_index")
.first()[0]
== 3
)
assert (
refitted_transformed_df.filter(col("fruit") == "elderberry")
.select("fruit_index")
.first()[0]
== 4
)
def test_dataframe_string_indexer_refit(spark, sample_df, test_df):
indexer = DataFrameStringIndexer()
indexer.fit(sample_df, "fruit")
transformed_df = indexer.transform(test_df, "fruit", "fruit_index")
# Simulate refitting with another DataFrame
refit_data = [("date",), ("elderberry",), ("apple",)]
refit_df = spark.createDataFrame(refit_data, ["fruit"])
indexer.fit(refit_df, "fruit")
refitted_transformed_df = indexer.transform(
test_df, "fruit", "fruit_index"
)
# Check if indices for known fruits from the first fitting remain the same
assert (
refitted_transformed_df.filter(col("fruit") == "apple")
.select("fruit_index")
.first()[0]
== 0
)
assert (
refitted_transformed_df.filter(col("fruit") == "banana")
.select("fruit_index")
.first()[0]
== 1
)
# Check if new strings from the refitting get new indices
assert (
refitted_transformed_df.filter(col("fruit") == "date")
.select("fruit_index")
.first()[0]
== 3
)
assert (
refitted_transformed_df.filter(col("fruit") == "elderberry")
.select("fruit_index")
.first()[0]
== 4
)
def test_distributed_string_indexer_instance_variables(spark, sample_df):
indexer = DictStringIndexer()
indexer.fit(sample_df, "fruit")
# Check instance variables
assert len(indexer.string_to_index) == 3
assert len(indexer.index_to_string) == 3
assert set(indexer.string_to_index.keys()) == {"apple", "banana", "cherry"}
assert set(indexer.string_to_index.values()) == {0, 1, 2}
assert set(indexer.index_to_string.keys()) == {0, 1, 2}
assert set(indexer.index_to_string.values()) == {
"apple",
"banana",
"cherry",
}
def test_distributed_string_indexer_transform_without_fit(spark, sample_df):
indexer = DictStringIndexer()
# Check that calling transform before fit raises an error
with pytest.raises(
ValueError, match="The indexer has not been fitted yet. Use fit first."
):
indexer.transform(sample_df, "fruit", "fruit_index")
def test_dataframe_string_indexer_instance_variables(spark, sample_df):
indexer = DataFrameStringIndexer()
indexer.fit(sample_df, "fruit")
# Check instance variables
mapping_count = indexer.mapping_df.count()
assert mapping_count == 3
assert set(
indexer.mapping_df.select("string").rdd.flatMap(lambda x: x).collect()
) == {"apple", "banana", "cherry"}
assert set(
indexer.mapping_df.select("index").rdd.flatMap(lambda x: x).collect()
) == {0, 1, 2}
assert indexer._is_fitted == True
def test_dataframe_string_indexer_transform_without_fit(spark, sample_df):
indexer = DataFrameStringIndexer()
# Check that calling transform before fit raises an error
with pytest.raises(
ValueError, match="The indexer has not been fitted yet. Use fit first."
):
indexer.transform(sample_df, "fruit", "fruit_index")
Editor is loading...