Untitled
unknown
plain_text
a year ago
1.2 kB
4
Indexable
from pyspark.sql import SparkSession from pyspark.sql.functions import col # Initialize Spark session spark = SparkSession.builder \ .appName("StratifiedSamplingExample") \ .getOrCreate() # Example DataFrame data = [ (1, 'A'), (2, 'A'), (3, 'B'), (4, 'B'), (5, 'C'), (6, 'C'), (7, 'D'), (8, 'D'), (9, 'E'), (10, 'E') ] df = spark.createDataFrame(data, ["id", "category"]) # Specify the category to include all samples from category_to_include_all = 'A' # Calculate fractions for each stratum total_records = df.count() desired_sample_size = 500000 strata_counts = df.groupBy("category").count().collect() # Initialize fractions dictionary fractions = {} # Set the fraction for the specific category to 1.0 (include all samples) for row in strata_counts: if row['category'] == category_to_include_all: fractions[row['category']] = 1.0 else: # Calculate the fraction for other categories fractions[row['category']] = min(1.0, (desired_sample_size / total_records) * (row['count'] / total_records)) # Perform stratified sampling sampled_df = df.sampleBy("category", fractions, seed=42) # Show the result sampled_df.show()
Editor is loading...
Leave a Comment