{-# LANGUAGE CPP, MultiWayIf, TupleSections #-}
{-# LANGUAGE FlexibleContexts #-}
module Inst (
       deeplySkolemise,
       topInstantiate, topInstantiateInferred, deeplyInstantiate,
       instCall, instDFunType, instStupidTheta, instTyVarsWith,
       newWanted, newWanteds,
       tcInstTyBinders, tcInstTyBinder,
       newOverloadedLit, mkOverLit,
       newClsInst,
       tcGetInsts, tcGetInstEnvs, getOverlapFlag,
       tcExtendLocalInstEnv,
       instCallConstraints, newMethodFromName,
       tcSyntaxName,
       
       tyCoVarsOfWC,
       tyCoVarsOfCt, tyCoVarsOfCts,
    ) where
#include "HsVersions.h"
import GhcPrelude
import {-# SOURCE #-}   TcExpr( tcPolyExpr, tcSyntaxOp )
import {-# SOURCE #-}   TcUnify( unifyType, unifyKind )
import BasicTypes ( IntegralLit(..), SourceText(..) )
import FastString
import HsSyn
import TcHsSyn
import TcRnMonad
import TcEnv
import TcEvidence
import InstEnv
import TysWiredIn  ( heqDataCon, eqDataCon )
import CoreSyn     ( isOrphan )
import FunDeps
import TcMType
import Type
import TyCoRep
import TcType
import HscTypes
import Class( Class )
import MkId( mkDictFunId )
import CoreSyn( Expr(..) )  
import Id
import Name
import Var      ( EvVar, mkTyVar, tyVarName, VarBndr(..) )
import DataCon
import VarEnv
import PrelNames
import SrcLoc
import DynFlags
import Util
import Outputable
import qualified GHC.LanguageExtensions as LangExt
import Control.Monad( unless )
newMethodFromName :: CtOrigin -> Name -> TcRhoType -> TcM (HsExpr GhcTcId)
newMethodFromName origin name inst_ty
  = do { id <- tcLookupId name
              
              
              
              
       ; let ty = piResultTy (idType id) inst_ty
             (theta, _caller_knows_this) = tcSplitPhiTy ty
       ; wrap <- ASSERT( not (isForAllTy ty) && isSingleton theta )
                 instCall origin [inst_ty] theta
       ; return (mkHsWrap wrap (HsVar noExt (noLoc id))) }
deeplySkolemise :: TcSigmaType
                -> TcM ( HsWrapper
                       , [(Name,TyVar)]     
                       , [EvVar]            
                       , TcRhoType )
deeplySkolemise ty
  = go init_subst ty
  where
    init_subst = mkEmptyTCvSubst (mkInScopeSet (tyCoVarsOfType ty))
    go subst ty
      | Just (arg_tys, tvs, theta, ty') <- tcDeepSplitSigmaTy_maybe ty
      = do { let arg_tys' = substTys subst arg_tys
           ; ids1           <- newSysLocalIds (fsLit "dk") arg_tys'
           ; (subst', tvs1) <- tcInstSkolTyVarsX subst tvs
           ; ev_vars1       <- newEvVars (substTheta subst' theta)
           ; (wrap, tvs_prs2, ev_vars2, rho) <- go subst' ty'
           ; let tv_prs1 = map tyVarName tvs `zip` tvs1
           ; return ( mkWpLams ids1
                      <.> mkWpTyLams tvs1
                      <.> mkWpLams ev_vars1
                      <.> wrap
                      <.> mkWpEvVarApps ids1
                    , tv_prs1  ++ tvs_prs2
                    , ev_vars1 ++ ev_vars2
                    , mkFunTys arg_tys' rho ) }
      | otherwise
      = return (idHsWrapper, [], [], substTy subst ty)
        
topInstantiate :: CtOrigin -> TcSigmaType -> TcM (HsWrapper, TcRhoType)
topInstantiate = top_instantiate True
topInstantiateInferred :: CtOrigin -> TcSigmaType
                       -> TcM (HsWrapper, TcSigmaType)
topInstantiateInferred = top_instantiate False
top_instantiate :: Bool   
                          
                -> CtOrigin -> TcSigmaType -> TcM (HsWrapper, TcRhoType)
top_instantiate inst_all orig ty
  | not (null binders && null theta)
  = do { let (inst_bndrs, leave_bndrs) = span should_inst binders
             (inst_theta, leave_theta)
               | null leave_bndrs = (theta, [])
               | otherwise        = ([], theta)
             in_scope    = mkInScopeSet (tyCoVarsOfType ty)
             empty_subst = mkEmptyTCvSubst in_scope
             inst_tvs    = binderVars inst_bndrs
       ; (subst, inst_tvs') <- mapAccumLM newMetaTyVarX empty_subst inst_tvs
       ; let inst_theta' = substTheta subst inst_theta
             sigma'      = substTy subst (mkForAllTys leave_bndrs $
                                          mkFunTys leave_theta rho)
             inst_tv_tys' = mkTyVarTys inst_tvs'
       ; wrap1 <- instCall orig inst_tv_tys' inst_theta'
       ; traceTc "Instantiating"
                 (vcat [ text "all tyvars?" <+> ppr inst_all
                       , text "origin" <+> pprCtOrigin orig
                       , text "type" <+> debugPprType ty
                       , text "theta" <+> ppr theta
                       , text "leave_bndrs" <+> ppr leave_bndrs
                       , text "with" <+> vcat (map debugPprType inst_tv_tys')
                       , text "theta:" <+>  ppr inst_theta' ])
       ; (wrap2, rho2) <-
           if null leave_bndrs
         
           then top_instantiate inst_all orig sigma'
         
           else return (idHsWrapper, sigma')
       ; return (wrap2 <.> wrap1, rho2) }
  | otherwise = return (idHsWrapper, ty)
  where
    (binders, phi) = tcSplitForAllVarBndrs ty
    (theta, rho)   = tcSplitPhiTy phi
    should_inst bndr
      | inst_all  = True
      | otherwise = binderArgFlag bndr == Inferred
deeplyInstantiate :: CtOrigin -> TcSigmaType -> TcM (HsWrapper, TcRhoType)
deeplyInstantiate orig ty =
  deeply_instantiate orig
                     (mkEmptyTCvSubst (mkInScopeSet (tyCoVarsOfType ty)))
                     ty
deeply_instantiate :: CtOrigin
                   -> TCvSubst
                   -> TcSigmaType -> TcM (HsWrapper, TcRhoType)
deeply_instantiate orig subst ty
  | Just (arg_tys, tvs, theta, rho) <- tcDeepSplitSigmaTy_maybe ty
  = do { (subst', tvs') <- newMetaTyVarsX subst tvs
       ; let arg_tys' = substTys   subst' arg_tys
             theta'   = substTheta subst' theta
       ; ids1  <- newSysLocalIds (fsLit "di") arg_tys'
       ; wrap1 <- instCall orig (mkTyVarTys tvs') theta'
       ; traceTc "Instantiating (deeply)" (vcat [ text "origin" <+> pprCtOrigin orig
                                                , text "type" <+> ppr ty
                                                , text "with" <+> ppr tvs'
                                                , text "args:" <+> ppr ids1
                                                , text "theta:" <+>  ppr theta'
                                                , text "subst:" <+> ppr subst'])
       ; (wrap2, rho2) <- deeply_instantiate orig subst' rho
       ; return (mkWpLams ids1
                    <.> wrap2
                    <.> wrap1
                    <.> mkWpEvVarApps ids1,
                 mkFunTys arg_tys' rho2) }
  | otherwise
  = do { let ty' = substTy subst ty
       ; traceTc "deeply_instantiate final subst"
                 (vcat [ text "origin:"   <+> pprCtOrigin orig
                       , text "type:"     <+> ppr ty
                       , text "new type:" <+> ppr ty'
                       , text "subst:"    <+> ppr subst ])
      ; return (idHsWrapper, ty') }
instTyVarsWith :: CtOrigin -> [TyVar] -> [TcType] -> TcM TCvSubst
instTyVarsWith orig tvs tys
  = go empty_subst tvs tys
  where
    empty_subst = mkEmptyTCvSubst (mkInScopeSet (tyCoVarsOfTypes tys))
    go subst [] []
      = return subst
    go subst (tv:tvs) (ty:tys)
      | tv_kind `tcEqType` ty_kind
      = go (extendTCvSubst subst tv ty) tvs tys
      | otherwise
      = do { co <- emitWantedEq orig KindLevel Nominal ty_kind tv_kind
           ; go (extendTCvSubst subst tv (ty `mkCastTy` co)) tvs tys }
      where
        tv_kind = substTy subst (tyVarKind tv)
        ty_kind = tcTypeKind ty
    go _ _ _ = pprPanic "instTysWith" (ppr tvs $$ ppr tys)
instCall :: CtOrigin -> [TcType] -> TcThetaType -> TcM HsWrapper
instCall orig tys theta
  = do  { dict_app <- instCallConstraints orig theta
        ; return (dict_app <.> mkWpTyApps tys) }
instCallConstraints :: CtOrigin -> TcThetaType -> TcM HsWrapper
instCallConstraints orig preds
  | null preds
  = return idHsWrapper
  | otherwise
  = do { evs <- mapM go preds
       ; traceTc "instCallConstraints" (ppr evs)
       ; return (mkWpEvApps evs) }
  where
    go :: TcPredType -> TcM EvTerm
    go pred
     | Just (Nominal, ty1, ty2) <- getEqPredTys_maybe pred 
     = do  { co <- unifyType Nothing ty1 ty2
           ; return (evCoercion co) }
       
     | Just (tc, args@[_, _, ty1, ty2]) <- splitTyConApp_maybe pred
     , tc `hasKey` heqTyConKey
     = do { co <- unifyType Nothing ty1 ty2
          ; return (evDFunApp (dataConWrapId heqDataCon) args [Coercion co]) }
     | otherwise
     = emitWanted orig pred
instDFunType :: DFunId -> [DFunInstType]
             -> TcM ( [TcType]      
                    , TcThetaType ) 
instDFunType dfun_id dfun_inst_tys
  = do { (subst, inst_tys) <- go empty_subst dfun_tvs dfun_inst_tys
       ; return (inst_tys, substTheta subst dfun_theta) }
  where
    dfun_ty = idType dfun_id
    (dfun_tvs, dfun_theta, _) = tcSplitSigmaTy dfun_ty
    empty_subst = mkEmptyTCvSubst (mkInScopeSet (tyCoVarsOfType dfun_ty))
                  
                  
    go :: TCvSubst -> [TyVar] -> [DFunInstType] -> TcM (TCvSubst, [TcType])
    go subst [] [] = return (subst, [])
    go subst (tv:tvs) (Just ty : mb_tys)
      = do { (subst', tys) <- go (extendTvSubstAndInScope subst tv ty)
                                 tvs
                                 mb_tys
           ; return (subst', ty : tys) }
    go subst (tv:tvs) (Nothing : mb_tys)
      = do { (subst', tv') <- newMetaTyVarX subst tv
           ; (subst'', tys) <- go subst' tvs mb_tys
           ; return (subst'', mkTyVarTy tv' : tys) }
    go _ _ _ = pprPanic "instDFunTypes" (ppr dfun_id $$ ppr dfun_inst_tys)
instStupidTheta :: CtOrigin -> TcThetaType -> TcM ()
instStupidTheta orig theta
  = do  { _co <- instCallConstraints orig theta 
        ; return () }
tcInstTyBinders :: HasDebugCallStack
              => ([TyCoBinder], TcKind)   
              -> TcM ([TcType], TcKind)   
tcInstTyBinders (bndrs, ty)
  | null bndrs        
  = return ([], ty)   
                      
                      
  | otherwise
  = do { (subst, args) <- mapAccumLM (tcInstTyBinder Nothing) empty_subst bndrs
       ; ty' <- zonkTcType (substTy subst ty)
                   
                   
                   
       ; return (args, ty') }
  where
     empty_subst = mkEmptyTCvSubst (mkInScopeSet (tyCoVarsOfType ty))
tcInstTyBinder :: Maybe (VarEnv Kind)
               -> TCvSubst -> TyBinder -> TcM (TCvSubst, TcType)
tcInstTyBinder mb_kind_info subst (Named (Bndr tv _))
  = case lookup_tv tv of
      Just ki -> return (extendTvSubstAndInScope subst tv ki, ki)
      Nothing -> do { (subst', tv') <- newMetaTyVarX subst tv
                    ; return (subst', mkTyVarTy tv') }
  where
    lookup_tv tv = do { env <- mb_kind_info   
                      ; lookupVarEnv env tv }
tcInstTyBinder _ subst (Anon ty)
     
  | Just (mk, k1, k2) <- get_eq_tys_maybe substed_ty
  = do { co <- unifyKind Nothing k1 k2
       ; arg' <- mk co
       ; return (subst, arg') }
  | isPredTy substed_ty
  = do { let (env, tidy_ty) = tidyOpenType emptyTidyEnv substed_ty
       ; addErrTcM (env, text "Illegal constraint in a kind:" <+> ppr tidy_ty)
         
       ; u <- newUnique
       ; let name = mkSysTvName u (fsLit "dict")
       ; return (subst, mkTyVarTy $ mkTyVar name substed_ty) }
  | otherwise
  = do { tv_ty <- newFlexiTyVarTy substed_ty
       ; return (subst, tv_ty) }
  where
    substed_ty = substTy subst ty
      
    get_eq_tys_maybe :: Type
                     -> Maybe ( Coercion -> TcM Type
                                 
                                 
                              , Type  
                              , Type  
                              )
    get_eq_tys_maybe ty
        
      | Just (Nominal, k1, k2) <- getEqPredTys_maybe ty
      = Just (\co -> return $ mkCoercionTy co, k1, k2)
        
      | Just (tc, [_, _, k1, k2]) <- splitTyConApp_maybe ty
      = if | tc `hasKey` heqTyConKey
             -> Just (\co -> mkHEqBoxTy co k1 k2, k1, k2)
           | otherwise
             -> Nothing
        
      | Just (tc, [_, k1, k2]) <- splitTyConApp_maybe ty
      = if | tc `hasKey` eqTyConKey
             -> Just (\co -> mkEqBoxTy co k1 k2, k1, k2)
           | otherwise
             -> Nothing
      | otherwise
      = Nothing
mkHEqBoxTy :: TcCoercion -> Type -> Type -> TcM Type
mkHEqBoxTy co ty1 ty2
  = return $
    mkTyConApp (promoteDataCon heqDataCon) [k1, k2, ty1, ty2, mkCoercionTy co]
  where k1 = tcTypeKind ty1
        k2 = tcTypeKind ty2
mkEqBoxTy :: TcCoercion -> Type -> Type -> TcM Type
mkEqBoxTy co ty1 ty2
  = return $
    mkTyConApp (promoteDataCon eqDataCon) [k, ty1, ty2, mkCoercionTy co]
  where k = tcTypeKind ty1
newOverloadedLit :: HsOverLit GhcRn
                 -> ExpRhoType
                 -> TcM (HsOverLit GhcTcId)
newOverloadedLit
  lit@(OverLit { ol_val = val, ol_ext = rebindable }) res_ty
  | not rebindable
    
    
  = do { res_ty <- expTypeToType res_ty
       ; dflags <- getDynFlags
       ; case shortCutLit dflags val res_ty of
        
        
        
        
           Just expr -> return (lit { ol_witness = expr
                                    , ol_ext = OverLitTc False res_ty })
           Nothing   -> newNonTrivialOverloadedLit orig lit
                                                   (mkCheckExpType res_ty) }
  | otherwise
  = newNonTrivialOverloadedLit orig lit res_ty
  where
    orig = LiteralOrigin lit
newOverloadedLit XOverLit{} _ = panic "newOverloadedLit"
newNonTrivialOverloadedLit :: CtOrigin
                           -> HsOverLit GhcRn
                           -> ExpRhoType
                           -> TcM (HsOverLit GhcTcId)
newNonTrivialOverloadedLit orig
  lit@(OverLit { ol_val = val, ol_witness = HsVar _ (L _ meth_name)
               , ol_ext = rebindable }) res_ty
  = do  { hs_lit <- mkOverLit val
        ; let lit_ty = hsLitType hs_lit
        ; (_, fi') <- tcSyntaxOp orig (mkRnSyntaxExpr meth_name)
                                      [synKnownType lit_ty] res_ty $
                      \_ -> return ()
        ; let L _ witness = nlHsSyntaxApps fi' [nlHsLit hs_lit]
        ; res_ty <- readExpType res_ty
        ; return (lit { ol_witness = witness
                      , ol_ext = OverLitTc rebindable res_ty }) }
newNonTrivialOverloadedLit _ lit _
  = pprPanic "newNonTrivialOverloadedLit" (ppr lit)
mkOverLit ::OverLitVal -> TcM (HsLit GhcTc)
mkOverLit (HsIntegral i)
  = do  { integer_ty <- tcMetaTy integerTyConName
        ; return (HsInteger (il_text i)
                            (il_value i) integer_ty) }
mkOverLit (HsFractional r)
  = do  { rat_ty <- tcMetaTy rationalTyConName
        ; return (HsRat noExt r rat_ty) }
mkOverLit (HsIsString src s) = return (HsString src s)
tcSyntaxName :: CtOrigin
             -> TcType                 
             -> (Name, HsExpr GhcRn)   
             -> TcM (Name, HsExpr GhcTcId)
                                       
tcSyntaxName orig ty (std_nm, HsVar _ (L _ user_nm))
  | std_nm == user_nm
  = do rhs <- newMethodFromName orig std_nm ty
       return (std_nm, rhs)
tcSyntaxName orig ty (std_nm, user_nm_expr) = do
    std_id <- tcLookupId std_nm
    let
        
        ([tv], _, tau) = tcSplitSigmaTy (idType std_id)
        sigma1         = substTyWith [tv] [ty] tau
        
        
    addErrCtxtM (syntaxNameCtxt user_nm_expr orig sigma1) $ do
        
        
        
     span <- getSrcSpanM
     expr <- tcPolyExpr (L span user_nm_expr) sigma1
     return (std_nm, unLoc expr)
syntaxNameCtxt :: HsExpr GhcRn -> CtOrigin -> Type -> TidyEnv
               -> TcRn (TidyEnv, SDoc)
syntaxNameCtxt name orig ty tidy_env
  = do { inst_loc <- getCtLocM orig (Just TypeLevel)
       ; let msg = vcat [ text "When checking that" <+> quotes (ppr name)
                          <+> text "(needed by a syntactic construct)"
                        , nest 2 (text "has the required type:"
                                  <+> ppr (tidyType tidy_env ty))
                        , nest 2 (pprCtLoc inst_loc) ]
       ; return (tidy_env, msg) }
getOverlapFlag :: Maybe OverlapMode -> TcM OverlapFlag
getOverlapFlag overlap_mode
  = do  { dflags <- getDynFlags
        ; let overlap_ok    = xopt LangExt.OverlappingInstances dflags
              incoherent_ok = xopt LangExt.IncoherentInstances  dflags
              use x = OverlapFlag { isSafeOverlap = safeLanguageOn dflags
                                  , overlapMode   = x }
              default_oflag | incoherent_ok = use (Incoherent NoSourceText)
                            | overlap_ok    = use (Overlaps NoSourceText)
                            | otherwise     = use (NoOverlap NoSourceText)
              final_oflag = setOverlapModeMaybe default_oflag overlap_mode
        ; return final_oflag }
tcGetInsts :: TcM [ClsInst]
tcGetInsts = fmap tcg_insts getGblEnv
newClsInst :: Maybe OverlapMode -> Name -> [TyVar] -> ThetaType
           -> Class -> [Type] -> TcM ClsInst
newClsInst overlap_mode dfun_name tvs theta clas tys
  = do { (subst, tvs') <- freshenTyVarBndrs tvs
             
             
       ; let tys' = substTys subst tys
             dfun = mkDictFunId dfun_name tvs theta clas tys
             
             
             
             
             
       ; oflag <- getOverlapFlag overlap_mode
       ; let inst = mkLocalInstance dfun oflag tvs' clas tys'
       ; warnIfFlag Opt_WarnOrphans
                    (isOrphan (is_orphan inst))
                    (instOrphWarn inst)
       ; return inst }
instOrphWarn :: ClsInst -> SDoc
instOrphWarn inst
  = hang (text "Orphan instance:") 2 (pprInstanceHdr inst)
    $$ text "To avoid this"
    $$ nest 4 (vcat possibilities)
  where
    possibilities =
      text "move the instance declaration to the module of the class or of the type, or" :
      text "wrap the type with a newtype and declare the instance on the new type." :
      []
tcExtendLocalInstEnv :: [ClsInst] -> TcM a -> TcM a
  
tcExtendLocalInstEnv dfuns thing_inside
 = do { traceDFuns dfuns
      ; env <- getGblEnv
      ; (inst_env', cls_insts') <- foldlM addLocalInst
                                          (tcg_inst_env env, tcg_insts env)
                                          dfuns
      ; let env' = env { tcg_insts    = cls_insts'
                       , tcg_inst_env = inst_env' }
      ; setGblEnv env' thing_inside }
addLocalInst :: (InstEnv, [ClsInst]) -> ClsInst -> TcM (InstEnv, [ClsInst])
addLocalInst (home_ie, my_insts) ispec
   = do {
             
             
             
             
             
         ; isGHCi <- getIsGHCi
         ; eps    <- getEps
         ; tcg_env <- getGblEnv
           
           
           
         ; let home_ie'
                 | isGHCi    = deleteFromInstEnv home_ie ispec
                 | otherwise = home_ie
               global_ie = eps_inst_env eps
               inst_envs = InstEnvs { ie_global  = global_ie
                                    , ie_local   = home_ie'
                                    , ie_visible = tcVisibleOrphanMods tcg_env }
             
         ; let inconsistent_ispecs = checkFunDeps inst_envs ispec
         ; unless (null inconsistent_ispecs) $
           funDepErr ispec inconsistent_ispecs
             
         ; let (_tvs, cls, tys) = instanceHead ispec
               (matches, _, _)  = lookupInstEnv False inst_envs cls tys
               dups             = filter (identicalClsInstHead ispec) (map fst matches)
         ; unless (null dups) $
           dupInstErr ispec (head dups)
         ; return (extendInstEnv home_ie' ispec, ispec : my_insts) }
traceDFuns :: [ClsInst] -> TcRn ()
traceDFuns ispecs
  = traceTc "Adding instances:" (vcat (map pp ispecs))
  where
    pp ispec = hang (ppr (instanceDFunId ispec) <+> colon)
                  2 (ppr ispec)
        
funDepErr :: ClsInst -> [ClsInst] -> TcRn ()
funDepErr ispec ispecs
  = addClsInstsErr (text "Functional dependencies conflict between instance declarations:")
                    (ispec : ispecs)
dupInstErr :: ClsInst -> ClsInst -> TcRn ()
dupInstErr ispec dup_ispec
  = addClsInstsErr (text "Duplicate instance declarations:")
                    [ispec, dup_ispec]
addClsInstsErr :: SDoc -> [ClsInst] -> TcRn ()
addClsInstsErr herald ispecs
  = setSrcSpan (getSrcSpan (head sorted)) $
    addErr (hang herald 2 (pprInstances sorted))
 where
   sorted = sortWith getSrcLoc ispecs