Untitled

mail@pastecode.io avatar
unknown
plain_text
2 years ago
3.0 kB
3
Indexable
Never
import pyspark.sql.functions as F
from pyspark.sql.session import SparkSession
from pyspark.sql import functions as F
from pyspark.sql import types as T


def update_df(df, columns_dict):
    for i in columns_dict:
        splitted = i.split('.')
        if len(splitted) == 1:
            df = df.withColumn(splitted[0], columns_dict[splitted[0]])
        else:
            if splitted[0] not in df.columns:
                df = df.withColumn(splitted[0], F.lit(None).cast(T.StructType()))
            mask = '.'.join(splitted[:1])
            df_up = update_df(df.select(f'{mask}.*'),
                              {'.'.join(splitted[1:]):columns_dict['.'.join(splitted)]})
            prefix = "renamed_"
            df = df.select([F.col(c).alias(prefix+c) for c in df.columns])
            df = df.crossJoin(df_up) \
                   .withColumn(prefix+splitted[0], F.struct(*[F.col(c) for c in df_up.columns])) \
                   .drop(*df_up.columns)
            df = df.select([F.col(c).alias(c[len(prefix):]) for c in df.columns])
    updated_df = df
    return updated_df


---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-39-a1f7fca388ae> in <module>
     66 # print(res.printSchema())
     67 # print(res.collect())
---> 68 check_results(res)

<ipython-input-39-a1f7fca388ae> in check_results(df)
     31     assert set(df.columns) == set(["car", "owner"])
     32     assert set(df.select(F.col("car.*")).columns) == set(["brand",
---> 33     "model", "transmission", "color"])
     34     assert set(df.select(F.col("owner.*")).columns) == set(["first_name",
     35     "last_name"])

AssertionError: 

def check_results(df):
    assert set(df.columns) == set(["car", "owner"])
    assert set(df.select(F.col("car.*")).columns) == set(["brand",
    "model", "transmission", "color"])
    assert set(df.select(F.col("owner.*")).columns) == set(["first_name",
    "last_name"])
    assert df.filter(F.col("`car`.`transmission`.`wheel_drive`") ==
    "all").count() == 1
    assert df.filter(F.col("`car`.`transmission`.`gear_box`") ==
    "automatic").count() == 1
    print("All tests passed!")


input_df = (
    spark
    .range(0, 1)
    .select(
        F.struct(
            F.lit("bmw").alias("brand"),
            F.lit("220i").alias("model"),
            F.struct(
                F.lit("rear").alias("wheel_drive"),
                F.lit("automatic").alias("gear_box")
            ).alias("transmission")
        ).alias("car")
    )
)
columns_dict = {
    "car.brand": F.lit("audi"),
    "car.transmission.wheel_drive": F.lit("all"),
    "car.color": F.lit("black"),
    "car.sdf.sdf.sdf.sdf.d.color": F.lit("black"),
    "owner.first_name.first_name.first_name.first_name.first_name.first_name": F.lit("Ivan"),
    "owner.last_name": F.lit("Ivanov"),
}
res = update_df(input_df, columns_dict)
check_results(res)