From fcbfc49e761804e9288f01683b559be5df8acc4c Mon Sep 17 00:00:00 2001 From: Egor Zhdan Date: Thu, 26 Oct 2023 19:31:45 +0100 Subject: [PATCH] [cxx-interop] Allow removing elements from `std::vector` This adds `func remove(at index: Int)` to all instantiations of `std::vector` via an extension for `protocol CxxVector`. The original C++ method `std::vector::erase` is not visible in Swift because all of its overloads return unsafe iterators. rdar://113704853 --- .../ClangDerivedConformances.cpp | 16 ++++++- stdlib/public/Cxx/CxxVector.swift | 22 +++++++++- test/Interop/Cxx/stdlib/use-std-vector.swift | 42 +++++++++++++++++++ 3 files changed, 76 insertions(+), 4 deletions(-) diff --git a/lib/ClangImporter/ClangDerivedConformances.cpp b/lib/ClangImporter/ClangDerivedConformances.cpp index f3a99ca6c2947..37cd7cab260ce 100644 --- a/lib/ClangImporter/ClangDerivedConformances.cpp +++ b/lib/ClangImporter/ClangDerivedConformances.cpp @@ -1003,26 +1003,38 @@ void swift::conformToCxxVectorIfNeeded(ClangImporter::Implementation &impl, decl, ctx.getIdentifier("value_type")); auto iterType = lookupDirectSingleWithoutExtensions( decl, ctx.getIdentifier("const_iterator")); - if (!valueType || !iterType) + auto mutableIterType = lookupDirectSingleWithoutExtensions( + decl, ctx.getIdentifier("iterator")); + if (!valueType || !iterType || !mutableIterType) return; ProtocolDecl *cxxRandomAccessIteratorProto = ctx.getProtocol(KnownProtocolKind::UnsafeCxxRandomAccessIterator); - if (!cxxRandomAccessIteratorProto) + ProtocolDecl *cxxMutableRandomAccessIteratorProto = + ctx.getProtocol(KnownProtocolKind::UnsafeCxxMutableRandomAccessIterator); + if (!cxxRandomAccessIteratorProto || !cxxMutableRandomAccessIteratorProto) return; auto rawIteratorTy = iterType->getUnderlyingType(); + auto rawMutableIteratorTy = mutableIterType->getUnderlyingType(); // Check if RawIterator conforms to UnsafeCxxRandomAccessIterator. if (!checkConformance(rawIteratorTy, cxxRandomAccessIteratorProto)) return; + // Check if RawMutableIterator conforms to UnsafeCxxMutableInputIterator. + if (!checkConformance(rawMutableIteratorTy, + cxxMutableRandomAccessIteratorProto)) + return; + impl.addSynthesizedTypealias(decl, ctx.Id_Element, valueType->getUnderlyingType()); impl.addSynthesizedTypealias(decl, ctx.Id_ArrayLiteralElement, valueType->getUnderlyingType()); impl.addSynthesizedTypealias(decl, ctx.getIdentifier("RawIterator"), rawIteratorTy); + impl.addSynthesizedTypealias(decl, ctx.getIdentifier("RawMutableIterator"), + rawMutableIteratorTy); impl.addSynthesizedProtocolAttrs(decl, {KnownProtocolKind::CxxVector}); } diff --git a/stdlib/public/Cxx/CxxVector.swift b/stdlib/public/Cxx/CxxVector.swift index 03ed6b60ec641..d171489e2acd2 100644 --- a/stdlib/public/Cxx/CxxVector.swift +++ b/stdlib/public/Cxx/CxxVector.swift @@ -17,10 +17,19 @@ public protocol CxxVector: ExpressibleByArrayLiteral { associatedtype Element associatedtype RawIterator: UnsafeCxxRandomAccessIterator where RawIterator.Pointee == Element + associatedtype RawMutableIterator: UnsafeCxxMutableRandomAccessIterator + where RawMutableIterator.Pointee == Element init() + /// Do not implement this function manually in Swift. + mutating func __beginUnsafe() -> RawIterator + mutating func push_back(_ element: Element) + + /// Do not implement this function manually in Swift. + @discardableResult + mutating func __eraseUnsafe(_ iterator: RawIterator) -> RawMutableIterator } extension CxxVector { @@ -37,11 +46,20 @@ extension CxxVector { self.push_back(item) } } -} -extension CxxVector { @inlinable public init(arrayLiteral elements: Element...) { self.init(elements) } + + @discardableResult + @inlinable + public mutating func remove(at index: Int) -> Element { + // Not using CxxIterator here to avoid making a copy of the collection. + var rawIterator = self.__beginUnsafe() + rawIterator += RawIterator.Distance(index) + let element = rawIterator.pointee + self.__eraseUnsafe(rawIterator) + return element + } } diff --git a/test/Interop/Cxx/stdlib/use-std-vector.swift b/test/Interop/Cxx/stdlib/use-std-vector.swift index c2fb18747c95c..8303de852826e 100644 --- a/test/Interop/Cxx/stdlib/use-std-vector.swift +++ b/test/Interop/Cxx/stdlib/use-std-vector.swift @@ -85,6 +85,48 @@ func fill(vector v: inout Vector) { v.push_back(CInt(3)) } +StdVectorTestSuite.test("VectorOfInt.remove(at:)") { + var v = Vector() + fill(vector: &v) + + let rm1 = v.remove(at: 1) + expectEqual(rm1, 2) + expectEqual(v.size(), 2) + expectEqual(v[0], 1) + expectEqual(v[1], 3) + + let rm2 = v.remove(at: 0) + expectEqual(rm2, 1) + expectEqual(v.size(), 1) + expectEqual(v[0], 3) +} + +StdVectorTestSuite.test("VectorOfString.remove(at:)") { + var v = VectorOfString() + v.push_back(std.string()) + v.push_back(std.string("123")) + v.push_back(std.string("abc")) + v.push_back(std.string("qwe")) + + let rm1 = v.remove(at: 3) + expectEqual(rm1, std.string("qwe")) + expectEqual(v.size(), 3) + expectEqual(v[0], std.string()) + expectEqual(v[1], std.string("123")) + expectEqual(v[2], std.string("abc")) + + let rm2 = v.remove(at: 1) + expectEqual(rm2, std.string("123")) + expectEqual(v.size(), 2) + + v.remove(at: 0) + expectEqual(v.size(), 1) + + v.remove(at: 0) + expectEqual(v.size(), 0) + expectTrue(v.empty()) +} + StdVectorTestSuite.test("VectorOfInt for loop") { var v = Vector() fill(vector: &v)