{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE FlexibleContexts #-}
module Futhark.Pass
( PassM
, runPassM
, liftEither
, liftEitherM
, Pass (..)
, passLongOption
, intraproceduralTransformation
) where
import Control.Monad.Writer.Strict
import Control.Monad.Except hiding (liftEither)
import Control.Monad.State.Strict
import Control.Parallel.Strategies
import Data.Char
import Data.Either
import Prelude hiding (log)
import Futhark.Error
import Futhark.Representation.AST
import Futhark.Util.Log
import Futhark.MonadFreshNames
newtype PassM a = PassM (ExceptT InternalError (WriterT Log (State VNameSource)) a)
deriving (Functor, Applicative, Monad,
MonadError InternalError)
instance MonadLogger PassM where
addLog = PassM . tell
instance MonadFreshNames PassM where
putNameSource = PassM . put
getNameSource = PassM get
runPassM :: MonadFreshNames m =>
PassM a -> m (Either InternalError a, Log)
runPassM (PassM m) = modifyNameSource $ \src ->
runState (runWriterT $ runExceptT m) src
liftEither :: Show err => Either err a -> PassM a
liftEither (Left e) = compilerBugS $ show e
liftEither (Right v) = return v
liftEitherM :: Show err => PassM (Either err a) -> PassM a
liftEitherM m = liftEither =<< m
data Pass fromlore tolore =
Pass { passName :: String
, passDescription :: String
, passFunction :: Prog fromlore -> PassM (Prog tolore)
}
passLongOption :: Pass fromlore tolore -> String
passLongOption = map (spaceToDash . toLower) . passName
where spaceToDash ' ' = '-'
spaceToDash c = c
intraproceduralTransformation :: (FunDef fromlore -> PassM (FunDef tolore))
-> Prog fromlore -> PassM (Prog tolore)
intraproceduralTransformation ft prog =
either onError onSuccess <=< modifyNameSource $ \src ->
case partitionEithers $ parMap rpar (onFunction src) (progFunctions prog) of
([], rs) -> let (funs, logs, srcs) = unzip3 rs
in (Right (Prog funs, mconcat logs), mconcat srcs)
((err,log,src'):_, _) -> (Left (err, log), src')
where onFunction src f = case runState (runPassM (ft f)) src of
((Left x, log), src') -> Left (x, log, src')
((Right x, log), src') -> Right (x, log, src')
onError (err, log) = addLog log >> throwError err
onSuccess (x, log) = addLog log >> return x