36 lines
1.1 KiB
Python
36 lines
1.1 KiB
Python
from __future__ import annotations
|
|
from typing import Dict
|
|
import polars as pl
|
|
from tools.base import BaseTool
|
|
|
|
|
|
class AutoFieldTool(BaseTool):
|
|
"""Automatically shrink column types to the smallest that fits the data."""
|
|
|
|
def execute(self, inputs: Dict[str, pl.DataFrame]) -> Dict[str, pl.DataFrame]:
|
|
df = inputs.get("Input", pl.DataFrame())
|
|
if df.is_empty():
|
|
return {"Output": df}
|
|
|
|
new_cols: list[pl.Series] = []
|
|
for col_name in df.columns:
|
|
series = df[col_name]
|
|
shrunken = self._shrink(series)
|
|
new_cols.append(shrunken)
|
|
|
|
return {"Output": pl.DataFrame(new_cols)}
|
|
|
|
def _shrink(self, s: pl.Series) -> pl.Series:
|
|
if s.dtype in (pl.Int64, pl.Int32, pl.Int16):
|
|
for dtype in (pl.Int8, pl.Int16, pl.Int32, pl.Int64):
|
|
try:
|
|
return s.cast(dtype)
|
|
except Exception:
|
|
continue
|
|
if s.dtype in (pl.Float64, pl.Float32):
|
|
try:
|
|
return s.cast(pl.Float32)
|
|
except Exception:
|
|
pass
|
|
return s
|