@@ -2724,6 +2724,14 @@ public static Tensor<T> PermuteDimensions<T>(this Tensor<T> tensor, params ReadO
27242724 /// <param name="lengths"><see cref="ReadOnlySpan{T}"/> with the new dimensions.</param>
27252725 public static Tensor < T > Reshape < T > ( this Tensor < T > tensor , params ReadOnlySpan < nint > lengths )
27262726 {
2727+ if ( tensor . Lengths . SequenceEqual ( lengths ) )
2728+ return tensor ;
2729+
2730+ if ( ! TensorHelpers . IsContiguousAndDense < T > ( tensor ) & & ! tensor . Strides . Contains ( 0 ) )
2731+ {
2732+ ThrowHelper . ThrowArgument_CannotReshapeNonContiguousOrDense ( ) ;
2733+ }
2734+
27272735 nint [ ] arrLengths = lengths . ToArray ( ) ;
27282736 // Calculate wildcard info.
27292737 if ( lengths . Contains ( - 1 ) )
@@ -2745,7 +2753,33 @@ public static Tensor<T> Reshape<T>(this Tensor<T> tensor, params ReadOnlySpan<ni
27452753 nint tempLinear = TensorSpanHelpers . CalculateTotalLength ( arrLengths ) ;
27462754 if ( tempLinear != tensor . FlattenedLength )
27472755 ThrowHelper . ThrowArgument_InvalidReshapeDimensions ( ) ;
2748- nint [ ] strides = TensorSpanHelpers . CalculateStrides ( arrLengths ) ;
2756+
2757+ nint [ ] strides ;
2758+
2759+ // If we contain a 0 stride we can only add dimensions of length 1.
2760+ if ( tensor . Strides . Contains ( 0 ) )
2761+ {
2762+ List < nint > origStrides = new List < nint > ( tensor . Strides . ToArray ( ) ) ;
2763+ int lengthOffset = 0 ;
2764+ for ( int i = 0 ; i < arrLengths . Length ; i ++ )
2765+ {
2766+ if ( lengthOffset < tensor . Rank && arrLengths [ i ] == tensor . Lengths [ lengthOffset ] )
2767+ lengthOffset ++ ;
2768+ else if ( arrLengths [ i ] == 1 )
2769+ {
2770+ if ( lengthOffset == tensor . Rank )
2771+ origStrides . Add ( tensor . Strides [ lengthOffset - 1 ] ) ;
2772+ else
2773+ origStrides . Insert ( i , tensor . Strides [ i ] * tensor . Lengths [ i ] ) ;
2774+ }
2775+ else
2776+ ThrowHelper . ThrowArgument_InvalidReshapeDimensions ( ) ;
2777+ }
2778+ strides = origStrides . ToArray ( ) ;
2779+ }
2780+ else
2781+ strides = TensorSpanHelpers . CalculateStrides ( arrLengths ) ;
2782+
27492783 return new Tensor < T > ( tensor . _values , arrLengths , strides ) ;
27502784 }
27512785
@@ -2758,6 +2792,14 @@ public static Tensor<T> Reshape<T>(this Tensor<T> tensor, params ReadOnlySpan<ni
27582792 /// <param name="lengths"><see cref="ReadOnlySpan{T}"/> with the new dimensions.</param>
27592793 public static TensorSpan < T > Reshape < T > ( in this TensorSpan < T > tensor , params scoped ReadOnlySpan < nint > lengths )
27602794 {
2795+ if ( tensor . Lengths . SequenceEqual ( lengths ) )
2796+ return tensor ;
2797+
2798+ if ( ! TensorHelpers . IsContiguousAndDense < T > ( tensor ) & & ! tensor . Strides . Contains ( 0 ) )
2799+ {
2800+ ThrowHelper . ThrowArgument_CannotReshapeNonContiguousOrDense ( ) ;
2801+ }
2802+
27612803 nint [ ] arrLengths = lengths . ToArray ( ) ;
27622804 // Calculate wildcard info.
27632805 if ( lengths . Contains ( - 1 ) )
@@ -2779,7 +2821,35 @@ public static TensorSpan<T> Reshape<T>(in this TensorSpan<T> tensor, params scop
27792821 nint tempLinear = TensorSpanHelpers . CalculateTotalLength ( arrLengths ) ;
27802822 if ( tempLinear != tensor . FlattenedLength )
27812823 ThrowHelper . ThrowArgument_InvalidReshapeDimensions ( ) ;
2782- nint [ ] strides = TensorSpanHelpers . CalculateStrides ( arrLengths ) ;
2824+
2825+ nint [ ] strides ;
2826+
2827+ // If we contain a 0 stride we can only add dimensions of length 1.
2828+ if ( tensor . Strides . Contains ( 0 ) )
2829+ {
2830+ List < nint > origStrides = new List < nint > ( tensor . Strides . ToArray ( ) ) ;
2831+ int lengthOffset = 0 ;
2832+ for ( int i = 0 ; i < arrLengths . Length ; i ++ )
2833+ {
2834+ if ( lengthOffset < tensor . Rank && arrLengths [ i ] == tensor . Lengths [ lengthOffset ] )
2835+ {
2836+ lengthOffset ++ ;
2837+ }
2838+ else if ( arrLengths [ i ] == 1 )
2839+ {
2840+ if ( lengthOffset == tensor . Rank )
2841+ origStrides . Add ( tensor . Strides [ lengthOffset - 1 ] ) ;
2842+ else
2843+ origStrides . Insert ( i , tensor . Strides [ i ] * tensor . Lengths [ i ] ) ;
2844+ }
2845+ else
2846+ ThrowHelper . ThrowArgument_InvalidReshapeDimensions ( ) ;
2847+ }
2848+ strides = origStrides . ToArray ( ) ;
2849+ }
2850+ else
2851+ strides = TensorSpanHelpers . CalculateStrides ( arrLengths ) ;
2852+
27832853 TensorSpan < T > output = new TensorSpan < T > ( ref tensor . _reference , arrLengths , strides , tensor . _shape . _memoryLength ) ;
27842854 return output ;
27852855 }
@@ -2793,6 +2863,14 @@ public static TensorSpan<T> Reshape<T>(in this TensorSpan<T> tensor, params scop
27932863 /// <param name="lengths"><see cref="ReadOnlySpan{T}"/> with the new dimensions.</param>
27942864 public static ReadOnlyTensorSpan < T > Reshape < T > ( in this ReadOnlyTensorSpan < T > tensor , params scoped ReadOnlySpan < nint > lengths )
27952865 {
2866+ if ( tensor . Lengths . SequenceEqual ( lengths ) )
2867+ return tensor ;
2868+
2869+ if ( ! TensorHelpers . IsContiguousAndDense < T > ( tensor ) & & ! tensor . Strides . Contains ( 0 ) )
2870+ {
2871+ ThrowHelper . ThrowArgument_CannotReshapeNonContiguousOrDense ( ) ;
2872+ }
2873+
27962874 nint [ ] arrLengths = lengths . ToArray ( ) ;
27972875 // Calculate wildcard info.
27982876 if ( lengths . Contains ( - 1 ) )
@@ -2814,7 +2892,33 @@ public static ReadOnlyTensorSpan<T> Reshape<T>(in this ReadOnlyTensorSpan<T> ten
28142892 nint tempLinear = TensorSpanHelpers . CalculateTotalLength ( arrLengths ) ;
28152893 if ( tempLinear != tensor . FlattenedLength )
28162894 ThrowHelper . ThrowArgument_InvalidReshapeDimensions ( ) ;
2817- nint [ ] strides = TensorSpanHelpers . CalculateStrides ( arrLengths ) ;
2895+
2896+ nint [ ] strides ;
2897+
2898+ // If we contain a 0 stride we can only add dimensions of length 1.
2899+ if ( tensor . Strides . Contains ( 0 ) )
2900+ {
2901+ List < nint > origStrides = new List < nint > ( tensor . Strides . ToArray ( ) ) ;
2902+ int lengthOffset = 0 ;
2903+ for ( int i = 0 ; i < arrLengths . Length ; i ++ )
2904+ {
2905+ if ( lengthOffset < tensor . Rank && arrLengths [ i ] == tensor . Lengths [ lengthOffset ] )
2906+ lengthOffset ++ ;
2907+ else if ( arrLengths [ i ] == 1 )
2908+ {
2909+ if ( lengthOffset == tensor . Rank )
2910+ origStrides . Add ( tensor . Strides [ lengthOffset - 1 ] ) ;
2911+ else
2912+ origStrides . Insert ( i , tensor . Strides [ i ] * tensor . Lengths [ i ] ) ;
2913+ }
2914+ else
2915+ ThrowHelper . ThrowArgument_InvalidReshapeDimensions ( ) ;
2916+ }
2917+ strides = origStrides . ToArray ( ) ;
2918+ }
2919+ else
2920+ strides = TensorSpanHelpers . CalculateStrides ( arrLengths ) ;
2921+
28182922 ReadOnlyTensorSpan < T > output = new ReadOnlyTensorSpan < T > ( ref tensor . _reference , arrLengths , strides , tensor . _shape . _memoryLength ) ;
28192923 return output ;
28202924 }
@@ -3053,14 +3157,17 @@ public static ref readonly TensorSpan<T> SetSlice<T>(this in TensorSpan<T> tenso
30533157 TensorSpan < T > srcSpan ;
30543158 if ( ranges == ReadOnlySpan < NRange > . Empty )
30553159 {
3056- if ( ! tensor . Lengths . SequenceEqual ( values . Lengths ) )
3160+ if ( ! TensorHelpers . IsBroadcastableTo ( values . Lengths , tensor . Lengths ) )
30573161 ThrowHelper . ThrowArgument_SetSliceNoRange ( nameof ( values ) ) ;
3058- srcSpan = tensor . Slice ( tensor . Lengths ) ;
3162+ srcSpan = tensor ;
30593163 }
30603164 else
30613165 srcSpan = tensor . Slice ( ranges ) ;
30623166
3063- if ( ! srcSpan . Lengths . SequenceEqual ( values . Lengths ) )
3167+ if ( ! TensorHelpers . IsContiguousAndDense < T > ( srcSpan ) )
3168+ ThrowHelper . ThrowArgument_SetSliceInvalidShapes ( nameof ( values ) ) ;
3169+
3170+ if ( ! TensorHelpers . IsBroadcastableTo ( values . Lengths , srcSpan . Lengths ) )
30643171 ThrowHelper . ThrowArgument_SetSliceInvalidShapes ( nameof ( values ) ) ;
30653172
30663173 values . CopyTo ( srcSpan ) ;
@@ -3555,8 +3662,13 @@ public static Tensor<T> Unsqueeze<T>(this Tensor<T> tensor, int dimension)
35553662
35563663 List < nint > tempLengths = tensor . _lengths . ToList ( ) ;
35573664 tempLengths . Insert ( dimension , 1 ) ;
3558- nint [ ] lengths = tempLengths . ToArray ( ) ;
3559- nint [ ] strides = TensorSpanHelpers . CalculateStrides ( lengths ) ;
3665+ nint [ ] lengths = [ .. tempLengths ] ;
3666+ List < nint > tempStrides = tensor . Strides . ToArray ( ) . ToList ( ) ;
3667+ if ( dimension == tensor . Rank )
3668+ tempStrides . Add ( tensor . Strides [ dimension - 1 ] ) ;
3669+ else
3670+ tempStrides . Insert ( dimension , tensor . Strides [ dimension ] * tensor . Lengths [ dimension ] ) ;
3671+ nint [ ] strides = [ .. tempStrides ] ;
35603672 return new Tensor < T > ( tensor . _values , lengths , strides ) ;
35613673 }
35623674
@@ -3574,8 +3686,13 @@ public static TensorSpan<T> Unsqueeze<T>(in this TensorSpan<T> tensor, int dimen
35743686
35753687 List < nint > tempLengths = tensor . Lengths . ToArray ( ) . ToList ( ) ;
35763688 tempLengths . Insert ( dimension , 1 ) ;
3577- nint [ ] lengths = tempLengths . ToArray ( ) ;
3578- nint [ ] strides = TensorSpanHelpers . CalculateStrides ( lengths ) ;
3689+ nint [ ] lengths = [ .. tempLengths ] ;
3690+ List < nint > tempStrides = tensor . Strides . ToArray ( ) . ToList ( ) ;
3691+ if ( dimension == tensor . Rank )
3692+ tempStrides . Add ( tensor . Strides [ dimension - 1 ] ) ;
3693+ else
3694+ tempStrides . Insert ( dimension , tensor . Strides [ dimension ] * tensor . Lengths [ dimension ] ) ;
3695+ nint [ ] strides = [ .. tempStrides ] ;
35793696 return new TensorSpan < T > ( ref tensor . _reference , lengths , strides , tensor . _shape . _memoryLength ) ;
35803697 }
35813698
@@ -3593,8 +3710,13 @@ public static ReadOnlyTensorSpan<T> Unsqueeze<T>(in this ReadOnlyTensorSpan<T> t
35933710
35943711 List < nint > tempLengths = tensor . Lengths . ToArray ( ) . ToList ( ) ;
35953712 tempLengths . Insert ( dimension , 1 ) ;
3596- nint [ ] lengths = tempLengths . ToArray ( ) ;
3597- nint [ ] strides = TensorSpanHelpers . CalculateStrides ( lengths ) ;
3713+ nint [ ] lengths = [ .. tempLengths ] ;
3714+ List < nint > tempStrides = tensor . Strides . ToArray ( ) . ToList ( ) ;
3715+ if ( dimension == tensor . Rank )
3716+ tempStrides . Add ( tensor . Strides [ dimension - 1 ] ) ;
3717+ else
3718+ tempStrides . Insert ( dimension , tensor . Strides [ dimension ] * tensor . Lengths [ dimension ] ) ;
3719+ nint [ ] strides = [ .. tempStrides ] ;
35983720 return new ReadOnlyTensorSpan < T > ( ref tensor . _reference , lengths , strides , tensor . _shape . _memoryLength ) ;
35993721 }
36003722 #endregion
0 commit comments