{-# LANGUAGE DeriveDataTypeable        #-}
{-# LANGUAGE FlexibleContexts          #-}
{-# LANGUAGE FlexibleInstances         #-}
{-# LANGUAGE NoMonomorphismRestriction #-}
{-# LANGUAGE ScopedTypeVariables       #-}
{-# LANGUAGE TupleSections             #-}
{-# LANGUAGE TypeSynonymInstances      #-}

module Language.Haskell.Liquid.Transforms.Rec (
     transformRecExpr, transformScope
     , outerScTr , innerScTr
     , isIdTRecBound, setIdTRecBound
     ) where

-- import           Bag
-- import           ErrUtils
import           Coercion
import           Control.Arrow                        (second)
import           Control.Monad.State
import           CoreSyn
import           CoreUtils
import qualified Data.HashMap.Strict                  as M
import           Data.Hashable
import           Id
import           IdInfo
import           Language.Haskell.Liquid.GHC.API      hiding (exprType)
import           Language.Haskell.Liquid.GHC.Misc
import           Language.Haskell.Liquid.GHC.Play
import           Language.Haskell.Liquid.Misc         (mapSndM)
import           Language.Fixpoint.Misc               (mapSnd) -- , traceShow)
import           Language.Haskell.Liquid.Types.Errors
import           MkCore                               (mkCoreLams)
import           Prelude                              hiding (error)

import qualified Data.List                            as L


transformRecExpr :: CoreProgram -> CoreProgram
transformRecExpr :: CoreProgram -> CoreProgram
transformRecExpr CoreProgram
cbs = CoreProgram
pg
  -- TODO-REBARE weird GHC crash on Data/Text/Array.hs | isEmptyBag $ filterBag isTypeError e
  -- TODO-REBARE weird GHC crash on Data/Text/Array.hs = pg
  -- TODO-REBARE weird GHC crash on Data/Text/Array.hs | otherwise
  -- TODO-REBARE weird GHC crash on Data/Text/Array.hs = panic Nothing ("Type-check" ++ showSDoc (pprMessageBag e))
  where 
    pg :: CoreProgram
pg     = CoreProgram -> CoreProgram
inlineFailCases CoreProgram
pg0
    pg0 :: CoreProgram
pg0    = State TrEnv CoreProgram -> TrEnv -> CoreProgram
forall s a. State s a -> s -> a
evalState (CoreProgram -> State TrEnv CoreProgram
forall (t :: * -> *).
Traversable t =>
t (Bind CoreBndr) -> State TrEnv (t (Bind CoreBndr))
transPg (Bind CoreBndr -> Bind CoreBndr
inlineLoopBreaker (Bind CoreBndr -> Bind CoreBndr) -> CoreProgram -> CoreProgram
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CoreProgram
cbs)) TrEnv
initEnv
    -- (_, e) = lintCoreBindings [] pg




inlineLoopBreaker :: Bind Id -> Bind Id
inlineLoopBreaker :: Bind CoreBndr -> Bind CoreBndr
inlineLoopBreaker (NonRec CoreBndr
x Expr CoreBndr
e) | Just (CoreBndr
lbx, Expr CoreBndr
lbe) <- Expr CoreBndr -> Maybe (CoreBndr, Expr CoreBndr)
hasLoopBreaker Expr CoreBndr
be
  = [(CoreBndr, Expr CoreBndr)] -> Bind CoreBndr
forall b. [(b, Expr b)] -> Bind b
Rec [(CoreBndr
x, (CoreBndr -> Expr CoreBndr -> Expr CoreBndr)
-> Expr CoreBndr -> [CoreBndr] -> Expr CoreBndr
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr CoreBndr -> Expr CoreBndr -> Expr CoreBndr
forall b. b -> Expr b -> Expr b
Lam (HashMap CoreBndr (Expr CoreBndr) -> Expr CoreBndr -> Expr CoreBndr
forall a. Subable a => HashMap CoreBndr (Expr CoreBndr) -> a -> a
sub (CoreBndr -> Expr CoreBndr -> HashMap CoreBndr (Expr CoreBndr)
forall k v. Hashable k => k -> v -> HashMap k v
M.singleton CoreBndr
lbx Expr CoreBndr
forall b. Expr b
e') Expr CoreBndr
lbe) ([CoreBndr]
αs [CoreBndr] -> [CoreBndr] -> [CoreBndr]
forall a. [a] -> [a] -> [a]
++ [CoreBndr]
as))]
  where
    ([CoreBndr]
αs, [CoreBndr]
as, Expr CoreBndr
be) = Expr CoreBndr -> ([CoreBndr], [CoreBndr], Expr CoreBndr)
collectTyAndValBinders Expr CoreBndr
e

    e' :: Expr b
e' = (Expr b -> Expr b -> Expr b) -> Expr b -> [Expr b] -> Expr b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
L.foldl' Expr b -> Expr b -> Expr b
forall b. Expr b -> Expr b -> Expr b
App ((Expr b -> Expr b -> Expr b) -> Expr b -> [Expr b] -> Expr b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
L.foldl' Expr b -> Expr b -> Expr b
forall b. Expr b -> Expr b -> Expr b
App (CoreBndr -> Expr b
forall b. CoreBndr -> Expr b
Var CoreBndr
x) ((Type -> Expr b
forall b. Type -> Expr b
Type (Type -> Expr b) -> (CoreBndr -> Type) -> CoreBndr -> Expr b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CoreBndr -> Type
TyVarTy) (CoreBndr -> Expr b) -> [CoreBndr] -> [Expr b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CoreBndr]
αs)) (CoreBndr -> Expr b
forall b. CoreBndr -> Expr b
Var (CoreBndr -> Expr b) -> [CoreBndr] -> [Expr b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CoreBndr]
as)

    hasLoopBreaker :: Expr CoreBndr -> Maybe (CoreBndr, Expr CoreBndr)
hasLoopBreaker (Let (Rec [(CoreBndr
x1, Expr CoreBndr
e1)]) (Var CoreBndr
x2)) | CoreBndr -> Bool
isLoopBreaker CoreBndr
x1 Bool -> Bool -> Bool
&& CoreBndr
x1 CoreBndr -> CoreBndr -> Bool
forall a. Eq a => a -> a -> Bool
== CoreBndr
x2 = (CoreBndr, Expr CoreBndr) -> Maybe (CoreBndr, Expr CoreBndr)
forall a. a -> Maybe a
Just (CoreBndr
x1, Expr CoreBndr
e1)
    hasLoopBreaker Expr CoreBndr
_                               = Maybe (CoreBndr, Expr CoreBndr)
forall a. Maybe a
Nothing

    isLoopBreaker :: CoreBndr -> Bool
isLoopBreaker =  OccInfo -> Bool
isStrongLoopBreaker (OccInfo -> Bool) -> (CoreBndr -> OccInfo) -> CoreBndr -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IdInfo -> OccInfo
occInfo (IdInfo -> OccInfo) -> (CoreBndr -> IdInfo) -> CoreBndr -> OccInfo
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HasDebugCallStack => CoreBndr -> IdInfo
CoreBndr -> IdInfo
idInfo

inlineLoopBreaker Bind CoreBndr
bs
  = Bind CoreBndr
bs

inlineFailCases :: CoreProgram -> CoreProgram
inlineFailCases :: CoreProgram -> CoreProgram
inlineFailCases = ([(CoreBndr, Expr CoreBndr)] -> Bind CoreBndr -> Bind CoreBndr
go [] (Bind CoreBndr -> Bind CoreBndr) -> CoreProgram -> CoreProgram
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>)
  where
    go :: [(CoreBndr, Expr CoreBndr)] -> Bind CoreBndr -> Bind CoreBndr
go [(CoreBndr, Expr CoreBndr)]
su (Rec [(CoreBndr, Expr CoreBndr)]
xes)    = [(CoreBndr, Expr CoreBndr)] -> Bind CoreBndr
forall b. [(b, Expr b)] -> Bind b
Rec ((Expr CoreBndr -> Expr CoreBndr)
-> (CoreBndr, Expr CoreBndr) -> (CoreBndr, Expr CoreBndr)
forall b c a. (b -> c) -> (a, b) -> (a, c)
mapSnd ([(CoreBndr, Expr CoreBndr)] -> Expr CoreBndr -> Expr CoreBndr
go' [(CoreBndr, Expr CoreBndr)]
su) ((CoreBndr, Expr CoreBndr) -> (CoreBndr, Expr CoreBndr))
-> [(CoreBndr, Expr CoreBndr)] -> [(CoreBndr, Expr CoreBndr)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(CoreBndr, Expr CoreBndr)]
xes)
    go [(CoreBndr, Expr CoreBndr)]
su (NonRec CoreBndr
x Expr CoreBndr
e) = CoreBndr -> Expr CoreBndr -> Bind CoreBndr
forall b. b -> Expr b -> Bind b
NonRec CoreBndr
x ([(CoreBndr, Expr CoreBndr)] -> Expr CoreBndr -> Expr CoreBndr
go' [(CoreBndr, Expr CoreBndr)]
su Expr CoreBndr
e)

    go' :: [(CoreBndr, Expr CoreBndr)] -> Expr CoreBndr -> Expr CoreBndr
go' [(CoreBndr, Expr CoreBndr)]
su (App (Var CoreBndr
x) Expr CoreBndr
_)       | CoreBndr -> Bool
isFailId CoreBndr
x, Just Expr CoreBndr
e <- CoreBndr -> [(CoreBndr, Expr CoreBndr)] -> Maybe (Expr CoreBndr)
forall a b. Eq a => a -> [(a, b)] -> Maybe b
getFailExpr CoreBndr
x [(CoreBndr, Expr CoreBndr)]
su = Expr CoreBndr
e
    go' [(CoreBndr, Expr CoreBndr)]
su (Let (NonRec CoreBndr
x Expr CoreBndr
ex) Expr CoreBndr
e) | CoreBndr -> Bool
isFailId CoreBndr
x   = [(CoreBndr, Expr CoreBndr)] -> Expr CoreBndr -> Expr CoreBndr
go' (CoreBndr
-> Expr CoreBndr
-> [(CoreBndr, Expr CoreBndr)]
-> [(CoreBndr, Expr CoreBndr)]
forall a b. a -> Expr b -> [(a, Expr b)] -> [(a, Expr b)]
addFailExpr CoreBndr
x ([(CoreBndr, Expr CoreBndr)] -> Expr CoreBndr -> Expr CoreBndr
go' [(CoreBndr, Expr CoreBndr)]
su Expr CoreBndr
ex) [(CoreBndr, Expr CoreBndr)]
su) Expr CoreBndr
e

    go' [(CoreBndr, Expr CoreBndr)]
su (App Expr CoreBndr
e1 Expr CoreBndr
e2)      = Expr CoreBndr -> Expr CoreBndr -> Expr CoreBndr
forall b. Expr b -> Expr b -> Expr b
App ([(CoreBndr, Expr CoreBndr)] -> Expr CoreBndr -> Expr CoreBndr
go' [(CoreBndr, Expr CoreBndr)]
su Expr CoreBndr
e1) ([(CoreBndr, Expr CoreBndr)] -> Expr CoreBndr -> Expr CoreBndr
go' [(CoreBndr, Expr CoreBndr)]
su Expr CoreBndr
e2)
    go' [(CoreBndr, Expr CoreBndr)]
su (Lam CoreBndr
x Expr CoreBndr
e)        = CoreBndr -> Expr CoreBndr -> Expr CoreBndr
forall b. b -> Expr b -> Expr b
Lam CoreBndr
x ([(CoreBndr, Expr CoreBndr)] -> Expr CoreBndr -> Expr CoreBndr
go' [(CoreBndr, Expr CoreBndr)]
su Expr CoreBndr
e)
    go' [(CoreBndr, Expr CoreBndr)]
su (Let Bind CoreBndr
xs Expr CoreBndr
e)       = Bind CoreBndr -> Expr CoreBndr -> Expr CoreBndr
forall b. Bind b -> Expr b -> Expr b
Let ([(CoreBndr, Expr CoreBndr)] -> Bind CoreBndr -> Bind CoreBndr
go [(CoreBndr, Expr CoreBndr)]
su Bind CoreBndr
xs) ([(CoreBndr, Expr CoreBndr)] -> Expr CoreBndr -> Expr CoreBndr
go' [(CoreBndr, Expr CoreBndr)]
su Expr CoreBndr
e)
    go' [(CoreBndr, Expr CoreBndr)]
su (Case Expr CoreBndr
e CoreBndr
x Type
t [Alt CoreBndr]
alt) = Expr CoreBndr
-> CoreBndr -> Type -> [Alt CoreBndr] -> Expr CoreBndr
forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case ([(CoreBndr, Expr CoreBndr)] -> Expr CoreBndr -> Expr CoreBndr
go' [(CoreBndr, Expr CoreBndr)]
su Expr CoreBndr
e) CoreBndr
x Type
t ([(CoreBndr, Expr CoreBndr)] -> Alt CoreBndr -> Alt CoreBndr
goalt [(CoreBndr, Expr CoreBndr)]
su (Alt CoreBndr -> Alt CoreBndr) -> [Alt CoreBndr] -> [Alt CoreBndr]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Alt CoreBndr]
alt)
    go' [(CoreBndr, Expr CoreBndr)]
su (Cast Expr CoreBndr
e Coercion
c)       = Expr CoreBndr -> Coercion -> Expr CoreBndr
forall b. Expr b -> Coercion -> Expr b
Cast ([(CoreBndr, Expr CoreBndr)] -> Expr CoreBndr -> Expr CoreBndr
go' [(CoreBndr, Expr CoreBndr)]
su Expr CoreBndr
e) Coercion
c
    go' [(CoreBndr, Expr CoreBndr)]
su (Tick Tickish CoreBndr
t Expr CoreBndr
e)       = Tickish CoreBndr -> Expr CoreBndr -> Expr CoreBndr
forall b. Tickish CoreBndr -> Expr b -> Expr b
Tick Tickish CoreBndr
t ([(CoreBndr, Expr CoreBndr)] -> Expr CoreBndr -> Expr CoreBndr
go' [(CoreBndr, Expr CoreBndr)]
su Expr CoreBndr
e)
    go' [(CoreBndr, Expr CoreBndr)]
_  Expr CoreBndr
e                = Expr CoreBndr
e

    goalt :: [(CoreBndr, Expr CoreBndr)] -> Alt CoreBndr -> Alt CoreBndr
goalt [(CoreBndr, Expr CoreBndr)]
su (AltCon
c, [CoreBndr]
xs, Expr CoreBndr
e)     = (AltCon
c, [CoreBndr]
xs, [(CoreBndr, Expr CoreBndr)] -> Expr CoreBndr -> Expr CoreBndr
go' [(CoreBndr, Expr CoreBndr)]
su Expr CoreBndr
e)

    isFailId :: CoreBndr -> Bool
isFailId CoreBndr
x  = CoreBndr -> Bool
isLocalId CoreBndr
x Bool -> Bool -> Bool
&& (Name -> Bool
isSystemName (Name -> Bool) -> Name -> Bool
forall a b. (a -> b) -> a -> b
$ CoreBndr -> Name
varName CoreBndr
x) Bool -> Bool -> Bool
&& [Char] -> [Char] -> Bool
forall a. Eq a => [a] -> [a] -> Bool
L.isPrefixOf [Char]
"fail" (CoreBndr -> [Char]
forall a. Show a => a -> [Char]
show CoreBndr
x)
    getFailExpr :: a -> [(a, b)] -> Maybe b
getFailExpr = a -> [(a, b)] -> Maybe b
forall a b. Eq a => a -> [(a, b)] -> Maybe b
L.lookup

    addFailExpr :: a -> Expr b -> [(a, Expr b)] -> [(a, Expr b)]
addFailExpr a
x (Lam b
_ Expr b
e) [(a, Expr b)]
su = (a
x, Expr b
e)(a, Expr b) -> [(a, Expr b)] -> [(a, Expr b)]
forall a. a -> [a] -> [a]
:[(a, Expr b)]
su
    addFailExpr a
_ Expr b
_         [(a, Expr b)]
_  = Maybe SrcSpan -> [Char] -> [(a, Expr b)]
forall a. Maybe SrcSpan -> [Char] -> a
impossible Maybe SrcSpan
forall a. Maybe a
Nothing [Char]
"internal error" -- this cannot happen

-- isTypeError :: SDoc -> Bool
-- isTypeError s | isInfixOf "Non term variable" (showSDoc s) = False
-- isTypeError _ = True

-- No need for this transformation after ghc-8!!!
transformScope :: [Bind Id] -> [Bind Id]
transformScope :: CoreProgram -> CoreProgram
transformScope = CoreProgram -> CoreProgram
outerScTr (CoreProgram -> CoreProgram)
-> (CoreProgram -> CoreProgram) -> CoreProgram -> CoreProgram
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CoreProgram -> CoreProgram
forall (f :: * -> *).
Functor f =>
f (Bind CoreBndr) -> f (Bind CoreBndr)
innerScTr 

outerScTr :: [Bind Id] -> [Bind Id]
outerScTr :: CoreProgram -> CoreProgram
outerScTr = (CoreBndr -> CoreProgram -> CoreProgram)
-> CoreProgram -> CoreProgram
forall b. (b -> [Bind b] -> [Bind b]) -> [Bind b] -> [Bind b]
mapNonRec (CoreProgram -> CoreBndr -> CoreProgram -> CoreProgram
forall t. [Bind t] -> CoreBndr -> [Bind t] -> [Bind t]
go [])
  where
   go :: [Bind t] -> CoreBndr -> [Bind t] -> [Bind t]
go [Bind t]
ack CoreBndr
x (Bind t
xe : [Bind t]
xes) | CoreBndr -> Bind t -> Bool
forall t. CoreBndr -> Bind t -> Bool
isCaseArg CoreBndr
x Bind t
xe = [Bind t] -> CoreBndr -> [Bind t] -> [Bind t]
go (Bind t
xeBind t -> [Bind t] -> [Bind t]
forall a. a -> [a] -> [a]
:[Bind t]
ack) CoreBndr
x [Bind t]
xes
   go [Bind t]
ack CoreBndr
_ [Bind t]
xes        = [Bind t]
ack [Bind t] -> [Bind t] -> [Bind t]
forall a. [a] -> [a] -> [a]
++ [Bind t]
xes

isCaseArg :: Id -> Bind t -> Bool
isCaseArg :: CoreBndr -> Bind t -> Bool
isCaseArg CoreBndr
x (NonRec t
_ (Case (Var CoreBndr
z) t
_ Type
_ [Alt t]
_)) = CoreBndr
z CoreBndr -> CoreBndr -> Bool
forall a. Eq a => a -> a -> Bool
== CoreBndr
x
isCaseArg CoreBndr
_ Bind t
_                               = Bool
False

innerScTr :: Functor f => f (Bind Id) -> f (Bind Id)
innerScTr :: f (Bind CoreBndr) -> f (Bind CoreBndr)
innerScTr = ((CoreBndr -> Expr CoreBndr -> Expr CoreBndr)
-> Bind CoreBndr -> Bind CoreBndr
forall b. (b -> Expr b -> Expr b) -> Bind b -> Bind b
mapBnd CoreBndr -> Expr CoreBndr -> Expr CoreBndr
scTrans (Bind CoreBndr -> Bind CoreBndr)
-> f (Bind CoreBndr) -> f (Bind CoreBndr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>)

scTrans :: Id -> Expr Id -> Expr Id
scTrans :: CoreBndr -> Expr CoreBndr -> Expr CoreBndr
scTrans CoreBndr
x Expr CoreBndr
e = (CoreBndr -> Expr CoreBndr -> Expr CoreBndr)
-> Expr CoreBndr -> Expr CoreBndr
forall b. (b -> Expr b -> Expr b) -> Expr b -> Expr b
mapExpr CoreBndr -> Expr CoreBndr -> Expr CoreBndr
scTrans (Expr CoreBndr -> Expr CoreBndr) -> Expr CoreBndr -> Expr CoreBndr
forall a b. (a -> b) -> a -> b
$ (Bind CoreBndr -> Expr CoreBndr -> Expr CoreBndr)
-> Expr CoreBndr -> CoreProgram -> Expr CoreBndr
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Bind CoreBndr -> Expr CoreBndr -> Expr CoreBndr
forall b. Bind b -> Expr b -> Expr b
Let Expr CoreBndr
e0 CoreProgram
bs
  where (CoreProgram
bs, Expr CoreBndr
e0)           = CoreProgram
-> CoreBndr -> Expr CoreBndr -> (CoreProgram, Expr CoreBndr)
forall t. [Bind t] -> CoreBndr -> Expr t -> ([Bind t], Expr t)
go [] CoreBndr
x Expr CoreBndr
e
        go :: [Bind t] -> CoreBndr -> Expr t -> ([Bind t], Expr t)
go [Bind t]
bs CoreBndr
x (Let Bind t
b Expr t
e)  | CoreBndr -> Bind t -> Bool
forall t. CoreBndr -> Bind t -> Bool
isCaseArg CoreBndr
x Bind t
b = [Bind t] -> CoreBndr -> Expr t -> ([Bind t], Expr t)
go (Bind t
bBind t -> [Bind t] -> [Bind t]
forall a. a -> [a] -> [a]
:[Bind t]
bs) CoreBndr
x Expr t
e
        go [Bind t]
bs CoreBndr
x (Tick Tickish CoreBndr
t Expr t
e) = (Expr t -> Expr t) -> ([Bind t], Expr t) -> ([Bind t], Expr t)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (Tickish CoreBndr -> Expr t -> Expr t
forall b. Tickish CoreBndr -> Expr b -> Expr b
Tick Tickish CoreBndr
t) (([Bind t], Expr t) -> ([Bind t], Expr t))
-> ([Bind t], Expr t) -> ([Bind t], Expr t)
forall a b. (a -> b) -> a -> b
$ [Bind t] -> CoreBndr -> Expr t -> ([Bind t], Expr t)
go [Bind t]
bs CoreBndr
x Expr t
e
        go [Bind t]
bs CoreBndr
_ Expr t
e          = ([Bind t]
bs, Expr t
e)

type TE = State TrEnv

data TrEnv = Tr { TrEnv -> Int
freshIndex  :: !Int
                , TrEnv -> SrcSpan
_loc        :: SrcSpan
                }

initEnv :: TrEnv
initEnv :: TrEnv
initEnv = Int -> SrcSpan -> TrEnv
Tr Int
0 SrcSpan
noSrcSpan

transPg :: Traversable t
        => t (Bind CoreBndr)
        -> State TrEnv (t (Bind CoreBndr))
transPg :: t (Bind CoreBndr) -> State TrEnv (t (Bind CoreBndr))
transPg = (Bind CoreBndr -> StateT TrEnv Identity (Bind CoreBndr))
-> t (Bind CoreBndr) -> State TrEnv (t (Bind CoreBndr))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Bind CoreBndr -> StateT TrEnv Identity (Bind CoreBndr)
transBd

transBd :: Bind CoreBndr
        -> State TrEnv (Bind CoreBndr)
transBd :: Bind CoreBndr -> StateT TrEnv Identity (Bind CoreBndr)
transBd (NonRec CoreBndr
x Expr CoreBndr
e) = (Expr CoreBndr -> Bind CoreBndr)
-> StateT TrEnv Identity (Expr CoreBndr)
-> StateT TrEnv Identity (Bind CoreBndr)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM (CoreBndr -> Expr CoreBndr -> Bind CoreBndr
forall b. b -> Expr b -> Bind b
NonRec CoreBndr
x) (Expr CoreBndr -> StateT TrEnv Identity (Expr CoreBndr)
transExpr (Expr CoreBndr -> StateT TrEnv Identity (Expr CoreBndr))
-> StateT TrEnv Identity (Expr CoreBndr)
-> StateT TrEnv Identity (Expr CoreBndr)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Bind CoreBndr -> StateT TrEnv Identity (Bind CoreBndr))
-> Expr CoreBndr -> StateT TrEnv Identity (Expr CoreBndr)
forall (m :: * -> *) t a. Monad m => t -> a -> m a
mapBdM Bind CoreBndr -> StateT TrEnv Identity (Bind CoreBndr)
transBd Expr CoreBndr
e)
transBd (Rec [(CoreBndr, Expr CoreBndr)]
xes)    = ([(CoreBndr, Expr CoreBndr)] -> Bind CoreBndr)
-> StateT TrEnv Identity [(CoreBndr, Expr CoreBndr)]
-> StateT TrEnv Identity (Bind CoreBndr)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM [(CoreBndr, Expr CoreBndr)] -> Bind CoreBndr
forall b. [(b, Expr b)] -> Bind b
Rec (StateT TrEnv Identity [(CoreBndr, Expr CoreBndr)]
 -> StateT TrEnv Identity (Bind CoreBndr))
-> StateT TrEnv Identity [(CoreBndr, Expr CoreBndr)]
-> StateT TrEnv Identity (Bind CoreBndr)
forall a b. (a -> b) -> a -> b
$ ((CoreBndr, Expr CoreBndr)
 -> StateT TrEnv Identity (CoreBndr, Expr CoreBndr))
-> [(CoreBndr, Expr CoreBndr)]
-> StateT TrEnv Identity [(CoreBndr, Expr CoreBndr)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((Expr CoreBndr -> StateT TrEnv Identity (Expr CoreBndr))
-> (CoreBndr, Expr CoreBndr)
-> StateT TrEnv Identity (CoreBndr, Expr CoreBndr)
forall (m :: * -> *) b c a.
Applicative m =>
(b -> m c) -> (a, b) -> m (a, c)
mapSndM ((Bind CoreBndr -> StateT TrEnv Identity (Bind CoreBndr))
-> Expr CoreBndr -> StateT TrEnv Identity (Expr CoreBndr)
forall (m :: * -> *) t a. Monad m => t -> a -> m a
mapBdM Bind CoreBndr -> StateT TrEnv Identity (Bind CoreBndr)
transBd)) [(CoreBndr, Expr CoreBndr)]
xes

transExpr :: CoreExpr -> TE CoreExpr
transExpr :: Expr CoreBndr -> StateT TrEnv Identity (Expr CoreBndr)
transExpr Expr CoreBndr
e
  | (Expr CoreBndr -> Bool
isNonPolyRec Expr CoreBndr
e') Bool -> Bool -> Bool
&& (Bool -> Bool
not ([CoreBndr] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [CoreBndr]
tvs))
  = [CoreBndr]
-> [CoreBndr]
-> CoreProgram
-> Expr CoreBndr
-> StateT TrEnv Identity (Expr CoreBndr)
forall (t :: * -> *).
Foldable t =>
[CoreBndr]
-> [CoreBndr]
-> t (Bind CoreBndr)
-> Expr CoreBndr
-> StateT TrEnv Identity (Expr CoreBndr)
trans [CoreBndr]
tvs [CoreBndr]
ids CoreProgram
bs Expr CoreBndr
e'
  | Bool
otherwise
  = Expr CoreBndr -> StateT TrEnv Identity (Expr CoreBndr)
forall (m :: * -> *) a. Monad m => a -> m a
return Expr CoreBndr
e
  where ([CoreBndr]
tvs, [CoreBndr]
ids, Expr CoreBndr
e'')       = Expr CoreBndr -> ([CoreBndr], [CoreBndr], Expr CoreBndr)
collectTyAndValBinders Expr CoreBndr
e
        (CoreProgram
bs, Expr CoreBndr
e')              = Expr CoreBndr -> (CoreProgram, Expr CoreBndr)
forall t. Expr t -> ([Bind t], Expr t)
collectNonRecLets Expr CoreBndr
e''

isNonPolyRec :: Expr CoreBndr -> Bool
isNonPolyRec :: Expr CoreBndr -> Bool
isNonPolyRec (Let (Rec [(CoreBndr, Expr CoreBndr)]
xes) Expr CoreBndr
_) = (Expr CoreBndr -> Bool) -> [Expr CoreBndr] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Expr CoreBndr -> Bool
nonPoly ((CoreBndr, Expr CoreBndr) -> Expr CoreBndr
forall a b. (a, b) -> b
snd ((CoreBndr, Expr CoreBndr) -> Expr CoreBndr)
-> [(CoreBndr, Expr CoreBndr)] -> [Expr CoreBndr]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(CoreBndr, Expr CoreBndr)]
xes)
isNonPolyRec Expr CoreBndr
_                 = Bool
False

nonPoly :: CoreExpr -> Bool
nonPoly :: Expr CoreBndr -> Bool
nonPoly = [CoreBndr] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([CoreBndr] -> Bool)
-> (Expr CoreBndr -> [CoreBndr]) -> Expr CoreBndr -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([CoreBndr], Type) -> [CoreBndr]
forall a b. (a, b) -> a
fst (([CoreBndr], Type) -> [CoreBndr])
-> (Expr CoreBndr -> ([CoreBndr], Type))
-> Expr CoreBndr
-> [CoreBndr]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> ([CoreBndr], Type)
splitForAllTys (Type -> ([CoreBndr], Type))
-> (Expr CoreBndr -> Type) -> Expr CoreBndr -> ([CoreBndr], Type)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Expr CoreBndr -> Type
exprType

collectNonRecLets :: Expr t -> ([Bind t], Expr t)
collectNonRecLets :: Expr t -> ([Bind t], Expr t)
collectNonRecLets = [Bind t] -> Expr t -> ([Bind t], Expr t)
forall b. [Bind b] -> Expr b -> ([Bind b], Expr b)
go []
  where go :: [Bind b] -> Expr b -> ([Bind b], Expr b)
go [Bind b]
bs (Let b :: Bind b
b@(NonRec b
_ Expr b
_) Expr b
e') = [Bind b] -> Expr b -> ([Bind b], Expr b)
go (Bind b
bBind b -> [Bind b] -> [Bind b]
forall a. a -> [a] -> [a]
:[Bind b]
bs) Expr b
e'
        go [Bind b]
bs Expr b
e'                      = ([Bind b] -> [Bind b]
forall a. [a] -> [a]
reverse [Bind b]
bs, Expr b
e')

appTysAndIds :: [Var] -> [Id] -> Id -> Expr b
appTysAndIds :: [CoreBndr] -> [CoreBndr] -> CoreBndr -> Expr b
appTysAndIds [CoreBndr]
tvs [CoreBndr]
ids CoreBndr
x = Expr b -> [Expr b] -> Expr b
forall b. Expr b -> [Expr b] -> Expr b
mkApps (Expr b -> [Type] -> Expr b
forall b. Expr b -> [Type] -> Expr b
mkTyApps (CoreBndr -> Expr b
forall b. CoreBndr -> Expr b
Var CoreBndr
x) ((CoreBndr -> Type) -> [CoreBndr] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map CoreBndr -> Type
TyVarTy [CoreBndr]
tvs)) ((CoreBndr -> Expr b) -> [CoreBndr] -> [Expr b]
forall a b. (a -> b) -> [a] -> [b]
map CoreBndr -> Expr b
forall b. CoreBndr -> Expr b
Var [CoreBndr]
ids)

trans :: Foldable t
      => [TyVar]
      -> [Var]
      -> t (Bind Id)
      -> Expr Var
      -> State TrEnv (Expr Id)
trans :: [CoreBndr]
-> [CoreBndr]
-> t (Bind CoreBndr)
-> Expr CoreBndr
-> StateT TrEnv Identity (Expr CoreBndr)
trans [CoreBndr]
vs [CoreBndr]
ids t (Bind CoreBndr)
bs (Let (Rec [(CoreBndr, Expr CoreBndr)]
xes) Expr CoreBndr
e)
  = (Expr CoreBndr -> Expr CoreBndr)
-> StateT TrEnv Identity (Expr CoreBndr)
-> StateT TrEnv Identity (Expr CoreBndr)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM (Expr CoreBndr -> Expr CoreBndr
mkLam (Expr CoreBndr -> Expr CoreBndr)
-> (Expr CoreBndr -> Expr CoreBndr)
-> Expr CoreBndr
-> Expr CoreBndr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Expr CoreBndr -> Expr CoreBndr
mkLet) ([CoreBndr]
-> [CoreBndr]
-> Expr CoreBndr
-> StateT TrEnv Identity (Expr CoreBndr)
makeTrans [CoreBndr]
vs [CoreBndr]
liveIds Expr CoreBndr
e')
  where liveIds :: [CoreBndr]
liveIds = CoreBndr -> CoreBndr
mkAlive (CoreBndr -> CoreBndr) -> [CoreBndr] -> [CoreBndr]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CoreBndr]
ids
        mkLet :: Expr CoreBndr -> Expr CoreBndr
mkLet Expr CoreBndr
e = (Bind CoreBndr -> Expr CoreBndr -> Expr CoreBndr)
-> Expr CoreBndr -> t (Bind CoreBndr) -> Expr CoreBndr
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Bind CoreBndr -> Expr CoreBndr -> Expr CoreBndr
forall b. Bind b -> Expr b -> Expr b
Let Expr CoreBndr
e t (Bind CoreBndr)
bs
        mkLam :: Expr CoreBndr -> Expr CoreBndr
mkLam Expr CoreBndr
e = (CoreBndr -> Expr CoreBndr -> Expr CoreBndr)
-> Expr CoreBndr -> [CoreBndr] -> Expr CoreBndr
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr CoreBndr -> Expr CoreBndr -> Expr CoreBndr
forall b. b -> Expr b -> Expr b
Lam Expr CoreBndr
e ([CoreBndr] -> Expr CoreBndr) -> [CoreBndr] -> Expr CoreBndr
forall a b. (a -> b) -> a -> b
$ [CoreBndr]
vs [CoreBndr] -> [CoreBndr] -> [CoreBndr]
forall a. [a] -> [a] -> [a]
++ [CoreBndr]
liveIds
        e' :: Expr CoreBndr
e'      = Bind CoreBndr -> Expr CoreBndr -> Expr CoreBndr
forall b. Bind b -> Expr b -> Expr b
Let ([(CoreBndr, Expr CoreBndr)] -> Bind CoreBndr
forall b. [(b, Expr b)] -> Bind b
Rec [(CoreBndr, Expr CoreBndr)]
xes') Expr CoreBndr
e
        xes' :: [(CoreBndr, Expr CoreBndr)]
xes'    = ((Expr CoreBndr -> Expr CoreBndr)
-> (CoreBndr, Expr CoreBndr) -> (CoreBndr, Expr CoreBndr)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second Expr CoreBndr -> Expr CoreBndr
mkLet) ((CoreBndr, Expr CoreBndr) -> (CoreBndr, Expr CoreBndr))
-> [(CoreBndr, Expr CoreBndr)] -> [(CoreBndr, Expr CoreBndr)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(CoreBndr, Expr CoreBndr)]
xes

trans [CoreBndr]
_ [CoreBndr]
_ t (Bind CoreBndr)
_ Expr CoreBndr
_ = Maybe SrcSpan -> [Char] -> StateT TrEnv Identity (Expr CoreBndr)
forall a. Maybe SrcSpan -> [Char] -> a
panic Maybe SrcSpan
forall a. Maybe a
Nothing [Char]
"TransformRec.trans called with invalid input"

makeTrans :: [TyVar]
          -> [Var]
          -> Expr Var
          -> State TrEnv (Expr Var)
makeTrans :: [CoreBndr]
-> [CoreBndr]
-> Expr CoreBndr
-> StateT TrEnv Identity (Expr CoreBndr)
makeTrans [CoreBndr]
vs [CoreBndr]
ids (Let (Rec [(CoreBndr, Expr CoreBndr)]
xes) Expr CoreBndr
e)
 = do [([CoreBndr], CoreBndr)]
fids    <- (CoreBndr -> StateT TrEnv Identity ([CoreBndr], CoreBndr))
-> [CoreBndr] -> StateT TrEnv Identity [([CoreBndr], CoreBndr)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([CoreBndr]
-> [CoreBndr]
-> CoreBndr
-> StateT TrEnv Identity ([CoreBndr], CoreBndr)
mkFreshIds [CoreBndr]
vs [CoreBndr]
ids) [CoreBndr]
xs
      let ([[CoreBndr]]
ids', [CoreBndr]
ys) = [([CoreBndr], CoreBndr)] -> ([[CoreBndr]], [CoreBndr])
forall a b. [(a, b)] -> ([a], [b])
unzip [([CoreBndr], CoreBndr)]
fids
      let yes :: [Expr b]
yes  = [CoreBndr] -> [CoreBndr] -> CoreBndr -> Expr b
forall b. [CoreBndr] -> [CoreBndr] -> CoreBndr -> Expr b
appTysAndIds [CoreBndr]
vs [CoreBndr]
ids (CoreBndr -> Expr b) -> [CoreBndr] -> [Expr b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CoreBndr]
ys
      [CoreBndr]
ys'     <- (CoreBndr -> StateT TrEnv Identity CoreBndr)
-> [CoreBndr] -> StateT TrEnv Identity [CoreBndr]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM CoreBndr -> StateT TrEnv Identity CoreBndr
forall a. Freshable a => a -> TE a
fresh [CoreBndr]
xs
      let su :: HashMap CoreBndr (Expr b)
su   = [(CoreBndr, Expr b)] -> HashMap CoreBndr (Expr b)
forall k v. (Eq k, Hashable k) => [(k, v)] -> HashMap k v
M.fromList ([(CoreBndr, Expr b)] -> HashMap CoreBndr (Expr b))
-> [(CoreBndr, Expr b)] -> HashMap CoreBndr (Expr b)
forall a b. (a -> b) -> a -> b
$ [CoreBndr] -> [Expr b] -> [(CoreBndr, Expr b)]
forall a b. [a] -> [b] -> [(a, b)]
zip [CoreBndr]
xs (CoreBndr -> Expr b
forall b. CoreBndr -> Expr b
Var (CoreBndr -> Expr b) -> [CoreBndr] -> [Expr b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CoreBndr]
ys')
      let rs :: [(CoreBndr, Expr b)]
rs   = [CoreBndr] -> [Expr b] -> [(CoreBndr, Expr b)]
forall a b. [a] -> [b] -> [(a, b)]
zip [CoreBndr]
ys' [Expr b]
forall b. [Expr b]
yes
      let es' :: [Expr CoreBndr]
es'  = ([CoreBndr] -> Expr CoreBndr -> Expr CoreBndr)
-> [[CoreBndr]] -> [Expr CoreBndr] -> [Expr CoreBndr]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith ([CoreBndr] -> [CoreBndr] -> Expr CoreBndr -> Expr CoreBndr
mkE [CoreBndr]
ys) [[CoreBndr]]
ids' [Expr CoreBndr]
es
      let xes' :: [(CoreBndr, Expr CoreBndr)]
xes' = [CoreBndr] -> [Expr CoreBndr] -> [(CoreBndr, Expr CoreBndr)]
forall a b. [a] -> [b] -> [(a, b)]
zip [CoreBndr]
ys [Expr CoreBndr]
es'
      Expr CoreBndr -> StateT TrEnv Identity (Expr CoreBndr)
forall (m :: * -> *) a. Monad m => a -> m a
return   (Expr CoreBndr -> StateT TrEnv Identity (Expr CoreBndr))
-> Expr CoreBndr -> StateT TrEnv Identity (Expr CoreBndr)
forall a b. (a -> b) -> a -> b
$ [(CoreBndr, Expr CoreBndr)]
-> Bind CoreBndr -> Expr CoreBndr -> Expr CoreBndr
forall b. [(b, Expr b)] -> Bind b -> Expr b -> Expr b
mkRecBinds [(CoreBndr, Expr CoreBndr)]
forall b. [(CoreBndr, Expr b)]
rs ([(CoreBndr, Expr CoreBndr)] -> Bind CoreBndr
forall b. [(b, Expr b)] -> Bind b
Rec [(CoreBndr, Expr CoreBndr)]
xes') (HashMap CoreBndr (Expr CoreBndr) -> Expr CoreBndr -> Expr CoreBndr
forall a. Subable a => HashMap CoreBndr (Expr CoreBndr) -> a -> a
sub HashMap CoreBndr (Expr CoreBndr)
forall b. HashMap CoreBndr (Expr b)
su Expr CoreBndr
e)
 where
   ([CoreBndr]
xs, [Expr CoreBndr]
es)       = [(CoreBndr, Expr CoreBndr)] -> ([CoreBndr], [Expr CoreBndr])
forall a b. [(a, b)] -> ([a], [b])
unzip [(CoreBndr, Expr CoreBndr)]
xes
   mkSu :: [CoreBndr] -> [CoreBndr] -> HashMap CoreBndr (Expr b)
mkSu [CoreBndr]
ys [CoreBndr]
ids'   = [CoreBndr]
-> [CoreBndr]
-> [CoreBndr]
-> [(CoreBndr, CoreBndr)]
-> HashMap CoreBndr (Expr b)
forall k b.
(Eq k, Hashable k) =>
[k]
-> [CoreBndr]
-> [CoreBndr]
-> [(k, CoreBndr)]
-> HashMap k (Expr b)
mkSubs [CoreBndr]
ids [CoreBndr]
vs [CoreBndr]
ids' ([CoreBndr] -> [CoreBndr] -> [(CoreBndr, CoreBndr)]
forall a b. [a] -> [b] -> [(a, b)]
zip [CoreBndr]
xs [CoreBndr]
ys)
   mkE :: [CoreBndr] -> [CoreBndr] -> Expr CoreBndr -> Expr CoreBndr
mkE [CoreBndr]
ys [CoreBndr]
ids' Expr CoreBndr
e' = [CoreBndr] -> Expr CoreBndr -> Expr CoreBndr
mkCoreLams ([CoreBndr]
vs [CoreBndr] -> [CoreBndr] -> [CoreBndr]
forall a. [a] -> [a] -> [a]
++ [CoreBndr]
ids') (HashMap CoreBndr (Expr CoreBndr) -> Expr CoreBndr -> Expr CoreBndr
forall a. Subable a => HashMap CoreBndr (Expr CoreBndr) -> a -> a
sub ([CoreBndr] -> [CoreBndr] -> HashMap CoreBndr (Expr CoreBndr)
forall b. [CoreBndr] -> [CoreBndr] -> HashMap CoreBndr (Expr b)
mkSu [CoreBndr]
ys [CoreBndr]
ids') Expr CoreBndr
e')

makeTrans [CoreBndr]
_ [CoreBndr]
_ Expr CoreBndr
_ = Maybe SrcSpan -> [Char] -> StateT TrEnv Identity (Expr CoreBndr)
forall a. Maybe SrcSpan -> [Char] -> a
panic Maybe SrcSpan
forall a. Maybe a
Nothing [Char]
"TransformRec.makeTrans called with invalid input"

mkRecBinds :: [(b, Expr b)] -> Bind b -> Expr b -> Expr b
mkRecBinds :: [(b, Expr b)] -> Bind b -> Expr b -> Expr b
mkRecBinds [(b, Expr b)]
xes Bind b
rs Expr b
e = Bind b -> Expr b -> Expr b
forall b. Bind b -> Expr b -> Expr b
Let Bind b
rs ((Expr b -> (b, Expr b) -> Expr b)
-> Expr b -> [(b, Expr b)] -> Expr b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
L.foldl' Expr b -> (b, Expr b) -> Expr b
forall b. Expr b -> (b, Expr b) -> Expr b
f Expr b
e [(b, Expr b)]
xes)
  where f :: Expr b -> (b, Expr b) -> Expr b
f Expr b
e (b
x, Expr b
xe) = Bind b -> Expr b -> Expr b
forall b. Bind b -> Expr b -> Expr b
Let (b -> Expr b -> Bind b
forall b. b -> Expr b -> Bind b
NonRec b
x Expr b
xe) Expr b
e

mkSubs :: (Eq k, Hashable k)
       => [k] -> [Var] -> [Id] -> [(k, Id)] -> M.HashMap k (Expr b)
mkSubs :: [k]
-> [CoreBndr]
-> [CoreBndr]
-> [(k, CoreBndr)]
-> HashMap k (Expr b)
mkSubs [k]
ids [CoreBndr]
tvs [CoreBndr]
xs [(k, CoreBndr)]
ys = [(k, Expr b)] -> HashMap k (Expr b)
forall k v. (Eq k, Hashable k) => [(k, v)] -> HashMap k v
M.fromList ([(k, Expr b)] -> HashMap k (Expr b))
-> [(k, Expr b)] -> HashMap k (Expr b)
forall a b. (a -> b) -> a -> b
$ [(k, Expr b)]
forall b. [(k, Expr b)]
s1 [(k, Expr b)] -> [(k, Expr b)] -> [(k, Expr b)]
forall a. [a] -> [a] -> [a]
++ [(k, Expr b)]
forall b. [(k, Expr b)]
s2
  where s1 :: [(k, Expr b)]
s1 = ((CoreBndr -> Expr b) -> (k, CoreBndr) -> (k, Expr b)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second ([CoreBndr] -> [CoreBndr] -> CoreBndr -> Expr b
forall b. [CoreBndr] -> [CoreBndr] -> CoreBndr -> Expr b
appTysAndIds [CoreBndr]
tvs [CoreBndr]
xs)) ((k, CoreBndr) -> (k, Expr b)) -> [(k, CoreBndr)] -> [(k, Expr b)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(k, CoreBndr)]
ys
        s2 :: [(k, Expr b)]
s2 = [k] -> [Expr b] -> [(k, Expr b)]
forall a b. [a] -> [b] -> [(a, b)]
zip [k]
ids (CoreBndr -> Expr b
forall b. CoreBndr -> Expr b
Var (CoreBndr -> Expr b) -> [CoreBndr] -> [Expr b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CoreBndr]
xs)

mkFreshIds :: [TyVar]
           -> [Var]
           -> Var
           -> State TrEnv ([Var], Id)
mkFreshIds :: [CoreBndr]
-> [CoreBndr]
-> CoreBndr
-> StateT TrEnv Identity ([CoreBndr], CoreBndr)
mkFreshIds [CoreBndr]
tvs [CoreBndr]
ids CoreBndr
x
  = do [CoreBndr]
ids'  <- (CoreBndr -> StateT TrEnv Identity CoreBndr)
-> [CoreBndr] -> StateT TrEnv Identity [CoreBndr]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM CoreBndr -> StateT TrEnv Identity CoreBndr
forall a. Freshable a => a -> TE a
fresh [CoreBndr]
ids
       let ids'' :: [CoreBndr]
ids'' = (CoreBndr -> CoreBndr) -> [CoreBndr] -> [CoreBndr]
forall a b. (a -> b) -> [a] -> [b]
map CoreBndr -> CoreBndr
setIdTRecBound [CoreBndr]
ids'
       let t :: Type
t  = [TyCoVarBinder] -> Type -> Type
mkForAllTys ((CoreBndr -> ArgFlag -> TyCoVarBinder
forall var argf. var -> argf -> VarBndr var argf
`Bndr` ArgFlag
Required) (CoreBndr -> TyCoVarBinder) -> [CoreBndr] -> [TyCoVarBinder]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CoreBndr]
tvs) (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ [CoreBndr] -> Type -> Type
forall (t :: * -> *). Foldable t => t CoreBndr -> Type -> Type
mkType ([CoreBndr] -> [CoreBndr]
forall a. [a] -> [a]
reverse [CoreBndr]
ids'') (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ CoreBndr -> Type
varType CoreBndr
x
       let x' :: CoreBndr
x' = CoreBndr -> Type -> CoreBndr
setVarType CoreBndr
x Type
t
       ([CoreBndr], CoreBndr)
-> StateT TrEnv Identity ([CoreBndr], CoreBndr)
forall (m :: * -> *) a. Monad m => a -> m a
return ([CoreBndr]
ids'', CoreBndr
x')
  where
    mkType :: t CoreBndr -> Type -> Type
mkType t CoreBndr
ids Type
ty = (Type -> CoreBndr -> Type) -> Type -> t CoreBndr -> Type
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\Type
t CoreBndr
x -> AnonArgFlag -> Type -> Type -> Type
FunTy AnonArgFlag
VisArg (CoreBndr -> Type
varType CoreBndr
x) Type
t) Type
ty t CoreBndr
ids -- FIXME(adinapoli): Is 'VisArg' OK here?

-- NOTE [Don't choose transform-rec binders as decreasing params]
-- --------------------------------------------------------------
--
-- We don't want to select a binder created by TransformRec as the
-- decreasing parameter, since the user didn't write it. Furthermore,
-- consider T1065. There we have an inner loop that decreases on the
-- sole list parameter. But TransformRec prepends the parameters to the
-- outer `groupByFB` to the inner `groupByFBCore`, and now the first
-- decreasing parameter is the constant `xs0`. Disaster!
--
-- So we need a way to signal to L.H.L.Constraint.Generate that we
-- should ignore these copied Vars. The easiest way to do that is to set
-- a flag on the Var that we know won't be set, and it just so happens
-- GHC has a bunch of optional flags that can be set by various Core
-- analyses that we don't run...
setIdTRecBound :: Id -> Id
-- This is an ugly hack..
setIdTRecBound :: CoreBndr -> CoreBndr
setIdTRecBound = HasDebugCallStack => (IdInfo -> IdInfo) -> CoreBndr -> CoreBndr
(IdInfo -> IdInfo) -> CoreBndr -> CoreBndr
modifyIdInfo (IdInfo -> CafInfo -> IdInfo
`setCafInfo` CafInfo
NoCafRefs)

isIdTRecBound :: Id -> Bool
isIdTRecBound :: CoreBndr -> Bool
isIdTRecBound = Bool -> Bool
not (Bool -> Bool) -> (CoreBndr -> Bool) -> CoreBndr -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CafInfo -> Bool
mayHaveCafRefs (CafInfo -> Bool) -> (CoreBndr -> CafInfo) -> CoreBndr -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IdInfo -> CafInfo
cafInfo (IdInfo -> CafInfo) -> (CoreBndr -> IdInfo) -> CoreBndr -> CafInfo
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HasDebugCallStack => CoreBndr -> IdInfo
CoreBndr -> IdInfo
idInfo

class Freshable a where
  fresh :: a -> TE a

instance Freshable Int where
  fresh :: Int -> TE Int
fresh Int
_ = TE Int
forall (m :: * -> *). MonadState TrEnv m => m Int
freshInt

instance Freshable Unique where
  fresh :: Unique -> TE Unique
fresh Unique
_ = TE Unique
forall (m :: * -> *). MonadState TrEnv m => m Unique
freshUnique

instance Freshable Var where
  fresh :: CoreBndr -> StateT TrEnv Identity CoreBndr
fresh CoreBndr
v = (Unique -> CoreBndr) -> TE Unique -> StateT TrEnv Identity CoreBndr
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM (CoreBndr -> Unique -> CoreBndr
setVarUnique CoreBndr
v) TE Unique
forall (m :: * -> *). MonadState TrEnv m => m Unique
freshUnique

freshInt :: MonadState TrEnv m => m Int
freshInt :: m Int
freshInt
  = do TrEnv
s <- m TrEnv
forall s (m :: * -> *). MonadState s m => m s
get
       let n :: Int
n = TrEnv -> Int
freshIndex TrEnv
s
       TrEnv -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put TrEnv
s{freshIndex :: Int
freshIndex = Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1}
       Int -> m Int
forall (m :: * -> *) a. Monad m => a -> m a
return Int
n

freshUnique :: MonadState TrEnv m => m Unique
freshUnique :: m Unique
freshUnique = (Int -> Unique) -> m Int -> m Unique
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM (Char -> Int -> Unique
mkUnique Char
'X') m Int
forall (m :: * -> *). MonadState TrEnv m => m Int
freshInt


mapNonRec :: (b -> [Bind b] -> [Bind b]) -> [Bind b] -> [Bind b]
mapNonRec :: (b -> [Bind b] -> [Bind b]) -> [Bind b] -> [Bind b]
mapNonRec b -> [Bind b] -> [Bind b]
f (NonRec b
x Expr b
xe:[Bind b]
xes) = b -> Expr b -> Bind b
forall b. b -> Expr b -> Bind b
NonRec b
x Expr b
xe Bind b -> [Bind b] -> [Bind b]
forall a. a -> [a] -> [a]
: b -> [Bind b] -> [Bind b]
f b
x ((b -> [Bind b] -> [Bind b]) -> [Bind b] -> [Bind b]
forall b. (b -> [Bind b] -> [Bind b]) -> [Bind b] -> [Bind b]
mapNonRec b -> [Bind b] -> [Bind b]
f [Bind b]
xes)
mapNonRec b -> [Bind b] -> [Bind b]
f (Bind b
xe:[Bind b]
xes)          = Bind b
xe Bind b -> [Bind b] -> [Bind b]
forall a. a -> [a] -> [a]
: (b -> [Bind b] -> [Bind b]) -> [Bind b] -> [Bind b]
forall b. (b -> [Bind b] -> [Bind b]) -> [Bind b] -> [Bind b]
mapNonRec b -> [Bind b] -> [Bind b]
f [Bind b]
xes
mapNonRec b -> [Bind b] -> [Bind b]
_ []                = []

mapBnd :: (b -> Expr b -> Expr b) -> Bind b -> Bind b
mapBnd :: (b -> Expr b -> Expr b) -> Bind b -> Bind b
mapBnd b -> Expr b -> Expr b
f (NonRec b
b Expr b
e)             = b -> Expr b -> Bind b
forall b. b -> Expr b -> Bind b
NonRec b
b ((b -> Expr b -> Expr b) -> Expr b -> Expr b
forall b. (b -> Expr b -> Expr b) -> Expr b -> Expr b
mapExpr b -> Expr b -> Expr b
f  Expr b
e)
mapBnd b -> Expr b -> Expr b
f (Rec [(b, Expr b)]
bs)                 = [(b, Expr b)] -> Bind b
forall b. [(b, Expr b)] -> Bind b
Rec (((b, Expr b) -> (b, Expr b)) -> [(b, Expr b)] -> [(b, Expr b)]
forall a b. (a -> b) -> [a] -> [b]
map ((Expr b -> Expr b) -> (b, Expr b) -> (b, Expr b)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second ((b -> Expr b -> Expr b) -> Expr b -> Expr b
forall b. (b -> Expr b -> Expr b) -> Expr b -> Expr b
mapExpr b -> Expr b -> Expr b
f)) [(b, Expr b)]
bs)

mapExpr :: (b -> Expr b -> Expr b) -> Expr b -> Expr b
mapExpr :: (b -> Expr b -> Expr b) -> Expr b -> Expr b
mapExpr b -> Expr b -> Expr b
f (Let (NonRec b
x Expr b
ex) Expr b
e)   = Bind b -> Expr b -> Expr b
forall b. Bind b -> Expr b -> Expr b
Let (b -> Expr b -> Bind b
forall b. b -> Expr b -> Bind b
NonRec b
x (b -> Expr b -> Expr b
f b
x Expr b
ex) ) (b -> Expr b -> Expr b
f b
x Expr b
e)
mapExpr b -> Expr b -> Expr b
f (App Expr b
e1 Expr b
e2)             = Expr b -> Expr b -> Expr b
forall b. Expr b -> Expr b -> Expr b
App  ((b -> Expr b -> Expr b) -> Expr b -> Expr b
forall b. (b -> Expr b -> Expr b) -> Expr b -> Expr b
mapExpr b -> Expr b -> Expr b
f Expr b
e1) ((b -> Expr b -> Expr b) -> Expr b -> Expr b
forall b. (b -> Expr b -> Expr b) -> Expr b -> Expr b
mapExpr b -> Expr b -> Expr b
f Expr b
e2)
mapExpr b -> Expr b -> Expr b
f (Lam b
b Expr b
e)               = b -> Expr b -> Expr b
forall b. b -> Expr b -> Expr b
Lam b
b ((b -> Expr b -> Expr b) -> Expr b -> Expr b
forall b. (b -> Expr b -> Expr b) -> Expr b -> Expr b
mapExpr b -> Expr b -> Expr b
f Expr b
e)
mapExpr b -> Expr b -> Expr b
f (Let Bind b
bs Expr b
e)              = Bind b -> Expr b -> Expr b
forall b. Bind b -> Expr b -> Expr b
Let ((b -> Expr b -> Expr b) -> Bind b -> Bind b
forall b. (b -> Expr b -> Expr b) -> Bind b -> Bind b
mapBnd b -> Expr b -> Expr b
f Bind b
bs) ((b -> Expr b -> Expr b) -> Expr b -> Expr b
forall b. (b -> Expr b -> Expr b) -> Expr b -> Expr b
mapExpr b -> Expr b -> Expr b
f Expr b
e)
mapExpr b -> Expr b -> Expr b
f (Case Expr b
e b
b Type
t [Alt b]
alt)        = Expr b -> b -> Type -> [Alt b] -> Expr b
forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case Expr b
e b
b Type
t ((Alt b -> Alt b) -> [Alt b] -> [Alt b]
forall a b. (a -> b) -> [a] -> [b]
map ((b -> Expr b -> Expr b) -> Alt b -> Alt b
forall b t t1.
(b -> Expr b -> Expr b) -> (t, t1, Expr b) -> (t, t1, Expr b)
mapAlt b -> Expr b -> Expr b
f) [Alt b]
alt)
mapExpr b -> Expr b -> Expr b
f (Tick Tickish CoreBndr
t Expr b
e)              = Tickish CoreBndr -> Expr b -> Expr b
forall b. Tickish CoreBndr -> Expr b -> Expr b
Tick Tickish CoreBndr
t ((b -> Expr b -> Expr b) -> Expr b -> Expr b
forall b. (b -> Expr b -> Expr b) -> Expr b -> Expr b
mapExpr b -> Expr b -> Expr b
f Expr b
e)
mapExpr b -> Expr b -> Expr b
_  Expr b
e                      = Expr b
e

mapAlt :: (b -> Expr b -> Expr b) -> (t, t1, Expr b) -> (t, t1, Expr b)
mapAlt :: (b -> Expr b -> Expr b) -> (t, t1, Expr b) -> (t, t1, Expr b)
mapAlt b -> Expr b -> Expr b
f (t
d, t1
bs, Expr b
e) = (t
d, t1
bs, (b -> Expr b -> Expr b) -> Expr b -> Expr b
forall b. (b -> Expr b -> Expr b) -> Expr b -> Expr b
mapExpr b -> Expr b -> Expr b
f Expr b
e)

-- Do not apply transformations to inner code

mapBdM :: Monad m => t -> a -> m a
mapBdM :: t -> a -> m a
mapBdM t
_ = a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return

-- mapBdM f (Let b e)        = liftM2 Let (f b) (mapBdM f e)
-- mapBdM f (App e1 e2)      = liftM2 App (mapBdM f e1) (mapBdM f e2)
-- mapBdM f (Lam b e)        = liftM (Lam b) (mapBdM f e)
-- mapBdM f (Case e b t alt) = liftM (Case e b t) (mapM (mapBdAltM f) alt)
-- mapBdM f (Tick t e)       = liftM (Tick t) (mapBdM f e)
-- mapBdM _  e               = return  e
--
-- mapBdAltM f (d, bs, e) = liftM ((,,) d bs) (mapBdM f e)