54 lines
1.8 KiB
Python
54 lines
1.8 KiB
Python
from __future__ import annotations
|
|
from typing import Dict
|
|
import polars as pl
|
|
from tools.base import BaseTool
|
|
|
|
|
|
class CrossTabTool(BaseTool):
|
|
def execute(self, inputs: Dict[str, pl.DataFrame]) -> Dict[str, pl.DataFrame]:
|
|
df = inputs.get("Input", pl.DataFrame())
|
|
if self.config is None or df.is_empty():
|
|
return {"Output": df}
|
|
|
|
header_el = self.config.find("HeaderField")
|
|
data_el = self.config.find("DataField")
|
|
if header_el is None or data_el is None:
|
|
return {"Output": df}
|
|
|
|
header_field = header_el.attrib.get("field", "")
|
|
data_field = data_el.attrib.get("field", "")
|
|
method = data_el.attrib.get("method", "Sum").upper()
|
|
groups = [
|
|
f.attrib["name"]
|
|
for f in self.config.findall("GroupFields/Field")
|
|
if f.attrib["name"] in df.columns
|
|
]
|
|
|
|
if not header_field or not data_field:
|
|
return {"Output": df}
|
|
|
|
con = self.ctx.duckdb_con
|
|
con.register("__ct_input__", df.to_arrow())
|
|
|
|
group_clause = ", ".join(f'"{g}"' for g in groups) if groups else "1"
|
|
agg_methods = {
|
|
"SUM": "SUM", "COUNT": "COUNT", "AVG": "AVG", "MIN": "MIN",
|
|
"MAX": "MAX", "FIRST": "FIRST", "LAST": "LAST",
|
|
}
|
|
agg_fn = agg_methods.get(method, "SUM")
|
|
|
|
sql = f"""
|
|
PIVOT __ct_input__
|
|
ON "{header_field}"
|
|
USING {agg_fn}("{data_field}")
|
|
GROUP BY {group_clause}
|
|
"""
|
|
try:
|
|
result = pl.from_arrow(con.execute(sql).arrow())
|
|
except Exception as e:
|
|
raise RuntimeError(f"CrossTab PIVOT failed: {e}") from e
|
|
finally:
|
|
con.execute("DROP VIEW IF EXISTS __ct_input__")
|
|
|
|
return {"Output": result}
|