{-# LANGUAGE CPP #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -Wall #-}
{-# OPTIONS_GHC -fno-warn-deprecations #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}

module NumHask.Array.Constraints
  ( IsValidConcat
  , Squeeze
  , Concatenate
  , IsValidTranspose
  , DimShuffle
  , dimShuffle
  , Fold
  , FoldAlong
  , TailModule
  , HeadModule
  , Transpose
  ) where

import Data.Singletons.Prelude hiding (Max)
import Data.Singletons.Prelude.List
       ((:!!$), Drop, Filter, Head, Insert,
        Length, Minimum, SplitAt, Sum, Take, ZipWith,)
import Data.Singletons.Prelude.Tuple (Fst, Snd)
import Data.Singletons.TH (promote)
import Data.Singletons.TypeLits (Nat)
import qualified Protolude as P
import GHC.Err (error)

#if ( __GLASGOW_HASKELL__ < 801 )
instance P.Eq Nat where
  x == y = P.not (x P./= y)
  x /= y = P.not (x P.== y)

instance P.Ord Nat where
  x > y = P.not (x P./= y) P.&& P.not (x P.< y)
  x < y = P.not (x P./= y) P.&& P.not (x P.> y)
  x <= y = (x P.== y) P.|| P.not (x P.> y)
  x >= y = (x P.== y) P.|| P.not (x P.< y)
#endif

(!!) :: [a] -> Nat -> a
[] !! _ = error "Data.Singletons.List.!!: index too large"
(x:xs) !! n =
  if n P.== 0
    then x
    else xs !! (n P.- 1)

type family DropDim d a :: [b] where
  DropDim 0 xs = Drop 1 xs
  DropDim d xs = Take (d :- 1) (Fst (SplitAt d xs)) :++ Snd (SplitAt d xs)

type family IsValidConcat i (a :: [Nat]) (b :: [Nat]) :: P.Bool where
  IsValidConcat _ '[] _ = 'P.False
  IsValidConcat _ _ '[] = 'P.False
  IsValidConcat i a b = And (ZipWith (:==$) (DropDim i a) (DropDim i b))

type family Squeeze (a :: [Nat]) where
  Squeeze '[] = '[]
  Squeeze a = Filter ((:/=$$) 1) a

type family IsValidTranspose (p :: [Nat]) (a :: [Nat]) :: P.Bool where
  IsValidTranspose p a =
    (Minimum p :>= 0) :&& (Minimum a :>= 0) :&& (Sum a :== Sum p) :&& Length p :== Length a

type family Transpose a where
  Transpose a = Reverse a

type family AddDimension (d :: Nat) t :: [Nat] where
  AddDimension d t = Insert d t

type family Concatenate i (a :: [Nat]) (b :: [Nat]) :: [Nat] where
  Concatenate i a b =
    Take i (Fst (SplitAt (i :+ 1) a)) :++
    ('[ Head (Drop i a) :+ Head (Drop i b)]) :++
    Snd (SplitAt (i :+ 1) b)

-- | Reduces axis i in shape s.  Maintains singlton dimension
type family FoldAlong i (s :: [Nat]) where
  FoldAlong _ '[] = '[]
  FoldAlong d xs = Take d (Fst (SplitAt (d :+ 1) xs)) :++ '[ 1] :++ Snd (SplitAt (d :+ 1) xs)

-- | Reduces axis i in shape s. Does not maintain singlton dimension.
type family Fold i (s :: [Nat]) where
  Fold _ '[] = '[]
  Fold d xs = Take d (Fst (SplitAt (d :+ 1) xs)) :++ Snd (SplitAt (d :+ 1) xs)

type family TailModule i (s :: [Nat]) where
  TailModule _ '[] = '[]
  TailModule d xs = (Snd (SplitAt d xs))

type family HeadModule i (s :: [Nat]) where
  HeadModule _ '[] = '[]
  HeadModule d xs = (Fst (SplitAt d xs))

$(promote
    [d|
  dimShuffle :: P.Eq a => [a] -> [Nat] -> [a]
  dimShuffle _ [] = []
  dimShuffle [] _ = []
  dimShuffle (x : xs) (b : bs)
    = if b P.== 0 then x : dimShuffle xs bs else
        (xs !! (b P.- 1)) : dimShuffle xs bs
  |])