{-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE DeriveFoldable #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DeriveTraversable #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE NoImplicitPrelude #-} {-# LANGUAGE Rank2Types #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# OPTIONS_GHC -Wall #-} ----------------------------------------------------------------------------- -- | -- A class for semirings (types with two binary operations, one commutative and one associative, and two respective identites), with various general-purpose instances. -- ----------------------------------------------------------------------------- module Data.Semiring ( -- * Semiring typeclass Semiring(..) , (+) , (*) , (^) , foldMapP , foldMapT , sum , product , sum' , product' -- * Types , Add(..) , Mul(..) , WrappedNum(..) -- * Ring typeclass , Ring(..) , (-) , minus ) where import Control.Applicative (Applicative(..), Const(..), liftA2) import Data.Bool (Bool(..), (||), (&&), otherwise, not) import Data.Complex (Complex(..)) import Data.Eq (Eq(..)) import Data.Fixed (Fixed, HasResolution) import Data.Foldable (Foldable) import qualified Data.Foldable as Foldable import Data.Function ((.), const, flip, id) import Data.Functor (Functor(..)) #if MIN_VERSION_base(4,12,0) import Data.Functor.Contravariant (Predicate(..), Equivalence(..), Op(..)) #endif import Data.Functor.Identity (Identity(..)) #if defined(VERSION_unordered_containers) import Data.Hashable (Hashable) import Data.HashMap.Strict (HashMap) import qualified Data.HashMap.Strict as HashMap import Data.HashSet (HashSet) import qualified Data.HashSet as HashSet #endif import Data.Int (Int, Int8, Int16, Int32, Int64) import qualified Data.List as List import Data.Maybe (Maybe(..)) #if MIN_VERSION_base(4,12,0) import Data.Monoid (Ap(..)) #endif #if defined(VERSION_containers) --import Data.IntMap (IntMap) --import qualified Data.IntMap as IntMap --import Data.IntSet (IntSet) --import qualified Data.IntSet as IntSet import Data.Map (Map) import qualified Data.Map as Map #endif import Data.Monoid (Monoid(..),Dual(..), Product(..), Sum(..)) import Data.Ord (Ord(..), Ordering(..), compare) #if MIN_VERSION_base(4,6,0) import Data.Ord (Down(..)) #endif import Data.Proxy (Proxy(..)) import Data.Ratio (Ratio, Rational, (%)) import Data.Semigroup (Semigroup(..),Max(..), Min(..)) #if defined(VERSION_containers) import Data.Set (Set) import qualified Data.Set as Set #endif -- #if defined(VERSION_primitive) -- import Data.Primitive.Array (Array(..)) -- import qualified Data.Primitive.Array as Array -- #endif import Data.Traversable (Traversable) import Data.Typeable (Typeable) #if defined(VERSION_vector) import Data.Vector (Vector) import qualified Data.Vector as Vector import qualified Data.Vector.Storable as SV import qualified Data.Vector.Unboxed as UV #endif import Data.Word (Word, Word8, Word16, Word32, Word64) import Foreign.C.Types (CChar, CClock, CDouble, CFloat, CInt, CIntMax, CIntPtr, CLLong, CLong, CPtrdiff, CSChar, CSUSeconds, CShort, CSigAtomic, CSize, CTime, CUChar, CUInt, CUIntMax, CUIntPtr, CULLong, CULong, CUSeconds, CUShort, CWchar) import Foreign.Ptr (IntPtr, WordPtr) import Foreign.Storable (Storable) import GHC.Base (build) import GHC.Enum (Enum, Bounded) import GHC.Float (Float, Double) #if MIN_VERSION_base(4,6,1) import GHC.Generics (Generic,Generic1) #endif import GHC.IO (IO) import GHC.Integer (Integer) import qualified GHC.Num as Num import GHC.Read (Read) import GHC.Real (Integral, Fractional, Real, RealFrac, quot, even) import GHC.Show (Show) import Numeric.Natural (Natural) import System.Posix.Types (CCc, CDev, CGid, CIno, CMode, CNlink, COff, CPid, CRLim, CSpeed, CSsize, CTcflag, CUid, Fd) infixl 7 *, `times` infixl 6 +, `plus`, -, `minus` infixr 8 ^ {-------------------------------------------------------------------- Helpers --------------------------------------------------------------------} -- | Raise a number to a non-negative integral power. -- If the power is negative, this will return 'zero'. (^) :: (Semiring a, Integral b) => a -> b -> a x0 ^ y0 | y0 < 0 = zero | y0 == 0 = one | otherwise = f x0 y0 where f x y | even y = f (x * x) (y `quot` 2) | y == 1 = x | otherwise = g (x * x) (y `quot` 2) x g x y z | even y = g (x * x) (y `quot` 2) z | y == 1 = x * z | otherwise = g (x * x) (y `quot` 2) (x * z) {-# INLINE (^) #-} -- | Infix shorthand for 'plus'. (+) :: Semiring a => a -> a -> a (+) = plus {-# INLINE (+) #-} -- | Infix shorthand for 'times'. (*) :: Semiring a => a -> a -> a (*) = times {-# INLINE (*) #-} -- | Infix shorthand for 'minus'. (-) :: Ring a => a -> a -> a (-) = minus {-# INLINE (-) #-} -- | Map each element of the structure to a semiring, and combine the results -- using 'plus'. foldMapP :: (Foldable t, Semiring s) => (a -> s) -> t a -> s foldMapP f = Foldable.foldr (plus . f) zero {-# INLINE foldMapP #-} -- | Map each element of the structure to a semiring, and combine the results -- using 'times'. foldMapT :: (Foldable t, Semiring s) => (a -> s) -> t a -> s foldMapT f = Foldable.foldr (times . f) one {-# INLINE foldMapT #-} -- | The 'sum' function computes the additive sum of the elements in a structure. -- This function is lazy. For a strict version, see 'sum''. sum :: (Foldable t, Semiring a) => t a -> a sum = Foldable.foldr plus zero {-# INLINE sum #-} -- | The 'product' function computes the product of the elements in a structure. -- This function is lazy. for a strict version, see 'product''. product :: (Foldable t, Semiring a) => t a -> a product = Foldable.foldr times one {-# INLINE product #-} -- | The 'sum'' function computes the additive sum of the elements in a structure. -- This function is strict. For a lazy version, see 'sum'. sum' :: (Foldable t, Semiring a) => t a -> a sum' = Foldable.foldl' plus zero {-# INLINE sum' #-} -- | The 'product'' function computes the additive sum of the elements in a structure. -- This function is strict. For a lazy version, see 'product'. product' :: (Foldable t, Semiring a) => t a -> a product' = Foldable.foldl' times one {-# INLINE product' #-} -- | Monoid under 'plus'. Analogous to 'Data.Monoid.Sum', but -- uses the 'Semiring' constraint rather than 'Num'. newtype Add a = Add { getAdd :: a } deriving ( Bounded , Enum , Eq , Foldable , Fractional , Functor #if MIN_VERSION_base(4,6,1) , Generic , Generic1 #endif , Num.Num , Ord , Read , Real , RealFrac , Semiring , Show , Storable , Traversable , Typeable ) instance Semiring a => Semigroup (Add a) where (<>) = (+) {-# INLINE (<>) #-} instance Semiring a => Monoid (Add a) where mempty = Add zero mappend = (<>) {-# INLINE mempty #-} {-# INLINE mappend #-} -- | Monoid under 'times'. Analogous to 'Data.Monoid.Product', but -- uses the 'Semiring' constraint rather than 'Num'. newtype Mul a = Mul { getMul :: a } deriving ( Bounded , Enum , Eq , Foldable , Fractional , Functor #if MIN_VERSION_base(4,6,1) , Generic , Generic1 #endif , Num.Num , Ord , Read , Real , RealFrac , Semiring , Show , Storable , Traversable , Typeable ) instance Semiring a => Semigroup (Mul a) where (<>) = (*) {-# INLINE (<>) #-} instance Semiring a => Monoid (Mul a) where mempty = Mul one mappend = (<>) {-# INLINE mempty #-} {-# INLINE mappend #-} -- | Provide Semiring and Ring for an arbitrary Num. It is useful with GHC 8.6+'s DerivingVia extension. newtype WrappedNum a = WrapNum { unwrapNum :: a } deriving ( Bounded , Enum , Eq , Foldable , Fractional , Functor #if MIN_VERSION_base(4,6,1) , Generic , Generic1 #endif , Num.Num , Ord , Read , Real , RealFrac , Show , Storable , Traversable , Typeable ) instance Num.Num a => Semiring (WrappedNum a) where plus = (Num.+) zero = 0 times = (Num.*) one = 1 instance Num.Num a => Ring (WrappedNum a) where negate = Num.negate {-------------------------------------------------------------------- Classes --------------------------------------------------------------------} -- | The class of semirings (types with two binary -- operations and two respective identities). One -- can think of a semiring as two monoids of the same -- underlying type: A commutative monoid and an -- associative monoid. For any type R with a 'Prelude.Num' -- instance, the commutative monoid is (R, '(Prelude.+)', 0) -- and the associative monoid is (R, '(Prelude.*)', 1). -- -- Instances should satisfy the following laws: -- -- [/additive identity/] -- -- @x '+' 'zero' = 'zero' '+' x = x@ -- -- [/additive associativity/] -- -- @x '+' (y '+' z) = (x '+' y) '+' z@ -- -- [/additive commutativity/] -- -- @x '+' y = y '+' x@ -- -- [/multiplicative identity/] -- -- @x '*' 'one' = 'one' '*' x = x@ -- -- [/multiplicative associativity/] -- -- @x '*' (y '*' z) = (x '*' y) '*' z@ -- -- [/left- and right-distributivity of '*' over '+'/] -- -- @x '*' (y '+' z) = (x '*' y) '+' (x '*' z)@ -- @(x '+' y) '*' z = (x '*' z) '+' (y '*' z)@ -- -- [/annihilation/] -- -- @'zero' '*' x = x '*' 'zero' = 'zero'@ class Semiring a where #if __GLASGOW_HASKELL__ >= 708 {-# MINIMAL plus, zero, times, one #-} #endif plus :: a -> a -> a -- ^ Commutative Operation zero :: a -- ^ Commutative Unit times :: a -> a -> a -- ^ Associative Operation one :: a -- ^ Associative Unit -- | The class of semirings with an additive inverse. -- -- @'negate' a '+' a = 'zero'@ class Semiring a => Ring a where #if __GLASGOW_HASKELL__ >= 708 {-# MINIMAL negate #-} #endif negate :: a -> a -- | Substract two 'Ring' values. For any type 'R' with -- a 'Prelude.Num' instance, this is the same as '(Prelude.-)'. -- -- @x `minus` y = x '+' 'negate' y@ minus :: Ring a => a -> a -> a minus x y = x + negate y {-# INLINE minus #-} {-------------------------------------------------------------------- Instances (base) --------------------------------------------------------------------} instance Semiring b => Semiring (a -> b) where plus f g x = f x `plus` g x zero = const zero times f g x = f x `times` g x one = const one {-# INLINE plus #-} {-# INLINE zero #-} {-# INLINE times #-} {-# INLINE one #-} instance Ring b => Ring (a -> b) where negate f x = negate (f x) {-# INLINE negate #-} instance Semiring () where plus _ _ = () zero = () times _ _ = () one = () {-# INLINE plus #-} {-# INLINE zero #-} {-# INLINE times #-} {-# INLINE one #-} instance Ring () where negate _ = () {-# INLINE negate #-} instance Semiring (Proxy a) where plus _ _ = Proxy zero = Proxy times _ _ = Proxy one = Proxy {-# INLINE plus #-} {-# INLINE zero #-} {-# INLINE times #-} {-# INLINE one #-} instance Semiring Bool where plus = (||) zero = False times = (&&) one = True {-# INLINE plus #-} {-# INLINE zero #-} {-# INLINE times #-} {-# INLINE one #-} instance Ring Bool where negate = not {-# INLINE negate #-} -- See Section: List fusion instance Semiring a => Semiring [a] where zero = [] one = [one] plus = listAdd times = listTimes {-# INLINE plus #-} {-# INLINE zero #-} {-# INLINE times #-} {-# INLINE one #-} instance Ring a => Ring [a] where negate = fmap negate {-# INLINE negate #-} instance Semiring a => Semiring (Maybe a) where zero = Nothing one = Just one plus Nothing y = y plus x Nothing = x plus (Just x) (Just y) = Just (plus x y) times Nothing _ = Nothing times _ Nothing = Nothing times (Just x) (Just y) = Just (times x y) {-# INLINE plus #-} {-# INLINE zero #-} {-# INLINE times #-} {-# INLINE one #-} instance Ring a => Ring (Maybe a) where negate = fmap negate {-# INLINE negate #-} instance Semiring a => Semiring (IO a) where zero = pure zero one = pure one plus = liftA2 plus times = liftA2 times {-# INLINE plus #-} {-# INLINE zero #-} {-# INLINE times #-} {-# INLINE one #-} instance Ring a => Ring (IO a) where negate = fmap negate {-# INLINE negate #-} instance Semiring a => Semiring (Dual a) where zero = Dual zero Dual x `plus` Dual y = Dual (y `plus` x) one = Dual one Dual x `times` Dual y = Dual (y `times` x) {-# INLINE plus #-} {-# INLINE zero #-} {-# INLINE times #-} {-# INLINE one #-} instance Ring a => Ring (Dual a) where negate (Dual x) = Dual (negate x) {-# INLINE negate #-} instance Semiring a => Semiring (Const a b) where zero = Const zero one = Const one plus (Const x) (Const y) = Const (x `plus` y) times (Const x) (Const y) = Const (x `times` y) {-# INLINE plus #-} {-# INLINE zero #-} {-# INLINE times #-} {-# INLINE one #-} instance Ring a => Ring (Const a b) where negate (Const x) = Const (negate x) {-# INLINE negate #-} -- | This instance can suffer due to floating point arithmetic. instance Ring a => Semiring (Complex a) where zero = zero :+ zero one = one :+ zero plus (x :+ y) (x' :+ y') = plus x x' :+ plus y y' times (x :+ y) (x' :+ y') = (x * x' - (y * y')) :+ (x * y' + y * x') {-# INLINE plus #-} {-# INLINE zero #-} {-# INLINE times #-} {-# INLINE one #-} instance Ring a => Ring (Complex a) where negate (x :+ y) = negate x :+ negate y {-# INLINE negate #-} #if MIN_VERSION_base(4,12,0) instance (Semiring a, Applicative f) => Semiring (Ap f a) where zero = pure zero one = pure one plus = liftA2 plus times = liftA2 times {-# INLINE plus #-} {-# INLINE zero #-} {-# INLINE times #-} {-# INLINE one #-} instance (Ring a, Applicative f) => Ring (Ap f a) where negate = fmap negate {-# INLINE negate #-} #endif #if MIN_VERSION_base(4,12,0) deriving instance Semiring (Predicate a) deriving instance Ring (Predicate a) deriving instance Semiring a => Semiring (Equivalence a) deriving instance Ring a => Ring (Equivalence a) deriving instance Semiring a => Semiring (Op a b) deriving instance Ring a => Ring (Op a b) #endif #define deriveSemiring(ty) \ instance Semiring (ty) where { \ zero = 0 \ ; one = 1 \ ; plus x y = (Num.+) x y \ ; times x y = (Num.*) x y \ } deriveSemiring(Int) deriveSemiring(Int8) deriveSemiring(Int16) deriveSemiring(Int32) deriveSemiring(Int64) deriveSemiring(Integer) deriveSemiring(Word) deriveSemiring(Word8) deriveSemiring(Word16) deriveSemiring(Word32) deriveSemiring(Word64) deriveSemiring(Float) deriveSemiring(Double) deriveSemiring(CUIntMax) deriveSemiring(CIntMax) deriveSemiring(CUIntPtr) deriveSemiring(CIntPtr) deriveSemiring(CSUSeconds) deriveSemiring(CUSeconds) deriveSemiring(CTime) deriveSemiring(CClock) deriveSemiring(CSigAtomic) deriveSemiring(CWchar) deriveSemiring(CSize) deriveSemiring(CPtrdiff) deriveSemiring(CDouble) deriveSemiring(CFloat) deriveSemiring(CULLong) deriveSemiring(CLLong) deriveSemiring(CULong) deriveSemiring(CLong) deriveSemiring(CUInt) deriveSemiring(CInt) deriveSemiring(CUShort) deriveSemiring(CShort) deriveSemiring(CUChar) deriveSemiring(CSChar) deriveSemiring(CChar) deriveSemiring(IntPtr) deriveSemiring(WordPtr) deriveSemiring(Fd) deriveSemiring(CRLim) deriveSemiring(CTcflag) deriveSemiring(CSpeed) deriveSemiring(CCc) deriveSemiring(CUid) deriveSemiring(CNlink) deriveSemiring(CGid) deriveSemiring(CSsize) deriveSemiring(CPid) deriveSemiring(COff) deriveSemiring(CMode) deriveSemiring(CIno) deriveSemiring(CDev) deriveSemiring(Natural) instance Integral a => Semiring (Ratio a) where {-# SPECIALIZE instance Semiring Rational #-} zero = 0 % 1 one = 1 % 1 plus = (Num.+) times = (Num.*) {-# INLINE zero #-} {-# INLINE one #-} {-# INLINE plus #-} {-# INLINE times #-} deriving instance Semiring a => Semiring (Product a) deriving instance Semiring a => Semiring (Sum a) deriving instance Semiring a => Semiring (Identity a) #if MIN_VERSION_base(4,6,0) deriving instance Semiring a => Semiring (Down a) #endif deriving instance Semiring a => Semiring (Max a) deriving instance Semiring a => Semiring (Min a) instance HasResolution a => Semiring (Fixed a) where zero = 0 one = 1 plus = (Num.+) times = (Num.*) {-# INLINE zero #-} {-# INLINE one #-} {-# INLINE plus #-} {-# INLINE times #-} #define deriveRing(ty) \ instance Ring (ty) where { \ negate = Num.negate \ } deriveRing(Int) deriveRing(Int8) deriveRing(Int16) deriveRing(Int32) deriveRing(Int64) deriveRing(Integer) deriveRing(Word) deriveRing(Word8) deriveRing(Word16) deriveRing(Word32) deriveRing(Word64) deriveRing(Float) deriveRing(Double) deriveRing(CUIntMax) deriveRing(CIntMax) deriveRing(CUIntPtr) deriveRing(CIntPtr) deriveRing(CSUSeconds) deriveRing(CUSeconds) deriveRing(CTime) deriveRing(CClock) deriveRing(CSigAtomic) deriveRing(CWchar) deriveRing(CSize) deriveRing(CPtrdiff) deriveRing(CDouble) deriveRing(CFloat) deriveRing(CULLong) deriveRing(CLLong) deriveRing(CULong) deriveRing(CLong) deriveRing(CUInt) deriveRing(CInt) deriveRing(CUShort) deriveRing(CShort) deriveRing(CUChar) deriveRing(CSChar) deriveRing(CChar) deriveRing(IntPtr) deriveRing(WordPtr) deriveRing(Fd) deriveRing(CRLim) deriveRing(CTcflag) deriveRing(CSpeed) deriveRing(CCc) deriveRing(CUid) deriveRing(CNlink) deriveRing(CGid) deriveRing(CSsize) deriveRing(CPid) deriveRing(COff) deriveRing(CMode) deriveRing(CIno) deriveRing(CDev) deriveRing(Natural) instance Integral a => Ring (Ratio a) where negate = Num.negate {-# INLINE negate #-} #if MIN_VERSION_base(4,6,0) deriving instance Ring a => Ring (Down a) #endif deriving instance Ring a => Ring (Product a) deriving instance Ring a => Ring (Sum a) deriving instance Ring a => Ring (Identity a) deriving instance Ring a => Ring (Max a) deriving instance Ring a => Ring (Min a) instance HasResolution a => Ring (Fixed a) where negate = Num.negate {-# INLINE negate #-} {-------------------------------------------------------------------- Instances (containers) --------------------------------------------------------------------} #if defined(VERSION_containers) -- | The multiplication laws are satisfied for -- any underlying 'Monoid', so we require a -- 'Monoid' contraint instead of a 'Semiring' -- constraint since 'times' can use -- the context of either. instance (Ord a, Monoid a) => Semiring (Set a) where zero = Set.empty one = Set.singleton mempty plus = Set.union times xs ys = Foldable.foldMap (flip Set.map ys . mappend) xs {-# INLINE plus #-} {-# INLINE zero #-} {-# INLINE times #-} {-# INLINE one #-} -- | The multiplication laws are satisfied for -- any underlying 'Monoid' as the key type, -- so we require a 'Monoid' contraint instead of -- a 'Semiring' constraint since 'times' can use -- the context of either. instance (Ord k, Monoid k, Semiring v) => Semiring (Map k v) where zero = Map.empty one = Map.singleton mempty one plus = Map.unionWith (+) xs `times` ys = Map.fromListWith (+) [ (mappend k l, v * u) | (k,v) <- Map.toList xs , (l,u) <- Map.toList ys ] {-# INLINE plus #-} {-# INLINE zero #-} {-# INLINE times #-} {-# INLINE one #-} --newtype IntSetP = IntSetP { intSetP :: IntSet } --newtype IntSetT = IntSetT { intSetT :: IntSet } -- --instance Semiring IntSetP where -- zero = IntSetP (IntSet.empty) -- one = IntSetP (IntSet.singleton zero) -- plus (IntSetP x) (IntSetP y) = IntSetP (IntSet.union x y) -- times (IntSetP xs) (IntSetP ys) = IntSetP (foldMapIntSet (flip IntSet.map ys . plus) xs) -- --instance Semiring IntSetT where -- zero = IntSetT IntSet.empty -- one = IntSetT (IntSet.singleton one) -- plus (IntSetT x) (IntSetT y) = IntSetT (IntSet.union x y) -- times (IntSetT xs) (IntSetT ys) = IntSetT (foldMapIntSet (flip IntSet.map ys . times) xs) -- --foldMapIntSet :: Monoid m => (Int -> m) -> IntSet -> m --foldMapIntSet f = IntSet.foldl' (flip (mappend . f)) mempty --{-# INLINE foldMapIntSet #-} --instance (Semiring a) => Semiring (IntMap a) where -- zero = IntMap.empty -- one = IntMap.singleton zero one -- plus = IntMap.unionWith (+) -- xs `times` ys -- = IntMap.fromListWith (+) -- [ (plus k l, v * u) -- | (k,v) <- IntMap.toList xs -- , (l,u) <- IntMap.toList ys -- ] #endif {-------------------------------------------------------------------- Instances (unordered-containers) --------------------------------------------------------------------} #if defined(VERSION_unordered_containers) -- | The multiplication laws are satisfied for -- any underlying 'Monoid', so we require a -- 'Monoid' contraint instead of a 'Semiring' -- constraint since 'times' can use -- the context of either. instance (Eq a, Hashable a, Monoid a) => Semiring (HashSet a) where zero = HashSet.empty one = HashSet.singleton mempty plus = HashSet.union times xs ys = Foldable.foldMap (flip HashSet.map ys . mappend) xs {-# INLINE plus #-} {-# INLINE zero #-} {-# INLINE times #-} {-# INLINE one #-} -- | The multiplication laws are satisfied for -- any underlying 'Monoid' as the key type, -- so we require a 'Monoid' contraint instead of -- a 'Semiring' constraint since 'times' can use -- the context of either. instance (Eq k, Hashable k, Monoid k, Semiring v) => Semiring (HashMap k v) where zero = HashMap.empty one = HashMap.singleton mempty one plus = HashMap.unionWith (+) xs `times` ys = HashMap.fromListWith (+) [ (mappend k l, v * u) | (k,v) <- HashMap.toList xs , (l,u) <- HashMap.toList ys ] {-# INLINE plus #-} {-# INLINE zero #-} {-# INLINE times #-} {-# INLINE one #-} #endif {-------------------------------------------------------------------- Instances (primitive) --------------------------------------------------------------------} #if defined(VERSION_primitive) -- | The multiplication laws are satisfied for -- any underlying 'Monoid', so we require a -- 'Monoid' contraint instead of a 'Semiring' -- constraint since 'times' can use -- the context of either. -- instance (Monoid a) => Semiring (Array a) where -- zero = mempty -- one = runST e where -- e :: forall s. Monoid a => ST s (Array a) -- e = (Array.newArray 1 mempty) >>= Array.unsafeFreezeArray -- plus _ _ = mempty -- times _ _ = mempty -- {-# INLINE plus #-} -- {-# INLINE zero #-} -- {-# INLINE times #-} -- {-# INLINE one #-} #endif {-------------------------------------------------------------------- Instances (vector) --------------------------------------------------------------------} #if defined(VERSION_vector) instance Semiring a => Semiring (Vector a) where zero = Vector.empty one = Vector.singleton one plus xs ys = case compare (Vector.length xs) (Vector.length ys) of EQ -> Vector.zipWith (+) xs ys LT -> Vector.unsafeAccumulate (+) ys (Vector.indexed xs) GT -> Vector.unsafeAccumulate (+) xs (Vector.indexed ys) times signal kernel | Vector.null signal = Vector.empty | Vector.null kernel = Vector.empty | otherwise = Vector.generate (slen + klen - 1) f where !slen = Vector.length signal !klen = Vector.length kernel f n = Foldable.foldl' (\a k -> a + Vector.unsafeIndex signal k * Vector.unsafeIndex kernel (n - k) ) zero [kmin .. kmax] where !kmin = max 0 (n - (klen - 1)) !kmax = min n (slen - 1) {-# INLINE plus #-} {-# INLINE zero #-} {-# INLINE times #-} {-# INLINE one #-} instance Ring a => Ring (Vector a) where negate = Vector.map negate {-# INLINE negate #-} instance (UV.Unbox a, Semiring a) => Semiring (UV.Vector a) where zero = UV.empty one = UV.singleton one plus xs ys = case compare (UV.length xs) (UV.length ys) of EQ -> UV.zipWith (+) xs ys LT -> UV.unsafeAccumulate (+) ys (UV.indexed xs) GT -> UV.unsafeAccumulate (+) xs (UV.indexed ys) times signal kernel | UV.null signal = UV.empty | UV.null kernel = UV.empty | otherwise = UV.generate (slen + klen - 1) f where !slen = UV.length signal !klen = UV.length kernel f n = Foldable.foldl' (\a k -> a + UV.unsafeIndex signal k * UV.unsafeIndex kernel (n - k) ) zero [kmin .. kmax] where !kmin = max 0 (n - (klen - 1)) !kmax = min n (slen - 1) {-# INLINE plus #-} {-# INLINE zero #-} {-# INLINE times #-} {-# INLINE one #-} instance (UV.Unbox a, Ring a) => Ring (UV.Vector a) where negate = UV.map negate {-# INLINE negate #-} instance (SV.Storable a, Semiring a) => Semiring (SV.Vector a) where zero = SV.empty one = SV.singleton one plus xs ys = case compare lxs lys of EQ -> SV.zipWith (+) xs ys LT -> SV.unsafeAccumulate_ (+) ys (SV.enumFromN 0 lxs) xs GT -> SV.unsafeAccumulate_ (+) xs (SV.enumFromN 0 lys) ys where lxs = SV.length xs lys = SV.length ys times signal kernel | SV.null signal = SV.empty | SV.null kernel = SV.empty | otherwise = SV.generate (slen + klen - 1) f where !slen = SV.length signal !klen = SV.length kernel f n = Foldable.foldl' (\a k -> a + SV.unsafeIndex signal k * SV.unsafeIndex kernel (n - k)) zero [kmin .. kmax] where !kmin = max 0 (n - (klen - 1)) !kmax = min n (slen - 1) {-# INLINE plus #-} {-# INLINE zero #-} {-# INLINE times #-} {-# INLINE one #-} instance (SV.Storable a, Ring a) => Ring (SV.Vector a) where negate = SV.map negate {-# INLINE negate #-} #endif -- [Section: List fusion] -- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ listAdd, listTimes :: Semiring a => [a] -> [a] -> [a] listAdd [] ys = ys listAdd xs [] = xs listAdd (x:xs) (y:ys) = (x + y) : listAdd xs ys {-# NOINLINE [0] listAdd #-} listTimes _ [] = [] listTimes xs ys = List.foldr f [] xs where f x zs = List.foldr (g x) id ys (zero : zs) g x y a [] = x `times` y : a [] g x y a (z:zs) = x `times` y `plus` z : a zs {-# NOINLINE [0] listTimes #-} type ListBuilder a = forall b. (a -> b -> b) -> b -> b {-# RULES "listAddFB/left" forall (g :: ListBuilder a). listAdd (build g) = listAddFBL g "listAddFB/right" forall xs (g :: ListBuilder a). listAdd xs (build g) = listAddFBR xs g #-} -- a definition of listAdd which can be fused on its left argument listAddFBL :: Semiring a => ListBuilder a -> [a] -> [a] listAddFBL xf = xf f id where f x xs (y:ys) = x + y : xs ys f x xs [] = x : xs [] -- a definition of listAdd which can be fused on its right argument listAddFBR :: Semiring a => [a] -> ListBuilder a -> [a] listAddFBR xs' yf = yf f id xs' where f y ys (x:xs) = x + y : ys xs f y ys [] = y : ys []