Pyteryx/alteryx_runner/engine/executor.py

96 lines
3.1 KiB
Python

from __future__ import annotations
from collections import defaultdict, deque
import polars as pl
from .graph import WorkflowGraph, ConnectionDef
from .context import RunContext
from tools import get_tool_class
def execute(graph: WorkflowGraph, ctx: RunContext) -> dict[tuple, pl.DataFrame]:
"""Execute a WorkflowGraph in topological (BFS) order."""
in_degree: dict[int, int] = defaultdict(int)
successors: dict[int, list[ConnectionDef]] = defaultdict(list)
predecessors: dict[int, list[ConnectionDef]] = defaultdict(list)
for c in graph.connections:
in_degree[c.dest_id] += 1
successors[c.origin_id].append(c)
predecessors[c.dest_id].append(c)
for tid in graph.nodes:
if tid not in in_degree:
in_degree[tid] = 0
# (tool_id, anchor) → DataFrame
outputs: dict[tuple[int, str], pl.DataFrame] = {}
queue: deque[int] = deque(
tid for tid, deg in in_degree.items() if deg == 0
)
while queue:
tid = queue.popleft()
node = graph.nodes[tid]
tool_cls = get_tool_class(node.plugin)
if tool_cls is None:
if ctx.verbose:
print(f"[SKIP] ToolID={tid} plugin={node.plugin!r} (unsupported)")
_passthrough(tid, predecessors, outputs, successors, in_degree, queue)
continue
tool = tool_cls(node, ctx)
inputs: dict[str, pl.DataFrame] = {}
# Track duplicate dest_anchors to handle multi-input tools like Union
anchor_counts: dict[str, int] = defaultdict(int)
for c in predecessors[tid]:
anchor_counts[c.dest_anchor] += 1
for c in predecessors[tid]:
df = outputs.get((c.origin_id, c.origin_anchor))
if df is not None:
key = c.dest_anchor
# If multiple connections share the same dest_anchor,
# use the connection name (e.g., '#1', '#2') as the key
if anchor_counts[c.dest_anchor] > 1 and c.name:
key = c.name
inputs[key] = df
if ctx.verbose:
print(f"[RUN ] ToolID={tid} plugin={node.plugin!r}")
result = tool.execute(inputs)
for anchor, df in result.items():
outputs[(tid, anchor)] = df
for c in successors[tid]:
in_degree[c.dest_id] -= 1
if in_degree[c.dest_id] == 0:
queue.append(c.dest_id)
return outputs
def _passthrough(
tid: int,
predecessors: dict[int, list[ConnectionDef]],
outputs: dict[tuple[int, str], pl.DataFrame],
successors: dict[int, list[ConnectionDef]],
in_degree: dict[int, int],
queue: deque[int],
) -> None:
"""Propagate a single upstream output through a no-op node."""
preds = predecessors.get(tid, [])
df = pl.DataFrame()
if preds:
first = preds[0]
df = outputs.get((first.origin_id, first.origin_anchor), pl.DataFrame())
outputs[(tid, "Output")] = df
for c in successors.get(tid, []):
in_degree[c.dest_id] -= 1
if in_degree[c.dest_id] == 0:
queue.append(c.dest_id)