From a20e5e931718b7e6e72138b83bb1b9f96277d3c7 Mon Sep 17 00:00:00 2001 From: Joshua Moerman Date: Fri, 1 Dec 2023 16:59:54 +0100 Subject: [PATCH] Made isRefinementOf much faster --- app/Main.hs | 25 ++++++++++--------------- src/Partition.hs | 37 ++++++++++++++++++++++++++++++++++++- 2 files changed, 46 insertions(+), 16 deletions(-) diff --git a/app/Main.hs b/app/Main.hs index 00c18ef..151d5b0 100644 --- a/app/Main.hs +++ b/app/Main.hs @@ -10,9 +10,7 @@ import Control.Monad.Trans.State.Strict import Control.Monad (forM_, when, forever) import Data.Map.Strict qualified as Map import Data.Maybe (mapMaybe) --- import Data.Semigroup (Arg(..)) --- import Data.Set qualified as Set --- import Data.List.Ordered (nubSort) +import Data.List.Ordered (nubSort) import Data.List (minimumBy) import Data.Function (on) import System.Environment @@ -61,8 +59,8 @@ main = do -- Then compute each projection -- I did some manual preprocessing, these are the only interesting bits - let outs = ["10", "10-O9", "2.2", "3.0", "3.1", "3.10", "3.12", "3.13", "3.14", "3.16", "3.17", "3.18", "3.19", "3.2", "3.20", "3.21", "3.3", "3.4", "3.6", "3.7", "3.8", "3.9", "5.0", "5.1", "5.12", "5.13", "5.17", "5.2", "5.21", "5.23", "5.6", "5.7", "5.8", "5.9", "quiescence"] - -- outs = outputs machine + let -- outs = ["10", "10-O9", "2.2", "3.0", "3.1", "3.10", "3.12", "3.13", "3.14", "3.16", "3.17", "3.18", "3.19", "3.2", "3.20", "3.21", "3.3", "3.4", "3.6", "3.7", "3.8", "3.9", "5.0", "5.1", "5.12", "5.13", "5.17", "5.2", "5.21", "5.23", "5.6", "5.7", "5.8", "5.9", "quiescence"] + outs = outputs machine projections0 = allProjections machine outs projections = zip outs $ fmap refineMealy projections0 @@ -72,6 +70,7 @@ main = do printPartition partition ) + {- let totalSize = sum (fmap (numBlocks . snd) projections) putStrLn $ "total size = " <> show totalSize @@ -93,10 +92,10 @@ main = do ) print "done" + -} {- - -- Check refinement relations for all pairs -- This is a bit messy, it skips machines which are equivalent -- to earlier checked machines, so we thread some state through this @@ -112,20 +111,17 @@ main = do forM_ projections (\(o2, b2) -> do (repr0, _) <- get when (o1 < o2 && o2 `Map.notMember` repr0) $ do - case (isRefinementOf b1 b2, isRefinementOf b2 b1) of - (True, True) -> do + case comparePartitions b1 b2 of + Equivalent -> do (repr, ls) <- get put (Map.insert o2 o1 repr, ls) - (True, False) -> do + Refinement -> do (repr, ls) <- get put (repr, (o1, o2):ls) - (False, True) -> do + Coarsening -> do (repr, ls) <- get put (repr, (o2, o1):ls) - (False, False) -> return () - - -- liftIO $ putStr " vs. " - -- liftIO $ print o2 + Incomparable -> return () ) ) @@ -143,5 +139,4 @@ main = do ) return () - -} diff --git a/src/Partition.hs b/src/Partition.hs index 58ead5d..0982322 100644 --- a/src/Partition.hs +++ b/src/Partition.hs @@ -4,7 +4,7 @@ module Partition ) where import Control.Monad.Trans.State.Strict (runState, get, put) -import Data.Partition (Partition(..), isRefinementOf, numStates) +import Data.Partition (Partition(..), numStates) import Data.Vector qualified as V import Data.Map.Strict qualified as Map import Unsafe.Coerce (unsafeCoerce) @@ -27,3 +27,38 @@ commonRefinement p1 p2 = (vect, (_, nextBlock)) = runState (V.generateM n blockAtIdx) (Map.empty, 0) in Partition { numBlocks = unsafeCoerce nextBlock, stateAssignment = vect } +-- Could be made faster by doing what commonRefinement is doing but +-- stopping early. This is already much faster than what is in +-- the CoPaR library, so I won't bother. +isRefinementOf2 :: Partition -> Partition -> Bool +isRefinementOf2 refined original = + numBlocks refined == numBlocks (commonRefinement refined original) + +-- See comment at isRefinementOf2 +isEquivalent :: Partition -> Partition -> Bool +isEquivalent p1 p2 = + p1 == p2 || (numBlocks p1 == numBlocks p2 && numBlocks p1 == numBlocks (commonRefinement p1 p2)) + +-- Instead of checking whether one partition is a refinement of another AND +-- also checking vice versa. We can check the direction at once, computing the +-- common refinement only once. It saves some time. +data Comparison + = Equivalent + | Refinement + | Coarsening + | Incomparable + deriving (Eq, Ord, Read, Show, Enum, Bounded) + +-- See comment at isRefinementOf2 +comparePartitions :: Partition -> Partition -> Comparison +comparePartitions p1 p2 + | p1 == p2 = Equivalent + | otherwise = let glb = commonRefinement p1 p2 + n1 = numBlocks p1 + n2 = numBlocks p2 + n3 = numBlocks glb + in case (n1 == n3, n2 == n3) of + (True, True) -> Equivalent + (True, False) -> Refinement + (False, True) -> Coarsening + (False, False) -> Incomparable