1+ #!/usr/bin/env python
2+
3+ #######################################################
4+ # Copyright (c) 2019, ArrayFire
5+ # All rights reserved.
6+ #
7+ # This file is distributed under 3-clause BSD license.
8+ # The complete license agreement can be obtained at:
9+ # http://arrayfire.com/licenses/BSD-3-Clause
10+ ########################################################
11+
12+ from mnist_common import display_results , setup_mnist
13+
14+ import sys
15+ import time
16+
17+ import arrayfire as af
18+
19+ def accuracy (predicted , target ):
20+ _ , tlabels = af .imax (target , axis = 1 )
21+ _ , plabels = af .imax (predicted , axis = 1 )
22+ return 100 * af .count (plabels == tlabels ) / tlabels .size
23+
24+
25+ def abserr (predicted , target ):
26+ return 100 * af .sum (af .abs (predicted - target )) / predicted .size
27+
28+
29+ # Predict (probability) based on given parameters
30+ def predict_prob (X , Weights ):
31+ Z = af .matmul (X , Weights )
32+ return af .sigmoid (Z )
33+
34+
35+ # Predict (log probability) based on given parameters
36+ def predict_log_prob (X , Weights ):
37+ return af .log (predict_prob (X , Weights ))
38+
39+
40+ # Give most likely class based on given parameters
41+ def predict_class (X , Weights ):
42+ probs = predict_prob (X , Weights )
43+ _ , classes = af .imax (probs , 1 )
44+ return classes
45+
46+
47+ def cost (Weights , X , Y , lambda_param = 1.0 ):
48+ # Number of samples
49+ m = Y .shape [0 ]
50+
51+ dim0 = Weights .shape [0 ]
52+ dim1 = Weights .shape [1 ] if len (Weights .shape ) > 1 else 1
53+ dim2 = Weights .shape [2 ] if len (Weights .shape ) > 2 else 1
54+ dim3 = Weights .shape [3 ] if len (Weights .shape ) > 3 else 1
55+ # Make the lambda corresponding to Weights(0) == 0
56+ lambdat = af .constant (lambda_param , (dim0 , dim1 , dim2 , dim3 ))
57+
58+ # No regularization for bias weights
59+ lambdat [0 , :] = 0
60+
61+ # Get the prediction
62+ H = predict_prob (X , Weights )
63+
64+ # Cost of misprediction
65+ Jerr = - 1 * af .sum (Y * af .log (H ) + (1 - Y ) * af .log (1 - H ), axis = 0 )
66+
67+ # Regularization cost
68+ Jreg = 0.5 * af .sum (lambdat * Weights * Weights , axis = 0 )
69+
70+ # Total cost
71+ J = (Jerr + Jreg ) / m
72+
73+ # Find the gradient of cost
74+ D = (H - Y )
75+ dJ = (af .matmul (X , D , af .MatProp .TRANS ) + lambdat * Weights ) / m
76+
77+ return J , dJ
78+
79+
80+ def train (X , Y , alpha = 0.1 , lambda_param = 1.0 , maxerr = 0.01 , maxiter = 1000 , verbose = False ):
81+ # Initialize parameters to 0
82+ Weights = af .constant (0 , (X .shape [1 ], Y .shape [1 ]))
83+
84+ for i in range (maxiter ):
85+ # Get the cost and gradient
86+ J , dJ = cost (Weights , X , Y , lambda_param )
87+
88+ err = af .max (af .abs (J ))
89+ if err < maxerr :
90+ print ('Iteration {0:4d} Err: {1:4f}' .format (i + 1 , err ))
91+ print ('Training converged' )
92+ return Weights
93+
94+ if verbose and ((i + 1 ) % 10 == 0 ):
95+ print ('Iteration {0:4d} Err: {1:4f}' .format (i + 1 , err ))
96+
97+ # Update the parameters via gradient descent
98+ Weights = Weights - alpha * dJ
99+
100+ if verbose :
101+ print ('Training stopped after {0:d} iterations' .format (maxiter ))
102+
103+ return Weights
104+
105+
106+ def benchmark_logistic_regression (train_feats , train_targets , test_feats ):
107+ t0 = time .time ()
108+ Weights = train (train_feats , train_targets , 0.1 , 1.0 , 0.01 , 1000 )
109+ af .eval (Weights )
110+ af .sync (- 1 )
111+ t1 = time .time ()
112+ dt = t1 - t0
113+ print ('Training time: {0:4.4f} s' .format (dt ))
114+
115+ t0 = time .time ()
116+ iters = 100
117+ for i in range (iters ):
118+ test_outputs = predict_prob (test_feats , Weights )
119+ af .eval (test_outputs )
120+ af .sync (- 1 )
121+ t1 = time .time ()
122+ dt = t1 - t0
123+ print ('Prediction time: {0:4.4f} s' .format (dt / iters ))
124+
125+
126+ # Demo of one vs all logistic regression
127+ def logit_demo (console , perc ):
128+ # Load mnist data
129+ frac = float (perc ) / 100.0
130+ mnist_data = setup_mnist (frac , True )
131+ num_classes = mnist_data [0 ]
132+ num_train = mnist_data [1 ]
133+ num_test = mnist_data [2 ]
134+ train_images = mnist_data [3 ]
135+ test_images = mnist_data [4 ]
136+ train_targets = mnist_data [5 ]
137+ test_targets = mnist_data [6 ]
138+
139+ # Reshape images into feature vectors
140+ feature_length = int (train_images .size / num_train );
141+ train_feats = af .transpose (af .moddims (train_images , (feature_length , num_train )))
142+
143+
144+ test_feats = af .transpose (af .moddims (test_images , (feature_length , num_test )))
145+
146+ train_targets = af .transpose (train_targets )
147+ test_targets = af .transpose (test_targets )
148+
149+ num_train = train_feats .shape [0 ]
150+ num_test = test_feats .shape [0 ]
151+
152+
153+ # Add a bias that is always 1
154+ train_bias = af .constant (1 , (num_train , 1 ))
155+ test_bias = af .constant (1 , (num_test , 1 ))
156+ train_feats = af .join (1 , train_bias , train_feats )
157+ test_feats = af .join (1 , test_bias , test_feats )
158+
159+
160+ # Train logistic regression parameters
161+ Weights = train (train_feats , train_targets ,
162+ 0.1 , # learning rate
163+ 1.0 , # regularization constant
164+ 0.01 , # max error
165+ 1000 , # max iters
166+ True # verbose mode
167+ )
168+ af .eval (Weights )
169+ af .sync (- 1 )
170+
171+ # Predict the results
172+ train_outputs = predict_prob (train_feats , Weights )
173+ test_outputs = predict_prob (test_feats , Weights )
174+
175+ print ('Accuracy on training data: {0:2.2f}' .format (accuracy (train_outputs , train_targets )))
176+ print ('Accuracy on testing data: {0:2.2f}' .format (accuracy (test_outputs , test_targets )))
177+ print ('Maximum error on testing data: {0:2.2f}' .format (abserr (test_outputs , test_targets )))
178+
179+ benchmark_logistic_regression (train_feats , train_targets , test_feats )
180+
181+ if not console :
182+ test_outputs = af .transpose (test_outputs )
183+ # Get 20 random test images
184+ display_results (test_images , test_outputs , af .transpose (test_targets ), 20 , True )
185+
186+ def main ():
187+ argc = len (sys .argv )
188+
189+ device = int (sys .argv [1 ]) if argc > 1 else 0
190+ console = sys .argv [2 ][0 ] == '-' if argc > 2 else False
191+ perc = int (sys .argv [3 ]) if argc > 3 else 60
192+
193+
194+ try :
195+ af .set_device (device )
196+ af .info ()
197+ logit_demo (console , perc )
198+ except Exception as e :
199+ print ('Error: ' , str (e ))
200+
201+
202+ if __name__ == '__main__' :
203+ main ()
0 commit comments