module Math.Algebra.Group.StringRewriting where
import Data.List as L
import Data.Maybe (catMaybes)
rewrite :: (Eq a) => [([a], [a])] -> [a] -> [a]
rewrite rules word = rewrite' rules word where
    rewrite' (r:rs) xs =
        case rewrite1 r xs of
        Nothing -> rewrite' rs xs
        Just ys -> rewrite' rules ys
    rewrite' [] xs = xs
rewrite1 (l,r) xs =
    case xs `splitSubstring` l of
    Nothing -> Nothing
    Just (a,b) -> Just (a++r++b)
splitSubstring xs ys = splitSubstring' [] xs where
    splitSubstring' ls [] = Nothing
    splitSubstring' ls (r:rs) =
        if ys `L.isPrefixOf` (r:rs)
        then Just (reverse ls, drop (length ys) (r:rs))
        else splitSubstring' (r:ls) rs
findOverlap xs ys = findOverlap' [] xs ys where
    findOverlap' as [] cs = Nothing 
    findOverlap' as (b:bs) cs =
        if (b:bs) `L.isPrefixOf` cs
        then Just (reverse as, b:bs, drop (length (b:bs)) cs)
        else findOverlap' (b:as) bs cs
knuthBendix1 rules = knuthBendix' rules pairs where
    pairs = [(lri,lrj) | lri <- rules, lrj <- rules, lri /= lrj]
    knuthBendix' rules [] = rules 
    knuthBendix' rules ( ((li,ri),(lj,rj)) : ps) =
        case findOverlap li lj of
        Nothing -> knuthBendix' rules ps
        Just (a,b,c) -> case ordpair (rewrite rules (ri++c)) (rewrite rules (a++rj)) of
                        Nothing -> knuthBendix' rules ps 
                        Just rule' -> let rules' = reduce rule' rules
                                          ps' = ps ++ [(rule',rule) | rule <- rules'] ++ [(rule,rule') | rule <- rules']
                                      in knuthBendix' (rule':rules') ps'
                        
                        
                        
    reduce rule@(l,r) rules = filter (\(l',r') -> not (L.isInfixOf l l')) rules
        
ordpair x y =
    case shortlex x y of
    LT -> Just (y,x)
    EQ -> Nothing
    GT -> Just (x,y)
shortlex x y = compare (length x, x) (length y, y)
knuthBendix2 rules = map snd $ knuthBendix' rules' pairs where
    rules' = L.sort $ map sizedRule rules
    pairs = L.sort [sizedPair sri srj | sri <- rules', srj <- rules', sri /= srj]
    knuthBendix' rules [] = rules
    knuthBendix' rules ( (s,(li,ri),(lj,rj)) : ps) =
        case findOverlap li lj of
        Nothing -> knuthBendix' rules ps
        Just (a,b,c) -> case ordpair (rewrite (map snd rules) (ri++c)) (rewrite (map snd rules) (a++rj)) of
                        Nothing -> knuthBendix' rules ps 
                        Just rule' -> let rules' = reduce (snd rule') rules
                                          
                                          ps' = merge ps $ merge [sizedPair rule' rule | rule <- rules'] [sizedPair rule rule' | rule <- rules']
                                     in knuthBendix' (L.insert rule' rules') ps'
    reduce rule@(l,r) rules = filter (\(s',(l',r')) -> not (L.isInfixOf l l')) rules
    
    ordpair x y =
        let lx = length x; ly = length y in
            case compare (lx,x) (ly,y) of
            LT -> Just (ly,(y,x)); EQ -> Nothing; GT -> Just (lx,(x,y))
    sizedRule (rule@(l,r)) = (length l, rule)
    sizedPair (s1,r1) (s2,r2) = (s1+s2,r1,r2)
merge (x:xs) (y:ys) =
    case compare x y of
    LT -> x : merge xs (y:ys)
    GT -> y : merge (x:xs) ys
    EQ -> error "" 
merge xs ys = xs++ys
knuthBendix3 rules = knuthBendix' rules' pairs (length rules' + 1) where
    rules' = L.sort $ zipWith (\i (l,r) -> (length l,i,(l,r)) ) [1..] rules
    pairs = L.sort [sizedPair ri rj | ri <- rules', rj <- rules', ri /= rj]
    knuthBendix' rules [] k = map (\(s,i,r) -> r) rules
    knuthBendix' rules ( (s,(i,j),((li,ri),(lj,rj))) : ps) k =
        case findOverlap li lj of
        Nothing -> knuthBendix' rules ps k
        Just (a,b,c) -> case ordpair k (rewrite (map third rules) (ri++c)) (rewrite (map third rules) (a++rj)) of
                        Nothing -> knuthBendix' rules ps k 
                        Just rule'@(_,_,(l,r)) ->
                            let (outrules,inrules) = L.partition (\(s',i',(l',r')) -> L.isInfixOf l l') rules
                                removedIndices = map second outrules
                                ps' = [p | p@(s,(i,j),(ri,rj)) <- ps, i `notElem` removedIndices, j `notElem` removedIndices]
                                ps'' = merge ps' $ merge [sizedPair rule' rule | rule <- inrules] [sizedPair rule rule' | rule <- inrules]
                            in knuthBendix' (L.insert rule' inrules) ps'' (k+1)
    ordpair k x y =
        let lx = length x; ly = length y in
            case compare (lx,x) (ly,y) of
            LT -> Just (ly,k,(y,x)); EQ -> Nothing; GT -> Just (lx,k,(x,y))
    second (s,i,r) = i
    third (s,i,r) = r
    sizedPair (si,i,ri) (sj,j,rj) = (si+sj,(i,j),(ri,rj))
knuthBendix :: (Ord a) => [([a], [a])] -> [([a], [a])]
knuthBendix relations = knuthBendix3 (reduce [] rules) where
    rules = catMaybes [ordpair x y | (x,y) <- relations]
    reduce ls (r:rs) = reduce (r: reduce' r ls) (reduce' r rs)
    reduce ls [] = ls
    reduce' r rules = catMaybes [ordpair (rewrite [r] lhs) (rewrite [r] rhs) | (lhs,rhs) <- rules]
nfs :: (Ord a) => ([a], [([a], [a])]) -> [[a]]
nfs (gs,rs) = nfs' [[]] where
    nfs' [] = [] 
    nfs' ws = let ws' = [g:w | g <- gs, w <- ws, not (any (`L.isPrefixOf` (g:w)) (map fst rs))]
              in ws ++ nfs' ws'
elts :: (Ord a) => ([a], [([a], [a])]) -> [[a]]
elts (gs,rs) = nfs (gs, knuthBendix rs)
newtype SGen = S Int deriving (Eq,Ord)
instance Show SGen where
    show (S i) = "s" ++ show i
s_ i = S i
s1 = s_ 1
s2 = s_ 2
s3 = s_ 3
_S n = (gs, r ++ s ++ t) where
    gs = map s_ [1..n-1]
    r = [([s_ i, s_ i],[]) | i <- [1..n-1]]
    s = [(concat $ replicate 3 [s_ i, s_ (i+1)],[]) | i <- [1..n-2]]
    t = [([s_ i, s_ j, s_ i, s_ j],[]) | i <- [1..n-1], j <- [i+2..n-1]]
_S' n = (gs, r ++ s ++ t) where
    gs = map s_ [1..n-1]
    r = [([s_ i, s_ i], []) | i <- [1..n-1]]
    s = [([s_ (i+1), s_ i, s_ (i+1)], [s_ i, s_ (i+1), s_ i] ) | i <- [1..n-2]]
    t = [([s_ i, s_ j, s_ i, s_ j], []) | i <- [1..n-1], j <- [i+2..n-1]]
tri l m n = ("abc", [("aa",""),("bb",""),("cc",""),("ab" ^ l,""),("bc" ^ n,""),("ca" ^ m,"" )])
    where xs ^ i = concat $ replicate i xs
_D l m n = ("xy", [("x" ^ l,""), ("y" ^ m,""), ("xy" ^ n,"")])
    where xs ^ i = concat $ replicate i xs