Pyteryx/alteryx_runner/tools/join/join_multiple.py

47 lines
1.8 KiB
Python

from __future__ import annotations
from typing import Dict
import polars as pl
from tools.base import BaseTool
class JoinMultipleTool(BaseTool):
def execute(self, inputs: Dict[str, pl.DataFrame]) -> Dict[str, pl.DataFrame]:
if self.config is None:
return {"Output": pl.DataFrame()}
sorted_keys = sorted(
inputs.keys(),
key=lambda k: int(k.replace("Input", "").lstrip("#") or "0"),
)
dfs = [inputs[k] for k in sorted_keys]
if not dfs:
return {"Output": pl.DataFrame()}
by_pos = self._cfg_attr("JoinByRecPos", "value", "True") == "True"
output_join_only = self._cfg_attr("OutputJoinOnly", "value", "False") == "True"
# Prefix all columns with Input_#N_
prefixed: list[pl.DataFrame] = []
for i, df in enumerate(dfs, start=1):
prefix = f"Input_#{i}_"
renamed = {c: f"{prefix}{c}" for c in df.columns}
prefixed.append(df.rename(renamed))
if by_pos:
# Join by row position — add index then join on it
indexed = [df.with_row_index(f"__pos_{i}__") for i, df in enumerate(prefixed)]
result = indexed[0]
for i in range(1, len(indexed)):
lk = f"__pos_{0}__"
rk = f"__pos_{i}__"
result = result.join(indexed[i], left_on=lk, right_on=rk, how="inner")
pos_cols = [f"__pos_{i}__" for i in range(len(indexed)) if f"__pos_{i}__" in result.columns]
result = result.drop(pos_cols)
else:
# Join by named fields — best effort: join first pair, then extend
result = prefixed[0]
for i in range(1, len(prefixed)):
result = pl.concat([result, prefixed[i]], how="diagonal_relaxed")
return {"Output": result}