module Data.Array.Knead.Shape.Cubic (
C(switch),
switchInt,
intersect,
value,
constant,
paramWith,
tunnel,
offsetCode,
peek,
poke,
computeSize,
Struct,
T(..),
Z(Z), z,
(:.)((:.)),
Shape, shape,
Index, index,
cons, (#:.),
head,
tail,
switchR,
loadMultiValue,
storeMultiValue,
) where
import qualified Data.Array.Knead.Shape as Shape
import qualified Data.Array.Knead.Shape.Cubic.Int as Index
import qualified Data.Array.Knead.Parameter as Param
import qualified Data.Array.Knead.Expression as Expr
import Data.Array.Knead.Expression (Exp, )
import qualified Data.Array.Comfort.Shape as ComfortShape
import Data.Array.Comfort.Shape (ZeroBased(ZeroBased))
import qualified LLVM.Extra.Multi.Value.Memory as MultiValueMemory
import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Multi.Iterator as IterMV
import qualified LLVM.Extra.Iterator as Iter
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.Control as C
import LLVM.Extra.Multi.Value (Atom, )
import qualified LLVM.Util.Loop as Loop
import qualified LLVM.Core as LLVM
import qualified Foreign.Storable as St
import Foreign.Storable.FixedArray (sizeOfArray, )
import Foreign.Marshal.Array (advancePtr, )
import Foreign.Ptr (Ptr, castPtr, )
import Control.Monad (liftM2, )
import Prelude hiding (min, head, tail, )
class C ix where
switch ::
f Z ->
(forall ix0 i. (C ix0, Index.Single i) => f (ix0 :. i)) ->
f ix
instance C Z where
switch x _ = x
instance (C ix0, Index.Single i) => C (ix0 :. i) where
switch _ x = x
newtype SwitchInt f ix i = SwitchInt {runSwitchInt :: f (ix :. i)}
switchInt ::
(C ix) =>
f Z ->
(forall ix0. (C ix0) => f (ix0 :. Index.Int)) ->
f ix
switchInt z0 cons0 =
switch z0
(runSwitchInt $ Index.switchSingle (SwitchInt cons0))
newtype Op2 tag sh =
Op2 {runOp2 :: Exp (T tag sh) -> Exp (T tag sh) -> Exp (T tag sh)}
intersect :: C sh => Exp (Shape sh) -> Exp (Shape sh) -> Exp (Shape sh)
intersect =
runOp2 $
switchInt
(Op2 $ \z0 _ -> z0)
(Op2 $
switchR $ \is i ->
switchR $ \js j ->
intersect is js #:. Expr.min i j)
_value :: (C sh, MultiValue.C sh) => sh -> Exp sh
_value = Expr.lift0 . MultiValue.cons
newtype MakeValue val tag sh =
MakeValue {runMakeValue :: T tag sh -> val (T tag sh)}
value :: (C sh, Expr.Value val) => T tag sh -> val (T tag sh)
value =
runMakeValue $
switchInt
(MakeValue $ \(Cons Z) -> z)
(MakeValue $ \(Cons (t:.h)) ->
value (Cons t) #:. Expr.lift0 (MultiValue.cons h))
paramWith ::
(C sh, Expr.Value val) =>
Param.T p (T tag sh) ->
(forall parameters.
(St.Storable parameters,
MultiValueMemory.C parameters) =>
(p -> parameters) ->
(MultiValue.T parameters -> val (T tag sh)) ->
a) ->
a
paramWith p f =
case tunnel p of
Param.Tunnel get val -> f get (Expr.lift0 . val)
tunnel :: (C sh) => Param.T p (T tag sh) -> Param.Tunnel p (T tag sh)
tunnel p =
case structFieldsPropF p of
StructFieldsProp -> Param.tunnel value p
data StructFieldsProp sh = LLVM.StructFields (Struct sh) => StructFieldsProp
_structFieldsProp :: (C sh) => f sh -> StructFieldsProp sh
_structFieldsProp _p = structFieldsRec
structFieldsPropF :: (C sh) => f (g sh) -> StructFieldsProp sh
structFieldsPropF _p = structFieldsRec
withStructFieldsPropFF ::
(C sh) => (StructFieldsProp sh -> f (g (h sh))) -> f (g (h sh))
withStructFieldsPropFF f = f structFieldsRec
structFieldsRec :: (C sh) => StructFieldsProp sh
structFieldsRec =
switchInt
StructFieldsProp
(succStructFieldsProp structFieldsRec)
succStructFieldsProp ::
StructFieldsProp sh -> StructFieldsProp (sh:.Index.Int)
succStructFieldsProp StructFieldsProp = StructFieldsProp
data Z = Z
deriving (Eq, Ord, Read, Show)
infixl 3 :., #:.
data tail :. head = !tail :. !head
deriving (Eq, Ord, Read, Show)
newtype T tag sh = Cons {decons :: sh}
data ShapeTag
data IndexTag
type Shape = T ShapeTag
type Index = T IndexTag
shape :: sh -> Shape sh
shape = Cons
index :: ix -> Index ix
index = Cons
(#:.) :: (Expr.Value val) => val (T tag sh) -> val i -> val (T tag (sh:.i))
(#:.) = cons
cons :: (Expr.Value val) => val (T tag sh) -> val i -> val (T tag (sh:.i))
cons =
Expr.lift2 $
\(MultiValue.Cons t) (MultiValue.Cons h) ->
MultiValue.Cons (t,h)
z :: (Expr.Value val) => val (T tag Z)
z = Expr.lift0 $ MultiValue.Cons ()
head :: (Expr.Value val) => val (T tag (sh:.i)) -> val i
head = Expr.lift1 $ \(MultiValue.Cons (_t,h)) -> MultiValue.Cons h
tail :: (Expr.Value val) => val (T tag (sh:.i)) -> val (T tag sh)
tail = Expr.lift1 $ \(MultiValue.Cons (t,_h)) -> MultiValue.Cons t
switchR ::
Expr.Value val =>
(val (T tag sh) -> val i -> a) -> val (T tag (sh :. i)) -> a
switchR f ix = f (tail ix) (head ix)
instance (tag ~ ShapeTag, sh ~ Z) => Shape.Scalar (T tag sh) where
scalar = Expr.lift0 $ MultiValue.Cons ()
zeroIndex _ = Expr.lift0 $ MultiValue.Cons ()
type family PatternTuple pattern
type family Decomposed (f :: * -> *) tag pattern
type instance PatternTuple (sh:.s) =
PatternTuple sh :. MultiValue.PatternTuple s
type instance Decomposed f tag (sh:.s) =
Decomposed f tag sh :. MultiValue.Decomposed f s
type instance PatternTuple (Atom sh) = sh
type instance Decomposed f tag (Atom sh) = f (T tag sh)
class
(Expr.Composed (Decomposed Exp tag pattern) ~ T tag (PatternTuple pattern)) =>
Decompose tag pattern where
decompose ::
T tag pattern -> Exp (T tag (PatternTuple pattern)) ->
Decomposed Exp tag pattern
instance Decompose tag (Atom sh) where
decompose (Cons _atom) x = x
instance (Decompose tag sh, Expr.Decompose s) => Decompose tag (sh :. s) where
decompose (Cons (psh:.ps)) x =
decompose (Cons psh) (tail x) :. Expr.decompose ps (head x)
type instance MultiValue.PatternTuple (T tag sh) = T tag (PatternTuple sh)
type instance MultiValue.Decomposed f (T tag sh) = Decomposed f tag sh
type family Unwrap sh
type instance Unwrap (T tag sh) = sh
type family Tag sh
type instance Tag (T tag sh) = tag
instance
(Expr.Compose sh,
Expr.Composed sh ~ T (Tag (Expr.Composed sh)) (Unwrap (Expr.Composed sh)),
Expr.Compose s) =>
Expr.Compose (sh :. s) where
type Composed (sh :. s) =
T (Tag (Expr.Composed sh))
(Unwrap (Expr.Composed sh) :. Expr.Composed s)
compose (sh :. s) = cons (Expr.compose sh) (Expr.compose s)
instance (Decompose tag sh) => Expr.Decompose (T tag sh) where
decompose = decompose
instance (C sh) => St.Storable (T tag sh) where
sizeOf (Cons sh) = sizeOfArray (rank sh) (0::Shape.Size)
alignment (Cons _sh) = St.alignment (0::Shape.Size)
poke ptr = poke (castPtr ptr) . decons
peek = fmap Cons . peek . castPtr
type family Repr (f :: * -> *) sh
type instance Repr f Z = ()
type instance Repr f (tail :. head) = (Repr f tail, MultiValue.Repr f head)
instance (C sh) => MultiValue.C (T tag sh) where
type Repr f (T tag sh) = Repr f sh
cons = value
undef = constant $ MultiValue.undef
zero = constant $ MultiValue.zero
addPhis = addPhis
phis = phis
instance (tag ~ ShapeTag, C sh) => ComfortShape.C (T tag sh) where
size = fromIntegral . size . decons
instance (tag ~ ShapeTag, C sh) => ComfortShape.Indexed (T tag sh) where
type Index (T tag sh) = Index sh
indices (Cons ix) = map index $ indices ix
inBounds (Cons sh) (Cons ix) = inBounds sh ix
offset (Cons sh) (Cons ix) = offset sh ix
newtype Indices sh = Indices {runIndices :: sh -> [sh]}
indices :: (C sh) => sh -> [sh]
indices =
runIndices $
switchInt
(Indices $ \Z -> [Z])
(Indices $ \(t :. Index.Int h) ->
liftM2 (:.) (indices t)
(map Index.Int $ ComfortShape.indices $ ZeroBased h))
newtype InBounds sh = InBounds {runInBounds :: sh -> sh -> Bool}
inBounds :: (C sh) => sh -> sh -> Bool
inBounds =
runInBounds $
switchInt
(InBounds $ \Z Z -> True)
(InBounds $ \(sh :. Index.Int s) (ix :. Index.Int i) ->
inBounds sh ix && ComfortShape.inBounds (ZeroBased s) i)
newtype Offset sh = Offset {runOffset :: sh -> sh -> Int}
offset :: (C sh) => sh -> sh -> Int
offset =
runOffset $
switchInt
(Offset $ \Z Z -> 0)
(Offset $ \(sh :. Index.Int s) (ix :. Index.Int i) ->
offset sh ix * fromIntegral s + fromIntegral i)
instance (tag ~ ShapeTag, C sh) => Shape.C (T tag sh) where
size = computeSize
intersectCode = Expr.unliftM2 intersect
sizeOffset sh =
liftM2 (,)
(computeSize sh)
(return $ offsetCode sh)
iterator = iterator
loop = loop
type family Struct sh
type instance Struct Z = ()
type instance Struct (sh :. Index.Int) = (Shape.Size, Struct sh)
instance
(C sh, LLVM.StructFields (Struct sh)) =>
MultiValueMemory.C (T tag sh) where
type Struct (T tag sh) = LLVM.Struct (Struct sh)
load = loadMultiValue
store = storeMultiValue
loadMultiValue ::
(C sh) =>
LLVM.Value (Ptr (LLVM.Struct (Struct sh))) ->
LLVM.CodeGenFunction r (MultiValue.T (T tag sh))
loadMultiValue ptr =
withStructFieldsPropFF $ \StructFieldsProp ->
load =<< castPtrValue ptr
storeMultiValue ::
(C sh) =>
MultiValue.T (T tag sh) ->
LLVM.Value (Ptr (LLVM.Struct (Struct sh))) -> LLVM.CodeGenFunction r ()
storeMultiValue x ptr =
case structFieldsPropF x of
StructFieldsProp -> store x =<< castPtrValue ptr
newtype OffsetCode r sh =
OffsetCode {
runOffsetCode ::
MultiValue.T (Shape sh) -> MultiValue.T (Index sh) ->
LLVM.CodeGenFunction r (LLVM.Value Shape.Size)
}
offsetCode ::
(C sh) =>
MultiValue.T (Shape sh) -> MultiValue.T (Index sh) ->
LLVM.CodeGenFunction r (LLVM.Value Shape.Size)
offsetCode =
runOffsetCode $
switchInt
(OffsetCode $ \_zerosh _zeroix -> return A.zero)
(OffsetCode $
switchR $ \sh (MultiValue.Cons s) ->
switchR $ \ix (MultiValue.Cons i) ->
A.add i =<< A.mul s =<< offsetCode sh ix)
newtype Rank sh = Rank {runRank :: sh -> Int}
rank :: (C sh) => sh -> Int
rank =
runRank $
switch
(Rank $ const 0)
(Rank $ succ . rank . (\(sh :. _s) -> sh))
newtype Peek sh = Peek {runPeek :: Ptr Shape.Size -> IO sh}
peek :: (C sh) => Ptr Shape.Size -> IO sh
peek =
runPeek $
switchInt
(Peek $ const $ return Z)
(Peek $ \ptr -> do
h <- St.peek ptr
t <- peek $ advancePtr ptr 1
return (t :. Index.Int h))
newtype Poke sh = Poke {runPoke :: Ptr Shape.Size -> sh -> IO ()}
poke :: (C sh) => Ptr Shape.Size -> sh -> IO ()
poke =
runPoke $
switchInt
(Poke $ const $ const $ return ())
(Poke $ \ptr (sh :. Index.Int i) -> do
St.poke ptr i
poke (advancePtr ptr 1) sh)
castPtrValue ::
(LLVM.StructFields sh) =>
LLVM.Value (Ptr (LLVM.Struct sh)) ->
LLVM.CodeGenFunction r (LLVM.Value (Ptr Shape.Size))
castPtrValue = LLVM.bitcast
newtype Load r tag sh =
Load {
runLoad ::
LLVM.Value (Ptr Shape.Size) ->
LLVM.CodeGenFunction r (MultiValue.T (T tag sh))
}
load ::
(C sh) =>
LLVM.Value (Ptr Shape.Size) ->
LLVM.CodeGenFunction r (MultiValue.T (T tag sh))
load =
runLoad $
switchInt
(Load $ const $ return z)
(Load $ \ptr -> do
h <- LLVM.load ptr
t <- load =<< A.advanceArrayElementPtr ptr
return (t #:. MultiValue.Cons h))
newtype Store r tag sh =
Store {
runStore ::
MultiValue.T (T tag sh) ->
LLVM.Value (Ptr Shape.Size) ->
LLVM.CodeGenFunction r ()
}
store ::
(C sh) =>
MultiValue.T (T tag sh) ->
LLVM.Value (Ptr Shape.Size) ->
LLVM.CodeGenFunction r ()
store =
runStore $
switchInt
(Store $ \_z _ptr -> return ())
(Store $ switchR $ \sh (MultiValue.Cons k) ptr -> do
LLVM.store k ptr
store sh =<< A.advanceArrayElementPtr ptr)
newtype Size sh = Size {runSize :: sh -> Shape.Size}
size :: (C sh) => sh -> Shape.Size
size =
runSize $
switchInt
(Size $ \_z -> 1)
(Size $ \(sh :. Index.Int k) -> k * size sh)
newtype ComputeSize r sh =
ComputeSize {
runComputeSize ::
MultiValue.T (Shape sh) ->
LLVM.CodeGenFunction r (LLVM.Value Shape.Size)
}
computeSize ::
(C sh) =>
MultiValue.T (Shape sh) ->
LLVM.CodeGenFunction r (LLVM.Value Shape.Size)
computeSize =
runComputeSize $
switchInt
(ComputeSize $ \_z -> return A.one)
(ComputeSize $ switchR $ \sh (MultiValue.Cons k) ->
A.mul k =<< computeSize sh)
newtype
Constant val tag sh =
Constant {getConstant :: val Index.Int -> val (T tag sh)}
constant :: (C sh, Expr.Value val) => val Index.Int -> val (T tag sh)
constant =
getConstant $
switchInt
(Constant $ const z)
(Constant $ \x -> constant x #:. x)
newtype AddPhis r tag sh =
AddPhis {
runAddPhis ::
LLVM.BasicBlock ->
MultiValue.T (T tag sh) ->
MultiValue.T (T tag sh) ->
LLVM.CodeGenFunction r ()
}
addPhis ::
(C sh) =>
LLVM.BasicBlock ->
MultiValue.T (T tag sh) ->
MultiValue.T (T tag sh) ->
LLVM.CodeGenFunction r ()
addPhis =
runAddPhis $
switchInt
(AddPhis $ \_ _ _ -> return ())
(AddPhis $ \bb ->
switchR $ \hx tx ->
switchR $ \hy ty ->
MultiValue.addPhis bb tx ty >>
addPhis bb hx hy)
newtype Phis r tag sh =
Phis {
runPhis ::
LLVM.BasicBlock ->
MultiValue.T (T tag sh) ->
LLVM.CodeGenFunction r (MultiValue.T (T tag sh))
}
phis ::
(C sh) =>
LLVM.BasicBlock ->
MultiValue.T (T tag sh) ->
LLVM.CodeGenFunction r (MultiValue.T (T tag sh))
phis =
runPhis $
switchInt
(Phis $ \_ -> return)
(Phis $ \bb ->
switchR $ \h t ->
liftM2 (#:.)
(phis bb h)
(MultiValue.phis bb t))
newtype Iterator r sh =
Iterator {
runIterator ::
MultiValue.T (Shape sh) -> Iter.T r (MultiValue.T (Index sh))
}
iterator ::
(C sh) =>
MultiValue.T (Shape sh) -> Iter.T r (MultiValue.T (Index sh))
iterator =
runIterator $
switchInt
(Iterator $ \ _z -> Iter.empty)
(Iterator $ switchR $ \sh n ->
fmap (\(ix,i) -> ix#:.i) $
Iter.cartesian
(iterator sh)
(IterMV.takeWhile (MultiValue.cmp LLVM.CmpGT n) $
Iter.iterate MultiValue.inc MultiValue.zero))
newtype Loop r state sh =
Loop {
runLoop ::
(MultiValue.T (Index sh) ->
state ->
LLVM.CodeGenFunction r state) ->
MultiValue.T (Shape sh) ->
state ->
LLVM.CodeGenFunction r state
}
loop ::
(C sh, Loop.Phi state) =>
(MultiValue.T (Index sh) ->
state ->
LLVM.CodeGenFunction r state) ->
MultiValue.T (Shape sh) ->
state ->
LLVM.CodeGenFunction r state
loop =
runLoop $
switchInt
(Loop $ \code _z -> code z)
(Loop $ \code -> switchR $ \sh (MultiValue.Cons n) ->
loop
(\ix ptrStart ->
fmap fst $
C.fixedLengthLoop n (ptrStart, A.zero) $ \(ptr, k) ->
liftM2 (,)
(code (ix #:. MultiValue.Cons k) ptr)
(A.inc k))
sh)