{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}

-- |
-- Module      : Jikka.Core.Convert.TypeInfer
-- Description : does type inference. / 型推論を行います。
-- Copyright   : (c) Kimiyuki Onaka, 2021
-- License     : Apache License 2.0
-- Maintainer  : kimiyuki95@gmail.com
-- Stability   : experimental
-- Portability : portable
module Jikka.Core.Convert.TypeInfer
  ( run,

    -- * internal types and functions
    Equation (..),
    formularizeProgram,
    sortEquations,
    mergeAssertions,
    Subst (..),
    subst,
    solveEquations,
    substProgram,
  )
where

import Control.Arrow (second)
import Control.Monad.State.Strict
import Control.Monad.Writer.Strict (MonadWriter, execWriterT, tell)
import qualified Data.Map.Strict as M
import Data.Monoid (Dual (..))
import Jikka.Common.Alpha
import Jikka.Common.Error
import Jikka.Core.Format (formatType)
import Jikka.Core.Language.Expr
import Jikka.Core.Language.FreeVars
import Jikka.Core.Language.Lint
import Jikka.Core.Language.TypeCheck (literalToType, typecheckProgram)
import Jikka.Core.Language.Util

data Equation
  = TypeEquation Type Type
  | TypeAssertion VarName Type
  deriving (Equation -> Equation -> Bool
(Equation -> Equation -> Bool)
-> (Equation -> Equation -> Bool) -> Eq Equation
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Equation -> Equation -> Bool
$c/= :: Equation -> Equation -> Bool
== :: Equation -> Equation -> Bool
$c== :: Equation -> Equation -> Bool
Eq, Eq Equation
Eq Equation
-> (Equation -> Equation -> Ordering)
-> (Equation -> Equation -> Bool)
-> (Equation -> Equation -> Bool)
-> (Equation -> Equation -> Bool)
-> (Equation -> Equation -> Bool)
-> (Equation -> Equation -> Equation)
-> (Equation -> Equation -> Equation)
-> Ord Equation
Equation -> Equation -> Bool
Equation -> Equation -> Ordering
Equation -> Equation -> Equation
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Equation -> Equation -> Equation
$cmin :: Equation -> Equation -> Equation
max :: Equation -> Equation -> Equation
$cmax :: Equation -> Equation -> Equation
>= :: Equation -> Equation -> Bool
$c>= :: Equation -> Equation -> Bool
> :: Equation -> Equation -> Bool
$c> :: Equation -> Equation -> Bool
<= :: Equation -> Equation -> Bool
$c<= :: Equation -> Equation -> Bool
< :: Equation -> Equation -> Bool
$c< :: Equation -> Equation -> Bool
compare :: Equation -> Equation -> Ordering
$ccompare :: Equation -> Equation -> Ordering
$cp1Ord :: Eq Equation
Ord, Int -> Equation -> ShowS
[Equation] -> ShowS
Equation -> String
(Int -> Equation -> ShowS)
-> (Equation -> String) -> ([Equation] -> ShowS) -> Show Equation
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Equation] -> ShowS
$cshowList :: [Equation] -> ShowS
show :: Equation -> String
$cshow :: Equation -> String
showsPrec :: Int -> Equation -> ShowS
$cshowsPrec :: Int -> Equation -> ShowS
Show, ReadPrec [Equation]
ReadPrec Equation
Int -> ReadS Equation
ReadS [Equation]
(Int -> ReadS Equation)
-> ReadS [Equation]
-> ReadPrec Equation
-> ReadPrec [Equation]
-> Read Equation
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [Equation]
$creadListPrec :: ReadPrec [Equation]
readPrec :: ReadPrec Equation
$creadPrec :: ReadPrec Equation
readList :: ReadS [Equation]
$creadList :: ReadS [Equation]
readsPrec :: Int -> ReadS Equation
$creadsPrec :: Int -> ReadS Equation
Read)

type Eqns = Dual [Equation]

formularizeType :: MonadWriter Eqns m => Type -> Type -> m ()
formularizeType :: Type -> Type -> m ()
formularizeType Type
t1 Type
t2 = Dual [Equation] -> m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (Dual [Equation] -> m ()) -> Dual [Equation] -> m ()
forall a b. (a -> b) -> a -> b
$ [Equation] -> Dual [Equation]
forall a. a -> Dual a
Dual [Type -> Type -> Equation
TypeEquation Type
t1 Type
t2]

formularizeVarName :: MonadWriter Eqns m => VarName -> Type -> m ()
formularizeVarName :: VarName -> Type -> m ()
formularizeVarName VarName
x Type
t = Dual [Equation] -> m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (Dual [Equation] -> m ()) -> Dual [Equation] -> m ()
forall a b. (a -> b) -> a -> b
$ [Equation] -> Dual [Equation]
forall a. a -> Dual a
Dual [VarName -> Type -> Equation
TypeAssertion VarName
x Type
t]

formularizeExpr :: (MonadWriter Eqns m, MonadAlpha m) => Expr -> m Type
formularizeExpr :: Expr -> m Type
formularizeExpr = \case
  Var VarName
x -> do
    Type
t <- m Type
forall (m :: * -> *). MonadAlpha m => m Type
genType
    VarName -> Type -> m ()
forall (m :: * -> *).
MonadWriter (Dual [Equation]) m =>
VarName -> Type -> m ()
formularizeVarName VarName
x Type
t
    Type -> m Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
t
  Lit Literal
lit -> Type -> m Type
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> m Type) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ Literal -> Type
literalToType Literal
lit
  App Expr
f Expr
e -> do
    Type
ret <- m Type
forall (m :: * -> *). MonadAlpha m => m Type
genType
    Type
t <- Expr -> m Type
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Expr -> m Type
formularizeExpr Expr
e
    Expr -> Type -> m ()
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Expr -> Type -> m ()
formularizeExpr' Expr
f (Type -> Type -> Type
FunTy Type
t Type
ret)
    Type -> m Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
ret
  Lam VarName
x Type
t Expr
body -> do
    VarName -> Type -> m ()
forall (m :: * -> *).
MonadWriter (Dual [Equation]) m =>
VarName -> Type -> m ()
formularizeVarName VarName
x Type
t
    Type
ret <- Expr -> m Type
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Expr -> m Type
formularizeExpr Expr
body
    Type -> m Type
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> m Type) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ Type -> Type -> Type
FunTy Type
t Type
ret
  Let VarName
x Type
t Expr
e1 Expr
e2 -> do
    VarName -> Type -> m ()
forall (m :: * -> *).
MonadWriter (Dual [Equation]) m =>
VarName -> Type -> m ()
formularizeVarName VarName
x Type
t
    Expr -> Type -> m ()
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Expr -> Type -> m ()
formularizeExpr' Expr
e1 Type
t
    Expr -> m Type
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Expr -> m Type
formularizeExpr Expr
e2

formularizeExpr' :: (MonadWriter Eqns m, MonadAlpha m) => Expr -> Type -> m ()
formularizeExpr' :: Expr -> Type -> m ()
formularizeExpr' Expr
e Type
t = do
  Type
t' <- Expr -> m Type
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Expr -> m Type
formularizeExpr Expr
e
  Type -> Type -> m ()
forall (m :: * -> *).
MonadWriter (Dual [Equation]) m =>
Type -> Type -> m ()
formularizeType Type
t Type
t'

formularizeToplevelExpr :: (MonadWriter Eqns m, MonadAlpha m) => ToplevelExpr -> m Type
formularizeToplevelExpr :: ToplevelExpr -> m Type
formularizeToplevelExpr = \case
  ResultExpr Expr
e -> Expr -> m Type
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Expr -> m Type
formularizeExpr Expr
e
  ToplevelLet VarName
x Type
t Expr
e ToplevelExpr
cont -> do
    VarName -> Type -> m ()
forall (m :: * -> *).
MonadWriter (Dual [Equation]) m =>
VarName -> Type -> m ()
formularizeVarName VarName
x Type
t
    Expr -> Type -> m ()
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Expr -> Type -> m ()
formularizeExpr' Expr
e Type
t
    ToplevelExpr -> m Type
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
ToplevelExpr -> m Type
formularizeToplevelExpr ToplevelExpr
cont
  ToplevelLetRec VarName
f [(VarName, Type)]
args Type
ret Expr
body ToplevelExpr
cont -> do
    VarName -> Type -> m ()
forall (m :: * -> *).
MonadWriter (Dual [Equation]) m =>
VarName -> Type -> m ()
formularizeVarName VarName
f ([Type] -> Type -> Type
curryFunTy (((VarName, Type) -> Type) -> [(VarName, Type)] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (VarName, Type) -> Type
forall a b. (a, b) -> b
snd [(VarName, Type)]
args) Type
ret)
    ((VarName, Type) -> m ()) -> [(VarName, Type)] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((VarName -> Type -> m ()) -> (VarName, Type) -> m ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VarName -> Type -> m ()
forall (m :: * -> *).
MonadWriter (Dual [Equation]) m =>
VarName -> Type -> m ()
formularizeVarName) [(VarName, Type)]
args
    Expr -> Type -> m ()
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Expr -> Type -> m ()
formularizeExpr' Expr
body Type
ret
    ToplevelExpr -> m Type
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
ToplevelExpr -> m Type
formularizeToplevelExpr ToplevelExpr
cont

formularizeProgram :: MonadAlpha m => Program -> m [Equation]
formularizeProgram :: ToplevelExpr -> m [Equation]
formularizeProgram ToplevelExpr
prog = Dual [Equation] -> [Equation]
forall a. Dual a -> a
getDual (Dual [Equation] -> [Equation])
-> m (Dual [Equation]) -> m [Equation]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> WriterT (Dual [Equation]) m Type -> m (Dual [Equation])
forall (m :: * -> *) w a. Monad m => WriterT w m a -> m w
execWriterT (ToplevelExpr -> WriterT (Dual [Equation]) m Type
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
ToplevelExpr -> m Type
formularizeToplevelExpr ToplevelExpr
prog)

sortEquations :: [Equation] -> ([(Type, Type)], [(VarName, Type)])
sortEquations :: [Equation] -> ([(Type, Type)], [(VarName, Type)])
sortEquations = [(Type, Type)]
-> [(VarName, Type)]
-> [Equation]
-> ([(Type, Type)], [(VarName, Type)])
go [] []
  where
    go :: [(Type, Type)]
-> [(VarName, Type)]
-> [Equation]
-> ([(Type, Type)], [(VarName, Type)])
go [(Type, Type)]
eqns' [(VarName, Type)]
assertions [] = ([(Type, Type)]
eqns', [(VarName, Type)]
assertions)
    go [(Type, Type)]
eqns' [(VarName, Type)]
assertions (Equation
eqn : [Equation]
eqns) = case Equation
eqn of
      TypeEquation Type
t1 Type
t2 -> [(Type, Type)]
-> [(VarName, Type)]
-> [Equation]
-> ([(Type, Type)], [(VarName, Type)])
go ((Type
t1, Type
t2) (Type, Type) -> [(Type, Type)] -> [(Type, Type)]
forall a. a -> [a] -> [a]
: [(Type, Type)]
eqns') [(VarName, Type)]
assertions [Equation]
eqns
      TypeAssertion VarName
x Type
t -> [(Type, Type)]
-> [(VarName, Type)]
-> [Equation]
-> ([(Type, Type)], [(VarName, Type)])
go [(Type, Type)]
eqns' ((VarName
x, Type
t) (VarName, Type) -> [(VarName, Type)] -> [(VarName, Type)]
forall a. a -> [a] -> [a]
: [(VarName, Type)]
assertions) [Equation]
eqns

mergeAssertions :: [(VarName, Type)] -> [(Type, Type)]
mergeAssertions :: [(VarName, Type)] -> [(Type, Type)]
mergeAssertions = Map VarName Type
-> [(Type, Type)] -> [(VarName, Type)] -> [(Type, Type)]
forall k b. Ord k => Map k b -> [(b, b)] -> [(k, b)] -> [(b, b)]
go Map VarName Type
forall k a. Map k a
M.empty []
  where
    go :: Map k b -> [(b, b)] -> [(k, b)] -> [(b, b)]
go Map k b
_ [(b, b)]
eqns [] = [(b, b)]
eqns
    go Map k b
gamma [(b, b)]
eqns ((k
x, b
t) : [(k, b)]
assertions) = case k -> Map k b -> Maybe b
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup k
x Map k b
gamma of
      Maybe b
Nothing -> Map k b -> [(b, b)] -> [(k, b)] -> [(b, b)]
go (k -> b -> Map k b -> Map k b
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert k
x b
t Map k b
gamma) [(b, b)]
eqns [(k, b)]
assertions
      Just b
t' -> Map k b -> [(b, b)] -> [(k, b)] -> [(b, b)]
go Map k b
gamma ((b
t, b
t') (b, b) -> [(b, b)] -> [(b, b)]
forall a. a -> [a] -> [a]
: [(b, b)]
eqns) [(k, b)]
assertions

-- | `Subst` is type substituion. It's a mapping from type variables to their actual types.
newtype Subst = Subst {Subst -> Map TypeName Type
unSubst :: M.Map TypeName Type}

subst :: Subst -> Type -> Type
subst :: Subst -> Type -> Type
subst Subst
sigma = \case
  VarTy TypeName
x ->
    case TypeName -> Map TypeName Type -> Maybe Type
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup TypeName
x (Subst -> Map TypeName Type
unSubst Subst
sigma) of
      Maybe Type
Nothing -> TypeName -> Type
VarTy TypeName
x
      Just Type
t -> Subst -> Type -> Type
subst Subst
sigma Type
t
  Type
IntTy -> Type
IntTy
  Type
BoolTy -> Type
BoolTy
  ListTy Type
t -> Type -> Type
ListTy (Subst -> Type -> Type
subst Subst
sigma Type
t)
  TupleTy [Type]
ts -> [Type] -> Type
TupleTy ((Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Subst -> Type -> Type
subst Subst
sigma) [Type]
ts)
  FunTy Type
t Type
ret -> Type -> Type -> Type
FunTy (Subst -> Type -> Type
subst Subst
sigma Type
t) (Subst -> Type -> Type
subst Subst
sigma Type
ret)
  DataStructureTy DataStructure
ds -> DataStructure -> Type
DataStructureTy DataStructure
ds

unifyTyVar :: (MonadState Subst m, MonadError Error m) => TypeName -> Type -> m ()
unifyTyVar :: TypeName -> Type -> m ()
unifyTyVar TypeName
x Type
t =
  if TypeName
x TypeName -> [TypeName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` Type -> [TypeName]
freeTyVars Type
t
    then String -> m ()
forall (m :: * -> *) a. MonadError Error m => String -> m a
throwInternalError (String -> m ()) -> String -> m ()
forall a b. (a -> b) -> a -> b
$ String
"looped type equation " String -> ShowS
forall a. [a] -> [a] -> [a]
++ TypeName -> String
unTypeName TypeName
x String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" = " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
formatType Type
t
    else do
      (Subst -> Subst) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' (Map TypeName Type -> Subst
Subst (Map TypeName Type -> Subst)
-> (Subst -> Map TypeName Type) -> Subst -> Subst
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeName -> Type -> Map TypeName Type -> Map TypeName Type
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert TypeName
x Type
t (Map TypeName Type -> Map TypeName Type)
-> (Subst -> Map TypeName Type) -> Subst -> Map TypeName Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Subst -> Map TypeName Type
unSubst) -- This doesn't introduce the loop.

unifyType :: (MonadState Subst m, MonadError Error m) => Type -> Type -> m ()
unifyType :: Type -> Type -> m ()
unifyType Type
t1 Type
t2 = String -> m () -> m ()
forall (m :: * -> *) a. MonadError Error m => String -> m a -> m a
wrapError' (String
"failed to unify " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
formatType Type
t1 String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" and " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
formatType Type
t2) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
  Subst
sigma <- m Subst
forall s (m :: * -> *). MonadState s m => m s
get
  Type
t1 <- Type -> m Type
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> m Type) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ Subst -> Type -> Type
subst Subst
sigma Type
t1 -- shadowing
  Type
t2 <- Type -> m Type
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> m Type) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ Subst -> Type -> Type
subst Subst
sigma Type
t2 -- shadowing
  case (Type
t1, Type
t2) of
    (Type, Type)
_ | Type
t1 Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
t2 -> () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    (VarTy TypeName
x1, Type
_) -> do
      TypeName -> Type -> m ()
forall (m :: * -> *).
(MonadState Subst m, MonadError Error m) =>
TypeName -> Type -> m ()
unifyTyVar TypeName
x1 Type
t2
    (Type
_, VarTy TypeName
x2) -> do
      TypeName -> Type -> m ()
forall (m :: * -> *).
(MonadState Subst m, MonadError Error m) =>
TypeName -> Type -> m ()
unifyTyVar TypeName
x2 Type
t1
    (ListTy Type
t1, ListTy Type
t2) -> do
      Type -> Type -> m ()
forall (m :: * -> *).
(MonadState Subst m, MonadError Error m) =>
Type -> Type -> m ()
unifyType Type
t1 Type
t2
    (TupleTy [Type]
ts1, TupleTy [Type]
ts2) -> do
      if [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
ts1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
ts2
        then ((Type, Type) -> m ()) -> [(Type, Type)] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((Type -> Type -> m ()) -> (Type, Type) -> m ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Type -> Type -> m ()
forall (m :: * -> *).
(MonadState Subst m, MonadError Error m) =>
Type -> Type -> m ()
unifyType) ([Type] -> [Type] -> [(Type, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
ts1 [Type]
ts2)
        else String -> m ()
forall (m :: * -> *) a. MonadError Error m => String -> m a
throwInternalError (String -> m ()) -> String -> m ()
forall a b. (a -> b) -> a -> b
$ String
"different type ctors " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
formatType Type
t1 String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" and " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
formatType Type
t2
    (FunTy Type
t1 Type
ret1, FunTy Type
t2 Type
ret2) -> do
      Type -> Type -> m ()
forall (m :: * -> *).
(MonadState Subst m, MonadError Error m) =>
Type -> Type -> m ()
unifyType Type
t1 Type
t2
      Type -> Type -> m ()
forall (m :: * -> *).
(MonadState Subst m, MonadError Error m) =>
Type -> Type -> m ()
unifyType Type
ret1 Type
ret2
    (Type, Type)
_ -> String -> m ()
forall (m :: * -> *) a. MonadError Error m => String -> m a
throwInternalError (String -> m ()) -> String -> m ()
forall a b. (a -> b) -> a -> b
$ String
"different type ctors " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
formatType Type
t1 String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" and " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
formatType Type
t2

solveEquations :: MonadError Error m => [(Type, Type)] -> m Subst
solveEquations :: [(Type, Type)] -> m Subst
solveEquations [(Type, Type)]
eqns = String -> m Subst -> m Subst
forall (m :: * -> *) a. MonadError Error m => String -> m a -> m a
wrapError' String
"failed to solve type equations" (m Subst -> m Subst) -> m Subst -> m Subst
forall a b. (a -> b) -> a -> b
$ do
  StateT Subst m () -> Subst -> m Subst
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m s
execStateT (((Type, Type) -> StateT Subst m ())
-> [(Type, Type)] -> StateT Subst m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((Type -> Type -> StateT Subst m ())
-> (Type, Type) -> StateT Subst m ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Type -> Type -> StateT Subst m ()
forall (m :: * -> *).
(MonadState Subst m, MonadError Error m) =>
Type -> Type -> m ()
unifyType) [(Type, Type)]
eqns) (Map TypeName Type -> Subst
Subst Map TypeName Type
forall k a. Map k a
M.empty)

-- | `substUnit` replaces all undetermined type variables with the unit type.
substUnit :: Type -> Type
substUnit :: Type -> Type
substUnit = \case
  VarTy TypeName
_ -> [Type] -> Type
TupleTy []
  Type
IntTy -> Type
IntTy
  Type
BoolTy -> Type
BoolTy
  ListTy Type
t -> Type -> Type
ListTy (Type -> Type
substUnit Type
t)
  TupleTy [Type]
ts -> [Type] -> Type
TupleTy ((Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
substUnit [Type]
ts)
  FunTy Type
t Type
ret -> Type -> Type -> Type
FunTy (Type -> Type
substUnit Type
t) (Type -> Type
substUnit Type
ret)
  DataStructureTy DataStructure
ds -> DataStructure -> Type
DataStructureTy DataStructure
ds

-- | `subst'` does `subst` and replaces all undetermined type variables with the unit type.
subst' :: Subst -> Type -> Type
subst' :: Subst -> Type -> Type
subst' Subst
sigma = Type -> Type
substUnit (Type -> Type) -> (Type -> Type) -> Type -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Subst -> Type -> Type
subst Subst
sigma

substBuiltin :: Subst -> Builtin -> Builtin
substBuiltin :: Subst -> Builtin -> Builtin
substBuiltin Subst
sigma = (Type -> Type) -> Builtin -> Builtin
mapTypeInBuiltin (Subst -> Type -> Type
subst' Subst
sigma)

substLiteral :: Subst -> Literal -> Literal
substLiteral :: Subst -> Literal -> Literal
substLiteral Subst
sigma = \case
  LitBuiltin Builtin
builtin -> Builtin -> Literal
LitBuiltin (Subst -> Builtin -> Builtin
substBuiltin Subst
sigma Builtin
builtin)
  LitInt Integer
n -> Integer -> Literal
LitInt Integer
n
  LitBool Bool
p -> Bool -> Literal
LitBool Bool
p
  LitNil Type
t -> Type -> Literal
LitNil (Subst -> Type -> Type
subst' Subst
sigma Type
t)
  LitBottom Type
t String
err -> Type -> String -> Literal
LitBottom (Subst -> Type -> Type
subst' Subst
sigma Type
t) String
err

substExpr :: Subst -> Expr -> Expr
substExpr :: Subst -> Expr -> Expr
substExpr Subst
sigma = Expr -> Expr
go
  where
    go :: Expr -> Expr
go = \case
      Var VarName
x -> VarName -> Expr
Var VarName
x
      Lit Literal
lit -> Literal -> Expr
Lit (Subst -> Literal -> Literal
substLiteral Subst
sigma Literal
lit)
      App Expr
f Expr
e -> Expr -> Expr -> Expr
App (Expr -> Expr
go Expr
f) (Expr -> Expr
go Expr
e)
      Lam VarName
x Type
t Expr
body -> VarName -> Type -> Expr -> Expr
Lam VarName
x (Subst -> Type -> Type
subst' Subst
sigma Type
t) (Expr -> Expr
go Expr
body)
      Let VarName
x Type
t Expr
e1 Expr
e2 -> VarName -> Type -> Expr -> Expr -> Expr
Let VarName
x (Subst -> Type -> Type
subst Subst
sigma Type
t) (Expr -> Expr
go Expr
e1) (Expr -> Expr
go Expr
e2)

substToplevelExpr :: Subst -> ToplevelExpr -> ToplevelExpr
substToplevelExpr :: Subst -> ToplevelExpr -> ToplevelExpr
substToplevelExpr Subst
sigma = \case
  ResultExpr Expr
e -> Expr -> ToplevelExpr
ResultExpr (Subst -> Expr -> Expr
substExpr Subst
sigma Expr
e)
  ToplevelLet VarName
x Type
t Expr
e ToplevelExpr
cont -> VarName -> Type -> Expr -> ToplevelExpr -> ToplevelExpr
ToplevelLet VarName
x (Subst -> Type -> Type
subst' Subst
sigma Type
t) (Subst -> Expr -> Expr
substExpr Subst
sigma Expr
e) (Subst -> ToplevelExpr -> ToplevelExpr
substToplevelExpr Subst
sigma ToplevelExpr
cont)
  ToplevelLetRec VarName
f [(VarName, Type)]
args Type
ret Expr
body ToplevelExpr
cont -> VarName
-> [(VarName, Type)]
-> Type
-> Expr
-> ToplevelExpr
-> ToplevelExpr
ToplevelLetRec VarName
f (((VarName, Type) -> (VarName, Type))
-> [(VarName, Type)] -> [(VarName, Type)]
forall a b. (a -> b) -> [a] -> [b]
map ((Type -> Type) -> (VarName, Type) -> (VarName, Type)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (Subst -> Type -> Type
subst' Subst
sigma)) [(VarName, Type)]
args) (Subst -> Type -> Type
subst' Subst
sigma Type
ret) (Subst -> Expr -> Expr
substExpr Subst
sigma Expr
body) (Subst -> ToplevelExpr -> ToplevelExpr
substToplevelExpr Subst
sigma ToplevelExpr
cont)

substProgram :: Subst -> Program -> Program
substProgram :: Subst -> ToplevelExpr -> ToplevelExpr
substProgram = Subst -> ToplevelExpr -> ToplevelExpr
substToplevelExpr

-- | `run` does type inference.
--
-- * This assumes that program has no name conflicts.
--
-- Before:
--
-- > let f = fun y -> y
-- > in let x = 1
-- > in f(x + x)
--
-- After:
--
-- > let f: int -> int = fun y: int -> y
-- > in let x: int = 1
-- > in f(x + x)
run :: (MonadAlpha m, MonadError Error m) => Program -> m Program
run :: ToplevelExpr -> m ToplevelExpr
run ToplevelExpr
prog = String -> m ToplevelExpr -> m ToplevelExpr
forall (m :: * -> *) a. MonadError Error m => String -> m a -> m a
wrapError' String
"Jikka.Core.Convert.TypeInfer" (m ToplevelExpr -> m ToplevelExpr)
-> m ToplevelExpr -> m ToplevelExpr
forall a b. (a -> b) -> a -> b
$ do
  [Equation]
eqns <- ToplevelExpr -> m [Equation]
forall (m :: * -> *). MonadAlpha m => ToplevelExpr -> m [Equation]
formularizeProgram ToplevelExpr
prog
  let ([(Type, Type)]
eqns', [(VarName, Type)]
assertions) = [Equation] -> ([(Type, Type)], [(VarName, Type)])
sortEquations [Equation]
eqns
  let eqns'' :: [(Type, Type)]
eqns'' = [(VarName, Type)] -> [(Type, Type)]
mergeAssertions [(VarName, Type)]
assertions
  Subst
sigma <- [(Type, Type)] -> m Subst
forall (m :: * -> *).
MonadError Error m =>
[(Type, Type)] -> m Subst
solveEquations ([(Type, Type)]
eqns' [(Type, Type)] -> [(Type, Type)] -> [(Type, Type)]
forall a. [a] -> [a] -> [a]
++ [(Type, Type)]
eqns'')
  ToplevelExpr
prog <- ToplevelExpr -> m ToplevelExpr
forall (m :: * -> *) a. Monad m => a -> m a
return (ToplevelExpr -> m ToplevelExpr) -> ToplevelExpr -> m ToplevelExpr
forall a b. (a -> b) -> a -> b
$ Subst -> ToplevelExpr -> ToplevelExpr
substProgram Subst
sigma ToplevelExpr
prog
  m Type -> m Type
forall (m :: * -> *) a. MonadError Error m => m a -> m a
postcondition (m Type -> m Type) -> m Type -> m Type
forall a b. (a -> b) -> a -> b
$ do
    ToplevelExpr -> m Type
forall (m :: * -> *). MonadError Error m => ToplevelExpr -> m Type
typecheckProgram ToplevelExpr
prog
  ToplevelExpr -> m ToplevelExpr
forall (m :: * -> *) a. Monad m => a -> m a
return ToplevelExpr
prog