module Language.Paraiso.OM.Graph
(
Setup(..), Kernel(..), Graph, nmap, imap, getA,
Node(..), Edge(..),
StaticIdx(..),
Inst(..),
)where
import Data.Dynamic
import Data.Tensor.TypeLevel
import qualified Data.Vector as V
import qualified Data.Graph.Inductive as FGL
import Language.Paraiso.Name
import Language.Paraiso.OM.Arithmetic as A
import Language.Paraiso.OM.Reduce as R
import Language.Paraiso.OM.DynValue
import NumericPrelude
data Setup (vector :: * -> *) gauge anot =
Setup {
staticValues :: V.Vector (Named DynValue),
globalAnnotation :: anot
} deriving (Eq, Show)
data Kernel vector gauge anot =
Kernel {
kernelName :: Name,
dataflow :: Graph vector gauge anot
}
deriving (Show)
instance Nameable (Kernel v g a) where
name = kernelName
type Graph vector gauge anot = FGL.Gr (Node vector gauge anot) Edge
nmap :: (a -> b) -> Graph v g a -> Graph v g b
nmap f = FGL.nmap (napply f)
where
napply f0 (NValue x a0) = (NValue x $ f0 a0)
napply f0 (NInst x a0) = (NInst x $ f0 a0)
imap :: (FGL.Node -> a -> b) -> Graph v g a -> Graph v g b
imap f graph = FGL.mkGraph (map (\(i,a) -> (i, update i a)) $ FGL.labNodes graph) (FGL.labEdges graph)
where
update i (NValue x a0) = (NValue x $ f i a0)
update i (NInst x a0) = (NInst x $ f i a0)
data Node vector gauge anot =
NValue DynValue anot |
NInst (Inst vector gauge) anot
deriving (Show)
data Edge =
EUnord |
EOrd Int deriving (Eq, Ord, Show)
getA :: Node v g a -> a
getA nd = case nd of
NValue _ x -> x
NInst _ x -> x
instance Functor (Node v g) where
fmap f (NValue x y) = (NValue x (f y))
fmap f (NInst x y) = (NInst x (f y))
newtype StaticIdx = StaticIdx { fromStaticIdx :: Int}
instance Show StaticIdx where
show (StaticIdx x) = "static[" ++ show x ++ "]"
data Inst vector gauge
= Imm Dynamic
| Load StaticIdx
| Store StaticIdx
| Reduce R.Operator
| Broadcast
| Shift (vector gauge)
| LoadIndex (Axis vector)
| LoadSize (Axis vector)
| Arith A.Operator
deriving (Show)
instance Arity (Inst vector gauge) where
arity a = case a of
Imm _ -> (0,1)
Load _ -> (0,1)
Store _ -> (1,0)
Reduce _ -> (1,1)
Broadcast -> (1,1)
Shift _ -> (1,1)
LoadIndex _ -> (0,1)
LoadSize _ -> (0,1)
Arith op -> arity op