{-# LANGUAGE CPP #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE DefaultSignatures #-}
#if __GLASGOW_HASKELL__ >= 706
{-# LANGUAGE PolyKinds #-}
#endif
#if __GLASGOW_HASKELL__ >= 702 && __GLASGOW_HASKELL__ < 710
{-# LANGUAGE Trustworthy #-}
#endif
---------------------------------------------------------------------------
-- |
-- Copyright   :  (C) 2012-2015 Edward Kmett
-- License     :  BSD-style (see the file LICENSE)
--
-- Maintainer  :  Edward Kmett <ekmett@gmail.com>
-- Stability   :  experimental
-- Portability :  non-portable
--
-- Simple matrix operation for low-dimensional primitives.
---------------------------------------------------------------------------
module Linear.Trace
  ( Trace(..)
  , frobenius
  ) where

import Control.Monad as Monad
import Linear.V0
import Linear.V1
import Linear.V2
import Linear.V3
import Linear.V4
import Linear.Plucker
import Linear.Quaternion
import Linear.V
import Linear.Vector
#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ > 704
import Data.Complex
#endif
import Data.Distributive
import Data.Foldable as Foldable
import Data.Functor.Bind as Bind
import Data.Functor.Compose
import Data.Functor.Product
import Data.Hashable
import Data.HashMap.Lazy
import Data.IntMap
import Data.Map

-- $setup
-- >>> import Data.Complex
-- >>> import Data.IntMap
-- >>> import Debug.SimpleReflect.Vars
-- >>> import Linear.V2

class Functor m => Trace m where
  -- | Compute the trace of a matrix
  --
  -- >>> trace (V2 (V2 a b) (V2 c d))
  -- a + d
  trace :: Num a => m (m a) -> a
#ifndef HLINT
  default trace :: (Foldable m, Num a) => m (m a) -> a
  trace = m a -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
Foldable.sum (m a -> a) -> (m (m a) -> m a) -> m (m a) -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m (m a) -> m a
forall (m :: * -> *) a. Trace m => m (m a) -> m a
diagonal
  {-# INLINE trace #-}
#endif

  -- | Compute the diagonal of a matrix
  --
  -- >>> diagonal (V2 (V2 a b) (V2 c d))
  -- V2 a d
  diagonal :: m (m a) -> m a
#ifndef HLINT
  default diagonal :: Monad m => m (m a) -> m a
  diagonal = m (m a) -> m a
forall (m :: * -> *) a. Monad m => m (m a) -> m a
Monad.join
  {-# INLINE diagonal #-}
#endif

instance Trace IntMap where
  diagonal :: IntMap (IntMap a) -> IntMap a
diagonal = IntMap (IntMap a) -> IntMap a
forall (m :: * -> *) a. Bind m => m (m a) -> m a
Bind.join
  {-# INLINE diagonal #-}

instance Ord k => Trace (Map k) where
  diagonal :: Map k (Map k a) -> Map k a
diagonal = Map k (Map k a) -> Map k a
forall (m :: * -> *) a. Bind m => m (m a) -> m a
Bind.join
  {-# INLINE diagonal #-}

instance (Eq k, Hashable k) => Trace (HashMap k) where
  diagonal :: HashMap k (HashMap k a) -> HashMap k a
diagonal = HashMap k (HashMap k a) -> HashMap k a
forall (m :: * -> *) a. Bind m => m (m a) -> m a
Bind.join
  {-# INLINE diagonal #-}

instance Dim n => Trace (V n)
instance Trace V0
instance Trace V1
instance Trace V2
instance Trace V3
instance Trace V4
instance Trace Plucker
instance Trace Quaternion

#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ > 704
instance Trace Complex where
  trace :: Complex (Complex a) -> a
trace ((a
a :+ a
_) :+ (a
_ :+ a
b)) = a
a a -> a -> a
forall a. Num a => a -> a -> a
+ a
b
  {-# INLINE trace #-}
  diagonal :: Complex (Complex a) -> Complex a
diagonal ((a
a :+ a
_) :+ (a
_ :+ a
b)) = a
a a -> a -> Complex a
forall a. a -> a -> Complex a
:+ a
b
  {-# INLINE diagonal #-}
#endif

instance (Trace f, Trace g) => Trace (Product f g) where
  trace :: Product f g (Product f g a) -> a
trace (Pair f (Product f g a)
xx g (Product f g a)
yy) = f (f a) -> a
forall (m :: * -> *) a. (Trace m, Num a) => m (m a) -> a
trace (Product f g a -> f a
forall k (f :: k -> *) (g :: k -> *) (a :: k). Product f g a -> f a
pfst (Product f g a -> f a) -> f (Product f g a) -> f (f a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f (Product f g a)
xx) a -> a -> a
forall a. Num a => a -> a -> a
+ g (g a) -> a
forall (m :: * -> *) a. (Trace m, Num a) => m (m a) -> a
trace (Product f g a -> g a
forall k (f :: k -> *) (g :: k -> *) (a :: k). Product f g a -> g a
psnd (Product f g a -> g a) -> g (Product f g a) -> g (g a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> g (Product f g a)
yy) where
    pfst :: Product f g a -> f a
pfst (Pair f a
x g a
_) = f a
x
    psnd :: Product f g a -> g a
psnd (Pair f a
_ g a
y) = g a
y
  {-# INLINE trace #-}
  diagonal :: Product f g (Product f g a) -> Product f g a
diagonal (Pair f (Product f g a)
xx g (Product f g a)
yy) = f (f a) -> f a
forall (m :: * -> *) a. Trace m => m (m a) -> m a
diagonal (Product f g a -> f a
forall k (f :: k -> *) (g :: k -> *) (a :: k). Product f g a -> f a
pfst (Product f g a -> f a) -> f (Product f g a) -> f (f a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f (Product f g a)
xx) f a -> g a -> Product f g a
forall k (f :: k -> *) (g :: k -> *) (a :: k).
f a -> g a -> Product f g a
`Pair` g (g a) -> g a
forall (m :: * -> *) a. Trace m => m (m a) -> m a
diagonal (Product f g a -> g a
forall k (f :: k -> *) (g :: k -> *) (a :: k). Product f g a -> g a
psnd (Product f g a -> g a) -> g (Product f g a) -> g (g a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> g (Product f g a)
yy) where
    pfst :: Product f g a -> f a
pfst (Pair f a
x g a
_) = f a
x
    psnd :: Product f g a -> g a
psnd (Pair f a
_ g a
y) = g a
y
  {-# INLINE diagonal #-}

instance (Distributive g, Trace g, Trace f) => Trace (Compose g f) where
  trace :: Compose g f (Compose g f a) -> a
trace = g (g a) -> a
forall (m :: * -> *) a. (Trace m, Num a) => m (m a) -> a
trace (g (g a) -> a)
-> (Compose g f (Compose g f a) -> g (g a))
-> Compose g f (Compose g f a)
-> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (f (g (f a)) -> g a) -> g (f (g (f a))) -> g (g a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((f (f a) -> a) -> g (f (f a)) -> g a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap f (f a) -> a
forall (m :: * -> *) a. (Trace m, Num a) => m (m a) -> a
trace (g (f (f a)) -> g a)
-> (f (g (f a)) -> g (f (f a))) -> f (g (f a)) -> g a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (g (f a)) -> g (f (f a))
forall (g :: * -> *) (f :: * -> *) a.
(Distributive g, Functor f) =>
f (g a) -> g (f a)
distribute) (g (f (g (f a))) -> g (g a))
-> (Compose g f (Compose g f a) -> g (f (g (f a))))
-> Compose g f (Compose g f a)
-> g (g a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Compose g f (g (f a)) -> g (f (g (f a)))
forall k1 (f :: k1 -> *) k2 (g :: k2 -> k1) (a :: k2).
Compose f g a -> f (g a)
getCompose (Compose g f (g (f a)) -> g (f (g (f a))))
-> (Compose g f (Compose g f a) -> Compose g f (g (f a)))
-> Compose g f (Compose g f a)
-> g (f (g (f a)))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Compose g f a -> g (f a))
-> Compose g f (Compose g f a) -> Compose g f (g (f a))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Compose g f a -> g (f a)
forall k1 (f :: k1 -> *) k2 (g :: k2 -> k1) (a :: k2).
Compose f g a -> f (g a)
getCompose
  {-# INLINE trace #-}
  diagonal :: Compose g f (Compose g f a) -> Compose g f a
diagonal = g (f a) -> Compose g f a
forall k k1 (f :: k -> *) (g :: k1 -> k) (a :: k1).
f (g a) -> Compose f g a
Compose (g (f a) -> Compose g f a)
-> (Compose g f (Compose g f a) -> g (f a))
-> Compose g f (Compose g f a)
-> Compose g f a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (f (f a) -> f a) -> g (f (f a)) -> g (f a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap f (f a) -> f a
forall (m :: * -> *) a. Trace m => m (m a) -> m a
diagonal (g (f (f a)) -> g (f a))
-> (Compose g f (Compose g f a) -> g (f (f a)))
-> Compose g f (Compose g f a)
-> g (f a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. g (g (f (f a))) -> g (f (f a))
forall (m :: * -> *) a. Trace m => m (m a) -> m a
diagonal (g (g (f (f a))) -> g (f (f a)))
-> (Compose g f (Compose g f a) -> g (g (f (f a))))
-> Compose g f (Compose g f a)
-> g (f (f a))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (f (g (f a)) -> g (f (f a))) -> g (f (g (f a))) -> g (g (f (f a)))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap f (g (f a)) -> g (f (f a))
forall (g :: * -> *) (f :: * -> *) a.
(Distributive g, Functor f) =>
f (g a) -> g (f a)
distribute (g (f (g (f a))) -> g (g (f (f a))))
-> (Compose g f (Compose g f a) -> g (f (g (f a))))
-> Compose g f (Compose g f a)
-> g (g (f (f a)))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Compose g f (g (f a)) -> g (f (g (f a)))
forall k1 (f :: k1 -> *) k2 (g :: k2 -> k1) (a :: k2).
Compose f g a -> f (g a)
getCompose (Compose g f (g (f a)) -> g (f (g (f a))))
-> (Compose g f (Compose g f a) -> Compose g f (g (f a)))
-> Compose g f (Compose g f a)
-> g (f (g (f a)))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Compose g f a -> g (f a))
-> Compose g f (Compose g f a) -> Compose g f (g (f a))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Compose g f a -> g (f a)
forall k1 (f :: k1 -> *) k2 (g :: k2 -> k1) (a :: k2).
Compose f g a -> f (g a)
getCompose
  {-# INLINE diagonal #-}

-- | Compute the <http://mathworld.wolfram.com/FrobeniusNorm.html Frobenius norm> of a matrix.
frobenius :: (Num a, Foldable f, Additive f, Additive g, Distributive g, Trace g) => f (g a) -> a
frobenius :: f (g a) -> a
frobenius f (g a)
m = g (g a) -> a
forall (m :: * -> *) a. (Trace m, Num a) => m (m a) -> a
trace (g (g a) -> a) -> g (g a) -> a
forall a b. (a -> b) -> a -> b
$ (f a -> g a) -> g (f a) -> g (g a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\ f a
f' -> (g a -> g a -> g a) -> g a -> f (g a) -> g a
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
Foldable.foldl' g a -> g a -> g a
forall (f :: * -> *) a. (Additive f, Num a) => f a -> f a -> f a
(^+^) g a
forall (f :: * -> *) a. (Additive f, Num a) => f a
zero (f (g a) -> g a) -> f (g a) -> g a
forall a b. (a -> b) -> a -> b
$ (a -> g a -> g a) -> f a -> f (g a) -> f (g a)
forall (f :: * -> *) a b c.
Additive f =>
(a -> b -> c) -> f a -> f b -> f c
liftI2 a -> g a -> g a
forall (f :: * -> *) a. (Functor f, Num a) => a -> f a -> f a
(*^) f a
f' f (g a)
m) (f (g a) -> g (f a)
forall (g :: * -> *) (f :: * -> *) a.
(Distributive g, Functor f) =>
f (g a) -> g (f a)
distribute f (g a)
m)