131 lines
5.3 KiB
Python
131 lines
5.3 KiB
Python
from __future__ import annotations
|
|
import csv
|
|
import io
|
|
from pathlib import Path
|
|
from typing import Dict
|
|
import xml.etree.ElementTree as ET
|
|
import polars as pl
|
|
from tools.base import BaseTool
|
|
|
|
|
|
class OutputDataTool(BaseTool):
|
|
def execute(self, inputs: Dict[str, pl.DataFrame]) -> Dict[str, pl.DataFrame]:
|
|
df = inputs.get("Input", pl.DataFrame())
|
|
if self.config is None or df.is_empty():
|
|
return {}
|
|
|
|
file_el = self.config.find("File")
|
|
if file_el is None or not file_el.text:
|
|
return {}
|
|
|
|
raw_path = (file_el.text or "").strip()
|
|
fmt = int(file_el.attrib.get("FileFormat", "0"))
|
|
max_records_str = (file_el.attrib.get("MaxRecords") or "").strip()
|
|
max_records = int(max_records_str) if max_records_str else None
|
|
|
|
opts = self.config.find("FormatSpecificOptions")
|
|
if opts is None:
|
|
opts = ET.Element("x")
|
|
|
|
multi_el = self.config.find("MultiFile")
|
|
multi_file = (multi_el.attrib.get("value", "False") if multi_el is not None else "False") == "True"
|
|
multi_field = (self.config.findtext("MultiFileField") or "").strip()
|
|
multi_type = (self.config.findtext("MultiFileType") or "Suffix").strip()
|
|
keep_field = (self.config.findtext("KeepField") or "True").strip().lower() == "true"
|
|
|
|
out_path = self.ctx.resolve_output_path(raw_path)
|
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
if multi_file and multi_field and multi_field in df.columns:
|
|
for val in df[multi_field].unique().to_list():
|
|
part = df.filter(pl.col(multi_field) == val)
|
|
if not keep_field:
|
|
part = part.drop(multi_field)
|
|
part_path = self._multi_path(out_path, str(val), multi_type)
|
|
self._write(part, part_path, fmt, opts)
|
|
elif max_records:
|
|
chunk_num = 0
|
|
for i in range(0, len(df), max_records):
|
|
chunk = df.slice(i, max_records)
|
|
chunk_path = out_path if chunk_num == 0 else out_path.with_stem(
|
|
f"{out_path.stem}_{chunk_num}"
|
|
)
|
|
self._write(chunk, chunk_path, fmt, opts)
|
|
chunk_num += 1
|
|
else:
|
|
self._write(df, out_path, fmt, opts)
|
|
|
|
if self.ctx.verbose:
|
|
print(f"[Output] Wrote {len(df)} rows → {out_path}")
|
|
|
|
return {}
|
|
|
|
def _multi_path(self, base: Path, value: str, mode: str) -> Path:
|
|
safe = "".join(c if c.isalnum() or c in "-_" else "_" for c in value)
|
|
if mode == "Suffix":
|
|
return base.with_stem(f"{base.stem}_{safe}")
|
|
return base.with_stem(f"{safe}_{base.stem}")
|
|
|
|
def _write(self, df: pl.DataFrame, path: Path, fmt: int, opts: ET.Element) -> None:
|
|
if fmt in (0, 6): # CSV
|
|
delim = opts.findtext("Delimeter") or opts.findtext("Delimiter") or ","
|
|
# Accept both HeaderRow and Headers attribute names
|
|
header_val = (
|
|
opts.findtext("HeaderRow")
|
|
or opts.findtext("Headers")
|
|
or "True"
|
|
)
|
|
header = header_val.lower() != "false"
|
|
line_end = (opts.findtext("LineEndStyle") or "LF").strip().upper()
|
|
eol = "\r\n" if line_end == "CRLF" else "\n"
|
|
force_quotes = (opts.findtext("ForceQuotes") or "False").lower() == "true"
|
|
self._write_csv_alteryx(df, path, delim, header, eol, force_quotes)
|
|
elif fmt == 25: # Excel
|
|
df.write_excel(str(path))
|
|
elif fmt == 2: # Parquet
|
|
df.write_parquet(str(path))
|
|
elif fmt == 19: # YXDB — fall back to Parquet
|
|
fallback = path.with_suffix(".parquet")
|
|
df.write_parquet(str(fallback))
|
|
if self.ctx.verbose:
|
|
print(f"[Output] YXDB write not supported; wrote Parquet to {fallback}")
|
|
else:
|
|
df.write_csv(str(path))
|
|
|
|
@staticmethod
|
|
def _write_csv_alteryx(
|
|
df: pl.DataFrame,
|
|
path: Path,
|
|
delim: str,
|
|
header: bool,
|
|
eol: str,
|
|
force_quotes: bool,
|
|
) -> None:
|
|
"""Write CSV matching Alteryx quoting behaviour.
|
|
|
|
Alteryx quotes a field when it contains the delimiter, a newline,
|
|
a double-quote, or a single-quote (apostrophe). We replicate that
|
|
by using Python's csv module with a custom quoting function.
|
|
"""
|
|
# Characters that trigger quoting (beyond csv.QUOTE_MINIMAL's default set)
|
|
_needs_quote_chars = frozenset({delim, '"', "'", '\n', '\r'})
|
|
|
|
def _needs_quoting(val: str) -> bool:
|
|
return force_quotes or any(c in _needs_quote_chars for c in val)
|
|
|
|
with open(path, "w", newline="", encoding="utf-8") as f:
|
|
if header:
|
|
f.write(delim.join(df.columns) + eol)
|
|
for row in df.iter_rows():
|
|
parts: list[str] = []
|
|
for val in row:
|
|
if val is None:
|
|
parts.append("")
|
|
else:
|
|
s = str(val)
|
|
if _needs_quoting(s):
|
|
# Escape embedded double-quotes by doubling them
|
|
s = '"' + s.replace('"', '""') + '"'
|
|
parts.append(s)
|
|
f.write(delim.join(parts) + eol)
|