Untitled
unknown
plain_text
2 years ago
2.0 kB
24
Indexable
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, row_number, lit
from pyspark.sql.window import Window
import mlflow
class DataFrameStringIndexer(IndexerInterface):
def __init__(self):
spark = SparkSession.builder.getOrCreate()
self.mapping_df = spark.createDataFrame([], schema=["string", "index"])
# Used to ensure the uniqueness of new indices
self.current_max_index = -1
def fit(self, df, inputCol):
# Extract unique strings
unique_strings_df = df.select(inputCol).distinct()
# Assign an index to each new string
window_spec = Window.orderBy(lit(1))
indexed_new_strings = (unique_strings_df
.join(self.mapping_df, unique_strings_df[inputCol] == self.mapping_df["string"], "left_anti")
.withColumn("index", row_number().over(window_spec) + self.current_max_index))
# Update the current_max_index
self.current_max_index += indexed_new_strings.count()
# Union with the existing mapping DataFrame
self.mapping_df = self.mapping_df.union(indexed_new_strings)
def transform(self, df, inputCol, outputCol):
return (df.join(self.mapping_df, df[inputCol] == self.mapping_df["string"], "left")
.select(*df.columns, self.mapping_df["index"].alias(outputCol))
.filter(col(outputCol).isNotNull()))
def save(self, path):
# Save the mapping DataFrame as a Parquet file
self.mapping_df.write.parquet(path)
# Using MLflow to log the file
mlflow.log_artifact(path, "string_indexer")
@classmethod
def load(cls, path):
spark = SparkSession.builder.getOrCreate()
instance = cls()
instance.mapping_df = spark.read.parquet(path)
instance.current_max_index = instance.mapping_df.agg({"index": "max"}).collect()[0][0]
return instance
Editor is loading...