Untitled
unknown
plain_text
2 years ago
1.7 kB
5
Indexable
import pytest @pytest.fixture(params=[DictStringIndexer, DataFrameStringIndexer]) def indexer_class(request): return request.param def test_dict_string_indexer_order(spark, sample_df, indexer_class): if indexer_class is not DictStringIndexer: pytest.skip("Test only for DictStringIndexer") indexer_instance = indexer_class() indexer_instance.fit(sample_df, "fruit") for idx, fruit in enumerate(["apple", "banana", "cherry"]): assert indexer_instance.string_to_index[fruit] == idx def test_dataframe_string_indexer_order(spark, sample_df, indexer_class): if indexer_class is not DataFrameStringIndexer: pytest.skip("Test only for DataFrameStringIndexer") indexer_instance = indexer_class() indexer_instance.fit(sample_df, "fruit") fruit_to_index = dict(indexer_instance.mapping_df.rdd.map(lambda row: (row.string, row.index)).collect()) for idx, fruit in enumerate(["apple", "banana", "cherry"]): assert fruit_to_index[fruit] == idx @pytest.mark.parametrize( "sample_data, test_data, expected_data", [ # Test cases ... ] ) def test_string_indexer_parameterized(spark, sample_data, test_data, expected_data, indexer_class): sample_df = spark.createDataFrame(sample_data, ["fruit"]) test_df = spark.createDataFrame(test_data, ["fruit"]) indexer_instance = indexer_class() indexer_instance.fit(sample_df, "fruit") transformed_df = indexer_instance.transform(test_df, "fruit", "fruit_index") expected_df = spark.createDataFrame(expected_data, ["fruit", "fruit_index"]) assert transformed_df.subtract(expected_df).count() == 0 assert expected_df.subtract(transformed_df).count() == 0
Editor is loading...