Pyteryx/alteryx_runner/tools/join/join_tool.py

207 lines
8.6 KiB
Python

from __future__ import annotations
from typing import Dict
import polars as pl
from tools.base import BaseTool
def _duckdb_to_polars(con, sql: str) -> pl.DataFrame:
return pl.from_arrow(con.execute(sql).arrow())
class JoinTool(BaseTool):
def execute(self, inputs: Dict[str, pl.DataFrame]) -> Dict[str, pl.DataFrame]:
left = inputs.get("Left", pl.DataFrame())
right = inputs.get("Right", pl.DataFrame())
if self.config is None:
return {"J": pl.DataFrame(), "L": left, "R": right}
by_pos = self._cfg_attr("JoinByRecordPos", "value", "False") == "True"
if by_pos:
left = left.with_row_index("__pos__")
right = right.with_row_index("__pos__")
join_keys = [("__pos__", "__pos__")]
else:
# Parse join keys from JoinInfo/Field elements
# Structure: <JoinInfo connection="Left"><Field field="col1" /><Field field="col2" /></JoinInfo>
left_fields = []
right_fields = []
for ji in self.config.findall("JoinInfo"):
conn = ji.attrib.get("connection", "")
fields = [f.attrib.get("field", "") for f in ji.findall("Field")]
if conn == "Left":
left_fields = fields
elif conn == "Right":
right_fields = fields
# Pair up left and right fields by position
join_keys = list(zip(left_fields, right_fields))
if not join_keys:
return {"Join": pl.DataFrame(), "Left": left, "Right": right}
# Cast join key columns to String on both sides to handle type mismatches
# (e.g., Int64 vs String from different input sources)
for l_key, r_key in join_keys:
if l_key in left.columns:
left = left.with_columns(pl.col(l_key).cast(pl.String))
if r_key in right.columns:
right = right.with_columns(pl.col(r_key).cast(pl.String))
j_df, l_df, r_df = self._execute_join(left, right, join_keys)
if by_pos:
j_df = j_df.drop("__pos__") if "__pos__" in j_df.columns else j_df
l_df = l_df.drop("__pos__") if "__pos__" in l_df.columns else l_df
r_df = r_df.drop("__pos__") if "__pos__" in r_df.columns else r_df
# Apply SelectConfiguration for each output
j_df = self._apply_select_config(j_df, "Join")
l_df = self._apply_select_config(l_df, "Left")
r_df = self._apply_select_config(r_df, "Right")
# Use anchor names that match Alteryx connection names
return {"Join": j_df, "Left": l_df, "Right": r_df}
def _apply_select_config(self, df: pl.DataFrame, output_connection: str) -> pl.DataFrame:
"""Apply field selection and renaming from SelectConfiguration."""
if df.is_empty() or self.config is None:
return df
select_config = self.config.find("SelectConfiguration")
if select_config is None:
return df
# Find the Configuration for this output connection
for cfg in select_config.findall("Configuration"):
if cfg.attrib.get("outputConnection") == output_connection:
select_fields = cfg.find("SelectFields")
if select_fields is None:
return df
# Build column mapping
# First, collect explicitly selected fields
explicit_selections = [] # list of (src_col, output_name)
has_unknown = False
for sf in select_fields.findall("SelectField"):
field = sf.attrib.get("field", "")
selected = sf.attrib.get("selected", "False") == "True"
rename = sf.attrib.get("rename", "")
input_prefix = sf.attrib.get("input", "")
if not selected:
continue
if field == "*Unknown":
has_unknown = True
else:
# Find the column with prefix
src_col = f"{input_prefix}{field}" if input_prefix else field
if src_col in df.columns:
output_name = rename if rename else field
explicit_selections.append((src_col, output_name))
elif field in df.columns:
output_name = rename if rename else field
explicit_selections.append((field, output_name))
# Build final column list
selected_cols = []
rename_map = {}
# Add explicitly selected columns
for src, dst in explicit_selections:
selected_cols.append(src)
if src != dst:
rename_map[src] = dst
# Handle *Unknown: include all remaining columns, stripping prefixes
if has_unknown:
explicit_srcs = {src for src, _ in explicit_selections}
for col in df.columns:
if col not in explicit_srcs:
# Strip Left_/Right_ prefix for output name
output_name = col
if col.startswith("Left_"):
output_name = col[5:]
elif col.startswith("Right_"):
output_name = col[6:]
selected_cols.append(col)
if col != output_name:
rename_map[col] = output_name
# Apply selection and renaming
if selected_cols:
df = df.select(selected_cols)
if rename_map:
df = df.rename(rename_map)
break
return df
def _execute_join(
self,
left: pl.DataFrame,
right: pl.DataFrame,
join_keys: list[tuple[str, str]],
) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]:
con = self.ctx.duckdb_con
# Disambiguate conflicting column names
key_l = {k[0] for k in join_keys}
key_r = {k[1] for k in join_keys}
l_non_key = [c for c in left.columns if c not in key_l]
r_non_key = [c for c in right.columns if c not in key_r]
conflicts = set(l_non_key) & set(r_non_key)
# Prefix all non-key columns: Left_ for left, Right_ for right
# This matches Alteryx behavior where SelectConfiguration references
# fields with these prefixes
rename_l = {c: f"Left_{c}" for c in l_non_key}
rename_r = {c: f"Right_{c}" for c in r_non_key}
# But keep join keys without prefix (they come from left)
left_r = left.rename(rename_l) if rename_l else left
right_r = right.rename(rename_r) if rename_r else right
con.register("__join_left__", left_r.to_arrow())
con.register("__join_right__", right_r.to_arrow())
# Map renamed key column names
def lk(k: str) -> str:
return rename_l.get(k, k)
def rk(k: str) -> str:
return rename_r.get(k, k)
on_clause = " AND ".join(
f'l."{lk(k[0])}" = r."{rk(k[1])}"' for k in join_keys
)
# Include right join keys with Right_ prefix for SelectConfiguration
r_key_cols_sql = ", ".join(f'r."{rk(k[1])}" AS "Right_{k[1]}"' for k in join_keys)
r_cols_sql = ", ".join(f'r."{rk(c)}"' for c in r_non_key)
if r_key_cols_sql:
r_cols_sql = f"{r_key_cols_sql}, {r_cols_sql}"
r_key0 = rk(join_keys[0][1])
l_key0 = lk(join_keys[0][0])
j_sql = f"SELECT l.*, {r_cols_sql} FROM __join_left__ l INNER JOIN __join_right__ r ON {on_clause}"
l_sql = f'SELECT l.* FROM __join_left__ l LEFT JOIN __join_right__ r ON {on_clause} WHERE r."{r_key0}" IS NULL'
r_sql = f'SELECT r.* FROM __join_right__ r LEFT JOIN __join_left__ l ON {on_clause} WHERE l."{l_key0}" IS NULL'
try:
j_df = _duckdb_to_polars(con, j_sql)
except Exception:
j_df = pl.DataFrame()
try:
l_df = _duckdb_to_polars(con, l_sql)
except Exception:
l_df = pl.DataFrame()
try:
r_df = _duckdb_to_polars(con, r_sql)
except Exception:
r_df = pl.DataFrame()
con.execute("DROP VIEW IF EXISTS __join_left__")
con.execute("DROP VIEW IF EXISTS __join_right__")
return j_df, l_df, r_df