module NumHask.Array.Constraints
( IsValidConcat
, Squeeze
, Concatenate
, IsValidTranspose
, DimShuffle
, dimShuffle
, Fold
, FoldAlong
, TailModule
, HeadModule
, Transpose
) where
import Data.Singletons.Prelude hiding (Max)
import Data.Singletons.Prelude.List
((:!!$), Drop, Filter, Head, Insert,
Length, Minimum, SplitAt, Sum, Take, ZipWith,)
import Data.Singletons.Prelude.Tuple (Fst, Snd)
import Data.Singletons.TH (promote)
import Data.Singletons.TypeLits (Nat)
import qualified Protolude as P
import GHC.Err (error)
#if ( __GLASGOW_HASKELL__ < 801 )
instance P.Eq Nat where
x == y = P.not (x P./= y)
x /= y = P.not (x P.== y)
instance P.Ord Nat where
x > y = P.not (x P./= y) P.&& P.not (x P.< y)
x < y = P.not (x P./= y) P.&& P.not (x P.> y)
x <= y = (x P.== y) P.|| P.not (x P.> y)
x >= y = (x P.== y) P.|| P.not (x P.< y)
#endif
(!!) :: [a] -> Nat -> a
[] !! _ = error "Data.Singletons.List.!!: index too large"
(x:xs) !! n =
if n P.== 0
then x
else xs !! (n P.- 1)
type family DropDim d a :: [b] where
DropDim 0 xs = Drop 1 xs
DropDim d xs = Take (d :- 1) (Fst (SplitAt d xs)) :++ Snd (SplitAt d xs)
type family IsValidConcat i (a :: [Nat]) (b :: [Nat]) :: P.Bool where
IsValidConcat _ '[] _ = 'P.False
IsValidConcat _ _ '[] = 'P.False
IsValidConcat i a b = And (ZipWith (:==$) (DropDim i a) (DropDim i b))
type family Squeeze (a :: [Nat]) where
Squeeze '[] = '[]
Squeeze a = Filter ((:/=$$) 1) a
type family IsValidTranspose (p :: [Nat]) (a :: [Nat]) :: P.Bool where
IsValidTranspose p a =
(Minimum p :>= 0) :&& (Minimum a :>= 0) :&& (Sum a :== Sum p) :&& Length p :== Length a
type family Transpose a where
Transpose a = Reverse a
type family AddDimension (d :: Nat) t :: [Nat] where
AddDimension d t = Insert d t
type family Concatenate i (a :: [Nat]) (b :: [Nat]) :: [Nat] where
Concatenate i a b =
Take i (Fst (SplitAt (i :+ 1) a)) :++
('[ Head (Drop i a) :+ Head (Drop i b)]) :++
Snd (SplitAt (i :+ 1) b)
type family FoldAlong i (s :: [Nat]) where
FoldAlong _ '[] = '[]
FoldAlong d xs = Take d (Fst (SplitAt (d :+ 1) xs)) :++ '[ 1] :++ Snd (SplitAt (d :+ 1) xs)
type family Fold i (s :: [Nat]) where
Fold _ '[] = '[]
Fold d xs = Take d (Fst (SplitAt (d :+ 1) xs)) :++ Snd (SplitAt (d :+ 1) xs)
type family TailModule i (s :: [Nat]) where
TailModule _ '[] = '[]
TailModule d xs = (Snd (SplitAt d xs))
type family HeadModule i (s :: [Nat]) where
HeadModule _ '[] = '[]
HeadModule d xs = (Fst (SplitAt d xs))
$(promote
[d|
dimShuffle :: P.Eq a => [a] -> [Nat] -> [a]
dimShuffle _ [] = []
dimShuffle [] _ = []
dimShuffle (x : xs) (b : bs)
= if b P.== 0 then x : dimShuffle xs bs else
(xs !! (b P.- 1)) : dimShuffle xs bs
|])