@@ -6,18 +6,34 @@ import { ConfidenceWeighted, SoftConfidenceWeighted } from '../../../lib/model/c
66
77import { accuracy } from '../../../lib/evaluate/classification.js'
88
9- test ( 'ConfidenceWeighted' , ( ) => {
10- const model = new ConfidenceWeighted ( 0.9 )
11- const x = Matrix . concat ( Matrix . randn ( 50 , 2 , 0 , 0.2 ) , Matrix . randn ( 50 , 2 , 5 , 0.2 ) ) . toArray ( )
12- const t = [ ]
13- for ( let i = 0 ; i < x . length ; i ++ ) {
14- t [ i ] = Math . floor ( i / 50 ) * 2 - 1
15- }
16- model . init ( x , t )
17- model . fit ( )
18- const y = model . predict ( x )
19- const acc = accuracy ( y , t )
20- expect ( acc ) . toBeGreaterThan ( 0.95 )
9+ describe ( 'ConfidenceWeighted' , ( ) => {
10+ test ( 'normal eta' , ( ) => {
11+ const model = new ConfidenceWeighted ( 0.9 )
12+ const x = Matrix . concat ( Matrix . randn ( 50 , 2 , 0 , 0.2 ) , Matrix . randn ( 50 , 2 , 5 , 0.2 ) ) . toArray ( )
13+ const t = [ ]
14+ for ( let i = 0 ; i < x . length ; i ++ ) {
15+ t [ i ] = Math . floor ( i / 50 ) * 2 - 1
16+ }
17+ model . init ( x , t )
18+ model . fit ( )
19+ const y = model . predict ( x )
20+ const acc = accuracy ( y , t )
21+ expect ( acc ) . toBeGreaterThan ( 0.95 )
22+ } )
23+
24+ test . each ( [ 1 , 0.5 , 0 ] ) ( 'eta %p' , eta => {
25+ const model = new ConfidenceWeighted ( eta )
26+ const x = Matrix . concat ( Matrix . randn ( 50 , 2 , 0 , 0.2 ) , Matrix . randn ( 50 , 2 , 5 , 0.2 ) ) . toArray ( )
27+ const t = [ ]
28+ for ( let i = 0 ; i < x . length ; i ++ ) {
29+ t [ i ] = Math . floor ( i / 50 ) * 2 - 1
30+ }
31+ model . init ( x , t )
32+ model . fit ( )
33+ const y = model . predict ( x )
34+ const acc = accuracy ( y , t )
35+ expect ( acc ) . toBeCloseTo ( 0.5 )
36+ } )
2137} )
2238
2339test . each ( [ 1 , 2 ] ) ( 'SoftConfidenceWeighted %d' , version => {
0 commit comments