{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}

-- |
-- Module      : Jikka.RestrictedPython.Convert.TypeInfer
-- Description : does type inference. / 型推論を行います。
-- Copyright   : (c) Kimiyuki Onaka, 2021
-- License     : Apache License 2.0
-- Maintainer  : kimiyuki95@gmail.com
-- Stability   : experimental
-- Portability : portable
module Jikka.RestrictedPython.Convert.TypeInfer
  ( run,

    -- * internal types and functions
    Equation (..),
    formularizeProgram,
    sortEquations,
    mergeAssertions,
    Subst (..),
    subst,
    solveEquations,
    mapTypeProgram,
  )
where

import Control.Arrow (second)
import Control.Monad.Reader
import Control.Monad.State.Strict
import Control.Monad.Writer.Strict
import qualified Data.Map.Strict as M
import Jikka.Common.Alpha
import Jikka.Common.Error
import Jikka.RestrictedPython.Format (formatType)
import Jikka.RestrictedPython.Language.Builtin
import Jikka.RestrictedPython.Language.Expr
import Jikka.RestrictedPython.Language.Util

data Equation
  = TypeEquation Type Type (Maybe Loc)
  | TypeAssertion VarName' Type
  deriving (Equation -> Equation -> Bool
(Equation -> Equation -> Bool)
-> (Equation -> Equation -> Bool) -> Eq Equation
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Equation -> Equation -> Bool
$c/= :: Equation -> Equation -> Bool
== :: Equation -> Equation -> Bool
$c== :: Equation -> Equation -> Bool
Eq, Eq Equation
Eq Equation
-> (Equation -> Equation -> Ordering)
-> (Equation -> Equation -> Bool)
-> (Equation -> Equation -> Bool)
-> (Equation -> Equation -> Bool)
-> (Equation -> Equation -> Bool)
-> (Equation -> Equation -> Equation)
-> (Equation -> Equation -> Equation)
-> Ord Equation
Equation -> Equation -> Bool
Equation -> Equation -> Ordering
Equation -> Equation -> Equation
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Equation -> Equation -> Equation
$cmin :: Equation -> Equation -> Equation
max :: Equation -> Equation -> Equation
$cmax :: Equation -> Equation -> Equation
>= :: Equation -> Equation -> Bool
$c>= :: Equation -> Equation -> Bool
> :: Equation -> Equation -> Bool
$c> :: Equation -> Equation -> Bool
<= :: Equation -> Equation -> Bool
$c<= :: Equation -> Equation -> Bool
< :: Equation -> Equation -> Bool
$c< :: Equation -> Equation -> Bool
compare :: Equation -> Equation -> Ordering
$ccompare :: Equation -> Equation -> Ordering
$cp1Ord :: Eq Equation
Ord, Int -> Equation -> ShowS
[Equation] -> ShowS
Equation -> String
(Int -> Equation -> ShowS)
-> (Equation -> String) -> ([Equation] -> ShowS) -> Show Equation
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Equation] -> ShowS
$cshowList :: [Equation] -> ShowS
show :: Equation -> String
$cshow :: Equation -> String
showsPrec :: Int -> Equation -> ShowS
$cshowsPrec :: Int -> Equation -> ShowS
Show, ReadPrec [Equation]
ReadPrec Equation
Int -> ReadS Equation
ReadS [Equation]
(Int -> ReadS Equation)
-> ReadS [Equation]
-> ReadPrec Equation
-> ReadPrec [Equation]
-> Read Equation
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [Equation]
$creadListPrec :: ReadPrec [Equation]
readPrec :: ReadPrec Equation
$creadPrec :: ReadPrec Equation
readList :: ReadS [Equation]
$creadList :: ReadS [Equation]
readsPrec :: Int -> ReadS Equation
$creadsPrec :: Int -> ReadS Equation
Read)

type Eqns = Dual [Equation]

formularizeType :: MonadWriter Eqns m => Type -> Type -> Maybe Loc -> m ()
formularizeType :: Type -> Type -> Maybe Loc -> m ()
formularizeType Type
t1 Type
t2 Maybe Loc
location = Dual [Equation] -> m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (Dual [Equation] -> m ()) -> Dual [Equation] -> m ()
forall a b. (a -> b) -> a -> b
$ [Equation] -> Dual [Equation]
forall a. a -> Dual a
Dual [Type -> Type -> Maybe Loc -> Equation
TypeEquation Type
t1 Type
t2 Maybe Loc
location]

formularizeVarName :: MonadWriter Eqns m => VarName' -> Type -> m ()
formularizeVarName :: VarName' -> Type -> m ()
formularizeVarName VarName'
x Type
t = Dual [Equation] -> m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (Dual [Equation] -> m ()) -> Dual [Equation] -> m ()
forall a b. (a -> b) -> a -> b
$ [Equation] -> Dual [Equation]
forall a. a -> Dual a
Dual [VarName' -> Type -> Equation
TypeAssertion VarName'
x Type
t]

formularizeTarget :: (MonadWriter Eqns m, MonadAlpha m) => Target' -> m Type
formularizeTarget :: Target' -> m Type
formularizeTarget Target'
x0 = case Target' -> Target
forall a. WithLoc' a -> a
value' Target'
x0 of
  SubscriptTrg Target'
f Expr'
index -> do
    Type
t <- m Type
forall (m :: * -> *). MonadAlpha m => m Type
genType
    Type
tf <- Target' -> m Type
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Target' -> m Type
formularizeTarget Target'
f
    Type -> Type -> Maybe Loc -> m ()
forall (m :: * -> *).
MonadWriter (Dual [Equation]) m =>
Type -> Type -> Maybe Loc -> m ()
formularizeType Type
tf (Type -> Type
ListTy Type
t) (Target' -> Maybe Loc
forall a. WithLoc' a -> Maybe Loc
loc' Target'
x0)
    Type
tindex <- Expr' -> m Type
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Expr' -> m Type
formularizeExpr Expr'
index
    Type -> Type -> Maybe Loc -> m ()
forall (m :: * -> *).
MonadWriter (Dual [Equation]) m =>
Type -> Type -> Maybe Loc -> m ()
formularizeType Type
tindex Type
IntTy (Target' -> Maybe Loc
forall a. WithLoc' a -> Maybe Loc
loc' Target'
x0)
    Type -> m Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
t
  NameTrg VarName'
x -> do
    Type
t <- m Type
forall (m :: * -> *). MonadAlpha m => m Type
genType
    VarName' -> Type -> m ()
forall (m :: * -> *).
MonadWriter (Dual [Equation]) m =>
VarName' -> Type -> m ()
formularizeVarName VarName'
x Type
t
    Type -> m Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
t
  TupleTrg [Target']
xs -> do
    [Type] -> Type
TupleTy ([Type] -> Type) -> m [Type] -> m Type
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Target' -> m Type) -> [Target'] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Target' -> m Type
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Target' -> m Type
formularizeTarget [Target']
xs

formularizeTarget' :: (MonadWriter Eqns m, MonadAlpha m) => Target' -> Type -> m ()
formularizeTarget' :: Target' -> Type -> m ()
formularizeTarget' Target'
x0 Type
t = do
  Type
t' <- Target' -> m Type
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Target' -> m Type
formularizeTarget Target'
x0
  Type -> Type -> Maybe Loc -> m ()
forall (m :: * -> *).
MonadWriter (Dual [Equation]) m =>
Type -> Type -> Maybe Loc -> m ()
formularizeType Type
t Type
t' (Target' -> Maybe Loc
forall a. WithLoc' a -> Maybe Loc
loc' Target'
x0)

formularizeExpr :: (MonadWriter Eqns m, MonadAlpha m) => Expr' -> m Type
formularizeExpr :: Expr' -> m Type
formularizeExpr Expr'
e0 = case Expr' -> Expr
forall a. WithLoc' a -> a
value' Expr'
e0 of
  BoolOp Expr'
e1 BoolOp
_ Expr'
e2 -> do
    Expr' -> Type -> m ()
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Expr' -> Type -> m ()
formularizeExpr' Expr'
e1 Type
BoolTy
    Expr' -> Type -> m ()
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Expr' -> Type -> m ()
formularizeExpr' Expr'
e2 Type
BoolTy
    Type -> m Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
BoolTy
  BinOp Expr'
e1 Operator
_ Expr'
e2 -> do
    Expr' -> Type -> m ()
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Expr' -> Type -> m ()
formularizeExpr' Expr'
e1 Type
IntTy
    Expr' -> Type -> m ()
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Expr' -> Type -> m ()
formularizeExpr' Expr'
e2 Type
IntTy
    Type -> m Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
IntTy
  UnaryOp UnaryOp
op Expr'
e -> do
    let t' :: Type
t' = if UnaryOp
op UnaryOp -> UnaryOp -> Bool
forall a. Eq a => a -> a -> Bool
== UnaryOp
Not then Type
BoolTy else Type
IntTy
    Expr' -> Type -> m ()
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Expr' -> Type -> m ()
formularizeExpr' Expr'
e Type
t'
    Type -> m Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
t'
  Lambda [(VarName', Type)]
args Expr'
body -> do
    ((VarName', Type) -> m ()) -> [(VarName', Type)] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((VarName' -> Type -> m ()) -> (VarName', Type) -> m ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VarName' -> Type -> m ()
forall (m :: * -> *).
MonadWriter (Dual [Equation]) m =>
VarName' -> Type -> m ()
formularizeVarName) [(VarName', Type)]
args
    Type
ret <- m Type
forall (m :: * -> *). MonadAlpha m => m Type
genType
    Expr' -> Type -> m ()
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Expr' -> Type -> m ()
formularizeExpr' Expr'
body Type
ret
    Type -> m Type
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> m Type) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ [Type] -> Type -> Type
CallableTy (((VarName', Type) -> Type) -> [(VarName', Type)] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (VarName', Type) -> Type
forall a b. (a, b) -> b
snd [(VarName', Type)]
args) Type
ret
  IfExp Expr'
e1 Expr'
e2 Expr'
e3 -> do
    Expr' -> Type -> m ()
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Expr' -> Type -> m ()
formularizeExpr' Expr'
e1 Type
BoolTy
    Type
t <- Expr' -> m Type
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Expr' -> m Type
formularizeExpr Expr'
e2
    Expr' -> Type -> m ()
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Expr' -> Type -> m ()
formularizeExpr' Expr'
e3 Type
t
    Type -> m Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
t
  ListComp Expr'
e Comprehension
comp -> do
    let Comprehension Target'
x Expr'
iter Maybe Expr'
pred = Comprehension
comp
    Type
te <- Expr' -> m Type
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Expr' -> m Type
formularizeExpr Expr'
e
    Type
tx <- Target' -> m Type
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Target' -> m Type
formularizeTarget Target'
x
    Expr' -> Type -> m ()
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Expr' -> Type -> m ()
formularizeExpr' Expr'
iter (Type -> Type
ListTy Type
tx)
    case Maybe Expr'
pred of
      Maybe Expr'
Nothing -> () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
      Just Expr'
pred -> Expr' -> Type -> m ()
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Expr' -> Type -> m ()
formularizeExpr' Expr'
pred Type
BoolTy
    Type -> m Type
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> m Type) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ Type -> Type
ListTy Type
te
  Compare Expr'
e1 (CmpOp' CmpOp
op Type
t) Expr'
e2 -> do
    Expr' -> Type -> m ()
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Expr' -> Type -> m ()
formularizeExpr' Expr'
e1 Type
t
    Expr' -> Type -> m ()
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Expr' -> Type -> m ()
formularizeExpr' Expr'
e2 (if CmpOp
op CmpOp -> CmpOp -> Bool
forall a. Eq a => a -> a -> Bool
== CmpOp
In Bool -> Bool -> Bool
|| CmpOp
op CmpOp -> CmpOp -> Bool
forall a. Eq a => a -> a -> Bool
== CmpOp
NotIn then Type -> Type
ListTy Type
t else Type
t)
    Type -> m Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
BoolTy
  Call Expr'
f [Expr']
args -> do
    [Type]
ts <- (Expr' -> m Type) -> [Expr'] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Expr' -> m Type
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Expr' -> m Type
formularizeExpr [Expr']
args
    Type
ret <- m Type
forall (m :: * -> *). MonadAlpha m => m Type
genType
    Expr' -> Type -> m ()
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Expr' -> Type -> m ()
formularizeExpr' Expr'
f ([Type] -> Type -> Type
CallableTy [Type]
ts Type
ret)
    Type -> m Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
ret
  Constant Constant
const ->
    Type -> m Type
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> m Type) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ case Constant
const of
      Constant
ConstNone -> Type
NoneTy
      ConstInt Integer
_ -> Type
IntTy
      ConstBool Bool
_ -> Type
BoolTy
      ConstBuiltin Builtin
b -> Builtin -> Type
typeBuiltin Builtin
b
  Attribute Expr'
e Attribute'
x -> do
    let (Type
t1, Type
t2) = Attribute -> (Type, Type)
typeAttribute (Attribute' -> Attribute
forall a. WithLoc' a -> a
value' Attribute'
x)
    Expr' -> Type -> m ()
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Expr' -> Type -> m ()
formularizeExpr' Expr'
e Type
t1
    Type -> m Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
t2
  Subscript Expr'
e1 Expr'
e2 -> do
    Type
t <- m Type
forall (m :: * -> *). MonadAlpha m => m Type
genType
    Expr' -> Type -> m ()
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Expr' -> Type -> m ()
formularizeExpr' Expr'
e1 (Type -> Type
ListTy Type
t)
    Expr' -> Type -> m ()
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Expr' -> Type -> m ()
formularizeExpr' Expr'
e2 Type
IntTy
    Type -> m Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
t
  Starred Expr'
e -> do
    Type
t <- m Type
forall (m :: * -> *). MonadAlpha m => m Type
genType
    Expr' -> Type -> m ()
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Expr' -> Type -> m ()
formularizeExpr' Expr'
e (Type -> Type
ListTy Type
t)
    Type -> m Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
t -- because @*xs@ and @y@ has the same type in @[*xs, y]@
  Name VarName'
x -> do
    Type
t <- m Type
forall (m :: * -> *). MonadAlpha m => m Type
genType
    VarName' -> Type -> m ()
forall (m :: * -> *).
MonadWriter (Dual [Equation]) m =>
VarName' -> Type -> m ()
formularizeVarName VarName'
x Type
t
    Type -> m Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
t
  List Type
t [Expr']
es -> do
    [Expr'] -> (Expr' -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Expr']
es ((Expr' -> m ()) -> m ()) -> (Expr' -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Expr'
e -> do
      Expr' -> Type -> m ()
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Expr' -> Type -> m ()
formularizeExpr' Expr'
e Type
t
    Type -> m Type
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> m Type) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ Type -> Type
ListTy Type
t
  Tuple [Expr']
es -> [Type] -> Type
TupleTy ([Type] -> Type) -> m [Type] -> m Type
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Expr' -> m Type) -> [Expr'] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Expr' -> m Type
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Expr' -> m Type
formularizeExpr [Expr']
es
  SubscriptSlice Expr'
e Maybe Expr'
from Maybe Expr'
to Maybe Expr'
step -> do
    Type
t' <- m Type
forall (m :: * -> *). MonadAlpha m => m Type
genType
    Expr' -> Type -> m ()
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Expr' -> Type -> m ()
formularizeExpr' Expr'
e (Type -> Type
ListTy Type
t')
    let formularize :: Maybe Expr' -> m ()
formularize = \case
          Maybe Expr'
Nothing -> () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
          Just Expr'
e -> Expr' -> Type -> m ()
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Expr' -> Type -> m ()
formularizeExpr' Expr'
e Type
IntTy
    Maybe Expr' -> m ()
formularize Maybe Expr'
from
    Maybe Expr' -> m ()
formularize Maybe Expr'
to
    Maybe Expr' -> m ()
formularize Maybe Expr'
step
    Type -> m Type
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> Type
ListTy Type
t')

formularizeExpr' :: (MonadWriter Eqns m, MonadAlpha m) => Expr' -> Type -> m ()
formularizeExpr' :: Expr' -> Type -> m ()
formularizeExpr' Expr'
e0 Type
t = do
  Type
t' <- Expr' -> m Type
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Expr' -> m Type
formularizeExpr Expr'
e0
  Type -> Type -> Maybe Loc -> m ()
forall (m :: * -> *).
MonadWriter (Dual [Equation]) m =>
Type -> Type -> Maybe Loc -> m ()
formularizeType Type
t Type
t' (Expr' -> Maybe Loc
forall a. WithLoc' a -> Maybe Loc
loc' Expr'
e0)

formularizeStatement :: (MonadWriter Eqns m, MonadAlpha m) => Type -> Statement -> m ()
formularizeStatement :: Type -> Statement -> m ()
formularizeStatement Type
ret = \case
  Return Expr'
e -> do
    Type
t <- Expr' -> m Type
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Expr' -> m Type
formularizeExpr Expr'
e
    Type -> Type -> Maybe Loc -> m ()
forall (m :: * -> *).
MonadWriter (Dual [Equation]) m =>
Type -> Type -> Maybe Loc -> m ()
formularizeType Type
t Type
ret (Expr' -> Maybe Loc
forall a. WithLoc' a -> Maybe Loc
loc' Expr'
e)
  AugAssign Target'
x Operator
_ Expr'
e -> do
    Target' -> Type -> m ()
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Target' -> Type -> m ()
formularizeTarget' Target'
x Type
IntTy
    Expr' -> Type -> m ()
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Expr' -> Type -> m ()
formularizeExpr' Expr'
e Type
IntTy
  AnnAssign Target'
x Type
t Expr'
e -> do
    Target' -> Type -> m ()
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Target' -> Type -> m ()
formularizeTarget' Target'
x Type
t
    Expr' -> Type -> m ()
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Expr' -> Type -> m ()
formularizeExpr' Expr'
e Type
t
  For Target'
x Expr'
e [Statement]
body -> do
    Type
t <- Target' -> m Type
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Target' -> m Type
formularizeTarget Target'
x
    Expr' -> Type -> m ()
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Expr' -> Type -> m ()
formularizeExpr' Expr'
e (Type -> Type
ListTy Type
t)
    (Statement -> m ()) -> [Statement] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Type -> Statement -> m ()
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Type -> Statement -> m ()
formularizeStatement Type
ret) [Statement]
body
  If Expr'
e [Statement]
body1 [Statement]
body2 -> do
    Expr' -> Type -> m ()
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Expr' -> Type -> m ()
formularizeExpr' Expr'
e Type
BoolTy
    (Statement -> m ()) -> [Statement] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Type -> Statement -> m ()
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Type -> Statement -> m ()
formularizeStatement Type
ret) [Statement]
body1
    (Statement -> m ()) -> [Statement] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Type -> Statement -> m ()
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Type -> Statement -> m ()
formularizeStatement Type
ret) [Statement]
body2
  Assert Expr'
e -> do
    Expr' -> Type -> m ()
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Expr' -> Type -> m ()
formularizeExpr' Expr'
e Type
BoolTy
  Expr' Expr'
e -> do
    Expr' -> Type -> m ()
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Expr' -> Type -> m ()
formularizeExpr' Expr'
e Type
SideEffectTy

formularizeToplevelStatement :: (MonadWriter Eqns m, MonadAlpha m) => ToplevelStatement -> m ()
formularizeToplevelStatement :: ToplevelStatement -> m ()
formularizeToplevelStatement = \case
  ToplevelAnnAssign VarName'
x Type
t Expr'
e -> do
    VarName' -> Type -> m ()
forall (m :: * -> *).
MonadWriter (Dual [Equation]) m =>
VarName' -> Type -> m ()
formularizeVarName VarName'
x Type
t
    Expr' -> Type -> m ()
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Expr' -> Type -> m ()
formularizeExpr' Expr'
e Type
t
  ToplevelFunctionDef VarName'
f [(VarName', Type)]
args Type
ret [Statement]
body -> do
    ((VarName', Type) -> m ()) -> [(VarName', Type)] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((VarName' -> Type -> m ()) -> (VarName', Type) -> m ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VarName' -> Type -> m ()
forall (m :: * -> *).
MonadWriter (Dual [Equation]) m =>
VarName' -> Type -> m ()
formularizeVarName) [(VarName', Type)]
args
    VarName' -> Type -> m ()
forall (m :: * -> *).
MonadWriter (Dual [Equation]) m =>
VarName' -> Type -> m ()
formularizeVarName VarName'
f ([Type] -> Type -> Type
CallableTy (((VarName', Type) -> Type) -> [(VarName', Type)] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (VarName', Type) -> Type
forall a b. (a, b) -> b
snd [(VarName', Type)]
args) Type
ret)
    (Statement -> m ()) -> [Statement] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Type -> Statement -> m ()
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Type -> Statement -> m ()
formularizeStatement Type
ret) [Statement]
body
  ToplevelAssert Expr'
e -> do
    Expr' -> Type -> m ()
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
Expr' -> Type -> m ()
formularizeExpr' Expr'
e Type
BoolTy

formularizeProgram :: MonadAlpha m => Program -> m [Equation]
formularizeProgram :: Program -> m [Equation]
formularizeProgram Program
prog = Dual [Equation] -> [Equation]
forall a. Dual a -> a
getDual (Dual [Equation] -> [Equation])
-> m (Dual [Equation]) -> m [Equation]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> WriterT (Dual [Equation]) m () -> m (Dual [Equation])
forall (m :: * -> *) w a. Monad m => WriterT w m a -> m w
execWriterT ((ToplevelStatement -> WriterT (Dual [Equation]) m ())
-> Program -> WriterT (Dual [Equation]) m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ToplevelStatement -> WriterT (Dual [Equation]) m ()
forall (m :: * -> *).
(MonadWriter (Dual [Equation]) m, MonadAlpha m) =>
ToplevelStatement -> m ()
formularizeToplevelStatement Program
prog)

sortEquations :: [Equation] -> ([(Type, Type, Maybe Loc)], [(VarName', Type)])
sortEquations :: [Equation] -> ([(Type, Type, Maybe Loc)], [(VarName', Type)])
sortEquations = [(Type, Type, Maybe Loc)]
-> [(VarName', Type)]
-> [Equation]
-> ([(Type, Type, Maybe Loc)], [(VarName', Type)])
go [] []
  where
    go :: [(Type, Type, Maybe Loc)]
-> [(VarName', Type)]
-> [Equation]
-> ([(Type, Type, Maybe Loc)], [(VarName', Type)])
go [(Type, Type, Maybe Loc)]
eqns' [(VarName', Type)]
assertions [] = ([(Type, Type, Maybe Loc)]
eqns', [(VarName', Type)]
assertions)
    go [(Type, Type, Maybe Loc)]
eqns' [(VarName', Type)]
assertions (Equation
eqn : [Equation]
eqns) = case Equation
eqn of
      TypeEquation Type
t1 Type
t2 Maybe Loc
loc -> [(Type, Type, Maybe Loc)]
-> [(VarName', Type)]
-> [Equation]
-> ([(Type, Type, Maybe Loc)], [(VarName', Type)])
go ((Type
t1, Type
t2, Maybe Loc
loc) (Type, Type, Maybe Loc)
-> [(Type, Type, Maybe Loc)] -> [(Type, Type, Maybe Loc)]
forall a. a -> [a] -> [a]
: [(Type, Type, Maybe Loc)]
eqns') [(VarName', Type)]
assertions [Equation]
eqns
      TypeAssertion VarName'
x Type
t -> [(Type, Type, Maybe Loc)]
-> [(VarName', Type)]
-> [Equation]
-> ([(Type, Type, Maybe Loc)], [(VarName', Type)])
go [(Type, Type, Maybe Loc)]
eqns' ((VarName'
x, Type
t) (VarName', Type) -> [(VarName', Type)] -> [(VarName', Type)]
forall a. a -> [a] -> [a]
: [(VarName', Type)]
assertions) [Equation]
eqns

mergeAssertions :: [(VarName', Type)] -> [(Type, Type, Maybe Loc)]
mergeAssertions :: [(VarName', Type)] -> [(Type, Type, Maybe Loc)]
mergeAssertions = Map VarName Type
-> [(Type, Type, Maybe Loc)]
-> [(VarName', Type)]
-> [(Type, Type, Maybe Loc)]
forall a b.
Ord a =>
Map a b
-> [(b, b, Maybe Loc)] -> [(WithLoc' a, b)] -> [(b, b, Maybe Loc)]
go Map VarName Type
forall k a. Map k a
M.empty []
  where
    go :: Map a b
-> [(b, b, Maybe Loc)] -> [(WithLoc' a, b)] -> [(b, b, Maybe Loc)]
go Map a b
_ [(b, b, Maybe Loc)]
eqns [] = [(b, b, Maybe Loc)]
eqns
    go Map a b
gamma [(b, b, Maybe Loc)]
eqns ((WithLoc' a
x, b
t) : [(WithLoc' a, b)]
assertions) = case a -> Map a b -> Maybe b
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (WithLoc' a -> a
forall a. WithLoc' a -> a
value' WithLoc' a
x) Map a b
gamma of
      Maybe b
Nothing -> Map a b
-> [(b, b, Maybe Loc)] -> [(WithLoc' a, b)] -> [(b, b, Maybe Loc)]
go (a -> b -> Map a b -> Map a b
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (WithLoc' a -> a
forall a. WithLoc' a -> a
value' WithLoc' a
x) b
t Map a b
gamma) [(b, b, Maybe Loc)]
eqns [(WithLoc' a, b)]
assertions
      Just b
t' -> Map a b
-> [(b, b, Maybe Loc)] -> [(WithLoc' a, b)] -> [(b, b, Maybe Loc)]
go Map a b
gamma ((b
t, b
t', WithLoc' a -> Maybe Loc
forall a. WithLoc' a -> Maybe Loc
loc' WithLoc' a
x) (b, b, Maybe Loc) -> [(b, b, Maybe Loc)] -> [(b, b, Maybe Loc)]
forall a. a -> [a] -> [a]
: [(b, b, Maybe Loc)]
eqns) [(WithLoc' a, b)]
assertions

-- | `Subst` is type substituion. It's a mapping from type variables to their actual types.
newtype Subst = Subst {Subst -> Map TypeName Type
unSubst :: M.Map TypeName Type}

subst :: Subst -> Type -> Type
subst :: Subst -> Type -> Type
subst Subst
sigma = \case
  VarTy TypeName
x ->
    case TypeName -> Map TypeName Type -> Maybe Type
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup TypeName
x (Subst -> Map TypeName Type
unSubst Subst
sigma) of
      Maybe Type
Nothing -> TypeName -> Type
VarTy TypeName
x
      Just Type
t -> Subst -> Type -> Type
subst Subst
sigma Type
t
  Type
IntTy -> Type
IntTy
  Type
BoolTy -> Type
BoolTy
  ListTy Type
t -> Type -> Type
ListTy (Subst -> Type -> Type
subst Subst
sigma Type
t)
  TupleTy [Type]
ts -> [Type] -> Type
TupleTy ((Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Subst -> Type -> Type
subst Subst
sigma) [Type]
ts)
  CallableTy [Type]
ts Type
ret -> [Type] -> Type -> Type
CallableTy ((Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Subst -> Type -> Type
subst Subst
sigma) [Type]
ts) (Subst -> Type -> Type
subst Subst
sigma Type
ret)
  Type
StringTy -> Type
StringTy
  Type
SideEffectTy -> Type
SideEffectTy

unifyTyVar :: (MonadState Subst m, MonadError Error m) => TypeName -> Type -> m ()
unifyTyVar :: TypeName -> Type -> m ()
unifyTyVar TypeName
x Type
t =
  if TypeName
x TypeName -> [TypeName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` Type -> [TypeName]
freeTyVars Type
t
    then String -> m ()
forall (m :: * -> *) a. MonadError Error m => String -> m a
throwTypeError (String -> m ()) -> String -> m ()
forall a b. (a -> b) -> a -> b
$ String
"type equation loops: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
formatType (TypeName -> Type
VarTy TypeName
x) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" = " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
formatType Type
t
    else do
      (Subst -> Subst) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' (Map TypeName Type -> Subst
Subst (Map TypeName Type -> Subst)
-> (Subst -> Map TypeName Type) -> Subst -> Subst
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeName -> Type -> Map TypeName Type -> Map TypeName Type
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert TypeName
x Type
t (Map TypeName Type -> Map TypeName Type)
-> (Subst -> Map TypeName Type) -> Subst -> Map TypeName Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Subst -> Map TypeName Type
unSubst) -- This doesn't introduce the loop.

unifyType :: (MonadState Subst m, MonadError Error m) => Type -> Type -> m ()
unifyType :: Type -> Type -> m ()
unifyType Type
t1 Type
t2 = do
  Subst
sigma <- m Subst
forall s (m :: * -> *). MonadState s m => m s
get
  Type
t1 <- Type -> m Type
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> m Type) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ Subst -> Type -> Type
subst Subst
sigma Type
t1 -- shadowing
  Type
t2 <- Type -> m Type
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> m Type) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ Subst -> Type -> Type
subst Subst
sigma Type
t2 -- shadowing
  case (Type
t1, Type
t2) of
    (Type, Type)
_ | Type
t1 Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
t2 -> () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    (VarTy TypeName
x1, Type
_) -> do
      TypeName -> Type -> m ()
forall (m :: * -> *).
(MonadState Subst m, MonadError Error m) =>
TypeName -> Type -> m ()
unifyTyVar TypeName
x1 Type
t2
    (Type
_, VarTy TypeName
x2) -> do
      TypeName -> Type -> m ()
forall (m :: * -> *).
(MonadState Subst m, MonadError Error m) =>
TypeName -> Type -> m ()
unifyTyVar TypeName
x2 Type
t1
    (ListTy Type
t1, ListTy Type
t2) -> do
      Type -> Type -> m ()
forall (m :: * -> *).
(MonadState Subst m, MonadError Error m) =>
Type -> Type -> m ()
unifyType Type
t1 Type
t2
    (TupleTy [Type]
ts1, TupleTy [Type]
ts2) -> do
      if [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
ts1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
ts2
        then ((Type, Type) -> m ()) -> [(Type, Type)] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((Type -> Type -> m ()) -> (Type, Type) -> m ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Type -> Type -> m ()
forall (m :: * -> *).
(MonadState Subst m, MonadError Error m) =>
Type -> Type -> m ()
unifyType) ([Type] -> [Type] -> [(Type, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
ts1 [Type]
ts2)
        else String -> m ()
forall (m :: * -> *) a. MonadError Error m => String -> m a
throwTypeError (String -> m ()) -> String -> m ()
forall a b. (a -> b) -> a -> b
$ String
"type " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
formatType Type
t1 String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" is not type " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
formatType Type
t2
    (CallableTy [Type]
args1 Type
ret1, CallableTy [Type]
args2 Type
ret2) -> do
      if [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
args1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
args2
        then ((Type, Type) -> m ()) -> [(Type, Type)] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((Type -> Type -> m ()) -> (Type, Type) -> m ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Type -> Type -> m ()
forall (m :: * -> *).
(MonadState Subst m, MonadError Error m) =>
Type -> Type -> m ()
unifyType) ([Type] -> [Type] -> [(Type, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
args1 [Type]
args2)
        else String -> m ()
forall (m :: * -> *) a. MonadError Error m => String -> m a
throwTypeError (String -> m ()) -> String -> m ()
forall a b. (a -> b) -> a -> b
$ String
"type " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
formatType Type
t1 String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" is not type " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
formatType Type
t2
      Type -> Type -> m ()
forall (m :: * -> *).
(MonadState Subst m, MonadError Error m) =>
Type -> Type -> m ()
unifyType Type
ret1 Type
ret2
    (Type, Type)
_ -> String -> m ()
forall (m :: * -> *) a. MonadError Error m => String -> m a
throwTypeError (String -> m ()) -> String -> m ()
forall a b. (a -> b) -> a -> b
$ String
"type " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
formatType Type
t1 String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" is not type " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
formatType Type
t2

solveEquations :: MonadError Error m => [(Type, Type, Maybe Loc)] -> m Subst
solveEquations :: [(Type, Type, Maybe Loc)] -> m Subst
solveEquations [(Type, Type, Maybe Loc)]
eqns = String -> m Subst -> m Subst
forall (m :: * -> *) a. MonadError Error m => String -> m a -> m a
wrapError' String
"failed to solve type equations" (m Subst -> m Subst) -> m Subst -> m Subst
forall a b. (a -> b) -> a -> b
$ do
  (StateT Subst m [()] -> Subst -> m Subst)
-> Subst -> StateT Subst m [()] -> m Subst
forall a b c. (a -> b -> c) -> b -> a -> c
flip StateT Subst m [()] -> Subst -> m Subst
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m s
execStateT (Map TypeName Type -> Subst
Subst Map TypeName Type
forall k a. Map k a
M.empty) (StateT Subst m [()] -> m Subst) -> StateT Subst m [()] -> m Subst
forall a b. (a -> b) -> a -> b
$ do
    [Either Error ()]
errs <- [(Type, Type, Maybe Loc)]
-> ((Type, Type, Maybe Loc) -> StateT Subst m (Either Error ()))
-> StateT Subst m [Either Error ()]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Type, Type, Maybe Loc)]
eqns (((Type, Type, Maybe Loc) -> StateT Subst m (Either Error ()))
 -> StateT Subst m [Either Error ()])
-> ((Type, Type, Maybe Loc) -> StateT Subst m (Either Error ()))
-> StateT Subst m [Either Error ()]
forall a b. (a -> b) -> a -> b
$ \(Type
t1, Type
t2, Maybe Loc
loc) -> do
      (() -> Either Error ()
forall a b. b -> Either a b
Right (() -> Either Error ())
-> StateT Subst m () -> StateT Subst m (Either Error ())
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Type -> Type -> StateT Subst m ()
forall (m :: * -> *).
(MonadState Subst m, MonadError Error m) =>
Type -> Type -> m ()
unifyType Type
t1 Type
t2) StateT Subst m (Either Error ())
-> (Error -> StateT Subst m (Either Error ()))
-> StateT Subst m (Either Error ())
forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
`catchError` \Error
err -> do
        Subst
sigma <- StateT Subst m Subst
forall s (m :: * -> *). MonadState s m => m s
get
        Type
t1 <- Type -> StateT Subst m Type
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> StateT Subst m Type) -> Type -> StateT Subst m Type
forall a b. (a -> b) -> a -> b
$ Subst -> Type -> Type
subst Subst
sigma Type
t1 -- shadowing
        Type
t2 <- Type -> StateT Subst m Type
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> StateT Subst m Type) -> Type -> StateT Subst m Type
forall a b. (a -> b) -> a -> b
$ Subst -> Type -> Type
subst Subst
sigma Type
t2 -- shadowing
        Either Error () -> StateT Subst m (Either Error ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Either Error () -> StateT Subst m (Either Error ()))
-> Either Error () -> StateT Subst m (Either Error ())
forall a b. (a -> b) -> a -> b
$ Error -> Either Error ()
forall a b. a -> Either a b
Left ((Error -> Error)
-> (Loc -> Error -> Error) -> Maybe Loc -> Error -> Error
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Error -> Error
forall a. a -> a
id Loc -> Error -> Error
WithLocation Maybe Loc
loc (String -> Error -> Error
WithWrapped (String
"failed to unify type " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
formatType Type
t1 String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" and type " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
formatType Type
t2) Error
err))
    [Either Error ()] -> StateT Subst m [()]
forall (m :: * -> *) a.
MonadError Error m =>
[Either Error a] -> m [a]
reportErrors [Either Error ()]
errs

mapTypeConstant :: (Type -> Type) -> Constant -> Constant
mapTypeConstant :: (Type -> Type) -> Constant -> Constant
mapTypeConstant Type -> Type
f = \case
  Constant
ConstNone -> Constant
ConstNone
  ConstInt Integer
n -> Integer -> Constant
ConstInt Integer
n
  ConstBool Bool
p -> Bool -> Constant
ConstBool Bool
p
  ConstBuiltin Builtin
b -> Builtin -> Constant
ConstBuiltin ((Type -> Type) -> Builtin -> Builtin
mapTypeBuiltin Type -> Type
f Builtin
b)

mapTypeTarget :: (Type -> Type) -> Target' -> Target'
mapTypeTarget :: (Type -> Type) -> Target' -> Target'
mapTypeTarget Type -> Type
f = (Target -> Target) -> Target' -> Target'
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Target -> Target) -> Target' -> Target')
-> (Target -> Target) -> Target' -> Target'
forall a b. (a -> b) -> a -> b
$ \case
  SubscriptTrg Target'
x Expr'
index -> Target' -> Expr' -> Target
SubscriptTrg ((Type -> Type) -> Target' -> Target'
mapTypeTarget Type -> Type
f Target'
x) ((Type -> Type) -> Expr' -> Expr'
mapTypeExpr Type -> Type
f Expr'
index)
  NameTrg VarName'
x -> VarName' -> Target
NameTrg VarName'
x
  TupleTrg [Target']
xs -> [Target'] -> Target
TupleTrg ((Target' -> Target') -> [Target'] -> [Target']
forall a b. (a -> b) -> [a] -> [b]
map ((Type -> Type) -> Target' -> Target'
mapTypeTarget Type -> Type
f) [Target']
xs)

mapTypeExpr :: (Type -> Type) -> Expr' -> Expr'
mapTypeExpr :: (Type -> Type) -> Expr' -> Expr'
mapTypeExpr Type -> Type
f = (Expr' -> Expr') -> Expr' -> Expr'
mapSubExpr Expr' -> Expr'
go
  where
    go :: Expr' -> Expr'
go = (Expr -> Expr) -> Expr' -> Expr'
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Expr -> Expr) -> Expr' -> Expr')
-> (Expr -> Expr) -> Expr' -> Expr'
forall a b. (a -> b) -> a -> b
$ \case
      Lambda [(VarName', Type)]
args Expr'
body -> [(VarName', Type)] -> Expr' -> Expr
Lambda (((VarName', Type) -> (VarName', Type))
-> [(VarName', Type)] -> [(VarName', Type)]
forall a b. (a -> b) -> [a] -> [b]
map ((Type -> Type) -> (VarName', Type) -> (VarName', Type)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second Type -> Type
f) [(VarName', Type)]
args) (Expr' -> Expr'
go Expr'
body)
      ListComp Expr'
e (Comprehension Target'
x Expr'
iter Maybe Expr'
pred) -> Expr' -> Comprehension -> Expr
ListComp (Expr' -> Expr'
go Expr'
e) (Target' -> Expr' -> Maybe Expr' -> Comprehension
Comprehension ((Type -> Type) -> Target' -> Target'
mapTypeTarget Type -> Type
f Target'
x) (Expr' -> Expr'
go Expr'
iter) ((Expr' -> Expr') -> Maybe Expr' -> Maybe Expr'
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Expr' -> Expr'
go Maybe Expr'
pred))
      Compare Expr'
e1 (CmpOp' CmpOp
op Type
t) Expr'
e2 -> Expr' -> CmpOp' -> Expr' -> Expr
Compare (Expr' -> Expr'
go Expr'
e1) (CmpOp -> Type -> CmpOp'
CmpOp' CmpOp
op (Type -> Type
f Type
t)) (Expr' -> Expr'
go Expr'
e2)
      Constant Constant
const -> Constant -> Expr
Constant ((Type -> Type) -> Constant -> Constant
mapTypeConstant Type -> Type
f Constant
const)
      Attribute Expr'
e Attribute'
a -> Expr' -> Attribute' -> Expr
Attribute (Expr' -> Expr'
go Expr'
e) ((Type -> Type) -> Attribute -> Attribute
mapTypeAttribute Type -> Type
f (Attribute -> Attribute) -> Attribute' -> Attribute'
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Attribute'
a)
      List Type
t [Expr']
es -> Type -> [Expr'] -> Expr
List (Type -> Type
f Type
t) ((Expr' -> Expr') -> [Expr'] -> [Expr']
forall a b. (a -> b) -> [a] -> [b]
map Expr' -> Expr'
go [Expr']
es)
      Expr
e -> Expr
e

mapTypeStatement :: (Type -> Type) -> Statement -> Statement
mapTypeStatement :: (Type -> Type) -> Statement -> Statement
mapTypeStatement Type -> Type
f = \case
  Return Expr'
e -> Expr' -> Statement
Return ((Type -> Type) -> Expr' -> Expr'
mapTypeExpr Type -> Type
f Expr'
e)
  AugAssign Target'
x Operator
op Expr'
e -> Target' -> Operator -> Expr' -> Statement
AugAssign ((Type -> Type) -> Target' -> Target'
mapTypeTarget Type -> Type
f Target'
x) Operator
op ((Type -> Type) -> Expr' -> Expr'
mapTypeExpr Type -> Type
f Expr'
e)
  AnnAssign Target'
x Type
t Expr'
e -> Target' -> Type -> Expr' -> Statement
AnnAssign ((Type -> Type) -> Target' -> Target'
mapTypeTarget Type -> Type
f Target'
x) (Type -> Type
f Type
t) ((Type -> Type) -> Expr' -> Expr'
mapTypeExpr Type -> Type
f Expr'
e)
  For Target'
x Expr'
iter [Statement]
body -> Target' -> Expr' -> [Statement] -> Statement
For ((Type -> Type) -> Target' -> Target'
mapTypeTarget Type -> Type
f Target'
x) ((Type -> Type) -> Expr' -> Expr'
mapTypeExpr Type -> Type
f Expr'
iter) ((Statement -> Statement) -> [Statement] -> [Statement]
forall a b. (a -> b) -> [a] -> [b]
map ((Type -> Type) -> Statement -> Statement
mapTypeStatement Type -> Type
f) [Statement]
body)
  If Expr'
pred [Statement]
body1 [Statement]
body2 -> Expr' -> [Statement] -> [Statement] -> Statement
If ((Type -> Type) -> Expr' -> Expr'
mapTypeExpr Type -> Type
f Expr'
pred) ((Statement -> Statement) -> [Statement] -> [Statement]
forall a b. (a -> b) -> [a] -> [b]
map ((Type -> Type) -> Statement -> Statement
mapTypeStatement Type -> Type
f) [Statement]
body1) ((Statement -> Statement) -> [Statement] -> [Statement]
forall a b. (a -> b) -> [a] -> [b]
map ((Type -> Type) -> Statement -> Statement
mapTypeStatement Type -> Type
f) [Statement]
body2)
  Assert Expr'
e -> Expr' -> Statement
Assert ((Type -> Type) -> Expr' -> Expr'
mapTypeExpr Type -> Type
f Expr'
e)
  Expr' Expr'
e -> Expr' -> Statement
Expr' ((Type -> Type) -> Expr' -> Expr'
mapTypeExpr Type -> Type
f Expr'
e)

mapTypeToplevelStatement :: (Type -> Type) -> ToplevelStatement -> ToplevelStatement
mapTypeToplevelStatement :: (Type -> Type) -> ToplevelStatement -> ToplevelStatement
mapTypeToplevelStatement Type -> Type
f = \case
  ToplevelAnnAssign VarName'
x Type
t Expr'
e -> VarName' -> Type -> Expr' -> ToplevelStatement
ToplevelAnnAssign VarName'
x (Type -> Type
f Type
t) ((Type -> Type) -> Expr' -> Expr'
mapTypeExpr Type -> Type
f Expr'
e)
  ToplevelFunctionDef VarName'
g [(VarName', Type)]
args Type
ret [Statement]
body -> VarName'
-> [(VarName', Type)] -> Type -> [Statement] -> ToplevelStatement
ToplevelFunctionDef VarName'
g (((VarName', Type) -> (VarName', Type))
-> [(VarName', Type)] -> [(VarName', Type)]
forall a b. (a -> b) -> [a] -> [b]
map ((Type -> Type) -> (VarName', Type) -> (VarName', Type)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second Type -> Type
f) [(VarName', Type)]
args) (Type -> Type
f Type
ret) ((Statement -> Statement) -> [Statement] -> [Statement]
forall a b. (a -> b) -> [a] -> [b]
map ((Type -> Type) -> Statement -> Statement
mapTypeStatement Type -> Type
f) [Statement]
body)
  ToplevelAssert Expr'
e -> Expr' -> ToplevelStatement
ToplevelAssert ((Type -> Type) -> Expr' -> Expr'
mapTypeExpr Type -> Type
f Expr'
e)

mapTypeProgram :: (Type -> Type) -> Program -> Program
mapTypeProgram :: (Type -> Type) -> Program -> Program
mapTypeProgram Type -> Type
f Program
prog = (ToplevelStatement -> ToplevelStatement) -> Program -> Program
forall a b. (a -> b) -> [a] -> [b]
map ((Type -> Type) -> ToplevelStatement -> ToplevelStatement
mapTypeToplevelStatement Type -> Type
f) Program
prog

-- | `substUnit` replaces all undetermined type variables with the unit type.
substUnit :: Type -> Type
substUnit :: Type -> Type
substUnit = \case
  VarTy TypeName
_ -> Type
NoneTy
  Type
IntTy -> Type
IntTy
  Type
BoolTy -> Type
BoolTy
  ListTy Type
t -> Type -> Type
ListTy (Type -> Type
substUnit Type
t)
  TupleTy [Type]
ts -> [Type] -> Type
TupleTy ((Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
substUnit [Type]
ts)
  CallableTy [Type]
ts Type
ret -> [Type] -> Type -> Type
CallableTy ((Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
substUnit [Type]
ts) (Type -> Type
substUnit Type
ret)
  Type
StringTy -> Type
StringTy
  Type
SideEffectTy -> Type
SideEffectTy

-- | `subst'` does `subst` and replaces all undetermined type variables with the unit type.
subst' :: Subst -> Type -> Type
subst' :: Subst -> Type -> Type
subst' Subst
sigma = Type -> Type
substUnit (Type -> Type) -> (Type -> Type) -> Type -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Subst -> Type -> Type
subst Subst
sigma

-- | `run` infers types of given programs.
--
-- As the interface, you can understand this function does the following:
--
-- 1. Finds a type environment \(\Gamma\) s.t. for all statement \(\mathrm{stmt}\) in the given program, \(\Gamma \vdash \mathrm{stmt}\) holds, and
-- 2. Annotates each variable in the program using the \(\Gamma\).
--
-- In its implementation, this is just something like a Hindley-Milner type inference.
--
-- == Requirements
--
-- * There must be no name conflicts in given programs. They must be alpha-converted. (`Jikka.RestrictedPython.Convert.Alpha`)
-- * All names must be resolved. (`Jikka.RestrictedPython.Convert.ResolveBuiltin`)
run :: (MonadAlpha m, MonadError Error m) => Program -> m Program
run :: Program -> m Program
run Program
prog = String -> m Program -> m Program
forall (m :: * -> *) a. MonadError Error m => String -> m a -> m a
wrapError' String
"Jikka.RestrictedPython.Convert.TypeInfer" (m Program -> m Program) -> m Program -> m Program
forall a b. (a -> b) -> a -> b
$ do
  [Equation]
eqns <- Program -> m [Equation]
forall (m :: * -> *). MonadAlpha m => Program -> m [Equation]
formularizeProgram Program
prog
  let ([(Type, Type, Maybe Loc)]
eqns', [(VarName', Type)]
assertions) = [Equation] -> ([(Type, Type, Maybe Loc)], [(VarName', Type)])
sortEquations [Equation]
eqns
  let eqns'' :: [(Type, Type, Maybe Loc)]
eqns'' = [(VarName', Type)] -> [(Type, Type, Maybe Loc)]
mergeAssertions [(VarName', Type)]
assertions
  Subst
sigma <- [(Type, Type, Maybe Loc)] -> m Subst
forall (m :: * -> *).
MonadError Error m =>
[(Type, Type, Maybe Loc)] -> m Subst
solveEquations ([(Type, Type, Maybe Loc)]
eqns' [(Type, Type, Maybe Loc)]
-> [(Type, Type, Maybe Loc)] -> [(Type, Type, Maybe Loc)]
forall a. [a] -> [a] -> [a]
++ [(Type, Type, Maybe Loc)]
eqns'')
  Program -> m Program
forall (m :: * -> *) a. Monad m => a -> m a
return (Program -> m Program) -> Program -> m Program
forall a b. (a -> b) -> a -> b
$ (Type -> Type) -> Program -> Program
mapTypeProgram (Subst -> Type -> Type
subst' Subst
sigma) Program
prog