{-# LANGUAGE TemplateHaskell #-}

module Data.SortingNetwork.TH (
  gMkSortBy,
  mkUnsafeSortListBy,
  mkSortTupBy,
  mkUnsafeSortListByFns,
  mkSortTupByFns,
) where

import Control.Monad
import Control.Monad.IO.Class
import Data.Semigroup
import Data.SortingNetwork.Types
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as VM
import Language.Haskell.TH

-- | A monadic partial expression pending inner body.
type PartQ = Exp -> Q Exp

{- TODO: we should probably have functions for unboxed tuples -}

{- |
  @gMkSortBy mkPairs n mkP mkE@ generates a function that sorts elements using sorting network.

  @mkP :: [Pat] -> Pat@ and @mkE :: [Exp] -> Exp@ deals with unpacking input value and packing final results
  respectively. This generalization allows us to deal with lists, tuples, and unboxed-tuples all at once.
-}
gMkSortBy :: MkPairs -> Int -> ([Pat] -> Pat) -> ([Exp] -> Exp) -> Q Exp
gMkSortBy :: MkPairs -> Int -> ([Pat] -> Pat) -> ([Exp] -> Exp) -> Q Exp
gMkSortBy MkPairs
mkPairs Int
n [Pat] -> Pat
mkP [Exp] -> Exp
mkE = do
  -- cmp :: a -> a -> Ordering
  Name
cmp <- forall (m :: * -> *). Quote m => String -> m Name
newName String
"cmp"

  Name
swapper <- forall (m :: * -> *). Quote m => String -> m Name
newName String
"sw"
  Exp
swapperVal <- [|\u v f -> if $(varE cmp) u v == GT then f v u else f u v|]

  [Name]
ns0 <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). Quote m => String -> m Name
newName String
"v"
  let -- let sw = ... in <???>
      step0 :: PartQ
      step0 :: PartQ
step0 Exp
bd = [|let $(varP swapper) = $(pure swapperVal) in $(pure bd)|]

  [(Int, Int)]
pairs <- case MkPairs
mkPairs Int
n of
    Just [(Int, Int)]
ps -> forall (f :: * -> *) a. Applicative f => a -> f a
pure [(Int, Int)]
ps
    Maybe [(Int, Int)]
Nothing -> forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall a b. (a -> b) -> a -> b
$ String
"MkPairs returned Nothing on length " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Int
n

  (PartQ
mkBody :: PartQ, [Name]
ns :: [Name]) <- do
    MVector RealWorld Name
nv <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
PrimMonad m =>
Vector a -> m (MVector (PrimState m) a)
V.unsafeThaw (forall a. [a] -> Vector a
V.fromList [Name]
ns0)
    PartQ
e <-
      forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM
        ( \(PartQ
mk :: PartQ) (Int
i, Int
j) -> do
            Name
iOld <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> m a
VM.unsafeRead MVector RealWorld Name
nv Int
i
            Name
jOld <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> m a
VM.unsafeRead MVector RealWorld Name
nv Int
j
            Name
iNew <- forall (m :: * -> *). Quote m => String -> m Name
newName String
"v"
            Name
jNew <- forall (m :: * -> *). Quote m => String -> m Name
newName String
"v"
            forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO do
              forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
VM.unsafeWrite MVector RealWorld Name
nv Int
i Name
iNew
              forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
VM.unsafeWrite MVector RealWorld Name
nv Int
j Name
jNew
            forall (f :: * -> *) a. Applicative f => a -> f a
pure \(Exp
hole :: Exp) ->
              PartQ
mk
                forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [|
                  $(varE swapper)
                    $(varE iOld)
                    $(varE jOld)
                    (\ $(varP iNew) $(varP jNew) -> $(pure hole))
                  |]
        )
        PartQ
step0
        [(Int, Int)]
pairs
    Vector Name
nvFin <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
V.unsafeFreeze MVector RealWorld Name
nv
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (PartQ
e, forall a. Vector a -> [a]
V.toList Vector Name
nvFin)

  [|
    \ $(varP cmp)
      $(pure $ mkP $ VarP <$> ns0) ->
        $(mkBody $ mkE $ VarE <$> ns)
    |]

{- |
  @mkUnsafeSortListBy mkPairs n@ generates an expression of type @(a -> a -> Ordering) -> [a] -> [a]@.

  Note that resulting function is partial and requires input list to contain exactly @n@ elements.
-}
mkUnsafeSortListBy :: MkPairs -> Int -> ExpQ
mkUnsafeSortListBy :: MkPairs -> Int -> Q Exp
mkUnsafeSortListBy MkPairs
mkPairs Int
n = MkPairs -> Int -> ([Pat] -> Pat) -> ([Exp] -> Exp) -> Q Exp
gMkSortBy MkPairs
mkPairs Int
n [Pat] -> Pat
ListP [Exp] -> Exp
ListE

{- |
  @mkSortTupBy mkPairs n@ generates an expression of type @(a -> a -> Ordering) -> (a, a, ...) -> (a, a, ...)@.

  Where the input and output tuple @(a, a, ...)@ contains @n@ elements.
-}
mkSortTupBy :: MkPairs -> Int -> ExpQ
mkSortTupBy :: MkPairs -> Int -> Q Exp
mkSortTupBy MkPairs
mkPairs Int
n = MkPairs -> Int -> ([Pat] -> Pat) -> ([Exp] -> Exp) -> Q Exp
gMkSortBy MkPairs
mkPairs Int
n [Pat] -> Pat
TupP ([Maybe Exp] -> Exp
TupE forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. a -> Maybe a
Just)

{-
  Note: I'm not sure if there are more convenient ways to have type signatures with qq,
  so current approach is just to build it from plain constructors.

  Might be related: https://stackoverflow.com/q/37478037/315302
 -}

mkUnsafeSortListByFns, mkSortTupByFns :: MkPairs -> [Int] -> Q [Dec]
mkUnsafeSortListByFns :: MkPairs -> [Int] -> Q [Dec]
mkUnsafeSortListByFns MkPairs
mkPairs [Int]
ns =
  forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int]
ns \Int
n -> do
    let defN :: Name
        defN :: Name
defN = String -> Name
mkName forall a b. (a -> b) -> a -> b
$ String
"unsafeSortList" forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Int
n forall a. Semigroup a => a -> a -> a
<> String
"By"
    Exp
bd <- MkPairs -> Int -> Q Exp
mkUnsafeSortListBy MkPairs
mkPairs Int
n
    forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence
      [ forall (m :: * -> *). Quote m => Name -> m Type -> m Dec
sigD Name
defN [t|forall a. (a -> a -> Ordering) -> [a] -> [a]|]
      , forall (m :: * -> *). Quote m => Name -> [m Clause] -> m Dec
funD Name
defN [forall (m :: * -> *).
Quote m =>
[m Pat] -> m Body -> [m Dec] -> m Clause
clause [] (forall (m :: * -> *). Quote m => m Exp -> m Body
normalB forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp
bd) []]
      ]
mkSortTupByFns :: MkPairs -> [Int] -> Q [Dec]
mkSortTupByFns MkPairs
mkPairs [Int]
ns =
  forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int]
ns \Int
n -> do
    let defN :: Name
defN = String -> Name
mkName forall a b. (a -> b) -> a -> b
$ String
"sortTup" forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Int
n forall a. Semigroup a => a -> a -> a
<> String
"By"
    Exp
bd <- MkPairs -> Int -> Q Exp
mkSortTupBy MkPairs
mkPairs Int
n
    Name
a <- forall (m :: * -> *). Quote m => String -> m Name
newName String
"a"
    Type
tupTy <- do
      Type
constr <- forall (m :: * -> *). Quote m => Int -> m Type
tupleT Int
n
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. Endo a -> a -> a
appEndo (forall a b. (Semigroup a, Integral b) => b -> a -> a
stimes Int
n (forall a. (a -> a) -> Endo a
Endo (\Type
t -> Type -> Type -> Type
AppT Type
t (Name -> Type
VarT Name
a)))) Type
constr
    forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence
      [ forall (m :: * -> *). Quote m => Name -> m Type -> m Dec
sigD Name
defN [t|($(varT a) -> $(varT a) -> Ordering) -> $(pure tupTy) -> $(pure tupTy)|]
      , forall (m :: * -> *). Quote m => Name -> [m Clause] -> m Dec
funD Name
defN [forall (m :: * -> *).
Quote m =>
[m Pat] -> m Body -> [m Dec] -> m Clause
clause [] (forall (m :: * -> *). Quote m => m Exp -> m Body
normalB forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp
bd) []]
      ]