Pyteryx/alteryx_runner/tools/join/join_tool.py

208 lines
8.4 KiB
Python

from __future__ import annotations
from typing import Dict
import polars as pl
from tools.base import BaseTool
def _duckdb_to_polars(con, sql: str) -> pl.DataFrame:
return pl.from_arrow(con.execute(sql).arrow())
class JoinTool(BaseTool):
def execute(self, inputs: Dict[str, pl.DataFrame]) -> Dict[str, pl.DataFrame]:
left = inputs.get("Left", pl.DataFrame())
right = inputs.get("Right", pl.DataFrame())
if self.config is None:
return {"J": pl.DataFrame(), "L": left, "R": right}
by_pos = self._cfg_attr("JoinByRecordPos", "value", "False") == "True"
if by_pos:
left = left.with_row_index("__pos__")
right = right.with_row_index("__pos__")
join_keys = [("__pos__", "__pos__")]
else:
# Parse join keys from JoinInfo/Field elements
# Structure: <JoinInfo connection="Left"><Field field="col1" /><Field field="col2" /></JoinInfo>
left_fields = []
right_fields = []
for ji in self.config.findall("JoinInfo"):
conn = ji.attrib.get("connection", "")
fields = [f.attrib.get("field", "") for f in ji.findall("Field")]
if conn == "Left":
left_fields = fields
elif conn == "Right":
right_fields = fields
# Pair up left and right fields by position
join_keys = list(zip(left_fields, right_fields))
if not join_keys:
return {"Join": pl.DataFrame(), "Left": left, "Right": right}
# Cast join key columns to String on both sides to handle type mismatches
# (e.g., Int64 vs String from different input sources)
for l_key, r_key in join_keys:
if l_key in left.columns:
left = left.with_columns(pl.col(l_key).cast(pl.String))
if r_key in right.columns:
right = right.with_columns(pl.col(r_key).cast(pl.String))
j_df, l_df, r_df = self._execute_join(left, right, join_keys)
if by_pos:
j_df = j_df.drop("__pos__") if "__pos__" in j_df.columns else j_df
l_df = l_df.drop("__pos__") if "__pos__" in l_df.columns else l_df
r_df = r_df.drop("__pos__") if "__pos__" in r_df.columns else r_df
# Apply SelectConfiguration for each output
j_df = self._apply_select_config(j_df, "Join")
l_df = self._apply_select_config(l_df, "Left")
r_df = self._apply_select_config(r_df, "Right")
# Use anchor names that match Alteryx connection names
return {"Join": j_df, "Left": l_df, "Right": r_df}
def _apply_select_config(self, df: pl.DataFrame, output_connection: str) -> pl.DataFrame:
"""Apply field selection and renaming from SelectConfiguration."""
if df.is_empty() or self.config is None:
return df
select_config = self.config.find("SelectConfiguration")
if select_config is None:
return df
# Find the Configuration for this output connection
for cfg in select_config.findall("Configuration"):
if cfg.attrib.get("outputConnection") == output_connection:
select_fields = cfg.find("SelectFields")
if select_fields is None:
return df
order_changed_el = cfg.find("OrderChanged")
order_changed = (
order_changed_el is not None
and order_changed_el.attrib.get("value", "False") == "True"
)
# Parse field rules
rename_map: dict[str, str] = {} # src_col → output_name
exclude_set: set[str] = set() # columns explicitly excluded
explicit_order: list[str] = [] # for OrderChanged=True
has_unknown = False
unknown_selected = True
for sf in select_fields.findall("SelectField"):
field = sf.attrib.get("field", "")
selected = sf.attrib.get("selected", "True") == "True"
rename = sf.attrib.get("rename", "")
if field == "*Unknown":
has_unknown = True
unknown_selected = selected
continue
# Resolve column name in the DataFrame
if field not in df.columns:
continue
if not selected:
exclude_set.add(field)
else:
explicit_order.append(field)
if rename and rename != field:
rename_map[field] = rename
# Build final column list
mentioned = set(explicit_order) | exclude_set
if order_changed:
# Explicit selections first (in specified order), then *Unknown
final_cols = list(explicit_order)
if has_unknown and unknown_selected:
for col in df.columns:
if col not in mentioned:
final_cols.append(col)
else:
# Preserve original DataFrame column order
final_cols = []
for col in df.columns:
if col in exclude_set:
continue
if col in mentioned or (has_unknown and unknown_selected):
final_cols.append(col)
elif not has_unknown and col not in mentioned:
# Default: include if not explicitly excluded
final_cols.append(col)
if final_cols:
df = df.select(final_cols)
if rename_map:
df = df.rename(
{k: v for k, v in rename_map.items() if k in df.columns}
)
break
return df
def _execute_join(
self,
left: pl.DataFrame,
right: pl.DataFrame,
join_keys: list[tuple[str, str]],
) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]:
con = self.ctx.duckdb_con
key_l = {k[0] for k in join_keys}
key_r = {k[1] for k in join_keys}
l_non_key = [c for c in left.columns if c not in key_l]
r_non_key = [c for c in right.columns if c not in key_r]
# Only right non-key columns that clash with left columns need a prefix
conflicts = set(l_non_key) & set(r_non_key)
# Register the original (un-prefixed) tables
con.register("__join_left__", left.to_arrow())
con.register("__join_right__", right.to_arrow())
on_clause = " AND ".join(
f'l."{k[0]}" = r."{k[1]}"' for k in join_keys
)
# --- Inner join SELECT ------------------------------------------------
# Left columns first (no prefix), then right join keys with Right_
# prefix, then right non-key columns (Right_ prefix only on conflicts).
l_cols_sql = ", ".join(f'l."{c}"' for c in left.columns)
r_key_cols_sql = ", ".join(
f'r."{k[1]}" AS "Right_{k[1]}"' for k in join_keys
)
r_non_key_sql = ", ".join(
f'r."{c}" AS "Right_{c}"' if c in conflicts else f'r."{c}"'
for c in r_non_key
)
j_parts = [p for p in (l_cols_sql, r_key_cols_sql, r_non_key_sql) if p]
j_select = ", ".join(j_parts)
l_key0 = join_keys[0][0]
r_key0 = join_keys[0][1]
j_sql = f"SELECT {j_select} FROM __join_left__ l INNER JOIN __join_right__ r ON {on_clause}"
# Left/right unmatched keep original column names (no prefixes)
l_sql = f'SELECT l.* FROM __join_left__ l LEFT JOIN __join_right__ r ON {on_clause} WHERE r."{r_key0}" IS NULL'
r_sql = f'SELECT r.* FROM __join_right__ r LEFT JOIN __join_left__ l ON {on_clause} WHERE l."{l_key0}" IS NULL'
try:
j_df = _duckdb_to_polars(con, j_sql)
except Exception:
j_df = pl.DataFrame()
try:
l_df = _duckdb_to_polars(con, l_sql)
except Exception:
l_df = pl.DataFrame()
try:
r_df = _duckdb_to_polars(con, r_sql)
except Exception:
r_df = pl.DataFrame()
con.execute("DROP VIEW IF EXISTS __join_left__")
con.execute("DROP VIEW IF EXISTS __join_right__")
return j_df, l_df, r_df