{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}

module Test.Tasty.AutoCollect.GHC (
  module Test.Tasty.AutoCollect.GHC.Shim,

  -- * Output helpers
  showPpr,

  -- * Builders
  genFuncSig,
  genFuncDecl,
  lhsvar,
  mkHsAppTypes,
  mkHsTyVar,
  mkExprTypeSig,

  -- * Located utilities
  genLoc,
  firstLocatedWhere,
  getSpanLine,

  -- * Name utilities
  mkRdrName,
  mkLRdrName,
  mkRdrNameType,
  mkLRdrNameType,
  fromRdrName,
  thNameToGhcNameIO,
) where

import Data.Foldable (foldl')
import Data.IORef (IORef)
import Data.List (sortOn)
import Data.Maybe (fromMaybe, listToMaybe, mapMaybe)
import qualified Language.Haskell.TH as TH

import Test.Tasty.AutoCollect.GHC.Shim

{----- Output helpers -----}

showPpr :: Outputable a => a -> String
showPpr :: a -> String
showPpr = SDoc -> String
showSDocUnsafe (SDoc -> String) -> (a -> SDoc) -> a -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> SDoc
forall a. Outputable a => a -> SDoc
ppr

{----- Builders -----}

genFuncSig :: LocatedN RdrName -> LHsType GhcPs -> HsDecl GhcPs
genFuncSig :: LocatedN RdrName -> LHsType GhcPs -> HsDecl GhcPs
genFuncSig LocatedN RdrName
funcName LHsType GhcPs
funcType =
  XSigD GhcPs -> Sig GhcPs -> HsDecl GhcPs
forall p. XSigD p -> Sig p -> HsDecl p
SigD NoExtField
XSigD GhcPs
noExtField
    (Sig GhcPs -> HsDecl GhcPs)
-> (LHsType GhcPs -> Sig GhcPs) -> LHsType GhcPs -> HsDecl GhcPs
forall b c a. (b -> c) -> (a -> b) -> a -> c
. XTypeSig GhcPs
-> [Located (IdP GhcPs)] -> LHsSigWcType GhcPs -> Sig GhcPs
forall pass.
XTypeSig pass
-> [Located (IdP pass)] -> LHsSigWcType pass -> Sig pass
TypeSig NoExtField
XTypeSig GhcPs
noAnn [Located (IdP GhcPs)
LocatedN RdrName
funcName]
    (LHsSigWcType GhcPs -> Sig GhcPs)
-> (LHsType GhcPs -> LHsSigWcType GhcPs)
-> LHsType GhcPs
-> Sig GhcPs
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LHsType GhcPs -> LHsSigWcType GhcPs
hsTypeToHsSigWcType
    (LHsType GhcPs -> HsDecl GhcPs) -> LHsType GhcPs -> HsDecl GhcPs
forall a b. (a -> b) -> a -> b
$ LHsType GhcPs
funcType

-- | Make simple function declaration of the form `<funcName> <funcArgs> = <funcBody> where <funcWhere>`
genFuncDecl :: LocatedN RdrName -> [LPat GhcPs] -> LHsExpr GhcPs -> Maybe (HsLocalBinds GhcPs) -> HsDecl GhcPs
genFuncDecl :: LocatedN RdrName
-> [LPat GhcPs]
-> LHsExpr GhcPs
-> Maybe (HsLocalBinds GhcPs)
-> HsDecl GhcPs
genFuncDecl LocatedN RdrName
funcName [LPat GhcPs]
funcArgs LHsExpr GhcPs
funcBody Maybe (HsLocalBinds GhcPs)
mFuncWhere =
  XValD GhcPs -> HsBind GhcPs -> HsDecl GhcPs
forall p. XValD p -> HsBind p -> HsDecl p
ValD NoExtField
XValD GhcPs
NoExtField (HsBind GhcPs -> HsDecl GhcPs)
-> ([LMatch GhcPs (LHsExpr GhcPs)] -> HsBind GhcPs)
-> [LMatch GhcPs (LHsExpr GhcPs)]
-> HsDecl GhcPs
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Origin
-> LocatedN RdrName
-> [LMatch GhcPs (LHsExpr GhcPs)]
-> HsBind GhcPs
mkFunBind Origin
Generated LocatedN RdrName
funcName ([LMatch GhcPs (LHsExpr GhcPs)] -> HsDecl GhcPs)
-> [LMatch GhcPs (LHsExpr GhcPs)] -> HsDecl GhcPs
forall a b. (a -> b) -> a -> b
$
    [ HsMatchContext RdrName
-> [LPat GhcPs]
-> LHsExpr GhcPs
-> HsLocalBinds GhcPs
-> LMatch GhcPs (LHsExpr GhcPs)
mkMatch (LocatedN RdrName -> HsMatchContext RdrName
forall id. Located id -> HsMatchContext id
mkPrefixFunRhs LocatedN RdrName
funcName) [LPat GhcPs]
funcArgs LHsExpr GhcPs
funcBody HsLocalBinds GhcPs
funcWhere
    ]
  where
    funcWhere :: HsLocalBinds GhcPs
funcWhere = HsLocalBinds GhcPs
-> Maybe (HsLocalBinds GhcPs) -> HsLocalBinds GhcPs
forall a. a -> Maybe a -> a
fromMaybe HsLocalBinds GhcPs
forall (a :: Pass) (b :: Pass).
HsLocalBindsLR (GhcPass a) (GhcPass b)
emptyLocalBinds Maybe (HsLocalBinds GhcPs)
mFuncWhere

lhsvar :: LocatedN RdrName -> LHsExpr GhcPs
lhsvar :: LocatedN RdrName -> LHsExpr GhcPs
lhsvar = HsExpr GhcPs -> LHsExpr GhcPs
forall e ann. e -> GenLocated (SrcAnn ann) e
genLoc (HsExpr GhcPs -> LHsExpr GhcPs)
-> (LocatedN RdrName -> HsExpr GhcPs)
-> LocatedN RdrName
-> LHsExpr GhcPs
forall b c a. (b -> c) -> (a -> b) -> a -> c
. XVar GhcPs -> Located (IdP GhcPs) -> HsExpr GhcPs
forall p. XVar p -> Located (IdP p) -> HsExpr p
HsVar NoExtField
XVar GhcPs
NoExtField

mkHsAppTypes :: LHsExpr GhcPs -> [LHsType GhcPs] -> LHsExpr GhcPs
mkHsAppTypes :: LHsExpr GhcPs -> [LHsType GhcPs] -> LHsExpr GhcPs
mkHsAppTypes = (LHsExpr GhcPs -> LHsType GhcPs -> LHsExpr GhcPs)
-> LHsExpr GhcPs -> [LHsType GhcPs] -> LHsExpr GhcPs
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' LHsExpr GhcPs -> LHsType GhcPs -> LHsExpr GhcPs
mkHsAppType

mkHsAppType :: LHsExpr GhcPs -> LHsType GhcPs -> LHsExpr GhcPs
mkHsAppType :: LHsExpr GhcPs -> LHsType GhcPs -> LHsExpr GhcPs
mkHsAppType LHsExpr GhcPs
e LHsType GhcPs
t = HsExpr GhcPs -> LHsExpr GhcPs
forall e ann. e -> GenLocated (SrcAnn ann) e
genLoc (HsExpr GhcPs -> LHsExpr GhcPs) -> HsExpr GhcPs -> LHsExpr GhcPs
forall a b. (a -> b) -> a -> b
$ XAppTypeE GhcPs
-> LHsExpr GhcPs -> LHsWcType (NoGhcTc GhcPs) -> HsExpr GhcPs
forall p.
XAppTypeE p -> LHsExpr p -> LHsWcType (NoGhcTc p) -> HsExpr p
HsAppType XAppTypeE GhcPs
xAppTypeE LHsExpr GhcPs
e (XHsWC GhcPs (LHsType GhcPs)
-> LHsType GhcPs -> HsWildCardBndrs GhcPs (LHsType GhcPs)
forall pass thing.
XHsWC pass thing -> thing -> HsWildCardBndrs pass thing
HsWC NoExtField
XHsWC GhcPs (LHsType GhcPs)
noExtField LHsType GhcPs
t)

mkHsTyVar :: Name -> LHsType GhcPs
mkHsTyVar :: Name -> LHsType GhcPs
mkHsTyVar = HsType GhcPs -> LHsType GhcPs
forall e ann. e -> GenLocated (SrcAnn ann) e
genLoc (HsType GhcPs -> LHsType GhcPs)
-> (Name -> HsType GhcPs) -> Name -> LHsType GhcPs
forall b c a. (b -> c) -> (a -> b) -> a -> c
. XTyVar GhcPs
-> PromotionFlag -> Located (IdP GhcPs) -> HsType GhcPs
forall pass.
XTyVar pass -> PromotionFlag -> Located (IdP pass) -> HsType pass
HsTyVar NoExtField
XTyVar GhcPs
noAnn PromotionFlag
NotPromoted (LocatedN RdrName -> HsType GhcPs)
-> (Name -> LocatedN RdrName) -> Name -> HsType GhcPs
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RdrName -> LocatedN RdrName
forall e ann. e -> GenLocated (SrcAnn ann) e
genLoc (RdrName -> LocatedN RdrName)
-> (Name -> RdrName) -> Name -> LocatedN RdrName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> RdrName
forall thing. NamedThing thing => thing -> RdrName
getRdrName

-- | mkExprTypeSig <e> <t> = (<e> :: <t>)
mkExprTypeSig :: LHsExpr GhcPs -> LHsType GhcPs -> LHsExpr GhcPs
mkExprTypeSig :: LHsExpr GhcPs -> LHsType GhcPs -> LHsExpr GhcPs
mkExprTypeSig LHsExpr GhcPs
e LHsType GhcPs
t =
  HsExpr GhcPs -> LHsExpr GhcPs
forall e ann. e -> GenLocated (SrcAnn ann) e
genLoc (HsExpr GhcPs -> LHsExpr GhcPs)
-> (LHsSigWcType GhcPs -> HsExpr GhcPs)
-> LHsSigWcType GhcPs
-> LHsExpr GhcPs
forall b c a. (b -> c) -> (a -> b) -> a -> c
. XExprWithTySig GhcPs
-> LHsExpr GhcPs -> LHsSigWcType (NoGhcTc GhcPs) -> HsExpr GhcPs
forall p.
XExprWithTySig p
-> LHsExpr p -> LHsSigWcType (NoGhcTc p) -> HsExpr p
ExprWithTySig NoExtField
XExprWithTySig GhcPs
noAnn LHsExpr GhcPs
e (LHsSigWcType GhcPs -> LHsExpr GhcPs)
-> LHsSigWcType GhcPs -> LHsExpr GhcPs
forall a b. (a -> b) -> a -> b
$
    XHsWC GhcPs (LHsSigType GhcPs)
-> LHsSigType GhcPs -> LHsSigWcType GhcPs
forall pass thing.
XHsWC pass thing -> thing -> HsWildCardBndrs pass thing
HsWC NoExtField
XHsWC GhcPs (LHsSigType GhcPs)
NoExtField (LHsType GhcPs -> LHsSigType GhcPs
hsTypeToHsSigType LHsType GhcPs
t)

{----- Located utilities -----}

genLoc :: e -> GenLocated (SrcAnn ann) e
genLoc :: e -> GenLocated (SrcAnn ann) e
genLoc = SrcAnn ann -> e -> GenLocated (SrcAnn ann) e
forall l e. l -> e -> GenLocated l e
L SrcAnn ann
generatedSrcAnn

firstLocatedWhere :: Ord l => (GenLocated l e -> Maybe a) -> [GenLocated l e] -> Maybe a
firstLocatedWhere :: (GenLocated l e -> Maybe a) -> [GenLocated l e] -> Maybe a
firstLocatedWhere GenLocated l e -> Maybe a
f = [a] -> Maybe a
forall a. [a] -> Maybe a
listToMaybe ([a] -> Maybe a)
-> ([GenLocated l e] -> [a]) -> [GenLocated l e] -> Maybe a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (GenLocated l e -> Maybe a) -> [GenLocated l e] -> [a]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe GenLocated l e -> Maybe a
f ([GenLocated l e] -> [a])
-> ([GenLocated l e] -> [GenLocated l e])
-> [GenLocated l e]
-> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (GenLocated l e -> l) -> [GenLocated l e] -> [GenLocated l e]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn GenLocated l e -> l
forall l e. GenLocated l e -> l
getLoc

getSpanLine :: GenLocated (SrcSpanAnn' a) e -> String
getSpanLine :: GenLocated (SrcAnn ann) e -> String
getSpanLine GenLocated (SrcAnn ann) e
loc =
  case SrcAnn ann -> Either String RealSrcLoc
srcSpanStart (SrcAnn ann -> Either String RealSrcLoc)
-> SrcAnn ann -> Either String RealSrcLoc
forall a b. (a -> b) -> a -> b
$ GenLocated (SrcAnn ann) e -> SrcAnn ann
forall e. Located e -> SrcAnn ann
getLocA GenLocated (SrcAnn ann) e
loc of
    Right RealSrcLoc
srcLoc -> String
"line " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (RealSrcLoc -> Int
srcLocLine RealSrcLoc
srcLoc)
    Left String
s -> String
s

{----- Name utilities -----}

mkRdrName :: String -> RdrName
mkRdrName :: String -> RdrName
mkRdrName = OccName -> RdrName
mkRdrUnqual (OccName -> RdrName) -> (String -> OccName) -> String -> RdrName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> OccName
mkOccNameVar

mkLRdrName :: String -> LocatedN RdrName
mkLRdrName :: String -> LocatedN RdrName
mkLRdrName = RdrName -> LocatedN RdrName
forall e ann. e -> GenLocated (SrcAnn ann) e
genLoc (RdrName -> LocatedN RdrName)
-> (String -> RdrName) -> String -> LocatedN RdrName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> RdrName
mkRdrName

mkRdrNameType :: String -> RdrName
mkRdrNameType :: String -> RdrName
mkRdrNameType = OccName -> RdrName
mkRdrUnqual (OccName -> RdrName) -> (String -> OccName) -> String -> RdrName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> OccName
mkOccNameTC

mkLRdrNameType :: String -> LocatedN RdrName
mkLRdrNameType :: String -> LocatedN RdrName
mkLRdrNameType = RdrName -> LocatedN RdrName
forall e ann. e -> GenLocated (SrcAnn ann) e
genLoc (RdrName -> LocatedN RdrName)
-> (String -> RdrName) -> String -> LocatedN RdrName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> RdrName
mkRdrNameType

fromRdrName :: LocatedN RdrName -> String
fromRdrName :: LocatedN RdrName -> String
fromRdrName = OccName -> String
occNameString (OccName -> String)
-> (LocatedN RdrName -> OccName) -> LocatedN RdrName -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RdrName -> OccName
rdrNameOcc (RdrName -> OccName)
-> (LocatedN RdrName -> RdrName) -> LocatedN RdrName -> OccName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LocatedN RdrName -> RdrName
forall l e. GenLocated l e -> e
unLoc

-- https://gitlab.haskell.org/ghc/ghc/-/merge_requests/8492
thNameToGhcNameIO :: HscEnv -> IORef NameCache -> TH.Name -> IO (Maybe Name)
thNameToGhcNameIO :: HscEnv -> IORef NameCache -> Name -> IO (Maybe Name)
thNameToGhcNameIO HscEnv
hscEnv IORef NameCache
cache Name
name =
  ((Maybe Name, SimplCount) -> Maybe Name)
-> IO (Maybe Name, SimplCount) -> IO (Maybe Name)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Maybe Name, SimplCount) -> Maybe Name
forall a b. (a, b) -> a
fst
    (IO (Maybe Name, SimplCount) -> IO (Maybe Name))
-> (CoreM (Maybe Name) -> IO (Maybe Name, SimplCount))
-> CoreM (Maybe Name)
-> IO (Maybe Name)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HscEnv
-> RuleBase
-> Char
-> Module
-> ModuleSet
-> PrintUnqualified
-> SrcAnn ann
-> CoreM (Maybe Name)
-> IO (Maybe Name, SimplCount)
forall a.
HscEnv
-> RuleBase
-> Char
-> Module
-> ModuleSet
-> PrintUnqualified
-> SrcAnn ann
-> CoreM a
-> IO (a, SimplCount)
runCoreM
      HscEnv
hscEnv{hsc_NC :: IORef NameCache
hsc_NC = IORef NameCache
cache}
      (String -> RuleBase
forall a. String -> a
unused String
"cr_rule_base")
      (Char -> Char
forall a. a -> a
strict Char
'.')
      (String -> Module
forall a. String -> a
unused String
"cr_module")
      (ModuleSet -> ModuleSet
forall a. a -> a
strict ModuleSet
forall a. Monoid a => a
mempty)
      (String -> PrintUnqualified
forall a. String -> a
unused String
"cr_print_unqual")
      (String -> SrcAnn ann
forall a. String -> a
unused String
"cr_loc")
    (CoreM (Maybe Name) -> IO (Maybe Name))
-> CoreM (Maybe Name) -> IO (Maybe Name)
forall a b. (a -> b) -> a -> b
$ Name -> CoreM (Maybe Name)
thNameToGhcName Name
name
  where
    unused :: String -> a
unused String
msg = String -> a
forall a. HasCallStack => String -> a
error (String -> a) -> String -> a
forall a b. (a -> b) -> a -> b
$ String
"unexpectedly used: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
msg

    -- marks fields that are strict, so we can't use `unused`
    strict :: a -> a
strict = a -> a
forall a. a -> a
id