#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 702
#define USE_GHC_GENERICS
#endif
module Linear.Vector
  ( Additive(..)
  , E(..)
  , negated
  , (^*)
  , (*^)
  , (^/)
  , sumV
  , basis
  , basisFor
  , scaled
  , outer
  , unit
  ) where
import Control.Applicative
import Control.Lens
import Data.Complex
#if __GLASGOW_HASKELL__ < 710
import Data.Foldable as Foldable (Foldable, forM_, foldl')
#else
import Data.Foldable as Foldable (forM_, foldl')
#endif
import Data.HashMap.Lazy as HashMap
import Data.Hashable
import Data.IntMap as IntMap
import Data.Map as Map
#if __GLASGOW_HASKELL__ < 710
import Data.Monoid (mempty)
#endif
import Data.Vector as Vector
import Data.Vector.Mutable as Mutable
#ifdef USE_GHC_GENERICS
import GHC.Generics
#endif
import Linear.Instances ()
newtype E t = E { el :: forall x. Lens' (t x) x }
infixl 6 ^+^, ^-^
infixl 7 ^*, *^, ^/
#ifdef USE_GHC_GENERICS
class GAdditive f where
  gzero :: Num a => f a
  gliftU2 :: (a -> a -> a) -> f a -> f a -> f a
  gliftI2 :: (a -> b -> c) -> f a -> f b -> f c
instance GAdditive U1 where
  gzero = U1
  
  gliftU2 _ U1 U1 = U1
  
  gliftI2 _ U1 U1 = U1
  
instance (GAdditive f, GAdditive g) => GAdditive (f :*: g) where
  gzero = gzero :*: gzero
  
  gliftU2 f (a :*: b) (c :*: d) = gliftU2 f a c :*: gliftU2 f b d
  
  gliftI2 f (a :*: b) (c :*: d) = gliftI2 f a c :*: gliftI2 f b d
  
instance Additive f => GAdditive (Rec1 f) where
  gzero = Rec1 zero
  
  gliftU2 f (Rec1 g) (Rec1 h) = Rec1 (liftU2 f g h)
  
  gliftI2 f (Rec1 g) (Rec1 h) = Rec1 (liftI2 f g h)
  
instance GAdditive f => GAdditive (M1 i c f) where
  gzero = M1 gzero
  
  gliftU2 f (M1 g) (M1 h) = M1 (gliftU2 f g h)
  
  gliftI2 f (M1 g) (M1 h) = M1 (gliftI2 f g h)
  
instance GAdditive Par1 where
  gzero = Par1 0
  gliftU2 f (Par1 a) (Par1 b) = Par1 (f a b)
  
  gliftI2 f (Par1 a) (Par1 b) = Par1 (f a b)
  
#endif
class Functor f => Additive f where
  
  zero :: Num a => f a
#ifdef USE_GHC_GENERICS
#ifndef HLINT
  default zero :: (GAdditive (Rep1 f), Generic1 f, Num a) => f a
  zero = to1 gzero
#endif
#endif
  
  
  
  
  (^+^) :: Num a => f a -> f a -> f a
  (^+^) = liftU2 (+)
  
  
  
  
  
  (^-^) :: Num a => f a -> f a -> f a
  x ^-^ y = x ^+^ negated y
  
  lerp :: Num a => a -> f a -> f a -> f a
  lerp alpha u v = alpha *^ u ^+^ (1  alpha) *^ v
  
  
  
  
  
  
  liftU2 :: (a -> a -> a) -> f a -> f a -> f a
#ifdef USE_GHC_GENERICS
#ifndef HLINT
  default liftU2 :: Applicative f => (a -> a -> a) -> f a -> f a -> f a
  liftU2 = liftA2
  
#endif
#endif
  
  
  
  
  
  liftI2 :: (a -> b -> c) -> f a -> f b -> f c
#ifdef USE_GHC_GENERICS
#ifndef HLINT
  default liftI2 :: Applicative f => (a -> b -> c) -> f a -> f b -> f c
  liftI2 = liftA2
  
#endif
#endif
instance Additive ZipList where
  zero = ZipList []
  
  liftU2 f (ZipList xs) (ZipList ys) = ZipList (liftU2 f xs ys)
  
  liftI2 = liftA2
  
instance Additive Vector where
  zero = mempty
  
  liftU2 f u v = case compare lu lv of
    LT | lu == 0   -> v
       | otherwise -> Vector.modify (\ w -> Foldable.forM_ [0..lu1] $ \i -> unsafeWrite w i $ f (unsafeIndex u i) (unsafeIndex v i)) v
    EQ -> Vector.zipWith f u v
    GT | lv == 0   -> u
       | otherwise -> Vector.modify (\ w -> Foldable.forM_ [0..lv1] $ \i -> unsafeWrite w i $ f (unsafeIndex u i) (unsafeIndex v i)) u
    where
      lu = Vector.length u
      lv = Vector.length v
  
  liftI2 = Vector.zipWith
  
instance Additive Maybe where
  zero = Nothing
  
  liftU2 f (Just a) (Just b) = Just (f a b)
  liftU2 _ Nothing ys = ys
  liftU2 _ xs Nothing = xs
  
  liftI2 = liftA2
  
instance Additive [] where
  zero = []
  
  liftU2 f = go where
    go (x:xs) (y:ys) = f x y : go xs ys
    go [] ys = ys
    go xs [] = xs
  
  liftI2 = Prelude.zipWith
  
instance Additive IntMap where
  zero = IntMap.empty
  
  liftU2 = IntMap.unionWith
  
  liftI2 = IntMap.intersectionWith
  
instance Ord k => Additive (Map k) where
  zero = Map.empty
  
  liftU2 = Map.unionWith
  
  liftI2 = Map.intersectionWith
  
instance (Eq k, Hashable k) => Additive (HashMap k) where
  zero = HashMap.empty
  
  liftU2 = HashMap.unionWith
  
  liftI2 = HashMap.intersectionWith
  
instance Additive ((->) b) where
  zero   = const 0
  
  liftU2 = liftA2
  
  liftI2 = liftA2
  
instance Additive Complex where
  zero = 0 :+ 0
  
  liftU2 f (a :+ b) (c :+ d) = f a c :+ f b d
  
  liftI2 f (a :+ b) (c :+ d) = f a c :+ f b d
  
instance Additive Identity where
  zero = Identity 0
  
  liftU2 = liftA2
  
  liftI2 = liftA2
  
negated :: (Functor f, Num a) => f a -> f a
negated = fmap negate
sumV :: (Foldable f, Additive v, Num a) => f (v a) -> v a
sumV = Foldable.foldl' (^+^) zero
(*^) :: (Functor f, Num a) => a -> f a -> f a
(*^) a = fmap (a*)
(^*) :: (Functor f, Num a) => f a -> a -> f a
f ^* a = fmap (*a) f
(^/) :: (Functor f, Fractional a) => f a -> a -> f a
f ^/ a = fmap (/a) f
basis :: (Additive t, Traversable t, Num a) => [t a]
basis = basisFor (zero :: Additive v => v Int)
basisFor :: (Traversable t, Num a) => t b -> [t a]
basisFor = \t ->
   ifoldMapOf traversed ?? t $ \i _ ->
     return                  $
       iover  traversed ?? t $ \j _ ->
         if i == j then 1 else 0
scaled :: (Traversable t, Num a) => t a -> t (t a)
scaled = \t -> iter t (\i x -> iter t (\j _ -> if i == j then x else 0))
  where
  iter :: Traversable t => t a -> (Int -> a -> b) -> t b
  iter x f = iover traversed f x
unit :: (Additive t, Num a) => ASetter' (t a) a -> t a
unit l = set' l 1 zero
outer :: (Functor f, Functor g, Num a) => f a -> g a -> f (g a)
outer a b = fmap (\x->fmap (*x) b) a