#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 702
#define USE_GHC_GENERICS
#endif
module Linear.Vector
( Additive(..)
, negated
, (^*)
, (*^)
, (^/)
, sumV
, basis
, basisFor
, kronecker
, outer
) where
import Control.Applicative
import Data.Complex
import Data.Foldable as Foldable (Foldable, forM_, foldl')
import Data.Functor.Identity
import Data.HashMap.Lazy as HashMap
import Data.Hashable
import Data.IntMap as IntMap
import Data.Map as Map
import Data.Monoid (mempty)
import Data.Vector as Vector
import Data.Vector.Mutable as Mutable
import Data.Traversable (Traversable, traverse, mapAccumL)
#ifdef USE_GHC_GENERICS
import GHC.Generics
#endif
import Linear.Instances ()
infixl 6 ^+^, ^-^
infixl 7 ^*, *^, ^/
#ifdef USE_GHC_GENERICS
class GAdditive f where
gzero :: Num a => f a
gliftU2 :: (a -> a -> a) -> f a -> f a -> f a
gliftI2 :: (a -> b -> c) -> f a -> f b -> f c
instance GAdditive U1 where
gzero = U1
gliftU2 _ U1 U1 = U1
gliftI2 _ U1 U1 = U1
instance (GAdditive f, GAdditive g) => GAdditive (f :*: g) where
gzero = gzero :*: gzero
gliftU2 f (a :*: b) (c :*: d) = gliftU2 f a c :*: gliftU2 f b d
gliftI2 f (a :*: b) (c :*: d) = gliftI2 f a c :*: gliftI2 f b d
instance Additive f => GAdditive (Rec1 f) where
gzero = Rec1 zero
gliftU2 f (Rec1 g) (Rec1 h) = Rec1 (liftU2 f g h)
gliftI2 f (Rec1 g) (Rec1 h) = Rec1 (liftI2 f g h)
instance GAdditive f => GAdditive (M1 i c f) where
gzero = M1 gzero
gliftU2 f (M1 g) (M1 h) = M1 (gliftU2 f g h)
gliftI2 f (M1 g) (M1 h) = M1 (gliftI2 f g h)
instance GAdditive Par1 where
gzero = Par1 0
gliftU2 f (Par1 a) (Par1 b) = Par1 (f a b)
gliftI2 f (Par1 a) (Par1 b) = Par1 (f a b)
#endif
class Functor f => Additive f where
zero :: Num a => f a
#ifdef USE_GHC_GENERICS
#ifndef HLINT
default zero :: (GAdditive (Rep1 f), Generic1 f, Num a) => f a
zero = to1 gzero
#endif
#endif
(^+^) :: Num a => f a -> f a -> f a
#ifdef USE_GHC_GENERICS
#ifndef HLINT
default (^+^) :: Num a => f a -> f a -> f a
(^+^) = liftU2 (+)
#endif
#endif
(^-^) :: Num a => f a -> f a -> f a
#ifdef USE_GHC_GENERICS
#ifndef HLINT
default (^-^) :: Num a => f a -> f a -> f a
x ^-^ y = x ^+^ negated y
#endif
#endif
lerp :: Num a => a -> f a -> f a -> f a
lerp alpha u v = alpha *^ u ^+^ (1 alpha) *^ v
liftU2 :: (a -> a -> a) -> f a -> f a -> f a
#ifdef USE_GHC_GENERICS
#ifndef HLINT
default liftU2 :: Applicative f => (a -> a -> a) -> f a -> f a -> f a
liftU2 = liftA2
#endif
#endif
liftI2 :: (a -> b -> c) -> f a -> f b -> f c
#ifdef USE_GHC_GENERICS
#ifndef HLINT
default liftI2 :: Applicative f => (a -> b -> c) -> f a -> f b -> f c
liftI2 = liftA2
#endif
#endif
instance Additive ZipList where
zero = ZipList []
liftU2 f (ZipList xs) (ZipList ys) = ZipList (liftU2 f xs ys)
liftI2 = liftA2
#ifndef USE_GHC_GENERICS
(^+^) = liftU2 (+)
x ^-^ y = x ^+^ negated y
#endif
instance Additive Vector where
zero = mempty
liftU2 f u v = case compare lu lv of
LT | lu == 0 -> v
| otherwise -> modify (\ w -> Foldable.forM_ [0..lu1] $ \i -> unsafeWrite w i $ f (unsafeIndex u i) (unsafeIndex v i)) v
EQ -> Vector.zipWith f u v
GT | lv == 0 -> u
| otherwise -> modify (\ w -> Foldable.forM_ [0..lv1] $ \i -> unsafeWrite w i $ f (unsafeIndex u i) (unsafeIndex v i)) u
where
lu = Vector.length u
lv = Vector.length v
liftI2 = Vector.zipWith
#ifndef USE_GHC_GENERICS
(^+^) = liftU2 (+)
x ^-^ y = x ^+^ negated y
#endif
instance Additive Maybe where
zero = Nothing
liftU2 f (Just a) (Just b) = Just (f a b)
liftU2 _ Nothing ys = ys
liftU2 _ xs Nothing = xs
liftI2 = liftA2
#ifndef USE_GHC_GENERICS
(^+^) = liftU2 (+)
x ^-^ y = x ^+^ negated y
#endif
instance Additive [] where
zero = []
liftU2 f = go where
go (x:xs) (y:ys) = f x y : go xs ys
go [] ys = ys
go xs [] = xs
liftI2 = Prelude.zipWith
#ifndef USE_GHC_GENERICS
(^+^) = liftU2 (+)
x ^-^ y = x ^+^ negated y
#endif
instance Additive IntMap where
zero = IntMap.empty
liftU2 = IntMap.unionWith
liftI2 = IntMap.intersectionWith
#ifndef USE_GHC_GENERICS
(^+^) = liftU2 (+)
x ^-^ y = x ^+^ negated y
#endif
instance Ord k => Additive (Map k) where
zero = Map.empty
liftU2 = Map.unionWith
liftI2 = Map.intersectionWith
#ifndef USE_GHC_GENERICS
(^+^) = liftU2 (+)
x ^-^ y = x ^+^ negated y
#endif
instance (Eq k, Hashable k) => Additive (HashMap k) where
zero = HashMap.empty
liftU2 = HashMap.unionWith
liftI2 = HashMap.intersectionWith
#ifndef USE_GHC_GENERICS
(^+^) = liftU2 (+)
x ^-^ y = x ^+^ negated y
#endif
instance Additive ((->) b) where
zero = const 0
liftU2 = liftA2
liftI2 = liftA2
#ifndef USE_GHC_GENERICS
(^+^) = liftU2 (+)
x ^-^ y = x ^+^ negated y
#endif
instance Additive Complex where
zero = 0 :+ 0
liftU2 f (a :+ b) (c :+ d) = f a c :+ f b d
liftI2 f (a :+ b) (c :+ d) = f a c :+ f b d
#ifndef USE_GHC_GENERICS
(^+^) = liftU2 (+)
x ^-^ y = x ^+^ negated y
#endif
instance Additive Identity where
zero = Identity 0
liftU2 = liftA2
liftI2 = liftA2
#ifndef USE_GHC_GENERICS
(^+^) = liftU2 (+)
x ^-^ y = x ^+^ negated y
#endif
negated :: (Functor f, Num a) => f a -> f a
negated = fmap negate
sumV :: (Foldable f, Additive v, Num a) => f (v a) -> v a
sumV = Foldable.foldl' (^+^) zero
(*^) :: (Functor f, Num a) => a -> f a -> f a
(*^) a = fmap (a*)
(^*) :: (Functor f, Num a) => f a -> a -> f a
f ^* a = fmap (*a) f
(^/) :: (Functor f, Fractional a) => f a -> a -> f a
f ^/ a = fmap (/a) f
data SetOne a = SetOne { _filler :: !a, choices :: [a] }
instance Functor SetOne where
fmap f (SetOne a os) = SetOne (f a) (fmap f os)
instance Applicative SetOne where
pure a = SetOne a []
SetOne f fs <*> SetOne a as = SetOne (f a) (Prelude.foldr ((:) . ($ a)) (Prelude.map f as) fs)
basis :: (Applicative t, Traversable t, Num a) => [t a]
basis = choices $ traverse (\a -> SetOne 0 [a]) (pure 1)
basisFor :: (Traversable t, Num a) => t b -> [t a]
basisFor = choices . traverse (\_ -> SetOne 0 [1])
kronecker :: (Traversable t, Num a) => t a -> t (t a)
kronecker v = fillFromList (choices $ traverse (\a -> SetOne 0 [a]) v) v
fillFromList :: Traversable t => [a] -> t b -> t a
fillFromList l = snd . mapAccumL aux l
where aux (a:as) _ = (as, a)
aux [] _ = error "too few elements in takeFromList"
outer :: (Functor f, Functor g, Num a) => f a -> g a -> f (g a)
outer a b = fmap (\x->fmap (*x) b) a