86 lines
3.2 KiB
Python
86 lines
3.2 KiB
Python
from __future__ import annotations
|
|
import re
|
|
from typing import Dict
|
|
import polars as pl
|
|
from tools.base import BaseTool
|
|
|
|
|
|
class RegExTool(BaseTool):
|
|
def execute(self, inputs: Dict[str, pl.DataFrame]) -> Dict[str, pl.DataFrame]:
|
|
df = inputs.get("Input", pl.DataFrame())
|
|
if self.config is None or df.is_empty():
|
|
return {"Output": df}
|
|
|
|
field = self._cfg("Field", "") or ""
|
|
pattern = self._cfg("Expression", "") or ""
|
|
method = self._cfg("Method", "Match") or "Match"
|
|
output_field = self._cfg("OutputField", field) or field
|
|
full_match = (self._cfg_attr("FullMatch", "value", "False") or "False").lower() == "true"
|
|
case_insensitive = (self._cfg_attr("CaseInsensitive", "value", "False") or "False").lower() == "true"
|
|
|
|
if not field or not pattern or field not in df.columns:
|
|
return {"Output": df}
|
|
|
|
flags = re.IGNORECASE if case_insensitive else 0
|
|
compiled = re.compile(pattern, flags)
|
|
|
|
if method == "Match":
|
|
def match_fn(val: str | None) -> bool | None:
|
|
if val is None:
|
|
return None
|
|
if full_match:
|
|
return bool(compiled.fullmatch(val))
|
|
return bool(compiled.search(val))
|
|
|
|
series = (
|
|
df[field]
|
|
.cast(pl.String)
|
|
.map_elements(match_fn, return_dtype=pl.Boolean)
|
|
)
|
|
df = df.with_columns(series.alias(output_field))
|
|
|
|
elif method == "Replace":
|
|
replace_string = self._cfg("Replace/ReplaceString", "") or ""
|
|
replace_string = replace_string.replace("$", "\\")
|
|
|
|
def replace_fn(val: str | None) -> str | None:
|
|
if val is None:
|
|
return None
|
|
return compiled.sub(replace_string, val)
|
|
|
|
series = (
|
|
df[field]
|
|
.cast(pl.String)
|
|
.map_elements(replace_fn, return_dtype=pl.String)
|
|
)
|
|
df = df.with_columns(series.alias(field))
|
|
|
|
elif method == "Parse":
|
|
# Extract capture groups as new columns
|
|
root_name = output_field or field
|
|
values = df[field].cast(pl.String).to_list()
|
|
group_count = compiled.groups
|
|
extracted: dict[str, list] = {
|
|
f"{root_name}{i+1}": [] for i in range(group_count)
|
|
}
|
|
for val in values:
|
|
m = compiled.search(val) if val else None
|
|
for i in range(group_count):
|
|
extracted[f"{root_name}{i+1}"].append(m.group(i + 1) if m else None)
|
|
for col_name, col_vals in extracted.items():
|
|
df = df.with_columns(pl.Series(col_name, col_vals))
|
|
|
|
elif method == "Token":
|
|
# Split by pattern, return one row per token
|
|
all_rows = []
|
|
for row_dict in df.to_dicts():
|
|
val = str(row_dict.get(field) or "")
|
|
tokens = compiled.split(val)
|
|
for token in tokens:
|
|
new_row = dict(row_dict)
|
|
new_row[field] = token
|
|
all_rows.append(new_row)
|
|
df = pl.DataFrame(all_rows) if all_rows else df
|
|
|
|
return {"Output": df}
|