{-# LANGUAGE
  NoImplicitPrelude,
  ExplicitForAll,
  InstanceSigs,
  ScopedTypeVariables,
  TypeApplications,
  FlexibleContexts,
  UndecidableInstances,
  TypeFamilies,
  MultiParamTypeClasses,
  TypeOperators,
  DataKinds
#-}

module DDF.Diff where

import DDF.Lang
import qualified Prelude as M
import qualified Data.Map as M
import qualified DDF.Map as Map
import qualified Data.Bimap as M
import qualified DDF.Meta.Dual as M
import qualified DDF.VectorTF as VTF
import qualified DDF.Meta.DiffWrapper as M.DW
import qualified Data.Functor.Foldable as M
import qualified DDF.Meta.FreeVector as M

type instance DiffType v (l -> r) = DiffType v l -> DiffType v r
instance DBI r => DBI (Diff r v) where
  z = Diff z
  s (Diff x) = Diff $ s x
  abs (Diff f) = Diff $ abs f
  app (Diff f) (Diff x) = Diff $ app f x
  hoas f = Diff $ hoas (\x -> runDiff $ f $ Diff x)

type instance DiffType v M.Bool = M.Bool
instance Bool r => Bool (Diff r v) where
  bool x = Diff $ bool x
  ite = Diff ite

type instance DiffType v M.Char = M.Char
instance Char r => Char (Diff r v) where
  char = Diff . char

type instance DiffType v (l, r) = (DiffType v l, DiffType v r)
instance Prod r => Prod (Diff r v) where
  mkProd = Diff mkProd
  zro = Diff zro
  fst = Diff fst

type instance DiffType v (M.Dual l r) = M.Dual (DiffType v l) (DiffType v r)
instance Dual r => Dual (Diff r v) where
  dual = Diff $ dual
  runDual = Diff $ runDual

type instance DiffType v M.Double = M.Dual M.Double v
instance (Vector r v, Lang r) => Double (Diff r v) where
  double x = Diff $ mkDual2 (double x) zero
  doublePlus = Diff $ lam2 $ \l r ->
    mkDual2 (plus2 (dualOrig1 l) (dualOrig1 r)) (plus2 (dualDiff1 l) (dualDiff1 r))
  doubleMinus = Diff $ lam2 $ \l r ->
    mkDual2 (minus2 (dualOrig1 l) (dualOrig1 r)) (minus2 (dualDiff1 l) (dualDiff1 r))
  doubleMult = Diff $ lam2 $ \l r ->
    mkDual2 (mult2 (dualOrig1 l) (dualOrig1 r))
      (plus2 (mult2 (dualOrig1 l) (dualDiff1 r)) (mult2 (dualOrig1 r) (dualDiff1 l)))
  doubleDivide = Diff $ lam2 $ \l r ->
    mkDual2 (divide2 (dualOrig1 l) (dualOrig1 r))
      (divide2 (minus2 (mult2 (dualOrig1 r) (dualDiff1 l)) (mult2 (dualOrig1 l) (dualDiff1 r)))
        (mult2 (dualOrig1 r) (dualOrig1 r)))
  doubleExp = Diff $ lam $ \x -> let_2 (doubleExp1 (dualOrig1 x)) (lam $ \e -> mkDual2 e (mult2 e (dualDiff1 x)))
  doubleEq = Diff $ lam2 $ \l r -> doubleEq2 (dualOrig1 l) (dualOrig1 r)

type instance DiffType v M.Float = M.Dual M.Float v
instance (Vector r v, Lang r) => Float (Diff r v) where
  float x = Diff $ mkDual2 (float x) zero
  floatPlus = Diff $ lam2 $ \l r ->
    mkDual2 (plus2 (dualOrig1 l) (dualOrig1 r)) (plus2 (dualDiff1 l) (dualDiff1 r))
  floatMinus = Diff $ lam2 $ \l r ->
    mkDual2 (minus2 (dualOrig1 l) (dualOrig1 r)) (minus2 (dualDiff1 l) (dualDiff1 r))
  floatMult = Diff $ lam2 $ \l r ->
    mkDual2 (mult2 (float2Double1 (dualOrig1 l)) (dualOrig1 r))
      (plus2 (mult2 (float2Double1 (dualOrig1 l)) (dualDiff1 r)) (mult2 (float2Double1 (dualOrig1 r)) (dualDiff1 l)))
  floatDivide = Diff $ lam2 $ \l r ->
    mkDual2 (divide2 (dualOrig1 l) (float2Double1 (dualOrig1 r)))
      (divide2 (minus2 (mult2 (float2Double1 (dualOrig1 r)) (dualDiff1 l)) (mult2 (float2Double1 (dualOrig1 l)) (dualDiff1 r)))
        (float2Double1 (mult2 (float2Double1 (dualOrig1 r)) (dualOrig1 r))))
  floatExp = Diff (lam $ \x -> let_2 (floatExp1 (dualOrig1 x)) (lam $ \e -> mkDual2 e (mult2 (float2Double1 e) (dualDiff1 x))))

type instance DiffType v (Maybe l) = Maybe (DiffType v l)
instance Option r => Option (Diff r v) where
  nothing = Diff nothing
  just = Diff just
  optionMatch = Diff optionMatch

type instance DiffType v (M.Map k val) = M.Map (DiffType v k) (DiffType v val)
instance Map.Map r => Map.Map (Diff r v) where
  empty = Diff Map.empty
  singleton = Diff Map.singleton
  lookup :: forall h k a. Map.Ord k => Diff r v h (M.Map k a -> k -> Maybe a)
  lookup = withDict (Map.diffOrd (Proxy :: Proxy (v, k))) (Diff Map.lookup)
  alter :: forall h k a. Map.Ord k => Diff r v h ((Maybe a -> Maybe a) -> k -> M.Map k a -> M.Map k a)
  alter = withDict (Map.diffOrd (Proxy :: Proxy (v, k))) (Diff Map.alter)
  mapMap = Diff Map.mapMap
  unionWith :: forall h k a. Map.Ord k => Diff r v h ((a -> a -> a) -> M.Map k a -> M.Map k a -> M.Map k a)
  unionWith = withDict (Map.diffOrd (Proxy :: Proxy (v, k))) (Diff Map.unionWith)

type instance DiffType v (M.Bimap a b) = M.Bimap (DiffType v a) (DiffType v b)
instance Bimap r => Bimap (Diff r v) where
  size = Diff size
  toMapL = Diff toMapL
  toMapR = Diff toMapR
  lookupL :: forall h a b. (Map.Ord a, Map.Ord b) => Diff r v h (M.Bimap a b -> a -> Maybe b)
  lookupL = withDict (Map.diffOrd (Proxy :: Proxy (v, a))) (withDict (Map.diffOrd (Proxy :: Proxy (v, b))) (Diff lookupL))
  lookupR :: forall h a b. (Map.Ord a, Map.Ord b) => Diff r v h (M.Bimap a b -> b -> Maybe a)
  lookupR = withDict (Map.diffOrd (Proxy :: Proxy (v, a))) (withDict (Map.diffOrd (Proxy :: Proxy (v, b))) (Diff lookupR))
  empty = Diff empty
  singleton = Diff singleton
  insert :: forall h a b. (Map.Ord a, Map.Ord b) => Diff r v h ((a, b) -> M.Bimap a b -> M.Bimap a b)
  insert = withDict (Map.diffOrd (Proxy :: Proxy (v, a))) (withDict (Map.diffOrd (Proxy :: Proxy (v, b))) (Diff insert))
  updateL :: forall h a b. (Map.Ord a, Map.Ord b) => Diff r v h ((b -> Maybe b) -> a -> M.Bimap a b -> M.Bimap a b)
  updateL = withDict (Map.diffOrd (Proxy :: Proxy (v, a))) (withDict (Map.diffOrd (Proxy :: Proxy (v, b))) (Diff updateL))
  updateR :: forall h a b. (Map.Ord a, Map.Ord b) => Diff r v h ((a -> Maybe a) -> b -> M.Bimap a b -> M.Bimap a b)
  updateR = withDict (Map.diffOrd (Proxy :: Proxy (v, a))) (withDict (Map.diffOrd (Proxy :: Proxy (v, b))) (Diff updateR))

type instance DiffType v () = ()
instance Unit r => Unit (Diff r v) where
  unit = Diff unit

type instance DiffType v (M.Either l r) = M.Either (DiffType v l) (DiffType v r)
instance Sum r => Sum (Diff r v) where
  left = Diff left
  right = Diff right
  sumMatch = Diff sumMatch

instance Int r => Int (Diff r v) where
  int = Diff . int
  pred = Diff pred
  isZero = Diff isZero

instance Y r => Y (Diff r v) where
  y = Diff y

type instance DiffType v (M.IO l) = M.IO (DiffType v l)
instance IO r => IO (Diff r v) where
  putStrLn = Diff putStrLn

type instance DiffType v [l] = [DiffType v l]
instance List r => List (Diff r v) where
  nil = Diff nil
  cons = Diff cons
  listMatch = Diff listMatch

instance Functor r M.IO => Functor (Diff r v) M.IO where
  map = Diff map

instance Applicative r M.IO => Applicative (Diff r v) M.IO where
  pure = Diff pure
  ap = Diff ap

instance Monad r M.IO => Monad (Diff r v) M.IO where
  bind = Diff bind
  join = Diff join

instance (Vector r v, Lang r) => VTF.VectorTF (Diff r v) where
  zero = Diff VTF.zero
  basis = Diff VTF.basis
  plus = Diff VTF.plus
  mult = Diff $ VTF.mult `com2` dualOrig
  vtfMatch = Diff $ lam4 $ \ze b p m -> VTF.vtfMatch4 ze b p $ lam $ \x -> app m (mkDual2 x zero)

type instance DiffType v (M.DW.DiffWrapper a x) = M.DW.DiffWrapper (v ': a) x
instance DiffWrapper r => DiffWrapper (Diff r v) where
  diffWrapper = Diff diffWrapper
  runDiffWrapper = Diff runDiffWrapper

type instance DiffType v (M.Fix f) = M.DW.DiffWrapper '[v] (f (M.Fix f))
instance DiffWrapper r => Fix (Diff r v) where
  fix = Diff diffWrapper
  runFix = Diff runDiffWrapper

type instance DiffType v (M.FreeVector a b) = M.FreeVector (DiffType v a) (DiffType v b)
instance FreeVector r => FreeVector (Diff r v) where
  freeVector = Diff freeVector
  runFreeVector = Diff runFreeVector

type instance DiffType v Void = Void
type instance DiffType v (Writer l r) = Writer (DiffType v l) (DiffType v r)
type instance DiffType v (State l r) = State (DiffType v l) (DiffType v r)
instance (Vector r v, Lang r) => Lang (Diff r v) where
  exfalso = Diff exfalso
  writer = Diff writer
  runWriter = Diff runWriter
  float2Double = Diff $ bimap2 float2Double id
  double2Float = Diff $ bimap2 double2Float id
  state = Diff state
  runState = Diff runState

instance Map.Ord () where
  diffOrd _ = Dict

instance Map.Ord a => Map.Ord [a] where
  diffOrd (_ :: Proxy (v, [a])) = withDict (Map.diffOrd (Proxy :: Proxy (v, a))) Dict

instance Map.Ord l => Map.Ord (M.Dual l r) where
  diffOrd (_ :: Proxy (v, M.Dual l r)) = withDict (Map.diffOrd (Proxy :: Proxy (v, l))) Dict

instance Map.Ord M.Double where
  diffOrd _ = Dict

instance Map.Ord M.Float where
  diffOrd _ = Dict