{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE LambdaCase #-}

-- | JS codegen state monad
module GHC.StgToJS.Monad
  ( runG
  , emitGlobal
  , addDependency
  , emitToplevel
  , emitStatic
  , emitClosureInfo
  , emitForeign
  , assertRtsStat
  , getSettings
  , globalOccs
  , setGlobalIdCache
  , getGlobalIdCache
  , GlobalOcc(..)
  -- * Group
  , modifyGroup
  , resetGroup
  )
where

import GHC.Prelude

import GHC.JS.Syntax
import GHC.JS.Transform

import GHC.StgToJS.Types

import GHC.Unit.Module
import GHC.Stg.Syntax

import GHC.Types.SrcLoc
import GHC.Types.Id
import GHC.Types.Unique.FM
import GHC.Types.ForeignCall

import qualified Control.Monad.Trans.State.Strict as State
import GHC.Data.FastString
import GHC.Data.FastMutInt

import qualified Data.Map  as M
import qualified Data.Set  as S
import qualified Data.List as L

runG :: StgToJSConfig -> Module -> UniqFM Id CgStgExpr -> G a -> IO a
runG :: forall a.
StgToJSConfig -> Module -> UniqFM Id CgStgExpr -> G a -> IO a
runG StgToJSConfig
config Module
m UniqFM Id CgStgExpr
unfloat G a
action = forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
State.evalStateT G a
action forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< StgToJSConfig -> Module -> UniqFM Id CgStgExpr -> IO GenState
initState StgToJSConfig
config Module
m UniqFM Id CgStgExpr
unfloat

initState :: StgToJSConfig -> Module -> UniqFM Id CgStgExpr -> IO GenState
initState :: StgToJSConfig -> Module -> UniqFM Id CgStgExpr -> IO GenState
initState StgToJSConfig
config Module
m UniqFM Id CgStgExpr
unfloat = do
  FastMutInt
id_gen <- Int -> IO FastMutInt
newFastMutInt Int
1
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ GenState
    { gsSettings :: StgToJSConfig
gsSettings  = StgToJSConfig
config
    , gsModule :: Module
gsModule    = Module
m
    , gsId :: FastMutInt
gsId        = FastMutInt
id_gen
    , gsIdents :: IdCache
gsIdents    = IdCache
emptyIdCache
    , gsUnfloated :: UniqFM Id CgStgExpr
gsUnfloated = UniqFM Id CgStgExpr
unfloat
    , gsGroup :: GenGroupState
gsGroup     = GenGroupState
defaultGenGroupState
    , gsGlobal :: [JStat]
gsGlobal    = []
    }


modifyGroup :: (GenGroupState -> GenGroupState) -> G ()
modifyGroup :: (GenGroupState -> GenGroupState) -> G ()
modifyGroup GenGroupState -> GenGroupState
f = forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
State.modify GenState -> GenState
mod_state
  where
    mod_state :: GenState -> GenState
mod_state GenState
s = GenState
s { gsGroup :: GenGroupState
gsGroup = GenGroupState -> GenGroupState
f (GenState -> GenGroupState
gsGroup GenState
s) }

-- | emit a global (for the current module) toplevel statement
emitGlobal :: JStat -> G ()
emitGlobal :: JStat -> G ()
emitGlobal JStat
stat = forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
State.modify (\GenState
s -> GenState
s { gsGlobal :: [JStat]
gsGlobal = JStat
stat forall a. a -> [a] -> [a]
: GenState -> [JStat]
gsGlobal GenState
s })

-- | add a dependency on a particular symbol to the current group
addDependency :: OtherSymb -> G ()
addDependency :: OtherSymb -> G ()
addDependency OtherSymb
symbol = (GenGroupState -> GenGroupState) -> G ()
modifyGroup GenGroupState -> GenGroupState
mod_group
  where
    mod_group :: GenGroupState -> GenGroupState
mod_group GenGroupState
g = GenGroupState
g { ggsExtraDeps :: Set OtherSymb
ggsExtraDeps = forall a. Ord a => a -> Set a -> Set a
S.insert OtherSymb
symbol (GenGroupState -> Set OtherSymb
ggsExtraDeps GenGroupState
g) }

-- | emit a top-level statement for the current binding group
emitToplevel :: JStat -> G ()
emitToplevel :: JStat -> G ()
emitToplevel JStat
s = (GenGroupState -> GenGroupState) -> G ()
modifyGroup GenGroupState -> GenGroupState
mod_group
  where
    mod_group :: GenGroupState -> GenGroupState
mod_group GenGroupState
g = GenGroupState
g { ggsToplevelStats :: [JStat]
ggsToplevelStats = JStat
s forall a. a -> [a] -> [a]
: GenGroupState -> [JStat]
ggsToplevelStats GenGroupState
g}

-- | emit static data for the binding group
emitStatic :: FastString -> StaticVal -> Maybe Ident -> G ()
emitStatic :: FastString -> StaticVal -> Maybe Ident -> G ()
emitStatic FastString
ident StaticVal
val Maybe Ident
cc = (GenGroupState -> GenGroupState) -> G ()
modifyGroup GenGroupState -> GenGroupState
mod_group
  where
    mod_group :: GenGroupState -> GenGroupState
mod_group  GenGroupState
g = GenGroupState
g { ggsStatic :: [StaticInfo]
ggsStatic = [StaticInfo] -> [StaticInfo]
mod_static (GenGroupState -> [StaticInfo]
ggsStatic GenGroupState
g) }
    mod_static :: [StaticInfo] -> [StaticInfo]
mod_static [StaticInfo]
s = FastString -> StaticVal -> Maybe Ident -> StaticInfo
StaticInfo FastString
ident StaticVal
val Maybe Ident
cc forall a. a -> [a] -> [a]
: [StaticInfo]
s

-- | add closure info in our binding group. all heap objects must have closure info
emitClosureInfo :: ClosureInfo -> G ()
emitClosureInfo :: ClosureInfo -> G ()
emitClosureInfo ClosureInfo
ci = (GenGroupState -> GenGroupState) -> G ()
modifyGroup GenGroupState -> GenGroupState
mod_group
  where
    mod_group :: GenGroupState -> GenGroupState
mod_group GenGroupState
g = GenGroupState
g { ggsClosureInfo :: [ClosureInfo]
ggsClosureInfo = ClosureInfo
ci forall a. a -> [a] -> [a]
: GenGroupState -> [ClosureInfo]
ggsClosureInfo GenGroupState
g}

emitForeign :: Maybe RealSrcSpan
            -> FastString
            -> Safety
            -> CCallConv
            -> [FastString]
            -> FastString
            -> G ()
emitForeign :: Maybe RealSrcSpan
-> FastString
-> Safety
-> CCallConv
-> [FastString]
-> FastString
-> G ()
emitForeign Maybe RealSrcSpan
mbSpan FastString
pat Safety
safety CCallConv
cconv [FastString]
arg_tys FastString
res_ty = (GenGroupState -> GenGroupState) -> G ()
modifyGroup GenGroupState -> GenGroupState
mod_group
  where
    mod_group :: GenGroupState -> GenGroupState
mod_group GenGroupState
g = GenGroupState
g { ggsForeignRefs :: [ForeignJSRef]
ggsForeignRefs = ForeignJSRef
new_ref forall a. a -> [a] -> [a]
: GenGroupState -> [ForeignJSRef]
ggsForeignRefs GenGroupState
g }
    new_ref :: ForeignJSRef
new_ref = FastString
-> FastString
-> Safety
-> CCallConv
-> [FastString]
-> FastString
-> ForeignJSRef
ForeignJSRef FastString
spanTxt FastString
pat Safety
safety CCallConv
cconv [FastString]
arg_tys FastString
res_ty
    spanTxt :: FastString
spanTxt = case Maybe RealSrcSpan
mbSpan of
                -- TODO: Is there a better way to concatenate FastStrings?
                Just RealSrcSpan
sp -> [Char] -> FastString
mkFastString forall a b. (a -> b) -> a -> b
$
                  FastString -> [Char]
unpackFS (RealSrcSpan -> FastString
srcSpanFile RealSrcSpan
sp) forall a. [a] -> [a] -> [a]
++
                  [Char]
" " forall a. [a] -> [a] -> [a]
++
                  forall a. Show a => a -> [Char]
show (RealSrcSpan -> Int
srcSpanStartLine RealSrcSpan
sp, RealSrcSpan -> Int
srcSpanStartCol RealSrcSpan
sp) forall a. [a] -> [a] -> [a]
++
                  [Char]
"-" forall a. [a] -> [a] -> [a]
++
                  forall a. Show a => a -> [Char]
show (RealSrcSpan -> Int
srcSpanEndLine RealSrcSpan
sp, RealSrcSpan -> Int
srcSpanEndCol RealSrcSpan
sp)
                Maybe RealSrcSpan
Nothing -> FastString
"<unknown>"






-- | start with a new binding group
resetGroup :: G ()
resetGroup :: G ()
resetGroup = forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
State.modify (\GenState
s -> GenState
s { gsGroup :: GenGroupState
gsGroup = GenGroupState
defaultGenGroupState })

defaultGenGroupState :: GenGroupState
defaultGenGroupState :: GenGroupState
defaultGenGroupState = [JStat]
-> [ClosureInfo]
-> [StaticInfo]
-> [StackSlot]
-> Int
-> Set OtherSymb
-> GlobalIdCache
-> [ForeignJSRef]
-> GenGroupState
GenGroupState [] [] [] [] Int
0 forall a. Set a
S.empty GlobalIdCache
emptyGlobalIdCache []

emptyGlobalIdCache :: GlobalIdCache
emptyGlobalIdCache :: GlobalIdCache
emptyGlobalIdCache = UniqFM Ident (IdKey, Id) -> GlobalIdCache
GlobalIdCache forall key elt. UniqFM key elt
emptyUFM

emptyIdCache :: IdCache
emptyIdCache :: IdCache
emptyIdCache = Map IdKey Ident -> IdCache
IdCache forall k a. Map k a
M.empty



assertRtsStat :: G JStat -> G JStat
assertRtsStat :: G JStat -> G JStat
assertRtsStat G JStat
stat = do
  StgToJSConfig
s <- forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
State.gets GenState -> StgToJSConfig
gsSettings
  if StgToJSConfig -> Bool
csAssertRts StgToJSConfig
s then G JStat
stat else forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty

getSettings :: G StgToJSConfig
getSettings :: G StgToJSConfig
getSettings = forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
State.gets GenState -> StgToJSConfig
gsSettings

getGlobalIdCache :: G GlobalIdCache
getGlobalIdCache :: G GlobalIdCache
getGlobalIdCache = forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
State.gets (GenGroupState -> GlobalIdCache
ggsGlobalIdCache forall b c a. (b -> c) -> (a -> b) -> a -> c
. GenState -> GenGroupState
gsGroup)

setGlobalIdCache :: GlobalIdCache -> G ()
setGlobalIdCache :: GlobalIdCache -> G ()
setGlobalIdCache GlobalIdCache
v = forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
State.modify (\GenState
s -> GenState
s { gsGroup :: GenGroupState
gsGroup = (GenState -> GenGroupState
gsGroup GenState
s) { ggsGlobalIdCache :: GlobalIdCache
ggsGlobalIdCache = GlobalIdCache
v}})


data GlobalOcc = GlobalOcc
  { GlobalOcc -> Ident
global_ident :: !Ident
  , GlobalOcc -> Id
global_id    :: !Id
  , GlobalOcc -> Word
global_count :: !Word
  }

-- | Return number of occurrences of every global id used in the given JStat.
-- Sort by increasing occurrence count.
globalOccs :: JStat -> G [GlobalOcc]
globalOccs :: JStat -> G [GlobalOcc]
globalOccs JStat
jst = do
  GlobalIdCache UniqFM Ident (IdKey, Id)
gidc <- G GlobalIdCache
getGlobalIdCache
  -- build a map form Ident Unique to (Ident, Id, Count)
  let
    cmp_cnt :: GlobalOcc -> GlobalOcc -> Ordering
cmp_cnt GlobalOcc
g1 GlobalOcc
g2 = forall a. Ord a => a -> a -> Ordering
compare (GlobalOcc -> Word
global_count GlobalOcc
g1) (GlobalOcc -> Word
global_count GlobalOcc
g2)
    inc :: GlobalOcc -> GlobalOcc -> GlobalOcc
inc GlobalOcc
g1 GlobalOcc
g2 = GlobalOcc
g1 { global_count :: Word
global_count = GlobalOcc -> Word
global_count GlobalOcc
g1 forall a. Num a => a -> a -> a
+ GlobalOcc -> Word
global_count GlobalOcc
g2 }
    go :: UniqFM Ident GlobalOcc -> [Ident] -> [GlobalOcc]
go UniqFM Ident GlobalOcc
gids = \case
        []     -> -- return global Ids used locally sorted by increased use
                  forall a. (a -> a -> Ordering) -> [a] -> [a]
L.sortBy GlobalOcc -> GlobalOcc -> Ordering
cmp_cnt forall a b. (a -> b) -> a -> b
$ forall key elt. UniqFM key elt -> [elt]
nonDetEltsUFM UniqFM Ident GlobalOcc
gids
        (Ident
i:[Ident]
is) ->
          -- check if the Id is global
          case forall key elt. Uniquable key => UniqFM key elt -> key -> Maybe elt
lookupUFM UniqFM Ident (IdKey, Id)
gidc Ident
i of
            Maybe (IdKey, Id)
Nothing       -> UniqFM Ident GlobalOcc -> [Ident] -> [GlobalOcc]
go UniqFM Ident GlobalOcc
gids [Ident]
is
            Just (IdKey
_k,Id
gid) ->
              -- add it to the list of already found global ids. Increasing
              -- count by 1
              let g :: GlobalOcc
g = Ident -> Id -> Word -> GlobalOcc
GlobalOcc Ident
i Id
gid Word
1
              in UniqFM Ident GlobalOcc -> [Ident] -> [GlobalOcc]
go (forall key elt.
Uniquable key =>
(elt -> elt -> elt)
-> UniqFM key elt -> key -> elt -> UniqFM key elt
addToUFM_C GlobalOcc -> GlobalOcc -> GlobalOcc
inc UniqFM Ident GlobalOcc
gids Ident
i GlobalOcc
g) [Ident]
is

  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ UniqFM Ident GlobalOcc -> [Ident] -> [GlobalOcc]
go forall key elt. UniqFM key elt
emptyUFM (JStat -> [Ident]
identsS JStat
jst)