100 lines
3.4 KiB
Python
100 lines
3.4 KiB
Python
from __future__ import annotations
|
|
from typing import Dict, Optional
|
|
import xml.etree.ElementTree as ET
|
|
import polars as pl
|
|
from tools.base import BaseTool
|
|
|
|
|
|
class InputDataTool(BaseTool):
|
|
def execute(self, inputs: Dict[str, pl.DataFrame]) -> Dict[str, pl.DataFrame]:
|
|
if self.config is None:
|
|
return {"Output": pl.DataFrame()}
|
|
|
|
file_el = self.config.find("File")
|
|
if file_el is None or not file_el.text:
|
|
return {"Output": pl.DataFrame()}
|
|
|
|
raw_path = (file_el.text or "").strip()
|
|
fmt = int(file_el.attrib.get("FileFormat", "0"))
|
|
record_limit_str = file_el.attrib.get("RecordLimit", "").strip()
|
|
limit = int(record_limit_str) if record_limit_str else None
|
|
|
|
opts = self.config.find("FormatSpecificOptions")
|
|
if opts is None:
|
|
opts = ET.Element("x")
|
|
|
|
path_str, sheet = self._parse_path(raw_path)
|
|
resolved = self.ctx.resolve_path(path_str)
|
|
|
|
df = self._read(str(resolved), fmt, sheet, opts)
|
|
|
|
# Trim whitespace from string columns (matches Alteryx behavior)
|
|
for col in df.columns:
|
|
if df[col].dtype == pl.String:
|
|
df = df.with_columns(pl.col(col).str.strip_chars())
|
|
|
|
if limit:
|
|
df = df.head(limit)
|
|
|
|
return {"Output": df}
|
|
|
|
def _parse_path(self, raw: str) -> tuple[str, Optional[str]]:
|
|
if "|||" in raw:
|
|
path, sheet = raw.split("|||", 1)
|
|
return path.strip(), sheet.strip().strip("`").rstrip("$")
|
|
return raw.strip(), None
|
|
|
|
def _read(
|
|
self,
|
|
path: str,
|
|
fmt: int,
|
|
sheet: Optional[str],
|
|
opts: ET.Element,
|
|
) -> pl.DataFrame:
|
|
if fmt in (0, 6): # CSV / delimited
|
|
delim = opts.findtext("Delimeter") or opts.findtext("Delimiter") or ","
|
|
header_text = opts.findtext("HeaderRow") or "True"
|
|
has_header = header_text.strip().lower() in ("true", "1", "yes")
|
|
import_line = int(opts.findtext("ImportLine") or "1")
|
|
skip = max(0, import_line - 1)
|
|
return pl.read_csv(
|
|
path,
|
|
separator=delim,
|
|
has_header=has_header,
|
|
skip_rows=skip,
|
|
infer_schema_length=10000,
|
|
ignore_errors=True,
|
|
)
|
|
|
|
if fmt == 25: # Excel
|
|
read_header = (opts.findtext("FirstRowData") or "False").lower() != "true"
|
|
import_line = int(opts.findtext("ImportLine") or "1")
|
|
skip = max(0, import_line - 1)
|
|
return pl.read_excel(
|
|
path,
|
|
sheet_name=sheet or 0,
|
|
read_options={"has_header": read_header, "skip_rows": skip},
|
|
)
|
|
|
|
if fmt == 2: # Parquet
|
|
return pl.read_parquet(path)
|
|
|
|
if fmt == 19: # YXDB
|
|
try:
|
|
import yxdb
|
|
reader = yxdb.open_file(path)
|
|
rows = list(reader)
|
|
if rows:
|
|
return pl.DataFrame(rows)
|
|
return pl.DataFrame()
|
|
except ImportError:
|
|
raise NotImplementedError(
|
|
"YXDB format requires the 'yxdb' package: pip install yxdb"
|
|
)
|
|
|
|
if fmt == 56: # JSON
|
|
return pl.read_json(path)
|
|
|
|
# Fallback: try CSV
|
|
return pl.read_csv(path, infer_schema_length=10000, ignore_errors=True)
|