module Shader (
GPU(..),
rasterizeVertex,
inputVertex,
fragmentFrontFacing,
Vertex(),
Fragment(),
ShaderInfo,
getShaders,
Real'(..),
Convert(..),
dFdx,
dFdy,
fwidth,
vSampleBinFunc,
fSampleBinFunc,
vSampleTernFunc,
fSampleTernFunc,
module Data.Boolean
) where
import System.IO.Unsafe
import Data.Vec ((:.)(..), Vec2, Vec3, Vec4, norm, normalize, dot, cross)
import qualified Data.Vec as Vec
import Data.Unique
import Data.List
import Data.Maybe
import Data.Boolean
import Control.Monad.State
import Data.Map (Map)
import qualified Data.Map as Map hiding (Map)
import qualified Data.HashTable as HT
import Control.Exception (evaluate)
import System.Mem.StableName
import Data.IntSet (IntSet)
import qualified Data.IntSet as IntSet hiding (IntSet)
import Control.Arrow (first, second)
import Resources
import Formats
infixl 7 `mod'`
class GPU a where
type CPU a
toGPU :: CPU a -> a
data ShaderTree = ShaderUniform !Uniform
| ShaderConstant !Const
| ShaderInput !Int
| ShaderInputTree ShaderTree
| ShaderOp !Op (String -> [String] -> String) [ShaderTree]
type ShaderDAG = ([Int],[(ShaderTree, [Int])])
newtype Vertex a = Vertex { fromVertex :: ShaderTree }
newtype Fragment a = Fragment { fromFragment :: ShaderTree }
rasterizeVertex :: Vertex Float -> Fragment Float
rasterizeVertex = Fragment . ShaderInputTree . fromVertex
inputVertex :: Int -> Vertex Float
inputVertex = Vertex . ShaderInput
fragmentFrontFacing :: Fragment Bool
fragmentFrontFacing = Fragment $ ShaderOp "gl_ff" (assign "bool" (const "gl_FrontFacing")) []
getShaders :: Vec4 (Vertex Float) -> Fragment Bool -> Vec4 (Fragment Float) -> Maybe (Fragment Float) -> (ShaderInfo, ShaderInfo, [Int])
getShaders pos (Fragment ndisc) color mdepth = ((createShaderKey vdag,vstr,vuns),(createShaderKey fdag,fstr,funs), inputs)
where fcolor = fromFragment $ fFromVec "vec4" color
(varyings, fdag@(fcolor':ndisc':mdepth',_)) = splitShaders (createDAG (fcolor:ndisc: map fromFragment (maybeToList mdepth)))
vpos = fromVertex $ vFromVec "vec4" pos
vdag@(vpos':varyings',_) = createDAG (vpos:varyings)
inputs = extractInputs vdag
vcodeAssigns = getCodeAssignments (fromJust . flip elemIndex inputs) (length inputs) "v" vdag
vCodeFinish = setVaryings varyings' ++
"gl_Position = t" ++ show vpos' ++ ";\n"
fcodeAssigns = getCodeAssignments id (length varyings') "f" fdag
depthAssign = case mdepth' of [d] -> "gl_FragDepth = t" ++ show d ++ ";\n"
[] -> ""
fcodeFinish = "if (!t" ++ show ndisc' ++ ") discard;\n" ++
depthAssign ++
"gl_FragColor = t" ++ show fcolor' ++ ";\n"
vuns = extractUniforms vdag
funs = extractUniforms fdag
attributeDecl = inoutDecls "attribute" "v" (length inputs)
varyingDecl = inoutDecls "varying" "f" (length varyings')
vstr = makeShader (attributeDecl ++ varyingDecl ++ uniformDecls "v" vuns) (vcodeAssigns ++ vCodeFinish)
fstr = makeShader (varyingDecl ++ uniformDecls "f" funs) (fcodeAssigns ++ fcodeFinish)
vSampleBinFunc f t s tex c = toColor $ vToVec "float" 4 (vBinaryFunc "vec4" f (Vertex $ ShaderUniform $ UniformSampler t s tex) (vFromVec (tName c) c))
fSampleBinFunc f t s tex c = toColor $ fToVec "float" 4 (fBinaryFunc "vec4" f (Fragment $ ShaderUniform $ UniformSampler t s tex) (fFromVec (tName c) c))
vSampleTernFunc f t s tex c x = toColor $ vToVec "float" 4 (vTernaryFunc "vec4" f (Vertex $ ShaderUniform $ UniformSampler t s tex) (vFromVec (tName c) c) x)
fSampleTernFunc f t s tex c x = toColor $ fToVec "float" 4 (fTernaryFunc "vec4" f (Fragment $ ShaderUniform $ UniformSampler t s tex) (fFromVec (tName c) c) x)
instance GPU (Vertex Float) where
type CPU (Vertex Float) = Float
toGPU = Vertex . ShaderUniform . UniformFloat
instance GPU (Vertex Int) where
type CPU (Vertex Int) = Int
toGPU = Vertex . ShaderUniform . UniformInt
instance GPU (Vertex Bool) where
type CPU (Vertex Bool) = Bool
toGPU = Vertex . ShaderUniform . UniformBool
instance GPU (Fragment Float) where
type CPU (Fragment Float) = Float
toGPU = Fragment . ShaderUniform . UniformFloat
instance GPU (Fragment Int) where
type CPU (Fragment Int) = Int
toGPU = Fragment . ShaderUniform . UniformInt
instance GPU (Fragment Bool) where
type CPU (Fragment Bool) = Bool
toGPU = Fragment . ShaderUniform . UniformBool
instance GPU () where
type CPU () = ()
toGPU = id
instance (GPU a, GPU b) => GPU (a,b) where
type CPU (a,b) = (CPU a, CPU b)
toGPU (a,b)= (toGPU a, toGPU b)
instance (GPU a, GPU b, GPU c) => GPU (a,b,c) where
type CPU (a,b,c) = (CPU a, CPU b, CPU c)
toGPU (a,b,c)= (toGPU a, toGPU b, toGPU c)
instance (GPU a, GPU b, GPU c, GPU d) => GPU (a,b,c,d) where
type CPU (a,b,c,d) = (CPU a, CPU b, CPU c, CPU d)
toGPU (a,b,c,d)= (toGPU a, toGPU b, toGPU c, toGPU d)
instance (GPU a, GPU b) => GPU (a:.b) where
type CPU (a:.b) = CPU a :. CPU b
toGPU (a:.b) = toGPU a :. toGPU b
instance Eq (Vertex a) where
(==) = noFun "(==)"
(/=) = noFun "(/=)"
instance Eq (Fragment a) where
(==) = noFun "(==)"
(/=) = noFun "(/=)"
instance Show (Vertex a) where
show = noFun "show"
instance Show (Fragment a) where
show = noFun "show"
instance Ord (Vertex Float) where
(<=) = noFun "(<=)"
min = vBinaryFunc "float" "min"
max = vBinaryFunc "float" "max"
instance Ord (Fragment Float) where
(<=) = noFun "(<=)"
min = fBinaryFunc "float" "min"
max = fBinaryFunc "float" "max"
instance Num (Vertex Float) where
negate = vUnaryPreOp "float" "-"
(+) = vBinaryOp "float" "+"
(*) = vBinaryOp "float" "*"
fromInteger = Vertex . ShaderConstant . ConstFloat . fromInteger
abs = vUnaryFunc "float" "abs"
signum = vUnaryFunc "float" "sign"
instance Num (Fragment Float) where
negate = fUnaryPreOp "float" "-"
(+) = fBinaryOp "float" "+"
(*) = fBinaryOp "float" "*"
fromInteger = Fragment . ShaderConstant . ConstFloat . fromInteger
abs = fUnaryFunc "float" "abs"
signum = fUnaryFunc "float" "sign"
instance Ord (Vertex Int) where
(<=) = noFun "(<=)"
min = noFun "min"
max = noFun "max"
instance Ord (Fragment Int) where
(<=) = noFun "(<=)"
min = noFun "min"
max = noFun "max"
instance Num (Vertex Int) where
negate = vUnaryPreOp "int" "-"
(+) = vBinaryOp "int" "+"
(*) = vBinaryOp "int" "*"
fromInteger = Vertex . ShaderConstant . ConstInt . fromInteger
abs = noFun "abs"
signum = noFun "sign"
instance Num (Fragment Int) where
negate = fUnaryPreOp "int" "-"
(+) = fBinaryOp "int" "+"
(*) = fBinaryOp "int" "*"
fromInteger = Fragment . ShaderConstant . ConstInt . fromInteger
abs = noFun "abs"
signum = noFun "sign"
instance Fractional (Vertex Float) where
(/) = vBinaryOp "float" "/"
fromRational = Vertex . ShaderConstant . ConstFloat . fromRational
instance Fractional (Fragment Float) where
(/) = fBinaryOp "float" "/"
fromRational = Fragment . ShaderConstant . ConstFloat . fromRational
instance Floating (Vertex Float) where
pi = Vertex $ ShaderConstant $ ConstFloat pi
sqrt = vUnaryFunc "float" "sqrt"
exp = vUnaryFunc "float" "exp"
log = vUnaryFunc "float" "log"
(**) = vBinaryFunc "float" "pow"
sin = vUnaryFunc "float" "sin"
cos = vUnaryFunc "float" "cos"
tan = vUnaryFunc "float" "tan"
asin = vUnaryFunc "float" "asin"
acos = vUnaryFunc "float" "acos"
atan = vUnaryFunc "float" "atan"
sinh = noFun "float" "sinh"
cosh = noFun "float" "cosh"
asinh = noFun "float" "asinh"
atanh = noFun "float" "atanh"
acosh = noFun "float" "acosh"
instance Floating (Fragment Float) where
pi = Fragment $ ShaderConstant $ ConstFloat pi
sqrt = fUnaryFunc "float" "sqrt"
exp = fUnaryFunc "float" "exp"
log = fUnaryFunc "float" "log"
(**) = fBinaryFunc "float" "pow"
sin = fUnaryFunc "float" "sin"
cos = fUnaryFunc "float" "cos"
tan = fUnaryFunc "float" "tan"
asin = fUnaryFunc "float" "asin"
acos = fUnaryFunc "float" "acos"
atan = fUnaryFunc "float" "atan"
sinh = noFun "sinh"
cosh = noFun "cosh"
asinh = noFun "asinh"
atanh = noFun "atanh"
acosh = noFun "acosh"
class (Ord a, Floating a) => Real' a where
rsqrt :: a -> a
exp2 :: a -> a
log2 :: a -> a
floor' :: a -> a
ceiling' :: a -> a
fract' :: a -> a
mod' :: a -> a -> a
clamp :: a -> a -> a -> a
saturate :: a -> a
mix :: a -> a -> a-> a
step :: a -> a -> a
smoothstep :: a -> a -> a -> a
rsqrt = (1/) . sqrt
exp2 = (2**)
log2 = logBase 2
clamp x a b = min (max x a) b
saturate x = clamp x 0 1
mix x y a = x*(1a)+y*a
step a x | x < a = 0
| otherwise = 1
smoothstep a b x = let t = saturate ((xa) / (ba))
in t*t*(32*t)
fract' x = x floor' x
mod' x y = x y* floor' (x/y)
instance Real' Float where
floor' = fromIntegral . floor
ceiling' = fromIntegral . ceiling
instance Real' Double where
floor' = fromIntegral . floor
ceiling' = fromIntegral . ceiling
instance Real' (Vertex Float) where
rsqrt = vUnaryFunc "float" "inversesqrt"
exp2 = vUnaryFunc "float" "exp2"
log2 = vUnaryFunc "float" "log2"
floor' = vUnaryFunc "float" "floor"
ceiling' = vUnaryFunc "float" "ceil"
fract' = vUnaryFunc "float" "fract"
mod' = vBinaryFunc "float" "mod"
clamp = vTernaryFunc "float" "clamp"
mix = vTernaryFunc "float" "mix"
step = vBinaryFunc "float" "step"
smoothstep = vTernaryFunc "float" "smoothstep"
instance Real' (Fragment Float) where
rsqrt = fUnaryFunc "float" "inversesqrt"
exp2 = fUnaryFunc "float" "exp2"
log2 = fUnaryFunc "float" "log2"
floor' = fUnaryFunc "float" "floor"
ceiling' = fUnaryFunc "float" "ceil"
fract' = fUnaryFunc "float" "fract"
mod' = fBinaryFunc "float" "mod"
clamp = fTernaryFunc "float" "clamp"
mix = fTernaryFunc "float" "mix"
step = fBinaryFunc "float" "step"
smoothstep = fTernaryFunc "float" "smoothstep"
instance Boolean (Vertex Bool) where
true = Vertex $ ShaderConstant $ ConstBool True
false = Vertex $ ShaderConstant $ ConstBool False
notB = vUnaryPreOp "bool" "!"
(&&*) = vBinaryOp "bool" "&&"
(||*) = vBinaryOp "bool" "||"
instance Boolean (Fragment Bool) where
true = Fragment $ ShaderConstant $ ConstBool True
false = Fragment $ ShaderConstant $ ConstBool False
notB = fUnaryPreOp "bool" "!"
(&&*) = fBinaryOp "bool" "&&"
(||*) = fBinaryOp "bool" "||"
instance Eq a => EqB (Vertex Bool) (Vertex a) where
(==*) = vBinaryOp "bool" "=="
(/=*) = vBinaryOp "bool" "!="
instance Eq a => EqB (Fragment Bool) (Fragment a) where
(==*) = fBinaryOp "bool" "=="
(/=*) = fBinaryOp "bool" "!="
instance Ord a => OrdB (Vertex Bool) (Vertex a) where
(<*) = vBinaryOp "bool" "<"
(>=*) = vBinaryOp "bool" ">="
(>*) = vBinaryOp "bool" ">"
(<=*) = vBinaryOp "bool" "<="
instance Ord a => OrdB (Fragment Bool) (Fragment a) where
(<*) = fBinaryOp "bool" "<"
(>=*) = fBinaryOp "bool" ">="
(>*) = fBinaryOp "bool" ">"
(<=*) = fBinaryOp "bool" "<="
instance IfB (Vertex Bool) (Vertex Int) where
ifB c a b = Vertex $ ShaderOp "if" (assign "int" (\[a,b,c]->a++"?"++b++":"++c)) [fromVertex c,fromVertex a,fromVertex b]
instance IfB (Vertex Bool) (Vertex Float) where
ifB c a b = Vertex $ ShaderOp "if" (assign "float" (\[a,b,c]->a++"?"++b++":"++c)) [fromVertex c,fromVertex a,fromVertex b]
instance IfB (Vertex Bool) (Vertex Bool) where
ifB c a b = Vertex $ ShaderOp "if" (assign "bool" (\[a,b,c]->a++"?"++b++":"++c)) [fromVertex c,fromVertex a,fromVertex b]
instance IfB (Fragment Bool) (Fragment Int) where
ifB c a b = Fragment $ ShaderOp "if" (assign "int" (\[a,b,c]->a++"?"++b++":"++c)) [fromFragment c,fromFragment a,fromFragment b]
instance IfB (Fragment Bool) (Fragment Float) where
ifB c a b = Fragment $ ShaderOp "if" (assign "float" (\[a,b,c]->a++"?"++b++":"++c)) [fromFragment c,fromFragment a,fromFragment b]
instance IfB (Fragment Bool) (Fragment Bool) where
ifB c a b = Fragment $ ShaderOp "if" (assign "bool" (\[a,b,c]->a++"?"++b++":"++c)) [fromFragment c,fromFragment a,fromFragment b]
class Convert a where
type ConvertFloat a
type ConvertInt a
toFloat :: a -> ConvertFloat a
toInt :: a -> ConvertInt a
instance Convert Float where
type ConvertFloat Float = Float
type ConvertInt Float = Int
toFloat = id
toInt = truncate
instance Convert Int where
type ConvertFloat Int = Float
type ConvertInt Int = Int
toFloat = fromIntegral
toInt = id
instance Convert (Vertex Float) where
type ConvertFloat (Vertex Float) = Vertex Float
type ConvertInt (Vertex Float) = Vertex Int
toFloat = id
toInt = vUnaryFunc "int" "int"
instance Convert (Vertex Int) where
type ConvertFloat (Vertex Int) = Vertex Float
type ConvertInt (Vertex Int) = Vertex Int
toFloat = vUnaryFunc "float" "float"
toInt = id
instance Convert (Fragment Float) where
type ConvertFloat (Fragment Float) = Fragment Float
type ConvertInt (Fragment Float) = Fragment Int
toFloat = id
toInt = fUnaryFunc "int" "int"
instance Convert (Fragment Int) where
type ConvertFloat (Fragment Int) = Fragment Float
type ConvertInt (Fragment Int) = Fragment Int
toFloat = fUnaryFunc "float" "float"
toInt = id
dFdx :: Fragment Float -> Fragment Float
dFdy :: Fragment Float -> Fragment Float
fwidth :: Fragment Float -> Fragment Float
dFdx = fUnaryFunc "float" "dFdx"
dFdy = fUnaryFunc "float" "dFdy"
fwidth = fUnaryFunc "float" "fwidth"
normF4 :: Vec4 (Fragment Float) -> Fragment Float
normF4 = fUnaryFunc "float" "length" . fFromVec "vec4"
normF3 :: Vec3 (Fragment Float) -> Fragment Float
normF3 = fUnaryFunc "float" "length" . fFromVec "vec3"
normF2 :: Vec2 (Fragment Float) -> Fragment Float
normF2 = fUnaryFunc "float" "length" . fFromVec "vec2"
normV4 :: Vec4 (Vertex Float) -> Vertex Float
normV4 = vUnaryFunc "float" "length" . vFromVec "vec4"
normV3 :: Vec3 (Vertex Float) -> Vertex Float
normV3 = vUnaryFunc "float" "length" . vFromVec "vec3"
normV2 :: Vec2 (Vertex Float) -> Vertex Float
normV2 = vUnaryFunc "float" "length" . vFromVec "vec3"
normalizeF4 :: Vec4 (Fragment Float) -> Vec4 (Fragment Float)
normalizeF4 = fToVec "float" 4 . fUnaryFunc "vec4" "normalize" . fFromVec "vec4"
normalizeF3 :: Vec3 (Fragment Float) -> Vec3 (Fragment Float)
normalizeF3 = fToVec "float" 3 . fUnaryFunc "vec3" "normalize" . fFromVec "vec3"
normalizeF2 :: Vec2 (Fragment Float) -> Vec2 (Fragment Float)
normalizeF2 = fToVec "float" 2 . fUnaryFunc "vec2" "normalize" . fFromVec "vec2"
normalizeV4 :: Vec4 (Vertex Float) -> Vec4 (Vertex Float)
normalizeV4 = vToVec "float" 4 . vUnaryFunc "vec4" "normalize" . vFromVec "vec4"
normalizeV3 :: Vec3 (Vertex Float) -> Vec3 (Vertex Float)
normalizeV3 = vToVec "float" 3 . vUnaryFunc "vec3" "normalize" . vFromVec "vec3"
normalizeV2 :: Vec2 (Vertex Float) -> Vec2 (Vertex Float)
normalizeV2 = vToVec "float" 2 . vUnaryFunc "vec2" "normalize" . vFromVec "vec2"
dotF4 :: Vec4 (Fragment Float) -> Vec4 (Fragment Float) -> Fragment Float
dotF4 a b = fBinaryFunc "float" "dot" (fFromVec "vec4" a) (fFromVec "vec4" b)
dotF3 :: Vec3 (Fragment Float) -> Vec3 (Fragment Float) -> Fragment Float
dotF3 a b = fBinaryFunc "float" "dot" (fFromVec "vec3" a) (fFromVec "vec3" b)
dotF2 :: Vec2 (Fragment Float) -> Vec2 (Fragment Float) -> Fragment Float
dotF2 a b = fBinaryFunc "float" "dot" (fFromVec "vec2" a) (fFromVec "vec2" b)
dotV4 :: Vec4 (Vertex Float) -> Vec4 (Vertex Float) -> Vertex Float
dotV4 a b = vBinaryFunc "float" "dot" (vFromVec "vec4" a) (vFromVec "vec4" b)
dotV3 :: Vec3 (Vertex Float) -> Vec3 (Vertex Float) -> Vertex Float
dotV3 a b = vBinaryFunc "float" "dot" (vFromVec "vec3" a) (vFromVec "vec3" b)
dotV2 :: Vec2 (Vertex Float) -> Vec2 (Vertex Float) -> Vertex Float
dotV2 a b = vBinaryFunc "float" "dot" (vFromVec "vec2" a) (vFromVec "vec2" b)
crossF3 :: Vec3 (Fragment Float) -> Vec3 (Fragment Float) -> Vec3 (Fragment Float)
crossF3 a b = fToVec "float" 3 $ fBinaryFunc "vec3" "cross" (fFromVec "vec3" a) (fFromVec "vec3" b)
crossV3 :: Vec3 (Vertex Float) -> Vec3 (Vertex Float) ->Vec3 (Vertex Float)
crossV3 a b = vToVec "float" 3 $ vBinaryFunc "vec3" "cross" (vFromVec "vec3" a) (vFromVec "vec3" b)
noFun :: String -> a
noFun = error . (++ ": No overloading for Vertex/Fragment")
setVaryings xs = setVaryings' 0 $ map (('t':) . show) xs
where
setVaryings' _ [] = ""
setVaryings' n xs = case splitAt 4 xs of (ys,rest) -> "f" ++ show n ++ " = " ++ tName' (length ys) ++ "(" ++ intercalate "," ys ++ ");\n" ++ setVaryings' (n+1) rest
inoutDecls t n i = inoutDecls' i 0
where inoutDecls' i x | i >= 4 = t ++ " vec4 " ++ n ++ show x ++ ";\n" ++ inoutDecls' (i4) (x+1)
| i == 0 = ""
| otherwise = t ++ " " ++ tName' i ++ " " ++ n ++ show x ++ ";\n"
uniformDecls :: String -> UniformSet -> String
uniformDecls p (f,i,b,s) = makeU "float" "f" (length f) ++
makeU "int" "i" (length i) ++
makeU "bool" "b" (length b) ++
concatMap (\(t,xs) -> makeU (sampName t) ('s':show (fromEnum t)) (length xs)) (Map.toList s)
where makeU t n 0 = ""
makeU t n i = "uniform " ++ t ++ " " ++ p ++ "u" ++ n ++ "[" ++ show i ++ "];\n"
makeShader init assignments = "#version 120\n" ++
init ++
"void main(){\n" ++
assignments ++
"}\n"
createShaderKey :: ShaderDAG -> ShaderKey
createShaderKey (a,xs) = (a,map (first toShaderKeyNode) xs)
where toShaderKeyNode (ShaderUniform _) = ShaderKeyUniform
toShaderKeyNode (ShaderInput a) = ShaderKeyInput a
toShaderKeyNode (ShaderConstant a) = ShaderKeyConstant a
toShaderKeyNode (ShaderOp a _ _) = ShaderKeyOp a
toShaderKeyNode (ShaderInputTree _) = error "Use splitShaders first"
splitShaders :: ShaderDAG -> ([ShaderTree], ShaderDAG)
splitShaders (a,xs) = case mapAccumL splitNode [] xs of (trees, xs2) -> (reverse trees, (a,xs2))
where splitNode ts (ShaderInputTree a, ys) = (a:ts, (ShaderInput (length ts), ys))
splitNode ts a = (ts, a)
createDAG :: [ShaderTree] -> ShaderDAG
createDAG = second reverse . unsafePerformIO . startDAG
where startDAG xs = do ht <- HT.new (==) (fromIntegral . hashStableName)
runStateT (mapM (createDAG' ht) xs) []
createDAG' :: HT.HashTable (StableName ShaderTree) Int -> ShaderTree -> StateT [(ShaderTree, [Int])] IO Int
createDAG' ht n = do n' <- liftIO $ evaluate n
k <- liftIO $ makeStableName n'
m <- liftIO $ HT.lookup ht k
case m of
Just i -> return i
Nothing -> do xs' <- case n' of
ShaderOp _ _ xs -> mapM (createDAG' ht) xs
_ -> return []
ys <- get
let y = length ys
liftIO $ HT.insert ht k y
put $ (n',xs'):ys
return y
extractUniforms :: ShaderDAG -> UniformSet
extractUniforms (_,xs) = foldl' extractUniform ([],[],[],Map.empty) $ reverse $ map fst xs
where extractUniform (a,b,c,m) (ShaderUniform (UniformFloat x)) = (x:a,b,c,m)
extractUniform (a,b,c,m) (ShaderUniform (UniformInt x)) = (a,x:b,c,m)
extractUniform (a,b,c,m) (ShaderUniform (UniformBool x)) = (a,b,x:c,m)
extractUniform (a,b,c,m) (ShaderUniform (UniformSampler t s tex)) = (a,b,c,Map.insertWith' (++) t [(s,tex)] m)
extractUniform x _ = x
extractInputs :: ShaderDAG -> [Int]
extractInputs (_,xs) = IntSet.toAscList $ foldl' extractIn (IntSet.empty) $ map fst xs
where extractIn s (ShaderInput a) = IntSet.insert a s
extractIn x _ = x
getCodeAssignments :: (Int -> Int) -> Int -> String -> ShaderDAG -> String
getCodeAssignments inF numIns inName (_,xs) = concat $ snd $ mapAccumL getCode ((0,0,0,Map.empty),Map.empty) $ zip [0..] xs
where getCode ((f,i,b,s),inlns) (n, (ShaderUniform (UniformFloat _), _)) = (((f+1,i,b,s),inlns), assign "float" (const $ inName ++ "uf[" ++ show f ++ "]") (var n) [])
getCode ((f,i,b,s),inlns) (n, (ShaderUniform (UniformInt _), _)) = (((f,i+1,b,s),inlns), assign "int" (const $ inName ++ "ui[" ++ show i ++ "]") (var n) [])
getCode ((f,i,b,s),inlns) (n, (ShaderUniform (UniformBool _), _)) = (((f,i,b+1,s),inlns), assign "bool" (const $ inName ++ "ub[" ++ show b ++ "]") (var n) [])
getCode ((f,i,b,s),inlns) (n, (ShaderUniform (UniformSampler t _ _), _)) =
case first (fromMaybe 0) $ Map.insertLookupWithKey (const $ const (+1)) t 1 s of
(x, s') -> (((f,i,b,s'),Map.insert n (inName ++ "us" ++ show (fromEnum t) ++ "[" ++ show x ++ "]") inlns), "")
getCode x (n, (ShaderConstant (ConstFloat f), _)) = (x, assign "float" (const $ show f) (var n) [])
getCode x (n, (ShaderConstant (ConstInt i), _)) = (x, assign "int" (const $ show i) (var n) [])
getCode x (n, (ShaderConstant (ConstBool b), _)) = (x, assign "bool" (const $ if b then "true" else "false") (var n) [])
getCode x (n, (ShaderInput i, _)) = (x, assign "float" (const $ inName ++ inoutAccessor (inF i) numIns) (var n) [])
getCode x@(_,inlns) (n, (ShaderOp _ f _, xs)) = (x, f (var n) (map (varMaybeInline inlns) xs))
getCode _ (_, (ShaderInputTree _, _)) = error "Shader.getCodeAssignments: Use splitShaders first!"
var n = 't' : show n
varMaybeInline inlns n = case Map.lookup n inlns of Just str -> str
Nothing -> var n
inoutAccessor i tot = case divMod i 4 of (d,m) -> if i+1 == tot && m == 0 then show d else show d ++ "." ++ (["x","y","z","w"]!!m)
sampName Sampler3D = "sampler3D"
sampName Sampler2D = "sampler2D"
sampName Sampler1D = "sampler1D"
sampName SamplerCube = "samplerCube"
tName v = tName' $ Vec.length v
tName' 1 = "float"
tName' x = "vec" ++ show x
assign :: String -> ([String] -> String) -> String -> [String] -> String
assign t f x ys = t ++ " " ++ x ++ "=" ++ f ys ++ ";\n"
binFunc :: String -> [String] -> String
binFunc s = head . binFunc'
where
binFunc' (a:b:xs) = binFunc' $ (s ++ "(" ++ a ++ "," ++ b ++ ")"):binFunc' xs
binFunc' x = x
vBinaryOp t s a b = Vertex $ ShaderOp s (assign t (intercalate s)) [fromVertex a, fromVertex b]
vUnaryPreOp t s a = Vertex $ ShaderOp s (assign t ((s ++) . head)) [fromVertex a]
vUnaryPostOp t s a = Vertex $ ShaderOp s (assign t ((++ s) . head)) [fromVertex a]
vUnaryFunc t s a = Vertex $ ShaderOp s (assign t (((s ++ "(") ++) . (++ ")") . head)) [fromVertex a]
vBinaryFunc t s a b = Vertex $ ShaderOp s (assign t (binFunc s)) [fromVertex a, fromVertex b]
vTernaryFunc t s a b c = Vertex $ ShaderOp s (assign t (\[a,b,c]->s++"("++a++","++b++","++c++")")) [fromVertex a, fromVertex b, fromVertex c]
vFromVec t = Vertex . ShaderOp "" (assign t (((t ++ "(") ++) . (++ ")") . intercalate ",")) . map fromVertex . Vec.toList
vToVec t n a = Vec.fromList $ map (\s -> Vertex $ ShaderOp s (assign t (\[x]->x++"["++s++"]")) [fromVertex a]) [show n' | n' <-[0..n 1]]
fBinaryOp t s a b = Fragment $ ShaderOp s (assign t (intercalate s)) [fromFragment a, fromFragment b]
fUnaryPreOp t s a = Fragment $ ShaderOp s (assign t ((s ++) . head)) [fromFragment a]
fUnaryPostOp t s a = Fragment $ ShaderOp s (assign t ((++ s) . head)) [fromFragment a]
fUnaryFunc t s a = Fragment $ ShaderOp s (assign t (((s ++ "(") ++) . (++ ")") . head)) [fromFragment a]
fBinaryFunc t s a b = Fragment $ ShaderOp s (assign t (binFunc s)) [fromFragment a, fromFragment b]
fTernaryFunc t s a b c = Fragment $ ShaderOp s (assign t (\[a,b,c]->s++"("++a++","++b++","++c++")")) [fromFragment a, fromFragment b, fromFragment c]
fFromVec t = Fragment . ShaderOp "" (assign t (((t ++ "(") ++) . (++ ")") . intercalate ",")) . map fromFragment . Vec.toList
fToVec t n a = Vec.fromList $ map (\s -> Fragment $ ShaderOp s (assign t (\[x]->x++"["++s++"]")) [fromFragment a]) [show n' | n' <-[0..n 1]]