Pyteryx/alteryx_runner/tools/transform/summarize_tool.py

111 lines
3.8 KiB
Python

from __future__ import annotations
from typing import Dict
import polars as pl
from tools.base import BaseTool
class SummarizeTool(BaseTool):
def execute(self, inputs: Dict[str, pl.DataFrame]) -> Dict[str, pl.DataFrame]:
df = inputs.get("Input", pl.DataFrame())
if self.config is None or df.is_empty():
return {"Output": df}
fields = self.config.findall("SummarizeFields/SummarizeField")
group_fields = [
f.attrib["field"]
for f in fields
if f.attrib.get("action") == "GroupBy" and f.attrib["field"] in df.columns
]
agg_exprs: list[pl.Expr] = []
for f in fields:
action = f.attrib.get("action", "GroupBy")
field = f.attrib["field"]
rename = f.attrib.get("rename", field)
if action == "GroupBy":
continue
if field not in df.columns:
continue
expr = self._build_agg(field, action, f.attrib)
if expr is not None:
agg_exprs.append(expr.alias(rename))
if not agg_exprs:
if group_fields:
return {"Output": df.select(group_fields).unique(maintain_order=True)}
return {"Output": pl.DataFrame()}
# Upcast Float32 columns to Float64 before aggregation to avoid
# floating-point precision noise (matches Alteryx behaviour).
float32_cols = [
c for c in df.columns
if df[c].dtype == pl.Float32
]
if float32_cols:
df = df.with_columns(
[pl.col(c).cast(pl.Float64) for c in float32_cols]
)
if group_fields:
result = df.group_by(group_fields, maintain_order=True).agg(agg_exprs)
# Restore group column order
all_cols = group_fields + [e.meta.output_name() for e in agg_exprs]
existing = [c for c in all_cols if c in result.columns]
result = result.select(existing)
# Sort by group columns for deterministic output
# (Alteryx Summarize sorts groups alphabetically)
result = result.sort(group_fields)
else:
result = df.select(agg_exprs)
return {"Output": result}
def _build_agg(self, field: str, action: str, attrs: dict) -> pl.Expr | None:
col = pl.col(field)
action_map = {
"Sum": col.sum(),
"Count": col.count(),
"Count Non Null": col.drop_nulls().count(),
"Count Distinct": col.n_unique(),
"Count Distinct Non Null": col.drop_nulls().n_unique(),
"Count Null": col.is_null().sum(),
"Min": col.min(),
"Max": col.max(),
"Avg": col.mean(),
"Average": col.mean(),
"Median": col.median(),
"Std Deviation": col.std(),
"Variance": col.var(),
"First": col.first(),
"Last": col.last(),
}
if action in action_map:
return action_map[action]
if action == "Percentile":
p = float(attrs.get("percentile", "50")) / 100.0
return col.quantile(p, interpolation="linear")
if action == "Concatenate":
sep = attrs.get("separator", "")
order = attrs.get("order", "")
base = col.cast(pl.String).drop_nulls()
if order == "Ascending":
base = base.sort(descending=False)
elif order == "Descending":
base = base.sort(descending=True)
return base.str.join(sep)
if action == "Mode":
# Return most frequent value
return (
col.value_counts(sort=True)
.struct.field(field)
.first()
)
return None