Untitled
unknown
plain_text
2 years ago
2.1 kB
10
Indexable
class DictStringIndexer(StringIndexerInterface):
def __init__(self):
self.string_to_index: Dict[str, int] = {}
self.index_to_string: Dict[int, str] = {}
self.mapping_broadcast: Optional[Broadcast] = None
def fit(self, df, input_col):
unique_strings = (
df.select(input_col).distinct().rdd.flatMap(lambda x: x).collect()
)
for s in unique_strings:
if s not in self.string_to_index:
idx = len(self.string_to_index)
self.string_to_index[s] = idx
self.index_to_string[idx] = s
self.mapping_broadcast = SparkSession.builder.getOrCreate().sparkContext.broadcast(self.string_to_index) # type: ignore
def transform(self, df, input_col, output_col):
if not self.mapping_broadcast:
raise ValueError(
"The indexer has not been fitted yet. Use fit first."
)
@udf(IntegerType())
def map_string(s):
return self.mapping_broadcast.value.get(s, None) # type: ignore
logger.info(f"Transforming dataframe. Input col: {input_col}, output col: {output_col}")
return df.withColumn(output_col, map_string(df[input_col])).filter(
col(output_col).isNotNull()
)
def save(self):
string_indexer_tmp_path = Path.cwd() / "tmp"
string_indexer_tmp_path.mkdir(exist_ok=True)
string_indexer_tmp_path = (
string_indexer_tmp_path / "string_indexer.json"
)
with open(string_indexer_tmp_path, "w") as f:
json.dump(self.string_to_index, f)
mlflow.log_artifact(str(string_indexer_tmp_path), "string_indexer")
@classmethod
def load(cls, string_to_index: Dict):
instance = cls()
instance.string_to_index = string_to_index
instance.index_to_string = {v: k for k, v in string_to_index.items()}
instance.mapping_broadcast = SparkSession.builder.getOrCreate().sparkContext.broadcast(instance.string_to_index) # type: ignore
return instanceEditor is loading...