77 lines
2.6 KiB
Python
77 lines
2.6 KiB
Python
from __future__ import annotations
|
|
from typing import Dict
|
|
import polars as pl
|
|
from tools.base import BaseTool
|
|
|
|
|
|
class SelectTool(BaseTool):
|
|
def execute(self, inputs: Dict[str, pl.DataFrame]) -> Dict[str, pl.DataFrame]:
|
|
df = inputs.get("Input", pl.DataFrame())
|
|
if self.config is None or df.is_empty():
|
|
return {"Output": df}
|
|
|
|
select_fields = self.config.findall("SelectFields/SelectField")
|
|
order_changed = self.config.attrib.get("OrderChanged", "False") == "True"
|
|
|
|
keep_unknown = True
|
|
explicit: dict[str, dict] = {}
|
|
|
|
for sf in select_fields:
|
|
name = sf.attrib.get("field", "")
|
|
if name == "*Unknown":
|
|
keep_unknown = sf.attrib.get("selected", "True") == "True"
|
|
continue
|
|
explicit[name] = {
|
|
"selected": sf.attrib.get("selected", "True") == "True",
|
|
"rename": sf.attrib.get("rename", name),
|
|
"type": sf.attrib.get("type"),
|
|
"size": sf.attrib.get("size"),
|
|
}
|
|
|
|
if order_changed:
|
|
# Follow config order for explicitly listed columns
|
|
ordered_names = [
|
|
sf.attrib.get("field", "")
|
|
for sf in select_fields
|
|
if sf.attrib.get("field", "") != "*Unknown"
|
|
]
|
|
else:
|
|
# Follow incoming column order
|
|
ordered_names = [c for c in df.columns if c in explicit]
|
|
for sf in select_fields:
|
|
n = sf.attrib.get("field", "")
|
|
if n != "*Unknown" and n not in ordered_names and n in df.columns:
|
|
ordered_names.append(n)
|
|
|
|
result_exprs: list[pl.Expr] = []
|
|
processed: set[str] = set()
|
|
|
|
for name in ordered_names:
|
|
if name not in df.columns:
|
|
continue
|
|
meta = explicit.get(name)
|
|
if meta is None:
|
|
continue
|
|
if not meta["selected"]:
|
|
processed.add(name)
|
|
continue
|
|
processed.add(name)
|
|
col = pl.col(name)
|
|
if meta["type"]:
|
|
try:
|
|
target_dtype = self.ctx.type_mapper.map(meta["type"], meta["size"])
|
|
col = col.cast(target_dtype)
|
|
except Exception:
|
|
pass
|
|
result_exprs.append(col.alias(meta["rename"]))
|
|
|
|
if keep_unknown:
|
|
for c in df.columns:
|
|
if c not in processed:
|
|
result_exprs.append(pl.col(c))
|
|
|
|
if not result_exprs:
|
|
return {"Output": pl.DataFrame()}
|
|
|
|
return {"Output": df.select(result_exprs)}
|