module Neet.Population (
                         Population(..)
                       , SpecId(..)
                         
                       , PopM
                       , PopContext
                       , runPopM
                         
                       , PopSettings(..)
                       , newPop
                         
                       , trainOnce
                       , trainN
                       , trainUntil
                         
                       , speciesCount
                         
                       , validatePopulation
                       ) where
import Neet.Species
import Neet.Genome
import Neet.Parameters
import Data.MultiMap (MultiMap)
import qualified Data.MultiMap as MM
import Data.Map (Map)
import qualified Data.Map as M
import Data.List (foldl', maximumBy, sortBy)
import Data.Maybe
import Control.Monad.Random
import Control.Monad.Fresh.Class
import Control.Monad.Trans.State
import Control.Applicative
import Control.Monad
import Control.Parallel.Strategies
import Data.Function
newtype SpecId = SpecId Int
               deriving (Show, Eq, Ord)
data Population =
  Population { popSize   :: Int                
             , popSpecs  :: !(Map SpecId Species) 
             , popBScore :: !Double             
             , popBOrg   :: !Genome             
             , popBSpec  :: !SpecId             
             , popCont   :: !PopContext         
             , nextSpec  :: !SpecId             
             , popParams  :: Parameters        
             , popGen    :: Int                
             }
  deriving (Show)
data PopContext =
  PC { nextInno :: InnoId
     , randGen  :: StdGen
     } 
  deriving (Show)
newtype PopM a = PopM (State PopContext a)
            deriving (Functor, Applicative, Monad)
instance MonadRandom PopM where
  getRandom = PopM . state $ \s ->
    let (r, gen) = random (randGen s)
    in (r, s { randGen = gen })
  getRandoms = PopM . state $ \s ->
    let (g1, g2) = split $ randGen s
    in (randoms g1, s { randGen = g2 })
  getRandomR range = PopM . state $ \s ->
    let (r, gen) = randomR range (randGen s)
    in (r, s { randGen = gen })
  getRandomRs range = PopM . state $ \s ->
    let (g1, g2) = split $ randGen s
    in (randomRs range g1, s { randGen = g2 })
instance MonadFresh InnoId PopM where
  fresh = PopM . state $ \s ->
    let inno@(InnoId x) = nextInno s
    in (inno, s { nextInno = InnoId $ x + 1 })
runPopM :: PopM a -> PopContext -> (a, PopContext)
runPopM (PopM ma) = runState ma
data PopSettings =
  PS { psSize    :: Int        
     , psInputs  :: Int        
     , psOutputs :: Int        
     , psParams  :: Parameters 
     , sparse    :: Maybe Int  
                               
     } 
  deriving (Show)
newtype SpecM a = SM (State SpecId a)
                deriving (Functor, Applicative, Monad)
instance MonadFresh SpecId SpecM where
  fresh = SM . state $ \s@(SpecId x) -> (s, SpecId $ x + 1)
runSpecM :: SpecM a -> SpecId -> (a, SpecId)
runSpecM (SM ma) = runState ma
data SpecBucket =
  SB SpecId Genome [Genome]
shuttleOrgs :: MonadFresh SpecId m =>
               Parameters -> [SpecBucket] -> [Genome] -> m [SpecBucket]
shuttleOrgs p@Parameters{..} buckets = foldM shutOne buckets
  where DistParams{..} = distParams
        shutOne :: MonadFresh SpecId m => [SpecBucket] -> Genome -> m [SpecBucket]
        shutOne (SB sId rep gs:bs) g
          | distance p g rep <= delta_t = return $ SB sId rep (g:gs) : bs
          | otherwise = liftM (SB sId rep gs :) $ shutOne bs g
        shutOne [] g = do
          newId <- fresh
          return $ [SB newId g [g]]
zipWithDefaults :: (a -> b -> Maybe c) -> (a -> Maybe c) -> (b -> Maybe c) -> [a] -> [b] -> [c]
zipWithDefaults _ _  db [] bs = mapMaybe db bs
zipWithDefaults _ da _  as [] = mapMaybe da as
zipWithDefaults f da db (a:as) (b:bs) =
  case f a b of
   Just res -> res : zipWithDefaults f da db as bs
   Nothing -> zipWithDefaults f da db as bs
speciate :: MonadFresh SpecId m =>
            Parameters -> Map SpecId Species -> [Genome] -> m (Map SpecId Species)
speciate params specs gens = do
  filled <- fill
  let zipped = zipWithDefaults mkSpecies (const Nothing) newSpecies specL filled
  return $ M.fromList zipped
  where oneBucket (k, Species _ (rep:_) _ _) = SB k rep []
        oneBucket _                          = error "(speciate) Empty species!"
        assocs = M.toList specs
        specL = map snd assocs
        buckets = map oneBucket assocs
        fill = shuttleOrgs params buckets gens
        mkSpecies (Species _ _ scr imp) (SB sId _ gs)
          | null gs = Nothing
          | otherwise = Just $ (sId, Species (length gs) gs scr imp)
        newSpecies (SB _ _ []) = Nothing
        newSpecies (SB sId _ (g:gs)) = Just $ (sId, newSpec g gs)
newPop :: Int -> PopSettings -> Population
newPop seed PS{..} = fst $ runPopM generate initCont
  where Parameters{..} = psParams
        popSize = psSize
        popBScore = 0
        popBSpec = SpecId 1
        initCont = PC (InnoId $ psInputs * psOutputs + 2) (mkStdGen seed)
        popParams = psParams
        orgGenner = case sparse of
                     Nothing -> fullConn mutParams
                     Just conCount -> sparseConn mutParams conCount
        generateGens = replicateM psSize (orgGenner psInputs psOutputs)
        popGen = 1
        generate = do
          gens <- generateGens
          let (popSpecs, nextSpec) = runSpecM (speciate psParams M.empty gens) (SpecId 1)
              popBOrg = head gens
          popCont <- PopM get
          return Population{..}
trainOnce :: GenScorer a -> Population -> (Population, Maybe Genome)
trainOnce scorer pop = (generated, msolution)
  where params = popParams pop
        mParams = mutParams params
        mParamsS = mutParamsS params
          
        chooseParams :: Species -> MutParams
        chooseParams s = if specSize s >= largeSize params then mParams else mParamsS
        
        
        initSpecs = popSpecs pop
        oneEval :: Strategy (Species, TestResult)
        oneEval = evalTuple2 r0 rseq
        
        fits = M.map (\sp -> (sp, runFitTest scorer sp)) initSpecs `using` parTraversable oneEval
        msolution = go $ map (trSol . snd) $ M.elems fits
          where go [] = Nothing
                go (Just x:_) = Just x
                go (_:xs) = go xs
        
        eugenics :: SpecId -> (Species, TestResult) ->
                    Maybe (Species, MultiMap Double Genome, Double)
        eugenics sId (sp, TR{..})
          | maybe False (lastImprovement nSpec >=) (dropTime params)
            && sId /= popBSpec pop = Nothing
          | otherwise = Just (nSpec, trScores, trAdj)
          where nSpec = updateSpec trSS sp
        
        masterRace :: Map SpecId (Species, MultiMap Double Genome, Double)
        masterRace = M.mapMaybeWithKey eugenics fits
        
        masterList :: [(SpecId,(Species, MultiMap Double Genome, Double))]
        masterList = M.toList masterRace
        
        idVeryBest :: (SpecId, Species)
        idVeryBest = maximumBy (compare `on` (bestScore . specScore . snd)) $ map clean masterList
          where clean (sId,(sp, _, _)) = (sId,sp)
        veryBest = snd idVeryBest
        bestId = fst idVeryBest
        
        masterSpec :: Map SpecId Species
        masterSpec = M.map (\(s,_,_) -> s) masterRace
        totalFitness = M.foldl' (+) 0 . M.map (\(_,_,x) -> x) $ masterRace
        totalSize = popSize pop
        dubSize = fromIntegral totalSize
        
        candSpecs :: MonadRandom m => [(MutParams, Int, m (Double,Genome))]
        candSpecs = zip3 ps realShares pickers
          where sortedMaster = sortBy revComp masterList
                
                revComp (_,(sp1,_,_)) (_,(sp2,_,_)) = (compare `on` (bestScore . specScore)) sp2 sp1
                initShares = map share sortedMaster
                share (_,(_, _, adj)) = floor $ adj / totalFitness * dubSize
                remaining = totalSize  foldl' (+) 0 initShares
                distributeRem _ [] = error "Should run out of numbers first"
                distributeRem n l@(x:xs)
                  | n > 0 = x + 1 : distributeRem (n  1) xs
                  | n < 0 = error "Remainder should be positive"
                  | otherwise = l
                realShares = distributeRem remaining initShares
                pickers :: MonadRandom m => [m (Double, Genome)]
                pickers = map picker sortedMaster
                  where picker (_,(s, mmap, _)) =
                          let numToTake = specSize s `div` 5 + 1
                              desc = M.toDescList $ MM.toMap mmap
                              toPairs (k, vs) = map (\v -> (k,v)) vs
                              culled = take numToTake $ desc >>= toPairs
                          in uniform culled
                ps = map (\(_,(s,_,_)) -> chooseParams s) sortedMaster
        applyN :: Monad m => Int -> (a -> m a) -> a -> m a
        applyN 0 _  x = return x
        applyN n h !x = h x >>= applyN (n  1) h
        
        specGens :: (MonadFresh InnoId m, MonadRandom m) =>
                    Map ConnSig InnoId -> (MutParams, Int, m (Double, Genome)) ->
                    m (Map ConnSig InnoId, [Genome])
        specGens inns (p, n, gen) = applyN n genOne (inns, [])
          where genOne (innos, gs) = do
                  roll <- getRandomR (0,1)
                  if roll <= noCrossover p
                    then do
                    (_,parent) <- gen
                    (innos', g) <- mutate p innos parent
                    return (innos', g:gs)
                    else do
                    (fit1, mom) <- gen
                    (fit2, dad) <- gen
                    (innos', g) <- if fit1 > fit2
                                   then breed p innos mom dad
                                   else breed p innos dad mom
                    return (innos', g:gs)
        allGens :: (MonadRandom m, MonadFresh InnoId m) => m [Genome]
        allGens = liftM (concat . snd) $ foldM ag' (M.empty, []) candSpecs
          where ag' (innos, cands) cand = do
                  (innos', specGen) <- specGens innos cand
                  return $ (innos', specGen:cands)
                  
        genNewSpecies :: (MonadRandom m, MonadFresh InnoId m) => m (Map SpecId Species, SpecId)
        genNewSpecies = do
          gens <- allGens
          return $ runSpecM (speciate params masterSpec gens) (nextSpec pop)
        generated :: Population
        generated = fst $ runPopM generate (popCont pop)
        generate :: PopM Population
        generate = do
          (specs, nextSpec') <- genNewSpecies
          let bScoreNow = (bestScore . specScore) veryBest
              bOrgNow = (bestGen . specScore) veryBest
              bSpecNow = bestId
              (bScore, bOrg, bSpec) =
                if bScoreNow > popBScore pop
                then (bScoreNow, bOrgNow, bSpecNow)
                else (popBScore pop, popBOrg pop, popBSpec pop)
          cont' <- PopM get
          return pop { popSpecs = specs
                     , popBScore = bScore
                     , popBOrg = bOrg
                     , popBSpec = bSpec
                     , popCont = cont'
                     , nextSpec = nextSpec'
                     , popGen = popGen pop + 1
                     } 
trainN :: Int -> GenScorer a -> Population -> Population
trainN n scorer p
  | n <= 0 = p
  | otherwise = applyN n (trainOnce scorer) p
  where applyN 0  _ !x = x
        applyN n' h !x = applyN (n'  1) h (fst $ h x)
trainUntil :: Int -> GenScorer a -> Population -> (Population, Maybe (Genome, Int))
trainUntil n f p
  | n <= 0 = (p, Nothing)
  | otherwise = go n p
  where go 0  !p' = (p', Nothing)
        go n' !p' = case trainOnce f p' of
                     (p'', Nothing) -> go (n'  1) p''
                     (p'', Just g) -> (p'', Just (g, n  n'))
speciesCount :: Population -> Int
speciesCount Population{..} = M.size popSpecs
validatePopulation :: Population -> Maybe [String]
validatePopulation Population{..} = case errRes of
                                     [] -> Nothing
                                     xs -> Just xs
  where totalSSize = M.foldl' (\acc x -> specSize x + acc) 0 popSpecs
        goodSize
          | totalSSize == popSize = []
          | otherwise = ["Population size differs from actual size"]
        goodSId
          | (not . M.null) popSpecs && fst (M.findMax popSpecs) < nextSpec = []
          | otherwise = ["SpecId lower than extant species"]
        specErrs = concat . M.elems $ M.mapMaybe validateSpecies popSpecs
        errRes = goodSId  ++ goodSize ++ specErrs