-- | Bases in the cohomology of the spaces appearing in the computations.
--
-- We have three different spaces: 
--
-- * @Q^n = P^1 x P^1 x ... x P^1@ (@n@ times; @m = length lambda@)
--
-- * @Q^m = P^1 x P^1 x ... x P^1 x P^1@ (@m@ times, @m = sum lambda >= n@)
-- 
-- * @P^m = P(Sym^m C^2)@
--
-- Furthermore, we have @GL2@ acting naturally on these spaces.
--

{-# LANGUAGE 
      BangPatterns, TypeSynonymInstances, FlexibleInstances, DeriveFunctor, 
      ScopedTypeVariables, Rank2Types 
  #-}
module Math.RootLoci.Geometry.Cohomology where

--------------------------------------------------------------------------------

import Data.List
import Data.Monoid

import Math.Combinat.Numbers

import qualified Data.Map as Map
import qualified Data.Set as Set

import qualified Math.RootLoci.Algebra.FreeMod as ZMod
import Math.RootLoci.Algebra.FreeMod ( ZMod , FreeMod(..) , unFreeMod )

import Math.RootLoci.Algebra.SymmPoly 
import Math.RootLoci.Misc.Pretty

--------------------------------------------------------------------------------
-- * The non-equivariant case

-- | a (ring) generator of @H^*(Q^n)@ (note that @u_i^2 = 0@)
newtype U = U Int deriving (Eq,Ord,Show)

-- | (a ring) generator of @H^*(Q^m)@ (note that @h_i^2 = 0@)
newtype H = H Int deriving (Eq,Ord,Show)

-- | the generator of @H^*(P^n)@ (with @g^(n+1) = 0@)
newtype G = G Int deriving (Eq,Ord,Show)

-- | A monomial in @u_i@ (encoded as a subset of @[1..n]@, as @u_i^2=0@)
newtype US = US [U] deriving (Eq,Ord,Show)

-- | A monomial in @h_i@ (encoded as a subset of @[1..m]@, as @h_i^2=0@)
newtype HS = HS [H] deriving (Eq,Ord,Show)

--------------------------------------------------------------------------------

instance Monoid US where
  mempty = US []
  (US us1) `mappend` (US us2) = 
    if nub us3 == us3
      then US us3
      else error "[U]/monoid: duplicate indices"
    where
      us3 = sort (us1 ++ us2)

instance Monoid HS where
  mempty = HS []
  (HS hs1) `mappend` (HS hs2) = 
    if nub hs3 == hs3
      then HS hs3
      else error "[H]/monoid: duplicate indices"
    where
      hs3 = sort (hs1 ++ hs2)

instance Monoid G where
  mempty = G 0
  (G e) `mappend` (G f) = G (e+f)
 
--------------------------------------------------------------------------------

instance Pretty G where
  pretty (G e) = "g^" ++ show e

instance Pretty H where
  pretty (H i) = "h" ++ show i

instance Pretty U where
  pretty (U i) = "u" ++ show i

instance Pretty HS where
  pretty (HS []) = ""
  pretty (HS hs) = intercalate "*" (map pretty hs)

instance Pretty US where
  pretty (US []) = ""
  pretty (US us) = intercalate "*" (map pretty us)

--------------------------------------------------------------------------------

instance Graded U where grade _ = 1
instance Graded H where grade _ = 1
instance Graded G where grade (G g) = g
instance Graded HS where grade (HS js) = length js
instance Graded US where grade (US js) = length js

instance Graded ab => Graded (Omega ab) where grade (Omega us ab) = length us + grade ab
instance Graded ab => Graded (Eta   ab) where grade (Eta   hs ab) = length hs + grade ab
instance Graded ab => Graded (Gam   ab) where grade (Gam   g  ab) = g + grade ab

--------------------------------------------------------------------------------
-- * The equivariant case

-- | A monomial generator of @Z[alpha,beta;u1,u2,...,u_nd]/(...)@, 
-- the cohomology ring of @Q^n@. 
--
-- The encoding is that the list is the list of indices of @u@ which appear.
data Omega ab = Omega ![Int] !ab deriving (Eq,Ord,Show,Functor)

-- | A monomial generator of @Z[alpha,beta;eta1,eta2...eta_m]/(...)@,
-- he cohomology ring of @Q^m@. 
--
-- The encoding is that the list is the list of indices of @eta@ which appear.
data Eta ab = Eta ![Int] !ab deriving (Eq,Ord,Show,Functor)

-- | A monomial generator of @Z[alpha,beta;gamma]/(...)@,
-- the cohomology ring of @P^m@. 
data Gam ab = Gam !Int !ab deriving (Eq,Ord,Show,Functor)

--------------------------------------------------------------------------------

-- | Class of monomial bases which form modules over the @H^*(BGL2)@
class Functor f => Equivariant f where 
  injectMonom  :: x -> f x
  projectMonom :: f x -> x

instance Equivariant Omega where 
  injectMonom = Omega [] 
  projectMonom (Omega _ ab) = ab

instance Equivariant Eta where 
  injectMonom = Eta [] 
  projectMonom (Eta _ ab) = ab

instance Equivariant Gam where  
  injectMonom = Gam 0  
  projectMonom (Gam _ ab) = ab

injectZMod :: (Equivariant f, ChernBase base, Ord (f base)) => ZMod base -> ZMod (f base)
injectZMod = ZMod.mapBase injectMonom

forgetGamma :: Ord base => ZMod (Gam base) -> ZMod base 
forgetGamma = ZMod.filterBase f where
  f (Gam k ab) = case k of
    0 -> Just ab
    _ -> Nothing

forgetEquiv :: ChernBase base => ZMod (Gam base) -> ZMod G
forgetEquiv = ZMod.filterBase f where
  f (Gam k ab) = if (ab == mempty) 
    then Just (G k)
    else Nothing

--------------------------------------------------------------------------------
-- * Conversion between different bases

convertOmega   
  :: (Ord ab, Ord cd)
  => (ZMod ab -> ZMod cd) 
  -> ZMod (Omega ab) -> ZMod (Omega cd)
convertOmega = convertEach f g Omega where
  f (Omega xs _ ) = xs
  g (Omega _  ab) = ab

convertEta
  :: (Ord ab, Ord cd)
  => (ZMod ab -> ZMod cd) 
  -> ZMod (Eta ab) -> ZMod (Eta cd)
convertEta = convertEach f g Eta where
  f (Eta xs _ ) = xs
  g (Eta _  ab) = ab

convertGam
  :: (Ord ab, Ord cd)
  => (ZMod ab -> ZMod cd) 
  -> ZMod (Gam ab) -> ZMod (Gam cd)
convertGam = convertEach f g Gam where
  f (Gam k _ ) = k
  g (Gam _ ab) = ab

-- | A generic function which can convert the @GL2@ representations
convertEach 
  :: forall f x y ab cd. (Functor f, Ord ab, Ord cd, Ord (f ab), Ord (f cd), Ord x) 
  => (forall y. f y -> x)
  -> (forall y. f y -> y)
  -> (forall y. x -> y -> f y)
  -> (ZMod    ab  -> ZMod    cd )
  ->  ZMod (f ab) -> ZMod (f cd)
convertEach selx sely build convert src = tgt where
  tgt    = ZMod.sum [ worker layer | layer <- layers ]
  layers = Set.toList $ Set.map selx $ Map.keysSet $ unFreeMod src :: [x]
  worker layer 
    = FreeMod
    $ Map.mapKeys (build layer)
    $ unFreeMod
    $ convert
    $ FreeMod
    $ Map.mapKeys sely 
    $ Map.filterWithKey (\k _ -> selx k == layer) 
    $ unFreeMod src

--------------------------------------------------------------------------------

-- | This is a hack to reuse the same pushforward code
unsafeEtaToOmega :: Ord ab => FreeMod coeff (Eta ab) -> FreeMod coeff (Omega ab)
unsafeEtaToOmega = ZMod.mapBase f where
  f (Eta js ab) = Omega js ab

unsafeOmegaToEta :: Ord ab => FreeMod coeff (Omega ab) -> FreeMod coeff (Eta ab)
unsafeOmegaToEta = ZMod.mapBase f where
  f (Omega js ab) = Eta js ab

--------------------------------------------------------------------------------

instance Monoid ab => Monoid (Omega ab) where
  mempty = Omega [] mempty
  (Omega as ab1) `mappend` (Omega bs ab2) = 
    if nub cs == cs
      then Omega cs (ab1 <> ab2)
      else error "Omega/monoid: duplicate indices"
    where
      cs = sort (as ++ bs)

instance Monoid ab => Monoid (Eta ab) where
  mempty = Eta [] mempty
  (Eta fs ab1) `mappend` (Eta gs ab2) = 
    if nub hs == hs
      then Eta hs (ab1 <> ab2)
      else error "Eta/monoid: duplicate indices"
    where
      hs = sort (fs ++ gs)

instance Monoid ab => Monoid (Gam ab) where
  mempty = Gam 0 mempty
  (Gam e ab1) `mappend` (Gam f ab2) = Gam (e+f) (ab1 <> ab2)

--------------------------------------------------------------------------------

instance (Pretty ab, Monoid ab, Eq ab) => Pretty (Gam ab) where
  pretty (Gam 0 ab) = pretty ab
  pretty (Gam g ab)
    | ab == mempty  = "g^" ++ show g
    | otherwise     = "g^" ++ show g ++ "*" ++ pretty ab

instance (Pretty ab, Monoid ab, Eq ab) => Pretty (Eta ab) where
  pretty eta = 
    case eta of
      (Eta [] ab)       -> pretty ab 
      (Eta is ab)   
        | ab == mempty  -> hs is
        | otherwise     -> hs is ++ "*" ++ pretty ab 
    where
      hs is = case is of
        [] -> ""
        _  -> intercalate "*" [ "h" ++ show i | i<-is ]

instance (Pretty ab, Monoid ab, Eq ab) => Pretty (Omega ab) where
  pretty omega = 
    case omega of
      (Omega [] ab)       -> pretty ab 
      (Omega is ab)    
        | ab == mempty    -> us is
        | otherwise       -> us is ++ "*" ++ pretty ab 
    where
      us is = case is of
        [] -> ""
        _  -> intercalate "*" [ "u" ++ show i | i<-is ]

--------------------------------------------------------------------------------