```
-- | Normal form of braids, take 1.
--
-- We implement the Adyan-Thurston-ElRifai-Morton solution to the word problem in braid groups.
--
--
-- Based on:
--
-- * [1] Joan S. Birman, Tara E. Brendle: BRAIDS - A SURVEY
--   <https://www.math.columbia.edu/~jb/Handbook-21.pdf> (chapter 5.1)
--
-- * [2] Elsayed A. Elrifai, Hugh R. Morton: Algorithms for positive braids
--

{-# LANGUAGE
CPP, BangPatterns,
ScopedTypeVariables, ExistentialQuantification,
DataKinds, KindSignatures, Rank2Types #-}

module Math.Combinat.Groups.Braid.NF
( -- * Normal form
BraidNF (..)
, nfReprWord
, braidNormalForm
, braidNormalForm'
, braidNormalFormNaive'
-- * Starting and finishing sets
, permWordStartingSet
, permWordFinishingSet
, permutationStartingSet
, permutationFinishingSet
)
where

--------------------------------------------------------------------------------

import Data.Proxy
import GHC.TypeLits

import Data.List ( mapAccumL , foldl' , (\\) )

import Data.Array.Unboxed
import Data.Array.ST
import Data.Array.IArray
import Data.Array.MArray
import Data.Array.Unsafe
import Data.Array.Base

import Math.Combinat.Helper
import Math.Combinat.Sign

import Math.Combinat.Permutations ( Permutation(..) , isIdentityPermutation , isReversePermutation )
import qualified Math.Combinat.Permutations as P

import Math.Combinat.Groups.Braid

--------------------------------------------------------------------------------

-- | A unique normal form for braids, called the /left-greedy normal form/.
-- It looks like @Delta^i*P@, where @Delta@ is the positive half-twist, @i@ is an integer,
-- and @P@ is a positive word, which can be further decomposed into non-@Delta@ /permutation words/;
-- these words themselves are not unique, but the permutations they realize /are/ unique.
--
-- This will solve the word problem relatively fast,
-- though it is not the fastest known algorithm.
--
data BraidNF (n :: Nat) = BraidNF
{ _nfDeltaExp :: !Int              -- ^ the exponent of @Delta@
, _nfPerms    :: [Permutation]     -- ^ the permutations
}
deriving (Eq,Ord,Show)

-- | A braid word representing the given normal form
nfReprWord :: KnownNat n => BraidNF n -> Braid n
nfReprWord (BraidNF k perms) = freeReduceBraidWord \$ composeMany (deltas ++ rest) where

deltas
| k > 0     = replicate   k           halfTwist
| k < 0     = replicate (-k) (inverse halfTwist)
| otherwise = []

rest = map permutationBraid perms

--------------------------------------------------------------------------------

-- | Computes the normal form of a braid. We apply free reduction first, it should be faster that way.
braidNormalForm :: KnownNat n => Braid n -> BraidNF n
braidNormalForm = braidNormalForm' . freeReduceBraidWord

-- | This function does not apply free reduction before computing the normal form
braidNormalForm' :: KnownNat n => Braid n -> BraidNF n
braidNormalForm' braid@(Braid gens) = BraidNF (dexp+pexp) perms where
n = numberOfStrands braid
invless = replaceInverses n gens
(dexp,posxword) = moveDeltasLeft n invless
factors = leftGreedyFactors n \$ expandPosXWord n posxword
(pexp,perms) = normalizePermFactors n \$ map (_braidPermutation n) factors

-- | This one uses the naive inverse replacement method. Probably somewhat slower than 'braidNormalForm''.
braidNormalFormNaive' :: KnownNat n => Braid n -> BraidNF n
braidNormalFormNaive' braid@(Braid gens) = BraidNF (dexp+pexp) perms where
n = numberOfStrands braid
invless = replaceInversesNaive gens
(dexp,posxword) = moveDeltasLeft n invless
factors = leftGreedyFactors n \$ expandPosXWord n posxword
(pexp,perms) = normalizePermFactors n \$ map (_braidPermutation n) factors

--------------------------------------------------------------------------------

-- | Replaces groups of @sigma_i^-1@ generators by @(Delta^-1 * P)@,
-- where @P@ is a positive word.
--
-- This should be more clever (resulting in shorter words) than the naive version below
--
replaceInverses :: Int -> [BrGen] -> [XGen]
replaceInverses n gens = worker gens where

worker [] = []
worker xs = replaceNegs neg ++ map (XSigma . brGenIdx) pos ++ worker rest where
(neg,tmp ) = span (isMinus . brGenSign) xs
(pos,rest) = span (isPlus  . brGenSign) tmp

replaceNegs gs = concatMap replaceFac facs where
facs = leftGreedyFactors n \$ map brGenIdx gs

replaceFac idxs = XDelta (-1) : map XSigma (_permutationBraid perm) where
perm = (P.reversePermutation n) `P.multiply` (P.adjacentTranspositions n idxs)

-- | Replaces @sigma_i^-1@ generators by @(Delta^-1 * L_i)@.
replaceInversesNaive :: [BrGen] -> [XGen]
replaceInversesNaive gens = concatMap f gens where
f (Sigma    i) = [ XSigma i ]
f (SigmaInv i) = [ XDelta (-1) , XL i ]

--------------------------------------------------------------------------------

-- | Temporary data structure to be used during the normal form computation
data XGen
= XDelta !Int   -- ^ @Delta^k@
| XSigma !Int   -- ^ @Sigma_j@
| XL     !Int   -- ^ @L_j = Delta * sigma_j^-1@
| XTauL  !Int   -- ^ @tau(L_j)@
deriving (Eq,Show)

isXDelta :: XGen -> Bool
isXDelta x = case x of { XDelta {} -> True ; _ -> False }

-- | We move the all @Delta@'s to the left
moveDeltasLeft :: Int -> [XGen] -> (Int,[XGen])
moveDeltasLeft n input = (finalExp, finalPosWord) where

(XDelta finalExp : finalPosWord) =  reverse \$ worker 0 (reverse input)

-- we start from the right end, and work towards the left end
worker  dexp [] = [ XDelta dexp ]
worker !dexp xs = this' ++ worker dexp' rest where
(delta,notdelta) = span isXDelta xs
(this ,rest    ) = span (not . isXDelta) notdelta
dexp' = dexp + sumDeltas delta
this' = if even dexp'
then this
else map xtau this

sumDeltas :: [XGen] -> Int
sumDeltas xs = foldl' (+) 0 [ k | XDelta k <- xs ]

-- | The @X -> Delta^-1 * X * Delta@ inner automorphism
xtau :: XGen -> XGen
xtau (XSigma j) = XSigma (n-j)
xtau (XDelta k) = XDelta k
xtau (XL     k) = XTauL  k
xtau (XTauL  k) = XL     k

--------------------------------------------------------------------------------

-- | Expands a /positive/ \"X-word\" into a positive braid word
expandPosXWord :: Int -> [XGen] -> [Int]
expandPosXWord n = concatMap f where

posHalfTwist = _halfTwist n

jtau :: Int -> Int
jtau j = n-j

posLTable    = listArray (1,n-1) [ _permutationBraid (posLPerm n i) | i<-[1..n-1] ] :: Array Int [Int]
posTauLTable = amap (map jtau) posLTable

-- posRTable = listArray (1,n-1) [ _permutationBraid (posRPerm n i) | i<-[1..n-1] ] :: Array Int [Int]

f x = case x of
XSigma i -> [i]
XL     i -> posLTable    ! i
XTauL  i -> posTauLTable ! i
XDelta i
| i > 0     -> concat (replicate i posHalfTwist)
| i < 0     -> error "expandPosXWord: negative delta power"
| otherwise -> []

-- word :: Braid n -> [Int]
-- word (Braid gens) = map brGenIdx gens

-- | Expands an \"X-word\" into a braid word. Useful for debugging.
expandAnyXWord :: forall n. KnownNat n => [XGen] -> Braid n
expandAnyXWord xgens = braid where
n = numberOfStrands braid

braid = composeMany (map f xgens)

posHalfTwist = halfTwist            :: Braid n
negHalfTwist = inverse posHalfTwist :: Braid n

posLTable    = listArray (1,n-1) [ permutationBraid (posLPerm n i) | i<-[1..n-1] ] :: Array Int (Braid n)
posTauLTable = amap tau posLTable

-- posRTable = listArray (1,n-1) [ permutationBraid (posRPerm n i) | i<-[1..n-1] ] :: Array Int (Braid n)

f :: XGen -> Braid n
f x = case x of
XSigma i -> sigma i
XL     i -> posLTable    ! i
XTauL  i -> posTauLTable ! i
XDelta i
| i > 0     -> composeMany (replicate   i  posHalfTwist)
| i < 0     -> composeMany (replicate (-i) negHalfTwist)
| otherwise -> identity

--------------------------------------------------------------------------------

-- | @posL k@ (denoted as @L_k@) is a /positive word/ which
-- satisfies @Delta = L_k * sigma_k@, or:
--
-- > (inverse halfTwist) `compose` (posL k) ~=~ sigmaInv k@
--
-- Thus we can replace any word with a positive word plus some @Delta^-1@\'s
--
posL :: KnownNat n => Int -> Braid n
posL k = braid where
n = numberOfStrands braid
braid = permutationBraid (posLPerm n k)

-- | @posR k n@ (denoted as @R_k@) is a /permutation braid/ which
-- satisfies @Delta = sigma_k * R_k@
--
-- > (posR k) `compose` (inverse halfTwist) ~=~ sigmaInv k@
--
-- Thus we can replace any word with a positive word plus some @Delta^-1@'s
--
posR :: KnownNat n => Int -> Braid n
posR k = braid where
n = numberOfStrands braid
braid = permutationBraid (posRPerm n k)

-- | The permutation @posL k :: Braid n@ is realizing
posLPerm :: Int -> Int -> Permutation
posLPerm n k
| k>0 && k<n  = (P.reversePermutation n `P.multiply` P.adjacentTransposition n k)
| otherwise   = error "posLPerm: index out of range"

-- | The permutation @posR k :: Braid n@ is realizing
posRPerm :: Int -> Int -> Permutation
posRPerm n k
| k>0 && k<n  = (P.adjacentTransposition n k `P.multiply` P.reversePermutation n )
| otherwise   = error "posRPerm: index out of range"

--------------------------------------------------------------------------------

-- | We recognize left-greedy factors which are @Delta@-s (easy, since they are the only ones
-- with length @(n choose 2)@), and move them to the left, returning their summed exponent
-- and the filtered new factors. We also filter trivial permutations (which should only happen
-- for the trivial braid, but it happens there?)
--
filterDeltaFactors :: Int -> [[Int]] -> (Int, [[Int]])
filterDeltaFactors n facs = (exp',facs'') where

(exp',facs') = go 0 (reverse facs)

jtau j = n-j
facs'' = reverse facs'
maxlen = div (n*(n-1)) 2

go !e []       = (e,[])
go !e (xs:xxs)
| null xs             = go e xxs
| length xs == maxlen = go (e+1) xxs
| otherwise           =
if even e
then let (e',yys) = go e xxs in (e' ,          xs : yys)
else let (e',yys) = go e xxs in (e' , map jtau xs : yys)

--------------------------------------------------------------------------------

-- | The /starting set/ of a positive braid P is the subset of @[1..n-1]@ defined by
--
-- > S(P) = [ i | P = sigma_i * Q , Q is positive ] = [ i | (sigma_i^-1 * P) is positive ]
--
-- This function returns the starting set a positive word, assuming it
-- is a /permutation braid/ (see Lemma 2.4 in [2])
--
permWordStartingSet :: Int -> [Int] -> [Int]
permWordStartingSet n xs = permWordFinishingSet n (reverse xs)

-- | The /finishing set/ of a positive braid P is the subset of @[1..n-1]@ defined by
--
-- > F(P) = [ i | P = Q * sigma_i , Q is positive ] = [ i | (P * sigma_i^-1) is positive ]
--
-- This function returns the finishing set, assuming the input is a /permutation braid/
--
permWordFinishingSet :: Int -> [Int] -> [Int]
permWordFinishingSet n input = runST action where

action :: forall s. ST s [Int]
action = do
perm <- newArray_ (1,n) :: ST s (STUArray s Int Int)
forM_ [1..n] \$ \i -> writeArray perm i i
forM_ input \$ \i -> do
writeArray perm  i    b
writeArray perm (i+1) a
flip filterM [1..n-1] \$ \i -> do
return (b<a)                    -- Lemma 2.4 in [2]

-- | This satisfies
--
-- > permutationStartingSet p == permWordStartingSet n (_permutationBraid p)
--
permutationStartingSet :: Permutation -> [Int]
permutationStartingSet = permutationFinishingSet . P.inverse

-- | This satisfies
--
-- > permutationFinishingSet p == permWordFinishingSet n (_permutationBraid p)
--
permutationFinishingSet :: Permutation -> [Int]
permutationFinishingSet (Permutation arr)
= [ i | i<-[1..n-1] , arr ! i > arr ! (i+1) ] where (1,n) = bounds arr

-- | Returns the list of permutations failing Lemma 2.5 in [2]
-- (so an empty list means the implementaton is correct)
fails_lemmma_2_5 :: Int -> [Permutation]
fails_lemmma_2_5 n = [ p | p <- P.permutations n , not (test p) ] where
test p = and [ check i | i<-[1..n-1] ] where
w = _permutationBraid p
s = permWordStartingSet n w
check i = _isPermutationBraid n (i:w) == (not \$ elem i s)

--------------------------------------------------------------------------------

-- | Given factors defined as permutation braids, we normalize them
-- to /left-canonical form/ by ensuring that
--
-- * for each consecutive pair @(P,Q)@ the finishing set F(P) contains the starting set S(Q)
--
-- * all @Delta@-s (corresponding to the reverse permutation) are moved to the left
--
-- * all trivial factors are filtered out
--
-- Unfortunately, it seems that we may need multiple sweeps to do that...
--
normalizePermFactors :: Int -> [Permutation] -> (Int,[Permutation])
normalizePermFactors n = go 0 where
go !acc input =
if (exp==0 && input == output)
then (acc,input)
else go (acc+exp) output
where
(exp,output) = normalizePermFactors1 n input

-- | Does 1 sweep of the above normalization process.
-- Unfortunately, it seems that we may need to do this multiple times...
--
normalizePermFactors1 :: Int -> [Permutation] -> (Int,[Permutation])
normalizePermFactors1 n input = (exp, reverse output) where
(exp, output) = worker 0 (reverse input)

-- Notes: We work in reverse order, from the right to the left.
-- We maintain the number of Delta-s pushed through; the tau involutions
-- are implicit in the parity of this number
--
worker :: Int -> [Permutation] -> (Int,[Permutation])
worker = worker' 0 0

-- We also maintain additional 0/1 flip flags for the first two permutations
-- this is a little bit of hack but it should work nicely
--
worker' :: Int -> Int -> Int -> [Permutation] -> (Int,[Permutation])
worker' !ep !eq !e (!p : rest@(!q : rest'))

-- check if the very first element is identity or Delta
-- (note: these are tau-invariants)

| isIdentityPermutation p  = worker'  eq  0  e    rest
| isReversePermutation  p  = worker'  eq  0 (e+1) rest

-- check if the second element is identity or Delta
-- this is necessary since we "fatten" the second element and it can possibly
-- become Delta after a while (?)

| isIdentityPermutation q  = worker'  ep    0  e    (p : rest')
| isReversePermutation  q  = worker' (ep-1) 0 (e+1) (p : rest')

-- ok so we have something like "... : Q : P"
-- if F(Q) contains S(P) then we can move on;
-- otherwise there is an element j in S(P) \\ F(Q), so we can
-- replace it by "... : Qj : jP"

| otherwise =
case permutationStartingSet preal \\ permutationFinishingSet qreal of
[]    -> let (e',rs) = worker' eq 0 e rest in (e', preal : rs)
(j:_) -> worker' (-e) (-e) e (p':q':rest') where
p' = P.multiply s preal
q' = P.multiply qreal s
where
preal = oddTau (e+ep) p       -- the "real" p
qreal = oddTau (e+eq) q       -- the "real" q

worker' _   _  !e [ ] = (e,[])
worker' !ep _  !e [p]
| isIdentityPermutation p  = (e   , [])
| isReversePermutation  p  = (e+1 , [])
| otherwise                = (e   , [oddTau (e+ep) p] )

oddTau :: Int -> Permutation -> Permutation
oddTau !e p = if even e then p else tauPerm p

{-
checkDelta :: Int -> Permutation -> [Permutation] -> (Int,[Permutation])
checkDelta !e !p !rest
| P.isIdentityPermutation p  = worker  e    rest
| isReversePermutation    p  = worker (e+1) rest
| otherwise                  = let (e',rs) = worker e rest in (e', oddTau e p : rs)
-}

--------------------------------------------------------------------------------

-- | Given a /positive/ word, we apply left-greedy factorization of
-- that word into subwords representing /permutation braids/.
--
-- Example 5.1 from the above handbook:
--
-- > leftGreedyFactors 7 [1,3,2,2,1,3,3,2,3,2] == [[1,3,2],[2,1,3],[3,2,3],[2]]
--
leftGreedyFactors :: Int -> [Int] -> [[Int]]
leftGreedyFactors n input = filter (not . null) \$ runST (action input) where

action :: forall s. [Int] -> ST s [[Int]]
action input = do

perm <- newArray_ (1,n) :: ST s (STUArray s Int Int)
forM_ [1..n] \$ \i -> writeArray perm i i
let doSwap :: Int -> ST s ()
doSwap i = do
writeArray perm  i    b
writeArray perm (i+1) a

mat <- newArray ((1,1),(n,n)) 0 :: ST s (STUArray s (Int,Int) Int)
let clearMat = forM_ [1..n] \$ \i ->
forM_ [1..n] \$ \j -> writeArray mat (i,j) 0

let doAdd1 :: Int -> Int -> ST s Int
let y = x+1
writeArray mat (i,j) y
writeArray mat (j,i) y
return y

let worker :: [Int] -> ST s [[Int]]
worker []     = return [[]]
worker (p:ps) = do
doSwap p
if c<=1
then do
(f:fs) <- worker ps
return ((p:f):fs)
else do
clearMat
fs <- worker (p:ps)
return ([]:fs)

worker input

--------------------------------------------------------------------------------

{-

-- | Finds ternary braid relations, and returns them as a list of indices, decorated
-- with a flag specifying which side of the relation we found, a sign specifying
-- whether it is a relation between positive or negative generators.
--
findTernaryBraidRelations :: Braid n -> [(Int,Bool,Sign)]
findTernaryBraidRelations (Braid gens) = go 0 gens where
go !k (Sigma a : rest@(Sigma b : Sigma c : _))
| a==c && b==a+1 = (k,True ,Plus) : go (k+1) rest
| a==c && b==a-1 = (k,False,Plus) : go (k+1) rest
| otherwise      =                  go (k+1) rest
go !k (SigmaInv a : rest@(SigmaInv b : SigmaInv c : _))
| a==c && b==a+1 = (k,True ,Minus) : go (k+1) rest
| a==c && b==a-1 = (k,False,Minus) : go (k+1) rest
| otherwise      =                   go (k+1) rest
go !k (x:xs) = go (k+1) xs
go _  []     = []

-- | Finds subsequences like @(i,i+1,i)@ and @(i+1,i,i+1)@, and returns them
-- and a list of indices, plus a flag specifying which one we found (the first
-- one is 'True', second one is 'False')
--
_findTernaryBraidRelations :: [Int] -> [(Int,Bool)]
_findTernaryBraidRelations = go 0 where
go !k (a:rest@(b:c:_))
| a==c && b==a+1 = (k,True ) : go (k+1) rest
| a==c && b==a-1 = (k,False) : go (k+1) rest
| otherwise      =             go (k+1) rest
go !k (x:xs) = go (k+1) xs
go _  []     = []

-}

--------------------------------------------------------------------------------

```