{-# LANGUAGE CPP #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE Trustworthy #-}
---------------------------------------------------------------------------

-- |

-- Copyright   :  (C) 2012-2015 Edward Kmett

-- License     :  BSD-style (see the file LICENSE)

--

-- Maintainer  :  Edward Kmett <ekmett@gmail.com>

-- Stability   :  experimental

-- Portability :  non-portable

--

-- Simple matrix operation for low-dimensional primitives.

---------------------------------------------------------------------------

module Linear.Trace
  ( Trace(..)
  , frobenius
  ) where

import Control.Monad as Monad
import Linear.V0
import Linear.V1
import Linear.V2
import Linear.V3
import Linear.V4
import Linear.Plucker
import Linear.Quaternion
import Linear.V
import Linear.Vector
import Data.Complex
import Data.Distributive
import Data.Foldable as Foldable
import Data.Functor.Bind as Bind
import Data.Functor.Compose
import Data.Functor.Product
import Data.Hashable
import Data.HashMap.Lazy
import Data.IntMap (IntMap)
import Data.Map (Map)

-- $setup

-- >>> import Data.Complex

-- >>> import Debug.SimpleReflect.Vars

-- >>> import Linear.V2


class Functor m => Trace m where
  -- | Compute the trace of a matrix

  --

  -- >>> trace (V2 (V2 a b) (V2 c d))

  -- a + d

  trace :: Num a => m (m a) -> a
#ifndef HLINT
  default trace :: (Foldable m, Num a) => m (m a) -> a
  trace = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
Foldable.sum forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. Trace m => m (m a) -> m a
diagonal
  {-# INLINE trace #-}
#endif

  -- | Compute the diagonal of a matrix

  --

  -- >>> diagonal (V2 (V2 a b) (V2 c d))

  -- V2 a d

  diagonal :: m (m a) -> m a
#ifndef HLINT
  default diagonal :: Monad m => m (m a) -> m a
  diagonal = forall (m :: * -> *) a. Monad m => m (m a) -> m a
Monad.join
  {-# INLINE diagonal #-}
#endif

instance Trace IntMap where
  diagonal :: forall a. IntMap (IntMap a) -> IntMap a
diagonal = forall (m :: * -> *) a. Bind m => m (m a) -> m a
Bind.join
  {-# INLINE diagonal #-}

instance Ord k => Trace (Map k) where
  diagonal :: forall a. Map k (Map k a) -> Map k a
diagonal = forall (m :: * -> *) a. Bind m => m (m a) -> m a
Bind.join
  {-# INLINE diagonal #-}

instance (Eq k, Hashable k) => Trace (HashMap k) where
  diagonal :: forall a. HashMap k (HashMap k a) -> HashMap k a
diagonal = forall (m :: * -> *) a. Bind m => m (m a) -> m a
Bind.join
  {-# INLINE diagonal #-}

instance Dim n => Trace (V n)
instance Trace V0
instance Trace V1
instance Trace V2
instance Trace V3
instance Trace V4
instance Trace Plucker
instance Trace Quaternion

instance Trace Complex where
  trace :: forall a. Num a => Complex (Complex a) -> a
trace ((a
a :+ a
_) :+ (a
_ :+ a
b)) = a
a forall a. Num a => a -> a -> a
+ a
b
  {-# INLINE trace #-}
  diagonal :: forall a. Complex (Complex a) -> Complex a
diagonal ((a
a :+ a
_) :+ (a
_ :+ a
b)) = a
a forall a. a -> a -> Complex a
:+ a
b
  {-# INLINE diagonal #-}

instance (Trace f, Trace g) => Trace (Product f g) where
  trace :: forall a. Num a => Product f g (Product f g a) -> a
trace (Pair f (Product f g a)
xx g (Product f g a)
yy) = forall (m :: * -> *) a. (Trace m, Num a) => m (m a) -> a
trace (forall {k} {f :: k -> *} {g :: k -> *} {a :: k}.
Product f g a -> f a
pfst forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f (Product f g a)
xx) forall a. Num a => a -> a -> a
+ forall (m :: * -> *) a. (Trace m, Num a) => m (m a) -> a
trace (forall {k} {f :: k -> *} {g :: k -> *} {a :: k}.
Product f g a -> g a
psnd forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> g (Product f g a)
yy) where
    pfst :: Product f g a -> f a
pfst (Pair f a
x g a
_) = f a
x
    psnd :: Product f g a -> g a
psnd (Pair f a
_ g a
y) = g a
y
  {-# INLINE trace #-}
  diagonal :: forall a. Product f g (Product f g a) -> Product f g a
diagonal (Pair f (Product f g a)
xx g (Product f g a)
yy) = forall (m :: * -> *) a. Trace m => m (m a) -> m a
diagonal (forall {k} {f :: k -> *} {g :: k -> *} {a :: k}.
Product f g a -> f a
pfst forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f (Product f g a)
xx) forall {k} (f :: k -> *) (g :: k -> *) (a :: k).
f a -> g a -> Product f g a
`Pair` forall (m :: * -> *) a. Trace m => m (m a) -> m a
diagonal (forall {k} {f :: k -> *} {g :: k -> *} {a :: k}.
Product f g a -> g a
psnd forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> g (Product f g a)
yy) where
    pfst :: Product f g a -> f a
pfst (Pair f a
x g a
_) = f a
x
    psnd :: Product f g a -> g a
psnd (Pair f a
_ g a
y) = g a
y
  {-# INLINE diagonal #-}

instance (Distributive g, Trace g, Trace f) => Trace (Compose g f) where
  trace :: forall a. Num a => Compose g f (Compose g f a) -> a
trace = forall (m :: * -> *) a. (Trace m, Num a) => m (m a) -> a
trace forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (m :: * -> *) a. (Trace m, Num a) => m (m a) -> a
trace forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (g :: * -> *) (f :: * -> *) a.
(Distributive g, Functor f) =>
f (g a) -> g (f a)
distribute) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k1} {k2} (f :: k1 -> *) (g :: k2 -> k1) (a :: k2).
Compose f g a -> f (g a)
getCompose forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k1} {k2} (f :: k1 -> *) (g :: k2 -> k1) (a :: k2).
Compose f g a -> f (g a)
getCompose
  {-# INLINE trace #-}
  diagonal :: forall a. Compose g f (Compose g f a) -> Compose g f a
diagonal = forall {k} {k1} (f :: k -> *) (g :: k1 -> k) (a :: k1).
f (g a) -> Compose f g a
Compose forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (m :: * -> *) a. Trace m => m (m a) -> m a
diagonal forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. Trace m => m (m a) -> m a
diagonal forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (g :: * -> *) (f :: * -> *) a.
(Distributive g, Functor f) =>
f (g a) -> g (f a)
distribute forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k1} {k2} (f :: k1 -> *) (g :: k2 -> k1) (a :: k2).
Compose f g a -> f (g a)
getCompose forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k1} {k2} (f :: k1 -> *) (g :: k2 -> k1) (a :: k2).
Compose f g a -> f (g a)
getCompose
  {-# INLINE diagonal #-}

-- | Compute the <http://mathworld.wolfram.com/FrobeniusNorm.html Frobenius norm> of a matrix.

frobenius :: (Num a, Foldable f, Additive f, Additive g, Distributive g, Trace g) => f (g a) -> a
frobenius :: forall a (f :: * -> *) (g :: * -> *).
(Num a, Foldable f, Additive f, Additive g, Distributive g,
 Trace g) =>
f (g a) -> a
frobenius f (g a)
m = forall (m :: * -> *) a. (Trace m, Num a) => m (m a) -> a
trace forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\ f a
f' -> forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
Foldable.foldl' forall (f :: * -> *) a. (Additive f, Num a) => f a -> f a -> f a
(^+^) forall (f :: * -> *) a. (Additive f, Num a) => f a
zero forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b c.
Additive f =>
(a -> b -> c) -> f a -> f b -> f c
liftI2 forall (f :: * -> *) a. (Functor f, Num a) => a -> f a -> f a
(*^) f a
f' f (g a)
m) (forall (g :: * -> *) (f :: * -> *) a.
(Distributive g, Functor f) =>
f (g a) -> g (f a)
distribute f (g a)
m)