Untitled
unknown
plain_text
2 years ago
6.9 kB
3
Indexable
import pytest from pyspark.sql import SparkSession from pyspark.sql.functions import col from src.dad_model.spark_model.stage_1_string_indexer.string_indexer import ( DictStringIndexer, DataFrameStringIndexer, ) from pyspark import SparkConf @pytest.fixture(scope="session") def spark(): spark = ( SparkSession.builder.master("local[1]") .appName("local-tests") .config("spark.executor.cores", "1") .config("spark.executor.instances", "1") .config("spark.sql.shuffle.partitions", "1") .config("spark.driver.bindAddress", "127.0.0.1") .getOrCreate() ) yield spark spark.stop() @pytest.fixture(scope="module") def sample_df(spark): data = [("apple",), ("banana",), ("cherry",), ("apple",), ("banana",)] return spark.createDataFrame(data, ["fruit"]) @pytest.fixture(scope="module") def test_df(spark): data = [("apple",), ("banana",), ("date",), ("elderberry",)] return spark.createDataFrame(data, ["fruit"]) # Testing DistributedStringIndexer def test_distributed_string_indexer(spark, sample_df, test_df): indexer = DictStringIndexer() indexer.fit(sample_df, "fruit") transformed_df = indexer.transform(test_df, "fruit", "fruit_index") # Check if the known fruits get indexed assert ( transformed_df.filter(col("fruit") == "apple") .select("fruit_index") .first()[0] == 0 ) assert ( transformed_df.filter(col("fruit") == "banana") .select("fruit_index") .first()[0] == 1 ) # Check if unknown fruits are removed from the DataFrame assert transformed_df.filter(col("fruit") == "date").count() == 0 assert transformed_df.filter(col("fruit") == "elderberry").count() == 0 # Testing DataFrameStringIndexer def test_dataframe_string_indexer(spark, sample_df, test_df): indexer = DictStringIndexer() indexer.fit(sample_df, "fruit") transformed_df = indexer.transform(test_df, "fruit", "fruit_index") # Check if the known fruits get indexed assert ( transformed_df.filter(col("fruit") == "apple") .select("fruit_index") .first()[0] == 0 ) assert ( transformed_df.filter(col("fruit") == "banana") .select("fruit_index") .first()[0] == 1 ) # Check if unknown fruits are removed from the DataFrame assert transformed_df.filter(col("fruit") == "date").count() == 0 assert transformed_df.filter(col("fruit") == "elderberry").count() == 0 def test_distributed_string_indexer_refit(spark, sample_df, test_df): indexer = DictStringIndexer() indexer.fit(sample_df, "fruit") transformed_df = indexer.transform(test_df, "fruit", "fruit_index") # Simulate refitting with another DataFrame refit_data = [("date",), ("elderberry",), ("apple",)] refit_df = spark.createDataFrame(refit_data, ["fruit"]) indexer.fit(refit_df, "fruit") refitted_transformed_df = indexer.transform( test_df, "fruit", "fruit_index" ) # Check if indices for known fruits from the first fitting remain the same assert ( refitted_transformed_df.filter(col("fruit") == "apple") .select("fruit_index") .first()[0] == 0 ) assert ( refitted_transformed_df.filter(col("fruit") == "banana") .select("fruit_index") .first()[0] == 1 ) # Check if new strings from the refitting get new indices assert ( refitted_transformed_df.filter(col("fruit") == "date") .select("fruit_index") .first()[0] == 3 ) assert ( refitted_transformed_df.filter(col("fruit") == "elderberry") .select("fruit_index") .first()[0] == 4 ) def test_dataframe_string_indexer_refit(spark, sample_df, test_df): indexer = DataFrameStringIndexer() indexer.fit(sample_df, "fruit") transformed_df = indexer.transform(test_df, "fruit", "fruit_index") # Simulate refitting with another DataFrame refit_data = [("date",), ("elderberry",), ("apple",)] refit_df = spark.createDataFrame(refit_data, ["fruit"]) indexer.fit(refit_df, "fruit") refitted_transformed_df = indexer.transform( test_df, "fruit", "fruit_index" ) # Check if indices for known fruits from the first fitting remain the same assert ( refitted_transformed_df.filter(col("fruit") == "apple") .select("fruit_index") .first()[0] == 0 ) assert ( refitted_transformed_df.filter(col("fruit") == "banana") .select("fruit_index") .first()[0] == 1 ) # Check if new strings from the refitting get new indices assert ( refitted_transformed_df.filter(col("fruit") == "date") .select("fruit_index") .first()[0] == 3 ) assert ( refitted_transformed_df.filter(col("fruit") == "elderberry") .select("fruit_index") .first()[0] == 4 ) def test_distributed_string_indexer_instance_variables(spark, sample_df): indexer = DictStringIndexer() indexer.fit(sample_df, "fruit") # Check instance variables assert len(indexer.string_to_index) == 3 assert len(indexer.index_to_string) == 3 assert set(indexer.string_to_index.keys()) == {"apple", "banana", "cherry"} assert set(indexer.string_to_index.values()) == {0, 1, 2} assert set(indexer.index_to_string.keys()) == {0, 1, 2} assert set(indexer.index_to_string.values()) == { "apple", "banana", "cherry", } def test_distributed_string_indexer_transform_without_fit(spark, sample_df): indexer = DictStringIndexer() # Check that calling transform before fit raises an error with pytest.raises( ValueError, match="The indexer has not been fitted yet. Use fit first." ): indexer.transform(sample_df, "fruit", "fruit_index") def test_dataframe_string_indexer_instance_variables(spark, sample_df): indexer = DataFrameStringIndexer() indexer.fit(sample_df, "fruit") # Check instance variables mapping_count = indexer.mapping_df.count() assert mapping_count == 3 assert set( indexer.mapping_df.select("string").rdd.flatMap(lambda x: x).collect() ) == {"apple", "banana", "cherry"} assert set( indexer.mapping_df.select("index").rdd.flatMap(lambda x: x).collect() ) == {0, 1, 2} assert indexer._is_fitted == True def test_dataframe_string_indexer_transform_without_fit(spark, sample_df): indexer = DataFrameStringIndexer() # Check that calling transform before fit raises an error with pytest.raises( ValueError, match="The indexer has not been fitted yet. Use fit first." ): indexer.transform(sample_df, "fruit", "fruit_index")
Editor is loading...