94 lines
3.5 KiB
Python
94 lines
3.5 KiB
Python
from __future__ import annotations
|
|
from typing import Dict
|
|
import polars as pl
|
|
from tools.base import BaseTool
|
|
|
|
|
|
class GenerateRowsTool(BaseTool):
|
|
def execute(self, inputs: Dict[str, pl.DataFrame]) -> Dict[str, pl.DataFrame]:
|
|
if self.config is None:
|
|
return {"Output": pl.DataFrame()}
|
|
|
|
update_existing = self._cfg_attr("UpdateField", "value", "False") == "True"
|
|
field_name = self._cfg("CreateField_Name", "") or ""
|
|
field_type = self._cfg("CreateField_Type", "Int32") or "Int32"
|
|
expr_init = self._cfg("Expression_Init", "1") or "1"
|
|
expr_cond = self._cfg("Expression_Cond", "False") or "False"
|
|
expr_loop = self._cfg("Expression_Loop", "") or ""
|
|
record_count_str = self._cfg_attr("RecordCount", "value", "0") or "0"
|
|
max_count = int(record_count_str)
|
|
|
|
dtype = self.ctx.type_mapper.map(field_type)
|
|
|
|
seed_rows = inputs.get("Input")
|
|
if seed_rows is None or seed_rows.is_empty():
|
|
seed_rows = pl.DataFrame()
|
|
has_seed = False
|
|
else:
|
|
has_seed = True
|
|
|
|
rows: list[dict] = []
|
|
MAX_SAFETY = max_count if max_count > 0 else 100_000
|
|
|
|
def _eval_scalar(expr: str, row: dict) -> object:
|
|
if not expr.strip():
|
|
return None
|
|
df_row = pl.DataFrame([row]) if row else pl.DataFrame()
|
|
try:
|
|
if row:
|
|
series = self.ctx.transpiler.eval_series(df_row, expr, "__val__", pl.String)
|
|
return series[0]
|
|
return self.ctx.transpiler.eval_scalar(expr)
|
|
except Exception:
|
|
return None
|
|
|
|
def _eval_bool(expr: str, row: dict) -> bool:
|
|
val = _eval_scalar(expr, row)
|
|
if isinstance(val, bool):
|
|
return val
|
|
if isinstance(val, str):
|
|
return val.lower() in ("true", "1", "yes")
|
|
return bool(val) if val is not None else False
|
|
|
|
if not has_seed:
|
|
current_row: dict = {}
|
|
init_val = _eval_scalar(expr_init, current_row)
|
|
current_row[field_name] = init_val
|
|
count = 0
|
|
while count < MAX_SAFETY:
|
|
if max_count > 0 and count >= max_count:
|
|
break
|
|
if not _eval_bool(expr_cond, current_row):
|
|
break
|
|
rows.append(dict(current_row))
|
|
next_val = _eval_scalar(expr_loop, current_row)
|
|
current_row[field_name] = next_val
|
|
count += 1
|
|
else:
|
|
for seed in seed_rows.to_dicts():
|
|
current_row = dict(seed)
|
|
if not update_existing:
|
|
init_val = _eval_scalar(expr_init, current_row)
|
|
current_row[field_name] = init_val
|
|
count = 0
|
|
while count < MAX_SAFETY:
|
|
if max_count > 0 and count >= max_count:
|
|
break
|
|
if not _eval_bool(expr_cond, current_row):
|
|
break
|
|
rows.append(dict(current_row))
|
|
next_val = _eval_scalar(expr_loop, current_row)
|
|
current_row[field_name] = next_val
|
|
count += 1
|
|
|
|
if not rows:
|
|
return {"Output": pl.DataFrame()}
|
|
|
|
df = pl.DataFrame(rows)
|
|
if field_name in df.columns:
|
|
try:
|
|
df = df.with_columns(pl.col(field_name).cast(dtype))
|
|
except Exception:
|
|
pass
|
|
return {"Output": df}
|