96 lines
3.1 KiB
Python
96 lines
3.1 KiB
Python
from __future__ import annotations
|
|
from collections import defaultdict, deque
|
|
import polars as pl
|
|
|
|
from .graph import WorkflowGraph, ConnectionDef
|
|
from .context import RunContext
|
|
from tools import get_tool_class
|
|
|
|
|
|
def execute(graph: WorkflowGraph, ctx: RunContext) -> dict[tuple, pl.DataFrame]:
|
|
"""Execute a WorkflowGraph in topological (BFS) order."""
|
|
in_degree: dict[int, int] = defaultdict(int)
|
|
successors: dict[int, list[ConnectionDef]] = defaultdict(list)
|
|
predecessors: dict[int, list[ConnectionDef]] = defaultdict(list)
|
|
|
|
for c in graph.connections:
|
|
in_degree[c.dest_id] += 1
|
|
successors[c.origin_id].append(c)
|
|
predecessors[c.dest_id].append(c)
|
|
|
|
for tid in graph.nodes:
|
|
if tid not in in_degree:
|
|
in_degree[tid] = 0
|
|
|
|
# (tool_id, anchor) → DataFrame
|
|
outputs: dict[tuple[int, str], pl.DataFrame] = {}
|
|
|
|
queue: deque[int] = deque(
|
|
tid for tid, deg in in_degree.items() if deg == 0
|
|
)
|
|
|
|
while queue:
|
|
tid = queue.popleft()
|
|
node = graph.nodes[tid]
|
|
tool_cls = get_tool_class(node.plugin)
|
|
|
|
if tool_cls is None:
|
|
if ctx.verbose:
|
|
print(f"[SKIP] ToolID={tid} plugin={node.plugin!r} (unsupported)")
|
|
_passthrough(tid, predecessors, outputs, successors, in_degree, queue)
|
|
continue
|
|
|
|
tool = tool_cls(node, ctx)
|
|
|
|
inputs: dict[str, pl.DataFrame] = {}
|
|
# Track duplicate dest_anchors to handle multi-input tools like Union
|
|
anchor_counts: dict[str, int] = defaultdict(int)
|
|
for c in predecessors[tid]:
|
|
anchor_counts[c.dest_anchor] += 1
|
|
|
|
for c in predecessors[tid]:
|
|
df = outputs.get((c.origin_id, c.origin_anchor))
|
|
if df is not None:
|
|
key = c.dest_anchor
|
|
# If multiple connections share the same dest_anchor,
|
|
# use the connection name (e.g., '#1', '#2') as the key
|
|
if anchor_counts[c.dest_anchor] > 1 and c.name:
|
|
key = c.name
|
|
inputs[key] = df
|
|
|
|
if ctx.verbose:
|
|
print(f"[RUN ] ToolID={tid} plugin={node.plugin!r}")
|
|
|
|
result = tool.execute(inputs)
|
|
|
|
for anchor, df in result.items():
|
|
outputs[(tid, anchor)] = df
|
|
|
|
for c in successors[tid]:
|
|
in_degree[c.dest_id] -= 1
|
|
if in_degree[c.dest_id] == 0:
|
|
queue.append(c.dest_id)
|
|
|
|
return outputs
|
|
|
|
|
|
def _passthrough(
|
|
tid: int,
|
|
predecessors: dict[int, list[ConnectionDef]],
|
|
outputs: dict[tuple[int, str], pl.DataFrame],
|
|
successors: dict[int, list[ConnectionDef]],
|
|
in_degree: dict[int, int],
|
|
queue: deque[int],
|
|
) -> None:
|
|
"""Propagate a single upstream output through a no-op node."""
|
|
preds = predecessors.get(tid, [])
|
|
df = pl.DataFrame()
|
|
if preds:
|
|
first = preds[0]
|
|
df = outputs.get((first.origin_id, first.origin_anchor), pl.DataFrame())
|
|
outputs[(tid, "Output")] = df
|
|
for c in successors.get(tid, []):
|
|
in_degree[c.dest_id] -= 1
|
|
if in_degree[c.dest_id] == 0:
|
|
queue.append(c.dest_id)
|