{-# LANGUAGE
  NoImplicitPrelude,
  ExplicitForAll,
  InstanceSigs,
  ScopedTypeVariables
#-}

module DDF.WDiff where

import DDF.Lang
import DDF.Diff
import qualified Data.Map as M
import qualified DDF.Map as Map

newtype WDiff r v h x = WDiff {runWDiff :: r (Diff v h) (Diff v x)}

instance DBI r => DBI (WDiff r v) where
  z = WDiff z
  s (WDiff x) = WDiff $ s x
  abs (WDiff f) = WDiff $ abs f
  app (WDiff f) (WDiff x) = WDiff $ app f x
  hoas f = WDiff $ hoas (\x -> runWDiff $ f $ WDiff x)

instance Bool r => Bool (WDiff r v) where
  bool x = WDiff $ bool x
  ite = WDiff ite

instance Char r => Char (WDiff r v) where
  char = WDiff . char

instance Prod r => Prod (WDiff r v) where
  mkProd = WDiff mkProd
  zro = WDiff zro
  fst = WDiff fst

instance Dual r => Dual (WDiff r v) where
  dual = WDiff $ dual
  runDual = WDiff $ runDual

instance (Vector r v, Double r, Dual r) => Double (WDiff r v) where
  double x = WDiff $ mkDual2 (double x) zero
  doublePlus = WDiff $ lam2 $ \l r ->
    mkDual2 (plus2 (dualOrig1 l) (dualOrig1 r)) (plus2 (dualDiff1 l) (dualDiff1 r))
  doubleMinus = WDiff $ lam2 $ \l r ->
    mkDual2 (minus2 (dualOrig1 l) (dualOrig1 r)) (minus2 (dualDiff1 l) (dualDiff1 r))
  doubleMult = WDiff $ lam2 $ \l r ->
    mkDual2 (mult2 (dualOrig1 l) (dualOrig1 r))
      (plus2 (mult2 (dualOrig1 l) (dualDiff1 r)) (mult2 (dualOrig1 r) (dualDiff1 l)))
  doubleDivide = WDiff $ lam2 $ \l r ->
    mkDual2 (divide2 (dualOrig1 l) (dualOrig1 r))
      (divide2 (minus2 (mult2 (dualOrig1 r) (dualDiff1 l)) (mult2 (dualOrig1 l) (dualDiff1 r)))
        (mult2 (dualOrig1 r) (dualOrig1 r)))
  doubleExp = WDiff $ lam $ \x -> let_2 (doubleExp1 (dualOrig1 x)) (lam $ \e -> mkDual2 e (mult2 e (dualDiff1 x)))

instance (Vector r v, Lang r) => Float (WDiff r v) where
  float x = WDiff $ mkDual2 (float x) zero
  floatPlus = WDiff $ lam2 $ \l r ->
    mkDual2 (plus2 (dualOrig1 l) (dualOrig1 r)) (plus2 (dualDiff1 l) (dualDiff1 r))
  floatMinus = WDiff $ lam2 $ \l r ->
    mkDual2 (minus2 (dualOrig1 l) (dualOrig1 r)) (minus2 (dualDiff1 l) (dualDiff1 r))
  floatMult = WDiff $ lam2 $ \l r ->
    mkDual2 (mult2 (float2Double1 (dualOrig1 l)) (dualOrig1 r))
      (plus2 (mult2 (float2Double1 (dualOrig1 l)) (dualDiff1 r)) (mult2 (float2Double1 (dualOrig1 r)) (dualDiff1 l)))
  floatDivide = WDiff $ lam2 $ \l r ->
    mkDual2 (divide2 (dualOrig1 l) (float2Double1 (dualOrig1 r)))
      (divide2 (minus2 (mult2 (float2Double1 (dualOrig1 r)) (dualDiff1 l)) (mult2 (float2Double1 (dualOrig1 l)) (dualDiff1 r)))
        (float2Double1 (mult2 (float2Double1 (dualOrig1 r)) (dualOrig1 r))))
  floatExp = WDiff (lam $ \x -> let_2 (floatExp1 (dualOrig1 x)) (lam $ \e -> mkDual2 e (mult2 (float2Double1 e) (dualDiff1 x))))

instance Option r => Option (WDiff r v) where
  nothing = WDiff nothing
  just = WDiff just
  optionMatch = WDiff optionMatch

instance Map.Map r => Map.Map (WDiff r v) where
  empty = WDiff Map.empty
  singleton = WDiff Map.singleton
  lookup :: forall h k a. Map.Ord k => WDiff r v h (k -> M.Map k a -> Maybe a)
  lookup = withDict (Map.diffOrd (Proxy :: Proxy (v, k))) (WDiff Map.lookup)
  alter :: forall h k a. Map.Ord k => WDiff r v h ((Maybe a -> Maybe a) -> k -> M.Map k a -> M.Map k a)
  alter = withDict (Map.diffOrd (Proxy :: Proxy (v, k))) (WDiff Map.alter)
  mapMap = WDiff Map.mapMap

instance Bimap r => Bimap (WDiff r v) where

instance (Vector r v, Lang r) => Lang (WDiff r v) where
  fix = WDiff fix
  left = WDiff left
  right = WDiff right
  sumMatch = WDiff sumMatch
  unit = WDiff unit
  exfalso = WDiff exfalso
  ioRet = WDiff ioRet
  ioBind = WDiff ioBind
  nil = WDiff nil
  cons = WDiff cons
  listMatch = WDiff listMatch
  ioMap = WDiff ioMap
  writer = WDiff writer
  runWriter = WDiff runWriter
  float2Double = WDiff $ bimap2 float2Double id
  double2Float = WDiff $ bimap2 double2Float id
  state = WDiff state
  runState = WDiff runState
  putStrLn = WDiff putStrLn