module Graphics.HaGL.CodeGen (
    GLProgram(..),
    UniformVar(..), InpVar(..),
    genProgram
) where

import Prelude hiding (id)
import Control.Monad.State.Lazy (State, evalState, gets, modify, unless)
import Control.Exception (throw)
import qualified Data.List as List
import qualified Data.Map as Map
import qualified Data.Set as Set

import Graphics.HaGL.ExprID
import Graphics.HaGL.GLType
import Graphics.HaGL.GLExpr
import Graphics.HaGL.GLAst
import Graphics.HaGL.GLObj
import Graphics.HaGL.Shader


-- GLProgram = output of code gen for a GLObj

data GLProgram = GLProgram {
    GLProgram -> PrimitiveMode
primitiveMode :: PrimitiveMode,
    GLProgram -> Maybe [ConstExpr UInt]
indices :: Maybe [ConstExpr UInt],
    GLProgram -> Set UniformVar
uniformVars :: Set.Set UniformVar,
    GLProgram -> Set InpVar
inputVars :: Set.Set InpVar,
    GLProgram -> Int
numElts :: Int,
    GLProgram -> Shader
vertexShader :: Shader,
    GLProgram -> Shader
fragmentShader :: Shader
}

data UniformVar where
    UniformVar :: GLType t => ExprID -> GLExpr HostDomain t -> UniformVar

instance HasExprID UniformVar where
    getID :: UniformVar -> ExprID
getID (UniformVar ExprID
id GLExpr 'HostDomain t
_) = ExprID
id
instance Eq UniformVar where
    UniformVar
x1 == :: UniformVar -> UniformVar -> Bool
== UniformVar
x2 = forall a. HasExprID a => a -> ExprID
getID UniformVar
x1 forall a. Eq a => a -> a -> Bool
== forall a. HasExprID a => a -> ExprID
getID UniformVar
x2
instance Ord UniformVar where
    compare :: UniformVar -> UniformVar -> Ordering
compare UniformVar
x1 UniformVar
x2 = forall a. Ord a => a -> a -> Ordering
compare (forall a. HasExprID a => a -> ExprID
getID UniformVar
x1) (forall a. HasExprID a => a -> ExprID
getID UniformVar
x2)

data InpVar where
    InpVar :: GLInputType t => 
        ExprID -> [GLExpr ConstDomain t] -> InpVar

instance HasExprID InpVar where
    getID :: InpVar -> ExprID
getID (InpVar ExprID
id [GLExpr 'ConstDomain t]
_) = ExprID
id
instance Eq InpVar where
    InpVar
x1 == :: InpVar -> InpVar -> Bool
== InpVar
x2 = forall a. HasExprID a => a -> ExprID
getID InpVar
x1 forall a. Eq a => a -> a -> Bool
== forall a. HasExprID a => a -> ExprID
getID InpVar
x2
instance Ord InpVar where
    compare :: InpVar -> InpVar -> Ordering
compare InpVar
x1 InpVar
x2 = forall a. Ord a => a -> a -> Ordering
compare (forall a. HasExprID a => a -> ExprID
getID InpVar
x1) (forall a. HasExprID a => a -> ExprID
getID InpVar
x2)

instance Show GLProgram where
    show :: GLProgram -> String
show GLProgram
glProg = {-}"\n" ++
        concatMap (\s -> show s ++ "\n") 
            (Set.toList $ inputVars glProg) ++ 
        "========\n\n" ++
        concatMap (\s -> show s ++ "\n") 
            (Set.toList $ uniformVars glProg) ++ 
        "========\n\n" ++-}
        forall a. [a] -> [[a]] -> [a]
List.intercalate String
"\n\n" (forall a b. (a -> b) -> [a] -> [b]
map forall a. Show a => a -> String
show 
            [GLProgram -> Shader
vertexShader GLProgram
glProg, 
             GLProgram -> Shader
fragmentShader GLProgram
glProg])


-- Intermediate code gen state

data CGDat = CGDat {
    CGDat -> Set ExprID
globalDefs :: Set.Set ExprID, 
    CGDat -> Map ScopeID Scope
scopes :: Map.Map ScopeID Scope,
    CGDat -> [(ExprID, [ExprID])]
funcStack :: [(ExprID, [ExprID])],
    CGDat -> GLProgram
program :: GLProgram
}

initCGDat :: GLObj -> CGDat
initCGDat GLObj
glObj = CGDat {
    globalDefs :: Set ExprID
globalDefs = forall a. Set a
Set.empty,
    scopes :: Map ScopeID Scope
scopes = forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList forall a b. (a -> b) -> a -> b
$
        [(GLDomain -> ScopeID
MainScope GLDomain
dom, Scope
emptyScope) | GLDomain
dom <- [GLDomain]
shaderDomains] forall a. [a] -> [a] -> [a]
++
        [(ScopeID
GlobalScope, Scope
emptyScope), (ScopeID
LocalScope, Scope
emptyScope)],
    funcStack :: [(ExprID, [ExprID])]
funcStack = [],
    program :: GLProgram
program = GLProgram {
        primitiveMode :: PrimitiveMode
Graphics.HaGL.CodeGen.primitiveMode = 
            GLObj -> PrimitiveMode
Graphics.HaGL.GLObj.primitiveMode GLObj
glObj,
        indices :: Maybe [ConstExpr UInt]
Graphics.HaGL.CodeGen.indices = 
            GLObj -> Maybe [ConstExpr UInt]
Graphics.HaGL.GLObj.indices GLObj
glObj,
        uniformVars :: Set UniformVar
uniformVars = forall a. Set a
Set.empty,
        inputVars :: Set InpVar
inputVars = forall a. Set a
Set.empty,
        numElts :: Int
numElts = Int
0,
        vertexShader :: Shader
vertexShader = [ShaderFn] -> [ShaderDecl] -> [ShaderStmt] -> Shader
Shader [] [] [],
        fragmentShader :: Shader
fragmentShader = [ShaderFn] -> [ShaderDecl] -> [ShaderStmt] -> Shader
Shader [] [] []
    }
}

data ScopeID =
    MainScope GLDomain |
    GlobalScope |
    LocalScope
    deriving (ScopeID -> ScopeID -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ScopeID -> ScopeID -> Bool
$c/= :: ScopeID -> ScopeID -> Bool
== :: ScopeID -> ScopeID -> Bool
$c== :: ScopeID -> ScopeID -> Bool
Eq, Eq ScopeID
ScopeID -> ScopeID -> Bool
ScopeID -> ScopeID -> Ordering
ScopeID -> ScopeID -> ScopeID
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: ScopeID -> ScopeID -> ScopeID
$cmin :: ScopeID -> ScopeID -> ScopeID
max :: ScopeID -> ScopeID -> ScopeID
$cmax :: ScopeID -> ScopeID -> ScopeID
>= :: ScopeID -> ScopeID -> Bool
$c>= :: ScopeID -> ScopeID -> Bool
> :: ScopeID -> ScopeID -> Bool
$c> :: ScopeID -> ScopeID -> Bool
<= :: ScopeID -> ScopeID -> Bool
$c<= :: ScopeID -> ScopeID -> Bool
< :: ScopeID -> ScopeID -> Bool
$c< :: ScopeID -> ScopeID -> Bool
compare :: ScopeID -> ScopeID -> Ordering
$ccompare :: ScopeID -> ScopeID -> Ordering
Ord)

data Scope = Scope {
    Scope -> Set ExprID
scopeExprs :: Set.Set ExprID,
    Scope -> [ShaderStmt]
scopeStmts :: [ShaderStmt]
}

emptyScope :: Scope
emptyScope :: Scope
emptyScope = Set ExprID -> [ShaderStmt] -> Scope
Scope forall a. Set a
Set.empty []

type CGState = State CGDat


-- genProgram

genProgram :: GLObj -> GLProgram
genProgram :: GLObj -> GLProgram
genProgram GLObj
glObj = forall s a. State s a -> s -> a
evalState CGState GLProgram
gen (GLObj -> CGDat
initCGDat GLObj
glObj) where 
    gen :: CGState GLProgram
    gen :: CGState GLProgram
gen = do
        ShaderExpr
posRef <- forall (d :: GLDomain) t.
IsGLDomain d =>
GLExpr d t -> CGState ShaderExpr
traverseGLExpr forall a b. (a -> b) -> a -> b
$ GLObj -> VertExpr (Vec 4 Float)
position GLObj
glObj
        ShaderExpr
colorRef <- forall (d :: GLDomain) t.
IsGLDomain d =>
GLExpr d t -> CGState ShaderExpr
traverseGLExpr forall a b. (a -> b) -> a -> b
$ GLObj -> FragExpr (Vec 4 Float)
color GLObj
glObj
        ShaderExpr
discardRef <- forall (d :: GLDomain) t.
IsGLDomain d =>
GLExpr d t -> CGState ShaderExpr
traverseGLExpr forall a b. (a -> b) -> a -> b
$ GLObj -> FragExpr Bool
discardWhen GLObj
glObj

        [ShaderStmt]
vertStmts <- Scope -> [ShaderStmt]
scopeStmts forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ScopeID -> CGState Scope
getScope (GLDomain -> ScopeID
MainScope GLDomain
VertexDomain)
        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (GLDomain -> (Shader -> Shader) -> StateT CGDat Identity ()
modifyShader GLDomain
VertexDomain forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShaderStmt -> Shader -> Shader
addStmt) [ShaderStmt]
vertStmts
        GLDomain -> (Shader -> Shader) -> StateT CGDat Identity ()
modifyShader GLDomain
VertexDomain forall a b. (a -> b) -> a -> b
$ ShaderStmt -> Shader -> Shader
addStmt forall a b. (a -> b) -> a -> b
$
            String -> ShaderExpr -> ShaderStmt
VarAsmt String
"gl_Position" ShaderExpr
posRef

        [ShaderStmt]
fragStmts <- Scope -> [ShaderStmt]
scopeStmts forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ScopeID -> CGState Scope
getScope (GLDomain -> ScopeID
MainScope GLDomain
FragmentDomain)
        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (GLDomain -> (Shader -> Shader) -> StateT CGDat Identity ()
modifyShader GLDomain
FragmentDomain forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShaderStmt -> Shader -> Shader
addStmt) [ShaderStmt]
fragStmts            
        GLDomain -> (Shader -> Shader) -> StateT CGDat Identity ()
modifyShader GLDomain
FragmentDomain forall a b. (a -> b) -> a -> b
$ ShaderDecl -> Shader -> Shader
addDecl forall a b. (a -> b) -> a -> b
$
            String -> String -> String -> ShaderDecl
OutDecl String
"" String
"fColor" String
"vec4"
        GLDomain -> (Shader -> Shader) -> StateT CGDat Identity ()
modifyShader GLDomain
FragmentDomain forall a b. (a -> b) -> a -> b
$ ShaderStmt -> Shader -> Shader
addStmt forall a b. (a -> b) -> a -> b
$
            String -> ShaderExpr -> ShaderStmt
VarAsmt String
"fColor" ShaderExpr
colorRef
        GLDomain -> (Shader -> Shader) -> StateT CGDat Identity ()
modifyShader GLDomain
FragmentDomain forall a b. (a -> b) -> a -> b
$ ShaderStmt -> Shader -> Shader
addStmt forall a b. (a -> b) -> a -> b
$
            ShaderExpr -> ShaderStmt
DiscardStmt ShaderExpr
discardRef

        GLProgram -> GLProgram
verifyProg forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets CGDat -> GLProgram
program
    
    verifyProg :: GLProgram -> GLProgram
    verifyProg :: GLProgram -> GLProgram
verifyProg GLProgram
prog =
        case forall a b. (a -> b) -> [a] -> [b]
map (\(InpVar ExprID
_ [GLExpr 'ConstDomain t]
dat) -> forall (t :: * -> *) a. Foldable t => t a -> Int
length [GLExpr 'ConstDomain t]
dat) (forall a. Set a -> [a]
Set.toList (GLProgram -> Set InpVar
inputVars GLProgram
prog)) of
            [] -> forall a e. Exception e => e -> a
throw GLObjException
NoInputVars
            [Int]
lngts | forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem Int
0 [Int]
lngts -> forall a e. Exception e => e -> a
throw GLObjException
EmptyInputVar
            Int
n:[Int]
lngts | forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall a. Eq a => a -> a -> Bool
== Int
n) [Int]
lngts -> GLProgram
prog { numElts :: Int
numElts = Int
n }
            [Int]
_ -> forall a e. Exception e => e -> a
throw GLObjException
MismatchedInputVars


-- Traversal

traverseGLExpr :: IsGLDomain d => GLExpr d t -> CGState ShaderExpr
traverseGLExpr :: forall (d :: GLDomain) t.
IsGLDomain d =>
GLExpr d t -> CGState ShaderExpr
traverseGLExpr GLExpr d t
glExpr = let glAst :: GLAst
glAst = forall (d :: GLDomain) t. IsGLDomain d => GLExpr d t -> GLAst
toGLAst GLExpr d t
glExpr in
    ScopeID -> GLAst -> CGState ShaderExpr
traverseGLAst (GLDomain -> ScopeID
MainScope forall a b. (a -> b) -> a -> b
$ forall (d :: GLDomain) t. IsGLDomain d => GLExpr d t -> GLDomain
getShaderType GLExpr d t
glExpr) GLAst
glAst

traverseGLAst :: ScopeID -> GLAst -> CGState ShaderExpr
traverseGLAst :: ScopeID -> GLAst -> CGState ShaderExpr
traverseGLAst ScopeID
_ (GLAstAtom ExprID
_ GLTypeInfo
_ (Const t
x)) = 
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall t. GLType t => t -> ShaderExpr
ShaderConst t
x
traverseGLAst ScopeID
_ (GLAstAtom ExprID
id GLTypeInfo
_ GLAtom d t
GenVar) = do
    [ExprID]
boundParamIds <- forall a b. (a, b) -> b
snd forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [a] -> a
head forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets CGDat -> [(ExprID, [ExprID])]
funcStack
    if ExprID
id forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`List.elem` [ExprID]
boundParamIds
        then forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ String -> ShaderExpr
ShaderVarRef forall a b. (a -> b) -> a -> b
$ ExprID -> String
idLabel ExprID
id
    else
        forall a e. Exception e => e -> a
throw GLExprException
UnsupportedNameCapture
traverseGLAst ScopeID
_ (GLAstAtom ExprID
id GLTypeInfo
ti (Uniform GLExpr 'HostDomain t
x)) = 
    ScopeID -> ExprID -> StateT CGDat Identity () -> CGState ShaderExpr
ifUndef ScopeID
GlobalScope ExprID
id forall a b. (a -> b) -> a -> b
$ do
        UniformVar -> StateT CGDat Identity ()
addUniformVar forall a b. (a -> b) -> a -> b
$ forall t. GLType t => ExprID -> GLExpr 'HostDomain t -> UniformVar
UniformVar ExprID
id GLExpr 'HostDomain t
x
        GLDomain -> (Shader -> Shader) -> StateT CGDat Identity ()
modifyShader (GLTypeInfo -> GLDomain
shaderType GLTypeInfo
ti) forall a b. (a -> b) -> a -> b
$ ShaderDecl -> Shader -> Shader
addDecl forall a b. (a -> b) -> a -> b
$ 
            String -> String -> ShaderDecl
UniformDecl (ExprID -> String
idLabel ExprID
id) (GLTypeInfo -> String
exprType GLTypeInfo
ti)
traverseGLAst ScopeID
_ (GLAstAtom ExprID
_ GLTypeInfo
ti (GenericUniform String
label)) = do
    let safeLabel :: String
safeLabel = String
"u_" forall a. [a] -> [a] -> [a]
++ String
label
    GLDomain -> (Shader -> Shader) -> StateT CGDat Identity ()
modifyShader (GLTypeInfo -> GLDomain
shaderType GLTypeInfo
ti) forall a b. (a -> b) -> a -> b
$ ShaderDecl -> Shader -> Shader
addDecl forall a b. (a -> b) -> a -> b
$ 
        String -> String -> ShaderDecl
UniformDecl String
safeLabel (GLTypeInfo -> String
exprType GLTypeInfo
ti)
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ String -> ShaderExpr
ShaderVarRef String
safeLabel
traverseGLAst ScopeID
_ (GLAstAtom ExprID
id GLTypeInfo
ti (Inp [GLExpr 'ConstDomain t]
xs)) = 
    ScopeID -> ExprID -> StateT CGDat Identity () -> CGState ShaderExpr
ifUndef ScopeID
GlobalScope ExprID
id forall a b. (a -> b) -> a -> b
$ do
        InpVar -> StateT CGDat Identity ()
addInputVar forall a b. (a -> b) -> a -> b
$ forall t.
GLInputType t =>
ExprID -> [GLExpr 'ConstDomain t] -> InpVar
InpVar ExprID
id [GLExpr 'ConstDomain t]
xs
        GLDomain -> (Shader -> Shader) -> StateT CGDat Identity ()
modifyShader (GLTypeInfo -> GLDomain
shaderType GLTypeInfo
ti) forall a b. (a -> b) -> a -> b
$ ShaderDecl -> Shader -> Shader
addDecl forall a b. (a -> b) -> a -> b
$ 
            String -> String -> String -> ShaderDecl
InpDecl String
"" (ExprID -> String
idLabel ExprID
id) (GLTypeInfo -> String
exprType GLTypeInfo
ti)
traverseGLAst ScopeID
_ (GLAstAtom ExprID
id GLTypeInfo
ti (Frag InterpolationType
interpType GLExpr 'VertexDomain t
x)) = 
    ScopeID -> ExprID -> StateT CGDat Identity () -> CGState ShaderExpr
ifUndef ScopeID
GlobalScope ExprID
id forall a b. (a -> b) -> a -> b
$ do
        ShaderExpr
vertExpr <- ScopeID -> GLAst -> CGState ShaderExpr
traverseGLAst (GLDomain -> ScopeID
MainScope GLDomain
VertexDomain) forall a b. (a -> b) -> a -> b
$ forall (d :: GLDomain) t. IsGLDomain d => GLExpr d t -> GLAst
toGLAst GLExpr 'VertexDomain t
x
        ScopeID -> ShaderStmt -> StateT CGDat Identity ()
scopedStmt (GLDomain -> ScopeID
MainScope GLDomain
VertexDomain) forall a b. (a -> b) -> a -> b
$
            String -> ShaderExpr -> ShaderStmt
VarAsmt (ExprID -> String
idLabel ExprID
id) ShaderExpr
vertExpr
        GLDomain -> (Shader -> Shader) -> StateT CGDat Identity ()
modifyShader GLDomain
VertexDomain forall a b. (a -> b) -> a -> b
$ ShaderDecl -> Shader -> Shader
addDecl forall a b. (a -> b) -> a -> b
$
            String -> String -> String -> ShaderDecl
OutDecl (forall a. Show a => a -> String
show InterpolationType
interpType) (ExprID -> String
idLabel ExprID
id) (GLTypeInfo -> String
exprType GLTypeInfo
ti)
        GLDomain -> (Shader -> Shader) -> StateT CGDat Identity ()
modifyShader GLDomain
FragmentDomain forall a b. (a -> b) -> a -> b
$ ShaderDecl -> Shader -> Shader
addDecl forall a b. (a -> b) -> a -> b
$
            String -> String -> String -> ShaderDecl
InpDecl (forall a. Show a => a -> String
show InterpolationType
interpType) (ExprID -> String
idLabel ExprID
id) (GLTypeInfo -> String
exprType GLTypeInfo
ti)
traverseGLAst ScopeID
_ (GLAstAtom ExprID
_ GLTypeInfo
_ GLAtom d t
_) = forall a. HasCallStack => String -> a
error String
"GLAst contains disallowed atomic variable"
traverseGLAst ScopeID
_ (GLAstFunc ExprID
fnID GLTypeInfo
ti (GLAstExpr ExprID
_ GLTypeInfo
_ String
"?:" [GLAst
cond, GLAst
ret, 
  GLAstFuncApp ExprID
_ GLTypeInfo
_ (GLAstFunc ExprID
fnID' GLTypeInfo
_ GLAst
_ [GLAst]
_) [GLAst]
recArgs]) [GLAst]
params) | ExprID
fnID forall a. Eq a => a -> a -> Bool
== ExprID
fnID' =
    ExprID
-> [ExprID] -> StateT CGDat Identity () -> CGState ShaderExpr
defFn ExprID
fnID (forall a b. (a -> b) -> [a] -> [b]
map forall a. HasExprID a => a -> ExprID
getID [GLAst]
params) forall a b. (a -> b) -> a -> b
$ do
        let paramExprs :: [ShaderParam]
paramExprs = forall a b. (a -> b) -> [a] -> [b]
map GLAst -> ShaderParam
glastToParamExpr [GLAst]
params
        ((ShaderExpr
condExpr, [ShaderStmt]
updateStmts, ShaderExpr
retExpr, [ShaderStmt]
retStmts), [ShaderStmt]
condStmts) <- forall a. CGState a -> CGState (a, [ShaderStmt])
localScope forall a b. (a -> b) -> a -> b
$ do
            ShaderExpr
condExpr <- ScopeID -> GLAst -> CGState ShaderExpr
traverseGLAst ScopeID
LocalScope GLAst
cond
            (()
_, [ShaderStmt]
updateStmts) <- forall a. CGState a -> CGState (a, [ShaderStmt])
innerScope forall a b. (a -> b) -> a -> b
$ do
                [ShaderExpr]
argExprs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (ScopeID -> GLAst -> CGState ShaderExpr
traverseGLAst ScopeID
LocalScope) [GLAst]
recArgs
                forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (\(ShaderParam String
paramName String
_, ShaderExpr
argName) -> ScopeID -> ShaderStmt -> StateT CGDat Identity ()
scopedStmt ScopeID
LocalScope forall a b. (a -> b) -> a -> b
$ 
                    String -> ShaderExpr -> ShaderStmt
VarAsmt String
paramName ShaderExpr
argName) forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [ShaderParam]
paramExprs [ShaderExpr]
argExprs
            (ShaderExpr
retExpr, [ShaderStmt]
retStmts) <- forall a. CGState a -> CGState (a, [ShaderStmt])
innerScope forall a b. (a -> b) -> a -> b
$ ScopeID -> GLAst -> CGState ShaderExpr
traverseGLAst ScopeID
LocalScope GLAst
ret
            forall (m :: * -> *) a. Monad m => a -> m a
return (ShaderExpr
condExpr, [ShaderStmt]
updateStmts, ShaderExpr
retExpr, [ShaderStmt]
retStmts)
        GLDomain -> (Shader -> Shader) -> StateT CGDat Identity ()
modifyShader (GLTypeInfo -> GLDomain
shaderType GLTypeInfo
ti) forall a b. (a -> b) -> a -> b
$ ShaderFn -> Shader -> Shader
addFn forall a b. (a -> b) -> a -> b
$
            String
-> String
-> [ShaderParam]
-> ShaderExpr
-> ShaderExpr
-> [ShaderStmt]
-> [ShaderStmt]
-> [ShaderStmt]
-> ShaderFn
ShaderLoopFn (ExprID -> String
idLabel ExprID
fnID) (GLTypeInfo -> String
exprType GLTypeInfo
ti) 
                [ShaderParam]
paramExprs
                ShaderExpr
condExpr
                ShaderExpr
retExpr
                [ShaderStmt]
condStmts
                [ShaderStmt]
retStmts
                [ShaderStmt]
updateStmts
traverseGLAst ScopeID
_ (GLAstFunc ExprID
fnID GLTypeInfo
ti GLAst
r [GLAst]
params) =
    ExprID
-> [ExprID] -> StateT CGDat Identity () -> CGState ShaderExpr
defFn ExprID
fnID (forall a b. (a -> b) -> [a] -> [b]
map forall a. HasExprID a => a -> ExprID
getID [GLAst]
params) forall a b. (a -> b) -> a -> b
$ do
        let paramExprs :: [ShaderParam]
paramExprs = forall a b. (a -> b) -> [a] -> [b]
map GLAst -> ShaderParam
glastToParamExpr [GLAst]
params
        (ShaderExpr
rExpr, [ShaderStmt]
scopeStmts) <- forall a. CGState a -> CGState (a, [ShaderStmt])
localScope forall a b. (a -> b) -> a -> b
$ ScopeID -> GLAst -> CGState ShaderExpr
traverseGLAst ScopeID
LocalScope GLAst
r
        GLDomain -> (Shader -> Shader) -> StateT CGDat Identity ()
modifyShader (GLTypeInfo -> GLDomain
shaderType GLTypeInfo
ti) forall a b. (a -> b) -> a -> b
$ ShaderFn -> Shader -> Shader
addFn forall a b. (a -> b) -> a -> b
$
            String
-> String
-> [ShaderParam]
-> [ShaderStmt]
-> ShaderExpr
-> ShaderFn
ShaderFn (ExprID -> String
idLabel ExprID
fnID) (GLTypeInfo -> String
exprType GLTypeInfo
ti)
                [ShaderParam]
paramExprs
                [ShaderStmt]
scopeStmts 
                ShaderExpr
rExpr
traverseGLAst ScopeID
scopeID (GLAstFuncApp ExprID
callID GLTypeInfo
ti GLAst
fn [GLAst]
args) = 
    ScopeID -> ExprID -> StateT CGDat Identity () -> CGState ShaderExpr
ifUndef ScopeID
scopeID ExprID
callID forall a b. (a -> b) -> a -> b
$ do
        [ShaderExpr]
argExprs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (ScopeID -> GLAst -> CGState ShaderExpr
traverseGLAst ScopeID
scopeID) [GLAst]
args
        ShaderExpr
_ <- ScopeID -> GLAst -> CGState ShaderExpr
traverseGLAst ScopeID
LocalScope GLAst
fn
        ScopeID -> ShaderStmt -> StateT CGDat Identity ()
scopedStmt ScopeID
scopeID forall a b. (a -> b) -> a -> b
$ String -> String -> ShaderExpr -> ShaderStmt
VarDeclAsmt (ExprID -> String
idLabel ExprID
callID) (GLTypeInfo -> String
exprType GLTypeInfo
ti)
            (String -> [ShaderExpr] -> ShaderExpr
ShaderExpr (ExprID -> String
idLabel forall a b. (a -> b) -> a -> b
$ forall a. HasExprID a => a -> ExprID
getID GLAst
fn) [ShaderExpr]
argExprs)
traverseGLAst ScopeID
scopeID (GLAstExpr ExprID
id GLTypeInfo
ti String
exprName [GLAst]
subnodes) =
    ScopeID -> ExprID -> StateT CGDat Identity () -> CGState ShaderExpr
ifUndef ScopeID
scopeID ExprID
id forall a b. (a -> b) -> a -> b
$ do
        [ShaderExpr]
subexprs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (ScopeID -> GLAst -> CGState ShaderExpr
traverseGLAst ScopeID
scopeID) [GLAst]
subnodes
        ScopeID -> ShaderStmt -> StateT CGDat Identity ()
scopedStmt ScopeID
scopeID forall a b. (a -> b) -> a -> b
$ String -> String -> ShaderExpr -> ShaderStmt
VarDeclAsmt (ExprID -> String
idLabel ExprID
id) (GLTypeInfo -> String
exprType GLTypeInfo
ti) forall a b. (a -> b) -> a -> b
$
            String -> [ShaderExpr] -> ShaderExpr
ShaderExpr String
exprName [ShaderExpr]
subexprs


-- Scope management

localScope :: CGState a -> CGState (a, [ShaderStmt])
localScope :: forall a. CGState a -> CGState (a, [ShaderStmt])
localScope CGState a
action = forall a. CGState a -> CGState (a, [ShaderStmt])
innerScope forall a b. (a -> b) -> a -> b
$ do
    ScopeID -> (Scope -> Scope) -> StateT CGDat Identity ()
modifyScope ScopeID
LocalScope forall a b. (a -> b) -> a -> b
$ forall a b. a -> b -> a
const Scope
emptyScope
    CGState a
action

innerScope :: CGState a -> CGState (a, [ShaderStmt])
innerScope :: forall a. CGState a -> CGState (a, [ShaderStmt])
innerScope CGState a
action = do
    Scope
scopeBefore <- ScopeID -> CGState Scope
getScope ScopeID
LocalScope
    ScopeID -> (Scope -> Scope) -> StateT CGDat Identity ()
modifyScope ScopeID
LocalScope forall a b. (a -> b) -> a -> b
$ \Scope
scope -> Scope
scope { scopeStmts :: [ShaderStmt]
scopeStmts = [] }
    a
res <- CGState a
action
    Scope
scopeAfter <- ScopeID -> CGState Scope
getScope ScopeID
LocalScope
    ScopeID -> (Scope -> Scope) -> StateT CGDat Identity ()
modifyScope ScopeID
LocalScope forall a b. (a -> b) -> a -> b
$ forall a b. a -> b -> a
const Scope
scopeBefore
    forall (m :: * -> *) a. Monad m => a -> m a
return (a
res, Scope -> [ShaderStmt]
scopeStmts Scope
scopeAfter)

getScope :: ScopeID -> CGState Scope
getScope :: ScopeID -> CGState Scope
getScope ScopeID
scopeID = do
    Map ScopeID Scope
scopes <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets CGDat -> Map ScopeID Scope
scopes
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => a -> k -> Map k a -> a
Map.findWithDefault Scope
emptyScope ScopeID
scopeID Map ScopeID Scope
scopes

modifyScope :: ScopeID -> (Scope -> Scope) -> CGState ()
modifyScope :: ScopeID -> (Scope -> Scope) -> StateT CGDat Identity ()
modifyScope ScopeID
scopeID Scope -> Scope
f = do
    forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \CGDat
s -> CGDat
s { scopes :: Map ScopeID Scope
scopes = forall k a. Ord k => (a -> a) -> k -> Map k a -> Map k a
Map.adjust Scope -> Scope
f ScopeID
scopeID forall a b. (a -> b) -> a -> b
$ CGDat -> Map ScopeID Scope
scopes CGDat
s }

ifUndef :: ScopeID -> ExprID -> CGState () -> CGState ShaderExpr
ifUndef :: ScopeID -> ExprID -> StateT CGDat Identity () -> CGState ShaderExpr
ifUndef ScopeID
scopeID ExprID
id StateT CGDat Identity ()
initFn = do
    Set ExprID
locals <- Scope -> Set ExprID
scopeExprs forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ScopeID -> CGState Scope
getScope ScopeID
scopeID
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ExprID
id forall a. Ord a => a -> Set a -> Bool
`Set.member` Set ExprID
locals) forall a b. (a -> b) -> a -> b
$ do 
        ScopeID -> (Scope -> Scope) -> StateT CGDat Identity ()
modifyScope ScopeID
scopeID forall a b. (a -> b) -> a -> b
$ \Scope
scope -> 
            Scope
scope { scopeExprs :: Set ExprID
scopeExprs = forall a. Ord a => a -> Set a -> Set a
Set.insert ExprID
id forall a b. (a -> b) -> a -> b
$ Scope -> Set ExprID
scopeExprs Scope
scope }
        StateT CGDat Identity ()
initFn
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ String -> ShaderExpr
ShaderVarRef forall a b. (a -> b) -> a -> b
$ ExprID -> String
idLabel ExprID
id

scopedStmt :: ScopeID -> ShaderStmt -> CGState ()
scopedStmt :: ScopeID -> ShaderStmt -> StateT CGDat Identity ()
scopedStmt ScopeID
scopeID ShaderStmt
stmt = ScopeID -> (Scope -> Scope) -> StateT CGDat Identity ()
modifyScope ScopeID
scopeID forall a b. (a -> b) -> a -> b
$ \Scope
scope -> 
    Scope
scope { scopeStmts :: [ShaderStmt]
scopeStmts = Scope -> [ShaderStmt]
scopeStmts Scope
scope forall a. [a] -> [a] -> [a]
++ [ShaderStmt
stmt] }


-- Function construction helpers

defFn :: ExprID -> [ExprID] -> CGState () -> CGState ShaderExpr
defFn :: ExprID
-> [ExprID] -> StateT CGDat Identity () -> CGState ShaderExpr
defFn ExprID
id [ExprID]
paramIds StateT CGDat Identity ()
initFn = do
    [(ExprID, [ExprID])]
fns <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets CGDat -> [(ExprID, [ExprID])]
funcStack
    if ExprID
id forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`List.elem` forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(ExprID, [ExprID])]
fns then
        forall a e. Exception e => e -> a
throw GLExprException
UnsupportedRecCall
    else do
        forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \CGDat
s -> CGDat
s { funcStack :: [(ExprID, [ExprID])]
funcStack = (ExprID
id, [ExprID]
paramIds) forall a. a -> [a] -> [a]
: CGDat -> [(ExprID, [ExprID])]
funcStack CGDat
s }
        ShaderExpr
res <- ScopeID -> ExprID -> StateT CGDat Identity () -> CGState ShaderExpr
ifUndef ScopeID
GlobalScope ExprID
id StateT CGDat Identity ()
initFn
        forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \CGDat
s -> CGDat
s { funcStack :: [(ExprID, [ExprID])]
funcStack = forall a. [a] -> [a]
tail forall a b. (a -> b) -> a -> b
$ CGDat -> [(ExprID, [ExprID])]
funcStack CGDat
s }
        forall (m :: * -> *) a. Monad m => a -> m a
return ShaderExpr
res

glastToParamExpr :: GLAst -> ShaderParam
glastToParamExpr :: GLAst -> ShaderParam
glastToParamExpr (GLAstAtom ExprID
id GLTypeInfo
ti GLAtom d t
GenVar) = 
    String -> String -> ShaderParam
ShaderParam (ExprID -> String
idLabel ExprID
id) (GLTypeInfo -> String
exprType GLTypeInfo
ti)


-- Shader modification

modifyShader :: GLDomain -> (Shader -> Shader) -> CGState ()
modifyShader :: GLDomain -> (Shader -> Shader) -> StateT CGDat Identity ()
modifyShader GLDomain
VertexDomain Shader -> Shader
f = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\CGDat
s -> CGDat
s { 
    program :: GLProgram
program = (CGDat -> GLProgram
program CGDat
s) { vertexShader :: Shader
vertexShader = Shader -> Shader
f forall a b. (a -> b) -> a -> b
$ GLProgram -> Shader
vertexShader forall a b. (a -> b) -> a -> b
$ CGDat -> GLProgram
program CGDat
s } })
modifyShader GLDomain
FragmentDomain Shader -> Shader
f = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\CGDat
s -> CGDat
s { 
    program :: GLProgram
program = (CGDat -> GLProgram
program CGDat
s) { fragmentShader :: Shader
fragmentShader = Shader -> Shader
f forall a b. (a -> b) -> a -> b
$ GLProgram -> Shader
fragmentShader forall a b. (a -> b) -> a -> b
$ CGDat -> GLProgram
program CGDat
s } })

addUniformVar :: UniformVar -> CGState ()
addUniformVar :: UniformVar -> StateT CGDat Identity ()
addUniformVar UniformVar
unif = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\CGDat
s -> CGDat
s { 
    program :: GLProgram
program = (CGDat -> GLProgram
program CGDat
s) { uniformVars :: Set UniformVar
uniformVars = forall a. Ord a => a -> Set a -> Set a
Set.insert UniformVar
unif forall a b. (a -> b) -> a -> b
$ GLProgram -> Set UniformVar
uniformVars forall a b. (a -> b) -> a -> b
$ CGDat -> GLProgram
program CGDat
s } })

addInputVar :: InpVar -> CGState ()
addInputVar :: InpVar -> StateT CGDat Identity ()
addInputVar InpVar
unif = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\CGDat
s -> CGDat
s { 
    program :: GLProgram
program = (CGDat -> GLProgram
program CGDat
s) { inputVars :: Set InpVar
inputVars = forall a. Ord a => a -> Set a -> Set a
Set.insert InpVar
unif forall a b. (a -> b) -> a -> b
$ GLProgram -> Set InpVar
inputVars forall a b. (a -> b) -> a -> b
$ CGDat -> GLProgram
program CGDat
s } })