module Language.Paraiso.OM.Builder.Internal
(
Builder, BuilderState(..),
B, BuilderOf,
makeKernel, initState,
modifyG, getG, freeNode, addNode, valueToNode, lookUpStatic,
load, store,
reduce, broadcast,
shift, loadIndex,
imm, mkOp1, mkOp2
) where
import qualified Algebra.Absolute as Absolute
import qualified Algebra.Additive as Additive
import qualified Algebra.Algebraic as Algebraic
import qualified Algebra.Field as Field
import qualified Algebra.Lattice as Lattice
import qualified Algebra.Ring as Ring
import qualified Algebra.Transcendental as Transcendental
import qualified Algebra.ZeroTestable as ZeroTestable
import Control.Monad
import qualified Control.Monad.State as State
import qualified Data.Graph.Inductive as FGL
import Data.Dynamic (Typeable)
import qualified Data.Dynamic as Dynamic
import qualified Language.Paraiso.OM.Arithmetic as A
import Language.Paraiso.OM.DynValue as DVal
import Language.Paraiso.OM.Graph
import Language.Paraiso.OM.Realm as Realm
import Language.Paraiso.OM.Reduce as Reduce
import Language.Paraiso.OM.Value as Val
import Language.Paraiso.Prelude
import Language.Paraiso.Tensor
import qualified Prelude (Num(..), Fractional(..))
data BuilderState vector gauge = BuilderState
{ setup :: Setup vector gauge,
target :: Graph vector gauge ()} deriving (Show)
makeKernel :: (Vector v, Ring.C g) =>
Setup v g
-> Name
-> Builder v g ()
-> Kernel v g ()
makeKernel setup0 name0 builder0 = let
state0 = initState setup0
graph = target $ snd $ State.runState builder0 state0
in Kernel{kernelName = name0, dataflow = graph}
initState :: Setup v g -> BuilderState v g
initState s = BuilderState {
setup = s,
target = FGL.empty
}
type Builder vector gauge val =
State.State (BuilderState vector gauge) val
instance Eq (Builder v g v2) where
_ == _ = undefined
instance Show (Builder v g v2) where
show _ = "<<REDACTED>>"
type B a = (Vector v, Ring.C g) => Builder v g a
type BuilderOf r c = (Vector v, Ring.C g) => Builder v g (Value r c)
modifyG :: (Vector v, Ring.C g) =>
(Graph v g () -> Graph v g ())
-> Builder v g ()
modifyG f = State.modify (\bs -> bs{target = f.target $ bs})
getG :: (Vector v, Ring.C g) => Builder v g (Graph v g ())
getG = fmap target State.get
freeNode :: B FGL.Node
freeNode = do
n <- fmap (FGL.noNodes) getG
return n
addNode :: (Vector v, Ring.C g) =>
[FGL.Node]
-> Node v g ()
-> Builder v g FGL.Node
addNode froms new = do
n <- freeNode
modifyG (([(EOrd i, froms !! i) | i <-[0..length froms 1] ], n, new, []) FGL.&)
return n
valueToNode :: (TRealm r, Typeable c) => Value r c -> B FGL.Node
valueToNode val = do
let
con = Val.content val
type0 = toDyn val
case val of
FromNode _ _ n -> return n
FromImm _ _ -> do
n0 <- addNode [] (NInst (Imm (Dynamic.toDyn con)) ())
n1 <- addNode [n0] (NValue type0 ())
return n1
lookUpStatic :: Named DynValue -> B ()
lookUpStatic (Named name0 type0)= do
st <- State.get
let
vs :: [Named DynValue]
vs = staticValues $ setup st
matches = filter ((==name0).name) vs
(Named _ type1) = head matches
when (length matches == 0) $ fail ("no name found: " ++ nameStr name0)
when (length matches > 1) $ fail ("multiple match found:" ++ nameStr name0)
when (type0 /= type1) $ fail ("type mismatch; expected: " ++ show type1 ++ "; " ++
" actual: " ++ nameStr name0 ++ "::" ++ show type0)
load :: (TRealm r, Typeable c) =>
r
-> c
-> Name
-> B (Value r c)
load r0 c0 name0 = do
let
type0 = mkDyn r0 c0
nv = Named name0 type0
lookUpStatic nv
n0 <- addNode [] (NInst (Load name0) ())
n1 <- addNode [n0] (NValue type0 ())
return (FromNode r0 c0 n1)
store :: (Vector v, Ring.C g, TRealm r, Typeable c) =>
Name
-> Builder v g (Value r c)
-> Builder v g ()
store name0 builder0 = do
val0 <- builder0
let
type0 = toDyn val0
nv = Named name0 type0
lookUpStatic nv
n0 <- valueToNode val0
_ <- addNode [n0] (NInst (Store name0) ())
return ()
reduce :: (Vector v, Ring.C g, Typeable c) =>
Reduce.Operator
-> Builder v g (Value TLocal c)
-> Builder v g (Value TGlobal c)
reduce op builder1 = do
val1 <- builder1
let
c1 = Val.content val1
type2 = mkDyn TGlobal c1
n1 <- valueToNode val1
n2 <- addNode [n1] (NInst (Reduce op) ())
n3 <- addNode [n2] (NValue type2 ())
return (FromNode TGlobal c1 n3)
broadcast :: (Vector v, Ring.C g, Typeable c) =>
Builder v g (Value TGlobal c)
-> Builder v g (Value TLocal c)
broadcast builder1 = do
val1 <- builder1
let
c1 = Val.content val1
type2 = mkDyn TLocal c1
n1 <- valueToNode val1
n2 <- addNode [n1] (NInst Broadcast ())
n3 <- addNode [n2] (NValue type2 ())
return (FromNode TLocal c1 n3)
shift :: (Vector v, Ring.C g, Typeable c, Additive.C (v g)) =>
v g
-> Builder v g (Value TLocal c)
-> Builder v g (Value TLocal c)
shift vec builder1 = do
val1 <- builder1
let
type1 = toDyn val1
c1 = Val.content val1
n1 <- valueToNode val1
n2 <- addNode [n1] (NInst (Shift (Additive.negate vec)) ())
n3 <- addNode [n2] (NValue type1 ())
return (FromNode TLocal c1 n3)
loadIndex :: (Vector v, Ring.C g, Typeable c) =>
c
-> Axis v
-> Builder v g (Value TLocal c)
loadIndex c0 axis = do
let type0 = mkDyn TLocal c0
n0 <- addNode [] (NInst (LoadIndex axis) ())
n1 <- addNode [n0] (NValue type0 ())
return (FromNode TLocal c0 n1)
imm :: (TRealm r, Typeable c) =>
c
-> B (Value r c)
imm c0 = return (FromImm unitTRealm c0)
mkOp1 :: (Vector v, Ring.C g, TRealm r, Typeable c) =>
A.Operator
-> (Builder v g (Value r c))
-> (Builder v g (Value r c))
mkOp1 op builder1 = do
v1 <- builder1
let
r1 = Val.realm v1
c1 = Val.content v1
n1 <- valueToNode v1
n0 <- addNode [n1] (NInst (Arith op) ())
n01 <- addNode [n0] (NValue (toDyn v1) ())
return $ FromNode r1 c1 n01
mkOp2 :: (Vector v, Ring.C g, TRealm r, Typeable c) =>
A.Operator
-> (Builder v g (Value r c))
-> (Builder v g (Value r c))
-> (Builder v g (Value r c))
mkOp2 op builder1 builder2 = do
v1 <- builder1
v2 <- builder2
let
r1 = Val.realm v1
c1 = Val.content v1
n1 <- valueToNode v1
n2 <- valueToNode v2
n0 <- addNode [n1, n2] (NInst (Arith op) ())
n01 <- addNode [n0] (NValue (toDyn v1) ())
return $ FromNode r1 c1 n01
instance (Vector v, Ring.C g, TRealm r, Typeable c, Additive.C c) => Additive.C (Builder v g (Value r c)) where
zero = return $ FromImm unitTRealm Additive.zero
(+) = mkOp2 A.Add
() = mkOp2 A.Sub
negate = mkOp1 A.Neg
instance (Vector v, Ring.C g, TRealm r, Typeable c, Ring.C c) => Ring.C (Builder v g (Value r c)) where
one = return $ FromImm unitTRealm Ring.one
(*) = mkOp2 A.Mul
fromInteger = imm . fromInteger
instance (Vector v, Ring.C g, TRealm r, Typeable c, Ring.C c) => Prelude.Num (Builder v g (Value r c)) where
(+) = (Additive.+)
(*) = (Ring.*)
() = (Additive.-)
negate = Additive.negate
abs = undefined
signum = undefined
fromInteger = Ring.fromInteger
instance (Vector v, Ring.C g, TRealm r, Typeable c, Field.C c) => Field.C (Builder v g (Value r c)) where
(/) = mkOp2 A.Div
recip = mkOp1 A.Inv
fromRational' = imm . fromRational'
instance (Vector v, Ring.C g, TRealm r, Typeable c, Field.C c, Prelude.Fractional c) => Prelude.Fractional (Builder v g (Value r c)) where
(/) = (Field./)
recip = Field.recip
fromRational = imm . Prelude.fromRational
instance (Vector v, Ring.C g, TRealm r) => Boolean (Builder v g (Value r Bool)) where
true = imm True
false = imm False
not = mkOp1 A.Not
(&&) = mkOp2 A.And
(||) = mkOp2 A.Or
instance (Vector v, Ring.C g, TRealm r, Typeable c, Algebraic.C c) => Algebraic.C (Builder v g (Value r c)) where
sqrt = mkOp1 A.Sqrt
x ^/ y = mkOp2 A.Pow x (fromRational' y)
instance (Vector v, Ring.C g, TRealm r, Typeable c) => Lattice.C (Builder v g (Value r c))
where
up = mkOp2 A.Max
dn = mkOp2 A.Min
instance (Vector v, Ring.C g, TRealm r, Typeable c) => ZeroTestable.C (Builder v g (Value r c))
where
isZero _ = error "isZero undefined for builder."
instance (Vector v, Ring.C g, TRealm r, Typeable c, Ring.C c) => Absolute.C (Builder v g (Value r c))
where
abs = mkOp1 A.Abs
signum = mkOp1 A.Signum
instance (Vector v, Ring.C g, TRealm r, Typeable c, Transcendental.C c) =>
Transcendental.C (Builder v g (Value r c)) where
pi = imm pi
exp = mkOp1 A.Exp
log = mkOp1 A.Log
sin = mkOp1 A.Sin
cos = mkOp1 A.Cos
tan = mkOp1 A.Tan
asin = mkOp1 A.Asin
acos = mkOp1 A.Acos
atan = mkOp1 A.Atan