{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE FlexibleContexts #-}
module JL.Inferer where
import Control.Monad.State.Strict
import qualified Data.HashMap.Strict as HM
import Data.Map (Map)
import qualified Data.Map as M
import Data.Set (Set)
import qualified Data.Set as S
import Data.Text (Text)
import qualified Data.Text as T
import JL.Printer
import JL.Types
infer :: Map Variable Type -> Expression -> [TypeVariable] -> Type
infer :: Map Variable Type -> Expression -> [TypeVariable] -> Type
infer Map Variable Type
ctx Expression
t [TypeVariable]
stream =
case State [TypeVariable] (Type, Set (Type, Type))
-> [TypeVariable] -> (Type, Set (Type, Type))
forall s a. State s a -> s -> a
evalState (Map Variable Type
-> Expression -> State [TypeVariable] (Type, Set (Type, Type))
forall (m :: * -> *).
MonadState [TypeVariable] m =>
Map Variable Type -> Expression -> m (Type, Set (Type, Type))
check Map Variable Type
ctx Expression
t) [TypeVariable]
stream of
(Type
ty, Set (Type, Type)
cs) ->
let s :: Map TypeVariable Type
s = State () (Map TypeVariable Type) -> () -> Map TypeVariable Type
forall s a. State s a -> s -> a
evalState ([(Type, Type)] -> State () (Map TypeVariable Type)
forall (m :: * -> *).
Monad m =>
[(Type, Type)] -> m (Map TypeVariable Type)
unify (Set (Type, Type) -> [(Type, Type)]
forall a. Set a -> [a]
S.toList Set (Type, Type)
cs)) ()
in Map TypeVariable Type -> Type -> Type
replace Map TypeVariable Type
s Type
ty
check
:: MonadState ([TypeVariable]) m
=> Map Variable Type -> Expression -> m (Type, Set (Type, Type))
check :: Map Variable Type -> Expression -> m (Type, Set (Type, Type))
check Map Variable Type
ctx Expression
expr =
case Expression
expr of
VariableExpression name :: Variable
name@(Variable Text
text) ->
case Variable -> Map Variable Type -> Maybe Type
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Variable
name Map Variable Type
ctx of
Maybe Type
Nothing -> [Char] -> m (Type, Set (Type, Type))
forall a. HasCallStack => [Char] -> a
error ([Char]
"Not in scope: `" [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Text -> [Char]
T.unpack Text
text [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"'")
Just Type
typ -> (Type, Set (Type, Type)) -> m (Type, Set (Type, Type))
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
typ, Set (Type, Type)
forall a. Monoid a => a
mempty)
LambdaExpression Variable
x Expression
body -> do
TypeVariable
sym <- m TypeVariable
forall (m :: * -> *). MonadState [TypeVariable] m => m TypeVariable
generateTypeVariable
let xty :: Type
xty = TypeVariable -> Type
VariableType TypeVariable
sym
(Type
rty, Set (Type, Type)
cs) <- Map Variable Type -> Expression -> m (Type, Set (Type, Type))
forall (m :: * -> *).
MonadState [TypeVariable] m =>
Map Variable Type -> Expression -> m (Type, Set (Type, Type))
check (Variable -> Type -> Map Variable Type -> Map Variable Type
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Variable
x Type
xty Map Variable Type
ctx) Expression
body
(Type, Set (Type, Type)) -> m (Type, Set (Type, Type))
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> Type -> Type
FunctionType Type
xty Type
rty, Set (Type, Type)
cs)
ConstantExpression {} -> (Type, Set (Type, Type)) -> m (Type, Set (Type, Type))
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
JSONType, Set (Type, Type)
forall a. Monoid a => a
mempty)
ApplicationExpression Expression
f Expression
x -> do
(Type
fty, Set (Type, Type)
cs1) <- Map Variable Type -> Expression -> m (Type, Set (Type, Type))
forall (m :: * -> *).
MonadState [TypeVariable] m =>
Map Variable Type -> Expression -> m (Type, Set (Type, Type))
check Map Variable Type
ctx Expression
f
(Type
xty, Set (Type, Type)
cs2) <- Map Variable Type -> Expression -> m (Type, Set (Type, Type))
forall (m :: * -> *).
MonadState [TypeVariable] m =>
Map Variable Type -> Expression -> m (Type, Set (Type, Type))
check Map Variable Type
ctx Expression
x
TypeVariable
sym <- m TypeVariable
forall (m :: * -> *). MonadState [TypeVariable] m => m TypeVariable
generateTypeVariable
let rty :: Type
rty = TypeVariable -> Type
VariableType TypeVariable
sym
cs :: Set (Type, Type)
cs = (Type, Type) -> Set (Type, Type) -> Set (Type, Type)
forall a. Ord a => a -> Set a -> Set a
S.insert (Type
fty, Type -> Type -> Type
FunctionType Type
xty Type
rty) (Set (Type, Type)
cs1 Set (Type, Type) -> Set (Type, Type) -> Set (Type, Type)
forall a. Semigroup a => a -> a -> a
<> Set (Type, Type)
cs2)
(Type, Set (Type, Type)) -> m (Type, Set (Type, Type))
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
rty, Set (Type, Type)
cs)
InfixExpression Expression
l Variable
f Expression
r -> do
(Type
fty, Set (Type, Type)
cs1) <- Map Variable Type -> Expression -> m (Type, Set (Type, Type))
forall (m :: * -> *).
MonadState [TypeVariable] m =>
Map Variable Type -> Expression -> m (Type, Set (Type, Type))
check Map Variable Type
ctx (Variable -> Expression
VariableExpression Variable
f)
(Type
ty1, Set (Type, Type)
cs2) <- Map Variable Type -> Expression -> m (Type, Set (Type, Type))
forall (m :: * -> *).
MonadState [TypeVariable] m =>
Map Variable Type -> Expression -> m (Type, Set (Type, Type))
check Map Variable Type
ctx Expression
l
(Type
ty2, Set (Type, Type)
cs3) <- Map Variable Type -> Expression -> m (Type, Set (Type, Type))
forall (m :: * -> *).
MonadState [TypeVariable] m =>
Map Variable Type -> Expression -> m (Type, Set (Type, Type))
check Map Variable Type
ctx Expression
r
TypeVariable
sym <- m TypeVariable
forall (m :: * -> *). MonadState [TypeVariable] m => m TypeVariable
generateTypeVariable
let rty :: Type
rty = TypeVariable -> Type
VariableType TypeVariable
sym
cs :: Set (Type, Type)
cs =
(Type, Type) -> Set (Type, Type) -> Set (Type, Type)
forall a. Ord a => a -> Set a -> Set a
S.insert
(Type
fty, Type -> Type -> Type
FunctionType Type
ty1 (Type -> Type -> Type
FunctionType Type
ty2 Type
rty))
(Set (Type, Type)
cs1 Set (Type, Type) -> Set (Type, Type) -> Set (Type, Type)
forall a. Semigroup a => a -> a -> a
<> Set (Type, Type)
cs2 Set (Type, Type) -> Set (Type, Type) -> Set (Type, Type)
forall a. Semigroup a => a -> a -> a
<> Set (Type, Type)
cs3)
(Type, Set (Type, Type)) -> m (Type, Set (Type, Type))
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
rty, Set (Type, Type)
cs)
IfExpression Expression
cond Expression
a Expression
b -> do
(Type
condty, Set (Type, Type)
cs1) <- Map Variable Type -> Expression -> m (Type, Set (Type, Type))
forall (m :: * -> *).
MonadState [TypeVariable] m =>
Map Variable Type -> Expression -> m (Type, Set (Type, Type))
check Map Variable Type
ctx Expression
cond
(Type
aty, Set (Type, Type)
cs2) <- Map Variable Type -> Expression -> m (Type, Set (Type, Type))
forall (m :: * -> *).
MonadState [TypeVariable] m =>
Map Variable Type -> Expression -> m (Type, Set (Type, Type))
check Map Variable Type
ctx Expression
a
(Type
bty, Set (Type, Type)
cs3) <- Map Variable Type -> Expression -> m (Type, Set (Type, Type))
forall (m :: * -> *).
MonadState [TypeVariable] m =>
Map Variable Type -> Expression -> m (Type, Set (Type, Type))
check Map Variable Type
ctx Expression
b
TypeVariable
sym <- m TypeVariable
forall (m :: * -> *). MonadState [TypeVariable] m => m TypeVariable
generateTypeVariable
let rty :: Type
rty = TypeVariable -> Type
VariableType TypeVariable
sym
cs :: Set (Type, Type)
cs =
(Type, Type) -> Set (Type, Type) -> Set (Type, Type)
forall a. Ord a => a -> Set a -> Set a
S.insert
(Type
condty, Type
JSONType)
((Type, Type) -> Set (Type, Type) -> Set (Type, Type)
forall a. Ord a => a -> Set a -> Set a
S.insert (Type
aty, Type
bty) (Set (Type, Type)
cs1 Set (Type, Type) -> Set (Type, Type) -> Set (Type, Type)
forall a. Semigroup a => a -> a -> a
<> Set (Type, Type)
cs2 Set (Type, Type) -> Set (Type, Type) -> Set (Type, Type)
forall a. Semigroup a => a -> a -> a
<> Set (Type, Type)
cs3))
(Type, Set (Type, Type)) -> m (Type, Set (Type, Type))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type
rty, Set (Type, Type)
cs)
RecordExpression HashMap Text Expression
pairs -> do
Set (Type, Type)
cs <-
(Set (Type, Type) -> (Text, Expression) -> m (Set (Type, Type)))
-> Set (Type, Type) -> [(Text, Expression)] -> m (Set (Type, Type))
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM
(\Set (Type, Type)
cs (Text
_, Expression
e) -> do
(Type
pty, Set (Type, Type)
cs') <- Map Variable Type -> Expression -> m (Type, Set (Type, Type))
forall (m :: * -> *).
MonadState [TypeVariable] m =>
Map Variable Type -> Expression -> m (Type, Set (Type, Type))
check Map Variable Type
ctx Expression
e
Set (Type, Type) -> m (Set (Type, Type))
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Type, Type) -> Set (Type, Type) -> Set (Type, Type)
forall a. Ord a => a -> Set a -> Set a
S.insert (Type
pty, Type
JSONType) (Set (Type, Type)
cs Set (Type, Type) -> Set (Type, Type) -> Set (Type, Type)
forall a. Semigroup a => a -> a -> a
<> Set (Type, Type)
cs')))
Set (Type, Type)
forall a. Monoid a => a
mempty
(HashMap Text Expression -> [(Text, Expression)]
forall k v. HashMap k v -> [(k, v)]
HM.toList HashMap Text Expression
pairs)
(Type, Set (Type, Type)) -> m (Type, Set (Type, Type))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type
JSONType, Set (Type, Type)
cs)
SubscriptExpression Subscripted
e [Subscript]
ks -> do
(Type
t1, Set (Type, Type)
c1) <-
(case Subscripted
e of
ExpressionSubscripted Expression
es -> Map Variable Type -> Expression -> m (Type, Set (Type, Type))
forall (m :: * -> *).
MonadState [TypeVariable] m =>
Map Variable Type -> Expression -> m (Type, Set (Type, Type))
check Map Variable Type
ctx Expression
es
Subscripted
WildcardSubscripted -> (Type, Set (Type, Type)) -> m (Type, Set (Type, Type))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type -> Type -> Type
FunctionType Type
JSONType Type
JSONType, Set (Type, Type)
forall a. Monoid a => a
mempty))
Set (Type, Type)
cs <-
(Set (Type, Type) -> Subscript -> m (Set (Type, Type)))
-> Set (Type, Type) -> [Subscript] -> m (Set (Type, Type))
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM
(\Set (Type, Type)
cs Subscript
s ->
case Subscript
s of
PropertySubscript {} -> Set (Type, Type) -> m (Set (Type, Type))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Set (Type, Type)
cs
ExpressionSubscript Expression
es -> do
(Type
pty, Set (Type, Type)
cs') <- Map Variable Type -> Expression -> m (Type, Set (Type, Type))
forall (m :: * -> *).
MonadState [TypeVariable] m =>
Map Variable Type -> Expression -> m (Type, Set (Type, Type))
check Map Variable Type
ctx Expression
es
Set (Type, Type) -> m (Set (Type, Type))
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Type, Type) -> Set (Type, Type) -> Set (Type, Type)
forall a. Ord a => a -> Set a -> Set a
S.insert (Type
pty, Type
JSONType) (Set (Type, Type)
cs Set (Type, Type) -> Set (Type, Type) -> Set (Type, Type)
forall a. Semigroup a => a -> a -> a
<> Set (Type, Type)
cs')))
Set (Type, Type)
c1
[Subscript]
ks
let rty :: Type
rty = case Subscripted
e of
Subscripted
WildcardSubscripted -> Type -> Type -> Type
FunctionType Type
JSONType Type
JSONType
Subscripted
_ -> Type
JSONType
(Type, Set (Type, Type)) -> m (Type, Set (Type, Type))
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( Type
rty
, (Type, Type) -> Set (Type, Type) -> Set (Type, Type)
forall a. Ord a => a -> Set a -> Set a
S.insert (Type
t1, Type
rty) Set (Type, Type)
cs)
ArrayExpression Vector Expression
as -> do
Set (Type, Type)
cs <-
(Set (Type, Type) -> Expression -> m (Set (Type, Type)))
-> Set (Type, Type) -> Vector Expression -> m (Set (Type, Type))
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM
(\Set (Type, Type)
cs Expression
e -> do
(Type
pty, Set (Type, Type)
cs') <- Map Variable Type -> Expression -> m (Type, Set (Type, Type))
forall (m :: * -> *).
MonadState [TypeVariable] m =>
Map Variable Type -> Expression -> m (Type, Set (Type, Type))
check Map Variable Type
ctx Expression
e
Set (Type, Type) -> m (Set (Type, Type))
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Type, Type) -> Set (Type, Type) -> Set (Type, Type)
forall a. Ord a => a -> Set a -> Set a
S.insert (Type
pty, Type
JSONType) (Set (Type, Type)
cs Set (Type, Type) -> Set (Type, Type) -> Set (Type, Type)
forall a. Semigroup a => a -> a -> a
<> Set (Type, Type)
cs')))
Set (Type, Type)
forall a. Monoid a => a
mempty
Vector Expression
as
(Type, Set (Type, Type)) -> m (Type, Set (Type, Type))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type
JSONType, Set (Type, Type)
cs)
generateTypeVariable
:: MonadState ([TypeVariable]) m
=> m TypeVariable
generateTypeVariable :: m TypeVariable
generateTypeVariable =
m [TypeVariable]
forall s (m :: * -> *). MonadState s m => m s
get m [TypeVariable]
-> ([TypeVariable] -> m TypeVariable) -> m TypeVariable
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
TypeVariable
v:[TypeVariable]
vs -> do
[TypeVariable] -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put [TypeVariable]
vs
TypeVariable -> m TypeVariable
forall (f :: * -> *) a. Applicative f => a -> f a
pure TypeVariable
v
[TypeVariable]
_ ->
[Char] -> m TypeVariable
forall a. HasCallStack => [Char] -> a
error [Char]
"Ran out of type variables"
unify
:: Monad m
=> [(Type, Type)] -> m (Map TypeVariable Type)
unify :: [(Type, Type)] -> m (Map TypeVariable Type)
unify [] = Map TypeVariable Type -> m (Map TypeVariable Type)
forall (m :: * -> *) a. Monad m => a -> m a
return Map TypeVariable Type
forall a. Monoid a => a
mempty
unify ((Type
a, Type
b):[(Type, Type)]
cs)
| Type
a Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
b = [(Type, Type)] -> m (Map TypeVariable Type)
forall (m :: * -> *).
Monad m =>
[(Type, Type)] -> m (Map TypeVariable Type)
unify [(Type, Type)]
cs
| VariableType TypeVariable
v <- Type
a = TypeVariable
-> [(Type, Type)] -> Type -> Type -> m (Map TypeVariable Type)
forall (m :: * -> *).
Monad m =>
TypeVariable
-> [(Type, Type)] -> Type -> Type -> m (Map TypeVariable Type)
unifyVariable TypeVariable
v [(Type, Type)]
cs Type
a Type
b
| VariableType TypeVariable
v <- Type
b = TypeVariable
-> [(Type, Type)] -> Type -> Type -> m (Map TypeVariable Type)
forall (m :: * -> *).
Monad m =>
TypeVariable
-> [(Type, Type)] -> Type -> Type -> m (Map TypeVariable Type)
unifyVariable TypeVariable
v [(Type, Type)]
cs Type
b Type
a
| FunctionType Type
a1 Type
b1 <- Type
a
, FunctionType Type
a2 Type
b2 <- Type
b = [(Type, Type)] -> m (Map TypeVariable Type)
forall (m :: * -> *).
Monad m =>
[(Type, Type)] -> m (Map TypeVariable Type)
unify ([(Type
a1, Type
a2), (Type
b1, Type
b2)] [(Type, Type)] -> [(Type, Type)] -> [(Type, Type)]
forall a. Semigroup a => a -> a -> a
<> [(Type, Type)]
cs)
| Bool
otherwise =
[Char] -> m (Map TypeVariable Type)
forall a. HasCallStack => [Char] -> a
error
(Text -> [Char]
T.unpack
(Text
"Type " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text -> Text
quote (Type -> Text
prettyType Type
a) Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" doesn't match " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<>
Text -> Text
quote (Type -> Text
prettyType Type
b)))
unifyVariable
:: Monad m
=> TypeVariable -> [(Type, Type)] -> Type -> Type -> m (Map TypeVariable Type)
unifyVariable :: TypeVariable
-> [(Type, Type)] -> Type -> Type -> m (Map TypeVariable Type)
unifyVariable TypeVariable
v [(Type, Type)]
cs Type
a Type
b =
if TypeVariable -> Type -> Bool
occurs TypeVariable
v Type
b
then [Char] -> m (Map TypeVariable Type)
forall a. HasCallStack => [Char] -> a
error
(Text -> [Char]
T.unpack (Text
"Occurs check: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Type -> Text
prettyType Type
a Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" ~ " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Type -> Text
prettyType Type
b))
else let subbed :: Map TypeVariable Type
subbed = TypeVariable -> Type -> Map TypeVariable Type
forall k a. k -> a -> Map k a
M.singleton TypeVariable
v Type
b
in do Map TypeVariable Type
rest <- [(Type, Type)] -> m (Map TypeVariable Type)
forall (m :: * -> *).
Monad m =>
[(Type, Type)] -> m (Map TypeVariable Type)
unify (Map TypeVariable Type -> [(Type, Type)] -> [(Type, Type)]
substitute Map TypeVariable Type
subbed [(Type, Type)]
cs)
Map TypeVariable Type -> m (Map TypeVariable Type)
forall (m :: * -> *) a. Monad m => a -> m a
return (Map TypeVariable Type
rest Map TypeVariable Type
-> Map TypeVariable Type -> Map TypeVariable Type
forall a. Semigroup a => a -> a -> a
<> Map TypeVariable Type
subbed)
occurs :: TypeVariable -> Type -> Bool
occurs :: TypeVariable -> Type -> Bool
occurs TypeVariable
x (VariableType TypeVariable
y)
| TypeVariable
x TypeVariable -> TypeVariable -> Bool
forall a. Eq a => a -> a -> Bool
== TypeVariable
y = Bool
True
| Bool
otherwise = Bool
False
occurs TypeVariable
x (FunctionType Type
a Type
b) = TypeVariable -> Type -> Bool
occurs TypeVariable
x Type
a Bool -> Bool -> Bool
|| TypeVariable -> Type -> Bool
occurs TypeVariable
x Type
b
occurs TypeVariable
_ Type
JSONType = Bool
False
substitute :: Map TypeVariable Type -> [(Type, Type)] -> [(Type, Type)]
substitute :: Map TypeVariable Type -> [(Type, Type)] -> [(Type, Type)]
substitute Map TypeVariable Type
subs = ((Type, Type) -> (Type, Type)) -> [(Type, Type)] -> [(Type, Type)]
forall a b. (a -> b) -> [a] -> [b]
map (Type, Type) -> (Type, Type)
go
where
go :: (Type, Type) -> (Type, Type)
go (Type
a, Type
b) = (Map TypeVariable Type -> Type -> Type
replace Map TypeVariable Type
subs Type
a, Map TypeVariable Type -> Type -> Type
replace Map TypeVariable Type
subs Type
b)
replace :: Map TypeVariable Type -> Type -> Type
replace :: Map TypeVariable Type -> Type -> Type
replace Map TypeVariable Type
s' Type
t' = (TypeVariable -> Type -> Type -> Type)
-> Type -> Map TypeVariable Type -> Type
forall k a b. (k -> a -> b -> b) -> b -> Map k a -> b
M.foldrWithKey TypeVariable -> Type -> Type -> Type
go Type
t' Map TypeVariable Type
s'
where
go :: TypeVariable -> Type -> Type -> Type
go TypeVariable
s1 Type
t (VariableType TypeVariable
s2)
| TypeVariable
s1 TypeVariable -> TypeVariable -> Bool
forall a. Eq a => a -> a -> Bool
== TypeVariable
s2 = Type
t
| Bool
otherwise = TypeVariable -> Type
VariableType TypeVariable
s2
go TypeVariable
s Type
t (FunctionType Type
t2 Type
t3) = Type -> Type -> Type
FunctionType (TypeVariable -> Type -> Type -> Type
go TypeVariable
s Type
t Type
t2) (TypeVariable -> Type -> Type -> Type
go TypeVariable
s Type
t Type
t3)
go TypeVariable
_ Type
_ Type
JSONType = Type
JSONType
quote :: Text -> Text
quote :: Text -> Text
quote Text
t = Text
"‘" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
t Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"’"