-- | A simple inliner. Inlines all non-recursive functions.
--
-- This should all work.
module Kempe.Inline ( inline
                    ) where

import           Data.Graph         (Graph, Vertex, graphFromEdges, path)
import qualified Data.IntMap        as IM
import qualified Data.List.NonEmpty as NE
import           Data.Maybe         (fromMaybe, mapMaybe)
import           Data.Tuple.Extra   (third3)
import           Kempe.AST
import           Kempe.Name
import           Kempe.Unique

-- | A 'FnModuleMap' is a map which retrives the 'Atoms's defining
-- a given 'Name'
type FnModuleMap c b = IM.IntMap (Maybe [Atom c b])

inline :: Module a c b -> Module a c b
inline :: Module a c b -> Module a c b
inline Module a c b
m = (KempeDecl a c b -> KempeDecl a c b)
-> Module a c b -> Module a c b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap KempeDecl a c b -> KempeDecl a c b
forall a. KempeDecl a c b -> KempeDecl a c b
inlineDecl Module a c b
m
    where inlineDecl :: KempeDecl a c b -> KempeDecl a c b
inlineDecl (FunDecl b
l Name b
n [KempeTy a]
ty [KempeTy a]
ty' [Atom c b]
as) = b
-> Name b
-> [KempeTy a]
-> [KempeTy a]
-> [Atom c b]
-> KempeDecl a c b
forall a c b.
b
-> Name b
-> [KempeTy a]
-> [KempeTy a]
-> [Atom c b]
-> KempeDecl a c b
FunDecl b
l Name b
n [KempeTy a]
ty [KempeTy a]
ty' (Name b -> [Atom c b] -> [Atom c b]
inlineAtoms Name b
n [Atom c b]
as)
          inlineDecl KempeDecl a c b
d                       = KempeDecl a c b
d
          inlineAtoms :: Name b -> [Atom c b] -> [Atom c b]
inlineAtoms Name b
n = (Atom c b -> [Atom c b]) -> [Atom c b] -> [Atom c b]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Name b -> Atom c b -> [Atom c b]
inlineAtom Name b
n)
          inlineAtom :: Name b -> Atom c b -> [Atom c b]
inlineAtom Name b
declName a :: Atom c b
a@(AtName b
_ Name b
n) =
            if Graph -> Vertex -> Vertex -> Bool
path Graph
graph (Name b -> Vertex
nLookup Name b
n) (Name b -> Vertex
nLookup Name b
declName) Bool -> Bool -> Bool
|| Name b -> Bool
forall a. Name a -> Bool
don'tInline Name b
n
                then [Atom c b
a] -- no inline
                else (Atom c b -> [Atom c b]) -> [Atom c b] -> [Atom c b]
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Name b -> Atom c b -> [Atom c b]
inlineAtom Name b
declName) ([Atom c b] -> [Atom c b]) -> [Atom c b] -> [Atom c b]
forall a b. (a -> b) -> a -> b
$ Atom c b -> Name b -> [Atom c b]
forall a. Atom c b -> Name a -> [Atom c b]
findDecl Atom c b
a Name b
n
          inlineAtom Name b
declName (If b
l [Atom c b]
as [Atom c b]
as') =
            [b -> [Atom c b] -> [Atom c b] -> Atom c b
forall c b. b -> [Atom c b] -> [Atom c b] -> Atom c b
If b
l (Name b -> [Atom c b] -> [Atom c b]
inlineAtoms Name b
declName [Atom c b]
as) (Name b -> [Atom c b] -> [Atom c b]
inlineAtoms Name b
declName [Atom c b]
as')]
          inlineAtom Name b
declName (Case b
l NonEmpty (Pattern c b, [Atom c b])
ls) =
            let (NonEmpty (Pattern c b)
ps, NonEmpty [Atom c b]
ass) = NonEmpty (Pattern c b, [Atom c b])
-> (NonEmpty (Pattern c b), NonEmpty [Atom c b])
forall (f :: * -> *) a b. Functor f => f (a, b) -> (f a, f b)
NE.unzip NonEmpty (Pattern c b, [Atom c b])
ls
                in [b -> NonEmpty (Pattern c b, [Atom c b]) -> Atom c b
forall c b. b -> NonEmpty (Pattern c b, [Atom c b]) -> Atom c b
Case b
l (NonEmpty (Pattern c b)
-> NonEmpty [Atom c b] -> NonEmpty (Pattern c b, [Atom c b])
forall a b. NonEmpty a -> NonEmpty b -> NonEmpty (a, b)
NE.zip NonEmpty (Pattern c b)
ps (NonEmpty [Atom c b] -> NonEmpty (Pattern c b, [Atom c b]))
-> NonEmpty [Atom c b] -> NonEmpty (Pattern c b, [Atom c b])
forall a b. (a -> b) -> a -> b
$ ([Atom c b] -> [Atom c b])
-> NonEmpty [Atom c b] -> NonEmpty [Atom c b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Name b -> [Atom c b] -> [Atom c b]
inlineAtoms Name b
declName) NonEmpty [Atom c b]
ass)]
          inlineAtom Name b
_ Atom c b
a = [Atom c b
a]
          fnMap :: FnModuleMap c b
fnMap = Module a c b -> FnModuleMap c b
forall a c b. Module a c b -> FnModuleMap c b
mkFnModuleMap Module a c b
m
          (Graph
graph, Vertex -> (KempeDecl a c b, Name b, [Name b])
_, Name b -> Vertex
nLookup) = Module a c b
-> (Graph, Vertex -> (KempeDecl a c b, Name b, [Name b]),
    Name b -> Vertex)
forall a c b.
Module a c b
-> (Graph, Vertex -> (KempeDecl a c b, Name b, [Name b]),
    Name b -> Vertex)
kempeGraph Module a c b
m
          findDecl :: Atom c b -> Name a -> [Atom c b]
findDecl Atom c b
at (Name Text
_ (Unique Vertex
k) a
_) =
            case Vertex -> FnModuleMap c b -> Maybe [Atom c b]
forall a. Vertex -> IntMap a -> a
findPreDecl Vertex
k FnModuleMap c b
fnMap of
                Just [Atom c b]
as -> [Atom c b]
as
                Maybe [Atom c b]
Nothing -> Atom c b -> [Atom c b]
forall (f :: * -> *) a. Applicative f => a -> f a
pure Atom c b
at -- tried to inline an extern function
          findPreDecl :: Vertex -> IntMap a -> a
findPreDecl = a -> Vertex -> IntMap a -> a
forall a. a -> Vertex -> IntMap a -> a
IM.findWithDefault ([Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"Internal error: FnModuleMap does not contain name/declaration!")
          recMap :: IntMap Bool
recMap = Module a c b -> (Graph, Name b -> Vertex) -> IntMap Bool
forall a c b.
Module a c b -> (Graph, Name b -> Vertex) -> IntMap Bool
graphRecursiveMap Module a c b
m (Graph
graph, Name b -> Vertex
nLookup)
          don'tInline :: Name a -> Bool
don'tInline (Name Text
_ (Unique Vertex
i) a
_) = Bool -> Vertex -> IntMap Bool -> Bool
forall a. a -> Vertex -> IntMap a -> a
IM.findWithDefault ([Char] -> Bool
forall a. HasCallStack => [Char] -> a
error [Char]
"Internal error! recursive map missing key!") Vertex
i IntMap Bool
recMap

-- | Given a module, make a map telling which top-level names are recursive or
-- cannot be inlined
graphRecursiveMap :: Module a c b -> (Graph, Name b -> Vertex) -> IM.IntMap Bool
graphRecursiveMap :: Module a c b -> (Graph, Name b -> Vertex) -> IntMap Bool
graphRecursiveMap Module a c b
m (Graph
graph, Name b -> Vertex
nLookup) = [(Vertex, Bool)] -> IntMap Bool
forall a. [(Vertex, a)] -> IntMap a
IM.fromList ([(Vertex, Bool)] -> IntMap Bool)
-> [(Vertex, Bool)] -> IntMap Bool
forall a b. (a -> b) -> a -> b
$ (KempeDecl a c b -> Maybe (Vertex, Bool))
-> Module a c b -> [(Vertex, Bool)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe KempeDecl a c b -> Maybe (Vertex, Bool)
forall a c. KempeDecl a c b -> Maybe (Vertex, Bool)
fnRecursive Module a c b
m
    where fnRecursive :: KempeDecl a c b -> Maybe (Vertex, Bool)
fnRecursive (FunDecl b
_ n :: Name b
n@(Name Text
_ (Unique Vertex
i) b
_) [KempeTy a]
_ [KempeTy a]
_ [Atom c b]
as) | Name b
n Name b -> [Name b] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Atom c b] -> [Name b]
forall c a. [Atom c a] -> [Name a]
namesInAtoms [Atom c b]
as = (Vertex, Bool) -> Maybe (Vertex, Bool)
forall a. a -> Maybe a
Just (Vertex
i, Bool
True) -- if it calls iteself
                                                                 | Name b -> [Atom c b] -> Bool
forall c. Name b -> [Atom c b] -> Bool
anyReachable Name b
n [Atom c b]
as = (Vertex, Bool) -> Maybe (Vertex, Bool)
forall a. a -> Maybe a
Just (Vertex
i, Bool
True)
                                                                 | Bool
otherwise = (Vertex, Bool) -> Maybe (Vertex, Bool)
forall a. a -> Maybe a
Just (Vertex
i, Bool
False)
          fnRecursive (ExtFnDecl b
_ (Name Text
_ (Unique Vertex
i) b
_) [KempeTy a]
_ [KempeTy a]
_ ByteString
_) = (Vertex, Bool) -> Maybe (Vertex, Bool)
forall a. a -> Maybe a
Just (Vertex
i, Bool
True) -- not recursive but don't try to inline this
          fnRecursive KempeDecl a c b
_ = Maybe (Vertex, Bool)
forall a. Maybe a
Nothing
          anyReachable :: Name b -> [Atom c b] -> Bool
anyReachable Name b
n [Atom c b]
as =
            (Name b -> Bool) -> [Name b] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (\Name b
nA -> Graph -> Vertex -> Vertex -> Bool
path Graph
graph (Name b -> Vertex
nLookup Name b
nA) (Name b -> Vertex
nLookup Name b
n)) ([Atom c b] -> [Name b]
forall c a. [Atom c a] -> [Name a]
namesInAtoms [Atom c b]
as) -- TODO: lift let-binding (nLookup?)


kempeGraph :: Module a c b -> (Graph, Vertex -> (KempeDecl a c b, Name b, [Name b]), Name b -> Vertex)
kempeGraph :: Module a c b
-> (Graph, Vertex -> (KempeDecl a c b, Name b, [Name b]),
    Name b -> Vertex)
kempeGraph = ((Name b -> Maybe Vertex) -> Name b -> Vertex)
-> (Graph, Vertex -> (KempeDecl a c b, Name b, [Name b]),
    Name b -> Maybe Vertex)
-> (Graph, Vertex -> (KempeDecl a c b, Name b, [Name b]),
    Name b -> Vertex)
forall c c' a b. (c -> c') -> (a, b, c) -> (a, b, c')
third3 (Maybe Vertex -> Vertex
forall a. Maybe a -> a
findVtx (Maybe Vertex -> Vertex)
-> (Name b -> Maybe Vertex) -> Name b -> Vertex
forall b c a. (b -> c) -> (a -> b) -> a -> c
.) ((Graph, Vertex -> (KempeDecl a c b, Name b, [Name b]),
  Name b -> Maybe Vertex)
 -> (Graph, Vertex -> (KempeDecl a c b, Name b, [Name b]),
     Name b -> Vertex))
-> (Module a c b
    -> (Graph, Vertex -> (KempeDecl a c b, Name b, [Name b]),
        Name b -> Maybe Vertex))
-> Module a c b
-> (Graph, Vertex -> (KempeDecl a c b, Name b, [Name b]),
    Name b -> Vertex)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(KempeDecl a c b, Name b, [Name b])]
-> (Graph, Vertex -> (KempeDecl a c b, Name b, [Name b]),
    Name b -> Maybe Vertex)
forall key node.
Ord key =>
[(node, key, [key])]
-> (Graph, Vertex -> (node, key, [key]), key -> Maybe Vertex)
graphFromEdges ([(KempeDecl a c b, Name b, [Name b])]
 -> (Graph, Vertex -> (KempeDecl a c b, Name b, [Name b]),
     Name b -> Maybe Vertex))
-> (Module a c b -> [(KempeDecl a c b, Name b, [Name b])])
-> Module a c b
-> (Graph, Vertex -> (KempeDecl a c b, Name b, [Name b]),
    Name b -> Maybe Vertex)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Module a c b -> [(KempeDecl a c b, Name b, [Name b])]
forall a c b. Module a c b -> [(KempeDecl a c b, Name b, [Name b])]
kempePreGraph
    where findVtx :: Maybe a -> a
findVtx = a -> Maybe a -> a
forall a. a -> Maybe a -> a
fromMaybe ([Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"Internal error: bad name lookup!")

kempePreGraph :: Module a c b -> [(KempeDecl a c b, Name b, [Name b])]
kempePreGraph :: Module a c b -> [(KempeDecl a c b, Name b, [Name b])]
kempePreGraph = (KempeDecl a c b -> Maybe (KempeDecl a c b, Name b, [Name b]))
-> Module a c b -> [(KempeDecl a c b, Name b, [Name b])]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe KempeDecl a c b -> Maybe (KempeDecl a c b, Name b, [Name b])
forall a c b.
KempeDecl a c b -> Maybe (KempeDecl a c b, Name b, [Name b])
kempeDeclToGraph
    where kempeDeclToGraph :: KempeDecl a c b -> Maybe (KempeDecl a c b, Name b, [Name b])
          kempeDeclToGraph :: KempeDecl a c b -> Maybe (KempeDecl a c b, Name b, [Name b])
kempeDeclToGraph d :: KempeDecl a c b
d@(FunDecl b
_ Name b
n [KempeTy a]
_ [KempeTy a]
_ [Atom c b]
as)  = (KempeDecl a c b, Name b, [Name b])
-> Maybe (KempeDecl a c b, Name b, [Name b])
forall a. a -> Maybe a
Just (KempeDecl a c b
d, Name b
n, (Atom c b -> [Name b]) -> [Atom c b] -> [Name b]
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Atom c b -> [Name b]
forall c a. Atom c a -> [Name a]
namesInAtom [Atom c b]
as)
          kempeDeclToGraph d :: KempeDecl a c b
d@(ExtFnDecl b
_ Name b
n [KempeTy a]
_ [KempeTy a]
_ ByteString
_) = (KempeDecl a c b, Name b, [Name b])
-> Maybe (KempeDecl a c b, Name b, [Name b])
forall a. a -> Maybe a
Just (KempeDecl a c b
d, Name b
n, [])
          kempeDeclToGraph KempeDecl a c b
_                       = Maybe (KempeDecl a c b, Name b, [Name b])
forall a. Maybe a
Nothing

mkFnModuleMap :: Module a c b -> FnModuleMap c b
mkFnModuleMap :: Module a c b -> FnModuleMap c b
mkFnModuleMap = [(Vertex, Maybe [Atom c b])] -> FnModuleMap c b
forall a. [(Vertex, a)] -> IntMap a
IM.fromList ([(Vertex, Maybe [Atom c b])] -> FnModuleMap c b)
-> (Module a c b -> [(Vertex, Maybe [Atom c b])])
-> Module a c b
-> FnModuleMap c b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (KempeDecl a c b -> Maybe (Vertex, Maybe [Atom c b]))
-> Module a c b -> [(Vertex, Maybe [Atom c b])]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe KempeDecl a c b -> Maybe (Vertex, Maybe [Atom c b])
forall a c b. KempeDecl a c b -> Maybe (Vertex, Maybe [Atom c b])
toInt where
    toInt :: KempeDecl a c b -> Maybe (Vertex, Maybe [Atom c b])
toInt (FunDecl b
_ (Name Text
_ (Unique Vertex
i) b
_) [KempeTy a]
_ [KempeTy a]
_ [Atom c b]
as)  = (Vertex, Maybe [Atom c b]) -> Maybe (Vertex, Maybe [Atom c b])
forall a. a -> Maybe a
Just (Vertex
i, [Atom c b] -> Maybe [Atom c b]
forall a. a -> Maybe a
Just [Atom c b]
as)
    toInt (ExtFnDecl b
_ (Name Text
_ (Unique Vertex
i) b
_) [KempeTy a]
_ [KempeTy a]
_ ByteString
_) = (Vertex, Maybe [Atom c b]) -> Maybe (Vertex, Maybe [Atom c b])
forall a. a -> Maybe a
Just (Vertex
i, Maybe [Atom c b]
forall a. Maybe a
Nothing)
    toInt KempeDecl a c b
_                                         = Maybe (Vertex, Maybe [Atom c b])
forall a. Maybe a
Nothing

namesInAtoms :: [Atom c a] -> [Name a]
namesInAtoms :: [Atom c a] -> [Name a]
namesInAtoms = (Atom c a -> [Name a]) -> [Atom c a] -> [Name a]
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Atom c a -> [Name a]
forall c a. Atom c a -> [Name a]
namesInAtom

namesInAtom :: Atom c a -> [Name a]
namesInAtom :: Atom c a -> [Name a]
namesInAtom AtBuiltin{}   = []
namesInAtom (If a
_ [Atom c a]
as [Atom c a]
as') = (Atom c a -> [Name a]) -> [Atom c a] -> [Name a]
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Atom c a -> [Name a]
forall c a. Atom c a -> [Name a]
namesInAtom [Atom c a]
as [Name a] -> [Name a] -> [Name a]
forall a. Semigroup a => a -> a -> a
<> (Atom c a -> [Name a]) -> [Atom c a] -> [Name a]
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Atom c a -> [Name a]
forall c a. Atom c a -> [Name a]
namesInAtom [Atom c a]
as'
namesInAtom (Dip a
_ [Atom c a]
as)    = (Atom c a -> [Name a]) -> [Atom c a] -> [Name a]
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Atom c a -> [Name a]
forall c a. Atom c a -> [Name a]
namesInAtom [Atom c a]
as
namesInAtom (AtName a
_ Name a
n)  = [Name a
n]
namesInAtom AtCons{}      = []
namesInAtom IntLit{}      = []
namesInAtom BoolLit{}     = []
namesInAtom Int8Lit{}     = []
namesInAtom WordLit{}     = []
namesInAtom (Case a
_ NonEmpty (Pattern c a, [Atom c a])
as)   = (Atom c a -> [Name a]) -> [Atom c a] -> [Name a]
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Atom c a -> [Name a]
forall c a. Atom c a -> [Name a]
namesInAtom (((Pattern c a, [Atom c a]) -> [Atom c a])
-> NonEmpty (Pattern c a, [Atom c a]) -> [Atom c a]
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Pattern c a, [Atom c a]) -> [Atom c a]
forall a b. (a, b) -> b
snd NonEmpty (Pattern c a, [Atom c a])
as)