{-# LANGUAGE CPP, ScopedTypeVariables, FlexibleContexts #-}

-- | Note: this module should NOT export externals. It is for common
--   transformations needed by the other primitive modules.
module Language.HERMIT.Primitive.Common
    ( -- * Utility Transformations
      applyInContextT
      -- ** Finding function calls.
    , callT
    , callPredT
    , callNameT
    , callSaturatedT
    , callNameG
    , callDataConT
    , callDataConNameT
    , callsR
    , callsT
      -- ** Collecting variables bound at a Node
    , progIdsT
    , consIdsT
    , consRecIdsT
    , consNonRecIdT
    , bindVarsT
    , nonRecVarT
    , recIdsT
    , defIdT
    , lamVarT
    , letVarsT
    , letRecIdsT
    , letNonRecVarT
    , caseVarsT
    , caseWildIdT
    , caseAltVarsT
    , altVarsT
      -- ** Finding variables bound in the Context
    , boundVarsT
    , findBoundVarT
    , findIdT
    , findId
      -- ** Miscallaneous
    , wrongExprForm
    , nodups
    , mapAlts
    )

where

import GhcPlugins

import Data.List
import Data.Monoid
import qualified Data.Set as S

import Control.Monad(liftM)

import Language.HERMIT.Kure
import Language.HERMIT.Core
import Language.HERMIT.Context
import Language.HERMIT.GHC

import Language.HERMIT.Primitive.GHC

import qualified Language.Haskell.TH as TH
import Language.Haskell.TH.Syntax (showName)

------------------------------------------------------------------------------

-- | Apply a transformation to a value in the current context.
applyInContextT :: Translate c m a b -> a -> Translate c m x b
applyInContextT t a = contextonlyT $ \ c -> apply t c a

-- Note: this is the same as: return a >>> t

------------------------------------------------------------------------------

-- | Lift GHC's collectArgs
callT :: Monad m => Translate c m CoreExpr (CoreExpr, [CoreExpr])
callT = contextfreeT $ \ e -> case e of
                                Var {} -> return (e, [])
                                App {} -> return (collectArgs e)
                                _      -> fail "not an application or variable occurence."

callPredT :: Monad m => (Id -> [CoreExpr] -> Bool) -> Translate c m CoreExpr (CoreExpr, [CoreExpr])
callPredT p = do
    call@(Var i, args) <- callT
    guardMsg (p i args) "predicate failed."
    return call

-- | Succeeds if we are looking at an application of given function
--   returning zero or more arguments to which it is applied.
callNameT :: MonadCatch m => TH.Name -> Translate c m CoreExpr (CoreExpr, [CoreExpr])
callNameT nm = setFailMsg ("callNameT: not a call to " ++ show nm) $
    callPredT (const . cmpTHName2Var nm)

-- | Succeeds if we are looking at a fully saturated function call.
callSaturatedT :: Monad m => Translate c m CoreExpr (CoreExpr, [CoreExpr])
callSaturatedT = callPredT (\ i args -> idArity i == length args)
-- TODO: probably better to calculate arity based on Id's type, as
--       idArity is conservatively set to zero by default.

-- | Succeeds if we are looking at an application of given function
callNameG :: MonadCatch m => TH.Name -> Translate c m CoreExpr ()
callNameG nm = prefixFailMsg "callNameG failed: " $ callNameT nm >>= \_ -> constT (return ())

-- | Succeeds if we are looking at an application of a data constructor.
callDataConT :: MonadCatch m => Translate c m CoreExpr (DataCon, [Type], [CoreExpr])
callDataConT = prefixFailMsg "callDataConT failed:" $
#if __GLASGOW_HASKELL__ > 706
    do mb <- contextfreeT $ \ e -> let in_scope = mkInScopeSet (mkVarEnv [ (v,v) | v <- S.toList (coreExprFreeVars e) ])
                                   in return $ exprIsConApp_maybe (in_scope, idUnfolding) e
       maybe (fail "not a datacon application.") return mb
#else
    contextfreeT (return . exprIsConApp_maybe idUnfolding)
        >>= maybe (fail "not a datacon application.") return
#endif

-- | Succeeds if we are looking at an application of a named data constructor.
callDataConNameT :: MonadCatch m => TH.Name -> Translate c m CoreExpr (DataCon, [Type], [CoreExpr])
callDataConNameT nm = do
    res@(dc,_,_) <- callDataConT
    guardMsg (cmpTHName2Name nm (dataConName dc)) "wrong datacon."
    return res

-- | Apply a rewrite to all applications of a given function in a top-down manner, pruning on success.
callsR :: (ExtendPath c Crumb, AddBindings c, MonadCatch m) => TH.Name -> Rewrite c m CoreExpr -> Rewrite c m Core
callsR nm rr = prunetdR (promoteExprR $ callNameG nm >> rr)

-- | Apply a translate to all applications of a given function in a top-down manner,
--   pruning on success, collecting the results.
callsT :: (ExtendPath c Crumb, AddBindings c, MonadCatch m) => TH.Name -> Translate c m CoreExpr b -> Translate c m Core [b]
callsT nm t = collectPruneT (promoteExprT $ callNameG nm >> t)

------------------------------------------------------------------------------

-- | List all identifiers bound at the top-level in a program.
progIdsT :: (ExtendPath c Crumb, AddBindings c, MonadCatch m) => Translate c m CoreProg [Id]
progIdsT = progNilT [] <+ progConsT bindVarsT progIdsT (++)

-- | List the identifiers bound by the top-level binding group at the head of the program.
consIdsT :: (ExtendPath c Crumb, AddBindings c, MonadCatch m) => Translate c m CoreProg [Id]
consIdsT = progConsT bindVarsT mempty (\ vs () -> vs)

-- | List the identifiers bound by a recursive top-level binding group at the head of the program.
consRecIdsT :: (ExtendPath c Crumb, AddBindings c, Monad m) => Translate c m CoreProg [Id]
consRecIdsT = progConsT recIdsT mempty (\ vs () -> vs)

-- | Return the identifier bound by a non-recursive top-level binding at the head of the program.
consNonRecIdT :: (ExtendPath c Crumb, AddBindings c, Monad m) => Translate c m CoreProg Id
consNonRecIdT = progConsT nonRecVarT mempty (\ v () -> v)

-- | List all variables bound in a binding group.
bindVarsT :: (ExtendPath c Crumb, AddBindings c, MonadCatch m) => Translate c m CoreBind [Var]
bindVarsT = liftM return nonRecVarT <+ recIdsT

-- | Return the variable bound by a non-recursive let expression.
nonRecVarT :: (ExtendPath c Crumb, Monad m) => Translate c m CoreBind Var
nonRecVarT = nonRecT idR mempty (\ v () -> v)

-- | List all identifiers bound in a recursive binding group.
recIdsT :: (ExtendPath c Crumb, AddBindings c, Monad m) => Translate c m CoreBind [Id]
recIdsT = recT (\ _ -> defIdT) id

-- | Return the identifier bound by a recursive definition.
defIdT :: (ExtendPath c Crumb, Monad m) => Translate c m CoreDef Id
defIdT = defT idR mempty (\ v () -> v)

-- | Return the variable bound by a lambda expression.
lamVarT :: (ExtendPath c Crumb, AddBindings c, Monad m) => Translate c m CoreExpr Var
lamVarT = lamT idR mempty (\ v () -> v)

-- | List the variables bound by a let expression.
letVarsT :: (ExtendPath c Crumb, AddBindings c, MonadCatch m) => Translate c m CoreExpr [Var]
letVarsT = letT bindVarsT mempty (\ vs () -> vs)

-- | List the identifiers bound by a recursive let expression.
letRecIdsT :: (ExtendPath c Crumb, AddBindings c, Monad m) => Translate c m CoreExpr [Id]
letRecIdsT = letT recIdsT mempty (\ vs () -> vs)

-- | Return the variable bound by a non-recursive let expression.
letNonRecVarT :: (ExtendPath c Crumb, AddBindings c, Monad m) => Translate c m CoreExpr Var
letNonRecVarT = letT nonRecVarT mempty (\ v () -> v)

-- | List all variables bound by a case expression (in the alternatives and the wildcard binder).
caseVarsT :: (ExtendPath c Crumb, AddBindings c, Monad m) => Translate c m CoreExpr [Var]
caseVarsT = caseT mempty idR mempty (\ _ -> altVarsT) (\ () v () vss -> v : nub (concat vss))

-- | Return the case wildcard binder.
caseWildIdT :: (ExtendPath c Crumb, AddBindings c, Monad m) => Translate c m CoreExpr Id
caseWildIdT = caseT mempty idR mempty (\ _ -> idR) (\ () i () _ -> i)

-- | List the variables bound by all alternatives in a case expression.
caseAltVarsT :: (ExtendPath c Crumb, AddBindings c, Monad m) => Translate c m CoreExpr [[Var]]
caseAltVarsT = caseT mempty mempty mempty (\ _ -> altVarsT) (\ () () () vss -> vss)

-- | List the variables bound by a case alternative.
altVarsT :: (ExtendPath c Crumb, AddBindings c, Monad m) => Translate c m CoreAlt [Var]
altVarsT = altT mempty (\ _ -> idR) mempty (\ () vs () -> vs)

------------------------------------------------------------------------------

-- Need a better error type so that we can factor out the repetition.

-- | Lifted version of 'boundVars'.
boundVarsT :: (BoundVars c, Monad m) => Translate c m a (S.Set Var)
boundVarsT = contextonlyT (return . boundVars)

-- | Find the unique variable bound in the context that matches the given name, failing if it is not unique.
findBoundVarT :: (BoundVars c, MonadCatch m) => TH.Name -> Translate c m a Var
findBoundVarT nm = prefixFailMsg ("Cannot resolve name " ++ showName nm ++ ", ") $
                        do c <- contextT
                           case findBoundVars nm c of
                             []         -> fail "no matching variables in scope."
                             [v]        -> return v
                             _ : _ : _  -> fail "multiple matching variables in scope."

-- | Lookup the name in the context first, then, failing that, in GHC's global reader environment.
findIdT :: (BoundVars c, HasGlobalRdrEnv c, HasDynFlags m, MonadThings m, MonadCatch m) => TH.Name -> Translate c m a Id
findIdT nm = prefixFailMsg ("Cannot resolve name " ++ showName nm ++ ", ") $
             contextonlyT (findId nm)

findId :: (BoundVars c, HasGlobalRdrEnv c, HasDynFlags m, MonadThings m) => TH.Name -> c -> m Id
findId nm c = case findBoundVars nm c of
                []         -> findIdMG nm c
                [v]        -> return v
                _ : _ : _  -> fail "multiple matching variables in scope."

findIdMG :: (BoundVars c, HasGlobalRdrEnv c, HasDynFlags m, MonadThings m) => TH.Name -> c -> m Id
findIdMG nm c =
    case filter isValName $ findNamesFromTH (hermitGlobalRdrEnv c) nm of
      []  -> findIdBuiltIn nm
      [n] -> lookupId n
      ns  -> do dynFlags <- getDynFlags
                fail $ "multiple matches found:\n" ++ intercalate ", " (map (showPpr dynFlags) ns)

findIdBuiltIn :: forall m. Monad m => TH.Name -> m Id
findIdBuiltIn = go . showName
    where go ":"     = dataConId consDataCon
          go "[]"    = dataConId nilDataCon

          go "True"  = return trueDataConId
          go "False" = return falseDataConId

          go "<"     = return ltDataConId
          go "=="    = return eqDataConId
          go ">"     = return gtDataConId

          go "I#"    = dataConId intDataCon

          go "()"    = return unitDataConId
          -- TODO: add more as needed
          --       http://www.haskell.org/ghc/docs/latest/html/libraries/ghc/TysWiredIn.html
          go _   = fail "variable not in scope."

          dataConId :: DataCon -> m Id
          dataConId = return . dataConWorkId

------------------------------------------------------------------------------

-- | Constructs a common error message.
--   Argument 'String' should be the desired form of the expression.
wrongExprForm :: String -> String
wrongExprForm form = "Expression does not have the form: " ++ form

------------------------------------------------------------------------------

-- | Determine if a list contains no duplicated elements.
nodups :: Eq a => [a] -> Bool
nodups as = length as == length (nub as)

------------------------------------------------------------------------------

mapAlts :: (CoreExpr -> CoreExpr) -> [CoreAlt] -> [CoreAlt]
mapAlts f alts = [ (ac, vs, f e) | (ac, vs, e) <- alts ]

------------------------------------------------------------------------------