from __future__ import annotations import random from typing import Dict import polars as pl from tools.base import BaseTool class SampleTool(BaseTool): def execute(self, inputs: Dict[str, pl.DataFrame]) -> Dict[str, pl.DataFrame]: df = inputs.get("Input", pl.DataFrame()) if self.config is None or df.is_empty(): return {"Output": df} mode = self._cfg("Mode", "First") or "First" n = int(self._cfg("N", "1") or "1") group_fields = [ f.attrib["name"] for f in self.config.findall("GroupFields/Field") if f.attrib["name"] in df.columns ] if not group_fields: return {"Output": self._sample_flat(df, mode, n)} parts: list[pl.DataFrame] = [] for group_vals, group_df in df.group_by(group_fields, maintain_order=True): parts.append(self._sample_flat(group_df, mode, n)) return {"Output": pl.concat(parts) if parts else pl.DataFrame()} def _sample_flat(self, df: pl.DataFrame, mode: str, n: int) -> pl.DataFrame: if mode == "First": return df.head(n) if mode == "Last": return df.tail(n) if mode == "Sample": if n <= 0: return df indices = list(range(0, len(df), n)) return df[indices] if mode == "Random": if n <= 0: return df mask = [random.random() < (1.0 / n) for _ in range(len(df))] return df.filter(pl.Series(mask)) if mode == "NPercent": count = max(1, int(len(df) * n / 100)) return df.head(count) return df