module Language.Lambda.SystemF.TypeCheck where

import Language.Lambda.Shared.Errors (LambdaException(..))
import Language.Lambda.SystemF.Expression
import Language.Lambda.SystemF.State

import Control.Monad.Except (MonadError(..))
import Prettyprinter
import RIO
import qualified RIO.List as List
import qualified RIO.Map as Map

type UniqueSupply n = [n]
type Context' n t = Map n t

-- TODO: name/ty different types
typecheck
  :: (Ord name, Pretty name)
  => SystemFExpr name name
  -> Typecheck name (Ty name)
typecheck :: SystemFExpr name name -> Typecheck name (Ty name)
typecheck (Var name
v) = name -> Typecheck name (Ty name)
forall name. Ord name => name -> Typecheck name (Ty name)
typecheckVar name
v
typecheck (Abs name
n Ty name
t SystemFExpr name name
body) = name
-> Ty name -> SystemFExpr name name -> Typecheck name (Ty name)
forall name.
(Ord name, Pretty name) =>
name
-> Ty name -> SystemFExpr name name -> Typecheck name (Ty name)
typecheckAbs name
n Ty name
t SystemFExpr name name
body
typecheck (App SystemFExpr name name
e1 SystemFExpr name name
e2) = SystemFExpr name name
-> SystemFExpr name name -> Typecheck name (Ty name)
forall name.
(Ord name, Pretty name) =>
SystemFExpr name name
-> SystemFExpr name name -> Typecheck name (Ty name)
typecheckApp SystemFExpr name name
e1 SystemFExpr name name
e2
typecheck (TyAbs name
t SystemFExpr name name
body) = name -> SystemFExpr name name -> Typecheck name (Ty name)
forall name.
(Ord name, Pretty name) =>
name -> SystemFExpr name name -> Typecheck name (Ty name)
typecheckTyAbs name
t SystemFExpr name name
body
typecheck (TyApp SystemFExpr name name
e Ty name
ty) = SystemFExpr name name -> Ty name -> Typecheck name (Ty name)
forall name.
(Ord name, Pretty name) =>
SystemFExpr name name -> Ty name -> Typecheck name (Ty name)
typecheckTyApp SystemFExpr name name
e Ty name
ty

typecheckVar :: Ord name => name -> Typecheck name (Ty name)
typecheckVar :: name -> Typecheck name (Ty name)
typecheckVar name
var = Typecheck name (Context name)
forall name. Typecheck name (Context name)
getContext Typecheck name (Context name)
-> (Context name -> Typecheck name (Ty name))
-> Typecheck name (Ty name)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Maybe (Ty name) -> Typecheck name (Ty name)
forall name.
Maybe (Ty name)
-> StateT (TypecheckState name) (Except LambdaException) (Ty name)
defaultToFreshTyVar (Maybe (Ty name) -> Typecheck name (Ty name))
-> (Context name -> Maybe (Ty name))
-> Context name
-> Typecheck name (Ty name)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. name -> Context name -> Maybe (Ty name)
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup name
var
  where defaultToFreshTyVar :: Maybe (Ty name)
-> StateT (TypecheckState name) (Except LambdaException) (Ty name)
defaultToFreshTyVar (Just Ty name
v) = Ty name
-> StateT (TypecheckState name) (Except LambdaException) (Ty name)
forall (m :: * -> *) a. Monad m => a -> m a
return Ty name
v
        defaultToFreshTyVar Maybe (Ty name)
Nothing = name -> Ty name
forall name. name -> Ty name
TyVar (name -> Ty name)
-> StateT (TypecheckState name) (Except LambdaException) name
-> StateT (TypecheckState name) (Except LambdaException) (Ty name)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StateT (TypecheckState name) (Except LambdaException) name
forall name. Typecheck name name
unique

typecheckAbs
  :: (Ord name, Pretty name)
  => name
  -> Ty name
  -> SystemFExpr name name
  -> Typecheck name (Ty name)
typecheckAbs :: name
-> Ty name -> SystemFExpr name name -> Typecheck name (Ty name)
typecheckAbs name
name Ty name
ty SystemFExpr name name
body
  = (Context name -> Context name) -> Typecheck name ()
forall name. (Context name -> Context name) -> Typecheck name ()
modifyContext (name -> Ty name -> Context name -> Context name
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert name
name Ty name
ty)
    Typecheck name ()
-> Typecheck name (Ty name) -> Typecheck name (Ty name)
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Ty name -> Ty name -> Ty name
forall name. Ty name -> Ty name -> Ty name
TyArrow Ty name
ty (Ty name -> Ty name)
-> Typecheck name (Ty name) -> Typecheck name (Ty name)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SystemFExpr name name -> Typecheck name (Ty name)
forall name.
(Ord name, Pretty name) =>
SystemFExpr name name -> Typecheck name (Ty name)
typecheck SystemFExpr name name
body

typecheckApp
  :: (Ord name, Pretty name)
  => SystemFExpr name name
  -> SystemFExpr name name
  -> Typecheck name (Ty name)
typecheckApp :: SystemFExpr name name
-> SystemFExpr name name -> Typecheck name (Ty name)
typecheckApp SystemFExpr name name
e1 SystemFExpr name name
e2 = do
  -- Typecheck expressions
  Ty name
t1 <- SystemFExpr name name -> Typecheck name (Ty name)
forall name.
(Ord name, Pretty name) =>
SystemFExpr name name -> Typecheck name (Ty name)
typecheck SystemFExpr name name
e1
  Ty name
t2 <- SystemFExpr name name -> Typecheck name (Ty name)
forall name.
(Ord name, Pretty name) =>
SystemFExpr name name -> Typecheck name (Ty name)
typecheck SystemFExpr name name
e2

  -- Verify the type of t1 is an Arrow
  (Ty name
t1AppInput, Ty name
t1AppOutput) <- case Ty name
t1 of
    (TyArrow Ty name
appInput Ty name
appOutput) -> (Ty name, Ty name)
-> StateT
     (TypecheckState name) (Except LambdaException) (Ty name, Ty name)
forall (m :: * -> *) a. Monad m => a -> m a
return (Ty name
appInput, Ty name
appOutput)
    Ty name
t1' -> LambdaException
-> StateT
     (TypecheckState name) (Except LambdaException) (Ty name, Ty name)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (LambdaException
 -> StateT
      (TypecheckState name) (Except LambdaException) (Ty name, Ty name))
-> LambdaException
-> StateT
     (TypecheckState name) (Except LambdaException) (Ty name, Ty name)
forall a b. (a -> b) -> a -> b
$ Ty name -> Ty name -> LambdaException
forall t1 t2. (Pretty t1, Pretty t2) => t1 -> t2 -> LambdaException
tyMismatchError Ty name
t1' Ty name
t1

  -- Verify the output of e1 matches the type of e2
  if Ty name
t1AppInput Ty name -> Ty name -> Bool
forall a. Eq a => a -> a -> Bool
== Ty name
t2
    then Ty name -> Typecheck name (Ty name)
forall (m :: * -> *) a. Monad m => a -> m a
return Ty name
t1AppOutput
    else LambdaException -> Typecheck name (Ty name)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (LambdaException -> Typecheck name (Ty name))
-> LambdaException -> Typecheck name (Ty name)
forall a b. (a -> b) -> a -> b
$ Ty name -> Ty name -> LambdaException
forall t1 t2. (Pretty t1, Pretty t2) => t1 -> t2 -> LambdaException
tyMismatchError (Ty name -> Ty name -> Ty name
forall name. Ty name -> Ty name -> Ty name
TyArrow Ty name
t2 Ty name
t1AppOutput) (Ty name -> Ty name -> Ty name
forall name. Ty name -> Ty name -> Ty name
TyArrow Ty name
t1 Ty name
t1AppOutput)

typecheckTyAbs
  :: (Ord name, Pretty name)
  => name
  -> SystemFExpr name name
  -> Typecheck name (Ty name)
typecheckTyAbs :: name -> SystemFExpr name name -> Typecheck name (Ty name)
typecheckTyAbs name
ty SystemFExpr name name
body
  = (Context name -> Context name) -> Typecheck name ()
forall name. (Context name -> Context name) -> Typecheck name ()
modifyContext (name -> Ty name -> Context name -> Context name
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert name
ty (name -> Ty name
forall name. name -> Ty name
TyVar name
ty))
    Typecheck name ()
-> Typecheck name (Ty name) -> Typecheck name (Ty name)
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> name -> Ty name -> Ty name
forall name. name -> Ty name -> Ty name
TyForAll name
ty (Ty name -> Ty name)
-> Typecheck name (Ty name) -> Typecheck name (Ty name)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SystemFExpr name name -> Typecheck name (Ty name)
forall name.
(Ord name, Pretty name) =>
SystemFExpr name name -> Typecheck name (Ty name)
typecheck SystemFExpr name name
body

typecheckTyApp
  :: (Ord name, Pretty name)
  => SystemFExpr name name
  -> Ty name
  -> Typecheck name (Ty name)
typecheckTyApp :: SystemFExpr name name -> Ty name -> Typecheck name (Ty name)
typecheckTyApp (TyAbs name
t SystemFExpr name name
expr) Ty name
ty = SystemFExpr name name -> Typecheck name (Ty name)
forall name.
(Ord name, Pretty name) =>
SystemFExpr name name -> Typecheck name (Ty name)
typecheck (SystemFExpr name name -> Typecheck name (Ty name))
-> SystemFExpr name name -> Typecheck name (Ty name)
forall a b. (a -> b) -> a -> b
$ Ty name -> name -> SystemFExpr name name -> SystemFExpr name name
forall n. Eq n => Ty n -> n -> SystemFExpr n n -> SystemFExpr n n
substitute Ty name
ty name
t SystemFExpr name name
expr
typecheckTyApp SystemFExpr name name
expr Ty name
_ = SystemFExpr name name -> Typecheck name (Ty name)
forall name.
(Ord name, Pretty name) =>
SystemFExpr name name -> Typecheck name (Ty name)
typecheck SystemFExpr name name
expr

unique :: Typecheck name name
unique :: Typecheck name name
unique = Typecheck name [name]
forall name. Typecheck name [name]
getUniques Typecheck name [name]
-> ([name] -> Typecheck name name) -> Typecheck name name
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Maybe name -> Typecheck name name
forall (m :: * -> *) a.
MonadError LambdaException m =>
Maybe a -> m a
fromJust' (Maybe name -> Typecheck name name)
-> ([name] -> Maybe name) -> [name] -> Typecheck name name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [name] -> Maybe name
forall a. [a] -> Maybe a
List.headMaybe
  where fromJust' :: Maybe a -> m a
fromJust' (Just a
u) = a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
u
        fromJust' Maybe a
Nothing = LambdaException -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError LambdaException
ImpossibleError

substitute
  :: Eq n
  => Ty n
  -> n
  -> SystemFExpr n n
  -> SystemFExpr n n
substitute :: Ty n -> n -> SystemFExpr n n -> SystemFExpr n n
substitute Ty n
ty n
name (App SystemFExpr n n
e1 SystemFExpr n n
e2) = SystemFExpr n n -> SystemFExpr n n -> SystemFExpr n n
forall name ty.
SystemFExpr name ty -> SystemFExpr name ty -> SystemFExpr name ty
App (Ty n -> n -> SystemFExpr n n -> SystemFExpr n n
forall n. Eq n => Ty n -> n -> SystemFExpr n n -> SystemFExpr n n
substitute Ty n
ty n
name SystemFExpr n n
e1) (Ty n -> n -> SystemFExpr n n -> SystemFExpr n n
forall n. Eq n => Ty n -> n -> SystemFExpr n n -> SystemFExpr n n
substitute Ty n
ty n
name SystemFExpr n n
e2)
substitute Ty n
ty n
name (Abs n
n Ty n
ty' SystemFExpr n n
e) = n -> Ty n -> SystemFExpr n n -> SystemFExpr n n
forall name ty.
name -> Ty ty -> SystemFExpr name ty -> SystemFExpr name ty
Abs n
n (Ty n -> n -> Ty n -> Ty n
forall name. Eq name => Ty name -> name -> Ty name -> Ty name
substituteTy Ty n
ty n
name Ty n
ty') (Ty n -> n -> SystemFExpr n n -> SystemFExpr n n
forall n. Eq n => Ty n -> n -> SystemFExpr n n -> SystemFExpr n n
substitute Ty n
ty n
name SystemFExpr n n
e)
substitute Ty n
ty n
name (TyAbs n
ty' SystemFExpr n n
e) = n -> SystemFExpr n n -> SystemFExpr n n
forall name ty. ty -> SystemFExpr name ty -> SystemFExpr name ty
TyAbs n
ty' (Ty n -> n -> SystemFExpr n n -> SystemFExpr n n
forall n. Eq n => Ty n -> n -> SystemFExpr n n -> SystemFExpr n n
substitute Ty n
ty n
name SystemFExpr n n
e) 
substitute Ty n
ty n
name (TyApp SystemFExpr n n
e Ty n
ty') = SystemFExpr n n -> Ty n -> SystemFExpr n n
forall name ty. SystemFExpr name ty -> Ty ty -> SystemFExpr name ty
TyApp (Ty n -> n -> SystemFExpr n n -> SystemFExpr n n
forall n. Eq n => Ty n -> n -> SystemFExpr n n -> SystemFExpr n n
substitute Ty n
ty n
name SystemFExpr n n
e) (Ty n -> n -> Ty n -> Ty n
forall name. Eq name => Ty name -> name -> Ty name -> Ty name
substituteTy Ty n
ty n
name Ty n
ty')
substitute Ty n
_ n
_ SystemFExpr n n
expr = SystemFExpr n n
expr

substituteTy
  :: Eq name
  => Ty name
  -> name
  -> Ty name
  -> Ty name
substituteTy :: Ty name -> name -> Ty name -> Ty name
substituteTy Ty name
ty name
name (TyArrow Ty name
t1 Ty name
t2) 
  = Ty name -> Ty name -> Ty name
forall name. Ty name -> Ty name -> Ty name
TyArrow (Ty name -> name -> Ty name -> Ty name
forall name. Eq name => Ty name -> name -> Ty name -> Ty name
substituteTy Ty name
ty name
name Ty name
t1) (Ty name -> name -> Ty name -> Ty name
forall name. Eq name => Ty name -> name -> Ty name -> Ty name
substituteTy Ty name
ty name
name Ty name
t2)
substituteTy Ty name
ty name
name ty' :: Ty name
ty'@(TyVar name
name') 
  | name
name name -> name -> Bool
forall a. Eq a => a -> a -> Bool
== name
name' = Ty name
ty
  | Bool
otherwise     = Ty name
ty'
substituteTy Ty name
_ name
name t2 :: Ty name
t2@(TyForAll name
name' Ty name
t2') 
  | name
name name -> name -> Bool
forall a. Eq a => a -> a -> Bool
== name
name' = Ty name
t2
  | Bool
otherwise     = name -> Ty name -> Ty name
forall name. name -> Ty name -> Ty name
TyForAll name
name' (Ty name -> name -> Ty name -> Ty name
forall name. Eq name => Ty name -> name -> Ty name -> Ty name
substituteTy Ty name
t2 name
name Ty name
t2')


tyMismatchError
  :: (Pretty t1, Pretty t2)
  => t1
  -> t2
  -> LambdaException
tyMismatchError :: t1 -> t2 -> LambdaException
tyMismatchError t1
expected t2
actual
  = Text -> LambdaException
TyMismatchError
  (Text -> LambdaException) -> Text -> LambdaException
forall a b. (a -> b) -> a -> b
$ Text
"Couldn't match expected type "
  Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> t1 -> Text
forall pretty. Pretty pretty => pretty -> Text
prettyPrint t1
expected
  Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" with actual type "
  Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> t2 -> Text
forall pretty. Pretty pretty => pretty -> Text
prettyPrint t2
actual