{-# language CPP #-}
module AutoApply
( autoapply
, autoapplyDecs
) where
import Control.Applicative
import Control.Arrow ( (>>>) )
import Control.Monad
#if __GLASGOW_HASKELL__ < 808
import Control.Monad.Fail ( MonadFail
)
#endif
import Control.Monad.Logic ( LogicT
, observeManyT
)
import Control.Monad.Trans as T
import Control.Monad.Trans.Except
import Control.Unification
import Control.Unification.IntVar
import Control.Unification.Types
import Data.Foldable
import Data.Functor
import Data.Functor.Fixedpoint
import Data.Maybe
import Data.Traversable
import Language.Haskell.TH
import Language.Haskell.TH.Desugar
import Prelude hiding ( pred )
autoapply :: [Name] -> Name -> Q Exp
autoapply givens fun = do
givenInfos <- for givens $ fmap (uncurry Given) . reifyVal "Argument"
funInfo <- uncurry Function <$> reifyVal "Function" fun
autoapply1 givenInfos funInfo
autoapplyDecs :: (String -> String) -> [Name] -> [Name] -> Q [Dec]
autoapplyDecs getNewName givens funs = do
givenInfos <- for givens $ fmap (uncurry Given) . reifyVal "Argument"
funInfos <- for funs $ fmap (uncurry Function) . reifyVal "Function"
let mkFun fun = do
exp' <- autoapply1 givenInfos fun
pure $ FunD (mkName . getNewName . nameBase . fName $ fun)
[Clause [] (NormalB exp') []]
traverse mkFun funInfos
data Given = Given
{ gName :: Name
, gType :: DType
}
deriving (Show)
data Function = Function
{ fName :: Name
, fType :: DType
}
deriving (Show)
autoapply1 :: [Given] -> Function -> Q Exp
autoapply1 givens fun = do
let
(fmap varBndrName -> cmdVarNames, preds, args, ret) = unravel (fType fun)
defaultMaybe m = (Just <$> m) <|> pure Nothing
liftQ :: Q a -> IntBindingT TypeF (LogicT Q) a
liftQ = T.lift . T.lift
errorToLogic go = runExceptT go >>= \case
Left (_ :: UFailure TypeF IntVar) -> empty
Right x -> pure x
quant t = do
vs <- getFreeVars t
for_ vs $ \v -> bindVar v . (UTerm . VarF) =<< liftQ (newName "a")
genProvs :: LogicT Q [ArgProvenance]
genProvs = evalIntBindingT $ do
cmdVars <- sequence [ (n, ) <$> freeVar | n <- cmdVarNames ]
instArgs <- traverse
(fmap (instWithVars cmdVars . snd) . liftQ . typeDtoF)
args
cmdM <- UVar <$> freeVar
retInst <- fmap (instWithVars cmdVars . snd) . liftQ . typeDtoF $ ret
instGivens <- fmap concat . for givens $ \g@Given {..} -> do
nonApp <- do
instTy <- uncurry inst <=< liftQ . typeDtoF $ gType
v <- liftQ $ newName "g"
pure (instTy, pure (), BoundPure v g)
app <- case stripForall gType of
(vars, DAppT m a) ->
liftQ (isInstance ''Applicative [sweeten m]) >>= \case
False -> pure Nothing
True -> do
m' <- inst vars . snd <=< liftQ . typeDtoF $ m
a' <- inst vars . snd <=< liftQ . typeDtoF $ a
v <- liftQ $ newName "g"
let predicate = do
_ <- unify m' cmdM
pure ()
pure $ Just (a', predicate, Bound v g)
_ -> pure Nothing
pure ([nonApp] <> toList app)
as <- for instArgs $ \argTy ->
defaultMaybe . asum $ instGivens <&> \(givenTy, predicate, g) -> do
_ <- errorToLogic $ do
predicate
freshGivenTy <- freshen givenTy
unify freshGivenTy argTy
pure g
when (any isMonadicBind (catMaybes as)) $ do
a <- UVar <$> freeVar
ret' <- errorToLogic $ unify retInst (UTerm (AppF cmdM a))
quant ret'
retFrozen <- freeze <$> errorToLogic (applyBindings ret')
case retFrozen of
Just (Fix (AppF m _)) -> do
let typeD = typeFtoD m
liftQ (isInstance ''Applicative [sweeten typeD]) >>= \case
False -> empty
True -> pure ()
Nothing ->
liftQ
$ fail
"\"impossible\", return type didn't freeze while checking monadic bindings"
_ -> empty
for_ preds $ \pred -> do
instPred <- fmap (instWithVars cmdVars . snd) . liftQ . typeDtoF $ pred
quant instPred
instFrozen <- freeze <$> errorToLogic (applyBindings instPred)
case instFrozen of
Just f -> do
let (class', predArgs) = unfoldDType (typeFtoD f)
typeArgs = [ a | DTANormal a <- predArgs ]
className <- case class' of
DConT n -> pure n
_ -> liftQ $ fail "unfolded predicate didn't begin with a ConT"
liftQ (isInstance className (sweeten <$> typeArgs)) >>= \case
False -> empty
True -> pure ()
Nothing ->
liftQ
$ fail
"\"impossible\": predicate didn't freeze while checking predicates"
for (zip args as) $ \case
(_, Just p ) -> pure p
(t, Nothing) -> (`Argument` t) <$> liftQ (newName "a")
argProvenances <-
note "\"Impossible\" Finding argument provenances failed"
. listToMaybe
=<< observeManyT 1 genProvs
unless (length argProvenances == length args) $ fail
"\"Impossible\", incorrect number of argument provenances were found"
let bindGiven = \case
BoundPure _ _ -> Nothing
Bound n g -> Just $ BindS (VarP n) (VarE (gName g))
Argument _ _ -> Nothing
bs = catMaybes (bindGiven <$> argProvenances)
ret' = applyDExp
(DVarE (fName fun))
(argProvenances <&> \case
Bound n _ -> DVarE n
BoundPure _ (Given n _) -> DVarE n
Argument n _ -> DVarE n
)
exp' <- dsDoStmts (bs <> [NoBindS (sweeten ret')])
pure $ LamE [ SigP (VarP n) (sweeten t) | Argument n t <- argProvenances ]
(sweeten exp')
data ArgProvenance
= Bound Name Given
| BoundPure Name Given
| Argument Name DType
deriving (Show)
isMonadicBind :: ArgProvenance -> Bool
isMonadicBind = \case
Bound _ _ -> True
_ -> False
data TypeF a
= AppF a a
| VarF Name
| ConF Name
| ArrowF
| LitF TyLit
deriving (Show, Functor, Foldable, Traversable)
instance Unifiable TypeF where
zipMatch (AppF l1 r1) (AppF l2 r2) =
Just (AppF (Right (l1, l2)) (Right (r1, r2)))
zipMatch (VarF n1) (VarF n2) | n1 == n2 = Just (VarF n1)
zipMatch (ConF n1) (ConF n2) | n1 == n2 = Just (ConF n1)
zipMatch ArrowF ArrowF = Just ArrowF
zipMatch (LitF l1) (LitF l2) | l1 == l2 = Just (LitF l1)
zipMatch _ _ = Nothing
typeDtoF :: MonadFail m => DType -> m ([Name], Fix TypeF)
typeDtoF = traverse go . stripForall
where
go = \case
DForallT{} -> fail "TODO: Higher ranked types"
DConstrainedT{} -> fail "TODO: Higher ranked types"
DAppT l r -> do
l' <- go l
r' <- go r
pure $ Fix (AppF l' r')
DAppKindT t _ -> go t
DSigT t _ -> go t
DVarT n -> pure . Fix $ VarF n
DConT n -> pure . Fix $ ConF n
DArrowT -> pure . Fix $ ArrowF
DLitT l -> pure . Fix $ LitF l
DWildCardT -> fail "TODO: Wildcards"
typeFtoD :: Fix TypeF -> DType
typeFtoD = unFix >>> \case
AppF l r -> DAppT (typeFtoD l) (typeFtoD r)
VarF n -> DVarT n
ConF n -> DConT n
ArrowF -> DArrowT
LitF l -> DLitT l
varBndrName :: DTyVarBndr -> Name
varBndrName = \case
DPlainTV n -> n
DKindedTV n _ -> n
raiseForalls :: DType -> DType
raiseForalls = go >>> \case
(vs, ctx, t) -> DForallT ForallVis vs . DConstrainedT ctx $ t
where
go = \case
DForallT _ vs t -> let (vs', ctx', t') = go t in (vs <> vs', ctx', t')
DConstrainedT ctx t ->
let (vs', ctx', t') = go t in (vs', ctx <> ctx', t')
l :~> r -> let (vs, ctx, r') = go r in (vs, ctx, l :~> r')
t -> ([], [], t)
pattern (:~>) :: DType -> DType -> DType
pattern l :~> r = DArrowT `DAppT` l `DAppT` r
inst
:: BindingMonad TypeF IntVar m
=> [Name]
-> Fix TypeF
-> m (UTerm TypeF IntVar)
inst ns t = do
vs <- sequence [ (n, ) <$> freeVar | n <- ns ]
pure $ instWithVars vs t
instWithVars :: [(Name, IntVar)] -> Fix TypeF -> UTerm TypeF IntVar
instWithVars vs t =
let go (Fix f) = case f of
AppF l r -> UTerm (AppF (go l) (go r))
VarF n | Just v <- lookup n vs -> UVar v
VarF n -> UTerm (VarF n)
ConF n -> UTerm (ConF n)
ArrowF -> UTerm ArrowF
LitF l -> UTerm (LitF l)
in go t
reifyVal :: String -> Name -> Q (Name, DType)
reifyVal d n = dsReify n >>= \case
Just (DVarI name ty _) -> pure (name, ty)
_ -> fail $ d <> " " <> show n <> " isn't a value"
stripForall :: DType -> ([Name], DType)
stripForall = raiseForalls >>> \case
DForallT _ vs (DConstrainedT _ ty) -> (varBndrName <$> vs, ty)
DForallT _ vs ty -> (varBndrName <$> vs, ty)
DConstrainedT _ ty -> ([], ty)
ty -> ([], ty)
unravel :: DType -> ([DTyVarBndr], [DPred], [DType], DType)
unravel t =
let (argList, ret) = unravelDType t
go = \case
DFANil -> ([], [], [])
DFAForalls _ vs as -> (vs, [], []) <> go as
DFACxt preds as -> ([], preds, []) <> go as
DFAAnon a as -> ([], [], [a]) <> go as
in let (vs, preds, args) = go argList in (vs, preds, args, ret)
note :: MonadFail m => String -> Maybe a -> m a
note s = maybe (fail s) pure