Untitled
unknown
plain_text
2 years ago
6.1 kB
7
Indexable
import pytest
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
# Assuming the DistributedStringIndexer and DataFrameStringIndexer classes are in a file named 'indexers.py'
from indexers import DistributedStringIndexer, DataFrameStringIndexer
# Setup a SparkSession fixture for the tests
@pytest.fixture(scope="module")
def spark():
return SparkSession.builder.appName("testing").getOrCreate()
@pytest.fixture(scope="module")
def sample_df(spark):
data = [("apple",), ("banana",), ("cherry",), ("apple",), ("banana",)]
return spark.createDataFrame(data, ["fruit"])
@pytest.fixture(scope="module")
def test_df(spark):
data = [("apple",), ("banana",), ("date",), ("elderberry",)]
return spark.createDataFrame(data, ["fruit"])
# Testing DistributedStringIndexer
def test_distributed_string_indexer(spark, sample_df, test_df):
indexer = DistributedStringIndexer()
indexer.fit(sample_df, "fruit")
transformed_df = indexer.transform(test_df, "fruit", "fruit_index")
# Check if the known fruits get indexed
assert transformed_df.filter(col("fruit") == "apple").select("fruit_index").first()[0] == 0
assert transformed_df.filter(col("fruit") == "banana").select("fruit_index").first()[0] == 1
# Check if unknown fruits are removed from the DataFrame
assert transformed_df.filter(col("fruit") == "date").count() == 0
assert transformed_df.filter(col("fruit") == "elderberry").count() == 0
# Testing DataFrameStringIndexer
def test_dataframe_string_indexer(spark, sample_df, test_df):
indexer = DataFrameStringIndexer()
indexer.fit(sample_df, "fruit")
transformed_df = indexer.transform(test_df, "fruit", "fruit_index")
# Check if the known fruits get indexed
assert transformed_df.filter(col("fruit") == "apple").select("fruit_index").first()[0] == 0
assert transformed_df.filter(col("fruit") == "banana").select("fruit_index").first()[0] == 1
# Check if unknown fruits are removed from the DataFrame
assert transformed_df.filter(col("fruit") == "date").count() == 0
assert transformed_df.filter(col("fruit") == "elderberry").count() == 0
# ... [previous imports and fixtures]
def test_distributed_string_indexer_refit(spark, sample_df, test_df):
indexer = DistributedStringIndexer()
indexer.fit(sample_df, "fruit")
transformed_df = indexer.transform(test_df, "fruit", "fruit_index")
# Simulate refitting with another DataFrame
refit_data = [("date",), ("elderberry",), ("apple",)]
refit_df = spark.createDataFrame(refit_data, ["fruit"])
indexer.fit(refit_df, "fruit")
refitted_transformed_df = indexer.transform(test_df, "fruit", "fruit_index")
# Check if indices for known fruits from the first fitting remain the same
assert refitted_transformed_df.filter(col("fruit") == "apple").select("fruit_index").first()[0] == 0
assert refitted_transformed_df.filter(col("fruit") == "banana").select("fruit_index").first()[0] == 1
# Check if new strings from the refitting get new indices
assert refitted_transformed_df.filter(col("fruit") == "date").select("fruit_index").first()[0] == 3
assert refitted_transformed_df.filter(col("fruit") == "elderberry").select("fruit_index").first()[0] == 4
def test_dataframe_string_indexer_refit(spark, sample_df, test_df):
indexer = DataFrameStringIndexer()
indexer.fit(sample_df, "fruit")
transformed_df = indexer.transform(test_df, "fruit", "fruit_index")
# Simulate refitting with another DataFrame
refit_data = [("date",), ("elderberry",), ("apple",)]
refit_df = spark.createDataFrame(refit_data, ["fruit"])
indexer.fit(refit_df, "fruit")
refitted_transformed_df = indexer.transform(test_df, "fruit", "fruit_index")
# Check if indices for known fruits from the first fitting remain the same
assert refitted_transformed_df.filter(col("fruit") == "apple").select("fruit_index").first()[0] == 0
assert refitted_transformed_df.filter(col("fruit") == "banana").select("fruit_index").first()[0] == 1
# Check if new strings from the refitting get new indices
assert refitted_transformed_df.filter(col("fruit") == "date").select("fruit_index").first()[0] == 3
assert refitted_transformed_df.filter(col("fruit") == "elderberry").select("fruit_index").first()[0] == 4
def test_distributed_string_indexer_instance_variables(spark, sample_df):
indexer = DistributedStringIndexer()
indexer.fit(sample_df, "fruit")
# Check instance variables
assert len(indexer.string_to_index) == 3
assert len(indexer.index_to_string) == 3
assert set(indexer.string_to_index.keys()) == {"apple", "banana", "cherry"}
assert set(indexer.string_to_index.values()) == {0, 1, 2}
assert set(indexer.index_to_string.keys()) == {0, 1, 2}
assert set(indexer.index_to_string.values()) == {"apple", "banana", "cherry"}
def test_distributed_string_indexer_transform_without_fit(spark, sample_df):
indexer = DistributedStringIndexer()
# Check that calling transform before fit raises an error
with pytest.raises(ValueError, match="The indexer has not been fitted yet. Use fit first."):
indexer.transform(sample_df, "fruit", "fruit_index")
def test_dataframe_string_indexer_instance_variables(spark, sample_df):
indexer = DataFrameStringIndexer()
indexer.fit(sample_df, "fruit")
# Check instance variables
mapping_count = indexer.mapping_df.count()
assert mapping_count == 3
assert set(indexer.mapping_df.select("string").rdd.flatMap(lambda x: x).collect()) == {"apple", "banana", "cherry"}
assert set(indexer.mapping_df.select("index").rdd.flatMap(lambda x: x).collect()) == {0, 1, 2}
assert indexer._is_fitted == True
def test_dataframe_string_indexer_transform_without_fit(spark, sample_df):
indexer = DataFrameStringIndexer()
# Check that calling transform before fit raises an error
with pytest.raises(ValueError, match="The indexer has not been fitted yet. Use fit first."):
indexer.transform(sample_df, "fruit", "fruit_index")
Editor is loading...