{-# LANGUAGE RankNTypes, FlexibleContexts, FlexibleInstances, MultiParamTypeClasses, RebindableSyntax #-} module Algebra.AD (D(..),E(..),subst,dVar,var,sqrtE) where import Algebra.Classes hiding ((:+)) import Data.Map (Map) import qualified Data.Map.Strict as M import Prelude hiding (Num(..),(/),fromRational,recip) -- import Data.Vector as V import Data.Function (on) data AST v c = V v | AST v c :* AST v c | AST v c :+ AST v c | AST v c :- AST v c | K c instance (Show v, Show c) => Show (AST v c) where showsPrec p (V v) = shows v showsPrec p (K c ) = shows c showsPrec p (x :+ y) = parens (p>2) (showsPrec 2 x . showString " + " . showsPrec 2 y) showsPrec p (x :* y) = parens (p>3) (showsPrec 3 x . showString " + " . showsPrec 3 y) parens True x = showString "(" . x . showString ")" parens False x = x data D v c = D {dValue :: !c ,dDerivs :: !(Map v c) } deriving Show dVar :: forall v c. Ring c => v -> c -> D v c dVar v c = D c (M.singleton v 1) var :: (Multiplicative c, Additive c, Ord v) => v -> E v c var v = E $ \env -> env v instance (Ord v,Additive c) => Additive (D v c) where zero = D zero zero D v1 d1 + D v2 d2 = D (v1 + v2) (d1 + d2) instance (Ord v,Group c) => Group (D v c) where negate (D x d) = D (negate x) (negate d) D v1 d1 - D v2 d2 = D (v1 - v2) (d1 - d2) instance Ord c => Ord (D v c) where compare = compare `on` dValue instance Eq c => Eq (D v c) where (==) = (==) `on` dValue instance (Ord v,Ring c) => Multiplicative (D v c) where one = D one zero D v1 d1 * D v2 d2 = D (v1 * v2) (v2 *^ d1 + v1 *^ d2) instance (AbelianAdditive c,Ord v) => AbelianAdditive (D v c) instance (Ord v,Ring c) => Module c (D v c) where k *^ D v d = D (k * v) (k *^ d) instance (Ord v,Ring c) => Module (D v c) (D v c) where (*^) = (*) instance (Ord v, Ring c) => Ring (D v c) where fromInteger k = D (fromInteger k) zero newtype E v c = E {fromE :: (v -> D v c) -> D v c} instance (Ord v,Additive c) => Additive (E v c) where zero = E (const zero) (+) = liftE2 (+) instance (Ord v,Group c) => Group (E v c) where negate (E x) = E (negate . x) (-) = liftE2 (-) instance (Ord v,Ring c) => Multiplicative (E v c) where one = E (const one) (*) = liftE2 (*) instance (Ord v,Ring c) => AbelianAdditive (E v c) instance (Ord v,Ring c) => Module c (E v c) where k *^ E x = E ((k *^) . x) instance (Ord v,Ring c) => Module (E v c) (E v c) where (*^) = (*) instance (Ord v, Ring c) => Ring (E v c) where fromInteger k = E (\ _ -> fromInteger k) liftE2 :: forall t t1. (D t t1 -> D t t1 -> D t t1) -> E t t1 -> E t t1 -> E t t1 liftE2 f (E x) (E y) = E (\e -> f (x e) (y e)) liftE :: forall t t1. (D t t1 -> D t t1) -> E t t1 -> E t t1 liftE f (E x) = E (\e -> f (x e)) subst :: E v c -> (v -> E v c) -> E v c subst (E p) f = E $ \k -> p (\a -> fromE (f a) k) sqrtD :: (Ord v, Floating c, Field c) => D v c -> D v c sqrtD (D v d) = D (sqrtv) ((0.5/sqrtv) *^ d) where sqrtv = sqrt v sqrtE :: forall t t1. (Floating t1, Ord t, Field t1) => E t t1 -> E t t1 sqrtE = liftE sqrtD instance (Field c,Ord v) => Division (D v c) where recip (D v d) = D (recip v) (negate (square iv) *^ d) where square x = x*x iv = recip v instance (Field c,Ord v) => Division (E v c) where recip = liftE recip (/) = liftE2 (/)