from __future__ import annotations from typing import Dict import polars as pl from tools.base import BaseTool def _duckdb_to_polars(con, sql: str) -> pl.DataFrame: return pl.from_arrow(con.execute(sql).arrow()) class JoinTool(BaseTool): def execute(self, inputs: Dict[str, pl.DataFrame]) -> Dict[str, pl.DataFrame]: left = inputs.get("Left", pl.DataFrame()) right = inputs.get("Right", pl.DataFrame()) if self.config is None: return {"J": pl.DataFrame(), "L": left, "R": right} by_pos = self._cfg_attr("JoinByRecordPos", "value", "False") == "True" if by_pos: left = left.with_row_index("__pos__") right = right.with_row_index("__pos__") join_keys = [("__pos__", "__pos__")] else: # Parse join keys from JoinInfo/Field elements # Structure: left_fields = [] right_fields = [] for ji in self.config.findall("JoinInfo"): conn = ji.attrib.get("connection", "") fields = [f.attrib.get("field", "") for f in ji.findall("Field")] if conn == "Left": left_fields = fields elif conn == "Right": right_fields = fields # Pair up left and right fields by position join_keys = list(zip(left_fields, right_fields)) if not join_keys: return {"Join": pl.DataFrame(), "Left": left, "Right": right} # Cast join key columns to String on both sides to handle type mismatches # (e.g., Int64 vs String from different input sources) for l_key, r_key in join_keys: if l_key in left.columns: left = left.with_columns(pl.col(l_key).cast(pl.String)) if r_key in right.columns: right = right.with_columns(pl.col(r_key).cast(pl.String)) j_df, l_df, r_df = self._execute_join(left, right, join_keys) if by_pos: j_df = j_df.drop("__pos__") if "__pos__" in j_df.columns else j_df l_df = l_df.drop("__pos__") if "__pos__" in l_df.columns else l_df r_df = r_df.drop("__pos__") if "__pos__" in r_df.columns else r_df # Apply SelectConfiguration for each output j_df = self._apply_select_config(j_df, "Join") l_df = self._apply_select_config(l_df, "Left") r_df = self._apply_select_config(r_df, "Right") # Use anchor names that match Alteryx connection names return {"Join": j_df, "Left": l_df, "Right": r_df} def _apply_select_config(self, df: pl.DataFrame, output_connection: str) -> pl.DataFrame: """Apply field selection and renaming from SelectConfiguration.""" if df.is_empty() or self.config is None: return df select_config = self.config.find("SelectConfiguration") if select_config is None: return df # Find the Configuration for this output connection for cfg in select_config.findall("Configuration"): if cfg.attrib.get("outputConnection") == output_connection: select_fields = cfg.find("SelectFields") if select_fields is None: return df # Build column mapping # First, collect explicitly selected fields explicit_selections = [] # list of (src_col, output_name) has_unknown = False for sf in select_fields.findall("SelectField"): field = sf.attrib.get("field", "") selected = sf.attrib.get("selected", "False") == "True" rename = sf.attrib.get("rename", "") input_prefix = sf.attrib.get("input", "") if not selected: continue if field == "*Unknown": has_unknown = True else: # Find the column with prefix src_col = f"{input_prefix}{field}" if input_prefix else field if src_col in df.columns: output_name = rename if rename else field explicit_selections.append((src_col, output_name)) elif field in df.columns: output_name = rename if rename else field explicit_selections.append((field, output_name)) # Build final column list selected_cols = [] rename_map = {} # Add explicitly selected columns for src, dst in explicit_selections: selected_cols.append(src) if src != dst: rename_map[src] = dst # Handle *Unknown: include all remaining columns, stripping prefixes if has_unknown: explicit_srcs = {src for src, _ in explicit_selections} for col in df.columns: if col not in explicit_srcs: # Strip Left_/Right_ prefix for output name output_name = col if col.startswith("Left_"): output_name = col[5:] elif col.startswith("Right_"): output_name = col[6:] selected_cols.append(col) if col != output_name: rename_map[col] = output_name # Apply selection and renaming if selected_cols: df = df.select(selected_cols) if rename_map: df = df.rename(rename_map) break return df def _execute_join( self, left: pl.DataFrame, right: pl.DataFrame, join_keys: list[tuple[str, str]], ) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]: con = self.ctx.duckdb_con # Disambiguate conflicting column names key_l = {k[0] for k in join_keys} key_r = {k[1] for k in join_keys} l_non_key = [c for c in left.columns if c not in key_l] r_non_key = [c for c in right.columns if c not in key_r] conflicts = set(l_non_key) & set(r_non_key) # Prefix all non-key columns: Left_ for left, Right_ for right # This matches Alteryx behavior where SelectConfiguration references # fields with these prefixes rename_l = {c: f"Left_{c}" for c in l_non_key} rename_r = {c: f"Right_{c}" for c in r_non_key} # But keep join keys without prefix (they come from left) left_r = left.rename(rename_l) if rename_l else left right_r = right.rename(rename_r) if rename_r else right con.register("__join_left__", left_r.to_arrow()) con.register("__join_right__", right_r.to_arrow()) # Map renamed key column names def lk(k: str) -> str: return rename_l.get(k, k) def rk(k: str) -> str: return rename_r.get(k, k) on_clause = " AND ".join( f'l."{lk(k[0])}" = r."{rk(k[1])}"' for k in join_keys ) # Include right join keys with Right_ prefix for SelectConfiguration r_key_cols_sql = ", ".join(f'r."{rk(k[1])}" AS "Right_{k[1]}"' for k in join_keys) r_cols_sql = ", ".join(f'r."{rk(c)}"' for c in r_non_key) if r_key_cols_sql: r_cols_sql = f"{r_key_cols_sql}, {r_cols_sql}" r_key0 = rk(join_keys[0][1]) l_key0 = lk(join_keys[0][0]) j_sql = f"SELECT l.*, {r_cols_sql} FROM __join_left__ l INNER JOIN __join_right__ r ON {on_clause}" l_sql = f'SELECT l.* FROM __join_left__ l LEFT JOIN __join_right__ r ON {on_clause} WHERE r."{r_key0}" IS NULL' r_sql = f'SELECT r.* FROM __join_right__ r LEFT JOIN __join_left__ l ON {on_clause} WHERE l."{l_key0}" IS NULL' try: j_df = _duckdb_to_polars(con, j_sql) except Exception: j_df = pl.DataFrame() try: l_df = _duckdb_to_polars(con, l_sql) except Exception: l_df = pl.DataFrame() try: r_df = _duckdb_to_polars(con, r_sql) except Exception: r_df = pl.DataFrame() con.execute("DROP VIEW IF EXISTS __join_left__") con.execute("DROP VIEW IF EXISTS __join_right__") return j_df, l_df, r_df