module Shady.Language.Exp
(
Id, V(..), var, genVar
, Pat, patT, pat
, TPath, emptyP, fstP, sndP, namePath
, E(..), (:=>), (:=>*)
, op1, op2, op3, op4
, pureE, fmapE, liftE2, liftE3, liftE4
, notE
, (==^), (/=^)
, truncateE, roundE, ceilingE, floorE
, allV, anyV
, SamplerE, texture
, lit
, BoolE, FloatE, R1E, R2E, R3E, R4E, VecE
, vec2, vec3, vec4
, un2, un3, un4
, getX, getY, getZ, getW, get, (<+>)
, unitE, pairE, fstE, sndE, unPairE, uniform, uniformV
, ComplexE
, ToE(..), toE, FromE(..), toFromE, patE
, module Shady.Language.Type
, letE
)
where
import Data.Monoid (Monoid(..),First(..))
import Data.Maybe (fromMaybe)
import Control.Applicative (Applicative(pure),(<$>))
import Control.Monad (liftM2)
import Control.Arrow ((&&&),second)
import Text.PrettyPrint.Leijen hiding ((<$>),(<+>))
import Text.PrettyPrint.Leijen.PrettyPrec
import Text.PrettyPrint.Leijen.DocExpr hiding (var,apply)
import qualified Text.PrettyPrint.Leijen.DocExpr as X
import Control.Compose (result,(~>))
import Data.Boolean
import Data.VectorSpace
import Data.NameM
import Shady.Language.Type hiding ((<+>),vec2,vec3,vec4,un2,un3,un4,get)
import Shady.Language.Glom
import qualified Shady.Vec as V
import Shady.Language.Operator
import Shady.Misc
import Shady.Complex
deriving instance Functor First
deriving instance Applicative First
deriving instance Monad First
fromFirst :: a -> First a -> a
fromFirst a = fromMaybe a . getFirst
type Id = String
data V a = V { varName :: Id, varType :: Type a } deriving Show
instance SynEq V where V a _ =-= V b _ = a == b
instance HasExprU V where
exprU = X.var . show
instance HasExpr a => HasExpr (V a) where expr = exprU
instance HasExpr a => PrettyPrec (V a) where prettyPrec = prettyExpr
instance HasExpr a => Pretty (V a) where pretty = prettyPrec 0
var :: HasType a => Id -> V a
var = flip V typeT
genVar :: HasType a => NameM (V a)
genVar = var <$> genName
newtype TPath = TPath String
emptyP :: TPath
emptyP = TPath ""
fstP, sndP :: TPath -> TPath
fstP (TPath p) = TPath ('F' : p)
sndP (TPath p) = TPath ('S' : p)
namePath :: String -> TPath -> String
namePath vname (TPath "") = vname
namePath vname (TPath p) = vname ++ "_" ++ reverse p
type Pat = Glom V
patT :: Pat a -> Type a
patT (BaseG (V _ t)) = t
patT UnitG = UnitT
patT (a :* b) = patT a :*: patT b
pat :: HasType a => String -> Pat a
pat vname = divvy emptyP typeT
where
divvy :: TPath -> Type s -> Pat s
divvy _ UnitT = UnitG
divvy path (a :*: b) = divvy (fstP path) a :* divvy (sndP path) b
divvy _ (_ :->: _) = error "pat: function type not handled"
divvy path t = BaseG (V (namePath vname path) t)
infixl 9 :^
data E :: * -> * where
Op :: Op a -> E a
Var :: V a -> E a
(:^) :: HasType a =>
E (a -> b) -> E a -> E b
Lam :: HasType a =>
V a -> E b -> E (a -> b)
instance SynEq E where
Op o =-= Op o' = o =-= o'
Var v =-= Var v' = v =-= v'
f :^ x =-= g :^ y = f === g && x === y
_ =-= _ = False
letE :: (HasType a, HasType b) =>
V a -> E a -> E b -> E b
letE v a b = Lam v b :^ a
instance HasExpr (E a) where
expr (Op oper) = X.var (show oper)
expr (Var (V n _)) = X.var n
expr e@(_ :^ _) = appExpr e []
expr (Lam (V n _) f) = lambdaX n (expr f)
appExpr :: forall a. E a -> [Expr] -> Expr
appExpr (Op o) xs = opExpr o xs
appExpr (Op Not :^ (Op (Lt n) :^ a :^ b)) xs = appExpr (Op (Le n) :^ b :^ a) xs
appExpr (Op Mul :^ a :^ (Op Recip :^ b)) xs = appExpr (Op Divide :^ a :^ b) xs
appExpr (Op Add :^ a :^ (Op Negate :^ b)) xs = appExpr (Op Sub :^ a :^ b) xs
appExpr e@(Op (Cat _ _ _) :^ _ :^ _) xs
| First (Just e') <- catFix e = appExpr e' xs
appExpr (Op (Swizzle ixs) :^ v) xs
| Just e' <- swizzleOpt ixs v = appExpr e' xs
appExpr (Lam v b :^ a) xs = foldl ($$) (letExpr v a b) xs
appExpr (f :^ e) xs = appExpr f (expr e : xs)
appExpr f xs = foldl ($$) (expr f) xs
catFix :: a :=>? a
catFix (Op (Cat (Succ Zero) (Succ Zero) _) :^ a :^ b) =
pure (Op VVec2 :^ a :^ b)
catFix (Op (Cat (Succ Zero) _ _) :^ a :^ b) = catFix b >>= consV a
catFix _ = mempty
consV :: One a :=> Vec n a :=>? (Vec (S n) a)
consV a (Op VVec2 :^ b :^ c) = pure (Op VVec3 :^ a :^ b :^ c)
consV a (Op VVec3 :^ b :^ c :^ d) = pure (Op VVec4 :^ a :^ b :^ c :^ d)
consV _ _ = mempty
swizzleOpt :: forall n m a. (IsNat m, IsNat n) =>
Vec n (Index m) -> E (Vec m a) -> Maybe (E (Vec n a))
swizzleOpt ixs v | Just Refl <- m `natEq` n, ixs == indices n = Just v
| otherwise = Nothing
where
m = nat :: Nat m
n = nat :: Nat n
letExpr :: HasType a => V a -> E a -> E b -> Expr
letExpr (V n _) a b = letX n (expr a) (expr b)
instance PrettyPrec (E a) where prettyPrec = prettyExpr
instance Pretty (E a) where pretty = prettyPrec 0
instance Show (E a) where show = show . pretty
infixr 7 :=>, :=>*
type a :=> b = E a -> b
type a :=>* b = a :=> E b
infixr 7 :=>?
type a :=>? b = a :=> First (E b)
op1 :: (HasType a, HasType b) =>
Op (a -> b) -> a :=>* b
op1 o a = Op o :^ a
op2 :: (HasType a, HasType b, HasType c) =>
Op (a -> b -> c) -> a :=> b :=>* c
op2 o a b = op1 o a :^ b
op3 :: (HasType a, HasType b, HasType c, HasType d) =>
Op (a -> b -> c -> d) -> a :=> b :=> c :=>* d
op3 o a b c = op2 o a b :^ c
op4 :: (HasType a, HasType b, HasType c, HasType d, HasType e) =>
Op (a -> b -> c -> d -> e) -> a :=> b :=> c :=> d :=>* e
op4 o a b c d = op3 o a b c :^ d
infix 0 @>
(@>) :: First a -> a -> a
(@>) = flip fromFirst
identityL :: Eq a => a -> a :=> b :=>? b
identityL i (Op (Lit u)) b | u == i = pure b
identityL _ _ _ = mempty
identityR :: Eq b => b -> a :=> b :=>? a
identityR i a (Op (Lit v)) | v == i = pure a
identityR _ _ _ = mempty
identity :: Eq a => a -> a :=> a :=>? a
identity = identityL `mappend` identityR
annihilator :: Eq a => a -> a :=> a :=>? a
annihilator z (Op (Lit u)) _ | u == z = pure (pureE z)
annihilator z _ (Op (Lit v)) | v == z = pure (pureE z)
annihilator _ _ _ = mempty
inverse :: Op (a -> a -> a) -> a :=> a :=>? a
inverse Add a (Op Negate :^ b) | a =-= b = pure 0
inverse Add (Op Negate :^ b) a | a =-= b = pure 0
inverse Mul a (Op Recip :^ b) | a =-= b = pure 1
inverse Mul (Op Recip :^ b) a | a =-= b = pure 1
inverse Mul a (Op (Lit (1))) = pure (negate a)
inverse Mul (Op (Lit (1))) a = pure (negate a)
inverse _ _ _ = mempty
commute :: a :=> a :=>? a
commute _ _ = mempty
#define SIMPLIFY
simple1 :: Op (a -> b) -> a :=>? b
#ifdef SIMPLIFY
simple1 Negate (Op Negate :^ a) = pure a
simple1 Negate (Op Mul :^ a :^ b) = pure (negate a * b)
simple1 Negate (Op Add :^ a :^ b) = pure (negate a + negate b)
simple1 Recip (Op Recip :^ a) = pure a
simple1 Fst (Op Pair :^ a :^ _) = pure a
simple1 Snd (Op Pair :^ _ :^ b) = pure b
simple1 Cos (Op Negate :^ a) = pure (cos a)
simple1 Sin (Op Negate :^ a) = pure ( sin a)
#endif
simple1 _ _ = mempty
simple2 :: Op (a -> b -> c) -> a :=> b :=>? c
#ifdef SIMPLIFY
simple2 Add = identity 0 `mappend` addMul `mappend`
inverse Add `mappend` commute
simple2 Mul = annihilator 0 `mappend` identity 1 `mappend`
inverse Mul `mappend` commute `mappend` mulNegNeg
simple2 (Cat _ _ _) = (<+?>)
#endif
simple2 _ = mempty
simple3 :: Op (a -> b -> c -> d) -> a :=> b :=> c :=>? d
#ifdef SIMPLIFY
simple3 If (Op (Lit c)) a b = pure $ if un1 c then a else b
simple3 If _ a b | a =-= b = pure a
#endif
simple3 _ _ _ _ = mempty
simple4 :: Op (a -> b -> c -> d -> e) -> a :=> b :=> c :=> d :=>? e
#ifdef SIMPLIFY
#endif
simple4 _ = mempty
infix 1 <+?>
(<+?>) :: forall n m a.
(IsNat n, IsNat m, IsScalar a,
IsNat (m :+: n), Show a) =>
Vec m a :=> Vec n a :=>? Vec (m :+: n) a
a <+?> b | n' > 1
, Just Refl <- a =:= b = pure (Op (Swizzle (is V.<+> is)) :^ a)
where
n :: Nat n
n = nat
n' = natToZ n
is = indices n
a <+?> Op (Swizzle js) :^ b | Just Refl <- a =:= b =
pure (Op (Swizzle (indices nat V.<+> js)) :^ a)
Op (Swizzle is) :^ a <+?> b | Just Refl <- a =:= b =
pure (Op (Swizzle (is V.<+> indices nat)) :^ a)
Op (Swizzle is) :^ a <+?> Op (Swizzle js) :^ b
| Just Refl <- a =:= b
= pure (Op (Swizzle (is V.<+> js)) :^ a)
Op Min :^ a :^ a' <+?> Op Min :^ b :^ b' = pure ((a <+> b) `min` (a' <+> b'))
Op Max :^ a :^ a' <+?> Op Max :^ b :^ b' = pure ((a <+> b) `max` (a' <+> b'))
Op Add :^ a :^ a' <+?> Op Add :^ b :^ b' = pure ((a <+> b) + (a' <+> b'))
Op Sub :^ a :^ a' <+?> Op Sub :^ b :^ b' = pure ((a <+> b) (a' <+> b'))
Op Mul :^ a :^ a' <+?> Op Mul :^ b :^ b' = pure ((a <+> b) * (a' <+> b'))
Op Quot :^ a :^ a' <+?> Op Quot :^ b :^ b' = pure ((a <+> b) `quot` (a' <+> b'))
Op Rem :^ a :^ a' <+?> Op Rem :^ b :^ b' = pure ((a <+> b) `rem` (a' <+> b'))
Op Div :^ a :^ a' <+?> Op Div :^ b :^ b' = pure ((a <+> b) `div` (a' <+> b'))
Op Mod :^ a :^ a' <+?> Op Mod :^ b :^ b' = pure ((a <+> b) `mod` (a' <+> b'))
Op FMod :^ a :^ a' <+?> Op FMod :^ b :^ b' = pure ((a <+> b) `fmod` (a' <+> b'))
Op Divide :^ a :^ a' <+?> Op Divide :^ b :^ b' = pure ((a <+> b) / (a' <+> b'))
Op Negate :^ a <+?> Op Negate :^ b = pure (negate (a <+> b))
Op Recip :^ a <+?> Op Recip :^ b = pure (recip (a <+> b))
Op Abs :^ a <+?> Op Abs :^ b = pure (abs (a <+> b))
Op Signum :^ a <+?> Op Signum :^ b = pure (signum (a <+> b))
Op Sqrt :^ a <+?> Op Sqrt :^ b = pure (sqrt (a <+> b))
Op Exp :^ a <+?> Op Exp :^ b = pure (exp (a <+> b))
Op Log :^ a <+?> Op Log :^ b = pure (log (a <+> b))
Op Sin :^ a <+?> Op Sin :^ b = pure (sin (a <+> b))
Op Cos :^ a <+?> Op Cos :^ b = pure (cos (a <+> b))
Op Asin :^ a <+?> Op Asin :^ b = pure (asin (a <+> b))
Op Acos :^ a <+?> Op Acos :^ b = pure (acos (a <+> b))
Op Sinh :^ a <+?> Op Sinh :^ b = pure (sinh (a <+> b))
Op Asinh :^ a <+?> Op Asinh :^ b = pure (asinh (a <+> b))
Op Atanh :^ a <+?> Op Atanh :^ b = pure (atanh (a <+> b))
Op Acosh :^ a <+?> Op Acosh :^ b = pure (acosh (a <+> b))
Op Truncate :^ a <+?> Op Truncate :^ b = pure (truncateE (a <+> b))
Op Round :^ a <+?> Op Round :^ b = pure (roundE (a <+> b))
Op Ceiling :^ a <+?> Op Ceiling :^ b = pure (ceilingE (a <+> b))
Op Floor :^ a <+?> Op Floor :^ b = pure (floorE (a <+> b))
Op Not :^ a <+?> Op Not :^ b = pure (notE (a <+> b))
Op (EqualV _) :^ a :^ a' <+?> Op (EqualV _) :^ b :^ b'
| Just Refl <- a `compatible1` b
= pure ((a <+> b) ==* (a' <+> b'))
Op (Lt _) :^ a :^ a' <+?> Op (Lt _) :^ b :^ b'
| Just Refl <- a `compatible1` b
= pure ((a <+> b) <* (a' <+> b'))
Op (Le _) :^ a :^ a' <+?> Op (Le _) :^ b :^ b'
| Just Refl <- a `compatible1` b
= pure ((a <+> b) <=* (a' <+> b'))
_ <+?> _ = mempty
addMul :: forall n a.
(IsNat n, IsScalar a, Num a) =>
Vec n a :=> Vec n a :=>? Vec n a
(Op Mul :^ a :^ b) `addMul` (Op Mul :^ a' :^ b')
| Just Refl <- (typeT :: Type (Vec n a)) `tyEq` (typeT :: Type R1)
= pure $ (a <+> a') <.> (b <+> b')
(Op Dot :^ a :^ b) `addMul` (Op Mul :^ a' :^ b')
| Just Refl <- (typeT :: Type (Vec n a)) `tyEq` (typeT :: Type R1)
, Just CanExtend <- canExtendE a
= pure $ (a <+> a') <.> (b <+> b')
_ `addMul` _ = mempty
data CanExtend :: * -> * where
CanExtend :: IsNat (n :+: OneT) => CanExtend n
canExtend :: forall n. IsNat n => Maybe (CanExtend n)
canExtend =
case (nat :: Nat n) of
Zero -> j
Succ Zero -> j
Succ (Succ Zero) -> j
Succ (Succ (Succ Zero)) -> j
_ -> Nothing
where
j :: IsNat (m :+: OneT) => Maybe (CanExtend m)
j = Just CanExtend
canExtendE :: IsNat n => f (Vec n a) -> Maybe (CanExtend n)
canExtendE = const canExtend
mulNegNeg :: (IsNat n, IsScalar a, Num a) =>
Vec n a :=> Vec n a :=>? Vec n a
mulNegNeg (Op Negate :^ a) (Op Negate :^ b) = pure (a * b)
mulNegNeg _ _ = mempty
pureE :: Show a => a -> E a
pureE = Op . Lit
fmapE :: (HasType a, HasType b ) =>
Op (a -> b) -> a :=>* b
#ifdef SIMPLIFY
fmapE o (Op (Lit x)) = Op (Lit (opVal o x))
#endif
fmapE o a = simple1 o a @> op1 o a
liftE2 :: (HasType a, HasType b, HasType c ) =>
Op (a -> b -> c) -> a :=> b :=>* c
#ifdef SIMPLIFY
liftE2 o (Op (Lit x)) (Op (Lit y)) = Op (Lit (opVal o x y))
#endif
liftE2 o a b = simple2 o a b @> op2 o a b
liftE3 :: (HasType a, HasType b, HasType c, HasType d ) =>
Op (a -> b -> c -> d) -> a :=> b :=> c :=>* d
#ifdef SIMPLIFY
liftE3 o (Op (Lit x)) (Op (Lit y)) (Op (Lit z)) = Op (Lit (opVal o x y z))
#endif
liftE3 o a b c = simple3 o a b c @> op3 o a b c
liftE4 :: (HasType a, HasType b, HasType c, HasType d, HasType e ) =>
Op (a -> b -> c -> d -> e) -> a :=> b :=> c :=> d :=>* e
#ifdef SIMPLIFY
liftE4 o (Op (Lit w)) (Op (Lit x)) (Op (Lit y)) (Op (Lit z)) =
Op (Lit (opVal o w x y z))
#endif
liftE4 o a b c d = simple4 o a b c d @> op4 o a b c d
noOv :: String -> a
noOv meth = error $ meth ++ ": No overloading for E"
instance Eq (E a) where
(==) = noOv "(==)"
(/=) = noOv "(/=)"
instance (IsNat n, IsScalar a, Ord a, Show a) => Ord (E (Vec n a)) where
min = liftE2 Min
max = liftE2 Max
(<) = noOv "(<)"
instance IsNat n => Boolean (VecE n Bool) where
false = pureU False
true = pureU True
notB = fmapE Not
(&&*) = liftE2 And
(||*) = liftE2 Or
pureU :: (IsNat n, IsScalar a) => a -> VecE n a
pureU x = uniformV' (pureE (vec1 x))
uniformV' :: (IsNat n, IsScalar a, Show a) =>
One a :=>* Vec n a
uniformV' = fmapE (UniformV vectorT)
notE :: IsNat n => Vec n Bool :=>* Vec n Bool
notE = notB
instance (IsNat n, IsScalar a, Show a) => IfB BoolE (VecE n a) where
ifB = liftE3 If
instance (IsNat n, IsScalar a, Eq a, Show a) => EqB (VecE n Bool) (VecE n a) where
(==*) = liftE2 (EqualV nat)
instance (IsNat n, IsScalar a, Ord a, Show a) =>
OrdB (VecE n Bool) (VecE n a) where
(<*) = liftE2 (Lt nat)
infix 4 ==^, /=^
(==^) :: (IsNat n, IsScalar a, Eq a, Show a) =>
Vec n a :=> Vec n a :=>* B1
(==^) = liftE2 Equal
(/=^) :: (IsNat n, IsScalar a, Eq a, Show a) =>
Vec n a :=> Vec n a :=>* B1
(/=^) = (result.result) notE (==^)
instance Enum a => Enum (E a) where
succ = noOv "succ"
pred = noOv "pred"
toEnum = noOv "toEnum"
fromEnum = noOv "fromEnum"
enumFrom = noOv "enumFrom"
enumFromThen = noOv "enumFromThen"
enumFromTo = noOv "enumFromTo"
enumFromThenTo = noOv "enumFromThenTo"
instance (IsNat n, IsScalar a, Num a) =>
Num (E (Vec n a)) where
fromInteger = pureE . fromInteger
negate = fmapE Negate
(+) = liftE2 Add
(*) = liftE2 Mul
abs = fmapE Abs
signum = fmapE Signum
instance (IsNat n, IsScalar a, Ord a, Num a) =>
Real (E (Vec n a)) where
toRational = noOv "toRational"
instance (IsNat n, IsScalar b, Integral b) =>
Integral (E (Vec n b)) where
quot = liftE2 Quot
rem = liftE2 Rem
div = liftE2 Div
mod = liftE2 Mod
quotRem = both quot rem
divMod = both div mod
toInteger = noOv "toInteger"
both :: (a -> b -> c) -> (a -> b -> c') -> (a -> b -> (c,c'))
both f g a b = (f a b, g a b)
instance (IsNat n, IsScalar b, Fractional b) => Fractional (E (Vec n b)) where
recip = fmapE Recip
fromRational = pureE . fromRational
instance (IsNat n, IsScalar b, Floating b) => Floating (E (Vec n b)) where
pi = pureE pi
sqrt = fmapE Sqrt
exp = fmapE Exp
log = fmapE Log
sin = fmapE Sin
cos = fmapE Cos
asin = fmapE Asin
atan = fmapE Atan
acos = fmapE Acos
sinh x = (exp x exp (x)) / 2
cosh x = (exp x + exp (x)) / 2
asinh x = log (x + sqrt (x*x + 1))
acosh x = log (x + sqrt (x*x 1))
atanh x = (log (1 + x) log (1 x)) / 2
instance (IsNat n, IsScalar b, RealFrac b) => RealFrac (E (Vec n b)) where
properFraction = noOv "properFraction"
truncate = noOv "truncate"
round = noOv "round"
ceiling = noOv "ceiling"
floor = noOv "floor"
truncateE, roundE, ceilingE, floorE :: IsNat n => Vec n R :=>* Vec n R
truncateE = fmapE Truncate
roundE = fmapE Round
ceilingE = fmapE Ceiling
floorE = fmapE Floor
instance (IsNat n, IsScalar a, FMod a) => FMod (E (Vec n a)) where
fmod = liftE2 FMod
instance (IsNat n, IsScalar a, FMod a, RealFrac a) => Frac (E (Vec n a)) where
frac = fracViaFmod
allV :: IsNat n => Vec n Bool :=>* B1
allV = fmapE AllV
anyV :: IsNat n => Vec n Bool :=>* B1
anyV = fmapE AnyV
type SamplerE n = E (Sampler n)
texture :: IsNat n => Sampler n :=> Vec n R :=>* R4
texture = liftE2 (Texture nat)
lit :: Show a => a -> E a
lit = Op . Lit
type BoolE = E B1
type FloatE = E R1
type R1E = E R1
type R2E = E R2
type R3E = E R3
type R4E = E R4
type VecE n a = E (Vec n a)
vec2 :: (IsScalar a, Show a) => One a :=> One a :=>* Two a
vec3 :: (IsScalar a, Show a) => One a :=> One a :=> One a :=>* Three a
vec4 :: (IsScalar a, Show a) => One a :=> One a :=> One a :=> One a :=>* Four a
vec2 a b = a <+> b
vec3 a b c = a <+> vec2 b c
vec4 a b c d = a <+> vec3 b c d
un2 :: IsScalar a => Two a :=> (E (One a), E (One a))
un2 u = (getX u, getY u)
un3 :: IsScalar a => Three a :=> (E (One a), E (One a), E (One a))
un3 u = (getX u, getY u, getZ u)
un4 :: IsScalar a => Four a :=> (E (One a), E (One a), E (One a), E (One a))
un4 u = (getX u, getY u, getZ u, getW u)
getX :: (IsNat n, IsScalar a, Show a) =>
Vec (S n) a :=>* One a
getX = get index0
getY :: (IsNat n, IsScalar a, Show a) =>
Vec (S (S n)) a :=>* One a
getY = get index1
getZ :: (IsNat n, IsScalar a, Show a) =>
Vec (S (S (S n))) a :=>* One a
getZ = get index2
getW :: (IsNat n, IsScalar a, Show a) =>
Vec (S (S (S (S n)))) a :=>* One a
getW = get index3
get :: (IsNat n, IsScalar a, Show a) =>
Index n -> (Vec n a) :=>* One a
get i = fmapE (Swizzle (vec1 i))
infixl 1 <+>
(<+>) :: (IsNat m, IsNat n, IsNat (m :+: n), IsScalar a, Show a) =>
Vec m a :=> Vec n a :=>* Vec (m :+: n) a
(<+>) = liftE2 (Cat nat nat vectorT)
unitE :: E ()
unitE = pureE ()
pairE :: (HasType a, HasType b) =>
E a -> E b -> E (a,b)
pairE = liftE2 Pair
fstE :: (HasType a, HasType b ) =>
Show a => E (a,b) -> E a
fstE = fmapE Fst
sndE :: (HasType a, HasType b ) =>
Show b => E (a,b) -> E b
sndE = fmapE Snd
unPairE :: (HasType a, HasType b) =>
E (a,b) -> (E a, E b)
unPairE = fstE &&& sndE
instance UnitF E where unit = unitE
instance PairF E where (#) = pairE
uniform :: (IsNat n, IsScalar a, Show a) =>
(E (Vec n a) -> b) -> (E (One a) -> b)
uniform = (. uniformV)
uniformV :: (IsNat n, IsScalar a, Show a) =>
One a :=>* Vec n a
uniformV = fmapE (UniformV vectorT)
instance (IsNat n, IsScalar a, Num a) =>
AdditiveGroup (E (Vec n a)) where
zeroV = pureE 0
(^+^) = liftE2 Add
negateV = fmapE Negate
instance (IsNat n, IsScalar a, Num a) =>
VectorSpace (E (Vec n a)) where
type Scalar (E (Vec n a)) = E (One a)
s *^ u = uniformV s * u
instance IsNat n => InnerSpace (E (Vec n R)) where
(<.>) = case (nat :: Nat n) of
Succ Zero -> liftE2 Mul
_ -> liftE2 Dot
patE :: Pat a -> E a
patE (BaseG v) = Var v
patE UnitG = unitE
patE (p :* q) = patE p # patE q
class ToE w where
type ExpT w
toEN :: w -> NameM (E (ExpT w))
toE :: ToE w => w -> E (ExpT w)
toE = runNameM . toEN
class ToE w => FromE w where
fromE :: E (ExpT w) -> w
instance ToE (E a) where
type ExpT (E a) = a
toEN = return
instance FromE (E a) where
fromE = id
instance ToE () where
type ExpT () = ()
toEN () = return unit
instance FromE () where fromE = const ()
infixr 1 ##
(##) :: (PairF f, HasType a, HasType b ) =>
NameM (f a) -> NameM (f b) -> NameM (f (a,b))
(##) = liftM2 (#)
instance ( ToE u, Show (ExpT u), HasType (ExpT u)
, ToE v, Show (ExpT v), HasType (ExpT v)
) => ToE (u,v) where
type ExpT (u,v) = (ExpT u, ExpT v)
toEN (u,v) = liftM2 (#) (toEN u) (toEN v)
instance ( FromE u , HasType (ExpT u)
, FromE v , HasType (ExpT v)
) => FromE (u,v) where
fromE e = (fromE eu, fromE ev) where (eu,ev) = unPairE e
instance ( ToE u , HasType (ExpT u)
, ToE v , HasType (ExpT v)
, ToE w , HasType (ExpT w)
) => ToE (u,v,w) where
type ExpT (u,v,w) = ExpT u :# ExpT v :# ExpT w
toEN (u,v,w) = toEN u ## toEN v ## toEN w
instance ( FromE u , HasType (ExpT u)
, FromE v , HasType (ExpT v)
, FromE w , HasType (ExpT w)
) => FromE (u,v,w) where
fromE e = (fromE eu, fromE ev, fromE ew)
where (eu,(ev,ew)) = (second unPairE . unPairE) e
instance (FromE u, ToE v, HasType (ExpT u)) => ToE (u -> v) where
type ExpT (u -> v) = ExpT u -> ExpT v
toEN f = do u <- genVar
b <- toEN (f (fromE (Var u)))
return $ Lam u b
instance ToE (Pat a) where
type ExpT (Pat a) = a
toEN = return . patE
toFromE :: (FromE v, FromE w) => (v -> w) -> (E (ExpT v) -> E (ExpT w))
toFromE = fromE ~> toE
type ComplexE a = Complex (E (One a))
instance (Show a, IsScalar a) => ToE (ComplexE a) where
type ExpT (ComplexE a) = Two a
toEN (x :+ y) = return $ x <+> y
instance (Show a, IsScalar a) => FromE (ComplexE a) where
fromE c = getX c :+ getY c