Pyteryx/alteryx_runner/tools/parse/regex_tool.py

86 lines
3.2 KiB
Python

from __future__ import annotations
import re
from typing import Dict
import polars as pl
from tools.base import BaseTool
class RegExTool(BaseTool):
def execute(self, inputs: Dict[str, pl.DataFrame]) -> Dict[str, pl.DataFrame]:
df = inputs.get("Input", pl.DataFrame())
if self.config is None or df.is_empty():
return {"Output": df}
field = self._cfg("Field", "") or ""
pattern = self._cfg("Expression", "") or ""
method = self._cfg("Method", "Match") or "Match"
output_field = self._cfg("OutputField", field) or field
full_match = (self._cfg_attr("FullMatch", "value", "False") or "False").lower() == "true"
case_insensitive = (self._cfg_attr("CaseInsensitive", "value", "False") or "False").lower() == "true"
if not field or not pattern or field not in df.columns:
return {"Output": df}
flags = re.IGNORECASE if case_insensitive else 0
compiled = re.compile(pattern, flags)
if method == "Match":
def match_fn(val: str | None) -> bool | None:
if val is None:
return None
if full_match:
return bool(compiled.fullmatch(val))
return bool(compiled.search(val))
series = (
df[field]
.cast(pl.String)
.map_elements(match_fn, return_dtype=pl.Boolean)
)
df = df.with_columns(series.alias(output_field))
elif method == "Replace":
replace_string = self._cfg("Replace/ReplaceString", "") or ""
replace_string = replace_string.replace("$", "\\")
def replace_fn(val: str | None) -> str | None:
if val is None:
return None
return compiled.sub(replace_string, val)
series = (
df[field]
.cast(pl.String)
.map_elements(replace_fn, return_dtype=pl.String)
)
df = df.with_columns(series.alias(field))
elif method == "Parse":
# Extract capture groups as new columns
root_name = output_field or field
values = df[field].cast(pl.String).to_list()
group_count = compiled.groups
extracted: dict[str, list] = {
f"{root_name}{i+1}": [] for i in range(group_count)
}
for val in values:
m = compiled.search(val) if val else None
for i in range(group_count):
extracted[f"{root_name}{i+1}"].append(m.group(i + 1) if m else None)
for col_name, col_vals in extracted.items():
df = df.with_columns(pl.Series(col_name, col_vals))
elif method == "Token":
# Split by pattern, return one row per token
all_rows = []
for row_dict in df.to_dicts():
val = str(row_dict.get(field) or "")
tokens = compiled.split(val)
for token in tokens:
new_row = dict(row_dict)
new_row[field] = token
all_rows.append(new_row)
df = pl.DataFrame(all_rows) if all_rows else df
return {"Output": df}