Untitled

mail@pastecode.io avatar
unknown
plain_text
a year ago
6.1 kB
1
Indexable
Never
import pytest
from pyspark.sql import SparkSession
from pyspark.sql.functions import col

# Assuming the DistributedStringIndexer and DataFrameStringIndexer classes are in a file named 'indexers.py'
from indexers import DistributedStringIndexer, DataFrameStringIndexer

# Setup a SparkSession fixture for the tests
@pytest.fixture(scope="module")
def spark():
    return SparkSession.builder.appName("testing").getOrCreate()

@pytest.fixture(scope="module")
def sample_df(spark):
    data = [("apple",), ("banana",), ("cherry",), ("apple",), ("banana",)]
    return spark.createDataFrame(data, ["fruit"])

@pytest.fixture(scope="module")
def test_df(spark):
    data = [("apple",), ("banana",), ("date",), ("elderberry",)]
    return spark.createDataFrame(data, ["fruit"])

# Testing DistributedStringIndexer
def test_distributed_string_indexer(spark, sample_df, test_df):
    indexer = DistributedStringIndexer()
    indexer.fit(sample_df, "fruit")
    transformed_df = indexer.transform(test_df, "fruit", "fruit_index")
    
    # Check if the known fruits get indexed
    assert transformed_df.filter(col("fruit") == "apple").select("fruit_index").first()[0] == 0
    assert transformed_df.filter(col("fruit") == "banana").select("fruit_index").first()[0] == 1
    
    # Check if unknown fruits are removed from the DataFrame
    assert transformed_df.filter(col("fruit") == "date").count() == 0
    assert transformed_df.filter(col("fruit") == "elderberry").count() == 0

# Testing DataFrameStringIndexer
def test_dataframe_string_indexer(spark, sample_df, test_df):
    indexer = DataFrameStringIndexer()
    indexer.fit(sample_df, "fruit")
    transformed_df = indexer.transform(test_df, "fruit", "fruit_index")
    
    # Check if the known fruits get indexed
    assert transformed_df.filter(col("fruit") == "apple").select("fruit_index").first()[0] == 0
    assert transformed_df.filter(col("fruit") == "banana").select("fruit_index").first()[0] == 1
    
    # Check if unknown fruits are removed from the DataFrame
    assert transformed_df.filter(col("fruit") == "date").count() == 0
    assert transformed_df.filter(col("fruit") == "elderberry").count() == 0

# ... [previous imports and fixtures]

def test_distributed_string_indexer_refit(spark, sample_df, test_df):
    indexer = DistributedStringIndexer()
    indexer.fit(sample_df, "fruit")
    transformed_df = indexer.transform(test_df, "fruit", "fruit_index")

    # Simulate refitting with another DataFrame
    refit_data = [("date",), ("elderberry",), ("apple",)]
    refit_df = spark.createDataFrame(refit_data, ["fruit"])
    indexer.fit(refit_df, "fruit")
    
    refitted_transformed_df = indexer.transform(test_df, "fruit", "fruit_index")

    # Check if indices for known fruits from the first fitting remain the same
    assert refitted_transformed_df.filter(col("fruit") == "apple").select("fruit_index").first()[0] == 0
    assert refitted_transformed_df.filter(col("fruit") == "banana").select("fruit_index").first()[0] == 1

    # Check if new strings from the refitting get new indices
    assert refitted_transformed_df.filter(col("fruit") == "date").select("fruit_index").first()[0] == 3
    assert refitted_transformed_df.filter(col("fruit") == "elderberry").select("fruit_index").first()[0] == 4

def test_dataframe_string_indexer_refit(spark, sample_df, test_df):
    indexer = DataFrameStringIndexer()
    indexer.fit(sample_df, "fruit")
    transformed_df = indexer.transform(test_df, "fruit", "fruit_index")

    # Simulate refitting with another DataFrame
    refit_data = [("date",), ("elderberry",), ("apple",)]
    refit_df = spark.createDataFrame(refit_data, ["fruit"])
    indexer.fit(refit_df, "fruit")
    
    refitted_transformed_df = indexer.transform(test_df, "fruit", "fruit_index")

    # Check if indices for known fruits from the first fitting remain the same
    assert refitted_transformed_df.filter(col("fruit") == "apple").select("fruit_index").first()[0] == 0
    assert refitted_transformed_df.filter(col("fruit") == "banana").select("fruit_index").first()[0] == 1

    # Check if new strings from the refitting get new indices
    assert refitted_transformed_df.filter(col("fruit") == "date").select("fruit_index").first()[0] == 3
    assert refitted_transformed_df.filter(col("fruit") == "elderberry").select("fruit_index").first()[0] == 4

def test_distributed_string_indexer_instance_variables(spark, sample_df):
    indexer = DistributedStringIndexer()
    indexer.fit(sample_df, "fruit")

    # Check instance variables
    assert len(indexer.string_to_index) == 3
    assert len(indexer.index_to_string) == 3
    assert set(indexer.string_to_index.keys()) == {"apple", "banana", "cherry"}
    assert set(indexer.string_to_index.values()) == {0, 1, 2}
    assert set(indexer.index_to_string.keys()) == {0, 1, 2}
    assert set(indexer.index_to_string.values()) == {"apple", "banana", "cherry"}

def test_distributed_string_indexer_transform_without_fit(spark, sample_df):
    indexer = DistributedStringIndexer()

    # Check that calling transform before fit raises an error
    with pytest.raises(ValueError, match="The indexer has not been fitted yet. Use fit first."):
        indexer.transform(sample_df, "fruit", "fruit_index")

def test_dataframe_string_indexer_instance_variables(spark, sample_df):
    indexer = DataFrameStringIndexer()
    indexer.fit(sample_df, "fruit")

    # Check instance variables
    mapping_count = indexer.mapping_df.count()
    assert mapping_count == 3
    assert set(indexer.mapping_df.select("string").rdd.flatMap(lambda x: x).collect()) == {"apple", "banana", "cherry"}
    assert set(indexer.mapping_df.select("index").rdd.flatMap(lambda x: x).collect()) == {0, 1, 2}
    assert indexer._is_fitted == True

def test_dataframe_string_indexer_transform_without_fit(spark, sample_df):
    indexer = DataFrameStringIndexer()

    # Check that calling transform before fit raises an error
    with pytest.raises(ValueError, match="The indexer has not been fitted yet. Use fit first."):
        indexer.transform(sample_df, "fruit", "fruit_index")