{-# LANGUAGE PartialTypeSignatures #-} {-# OPTIONS_GHC -Wno-partial-type-signatures #-} module SplittingTree where import Control.Monad.Trans.Class import Control.Monad.Trans.State import Data.Coerce (coerce) import Data.Foldable (traverse_) import Data.List (sortOn) import Data.Map.Strict qualified as Map newtype Block = Block Int deriving (Eq, Ord, Read, Show, Enum) -- A partition is represented by a map s -> Block. Two elements mapped to the -- same block are equivalent. Note that a permutation of the blocks will -- not change the partition, but it does change the underlying representation. -- (That is why I haven't given it an Eq instance yet.) newtype Partition s = Partition {getPartition :: Map.Map s Block} deriving (Read, Show) -- Determines whether two elements are equivalent in the partition. sameBlock :: Ord s => Partition s -> s -> s -> Bool sameBlock (Partition m) s t = m Map.! s == m Map.! t -- In the splitting tree, we record the splits we have made during partition -- refinement. The leafs correspond to the blocks in the partition, and the -- other nodes will be inner nodes. An inner node is just a number, the -- associated information is kept in the SplittingTree type. newtype InnerNode = Node Int deriving (Eq, Ord, Read, Show, Enum) -- Note that this type of tree is given by the "parent relation", so it is not -- inductively defined. (And we don't need that for partition refinement.) data SplittingTree s i o = SplittingTree { label :: Map.Map InnerNode [i] -- ^ The separating word for this node. We use a lazy list for this, because -- sharing is important here. , innerParent :: Map.Map InnerNode (InnerNode, o) -- ^ Tree structure. The node without parent is the root. , blockParent :: Map.Map Block (InnerNode, o) -- ^ Tree structure of the leaves. , size :: Map.Map Block Int -- ^ Size of each block of the partition } deriving (Show) -- Add size, and perhaps some info about the missing subblock data Splitter s i o = Splitter { split :: Map.Map o [s] , leftOut :: o , witness :: [i] } deriving (Show) -- The data structure used during partition refinement. data PRState s i o = PRState { partition :: Partition s , nextBlockId :: Block , splittingTree :: SplittingTree s i o , nextNodeId :: InnerNode } deriving (Show) updatePartition :: (Monad m, Ord s) => s -> Block -> StateT (PRState s i o) m () updatePartition s b = modify foo where foo prs = prs{partition = coerce (Map.insert s b) (partition prs)} updateSize :: Monad m => Block -> Int -> StateT (PRState s i o) m Int updateSize b n = modify (\prs -> prs{splittingTree = (splittingTree prs){size = Map.insert b n (size (splittingTree prs))}}) >> return n genNextBlockId :: Monad m => StateT (PRState s i o) m Block genNextBlockId = do idx <- gets nextBlockId modify (\prs -> prs{nextBlockId = succ (nextBlockId prs)}) return idx updateParent :: Monad m => Either Block InnerNode -> InnerNode -> o -> StateT (PRState s i o) m () updateParent (Left block) target output = modify foo where foo prs = prs{splittingTree = (splittingTree prs){blockParent = Map.insert block (target, output) (blockParent (splittingTree prs))}} updateParent (Right node) target output = modify foo where foo prs = prs{splittingTree = (splittingTree prs){innerParent = Map.insert node (target, output) (innerParent (splittingTree prs))}} updateLabel :: Monad m => InnerNode -> [i] -> StateT (PRState s i o) m () updateLabel node witness = modify (\prs -> prs{splittingTree = (splittingTree prs){label = Map.insert node witness (label (splittingTree prs))}}) genNextNodeId :: Monad m => StateT (PRState s i o) m InnerNode genNextNodeId = do idx <- gets nextNodeId modify (\prs -> prs{nextNodeId = succ (nextNodeId prs)}) return idx refineWithSplitter :: (Monad m, Ord o, Ord s) => i -> (s -> [s]) -> Splitter s i o -> StateT (PRState s i o) m [Splitter s i o] refineWithSplitter action rev Splitter{..} = do currentPartition <- getPartition <$> gets partition currentSplittingTree <- gets splittingTree let -- For each block in the splitter, we get its predecessors. predecessors = Map.map (concatMap rev) split -- For each block that we found, we are going to create temporary children. -- Only when the splitter actually splits the subblock, we change the -- partition. We work on list here, because we are going to re-order -- the data anyways. tempChildsList = [(b, [(o, [s])]) | (o, ls) <- Map.toList predecessors, s <- ls, let b = currentPartition Map.! s] -- We need it sorted on the block and the output tempChildsMaps = Map.map (Map.fromListWith (++)) . Map.fromListWith (++) $ tempChildsList -- Now we need to check the 3-way split: -- \* Some blocks have no states which occured, these don't appear. -- \* Some blocks have all states move to a single subblock, this is not a -- proper split and should be removed. -- \* Some blocks have different outputs (a proper split) or states which -- moved and states which didn't. properSplit b os | Map.null os = error "Should not happen" | Map.size os >= 2 = True | length (head (Map.elems os)) == size currentSplittingTree Map.! b = False | otherwise = True -- We keep the proper splits only tempChildsMaps2 = Map.filterWithKey properSplit tempChildsMaps -- Now we can assign new blocks to the newly split states. updateSubBlock nNIdx o ls = do -- Create a new sub-block nBIdx <- genNextBlockId -- Set all states to that id mapM_ (`updatePartition` nBIdx) ls -- And update the tree updateParent (Left nBIdx) nNIdx o n <- updateSize nBIdx (length ls) return (n, ls) updateBlock b children = do -- Create a new inner node nNIdx <- genNextNodeId -- Update all sub-blocks sizesAndSubblocks <- Map.traverseWithKey (updateSubBlock nNIdx) children -- There may be states remaining in b (because we process the smaller -- halves). So we need to update its size. -- TODO: Do we need to remove b if it is empty? let oldSize = size currentSplittingTree Map.! b missingSize = oldSize - sum (Map.map fst sizesAndSubblocks) _ <- updateSize b missingSize -- Also update the "missing block". This means b becomes a child of nNIdx -- and nNIdx a child of the current parent of b. -- And we update the witness by prepending the action let (currentParent, op) = blockParent currentSplittingTree Map.! b newWitness = action : witness updateParent (Right nNIdx) currentParent op updateParent (Left b) nNIdx leftOut updateLabel nNIdx newWitness -- Lastly, we make the new splitter. We cannot properly remove the -- largest subblock, as we would have to look up its states. I'm not -- sure it's worth the effort, so we only do it when we can remove a -- subblock. if missingSize == 0 then let ls = Map.toList sizesAndSubblocks -- TODO: sort(On) is unnecessarily expensive, we only need to -- know the biggest... ((o1, _) : smallerBlocks) = sortOn (\(_, (n, _)) -> -n) ls in return Splitter { split = Map.fromList (fmap (\(o, (_, lss)) -> (o, lss)) smallerBlocks) , leftOut = o1 , witness = newWitness } else return Splitter { split = Map.map snd sizesAndSubblocks , leftOut = leftOut , witness = newWitness } Map.elems <$> Map.traverseWithKey updateBlock tempChildsMaps2 refineWithOutput :: (Monad m, Ord o, Ord s) => i -> (s -> o) -> StateT (PRState s i o) m [Splitter s i o] refineWithOutput action out = do currentPartition <- getPartition <$> gets partition currentSplittingTree <- gets splittingTree let -- Compute all outputs and (existing blocks) tempChildsList = [(b, [(o, [s])]) | (s, b) <- Map.toList currentPartition, let o = out s] -- Then sort them on blocks and outputs tempChildsMaps = Map.map (Map.fromListWith (++)) . Map.fromListWith (++) $ tempChildsList -- Only consider actual splits tempChildsMaps2 = Map.filter (\children -> Map.size children >= 2) tempChildsMaps updateStates nNIdx o ss = do -- Create a new sub-block nBIdx <- genNextBlockId -- Set all states to that id mapM_ (`updatePartition` nBIdx) ss -- And update the tree updateParent (Left nBIdx) nNIdx o _ <- updateSize nBIdx (length ss) return () updateBlock b children = do -- We skip the biggest part, and don't update the blocks. let ((o1, biggest) : smaller) = sortOn (\(_, ss) -> negate (length ss)) . Map.toList $ children witness = [action] -- For the remaining blocks, we update the partition nNIdx <- genNextNodeId traverse_ (uncurry (updateStates nNIdx)) smaller -- If we are doing the very first split, the nNIdx node does not have a -- parent. So we don't have to do updates. Now nNIdx will be the root. case Map.lookup b (blockParent currentSplittingTree) of Nothing -> return () Just (currentParent, op) -> updateParent (Right nNIdx) currentParent op updateLabel nNIdx witness -- Remember to update the tree structure for the biggest (skipped) block updateParent (Left b) nNIdx o1 _ <- updateSize b (length biggest) -- Return the splitter, not that we already skipped the larger part. return Splitter { split = Map.fromList smaller , leftOut = o1 , witness = witness } Map.elems <$> Map.traverseWithKey updateBlock tempChildsMaps2 initialPRState :: Ord s => [s] -> PRState s i o initialPRState ls = PRState { partition = Partition . Map.fromList $ [(s, Block 0) | s <- ls] , nextBlockId = Block 1 , splittingTree = SplittingTree { label = Map.empty , innerParent = Map.empty , blockParent = Map.empty , size = Map.singleton (Block 0) (length ls) } , nextNodeId = Node 0 } refineWithAllOutputs :: (Monad m, Ord o, Ord s) => [(i, s -> o)] -> StateT (PRState s i o) m [Splitter s i o] refineWithAllOutputs ls = concat <$> traverse (uncurry refineWithOutput) ls refineWithSplitterAllInputs :: (Monad m, Ord o, Ord s) => [(i, s -> [s])] -> Splitter s i o -> StateT (PRState s i o) m [Splitter s i o] refineWithSplitterAllInputs ls splitter = concat <$> traverse (\(i, rev) -> refineWithSplitter i rev splitter) ls refine :: (Monad m, Ord o, Ord s) => ([i] -> m ()) -> [(i, s -> o)] -> [(i, s -> [s])] -> StateT (PRState s i o) m () refine ping outputs transitionsReverse = do initialQueue <- refineWithAllOutputs outputs let loop [] = return () loop (splitter : splitters) = do _ <- lift (ping (witness splitter)) newQueue <- refineWithSplitterAllInputs transitionsReverse splitter loop (splitters <> newQueue) loop initialQueue