1
Fork 0
mirror of https://git.cs.ou.nl/joshua.moerman/mealy-decompose.git synced 2025-04-30 02:07:44 +02:00
mealy-decompose/src/SplittingTree.hs
2024-06-14 14:43:32 +02:00

279 lines
11 KiB
Haskell

{-# LANGUAGE PartialTypeSignatures #-}
{-# OPTIONS_GHC -Wno-partial-type-signatures #-}
module SplittingTree where
import Control.Monad.Trans.Class
import Control.Monad.Trans.State
import Data.Coerce (coerce)
import Data.Foldable (traverse_)
import Data.List (sortOn)
import Data.Map.Strict qualified as Map
newtype Block = Block Int
deriving (Eq, Ord, Read, Show, Enum)
-- A partition is represented by a map s -> Block. Two elements mapped to the
-- same block are equivalent. Note that a permutation of the blocks will
-- not change the partition, but it does change the underlying representation.
-- (That is why I haven't given it an Eq instance yet.)
newtype Partition s = Partition {getPartition :: Map.Map s Block}
deriving (Read, Show)
-- Determines whether two elements are equivalent in the partition.
sameBlock :: Ord s => Partition s -> s -> s -> Bool
sameBlock (Partition m) s t = m Map.! s == m Map.! t
-- In the splitting tree, we record the splits we have made during partition
-- refinement. The leafs correspond to the blocks in the partition, and the
-- other nodes will be inner nodes. An inner node is just a number, the
-- associated information is kept in the SplittingTree type.
newtype InnerNode = Node Int
deriving (Eq, Ord, Read, Show, Enum)
-- Note that this type of tree is given by the "parent relation", so it is not
-- inductively defined. (And we don't need that for partition refinement.)
data SplittingTree s i o = SplittingTree
{ label :: Map.Map InnerNode [i]
-- ^ The separating word for this node. We use a lazy list for this, because
-- sharing is important here.
, innerParent :: Map.Map InnerNode (InnerNode, o)
-- ^ Tree structure. The node without parent is the root.
, blockParent :: Map.Map Block (InnerNode, o)
-- ^ Tree structure of the leaves.
, size :: Map.Map Block Int
-- ^ Size of each block of the partition
}
deriving (Show)
-- Add size, and perhaps some info about the missing subblock
data Splitter s i o = Splitter
{ split :: Map.Map o [s]
, leftOut :: o
, witness :: [i]
}
deriving (Show)
-- The data structure used during partition refinement.
data PRState s i o = PRState
{ partition :: Partition s
, nextBlockId :: Block
, splittingTree :: SplittingTree s i o
, nextNodeId :: InnerNode
}
deriving (Show)
updatePartition :: (Monad m, Ord s) => s -> Block -> StateT (PRState s i o) m ()
updatePartition s b = modify foo
where
foo prs = prs{partition = coerce (Map.insert s b) (partition prs)}
updateSize :: Monad m => Block -> Int -> StateT (PRState s i o) m Int
updateSize b n =
modify (\prs -> prs{splittingTree = (splittingTree prs){size = Map.insert b n (size (splittingTree prs))}})
>> return n
genNextBlockId :: Monad m => StateT (PRState s i o) m Block
genNextBlockId = do
idx <- gets nextBlockId
modify (\prs -> prs{nextBlockId = succ (nextBlockId prs)})
return idx
updateParent :: Monad m => Either Block InnerNode -> InnerNode -> o -> StateT (PRState s i o) m ()
updateParent (Left block) target output = modify foo
where
foo prs = prs{splittingTree = (splittingTree prs){blockParent = Map.insert block (target, output) (blockParent (splittingTree prs))}}
updateParent (Right node) target output = modify foo
where
foo prs = prs{splittingTree = (splittingTree prs){innerParent = Map.insert node (target, output) (innerParent (splittingTree prs))}}
updateLabel :: Monad m => InnerNode -> [i] -> StateT (PRState s i o) m ()
updateLabel node witness = modify (\prs -> prs{splittingTree = (splittingTree prs){label = Map.insert node witness (label (splittingTree prs))}})
genNextNodeId :: Monad m => StateT (PRState s i o) m InnerNode
genNextNodeId = do
idx <- gets nextNodeId
modify (\prs -> prs{nextNodeId = succ (nextNodeId prs)})
return idx
refineWithSplitter :: (Monad m, Ord o, Ord s) => i -> (s -> [s]) -> Splitter s i o -> StateT (PRState s i o) m [Splitter s i o]
refineWithSplitter action rev Splitter{..} = do
currentPartition <- getPartition <$> gets partition
currentSplittingTree <- gets splittingTree
let
-- For each block in the splitter, we get its predecessors.
predecessors = Map.map (concatMap rev) split
-- For each block that we found, we are going to create temporary children.
-- Only when the splitter actually splits the subblock, we change the
-- partition. We work on list here, because we are going to re-order
-- the data anyways.
tempChildsList = [(b, [(o, [s])]) | (o, ls) <- Map.toList predecessors, s <- ls, let b = currentPartition Map.! s]
-- We need it sorted on the block and the output
tempChildsMaps = Map.map (Map.fromListWith (++)) . Map.fromListWith (++) $ tempChildsList
-- Now we need to check the 3-way split:
-- \* Some blocks have no states which occured, these don't appear.
-- \* Some blocks have all states move to a single subblock, this is not a
-- proper split and should be removed.
-- \* Some blocks have different outputs (a proper split) or states which
-- moved and states which didn't.
properSplit b os
| Map.null os = error "Should not happen"
| Map.size os >= 2 = True
| length (head (Map.elems os)) == size currentSplittingTree Map.! b = False
| otherwise = True
-- We keep the proper splits only
tempChildsMaps2 = Map.filterWithKey properSplit tempChildsMaps
-- Now we can assign new blocks to the newly split states.
updateSubBlock nNIdx o ls = do
-- Create a new sub-block
nBIdx <- genNextBlockId
-- Set all states to that id
mapM_ (`updatePartition` nBIdx) ls
-- And update the tree
updateParent (Left nBIdx) nNIdx o
n <- updateSize nBIdx (length ls)
return (n, ls)
updateBlock b children = do
-- Create a new inner node
nNIdx <- genNextNodeId
-- Update all sub-blocks
sizesAndSubblocks <- Map.traverseWithKey (updateSubBlock nNIdx) children
-- There may be states remaining in b (because we process the smaller
-- halves). So we need to update its size.
-- TODO: Do we need to remove b if it is empty?
let oldSize = size currentSplittingTree Map.! b
missingSize = oldSize - sum (Map.map fst sizesAndSubblocks)
_ <- updateSize b missingSize
-- Also update the "missing block". This means b becomes a child of nNIdx
-- and nNIdx a child of the current parent of b.
-- And we update the witness by prepending the action
let (currentParent, op) = blockParent currentSplittingTree Map.! b
newWitness = action : witness
updateParent (Right nNIdx) currentParent op
updateParent (Left b) nNIdx leftOut
updateLabel nNIdx newWitness
-- Lastly, we make the new splitter. We cannot properly remove the
-- largest subblock, as we would have to look up its states. I'm not
-- sure it's worth the effort, so we only do it when we can remove a
-- subblock.
if missingSize == 0
then
let ls = Map.toList sizesAndSubblocks
-- TODO: sort(On) is unnecessarily expensive, we only need to
-- know the biggest...
((o1, _) : smallerBlocks) = sortOn (\(_, (n, _)) -> -n) ls
in return
Splitter
{ split = Map.fromList (fmap (\(o, (_, lss)) -> (o, lss)) smallerBlocks)
, leftOut = o1
, witness = newWitness
}
else
return
Splitter
{ split = Map.map snd sizesAndSubblocks
, leftOut = leftOut
, witness = newWitness
}
Map.elems <$> Map.traverseWithKey updateBlock tempChildsMaps2
refineWithOutput :: (Monad m, Ord o, Ord s) => i -> (s -> o) -> StateT (PRState s i o) m [Splitter s i o]
refineWithOutput action out = do
currentPartition <- getPartition <$> gets partition
currentSplittingTree <- gets splittingTree
let
-- Compute all outputs and (existing blocks)
tempChildsList = [(b, [(o, [s])]) | (s, b) <- Map.toList currentPartition, let o = out s]
-- Then sort them on blocks and outputs
tempChildsMaps = Map.map (Map.fromListWith (++)) . Map.fromListWith (++) $ tempChildsList
-- Only consider actual splits
tempChildsMaps2 = Map.filter (\children -> Map.size children >= 2) tempChildsMaps
updateStates nNIdx o ss = do
-- Create a new sub-block
nBIdx <- genNextBlockId
-- Set all states to that id
mapM_ (`updatePartition` nBIdx) ss
-- And update the tree
updateParent (Left nBIdx) nNIdx o
_ <- updateSize nBIdx (length ss)
return ()
updateBlock b children = do
-- We skip the biggest part, and don't update the blocks.
let ((o1, biggest) : smaller) = sortOn (\(_, ss) -> negate (length ss)) . Map.toList $ children
witness = [action]
-- For the remaining blocks, we update the partition
nNIdx <- genNextNodeId
traverse_ (uncurry (updateStates nNIdx)) smaller
-- If we are doing the very first split, the nNIdx node does not have a
-- parent. So we don't have to do updates. Now nNIdx will be the root.
case Map.lookup b (blockParent currentSplittingTree) of
Nothing -> return ()
Just (currentParent, op) -> updateParent (Right nNIdx) currentParent op
updateLabel nNIdx witness
-- Remember to update the tree structure for the biggest (skipped) block
updateParent (Left b) nNIdx o1
_ <- updateSize b (length biggest)
-- Return the splitter, not that we already skipped the larger part.
return
Splitter
{ split = Map.fromList smaller
, leftOut = o1
, witness = witness
}
Map.elems <$> Map.traverseWithKey updateBlock tempChildsMaps2
initialPRState :: Ord s => [s] -> PRState s i o
initialPRState ls =
PRState
{ partition = Partition . Map.fromList $ [(s, Block 0) | s <- ls]
, nextBlockId = Block 1
, splittingTree =
SplittingTree
{ label = Map.empty
, innerParent = Map.empty
, blockParent = Map.empty
, size = Map.singleton (Block 0) (length ls)
}
, nextNodeId = Node 0
}
refineWithAllOutputs :: (Monad m, Ord o, Ord s) => [(i, s -> o)] -> StateT (PRState s i o) m [Splitter s i o]
refineWithAllOutputs ls = concat <$> traverse (uncurry refineWithOutput) ls
refineWithSplitterAllInputs :: (Monad m, Ord o, Ord s) => [(i, s -> [s])] -> Splitter s i o -> StateT (PRState s i o) m [Splitter s i o]
refineWithSplitterAllInputs ls splitter = concat <$> traverse (\(i, rev) -> refineWithSplitter i rev splitter) ls
refine :: (Monad m, Ord o, Ord s) => ([i] -> m ()) -> [(i, s -> o)] -> [(i, s -> [s])] -> StateT (PRState s i o) m ()
refine ping outputs transitionsReverse = do
initialQueue <- refineWithAllOutputs outputs
let
loop [] = return ()
loop (splitter : splitters) = do
_ <- lift (ping (witness splitter))
newQueue <- refineWithSplitterAllInputs transitionsReverse splitter
loop (splitters <> newQueue)
loop initialQueue