{-# LANGUAGE CPP
, BangPatterns
, DataKinds
, EmptyCase
, ExistentialQuantification
, FlexibleContexts
, FlexibleInstances
, GADTs
, GeneralizedNewtypeDeriving
, KindSignatures
, MultiParamTypeClasses
, OverloadedStrings
, PolyKinds
, ScopedTypeVariables
, StandaloneDeriving
, TupleSections
, TypeFamilies
, TypeOperators
, UndecidableInstances
#-}
{-# OPTIONS_GHC -Wall -fwarn-tabs #-}
module Language.Hakaru.Syntax.Hoist (hoist) where
import Control.Applicative (liftA2)
import Control.Monad.RWS
import qualified Data.Foldable as F
import qualified Data.Graph as G
import qualified Data.IntMap.Strict as IM
import qualified Data.List as L
import Data.Maybe (mapMaybe)
import Data.Number.Nat
import Data.Proxy (KProxy (..))
import qualified Data.Vector as V
import Language.Hakaru.Syntax.ABT
import Language.Hakaru.Syntax.ANF (isValue)
import Language.Hakaru.Syntax.AST
import Language.Hakaru.Syntax.AST.Eq (alphaEq)
import Language.Hakaru.Syntax.Gensym
import Language.Hakaru.Syntax.IClasses
import Language.Hakaru.Types.DataKind
import Language.Hakaru.Types.Sing (Sing)
#if __GLASGOW_HASKELL__ < 710
import Control.Applicative
#endif
data Entry (abt :: Hakaru -> *)
= forall (a :: Hakaru) . Entry
{ varDependencies :: !(VarSet (KindOf a))
, expression :: !(abt a)
, sing :: !(Sing a)
, bindings :: ![Variable a]
}
instance Show (Entry abt) where
show (Entry d _ _ b) = "Entry (" ++ show d ++ ") (" ++ show b ++ ")"
type HakaruProxy = ('KProxy :: KProxy Hakaru)
type LiveSet = VarSet HakaruProxy
type HakaruVar = SomeVariable HakaruProxy
-- The @HoistM@ monad makes use of three monadic layers to propagate information
-- both downwards to the leaves and upwards to the root node of the AST.
--
-- The Writer layer propagates the live expressions which may be hoisted (i.e.
-- all their data dependencies are currently filled) from each subexpression to
-- their parents.
--
-- The Reader layer propagates the currently bound variables which will be used
-- to decide when to introduce new bindings.
--
-- The State layer is just to provide a counter in order to gensym new
-- variables, since the process of adding new bindings is a little tricky.
-- What we want is to fully duplicate bindings without altering the original
-- variable identifiers. To do so, all original variable names are preserved and
-- new variables are added outside the range of existing variables.
newtype HoistM (abt :: [Hakaru] -> Hakaru -> *) a
= HoistM { runHoistM :: RWS LiveSet (ExpressionSet abt) Nat a }
deriving instance Functor (HoistM abt)
deriving instance (ABT Term abt) => Applicative (HoistM abt)
deriving instance (ABT Term abt) => Monad (HoistM abt)
deriving instance (ABT Term abt) => MonadState Nat (HoistM abt)
deriving instance (ABT Term abt) => MonadWriter (ExpressionSet abt) (HoistM abt)
deriving instance (ABT Term abt) => MonadReader LiveSet (HoistM abt)
newtype ExpressionSet (abt :: [Hakaru] -> Hakaru -> *)
= ExpressionSet [Entry (abt '[])]
mergeEntry :: (ABT Term abt) => Entry (abt '[]) -> Entry (abt '[]) -> Entry (abt '[])
mergeEntry (Entry d e s1 b1) (Entry _ _ s2 b2) =
case jmEq1 s1 s2 of
Just Refl -> Entry d e s1 $ L.nub (b1 ++ b2)
Nothing -> error "cannot union mismatched entries"
entryEqual :: (ABT Term abt) => Entry (abt '[]) -> Entry (abt '[]) -> Bool
entryEqual Entry{varDependencies=d1,expression=e1,sing=s1}
Entry{varDependencies=d2,expression=e2,sing=s2} =
case (d1 == d2, jmEq1 s1 s2) of
(True , Just Refl) -> alphaEq e1 e2
_ -> False
unionEntrySet
:: forall abt
. (ABT Term abt)
=> ExpressionSet abt
-> ExpressionSet abt
-> ExpressionSet abt
unionEntrySet (ExpressionSet xs) (ExpressionSet ys) =
ExpressionSet . mapMaybe uniquify $ L.groupBy entryEqual (xs ++ ys)
where
uniquify :: [Entry (abt '[])] -> Maybe (Entry (abt '[]))
uniquify [] = Nothing
uniquify zs = Just $ L.foldl1' mergeEntry zs
intersectEntrySet
:: forall abt
. (ABT Term abt)
=> ExpressionSet abt
-> ExpressionSet abt
-> ExpressionSet abt
intersectEntrySet (ExpressionSet xs) (ExpressionSet ys) = ExpressionSet merged
where
merged :: [Entry (abt '[])]
merged = map (uncurry mergeEntry) . filter (uncurry entryEqual) $ liftA2 (,) xs ys
-- The general case for generating the entry set for a term is to simply union
-- the sets for all the subterms, so we choose union as our monoidal operation
-- for the Writer monad.
instance (ABT Term abt) => Monoid (ExpressionSet abt) where
mempty = ExpressionSet []
mappend = unionEntrySet
-- Given a list of entries to introduce, order them so that their data
-- data dependencies are satisified.
topSortEntries
:: forall abt
. [Entry (abt '[])]
-> [Entry (abt '[])]
topSortEntries entryList = map (entries V.!) $ G.topSort graph
where
entries :: V.Vector (Entry (abt '[]))
!entries = V.fromList entryList
-- The graph is represented as dependencies between entries, where an entry
-- (a) depends on entry (b) if (b) introduces a variable which (a) depends
-- on.
getVIDs :: Entry (abt '[]) -> [Int]
getVIDs Entry{bindings=b} = map (fromNat . varID) b
-- Associates all variables introduced by an entry to the entry itself.
-- A given entry may introduce multiple bindings, since an entry stores all
-- α-equivalent variable definitions.
assocBindingsTo :: IM.IntMap Int -> Int -> Entry (abt '[]) -> IM.IntMap Int
assocBindingsTo m n = L.foldl' (\acc v -> IM.insert v n acc) m . getVIDs
-- Mapping from variable IDs to their corresponding entries
varMap :: IM.IntMap Int
!varMap = V.ifoldl' assocBindingsTo IM.empty entries
-- Create an edge from each dependency to the variable
makeEdges :: Int -> Entry (abt '[]) -> [G.Edge]
makeEdges idx Entry{varDependencies=d} = map (, idx)
. mapMaybe (flip IM.lookup varMap)
$ varSetKeys d
-- Collect all the verticies to build the full graph
vertices :: [G.Edge]
!vertices = V.foldr (++) [] $ V.imap makeEdges entries
-- The full graph structure to be topologically sorted
graph :: G.Graph
!graph = G.buildG (0, V.length entries - 1) vertices
recordEntry
:: (ABT Term abt)
=> Variable a
-> abt '[] a
-> HoistM abt ()
recordEntry v abt = tell $ ExpressionSet [Entry (freeVars abt) abt (varType v) [v]]
execHoistM :: Nat -> HoistM abt a -> a
execHoistM counter act = a
where
hoisted = runHoistM act
(a, _, _) = runRWS hoisted emptyVarSet counter
-- | An expression is considered "toplevel" if it can be hoisted outside all
-- binders. This means that the expression has no data dependencies.
toplevelEntry
:: Entry abt
-> Bool
toplevelEntry Entry{varDependencies=d} = sizeVarSet d == 0
captureEntries
:: (ABT Term abt)
=> HoistM abt a
-> HoistM abt (a, ExpressionSet abt)
captureEntries = censor (const mempty) . listen
hoist
:: (ABT Term abt)
=> abt '[] a
-> abt '[] a
hoist abt = execHoistM (nextFreeOrBind abt) $
captureEntries (hoist' abt) >>= uncurry (introduceToplevel emptyVarSet)
partitionEntrySet
:: (Entry (abt '[]) -> Bool)
-> ExpressionSet abt
-> (ExpressionSet abt, ExpressionSet abt)
partitionEntrySet p (ExpressionSet xs) = (ExpressionSet true, ExpressionSet false)
where
(true, false) = L.partition p xs
introduceToplevel
:: (ABT Term abt)
=> LiveSet
-> abt '[] a
-> ExpressionSet abt
-> HoistM abt (abt '[] a)
introduceToplevel avail abt entries = do
-- After transforming the given ast, we need to introduce all the toplevel
-- bindings (i.e. bindings with no data dependencies), most of which should be
-- eliminated by constant propagation.
let (ExpressionSet toplevel, rest) = partitionEntrySet toplevelEntry entries
intro = concatMap getBoundVars toplevel ++ fromVarSet avail
-- First we wrap the now AST in the all terms which depdend on top level
-- definitions
wrapped <- introduceBindings intro abt rest
-- Then wrap the result in the toplevel definitions
wrapExpr wrapped toplevel
bindVar
:: (ABT Term abt)
=> Variable (a :: Hakaru)
-> HoistM abt b
-> HoistM abt b
bindVar = local . insertVarSet
isolateBinder
:: (ABT Term abt)
=> Variable (a :: Hakaru)
-> HoistM abt b
-> HoistM abt (b, ExpressionSet abt)
isolateBinder v = captureEntries . bindVar v
hoist'
:: forall abt xs a . (ABT Term abt)
=> abt xs a
-> HoistM abt (abt xs a)
hoist' = start
where
insertMany :: [HakaruVar] -> LiveSet -> LiveSet
insertMany = flip $ L.foldl' (\ acc (SomeVariable v) -> insertVarSet v acc)
start :: forall ys b . abt ys b -> HoistM abt (abt ys b)
start = loop [] . viewABT
isolateBinders :: [HakaruVar] -> HoistM abt c -> HoistM abt (c, ExpressionSet abt)
isolateBinders xs = censor (const mempty) . listen . local (insertMany xs)
-- @loop@ takes 2 parameters.
--
-- 1. The list of variables bound so far
-- 2. The current term we are recurring over
--
-- We add a value to the first every time we hit a @Bind@ term, and when
-- a @Syn@ term is finally reached, we introduce any hoisted values whose
-- data dependencies are satisified by these new variables.
loop :: forall ys b
. [HakaruVar]
-> View (Term abt) ys b
-> HoistM abt (abt ys b)
loop _ (Var v) = return (var v)
-- This case is not needed, but we can avoid performing the expensive work
-- of calling introduceBindings in the case were we won't be performing any
-- work.
loop [] (Syn s) = hoistTerm s
loop xs (Syn s) = do
(term, entries) <- isolateBinders xs (hoistTerm s)
introduceBindings xs term entries
loop xs (Bind v b) = bind v <$> loop (SomeVariable v : xs) b
getBoundVars :: Entry x -> [HakaruVar]
getBoundVars Entry{bindings=b} = fmap SomeVariable b
wrapExpr
:: forall abt b . (ABT Term abt)
=> abt '[] b
-> [Entry (abt '[])]
-> HoistM abt (abt '[] b)
wrapExpr = F.foldrM wrap
where
mklet :: abt '[] a -> Variable a -> abt '[] b -> abt '[] b
mklet e v b =
case viewABT b of
Var v' | Just Refl <- varEq v v' -> e
_ -> syn (Let_ :$ e :* bind v b :* End)
-- Binds the Entry's expression to a fresh variable and rebinds any other
-- variable uses to the fresh variable.
wrap :: Entry (abt '[]) -> abt '[] b -> HoistM abt (abt '[] b)
wrap Entry{expression=e,bindings=[]} acc = do
tmp <- varForExpr e
return $ mklet e tmp acc
wrap Entry{expression=e,bindings=(x:xs)} acc = do
let rhs = var x
body = foldr (mklet rhs) acc xs
return $ mklet e x body
-- This will introduce all binders which must be introduced by binding the
-- @newVars@ set. As a side effect, the remaining entries are written into the
-- Writer layer of the stack.
introduceBindings
:: forall (a :: Hakaru) abt
. (ABT Term abt)
=> [HakaruVar]
-> abt '[] a
-> ExpressionSet abt
-> HoistM abt (abt '[] a)
introduceBindings newVars body (ExpressionSet entries) = do
tell (ExpressionSet leftOver)
wrapExpr body (topSortEntries resultEntries)
where
resultEntries, leftOver :: [Entry (abt '[])]
(resultEntries, leftOver) = loop entries newVars
introducedBy
:: forall (b :: Hakaru)
. Variable b
-> Entry (abt '[])
-> Bool
introducedBy v Entry{varDependencies=deps} = memberVarSet v deps
loop
:: [Entry (abt '[])]
-> [HakaruVar]
-> ([Entry (abt '[])], [Entry (abt '[])])
loop exprs [] = ([], exprs)
loop exprs (SomeVariable v : xs) = (introduced ++ intro, acc)
where
~(intro, acc) = loop rest (xs ++ vars)
vars = concatMap getBoundVars introduced
(introduced, rest) = L.partition (introducedBy v) exprs
-- Contrary to the other binding forms, let expressions are killed by the
-- hoisting pass. Their RHSs are floated upward in the AST and re-introduced
-- where their data dependencies are fulfilled. Thus, the result of hoisting
-- a let expression is just the hoisted body.
hoistTerm
:: forall (a :: Hakaru) (abt :: [Hakaru] -> Hakaru -> *)
. (ABT Term abt)
=> Term abt a
-> HoistM abt (abt '[] a)
hoistTerm (Let_ :$ rhs :* body :* End) =
caseBind body $ \ v body' -> do
rhs' <- hoist' rhs
recordEntry v rhs'
bindVar v (hoist' body')
hoistTerm (Lam_ :$ body :* End) =
caseBind body $ \ v body' -> do
available <- fmap (insertVarSet v) ask
(body'', entries) <- isolateBinder v (hoist' body')
finalized <- introduceToplevel available body'' entries
return $ syn (Lam_ :$ bind v finalized :* End)
hoistTerm term = do
result <- syn <$> traverse21 hoist' term
if isValue result
then return result
else do fresh <- varForExpr result
recordEntry fresh result
return (var fresh)