{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE Strict #-}

-- | Definition of a polymorphic (generic) pass that can work with
-- programs of any lore.
module Futhark.Pass
  ( PassM,
    runPassM,
    liftEither,
    liftEitherM,
    Pass (..),
    passLongOption,
    parPass,
    intraproceduralTransformation,
    intraproceduralTransformationWithConsts,
  )
where

import Control.Monad.State.Strict
import Control.Monad.Writer.Strict
import Control.Parallel.Strategies
import Data.Char
import Data.Either
import Futhark.Error
import Futhark.IR
import Futhark.MonadFreshNames
import Futhark.Util.Log
import Prelude hiding (log)

-- | The monad in which passes execute.
newtype PassM a = PassM (WriterT Log (State VNameSource) a)
  deriving ((forall a b. (a -> b) -> PassM a -> PassM b)
-> (forall a b. a -> PassM b -> PassM a) -> Functor PassM
forall a b. a -> PassM b -> PassM a
forall a b. (a -> b) -> PassM a -> PassM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> PassM b -> PassM a
$c<$ :: forall a b. a -> PassM b -> PassM a
fmap :: forall a b. (a -> b) -> PassM a -> PassM b
$cfmap :: forall a b. (a -> b) -> PassM a -> PassM b
Functor, Functor PassM
Functor PassM
-> (forall a. a -> PassM a)
-> (forall a b. PassM (a -> b) -> PassM a -> PassM b)
-> (forall a b c. (a -> b -> c) -> PassM a -> PassM b -> PassM c)
-> (forall a b. PassM a -> PassM b -> PassM b)
-> (forall a b. PassM a -> PassM b -> PassM a)
-> Applicative PassM
forall a. a -> PassM a
forall a b. PassM a -> PassM b -> PassM a
forall a b. PassM a -> PassM b -> PassM b
forall a b. PassM (a -> b) -> PassM a -> PassM b
forall a b c. (a -> b -> c) -> PassM a -> PassM b -> PassM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. PassM a -> PassM b -> PassM a
$c<* :: forall a b. PassM a -> PassM b -> PassM a
*> :: forall a b. PassM a -> PassM b -> PassM b
$c*> :: forall a b. PassM a -> PassM b -> PassM b
liftA2 :: forall a b c. (a -> b -> c) -> PassM a -> PassM b -> PassM c
$cliftA2 :: forall a b c. (a -> b -> c) -> PassM a -> PassM b -> PassM c
<*> :: forall a b. PassM (a -> b) -> PassM a -> PassM b
$c<*> :: forall a b. PassM (a -> b) -> PassM a -> PassM b
pure :: forall a. a -> PassM a
$cpure :: forall a. a -> PassM a
Applicative, Applicative PassM
Applicative PassM
-> (forall a b. PassM a -> (a -> PassM b) -> PassM b)
-> (forall a b. PassM a -> PassM b -> PassM b)
-> (forall a. a -> PassM a)
-> Monad PassM
forall a. a -> PassM a
forall a b. PassM a -> PassM b -> PassM b
forall a b. PassM a -> (a -> PassM b) -> PassM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> PassM a
$creturn :: forall a. a -> PassM a
>> :: forall a b. PassM a -> PassM b -> PassM b
$c>> :: forall a b. PassM a -> PassM b -> PassM b
>>= :: forall a b. PassM a -> (a -> PassM b) -> PassM b
$c>>= :: forall a b. PassM a -> (a -> PassM b) -> PassM b
Monad)

instance MonadLogger PassM where
  addLog :: Log -> PassM ()
addLog = WriterT Log (State VNameSource) () -> PassM ()
forall a. WriterT Log (State VNameSource) a -> PassM a
PassM (WriterT Log (State VNameSource) () -> PassM ())
-> (Log -> WriterT Log (State VNameSource) ()) -> Log -> PassM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Log -> WriterT Log (State VNameSource) ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell

instance MonadFreshNames PassM where
  putNameSource :: VNameSource -> PassM ()
putNameSource = WriterT Log (State VNameSource) () -> PassM ()
forall a. WriterT Log (State VNameSource) a -> PassM a
PassM (WriterT Log (State VNameSource) () -> PassM ())
-> (VNameSource -> WriterT Log (State VNameSource) ())
-> VNameSource
-> PassM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VNameSource -> WriterT Log (State VNameSource) ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put
  getNameSource :: PassM VNameSource
getNameSource = WriterT Log (State VNameSource) VNameSource -> PassM VNameSource
forall a. WriterT Log (State VNameSource) a -> PassM a
PassM WriterT Log (State VNameSource) VNameSource
forall s (m :: * -> *). MonadState s m => m s
get

-- | Execute a 'PassM' action, yielding logging information and either
-- an error text or a result.
runPassM ::
  MonadFreshNames m =>
  PassM a ->
  m (a, Log)
runPassM :: forall (m :: * -> *) a. MonadFreshNames m => PassM a -> m (a, Log)
runPassM (PassM WriterT Log (State VNameSource) a
m) = (VNameSource -> ((a, Log), VNameSource)) -> m (a, Log)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((a, Log), VNameSource)) -> m (a, Log))
-> (VNameSource -> ((a, Log), VNameSource)) -> m (a, Log)
forall a b. (a -> b) -> a -> b
$ State VNameSource (a, Log)
-> VNameSource -> ((a, Log), VNameSource)
forall s a. State s a -> s -> (a, s)
runState (WriterT Log (State VNameSource) a -> State VNameSource (a, Log)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT WriterT Log (State VNameSource) a
m)

-- | Turn an 'Either' computation into a 'PassM'.  If the 'Either' is
-- 'Left', the result is a 'CompilerBug'.
liftEither :: Show err => Either err a -> PassM a
liftEither :: forall err a. Show err => Either err a -> PassM a
liftEither (Left err
e) = String -> PassM a
forall a. String -> a
compilerBugS (String -> PassM a) -> String -> PassM a
forall a b. (a -> b) -> a -> b
$ err -> String
forall a. Show a => a -> String
show err
e
liftEither (Right a
v) = a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return a
v

-- | Turn an 'Either' monadic computation into a 'PassM'.  If the 'Either' is
-- 'Left', the result is an exception.
liftEitherM :: Show err => PassM (Either err a) -> PassM a
liftEitherM :: forall err a. Show err => PassM (Either err a) -> PassM a
liftEitherM PassM (Either err a)
m = Either err a -> PassM a
forall err a. Show err => Either err a -> PassM a
liftEither (Either err a -> PassM a) -> PassM (Either err a) -> PassM a
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PassM (Either err a)
m

-- | A compiler pass transforming a 'Prog' of a given lore to a 'Prog'
-- of another lore.
data Pass fromlore tolore = Pass
  { -- | Name of the pass.  Keep this short and simple.  It will
    -- be used to automatically generate a command-line option
    -- name via 'passLongOption'.
    forall fromlore tolore. Pass fromlore tolore -> String
passName :: String,
    -- | A slightly longer description, which will show up in the
    -- command-line help text.
    forall fromlore tolore. Pass fromlore tolore -> String
passDescription :: String,
    forall fromlore tolore.
Pass fromlore tolore -> Prog fromlore -> PassM (Prog tolore)
passFunction :: Prog fromlore -> PassM (Prog tolore)
  }

-- | Take the name of the pass, turn spaces into dashes, and make all
-- characters lowercase.
passLongOption :: Pass fromlore tolore -> String
passLongOption :: forall fromlore tolore. Pass fromlore tolore -> String
passLongOption = (Char -> Char) -> String -> String
forall a b. (a -> b) -> [a] -> [b]
map (Char -> Char
spaceToDash (Char -> Char) -> (Char -> Char) -> Char -> Char
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Char -> Char
toLower) (String -> String)
-> (Pass fromlore tolore -> String)
-> Pass fromlore tolore
-> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pass fromlore tolore -> String
forall fromlore tolore. Pass fromlore tolore -> String
passName
  where
    spaceToDash :: Char -> Char
spaceToDash Char
' ' = Char
'-'
    spaceToDash Char
c = Char
c

-- | Apply a 'PassM' operation in parallel to multiple elements,
-- joining together the name sources and logs, and propagating any
-- error properly.
parPass :: (a -> PassM b) -> [a] -> PassM [b]
parPass :: forall a b. (a -> PassM b) -> [a] -> PassM [b]
parPass a -> PassM b
f [a]
as = do
  ([b]
x, Log
log) <- (VNameSource -> (([b], Log), VNameSource)) -> PassM ([b], Log)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (([b], Log), VNameSource)) -> PassM ([b], Log))
-> (VNameSource -> (([b], Log), VNameSource)) -> PassM ([b], Log)
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
    let ([b]
bs, [Log]
logs, [VNameSource]
srcs) = [(b, Log, VNameSource)] -> ([b], [Log], [VNameSource])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(b, Log, VNameSource)] -> ([b], [Log], [VNameSource]))
-> [(b, Log, VNameSource)] -> ([b], [Log], [VNameSource])
forall a b. (a -> b) -> a -> b
$ Strategy (b, Log, VNameSource)
-> (a -> (b, Log, VNameSource)) -> [a] -> [(b, Log, VNameSource)]
forall b a. Strategy b -> (a -> b) -> [a] -> [b]
parMap Strategy (b, Log, VNameSource)
forall a. Strategy a
rpar (VNameSource -> a -> (b, Log, VNameSource)
forall {p}.
MonadFreshNames (StateT p Identity) =>
p -> a -> (b, Log, p)
f' VNameSource
src) [a]
as
     in (([b]
bs, [Log] -> Log
forall a. Monoid a => [a] -> a
mconcat [Log]
logs), [VNameSource] -> VNameSource
forall a. Monoid a => [a] -> a
mconcat [VNameSource]
srcs)

  Log -> PassM ()
forall (m :: * -> *). MonadLogger m => Log -> m ()
addLog Log
log
  [b] -> PassM [b]
forall (m :: * -> *) a. Monad m => a -> m a
return [b]
x
  where
    f' :: p -> a -> (b, Log, p)
f' p
src a
a =
      let ((b
x', Log
log), p
src') = State p (b, Log) -> p -> ((b, Log), p)
forall s a. State s a -> s -> (a, s)
runState (PassM b -> State p (b, Log)
forall (m :: * -> *) a. MonadFreshNames m => PassM a -> m (a, Log)
runPassM (a -> PassM b
f a
a)) p
src
       in (b
x', Log
log, p
src')

-- | Apply some operation to the top-level constants.  Then applies an
-- operation to all the function function definitions, which are also
-- given the transformed constants so they can be brought into scope.
-- The function definition transformations are run in parallel (with
-- 'parPass'), since they cannot affect each other.
intraproceduralTransformationWithConsts ::
  (Stms fromlore -> PassM (Stms tolore)) ->
  (Stms tolore -> FunDef fromlore -> PassM (FunDef tolore)) ->
  Prog fromlore ->
  PassM (Prog tolore)
intraproceduralTransformationWithConsts :: forall fromlore tolore.
(Stms fromlore -> PassM (Stms tolore))
-> (Stms tolore -> FunDef fromlore -> PassM (FunDef tolore))
-> Prog fromlore
-> PassM (Prog tolore)
intraproceduralTransformationWithConsts Stms fromlore -> PassM (Stms tolore)
ct Stms tolore -> FunDef fromlore -> PassM (FunDef tolore)
ft (Prog Stms fromlore
consts [FunDef fromlore]
funs) = do
  Stms tolore
consts' <- Stms fromlore -> PassM (Stms tolore)
ct Stms fromlore
consts
  [FunDef tolore]
funs' <- (FunDef fromlore -> PassM (FunDef tolore))
-> [FunDef fromlore] -> PassM [FunDef tolore]
forall a b. (a -> PassM b) -> [a] -> PassM [b]
parPass (Stms tolore -> FunDef fromlore -> PassM (FunDef tolore)
ft Stms tolore
consts') [FunDef fromlore]
funs
  Prog tolore -> PassM (Prog tolore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Prog tolore -> PassM (Prog tolore))
-> Prog tolore -> PassM (Prog tolore)
forall a b. (a -> b) -> a -> b
$ Stms tolore -> [FunDef tolore] -> Prog tolore
forall lore. Stms lore -> [FunDef lore] -> Prog lore
Prog Stms tolore
consts' [FunDef tolore]
funs'

-- | Like 'intraproceduralTransformationWithConsts', but do not change
-- the top-level constants, and simply pass along their 'Scope'.
intraproceduralTransformation ::
  (Scope lore -> Stms lore -> PassM (Stms lore)) ->
  Prog lore ->
  PassM (Prog lore)
intraproceduralTransformation :: forall lore.
(Scope lore -> Stms lore -> PassM (Stms lore))
-> Prog lore -> PassM (Prog lore)
intraproceduralTransformation Scope lore -> Stms lore -> PassM (Stms lore)
f =
  (Stms lore -> PassM (Stms lore))
-> (Stms lore -> FunDef lore -> PassM (FunDef lore))
-> Prog lore
-> PassM (Prog lore)
forall fromlore tolore.
(Stms fromlore -> PassM (Stms tolore))
-> (Stms tolore -> FunDef fromlore -> PassM (FunDef tolore))
-> Prog fromlore
-> PassM (Prog tolore)
intraproceduralTransformationWithConsts (Scope lore -> Stms lore -> PassM (Stms lore)
f Scope lore
forall a. Monoid a => a
mempty) Stms lore -> FunDef lore -> PassM (FunDef lore)
forall {a}.
Scoped lore a =>
a -> FunDef lore -> PassM (FunDef lore)
f'
  where
    f' :: a -> FunDef lore -> PassM (FunDef lore)
f' a
consts FunDef lore
fd = do
      Stms lore
stms <-
        Scope lore -> Stms lore -> PassM (Stms lore)
f
          (a -> Scope lore
forall lore a. Scoped lore a => a -> Scope lore
scopeOf a
consts Scope lore -> Scope lore -> Scope lore
forall a. Semigroup a => a -> a -> a
<> [Param (FParamInfo lore)] -> Scope lore
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams (FunDef lore -> [Param (FParamInfo lore)]
forall lore. FunDef lore -> [FParam lore]
funDefParams FunDef lore
fd))
          (BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT lore -> Stms lore) -> BodyT lore -> Stms lore
forall a b. (a -> b) -> a -> b
$ FunDef lore -> BodyT lore
forall lore. FunDef lore -> BodyT lore
funDefBody FunDef lore
fd)
      FunDef lore -> PassM (FunDef lore)
forall (m :: * -> *) a. Monad m => a -> m a
return FunDef lore
fd {funDefBody :: BodyT lore
funDefBody = (FunDef lore -> BodyT lore
forall lore. FunDef lore -> BodyT lore
funDefBody FunDef lore
fd) {bodyStms :: Stms lore
bodyStms = Stms lore
stms}}