module Data.SortingNetwork.Compares (
  batcher,
) where

import Control.Monad
import Data.Bits

{-
  Ref: https://en.wikipedia.org/wiki/Batcher_odd%E2%80%93even_mergesort
 -}
batcher :: Int -> [] (Int, Int)
batcher :: Int -> [(Int, Int)]
batcher Int
n = do
  -- INVARIANT: p == shiftL 1 (pw - 1)
  (Int
p, Int
pw) <- forall a b. [a] -> [b] -> [(a, b)]
zip (forall a. (a -> Bool) -> [a] -> [a]
takeWhile (forall a. Ord a => a -> a -> Bool
< Int
n) forall a b. (a -> b) -> a -> b
$ forall a. (a -> a) -> a -> [a]
iterate (forall a. Num a => a -> a -> a
* Int
2) Int
1) [Int
1 ..]
  Int
k <- forall a. (a -> Bool) -> [a] -> [a]
takeWhile (forall a. Ord a => a -> a -> Bool
>= Int
1) forall a b. (a -> b) -> a -> b
$ forall a. (a -> a) -> a -> [a]
iterate (\Int
v -> forall a. Bits a => a -> Int -> a
shiftR Int
v Int
1) Int
p
  Int
j <- forall a. (a -> Bool) -> [a] -> [a]
takeWhile (forall a. Ord a => a -> a -> Bool
<= Int
n forall a. Num a => a -> a -> a
- Int
1 forall a. Num a => a -> a -> a
- Int
k) forall a b. (a -> b) -> a -> b
$ forall a. (a -> a) -> a -> [a]
iterate (forall a. Num a => a -> a -> a
+ Int
2 forall a. Num a => a -> a -> a
* Int
k) (forall a. Integral a => a -> a -> a
rem Int
k Int
p)
  Int
i <- [Int
0 .. Int
k forall a. Num a => a -> a -> a
- Int
1]
  forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ forall a. Bits a => a -> Int -> a
shiftR (Int
i forall a. Num a => a -> a -> a
+ Int
j) Int
pw forall a. Eq a => a -> a -> Bool
== forall a. Bits a => a -> Int -> a
shiftR (Int
i forall a. Num a => a -> a -> a
+ Int
j forall a. Num a => a -> a -> a
+ Int
k) Int
pw
  {-
    Index could get out of bound without this check
    when n is not a power of 2.

    I'm not sure about then correctness when n > 16,
    but our tests can verify its correctness for n = [2.. 16]
    based on 0-1 principle.
   -}
  forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ Int
i forall a. Num a => a -> a -> a
+ Int
j forall a. Num a => a -> a -> a
+ Int
k forall a. Ord a => a -> a -> Bool
< Int
n
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
i forall a. Num a => a -> a -> a
+ Int
j, Int
i forall a. Num a => a -> a -> a
+ Int
j forall a. Num a => a -> a -> a
+ Int
k)