module Data.Array.Accelerate.Arithmetic.Interpolation (
bisect,
lookupInterval,
Interpolator13, sampleBasisFunctions13,
) where
import qualified Data.Array.Accelerate.LinearAlgebra.Matrix.Sparse as Sparse
import qualified Data.Array.Accelerate.LinearAlgebra as LinAlg
import qualified Data.Array.Accelerate.Utility.Arrange as Arrange
import qualified Data.Array.Accelerate.Utility.Lift.Exp as Exp
import qualified Data.Array.Accelerate.Utility.Loop as Loop
import Data.Array.Accelerate.LinearAlgebra
(Scalar, Vector, numElems, extrudeVector, )
import qualified Data.Array.Accelerate as A
import Data.Array.Accelerate (Exp, Any(Any), Z(Z), (:.)((:.)), )
import Data.Ord.HT (limit, )
bisect ::
(A.Slice ix, A.Shape ix, A.Ord a, A.Elt a) =>
Vector ix a ->
Scalar ix a ->
Scalar ix (Int, Int) ->
Scalar ix (Int, Int)
bisect nodes xs bounds =
let centers =
A.map
(A.uncurry $ \lower upper -> div (lower+upper) 2)
bounds
in A.zipWith3
(\center interval leftBranch ->
A.cond leftBranch
(Exp.mapSnd (const center) interval)
(Exp.mapFst (const center) interval))
centers bounds $
A.zipWith (A.<) xs $
Arrange.gather (Arrange.mapWithIndex Exp.indexCons centers) nodes
lookupInterval ::
(A.Slice ix, A.Shape ix, A.Ord a, A.Elt a) =>
Vector ix a ->
Scalar ix a ->
Scalar ix Int
lookupInterval nodes x =
A.map A.fst $
Loop.nestLog2 (numElems nodes) (bisect nodes x) $
A.fill (A.shape x) $
A.lift (0 :: Exp Int, numElems nodes)
outerVector ::
(A.Shape ix, A.Slice ix, A.Elt a, A.Elt b, A.Elt c) =>
(Exp a -> Exp b -> Exp c) ->
Scalar ix a -> Vector Z b -> Vector ix c
outerVector f x y =
A.zipWith f
(A.replicate (A.lift $ Any :. numElems y) x)
(extrudeVector (A.shape x) y)
type Interpolator13 a = (a,a) -> (a,a) -> (a,a) -> (a,a) -> a -> a
sampleBasisFunctions13 ::
(A.Slice ix, A.Shape ix, A.Ord a, Num a) =>
Interpolator13 (Exp a) ->
Vector Z a -> Vector ix a -> Sparse.Rows ix a
sampleBasisFunctions13 interpolate nodes zs =
Sparse.Rows (numElems nodes) $
let indices = lookupInterval (extrudeVector (A.shape zs) nodes) zs
minIx = 1
maxIx = numElems nodes 3
limitIndices = A.map (limit (minIx, maxIx)) indices
gatherFromNodes d =
LinAlg.gatherFromVector (A.map (d+) limitIndices) nodes
in outerVector
(A.lift2 $
\(n, ln, z, x) (k, y) ->
case (Exp.unliftQuadruple x, Exp.unliftQuadruple y) of
((xm1,x0,x1,x2), (ym1,y0,y1,y2)) ->
(ln+k :: Exp Int,
A.cond (n A.< minIx) y0 $
A.cond (n A.> maxIx) y1 $
interpolate (xm1,ym1) (x0,y0) (x1,y1) (x2,y2) z))
(A.zip4 indices limitIndices zs
(A.zip4
(gatherFromNodes (1))
(gatherFromNodes 0)
(gatherFromNodes 1)
(gatherFromNodes 2)))
(A.use $
A.fromList (Z:.4)
[(1, (1,0,0,0)), (0, (0,1,0,0)), (1, (0,0,1,0)), (2, (0,0,0,1))])