module Neet.Genome (
NodeId(..)
, NodeType(..)
, NodeGene(..)
, ConnGene(..)
, InnoId(..)
, ConnSig
, Genome(..)
, fullConn
, sparseConn
, genomeComplexity
, mutateAdd
, mutateSub
, crossover
, breed
, distance
, GenScorer(..)
, renderGenome
, printGenome
, validateGenome
) where
import Control.Applicative
import Control.Monad
import Control.Monad.Random
import Control.Arrow (first)
import Data.Map.Strict (Map)
import qualified Data.Traversable as T
import qualified Data.Foldable as F
import qualified Data.Map.Strict as M
import qualified Data.IntSet as IS
import qualified Data.Set as S
import qualified Data.IntMap as IM
import Data.IntMap (IntMap)
import Data.Maybe
import Control.Monad.Fresh.Class
import Neet.Parameters
import Data.GraphViz
import Data.GraphViz.Attributes.Complete
import GHC.Generics (Generic)
import Data.Serialize (Serialize)
import Text.Printf
newtype NodeId = NodeId { getNodeId :: Int }
deriving (Show, Eq, Ord, PrintDot, Serialize)
data NodeType = Input | Hidden | Output
deriving (Show, Eq, Generic)
instance Serialize NodeType
data NodeGene = NodeGene { nodeType :: NodeType
, yHint :: Rational
}
deriving (Show, Generic)
instance Serialize NodeGene
data ConnGene = ConnGene { connIn :: NodeId
, connOut :: NodeId
, connWeight :: Double
, connEnabled :: Bool
, connRec :: Bool
}
deriving (Show, Generic)
instance Serialize ConnGene
newtype InnoId = InnoId { getInnoId :: Int }
deriving (Show, Eq, Ord)
data Genome =
Genome { ioCount :: Int
, nodeGenes :: IntMap NodeGene
, connGenes :: IntMap ConnGene
, nextNode :: NodeId
}
deriving (Show, Generic)
instance Serialize Genome
fullConn :: MonadRandom m => MutParams -> Int -> Int -> m Genome
fullConn MutParams{..} iSize oSize = do
let inCount = iSize + 1
inIDs = [1..inCount]
outIDs = [inCount + 1..oSize + inCount]
inputGenes = zip inIDs $ repeat (NodeGene Input 0)
outputGenes = zip outIDs $ repeat (NodeGene Output 1)
nodeGenes = IM.fromList $ inputGenes ++ outputGenes
nextNode = NodeId $ inCount + oSize + 1
nodePairs = (,) <$> inIDs <*> outIDs
ioCount = inCount + oSize
conns <- zipWith (\(inN, outN) w -> ConnGene (NodeId inN) (NodeId outN) w True False)
nodePairs `liftM` getRandomRs (weightRange,weightRange)
let connGenes = IM.fromList $ zip [1..] conns
return $ Genome{..}
sparseConn :: MonadRandom m => MutParams -> Int -> Int -> Int -> m Genome
sparseConn MutParams{..} cons iSize oSize = do
let inCount = iSize + 1
inIDs = [1..inCount]
outIDs = [inCount + 1..oSize + inCount]
inputGenes = zip inIDs $ repeat (NodeGene Input 0)
outputGenes = zip outIDs $ repeat (NodeGene Output 1)
nodeGenes = IM.fromList $ inputGenes ++ outputGenes
nextNode = NodeId $ inCount + oSize + 1
nodePairs = (,) <$> inIDs <*> outIDs
idNodePairs = zip [1..] nodePairs
ioCount = inCount + oSize
conPairs <- replicateM cons (uniform idNodePairs)
conns <- zipWith (\(inno,(inN, outN)) w -> (inno, ConnGene (NodeId inN) (NodeId outN) w True False))
conPairs `liftM` getRandomRs (weightRange, weightRange)
let connGenes = IM.fromList conns
return $ Genome{..}
mutateWeights :: MonadRandom m => MutParams -> Genome -> m Genome
mutateWeights MutParams{..} gen@Genome{..} = do
roll <- getRandomR (0,1)
if roll > mutWeightRate
then return gen
else setConns gen `liftM` T.mapM mutOne connGenes
where setConns g cs = g { connGenes = cs }
mutOne conn = do
roll <- getRandomR (0,1)
let newWeight
| roll <= newWeightRate = getRandomR (weightRange,weightRange)
| otherwise = do
pert <- getRandomR (pertAmount,pertAmount)
return $ connWeight conn + pert
w <- newWeight
return $ conn { connWeight = w }
data ConnSig = ConnSig NodeId NodeId
deriving (Show, Eq, Ord)
toConnSig :: ConnGene -> ConnSig
toConnSig gene = ConnSig (connIn gene) (connOut gene)
addConn :: MonadFresh InnoId m => ConnGene ->
(Map ConnSig InnoId, IntMap ConnGene) ->
m (Map ConnSig InnoId, IntMap ConnGene)
addConn conn (innos, conns) = case M.lookup siggy innos of
Just inno -> return (innos, IM.insert (getInnoId inno) conn conns)
Nothing -> do
nI@(InnoId newInno) <- fresh
return (M.insert siggy nI innos, IM.insert newInno conn conns)
where siggy = toConnSig conn
mutateConn :: (MonadFresh InnoId m, MonadRandom m) =>
MutParams -> Map ConnSig InnoId -> Genome -> m (Map ConnSig InnoId, Genome)
mutateConn params innos g = do
roll <- getRandomR (0,1)
if roll > addConnRate params
then return (innos, g)
else case allowed of
[] -> return (innos, g)
_ -> do
(innos', conns') <- addRandConn innos (connGenes g)
return $ (innos', g { connGenes = conns' })
where
taken :: Map ConnSig Bool
taken = M.fromList . map (\c -> (toConnSig c, True)) . IM.elems . connGenes $ g
notInput (NodeGene Input _) = False
notInput _ = True
nodes = IM.toList $ nodeGenes g
nonInputs = filter (notInput . snd) nodes
makePair (n1,g1) (n2,g2) = (ConnSig (NodeId n1) (NodeId n2), yHint g2 <= yHint g1)
candidates = M.fromList $
if recurrencies params
then makePair <$> nodes <*> nonInputs
else filter nonRec $ makePair <$> nodes <*> nonInputs
nonRec (_,reccy) = not reccy
allowed = M.toList $ M.difference candidates taken
pickOne :: MonadRandom m => m (ConnSig, Bool)
pickOne = uniform allowed
pickWeight :: MonadRandom m => m Double
pickWeight = let r = weightRange params in getRandomR (r,r)
addRandConn :: (MonadRandom m, MonadFresh InnoId m) =>
Map ConnSig InnoId -> IntMap ConnGene ->
m (Map ConnSig InnoId, IntMap ConnGene)
addRandConn innos' conns = do
(ConnSig inNode outNode, recc) <- pickOne
w <- pickWeight
let newConn = ConnGene inNode outNode w True recc
addConn newConn (innos',conns)
mutateNode :: (MonadRandom m, MonadFresh InnoId m) =>
MutParams -> Map ConnSig InnoId ->
Genome -> m (Map ConnSig InnoId, Genome)
mutateNode params innos g = do
roll <- getRandomR (0,1)
if roll <= addNodeRate params then addRandNode else return (innos, g)
where conns = connGenes g
nodes = nodeGenes g
pickConn :: MonadRandom m => m (Int, ConnGene)
pickConn = uniform $ IM.toList conns
newId = nextNode g
newNextNode = case newId of NodeId x -> NodeId (x + 1)
addNode :: MonadFresh InnoId m =>
InnoId -> ConnGene -> m (Map ConnSig InnoId, Genome)
addNode inno gene = do
let ConnSig (NodeId inId) (NodeId outId) = toConnSig gene
inGene = nodes IM.! inId
outGene = nodes IM.! outId
newGene = NodeGene Hidden ((yHint inGene + yHint outGene) / 2)
newNodes = IM.insert (getNodeId newId) newGene nodes
disabledConn = gene { connEnabled = False }
backGene = ConnGene (NodeId inId) newId 1 True (connRec gene)
forwardGene = ConnGene newId (NodeId outId) (connWeight gene) True (connRec gene)
(innos', newConns) <-
addConn backGene >=> addConn forwardGene $ (innos, conns)
return $ (innos', g { nodeGenes = newNodes
, connGenes = IM.insert (getInnoId inno) disabledConn newConns
, nextNode = newNextNode
})
addRandNode :: (MonadRandom m, MonadFresh InnoId m) => m (Map ConnSig InnoId, Genome)
addRandNode =
pickConn >>= uncurry (addNode . InnoId)
isOrphanNode :: NodeId -> IntMap ConnGene -> Bool
isOrphanNode nId imap = F.all doesntContain imap
where doesntContain ConnGene{..} = nId /= connIn && nId /= connOut
mutDelConn :: MonadRandom m => MutParams -> Genome -> m Genome
mutDelConn MutParams{..} genome@Genome{..} = do
roll <- getRandomR (0,1)
if IM.size connGenes <= 1 || roll > delConnChance
then return genome
else do
(connId, deleteThis) <- uniform $ IM.toList connGenes
let newConns = IM.delete connId connGenes
inOfDeleted = connIn deleteThis
outOfDeleted = connOut deleteThis
inIsHidden = nodeType (nodeGenes IM.! (getNodeId inOfDeleted)) == Hidden
outIsHidden = nodeType (nodeGenes IM.! (getNodeId outOfDeleted)) == Hidden
possiblyRemoved
| inIsHidden && isOrphanNode inOfDeleted newConns =
IM.delete (getNodeId inOfDeleted) nodeGenes
| otherwise = nodeGenes
possiblyRemoved2
| outIsHidden && isOrphanNode outOfDeleted newConns =
IM.delete (getNodeId outOfDeleted) possiblyRemoved
| otherwise = possiblyRemoved
return genome { nodeGenes = possiblyRemoved2, connGenes = newConns }
doesntContainNode :: NodeId -> ConnGene -> Bool
doesntContainNode nId cg = connIn cg /= nId && connOut cg /= nId
mutDelNode :: MonadRandom m => MutParams -> Genome -> m Genome
mutDelNode MutParams{..} genome@Genome{..} = do
roll <- getRandomR (0,1)
if
| IM.size nodeGenes == ioCount -> return genome
| roll > delNodeChance -> return genome
| otherwise -> do
deleteThis <- uniform delCandidates
let newNodes = IM.delete deleteThis nodeGenes
newConns = IM.filter (doesntContainNode (NodeId deleteThis)) connGenes
return $ genome { nodeGenes = newNodes, connGenes = newConns }
where delCandidates = drop ioCount $ IM.keys nodeGenes
mutateSub :: MonadRandom m => MutParams -> Genome -> m Genome
mutateSub params = mutDelNode params >=> mutDelConn params >=> mutateWeights params
mutateAdd :: (MonadRandom m, MonadFresh InnoId m) => MutParams -> Map ConnSig InnoId ->
Genome -> m (Map ConnSig InnoId, Genome)
mutateAdd params innos g = do
g' <- mutateWeights params g
uncurry (mutateNode params) >=> uncurry (mutateConn params) $ (innos, g')
superLeft :: (a -> b -> c) -> (a -> c) -> IntMap a -> IntMap b -> IntMap c
superLeft comb mk = IM.mergeWithKey (\_ a b -> Just $ comb a b) (IM.map mk) (const IM.empty)
flipCoin :: MonadRandom m => a -> a -> m a
flipCoin a1 a2 = do
roll <- getRandom
return $ if roll then a1 else a2
crossConns :: MonadRandom m => MutParams -> IntMap ConnGene -> IntMap ConnGene ->
m (IntMap ConnGene)
crossConns params m1 m2 = T.sequence $ superLeft flipConn return m1 m2
where flipConn c1 c2 = do
if connEnabled c1 && connEnabled c2
then flipCoin c1 c2
else do
c <- flipCoin c1 c2
roll <- getRandomR (0,1)
let enabled
| roll <= disableChance params = False
| otherwise = True
return c { connEnabled = enabled }
crossNodes :: IntMap NodeGene -> IntMap NodeGene ->
IntMap NodeGene
crossNodes m1 m2 = superLeft (\a _ -> a) id m1 m2
crossover :: MonadRandom m => MutParams -> Genome -> Genome -> m Genome
crossover params g1 g2 = Genome (ioCount g1) newNodes `liftM` newConns `ap` return newNextNode
where newNextNode = max (nextNode g1) (nextNode g2)
newConns = crossConns params (connGenes g1) (connGenes g2)
newNodes = crossNodes (nodeGenes g1) (nodeGenes g2)
breed :: (MonadRandom m, MonadFresh InnoId m) =>
MutParams -> Map ConnSig InnoId -> Genome -> Genome ->
m (Map ConnSig InnoId, Genome)
breed params innos g1 g2 =
crossover params g1 g2 >>= mutateAdd params innos
differences :: IntMap ConnGene -> IntMap ConnGene -> IntMap Double
differences = IM.mergeWithKey (\_ c1 c2 -> Just $ oneDiff c1 c2) (const IM.empty) (const IM.empty)
where oneDiff c1 c2 = abs $ connWeight c1 connWeight c2
distance :: Parameters -> Genome -> Genome -> Double
distance params g1 g2 = c1 * exFactor + c2 * disFactor + c3 * weightFactor
where DistParams c1 c2 c3 _ = distParams . specParams $ params
conns1 = connGenes g1
conns2 = connGenes g2
weightDiffs = differences conns1 conns2
weightsSize = IM.size weightDiffs
weightFactor
| weightsSize > 0 = IM.foldl (+) 0 weightDiffs / fromIntegral weightsSize
| otherwise = 0
ids1 = IM.keysSet conns1
ids2 = IM.keysSet conns2
edge = min (IS.findMax ids1) (IS.findMax ids2)
exJoints = (ids1 `IS.difference` ids2) `IS.union` (ids2 `IS.difference` ids1)
(excess, disjoint) = IS.partition (>= edge) exJoints
exFactor = fromIntegral $ IS.size excess
disFactor = fromIntegral $ IS.size disjoint
graphParams :: GraphvizParams NodeId NodeGene Double Rational Rational
graphParams =
Params { isDirected = True
, globalAttributes = [ GraphAttrs [ RankDir FromLeft
, Splines LineEdges
]
, NodeAttrs [ FixedSize SetNodeSize
]
]
, clusterBy = categorizer
, isDotCluster = const True
, clusterID = iderizer
, fmtCluster = clusterizer
, fmtNode = const []
, fmtEdge = \(_,_,w) -> [ toLabel $ (printf "%.2f" w :: String) ]
}
where categorizer (nId, ng) = C (yHint ng) (N (nId, yHint ng))
iderizer 0 = Str "Input Layer"
iderizer 1 = Str "Output Layer"
iderizer rat = Num (Dbl $ fromRational rat)
whiteAttr = Color [WC (X11Color White) Nothing]
blueAttr = Color [WC (X11Color Blue4) Nothing ]
redAttr = Color [WC (X11Color Red2) Nothing ]
greenAttr = Color [WC (X11Color SeaGreen) Nothing ]
solidAttr = Style [ SItem Solid [] ]
circAttr = Shape Circle
clusterizer 0 = [ GraphAttrs [ whiteAttr, rank MinRank ]
, NodeAttrs [ solidAttr, blueAttr, circAttr ]
]
clusterizer 1 = [ GraphAttrs [ whiteAttr, rank MaxRank ]
, NodeAttrs [ solidAttr, redAttr, circAttr ]
]
clusterizer _ = [ GraphAttrs [ whiteAttr ]
, NodeAttrs [ solidAttr, greenAttr, circAttr ]
]
renderGenome :: Genome -> IO ()
renderGenome g = runGraphvizCanvas Dot graph Xlib
where nodes = map (first NodeId) . IM.toList . nodeGenes $ g
edges = mapMaybe mkEdge . IM.elems . connGenes $ g
mkEdge ConnGene{..} = if connEnabled then Just (connIn, connOut, connWeight) else Nothing
graph = graphElemsToDot graphParams nodes edges
printGenome :: Genome -> IO ()
printGenome g = putStrLn $ unlines stuff
where unwrap (NodeId x) = x
eText True = ""
eText False = "(Disabled)"
stuff = [header, nHeader] ++ nInfo ++ [cHeader] ++ cInfo
header = "Genetic Info:"
nHeader = "Nodes:"
nInfo = map mkNInfo . IM.toList $ nodeGenes g
mkNInfo (x, NodeGene t _) = show x ++ "(" ++ show t ++ ")"
cHeader = "\n\nConnections:"
cInfo = map mkCInfo . IM.toList $ connGenes g
mkCInfo (i, ConnGene{..}) =
"\nInnovation " ++ show i ++
"\nConnection from " ++ show (unwrap connIn) ++ " to " ++
show (unwrap connOut) ++ " " ++ eText connEnabled ++
" with weight " ++ show connWeight
data GenScorer score =
GS { gScorer :: Genome -> score
, fitnessFunction :: score -> Double
, winCriteria :: score -> Bool
}
uniq :: Ord a => [a] -> Bool
uniq = go S.empty
where go _ [] = True
go set (x:xs) = not (S.member x set) && go (S.insert x set) xs
validateGenome :: Genome -> Maybe [String]
validateGenome Genome{..} = case errRes of
[] -> Nothing
xs -> Just xs
where nodeOk = case IM.maxViewWithKey nodeGenes of
Nothing -> Nothing
Just ((nid,_), _)
| nid < getNodeId nextNode -> Nothing
| otherwise -> Just "NodeId too low"
connOk (ConnSig (NodeId n1) (NodeId n2))
| IM.member n1 nodeGenes && IM.member n2 nodeGenes = Nothing
| otherwise = Just "Connection gene between nonexistent nodes"
connsOk = join . listToMaybe $ map connOk sigList
sigList = map toConnSig . IM.elems $ connGenes
nonDup
| uniq sigList = Nothing
| otherwise = Just "Non unique connection signatures"
ioCountGood
| IM.size (IM.filter (\n -> nodeType n == Input || nodeType n == Output) nodeGenes) ==
ioCount = Nothing
| otherwise = Just "ioCount bad"
errRes = catMaybes [nodeOk, connsOk, nonDup, ioCountGood]
genomeComplexity :: Genome -> Int
genomeComplexity gen = IM.size (nodeGenes gen) + IM.size (nodeGenes gen)