{-# Language FlexibleInstances, PatternGuards #-}
-- | Assumes that local names do not shadow top level names.
module Cryptol.ModuleSystem.InstantiateModule
  ( instantiateModule
  ) where

import           Data.Set (Set)
import qualified Data.Set as Set
import           Data.Map (Map)
import qualified Data.Map as Map
import           MonadLib(ReaderT,runReaderT,ask)

import Cryptol.Parser.Position(Located(..))
import Cryptol.ModuleSystem.Name
import Cryptol.TypeCheck.AST
import Cryptol.TypeCheck.Subst(listParamSubst, apSubst)
import Cryptol.Utils.Ident(ModName,modParamIdent)

{-
XXX: Should we simplify constraints in the instantiated modules?

If so, we also need to adjust the constraint parameters on terms appropriately,
especially when working with dictionaries.
-}


-- | Convert a module instantiation into a partial module.
-- The resulting module is incomplete because it is missing the definitions
-- from the instantiation.
instantiateModule :: FreshM m =>
                     Module           {- ^ Parametrized module -} ->
                     ModName          {- ^ Name of the new module -} ->
                     Map TParam Type  {- ^ Type params -} ->
                     Map Name Expr    {- ^ Value parameters -} ->
                     m ([Located Prop], Module)
                     -- ^ Instantiated constraints, fresh module, new supply
instantiateModule func newName tpMap vpMap =
  runReaderT newName $
    do let oldVpNames = Map.keys vpMap
       newVpNames <- mapM freshParamName (Map.keys vpMap)
       let vpNames = Map.fromList (zip oldVpNames newVpNames)

       env <- computeEnv func tpMap vpNames

       let rnMp :: Inst a => (a -> Name) -> Map Name a -> Map Name a
           rnMp f m = Map.fromList [ (f x, x) | a <- Map.elems m
                                              , let x = inst env a ]

           renamedExports  = inst env (mExports func)
           renamedTySyns   = rnMp tsName (mTySyns func)
           renamedNewtypes = rnMp ntName (mNewtypes func)
           renamedPrimTys  = rnMp atName (mPrimTypes func)

           su = listParamSubst (Map.toList (tyParamMap env))

           goals = map (fmap (apSubst su)) (mParamConstraints func)
           -- Constraints to discharge about the type instances

       let renamedDecls = inst env (mDecls func)
           paramDecls = map (mkParamDecl su vpNames) (Map.toList vpMap)

       return ( goals
              , Module
                 { mName              = newName
                 , mExports           = renamedExports
                 , mImports           = mImports func
                 , mTySyns            = renamedTySyns
                 , mNewtypes          = renamedNewtypes
                 , mPrimTypes         = renamedPrimTys
                 , mParamTypes        = Map.empty
                 , mParamConstraints  = []
                 , mParamFuns         = Map.empty
                 , mDecls             = paramDecls ++ renamedDecls
                 } )

  where
  mkParamDecl su vpNames (x,e) =
      NonRecursive Decl
        { dName        = Map.findWithDefault (error "OOPS") x vpNames
        , dSignature   = apSubst su
                       $ mvpType
                       $ Map.findWithDefault (error "UUPS") x (mParamFuns func)
        , dDefinition  = DExpr e
        , dPragmas     = []      -- XXX: which if any pragmas?
        , dInfix       = False   -- XXX: get from parameter?
        , dFixity      = Nothing -- XXX: get from parameter
        , dDoc         = Nothing -- XXX: get from parametr(or instance?)
        }


--------------------------------------------------------------------------------
-- Things that need to be renamed

class Defines t where
  defines :: t -> Set Name

instance Defines t => Defines [t] where
  defines = Set.unions . map defines

instance Defines Decl where
  defines = Set.singleton . dName

instance Defines DeclGroup where
  defines d =
    case d of
      NonRecursive x -> defines x
      Recursive x    -> defines x


--------------------------------------------------------------------------------

type InstM = ReaderT ModName

-- | Generate a new instance of a declared name.
freshenName :: FreshM m => Name -> InstM m Name
freshenName x =
  do m <- ask
     let sys = case nameInfo x of
                 Declared _ s -> s
                 _            -> UserName
     liftSupply (mkDeclared m sys (nameIdent x) (nameFixity x) (nameLoc x))

freshParamName :: FreshM m => Name -> InstM m Name
freshParamName x =
  do m <- ask
     let newName = modParamIdent (nameIdent x)
     liftSupply (mkDeclared m UserName newName (nameFixity x) (nameLoc x))




-- | Compute renaming environment from a module instantiation.
-- computeEnv :: ModInst -> InstM Env
computeEnv :: FreshM m =>
              Module {- ^ Functor being instantiated -} ->
              Map TParam Type {- replace type params by type -} ->
              Map Name Name {- replace value parameters by other names -} ->
              InstM m Env
computeEnv m tpMap vpMap =
  do tss <- mapM freshTy (Map.toList (mTySyns m))
     nts <- mapM freshTy (Map.toList (mNewtypes m))
     let tnMap = Map.fromList (tss ++ nts)

     defHere <- mapM mkVParam (Set.toList (defines (mDecls m)))
     let fnMap = Map.union vpMap (Map.fromList defHere)

     return Env { funNameMap  = fnMap
                , tyNameMap   = tnMap
                , tyParamMap  = tpMap
                }
  where
  freshTy (x,_) = do y <- freshenName x
                     return (x,y)

  mkVParam x = do y <- freshenName x
                  return (x,y)



--------------------------------------------------------------------------------
-- Do the renaming

data Env = Env
  { funNameMap  :: Map Name Name
  , tyNameMap   :: Map Name Name
  , tyParamMap  :: Map TParam Type
  } deriving Show


class Inst t where
  inst :: Env -> t -> t

instance Inst a => Inst [a] where
  inst env = map (inst env)

instance Inst Expr where
  inst env = go
    where
    go expr =
      case expr of
        EVar x -> case Map.lookup x (funNameMap env) of
                    Just y -> EVar y
                    _      -> expr

        EList xs t                -> EList (inst env xs) (inst env t)
        ETuple es                 -> ETuple (inst env es)
        ERec xs                   -> ERec (fmap go xs)
        ESel e s                  -> ESel (go e) s
        ESet e x v                -> ESet (go e) x (go v)
        EIf e1 e2 e3              -> EIf (go e1) (go e2) (go e3)
        EComp t1 t2 e mss         -> EComp (inst env t1) (inst env t2)
                                           (go e)
                                           (inst env mss)
        ETAbs t e                 -> ETAbs t (go e)
        ETApp e t                 -> ETApp (go e) (inst env t)
        EApp e1 e2                -> EApp (go e1) (go e2)
        EAbs x t e                -> EAbs x (inst env t) (go e)
        EProofAbs p e             -> EProofAbs (inst env p) (go e)
        EProofApp e               -> EProofApp (go e)
        EWhere e ds               -> EWhere (go e) (inst env ds)


instance Inst DeclGroup where
  inst env dg =
    case dg of
      NonRecursive d -> NonRecursive (inst env d)
      Recursive ds   -> Recursive (inst env ds)

instance Inst DeclDef where
  inst env d =
    case d of
      DPrim   -> DPrim
      DExpr e -> DExpr (inst env e)

instance Inst Decl where
  inst env d = d { dSignature = inst env (dSignature d)
                 , dDefinition = inst env (dDefinition d)
                 , dName = Map.findWithDefault (dName d) (dName d)
                                                         (funNameMap env)
                 }

instance Inst Match where
  inst env m =
    case m of
      From x t1 t2 e -> From x (inst env t1) (inst env t2) (inst env e)
      Let d          -> Let (inst env d)

instance Inst Schema where
  inst env s = s { sProps = inst env (sProps s)
                 , sType  = inst env (sType s)
                 }

instance Inst Type where
  inst env ty =
    case ty of
      TCon tc ts    -> TCon (inst env tc) (inst env ts)
      TVar tv       ->
        case tv of
          TVBound tp | Just t <- Map.lookup tp (tyParamMap env) -> t
          _ -> ty
      TUser x ts t  -> TUser y (inst env ts) (inst env t)
        where y = Map.findWithDefault x x (tyNameMap env)
      TRec fs       -> TRec (fmap (inst env) fs)

instance Inst TCon where
  inst env tc =
    case tc of
      TC x -> TC (inst env x)
      _    -> tc

instance Inst TC where
  inst env tc =
    case tc of
      TCNewtype x  -> TCNewtype (inst env x)
      TCAbstract x -> TCAbstract (inst env x)
      _            -> tc

instance Inst UserTC where
  inst env (UserTC x t) = UserTC y t
    where y = Map.findWithDefault x x (tyNameMap env)

instance Inst (ExportSpec Name) where
  inst env es = ExportSpec { eTypes = Set.map instT (eTypes es)
                           , eBinds = Set.map instV (eBinds es)
                           }
    where instT x = Map.findWithDefault x x (tyNameMap env)
          instV x = Map.findWithDefault x x (funNameMap env)

instance Inst TySyn where
  inst env ts = TySyn { tsName = instTyName env x
                      , tsParams = tsParams ts
                      , tsConstraints = inst env (tsConstraints ts)
                      , tsDef = inst env (tsDef ts)
                      , tsDoc = tsDoc ts
                      }
    where x = tsName ts

instance Inst Newtype where
  inst env nt = Newtype { ntName = instTyName env x
                        , ntParams = ntParams nt
                        , ntConstraints = inst env (ntConstraints nt)
                        , ntFields = [ (f,inst env t) | (f,t) <- ntFields nt ]
                        , ntDoc = ntDoc nt
                        }
    where x = ntName nt

instance Inst AbstractType where
  inst env a = AbstractType { atName    = instTyName env (atName a)
                            , atKind    = atKind a
                            , atCtrs    = case atCtrs a of
                                            (xs,ps) -> (xs, inst env ps)
                            , atFixitiy = atFixitiy a
                            , atDoc     = atDoc a
                            }

instTyName :: Env -> Name -> Name
instTyName env x = Map.findWithDefault x x (tyNameMap env)