/* * Created on Aug 5, 2005 * */ package aima.learning.statistics; import java.util.ArrayList; import java.util.Hashtable; import java.util.Iterator; import java.util.List; import aima.util.Util; public class StandardBackPropogation implements NeuralNetworkTrainingScheme { private Hashtable neuronDeltaMap,neuronBiasMap; private Hashtable linkWeightMap; private double learningRate; public StandardBackPropogation() { neuronDeltaMap = new Hashtable(); neuronBiasMap = new Hashtable(); linkWeightMap = new Hashtable(); learningRate = 0.1; } public void backPropogate(FeedForwardNetwork network, List input, List correctOutput) { network.propogateInput(input); calculateOutputLayerDelta(network, correctOutput); calculateHiddenLayersDelta(network); } private void calculateOutputLayerDelta(FeedForwardNetwork network, List correctOutput) { Layer outputLayer = network.getOutputLayer(); Iterator neuronIter = outputLayer.iterator(); Iterator errorIter = outputLayer.getError(correctOutput) .iterator(); while (neuronIter.hasNext() && errorIter.hasNext()) { // multiplied by -1 because the error calculationis inverted from the book neuronDeltaMap.put(neuronIter.next(), -1 * errorIter.next()); } } private void calculateHiddenLayersDelta(FeedForwardNetwork network) { List hiddenLayers = network.getHiddenLayers(); for (int i = hiddenLayers.size() - 1;i> -1;i-- ){ Layer layer = hiddenLayers.get(i); for (Neuron neuron :layer.getNeurons()){ double weightsum =0.0; for (Link l :neuron.outLinks()){ weightsum += (l.weight()*neuronDeltaMap.get(l.target())); } neuronDeltaMap.put(neuron,weightsum * neuron.getActivationFuncton().deriv(neuron.activation())); } } } public List delta(Layer l){ List list = new ArrayList(); for (Neuron n :l.getNeurons()){ list.add(neuronDeltaMap.get(n)); } return list; } public void updateWeightsAndBiases(FeedForwardNetwork network) { for (Neuron n :network.getOutputLayer().getNeurons()){ updateWeightsAndBiases(n); } for (Layer hiddenLayer:network.getHiddenLayers()){ for (Neuron n :hiddenLayer.getNeurons()){ updateWeightsAndBiases(n); } } } private void updateWeightsAndBiases(Neuron n) { double delta = neuronDeltaMap.get(n); neuronBiasMap.put(n,n.bias()); n.setBias(n.bias() - learningRate* delta) ; for (Link link : n.inLinks()){ linkWeightMap.put(link,link.weight()); double weightChange = (learningRate * delta * link.source().activation()); link.setWeight(link.weight() - weightChange); } } public double error(List expectedOutput,FeedForwardNetwork network) { return Util.sumOfSquares(network.error(expectedOutput)); } }