Untitled

mail@pastecode.io avatar
unknown
plain_text
a year ago
2.0 kB
2
Indexable
Never
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, row_number, lit
from pyspark.sql.window import Window
import mlflow

class DataFrameStringIndexer(IndexerInterface):

    def __init__(self):
        spark = SparkSession.builder.getOrCreate()
        self.mapping_df = spark.createDataFrame([], schema=["string", "index"])
        # Used to ensure the uniqueness of new indices
        self.current_max_index = -1

    def fit(self, df, inputCol):
        # Extract unique strings
        unique_strings_df = df.select(inputCol).distinct()
        
        # Assign an index to each new string
        window_spec = Window.orderBy(lit(1))
        indexed_new_strings = (unique_strings_df
                               .join(self.mapping_df, unique_strings_df[inputCol] == self.mapping_df["string"], "left_anti")
                               .withColumn("index", row_number().over(window_spec) + self.current_max_index))
        
        # Update the current_max_index
        self.current_max_index += indexed_new_strings.count()
        
        # Union with the existing mapping DataFrame
        self.mapping_df = self.mapping_df.union(indexed_new_strings)

    def transform(self, df, inputCol, outputCol):
        return (df.join(self.mapping_df, df[inputCol] == self.mapping_df["string"], "left")
                .select(*df.columns, self.mapping_df["index"].alias(outputCol))
                .filter(col(outputCol).isNotNull()))

    def save(self, path):
        # Save the mapping DataFrame as a Parquet file
        self.mapping_df.write.parquet(path)
        # Using MLflow to log the file
        mlflow.log_artifact(path, "string_indexer")

    @classmethod
    def load(cls, path):
        spark = SparkSession.builder.getOrCreate()
        instance = cls()
        instance.mapping_df = spark.read.parquet(path)
        instance.current_max_index = instance.mapping_df.agg({"index": "max"}).collect()[0][0]
        return instance