from __future__ import annotations from typing import Dict import polars as pl from tools.base import BaseTool def _duckdb_to_polars(con, sql: str) -> pl.DataFrame: return pl.from_arrow(con.execute(sql).arrow()) class JoinTool(BaseTool): def execute(self, inputs: Dict[str, pl.DataFrame]) -> Dict[str, pl.DataFrame]: left = inputs.get("Left", pl.DataFrame()) right = inputs.get("Right", pl.DataFrame()) if self.config is None: return {"J": pl.DataFrame(), "L": left, "R": right} by_pos = self._cfg_attr("JoinByRecordPos", "value", "False") == "True" if by_pos: left = left.with_row_index("__pos__") right = right.with_row_index("__pos__") join_keys = [("__pos__", "__pos__")] else: # Parse join keys from JoinInfo/Field elements # Structure: left_fields = [] right_fields = [] for ji in self.config.findall("JoinInfo"): conn = ji.attrib.get("connection", "") fields = [f.attrib.get("field", "") for f in ji.findall("Field")] if conn == "Left": left_fields = fields elif conn == "Right": right_fields = fields # Pair up left and right fields by position join_keys = list(zip(left_fields, right_fields)) if not join_keys: return {"Join": pl.DataFrame(), "Left": left, "Right": right} # Cast join key columns to String on both sides to handle type mismatches # (e.g., Int64 vs String from different input sources) for l_key, r_key in join_keys: if l_key in left.columns: left = left.with_columns(pl.col(l_key).cast(pl.String)) if r_key in right.columns: right = right.with_columns(pl.col(r_key).cast(pl.String)) j_df, l_df, r_df = self._execute_join(left, right, join_keys) if by_pos: j_df = j_df.drop("__pos__") if "__pos__" in j_df.columns else j_df l_df = l_df.drop("__pos__") if "__pos__" in l_df.columns else l_df r_df = r_df.drop("__pos__") if "__pos__" in r_df.columns else r_df # Apply SelectConfiguration for each output j_df = self._apply_select_config(j_df, "Join") l_df = self._apply_select_config(l_df, "Left") r_df = self._apply_select_config(r_df, "Right") # Use anchor names that match Alteryx connection names return {"Join": j_df, "Left": l_df, "Right": r_df} def _apply_select_config(self, df: pl.DataFrame, output_connection: str) -> pl.DataFrame: """Apply field selection and renaming from SelectConfiguration.""" if df.is_empty() or self.config is None: return df select_config = self.config.find("SelectConfiguration") if select_config is None: return df # Find the Configuration for this output connection for cfg in select_config.findall("Configuration"): if cfg.attrib.get("outputConnection") == output_connection: select_fields = cfg.find("SelectFields") if select_fields is None: return df order_changed_el = cfg.find("OrderChanged") order_changed = ( order_changed_el is not None and order_changed_el.attrib.get("value", "False") == "True" ) # Parse field rules rename_map: dict[str, str] = {} # src_col → output_name exclude_set: set[str] = set() # columns explicitly excluded explicit_order: list[str] = [] # for OrderChanged=True has_unknown = False unknown_selected = True for sf in select_fields.findall("SelectField"): field = sf.attrib.get("field", "") selected = sf.attrib.get("selected", "True") == "True" rename = sf.attrib.get("rename", "") if field == "*Unknown": has_unknown = True unknown_selected = selected continue # Resolve column name in the DataFrame if field not in df.columns: continue if not selected: exclude_set.add(field) else: explicit_order.append(field) if rename and rename != field: rename_map[field] = rename # Build final column list mentioned = set(explicit_order) | exclude_set if order_changed: # Explicit selections first (in specified order), then *Unknown final_cols = list(explicit_order) if has_unknown and unknown_selected: for col in df.columns: if col not in mentioned: final_cols.append(col) else: # Preserve original DataFrame column order final_cols = [] for col in df.columns: if col in exclude_set: continue if col in mentioned or (has_unknown and unknown_selected): final_cols.append(col) elif not has_unknown and col not in mentioned: # Default: include if not explicitly excluded final_cols.append(col) if final_cols: df = df.select(final_cols) if rename_map: df = df.rename( {k: v for k, v in rename_map.items() if k in df.columns} ) break return df def _execute_join( self, left: pl.DataFrame, right: pl.DataFrame, join_keys: list[tuple[str, str]], ) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]: con = self.ctx.duckdb_con key_l = {k[0] for k in join_keys} key_r = {k[1] for k in join_keys} l_non_key = [c for c in left.columns if c not in key_l] r_non_key = [c for c in right.columns if c not in key_r] # Only right non-key columns that clash with left columns need a prefix conflicts = set(l_non_key) & set(r_non_key) # Register the original (un-prefixed) tables con.register("__join_left__", left.to_arrow()) con.register("__join_right__", right.to_arrow()) on_clause = " AND ".join( f'l."{k[0]}" = r."{k[1]}"' for k in join_keys ) # --- Inner join SELECT ------------------------------------------------ # Left columns first (no prefix), then right join keys with Right_ # prefix, then right non-key columns (Right_ prefix only on conflicts). l_cols_sql = ", ".join(f'l."{c}"' for c in left.columns) r_key_cols_sql = ", ".join( f'r."{k[1]}" AS "Right_{k[1]}"' for k in join_keys ) r_non_key_sql = ", ".join( f'r."{c}" AS "Right_{c}"' if c in conflicts else f'r."{c}"' for c in r_non_key ) j_parts = [p for p in (l_cols_sql, r_key_cols_sql, r_non_key_sql) if p] j_select = ", ".join(j_parts) l_key0 = join_keys[0][0] r_key0 = join_keys[0][1] j_sql = f"SELECT {j_select} FROM __join_left__ l INNER JOIN __join_right__ r ON {on_clause}" # Left/right unmatched keep original column names (no prefixes) l_sql = f'SELECT l.* FROM __join_left__ l LEFT JOIN __join_right__ r ON {on_clause} WHERE r."{r_key0}" IS NULL' r_sql = f'SELECT r.* FROM __join_right__ r LEFT JOIN __join_left__ l ON {on_clause} WHERE l."{l_key0}" IS NULL' try: j_df = _duckdb_to_polars(con, j_sql) except Exception: j_df = pl.DataFrame() try: l_df = _duckdb_to_polars(con, l_sql) except Exception: l_df = pl.DataFrame() try: r_df = _duckdb_to_polars(con, r_sql) except Exception: r_df = pl.DataFrame() con.execute("DROP VIEW IF EXISTS __join_left__") con.execute("DROP VIEW IF EXISTS __join_right__") return j_df, l_df, r_df