{-# LANGUAGE TemplateHaskell #-}

module MonadLab.Cont (
   contT
 ) where

import Prelude hiding (Monad)
import Language.Haskell.TH
import MonadLab.CommonTypes


contT :: Layer -> MonadTransformer
contT l@(ContT t) = \m -> ( contTransTypeCon t m
			  , contTransReturn
			  , contTransBind
			  , contTransLayerNPM l : contTransLiftLayerNPMs m
			  , [| $(contTransLift m) . $(getBaseLift m) |]
			  )

--------------------------------

contTransTypeCon :: TypeQ -> Monad -> MonadTypeCon
contTransTypeCon ans m = let tc = getTypeCon m
			 in  \t -> arrow (arrow t (tc ans)) (tc ans) 

contTransReturn :: ReturnExpQ
contTransReturn = [| \v -> \k -> k v |]

contTransBind :: BindExpQ
contTransBind = [| \x -> \f -> \k -> x (\a -> f a k) |]

contTransCallCC :: NonProperMorphismExpQ
contTransCallCC = [| \f -> \k -> f (\a -> \_ -> k a) k |]

contTransLift :: Monad -> LiftExpQ
contTransLift m = [| $(getBind m) |]

--------------------------------

contTransLayerNPM :: Layer -> LayerNPM
contTransLayerNPM l = (l, [contTransCallCC])


contTransLiftLayerNPMs :: Monad -> [LayerNPM]
contTransLiftLayerNPMs m = map (contTransLiftLayerNPM m) (getLayerNPMs m) 

	where	contTransLiftLayerNPM :: Monad -> LayerNPM -> LayerNPM
		contTransLiftLayerNPM m l = case l of
			(Io, [liftIO]) 		 	    -> (Io, [contTransLiftLiftIO m liftIO])
			(List, [merge, halt]) 		    -> (List, [contTransLiftMerge m, contTransLiftHalt m])
			(StateT n t, [get,put]) 	    -> (StateT n t, [contTransLiftGet m get, contTransLiftPut m put])
			(EnvT n t, [rdEnv,inEnv])   	    -> (EnvT n t, [contTransLiftRdEnv m rdEnv, contTransLiftInEnv m rdEnv inEnv])
			(ErrorT n t, [throw,catch]) 	    -> (ErrorT n t, [contTransLiftThrow m throw, contTransLiftCatch catch])
			(WriterT n t, [tell, listen, pass]) -> (WriterT n t, [stateTransLiftTell m tell, stateTransLiftListen m listen, stateTransLiftPass m pass])

contTransLiftGet :: Monad -> NonProperMorphismExpQ -> NonProperMorphismExpQ 
contTransLiftGet m get = [| $(contTransLift m) $get |]

contTransLiftPut :: Monad -> NonProperMorphismExpQ -> NonProperMorphismExpQ
contTransLiftPut m put = [| $composition $(contTransLift m) $put |]

contTransLiftRdEnv :: Monad -> NonProperMorphismExpQ -> NonProperMorphismExpQ 
contTransLiftRdEnv m rdEnv = [| $(contTransLift m) $rdEnv |]

contTransLiftInEnv :: Monad -> NonProperMorphismExpQ -> NonProperMorphismExpQ -> NonProperMorphismExpQ
contTransLiftInEnv m rdEnv inEnv = [| \r -> \c -> \k -> $(getBind m) $rdEnv (\o -> $inEnv r (c ($composition ($inEnv o) k))) |]

contTransLiftThrow :: Monad -> NonProperMorphismExpQ -> NonProperMorphismExpQ 
contTransLiftThrow m throw = [| $composition $(contTransLift m) $throw |]

contTransLiftCatch :: NonProperMorphismExpQ -> NonProperMorphismExpQ
contTransLiftCatch catch = error "contTransLiftCatch:  Cannot lift 'catch' through cont transformer"

contTransLiftMerge :: Monad -> NonProperMorphismExpQ
contTransLiftMerge m = let newJoin     = [| \x -> $contTransBind x (\a -> a) |]
			   newBaseLift = [| $(contTransLift m) . $(getBaseLift m) |]
		       in  [| $composition $newJoin $newBaseLift |]

contTransLiftHalt :: Monad -> NonProperMorphismExpQ
contTransLiftHalt m = let newBaseLift = [| $(contTransLift m) . $(getBaseLift m) |]
		      in  [| $newBaseLift [] |]

contTransLiftLiftIO :: Monad -> NonProperMorphismExpQ -> NonProperMorphismExpQ 
contTransLiftLiftIO m liftIO = [| $(contTransLift m) . $liftIO |]

stateTransLiftTell :: Monad -> NonProperMorphismExpQ -> NonProperMorphismExpQ
stateTransLiftTell m tell = error "contTransLiftTell: Cannot lift 'tell' through cont transformer"

stateTransLiftListen :: Monad -> NonProperMorphismExpQ -> NonProperMorphismExpQ
stateTransLiftListen m listen = error "contTransLiftListen: Cannot lift 'listen' through cont transformer"

stateTransLiftPass :: Monad -> NonProperMorphismExpQ -> NonProperMorphismExpQ
stateTransLiftPass m pass = error "contTransLiftPass: Cannot lift 'pass' through cont transformer"