module Neet.Network (
modSig
, Network(..)
, Neuron(..)
, mkPhenotype
, stepNeuron
, stepNetwork
, snapshot
, getOutput
) where
import Data.Map (Map)
import Data.Set (Set)
import qualified Data.Set as S
import qualified Data.Map as M
import Data.List (foldl')
import Neet.Genome
modSig :: Double -> Double
modSig d = 1 / (1 + exp (4.9 * d))
data Neuron =
Neuron { activation :: Double
, connections :: Map NodeId Double
, yHint :: Rational
}
deriving (Show)
data Network =
Network { netInputs :: [NodeId]
, netOutputs :: [NodeId]
, netState :: Map NodeId Neuron
, netDepth :: Int
}
deriving (Show)
stepNeuron :: Map NodeId Double -> Neuron -> Neuron
stepNeuron acts (Neuron _ conns yh) = Neuron (modSig weightedSum) conns yh
where oneFactor nId w = (acts M.! nId) * w
weightedSum = M.foldlWithKey' (\acc k w -> acc + oneFactor k w) 0 conns
stepNetwork :: Network -> [Double] -> Network
stepNetwork net@Network{..} ins = net { netState = newNeurons }
where pairs = zip netInputs (ins ++ [1])
acts = M.map activation netState
modState = foldl' (flip $ uncurry M.insert) acts pairs
newNeurons = M.map (stepNeuron modState) netState
snapshot :: Network -> [Double] -> Network
snapshot net ds = go (netDepth net 1) ds
where go 0 ds = net
go n ds = stepNetwork (go (n 1) ds) ds
mkPhenotype :: Genome -> Network
mkPhenotype Genome{..} = (M.foldl' addConn nodeHusk connGenes) { netInputs = ins
, netOutputs = outs
, netDepth = dep }
where addNode n@(Network _ _ s _) nId (NodeGene _ yh) =
n { netState = M.insert nId (Neuron 0 M.empty yh) s
}
ins = M.keys . M.filter (\ng -> nodeType ng == Input) $ nodeGenes
outs = M.keys . M.filter (\ng -> nodeType ng == Output) $ nodeGenes
nodeHusk = M.foldlWithKey' addNode (Network [] [] M.empty 0) nodeGenes
depthSet :: Set Rational
depthSet = M.foldl' (flip S.insert) S.empty $ M.map Neet.Genome.yHint nodeGenes
dep = S.size depthSet
addConn2Node nId w (Neuron a cs yh) = Neuron a (M.insert nId w cs) yh
addConn net@Network{ netState = s } ConnGene{..}
| not connEnabled = net
| otherwise =
let newS = M.adjust (addConn2Node connIn connWeight) connOut s
in net { netState = newS }
getOutput :: Network -> [Double]
getOutput Network{..} = map (activation . (netState M.!)) netOutputs