97 lines
3.2 KiB
Python
97 lines
3.2 KiB
Python
from __future__ import annotations
|
|
from typing import Dict
|
|
import polars as pl
|
|
from tools.base import BaseTool
|
|
|
|
|
|
class SummarizeTool(BaseTool):
|
|
def execute(self, inputs: Dict[str, pl.DataFrame]) -> Dict[str, pl.DataFrame]:
|
|
df = inputs.get("Input", pl.DataFrame())
|
|
if self.config is None or df.is_empty():
|
|
return {"Output": df}
|
|
|
|
fields = self.config.findall("SummarizeFields/SummarizeField")
|
|
group_fields = [
|
|
f.attrib["field"]
|
|
for f in fields
|
|
if f.attrib.get("action") == "GroupBy" and f.attrib["field"] in df.columns
|
|
]
|
|
agg_exprs: list[pl.Expr] = []
|
|
|
|
for f in fields:
|
|
action = f.attrib.get("action", "GroupBy")
|
|
field = f.attrib["field"]
|
|
rename = f.attrib.get("rename", field)
|
|
|
|
if action == "GroupBy":
|
|
continue
|
|
if field not in df.columns:
|
|
continue
|
|
|
|
expr = self._build_agg(field, action, f.attrib)
|
|
if expr is not None:
|
|
agg_exprs.append(expr.alias(rename))
|
|
|
|
if not agg_exprs:
|
|
if group_fields:
|
|
return {"Output": df.select(group_fields).unique(maintain_order=True)}
|
|
return {"Output": pl.DataFrame()}
|
|
|
|
if group_fields:
|
|
result = df.group_by(group_fields, maintain_order=True).agg(agg_exprs)
|
|
# Restore group column order
|
|
all_cols = group_fields + [e.meta.output_name() for e in agg_exprs]
|
|
existing = [c for c in all_cols if c in result.columns]
|
|
result = result.select(existing)
|
|
else:
|
|
result = df.select(agg_exprs)
|
|
|
|
return {"Output": result}
|
|
|
|
def _build_agg(self, field: str, action: str, attrs: dict) -> pl.Expr | None:
|
|
col = pl.col(field)
|
|
action_map = {
|
|
"Sum": col.sum(),
|
|
"Count": col.count(),
|
|
"Count Non Null": col.drop_nulls().count(),
|
|
"Count Distinct": col.n_unique(),
|
|
"Count Distinct Non Null": col.drop_nulls().n_unique(),
|
|
"Count Null": col.is_null().sum(),
|
|
"Min": col.min(),
|
|
"Max": col.max(),
|
|
"Avg": col.mean(),
|
|
"Average": col.mean(),
|
|
"Median": col.median(),
|
|
"Std Deviation": col.std(),
|
|
"Variance": col.var(),
|
|
"First": col.first(),
|
|
"Last": col.last(),
|
|
}
|
|
|
|
if action in action_map:
|
|
return action_map[action]
|
|
|
|
if action == "Percentile":
|
|
p = float(attrs.get("percentile", "50")) / 100.0
|
|
return col.quantile(p, interpolation="linear")
|
|
|
|
if action == "Concatenate":
|
|
sep = attrs.get("separator", "")
|
|
order = attrs.get("order", "")
|
|
base = col.cast(pl.String).drop_nulls()
|
|
if order == "Ascending":
|
|
base = base.sort(descending=False)
|
|
elif order == "Descending":
|
|
base = base.sort(descending=True)
|
|
return base.str.join(sep)
|
|
|
|
if action == "Mode":
|
|
# Return most frequent value
|
|
return (
|
|
col.value_counts(sort=True)
|
|
.struct.field(field)
|
|
.first()
|
|
)
|
|
|
|
return None
|