module Data.Array.Knead.Index.Nested.Shape (
C(..),
value,
paramWith,
load,
intersect,
flattenIndex,
Range(..),
Shifted(..),
Scalar(..),
) where
import qualified Data.Array.Knead.Expression as Expr
import qualified Data.Array.Knead.Parameter as Param
import Data.Array.Knead.Expression (Exp, )
import qualified LLVM.Extra.Multi.Value.Memory as MultiValueMemory
import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.Control as C
import LLVM.Extra.Multi.Value (atom)
import LLVM.Extra.Monad (liftR2)
import qualified LLVM.Util.Loop as Loop
import qualified LLVM.Core as LLVM
import Foreign.Storable (Storable, )
import Foreign.Ptr (Ptr, )
import Data.Word (Word32, Word64)
import Data.Int (Int32, Int64)
import qualified Control.Monad.HT as Monad
import Control.Applicative ((<$>))
value :: (C sh, Expr.Value val) => sh -> val sh
value = Expr.lift0 . MultiValue.cons
paramWith ::
(Storable b, MultiValueMemory.C b, Expr.Value val) =>
Param.T p b ->
(forall parameters.
(Storable parameters,
MultiValueMemory.C parameters) =>
(p -> parameters) ->
(MultiValue.T parameters -> val b) ->
a) ->
a
paramWith p f =
Param.withMulti p (\get val -> f get (Expr.lift0 . val))
load ::
(MultiValueMemory.C sh) =>
f sh -> LLVM.Value (Ptr (MultiValueMemory.Struct sh)) ->
LLVM.CodeGenFunction r (MultiValue.T sh)
load _ = MultiValueMemory.load
intersect :: (C sh) => Exp sh -> Exp sh -> Exp sh
intersect = Expr.liftM2 intersectCode
flattenIndex ::
(C sh) =>
MultiValue.T sh -> MultiValue.T (Index sh) ->
LLVM.CodeGenFunction r (LLVM.Value Word32)
flattenIndex sh ix =
fmap snd $ flattenIndexRec sh ix
class (MultiValue.C sh) => C sh where
type Index sh :: *
intersectCode ::
MultiValue.T sh -> MultiValue.T sh ->
LLVM.CodeGenFunction r (MultiValue.T sh)
sizeCode ::
MultiValue.T sh ->
LLVM.CodeGenFunction r (LLVM.Value Word32)
size :: sh -> Int
flattenIndexRec ::
MultiValue.T sh -> MultiValue.T (Index sh) ->
LLVM.CodeGenFunction r (LLVM.Value Word32, LLVM.Value Word32)
loop ::
(Index sh ~ ix, Loop.Phi state) =>
(MultiValue.T ix -> state -> LLVM.CodeGenFunction r state) ->
MultiValue.T sh -> state -> LLVM.CodeGenFunction r state
instance C () where
type Index () = ()
intersectCode _ _ = return $ MultiValue.cons ()
sizeCode _ = return A.one
size _ = 1
flattenIndexRec _ _ = return (A.one, A.zero)
loop = id
class C sh => Scalar sh where
scalar :: (Expr.Value val) => val sh
zeroIndex :: (Expr.Value val) => f sh -> val (Index sh)
instance Scalar () where
scalar = Expr.lift0 $ MultiValue.Cons ()
zeroIndex _ = Expr.lift0 $ MultiValue.Cons ()
loopPrimitive ::
(MultiValue.Repr LLVM.Value j ~ LLVM.Value j,
Num j, LLVM.IsConst j, LLVM.IsInteger j,
LLVM.CmpRet j, LLVM.CmpResult j ~ Bool,
MultiValue.Additive i, MultiValue.IntegerConstant i,
Loop.Phi state) =>
(MultiValue.T i -> state -> LLVM.CodeGenFunction r state) ->
MultiValue.T j -> state -> LLVM.CodeGenFunction r state
loopPrimitive code (MultiValue.Cons n) ptrStart =
loopStart code n MultiValue.zero ptrStart
loopStart ::
(Num j, LLVM.IsConst j, LLVM.IsInteger j,
LLVM.CmpRet j, LLVM.CmpResult j ~ Bool,
MultiValue.Additive i, MultiValue.IntegerConstant i,
Loop.Phi state) =>
(MultiValue.T i -> state -> LLVM.CodeGenFunction r state) ->
LLVM.Value j ->
MultiValue.T i -> state -> LLVM.CodeGenFunction r state
loopStart code n start ptrStart =
fmap fst $
C.fixedLengthLoop n (ptrStart, start) $ \(ptr, k) ->
Monad.lift2 (,)
(code k ptr)
(MultiValue.add k $ MultiValue.fromInteger' 1)
instance C Word32 where
type Index Word32 = Word32
intersectCode = MultiValue.min
sizeCode (MultiValue.Cons n) = return n
size = fromIntegral
flattenIndexRec (MultiValue.Cons n) (MultiValue.Cons i) = return (n, i)
loop = loopPrimitive
instance C Word64 where
type Index Word64 = Word64
intersectCode = MultiValue.min
sizeCode (MultiValue.Cons n) = LLVM.trunc n
size = fromIntegral
flattenIndexRec (MultiValue.Cons n) (MultiValue.Cons i) =
Monad.lift2 (,) (LLVM.trunc n) (LLVM.trunc i)
loop = loopPrimitive
instance C Int32 where
type Index Int32 = Int32
intersectCode = MultiValue.min
sizeCode (MultiValue.Cons n) = LLVM.bitcast n
size = fromIntegral
flattenIndexRec (MultiValue.Cons n) (MultiValue.Cons i) =
Monad.lift2 (,) (LLVM.bitcast n) (LLVM.bitcast i)
loop = loopPrimitive
instance C Int64 where
type Index Int64 = Int64
intersectCode = MultiValue.min
sizeCode (MultiValue.Cons n) = LLVM.trunc n
size = fromIntegral
flattenIndexRec (MultiValue.Cons n) (MultiValue.Cons i) =
Monad.lift2 (,) (LLVM.trunc n) (LLVM.trunc i)
loop = loopPrimitive
data Range n = Range n n
singletonRange :: n -> Range n
singletonRange n = Range n n
class
(MultiValue.Additive n, MultiValue.Real n, MultiValue.IntegerConstant n) =>
ToSize n where
toSize :: MultiValue.T n -> LLVM.CodeGenFunction r (LLVM.Value Word32)
instance ToSize Word32 where toSize (MultiValue.Cons n) = LLVM.adapt n
instance ToSize Word64 where toSize (MultiValue.Cons n) = LLVM.adapt n
instance ToSize Int32 where toSize (MultiValue.Cons n) = LLVM.bitcast n
instance ToSize Int64 where toSize (MultiValue.Cons n) = LLVM.trunc n
rangeSize ::
(ToSize n) =>
Range (MultiValue.T n) -> LLVM.CodeGenFunction r (LLVM.Value Word32)
rangeSize (Range from to) =
toSize =<<
MultiValue.add (MultiValue.fromInteger' 1) =<< MultiValue.sub to from
instance (MultiValue.C n) => MultiValue.C (Range n) where
type Repr f (Range n) = Range (MultiValue.Repr f n)
cons (Range from to) =
MultiValue.compose $ Range (MultiValue.cons from) (MultiValue.cons to)
undef = MultiValue.compose $ singletonRange MultiValue.undef
zero = MultiValue.compose $ singletonRange MultiValue.zero
phis bb a =
case MultiValue.decompose (singletonRange atom) a of
Range a0 a1 ->
fmap MultiValue.compose $
Monad.lift2 Range (MultiValue.phis bb a0) (MultiValue.phis bb a1)
addPhis bb a b =
case (MultiValue.decompose (singletonRange atom) a,
MultiValue.decompose (singletonRange atom) b) of
(Range a0 a1, Range b0 b1) ->
MultiValue.addPhis bb a0 b0 >>
MultiValue.addPhis bb a1 b1
type instance
MultiValue.Decomposed f (Range pn) =
Range (MultiValue.Decomposed f pn)
type instance
MultiValue.PatternTuple (Range pn) =
Range (MultiValue.PatternTuple pn)
instance (MultiValue.Compose n) => MultiValue.Compose (Range n) where
type Composed (Range n) = Range (MultiValue.Composed n)
compose (Range from to) =
case (MultiValue.compose from, MultiValue.compose to) of
(MultiValue.Cons f, MultiValue.Cons t) ->
MultiValue.Cons (Range f t)
instance (MultiValue.Decompose pn) => MultiValue.Decompose (Range pn) where
decompose (Range pfrom pto) (MultiValue.Cons (Range from to)) =
Range
(MultiValue.decompose pfrom (MultiValue.Cons from))
(MultiValue.decompose pto (MultiValue.Cons to))
instance (Integral n, ToSize n) => C (Range n) where
type Index (Range n) = n
intersectCode =
MultiValue.modifyF2 (singletonRange atom) (singletonRange atom) $
\(Range fromN toN) (Range fromM toM) ->
Monad.lift2 Range (MultiValue.max fromN fromM) (MultiValue.min toN toM)
sizeCode = rangeSize . MultiValue.decompose (singletonRange atom)
size (Range from to) = fromIntegral $ tofrom+1
flattenIndexRec rngValue i =
case MultiValue.decompose (singletonRange atom) rngValue of
rng@(Range from _to) ->
Monad.lift2 (,) (rangeSize rng) (toSize =<< MultiValue.sub i from)
loop code rngValue ptrStart =
case MultiValue.decompose (singletonRange atom) rngValue of
rng@(Range from _to) -> do
n <- rangeSize rng
loopStart code n from ptrStart
data Shifted n = Shifted {shiftedOffset, shiftedSize :: n}
singletonShifted :: n -> Shifted n
singletonShifted n = Shifted n n
instance (MultiValue.C n) => MultiValue.C (Shifted n) where
type Repr f (Shifted n) = Shifted (MultiValue.Repr f n)
cons (Shifted offset len) =
MultiValue.compose $
Shifted (MultiValue.cons offset) (MultiValue.cons len)
undef = MultiValue.compose $ singletonShifted MultiValue.undef
zero = MultiValue.compose $ singletonShifted MultiValue.zero
phis bb a =
case MultiValue.decompose (singletonShifted atom) a of
Shifted a0 a1 ->
fmap MultiValue.compose $
Monad.lift2 Shifted (MultiValue.phis bb a0) (MultiValue.phis bb a1)
addPhis bb a b =
case (MultiValue.decompose (singletonShifted atom) a,
MultiValue.decompose (singletonShifted atom) b) of
(Shifted a0 a1, Shifted b0 b1) ->
MultiValue.addPhis bb a0 b0 >>
MultiValue.addPhis bb a1 b1
type instance
MultiValue.Decomposed f (Shifted pn) =
Shifted (MultiValue.Decomposed f pn)
type instance
MultiValue.PatternTuple (Shifted pn) =
Shifted (MultiValue.PatternTuple pn)
instance (MultiValue.Compose n) => MultiValue.Compose (Shifted n) where
type Composed (Shifted n) = Shifted (MultiValue.Composed n)
compose (Shifted offset len) =
case (MultiValue.compose offset, MultiValue.compose len) of
(MultiValue.Cons o, MultiValue.Cons l) ->
MultiValue.Cons (Shifted o l)
instance (MultiValue.Decompose pn) => MultiValue.Decompose (Shifted pn) where
decompose (Shifted poffset plen) (MultiValue.Cons (Shifted offset len)) =
Shifted
(MultiValue.decompose poffset (MultiValue.Cons offset))
(MultiValue.decompose plen (MultiValue.Cons len))
instance (Integral n, ToSize n) => C (Shifted n) where
type Index (Shifted n) = n
intersectCode =
MultiValue.modifyF2 (singletonShifted atom) (singletonShifted atom) $
\(Shifted offsetN lenN) (Shifted offsetM lenM) -> do
offset <- MultiValue.max offsetN offsetM
endN <- MultiValue.add offsetN lenN
endM <- MultiValue.add offsetM lenM
end <- MultiValue.min endN endM
Shifted offset <$> MultiValue.sub end offset
sizeCode =
toSize . shiftedSize . MultiValue.decompose (singletonShifted atom)
size (Shifted _offset len) = fromIntegral len
flattenIndexRec shapeValue i =
case MultiValue.decompose (singletonShifted atom) shapeValue of
Shifted offset len ->
Monad.lift2 (,) (toSize len) (toSize =<< MultiValue.sub i offset)
loop code rngValue ptrStart =
case MultiValue.decompose (singletonShifted atom) rngValue of
Shifted from len -> do
n <- toSize len
loopStart code n from ptrStart
instance (C n, C m) => C (n,m) where
type Index (n,m) = (Index n, Index m)
intersectCode a b =
case (MultiValue.unzip a, MultiValue.unzip b) of
((an,am), (bn,bm)) ->
Monad.lift2 MultiValue.zip
(intersectCode an bn)
(intersectCode am bm)
sizeCode nm =
case MultiValue.unzip nm of
(n,m) -> liftR2 A.mul (sizeCode n) (sizeCode m)
size (n,m) = size n * size m
flattenIndexRec nm ij =
case (MultiValue.unzip nm, MultiValue.unzip ij) of
((n,m), (i,j)) -> do
(ns, il) <- flattenIndexRec n i
(ms, jl) <- flattenIndexRec m j
Monad.lift2 (,)
(A.mul ns ms)
(A.add jl =<< A.mul ms il)
loop code nm =
case MultiValue.unzip nm of
(n,m) -> loop (\i -> loop (\j -> code (MultiValue.zip i j)) m) n
instance (C n, C m, C l) => C (n,m,l) where
type Index (n,m,l) = (Index n, Index m, Index l)
intersectCode a b =
case (MultiValue.unzip3 a, MultiValue.unzip3 b) of
((ai,aj,ak), (bi,bj,bk)) ->
Monad.lift3 MultiValue.zip3
(intersectCode ai bi)
(intersectCode aj bj)
(intersectCode ak bk)
sizeCode nml =
case MultiValue.unzip3 nml of
(n,m,l) ->
liftR2 A.mul (sizeCode n) $
liftR2 A.mul (sizeCode m) (sizeCode l)
size (n,m,l) = size n * size m * size l
flattenIndexRec nml ijk =
case (MultiValue.unzip3 nml, MultiValue.unzip3 ijk) of
((n,m,l), (i,j,k)) -> do
(ns, il) <- flattenIndexRec n i
(ms, jl) <- flattenIndexRec m j
x0 <- A.add jl =<< A.mul ms il
(ls, kl) <- flattenIndexRec l k
x1 <- A.add kl =<< A.mul ls x0
sz <- A.mul ns =<< A.mul ms ls
return (sz, x1)
loop code nml =
case MultiValue.unzip3 nml of
(n,m,l) ->
loop (\i -> loop (\j -> loop (\k ->
code (MultiValue.zip3 i j k))
l) m) n