{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RebindableSyntax #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StrictData #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE NoStarIsType #-}
{-# OPTIONS_GHC -Wall #-}
{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
{-# OPTIONS_GHC -fno-warn-unused-top-binds #-}

-- | Functions for manipulating shape. The module tends to supply equivalent functionality at type-level and value-level with functions of the same name (except for capitalization).
module NumHask.Array.Shape
  ( Shape (..),
    HasShape (..),
    type (++),
    type (!!),
    Take,
    Drop,
    Reverse,
    ReverseGo,
    Filter,
    rank,
    Rank,
    ranks,
    Ranks,
    size,
    Size,
    dimension,
    Dimension,
    flatten,
    shapen,
    minimum,
    Minimum,
    checkIndex,
    CheckIndex,
    checkIndexes,
    CheckIndexes,
    addIndex,
    AddIndex,
    dropIndex,
    DropIndex,
    posRelative,
    PosRelative,
    PosRelativeGo,
    addIndexes,
    AddIndexes,
    AddIndexesGo,
    dropIndexes,
    DropIndexes,
    takeIndexes,
    TakeIndexes,
    exclude,
    Exclude,
    concatenate',
    Concatenate,
    CheckConcatenate,
    Insert,
    CheckInsert,
    reorder',
    Reorder,
    CheckReorder,
    squeeze',
    Squeeze,
    incAt,
    decAt,
    KnownNats (..),
    KnownNatss (..),
  )
where

import Data.Proxy
import Data.Type.Bool
import Data.Type.Equality
import GHC.TypeLits as L
import NumHask.Prelude as P hiding (Last, minimum)

-- $setup
-- >>> :set -XDataKinds
-- >>> :set -XOverloadedLists
-- >>> :set -XTypeFamilies
-- >>> :set -XFlexibleContexts
-- >>> :set -XRebindableSyntax
-- >>> import NumHask.Prelude

-- | The Shape type holds a [Nat] at type level and the equivalent [Int] at value level.
-- Using [Int] as the index for an array nicely represents the practical interests and constraints downstream of this high-level API: densely-packed numbers (reals or integrals), indexed and layered.
newtype Shape (s :: [Nat]) = Shape {Shape s -> [Int]
shapeVal :: [Int]} deriving (Int -> Shape s -> ShowS
[Shape s] -> ShowS
Shape s -> String
(Int -> Shape s -> ShowS)
-> (Shape s -> String) -> ([Shape s] -> ShowS) -> Show (Shape s)
forall (s :: [Nat]). Int -> Shape s -> ShowS
forall (s :: [Nat]). [Shape s] -> ShowS
forall (s :: [Nat]). Shape s -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Shape s] -> ShowS
$cshowList :: forall (s :: [Nat]). [Shape s] -> ShowS
show :: Shape s -> String
$cshow :: forall (s :: [Nat]). Shape s -> String
showsPrec :: Int -> Shape s -> ShowS
$cshowsPrec :: forall (s :: [Nat]). Int -> Shape s -> ShowS
Show)

class HasShape s where
  toShape :: Shape s

instance HasShape '[] where
  toShape :: Shape '[]
toShape = [Int] -> Shape '[]
forall (s :: [Nat]). [Int] -> Shape s
Shape []

instance (KnownNat n, HasShape s) => HasShape (n : s) where
  toShape :: Shape (n : s)
toShape = [Int] -> Shape (n : s)
forall (s :: [Nat]). [Int] -> Shape s
Shape ([Int] -> Shape (n : s)) -> [Int] -> Shape (n : s)
forall a b. (a -> b) -> a -> b
$ Integer -> Int
forall a. FromInteger a => Integer -> a
fromInteger (Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal (Proxy n
forall k (t :: k). Proxy t
Proxy :: Proxy n)) Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: Shape s -> [Int]
forall (s :: [Nat]). Shape s -> [Int]
shapeVal (Shape s
forall (s :: [Nat]). HasShape s => Shape s
toShape :: Shape s)

-- | Number of dimensions
rank :: [a] -> Int
rank :: [a] -> Int
rank = [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length
{-# INLINE rank #-}

type family Rank (s :: [a]) :: Nat where
  Rank '[] = 0
  Rank (_ : s) = Rank s + 1

-- | The shape of a list of element indexes
ranks :: [[a]] -> [Int]
ranks :: [[a]] -> [Int]
ranks = ([a] -> Int) -> [[a]] -> [Int]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap [a] -> Int
forall a. [a] -> Int
rank
{-# INLINE ranks #-}

type family Ranks (s :: [[a]]) :: [Nat] where
  Ranks '[] = '[]
  Ranks (x : xs) = Rank x : Ranks xs

-- | Number of elements
size :: [Int] -> Int
size :: [Int] -> Int
size [] = Int
1
size [Int
x] = Int
x
size [Int]
xs = [Int] -> Int
forall a (f :: Type -> Type).
(Multiplicative a, Foldable f) =>
f a -> a
P.product [Int]
xs
{-# INLINE size #-}

type family Size (s :: [Nat]) :: Nat where
  Size '[] = 1
  Size (n : s) = n L.* Size s

-- | convert from n-dim shape index to a flat index
--
-- >>> flatten [2,3,4] [1,1,1]
-- 17
--
-- >>> flatten [] [1,1,1]
-- 0
flatten :: [Int] -> [Int] -> Int
flatten :: [Int] -> [Int] -> Int
flatten [] [Int]
_ = Int
0
flatten [Int]
_ [Int
x'] = Int
x'
flatten [Int]
ns [Int]
xs = [Int] -> Int
forall a (f :: Type -> Type). (Additive a, Foldable f) => f a -> a
sum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> Int
forall a. Multiplicative a => a -> a -> a
(*) [Int]
xs (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
drop Int
1 ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> Int) -> Int -> [Int] -> [Int]
forall a b. (a -> b -> b) -> b -> [a] -> [b]
scanr Int -> Int -> Int
forall a. Multiplicative a => a -> a -> a
(*) Int
forall a. Multiplicative a => a
one [Int]
ns)
{-# INLINE flatten #-}

-- | convert from a flat index to a shape index
--
-- >>> shapen [2,3,4] 17
-- [1,1,1]
shapen :: [Int] -> Int -> [Int]
shapen :: [Int] -> Int -> [Int]
shapen [] Int
_ = []
shapen [Int
_] Int
x' = [Int
x']
shapen [Int
_, Int
y] Int
x' = let (Int
i, Int
j) = Int -> Int -> (Int, Int)
forall a. Integral a => a -> a -> (a, a)
divMod Int
x' Int
y in [Int
i, Int
j]
shapen [Int]
ns Int
x =
  ([Int], Int) -> [Int]
forall a b. (a, b) -> a
fst (([Int], Int) -> [Int]) -> ([Int], Int) -> [Int]
forall a b. (a -> b) -> a -> b
$
    (Int -> ([Int], Int) -> ([Int], Int))
-> ([Int], Int) -> [Int] -> ([Int], Int)
forall (t :: Type -> Type) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr
      ( \Int
a ([Int]
acc, Int
r) ->
          let (Int
d, Int
m) = Int -> Int -> (Int, Int)
forall a. Integral a => a -> a -> (a, a)
divMod Int
r Int
a
           in (Int
m Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int]
acc, Int
d)
      )
      ([], Int
x)
      [Int]
ns
{-# INLINE shapen #-}

-- | /checkIndex i n/ checks if /i/ is a valid index of a list of length /n/
checkIndex :: Int -> Int -> Bool
checkIndex :: Int -> Int -> Bool
checkIndex Int
i Int
n = Int
forall a. Additive a => a
zero Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
i Bool -> Bool -> Bool
&& Int
i Int -> Int -> Int
forall a. Additive a => a -> a -> a
+ Int
forall a. Multiplicative a => a
one Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
n

type family CheckIndex (i :: Nat) (n :: Nat) :: Bool where
  CheckIndex i n =
    If ((0 <=? i) && (i + 1 <=? n)) 'True (L.TypeError ('Text "index outside range"))

-- | /checkIndexes is n/ check if /is/ are valid indexes of a list of length /n/
checkIndexes :: [Int] -> Int -> Bool
checkIndexes :: [Int] -> Int -> Bool
checkIndexes [Int]
is Int
n = (Int -> Bool) -> [Int] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
all (Int -> Int -> Bool
`checkIndex` Int
n) [Int]
is

type family CheckIndexes (i :: [Nat]) (n :: Nat) :: Bool where
  CheckIndexes '[] _ = 'True
  CheckIndexes (i : is) n = CheckIndex i n && CheckIndexes is n

-- | dimension i is the i'th dimension of a Shape
dimension :: [Int] -> Int -> Int
dimension :: [Int] -> Int -> Int
dimension (Int
s : [Int]
_) Int
0 = Int
s
dimension (Int
_ : [Int]
s) Int
n = [Int] -> Int -> Int
dimension [Int]
s (Int
n Int -> Int -> Int
forall a. Subtractive a => a -> a -> a
- Int
1)
dimension [Int]
_ Int
_ = NumHaskException -> Int
forall a e. Exception e => e -> a
throw (String -> NumHaskException
NumHaskException String
"dimension overflow")

type family Dimension (s :: [Nat]) (i :: Nat) :: Nat where
  Dimension (s : _) 0 = s
  Dimension (_ : s) n = Dimension s (n - 1)
  Dimension _ _ = L.TypeError ('Text "dimension overflow")

-- | minimum value in a list
minimum :: [Int] -> Int
minimum :: [Int] -> Int
minimum [] = NumHaskException -> Int
forall a e. Exception e => e -> a
throw (String -> NumHaskException
NumHaskException String
"dimension underflow")
minimum [Int
x] = Int
x
minimum (Int
x : [Int]
xs) = Int -> Int -> Int
forall a. Ord a => a -> a -> a
P.min Int
x ([Int] -> Int
minimum [Int]
xs)

type family Minimum (s :: [Nat]) :: Nat where
  Minimum '[] = L.TypeError ('Text "zero dimension")
  Minimum '[x] = x
  Minimum (x : xs) = If (x <=? Minimum xs) x (Minimum xs)

type family Take (n :: Nat) (a :: [k]) :: [k] where
  Take 0 _ = '[]
  Take n (x : xs) = x : Take (n - 1) xs

type family Drop (n :: Nat) (a :: [k]) :: [k] where
  Drop 0 xs = xs
  Drop n (_ : xs) = Drop (n - 1) xs

type family Tail (a :: [k]) :: [k] where
  Tail '[] = L.TypeError ('Text "No tail")
  Tail (_ : xs) = xs

type family Init (a :: [k]) :: [k] where
  Init '[] = L.TypeError ('Text "No init")
  Init '[_] = '[]
  Init (x : xs) = x : Init xs

type family Head (a :: [k]) :: k where
  Head '[] = L.TypeError ('Text "No head")
  Head (x : _) = x

type family Last (a :: [k]) :: k where
  Last '[] = L.TypeError ('Text "No last")
  Last '[x] = x
  Last (_ : xs) = Last xs

type family (a :: [k]) ++ (b :: [k]) :: [k] where
  '[] ++ b = b
  (a : as) ++ b = a : (as ++ b)

-- | drop the i'th dimension from a shape
--
-- >>> dropIndex [2, 3, 4] 1
-- [2,4]
dropIndex :: [Int] -> Int -> [Int]
dropIndex :: [Int] -> Int -> [Int]
dropIndex [Int]
s Int
i = Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take Int
i [Int]
s [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
drop (Int
i Int -> Int -> Int
forall a. Additive a => a -> a -> a
+ Int
1) [Int]
s

type DropIndex s i = Take i s ++ Drop (i + 1) s

-- | /addIndex s i d/ adds a new dimension to shape /s/ at position /i/
--
-- >>> addIndex [2,4] 1 3
-- [2,3,4]
addIndex :: [Int] -> Int -> Int -> [Int]
addIndex :: [Int] -> Int -> Int -> [Int]
addIndex [Int]
s Int
i Int
d = Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take Int
i [Int]
s [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ (Int
d Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
drop Int
i [Int]
s)

type AddIndex s i d = Take i s ++ (d : Drop i s)

type Reverse (a :: [k]) = ReverseGo a '[]

type family ReverseGo (a :: [k]) (b :: [k]) :: [k] where
  ReverseGo '[] b = b
  ReverseGo (a : as) b = ReverseGo as (a : b)

-- | convert a list of position that references a final shape to one that references positions relative to an accumulator.  Deletions are from the left and additions are from the right.
--
-- deletions
--
-- >>> posRelative [0,1]
-- [0,0]
--
-- additions
--
-- >>> reverse (posRelative (reverse [1,0]))
-- [0,0]
posRelative :: [Int] -> [Int]
posRelative :: [Int] -> [Int]
posRelative [Int]
as = [Int] -> [Int]
forall a. [a] -> [a]
reverse ([Int] -> [Int] -> [Int]
forall b.
(Subtractive b, Multiplicative b, Ord b) =>
[b] -> [b] -> [b]
go [] [Int]
as)
  where
    go :: [b] -> [b] -> [b]
go [b]
r [] = [b]
r
    go [b]
r (b
x : [b]
xs) = [b] -> [b] -> [b]
go (b
x b -> [b] -> [b]
forall a. a -> [a] -> [a]
: [b]
r) ((\b
y -> b -> b -> Bool -> b
forall a. a -> a -> Bool -> a
bool (b
y b -> b -> b
forall a. Subtractive a => a -> a -> a
- b
forall a. Multiplicative a => a
one) b
y (b
y b -> b -> Bool
forall a. Ord a => a -> a -> Bool
< b
x)) (b -> b) -> [b] -> [b]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [b]
xs)

type family PosRelative (s :: [Nat]) where
  PosRelative s = PosRelativeGo s '[]

type family PosRelativeGo (r :: [Nat]) (s :: [Nat]) where
  PosRelativeGo '[] r = Reverse r
  PosRelativeGo (x : xs) r = PosRelativeGo (DecMap x xs) (x : r)

type family DecMap (x :: Nat) (ys :: [Nat]) :: [Nat] where
  DecMap _ '[] = '[]
  DecMap x (y : ys) = If (y + 1 <=? x) y (y - 1) : DecMap x ys

-- | drop dimensions of a shape according to a list of positions (where position refers to the initial shape)
--
-- >>> dropIndexes [2, 3, 4] [1, 0]
-- [4]
dropIndexes :: [Int] -> [Int] -> [Int]
dropIndexes :: [Int] -> [Int] -> [Int]
dropIndexes [Int]
s [Int]
i = ([Int] -> Int -> [Int]) -> [Int] -> [Int] -> [Int]
forall (t :: Type -> Type) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' [Int] -> Int -> [Int]
dropIndex [Int]
s ([Int] -> [Int]
posRelative [Int]
i)

type family DropIndexes (s :: [Nat]) (i :: [Nat]) where
  DropIndexes s i = DropIndexesGo s (PosRelative i)

type family DropIndexesGo (s :: [Nat]) (i :: [Nat]) where
  DropIndexesGo s '[] = s
  DropIndexesGo s (i : is) = DropIndexesGo (DropIndex s i) is

-- | insert a list of dimensions according to position and dimension lists.  Note that the list of positions references the final shape and not the initial shape.
--
-- >>> addIndexes [4] [1,0] [3,2]
-- [2,3,4]
addIndexes :: () => [Int] -> [Int] -> [Int] -> [Int]
addIndexes :: [Int] -> [Int] -> [Int] -> [Int]
addIndexes [Int]
as [Int]
xs = [Int] -> [Int] -> [Int] -> [Int]
addIndexesGo [Int]
as ([Int] -> [Int]
forall a. [a] -> [a]
reverse ([Int] -> [Int]
posRelative ([Int] -> [Int]
forall a. [a] -> [a]
reverse [Int]
xs)))
  where
    addIndexesGo :: [Int] -> [Int] -> [Int] -> [Int]
addIndexesGo [Int]
as' [] [Int]
_ = [Int]
as'
    addIndexesGo [Int]
as' (Int
x : [Int]
xs') (Int
y : [Int]
ys') = [Int] -> [Int] -> [Int] -> [Int]
addIndexesGo ([Int] -> Int -> Int -> [Int]
addIndex [Int]
as' Int
x Int
y) [Int]
xs' [Int]
ys'
    addIndexesGo [Int]
_ [Int]
_ [Int]
_ = NumHaskException -> [Int]
forall a e. Exception e => e -> a
throw (String -> NumHaskException
NumHaskException String
"mismatched ranks")

type family AddIndexes (as :: [Nat]) (xs :: [Nat]) (ys :: [Nat]) where
  AddIndexes as xs ys = AddIndexesGo as (Reverse (PosRelative (Reverse xs))) ys

type family AddIndexesGo (as :: [Nat]) (xs :: [Nat]) (ys :: [Nat]) where
  AddIndexesGo as' '[] _ = as'
  AddIndexesGo as' (x : xs') (y : ys') = AddIndexesGo (AddIndex as' x y) xs' ys'
  AddIndexesGo _ _ _ = L.TypeError ('Text "mismatched ranks")

-- | take list of dimensions according to position lists.
--
-- >>> takeIndexes [2,3,4] [2,0]
-- [4,2]
takeIndexes :: [Int] -> [Int] -> [Int]
takeIndexes :: [Int] -> [Int] -> [Int]
takeIndexes [Int]
s [Int]
i = ([Int]
s [Int] -> Int -> Int
forall a. [a] -> Int -> a
!!) (Int -> Int) -> [Int] -> [Int]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int]
i

type family TakeIndexes (s :: [Nat]) (i :: [Nat]) where
  TakeIndexes '[] _ = '[]
  TakeIndexes _ '[] = '[]
  TakeIndexes s (i : is) =
    (s !! i) ': TakeIndexes s is

type family (a :: [k]) !! (b :: Nat) :: k where
  (!!) '[] _ = L.TypeError ('Text "Index Underflow")
  (!!) (x : _) 0 = x
  (!!) (_ : xs) i = (!!) xs (i - 1)

type family Enumerate (n :: Nat) where
  Enumerate n = Reverse (EnumerateGo n)

type family EnumerateGo (n :: Nat) where
  EnumerateGo 0 = '[]
  EnumerateGo n = (n - 1) : EnumerateGo (n - 1)

-- | turn a list of included positions for a given rank into a list of excluded positions
--
-- >>> exclude 3 [1,2]
-- [0]
exclude :: Int -> [Int] -> [Int]
exclude :: Int -> [Int] -> [Int]
exclude Int
r = [Int] -> [Int] -> [Int]
dropIndexes [Int
0 .. (Int
r Int -> Int -> Int
forall a. Subtractive a => a -> a -> a
- Int
1)]

type family Exclude (r :: Nat) (i :: [Nat]) where
  Exclude r i = DropIndexes (EnumerateGo r) i

-- | concatenate
--
-- >>> concatenate' 1 [2,3,4] [2,3,4]
-- [2,6,4]
concatenate' :: Int -> [Int] -> [Int] -> [Int]
concatenate' :: Int -> [Int] -> [Int] -> [Int]
concatenate' Int
i [Int]
s0 [Int]
s1 = Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take Int
i [Int]
s0 [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ ([Int] -> Int -> Int
dimension [Int]
s0 Int
i Int -> Int -> Int
forall a. Additive a => a -> a -> a
+ [Int] -> Int -> Int
dimension [Int]
s1 Int
i Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
drop (Int
i Int -> Int -> Int
forall a. Additive a => a -> a -> a
+ Int
1) [Int]
s0)

type Concatenate i s0 s1 = Take i s0 ++ (Dimension s0 i + Dimension s1 i : Drop (i + 1) s0)

type CheckConcatenate i s0 s1 s =
  ( CheckIndex i (Rank s0)
      && DropIndex s0 i == DropIndex s1 i
      && Rank s0 == Rank s1
  )
    ~ 'True

type CheckInsert d i s =
  (CheckIndex d (Rank s) && CheckIndex i (Dimension s d)) ~ 'True

type Insert d s = Take d s ++ (Dimension s d + 1 : Drop (d + 1) s)

-- | /incAt d s/ increments the index at /d/ of shape /s/ by one.
incAt :: Int -> [Int] -> [Int]
incAt :: Int -> [Int] -> [Int]
incAt Int
d [Int]
s = Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take Int
d [Int]
s [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ ([Int] -> Int -> Int
dimension [Int]
s Int
d Int -> Int -> Int
forall a. Additive a => a -> a -> a
+ Int
1 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
drop (Int
d Int -> Int -> Int
forall a. Additive a => a -> a -> a
+ Int
1) [Int]
s)

-- | /decAt d s/ decrements the index at /d/ of shape /s/ by one.
decAt :: Int -> [Int] -> [Int]
decAt :: Int -> [Int] -> [Int]
decAt Int
d [Int]
s = Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take Int
d [Int]
s [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ ([Int] -> Int -> Int
dimension [Int]
s Int
d Int -> Int -> Int
forall a. Subtractive a => a -> a -> a
- Int
1 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
drop (Int
d Int -> Int -> Int
forall a. Additive a => a -> a -> a
+ Int
1) [Int]
s)

-- | /reorder' s i/ reorders the dimensions of shape /s/ according to a list of positions /i/
--
-- >>> reorder' [2,3,4] [2,0,1]
-- [4,2,3]
reorder' :: [Int] -> [Int] -> [Int]
reorder' :: [Int] -> [Int] -> [Int]
reorder' [] [Int]
_ = []
reorder' [Int]
_ [] = []
reorder' [Int]
s (Int
d : [Int]
ds) = [Int] -> Int -> Int
dimension [Int]
s Int
d Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int] -> [Int] -> [Int]
reorder' [Int]
s [Int]
ds

type family Reorder (s :: [Nat]) (ds :: [Nat]) :: [Nat] where
  Reorder '[] _ = '[]
  Reorder _ '[] = '[]
  Reorder s (d : ds) = Dimension s d : Reorder s ds

type family CheckReorder (ds :: [Nat]) (s :: [Nat]) where
  CheckReorder ds s =
    If
      ( Rank ds == Rank s
          && CheckIndexes ds (Rank s)
      )
      'True
      (L.TypeError ('Text "bad dimensions"))
      ~ 'True

-- | remove 1's from a list
squeeze' :: (Eq a, Multiplicative a) => [a] -> [a]
squeeze' :: [a] -> [a]
squeeze' = (a -> Bool) -> [a] -> [a]
forall a. (a -> Bool) -> [a] -> [a]
filter (a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= a
forall a. Multiplicative a => a
one)

type family Squeeze (a :: [Nat]) where
  Squeeze '[] = '[]
  Squeeze a = Filter '[] a 1

type family Filter (r :: [Nat]) (xs :: [Nat]) (i :: Nat) where
  Filter r '[] _ = Reverse r
  Filter r (x : xs) i = Filter (If (x == i) r (x : r)) xs i

-- unused but useful type-level functions

type family Sort (xs :: [k]) :: [k] where
  Sort '[] = '[]
  Sort (x ': xs) = (Sort (SFilter 'FMin x xs) ++ '[x]) ++ Sort (SFilter 'FMax x xs)

data Flag = FMin | FMax

type family Cmp (a :: k) (b :: k) :: Ordering

type family SFilter (f :: Flag) (p :: k) (xs :: [k]) :: [k] where
  SFilter f p '[] = '[]
  SFilter 'FMin p (x ': xs) = If (Cmp x p == 'LT) (x ': SFilter 'FMin p xs) (SFilter 'FMin p xs)
  SFilter 'FMax p (x ': xs) = If (Cmp x p == 'GT || Cmp x p == 'EQ) (x ': SFilter 'FMax p xs) (SFilter 'FMax p xs)

type family Zip lst lst' where
  Zip lst lst' = ZipWith '(,) lst lst' -- Implemented as TF because #11375

type family ZipWith f lst lst' where
  ZipWith f '[] lst = '[]
  ZipWith f lst '[] = '[]
  ZipWith f (l ': ls) (n ': ns) = f l n ': ZipWith f ls ns

type family Fst a where
  Fst '(a, _) = a

type family Snd a where
  Snd '(_, a) = a

type family FMap f lst where
  FMap f '[] = '[]
  FMap f (l ': ls) = f l ': FMap f ls

-- | Reflect a list of Nats
class KnownNats (ns :: [Nat]) where
  natVals :: Proxy ns -> [Int]

instance KnownNats '[] where
  natVals :: Proxy '[] -> [Int]
natVals Proxy '[]
_ = []

instance (KnownNat n, KnownNats ns) => KnownNats (n : ns) where
  natVals :: Proxy (n : ns) -> [Int]
natVals Proxy (n : ns)
_ = Integer -> Int
forall a. FromInteger a => Integer -> a
fromInteger (Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal (Proxy n
forall k (t :: k). Proxy t
Proxy @n)) Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: Proxy ns -> [Int]
forall (ns :: [Nat]). KnownNats ns => Proxy ns -> [Int]
natVals (Proxy ns
forall k (t :: k). Proxy t
Proxy @ns)

-- | Reflect a list of list of Nats
class KnownNatss (ns :: [[Nat]]) where
  natValss :: Proxy ns -> [[Int]]

instance KnownNatss '[] where
  natValss :: Proxy '[] -> [[Int]]
natValss Proxy '[]
_ = []

instance (KnownNats n, KnownNatss ns) => KnownNatss (n : ns) where
  natValss :: Proxy (n : ns) -> [[Int]]
natValss Proxy (n : ns)
_ = Proxy n -> [Int]
forall (ns :: [Nat]). KnownNats ns => Proxy ns -> [Int]
natVals (Proxy n
forall k (t :: k). Proxy t
Proxy @n) [Int] -> [[Int]] -> [[Int]]
forall a. a -> [a] -> [a]
: Proxy ns -> [[Int]]
forall (ns :: [[Nat]]). KnownNatss ns => Proxy ns -> [[Int]]
natValss (Proxy ns
forall k (t :: k). Proxy t
Proxy @ns)