diff --git a/lib/SILOptimizer/Mandatory/AddressLowering.cpp b/lib/SILOptimizer/Mandatory/AddressLowering.cpp index 231a6941f1ef9..30b6972e1afce 100644 --- a/lib/SILOptimizer/Mandatory/AddressLowering.cpp +++ b/lib/SILOptimizer/Mandatory/AddressLowering.cpp @@ -419,6 +419,11 @@ struct AddressLoweringState { // parameters are rewritten. SmallBlotSetVector indirectApplies; + // checked_cast_br instructions with loadable source type and opaque target + // type need to be rewritten in a post-pass, once all the uses of the opaque + // target value are rewritten to their address forms. + SmallVector opaqueResultCCBs; + // All function-exiting terminators (return or throw instructions). SmallVector exitingInsts; @@ -606,6 +611,15 @@ void OpaqueValueVisitor::mapValueStorage() { if (auto apply = FullApplySite::isa(&inst)) checkForIndirectApply(apply); + // Collect all checked_cast_br instructions that have a loadable source + // type and opaque target type + if (auto *ccb = dyn_cast(&inst)) { + if (!ccb->getSourceLoweredType().isAddressOnly(*ccb->getFunction()) && + ccb->getTargetLoweredType().isAddressOnly(*ccb->getFunction())) { + pass.opaqueResultCCBs.push_back(ccb); + } + } + for (auto result : inst.getResults()) { if (isPseudoCallResult(result) || isPseudoReturnValue(result)) continue; @@ -2252,6 +2266,99 @@ void ApplyRewriter::replaceDirectResults(DestructureTupleInst *oldDestructure) { } } +//===----------------------------------------------------------------------===// +// CheckedCastBrRewriter +// +// Utilities for rewriting checked_cast_br with opaque source/target type +// ===---------------------------------------------------------------------===// +class CheckedCastBrRewriter { + CheckedCastBranchInst *ccb; + AddressLoweringState &pass; + SILLocation castLoc; + SILFunction *func; + SILBasicBlock *successBB; + SILBasicBlock *failureBB; + SILArgument *origSuccessVal; + SILArgument *origFailureVal; + SILBuilder termBuilder; + SILBuilder successBuilder; + SILBuilder failureBuilder; + +public: + CheckedCastBrRewriter(CheckedCastBranchInst *ccb, AddressLoweringState &pass) + : ccb(ccb), pass(pass), castLoc(ccb->getLoc()), func(ccb->getFunction()), + successBB(ccb->getSuccessBB()), failureBB(ccb->getFailureBB()), + origSuccessVal(successBB->getArgument(0)), + origFailureVal(failureBB->getArgument(0)), + termBuilder(pass.getTermBuilder(ccb)), + successBuilder(pass.getBuilder(successBB->begin())), + failureBuilder(pass.getBuilder(failureBB->begin())) {} + + /// Rewrite checked_cast_br with opaque source/target operands to + /// checked_cast_addr_br + void rewrite() { + auto srcAddr = + getAddressForCastEntity(ccb->getOperand(), /* needsInit */ true); + auto destAddr = + getAddressForCastEntity(origSuccessVal, /* needsInit */ false); + + // getReusedStorageOperand() ensured we do not allocate a separate address + // for failure block arg. Set the storage address of the failure block arg + // to be source address here. + if (origFailureVal->getType().isAddressOnly(*func)) { + pass.valueStorageMap.setStorageAddress(origFailureVal, srcAddr); + } + + termBuilder.createCheckedCastAddrBranch( + castLoc, CastConsumptionKind::TakeOnSuccess, srcAddr, + ccb->getSourceFormalType(), destAddr, ccb->getTargetFormalType(), + successBB, failureBB, ccb->getTrueBBCount(), ccb->getFalseBBCount()); + + replaceBlockArg(origSuccessVal, destAddr); + replaceBlockArg(origFailureVal, srcAddr); + + pass.deleter.forceDelete(ccb); + } + +private: + /// Return the storageAddress if \p value is opaque, otherwise create and + /// return a stack temporary. + SILValue getAddressForCastEntity(SILValue value, bool needsInit) { + if (value->getType().isAddressOnly(*func)) + return pass.valueStorageMap.getStorage(value).storageAddress; + + // Create a stack temporary for a loadable value + auto *addr = termBuilder.createAllocStack(castLoc, value->getType()); + if (needsInit) { + termBuilder.createStore(castLoc, value, addr, + value->getType().isTrivial(*func) + ? StoreOwnershipQualifier::Trivial + : StoreOwnershipQualifier::Init); + } + successBuilder.createDeallocStack(castLoc, addr); + failureBuilder.createDeallocStack(castLoc, addr); + return addr; + } + + void replaceBlockArg(SILArgument *blockArg, SILValue addr) { + // Replace all uses of the opaque block arg with a load from its + // storage address. + auto load = + pass.getBuilder(blockArg->getParent()->begin()) + .createTrivialLoadOr(castLoc, addr, LoadOwnershipQualifier::Take); + blockArg->replaceAllUsesWith(load); + + blockArg->getParent()->eraseArgument(blockArg->getIndex()); + + if (blockArg->getType().isAddressOnly(*func)) { + // In case of opaque block arg, replace the block arg with the dummy load + // in the valueStorageMap. DefRewriter::visitLoadInst will then rewrite + // the dummy load to copy_addr. + pass.valueStorageMap.replaceValue(blockArg, load); + } + } +}; + //===----------------------------------------------------------------------===// // ReturnRewriter // @@ -2811,87 +2918,8 @@ void UseRewriter::visitSwitchEnumInst(SwitchEnumInst * switchEnum) { defaultCounter); } -void UseRewriter::visitCheckedCastBranchInst( - CheckedCastBranchInst *checkedCastBranch) { - auto loc = checkedCastBranch->getLoc(); - auto *func = checkedCastBranch->getFunction(); - auto *successBB = checkedCastBranch->getSuccessBB(); - auto *failureBB = checkedCastBranch->getFailureBB(); - auto *oldSuccessVal = successBB->getArgument(0); - auto *oldFailureVal = failureBB->getArgument(0); - auto termBuilder = pass.getTermBuilder(checkedCastBranch); - auto successBuilder = pass.getBuilder(successBB->begin()); - auto failureBuilder = pass.getBuilder(failureBB->begin()); - bool isAddressOnlyTarget = oldSuccessVal->getType().isAddressOnly(*func); - - auto srcAddr = pass.valueStorageMap.getStorage(use->get()).storageAddress; - - if (isAddressOnlyTarget) { - // If target is opaque, use the storage address mapped to success - // block's argument as the destination for checked_cast_addr_br. - SILValue destAddr = - pass.valueStorageMap.getStorage(oldSuccessVal).storageAddress; - - termBuilder.createCheckedCastAddrBranch( - loc, CastConsumptionKind::TakeOnSuccess, srcAddr, - checkedCastBranch->getSourceFormalType(), destAddr, - checkedCastBranch->getTargetFormalType(), successBB, failureBB, - checkedCastBranch->getTrueBBCount(), - checkedCastBranch->getFalseBBCount()); - - // In this case, since both success and failure block's args are opaque, - // create dummy loads from their storage addresses that will later be - // rewritten to copy_addr in DefRewriter::visitLoadInst - auto newSuccessVal = successBuilder.createTrivialLoadOr( - loc, destAddr, LoadOwnershipQualifier::Take); - oldSuccessVal->replaceAllUsesWith(newSuccessVal); - successBB->eraseArgument(0); - - pass.valueStorageMap.replaceValue(oldSuccessVal, newSuccessVal); - - auto newFailureVal = failureBuilder.createTrivialLoadOr( - loc, srcAddr, LoadOwnershipQualifier::Take); - oldFailureVal->replaceAllUsesWith(newFailureVal); - failureBB->eraseArgument(0); - - pass.valueStorageMap.replaceValue(oldFailureVal, newFailureVal); - markRewritten(newFailureVal, srcAddr); - } else { - // If the target is loadable, create a stack temporary to be used as the - // destination for checked_cast_addr_br. - SILValue destAddr = termBuilder.createAllocStack( - loc, checkedCastBranch->getTargetLoweredType()); - - termBuilder.createCheckedCastAddrBranch( - loc, CastConsumptionKind::TakeOnSuccess, srcAddr, - checkedCastBranch->getSourceFormalType(), destAddr, - checkedCastBranch->getTargetFormalType(), successBB, failureBB, - checkedCastBranch->getTrueBBCount(), - checkedCastBranch->getFalseBBCount()); - - // Replace the success block arg with loaded value from destAddr, and delete - // the success block arg. - auto newSuccessVal = successBuilder.createTrivialLoadOr( - loc, destAddr, LoadOwnershipQualifier::Take); - oldSuccessVal->replaceAllUsesWith(newSuccessVal); - successBB->eraseArgument(0); - - successBuilder.createDeallocStack(loc, destAddr); - failureBuilder.createDeallocStack(loc, destAddr); - - // Since failure block arg is opaque, create dummy load from its storage - // address. This will be replaced later with copy_addr in - // DefRewriter::visitLoadInst. - auto newFailureVal = failureBuilder.createTrivialLoadOr( - loc, srcAddr, LoadOwnershipQualifier::Take); - oldFailureVal->replaceAllUsesWith(newFailureVal); - failureBB->eraseArgument(0); - - pass.valueStorageMap.replaceValue(oldFailureVal, newFailureVal); - markRewritten(newFailureVal, srcAddr); - } - - pass.deleter.forceDelete(checkedCastBranch); +void UseRewriter::visitCheckedCastBranchInst(CheckedCastBranchInst *ccb) { + CheckedCastBrRewriter(ccb, pass).rewrite(); } void UseRewriter::visitUncheckedEnumDataInst( @@ -2989,18 +3017,9 @@ class DefRewriter : SILInstructionVisitor { LLVM_DEBUG(llvm::dbgs() << "REWRITE ARG "; arg->dump()); if (storage.storageAddress) LLVM_DEBUG(llvm::dbgs() << " STORAGE "; storage.storageAddress->dump()); - storage.storageAddress = addrMat.materializeAddress(arg); } - void setStorageAddress(SILValue oldValue, SILValue addr) { - auto &storage = pass.valueStorageMap.getStorage(oldValue); - // getReusedStorageOperand() ensures that oldValue does not already have - // separate storage. So there's no need to delete its alloc_stack. - assert(!storage.storageAddress || storage.storageAddress == addr); - storage.storageAddress = addr; - } - void beforeVisit(SILInstruction *inst) { LLVM_DEBUG(llvm::dbgs() << "REWRITE DEF "; inst->dump()); if (storage.storageAddress) @@ -3068,7 +3087,7 @@ class DefRewriter : SILInstructionVisitor { openExistentialBoxValue->getType().getAddressType()); openExistentialBoxValue->replaceAllTypeDependentUsesWith(openAddr); - setStorageAddress(openExistentialBoxValue, openAddr); + pass.valueStorageMap.setStorageAddress(openExistentialBoxValue, openAddr); } // Load an opaque value. @@ -3133,7 +3152,7 @@ class DefRewriter : SILInstructionVisitor { // Rewrite Opaque Values //===----------------------------------------------------------------------===// -// Rewrite applies with indirect paramters or results of loadable types which +// Rewrite applies with indirect parameters or results of loadable types which // were not visited during opaque value rewritting. static void rewriteIndirectApply(FullApplySite apply, AddressLoweringState &pass) { @@ -3186,6 +3205,13 @@ static void rewriteFunction(AddressLoweringState &pass) { rewriteIndirectApply(optionalApply.getValue(), pass); } } + + // Rewrite all checked_cast_br instructions with loadable source type and + // opaque target type now + for (auto *ccb : pass.opaqueResultCCBs) { + CheckedCastBrRewriter(ccb, pass).rewrite(); + } + // Rewrite this function's return value now that all opaque values within the // function are rewritten. This still depends on a valid ValueStorage // projection operands. diff --git a/lib/SILOptimizer/Mandatory/AddressLowering.h b/lib/SILOptimizer/Mandatory/AddressLowering.h index c75f893385ae0..17ecf558d43e0 100644 --- a/lib/SILOptimizer/Mandatory/AddressLowering.h +++ b/lib/SILOptimizer/Mandatory/AddressLowering.h @@ -256,6 +256,12 @@ class ValueStorageMap { return getNonEnumBaseStorage(getStorage(value)); } + void setStorageAddress(SILValue value, SILValue addr) { + auto &storage = getStorage(value); + assert(!storage.storageAddress || storage.storageAddress == addr); + storage.storageAddress = addr; + } + /// Insert a value in the map, creating a ValueStorage object for it. This /// must be called in RPO order. void insertValue(SILValue value, SILValue storageAddress); diff --git a/test/SILOptimizer/address_lowering.sil b/test/SILOptimizer/address_lowering.sil index 3e3fd1c2c1059..09249c3ef2901 100644 --- a/test/SILOptimizer/address_lowering.sil +++ b/test/SILOptimizer/address_lowering.sil @@ -1245,6 +1245,47 @@ bb3: return %31 : $() } +sil @use_Any : $@convention(thin) (@in Any) -> () + +// CHECK-LABEL: sil [ossa] @test_checked_cast_br3 : $@convention(method) (@owned C) -> () { +// CHECK: bb0(%0 : @owned $C): +// CHECK: [[DST:%.*]] = alloc_stack $Any +// CHECK: [[SRC_TMP:%.*]] = alloc_stack $C +// CHECK: store %0 to [init] [[SRC_TMP]] : $*C +// CHECK: checked_cast_addr_br take_on_success C in [[SRC_TMP]] : $*C to Any in [[DST]] : $*Any, bb2, bb1 +// CHECK: bb1: +// CHECK: [[LD:%.*]] = load [take] [[SRC_TMP]] : $*C +// CHECK: dealloc_stack [[SRC_TMP]] : $*C +// CHECK: destroy_value [[LD]] : $C +// CHECK: br bb3 +// CHECK: bb2: +// CHECK: dealloc_stack [[SRC_TMP]] : $*C +// CHECK: [[FUNC:%.*]] = function_ref @use_Any : $@convention(thin) (@in Any) -> () +// CHECK: apply [[FUNC]]([[DST]]) : $@convention(thin) (@in Any) -> () +// CHECK: br bb3 +// CHECK: bb3: +// CHECK: [[RES:%.*]] = tuple () +// CHECK: dealloc_stack [[DST]] : $*Any +// CHECK: return [[RES]] : $() +// CHECK: } +sil [ossa] @test_checked_cast_br3 : $@convention(method) (@owned C) -> () { +bb0(%0 : @owned $C): + checked_cast_br %0 : $C to Any, bb1, bb2 + +bb1(%3 : @owned $Any): + %f = function_ref @use_Any : $@convention(thin) (@in Any) -> () + %call = apply %f(%3) : $@convention(thin) (@in Any) -> () + br bb3 + +bb2(%4 : @owned $C): + destroy_value %4 : $C + br bb3 + +bb3: + %31 = tuple () + return %31 : $() +} + // CHECK-LABEL: sil hidden [ossa] @test_unchecked_bitwise_cast : // CHECK: bb0(%0 : $*U, %1 : $*T, %2 : $@thick U.Type): // CHECK: [[STK:%.*]] = alloc_stack $T