Untitled
unknown
plain_text
8 days ago
20 kB
1
Indexable
from __future__ import annotations """DataFrameDiff – robust dataframe comparison for Spark / Databricks. Highlights ---------- * **Column‑name safe** – works with any legal identifier (spaces, dots, hyphens, back‑ticks …) * **Complex‑type aware** – compares structs, arrays and maps with vectorised pandas UDFs when needed. * **One‑stop API** – `DataFrameDiff.create_diff()` and `DataFrameDiff.get_diff_summary_dict()` * **No hidden string SQL** – all column/field access goes through helper functions that guarantee safety. Author: Prophecy / Databricks Solutions Engineering """ from functools import reduce import operator as op import uuid import enum from typing import List, Tuple, Dict, Any import pandas as pd from pyspark.sql import DataFrame, Column, functions as F from pyspark.sql.types import ( DataType, NullType, BooleanType, IntegerType, LongType, FloatType, DoubleType, DecimalType, DateType, TimestampType, StringType, MapType, ArrayType, StructType, ) # --------------------------------------------------------------------------- # Helpers for **safe** identifier / struct handling # --------------------------------------------------------------------------- def q(name: str) -> str: """Quote a Spark identifier so that any character becomes legal in string SQL.""" return f"`{name.replace('`', '``')}`" def sf(struct_col: str | Column, field_name: str) -> Column: """Safely extract *field_name* from *struct_col* without dotted strings.""" if isinstance(struct_col, str): struct_col = F.col(struct_col) return struct_col.getField(field_name) # --------------------------------------------------------------------------- # Complex‑type comparison helpers # --------------------------------------------------------------------------- def compare_structs_vectorized(left_series: pd.Series, right_series: pd.Series) -> pd.Series: """Vectorised recursive comparison for complex types (used in a pandas UDF).""" def _cmp(l: Any, r: Any) -> bool: if l is None and r is None: return True if l is None or r is None: return False if isinstance(l, dict) and isinstance(r, dict): if set(l.keys()) != set(r.keys()): return False return all(_cmp(l[k], r[k]) for k in l) if isinstance(l, list) and isinstance(r, list): if len(l) != len(r): return False return all(_cmp(li, ri) for li, ri in zip(l, r)) return l == r return left_series.combine(right_series, _cmp) try: # Spark ≥ 3.2 from pyspark.sql.pandas.functions import pandas_udf except ImportError: from pyspark.sql.functions import pandas_udf # type: ignore compare_structs_udf = pandas_udf(compare_structs_vectorized, "boolean") def needs_complex_comparison(dt: DataType) -> bool: """Return *True* if *dt* (possibly nested) contains MapType/ArrayType/StructType.""" if isinstance(dt, MapType): return True if isinstance(dt, ArrayType): return needs_complex_comparison(dt.elementType) if isinstance(dt, StructType): return any(needs_complex_comparison(f.dataType) for f in dt.fields) return False # --------------------------------------------------------------------------- # Public enum – keys used in COMPTUTED_DIFFS # --------------------------------------------------------------------------- class DiffKeys(enum.Enum): JOINED = "joined" SUMMARY = "summary" CLEANED = "cleaned" EXPECTED = "expected" GENERATED = "generated" KEY_COLUMNS = "keyCols" VALUE_COLUMNS = "valueCols" # --------------------------------------------------------------------------- # Main class # --------------------------------------------------------------------------- class DataFrameDiff: """Compare two Spark DataFrames safely, even with the strangest column names.""" COMPUTED_DIFFS: Dict[str, Dict[str, Any]] = {} # --------------------------------------------------------------------- # Schema alignment helpers # --------------------------------------------------------------------- @staticmethod def _precedence(dt: DataType) -> int: if isinstance(dt, NullType): return 0 if isinstance(dt, BooleanType): return 1 if isinstance(dt, IntegerType): return 2 if isinstance(dt, LongType): return 3 if isinstance(dt, FloatType): return 4 if isinstance(dt, DoubleType): return 5 if isinstance(dt, DecimalType): return 6 if isinstance(dt, DateType): return 7 if isinstance(dt, TimestampType): return 8 if isinstance(dt, StringType): return 9 return 99 @classmethod def _find_common_type(cls, dt1: DataType, dt2: DataType) -> DataType: if dt1 == dt2: return dt1 if isinstance(dt1, DecimalType) and isinstance(dt2, DecimalType): return DecimalType(max(dt1.precision, dt2.precision), max(dt1.scale, dt2.scale)) if isinstance(dt1, NullType): return dt2 if isinstance(dt2, NullType): return dt1 prec1, prec2 = cls._precedence(dt1), cls._precedence(dt2) numeric = (BooleanType, IntegerType, LongType, FloatType, DoubleType, DecimalType) if isinstance(dt1, numeric) and isinstance(dt2, numeric): return dt1 if prec1 >= prec2 else dt2 if (isinstance(dt1, DateType) and isinstance(dt2, TimestampType)) or ( isinstance(dt2, DateType) and isinstance(dt1, TimestampType) ): return TimestampType() return dt1 if prec1 > prec2 else dt2 # ------------------------------------------------------------------ # Public helpers # ------------------------------------------------------------------ @classmethod def align_schemas(cls, df1: DataFrame, df2: DataFrame) -> Tuple[DataFrame, DataFrame]: common = set(df1.columns).intersection(df2.columns) for c in common: dt = cls._find_common_type(df1.schema[c].dataType, df2.schema[c].dataType) if df1.schema[c].dataType != dt: df1 = df1.withColumn(c, F.col(q(c)).cast(dt)) if df2.schema[c].dataType != dt: df2 = df2.withColumn(c, F.col(q(c)).cast(dt)) return df1, df2 @staticmethod def _unique(df: DataFrame, base: str) -> str: name, i = base, 0 while name in df.columns: name = f"{base}_diff_{i}" i += 1 return name @classmethod def split_by_pk_uniqueness(cls, df: DataFrame, key_cols: List[str]) -> Tuple[DataFrame, DataFrame]: # Build null‑safe representation for each key for col in key_cols: ns = cls._unique(df, f"__ns_{col}") df = df.withColumn(ns, F.struct(F.col(q(col)).isNull().alias("is_null"), F.col(q(col)).alias("val"))) ns_cols = [c for c in df.columns if c.startswith("__ns_")] cnt = cls._unique(df, "__cnt__") pk_counts = df.groupBy(*ns_cols).agg(F.count("*").alias(cnt)) once = pk_counts.filter(F.col(cnt) == 1).select(*ns_cols) many = pk_counts.filter(F.col(cnt) > 1).select(*ns_cols) unique_df = df.join(once, on=ns_cols, how="inner").drop(*ns_cols) dup_df = df.join(many, on=ns_cols, how="inner").drop(*ns_cols) return unique_df, dup_df # ------------------------------------------------------------------ # Core diff computation # ------------------------------------------------------------------ @classmethod def _create_joined_df( cls, df_left: DataFrame, df_right: DataFrame, key_cols: List[str], value_cols: List[str], ) -> DataFrame: df_left, df_right = cls.align_schemas(df_left, df_right) joined = df_left.alias("left").join(df_right.alias("right"), on=key_cols, how="full_outer") # coalesced keys coalesced = [F.coalesce(F.col(f"left.{q(c)}"), F.col(f"right.{q(c)}")).alias(c) for c in key_cols] # presence flags via Column API cond_left = reduce( op.and_, [F.coalesce(F.col(f"left.{q(c)}"), F.col(f"right.{q(c)}")) == F.col(f"left.{q(c)}") for c in key_cols], ) cond_right = reduce( op.and_, [F.coalesce(F.col(f"left.{q(c)}"), F.col(f"right.{q(c)}")) == F.col(f"right.{q(c)}") for c in key_cols], ) presence_left = F.when(cond_left, 1).otherwise(0).alias("presence_in_left") presence_right = F.when(cond_right, 1).otherwise(0).alias("presence_in_right") # structs with aligned value columns left_struct = F.struct( *[F.col(f"left.{q(c)}").alias(c) if c in df_left.columns else F.lit(None).alias(c) for c in value_cols] ).alias("left_values") right_struct = F.struct( *[F.col(f"right.{q(c)}").alias(c) if c in df_right.columns else F.lit(None).alias(c) for c in value_cols] ).alias("right_values") return joined.select(*coalesced, presence_left, presence_right, left_struct, right_struct) # ------------------------------------------------------------------ # Row / column comparison helpers # ------------------------------------------------------------------ @classmethod def _add_row_matches(cls, df: DataFrame) -> DataFrame: if needs_complex_comparison(df.schema["left_values"].dataType): cmp_col = compare_structs_udf(F.col("left_values"), F.col("right_values")) else: cmp_col = F.col("left_values").eqNullSafe(F.col("right_values")) return df.withColumn("row_matches", cmp_col) @classmethod def _add_column_results(cls, df: DataFrame) -> DataFrame: left_struct, right_struct = "left_values", "right_values" fields = df.schema[left_struct].dataType.fields both_exist = (F.col("presence_in_left") == 1) & (F.col("presence_in_right") == 1) def _cmp(f): l, r = sf(left_struct, f.name), sf(right_struct, f.name) if needs_complex_comparison(f.dataType): col = compare_structs_udf(l, r) else: col = l.eqNullSafe(r) return F.when(both_exist, col).otherwise(F.lit(False)).alias(f.name) comp_struct = F.struct(*[_cmp(f) for f in fields]).alias("compared_values") return df.withColumn("compared_values", comp_struct) # ------------------------------------------------------------------ # Summary + cleaned view # ------------------------------------------------------------------ @classmethod def _mismatch_summary(cls, df: DataFrame) -> DataFrame: comp = "compared_values" cols = [f.name for f in df.schema[comp].dataType.fields] agg = [ F.coalesce(F.sum(F.when(sf(comp, c), 1).otherwise(0)), F.lit(0)).alias(f"{c}_match_count") for c in cols ] + [ F.coalesce(F.sum(F.when(~sf(comp, c), 1).otherwise(0)), F.lit(0)).alias(f"{c}_mismatch_count") for c in cols ] key_match = F.coalesce( F.sum(F.when((F.col("presence_in_left") == 1) & (F.col("presence_in_right") == 1), 1).otherwise(0)), F.lit(0), ).alias("key_columns_match_count") key_mismatch = F.coalesce( F.sum(F.when(~((F.col("presence_in_left") == 1) & (F.col("presence_in_right") == 1)), 1).otherwise(0)), F.lit(0), ).alias("key_columns_mismatch_count") rows_match = F.coalesce(F.sum(F.when(F.col("row_matches"), 1).otherwise(0)), F.lit(0)).alias("rows_matching") rows_nomatch = F.coalesce(F.sum(F.when(~F.col("row_matches"), 1).otherwise(0)), F.lit(0)).alias("rows_not_matching") return df.agg(rows_match, rows_nomatch, key_match, key_mismatch, *agg) # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ @staticmethod def _value_columns(df1: DataFrame, df2: DataFrame, key_cols: List[str]) -> List[str]: return list(set(df1.columns).union(df2.columns) - set(key_cols)) @staticmethod def _schema_json(df: DataFrame): return [{"name": f.name, "type": f.dataType.simpleString()} for f in df.schema.fields] @classmethod def create_diff( cls, expected_df: DataFrame, generated_df: DataFrame, key_columns: List[str], diff_key: str, ) -> None: value_cols = cls._value_columns(expected_df, generated_df, key_columns) left_unique, _ = cls.split_by_pk_uniqueness(generated_df, key_columns) right_unique, _ = cls.split_by_pk_uniqueness(expected_df, key_columns) joined = cls._create_joined_df(left_unique, right_unique, key_columns, value_cols) joined = cls._add_row_matches(joined) joined = cls._add_column_results(joined) summary = cls._mismatch_summary(joined) cleaned = cls._clean_joined(joined, key_columns, value_cols, generated_df, expected_df) cls.COMPUTED_DIFFS[diff_key] = { DiffKeys.JOINED.value: joined, DiffKeys.SUMMARY.value: summary, DiffKeys.CLEANED.value: cleaned, DiffKeys.EXPECTED.value: expected_df, DiffKeys.GENERATED.value: generated_df, DiffKeys.KEY_COLUMNS.value: key_columns, DiffKeys.VALUE_COLUMNS.value: value_cols, } # ------------------------------------------------------------------ # Clean view for UI / downstream consumption # ------------------------------------------------------------------ @classmethod def _clean_joined( cls, df: DataFrame, key_cols: List[str], value_cols: List[str], left_df: DataFrame, right_df: DataFrame, ) -> DataFrame: sel: List[Column] = [F.col(q(c)) for c in key_cols] for c in value_cols: if c in left_df.columns and c in right_df.columns: col = F.when( sf("compared_values", c), F.array(sf("left_values", c)), ).otherwise(F.array(sf("left_values", c), sf("right_values", c))).alias(c) elif c in left_df.columns: col = F.array(sf("left_values", c)).alias(c) else: col = F.array(sf("right_values", c)).alias(c) sel.append(col) sel += [F.col("row_matches"), F.col("presence_in_left"), F.col("presence_in_right")] return df.select(*sel) # ------------------------------------------------------------------ # Summary dictionary (for dashboards / JSON API) # ------------------------------------------------------------------ @classmethod def get_diff_summary_dict(cls, diff_key: str) -> Dict[str, Any]: diff = cls.COMPUTED_DIFFS[diff_key] summary_row = diff[DiffKeys.SUMMARY.value].collect()[0].asDict() total_rows = summary_row["rows_matching"] + summary_row["rows_not_matching"] value_cols = diff[DiffKeys.VALUE_COLUMNS.value] key_cols = diff[DiffKeys.KEY_COLUMNS.value] perfect_val_cols = sum( 1 for c in value_cols if summary_row[f"{c}_match_count"] == total_rows ) key_cols_match = summary_row["key_columns_match_count"] == total_rows def _stats(df: DataFrame): uniq, dup = cls.split_by_pk_uniqueness(df, key_cols) return { "columns": cls._schema_json(df), "rowsCount": df.count(), "uniquePkCount": uniq.count(), "duplicatePkCount": dup.count(), } pct_rows = int((summary_row["rows_matching"] / total_rows * 100) if total_rows else 0) pct_cols = int(((perfect_val_cols + (len(key_cols) if key_cols_match else 0)) / (len(value_cols) + len(key_cols)) * 100)) return { "label": diff_key, "data": { "summaryTiles": [ { "title": "Datasets matching status", "text": "Matching" if summary_row["rows_not_matching"] == 0 else "Not Matching", "badgeContent": f"{pct_rows}", "isPositive": summary_row["rows_not_matching"] == 0, "order": 0, "orderType": "MatchingStatus", "toolTip": "Percentage of rows that match between expected and generated datasets.", }, { "title": "Number of columns matching", "text": f"{perfect_val_cols + (len(key_cols) if key_cols_match else 0)}/{len(value_cols) + len(key_cols)}", "badgeContent": f"{pct_cols}", "isPositive": perfect_val_cols == len(value_cols) and key_cols_match, "order": 1, "orderType": "ColumnMatch", "toolTip": "Percentage of columns that match between expected and generated datasets.", }, { "title": "Number of rows matching", "text": f"{summary_row['rows_matching']:,}/{total_rows:,}", "badgeContent": f"{pct_rows}", "isPositive": summary_row["rows_matching"] == total_rows, "order": 2, "orderType": "RowMatch", "toolTip": "Number of rows that match between expected and generated datasets.", }, ], "expData": _stats(diff[DiffKeys.EXPECTED.value]), "genData": _stats(diff[DiffKeys.GENERATED.value]), "commonData": { "keyColumns": key_cols, "columnComparisons": { c: { "matches": summary_row[f"{c}_match_count"], "mismatches": summary_row[f"{c}_mismatch_count"], } for c in set(diff[DiffKeys.EXPECTED.value].columns) .intersection(diff[DiffKeys.GENERATED.value].columns) .intersection(value_cols) }, "rowsMatchingCount": summary_row["rows_matching"], "rowsMismatchingCount": summary_row["rows_not_matching"], "keyColumnsMatchCount": summary_row["key_columns_match_count"], "keyColumnsMismatchCount": summary_row["key_columns_mismatch_count"], }, }, } # ------------------------------------------------------------------ # Optional registration helper (kept for backwards compatibility) # ------------------------------------------------------------------ @classmethod def datasampleloader_register(cls, diff_key: str, _type: DiffKeys) -> str: from .datasampleloader import DataSampleLoaderLib # local import to avoid hard dep df = cls.COMPUTED_DIFFS[diff_key][_type.value] key = str(uuid.uuid4()) DataSampleLoaderLib.register(key=key, df=df, create_truncated_columns=False) return key
Editor is loading...
Leave a Comment