{- Language/Haskell/TH/Desugar/FV.hs

(c) Ryan Scott 2018

Compute free variables of programs.
-}

{-# LANGUAGE CPP #-}
module Language.Haskell.TH.Desugar.FV
  ( fvDType
  , extractBoundNamesDPat
  ) where

#if __GLASGOW_HASKELL__ < 804
import Data.Monoid ((<>))
#endif
import Language.Haskell.TH.Syntax
import Language.Haskell.TH.Desugar.AST
import qualified Language.Haskell.TH.Desugar.OSet as OS
import Language.Haskell.TH.Desugar.OSet (OSet)

-- | Compute the free variables of a 'DType'.
fvDType :: DType -> OSet Name
fvDType :: DType -> OSet Name
fvDType = DType -> OSet Name
go
  where
    go :: DType -> OSet Name
    go :: DType -> OSet Name
go (DForallT DForallTelescope
tele DType
ty)      = DForallTelescope -> OSet Name -> OSet Name
fv_dtele DForallTelescope
tele (DType -> OSet Name
go DType
ty)
    go (DConstrainedT DCxt
ctxt DType
ty) = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap DType -> OSet Name
fvDType DCxt
ctxt forall a. Semigroup a => a -> a -> a
<> DType -> OSet Name
go DType
ty
    go (DAppT DType
t1 DType
t2)           = DType -> OSet Name
go DType
t1 forall a. Semigroup a => a -> a -> a
<> DType -> OSet Name
go DType
t2
    go (DAppKindT DType
t DType
k)         = DType -> OSet Name
go DType
t forall a. Semigroup a => a -> a -> a
<> DType -> OSet Name
go DType
k
    go (DSigT DType
ty DType
ki)           = DType -> OSet Name
go DType
ty forall a. Semigroup a => a -> a -> a
<> DType -> OSet Name
go DType
ki
    go (DVarT Name
n)               = forall a. a -> OSet a
OS.singleton Name
n
    go (DConT {})              = forall a. OSet a
OS.empty
    go DType
DArrowT                 = forall a. OSet a
OS.empty
    go (DLitT {})              = forall a. OSet a
OS.empty
    go DType
DWildCardT              = forall a. OSet a
OS.empty

-----
-- Extracting bound term names
-----

-- | Extract the term variables bound by a 'DPat'.
--
-- This does /not/ extract any type variables bound by pattern signatures.
extractBoundNamesDPat :: DPat -> OSet Name
extractBoundNamesDPat :: DPat -> OSet Name
extractBoundNamesDPat = DPat -> OSet Name
go
  where
    go :: DPat -> OSet Name
    go :: DPat -> OSet Name
go (DLitP Lit
_)          = forall a. OSet a
OS.empty
    go (DVarP Name
n)          = forall a. a -> OSet a
OS.singleton Name
n
    go (DConP Name
_ DCxt
tys [DPat]
pats) = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap DType -> OSet Name
fvDType DCxt
tys forall a. Semigroup a => a -> a -> a
<> forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap DPat -> OSet Name
go [DPat]
pats
    go (DTildeP DPat
p)        = DPat -> OSet Name
go DPat
p
    go (DBangP DPat
p)         = DPat -> OSet Name
go DPat
p
    go (DSigP DPat
p DType
_)        = DPat -> OSet Name
go DPat
p
    go DPat
DWildP             = forall a. OSet a
OS.empty

-----
-- Binding forms
-----

-- | Adjust the free variables of something following a 'DForallTelescope'.
fv_dtele :: DForallTelescope -> OSet Name -> OSet Name
fv_dtele :: DForallTelescope -> OSet Name -> OSet Name
fv_dtele (DForallVis   [DTyVarBndrUnit]
tvbs) = forall flag. [DTyVarBndr flag] -> OSet Name -> OSet Name
fv_dtvbs [DTyVarBndrUnit]
tvbs
fv_dtele (DForallInvis [DTyVarBndrSpec]
tvbs) = forall flag. [DTyVarBndr flag] -> OSet Name -> OSet Name
fv_dtvbs [DTyVarBndrSpec]
tvbs

-- | Adjust the free variables of something following 'DTyVarBndr's.
fv_dtvbs :: [DTyVarBndr flag] -> OSet Name -> OSet Name
fv_dtvbs :: forall flag. [DTyVarBndr flag] -> OSet Name -> OSet Name
fv_dtvbs [DTyVarBndr flag]
tvbs OSet Name
fvs = forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr forall flag. DTyVarBndr flag -> OSet Name -> OSet Name
fv_dtvb OSet Name
fvs [DTyVarBndr flag]
tvbs

-- | Adjust the free variables of something following a 'DTyVarBndr'.
fv_dtvb :: DTyVarBndr flag -> OSet Name -> OSet Name
fv_dtvb :: forall flag. DTyVarBndr flag -> OSet Name -> OSet Name
fv_dtvb (DPlainTV Name
n flag
_)    OSet Name
fvs = forall a. Ord a => a -> OSet a -> OSet a
OS.delete Name
n OSet Name
fvs
fv_dtvb (DKindedTV Name
n flag
_ DType
k) OSet Name
fvs = forall a. Ord a => a -> OSet a -> OSet a
OS.delete Name
n OSet Name
fvs forall a. Semigroup a => a -> a -> a
<> DType -> OSet Name
fvDType DType
k