{-# LANGUAGE TemplateHaskell #-}

module Data.SortingNetwork.TH (
  gMkSortBy,
  mkSortListBy,
  mkSortTupBy,
  mkSortListByFns,
  mkSortTupByFns,
) where

import Control.Monad
import Control.Monad.IO.Class
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as VM
import Language.Haskell.TH

type MkPairs = Int -> [(Int, Int)]
type PartQ = Exp -> Q Exp

gMkSortBy :: MkPairs -> Int -> ([Pat] -> Pat) -> ([Exp] -> Exp) -> Q Exp
gMkSortBy :: MkPairs -> Int -> ([Pat] -> Pat) -> ([Exp] -> Exp) -> Q Exp
gMkSortBy MkPairs
mkPairs Int
n [Pat] -> Pat
mkP [Exp] -> Exp
mkE = do
  -- cmp :: a -> a -> Ordering
  Name
cmp <- forall (m :: * -> *). Quote m => String -> m Name
newName String
"cmp"

  Name
swapper <- forall (m :: * -> *). Quote m => String -> m Name
newName String
"sw"
  Exp
swapperVal <- [|\u v f -> if $(varE cmp) u v == GT then f v u else f u v|]

  [Name]
ns0 <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). Quote m => String -> m Name
newName String
"v"
  let -- let sw = ... in <???>
      step0 :: PartQ
      step0 :: PartQ
step0 Exp
bd = [|let $(varP swapper) = $(pure swapperVal) in $(pure bd)|]
  (PartQ
mkBody :: PartQ, [Name]
ns :: [Name]) <- do
    MVector RealWorld Name
nv <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
PrimMonad m =>
Vector a -> m (MVector (PrimState m) a)
V.unsafeThaw (forall a. [a] -> Vector a
V.fromList [Name]
ns0)
    PartQ
e <-
      forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM
        ( \(PartQ
mk :: PartQ) (Int
i, Int
j) -> do
            Name
iOld <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> m a
VM.unsafeRead MVector RealWorld Name
nv Int
i
            Name
jOld <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> m a
VM.unsafeRead MVector RealWorld Name
nv Int
j
            Name
iNew <- forall (m :: * -> *). Quote m => String -> m Name
newName String
"v"
            Name
jNew <- forall (m :: * -> *). Quote m => String -> m Name
newName String
"v"
            forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO do
              forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
VM.unsafeWrite MVector RealWorld Name
nv Int
i Name
iNew
              forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
VM.unsafeWrite MVector RealWorld Name
nv Int
j Name
jNew
            forall (f :: * -> *) a. Applicative f => a -> f a
pure \(Exp
hole :: Exp) ->
              PartQ
mk
                forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [|
                  $(varE swapper)
                    $(varE iOld)
                    $(varE jOld)
                    (\ $(varP iNew) $(varP jNew) -> $(pure hole))
                  |]
        )
        PartQ
step0
        (MkPairs
mkPairs Int
n)
    Vector Name
nvFin <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
V.unsafeFreeze MVector RealWorld Name
nv
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (PartQ
e, forall a. Vector a -> [a]
V.toList Vector Name
nvFin)

  [|
    \ $(varP cmp)
      $(pure $ mkP $ VarP <$> ns0) ->
        $(mkBody $ mkE $ VarE <$> ns)
    |]

mkSortListBy, mkSortTupBy :: MkPairs -> Int -> ExpQ
mkSortListBy :: MkPairs -> Int -> Q Exp
mkSortListBy MkPairs
mkPairs Int
n = MkPairs -> Int -> ([Pat] -> Pat) -> ([Exp] -> Exp) -> Q Exp
gMkSortBy MkPairs
mkPairs Int
n [Pat] -> Pat
ListP [Exp] -> Exp
ListE
mkSortTupBy :: MkPairs -> Int -> Q Exp
mkSortTupBy MkPairs
mkPairs Int
n = MkPairs -> Int -> ([Pat] -> Pat) -> ([Exp] -> Exp) -> Q Exp
gMkSortBy MkPairs
mkPairs Int
n [Pat] -> Pat
TupP ([Maybe Exp] -> Exp
TupE forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. a -> Maybe a
Just)

mkSortListByFns, mkSortTupByFns :: MkPairs -> [Int] -> Q [Dec]
mkSortListByFns :: MkPairs -> [Int] -> Q [Dec]
mkSortListByFns MkPairs
mkPairs [Int]
ns =
  forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int]
ns \Int
n -> do
    let defN :: Name
        defN :: Name
defN = String -> Name
mkName forall a b. (a -> b) -> a -> b
$ String
"sortList" forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Int
n forall a. Semigroup a => a -> a -> a
<> String
"By"
    Exp
bd <- MkPairs -> Int -> Q Exp
mkSortListBy MkPairs
mkPairs Int
n
    [d|$(varP defN) = $(pure bd)|]
mkSortTupByFns :: MkPairs -> [Int] -> Q [Dec]
mkSortTupByFns MkPairs
mkPairs [Int]
ns =
  forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int]
ns \Int
n -> do
    let defN :: Name
defN = String -> Name
mkName forall a b. (a -> b) -> a -> b
$ String
"sortTup" forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Int
n forall a. Semigroup a => a -> a -> a
<> String
"By"
    Exp
bd <- MkPairs -> Int -> Q Exp
mkSortTupBy MkPairs
mkPairs Int
n
    [d|$(varP defN) = $(pure bd)|]