{-| Module: Squeal.PostgreSQL.Expression.Aggregate Description: Aggregate functions Copyright: (c) Eitan Chatav, 2019 Maintainer: eitan@morphism.tech Stability: experimental Aggregate functions -} {-# LANGUAGE DataKinds , DeriveGeneric , FlexibleContexts , FlexibleInstances , FunctionalDependencies , LambdaCase , MultiParamTypeClasses , OverloadedStrings , PolyKinds , TypeFamilies , TypeOperators , UndecidableInstances #-} module Squeal.PostgreSQL.Expression.Aggregate ( Aggregate (..) , Distinction (..) , PGSum , PGAvg ) where import Control.DeepSeq import Data.ByteString (ByteString) import Data.Kind import GHC.TypeLits import qualified GHC.Generics as GHC import qualified Generics.SOP as SOP import Squeal.PostgreSQL.Alias import Squeal.PostgreSQL.Expression import Squeal.PostgreSQL.List import Squeal.PostgreSQL.Render import Squeal.PostgreSQL.Schema -- $setup -- >>> import Squeal.PostgreSQL {- | `Aggregate` functions compute a single result from a set of input values. `Aggregate` functions can be used as `GroupedBy` `Expression`s as well as `Squeal.PostgreSQL.Expression.Window.WindowFunction`s. -} class Aggregate expr1 exprN aggr | aggr -> expr1, aggr -> exprN where -- | A special aggregation that does not require an input -- -- >>> :{ -- let -- expression :: Expression '[] commons ('Grouped bys) schemas params from ('NotNull 'PGint8) -- expression = countStar -- in printSQL expression -- :} -- count(*) countStar :: aggr ('NotNull 'PGint8) -- | >>> :{ -- let -- expression :: Expression '[] commons ('Grouped bys) schemas params '[tab ::: '["col" ::: null ty]] ('NotNull 'PGint8) -- expression = count (All #col) -- in printSQL expression -- :} -- count(ALL "col") count :: expr1 ty -- ^ what to count -> aggr ('NotNull 'PGint8) -- | >>> :{ -- let -- expression :: Expression '[] commons ('Grouped bys) schemas params '[tab ::: '["col" ::: 'Null 'PGnumeric]] ('Null 'PGnumeric) -- expression = sum_ (Distinct #col) -- in printSQL expression -- :} -- sum(DISTINCT "col") sum_ :: expr1 (null ty) -> aggr ('Null (PGSum ty)) -- | input values, including nulls, concatenated into an array arrayAgg :: expr1 ty -> aggr ('Null ('PGvararray ty)) -- | aggregates values as a JSON array jsonAgg :: expr1 ty -> aggr ('Null 'PGjson) -- | aggregates values as a JSON array jsonbAgg :: expr1 ty -> aggr ('Null 'PGjsonb) {- | the bitwise AND of all non-null input values, or null if none >>> :{ let expression :: Expression '[] commons ('Grouped bys) schemas params '[tab ::: '["col" ::: null 'PGint4]] ('Null 'PGint4) expression = bitAnd (Distinct #col) in printSQL expression :} bit_and(DISTINCT "col") -} bitAnd :: int `In` PGIntegral => expr1 (null int) -- ^ what to aggregate -> aggr ('Null int) {- | the bitwise OR of all non-null input values, or null if none >>> :{ let expression :: Expression '[] commons ('Grouped bys) schemas params '[tab ::: '["col" ::: null 'PGint4]] ('Null 'PGint4) expression = bitOr (All #col) in printSQL expression :} bit_or(ALL "col") -} bitOr :: int `In` PGIntegral => expr1 (null int) -- ^ what to aggregate -> aggr ('Null int) {- | true if all input values are true, otherwise false >>> :{ let winFun :: WindowFunction '[] commons 'Ungrouped schemas params '[tab ::: '["col" ::: null 'PGbool]] ('Null 'PGbool) winFun = boolAnd #col in printSQL winFun :} bool_and("col") -} boolAnd :: expr1 (null 'PGbool) -- ^ what to aggregate -> aggr ('Null 'PGbool) {- | true if at least one input value is true, otherwise false >>> :{ let expression :: Expression '[] commons ('Grouped bys) schemas params '[tab ::: '["col" ::: null 'PGbool]] ('Null 'PGbool) expression = boolOr (All #col) in printSQL expression :} bool_or(ALL "col") -} boolOr :: expr1 (null 'PGbool) -- ^ what to aggregate -> aggr ('Null 'PGbool) {- | equivalent to `boolAnd` >>> :{ let expression :: Expression '[] commons ('Grouped bys) schemas params '[tab ::: '["col" ::: null 'PGbool]] ('Null 'PGbool) expression = every (Distinct #col) in printSQL expression :} every(DISTINCT "col") -} every :: expr1 (null 'PGbool) -- ^ what to aggregate -> aggr ('Null 'PGbool) {- |maximum value of expression across all input values-} max_ :: expr1 (null ty) -- ^ what to maximize -> aggr ('Null ty) -- | minimum value of expression across all input values min_ :: expr1 (null ty) -- ^ what to minimize -> aggr ('Null ty) -- | the average (arithmetic mean) of all input values avg :: expr1 (null ty) -- ^ what to average -> aggr ('Null (PGAvg ty)) {- | correlation coefficient >>> :{ let expression :: Expression '[] c ('Grouped g) s p '[t ::: '["x" ::: 'NotNull 'PGfloat8, "y" ::: 'NotNull 'PGfloat8]] ('Null 'PGfloat8) expression = corr (All (#y *: #x)) in printSQL expression :} corr(ALL "y", "x") -} corr :: exprN '[null 'PGfloat8, null 'PGfloat8] -> aggr ('Null 'PGfloat8) {- | population covariance >>> :{ let expression :: Expression '[] c ('Grouped g) s p '[t ::: '["x" ::: 'NotNull 'PGfloat8, "y" ::: 'NotNull 'PGfloat8]] ('Null 'PGfloat8) expression = covarPop (All (#y *: #x)) in printSQL expression :} covar_pop(ALL "y", "x") -} covarPop :: exprN '[null 'PGfloat8, null 'PGfloat8] -> aggr ('Null 'PGfloat8) {- | sample covariance >>> :{ let winFun :: WindowFunction '[] c 'Ungrouped s p '[t ::: '["x" ::: 'NotNull 'PGfloat8, "y" ::: 'NotNull 'PGfloat8]] ('Null 'PGfloat8) winFun = covarSamp (#y *: #x) in printSQL winFun :} covar_samp("y", "x") -} covarSamp :: exprN '[null 'PGfloat8, null 'PGfloat8] -> aggr ('Null 'PGfloat8) {- | average of the independent variable (sum(X)/N) >>> :{ let expression :: Expression '[] c ('Grouped g) s p '[t ::: '["x" ::: 'NotNull 'PGfloat8, "y" ::: 'NotNull 'PGfloat8]] ('Null 'PGfloat8) expression = regrAvgX (All (#y *: #x)) in printSQL expression :} regr_avgx(ALL "y", "x") -} regrAvgX :: exprN '[null 'PGfloat8, null 'PGfloat8] -> aggr ('Null 'PGfloat8) {- | average of the dependent variable (sum(Y)/N) >>> :{ let winFun :: WindowFunction '[] c 'Ungrouped s p '[t ::: '["x" ::: 'NotNull 'PGfloat8, "y" ::: 'NotNull 'PGfloat8]] ('Null 'PGfloat8) winFun = regrAvgY (#y *: #x) in printSQL winFun :} regr_avgy("y", "x") -} regrAvgY :: exprN '[null 'PGfloat8, null 'PGfloat8] -> aggr ('Null 'PGfloat8) {- | number of input rows in which both expressions are nonnull >>> :{ let winFun :: WindowFunction '[] c 'Ungrouped s p '[t ::: '["x" ::: 'NotNull 'PGfloat8, "y" ::: 'NotNull 'PGfloat8]] ('Null 'PGint8) winFun = regrCount (#y *: #x) in printSQL winFun :} regr_count("y", "x") -} regrCount :: exprN '[null 'PGfloat8, null 'PGfloat8] -> aggr ('Null 'PGint8) {- | y-intercept of the least-squares-fit linear equation determined by the (X, Y) pairs >>> :{ let expression :: Expression '[] c ('Grouped g) s p '[t ::: '["x" ::: 'NotNull 'PGfloat8, "y" ::: 'NotNull 'PGfloat8]] ('Null 'PGfloat8) expression = regrIntercept (All (#y *: #x)) in printSQL expression :} regr_intercept(ALL "y", "x") -} regrIntercept :: exprN '[null 'PGfloat8, null 'PGfloat8] -> aggr ('Null 'PGfloat8) -- | @regr_r2(Y, X)@, square of the correlation coefficient regrR2 :: exprN '[null 'PGfloat8, null 'PGfloat8] -> aggr ('Null 'PGfloat8) -- | @regr_slope(Y, X)@, slope of the least-squares-fit linear equation -- determined by the (X, Y) pairs regrSlope :: exprN '[null 'PGfloat8, null 'PGfloat8] -> aggr ('Null 'PGfloat8) -- | @regr_sxx(Y, X)@, sum(X^2) - sum(X)^2/N -- (“sum of squares” of the independent variable) regrSxx :: exprN '[null 'PGfloat8, null 'PGfloat8] -> aggr ('Null 'PGfloat8) -- | @regr_sxy(Y, X)@, sum(X*Y) - sum(X) * sum(Y)/N -- (“sum of products” of independent times dependent variable) regrSxy :: exprN '[null 'PGfloat8, null 'PGfloat8] -> aggr ('Null 'PGfloat8) -- | @regr_syy(Y, X)@, sum(Y^2) - sum(Y)^2/N -- (“sum of squares” of the dependent variable) regrSyy :: exprN '[null 'PGfloat8, null 'PGfloat8] -> aggr ('Null 'PGfloat8) -- | historical alias for `stddevSamp` stddev :: expr1 (null ty) -> aggr ('Null (PGAvg ty)) -- | population standard deviation of the input values stddevPop :: expr1 (null ty) -> aggr ('Null (PGAvg ty)) -- | sample standard deviation of the input values stddevSamp :: expr1 (null ty) -> aggr ('Null (PGAvg ty)) -- | historical alias for `varSamp` variance :: expr1 (null ty) -> aggr ('Null (PGAvg ty)) -- | population variance of the input values -- (square of the population standard deviation) varPop :: expr1 (null ty) -> aggr ('Null (PGAvg ty)) -- | sample variance of the input values -- (square of the sample standard deviation) varSamp :: expr1 (null ty) -> aggr ('Null (PGAvg ty)) {- | `Distinction`s are used for the input of `Aggregate` `Expression`s. `All` invokes the aggregate once for each input row. `Distinct` invokes the aggregate once for each distinct value of the expression (or distinct set of values, for multiple expressions) found in the input -} data Distinction (expr :: kind -> Type) (ty :: kind) = All (expr ty) | Distinct (expr ty) deriving (GHC.Generic,Show,Eq,Ord) instance NFData (Distinction (Expression outer commons grp schemas params from) ty) instance RenderSQL (Distinction (Expression outer commons grp schemas params from) ty) where renderSQL = \case All x -> "ALL" <+> renderSQL x Distinct x -> "DISTINCT" <+> renderSQL x instance SOP.SListI tys => RenderSQL (Distinction (NP (Expression outer commons grp schemas params from)) tys) where renderSQL = \case All xs -> "ALL" <+> renderCommaSeparated renderSQL xs Distinct xs -> "DISTINCT" <+> renderCommaSeparated renderSQL xs instance Aggregate (Distinction (Expression outer commons 'Ungrouped schemas params from)) (Distinction (NP (Expression outer commons 'Ungrouped schemas params from))) (Expression outer commons ('Grouped bys) schemas params from) where countStar = UnsafeExpression "count(*)" count = unsafeAggregate1 "count" sum_ = unsafeAggregate1 "sum" arrayAgg = unsafeAggregate1 "array_agg" jsonAgg = unsafeAggregate1 "json_agg" jsonbAgg = unsafeAggregate1 "jsonb_agg" bitAnd = unsafeAggregate1 "bit_and" bitOr = unsafeAggregate1 "bit_or" boolAnd = unsafeAggregate1 "bool_and" boolOr = unsafeAggregate1 "bool_or" every = unsafeAggregate1 "every" max_ = unsafeAggregate1 "max" min_ = unsafeAggregate1 "min" avg = unsafeAggregate1 "avg" corr = unsafeAggregateN "corr" covarPop = unsafeAggregateN "covar_pop" covarSamp = unsafeAggregateN "covar_samp" regrAvgX = unsafeAggregateN "regr_avgx" regrAvgY = unsafeAggregateN "regr_avgy" regrCount = unsafeAggregateN "regr_count" regrIntercept = unsafeAggregateN "regr_intercept" regrR2 = unsafeAggregateN "regr_r2" regrSlope = unsafeAggregateN "regr_slope" regrSxx = unsafeAggregateN "regr_sxx" regrSxy = unsafeAggregateN "regr_sxy" regrSyy = unsafeAggregateN "regr_syy" stddev = unsafeAggregate1 "stddev" stddevPop = unsafeAggregate1 "stddev_pop" stddevSamp = unsafeAggregate1 "stddev_samp" variance = unsafeAggregate1 "variance" varPop = unsafeAggregate1 "var_pop" varSamp = unsafeAggregate1 "var_samp" -- | escape hatch to define aggregate functions unsafeAggregate1 :: ByteString -- ^ aggregate function -> Distinction (Expression outer commons 'Ungrouped schemas params from) x -> Expression outer commons ('Grouped bys) schemas params from y unsafeAggregate1 fun x = UnsafeExpression $ fun <> parenthesized (renderSQL x) unsafeAggregateN :: SOP.SListI xs => ByteString -- ^ function -> Distinction (NP (Expression outer commons 'Ungrouped schemas params from)) xs -> Expression outer commons ('Grouped bys) schemas params from y unsafeAggregateN fun xs = UnsafeExpression $ fun <> parenthesized (renderSQL xs) -- | A type family that calculates `PGSum``PGType` of -- a given argument `PGType`. type family PGSum ty where PGSum 'PGint2 = 'PGint8 PGSum 'PGint4 = 'PGint8 PGSum 'PGint8 = 'PGnumeric PGSum 'PGfloat4 = 'PGfloat4 PGSum 'PGfloat8 = 'PGfloat8 PGSum 'PGnumeric = 'PGnumeric PGSum 'PGinterval = 'PGinterval PGSum 'PGmoney = 'PGmoney PGSum pg = TypeError ( 'Text "Squeal type error: Cannot sum with argument type " ':<>: 'ShowType pg ) -- | A type family that calculates `PGAvg` type of a `PGType`. type family PGAvg ty where PGAvg 'PGint2 = 'PGnumeric PGAvg 'PGint4 = 'PGnumeric PGAvg 'PGint8 = 'PGnumeric PGAvg 'PGnumeric = 'PGnumeric PGAvg 'PGfloat4 = 'PGfloat8 PGAvg 'PGfloat8 = 'PGfloat8 PGAvg 'PGinterval = 'PGinterval PGAvg pg = TypeError ('Text "Squeal type error: No average for " ':<>: 'ShowType pg)