Untitled
unknown
plain_text
2 years ago
1.2 kB
9
Indexable
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
# Initialize Spark session
spark = SparkSession.builder \
.appName("StratifiedSamplingExample") \
.getOrCreate()
# Example DataFrame
data = [
(1, 'A'), (2, 'A'), (3, 'B'), (4, 'B'), (5, 'C'),
(6, 'C'), (7, 'D'), (8, 'D'), (9, 'E'), (10, 'E')
]
df = spark.createDataFrame(data, ["id", "category"])
# Specify the category to include all samples from
category_to_include_all = 'A'
# Calculate fractions for each stratum
total_records = df.count()
desired_sample_size = 500000
strata_counts = df.groupBy("category").count().collect()
# Initialize fractions dictionary
fractions = {}
# Set the fraction for the specific category to 1.0 (include all samples)
for row in strata_counts:
if row['category'] == category_to_include_all:
fractions[row['category']] = 1.0
else:
# Calculate the fraction for other categories
fractions[row['category']] = min(1.0, (desired_sample_size / total_records) * (row['count'] / total_records))
# Perform stratified sampling
sampled_df = df.sampleBy("category", fractions, seed=42)
# Show the result
sampled_df.show()
Editor is loading...
Leave a Comment