Untitled
unknown
plain_text
2 years ago
2.0 kB
4
Indexable
from pyspark.sql import SparkSession from pyspark.sql.functions import col, row_number, lit from pyspark.sql.window import Window import mlflow class DataFrameStringIndexer(IndexerInterface): def __init__(self): spark = SparkSession.builder.getOrCreate() self.mapping_df = spark.createDataFrame([], schema=["string", "index"]) # Used to ensure the uniqueness of new indices self.current_max_index = -1 def fit(self, df, inputCol): # Extract unique strings unique_strings_df = df.select(inputCol).distinct() # Assign an index to each new string window_spec = Window.orderBy(lit(1)) indexed_new_strings = (unique_strings_df .join(self.mapping_df, unique_strings_df[inputCol] == self.mapping_df["string"], "left_anti") .withColumn("index", row_number().over(window_spec) + self.current_max_index)) # Update the current_max_index self.current_max_index += indexed_new_strings.count() # Union with the existing mapping DataFrame self.mapping_df = self.mapping_df.union(indexed_new_strings) def transform(self, df, inputCol, outputCol): return (df.join(self.mapping_df, df[inputCol] == self.mapping_df["string"], "left") .select(*df.columns, self.mapping_df["index"].alias(outputCol)) .filter(col(outputCol).isNotNull())) def save(self, path): # Save the mapping DataFrame as a Parquet file self.mapping_df.write.parquet(path) # Using MLflow to log the file mlflow.log_artifact(path, "string_indexer") @classmethod def load(cls, path): spark = SparkSession.builder.getOrCreate() instance = cls() instance.mapping_df = spark.read.parquet(path) instance.current_max_index = instance.mapping_df.agg({"index": "max"}).collect()[0][0] return instance
Editor is loading...