{-# LANGUAGE MultiParamTypeClasses, TypeFamilies, FlexibleContexts,
RankNTypes, ExistentialQuantification, GeneralizedNewtypeDeriving,
FlexibleInstances, GADTs #-}
module Graphics.GPipe.Internal.Shader (
Shader(..),
ShaderM(..),
ShaderState(..),
CompiledShader,
Render(..),
getName,
tellDrawcall,
askUniformAlignment,
modifyRenderIO,
compileShader,
mapShader,
guard',
maybeShader,
chooseShader,
silenceShader
) where
import Graphics.GPipe.Internal.Compiler
import Graphics.GPipe.Internal.Context
import Graphics.GPipe.Internal.Buffer
import Control.Monad.Trans.State
import qualified Control.Monad.Trans.State.Strict as StrictState
import Control.Monad.IO.Class
import qualified Data.IntSet as Set
import Control.Monad.Trans.Writer.Lazy (tell, WriterT(..), execWriterT)
import Control.Monad.Exception (MonadException)
import Control.Applicative (Applicative, Alternative, (<|>))
import Control.Monad.Trans.Class (lift)
import Data.Maybe (fromJust, isJust, isNothing)
import Control.Monad (MonadPlus, when)
import Control.Monad.Trans.List (ListT(..))
import Data.Monoid (All(..), mempty)
import Data.Either
import Control.Monad.Trans.Reader
import Data.List (find)
data ShaderState s = ShaderState Int (RenderIOState s)
newShaderState :: ShaderState s
newShaderState = ShaderState 1 newRenderIOState
getName :: ShaderM s Int
getName = do ShaderState n r <- ShaderM $ lift $ lift $ lift get
ShaderM $ lift $ lift $ lift $ put $ ShaderState (n+1) r
return n
askUniformAlignment = ShaderM ask
modifyRenderIO :: (RenderIOState s -> RenderIOState s) -> ShaderM s ()
modifyRenderIO f = ShaderM $ lift $ lift $ lift $ modify (\(ShaderState a s) -> ShaderState a (f s))
tellDrawcall :: IO (Drawcall s) -> ShaderM s ()
tellDrawcall dc = ShaderM $ lift $ tell ([dc], mempty)
mapDrawcall :: (s -> s') -> Drawcall s' -> Drawcall s
mapDrawcall f (Drawcall a b c d e g h i j k m) = Drawcall (a . f) b c d e g h i j k m
newtype ShaderM s a = ShaderM (ReaderT UniformAlignment (WriterT ([IO (Drawcall s)], s -> All) (ListT (State (ShaderState s)))) a) deriving (MonadPlus, Monad, Alternative, Applicative, Functor)
newtype Shader os s a = Shader (ShaderM s a) deriving (MonadPlus, Monad, Alternative, Applicative, Functor)
mapShader :: (s -> s') -> Shader os s' a -> Shader os s a
mapShader f (Shader (ShaderM m)) = Shader $ ShaderM $ do
uniAl <- ask
lift $ WriterT $ ListT $ do
ShaderState x s <- get
let (adcs, ShaderState x' s') = runState (runListT (runWriterT (runReaderT m uniAl))) (ShaderState x newRenderIOState)
put $ ShaderState x' (mapRenderIOState f s' s)
return $ map (\(a,(dcs, disc)) -> (a, (map (>>= (return . mapDrawcall f)) dcs, disc . f))) adcs
maybeShader :: (s -> Maybe s') -> Shader os s' () -> Shader os s ()
maybeShader f m = (guard' (isJust . f) >> mapShader (fromJust . f) m) <|> guard' (isNothing . f)
guard' :: (s -> Bool) -> Shader os s ()
guard' f = Shader $ ShaderM $ lift $ tell (mempty, All . f)
chooseShader :: (s -> Either s' s'') -> Shader os s' a -> Shader os s'' a -> Shader os s a
chooseShader f a b = (guard' (isLeft . f) >> mapShader (fromLeft . f) a) <|> (guard' (isRight . f) >> mapShader (fromRight . f) b)
where fromLeft (Left x) = x
fromRight (Right x) = x
silenceShader :: Shader os s a -> Shader os s a
silenceShader (Shader (ShaderM m)) = Shader $ ShaderM $ do
uniAl <- ask
lift $ WriterT $ ListT $ do
s <- get
let (adcs, s') = runState (runListT (runWriterT (runReaderT m uniAl))) s
put s'
return $ map (\ (a, (_, disc)) -> (a, ([], disc))) adcs
type CompiledShader os s = s -> Render os ()
compileShader :: (ContextHandler ctx, MonadIO m, MonadException m) => Shader os x () -> ContextT ctx os m (CompiledShader os x)
compileShader (Shader (ShaderM m)) = do
uniAl <- liftNonWinContextIO getUniformAlignment
let (adcs, ShaderState _ st) = runState (runListT (execWriterT (runReaderT m uniAl))) newShaderState
xs <- mapM (\(dcs, disc) -> do rf <- compile dcs st
return (rf, disc)) adcs
return $ \ s -> case find (\(_,disc) -> getAll (disc s)) xs of
Nothing -> error "render: Shader evaluated to mzero"
Just (rf,_) -> rf s