module Foreign.Storable.Generic.Plugin.Internal
    ( groupTypes
    , gstorableSubstitution)
where
import Prelude hiding ((<>))
import CoreSyn (Bind(..),Expr(..), CoreExpr, CoreBind, CoreProgram, Alt)
import Literal (Literal(..))
import Id  (isLocalId, isGlobalId,Id, modifyInlinePragma, setInlinePragma, idInfo)
import IdInfo
import Var (Var(..))
import Name (getOccName,mkOccName)
import OccName (OccName(..), occNameString)
import qualified Name as N (varName)
import SrcLoc (noSrcSpan)
import Unique (getUnique)
import HscMain (hscCompileCoreExpr)
import HscTypes (HscEnv,ModGuts(..))
import CoreMonad
    (CoreM, CoreToDo(..),
     getHscEnv, getDynFlags, putMsg, putMsgS)
import BasicTypes (CompilerPhase(..))
import Type (isAlgType, splitTyConApp_maybe)
import TyCon (tyConKind, algTyConRhs, visibleDataCons)
import TyCoRep (Type(..), TyBinder(..))
import TysWiredIn (intDataCon)
import DataCon    (dataConWorkId,dataConOrigArgTys)
import MkCore (mkWildValBinder)
import Outputable
    (cat, ppr, SDoc, showSDocUnsafe, showSDoc,
     ($$), ($+$), hsep, vcat, empty,text,
     (<>), (<+>), nest, int, colon,hcat, comma,
     punctuate, fsep)
import Data.List
import Data.Maybe
import Data.Either
import Data.IORef
import Debug.Trace
import Control.Monad.IO.Class
import Control.Monad
import Foreign.Storable.Generic.Plugin.Internal.Error
import Foreign.Storable.Generic.Plugin.Internal.Compile
import Foreign.Storable.Generic.Plugin.Internal.GroupTypes
import Foreign.Storable.Generic.Plugin.Internal.Helpers
import Foreign.Storable.Generic.Plugin.Internal.Predicates
import Foreign.Storable.Generic.Plugin.Internal.Types
groupTypes_errors :: Flags -> [Error] -> CoreM ()
groupTypes_errors flags errors = do
    let (Flags verb to_crash) = flags
        crasher errs = case errs of
            [] -> return ()
            _  -> error "Crashing..."
        print_header txt = case verb of
            None  -> empty
            other ->    text "Errors while grouping types - types not found for: "
                     $$ nest 4 txt
        print_tyNotF verb id = case verb of
            None  -> empty
            other -> ppr id $$ nest 13 (text "::") <+> ppr (varType id)
        print_err    err = case err of
            TypeNotFound id -> print_tyNotF verb id
            other           -> pprError verb other
        printer errs = case errs of
            [] -> return ()
            ls ->  putMsg $ print_header (vcat (map print_err errs))
    
    
    printer errors
    when to_crash $ crasher errors
groupTypes_info :: Flags -> [[Type]] -> CoreM ()
groupTypes_info flags types = do
    let (Flags verb _) = flags
        
        print_header txt = case verb of
            None  -> empty
            other ->    text "GStorable instances will be optimised in the following order"
                    $+$ nest 4 txt
                    $+$ text ""
        print_layer layer ix = int ix <> text ":" <+> fsep (punctuate comma $ map ppr layer)
        
        printer groups = case groups of
            [] -> return ()
            _  -> putMsg $ print_header (vcat $ zipWith print_layer groups [1..])
    
    printer types
groupTypes :: Flags -> IORef [[Type]] -> ModGuts -> CoreM ModGuts
groupTypes flags type_order_ref guts = do
    let binds = mg_binds guts
        
        all_ids = concatMap getIdsBind binds
        with_typecheck = withTypeCheck getGStorableType isGStorableId
        predicate id = and [ with_typecheck id
                           , not (hasGStorableConstraints $ varType id)
                           ]
        gstorable_ids = filter predicate all_ids
        
        
        
        m_gstorable_types = map (getGStorableType.varType) gstorable_ids
        
        bad_types_zip id m_t = case m_t of
            Nothing -> Just $ TypeNotFound id
            Just _  -> Nothing
        bad_types     =    catMaybes $ zipWith bad_types_zip gstorable_ids m_gstorable_types
        
        type_list = [ t | Just t <- m_gstorable_types]
        
        (type_order,m_error) = calcGroupOrder type_list
    groupTypes_info flags type_order
    groupTypes_errors flags bad_types
    liftIO $ writeIORef type_order_ref type_order
    return guts
grouping_errors :: Flags            
                -> Maybe Error      
                -> CoreM [CoreBind] 
grouping_errors flags m_err = do
   let (Flags _ to_crash) = flags
       verb = Some
       crasher m_e = case m_e of
           Nothing -> return ()
           Just _  -> error "Crashing..."
       print_header txt = case verb of
           None  -> empty
           other ->    text "Errors while grouping bindings: "
                    $$ nest 4 txt
       printer m_err = case m_err of
           Nothing  -> return ()
           Just err ->  putMsg $ print_header (pprError verb err)
       ungroup m_e = case m_e of
           Just (OrderingFailedBinds _ rest) -> rest
           _                                 -> []
   printer m_err
   when to_crash $ crasher m_err
   return $ ungroup m_err
foundBinds_info :: Flags    
                -> [Id]     
                -> CoreM ()
foundBinds_info flags ids = do
    
    dyn_flags <- getDynFlags
    let (Flags verb _) = flags
        
        print_header txt = case verb of
            None  -> empty
            other ->    text "The following bindings are to be optimised:"
                    $+$ nest 4 txt
        print_binding id = ppr id
        max_nest = maximum $ 0 : map (length.(showSDoc dyn_flags).ppr) ids
        
        printer the_groups = case the_groups of
            [] -> return ()
            _  -> putMsg $ print_header $ vcat (map print_group the_groups)
        
        eqType_maybe (Just t1) (Just t2) = t1 `eqType` t2
        eqType_maybe _         _         = False
        
        grouped = groupBy (\i1 i2 -> (getGStorableType $ varType i1) `eqType_maybe` (getGStorableType $ varType i2) ) ids
        sorting = sortBy (\i1 i2 -> varName i1 `compare` varName i2)
        sorted  = map sorting grouped
        
        print_group the_group = case the_group of
            [] -> empty
            (h:_) -> case getGStorableType $ varType h of
                Just gtype ->     ppr  gtype
                              $+$ (fsep $ punctuate comma (map print_binding the_group))
                Nothing    -> ppr "Could not get the type of a binding:"
                              $+$ nest 4 (ppr h <+> text "::" <+> ppr (varType h))
    
    printer sorted
gstorableSubstitution :: Flags          
                      -> IORef [[Type]] 
                      -> ModGuts        
                      -> CoreM ModGuts  
gstorableSubstitution flags type_order_ref guts = do
    type_hierarchy <- liftIO $ readIORef type_order_ref
    let binds  = mg_binds guts
        
        
        typeCheck t = if hasGStorableConstraints t
            then Nothing
            else getGStorableMethodType t
        predicate = toIsBind (withTypeCheck typeCheck isGStorableMethodId)
        (gstorable_binds,rest) = partition predicate binds
        
        
        (nonrecs, recs) = partition isNonRecBind gstorable_binds
        
        (grouped_binds, m_err_group) = groupBinds type_hierarchy nonrecs
    foundBinds_info flags $ concatMap getIdsBind $ concat grouped_binds
    
    not_grouped <- grouping_errors flags m_err_group
    
    new_gstorables <- compileGroups flags grouped_binds rest 
    return $ guts {mg_binds = concat [new_gstorables, not_grouped,recs,rest]}