@@ -10670,39 +10670,45 @@ GenTree* Compiler::fgOptimizeHWIntrinsic(GenTreeHWIntrinsic* node)
1067010670#endif // TARGET_ARM64
1067110671 case NI_Vector128_Create:
1067210672 {
10673- // The `Dot` API returns a scalar. However, many common usages require it to
10674- // be then immediately broadcast back to a vector so that it can be used in
10675- // a subsequent operation. One of the most common is normalizing a vector
10673+ // The managed `Dot` API returns a scalar. However, many common usages require
10674+ // it to be then immediately broadcast back to a vector so that it can be used
10675+ // in a subsequent operation. One of the most common is normalizing a vector
1067610676 // which is effectively `value / value.Length` where `Length` is
10677- // `Sqrt(Dot(value, value))`
10677+ // `Sqrt(Dot(value, value))`. Because of this, and because of how a lot of
10678+ // hardware works, we treat `NI_Vector_Dot` as returning a SIMD type and then
10679+ // also wrap it in `ToScalar` where required.
1067810680 //
1067910681 // In order to ensure that developers can still utilize this efficiently, we
10680- // will look for two common patterns:
10682+ // then look for four common patterns:
1068110683 // * Create(Dot(..., ...))
1068210684 // * Create(Sqrt(Dot(..., ...)))
10685+ // * Create(ToScalar(Dot(..., ...)))
10686+ // * Create(ToScalar(Sqrt(Dot(..., ...))))
1068310687 //
10684- // When these exist, we'll avoid converting to a scalar at all and just
10685- // keep everything as a vector. However, we only do this for Vector64/Vector128
10686- // and only for float/double.
10688+ // When these exist, we'll avoid converting to a scalar and hence, avoid broadcasting
10689+ // the value back into a vector. Instead we'll just keep everything as a vector.
1068710690 //
10688- // We don't do this for Vector256 since that is xarch only and doesn't trivially
10689- // support operations which cross the upper and lower 128-bit lanes
10691+ // We only do this for Vector64/Vector128 today. We could expand this more in
10692+ // the future but it would require additional hand handling for Vector256
10693+ // (since a 256-bit result requires more work). We do some integer handling
10694+ // when the value is trivially replicated to all elements without extra work.
1069010695
1069110696 if (node->GetOperandCount() != 1)
1069210697 {
1069310698 break;
1069410699 }
1069510700
10696- if (!varTypeIsFloating(node->GetSimdBaseType()))
10697- {
10698- break;
10699- }
10700-
10701- GenTree* op1 = node->Op(1);
10702- GenTree* sqrt = nullptr;
10701+ GenTree* op1 = node->Op(1);
10702+ GenTree* sqrt = nullptr;
10703+ GenTree* toScalar = nullptr;
1070310704
1070410705 if (op1->OperIs(GT_INTRINSIC))
1070510706 {
10707+ if (!varTypeIsFloating(node->GetSimdBaseType()))
10708+ {
10709+ break;
10710+ }
10711+
1070610712 if (op1->AsIntrinsic()->gtIntrinsicName != NI_System_Math_Sqrt)
1070710713 {
1070810714 break;
@@ -10719,6 +10725,24 @@ GenTree* Compiler::fgOptimizeHWIntrinsic(GenTreeHWIntrinsic* node)
1071910725
1072010726 GenTreeHWIntrinsic* hwop1 = op1->AsHWIntrinsic();
1072110727
10728+ #if defined(TARGET_ARM64)
10729+ if ((hwop1->GetHWIntrinsicId() == NI_Vector64_ToScalar) ||
10730+ (hwop1->GetHWIntrinsicId() == NI_Vector128_ToScalar))
10731+ #else
10732+ if (hwop1->GetHWIntrinsicId() == NI_Vector128_ToScalar)
10733+ #endif
10734+ {
10735+ op1 = hwop1->Op(1);
10736+
10737+ if (!op1->OperIs(GT_HWINTRINSIC))
10738+ {
10739+ break;
10740+ }
10741+
10742+ toScalar = hwop1;
10743+ hwop1 = op1->AsHWIntrinsic();
10744+ }
10745+
1072210746#if defined(TARGET_ARM64)
1072310747 if ((hwop1->GetHWIntrinsicId() != NI_Vector64_Dot) && (hwop1->GetHWIntrinsicId() != NI_Vector128_Dot))
1072410748#else
@@ -10728,13 +10752,16 @@ GenTree* Compiler::fgOptimizeHWIntrinsic(GenTreeHWIntrinsic* node)
1072810752 break;
1072910753 }
1073010754
10731- unsigned simdSize = node->GetSimdSize();
10732- var_types simdType = getSIMDTypeForSize(simdSize);
10733-
10734- hwop1->gtType = simdType;
10755+ if (toScalar != nullptr)
10756+ {
10757+ DEBUG_DESTROY_NODE(toScalar);
10758+ }
1073510759
1073610760 if (sqrt != nullptr)
1073710761 {
10762+ unsigned simdSize = node->GetSimdSize();
10763+ var_types simdType = getSIMDTypeForSize(simdSize);
10764+
1073810765 node = gtNewSimdSqrtNode(simdType, hwop1, node->GetSimdBaseJitType(), simdSize)->AsHWIntrinsic();
1073910766 DEBUG_DESTROY_NODE(sqrt);
1074010767 }
0 commit comments