Pyteryx/alteryx_runner/tools/transform/cross_tab.py

54 lines
1.8 KiB
Python

from __future__ import annotations
from typing import Dict
import polars as pl
from tools.base import BaseTool
class CrossTabTool(BaseTool):
def execute(self, inputs: Dict[str, pl.DataFrame]) -> Dict[str, pl.DataFrame]:
df = inputs.get("Input", pl.DataFrame())
if self.config is None or df.is_empty():
return {"Output": df}
header_el = self.config.find("HeaderField")
data_el = self.config.find("DataField")
if header_el is None or data_el is None:
return {"Output": df}
header_field = header_el.attrib.get("field", "")
data_field = data_el.attrib.get("field", "")
method = data_el.attrib.get("method", "Sum").upper()
groups = [
f.attrib["name"]
for f in self.config.findall("GroupFields/Field")
if f.attrib["name"] in df.columns
]
if not header_field or not data_field:
return {"Output": df}
con = self.ctx.duckdb_con
con.register("__ct_input__", df.to_arrow())
group_clause = ", ".join(f'"{g}"' for g in groups) if groups else "1"
agg_methods = {
"SUM": "SUM", "COUNT": "COUNT", "AVG": "AVG", "MIN": "MIN",
"MAX": "MAX", "FIRST": "FIRST", "LAST": "LAST",
}
agg_fn = agg_methods.get(method, "SUM")
sql = f"""
PIVOT __ct_input__
ON "{header_field}"
USING {agg_fn}("{data_field}")
GROUP BY {group_clause}
"""
try:
result = pl.from_arrow(con.execute(sql).arrow())
except Exception as e:
raise RuntimeError(f"CrossTab PIVOT failed: {e}") from e
finally:
con.execute("DROP VIEW IF EXISTS __ct_input__")
return {"Output": result}