{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE ViewPatterns #-}

module Downhill.BVar.Prelude
  ( -- * Tuples

    -- | Pattern synonyms @T2@, @T3@ pack and unpack tuples:
    --
    -- @
    -- fstBVar :: (HasGrad a, HasGrad b) => BVar r (a, b) -> BVar r a
    -- fstBVar (T2 a _b) = a
    --
    -- tieBVar :: (HasGrad a, HasGrad b) => BVar r a -> BVar r b -> BVar r (a, b)
    -- tieBVar a b = T2 a b
    -- @
    pattern T2,
    pattern T3,
  )
where

import Downhill.BVar (BVar (BVar))
import Downhill.Grad (HasGrad)
import qualified Downhill.Linear.Prelude as Linear
import Prelude ()

toPair :: (HasGrad a, HasGrad b) => BVar r (a, b) -> (BVar r a, BVar r b)
toPair :: forall a b r.
(HasGrad a, HasGrad b) =>
BVar r (a, b) -> (BVar r a, BVar r b)
toPair (BVar (a
x, b
y) (Linear.T2 BackGrad r (Grad a)
dx BackGrad r (Grad b)
dy)) = (forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar a
x BackGrad r (Grad a)
dx, forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar b
y BackGrad r (Grad b)
dy)

{-# COMPLETE T2 #-}

pattern T2 :: (HasGrad a, HasGrad b) => BVar r a -> BVar r b -> BVar r (a, b)
pattern $bT2 :: forall a b r.
(HasGrad a, HasGrad b) =>
BVar r a -> BVar r b -> BVar r (a, b)
$mT2 :: forall {r} {a} {b} {r}.
(HasGrad a, HasGrad b) =>
BVar r (a, b) -> (BVar r a -> BVar r b -> r) -> ((# #) -> r) -> r
T2 a b <-
  (toPair -> (a, b))
  where
    T2 (BVar a
a BackGrad r (Grad a)
da) (BVar b
b BackGrad r (Grad b)
db) = forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (a
a, b
b) (forall r a b.
(BasicVector a, BasicVector b) =>
BackGrad r a -> BackGrad r b -> BackGrad r (a, b)
Linear.T2 BackGrad r (Grad a)
da BackGrad r (Grad b)
db)

toTriple :: (HasGrad a, HasGrad b, HasGrad c) => BVar r (a, b, c) -> (BVar r a, BVar r b, BVar r c)
toTriple :: forall a b c r.
(HasGrad a, HasGrad b, HasGrad c) =>
BVar r (a, b, c) -> (BVar r a, BVar r b, BVar r c)
toTriple (BVar (a
x, b
y, c
z) (Linear.T3 BackGrad r (Grad a)
dx BackGrad r (Grad b)
dy BackGrad r (Grad c)
dz)) = (forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar a
x BackGrad r (Grad a)
dx, forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar b
y BackGrad r (Grad b)
dy, forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar c
z BackGrad r (Grad c)
dz)

{-# COMPLETE T3 #-}

pattern T3 :: (HasGrad a, HasGrad b, HasGrad c) => BVar r a -> BVar r b -> BVar r c -> BVar r (a, b, c)
pattern $bT3 :: forall a b c r.
(HasGrad a, HasGrad b, HasGrad c) =>
BVar r a -> BVar r b -> BVar r c -> BVar r (a, b, c)
$mT3 :: forall {r} {a} {b} {c} {r}.
(HasGrad a, HasGrad b, HasGrad c) =>
BVar r (a, b, c)
-> (BVar r a -> BVar r b -> BVar r c -> r) -> ((# #) -> r) -> r
T3 a b c <-
  (toTriple -> (a, b, c))
  where
    T3 (BVar a
a BackGrad r (Grad a)
da) (BVar b
b BackGrad r (Grad b)
db) (BVar c
c BackGrad r (Grad c)
dc) = forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (a
a, b
b, c
c) (forall r a b c.
(BasicVector a, BasicVector b, BasicVector c) =>
BackGrad r a
-> BackGrad r b -> BackGrad r c -> BackGrad r (a, b, c)
Linear.T3 BackGrad r (Grad a)
da BackGrad r (Grad b)
db BackGrad r (Grad c)
dc)