{-# LANGUAGE ConstraintKinds       #-}
{-# LANGUAGE DataKinds             #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE LambdaCase            #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ParallelListComp      #-}
{-# LANGUAGE PatternSynonyms       #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE TemplateHaskell       #-}
{-# LANGUAGE TypeApplications      #-}
{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE TypeOperators         #-}
{-# LANGUAGE UndecidableInstances  #-}
{-# LANGUAGE ViewPatterns          #-}
-- |
-- Module      : Data.Array.Accelerate.Pattern
-- Copyright   : [2018..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.Pattern (

  pattern Pattern,
  pattern T2,  pattern T3,  pattern T4,  pattern T5,  pattern T6,
  pattern T7,  pattern T8,  pattern T9,  pattern T10, pattern T11,
  pattern T12, pattern T13, pattern T14, pattern T15, pattern T16,

  pattern Z_, pattern Ix, pattern (::.),
  pattern I0, pattern I1, pattern I2, pattern I3, pattern I4,
  pattern I5, pattern I6, pattern I7, pattern I8, pattern I9,

  pattern V2, pattern V3, pattern V4, pattern V8, pattern V16,

) where

import Data.Array.Accelerate.AST.Idx
import Data.Array.Accelerate.Representation.Tag
import Data.Array.Accelerate.Representation.Vec
import Data.Array.Accelerate.Smart
import Data.Array.Accelerate.Sugar.Array
import Data.Array.Accelerate.Sugar.Elt
import Data.Array.Accelerate.Sugar.Shape
import Data.Array.Accelerate.Sugar.Vec
import Data.Array.Accelerate.Type
import Data.Primitive.Vec

import Language.Haskell.TH                                          hiding ( Exp, Match, tupP, tupE )
import Language.Haskell.TH.Extra


-- | A pattern synonym for working with (product) data types. You can declare
-- your own pattern synonyms based off of this.
--
pattern Pattern :: forall b a context. IsPattern context a b => b -> context a
pattern $bPattern :: b -> context a
$mPattern :: forall r b a (context :: * -> *).
IsPattern context a b =>
context a -> (b -> r) -> (Void# -> r) -> r
Pattern vars <- (destruct @context -> vars)
  where Pattern = forall a b. IsPattern context a b => b -> context a
forall (con :: * -> *) a b. IsPattern con a b => b -> con a
construct @context

class IsPattern con a b where
  construct :: b -> con a
  destruct  :: con a -> b


pattern Vector :: forall b a context. IsVector context a b => b -> context a
pattern $bVector :: b -> context a
$mVector :: forall r b a (context :: * -> *).
IsVector context a b =>
context a -> (b -> r) -> (Void# -> r) -> r
Vector vars <- (vunpack @context -> vars)
  where Vector = forall a b. IsVector context a b => b -> context a
forall (context :: * -> *) a b.
IsVector context a b =>
b -> context a
vpack @context

class IsVector context a b where
  vpack   :: b -> context a
  vunpack :: context a -> b

-- | Pattern synonyms for indices, which may be more convenient to use than
-- 'Data.Array.Accelerate.Lift.lift' and
-- 'Data.Array.Accelerate.Lift.unlift'.
--
pattern Z_ :: Exp DIM0
pattern $bZ_ :: Exp DIM0
$mZ_ :: forall r. Exp DIM0 -> (Void# -> r) -> (Void# -> r) -> r
Z_ = Pattern Z
{-# COMPLETE Z_ #-}

infixl 3 ::.
pattern (::.) :: (Elt a, Elt b) => Exp a -> Exp b -> Exp (a :. b)
pattern a $b::. :: Exp a -> Exp b -> Exp (a :. b)
$m::. :: forall r a b.
(Elt a, Elt b) =>
Exp (a :. b) -> (Exp a -> Exp b -> r) -> (Void# -> r) -> r
::. b = Pattern (a :. b)
{-# COMPLETE (::.) #-}

infixl 3 `Ix`
pattern Ix :: (Elt a, Elt b) => Exp a -> Exp b -> Exp (a :. b)
pattern a $bIx :: Exp a -> Exp b -> Exp (a :. b)
$mIx :: forall r a b.
(Elt a, Elt b) =>
Exp (a :. b) -> (Exp a -> Exp b -> r) -> (Void# -> r) -> r
`Ix` b = a ::. b
{-# COMPLETE Ix #-}

-- IsPattern instances for Shape nil and cons
--
instance IsPattern Exp Z Z where
  construct :: DIM0 -> Exp DIM0
construct DIM0
_ = DIM0 -> Exp DIM0
forall e. (HasCallStack, Elt e) => e -> Exp e
constant DIM0
Z
  destruct :: Exp DIM0 -> DIM0
destruct Exp DIM0
_  = DIM0
Z

instance (Elt a, Elt b) => IsPattern Exp (a :. b) (Exp a :. Exp b) where
  construct :: (Exp a :. Exp b) -> Exp (a :. b)
construct (Exp SmartExp (EltR a)
a :. Exp SmartExp (EltR b)
b) = SmartExp (EltR (a :. b)) -> Exp (a :. b)
forall t. SmartExp (EltR t) -> Exp t
Exp (SmartExp (EltR (a :. b)) -> Exp (a :. b))
-> SmartExp (EltR (a :. b)) -> Exp (a :. b)
forall a b. (a -> b) -> a -> b
$ PreSmartExp SmartAcc SmartExp (EltR a, EltR b)
-> SmartExp (EltR a, EltR b)
forall t. PreSmartExp SmartAcc SmartExp t -> SmartExp t
SmartExp (PreSmartExp SmartAcc SmartExp (EltR a, EltR b)
 -> SmartExp (EltR a, EltR b))
-> PreSmartExp SmartAcc SmartExp (EltR a, EltR b)
-> SmartExp (EltR a, EltR b)
forall a b. (a -> b) -> a -> b
$ SmartExp (EltR a)
-> SmartExp (EltR b)
-> PreSmartExp SmartAcc SmartExp (EltR a, EltR b)
forall (exp :: * -> *) t1 t2 (acc :: * -> *).
exp t1 -> exp t2 -> PreSmartExp acc exp (t1, t2)
Pair SmartExp (EltR a)
a SmartExp (EltR b)
b
  destruct :: Exp (a :. b) -> Exp a :. Exp b
destruct (Exp SmartExp (EltR (a :. b))
t)           = SmartExp (EltR a) -> Exp a
forall t. SmartExp (EltR t) -> Exp t
Exp (PreSmartExp SmartAcc SmartExp (EltR a) -> SmartExp (EltR a)
forall t. PreSmartExp SmartAcc SmartExp t -> SmartExp t
SmartExp (PreSmartExp SmartAcc SmartExp (EltR a) -> SmartExp (EltR a))
-> PreSmartExp SmartAcc SmartExp (EltR a) -> SmartExp (EltR a)
forall a b. (a -> b) -> a -> b
$ PairIdx (EltR a, EltR b) (EltR a)
-> SmartExp (EltR a, EltR b)
-> PreSmartExp SmartAcc SmartExp (EltR a)
forall t1 t2 t (exp :: * -> *) (acc :: * -> *).
PairIdx (t1, t2) t -> exp (t1, t2) -> PreSmartExp acc exp t
Prj PairIdx (EltR a, EltR b) (EltR a)
forall a b. PairIdx (a, b) a
PairIdxLeft SmartExp (EltR a, EltR b)
SmartExp (EltR (a :. b))
t) Exp a -> Exp b -> Exp a :. Exp b
forall tail head. tail -> head -> tail :. head
:. SmartExp (EltR b) -> Exp b
forall t. SmartExp (EltR t) -> Exp t
Exp (PreSmartExp SmartAcc SmartExp (EltR b) -> SmartExp (EltR b)
forall t. PreSmartExp SmartAcc SmartExp t -> SmartExp t
SmartExp (PreSmartExp SmartAcc SmartExp (EltR b) -> SmartExp (EltR b))
-> PreSmartExp SmartAcc SmartExp (EltR b) -> SmartExp (EltR b)
forall a b. (a -> b) -> a -> b
$ PairIdx (EltR a, EltR b) (EltR b)
-> SmartExp (EltR a, EltR b)
-> PreSmartExp SmartAcc SmartExp (EltR b)
forall t1 t2 t (exp :: * -> *) (acc :: * -> *).
PairIdx (t1, t2) t -> exp (t1, t2) -> PreSmartExp acc exp t
Prj PairIdx (EltR a, EltR b) (EltR b)
forall a b. PairIdx (a, b) b
PairIdxRight SmartExp (EltR a, EltR b)
SmartExp (EltR (a :. b))
t)


-- IsPattern instances for up to 16-tuples (Acc and Exp). TH takes care of
-- the (unremarkable) boilerplate for us.
--
runQ $ do
    let
        -- Generate instance declarations for IsPattern of the form:
        -- instance (Arrays x, ArraysR x ~ (((), ArraysR a), ArraysR b), Arrays a, Arrays b,) => IsPattern Acc x (Acc a, Acc b)
        mkAccPattern :: Int -> Q [Dec]
        mkAccPattern n = do
          a <- newName "a"
          let
              -- Type variables for the elements
              xs       = [ mkName ('x' : show i) | i <- [0 .. n-1] ]
              -- Last argument to `IsPattern`, eg (Acc a, Acc b) in the example
              b        = tupT (map (\t -> [t| Acc $(varT t)|]) xs)
              -- Representation as snoc-list of pairs, eg (((), ArraysR a), ArraysR b)
              snoc     = foldl (\sn t -> [t| ($sn, ArraysR $(varT t)) |]) [t| () |] xs
              -- Constraints for the type class, consisting of Arrays constraints on all type variables,
              -- and an equality constraint on the representation type of `a` and the snoc representation `snoc`.
              context  = tupT
                       $ [t| Arrays $(varT a) |]
                       : [t| ArraysR $(varT a) ~ $snoc |]
                       : map (\t -> [t| Arrays $(varT t)|]) xs
              --
              get x 0 = [| Acc (SmartAcc (Aprj PairIdxRight $x)) |]
              get x i = get  [| SmartAcc (Aprj PairIdxLeft $x) |] (i-1)
          --
          _x <- newName "_x"
          [d| instance $context => IsPattern Acc $(varT a) $b where
                construct $(tupP (map (\x -> [p| Acc $(varP x)|]) xs)) =
                  Acc $(foldl (\vs v -> [| SmartAcc ($vs `Apair` $(varE v)) |]) [| SmartAcc Anil |] xs)
                destruct (Acc $(varP _x)) =
                  $(tupE (map (get (varE _x)) [(n-1), (n-2) .. 0]))
            |]

        -- Generate instance declarations for IsPattern of the form:
        -- instance (Elt x, EltR x ~ (((), EltR a), EltR b), Elt a, Elt b,) => IsPattern Exp x (Exp a, Exp b)
        mkExpPattern :: Int -> Q [Dec]
        mkExpPattern n = do
          a <- newName "a"
          let
              -- Type variables for the elements
              xs       = [ mkName ('x' : show i) | i <- [0 .. n-1] ]
              -- Variables for sub-pattern matches
              ms       = [ mkName ('m' : show i) | i <- [0 .. n-1] ]
              tags     = foldl (\ts t -> [p| $ts `TagRpair` $(varP t) |]) [p| TagRunit |] ms
              -- Last argument to `IsPattern`, eg (Exp, a, Exp b) in the example
              b        = tupT (map (\t -> [t| Exp $(varT t)|]) xs)
              -- Representation as snoc-list of pairs, eg (((), EltR a), EltR b)
              snoc     = foldl (\sn t -> [t| ($sn, EltR $(varT t)) |]) [t| () |] xs
              -- Constraints for the type class, consisting of Elt constraints on all type variables,
              -- and an equality constraint on the representation type of `a` and the snoc representation `snoc`.
              context  = tupT
                       $ [t| Elt $(varT a) |]
                       : [t| EltR $(varT a) ~ $snoc |]
                       : map (\t -> [t| Elt $(varT t)|]) xs
              --
              get x 0 =     [| SmartExp (Prj PairIdxRight $x) |]
              get x i = get [| SmartExp (Prj PairIdxLeft $x)  |] (i-1)
          --
          _x <- newName "_x"
          _y <- newName "_y"
          [d| instance $context => IsPattern Exp $(varT a) $b where
                construct $(tupP (map (\x -> [p| Exp $(varP x)|]) xs)) =
                  let _unmatch :: SmartExp a -> SmartExp a
                      _unmatch (SmartExp (Match _ $(varP _y))) = $(varE _y)
                      _unmatch x = x
                  in
                  Exp $(foldl (\vs v -> [| SmartExp ($vs `Pair` _unmatch $(varE v)) |]) [| SmartExp Nil |] xs)
                destruct (Exp $(varP _x)) =
                  case $(varE _x) of
                    SmartExp (Match $tags $(varP _y))
                      -> $(tupE [[| Exp (SmartExp (Match $(varE m) $(get (varE _x) i))) |] | m <- ms | i <- [(n-1), (n-2) .. 0]])
                    _ -> $(tupE [[| Exp $(get (varE _x) i) |] | i <- [(n-1), (n-2) .. 0]])
            |]

        -- Generate instance declarations for IsVector of the form:
        -- instance (Elt v, EltR v ~ Vec 2 a, Elt a) => IsVector Exp v (Exp a, Exp a)
        mkVecPattern :: Int -> Q [Dec]
        mkVecPattern n = do
          a <- newName "a"
          v <- newName "v"
          let
              -- Last argument to `IsVector`, eg (Exp, a, Exp a) in the example
              tup      = tupT (replicate n ([t| Exp $(varT a)|]))
              -- Representation as a vector, eg (Vec 2 a)
              vec      = [t| Vec $(litT (numTyLit (fromIntegral n))) $(varT a) |]
              -- Constraints for the type class, consisting of Elt constraints on all type variables,
              -- and an equality constraint on the representation type of `a` and the vector representation `vec`.
              context  = [t| (Elt $(varT v), VecElt $(varT a), EltR $(varT v) ~ $vec) |]
              --
              vecR     = foldr appE [| VecRnil (singleType @ $(varT a)) |] (replicate n [| VecRsucc |])
              tR       = tupT (replicate n (varT a))
          --
          [d| instance $context => IsVector Exp $(varT v) $tup where
                vpack x = case construct x :: Exp $tR of
                            Exp x' -> Exp (SmartExp (VecPack $vecR x'))
                vunpack (Exp x) = destruct (Exp (SmartExp (VecUnpack $vecR x)) :: Exp $tR)
            |]
    --
    es <- mapM mkExpPattern [0..16]
    as <- mapM mkAccPattern [0..16]
    vs <- mapM mkVecPattern [2,3,4,8,16]
    return $ concat (es ++ as ++ vs)


-- | Specialised pattern synonyms for tuples, which may be more convenient to
-- use than 'Data.Array.Accelerate.Lift.lift' and
-- 'Data.Array.Accelerate.Lift.unlift'. For example, to construct a pair:
--
-- > let a = 4        :: Exp Int
-- > let b = 2        :: Exp Float
-- > let c = T2 a b   -- :: Exp (Int, Float); equivalent to 'lift (a,b)'
--
-- Similarly they can be used to destruct values:
--
-- > let T2 x y = c   -- x :: Exp Int, y :: Exp Float; equivalent to 'let (x,y) = unlift c'
--
-- These pattern synonyms can be used for both 'Exp' and 'Acc' terms.
--
-- Similarly, we have patterns for constructing and destructing indices of
-- a given dimensionality:
--
-- > let ix = Ix 2 3    -- :: Exp DIM2
-- > let I2 y x = ix    -- y :: Exp Int, x :: Exp Int
--
runQ $ do
    let
        mkT :: Int -> Q [Dec]
        mkT n =
          let xs    = [ mkName ('x' : show i) | i <- [0 .. n-1] ]
              ts    = map varT xs
              name  = mkName ('T':show n)
              con   = varT (mkName "con")
              ty1   = tupT ts
              ty2   = tupT (map (con `appT`) ts)
              sig   = foldr (\t r -> [t| $con $t -> $r |]) (appT con ty1) ts
          in
          sequence
            [ patSynSigD name [t| IsPattern $con $ty1 $ty2 => $sig |]
            , patSynD    name (prefixPatSyn xs) implBidir [p| Pattern $(tupP (map varP xs)) |]
            , pragCompleteD [name] (Just ''Acc)
            , pragCompleteD [name] (Just ''Exp)
            ]

        mkI :: Int -> Q [Dec]
        mkI n =
          let xs      = [ mkName ('x' : show i) | i <- [0 .. n-1] ]
              ts      = map varT xs
              name    = mkName ('I':show n)
              ix      = mkName "Ix"
              cst     = tupT (map (\t -> [t| Elt $t |]) ts)
              dim     = foldl (\h t -> [t| $h :. $t |]) [t| Z |] ts
              sig     = foldr (\t r -> [t| Exp $t -> $r |]) [t| Exp $dim |] ts
          in
          sequence
            [ patSynSigD name [t| $cst => $sig |]
            , patSynD    name (prefixPatSyn xs) implBidir (foldl (\ps p -> infixP ps ix (varP p)) [p| Z_ |] xs)
            , pragCompleteD [name] Nothing
            ]

        mkV :: Int -> Q [Dec]
        mkV n =
          let xs    = [ mkName ('x' : show i) | i <- [0 .. n-1] ]
              ts    = map varT xs
              name  = mkName ('V':show n)
              con   = varT (mkName "con")
              ty1   = varT (mkName "vec")
              ty2   = tupT (map (con `appT`) ts)
              sig   = foldr (\t r -> [t| $con $t -> $r |]) (appT con ty1) ts
          in
          sequence
            [ patSynSigD name [t| IsVector $con $ty1 $ty2 => $sig |]
            , patSynD    name (prefixPatSyn xs) implBidir [p| Vector $(tupP (map varP xs)) |]
            , pragCompleteD [name] (Just ''Exp)
            ]
    --
    ts <- mapM mkT [2..16]
    is <- mapM mkI [0..9]
    vs <- mapM mkV [2,3,4,8,16]
    return $ concat (ts ++ is ++ vs)