{-# LANGUAGE FlexibleContexts #-}

module Language.Haskell.Liquid.Transforms.InlineAux
  ( inlineAux
  )
where
import qualified Language.Haskell.Liquid.UX.Config  as UX
import           Liquid.GHC.API
import           Control.Arrow                  (second)
import qualified Language.Haskell.Liquid.GHC.Misc
                                               as GM
import qualified Data.HashMap.Strict           as M

inlineAux :: UX.Config -> Module -> CoreProgram -> CoreProgram
inlineAux :: Config -> Module -> CoreProgram -> CoreProgram
inlineAux Config
cfg Module
m CoreProgram
cbs =  if Config -> Bool
UX.auxInline Config
cfg then Module
-> (Id -> Bool)
-> (Activation -> Bool)
-> [CoreRule]
-> CoreProgram
-> CoreProgram
occurAnalysePgm Module
m (forall a b. a -> b -> a
const Bool
False) (forall a b. a -> b -> a
const Bool
False) [] (forall a b. (a -> b) -> [a] -> [b]
map CoreBind -> CoreBind
f CoreProgram
cbs) else CoreProgram
cbs
 where
  f :: CoreBind -> CoreBind
  f :: CoreBind -> CoreBind
f all' :: CoreBind
all'@(NonRec Id
x CoreExpr
e)
    | Just (Id
dfunId, HashMap Id Id
methodToAux) <- forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
M.lookup Id
x HashMap Id (Id, HashMap Id Id)
auxToMethodToAux = forall b. b -> Expr b -> Bind b
NonRec
      Id
x
      (Id -> HashMap Id Id -> CoreExpr -> CoreExpr
inlineAuxExpr Id
dfunId HashMap Id Id
methodToAux CoreExpr
e)
    | Bool
otherwise = CoreBind
all'
  f (Rec [(Id, CoreExpr)]
bs) = forall b. [(b, Expr b)] -> Bind b
Rec (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Id, CoreExpr) -> (Id, CoreExpr)
g [(Id, CoreExpr)]
bs)
   where
    g :: (Id, CoreExpr) -> (Id, CoreExpr)
g all' :: (Id, CoreExpr)
all'@(Id
x, CoreExpr
e)
      | Just (Id
dfunId, HashMap Id Id
methodToAux) <- forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
M.lookup Id
x HashMap Id (Id, HashMap Id Id)
auxToMethodToAux
      = (Id
x, Id -> HashMap Id Id -> CoreExpr -> CoreExpr
inlineAuxExpr Id
dfunId HashMap Id Id
methodToAux CoreExpr
e)
      | Bool
otherwise
      = (Id, CoreExpr)
all'
  auxToMethodToAux :: HashMap Id (Id, HashMap Id Id)
auxToMethodToAux = forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Id -> CoreExpr -> HashMap Id (Id, HashMap Id Id)
dfunIdSubst) (CoreProgram -> [(Id, CoreExpr)]
grepDFunIds CoreProgram
cbs)


-- inlineDFun :: DynFlags -> CoreProgram -> IO CoreProgram
-- inlineDFun df cbs = mapM go cbs
--  where
--   go orig@(NonRec x e) | isDFunId x = do
--                            -- e''' <- simplifyExpr df e''
--                            let newBody = mkCoreApps (GM.tracePpr ("substituted type:" ++ GM.showPpr (exprType (mkCoreApps e' (Var <$> binders)))) e') (fmap Var binders)
--                                bind = NonRec (mkWildValBinder (exprType newBody)) newBody
--                            pure $ NonRec x (mkLet bind e)
--                        | otherwise  = pure orig
--    where
--     -- wcBinder = mkWildValBinder t
--     (binders, _) = GM.tracePpr "collectBinders"$ collectBinders e
--     e' = substExprAll empty subst e
--   go recs = pure recs
--   subst = buildDictSubst cbs

-- grab the dictionaries
grepDFunIds :: CoreProgram -> [(DFunId, CoreExpr)]
grepDFunIds :: CoreProgram -> [(Id, CoreExpr)]
grepDFunIds = forall a. (a -> Bool) -> [a] -> [a]
filter (Id -> Bool
isDFunId forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall b. [Bind b] -> [(b, Expr b)]
flattenBinds

isClassOpAuxOccName :: OccName -> Bool
isClassOpAuxOccName :: OccName -> Bool
isClassOpAuxOccName OccName
occ = case OccName -> [Char]
occNameString OccName
occ of
  Char
'$' : Char
'c' : [Char]
_ -> Bool
True
  [Char]
_             -> Bool
False

isClassOpAuxOf :: Id -> Id -> Bool
isClassOpAuxOf :: Id -> Id -> Bool
isClassOpAuxOf Id
aux Id
method = case OccName -> [Char]
occNameString forall a b. (a -> b) -> a -> b
$ forall a. NamedThing a => a -> OccName
getOccName Id
aux of
  Char
'$' : Char
'c' : [Char]
rest -> [Char]
rest forall a. Eq a => a -> a -> Bool
== OccName -> [Char]
occNameString (forall a. NamedThing a => a -> OccName
getOccName Id
method)
  [Char]
_                -> Bool
False

dfunIdSubst :: DFunId -> CoreExpr -> M.HashMap Id (Id, M.HashMap Id Id)
dfunIdSubst :: Id -> CoreExpr -> HashMap Id (Id, HashMap Id Id)
dfunIdSubst Id
dfunId CoreExpr
e = forall k v. (Eq k, Hashable k) => [(k, v)] -> HashMap k v
M.fromList forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [Id]
auxIds (forall a. a -> [a]
repeat (Id
dfunId, HashMap Id Id
methodToAux))
 where
  methodToAux :: HashMap Id Id
methodToAux = forall k v. (Eq k, Hashable k) => [(k, v)] -> HashMap k v
M.fromList
    [ (Id
m, Id
aux) | Id
m <- [Id]
methods, Id
aux <- [Id]
auxIds, Id
aux Id -> Id -> Bool
`isClassOpAuxOf` Id
m ]
  ([Id]
_, [Type]
_, Class
cls, [Type]
_) = Type -> ([Id], [Type], Class, [Type])
tcSplitDFunTy (Id -> Type
idType Id
dfunId)
  auxIds :: [Id]
auxIds = forall a. (a -> Bool) -> [a] -> [a]
filter (OccName -> Bool
isClassOpAuxOccName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. NamedThing a => a -> OccName
getOccName) (CoreExpr -> [Id]
exprFreeVarsList CoreExpr
e)
  methods :: [Id]
methods = Class -> [Id]
classAllSelIds Class
cls

inlineAuxExpr :: DFunId -> M.HashMap Id Id -> CoreExpr -> CoreExpr
inlineAuxExpr :: Id -> HashMap Id Id -> CoreExpr -> CoreExpr
inlineAuxExpr Id
dfunId HashMap Id Id
methodToAux = CoreExpr -> CoreExpr
go
 where
  go :: CoreExpr -> CoreExpr
  go :: CoreExpr -> CoreExpr
go (Lam Id
b CoreExpr
body) = forall b. b -> Expr b -> Expr b
Lam Id
b (CoreExpr -> CoreExpr
go CoreExpr
body)
  go (Let CoreBind
b CoreExpr
body)
    | NonRec Id
x CoreExpr
e <- CoreBind
b, Id -> Bool
isDictId Id
x =
        CoreExpr -> CoreExpr
go forall a b. (a -> b) -> a -> b
$ HasDebugCallStack => Subst -> CoreExpr -> CoreExpr
substExpr (Subst -> Id -> CoreExpr -> Subst
extendIdSubst Subst
emptySubst Id
x CoreExpr
e) CoreExpr
body
    | Bool
otherwise = forall b. Bind b -> Expr b -> Expr b
Let (forall b. (Expr b -> Expr b) -> Bind b -> Bind b
mapBnd CoreExpr -> CoreExpr
go CoreBind
b) (CoreExpr -> CoreExpr
go CoreExpr
body)
  go (Case CoreExpr
e Id
x Type
t [Alt Id]
alts) = forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case (CoreExpr -> CoreExpr
go CoreExpr
e) Id
x Type
t (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall b. (Expr b -> Expr b) -> Alt b -> Alt b
mapAlt CoreExpr -> CoreExpr
go) [Alt Id]
alts)
  go (Cast CoreExpr
e CoercionR
c       ) = forall b. Expr b -> CoercionR -> Expr b
Cast (CoreExpr -> CoreExpr
go CoreExpr
e) CoercionR
c
  go (Tick CoreTickish
t CoreExpr
e       ) = forall b. CoreTickish -> Expr b -> Expr b
Tick CoreTickish
t (CoreExpr -> CoreExpr
go CoreExpr
e)
  go CoreExpr
e
    | (Var Id
m, [CoreExpr]
args) <- forall b. Expr b -> (Expr b, [Expr b])
collectArgs CoreExpr
e
    , Just Id
aux <- forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
M.lookup Id
m HashMap Id Id
methodToAux
    , CoreExpr
arg : [CoreExpr]
argsNoTy <- forall a. (a -> Bool) -> [a] -> [a]
dropWhile forall b. Expr b -> Bool
isTypeArg [CoreExpr]
args
    , (Var Id
x, [CoreExpr]
argargs) <- forall b. Expr b -> (Expr b, [Expr b])
collectArgs CoreExpr
arg
    , Id
x forall a. Eq a => a -> a -> Bool
== Id
dfunId
    = forall a. Outputable a => [Char] -> a -> a
GM.notracePpr ([Char]
"inlining in" forall a. [a] -> [a] -> [a]
++ forall a. Outputable a => a -> [Char]
GM.showPpr CoreExpr
e)
      forall a b. (a -> b) -> a -> b
$ CoreExpr -> [CoreExpr] -> CoreExpr
mkCoreApps (forall b. Id -> Expr b
Var Id
aux) ([CoreExpr]
argargs forall a. [a] -> [a] -> [a]
++ (CoreExpr -> CoreExpr
go forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CoreExpr]
argsNoTy))
  go (App CoreExpr
e0 CoreExpr
e1) = forall b. Expr b -> Expr b -> Expr b
App (CoreExpr -> CoreExpr
go CoreExpr
e0) (CoreExpr -> CoreExpr
go CoreExpr
e1)
  go CoreExpr
e           = CoreExpr
e


-- modified from Rec.hs
mapBnd :: (Expr b -> Expr b) -> Bind b -> Bind b
mapBnd :: forall b. (Expr b -> Expr b) -> Bind b -> Bind b
mapBnd Expr b -> Expr b
f (NonRec b
b Expr b
e) = forall b. b -> Expr b -> Bind b
NonRec b
b (Expr b -> Expr b
f Expr b
e)
mapBnd Expr b -> Expr b
f (Rec [(b, Expr b)]
bs    ) = forall b. [(b, Expr b)] -> Bind b
Rec (forall a b. (a -> b) -> [a] -> [b]
map (forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second Expr b -> Expr b
f) [(b, Expr b)]
bs)

mapAlt :: (Expr b -> Expr b) -> Alt b -> Alt b
mapAlt :: forall b. (Expr b -> Expr b) -> Alt b -> Alt b
mapAlt Expr b -> Expr b
f (Alt AltCon
d [b]
bs Expr b
e) = forall b. AltCon -> [b] -> Expr b -> Alt b
Alt AltCon
d [b]
bs (Expr b -> Expr b
f Expr b
e)