module Spark.Core.Internal.Groups(
GroupData,
LogicalGroupData,
groupByKey,
mapGroup,
aggKey,
groupAsDS
) where
import qualified Data.Text as T
import qualified Data.Vector as V
import Formatting
import Debug.Trace(trace)
import Spark.Core.Internal.DatasetStructures
import Spark.Core.Internal.ColumnStructures
import Spark.Core.Internal.ColumnFunctions(untypedCol, colType, colOp, iUntypedColData, colOrigin, castTypeCol, dropColReference)
import Spark.Core.Internal.DatasetFunctions
import Spark.Core.Internal.LocalDataFunctions()
import Spark.Core.Internal.FunctionsInternals
import Spark.Core.Internal.TypesFunctions(tupleType, structTypeFromFields)
import Spark.Core.Internal.OpStructures
import Spark.Core.Internal.TypesStructures
import Spark.Core.Internal.Utilities
import Spark.Core.Internal.RowStructures(Cell)
import Spark.Core.Try
import Spark.Core.StructuresInternal
data GroupData key val = GroupData {
_gdRef :: !UntypedDataset,
_gdKey :: !UntypedColumnData,
_gdValue :: !UntypedColumnData
}
type LogicalGroupData = Try UntypedGroupData
type UntypedGroupData = GroupData Cell Cell
type GroupTry a = Either T.Text a
data PipedTrans =
PipedError !T.Text
| PipedDataset !UntypedDataset
| PipedGroup !UntypedGroupData
deriving (Show)
groupByKey :: (HasCallStack) => Column ref key -> Column ref val -> GroupData key val
groupByKey keys vals = forceRight $ _castGroup (colType keys) (colType vals) =<< _groupByKey (iUntypedColData keys) (iUntypedColData vals)
mapGroup :: GroupData key val -> (forall ref. Column ref val -> Column ref val') -> GroupData key val'
mapGroup g f =
let c = _unsafeCastColData (_gdValue g)
in g { _gdValue = iUntypedColData (f c) }
aggKey :: (HasCallStack) => GroupData key val -> (forall ref. Column ref val -> LocalData val') -> Dataset (key, val')
aggKey gd f = trace "aggKey" $
let ugd = _untypedGroup gd
keyt = traceHint "aggKey: keyt: " $ mapGroupKeys gd colType
valt = traceHint "aggKey: valt: " $ mapGroupValues gd colType
fOut = traceHint "aggKey: fOut: " $ f (mapGroupValues gd dropColReference)
valt' = traceHint "aggKey: valt': " $ nodeType fOut
t = traceHint "aggKey: t: " $ tupleType keyt valt'
f' c = untypedLocalData . f <$> castTypeCol valt c
tud = traceHint "aggKey: tud: " $ _aggKey ugd f'
res = castType' t tud
in forceRight res
groupAsDS :: forall key val. GroupData key val -> Dataset (key, val)
groupAsDS g = pack s where
c1 = _unsafeCastColData (_gdKey g) :: Column UnknownReference key
c2 = _unsafeCastColData (_gdValue g) :: Column UnknownReference val
s = struct (c1, c2) :: Column UnknownReference (key, val)
mapGroupKeys :: GroupData key val -> (forall ref. Column ref key -> a) -> a
mapGroupKeys gd f =
f (_unsafeCastColData (_gdKey gd))
mapGroupValues :: GroupData key val -> (forall ref. Column ref val -> a) -> a
mapGroupValues gd f =
f (_unsafeCastColData (_gdValue gd))
instance Show (GroupData key val) where
show gd = T.unpack s where
s = sformat ("GroupData[key="%sh%", val="%sh%"]") (_gdKey gd) (_gdValue gd)
_mapStructuredTransform :: ColOp -> LogicalGroupData -> GroupTry LogicalGroupData
_mapStructuredTransform = undefined
_mapAggTransform :: AggTransform -> LogicalGroupData -> GroupTry LogicalGroupData
_mapAggTransform = undefined
_pError :: T.Text -> PipedTrans
_pError = PipedError
_unrollTransform :: PipedTrans -> NodeId -> UntypedNode -> PipedTrans
_unrollTransform start nid un | nodeId un == nid = start
_unrollTransform start nid un = case nodeParents un of
[p] ->
let pt' = _unrollTransform start nid p in _unrollStep pt' un
_ ->
_pError $ sformat (sh%": operations with multiple parents cannot be used in groups yet.") un
_unrollStep :: PipedTrans -> UntypedNode -> PipedTrans
_unrollStep pt un = traceHint ("_unrollStep: pt=" <> show' pt <> " un=" <> show' un <> " res=") $
let op = nodeOp un
dt = unSQLType (nodeType un) in case nodeParents un of
[p] ->
case (pt, op) of
(PipedError e, _) -> PipedError e
(PipedDataset ds, NodeStructuredTransform _) ->
let ds' = updateNode un (\un' -> un' { _cnParents = V.singleton (untyped ds)})
in PipedDataset ds'
(PipedGroup g, NodeStructuredTransform co) ->
_unrollGroupTrans g co
(PipedGroup g, NodeAggregatorReduction uao) ->
case uaoInitialOuter uao of
OpaqueAggTransform x -> _pError $ sformat ("Cannot apply opaque transform in the context of an aggregation: "%sh) x
InnerAggOp ao ->
PipedDataset $ _applyAggOp dt ao g
_ -> _pError $ sformat (sh%": Operation not supported with trans="%sh%" and parents="%sh) op pt p
l -> _pError $ sformat (sh%": expected one parent but got "%sh) un l
_applyAggOp :: (HasCallStack) => DataType -> AggOp -> UntypedGroupData -> UntypedDataset
_applyAggOp dt ao ugd = traceHint ("_applyAggOp dt=" <> show' dt <> " ao=" <> show' ao <> " ugd=" <> show' ugd <> " res=") $
let c1 = untypedCol (_gdKey ugd) @@ T.unpack "_1"
c2 = untypedCol (_gdValue ugd) @@ T.unpack "_2"
s = struct' [c1, c2]
p = pack1 <$> s
ds = forceRight p
keyDt = unSQLType (colType (_gdKey ugd))
st' = structTypeFromFields [(unsafeFieldName "key", keyDt), (unsafeFieldName "agg", dt)]
st = forceRight st'
resDt = SQLType . StrictType . Struct $ st
ds2 = emptyDataset (NodeGroupedReduction ao) resDt `parents` [untyped ds]
in ds2
_unrollGroupTrans :: UntypedGroupData -> ColOp -> PipedTrans
_unrollGroupTrans ugd co = case _combineColOp (colOp (_gdValue ugd)) co of
Left x -> _pError $ "_unrollGroupTrans: failure with " <> show' x
Right co' -> PipedGroup $ ugd { _gdValue = _transformCol co' (_gdValue ugd) }
_transformCol :: ColOp -> UntypedColumnData -> UntypedColumnData
_transformCol co ucd = ucd { _cOp = co }
_combineColOp :: ColOp -> ColOp -> Try ColOp
_combineColOp _ (x @ (ColLit _ _)) = pure x
_combineColOp x (ColFunction fn v) =
ColFunction fn <$> sequence (_combineColOp x <$> v)
_combineColOp x (ColExtraction fp) = _extractColOp x (V.toList (unFieldPath fp))
_combineColOp x (ColStruct v) =
ColStruct <$> sequence (f <$> v) where
f (TransformField n val) = TransformField n <$> _combineColOp x val
_extractColOp :: ColOp -> [FieldName] -> Try ColOp
_extractColOp x [] = pure x
_extractColOp (ColStruct s) (fn : t) =
case V.find (\x -> tfName x == fn) s of
Just (TransformField _ co) ->
_extractColOp co t
Nothing ->
tryError $ sformat ("Expected to find field "%sh%" in structure "%sh) fn s
_extractColOp x y =
tryError $ sformat ("Cannot perform extraction "%sh%" on column operation "%sh) y x
_aggKey :: UntypedGroupData -> (UntypedColumnData -> Try UntypedLocalData) -> Try UntypedDataset
_aggKey ugd f =
let inputDt = unSQLType . colType . _gdValue $ ugd
p = placeholder inputDt :: UntypedDataset
startNid = nodeId p in do
uld <- f (_unsafeCastColData (asCol p))
case _unrollTransform (PipedGroup ugd) startNid (untyped uld) of
PipedError t -> tryError t
PipedGroup g ->
tryError $ sformat ("Expected a dataframe at the output but got a group: "%sh) g
PipedDataset ds -> pure ds
_unsafeCastColData :: Column ref a -> Column ref' a'
_unsafeCastColData c = c { _cType = _cType c }
_castGroup ::
SQLType key -> SQLType val -> UntypedGroupData -> Try (GroupData key val)
_castGroup (SQLType keyType) (SQLType valType) ugd =
let keyType' = unSQLType . colType . _gdKey $ ugd
valType' = unSQLType . colType . _gdValue $ ugd in
if keyType == keyType'
then if valType == valType'
then
pure ugd { _gdRef = _gdRef ugd }
else
tryError $ sformat ("The value column (of type "%sh%") cannot be cast to type "%sh) valType' valType
else
tryError $ sformat ("The value column (of type "%sh%") cannot be cast to type "%sh) keyType' keyType
_untypedGroup :: GroupData key val -> UntypedGroupData
_untypedGroup gd = gd { _gdRef = _gdRef gd }
_groupByKey :: UntypedColumnData -> UntypedColumnData -> LogicalGroupData
_groupByKey keys vals =
if nodeId (colOrigin keys) == nodeId (colOrigin vals)
then
pure GroupData {
_gdRef = colOrigin keys,
_gdKey = keys,
_gdValue = vals
}
else
tryError $ sformat ("The columns have different origin: "%sh%" and "%sh) keys vals