{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# OPTIONS_GHC -Wno-unticked-promoted-constructors #-}

-- |
-- Copyright: © 2022–2023 Jonathan Knowles
-- License: Apache-2.0
--
module Data.MultiSet
    ( MultiSet
    , MultiSetType (..)
    , MultiSetN
    , MultiSetZ
    , Multiplicity
    , cardinality
    , multiplicity
    , maximum
    , minimum
    , invert
    , intersection
    , union
    , emptyN
    , emptyZ
    , fromListN
    , fromListZ
    , toList
    , toMultiSetZ
    , toMultiSetN
    )
    where

import Prelude hiding
    ( gcd, maximum, minimum )

import Data.Coerce
    ( coerce )
import Data.Group
    ( Group )
import Data.Monoid
    ( Sum (..) )
import Data.MonoidMap
    ( MonoidMap )
import Numeric.Natural
    ( Natural )

import qualified Data.Foldable as F
import qualified Data.Group as Group
import qualified Data.MonoidMap as MonoidMap

data MultiSet (t :: MultiSetType) a =
    MultiplicityConstraints (Multiplicity t) =>
    MultiSet {forall (t :: MultiSetType) a.
MultiSet t a -> MonoidMap a (Sum (Multiplicity t))
unwrap :: MonoidMap a (Sum (Multiplicity t))}

data MultiSetType
    -- | Indicates a multiset with 'Natural' (ℕ) multiplicity.
    = N
    -- | Indicates a multiset with 'Integer' (ℤ) multiplicity.
    | Z

-- | Represents a multiset with 'Natural' (ℕ) multiplicity.
type MultiSetN = MultiSet N

-- | Represents a multiset with 'Integer' (ℤ) multiplicity.
type MultiSetZ = MultiSet Z

-- | Maps the type of a multiset to the type of its multiplicity.
type family Multiplicity (t :: MultiSetType) where
    Multiplicity N = Natural
    Multiplicity Z = Integer

type MultiplicityConstraints t = (Eq t, Integral t, Num t, Ord t, Show t)

deriving instance Eq a => Eq (MultiSet t a)
deriving instance Show a => Show (MultiSet t a)

instance Ord a => Semigroup (MultiSet t a) where
    MultiSet MonoidMap a (Sum (Multiplicity t))
s1 <> :: MultiSet t a -> MultiSet t a -> MultiSet t a
<> MultiSet MonoidMap a (Sum (Multiplicity t))
s2 = MonoidMap a (Sum (Multiplicity t)) -> MultiSet t a
forall (t :: MultiSetType) a.
MultiplicityConstraints (Multiplicity t) =>
MonoidMap a (Sum (Multiplicity t)) -> MultiSet t a
MultiSet (MonoidMap a (Sum (Multiplicity t))
s1 MonoidMap a (Sum (Multiplicity t))
-> MonoidMap a (Sum (Multiplicity t))
-> MonoidMap a (Sum (Multiplicity t))
forall a. Semigroup a => a -> a -> a
<> MonoidMap a (Sum (Multiplicity t))
s2)

instance Ord a => Monoid (MultiSetN a) where
    mempty :: MultiSetN a
mempty = MonoidMap a (Sum (Multiplicity 'N)) -> MultiSetN a
forall (t :: MultiSetType) a.
MultiplicityConstraints (Multiplicity t) =>
MonoidMap a (Sum (Multiplicity t)) -> MultiSet t a
MultiSet MonoidMap a (Sum Natural)
MonoidMap a (Sum (Multiplicity 'N))
forall a. Monoid a => a
mempty
instance Ord a => Monoid (MultiSetZ a) where
    mempty :: MultiSetZ a
mempty = MonoidMap a (Sum (Multiplicity 'Z)) -> MultiSetZ a
forall (t :: MultiSetType) a.
MultiplicityConstraints (Multiplicity t) =>
MonoidMap a (Sum (Multiplicity t)) -> MultiSet t a
MultiSet MonoidMap a (Sum Integer)
MonoidMap a (Sum (Multiplicity 'Z))
forall a. Monoid a => a
mempty

instance Ord a => Group (MultiSetZ a) where
    invert :: MultiSetZ a -> MultiSetZ a
invert (MultiSet MonoidMap a (Sum (Multiplicity 'Z))
s) = MonoidMap a (Sum (Multiplicity 'Z)) -> MultiSetZ a
forall (t :: MultiSetType) a.
MultiplicityConstraints (Multiplicity t) =>
MonoidMap a (Sum (Multiplicity t)) -> MultiSet t a
MultiSet (MonoidMap a (Sum Integer) -> MonoidMap a (Sum Integer)
forall v k.
(MonoidNull v, Group v) =>
MonoidMap k v -> MonoidMap k v
MonoidMap.invert MonoidMap a (Sum Integer)
MonoidMap a (Sum (Multiplicity 'Z))
s)

emptyN :: MultiSetN a
emptyN :: forall a. MultiSetN a
emptyN = MonoidMap a (Sum (Multiplicity 'N)) -> MultiSet 'N a
forall (t :: MultiSetType) a.
MultiplicityConstraints (Multiplicity t) =>
MonoidMap a (Sum (Multiplicity t)) -> MultiSet t a
MultiSet MonoidMap a (Sum Natural)
MonoidMap a (Sum (Multiplicity 'N))
forall k v. MonoidMap k v
MonoidMap.empty

emptyZ :: MultiSetZ a
emptyZ :: forall a. MultiSetZ a
emptyZ = MonoidMap a (Sum (Multiplicity 'Z)) -> MultiSet 'Z a
forall (t :: MultiSetType) a.
MultiplicityConstraints (Multiplicity t) =>
MonoidMap a (Sum (Multiplicity t)) -> MultiSet t a
MultiSet MonoidMap a (Sum Integer)
MonoidMap a (Sum (Multiplicity 'Z))
forall k v. MonoidMap k v
MonoidMap.empty

fromListN :: Ord a => [(a, Natural)] -> MultiSetN a
fromListN :: forall a. Ord a => [(a, Natural)] -> MultiSetN a
fromListN = MonoidMap a (Sum Natural) -> MultiSet 'N a
MonoidMap a (Sum (Multiplicity 'N)) -> MultiSet 'N a
forall (t :: MultiSetType) a.
MultiplicityConstraints (Multiplicity t) =>
MonoidMap a (Sum (Multiplicity t)) -> MultiSet t a
MultiSet (MonoidMap a (Sum Natural) -> MultiSet 'N a)
-> ([(a, Natural)] -> MonoidMap a (Sum Natural))
-> [(a, Natural)]
-> MultiSet 'N a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(a, Sum Natural)] -> MonoidMap a (Sum Natural)
forall k v. (Ord k, MonoidNull v) => [(k, v)] -> MonoidMap k v
MonoidMap.fromList ([(a, Sum Natural)] -> MonoidMap a (Sum Natural))
-> ([(a, Natural)] -> [(a, Sum Natural)])
-> [(a, Natural)]
-> MonoidMap a (Sum Natural)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(a, Natural)] -> [(a, Sum Natural)]
forall a b. Coercible a b => a -> b
coerce

fromListZ :: Ord a => [(a, Integer)] -> MultiSetZ a
fromListZ :: forall a. Ord a => [(a, Integer)] -> MultiSetZ a
fromListZ = MonoidMap a (Sum Integer) -> MultiSet 'Z a
MonoidMap a (Sum (Multiplicity 'Z)) -> MultiSet 'Z a
forall (t :: MultiSetType) a.
MultiplicityConstraints (Multiplicity t) =>
MonoidMap a (Sum (Multiplicity t)) -> MultiSet t a
MultiSet (MonoidMap a (Sum Integer) -> MultiSet 'Z a)
-> ([(a, Integer)] -> MonoidMap a (Sum Integer))
-> [(a, Integer)]
-> MultiSet 'Z a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(a, Sum Integer)] -> MonoidMap a (Sum Integer)
forall k v. (Ord k, MonoidNull v) => [(k, v)] -> MonoidMap k v
MonoidMap.fromList ([(a, Sum Integer)] -> MonoidMap a (Sum Integer))
-> ([(a, Integer)] -> [(a, Sum Integer)])
-> [(a, Integer)]
-> MonoidMap a (Sum Integer)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(a, Integer)] -> [(a, Sum Integer)]
forall a b. Coercible a b => a -> b
coerce

toList :: MultiSet t a -> [(a, Multiplicity t)]
toList :: forall (t :: MultiSetType) a. MultiSet t a -> [(a, Multiplicity t)]
toList = [(a, Sum (Multiplicity t))] -> [(a, Multiplicity t)]
forall a b. Coercible a b => a -> b
coerce ([(a, Sum (Multiplicity t))] -> [(a, Multiplicity t)])
-> (MultiSet t a -> [(a, Sum (Multiplicity t))])
-> MultiSet t a
-> [(a, Multiplicity t)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MonoidMap a (Sum (Multiplicity t)) -> [(a, Sum (Multiplicity t))]
forall k v. MonoidMap k v -> [(k, v)]
MonoidMap.toList (MonoidMap a (Sum (Multiplicity t)) -> [(a, Sum (Multiplicity t))])
-> (MultiSet t a -> MonoidMap a (Sum (Multiplicity t)))
-> MultiSet t a
-> [(a, Sum (Multiplicity t))]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MultiSet t a -> MonoidMap a (Sum (Multiplicity t))
forall (t :: MultiSetType) a.
MultiSet t a -> MonoidMap a (Sum (Multiplicity t))
unwrap

toMultiSetZ :: Ord a => (MultiSetN a, MultiSetN a) -> MultiSetZ a
toMultiSetZ :: forall a. Ord a => (MultiSetN a, MultiSetN a) -> MultiSetZ a
toMultiSetZ (MultiSet MonoidMap a (Sum (Multiplicity 'N))
ns, MultiSet MonoidMap a (Sum (Multiplicity 'N))
ps) = MonoidMap a (Sum (Multiplicity 'Z)) -> MultiSet 'Z a
forall (t :: MultiSetType) a.
MultiplicityConstraints (Multiplicity t) =>
MonoidMap a (Sum (Multiplicity t)) -> MultiSet t a
MultiSet (MonoidMap a (Sum (Multiplicity 'Z)) -> MultiSet 'Z a)
-> MonoidMap a (Sum (Multiplicity 'Z)) -> MultiSet 'Z a
forall a b. (a -> b) -> a -> b
$ MonoidMap a (Sum Integer)
-> MonoidMap a (Sum Integer) -> MonoidMap a (Sum Integer)
forall a. Semigroup a => a -> a -> a
(<>)
    ((Sum Natural -> Sum Integer)
-> MonoidMap a (Sum Natural) -> MonoidMap a (Sum Integer)
forall v2 v1 k.
MonoidNull v2 =>
(v1 -> v2) -> MonoidMap k v1 -> MonoidMap k v2
MonoidMap.map ((Natural -> Integer) -> Sum Natural -> Sum Integer
forall a b. (a -> b) -> Sum a -> Sum b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Integer -> Integer
forall a. Num a => a -> a
negate (Integer -> Integer) -> (Natural -> Integer) -> Natural -> Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Natural -> Integer
naturalToInteger)) MonoidMap a (Sum Natural)
MonoidMap a (Sum (Multiplicity 'N))
ns)
    ((Sum Natural -> Sum Integer)
-> MonoidMap a (Sum Natural) -> MonoidMap a (Sum Integer)
forall v2 v1 k.
MonoidNull v2 =>
(v1 -> v2) -> MonoidMap k v1 -> MonoidMap k v2
MonoidMap.map ((Natural -> Integer) -> Sum Natural -> Sum Integer
forall a b. (a -> b) -> Sum a -> Sum b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (         Natural -> Integer
naturalToInteger)) MonoidMap a (Sum Natural)
MonoidMap a (Sum (Multiplicity 'N))
ps)

toMultiSetN :: MultiSetZ a -> (MultiSetN a, MultiSetN a)
toMultiSetN :: forall a. MultiSetZ a -> (MultiSetN a, MultiSetN a)
toMultiSetN (MultiSet MonoidMap a (Sum (Multiplicity 'Z))
s) = (MonoidMap a (Sum (Multiplicity 'N)) -> MultiSet 'N a
forall (t :: MultiSetType) a.
MultiplicityConstraints (Multiplicity t) =>
MonoidMap a (Sum (Multiplicity t)) -> MultiSet t a
MultiSet MonoidMap a (Sum Natural)
MonoidMap a (Sum (Multiplicity 'N))
ns, MonoidMap a (Sum (Multiplicity 'N)) -> MultiSet 'N a
forall (t :: MultiSetType) a.
MultiplicityConstraints (Multiplicity t) =>
MonoidMap a (Sum (Multiplicity t)) -> MultiSet t a
MultiSet MonoidMap a (Sum Natural)
MonoidMap a (Sum (Multiplicity 'N))
ps)
  where
    ns :: MonoidMap a (Sum Natural)
ns = (Sum Integer -> Sum Natural)
-> MonoidMap a (Sum Integer) -> MonoidMap a (Sum Natural)
forall v2 v1 k.
MonoidNull v2 =>
(v1 -> v2) -> MonoidMap k v1 -> MonoidMap k v2
MonoidMap.map ((Integer -> Natural) -> Sum Integer -> Sum Natural
forall a b. (a -> b) -> Sum a -> Sum b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Integer -> Natural
integerNegativePartToNatural) MonoidMap a (Sum Integer)
MonoidMap a (Sum (Multiplicity 'Z))
s
    ps :: MonoidMap a (Sum Natural)
ps = (Sum Integer -> Sum Natural)
-> MonoidMap a (Sum Integer) -> MonoidMap a (Sum Natural)
forall v2 v1 k.
MonoidNull v2 =>
(v1 -> v2) -> MonoidMap k v1 -> MonoidMap k v2
MonoidMap.map ((Integer -> Natural) -> Sum Integer -> Sum Natural
forall a b. (a -> b) -> Sum a -> Sum b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Integer -> Natural
integerPositivePartToNatural) MonoidMap a (Sum Integer)
MonoidMap a (Sum (Multiplicity 'Z))
s

cardinality :: MultiSet t a -> Multiplicity t
cardinality :: forall (t :: MultiSetType) a. MultiSet t a -> Multiplicity t
cardinality (MultiSet MonoidMap a (Sum (Multiplicity t))
s) = Sum (Multiplicity t) -> Multiplicity t
forall a. Sum a -> a
getSum (Sum (Multiplicity t) -> Multiplicity t)
-> Sum (Multiplicity t) -> Multiplicity t
forall a b. (a -> b) -> a -> b
$ MonoidMap a (Sum (Multiplicity t)) -> Sum (Multiplicity t)
forall m. Monoid m => MonoidMap a m -> m
forall (t :: * -> *) m. (Foldable t, Monoid m) => t m -> m
F.fold MonoidMap a (Sum (Multiplicity t))
s

multiplicity :: Ord a => a -> MultiSet t a -> Multiplicity t
multiplicity :: forall a (t :: MultiSetType).
Ord a =>
a -> MultiSet t a -> Multiplicity t
multiplicity a
a (MultiSet MonoidMap a (Sum (Multiplicity t))
s) = Sum (Multiplicity t) -> Multiplicity t
forall a. Sum a -> a
getSum (Sum (Multiplicity t) -> Multiplicity t)
-> Sum (Multiplicity t) -> Multiplicity t
forall a b. (a -> b) -> a -> b
$ a -> MonoidMap a (Sum (Multiplicity t)) -> Sum (Multiplicity t)
forall k v. (Ord k, Monoid v) => k -> MonoidMap k v -> v
MonoidMap.get a
a MonoidMap a (Sum (Multiplicity t))
s

maximum :: MultiSet t a -> Multiplicity t
maximum :: forall (t :: MultiSetType) a. MultiSet t a -> Multiplicity t
maximum (MultiSet MonoidMap a (Sum (Multiplicity t))
s) = if MonoidMap a (Sum (Multiplicity t)) -> Bool
forall k v. MonoidMap k v -> Bool
MonoidMap.null MonoidMap a (Sum (Multiplicity t))
s then Multiplicity t
0 else Sum (Multiplicity t) -> Multiplicity t
forall a. Sum a -> a
getSum (Sum (Multiplicity t) -> Multiplicity t)
-> Sum (Multiplicity t) -> Multiplicity t
forall a b. (a -> b) -> a -> b
$ MonoidMap a (Sum (Multiplicity t)) -> Sum (Multiplicity t)
forall a. Ord a => MonoidMap a a -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
F.maximum MonoidMap a (Sum (Multiplicity t))
s

minimum :: MultiSet t a -> Multiplicity t
minimum :: forall (t :: MultiSetType) a. MultiSet t a -> Multiplicity t
minimum (MultiSet MonoidMap a (Sum (Multiplicity t))
s) = if MonoidMap a (Sum (Multiplicity t)) -> Bool
forall k v. MonoidMap k v -> Bool
MonoidMap.null MonoidMap a (Sum (Multiplicity t))
s then Multiplicity t
0 else Sum (Multiplicity t) -> Multiplicity t
forall a. Sum a -> a
getSum (Sum (Multiplicity t) -> Multiplicity t)
-> Sum (Multiplicity t) -> Multiplicity t
forall a b. (a -> b) -> a -> b
$ MonoidMap a (Sum (Multiplicity t)) -> Sum (Multiplicity t)
forall a. Ord a => MonoidMap a a -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
F.minimum MonoidMap a (Sum (Multiplicity t))
s

invert :: MultiSet t a -> MultiSetZ a
invert :: forall (t :: MultiSetType) a. MultiSet t a -> MultiSetZ a
invert (MultiSet MonoidMap a (Sum (Multiplicity t))
s) =
    MonoidMap a (Sum (Multiplicity 'Z)) -> MultiSet 'Z a
forall (t :: MultiSetType) a.
MultiplicityConstraints (Multiplicity t) =>
MonoidMap a (Sum (Multiplicity t)) -> MultiSet t a
MultiSet ((Sum (Multiplicity t) -> Sum Integer)
-> MonoidMap a (Sum (Multiplicity t)) -> MonoidMap a (Sum Integer)
forall v2 v1 k.
MonoidNull v2 =>
(v1 -> v2) -> MonoidMap k v1 -> MonoidMap k v2
MonoidMap.map ((Multiplicity t -> Integer) -> Sum (Multiplicity t) -> Sum Integer
forall a b. (a -> b) -> Sum a -> Sum b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Integer -> Integer
forall a. Num a => a -> a
negate (Integer -> Integer)
-> (Multiplicity t -> Integer) -> Multiplicity t -> Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Multiplicity t -> Integer
forall a. Integral a => a -> Integer
toInteger)) MonoidMap a (Sum (Multiplicity t))
s)

intersection :: Ord a => MultiSet t a -> MultiSet t a -> MultiSet t a
intersection :: forall a (t :: MultiSetType).
Ord a =>
MultiSet t a -> MultiSet t a -> MultiSet t a
intersection (MultiSet MonoidMap a (Sum (Multiplicity t))
s1) (MultiSet MonoidMap a (Sum (Multiplicity t))
s2) =
    MonoidMap a (Sum (Multiplicity t)) -> MultiSet t a
forall (t :: MultiSetType) a.
MultiplicityConstraints (Multiplicity t) =>
MonoidMap a (Sum (Multiplicity t)) -> MultiSet t a
MultiSet ((Sum (Multiplicity t)
 -> Sum (Multiplicity t) -> Sum (Multiplicity t))
-> MonoidMap a (Sum (Multiplicity t))
-> MonoidMap a (Sum (Multiplicity t))
-> MonoidMap a (Sum (Multiplicity t))
forall k v3 v1 v2.
(Ord k, MonoidNull v3) =>
(v1 -> v2 -> v3)
-> MonoidMap k v1 -> MonoidMap k v2 -> MonoidMap k v3
MonoidMap.intersectionWith Sum (Multiplicity t)
-> Sum (Multiplicity t) -> Sum (Multiplicity t)
forall a. Ord a => a -> a -> a
min MonoidMap a (Sum (Multiplicity t))
s1 MonoidMap a (Sum (Multiplicity t))
s2)

union :: Ord a => MultiSet t a -> MultiSet t a -> MultiSet t a
union :: forall a (t :: MultiSetType).
Ord a =>
MultiSet t a -> MultiSet t a -> MultiSet t a
union (MultiSet MonoidMap a (Sum (Multiplicity t))
s1) (MultiSet MonoidMap a (Sum (Multiplicity t))
s2) =
    MonoidMap a (Sum (Multiplicity t)) -> MultiSet t a
forall (t :: MultiSetType) a.
MultiplicityConstraints (Multiplicity t) =>
MonoidMap a (Sum (Multiplicity t)) -> MultiSet t a
MultiSet ((Sum (Multiplicity t)
 -> Sum (Multiplicity t) -> Sum (Multiplicity t))
-> MonoidMap a (Sum (Multiplicity t))
-> MonoidMap a (Sum (Multiplicity t))
-> MonoidMap a (Sum (Multiplicity t))
forall k v1 v2 v3.
(Ord k, Monoid v1, Monoid v2, MonoidNull v3) =>
(v1 -> v2 -> v3)
-> MonoidMap k v1 -> MonoidMap k v2 -> MonoidMap k v3
MonoidMap.unionWith Sum (Multiplicity t)
-> Sum (Multiplicity t) -> Sum (Multiplicity t)
forall a. Ord a => a -> a -> a
max MonoidMap a (Sum (Multiplicity t))
s1 MonoidMap a (Sum (Multiplicity t))
s2)

--------------------------------------------------------------------------------
-- Utilities
--------------------------------------------------------------------------------

naturalToInteger :: Natural -> Integer
naturalToInteger :: Natural -> Integer
naturalToInteger = Natural -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral

integerNegativePartToNatural :: Integer -> Natural
integerNegativePartToNatural :: Integer -> Natural
integerNegativePartToNatural Integer
n
    | Integer
n Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
0 = Integer -> Natural
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Integer
forall a. Num a => a -> a
abs Integer
n)
    | Bool
otherwise = Natural
0

integerPositivePartToNatural :: Integer -> Natural
integerPositivePartToNatural :: Integer -> Natural
integerPositivePartToNatural Integer
n
    | Integer
n Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
> Integer
0 = Integer -> Natural
forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
n
    | Bool
otherwise = Natural
0