Untitled
unknown
plain_text
2 years ago
6.1 kB
3
Indexable
import pytest from pyspark.sql import SparkSession from pyspark.sql.functions import col # Assuming the DistributedStringIndexer and DataFrameStringIndexer classes are in a file named 'indexers.py' from indexers import DistributedStringIndexer, DataFrameStringIndexer # Setup a SparkSession fixture for the tests @pytest.fixture(scope="module") def spark(): return SparkSession.builder.appName("testing").getOrCreate() @pytest.fixture(scope="module") def sample_df(spark): data = [("apple",), ("banana",), ("cherry",), ("apple",), ("banana",)] return spark.createDataFrame(data, ["fruit"]) @pytest.fixture(scope="module") def test_df(spark): data = [("apple",), ("banana",), ("date",), ("elderberry",)] return spark.createDataFrame(data, ["fruit"]) # Testing DistributedStringIndexer def test_distributed_string_indexer(spark, sample_df, test_df): indexer = DistributedStringIndexer() indexer.fit(sample_df, "fruit") transformed_df = indexer.transform(test_df, "fruit", "fruit_index") # Check if the known fruits get indexed assert transformed_df.filter(col("fruit") == "apple").select("fruit_index").first()[0] == 0 assert transformed_df.filter(col("fruit") == "banana").select("fruit_index").first()[0] == 1 # Check if unknown fruits are removed from the DataFrame assert transformed_df.filter(col("fruit") == "date").count() == 0 assert transformed_df.filter(col("fruit") == "elderberry").count() == 0 # Testing DataFrameStringIndexer def test_dataframe_string_indexer(spark, sample_df, test_df): indexer = DataFrameStringIndexer() indexer.fit(sample_df, "fruit") transformed_df = indexer.transform(test_df, "fruit", "fruit_index") # Check if the known fruits get indexed assert transformed_df.filter(col("fruit") == "apple").select("fruit_index").first()[0] == 0 assert transformed_df.filter(col("fruit") == "banana").select("fruit_index").first()[0] == 1 # Check if unknown fruits are removed from the DataFrame assert transformed_df.filter(col("fruit") == "date").count() == 0 assert transformed_df.filter(col("fruit") == "elderberry").count() == 0 # ... [previous imports and fixtures] def test_distributed_string_indexer_refit(spark, sample_df, test_df): indexer = DistributedStringIndexer() indexer.fit(sample_df, "fruit") transformed_df = indexer.transform(test_df, "fruit", "fruit_index") # Simulate refitting with another DataFrame refit_data = [("date",), ("elderberry",), ("apple",)] refit_df = spark.createDataFrame(refit_data, ["fruit"]) indexer.fit(refit_df, "fruit") refitted_transformed_df = indexer.transform(test_df, "fruit", "fruit_index") # Check if indices for known fruits from the first fitting remain the same assert refitted_transformed_df.filter(col("fruit") == "apple").select("fruit_index").first()[0] == 0 assert refitted_transformed_df.filter(col("fruit") == "banana").select("fruit_index").first()[0] == 1 # Check if new strings from the refitting get new indices assert refitted_transformed_df.filter(col("fruit") == "date").select("fruit_index").first()[0] == 3 assert refitted_transformed_df.filter(col("fruit") == "elderberry").select("fruit_index").first()[0] == 4 def test_dataframe_string_indexer_refit(spark, sample_df, test_df): indexer = DataFrameStringIndexer() indexer.fit(sample_df, "fruit") transformed_df = indexer.transform(test_df, "fruit", "fruit_index") # Simulate refitting with another DataFrame refit_data = [("date",), ("elderberry",), ("apple",)] refit_df = spark.createDataFrame(refit_data, ["fruit"]) indexer.fit(refit_df, "fruit") refitted_transformed_df = indexer.transform(test_df, "fruit", "fruit_index") # Check if indices for known fruits from the first fitting remain the same assert refitted_transformed_df.filter(col("fruit") == "apple").select("fruit_index").first()[0] == 0 assert refitted_transformed_df.filter(col("fruit") == "banana").select("fruit_index").first()[0] == 1 # Check if new strings from the refitting get new indices assert refitted_transformed_df.filter(col("fruit") == "date").select("fruit_index").first()[0] == 3 assert refitted_transformed_df.filter(col("fruit") == "elderberry").select("fruit_index").first()[0] == 4 def test_distributed_string_indexer_instance_variables(spark, sample_df): indexer = DistributedStringIndexer() indexer.fit(sample_df, "fruit") # Check instance variables assert len(indexer.string_to_index) == 3 assert len(indexer.index_to_string) == 3 assert set(indexer.string_to_index.keys()) == {"apple", "banana", "cherry"} assert set(indexer.string_to_index.values()) == {0, 1, 2} assert set(indexer.index_to_string.keys()) == {0, 1, 2} assert set(indexer.index_to_string.values()) == {"apple", "banana", "cherry"} def test_distributed_string_indexer_transform_without_fit(spark, sample_df): indexer = DistributedStringIndexer() # Check that calling transform before fit raises an error with pytest.raises(ValueError, match="The indexer has not been fitted yet. Use fit first."): indexer.transform(sample_df, "fruit", "fruit_index") def test_dataframe_string_indexer_instance_variables(spark, sample_df): indexer = DataFrameStringIndexer() indexer.fit(sample_df, "fruit") # Check instance variables mapping_count = indexer.mapping_df.count() assert mapping_count == 3 assert set(indexer.mapping_df.select("string").rdd.flatMap(lambda x: x).collect()) == {"apple", "banana", "cherry"} assert set(indexer.mapping_df.select("index").rdd.flatMap(lambda x: x).collect()) == {0, 1, 2} assert indexer._is_fitted == True def test_dataframe_string_indexer_transform_without_fit(spark, sample_df): indexer = DataFrameStringIndexer() # Check that calling transform before fit raises an error with pytest.raises(ValueError, match="The indexer has not been fitted yet. Use fit first."): indexer.transform(sample_df, "fruit", "fruit_index")
Editor is loading...