{-# LANGUAGE FlexibleInstances #-}
module Intensional.InferM
( InferM,
Context,
InferEnv (..),
Stats (..),
runInferM,
Intensional.InferM.saturate,
branchAny,
emitDD,
emitDK,
emitKD,
fresh,
putVar,
putVars,
setLoc,
getExternalName,
isTrivial,
isIneligible,
noteD,
noteK,
incrN,
noteErrs,
Intensional.InferM.cexs
)
where
import Intensional.Constraints as Constraints
import Intensional.Constructors
import Control.Monad.RWS.Strict hiding (guard)
import qualified Data.IntSet as IntSet
import qualified Data.Map as M
import GhcPlugins hiding ((<>), singleton)
import Intensional.Scheme
import Intensional.Types
import Intensional.Ubiq
import Intensional.Guard
type InferM = RWS InferEnv ConstraintSet InferState
type Context = M.Map Name Scheme
data InferEnv
= InferEnv
{ InferEnv -> Module
modName :: Module,
InferEnv -> Context
varEnv :: Context,
InferEnv -> SrcSpan
inferLoc :: SrcSpan
}
data InferState =
InferState {
InferState -> Int
maxK :: Int,
InferState -> Int
maxD :: Int,
InferState -> Int
maxI :: Int,
InferState -> Int
cntN :: Int,
InferState -> Int
rVar :: Int,
InferState -> ConstraintSet
errs :: ConstraintSet
}
initState :: InferState
initState :: InferState
initState = Int -> Int -> Int -> Int -> Int -> ConstraintSet -> InferState
InferState 0 0 0 0 0 ConstraintSet
forall a. Monoid a => a
mempty
noteK :: Int -> InferM ()
noteK :: Int -> InferM ()
noteK x :: Int
x = (InferState -> InferState) -> InferM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\s :: InferState
s -> InferState
s { maxK :: Int
maxK = Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
x (InferState -> Int
maxK InferState
s) })
noteD :: Int -> InferM ()
noteD :: Int -> InferM ()
noteD x :: Int
x = (InferState -> InferState) -> InferM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\s :: InferState
s -> InferState
s { maxD :: Int
maxD = Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
x (InferState -> Int
maxD InferState
s) })
noteI :: Int -> InferM ()
noteI :: Int -> InferM ()
noteI x :: Int
x = (InferState -> InferState) -> InferM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\s :: InferState
s -> InferState
s { maxI :: Int
maxI = Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
x (InferState -> Int
maxI InferState
s) })
incrV :: InferM ()
incrV :: InferM ()
incrV = (InferState -> InferState) -> InferM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\s :: InferState
s -> InferState
s { rVar :: Int
rVar = InferState -> Int
rVar InferState
s Int -> Int -> Int
forall a. Num a => a -> a -> a
+ 1 })
incrN :: InferM ()
incrN :: InferM ()
incrN = (InferState -> InferState) -> InferM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\s :: InferState
s -> InferState
s { cntN :: Int
cntN = InferState -> Int
cntN InferState
s Int -> Int -> Int
forall a. Num a => a -> a -> a
+ 1 })
noteErrs :: ConstraintSet -> InferM ()
noteErrs :: ConstraintSet -> InferM ()
noteErrs es :: ConstraintSet
es = (InferState -> InferState) -> InferM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\s :: InferState
s -> InferState
s { errs :: ConstraintSet
errs = ConstraintSet
es ConstraintSet -> ConstraintSet -> ConstraintSet
forall a. Semigroup a => a -> a -> a
<> InferState -> ConstraintSet
errs InferState
s })
data Stats =
Stats {
Stats -> Int
getK :: Int,
Stats -> Int
getD :: Int,
Stats -> Int
getV :: Int,
Stats -> Int
getI :: Int,
Stats -> Int
getN :: Int
}
runInferM ::
InferM a ->
Module ->
Context ->
(a, [Atomic], Stats)
runInferM :: InferM a -> Module -> Context -> (a, [Atomic], Stats)
runInferM run :: InferM a
run mod_name :: Module
mod_name init_env :: Context
init_env =
let (a :: a
a, s :: InferState
s, _) = InferM a
-> InferEnv -> InferState -> (a, InferState, ConstraintSet)
forall r w s a. RWS r w s a -> r -> s -> (a, s, w)
runRWS InferM a
run (Module -> Context -> SrcSpan -> InferEnv
InferEnv Module
mod_name Context
init_env (FastString -> SrcSpan
UnhelpfulSpan (String -> FastString
mkFastString "Nowhere"))) InferState
initState
in (a
a, ConstraintSet -> [Atomic]
forall (t :: * -> *) a. Foldable t => t a -> [a]
Constraints.toList (InferState -> ConstraintSet
errs InferState
s), Int -> Int -> Int -> Int -> Int -> Stats
Stats (InferState -> Int
maxK InferState
s) (InferState -> Int
maxD InferState
s) (InferState -> Int
rVar InferState
s) (InferState -> Int
maxI InferState
s) (InferState -> Int
cntN InferState
s))
saturate :: Refined a => InferM a -> InferM a
saturate :: InferM a -> InferM a
saturate ma :: InferM a
ma = RWST
InferEnv
ConstraintSet
InferState
Identity
(a, ConstraintSet -> ConstraintSet)
-> InferM a
forall w (m :: * -> *) a. MonadWriter w m => m (a, w -> w) -> m a
pass (RWST
InferEnv
ConstraintSet
InferState
Identity
(a, ConstraintSet -> ConstraintSet)
-> InferM a)
-> RWST
InferEnv
ConstraintSet
InferState
Identity
(a, ConstraintSet -> ConstraintSet)
-> InferM a
forall a b. (a -> b) -> a -> b
$
do
a
a <- InferM a
ma
Context
env <- (InferEnv -> Context)
-> RWST InferEnv ConstraintSet InferState Identity Context
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks InferEnv -> Context
varEnv
Module
m <- (InferEnv -> Module)
-> RWST InferEnv ConstraintSet InferState Identity Module
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks InferEnv -> Module
modName
SrcSpan
src <- (InferEnv -> SrcSpan)
-> RWST InferEnv ConstraintSet InferState Identity SrcSpan
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks InferEnv -> SrcSpan
inferLoc
let interface :: Domain
interface = a -> Domain
forall t. Refined t => t -> Domain
domain a
a Domain -> Domain -> Domain
forall a. Semigroup a => a -> a -> a
<> Context -> Domain
forall t. Refined t => t -> Domain
domain Context
env
Int -> InferM ()
noteI (Domain -> Int
IntSet.size Domain
interface)
let fn :: ConstraintSet -> ConstraintSet
fn cs :: ConstraintSet
cs =
let ds :: ConstraintSet
ds = CInfo -> Domain -> ConstraintSet -> ConstraintSet
Constraints.saturate (Module -> SrcSpan -> CInfo
CInfo Module
m SrcSpan
src) Domain
interface ConstraintSet
cs
in if Bool
debugging then a
-> Context
-> SrcSpan
-> ConstraintSet
-> ConstraintSet
-> ConstraintSet
forall t t.
(Refined t, Refined t) =>
t
-> t -> SrcSpan -> ConstraintSet -> ConstraintSet -> ConstraintSet
debugBracket a
a Context
env SrcSpan
src ConstraintSet
cs ConstraintSet
ds else ConstraintSet
ds
(a, ConstraintSet -> ConstraintSet)
-> RWST
InferEnv
ConstraintSet
InferState
Identity
(a, ConstraintSet -> ConstraintSet)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
a, ConstraintSet -> ConstraintSet
fn)
where
debugBracket :: t
-> t -> SrcSpan -> ConstraintSet -> ConstraintSet -> ConstraintSet
debugBracket a :: t
a env :: t
env src :: SrcSpan
src cs :: ConstraintSet
cs ds :: ConstraintSet
ds =
let asz :: String
asz = "type: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (Domain -> Int
IntSet.size (Domain -> Int) -> Domain -> Int
forall a b. (a -> b) -> a -> b
$ t -> Domain
forall t. Refined t => t -> Domain
domain t
a)
esz :: String
esz = "env: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (Domain -> Int
IntSet.size (Domain -> Int) -> Domain -> Int
forall a b. (a -> b) -> a -> b
$ t -> Domain
forall t. Refined t => t -> Domain
domain t
env)
csz :: String
csz = Int -> String
forall a. Show a => a -> String
show (ConstraintSet -> Int
size ConstraintSet
cs)
spn :: String
spn = SrcSpan -> String
traceSpan SrcSpan
src
tmsg :: String
tmsg = "#interface = (" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
asz String -> String -> String
forall a. [a] -> [a] -> [a]
++ " + " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
esz String -> String -> String
forall a. [a] -> [a] -> [a]
++ "), #constraints = " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
csz
ds' :: ConstraintSet
ds' = String -> ConstraintSet -> ConstraintSet
forall a. String -> a -> a
trace ("[TRACE] BEGIN saturate at " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
spn String -> String -> String
forall a. [a] -> [a] -> [a]
++ ": " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
tmsg) ConstraintSet
ds
in ConstraintSet
ds' ConstraintSet -> ConstraintSet -> ConstraintSet
forall a b. a -> b -> b
`seq` String -> ConstraintSet -> ConstraintSet
forall a. String -> a -> a
trace ("[TRACE] END saturate at " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
spn String -> String -> String
forall a. [a] -> [a] -> [a]
++ " saturated size: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ (Int -> String
forall a. Show a => a -> String
show (Int -> String) -> Int -> String
forall a b. (a -> b) -> a -> b
$ ConstraintSet -> Int
size ConstraintSet
ds)) ConstraintSet
ds
cexs :: ConstraintSet -> InferM ConstraintSet
cexs :: ConstraintSet -> InferM ConstraintSet
cexs cs :: ConstraintSet
cs =
do Module
m <- (InferEnv -> Module)
-> RWST InferEnv ConstraintSet InferState Identity Module
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks InferEnv -> Module
modName
SrcSpan
src <- (InferEnv -> SrcSpan)
-> RWST InferEnv ConstraintSet InferState Identity SrcSpan
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks InferEnv -> SrcSpan
inferLoc
ConstraintSet -> InferM ConstraintSet
forall (m :: * -> *) a. Monad m => a -> m a
return (ConstraintSet -> InferM ConstraintSet)
-> ConstraintSet -> InferM ConstraintSet
forall a b. (a -> b) -> a -> b
$ CInfo -> ConstraintSet -> ConstraintSet
Constraints.cexs (Module -> SrcSpan -> CInfo
CInfo Module
m SrcSpan
src) ConstraintSet
cs
isIneligible :: TyCon -> InferM Bool
isIneligible :: TyCon -> InferM Bool
isIneligible tc :: TyCon
tc =
do Module
m <- (InferEnv -> Module)
-> RWST InferEnv ConstraintSet InferState Identity Module
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks InferEnv -> Module
modName
Bool -> InferM Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> Bool
not (Module -> Name -> Bool
homeOrBase Module
m (TyCon -> Name
forall a. NamedThing a => a -> Name
getName TyCon
tc)) Bool -> Bool -> Bool
|| [DataCon] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (TyCon -> [DataCon]
tyConDataCons TyCon
tc))
where
homeOrBase :: Module -> Name -> Bool
homeOrBase m :: Module
m n :: Name
n =
Module -> Name -> Bool
nameIsHomePackage Module
m Name
n
isTrivial :: TyCon -> Bool
isTrivial :: TyCon -> Bool
isTrivial tc :: TyCon
tc = (Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 1) ([DataCon] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (TyCon -> [DataCon]
tyConDataCons TyCon
tc))
branchAny :: [DataCon] -> DataType TyCon -> InferM a -> InferM a
branchAny :: [DataCon] -> DataType TyCon -> InferM a -> InferM a
branchAny _ (Base _) m :: InferM a
m = InferM a
m
branchAny ks :: [DataCon]
ks (Inj x :: Int
x d :: TyCon
d) m :: InferM a
m =
if (TyCon -> Bool
isTrivial TyCon
d) then InferM a
m else (ConstraintSet -> ConstraintSet) -> InferM a -> InferM a
forall w (m :: * -> *) a. MonadWriter w m => (w -> w) -> m a -> m a
censor ConstraintSet -> ConstraintSet
guardWithAll InferM a
m
where
dn :: Name
dn = TyCon -> Name
forall a. NamedThing a => a -> Name
getName TyCon
d
guardWithAll :: ConstraintSet -> ConstraintSet
guardWithAll cs :: ConstraintSet
cs =
(DataCon -> ConstraintSet) -> [DataCon] -> ConstraintSet
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (\k :: DataCon
k -> Guard -> ConstraintSet -> ConstraintSet
Constraints.guardWith ([Name] -> Int -> Name -> Guard
singleton [DataCon -> Name
forall a. NamedThing a => a -> Name
getName DataCon
k] Int
x Name
dn) ConstraintSet
cs) [DataCon]
ks
mkConFromCtx :: ConL -> ConR -> InferM Atomic
mkConFromCtx :: ConL -> ConR -> InferM Atomic
mkConFromCtx l :: ConL
l r :: ConR
r =
do Module
m <- (InferEnv -> Module)
-> RWST InferEnv ConstraintSet InferState Identity Module
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks InferEnv -> Module
modName
SrcSpan
s <- (InferEnv -> SrcSpan)
-> RWST InferEnv ConstraintSet InferState Identity SrcSpan
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks InferEnv -> SrcSpan
inferLoc
Atomic -> InferM Atomic
forall (m :: * -> *) a. Monad m => a -> m a
return (ConL -> ConR -> Guard -> CInfo -> Atomic
forall (l :: Side) (r :: Side).
K l -> K r -> Guard -> CInfo -> Constraint l r
Constraint ConL
l ConR
r Guard
forall a. Monoid a => a
mempty (Module -> SrcSpan -> CInfo
CInfo Module
m SrcSpan
s))
emitDD :: DataType TyCon -> DataType TyCon -> InferM ()
emitDD :: DataType TyCon -> DataType TyCon -> InferM ()
emitDD (Inj x :: Int
x d :: TyCon
d) (Inj y :: Int
y _) =
Bool -> InferM () -> InferM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (TyCon -> Bool
isTrivial TyCon
d) (InferM () -> InferM ()) -> InferM () -> InferM ()
forall a b. (a -> b) -> a -> b
$
do Atomic
a <- ConL -> ConR -> InferM Atomic
mkConFromCtx (DataType Name -> ConL
forall (s :: Side). DataType Name -> K s
Dom (Int -> Name -> DataType Name
forall d. Int -> d -> DataType d
Inj Int
x Name
dn)) (DataType Name -> ConR
forall (s :: Side). DataType Name -> K s
Dom (Int -> Name -> DataType Name
forall d. Int -> d -> DataType d
Inj Int
y Name
dn))
ConstraintSet -> InferM ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([Atomic] -> ConstraintSet
Constraints.fromList [Atomic
a])
where
dn :: Name
dn = TyCon -> Name
forall a. NamedThing a => a -> Name
getName TyCon
d
emitDD _ _ = () -> InferM ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
emitKD :: DataCon -> SrcSpan -> DataType TyCon -> InferM ()
emitKD :: DataCon -> SrcSpan -> DataType TyCon -> InferM ()
emitKD k :: DataCon
k s :: SrcSpan
s (Inj x :: Int
x d :: TyCon
d) =
Bool -> InferM () -> InferM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (TyCon -> Bool
isTrivial TyCon
d) (InferM () -> InferM ()) -> InferM () -> InferM ()
forall a b. (a -> b) -> a -> b
$
do Atomic
a <- ConL -> ConR -> InferM Atomic
mkConFromCtx (Name -> SrcSpan -> ConL
Con Name
kn SrcSpan
s) (DataType Name -> ConR
forall (s :: Side). DataType Name -> K s
Dom (Int -> Name -> DataType Name
forall d. Int -> d -> DataType d
Inj Int
x Name
dn))
ConstraintSet -> InferM ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([Atomic] -> ConstraintSet
Constraints.fromList [Atomic
a])
where
dn :: Name
dn = TyCon -> Name
forall a. NamedThing a => a -> Name
getName TyCon
d
kn :: Name
kn = DataCon -> Name
forall a. NamedThing a => a -> Name
getName DataCon
k
emitKD _ _ _ = () -> InferM ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
emitDK :: DataType TyCon -> [DataCon] -> SrcSpan -> InferM ()
emitDK :: DataType TyCon -> [DataCon] -> SrcSpan -> InferM ()
emitDK (Inj x :: Int
x d :: TyCon
d) ks :: [DataCon]
ks s :: SrcSpan
s =
Bool -> InferM () -> InferM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (TyCon -> Bool
isTrivial TyCon
d Bool -> Bool -> Bool
|| [DataCon] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (TyCon -> [DataCon]
tyConDataCons TyCon
d) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [DataCon] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DataCon]
ks) (InferM () -> InferM ()) -> InferM () -> InferM ()
forall a b. (a -> b) -> a -> b
$
do Atomic
a <- ConL -> ConR -> InferM Atomic
mkConFromCtx (DataType Name -> ConL
forall (s :: Side). DataType Name -> K s
Dom (Int -> Name -> DataType Name
forall d. Int -> d -> DataType d
Inj Int
x Name
dn)) (UniqSet Name -> SrcSpan -> ConR
Set UniqSet Name
ksn SrcSpan
s)
ConstraintSet -> InferM ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([Atomic] -> ConstraintSet
Constraints.fromList [Atomic
a])
where
dn :: Name
dn = TyCon -> Name
forall a. NamedThing a => a -> Name
getName TyCon
d
ksn :: UniqSet Name
ksn = [Name] -> UniqSet Name
forall a. Uniquable a => [a] -> UniqSet a
mkUniqSet ((DataCon -> Name) -> [DataCon] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map DataCon -> Name
forall a. NamedThing a => a -> Name
getName [DataCon]
ks)
emitDK _ _ _ = () -> InferM ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
fresh :: InferM RVar
fresh :: InferM Int
fresh = do
Int
i <- (InferState -> Int) -> InferM Int
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets InferState -> Int
rVar
InferM ()
incrV
Int -> InferM Int
forall (m :: * -> *) a. Monad m => a -> m a
return Int
i
putVar :: Name -> Scheme -> InferM a -> InferM a
putVar :: Name -> Scheme -> InferM a -> InferM a
putVar n :: Name
n s :: Scheme
s = (InferEnv -> InferEnv) -> InferM a -> InferM a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (\env :: InferEnv
env -> InferEnv
env {varEnv :: Context
varEnv = Name -> Scheme -> Context -> Context
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Name
n Scheme
s (InferEnv -> Context
varEnv InferEnv
env)})
putVars :: Context -> InferM a -> InferM a
putVars :: Context -> InferM a -> InferM a
putVars ctx :: Context
ctx = (InferEnv -> InferEnv) -> InferM a -> InferM a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (\env :: InferEnv
env -> InferEnv
env {varEnv :: Context
varEnv = Context -> Context -> Context
forall k a. Ord k => Map k a -> Map k a -> Map k a
M.union Context
ctx (InferEnv -> Context
varEnv InferEnv
env)})
setLoc :: SrcSpan -> InferM a -> InferM a
setLoc :: SrcSpan -> InferM a -> InferM a
setLoc l :: SrcSpan
l = (InferEnv -> InferEnv) -> InferM a -> InferM a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (\env :: InferEnv
env -> InferEnv
env {inferLoc :: SrcSpan
inferLoc = SrcSpan
l})
getExternalName :: NamedThing a => a -> InferM Name
getExternalName :: a -> InferM Name
getExternalName a :: a
a = do
let n :: Name
n = a -> Name
forall a. NamedThing a => a -> Name
getName a
a
Module
mn <- (InferEnv -> Module)
-> RWST InferEnv ConstraintSet InferState Identity Module
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks InferEnv -> Module
modName
Name -> InferM Name
forall (m :: * -> *) a. Monad m => a -> m a
return (Name -> InferM Name) -> Name -> InferM Name
forall a b. (a -> b) -> a -> b
$ Unique -> Module -> OccName -> SrcSpan -> Name
mkExternalName (Name -> Unique
nameUnique Name
n) Module
mn (Name -> OccName
nameOccName Name
n) (Name -> SrcSpan
nameSrcSpan Name
n)