import pyspark.sql.functions as F
from pyspark.sql.session import SparkSession
from pyspark.sql import functions as F
from pyspark.sql import types as T
def update_df(df, columns_dict):
for i in columns_dict:
splitted = i.split('.')
if len(splitted) == 1:
df = df.withColumn(splitted[0], columns_dict[splitted[0]])
else:
if splitted[0] not in df.columns:
df = df.withColumn(splitted[0], F.lit(None).cast(T.StructType()))
mask = '.'.join(splitted[:1])
df_up = update_df(df.select(f'{mask}.*'),
{'.'.join(splitted[1:]):columns_dict['.'.join(splitted)]})
prefix = "renamed_"
df = df.select([F.col(c).alias(prefix+c) for c in df.columns])
df = df.crossJoin(df_up) \
.withColumn(prefix+splitted[0], F.struct(*[F.col(c) for c in df_up.columns])) \
.drop(*df_up.columns)
df = df.select([F.col(c).alias(c[len(prefix):]) for c in df.columns])
updated_df = df
return updated_df
def check_results(df):
assert set(df.columns) == set(["car", "owner"])
assert set(df.select(F.col("car.*")).columns) == set(["brand",
"model", "transmission", "color"])
assert set(df.select(F.col("owner.*")).columns) == set(["first_name",
"last_name"])
assert df.filter(F.col("`car`.`transmission`.`wheel_drive`") ==
"all").count() == 1
assert df.filter(F.col("`car`.`transmission`.`gear_box`") ==
"automatic").count() == 1
print("All tests passed!")
input_df = (
spark
.range(0, 1)
.select(
F.struct(
F.lit("bmw").alias("brand"),
F.lit("220i").alias("model"),
F.struct(
F.lit("rear").alias("wheel_drive"),
F.lit("automatic").alias("gear_box")
).alias("transmission")
).alias("car")
)
)
columns_dict = {
"car.brand": F.lit("audi"),
"car.transmission.wheel_drive": F.lit("all"),
"car.color": F.lit("black"),
"car.sdf.sdf.sdf.sdf.d.color": F.lit("black"),
"owner.first_name.first_name.first_name.first_name.first_name.first_name": F.lit("Ivan"),
"owner.last_name": F.lit("Ivanov"),
}
res = update_df(input_df, columns_dict)
check_results(res)