from __future__ import annotations import re from typing import Dict import polars as pl from tools.base import BaseTool class RegExTool(BaseTool): def execute(self, inputs: Dict[str, pl.DataFrame]) -> Dict[str, pl.DataFrame]: df = inputs.get("Input", pl.DataFrame()) if self.config is None or df.is_empty(): return {"Output": df} field = self._cfg("Field", "") or "" pattern = self._cfg("Expression", "") or "" method = self._cfg("Method", "Match") or "Match" output_field = self._cfg("OutputField", field) or field full_match = (self._cfg_attr("FullMatch", "value", "False") or "False").lower() == "true" case_insensitive = (self._cfg_attr("CaseInsensitive", "value", "False") or "False").lower() == "true" if not field or not pattern or field not in df.columns: return {"Output": df} flags = re.IGNORECASE if case_insensitive else 0 compiled = re.compile(pattern, flags) if method == "Match": def match_fn(val: str | None) -> bool | None: if val is None: return None if full_match: return bool(compiled.fullmatch(val)) return bool(compiled.search(val)) series = ( df[field] .cast(pl.String) .map_elements(match_fn, return_dtype=pl.Boolean) ) df = df.with_columns(series.alias(output_field)) elif method == "Replace": replace_string = self._cfg("Replace/ReplaceString", "") or "" replace_string = replace_string.replace("$", "\\") def replace_fn(val: str | None) -> str | None: if val is None: return None return compiled.sub(replace_string, val) series = ( df[field] .cast(pl.String) .map_elements(replace_fn, return_dtype=pl.String) ) df = df.with_columns(series.alias(field)) elif method == "Parse": # Extract capture groups as new columns root_name = output_field or field values = df[field].cast(pl.String).to_list() group_count = compiled.groups extracted: dict[str, list] = { f"{root_name}{i+1}": [] for i in range(group_count) } for val in values: m = compiled.search(val) if val else None for i in range(group_count): extracted[f"{root_name}{i+1}"].append(m.group(i + 1) if m else None) for col_name, col_vals in extracted.items(): df = df.with_columns(pl.Series(col_name, col_vals)) elif method == "Token": # Split by pattern, return one row per token all_rows = [] for row_dict in df.to_dicts(): val = str(row_dict.get(field) or "") tokens = compiled.split(val) for token in tokens: new_row = dict(row_dict) new_row[field] = token all_rows.append(new_row) df = pl.DataFrame(all_rows) if all_rows else df return {"Output": df}