From ce7e7040877ff83f6170413bc54582623d5b861e Mon Sep 17 00:00:00 2001 From: leimao Date: Mon, 17 Nov 2025 20:53:33 -0800 Subject: [PATCH 1/3] Fix Bugs --- cpp/src/types.cpp | 26 +++++++++----------------- 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/cpp/src/types.cpp b/cpp/src/types.cpp index 69b956a162..93904ff65d 100644 --- a/cpp/src/types.cpp +++ b/cpp/src/types.cpp @@ -126,27 +126,20 @@ DataType::DataType(c10::ScalarType t) { "Data type is unsupported (" << t << ")"); switch (t) { case at::kHalf: - value = DataType::kHalf; - break; + return DataType::kHalf; case at::kChar: - value = DataType::kChar; - break; + return DataType::kChar; case at::kInt: - value = DataType::kInt; - break; + return DataType::kInt; case at::kLong: - value = DataType::kLong; - break; + return DataType::kLong; case at::kDouble: - value = DataType::kDouble; - break; + return DataType::kDouble; case at::kBool: - value = DataType::kBool; - break; + return DataType::kBool; case at::kFloat: default: - value = DataType::kFloat; - break; + return DataType::kFloat; } } @@ -157,11 +150,10 @@ TensorFormat::TensorFormat(at::MemoryFormat t) { switch (t) { case at::MemoryFormat::ChannelsLast: - value = TensorFormat::kChannelsLast; + return TensorFormat::kChannelsLast; case at::MemoryFormat::Contiguous: default: - value = TensorFormat::kContiguous; - break; + return TensorFormat::kContiguous; } } From 84ac1026575b50d9257dbef0f26390f23f57cf41 Mon Sep 17 00:00:00 2001 From: leimao Date: Mon, 17 Nov 2025 20:57:55 -0800 Subject: [PATCH 2/3] Fix Bugs --- core/lowering/register_trt_placeholder_ops.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/core/lowering/register_trt_placeholder_ops.cpp b/core/lowering/register_trt_placeholder_ops.cpp index d083c71715..ccdeea0763 100644 --- a/core/lowering/register_trt_placeholder_ops.cpp +++ b/core/lowering/register_trt_placeholder_ops.cpp @@ -20,7 +20,6 @@ RegisterOperators trt_placeholder_ops_reg({ [](Stack& stack) { auto attn_mask = pop(stack).to(); if (attn_mask.scalar_type() == at::kBool) { - attn_mask = attn_mask; attn_mask.masked_fill_(attn_mask.logical_not(), -std::numeric_limits::infinity()); } return attn_mask; From 5f6c5081c71b5093a5e7c240f4b78d91e7eb7452 Mon Sep 17 00:00:00 2001 From: leimao Date: Mon, 17 Nov 2025 21:36:28 -0800 Subject: [PATCH 3/3] Fix Bugs --- cpp/src/types.cpp | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/cpp/src/types.cpp b/cpp/src/types.cpp index 93904ff65d..8b7569636c 100644 --- a/cpp/src/types.cpp +++ b/cpp/src/types.cpp @@ -126,20 +126,27 @@ DataType::DataType(c10::ScalarType t) { "Data type is unsupported (" << t << ")"); switch (t) { case at::kHalf: - return DataType::kHalf; + value = DataType::kHalf; + break; case at::kChar: - return DataType::kChar; + value = DataType::kChar; + break; case at::kInt: - return DataType::kInt; + value = DataType::kInt; + break; case at::kLong: - return DataType::kLong; + value = DataType::kLong; + break; case at::kDouble: - return DataType::kDouble; + value = DataType::kDouble; + break; case at::kBool: - return DataType::kBool; + value = DataType::kBool; + break; case at::kFloat: default: - return DataType::kFloat; + value = DataType::kFloat; + break; } } @@ -150,10 +157,12 @@ TensorFormat::TensorFormat(at::MemoryFormat t) { switch (t) { case at::MemoryFormat::ChannelsLast: - return TensorFormat::kChannelsLast; + value = TensorFormat::kChannelsLast; + break; case at::MemoryFormat::Contiguous: default: - return TensorFormat::kContiguous; + value = TensorFormat::kContiguous; + break; } }