54 lines
1.7 KiB
Python
54 lines
1.7 KiB
Python
from __future__ import annotations
|
|
from typing import Dict
|
|
import polars as pl
|
|
from tools.base import BaseTool
|
|
|
|
|
|
class AppendFieldsTool(BaseTool):
|
|
"""Cross-join Source rows onto every Target row."""
|
|
|
|
def execute(self, inputs: Dict[str, pl.DataFrame]) -> Dict[str, pl.DataFrame]:
|
|
target = inputs.get("Target", pl.DataFrame())
|
|
source = inputs.get("Source", pl.DataFrame())
|
|
|
|
if target.is_empty():
|
|
return {"Output": pl.DataFrame()}
|
|
if source.is_empty():
|
|
return {"Output": target}
|
|
|
|
MAX_SOURCE_ROWS = 10_000
|
|
if len(source) > MAX_SOURCE_ROWS:
|
|
import warnings
|
|
warnings.warn(
|
|
f"AppendFields Source has {len(source)} rows; "
|
|
f"capping at {MAX_SOURCE_ROWS} for safety.",
|
|
stacklevel=2,
|
|
)
|
|
source = source.head(MAX_SOURCE_ROWS)
|
|
|
|
con = self.ctx.duckdb_con
|
|
con.register("__append_target__", target.to_arrow())
|
|
con.register("__append_source__", source.to_arrow())
|
|
|
|
# Disambiguate conflicting column names
|
|
t_cols = set(target.columns)
|
|
s_cols = set(source.columns)
|
|
conflicts = t_cols & s_cols
|
|
s_select = ", ".join(
|
|
f's."{c}" AS "Source_{c}"' if c in conflicts else f's."{c}"'
|
|
for c in source.columns
|
|
)
|
|
|
|
sql = f"""
|
|
SELECT t.*, {s_select}
|
|
FROM __append_target__ t
|
|
CROSS JOIN __append_source__ s
|
|
"""
|
|
try:
|
|
result = pl.from_arrow(con.execute(sql).arrow())
|
|
finally:
|
|
con.execute("DROP VIEW IF EXISTS __append_target__")
|
|
con.execute("DROP VIEW IF EXISTS __append_source__")
|
|
|
|
return {"Output": result}
|