80 lines
2.6 KiB
Python
80 lines
2.6 KiB
Python
from __future__ import annotations
|
|
import re
|
|
from typing import Dict
|
|
import polars as pl
|
|
from tools.base import BaseTool
|
|
|
|
|
|
_ROW_REF = re.compile(r"\[Row([+-]\d+):([^\]]+)\]")
|
|
|
|
|
|
class MultiRowFormulaTool(BaseTool):
|
|
def execute(self, inputs: Dict[str, pl.DataFrame]) -> Dict[str, pl.DataFrame]:
|
|
df = inputs.get("Input", pl.DataFrame())
|
|
if self.config is None or df.is_empty():
|
|
return {"Output": df}
|
|
|
|
ff = self.config.find("FormulaField")
|
|
if ff is None:
|
|
return {"Output": df}
|
|
|
|
field = ff.attrib.get("field", "")
|
|
expr_text = ff.attrib.get("expression", "")
|
|
alteryx_type = ff.attrib.get("type", "V_WString")
|
|
size = ff.attrib.get("size")
|
|
dtype = self.ctx.type_mapper.map(alteryx_type, size)
|
|
|
|
group_fields = [
|
|
f.attrib["name"]
|
|
for f in self.config.findall("GroupByFields/Field")
|
|
if f.attrib["name"] in df.columns
|
|
]
|
|
|
|
if not group_fields:
|
|
df = self._apply(df, field, expr_text, dtype)
|
|
else:
|
|
parts: list[pl.DataFrame] = []
|
|
for _, group_df in df.group_by(group_fields, maintain_order=True):
|
|
parts.append(self._apply(group_df, field, expr_text, dtype))
|
|
df = pl.concat(parts) if parts else df
|
|
|
|
return {"Output": df}
|
|
|
|
def _apply(
|
|
self,
|
|
df: pl.DataFrame,
|
|
field: str,
|
|
expr_text: str,
|
|
dtype: pl.PolarsDataType,
|
|
) -> pl.DataFrame:
|
|
# Parse row references and convert to shift() expressions
|
|
refs = _ROW_REF.findall(expr_text)
|
|
shift_cols: dict[str, tuple[str, int]] = {}
|
|
|
|
clean_expr = expr_text
|
|
for offset_str, col_name in refs:
|
|
offset = int(offset_str)
|
|
shift_col = f"__shift_{col_name}_{offset}__"
|
|
shift_cols[shift_col] = (col_name, offset)
|
|
pattern = re.escape(f"[Row{offset_str}:{col_name}]")
|
|
clean_expr = re.sub(pattern, f"[{shift_col}]", clean_expr)
|
|
|
|
if shift_cols:
|
|
df = df.with_columns([
|
|
pl.col(src).shift(-offset if offset > 0 else abs(offset)).alias(tmp)
|
|
for tmp, (src, offset) in shift_cols.items()
|
|
])
|
|
|
|
try:
|
|
series = self.ctx.transpiler.eval_series(df, clean_expr, field, dtype)
|
|
df = df.with_columns(series.alias(field))
|
|
except Exception as e:
|
|
raise RuntimeError(
|
|
f"MultiRowFormula field {field!r}: {e}"
|
|
) from e
|
|
|
|
if shift_cols:
|
|
df = df.drop(list(shift_cols.keys()))
|
|
|
|
return df
|