208 lines
8.4 KiB
Python
208 lines
8.4 KiB
Python
from __future__ import annotations
|
|
from typing import Dict
|
|
import polars as pl
|
|
from tools.base import BaseTool
|
|
|
|
|
|
def _duckdb_to_polars(con, sql: str) -> pl.DataFrame:
|
|
return pl.from_arrow(con.execute(sql).arrow())
|
|
|
|
|
|
class JoinTool(BaseTool):
|
|
def execute(self, inputs: Dict[str, pl.DataFrame]) -> Dict[str, pl.DataFrame]:
|
|
left = inputs.get("Left", pl.DataFrame())
|
|
right = inputs.get("Right", pl.DataFrame())
|
|
if self.config is None:
|
|
return {"J": pl.DataFrame(), "L": left, "R": right}
|
|
|
|
by_pos = self._cfg_attr("JoinByRecordPos", "value", "False") == "True"
|
|
|
|
if by_pos:
|
|
left = left.with_row_index("__pos__")
|
|
right = right.with_row_index("__pos__")
|
|
join_keys = [("__pos__", "__pos__")]
|
|
else:
|
|
# Parse join keys from JoinInfo/Field elements
|
|
# Structure: <JoinInfo connection="Left"><Field field="col1" /><Field field="col2" /></JoinInfo>
|
|
left_fields = []
|
|
right_fields = []
|
|
for ji in self.config.findall("JoinInfo"):
|
|
conn = ji.attrib.get("connection", "")
|
|
fields = [f.attrib.get("field", "") for f in ji.findall("Field")]
|
|
if conn == "Left":
|
|
left_fields = fields
|
|
elif conn == "Right":
|
|
right_fields = fields
|
|
# Pair up left and right fields by position
|
|
join_keys = list(zip(left_fields, right_fields))
|
|
|
|
if not join_keys:
|
|
return {"Join": pl.DataFrame(), "Left": left, "Right": right}
|
|
|
|
# Cast join key columns to String on both sides to handle type mismatches
|
|
# (e.g., Int64 vs String from different input sources)
|
|
for l_key, r_key in join_keys:
|
|
if l_key in left.columns:
|
|
left = left.with_columns(pl.col(l_key).cast(pl.String))
|
|
if r_key in right.columns:
|
|
right = right.with_columns(pl.col(r_key).cast(pl.String))
|
|
|
|
j_df, l_df, r_df = self._execute_join(left, right, join_keys)
|
|
|
|
if by_pos:
|
|
j_df = j_df.drop("__pos__") if "__pos__" in j_df.columns else j_df
|
|
l_df = l_df.drop("__pos__") if "__pos__" in l_df.columns else l_df
|
|
r_df = r_df.drop("__pos__") if "__pos__" in r_df.columns else r_df
|
|
|
|
# Apply SelectConfiguration for each output
|
|
j_df = self._apply_select_config(j_df, "Join")
|
|
l_df = self._apply_select_config(l_df, "Left")
|
|
r_df = self._apply_select_config(r_df, "Right")
|
|
|
|
# Use anchor names that match Alteryx connection names
|
|
return {"Join": j_df, "Left": l_df, "Right": r_df}
|
|
|
|
def _apply_select_config(self, df: pl.DataFrame, output_connection: str) -> pl.DataFrame:
|
|
"""Apply field selection and renaming from SelectConfiguration."""
|
|
if df.is_empty() or self.config is None:
|
|
return df
|
|
|
|
select_config = self.config.find("SelectConfiguration")
|
|
if select_config is None:
|
|
return df
|
|
|
|
# Find the Configuration for this output connection
|
|
for cfg in select_config.findall("Configuration"):
|
|
if cfg.attrib.get("outputConnection") == output_connection:
|
|
select_fields = cfg.find("SelectFields")
|
|
if select_fields is None:
|
|
return df
|
|
|
|
order_changed_el = cfg.find("OrderChanged")
|
|
order_changed = (
|
|
order_changed_el is not None
|
|
and order_changed_el.attrib.get("value", "False") == "True"
|
|
)
|
|
|
|
# Parse field rules
|
|
rename_map: dict[str, str] = {} # src_col → output_name
|
|
exclude_set: set[str] = set() # columns explicitly excluded
|
|
explicit_order: list[str] = [] # for OrderChanged=True
|
|
has_unknown = False
|
|
unknown_selected = True
|
|
|
|
for sf in select_fields.findall("SelectField"):
|
|
field = sf.attrib.get("field", "")
|
|
selected = sf.attrib.get("selected", "True") == "True"
|
|
rename = sf.attrib.get("rename", "")
|
|
|
|
if field == "*Unknown":
|
|
has_unknown = True
|
|
unknown_selected = selected
|
|
continue
|
|
|
|
# Resolve column name in the DataFrame
|
|
if field not in df.columns:
|
|
continue
|
|
|
|
if not selected:
|
|
exclude_set.add(field)
|
|
else:
|
|
explicit_order.append(field)
|
|
if rename and rename != field:
|
|
rename_map[field] = rename
|
|
|
|
# Build final column list
|
|
mentioned = set(explicit_order) | exclude_set
|
|
|
|
if order_changed:
|
|
# Explicit selections first (in specified order), then *Unknown
|
|
final_cols = list(explicit_order)
|
|
if has_unknown and unknown_selected:
|
|
for col in df.columns:
|
|
if col not in mentioned:
|
|
final_cols.append(col)
|
|
else:
|
|
# Preserve original DataFrame column order
|
|
final_cols = []
|
|
for col in df.columns:
|
|
if col in exclude_set:
|
|
continue
|
|
if col in mentioned or (has_unknown and unknown_selected):
|
|
final_cols.append(col)
|
|
elif not has_unknown and col not in mentioned:
|
|
# Default: include if not explicitly excluded
|
|
final_cols.append(col)
|
|
|
|
if final_cols:
|
|
df = df.select(final_cols)
|
|
if rename_map:
|
|
df = df.rename(
|
|
{k: v for k, v in rename_map.items() if k in df.columns}
|
|
)
|
|
break
|
|
return df
|
|
|
|
def _execute_join(
|
|
self,
|
|
left: pl.DataFrame,
|
|
right: pl.DataFrame,
|
|
join_keys: list[tuple[str, str]],
|
|
) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]:
|
|
con = self.ctx.duckdb_con
|
|
|
|
key_l = {k[0] for k in join_keys}
|
|
key_r = {k[1] for k in join_keys}
|
|
l_non_key = [c for c in left.columns if c not in key_l]
|
|
r_non_key = [c for c in right.columns if c not in key_r]
|
|
# Only right non-key columns that clash with left columns need a prefix
|
|
conflicts = set(l_non_key) & set(r_non_key)
|
|
|
|
# Register the original (un-prefixed) tables
|
|
con.register("__join_left__", left.to_arrow())
|
|
con.register("__join_right__", right.to_arrow())
|
|
|
|
on_clause = " AND ".join(
|
|
f'l."{k[0]}" = r."{k[1]}"' for k in join_keys
|
|
)
|
|
|
|
# --- Inner join SELECT ------------------------------------------------
|
|
# Left columns first (no prefix), then right join keys with Right_
|
|
# prefix, then right non-key columns (Right_ prefix only on conflicts).
|
|
l_cols_sql = ", ".join(f'l."{c}"' for c in left.columns)
|
|
r_key_cols_sql = ", ".join(
|
|
f'r."{k[1]}" AS "Right_{k[1]}"' for k in join_keys
|
|
)
|
|
r_non_key_sql = ", ".join(
|
|
f'r."{c}" AS "Right_{c}"' if c in conflicts else f'r."{c}"'
|
|
for c in r_non_key
|
|
)
|
|
j_parts = [p for p in (l_cols_sql, r_key_cols_sql, r_non_key_sql) if p]
|
|
j_select = ", ".join(j_parts)
|
|
|
|
l_key0 = join_keys[0][0]
|
|
r_key0 = join_keys[0][1]
|
|
|
|
j_sql = f"SELECT {j_select} FROM __join_left__ l INNER JOIN __join_right__ r ON {on_clause}"
|
|
# Left/right unmatched keep original column names (no prefixes)
|
|
l_sql = f'SELECT l.* FROM __join_left__ l LEFT JOIN __join_right__ r ON {on_clause} WHERE r."{r_key0}" IS NULL'
|
|
r_sql = f'SELECT r.* FROM __join_right__ r LEFT JOIN __join_left__ l ON {on_clause} WHERE l."{l_key0}" IS NULL'
|
|
|
|
try:
|
|
j_df = _duckdb_to_polars(con, j_sql)
|
|
except Exception:
|
|
j_df = pl.DataFrame()
|
|
try:
|
|
l_df = _duckdb_to_polars(con, l_sql)
|
|
except Exception:
|
|
l_df = pl.DataFrame()
|
|
try:
|
|
r_df = _duckdb_to_polars(con, r_sql)
|
|
except Exception:
|
|
r_df = pl.DataFrame()
|
|
|
|
con.execute("DROP VIEW IF EXISTS __join_left__")
|
|
con.execute("DROP VIEW IF EXISTS __join_right__")
|
|
|
|
return j_df, l_df, r_df
|