module Data.Array.Knead.Shape (
C(..), Index,
Size,
value,
paramWith,
load,
intersect,
offset,
ZeroBased(ZeroBased), zeroBased, zeroBasedSize,
Range(Range), range, rangeFrom, rangeTo,
Shifted(Shifted), shifted, shiftedOffset, shiftedSize,
Enumeration(Enumeration), EnumBounded(..),
Scalar(..),
Sequence(..),
) where
import qualified Data.Array.Knead.Expression as Expr
import qualified Data.Array.Knead.Parameter as Param
import Data.Array.Knead.Shape.Orphan
(zeroBased, zeroBasedSize,
singletonRange, unzipRange, singletonShifted, unzipShifted)
import Data.Array.Knead.Expression (Exp, )
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Shape
(Index, ZeroBased, Range(Range), Shifted(Shifted),
Enumeration(Enumeration))
import Data.Ix (Ix)
import qualified LLVM.Extra.Multi.Value.Memory as MultiMem
import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Multi.Iterator as IterMV
import qualified LLVM.Extra.Iterator as Iter
import qualified LLVM.Extra.ScalarOrVector as SoV
import qualified LLVM.Extra.Arithmetic as A
import LLVM.Extra.Multi.Value (atom)
import qualified LLVM.Util.Loop as Loop
import qualified LLVM.Core as LLVM
import Foreign.Storable (Storable)
import Foreign.Ptr (Ptr)
import qualified Data.Enum.Storable as Enum
import Data.Tagged (Tagged)
import Data.Tuple.HT (mapSnd)
import Data.Word (Word8, Word16, Word32, Word64)
import Data.Int (Int8, Int16, Int32, Int64)
import qualified Control.Monad.HT as Monad
import Control.Applicative ((<$>))
import Prelude2010
import Prelude ()
type Size = Word64
value :: (C sh, Expr.Value val) => sh -> val sh
value = Expr.lift0 . MultiValue.cons
paramWith ::
(Storable b, MultiMem.C b, Expr.Value val) =>
Param.T p b ->
(forall parameters.
(Storable parameters, MultiMem.C parameters) =>
(p -> parameters) ->
(MultiValue.T parameters -> val b) ->
a) ->
a
paramWith p f =
Param.withMulti p (\get val -> f get (Expr.lift0 . val))
load ::
(MultiMem.C sh) =>
f sh -> LLVM.Value (Ptr (MultiMem.Struct sh)) ->
LLVM.CodeGenFunction r (MultiValue.T sh)
load _ = MultiMem.load
intersect :: (C sh) => Exp sh -> Exp sh -> Exp sh
intersect = Expr.liftM2 intersectCode
offset ::
(C sh) =>
MultiValue.T sh -> MultiValue.T (Index sh) ->
LLVM.CodeGenFunction r (LLVM.Value Size)
offset sh ix = ($ix) . snd =<< sizeOffset sh
class (MultiValue.C sh, MultiValue.C (Index sh), Shape.Indexed sh) => C sh where
intersectCode ::
MultiValue.T sh -> MultiValue.T sh ->
LLVM.CodeGenFunction r (MultiValue.T sh)
size :: MultiValue.T sh -> LLVM.CodeGenFunction r (LLVM.Value Size)
sizeOffset ::
MultiValue.T sh ->
LLVM.CodeGenFunction r
(LLVM.Value Size,
MultiValue.T (Index sh) -> LLVM.CodeGenFunction r (LLVM.Value Size))
iterator :: (Index sh ~ ix) => MultiValue.T sh -> Iter.T r (MultiValue.T ix)
loop ::
(Index sh ~ ix, MultiValue.C ix, Loop.Phi state) =>
(MultiValue.T ix -> state -> LLVM.CodeGenFunction r state) ->
MultiValue.T sh -> state -> LLVM.CodeGenFunction r state
loop f sh = Iter.mapState_ f (iterator sh)
instance C () where
intersectCode _ _ = return $ MultiValue.cons ()
size _ = return A.one
sizeOffset _ = return (A.one, \_ -> return A.zero)
iterator = Iter.singleton
loop = id
class C sh => Scalar sh where
scalar :: (Expr.Value val) => val sh
zeroIndex :: (Expr.Value val) => f sh -> val (Index sh)
instance Scalar () where
scalar = Expr.lift0 $ MultiValue.Cons ()
zeroIndex _ = Expr.lift0 $ MultiValue.Cons ()
class
(C sh,
MultiValue.IntegerConstant (Index sh),
MultiValue.Additive (Index sh)) =>
Sequence sh where
sequenceShapeFromIndex ::
MultiValue.T (Index sh) -> LLVM.CodeGenFunction r (MultiValue.T sh)
class
(MultiValue.Additive n, MultiValue.Real n, MultiValue.IntegerConstant n) =>
ToSize n where
toSize :: MultiValue.T n -> LLVM.CodeGenFunction r (LLVM.Value Size)
instance ToSize Word8 where toSize (MultiValue.Cons n) = LLVM.ext n
instance ToSize Word16 where toSize (MultiValue.Cons n) = LLVM.ext n
instance ToSize Word32 where toSize (MultiValue.Cons n) = LLVM.ext n
instance ToSize Word64 where toSize (MultiValue.Cons n) = return n
instance ToSize Int8 where toSize (MultiValue.Cons n) = LLVM.zext n
instance ToSize Int16 where toSize (MultiValue.Cons n) = LLVM.zext n
instance ToSize Int32 where toSize (MultiValue.Cons n) = LLVM.zext n
instance ToSize Int64 where toSize (MultiValue.Cons n) = LLVM.bitcast n
instance
(Integral n, ToSize n, MultiValue.Comparison n) => C (ZeroBased n) where
intersectCode sha shb =
zeroBased <$> MultiValue.min (zeroBasedSize sha) (zeroBasedSize shb)
size = toSize . zeroBasedSize
sizeOffset sh = Monad.lift2 (,) (toSize $ zeroBasedSize sh) (return toSize)
iterator sh =
IterMV.take (zeroBasedSize sh) $
Iter.iterate MultiValue.inc MultiValue.zero
instance
(Integral n, ToSize n, MultiValue.Comparison n) =>
Sequence (ZeroBased n) where
sequenceShapeFromIndex = return . zeroBased
rangeSize ::
(ToSize n) =>
Range (MultiValue.T n) -> LLVM.CodeGenFunction r (LLVM.Value Size)
rangeSize (Range from to) =
toSize =<< MultiValue.inc =<< MultiValue.sub to from
rangeFrom :: (Expr.Value val) => val (Range n) -> val n
rangeFrom = Expr.lift1 $ Shape.rangeFrom . unzipRange
rangeTo :: (Expr.Value val) => val (Range n) -> val n
rangeTo = Expr.lift1 $ Shape.rangeTo . unzipRange
range :: (Expr.Value val) => val n -> val n -> val (Range n)
range =
Expr.lift2 $
\(MultiValue.Cons from) (MultiValue.Cons to) ->
MultiValue.Cons (Range from to)
instance (Ix n, ToSize n, MultiValue.Comparison n) => C (Range n) where
intersectCode =
MultiValue.modifyF2 (singletonRange atom) (singletonRange atom) $
\(Range fromN toN) (Range fromM toM) ->
Monad.lift2 Range (MultiValue.max fromN fromM) (MultiValue.min toN toM)
size = rangeSize . unzipRange
sizeOffset rngValue =
case unzipRange rngValue of
rng@(Range from _to) ->
Monad.lift2 (,) (rangeSize rng)
(return $ \i -> toSize =<< MultiValue.sub i from)
iterator rngValue =
case MultiValue.decompose (singletonRange atom) rngValue of
Range from to ->
IterMV.takeWhile (MultiValue.cmp LLVM.CmpGE to) $
Iter.iterate MultiValue.inc from
shiftedOffset :: (Expr.Value val) => val (Shifted n) -> val n
shiftedOffset = Expr.lift1 $ Shape.shiftedOffset . unzipShifted
shiftedSize :: (Expr.Value val) => val (Shifted n) -> val n
shiftedSize = Expr.lift1 $ Shape.shiftedSize . unzipShifted
shifted :: (Expr.Value val) => val n -> val n -> val (Shifted n)
shifted =
Expr.lift2 $
\(MultiValue.Cons from) (MultiValue.Cons to) ->
MultiValue.Cons (Shifted from to)
instance (Integral n, ToSize n, MultiValue.Comparison n) => C (Shifted n) where
intersectCode =
MultiValue.modifyF2 (singletonShifted atom) (singletonShifted atom) $
\(Shifted startN lenN) (Shifted startM lenM) -> do
start <- MultiValue.max startN startM
endN <- MultiValue.add startN lenN
endM <- MultiValue.add startM lenM
end <- MultiValue.min endN endM
Shifted start <$> MultiValue.sub end start
size = toSize . shiftedSize
sizeOffset shapeValue =
case unzipShifted shapeValue of
Shifted start len ->
Monad.lift2 (,) (toSize len)
(return $ \i -> toSize =<< MultiValue.sub i start)
iterator rngValue =
case MultiValue.decompose (singletonShifted atom) rngValue of
Shifted from len ->
IterMV.take len $ Iter.iterate MultiValue.inc from
class (IterMV.Enum enum, MultiValue.Bounded enum) => EnumBounded enum where
enumOffset :: MultiValue.T enum -> LLVM.CodeGenFunction r (LLVM.Value Size)
instance
(ToSize w, MultiValue.Additive w,
LLVM.IsInteger w, SoV.IntegerConstant w, Num w,
MultiValue.Repr LLVM.Value w ~ LLVM.Value w,
LLVM.CmpRet w, LLVM.CmpResult w ~ Bool,
Enum e, Bounded e) =>
EnumBounded (Enum.T w e) where
enumOffset ix =
toSize =<<
MultiValue.sub
(MultiValue.fromEnum ix)
(MultiValue.fromEnum $ MultiValue.minBound `asTypeOf` ix)
instance
(Enum enum, Bounded enum, EnumBounded enum) => C (Enumeration enum) where
intersectCode _sha shb = return shb
size = return . A.fromInteger' . toInteger . Shape.size . plainEnumeration
sizeOffset sh = do
sz <- size sh
return (sz, enumOffset)
iterator _ = IterMV.enumFromTo MultiValue.minBound MultiValue.maxBound
plainEnumeration :: val (Enumeration enum) -> Enumeration enum
plainEnumeration _ = Enumeration
instance (C sh) => C (Tagged tag sh) where
intersectCode = MultiValue.liftTaggedM2 intersectCode
size = size . MultiValue.untag
sizeOffset =
fmap (mapSnd (. MultiValue.untag)) . sizeOffset . MultiValue.untag
iterator = fmap MultiValue.tag . iterator . MultiValue.untag
instance (C n, C m) => C (n,m) where
intersectCode a b =
case (MultiValue.unzip a, MultiValue.unzip b) of
((an,am), (bn,bm)) ->
Monad.lift2 MultiValue.zip
(intersectCode an bn)
(intersectCode am bm)
size nm =
case MultiValue.unzip nm of
(n,m) -> Monad.liftJoin2 A.mul (size n) (size m)
sizeOffset nm =
case MultiValue.unzip nm of
(n,m) -> do
(ns, iOffset) <- sizeOffset n
(ms, jOffset) <- sizeOffset m
sz <- A.mul ns ms
return
(sz,
\ij ->
case MultiValue.unzip ij of
(i,j) -> do
il <- iOffset i
jl <- jOffset j
A.add jl =<< A.mul ms il)
iterator nm =
case MultiValue.unzip nm of
(n,m) ->
uncurry MultiValue.zip <$>
Iter.cartesian (iterator n) (iterator m)
loop code nm =
case MultiValue.unzip nm of
(n,m) -> loop (\i -> loop (\j -> code (MultiValue.zip i j)) m) n
instance (C n, C m, C l) => C (n,m,l) where
intersectCode a b =
case (MultiValue.unzip3 a, MultiValue.unzip3 b) of
((ai,aj,ak), (bi,bj,bk)) ->
Monad.lift3 MultiValue.zip3
(intersectCode ai bi)
(intersectCode aj bj)
(intersectCode ak bk)
size nml =
case MultiValue.unzip3 nml of
(n,m,l) ->
Monad.liftJoin2 A.mul (size n) $
Monad.liftJoin2 A.mul (size m) (size l)
sizeOffset nml =
case MultiValue.unzip3 nml of
(n,m,l) -> do
(ns, iOffset) <- sizeOffset n
(ms, jOffset) <- sizeOffset m
(ls, kOffset) <- sizeOffset l
sz <- A.mul ns =<< A.mul ms ls
return
(sz,
\ijk ->
case MultiValue.unzip3 ijk of
(i,j,k) -> do
il <- iOffset i
jl <- jOffset j
kl <- kOffset k
A.add kl =<< A.mul ls =<< A.add jl =<< A.mul ms il)
iterator nml =
case MultiValue.unzip3 nml of
(n,m,l) ->
fmap (\(a,(b,c)) -> MultiValue.zip3 a b c) $
Iter.cartesian (iterator n) $
Iter.cartesian (iterator m) (iterator l)
loop code nml =
case MultiValue.unzip3 nml of
(n,m,l) ->
loop (\i -> loop (\j -> loop (\k ->
code (MultiValue.zip3 i j k))
l) m) n