module Spark.Core.Internal.AggregationFunctions where
import Data.Aeson(Value(Null))
import qualified Data.Text as T
import Formatting
import Spark.Core.Internal.DatasetStructures
import Spark.Core.Internal.ColumnStructures
import Spark.Core.Internal.DatasetFunctions
import Spark.Core.Internal.RowGenerics(ToSQL)
import Spark.Core.Internal.LocalDataFunctions()
import Spark.Core.Internal.FunctionsInternals
import Spark.Core.Internal.OpStructures
import Spark.Core.Internal.Utilities(failure, HasCallStack)
import Spark.Core.Internal.TypesStructures
import Spark.Core.Types
colSum :: forall ref a. (Num a, SQLTypeable a, ToSQL a) =>
Column ref a -> LocalData a
colSum = applyUniAgg (_sumAgg :: UniversalAggregator a a)
count :: forall a. (SQLTypeable a) => Dataset a -> LocalData Int
count ds = applyUniAgg (_countAgg2 :: UniversalAggregator a Int) (asCol ds)
collect :: forall ref a. (SQLTypeable a) => Column ref a -> LocalData [a]
collect = applyUniAgg (_collectAgg :: UniversalAggregator a [a])
data UniversalAggregator a buff = UniversalAggregator {
uaInitialOuter :: Dataset a -> LocalData buff,
uaMergeBuffer :: LocalData buff -> LocalData buff -> LocalData buff
}
univAggToOp :: forall a buff. (SQLTypeable a, SQLTypeable buff) =>
UniversalAggregator a buff -> UniversalAggregatorOp
univAggToOp = univAggToOpTyped (buildType :: SQLType a) (buildType :: SQLType buff)
univAggToOpTyped :: forall a buff.
SQLType a ->
SQLType buff ->
UniversalAggregator a buff ->
UniversalAggregatorOp
univAggToOpTyped sqlta sqltm ua =
let
mt = unSQLType sqltm
outer = _unsafeExtractOp $ fun1ToOpTyped sqlta (uaInitialOuter ua)
merge = _unsafeExtractOp $ fun2ToOpTyped sqltm sqltm (uaMergeBuffer ua)
in UniversalAggregatorOp {
uaoMergeType = mt,
uaoInitialOuter = outer,
uaoMergeBuffer = merge
}
applyUniAgg :: UniversalAggregator a b -> Column ref a -> LocalData b
applyUniAgg ua c =
let
ds = pack1 c
ld1 = uaInitialOuter ua ds
in ld1
simpleOp1Typed :: (IsLocality loca, IsLocality locb) =>
SQLType b ->
T.Text ->
ComputeNode loca a -> ComputeNode locb b
simpleOp1Typed sqltb name =
let so = StandardOperator {
soName = name,
soOutputType = unSQLType sqltb,
soExtra = Null
}
no = NodeLocalOp so
in nodeOpToFun1Typed sqltb no
simpleOp1 :: forall a b loca locb. (IsLocality loca, IsLocality locb, SQLTypeable a, SQLTypeable b) =>
T.Text ->
ComputeNode loca a -> ComputeNode locb b
simpleOp1 = simpleOp1Typed (buildType :: SQLType b)
simpleOp2 :: forall a1 a2 b loc1 loc2 locb. (SQLTypeable b, IsLocality loc1, IsLocality loc2, IsLocality locb) =>
T.Text ->
ComputeNode loc1 a1 -> ComputeNode loc2 a2 -> ComputeNode locb b
simpleOp2 = simpleOp2Typed (buildType :: SQLType b)
simpleOp2Typed :: (IsLocality loc1, IsLocality loc2, IsLocality locb) =>
SQLType b ->
T.Text ->
ComputeNode loc1 a1 -> ComputeNode loc2 a2 -> ComputeNode locb b
simpleOp2Typed sqltb name =
let so = StandardOperator {
soName = name,
soOutputType = unSQLType sqltb,
soExtra = Null
}
no = NodeLocalOp so
in nodeOpToFun2Typed sqltb no
_unsafeExtractOp :: (HasCallStack) => NodeOp -> StandardOperator
_unsafeExtractOp (NodeLocalOp so) = so
_unsafeExtractOp (NodeOpaqueAggregator so) = so
_unsafeExtractOp (NodeDistributedOp so) = so
_unsafeExtractOp x = failure $ sformat ("Expected standard op, found "%shown) x
_countAgg2 :: (SQLTypeable a) => UniversalAggregator a Int
_countAgg2 = UniversalAggregator {
uaInitialOuter = simpleOp1 "org.spark.Count",
uaMergeBuffer = (+)
}
_sumAgg :: forall a. (SQLTypeable a, Num a, ToSQL a) => UniversalAggregator a a
_sumAgg = UniversalAggregator {
uaInitialOuter = simpleOp1 "org.spark.Sum",
uaMergeBuffer = (+)
}
_collectAgg :: forall a. SQLTypeable a => UniversalAggregator a [a]
_collectAgg =
UniversalAggregator {
uaInitialOuter = simpleOp1 "org.spark.Collect",
uaMergeBuffer = simpleOp2 "org.spark.CatSorted"
}