{-# LANGUAGE CPP #-}
{-# LANGUAGE Unsafe #-}

-- | A plugin that identifies and reports on uses of recursion. The name evokes
--   a language pragma – implying a @Recursion@ pragma that is enabled by
--   default.
module NoRecursion (plugin) where

-- NB: These unqualified modules come from semigroups in GHC <8, and base
--     otherwise.
import safe Data.List.NonEmpty (NonEmpty, nonEmpty)
import safe Data.Semigroup (Semigroup ((<>)))
import safe "base" Control.Applicative (Applicative (pure))
import safe "base" Control.Category (Category ((.)))
import safe "base" Control.Exception (ErrorCall (ErrorCall), throwIO)
import safe "base" Control.Monad ((=<<))
import safe "base" Data.Bool (Bool (True), not, (&&), (||))
import safe "base" Data.Data (Data)
import safe "base" Data.Either (Either (Left), either)
import safe "base" Data.Foldable
  ( Foldable (foldMap, toList),
    all,
    elem,
    notElem,
    traverse_,
  )
import safe "base" Data.Function (($))
import safe "base" Data.Functor (Functor (fmap), (<$>))
import safe "base" Data.List (filter, intercalate, isPrefixOf, null)
import safe "base" Data.Maybe (maybe)
import safe "base" Data.String (String)
import safe "base" Data.Tuple (fst, uncurry)
#if MIN_VERSION_ghc(9, 0, 0)
import safe "base" Data.Bifunctor (Bifunctor (first))
import qualified "ghc" GHC.Plugins as Plugins
#else
import qualified "ghc" GhcPlugins as Plugins
#endif

defaultPurePlugin :: Plugins.Plugin
#if MIN_VERSION_ghc(8, 6, 1)
defaultPurePlugin :: Plugin
defaultPurePlugin =
  Plugin
Plugins.defaultPlugin {Plugins.pluginRecompile = Plugins.purePlugin}
#else
defaultPurePlugin = Plugins.defaultPlugin
#endif

-- | The entrypoint for the `NoRecursion` plugin.
plugin :: Plugins.Plugin
plugin :: Plugin
plugin = Plugin
defaultPurePlugin {Plugins.installCoreToDos = \[String]
_opts -> [CoreToDo] -> CoreM [CoreToDo]
forall a. a -> CoreM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([CoreToDo] -> CoreM [CoreToDo])
-> ([CoreToDo] -> [CoreToDo]) -> [CoreToDo] -> CoreM [CoreToDo]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [CoreToDo] -> [CoreToDo]
install}

install :: [Plugins.CoreToDo] -> [Plugins.CoreToDo]
install :: [CoreToDo] -> [CoreToDo]
install = (String -> CorePluginPass -> CoreToDo
Plugins.CoreDoPluginPass String
"add NoRecursion rule" CorePluginPass
noRecursionPass :)

-- | Annotations of type @a@ for a module – `fst` is the module-level
--   annotations and `Data.Tuple.snd` is a map of annotations for each name in
--   the module.
type Annotations a = (a, Plugins.NameEnv a)

getAnnotations :: (Data a) => Plugins.ModGuts -> Plugins.CoreM (Annotations [a])
#if MIN_VERSION_ghc(9, 0, 1)
getAnnotations :: forall a. Data a => ModGuts -> CoreM (Annotations [a])
getAnnotations ModGuts
guts =
  (ModuleEnv [a] -> [a])
-> (ModuleEnv [a], NameEnv [a]) -> ([a], NameEnv [a])
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first
    ( \ModuleEnv [a]
modAnns ->
        ModuleEnv [a] -> [a] -> Module -> [a]
forall a. ModuleEnv a -> a -> Module -> a
Plugins.lookupWithDefaultModuleEnv ModuleEnv [a]
modAnns [] (Module -> [a]) -> Module -> [a]
forall a b. (a -> b) -> a -> b
$
          ModGuts -> Module
Plugins.mg_module ModGuts
guts
    )
    ((ModuleEnv [a], NameEnv [a]) -> ([a], NameEnv [a]))
-> CoreM (ModuleEnv [a], NameEnv [a]) -> CoreM ([a], NameEnv [a])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([Word8] -> a) -> ModGuts -> CoreM (ModuleEnv [a], NameEnv [a])
forall a.
Typeable a =>
([Word8] -> a) -> ModGuts -> CoreM (ModuleEnv [a], NameEnv [a])
Plugins.getAnnotations [Word8] -> a
forall a. Data a => [Word8] -> a
Plugins.deserializeWithData ModGuts
guts
#else
getAnnotations guts =
  ( \anns ->
      ( Plugins.lookupWithDefaultUFM
          anns
          []
          ( Plugins.ModuleTarget $ Plugins.mg_module guts ::
              Plugins.CoreAnnTarget
          ),
        anns
      )
  )
    <$> Plugins.getAnnotations Plugins.deserializeWithData guts
#endif

noRecursionPass :: Plugins.ModGuts -> Plugins.CoreM Plugins.ModGuts
noRecursionPass :: CorePluginPass
noRecursionPass ModGuts
guts = do
  DynFlags
dflags <- CoreM DynFlags
forall (m :: * -> *). HasDynFlags m => m DynFlags
Plugins.getDynFlags
  Annotations [String]
anns <- ModGuts -> CoreM (Annotations [String])
forall a. Data a => ModGuts -> CoreM (Annotations [a])
getAnnotations ModGuts
guts
  (NonEmpty (RecursionRecord CoreBndr) -> CoreM ModGuts)
-> (() -> CoreM ModGuts)
-> Either (NonEmpty (RecursionRecord CoreBndr)) ()
-> CoreM ModGuts
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either
    ( \NonEmpty (RecursionRecord CoreBndr)
recs ->
        IO ModGuts -> CoreM ModGuts
forall a. IO a -> CoreM a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
Plugins.liftIO (IO ModGuts -> CoreM ModGuts)
-> (String -> IO ModGuts) -> String -> CoreM ModGuts
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. ErrorCall -> IO ModGuts
forall e a. Exception e => e -> IO a
throwIO (ErrorCall -> IO ModGuts)
-> (String -> ErrorCall) -> String -> IO ModGuts
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. String -> ErrorCall
ErrorCall (String -> CoreM ModGuts) -> String -> CoreM ModGuts
forall a b. (a -> b) -> a -> b
$
          String
"something recursive:\n"
            String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
"\n" (NonEmpty String -> [String]
forall a. NonEmpty a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList (NonEmpty String -> [String]) -> NonEmpty String -> [String]
forall a b. (a -> b) -> a -> b
$ DynFlags -> RecursionRecord CoreBndr -> String
forall b. Outputable b => DynFlags -> RecursionRecord b -> String
formatRecursionRecord DynFlags
dflags (RecursionRecord CoreBndr -> String)
-> NonEmpty (RecursionRecord CoreBndr) -> NonEmpty String
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> NonEmpty (RecursionRecord CoreBndr)
recs)
    )
    (\() -> CorePluginPass
forall a. a -> CoreM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ModGuts
guts)
    (Either (NonEmpty (RecursionRecord CoreBndr)) () -> CoreM ModGuts)
-> ([Bind CoreBndr]
    -> Either (NonEmpty (RecursionRecord CoreBndr)) ())
-> [Bind CoreBndr]
-> CoreM ModGuts
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. DynFlags
-> Annotations [String]
-> [Bind CoreBndr]
-> Either (NonEmpty (RecursionRecord CoreBndr)) ()
failOnRecursion DynFlags
dflags Annotations [String]
anns
    ([Bind CoreBndr] -> CoreM ModGuts)
-> [Bind CoreBndr] -> CoreM ModGuts
forall a b. (a -> b) -> a -> b
$ ModGuts -> [Bind CoreBndr]
Plugins.mg_binds ModGuts
guts

data RecursionRecord b = RecursionRecord [b] (NonEmpty b)

formatRecursionRecord ::
  (Plugins.Outputable b) => Plugins.DynFlags -> RecursionRecord b -> String
formatRecursionRecord :: forall b. Outputable b => DynFlags -> RecursionRecord b -> String
formatRecursionRecord DynFlags
dflags (RecursionRecord [b]
context NonEmpty b
recs) =
  String -> (NonEmpty b -> String) -> Maybe (NonEmpty b) -> String
forall b a. b -> (a -> b) -> Maybe a -> b
maybe
    String
"at the top level"
    ( \NonEmpty b
v ->
        String
"in "
          String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate
            String
" >> "
            (DynFlags -> SDoc -> String
Plugins.showSDoc DynFlags
dflags (SDoc -> String) -> (b -> SDoc) -> b -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. b -> SDoc
forall a. Outputable a => a -> SDoc
Plugins.ppr (b -> String) -> [b] -> [String]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> NonEmpty b -> [b]
forall a. NonEmpty a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList NonEmpty b
v)
    )
    ([b] -> Maybe (NonEmpty b)
forall a. [a] -> Maybe (NonEmpty a)
nonEmpty [b]
context)
    String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
", the following bindings were recursive: "
    String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
", " (DynFlags -> SDoc -> String
Plugins.showSDoc DynFlags
dflags (SDoc -> String) -> (b -> SDoc) -> b -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. b -> SDoc
forall a. Outputable a => a -> SDoc
Plugins.ppr (b -> String) -> [b] -> [String]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> NonEmpty b -> [b]
forall a. NonEmpty a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList NonEmpty b
recs)

recursionAnnotation :: String
recursionAnnotation :: String
recursionAnnotation = String
"Recursion"

noRecursionAnnotation :: String
noRecursionAnnotation :: String
noRecursionAnnotation = String
"NoRecursion"

failOnRecursion ::
  Plugins.DynFlags ->
  Annotations [String] ->
  [Plugins.CoreBind] ->
  Either (NonEmpty (RecursionRecord Plugins.CoreBndr)) ()
failOnRecursion :: DynFlags
-> Annotations [String]
-> [Bind CoreBndr]
-> Either (NonEmpty (RecursionRecord CoreBndr)) ()
failOnRecursion DynFlags
dflags ([String]
modAnns, NameEnv [String]
nameAnns) [Bind CoreBndr]
original =
  let moduleAllowsRecursion :: Bool
moduleAllowsRecursion =
        String -> [String] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem String
recursionAnnotation [String]
modAnns
          Bool -> Bool -> Bool
&& String -> [String] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
notElem String
noRecursionAnnotation [String]
modAnns
   in (NonEmpty (RecursionRecord CoreBndr)
 -> Either (NonEmpty (RecursionRecord CoreBndr)) Any)
-> Maybe (NonEmpty (RecursionRecord CoreBndr))
-> Either (NonEmpty (RecursionRecord CoreBndr)) ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ NonEmpty (RecursionRecord CoreBndr)
-> Either (NonEmpty (RecursionRecord CoreBndr)) Any
forall a b. a -> Either a b
Left
        (Maybe (NonEmpty (RecursionRecord CoreBndr))
 -> Either (NonEmpty (RecursionRecord CoreBndr)) ())
-> ([RecursionRecord CoreBndr]
    -> Maybe (NonEmpty (RecursionRecord CoreBndr)))
-> [RecursionRecord CoreBndr]
-> Either (NonEmpty (RecursionRecord CoreBndr)) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [RecursionRecord CoreBndr]
-> Maybe (NonEmpty (RecursionRecord CoreBndr))
forall a. [a] -> Maybe (NonEmpty a)
nonEmpty
        -- __TODO__: Default method implementations seem to cause mutual
        --           recursion with the instance, so here we filter them out,
        --           but this probably lets some real mutual recursion slip
        --           through.
        ([RecursionRecord CoreBndr]
 -> Maybe (NonEmpty (RecursionRecord CoreBndr)))
-> ([RecursionRecord CoreBndr] -> [RecursionRecord CoreBndr])
-> [RecursionRecord CoreBndr]
-> Maybe (NonEmpty (RecursionRecord CoreBndr))
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (RecursionRecord CoreBndr -> Bool)
-> [RecursionRecord CoreBndr] -> [RecursionRecord CoreBndr]
forall a. (a -> Bool) -> [a] -> [a]
filter
          ( \(RecursionRecord [CoreBndr]
context NonEmpty CoreBndr
recs) ->
              Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$
                [CoreBndr] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [CoreBndr]
context
                  Bool -> Bool -> Bool
&& (CoreBndr -> Bool) -> NonEmpty CoreBndr -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all
                    ( \CoreBndr
var ->
                        let v :: String
v = DynFlags -> SDoc -> String
Plugins.showSDoc DynFlags
dflags (SDoc -> String) -> SDoc -> String
forall a b. (a -> b) -> a -> b
$ CoreBndr -> SDoc
forall a. Outputable a => a -> SDoc
Plugins.ppr CoreBndr
var
                         in String
"$c" String -> String -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isPrefixOf` String
v Bool -> Bool -> Bool
|| String
"$f" String -> String -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isPrefixOf` String
v
                    )
                    NonEmpty CoreBndr
recs
          )
        ([RecursionRecord CoreBndr]
 -> Either (NonEmpty (RecursionRecord CoreBndr)) ())
-> [RecursionRecord CoreBndr]
-> Either (NonEmpty (RecursionRecord CoreBndr)) ()
forall a b. (a -> b) -> a -> b
$ Bind CoreBndr -> [RecursionRecord CoreBndr]
forall b. Bind b -> [RecursionRecord b]
recursiveCallsForBind
          (Bind CoreBndr -> [RecursionRecord CoreBndr])
-> [Bind CoreBndr] -> [RecursionRecord CoreBndr]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Bind CoreBndr -> Bool) -> [Bind CoreBndr] -> [Bind CoreBndr]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (Bind CoreBndr -> Bool) -> Bind CoreBndr -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Bool -> NameEnv [String] -> Bind CoreBndr -> Bool
allowBind Bool
moduleAllowsRecursion NameEnv [String]
nameAnns) [Bind CoreBndr]
original

addBindingReference :: b -> [RecursionRecord b] -> [RecursionRecord b]
addBindingReference :: forall b. b -> [RecursionRecord b] -> [RecursionRecord b]
addBindingReference b
var =
  (RecursionRecord b -> RecursionRecord b)
-> [RecursionRecord b] -> [RecursionRecord b]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(RecursionRecord [b]
context NonEmpty b
recs) -> [b] -> NonEmpty b -> RecursionRecord b
forall b. [b] -> NonEmpty b -> RecursionRecord b
RecursionRecord (b
var b -> [b] -> [b]
forall a. a -> [a] -> [a]
: [b]
context) NonEmpty b
recs)

allowBind :: Bool -> Plugins.NameEnv [String] -> Plugins.CoreBind -> Bool
allowBind :: Bool -> NameEnv [String] -> Bind CoreBndr -> Bool
allowBind Bool
moduleAllowsRecursion NameEnv [String]
anns = \case
  Plugins.NonRec {} -> Bool
True
  Plugins.Rec [(CoreBndr, Expr CoreBndr)]
bs -> ((CoreBndr, Expr CoreBndr) -> Bool)
-> [(CoreBndr, Expr CoreBndr)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Bool -> NameEnv [String] -> CoreBndr -> Bool
recursionAllowed Bool
moduleAllowsRecursion NameEnv [String]
anns (CoreBndr -> Bool)
-> ((CoreBndr, Expr CoreBndr) -> CoreBndr)
-> (CoreBndr, Expr CoreBndr)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (CoreBndr, Expr CoreBndr) -> CoreBndr
forall a b. (a, b) -> a
fst) [(CoreBndr, Expr CoreBndr)]
bs

recursionAllowed :: Bool -> Plugins.NameEnv [String] -> Plugins.Var -> Bool
recursionAllowed :: Bool -> NameEnv [String] -> CoreBndr -> Bool
recursionAllowed Bool
moduleAllowsRecursion NameEnv [String]
anns CoreBndr
var =
  let strAnns :: [String]
strAnns =
        NameEnv [String] -> [String] -> Unique -> [String]
forall key elt. UniqFM key elt -> elt -> Unique -> elt
Plugins.lookupWithDefaultUFM_Directly NameEnv [String]
anns [] (Unique -> [String]) -> Unique -> [String]
forall a b. (a -> b) -> a -> b
$ CoreBndr -> Unique
forall a. Uniquable a => a -> Unique
Plugins.getUnique CoreBndr
var
   in (Bool
moduleAllowsRecursion Bool -> Bool -> Bool
|| String -> [String] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem String
recursionAnnotation [String]
strAnns)
        Bool -> Bool -> Bool
&& String -> [String] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
notElem String
noRecursionAnnotation [String]
strAnns

recursiveCallsForBind :: Plugins.Bind b -> [RecursionRecord b]
recursiveCallsForBind :: forall b. Bind b -> [RecursionRecord b]
recursiveCallsForBind =
  let collectCalls :: b -> Expr b -> [RecursionRecord b]
collectCalls b
v = b -> [RecursionRecord b] -> [RecursionRecord b]
forall b. b -> [RecursionRecord b] -> [RecursionRecord b]
addBindingReference b
v ([RecursionRecord b] -> [RecursionRecord b])
-> (Expr b -> [RecursionRecord b]) -> Expr b -> [RecursionRecord b]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Expr b -> [RecursionRecord b]
forall b. Expr b -> [RecursionRecord b]
collectRecursiveCalls
   in \case
        Plugins.NonRec b
v Expr b
rhs -> b -> Expr b -> [RecursionRecord b]
forall {b}. b -> Expr b -> [RecursionRecord b]
collectCalls b
v Expr b
rhs
        Plugins.Rec [(b, Expr b)]
binds ->
          let nestedRecursion :: [RecursionRecord b]
nestedRecursion = ((b, Expr b) -> [RecursionRecord b])
-> [(b, Expr b)] -> [RecursionRecord b]
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap ((b -> Expr b -> [RecursionRecord b])
-> (b, Expr b) -> [RecursionRecord b]
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry b -> Expr b -> [RecursionRecord b]
forall {b}. b -> Expr b -> [RecursionRecord b]
collectCalls) [(b, Expr b)]
binds
           in [RecursionRecord b]
-> (NonEmpty (b, Expr b) -> [RecursionRecord b])
-> Maybe (NonEmpty (b, Expr b))
-> [RecursionRecord b]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe
                [RecursionRecord b]
nestedRecursion
                (\NonEmpty (b, Expr b)
bnds -> [b] -> NonEmpty b -> RecursionRecord b
forall b. [b] -> NonEmpty b -> RecursionRecord b
RecursionRecord [] ((b, Expr b) -> b
forall a b. (a, b) -> a
fst ((b, Expr b) -> b) -> NonEmpty (b, Expr b) -> NonEmpty b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> NonEmpty (b, Expr b)
bnds) RecursionRecord b -> [RecursionRecord b] -> [RecursionRecord b]
forall a. a -> [a] -> [a]
: [RecursionRecord b]
nestedRecursion)
                (Maybe (NonEmpty (b, Expr b)) -> [RecursionRecord b])
-> Maybe (NonEmpty (b, Expr b)) -> [RecursionRecord b]
forall a b. (a -> b) -> a -> b
$ [(b, Expr b)] -> Maybe (NonEmpty (b, Expr b))
forall a. [a] -> Maybe (NonEmpty a)
nonEmpty [(b, Expr b)]
binds

-- | This collects all identifiable recursion points in an expression.
collectRecursiveCalls :: Plugins.Expr b -> [RecursionRecord b]
collectRecursiveCalls :: forall b. Expr b -> [RecursionRecord b]
collectRecursiveCalls = \case
  Plugins.App Expr b
f Expr b
a -> Expr b -> [RecursionRecord b]
forall b. Expr b -> [RecursionRecord b]
collectRecursiveCalls Expr b
f [RecursionRecord b] -> [RecursionRecord b] -> [RecursionRecord b]
forall a. Semigroup a => a -> a -> a
<> Expr b -> [RecursionRecord b]
forall b. Expr b -> [RecursionRecord b]
collectRecursiveCalls Expr b
a
  Plugins.Case Expr b
scrut b
_ Type
_ [Alt b]
alts ->
    Expr b -> [RecursionRecord b]
forall b. Expr b -> [RecursionRecord b]
collectRecursiveCalls Expr b
scrut [RecursionRecord b] -> [RecursionRecord b] -> [RecursionRecord b]
forall a. Semigroup a => a -> a -> a
<> (Alt b -> [RecursionRecord b]) -> [Alt b] -> [RecursionRecord b]
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Alt b -> [RecursionRecord b]
forall b. Alt b -> [RecursionRecord b]
recursiveCallsForAlt [Alt b]
alts
  Plugins.Cast Expr b
e CoercionR
_ -> Expr b -> [RecursionRecord b]
forall b. Expr b -> [RecursionRecord b]
collectRecursiveCalls Expr b
e
  Plugins.Coercion CoercionR
_ -> []
  Plugins.Lam b
_ Expr b
body -> Expr b -> [RecursionRecord b]
forall b. Expr b -> [RecursionRecord b]
collectRecursiveCalls Expr b
body
  Plugins.Let Bind b
bind Expr b
e -> Bind b -> [RecursionRecord b]
forall b. Bind b -> [RecursionRecord b]
recursiveCallsForBind Bind b
bind [RecursionRecord b] -> [RecursionRecord b] -> [RecursionRecord b]
forall a. Semigroup a => a -> a -> a
<> Expr b -> [RecursionRecord b]
forall b. Expr b -> [RecursionRecord b]
collectRecursiveCalls Expr b
e
  Plugins.Lit Literal
_ -> []
  Plugins.Tick CoreTickish
_ Expr b
body -> Expr b -> [RecursionRecord b]
forall b. Expr b -> [RecursionRecord b]
collectRecursiveCalls Expr b
body
  Plugins.Type Type
_ -> []
  Plugins.Var CoreBndr
_ -> []

recursiveCallsForAlt :: Plugins.Alt b -> [RecursionRecord b]
#if MIN_VERSION_ghc(9, 2, 0)
recursiveCallsForAlt :: forall b. Alt b -> [RecursionRecord b]
recursiveCallsForAlt (Plugins.Alt AltCon
_ [b]
_ Expr b
rhs) = Expr b -> [RecursionRecord b]
forall b. Expr b -> [RecursionRecord b]
collectRecursiveCalls Expr b
rhs
#else
recursiveCallsForAlt (_, _, rhs) = collectRecursiveCalls rhs
#endif