{-|
  Copyright   :  (C) 2021, QBayLogic
  License     :  BSD2 (see the file LICENSE)
  Maintainer  :  QBayLogic B.V. <devops@qbaylogic.com>
-}

{-# LANGUAGE CPP #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE QuasiQuotes #-}

module Clash.GHCi.Common
  ( checkImportDirs
  , checkMonoLocalBinds
  , checkMonoLocalBindsMod
  , getMainTopEntity
  ) where

-- Clash
import           Clash.Driver.Types     (ClashOpts (..), BindingMap)
import           Clash.Netlist.Types    (TopEntityT(..))

-- The GHC interface
#if MIN_VERSION_ghc(9,0,0)
import qualified GHC.Data.EnumSet       as GHC (member)
import           GHC.Utils.Panic        (GhcException (..), throwGhcException)
import qualified GHC
  (DynFlags, ModSummary (..), extensionFlags, moduleName, moduleNameString)
#else
import qualified EnumSet                as GHC (member)
import           Panic                  (GhcException (..), throwGhcException)
import qualified GHC                    (DynFlags, ModSummary (..), Module (..),
                                         extensionFlags, moduleNameString)
#endif
import           Clash.Core.Name        (nameOcc)
import           Clash.Core.Var         (varName)
import           Clash.Normalize.Util   (collectCallGraphUniques, callGraph)
import qualified Clash.Util.Interpolate as I
import           Clash.Util             (ClashException(..), HasCallStack, noSrcSpan)
import           Clash.Unique           (getUnique)
import           Control.Exception      (throw)
import           Data.List              (isSuffixOf)
import qualified Data.Text              as Text
import qualified Data.HashSet           as HashSet
import qualified GHC.LanguageExtensions as LangExt (Extension (..))

import           Control.Monad          (forM_, unless, when)
import           System.Directory       (doesDirectoryExist)
import           System.IO              (hPutStrLn, stderr)

getMainTopEntity
  :: HasCallStack
  => String
  -- ^ Module name
  -> BindingMap
  -- ^ Map of global binders
  -> [TopEntityT]
  -- ^ List of top entities loaded by LoadModules
  -> String
  -- ^ string passed with -main-is
  -> IO (TopEntityT, [TopEntityT])
  -- ^ Throws exception if -main-is was set, but no such top entity was found.
  -- Otherwise, returns main top entity and all top entities (transitively) used
  -- in the main top entity.
getMainTopEntity :: String
-> BindingMap
-> [TopEntityT]
-> String
-> IO (TopEntityT, [TopEntityT])
getMainTopEntity String
modName BindingMap
bindingMap [TopEntityT]
topEnts String
nm =
  case (TopEntityT -> Bool) -> [TopEntityT] -> [TopEntityT]
forall a. (a -> Bool) -> [a] -> [a]
filter TopEntityT -> Bool
isNm [TopEntityT]
topEnts of
    [] -> ClashException -> IO (TopEntityT, [TopEntityT])
forall a e. Exception e => e -> a
throw (ClashException -> IO (TopEntityT, [TopEntityT]))
-> ClashException -> IO (TopEntityT, [TopEntityT])
forall a b. (a -> b) -> a -> b
$ SrcSpan -> String -> Maybe String -> ClashException
ClashException SrcSpan
noSrcSpan [I.i|
      Could not find top entity called #{show nm} in #{show modName}
    |] Maybe String
forall a. Maybe a
Nothing
    [TopEntityT
t] ->
      let
        closure0 :: HashSet Unique
closure0 = CallGraph -> HashSet Unique
collectCallGraphUniques (BindingMap -> Id -> CallGraph
callGraph BindingMap
bindingMap (TopEntityT -> Id
topId TopEntityT
t))
        closure1 :: HashSet Unique
closure1 = Unique -> HashSet Unique -> HashSet Unique
forall a. (Eq a, Hashable a) => a -> HashSet a -> HashSet a
HashSet.delete (Id -> Unique
forall a. Uniquable a => a -> Unique
getUnique (TopEntityT -> Id
topId TopEntityT
t)) HashSet Unique
closure0
      in
        (TopEntityT, [TopEntityT]) -> IO (TopEntityT, [TopEntityT])
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (TopEntityT
t, (TopEntityT -> Bool) -> [TopEntityT] -> [TopEntityT]
forall a. (a -> Bool) -> [a] -> [a]
filter ((Unique -> HashSet Unique -> Bool
forall a. (Eq a, Hashable a) => a -> HashSet a -> Bool
`HashSet.member` HashSet Unique
closure1) (Unique -> Bool) -> (TopEntityT -> Unique) -> TopEntityT -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Id -> Unique
forall a. Uniquable a => a -> Unique
getUnique (Id -> Unique) -> (TopEntityT -> Id) -> TopEntityT -> Unique
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TopEntityT -> Id
topId) [TopEntityT]
topEnts)
    [TopEntityT]
ts ->
      String -> IO (TopEntityT, [TopEntityT])
forall a. HasCallStack => String -> a
error (String -> IO (TopEntityT, [TopEntityT]))
-> String -> IO (TopEntityT, [TopEntityT])
forall a b. (a -> b) -> a -> b
$ [I.i|
        Internal error: multiple top entities called #{nm} (#{map topId ts})
        found in #{modName}.
      |]
 where
  isNm :: TopEntityT -> Bool
isNm (TopEntityT{Id
topId :: Id
topId :: TopEntityT -> Id
topId}) =
    let topIdNm :: String
topIdNm = Text -> String
Text.unpack (Name Term -> Text
forall a. Name a -> Text
nameOcc (Id -> Name Term
forall a. Var a -> Name a
varName Id
topId)) in
    String
topIdNm String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
nm Bool -> Bool -> Bool
|| (Char
'.'Char -> String -> String
forall a. a -> [a] -> [a]
:String
nm) String -> String -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isSuffixOf` String
topIdNm

-- | Checks whether MonoLocalBinds and MonomorphismRestricton language extensions
-- are enabled or not in modules.
checkMonoLocalBindsMod :: GHC.ModSummary -> IO ()
checkMonoLocalBindsMod :: ModSummary -> IO ()
checkMonoLocalBindsMod ModSummary
x = do
  Bool -> IO () -> IO ()
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Extension -> DynFlags -> Bool
active Extension
LangExt.MonoLocalBinds (DynFlags -> Bool)
-> (ModSummary -> DynFlags) -> ModSummary -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ModSummary -> DynFlags
GHC.ms_hspp_opts (ModSummary -> Bool) -> ModSummary -> Bool
forall a b. (a -> b) -> a -> b
$ ModSummary
x)
         (Handle -> String -> IO ()
hPutStrLn Handle
stderr (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ Extension -> ModSummary -> String
msg Extension
LangExt.MonoLocalBinds ModSummary
x)
  Bool -> IO () -> IO ()
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Extension -> DynFlags -> Bool
active Extension
LangExt.MonomorphismRestriction (DynFlags -> Bool)
-> (ModSummary -> DynFlags) -> ModSummary -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ModSummary -> DynFlags
GHC.ms_hspp_opts (ModSummary -> Bool) -> ModSummary -> Bool
forall a b. (a -> b) -> a -> b
$ ModSummary
x)
         (Handle -> String -> IO ()
hPutStrLn Handle
stderr (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ Extension -> ModSummary -> String
msg Extension
LangExt.MonomorphismRestriction ModSummary
x)
  where
    msg :: Extension -> ModSummary -> String
msg Extension
ext = Extension -> String -> String
messageWith Extension
ext (String -> String)
-> (ModSummary -> String) -> ModSummary -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ModuleName -> String
GHC.moduleNameString (ModuleName -> String)
-> (ModSummary -> ModuleName) -> ModSummary -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Module -> ModuleName
GHC.moduleName (Module -> ModuleName)
-> (ModSummary -> Module) -> ModSummary -> ModuleName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ModSummary -> Module
GHC.ms_mod

-- | Checks whether MonoLocalBinds and MonomorphismRestriction language extensions
-- are enabled when generating the HDL directly e.g. in GHCi. modules.
checkMonoLocalBinds :: GHC.DynFlags -> IO ()
checkMonoLocalBinds :: DynFlags -> IO ()
checkMonoLocalBinds DynFlags
dflags = do
  Bool -> IO () -> IO ()
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Extension -> DynFlags -> Bool
active Extension
LangExt.MonoLocalBinds DynFlags
dflags)
         (Handle -> String -> IO ()
hPutStrLn Handle
stderr (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ Extension -> String -> String
messageWith Extension
LangExt.MonoLocalBinds String
"")
  Bool -> IO () -> IO ()
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Extension -> DynFlags -> Bool
active Extension
LangExt.MonomorphismRestriction DynFlags
dflags)
         (Handle -> String -> IO ()
hPutStrLn Handle
stderr (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ Extension -> String -> String
messageWith Extension
LangExt.MonomorphismRestriction String
"")

messageWith :: LangExt.Extension -> String -> String
messageWith :: Extension -> String -> String
messageWith Extension
ext String
srcModule
  | String
srcModule String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== []  = String
msgStem String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"."
  | Bool
otherwise = String
msgStem String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" in module: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
srcModule
  where
    msgStem :: String
msgStem = String
"Warning: Extension " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Extension -> String
forall a. Show a => a -> String
show Extension
ext String -> String -> String
forall a. Semigroup a => a -> a -> a
<>
              String
" is disabled. This might lead to unexpected logic duplication"

active :: LangExt.Extension -> GHC.DynFlags -> Bool
active :: Extension -> DynFlags -> Bool
active Extension
ext = Extension -> EnumSet Extension -> Bool
forall a. Enum a => a -> EnumSet a -> Bool
GHC.member Extension
ext (EnumSet Extension -> Bool)
-> (DynFlags -> EnumSet Extension) -> DynFlags -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DynFlags -> EnumSet Extension
GHC.extensionFlags

checkImportDirs :: Foldable t => ClashOpts -> t FilePath -> IO ()
checkImportDirs :: ClashOpts -> t String -> IO ()
checkImportDirs ClashOpts
opts t String
idirs = Bool -> IO () -> IO ()
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
when (ClashOpts -> Bool
opt_checkIDir ClashOpts
opts) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
  t String -> (String -> IO ()) -> IO ()
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ t String
idirs ((String -> IO ()) -> IO ()) -> (String -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \String
dir -> do
    String -> IO Bool
doesDirectoryExist String
dir IO Bool -> (Bool -> IO ()) -> IO ()
forall (m :: Type -> Type) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Bool
False -> GhcException -> IO ()
forall a. GhcException -> a
throwGhcException (String -> GhcException
CmdLineError (String -> GhcException) -> String -> GhcException
forall a b. (a -> b) -> a -> b
$ String
"Missing directory: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
dir)
      Bool
_     -> () -> IO ()
forall (m :: Type -> Type) a. Monad m => a -> m a
return ()