{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
#if MIN_VERSION_base(4,7,0) && !MIN_VERSION_base(4,8,0)
{-# LANGUAGE UndecidableInstances #-}
#endif
module Data.Semiring
(
Semiring(..)
, (+)
, (*)
, (^)
, foldMapP
, foldMapT
, sum
, product
, sum'
, product'
, Add(..)
, Mul(..)
, WrappedNum(..)
#if MIN_VERSION_base(4,7,0)
, IntSetOf(..)
, IntMapOf(..)
#endif
, Ring(..)
, (-)
, minus
) where
import Control.Applicative (Applicative(..), Const(..), liftA2)
import Data.Bool (Bool(..), (||), (&&), not)
#if MIN_VERSION_base(4,7,0)
import Data.Coerce (Coercible, coerce)
#endif
import Data.Complex (Complex(..))
import Data.Eq (Eq(..))
import Data.Fixed (Fixed, HasResolution)
import Data.Foldable (Foldable(foldMap))
import qualified Data.Foldable as Foldable
import Data.Function ((.), const, id)
#if defined(VERSION_unordered_containers) || defined(VERSION_containers)
import Data.Function (flip)
#endif
import Data.Functor (Functor(..))
#if MIN_VERSION_base(4,12,0)
import Data.Functor.Contravariant (Predicate(..), Equivalence(..), Op(..))
#endif
import Data.Functor.Identity (Identity(..))
#if defined(VERSION_unordered_containers)
import Data.Hashable (Hashable)
import Data.HashMap.Strict (HashMap)
import qualified Data.HashMap.Strict as HashMap
import Data.HashSet (HashSet)
import qualified Data.HashSet as HashSet
#endif
import Data.Int (Int, Int8, Int16, Int32, Int64)
import qualified Data.List as List
import Data.Maybe (Maybe(..))
#if MIN_VERSION_base(4,12,0)
import Data.Monoid (Ap(..))
#endif
#if defined(VERSION_containers)
#if MIN_VERSION_base(4,7,0)
import Data.IntMap (IntMap)
import qualified Data.IntMap as IntMap
import Data.IntSet (IntSet)
import qualified Data.IntSet as IntSet
#endif
import Data.Map (Map)
import qualified Data.Map as Map
#endif
import Data.Monoid (Monoid(..), Dual(..))
import Data.Ord (Ord)
#if MIN_VERSION_base(4,6,0)
import Data.Ord (Down(..))
#endif
import Data.Proxy (Proxy(..))
import Data.Ratio (Ratio, Rational, (%))
import Data.Semigroup (Semigroup(..))
#if defined(VERSION_containers)
import Data.Set (Set)
import qualified Data.Set as Set
#endif
import Data.Traversable (Traversable)
import Data.Typeable (Typeable)
#if defined(VERSION_vector)
import Data.Bool (otherwise)
import Data.Ord (Ordering(..), compare, min, max)
import Data.Vector (Vector)
import qualified Data.Vector as Vector
import qualified Data.Vector.Storable as SV
import qualified Data.Vector.Unboxed as UV
#endif
import Data.Word (Word, Word8, Word16, Word32, Word64)
import Foreign.C.Types
(CChar, CClock, CDouble, CFloat, CInt,
CIntMax, CIntPtr, CLLong, CLong,
CPtrdiff, CSChar, CSUSeconds, CShort,
CSigAtomic, CSize, CTime, CUChar, CUInt,
CUIntMax, CUIntPtr, CULLong, CULong,
CUSeconds, CUShort, CWchar)
import Foreign.Ptr (IntPtr, WordPtr)
import Foreign.Storable (Storable)
import GHC.Base (build)
import GHC.Enum (Enum, Bounded)
import GHC.Float (Float, Double)
#if MIN_VERSION_base(4,6,1)
import GHC.Generics (Generic,Generic1)
#endif
import GHC.IO (IO)
import GHC.Integer (Integer)
import qualified GHC.Num as Num
import GHC.Read (Read)
import GHC.Real (Integral, Fractional, Real, RealFrac)
import GHC.Show (Show)
import Numeric.Natural (Natural)
import System.Posix.Types
(CCc, CDev, CGid, CIno, CMode, CNlink,
COff, CPid, CRLim, CSpeed, CSsize,
CTcflag, CUid, Fd)
infixl 7 *, `times`
infixl 6 +, `plus`, -, `minus`
infixr 8 ^
{-# SPECIALISE [1] (^) ::
Integer -> Integer -> Integer,
Integer -> Int -> Integer,
Int -> Int -> Int #-}
{-# INLINABLE [1] (^) #-}
(^) :: (Semiring a, Integral b) => a -> b -> a
x ^ y = getMul (stimes y (Mul x))
{-# RULES
"^0/Int" forall x. x ^ (0 :: Int) = one
"^1/Int" forall x. x ^ (1 :: Int) = let u = x in u
"^2/Int" forall x. x ^ (2 :: Int) = let u = x in u*u
"^3/Int" forall x. x ^ (3 :: Int) = let u = x in u*u*u
"^4/Int" forall x. x ^ (4 :: Int) = let u = x in u*u*u*u
"^5/Int" forall x. x ^ (5 :: Int) = let u = x in u*u*u*u*u
"^0/Integer" forall x. x ^ (0 :: Integer) = one
"^1/Integer" forall x. x ^ (1 :: Integer) = let u = x in u
"^2/Integer" forall x. x ^ (2 :: Integer) = let u = x in u*u
"^3/Integer" forall x. x ^ (3 :: Integer) = let u = x in u*u*u
"^4/Integer" forall x. x ^ (4 :: Integer) = let u = x in u*u*u*u
"^5/Integer" forall x. x ^ (5 :: Integer) = let u = x in u*u*u*u*u
#-}
(+) :: Semiring a => a -> a -> a
(+) = plus
{-# INLINE (+) #-}
(*) :: Semiring a => a -> a -> a
(*) = times
{-# INLINE (*) #-}
(-) :: Ring a => a -> a -> a
(-) = minus
{-# INLINE (-) #-}
foldMapP :: (Foldable t, Semiring s) => (a -> s) -> t a -> s
foldMapP f = Foldable.foldr (plus . f) zero
{-# INLINE foldMapP #-}
foldMapT :: (Foldable t, Semiring s) => (a -> s) -> t a -> s
foldMapT f = Foldable.foldr (times . f) one
{-# INLINE foldMapT #-}
#if MIN_VERSION_base(4,7,0)
infixr 9 #.
(#.) :: Coercible b c => (b -> c) -> (a -> b) -> a -> c
(#.) _ = coerce
sum :: (Foldable t, Semiring a) => t a -> a
sum = getAdd #. foldMap Add
{-# INLINE sum #-}
product :: (Foldable t, Semiring a) => t a -> a
product = getMul #. foldMap Mul
{-# INLINE product #-}
#else
sum :: (Foldable t, Semiring a) => t a -> a
sum = getAdd . foldMap Add
{-# INLINE sum #-}
product :: (Foldable t, Semiring a) => t a -> a
product = getMul . foldMap Mul
{-# INLINE product #-}
#endif
sum' :: (Foldable t, Semiring a) => t a -> a
sum' = Foldable.foldl' plus zero
{-# INLINE sum' #-}
product' :: (Foldable t, Semiring a) => t a -> a
product' = Foldable.foldl' times one
{-# INLINE product' #-}
newtype Add a = Add { getAdd :: a }
deriving
( Bounded
, Enum
, Eq
, Foldable
, Fractional
, Functor
#if MIN_VERSION_base(4,6,1)
, Generic
, Generic1
#endif
, Num.Num
, Ord
, Read
, Real
, RealFrac
, Show
, Storable
, Traversable
, Typeable
)
instance Semiring a => Semigroup (Add a) where
Add a <> Add b = Add (a + b)
{-# INLINE (<>) #-}
instance Semiring a => Monoid (Add a) where
mempty = Add zero
mappend = (<>)
{-# INLINE mempty #-}
{-# INLINE mappend #-}
newtype Mul a = Mul { getMul :: a }
deriving
( Bounded
, Enum
, Eq
, Foldable
, Fractional
, Functor
#if MIN_VERSION_base(4,6,1)
, Generic
, Generic1
#endif
, Num.Num
, Ord
, Read
, Real
, RealFrac
, Show
, Storable
, Traversable
, Typeable
)
instance Semiring a => Semigroup (Mul a) where
Mul a <> Mul b = Mul (a * b)
{-# INLINE (<>) #-}
instance Semiring a => Monoid (Mul a) where
mempty = Mul one
mappend = (<>)
{-# INLINE mempty #-}
{-# INLINE mappend #-}
newtype WrappedNum a = WrapNum { unwrapNum :: a }
deriving
( Bounded
, Enum
, Eq
, Foldable
, Fractional
, Functor
#if MIN_VERSION_base(4,6,1)
, Generic
, Generic1
#endif
, Num.Num
, Ord
, Read
, Real
, RealFrac
, Show
, Storable
, Traversable
, Typeable
)
instance Num.Num a => Semiring (WrappedNum a) where
plus = (Num.+)
zero = 0
times = (Num.*)
one = 1
instance Num.Num a => Ring (WrappedNum a) where
negate = Num.negate
class Semiring a where
#if __GLASGOW_HASKELL__ >= 708
{-# MINIMAL plus, zero, times, one #-}
#endif
plus :: a -> a -> a
zero :: a
times :: a -> a -> a
one :: a
class Semiring a => Ring a where
#if __GLASGOW_HASKELL__ >= 708
{-# MINIMAL negate #-}
#endif
negate :: a -> a
minus :: Ring a => a -> a -> a
minus x y = x + negate y
{-# INLINE minus #-}
instance Semiring b => Semiring (a -> b) where
plus f g x = f x `plus` g x
zero = const zero
times f g x = f x `times` g x
one = const one
{-# INLINE plus #-}
{-# INLINE zero #-}
{-# INLINE times #-}
{-# INLINE one #-}
instance Ring b => Ring (a -> b) where
negate f x = negate (f x)
{-# INLINE negate #-}
instance Semiring () where
plus _ _ = ()
zero = ()
times _ _ = ()
one = ()
{-# INLINE plus #-}
{-# INLINE zero #-}
{-# INLINE times #-}
{-# INLINE one #-}
instance Ring () where
negate _ = ()
{-# INLINE negate #-}
instance Semiring (Proxy a) where
plus _ _ = Proxy
zero = Proxy
times _ _ = Proxy
one = Proxy
{-# INLINE plus #-}
{-# INLINE zero #-}
{-# INLINE times #-}
{-# INLINE one #-}
instance Semiring Bool where
plus = (||)
zero = False
times = (&&)
one = True
{-# INLINE plus #-}
{-# INLINE zero #-}
{-# INLINE times #-}
{-# INLINE one #-}
instance Ring Bool where
negate = not
{-# INLINE negate #-}
instance Semiring a => Semiring [a] where
zero = []
one = [one]
plus = listAdd
times = listTimes
{-# INLINE plus #-}
{-# INLINE zero #-}
{-# INLINE times #-}
{-# INLINE one #-}
instance Ring a => Ring [a] where
negate = fmap negate
{-# INLINE negate #-}
instance Semiring a => Semiring (Maybe a) where
zero = Nothing
one = Just one
plus Nothing y = y
plus x Nothing = x
plus (Just x) (Just y) = Just (plus x y)
times Nothing _ = Nothing
times _ Nothing = Nothing
times (Just x) (Just y) = Just (times x y)
{-# INLINE plus #-}
{-# INLINE zero #-}
{-# INLINE times #-}
{-# INLINE one #-}
instance Ring a => Ring (Maybe a) where
negate = fmap negate
{-# INLINE negate #-}
instance Semiring a => Semiring (IO a) where
zero = pure zero
one = pure one
plus = liftA2 plus
times = liftA2 times
{-# INLINE plus #-}
{-# INLINE zero #-}
{-# INLINE times #-}
{-# INLINE one #-}
instance Ring a => Ring (IO a) where
negate = fmap negate
{-# INLINE negate #-}
instance Semiring a => Semiring (Dual a) where
zero = Dual zero
Dual x `plus` Dual y = Dual (y `plus` x)
one = Dual one
Dual x `times` Dual y = Dual (y `times` x)
{-# INLINE plus #-}
{-# INLINE zero #-}
{-# INLINE times #-}
{-# INLINE one #-}
instance Ring a => Ring (Dual a) where
negate (Dual x) = Dual (negate x)
{-# INLINE negate #-}
instance Semiring a => Semiring (Const a b) where
zero = Const zero
one = Const one
plus (Const x) (Const y) = Const (x `plus` y)
times (Const x) (Const y) = Const (x `times` y)
{-# INLINE plus #-}
{-# INLINE zero #-}
{-# INLINE times #-}
{-# INLINE one #-}
instance Ring a => Ring (Const a b) where
negate (Const x) = Const (negate x)
{-# INLINE negate #-}
instance Ring a => Semiring (Complex a) where
zero = zero :+ zero
one = one :+ zero
plus (x :+ y) (x' :+ y') = plus x x' :+ plus y y'
times (x :+ y) (x' :+ y')
= (x * x' - (y * y')) :+ (x * y' + y * x')
{-# INLINE plus #-}
{-# INLINE zero #-}
{-# INLINE times #-}
{-# INLINE one #-}
instance Ring a => Ring (Complex a) where
negate (x :+ y) = negate x :+ negate y
{-# INLINE negate #-}
#if MIN_VERSION_base(4,12,0)
instance (Semiring a, Applicative f) => Semiring (Ap f a) where
zero = pure zero
one = pure one
plus = liftA2 plus
times = liftA2 times
{-# INLINE plus #-}
{-# INLINE zero #-}
{-# INLINE times #-}
{-# INLINE one #-}
instance (Ring a, Applicative f) => Ring (Ap f a) where
negate = fmap negate
{-# INLINE negate #-}
#endif
#if MIN_VERSION_base(4,12,0)
deriving instance Semiring (Predicate a)
deriving instance Ring (Predicate a)
deriving instance Semiring a => Semiring (Equivalence a)
deriving instance Ring a => Ring (Equivalence a)
deriving instance Semiring a => Semiring (Op a b)
deriving instance Ring a => Ring (Op a b)
#endif
#define deriveSemiring(ty) \
instance Semiring (ty) where { \
zero = 0 \
; one = 1 \
; plus x y = (Num.+) x y \
; times x y = (Num.*) x y \
; {-# INLINE zero #-} \
; {-# INLINE one #-} \
; {-# INLINE plus #-} \
; {-# INLINE times #-} \
}
deriveSemiring(Int)
deriveSemiring(Int8)
deriveSemiring(Int16)
deriveSemiring(Int32)
deriveSemiring(Int64)
deriveSemiring(Integer)
deriveSemiring(Word)
deriveSemiring(Word8)
deriveSemiring(Word16)
deriveSemiring(Word32)
deriveSemiring(Word64)
deriveSemiring(Float)
deriveSemiring(Double)
deriveSemiring(CUIntMax)
deriveSemiring(CIntMax)
deriveSemiring(CUIntPtr)
deriveSemiring(CIntPtr)
deriveSemiring(CSUSeconds)
deriveSemiring(CUSeconds)
deriveSemiring(CTime)
deriveSemiring(CClock)
deriveSemiring(CSigAtomic)
deriveSemiring(CWchar)
deriveSemiring(CSize)
deriveSemiring(CPtrdiff)
deriveSemiring(CDouble)
deriveSemiring(CFloat)
deriveSemiring(CULLong)
deriveSemiring(CLLong)
deriveSemiring(CULong)
deriveSemiring(CLong)
deriveSemiring(CUInt)
deriveSemiring(CInt)
deriveSemiring(CUShort)
deriveSemiring(CShort)
deriveSemiring(CUChar)
deriveSemiring(CSChar)
deriveSemiring(CChar)
deriveSemiring(IntPtr)
deriveSemiring(WordPtr)
deriveSemiring(Fd)
deriveSemiring(CRLim)
deriveSemiring(CTcflag)
deriveSemiring(CSpeed)
deriveSemiring(CCc)
deriveSemiring(CUid)
deriveSemiring(CNlink)
deriveSemiring(CGid)
deriveSemiring(CSsize)
deriveSemiring(CPid)
deriveSemiring(COff)
deriveSemiring(CMode)
deriveSemiring(CIno)
deriveSemiring(CDev)
deriveSemiring(Natural)
instance Integral a => Semiring (Ratio a) where
{-# SPECIALIZE instance Semiring Rational #-}
zero = 0 % 1
one = 1 % 1
plus = (Num.+)
times = (Num.*)
{-# INLINE zero #-}
{-# INLINE one #-}
{-# INLINE plus #-}
{-# INLINE times #-}
deriving instance Semiring a => Semiring (Identity a)
#if MIN_VERSION_base(4,6,0)
deriving instance Semiring a => Semiring (Down a)
#endif
instance HasResolution a => Semiring (Fixed a) where
zero = 0
one = 1
plus = (Num.+)
times = (Num.*)
{-# INLINE zero #-}
{-# INLINE one #-}
{-# INLINE plus #-}
{-# INLINE times #-}
#define deriveRing(ty) \
instance Ring (ty) where { \
negate = Num.negate \
; {-# INLINE negate #-} \
}
deriveRing(Int)
deriveRing(Int8)
deriveRing(Int16)
deriveRing(Int32)
deriveRing(Int64)
deriveRing(Integer)
deriveRing(Word)
deriveRing(Word8)
deriveRing(Word16)
deriveRing(Word32)
deriveRing(Word64)
deriveRing(Float)
deriveRing(Double)
deriveRing(CUIntMax)
deriveRing(CIntMax)
deriveRing(CUIntPtr)
deriveRing(CIntPtr)
deriveRing(CSUSeconds)
deriveRing(CUSeconds)
deriveRing(CTime)
deriveRing(CClock)
deriveRing(CSigAtomic)
deriveRing(CWchar)
deriveRing(CSize)
deriveRing(CPtrdiff)
deriveRing(CDouble)
deriveRing(CFloat)
deriveRing(CULLong)
deriveRing(CLLong)
deriveRing(CULong)
deriveRing(CLong)
deriveRing(CUInt)
deriveRing(CInt)
deriveRing(CUShort)
deriveRing(CShort)
deriveRing(CUChar)
deriveRing(CSChar)
deriveRing(CChar)
deriveRing(IntPtr)
deriveRing(WordPtr)
deriveRing(Fd)
deriveRing(CRLim)
deriveRing(CTcflag)
deriveRing(CSpeed)
deriveRing(CCc)
deriveRing(CUid)
deriveRing(CNlink)
deriveRing(CGid)
deriveRing(CSsize)
deriveRing(CPid)
deriveRing(COff)
deriveRing(CMode)
deriveRing(CIno)
deriveRing(CDev)
deriveRing(Natural)
instance Integral a => Ring (Ratio a) where
negate = Num.negate
{-# INLINE negate #-}
#if MIN_VERSION_base(4,6,0)
deriving instance Ring a => Ring (Down a)
#endif
deriving instance Ring a => Ring (Identity a)
instance HasResolution a => Ring (Fixed a) where
negate = Num.negate
{-# INLINE negate #-}
#if defined(VERSION_containers)
instance (Ord a, Monoid a) => Semiring (Set a) where
zero = Set.empty
one = Set.singleton mempty
plus = Set.union
times xs ys = Foldable.foldMap (flip Set.map ys . mappend) xs
{-# INLINE plus #-}
{-# INLINE zero #-}
{-# INLINE times #-}
{-# INLINE one #-}
#if MIN_VERSION_base(4,7,0)
newtype IntSetOf a = IntSetOf { getIntSet :: IntSet }
deriving
( Eq
#if MIN_VERSION_base(4,6,1)
, Generic
, Generic1
#endif
, Ord
, Read
, Show
, Typeable
, Semigroup
, Monoid
)
instance (Coercible Int a, Monoid a) => Semiring (IntSetOf a) where
zero = coerce IntSet.empty
one = coerce IntSet.singleton (mempty :: a)
plus = coerce IntSet.union
xs `times` ys
= coerce IntSet.fromList
[ mappend k l
| k :: a <- coerce IntSet.toList xs
, l :: a <- coerce IntSet.toList ys
]
{-# INLINE plus #-}
{-# INLINE zero #-}
{-# INLINE times #-}
{-# INLINE one #-}
#endif
instance (Ord k, Monoid k, Semiring v) => Semiring (Map k v) where
zero = Map.empty
one = Map.singleton mempty one
plus = Map.unionWith (+)
xs `times` ys
= Map.fromListWith (+)
[ (mappend k l, v * u)
| (k,v) <- Map.toList xs
, (l,u) <- Map.toList ys
]
{-# INLINE plus #-}
{-# INLINE zero #-}
{-# INLINE times #-}
{-# INLINE one #-}
#if MIN_VERSION_base(4,7,0)
newtype IntMapOf k v = IntMapOf { getIntMap :: IntMap v }
deriving
( Eq
#if MIN_VERSION_base(4,6,1)
, Generic
, Generic1
#endif
, Ord
, Read
, Show
, Typeable
, Semigroup
, Monoid
)
instance (Coercible Int k, Monoid k, Semiring v) => Semiring (IntMapOf k v) where
zero = coerce (IntMap.empty :: IntMap v)
one = coerce (IntMap.singleton :: Int -> v -> IntMap v) (mempty :: k) (one :: v)
plus = coerce (IntMap.unionWith (+) :: IntMap v -> IntMap v -> IntMap v)
xs `times` ys
= coerce (IntMap.fromListWith (+) :: [(Int, v)] -> IntMap v)
[ (mappend k l, v * u)
| (k :: k, v :: v) <- coerce (IntMap.toList :: IntMap v -> [(Int, v)]) xs
, (l :: k, u :: v) <- coerce (IntMap.toList :: IntMap v -> [(Int, v)]) ys
]
{-# INLINE plus #-}
{-# INLINE zero #-}
{-# INLINE times #-}
{-# INLINE one #-}
#endif
#endif
#if defined(VERSION_unordered_containers)
instance (Eq a, Hashable a, Monoid a) => Semiring (HashSet a) where
zero = HashSet.empty
one = HashSet.singleton mempty
plus = HashSet.union
times xs ys = Foldable.foldMap (flip HashSet.map ys . mappend) xs
{-# INLINE plus #-}
{-# INLINE zero #-}
{-# INLINE times #-}
{-# INLINE one #-}
instance (Eq k, Hashable k, Monoid k, Semiring v) => Semiring (HashMap k v) where
zero = HashMap.empty
one = HashMap.singleton mempty one
plus = HashMap.unionWith (+)
xs `times` ys
= HashMap.fromListWith (+)
[ (mappend k l, v * u)
| (k,v) <- HashMap.toList xs
, (l,u) <- HashMap.toList ys
]
{-# INLINE plus #-}
{-# INLINE zero #-}
{-# INLINE times #-}
{-# INLINE one #-}
#endif
#if defined(VERSION_primitive)
#endif
#if defined(VERSION_vector)
instance Semiring a => Semiring (Vector a) where
zero = Vector.empty
one = Vector.singleton one
plus xs ys =
case compare (Vector.length xs) (Vector.length ys) of
EQ -> Vector.zipWith (+) xs ys
LT -> Vector.unsafeAccumulate (+) ys (Vector.indexed xs)
GT -> Vector.unsafeAccumulate (+) xs (Vector.indexed ys)
times signal kernel
| Vector.null signal = Vector.empty
| Vector.null kernel = Vector.empty
| otherwise = Vector.generate (slen + klen - 1) f
where
!slen = Vector.length signal
!klen = Vector.length kernel
f n = Foldable.foldl'
(\a k -> a +
Vector.unsafeIndex signal k *
Vector.unsafeIndex kernel (n - k)
)
zero
[kmin .. kmax]
where
!kmin = max 0 (n - (klen - 1))
!kmax = min n (slen - 1)
{-# INLINE plus #-}
{-# INLINE zero #-}
{-# INLINE times #-}
{-# INLINE one #-}
instance Ring a => Ring (Vector a) where
negate = Vector.map negate
{-# INLINE negate #-}
instance (UV.Unbox a, Semiring a) => Semiring (UV.Vector a) where
zero = UV.empty
one = UV.singleton one
plus xs ys =
case compare (UV.length xs) (UV.length ys) of
EQ -> UV.zipWith (+) xs ys
LT -> UV.unsafeAccumulate (+) ys (UV.indexed xs)
GT -> UV.unsafeAccumulate (+) xs (UV.indexed ys)
times signal kernel
| UV.null signal = UV.empty
| UV.null kernel = UV.empty
| otherwise = UV.generate (slen + klen - 1) f
where
!slen = UV.length signal
!klen = UV.length kernel
f n = Foldable.foldl'
(\a k -> a +
UV.unsafeIndex signal k *
UV.unsafeIndex kernel (n - k)
)
zero
[kmin .. kmax]
where
!kmin = max 0 (n - (klen - 1))
!kmax = min n (slen - 1)
{-# INLINE plus #-}
{-# INLINE zero #-}
{-# INLINE times #-}
{-# INLINE one #-}
instance (UV.Unbox a, Ring a) => Ring (UV.Vector a) where
negate = UV.map negate
{-# INLINE negate #-}
instance (SV.Storable a, Semiring a) => Semiring (SV.Vector a) where
zero = SV.empty
one = SV.singleton one
plus xs ys =
case compare lxs lys of
EQ -> SV.zipWith (+) xs ys
LT -> SV.unsafeAccumulate_ (+) ys (SV.enumFromN 0 lxs) xs
GT -> SV.unsafeAccumulate_ (+) xs (SV.enumFromN 0 lys) ys
where
lxs = SV.length xs
lys = SV.length ys
times signal kernel
| SV.null signal = SV.empty
| SV.null kernel = SV.empty
| otherwise = SV.generate (slen + klen - 1) f
where
!slen = SV.length signal
!klen = SV.length kernel
f n = Foldable.foldl'
(\a k -> a +
SV.unsafeIndex signal k *
SV.unsafeIndex kernel (n - k))
zero
[kmin .. kmax]
where
!kmin = max 0 (n - (klen - 1))
!kmax = min n (slen - 1)
{-# INLINE plus #-}
{-# INLINE zero #-}
{-# INLINE times #-}
{-# INLINE one #-}
instance (SV.Storable a, Ring a) => Ring (SV.Vector a) where
negate = SV.map negate
{-# INLINE negate #-}
#endif
listAdd, listTimes :: Semiring a => [a] -> [a] -> [a]
listAdd [] ys = ys
listAdd xs [] = xs
listAdd (x:xs) (y:ys) = (x + y) : listAdd xs ys
{-# NOINLINE [0] listAdd #-}
listTimes _ [] = []
listTimes xs ys = List.foldr f [] xs
where
f x zs = List.foldr (g x) id ys (zero : zs)
g x y a [] = x `times` y : a []
g x y a (z:zs) = x `times` y `plus` z : a zs
{-# NOINLINE [0] listTimes #-}
type ListBuilder a = forall b. (a -> b -> b) -> b -> b
{-# RULES
"listAddFB/left" forall (g :: ListBuilder a). listAdd (build g) = listAddFBL g
"listAddFB/right" forall xs (g :: ListBuilder a). listAdd xs (build g) = listAddFBR xs g
#-}
listAddFBL :: Semiring a => ListBuilder a -> [a] -> [a]
listAddFBL xf = xf f id where
f x xs (y:ys) = x + y : xs ys
f x xs [] = x : xs []
listAddFBR :: Semiring a => [a] -> ListBuilder a -> [a]
listAddFBR xs' yf = yf f id xs' where
f y ys (x:xs) = x + y : ys xs
f y ys [] = y : ys []