Untitled
unknown
plain_text
a year ago
3.7 kB
8
Indexable
import unittest
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
class SparkTransformationTests(unittest.TestCase):
@classmethod
def setUpClass(cls):
"""Setup the SparkSession before any tests."""
cls.spark = SparkSession.builder \
.appName("PySpark Unit Test Example") \
.master("local[*]") \
.getOrCreate()
@classmethod
def tearDownClass(cls):
"""Stop the SparkSession after all tests."""
cls.spark.stop()
def test_basic_transformation(self):
"""Test basic column transformation."""
# Sample data
data = [("Alice", 34), ("Bob", 45), ("Charlie", 23)]
df = self.spark.createDataFrame(data, ["name", "age"])
# Transformation: Add a new column 'age_plus_5'
transformed_df = df.withColumn("age_plus_5", col("age") + 5)
# Collect results and assert
result = transformed_df.collect()
expected = [("Alice", 34, 39), ("Bob", 45, 50), ("Charlie", 23, 28)]
self.assertEqual([row.asDict() for row in result], [{"name": "Alice", "age": 34, "age_plus_5": 39},
{"name": "Bob", "age": 45, "age_plus_5": 50},
{"name": "Charlie", "age": 23, "age_plus_5": 28}])
def test_filter(self):
"""Test filtering rows based on a condition."""
# Sample data
data = [("Alice", 34), ("Bob", 45), ("Charlie", 23)]
df = self.spark.createDataFrame(data, ["name", "age"])
# Filter: Select rows where age > 30
filtered_df = df.filter(col("age") > 30)
# Collect results and assert
result = filtered_df.collect()
expected = [("Alice", 34), ("Bob", 45)]
self.assertEqual([row.asDict() for row in result], [{"name": "Alice", "age": 34}, {"name": "Bob", "age": 45}])
def test_group_by_aggregation(self):
"""Test groupBy and aggregation functions."""
# Sample data
data = [("Alice", "Sales", 3000), ("Bob", "HR", 4000), ("Charlie", "Sales", 5000), ("David", "HR", 3500)]
df = self.spark.createDataFrame(data, ["name", "department", "salary"])
# Group by 'department' and calculate the average salary
avg_salary_df = df.groupBy("department").avg("salary")
# Collect results and assert
result = avg_salary_df.collect()
expected = [("HR", 3750), ("Sales", 4000)]
self.assertEqual([row.asDict() for row in result], [{"department": "HR", "avg(salary)": 3750},
{"department": "Sales", "avg(salary)": 4000}])
def test_join(self):
"""Test joining two DataFrames."""
# Sample data
data1 = [(1, "Alice"), (2, "Bob"), (3, "Charlie")]
data2 = [(1, "Sales"), (2, "HR"), (3, "Marketing")]
df1 = self.spark.createDataFrame(data1, ["id", "name"])
df2 = self.spark.createDataFrame(data2, ["id", "department"])
# Perform an inner join
joined_df = df1.join(df2, "id")
# Collect results and assert
result = joined_df.collect()
expected = [(1, "Alice", "Sales"), (2, "Bob", "HR"), (3, "Charlie", "Marketing")]
self.assertEqual([row.asDict() for row in result], [{"id": 1, "name": "Alice", "department": "Sales"},
{"id": 2, "name": "Bob", "department": "HR"},
{"id": 3, "name": "Charlie", "department": "Marketing"}])
if __name__ == "__main__":
unittest.main()
Editor is loading...
Leave a Comment