{-# LANGUAGE ConstraintKinds   #-}
{-# LANGUAGE FlexibleContexts  #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs             #-}
{-# LANGUAGE TypeOperators     #-}
-- |
-- Module      : Data.Array.Accelerate.Data.Fold
-- Copyright   : [2016..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--
-- Combine folds in 'Applicative' style to generate multiple results with
-- a single pass over the array. Based on Max Rabkin's "Beautiful Folding" [1]
-- and talks by Gabriel Gonzalez [2].
--
--  1. <http://squing.blogspot.com/2008/11/beautiful-folding.html>
--  2. <https://www.youtube.com/watch?v=6a5Ti0r8Q2s>
--

module Data.Array.Accelerate.Data.Fold (

  Fold(..), runFold,

) where

import Data.Array.Accelerate.Classes.Floating                       as A
import Data.Array.Accelerate.Classes.Fractional                     as A
import Data.Array.Accelerate.Classes.Num                            as A
import Data.Array.Accelerate.Data.Monoid
import Data.Array.Accelerate.Language                               as A
import Data.Array.Accelerate.Lift
import Data.Array.Accelerate.Smart                                  ( Acc, Exp, constant )
import Data.Array.Accelerate.Sugar.Array
import Data.Array.Accelerate.Sugar.Elt
import Data.Array.Accelerate.Sugar.Shape

import Prelude                                                      hiding ( sum, product, length )
import Control.Applicative                                          as P
import qualified Prelude                                            as P


-- | 'Fold' describes how to process data of some 'i'nput type into some
-- 'o'utput type, via a reduction using some intermediate Monoid 'w'. For
-- example, both 'sum' and 'length' below use the 'Sum' monoid:
--
-- > sum = Fold (lift . Sum) (getSum . unlift)
-- > length = Fold (\_ -> 1) (getSum . unlift)
--
-- The key is that 'Fold's can be combined using 'Applicative' in order to
-- produce multiple outputs from a /single/ reduction of the array. For example:
--
-- > average = (/) <$> sum <*> length
--
-- This computes both the sum of the array as well as its length in a single
-- traversal, then combines both results to compute the average.
--
-- Because 'Fold' has some numeric instances, this can also be defined more
-- succinctly as:
--
-- > average = sum / length
--
-- A more complex example:
--
-- > sumOfSquares = Fold (lift . Sum . (^2)) (getSum . unlift)
-- > standardDeviation = sqrt ((sumOfSquares / length) - (sum / length) ^ 2)
--
-- These will all execute with a single reduction kernel and a single map to
-- summarise (combine) the results.
--
data Fold i o where
  Fold :: (Elt w, Monoid (Exp w))
       => (i -> Exp w)              -- transform input element into internal monoid type
       -> (Exp w -> o)              -- summarise the reduction to retrieve the final result
       -> Fold i o

-- | Apply a 'Fold' to an array.
--
runFold
    :: (Shape sh, Elt i, Elt o)
    => Fold (Exp i) (Exp o)
    -> Acc (Array (sh:.Int) i)
    -> Acc (Array sh o)
runFold :: Fold (Exp i) (Exp o)
-> Acc (Array (sh :. Int) i) -> Acc (Array sh o)
runFold (Fold Exp i -> Exp w
tally Exp w -> Exp o
summarise) Acc (Array (sh :. Int) i)
is
  = (Exp w -> Exp o) -> Acc (Array sh w) -> Acc (Array sh o)
forall sh a b.
(Shape sh, Elt a, Elt b) =>
(Exp a -> Exp b) -> Acc (Array sh a) -> Acc (Array sh b)
A.map Exp w -> Exp o
summarise
  (Acc (Array sh w) -> Acc (Array sh o))
-> Acc (Array sh w) -> Acc (Array sh o)
forall a b. (a -> b) -> a -> b
$ (Exp w -> Exp w -> Exp w)
-> Exp w -> Acc (Array (sh :. Int) w) -> Acc (Array sh w)
forall sh a.
(Shape sh, Elt a) =>
(Exp a -> Exp a -> Exp a)
-> Exp a -> Acc (Array (sh :. Int) a) -> Acc (Array sh a)
A.fold Exp w -> Exp w -> Exp w
forall a. Monoid a => a -> a -> a
mappend Exp w
forall a. Monoid a => a
mempty
  (Acc (Array (sh :. Int) w) -> Acc (Array sh w))
-> Acc (Array (sh :. Int) w) -> Acc (Array sh w)
forall a b. (a -> b) -> a -> b
$ (Exp i -> Exp w)
-> Acc (Array (sh :. Int) i) -> Acc (Array (sh :. Int) w)
forall sh a b.
(Shape sh, Elt a, Elt b) =>
(Exp a -> Exp b) -> Acc (Array sh a) -> Acc (Array sh b)
A.map Exp i -> Exp w
tally Acc (Array (sh :. Int) i)
is


-- sum :: A.Num e => Fold (Exp e) (Exp e)
-- sum = Fold (lift . Sum) (getSum . unlift)

-- product :: A.Num e => Fold (Exp e) (Exp e)
-- product = Fold (lift . Product) (getProduct . unlift)

-- length :: A.Num i => Fold (Exp e) (Exp i)
-- length = Fold (\_ -> 1) (getSum . unlift)


-- combine2 :: (Elt a, Elt b) => Exp a -> Exp b -> Exp (a,b)
-- combine2 a b = lift (a,b)

-- combine3 :: (Elt a, Elt b, Elt c) => Exp a -> Exp b -> Exp c -> Exp (a,b,c)
-- combine3 a b c = lift (a,b,c)

-- combine4 :: (Elt a, Elt b, Elt c, Elt d) => Exp a -> Exp b -> Exp c -> Exp d -> Exp (a,b,c,d)
-- combine4 a b c d = lift (a,b,c,d)

-- combine5 :: (Elt a, Elt b, Elt c, Elt d, Elt e) => Exp a -> Exp b -> Exp c -> Exp d -> Exp e -> Exp (a,b,c,d,e)
-- combine5 a b c d e = lift (a,b,c,d,e)


-- Instances for 'Fold'
-- --------------------

instance P.Functor (Fold i) where
  fmap :: (a -> b) -> Fold i a -> Fold i b
fmap a -> b
k (Fold i -> Exp w
tally Exp w -> a
summarise) = (i -> Exp w) -> (Exp w -> b) -> Fold i b
forall w i o.
(Elt w, Monoid (Exp w)) =>
(i -> Exp w) -> (Exp w -> o) -> Fold i o
Fold i -> Exp w
tally (a -> b
k (a -> b) -> (Exp w -> a) -> Exp w -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp w -> a
summarise)

instance P.Applicative (Fold i) where
  pure :: a -> Fold i a
pure a
o                    = (i -> Exp ()) -> (Exp () -> a) -> Fold i a
forall w i o.
(Elt w, Monoid (Exp w)) =>
(i -> Exp w) -> (Exp w -> o) -> Fold i o
Fold (\i
_ -> () -> Exp ()
forall e. (HasCallStack, Elt e) => e -> Exp e
constant ()) (\Exp ()
_ -> a
o)
  Fold i -> Exp w
tF Exp w -> a -> b
sF <*> :: Fold i (a -> b) -> Fold i a -> Fold i b
<*> Fold i -> Exp w
tX Exp w -> a
sX = (i -> Exp (w, w)) -> (Exp (w, w) -> b) -> Fold i b
forall w i o.
(Elt w, Monoid (Exp w)) =>
(i -> Exp w) -> (Exp w -> o) -> Fold i o
Fold i -> Exp (w, w)
i -> Exp (Plain (Exp w, Exp w))
tally Exp (w, w) -> b
summarise
    where
      tally :: i -> Exp (Plain (Exp w, Exp w))
tally i
i     = (Exp w, Exp w) -> Exp (Plain (Exp w, Exp w))
forall (c :: * -> *) e. Lift c e => e -> c (Plain e)
lift (i -> Exp w
tF i
i, i -> Exp w
tX i
i)
      summarise :: Exp (w, w) -> b
summarise Exp (w, w)
t = let (Exp w
mF, Exp w
mX) = Exp (Plain (Exp w, Exp w)) -> (Exp w, Exp w)
forall (c :: * -> *) e. Unlift c e => c (Plain e) -> e
unlift Exp (w, w)
Exp (Plain (Exp w, Exp w))
t
                    in Exp w -> a -> b
sF Exp w
mF (Exp w -> a
sX Exp w
mX)

instance A.Num b => P.Num (Fold a (Exp b)) where
  + :: Fold a (Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
(+)           = (Exp b -> Exp b -> Exp b)
-> Fold a (Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 Exp b -> Exp b -> Exp b
forall a. Num a => a -> a -> a
(+)
  (-)           = (Exp b -> Exp b -> Exp b)
-> Fold a (Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 (-)
  * :: Fold a (Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
(*)           = (Exp b -> Exp b -> Exp b)
-> Fold a (Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 Exp b -> Exp b -> Exp b
forall a. Num a => a -> a -> a
(*)
  negate :: Fold a (Exp b) -> Fold a (Exp b)
negate        = (Exp b -> Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Exp b -> Exp b
forall a. Num a => a -> a
negate
  abs :: Fold a (Exp b) -> Fold a (Exp b)
abs           = (Exp b -> Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Exp b -> Exp b
forall a. Num a => a -> a
abs
  signum :: Fold a (Exp b) -> Fold a (Exp b)
signum        = (Exp b -> Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Exp b -> Exp b
forall a. Num a => a -> a
signum
  fromInteger :: Integer -> Fold a (Exp b)
fromInteger Integer
n = Exp b -> Fold a (Exp b)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Integer -> Exp b
forall a. Num a => Integer -> a
A.fromInteger Integer
n)

instance A.Fractional b => P.Fractional (Fold a (Exp b)) where
  / :: Fold a (Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
(/)            = (Exp b -> Exp b -> Exp b)
-> Fold a (Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 Exp b -> Exp b -> Exp b
forall a. Fractional a => a -> a -> a
(/)
  recip :: Fold a (Exp b) -> Fold a (Exp b)
recip          = (Exp b -> Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Exp b -> Exp b
forall a. Fractional a => a -> a
recip
  fromRational :: Rational -> Fold a (Exp b)
fromRational Rational
n = Exp b -> Fold a (Exp b)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Rational -> Exp b
forall a. Fractional a => Rational -> a
A.fromRational Rational
n)

instance A.Floating b => P.Floating (Fold a (Exp b)) where
  pi :: Fold a (Exp b)
pi      = Exp b -> Fold a (Exp b)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp b
forall a. Floating a => a
pi
  sin :: Fold a (Exp b) -> Fold a (Exp b)
sin     = (Exp b -> Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Exp b -> Exp b
forall a. Floating a => a -> a
sin
  cos :: Fold a (Exp b) -> Fold a (Exp b)
cos     = (Exp b -> Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Exp b -> Exp b
forall a. Floating a => a -> a
cos
  tan :: Fold a (Exp b) -> Fold a (Exp b)
tan     = (Exp b -> Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Exp b -> Exp b
forall a. Floating a => a -> a
tan
  asin :: Fold a (Exp b) -> Fold a (Exp b)
asin    = (Exp b -> Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Exp b -> Exp b
forall a. Floating a => a -> a
asin
  acos :: Fold a (Exp b) -> Fold a (Exp b)
acos    = (Exp b -> Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Exp b -> Exp b
forall a. Floating a => a -> a
acos
  atan :: Fold a (Exp b) -> Fold a (Exp b)
atan    = (Exp b -> Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Exp b -> Exp b
forall a. Floating a => a -> a
atan
  sinh :: Fold a (Exp b) -> Fold a (Exp b)
sinh    = (Exp b -> Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Exp b -> Exp b
forall a. Floating a => a -> a
sinh
  cosh :: Fold a (Exp b) -> Fold a (Exp b)
cosh    = (Exp b -> Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Exp b -> Exp b
forall a. Floating a => a -> a
cosh
  tanh :: Fold a (Exp b) -> Fold a (Exp b)
tanh    = (Exp b -> Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Exp b -> Exp b
forall a. Floating a => a -> a
tanh
  asinh :: Fold a (Exp b) -> Fold a (Exp b)
asinh   = (Exp b -> Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Exp b -> Exp b
forall a. Floating a => a -> a
asinh
  acosh :: Fold a (Exp b) -> Fold a (Exp b)
acosh   = (Exp b -> Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Exp b -> Exp b
forall a. Floating a => a -> a
acosh
  atanh :: Fold a (Exp b) -> Fold a (Exp b)
atanh   = (Exp b -> Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Exp b -> Exp b
forall a. Floating a => a -> a
atanh
  exp :: Fold a (Exp b) -> Fold a (Exp b)
exp     = (Exp b -> Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Exp b -> Exp b
forall a. Floating a => a -> a
exp
  sqrt :: Fold a (Exp b) -> Fold a (Exp b)
sqrt    = (Exp b -> Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Exp b -> Exp b
forall a. Floating a => a -> a
sqrt
  log :: Fold a (Exp b) -> Fold a (Exp b)
log     = (Exp b -> Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Exp b -> Exp b
forall a. Floating a => a -> a
log
  ** :: Fold a (Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
(**)    = (Exp b -> Exp b -> Exp b)
-> Fold a (Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 Exp b -> Exp b -> Exp b
forall a. Floating a => a -> a -> a
(**)
  logBase :: Fold a (Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
logBase = (Exp b -> Exp b -> Exp b)
-> Fold a (Exp b) -> Fold a (Exp b) -> Fold a (Exp b)
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 Exp b -> Exp b -> Exp b
forall a. Floating a => a -> a -> a
logBase