Untitled
import unittest from pyspark.sql import SparkSession from pyspark.sql.functions import col class SparkTransformationTests(unittest.TestCase): @classmethod def setUpClass(cls): """Setup the SparkSession before any tests.""" cls.spark = SparkSession.builder \ .appName("PySpark Unit Test Example") \ .master("local[*]") \ .getOrCreate() @classmethod def tearDownClass(cls): """Stop the SparkSession after all tests.""" cls.spark.stop() def test_basic_transformation(self): """Test basic column transformation.""" # Sample data data = [("Alice", 34), ("Bob", 45), ("Charlie", 23)] df = self.spark.createDataFrame(data, ["name", "age"]) # Transformation: Add a new column 'age_plus_5' transformed_df = df.withColumn("age_plus_5", col("age") + 5) # Collect results and assert result = transformed_df.collect() expected = [("Alice", 34, 39), ("Bob", 45, 50), ("Charlie", 23, 28)] self.assertEqual([row.asDict() for row in result], [{"name": "Alice", "age": 34, "age_plus_5": 39}, {"name": "Bob", "age": 45, "age_plus_5": 50}, {"name": "Charlie", "age": 23, "age_plus_5": 28}]) def test_filter(self): """Test filtering rows based on a condition.""" # Sample data data = [("Alice", 34), ("Bob", 45), ("Charlie", 23)] df = self.spark.createDataFrame(data, ["name", "age"]) # Filter: Select rows where age > 30 filtered_df = df.filter(col("age") > 30) # Collect results and assert result = filtered_df.collect() expected = [("Alice", 34), ("Bob", 45)] self.assertEqual([row.asDict() for row in result], [{"name": "Alice", "age": 34}, {"name": "Bob", "age": 45}]) def test_group_by_aggregation(self): """Test groupBy and aggregation functions.""" # Sample data data = [("Alice", "Sales", 3000), ("Bob", "HR", 4000), ("Charlie", "Sales", 5000), ("David", "HR", 3500)] df = self.spark.createDataFrame(data, ["name", "department", "salary"]) # Group by 'department' and calculate the average salary avg_salary_df = df.groupBy("department").avg("salary") # Collect results and assert result = avg_salary_df.collect() expected = [("HR", 3750), ("Sales", 4000)] self.assertEqual([row.asDict() for row in result], [{"department": "HR", "avg(salary)": 3750}, {"department": "Sales", "avg(salary)": 4000}]) def test_join(self): """Test joining two DataFrames.""" # Sample data data1 = [(1, "Alice"), (2, "Bob"), (3, "Charlie")] data2 = [(1, "Sales"), (2, "HR"), (3, "Marketing")] df1 = self.spark.createDataFrame(data1, ["id", "name"]) df2 = self.spark.createDataFrame(data2, ["id", "department"]) # Perform an inner join joined_df = df1.join(df2, "id") # Collect results and assert result = joined_df.collect() expected = [(1, "Alice", "Sales"), (2, "Bob", "HR"), (3, "Charlie", "Marketing")] self.assertEqual([row.asDict() for row in result], [{"id": 1, "name": "Alice", "department": "Sales"}, {"id": 2, "name": "Bob", "department": "HR"}, {"id": 3, "name": "Charlie", "department": "Marketing"}]) if __name__ == "__main__": unittest.main()
Leave a Comment