Untitled

mail@pastecode.io avatar
unknown
plain_text
a year ago
2.1 kB
2
Indexable
Never
class DictStringIndexer(StringIndexerInterface):
    def __init__(self):
        self.string_to_index: Dict[str, int] = {}
        self.index_to_string: Dict[int, str] = {}
        self.mapping_broadcast: Optional[Broadcast] = None

    def fit(self, df, input_col):
        unique_strings = (
            df.select(input_col).distinct().rdd.flatMap(lambda x: x).collect()
        )

        for s in unique_strings:
            if s not in self.string_to_index:
                idx = len(self.string_to_index)
                self.string_to_index[s] = idx
                self.index_to_string[idx] = s

        self.mapping_broadcast = SparkSession.builder.getOrCreate().sparkContext.broadcast(self.string_to_index)  # type: ignore

    def transform(self, df, input_col, output_col):
        if not self.mapping_broadcast:
            raise ValueError(
                "The indexer has not been fitted yet. Use fit first."
            )

        @udf(IntegerType())
        def map_string(s):
            return self.mapping_broadcast.value.get(s, None)  # type: ignore
        logger.info(f"Transforming dataframe. Input col: {input_col}, output col: {output_col}")
        return df.withColumn(output_col, map_string(df[input_col])).filter(
            col(output_col).isNotNull()
        )
    
    def save(self):
        string_indexer_tmp_path = Path.cwd() / "tmp"
        string_indexer_tmp_path.mkdir(exist_ok=True)
        string_indexer_tmp_path = (
            string_indexer_tmp_path / "string_indexer.json"
        )
        with open(string_indexer_tmp_path, "w") as f:
            json.dump(self.string_to_index, f)

        mlflow.log_artifact(str(string_indexer_tmp_path), "string_indexer")

    @classmethod
    def load(cls, string_to_index: Dict):
        instance = cls()
        instance.string_to_index = string_to_index
        instance.index_to_string = {v: k for k, v in string_to_index.items()}
        instance.mapping_broadcast = SparkSession.builder.getOrCreate().sparkContext.broadcast(instance.string_to_index)  # type: ignore

        return instance