{-# LANGUAGE CPP, GADTs, TypeOperators, TypeFamilies, ScopedTypeVariables, RankNTypes #-}
{-# LANGUAGE FlexibleContexts, FlexibleInstances, MultiParamTypeClasses, TypeSynonymInstances #-}
{-# LANGUAGE DeriveDataTypeable, StandaloneDeriving, PatternGuards #-}
-- |
-- Module      : Data.Array.Accelerate.Smart
-- Copyright   : [2008..2011] Manuel M T Chakravarty, Gabriele Keller, Sean Lee
-- License     : BSD3
--
-- Maintainer  : Manuel M T Chakravarty <chak@cse.unsw.edu.au>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--
-- This modules defines the AST of the user-visible embedded language using more
-- convenient higher-order abstract syntax (instead of de Bruijn indices).
-- Moreover, it defines smart constructors to construct programs.
--

module Data.Array.Accelerate.Smart (

  -- * HOAS AST
  Acc(..), PreAcc(..), Exp, PreExp(..), Boundary(..), Stencil(..),

  -- * HOAS -> de Bruijn conversion
  convertAcc, convertAccFun1,

  -- * Smart constructors for pairing and unpairing
  pair, unpair,

  -- * Smart constructors for literals
  constant,

  -- * Smart constructors and destructors for tuples
  tup2, tup3, tup4, tup5, tup6, tup7, tup8, tup9,
  untup2, untup3, untup4, untup5, untup6, untup7, untup8, untup9,

  -- * Smart constructors for constants
  mkMinBound, mkMaxBound, mkPi,
  mkSin, mkCos, mkTan,
  mkAsin, mkAcos, mkAtan,
  mkAsinh, mkAcosh, mkAtanh,
  mkExpFloating, mkSqrt, mkLog,
  mkFPow, mkLogBase,
  mkTruncate, mkRound, mkFloor, mkCeiling,
  mkAtan2,

  -- * Smart constructors for primitive functions
  mkAdd, mkSub, mkMul, mkNeg, mkAbs, mkSig, mkQuot, mkRem, mkIDiv, mkMod,
  mkBAnd, mkBOr, mkBXor, mkBNot, mkBShiftL, mkBShiftR, mkBRotateL, mkBRotateR,
  mkFDiv, mkRecip, mkLt, mkGt, mkLtEq, mkGtEq, mkEq, mkNEq, mkMax, mkMin,
  mkLAnd, mkLOr, mkLNot,

  -- * Smart constructors for type coercion functions
  mkBoolToInt, mkFromIntegral,

  -- * Auxiliary functions
  ($$), ($$$), ($$$$), ($$$$$)

) where

-- standard library
import Control.Applicative                      hiding (Const)
import Control.Monad
import Data.HashTable                           as Hash
import Data.List
import Data.Maybe
import qualified Data.IntMap                    as IntMap
import Data.Typeable
import System.Mem.StableName
import System.IO.Unsafe                         (unsafePerformIO)
import Prelude                                  hiding (exp)

-- friends
import Data.Array.Accelerate.Debug
import Data.Array.Accelerate.Type
import Data.Array.Accelerate.Array.Sugar
import Data.Array.Accelerate.Tuple              hiding (Tuple)
import Data.Array.Accelerate.AST                hiding (
  PreOpenAcc(..), OpenAcc(..), Acc, Stencil(..), PreOpenExp(..), OpenExp, PreExp, Exp)
import qualified Data.Array.Accelerate.Tuple    as Tuple
import qualified Data.Array.Accelerate.AST      as AST
import Data.Array.Accelerate.Pretty ()

#include "accelerate.h"


-- Configuration
-- -------------

-- Are array computations floated out of expressions irrespective of whether they are shared or 
-- not?  'True' implies floating them out.
--
floatOutAccFromExp :: Bool
floatOutAccFromExp = True


-- Layouts
-- -------

-- A layout of an environment has an entry for each entry of the environment.
-- Each entry in the layout holds the deBruijn index that refers to the
-- corresponding entry in the environment.
--
data Layout env env' where
  EmptyLayout :: Layout env ()
  PushLayout  :: Typeable t
              => Layout env env' -> Idx env t -> Layout env (env', t)

-- Project the nth index out of an environment layout.
--
prjIdx :: Typeable t => Int -> Layout env env' -> Idx env t
prjIdx 0 (PushLayout _ ix) = case gcast ix of
                               Just ix' -> ix'
                               Nothing  -> INTERNAL_ERROR(error) "prjIdx" "type mismatch"
prjIdx n (PushLayout l _)  = prjIdx (n - 1) l
prjIdx _ EmptyLayout       = INTERNAL_ERROR(error) "prjIdx" "inconsistent valuation"

-- Add an entry to a layout, incrementing all indices
--
incLayout :: Layout env env' -> Layout (env, t) env'
incLayout EmptyLayout         = EmptyLayout
incLayout (PushLayout lyt ix) = PushLayout (incLayout lyt) (SuccIdx ix)


-- Array computations
-- ------------------

-- |Array-valued collective computations without a recursive knot
--
-- Note [Pipe and sharing recovery]
-- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-- The 'Pipe' constructor is special.  It is the only form that contains functions over array
-- computations and these functions are fixed to be over vanilla 'Acc' types.  This enables us to
-- perform sharing recovery independently from the context for them.
--
data PreAcc acc a where  
    -- Needed for conversion to de Bruijn form
  Atag        :: Arrays as
              => Int                        -- environment size at defining occurrence
              -> PreAcc acc as

  Pipe        :: (Arrays as, Arrays bs, Arrays cs) 
              => (Acc as -> Acc bs)         -- see comment above on why 'Acc' and not 'acc'
              -> (Acc bs -> Acc cs) 
              -> acc as 
              -> PreAcc acc cs
  Acond       :: (Arrays as)
              => PreExp acc Bool
              -> acc as
              -> acc as
              -> PreAcc acc as
  FstArray    :: (Shape sh1, Shape sh2, Elt e1, Elt e2)
              => acc (Array sh1 e1, Array sh2 e2)
              -> PreAcc acc (Array sh1 e1)
  SndArray    :: (Shape sh1, Shape sh2, Elt e1, Elt e2)
              => acc (Array sh1 e1, Array sh2 e2)
              -> PreAcc acc (Array sh2 e2)
  PairArrays  :: (Shape sh1, Shape sh2, Elt e1, Elt e2)
              => acc (Array sh1 e1)
              -> acc (Array sh2 e2)
              -> PreAcc acc (Array sh1 e1, Array sh2 e2)

  Use         :: (Shape sh, Elt e)
              => Array sh e -> PreAcc acc (Array sh e)
  Unit        :: Elt e
              => PreExp acc e 
              -> PreAcc acc (Scalar e)
  Generate    :: (Shape sh, Elt e)
              => PreExp acc sh
              -> (Exp sh -> PreExp acc e)
              -> PreAcc acc (Array sh e)
  Reshape     :: (Shape sh, Shape sh', Elt e)
              => PreExp acc sh
              -> acc (Array sh' e)
              -> PreAcc acc (Array sh e)
  Replicate   :: (Slice slix, Elt e,
                  Typeable (SliceShape slix), Typeable (FullShape slix))
                  -- the Typeable constraints shouldn't be necessary as they are implied by 
                  -- 'SliceIx slix' — unfortunately, the (old) type checker doesn't grok that
              => PreExp acc slix
              -> acc (Array (SliceShape slix)    e)
              -> PreAcc acc (Array (FullShape slix) e)
  Index       :: (Slice slix, Elt e, 
                  Typeable (SliceShape slix), Typeable (FullShape slix))
                  -- the Typeable constraints shouldn't be necessary as they are implied by 
                  -- 'SliceIx slix' — unfortunately, the (old) type checker doesn't grok that
              => acc (Array (FullShape slix) e)
              -> PreExp acc slix
              -> PreAcc acc (Array (SliceShape slix) e)
  Map         :: (Shape sh, Elt e, Elt e')
              => (Exp e -> PreExp acc e') 
              -> acc (Array sh e)
              -> PreAcc acc (Array sh e')
  ZipWith     :: (Shape sh, Elt e1, Elt e2, Elt e3)
              => (Exp e1 -> Exp e2 -> PreExp acc e3) 
              -> acc (Array sh e1)
              -> acc (Array sh e2)
              -> PreAcc acc (Array sh e3)
  Fold        :: (Shape sh, Elt e)
              => (Exp e -> Exp e -> PreExp acc e)
              -> PreExp acc e
              -> acc (Array (sh:.Int) e)
              -> PreAcc acc (Array sh e)
  Fold1       :: (Shape sh, Elt e)
              => (Exp e -> Exp e -> PreExp acc e)
              -> acc (Array (sh:.Int) e)
              -> PreAcc acc (Array sh e)
  FoldSeg     :: (Shape sh, Elt e)
              => (Exp e -> Exp e -> PreExp acc e)
              -> PreExp acc e
              -> acc (Array (sh:.Int) e)
              -> acc Segments
              -> PreAcc acc (Array (sh:.Int) e)
  Fold1Seg    :: (Shape sh, Elt e)
              => (Exp e -> Exp e -> PreExp acc e)
              -> acc (Array (sh:.Int) e)
              -> acc Segments
              -> PreAcc acc (Array (sh:.Int) e)
  Scanl       :: Elt e
              => (Exp e -> Exp e -> PreExp acc e)
              -> PreExp acc e
              -> acc (Vector e)
              -> PreAcc acc (Vector e)
  Scanl'      :: Elt e
              => (Exp e -> Exp e -> PreExp acc e)
              -> PreExp acc e
              -> acc (Vector e)
              -> PreAcc acc (Vector e, Scalar e)
  Scanl1      :: Elt e
              => (Exp e -> Exp e -> PreExp acc e)
              -> acc (Vector e)
              -> PreAcc acc (Vector e)
  Scanr       :: Elt e
              => (Exp e -> Exp e -> PreExp acc e)
              -> PreExp acc e
              -> acc (Vector e)
              -> PreAcc acc (Vector e)
  Scanr'      :: Elt e
              => (Exp e -> Exp e -> PreExp acc e)
              -> PreExp acc e
              -> acc (Vector e)
              -> PreAcc acc (Vector e, Scalar e)
  Scanr1      :: Elt e
              => (Exp e -> Exp e -> PreExp acc e)
              -> acc (Vector e)
              -> PreAcc acc (Vector e)
  Permute     :: (Shape sh, Shape sh', Elt e)
              => (Exp e -> Exp e -> PreExp acc e)
              -> acc (Array sh' e)
              -> (Exp sh -> PreExp acc sh')
              -> acc (Array sh e)
              -> PreAcc acc (Array sh' e)
  Backpermute :: (Shape sh, Shape sh', Elt e)
              => PreExp acc sh'
              -> (Exp sh' -> PreExp acc sh)
              -> acc (Array sh e)
              -> PreAcc acc (Array sh' e)
  Stencil     :: (Shape sh, Elt a, Elt b, Stencil sh a stencil)
              => (stencil -> PreExp acc b)
              -> Boundary a
              -> acc (Array sh a)
              -> PreAcc acc (Array sh b)
  Stencil2    :: (Shape sh, Elt a, Elt b, Elt c,
                 Stencil sh a stencil1, Stencil sh b stencil2)
              => (stencil1 -> stencil2 -> PreExp acc c)
              -> Boundary a
              -> acc (Array sh a)
              -> Boundary b
              -> acc (Array sh b)
              -> PreAcc acc (Array sh c)

-- |Array-valued collective computations
--
newtype Acc a = Acc (PreAcc Acc a)

deriving instance Typeable1 Acc

-- |Conversion from HOAS to de Bruijn computation AST
-- -

-- |Convert a closed array expression to de Bruijn form while also incorporating sharing
-- information.
--
convertAcc :: Arrays arrs => Acc arrs -> AST.Acc arrs
convertAcc = convertOpenAcc EmptyLayout

-- |Convert a closed array expression to de Bruijn form while also incorporating sharing
-- information.
--
convertOpenAcc :: Arrays arrs => Layout aenv aenv -> Acc arrs -> AST.OpenAcc aenv arrs
convertOpenAcc alyt = convertSharingAcc alyt [] . recoverSharing floatOutAccFromExp

-- |Convert a unary function over array computations
--
convertAccFun1 :: forall a b. (Arrays a, Arrays b)
               => (Acc a -> Acc b) 
               -> AST.Afun (a -> b)
convertAccFun1 f = Alam (Abody openF)
  where
    a     = Atag 0
    alyt  = EmptyLayout 
            `PushLayout` 
            (ZeroIdx :: Idx ((), a) a)
    openF = convertOpenAcc alyt (f (Acc a))

-- |Convert an array expression with given array environment layout and sharing information into
-- de Bruijn form while recovering sharing at the same time (by introducing appropriate let
-- bindings).  The latter implements the third phase of sharing recovery.
--
-- The sharing environment 'env' keeps track of all currently bound sharing variables, keeping them
-- in reverse chronological order (outermost variable is at the end of the list)
--
convertSharingAcc :: forall a aenv. Arrays a
                  => Layout aenv aenv
                  -> [StableSharingAcc]
                  -> SharingAcc a
                  -> AST.OpenAcc aenv a
convertSharingAcc alyt env (VarSharing sa)
  | Just i <- findIndex (matchStableAcc sa) env 
  = AST.OpenAcc $ AST.Avar (prjIdx i alyt)
  | otherwise                                   
  = INTERNAL_ERROR(error) "convertSharingAcc (prjIdx)" err
  where
    err = "inconsistent valuation; sa = " ++ show (hashStableName sa) ++ "; env = " ++ show env
convertSharingAcc alyt env (LetSharing sa@(StableSharingAcc _ boundAcc) bodyAcc)
  = AST.OpenAcc
  $ let alyt' = incLayout alyt `PushLayout` ZeroIdx
    in
    AST.Let (convertSharingAcc alyt env boundAcc) (convertSharingAcc alyt' (sa:env) bodyAcc)
convertSharingAcc alyt env (AccSharing _ preAcc)
  = AST.OpenAcc
  $ (case preAcc of
      Atag i
        -> AST.Avar (prjIdx i alyt)
      Pipe afun1 afun2 acc
        -> let boundAcc = convertAccFun1 afun1 `AST.Apply` convertSharingAcc alyt env acc
               bodyAcc  = convertAccFun1 afun2 `AST.Apply` AST.OpenAcc (AST.Avar AST.ZeroIdx)
           in
           AST.Let (AST.OpenAcc boundAcc) (AST.OpenAcc bodyAcc)
      Acond b acc1 acc2
        -> AST.Acond (convertExp alyt env b) (convertSharingAcc alyt env acc1)
                     (convertSharingAcc alyt env acc2)
      FstArray acc
        -> AST.Let2 (convertSharingAcc alyt env acc) 
                    (AST.OpenAcc $ AST.Avar (AST.SuccIdx AST.ZeroIdx))
      SndArray acc
        -> AST.Let2 (convertSharingAcc alyt env acc) 
                    (AST.OpenAcc $ AST.Avar AST.ZeroIdx)
      PairArrays acc1 acc2
        -> AST.PairArrays (convertSharingAcc alyt env acc1)
                          (convertSharingAcc alyt env acc2)
      Use array
        -> AST.Use array
      Unit e
        -> AST.Unit (convertExp alyt env e)
      Generate sh f
        -> AST.Generate (convertExp alyt env sh) (convertFun1 alyt env f)
      Reshape e acc
        -> AST.Reshape (convertExp alyt env e) (convertSharingAcc alyt env acc)
      Replicate ix acc
        -> mkReplicate (convertExp alyt env ix) (convertSharingAcc alyt env acc)
      Index acc ix
        -> mkIndex (convertSharingAcc alyt env acc) (convertExp alyt env ix)
      Map f acc 
        -> AST.Map (convertFun1 alyt env f) (convertSharingAcc alyt env acc)
      ZipWith f acc1 acc2
        -> AST.ZipWith (convertFun2 alyt env f) 
                       (convertSharingAcc alyt env acc1)
                       (convertSharingAcc alyt env acc2)
      Fold f e acc
        -> AST.Fold (convertFun2 alyt env f) (convertExp alyt env e) 
                    (convertSharingAcc alyt env acc)
      Fold1 f acc
        -> AST.Fold1 (convertFun2 alyt env f) (convertSharingAcc alyt env acc)
      FoldSeg f e acc1 acc2
        -> AST.FoldSeg (convertFun2 alyt env f) (convertExp alyt env e) 
                       (convertSharingAcc alyt env acc1) (convertSharingAcc alyt env acc2)
      Fold1Seg f acc1 acc2
        -> AST.Fold1Seg (convertFun2 alyt env f)
                        (convertSharingAcc alyt env acc1)
                        (convertSharingAcc alyt env acc2)
      Scanl f e acc
        -> AST.Scanl (convertFun2 alyt env f) (convertExp alyt env e) 
                     (convertSharingAcc alyt env acc)
      Scanl' f e acc
        -> AST.Scanl' (convertFun2 alyt env f)
                      (convertExp alyt env e)
                      (convertSharingAcc alyt env acc)
      Scanl1 f acc
        -> AST.Scanl1 (convertFun2 alyt env f) (convertSharingAcc alyt env acc)
      Scanr f e acc
        -> AST.Scanr (convertFun2 alyt env f) (convertExp alyt env e)
                     (convertSharingAcc alyt env acc)
      Scanr' f e acc
        -> AST.Scanr' (convertFun2 alyt env f)
                      (convertExp alyt env e)
                      (convertSharingAcc alyt env acc)
      Scanr1 f acc
        -> AST.Scanr1 (convertFun2 alyt env f) (convertSharingAcc alyt env acc)
      Permute f dftAcc perm acc
        -> AST.Permute (convertFun2 alyt env f) 
                       (convertSharingAcc alyt env dftAcc)
                       (convertFun1 alyt env perm) 
                       (convertSharingAcc alyt env acc)
      Backpermute newDim perm acc
        -> AST.Backpermute (convertExp alyt env newDim)
                           (convertFun1 alyt env perm) 
                           (convertSharingAcc alyt env acc)
      Stencil stencil boundary acc
        -> AST.Stencil (convertStencilFun acc alyt env stencil) 
                       (convertBoundary boundary) 
                       (convertSharingAcc alyt env acc)
      Stencil2 stencil bndy1 acc1 bndy2 acc2
        -> AST.Stencil2 (convertStencilFun2 acc1 acc2 alyt env stencil) 
                        (convertBoundary bndy1) 
                        (convertSharingAcc alyt env acc1)
                        (convertBoundary bndy2) 
                        (convertSharingAcc alyt env acc2)
    :: AST.PreOpenAcc AST.OpenAcc aenv a)

-- |Convert a boundary condition
--
convertBoundary :: Elt e => Boundary e -> Boundary (EltRepr e)
convertBoundary Clamp        = Clamp
convertBoundary Mirror       = Mirror
convertBoundary Wrap         = Wrap
convertBoundary (Constant e) = Constant (fromElt e)


-- Sharing recovery
-- ----------------

-- Sharing recovery proceeds in two phases:
--
-- /Phase One: build the occurence map/
--
-- This is a top-down traversal of the AST that computes a map from AST nodes to the number of
-- occurences of that AST node in the overall Accelerate program.  An occurrences count of two or
-- more indicates sharing.
--
-- IMPORTANT: To avoid unfolding the sharing, we do not descent into subtrees that we have
--   previously encountered.  Hence, the complexity is proprtional to the number of nodes in the
--   tree /with/ sharing.  Consequently, the occurence count is that in the tree with sharing
--   as well.
--
-- During computation of the occurences, the tree is annotated with stable names on every node
-- using 'AccSharing' constructors and all but the first occurence of shared subtrees are pruned
-- using 'VarSharing' constructors (see 'SharingAcc' below).  This phase is impure as it is based
-- on stable names.
--
-- We use a hash table (instead of 'Data.Map') as computing stable names forces us to live in IO
-- anyway.  Once, the computation of occurence counts is complete, we freeze the hash table into
-- a 'Data.Map'.
--
-- (Implemented by 'makeOccMap'.)
--
-- /Phase Two: determine scopes and inject sharing information/
--
-- This is a bottom-up traversal that determines the scope for every binding to be introduced
-- to share a subterm.  It uses the occurence map to determine, for every shared subtree, the
-- lowest AST node at which the binding for that shared subtree can be placed (using a 'LetSharing'
-- constructor)— it's the meet of all the shared subtree occurences.
--
-- The second phase is also replacing the first occurence of each shared subtree with a
-- 'VarSharing' node and floats the shared subtree up to its binding point.
--
--  (Implemented by 'determineScopes'.)

-- Opaque stable name for an array computation — used to key the occurence map.
--
data StableAccName where
  StableAccName :: Typeable arrs => StableName (Acc arrs) -> StableAccName

instance Show StableAccName where
  show (StableAccName sn) = show $ hashStableName sn

instance Eq StableAccName where
  StableAccName sn1 == StableAccName sn2
    | Just sn1' <- gcast sn1 = sn1' == sn2
    | otherwise              = False

makeStableAcc :: Acc arrs -> IO (StableName (Acc arrs))
makeStableAcc acc = acc `seq` makeStableName acc

-- Interleave sharing annotations into an array computation AST.  Subtrees can be marked as being
-- represented by variable (binding a shared subtree) using 'VarSharing' and as being prefixed by
-- a let binding (for a shared subtree) using 'LetSharing'.
--
data SharingAcc arrs where
  VarSharing :: Arrays arrs => StableName (Acc arrs)                           -> SharingAcc arrs
  LetSharing ::                StableSharingAcc -> SharingAcc arrs             -> SharingAcc arrs
  AccSharing :: Arrays arrs => StableName (Acc arrs) -> PreAcc SharingAcc arrs -> SharingAcc arrs

-- Stable name for an array computation associated with its sharing-annotated version.
--
data StableSharingAcc where
  StableSharingAcc :: Arrays arrs => StableName (Acc arrs) -> SharingAcc arrs -> StableSharingAcc

instance Show StableSharingAcc where
  show (StableSharingAcc sn _) = show $ hashStableName sn

instance Eq StableSharingAcc where
  StableSharingAcc sn1 _ == StableSharingAcc sn2 _
    | Just sn1' <- gcast sn1 = sn1' == sn2
    | otherwise              = False

-- Test whether the given stable names matches an array computation with sharing.
--
matchStableAcc :: Typeable arrs => StableName (Acc arrs) -> StableSharingAcc -> Bool
matchStableAcc sn1 (StableSharingAcc sn2 _)
  | Just sn1' <- gcast sn1 = sn1' == sn2
  | otherwise              = False

-- Hash table keyed on the stable names of array computations.
--    
type AccHashTable v = Hash.HashTable StableAccName v

-- Mutable version of the occurrence map, which associates each AST node with an occurence count.
--
type OccMapHash = AccHashTable Int

-- Create a new hash table keyed by array computations.
--
newAccHashTable :: IO (AccHashTable v)
newAccHashTable = Hash.new (==) hashStableAcc
  where
    hashStableAcc (StableAccName sn) = fromIntegral (hashStableName sn)

-- Immutable version of the occurence map.  We use the 'StableName' hash to index an 'IntMap' and
-- disambiguate 'StableName's with identical hashes explicitly, storing them in a list in the
-- 'IntMap'.
--
type OccMap = IntMap.IntMap [(StableAccName, Int)]

-- Turn a mutable into an immutable occurence map.
--
freezeOccMap :: OccMapHash -> IO OccMap
freezeOccMap oc
  = do
      kvs <- Hash.toList oc
      return . IntMap.fromList . map (\kvs -> (key (head kvs), kvs)). groupBy sameKey $ kvs
  where
    key (StableAccName sn, _) = hashStableName sn
    sameKey kv1 kv2           = key kv1 == key kv2

-- Look up the occurence map keyed by array computations using a stable name.  If a the key does
-- not exist in the map, return an occurence count of '1'.
--
lookupWithAccName :: OccMap -> StableAccName -> Int
lookupWithAccName oc sa@(StableAccName sn) 
  = fromMaybe 1 $ IntMap.lookup (hashStableName sn) oc >>= Prelude.lookup sa
    
-- Look up the occurence map keyed by array computations using a sharing array computation.  If an
-- the key does not exist in the map, return an occurence count of '1'.
--
lookupWithSharingAcc :: OccMap -> StableSharingAcc -> Int
lookupWithSharingAcc oc (StableSharingAcc sn _) = lookupWithAccName oc (StableAccName sn)

-- Compute the occurence map, marks all nodes with stable names, and drop repeated occurences
-- of shared subtrees (Phase One).
--
-- Note [Traversing functions and side effects]
-- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-- We need to descent into function bodies to build the 'OccMap' with all occurences in the
-- function bodies.  Due to the side effects in the construction of the occurence map and, more
-- importantly, the dependence of the second phase on /global/ occurence information, we may not
-- delay the body traversals by putting them under a lambda.  Hence, we apply the each function, to
-- traverse its body and use a /dummy abstraction/ of the result.
--
-- For example, given a function 'f', we traverse 'f (Tag 0)', which yields a transformed body 'e'.
-- As the result of the traversal of the overall function, we use 'const e'.  Hence, it is crucial
-- that the 'Tag' supplied during the initial traversal is already the one required by the HOAS to
-- de Bruijn conversion in 'convertSharingAcc' — any subsequent application of 'const e' will only
-- yield 'e' with the embedded 'Tag 0' of the original application.
--
makeOccMap :: Typeable arrs => Acc arrs -> IO (SharingAcc arrs, OccMapHash)
makeOccMap rootAcc
  = do
      occMap <- newAccHashTable
      rootAcc' <- traverseAcc True (enterOcc occMap) rootAcc
      return (rootAcc', occMap)
  where
    -- Enter one AST node occurrence into an occurrence map.  Returns 'True' if this is a repeated
    -- occurence.
    --
    -- The first argument determines whether the 'OccMap' will be modified - see Note [Traversing
    -- functions and side effects].
    --
    enterOcc :: OccMapHash -> Bool -> StableAccName -> IO Bool
    enterOcc occMap updateMap sa 
      = do
          entry <- Hash.lookup occMap sa
          case entry of
            Nothing -> when updateMap (       Hash.insert occMap sa 1      ) >> return False
            Just n  -> when updateMap (void $ Hash.update occMap sa (n + 1)) >> return True
              where
                void = (>> return ())

    traverseAcc :: forall arrs. Typeable arrs
                => Bool -> (Bool -> StableAccName -> IO Bool) -> Acc arrs -> IO (SharingAcc arrs)
    traverseAcc updateMap enter acc'@(Acc pacc)
      = do
            -- Compute stable name and enter it into the occurence map
          sn <- makeStableAcc acc'
          isRepeatedOccurence <- enter updateMap $ StableAccName sn
          
          traceLine (showPreAccOp pacc) $
            if isRepeatedOccurence 
              then "REPEATED occurence"
              else "first occurence (" ++ show (hashStableName sn) ++ ")"

            -- Reconstruct the computation in shared form
            --
            -- NB: This function can only be used in the case alternatives below; outside of the
            --     case we cannot discharge the 'Arrays arrs' constraint.
          let reconstruct :: Arrays arrs 
                          => IO (PreAcc SharingAcc arrs)
                          -> IO (SharingAcc arrs)
              reconstruct newAcc | isRepeatedOccurence = pure $ VarSharing sn
                                 | otherwise           = AccSharing sn <$> newAcc

          case pacc of
            Atag i                   -> reconstruct $ return (Atag i)
            Pipe afun1 afun2 acc     -> reconstruct $ travA (Pipe afun1 afun2) acc
            Acond e acc1 acc2        -> reconstruct $ do
                                          e'    <- traverseExp updateMap enter e
                                          acc1' <- traverseAcc updateMap enter acc1
                                          acc2' <- traverseAcc updateMap enter acc2
                                          return (Acond e' acc1' acc2')
            FstArray acc             -> reconstruct $ travA FstArray acc
            SndArray acc             -> reconstruct $ travA SndArray acc
            PairArrays acc1 acc2     -> reconstruct $ do
                                          acc1' <- traverseAcc updateMap enter acc1
                                          acc2' <- traverseAcc updateMap enter acc2
                                          return (PairArrays acc1' acc2')
            Use arr                  -> reconstruct $ return (Use arr)
            Unit e                   -> reconstruct $ do
                                          e' <- traverseExp updateMap enter e
                                          return (Unit e')
            Generate e f             -> reconstruct $ do
                                          e' <- traverseExp  updateMap enter e
                                          f' <- traverseFun1 updateMap enter f
                                          return (Generate e' f')
            Reshape e acc            -> reconstruct $ travEA Reshape e acc
            Replicate e acc          -> reconstruct $ travEA Replicate e acc
            Index acc e              -> reconstruct $ travEA (flip Index) e acc
            Map f acc                -> reconstruct $ do
                                          f'   <- traverseFun1 updateMap enter f
                                          acc' <- traverseAcc  updateMap enter acc
                                          return (Map f' acc')
            ZipWith f acc1 acc2      -> reconstruct $ travF2A2 ZipWith f acc1 acc2
            Fold f e acc             -> reconstruct $ travF2EA Fold f e acc
            Fold1 f acc              -> reconstruct $ travF2A Fold1 f acc
            FoldSeg f e acc1 acc2    -> reconstruct $ do
                                          f'    <- traverseFun2 updateMap enter f
                                          e'    <- traverseExp  updateMap enter e
                                          acc1' <- traverseAcc  updateMap enter acc1
                                          acc2' <- traverseAcc  updateMap enter acc2
                                          return (FoldSeg f' e' acc1' acc2')
            Fold1Seg f acc1 acc2     -> reconstruct $ travF2A2 Fold1Seg f acc1 acc2
            Scanl f e acc            -> reconstruct $ travF2EA Scanl f e acc
            Scanl' f e acc           -> reconstruct $ travF2EA Scanl' f e acc
            Scanl1 f acc             -> reconstruct $ travF2A Scanl1 f acc
            Scanr f e acc            -> reconstruct $ travF2EA Scanr f e acc
            Scanr' f e acc           -> reconstruct $ travF2EA Scanr' f e acc
            Scanr1 f acc             -> reconstruct $ travF2A Scanr1 f acc
            Permute c acc1 p acc2    -> reconstruct $ do
                                          c'    <- traverseFun2 updateMap enter c
                                          p'    <- traverseFun1 updateMap enter p
                                          acc1' <- traverseAcc  updateMap enter acc1
                                          acc2' <- traverseAcc  updateMap enter acc2
                                          return (Permute c' acc1' p' acc2')
            Backpermute e p acc      -> reconstruct $ do
                                          e'   <- traverseExp  updateMap enter e
                                          p'   <- traverseFun1 updateMap enter p
                                          acc' <- traverseAcc  updateMap enter acc
                                          return (Backpermute e' p' acc')
            Stencil s bnd acc        -> reconstruct $ do
                                          s'   <- traverseStencil1 acc updateMap enter s
                                          acc' <- traverseAcc  updateMap enter acc
                                          return (Stencil s' bnd acc')
            Stencil2 s bnd1 acc1 
                       bnd2 acc2     -> reconstruct $ do
                                          s'    <- traverseStencil2 acc1 acc2 updateMap enter s
                                          acc1' <- traverseAcc  updateMap enter acc1
                                          acc2' <- traverseAcc  updateMap enter acc2
                                          return (Stencil2 s' bnd1 acc1' bnd2 acc2')
      where
        travA :: Arrays arrs'
              => (SharingAcc arrs' -> PreAcc SharingAcc arrs) 
              -> Acc arrs' -> IO (PreAcc SharingAcc arrs)
        travA c acc
          = do
              acc' <- traverseAcc updateMap enter acc
              return $ c acc'

        travEA :: (Typeable b, Arrays arrs')
               => (SharingExp b -> SharingAcc arrs' -> PreAcc SharingAcc arrs) 
               -> Exp b -> Acc arrs' -> IO (PreAcc SharingAcc arrs)
        travEA c exp acc
          = do
              exp' <- traverseExp updateMap enter exp
              acc' <- traverseAcc updateMap enter acc
              return $ c exp' acc'

        travF2A :: (Elt b, Elt c, Typeable d, Arrays arrs')
                => ((Exp b -> Exp c -> SharingExp d) -> SharingAcc arrs' -> PreAcc SharingAcc arrs) 
                -> (Exp b -> Exp c -> Exp d) -> Acc arrs' -> IO (PreAcc SharingAcc arrs)
        travF2A c fun acc
          = do
              fun' <- traverseFun2 updateMap enter fun
              acc' <- traverseAcc updateMap enter acc
              return $ c fun' acc'

        travF2EA :: (Elt b, Elt c, Typeable d, Typeable e, Arrays arrs')
                 => ((Exp b -> Exp c -> SharingExp d) -> SharingExp e
                       -> SharingAcc arrs' -> PreAcc SharingAcc arrs) 
                 -> (Exp b -> Exp c -> Exp d) -> Exp e -> Acc arrs' -> IO (PreAcc SharingAcc arrs)
        travF2EA c fun exp acc
          = do
              fun' <- traverseFun2 updateMap enter fun
              exp' <- traverseExp updateMap enter exp
              acc' <- traverseAcc updateMap enter acc
              return $ c fun' exp' acc'

        travF2A2 :: (Elt b, Elt c, Typeable d, Arrays arrs1, Arrays arrs2)
                 => ((Exp b -> Exp c -> SharingExp d) -> SharingAcc arrs1
                       -> SharingAcc arrs2 -> PreAcc SharingAcc arrs) 
                 -> (Exp b -> Exp c -> Exp d) -> Acc arrs1 -> Acc arrs2 
                 -> IO (PreAcc SharingAcc arrs)
        travF2A2 c fun acc1 acc2
          = do
              fun' <- traverseFun2 updateMap enter fun
              acc1' <- traverseAcc updateMap enter acc1
              acc2' <- traverseAcc updateMap enter acc2
              return $ c fun' acc1' acc2'

    traverseFun1 :: (Elt b, Typeable c) 
                  => Bool -> (Bool -> StableAccName -> IO Bool) -> (Exp b -> Exp c) 
                  -> IO (Exp b -> SharingExp c)
    traverseFun1 updateMap enter f
      = do
            -- see Note [Traversing functions and side effects]
          body <- traverseExp updateMap enter $ f (Tag 0)
          return $ const body

    traverseFun2 :: (Elt b, Elt c, Typeable d) 
                  => Bool -> (Bool -> StableAccName -> IO Bool) -> (Exp b -> Exp c -> Exp d) 
                  -> IO (Exp b -> Exp c -> SharingExp d)
    traverseFun2 updateMap enter f
      = do
            -- see Note [Traversing functions and side effects]
          body <- traverseExp updateMap enter $ f (Tag 1) (Tag 0)
          return $ \_ _ -> body

    traverseStencil1 :: forall sh b c stencil. (Stencil sh b stencil, Typeable c) 
                     => Acc (Array sh b){-dummy-}
                     -> Bool -> (Bool -> StableAccName -> IO Bool) -> (stencil -> Exp c) 
                     -> IO (stencil -> SharingExp c)
    traverseStencil1 _ updateMap enter stencilFun 
      = do
            -- see Note [Traversing functions and side effects]
          body <- traverseExp updateMap enter $ 
                    stencilFun (stencilPrj (undefined::sh) (undefined::b) (Tag 0))
          return $ const body
        
    traverseStencil2 :: forall sh b c d stencil1 stencil2. 
                        (Stencil sh b stencil1, Stencil sh c stencil2, Typeable d) 
                     => Acc (Array sh b){-dummy-}
                     -> Acc (Array sh c){-dummy-}
                     -> Bool -> (Bool -> StableAccName -> IO Bool) 
                     -> (stencil1 -> stencil2 -> Exp d) 
                     -> IO (stencil1 -> stencil2 -> SharingExp d)
    traverseStencil2 _ _ updateMap enter stencilFun 
      = do
            -- see Note [Traversing functions and side effects]
          body <- traverseExp updateMap enter $ 
                    stencilFun (stencilPrj (undefined::sh) (undefined::b) (Tag 1))
                               (stencilPrj (undefined::sh) (undefined::c) (Tag 0))
          return $ \_ _ -> body
        
    traverseExp :: Typeable a 
                => Bool -> (Bool -> StableAccName -> IO Bool) -> Exp a -> IO (SharingExp a)
    traverseExp updateMap enter exp  -- @(Exp pexp)
      = -- FIXME: recover sharing of scalar expressions as well
          case exp of
            Tag i           -> return $ Tag i
            Const c         -> return $ Const c
            Tuple tup       -> Tuple <$> travTup tup
            Prj i e         -> travE1 (Prj i) e
            IndexNil        -> return IndexNil
            IndexCons ix i  -> travE2 IndexCons ix i
            IndexHead i     -> travE1 IndexHead i
            IndexTail ix    -> travE1 IndexTail ix
            IndexAny        -> return $ IndexAny
            Cond e1 e2 e3   -> travE3 Cond e1 e2 e3
            PrimConst c     -> return $ PrimConst c
            PrimApp p e     -> travE1 (PrimApp p) e
            IndexScalar a e -> travAE IndexScalar a e
            Shape a         -> travA Shape a
            Size a          -> travA Size a
      where
        travE1 :: Typeable b => (SharingExp b -> SharingExp c) -> Exp b -> IO (SharingExp c)
        travE1 c e
          = do
              e' <- traverseExp updateMap enter e
              return $ c e'
      
        travE2 :: (Typeable b, Typeable c) 
               => (SharingExp b -> SharingExp c -> SharingExp d) -> Exp b -> Exp c 
               -> IO (SharingExp d)
        travE2 c e1 e2
          = do
              e1' <- traverseExp updateMap enter e1
              e2' <- traverseExp updateMap enter e2
              return $ c e1' e2'
      
        travE3 :: (Typeable b, Typeable c, Typeable d) 
               => (SharingExp b -> SharingExp c -> SharingExp d -> SharingExp e) 
               -> Exp b -> Exp c -> Exp d
               -> IO (SharingExp e)
        travE3 c e1 e2 e3
          = do
              e1' <- traverseExp updateMap enter e1
              e2' <- traverseExp updateMap enter e2
              e3' <- traverseExp updateMap enter e3
              return $ c e1' e2' e3'
      
        travA :: Typeable b => (SharingAcc b -> SharingExp c) -> Acc b -> IO (SharingExp c)
        travA c acc
          = do
              acc' <- traverseAcc updateMap enter acc
              return $ c acc'

        travAE :: (Typeable b, Typeable c) 
               => (SharingAcc b -> SharingExp c -> SharingExp d) -> Acc b -> Exp c 
               -> IO (SharingExp d)
        travAE c acc e
          = do
              acc' <- traverseAcc updateMap enter acc
              e' <- traverseExp updateMap enter e
              return $ c acc' e'

        travTup :: Tuple.Tuple (PreExp Acc) tup -> IO (Tuple.Tuple (PreExp SharingAcc) tup)
        travTup NilTup          = return NilTup
        travTup (SnocTup tup e) = pure SnocTup <*> travTup tup <*> traverseExp updateMap enter e

-- Type used to maintain how often each shared subterm occured.
--
--   Invariant: If one shared term 's' is itself a subterm of another shared term 't', then 's' 
--              must occur *after* 't' in the 'NodeCounts'.  Moreover, no shared term occur twice.
--
-- To ensure the invariant is preserved over merging node counts from sibling subterms, the
-- function '(+++)' must be used.
--
newtype NodeCounts = NodeCounts [(StableSharingAcc, Int)]
  deriving Show

-- Empty node counts
--
noNodeCounts :: NodeCounts
noNodeCounts = NodeCounts []

-- Singleton node counts
--
nodeCount :: (StableSharingAcc, Int) -> NodeCounts
nodeCount nc = NodeCounts [nc]

-- Combine node counts that belong to the same node.
--
-- * We assume that the node counts invariant —subterms follow their parents— holds for both
--   arguments and guarantee that it still holds for the result.
--
-- * This function has quadratic complexity.  This could be improved by labelling nodes with their
--   nesting depth, but doesn't seem worthwhile as the arguments are expected to be fairly short.
--   Change if profiling suggests that this function is a bottleneck.
--
(+++) :: NodeCounts -> NodeCounts -> NodeCounts
NodeCounts us +++ NodeCounts vs = NodeCounts $ merge us vs
  where
    merge []                         ys                         = ys
    merge xs                         []                         = xs
    merge xs@(x@(sa1, count1) : xs') ys@(y@(sa2, count2) : ys') 
      | sa1 == sa2                = (sa1 `pickNoneVar` sa2, count1 + count2) : merge xs' ys'
      | sa1 `notElem` map fst ys' = x : merge xs' ys
      | sa2 `notElem` map fst xs' = y : merge xs  ys'
      | otherwise                 = INTERNAL_ERROR(error) "(+++)" "Precondition violated"

    (StableSharingAcc _ (VarSharing _)) `pickNoneVar` sa2                                 = sa2
    sa1                                 `pickNoneVar` _sa2                                = sa1

-- Determine the scopes of all variables representing shared subterms (Phase Two) in a bottom-up
-- sweep.  The first argument determines whether array computations are floated out of expressions
-- irrespective of whether they are shared or not — 'True' implies floating them out.
--
-- Precondition: there are only 'VarSharing' and 'AccSharing' nodes in the argument.
--
determineScopes :: Typeable a => Bool -> OccMap -> SharingAcc a -> SharingAcc a
determineScopes floatOutAcc occMap rootAcc = fst $ scopesAcc rootAcc
  where
    scopesAcc :: forall arrs. SharingAcc arrs -> (SharingAcc arrs, NodeCounts)
    scopesAcc (LetSharing _ _)
      = INTERNAL_ERROR(error) "determineScopes: scopes" "unexpected 'LetSharing'"
    scopesAcc sharingAcc@(VarSharing sn)
      = (VarSharing sn, nodeCount (StableSharingAcc sn sharingAcc, 1))
    scopesAcc (AccSharing sn pacc)
      = case pacc of
          Atag i                  -> reconstruct (Atag i) noNodeCounts
          Pipe afun1 afun2 acc    -> travA (Pipe afun1 afun2) acc
            -- we are not traversing 'afun1' & 'afun2' — see Note [Pipe and sharing recovery]
          Acond e acc1 acc2       -> let
                                       (e'   , accCount1) = scopesExp e
                                       (acc1', accCount2) = scopesAcc acc1
                                       (acc2', accCount3) = scopesAcc acc2
                                     in
                                     reconstruct (Acond e' acc1' acc2')
                                                 (accCount1 +++ accCount2 +++ accCount3)
          FstArray acc            -> travA FstArray acc
          SndArray acc            -> travA SndArray acc
          PairArrays acc1 acc2    -> let
                                       (acc1', accCount1) = scopesAcc acc1
                                       (acc2', accCount2) = scopesAcc acc2
                                     in
                                     reconstruct (PairArrays acc1' acc2') (accCount1 +++ accCount2)
          Use arr                 -> reconstruct (Use arr) noNodeCounts
          Unit e                  -> let
                                       (e', accCount) = scopesExp e
                                     in
                                     reconstruct (Unit e') accCount
          Generate sh f           -> let
                                       (sh', accCount1) = scopesExp sh
                                       (f' , accCount2) = scopesFun1 f
                                     in
                                     reconstruct (Generate sh' f') (accCount1 +++ accCount2)
          Reshape sh acc          -> travEA Reshape sh acc
          Replicate n acc         -> travEA Replicate n acc
          Index acc i             -> travEA (flip Index) i acc
          Map f acc               -> let
                                       (f'  , accCount1) = scopesFun1 f
                                       (acc', accCount2) = scopesAcc  acc
                                     in
                                     reconstruct (Map f' acc') (accCount1 +++ accCount2)
          ZipWith f acc1 acc2     -> travF2A2 ZipWith f acc1 acc2
          Fold f z acc            -> travF2EA Fold f z acc
          Fold1 f acc             -> travF2A Fold1 f acc
          FoldSeg f z acc1 acc2   -> let
                                       (f'   , accCount1)  = scopesFun2 f
                                       (z'   , accCount2)  = scopesExp  z
                                       (acc1', accCount3)  = scopesAcc  acc1
                                       (acc2', accCount4)  = scopesAcc  acc2
                                     in
                                     reconstruct (FoldSeg f' z' acc1' acc2') 
                                       (accCount1 +++ accCount2 +++ accCount3 +++ accCount4)
          Fold1Seg f acc1 acc2    -> travF2A2 Fold1Seg f acc1 acc2
          Scanl f z acc           -> travF2EA Scanl f z acc
          Scanl' f z acc          -> travF2EA Scanl' f z acc
          Scanl1 f acc            -> travF2A Scanl1 f acc
          Scanr f z acc           -> travF2EA Scanr f z acc
          Scanr' f z acc          -> travF2EA Scanr' f z acc
          Scanr1 f acc            -> travF2A Scanr1 f acc
          Permute fc acc1 fp acc2 -> let
                                       (fc'  , accCount1) = scopesFun2 fc
                                       (acc1', accCount2) = scopesAcc  acc1
                                       (fp'  , accCount3) = scopesFun1 fp
                                       (acc2', accCount4) = scopesAcc  acc2
                                     in
                                     reconstruct (Permute fc' acc1' fp' acc2')
                                       (accCount1 +++ accCount2 +++ accCount3 +++ accCount4)
          Backpermute sh fp acc   -> let
                                       (sh' , accCount1) = scopesExp  sh
                                       (fp' , accCount2) = scopesFun1 fp
                                       (acc', accCount3) = scopesAcc  acc
                                     in
                                     reconstruct (Backpermute sh' fp' acc')
                                       (accCount1 +++ accCount2 +++ accCount3)
          Stencil st bnd acc      -> let
                                       (st' , accCount1) = scopesStencil1 acc st
                                       (acc', accCount2) = scopesAcc      acc
                                     in
                                     reconstruct (Stencil st' bnd acc') (accCount1 +++ accCount2)
          Stencil2 st bnd1 acc1 bnd2 acc2 
                                  -> let
                                       (st'  , accCount1) = scopesStencil2 acc1 acc2 st
                                       (acc1', accCount2) = scopesAcc acc1
                                       (acc2', accCount3) = scopesAcc acc2
                                     in
                                     reconstruct (Stencil2 st' bnd1 acc1' bnd2 acc2')
                                       (accCount1 +++ accCount2 +++ accCount3)
      where
        travEA :: Arrays arrs 
               => (SharingExp e -> SharingAcc arrs' -> PreAcc SharingAcc arrs) 
               -> SharingExp e
               -> SharingAcc arrs' 
               -> (SharingAcc arrs, NodeCounts)
        travEA c e acc = reconstruct (c e' acc') (accCount1 +++ accCount2)
          where
            (e'  , accCount1) = scopesExp e
            (acc', accCount2) = scopesAcc acc

        travF2A :: (Elt a, Elt b, Arrays arrs)
                => ((Exp a -> Exp b -> SharingExp c) -> SharingAcc arrs' -> PreAcc SharingAcc arrs) 
                -> (Exp a -> Exp b -> SharingExp c)
                -> SharingAcc arrs'
                -> (SharingAcc arrs, NodeCounts)
        travF2A c f acc = reconstruct (c f' acc') (accCount1 +++ accCount2)
          where
            (f'  , accCount1) = scopesFun2 f
            (acc', accCount2) = scopesAcc  acc              

        travF2EA :: (Elt a, Elt b, Arrays arrs)
                 => ((Exp a -> Exp b -> SharingExp c) -> SharingExp e 
                     -> SharingAcc arrs' -> PreAcc SharingAcc arrs) 
                 -> (Exp a -> Exp b -> SharingExp c)
                 -> SharingExp e 
                 -> SharingAcc arrs'
                 -> (SharingAcc arrs, NodeCounts)
        travF2EA c f e acc = reconstruct (c f' e' acc') (accCount1 +++ accCount2 +++ accCount3)
          where
            (f'  , accCount1) = scopesFun2 f
            (e'  , accCount2) = scopesExp  e
            (acc', accCount3) = scopesAcc  acc

        travF2A2 :: (Elt a, Elt b, Arrays arrs)
                 => ((Exp a -> Exp b -> SharingExp c) -> SharingAcc arrs1 
                     -> SharingAcc arrs2 -> PreAcc SharingAcc arrs) 
                 -> (Exp a -> Exp b -> SharingExp c)
                 -> SharingAcc arrs1 
                 -> SharingAcc arrs2 
                 -> (SharingAcc arrs, NodeCounts)
        travF2A2 c f acc1 acc2 = reconstruct (c f' acc1' acc2') 
                                             (accCount1 +++ accCount2 +++ accCount3)
          where
            (f'   , accCount1) = scopesFun2 f
            (acc1', accCount2) = scopesAcc  acc1
            (acc2', accCount3) = scopesAcc  acc2

        travA :: Arrays arrs 
              => (SharingAcc arrs' -> PreAcc SharingAcc arrs) 
              -> SharingAcc arrs' 
              -> (SharingAcc arrs, NodeCounts)
        travA c acc = reconstruct (c acc') accCount
          where
            (acc', accCount) = scopesAcc acc

          -- Occurence count of the currently processed node
        occCount = lookupWithAccName occMap (StableAccName sn)

        -- Reconstruct the current tree node.
        --
        -- * If the current node is being shared ('occCount > 1'), replace it by a 'VarSharing'
        --   node and float the shared subtree out wrapped in a 'NodeCounts' value.
        -- * If the current node is not shared, reconstruct it in place.
        --
        -- In either case, any completed 'NodeCounts' are injected as bindings using 'LetSharing'
        -- node.
        -- 
        reconstruct :: Arrays arrs 
                    => PreAcc SharingAcc arrs -> NodeCounts -> (SharingAcc arrs, NodeCounts)
        reconstruct newAcc subCount
          | occCount > 1 = ( VarSharing sn
                           , nodeCount (StableSharingAcc sn sharingAcc, 1) +++ newCount)
          | otherwise    = (sharingAcc, newCount)
          where
              -- Determine the bindings that need to be attached to the current node...
            (newCount, bindHere) = filterCompleted subCount

              -- ...and wrap them in 'LetSharing' constructors
            lets       = foldl (flip (.)) id . map LetSharing $ bindHere
            sharingAcc = lets $ AccSharing sn newAcc

        -- Extract nodes that have a complete node count (i.e., their node count is equal to the
        -- number of occurences of that node in the overall expression) => nodes with a completed
        -- node count should be let bound at the currently processed node.
        --
        filterCompleted :: NodeCounts -> (NodeCounts, [StableSharingAcc])
        filterCompleted (NodeCounts counts) 
          = let (counts', completed) = fc counts
            in (NodeCounts counts', completed)
          where
            fc []                 = ([], [])
            fc (sub@(sa, n):subs)
                -- current node is the binding point for the shared node 'sa'
              | occCount == n     = (subs', sa:bindHere)
                -- not a binding point
              | otherwise         = (sub:subs', bindHere)
              where
                occCount          = lookupWithSharingAcc occMap sa
                (subs', bindHere) = fc subs

    scopesExp :: forall arrs. SharingExp arrs -> (SharingExp arrs, NodeCounts)
    scopesExp pacc
      = case pacc of
          Tag i           -> (Tag i, noNodeCounts)
          Const c         -> (Const c, noNodeCounts)
          Tuple tup       -> let (tup', accCount) = travTup tup in (Tuple tup', accCount)
          Prj i e         -> travE1 (Prj i) e
          IndexNil        -> (IndexNil, noNodeCounts)
          IndexCons ix i  -> travE2 IndexCons ix i
          IndexHead i     -> travE1 IndexHead i
          IndexTail ix    -> travE1 IndexTail ix
          IndexAny        -> (IndexAny, noNodeCounts)
          Cond e1 e2 e3   -> travE3 Cond e1 e2 e3
          PrimConst c     -> (PrimConst c, noNodeCounts)
          PrimApp p e     -> travE1 (PrimApp p) e
          IndexScalar a e -> travAE IndexScalar a e
          Shape a         -> travA Shape a
          Size a          -> travA Size a
     where
        travTup :: Tuple.Tuple (PreExp SharingAcc) tup 
                -> (Tuple.Tuple (PreExp SharingAcc) tup, NodeCounts)
        travTup NilTup          = (NilTup, noNodeCounts)
        travTup (SnocTup tup e) = let
                                    (tup', accCountT) = travTup tup
                                    (e'  , accCountE) = scopesExp e
                                  in
                                  (SnocTup tup' e', accCountT +++ accCountE)

        travE1 :: (SharingExp a -> SharingExp b) -> SharingExp a -> (SharingExp b, NodeCounts)
        travE1 c e = (c e', accCount)
          where
            (e', accCount) = scopesExp e

        travE2 :: (SharingExp a -> SharingExp b -> SharingExp c) -> SharingExp a -> SharingExp b 
               -> (SharingExp c, NodeCounts)
        travE2 c e1 e2 = (c e1' e2', accCount1 +++ accCount2)
          where
            (e1', accCount1) = scopesExp e1
            (e2', accCount2) = scopesExp e2

        travE3 :: (SharingExp a -> SharingExp b -> SharingExp c -> SharingExp d) 
               -> SharingExp a -> SharingExp b -> SharingExp c 
               -> (SharingExp d, NodeCounts)
        travE3 c e1 e2 e3 = (c e1' e2' e3', accCount1 +++ accCount2 +++ accCount3)
          where
            (e1', accCount1) = scopesExp e1
            (e2', accCount2) = scopesExp e2
            (e3', accCount3) = scopesExp e3

        travA :: (SharingAcc a -> SharingExp b) -> SharingAcc a -> (SharingExp b, NodeCounts)
        travA c acc = maybeFloatOutAcc c acc' accCount
          where
            (acc', accCount)  = scopesAcc acc
        
        travAE :: (SharingAcc a -> SharingExp b -> SharingExp c) -> SharingAcc a -> SharingExp b 
               -> (SharingExp c, NodeCounts)
        travAE c acc e = maybeFloatOutAcc (flip c e') acc' (accCountA +++ accCountE)
          where
            (acc', accCountA) = scopesAcc acc
            (e'  , accCountE) = scopesExp e
        
        maybeFloatOutAcc :: (SharingAcc a -> SharingExp b) -> SharingAcc a -> NodeCounts
                         -> (SharingExp b, NodeCounts)
        maybeFloatOutAcc c acc@(VarSharing _) accCount = (c acc, accCount)  -- nothing to float out
        maybeFloatOutAcc c acc                accCount
          | floatOutAcc = (c var, nodeCount (stableAcc, 1) +++ accCount)
          | otherwise   = (c acc, accCount)
          where
             (var, stableAcc) = abstract acc id

        abstract :: SharingAcc a -> (SharingAcc a -> SharingAcc a) 
                 -> (SharingAcc a, StableSharingAcc)
        abstract (VarSharing _)        _    = INTERNAL_ERROR(error) "sharingAccToVar" "VarSharing"
        abstract (LetSharing sa acc)   lets = abstract acc (lets . LetSharing sa)
        abstract acc@(AccSharing sn _) lets = (VarSharing sn, StableSharingAcc sn (lets acc))

    -- The lambda bound variable is at this point already irrelevant; for details, see
    -- Note [Traversing functions and side effects]
    --
    scopesFun1 :: Elt e1 => (Exp e1 -> SharingExp e2) -> (Exp e1 -> SharingExp e2, NodeCounts)
    scopesFun1 f = (const body, counts)
      where
        (body, counts) = scopesExp (f undefined)

    -- The lambda bound variable is at this point already irrelevant; for details, see
    -- Note [Traversing functions and side effects]
    --
    scopesFun2 :: (Elt e1, Elt e2) 
               => (Exp e1 -> Exp e2 -> SharingExp e3) 
               -> (Exp e1 -> Exp e2 -> SharingExp e3, NodeCounts)
    scopesFun2 f = (\_ _ -> body, counts)
      where
        (body, counts) = scopesExp (f undefined undefined)

    -- The lambda bound variable is at this point already irrelevant; for details, see
    -- Note [Traversing functions and side effects]
    --
    scopesStencil1 :: forall sh e1 e2 stencil. Stencil sh e1 stencil
                   => SharingAcc (Array sh e1){-dummy-}
                   -> (stencil -> SharingExp e2) 
                   -> (stencil -> SharingExp e2, NodeCounts)
    scopesStencil1 _ stencilFun = (const body, counts)
      where
        (body, counts) = scopesExp (stencilFun undefined)
          
    -- The lambda bound variable is at this point already irrelevant; for details, see
    -- Note [Traversing functions and side effects]
    --
    scopesStencil2 :: forall sh e1 e2 e3 stencil1 stencil2. 
                      (Stencil sh e1 stencil1, Stencil sh e2 stencil2)
                   => SharingAcc (Array sh e1){-dummy-}
                   -> SharingAcc (Array sh e2){-dummy-}
                   -> (stencil1 -> stencil2 -> SharingExp e3) 
                   -> (stencil1 -> stencil2 -> SharingExp e3, NodeCounts)
    scopesStencil2 _ _ stencilFun = (\_ _ -> body, counts)
      where
        (body, counts) = scopesExp (stencilFun undefined undefined)          
                  
-- |Recover sharing information and annotate the HOAS AST with variable and let binding
--  annotations.  The first argument determines whether array computations are floated out of
--  expressions irrespective of whether they are shared or not — 'True' implies floating them out.
--
-- NB: Strictly speaking, this function is not deterministic, as it uses stable pointers to
--     determine the sharing of subterms.  The stable pointer API does not guarantee its
--     completeness; i.e., it may miss some equalities, which implies that we may fail to discover
--     some sharing.  However, sharing does not affect the denotational meaning of an array
--     computation; hence, we do not compromise denotational correctness.
--
recoverSharing :: Typeable a => Bool -> Acc a -> SharingAcc a
{-# NOINLINE recoverSharing #-}
recoverSharing floatOutAcc acc 
  = let (acc', occMap) =   -- as we need to use stable pointers; it's safe as explained above
          unsafePerformIO $ do
            (acc', occMap) <- makeOccMap acc
 
            occMapList <- Hash.toList occMap
            traceChunk "OccMap" $
              show occMapList
 
            frozenOccMap <- freezeOccMap occMap
            return (acc', frozenOccMap)
    in 
    determineScopes floatOutAcc occMap acc'


-- Embedded expressions of the surface language
-- --------------------------------------------

-- HOAS expressions mirror the constructors of `AST.OpenExp', but with the
-- `Tag' constructor instead of variables in the form of de Bruijn indices.
-- Moreover, HOAS expression use n-tuples and the type class 'Elt' to
-- constrain element types, whereas `AST.OpenExp' uses nested pairs and the 
-- GADT 'TupleType'.
--

-- |Scalar expressions to parametrise collective array operations, themselves parameterised over
-- the type of collective array operations.
--
data PreExp acc t where
    -- Needed for conversion to de Bruijn form
  Tag         :: Elt t
              => Int                           -> PreExp acc t
                 -- environment size at defining occurrence

    -- All the same constructors as 'AST.Exp'
  Const       :: Elt t 
              => t                                               -> PreExp acc t

  Tuple       :: (Elt t, IsTuple t)
              => Tuple.Tuple (PreExp acc) (TupleRepr t)          -> PreExp acc t
  Prj         :: (Elt t, IsTuple t)
              => TupleIdx (TupleRepr t) e     
              -> PreExp acc t                                    -> PreExp acc e
  IndexNil    ::                                                    PreExp acc Z
  IndexCons   :: (Slice sl, Elt a)
              => PreExp acc sl -> PreExp acc a                   -> PreExp acc (sl:.a)
  IndexHead   :: (Slice sl, Elt a)
              => PreExp acc (sl:.a)                              -> PreExp acc a
  IndexTail   :: (Slice sl, Elt a)
              => PreExp acc (sl:.a)                              -> PreExp acc sl
  IndexAny    :: Shape sh
              =>                                                    PreExp acc (Any sh)
  Cond        :: PreExp acc Bool -> PreExp acc t -> PreExp acc t -> PreExp acc t
  PrimConst   :: Elt t                       
              => PrimConst t                                     -> PreExp acc t
  PrimApp     :: (Elt a, Elt r)             
              => PrimFun (a -> r) -> PreExp acc a                -> PreExp acc r
  IndexScalar :: (Shape sh, Elt t)
              => acc (Array sh t) -> PreExp acc sh               -> PreExp acc t
  Shape       :: (Shape sh, Elt e)
              => acc (Array sh e)                                -> PreExp acc sh
  Size        :: (Shape sh, Elt e)
              => acc (Array sh e)                                -> PreExp acc Int

-- |Scalar expressions for plain array computations.
--
type Exp t = PreExp Acc t

-- |Scalar expressions for array computations with sharing annotations.
--
type SharingExp t = PreExp SharingAcc t

-- |Conversion from HOAS to de Bruijn expression AST
-- -

-- |Convert an open expression with given environment layouts.
--
convertOpenExp :: forall t env aenv. 
                  Layout env  env       -- scalar environment
               -> Layout aenv aenv      -- array environment
               -> [StableSharingAcc]    -- currently bound sharing variables
               -> SharingExp t          -- expression to be converted
               -> AST.OpenExp env aenv t
convertOpenExp lyt alyt env = cvt
  where
    cvt :: SharingExp t' -> AST.OpenExp env aenv t'
    cvt (Tag i)             = AST.Var (prjIdx i lyt)
    cvt (Const v)           = AST.Const (fromElt v)
    cvt (Tuple tup)         = AST.Tuple (convertTuple lyt alyt env tup)
    cvt (Prj idx e)         = AST.Prj idx (cvt e)
    cvt IndexNil            = AST.IndexNil
    cvt (IndexCons ix i)    = AST.IndexCons (cvt ix) (cvt i)
    cvt (IndexHead i)       = AST.IndexHead (cvt i)
    cvt (IndexTail ix)      = AST.IndexTail (cvt ix)
    cvt (IndexAny)          = AST.IndexAny
    cvt (Cond e1 e2 e3)     = AST.Cond (cvt e1) (cvt e2) (cvt e3)
    cvt (PrimConst c)       = AST.PrimConst c
    cvt (PrimApp p e)       = AST.PrimApp p (cvt e)
    cvt (IndexScalar a e)   = AST.IndexScalar (convertSharingAcc alyt env a) (cvt e)
    cvt (Shape a)           = AST.Shape (convertSharingAcc alyt env a)
    cvt (Size a)            = AST.Size (convertSharingAcc alyt env a)

-- |Convert a tuple expression
--
convertTuple :: Layout env env 
             -> Layout aenv aenv 
             -> [StableSharingAcc]                 -- currently bound sharing variables
             -> Tuple.Tuple (PreExp SharingAcc) t 
             -> Tuple.Tuple (AST.OpenExp env aenv) t
convertTuple _lyt _alyt _env NilTup           = NilTup
convertTuple lyt  alyt  env  (es `SnocTup` e) 
  = convertTuple lyt alyt env es `SnocTup` convertOpenExp lyt alyt env e

-- |Convert an expression closed wrt to scalar variables
--
convertExp :: Layout aenv aenv      -- array environment
           -> [StableSharingAcc]    -- currently bound sharing variables
           -> SharingExp t          -- expression to be converted
           -> AST.Exp aenv t
convertExp alyt env = convertOpenExp EmptyLayout alyt env

-- |Convert a unary functions
--
convertFun1 :: forall a b aenv. Elt a
            => Layout aenv aenv 
            -> [StableSharingAcc]               -- currently bound sharing variables
            -> (Exp a -> SharingExp b) 
            -> AST.Fun aenv (a -> b)
convertFun1 alyt env f = Lam (Body openF)
  where
    a     = Tag 0
    lyt   = EmptyLayout 
            `PushLayout` 
            (ZeroIdx :: Idx ((), EltRepr a) (EltRepr a))
    openF = convertOpenExp lyt alyt env (f a)

-- |Convert a binary functions
--
convertFun2 :: forall a b c aenv. (Elt a, Elt b) 
            => Layout aenv aenv 
            -> [StableSharingAcc]               -- currently bound sharing variables
            -> (Exp a -> Exp b -> SharingExp c) 
            -> AST.Fun aenv (a -> b -> c)
convertFun2 alyt env f = Lam (Lam (Body openF))
  where
    a     = Tag 1
    b     = Tag 0
    lyt   = EmptyLayout 
            `PushLayout`
            (SuccIdx ZeroIdx :: Idx (((), EltRepr a), EltRepr b) (EltRepr a))
            `PushLayout`
            (ZeroIdx         :: Idx (((), EltRepr a), EltRepr b) (EltRepr b))
    openF = convertOpenExp lyt alyt env (f a b)

-- Convert a unary stencil function
--
convertStencilFun :: forall sh a stencil b aenv. (Elt a, Stencil sh a stencil)
                  => SharingAcc (Array sh a)            -- just passed to fix the type variables
                  -> Layout aenv aenv 
                  -> [StableSharingAcc]                 -- currently bound sharing variables
                  -> (stencil -> SharingExp b)
                  -> AST.Fun aenv (StencilRepr sh stencil -> b)
convertStencilFun _ alyt env stencilFun = Lam (Body openStencilFun)
  where
    stencil = Tag 0 :: Exp (StencilRepr sh stencil)
    lyt     = EmptyLayout 
              `PushLayout` 
              (ZeroIdx :: Idx ((), EltRepr (StencilRepr sh stencil)) 
                              (EltRepr (StencilRepr sh stencil)))
    openStencilFun = convertOpenExp lyt alyt env $
                       stencilFun (stencilPrj (undefined::sh) (undefined::a) stencil)

-- Convert a binary stencil function
--
convertStencilFun2 :: forall sh a b stencil1 stencil2 c aenv. 
                      (Elt a, Stencil sh a stencil1,
                       Elt b, Stencil sh b stencil2)
                   => SharingAcc (Array sh a)           -- just passed to fix the type variables
                   -> SharingAcc (Array sh b)           -- just passed to fix the type variables
                   -> Layout aenv aenv 
                   -> [StableSharingAcc]                 -- currently bound sharing variables
                   -> (stencil1 -> stencil2 -> SharingExp c)
                   -> AST.Fun aenv (StencilRepr sh stencil1 ->
                                    StencilRepr sh stencil2 -> c)
convertStencilFun2 _ _ alyt env stencilFun = Lam (Lam (Body openStencilFun))
  where
    stencil1 = Tag 1 :: Exp (StencilRepr sh stencil1)
    stencil2 = Tag 0 :: Exp (StencilRepr sh stencil2)
    lyt     = EmptyLayout 
              `PushLayout` 
              (SuccIdx ZeroIdx :: Idx (((), EltRepr (StencilRepr sh stencil1)),
                                            EltRepr (StencilRepr sh stencil2)) 
                                       (EltRepr (StencilRepr sh stencil1)))
              `PushLayout` 
              (ZeroIdx         :: Idx (((), EltRepr (StencilRepr sh stencil1)),
                                            EltRepr (StencilRepr sh stencil2)) 
                                       (EltRepr (StencilRepr sh stencil2)))
    openStencilFun = convertOpenExp lyt alyt env $
                       stencilFun (stencilPrj (undefined::sh) (undefined::a) stencil1)
                                  (stencilPrj (undefined::sh) (undefined::b) stencil2)


-- Pretty printing
--

instance Arrays arrs => Show (Acc arrs) where
  show = show . convertAcc
  
instance Show (Exp a) where
  show = show . convertExp EmptyLayout [] . toSharingExp
    where
      toSharingExp :: Exp b -> SharingExp b
      toSharingExp (Tag i)             = Tag i
      toSharingExp (Const v)           = Const v
      toSharingExp (Tuple tup)         = Tuple (toSharingTup tup)
      toSharingExp (Prj idx e)         = Prj idx (toSharingExp e)
      toSharingExp IndexNil            = IndexNil
      toSharingExp (IndexCons ix i)    = IndexCons (toSharingExp ix) (toSharingExp i)
      toSharingExp (IndexHead ix)      = IndexHead (toSharingExp ix)
      toSharingExp (IndexTail ix)      = IndexTail (toSharingExp ix)
      toSharingExp (IndexAny)          = IndexAny
      toSharingExp (Cond e1 e2 e3)     = Cond (toSharingExp e1) (toSharingExp e2) (toSharingExp e3)
      toSharingExp (PrimConst c)       = PrimConst c
      toSharingExp (PrimApp p e)       = PrimApp p (toSharingExp e)
      toSharingExp (IndexScalar a e)   = IndexScalar (recoverSharing False a) (toSharingExp e)
      toSharingExp (Shape a)           = Shape (recoverSharing False a)
      toSharingExp (Size a)            = Size (recoverSharing False a)

      toSharingTup :: Tuple.Tuple (PreExp Acc) tup -> Tuple.Tuple (PreExp SharingAcc) tup
      toSharingTup NilTup          = NilTup
      toSharingTup (SnocTup tup e) = SnocTup (toSharingTup tup) (toSharingExp e)

-- for debugging
showPreAccOp :: PreAcc acc arrs -> String
showPreAccOp (Atag _)             = "Atag"                   
showPreAccOp (Pipe _ _ _)         = "Pipe"
showPreAccOp (Acond _ _ _)        = "Acond"
showPreAccOp (FstArray _)         = "FstArray"
showPreAccOp (SndArray _)         = "SndArray"
showPreAccOp (PairArrays _ _)     = "PairArrays"
showPreAccOp (Use _)              = "Use"
showPreAccOp (Unit _)             = "Unit"
showPreAccOp (Generate _ _)       = "Generate"
showPreAccOp (Reshape _ _)        = "Reshape"
showPreAccOp (Replicate _ _)      = "Replicate"
showPreAccOp (Index _ _)          = "Index"
showPreAccOp (Map _ _)            = "Map"
showPreAccOp (ZipWith _ _ _)      = "ZipWith"
showPreAccOp (Fold _ _ _)         = "Fold"
showPreAccOp (Fold1 _ _)          = "Fold1"
showPreAccOp (FoldSeg _ _ _ _)    = "FoldSeg"
showPreAccOp (Fold1Seg _ _ _)     = "Fold1Seg"
showPreAccOp (Scanl _ _ _)        = "Scanl"
showPreAccOp (Scanl' _ _ _)       = "Scanl'"
showPreAccOp (Scanl1 _ _)         = "Scanl1"
showPreAccOp (Scanr _ _ _)        = "Scanr"
showPreAccOp (Scanr' _ _ _)       = "Scanr'"
showPreAccOp (Scanr1 _ _)         = "Scanr1"
showPreAccOp (Permute _ _ _ _)    = "Permute"
showPreAccOp (Backpermute _ _ _)  = "Backpermute"
showPreAccOp (Stencil _ _ _)      = "Stencil"
showPreAccOp (Stencil2 _ _ _ _ _) = "Stencil2"

_showSharingAccOp :: SharingAcc arrs -> String
_showSharingAccOp (VarSharing sn)    = "VAR " ++ show (hashStableName sn)
_showSharingAccOp (LetSharing _ acc) = "LET " ++ _showSharingAccOp acc
_showSharingAccOp (AccSharing _ acc) = showPreAccOp acc


-- |Smart constructors to construct representation AST forms
-- ---------------------------------------------------------

mkIndex :: forall slix e aenv. (Slice slix, Elt e)
        => AST.OpenAcc                aenv (Array (FullShape slix) e)
        -> AST.Exp                    aenv slix
        -> AST.PreOpenAcc AST.OpenAcc aenv (Array (SliceShape slix) e)
mkIndex arr e
  = AST.Index (sliceIndex slix) arr e
  where
    slix = undefined :: slix

mkReplicate :: forall slix e aenv. (Slice slix, Elt e)
        => AST.Exp                    aenv slix
        -> AST.OpenAcc                aenv (Array (SliceShape slix) e)
        -> AST.PreOpenAcc AST.OpenAcc aenv (Array (FullShape slix) e)
mkReplicate e arr
  = AST.Replicate (sliceIndex slix) e arr
  where
    slix = undefined :: slix


-- |Smart constructors for stencil reification
-- -------------------------------------------

-- Stencil reification
--
-- In the AST representation, we turn the stencil type from nested tuples of Accelerate expressions
-- into an Accelerate expression whose type is a tuple nested in the same manner.  This enables us
-- to represent the stencil function as a unary function (which also only needs one de Bruijn
-- index). The various positions in the stencil are accessed via tuple indices (i.e., projections).

class (Elt (StencilRepr sh stencil), AST.Stencil sh a (StencilRepr sh stencil)) 
  => Stencil sh a stencil where
  type StencilRepr sh stencil :: *
  stencilPrj :: sh{-dummy-} -> a{-dummy-} -> Exp (StencilRepr sh stencil) -> stencil

-- DIM1
instance Elt e => Stencil DIM1 e (Exp e, Exp e, Exp e) where
  type StencilRepr DIM1 (Exp e, Exp e, Exp e) 
    = (e, e, e)
  stencilPrj _ _ s = (Prj tix2 s, 
                      Prj tix1 s, 
                      Prj tix0 s)
instance Elt e => Stencil DIM1 e (Exp e, Exp e, Exp e, Exp e, Exp e) where
  type StencilRepr DIM1 (Exp e, Exp e, Exp e, Exp e, Exp e)
    = (e, e, e, e, e)
  stencilPrj _ _ s = (Prj tix4 s, 
                      Prj tix3 s, 
                      Prj tix2 s, 
                      Prj tix1 s, 
                      Prj tix0 s)
instance Elt e => Stencil DIM1 e (Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e) where
  type StencilRepr DIM1 (Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e) 
    = (e, e, e, e, e, e, e)
  stencilPrj _ _ s = (Prj tix6 s, 
                      Prj tix5 s, 
                      Prj tix4 s, 
                      Prj tix3 s, 
                      Prj tix2 s, 
                      Prj tix1 s, 
                      Prj tix0 s)
instance Elt e => Stencil DIM1 e (Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e) where
  type StencilRepr DIM1 (Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e)
    = (e, e, e, e, e, e, e, e, e)
  stencilPrj _ _ s = (Prj tix8 s, 
                      Prj tix7 s, 
                      Prj tix6 s, 
                      Prj tix5 s, 
                      Prj tix4 s, 
                      Prj tix3 s, 
                      Prj tix2 s, 
                      Prj tix1 s, 
                      Prj tix0 s)

-- DIM(n+1)
instance (Stencil (sh:.Int) a row2, 
          Stencil (sh:.Int) a row1,
          Stencil (sh:.Int) a row0) => Stencil (sh:.Int:.Int) a (row2, row1, row0) where
  type StencilRepr (sh:.Int:.Int) (row2, row1, row0) 
    = (StencilRepr (sh:.Int) row2, StencilRepr (sh:.Int) row1, StencilRepr (sh:.Int) row0)
  stencilPrj _ a s = (stencilPrj (undefined::(sh:.Int)) a (Prj tix2 s), 
                      stencilPrj (undefined::(sh:.Int)) a (Prj tix1 s), 
                      stencilPrj (undefined::(sh:.Int)) a (Prj tix0 s))
instance (Stencil (sh:.Int) a row1,
          Stencil (sh:.Int) a row2,
          Stencil (sh:.Int) a row3,
          Stencil (sh:.Int) a row4,
          Stencil (sh:.Int) a row5) => Stencil (sh:.Int:.Int) a (row1, row2, row3, row4, row5) where
  type StencilRepr (sh:.Int:.Int) (row1, row2, row3, row4, row5) 
    = (StencilRepr (sh:.Int) row1, StencilRepr (sh:.Int) row2, StencilRepr (sh:.Int) row3,
       StencilRepr (sh:.Int) row4, StencilRepr (sh:.Int) row5)
  stencilPrj _ a s = (stencilPrj (undefined::(sh:.Int)) a (Prj tix4 s), 
                      stencilPrj (undefined::(sh:.Int)) a (Prj tix3 s), 
                      stencilPrj (undefined::(sh:.Int)) a (Prj tix2 s), 
                      stencilPrj (undefined::(sh:.Int)) a (Prj tix1 s), 
                      stencilPrj (undefined::(sh:.Int)) a (Prj tix0 s))
instance (Stencil (sh:.Int) a row1,
          Stencil (sh:.Int) a row2,
          Stencil (sh:.Int) a row3,
          Stencil (sh:.Int) a row4,
          Stencil (sh:.Int) a row5,
          Stencil (sh:.Int) a row6,
          Stencil (sh:.Int) a row7) 
  => Stencil (sh:.Int:.Int) a (row1, row2, row3, row4, row5, row6, row7) where
  type StencilRepr (sh:.Int:.Int) (row1, row2, row3, row4, row5, row6, row7) 
    = (StencilRepr (sh:.Int) row1, StencilRepr (sh:.Int) row2, StencilRepr (sh:.Int) row3,
       StencilRepr (sh:.Int) row4, StencilRepr (sh:.Int) row5, StencilRepr (sh:.Int) row6,
       StencilRepr (sh:.Int) row7)
  stencilPrj _ a s = (stencilPrj (undefined::(sh:.Int)) a (Prj tix6 s), 
                      stencilPrj (undefined::(sh:.Int)) a (Prj tix5 s), 
                      stencilPrj (undefined::(sh:.Int)) a (Prj tix4 s), 
                      stencilPrj (undefined::(sh:.Int)) a (Prj tix3 s), 
                      stencilPrj (undefined::(sh:.Int)) a (Prj tix2 s), 
                      stencilPrj (undefined::(sh:.Int)) a (Prj tix1 s), 
                      stencilPrj (undefined::(sh:.Int)) a (Prj tix0 s))
instance (Stencil (sh:.Int) a row1,
          Stencil (sh:.Int) a row2,
          Stencil (sh:.Int) a row3,
          Stencil (sh:.Int) a row4,
          Stencil (sh:.Int) a row5,
          Stencil (sh:.Int) a row6,
          Stencil (sh:.Int) a row7,
          Stencil (sh:.Int) a row8,
          Stencil (sh:.Int) a row9) 
  => Stencil (sh:.Int:.Int) a (row1, row2, row3, row4, row5, row6, row7, row8, row9) where
  type StencilRepr (sh:.Int:.Int) (row1, row2, row3, row4, row5, row6, row7, row8, row9) 
    = (StencilRepr (sh:.Int) row1, StencilRepr (sh:.Int) row2, StencilRepr (sh:.Int) row3,
       StencilRepr (sh:.Int) row4, StencilRepr (sh:.Int) row5, StencilRepr (sh:.Int) row6,
       StencilRepr (sh:.Int) row7, StencilRepr (sh:.Int) row8, StencilRepr (sh:.Int) row9)
  stencilPrj _ a s = (stencilPrj (undefined::(sh:.Int)) a (Prj tix8 s), 
                      stencilPrj (undefined::(sh:.Int)) a (Prj tix7 s), 
                      stencilPrj (undefined::(sh:.Int)) a (Prj tix6 s), 
                      stencilPrj (undefined::(sh:.Int)) a (Prj tix5 s), 
                      stencilPrj (undefined::(sh:.Int)) a (Prj tix4 s), 
                      stencilPrj (undefined::(sh:.Int)) a (Prj tix3 s), 
                      stencilPrj (undefined::(sh:.Int)) a (Prj tix2 s), 
                      stencilPrj (undefined::(sh:.Int)) a (Prj tix1 s), 
                      stencilPrj (undefined::(sh:.Int)) a (Prj tix0 s))
  
-- Auxiliary tuple index constants
--
tix0 :: Elt s => TupleIdx (t, s) s
tix0 = ZeroTupIdx
tix1 :: Elt s => TupleIdx ((t, s), s1) s
tix1 = SuccTupIdx tix0
tix2 :: Elt s => TupleIdx (((t, s), s1), s2) s
tix2 = SuccTupIdx tix1
tix3 :: Elt s => TupleIdx ((((t, s), s1), s2), s3) s
tix3 = SuccTupIdx tix2
tix4 :: Elt s => TupleIdx (((((t, s), s1), s2), s3), s4) s
tix4 = SuccTupIdx tix3
tix5 :: Elt s => TupleIdx ((((((t, s), s1), s2), s3), s4), s5) s
tix5 = SuccTupIdx tix4
tix6 :: Elt s => TupleIdx (((((((t, s), s1), s2), s3), s4), s5), s6) s
tix6 = SuccTupIdx tix5
tix7 :: Elt s => TupleIdx ((((((((t, s), s1), s2), s3), s4), s5), s6), s7) s
tix7 = SuccTupIdx tix6
tix8 :: Elt s => TupleIdx (((((((((t, s), s1), s2), s3), s4), s5), s6), s7), s8) s
tix8 = SuccTupIdx tix7

-- Pushes the 'Acc' constructor through a pair
--
unpair :: (Shape sh1, Shape sh2, Elt e1, Elt e2)
       => Acc (Array sh1 e1, Array sh2 e2) 
       -> (Acc (Array sh1 e1), Acc (Array sh2 e2))
unpair acc = (Acc $ FstArray acc, Acc $ SndArray acc)

-- Creates an 'Acc' pair from two separate 'Acc's.
--
pair :: (Shape sh1, Shape sh2, Elt e1, Elt e2)
     => Acc (Array sh1 e1)
     -> Acc (Array sh2 e2)
     -> Acc (Array sh1 e1, Array sh2 e2)
pair acc1 acc2 = Acc $ PairArrays acc1 acc2


-- Smart constructor for literals
-- 

-- |Constant scalar expression
--
constant :: Elt t => t -> Exp t
constant = Const

-- Smart constructor and destructors for tuples
--

tup2 :: (Elt a, Elt b) => (Exp a, Exp b) -> Exp (a, b)
tup2 (x1, x2) = Tuple (NilTup `SnocTup` x1 `SnocTup` x2)

tup3 :: (Elt a, Elt b, Elt c) => (Exp a, Exp b, Exp c) -> Exp (a, b, c)
tup3 (x1, x2, x3) = Tuple (NilTup `SnocTup` x1 `SnocTup` x2 `SnocTup` x3)

tup4 :: (Elt a, Elt b, Elt c, Elt d) 
     => (Exp a, Exp b, Exp c, Exp d) -> Exp (a, b, c, d)
tup4 (x1, x2, x3, x4) 
  = Tuple (NilTup `SnocTup` x1 `SnocTup` x2 `SnocTup` x3 `SnocTup` x4)

tup5 :: (Elt a, Elt b, Elt c, Elt d, Elt e) 
     => (Exp a, Exp b, Exp c, Exp d, Exp e) -> Exp (a, b, c, d, e)
tup5 (x1, x2, x3, x4, x5)
  = Tuple $
      NilTup `SnocTup` x1 `SnocTup` x2 `SnocTup` x3 `SnocTup` x4 `SnocTup` x5

tup6 :: (Elt a, Elt b, Elt c, Elt d, Elt e, Elt f)
     => (Exp a, Exp b, Exp c, Exp d, Exp e, Exp f) -> Exp (a, b, c, d, e, f)
tup6 (x1, x2, x3, x4, x5, x6)
  = Tuple $
      NilTup `SnocTup` x1 `SnocTup` x2 `SnocTup` x3 `SnocTup` x4 `SnocTup` x5 `SnocTup` x6

tup7 :: (Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g)
     => (Exp a, Exp b, Exp c, Exp d, Exp e, Exp f, Exp g)
     -> Exp (a, b, c, d, e, f, g)
tup7 (x1, x2, x3, x4, x5, x6, x7)
  = Tuple $
      NilTup `SnocTup` x1 `SnocTup` x2 `SnocTup` x3
	     `SnocTup` x4 `SnocTup` x5 `SnocTup` x6 `SnocTup` x7

tup8 :: (Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h)
     => (Exp a, Exp b, Exp c, Exp d, Exp e, Exp f, Exp g, Exp h)
     -> Exp (a, b, c, d, e, f, g, h)
tup8 (x1, x2, x3, x4, x5, x6, x7, x8)
  = Tuple $
      NilTup `SnocTup` x1 `SnocTup` x2 `SnocTup` x3 `SnocTup` x4
	     `SnocTup` x5 `SnocTup` x6 `SnocTup` x7 `SnocTup` x8

tup9 :: (Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h, Elt i)
     => (Exp a, Exp b, Exp c, Exp d, Exp e, Exp f, Exp g, Exp h, Exp i)
     -> Exp (a, b, c, d, e, f, g, h, i)
tup9 (x1, x2, x3, x4, x5, x6, x7, x8, x9)
  = Tuple $
      NilTup `SnocTup` x1 `SnocTup` x2 `SnocTup` x3 `SnocTup` x4
	     `SnocTup` x5 `SnocTup` x6 `SnocTup` x7 `SnocTup` x8 `SnocTup` x9

untup2 :: (Elt a, Elt b) => Exp (a, b) -> (Exp a, Exp b)
untup2 e = (SuccTupIdx ZeroTupIdx `Prj` e, ZeroTupIdx `Prj` e)

untup3 :: (Elt a, Elt b, Elt c) => Exp (a, b, c) -> (Exp a, Exp b, Exp c)
untup3 e = (SuccTupIdx (SuccTupIdx ZeroTupIdx) `Prj` e, 
            SuccTupIdx ZeroTupIdx `Prj` e, 
            ZeroTupIdx `Prj` e)

untup4 :: (Elt a, Elt b, Elt c, Elt d) 
       => Exp (a, b, c, d) -> (Exp a, Exp b, Exp c, Exp d)
untup4 e = (SuccTupIdx (SuccTupIdx (SuccTupIdx ZeroTupIdx)) `Prj` e, 
            SuccTupIdx (SuccTupIdx ZeroTupIdx) `Prj` e, 
            SuccTupIdx ZeroTupIdx `Prj` e, 
            ZeroTupIdx `Prj` e)

untup5 :: (Elt a, Elt b, Elt c, Elt d, Elt e) 
       => Exp (a, b, c, d, e) -> (Exp a, Exp b, Exp c, Exp d, Exp e)
untup5 e = (SuccTupIdx (SuccTupIdx (SuccTupIdx (SuccTupIdx ZeroTupIdx))) 
            `Prj` e, 
            SuccTupIdx (SuccTupIdx (SuccTupIdx ZeroTupIdx)) `Prj` e, 
            SuccTupIdx (SuccTupIdx ZeroTupIdx) `Prj` e, 
            SuccTupIdx ZeroTupIdx `Prj` e, 
            ZeroTupIdx `Prj` e)

untup6 :: (Elt a, Elt b, Elt c, Elt d, Elt e, Elt f)
       => Exp (a, b, c, d, e, f) -> (Exp a, Exp b, Exp c, Exp d, Exp e, Exp f)
untup6 e = (SuccTupIdx (SuccTupIdx (SuccTupIdx (SuccTupIdx (SuccTupIdx ZeroTupIdx)))) `Prj` e,
            SuccTupIdx (SuccTupIdx (SuccTupIdx (SuccTupIdx ZeroTupIdx))) `Prj` e,
            SuccTupIdx (SuccTupIdx (SuccTupIdx ZeroTupIdx)) `Prj` e,
            SuccTupIdx (SuccTupIdx ZeroTupIdx) `Prj` e,
            SuccTupIdx ZeroTupIdx `Prj` e,
            ZeroTupIdx `Prj` e)

untup7 :: (Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g)
       => Exp (a, b, c, d, e, f, g) -> (Exp a, Exp b, Exp c, Exp d, Exp e, Exp f, Exp g)
untup7 e = (SuccTupIdx (SuccTupIdx (SuccTupIdx (SuccTupIdx (SuccTupIdx (SuccTupIdx ZeroTupIdx))))) `Prj` e,
            SuccTupIdx (SuccTupIdx (SuccTupIdx (SuccTupIdx (SuccTupIdx ZeroTupIdx)))) `Prj` e,
            SuccTupIdx (SuccTupIdx (SuccTupIdx (SuccTupIdx ZeroTupIdx))) `Prj` e,
            SuccTupIdx (SuccTupIdx (SuccTupIdx ZeroTupIdx)) `Prj` e,
            SuccTupIdx (SuccTupIdx ZeroTupIdx) `Prj` e,
            SuccTupIdx ZeroTupIdx `Prj` e,
            ZeroTupIdx `Prj` e)

untup8 :: (Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h)
       => Exp (a, b, c, d, e, f, g, h) -> (Exp a, Exp b, Exp c, Exp d, Exp e, Exp f, Exp g, Exp h)
untup8 e = (SuccTupIdx (SuccTupIdx (SuccTupIdx (SuccTupIdx (SuccTupIdx (SuccTupIdx (SuccTupIdx ZeroTupIdx)))))) `Prj` e,
            SuccTupIdx (SuccTupIdx (SuccTupIdx (SuccTupIdx (SuccTupIdx (SuccTupIdx ZeroTupIdx))))) `Prj` e,
            SuccTupIdx (SuccTupIdx (SuccTupIdx (SuccTupIdx (SuccTupIdx ZeroTupIdx)))) `Prj` e,
            SuccTupIdx (SuccTupIdx (SuccTupIdx (SuccTupIdx ZeroTupIdx))) `Prj` e,
            SuccTupIdx (SuccTupIdx (SuccTupIdx ZeroTupIdx)) `Prj` e,
            SuccTupIdx (SuccTupIdx ZeroTupIdx) `Prj` e,
            SuccTupIdx ZeroTupIdx `Prj` e,
            ZeroTupIdx `Prj` e)

untup9 :: (Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h, Elt i)
       => Exp (a, b, c, d, e, f, g, h, i) -> (Exp a, Exp b, Exp c, Exp d, Exp e, Exp f, Exp g, Exp h, Exp i)
untup9 e = (SuccTupIdx (SuccTupIdx (SuccTupIdx (SuccTupIdx (SuccTupIdx (SuccTupIdx (SuccTupIdx (SuccTupIdx ZeroTupIdx))))))) `Prj` e,
            SuccTupIdx (SuccTupIdx (SuccTupIdx (SuccTupIdx (SuccTupIdx (SuccTupIdx (SuccTupIdx ZeroTupIdx)))))) `Prj` e,
            SuccTupIdx (SuccTupIdx (SuccTupIdx (SuccTupIdx (SuccTupIdx (SuccTupIdx ZeroTupIdx))))) `Prj` e,
            SuccTupIdx (SuccTupIdx (SuccTupIdx (SuccTupIdx (SuccTupIdx ZeroTupIdx)))) `Prj` e,
            SuccTupIdx (SuccTupIdx (SuccTupIdx (SuccTupIdx ZeroTupIdx))) `Prj` e,
            SuccTupIdx (SuccTupIdx (SuccTupIdx ZeroTupIdx)) `Prj` e,
            SuccTupIdx (SuccTupIdx ZeroTupIdx) `Prj` e,
            SuccTupIdx ZeroTupIdx `Prj` e,
            ZeroTupIdx `Prj` e)

-- Smart constructor for constants
-- 

mkMinBound :: (Elt t, IsBounded t) => Exp t
mkMinBound = PrimConst (PrimMinBound boundedType)

mkMaxBound :: (Elt t, IsBounded t) => Exp t
mkMaxBound = PrimConst (PrimMaxBound boundedType)

mkPi :: (Elt r, IsFloating r) => Exp r
mkPi = PrimConst (PrimPi floatingType)


-- Smart constructors for primitive applications
--

-- Operators from Floating

mkSin :: (Elt t, IsFloating t) => Exp t -> Exp t
mkSin x = PrimSin floatingType `PrimApp` x

mkCos :: (Elt t, IsFloating t) => Exp t -> Exp t
mkCos x = PrimCos floatingType `PrimApp` x

mkTan :: (Elt t, IsFloating t) => Exp t -> Exp t
mkTan x = PrimTan floatingType `PrimApp` x

mkAsin :: (Elt t, IsFloating t) => Exp t -> Exp t
mkAsin x = PrimAsin floatingType `PrimApp` x

mkAcos :: (Elt t, IsFloating t) => Exp t -> Exp t
mkAcos x = PrimAcos floatingType `PrimApp` x

mkAtan :: (Elt t, IsFloating t) => Exp t -> Exp t
mkAtan x = PrimAtan floatingType `PrimApp` x

mkAsinh :: (Elt t, IsFloating t) => Exp t -> Exp t
mkAsinh x = PrimAsinh floatingType `PrimApp` x

mkAcosh :: (Elt t, IsFloating t) => Exp t -> Exp t
mkAcosh x = PrimAcosh floatingType `PrimApp` x

mkAtanh :: (Elt t, IsFloating t) => Exp t -> Exp t
mkAtanh x = PrimAtanh floatingType `PrimApp` x

mkExpFloating :: (Elt t, IsFloating t) => Exp t -> Exp t
mkExpFloating x = PrimExpFloating floatingType `PrimApp` x

mkSqrt :: (Elt t, IsFloating t) => Exp t -> Exp t
mkSqrt x = PrimSqrt floatingType `PrimApp` x

mkLog :: (Elt t, IsFloating t) => Exp t -> Exp t
mkLog x = PrimLog floatingType `PrimApp` x

mkFPow :: (Elt t, IsFloating t) => Exp t -> Exp t -> Exp t
mkFPow x y = PrimFPow floatingType `PrimApp` tup2 (x, y)

mkLogBase :: (Elt t, IsFloating t) => Exp t -> Exp t -> Exp t
mkLogBase x y = PrimLogBase floatingType `PrimApp` tup2 (x, y)

-- Operators from Num

mkAdd :: (Elt t, IsNum t) => Exp t -> Exp t -> Exp t
mkAdd x y = PrimAdd numType `PrimApp` tup2 (x, y)

mkSub :: (Elt t, IsNum t) => Exp t -> Exp t -> Exp t
mkSub x y = PrimSub numType `PrimApp` tup2 (x, y)

mkMul :: (Elt t, IsNum t) => Exp t -> Exp t -> Exp t
mkMul x y = PrimMul numType `PrimApp` tup2 (x, y)

mkNeg :: (Elt t, IsNum t) => Exp t -> Exp t
mkNeg x = PrimNeg numType `PrimApp` x

mkAbs :: (Elt t, IsNum t) => Exp t -> Exp t
mkAbs x = PrimAbs numType `PrimApp` x

mkSig :: (Elt t, IsNum t) => Exp t -> Exp t
mkSig x = PrimSig numType `PrimApp` x

-- Operators from Integral & Bits

mkQuot :: (Elt t, IsIntegral t) => Exp t -> Exp t -> Exp t
mkQuot x y = PrimQuot integralType `PrimApp` tup2 (x, y)

mkRem :: (Elt t, IsIntegral t) => Exp t -> Exp t -> Exp t
mkRem x y = PrimRem integralType `PrimApp` tup2 (x, y)

mkIDiv :: (Elt t, IsIntegral t) => Exp t -> Exp t -> Exp t
mkIDiv x y = PrimIDiv integralType `PrimApp` tup2 (x, y)

mkMod :: (Elt t, IsIntegral t) => Exp t -> Exp t -> Exp t
mkMod x y = PrimMod integralType `PrimApp` tup2 (x, y)

mkBAnd :: (Elt t, IsIntegral t) => Exp t -> Exp t -> Exp t
mkBAnd x y = PrimBAnd integralType `PrimApp` tup2 (x, y)

mkBOr :: (Elt t, IsIntegral t) => Exp t -> Exp t -> Exp t
mkBOr x y = PrimBOr integralType `PrimApp` tup2 (x, y)

mkBXor :: (Elt t, IsIntegral t) => Exp t -> Exp t -> Exp t
mkBXor x y = PrimBXor integralType `PrimApp` tup2 (x, y)

mkBNot :: (Elt t, IsIntegral t) => Exp t -> Exp t
mkBNot x = PrimBNot integralType `PrimApp` x

mkBShiftL :: (Elt t, IsIntegral t) => Exp t -> Exp Int -> Exp t
mkBShiftL x i = PrimBShiftL integralType `PrimApp` tup2 (x, i)

mkBShiftR :: (Elt t, IsIntegral t) => Exp t -> Exp Int -> Exp t
mkBShiftR x i = PrimBShiftR integralType `PrimApp` tup2 (x, i)

mkBRotateL :: (Elt t, IsIntegral t) => Exp t -> Exp Int -> Exp t
mkBRotateL x i = PrimBRotateL integralType `PrimApp` tup2 (x, i)

mkBRotateR :: (Elt t, IsIntegral t) => Exp t -> Exp Int -> Exp t
mkBRotateR x i = PrimBRotateR integralType `PrimApp` tup2 (x, i)

-- Operators from Fractional

mkFDiv :: (Elt t, IsFloating t) => Exp t -> Exp t -> Exp t
mkFDiv x y = PrimFDiv floatingType `PrimApp` tup2 (x, y)

mkRecip :: (Elt t, IsFloating t) => Exp t -> Exp t
mkRecip x = PrimRecip floatingType `PrimApp` x

-- Operators from RealFrac

mkTruncate :: (Elt a, Elt b, IsFloating a, IsIntegral b) => Exp a -> Exp b
mkTruncate x = PrimTruncate floatingType integralType `PrimApp` x

mkRound :: (Elt a, Elt b, IsFloating a, IsIntegral b) => Exp a -> Exp b
mkRound x = PrimRound floatingType integralType `PrimApp` x

mkFloor :: (Elt a, Elt b, IsFloating a, IsIntegral b) => Exp a -> Exp b
mkFloor x = PrimFloor floatingType integralType `PrimApp` x

mkCeiling :: (Elt a, Elt b, IsFloating a, IsIntegral b) => Exp a -> Exp b
mkCeiling x = PrimCeiling floatingType integralType `PrimApp` x

-- Operators from RealFloat

mkAtan2 :: (Elt t, IsFloating t) => Exp t -> Exp t -> Exp t
mkAtan2 x y = PrimAtan2 floatingType `PrimApp` tup2 (x, y)

-- FIXME: add missing operations from Floating, RealFrac & RealFloat

-- Relational and equality operators

mkLt :: (Elt t, IsScalar t) => Exp t -> Exp t -> Exp Bool
mkLt x y = PrimLt scalarType `PrimApp` tup2 (x, y)

mkGt :: (Elt t, IsScalar t) => Exp t -> Exp t -> Exp Bool
mkGt x y = PrimGt scalarType `PrimApp` tup2 (x, y)

mkLtEq :: (Elt t, IsScalar t) => Exp t -> Exp t -> Exp Bool
mkLtEq x y = PrimLtEq scalarType `PrimApp` tup2 (x, y)

mkGtEq :: (Elt t, IsScalar t) => Exp t -> Exp t -> Exp Bool
mkGtEq x y = PrimGtEq scalarType `PrimApp` tup2 (x, y)

mkEq :: (Elt t, IsScalar t) => Exp t -> Exp t -> Exp Bool
mkEq x y = PrimEq scalarType `PrimApp` tup2 (x, y)

mkNEq :: (Elt t, IsScalar t) => Exp t -> Exp t -> Exp Bool
mkNEq x y = PrimNEq scalarType `PrimApp` tup2 (x, y)

mkMax :: (Elt t, IsScalar t) => Exp t -> Exp t -> Exp t
mkMax x y = PrimMax scalarType `PrimApp` tup2 (x, y)

mkMin :: (Elt t, IsScalar t) => Exp t -> Exp t -> Exp t
mkMin x y = PrimMin scalarType `PrimApp` tup2 (x, y)

-- Logical operators

mkLAnd :: Exp Bool -> Exp Bool -> Exp Bool
mkLAnd x y = PrimLAnd `PrimApp` tup2 (x, y)

mkLOr :: Exp Bool -> Exp Bool -> Exp Bool
mkLOr x y = PrimLOr `PrimApp` tup2 (x, y)

mkLNot :: Exp Bool -> Exp Bool
mkLNot x = PrimLNot `PrimApp` x

-- FIXME: Character conversions

-- FIXME: Numeric conversions

mkFromIntegral :: (Elt a, Elt b, IsIntegral a, IsNum b) => Exp a -> Exp b
mkFromIntegral x = PrimFromIntegral integralType numType `PrimApp` x

-- FIXME: Other conversions

mkBoolToInt :: Exp Bool -> Exp Int
mkBoolToInt b = PrimBoolToInt `PrimApp` b


-- Auxiliary functions
-- --------------------

infixr 0 $$
($$) :: (b -> a) -> (c -> d -> b) -> c -> d -> a
(f $$ g) x y = f (g x y)

infixr 0 $$$
($$$) :: (b -> a) -> (c -> d -> e -> b) -> c -> d -> e -> a
(f $$$ g) x y z = f (g x y z)

infixr 0 $$$$
($$$$) :: (b -> a) -> (c -> d -> e -> f -> b) -> c -> d -> e -> f -> a
(f $$$$ g) x y z u = f (g x y z u)

infixr 0 $$$$$
($$$$$) :: (b -> a) -> (c -> d -> e -> f -> g -> b) -> c -> d -> e -> f -> g-> a
(f $$$$$ g) x y z u v = f (g x y z u v)