207 lines
8.6 KiB
Python
207 lines
8.6 KiB
Python
from __future__ import annotations
|
|
from typing import Dict
|
|
import polars as pl
|
|
from tools.base import BaseTool
|
|
|
|
|
|
def _duckdb_to_polars(con, sql: str) -> pl.DataFrame:
|
|
return pl.from_arrow(con.execute(sql).arrow())
|
|
|
|
|
|
class JoinTool(BaseTool):
|
|
def execute(self, inputs: Dict[str, pl.DataFrame]) -> Dict[str, pl.DataFrame]:
|
|
left = inputs.get("Left", pl.DataFrame())
|
|
right = inputs.get("Right", pl.DataFrame())
|
|
if self.config is None:
|
|
return {"J": pl.DataFrame(), "L": left, "R": right}
|
|
|
|
by_pos = self._cfg_attr("JoinByRecordPos", "value", "False") == "True"
|
|
|
|
if by_pos:
|
|
left = left.with_row_index("__pos__")
|
|
right = right.with_row_index("__pos__")
|
|
join_keys = [("__pos__", "__pos__")]
|
|
else:
|
|
# Parse join keys from JoinInfo/Field elements
|
|
# Structure: <JoinInfo connection="Left"><Field field="col1" /><Field field="col2" /></JoinInfo>
|
|
left_fields = []
|
|
right_fields = []
|
|
for ji in self.config.findall("JoinInfo"):
|
|
conn = ji.attrib.get("connection", "")
|
|
fields = [f.attrib.get("field", "") for f in ji.findall("Field")]
|
|
if conn == "Left":
|
|
left_fields = fields
|
|
elif conn == "Right":
|
|
right_fields = fields
|
|
# Pair up left and right fields by position
|
|
join_keys = list(zip(left_fields, right_fields))
|
|
|
|
if not join_keys:
|
|
return {"Join": pl.DataFrame(), "Left": left, "Right": right}
|
|
|
|
# Cast join key columns to String on both sides to handle type mismatches
|
|
# (e.g., Int64 vs String from different input sources)
|
|
for l_key, r_key in join_keys:
|
|
if l_key in left.columns:
|
|
left = left.with_columns(pl.col(l_key).cast(pl.String))
|
|
if r_key in right.columns:
|
|
right = right.with_columns(pl.col(r_key).cast(pl.String))
|
|
|
|
j_df, l_df, r_df = self._execute_join(left, right, join_keys)
|
|
|
|
if by_pos:
|
|
j_df = j_df.drop("__pos__") if "__pos__" in j_df.columns else j_df
|
|
l_df = l_df.drop("__pos__") if "__pos__" in l_df.columns else l_df
|
|
r_df = r_df.drop("__pos__") if "__pos__" in r_df.columns else r_df
|
|
|
|
# Apply SelectConfiguration for each output
|
|
j_df = self._apply_select_config(j_df, "Join")
|
|
l_df = self._apply_select_config(l_df, "Left")
|
|
r_df = self._apply_select_config(r_df, "Right")
|
|
|
|
# Use anchor names that match Alteryx connection names
|
|
return {"Join": j_df, "Left": l_df, "Right": r_df}
|
|
|
|
def _apply_select_config(self, df: pl.DataFrame, output_connection: str) -> pl.DataFrame:
|
|
"""Apply field selection and renaming from SelectConfiguration."""
|
|
if df.is_empty() or self.config is None:
|
|
return df
|
|
|
|
select_config = self.config.find("SelectConfiguration")
|
|
if select_config is None:
|
|
return df
|
|
|
|
# Find the Configuration for this output connection
|
|
for cfg in select_config.findall("Configuration"):
|
|
if cfg.attrib.get("outputConnection") == output_connection:
|
|
select_fields = cfg.find("SelectFields")
|
|
if select_fields is None:
|
|
return df
|
|
|
|
# Build column mapping
|
|
# First, collect explicitly selected fields
|
|
explicit_selections = [] # list of (src_col, output_name)
|
|
has_unknown = False
|
|
|
|
for sf in select_fields.findall("SelectField"):
|
|
field = sf.attrib.get("field", "")
|
|
selected = sf.attrib.get("selected", "False") == "True"
|
|
rename = sf.attrib.get("rename", "")
|
|
input_prefix = sf.attrib.get("input", "")
|
|
|
|
if not selected:
|
|
continue
|
|
|
|
if field == "*Unknown":
|
|
has_unknown = True
|
|
else:
|
|
# Find the column with prefix
|
|
src_col = f"{input_prefix}{field}" if input_prefix else field
|
|
if src_col in df.columns:
|
|
output_name = rename if rename else field
|
|
explicit_selections.append((src_col, output_name))
|
|
elif field in df.columns:
|
|
output_name = rename if rename else field
|
|
explicit_selections.append((field, output_name))
|
|
|
|
# Build final column list
|
|
selected_cols = []
|
|
rename_map = {}
|
|
|
|
# Add explicitly selected columns
|
|
for src, dst in explicit_selections:
|
|
selected_cols.append(src)
|
|
if src != dst:
|
|
rename_map[src] = dst
|
|
|
|
# Handle *Unknown: include all remaining columns, stripping prefixes
|
|
if has_unknown:
|
|
explicit_srcs = {src for src, _ in explicit_selections}
|
|
for col in df.columns:
|
|
if col not in explicit_srcs:
|
|
# Strip Left_/Right_ prefix for output name
|
|
output_name = col
|
|
if col.startswith("Left_"):
|
|
output_name = col[5:]
|
|
elif col.startswith("Right_"):
|
|
output_name = col[6:]
|
|
selected_cols.append(col)
|
|
if col != output_name:
|
|
rename_map[col] = output_name
|
|
|
|
# Apply selection and renaming
|
|
if selected_cols:
|
|
df = df.select(selected_cols)
|
|
if rename_map:
|
|
df = df.rename(rename_map)
|
|
break
|
|
return df
|
|
|
|
def _execute_join(
|
|
self,
|
|
left: pl.DataFrame,
|
|
right: pl.DataFrame,
|
|
join_keys: list[tuple[str, str]],
|
|
) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]:
|
|
con = self.ctx.duckdb_con
|
|
|
|
# Disambiguate conflicting column names
|
|
key_l = {k[0] for k in join_keys}
|
|
key_r = {k[1] for k in join_keys}
|
|
l_non_key = [c for c in left.columns if c not in key_l]
|
|
r_non_key = [c for c in right.columns if c not in key_r]
|
|
conflicts = set(l_non_key) & set(r_non_key)
|
|
|
|
# Prefix all non-key columns: Left_ for left, Right_ for right
|
|
# This matches Alteryx behavior where SelectConfiguration references
|
|
# fields with these prefixes
|
|
rename_l = {c: f"Left_{c}" for c in l_non_key}
|
|
rename_r = {c: f"Right_{c}" for c in r_non_key}
|
|
# But keep join keys without prefix (they come from left)
|
|
left_r = left.rename(rename_l) if rename_l else left
|
|
right_r = right.rename(rename_r) if rename_r else right
|
|
|
|
con.register("__join_left__", left_r.to_arrow())
|
|
con.register("__join_right__", right_r.to_arrow())
|
|
|
|
# Map renamed key column names
|
|
def lk(k: str) -> str:
|
|
return rename_l.get(k, k)
|
|
|
|
def rk(k: str) -> str:
|
|
return rename_r.get(k, k)
|
|
|
|
on_clause = " AND ".join(
|
|
f'l."{lk(k[0])}" = r."{rk(k[1])}"' for k in join_keys
|
|
)
|
|
|
|
# Include right join keys with Right_ prefix for SelectConfiguration
|
|
r_key_cols_sql = ", ".join(f'r."{rk(k[1])}" AS "Right_{k[1]}"' for k in join_keys)
|
|
r_cols_sql = ", ".join(f'r."{rk(c)}"' for c in r_non_key)
|
|
if r_key_cols_sql:
|
|
r_cols_sql = f"{r_key_cols_sql}, {r_cols_sql}"
|
|
|
|
r_key0 = rk(join_keys[0][1])
|
|
l_key0 = lk(join_keys[0][0])
|
|
j_sql = f"SELECT l.*, {r_cols_sql} FROM __join_left__ l INNER JOIN __join_right__ r ON {on_clause}"
|
|
l_sql = f'SELECT l.* FROM __join_left__ l LEFT JOIN __join_right__ r ON {on_clause} WHERE r."{r_key0}" IS NULL'
|
|
r_sql = f'SELECT r.* FROM __join_right__ r LEFT JOIN __join_left__ l ON {on_clause} WHERE l."{l_key0}" IS NULL'
|
|
|
|
try:
|
|
j_df = _duckdb_to_polars(con, j_sql)
|
|
except Exception:
|
|
j_df = pl.DataFrame()
|
|
try:
|
|
l_df = _duckdb_to_polars(con, l_sql)
|
|
except Exception:
|
|
l_df = pl.DataFrame()
|
|
try:
|
|
r_df = _duckdb_to_polars(con, r_sql)
|
|
except Exception:
|
|
r_df = pl.DataFrame()
|
|
|
|
con.execute("DROP VIEW IF EXISTS __join_left__")
|
|
con.execute("DROP VIEW IF EXISTS __join_right__")
|
|
|
|
return j_df, l_df, r_df
|