module Neet.Population (
                         Population(..)
                       , SpecId(..)
                         
                       , PopM
                       , PopContext
                       , runPopM
                         
                       , PopSettings(..)
                       , newPop
                         
                       , trainOnce
                       , TrainMethod(..)
                         
                       , pureTrain
                       , winTrain
                         
                       , trainN
                       , trainUntil
                       , trainPure
                         
                       , speciesCount
                         
                       , validatePopulation
                       ) where
import Neet.Species
import Neet.Genome
import Neet.Parameters
import Data.MultiMap (MultiMap)
import qualified Data.MultiMap as MM
import Data.Map (Map)
import qualified Data.Map as M
import Data.List (foldl', maximumBy, sortBy)
import Data.Maybe
import Data.Monoid
import Data.Functor.Identity
import Control.Monad.Random
import Control.Monad.Fresh.Class
import Control.Monad.Trans.State
import Control.Applicative
import Control.Monad
import Data.Traversable
import Control.Parallel.Strategies
import Data.Function
newtype SpecId = SpecId Int
               deriving (Show, Eq, Ord)
data Population =
  Population { popSize   :: Int                
             , popSpecs  :: !(Map SpecId Species) 
             , popBScore :: !Double             
             , popBOrg   :: !Genome             
             , popBSpec  :: !SpecId             
             , popCont   :: !PopContext         
             , nextSpec  :: !SpecId             
             , popParams :: Parameters        
             , popStrat  :: SearchStrat
             , popPhase  :: PhaseState
             , popGen    :: Int                
             }
  deriving (Show)
newtype TrainMethod f =
  TrainMethod { tmGen :: forall t. Traversable t => t Genome -> f (t Double)
                
              }
pureTrain :: GenScorer a -> TrainMethod Identity
pureTrain gs = TrainMethod go
  where go gmap = Identity $ fmap (fitnessFunction gs . gScorer gs) gmap
winTrain :: GenScorer a -> TrainMethod ((,) (First Genome))
winTrain gs = TrainMethod (traverse go)
  where go genome
          | winCriteria gs score = (First (Just genome), fitnessFunction gs score)
          | otherwise = (First Nothing, fitnessFunction gs score)
          where score = gScorer gs genome
data PopContext =
  PC { nextInno :: InnoId
     , randGen  :: StdGen
     } 
  deriving (Show)
newtype PopM a = PopM (State PopContext a)
            deriving (Functor, Applicative, Monad)
instance MonadRandom PopM where
  getRandom = PopM . state $ \s ->
    let (r, gen) = random (randGen s)
    in (r, s { randGen = gen })
  getRandoms = PopM . state $ \s ->
    let (g1, g2) = split $ randGen s
    in (randoms g1, s { randGen = g2 })
  getRandomR range = PopM . state $ \s ->
    let (r, gen) = randomR range (randGen s)
    in (r, s { randGen = gen })
  getRandomRs range = PopM . state $ \s ->
    let (g1, g2) = split $ randGen s
    in (randomRs range g1, s { randGen = g2 })
instance MonadFresh InnoId PopM where
  fresh = PopM . state $ \s ->
    let inno@(InnoId x) = nextInno s
    in (inno, s { nextInno = InnoId $ x + 1 })
runPopM :: PopM a -> PopContext -> (a, PopContext)
runPopM (PopM ma) = runState ma
data PopSettings =
  PS { psSize    :: Int        
     , psInputs  :: Int        
     , psOutputs :: Int        
     , psParams  :: Parameters 
     , sparse    :: Maybe Int  
                               
     , psStrategy :: Maybe PhaseParams
     } 
  deriving (Show)
newtype SpecM a = SM (State SpecId a)
                deriving (Functor, Applicative, Monad)
instance MonadFresh SpecId SpecM where
  fresh = SM . state $ \s@(SpecId x) -> (s, SpecId $ x + 1)
runSpecM :: SpecM a -> SpecId -> (a, SpecId)
runSpecM (SM ma) = runState ma
data SpecBucket =
  SB SpecId Genome [Genome]
shuttleOrgs :: MonadFresh SpecId m =>
               Parameters -> [SpecBucket] -> [Genome] -> m [SpecBucket]
shuttleOrgs p@Parameters{..} buckets = foldM shutOne buckets
  where DistParams{..} = distParams specParams
        shutOne :: MonadFresh SpecId m => [SpecBucket] -> Genome -> m [SpecBucket]
        shutOne (SB sId rep gs:bs) g
          | distance p g rep <= delta_t = return $ SB sId rep (g:gs) : bs
          | otherwise = liftM (SB sId rep gs :) $ shutOne bs g
        shutOne [] g = do
          newId <- fresh
          return $ [SB newId g [g]]
zipWithDefaults :: (a -> b -> Maybe c) -> (a -> Maybe c) -> (b -> Maybe c) -> [a] -> [b] -> [c]
zipWithDefaults _ _  db [] bs = mapMaybe db bs
zipWithDefaults _ da _  as [] = mapMaybe da as
zipWithDefaults f da db (a:as) (b:bs) =
  case f a b of
   Just res -> res : zipWithDefaults f da db as bs
   Nothing -> zipWithDefaults f da db as bs
speciate :: MonadFresh SpecId m =>
            Parameters -> Map SpecId Species -> [Genome] -> m (Map SpecId Species)
speciate params specs gens = do
  filled <- fill
  let zipped = zipWithDefaults mkSpecies (const Nothing) newSpecies specL filled
  return $ M.fromList zipped
  where oneBucket (k, Species _ (rep:_) _ _) = SB k rep []
        oneBucket _                          = error "(speciate) Empty species!"
        assocs = M.toList specs
        specL = map snd assocs
        buckets = map oneBucket assocs
        fill = shuttleOrgs params buckets gens
        mkSpecies (Species _ _ scr imp) (SB sId _ gs)
          | null gs = Nothing
          | otherwise = Just $ (sId, Species (length gs) gs scr imp)
        newSpecies (SB _ _ []) = Nothing
        newSpecies (SB sId _ (g:gs)) = Just $ (sId, newSpec g gs)
newPop :: Int -> PopSettings -> Population
newPop seed PS{..} = fst $ runPopM generate initCont
  where Parameters{..} = psParams
        popSize = psSize
        popBScore = 0
        popBSpec = SpecId 1
        initCont = PC (InnoId $ psInputs * psOutputs + 2) (mkStdGen seed)
        popParams = psParams
        orgGenner = case sparse of
                     Nothing -> fullConn mutParams
                     Just conCount -> sparseConn mutParams conCount
        generateGens = replicateM psSize (orgGenner psInputs psOutputs)
        popGen = 1
        generate = do
          gens <- generateGens
          let (popSpecs, nextSpec) = runSpecM (speciate psParams M.empty gens) (SpecId 1)
              popBOrg = head gens
              avgComp = fromIntegral (foldl' (+) 0 . map genomeComplexity $ gens) / fromIntegral popSize
              (popStrat, popPhase) = case psStrategy of
                Nothing -> (Complexify, Complexifying 0) 
                Just pp@PhaseParams{..} ->
                  (Phased pp, Complexifying (phaseAddAmount + avgComp))
          popCont <- PopM get
          return Population{..}
trainOnce :: Applicative f => TrainMethod f -> Population -> f Population
trainOnce method pop = fmap (flip trainOnceWFits pop) $ specRes
  where specRes = traverse (runFitTestWStrategy (tmGen method)) (popSpecs pop)
trainOnceWFits :: Map SpecId TestResult -> Population -> Population
trainOnceWFits tResults pop = generated
  where params = popParams pop
        mParams = mutParams params
        mParamsS = mutParamsS params
        avgComp = avgComplexity pop
        newPhase = case (popPhase pop, popStrat pop) of
          (_, Complexify) -> Complexifying 0
          (Complexifying thresh, Phased PhaseParams{..})
            | avgComp < thresh -> Complexifying thresh
            | otherwise -> Pruning 0 avgComp
          (Pruning lastFall lastComp, Phased PhaseParams{..})
            | avgComp < lastComp -> Pruning 0 avgComp
            | lastFall >= phaseWaitTime -> Complexifying (avgComp + phaseAddAmount)
            | otherwise -> Pruning (lastFall + 1) avgComp
            
        isPruning = case newPhase of
          Pruning _ _ -> True
          _ -> False
        chooseParams :: Species -> MutParams
        chooseParams s = if specSize s >= largeSize params then mParams else mParamsS
        
        
        initSpecs = popSpecs pop
        oneEval :: Strategy (Species, TestResult)
        oneEval = evalTuple2 r0 rseq
        
        fits = M.intersectionWith (,) initSpecs tResults `using` parTraversable oneEval
        
        eugenics :: SpecId -> (Species, TestResult) ->
                    Maybe (Species, MultiMap Double Genome, Double)
        eugenics sId (sp, TR{..})
          | maybe False (lastImprovement nSpec >=) (dropTime params)
            && sId /= popBSpec pop = Nothing
          | otherwise = Just (nSpec, trScores, trAdj)
          where nSpec = updateSpec trSS sp
        
        masterRace :: Map SpecId (Species, MultiMap Double Genome, Double)
        masterRace = M.mapMaybeWithKey eugenics fits
        
        masterList :: [(SpecId,(Species, MultiMap Double Genome, Double))]
        masterList = M.toList masterRace
        
        idVeryBest :: (SpecId, Species)
        idVeryBest = maximumBy (compare `on` (bestScore . specScore . snd)) $ map clean masterList
          where clean (sId,(sp, _, _)) = (sId,sp)
        veryBest = snd idVeryBest
        bestId = fst idVeryBest
        
        masterSpec :: Map SpecId Species
        masterSpec = M.map (\(s,_,_) -> s) masterRace
        totalFitness = M.foldl' (+) 0 . M.map (\(_,_,x) -> x) $ masterRace
        totalSize = popSize pop
        dubSize = fromIntegral totalSize
        
        candSpecs :: MonadRandom m => [(MutParams, Int, m (Double,Genome))]
        candSpecs = zip3 ps realShares pickers
          where sortedMaster = sortBy revComp masterList
                
                revComp (_,(sp1,_,_)) (_,(sp2,_,_)) = (compare `on` (bestScore . specScore)) sp2 sp1
                initShares = snd $ mapAccumL share 0 sortedMaster
                share skim (_,(_, _, adj)) = (newSkim, actualShare)
                  where everything = adj / totalFitness * dubSize + skim
                        actualShare = floor everything
                        newSkim = everything  fromIntegral actualShare
                remaining = totalSize  foldl' (+) 0 initShares
                distributeRem _ [] = error "Should run out of numbers first"
                distributeRem n l@(x:xs)
                  | n > 0 = x + 1 : distributeRem (n  1) xs
                  | n < 0 = error "Remainder should be positive"
                  | otherwise = l
                realShares = distributeRem remaining initShares
                pickers :: MonadRandom m => [m (Double, Genome)]
                pickers = map picker sortedMaster
                  where picker (_,(s, mmap, _)) =
                          let numToTake = specSize s `div` 5 + 1
                              desc = M.toDescList $ MM.toMap mmap
                              toPairs (k, vs) = map (\v -> (k,v)) vs
                              culled = take numToTake $ desc >>= toPairs
                          in uniform culled
                ps = map (\(_,(s,_,_)) -> chooseParams s) sortedMaster
        applyN :: Monad m => Int -> (a -> m a) -> a -> m a
        applyN 0 _  x = return x
        applyN n h !x = h x >>= applyN (n  1) h
        
        specGens :: (MonadFresh InnoId m, MonadRandom m) =>
                    Map ConnSig InnoId -> (MutParams, Int, m (Double, Genome)) ->
                    m (Map ConnSig InnoId, [Genome])
        specGens inns (p, n, gen) = applyN n genOne (inns, [])
          where genOne (innos, gs)
                  | isPruning = do
                      (_,parent) <- gen
                      g <- mutateSub p parent
                      return (innos, g:gs)
                  | otherwise =  do
                      roll <- getRandomR (0,1)
                      if roll <= noCrossover p
                        then do
                        (_,parent) <- gen
                        (innos', g) <- mutateAdd p innos parent
                        return (innos', g:gs)
                        else do
                        (fit1, mom) <- gen
                        (fit2, dad) <- gen
                        (innos', g) <- if fit1 > fit2
                                       then breed p innos mom dad
                                       else breed p innos dad mom
                        return (innos', g:gs)
                
        allGens :: (MonadRandom m, MonadFresh InnoId m) => m [Genome]
        allGens = liftM (concat . snd) $ foldM ag' (M.empty, []) candSpecs
          where ag' (innos, cands) cand = do
                  (innos', specGen) <- specGens innos cand
                  return $ (innos', specGen:cands)
                  
        genNewSpecies :: (MonadRandom m, MonadFresh InnoId m) => m (Map SpecId Species, SpecId)
        genNewSpecies = do
          gens <- allGens
          return $ runSpecM (speciate params masterSpec gens) (nextSpec pop)
        generated :: Population
        generated = fst $ runPopM generate (popCont pop)
        generate :: PopM Population
        generate = do
          (specs, nextSpec') <- genNewSpecies
          let specCount = M.size specs
              newParams :: Parameters
              newParams =
                case specParams params of
                 Simple _ -> params
                 Target dp st@SpeciesTarget{..}
                   | specCount > snd targetCount ->
                       let newDP = dp { delta_t = delta_t dp + adjustAmount }
                       in params { specParams = Target newDP st }
                   | specCount < fst targetCount ->
                       let newDP = dp { delta_t = delta_t dp  adjustAmount }
                       in params { specParams = Target newDP st }
                   | otherwise -> params
                             
              bScoreNow = (bestScore . specScore) veryBest
              bOrgNow = (bestGen . specScore) veryBest
              bSpecNow = bestId
              (bScore, bOrg, bSpec) =
                if bScoreNow > popBScore pop
                then (bScoreNow, bOrgNow, bSpecNow)
                else (popBScore pop, popBOrg pop, popBSpec pop)
          cont' <- PopM get
          return pop { popSpecs = specs
                     , popBScore = bScore
                     , popBOrg = bOrg
                     , popBSpec = bSpec
                     , popCont = cont'
                     , popParams = newParams
                     , nextSpec = nextSpec'
                     , popGen = popGen pop + 1
                     , popPhase = newPhase
                     } 
trainN :: (Applicative f, Monad f) =>
          TrainMethod f -> Int -> Population -> f Population
trainN tm n p
  | n <= 0 = return p
  | otherwise = applyN n (trainOnce tm) (return p)
  where applyN n' h !x = iterate (>>= h) x !! n'
trainPure :: GenScorer a -> Population -> Population
trainPure gs pop = runIdentity $ trainOnce (pureTrain gs) pop
trainUntil :: Int -> GenScorer a -> Population -> (Population, Maybe (Genome, Int))
trainUntil n f p
  | n <= 0 = (p, Nothing)
  | otherwise = go n p
  where go 0  !p' = (p', Nothing)
        go n' !p' = case trainOnce (winTrain f) p' of
                     (First Nothing, p'') -> go (n'  1) p''
                     (First (Just g), p'') -> (p'', Just (g, n  n'))
speciesCount :: Population -> Int
speciesCount Population{..} = M.size popSpecs
avgComplexity :: Population -> Double
avgComplexity pop = fromIntegral (totalComplexityMap (popSpecs pop)) / fromIntegral (popSize pop)
totalComplexityMap :: Map SpecId Species -> Int
totalComplexityMap smap = M.foldl' (+) 0 . M.map speciesComplexity $ smap
validatePopulation :: Population -> Maybe [String]
validatePopulation Population{..} = case errRes of
                                     [] -> Nothing
                                     xs -> Just xs
  where totalSSize = M.foldl' (\acc x -> specSize x + acc) 0 popSpecs
        goodSize
          | totalSSize == popSize = []
          | otherwise = ["Population size differs from actual size"]
        goodSId
          | (not . M.null) popSpecs && fst (M.findMax popSpecs) < nextSpec = []
          | otherwise = ["SpecId lower than extant species"]
        specErrs = concat . M.elems $ M.mapMaybe validateSpecies popSpecs
        errRes = goodSId  ++ goodSize ++ specErrs