Untitled

 avatar
unknown
plain_text
a month ago
3.7 kB
2
Indexable
import unittest
from pyspark.sql import SparkSession
from pyspark.sql.functions import col

class SparkTransformationTests(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
        """Setup the SparkSession before any tests."""
        cls.spark = SparkSession.builder \
            .appName("PySpark Unit Test Example") \
            .master("local[*]") \
            .getOrCreate()

    @classmethod
    def tearDownClass(cls):
        """Stop the SparkSession after all tests."""
        cls.spark.stop()

    def test_basic_transformation(self):
        """Test basic column transformation."""
        # Sample data
        data = [("Alice", 34), ("Bob", 45), ("Charlie", 23)]
        df = self.spark.createDataFrame(data, ["name", "age"])

        # Transformation: Add a new column 'age_plus_5'
        transformed_df = df.withColumn("age_plus_5", col("age") + 5)

        # Collect results and assert
        result = transformed_df.collect()
        expected = [("Alice", 34, 39), ("Bob", 45, 50), ("Charlie", 23, 28)]
        self.assertEqual([row.asDict() for row in result], [{"name": "Alice", "age": 34, "age_plus_5": 39},
                                                           {"name": "Bob", "age": 45, "age_plus_5": 50},
                                                           {"name": "Charlie", "age": 23, "age_plus_5": 28}])

    def test_filter(self):
        """Test filtering rows based on a condition."""
        # Sample data
        data = [("Alice", 34), ("Bob", 45), ("Charlie", 23)]
        df = self.spark.createDataFrame(data, ["name", "age"])

        # Filter: Select rows where age > 30
        filtered_df = df.filter(col("age") > 30)

        # Collect results and assert
        result = filtered_df.collect()
        expected = [("Alice", 34), ("Bob", 45)]
        self.assertEqual([row.asDict() for row in result], [{"name": "Alice", "age": 34}, {"name": "Bob", "age": 45}])

    def test_group_by_aggregation(self):
        """Test groupBy and aggregation functions."""
        # Sample data
        data = [("Alice", "Sales", 3000), ("Bob", "HR", 4000), ("Charlie", "Sales", 5000), ("David", "HR", 3500)]
        df = self.spark.createDataFrame(data, ["name", "department", "salary"])

        # Group by 'department' and calculate the average salary
        avg_salary_df = df.groupBy("department").avg("salary")

        # Collect results and assert
        result = avg_salary_df.collect()
        expected = [("HR", 3750), ("Sales", 4000)]
        self.assertEqual([row.asDict() for row in result], [{"department": "HR", "avg(salary)": 3750},
                                                           {"department": "Sales", "avg(salary)": 4000}])

    def test_join(self):
        """Test joining two DataFrames."""
        # Sample data
        data1 = [(1, "Alice"), (2, "Bob"), (3, "Charlie")]
        data2 = [(1, "Sales"), (2, "HR"), (3, "Marketing")]

        df1 = self.spark.createDataFrame(data1, ["id", "name"])
        df2 = self.spark.createDataFrame(data2, ["id", "department"])

        # Perform an inner join
        joined_df = df1.join(df2, "id")

        # Collect results and assert
        result = joined_df.collect()
        expected = [(1, "Alice", "Sales"), (2, "Bob", "HR"), (3, "Charlie", "Marketing")]
        self.assertEqual([row.asDict() for row in result], [{"id": 1, "name": "Alice", "department": "Sales"},
                                                           {"id": 2, "name": "Bob", "department": "HR"},
                                                           {"id": 3, "name": "Charlie", "department": "Marketing"}])

if __name__ == "__main__":
    unittest.main()
Leave a Comment