@@ -433,7 +433,59 @@ describe('Tensor', () => {
433433 } )
434434 } )
435435
436- test . each ( [ - 1 , 1 ] ) ( 'axis %i' , i => {
436+ describe ( 'axis 1' , ( ) => {
437+ test . each ( [ 0 , 1 , 2 ] ) ( 'scalar %p' , k => {
438+ const ten = Tensor . randn ( [ 3 , 4 , 5 ] )
439+ const slice = ten . select ( k , 1 )
440+ expect ( slice . sizes ) . toEqual ( [ 3 , 1 , 5 ] )
441+ for ( let i = 0 ; i < 3 ; i ++ ) {
442+ for ( let j = 0 ; j < 5 ; j ++ ) {
443+ expect ( slice . at ( i , 0 , j ) ) . toBe ( ten . at ( i , k , j ) )
444+ }
445+ }
446+ } )
447+
448+ test . each ( [ [ [ 0 ] ] , [ [ 1 ] ] , [ [ 2 ] ] , [ [ 0 , 0 ] ] , [ [ 1 , 2 ] ] , [ [ 2 , 0 ] ] ] ) ( 'array %p' , k => {
449+ const ten = Tensor . randn ( [ 3 , 4 , 5 ] )
450+ const slice = ten . select ( k , 1 )
451+ expect ( slice . sizes ) . toEqual ( [ 3 , k . length , 5 ] )
452+ for ( let t = 0 ; t < k . length ; t ++ ) {
453+ for ( let i = 0 ; i < 3 ; i ++ ) {
454+ for ( let j = 0 ; j < 5 ; j ++ ) {
455+ expect ( slice . at ( i , t , j ) ) . toBe ( ten . at ( i , k [ t ] , j ) )
456+ }
457+ }
458+ }
459+ } )
460+ } )
461+
462+ describe ( 'axis 2' , ( ) => {
463+ test . each ( [ 0 , 1 , 2 ] ) ( 'scalar %p' , k => {
464+ const ten = Tensor . randn ( [ 3 , 4 , 5 ] )
465+ const slice = ten . select ( k , 2 )
466+ expect ( slice . sizes ) . toEqual ( [ 3 , 4 , 1 ] )
467+ for ( let i = 0 ; i < 3 ; i ++ ) {
468+ for ( let j = 0 ; j < 4 ; j ++ ) {
469+ expect ( slice . at ( i , j , 0 ) ) . toBe ( ten . at ( i , j , k ) )
470+ }
471+ }
472+ } )
473+
474+ test . each ( [ [ [ 0 ] ] , [ [ 1 ] ] , [ [ 2 ] ] , [ [ 0 , 0 ] ] , [ [ 1 , 2 ] ] , [ [ 2 , 0 ] ] ] ) ( 'array %p' , k => {
475+ const ten = Tensor . randn ( [ 3 , 4 , 5 ] )
476+ const slice = ten . select ( k , 2 )
477+ expect ( slice . sizes ) . toEqual ( [ 3 , 4 , k . length ] )
478+ for ( let t = 0 ; t < k . length ; t ++ ) {
479+ for ( let i = 0 ; i < 3 ; i ++ ) {
480+ for ( let j = 0 ; j < 4 ; j ++ ) {
481+ expect ( slice . at ( i , j , t ) ) . toBe ( ten . at ( i , j , k [ t ] ) )
482+ }
483+ }
484+ }
485+ } )
486+ } )
487+
488+ test . each ( [ - 1 , 3 ] ) ( 'axis %i' , i => {
437489 const ten = new Tensor ( [ 2 , 3 , 4 ] )
438490 expect ( ( ) => ten . select ( 0 , i ) ) . toThrow ( 'Invalid axis.' )
439491 } )
@@ -618,7 +670,61 @@ describe('Tensor', () => {
618670 }
619671 } )
620672
621- test . each ( [ - 1 , 1 ] ) ( 'fail invalid axis %p' , axis => {
673+ test ( 'axis 1' , ( ) => {
674+ const org = Tensor . randn ( [ 3 , 4 , 5 ] )
675+ const ten = org . copy ( )
676+ ten . shuffle ( 1 )
677+
678+ const expidx = [ ]
679+ for ( let t = 0 ; t < org . sizes [ 1 ] ; t ++ ) {
680+ for ( let i = 0 ; i < org . sizes [ 1 ] ; i ++ ) {
681+ let flg = true
682+ for ( let j = 0 ; j < org . sizes [ 0 ] ; j ++ ) {
683+ for ( let k = 0 ; k < org . sizes [ 2 ] ; k ++ ) {
684+ flg &= ten . at ( j , t , k ) === org . at ( j , i , k )
685+ }
686+ }
687+ if ( flg ) {
688+ expidx . push ( i )
689+ break
690+ }
691+ }
692+ }
693+ expidx . sort ( ( a , b ) => a - b )
694+ expect ( expidx ) . toHaveLength ( org . sizes [ 1 ] )
695+ for ( let i = 0 ; i < org . sizes [ 1 ] ; i ++ ) {
696+ expect ( expidx [ i ] ) . toBe ( i )
697+ }
698+ } )
699+
700+ test ( 'axis 2' , ( ) => {
701+ const org = Tensor . randn ( [ 3 , 4 , 5 ] )
702+ const ten = org . copy ( )
703+ ten . shuffle ( 2 )
704+
705+ const expidx = [ ]
706+ for ( let t = 0 ; t < org . sizes [ 2 ] ; t ++ ) {
707+ for ( let i = 0 ; i < org . sizes [ 2 ] ; i ++ ) {
708+ let flg = true
709+ for ( let j = 0 ; j < org . sizes [ 0 ] ; j ++ ) {
710+ for ( let k = 0 ; k < org . sizes [ 1 ] ; k ++ ) {
711+ flg &= ten . at ( j , k , t ) === org . at ( j , k , i )
712+ }
713+ }
714+ if ( flg ) {
715+ expidx . push ( i )
716+ break
717+ }
718+ }
719+ }
720+ expidx . sort ( ( a , b ) => a - b )
721+ expect ( expidx ) . toHaveLength ( org . sizes [ 2 ] )
722+ for ( let i = 0 ; i < org . sizes [ 2 ] ; i ++ ) {
723+ expect ( expidx [ i ] ) . toBe ( i )
724+ }
725+ } )
726+
727+ test . each ( [ - 1 , 3 ] ) ( 'fail invalid axis %p' , axis => {
622728 const mat = Tensor . randn ( [ 2 , 3 , 4 ] )
623729 expect ( ( ) => mat . shuffle ( axis ) ) . toThrow ( 'Invalid axis.' )
624730 } )
0 commit comments