{-|
  Functions that generate sorting networks given a input size.
 -}

module Data.SortingNetwork.Compares (
  oddEvenMerge,
  optimal,
) where

import Control.Monad
import Data.Bits
import Data.SortingNetwork.Types

{-
  TODO: might be useful: https://metacpan.org/dist/Algorithm-Networksort/source/lib/Algorithm/Networksort.pm
 -}

{-|
  Batcher's odd-even mergesort

  Adopted from Pseudocode section of <https://en.wikipedia.org/wiki/Batcher_odd%E2%80%93even_mergesort>.
 -}
oddEvenMerge :: MkPairs
oddEvenMerge :: MkPairs
oddEvenMerge Int
n =
  forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Int
n forall a. Ord a => a -> a -> Bool
>= Int
0) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (f :: * -> *) a. Applicative f => a -> f a
pure do
    -- INVARIANT: p == shiftL 1 (pw - 1)
    (Int
p, Int
pw) <- forall a b. [a] -> [b] -> [(a, b)]
zip (forall a. (a -> Bool) -> [a] -> [a]
takeWhile (forall a. Ord a => a -> a -> Bool
< Int
n) forall a b. (a -> b) -> a -> b
$ forall a. (a -> a) -> a -> [a]
iterate (forall a. Num a => a -> a -> a
* Int
2) Int
1) [Int
1 ..]
    Int
k <- forall a. (a -> Bool) -> [a] -> [a]
takeWhile (forall a. Ord a => a -> a -> Bool
>= Int
1) forall a b. (a -> b) -> a -> b
$ forall a. (a -> a) -> a -> [a]
iterate (\Int
v -> forall a. Bits a => a -> Int -> a
shiftR Int
v Int
1) Int
p
    Int
j <- forall a. (a -> Bool) -> [a] -> [a]
takeWhile (forall a. Ord a => a -> a -> Bool
<= Int
n forall a. Num a => a -> a -> a
- Int
1 forall a. Num a => a -> a -> a
- Int
k) forall a b. (a -> b) -> a -> b
$ forall a. (a -> a) -> a -> [a]
iterate (forall a. Num a => a -> a -> a
+ Int
2 forall a. Num a => a -> a -> a
* Int
k) (forall a. Integral a => a -> a -> a
rem Int
k Int
p)
    Int
i <- [Int
0 .. Int
k forall a. Num a => a -> a -> a
- Int
1]
    forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ forall a. Bits a => a -> Int -> a
shiftR (Int
i forall a. Num a => a -> a -> a
+ Int
j) Int
pw forall a. Eq a => a -> a -> Bool
== forall a. Bits a => a -> Int -> a
shiftR (Int
i forall a. Num a => a -> a -> a
+ Int
j forall a. Num a => a -> a -> a
+ Int
k) Int
pw
    {-
      Index could get out of bound without this check
      when n is not a power of 2.

      I'm not sure about then correctness when n > 16,
      but our tests can verify its correctness for n = [2.. 16]
      based on 0-1 principle.
     -}
    forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ Int
i forall a. Num a => a -> a -> a
+ Int
j forall a. Num a => a -> a -> a
+ Int
k forall a. Ord a => a -> a -> Bool
< Int
n
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
i forall a. Num a => a -> a -> a
+ Int
j, Int
i forall a. Num a => a -> a -> a
+ Int
j forall a. Num a => a -> a -> a
+ Int
k)

{-|
  Sorting networks that are optimal by size.

  Source from <https://bertdobbelaere.github.io/sorting_networks.html>.
 -}
optimal :: MkPairs
optimal :: MkPairs
optimal Int
v = forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Int
v [(Int, [(Int, Int)])]
optimalNetworks

{-
  In the event that there are multiple networks, one with
  minimal size is picked (but this decision is arbitrary, we will reconsider later).
 -}
optimalNetworks :: [(Int, [] (Int, Int))]
optimalNetworks :: [(Int, [(Int, Int)])]
optimalNetworks =
  [ (Int
2, [(Int
0, Int
1)])
  , (Int
3, [(Int
0, Int
2), (Int
0, Int
1), (Int
1, Int
2)])
  , (Int
4, [(Int
0, Int
2), (Int
1, Int
3), (Int
0, Int
1), (Int
2, Int
3), (Int
1, Int
2)])
  , (Int
5, [(Int
0, Int
3), (Int
1, Int
4), (Int
0, Int
2), (Int
1, Int
3), (Int
0, Int
1), (Int
2, Int
4), (Int
1, Int
2), (Int
3, Int
4), (Int
2, Int
3)])
  , (Int
6, [(Int
0, Int
5), (Int
1, Int
3), (Int
2, Int
4), (Int
1, Int
2), (Int
3, Int
4), (Int
0, Int
3), (Int
2, Int
5), (Int
0, Int
1), (Int
2, Int
3), (Int
4, Int
5), (Int
1, Int
2), (Int
3, Int
4)])
  , (Int
7, [(Int
0, Int
6), (Int
2, Int
3), (Int
4, Int
5), (Int
0, Int
2), (Int
1, Int
4), (Int
3, Int
6), (Int
0, Int
1), (Int
2, Int
5), (Int
3, Int
4), (Int
1, Int
2), (Int
4, Int
6), (Int
2, Int
3), (Int
4, Int
5), (Int
1, Int
2), (Int
3, Int
4), (Int
5, Int
6)])
  , (Int
8, [(Int
0, Int
2), (Int
1, Int
3), (Int
4, Int
6), (Int
5, Int
7), (Int
0, Int
4), (Int
1, Int
5), (Int
2, Int
6), (Int
3, Int
7), (Int
0, Int
1), (Int
2, Int
3), (Int
4, Int
5), (Int
6, Int
7), (Int
2, Int
4), (Int
3, Int
5), (Int
1, Int
4), (Int
3, Int
6), (Int
1, Int
2), (Int
3, Int
4), (Int
5, Int
6)])
  , (Int
9, [(Int
0, Int
3), (Int
1, Int
7), (Int
2, Int
5), (Int
4, Int
8), (Int
0, Int
7), (Int
2, Int
4), (Int
3, Int
8), (Int
5, Int
6), (Int
0, Int
2), (Int
1, Int
3), (Int
4, Int
5), (Int
7, Int
8), (Int
1, Int
4), (Int
3, Int
6), (Int
5, Int
7), (Int
0, Int
1), (Int
2, Int
4), (Int
3, Int
5), (Int
6, Int
8), (Int
2, Int
3), (Int
4, Int
5), (Int
6, Int
7), (Int
1, Int
2), (Int
3, Int
4), (Int
5, Int
6)])
  , (Int
10, [(Int
0, Int
8), (Int
1, Int
9), (Int
2, Int
7), (Int
3, Int
5), (Int
4, Int
6), (Int
0, Int
2), (Int
1, Int
4), (Int
5, Int
8), (Int
7, Int
9), (Int
0, Int
3), (Int
2, Int
4), (Int
5, Int
7), (Int
6, Int
9), (Int
0, Int
1), (Int
3, Int
6), (Int
8, Int
9), (Int
1, Int
5), (Int
2, Int
3), (Int
4, Int
8), (Int
6, Int
7), (Int
1, Int
2), (Int
3, Int
5), (Int
4, Int
6), (Int
7, Int
8), (Int
2, Int
3), (Int
4, Int
5), (Int
6, Int
7), (Int
3, Int
4), (Int
5, Int
6)])
  , (Int
11, [(Int
0, Int
9), (Int
1, Int
6), (Int
2, Int
4), (Int
3, Int
7), (Int
5, Int
8), (Int
0, Int
1), (Int
3, Int
5), (Int
4, Int
10), (Int
6, Int
9), (Int
7, Int
8), (Int
1, Int
3), (Int
2, Int
5), (Int
4, Int
7), (Int
8, Int
10), (Int
0, Int
4), (Int
1, Int
2), (Int
3, Int
7), (Int
5, Int
9), (Int
6, Int
8), (Int
0, Int
1), (Int
2, Int
6), (Int
4, Int
5), (Int
7, Int
8), (Int
9, Int
10), (Int
2, Int
4), (Int
3, Int
6), (Int
5, Int
7), (Int
8, Int
9), (Int
1, Int
2), (Int
3, Int
4), (Int
5, Int
6), (Int
7, Int
8), (Int
2, Int
3), (Int
4, Int
5), (Int
6, Int
7)])
  , (Int
12, [(Int
0, Int
8), (Int
1, Int
7), (Int
2, Int
6), (Int
3, Int
11), (Int
4, Int
10), (Int
5, Int
9), (Int
0, Int
1), (Int
2, Int
5), (Int
3, Int
4), (Int
6, Int
9), (Int
7, Int
8), (Int
10, Int
11), (Int
0, Int
2), (Int
1, Int
6), (Int
5, Int
10), (Int
9, Int
11), (Int
0, Int
3), (Int
1, Int
2), (Int
4, Int
6), (Int
5, Int
7), (Int
8, Int
11), (Int
9, Int
10), (Int
1, Int
4), (Int
3, Int
5), (Int
6, Int
8), (Int
7, Int
10), (Int
1, Int
3), (Int
2, Int
5), (Int
6, Int
9), (Int
8, Int
10), (Int
2, Int
3), (Int
4, Int
5), (Int
6, Int
7), (Int
8, Int
9), (Int
4, Int
6), (Int
5, Int
7), (Int
3, Int
4), (Int
5, Int
6), (Int
7, Int
8)])
  , (Int
13, [(Int
0, Int
12), (Int
1, Int
10), (Int
2, Int
9), (Int
3, Int
7), (Int
5, Int
11), (Int
6, Int
8), (Int
1, Int
6), (Int
2, Int
3), (Int
4, Int
11), (Int
7, Int
9), (Int
8, Int
10), (Int
0, Int
4), (Int
1, Int
2), (Int
3, Int
6), (Int
7, Int
8), (Int
9, Int
10), (Int
11, Int
12), (Int
4, Int
6), (Int
5, Int
9), (Int
8, Int
11), (Int
10, Int
12), (Int
0, Int
5), (Int
3, Int
8), (Int
4, Int
7), (Int
6, Int
11), (Int
9, Int
10), (Int
0, Int
1), (Int
2, Int
5), (Int
6, Int
9), (Int
7, Int
8), (Int
10, Int
11), (Int
1, Int
3), (Int
2, Int
4), (Int
5, Int
6), (Int
9, Int
10), (Int
1, Int
2), (Int
3, Int
4), (Int
5, Int
7), (Int
6, Int
8), (Int
2, Int
3), (Int
4, Int
5), (Int
6, Int
7), (Int
8, Int
9), (Int
3, Int
4), (Int
5, Int
6)])
  , (Int
14, [(Int
0, Int
1), (Int
2, Int
3), (Int
4, Int
5), (Int
6, Int
7), (Int
8, Int
9), (Int
10, Int
11), (Int
12, Int
13), (Int
0, Int
2), (Int
1, Int
3), (Int
4, Int
8), (Int
5, Int
9), (Int
10, Int
12), (Int
11, Int
13), (Int
0, Int
4), (Int
1, Int
2), (Int
3, Int
7), (Int
5, Int
8), (Int
6, Int
10), (Int
9, Int
13), (Int
11, Int
12), (Int
0, Int
6), (Int
1, Int
5), (Int
3, Int
9), (Int
4, Int
10), (Int
7, Int
13), (Int
8, Int
12), (Int
2, Int
10), (Int
3, Int
11), (Int
4, Int
6), (Int
7, Int
9), (Int
1, Int
3), (Int
2, Int
8), (Int
5, Int
11), (Int
6, Int
7), (Int
10, Int
12), (Int
1, Int
4), (Int
2, Int
6), (Int
3, Int
5), (Int
7, Int
11), (Int
8, Int
10), (Int
9, Int
12), (Int
2, Int
4), (Int
3, Int
6), (Int
5, Int
8), (Int
7, Int
10), (Int
9, Int
11), (Int
3, Int
4), (Int
5, Int
6), (Int
7, Int
8), (Int
9, Int
10), (Int
6, Int
7)])
  , (Int
15, [(Int
1, Int
2), (Int
3, Int
10), (Int
4, Int
14), (Int
5, Int
8), (Int
6, Int
13), (Int
7, Int
12), (Int
9, Int
11), (Int
0, Int
14), (Int
1, Int
5), (Int
2, Int
8), (Int
3, Int
7), (Int
6, Int
9), (Int
10, Int
12), (Int
11, Int
13), (Int
0, Int
7), (Int
1, Int
6), (Int
2, Int
9), (Int
4, Int
10), (Int
5, Int
11), (Int
8, Int
13), (Int
12, Int
14), (Int
0, Int
6), (Int
2, Int
4), (Int
3, Int
5), (Int
7, Int
11), (Int
8, Int
10), (Int
9, Int
12), (Int
13, Int
14), (Int
0, Int
3), (Int
1, Int
2), (Int
4, Int
7), (Int
5, Int
9), (Int
6, Int
8), (Int
10, Int
11), (Int
12, Int
13), (Int
0, Int
1), (Int
2, Int
3), (Int
4, Int
6), (Int
7, Int
9), (Int
10, Int
12), (Int
11, Int
13), (Int
1, Int
2), (Int
3, Int
5), (Int
8, Int
10), (Int
11, Int
12), (Int
3, Int
4), (Int
5, Int
6), (Int
7, Int
8), (Int
9, Int
10), (Int
2, Int
3), (Int
4, Int
5), (Int
6, Int
7), (Int
8, Int
9), (Int
10, Int
11), (Int
5, Int
6), (Int
7, Int
8)])
  , (Int
16, [(Int
0, Int
13), (Int
1, Int
12), (Int
2, Int
15), (Int
3, Int
14), (Int
4, Int
8), (Int
5, Int
6), (Int
7, Int
11), (Int
9, Int
10), (Int
0, Int
5), (Int
1, Int
7), (Int
2, Int
9), (Int
3, Int
4), (Int
6, Int
13), (Int
8, Int
14), (Int
10, Int
15), (Int
11, Int
12), (Int
0, Int
1), (Int
2, Int
3), (Int
4, Int
5), (Int
6, Int
8), (Int
7, Int
9), (Int
10, Int
11), (Int
12, Int
13), (Int
14, Int
15), (Int
0, Int
2), (Int
1, Int
3), (Int
4, Int
10), (Int
5, Int
11), (Int
6, Int
7), (Int
8, Int
9), (Int
12, Int
14), (Int
13, Int
15), (Int
1, Int
2), (Int
3, Int
12), (Int
4, Int
6), (Int
5, Int
7), (Int
8, Int
10), (Int
9, Int
11), (Int
13, Int
14), (Int
1, Int
4), (Int
2, Int
6), (Int
5, Int
8), (Int
7, Int
10), (Int
9, Int
13), (Int
11, Int
14), (Int
2, Int
4), (Int
3, Int
6), (Int
9, Int
12), (Int
11, Int
13), (Int
3, Int
5), (Int
6, Int
8), (Int
7, Int
9), (Int
10, Int
12), (Int
3, Int
4), (Int
5, Int
6), (Int
7, Int
8), (Int
9, Int
10), (Int
11, Int
12), (Int
6, Int
7), (Int
8, Int
9)])
  ]