-- | An implementation of LJT proof search directly on Core terms.
module GHC.LJT where

import FastString
import Unique
import Type
import Id
import Var
import CoreSyn
import Outputable
import TyCoRep
import TyCon
import DataCon
import MkCore
import MkId
import CoreUtils
import TysWiredIn
import BasicTypes
import NameEnv
import NameSet
import Coercion

import Data.List
import Data.Hashable
import Control.Monad
import Data.Bifunctor

ljt ::  Type -> [CoreExpr]
ljt :: Type -> [CoreExpr]
ljt t :: Type
t = [] [Id] -> Type -> [CoreExpr]
==> Type
t


(==>) :: [Id] -> Type -> [CoreExpr]

-- Rule Axiom
-- (TODO: The official algorithm restricts this rule to atoms. Why?)
ante :: [Id]
ante ==> :: [Id] -> Type -> [CoreExpr]
==> goal :: Type
goal
    | Just v :: Id
v <- (Id -> Bool) -> [Id] -> Maybe Id
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (\v :: Id
v -> Id -> Type
idType Id
v Type -> Type -> Bool
`eqType` Type
goal) [Id]
ante
    = CoreExpr -> [CoreExpr]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (CoreExpr -> [CoreExpr]) -> CoreExpr -> [CoreExpr]
forall a b. (a -> b) -> a -> b
$ Id -> CoreExpr
forall b. Id -> Expr b
Var Id
v

-- Rule f⇒
ante :: [Id]
ante ==> goal :: Type
goal
    | Just v :: Id
v <- (Id -> Bool) -> [Id] -> Maybe Id
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (\v :: Id
v -> Type -> Bool
isEmptyTy (Id -> Type
idType Id
v)) [Id]
ante
    = CoreExpr -> [CoreExpr]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (CoreExpr -> [CoreExpr]) -> CoreExpr -> [CoreExpr]
forall a b. (a -> b) -> a -> b
$ CoreExpr -> Type -> Type -> [CoreAlt] -> CoreExpr
mkWildCase (Id -> CoreExpr
forall b. Id -> Expr b
Var Id
v) (Id -> Type
idType Id
v) Type
goal []

-- Rule →⇒2
ante :: [Id]
ante ==> goal :: Type
goal
    | Just ((v :: Id
v,((tys :: [Type]
tys, build :: [CoreExpr] -> CoreExpr
build, _destruct :: CoreExpr -> [Id] -> CoreExpr -> CoreExpr
_destruct),_r :: Type
_r)),ante' :: [Id]
ante') <- (Type
 -> Maybe
      (([Type], [CoreExpr] -> CoreExpr,
        CoreExpr -> [Id] -> CoreExpr -> CoreExpr),
       Type))
-> [Id]
-> Maybe
     ((Id,
       (([Type], [CoreExpr] -> CoreExpr,
         CoreExpr -> [Id] -> CoreExpr -> CoreExpr),
        Type)),
      [Id])
forall a. (Type -> Maybe a) -> [Id] -> Maybe ((Id, a), [Id])
anyA ((Type
 -> Maybe
      ([Type], [CoreExpr] -> CoreExpr,
       CoreExpr -> [Id] -> CoreExpr -> CoreExpr))
-> Type
-> Maybe
     (([Type], [CoreExpr] -> CoreExpr,
       CoreExpr -> [Id] -> CoreExpr -> CoreExpr),
      Type)
forall a. (Type -> Maybe a) -> Type -> Maybe (a, Type)
funLeft Type
-> Maybe
     ([Type], [CoreExpr] -> CoreExpr,
      CoreExpr -> [Id] -> CoreExpr -> CoreExpr)
isProdType) [Id]
ante
    = let vs :: [Id]
vs = (Type -> Id) -> [Type] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Id
newVar [Type]
tys
          expr :: CoreExpr
expr = [Id] -> CoreExpr -> CoreExpr
forall b. [b] -> Expr b -> Expr b
mkLams [Id]
vs (CoreExpr -> CoreExpr -> CoreExpr
forall b. Expr b -> Expr b -> Expr b
App (Id -> CoreExpr
forall b. Id -> Expr b
Var Id
v) ([CoreExpr] -> CoreExpr
build ((Id -> CoreExpr) -> [Id] -> [CoreExpr]
forall a b. (a -> b) -> [a] -> [b]
map Id -> CoreExpr
forall b. Id -> Expr b
Var [Id]
vs)))
          v' :: Id
v' = Type -> Id
newVar (CoreExpr -> Type
exprType CoreExpr
expr)
      in Id -> CoreExpr -> CoreExpr -> CoreExpr
forall b. b -> Expr b -> Expr b -> Expr b
mkLetNonRec Id
v' CoreExpr
expr (CoreExpr -> CoreExpr) -> [CoreExpr] -> [CoreExpr]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Id
v' Id -> [Id] -> [Id]
forall a. a -> [a] -> [a]
: [Id]
ante') [Id] -> Type -> [CoreExpr]
==> Type
goal

-- Rule →⇒3
ante :: [Id]
ante ==> goal :: Type
goal
    | Just ((v :: Id
v,((tys :: [Type]
tys, injs :: [CoreExpr -> CoreExpr]
injs, _destruct :: CoreExpr -> [Id] -> [CoreExpr] -> CoreExpr
_destruct),_r :: Type
_r)),ante' :: [Id]
ante') <- (Type
 -> Maybe
      (([Type], [CoreExpr -> CoreExpr],
        CoreExpr -> [Id] -> [CoreExpr] -> CoreExpr),
       Type))
-> [Id]
-> Maybe
     ((Id,
       (([Type], [CoreExpr -> CoreExpr],
         CoreExpr -> [Id] -> [CoreExpr] -> CoreExpr),
        Type)),
      [Id])
forall a. (Type -> Maybe a) -> [Id] -> Maybe ((Id, a), [Id])
anyA ((Type
 -> Maybe
      ([Type], [CoreExpr -> CoreExpr],
       CoreExpr -> [Id] -> [CoreExpr] -> CoreExpr))
-> Type
-> Maybe
     (([Type], [CoreExpr -> CoreExpr],
       CoreExpr -> [Id] -> [CoreExpr] -> CoreExpr),
      Type)
forall a. (Type -> Maybe a) -> Type -> Maybe (a, Type)
funLeft Type
-> Maybe
     ([Type], [CoreExpr -> CoreExpr],
      CoreExpr -> [Id] -> [CoreExpr] -> CoreExpr)
isSumType) [Id]
ante
    = let es :: [CoreExpr]
es = [ Type -> (Id -> CoreExpr) -> CoreExpr
lam Type
ty (\vx :: Id
vx -> CoreExpr -> CoreExpr -> CoreExpr
forall b. Expr b -> Expr b -> Expr b
App (Id -> CoreExpr
forall b. Id -> Expr b
Var Id
v) (CoreExpr -> CoreExpr
inj (Id -> CoreExpr
forall b. Id -> Expr b
Var Id
vx))) | (ty :: Type
ty,inj :: CoreExpr -> CoreExpr
inj) <- [Type] -> [CoreExpr -> CoreExpr] -> [(Type, CoreExpr -> CoreExpr)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
tys [CoreExpr -> CoreExpr]
injs ]
      in [CoreExpr] -> ([Id] -> [CoreExpr]) -> [CoreExpr]
forall (f :: * -> *).
Applicative f =>
[CoreExpr] -> ([Id] -> f CoreExpr) -> f CoreExpr
letsA [CoreExpr]
es (([Id] -> [CoreExpr]) -> [CoreExpr])
-> ([Id] -> [CoreExpr]) -> [CoreExpr]
forall a b. (a -> b) -> a -> b
$ \vs :: [Id]
vs -> ([Id]
vs [Id] -> [Id] -> [Id]
forall a. [a] -> [a] -> [a]
++ [Id]
ante') [Id] -> Type -> [CoreExpr]
==> Type
goal

-- Rule ∧⇒
ante :: [Id]
ante ==> goal :: Type
goal
    | Just ((v :: Id
v,(tys :: [Type]
tys, _build :: [CoreExpr] -> CoreExpr
_build, destruct :: CoreExpr -> [Id] -> CoreExpr -> CoreExpr
destruct)),ante' :: [Id]
ante') <- (Type
 -> Maybe
      ([Type], [CoreExpr] -> CoreExpr,
       CoreExpr -> [Id] -> CoreExpr -> CoreExpr))
-> [Id]
-> Maybe
     ((Id,
       ([Type], [CoreExpr] -> CoreExpr,
        CoreExpr -> [Id] -> CoreExpr -> CoreExpr)),
      [Id])
forall a. (Type -> Maybe a) -> [Id] -> Maybe ((Id, a), [Id])
anyA Type
-> Maybe
     ([Type], [CoreExpr] -> CoreExpr,
      CoreExpr -> [Id] -> CoreExpr -> CoreExpr)
isProdType [Id]
ante
    = let pats :: [Id]
pats = (Type -> Id) -> [Type] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Id
newVar [Type]
tys
      in CoreExpr -> [Id] -> CoreExpr -> CoreExpr
destruct (Id -> CoreExpr
forall b. Id -> Expr b
Var Id
v) [Id]
pats (CoreExpr -> CoreExpr) -> [CoreExpr] -> [CoreExpr]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([Id]
pats [Id] -> [Id] -> [Id]
forall a. [a] -> [a] -> [a]
++ [Id]
ante') [Id] -> Type -> [CoreExpr]
==> Type
goal

-- Rule ⇒∧
ante :: [Id]
ante ==> goal :: Type
goal
    | Just (tys :: [Type]
tys, build :: [CoreExpr] -> CoreExpr
build, _destruct :: CoreExpr -> [Id] -> CoreExpr -> CoreExpr
_destruct) <- Type
-> Maybe
     ([Type], [CoreExpr] -> CoreExpr,
      CoreExpr -> [Id] -> CoreExpr -> CoreExpr)
isProdType Type
goal
    = [CoreExpr] -> CoreExpr
build ([CoreExpr] -> CoreExpr) -> [[CoreExpr]] -> [CoreExpr]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [[CoreExpr]] -> [[CoreExpr]]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [[Id]
ante [Id] -> Type -> [CoreExpr]
==> Type
ty | Type
ty <- [Type]
tys]

-- Rule ∨⇒
ante :: [Id]
ante ==> goal :: Type
goal
    | Just ((vAorB :: Id
vAorB, (tys :: [Type]
tys, _injs :: [CoreExpr -> CoreExpr]
_injs, destruct :: CoreExpr -> [Id] -> [CoreExpr] -> CoreExpr
destruct)),ante' :: [Id]
ante') <- (Type
 -> Maybe
      ([Type], [CoreExpr -> CoreExpr],
       CoreExpr -> [Id] -> [CoreExpr] -> CoreExpr))
-> [Id]
-> Maybe
     ((Id,
       ([Type], [CoreExpr -> CoreExpr],
        CoreExpr -> [Id] -> [CoreExpr] -> CoreExpr)),
      [Id])
forall a. (Type -> Maybe a) -> [Id] -> Maybe ((Id, a), [Id])
anyA Type
-> Maybe
     ([Type], [CoreExpr -> CoreExpr],
      CoreExpr -> [Id] -> [CoreExpr] -> CoreExpr)
isSumType [Id]
ante
    = let vs :: [Id]
vs = (Type -> Id) -> [Type] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Id
newVar [Type]
tys in
      CoreExpr -> [Id] -> [CoreExpr] -> CoreExpr
destruct (Id -> CoreExpr
forall b. Id -> Expr b
Var Id
vAorB) [Id]
vs ([CoreExpr] -> CoreExpr) -> [[CoreExpr]] -> [CoreExpr]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [[CoreExpr]] -> [[CoreExpr]]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [ (Id
vId -> [Id] -> [Id]
forall a. a -> [a] -> [a]
:[Id]
ante') [Id] -> Type -> [CoreExpr]
==> Type
goal | Id
v <- [Id]
vs]

-- Rule ⇒→
ante :: [Id]
ante ==> FunTy t1 :: Type
t1 t2 :: Type
t2
    = Id -> CoreExpr -> CoreExpr
forall b. b -> Expr b -> Expr b
Lam Id
v (CoreExpr -> CoreExpr) -> [CoreExpr] -> [CoreExpr]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Id
v Id -> [Id] -> [Id]
forall a. a -> [a] -> [a]
: [Id]
ante) [Id] -> Type -> [CoreExpr]
==> Type
t2
  where
    v :: Id
v = Type -> Id
newVar Type
t1

-- Rule →⇒1
-- (TODO: The official algorithm restricts this rule to atoms. Why?)
ante :: [Id]
ante ==> goal :: Type
goal
    | let isInAnte :: Type -> Maybe Id
isInAnte a :: Type
a = (Id -> Bool) -> [Id] -> Maybe Id
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (\v :: Id
v -> Id -> Type
idType Id
v Type -> Type -> Bool
`eqType` Type
a) [Id]
ante
    , Just ((vAB :: Id
vAB, (vA :: Id
vA,_)), ante' :: [Id]
ante') <- (Type -> Maybe (Id, Type))
-> [Id] -> Maybe ((Id, (Id, Type)), [Id])
forall a. (Type -> Maybe a) -> [Id] -> Maybe ((Id, a), [Id])
anyA ((Type -> Maybe Id) -> Type -> Maybe (Id, Type)
forall a. (Type -> Maybe a) -> Type -> Maybe (a, Type)
funLeft Type -> Maybe Id
isInAnte) [Id]
ante
    = CoreExpr -> (Id -> [CoreExpr]) -> [CoreExpr]
forall (f :: * -> *).
Applicative f =>
CoreExpr -> (Id -> f CoreExpr) -> f CoreExpr
letA (CoreExpr -> CoreExpr -> CoreExpr
forall b. Expr b -> Expr b -> Expr b
App (Id -> CoreExpr
forall b. Id -> Expr b
Var Id
vAB) (Id -> CoreExpr
forall b. Id -> Expr b
Var Id
vA)) ((Id -> [CoreExpr]) -> [CoreExpr])
-> (Id -> [CoreExpr]) -> [CoreExpr]
forall a b. (a -> b) -> a -> b
$ \vB :: Id
vB -> (Id
vB Id -> [Id] -> [Id]
forall a. a -> [a] -> [a]
: [Id]
ante') [Id] -> Type -> [CoreExpr]
==> Type
goal

-- Rule ⇒∨
ante :: [Id]
ante ==> goal :: Type
goal
    | Just (tys :: [Type]
tys, injs :: [CoreExpr -> CoreExpr]
injs, _destruct :: CoreExpr -> [Id] -> [CoreExpr] -> CoreExpr
_destruct) <- Type
-> Maybe
     ([Type], [CoreExpr -> CoreExpr],
      CoreExpr -> [Id] -> [CoreExpr] -> CoreExpr)
isSumType Type
goal
    = [[CoreExpr]] -> [CoreExpr]
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, MonadPlus m) =>
t (m a) -> m a
msum [ CoreExpr -> CoreExpr
inj (CoreExpr -> CoreExpr) -> [CoreExpr] -> [CoreExpr]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Id]
ante [Id] -> Type -> [CoreExpr]
==> Type
ty | (ty :: Type
ty,inj :: CoreExpr -> CoreExpr
inj) <- [Type] -> [CoreExpr -> CoreExpr] -> [(Type, CoreExpr -> CoreExpr)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
tys [CoreExpr -> CoreExpr]
injs ]

-- Rule →⇒4
ante :: [Id]
ante ==> goal :: Type
goal
    | Just ((vABC :: Id
vABC, ((a :: Type
a,b :: Type
b),_)), ante' :: [Id]
ante') <- (Type -> Maybe ((Type, Type), Type))
-> [Id] -> Maybe ((Id, ((Type, Type), Type)), [Id])
forall a. (Type -> Maybe a) -> [Id] -> Maybe ((Id, a), [Id])
anyA ((Type -> Maybe (Type, Type)) -> Type -> Maybe ((Type, Type), Type)
forall a. (Type -> Maybe a) -> Type -> Maybe (a, Type)
funLeft ((Type -> Maybe Type) -> Type -> Maybe (Type, Type)
forall a. (Type -> Maybe a) -> Type -> Maybe (a, Type)
funLeft Type -> Maybe Type
forall a. a -> Maybe a
Just)) [Id]
ante
    = do
        let eBC :: CoreExpr
eBC = Type -> (Id -> CoreExpr) -> CoreExpr
lam Type
b ((Id -> CoreExpr) -> CoreExpr) -> (Id -> CoreExpr) -> CoreExpr
forall a b. (a -> b) -> a -> b
$ \vB :: Id
vB -> CoreExpr -> CoreExpr -> CoreExpr
forall b. Expr b -> Expr b -> Expr b
App (Id -> CoreExpr
forall b. Id -> Expr b
Var Id
vABC) (Type -> (Id -> CoreExpr) -> CoreExpr
lam Type
a ((Id -> CoreExpr) -> CoreExpr) -> (Id -> CoreExpr) -> CoreExpr
forall a b. (a -> b) -> a -> b
$ \_ -> Id -> CoreExpr
forall b. Id -> Expr b
Var Id
vB)
        CoreExpr
eAB <- CoreExpr -> (Id -> [CoreExpr]) -> [CoreExpr]
forall (f :: * -> *).
Applicative f =>
CoreExpr -> (Id -> f CoreExpr) -> f CoreExpr
letA CoreExpr
eBC           ((Id -> [CoreExpr]) -> [CoreExpr])
-> (Id -> [CoreExpr]) -> [CoreExpr]
forall a b. (a -> b) -> a -> b
$ \vBC :: Id
vBC -> (Id
vBC Id -> [Id] -> [Id]
forall a. a -> [a] -> [a]
: [Id]
ante') [Id] -> Type -> [CoreExpr]
==> Type -> Type -> Type
FunTy Type
a Type
b
        CoreExpr -> (Id -> [CoreExpr]) -> [CoreExpr]
forall (f :: * -> *).
Applicative f =>
CoreExpr -> (Id -> f CoreExpr) -> f CoreExpr
letA (CoreExpr -> CoreExpr -> CoreExpr
forall b. Expr b -> Expr b -> Expr b
App (Id -> CoreExpr
forall b. Id -> Expr b
Var Id
vABC) CoreExpr
eAB) ((Id -> [CoreExpr]) -> [CoreExpr])
-> (Id -> [CoreExpr]) -> [CoreExpr]
forall a b. (a -> b) -> a -> b
$ \vC :: Id
vC  -> (Id
vC Id -> [Id] -> [Id]
forall a. a -> [a] -> [a]
: [Id]
ante') [Id] -> Type -> [CoreExpr]
==> Type
goal

-- Nothing found :-(
_ante :: [Id]
_ante ==> _goal :: Type
_goal
    = -- pprTrace "go" (vcat [ ppr (idType v) | v <- ante] $$ text "------" $$ ppr goal) $
      [CoreExpr]
forall (m :: * -> *) a. MonadPlus m => m a
mzero

-- Smart constructors

newVar :: Type -> Id
newVar :: Type -> Id
newVar ty :: Type
ty = FastString -> Unique -> Type -> Id
mkSysLocal (String -> FastString
mkFastString "x") (Int -> Unique
mkBuiltinUnique Int
i) Type
ty
  where i :: Int
i = String -> Int
forall a. Hashable a => a -> Int
hash (SDoc -> String
showSDocUnsafe (Type -> SDoc
forall a. Outputable a => a -> SDoc
ppr Type
ty))
  -- We don’t mind if variables with equal types shadow each other,
  -- so let’s just derive the unique from the type

lam :: Type -> (Id -> CoreExpr) -> CoreExpr
lam :: Type -> (Id -> CoreExpr) -> CoreExpr
lam ty :: Type
ty gen :: Id -> CoreExpr
gen = Id -> CoreExpr -> CoreExpr
forall b. b -> Expr b -> Expr b
Lam Id
v (CoreExpr -> CoreExpr) -> CoreExpr -> CoreExpr
forall a b. (a -> b) -> a -> b
$ Id -> CoreExpr
gen Id
v
  where v :: Id
v = Type -> Id
newVar Type
ty

lamA :: Applicative f => Type -> (Id -> f CoreExpr) -> f CoreExpr
lamA :: Type -> (Id -> f CoreExpr) -> f CoreExpr
lamA ty :: Type
ty gen :: Id -> f CoreExpr
gen = Id -> CoreExpr -> CoreExpr
forall b. b -> Expr b -> Expr b
Lam Id
v (CoreExpr -> CoreExpr) -> f CoreExpr -> f CoreExpr
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Id -> f CoreExpr
gen Id
v
  where v :: Id
v = Type -> Id
newVar Type
ty

let_ :: CoreExpr -> (Id -> CoreExpr) -> CoreExpr
let_ :: CoreExpr -> (Id -> CoreExpr) -> CoreExpr
let_ e :: CoreExpr
e gen :: Id -> CoreExpr
gen = Id -> CoreExpr -> CoreExpr -> CoreExpr
forall b. b -> Expr b -> Expr b -> Expr b
mkLetNonRec Id
v CoreExpr
e (CoreExpr -> CoreExpr) -> CoreExpr -> CoreExpr
forall a b. (a -> b) -> a -> b
$ Id -> CoreExpr
gen Id
v
  where v :: Id
v = Type -> Id
newVar (CoreExpr -> Type
exprType CoreExpr
e)

letA :: Applicative f => CoreExpr -> (Id -> f CoreExpr) -> f CoreExpr
letA :: CoreExpr -> (Id -> f CoreExpr) -> f CoreExpr
letA e :: CoreExpr
e gen :: Id -> f CoreExpr
gen = Id -> CoreExpr -> CoreExpr -> CoreExpr
forall b. b -> Expr b -> Expr b -> Expr b
mkLetNonRec Id
v CoreExpr
e (CoreExpr -> CoreExpr) -> f CoreExpr -> f CoreExpr
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Id -> f CoreExpr
gen Id
v
  where v :: Id
v = Type -> Id
newVar (CoreExpr -> Type
exprType CoreExpr
e)

letsA :: Applicative f => [CoreExpr] -> ([Id] -> f CoreExpr) -> f CoreExpr
letsA :: [CoreExpr] -> ([Id] -> f CoreExpr) -> f CoreExpr
letsA es :: [CoreExpr]
es gen :: [Id] -> f CoreExpr
gen = [Bind Id] -> CoreExpr -> CoreExpr
forall b. [Bind b] -> Expr b -> Expr b
mkLets ((Id -> CoreExpr -> Bind Id) -> [Id] -> [CoreExpr] -> [Bind Id]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Id -> CoreExpr -> Bind Id
forall b. b -> Expr b -> Bind b
NonRec [Id]
vs [CoreExpr]
es) (CoreExpr -> CoreExpr) -> f CoreExpr -> f CoreExpr
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Id] -> f CoreExpr
gen [Id]
vs
  where vs :: [Id]
vs = (CoreExpr -> Id) -> [CoreExpr] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Id
newVar (Type -> Id) -> (CoreExpr -> Type) -> CoreExpr -> Id
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CoreExpr -> Type
exprType) [CoreExpr]
es

-- Predicate on types

isProdType :: Type -> Maybe ([Type], [CoreExpr] -> CoreExpr, CoreExpr -> [Id] -> CoreExpr -> CoreExpr)
isProdType :: Type
-> Maybe
     ([Type], [CoreExpr] -> CoreExpr,
      CoreExpr -> [Id] -> CoreExpr -> CoreExpr)
isProdType ty :: Type
ty
    | Just (tc :: TyCon
tc, _, dc :: DataCon
dc, repargs :: [Type]
repargs) <- Type -> Maybe (TyCon, [Type], DataCon, [Type])
splitDataProductType_maybe Type
ty
    , Bool -> Bool
not (TyCon -> Bool
isRecTyCon TyCon
tc)
    = ([Type], [CoreExpr] -> CoreExpr,
 CoreExpr -> [Id] -> CoreExpr -> CoreExpr)
-> Maybe
     ([Type], [CoreExpr] -> CoreExpr,
      CoreExpr -> [Id] -> CoreExpr -> CoreExpr)
forall a. a -> Maybe a
Just ( [Type]
repargs
           , \args :: [CoreExpr]
args -> DataCon -> [CoreExpr] -> CoreExpr
forall b. DataCon -> [Arg b] -> Arg b
mkConApp DataCon
dc ((Type -> CoreExpr) -> [Type] -> [CoreExpr]
forall a b. (a -> b) -> [a] -> [b]
map Type -> CoreExpr
forall b. Type -> Expr b
Type [Type]
repargs [CoreExpr] -> [CoreExpr] -> [CoreExpr]
forall a. [a] -> [a] -> [a]
++ [CoreExpr]
args)
           , \scrut :: CoreExpr
scrut pats :: [Id]
pats rhs :: CoreExpr
rhs -> CoreExpr -> Type -> Type -> [CoreAlt] -> CoreExpr
mkWildCase CoreExpr
scrut Type
ty (CoreExpr -> Type
exprType CoreExpr
rhs) [(DataCon -> AltCon
DataAlt DataCon
dc, [Id]
pats, CoreExpr
rhs)]
           )
    | Just (tc :: TyCon
tc, ty_args :: [Type]
ty_args) <- HasDebugCallStack => Type -> Maybe (TyCon, [Type])
Type -> Maybe (TyCon, [Type])
splitTyConApp_maybe Type
ty
    , Just dc :: DataCon
dc <- TyCon -> Maybe DataCon
newTyConDataCon_maybe TyCon
tc
    , Bool -> Bool
not (TyCon -> Bool
isRecTyCon TyCon
tc)
    , let repargs :: [Type]
repargs = DataCon -> [Type] -> [Type]
dataConInstArgTys DataCon
dc [Type]
ty_args
    = ([Type], [CoreExpr] -> CoreExpr,
 CoreExpr -> [Id] -> CoreExpr -> CoreExpr)
-> Maybe
     ([Type], [CoreExpr] -> CoreExpr,
      CoreExpr -> [Id] -> CoreExpr -> CoreExpr)
forall a. a -> Maybe a
Just ( [Type]
repargs
           , \[arg :: CoreExpr
arg] -> TyCon -> [Type] -> CoreExpr -> CoreExpr
wrapNewTypeBody TyCon
tc [Type]
ty_args CoreExpr
arg
           , \scrut :: CoreExpr
scrut [pat :: Id
pat] rhs :: CoreExpr
rhs ->
                Id -> CoreExpr -> CoreExpr -> CoreExpr
forall b. b -> Expr b -> Expr b -> Expr b
mkLetNonRec Id
pat (TyCon -> [Type] -> CoreExpr -> CoreExpr
unwrapNewTypeBody TyCon
tc [Type]
ty_args CoreExpr
scrut) CoreExpr
rhs
           )
isProdType _ = Maybe
  ([Type], [CoreExpr] -> CoreExpr,
   CoreExpr -> [Id] -> CoreExpr -> CoreExpr)
forall a. Maybe a
Nothing

-- Haskell sum constructors can have multiple parameters. For our purposes, if
-- so, we wrap them in a product.
isSumType :: Type -> Maybe ([Type], [CoreExpr -> CoreExpr], CoreExpr -> [Id] -> [CoreExpr] -> CoreExpr)
isSumType :: Type
-> Maybe
     ([Type], [CoreExpr -> CoreExpr],
      CoreExpr -> [Id] -> [CoreExpr] -> CoreExpr)
isSumType ty :: Type
ty
    | Just (tc :: TyCon
tc, ty_args :: [Type]
ty_args) <- HasDebugCallStack => Type -> Maybe (TyCon, [Type])
Type -> Maybe (TyCon, [Type])
splitTyConApp_maybe Type
ty
    , Just dcs :: [DataCon]
dcs <- TyCon -> Maybe [DataCon]
isDataSumTyCon_maybe TyCon
tc
    , Bool -> Bool
not (TyCon -> Bool
isRecTyCon TyCon
tc)
    = let tys :: [Type]
tys = [ Boxity -> [Type] -> Type
mkTupleTy Boxity
Boxed (DataCon -> [Type] -> [Type]
dataConInstArgTys DataCon
dc [Type]
ty_args) | DataCon
dc <- [DataCon]
dcs ]
          injs :: [CoreExpr -> CoreExpr]
injs = [
            let vtys :: [Type]
vtys = DataCon -> [Type] -> [Type]
dataConInstArgTys DataCon
dc [Type]
ty_args
                vs :: [Id]
vs = (Type -> Id) -> [Type] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Id
newVar [Type]
vtys
            in \ e :: CoreExpr
e -> [Id] -> CoreExpr -> Id -> CoreExpr -> CoreExpr
mkSmallTupleCase [Id]
vs (DataCon -> [CoreExpr] -> CoreExpr
forall b. DataCon -> [Arg b] -> Arg b
mkConApp DataCon
dc ((Type -> CoreExpr) -> [Type] -> [CoreExpr]
forall a b. (a -> b) -> [a] -> [b]
map Type -> CoreExpr
forall b. Type -> Expr b
Type [Type]
ty_args [CoreExpr] -> [CoreExpr] -> [CoreExpr]
forall a. [a] -> [a] -> [a]
++ (Id -> CoreExpr) -> [Id] -> [CoreExpr]
forall a b. (a -> b) -> [a] -> [b]
map Id -> CoreExpr
forall b. Id -> Expr b
Var [Id]
vs))
                        (Type -> Id
mkWildValBinder (CoreExpr -> Type
exprType CoreExpr
e)) CoreExpr
e
           | DataCon
dc <- [DataCon]
dcs]
          destruct :: CoreExpr -> [Id] -> [CoreExpr] -> CoreExpr
destruct = \e :: CoreExpr
e vs :: [Id]
vs alts :: [CoreExpr]
alts ->
            CoreExpr -> Id -> Type -> [CoreAlt] -> CoreExpr
forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case CoreExpr
e (Type -> Id
mkWildValBinder (CoreExpr -> Type
exprType CoreExpr
e)) (CoreExpr -> Type
exprType ([CoreExpr] -> CoreExpr
forall a. [a] -> a
head [CoreExpr]
alts)) 
            [ let pats :: [Id]
pats = (Type -> Id) -> [Type] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Id
newVar (DataCon -> [Type] -> [Type]
dataConInstArgTys DataCon
dc [Type]
ty_args) in
              (DataCon -> AltCon
DataAlt DataCon
dc, [Id]
pats, Id -> CoreExpr -> CoreExpr -> CoreExpr
forall b. b -> Expr b -> Expr b -> Expr b
mkLetNonRec Id
v ([Id] -> CoreExpr
mkCoreVarTup [Id]
pats) CoreExpr
rhs)
            | (dc :: DataCon
dc,v :: Id
v,rhs :: CoreExpr
rhs) <- [DataCon] -> [Id] -> [CoreExpr] -> [(DataCon, Id, CoreExpr)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [DataCon]
dcs [Id]
vs [CoreExpr]
alts ]
      in ([Type], [CoreExpr -> CoreExpr],
 CoreExpr -> [Id] -> [CoreExpr] -> CoreExpr)
-> Maybe
     ([Type], [CoreExpr -> CoreExpr],
      CoreExpr -> [Id] -> [CoreExpr] -> CoreExpr)
forall a. a -> Maybe a
Just ([Type]
tys, [CoreExpr -> CoreExpr]
injs, CoreExpr -> [Id] -> [CoreExpr] -> CoreExpr
destruct)
isSumType _ = Maybe
  ([Type], [CoreExpr -> CoreExpr],
   CoreExpr -> [Id] -> [CoreExpr] -> CoreExpr)
forall a. Maybe a
Nothing

-- We don’t want to look into recursive type cons.
-- Which ones are recursive? Surely those that get mentioned in their
-- arguments. Or in type cons in their arguments.
-- But that is not enough, because of higher kinded arguments. So prohibit
-- those as well.

isRecTyCon :: TyCon -> Bool
isRecTyCon :: TyCon -> Bool
isRecTyCon tc :: TyCon
tc = NameSet -> TyCon -> Bool
go NameSet
emptyNameSet TyCon
tc
  where
    go :: NameSet -> TyCon -> Bool
go seen :: NameSet
seen tc :: TyCon
tc | TyCon -> Name
tyConName TyCon
tc Name -> NameSet -> Bool
`elemNameSet` NameSet
seen = Bool
True
               | (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Type -> Bool
isHigherKind [Type]
paramKinds     = Bool
False
               | (TyCon -> Bool) -> [TyCon] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (NameSet -> TyCon -> Bool
go NameSet
seen') [TyCon]
mentionedTyCons  = Bool
True
               | Bool
otherwise                       = Bool
False
      where mentionedTyCons :: [TyCon]
mentionedTyCons = (Type -> [TyCon]) -> [Type] -> [TyCon]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Type -> [TyCon]
getTyCons ([Type] -> [TyCon]) -> [Type] -> [TyCon]
forall a b. (a -> b) -> a -> b
$ (DataCon -> [Type]) -> [DataCon] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap DataCon -> [Type]
dataConOrigArgTys ([DataCon] -> [Type]) -> [DataCon] -> [Type]
forall a b. (a -> b) -> a -> b
$ TyCon -> [DataCon]
tyConDataCons TyCon
tc
            paramKinds :: [Type]
paramKinds = (Id -> Type) -> [Id] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Id -> Type
varType (TyCon -> [Id]
tyConTyVars TyCon
tc)
            seen' :: NameSet
seen' = NameSet
seen NameSet -> Name -> NameSet
`extendNameSet` TyCon -> Name
tyConName TyCon
tc

    isHigherKind :: Kind -> Bool
    isHigherKind :: Type -> Bool
isHigherKind k :: Type
k = Bool -> Bool
not (Type
k Type -> Type -> Bool
`eqType` Type
liftedTypeKind)

    getTyCons :: Type -> [TyCon]
    getTyCons :: Type -> [TyCon]
getTyCons = NameEnv TyCon -> [TyCon]
forall a. NameEnv a -> [a]
nameEnvElts (NameEnv TyCon -> [TyCon])
-> (Type -> NameEnv TyCon) -> Type -> [TyCon]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> NameEnv TyCon
go
      where
        go :: Type -> NameEnv TyCon
go (TyConApp tc :: TyCon
tc tys :: [Type]
tys) = Name -> TyCon -> NameEnv TyCon
forall a. Name -> a -> NameEnv a
unitNameEnv (TyCon -> Name
tyConName TyCon
tc) TyCon
tc NameEnv TyCon -> NameEnv TyCon -> NameEnv TyCon
forall a. NameEnv a -> NameEnv a -> NameEnv a
`plusNameEnv` [Type] -> NameEnv TyCon
go_s [Type]
tys
        go (LitTy _)         = NameEnv TyCon
forall a. NameEnv a
emptyNameEnv
        go (TyVarTy _)       = NameEnv TyCon
forall a. NameEnv a
emptyNameEnv
        go (AppTy a :: Type
a b :: Type
b)       = Type -> NameEnv TyCon
go Type
a NameEnv TyCon -> NameEnv TyCon -> NameEnv TyCon
forall a. NameEnv a -> NameEnv a -> NameEnv a
`plusNameEnv` Type -> NameEnv TyCon
go Type
b
        go (FunTy a :: Type
a b :: Type
b)       = Type -> NameEnv TyCon
go Type
a NameEnv TyCon -> NameEnv TyCon -> NameEnv TyCon
forall a. NameEnv a -> NameEnv a -> NameEnv a
`plusNameEnv` Type -> NameEnv TyCon
go Type
b
        go (ForAllTy _ ty :: Type
ty)   = Type -> NameEnv TyCon
go Type
ty
        go (CastTy ty :: Type
ty _)     = Type -> NameEnv TyCon
go Type
ty
        go (CoercionTy co :: Coercion
co)   = NameEnv TyCon
forall a. NameEnv a
emptyNameEnv
        go_s :: [Type] -> NameEnv TyCon
go_s = (Type -> NameEnv TyCon -> NameEnv TyCon)
-> NameEnv TyCon -> [Type] -> NameEnv TyCon
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (NameEnv TyCon -> NameEnv TyCon -> NameEnv TyCon
forall a. NameEnv a -> NameEnv a -> NameEnv a
plusNameEnv (NameEnv TyCon -> NameEnv TyCon -> NameEnv TyCon)
-> (Type -> NameEnv TyCon)
-> Type
-> NameEnv TyCon
-> NameEnv TyCon
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> NameEnv TyCon
go) NameEnv TyCon
forall a. NameEnv a
emptyNameEnv



-- A copy from MkId.hs, no longer exported there :-(
wrapNewTypeBody :: TyCon -> [Type] -> CoreExpr -> CoreExpr
wrapNewTypeBody :: TyCon -> [Type] -> CoreExpr -> CoreExpr
wrapNewTypeBody tycon :: TyCon
tycon args :: [Type]
args result_expr :: CoreExpr
result_expr
  = TyCon -> [Type] -> CoreExpr -> CoreExpr
wrapFamInstBody TyCon
tycon [Type]
args (CoreExpr -> CoreExpr) -> CoreExpr -> CoreExpr
forall a b. (a -> b) -> a -> b
$
    CoreExpr -> Coercion -> CoreExpr
mkCast CoreExpr
result_expr (Coercion -> Coercion
mkSymCo Coercion
co)
  where
    co :: Coercion
co = Role -> CoAxiom Unbranched -> [Type] -> [Coercion] -> Coercion
mkUnbranchedAxInstCo Role
Representational (TyCon -> CoAxiom Unbranched
newTyConCo TyCon
tycon) [Type]
args []

-- Combinators to search for matching things

funLeft :: (Type -> Maybe a) -> Type -> Maybe (a,Type)
funLeft :: (Type -> Maybe a) -> Type -> Maybe (a, Type)
funLeft p :: Type -> Maybe a
p (FunTy t1 :: Type
t1 t2 :: Type
t2) = (\x :: a
x -> (a
x,Type
t2)) (a -> (a, Type)) -> Maybe a -> Maybe (a, Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Type -> Maybe a
p Type
t1
funLeft _ _ = Maybe (a, Type)
forall a. Maybe a
Nothing

anyA :: (Type -> Maybe a) -> [Id] -> Maybe ((Id, a), [Id])
anyA :: (Type -> Maybe a) -> [Id] -> Maybe ((Id, a), [Id])
anyA _ [] = Maybe ((Id, a), [Id])
forall a. Maybe a
Nothing
anyA p :: Type -> Maybe a
p (v :: Id
v:vs :: [Id]
vs) | Just x :: a
x <- Type -> Maybe a
p (Id -> Type
idType Id
v) = ((Id, a), [Id]) -> Maybe ((Id, a), [Id])
forall a. a -> Maybe a
Just ((Id
v,a
x), [Id]
vs)
              | Bool
otherwise              = ([Id] -> [Id]) -> ((Id, a), [Id]) -> ((Id, a), [Id])
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (Id
vId -> [Id] -> [Id]
forall a. a -> [a] -> [a]
:) (((Id, a), [Id]) -> ((Id, a), [Id]))
-> Maybe ((Id, a), [Id]) -> Maybe ((Id, a), [Id])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Type -> Maybe a) -> [Id] -> Maybe ((Id, a), [Id])
forall a. (Type -> Maybe a) -> [Id] -> Maybe ((Id, a), [Id])
anyA Type -> Maybe a
p [Id]
vs