module TypedFlow.TF (
parameter',
parameter,
parameterDefault,
ParamWithDefault(..),
getParameters,
persistent,
modifyPersistent,
placeholder,
peekAt,
peekAtMany,
zeros,
ones,
constant,
round, sigmoid, tanh, log, relu, floor, negate,
add, (+), (⊕), (⊝), (⊙), (⊘), equal,
(∙), (·), matmul,
reduceMeanAll, reduceSumAll,
reduceSum, reduceMean,
argmax, argmax0, argmax1,
softmax0, softmax1,
grad,
clipByGlobalNorm,
clipByValue,
last0, nth0, nth0', gather,
split0, slice, slice1,
stack0, unstack0, stackN,
stack1,
concatT, concat0, concat1,
expandDim,
expandDim0, squeeze0,
expandDim1, squeeze1,
flatten2, inflate2, flattenN2,
flatten3, inflate3,
reshape, flattenAll, inflateAll,
transpose, transposeN, transposeN', transpose01, transposeN01,
reverseSequences, sequenceMask,
cast,
convolution,
oneHot, oneHot0, oneHot1,
if_, where_,
mapT, mapTN, zipWithT, zipWithTN,
sigmoidCrossEntropyWithLogits,
softmaxCrossEntropyWithLogits,
sparseSoftmaxCrossEntropyWithLogits,
truncatedNormal, randomUniform, randomOrthogonal, varianceScaling, glorotUniform,
repeatT, flattenHTV, inflateHTV, KnownTensors(..), LastEqual
) where
import Prelude hiding (tanh,Num(..),Floating(..),round,floor)
import qualified Prelude
import Prelude (())
import Text.PrettyPrint.Compact hiding (Last, All,Product,Sum)
import GHC.TypeLits
import Data.Proxy
import TypedFlow.Types
import Control.Monad (when)
repeatT :: forall (ss :: [Shape]) t. All KnownShape ss => KnownLen ss =>
(forall s. KnownShape s => T s t) -> HTV t ss
repeatT f = zs (shapeSList @ss)
where zs :: forall (s :: [Shape]). All KnownShape s => SList s -> HTV t s
zs LZ = Unit
zs (LS _ n) = F f :* zs n
zeros :: ∀ t (shape :: Shape). KnownShape shape => KnownTyp t => (T shape t)
zeros = T (funcall "tf.zeros" [showShape @shape, named "dtype" (showTyp @t)])
ones :: ∀ t (shape :: Shape). KnownShape shape => KnownTyp t => (T shape t)
ones = T (funcall "tf.ones" [showShape @shape, named "dtype" (showTyp @t)])
constant :: forall s w. KnownShape s => KnownBits w => Float -> T s ('Typ 'Float w)
constant c = T (funcall "tf.constant" [float c, named "shape" (showShape @s), named "dtype" (showTyp @(Flt w))])
persistent :: ∀ (shape :: Shape) t. (KnownTyp t,KnownShape shape) => Bool -> String -> T shape t -> Gen (T shape t)
persistent trainable name (T initial) = do
v <- newVar
when trainable (newParameter (ParamInfo name (shapeToList @shape) (typVal @t) (T v)))
v <-- funcall "tf.Variable" [initial, named "name" (string (show (name))), named "trainable" (bool trainable)]
return (T v)
parameter' :: ∀ (shape :: Shape) t. (KnownTyp t,KnownShape shape) => String -> T shape t -> Gen (T shape t)
parameter' = persistent True
peekAt :: String -> Tensor s t -> Gen ()
peekAt p (T v) = peekAtAny p v
peekAtMany :: String -> HTV t xs -> Gen ()
peekAtMany p htv = peekAtAny p (list $ htoList $ hmap (\(F (T x)) -> K x) htv)
modifyPersistent :: T s t -> T s t -> T s t
modifyPersistent (T ref) (T value) = T (funcall "tf.assign" [ref,value])
getParameters :: Gen UntypedExpression
getParameters = do
v <- newVar
v <-- text "tf.trainable_variables()"
return v
grad :: T s Float32 -> UntypedExpression -> UntypedExpression
grad (T y) vars = funcall "tf.gradients" [y, vars]
clipByGlobalNorm :: Float -> UntypedExpression -> UntypedExpression
clipByGlobalNorm maxNorm x = funcall "tf.clip_by_global_norm" [x,float maxNorm] <> brackets (int 0)
clipByValue :: Float -> Float -> T s (Flt t) -> T s (Flt t)
clipByValue lo hi (T x) = T (funcall "tf.clip_by_value" [x, float lo, float hi])
placeholder :: ∀t s. (KnownShape s, KnownTyp t) => String -> Gen (T s t)
placeholder n = do
let name = text n
name <-- funcall "tf.placeholder" [showTyp @t, named "shape" (showShape @s), named "name" (text (show n))]
peekAt n (T name)
return (T name)
reduceAll :: String -> Tensor s t -> Tensor '[] t
reduceAll op = unOp ("tf.reduce_" ++ op)
reduceMeanAll, reduceSumAll :: ∀ (s :: Shape) t. Tensor s t -> Tensor '[] t
reduceMeanAll = reduceAll "mean"
reduceSumAll = reduceAll "sum"
reduce :: ∀ n s t. (KnownLen s,KnownPeano n) => String -> T s t -> T (Take n s ++ Drop ('Succ n) s) t
reduce op (T x) = T (funcall ("tf.reduce_" ++ op) [x, text "axis=" <> integer (listLen @ s peanoInt @n 1)])
reduceSum, reduceMean :: ∀n s t. (KnownLen s,KnownPeano n) => T s t -> T (Take n s ++ Drop ('Succ n) s) t
reduceSum = reduce @n "sum"
reduceMean = reduce @n "mean"
reduceSum0 :: ∀ s' n t. KnownLen s' => Tensor (n ': s') t -> Tensor s' t
reduceSum0 = reduceSum @Dim0
add :: ∀ s d t. Tensor (d++s) t -> Tensor d t -> Tensor (d++s) t
add = binOp "tf.add"
(+) :: ∀ (d :: Shape) (s :: Shape) t. Tensor (d ++ s) t -> Tensor d t -> Tensor (d ++ s) t
(+) = add @s @d
infixl 6 +
equal :: Tensor d t -> Tensor d t -> Tensor d TFBool
equal = binOp "tf.equal"
(⊕), (⊝), (⊙), (⊘) :: ∀ (s :: Shape) t. Tensor s t -> Tensor s t -> Tensor s t
(⊝) = binOp "tf.subtract"
(⊙) = binOp "tf.multiply"
(⊘) = binOp "tf.divide"
(⊕) = binOp "tf.add"
infixl 7 ⊙,⊘
infixl 6 ⊕,⊝
matmul :: Tensor (o ': n ': s) t -> Tensor (m ': o ': s) t -> Tensor (m ': n ': s) t
matmul = binOp "tf.matmul"
round, sigmoid, tanh, log, relu, floor
:: ∀ s t. Tensor s ('Typ 'Float t) -> Tensor s ('Typ 'Float t)
sigmoid = unOp "tf.sigmoid"
tanh = unOp "tf.tanh"
log = unOp "tf.log"
relu = unOp "tf.nn.relu"
round = unOp "tf.round"
floor = unOp "tf.floor"
negate :: ∀ s t. T s t -> T s t
negate = unOp "-"
split0 :: ∀ n m batchShape t. (KnownNat n, KnownNat m, KnownLen batchShape) =>
Tensor ((n + m) ': batchShape) t -> Gen (Tensor (n ': batchShape) t, Tensor (m ': batchShape) t)
split0 (T x) = do
v1 <- newVar
v2 <- newVar
gen (v1 <> text "," <> v2 <> text " = " <> funcall "tf.split" [x, list [showDim @ n, showDim @ m], text "axis=" <> showShapeLen @batchShape])
return (T v1, T v2)
concatT :: ∀ n d1 d2 s t. (KnownPeano n, KnownLen s, (d1+d2) ~ At n s) =>
T (Take n s ++ (d1 ': Drop ('Succ n) s)) t -> T (Take n s ++ (d2 ': Drop ('Succ n) s)) t -> T s t
concatT (T x) (T y) = T (funcall "tf.concat" [list [x,y], named "axis" (integer (listLen @s peanoInt @n 1))])
concat0 :: ∀ ys d1 d2 t. (KnownLen ys) => T (d1 ': ys) t -> T (d2 ': ys) t -> T ((d1 + d2) ': ys) t
concat0 = concatT @Dim0
concat1 :: ∀ n ys d1 d2 t. (KnownLen ys) => T (n ': d1 ': ys) t -> T (n ': d2 ': ys) t -> T (n ': (d1 + d2) ': ys) t
concat1 = concatT @Dim1
expandDim :: forall n s t. (KnownLen s, KnownPeano n) => Tensor s t -> Tensor (Take n s ++ (1 ': Drop n s)) t
expandDim (T x) = (T (funcall "tf.expand_dims" [x, named "axis" (integer (listLen @s peanoInt @n))]))
expandDim0 :: ∀ s t. KnownLen s => Tensor s t -> Tensor (1 ': s) t
expandDim0 = expandDim @Dim0
expandDim1 :: ∀ n s t. KnownShape s => Tensor (n ': s) t -> Tensor (n ': 1 ': s) t
expandDim1 = expandDim @Dim1
squeeze :: ∀ s0 s1 t. KnownLen s1 => Tensor (s0 ++ (1 ': s1)) t -> Tensor (s0 ++ s1) t
squeeze (T x) = T (funcall "tf.squeeze" [x, text "axis=" <> integer (listLen @ s1)])
squeeze0 :: ∀ s t. KnownLen s => Tensor (1 ': s) t -> Tensor s t
squeeze0 = squeeze @ '[]
squeeze1 :: ∀ n s t. KnownLen s => Tensor (n ': 1 ': s) t -> Tensor (n ': s) t
squeeze1 = squeeze @ '[n]
reshape :: ∀ s2 s1 t. KnownShape s2 => Product s1 ~ Product s2 => Tensor s1 t -> Tensor s2 t
reshape = unsafeReshape
unsafeReshape :: ∀ s2 s1 t. KnownShape s2 => Tensor s1 t -> Tensor s2 t
unsafeReshape (T t) = T (funcall "tf.reshape" [t, showShapeMinus @s2])
flatten2 :: ∀ m n s t. (KnownNat m, KnownNat n, KnownShape s) => Tensor (m ': n ': s) t -> Tensor (m*n ': s) t
flatten2 = prodAssoc @m @n @(Product s) reshape
flattenN2 :: ∀ s m n t. (KnownNat m, KnownNat n, KnownShape s) => Tensor (s ++ '[m,n]) t -> Tensor (s ++ '[m*n]) t
flattenN2 = prodHomo @s @'[m,n] $
prodHomo @s @'[m*n] $
knownAppend @s @'[m*n] $
reshape
flatten3 :: ∀ m n o s t. (KnownNat m, KnownNat n, KnownNat o, KnownShape s) => Tensor (m ': n ': o ': s) t -> Tensor (m*n*o ': s) t
flatten3 =
prodAssoc @m @n @(o * Product s) $
prodAssoc @(m * n) @o @(Product s) $
reshape
inflate2 :: ∀ m n s t. (KnownNat m, KnownNat n, KnownShape s) => Tensor (m*n ': s) t -> Tensor (m ': n ': s) t
inflate2 = prodAssoc @m @n @(Product s) reshape
inflate3 :: ∀ m n o s t. (KnownNat m, KnownNat n, KnownNat o, KnownShape s) => Tensor (m*n*o ': s) t -> Tensor (m ': n ': o ': s) t
inflate3 =
prodAssoc @m @n @(o * Product s) $
prodAssoc @(m * n) @o @(Product s) $
reshape
last0 :: ∀ n s t. KnownNat n => KnownLen s => T (n ': s) t -> Tensor s t
last0 = nth0 (natVal (Proxy @n) 1)
nth0 :: ∀ n s t. KnownLen s => Integer -> T (n ': s) t -> Tensor s t
nth0 i (T x) = T (x <> list (replicate (fromIntegral (listLen @s)) (text ":") ++ [integer i]))
nth0' :: ∀ n m s t. KnownNat n => KnownLen s => n < m => T (m ': s) t -> Tensor s t
nth0' (T x) = T (x <> list (replicate (fromIntegral (listLen @s)) (text ":") ++ [integer (natVal (Proxy @n))]))
slice :: forall n i j s t. KnownNat j => KnownNat i => (i < j, j <= At n s, KnownPeano n, KnownLen s) =>
Tensor s t -> Tensor (Take n s ++ ((ji) ': Drop ('Succ n) s)) t
slice (T x) = T (x <> list (replicate (fromIntegral (listLen @s peanoInt @n 1)) (text ":") ++ [integer (natVal (Proxy @i)) <> text ".." <> integer (natVal (Proxy @j))]))
slice1 :: forall i j m n s t. KnownNat j => KnownNat i => (i < j, j <= m, KnownLen s) =>
Tensor (n ': m ': s) t -> Tensor (n ': (ji) ': s) t
slice1 = slice @Dim1 @i @j
unstack0 :: ∀ s (n::Nat) t. (KnownLen s, KnownNat n) => Tensor (n ': s) t -> Gen (V n (T s t))
unstack0 (T x) = do
v <- newVar
v <-- funcall "tf.unstack" [x, text "axis=" <> integer (listLen @ s)]
return $ V $ [ T $ v <> brackets (integer i)| i <- [0..n Prelude.- 1] ]
where n = natVal (Proxy @ n)
stack0 :: ∀ s (n::Nat) t. (KnownLen s) => V n (T s t) -> Tensor (n ': s) t
stack0 (V xs) = T (funcall "tf.stack" [list [x | T x <- xs], text "axis=" <> integer (listLen @ s)])
stack1 :: ∀ s (n::Nat) m t. (KnownLen s) => V n (T (m ': s) t) -> Tensor (m ': n ': s) t
stack1 (V xs) = T (funcall "tf.stack" [list [x | T x <- xs], text "axis=" <> integer (listLen @ s)])
stackN :: ∀ s (n::Nat) t. V n (T s t) -> Tensor (s ++ '[n]) t
stackN (V xs) = T (funcall "tf.stack" [list [x | T x <- xs], text "axis=0"])
transpose :: ∀ s t. T (Reverse s) t -> T s t
transpose = unOp "tf.transpose"
transposeN :: ∀ s n t. KnownLen s => T (n ': s) t -> T (s ++ '[n]) t
transposeN (T x) = T (funcall "tf.transpose" [x, named "perm" (list (map integer (listLen @s:[0.. listLen @s1])))])
transposeN' :: ∀ s n t. KnownLen s => T (s ++ '[n]) t -> T (n ': s) t
transposeN' (T x) = T (funcall "tf.transpose" [x, named "perm" (list (map integer ([1.. listLen @s]++[0])))])
transpose01 :: ∀ s m n t. KnownLen s => T (m ': n ': s) t -> T (n ': m ': s) t
transpose01 (T x) = T (funcall "tf.transpose" [x, named "perm" (list (map integer ([0..l1] ++ [l Prelude.+ 1,l])))])
where l = listLen @s
transposeN01 :: ∀ s m n t. T (s ++ [m,n]) t -> T (s ++ [n,m]) t
transposeN01 (T x) = T (funcall "tf.transpose" [x, named "perm" (list (map integer [1,0]))])
class LastEqual x xs
instance LastEqual x (x ': '[])
instance LastEqual x (y2 ': xs) => LastEqual x (y ': (y2 ': xs))
reverseSequences :: forall bs n x t. KnownLen x => LastEqual bs x => T '[bs] Int32 -> T (n ': x) t -> T (n ': x) t
reverseSequences (T seqLengths) (T input) =
T (funcall "tf.reverse_sequence" [input, seqLengths, named "seq_axis" (showShapeLen @x),named "batch_axis" (int 0)])
sequenceMask :: forall maxlen bs. KnownNat maxlen => Tensor '[bs] Int32 -> Tensor '[maxlen,bs] TFBool
sequenceMask (T x) = T (funcall "tf.sequence_mask" [x, named "maxlen" (showDim @maxlen)])
gather :: ∀s n indexShape t. T (s ++ '[n]) t -> T indexShape Int32 -> T (s ++ indexShape) t
gather = binOp "tf.gather"
convolution :: forall outputChannels filterSpatialShape inChannels s t.
KnownLen filterSpatialShape
=> Length filterSpatialShape <= 3
=> ((1 + Length filterSpatialShape) ~ Length s)
=> T ('[inChannels] ++ s) t
-> T ('[outputChannels,inChannels] ++ filterSpatialShape) t
-> T ('[outputChannels] ++ s) t
convolution (T input) (T filters) = T (funcall "tf.nn.convolution" [input,filters
,named "padding" (text (show "SAME"))
,named "data_format" (text (show dataFormat))])
where dataFormat = case listLen @ filterSpatialShape of
1 -> "NWC"
2 -> "NHWC"
3 -> "NDHWC"
_ -> error "convolution: more than 3 spatial dimensions are not supported!"
softmax0 :: T (n ': s) ('Typ 'Float w) -> T (n ': s) ('Typ 'Float w)
softmax0 = unOp "tf.nn.softmax"
softmax1 :: forall n m s w. KnownLen s => T (m ': n ': s) ('Typ 'Float w) -> T (m ': n ': s) ('Typ 'Float w)
softmax1 (T x) = T (funcall "tf.nn.softmax" [x, named "dim" (showShapeLen @s)])
argmax :: forall n u m s t. (KnownLen s, KnownPeano n,KnownBits u) => Tensor (Take n s ++ (m ': Drop n s)) t -> Tensor s ('Typ 'Int u)
argmax (T t) = T (funcall "tf.argmax" [t, named "axis" (integer ((listLen @ s) peanoInt @n)) , named "output_type" (showTyp @('Typ 'Int u))])
argmax0 :: forall u n s t. (KnownLen s, KnownBits u) => T (n ': s) t -> T s ('Typ 'Int u)
argmax0 = argmax @Dim0
argmax1 :: forall u m n s t. (KnownLen s, KnownBits u) => T (m ': n ': s) t -> T (m ': s) ('Typ 'Int u)
argmax1 = argmax @Dim1
cast :: forall u s t. KnownTyp u => T s t -> T s u
cast (T t) = T (funcall "tf.cast" [t, showTyp @ u])
softmaxCrossEntropyWithLogits :: Tensor '[numClasses,batchSize] Float32
-> Tensor '[numClasses,batchSize] Float32
-> Tensor '[batchSize] Float32
softmaxCrossEntropyWithLogits (T labels) (T logits) =
T (funcall "tf.nn.softmax_cross_entropy_with_logits" [named "labels" labels,named "logits" logits])
sigmoidCrossEntropyWithLogits :: Tensor s (Flt w)
-> Tensor s (Flt w)
-> Tensor s (Flt w)
sigmoidCrossEntropyWithLogits (T labels) (T logits) =
T (funcall "tf.nn.sigmoid_cross_entropy_with_logits" [named "labels" labels,named "logits" logits])
sparseSoftmaxCrossEntropyWithLogits :: Tensor s Int32
-> Tensor (numClasses ': s) (Flt t)
-> Tensor s (Flt t)
sparseSoftmaxCrossEntropyWithLogits (T labels) (T logits) =
T (funcall "tf.nn.sparse_softmax_cross_entropy_with_logits" [named "labels" labels,named "logits" logits])
oneHot :: forall n numClasses s w t. KnownNat numClasses => KnownBits t =>
(KnownLen s, KnownPeano n) => Tensor s ('Typ 'Int w) -> Tensor (Take n s ++ (numClasses ': Drop n s)) (Flt t)
oneHot (T x) = T (funcall "tf.one_hot" [x, named "depth" (showDim @numClasses), named "axis" (integer (listLen @s peanoInt @n)), named "dtype" (showTyp @(Flt t))])
oneHot0 :: forall numClasses w batchSize t. KnownNat numClasses => KnownBits t => Tensor '[batchSize] ('Typ 'Int w) -> Tensor '[numClasses,batchSize] (Flt t)
oneHot0 = oneHot @Dim0
oneHot1 :: forall numClasses w batchSize m t. KnownNat numClasses => KnownBits t => Tensor '[m,batchSize] ('Typ 'Int w) -> Tensor '[m,numClasses,batchSize] (Flt t)
oneHot1 = oneHot @Dim1
truncatedNormal :: forall s w. KnownShape s => KnownBits w => Float -> T s ('Typ 'Float w)
truncatedNormal stddev = T (funcall "tf.truncated_normal" [showShape @s, named "stddev" (float stddev), named "dtype" (showTyp @(Flt w))])
randomUniform :: forall s t. (KnownShape s, KnownTyp t) => Float -> Float -> T s t
randomUniform low high = T (funcall "tf.random_uniform" [showShape @s
,named "minval" (float low)
,named "maxval" (float high)
,named "dtype" (showTyp @t)])
randomOrthogonal :: forall n s t. (KnownBits t, KnownNat n, KnownShape s) => T (n ':s) ('Typ 'Float t)
randomOrthogonal = T (funcall' (funcall "tf.orthogonal_initializer" [named "dtype" (showTyp @('Typ 'Float t))])
[named "shape" (showShape @(n ': s))])
data VarianceScaleMode = VSFanIn | VSFanOut | VSAvg
data Distrib = NormalDistr | UniformDistr
varianceScaling :: forall inDim outDim t. KnownNat inDim => (KnownNat outDim, KnownBits t) =>
Float -> VarianceScaleMode -> Distrib -> Tensor '[inDim,outDim] ('Typ 'Float t)
varianceScaling factor mode distr = case distr of
UniformDistr -> randomUniform (limit) limit
NormalDistr -> truncatedNormal limit
where
fan_in = fromIntegral (natVal (Proxy @inDim))
fan_out = fromIntegral (natVal (Proxy @outDim))
n = max 1 $ case mode of
VSFanIn -> fan_in
VSFanOut -> fan_out
VSAvg -> (fan_in Prelude.+ fan_out) / 2
limit = Prelude.sqrt ((case distr of NormalDistr -> 1.3; UniformDistr -> 3) Prelude.* factor / n)
glorotUniform :: forall inDim outDim t. KnownNat inDim => (KnownNat outDim, KnownBits t) => Tensor '[inDim,outDim] ('Typ 'Float t)
glorotUniform = varianceScaling 1 VSAvg UniformDistr
(∙) :: Tensor '[cols, rows] t -> Tensor '[cols,batchSize] t -> Tensor '[rows,batchSize] t
m ∙ v = matmul v (transpose m)
infixl 7 ∙
(·) :: ∀ cols batchSize t. Tensor '[cols,batchSize] t -> Tensor '[cols,batchSize] t -> Tensor '[batchSize] t
x · y = reduceSum0 (x ⊙ y)
infixl 7 ·
mapT :: forall s t r u n. KnownTyp u => KnownLen r => KnownLen s => (T s t -> T r u) -> T (n ': s) t -> Gen (T (n ': r) u)
mapT f x = do
x' <- mapTN @n f (transposeN @s @n x)
return (transposeN' @r x')
mapTN :: forall n s t r u. KnownTyp u => (T s t -> T r u) -> T (s ++ '[n]) t -> Gen(T (r ++ '[n]) u)
mapTN f t = do
fn <- lambda f
return (T (funcall "tf.map_fn" [fn, fromTensor t, named "dtype" (showTyp @u)]))
zipWithT :: forall (s :: [Nat]) (t :: Typ) (s1 :: [Nat]) (t1 :: Typ) (s2 :: Shape) (n :: Nat) (t2 :: Typ).
KnownNat n => (KnownLen s, KnownLen s2, KnownLen s1) => KnownTyp t2 =>
(T s t -> T s1 t1 -> T s2 t2)
-> Tensor (n ': s) t
-> Tensor (n ': s1) t1
-> Gen (Tensor (n ': s2) t2)
zipWithT f x y = do
x' <- zipWithTN @n f (transposeN @s @n x) (transposeN @s1 @n y)
return (transposeN' @s2 x')
zipWithTN :: forall (n :: Nat) (s :: [Nat]) (t :: Typ) (s1 :: [Nat]) (t1 :: Typ) (s2 :: Shape) (t2 :: Typ).
KnownTyp t2 =>
(T s t -> T s1 t1 -> T s2 t2)
-> Tensor (s ++ '[n]) t
-> Tensor (s1 ++ '[n]) t1
-> Gen (Tensor (s2 ++ '[n]) t2)
zipWithTN f (T t) (T u) = do
fn <- lambda2 f
return (T (funcall "tf.map_fn" [fn, tuple [t,u], named "dtype" (showTyp @t2)]))
lambda2 :: (T s t -> T s1 t1 -> T s' t') -> Gen UntypedExpression
lambda2 f = do
v <- newVar
let T body = f (T (v <> brackets (int 0))) (T (v <> brackets (int 1)))
return (text "lambda " <> v <> text ": " <> body)
if_ :: Scalar TFBool -> T s t -> T s t -> T s t
if_ (T c) (T x) (T y) = T (funcall "tf.cond" [
c,
(lambda0 x),
(lambda0 y),
named "strict" (bool True)])
where lambda0 z = text "lambda: " <> z
where_ :: T s TFBool -> T s t -> T s t -> T s t
where_ (T c) (T x) (T y) = T (funcall "tf.where" [c, x, y])
parameterDefault :: forall p. ParamWithDefault p => String -> Gen p
parameterDefault name = parameter name defaultInitializer
parameter :: forall p. KnownTensors p => String -> p -> Gen p
parameter = travTensor parameter'
class KnownTensors p where
travTensor :: (forall s t. (KnownTyp t, KnownShape s) => String -> T s t -> Gen (T s t)) -> String -> p -> Gen p
instance (KnownTyp t, KnownShape shape) => KnownTensors (T shape t) where
travTensor f = f
instance (KnownTyp t, All KnownShape ys) => KnownTensors (HTV t ys) where
travTensor f s = ttr 0
where ttr :: forall xs. All KnownShape xs => Int -> HTV t xs -> Gen (HTV t xs)
ttr _ Unit = return Unit
ttr n (F x :* xs) = do
x' <- f (s <> "_" <> show n) x
xs' <- ttr (n Prelude.+ 1) xs
return (F x' :* xs')
instance (KnownTensors p, KnownTensors q) => KnownTensors (p,q) where
travTensor f s (x,y) = (,) <$> travTensor f (s<>"_fst") x <*> travTensor f (s<>"_snd") y
instance (KnownTensors p1, KnownTensors p2, KnownTensors p3) => KnownTensors (p1,p2,p3) where
travTensor f s (x,y,z) = (,,) <$> travTensor f (s<>"_1") x <*> travTensor f (s<>"_2") y <*> travTensor f (s<>"_3") z
instance (KnownTensors p1, KnownTensors p2, KnownTensors p3, KnownTensors p4) => KnownTensors (p1,p2,p3,p4) where
travTensor f s (x,y,z,w) = (,,,) <$> travTensor f (s<>"_1") x <*> travTensor f (s<>"_2") y <*> travTensor f (s<>"_3") z <*> travTensor f (s<>"_4") w
class KnownTensors p => ParamWithDefault p where
defaultInitializer :: p
flattenAll :: forall s t. KnownShape s => Tensor s t -> Tensor '[Product s] t
flattenAll = knownProduct @s reshape
flattenHTV :: KnownTyp t => All KnownShape xs => HTV t xs -> Tensor '[Sum (Ap (FMap CProduct) xs)] t
flattenHTV Unit = zeros
flattenHTV (F x :* xs) = concat0 (flattenAll x) (flattenHTV xs)
inflateAll :: forall s t. KnownShape s => Tensor '[Product s] t -> Tensor s t
inflateAll = knownProduct @s reshape
class CProduct (xs :: [Nat])
instance Fun CProduct where type Ap CProduct xs = Product xs
inflateHTV :: ∀ xs s t. (All KnownShape xs, KnownLen s, KnownLen xs) =>
Tensor '[Sum (Ap (FMap CProduct) xs)] t -> Gen (HTV t xs)
inflateHTV (T x) = do
v <- newVar
gen (v <> text " = " <> funcall "tf.split" [x, showShape' (prodshape @xs shapeSList), text "axis=0"])
return (mkArr @xs 0 shapeSList v)
where mkArr :: forall zs. All KnownShape zs => Int -> SList zs -> DOC -> HTV t zs
mkArr _ LZ _ = Unit
mkArr i (LS _ n) v = F (unsafeReshape (T (v <> brackets (int i)) )):* mkArr (succ i) n v
prodshape :: forall zs. All KnownShape zs => SList zs -> [Integer]
prodshape LZ = []
prodshape (LS xx xs) = product (shapeToList' (shapeSListProxy xx)) : prodshape xs