Skip to content

[flang][HLFIR] Relax verifiers of intrinsic operations #80132

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 44 additions & 27 deletions flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,16 @@
#include "mlir/IR/OpImplementation.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/CommandLine.h"
#include <iterator>
#include <mlir/Interfaces/SideEffectInterfaces.h>
#include <optional>
#include <tuple>

static llvm::cl::opt<bool> useStrictIntrinsicVerifier(
"strict-intrinsic-verifier", llvm::cl::init(false),
llvm::cl::desc("use stricter verifier for HLFIR intrinsic operations"));

/// generic implementation of the memory side effects interface for hlfir
/// transformational intrinsic operations
static void
Expand Down Expand Up @@ -498,7 +503,7 @@ verifyLogicalReductionOp(LogicalReductionOp reductionOp) {
mlir::Type resultType = results[0];
if (mlir::isa<fir::LogicalType>(resultType)) {
// Result is of the same type as MASK
if (resultType != logicalTy)
if ((resultType != logicalTy) && useStrictIntrinsicVerifier)
return reductionOp->emitOpError(
"result must have the same element type as MASK argument");

Expand All @@ -509,7 +514,7 @@ verifyLogicalReductionOp(LogicalReductionOp reductionOp) {
if (!resultExpr.isArray())
return reductionOp->emitOpError("result must be an array");

if (resultExpr.getEleTy() != logicalTy)
if ((resultExpr.getEleTy() != logicalTy) && useStrictIntrinsicVerifier)
return reductionOp->emitOpError(
"result must have the same element type as MASK argument");

Expand Down Expand Up @@ -585,7 +590,7 @@ mlir::LogicalResult hlfir::CountOp::verify() {
if (resultShape.size() != (maskShape.size() - 1))
return emitOpError("result rank must be one less than MASK");
} else {
return emitOpError("result must be of numerical scalar type");
return emitOpError("result must be of numerical array type");
}
} else if (!hlfir::isFortranScalarNumericalType(resultType)) {
return emitOpError("result must be of numerical scalar type");
Expand Down Expand Up @@ -682,15 +687,18 @@ verifyArrayAndMaskForReductionOp(NumericalReductionOp reductionOp) {
if (!maskShape.empty()) {
if (maskShape.size() != arrayShape.size())
return reductionOp->emitWarning("MASK must be conformable to ARRAY");
static_assert(fir::SequenceType::getUnknownExtent() ==
hlfir::ExprType::getUnknownExtent());
constexpr int64_t unknownExtent = fir::SequenceType::getUnknownExtent();
for (std::size_t i = 0; i < arrayShape.size(); ++i) {
int64_t arrayExtent = arrayShape[i];
int64_t maskExtent = maskShape[i];
if ((arrayExtent != maskExtent) && (arrayExtent != unknownExtent) &&
(maskExtent != unknownExtent))
return reductionOp->emitWarning("MASK must be conformable to ARRAY");
if (useStrictIntrinsicVerifier) {
static_assert(fir::SequenceType::getUnknownExtent() ==
hlfir::ExprType::getUnknownExtent());
constexpr int64_t unknownExtent = fir::SequenceType::getUnknownExtent();
for (std::size_t i = 0; i < arrayShape.size(); ++i) {
int64_t arrayExtent = arrayShape[i];
int64_t maskExtent = maskShape[i];
if ((arrayExtent != maskExtent) && (arrayExtent != unknownExtent) &&
(maskExtent != unknownExtent))
return reductionOp->emitWarning(
"MASK must be conformable to ARRAY");
}
}
}
}
Expand Down Expand Up @@ -719,7 +727,7 @@ verifyNumericalReductionOp(NumericalReductionOp reductionOp) {
mlir::Type resultType = results[0];
if (hlfir::isFortranScalarNumericalType(resultType)) {
// Result is of the same type as ARRAY
if (resultType != numTy)
if ((resultType != numTy) && useStrictIntrinsicVerifier)
return reductionOp->emitOpError(
"result must have the same element type as ARRAY argument");

Expand All @@ -729,7 +737,7 @@ verifyNumericalReductionOp(NumericalReductionOp reductionOp) {
if (!resultExpr.isArray())
return reductionOp->emitOpError("result must be an array");

if (resultExpr.getEleTy() != numTy)
if ((resultExpr.getEleTy() != numTy) && useStrictIntrinsicVerifier)
return reductionOp->emitOpError(
"result must have the same element type as ARRAY argument");

Expand Down Expand Up @@ -792,7 +800,7 @@ verifyCharacterReductionOp(CharacterReductionOp reductionOp) {
"result must be character");

// Result is of the same type as ARRAY
if (resultType != numTy)
if ((resultType != numTy) && useStrictIntrinsicVerifier)
return reductionOp->emitOpError(
"result must have the same element type as ARRAY argument");

Expand Down Expand Up @@ -823,9 +831,8 @@ mlir::LogicalResult hlfir::MaxvalOp::verify() {
auto resultExpr = mlir::dyn_cast<hlfir::ExprType>(results[0]);
if (resultExpr && mlir::isa<fir::CharacterType>(resultExpr.getEleTy())) {
return verifyCharacterReductionOp<hlfir::MaxvalOp *>(this);
} else {
return verifyNumericalReductionOp<hlfir::MaxvalOp *>(this);
}
return verifyNumericalReductionOp<hlfir::MaxvalOp *>(this);
}

void hlfir::MaxvalOp::getEffects(
Expand All @@ -848,9 +855,8 @@ mlir::LogicalResult hlfir::MinvalOp::verify() {
auto resultExpr = mlir::dyn_cast<hlfir::ExprType>(results[0]);
if (resultExpr && mlir::isa<fir::CharacterType>(resultExpr.getEleTy())) {
return verifyCharacterReductionOp<hlfir::MinvalOp *>(this);
} else {
return verifyNumericalReductionOp<hlfir::MinvalOp *>(this);
}
return verifyNumericalReductionOp<hlfir::MinvalOp *>(this);
}

void hlfir::MinvalOp::getEffects(
Expand Down Expand Up @@ -1007,17 +1013,19 @@ mlir::LogicalResult hlfir::DotProductOp::verify() {

constexpr int64_t unknownExtent = fir::SequenceType::getUnknownExtent();
if ((lhsSize != unknownExtent) && (rhsSize != unknownExtent) &&
(lhsSize != rhsSize))
(lhsSize != rhsSize) && useStrictIntrinsicVerifier)
return emitOpError("both arrays must have the same size");

if (mlir::isa<fir::LogicalType>(lhsEleTy) !=
mlir::isa<fir::LogicalType>(rhsEleTy))
return emitOpError("if one array is logical, so should the other be");
if (useStrictIntrinsicVerifier) {
if (mlir::isa<fir::LogicalType>(lhsEleTy) !=
mlir::isa<fir::LogicalType>(rhsEleTy))
return emitOpError("if one array is logical, so should the other be");

if (mlir::isa<fir::LogicalType>(lhsEleTy) !=
mlir::isa<fir::LogicalType>(resultTy))
return emitOpError("the result type should be a logical only if the "
"argument types are logical");
if (mlir::isa<fir::LogicalType>(lhsEleTy) !=
mlir::isa<fir::LogicalType>(resultTy))
return emitOpError("the result type should be a logical only if the "
"argument types are logical");
}

if (!hlfir::isFortranScalarNumericalType(resultTy) &&
!mlir::isa<fir::LogicalType>(resultTy))
Expand Down Expand Up @@ -1067,6 +1075,9 @@ mlir::LogicalResult hlfir::MatmulOp::verify() {
mlir::isa<fir::LogicalType>(rhsEleTy))
return emitOpError("if one array is logical, so should the other be");

if (!useStrictIntrinsicVerifier)
return mlir::success();

int64_t lastLhsDim = lhsShape[lhsRank - 1];
int64_t firstRhsDim = rhsShape[0];
constexpr int64_t unknownExtent = fir::SequenceType::getUnknownExtent();
Expand Down Expand Up @@ -1179,6 +1190,9 @@ mlir::LogicalResult hlfir::TransposeOp::verify() {
if (rank != 2 || resultRank != 2)
return emitOpError("input and output arrays should have rank 2");

if (!useStrictIntrinsicVerifier)
return mlir::success();

constexpr int64_t unknownExtent = fir::SequenceType::getUnknownExtent();
if ((inShape[0] != resultShape[1]) && (inShape[0] != unknownExtent))
return emitOpError("output shape does not match input array");
Expand Down Expand Up @@ -1226,6 +1240,9 @@ mlir::LogicalResult hlfir::MatmulTransposeOp::verify() {
if ((lhsRank != 2) || ((rhsRank != 1) && (rhsRank != 2)))
return emitOpError("array must have either rank 1 or rank 2");

if (!useStrictIntrinsicVerifier)
return mlir::success();

if (mlir::isa<fir::LogicalType>(lhsEleTy) !=
mlir::isa<fir::LogicalType>(rhsEleTy))
return emitOpError("if one array is logical, so should the other be");
Expand Down
4 changes: 2 additions & 2 deletions flang/test/HLFIR/invalid.fir
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// HLFIR ops diagnotic tests

// RUN: fir-opt -split-input-file -verify-diagnostics %s
// RUN: fir-opt -strict-intrinsic-verifier -split-input-file -verify-diagnostics %s

func.func @bad_declare(%arg0: !fir.ref<f32>) {
// expected-error@+1 {{'hlfir.declare' op first result type is inconsistent with variable properties: expected '!fir.ref<f32>'}}
Expand Down Expand Up @@ -382,7 +382,7 @@ func.func @bad_count2(%arg0: !hlfir.expr<?x?x!fir.logical<4>>, %arg1: i32){

// -----
func.func @bad_count3(%arg0: !hlfir.expr<?x!fir.logical<4>>, %arg1: i32) {
// expected-error@+1 {{'hlfir.count' op result must be of numerical scalar type}}
// expected-error@+1 {{'hlfir.count' op result must be of numerical array type}}
%0 = hlfir.count %arg0 dim %arg1 : (!hlfir.expr<?x!fir.logical<4>>, i32) -> !hlfir.expr<i32>
}

Expand Down
44 changes: 44 additions & 0 deletions flang/test/Lower/HLFIR/minval.f90
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,47 @@ end subroutine test_unknown_char_len_result
! CHECK-NEXT: hlfir.destroy %[[EXPR]]
! CHECK-NEXT: return
! CHECK-NEXT: }

! Test edge case with missmatch between argument type !fir.char<1,?> and result
! type !fir.char<1,4>
function test_type_mismatch
character(:), allocatable :: test_type_mismatch(:)
character(3) :: char(3,4)
test_type_mismatch = minval(char//' ', dim=1)
end function
! CHECK-LABEL: func.func @_QPtest_type_mismatch() -> !fir.box<!fir.heap<!fir.array<?x!fir.char<1,?>>>> {
! CHECK: %[[VAL_0:.*]] = arith.constant 3 : index
! CHECK: %[[VAL_1:.*]] = arith.constant 3 : index
! CHECK: %[[VAL_2:.*]] = arith.constant 4 : index
! CHECK: %[[VAL_3:.*]] = fir.alloca !fir.array<3x4x!fir.char<1,3>> {bindc_name = "char", uniq_name = "_QFtest_type_mismatchEchar"}
! CHECK: %[[VAL_4:.*]] = fir.shape %[[VAL_1]], %[[VAL_2]] : (index, index) -> !fir.shape<2>
! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %[[VAL_3]](%[[VAL_4]]) typeparams %[[VAL_0]] {uniq_name = "_QFtest_type_mismatchEchar"} : (!fir.ref<!fir.array<3x4x!fir.char<1,3>>>, !fir.shape<2>, index) -> (!fir.ref<!fir.array<3x4x!fir.char<1,3>>>, !fir.ref<!fir.array<3x4x!fir.char<1,3>>>)
! CHECK: %[[VAL_6:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?x!fir.char<1,?>>>> {bindc_name = "test_type_mismatch", uniq_name = "_QFtest_type_mismatchEtest_type_mismatch"}
! CHECK: %[[VAL_7:.*]] = fir.zero_bits !fir.heap<!fir.array<?x!fir.char<1,?>>>
! CHECK: %[[VAL_8:.*]] = arith.constant 0 : index
! CHECK: %[[VAL_9:.*]] = fir.shape %[[VAL_8]] : (index) -> !fir.shape<1>
! CHECK: %[[VAL_10:.*]] = arith.constant 0 : index
! CHECK: %[[VAL_11:.*]] = fir.embox %[[VAL_7]](%[[VAL_9]]) typeparams %[[VAL_10]] : (!fir.heap<!fir.array<?x!fir.char<1,?>>>, !fir.shape<1>, index) -> !fir.box<!fir.heap<!fir.array<?x!fir.char<1,?>>>>
! CHECK: fir.store %[[VAL_11]] to %[[VAL_6]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.char<1,?>>>>>
! CHECK: %[[VAL_12:.*]]:2 = hlfir.declare %[[VAL_6]] {fortran_attrs = #{{.*}}, uniq_name = "_QFtest_type_mismatchEtest_type_mismatch"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.char<1,?>>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.char<1,?>>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.char<1,?>>>>>)
! CHECK: %[[VAL_13:.*]] = fir.address_of(@_QQclX20) : !fir.ref<!fir.char<1>>
! CHECK: %[[VAL_14:.*]] = arith.constant 1 : index
! CHECK: %[[VAL_15:.*]]:2 = hlfir.declare %[[VAL_13]] typeparams %[[VAL_14]] {fortran_attrs = {{.*}}, uniq_name = "_QQclX20"} : (!fir.ref<!fir.char<1>>, index) -> (!fir.ref<!fir.char<1>>, !fir.ref<!fir.char<1>>)
! CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_0]], %[[VAL_14]] : index
! CHECK: %[[VAL_17:.*]] = hlfir.elemental %[[VAL_4]] typeparams %[[VAL_16]] unordered : (!fir.shape<2>, index) -> !hlfir.expr<3x4x!fir.char<1,?>> {
! CHECK: ^bb0(%[[VAL_18:.*]]: index, %[[VAL_19:.*]]: index):
! CHECK: %[[VAL_20:.*]] = hlfir.designate %[[VAL_5]]#0 (%[[VAL_18]], %[[VAL_19]]) typeparams %[[VAL_0]] : (!fir.ref<!fir.array<3x4x!fir.char<1,3>>>, index, index, index) -> !fir.ref<!fir.char<1,3>>
! CHECK: %[[VAL_21:.*]] = hlfir.concat %[[VAL_20]], %[[VAL_15]]#0 len %[[VAL_16]] : (!fir.ref<!fir.char<1,3>>, !fir.ref<!fir.char<1>>, index) -> !hlfir.expr<!fir.char<1,4>>
! CHECK: hlfir.yield_element %[[VAL_21]] : !hlfir.expr<!fir.char<1,4>>
! CHECK: }
! CHECK: %[[VAL_22:.*]] = arith.constant 1 : i32
! CHECK: %[[VAL_23:.*]] = hlfir.minval %[[VAL_17]] dim %[[VAL_22]] {fastmath = {{.*}}} : (!hlfir.expr<3x4x!fir.char<1,?>>, i32) -> !hlfir.expr<4x!fir.char<1,4>>
! CHECK: hlfir.assign %[[VAL_23]] to %[[VAL_12]]#0 realloc : !hlfir.expr<4x!fir.char<1,4>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.char<1,?>>>>>
! CHECK: hlfir.destroy %[[VAL_23]] : !hlfir.expr<4x!fir.char<1,4>>
! CHECK: hlfir.destroy %[[VAL_17]] : !hlfir.expr<3x4x!fir.char<1,?>>
! CHECK: %[[VAL_24:.*]] = fir.load %[[VAL_12]]#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.char<1,?>>>>>
! CHECK: %[[VAL_25:.*]] = arith.constant 1 : index
! CHECK: %[[VAL_26:.*]] = fir.shift %[[VAL_25]] : (index) -> !fir.shift<1>
! CHECK: %[[VAL_27:.*]] = fir.rebox %[[VAL_24]](%[[VAL_26]]) : (!fir.box<!fir.heap<!fir.array<?x!fir.char<1,?>>>>, !fir.shift<1>) -> !fir.box<!fir.heap<!fir.array<?x!fir.char<1,?>>>>
! CHECK: return %[[VAL_27]] : !fir.box<!fir.heap<!fir.array<?x!fir.char<1,?>>>>
! CHECK: }