Pyteryx/alteryx_runner/tools/preparation/sample_tool.py

51 lines
1.8 KiB
Python

from __future__ import annotations
import random
from typing import Dict
import polars as pl
from tools.base import BaseTool
class SampleTool(BaseTool):
def execute(self, inputs: Dict[str, pl.DataFrame]) -> Dict[str, pl.DataFrame]:
df = inputs.get("Input", pl.DataFrame())
if self.config is None or df.is_empty():
return {"Output": df}
mode = self._cfg("Mode", "First") or "First"
n = int(self._cfg("N", "1") or "1")
group_fields = [
f.attrib["name"]
for f in self.config.findall("GroupFields/Field")
if f.attrib["name"] in df.columns
]
if not group_fields:
return {"Output": self._sample_flat(df, mode, n)}
parts: list[pl.DataFrame] = []
for group_vals, group_df in df.group_by(group_fields, maintain_order=True):
parts.append(self._sample_flat(group_df, mode, n))
return {"Output": pl.concat(parts) if parts else pl.DataFrame()}
def _sample_flat(self, df: pl.DataFrame, mode: str, n: int) -> pl.DataFrame:
if mode == "First":
return df.head(n)
if mode == "Last":
return df.tail(n)
if mode == "Skip":
return df.slice(n, len(df) - n) if n < len(df) else pl.DataFrame(schema=df.schema)
if mode == "Sample":
if n <= 0:
return df
indices = list(range(0, len(df), n))
return df[indices]
if mode == "Random":
if n <= 0:
return df
mask = [random.random() < (1.0 / n) for _ in range(len(df))]
return df.filter(pl.Series(mask))
if mode == "NPercent":
count = max(1, int(len(df) * n / 100))
return df.head(count)
return df