Untitled

mail@pastecode.io avatar
unknown
plain_text
a year ago
1.7 kB
0
Indexable
Never
import pytest

@pytest.fixture(params=[DictStringIndexer, DataFrameStringIndexer])
def indexer_class(request):
    return request.param

def test_dict_string_indexer_order(spark, sample_df, indexer_class):
    if indexer_class is not DictStringIndexer:
        pytest.skip("Test only for DictStringIndexer")

    indexer_instance = indexer_class()
    indexer_instance.fit(sample_df, "fruit")
    
    for idx, fruit in enumerate(["apple", "banana", "cherry"]):
        assert indexer_instance.string_to_index[fruit] == idx

def test_dataframe_string_indexer_order(spark, sample_df, indexer_class):
    if indexer_class is not DataFrameStringIndexer:
        pytest.skip("Test only for DataFrameStringIndexer")

    indexer_instance = indexer_class()
    indexer_instance.fit(sample_df, "fruit")

    fruit_to_index = dict(indexer_instance.mapping_df.rdd.map(lambda row: (row.string, row.index)).collect())
    for idx, fruit in enumerate(["apple", "banana", "cherry"]):
        assert fruit_to_index[fruit] == idx

@pytest.mark.parametrize(
    "sample_data, test_data, expected_data",
    [
        # Test cases ...
    ]
)
def test_string_indexer_parameterized(spark, sample_data, test_data, expected_data, indexer_class):
    sample_df = spark.createDataFrame(sample_data, ["fruit"])
    test_df = spark.createDataFrame(test_data, ["fruit"])

    indexer_instance = indexer_class()
    indexer_instance.fit(sample_df, "fruit")
    transformed_df = indexer_instance.transform(test_df, "fruit", "fruit_index")

    expected_df = spark.createDataFrame(expected_data, ["fruit", "fruit_index"])
    assert transformed_df.subtract(expected_df).count() == 0
    assert expected_df.subtract(transformed_df).count() == 0