Untitled

 avatar
unknown
plain_text
8 days ago
20 kB
1
Indexable
from __future__ import annotations

"""DataFrameDiff – robust dataframe comparison for Spark / Databricks.

Highlights
----------
* **Column‑name safe** – works with any legal identifier (spaces, dots, hyphens, back‑ticks …)
* **Complex‑type aware** – compares structs, arrays and maps with vectorised pandas UDFs when needed.
* **One‑stop API** – `DataFrameDiff.create_diff()` and `DataFrameDiff.get_diff_summary_dict()`
* **No hidden string SQL** – all column/field access goes through helper functions that guarantee safety.

Author: Prophecy / Databricks Solutions Engineering
"""

from functools import reduce
import operator as op
import uuid
import enum
from typing import List, Tuple, Dict, Any

import pandas as pd
from pyspark.sql import DataFrame, Column, functions as F
from pyspark.sql.types import (
    DataType,
    NullType,
    BooleanType,
    IntegerType,
    LongType,
    FloatType,
    DoubleType,
    DecimalType,
    DateType,
    TimestampType,
    StringType,
    MapType,
    ArrayType,
    StructType,
)

# ---------------------------------------------------------------------------
# Helpers for **safe** identifier / struct handling
# ---------------------------------------------------------------------------

def q(name: str) -> str:
    """Quote a Spark identifier so that any character becomes legal in string SQL."""
    return f"`{name.replace('`', '``')}`"


def sf(struct_col: str | Column, field_name: str) -> Column:
    """Safely extract *field_name* from *struct_col* without dotted strings."""
    if isinstance(struct_col, str):
        struct_col = F.col(struct_col)
    return struct_col.getField(field_name)


# ---------------------------------------------------------------------------
# Complex‑type comparison helpers
# ---------------------------------------------------------------------------

def compare_structs_vectorized(left_series: pd.Series, right_series: pd.Series) -> pd.Series:
    """Vectorised recursive comparison for complex types (used in a pandas UDF)."""

    def _cmp(l: Any, r: Any) -> bool:
        if l is None and r is None:
            return True
        if l is None or r is None:
            return False
        if isinstance(l, dict) and isinstance(r, dict):
            if set(l.keys()) != set(r.keys()):
                return False
            return all(_cmp(l[k], r[k]) for k in l)
        if isinstance(l, list) and isinstance(r, list):
            if len(l) != len(r):
                return False
            return all(_cmp(li, ri) for li, ri in zip(l, r))
        return l == r

    return left_series.combine(right_series, _cmp)


try:
    # Spark ≥ 3.2
    from pyspark.sql.pandas.functions import pandas_udf
except ImportError:
    from pyspark.sql.functions import pandas_udf  # type: ignore

compare_structs_udf = pandas_udf(compare_structs_vectorized, "boolean")


def needs_complex_comparison(dt: DataType) -> bool:
    """Return *True* if *dt* (possibly nested) contains MapType/ArrayType/StructType."""
    if isinstance(dt, MapType):
        return True
    if isinstance(dt, ArrayType):
        return needs_complex_comparison(dt.elementType)
    if isinstance(dt, StructType):
        return any(needs_complex_comparison(f.dataType) for f in dt.fields)
    return False


# ---------------------------------------------------------------------------
# Public enum – keys used in COMPTUTED_DIFFS
# ---------------------------------------------------------------------------

class DiffKeys(enum.Enum):
    JOINED = "joined"
    SUMMARY = "summary"
    CLEANED = "cleaned"
    EXPECTED = "expected"
    GENERATED = "generated"
    KEY_COLUMNS = "keyCols"
    VALUE_COLUMNS = "valueCols"


# ---------------------------------------------------------------------------
# Main class
# ---------------------------------------------------------------------------

class DataFrameDiff:
    """Compare two Spark DataFrames safely, even with the strangest column names."""

    COMPUTED_DIFFS: Dict[str, Dict[str, Any]] = {}

    # ---------------------------------------------------------------------
    #  Schema alignment helpers
    # ---------------------------------------------------------------------

    @staticmethod
    def _precedence(dt: DataType) -> int:
        if isinstance(dt, NullType):
            return 0
        if isinstance(dt, BooleanType):
            return 1
        if isinstance(dt, IntegerType):
            return 2
        if isinstance(dt, LongType):
            return 3
        if isinstance(dt, FloatType):
            return 4
        if isinstance(dt, DoubleType):
            return 5
        if isinstance(dt, DecimalType):
            return 6
        if isinstance(dt, DateType):
            return 7
        if isinstance(dt, TimestampType):
            return 8
        if isinstance(dt, StringType):
            return 9
        return 99

    @classmethod
    def _find_common_type(cls, dt1: DataType, dt2: DataType) -> DataType:
        if dt1 == dt2:
            return dt1
        if isinstance(dt1, DecimalType) and isinstance(dt2, DecimalType):
            return DecimalType(max(dt1.precision, dt2.precision), max(dt1.scale, dt2.scale))
        if isinstance(dt1, NullType):
            return dt2
        if isinstance(dt2, NullType):
            return dt1
        prec1, prec2 = cls._precedence(dt1), cls._precedence(dt2)
        numeric = (BooleanType, IntegerType, LongType, FloatType, DoubleType, DecimalType)
        if isinstance(dt1, numeric) and isinstance(dt2, numeric):
            return dt1 if prec1 >= prec2 else dt2
        if (isinstance(dt1, DateType) and isinstance(dt2, TimestampType)) or (
            isinstance(dt2, DateType) and isinstance(dt1, TimestampType)
        ):
            return TimestampType()
        return dt1 if prec1 > prec2 else dt2

    # ------------------------------------------------------------------
    #  Public helpers
    # ------------------------------------------------------------------

    @classmethod
    def align_schemas(cls, df1: DataFrame, df2: DataFrame) -> Tuple[DataFrame, DataFrame]:
        common = set(df1.columns).intersection(df2.columns)
        for c in common:
            dt = cls._find_common_type(df1.schema[c].dataType, df2.schema[c].dataType)
            if df1.schema[c].dataType != dt:
                df1 = df1.withColumn(c, F.col(q(c)).cast(dt))
            if df2.schema[c].dataType != dt:
                df2 = df2.withColumn(c, F.col(q(c)).cast(dt))
        return df1, df2

    @staticmethod
    def _unique(df: DataFrame, base: str) -> str:
        name, i = base, 0
        while name in df.columns:
            name = f"{base}_diff_{i}"
            i += 1
        return name

    @classmethod
    def split_by_pk_uniqueness(cls, df: DataFrame, key_cols: List[str]) -> Tuple[DataFrame, DataFrame]:
        # Build null‑safe representation for each key
        for col in key_cols:
            ns = cls._unique(df, f"__ns_{col}")
            df = df.withColumn(ns, F.struct(F.col(q(col)).isNull().alias("is_null"), F.col(q(col)).alias("val")))
        ns_cols = [c for c in df.columns if c.startswith("__ns_")]
        cnt = cls._unique(df, "__cnt__")
        pk_counts = df.groupBy(*ns_cols).agg(F.count("*").alias(cnt))
        once = pk_counts.filter(F.col(cnt) == 1).select(*ns_cols)
        many = pk_counts.filter(F.col(cnt) > 1).select(*ns_cols)
        unique_df = df.join(once, on=ns_cols, how="inner").drop(*ns_cols)
        dup_df = df.join(many, on=ns_cols, how="inner").drop(*ns_cols)
        return unique_df, dup_df

    # ------------------------------------------------------------------
    #  Core diff computation
    # ------------------------------------------------------------------

    @classmethod
    def _create_joined_df(
        cls,
        df_left: DataFrame,
        df_right: DataFrame,
        key_cols: List[str],
        value_cols: List[str],
    ) -> DataFrame:
        df_left, df_right = cls.align_schemas(df_left, df_right)
        joined = df_left.alias("left").join(df_right.alias("right"), on=key_cols, how="full_outer")

        # coalesced keys
        coalesced = [F.coalesce(F.col(f"left.{q(c)}"), F.col(f"right.{q(c)}")).alias(c) for c in key_cols]

        # presence flags via Column API
        cond_left = reduce(
            op.and_,
            [F.coalesce(F.col(f"left.{q(c)}"), F.col(f"right.{q(c)}")) == F.col(f"left.{q(c)}") for c in key_cols],
        )
        cond_right = reduce(
            op.and_,
            [F.coalesce(F.col(f"left.{q(c)}"), F.col(f"right.{q(c)}")) == F.col(f"right.{q(c)}") for c in key_cols],
        )
        presence_left = F.when(cond_left, 1).otherwise(0).alias("presence_in_left")
        presence_right = F.when(cond_right, 1).otherwise(0).alias("presence_in_right")

        # structs with aligned value columns
        left_struct = F.struct(
            *[F.col(f"left.{q(c)}").alias(c) if c in df_left.columns else F.lit(None).alias(c) for c in value_cols]
        ).alias("left_values")
        right_struct = F.struct(
            *[F.col(f"right.{q(c)}").alias(c) if c in df_right.columns else F.lit(None).alias(c) for c in value_cols]
        ).alias("right_values")

        return joined.select(*coalesced, presence_left, presence_right, left_struct, right_struct)

    # ------------------------------------------------------------------
    #  Row / column comparison helpers
    # ------------------------------------------------------------------

    @classmethod
    def _add_row_matches(cls, df: DataFrame) -> DataFrame:
        if needs_complex_comparison(df.schema["left_values"].dataType):
            cmp_col = compare_structs_udf(F.col("left_values"), F.col("right_values"))
        else:
            cmp_col = F.col("left_values").eqNullSafe(F.col("right_values"))
        return df.withColumn("row_matches", cmp_col)

    @classmethod
    def _add_column_results(cls, df: DataFrame) -> DataFrame:
        left_struct, right_struct = "left_values", "right_values"
        fields = df.schema[left_struct].dataType.fields
        both_exist = (F.col("presence_in_left") == 1) & (F.col("presence_in_right") == 1)

        def _cmp(f):
            l, r = sf(left_struct, f.name), sf(right_struct, f.name)
            if needs_complex_comparison(f.dataType):
                col = compare_structs_udf(l, r)
            else:
                col = l.eqNullSafe(r)
            return F.when(both_exist, col).otherwise(F.lit(False)).alias(f.name)

        comp_struct = F.struct(*[_cmp(f) for f in fields]).alias("compared_values")
        return df.withColumn("compared_values", comp_struct)

    # ------------------------------------------------------------------
    #  Summary + cleaned view
    # ------------------------------------------------------------------

    @classmethod
    def _mismatch_summary(cls, df: DataFrame) -> DataFrame:
        comp = "compared_values"
        cols = [f.name for f in df.schema[comp].dataType.fields]
        agg = [
            F.coalesce(F.sum(F.when(sf(comp, c), 1).otherwise(0)), F.lit(0)).alias(f"{c}_match_count")
            for c in cols
        ] + [
            F.coalesce(F.sum(F.when(~sf(comp, c), 1).otherwise(0)), F.lit(0)).alias(f"{c}_mismatch_count")
            for c in cols
        ]
        key_match = F.coalesce(
            F.sum(F.when((F.col("presence_in_left") == 1) & (F.col("presence_in_right") == 1), 1).otherwise(0)),
            F.lit(0),
        ).alias("key_columns_match_count")
        key_mismatch = F.coalesce(
            F.sum(F.when(~((F.col("presence_in_left") == 1) & (F.col("presence_in_right") == 1)), 1).otherwise(0)),
            F.lit(0),
        ).alias("key_columns_mismatch_count")
        rows_match = F.coalesce(F.sum(F.when(F.col("row_matches"), 1).otherwise(0)), F.lit(0)).alias("rows_matching")
        rows_nomatch = F.coalesce(F.sum(F.when(~F.col("row_matches"), 1).otherwise(0)), F.lit(0)).alias("rows_not_matching")
        return df.agg(rows_match, rows_nomatch, key_match, key_mismatch, *agg)

    # ------------------------------------------------------------------
    #  Public API
    # ------------------------------------------------------------------

    @staticmethod
    def _value_columns(df1: DataFrame, df2: DataFrame, key_cols: List[str]) -> List[str]:
        return list(set(df1.columns).union(df2.columns) - set(key_cols))

    @staticmethod
    def _schema_json(df: DataFrame):
        return [{"name": f.name, "type": f.dataType.simpleString()} for f in df.schema.fields]

    @classmethod
    def create_diff(
        cls,
        expected_df: DataFrame,
        generated_df: DataFrame,
        key_columns: List[str],
        diff_key: str,
    ) -> None:
        value_cols = cls._value_columns(expected_df, generated_df, key_columns)
        left_unique, _ = cls.split_by_pk_uniqueness(generated_df, key_columns)
        right_unique, _ = cls.split_by_pk_uniqueness(expected_df, key_columns)

        joined = cls._create_joined_df(left_unique, right_unique, key_columns, value_cols)
        joined = cls._add_row_matches(joined)
        joined = cls._add_column_results(joined)
        summary = cls._mismatch_summary(joined)
        cleaned = cls._clean_joined(joined, key_columns, value_cols, generated_df, expected_df)

        cls.COMPUTED_DIFFS[diff_key] = {
            DiffKeys.JOINED.value: joined,
            DiffKeys.SUMMARY.value: summary,
            DiffKeys.CLEANED.value: cleaned,
            DiffKeys.EXPECTED.value: expected_df,
            DiffKeys.GENERATED.value: generated_df,
            DiffKeys.KEY_COLUMNS.value: key_columns,
            DiffKeys.VALUE_COLUMNS.value: value_cols,
        }

    # ------------------------------------------------------------------
    #  Clean view for UI / downstream consumption
    # ------------------------------------------------------------------

    @classmethod
    def _clean_joined(
        cls,
        df: DataFrame,
        key_cols: List[str],
        value_cols: List[str],
        left_df: DataFrame,
        right_df: DataFrame,
    ) -> DataFrame:
        sel: List[Column] = [F.col(q(c)) for c in key_cols]
        for c in value_cols:
            if c in left_df.columns and c in right_df.columns:
                col = F.when(
                    sf("compared_values", c),
                    F.array(sf("left_values", c)),
                ).otherwise(F.array(sf("left_values", c), sf("right_values", c))).alias(c)
            elif c in left_df.columns:
                col = F.array(sf("left_values", c)).alias(c)
            else:
                col = F.array(sf("right_values", c)).alias(c)
            sel.append(col)
        sel += [F.col("row_matches"), F.col("presence_in_left"), F.col("presence_in_right")]
        return df.select(*sel)

    # ------------------------------------------------------------------
    #  Summary dictionary (for dashboards / JSON API)
    # ------------------------------------------------------------------

    @classmethod
    def get_diff_summary_dict(cls, diff_key: str) -> Dict[str, Any]:
        diff = cls.COMPUTED_DIFFS[diff_key]
        summary_row = diff[DiffKeys.SUMMARY.value].collect()[0].asDict()
        total_rows = summary_row["rows_matching"] + summary_row["rows_not_matching"]
        value_cols = diff[DiffKeys.VALUE_COLUMNS.value]
        key_cols = diff[DiffKeys.KEY_COLUMNS.value]

        perfect_val_cols = sum(
            1 for c in value_cols if summary_row[f"{c}_match_count"] == total_rows
        )
        key_cols_match = summary_row["key_columns_match_count"] == total_rows

        def _stats(df: DataFrame):
            uniq, dup = cls.split_by_pk_uniqueness(df, key_cols)
            return {
                "columns": cls._schema_json(df),
                "rowsCount": df.count(),
                "uniquePkCount": uniq.count(),
                "duplicatePkCount": dup.count(),
            }

        pct_rows = int((summary_row["rows_matching"] / total_rows * 100) if total_rows else 0)
        pct_cols = int(((perfect_val_cols + (len(key_cols) if key_cols_match else 0)) / (len(value_cols) + len(key_cols)) * 100))

        return {
            "label": diff_key,
            "data": {
                "summaryTiles": [
                    {
                        "title": "Datasets matching status",
                        "text": "Matching" if summary_row["rows_not_matching"] == 0 else "Not Matching",
                        "badgeContent": f"{pct_rows}",
                        "isPositive": summary_row["rows_not_matching"] == 0,
                        "order": 0,
                        "orderType": "MatchingStatus",
                        "toolTip": "Percentage of rows that match between expected and generated datasets.",
                    },
                    {
                        "title": "Number of columns matching",
                        "text": f"{perfect_val_cols + (len(key_cols) if key_cols_match else 0)}/{len(value_cols) + len(key_cols)}",
                        "badgeContent": f"{pct_cols}",
                        "isPositive": perfect_val_cols == len(value_cols) and key_cols_match,
                        "order": 1,
                        "orderType": "ColumnMatch",
                        "toolTip": "Percentage of columns that match between expected and generated datasets.",
                    },
                    {
                        "title": "Number of rows matching",
                        "text": f"{summary_row['rows_matching']:,}/{total_rows:,}",
                        "badgeContent": f"{pct_rows}",
                        "isPositive": summary_row["rows_matching"] == total_rows,
                        "order": 2,
                        "orderType": "RowMatch",
                        "toolTip": "Number of rows that match between expected and generated datasets.",
                    },
                ],
                "expData": _stats(diff[DiffKeys.EXPECTED.value]),
                "genData": _stats(diff[DiffKeys.GENERATED.value]),
                "commonData": {
                    "keyColumns": key_cols,
                    "columnComparisons": {
                        c: {
                            "matches": summary_row[f"{c}_match_count"],
                            "mismatches": summary_row[f"{c}_mismatch_count"],
                        }
                        for c in set(diff[DiffKeys.EXPECTED.value].columns)
                        .intersection(diff[DiffKeys.GENERATED.value].columns)
                        .intersection(value_cols)
                    },
                    "rowsMatchingCount": summary_row["rows_matching"],
                    "rowsMismatchingCount": summary_row["rows_not_matching"],
                    "keyColumnsMatchCount": summary_row["key_columns_match_count"],
                    "keyColumnsMismatchCount": summary_row["key_columns_mismatch_count"],
                },
            },
        }

    # ------------------------------------------------------------------
    #  Optional registration helper (kept for backwards compatibility)
    # ------------------------------------------------------------------

    @classmethod
    def datasampleloader_register(cls, diff_key: str, _type: DiffKeys) -> str:
        from .datasampleloader import DataSampleLoaderLib  # local import to avoid hard dep

        df = cls.COMPUTED_DIFFS[diff_key][_type.value]
        key = str(uuid.uuid4())
        DataSampleLoaderLib.register(key=key, df=df, create_truncated_columns=False)
        return key
Editor is loading...
Leave a Comment