Untitled
unknown
plain_text
2 years ago
1.7 kB
8
Indexable
import pytest
@pytest.fixture(params=[DictStringIndexer, DataFrameStringIndexer])
def indexer_class(request):
return request.param
def test_dict_string_indexer_order(spark, sample_df, indexer_class):
if indexer_class is not DictStringIndexer:
pytest.skip("Test only for DictStringIndexer")
indexer_instance = indexer_class()
indexer_instance.fit(sample_df, "fruit")
for idx, fruit in enumerate(["apple", "banana", "cherry"]):
assert indexer_instance.string_to_index[fruit] == idx
def test_dataframe_string_indexer_order(spark, sample_df, indexer_class):
if indexer_class is not DataFrameStringIndexer:
pytest.skip("Test only for DataFrameStringIndexer")
indexer_instance = indexer_class()
indexer_instance.fit(sample_df, "fruit")
fruit_to_index = dict(indexer_instance.mapping_df.rdd.map(lambda row: (row.string, row.index)).collect())
for idx, fruit in enumerate(["apple", "banana", "cherry"]):
assert fruit_to_index[fruit] == idx
@pytest.mark.parametrize(
"sample_data, test_data, expected_data",
[
# Test cases ...
]
)
def test_string_indexer_parameterized(spark, sample_data, test_data, expected_data, indexer_class):
sample_df = spark.createDataFrame(sample_data, ["fruit"])
test_df = spark.createDataFrame(test_data, ["fruit"])
indexer_instance = indexer_class()
indexer_instance.fit(sample_df, "fruit")
transformed_df = indexer_instance.transform(test_df, "fruit", "fruit_index")
expected_df = spark.createDataFrame(expected_data, ["fruit", "fruit_index"])
assert transformed_df.subtract(expected_df).count() == 0
assert expected_df.subtract(transformed_df).count() == 0
Editor is loading...