51 lines
1.8 KiB
Python
51 lines
1.8 KiB
Python
from __future__ import annotations
|
|
import random
|
|
from typing import Dict
|
|
import polars as pl
|
|
from tools.base import BaseTool
|
|
|
|
|
|
class SampleTool(BaseTool):
|
|
def execute(self, inputs: Dict[str, pl.DataFrame]) -> Dict[str, pl.DataFrame]:
|
|
df = inputs.get("Input", pl.DataFrame())
|
|
if self.config is None or df.is_empty():
|
|
return {"Output": df}
|
|
|
|
mode = self._cfg("Mode", "First") or "First"
|
|
n = int(self._cfg("N", "1") or "1")
|
|
group_fields = [
|
|
f.attrib["name"]
|
|
for f in self.config.findall("GroupFields/Field")
|
|
if f.attrib["name"] in df.columns
|
|
]
|
|
|
|
if not group_fields:
|
|
return {"Output": self._sample_flat(df, mode, n)}
|
|
|
|
parts: list[pl.DataFrame] = []
|
|
for group_vals, group_df in df.group_by(group_fields, maintain_order=True):
|
|
parts.append(self._sample_flat(group_df, mode, n))
|
|
return {"Output": pl.concat(parts) if parts else pl.DataFrame()}
|
|
|
|
def _sample_flat(self, df: pl.DataFrame, mode: str, n: int) -> pl.DataFrame:
|
|
if mode == "First":
|
|
return df.head(n)
|
|
if mode == "Last":
|
|
return df.tail(n)
|
|
if mode == "Skip":
|
|
return df.slice(n, len(df) - n) if n < len(df) else pl.DataFrame(schema=df.schema)
|
|
if mode == "Sample":
|
|
if n <= 0:
|
|
return df
|
|
indices = list(range(0, len(df), n))
|
|
return df[indices]
|
|
if mode == "Random":
|
|
if n <= 0:
|
|
return df
|
|
mask = [random.random() < (1.0 / n) for _ in range(len(df))]
|
|
return df.filter(pl.Series(mask))
|
|
if mode == "NPercent":
|
|
count = max(1, int(len(df) * n / 100))
|
|
return df.head(count)
|
|
return df
|