"""
Statistics computation with minimal DataFrame scans.
"""
from typing import Any
from py4j.protocol import Py4JError, Py4JJavaError
from pyspark.sql import DataFrame
from pyspark.sql import functions as F # noqa: N812
from pyspark.sql import types as t
from pyspark.sql.utils import AnalysisException
from .constants import (
APPROX_DISTINCT_RSD,
DEFAULT_TOP_VALUES_LIMIT,
ID_COLUMN_UNIQUENESS_THRESHOLD,
OUTLIER_IQR_MULTIPLIER,
PATTERNS,
QUALITY_OUTLIER_PENALTY_MAX,
)
from .exceptions import SparkOperationError, StatisticsError
from .logging import get_logger
from .utils import escape_column_name
# Check if median function is available (PySpark 3.4.0+)
try:
_median_check = F.median
HAS_MEDIAN = True
except AttributeError:
HAS_MEDIAN = False
logger = get_logger(__name__)
[docs]
class StatisticsComputer:
"""Computes statistics for DataFrame columns using type-specific calculators."""
[docs]
def __init__(
self,
dataframe: DataFrame,
total_rows: int | None = None,
cache_manager: Any = None,
):
"""
Initialize with a PySpark DataFrame.
Args:
dataframe: PySpark DataFrame to compute statistics for
total_rows: Cached row count to avoid recomputation
cache_manager: Optional CacheManager for performance optimization
"""
self.df = dataframe
self._total_rows = total_rows
self._cache_manager = cache_manager
self._column_types = {
field.name: field.dataType for field in self.df.schema.fields
}
logger.debug(
f"StatisticsComputer initialized with {'cached' if total_rows else 'lazy'} row count"
+ (", cache manager enabled" if cache_manager else "")
)
def _get_total_rows(self) -> int:
"""Get total row count, computing if not cached."""
if self._total_rows is None:
logger.debug("Computing DataFrame row count")
self._total_rows = self.df.count()
logger.debug(f"Row count computed: {self._total_rows:,}")
return self._total_rows
[docs]
def compute_all_columns_batch(
self,
columns: list[str] | None = None,
include_advanced: bool = True,
include_quality: bool = True,
progress_tracker: Any = None,
) -> dict[str, dict[str, Any]]:
"""
Compute statistics for multiple columns with minimal DataFrame scans.
Args:
columns: List of columns to profile. If None, profiles all columns.
include_advanced: Include advanced statistics (always True)
include_quality: Include data quality metrics
progress_tracker: Optional progress tracker for reporting progress
Returns:
Dictionary mapping column names to their statistics
"""
if columns is None:
columns = self.df.columns
logger.debug("Computing basic statistics")
logger.info(f"Starting optimized computation for {len(columns)} columns")
# Filter out non-existent columns
valid_columns = [col for col in columns if col in self._column_types]
if len(valid_columns) < len(columns):
invalid = set(columns) - set(valid_columns)
logger.warning(f"Columns not found in DataFrame: {invalid}")
total_rows = self._get_total_rows()
# Update progress tracker with actual column count
if progress_tracker:
progress_tracker.total_items = len(valid_columns)
progress_tracker.start()
# Build aggregation expressions for all columns
agg_exprs = self._build_aggregation_expressions(
valid_columns, include_quality, include_advanced
)
# Handle empty DataFrame or no valid columns
if not valid_columns:
logger.warning("No valid columns to process")
return {}
if not agg_exprs:
logger.warning("No aggregation expressions generated")
return {}
try:
# Execute single aggregation for all statistics
logger.debug(f"Executing aggregation with {len(agg_exprs)} expressions")
if self._cache_manager:
logger.debug("Using cached DataFrame for statistics computation")
if progress_tracker:
progress_tracker.update("Executing aggregations")
result_row = self.df.agg(*agg_exprs).collect()[0]
# Unpack results into column dictionaries
results = self._unpack_results(
result_row,
valid_columns,
total_rows,
include_quality,
include_advanced,
progress_tracker,
)
# Compute special cases that require separate scans
special_stats = self._compute_special_cases(
valid_columns, results, include_advanced
)
# Merge special statistics
for col_name, special in special_stats.items():
if col_name in results:
results[col_name].update(special)
logger.info(f"Computation completed for {len(results)} columns")
return results
except (AnalysisException, Py4JError, Py4JJavaError) as e:
logger.error(f"Spark error during batch computation: {e!s}")
raise SparkOperationError(
f"Failed to compute statistics in batch: {e!s}", e
) from e
except Exception as e:
logger.error(f"Unexpected error during batch computation: {e!s}")
raise StatisticsError(
f"Failed to compute statistics in batch: {e!s}"
) from e
def _build_aggregation_expressions(
self, columns: list[str], include_quality: bool, include_advanced: bool
) -> list:
"""Build aggregation expressions for all columns in a single pass."""
agg_exprs = []
for col_name in columns:
col_type = self._column_types[col_name]
escaped = escape_column_name(col_name)
# Basic statistics for all column types
agg_exprs.extend(
[
F.count(F.col(escaped)).alias(f"{col_name}__non_null_count"),
F.count(F.when(F.col(escaped).isNull(), 1)).alias(
f"{col_name}__null_count"
),
F.approx_count_distinct(
F.col(escaped), rsd=APPROX_DISTINCT_RSD
).alias(f"{col_name}__distinct_count"),
]
)
# Numeric column statistics
if isinstance(col_type, t.NumericType):
agg_exprs.extend(
self._build_numeric_expressions(col_name, escaped, include_advanced)
)
if include_quality:
agg_exprs.extend(
self._build_numeric_quality_expressions(col_name, escaped)
)
# String column statistics
elif isinstance(col_type, t.StringType):
agg_exprs.extend(
self._build_string_expressions(col_name, escaped, include_advanced)
)
if include_quality:
agg_exprs.extend(
self._build_string_quality_expressions(col_name, escaped)
)
# Temporal column statistics
elif isinstance(col_type, t.TimestampType | t.DateType):
agg_exprs.extend(self._build_temporal_expressions(col_name, escaped))
return agg_exprs
def _build_numeric_expressions(
self, col_name: str, escaped: str, include_advanced: bool
) -> list:
"""Build numeric-specific aggregation expressions."""
exprs = [
F.min(F.col(escaped)).alias(f"{col_name}__min"),
F.max(F.col(escaped)).alias(f"{col_name}__max"),
F.mean(F.col(escaped)).alias(f"{col_name}__mean"),
F.stddev(F.col(escaped)).alias(f"{col_name}__std"),
F.sum(F.col(escaped)).alias(f"{col_name}__sum"),
F.count(F.when(F.col(escaped) == 0, 1)).alias(f"{col_name}__zero_count"),
F.count(F.when(F.col(escaped) < 0, 1)).alias(f"{col_name}__negative_count"),
]
if include_advanced:
# Add advanced statistics
exprs.extend(
[
F.skewness(F.col(escaped)).alias(f"{col_name}__skewness"),
F.kurtosis(F.col(escaped)).alias(f"{col_name}__kurtosis"),
F.variance(F.col(escaped)).alias(f"{col_name}__variance"),
]
)
# Add median
if HAS_MEDIAN:
exprs.append(F.median(F.col(escaped)).alias(f"{col_name}__median"))
else:
exprs.append(
F.expr(f"percentile_approx({escaped}, 0.5)").alias(
f"{col_name}__median"
)
)
# Add percentiles
percentiles = [(0.25, "q1"), (0.75, "q3"), (0.05, "p5"), (0.95, "p95")]
for p_val, p_name in percentiles:
exprs.append(
F.expr(f"percentile_approx({escaped}, {p_val})").alias(
f"{col_name}__{p_name}"
)
)
return exprs
def _build_numeric_quality_expressions(self, col_name: str, escaped: str) -> list:
"""Build numeric quality-specific expressions."""
return [
F.count(F.when(F.isnan(F.col(escaped)), 1)).alias(f"{col_name}__nan_count"),
F.count(F.when(F.col(escaped) == float("inf"), 1)).alias(
f"{col_name}__inf_count"
),
F.count(F.when(F.col(escaped) == float("-inf"), 1)).alias(
f"{col_name}__neg_inf_count"
),
]
def _build_string_expressions(
self, col_name: str, escaped: str, include_advanced: bool
) -> list:
"""Build string-specific aggregation expressions."""
exprs = [
F.min(F.length(F.col(escaped))).alias(f"{col_name}__min_length"),
F.max(F.length(F.col(escaped))).alias(f"{col_name}__max_length"),
F.mean(F.length(F.col(escaped))).alias(f"{col_name}__avg_length"),
F.count(F.when(F.col(escaped) == "", 1)).alias(f"{col_name}__empty_count"),
]
if include_advanced:
# Add advanced string statistics
exprs.append(
F.count(F.when(F.trim(F.col(escaped)) != F.col(escaped), 1)).alias(
f"{col_name}__has_whitespace_count"
)
)
# Pattern detection
exprs.extend(
[
F.count(F.when(F.col(escaped).rlike(PATTERNS["email"]), 1)).alias(
f"{col_name}__email_count"
),
F.count(F.when(F.col(escaped).rlike(PATTERNS["url"]), 1)).alias(
f"{col_name}__url_count"
),
F.count(F.when(F.col(escaped).rlike(PATTERNS["phone"]), 1)).alias(
f"{col_name}__phone_like_count"
),
F.count(
F.when(F.col(escaped).rlike(PATTERNS["numeric_string"]), 1)
).alias(f"{col_name}__numeric_string_count"),
F.count(
F.when(
(F.col(escaped).isNotNull())
& (F.col(escaped) == F.upper(F.col(escaped))),
1,
)
).alias(f"{col_name}__uppercase_count"),
F.count(
F.when(
(F.col(escaped).isNotNull())
& (F.col(escaped) == F.lower(F.col(escaped))),
1,
)
).alias(f"{col_name}__lowercase_count"),
]
)
return exprs
def _build_string_quality_expressions(self, col_name: str, escaped: str) -> list:
"""Build string quality-specific expressions."""
return [
F.count(F.when(F.trim(F.col(escaped)) == "", 1)).alias(
f"{col_name}__blank_count"
),
F.count(F.when(F.col(escaped).rlike(r"[^\x00-\x7F]"), 1)).alias(
f"{col_name}__non_ascii_count"
),
F.count(F.when(F.length(F.col(escaped)) == 1, 1)).alias(
f"{col_name}__single_char_count"
),
]
def _build_temporal_expressions(self, col_name: str, escaped: str) -> list:
"""Build temporal-specific aggregation expressions."""
return [
F.min(F.col(escaped)).alias(f"{col_name}__min_date"),
F.max(F.col(escaped)).alias(f"{col_name}__max_date"),
]
def _unpack_results(
self,
result_row: Any,
columns: list[str],
total_rows: int,
include_quality: bool,
include_advanced: bool,
progress_tracker: Any = None,
) -> dict[str, dict[str, Any]]:
"""Unpack flat aggregation results into column-specific dictionaries."""
results = {}
for col_name in columns:
# Update progress for each column
if progress_tracker:
progress_tracker.update(f"Processing {col_name}")
col_type = self._column_types[col_name]
# Basic statistics - common to all types
non_null_count = result_row[f"{col_name}__non_null_count"]
null_count = result_row[f"{col_name}__null_count"]
distinct_count = result_row[f"{col_name}__distinct_count"]
stats: dict[str, Any] = {
"data_type": str(col_type),
"total_count": int(total_rows),
"non_null_count": non_null_count,
"null_count": null_count,
"null_percentage": (
(null_count / total_rows * 100) if total_rows > 0 else 0.0
),
"distinct_count": distinct_count,
"distinct_percentage": (
(distinct_count / non_null_count * 100)
if non_null_count > 0
else 0.0
),
}
# Type-specific statistics
if isinstance(col_type, t.NumericType):
self._unpack_numeric(
stats,
result_row,
col_name,
total_rows,
include_advanced,
include_quality,
)
elif isinstance(col_type, t.StringType):
self._unpack_string(
stats, result_row, col_name, include_advanced, include_quality
)
elif isinstance(col_type, t.TimestampType | t.DateType):
self._unpack_temporal(stats, result_row, col_name)
# Add quality score if requested
if include_quality:
quality_metrics = self._calculate_quality_metrics(
stats, col_type, col_name
)
stats["quality"] = quality_metrics
results[col_name] = stats
return results
def _unpack_numeric(
self,
stats: dict[str, Any],
result_row: Any,
col_name: str,
total_rows: int,
include_advanced: bool,
include_quality: bool,
) -> None:
"""Unpack all numeric statistics and quality metrics from result row."""
# Basic numeric stats
stats.update(
{
"min": result_row[f"{col_name}__min"],
"max": result_row[f"{col_name}__max"],
"mean": result_row[f"{col_name}__mean"],
"std": (
result_row[f"{col_name}__std"]
if result_row[f"{col_name}__std"] is not None
else 0.0
),
"sum": result_row[f"{col_name}__sum"],
"zero_count": result_row[f"{col_name}__zero_count"],
"negative_count": result_row[f"{col_name}__negative_count"],
}
)
# Advanced statistics
if include_advanced:
stats.update(
{
"median": result_row[f"{col_name}__median"],
"q1": result_row[f"{col_name}__q1"],
"q3": result_row[f"{col_name}__q3"],
"p5": result_row[f"{col_name}__p5"],
"p95": result_row[f"{col_name}__p95"],
"skewness": result_row[f"{col_name}__skewness"],
"kurtosis": result_row[f"{col_name}__kurtosis"],
"variance": result_row[f"{col_name}__variance"],
}
)
# Quality metrics
if include_quality:
nan_count = result_row[f"{col_name}__nan_count"]
inf_count = result_row[f"{col_name}__inf_count"]
neg_inf_count = result_row[f"{col_name}__neg_inf_count"]
if "quality" not in stats:
stats["quality"] = {}
stats["quality"].update(
{
"nan_count": nan_count,
"infinity_count": inf_count + neg_inf_count,
}
)
# Derived statistics
if stats["min"] is not None and stats["max"] is not None:
stats["range"] = stats["max"] - stats["min"]
if (
include_advanced
and stats.get("q1") is not None
and stats.get("q3") is not None
):
stats["iqr"] = stats["q3"] - stats["q1"]
# Calculate outlier bounds (counts computed separately)
iqr = stats["iqr"]
stats["outliers"] = {
"method": "iqr",
"lower_bound": stats["q1"] - OUTLIER_IQR_MULTIPLIER * iqr,
"upper_bound": stats["q3"] + OUTLIER_IQR_MULTIPLIER * iqr,
"outlier_count": 0, # Updated in _compute_special_cases
"outlier_percentage": 0.0,
"lower_outlier_count": 0,
"upper_outlier_count": 0,
}
# Coefficient of variation
if stats["mean"] and stats["mean"] != 0 and stats["std"]:
stats["cv"] = abs(stats["std"] / stats["mean"])
def _unpack_string(
self,
stats: dict[str, Any],
result_row: Any,
col_name: str,
include_advanced: bool,
include_quality: bool,
) -> None:
"""Unpack all string statistics and quality metrics from result row."""
# Basic string stats
stats.update(
{
"min_length": result_row[f"{col_name}__min_length"],
"max_length": result_row[f"{col_name}__max_length"],
"avg_length": result_row[f"{col_name}__avg_length"],
"empty_count": result_row[f"{col_name}__empty_count"],
}
)
# Advanced statistics
if include_advanced:
stats["has_whitespace_count"] = result_row[
f"{col_name}__has_whitespace_count"
]
# Pattern detection results
stats["patterns"] = {
"email_count": result_row[f"{col_name}__email_count"],
"url_count": result_row[f"{col_name}__url_count"],
"phone_like_count": result_row[f"{col_name}__phone_like_count"],
"numeric_string_count": result_row[f"{col_name}__numeric_string_count"],
"uppercase_count": result_row[f"{col_name}__uppercase_count"],
"lowercase_count": result_row[f"{col_name}__lowercase_count"],
}
# Quality metrics
if include_quality:
if "quality" not in stats:
stats["quality"] = {}
stats["quality"].update(
{
"blank_count": result_row[f"{col_name}__blank_count"],
"non_ascii_count": result_row[f"{col_name}__non_ascii_count"],
"single_char_count": result_row[f"{col_name}__single_char_count"],
}
)
def _unpack_temporal(
self, stats: dict[str, Any], result_row: Any, col_name: str
) -> None:
"""Unpack temporal statistics from result row."""
min_date = result_row[f"{col_name}__min_date"]
max_date = result_row[f"{col_name}__max_date"]
stats.update(
{
"min_date": min_date,
"max_date": max_date,
}
)
# Calculate date range in days
if min_date and max_date:
try:
date_range_days = (max_date - min_date).days
stats["date_range_days"] = date_range_days
except (AttributeError, TypeError):
stats["date_range_days"] = None
else:
stats["date_range_days"] = None
def _calculate_quality_metrics(
self, stats: dict[str, Any], column_type: Any, col_name: str
) -> dict[str, Any]:
"""Calculate overall quality metrics for a column."""
null_percentage = stats.get("null_percentage", 0.0)
distinct_percentage = stats.get("distinct_percentage", 0.0)
non_null_count = stats.get("non_null_count", 0)
quality_metrics = stats.get("quality", {})
quality_metrics.update(
{
"completeness": 1.0 - (null_percentage / 100.0),
"uniqueness": (
distinct_percentage / 100.0 if non_null_count > 0 else 0.0
),
"null_count": stats.get("null_count", 0),
"column_type": self._get_type_name(column_type),
}
)
# Calculate quality score
quality_score = quality_metrics["completeness"]
# Penalize for outliers in numeric columns
if isinstance(column_type, t.NumericType) and "outliers" in stats:
outlier_percentage = stats["outliers"].get("outlier_percentage", 0.0)
outlier_penalty = min(
outlier_percentage / 100.0 * QUALITY_OUTLIER_PENALTY_MAX,
QUALITY_OUTLIER_PENALTY_MAX,
)
quality_score *= 1 - outlier_penalty
quality_metrics["outlier_percentage"] = outlier_percentage
# Penalize for low uniqueness in ID-like columns
if (
"id" in col_name.lower()
and quality_metrics["uniqueness"] < ID_COLUMN_UNIQUENESS_THRESHOLD
):
quality_score *= quality_metrics["uniqueness"]
quality_metrics["quality_score"] = round(quality_score, 3)
return dict(quality_metrics)
def _get_type_name(self, column_type: Any) -> str:
"""Get simplified type name for reporting."""
if isinstance(column_type, t.NumericType):
return "numeric"
if isinstance(column_type, t.StringType):
return "string"
if isinstance(column_type, t.TimestampType | t.DateType):
return "temporal"
return "other"
def _compute_special_cases(
self,
columns: list[str],
results: dict[str, dict[str, Any]],
include_advanced: bool,
) -> dict[str, dict[str, Any]]:
"""
Compute statistics that require separate scans.
Optimized to minimize the number of additional scans.
"""
special_stats: dict[str, dict[str, Any]] = {}
# Group columns by what special processing they need
numeric_cols_needing_outliers = []
string_cols_needing_top_values = []
for col_name in columns:
col_type = self._column_types[col_name]
# Numeric columns need outlier counts (only if advanced stats are requested)
if (
isinstance(col_type, t.NumericType)
and col_name in results
and include_advanced
):
if "outliers" in results[col_name]:
numeric_cols_needing_outliers.append(col_name)
# String columns need top values (only if advanced stats are requested)
elif isinstance(col_type, t.StringType) and include_advanced:
string_cols_needing_top_values.append(col_name)
# Compute outlier counts for numeric columns in one scan
if numeric_cols_needing_outliers:
outlier_stats = self._compute_outlier_counts_batch(
numeric_cols_needing_outliers, results
)
for col_name, outlier_info in outlier_stats.items():
if col_name not in special_stats:
special_stats[col_name] = {}
special_stats[col_name]["outliers"] = outlier_info
# Compute top values for string columns (requires separate groupBy per column)
if string_cols_needing_top_values:
# Process in batches to avoid too many concurrent operations
batch_size = 10
for i in range(0, len(string_cols_needing_top_values), batch_size):
batch = string_cols_needing_top_values[i : i + batch_size]
for col_name in batch:
top_values = self._get_top_values(col_name)
if col_name not in special_stats:
special_stats[col_name] = {}
special_stats[col_name]["top_values"] = top_values
return special_stats
def _compute_outlier_counts_batch(
self, columns: list[str], results: dict[str, dict[str, Any]]
) -> dict[str, dict[str, Any]]:
"""Compute outlier counts for multiple numeric columns in one scan."""
agg_exprs = []
bounds_map = {}
for col_name in columns:
if col_name in results and "outliers" in results[col_name]:
outlier_info = results[col_name]["outliers"]
lower_bound = outlier_info["lower_bound"]
upper_bound = outlier_info["upper_bound"]
bounds_map[col_name] = (lower_bound, upper_bound)
escaped = escape_column_name(col_name)
agg_exprs.extend(
[
F.count(F.when(F.col(escaped) < lower_bound, 1)).alias(
f"{col_name}__lower_outliers"
),
F.count(F.when(F.col(escaped) > upper_bound, 1)).alias(
f"{col_name}__upper_outliers"
),
]
)
if not agg_exprs:
return {}
# Execute aggregation
result_row = self.df.agg(*agg_exprs).collect()[0]
# Unpack results
outlier_results = {}
total_rows = self._get_total_rows()
for col_name in columns:
if col_name in bounds_map:
lower_count = result_row[f"{col_name}__lower_outliers"]
upper_count = result_row[f"{col_name}__upper_outliers"]
total_outliers = lower_count + upper_count
lower_bound, upper_bound = bounds_map[col_name]
outlier_results[col_name] = {
"method": "iqr",
"lower_bound": lower_bound,
"upper_bound": upper_bound,
"outlier_count": total_outliers,
"outlier_percentage": (
(total_outliers / total_rows * 100) if total_rows > 0 else 0.0
),
"lower_outlier_count": lower_count,
"upper_outlier_count": upper_count,
}
return outlier_results
def _get_top_values(
self, column_name: str, limit: int = DEFAULT_TOP_VALUES_LIMIT
) -> list[dict[str, Any]]:
"""Get top frequent values for a column."""
escaped = escape_column_name(column_name)
try:
top_values = (
self.df.filter(F.col(escaped).isNotNull())
.groupBy(column_name)
.count()
.orderBy(F.desc("count"))
.limit(limit)
.collect()
)
return [
{"value": row[column_name], "count": row["count"]} for row in top_values
]
except Exception as e:
logger.warning(f"Failed to compute top values for {column_name}: {e!s}")
return []