{-# LANGUAGE NoImplicitPrelude
        , NoMonomorphismRestriction, RebindableSyntax, ConstraintKinds, RankNTypes #-}
module Knots.Complex where

import Knots.Prelude hiding (Rational)
import Prelude (Rational)
import Control.DeepSeq
import Control.Parallel.Strategies
import Knots.Free
import Knots.Graded
import Knots.Morphism
import Knots.PD
import Knots.Util
import Knots.Khovanov
import qualified Data.Map as Map
import qualified Data.Set as Set

type Complex b r = [ Lin b b r ]
type GradedComplex b r = Map Int (Complex b r)

-- | Compute cohomology ranks of a complex
cohomology :: (NFData b, NFData r, Basis b, Field r, Eq r) => Complex b r -> [Int]
cohomology = map (\phi -> Set.size (dom phi) - length (steps phi)) . gaussComplex

-- | Compute cohomology ranks of a graded complex
cohomologyGraded :: (NFData b, NFData r, Basis b, Field r, Eq r) => GradedComplex b r -> Map Int [Int]
cohomologyGraded = fmap cohomology

mx m n vs = lin [1..n]
                [1..m]
                (indexify (map indexify vs))

indexify = purify . plus . zip [1..]
purify = liftF (Map.filter (/= zero))

purifyLin = onMatrix $ purify . fmap purify

sampleComplex :: Complex Int Rational
sampleComplex =
    [ mx 6 4  [ [1,0,1,0,1,0]
              , [0,1,0,1,0,1]
              , [0,1,0,1,0,1]
              , [0,0,0,0,0,0]
              ]
    , mx 12 6  [ [0,1,1,0,0,0,0,0,0,1,1,0]
               , [0,0,0,1,0,0,0,0,0,0,0,1]
               , [0,-1,-1,0,0,1,1,0,0,0,0,0]
               , [0,0,0,-1,0,0,0,1,0,0,0,0]
               , [0,0,0,0,0,-1,-1,0,0,-1,-1,0]
               , [0,0,0,0,0,0,0,-1,0,0,0,-1]
               ]
    , mx 8 12 [ [0,0,1,0,1,0,0,0]
              , [0,0,0,1,0,1,0,0]
              , [0,0,0,0,0,0,1,0]
              , [0,0,0,0,0,0,0,1]
              , [0,1,0,0,1,0,0,0]
              , [0,0,0,0,0,1,0,0]
              , [0,0,0,1,0,0,1,0]
              , [0,0,0,0,0,0,0,1]
              , [0,-1,-1,0,0,0,0,0]
              , [0,0,0,-1,0,0,0,0]
              , [0,0,0,0,0,-1,-1,0]
              , [0,0,0,0,0,0,0,-1]
              ]

    ]

-- | Check whether the differential squares to zero
isComplex :: (Basis b, RingEq r) => Complex b r -> Bool
isComplex cx = all isNullMatrix $ cx `oooo` (zero : cx)
-- Alternatively:
--          isComplex (cx) = case cx of
--                    [] -> True
--                    [x] -> True
--                    (x:y:zs) -> isNullMatrix (y `o` x) && isComplex ((y:zs))

-- | The basis vectors in the domain that actually occur. In other words, the
-- basis vectors that correspond to (potentially) non-zero columns.
--
-- Assuming the morphism to be given in column echelon form, this counts
-- steps.
steps = Map.elems . fmap (fst . Map.findMin) . columns

omit :: (Basis b) => [b] -> Free b r -> Free b r
omit bs = liftF (Map.filterWithKey (\b _ -> b `notElem` bs))

omitColumns :: (Basis b) => [b] -> Lin b c r -> Lin b c r
omitColumns bs (Lin from to mx) = Lin (Set.difference from (Set.fromList bs))
                                      to
                                      (omit bs mx)

-- | Elementary column transformations, stepping through the complex from left
-- to right. The resulting complex has the same cohomology and has all
-- matrices in column echelon form.
gaussComplex :: (NFData b, NFData r, Basis b, Field r, Eq r) => Complex b r -> Complex b r
gaussComplex (cx) = (go zero cx) where
    go x []     = let domain = Set.difference (cod x) (Set.fromList $ steps x)
                  in [ Lin domain Set.empty zero ]
    go x (y:ys) = let y' = gauss $ omitColumns (steps x) y
                  in y' : go y' ys

type C r = Graded (Complex [B] r)

instance (AbelianGroup a) => AbelianGroup [a] where
    zero = []
    negate = map negate
    (a:as) + (b:bs) = (a+b) : (as+bs)
    [] + bs         = bs
    as + []         = as

instance (Ring a) => Ring [a] where
    fromInteger n = [ fromInteger n ]
    (f:fs) * (g:gs) = f*g : [f]*gs + fs*(g:gs)
    _ * _           = []

ooo :: (RingEq r) => C r -> C r -> C r
Graded g x `ooo` Graded h y = Graded
    { grade      = g + h
    , components = let x_shifted = Map.mapKeys (subtract h) x
                   in  Map.intersectionWith oooo x_shifted y
                       `Map.union` fmap (zero `oooo`) (y Map.\\ x_shifted)
                       `Map.union` fmap (`oooo` zero) (x_shifted Map.\\ y)
    }

(g:gs) `oooo` (f:fs) = (g `o` f) : (gs `oooo` fs)
[] `oooo` fs = map (zero `o`) fs
gs `oooo` [] = map (`o` zero) gs

khCx_Q :: Ord a => Maybe a -> PD a -> GradedComplex IntPair Rational
khCx_Q = khCx

khCx_Z :: Ord a => Maybe a -> PD a -> GradedComplex IntPair Integer
khCx_Z = khCx

khCx_F2 :: Ord a => Maybe a -> PD a -> GradedComplex IntPair F2
khCx_F2 = khCx

groupByKeySize range mp = [ Map.filterWithKey (\s _ -> Set.size s == k) mp | k <- range ]

-- | Example: Compute rational Khovanov homology and print ungraded results.
example :: Ord a => PD a -> [Int]
example input =
    let computed = fmap cohomology (khCx_Q Nothing input) `using` parTraversable (evalList rseq)
    in  foldr (+) [] computed

khCx mark pd = let n = length pd in
      hashInComplex
    . map2 (maybe zero id)
    . convertMap4
    . map2 join_lin
    . map (\(s,mp) -> Map.mapKeys (+ s) mp)
    . zip [0..]
    . map convert
    . map (\((from,to),phi) -> Lin from to (map' F $ map2 components phi))
    . zip [ (choose'' n i, choose'' n (i+1)) | i <- [0..n-1] ]
    . groupByKeySize [0..n-1]
    . convertMap2
    . Map.mapKeys swap
    . fmap (uncurry toMorphism)
    . khovanov mark
    $ pd

conv =
      map2 (maybe zero id)
    . convertMap4
    . map (\(s,mp) -> Map.mapKeys (+ s) mp)
    . zip [0..]
    . map (\x -> components (toMorphism x None))

hash :: (Set Int, [B]) -> IntPair
hash (s, l) = IntPair (bitfield s)
                      (bitfield (map snd . filter (isB1 . fst) $ zip l [0::Int ..]))

bitfield = foldl' (\acc x -> acc + 2^x) 0

hashInComplex :: GradedComplex (Set Int, [B]) r -> GradedComplex IntPair r
hashInComplex = fmap . fmap $ \(Lin from to mx) ->
    Lin (Set.map hash from) (Set.map hash to) (fmap (mapBasis hash) $ mapBasis hash mx)