Skip to content

Commit 0cd1880

Browse files
Will Fengfacebook-github-bot
authored andcommitted
Add std::variant backport as c10::variant (pytorch#26836)
Summary: Pull Request resolved: pytorch#26836 * **pytorch#26836 Add std::variant backport as c10::variant** Test Plan: Imported from OSS Differential Revision: D17649064 Pulled By: yf225 fbshipit-source-id: aa5ee26fe7078cc66d03663b9ff9e998e1d5839a
1 parent cca3a36 commit 0cd1880

File tree

3 files changed

+2865
-1
lines changed

3 files changed

+2865
-1
lines changed

aten/src/ATen/test/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ list(APPEND ATen_CPU_TEST_SRCS
2929
${CMAKE_CURRENT_SOURCE_DIR}/xla_tensor_test.cpp
3030
${CMAKE_CURRENT_SOURCE_DIR}/tensor_iterator_test.cpp
3131
${CMAKE_CURRENT_SOURCE_DIR}/cpu_generator_test.cpp
32-
${CMAKE_CURRENT_SOURCE_DIR}/pow_test.cpp)
32+
${CMAKE_CURRENT_SOURCE_DIR}/pow_test.cpp
33+
${CMAKE_CURRENT_SOURCE_DIR}/variant_test.cpp)
3334

3435
list(APPEND ATen_CUDA_TEST_SRCS
3536
${CMAKE_CURRENT_SOURCE_DIR}/cuda_integer_divider_test.cu
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#include <gtest/gtest.h>
2+
3+
#include <c10/util/variant.h>
4+
5+
namespace testns {
6+
7+
namespace enumtype {
8+
// NOTE: We need to provide the default constructor for each struct,
9+
// otherwise Clang 3.8 would complain:
10+
// ```
11+
// error: default initialization of an object of const type 'const enumtype::Enum1'
12+
// without a user-provided default constructor
13+
// ```
14+
struct Enum1 { Enum1() {} };
15+
struct Enum2 { Enum2() {} };
16+
struct Enum3 { Enum3() {} };
17+
} // namespace enumtype
18+
19+
const enumtype::Enum1 kEnum1;
20+
const enumtype::Enum2 kEnum2;
21+
const enumtype::Enum3 kEnum3;
22+
23+
} // namespace testns
24+
25+
std::string func(c10::variant<testns::enumtype::Enum1, testns::enumtype::Enum2, testns::enumtype::Enum3> v) {
26+
if (c10::get_if<testns::enumtype::Enum1>(&v)) {
27+
return "Enum1";
28+
} else if (c10::get_if<testns::enumtype::Enum2>(&v)) {
29+
return "Enum2";
30+
} else if (c10::get_if<testns::enumtype::Enum3>(&v)) {
31+
return "Enum3";
32+
} else {
33+
return "Unsupported enum";
34+
}
35+
}
36+
37+
TEST(VariantTest, Basic) {
38+
ASSERT_EQ(func(testns::kEnum1), "Enum1");
39+
ASSERT_EQ(func(testns::kEnum2), "Enum2");
40+
ASSERT_EQ(func(testns::kEnum3), "Enum3");
41+
}

0 commit comments

Comments
 (0)