From db2b00273c234528d97c2d3287a2f26d3fa51087 Mon Sep 17 00:00:00 2001 From: Joshua Moerman Date: Thu, 3 Jan 2019 13:52:35 +0100 Subject: [PATCH] Adds fold functions for EquivariantSet --- src/EquivariantSet.hs | 38 +++++++++++++++++++++++++++++++++----- src/Orbit.hs | 1 + src/Orbit/Class.hs | 4 ++++ 3 files changed, 38 insertions(+), 5 deletions(-) diff --git a/src/EquivariantSet.hs b/src/EquivariantSet.hs index 8899aec..bb84155 100644 --- a/src/EquivariantSet.hs +++ b/src/EquivariantSet.hs @@ -13,13 +13,11 @@ import Data.Proxy import Data.Set (Set) import qualified Data.Set as Set -import Data.Semigroup (Semigroup) +import Prelude hiding (map, product) import Orbit import Support --- TODO: think about folds (the monoids should be nominal?) --- TODO: partition / fromList / ... -- Given a nominal type, we can construct equivariant sets. These simply use a -- standard set data structure. This works well because orbits are uniquely @@ -76,9 +74,11 @@ empty = EqSet Set.empty singleOrbit :: Orbit a => a -> EquivariantSet a singleOrbit = EqSet . Set.singleton . toOrbit +-- Insert whole orbit of a insert :: (Orbit a, Ord (Orb a)) => a -> EquivariantSet a -> EquivariantSet a insert a = EqSet . Set.insert (toOrbit a) . unEqSet +-- Deletes whole orbit of a delete :: (Orbit a, Ord (Orb a)) => a -> EquivariantSet a -> EquivariantSet a delete a = EqSet . Set.delete (toOrbit a) . unEqSet @@ -95,11 +95,16 @@ difference a b = EqSet $ Set.difference (unEqSet a) (unEqSet b) intersection :: Ord (Orb a) => EquivariantSet a -> EquivariantSet a -> EquivariantSet a intersection a b = EqSet $ Set.intersection (unEqSet a) (unEqSet b) --- This is the meat of the file! Relies on the ordering of Orbit.product +-- Cartesian product. This is a non trivial thing and relies on the +-- ordering of Orbit.product. product :: forall a b. (Orbit a, Orbit b) => EquivariantSet a -> EquivariantSet b -> EquivariantSet (a, b) product (EqSet sa) (EqSet sb) = EqSet . Set.fromDistinctAscList . concat $ Orbit.product (Proxy @a) (Proxy @b) <$> Set.toAscList sa <*> Set.toAscList sb +-- Cartesian product followed by a function (f should be equivariant) +productWith :: (Orbit a, Orbit b, Orbit c, Ord (Orb c)) => (a -> b -> c) -> EquivariantSet a -> EquivariantSet b -> EquivariantSet c +productWith f as bs = map (uncurry f) $ EquivariantSet.product as bs + -- Filter @@ -107,6 +112,11 @@ product (EqSet sa) (EqSet sb) = EqSet . Set.fromDistinctAscList . concat filter :: Orbit a => (a -> Bool) -> EquivariantSet a -> EquivariantSet a filter f (EqSet s) = EqSet . Set.filter (f . getElementE) $ s +-- f should be equivariant +partition :: Orbit a => (a -> Bool) -> EquivariantSet a -> (EquivariantSet a, EquivariantSet a) +partition f (EqSet s) = both EqSet . Set.partition (f . getElementE) $ s + where both f (a, b) = (f a, f b) + -- Map @@ -115,14 +125,32 @@ filter f (EqSet s) = EqSet . Set.filter (f . getElementE) $ s map :: (Orbit a, Orbit b, Ord (Orb b)) => (a -> b) -> EquivariantSet a -> EquivariantSet b map f = EqSet . Set.map (omap f) . unEqSet --- f should also preserve order on the orbit types! +-- precondition: f quivariant and preserves order on the orbits. -- This means you should know the representation to use it well mapMonotonic :: (Orbit a, Orbit b) => (a -> b) -> EquivariantSet a -> EquivariantSet b mapMonotonic f = EqSet . Set.mapMonotonic (omap f) . unEqSet +-- Folds + +-- I am not sure about the preconditions for folds +foldr :: Orbit a => (a -> b -> b) -> b -> EquivariantSet a -> b +foldr f b = Set.foldr (f . getElementE) b . unEqSet + +foldl :: Orbit a => (b -> a -> b) -> b -> EquivariantSet a -> b +foldl f b = Set.foldl (\b -> f b . getElementE) b . unEqSet + + -- Conversion toList :: Orbit a => EquivariantSet a -> [a] toList = fmap getElementE . Set.toList . unEqSet +fromList :: (Orbit a, Ord (Orb a)) => [a] -> EquivariantSet a +fromList = EqSet . Set.fromList . fmap toOrbit + +toOrbitList :: EquivariantSet a -> [Orb a] +toOrbitList = Set.toList . unEqSet + +fromOrbitList :: Ord (Orb a) => [Orb a] -> EquivariantSet a +fromOrbitList = EqSet . Set.fromList diff --git a/src/Orbit.hs b/src/Orbit.hs index 6c1b092..25963e7 100644 --- a/src/Orbit.hs +++ b/src/Orbit.hs @@ -48,6 +48,7 @@ instance Orbit Support where index _ n = n +-- Some instances we can derive via generics deriving instance (Orbit a, Orbit b) => Orbit (Either a b) deriving instance Orbit () diff --git a/src/Orbit/Class.hs b/src/Orbit/Class.hs index 79deca3..dfc6abf 100644 --- a/src/Orbit/Class.hs +++ b/src/Orbit/Class.hs @@ -34,6 +34,8 @@ class Orbit a where getElement :: Orb a -> Support -> a index :: Proxy a -> Orb a -> Int + -- We provide default implementations for generic types + -- This enables us to derive Orbit instances by the Haskell compiler -- default Orb a :: (Generic a, GOrbit (Rep a)) => * type Orb a = GOrb (Rep a) @@ -137,6 +139,8 @@ selectOrd f x ~(ls, rs) = case f x of instance Orbit a => GOrbit (K1 c a) where + -- Cannot use (Orb a) here, that may lead to a recursive type + -- So we use the type OrbRec a instead (which uses Orb a one step later). type GOrb (K1 c a) = OrbRec a gtoOrbit (K1 x) = OrbRec (toOrbit x) gsupport (K1 x) = support x