module Feldspar.Run.Compile where
import Control.Monad.Identity
import Control.Monad.Reader
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Constraint (Dict (..))
import Data.Default.Class
import Language.Syntactic hiding ((:+:) (..), (:<:) (..))
import Language.Syntactic.Functional hiding (Binding (..))
import Language.Syntactic.Functional.Tuple
import qualified Control.Monad.Operational.Higher as Oper
import Language.Embedded.Expression
import Language.Embedded.Imperative hiding ((:+:) (..), (:<:) (..))
import Language.Embedded.Concurrent
import qualified Language.Embedded.Imperative as Imp
import Language.Embedded.Backend.C (ExternalCompilerOpts (..))
import qualified Language.Embedded.Backend.C as Imp
import Data.TypedStruct
import Data.Selection
import Feldspar.Primitive.Representation
import Feldspar.Primitive.Backend.C ()
import Feldspar.Representation
import Feldspar.Run.Representation
import Feldspar.Optimize
type VExp = Struct PrimType' Prim
data VExp'
where
VExp' :: Struct PrimType' Prim a -> VExp'
newRefV :: Monad m => TypeRep a -> String -> TargetT m (Struct PrimType' Imp.Ref a)
newRefV t base = lift $ mapStructA (const (newNamedRef base)) t
initRefV :: Monad m => String -> VExp a -> TargetT m (Struct PrimType' Imp.Ref a)
initRefV base = lift . mapStructA (initNamedRef base)
getRefV :: Monad m => Struct PrimType' Imp.Ref a -> TargetT m (VExp a)
getRefV = lift . mapStructA getRef
setRefV :: Monad m => Struct PrimType' Imp.Ref a -> VExp a -> TargetT m ()
setRefV r = lift . sequence_ . zipListStruct setRef r
unsafeFreezeRefV :: Monad m => Struct PrimType' Imp.Ref a -> TargetT m (VExp a)
unsafeFreezeRefV = lift . mapStructA unsafeFreezeRef
data CompilerOpts = CompilerOpts
{ compilerAssertions :: Selection AssertionLabel
}
instance Default CompilerOpts
where
def = CompilerOpts
{ compilerAssertions = universal
}
data Env = Env
{ envAliases :: Map Name VExp'
, envOptions :: CompilerOpts
}
env0 :: Env
env0 = Env Map.empty def
localAlias :: MonadReader Env m
=> Name
-> VExp a
-> m b
-> m b
localAlias v e =
local (\env -> env {envAliases = Map.insert v (VExp' e) (envAliases env)})
lookAlias :: MonadReader Env m => TypeRep a -> Name -> m (VExp a)
lookAlias t v = do
env <- asks envAliases
return $ case Map.lookup v env of
Nothing -> error $ "lookAlias: variable " ++ show v ++ " not in scope"
Just (VExp' e) -> case typeEq t (toTypeRep e) of Just Dict -> e
type TargetCMD
= RefCMD
Imp.:+: ArrCMD
Imp.:+: ControlCMD
Imp.:+: ThreadCMD
Imp.:+: ChanCMD
Imp.:+: PtrCMD
Imp.:+: FileCMD
Imp.:+: C_CMD
type TargetT m = ReaderT Env (ProgramT TargetCMD (Param2 Prim PrimType') m)
type ProgC = Program TargetCMD (Param2 Prim PrimType')
translateExp :: forall m a . Monad m => Data a -> TargetT m (VExp a)
translateExp a = do
cs <- asks (compilerAssertions . envOptions)
goAST $ optimize cs $ unData a
where
goAST :: ASTF FeldDomain b -> TargetT m (VExp b)
goAST = simpleMatch (\(s :&: ValT t) -> go t s)
goSmallAST :: PrimType' b => ASTF FeldDomain b -> TargetT m (Prim b)
goSmallAST = fmap extractSingle . goAST
go :: TypeRep (DenResult sig)
-> FeldConstructs sig
-> Args (AST FeldDomain) sig
-> TargetT m (VExp (DenResult sig))
go t lit Nil
| Just (Lit a) <- prj lit
= return $ mapStruct (constExp . runIdentity) $ toStruct t a
go t lit Nil
| Just (Literal a) <- prj lit
= return $ mapStruct (constExp . runIdentity) $ toStruct t a
go t var Nil
| Just (VarT v) <- prj var
= lookAlias t v
go t lt (a :* (lam :$ body) :* Nil)
| Just (Let tag) <- prj lt
, Just (LamT v) <- prj lam
= do let base = if null tag then "let" else tag
r <- initRefV base =<< goAST a
a' <- unsafeFreezeRefV r
localAlias v a' $ goAST body
go t tup (a :* b :* Nil)
| Just Pair <- prj tup = Two <$> goAST a <*> goAST b
go t sel (ab :* Nil)
| Just Fst <- prj sel = do
Two a _ <- goAST ab
return a
| Just Snd <- prj sel = do
Two _ b <- goAST ab
return b
go _ c Nil
| Just Pi <- prj c = return $ Single $ sugarSymPrim Pi
go _ op (a :* Nil)
| Just Neg <- prj op = liftStruct (sugarSymPrim Neg) <$> goAST a
| Just Abs <- prj op = liftStruct (sugarSymPrim Abs) <$> goAST a
| Just Sign <- prj op = liftStruct (sugarSymPrim Sign) <$> goAST a
| Just Exp <- prj op = liftStruct (sugarSymPrim Exp) <$> goAST a
| Just Log <- prj op = liftStruct (sugarSymPrim Log) <$> goAST a
| Just Sqrt <- prj op = liftStruct (sugarSymPrim Sqrt) <$> goAST a
| Just Sin <- prj op = liftStruct (sugarSymPrim Sin) <$> goAST a
| Just Cos <- prj op = liftStruct (sugarSymPrim Cos) <$> goAST a
| Just Tan <- prj op = liftStruct (sugarSymPrim Tan) <$> goAST a
| Just Asin <- prj op = liftStruct (sugarSymPrim Asin) <$> goAST a
| Just Acos <- prj op = liftStruct (sugarSymPrim Acos) <$> goAST a
| Just Atan <- prj op = liftStruct (sugarSymPrim Atan) <$> goAST a
| Just Sinh <- prj op = liftStruct (sugarSymPrim Sinh) <$> goAST a
| Just Cosh <- prj op = liftStruct (sugarSymPrim Cosh) <$> goAST a
| Just Tanh <- prj op = liftStruct (sugarSymPrim Tanh) <$> goAST a
| Just Asinh <- prj op = liftStruct (sugarSymPrim Asinh) <$> goAST a
| Just Acosh <- prj op = liftStruct (sugarSymPrim Acosh) <$> goAST a
| Just Atanh <- prj op = liftStruct (sugarSymPrim Atanh) <$> goAST a
| Just Real <- prj op = liftStruct (sugarSymPrim Real) <$> goAST a
| Just Imag <- prj op = liftStruct (sugarSymPrim Imag) <$> goAST a
| Just Magnitude <- prj op = liftStruct (sugarSymPrim Magnitude) <$> goAST a
| Just Phase <- prj op = liftStruct (sugarSymPrim Phase) <$> goAST a
| Just Conjugate <- prj op = liftStruct (sugarSymPrim Conjugate) <$> goAST a
| Just I2N <- prj op = liftStruct (sugarSymPrim I2N) <$> goAST a
| Just I2B <- prj op = liftStruct (sugarSymPrim I2B) <$> goAST a
| Just B2I <- prj op = liftStruct (sugarSymPrim B2I) <$> goAST a
| Just Round <- prj op = liftStruct (sugarSymPrim Round) <$> goAST a
| Just Not <- prj op = liftStruct (sugarSymPrim Not) <$> goAST a
| Just BitCompl <- prj op = liftStruct (sugarSymPrim BitCompl) <$> goAST a
go _ op (a :* b :* Nil)
| Just Add <- prj op = liftStruct2 (sugarSymPrim Add) <$> goAST a <*> goAST b
| Just Sub <- prj op = liftStruct2 (sugarSymPrim Sub) <$> goAST a <*> goAST b
| Just Mul <- prj op = liftStruct2 (sugarSymPrim Mul) <$> goAST a <*> goAST b
| Just FDiv <- prj op = liftStruct2 (sugarSymPrim FDiv) <$> goAST a <*> goAST b
| Just Quot <- prj op = liftStruct2 (sugarSymPrim Quot) <$> goAST a <*> goAST b
| Just Rem <- prj op = liftStruct2 (sugarSymPrim Rem) <$> goAST a <*> goAST b
| Just Div <- prj op = liftStruct2 (sugarSymPrim Div) <$> goAST a <*> goAST b
| Just Mod <- prj op = liftStruct2 (sugarSymPrim Mod) <$> goAST a <*> goAST b
| Just Complex <- prj op = liftStruct2 (sugarSymPrim Complex) <$> goAST a <*> goAST b
| Just Polar <- prj op = liftStruct2 (sugarSymPrim Polar) <$> goAST a <*> goAST b
| Just Pow <- prj op = liftStruct2 (sugarSymPrim Pow) <$> goAST a <*> goAST b
| Just Eq <- prj op = liftStruct2 (sugarSymPrim Eq) <$> goAST a <*> goAST b
| Just And <- prj op = liftStruct2 (sugarSymPrim And) <$> goAST a <*> goAST b
| Just Or <- prj op = liftStruct2 (sugarSymPrim Or) <$> goAST a <*> goAST b
| Just Lt <- prj op = liftStruct2 (sugarSymPrim Lt) <$> goAST a <*> goAST b
| Just Gt <- prj op = liftStruct2 (sugarSymPrim Gt) <$> goAST a <*> goAST b
| Just Le <- prj op = liftStruct2 (sugarSymPrim Le) <$> goAST a <*> goAST b
| Just Ge <- prj op = liftStruct2 (sugarSymPrim Ge) <$> goAST a <*> goAST b
| Just BitAnd <- prj op = liftStruct2 (sugarSymPrim BitAnd) <$> goAST a <*> goAST b
| Just BitOr <- prj op = liftStruct2 (sugarSymPrim BitOr) <$> goAST a <*> goAST b
| Just BitXor <- prj op = liftStruct2 (sugarSymPrim BitXor) <$> goAST a <*> goAST b
| Just ShiftL <- prj op = liftStruct2 (sugarSymPrim ShiftL) <$> goAST a <*> goAST b
| Just ShiftR <- prj op = liftStruct2 (sugarSymPrim ShiftR) <$> goAST a <*> goAST b
go _ arrIx (i :* Nil)
| Just (ArrIx arr) <- prj arrIx = do
i' <- goSmallAST i
return $ Single $ sugarSymPrim (ArrIx arr) i'
go ty cond (c :* t :* f :* Nil)
| Just Cond <- prj cond = do
env <- ask
case (flip runReaderT env $ goAST t, flip runReaderT env $ goAST f) of
(t',f') -> do
tView <- lift $ lift $ Oper.viewT t'
fView <- lift $ lift $ Oper.viewT f'
case (tView,fView) of
(Oper.Return (Single tExp), Oper.Return (Single fExp)) -> do
c' <- goSmallAST c
return $ Single $ sugarSymPrim Cond c' tExp fExp
_ -> do
c' <- goSmallAST c
res <- newRefV ty "v"
ReaderT $ \env -> iff c'
(flip runReaderT env . setRefV res =<< t')
(flip runReaderT env . setRefV res =<< f')
unsafeFreezeRefV res
go t divBal (a :* b :* Nil)
| Just DivBalanced <- prj divBal
= liftStruct2 (sugarSymPrim Quot) <$> goAST a <*> goAST b
go t guard (cond :* a :* Nil)
| Just (GuardVal c msg) <- prj guard
= do cs <- asks (compilerAssertions . envOptions)
when (cs `includes` c) $ do
cond' <- extractSingle <$> goAST cond
lift $ assert cond' msg
goAST a
go t loop (len :* init :* (lami :$ (lams :$ body)) :* Nil)
| Just ForLoop <- prj loop
, Just (LamT iv) <- prj lami
, Just (LamT sv) <- prj lams
= do len' <- goSmallAST len
state <- initRefV "state" =<< goAST init
ReaderT $ \env -> for (0, 1, Excl len') $ \i -> flip runReaderT env $ do
s <- case t of
Single _ -> unsafeFreezeRefV state
_ -> getRefV state
s' <- localAlias iv (Single i) $
localAlias sv s $
goAST body
setRefV state s'
unsafeFreezeRefV state
go _ free Nil
| Just (FreeVar v) <- prj free = return $ Single $ sugarSymPrim $ FreeVar v
go t unsPerf Nil
| Just (UnsafePerform prog) <- prj unsPerf
= translateExp =<<
Oper.reexpressEnv unsafeTransSmallExp (Oper.liftProgram $ unComp prog)
go _ s _ = error $ "translateExp: no handling of symbol " ++ renderSym s
unsafeTransSmallExp :: Monad m => Data a -> TargetT m (Prim a)
unsafeTransSmallExp a = do
Single b <- translateExp a
return b
translate :: Env -> Run a -> ProgC a
translate env
= Oper.interpretWithMonadT Oper.singleton id
. flip runReaderT env . Oper.reexpressEnv unsafeTransSmallExp
. Oper.interpretWithMonadT Oper.singleton
(lift . flip runReaderT env . Oper.reexpressEnv unsafeTransSmallExp)
. unRun
instance (Imp.ControlCMD Oper.:<: instr) =>
Oper.Reexpressible AssertCMD instr Env
where
reexpressInstrEnv reexp (Assert c cond msg) = do
cs <- asks (compilerAssertions . envOptions)
when (cs `includes` c) $
(reexp cond >>= lift . flip Imp.assert msg)
runIO :: MonadRun m => m a -> IO a
runIO = Imp.runIO . translate env0 . liftRun
runIO' :: MonadRun m => m a -> IO a
runIO'
= Oper.interpretWithMonadBiT
(return . evalExp)
Oper.interpBi
(Imp.interpretBi (return . evalExp))
. unRun
. liftRun
captureIO :: MonadRun m
=> m a
-> String
-> IO String
captureIO = Imp.captureIO . translate env0 . liftRun
compile' :: MonadRun m => CompilerOpts -> m a -> String
compile' opts = Imp.compile . translate (Env mempty opts) . liftRun
compile :: MonadRun m => m a -> String
compile = compile' def {compilerAssertions = onlyUserAssertions}
compileAll' :: MonadRun m => CompilerOpts -> m a -> [(String, String)]
compileAll' opts = Imp.compileAll . translate (Env mempty opts) . liftRun
compileAll :: MonadRun m => m a -> [(String, String)]
compileAll = compileAll' def {compilerAssertions = onlyUserAssertions}
icompile' :: MonadRun m => CompilerOpts -> m a -> IO ()
icompile' opts = Imp.icompile . translate (Env mempty opts) . liftRun
icompile :: MonadRun m => m a -> IO ()
icompile = icompile' def {compilerAssertions = onlyUserAssertions}
compileAndCheck' :: MonadRun m
=> CompilerOpts
-> ExternalCompilerOpts
-> m a
-> IO ()
compileAndCheck' opts eopts =
Imp.compileAndCheck' eopts . translate (Env mempty opts) . liftRun
compileAndCheck :: MonadRun m => m a -> IO ()
compileAndCheck = compileAndCheck' def def
runCompiled' :: MonadRun m
=> CompilerOpts
-> ExternalCompilerOpts
-> m a
-> IO ()
runCompiled' opts eopts =
Imp.runCompiled' eopts . translate (Env mempty opts) . liftRun
runCompiled :: MonadRun m => m a -> IO ()
runCompiled = runCompiled' def def
withCompiled' :: MonadRun m
=> CompilerOpts
-> ExternalCompilerOpts
-> m a
-> ((String -> IO String) -> IO b)
-> IO b
withCompiled' opts eopts =
Imp.withCompiled' eopts . translate (Env mempty opts) . liftRun
withCompiled :: MonadRun m
=> m a
-> ((String -> IO String) -> IO b)
-> IO b
withCompiled = withCompiled' def def {externalSilent = True}
captureCompiled' :: MonadRun m
=> CompilerOpts
-> ExternalCompilerOpts
-> m a
-> String
-> IO String
captureCompiled' opts eopts =
Imp.captureCompiled' eopts . translate (Env mempty opts) . liftRun
captureCompiled :: MonadRun m
=> m a
-> String
-> IO String
captureCompiled = captureCompiled' def def
compareCompiled' :: MonadRun m
=> CompilerOpts
-> ExternalCompilerOpts
-> m a
-> IO a
-> String
-> IO ()
compareCompiled' opts eopts =
Imp.compareCompiled' eopts . translate (Env mempty opts) . liftRun
compareCompiled :: MonadRun m
=> m a
-> IO a
-> String
-> IO ()
compareCompiled = compareCompiled' def def