{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module TensorSafe.Shape where
import Data.Singletons
import GHC.TypeLits as N
import TensorSafe.Core
data Shape
= D1 Nat
| D2 Nat Nat
| D3 Nat Nat Nat
data S (n :: Shape) where
S1D :: ( KnownNat len )
=> R len
-> S ('D1 len)
S2D :: ( KnownNat rows, KnownNat columns )
=> L rows columns
-> S ('D2 rows columns)
S3D :: ( KnownNat rows
, KnownNat columns
, KnownNat depth
, KnownNat (rows N.* depth))
=> L (rows N.* depth) columns
-> S ('D3 rows columns depth)
deriving instance Show (S n)
data instance Sing (n :: Shape) where
D1Sing :: KnownNat a => Sing a -> Sing ('D1 a)
D2Sing :: (KnownNat a, KnownNat b) => Sing a -> Sing b -> Sing ('D2 a b)
D3Sing :: (KnownNat a, KnownNat b, KnownNat c) => Sing a -> Sing b -> Sing c -> Sing ('D3 a b c)
instance KnownNat a => SingI ('D1 a) where
sing = D1Sing sing
instance (KnownNat a, KnownNat b) => SingI ('D2 a b) where
sing = D2Sing sing sing
instance (KnownNat a, KnownNat b, KnownNat c) => SingI ('D3 a b c) where
sing = D3Sing sing sing sing
type family ShapeEquals (sIn :: Shape) (sOut :: Shape) :: Bool where
ShapeEquals s s = 'True
ShapeEquals _ _ = 'False
type family ShapeEquals' (sIn :: Shape) (sOut :: Shape) :: Bool where
ShapeEquals' s s = 'True
ShapeEquals' s1 s2 =
TypeError ( 'Text "Couldn't match the Shape "
':<>: 'ShowType s1
':<>: 'Text " with the Shape "
':<>: 'ShowType s2)