Pyteryx/alteryx_runner/tools/join/append_fields.py

54 lines
1.7 KiB
Python

from __future__ import annotations
from typing import Dict
import polars as pl
from tools.base import BaseTool
class AppendFieldsTool(BaseTool):
"""Cross-join Source rows onto every Target row."""
def execute(self, inputs: Dict[str, pl.DataFrame]) -> Dict[str, pl.DataFrame]:
target = inputs.get("Target", pl.DataFrame())
source = inputs.get("Source", pl.DataFrame())
if target.is_empty():
return {"Output": pl.DataFrame()}
if source.is_empty():
return {"Output": target}
MAX_SOURCE_ROWS = 10_000
if len(source) > MAX_SOURCE_ROWS:
import warnings
warnings.warn(
f"AppendFields Source has {len(source)} rows; "
f"capping at {MAX_SOURCE_ROWS} for safety.",
stacklevel=2,
)
source = source.head(MAX_SOURCE_ROWS)
con = self.ctx.duckdb_con
con.register("__append_target__", target.to_arrow())
con.register("__append_source__", source.to_arrow())
# Disambiguate conflicting column names
t_cols = set(target.columns)
s_cols = set(source.columns)
conflicts = t_cols & s_cols
s_select = ", ".join(
f's."{c}" AS "Source_{c}"' if c in conflicts else f's."{c}"'
for c in source.columns
)
sql = f"""
SELECT t.*, {s_select}
FROM __append_target__ t
CROSS JOIN __append_source__ s
"""
try:
result = pl.from_arrow(con.execute(sql).arrow())
finally:
con.execute("DROP VIEW IF EXISTS __append_target__")
con.execute("DROP VIEW IF EXISTS __append_source__")
return {"Output": result}