Pyteryx/alteryx_runner/tools/preparation/generate_rows.py

94 lines
3.5 KiB
Python

from __future__ import annotations
from typing import Dict
import polars as pl
from tools.base import BaseTool
class GenerateRowsTool(BaseTool):
def execute(self, inputs: Dict[str, pl.DataFrame]) -> Dict[str, pl.DataFrame]:
if self.config is None:
return {"Output": pl.DataFrame()}
update_existing = self._cfg_attr("UpdateField", "value", "False") == "True"
field_name = self._cfg("CreateField_Name", "") or ""
field_type = self._cfg("CreateField_Type", "Int32") or "Int32"
expr_init = self._cfg("Expression_Init", "1") or "1"
expr_cond = self._cfg("Expression_Cond", "False") or "False"
expr_loop = self._cfg("Expression_Loop", "") or ""
record_count_str = self._cfg_attr("RecordCount", "value", "0") or "0"
max_count = int(record_count_str)
dtype = self.ctx.type_mapper.map(field_type)
seed_rows = inputs.get("Input")
if seed_rows is None or seed_rows.is_empty():
seed_rows = pl.DataFrame()
has_seed = False
else:
has_seed = True
rows: list[dict] = []
MAX_SAFETY = max_count if max_count > 0 else 100_000
def _eval_scalar(expr: str, row: dict) -> object:
if not expr.strip():
return None
df_row = pl.DataFrame([row]) if row else pl.DataFrame()
try:
if row:
series = self.ctx.transpiler.eval_series(df_row, expr, "__val__", pl.String)
return series[0]
return self.ctx.transpiler.eval_scalar(expr)
except Exception:
return None
def _eval_bool(expr: str, row: dict) -> bool:
val = _eval_scalar(expr, row)
if isinstance(val, bool):
return val
if isinstance(val, str):
return val.lower() in ("true", "1", "yes")
return bool(val) if val is not None else False
if not has_seed:
current_row: dict = {}
init_val = _eval_scalar(expr_init, current_row)
current_row[field_name] = init_val
count = 0
while count < MAX_SAFETY:
if max_count > 0 and count >= max_count:
break
if not _eval_bool(expr_cond, current_row):
break
rows.append(dict(current_row))
next_val = _eval_scalar(expr_loop, current_row)
current_row[field_name] = next_val
count += 1
else:
for seed in seed_rows.to_dicts():
current_row = dict(seed)
if not update_existing:
init_val = _eval_scalar(expr_init, current_row)
current_row[field_name] = init_val
count = 0
while count < MAX_SAFETY:
if max_count > 0 and count >= max_count:
break
if not _eval_bool(expr_cond, current_row):
break
rows.append(dict(current_row))
next_val = _eval_scalar(expr_loop, current_row)
current_row[field_name] = next_val
count += 1
if not rows:
return {"Output": pl.DataFrame()}
df = pl.DataFrame(rows)
if field_name in df.columns:
try:
df = df.with_columns(pl.col(field_name).cast(dtype))
except Exception:
pass
return {"Output": df}