-- | Float let-bindings with a single use forward into their use-sites.
module DDC.Core.Transform.Forward
        ( ForwardInfo   (..)
        , forwardModule
        , forwardX)
where
import DDC.Base.Pretty
import DDC.Core.Analysis.Usage
import DDC.Core.Exp
import DDC.Core.Module
import DDC.Core.Simplifier.Base
import DDC.Core.Fragment
import DDC.Core.Predicates
import Data.Map                 (Map)
import Control.Monad
import Control.Monad.Writer	(Writer, runWriter, tell)
import Data.Monoid		(Monoid, mempty, mappend)
import Data.Typeable
import qualified Data.Map                               as Map
import qualified DDC.Core.Transform.SubstituteXX	as S

-------------------------------------------------------------------------------
-- | Summary of number of bindings floated.
data ForwardInfo
        = ForwardInfo
        { -- | Number of trivial @v1 = v2@ bindings inlined.
          infoSubsts   :: Int

          -- | Number of bindings floated forwards.
        , infoBindings :: Int }
        deriving Typeable


instance Pretty ForwardInfo where
 ppr (ForwardInfo substs bindings)
  =  text "Forward:"
  <$> indent 4 (vcat
      [ text "Substitutions:  " <> int substs
      , text "Bindings:       " <> int bindings ])


instance Monoid ForwardInfo where
 mempty = ForwardInfo 0 0
 mappend (ForwardInfo s1 b1)(ForwardInfo s2 b2)
        = ForwardInfo (s1 + s2) (b1 + b2)


-------------------------------------------------------------------------------
-- | Float let-bindings in a module with a single use forward into
--   their use sites.
forwardModule 
        :: Ord n
        => Profile n -> Module a n -> Module a n

forwardModule profile mm
        = fst   $ runWriter
                $ forwardWith profile Map.empty 
                $ usageModule mm


-- | Float let-bindings in an expression with a single use forward into
--   their use-sites.
forwardX :: Ord n
         => Profile n -> Exp a n -> TransformResult (Exp a n)
forwardX profile xx
 = let  (x',info) = runWriter
		  $ forwardWith profile Map.empty
		  $ usageX xx

        progress (ForwardInfo s _) 
                = s > 0

   in  TransformResult
        { result	 = x'
        , resultProgress = progress info
        , resultAgain    = False
        , resultInfo	 = TransformInfo info }


-------------------------------------------------------------------------------
class Forward (c :: * -> * -> *) where
 -- | Carry bindings forward and downward into their use-sites.
 forwardWith 
        :: Ord n
        => Profile n
        -> Map n (Exp a n)
        -> c (UsedMap n, a) n
        -> Writer ForwardInfo (c a n)

instance Forward Module where
 forwardWith profile bindings 
        (ModuleCore
                { moduleName            = name
                , moduleExportKinds     = exportKinds
                , moduleExportTypes     = exportTypes
                , moduleImportKinds     = importKinds
                , moduleImportTypes     = importTypes
                , moduleBody            = body })

  = do	body' <- forwardWith profile bindings body
	return ModuleCore
		{ moduleName            = name
		, moduleExportKinds     = exportKinds
		, moduleExportTypes     = exportTypes
		, moduleImportKinds     = importKinds
		, moduleImportTypes     = importTypes
		, moduleBody            = body' }


instance Forward Exp where
 forwardWith profile bindings xx
  = {-# SCC forwardWith #-}
    let down    = forwardWith profile bindings 
    in case xx of
        XVar a u@(UName n)
         -> case Map.lookup n bindings of
                Just xx'        -> do
		    tell mempty { infoSubsts = 1 }
		    return xx'
                Nothing         ->
		    return $ XVar (snd a) u

        XVar a u        -> return $ XVar (snd a) u
        XCon a u        -> return $ XCon (snd a) u
        XLAM a b x      -> liftM    (XLAM (snd a) b) (down x)
        XLam a b x      -> liftM    (XLam (snd a) b) (down x)
        XApp a x1 x2    -> liftM2   (XApp (snd a))   (down x1) (down x2)

        XLet (UsedMap um, _) (LLet _mode (BName n _) x1) x2
         | isXLam x1 || isXLAM x1
         , Just usage     <- Map.lookup n um
         , [UsedFunction] <- filterUsedInCasts usage
	 -> do
                -- Record that we've moved this binding.
                tell mempty { infoBindings = 1 }
                x1'           <- down x1
                forwardWith profile (Map.insert n x1' bindings) x2

	-- Always float atomic bindings (variables, constructors)
        XLet _ (LLet _mode b x1) x2
	 | isAtomX x1
	 -> do 
                -- Record that we've moved this binding.
                tell mempty { infoBindings = 1 }

                -- Slow, but handles anonymous binders and shadowing
                down $ S.substituteXX b x1 x2

        XLet (_, a') lts x     
         -> liftM2 (XLet a') (down lts) (down x)

        XCase a x alts  -> liftM2   (XCase (snd a)) (down x) (mapM down alts)
        XCast a c x     -> liftM2   (XCast (snd a)) (down c) (down x)
        XType t         -> return $ XType t
        XWitness w      -> return $ XWitness w


filterUsedInCasts :: [Used] -> [Used]
filterUsedInCasts = filter notCast
 where  notCast UsedInCast      = False
        notCast _               = True


instance Forward Cast where
 forwardWith profile bindings xx
  = let down    = forwardWith profile bindings
    in case xx of
        CastWeakenEffect eff    -> return $ CastWeakenEffect eff
        CastWeakenClosure xs    -> liftM    CastWeakenClosure (mapM down xs)
        CastPurify w            -> return $ CastPurify w
        CastForget w            -> return $ CastForget w


instance Forward Lets where
 forwardWith profile bindings lts
  = let down    = forwardWith profile bindings
    in case lts of
        LLet mode b x   -> liftM (LLet mode b) (down x)

        LRec bxs        
         -> liftM LRec
         $  mapM (\(b,x) 
                    -> do x' <- down x
			  return (b, x')) 
            bxs

        LLetRegions b bs -> return $ LLetRegions b bs
        LWithRegion b    -> return $ LWithRegion b


instance Forward Alt where
 forwardWith profile bindings (AAlt p x)
  = liftM (AAlt p) (forwardWith profile bindings x)