-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[MLLIB] [spark-2352] Implementation of an Artificial Neural Network (ANN) #1290
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
143 commits
Select commit
Hold shift + click to select a range
5bba9dd
Create ParallelANN.scala
bgreeven 5874743
Create GeneralizedSteepestDescendAlgorithm
bgreeven 8c3ff4a
Create TestParallelANN.scala
bgreeven 96a0970
Create TestParallelANNgraphics.scala
bgreeven 69b0e59
Update TestParallelANN.scala
bgreeven 3f528b9
Update TestParallelANN.scala
bgreeven b1972b1
Update TestParallelANNgraphics.scala
bgreeven dd79615
Update GeneralizedSteepestDescendAlgorithm
bgreeven 1f6de6a
Update ParallelANN.scala
bgreeven 011c10b
Update GeneralizedSteepestDescendAlgorithm
bgreeven e7e29aa
Update TestParallelANN.scala
bgreeven 100ad4b
Update TestParallelANNgraphics.scala
bgreeven c9fc3f4
Rename GeneralizedSteepestDescendAlgorithm to GeneralizedSteepestDesc…
bgreeven 78f99dc
Update TestParallelANNgraphics.scala
bgreeven 43103f0
Update and rename GeneralizedSteepestDescendAlgorithm.scala to Genera…
bgreeven 2ecc7d5
Update ParallelANN.scala
bgreeven d80fe63
Update TestParallelANN.scala
bgreeven 149a726
Update TestParallelANNgraphics.scala
bgreeven ace988e
Create mllib-ann.md
bgreeven 9f75f59
Update mllib-ann.md
bgreeven c81de0c
Update mllib-ann.md
bgreeven 3c456b5
Update mllib-ann.md
bgreeven 3807e73
Update mllib-ann.md
bgreeven 5236a9d
Update mllib-ann.md
bgreeven 443ea7e
Update and rename GeneralizedSteepestDescentAlgorithm.scala to Genera…
bgreeven 3466f95
Update ParallelANN.scala
bgreeven aed39c6
Update TestParallelANN.scala
bgreeven 1972c69
ANN test suite: learning XOR function
avulanov d04c1d6
Removing dependency on GeneralizedModel and Algorithm
avulanov bd4508b
Addressing reviewers comments: interface refactoring
avulanov c032476
Apache header
avulanov 3e90c4d
Update ArtificialNeuralNetwork.scala
bgreeven 71ca727
Update and rename TestParallelANN.scala to TestANN.scala
bgreeven 293d013
Delete TestParallelANNgraphics.scala
bgreeven 18ac979
Update ArtificialNeuralNetwork.scala
bgreeven daf1375
Update ANNSuite.scala
bgreeven 5e3345c
minor style fixes
avulanov 6c657c3
Forward propagation code sharing
avulanov 5ab0263
Update ArtificialNeuralNetwork.scala
bgreeven 577a13a
Update ANNSuite.scala
bgreeven d048878
Delete TestANN.scala
bgreeven 90195fa
Create ANNDemo.scala
bgreeven 7c90249
Update mllib-ann.md
bgreeven 87f630b
Update mllib-ann.md
bgreeven 986f37a
Update ArtificialNeuralNetwork.scala
bgreeven 8e3e2d5
Update ArtificialNeuralNetwork.scala
bgreeven d2b80fe
Update ArtificialNeuralNetwork.scala
bgreeven 1a1c10b
Update ArtificialNeuralNetwork.scala
bgreeven 2a9554b
Update mllib-ann.md
bgreeven 40197ef
Update ANNDemo.scala
bgreeven 589205f
Update ArtificialNeuralNetwork.scala
bgreeven 6390947
Update ANNSuite.scala
bgreeven abfb0f5
Update ArtificialNeuralNetwork.scala
bgreeven 039df76
Update ArtificialNeuralNetwork.scala
bgreeven aff66ae
Update ArtificialNeuralNetwork.scala
bgreeven e78dcd6
Minor style fixes
avulanov ccbed58
Unit test parameter
avulanov e3dc003
Update ANNSuite.scala
bgreeven dd47d75
ANN classifier draft
avulanov 3e7eca1
Update ArtificialNeuralNetwork.scala
bgreeven f8d5a05
Update ANNSuite.scala
bgreeven 57b9147
XOR classification test with draft
avulanov c189bb2
ANN classifier refactoring in progress: need random weight function
avulanov c4baf79
Minor stylefix, add additional function for customized initial weights
avulanov d0836ed
Model as a parameters for classifier
avulanov 01bbca0
Scala style fix
avulanov c7e5323
Encoding of output with 0.1 and 0.9 by bgreeven suggestion
avulanov 90f5ae5
Addressing bgreeven comment regarding labels sort, annotations
avulanov 243e667
Create ParallelANN.scala
bgreeven 96ba82a
Create GeneralizedSteepestDescendAlgorithm
bgreeven 576ef79
Create TestParallelANN.scala
bgreeven c5cb54d
Create TestParallelANNgraphics.scala
bgreeven 1af7f25
Update TestParallelANN.scala
bgreeven 99f0581
Update TestParallelANN.scala
bgreeven b01fc3c
Update TestParallelANNgraphics.scala
bgreeven cae6dc2
Update GeneralizedSteepestDescendAlgorithm
bgreeven 9eee6f1
Update ParallelANN.scala
bgreeven fec8691
Update GeneralizedSteepestDescendAlgorithm
bgreeven 060ae3a
Update TestParallelANN.scala
bgreeven d1619c8
Update TestParallelANNgraphics.scala
bgreeven 7c3a5b3
Rename GeneralizedSteepestDescendAlgorithm to GeneralizedSteepestDesc…
bgreeven fef4776
Update TestParallelANNgraphics.scala
bgreeven c086751
Update and rename GeneralizedSteepestDescendAlgorithm.scala to Genera…
bgreeven 21d95d0
Update ParallelANN.scala
bgreeven d4764a4
Update TestParallelANN.scala
bgreeven 4623f25
Update TestParallelANNgraphics.scala
bgreeven 10242b7
Create mllib-ann.md
bgreeven 402ad79
Update mllib-ann.md
bgreeven 07218eb
Update mllib-ann.md
bgreeven f7cfa4e
Update mllib-ann.md
bgreeven d3211db
Update mllib-ann.md
bgreeven 51ca78b
Update mllib-ann.md
bgreeven ceaf2f7
Update and rename GeneralizedSteepestDescentAlgorithm.scala to Genera…
bgreeven 6f79c96
Update ParallelANN.scala
bgreeven 2972747
Update TestParallelANN.scala
bgreeven 6740981
ANN test suite: learning XOR function
avulanov c22c3dc
Removing dependency on GeneralizedModel and Algorithm
avulanov d320d76
Addressing reviewers comments: interface refactoring
avulanov 181c29b
Apache header
avulanov 7ac9a67
Update ArtificialNeuralNetwork.scala
bgreeven 8e0dc8b
Update and rename TestParallelANN.scala to TestANN.scala
bgreeven 0a3fca6
Delete TestParallelANNgraphics.scala
bgreeven 50ca819
Update ArtificialNeuralNetwork.scala
bgreeven c2da9b0
Update ANNSuite.scala
bgreeven 73ba0dc
minor style fixes
avulanov a024c6b
Forward propagation code sharing
avulanov 95e5299
Update ArtificialNeuralNetwork.scala
bgreeven 5a3531b
Update ANNSuite.scala
bgreeven 85050ba
Delete TestANN.scala
bgreeven 5f51305
Create ANNDemo.scala
bgreeven 95ed2a2
Update mllib-ann.md
bgreeven 4b83de4
Update mllib-ann.md
bgreeven 7828327
Update ArtificialNeuralNetwork.scala
bgreeven 84ac2e8
Update ArtificialNeuralNetwork.scala
bgreeven e2e94b2
Update ArtificialNeuralNetwork.scala
bgreeven a7fb749
Update ArtificialNeuralNetwork.scala
bgreeven 3995be8
Update mllib-ann.md
bgreeven b44aec3
Update ANNDemo.scala
bgreeven 6265bd6
Update ArtificialNeuralNetwork.scala
bgreeven 325ffab
Update ANNSuite.scala
bgreeven 099ff85
Update ArtificialNeuralNetwork.scala
bgreeven 1c0aab4
Update ArtificialNeuralNetwork.scala
bgreeven 5db2b60
Update ArtificialNeuralNetwork.scala
bgreeven b13019a
Minor style fixes
avulanov e2d4e92
Unit test parameter
avulanov fefe08e
Update ANNSuite.scala
bgreeven 2fbbe23
Update ArtificialNeuralNetwork.scala
bgreeven 57565ae
Update ANNSuite.scala
bgreeven bd74834
Minor stylefix, add additional function for customized initial weights
avulanov 12fb903
Update mllib-ann.md
bgreeven a0d1da0
Update ArtificialNeuralNetwork.scala
bgreeven 9b10666
Update mllib-ann.md
bgreeven 3cf5f9b
Fixes after rebase
avulanov 398e3dd
Matrix form of back-propagation based on avulanov/spark/tree/neuralne…
avulanov 62b1d91
Fix of broken gradient test
avulanov 3f93a2a
Roll/unroll ordering, weight by layer function
avulanov 2fb67f6
Roll and cumulative update optimizations
avulanov 799b277
Update ANNSuite.scala
loachli 6166ad9
Batch ANN
avulanov 5205fda
ANN Classifier batch
avulanov e660ee8
Divisor fix, train interfaces
avulanov d18e9b5
Test Context fix
avulanov 5de5bad
Bias averaging fix
avulanov File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,239 @@ | ||
| --- | ||
| layout: global | ||
| title: Artificial Neural Networks - MLlib | ||
| displayTitle: <a href="mllib-guide.html">MLlib</a> - Artificial Neural Networks | ||
| --- | ||
|
|
||
| # Introduction | ||
|
|
||
| This document describes the MLlib's Artificial Neural Network (ANN) implementation. | ||
|
|
||
| The implementation currently consist of the following files: | ||
|
|
||
| * 'ArtificialNeuralNetwork.scala': implements the ANN | ||
| * 'ANNSuite': implements automated tests for the ANN and its gradient | ||
| * 'ANNDemo': a demo that approximates three functions and shows a graphical representation of | ||
| the result | ||
|
|
||
| # Summary of usage | ||
|
|
||
| The "ArtificialNeuralNetwork" object is used as an interface to the neural network. It is | ||
| called as follows: | ||
|
|
||
| ``` | ||
| val annModel = ArtificialNeuralNetwork.train(rdd, hiddenLayersTopology, maxNumIterations) | ||
| ``` | ||
|
|
||
| where | ||
|
|
||
| * `rdd` is an RDD of type (Vector,Vector), the first element containing the input vector and | ||
| the second the associated output vector. | ||
| * `hiddenLayersTopology` is an array of integers (Array[Int]), which contains the number of | ||
| nodes per hidden layer, starting with the layer that takes inputs from the input layer, and | ||
| finishing with the layer that outputs to the output layer. The bias nodes are not counted. | ||
| * `maxNumIterations` is an upper bound to the number of iterations to be performed. | ||
| * `ANNmodel` contains the trained ANN parameters, and can be used to calculated the ANNs | ||
| approximation to arbitrary input values. | ||
|
|
||
| The approximations can be calculated as follows: | ||
|
|
||
| ``` | ||
| val v_out = annModel.predict(v_in) | ||
| ``` | ||
|
|
||
| where v_in is either a Vector or an RDD of Vectors, and v_out respectively a Vector or RDD of | ||
| (Vector,Vector) pairs, corresponding to input and output values. | ||
|
|
||
| Further details and other calling options will be elaborated upon below. | ||
|
|
||
| # Architecture and Notation | ||
|
|
||
| The file ArtificialNeuralNetwork.scala implements the ANN. The following picture shows the | ||
| architecture of a 3-layer ANN: | ||
|
|
||
| ``` | ||
| +-------+ | ||
| | | | ||
| | N_0,0 | | ||
| | | | ||
| +-------+ +-------+ | ||
| | | | ||
| +-------+ | N_0,1 | +-------+ | ||
| | | | | | | | ||
| | N_1,0 |- +-------+ ->| N_0,2 | | ||
| | | \ Wij1 / | | | ||
| +-------+ -- +-------+ -- +-------+ | ||
| \ | | / Wjk2 | ||
| : ->| N_1,1 |- +-------+ | ||
| : | | | | | ||
| : +-------+ | N_1,2 | | ||
| : | | | ||
| : : +-------+ | ||
| : : | ||
| : : : | ||
| : : | ||
| : : +-------+ | ||
| : : | | | ||
| : : |N_K-1,2| | ||
| : | | | ||
| : +-------+ +-------+ | ||
| : | | | ||
| : |N_J-1,1| | ||
| | | | ||
| +-------+ +-------+ | ||
| | | | ||
| |N_I-1,0| | ||
| | | | ||
| +-------+ | ||
|
|
||
| +-------+ +--------+ | ||
| | | | | | ||
| | -1 | | -1 | | ||
| | | | | | ||
| +-------+ +--------+ | ||
|
|
||
| INPUT LAYER HIDDEN LAYER OUTPUT LAYER | ||
| ``` | ||
|
|
||
| The i-th node in layer l is denoted by N_{i,l}, both i and l starting with 0. The weight | ||
| between node i in layer l-1 and node j in layer l is denoted by Wijl. Layer 0 is the input | ||
| layer, whereas layer L is the output layer. | ||
|
|
||
| The ANN also implements bias units. These are nodes that always output the value -1. The bias | ||
| units are in all layers except the output layer. They act similar to other nodes, but do not | ||
| have input. | ||
|
|
||
| The value of node N_{j,l} is calculated as follows: | ||
|
|
||
| `$N_{j,l} = g( \sum_{i=0}^{topology_l} W_{i,j,l)*N_{i,l-1} )$` | ||
|
|
||
| Where g is the sigmoid function | ||
|
|
||
| `$g(t) = \frac{e^{\beta t} }{1+e^{\beta t}}$` | ||
|
|
||
| # LBFGS | ||
|
|
||
| MLlib's ANN implementation uses the LBFGS optimisation algorithm for training. It minimises the | ||
| following error function: | ||
|
|
||
| `$E = \sum_{k=0}^{K-1} (N_{k,L} - Y_k)^2$` | ||
|
|
||
| where Y_k is the target output given inputs N_{0,0} ... N_{I-1,0}. | ||
|
|
||
| # Implementation Details | ||
|
|
||
| ## The "ArtificialNeuralNetwork" class | ||
|
|
||
| The "ArtificialNeuralNetwork" class has the following constructor: | ||
|
|
||
| ``` | ||
| class ArtificialNeuralNetwork private(topology: Array[Int], maxNumIterations: Int, | ||
| convergenceTol: Double) | ||
| ``` | ||
|
|
||
| * `topology` is an array of integers indicating then number of nodes per layer. For example, if | ||
| "topology" holds (3, 5, 1), it means that there are three input nodes, five nodes in a single | ||
| hidden layer and 1 output node. | ||
| * `maxNumIterations` indicates the number of iterations after which the LBFGS algorithm must | ||
| have stopped. | ||
| * `convergenceTol` indicates the acceptable error, and if reached the LBFGS algorithm will | ||
| stop. A lower value of "convergenceTol" will give a higher precision. | ||
|
|
||
| ## The "ArtificialNeuralNetwork" object | ||
|
|
||
| The object "ArtificialNeuralNetwork" is the interface to the "ArtificialNeuralNetwork" class. | ||
| The object contains the training function. There are six different instances of the training | ||
| function, each for use with different parameters. All take as the first parameter the RDD | ||
| "input", which contains pairs of input and output vectors. | ||
|
|
||
| In addition, there are three functions for generating random weights. Two take a fixed seed, | ||
| which is useful for testing if one wants to start with the same weights in every test. | ||
|
|
||
| * `def train(trainingRDD: RDD[(Vector, Vector)], hiddenLayersTopology: Array[Int], | ||
| maxNumIterations: Int): ArtificialNeuralNetworkModel`: starts training with random initial | ||
| weights, and a default convergenceTol=1e-4. | ||
| * `def train(trainingRDD: RDD[(Vector, Vector)], model: ArtificialNeuralNetworkModel, | ||
| maxNumIterations: Int): ArtificialNeuralNetworkModel`: resumes training given an earlier | ||
| calculated model, and a default convergenceTol=1e-4. | ||
| * `def train(trainingRDD: RDD[(Vector,Vector)], hiddenLayersTopology: Array[Int], | ||
| initialWeights: Vector, maxNumIterations: Int): ArtificialNeuralNetworkModel`: Trains an ANN | ||
| with given initial weights, and a default convergenceTol=1e-4. | ||
| * `def train(trainingRDD: RDD[(Vector, Vector)], hiddenLayersTopology: Array[Int], | ||
| maxNumIterations: Int, convergenceTol: Double): ArtificialNeuralNetworkModel`: starts training | ||
| with random initial weights. Allows setting a customised "convergenceTol". | ||
| * `def train(trainingRDD: RDD[(Vector, Vector)], model: ArtificialNeuralNetworkModel, | ||
| maxNumIterations: Int, convergenceTol: Double): ArtificialNeuralNetworkModel`: resumes training | ||
| given an earlier calculated model. Allows setting a customised "convergenceTol". | ||
| * `def train(trainingRDD: RDD[(Vector,Vector)], hiddenLayersTopology: Array[Int], | ||
| initialWeights: Vector, maxNumIterations: Int, convergenceTol: Double): | ||
| ArtificialNeuralNetworkModel`: Trains an ANN with given initial weights. Allows setting a | ||
| customised "convergenceTol". | ||
| * `def randomWeights(trainingRDD: RDD[(Vector,Vector)], hiddenLayersTopology: Array[Int]): | ||
| Vector`: Generates a random weights vector. | ||
| *`def randomWeights(trainingRDD: RDD[(Vector,Vector)], hiddenLayersTopology: Array[Int], | ||
| seed: Int): Vector`: Generates a random weights vector with given seed. | ||
| *`def randomWeights(inputLayerSize: Int, outputLayerSize: Int, hiddenLayersTopology: Array[Int], | ||
| seed: Int): Vector`: Generates a random weights vector, using given random seed, input layer | ||
| size, hidden layers topology and output layer size. | ||
|
|
||
| Notice that the "hiddenLayersTopology" differs from the "topology" array. The | ||
| "hiddenLayersTopology" does not include the number of nodes in the input and output layers. The | ||
| number of nodes in input and output layers is calculated from the first element of the training | ||
| RDD. For example, the "topology" array (3, 5, 7, 1) would have a "hiddenLayersTopology" (5, 7), | ||
| the values 3 and 1 are deduced from the training data. The rationale for having these different | ||
| arrays is that future methods may have a different mapping between input values and input nodes | ||
| or output values and output nodes. | ||
|
|
||
| ## The "ArtificialNeuralNetworkModel" class | ||
|
|
||
| All training functions return the trained ANN using the class "ArtificialNeuralNetworkModel". | ||
| This class has the following function: | ||
|
|
||
| * `predict(testData: Vector): Vector` calculates the output vector given input vector | ||
| "testData". | ||
| * `predict(testData: RDD[Vector]): RDD[(Vector,Vector)]` returns (input, output) vector pairs, | ||
| using input vector pairs in "testData". | ||
|
|
||
| The weights used by "predict" come from the model. | ||
|
|
||
| ## Training | ||
|
|
||
| We have chosen to implement the ANN with LBFGS as optimiser function. We compared it with | ||
| Stochastic Gradient Descent. LBGFS was much faster, but in accordance is also earlier with | ||
| overfitting. | ||
|
|
||
| Science has provided many different strategies to train an ANN. Hence it is important that the | ||
| optimising functions in MLlib's ANN are interchangeable. A new optimisation strategy can be | ||
| implemented by creating a new class descending from ArtificialNeuralNetwork, and replacing the | ||
| optimiser, updater and possibly gradient as required. | ||
|
|
||
| # Demo and tests | ||
|
|
||
| Usage of MLlib's ANN is demonstrated through the "ANNDemo" demo program. The program generates | ||
| three functions: | ||
|
|
||
| * f2d: x -> y | ||
| * f3d: (x,y) -> z | ||
| * f4d: t -> (x,y,z) | ||
|
|
||
| It will calculate approximations of the target functions, and show a graphical representation | ||
| of the training set and the results after applying the testing set. | ||
|
|
||
| In addition, there are the following automated tests: | ||
|
|
||
| * "ANN learns XOR function": tests that the ANN can properly approximate an XOR function. | ||
| * "Gradient of ANN": tests that the output of the ANN gradient is roughly equal to an | ||
| approximated gradient. | ||
|
|
||
| # Conclusion | ||
|
|
||
| The "ArtificialNeuralNetwork" class implements a Artificial Neural Network (ANN), using the | ||
| LBFGS algorithm. It takes as input an RDD of input/output values of type "(Vector,Vector)", and | ||
| returns an object of type "ArtificialNeuralNetworkModel" containing the parameters of the | ||
| trained ANN. The "ArtificialNeuralNetworkModel" object can also be used to calculate results | ||
| after training. | ||
|
|
||
| The training of an ANN can be interrupted and later continued, allowing intermediate inspection | ||
| of the results. | ||
|
|
||
| A demo program and tests for ANN are provided. | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How common is it to have different number of nodes in the hidden layer? I am wondering whether there should be support for a simpler method
def train(rdd, numNodesHiddenLayers, maxNumIterations)and perhaps even a simplerdef train(rdd)with good default settings to help users get started.@mengxr @jkbradley Would the upcoming MLlib API feature make this suggestion moot with support for default parameters?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@manishamde given the size of PR @mengxr suggested to split it into multiple PRs. There is an implementation of a classifier that is based on this artificial neural network https://github.com/avulanov/spark/tree/annclassifier. It employes
RDD[LabeledPoint]and implements MLlibClassifier. Softmax output and cross-entropy error is usually used for better classification performance and they are not yet implemented. We've discussed this issue with @bgreeven and our thinking is to have interface in this PR that allows setting different error function and optimizer. Does it make sense?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@avulanov Thanks for the clarification. Sounds good to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@manishamde: In most cases, one hidden layer is enough. For some special functions two hidden layers are needed. This is a nice text about the choice of number of layers and number of nodes per layer:
http://www.heatonresearch.com/node/707
Especially the number of nodes per layer depends heavily on the particular problem.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bgreeven Thanks for the clarification.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@manishamde The upcoming API may make it a bit easier, but the current API here could support default parameters via a builder pattern for parameters. I'll take a closer look at this PR!