module Data.Array.Knead.Simple.Private where
import qualified Data.Array.Knead.Index.Nested.Shape as Shape
import qualified Data.Array.Knead.Expression as Expr
import Data.Array.Knead.Expression (Exp(Exp), )
import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Monad as Monad
import qualified LLVM.Extra.Maybe as Maybe
import qualified LLVM.Extra.Control as C
import qualified LLVM.Core as LLVM
import qualified Control.Category as Cat
import Control.Monad ((<=<), )
import Prelude hiding (id, map, zipWith, replicate, )
type Val = MultiValue.T
type Code r a = LLVM.CodeGenFunction r (Val a)
data Array sh a =
Array (Exp sh) (forall r. Val (Shape.Index sh) -> Code r a)
shape :: Array sh a -> Exp sh
shape (Array sh _) = sh
(!) ::
(Shape.C sh, Shape.Index sh ~ ix) =>
Array sh a -> Exp ix -> Exp a
(!) (Array _ code) (Exp ix) = Exp (code =<< ix)
the :: (Shape.Scalar sh) => Array sh a -> Exp a
the (Array z code) = Exp (code $ Shape.zeroIndex z)
fromScalar :: (Shape.Scalar sh) => Exp a -> Array sh a
fromScalar = fill Shape.scalar
fill :: Exp sh -> Exp a -> Array sh a
fill sh (Exp code) = Array sh (\_z -> code)
class C array where
lift0 :: Array sh a -> array sh a
lift1 :: (Array sha a -> Array shb b) -> array sha a -> array shb b
lift2 ::
(Array sha a -> Array shb b -> Array shc c) ->
array sha a -> array shb b -> array shc c
instance C Array where
lift0 = Cat.id
lift1 = Cat.id
lift2 = Cat.id
gather ::
(C array,
Shape.C sh0, Shape.Index sh0 ~ ix0,
Shape.C sh1, Shape.Index sh1 ~ ix1,
MultiValue.C a) =>
array sh1 ix0 ->
array sh0 a ->
array sh1 a
gather =
lift2 $ \(Array sh1 f) (Array _sh0 code) ->
Array sh1 (code <=< f)
backpermute2 ::
(C array,
Shape.C sh0, Shape.Index sh0 ~ ix0,
Shape.C sh1, Shape.Index sh1 ~ ix1,
Shape.C sh, Shape.Index sh ~ ix) =>
Exp sh ->
(Exp ix -> Exp ix0) ->
(Exp ix -> Exp ix1) ->
(Exp a -> Exp b -> Exp c) ->
array sh0 a -> array sh1 b -> array sh c
backpermute2 sh projectIndex0 projectIndex1 f =
lift2 $ \(Array _sha codeA) (Array _shb codeB) ->
Array sh
(\ix ->
Monad.liftR2 (Expr.unliftM2 f)
(codeA =<< Expr.unliftM1 projectIndex0 ix)
(codeB =<< Expr.unliftM1 projectIndex1 ix))
id ::
(Shape.C sh, Shape.Index sh ~ ix) =>
Exp sh -> Array sh ix
id sh = Array sh return
map ::
(C array, Shape.C sh) =>
(Exp a -> Exp b) ->
array sh a -> array sh b
map f =
lift1 $ \(Array sh code) ->
Array sh (Expr.unliftM1 f <=< code)
mapWithIndex ::
(C array, Shape.C sh, Shape.Index sh ~ ix) =>
(Exp ix -> Exp a -> Exp b) ->
array sh a -> array sh b
mapWithIndex f =
lift1 $ \(Array sh code) ->
Array sh (\ix -> Expr.unliftM2 f ix =<< code ix)
fold1Code ::
(Shape.C sh1, Shape.Index sh1 ~ ix1, MultiValue.C a) =>
(Exp a -> Exp a -> Exp a) ->
Exp sh1 ->
(Val ix0 -> Val ix1 -> Code r a) ->
(Val ix0 -> Code r a)
fold1Code f (Exp nc) code ix = do
n <- nc
fmap Maybe.fromJust $
Shape.loop
(\i0 macc0 -> do
a <- code ix i0
acc1 <- Maybe.run macc0 (return a) (flip (Expr.unliftM2 f) a)
return $ Maybe.just acc1)
n Maybe.nothing
fold1 ::
(C array, Shape.C sh0, Shape.C sh1, MultiValue.C a) =>
(Exp a -> Exp a -> Exp a) ->
array (sh0, sh1) a -> array sh0 a
fold1 f =
lift1 $ \(Array shs code) ->
case Expr.unzip shs of
(sh, s) -> Array sh $ fold1Code f s $ MultiValue.curry code
fold1All ::
(Shape.C sh, MultiValue.C a) =>
(Exp a -> Exp a -> Exp a) ->
Array sh a -> Array () a
fold1All f (Array esh code) =
fold1 f $
Array
(Expr.lift1 (MultiValue.zip (MultiValue.Cons ())) esh)
(code . MultiValue.snd)
findAllCode ::
(Shape.C sh, Shape.Index sh ~ ix, MultiValue.C a) =>
(Exp a -> Exp Bool) ->
Exp sh ->
(Val ix -> Code r a) ->
Code r (Maybe a)
findAllCode p (Exp sh) code = do
n <- sh
finalFound <-
Shape.loop
(\i found ->
C.ifThenElse (Maybe.isJust found)
(return found)
(do
a <- code i
MultiValue.Cons b <- Expr.unliftM1 p a
return $ Maybe.fromBool b a))
n Maybe.nothing
Maybe.run finalFound
(return MultiValue.nothing)
(return . MultiValue.just)
findAll ::
(Shape.C sh, MultiValue.C a) =>
(Exp a -> Exp Bool) ->
Array sh a -> Exp (Maybe a)
findAll p (Array sh code) = Exp (findAllCode p sh code)
class Process proc where
infixl 3 $:.
($:.) :: (Process proc0, Process proc1) => proc0 -> (proc0 -> proc1) -> proc1
($:.) = flip ($)