Pyteryx/alteryx_runner/tools/preparation/auto_field.py

36 lines
1.1 KiB
Python

from __future__ import annotations
from typing import Dict
import polars as pl
from tools.base import BaseTool
class AutoFieldTool(BaseTool):
"""Automatically shrink column types to the smallest that fits the data."""
def execute(self, inputs: Dict[str, pl.DataFrame]) -> Dict[str, pl.DataFrame]:
df = inputs.get("Input", pl.DataFrame())
if df.is_empty():
return {"Output": df}
new_cols: list[pl.Series] = []
for col_name in df.columns:
series = df[col_name]
shrunken = self._shrink(series)
new_cols.append(shrunken)
return {"Output": pl.DataFrame(new_cols)}
def _shrink(self, s: pl.Series) -> pl.Series:
if s.dtype in (pl.Int64, pl.Int32, pl.Int16):
for dtype in (pl.Int8, pl.Int16, pl.Int32, pl.Int64):
try:
return s.cast(dtype)
except Exception:
continue
if s.dtype in (pl.Float64, pl.Float32):
try:
return s.cast(pl.Float32)
except Exception:
pass
return s