{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_HADDOCK hide #-}
module Data.Array.Accelerate.Smart (
Acc(..), SmartAcc(..), PreSmartAcc(..),
Level, Direction(..),
Exp(..), SmartExp(..), PreSmartExp(..),
Stencil(..),
Boundary(..), PreBoundary(..),
PrimBool,
PrimMaybe,
HasArraysR(..),
HasTypeR(..),
constant, undef,
indexHead, indexTail,
mkMinBound, mkMaxBound, mkPi,
mkSin, mkCos, mkTan,
mkAsin, mkAcos, mkAtan,
mkSinh, mkCosh, mkTanh,
mkAsinh, mkAcosh, mkAtanh,
mkExpFloating, mkSqrt, mkLog,
mkFPow, mkLogBase,
mkTruncate, mkRound, mkFloor, mkCeiling,
mkAtan2,
mkAdd, mkSub, mkMul, mkNeg, mkAbs, mkSig, mkQuot, mkRem, mkQuotRem, mkIDiv, mkMod, mkDivMod,
mkBAnd, mkBOr, mkBXor, mkBNot, mkBShiftL, mkBShiftR, mkBRotateL, mkBRotateR, mkPopCount, mkCountLeadingZeros, mkCountTrailingZeros,
mkFDiv, mkRecip, mkLt, mkGt, mkLtEq, mkGtEq, mkEq, mkNEq, mkMax, mkMin,
mkLAnd, mkLOr, mkLNot, mkIsNaN, mkIsInfinite,
mkFromIntegral, mkToFloating, mkBitcast, mkCoerce, Coerce(..),
($$), ($$$), ($$$$), ($$$$$),
ApplyAcc(..),
unAcc, unAccFunction, mkExp, unExp, unExpFunction, unExpBinaryFunction, unPair, mkPairToTuple,
showPreAccOp,
showPreExpOp,
) where
import Data.Array.Accelerate.AST.Idx
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Representation.Elt
import Data.Array.Accelerate.Representation.Shape
import Data.Array.Accelerate.Representation.Slice
import Data.Array.Accelerate.Representation.Stencil hiding ( StencilR, stencilR )
import Data.Array.Accelerate.Representation.Tag
import Data.Array.Accelerate.Representation.Type
import Data.Array.Accelerate.Representation.Vec
import Data.Array.Accelerate.Sugar.Array ( Arrays )
import Data.Array.Accelerate.Sugar.Elt
import Data.Array.Accelerate.Sugar.Foreign
import Data.Array.Accelerate.Sugar.Shape ( (:.)(..) )
import Data.Array.Accelerate.Type
import qualified Data.Array.Accelerate.Representation.Stencil as R
import qualified Data.Array.Accelerate.Sugar.Array as Sugar
import qualified Data.Array.Accelerate.Sugar.Shape as Sugar
import Data.Array.Accelerate.AST ( Direction(..)
, PrimBool, PrimMaybe
, PrimFun(..), primFunType
, PrimConst(..), primConstType )
import Data.Primitive.Vec
import Data.Kind
import Prelude
import GHC.TypeLits
newtype Acc a = Acc (SmartAcc (Sugar.ArraysR a))
newtype SmartAcc a = SmartAcc (PreSmartAcc SmartAcc SmartExp a)
type Level = Int
data PreSmartAcc acc exp as where
Atag :: ArraysR as
-> Level
-> PreSmartAcc acc exp as
Pipe :: ArraysR as
-> ArraysR bs
-> ArraysR cs
-> (SmartAcc as -> acc bs)
-> (SmartAcc bs -> acc cs)
-> acc as
-> PreSmartAcc acc exp cs
Aforeign :: Foreign asm
=> ArraysR bs
-> asm (as -> bs)
-> (SmartAcc as -> SmartAcc bs)
-> acc as
-> PreSmartAcc acc exp bs
Acond :: exp PrimBool
-> acc as
-> acc as
-> PreSmartAcc acc exp as
Awhile :: ArraysR arrs
-> (SmartAcc arrs -> acc (Scalar PrimBool))
-> (SmartAcc arrs -> acc arrs)
-> acc arrs
-> PreSmartAcc acc exp arrs
Anil :: PreSmartAcc acc exp ()
Apair :: acc arrs1
-> acc arrs2
-> PreSmartAcc acc exp (arrs1, arrs2)
Aprj :: PairIdx (arrs1, arrs2) arrs
-> acc (arrs1, arrs2)
-> PreSmartAcc acc exp arrs
Use :: ArrayR (Array sh e)
-> Array sh e
-> PreSmartAcc acc exp (Array sh e)
Unit :: TypeR e
-> exp e
-> PreSmartAcc acc exp (Scalar e)
Generate :: ArrayR (Array sh e)
-> exp sh
-> (SmartExp sh -> exp e)
-> PreSmartAcc acc exp (Array sh e)
Reshape :: ShapeR sh
-> exp sh
-> acc (Array sh' e)
-> PreSmartAcc acc exp (Array sh e)
Replicate :: SliceIndex slix sl co sh
-> exp slix
-> acc (Array sl e)
-> PreSmartAcc acc exp (Array sh e)
Slice :: SliceIndex slix sl co sh
-> acc (Array sh e)
-> exp slix
-> PreSmartAcc acc exp (Array sl e)
Map :: TypeR e
-> TypeR e'
-> (SmartExp e -> exp e')
-> acc (Array sh e)
-> PreSmartAcc acc exp (Array sh e')
ZipWith :: TypeR e1
-> TypeR e2
-> TypeR e3
-> (SmartExp e1 -> SmartExp e2 -> exp e3)
-> acc (Array sh e1)
-> acc (Array sh e2)
-> PreSmartAcc acc exp (Array sh e3)
Fold :: TypeR e
-> (SmartExp e -> SmartExp e -> exp e)
-> Maybe (exp e)
-> acc (Array (sh, Int) e)
-> PreSmartAcc acc exp (Array sh e)
FoldSeg :: IntegralType i
-> TypeR e
-> (SmartExp e -> SmartExp e -> exp e)
-> Maybe (exp e)
-> acc (Array (sh, Int) e)
-> acc (Segments i)
-> PreSmartAcc acc exp (Array (sh, Int) e)
Scan :: Direction
-> TypeR e
-> (SmartExp e -> SmartExp e -> exp e)
-> Maybe (exp e)
-> acc (Array (sh, Int) e)
-> PreSmartAcc acc exp (Array (sh, Int) e)
Scan' :: Direction
-> TypeR e
-> (SmartExp e -> SmartExp e -> exp e)
-> exp e
-> acc (Array (sh, Int) e)
-> PreSmartAcc acc exp (Array (sh, Int) e, Array sh e)
Permute :: ArrayR (Array sh e)
-> (SmartExp e -> SmartExp e -> exp e)
-> acc (Array sh' e)
-> (SmartExp sh -> exp (PrimMaybe sh'))
-> acc (Array sh e)
-> PreSmartAcc acc exp (Array sh' e)
Backpermute :: ShapeR sh'
-> exp sh'
-> (SmartExp sh' -> exp sh)
-> acc (Array sh e)
-> PreSmartAcc acc exp (Array sh' e)
Stencil :: R.StencilR sh a stencil
-> TypeR b
-> (SmartExp stencil -> exp b)
-> PreBoundary acc exp (Array sh a)
-> acc (Array sh a)
-> PreSmartAcc acc exp (Array sh b)
Stencil2 :: R.StencilR sh a stencil1
-> R.StencilR sh b stencil2
-> TypeR c
-> (SmartExp stencil1 -> SmartExp stencil2 -> exp c)
-> PreBoundary acc exp (Array sh a)
-> acc (Array sh a)
-> PreBoundary acc exp (Array sh b)
-> acc (Array sh b)
-> PreSmartAcc acc exp (Array sh c)
newtype Exp t = Exp (SmartExp (EltR t))
newtype SmartExp t = SmartExp (PreSmartExp SmartAcc SmartExp t)
data PreSmartExp acc exp t where
Tag :: TypeR t
-> Level
-> PreSmartExp acc exp t
Match :: TagR t
-> exp t
-> PreSmartExp acc exp t
Const :: ScalarType t
-> t
-> PreSmartExp acc exp t
Nil :: PreSmartExp acc exp ()
Pair :: exp t1
-> exp t2
-> PreSmartExp acc exp (t1, t2)
Prj :: PairIdx (t1, t2) t
-> exp (t1, t2)
-> PreSmartExp acc exp t
VecPack :: KnownNat n
=> VecR n s tup
-> exp tup
-> PreSmartExp acc exp (Vec n s)
VecUnpack :: KnownNat n
=> VecR n s tup
-> exp (Vec n s)
-> PreSmartExp acc exp tup
ToIndex :: ShapeR sh
-> exp sh
-> exp sh
-> PreSmartExp acc exp Int
FromIndex :: ShapeR sh
-> exp sh
-> exp Int
-> PreSmartExp acc exp sh
Case :: exp a
-> [(TagR a, exp b)]
-> PreSmartExp acc exp b
Cond :: exp PrimBool
-> exp t
-> exp t
-> PreSmartExp acc exp t
While :: TypeR t
-> (SmartExp t -> exp PrimBool)
-> (SmartExp t -> exp t)
-> exp t
-> PreSmartExp acc exp t
PrimConst :: PrimConst t
-> PreSmartExp acc exp t
PrimApp :: PrimFun (a -> r)
-> exp a
-> PreSmartExp acc exp r
Index :: TypeR t
-> acc (Array sh t)
-> exp sh
-> PreSmartExp acc exp t
LinearIndex :: TypeR t
-> acc (Array sh t)
-> exp Int
-> PreSmartExp acc exp t
Shape :: ShapeR sh
-> acc (Array sh e)
-> PreSmartExp acc exp sh
ShapeSize :: ShapeR sh
-> exp sh
-> PreSmartExp acc exp Int
Foreign :: Foreign asm
=> TypeR y
-> asm (x -> y)
-> (SmartExp x -> SmartExp y)
-> exp x
-> PreSmartExp acc exp y
Undef :: ScalarType t
-> PreSmartExp acc exp t
Coerce :: BitSizeEq a b
=> ScalarType a
-> ScalarType b
-> exp a
-> PreSmartExp acc exp b
data Boundary t where
Boundary :: PreBoundary SmartAcc SmartExp (Array (EltR sh) (EltR e))
-> Boundary (Sugar.Array sh e)
data PreBoundary acc exp t where
Clamp :: PreBoundary acc exp t
Mirror :: PreBoundary acc exp t
Wrap :: PreBoundary acc exp t
Constant :: e
-> PreBoundary acc exp (Array sh e)
Function :: (SmartExp sh -> exp e)
-> PreBoundary acc exp (Array sh e)
class Stencil sh e stencil where
type StencilR sh stencil :: Type
stencilR :: R.StencilR (EltR sh) (EltR e) (StencilR sh stencil)
stencilPrj :: SmartExp (StencilR sh stencil) -> stencil
instance Elt e => Stencil Sugar.DIM1 e (Exp e, Exp e, Exp e) where
type StencilR Sugar.DIM1 (Exp e, Exp e, Exp e)
= EltR (e, e, e)
stencilR = StencilRunit3 @(EltR e) $ eltR @e
stencilPrj s = (Exp $ prj2 s,
Exp $ prj1 s,
Exp $ prj0 s)
instance Elt e => Stencil Sugar.DIM1 e (Exp e, Exp e, Exp e, Exp e, Exp e) where
type StencilR Sugar.DIM1 (Exp e, Exp e, Exp e, Exp e, Exp e)
= EltR (e, e, e, e, e)
stencilR = StencilRunit5 $ eltR @e
stencilPrj s = (Exp $ prj4 s,
Exp $ prj3 s,
Exp $ prj2 s,
Exp $ prj1 s,
Exp $ prj0 s)
instance Elt e => Stencil Sugar.DIM1 e (Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e) where
type StencilR Sugar.DIM1 (Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e)
= EltR (e, e, e, e, e, e, e)
stencilR = StencilRunit7 $ eltR @e
stencilPrj s = (Exp $ prj6 s,
Exp $ prj5 s,
Exp $ prj4 s,
Exp $ prj3 s,
Exp $ prj2 s,
Exp $ prj1 s,
Exp $ prj0 s)
instance Elt e => Stencil Sugar.DIM1 e (Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e)
where
type StencilR Sugar.DIM1 (Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e)
= EltR (e, e, e, e, e, e, e, e, e)
stencilR = StencilRunit9 $ eltR @e
stencilPrj s = (Exp $ prj8 s,
Exp $ prj7 s,
Exp $ prj6 s,
Exp $ prj5 s,
Exp $ prj4 s,
Exp $ prj3 s,
Exp $ prj2 s,
Exp $ prj1 s,
Exp $ prj0 s)
instance (Stencil (sh:.Int) a row2,
Stencil (sh:.Int) a row1,
Stencil (sh:.Int) a row0) => Stencil (sh:.Int:.Int) a (row2, row1, row0) where
type StencilR (sh:.Int:.Int) (row2, row1, row0)
= Tup3 (StencilR (sh:.Int) row2) (StencilR (sh:.Int) row1) (StencilR (sh:.Int) row0)
stencilR = StencilRtup3 (stencilR @(sh:.Int) @a @row2) (stencilR @(sh:.Int) @a @row1) (stencilR @(sh:.Int) @a @row0)
stencilPrj s = (stencilPrj @(sh:.Int) @a $ prj2 s,
stencilPrj @(sh:.Int) @a $ prj1 s,
stencilPrj @(sh:.Int) @a $ prj0 s)
instance (Stencil (sh:.Int) a row4,
Stencil (sh:.Int) a row3,
Stencil (sh:.Int) a row2,
Stencil (sh:.Int) a row1,
Stencil (sh:.Int) a row0) => Stencil (sh:.Int:.Int) a (row4, row3, row2, row1, row0) where
type StencilR (sh:.Int:.Int) (row4, row3, row2, row1, row0)
= Tup5 (StencilR (sh:.Int) row4) (StencilR (sh:.Int) row3) (StencilR (sh:.Int) row2)
(StencilR (sh:.Int) row1) (StencilR (sh:.Int) row0)
stencilR = StencilRtup5 (stencilR @(sh:.Int) @a @row4) (stencilR @(sh:.Int) @a @row3)
(stencilR @(sh:.Int) @a @row2) (stencilR @(sh:.Int) @a @row1) (stencilR @(sh:.Int) @a @row0)
stencilPrj s = (stencilPrj @(sh:.Int) @a $ prj4 s,
stencilPrj @(sh:.Int) @a $ prj3 s,
stencilPrj @(sh:.Int) @a $ prj2 s,
stencilPrj @(sh:.Int) @a $ prj1 s,
stencilPrj @(sh:.Int) @a $ prj0 s)
instance (Stencil (sh:.Int) a row6,
Stencil (sh:.Int) a row5,
Stencil (sh:.Int) a row4,
Stencil (sh:.Int) a row3,
Stencil (sh:.Int) a row2,
Stencil (sh:.Int) a row1,
Stencil (sh:.Int) a row0)
=> Stencil (sh:.Int:.Int) a (row6, row5, row4, row3, row2, row1, row0) where
type StencilR (sh:.Int:.Int) (row6, row5, row4, row3, row2, row1, row0)
= Tup7 (StencilR (sh:.Int) row6) (StencilR (sh:.Int) row5) (StencilR (sh:.Int) row4)
(StencilR (sh:.Int) row3) (StencilR (sh:.Int) row2) (StencilR (sh:.Int) row1)
(StencilR (sh:.Int) row0)
stencilR = StencilRtup7 (stencilR @(sh:.Int) @a @row6)
(stencilR @(sh:.Int) @a @row5) (stencilR @(sh:.Int) @a @row4) (stencilR @(sh:.Int) @a @row3)
(stencilR @(sh:.Int) @a @row2) (stencilR @(sh:.Int) @a @row1) (stencilR @(sh:.Int) @a @row0)
stencilPrj s = (stencilPrj @(sh:.Int) @a $ prj6 s,
stencilPrj @(sh:.Int) @a $ prj5 s,
stencilPrj @(sh:.Int) @a $ prj4 s,
stencilPrj @(sh:.Int) @a $ prj3 s,
stencilPrj @(sh:.Int) @a $ prj2 s,
stencilPrj @(sh:.Int) @a $ prj1 s,
stencilPrj @(sh:.Int) @a $ prj0 s)
instance (Stencil (sh:.Int) a row8,
Stencil (sh:.Int) a row7,
Stencil (sh:.Int) a row6,
Stencil (sh:.Int) a row5,
Stencil (sh:.Int) a row4,
Stencil (sh:.Int) a row3,
Stencil (sh:.Int) a row2,
Stencil (sh:.Int) a row1,
Stencil (sh:.Int) a row0)
=> Stencil (sh:.Int:.Int) a (row8, row7, row6, row5, row4, row3, row2, row1, row0) where
type StencilR (sh:.Int:.Int) (row8, row7, row6, row5, row4, row3, row2, row1, row0)
= Tup9 (StencilR (sh:.Int) row8) (StencilR (sh:.Int) row7) (StencilR (sh:.Int) row6)
(StencilR (sh:.Int) row5) (StencilR (sh:.Int) row4) (StencilR (sh:.Int) row3)
(StencilR (sh:.Int) row2) (StencilR (sh:.Int) row1) (StencilR (sh:.Int) row0)
stencilR = StencilRtup9
(stencilR @(sh:.Int) @a @row8) (stencilR @(sh:.Int) @a @row7) (stencilR @(sh:.Int) @a @row6)
(stencilR @(sh:.Int) @a @row5) (stencilR @(sh:.Int) @a @row4) (stencilR @(sh:.Int) @a @row3)
(stencilR @(sh:.Int) @a @row2) (stencilR @(sh:.Int) @a @row1) (stencilR @(sh:.Int) @a @row0)
stencilPrj s = (stencilPrj @(sh:.Int) @a $ prj8 s,
stencilPrj @(sh:.Int) @a $ prj7 s,
stencilPrj @(sh:.Int) @a $ prj6 s,
stencilPrj @(sh:.Int) @a $ prj5 s,
stencilPrj @(sh:.Int) @a $ prj4 s,
stencilPrj @(sh:.Int) @a $ prj3 s,
stencilPrj @(sh:.Int) @a $ prj2 s,
stencilPrj @(sh:.Int) @a $ prj1 s,
stencilPrj @(sh:.Int) @a $ prj0 s)
prjTail :: SmartExp (t, a) -> SmartExp t
prjTail = SmartExp . Prj PairIdxLeft
prj0 :: SmartExp (t, a) -> SmartExp a
prj0 = SmartExp . Prj PairIdxRight
prj1 :: SmartExp ((t, a), s0) -> SmartExp a
prj1 = prj0 . prjTail
prj2 :: SmartExp (((t, a), s1), s0) -> SmartExp a
prj2 = prj1 . prjTail
prj3 :: SmartExp ((((t, a), s2), s1), s0) -> SmartExp a
prj3 = prj2 . prjTail
prj4 :: SmartExp (((((t, a), s3), s2), s1), s0) -> SmartExp a
prj4 = prj3 . prjTail
prj5 :: SmartExp ((((((t, a), s4), s3), s2), s1), s0) -> SmartExp a
prj5 = prj4 . prjTail
prj6 :: SmartExp (((((((t, a), s5), s4), s3), s2), s1), s0) -> SmartExp a
prj6 = prj5 . prjTail
prj7 :: SmartExp ((((((((t, a), s6), s5), s4), s3), s2), s1), s0) -> SmartExp a
prj7 = prj6 . prjTail
prj8 :: SmartExp (((((((((t, a), s7), s6), s5), s4), s3), s2), s1), s0) -> SmartExp a
prj8 = prj7 . prjTail
class HasArraysR f where
arraysR :: f a -> ArraysR a
instance HasArraysR SmartAcc where
arraysR (SmartAcc e) = arraysR e
arrayR :: HasArraysR f => f (Array sh e) -> ArrayR (Array sh e)
arrayR acc = case arraysR acc of
TupRsingle repr -> repr
instance HasArraysR acc => HasArraysR (PreSmartAcc acc exp) where
arraysR = \case
Atag repr _ -> repr
Pipe _ _ repr _ _ _ -> repr
Aforeign repr _ _ _ -> repr
Acond _ a _ -> arraysR a
Awhile _ _ _ a -> arraysR a
Anil -> TupRunit
Apair a1 a2 -> arraysR a1 `TupRpair` arraysR a2
Aprj idx a | TupRpair t1 t2 <- arraysR a
-> case idx of
PairIdxLeft -> t1
PairIdxRight -> t2
Aprj _ _ -> error "Ejector seat? You're joking!"
Use repr _ -> TupRsingle repr
Unit tp _ -> TupRsingle $ ArrayR ShapeRz $ tp
Generate repr _ _ -> TupRsingle repr
Reshape shr _ a -> let ArrayR _ tp = arrayR a
in TupRsingle $ ArrayR shr tp
Replicate si _ a -> let ArrayR _ tp = arrayR a
in TupRsingle $ ArrayR (sliceDomainR si) tp
Slice si a _ -> let ArrayR _ tp = arrayR a
in TupRsingle $ ArrayR (sliceShapeR si) tp
Map _ tp _ a -> let ArrayR shr _ = arrayR a
in TupRsingle $ ArrayR shr tp
ZipWith _ _ tp _ a _ -> let ArrayR shr _ = arrayR a
in TupRsingle $ ArrayR shr tp
Fold _ _ _ a -> let ArrayR (ShapeRsnoc shr) tp = arrayR a
in TupRsingle (ArrayR shr tp)
FoldSeg _ _ _ _ a _ -> arraysR a
Scan _ _ _ _ a -> arraysR a
Scan' _ _ _ _ a -> let repr@(ArrayR (ShapeRsnoc shr) tp) = arrayR a
in TupRsingle repr `TupRpair` TupRsingle (ArrayR shr tp)
Permute _ _ a _ _ -> arraysR a
Backpermute shr _ _ a -> let ArrayR _ tp = arrayR a
in TupRsingle (ArrayR shr tp)
Stencil s tp _ _ _ -> TupRsingle $ ArrayR (stencilShapeR s) tp
Stencil2 s _ tp _ _ _ _ _ -> TupRsingle $ ArrayR (stencilShapeR s) tp
class HasTypeR f where
typeR :: HasCallStack => f t -> TypeR t
instance HasTypeR SmartExp where
typeR (SmartExp e) = typeR e
instance HasTypeR exp => HasTypeR (PreSmartExp acc exp) where
typeR = \case
Tag tp _ -> tp
Match _ e -> typeR e
Const tp _ -> TupRsingle tp
Nil -> TupRunit
Pair e1 e2 -> typeR e1 `TupRpair` typeR e2
Prj idx e
| TupRpair t1 t2 <- typeR e -> case idx of
PairIdxLeft -> t1
PairIdxRight -> t2
Prj _ _ -> error "I never joke about my work"
VecPack vecR _ -> TupRsingle $ VectorScalarType $ vecRvector vecR
VecUnpack vecR _ -> vecRtuple vecR
ToIndex _ _ _ -> TupRsingle scalarTypeInt
FromIndex shr _ _ -> shapeType shr
Case _ ((_,c):_) -> typeR c
Case{} -> internalError "encountered empty case"
Cond _ e _ -> typeR e
While t _ _ _ -> t
PrimConst c -> TupRsingle $ SingleScalarType $ primConstType c
PrimApp f _ -> snd $ primFunType f
Index tp _ _ -> tp
LinearIndex tp _ _ -> tp
Shape shr _ -> shapeType shr
ShapeSize _ _ -> TupRsingle scalarTypeInt
Foreign tp _ _ _ -> tp
Undef tp -> TupRsingle tp
Coerce _ tp _ -> TupRsingle tp
constant :: forall e. (HasCallStack, Elt e) => e -> Exp e
constant = Exp . go (eltR @e) . fromElt
where
go :: HasCallStack => TypeR t -> t -> SmartExp t
go TupRunit () = SmartExp $ Nil
go (TupRsingle tp) c = SmartExp $ Const tp c
go (TupRpair t1 t2) (c1, c2) = SmartExp $ go t1 c1 `Pair` go t2 c2
undef :: forall e. Elt e => Exp e
undef = Exp $ go $ eltR @e
where
go :: TypeR t -> SmartExp t
go TupRunit = SmartExp $ Nil
go (TupRsingle t) = SmartExp $ Undef t
go (TupRpair t1 t2) = SmartExp $ go t1 `Pair` go t2
indexHead :: (Elt sh, Elt a) => Exp (sh :. a) -> Exp a
indexHead (Exp x) = mkExp $ Prj PairIdxRight x
indexTail :: (Elt sh, Elt a) => Exp (sh :. a) -> Exp sh
indexTail (Exp x) = mkExp $ Prj PairIdxLeft x
mkMinBound :: (Elt t, IsBounded (EltR t)) => Exp t
mkMinBound = mkExp $ PrimConst (PrimMinBound boundedType)
mkMaxBound :: (Elt t, IsBounded (EltR t)) => Exp t
mkMaxBound = mkExp $ PrimConst (PrimMaxBound boundedType)
mkPi :: (Elt r, IsFloating (EltR r)) => Exp r
mkPi = mkExp $ PrimConst (PrimPi floatingType)
mkSin :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t
mkSin = mkPrimUnary $ PrimSin floatingType
mkCos :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t
mkCos = mkPrimUnary $ PrimCos floatingType
mkTan :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t
mkTan = mkPrimUnary $ PrimTan floatingType
mkAsin :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t
mkAsin = mkPrimUnary $ PrimAsin floatingType
mkAcos :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t
mkAcos = mkPrimUnary $ PrimAcos floatingType
mkAtan :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t
mkAtan = mkPrimUnary $ PrimAtan floatingType
mkSinh :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t
mkSinh = mkPrimUnary $ PrimSinh floatingType
mkCosh :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t
mkCosh = mkPrimUnary $ PrimCosh floatingType
mkTanh :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t
mkTanh = mkPrimUnary $ PrimTanh floatingType
mkAsinh :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t
mkAsinh = mkPrimUnary $ PrimAsinh floatingType
mkAcosh :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t
mkAcosh = mkPrimUnary $ PrimAcosh floatingType
mkAtanh :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t
mkAtanh = mkPrimUnary $ PrimAtanh floatingType
mkExpFloating :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t
mkExpFloating = mkPrimUnary $ PrimExpFloating floatingType
mkSqrt :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t
mkSqrt = mkPrimUnary $ PrimSqrt floatingType
mkLog :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t
mkLog = mkPrimUnary $ PrimLog floatingType
mkFPow :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t -> Exp t
mkFPow = mkPrimBinary $ PrimFPow floatingType
mkLogBase :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t -> Exp t
mkLogBase = mkPrimBinary $ PrimLogBase floatingType
mkAdd :: (Elt t, IsNum (EltR t)) => Exp t -> Exp t -> Exp t
mkAdd = mkPrimBinary $ PrimAdd numType
mkSub :: (Elt t, IsNum (EltR t)) => Exp t -> Exp t -> Exp t
mkSub = mkPrimBinary $ PrimSub numType
mkMul :: (Elt t, IsNum (EltR t)) => Exp t -> Exp t -> Exp t
mkMul = mkPrimBinary $ PrimMul numType
mkNeg :: (Elt t, IsNum (EltR t)) => Exp t -> Exp t
mkNeg = mkPrimUnary $ PrimNeg numType
mkAbs :: (Elt t, IsNum (EltR t)) => Exp t -> Exp t
mkAbs = mkPrimUnary $ PrimAbs numType
mkSig :: (Elt t, IsNum (EltR t)) => Exp t -> Exp t
mkSig = mkPrimUnary $ PrimSig numType
mkQuot :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t -> Exp t
mkQuot = mkPrimBinary $ PrimQuot integralType
mkRem :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t -> Exp t
mkRem = mkPrimBinary $ PrimRem integralType
mkQuotRem :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t -> (Exp t, Exp t)
mkQuotRem (Exp x) (Exp y) =
let pair = SmartExp $ PrimQuotRem integralType `PrimApp` SmartExp (Pair x y)
in (mkExp $ Prj PairIdxLeft pair, mkExp $ Prj PairIdxRight pair)
mkIDiv :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t -> Exp t
mkIDiv = mkPrimBinary $ PrimIDiv integralType
mkMod :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t -> Exp t
mkMod = mkPrimBinary $ PrimMod integralType
mkDivMod :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t -> (Exp t, Exp t)
mkDivMod (Exp x) (Exp y) =
let pair = SmartExp $ PrimDivMod integralType `PrimApp` SmartExp (Pair x y)
in (mkExp $ Prj PairIdxLeft pair, mkExp $ Prj PairIdxRight pair)
mkBAnd :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t -> Exp t
mkBAnd = mkPrimBinary $ PrimBAnd integralType
mkBOr :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t -> Exp t
mkBOr = mkPrimBinary $ PrimBOr integralType
mkBXor :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t -> Exp t
mkBXor = mkPrimBinary $ PrimBXor integralType
mkBNot :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t
mkBNot = mkPrimUnary $ PrimBNot integralType
mkBShiftL :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp Int -> Exp t
mkBShiftL = mkPrimBinary $ PrimBShiftL integralType
mkBShiftR :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp Int -> Exp t
mkBShiftR = mkPrimBinary $ PrimBShiftR integralType
mkBRotateL :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp Int -> Exp t
mkBRotateL = mkPrimBinary $ PrimBRotateL integralType
mkBRotateR :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp Int -> Exp t
mkBRotateR = mkPrimBinary $ PrimBRotateR integralType
mkPopCount :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp Int
mkPopCount = mkPrimUnary $ PrimPopCount integralType
mkCountLeadingZeros :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp Int
mkCountLeadingZeros = mkPrimUnary $ PrimCountLeadingZeros integralType
mkCountTrailingZeros :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp Int
mkCountTrailingZeros = mkPrimUnary $ PrimCountTrailingZeros integralType
mkFDiv :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t -> Exp t
mkFDiv = mkPrimBinary $ PrimFDiv floatingType
mkRecip :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t
mkRecip = mkPrimUnary $ PrimRecip floatingType
mkTruncate :: (Elt a, Elt b, IsFloating (EltR a), IsIntegral (EltR b)) => Exp a -> Exp b
mkTruncate = mkPrimUnary $ PrimTruncate floatingType integralType
mkRound :: (Elt a, Elt b, IsFloating (EltR a), IsIntegral (EltR b)) => Exp a -> Exp b
mkRound = mkPrimUnary $ PrimRound floatingType integralType
mkFloor :: (Elt a, Elt b, IsFloating (EltR a), IsIntegral (EltR b)) => Exp a -> Exp b
mkFloor = mkPrimUnary $ PrimFloor floatingType integralType
mkCeiling :: (Elt a, Elt b, IsFloating (EltR a), IsIntegral (EltR b)) => Exp a -> Exp b
mkCeiling = mkPrimUnary $ PrimCeiling floatingType integralType
mkAtan2 :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t -> Exp t
mkAtan2 = mkPrimBinary $ PrimAtan2 floatingType
mkIsNaN :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp Bool
mkIsNaN = mkPrimUnaryBool $ PrimIsNaN floatingType
mkIsInfinite :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp Bool
mkIsInfinite = mkPrimUnaryBool $ PrimIsInfinite floatingType
mkLt :: (Elt t, IsSingle (EltR t)) => Exp t -> Exp t -> Exp Bool
mkLt = mkPrimBinaryBool $ PrimLt singleType
mkGt :: (Elt t, IsSingle (EltR t)) => Exp t -> Exp t -> Exp Bool
mkGt = mkPrimBinaryBool $ PrimGt singleType
mkLtEq :: (Elt t, IsSingle (EltR t)) => Exp t -> Exp t -> Exp Bool
mkLtEq = mkPrimBinaryBool $ PrimLtEq singleType
mkGtEq :: (Elt t, IsSingle (EltR t)) => Exp t -> Exp t -> Exp Bool
mkGtEq = mkPrimBinaryBool $ PrimGtEq singleType
mkEq :: (Elt t, IsSingle (EltR t)) => Exp t -> Exp t -> Exp Bool
mkEq = mkPrimBinaryBool $ PrimEq singleType
mkNEq :: (Elt t, IsSingle (EltR t)) => Exp t -> Exp t -> Exp Bool
mkNEq = mkPrimBinaryBool $ PrimNEq singleType
mkMax :: (Elt t, IsSingle (EltR t)) => Exp t -> Exp t -> Exp t
mkMax = mkPrimBinary $ PrimMax singleType
mkMin :: (Elt t, IsSingle (EltR t)) => Exp t -> Exp t -> Exp t
mkMin = mkPrimBinary $ PrimMin singleType
mkLAnd :: Exp Bool -> Exp Bool -> Exp Bool
mkLAnd (Exp a) (Exp b) = mkExp $ SmartExp (PrimApp PrimLAnd (SmartExp $ Pair x y)) `Pair` SmartExp Nil
where
x = SmartExp $ Prj PairIdxLeft a
y = SmartExp $ Prj PairIdxLeft b
mkLOr :: Exp Bool -> Exp Bool -> Exp Bool
mkLOr (Exp a) (Exp b) = mkExp $ SmartExp (PrimApp PrimLOr (SmartExp $ Pair x y)) `Pair` SmartExp Nil
where
x = SmartExp $ Prj PairIdxLeft a
y = SmartExp $ Prj PairIdxLeft b
mkLNot :: Exp Bool -> Exp Bool
mkLNot (Exp a) = mkExp $ SmartExp (PrimApp PrimLNot x) `Pair` SmartExp Nil
where
x = SmartExp $ Prj PairIdxLeft a
mkFromIntegral :: (Elt a, Elt b, IsIntegral (EltR a), IsNum (EltR b)) => Exp a -> Exp b
mkFromIntegral = mkPrimUnary $ PrimFromIntegral integralType numType
mkToFloating :: (Elt a, Elt b, IsNum (EltR a), IsFloating (EltR b)) => Exp a -> Exp b
mkToFloating = mkPrimUnary $ PrimToFloating numType floatingType
mkBitcast :: forall b a. (Elt a, Elt b, IsScalar (EltR a), IsScalar (EltR b), BitSizeEq (EltR a) (EltR b)) => Exp a -> Exp b
mkBitcast (Exp a) = mkExp $ Coerce (scalarType @(EltR a)) (scalarType @(EltR b)) a
mkCoerce :: Coerce (EltR a) (EltR b) => Exp a -> Exp b
mkCoerce (Exp a) = Exp $ mkCoerce' a
class Coerce a b where
mkCoerce' :: SmartExp a -> SmartExp b
instance {-# OVERLAPS #-} (IsScalar a, IsScalar b, BitSizeEq a b) => Coerce a b where
mkCoerce' = SmartExp . Coerce (scalarType @a) (scalarType @b)
instance (Coerce a1 b1, Coerce a2 b2) => Coerce (a1, a2) (b1, b2) where
mkCoerce' a = SmartExp $ Pair (mkCoerce' $ SmartExp $ Prj PairIdxLeft a) (mkCoerce' $ SmartExp $ Prj PairIdxRight a)
instance Coerce a a where
mkCoerce' = id
instance Coerce ((), a) a where
mkCoerce' a = SmartExp $ Prj PairIdxRight a
instance Coerce a ((), a) where
mkCoerce' = SmartExp . Pair (SmartExp Nil)
instance Coerce (a, ()) a where
mkCoerce' a = SmartExp $ Prj PairIdxLeft a
instance Coerce a (a, ()) where
mkCoerce' a = SmartExp (Pair a (SmartExp Nil))
infixr 0 $$
($$) :: (b -> a) -> (c -> d -> b) -> c -> d -> a
(f $$ g) x y = f (g x y)
infixr 0 $$$
($$$) :: (b -> a) -> (c -> d -> e -> b) -> c -> d -> e -> a
(f $$$ g) x y z = f (g x y z)
infixr 0 $$$$
($$$$) :: (b -> a) -> (c -> d -> e -> f -> b) -> c -> d -> e -> f -> a
(f $$$$ g) x y z u = f (g x y z u)
infixr 0 $$$$$
($$$$$) :: (b -> a) -> (c -> d -> e -> f -> g -> b) -> c -> d -> e -> f -> g-> a
(f $$$$$ g) x y z u v = f (g x y z u v)
unAcc :: Arrays a => Acc a -> SmartAcc (Sugar.ArraysR a)
unAcc (Acc a) = a
unAccFunction :: (Arrays a, Arrays b) => (Acc a -> Acc b) -> SmartAcc (Sugar.ArraysR a) -> SmartAcc (Sugar.ArraysR b)
unAccFunction f = unAcc . f . Acc
mkExp :: PreSmartExp SmartAcc SmartExp (EltR t) -> Exp t
mkExp = Exp . SmartExp
unExp :: Exp e -> SmartExp (EltR e)
unExp (Exp e) = e
unExpFunction :: (Elt a, Elt b) => (Exp a -> Exp b) -> SmartExp (EltR a) -> SmartExp (EltR b)
unExpFunction f = unExp . f . Exp
unExpBinaryFunction :: (Elt a, Elt b, Elt c) => (Exp a -> Exp b -> Exp c) -> SmartExp (EltR a) -> SmartExp (EltR b) -> SmartExp (EltR c)
unExpBinaryFunction f a b = unExp $ f (Exp a) (Exp b)
mkPrimUnary :: (Elt a, Elt b) => PrimFun (EltR a -> EltR b) -> Exp a -> Exp b
mkPrimUnary prim (Exp a) = mkExp $ PrimApp prim a
mkPrimBinary :: (Elt a, Elt b, Elt c) => PrimFun ((EltR a, EltR b) -> EltR c) -> Exp a -> Exp b -> Exp c
mkPrimBinary prim (Exp a) (Exp b) = mkExp $ PrimApp prim (SmartExp $ Pair a b)
mkPrimUnaryBool :: Elt a => PrimFun (EltR a -> PrimBool) -> Exp a -> Exp Bool
mkPrimUnaryBool = mkCoerce @PrimBool $$ mkPrimUnary
mkPrimBinaryBool :: (Elt a, Elt b) => PrimFun ((EltR a, EltR b) -> PrimBool) -> Exp a -> Exp b -> Exp Bool
mkPrimBinaryBool = mkCoerce @PrimBool $$$ mkPrimBinary
unPair :: SmartExp (a, b) -> (SmartExp a, SmartExp b)
unPair e = (SmartExp $ Prj PairIdxLeft e, SmartExp $ Prj PairIdxRight e)
mkPairToTuple :: SmartAcc (a, b) -> SmartAcc (((), a), b)
mkPairToTuple e = SmartAcc Anil `pair` a `pair` b
where
a = SmartAcc $ Aprj PairIdxLeft e
b = SmartAcc $ Aprj PairIdxRight e
pair x y = SmartAcc $ Apair x y
class ApplyAcc a where
type FromApplyAcc a
applyAcc :: FromApplyAcc a -> a
instance ApplyAcc (SmartAcc a) where
type FromApplyAcc (SmartAcc a) = PreSmartAcc SmartAcc SmartExp a
applyAcc = SmartAcc
instance (Arrays a, ApplyAcc t) => ApplyAcc (Acc a -> t) where
type FromApplyAcc (Acc a -> t) = SmartAcc (Sugar.ArraysR a) -> FromApplyAcc t
applyAcc f a = applyAcc $ f (unAcc a)
instance (Elt a, ApplyAcc t) => ApplyAcc (Exp a -> t) where
type FromApplyAcc (Exp a -> t) = SmartExp (EltR a) -> FromApplyAcc t
applyAcc f a = applyAcc $ f (unExp a)
instance (Elt a, Elt b, ApplyAcc t) => ApplyAcc ((Exp a -> Exp b) -> t) where
type FromApplyAcc ((Exp a -> Exp b) -> t) = (SmartExp (EltR a) -> SmartExp (EltR b)) -> FromApplyAcc t
applyAcc f a = applyAcc $ f (unExpFunction a)
instance (Elt a, Elt b, Elt c, ApplyAcc t) => ApplyAcc ((Exp a -> Exp b -> Exp c) -> t) where
type FromApplyAcc ((Exp a -> Exp b -> Exp c) -> t) = (SmartExp (EltR a) -> SmartExp (EltR b) -> SmartExp (EltR c)) -> FromApplyAcc t
applyAcc f a = applyAcc $ f (unExpBinaryFunction a)
instance (Arrays a, Arrays b, ApplyAcc t) => ApplyAcc ((Acc a -> Acc b) -> t) where
type FromApplyAcc ((Acc a -> Acc b) -> t) = (SmartAcc (Sugar.ArraysR a) -> SmartAcc (Sugar.ArraysR b)) -> FromApplyAcc t
applyAcc f a = applyAcc $ f (unAccFunction a)
showPreAccOp :: forall acc exp arrs. PreSmartAcc acc exp arrs -> String
showPreAccOp (Atag _ i) = "Atag " ++ show i
showPreAccOp (Use aR a) = "Use " ++ showArrayShort 5 (showsElt (arrayRtype aR)) aR a
showPreAccOp Pipe{} = "Pipe"
showPreAccOp Acond{} = "Acond"
showPreAccOp Awhile{} = "Awhile"
showPreAccOp Apair{} = "Apair"
showPreAccOp Anil{} = "Anil"
showPreAccOp Aprj{} = "Aprj"
showPreAccOp Unit{} = "Unit"
showPreAccOp Generate{} = "Generate"
showPreAccOp Reshape{} = "Reshape"
showPreAccOp Replicate{} = "Replicate"
showPreAccOp Slice{} = "Slice"
showPreAccOp Map{} = "Map"
showPreAccOp ZipWith{} = "ZipWith"
showPreAccOp (Fold _ _ z _) = "Fold" ++ maybe "1" (const "") z
showPreAccOp (FoldSeg _ _ _ z _ _) = "Fold" ++ maybe "1" (const "") z ++ "Seg"
showPreAccOp (Scan d _ _ z _) = "Scan" ++ showsDirection d (maybe "1" (const "") z)
showPreAccOp (Scan' d _ _ _ _) = "Scan" ++ showsDirection d "'"
showPreAccOp Permute{} = "Permute"
showPreAccOp Backpermute{} = "Backpermute"
showPreAccOp Stencil{} = "Stencil"
showPreAccOp Stencil2{} = "Stencil2"
showPreAccOp Aforeign{} = "Aforeign"
showsDirection :: Direction -> ShowS
showsDirection LeftToRight = ('l':)
showsDirection RightToLeft = ('r':)
showPreExpOp :: PreSmartExp acc exp t -> String
showPreExpOp (Tag _ i) = "Tag" ++ show i
showPreExpOp Match{} = "Match"
showPreExpOp (Const t c) = "Const " ++ showElt (TupRsingle t) c
showPreExpOp (Undef _) = "Undef"
showPreExpOp Nil{} = "Nil"
showPreExpOp Pair{} = "Pair"
showPreExpOp Prj{} = "Prj"
showPreExpOp VecPack{} = "VecPack"
showPreExpOp VecUnpack{} = "VecUnpack"
showPreExpOp ToIndex{} = "ToIndex"
showPreExpOp FromIndex{} = "FromIndex"
showPreExpOp Case{} = "Case"
showPreExpOp Cond{} = "Cond"
showPreExpOp While{} = "While"
showPreExpOp PrimConst{} = "PrimConst"
showPreExpOp PrimApp{} = "PrimApp"
showPreExpOp Index{} = "Index"
showPreExpOp LinearIndex{} = "LinearIndex"
showPreExpOp Shape{} = "Shape"
showPreExpOp ShapeSize{} = "ShapeSize"
showPreExpOp Foreign{} = "Foreign"
showPreExpOp Coerce{} = "Coerce"