|
20 | 20 | #include "llvm/Support/ErrorHandling.h" |
21 | 21 |
|
22 | 22 | #include <cstdint> |
| 23 | +#include <optional> |
23 | 24 |
|
24 | 25 | using namespace mlir; |
25 | 26 | using namespace mlir::spirv; |
@@ -172,14 +173,6 @@ Type ArrayType::getElementType() const { return getImpl()->elementType; } |
172 | 173 |
|
173 | 174 | unsigned ArrayType::getArrayStride() const { return getImpl()->stride; } |
174 | 175 |
|
175 | | -std::optional<int64_t> ArrayType::getSizeInBytes() { |
176 | | - auto elementType = llvm::cast<SPIRVType>(getElementType()); |
177 | | - std::optional<int64_t> size = elementType.getSizeInBytes(); |
178 | | - if (!size) |
179 | | - return std::nullopt; |
180 | | - return (*size + getArrayStride()) * getNumElements(); |
181 | | -} |
182 | | - |
183 | 176 | //===----------------------------------------------------------------------===// |
184 | 177 | // CompositeType |
185 | 178 | //===----------------------------------------------------------------------===// |
@@ -245,28 +238,6 @@ void TypeCapabilityVisitor::addConcrete(VectorType type) { |
245 | 238 | } |
246 | 239 | } |
247 | 240 |
|
248 | | -std::optional<int64_t> CompositeType::getSizeInBytes() { |
249 | | - if (auto arrayType = llvm::dyn_cast<ArrayType>(*this)) |
250 | | - return arrayType.getSizeInBytes(); |
251 | | - if (auto structType = llvm::dyn_cast<StructType>(*this)) |
252 | | - return structType.getSizeInBytes(); |
253 | | - if (auto vectorType = llvm::dyn_cast<VectorType>(*this)) { |
254 | | - std::optional<int64_t> elementSize = |
255 | | - llvm::cast<ScalarType>(vectorType.getElementType()).getSizeInBytes(); |
256 | | - if (!elementSize) |
257 | | - return std::nullopt; |
258 | | - return *elementSize * vectorType.getNumElements(); |
259 | | - } |
260 | | - if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) { |
261 | | - std::optional<int64_t> elementSize = |
262 | | - llvm::cast<ScalarType>(tensorArmType.getElementType()).getSizeInBytes(); |
263 | | - if (!elementSize) |
264 | | - return std::nullopt; |
265 | | - return *elementSize * tensorArmType.getNumElements(); |
266 | | - } |
267 | | - return std::nullopt; |
268 | | -} |
269 | | - |
270 | 241 | //===----------------------------------------------------------------------===// |
271 | 242 | // CooperativeMatrixType |
272 | 243 | //===----------------------------------------------------------------------===// |
@@ -714,19 +685,6 @@ void TypeCapabilityVisitor::addConcrete(ScalarType type) { |
714 | 685 | #undef WIDTH_CASE |
715 | 686 | } |
716 | 687 |
|
717 | | -std::optional<int64_t> ScalarType::getSizeInBytes() { |
718 | | - auto bitWidth = getIntOrFloatBitWidth(); |
719 | | - // According to the SPIR-V spec: |
720 | | - // "There is no physical size or bit pattern defined for values with boolean |
721 | | - // type. If they are stored (in conjunction with OpVariable), they can only |
722 | | - // be used with logical addressing operations, not physical, and only with |
723 | | - // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup, |
724 | | - // Private, Function, Input, and Output." |
725 | | - if (bitWidth == 1) |
726 | | - return std::nullopt; |
727 | | - return bitWidth / 8; |
728 | | -} |
729 | | - |
730 | 688 | //===----------------------------------------------------------------------===// |
731 | 689 | // SPIRVType |
732 | 690 | //===----------------------------------------------------------------------===// |
@@ -760,11 +718,35 @@ void SPIRVType::getCapabilities( |
760 | 718 | } |
761 | 719 |
|
762 | 720 | std::optional<int64_t> SPIRVType::getSizeInBytes() { |
763 | | - if (auto scalarType = llvm::dyn_cast<ScalarType>(*this)) |
764 | | - return scalarType.getSizeInBytes(); |
765 | | - if (auto compositeType = llvm::dyn_cast<CompositeType>(*this)) |
766 | | - return compositeType.getSizeInBytes(); |
767 | | - return std::nullopt; |
| 721 | + return TypeSwitch<SPIRVType, std::optional<int64_t>>(*this) |
| 722 | + .Case<ScalarType>([](ScalarType type) -> std::optional<int64_t> { |
| 723 | + // According to the SPIR-V spec: |
| 724 | + // "There is no physical size or bit pattern defined for values with |
| 725 | + // boolean type. If they are stored (in conjunction with OpVariable), |
| 726 | + // they can only be used with logical addressing operations, not |
| 727 | + // physical, and only with non-externally visible shader Storage |
| 728 | + // Classes: Workgroup, CrossWorkgroup, Private, Function, Input, and |
| 729 | + // Output." |
| 730 | + int64_t bitWidth = type.getIntOrFloatBitWidth(); |
| 731 | + if (bitWidth == 1) |
| 732 | + return std::nullopt; |
| 733 | + return bitWidth / 8; |
| 734 | + }) |
| 735 | + .Case<ArrayType>([](ArrayType type) -> std::optional<int64_t> { |
| 736 | + // Since array type may have an explicit stride declaration (in bytes), |
| 737 | + // we also include it in the calculation. |
| 738 | + auto elementType = cast<SPIRVType>(type.getElementType()); |
| 739 | + if (std::optional<int64_t> size = elementType.getSizeInBytes()) |
| 740 | + return (*size + type.getArrayStride()) * type.getNumElements(); |
| 741 | + return std::nullopt; |
| 742 | + }) |
| 743 | + .Case<VectorType, TensorArmType>([](auto type) -> std::optional<int64_t> { |
| 744 | + if (std::optional<int64_t> elementSize = |
| 745 | + cast<ScalarType>(type.getElementType()).getSizeInBytes()) |
| 746 | + return *elementSize * type.getNumElements(); |
| 747 | + return std::nullopt; |
| 748 | + }) |
| 749 | + .Default(std::optional<int64_t>()); |
768 | 750 | } |
769 | 751 |
|
770 | 752 | //===----------------------------------------------------------------------===// |
|
0 commit comments