{-# LANGUAGE CPP #-}
{-# LANGUAGE DefaultSignatures #-}
#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 702
{-# LANGUAGE Trustworthy #-}
#endif
-----------------------------------------------------------------------------
-- |
-- Copyright   :  (C) 2012-2015 Edward Kmett
-- License     :  BSD-style (see the file LICENSE)
--
-- Maintainer  :  Edward Kmett <ekmett@gmail.com>
-- Stability   :  experimental
-- Portability :  non-portable
--
-- Free metric spaces
----------------------------------------------------------------------------
module Linear.Metric
  ( Metric(..), normalize, project
  ) where

import Control.Applicative
import Data.Foldable as Foldable
import Data.Functor.Compose
import Data.Functor.Identity
import Data.Functor.Product
import Data.Vector (Vector)
import Data.IntMap (IntMap)
import Data.Map (Map)
import Data.HashMap.Strict (HashMap)
import Data.Hashable (Hashable)
import Linear.Epsilon
import Linear.Vector

-- $setup
-- >>> import Linear
--

-- | Free and sparse inner product/metric spaces.
class Additive f => Metric f where
  -- | Compute the inner product of two vectors or (equivalently)
  -- convert a vector @f a@ into a covector @f a -> a@.
  --
  -- >>> V2 1 2 `dot` V2 3 4
  -- 11
  dot :: Num a => f a -> f a -> a
#ifndef HLINT
  default dot :: (Foldable f, Num a) => f a -> f a -> a
  dot f a
x f a
y = f a -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
Foldable.sum (f a -> a) -> f a -> a
forall a b. (a -> b) -> a -> b
$ (a -> a -> a) -> f a -> f a -> f a
forall (f :: * -> *) a b c.
Additive f =>
(a -> b -> c) -> f a -> f b -> f c
liftI2 a -> a -> a
forall a. Num a => a -> a -> a
(*) f a
x f a
y
#endif

  -- | Compute the squared norm. The name quadrance arises from
  -- Norman J. Wildberger's rational trigonometry.
  quadrance :: Num a => f a -> a
  quadrance f a
v = f a -> f a -> a
forall (f :: * -> *) a. (Metric f, Num a) => f a -> f a -> a
dot f a
v f a
v

  -- | Compute the quadrance of the difference
  qd :: Num a => f a -> f a -> a
  qd f a
f f a
g = f a -> a
forall (f :: * -> *) a. (Metric f, Num a) => f a -> a
quadrance (f a
f f a -> f a -> f a
forall (f :: * -> *) a. (Additive f, Num a) => f a -> f a -> f a
^-^ f a
g)

  -- | Compute the distance between two vectors in a metric space
  distance :: Floating a => f a -> f a -> a
  distance f a
f f a
g = f a -> a
forall (f :: * -> *) a. (Metric f, Floating a) => f a -> a
norm (f a
f f a -> f a -> f a
forall (f :: * -> *) a. (Additive f, Num a) => f a -> f a -> f a
^-^ f a
g)

  -- | Compute the norm of a vector in a metric space
  norm :: Floating a => f a -> a
  norm f a
v = a -> a
forall a. Floating a => a -> a
sqrt (f a -> a
forall (f :: * -> *) a. (Metric f, Num a) => f a -> a
quadrance f a
v)

  -- | Convert a non-zero vector to unit vector.
  signorm :: Floating a => f a -> f a
  signorm f a
v = (a -> a) -> f a -> f a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> a -> a
forall a. Fractional a => a -> a -> a
/a
m) f a
v where
    m :: a
m = f a -> a
forall (f :: * -> *) a. (Metric f, Floating a) => f a -> a
norm f a
v

instance (Metric f, Metric g) => Metric (Product f g) where
  dot :: Product f g a -> Product f g a -> a
dot (Pair f a
a g a
b) (Pair f a
c g a
d) = f a -> f a -> a
forall (f :: * -> *) a. (Metric f, Num a) => f a -> f a -> a
dot f a
a f a
c a -> a -> a
forall a. Num a => a -> a -> a
+ g a -> g a -> a
forall (f :: * -> *) a. (Metric f, Num a) => f a -> f a -> a
dot g a
b g a
d
  quadrance :: Product f g a -> a
quadrance (Pair f a
a g a
b) = f a -> a
forall (f :: * -> *) a. (Metric f, Num a) => f a -> a
quadrance f a
a a -> a -> a
forall a. Num a => a -> a -> a
+ g a -> a
forall (f :: * -> *) a. (Metric f, Num a) => f a -> a
quadrance g a
b
  qd :: Product f g a -> Product f g a -> a
qd (Pair f a
a g a
b) (Pair f a
c g a
d) = f a -> f a -> a
forall (f :: * -> *) a. (Metric f, Num a) => f a -> f a -> a
qd f a
a f a
c a -> a -> a
forall a. Num a => a -> a -> a
+ g a -> g a -> a
forall (f :: * -> *) a. (Metric f, Num a) => f a -> f a -> a
qd g a
b g a
d
  distance :: Product f g a -> Product f g a -> a
distance Product f g a
p Product f g a
q = a -> a
forall a. Floating a => a -> a
sqrt (Product f g a -> Product f g a -> a
forall (f :: * -> *) a. (Metric f, Num a) => f a -> f a -> a
qd Product f g a
p Product f g a
q)

instance (Metric f, Metric g) => Metric (Compose f g) where
  dot :: Compose f g a -> Compose f g a -> a
dot (Compose f (g a)
a) (Compose f (g a)
b) = f a -> a
forall (f :: * -> *) a. (Metric f, Num a) => f a -> a
quadrance ((g a -> g a -> a) -> f (g a) -> f (g a) -> f a
forall (f :: * -> *) a b c.
Additive f =>
(a -> b -> c) -> f a -> f b -> f c
liftI2 g a -> g a -> a
forall (f :: * -> *) a. (Metric f, Num a) => f a -> f a -> a
dot f (g a)
a f (g a)
b)
  quadrance :: Compose f g a -> a
quadrance = f a -> a
forall (f :: * -> *) a. (Metric f, Num a) => f a -> a
quadrance (f a -> a) -> (Compose f g a -> f a) -> Compose f g a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (g a -> a) -> f (g a) -> f a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap g a -> a
forall (f :: * -> *) a. (Metric f, Num a) => f a -> a
quadrance (f (g a) -> f a)
-> (Compose f g a -> f (g a)) -> Compose f g a -> f a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Compose f g a -> f (g a)
forall k1 (f :: k1 -> *) k2 (g :: k2 -> k1) (a :: k2).
Compose f g a -> f (g a)
getCompose
  qd :: Compose f g a -> Compose f g a -> a
qd (Compose f (g a)
a) (Compose f (g a)
b) = f a -> a
forall (f :: * -> *) a. (Metric f, Num a) => f a -> a
quadrance ((g a -> g a -> a) -> f (g a) -> f (g a) -> f a
forall (f :: * -> *) a b c.
Additive f =>
(a -> b -> c) -> f a -> f b -> f c
liftI2 g a -> g a -> a
forall (f :: * -> *) a. (Metric f, Num a) => f a -> f a -> a
qd f (g a)
a f (g a)
b)
  distance :: Compose f g a -> Compose f g a -> a
distance (Compose f (g a)
a) (Compose f (g a)
b) = f a -> a
forall (f :: * -> *) a. (Metric f, Floating a) => f a -> a
norm ((g a -> g a -> a) -> f (g a) -> f (g a) -> f a
forall (f :: * -> *) a b c.
Additive f =>
(a -> b -> c) -> f a -> f b -> f c
liftI2 g a -> g a -> a
forall (f :: * -> *) a. (Metric f, Num a) => f a -> f a -> a
qd f (g a)
a f (g a)
b)

instance Metric Identity where
  dot :: Identity a -> Identity a -> a
dot (Identity a
x) (Identity a
y) = a
x a -> a -> a
forall a. Num a => a -> a -> a
* a
y

instance Metric []

instance Metric Maybe

instance Metric ZipList where
  -- ZipList is missing its Foldable instance
  dot :: ZipList a -> ZipList a -> a
dot (ZipList [a]
x) (ZipList [a]
y) = [a] -> [a] -> a
forall (f :: * -> *) a. (Metric f, Num a) => f a -> f a -> a
dot [a]
x [a]
y

instance Metric IntMap

instance Ord k => Metric (Map k)

instance (Hashable k, Eq k) => Metric (HashMap k)

instance Metric Vector

-- | Normalize a 'Metric' functor to have unit 'norm'. This function
-- does not change the functor if its 'norm' is 0 or 1.
normalize :: (Floating a, Metric f, Epsilon a) => f a -> f a
normalize :: f a -> f a
normalize f a
v = if a -> Bool
forall a. Epsilon a => a -> Bool
nearZero a
l Bool -> Bool -> Bool
|| a -> Bool
forall a. Epsilon a => a -> Bool
nearZero (a
1a -> a -> a
forall a. Num a => a -> a -> a
-a
l) then f a
v else (a -> a) -> f a -> f a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> a -> a
forall a. Fractional a => a -> a -> a
/a -> a
forall a. Floating a => a -> a
sqrt a
l) f a
v
  where l :: a
l = f a -> a
forall (f :: * -> *) a. (Metric f, Num a) => f a -> a
quadrance f a
v

-- | @project u v@ computes the projection of @v@ onto @u@.
project :: (Metric v, Fractional a) => v a -> v a -> v a
project :: v a -> v a -> v a
project v a
u v a
v = ((v a
v v a -> v a -> a
forall (f :: * -> *) a. (Metric f, Num a) => f a -> f a -> a
`dot` v a
u) a -> a -> a
forall a. Fractional a => a -> a -> a
/ v a -> a
forall (f :: * -> *) a. (Metric f, Num a) => f a -> a
quadrance v a
u) a -> v a -> v a
forall (f :: * -> *) a. (Functor f, Num a) => a -> f a -> f a
*^ v a
u