module Math.Combinatorics.LatinSquares where
import qualified Data.List as L
import qualified Data.Set as S
import qualified Data.Map as M
import Math.Combinatorics.Design
import Math.Algebra.Field.Base
import Math.Algebra.Field.Extension
import Math.Algebra.LinearAlgebra (fMatrix')
import Math.Combinatorics.Graph
import Math.Combinatorics.GraphAuts
import Math.Combinatorics.StronglyRegularGraph
import Math.Core.Utils (combinationsOf)
findLatinSqs :: (Eq a) => [a] -> [[[a]]]
findLatinSqs xs = findLatinSqs' 1 [xs] where
    n = length xs
    findLatinSqs' i rows
        | i == n    = [reverse rows]
        | otherwise = concat [findLatinSqs' (i+1) (row:rows)
                             | row <- findRows (L.transpose rows) [] xs]
    findRows (col:cols) ls rs = concat [findRows cols (r:ls) (L.delete r rs)
                                    | r <- rs, r `notElem` col]
    findRows [] ls _ = [reverse ls]
isLatinSq :: (Ord a) => [[a]] -> Bool
isLatinSq rows = all isOneOfEach rows && all isOneOfEach cols where
    cols = L.transpose rows
isOneOfEach xs = length xs == S.size (S.fromList xs)
incidenceGraphLS l = graph (vs,es) where
    n = length l 
    vs = [ (i, j, l ! (i,j)) | i <- [1..n], j <- [1..n] ]
    es = [ [v1,v2] | [v1@(i,j,lij), v2@(i',j',lij')] <- combinationsOf 2 vs, i == i' || j == j' || lij == lij' ]
    m ! (i,j) = m !! (i-1) !! (j-1)
incidenceGraphLS' l = graph (vs,es) where
    n = length l 
    vs = [ (i, j) | i <- [1..n], j <- [1..n] ]
    es = [ [v1,v2] | [v1@(i,j), v2@(i',j')] <- combinationsOf 2 vs, i == i' || j == j' || l' M.! (i,j) == l' M.! (i',j') ]
    l' = M.fromList [ ( (i,j), l !! (i-1) !! (j-1) ) | i <- [1..n], j <- [1..n] ]
isOrthogonal :: (Ord a, Ord b) => [[a]] -> [[b]] -> Bool
isOrthogonal greeks latins = isOneOfEach pairs
    where pairs = zip (concat greeks) (concat latins)
findMOLS k lsqs = findMOLS' k [] lsqs where
    findMOLS' 0 ls _ = [reverse ls]
    findMOLS' i ls (r:rs) =
        if all (isOrthogonal r) ls
        then findMOLS' (i-1) (r:ls) rs ++ findMOLS' i ls rs
        else findMOLS' i ls rs
    findMOLS' _ _ [] = []
isMOLS :: (Ord a) => [[[a]]] -> Bool
isMOLS (greek:latins) = all (isOrthogonal greek) latins && isMOLS latins
isMOLS [] = True
fromProjectivePlane :: (Ord k, Num k) => Design [k] -> [[[Int]]]
fromProjectivePlane (D xs bs) = map toLS parallelClasses where
    k = [x | [0,1,x] <- xs] 
    n = length k            
    parallelClasses = drop 2 $ L.groupBy (\l1 l2 -> head l1 == head l2) bs
    
    
    
    toLS ls = let grid = M.fromList [ ((x,y),i) | (i, [0,1,mu]:ps) <- zip [1..] ls, [1,x,y] <- ps]
              in fMatrix' n (\i j -> grid M.! (k !! i, k !! j))
isOA (k,n) rows =
    length rows == k &&
    all ( (== n^2) . length ) rows &&
    all isOneOfEach [zip ri rj | [ri,rj] <- combinationsOf 2 rows ]
fromLS l =
    [ concat [replicate n i | i <- [1..n] ] 
    , concat (replicate n [1..n])           
    , concat l                              
    ]
    where n = length l 
fromMOLS mols =
    (concat [replicate n i | i <- [1..n] ]) : 
    (concat (replicate n [1..n]) ) :          
    map concat mols                           
    where n = length $ head mols 
graphOA rows = graph (vs,es) where
    vs = L.transpose rows 
    es = [ [v1,v2] | [v1,v2] <- combinationsOf 2 vs, or (zipWith (==) v1 v2) ]
    
srgParamsOA (k,n) =  Just ( n^2, (n-1)*k, n-2+(k-1)*(k-2), k*(k-1) )