Untitled
unknown
plain_text
a year ago
2.1 kB
2
Indexable
Never
class DictStringIndexer(StringIndexerInterface): def __init__(self): self.string_to_index: Dict[str, int] = {} self.index_to_string: Dict[int, str] = {} self.mapping_broadcast: Optional[Broadcast] = None def fit(self, df, input_col): unique_strings = ( df.select(input_col).distinct().rdd.flatMap(lambda x: x).collect() ) for s in unique_strings: if s not in self.string_to_index: idx = len(self.string_to_index) self.string_to_index[s] = idx self.index_to_string[idx] = s self.mapping_broadcast = SparkSession.builder.getOrCreate().sparkContext.broadcast(self.string_to_index) # type: ignore def transform(self, df, input_col, output_col): if not self.mapping_broadcast: raise ValueError( "The indexer has not been fitted yet. Use fit first." ) @udf(IntegerType()) def map_string(s): return self.mapping_broadcast.value.get(s, None) # type: ignore logger.info(f"Transforming dataframe. Input col: {input_col}, output col: {output_col}") return df.withColumn(output_col, map_string(df[input_col])).filter( col(output_col).isNotNull() ) def save(self): string_indexer_tmp_path = Path.cwd() / "tmp" string_indexer_tmp_path.mkdir(exist_ok=True) string_indexer_tmp_path = ( string_indexer_tmp_path / "string_indexer.json" ) with open(string_indexer_tmp_path, "w") as f: json.dump(self.string_to_index, f) mlflow.log_artifact(str(string_indexer_tmp_path), "string_indexer") @classmethod def load(cls, string_to_index: Dict): instance = cls() instance.string_to_index = string_to_index instance.index_to_string = {v: k for k, v in string_to_index.items()} instance.mapping_broadcast = SparkSession.builder.getOrCreate().sparkContext.broadcast(instance.string_to_index) # type: ignore return instance