{-|
  Functions that generate sorting networks given a input size.
 -}

module Data.SortingNetwork.Compares (
  oddEvenMerge,
  optimal,
) where

import Control.Monad
import Data.Bits
import Data.SortingNetwork.Types

{-
  TODO: might be useful: https://metacpan.org/dist/Algorithm-Networksort/source/lib/Algorithm/Networksort.pm
 -}

{-|
  Batcher's odd-even mergesort

  Adopted from Pseudocode section of <https://en.wikipedia.org/wiki/Batcher_odd%E2%80%93even_mergesort>.
 -}
oddEvenMerge :: MkPairs
oddEvenMerge n =
  guard (n >= 0) >> pure do
    -- INVARIANT: p == shiftL 1 (pw - 1)
    (p, pw) <- zip (takeWhile (< n) $ iterate (* 2) 1) [1 ..]
    k <- takeWhile (>= 1) $ iterate (\v -> shiftR v 1) p
    j <- takeWhile (<= n - 1 - k) $ iterate (+ 2 * k) (rem k p)
    i <- [0 .. k - 1]
    guard $ shiftR (i + j) pw == shiftR (i + j + k) pw
    {-
      Index could get out of bound without this check
      when n is not a power of 2.

      I'm not sure about then correctness when n > 16,
      but our tests can verify its correctness for n = [2.. 16]
      based on 0-1 principle.
     -}
    guard $ i + j + k < n
    pure (i + j, i + j + k)

{-|
  Sorting networks that are optimal by size.

  Source from <https://bertdobbelaere.github.io/sorting_networks.html>.
 -}
optimal :: MkPairs
optimal v = lookup v optimalNetworks

{-
  In the event that there are multiple networks, one with
  minimal size is picked (but this decision is arbitrary, we will reconsider later).
 -}
optimalNetworks :: [(Int, [] (Int, Int))]
optimalNetworks =
  [ (2, [(0, 1)])
  , (3, [(0, 2), (0, 1), (1, 2)])
  , (4, [(0, 2), (1, 3), (0, 1), (2, 3), (1, 2)])
  , (5, [(0, 3), (1, 4), (0, 2), (1, 3), (0, 1), (2, 4), (1, 2), (3, 4), (2, 3)])
  , (6, [(0, 5), (1, 3), (2, 4), (1, 2), (3, 4), (0, 3), (2, 5), (0, 1), (2, 3), (4, 5), (1, 2), (3, 4)])
  , (7, [(0, 6), (2, 3), (4, 5), (0, 2), (1, 4), (3, 6), (0, 1), (2, 5), (3, 4), (1, 2), (4, 6), (2, 3), (4, 5), (1, 2), (3, 4), (5, 6)])
  , (8, [(0, 2), (1, 3), (4, 6), (5, 7), (0, 4), (1, 5), (2, 6), (3, 7), (0, 1), (2, 3), (4, 5), (6, 7), (2, 4), (3, 5), (1, 4), (3, 6), (1, 2), (3, 4), (5, 6)])
  , (9, [(0, 3), (1, 7), (2, 5), (4, 8), (0, 7), (2, 4), (3, 8), (5, 6), (0, 2), (1, 3), (4, 5), (7, 8), (1, 4), (3, 6), (5, 7), (0, 1), (2, 4), (3, 5), (6, 8), (2, 3), (4, 5), (6, 7), (1, 2), (3, 4), (5, 6)])
  , (10, [(0, 8), (1, 9), (2, 7), (3, 5), (4, 6), (0, 2), (1, 4), (5, 8), (7, 9), (0, 3), (2, 4), (5, 7), (6, 9), (0, 1), (3, 6), (8, 9), (1, 5), (2, 3), (4, 8), (6, 7), (1, 2), (3, 5), (4, 6), (7, 8), (2, 3), (4, 5), (6, 7), (3, 4), (5, 6)])
  , (11, [(0, 9), (1, 6), (2, 4), (3, 7), (5, 8), (0, 1), (3, 5), (4, 10), (6, 9), (7, 8), (1, 3), (2, 5), (4, 7), (8, 10), (0, 4), (1, 2), (3, 7), (5, 9), (6, 8), (0, 1), (2, 6), (4, 5), (7, 8), (9, 10), (2, 4), (3, 6), (5, 7), (8, 9), (1, 2), (3, 4), (5, 6), (7, 8), (2, 3), (4, 5), (6, 7)])
  , (12, [(0, 8), (1, 7), (2, 6), (3, 11), (4, 10), (5, 9), (0, 1), (2, 5), (3, 4), (6, 9), (7, 8), (10, 11), (0, 2), (1, 6), (5, 10), (9, 11), (0, 3), (1, 2), (4, 6), (5, 7), (8, 11), (9, 10), (1, 4), (3, 5), (6, 8), (7, 10), (1, 3), (2, 5), (6, 9), (8, 10), (2, 3), (4, 5), (6, 7), (8, 9), (4, 6), (5, 7), (3, 4), (5, 6), (7, 8)])
  , (13, [(0, 12), (1, 10), (2, 9), (3, 7), (5, 11), (6, 8), (1, 6), (2, 3), (4, 11), (7, 9), (8, 10), (0, 4), (1, 2), (3, 6), (7, 8), (9, 10), (11, 12), (4, 6), (5, 9), (8, 11), (10, 12), (0, 5), (3, 8), (4, 7), (6, 11), (9, 10), (0, 1), (2, 5), (6, 9), (7, 8), (10, 11), (1, 3), (2, 4), (5, 6), (9, 10), (1, 2), (3, 4), (5, 7), (6, 8), (2, 3), (4, 5), (6, 7), (8, 9), (3, 4), (5, 6)])
  , (14, [(0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11), (12, 13), (0, 2), (1, 3), (4, 8), (5, 9), (10, 12), (11, 13), (0, 4), (1, 2), (3, 7), (5, 8), (6, 10), (9, 13), (11, 12), (0, 6), (1, 5), (3, 9), (4, 10), (7, 13), (8, 12), (2, 10), (3, 11), (4, 6), (7, 9), (1, 3), (2, 8), (5, 11), (6, 7), (10, 12), (1, 4), (2, 6), (3, 5), (7, 11), (8, 10), (9, 12), (2, 4), (3, 6), (5, 8), (7, 10), (9, 11), (3, 4), (5, 6), (7, 8), (9, 10), (6, 7)])
  , (15, [(1, 2), (3, 10), (4, 14), (5, 8), (6, 13), (7, 12), (9, 11), (0, 14), (1, 5), (2, 8), (3, 7), (6, 9), (10, 12), (11, 13), (0, 7), (1, 6), (2, 9), (4, 10), (5, 11), (8, 13), (12, 14), (0, 6), (2, 4), (3, 5), (7, 11), (8, 10), (9, 12), (13, 14), (0, 3), (1, 2), (4, 7), (5, 9), (6, 8), (10, 11), (12, 13), (0, 1), (2, 3), (4, 6), (7, 9), (10, 12), (11, 13), (1, 2), (3, 5), (8, 10), (11, 12), (3, 4), (5, 6), (7, 8), (9, 10), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11), (5, 6), (7, 8)])
  , (16, [(0, 13), (1, 12), (2, 15), (3, 14), (4, 8), (5, 6), (7, 11), (9, 10), (0, 5), (1, 7), (2, 9), (3, 4), (6, 13), (8, 14), (10, 15), (11, 12), (0, 1), (2, 3), (4, 5), (6, 8), (7, 9), (10, 11), (12, 13), (14, 15), (0, 2), (1, 3), (4, 10), (5, 11), (6, 7), (8, 9), (12, 14), (13, 15), (1, 2), (3, 12), (4, 6), (5, 7), (8, 10), (9, 11), (13, 14), (1, 4), (2, 6), (5, 8), (7, 10), (9, 13), (11, 14), (2, 4), (3, 6), (9, 12), (11, 13), (3, 5), (6, 8), (7, 9), (10, 12), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12), (6, 7), (8, 9)])
  ]