module Neet.Examples.XOR (xorFit, andFit, orFit, xorExperiment) where
import Neet
import Neet.Species
import Data.Monoid
import qualified Data.Map.Strict as M
import System.Random
import Data.List (intercalate)
boolQuestions :: [[Double]]
boolQuestions = [ [0, 0]
                , [0, 1]
                , [1, 0]
                , [1, 1]
                ]
xorAnswers :: [Bool]
xorAnswers = [False, True, True, False]
sampleFit :: [[Double]] -> [Bool] -> GenScorer [Double]
sampleFit questions answers = GS intermed ff criteria
  where intermed g = map try questions
          where try samp = head $ pushThrough net samp
                net = mkPhenotype g
        ff ds = (fromIntegral (length answers)  sumDiffs)**2
          where sumDiffs = sum $ zipWith (\x y -> abs (x  y)) ds binarized
        binarized = map (\b -> if b then 1 else 0) answers
        bounds = map (\b -> if b then (>0.5) else (<0.5)) answers
        criteria ds = and $ zipWith id bounds ds
xorFit :: GenScorer [Double]
xorFit = sampleFit boolQuestions xorAnswers
andAnswers :: [Bool]
andAnswers = [False, False, False, True]
andFit :: GenScorer [Double]
andFit = sampleFit boolQuestions andAnswers
orAnswers :: [Bool]
orAnswers = [False, True, True, True]
orFit :: GenScorer [Double]
orFit = sampleFit boolQuestions orAnswers
xorExperiment :: IO ()
xorExperiment = do
  putStrLn $ "XOR Input list is: " ++ show boolQuestions
  putStrLn "Press Enter to start learning"
  _ <- getLine
  putStrLn "Running XOR experiment with 150 population and default parameters"
  seed <- randomIO
  let pp = Just (PhaseParams 10 10)
      pop = newPop seed (PS 150 2 1 params Nothing pp)
      params = defParams { specParams = sp, mutParams = mp, mutParamsS = mpS }
      mp = defMutParams { delConnChance = 0.3, delNodeChance = 0.03 }
      mpS = defMutParamsS { addConnRate = 0.05, delConnChance = 0.05 }
      sp = Target dp (SpeciesTarget (14,17) 0.1)
      dp = defDistParams { delta_t = 5 }
  (pop', sol) <- xorLoop pop
  printInfo pop'
  putStrLn $ "Solution found in generation " ++ show (popGen pop')
  let score = gScorer xorFit sol
  putStrLn $ "\nOutputs to XOR inputs are: " ++ show score
  putStrLn $ "Fitness (Out of 16): " ++ show (fitnessFunction xorFit score)
  putStrLn $ "Final distance threshold: " ++ show (distParams . specParams $ popParams pop')
  
  putStrLn "\nPress Enter to view network"
  _ <- getLine
  renderGenome sol
mkSpecInfo :: Population -> String
mkSpecInfo pop = intercalate ", " infos
  where infos = map (\((SpecId k), sp) -> "S" ++ show k ++ " P" ++ show (specSize sp)) ass
        ass = M.toList $ popSpecs pop
xorLoop :: Population -> IO (Population, Genome)
xorLoop pop = do
  printInfo pop
  let (First mg, pop') = trainOnce (winTrain xorFit) pop
  case mg of
   Nothing -> xorLoop pop'
   Just g -> return (pop',g)
     
printInfo :: Population -> IO ()
printInfo pop = do
  putStrLn $ "Generation " ++ show (popGen pop)
  putStrLn $ "Species: " ++ mkSpecInfo pop
  putStrLn $ "High Score: " ++ show (popBScore pop)
  putStrLn ""