{-# LANGUAGE TemplateHaskell #-}

module Loopbreaker.InlineRecCalls (action) where

import           Control.Arrow hiding ((<+>))
import           Data.Generics
import           Data.Kind
import           Data.Map (Map)
import qualified Data.Map as M
import           Data.Maybe
import           Data.Set (Set)
import qualified Data.Set as S

import Bag
import ErrUtils
import GhcPlugins hiding ((<>), debugTraceMsg)
import HsSyn
import MonadUtils


------------------------------------------------------------------------------
type MonadInline m = ((MonadUnique m, MonadIO m, HasDynFlags m) :: Constraint)

------------------------------------------------------------------------------
-- | Forces compiler to inline functions by creating loopbreaker with NOINLINE
-- pragma, changing recursive calls to use it and by adding INLINE pragma to
-- the original function.
action :: MonadInline m
       => [CommandLineOption] -> HsGroup GhcRn -> m (HsGroup GhcRn)
action opts group@HsGroup{ hs_valds } = do
  let shouldDisable = "disable" `elem` opts

  dyn_flags <- getDynFlags

  if not shouldDisable && optLevel dyn_flags > 0
    then do
      liftIO $ showPass dyn_flags "Break loops"
      valds' <- inlineRecCalls hs_valds
      pure group{ hs_valds = valds' }
    else
      pure group

action _ _ = error "Loopbreaker.InlineRecCalls.action: expected renamed group"

------------------------------------------------------------------------------
inlineRecCalls :: MonadInline m => HsValBinds GhcRn -> m (HsValBinds GhcRn)
inlineRecCalls (XValBindsLR (NValBinds binds sigs)) = do
  let (types, inlined) = typesFromSigs &&& inlinedFromSigs $ unLoc <$> sigs

  (binds', extra_sigs) <- second concat . unzip
                      <$> traverse (inlineRecCall types inlined) binds

  pure $ XValBindsLR $ NValBinds binds' $ sigs ++ extra_sigs

-- TODO: should we throw an error here instead?
inlineRecCalls val_binds = pure val_binds

------------------------------------------------------------------------------
typesFromSigs :: Ord (IdP p) => [Sig p] -> Map (IdP p) (LHsSigWcType p)
typesFromSigs = M.fromList . concatMap sigToTups where
  sigToTups (TypeSig _ names type_) = (,type_) . unLoc <$> names
  sigToTups _                       = []

------------------------------------------------------------------------------
inlinedFromSigs :: Ord (IdP p) => [Sig p] -> Set (IdP p)
inlinedFromSigs = S.fromList . catMaybes . map nameIfInlineSig where
  nameIfInlineSig (InlineSig _ (L _ name) pragma)
    | isInlinePragma pragma = Just name
  nameIfInlineSig _ = Nothing

------------------------------------------------------------------------------
-- | Inserts loopbreaker to recursive binding group of single binding and
-- emits necessary signatures.
inlineRecCall
  :: MonadInline m
  => Map Name (LHsSigWcType GhcRn)  -- ^ types of bindings
  -> Set Name                       -- ^ Loopbreaker annotations
  -> (RecFlag, LHsBinds GhcRn)      -- ^ binding being inlined
  -> m ((RecFlag, LHsBinds GhcRn), [LSig GhcRn])
inlineRecCall types inlined (Recursive, binds)
  | (bagToList -> [L fun_loc fun_bind])           <- binds
  , FunBind{ fun_id = L _ fun_name, fun_matches } <- fun_bind
  , S.member fun_name inlined
  = do
  dyn_flags <- getDynFlags
  liftIO $ debugTraceMsg dyn_flags 2 $ text "Loopbreaker:" <+> ppr fun_name

  (loopb_name, loopb_decl) <- loopbreaker fun_name

  fun_matches' <- everywhereM ( fmap (replaceVarNamesT fun_name loopb_name)
                              . inlineLocalRecCallsM
                              ) fun_matches

  let m_loopb_sig = loopbreakerSig loopb_name <$> M.lookup fun_name types

  pure
    ( ( Recursive
      , listToBag
          [ L fun_loc fun_bind{ fun_matches = fun_matches' }
          , loopb_decl
          ]
      )
      -- If the original function didn't have type signature specified, we
      -- shouldn't have to have either
    , inlineSig noInlinePragma loopb_name : maybeToList m_loopb_sig
    )
-- We ignore mutually recursive and other bindings
inlineRecCall _ _ binds = pure (binds, [])

------------------------------------------------------------------------------
-- | Creates loopbreaker and it's name from name of the original function.
loopbreaker :: MonadUnique m => Name -> m (Name, LHsBind GhcRn)
loopbreaker fun_name =
  (id &&& loopbreakerDecl fun_name) <$> loopbreakerName fun_name

------------------------------------------------------------------------------
loopbreakerName :: MonadUnique m => Name -> m Name
loopbreakerName (occName -> occNameFS -> orig_fs) =
  flip mkSystemVarName (orig_fs <> "__Loopbreaker") <$> getUniqueM

------------------------------------------------------------------------------
loopbreakerDecl :: Name -> Name -> LHsBind GhcRn
loopbreakerDecl fun_name loopb_name =
  noLoc $ mkTopFunBind Generated (noLoc loopb_name)
    [ mkSimpleMatch (mkPrefixFunRhs $ noLoc loopb_name) [] $
        nlHsVar fun_name
    ]

------------------------------------------------------------------------------
-- | Creates loopbreaker type signature from type of original function.
loopbreakerSig :: Name -> LHsSigWcType GhcRn -> LSig GhcRn
loopbreakerSig loopb_name fun_type =
  noLoc $ TypeSig NoExt [noLoc loopb_name] fun_type

------------------------------------------------------------------------------
inlineSig :: (XInlineSig p ~ NoExt) => InlinePragma -> IdP p -> LSig p
inlineSig how name = noLoc $ InlineSig NoExt (noLoc name) how

------------------------------------------------------------------------------
-- | Contrary to 'neverInlinePragma', this has behaviour of @NOINLINE@ pragma.
noInlinePragma :: InlinePragma
noInlinePragma = defaultInlinePragma
  { inl_inline = NoInline
  , inl_act    = NeverActive
  }

------------------------------------------------------------------------------
-- | Transformation that applies loopbreakers to local bindings.
inlineLocalRecCallsM :: (MonadInline m, Typeable a) => a -> m a
inlineLocalRecCallsM = mkM $ \case
  (HsValBinds NoExt binds :: HsLocalBinds GhcRn)
    -> HsValBinds NoExt <$> inlineRecCalls binds
  e -> pure e

------------------------------------------------------------------------------
-- | Transformation that replaces every name in variable expression.
replaceVarNamesT :: Typeable a => Name -> Name -> a -> a
replaceVarNamesT from to = mkT $ \case
  HsVar NoExt (L loc name) :: HsExpr GhcRn
    | name == from -> HsVar NoExt $ L loc to
  e -> e