commit accf03131807fc3a25741ba4932719b1441578df
Author: Coder Agent <coder@agents.omni>
Date: Thu Feb 19 12:16:37 2026
Add graph IR build/compile foundation with applicative fork detection
Task-Id: t-531
diff --git a/Omni/Agent/Prompt/IR.hs b/Omni/Agent/Prompt/IR.hs
index 94c686ff..77c2551b 100644
--- a/Omni/Agent/Prompt/IR.hs
+++ b/Omni/Agent/Prompt/IR.hs
@@ -1,4 +1,6 @@
{-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE NoImplicitPrelude #-}
@@ -61,6 +63,26 @@ module Omni.Agent.Prompt.IR
sectionsByPriority,
sectionsByComposition,
+ -- * Graph IR (t-531)
+ GraphExpr,
+ InferSpec (..),
+ NodeId (..),
+ GraphNode (..),
+ GraphNodeKind (..),
+ ParallelHint (..),
+ PromptGraph (..),
+ OptimizationPass (..),
+ ExecutionPlan (..),
+ CompiledGraph (..),
+ graphInfer,
+ fork,
+ buildGraph,
+ compileGraph,
+ buildAndCompile,
+ eliminateDeadBranches,
+ fuseAdjacentInferences,
+ scheduleForParallelism,
+
-- * Testing
main,
test,
@@ -467,6 +489,391 @@ sectionsByComposition :: [Section] -> [(CompositionMode, [Section])]
sectionsByComposition sections =
[(m, filter ((== m) <. secCompositionMode) sections) | m <- [Hierarchical, Constraint, Additive, Contextual]]
+-- * Graph IR (t-531)
+
+-- | Primitive inference step used by graph expressions.
+data InferSpec = InferSpec
+ { isName :: Text,
+ isPrompt :: Text
+ }
+ deriving (Show, Eq, Generic)
+
+instance Aeson.ToJSON InferSpec where
+ toJSON spec =
+ Aeson.object
+ [ "name" .= isName spec,
+ "prompt" .= isPrompt spec
+ ]
+
+-- | Typed expression language for inference graphs.
+--
+-- The Applicative instance preserves structure, which allows us to detect
+-- independent branches (especially when marked with 'fork').
+data GraphExpr a where
+ GraphPure :: a -> GraphExpr a
+ GraphInfer :: InferSpec -> GraphExpr Text
+ GraphFork :: GraphExpr a -> GraphExpr a
+ GraphAp :: GraphExpr (a -> b) -> GraphExpr a -> GraphExpr b
+
+instance Functor GraphExpr where
+ fmap f expr = GraphPure f <*> expr
+
+instance Applicative GraphExpr where
+ pure = GraphPure
+ (<*>) = GraphAp
+
+-- | Smart constructor for inference leaves.
+graphInfer :: Text -> Text -> GraphExpr Text
+graphInfer name prompt = GraphInfer (InferSpec name prompt)
+
+-- | Mark a branch as explicitly forkable/parallel.
+fork :: GraphExpr a -> GraphExpr a
+fork = GraphFork
+
+-- | Stable node identifier in the compiled graph.
+newtype NodeId = NodeId {unNodeId :: Int}
+ deriving (Show, Eq, Ord, Generic)
+
+instance Aeson.ToJSON NodeId where
+ toJSON (NodeId n) = Aeson.toJSON n
+
+-- | Node operation kind in the graph IR.
+data GraphNodeKind
+ = GraphNodeInfer InferSpec
+ | GraphNodeApply
+ deriving (Show, Eq, Generic)
+
+instance Aeson.ToJSON GraphNodeKind where
+ toJSON (GraphNodeInfer spec) = Aeson.object ["type" .= ("infer" :: Text), "infer" .= spec]
+ toJSON GraphNodeApply = Aeson.object ["type" .= ("apply" :: Text)]
+
+-- | Single node in the graph IR.
+data GraphNode = GraphNode
+ { gnId :: NodeId,
+ gnKind :: GraphNodeKind,
+ gnDependsOn :: [NodeId]
+ }
+ deriving (Show, Eq, Generic)
+
+instance Aeson.ToJSON GraphNode where
+ toJSON node =
+ Aeson.object
+ [ "id" .= gnId node,
+ "kind" .= gnKind node,
+ "depends_on" .= gnDependsOn node
+ ]
+
+-- | Hint describing a detected parallel fan-out/fan-in region.
+data ParallelHint = ParallelHint
+ { phGroupId :: Int,
+ phMembers :: [NodeId],
+ phJoinNode :: NodeId
+ }
+ deriving (Show, Eq, Generic)
+
+instance Aeson.ToJSON ParallelHint where
+ toJSON hint =
+ Aeson.object
+ [ "group_id" .= phGroupId hint,
+ "members" .= phMembers hint,
+ "join_node" .= phJoinNode hint
+ ]
+
+-- | Graph IR produced by the build phase.
+data PromptGraph = PromptGraph
+ { pgNodes :: [GraphNode],
+ pgRoots :: [NodeId],
+ pgParallelHints :: [ParallelHint]
+ }
+ deriving (Show, Eq, Generic)
+
+instance Aeson.ToJSON PromptGraph where
+ toJSON graph =
+ Aeson.object
+ [ "nodes" .= pgNodes graph,
+ "roots" .= pgRoots graph,
+ "parallel_hints" .= pgParallelHints graph
+ ]
+
+-- | Optimization/compile passes.
+data OptimizationPass
+ = PassEliminateDeadBranches
+ | PassFuseAdjacentInferences
+ | PassScheduleParallelism
+ deriving (Show, Eq, Generic)
+
+instance Aeson.ToJSON OptimizationPass where
+ toJSON PassEliminateDeadBranches = Aeson.String "eliminate_dead_branches"
+ toJSON PassFuseAdjacentInferences = Aeson.String "fuse_adjacent_inferences"
+ toJSON PassScheduleParallelism = Aeson.String "schedule_parallelism"
+
+-- | Stage-based execution plan. Nodes in each stage are independent.
+newtype ExecutionPlan = ExecutionPlan
+ { epStages :: [[NodeId]]
+ }
+ deriving (Show, Eq, Generic)
+
+instance Aeson.ToJSON ExecutionPlan where
+ toJSON plan = Aeson.object ["stages" .= epStages plan]
+
+-- | Output of compile phase: optimized graph + execution schedule.
+data CompiledGraph = CompiledGraph
+ { cgGraph :: PromptGraph,
+ cgExecutionPlan :: ExecutionPlan,
+ cgPasses :: [OptimizationPass]
+ }
+ deriving (Show, Eq, Generic)
+
+instance Aeson.ToJSON CompiledGraph where
+ toJSON compiled =
+ Aeson.object
+ [ "graph" .= cgGraph compiled,
+ "execution_plan" .= cgExecutionPlan compiled,
+ "passes" .= cgPasses compiled
+ ]
+
+data BuildContext = BuildContext
+ { bcNextNodeId :: Int,
+ bcNextGroupId :: Int
+ }
+
+data BuildResult = BuildResult
+ { brNodes :: [GraphNode],
+ brRoots :: [NodeId],
+ brPendingForks :: [NodeId],
+ brHints :: [ParallelHint]
+ }
+
+initialBuildContext :: BuildContext
+initialBuildContext =
+ BuildContext
+ { bcNextNodeId = 0,
+ bcNextGroupId = 0
+ }
+
+emptyBuildResult :: BuildResult
+emptyBuildResult =
+ BuildResult
+ { brNodes = [],
+ brRoots = [],
+ brPendingForks = [],
+ brHints = []
+ }
+
+freshNodeId :: BuildContext -> (BuildContext, NodeId)
+freshNodeId ctx =
+ let n = bcNextNodeId ctx
+ in (ctx {bcNextNodeId = n + 1}, NodeId n)
+
+freshGroupId :: BuildContext -> (BuildContext, Int)
+freshGroupId ctx =
+ let n = bcNextGroupId ctx
+ in (ctx {bcNextGroupId = n + 1}, n)
+
+-- | Build phase: convert an applicative expression into graph IR.
+--
+-- This phase is pure structure construction only (no scheduling/execution).
+buildGraph :: GraphExpr a -> PromptGraph
+buildGraph expr =
+ let (_, result) = buildExpr initialBuildContext expr
+ in PromptGraph
+ { pgNodes = sortOn gnId (brNodes result),
+ pgRoots = dedupe (brRoots result),
+ pgParallelHints = brHints result
+ }
+
+buildExpr :: BuildContext -> GraphExpr a -> (BuildContext, BuildResult)
+buildExpr ctx = \case
+ GraphPure _ -> (ctx, emptyBuildResult)
+ GraphInfer spec ->
+ let (ctx', nodeId) = freshNodeId ctx
+ node = GraphNode nodeId (GraphNodeInfer spec) []
+ in (ctx', emptyBuildResult {brNodes = [node], brRoots = [nodeId]})
+ GraphFork expr ->
+ let (ctx', result) = buildExpr ctx expr
+ pendingForks = dedupe (brRoots result <> brPendingForks result)
+ in (ctx', result {brPendingForks = pendingForks})
+ GraphAp left right ->
+ let (ctx1, leftResult) = buildExpr ctx left
+ (ctx2, rightResult) = buildExpr ctx1 right
+ deps = dedupe (brRoots leftResult <> brRoots rightResult)
+ mergedNodes = brNodes leftResult <> brNodes rightResult
+ pendingForks = dedupe (brPendingForks leftResult <> brPendingForks rightResult)
+ mergedHints = brHints leftResult <> brHints rightResult
+ in if null deps
+ then
+ ( ctx2,
+ BuildResult
+ { brNodes = mergedNodes,
+ brRoots = [],
+ brPendingForks = pendingForks,
+ brHints = mergedHints
+ }
+ )
+ else
+ let (ctx3, applyNodeId) = freshNodeId ctx2
+ applyNode = GraphNode applyNodeId GraphNodeApply deps
+ hasParallelGroup = length pendingForks >= 2
+ (ctx4, newHints, remainingForks) =
+ if hasParallelGroup
+ then
+ let (ctx', groupId) = freshGroupId ctx3
+ hint =
+ ParallelHint
+ { phGroupId = groupId,
+ phMembers = pendingForks,
+ phJoinNode = applyNodeId
+ }
+ in (ctx', [hint], [])
+ else (ctx3, [], pendingForks)
+ in ( ctx4,
+ BuildResult
+ { brNodes = mergedNodes <> [applyNode],
+ brRoots = [applyNodeId],
+ brPendingForks = remainingForks,
+ brHints = mergedHints <> newHints
+ }
+ )
+
+-- | Compile phase: optimize built graph and produce execution schedule.
+compileGraph :: PromptGraph -> CompiledGraph
+compileGraph graph =
+ let withoutDead = eliminateDeadBranches graph
+ fused = fuseAdjacentInferences withoutDead
+ schedule = scheduleForParallelism fused
+ in CompiledGraph
+ { cgGraph = fused,
+ cgExecutionPlan = schedule,
+ cgPasses =
+ [ PassEliminateDeadBranches,
+ PassFuseAdjacentInferences,
+ PassScheduleParallelism
+ ]
+ }
+
+-- | Convenience entry point for build + compile.
+buildAndCompile :: GraphExpr a -> CompiledGraph
+buildAndCompile = compileGraph <. buildGraph
+
+-- | Optimization pass: remove nodes not reachable from the graph roots.
+eliminateDeadBranches :: PromptGraph -> PromptGraph
+eliminateDeadBranches graph =
+ let reachable = collectReachable (pgRoots graph) (pgNodes graph)
+ isReachable nodeId = nodeId `elem` reachable
+ liveNodes = filter (isReachable <. gnId) (pgNodes graph)
+ liveHints =
+ mapMaybe
+ (\hint ->
+ let members = filter isReachable (phMembers hint)
+ in if isReachable (phJoinNode hint) && length members >= 2
+ then Just hint {phMembers = members}
+ else Nothing
+ )
+ (pgParallelHints graph)
+ in graph
+ { pgNodes = liveNodes,
+ pgParallelHints = liveHints
+ }
+
+collectReachable :: [NodeId] -> [GraphNode] -> [NodeId]
+collectReachable roots nodes = go [] roots
+ where
+ go seen [] = seen
+ go seen (nodeId : rest)
+ | nodeId `elem` seen = go seen rest
+ | otherwise =
+ let deps = maybe [] gnDependsOn (List.find ((== nodeId) <. gnId) nodes)
+ in go (nodeId : seen) (deps <> rest)
+
+-- | Optimization pass: fuse linear infer→infer chains.
+--
+-- This intentionally skips nodes participating in detected parallel groups.
+fuseAdjacentInferences :: PromptGraph -> PromptGraph
+fuseAdjacentInferences graph =
+ case findFuseCandidate graph of
+ Nothing -> graph
+ Just candidate -> fuseAdjacentInferences (applyFusion graph candidate)
+
+findFuseCandidate :: PromptGraph -> Maybe (NodeId, NodeId, InferSpec, InferSpec)
+findFuseCandidate graph = go (pgNodes graph)
+ where
+ nodes = pgNodes graph
+ inParallel nodeId = any (((nodeId `elem`) <. phMembers)) (pgParallelHints graph)
+
+ go [] = Nothing
+ go (node : rest) =
+ case gnKind node of
+ GraphNodeInfer leftSpec
+ | not (inParallel (gnId node)) ->
+ let children = filter ((gnId node `elem`) <. gnDependsOn) nodes
+ in case children of
+ [child] ->
+ case gnKind child of
+ GraphNodeInfer rightSpec
+ | gnDependsOn child == [gnId node] && not (inParallel (gnId child)) ->
+ Just (gnId node, gnId child, leftSpec, rightSpec)
+ _ -> go rest
+ _ -> go rest
+ _ -> go rest
+
+applyFusion :: PromptGraph -> (NodeId, NodeId, InferSpec, InferSpec) -> PromptGraph
+applyFusion graph (leftId, rightId, leftSpec, rightSpec) =
+ let leftDeps =
+ fromMaybe
+ []
+ (gnDependsOn </ List.find ((== leftId) <. gnId) (pgNodes graph))
+ fusedSpec =
+ InferSpec
+ { isName = isName leftSpec <> "+" <> isName rightSpec,
+ isPrompt = isPrompt leftSpec <> "\n\n" <> isPrompt rightSpec
+ }
+ replaceId nodeId = if nodeId == leftId then rightId else nodeId
+ rewriteNode node
+ | gnId node == rightId =
+ node
+ { gnKind = GraphNodeInfer fusedSpec,
+ gnDependsOn = leftDeps
+ }
+ | gnId node == leftId = node
+ | otherwise = node {gnDependsOn = dedupe (map replaceId (gnDependsOn node))}
+ rewrittenNodes = map rewriteNode (pgNodes graph)
+ prunedNodes = filter ((/= leftId) <. gnId) rewrittenNodes
+ rewrittenRoots = dedupe (map replaceId (pgRoots graph))
+ rewrittenHints =
+ mapMaybe
+ (\hint ->
+ let members = dedupe (map replaceId (phMembers hint))
+ joinNode = replaceId (phJoinNode hint)
+ in if length members >= 2
+ then Just hint {phMembers = members, phJoinNode = joinNode}
+ else Nothing
+ )
+ (pgParallelHints graph)
+ in graph
+ { pgNodes = prunedNodes,
+ pgRoots = rewrittenRoots,
+ pgParallelHints = rewrittenHints
+ }
+
+-- | Optimization pass: produce stage-wise parallel schedule from DAG deps.
+scheduleForParallelism :: PromptGraph -> ExecutionPlan
+scheduleForParallelism graph =
+ ExecutionPlan (go [] (sortOn gnId (pgNodes graph)))
+ where
+ go _ [] = []
+ go completed remaining =
+ let readyNodes = filter (all (`elem` completed) <. gnDependsOn) remaining
+ readyIds = map gnId readyNodes
+ remaining' = filter ((`notElem` readyIds) <. gnId) remaining
+ in if null readyIds
+ then [map gnId remaining]
+ else readyIds : go (completed <> readyIds) remaining'
+
+dedupe :: (Eq a) => [a] -> [a]
+dedupe = reverse <. foldl' step []
+ where
+ step acc item = if item `elem` acc then acc else item : acc
+
-- * Tests
main :: IO ()
@@ -552,5 +959,57 @@ test =
let s = defaultContextStrategy
(csTemporalWindow s > 0) Test.@=? True
(csSemanticThreshold s >= 0 && csSemanticThreshold s <= 1) Test.@=? True
- (csRecencyDecay s > 0 && csRecencyDecay s <= 1) Test.@=? True
+ (csRecencyDecay s > 0 && csRecencyDecay s <= 1) Test.@=? True,
+ Test.unit "Applicative + fork detects a parallel fanout" <| do
+ let x = fork (graphInfer "x" "analyze X")
+ y = fork (graphInfer "y" "analyze Y")
+ graph = buildGraph ((,) <$> x <*> y)
+ length (pgParallelHints graph) Test.@=? 1
+ case pgParallelHints graph of
+ [hint] -> do
+ length (phMembers hint) Test.@=? 2
+ (phJoinNode hint `elem` pgRoots graph) Test.@=? True
+ _ -> Test.assertFailure "Expected exactly one parallel hint",
+ Test.unit "eliminateDeadBranches prunes unreachable nodes" <| do
+ let reachable = GraphNode (NodeId 1) (GraphNodeInfer (InferSpec "a" "prompt a")) []
+ dead = GraphNode (NodeId 2) (GraphNodeInfer (InferSpec "dead" "prompt dead")) []
+ graph =
+ PromptGraph
+ { pgNodes = [reachable, dead],
+ pgRoots = [NodeId 1],
+ pgParallelHints = []
+ }
+ pruned = eliminateDeadBranches graph
+ map gnId (pgNodes pruned) Test.@=? [NodeId 1],
+ Test.unit "fuseAdjacentInferences merges linear infer chain" <| do
+ let nodeA = GraphNode (NodeId 1) (GraphNodeInfer (InferSpec "a" "prompt a")) []
+ nodeB = GraphNode (NodeId 2) (GraphNodeInfer (InferSpec "b" "prompt b")) [NodeId 1]
+ graph =
+ PromptGraph
+ { pgNodes = [nodeA, nodeB],
+ pgRoots = [NodeId 2],
+ pgParallelHints = []
+ }
+ fused = fuseAdjacentInferences graph
+ length (pgNodes fused) Test.@=? 1
+ case pgNodes fused of
+ [node] -> do
+ gnId node Test.@=? NodeId 2
+ gnDependsOn node Test.@=? []
+ case gnKind node of
+ GraphNodeInfer spec -> do
+ isName spec Test.@=? "a+b"
+ Text.isInfixOf "prompt a" (isPrompt spec) Test.@=? True
+ Text.isInfixOf "prompt b" (isPrompt spec) Test.@=? True
+ _ -> Test.assertFailure "Expected fused infer node"
+ _ -> Test.assertFailure "Expected one fused node",
+ Test.unit "compileGraph produces stage schedule with parallel first stage" <| do
+ let x = fork (graphInfer "x" "analyze X")
+ y = fork (graphInfer "y" "analyze Y")
+ compiled = buildAndCompile ((,) <$> x <*> y)
+ stages = epStages (cgExecutionPlan compiled)
+ length stages Test.@=? 3
+ case stages of
+ (firstStage : _) -> length firstStage Test.@=? 2
+ _ -> Test.assertFailure "Expected at least one stage"
]