module Database.Sql.Util.Joins (HasJoins(..), JoinsResult) where
import Database.Sql.Type
import qualified Data.Map as M
import Data.Map (Map)
import qualified Data.Set as S
import Data.Set (Set)
import Data.Semigroup
import Data.Functor.Identity
import Data.Foldable
import Control.Monad (void, when)
import Control.Monad.Writer (Writer, execWriter, tell)
data Result = Result
{ resultBindings :: Map ColumnAliasId (Map (RColumnRef ()) FieldChain)
, resultColumns :: Set (Map (RColumnRef ()) FieldChain)
}
instance Monoid Result where
mempty = Result mempty mempty
mappend (Result bindings columns) (Result bindings' columns') = Result (bindings <> bindings') (columns <> columns')
type Join = ((FullyQualifiedColumnName, [StructFieldName ()]), (FullyQualifiedColumnName, [StructFieldName ()]))
type Scoped a = Writer Result a
type JoinsResult = Set Join
class HasJoins q where
getJoins :: q -> Set Join
instance HasJoins (Statement d ResolvedNames a) where
getJoins stmt =
let Result{..} = execWriter $ getJoinsStatement stmt
unalias :: Map (RColumnRef ()) FieldChain -> Map (FQColumnName ()) FieldChain
unalias m = M.fromList $ M.toList m >>= \case
(RColumnRef fqcn, chain) -> [(fqcn, chain)]
(RColumnAlias (ColumnAlias _ _ aliasId), _) -> maybe [] (M.toList . unalias) $ M.lookup aliasId resultBindings
sets = S.map unalias resultColumns
toPairs m
| M.null m = []
| otherwise = do
let ((c@(QColumnName _ (Identity table) _), chain), m') = M.deleteFindMin m
pairs = do
(c'@(QColumnName _ (Identity table') _), chain') <- M.toList m'
fields <- expandChain chain
fields' <- expandChain chain'
if table /= table'
then [((fqcnToFQCN c, fields), (fqcnToFQCN c', fields'))]
else []
pairs ++ toPairs m'
in S.fromList $ toPairs =<< S.toList sets
where
expandChain (FieldChain m)
| M.null m = [[]]
| otherwise = do
(k, v) <- M.toList m
(k:) <$> expandChain v
getJoinsStatement :: Statement d ResolvedNames a -> Scoped ()
getJoinsStatement (QueryStmt query) = void $ getJoinsQuery query
getJoinsStatement (InsertStmt insert) = getJoinsInsert insert
getJoinsStatement (UpdateStmt update) = getJoinsUpdate update
getJoinsStatement (DeleteStmt delete) = getJoinsDelete delete
getJoinsStatement (TruncateStmt _) = pure ()
getJoinsStatement (CreateTableStmt create) = getJoinsCreateTable create
getJoinsStatement (AlterTableStmt _) = pure ()
getJoinsStatement (DropTableStmt _) = pure ()
getJoinsStatement (CreateViewStmt create) = void $ getJoinsQuery $ createViewQuery create
getJoinsStatement (DropViewStmt _) = pure ()
getJoinsStatement (CreateSchemaStmt _) = pure ()
getJoinsStatement (GrantStmt _) = pure ()
getJoinsStatement (RevokeStmt _) = pure ()
getJoinsStatement (BeginStmt _) = pure ()
getJoinsStatement (CommitStmt _) = pure ()
getJoinsStatement (RollbackStmt _) = pure ()
getJoinsStatement (ExplainStmt _ _) = pure ()
getJoinsStatement (EmptyStmt _) = pure ()
queryColumns :: Query ResolvedNames a -> [RColumnRef a]
queryColumns (QueryExcept _ _ query _) = queryColumns query
queryColumns (QueryUnion _ _ _ query _) = queryColumns query
queryColumns (QueryIntersect _ _ query _) = queryColumns query
queryColumns (QueryWith _ _ query) = queryColumns query
queryColumns (QueryOrder _ _ query) = queryColumns query
queryColumns (QueryLimit _ _ query) = queryColumns query
queryColumns (QueryOffset _ _ query) = queryColumns query
queryColumns (QuerySelect _ Select{selectCols = SelectColumns _ selections}) = selections >>= \case
SelectExpr _ aliases _ -> map RColumnAlias aliases
SelectStar _ _ (StarColumnNames cols) -> cols
getJoinsCreateTable :: CreateTable d ResolvedNames a -> Scoped ()
getJoinsCreateTable CreateTable{..} = getJoinsTableDefinition createTableDefinition
getJoinsTableDefinition :: TableDefinition d ResolvedNames a -> Scoped ()
getJoinsTableDefinition (TableColumns _ _) = pure ()
getJoinsTableDefinition (TableLike _ _) = pure ()
getJoinsTableDefinition (TableAs _ _ query) = void $ getJoinsQuery query
getJoinsTableDefinition (TableNoColumnInfo _) = pure ()
getJoinsInsert :: Insert ResolvedNames a -> Scoped ()
getJoinsInsert Insert{..} = case insertValues of
InsertDefaultValues _ -> pure ()
InsertExprValues _ values -> mapM_ (mapM_ getJoinsDefaultExpr) values
InsertSelectValues query -> void $ getJoinsQuery query
InsertDataFromFile _ _ -> pure ()
getJoinsDefaultExpr :: DefaultExpr ResolvedNames a -> Scoped ()
getJoinsDefaultExpr (DefaultValue _) = pure ()
getJoinsDefaultExpr (ExprValue expr) = void $ getJoinsExpr expr
getJoinsUpdate :: Update ResolvedNames a -> Scoped ()
getJoinsUpdate Update{..} = do
mapM_ (getJoinsDefaultExpr . snd) updateSetExprs
mapM_ getJoinsTablish updateFrom
mapM_ getJoinsExpr updateWhere
getJoinsDelete :: Delete ResolvedNames a -> Scoped ()
getJoinsDelete (Delete _ _ (Just expr)) = void $ getJoinsExpr expr
getJoinsDelete (Delete _ _ Nothing) = pure ()
zipColumns :: Query ResolvedNames a -> Query ResolvedNames a -> Scoped ()
zipColumns lhs rhs = do
let lcolumns = queryColumns lhs
rcolumns = queryColumns rhs
forM_ (zip lcolumns rcolumns) $ \ (lcol, rcol) -> emit $ M.fromSet (const $ FieldChain M.empty) $ S.fromList [void lcol, void rcol]
getJoinsQuery :: Query ResolvedNames a -> Scoped ()
getJoinsQuery (QuerySelect _ select) = getJoinsSelect select
getJoinsQuery (QueryExcept _ _ lhs rhs) = do
getJoinsQuery lhs
getJoinsQuery rhs
zipColumns lhs rhs
getJoinsQuery (QueryUnion _ _ _ lhs rhs) = do
getJoinsQuery lhs
getJoinsQuery rhs
zipColumns lhs rhs
getJoinsQuery (QueryIntersect _ _ lhs rhs) = do
getJoinsQuery lhs
getJoinsQuery rhs
zipColumns lhs rhs
getJoinsQuery (QueryWith _ ctes query) = do
mapM_ getJoinsCTE ctes
getJoinsQuery query
getJoinsQuery (QueryOrder _ orders query) = do
mapM_ getJoinsOrder orders
getJoinsQuery query
getJoinsQuery (QueryLimit _ _ query) = getJoinsQuery query
getJoinsQuery (QueryOffset _ _ query) = getJoinsQuery query
getJoinsSelect :: Select ResolvedNames a -> Scoped ()
getJoinsSelect (Select{..}) = do
getJoinsSelectCols selectCols
maybe (pure ()) getJoinsSelectFrom selectFrom
maybe (pure ()) getJoinsSelectWhere selectWhere
maybe (pure ()) getJoinsSelectTimeseries selectTimeseries
maybe (pure ()) getJoinsSelectGroup selectGroup
maybe (pure ()) getJoinsSelectHaving selectHaving
maybe (pure ()) getJoinsSelectNamedWindow selectNamedWindow
getJoinsSelectFrom :: SelectFrom ResolvedNames a -> Scoped ()
getJoinsSelectFrom (SelectFrom _ tablishes) = mapM_ getJoinsTablish tablishes
getJoinsSelectCols :: SelectColumns ResolvedNames a -> Scoped ()
getJoinsSelectCols (SelectColumns _ selections) = mapM_ getJoinsSelection selections
getJoinsSelectWhere :: SelectWhere ResolvedNames a -> Scoped ()
getJoinsSelectWhere (SelectWhere _ expr) = void $ getJoinsExpr expr
getJoinsSelectTimeseries :: SelectTimeseries ResolvedNames a -> Scoped ()
getJoinsSelectTimeseries (SelectTimeseries _ _ _ partition expr) = do
maybe (pure ()) getJoinsPartition partition
void $ getJoinsExpr expr
getJoinsPositionOrExpr :: PositionOrExpr ResolvedNames a -> Scoped ()
getJoinsPositionOrExpr (PositionOrExprPosition _ _ _) = pure ()
getJoinsPositionOrExpr (PositionOrExprExpr expr) = void $ getJoinsExpr expr
getJoinsGroupingElement :: GroupingElement ResolvedNames a -> Scoped ()
getJoinsGroupingElement (GroupingElementExpr _ posOrExpr) = getJoinsPositionOrExpr posOrExpr
getJoinsGroupingElement (GroupingElementSet _ exprs) = mapM_ getJoinsExpr exprs
getJoinsSelectGroup :: SelectGroup ResolvedNames a -> Scoped ()
getJoinsSelectGroup (SelectGroup _ groupingElements) =
mapM_ getJoinsGroupingElement groupingElements
getJoinsSelectHaving :: SelectHaving ResolvedNames a -> Scoped ()
getJoinsSelectHaving (SelectHaving _ exprs) = mapM_ getJoinsExpr exprs
getJoinsSelectNamedWindow :: SelectNamedWindow ResolvedNames a -> Scoped ()
getJoinsSelectNamedWindow (SelectNamedWindow _ windows) = mapM_ joins windows
where
joins (NamedWindowExpr _ _ windowExpr) = getJoinsWindowExpr windowExpr
joins (NamedPartialWindowExpr _ _ partialWindowExpr) = getJoinsPartialWindowExpr partialWindowExpr
emit :: Map (RColumnRef ()) FieldChain -> Scoped ()
emit cols = tell $ mempty { resultColumns = S.singleton cols }
bind :: ColumnAliasId -> Map (RColumnRef ()) FieldChain -> Scoped ()
bind alias cols = tell $ mempty { resultBindings = M.singleton alias cols }
getJoinsExpr :: Expr ResolvedNames a -> Scoped (Map (RColumnRef ()) FieldChain)
getJoinsExpr (BinOpExpr _ op lhs rhs) = do
lcols <- getJoinsExpr lhs
rcols <- getJoinsExpr rhs
let allcols = M.unionWith (<>) lcols rcols
when (op `elem` ["=", "!=", "<>", "<=>", "==", "<", ">", "<=", ">="]) $ do
emit allcols
return allcols
getJoinsExpr (CaseExpr _ cases else_) = do
cols <- mapM (\ (when_, then_) -> getJoinsExpr when_ *> getJoinsExpr then_) cases
col <- maybe (pure M.empty) getJoinsExpr else_
return $ M.unionsWith (<>) $ col : cols
getJoinsExpr (LikeExpr _ _ escape pattern expr) = do
void $ maybe (pure mempty) (getJoinsExpr . escapeExpr) escape
lcols <- getJoinsExpr $ patternExpr pattern
rcols <- getJoinsExpr expr
let allcols = M.unionWith (<>) lcols rcols
emit allcols
return allcols
getJoinsExpr (UnOpExpr _ _ expr) = getJoinsExpr expr
getJoinsExpr (ConstantExpr _ _) = return M.empty
getJoinsExpr (ColumnExpr _ column) = return $ M.singleton (void column) $ FieldChain M.empty
getJoinsExpr (InListExpr _ exprs expr) = do
cols <- M.unionsWith (<>) <$> mapM getJoinsExpr (expr:exprs)
emit cols
return cols
getJoinsExpr (InSubqueryExpr _ query expr) = do
getJoinsQuery query
let [column] = queryColumns query
columns <- getJoinsExpr expr
let columns' = M.insert (void column) (FieldChain M.empty) columns
emit columns'
return columns'
getJoinsExpr (BetweenExpr _ expr start end) = M.unionsWith (<>) <$> mapM getJoinsExpr [expr, start, end]
getJoinsExpr (OverlapsExpr _ (r1start, r1end) (r2start, r2end)) = M.unionsWith (<>) <$> mapM getJoinsExpr [r1start, r1end, r2start, r2end]
getJoinsExpr (FunctionExpr _ _ _ args params mFilter mOver) = do
cols <- M.unionsWith (<>) <$> mapM getJoinsExpr (args ++ map snd params)
maybe (pure mempty) getJoinsFilter mFilter
maybe (pure mempty) getJoinsOverSubExpr mOver
return cols
getJoinsExpr (AtTimeZoneExpr _ ts tz) = M.unionWith (<>) <$> getJoinsExpr ts <*> getJoinsExpr tz
getJoinsExpr (SubqueryExpr _ query) = do
getJoinsQuery query
let [column] = queryColumns query
pure $ M.singleton (void column) $ FieldChain M.empty
getJoinsExpr (ExistsExpr _ query) = do
_ <- getJoinsQuery query
return M.empty
getJoinsExpr (ArrayExpr _ values) = M.unionsWith (<>) <$> mapM getJoinsExpr values
getJoinsExpr (FieldAccessExpr _ expr field) = go expr $ FieldChain $ M.singleton (void field) $ FieldChain M.empty
where
go (ColumnExpr _ ref@(RColumnRef _)) chain = return $ M.singleton (void ref) chain
go (FieldAccessExpr _ expr' field') chain = go expr' $ FieldChain $ M.singleton (void field') chain
go expr' _ = getJoinsExpr expr'
getJoinsExpr (ArrayAccessExpr _ expr index) = M.unionsWith (<>) <$> mapM getJoinsExpr [expr, index]
getJoinsExpr (TypeCastExpr _ _ expr _) = getJoinsExpr expr
getJoinsExpr (VariableSubstitutionExpr _) = return M.empty
getJoinsFilter :: Filter ResolvedNames a -> Scoped ()
getJoinsFilter (Filter _ expr) = void $ getJoinsExpr expr
getJoinsOverSubExpr :: OverSubExpr ResolvedNames a -> Scoped ()
getJoinsOverSubExpr (OverWindowExpr _ windowExpr) = getJoinsWindowExpr windowExpr
getJoinsOverSubExpr (OverWindowName _ _) = pure ()
getJoinsOverSubExpr (OverPartialWindowExpr _ partial) = getJoinsPartialWindowExpr partial
getJoinsWindowExpr :: WindowExpr ResolvedNames a -> Scoped ()
getJoinsWindowExpr (WindowExpr _ p os _) = do
maybe (pure ()) getJoinsPartition p
mapM_ getJoinsOrder os
getJoinsPartialWindowExpr :: PartialWindowExpr ResolvedNames a -> Scoped ()
getJoinsPartialWindowExpr (PartialWindowExpr _ _ p os _) = do
maybe (pure ()) getJoinsPartition p
mapM_ getJoinsOrder os
getJoinsPartition :: Partition ResolvedNames a -> Scoped ()
getJoinsPartition (PartitionBy _ es) = mapM_ getJoinsExpr es
getJoinsPartition (PartitionBest _) = return ()
getJoinsPartition (PartitionNodes _) = return ()
getJoinsOrder :: Order ResolvedNames a -> Scoped ()
getJoinsOrder (Order _ posOrExpr _ _) = void $ getJoinsPositionOrExpr posOrExpr
getJoinsTablish :: Tablish ResolvedNames a -> Scoped ()
getJoinsTablish (TablishTable _ _ _) = pure ()
getJoinsTablish (TablishLateralView _ LateralView{..} lhs) = do
maybe (pure ()) getJoinsTablish lhs
mapM_ getJoinsExpr lateralViewExprs
getJoinsTablish (TablishSubQuery _ _ query) = getJoinsQuery query
getJoinsTablish (TablishJoin _ _ (JoinNatural _ (RNaturalColumns columns)) lhs rhs) = do
getJoinsTablish lhs
getJoinsTablish rhs
forM_ columns $ \ (RUsingColumn lcol rcol) -> do
emit $ M.fromSet (const $ FieldChain M.empty) $ S.fromList [void lcol, void rcol]
getJoinsTablish (TablishJoin _ _ (JoinOn expr) lhs rhs) = do
getJoinsTablish lhs
getJoinsTablish rhs
void $ getJoinsExpr expr
getJoinsTablish (TablishJoin _ _ (JoinUsing _ columns) lhs rhs) = do
getJoinsTablish lhs
getJoinsTablish rhs
forM_ columns $ \ (RUsingColumn lcol rcol) -> do
emit $ M.fromSet (const $ FieldChain M.empty) $ S.fromList [void lcol, void rcol]
getJoinsCTE :: CTE ResolvedNames a -> Scoped ()
getJoinsCTE (CTE _ _ _ query) = getJoinsQuery query
getJoinsSelection :: Selection ResolvedNames a -> Scoped ()
getJoinsSelection (SelectStar _ _ _) = pure ()
getJoinsSelection (SelectExpr _ aliases expr) = do
cols <- getJoinsExpr expr
forM_ aliases $ \ (ColumnAlias _ _ aliasId) -> bind aliasId cols