Untitled
unknown
plain_text
8 months ago
20 kB
5
Indexable
from __future__ import annotations
"""DataFrameDiff – robust dataframe comparison for Spark / Databricks.
Highlights
----------
* **Column‑name safe** – works with any legal identifier (spaces, dots, hyphens, back‑ticks …)
* **Complex‑type aware** – compares structs, arrays and maps with vectorised pandas UDFs when needed.
* **One‑stop API** – `DataFrameDiff.create_diff()` and `DataFrameDiff.get_diff_summary_dict()`
* **No hidden string SQL** – all column/field access goes through helper functions that guarantee safety.
Author: Prophecy / Databricks Solutions Engineering
"""
from functools import reduce
import operator as op
import uuid
import enum
from typing import List, Tuple, Dict, Any
import pandas as pd
from pyspark.sql import DataFrame, Column, functions as F
from pyspark.sql.types import (
DataType,
NullType,
BooleanType,
IntegerType,
LongType,
FloatType,
DoubleType,
DecimalType,
DateType,
TimestampType,
StringType,
MapType,
ArrayType,
StructType,
)
# ---------------------------------------------------------------------------
# Helpers for **safe** identifier / struct handling
# ---------------------------------------------------------------------------
def q(name: str) -> str:
"""Quote a Spark identifier so that any character becomes legal in string SQL."""
return f"`{name.replace('`', '``')}`"
def sf(struct_col: str | Column, field_name: str) -> Column:
"""Safely extract *field_name* from *struct_col* without dotted strings."""
if isinstance(struct_col, str):
struct_col = F.col(struct_col)
return struct_col.getField(field_name)
# ---------------------------------------------------------------------------
# Complex‑type comparison helpers
# ---------------------------------------------------------------------------
def compare_structs_vectorized(left_series: pd.Series, right_series: pd.Series) -> pd.Series:
"""Vectorised recursive comparison for complex types (used in a pandas UDF)."""
def _cmp(l: Any, r: Any) -> bool:
if l is None and r is None:
return True
if l is None or r is None:
return False
if isinstance(l, dict) and isinstance(r, dict):
if set(l.keys()) != set(r.keys()):
return False
return all(_cmp(l[k], r[k]) for k in l)
if isinstance(l, list) and isinstance(r, list):
if len(l) != len(r):
return False
return all(_cmp(li, ri) for li, ri in zip(l, r))
return l == r
return left_series.combine(right_series, _cmp)
try:
# Spark ≥ 3.2
from pyspark.sql.pandas.functions import pandas_udf
except ImportError:
from pyspark.sql.functions import pandas_udf # type: ignore
compare_structs_udf = pandas_udf(compare_structs_vectorized, "boolean")
def needs_complex_comparison(dt: DataType) -> bool:
"""Return *True* if *dt* (possibly nested) contains MapType/ArrayType/StructType."""
if isinstance(dt, MapType):
return True
if isinstance(dt, ArrayType):
return needs_complex_comparison(dt.elementType)
if isinstance(dt, StructType):
return any(needs_complex_comparison(f.dataType) for f in dt.fields)
return False
# ---------------------------------------------------------------------------
# Public enum – keys used in COMPTUTED_DIFFS
# ---------------------------------------------------------------------------
class DiffKeys(enum.Enum):
JOINED = "joined"
SUMMARY = "summary"
CLEANED = "cleaned"
EXPECTED = "expected"
GENERATED = "generated"
KEY_COLUMNS = "keyCols"
VALUE_COLUMNS = "valueCols"
# ---------------------------------------------------------------------------
# Main class
# ---------------------------------------------------------------------------
class DataFrameDiff:
"""Compare two Spark DataFrames safely, even with the strangest column names."""
COMPUTED_DIFFS: Dict[str, Dict[str, Any]] = {}
# ---------------------------------------------------------------------
# Schema alignment helpers
# ---------------------------------------------------------------------
@staticmethod
def _precedence(dt: DataType) -> int:
if isinstance(dt, NullType):
return 0
if isinstance(dt, BooleanType):
return 1
if isinstance(dt, IntegerType):
return 2
if isinstance(dt, LongType):
return 3
if isinstance(dt, FloatType):
return 4
if isinstance(dt, DoubleType):
return 5
if isinstance(dt, DecimalType):
return 6
if isinstance(dt, DateType):
return 7
if isinstance(dt, TimestampType):
return 8
if isinstance(dt, StringType):
return 9
return 99
@classmethod
def _find_common_type(cls, dt1: DataType, dt2: DataType) -> DataType:
if dt1 == dt2:
return dt1
if isinstance(dt1, DecimalType) and isinstance(dt2, DecimalType):
return DecimalType(max(dt1.precision, dt2.precision), max(dt1.scale, dt2.scale))
if isinstance(dt1, NullType):
return dt2
if isinstance(dt2, NullType):
return dt1
prec1, prec2 = cls._precedence(dt1), cls._precedence(dt2)
numeric = (BooleanType, IntegerType, LongType, FloatType, DoubleType, DecimalType)
if isinstance(dt1, numeric) and isinstance(dt2, numeric):
return dt1 if prec1 >= prec2 else dt2
if (isinstance(dt1, DateType) and isinstance(dt2, TimestampType)) or (
isinstance(dt2, DateType) and isinstance(dt1, TimestampType)
):
return TimestampType()
return dt1 if prec1 > prec2 else dt2
# ------------------------------------------------------------------
# Public helpers
# ------------------------------------------------------------------
@classmethod
def align_schemas(cls, df1: DataFrame, df2: DataFrame) -> Tuple[DataFrame, DataFrame]:
common = set(df1.columns).intersection(df2.columns)
for c in common:
dt = cls._find_common_type(df1.schema[c].dataType, df2.schema[c].dataType)
if df1.schema[c].dataType != dt:
df1 = df1.withColumn(c, F.col(q(c)).cast(dt))
if df2.schema[c].dataType != dt:
df2 = df2.withColumn(c, F.col(q(c)).cast(dt))
return df1, df2
@staticmethod
def _unique(df: DataFrame, base: str) -> str:
name, i = base, 0
while name in df.columns:
name = f"{base}_diff_{i}"
i += 1
return name
@classmethod
def split_by_pk_uniqueness(cls, df: DataFrame, key_cols: List[str]) -> Tuple[DataFrame, DataFrame]:
# Build null‑safe representation for each key
for col in key_cols:
ns = cls._unique(df, f"__ns_{col}")
df = df.withColumn(ns, F.struct(F.col(q(col)).isNull().alias("is_null"), F.col(q(col)).alias("val")))
ns_cols = [c for c in df.columns if c.startswith("__ns_")]
cnt = cls._unique(df, "__cnt__")
pk_counts = df.groupBy(*ns_cols).agg(F.count("*").alias(cnt))
once = pk_counts.filter(F.col(cnt) == 1).select(*ns_cols)
many = pk_counts.filter(F.col(cnt) > 1).select(*ns_cols)
unique_df = df.join(once, on=ns_cols, how="inner").drop(*ns_cols)
dup_df = df.join(many, on=ns_cols, how="inner").drop(*ns_cols)
return unique_df, dup_df
# ------------------------------------------------------------------
# Core diff computation
# ------------------------------------------------------------------
@classmethod
def _create_joined_df(
cls,
df_left: DataFrame,
df_right: DataFrame,
key_cols: List[str],
value_cols: List[str],
) -> DataFrame:
df_left, df_right = cls.align_schemas(df_left, df_right)
joined = df_left.alias("left").join(df_right.alias("right"), on=key_cols, how="full_outer")
# coalesced keys
coalesced = [F.coalesce(F.col(f"left.{q(c)}"), F.col(f"right.{q(c)}")).alias(c) for c in key_cols]
# presence flags via Column API
cond_left = reduce(
op.and_,
[F.coalesce(F.col(f"left.{q(c)}"), F.col(f"right.{q(c)}")) == F.col(f"left.{q(c)}") for c in key_cols],
)
cond_right = reduce(
op.and_,
[F.coalesce(F.col(f"left.{q(c)}"), F.col(f"right.{q(c)}")) == F.col(f"right.{q(c)}") for c in key_cols],
)
presence_left = F.when(cond_left, 1).otherwise(0).alias("presence_in_left")
presence_right = F.when(cond_right, 1).otherwise(0).alias("presence_in_right")
# structs with aligned value columns
left_struct = F.struct(
*[F.col(f"left.{q(c)}").alias(c) if c in df_left.columns else F.lit(None).alias(c) for c in value_cols]
).alias("left_values")
right_struct = F.struct(
*[F.col(f"right.{q(c)}").alias(c) if c in df_right.columns else F.lit(None).alias(c) for c in value_cols]
).alias("right_values")
return joined.select(*coalesced, presence_left, presence_right, left_struct, right_struct)
# ------------------------------------------------------------------
# Row / column comparison helpers
# ------------------------------------------------------------------
@classmethod
def _add_row_matches(cls, df: DataFrame) -> DataFrame:
if needs_complex_comparison(df.schema["left_values"].dataType):
cmp_col = compare_structs_udf(F.col("left_values"), F.col("right_values"))
else:
cmp_col = F.col("left_values").eqNullSafe(F.col("right_values"))
return df.withColumn("row_matches", cmp_col)
@classmethod
def _add_column_results(cls, df: DataFrame) -> DataFrame:
left_struct, right_struct = "left_values", "right_values"
fields = df.schema[left_struct].dataType.fields
both_exist = (F.col("presence_in_left") == 1) & (F.col("presence_in_right") == 1)
def _cmp(f):
l, r = sf(left_struct, f.name), sf(right_struct, f.name)
if needs_complex_comparison(f.dataType):
col = compare_structs_udf(l, r)
else:
col = l.eqNullSafe(r)
return F.when(both_exist, col).otherwise(F.lit(False)).alias(f.name)
comp_struct = F.struct(*[_cmp(f) for f in fields]).alias("compared_values")
return df.withColumn("compared_values", comp_struct)
# ------------------------------------------------------------------
# Summary + cleaned view
# ------------------------------------------------------------------
@classmethod
def _mismatch_summary(cls, df: DataFrame) -> DataFrame:
comp = "compared_values"
cols = [f.name for f in df.schema[comp].dataType.fields]
agg = [
F.coalesce(F.sum(F.when(sf(comp, c), 1).otherwise(0)), F.lit(0)).alias(f"{c}_match_count")
for c in cols
] + [
F.coalesce(F.sum(F.when(~sf(comp, c), 1).otherwise(0)), F.lit(0)).alias(f"{c}_mismatch_count")
for c in cols
]
key_match = F.coalesce(
F.sum(F.when((F.col("presence_in_left") == 1) & (F.col("presence_in_right") == 1), 1).otherwise(0)),
F.lit(0),
).alias("key_columns_match_count")
key_mismatch = F.coalesce(
F.sum(F.when(~((F.col("presence_in_left") == 1) & (F.col("presence_in_right") == 1)), 1).otherwise(0)),
F.lit(0),
).alias("key_columns_mismatch_count")
rows_match = F.coalesce(F.sum(F.when(F.col("row_matches"), 1).otherwise(0)), F.lit(0)).alias("rows_matching")
rows_nomatch = F.coalesce(F.sum(F.when(~F.col("row_matches"), 1).otherwise(0)), F.lit(0)).alias("rows_not_matching")
return df.agg(rows_match, rows_nomatch, key_match, key_mismatch, *agg)
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
@staticmethod
def _value_columns(df1: DataFrame, df2: DataFrame, key_cols: List[str]) -> List[str]:
return list(set(df1.columns).union(df2.columns) - set(key_cols))
@staticmethod
def _schema_json(df: DataFrame):
return [{"name": f.name, "type": f.dataType.simpleString()} for f in df.schema.fields]
@classmethod
def create_diff(
cls,
expected_df: DataFrame,
generated_df: DataFrame,
key_columns: List[str],
diff_key: str,
) -> None:
value_cols = cls._value_columns(expected_df, generated_df, key_columns)
left_unique, _ = cls.split_by_pk_uniqueness(generated_df, key_columns)
right_unique, _ = cls.split_by_pk_uniqueness(expected_df, key_columns)
joined = cls._create_joined_df(left_unique, right_unique, key_columns, value_cols)
joined = cls._add_row_matches(joined)
joined = cls._add_column_results(joined)
summary = cls._mismatch_summary(joined)
cleaned = cls._clean_joined(joined, key_columns, value_cols, generated_df, expected_df)
cls.COMPUTED_DIFFS[diff_key] = {
DiffKeys.JOINED.value: joined,
DiffKeys.SUMMARY.value: summary,
DiffKeys.CLEANED.value: cleaned,
DiffKeys.EXPECTED.value: expected_df,
DiffKeys.GENERATED.value: generated_df,
DiffKeys.KEY_COLUMNS.value: key_columns,
DiffKeys.VALUE_COLUMNS.value: value_cols,
}
# ------------------------------------------------------------------
# Clean view for UI / downstream consumption
# ------------------------------------------------------------------
@classmethod
def _clean_joined(
cls,
df: DataFrame,
key_cols: List[str],
value_cols: List[str],
left_df: DataFrame,
right_df: DataFrame,
) -> DataFrame:
sel: List[Column] = [F.col(q(c)) for c in key_cols]
for c in value_cols:
if c in left_df.columns and c in right_df.columns:
col = F.when(
sf("compared_values", c),
F.array(sf("left_values", c)),
).otherwise(F.array(sf("left_values", c), sf("right_values", c))).alias(c)
elif c in left_df.columns:
col = F.array(sf("left_values", c)).alias(c)
else:
col = F.array(sf("right_values", c)).alias(c)
sel.append(col)
sel += [F.col("row_matches"), F.col("presence_in_left"), F.col("presence_in_right")]
return df.select(*sel)
# ------------------------------------------------------------------
# Summary dictionary (for dashboards / JSON API)
# ------------------------------------------------------------------
@classmethod
def get_diff_summary_dict(cls, diff_key: str) -> Dict[str, Any]:
diff = cls.COMPUTED_DIFFS[diff_key]
summary_row = diff[DiffKeys.SUMMARY.value].collect()[0].asDict()
total_rows = summary_row["rows_matching"] + summary_row["rows_not_matching"]
value_cols = diff[DiffKeys.VALUE_COLUMNS.value]
key_cols = diff[DiffKeys.KEY_COLUMNS.value]
perfect_val_cols = sum(
1 for c in value_cols if summary_row[f"{c}_match_count"] == total_rows
)
key_cols_match = summary_row["key_columns_match_count"] == total_rows
def _stats(df: DataFrame):
uniq, dup = cls.split_by_pk_uniqueness(df, key_cols)
return {
"columns": cls._schema_json(df),
"rowsCount": df.count(),
"uniquePkCount": uniq.count(),
"duplicatePkCount": dup.count(),
}
pct_rows = int((summary_row["rows_matching"] / total_rows * 100) if total_rows else 0)
pct_cols = int(((perfect_val_cols + (len(key_cols) if key_cols_match else 0)) / (len(value_cols) + len(key_cols)) * 100))
return {
"label": diff_key,
"data": {
"summaryTiles": [
{
"title": "Datasets matching status",
"text": "Matching" if summary_row["rows_not_matching"] == 0 else "Not Matching",
"badgeContent": f"{pct_rows}",
"isPositive": summary_row["rows_not_matching"] == 0,
"order": 0,
"orderType": "MatchingStatus",
"toolTip": "Percentage of rows that match between expected and generated datasets.",
},
{
"title": "Number of columns matching",
"text": f"{perfect_val_cols + (len(key_cols) if key_cols_match else 0)}/{len(value_cols) + len(key_cols)}",
"badgeContent": f"{pct_cols}",
"isPositive": perfect_val_cols == len(value_cols) and key_cols_match,
"order": 1,
"orderType": "ColumnMatch",
"toolTip": "Percentage of columns that match between expected and generated datasets.",
},
{
"title": "Number of rows matching",
"text": f"{summary_row['rows_matching']:,}/{total_rows:,}",
"badgeContent": f"{pct_rows}",
"isPositive": summary_row["rows_matching"] == total_rows,
"order": 2,
"orderType": "RowMatch",
"toolTip": "Number of rows that match between expected and generated datasets.",
},
],
"expData": _stats(diff[DiffKeys.EXPECTED.value]),
"genData": _stats(diff[DiffKeys.GENERATED.value]),
"commonData": {
"keyColumns": key_cols,
"columnComparisons": {
c: {
"matches": summary_row[f"{c}_match_count"],
"mismatches": summary_row[f"{c}_mismatch_count"],
}
for c in set(diff[DiffKeys.EXPECTED.value].columns)
.intersection(diff[DiffKeys.GENERATED.value].columns)
.intersection(value_cols)
},
"rowsMatchingCount": summary_row["rows_matching"],
"rowsMismatchingCount": summary_row["rows_not_matching"],
"keyColumnsMatchCount": summary_row["key_columns_match_count"],
"keyColumnsMismatchCount": summary_row["key_columns_mismatch_count"],
},
},
}
# ------------------------------------------------------------------
# Optional registration helper (kept for backwards compatibility)
# ------------------------------------------------------------------
@classmethod
def datasampleloader_register(cls, diff_key: str, _type: DiffKeys) -> str:
from .datasampleloader import DataSampleLoaderLib # local import to avoid hard dep
df = cls.COMPUTED_DIFFS[diff_key][_type.value]
key = str(uuid.uuid4())
DataSampleLoaderLib.register(key=key, df=df, create_truncated_columns=False)
return key
Editor is loading...
Leave a Comment