-- | Perform aggregation on 'S.Select's.  To aggregate a 'S.Select' you
-- should construct an 'Aggregator' encoding how you want the
-- aggregation to proceed, then call 'aggregate' on it.  The
-- 'Aggregator' should be constructed from the basic 'Aggregator's
-- below by using the combining operations from
-- "Data.Profunctor.Product".

module Opaleye.Aggregate
       (
       -- * Aggregation
         aggregate
       , aggregateOrdered
       , distinctAggregator
       , Aggregator
       -- * Basic 'Aggregator's
       , groupBy
       , Opaleye.Aggregate.sum
       , sumInt4
       , sumInt8
       , count
       , countStar
       , avg
       , Opaleye.Aggregate.max
       , Opaleye.Aggregate.min
       , boolOr
       , boolAnd
       , arrayAgg
       , jsonAgg
       , stringAgg
       -- * Counting rows
       , countRows
       ) where

import           Control.Applicative (pure)
import           Data.Profunctor     (lmap)
import qualified Data.Profunctor as P

import qualified Opaleye.Internal.Aggregate as A
import           Opaleye.Internal.Aggregate (Aggregator, orderAggregate)
import qualified Opaleye.Internal.Column as IC
import qualified Opaleye.Internal.QueryArr as Q
import qualified Opaleye.Internal.HaskellDB.PrimQuery as HPQ
import qualified Opaleye.Internal.PackMap as PM

import qualified Opaleye.Column    as C
import qualified Opaleye.Order     as Ord
import qualified Opaleye.Select    as S
import qualified Opaleye.SqlTypes   as T
import qualified Opaleye.Join      as J

-- This page of Postgres documentation tell us what aggregate
-- functions are available
--
--   http://www.postgresql.org/docs/9.3/static/functions-aggregate.html

{-|
Given a 'S.Select' producing rows of type @a@ and an 'Aggregator' accepting rows of
type @a@, apply the aggregator to the select.

If you simply want to count the number of rows in a query you might
find the 'countRows' function more convenient.

If you want to use 'aggregate' with 'S.SelectArr's then you should
compose it with 'Opaleye.Lateral.laterally':

@
'Opaleye.Lateral.laterally' . 'aggregate' :: 'Aggregator' a b -> 'S.SelectArr' a b -> 'S.SelectArr' a b
@

Please note that when aggregating an empty query with no @GROUP BY@
clause, Opaleye's behaviour differs from Postgres's behaviour.
Postgres returns a single row whereas Opaleye returns zero rows.
Opaleye's behaviour is consistent with the meaning of aggregating
over groups of rows and Postgres's behaviour is inconsistent.  When a
query has zero rows it has zero groups, and thus zero rows in the
result of an aggregation.

-}
-- See 'Opaleye.Internal.Sql.aggregate' for details of how aggregating
-- by an empty query with no group by is handled.
aggregate :: Aggregator a b -> S.Select a -> S.Select b
aggregate :: Aggregator a b -> Select a -> Select b
aggregate Aggregator a b
agg Select a
q = (((), Tag) -> (b, PrimQuery, Tag)) -> Select b
forall a b. ((a, Tag) -> (b, PrimQuery, Tag)) -> QueryArr a b
Q.productQueryArr (Aggregator a b -> (a, PrimQuery, Tag) -> (b, PrimQuery, Tag)
forall a b.
Aggregator a b -> (a, PrimQuery, Tag) -> (b, PrimQuery, Tag)
A.aggregateU Aggregator a b
agg ((a, PrimQuery, Tag) -> (b, PrimQuery, Tag))
-> (((), Tag) -> (a, PrimQuery, Tag))
-> ((), Tag)
-> (b, PrimQuery, Tag)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Select a -> ((), Tag) -> (a, PrimQuery, Tag)
forall a b. QueryArr a b -> (a, Tag) -> (b, PrimQuery, Tag)
Q.runSimpleQueryArr Select a
q)

-- | Order the values within each aggregation in `Aggregator` using
-- the given ordering. This is only relevant for aggregations that
-- depend on the order they get their elements, like `arrayAgg` and
-- `stringAgg`.
--
-- Note that this orders all aggregations with the same ordering. If
-- you need different orderings for different aggregations, use
-- 'Opaleye.Internal.Aggregate.orderAggregate'.

aggregateOrdered  :: Ord.Order a -> Aggregator a b -> S.Select a -> S.Select b
aggregateOrdered :: Order a -> Aggregator a b -> Select a -> Select b
aggregateOrdered Order a
o Aggregator a b
agg = Aggregator a b -> Select a -> Select b
forall a b. Aggregator a b -> Select a -> Select b
aggregate (Order a -> Aggregator a b -> Aggregator a b
forall a b. Order a -> Aggregator a b -> Aggregator a b
orderAggregate Order a
o Aggregator a b
agg)

-- | Aggregate only distinct values
distinctAggregator :: Aggregator a b -> Aggregator a b
distinctAggregator :: Aggregator a b -> Aggregator a b
distinctAggregator (A.Aggregator (PM.PackMap forall (f :: * -> *).
Applicative f =>
((Maybe (AggrOp, [OrderExpr], AggrDistinct), PrimExpr)
 -> f PrimExpr)
-> a -> f b
pm)) =
  PackMap
  (Maybe (AggrOp, [OrderExpr], AggrDistinct), PrimExpr) PrimExpr a b
-> Aggregator a b
forall a b.
PackMap
  (Maybe (AggrOp, [OrderExpr], AggrDistinct), PrimExpr) PrimExpr a b
-> Aggregator a b
A.Aggregator ((forall (f :: * -> *).
 Applicative f =>
 ((Maybe (AggrOp, [OrderExpr], AggrDistinct), PrimExpr)
  -> f PrimExpr)
 -> a -> f b)
-> PackMap
     (Maybe (AggrOp, [OrderExpr], AggrDistinct), PrimExpr) PrimExpr a b
forall a b s t.
(forall (f :: * -> *). Applicative f => (a -> f b) -> s -> f t)
-> PackMap a b s t
PM.PackMap (\(Maybe (AggrOp, [OrderExpr], AggrDistinct), PrimExpr) -> f PrimExpr
f a
c -> ((Maybe (AggrOp, [OrderExpr], AggrDistinct), PrimExpr)
 -> f PrimExpr)
-> a -> f b
forall (f :: * -> *).
Applicative f =>
((Maybe (AggrOp, [OrderExpr], AggrDistinct), PrimExpr)
 -> f PrimExpr)
-> a -> f b
pm ((Maybe (AggrOp, [OrderExpr], AggrDistinct), PrimExpr) -> f PrimExpr
f ((Maybe (AggrOp, [OrderExpr], AggrDistinct), PrimExpr)
 -> f PrimExpr)
-> ((Maybe (AggrOp, [OrderExpr], AggrDistinct), PrimExpr)
    -> (Maybe (AggrOp, [OrderExpr], AggrDistinct), PrimExpr))
-> (Maybe (AggrOp, [OrderExpr], AggrDistinct), PrimExpr)
-> f PrimExpr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Maybe (AggrOp, [OrderExpr], AggrDistinct)
 -> Maybe (AggrOp, [OrderExpr], AggrDistinct))
-> (Maybe (AggrOp, [OrderExpr], AggrDistinct), PrimExpr)
-> (Maybe (AggrOp, [OrderExpr], AggrDistinct), PrimExpr)
forall (p :: * -> * -> *) a b c.
Strong p =>
p a b -> p (a, c) (b, c)
P.first' (((AggrOp, [OrderExpr], AggrDistinct)
 -> (AggrOp, [OrderExpr], AggrDistinct))
-> Maybe (AggrOp, [OrderExpr], AggrDistinct)
-> Maybe (AggrOp, [OrderExpr], AggrDistinct)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(AggrOp
a,[OrderExpr]
b,AggrDistinct
_) -> (AggrOp
a,[OrderExpr]
b,AggrDistinct
HPQ.AggrDistinct)))) a
c))

-- | Group the aggregation by equality on the input to 'groupBy'.
groupBy :: Aggregator (C.Column a) (C.Column a)
groupBy :: Aggregator (Column a) (Column a)
groupBy = Maybe AggrOp -> Aggregator (Column a) (Column a)
forall a b. Maybe AggrOp -> Aggregator (Column a) (Column b)
A.makeAggr' Maybe AggrOp
forall a. Maybe a
Nothing

-- | Sum all rows in a group.
--
-- WARNING! The type of this operation is wrong and will crash at
-- runtime when the argument is 'T.SqlInt4' or 'T.SqlInt8'.  For those
-- use 'sumInt4' or 'sumInt8' instead.
sum :: Aggregator (C.Column a) (C.Column a)
sum :: Aggregator (Column a) (Column a)
sum = AggrOp -> Aggregator (Column a) (Column a)
forall a b. AggrOp -> Aggregator (Column a) (Column b)
A.makeAggr AggrOp
HPQ.AggrSum

sumInt4 :: Aggregator (C.Column T.SqlInt4) (C.Column T.SqlInt8)
sumInt4 :: Aggregator (Column SqlInt4) (Column SqlInt8)
sumInt4 = (Column SqlInt4 -> Column SqlInt8)
-> Aggregator (Column SqlInt4) (Column SqlInt4)
-> Aggregator (Column SqlInt4) (Column SqlInt8)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Column SqlInt4 -> Column SqlInt8
forall a b. Column a -> Column b
C.unsafeCoerceColumn Aggregator (Column SqlInt4) (Column SqlInt4)
forall a. Aggregator (Column a) (Column a)
Opaleye.Aggregate.sum

sumInt8 :: Aggregator (C.Column T.SqlInt8) (C.Column T.SqlNumeric)
sumInt8 :: Aggregator (Column SqlInt8) (Column SqlNumeric)
sumInt8 = (Column SqlInt8 -> Column SqlNumeric)
-> Aggregator (Column SqlInt8) (Column SqlInt8)
-> Aggregator (Column SqlInt8) (Column SqlNumeric)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Column SqlInt8 -> Column SqlNumeric
forall a b. Column a -> Column b
C.unsafeCoerceColumn Aggregator (Column SqlInt8) (Column SqlInt8)
forall a. Aggregator (Column a) (Column a)
Opaleye.Aggregate.sum

-- | Count the number of non-null rows in a group.
count :: Aggregator (C.Column a) (C.Column T.SqlInt8)
count :: Aggregator (Column a) (Column SqlInt8)
count = AggrOp -> Aggregator (Column a) (Column SqlInt8)
forall a b. AggrOp -> Aggregator (Column a) (Column b)
A.makeAggr AggrOp
HPQ.AggrCount

-- | Count the number of rows in a group.  This 'Aggregator' is named
-- @countStar@ after SQL's @COUNT(*)@ aggregation function.
countStar :: Aggregator a (C.Column T.SqlInt8)
countStar :: Aggregator a (Column SqlInt8)
countStar = (a -> Column SqlInt4)
-> Aggregator (Column SqlInt4) (Column SqlInt8)
-> Aggregator a (Column SqlInt8)
forall (p :: * -> * -> *) a b c.
Profunctor p =>
(a -> b) -> p b c -> p a c
lmap (Column SqlInt4 -> a -> Column SqlInt4
forall a b. a -> b -> a
const (Column SqlInt4
0 :: C.Column T.SqlInt4)) Aggregator (Column SqlInt4) (Column SqlInt8)
forall a. Aggregator (Column a) (Column SqlInt8)
count

-- | Average of a group
avg :: Aggregator (C.Column T.SqlFloat8) (C.Column T.SqlFloat8)
avg :: Aggregator (Column SqlFloat8) (Column SqlFloat8)
avg = AggrOp -> Aggregator (Column SqlFloat8) (Column SqlFloat8)
forall a b. AggrOp -> Aggregator (Column a) (Column b)
A.makeAggr AggrOp
HPQ.AggrAvg

-- | Maximum of a group
max :: Ord.SqlOrd a => Aggregator (C.Column a) (C.Column a)
max :: Aggregator (Column a) (Column a)
max = AggrOp -> Aggregator (Column a) (Column a)
forall a b. AggrOp -> Aggregator (Column a) (Column b)
A.makeAggr AggrOp
HPQ.AggrMax

-- | Maximum of a group
min :: Ord.SqlOrd a => Aggregator (C.Column a) (C.Column a)
min :: Aggregator (Column a) (Column a)
min = AggrOp -> Aggregator (Column a) (Column a)
forall a b. AggrOp -> Aggregator (Column a) (Column b)
A.makeAggr AggrOp
HPQ.AggrMin

boolOr :: Aggregator (C.Column T.SqlBool) (C.Column T.SqlBool)
boolOr :: Aggregator (Column SqlBool) (Column SqlBool)
boolOr = AggrOp -> Aggregator (Column SqlBool) (Column SqlBool)
forall a b. AggrOp -> Aggregator (Column a) (Column b)
A.makeAggr AggrOp
HPQ.AggrBoolOr

boolAnd :: Aggregator (C.Column T.SqlBool) (C.Column T.SqlBool)
boolAnd :: Aggregator (Column SqlBool) (Column SqlBool)
boolAnd = AggrOp -> Aggregator (Column SqlBool) (Column SqlBool)
forall a b. AggrOp -> Aggregator (Column a) (Column b)
A.makeAggr AggrOp
HPQ.AggrBoolAnd

arrayAgg :: Aggregator (C.Column a) (C.Column (T.SqlArray a))
arrayAgg :: Aggregator (Column a) (Column (SqlArray a))
arrayAgg = AggrOp -> Aggregator (Column a) (Column (SqlArray a))
forall a b. AggrOp -> Aggregator (Column a) (Column b)
A.makeAggr AggrOp
HPQ.AggrArr

jsonAgg :: Aggregator (C.Column a) (C.Column T.SqlJson)
jsonAgg :: Aggregator (Column a) (Column SqlJson)
jsonAgg = AggrOp -> Aggregator (Column a) (Column SqlJson)
forall a b. AggrOp -> Aggregator (Column a) (Column b)
A.makeAggr AggrOp
HPQ.JsonArr

stringAgg :: C.Column T.SqlText
          -> Aggregator (C.Column T.SqlText) (C.Column T.SqlText)
stringAgg :: Column SqlText -> Aggregator (Column SqlText) (Column SqlText)
stringAgg = Maybe AggrOp -> Aggregator (Column SqlText) (Column SqlText)
forall a b. Maybe AggrOp -> Aggregator (Column a) (Column b)
A.makeAggr' (Maybe AggrOp -> Aggregator (Column SqlText) (Column SqlText))
-> (Column SqlText -> Maybe AggrOp)
-> Column SqlText
-> Aggregator (Column SqlText) (Column SqlText)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AggrOp -> Maybe AggrOp
forall a. a -> Maybe a
Just (AggrOp -> Maybe AggrOp)
-> (Column SqlText -> AggrOp) -> Column SqlText -> Maybe AggrOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimExpr -> AggrOp
HPQ.AggrStringAggr (PrimExpr -> AggrOp)
-> (Column SqlText -> PrimExpr) -> Column SqlText -> AggrOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Column SqlText -> PrimExpr
forall a. Column a -> PrimExpr
IC.unColumn

-- | Count the number of rows in a query.  This is different from
-- 'aggregate' 'count' because it always returns exactly one row, even
-- when the input query is empty.

-- This is currently implemented in a cheeky way with a LEFT JOIN.  If
-- there are any performance issues it could be rewritten to use an
-- SQL COUNT aggregation which groups by nothing.  This would require
-- changing the AST though, so I'm not too keen.
--
-- See https://github.com/tomjaguarpaw/haskell-opaleye/issues/162
countRows :: S.Select a -> S.Select (C.Column T.SqlInt8)
countRows :: Select a -> Select (Column SqlInt8)
countRows = (Column (Nullable SqlInt8) -> Column SqlInt8)
-> SelectArr () (Column (Nullable SqlInt8))
-> Select (Column SqlInt8)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Column SqlInt8 -> Column (Nullable SqlInt8) -> Column SqlInt8
forall a. Column a -> Column (Nullable a) -> Column a
C.fromNullable Column SqlInt8
0)
            (SelectArr () (Column (Nullable SqlInt8))
 -> Select (Column SqlInt8))
-> (Select a -> SelectArr () (Column (Nullable SqlInt8)))
-> Select a
-> Select (Column SqlInt8)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (((), Column (Nullable SqlInt8)) -> Column (Nullable SqlInt8))
-> SelectArr () ((), Column (Nullable SqlInt8))
-> SelectArr () (Column (Nullable SqlInt8))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((), Column (Nullable SqlInt8)) -> Column (Nullable SqlInt8)
forall a b. (a, b) -> b
snd
            (SelectArr () ((), Column (Nullable SqlInt8))
 -> SelectArr () (Column (Nullable SqlInt8)))
-> (Select a -> SelectArr () ((), Column (Nullable SqlInt8)))
-> Select a
-> SelectArr () (Column (Nullable SqlInt8))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (\Select (Column SqlInt4)
q -> Select ()
-> Select (Column SqlInt8)
-> (((), Column SqlInt8) -> Field SqlBool)
-> SelectArr () ((), Column (Nullable SqlInt8))
forall fieldsL fieldsR nullableFieldsR.
(Default Unpackspec fieldsL fieldsL,
 Default Unpackspec fieldsR fieldsR,
 Default NullMaker fieldsR nullableFieldsR) =>
Select fieldsL
-> Select fieldsR
-> ((fieldsL, fieldsR) -> Field SqlBool)
-> Select (fieldsL, nullableFieldsR)
J.leftJoin (() -> Select ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ())
                                (Aggregator (Column SqlInt4) (Column SqlInt8)
-> Select (Column SqlInt4) -> Select (Column SqlInt8)
forall a b. Aggregator a b -> Select a -> Select b
aggregate Aggregator (Column SqlInt4) (Column SqlInt8)
forall a. Aggregator (Column a) (Column SqlInt8)
count Select (Column SqlInt4)
q)
                                (Column SqlBool -> ((), Column SqlInt8) -> Column SqlBool
forall a b. a -> b -> a
const (Bool -> Field SqlBool
T.sqlBool Bool
True)))
            (Select (Column SqlInt4)
 -> SelectArr () ((), Column (Nullable SqlInt8)))
-> (Select a -> Select (Column SqlInt4))
-> Select a
-> SelectArr () ((), Column (Nullable SqlInt8))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> Column SqlInt4) -> Select a -> Select (Column SqlInt4)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Column SqlInt4 -> a -> Column SqlInt4
forall a b. a -> b -> a
const (Column SqlInt4
0 :: C.Column T.SqlInt4))
            --- ^^ The count aggregator requires an input of type
            -- 'Column a' rather than 'a' (I'm not sure if there's a
            -- good reason for this).  To deal with that restriction
            -- we just map a dummy integer value over it.