module Grenade.Core.Shape (
Shape (..)
, S (..)
, Sing (..)
, randomOfShape
, fromStorable
) where
import Control.DeepSeq (NFData (..))
import Control.Monad.Random ( MonadRandom, getRandom )
import Data.Proxy
import Data.Singletons
import Data.Singletons.TypeLits
import Data.Vector.Storable ( Vector )
import qualified Data.Vector.Storable as V
import GHC.TypeLits
import qualified Numeric.LinearAlgebra.Static as H
import Numeric.LinearAlgebra.Static
import qualified Numeric.LinearAlgebra as NLA
data Shape
= D1 Nat
| D2 Nat Nat
| D3 Nat Nat Nat
data S (n :: Shape) where
S1D :: ( KnownNat len )
=> R len
-> S ('D1 len)
S2D :: ( KnownNat rows, KnownNat columns )
=> L rows columns
-> S ('D2 rows columns)
S3D :: ( KnownNat rows
, KnownNat columns
, KnownNat depth
, KnownNat (rows * depth))
=> L (rows * depth) columns
-> S ('D3 rows columns depth)
deriving instance Show (S n)
data instance Sing (n :: Shape) where
D1Sing :: Sing a -> Sing ('D1 a)
D2Sing :: Sing a -> Sing b -> Sing ('D2 a b)
D3Sing :: KnownNat (a * c) => Sing a -> Sing b -> Sing c -> Sing ('D3 a b c)
instance KnownNat a => SingI ('D1 a) where
sing = D1Sing sing
instance (KnownNat a, KnownNat b) => SingI ('D2 a b) where
sing = D2Sing sing sing
instance (KnownNat a, KnownNat b, KnownNat c, KnownNat (a * c)) => SingI ('D3 a b c) where
sing = D3Sing sing sing sing
instance SingI x => Num (S x) where
(+) = n2 (+)
() = n2 ()
(*) = n2 (*)
abs = n1 abs
signum = n1 signum
fromInteger x = nk (fromInteger x)
instance SingI x => Fractional (S x) where
(/) = n2 (/)
recip = n1 recip
fromRational x = nk (fromRational x)
instance SingI x => Floating (S x) where
pi = nk pi
exp = n1 exp
log = n1 log
sqrt = n1 sqrt
(**) = n2 (**)
logBase = n2 logBase
sin = n1 sin
cos = n1 cos
tan = n1 tan
asin = n1 asin
acos = n1 acos
atan = n1 atan
sinh = n1 sinh
cosh = n1 cosh
tanh = n1 tanh
asinh = n1 asinh
acosh = n1 acosh
atanh = n1 atanh
instance NFData (S x) where
rnf (S1D x) = rnf x
rnf (S2D x) = rnf x
rnf (S3D x) = rnf x
randomOfShape :: forall x m. ( MonadRandom m, SingI x ) => m (S x)
randomOfShape = do
seed :: Int <- getRandom
return $ case (sing :: Sing x) of
D1Sing l ->
withKnownNat l $
S1D (randomVector seed Uniform * 2 1)
D2Sing r c ->
withKnownNat r $ withKnownNat c $
S2D (uniformSample seed (1) 1)
D3Sing r c d ->
withKnownNat r $ withKnownNat c $ withKnownNat d $
S3D (uniformSample seed (1) 1)
fromStorable :: forall x. SingI x => Vector Double -> Maybe (S x)
fromStorable xs = case sing :: Sing x of
D1Sing l ->
withKnownNat l $
S1D <$> H.create xs
D2Sing r c ->
withKnownNat r $ withKnownNat c $
S2D <$> mkL xs
D3Sing r c d ->
withKnownNat r $ withKnownNat c $ withKnownNat d $
S3D <$> mkL xs
where
mkL :: forall rows columns. (KnownNat rows, KnownNat columns)
=> Vector Double -> Maybe (L rows columns)
mkL v =
let rows = fromIntegral $ natVal (Proxy :: Proxy rows)
columns = fromIntegral $ natVal (Proxy :: Proxy columns)
in if rows * columns == V.length v
then H.create $ NLA.reshape columns v
else Nothing
n1 :: ( forall a. Floating a => a -> a ) -> S x -> S x
n1 f (S1D x) = S1D (f x)
n1 f (S2D x) = S2D (f x)
n1 f (S3D x) = S3D (f x)
n2 :: ( forall a. Floating a => a -> a -> a ) -> S x -> S x -> S x
n2 f (S1D x) (S1D y) = S1D (f x y)
n2 f (S2D x) (S2D y) = S2D (f x y)
n2 f (S3D x) (S3D y) = S3D (f x y)
nk :: forall x. SingI x => Double -> S x
nk x = case (sing :: Sing x) of
D1Sing l ->
withKnownNat l $
S1D (konst x)
D2Sing r c ->
withKnownNat r $ withKnownNat c $
S2D (konst x)
D3Sing r c d ->
withKnownNat r $ withKnownNat c $ withKnownNat d $
S3D (konst x)