{-# LANGUAGE MultiParamTypeClasses, RankNTypes, ScopedTypeVariables, FlexibleInstances, FlexibleContexts, UndecidableInstances, PolyKinds, LambdaCase, NoMonomorphismRestriction, TypeFamilies, LiberalTypeSynonyms, FunctionalDependencies, ExistentialQuantification, InstanceSigs, TupleSections, ConstraintKinds, DefaultSignatures, UndecidableSuperClasses, TypeOperators, TypeApplications, PartialTypeSignatures #-} module Lang where import DBI import qualified Prelude as P import Prelude (($), (.), (+), (-), (++), show, (>>=), (*), (/), undefined) import qualified Control.Monad.Writer as P import qualified Data.Functor.Identity as P import qualified GHC.Float as P import qualified Data.Tuple as P import Data.Void import Data.Proxy import Data.Proxy import Data.Constraint import Data.Constraint.Forall type instance Diff v (P.Writer w a) = P.Writer (Diff v w) (Diff v a) type instance Diff v Void = Void type instance Diff v P.Double = (P.Double, v) type instance Diff v P.Float = (P.Float, v) type instance Diff v (P.Either a b) = P.Either (Diff v a) (Diff v b) type instance Diff v (P.Maybe a) = P.Maybe (Diff v a) type instance Diff v (P.IO a) = P.IO (Diff v a) type instance Diff v [a] = [Diff v a] class DBI repr => Lang repr where mkProd :: repr h (a -> b -> (a, b)) zro :: repr h ((a, b) -> a) fst :: repr h ((a, b) -> b) double :: P.Double -> repr h P.Double doubleZero :: repr h P.Double doubleZero = double 0 doubleOne :: repr h P.Double doubleOne = double 1 doublePlus :: repr h (P.Double -> P.Double -> P.Double) doubleMinus :: repr h (P.Double -> P.Double -> P.Double) doubleMult :: repr h (P.Double -> P.Double -> P.Double) doubleDivide :: repr h (P.Double -> P.Double -> P.Double) doubleExp :: repr h (P.Double -> P.Double) float :: P.Float -> repr h P.Float floatZero :: repr h P.Float floatZero = float 0 floatOne :: repr h P.Float floatOne = float 1 floatPlus :: repr h (P.Float -> P.Float -> P.Float) floatMinus :: repr h (P.Float -> P.Float -> P.Float) floatMult :: repr h (P.Float -> P.Float -> P.Float) floatDivide :: repr h (P.Float -> P.Float -> P.Float) floatExp :: repr h (P.Float -> P.Float) fix :: repr h ((a -> a) -> a) left :: repr h (a -> P.Either a b) right :: repr h (b -> P.Either a b) sumMatch :: repr h ((a -> c) -> (b -> c) -> P.Either a b -> c) unit :: repr h () exfalso :: repr h (Void -> a) nothing :: repr h (P.Maybe a) just :: repr h (a -> P.Maybe a) optionMatch :: repr h (b -> (a -> b) -> P.Maybe a -> b) ioRet :: repr h (a -> P.IO a) ioBind :: repr h (P.IO a -> (a -> P.IO b) -> P.IO b) ioMap :: repr h ((a -> b) -> P.IO a -> P.IO b) nil :: repr h [a] cons :: repr h (a -> [a] -> [a]) listMatch :: repr h (b -> (a -> [a] -> b) -> [a] -> b) listAppend :: repr h ([a] -> [a] -> [a]) listAppend = lam2 $ \l r -> fix2 (lam $ \self -> listMatch2 r (lam2 $ \a as -> cons2 a (app self as))) l writer :: repr h ((a, w) -> P.Writer w a) runWriter :: repr h (P.Writer w a -> (a, w)) swap :: repr h ((l, r) -> (r, l)) swap = lam $ \p -> mkProd2 (fst1 p) (zro1 p) curry :: repr h (((a, b) -> c) -> (a -> b -> c)) uncurry :: repr h ((a -> b -> c) -> ((a, b) -> c)) curry = lam3 $ \f a b -> app f (mkProd2 a b) uncurry = lam2 $ \f p -> app2 f (zro1 p) (fst1 p) float2Double :: repr h (P.Float -> P.Double) double2Float :: repr h (P.Double -> P.Float) class Reify repr x where reify :: x -> repr h x instance Lang repr => Reify repr () where reify _ = unit instance Lang repr => Reify repr P.Double where reify = double instance (Lang repr, Reify repr l, Reify repr r) => Reify repr (l, r) where reify (l, r) = mkProd2 (reify l) (reify r) instance Lang Eval where zro = comb P.fst fst = comb P.snd mkProd = comb (,) double = comb doublePlus = comb (+) doubleMinus = comb (-) doubleMult = comb (*) doubleDivide = comb (/) fix = comb loop where loop x = x $ loop x left = comb P.Left right = comb P.Right sumMatch = comb $ \l r -> \case P.Left x -> l x P.Right x -> r x unit = comb () exfalso = comb absurd nothing = comb P.Nothing just = comb P.Just ioRet = comb P.return ioBind = comb (>>=) nil = comb [] cons = comb (:) listMatch = comb $ \l r -> \case [] -> l x:xs -> r x xs optionMatch = comb $ \l r -> \case P.Nothing -> l P.Just x -> r x ioMap = comb P.fmap writer = comb (P.WriterT . P.Identity) runWriter = comb P.runWriter doubleExp = comb P.exp float = comb floatPlus = comb (+) floatMinus = comb (-) floatMult = comb (*) floatDivide = comb (/) floatExp = comb P.exp float2Double = comb P.float2Double double2Float = comb P.double2Float newtype UnHOAS repr h x = UnHOAS {runUnHOAS :: repr h x} instance DBI repr => DBI (UnHOAS repr) where z = UnHOAS z s (UnHOAS x) = UnHOAS $ s x abs (UnHOAS x) = UnHOAS $ abs x app (UnHOAS f) (UnHOAS x) = UnHOAS $ app f x instance Lang repr => Lang (UnHOAS repr) where mkProd = UnHOAS mkProd zro = UnHOAS zro fst = UnHOAS fst double = UnHOAS . double doublePlus = UnHOAS doublePlus doubleMinus = UnHOAS doubleMinus doubleMult = UnHOAS doubleMult doubleDivide = UnHOAS doubleDivide doubleExp = UnHOAS doubleExp fix = UnHOAS fix left = UnHOAS left right = UnHOAS right sumMatch = UnHOAS sumMatch unit = UnHOAS unit exfalso = UnHOAS exfalso nothing = UnHOAS nothing just = UnHOAS just ioRet = UnHOAS ioRet ioBind = UnHOAS ioBind nil = UnHOAS nil cons = UnHOAS cons listMatch = UnHOAS listMatch optionMatch = UnHOAS optionMatch ioMap = UnHOAS ioMap writer = UnHOAS writer runWriter = UnHOAS runWriter float = UnHOAS . float floatPlus = UnHOAS floatPlus floatMinus = UnHOAS floatMinus floatMult = UnHOAS floatMult floatDivide = UnHOAS floatDivide floatExp = UnHOAS floatExp float2Double = UnHOAS float2Double double2Float = UnHOAS double2Float instance Lang Show where mkProd = name "mkProd" zro = name "zro" fst = name "fst" double = name . show doublePlus = name "plus" doubleMinus = name "minus" doubleMult = name "mult" doubleDivide = name "divide" doubleExp = name "exp" fix = name "fix" left = name "left" right = name "right" sumMatch = name "sumMatch" unit = name "unit" exfalso = name "exfalso" nothing = name "nothing" just = name "just" ioRet = name "ioRet" ioBind = name "ioBind" nil = name "nil" cons = name "cons" listMatch = name "listMatch" optionMatch = name "optionMatch" ioMap = name "ioMap" writer = name "writer" runWriter = name "runWriter" float = name . show floatPlus = name "plus" floatMinus = name "minus" floatMult = name "mult" floatDivide = name "divide" floatExp = name "exp" float2Double = name "float2Double" double2Float = name "double2Float" instance Lang repr => Lang (GWDiff repr) where mkProd = GWDiff (P.const mkProd) zro = GWDiff $ P.const $ zro fst = GWDiff $ P.const $ fst double x = GWDiff $ P.const $ mkProd2 (double x) zero doublePlus = GWDiff $ P.const $ lam2 $ \l r -> mkProd2 (plus2 (zro1 l) (zro1 r)) (plus2 (fst1 l) (fst1 r)) doubleMinus = GWDiff $ P.const $ lam2 $ \l r -> mkProd2 (minus2 (zro1 l) (zro1 r)) (minus2 (fst1 l) (fst1 r)) doubleMult = GWDiff $ P.const $ lam2 $ \l r -> mkProd2 (mult2 (zro1 l) (zro1 r)) (plus2 (mult2 (zro1 l) (fst1 r)) (mult2 (zro1 r) (fst1 l))) doubleDivide = GWDiff $ P.const $ lam2 $ \l r -> mkProd2 (divide2 (zro1 l) (zro1 r)) (divide2 (minus2 (mult2 (zro1 r) (fst1 l)) (mult2 (zro1 l) (fst1 r))) (mult2 (zro1 r) (zro1 r))) doubleExp = GWDiff $ P.const $ lam $ \x -> mkProd2 (doubleExp1 (zro1 x)) (mult2 (doubleExp1 (zro1 x)) (fst1 x)) fix = GWDiff $ P.const fix left = GWDiff $ P.const left right = GWDiff $ P.const right sumMatch = GWDiff $ P.const sumMatch unit = GWDiff $ P.const unit exfalso = GWDiff $ P.const exfalso nothing = GWDiff $ P.const nothing just = GWDiff $ P.const just ioRet = GWDiff $ P.const ioRet ioBind = GWDiff $ P.const ioBind nil = GWDiff $ P.const nil cons = GWDiff $ P.const cons listMatch = GWDiff $ P.const listMatch optionMatch = GWDiff $ P.const optionMatch ioMap = GWDiff $ P.const ioMap writer = GWDiff $ P.const writer runWriter = GWDiff $ P.const runWriter float x = GWDiff $ P.const $ mkProd2 (float x) zero floatPlus = GWDiff $ P.const $ lam2 $ \l r -> mkProd2 (plus2 (zro1 l) (zro1 r)) (plus2 (fst1 l) (fst1 r)) floatMinus = GWDiff $ P.const $ lam2 $ \l r -> mkProd2 (minus2 (zro1 l) (zro1 r)) (minus2 (fst1 l) (fst1 r)) floatMult = GWDiff $ P.const $ lam2 $ \l r -> mkProd2 (mult2 (float2Double1 (zro1 l)) (zro1 r)) (plus2 (mult2 (float2Double1 (zro1 l)) (fst1 r)) (mult2 (float2Double1 (zro1 r)) (fst1 l))) floatDivide = GWDiff $ P.const $ lam2 $ \l r -> mkProd2 (divide2 (zro1 l) (float2Double1 (zro1 r))) (divide2 (minus2 (mult2 (float2Double1 (zro1 r)) (fst1 l)) (mult2 (float2Double1 (zro1 l)) (fst1 r))) (float2Double1 (mult2 (float2Double1 (zro1 r)) (zro1 r)))) floatExp = GWDiff $ P.const $ lam $ \x -> mkProd2 (floatExp1 (zro1 x)) (mult2 (float2Double1 (floatExp1 (zro1 x))) (fst1 x)) float2Double = GWDiff $ P.const $ bimap2 float2Double id double2Float = GWDiff $ P.const $ bimap2 double2Float id instance (Vector repr v, Lang repr) => Lang (WDiff repr v) where mkProd = WDiff mkProd zro = WDiff zro fst = WDiff fst double x = WDiff $ mkProd2 (double x) zero doublePlus = WDiff $ lam2 $ \l r -> mkProd2 (plus2 (zro1 l) (zro1 r)) (plus2 (fst1 l) (fst1 r)) doubleMinus = WDiff $ lam2 $ \l r -> mkProd2 (minus2 (zro1 l) (zro1 r)) (minus2 (fst1 l) (fst1 r)) doubleMult = WDiff $ lam2 $ \l r -> mkProd2 (mult2 (zro1 l) (zro1 r)) (plus2 (mult2 (zro1 l) (fst1 r)) (mult2 (zro1 r) (fst1 l))) doubleDivide = WDiff $ lam2 $ \l r -> mkProd2 (divide2 (zro1 l) (zro1 r)) (divide2 (minus2 (mult2 (zro1 r) (fst1 l)) (mult2 (zro1 l) (fst1 r))) (mult2 (zro1 r) (zro1 r))) doubleExp = WDiff $ lam $ \x -> mkProd2 (doubleExp1 (zro1 x)) (mult2 (doubleExp1 (zro1 x)) (fst1 x)) fix = WDiff fix left = WDiff left right = WDiff right sumMatch = WDiff sumMatch unit = WDiff unit exfalso = WDiff exfalso nothing = WDiff nothing just = WDiff just ioRet = WDiff ioRet ioBind = WDiff ioBind nil = WDiff nil cons = WDiff cons listMatch = WDiff listMatch optionMatch = WDiff optionMatch ioMap = WDiff ioMap writer = WDiff writer runWriter = WDiff runWriter float x = WDiff $ mkProd2 (float x) zero floatPlus = WDiff $ lam2 $ \l r -> mkProd2 (plus2 (zro1 l) (zro1 r)) (plus2 (fst1 l) (fst1 r)) floatMinus = WDiff $ lam2 $ \l r -> mkProd2 (minus2 (zro1 l) (zro1 r)) (minus2 (fst1 l) (fst1 r)) floatMult = WDiff $ lam2 $ \l r -> mkProd2 (mult2 (float2Double1 (zro1 l)) (zro1 r)) (plus2 (mult2 (float2Double1 (zro1 l)) (fst1 r)) (mult2 (float2Double1 (zro1 r)) (fst1 l))) floatDivide = WDiff $ lam2 $ \l r -> mkProd2 (divide2 (zro1 l) (float2Double1 (zro1 r))) (divide2 (minus2 (mult2 (float2Double1 (zro1 r)) (fst1 l)) (mult2 (float2Double1 (zro1 l)) (fst1 r))) (float2Double1 (mult2 (float2Double1 (zro1 r)) (zro1 r)))) floatExp = WDiff $ lam $ \x -> mkProd2 (floatExp1 (zro1 x)) (mult2 (float2Double1 (floatExp1 (zro1 x))) (fst1 x)) float2Double = WDiff $ bimap2 float2Double id double2Float = WDiff $ bimap2 double2Float id instance Lang repr => ProdCon (Monoid repr) l r where prodCon = Sub Dict instance Lang repr => ProdCon (WithDiff repr) l r where prodCon = Sub Dict instance Lang repr => ProdCon (Reify repr) l r where prodCon = Sub Dict instance Lang repr => ProdCon (Vector repr) l r where prodCon = Sub Dict instance Lang repr => Lang (ImpW repr) where nil = NoImpW nil cons = NoImpW cons listMatch = NoImpW listMatch zro = NoImpW zro fst = NoImpW fst mkProd = NoImpW mkProd ioRet = NoImpW ioRet ioMap = NoImpW ioMap ioBind = NoImpW ioBind unit = NoImpW unit nothing = NoImpW nothing just = NoImpW just optionMatch = NoImpW optionMatch exfalso = NoImpW exfalso fix = NoImpW fix left = NoImpW left right = NoImpW right sumMatch = NoImpW sumMatch writer = NoImpW writer runWriter = NoImpW runWriter double = NoImpW . double doubleExp = NoImpW doubleExp doublePlus = NoImpW doublePlus doubleMinus = NoImpW doubleMinus doubleMult = NoImpW doubleMult doubleDivide = NoImpW doubleDivide float = NoImpW . float floatExp = NoImpW floatExp floatPlus = NoImpW floatPlus floatMinus = NoImpW floatMinus floatMult = NoImpW floatMult floatDivide = NoImpW floatDivide float2Double = NoImpW float2Double double2Float = NoImpW double2Float instance (Lang l, Lang r) => Lang (Combine l r) where mkProd = Combine mkProd mkProd zro = Combine zro zro fst = Combine fst fst double x = Combine (double x) (double x) doublePlus = Combine doublePlus doublePlus doubleMinus = Combine doubleMinus doubleMinus doubleMult = Combine doubleMult doubleMult doubleDivide = Combine doubleDivide doubleDivide doubleExp = Combine doubleExp doubleExp float x = Combine (float x) (float x) floatPlus = Combine floatPlus floatPlus floatMinus = Combine floatMinus floatMinus floatMult = Combine floatMult floatMult floatDivide = Combine floatDivide floatDivide floatExp = Combine floatExp floatExp fix = Combine fix fix left = Combine left left right = Combine right right sumMatch = Combine sumMatch sumMatch unit = Combine unit unit exfalso = Combine exfalso exfalso nothing = Combine nothing nothing just = Combine just just optionMatch = Combine optionMatch optionMatch ioRet = Combine ioRet ioRet ioBind = Combine ioBind ioBind ioMap = Combine ioMap ioMap nil = Combine nil nil cons = Combine cons cons listMatch = Combine listMatch listMatch runWriter = Combine runWriter runWriter writer = Combine writer writer double2Float = Combine double2Float double2Float float2Double = Combine float2Double float2Double instance Lang repr => WithDiff repr () where withDiff = const1 id instance Lang repr => WithDiff repr P.Double where withDiff = lam2 $ \conv d -> mkProd2 d (app conv doubleOne) instance (Lang repr, WithDiff repr l, WithDiff repr r) => WithDiff repr (l, r) where withDiff = lam $ \conv -> bimap2 (withDiff1 (lam $ \l -> app conv (mkProd2 l zero))) (withDiff1 (lam $ \r -> app conv (mkProd2 zero r))) class Monoid r g => Group r g where invert :: r h (g -> g) minus :: r h (g -> g -> g) default invert :: Lang r => r h (g -> g) invert = minus1 zero default minus :: Lang r => r h (g -> g -> g) minus = lam2 $ \x y -> plus2 x (invert1 y) {-# MINIMAL (invert | minus) #-} class Group r v => Vector r v where mult :: r h (P.Double -> v -> v) divide :: r h (v -> P.Double -> v) default mult :: Lang r => r h (P.Double -> v -> v) mult = lam2 $ \x y -> divide2 y (recip1 x) default divide :: Lang r => r h (v -> P.Double -> v) divide = lam2 $ \x y -> mult2 (recip1 y) x {-# MINIMAL (mult | divide) #-} instance Lang r => Monoid r () where zero = unit plus = const1 $ const1 unit instance Lang r => Group r () where invert = const1 unit minus = const1 $ const1 unit instance Lang r => Vector r () where mult = const1 $ const1 unit divide = const1 $ const1 unit instance Lang r => Monoid r P.Double where zero = doubleZero plus = doublePlus instance Lang r => Group r P.Double where minus = doubleMinus instance Lang r => Vector r P.Double where mult = doubleMult divide = doubleDivide instance Lang r => Monoid r P.Float where zero = floatZero plus = floatPlus instance Lang r => Group r P.Float where minus = floatMinus instance Lang r => Vector r P.Float where mult = com2 floatMult double2Float divide = com2 (flip2 com double2Float) floatDivide instance (Lang repr, Monoid repr l, Monoid repr r) => Monoid repr (l, r) where zero = mkProd2 zero zero plus = lam2 $ \l r -> mkProd2 (plus2 (zro1 l) (zro1 r)) (plus2 (fst1 l) (fst1 r)) instance (Lang repr, Group repr l, Group repr r) => Group repr (l, r) where invert = bimap2 invert invert instance (Lang repr, Vector repr l, Vector repr r) => Vector repr (l, r) where mult = lam $ \x -> bimap2 (mult1 x) (mult1 x) instance (Lang repr, Monoid repr l, Monoid repr r) => Monoid repr (l -> r) where zero = const1 zero plus = lam3 $ \l r x -> plus2 (app l x) (app r x) instance (Lang repr, Group repr l, Group repr r) => Group repr (l -> r) where invert = lam2 $ \l x -> app l (invert1 x) instance (Lang repr, Vector repr l, Vector repr r) => Vector repr (l -> r) where mult = lam3 $ \l r x -> app r (mult2 l x) instance Lang r => Monoid r [a] where zero = nil plus = listAppend instance Lang r => Functor r [] where map = lam $ \f -> fix1 $ lam $ \self -> listMatch2 nil (lam2 $ \x xs -> cons2 (app f x) $ app self xs) instance Lang r => BiFunctor r (,) where bimap = lam3 $ \l r p -> mkProd2 (app l (zro1 p)) (app r (fst1 p)) instance Lang r => Functor r (P.Writer w) where map = lam $ \f -> com2 writer (com2 (bimap2 f id) runWriter) instance (Lang r, Monoid r w) => Applicative r (P.Writer w) where pure = com2 writer (flip2 mkProd zero) ap = lam2 $ \f x -> writer1 (mkProd2 (app (zro1 (runWriter1 f)) (zro1 (runWriter1 x))) (plus2 (fst1 (runWriter1 f)) (fst1 (runWriter1 x)))) instance (Lang r, Monoid r w) => Monad r (P.Writer w) where join = lam $ \x -> writer1 $ mkProd2 (zro1 $ runWriter1 $ zro1 $ runWriter1 x) (plus2 (fst1 $ runWriter1 $ zro1 $ runWriter1 x) (fst1 $ runWriter1 x)) instance Lang r => Functor r P.IO where map = ioMap instance Lang r => Applicative r P.IO where pure = ioRet ap = lam2 $ \f x -> ioBind2 f (flip2 ioMap x) instance Lang r => Monad r P.IO where bind = ioBind instance Lang r => Functor r P.Maybe where map = lam $ \func -> optionMatch2 nothing (com2 just func) instance Lang r => Applicative r P.Maybe where pure = just ap = optionMatch2 (const1 nothing) map instance Lang r => Monad r P.Maybe where bind = lam2 $ \x func -> optionMatch3 nothing func x runImpW :: forall repr h x. Lang repr => ImpW repr h x -> RunImpW repr h x runImpW (ImpW x) = RunImpW x runImpW (NoImpW x) = RunImpW (const1 x :: repr h (() -> x)) newtype GWDiff repr h x = GWDiff {runGWDiff :: forall v. Vector repr v => Proxy v -> repr (Diff v h) (Diff v x)} instance DBI repr => DBI (GWDiff repr) where z = GWDiff (P.const z) s (GWDiff x) = GWDiff (\p -> s $ x p) app (GWDiff f) (GWDiff x) = GWDiff (\p -> app (f p) (x p)) abs (GWDiff x) = GWDiff (\p -> abs $ x p) cons2 = app2 cons listMatch2 = app2 listMatch fix1 = app fix fix2 = app2 fix uncurry1 = app uncurry optionMatch2 = app2 optionMatch optionMatch3 = app3 optionMatch zro1 = app zro fst1 = app fst mult1 = app mult mult2 = app2 mult divide2 = app2 divide invert1 = app invert mkProd1 = app mkProd mkProd2 = app2 mkProd minus1 = app minus divide1 = app divide recip = divide1 doubleOne recip1 = app recip writer1 = app writer runWriter1 = app runWriter ioBind2 = app2 ioBind minus2 = app2 minus float2Double1 = app float2Double doubleExp1 = app doubleExp floatExp1 = app floatExp instance Lang repr => DBI (ImpW repr) where z = NoImpW z s :: forall a h b. ImpW repr h b -> ImpW repr (a, h) b s (ImpW x) = work x where work :: Weight w => repr h (w -> b) -> ImpW repr (a, h) b work x = ImpW (s x) s (NoImpW x) = NoImpW (s x) app (ImpW f) (ImpW x) = ImpW (lam $ \p -> app (app (conv f) (zro1 p)) (app (conv x) (fst1 p))) app (NoImpW f) (NoImpW x) = NoImpW (app f x) app (ImpW f) (NoImpW x) = ImpW (lam $ \w -> app2 (conv f) w (conv x)) app (NoImpW f) (ImpW x) = ImpW (lam $ \w -> app (conv f) (app (conv x) w)) abs (ImpW f) = ImpW (flip1 $ abs f) abs (NoImpW x) = NoImpW (abs x)