47 lines
1.8 KiB
Python
47 lines
1.8 KiB
Python
from __future__ import annotations
|
|
from typing import Dict
|
|
import polars as pl
|
|
from tools.base import BaseTool
|
|
|
|
|
|
class JoinMultipleTool(BaseTool):
|
|
def execute(self, inputs: Dict[str, pl.DataFrame]) -> Dict[str, pl.DataFrame]:
|
|
if self.config is None:
|
|
return {"Output": pl.DataFrame()}
|
|
|
|
sorted_keys = sorted(
|
|
inputs.keys(),
|
|
key=lambda k: int(k.replace("Input", "").lstrip("#") or "0"),
|
|
)
|
|
dfs = [inputs[k] for k in sorted_keys]
|
|
if not dfs:
|
|
return {"Output": pl.DataFrame()}
|
|
|
|
by_pos = self._cfg_attr("JoinByRecPos", "value", "True") == "True"
|
|
output_join_only = self._cfg_attr("OutputJoinOnly", "value", "False") == "True"
|
|
|
|
# Prefix all columns with Input_#N_
|
|
prefixed: list[pl.DataFrame] = []
|
|
for i, df in enumerate(dfs, start=1):
|
|
prefix = f"Input_#{i}_"
|
|
renamed = {c: f"{prefix}{c}" for c in df.columns}
|
|
prefixed.append(df.rename(renamed))
|
|
|
|
if by_pos:
|
|
# Join by row position — add index then join on it
|
|
indexed = [df.with_row_index(f"__pos_{i}__") for i, df in enumerate(prefixed)]
|
|
result = indexed[0]
|
|
for i in range(1, len(indexed)):
|
|
lk = f"__pos_{0}__"
|
|
rk = f"__pos_{i}__"
|
|
result = result.join(indexed[i], left_on=lk, right_on=rk, how="inner")
|
|
pos_cols = [f"__pos_{i}__" for i in range(len(indexed)) if f"__pos_{i}__" in result.columns]
|
|
result = result.drop(pos_cols)
|
|
else:
|
|
# Join by named fields — best effort: join first pair, then extend
|
|
result = prefixed[0]
|
|
for i in range(1, len(prefixed)):
|
|
result = pl.concat([result, prefixed[i]], how="diagonal_relaxed")
|
|
|
|
return {"Output": result}
|