module Compiler.Typesystem.SystemF where
import Data.Bifunctor.TH
import Calculi.Lambda
import Calculi.Lambda.Cube.SimpleType
import Calculi.Lambda.Cube.Polymorphic
import Calculi.Lambda.Cube.Polymorphic.Unification
import Control.Typecheckable
import Control.Monad
import Control.Monad.State
import Control.Monad.Except
import Control.Lens hiding (elements)
import Data.Either.Combinators
import qualified Data.Set as Set
import qualified Data.Map as Map
import Data.Traversable
import Data.List
import Data.Random.Generics
import Data.Generics
import Data.Semigroup
import Data.Maybe
import Data.Graph.Inductive (Gr)
import Test.QuickCheck
import qualified Language.Haskell.TH.Lift as TH
data SystemF m p =
Mono m
| Poly p
| Function (SystemF m p) (SystemF m p)
| Forall p (SystemF m p)
deriving (Eq, Ord, Data)
instance (Ord m, Ord p, Show m, Show p) => Show (SystemF m p) where
show (Mono m) = show m
show (Poly p) = show p
show (Function a b) =
let isForall (Forall _ _) = True
isForall _ = False
astr = if isFunction a || isForall a then "(" ++ show a ++ ")" else show a
in astr ++ " -> " ++ show b
show (Forall p expr) =
let getQuant :: SystemF m p -> ([p], SystemF m p)
getQuant (Forall _p _expr) = getQuant _expr & _1 %~ (_p :)
getQuant _expr = ([], _expr)
(ps, _expr) = getQuant expr
in "forall " ++ unwords (show <$> (p:ps)) ++ ". " ++ show _expr
deriveBifunctor ''SystemF
deriveBifoldable ''SystemF
deriveBitraversable ''SystemF
TH.deriveLift ''SystemF
instance (Ord m, Ord p) => SimpleType (SystemF m p) where
type MonoType (SystemF m p) = m
abstract = Function
reify (Function a b) = Just (a, b)
reify _ = Nothing
bases = \case
Function fun arg -> bases fun <> bases arg
Forall p expr -> Set.insert (Poly p) (bases expr)
expr -> Set.singleton expr
mono = Mono
equivalent = areAlphaEquivalent
instance (Ord m, Ord p) => Polymorphic (SystemF m p) where
type PolyType (SystemF m p) = p
substitutions =
let (<><>) :: (Semigroup s1, Semigroup s2) => Either s1 s2 -> Either s1 s2 -> Either s1 s2
(<><>) = curry $ \case
(Left a, Left b) -> Left (a <> b)
(a, b) -> (<>) <$> a <*> b
in curry $ \case
(Forall _ expr1 , expr2) -> substitutions expr1 expr2
(expr1 , Forall _ expr2) -> substitutions expr1 expr2
(Poly p1 , Poly p2) -> Right [Mutual p1 p2]
(expr , Poly p) -> Right [Substitution expr p]
(Poly p , expr) -> Right [Substitution expr p]
(Function arg1 ret1, Function arg2 ret2) -> substitutions arg1 arg2 <><> substitutions ret1 ret2
(expr1 , expr2)
| expr1 == expr2 -> Right []
| otherwise -> Left [(expr1, expr2)]
applySubstitution sub target = applySubstitution' where
applySubstitution' = \case
m@Mono{} -> m
p'@(Poly p)
| p == target -> sub
| otherwise -> p'
Forall p expr
| p == target -> applySubstitution' expr
| otherwise -> Forall p (applySubstitution' expr)
Function from to -> Function (applySubstitution' from) (applySubstitution' to)
quantify = Forall
poly = Poly
quantifiedOf = \case
Mono _ -> Set.empty
Poly _ -> Set.empty
Function from to -> quantifiedOf from <> quantifiedOf to
Forall p texpr -> Set.insert (poly p) (quantifiedOf texpr)
polytypesOf = \case
Mono _ -> Set.empty
p@Poly{} -> Set.singleton p
data SystemFErr c v t =
SFNotKnownErr (NotKnownErr c v t)
| SFSimpleTypeErr (SimpleTypeErr t)
| SFSubsErr (SubsErr Gr t (PolyType t))
deriving instance (Polymorphic t, Eq v, Eq c) => Eq (SystemFErr c v t)
deriving instance (Polymorphic t, Show v, Show c, Show t, Show (PolyType t)) => Show (SystemFErr c v t)
data SystemFContext c v t p = SystemFContext {
_polyctx :: SubsContext t p
, _stlcctx :: SimpleTypingContext c v t
} deriving (Eq, Ord, Show)
makeLenses ''SystemFContext
instance (Ord c, Ord v, Ord m, Ord p) => Typecheckable (LambdaTerm c v) (SystemF m p) where
type TypingContext (LambdaTerm c v) (SystemF m p) = (SystemFContext c v (SystemF m p) p)
type TypeError (LambdaTerm c v) (SystemF m p) =
ErrorContext'
(LambdaTerm c v)
(SystemF m p)
(SystemFErr c v (SystemF m p))
typecheck env _expr = runTypecheck env (typecheck' _expr) where
typecheck' :: LambdaTerm c v (SystemF m p) -> Typecheck (LambdaTerm c v) (SystemF m p) (SystemF m p)
typecheck' __expr =
flip catchError (throwError . fmap (expression %~ (__expr :)))
$ case __expr of
Variable v -> do
t <- Map.lookup v <$> use (stlcctx.variables)
let nameErr = throwErrorContext [] (SFNotKnownErr (UnknownVariable v))
maybe nameErr return t
Constant c -> do
t <- Map.lookup c <$> use (stlcctx.constants)
let nameErr = throwErrorContext [] (SFNotKnownErr (UnknownConstant c))
maybe nameErr return t
Apply fun arg -> do
fun'type <- typecheck' fun
arg'type <- typecheck' arg
(fun'from, fun'to) <- case reify fun'type of
Nothing -> throwErrorContext [fun] (SFSimpleTypeErr (NotAFunction fun'type))
Just reified -> return reified
case unify fun'from arg'type of
Left errs ->
throwErrorContexts ((,) [] . SFSubsErr <$> errs)
Right subs ->
return $ unvalidatedApplyAllSubs subs fun'to
Lambda (v, t) expr -> do
unknownTypes <- calcUnknownTypes t <$> use (stlcctx.allTypes)
unless (null unknownTypes) (throwErrorContexts ((,) [] <$> unknownTypes))
oldstate <- get
stlcctx.variables %= Map.insert v t
stlcctx.allTypes %= (outmostDeclaredPolys t <>)
expr'type <- typecheck' expr
put oldstate
return (t /-> expr'type)
calcUnknownTypes t types =
SFNotKnownErr . UnknownType <$> Set.toList (Set.difference (bases t) types)
outmostDeclaredPolys (Forall p texpr) = Set.insert (poly p) (outmostDeclaredPolys texpr)
outmostDeclaredPolys _ = Set.empty
instance (Ord m, Ord p, Arbitrary m, Data m, Arbitrary p, Data p) => Arbitrary (SystemF m p) where
arbitrary = process <$> sized (generatorSRWith aliases) where
aliases :: [AliasR Gen]
aliases = [
aliasR (\() -> arbitrary :: Gen m)
, aliasR (\() -> arbitrary :: Gen p)
]
process = generalise' . massUnquantify where
massUnquantify :: SystemF m p -> SystemF m p
massUnquantify t@Mono{} = t
massUnquantify t@Poly{} = t
massUnquantify (Function from to) = massUnquantify from /-> massUnquantify to
massUnquantify (Forall _ texpr) = massUnquantify texpr