module SubHask.Category.Trans.Derivative
where
import SubHask.Algebra
import SubHask.Algebra.Vector
import SubHask.Category
import SubHask.SubType
import SubHask.Internal.Prelude
import qualified Prelude as P
data Forward a = Forward
{ val :: !a
, val' :: a
}
deriving (Typeable,Show)
mkMutable [t| forall a. Forward a |]
instance Semigroup a => Semigroup (Forward a) where
(Forward a1 a1')+(Forward a2 a2') = Forward (a1+a2) (a1'+a2')
instance Cancellative a => Cancellative (Forward a) where
(Forward a1 a1')(Forward a2 a2') = Forward (a1a2) (a1'a2')
instance Monoid a => Monoid (Forward a) where
zero = Forward zero zero
instance Group a => Group (Forward a) where
negate (Forward a b) = Forward (negate a) (negate b)
instance Abelian a => Abelian (Forward a)
instance Rg a => Rg (Forward a) where
(Forward a1 a1')*(Forward a2 a2') = Forward (a1*a2) (a1*a2'+a2*a1')
instance Rig a => Rig (Forward a) where
one = Forward one zero
instance Ring a => Ring (Forward a) where
fromInteger x = Forward (fromInteger x) zero
instance Field a => Field (Forward a) where
reciprocal (Forward a a') = Forward (reciprocal a) (a'/(a*a))
(Forward a1 a1')/(Forward a2 a2') = Forward (a1/a2) ((a1'*a2+a1*a2')/(a2'*a2'))
fromRational r = Forward (fromRational r) 0
proveC1 :: (a ~ (a><a), Rig a) => (Forward a -> Forward a) -> C1 (a -> a)
proveC1 f = Diffn (\a -> val $ f $ Forward a one) $ Diff0 $ \a -> val' $ f $ Forward a one
proveC2 :: (a ~ (a><a), Rig a) => (Forward (Forward a) -> Forward (Forward a)) -> C2 (a -> a)
proveC2 f
= Diffn (\a -> val $ val $ f $ Forward (Forward a one) one)
$ Diffn (\a -> val' $ val $ f $ Forward (Forward a one) one)
$ Diff0 (\a -> val' $ val' $ f $ Forward (Forward a one) one)
class C (cat :: * -> * -> *) where
type D cat :: * -> * -> *
derivative :: cat a b -> D cat a (a >< b)
data Diff (n::Nat) a b where
Diff0 :: (a -> b) -> Diff 0 a b
Diffn :: (a -> b) -> Diff (n1) a (a >< b) -> Diff n a b
instance Sup (->) (Diff n) (->)
instance Sup (Diff n) (->) (->)
instance Diff 0 <: (->) where
embedType_ = Embed2 unDiff0
where
unDiff0 :: Diff 0 a b -> a -> b
unDiff0 (Diff0 f) = f
instance Diff n <: (->) where
embedType_ = Embed2 unDiffn
where
unDiffn :: Diff n a b -> a -> b
unDiffn (Diffn f f') = f
instance Sup (Diff 0) (Diff 1) (Diff 0)
instance Sup (Diff 1) (Diff 0) (Diff 0)
instance Diff 1 <: Diff 0 where embedType_ = Embed2 m2n where m2n (Diffn f f') = Diff0 f
instance Sup (Diff 0) (Diff 2) (Diff 0)
instance Sup (Diff 2) (Diff 0) (Diff 0)
instance Diff 2 <: Diff 0 where embedType_ = Embed2 m2n where m2n (Diffn f f') = Diff0 f
instance Sup (Diff 1) (Diff 2) (Diff 1)
instance Sup (Diff 2) (Diff 1) (Diff 1)
instance Diff 2 <: Diff 1 where embedType_ = Embed2 m2n where m2n (Diffn f f') = Diffn f (embedType2 f')
instance (1 <= n) => C (Diff n) where
type D (Diff n) = Diff (n1)
derivative (Diffn f f') = f'
unsafeProveC0 :: (a -> b) -> Diff 0 a b
unsafeProveC0 f = Diff0 f
unsafeProveC1
:: (a -> b)
-> (a -> a><b)
-> C1 (a -> b)
unsafeProveC1 f f' = Diffn f $ unsafeProveC0 f'
unsafeProveC2
:: (a -> b)
-> (a -> a><b)
-> (a -> a><a><b)
-> C2 (a -> b)
unsafeProveC2 f f' f'' = Diffn f $ unsafeProveC1 f' f''
type C0 a = C0_ a
type family C0_ (f :: *) :: * where
C0_ (a -> b) = Diff 0 a b
type C1 a = C1_ a
type family C1_ (f :: *) :: * where
C1_ (a -> b) = Diff 1 a b
type C2 a = C2_ a
type family C2_ (f :: *) :: * where
C2_ (a -> b) = Diff 2 a b
mkMutable [t| forall n a b. Diff n a b |]
instance Semigroup b => Semigroup (Diff 0 a b) where
(Diff0 f1 )+(Diff0 f2 ) = Diff0 (f1+f2)
instance (Semigroup b, Semigroup (a><b)) => Semigroup (Diff 1 a b) where
(Diffn f1 f1')+(Diffn f2 f2') = Diffn (f1+f2) (f1'+f2')
instance (Semigroup b, Semigroup (a><b), Semigroup (a><a><b)) => Semigroup (Diff 2 a b) where
(Diffn f1 f1')+(Diffn f2 f2') = Diffn (f1+f2) (f1'+f2')
instance Monoid b => Monoid (Diff 0 a b) where
zero = Diff0 zero
instance (Monoid b, Monoid (a><b)) => Monoid (Diff 1 a b) where
zero = Diffn zero zero
instance (Monoid b, Monoid (a><b), Monoid (a><a><b)) => Monoid (Diff 2 a b) where
zero = Diffn zero zero