{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}
{-# OPTIONS_GHC -Wno-redundant-constraints -Wno-unused-matches #-}

module Calligraphy.Compat.Lib
  ( sourceInfo,
    showContextInfo,
    readHieFileCompat,
    isInstanceNode,
    isTypeSignatureNode,
    isInlineNode,
    isMinimalNode,
    isDerivingNode,
    showAnns,
    spanSpans,
  )
where

import Calligraphy.Util.Lens
import Data.IORef
import qualified Data.Set as Set

#if MIN_VERSION_ghc(9,0,0)
import GHC.Iface.Ext.Binary
import GHC.Iface.Ext.Types
import GHC.Types.Name.Cache
import GHC.Types.SrcLoc
import GHC.Utils.Outputable (ppr, showSDocUnsafe)
import qualified Data.Map as Map
#else
import HieBin
import HieTypes
import NameCache
import SrcLoc
#endif

{-# INLINE sourceInfo #-}
sourceInfo :: Traversal' (HieAST a) (NodeInfo a)
showContextInfo :: ContextInfo -> String
readHieFileCompat :: IORef NameCache -> FilePath -> IO HieFileResult
#if MIN_VERSION_ghc(9,0,0)

sourceInfo f (Node (SourcedNodeInfo inf) sp children) = (\inf' -> Node (SourcedNodeInfo inf') sp children) <$> Map.alterF (maybe (pure Nothing) (fmap Just . f)) SourceInfo inf

showContextInfo = showSDocUnsafe . ppr

readHieFileCompat ref = readHieFile (NCU (atomicModifyIORef ref))

#else

sourceInfo :: (NodeInfo a -> m (NodeInfo a)) -> HieAST a -> m (HieAST a)
sourceInfo NodeInfo a -> m (NodeInfo a)
f (Node NodeInfo a
inf Span
sp [HieAST a]
children) = (\NodeInfo a
inf' -> NodeInfo a -> Span -> [HieAST a] -> HieAST a
forall a. NodeInfo a -> Span -> [HieAST a] -> HieAST a
Node NodeInfo a
inf' Span
sp [HieAST a]
children) (NodeInfo a -> HieAST a) -> m (NodeInfo a) -> m (HieAST a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> NodeInfo a -> m (NodeInfo a)
f NodeInfo a
inf

showContextInfo :: ContextInfo -> String
showContextInfo = ContextInfo -> String
forall a. Show a => a -> String
show

readHieFileCompat :: IORef NameCache -> String -> IO HieFileResult
readHieFileCompat IORef NameCache
ref String
fp = do
  NameCache
cache <- IORef NameCache -> IO NameCache
forall a. IORef a -> IO a
readIORef IORef NameCache
ref
  (HieFileResult
res, NameCache
cache') <- NameCache -> String -> IO (HieFileResult, NameCache)
readHieFile NameCache
cache String
fp
  IORef NameCache -> NameCache -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef NameCache
ref NameCache
cache'
  HieFileResult -> IO HieFileResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure HieFileResult
res

#endif

isInstanceNode :: NodeInfo a -> Bool
isTypeSignatureNode :: NodeInfo a -> Bool
isInlineNode :: NodeInfo a -> Bool
isMinimalNode :: NodeInfo a -> Bool
isDerivingNode :: NodeInfo a -> Bool
showAnns :: NodeInfo a -> String
#if MIN_VERSION_ghc(9,2,0)

isInstanceNode (NodeInfo anns _ _) = any (flip Set.member anns) [NodeAnnotation "ClsInstD" "InstDecl", NodeAnnotation "DerivDecl" "DerivDecl"]

isTypeSignatureNode (NodeInfo anns _ _) = Set.member (NodeAnnotation "TypeSig" "Sig") anns

isInlineNode (NodeInfo anns _ _) = Set.member (NodeAnnotation "InlineSig" "Sig") anns

isMinimalNode (NodeInfo anns _ _) = Set.member (NodeAnnotation "MinimalSig" "Sig") anns

isDerivingNode (NodeInfo anns _ _) = Set.member (NodeAnnotation "HsDerivingClause" "HsDerivingClause") anns

showAnns (NodeInfo anns _ _) = unwords (show . unNodeAnnotation <$> Set.toList anns)
  where
    unNodeAnnotation (NodeAnnotation a b) = (a, b)

#else

isInstanceNode :: NodeInfo a -> Bool
isInstanceNode (NodeInfo Set (FastString, FastString)
anns [a]
_ NodeIdentifiers a
_) = ((FastString, FastString) -> Bool)
-> [(FastString, FastString)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (((FastString, FastString) -> Set (FastString, FastString) -> Bool)
-> Set (FastString, FastString) -> (FastString, FastString) -> Bool
forall a b c. (a -> b -> c) -> b -> a -> c
flip (FastString, FastString) -> Set (FastString, FastString) -> Bool
forall a. Ord a => a -> Set a -> Bool
Set.member Set (FastString, FastString)
anns) [(FastString
"ClsInstD", FastString
"InstDecl"), (FastString
"DerivDecl", FastString
"DerivDecl")]

isTypeSignatureNode :: NodeInfo a -> Bool
isTypeSignatureNode (NodeInfo Set (FastString, FastString)
anns [a]
_ NodeIdentifiers a
_) = (FastString, FastString) -> Set (FastString, FastString) -> Bool
forall a. Ord a => a -> Set a -> Bool
Set.member (FastString
"TypeSig", FastString
"Sig") Set (FastString, FastString)
anns

isInlineNode :: NodeInfo a -> Bool
isInlineNode (NodeInfo Set (FastString, FastString)
anns [a]
_ NodeIdentifiers a
_) = (FastString, FastString) -> Set (FastString, FastString) -> Bool
forall a. Ord a => a -> Set a -> Bool
Set.member (FastString
"InlineSig", FastString
"Sig") Set (FastString, FastString)
anns

isMinimalNode :: NodeInfo a -> Bool
isMinimalNode (NodeInfo Set (FastString, FastString)
anns [a]
_ NodeIdentifiers a
_) = (FastString, FastString) -> Set (FastString, FastString) -> Bool
forall a. Ord a => a -> Set a -> Bool
Set.member (FastString
"MinimalSig", FastString
"Sig") Set (FastString, FastString)
anns

isDerivingNode :: NodeInfo a -> Bool
isDerivingNode (NodeInfo Set (FastString, FastString)
anns [a]
_ NodeIdentifiers a
_) = (FastString, FastString) -> Set (FastString, FastString) -> Bool
forall a. Ord a => a -> Set a -> Bool
Set.member (FastString
"HsDerivingClause", FastString
"HsDerivingClause") Set (FastString, FastString)
anns

showAnns :: NodeInfo a -> String
showAnns (NodeInfo Set (FastString, FastString)
anns [a]
_ NodeIdentifiers a
_) = [String] -> String
unwords ((FastString, FastString) -> String
forall a. Show a => a -> String
show ((FastString, FastString) -> String)
-> [(FastString, FastString)] -> [String]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Set (FastString, FastString) -> [(FastString, FastString)]
forall a. Set a -> [a]
Set.toList Set (FastString, FastString)
anns)

#endif

spanSpans :: Span -> Span -> Span
spanSpans :: Span -> Span -> Span
spanSpans Span
sp1 Span
sp2 =
  RealSrcLoc -> RealSrcLoc -> Span
mkRealSrcSpan
    ( RealSrcLoc -> RealSrcLoc -> RealSrcLoc
forall a. Ord a => a -> a -> a
min
        (Span -> RealSrcLoc
realSrcSpanStart Span
sp1)
        (Span -> RealSrcLoc
realSrcSpanStart Span
sp2)
    )
    ( RealSrcLoc -> RealSrcLoc -> RealSrcLoc
forall a. Ord a => a -> a -> a
max
        (Span -> RealSrcLoc
realSrcSpanEnd Span
sp1)
        (Span -> RealSrcLoc
realSrcSpanEnd Span
sp2)
    )