module Neet.Genome ( 
                     NodeId(..)
                   , NodeType(..)
                   , NodeGene(..)
                   , ConnGene(..)
                   , InnoId(..)
                   , ConnSig
                     
                   , Genome(..)
                     
                   , fullConn
                     
                   , mutate
                   , crossover
                   , breed
                     
                   , distance
                     
                   , renderGenome
                   ) where
import Control.Applicative
import Control.Monad
import Control.Monad.Random
import Data.Map.Strict (Map)
import qualified Data.Traversable as T
import qualified Data.Map.Strict as M
import Data.Set (Set)
import qualified Data.Set as S
import Data.Maybe
import Control.Monad.Fresh.Class
import Neet.Parameters
import Data.GraphViz
import Data.GraphViz.Attributes.Complete
newtype NodeId = NodeId Int
               deriving (Show, Eq, Ord, PrintDot)
data NodeType = Input | Hidden | Output
              deriving (Show, Eq)
data NodeGene = NodeGene { nodeType :: NodeType
                         , yHint :: Rational 
                         }
              deriving (Show)
data ConnGene = ConnGene { connIn :: NodeId
                         , connOut :: NodeId
                         , connWeight :: Double
                         , connEnabled :: Bool
                         , connRec :: Bool 
                         }
              deriving (Show)
newtype InnoId = InnoId Int
               deriving (Show, Eq, Ord)
data Genome =
  Genome { nodeGenes :: Map NodeId NodeGene
         , connGenes :: Map InnoId ConnGene
         , nextNode :: NodeId
         }
  deriving (Show)
fullConn :: MonadRandom m => Parameters -> Int -> Int -> m Genome
fullConn Parameters{..} iSize oSize = do
  let inCount = iSize + 1
      inIDs = map NodeId [1..inCount]
      outIDs = map NodeId [inCount + 1..oSize + inCount]
      inputGenes = zip inIDs $ repeat (NodeGene Input 0)
      outputGenes = zip outIDs $ repeat (NodeGene Output 1)
      nodeGenes = M.fromList $ inputGenes ++ outputGenes
      nextNode = NodeId $ inCount + oSize + 1
      nodePairs = (,) <$> inIDs <*> outIDs
  conns <- zipWith (\(inN, outN) w -> ConnGene inN outN w True False) nodePairs `liftM` getRandomRs (weightRange,weightRange)
  let connGenes = M.fromList $ zip (map InnoId [1..]) conns
  return $ Genome{..}
mutateWeights :: MonadRandom m => Parameters -> Genome -> m Genome
mutateWeights Parameters{..} g@Genome{..} = do
  roll <- getRandomR (0,1)
  if roll > mutWeightRate
    then return g
    else setConns g `liftM` T.mapM mutOne connGenes
  where setConns g cs = g { connGenes = cs }
        mutOne conn = do
          roll <- getRandomR (0,1)
          let newWeight
                | roll <= newWeightRate = getRandomR (weightRange,weightRange)
                | otherwise = do
                    pert <- getRandomR (pertAmount,pertAmount)
                    return $ connWeight conn + pert
          w <- newWeight
          return $ conn { connWeight = w }
data ConnSig = ConnSig NodeId NodeId
             deriving (Show, Eq, Ord)
toConnSig :: ConnGene -> ConnSig
toConnSig gene = ConnSig (connIn gene) (connOut gene)
addConn :: MonadFresh InnoId m => ConnGene ->
           (Map ConnSig InnoId, Map InnoId ConnGene) ->
           m (Map ConnSig InnoId, Map InnoId ConnGene)
addConn conn (innos, conns) = case M.lookup siggy innos of
  Just inno -> return (innos, M.insert inno conn conns)
  Nothing -> do
    newInno <- fresh
    return (M.insert siggy newInno innos, M.insert newInno conn conns)
  where siggy = toConnSig conn
mutateConn :: (MonadFresh InnoId m, MonadRandom m) =>
              Parameters -> Map ConnSig InnoId -> Genome -> m (Map ConnSig InnoId, Genome)
mutateConn params innos g = do
  roll <- getRandomR (0,1)
  if roll > addConnRate params
    then return (innos, g)
    else case allowed of
          [] -> return (innos, g)
          xs -> do
             (innos', conns') <- addRandConn innos (connGenes g)
             return $ (innos', g { connGenes = conns' })
             
  where 
        
        
        taken :: Map ConnSig Bool
        taken = M.fromList . map (\c -> (toConnSig c, True)) . M.elems . connGenes $ g
        
        notInput (NodeGene Input _) = False
        notInput _                  = True
        
        nodes = M.toList $ nodeGenes g
        
        nonInputs = filter (notInput . snd) nodes
        
        makePair (n1,g1) (n2,g2) = (ConnSig n1 n2, yHint g2 <= yHint g1)
        
        candidates = M.fromList $ makePair <$> nodes <*> nonInputs
        
        allowed = M.toList $ M.difference candidates taken
        
        pickOne :: MonadRandom m => m (ConnSig, Bool)
        pickOne = uniform allowed
        pickWeight :: MonadRandom m => m Double
        pickWeight = let r = weightRange params in getRandomR (r,r)
        
        
        addRandConn :: (MonadRandom m, MonadFresh InnoId m) =>
                       Map ConnSig InnoId -> Map InnoId ConnGene ->
                       m (Map ConnSig InnoId, Map InnoId ConnGene)
        addRandConn innos conns = do
          (ConnSig inNode outNode, recc) <- pickOne
          w <- pickWeight
          let newConn = ConnGene inNode outNode w True recc
          addConn newConn (innos,conns)
mutateNode :: (MonadRandom m, MonadFresh InnoId m) =>
              Parameters -> Map ConnSig InnoId ->
              Genome -> m (Map ConnSig InnoId, Genome)
mutateNode params innos g = do
  roll <- getRandomR (0,1)
  if roll <= addNodeRate params then addRandNode else return (innos, g)
  where conns = connGenes g
        nodes = nodeGenes g
        
        pickConn :: MonadRandom m => m (InnoId, ConnGene)
        pickConn = uniform $ M.toList conns
        
        newId = nextNode g
        
        newNextNode = case newId of NodeId x -> NodeId (x + 1)
        
        
        addNode :: MonadFresh InnoId m =>
                   InnoId -> ConnGene -> m (Map ConnSig InnoId, Genome)
        addNode inno gene = do
          let ConnSig inId outId = toConnSig gene
              
              inGene = nodes M.! inId
              
              outGene = nodes M.! outId
              
              newGene = NodeGene Hidden ((yHint inGene + yHint outGene) / 2)
              
              newNodes = M.insert newId newGene nodes
              
              disabledConn = gene { connEnabled = False }
              
              backGene = ConnGene inId newId 1 True (connRec gene)
              
              forwardGene = ConnGene newId outId (connWeight gene) True (connRec gene)
              
          (innos', newConns) <-
            addConn backGene >=> addConn forwardGene $ (innos, conns)
          return $ (innos', g { nodeGenes = newNodes
                              , connGenes = M.insert inno disabledConn newConns
                              , nextNode = newNextNode
                              })
        
        addRandNode :: (MonadRandom m, MonadFresh InnoId m) => m (Map ConnSig InnoId, Genome)
        addRandNode =
          pickConn >>= uncurry addNode
mutate :: (MonadRandom m, MonadFresh InnoId m) => Parameters -> Map ConnSig InnoId ->
          Genome -> m (Map ConnSig InnoId, Genome)
mutate params innos g = do
  g' <- mutateWeights params g
  uncurry (mutateNode params) >=> uncurry (mutateConn params) $ (innos, g')
superLeft :: Ord k => (a -> b -> c) -> (a -> c) -> Map k a -> Map k b -> Map k c
superLeft comb mk = M.mergeWithKey (\_ a b -> Just $ comb a b) (M.map mk) (const M.empty)
flipCoin :: MonadRandom m => a -> a -> m a
flipCoin a1 a2 = uniform [a1, a2]
crossConns :: MonadRandom m => Parameters -> Map InnoId ConnGene -> Map InnoId ConnGene ->
              m (Map InnoId ConnGene)
crossConns params m1 m2 = T.sequence $ superLeft flipConn return m1 m2
  where flipConn c1 c2 = do
          if connEnabled c1 && connEnabled c2
            then flipCoin c1 c2
            else do
            c <- flipCoin c1 c2
            roll <- getRandomR (0,1)
            let enabled
                  | roll <= disableChance params = False
                  | otherwise = True
            return c { connEnabled = enabled }
crossNodes :: MonadRandom m => Map NodeId NodeGene -> Map NodeId NodeGene ->
              m (Map NodeId NodeGene)
crossNodes m1 m2 = T.sequence $ superLeft flipCoin return m1 m2
crossover :: MonadRandom m => Parameters -> Genome -> Genome -> m Genome
crossover params g1 g2 = Genome `liftM` newNodes `ap` newConns `ap` return newNextNode
  where newNextNode = max (nextNode g1) (nextNode g2)
        newConns = crossConns params (connGenes g1) (connGenes g2)
        newNodes = crossNodes (nodeGenes g1) (nodeGenes g2)
breed :: (MonadRandom m, MonadFresh InnoId m) =>
         Parameters -> Map ConnSig InnoId -> Genome -> Genome ->
         m (Map ConnSig InnoId, Genome)
breed params innos g1 g2 =
  crossover params g1 g2 >>= mutate params innos
differences :: Map InnoId ConnGene -> Map InnoId ConnGene -> Map InnoId Double
differences = M.mergeWithKey (\_ c1 c2 -> Just $ oneDiff c1 c2) (const M.empty) (const M.empty)
  where oneDiff c1 c2 = abs $ connWeight c1  connWeight c2
distance :: Parameters -> Genome -> Genome -> Double
distance params g1 g2 = c1 * exFactor + c2 * disFactor + c3 * weightFactor
  where DistParams c1 c2 c3 dt = distParams params
        conns1 = connGenes g1
        conns2 = connGenes g2
        
        weightDiffs = differences conns1 conns2
        weightFactor = M.foldl (+) 0 weightDiffs / fromIntegral (M.size weightDiffs)
        ids1 = M.keysSet conns1
        ids2 = M.keysSet conns2
        
        edge = min (S.findMax ids1) (S.findMax ids2)
        
        exJoints = (ids1 `S.difference` ids2) `S.union` (ids2 `S.difference` ids1)
        (excess, disjoint) = S.partition (<= edge) exJoints
        exFactor = fromIntegral $ S.size excess
        disFactor = fromIntegral $ S.size disjoint
graphParams :: GraphvizParams NodeId NodeGene Double Rational Rational
graphParams =
  Params { isDirected = True
         , globalAttributes = [ GraphAttrs [ RankDir FromLeft
                                           , Splines LineEdges
                                           ]
                              , NodeAttrs [ FixedSize SetNodeSize
                                          ]
                              ]
         , clusterBy = categorizer
         , isDotCluster = const True
         , clusterID = iderizer
         , fmtCluster = clusterizer
         , fmtNode = const []
         , fmtEdge = \(_,_,w) -> [ toLabel w ]
         }
  where categorizer (nId, ng) = C (yHint ng) (N (nId, yHint ng))
        iderizer 0 = Str "Input Layer"
        iderizer 1 = Str "Output Layer"
        iderizer rat = Num (Dbl $ fromRational rat)
        whiteAttr = Color [WC (X11Color White) Nothing]
        blueAttr = Color [WC (X11Color Blue4) Nothing ]
        redAttr = Color [WC (X11Color Red2) Nothing ]
        greenAttr = Color [WC (X11Color SeaGreen) Nothing ]
        solidAttr = Style [ SItem Solid [] ]
        circAttr = Shape Circle
        clusterizer 0 = [ GraphAttrs [ whiteAttr, rank MinRank ]
                        , NodeAttrs [ solidAttr, blueAttr, circAttr ]
                        ]
        clusterizer 1 = [ GraphAttrs [ whiteAttr, rank MaxRank ]
                        , NodeAttrs [ solidAttr, redAttr, circAttr ]
                        ]
        clusterizer _ = [ GraphAttrs [ whiteAttr ]
                        , NodeAttrs [ solidAttr, greenAttr, circAttr ]
                        ]
renderGenome :: Genome -> IO ()
renderGenome g = runGraphvizCanvas Dot graph Xlib
  where dg = DotGraph True True Nothing
        nodes = M.toList . nodeGenes $ g
        edges = mapMaybe mkEdge . M.elems . connGenes $ g
        mkEdge ConnGene{..} = if connEnabled then Just (connIn, connOut, connWeight) else Nothing
        graph = graphElemsToDot graphParams nodes edges