module Foreign.Storable.Generic.Plugin.Internal
( groupTypes
, gstorableSubstitution)
where
import Prelude hiding ((<>))
import CoreSyn (Bind(..),Expr(..), CoreExpr, CoreBind, CoreProgram, Alt)
import Literal (Literal(..))
import Id (isLocalId, isGlobalId,Id, modifyInlinePragma, setInlinePragma, idInfo)
import IdInfo
import Var (Var(..))
import Name (getOccName,mkOccName)
import OccName (OccName(..), occNameString)
import qualified Name as N (varName)
import SrcLoc (noSrcSpan)
import Unique (getUnique)
import HscMain (hscCompileCoreExpr)
import HscTypes (HscEnv,ModGuts(..))
import CoreMonad
(CoreM, CoreToDo(..),
getHscEnv, getDynFlags, putMsg, putMsgS)
import BasicTypes (CompilerPhase(..))
import Type (isAlgType, splitTyConApp_maybe)
import TyCon (tyConKind, algTyConRhs, visibleDataCons)
import TyCoRep (Type(..), TyBinder(..))
import TysWiredIn (intDataCon)
import DataCon (dataConWorkId,dataConOrigArgTys)
import MkCore (mkWildValBinder)
import Outputable
(cat, ppr, SDoc, showSDocUnsafe, showSDoc,
($$), ($+$), hsep, vcat, empty,text,
(<>), (<+>), nest, int, colon,hcat, comma,
punctuate, fsep)
import Data.List
import Data.Maybe
import Data.Either
import Data.IORef
import Debug.Trace
import Control.Monad.IO.Class
import Control.Monad
import Foreign.Storable.Generic.Plugin.Internal.Error
import Foreign.Storable.Generic.Plugin.Internal.Compile
import Foreign.Storable.Generic.Plugin.Internal.GroupTypes
import Foreign.Storable.Generic.Plugin.Internal.Helpers
import Foreign.Storable.Generic.Plugin.Internal.Predicates
import Foreign.Storable.Generic.Plugin.Internal.Types
groupTypes_errors :: Flags -> [Error] -> CoreM ()
groupTypes_errors flags errors = do
let (Flags verb to_crash) = flags
crasher errs = case errs of
[] -> return ()
_ -> error "Crashing..."
print_header txt = case verb of
None -> empty
other -> text "Errors while grouping types - types not found for: "
$$ nest 4 txt
print_tyNotF verb id = case verb of
None -> empty
other -> ppr id $$ nest 13 (text "::") <+> ppr (varType id)
print_err err = case err of
TypeNotFound id -> print_tyNotF verb id
other -> pprError verb other
printer errs = case errs of
[] -> return ()
ls -> putMsg $ print_header (vcat (map print_err errs))
printer errors
when to_crash $ crasher errors
groupTypes_info :: Flags -> [[Type]] -> CoreM ()
groupTypes_info flags types = do
let (Flags verb _) = flags
print_header txt = case verb of
None -> empty
other -> text "GStorable instances will be optimised in the following order"
$+$ nest 4 txt
$+$ text ""
print_layer layer ix = int ix <> text ":" <+> fsep (punctuate comma $ map ppr layer)
printer groups = case groups of
[] -> return ()
_ -> putMsg $ print_header (vcat $ zipWith print_layer groups [1..])
printer types
groupTypes :: Flags -> IORef [[Type]] -> ModGuts -> CoreM ModGuts
groupTypes flags type_order_ref guts = do
let binds = mg_binds guts
all_ids = concatMap getIdsBind binds
with_typecheck = withTypeCheck getGStorableType isGStorableId
predicate id = and [ with_typecheck id
, not (hasGStorableConstraints $ varType id)
]
gstorable_ids = filter predicate all_ids
m_gstorable_types = map (getGStorableType.varType) gstorable_ids
bad_types_zip id m_t = case m_t of
Nothing -> Just $ TypeNotFound id
Just _ -> Nothing
bad_types = catMaybes $ zipWith bad_types_zip gstorable_ids m_gstorable_types
type_list = [ t | Just t <- m_gstorable_types]
(type_order,m_error) = calcGroupOrder type_list
groupTypes_info flags type_order
groupTypes_errors flags bad_types
liftIO $ writeIORef type_order_ref type_order
return guts
grouping_errors :: Flags
-> Maybe Error
-> CoreM [CoreBind]
grouping_errors flags m_err = do
let (Flags _ to_crash) = flags
verb = Some
crasher m_e = case m_e of
Nothing -> return ()
Just _ -> error "Crashing..."
print_header txt = case verb of
None -> empty
other -> text "Errors while grouping bindings: "
$$ nest 4 txt
printer m_err = case m_err of
Nothing -> return ()
Just err -> putMsg $ print_header (pprError verb err)
ungroup m_e = case m_e of
Just (OrderingFailedBinds _ rest) -> rest
_ -> []
printer m_err
when to_crash $ crasher m_err
return $ ungroup m_err
foundBinds_info :: Flags
-> [Id]
-> CoreM ()
foundBinds_info flags ids = do
dyn_flags <- getDynFlags
let (Flags verb _) = flags
print_header txt = case verb of
None -> empty
other -> text "The following bindings are to be optimised:"
$+$ nest 4 txt
print_binding id = ppr id
max_nest = maximum $ 0 : map (length.(showSDoc dyn_flags).ppr) ids
printer the_groups = case the_groups of
[] -> return ()
_ -> putMsg $ print_header $ vcat (map print_group the_groups)
eqType_maybe (Just t1) (Just t2) = t1 `eqType` t2
eqType_maybe _ _ = False
grouped = groupBy (\i1 i2 -> (getGStorableType $ varType i1) `eqType_maybe` (getGStorableType $ varType i2) ) ids
sorting = sortBy (\i1 i2 -> varName i1 `compare` varName i2)
sorted = map sorting grouped
print_group the_group = case the_group of
[] -> empty
(h:_) -> case getGStorableType $ varType h of
Just gtype -> ppr gtype
$+$ (fsep $ punctuate comma (map print_binding the_group))
Nothing -> ppr "Could not get the type of a binding:"
$+$ nest 4 (ppr h <+> text "::" <+> ppr (varType h))
printer sorted
gstorableSubstitution :: Flags
-> IORef [[Type]]
-> ModGuts
-> CoreM ModGuts
gstorableSubstitution flags type_order_ref guts = do
type_hierarchy <- liftIO $ readIORef type_order_ref
let binds = mg_binds guts
typeCheck t = if hasGStorableConstraints t
then Nothing
else getGStorableMethodType t
predicate = toIsBind (isGStorableMethodId)
(gstorable_binds,rest) = partition predicate binds
(nonrecs, recs) = partition isNonRecBind gstorable_binds
(grouped_binds, m_err_group) = groupBinds type_hierarchy nonrecs
foundBinds_info flags $ concatMap getIdsBind $ concat grouped_binds
not_grouped <- grouping_errors flags m_err_group
new_gstorables <- compileGroups flags grouped_binds rest
return $ guts {mg_binds = concat [new_gstorables, not_grouped,recs,rest]}