{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings     #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE TupleSections         #-}
{-# LANGUAGE TypeSynonymInstances  #-}
{-# LANGUAGE UndecidableInstances  #-}
{-# LANGUAGE BangPatterns          #-}
{-# LANGUAGE ConstraintKinds       #-}

module Language.Haskell.Liquid.Constraint.Fresh
  ( -- module Language.Haskell.Liquid.Types.Fresh
    -- , 
    refreshArgsTop
  , freshTy_type
  , freshTy_expr
  , trueTy
  , addKuts
  )
  where

-- import           Data.Maybe                    (catMaybes) -- , fromJust, isJust)
-- import           Data.Bifunctor
-- import qualified Data.List                      as L
import qualified Data.HashMap.Strict            as M
import qualified Data.HashSet                   as S
import           Data.Hashable
import           Control.Monad.State            (gets, get, put, modify)
import           Control.Monad                  (when, (>=>))
import           Prelude                        hiding (error)

import           CoreUtils  (exprType)
import           Type       (Type)
import           CoreSyn
import           Var        (varType, isTyVar, Var)

import           Language.Fixpoint.Misc  ((=>>))
import qualified Language.Fixpoint.Types as F
import           Language.Fixpoint.Types.Visitor (kvars)
import           Language.Haskell.Liquid.Types
-- import           Language.Haskell.Liquid.Types.RefType
-- import           Language.Haskell.Liquid.Types.Fresh
import           Language.Haskell.Liquid.Constraint.Types
import qualified Language.Haskell.Liquid.GHC.Misc as GM 

--------------------------------------------------------------------------------
-- | This is all hardwiring stuff to CG ----------------------------------------
--------------------------------------------------------------------------------
instance Freshable CG Integer where
  fresh :: CG Integer
fresh = do CGInfo
s <- StateT CGInfo Identity CGInfo
forall s (m :: * -> *). MonadState s m => m s
get
             let n :: Integer
n = CGInfo -> Integer
freshIndex CGInfo
s
             CGInfo -> StateT CGInfo Identity ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (CGInfo -> StateT CGInfo Identity ())
-> CGInfo -> StateT CGInfo Identity ()
forall a b. (a -> b) -> a -> b
$ CGInfo
s { freshIndex :: Integer
freshIndex = Integer
n Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
1 }
             Integer -> CG Integer
forall (m :: * -> *) a. Monad m => a -> m a
return Integer
n

--------------------------------------------------------------------------------
refreshArgsTop :: (Var, SpecType) -> CG SpecType
--------------------------------------------------------------------------------
refreshArgsTop :: (Var, SpecType) -> CG SpecType
refreshArgsTop (Var
x, SpecType
t)
  = do (SpecType
t', Subst
su) <- SpecType -> StateT CGInfo Identity (SpecType, Subst)
forall (m :: * -> *). FreshM m => SpecType -> m (SpecType, Subst)
refreshArgsSub SpecType
t
       (CGInfo -> CGInfo) -> StateT CGInfo Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CGInfo -> CGInfo) -> StateT CGInfo Identity ())
-> (CGInfo -> CGInfo) -> StateT CGInfo Identity ()
forall a b. (a -> b) -> a -> b
$ \CGInfo
s -> CGInfo
s {termExprs :: HashMap Var [Located Expr]
termExprs = ([Located Expr] -> [Located Expr])
-> Var -> HashMap Var [Located Expr] -> HashMap Var [Located Expr]
forall k v.
(Eq k, Hashable k) =>
(v -> v) -> k -> HashMap k v -> HashMap k v
M.adjust (Subst -> Located Expr -> Located Expr
forall a. Subable a => Subst -> a -> a
F.subst Subst
su (Located Expr -> Located Expr) -> [Located Expr] -> [Located Expr]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>) Var
x (HashMap Var [Located Expr] -> HashMap Var [Located Expr])
-> HashMap Var [Located Expr] -> HashMap Var [Located Expr]
forall a b. (a -> b) -> a -> b
$ CGInfo -> HashMap Var [Located Expr]
termExprs CGInfo
s}
       SpecType -> CG SpecType
forall (m :: * -> *) a. Monad m => a -> m a
return SpecType
t'

--------------------------------------------------------------------------------
-- | Generation: Freshness -----------------------------------------------------
--------------------------------------------------------------------------------

-- | Right now, we generate NO new pvars. Rather than clutter code
--   with `uRType` calls, put it in one place where the above
--   invariant is /obviously/ enforced.
--   Constraint generation should ONLY use @freshTy_type@ and @freshTy_expr@

freshTy_type        :: KVKind -> CoreExpr -> Type -> CG SpecType
freshTy_type :: KVKind -> CoreExpr -> Type -> CG SpecType
freshTy_type KVKind
k CoreExpr
e Type
τ  =  String -> SpecType -> SpecType
forall a. PPrint a => String -> a -> a
F.notracepp (String
"freshTy_type: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ KVKind -> String
forall a. PPrint a => a -> String
F.showpp KVKind
k String -> String -> String
forall a. [a] -> [a] -> [a]
++ CoreExpr -> String
forall a. Outputable a => a -> String
GM.showPpr CoreExpr
e) 
                   (SpecType -> SpecType) -> CG SpecType -> CG SpecType
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KVKind -> SpecType -> CG SpecType
freshTy_reftype KVKind
k (Type -> SpecType
forall r. Monoid r => Type -> RRType r
ofType Type
τ)

freshTy_expr        :: KVKind -> CoreExpr -> Type -> CG SpecType
freshTy_expr :: KVKind -> CoreExpr -> Type -> CG SpecType
freshTy_expr KVKind
k CoreExpr
e Type
_  = KVKind -> SpecType -> CG SpecType
freshTy_reftype KVKind
k (SpecType -> CG SpecType) -> SpecType -> CG SpecType
forall a b. (a -> b) -> a -> b
$ CoreExpr -> SpecType
exprRefType CoreExpr
e

freshTy_reftype     :: KVKind -> SpecType -> CG SpecType
freshTy_reftype :: KVKind -> SpecType -> CG SpecType
freshTy_reftype KVKind
k SpecType
_t = (SpecType -> CG SpecType
fixTy SpecType
t CG SpecType -> (SpecType -> CG SpecType) -> CG SpecType
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= SpecType -> CG SpecType
forall (m :: * -> *) a. Freshable m a => a -> m a
refresh) CG SpecType
-> (SpecType -> StateT CGInfo Identity ()) -> CG SpecType
forall (m :: * -> *) b a. Monad m => m b -> (b -> m a) -> m b
=>> KVKind -> SpecType -> StateT CGInfo Identity ()
addKVars KVKind
k
  where
    t :: SpecType
t                = {- F.tracepp ("freshTy_reftype:" ++ show k) -} SpecType
_t

-- | Used to generate "cut" kvars for fixpoint. Typically, KVars for recursive
--   definitions, and also to update the KVar profile.
addKVars        :: KVKind -> SpecType -> CG ()
addKVars :: KVKind -> SpecType -> StateT CGInfo Identity ()
addKVars !KVKind
k !SpecType
t  = do
    Config
cfg <- TargetInfo -> Config
forall t. HasConfig t => t -> Config
getConfig  (TargetInfo -> Config)
-> StateT CGInfo Identity TargetInfo
-> StateT CGInfo Identity Config
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (CGInfo -> TargetInfo) -> StateT CGInfo Identity TargetInfo
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets CGInfo -> TargetInfo
ghcI
    Bool -> StateT CGInfo Identity () -> StateT CGInfo Identity ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool
True)        (StateT CGInfo Identity () -> StateT CGInfo Identity ())
-> StateT CGInfo Identity () -> StateT CGInfo Identity ()
forall a b. (a -> b) -> a -> b
$ (CGInfo -> CGInfo) -> StateT CGInfo Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CGInfo -> CGInfo) -> StateT CGInfo Identity ())
-> (CGInfo -> CGInfo) -> StateT CGInfo Identity ()
forall a b. (a -> b) -> a -> b
$ \CGInfo
s -> CGInfo
s { kvProf :: KVProf
kvProf = KVKind -> Kuts -> KVProf -> KVProf
updKVProf KVKind
k Kuts
ks (CGInfo -> KVProf
kvProf CGInfo
s) }
    Bool -> StateT CGInfo Identity () -> StateT CGInfo Identity ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Config -> KVKind -> Bool
isKut Config
cfg KVKind
k) (StateT CGInfo Identity () -> StateT CGInfo Identity ())
-> StateT CGInfo Identity () -> StateT CGInfo Identity ()
forall a b. (a -> b) -> a -> b
$ KVKind -> SpecType -> StateT CGInfo Identity ()
forall a. PPrint a => a -> SpecType -> StateT CGInfo Identity ()
addKuts KVKind
k SpecType
t
  where
    ks :: Kuts
ks         = HashSet KVar -> Kuts
F.KS (HashSet KVar -> Kuts) -> HashSet KVar -> Kuts
forall a b. (a -> b) -> a -> b
$ [KVar] -> HashSet KVar
forall a. (Eq a, Hashable a) => [a] -> HashSet a
S.fromList ([KVar] -> HashSet KVar) -> [KVar] -> HashSet KVar
forall a b. (a -> b) -> a -> b
$ SpecType -> [KVar]
specTypeKVars SpecType
t

isKut :: Config -> KVKind -> Bool
isKut :: Config -> KVKind -> Bool
isKut Config
_  (RecBindE Var
_) = Bool
True
isKut Config
cfg KVKind
ProjectE    = Bool -> Bool
not (Config -> Bool
forall t. HasConfig t => t -> Bool
higherOrderFlag Config
cfg) -- see ISSUE 1034, tests/pos/T1034.hs
isKut Config
_    KVKind
_          = Bool
False

addKuts :: (PPrint a) => a -> SpecType -> CG ()
addKuts :: a -> SpecType -> StateT CGInfo Identity ()
addKuts a
_x SpecType
t = (CGInfo -> CGInfo) -> StateT CGInfo Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CGInfo -> CGInfo) -> StateT CGInfo Identity ())
-> (CGInfo -> CGInfo) -> StateT CGInfo Identity ()
forall a b. (a -> b) -> a -> b
$ \CGInfo
s -> CGInfo
s { kuts :: Kuts
kuts = Kuts -> Kuts -> Kuts
forall a. Monoid a => a -> a -> a
mappend (HashSet KVar -> Kuts
F.KS HashSet KVar
ks) (CGInfo -> Kuts
kuts CGInfo
s)   }
  where
     ks' :: HashSet KVar
ks'     = [KVar] -> HashSet KVar
forall a. (Eq a, Hashable a) => [a] -> HashSet a
S.fromList ([KVar] -> HashSet KVar) -> [KVar] -> HashSet KVar
forall a b. (a -> b) -> a -> b
$ SpecType -> [KVar]
specTypeKVars SpecType
t
     ks :: HashSet KVar
ks
       | HashSet KVar -> Bool
forall a. HashSet a -> Bool
S.null HashSet KVar
ks' = HashSet KVar
ks'
       | Bool
otherwise  = {- F.tracepp ("addKuts: " ++ showpp _x) -} HashSet KVar
ks'

specTypeKVars :: SpecType -> [F.KVar]
specTypeKVars :: SpecType -> [KVar]
specTypeKVars = Bool
-> (SEnv SpecType -> UReft Reft -> [KVar] -> [KVar])
-> [KVar]
-> SpecType
-> [KVar]
forall r c tv a.
(Reftable r, TyConable c) =>
Bool
-> (SEnv (RType c tv r) -> r -> a -> a) -> a -> RType c tv r -> a
foldReft Bool
False (\ SEnv SpecType
_ UReft Reft
r [KVar]
ks -> (Reft -> [KVar]
forall t. Visitable t => t -> [KVar]
kvars (Reft -> [KVar]) -> Reft -> [KVar]
forall a b. (a -> b) -> a -> b
$ UReft Reft -> Reft
forall r. UReft r -> r
ur_reft UReft Reft
r) [KVar] -> [KVar] -> [KVar]
forall a. [a] -> [a] -> [a]
++ [KVar]
ks) []

--------------------------------------------------------------------------------
trueTy  :: Type -> CG SpecType
--------------------------------------------------------------------------------
trueTy :: Type -> CG SpecType
trueTy = Type -> CG SpecType
ofType' (Type -> CG SpecType)
-> (SpecType -> CG SpecType) -> Type -> CG SpecType
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> SpecType -> CG SpecType
forall (m :: * -> *) a. Freshable m a => a -> m a
true

ofType' :: Type -> CG SpecType
ofType' :: Type -> CG SpecType
ofType' = SpecType -> CG SpecType
fixTy (SpecType -> CG SpecType)
-> (Type -> SpecType) -> Type -> CG SpecType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> SpecType
forall r. Monoid r => Type -> RRType r
ofType

fixTy :: SpecType -> CG SpecType
fixTy :: SpecType -> CG SpecType
fixTy SpecType
t = do TyConMap
tyi   <- CGInfo -> TyConMap
tyConInfo  (CGInfo -> TyConMap)
-> StateT CGInfo Identity CGInfo -> StateT CGInfo Identity TyConMap
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StateT CGInfo Identity CGInfo
forall s (m :: * -> *). MonadState s m => m s
get
             TCEmb TyCon
tce   <- CGInfo -> TCEmb TyCon
tyConEmbed (CGInfo -> TCEmb TyCon)
-> StateT CGInfo Identity CGInfo
-> StateT CGInfo Identity (TCEmb TyCon)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StateT CGInfo Identity CGInfo
forall s (m :: * -> *). MonadState s m => m s
get
             SpecType -> CG SpecType
forall (m :: * -> *) a. Monad m => a -> m a
return (SpecType -> CG SpecType) -> SpecType -> CG SpecType
forall a b. (a -> b) -> a -> b
$ TCEmb TyCon -> TyConMap -> SpecType -> SpecType
forall r.
(PPrint r, Reftable r, SubsTy RTyVar (RType RTyCon RTyVar ()) r,
 Reftable (RTProp RTyCon RTyVar r)) =>
TCEmb TyCon -> TyConMap -> RRType r -> RRType r
addTyConInfo TCEmb TyCon
tce TyConMap
tyi SpecType
t

exprRefType :: CoreExpr -> SpecType
exprRefType :: CoreExpr -> SpecType
exprRefType = HashMap Var SpecType -> CoreExpr -> SpecType
exprRefType_ HashMap Var SpecType
forall k v. HashMap k v
M.empty

exprRefType_ :: M.HashMap Var SpecType -> CoreExpr -> SpecType
exprRefType_ :: HashMap Var SpecType -> CoreExpr -> SpecType
exprRefType_ HashMap Var SpecType
γ (Let Bind Var
b CoreExpr
e)
  = HashMap Var SpecType -> CoreExpr -> SpecType
exprRefType_ (HashMap Var SpecType -> Bind Var -> HashMap Var SpecType
bindRefType_ HashMap Var SpecType
γ Bind Var
b) CoreExpr
e

exprRefType_ HashMap Var SpecType
γ (Lam Var
α CoreExpr
e) | Var -> Bool
isTyVar Var
α
  = RTVU RTyCon RTyVar -> SpecType -> UReft Reft -> SpecType
forall c tv r. RTVU c tv -> RType c tv r -> r -> RType c tv r
RAllT (RTyVar -> RTVU RTyCon RTyVar
forall tv s. tv -> RTVar tv s
makeRTVar (RTyVar -> RTVU RTyCon RTyVar) -> RTyVar -> RTVU RTyCon RTyVar
forall a b. (a -> b) -> a -> b
$ Var -> RTyVar
rTyVar Var
α) (HashMap Var SpecType -> CoreExpr -> SpecType
exprRefType_ HashMap Var SpecType
γ CoreExpr
e) UReft Reft
forall a. Monoid a => a
mempty

exprRefType_ HashMap Var SpecType
γ (Lam Var
x CoreExpr
e)
  = Symbol -> SpecType -> SpecType -> SpecType
forall r c tv.
Monoid r =>
Symbol -> RType c tv r -> RType c tv r -> RType c tv r
rFun (Var -> Symbol
forall a. Symbolic a => a -> Symbol
F.symbol Var
x) (Type -> SpecType
forall r. Monoid r => Type -> RRType r
ofType (Type -> SpecType) -> Type -> SpecType
forall a b. (a -> b) -> a -> b
$ Var -> Type
varType Var
x) (HashMap Var SpecType -> CoreExpr -> SpecType
exprRefType_ HashMap Var SpecType
γ CoreExpr
e)

exprRefType_ HashMap Var SpecType
γ (Tick Tickish Var
_ CoreExpr
e)
  = HashMap Var SpecType -> CoreExpr -> SpecType
exprRefType_ HashMap Var SpecType
γ CoreExpr
e

exprRefType_ HashMap Var SpecType
γ (Var Var
x)
  = SpecType -> Var -> HashMap Var SpecType -> SpecType
forall k v. (Eq k, Hashable k) => v -> k -> HashMap k v -> v
M.lookupDefault (Type -> SpecType
forall r. Monoid r => Type -> RRType r
ofType (Type -> SpecType) -> Type -> SpecType
forall a b. (a -> b) -> a -> b
$ Var -> Type
varType Var
x) Var
x HashMap Var SpecType
γ

exprRefType_ HashMap Var SpecType
_ CoreExpr
e
  = Type -> SpecType
forall r. Monoid r => Type -> RRType r
ofType (Type -> SpecType) -> Type -> SpecType
forall a b. (a -> b) -> a -> b
$ CoreExpr -> Type
exprType CoreExpr
e

bindRefType_ :: M.HashMap Var SpecType -> Bind Var -> M.HashMap Var SpecType
bindRefType_ :: HashMap Var SpecType -> Bind Var -> HashMap Var SpecType
bindRefType_ HashMap Var SpecType
γ (Rec [(Var, CoreExpr)]
xes)
  = HashMap Var SpecType -> [(Var, SpecType)] -> HashMap Var SpecType
forall k (t :: * -> *) v.
(Eq k, Foldable t, Hashable k) =>
HashMap k v -> t (k, v) -> HashMap k v
extendγ HashMap Var SpecType
γ [(Var
x, HashMap Var SpecType -> CoreExpr -> SpecType
exprRefType_ HashMap Var SpecType
γ CoreExpr
e) | (Var
x,CoreExpr
e) <- [(Var, CoreExpr)]
xes]

bindRefType_ HashMap Var SpecType
γ (NonRec Var
x CoreExpr
e)
  = HashMap Var SpecType -> [(Var, SpecType)] -> HashMap Var SpecType
forall k (t :: * -> *) v.
(Eq k, Foldable t, Hashable k) =>
HashMap k v -> t (k, v) -> HashMap k v
extendγ HashMap Var SpecType
γ [(Var
x, HashMap Var SpecType -> CoreExpr -> SpecType
exprRefType_ HashMap Var SpecType
γ CoreExpr
e)]

extendγ :: (Eq k, Foldable t, Hashable k)
        => M.HashMap k v
        -> t (k, v)
        -> M.HashMap k v
extendγ :: HashMap k v -> t (k, v) -> HashMap k v
extendγ HashMap k v
γ t (k, v)
xts
  = ((k, v) -> HashMap k v -> HashMap k v)
-> HashMap k v -> t (k, v) -> HashMap k v
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\(k
x,v
t) HashMap k v
m -> k -> v -> HashMap k v -> HashMap k v
forall k v.
(Eq k, Hashable k) =>
k -> v -> HashMap k v -> HashMap k v
M.insert k
x v
t HashMap k v
m) HashMap k v
γ t (k, v)
xts