← Back to task

Commit accf0313

commit accf03131807fc3a25741ba4932719b1441578df
Author: Coder Agent <coder@agents.omni>
Date:   Thu Feb 19 12:16:37 2026

    Add graph IR build/compile foundation with applicative fork detection
    
    Task-Id: t-531

diff --git a/Omni/Agent/Prompt/IR.hs b/Omni/Agent/Prompt/IR.hs
index 94c686ff..77c2551b 100644
--- a/Omni/Agent/Prompt/IR.hs
+++ b/Omni/Agent/Prompt/IR.hs
@@ -1,4 +1,6 @@
 {-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
 {-# LANGUAGE OverloadedStrings #-}
 {-# LANGUAGE NoImplicitPrelude #-}
 
@@ -61,6 +63,26 @@ module Omni.Agent.Prompt.IR
     sectionsByPriority,
     sectionsByComposition,
 
+    -- * Graph IR (t-531)
+    GraphExpr,
+    InferSpec (..),
+    NodeId (..),
+    GraphNode (..),
+    GraphNodeKind (..),
+    ParallelHint (..),
+    PromptGraph (..),
+    OptimizationPass (..),
+    ExecutionPlan (..),
+    CompiledGraph (..),
+    graphInfer,
+    fork,
+    buildGraph,
+    compileGraph,
+    buildAndCompile,
+    eliminateDeadBranches,
+    fuseAdjacentInferences,
+    scheduleForParallelism,
+
     -- * Testing
     main,
     test,
@@ -467,6 +489,391 @@ sectionsByComposition :: [Section] -> [(CompositionMode, [Section])]
 sectionsByComposition sections =
   [(m, filter ((== m) <. secCompositionMode) sections) | m <- [Hierarchical, Constraint, Additive, Contextual]]
 
+-- * Graph IR (t-531)
+
+-- | Primitive inference step used by graph expressions.
+data InferSpec = InferSpec
+  { isName :: Text,
+    isPrompt :: Text
+  }
+  deriving (Show, Eq, Generic)
+
+instance Aeson.ToJSON InferSpec where
+  toJSON spec =
+    Aeson.object
+      [ "name" .= isName spec,
+        "prompt" .= isPrompt spec
+      ]
+
+-- | Typed expression language for inference graphs.
+--
+-- The Applicative instance preserves structure, which allows us to detect
+-- independent branches (especially when marked with 'fork').
+data GraphExpr a where
+  GraphPure :: a -> GraphExpr a
+  GraphInfer :: InferSpec -> GraphExpr Text
+  GraphFork :: GraphExpr a -> GraphExpr a
+  GraphAp :: GraphExpr (a -> b) -> GraphExpr a -> GraphExpr b
+
+instance Functor GraphExpr where
+  fmap f expr = GraphPure f <*> expr
+
+instance Applicative GraphExpr where
+  pure = GraphPure
+  (<*>) = GraphAp
+
+-- | Smart constructor for inference leaves.
+graphInfer :: Text -> Text -> GraphExpr Text
+graphInfer name prompt = GraphInfer (InferSpec name prompt)
+
+-- | Mark a branch as explicitly forkable/parallel.
+fork :: GraphExpr a -> GraphExpr a
+fork = GraphFork
+
+-- | Stable node identifier in the compiled graph.
+newtype NodeId = NodeId {unNodeId :: Int}
+  deriving (Show, Eq, Ord, Generic)
+
+instance Aeson.ToJSON NodeId where
+  toJSON (NodeId n) = Aeson.toJSON n
+
+-- | Node operation kind in the graph IR.
+data GraphNodeKind
+  = GraphNodeInfer InferSpec
+  | GraphNodeApply
+  deriving (Show, Eq, Generic)
+
+instance Aeson.ToJSON GraphNodeKind where
+  toJSON (GraphNodeInfer spec) = Aeson.object ["type" .= ("infer" :: Text), "infer" .= spec]
+  toJSON GraphNodeApply = Aeson.object ["type" .= ("apply" :: Text)]
+
+-- | Single node in the graph IR.
+data GraphNode = GraphNode
+  { gnId :: NodeId,
+    gnKind :: GraphNodeKind,
+    gnDependsOn :: [NodeId]
+  }
+  deriving (Show, Eq, Generic)
+
+instance Aeson.ToJSON GraphNode where
+  toJSON node =
+    Aeson.object
+      [ "id" .= gnId node,
+        "kind" .= gnKind node,
+        "depends_on" .= gnDependsOn node
+      ]
+
+-- | Hint describing a detected parallel fan-out/fan-in region.
+data ParallelHint = ParallelHint
+  { phGroupId :: Int,
+    phMembers :: [NodeId],
+    phJoinNode :: NodeId
+  }
+  deriving (Show, Eq, Generic)
+
+instance Aeson.ToJSON ParallelHint where
+  toJSON hint =
+    Aeson.object
+      [ "group_id" .= phGroupId hint,
+        "members" .= phMembers hint,
+        "join_node" .= phJoinNode hint
+      ]
+
+-- | Graph IR produced by the build phase.
+data PromptGraph = PromptGraph
+  { pgNodes :: [GraphNode],
+    pgRoots :: [NodeId],
+    pgParallelHints :: [ParallelHint]
+  }
+  deriving (Show, Eq, Generic)
+
+instance Aeson.ToJSON PromptGraph where
+  toJSON graph =
+    Aeson.object
+      [ "nodes" .= pgNodes graph,
+        "roots" .= pgRoots graph,
+        "parallel_hints" .= pgParallelHints graph
+      ]
+
+-- | Optimization/compile passes.
+data OptimizationPass
+  = PassEliminateDeadBranches
+  | PassFuseAdjacentInferences
+  | PassScheduleParallelism
+  deriving (Show, Eq, Generic)
+
+instance Aeson.ToJSON OptimizationPass where
+  toJSON PassEliminateDeadBranches = Aeson.String "eliminate_dead_branches"
+  toJSON PassFuseAdjacentInferences = Aeson.String "fuse_adjacent_inferences"
+  toJSON PassScheduleParallelism = Aeson.String "schedule_parallelism"
+
+-- | Stage-based execution plan. Nodes in each stage are independent.
+newtype ExecutionPlan = ExecutionPlan
+  { epStages :: [[NodeId]]
+  }
+  deriving (Show, Eq, Generic)
+
+instance Aeson.ToJSON ExecutionPlan where
+  toJSON plan = Aeson.object ["stages" .= epStages plan]
+
+-- | Output of compile phase: optimized graph + execution schedule.
+data CompiledGraph = CompiledGraph
+  { cgGraph :: PromptGraph,
+    cgExecutionPlan :: ExecutionPlan,
+    cgPasses :: [OptimizationPass]
+  }
+  deriving (Show, Eq, Generic)
+
+instance Aeson.ToJSON CompiledGraph where
+  toJSON compiled =
+    Aeson.object
+      [ "graph" .= cgGraph compiled,
+        "execution_plan" .= cgExecutionPlan compiled,
+        "passes" .= cgPasses compiled
+      ]
+
+data BuildContext = BuildContext
+  { bcNextNodeId :: Int,
+    bcNextGroupId :: Int
+  }
+
+data BuildResult = BuildResult
+  { brNodes :: [GraphNode],
+    brRoots :: [NodeId],
+    brPendingForks :: [NodeId],
+    brHints :: [ParallelHint]
+  }
+
+initialBuildContext :: BuildContext
+initialBuildContext =
+  BuildContext
+    { bcNextNodeId = 0,
+      bcNextGroupId = 0
+    }
+
+emptyBuildResult :: BuildResult
+emptyBuildResult =
+  BuildResult
+    { brNodes = [],
+      brRoots = [],
+      brPendingForks = [],
+      brHints = []
+    }
+
+freshNodeId :: BuildContext -> (BuildContext, NodeId)
+freshNodeId ctx =
+  let n = bcNextNodeId ctx
+   in (ctx {bcNextNodeId = n + 1}, NodeId n)
+
+freshGroupId :: BuildContext -> (BuildContext, Int)
+freshGroupId ctx =
+  let n = bcNextGroupId ctx
+   in (ctx {bcNextGroupId = n + 1}, n)
+
+-- | Build phase: convert an applicative expression into graph IR.
+--
+-- This phase is pure structure construction only (no scheduling/execution).
+buildGraph :: GraphExpr a -> PromptGraph
+buildGraph expr =
+  let (_, result) = buildExpr initialBuildContext expr
+   in PromptGraph
+        { pgNodes = sortOn gnId (brNodes result),
+          pgRoots = dedupe (brRoots result),
+          pgParallelHints = brHints result
+        }
+
+buildExpr :: BuildContext -> GraphExpr a -> (BuildContext, BuildResult)
+buildExpr ctx = \case
+  GraphPure _ -> (ctx, emptyBuildResult)
+  GraphInfer spec ->
+    let (ctx', nodeId) = freshNodeId ctx
+        node = GraphNode nodeId (GraphNodeInfer spec) []
+     in (ctx', emptyBuildResult {brNodes = [node], brRoots = [nodeId]})
+  GraphFork expr ->
+    let (ctx', result) = buildExpr ctx expr
+        pendingForks = dedupe (brRoots result <> brPendingForks result)
+     in (ctx', result {brPendingForks = pendingForks})
+  GraphAp left right ->
+    let (ctx1, leftResult) = buildExpr ctx left
+        (ctx2, rightResult) = buildExpr ctx1 right
+        deps = dedupe (brRoots leftResult <> brRoots rightResult)
+        mergedNodes = brNodes leftResult <> brNodes rightResult
+        pendingForks = dedupe (brPendingForks leftResult <> brPendingForks rightResult)
+        mergedHints = brHints leftResult <> brHints rightResult
+     in if null deps
+          then
+            ( ctx2,
+              BuildResult
+                { brNodes = mergedNodes,
+                  brRoots = [],
+                  brPendingForks = pendingForks,
+                  brHints = mergedHints
+                }
+            )
+          else
+            let (ctx3, applyNodeId) = freshNodeId ctx2
+                applyNode = GraphNode applyNodeId GraphNodeApply deps
+                hasParallelGroup = length pendingForks >= 2
+                (ctx4, newHints, remainingForks) =
+                  if hasParallelGroup
+                    then
+                      let (ctx', groupId) = freshGroupId ctx3
+                          hint =
+                            ParallelHint
+                              { phGroupId = groupId,
+                                phMembers = pendingForks,
+                                phJoinNode = applyNodeId
+                              }
+                       in (ctx', [hint], [])
+                    else (ctx3, [], pendingForks)
+             in ( ctx4,
+                  BuildResult
+                    { brNodes = mergedNodes <> [applyNode],
+                      brRoots = [applyNodeId],
+                      brPendingForks = remainingForks,
+                      brHints = mergedHints <> newHints
+                    }
+                )
+
+-- | Compile phase: optimize built graph and produce execution schedule.
+compileGraph :: PromptGraph -> CompiledGraph
+compileGraph graph =
+  let withoutDead = eliminateDeadBranches graph
+      fused = fuseAdjacentInferences withoutDead
+      schedule = scheduleForParallelism fused
+   in CompiledGraph
+        { cgGraph = fused,
+          cgExecutionPlan = schedule,
+          cgPasses =
+            [ PassEliminateDeadBranches,
+              PassFuseAdjacentInferences,
+              PassScheduleParallelism
+            ]
+        }
+
+-- | Convenience entry point for build + compile.
+buildAndCompile :: GraphExpr a -> CompiledGraph
+buildAndCompile = compileGraph <. buildGraph
+
+-- | Optimization pass: remove nodes not reachable from the graph roots.
+eliminateDeadBranches :: PromptGraph -> PromptGraph
+eliminateDeadBranches graph =
+  let reachable = collectReachable (pgRoots graph) (pgNodes graph)
+      isReachable nodeId = nodeId `elem` reachable
+      liveNodes = filter (isReachable <. gnId) (pgNodes graph)
+      liveHints =
+        mapMaybe
+          (\hint ->
+             let members = filter isReachable (phMembers hint)
+              in if isReachable (phJoinNode hint) && length members >= 2
+                   then Just hint {phMembers = members}
+                   else Nothing
+          )
+          (pgParallelHints graph)
+   in graph
+        { pgNodes = liveNodes,
+          pgParallelHints = liveHints
+        }
+
+collectReachable :: [NodeId] -> [GraphNode] -> [NodeId]
+collectReachable roots nodes = go [] roots
+  where
+    go seen [] = seen
+    go seen (nodeId : rest)
+      | nodeId `elem` seen = go seen rest
+      | otherwise =
+          let deps = maybe [] gnDependsOn (List.find ((== nodeId) <. gnId) nodes)
+           in go (nodeId : seen) (deps <> rest)
+
+-- | Optimization pass: fuse linear infer→infer chains.
+--
+-- This intentionally skips nodes participating in detected parallel groups.
+fuseAdjacentInferences :: PromptGraph -> PromptGraph
+fuseAdjacentInferences graph =
+  case findFuseCandidate graph of
+    Nothing -> graph
+    Just candidate -> fuseAdjacentInferences (applyFusion graph candidate)
+
+findFuseCandidate :: PromptGraph -> Maybe (NodeId, NodeId, InferSpec, InferSpec)
+findFuseCandidate graph = go (pgNodes graph)
+  where
+    nodes = pgNodes graph
+    inParallel nodeId = any (((nodeId `elem`) <. phMembers)) (pgParallelHints graph)
+
+    go [] = Nothing
+    go (node : rest) =
+      case gnKind node of
+        GraphNodeInfer leftSpec
+          | not (inParallel (gnId node)) ->
+              let children = filter ((gnId node `elem`) <. gnDependsOn) nodes
+               in case children of
+                    [child] ->
+                      case gnKind child of
+                        GraphNodeInfer rightSpec
+                          | gnDependsOn child == [gnId node] && not (inParallel (gnId child)) ->
+                              Just (gnId node, gnId child, leftSpec, rightSpec)
+                        _ -> go rest
+                    _ -> go rest
+        _ -> go rest
+
+applyFusion :: PromptGraph -> (NodeId, NodeId, InferSpec, InferSpec) -> PromptGraph
+applyFusion graph (leftId, rightId, leftSpec, rightSpec) =
+  let leftDeps =
+        fromMaybe
+          []
+          (gnDependsOn </ List.find ((== leftId) <. gnId) (pgNodes graph))
+      fusedSpec =
+        InferSpec
+          { isName = isName leftSpec <> "+" <> isName rightSpec,
+            isPrompt = isPrompt leftSpec <> "\n\n" <> isPrompt rightSpec
+          }
+      replaceId nodeId = if nodeId == leftId then rightId else nodeId
+      rewriteNode node
+        | gnId node == rightId =
+            node
+              { gnKind = GraphNodeInfer fusedSpec,
+                gnDependsOn = leftDeps
+              }
+        | gnId node == leftId = node
+        | otherwise = node {gnDependsOn = dedupe (map replaceId (gnDependsOn node))}
+      rewrittenNodes = map rewriteNode (pgNodes graph)
+      prunedNodes = filter ((/= leftId) <. gnId) rewrittenNodes
+      rewrittenRoots = dedupe (map replaceId (pgRoots graph))
+      rewrittenHints =
+        mapMaybe
+          (\hint ->
+             let members = dedupe (map replaceId (phMembers hint))
+                 joinNode = replaceId (phJoinNode hint)
+              in if length members >= 2
+                   then Just hint {phMembers = members, phJoinNode = joinNode}
+                   else Nothing
+          )
+          (pgParallelHints graph)
+   in graph
+        { pgNodes = prunedNodes,
+          pgRoots = rewrittenRoots,
+          pgParallelHints = rewrittenHints
+        }
+
+-- | Optimization pass: produce stage-wise parallel schedule from DAG deps.
+scheduleForParallelism :: PromptGraph -> ExecutionPlan
+scheduleForParallelism graph =
+  ExecutionPlan (go [] (sortOn gnId (pgNodes graph)))
+  where
+    go _ [] = []
+    go completed remaining =
+      let readyNodes = filter (all (`elem` completed) <. gnDependsOn) remaining
+          readyIds = map gnId readyNodes
+          remaining' = filter ((`notElem` readyIds) <. gnId) remaining
+       in if null readyIds
+            then [map gnId remaining]
+            else readyIds : go (completed <> readyIds) remaining'
+
+dedupe :: (Eq a) => [a] -> [a]
+dedupe = reverse <. foldl' step []
+  where
+    step acc item = if item `elem` acc then acc else item : acc
+
 -- * Tests
 
 main :: IO ()
@@ -552,5 +959,57 @@ test =
         let s = defaultContextStrategy
         (csTemporalWindow s > 0) Test.@=? True
         (csSemanticThreshold s >= 0 && csSemanticThreshold s <= 1) Test.@=? True
-        (csRecencyDecay s > 0 && csRecencyDecay s <= 1) Test.@=? True
+        (csRecencyDecay s > 0 && csRecencyDecay s <= 1) Test.@=? True,
+      Test.unit "Applicative + fork detects a parallel fanout" <| do
+        let x = fork (graphInfer "x" "analyze X")
+            y = fork (graphInfer "y" "analyze Y")
+            graph = buildGraph ((,) <$> x <*> y)
+        length (pgParallelHints graph) Test.@=? 1
+        case pgParallelHints graph of
+          [hint] -> do
+            length (phMembers hint) Test.@=? 2
+            (phJoinNode hint `elem` pgRoots graph) Test.@=? True
+          _ -> Test.assertFailure "Expected exactly one parallel hint",
+      Test.unit "eliminateDeadBranches prunes unreachable nodes" <| do
+        let reachable = GraphNode (NodeId 1) (GraphNodeInfer (InferSpec "a" "prompt a")) []
+            dead = GraphNode (NodeId 2) (GraphNodeInfer (InferSpec "dead" "prompt dead")) []
+            graph =
+              PromptGraph
+                { pgNodes = [reachable, dead],
+                  pgRoots = [NodeId 1],
+                  pgParallelHints = []
+                }
+            pruned = eliminateDeadBranches graph
+        map gnId (pgNodes pruned) Test.@=? [NodeId 1],
+      Test.unit "fuseAdjacentInferences merges linear infer chain" <| do
+        let nodeA = GraphNode (NodeId 1) (GraphNodeInfer (InferSpec "a" "prompt a")) []
+            nodeB = GraphNode (NodeId 2) (GraphNodeInfer (InferSpec "b" "prompt b")) [NodeId 1]
+            graph =
+              PromptGraph
+                { pgNodes = [nodeA, nodeB],
+                  pgRoots = [NodeId 2],
+                  pgParallelHints = []
+                }
+            fused = fuseAdjacentInferences graph
+        length (pgNodes fused) Test.@=? 1
+        case pgNodes fused of
+          [node] -> do
+            gnId node Test.@=? NodeId 2
+            gnDependsOn node Test.@=? []
+            case gnKind node of
+              GraphNodeInfer spec -> do
+                isName spec Test.@=? "a+b"
+                Text.isInfixOf "prompt a" (isPrompt spec) Test.@=? True
+                Text.isInfixOf "prompt b" (isPrompt spec) Test.@=? True
+              _ -> Test.assertFailure "Expected fused infer node"
+          _ -> Test.assertFailure "Expected one fused node",
+      Test.unit "compileGraph produces stage schedule with parallel first stage" <| do
+        let x = fork (graphInfer "x" "analyze X")
+            y = fork (graphInfer "y" "analyze Y")
+            compiled = buildAndCompile ((,) <$> x <*> y)
+            stages = epStages (cgExecutionPlan compiled)
+        length stages Test.@=? 3
+        case stages of
+          (firstStage : _) -> length firstStage Test.@=? 2
+          _ -> Test.assertFailure "Expected at least one stage"
     ]