module Spark.Core.Internal.Groups(
GroupData,
LogicalGroupData,
groupByKey,
mapGroup,
aggKey,
groupAsDS
) where
import qualified Data.Text as T
import qualified Data.Vector as V
import Formatting
import Debug.Trace(trace)
import Spark.Core.Internal.DatasetStructures
import Spark.Core.Internal.ColumnStructures
import Spark.Core.Internal.ColumnFunctions(untypedCol, colType, colOp, iUntypedColData, colOrigin, castTypeCol, dropColReference, genColOp)
import Spark.Core.Internal.DatasetFunctions
import Spark.Core.Internal.LocalDataFunctions()
import Spark.Core.Internal.FunctionsInternals
import Spark.Core.Internal.TypesFunctions(tupleType, structTypeFromFields)
import Spark.Core.Internal.OpStructures
import Spark.Core.Internal.Projections
import Spark.Core.Internal.TypesStructures
import Spark.Core.Internal.Utilities
import Spark.Core.Internal.RowStructures(Cell)
import Spark.Core.Try
import Spark.Core.StructuresInternal
import Spark.Core.Internal.CanRename
data GroupData key val = GroupData {
_gdRef :: !UntypedDataset,
_gdKey :: !GroupColumn,
_gdValue :: !GroupColumn
}
type LogicalGroupData = Try UntypedGroupData
data GroupColumn = GroupColumn {
_gcType :: !DataType,
_gcOp :: !ColOp,
_gcRefName :: !(Maybe FieldName)
} deriving (Eq, Show)
type UntypedGroupData = GroupData Cell Cell
data PipedTrans =
PipedError !T.Text
| PipedDataset !UntypedDataset
| PipedGroup !UntypedGroupData
deriving (Show)
groupByKey :: (HasCallStack) => Column ref key -> Column ref val -> GroupData key val
groupByKey keys vals = forceRight $ _castGroup (colType keys) (colType vals) =<< _groupByKey (iUntypedColData keys) (iUntypedColData vals)
mapGroup :: GroupData key val -> (forall ref. Column ref val -> Column ref val') -> GroupData key val'
mapGroup g f =
let c = _valueCol g
c' = f (_unsafeCastColData c)
gVals = forceRight $ _groupCol c'
in g { _gdValue = gVals }
aggKey :: (HasCallStack) => GroupData key val -> (forall ref. Column ref val -> LocalData val') -> Dataset (key, val')
aggKey gd f = trace "aggKey" $
let ugd = _untypedGroup gd
keyt = traceHint "aggKey: keyt: " $ mapGroupKeys gd colType
valt = traceHint "aggKey: valt: " $ mapGroupValues gd colType
fOut = traceHint "aggKey: fOut: " $ f (mapGroupValues gd dropColReference)
valt' = traceHint "aggKey: valt': " $ nodeType fOut
t = traceHint "aggKey: t: " $ tupleType keyt valt'
f' c = untypedLocalData . f <$> castTypeCol valt c
tud = traceHint "aggKey: tud: " $ _aggKey ugd f'
res = castType' t tud
in forceRight res
groupAsDS :: forall key val. GroupData key val -> Dataset (key, val)
groupAsDS g = pack s where
c1 = _unsafeCastColData (_keyCol g) :: Column UnknownReference key
c2 = _unsafeCastColData (_valueCol g) :: Column UnknownReference val
s = struct (c1, c2) :: Column UnknownReference (key, val)
mapGroupKeys :: GroupData key val -> (forall ref. Column ref key -> a) -> a
mapGroupKeys gd f =
f (_unsafeCastColData (_keyCol gd))
mapGroupValues :: GroupData key val -> (forall ref. Column ref val -> a) -> a
mapGroupValues gd f =
f (_unsafeCastColData (_valueCol gd))
instance Show (GroupData key val) where
show gd = T.unpack s where
s = sformat ("GroupData[key="%sh%", val="%sh%"]") (_gdKey gd) (_gdValue gd)
_keyCol :: GroupData key val -> UntypedColumnData
_keyCol gd = ColumnData {
_cOrigin = _gdRef gd,
_cType = _gcType (_gdKey gd),
_cOp = genColOp . _gcOp . _gdKey $ gd,
_cReferingPath = _gcRefName . _gdKey $ gd
}
_valueCol :: GroupData key val -> UntypedColumnData
_valueCol gd = ColumnData {
_cOrigin = _gdRef gd,
_cType = _gcType (_gdValue gd),
_cOp = genColOp . _gcOp . _gdValue $ gd,
_cReferingPath = _gcRefName . _gdValue $ gd
}
_pError :: T.Text -> PipedTrans
_pError = PipedError
_unrollTransform :: PipedTrans -> NodeId -> UntypedNode -> PipedTrans
_unrollTransform start nid un | nodeId un == nid = start
_unrollTransform start nid un = case nodeParents un of
[p] ->
let pt' = _unrollTransform start nid p in _unrollStep pt' un
_ ->
_pError $ sformat (sh%": operations with multiple parents cannot be used in groups yet.") un
_unrollStep :: PipedTrans -> UntypedNode -> PipedTrans
_unrollStep pt un = traceHint ("_unrollStep: pt=" <> show' pt <> " un=" <> show' un <> " res=") $
let op = nodeOp un
dt = unSQLType (nodeType un) in case nodeParents un of
[p] ->
case (pt, op) of
(PipedError e, _) -> PipedError e
(PipedDataset ds, NodeStructuredTransform _) ->
let ds' = updateNode un (\un' -> un' { _cnParents = V.singleton (untyped ds)})
in PipedDataset ds'
(PipedGroup g, NodeStructuredTransform co) ->
_unrollGroupTrans g co
(PipedGroup g, NodeAggregatorReduction uao) ->
case uaoInitialOuter uao of
OpaqueAggTransform x -> _pError $ sformat ("Cannot apply opaque transform in the context of an aggregation: "%sh) x
InnerAggOp ao ->
PipedDataset $ _applyAggOp dt ao g
_ -> _pError $ sformat (sh%": Operation not supported with trans="%sh%" and parents="%sh) op pt p
l -> _pError $ sformat (sh%": expected one parent but got "%sh) un l
_applyAggOp :: (HasCallStack) => DataType -> AggOp -> UntypedGroupData -> UntypedDataset
_applyAggOp dt ao ugd = traceHint ("_applyAggOp dt=" <> show' dt <> " ao=" <> show' ao <> " ugd=" <> show' ugd <> " res=") $
let c1 = untypedCol (_keyCol ugd) @@ T.unpack "_1"
c2 = untypedCol (_valueCol ugd) @@ T.unpack "_2"
s = struct' [c1, c2]
p = pack1 <$> s
ds = forceRight p
keyDt = unSQLType (colType (_keyCol ugd))
st' = structTypeFromFields [(unsafeFieldName "key", keyDt), (unsafeFieldName "agg", dt)]
st = forceRight st'
resDt = SQLType . StrictType . Struct $ st
ds2 = emptyDataset (NodeGroupedReduction ao) resDt `parents` [untyped ds]
in ds2
_unrollGroupTrans :: UntypedGroupData -> ColOp -> PipedTrans
_unrollGroupTrans ugd co =
let gco = colOp (_valueCol ugd) in case colOpNoBroadcast gco of
Left x -> _pError $ "_unrollGroupTrans (1): using unimplemented feature:" <> show' x
Right co' -> case _combineColOp co' co of
Left x -> _pError $ "_unrollGroupTrans (2): failure with " <> show' x
Right co'' -> case _groupCol $ _transformCol co'' (_valueCol ugd) of
Left x -> _pError $ "_unrollGroupTrans (3): failure with " <> show' x
Right g -> PipedGroup $ ugd { _gdValue = g }
_transformCol :: ColOp -> UntypedColumnData -> UntypedColumnData
_transformCol co ucd = ucd { _cOp = genColOp co }
_combineColOp :: ColOp -> ColOp -> Try ColOp
_combineColOp _ (x @ (ColLit _ _)) = pure x
_combineColOp x (ColFunction fn v) =
ColFunction fn <$> sequence (_combineColOp x <$> v)
_combineColOp x (ColExtraction fp) = _extractColOp x (V.toList (unFieldPath fp))
_combineColOp x (ColStruct v) =
ColStruct <$> sequence (f <$> v) where
f (TransformField n val) = TransformField n <$> _combineColOp x val
_extractColOp :: ColOp -> [FieldName] -> Try ColOp
_extractColOp x [] = pure x
_extractColOp (ColStruct s) (fn : t) =
case V.find (\x -> tfName x == fn) s of
Just (TransformField _ co) ->
_extractColOp co t
Nothing ->
tryError $ sformat ("Expected to find field "%sh%" in structure "%sh) fn s
_extractColOp x y =
tryError $ sformat ("Cannot perform extraction "%sh%" on column operation "%sh) y x
_aggKey :: UntypedGroupData -> (UntypedColumnData -> Try UntypedLocalData) -> Try UntypedDataset
_aggKey ugd f =
let inputDt = unSQLType . colType . _valueCol $ ugd
p = placeholder inputDt :: UntypedDataset
startNid = nodeId p in do
uld <- f (_unsafeCastColData (asCol p))
case _unrollTransform (PipedGroup ugd) startNid (untyped uld) of
PipedError t -> tryError t
PipedGroup g ->
tryError $ sformat ("Expected a dataframe at the output but got a group: "%sh) g
PipedDataset ds -> pure ds
_unsafeCastColData :: Column ref a -> Column ref' a'
_unsafeCastColData c = c { _cType = _cType c }
_castGroup ::
SQLType key -> SQLType val -> UntypedGroupData -> Try (GroupData key val)
_castGroup (SQLType keyType) (SQLType valType) ugd =
let keyType' = unSQLType . colType . _keyCol $ ugd
valType' = unSQLType . colType . _valueCol $ ugd in
if keyType == keyType'
then if valType == valType'
then
pure ugd { _gdRef = _gdRef ugd }
else
tryError $ sformat ("The value column (of type "%sh%") cannot be cast to type "%sh) valType' valType
else
tryError $ sformat ("The value column (of type "%sh%") cannot be cast to type "%sh) keyType' keyType
_untypedGroup :: GroupData key val -> UntypedGroupData
_untypedGroup gd = gd { _gdRef = _gdRef gd }
_groupByKey :: UntypedColumnData -> UntypedColumnData -> LogicalGroupData
_groupByKey keys vals =
if nodeId (colOrigin keys) == nodeId (colOrigin vals)
then
let s = struct (keys, vals) :: Column UnknownReference (Cell, Cell)
ds = pack1 s
keys' = ds // _1
vals' = ds // _2
in do
gKeys <- _groupCol keys'
gVals <- _groupCol vals'
return GroupData {
_gdRef = colOrigin keys',
_gdKey = gKeys,
_gdValue = gVals
}
else
tryError $ sformat ("The columns have different origin: "%sh%" and "%sh) keys vals
_groupCol :: Column ref a -> Try GroupColumn
_groupCol c = do
co <- colOpNoBroadcast (colOp c)
return GroupColumn {
_gcType = unSQLType $ colType c,
_gcOp = co,
_gcRefName = Nothing
}