module Opaleye.Internal.Optimize where
import Prelude hiding (product)
import qualified Opaleye.Internal.PrimQuery as PQ
import Opaleye.Internal.Helpers ((.:))
import qualified Data.List.NonEmpty as NEL
import Control.Applicative ((<$>), (<*>), pure)
import qualified Data.Traversable as T
optimize :: PQ.PrimQuery' a -> PQ.PrimQuery' a
optimize = mergeProduct . removeUnit
removeUnit :: PQ.PrimQuery' a -> PQ.PrimQuery' a
removeUnit = PQ.foldPrimQuery PQ.primQueryFoldDefault { PQ.product = product }
where product pqs pes = PQ.Product pqs' pes
where pqs' = case NEL.nonEmpty (NEL.filter (not . PQ.isUnit) pqs) of
Nothing -> return PQ.Unit
Just xs -> xs
mergeProduct :: PQ.PrimQuery' a -> PQ.PrimQuery' a
mergeProduct = PQ.foldPrimQuery PQ.primQueryFoldDefault { PQ.product = product }
where product pqs pes = PQ.Product pqs' (pes ++ pes')
where pqs' = pqs >>= queries
queries (PQ.Product qs _) = qs
queries q = return q
pes' = NEL.toList pqs >>= conds
conds (PQ.Product _ cs) = cs
conds _ = []
removeEmpty :: PQ.PrimQuery' a -> Maybe (PQ.PrimQuery' b)
removeEmpty = PQ.foldPrimQuery PQ.PrimQueryFold {
PQ.unit = return PQ.Unit
, PQ.empty = const Nothing
, PQ.baseTable = return .: PQ.BaseTable
, PQ.product = \x y -> PQ.Product <$> (T.traverse removeEmpty
=<< T.sequence x)
<*> pure y
, PQ.aggregate = fmap . PQ.Aggregate
, PQ.order = fmap . PQ.Order
, PQ.limit = fmap . PQ.Limit
, PQ.join = \jt pe pq1 pq2 -> PQ.Join jt pe <$> pq1 <*> pq2
, PQ.values = return .: PQ.Values
, PQ.binary = \case
PQ.Except -> binary Just (const Nothing) PQ.Except
PQ.Union -> binary Just Just PQ.Union
PQ.Intersect -> binary (const Nothing) (const Nothing) PQ.Intersect
PQ.ExceptAll -> binary Just (const Nothing) PQ.ExceptAll
PQ.UnionAll -> binary Just Just PQ.UnionAll
PQ.IntersectAll -> binary (const Nothing) (const Nothing) PQ.IntersectAll
, PQ.label = fmap . PQ.Label
}
where
binary n1 n2 jj exprs = \case
(Nothing, Nothing) -> Nothing
(Nothing, Just pq2) -> n2 pq2
(Just pq1, Nothing) -> n1 pq1
(Just pq1, Just pq2) -> Just (PQ.Binary jj exprs (pq1, pq2))