from __future__ import annotations from typing import Dict import polars as pl from tools.base import BaseTool class GenerateRowsTool(BaseTool): def execute(self, inputs: Dict[str, pl.DataFrame]) -> Dict[str, pl.DataFrame]: if self.config is None: return {"Output": pl.DataFrame()} update_existing = self._cfg_attr("UpdateField", "value", "False") == "True" field_name = self._cfg("CreateField_Name", "") or "" field_type = self._cfg("CreateField_Type", "Int32") or "Int32" expr_init = self._cfg("Expression_Init", "1") or "1" expr_cond = self._cfg("Expression_Cond", "False") or "False" expr_loop = self._cfg("Expression_Loop", "") or "" record_count_str = self._cfg_attr("RecordCount", "value", "0") or "0" max_count = int(record_count_str) dtype = self.ctx.type_mapper.map(field_type) seed_rows = inputs.get("Input") if seed_rows is None or seed_rows.is_empty(): seed_rows = pl.DataFrame() has_seed = False else: has_seed = True rows: list[dict] = [] MAX_SAFETY = max_count if max_count > 0 else 100_000 def _eval_scalar(expr: str, row: dict) -> object: if not expr.strip(): return None df_row = pl.DataFrame([row]) if row else pl.DataFrame() try: if row: series = self.ctx.transpiler.eval_series(df_row, expr, "__val__", pl.String) return series[0] return self.ctx.transpiler.eval_scalar(expr) except Exception: return None def _eval_bool(expr: str, row: dict) -> bool: val = _eval_scalar(expr, row) if isinstance(val, bool): return val if isinstance(val, str): return val.lower() in ("true", "1", "yes") return bool(val) if val is not None else False if not has_seed: current_row: dict = {} init_val = _eval_scalar(expr_init, current_row) current_row[field_name] = init_val count = 0 while count < MAX_SAFETY: if max_count > 0 and count >= max_count: break if not _eval_bool(expr_cond, current_row): break rows.append(dict(current_row)) next_val = _eval_scalar(expr_loop, current_row) current_row[field_name] = next_val count += 1 else: for seed in seed_rows.to_dicts(): current_row = dict(seed) if not update_existing: init_val = _eval_scalar(expr_init, current_row) current_row[field_name] = init_val count = 0 while count < MAX_SAFETY: if max_count > 0 and count >= max_count: break if not _eval_bool(expr_cond, current_row): break rows.append(dict(current_row)) next_val = _eval_scalar(expr_loop, current_row) current_row[field_name] = next_val count += 1 if not rows: return {"Output": pl.DataFrame()} df = pl.DataFrame(rows) if field_name in df.columns: try: df = df.with_columns(pl.col(field_name).cast(dtype)) except Exception: pass return {"Output": df}