from __future__ import annotations from typing import Dict import polars as pl from tools.base import BaseTool class SummarizeTool(BaseTool): def execute(self, inputs: Dict[str, pl.DataFrame]) -> Dict[str, pl.DataFrame]: df = inputs.get("Input", pl.DataFrame()) if self.config is None or df.is_empty(): return {"Output": df} fields = self.config.findall("SummarizeFields/SummarizeField") group_fields = [ f.attrib["field"] for f in fields if f.attrib.get("action") == "GroupBy" and f.attrib["field"] in df.columns ] agg_exprs: list[pl.Expr] = [] for f in fields: action = f.attrib.get("action", "GroupBy") field = f.attrib["field"] rename = f.attrib.get("rename", field) if action == "GroupBy": continue if field not in df.columns: continue expr = self._build_agg(field, action, f.attrib) if expr is not None: agg_exprs.append(expr.alias(rename)) if not agg_exprs: if group_fields: return {"Output": df.select(group_fields).unique(maintain_order=True)} return {"Output": pl.DataFrame()} if group_fields: result = df.group_by(group_fields, maintain_order=True).agg(agg_exprs) # Restore group column order all_cols = group_fields + [e.meta.output_name() for e in agg_exprs] existing = [c for c in all_cols if c in result.columns] result = result.select(existing) else: result = df.select(agg_exprs) return {"Output": result} def _build_agg(self, field: str, action: str, attrs: dict) -> pl.Expr | None: col = pl.col(field) action_map = { "Sum": col.sum(), "Count": col.count(), "Count Non Null": col.drop_nulls().count(), "Count Distinct": col.n_unique(), "Count Distinct Non Null": col.drop_nulls().n_unique(), "Count Null": col.is_null().sum(), "Min": col.min(), "Max": col.max(), "Avg": col.mean(), "Average": col.mean(), "Median": col.median(), "Std Deviation": col.std(), "Variance": col.var(), "First": col.first(), "Last": col.last(), } if action in action_map: return action_map[action] if action == "Percentile": p = float(attrs.get("percentile", "50")) / 100.0 return col.quantile(p, interpolation="linear") if action == "Concatenate": sep = attrs.get("separator", "") order = attrs.get("order", "") base = col.cast(pl.String).drop_nulls() if order == "Ascending": base = base.sort(descending=False) elif order == "Descending": base = base.sort(descending=True) return base.str.join(sep) if action == "Mode": # Return most frequent value return ( col.value_counts(sort=True) .struct.field(field) .first() ) return None