From b6181eadead4d1a091d21be7375c157edffa4579 Mon Sep 17 00:00:00 2001 From: Charles Hu Date: Fri, 16 May 2025 19:47:06 -0700 Subject: [PATCH 01/10] Create platform specific AsyncIO - Darwin: based on DispatchIO - Linux: based on epoll - Windows (not included in this commit): based on IOCP with OVERLAPPED --- Sources/Subprocess/AsyncBufferSequence.swift | 37 +- Sources/Subprocess/Buffer.swift | 47 +- Sources/Subprocess/CMakeLists.txt | 1 + Sources/Subprocess/Configuration.swift | 2 - Sources/Subprocess/Error.swift | 17 +- Sources/Subprocess/IO/AsyncIO.swift | 808 ++++++++++++++++++ Sources/Subprocess/IO/Input.swift | 4 +- Sources/Subprocess/IO/Output.swift | 40 +- .../Platforms/Subprocess+Darwin.swift | 26 + .../Platforms/Subprocess+Linux.swift | 14 + .../Platforms/Subprocess+Unix.swift | 141 --- .../Platforms/Subprocess+Windows.swift | 165 +--- .../Input+Foundation.swift | 50 +- .../_SubprocessCShims/include/process_shims.h | 5 + .../SubprocessTests+Unix.swift | 7 +- 15 files changed, 956 insertions(+), 408 deletions(-) create mode 100644 Sources/Subprocess/IO/AsyncIO.swift diff --git a/Sources/Subprocess/AsyncBufferSequence.swift b/Sources/Subprocess/AsyncBufferSequence.swift index 39fb38b..147e09a 100644 --- a/Sources/Subprocess/AsyncBufferSequence.swift +++ b/Sources/Subprocess/AsyncBufferSequence.swift @@ -23,10 +23,10 @@ public struct AsyncBufferSequence: AsyncSequence, Sendable { public typealias Failure = any Swift.Error public typealias Element = Buffer - #if os(Windows) - internal typealias DiskIO = FileDescriptor - #else + #if canImport(Darwin) internal typealias DiskIO = DispatchIO + #else + internal typealias DiskIO = FileDescriptor #endif @_nonSendable @@ -47,15 +47,16 @@ public struct AsyncBufferSequence: AsyncSequence, Sendable { return self.buffer.removeFirst() } // Read more data - let data = try await self.diskIO.read( - upToLength: readBufferSize + let data = try await AsyncIO.shared.read( + from: self.diskIO, + upTo: readBufferSize ) guard let data else { // We finished reading. Close the file descriptor now - #if os(Windows) - try self.diskIO.close() - #else + #if canImport(Darwin) self.diskIO.close() + #else + try self.diskIO.close() #endif return nil } @@ -132,17 +133,7 @@ extension AsyncBufferSequence { self.eofReached = true return nil } - #if os(Windows) - // Cast data to CodeUnit type - let result = buffer.withUnsafeBytes { ptr in - return Array( - UnsafeBufferPointer( - start: ptr.bindMemory(to: Encoding.CodeUnit.self).baseAddress!, - count: ptr.count / MemoryLayout.size - ) - ) - } - #else + #if canImport(Darwin) // Unfortunately here we _have to_ copy the bytes out because // DispatchIO (rightfully) reuses buffer, which means `buffer.data` // has the same address on all iterations, therefore we can't directly @@ -157,7 +148,13 @@ extension AsyncBufferSequence { UnsafeBufferPointer(start: ptr.baseAddress?.assumingMemoryBound(to: Encoding.CodeUnit.self), count: elementCount) ) } - + #else + // Cast data to CodeUnitg type + let result = buffer.withUnsafeBytes { ptr in + return ptr.withMemoryRebound(to: Encoding.CodeUnit.self) { codeUnitPtr in + return Array(codeUnitPtr) + } + } #endif return result.isEmpty ? nil : result } diff --git a/Sources/Subprocess/Buffer.swift b/Sources/Subprocess/Buffer.swift index 94b8f52..292fac4 100644 --- a/Sources/Subprocess/Buffer.swift +++ b/Sources/Subprocess/Buffer.swift @@ -17,18 +17,8 @@ extension AsyncBufferSequence { /// A immutable collection of bytes public struct Buffer: Sendable { - #if os(Windows) - internal let data: [UInt8] - - internal init(data: [UInt8]) { - self.data = data - } - - internal static func createFrom(_ data: [UInt8]) -> [Buffer] { - return [.init(data: data)] - } - #else - // We need to keep the backingData alive while _ContiguousBufferView is alive + #if canImport(Darwin) + // We need to keep the backingData alive while Slice is alive internal let backingData: DispatchData internal let data: DispatchData._ContiguousBufferView @@ -45,7 +35,17 @@ extension AsyncBufferSequence { } return slices.map{ .init(data: $0, backingData: data) } } - #endif + #else + internal let data: [UInt8] + + internal init(data: [UInt8]) { + self.data = data + } + + internal static func createFrom(_ data: [UInt8]) -> [Buffer] { + return [.init(data: data)] + } + #endif // canImport(Darwin) } } @@ -92,26 +92,23 @@ extension AsyncBufferSequence.Buffer { // MARK: - Hashable, Equatable extension AsyncBufferSequence.Buffer: Equatable, Hashable { - #if os(Windows) - // Compiler generated conformances - #else + #if canImport(Darwin) public static func == (lhs: AsyncBufferSequence.Buffer, rhs: AsyncBufferSequence.Buffer) -> Bool { - return lhs.data.elementsEqual(rhs.data) + return lhs.data == rhs.data } public func hash(into hasher: inout Hasher) { - self.data.withUnsafeBytes { ptr in - hasher.combine(bytes: ptr) - } + hasher.combine(self.data) } #endif + // else Compiler generated conformances } // MARK: - DispatchData.Block #if canImport(Darwin) || canImport(Glibc) || canImport(Android) || canImport(Musl) extension DispatchData { /// Unfortunately `DispatchData.Region` is not available on Linux, hence our own wrapper - internal struct _ContiguousBufferView: @unchecked Sendable, RandomAccessCollection { + internal struct _ContiguousBufferView: @unchecked Sendable, RandomAccessCollection, Hashable { typealias Element = UInt8 internal let bytes: UnsafeBufferPointer @@ -127,6 +124,14 @@ extension DispatchData { return try body(UnsafeRawBufferPointer(self.bytes)) } + internal func hash(into hasher: inout Hasher) { + hasher.combine(bytes: UnsafeRawBufferPointer(self.bytes)) + } + + internal static func == (lhs: DispatchData._ContiguousBufferView, rhs: DispatchData._ContiguousBufferView) -> Bool { + return lhs.bytes.elementsEqual(rhs.bytes) + } + subscript(position: Int) -> UInt8 { _read { yield self.bytes[position] diff --git a/Sources/Subprocess/CMakeLists.txt b/Sources/Subprocess/CMakeLists.txt index 1142445..c136bfe 100644 --- a/Sources/Subprocess/CMakeLists.txt +++ b/Sources/Subprocess/CMakeLists.txt @@ -17,6 +17,7 @@ add_library(Subprocess Result.swift IO/Output.swift IO/Input.swift + IO/AsyncIO.swift Span+Subprocess.swift AsyncBufferSequence.swift API.swift diff --git a/Sources/Subprocess/Configuration.swift b/Sources/Subprocess/Configuration.swift index ef08ca9..9e50469 100644 --- a/Sources/Subprocess/Configuration.swift +++ b/Sources/Subprocess/Configuration.swift @@ -603,7 +603,6 @@ internal struct TrackedFileDescriptor: ~Copyable { self.closeWhenDone = closeWhenDone } - #if os(Windows) consuming func consumeDiskIO() -> FileDescriptor { let result = self.fileDescriptor // Transfer the ownership out and therefor @@ -611,7 +610,6 @@ internal struct TrackedFileDescriptor: ~Copyable { self.closeWhenDone = false return result } - #endif internal mutating func safelyClose() throws { guard self.closeWhenDone else { diff --git a/Sources/Subprocess/Error.swift b/Sources/Subprocess/Error.swift index dde4468..b7a6ca5 100644 --- a/Sources/Subprocess/Error.swift +++ b/Sources/Subprocess/Error.swift @@ -41,6 +41,7 @@ extension SubprocessError { case failedToWriteToSubprocess case failedToMonitorProcess case streamOutputExceedsLimit(Int) + case asyncIOFailed(String) // Signal case failedToSendSignal(Int32) // Windows Only @@ -67,18 +68,20 @@ extension SubprocessError { return 5 case .streamOutputExceedsLimit(_): return 6 - case .failedToSendSignal(_): + case .asyncIOFailed(_): return 7 - case .failedToTerminate: + case .failedToSendSignal(_): return 8 - case .failedToSuspend: + case .failedToTerminate: return 9 - case .failedToResume: + case .failedToSuspend: return 10 - case .failedToCreatePipe: + case .failedToResume: return 11 - case .invalidWindowsPath(_): + case .failedToCreatePipe: return 12 + case .invalidWindowsPath(_): + return 13 } } @@ -108,6 +111,8 @@ extension SubprocessError: CustomStringConvertible, CustomDebugStringConvertible return "Failed to monitor the state of child process with underlying error: \(self.underlyingError!)" case .streamOutputExceedsLimit(let limit): return "Failed to create output from current buffer because the output limit (\(limit)) was reached." + case .asyncIOFailed(let reason): + return "An error occurred within the AsyncIO subsystem: \(reason). Underlying error: \(self.underlyingError!)" case .failedToSendSignal(let signal): return "Failed to send signal \(signal) to the child process." case .failedToTerminate: diff --git a/Sources/Subprocess/IO/AsyncIO.swift b/Sources/Subprocess/IO/AsyncIO.swift new file mode 100644 index 0000000..3f075e5 --- /dev/null +++ b/Sources/Subprocess/IO/AsyncIO.swift @@ -0,0 +1,808 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2025 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// +//===----------------------------------------------------------------------===// + +#if canImport(System) +@preconcurrency import System +#else +@preconcurrency import SystemPackage +#endif + +/// Platform specific asynchronous read/write implementation + +// MARK: - Linux (epoll) +#if canImport(Glibc) || canImport(Android) || canImport(Musl) + +#if canImport(Glibc) +import Glibc +#elseif canImport(Android) +import Android +#elseif canImport(Musl) +import Musl +#endif + +import _SubprocessCShims +import Synchronization + +private typealias SignalStream = AsyncThrowingStream +private let _epollEventSize = 256 +private let _registration: Mutex< + [PlatformFileDescriptor : SignalStream.Continuation] +> = Mutex([:]) + +final class AsyncIO: Sendable { + + typealias OutputStream = AsyncThrowingStream + + private final class MonitorThreadContext { + let epollFileDescriptor: CInt + let shutdownFileDescriptor: CInt + + init( + epollFileDescriptor: CInt, + shutdownFileDescriptor: CInt + ) { + self.epollFileDescriptor = epollFileDescriptor + self.shutdownFileDescriptor = shutdownFileDescriptor + } + } + + private enum Event { + case read + case write + } + + private struct State { + let epollFileDescriptor: CInt + let shutdownFileDescriptor: CInt + let monitorThread: pthread_t + } + + static let shared: AsyncIO = AsyncIO() + + private let state: Result + + private init() { + // Create main epoll fd + let epollFileDescriptor = epoll_create1(CInt(EPOLL_CLOEXEC)) + guard epollFileDescriptor >= 0 else { + let error = SubprocessError( + code: .init(.asyncIOFailed("epoll_create1 failed")), + underlyingError: .init(rawValue: errno) + ) + self.state = .failure(error) + return + } + // Create shutdownFileDescriptor + let shutdownFileDescriptor = eventfd(0, CInt(EFD_NONBLOCK | EFD_CLOEXEC)) + guard shutdownFileDescriptor >= 0 else { + let error = SubprocessError( + code: .init(.asyncIOFailed("eventfd failed")), + underlyingError: .init(rawValue: errno) + ) + self.state = .failure(error) + return + } + + // Register shutdownFileDescriptor with epoll + var event = epoll_event( + events: EPOLLIN.rawValue, + data: epoll_data(fd: shutdownFileDescriptor) + ) + var rc = epoll_ctl( + epollFileDescriptor, + EPOLL_CTL_ADD, + shutdownFileDescriptor, + &event + ) + guard rc == 0 else { + let error = SubprocessError( + code: .init(.asyncIOFailed( + "failed to add shutdown fd \(shutdownFileDescriptor) to epoll list") + ), + underlyingError: .init(rawValue: errno) + ) + self.state = .failure(error) + return + } + + // Create thread data + let context = MonitorThreadContext( + epollFileDescriptor: epollFileDescriptor, + shutdownFileDescriptor: shutdownFileDescriptor + ) + let threadContext = Unmanaged.passRetained(context) + #if os(FreeBSD) || os(OpenBSD) + var thread: pthread_t? = nil + #else + var thread: pthread_t = pthread_t() + #endif + rc = pthread_create(&thread, nil, { args in + func reportError(_ error: SubprocessError) { + _registration.withLock { store in + for continuation in store.values { + continuation.finish(throwing: error) + } + } + } + + let unmanaged = Unmanaged.fromOpaque(args!) + let context = unmanaged.takeRetainedValue() + + var events: [epoll_event] = Array( + repeating: epoll_event(events: 0, data: epoll_data(fd: 0)), + count: _epollEventSize + ) + + // Enter the monitor loop + monitorLoop: while true { + let eventCount = epoll_wait( + context.epollFileDescriptor, + &events, + CInt(events.count), + -1 + ) + if eventCount < 0 { + if errno == EINTR || errno == EAGAIN { + continue // interrupted by signal; try again + } + // Report other errors + let error = SubprocessError( + code: .init(.asyncIOFailed( + "epoll_wait failed") + ), + underlyingError: .init(rawValue: errno) + ) + reportError(error) + break monitorLoop + } + + for index in 0 ..< Int(eventCount) { + let event = events[index] + let targetFileDescriptor = event.data.fd + // Breakout the monitor loop if we received shutdown + // from the shutdownFD + if targetFileDescriptor == context.shutdownFileDescriptor { + var buf: UInt64 = 0 + _ = _SubprocessCShims.read(context.shutdownFileDescriptor, &buf, MemoryLayout.size) + break monitorLoop + } + + // Notify the continuation + _registration.withLock { store in + if let continuation = store[targetFileDescriptor] { + continuation.yield(true) + } + } + } + } + + return nil + }, threadContext.toOpaque()) + guard rc == 0 else { + let error = SubprocessError( + code: .init(.asyncIOFailed("Failed to create monitor thread")), + underlyingError: .init(rawValue: rc) + ) + self.state = .failure(error) + return + } + + #if os(FreeBSD) || os(OpenBSD) + let monitorThread = thread! + #else + let monitorThread = thread + #endif + + let state = State( + epollFileDescriptor: epollFileDescriptor, + shutdownFileDescriptor: shutdownFileDescriptor, + monitorThread: monitorThread + ) + self.state = .success(state) + + atexit { + AsyncIO.shared.shutdown() + } + } + + private func shutdown() { + guard case .success(let currentState) = self.state else { + return + } + + var one: UInt64 = 1 + // Wake up the thread for shutdown + _ = _SubprocessCShims.write(currentState.shutdownFileDescriptor, &one, MemoryLayout.stride) + // Cleanup the monitor thread + pthread_join(currentState.monitorThread, nil) + } + + + private func registerFileDescriptor( + _ fileDescriptor: FileDescriptor, + for event: Event + ) -> SignalStream { + return SignalStream { continuation in + // If setup failed, nothing much we can do + switch self.state { + case .success(let state): + // Set file descriptor to be non blocking + let flags = fcntl(fileDescriptor.rawValue, F_GETFD) + guard flags != -1 else { + let error = SubprocessError( + code: .init(.asyncIOFailed( + "failed to get flags for \(fileDescriptor.rawValue)") + ), + underlyingError: .init(rawValue: errno) + ) + continuation.finish(throwing: error) + return + } + guard fcntl(fileDescriptor.rawValue, F_SETFL, flags | O_NONBLOCK) != -1 else { + let error = SubprocessError( + code: .init(.asyncIOFailed( + "failed to set \(fileDescriptor.rawValue) to be non-blocking") + ), + underlyingError: .init(rawValue: errno) + ) + continuation.finish(throwing: error) + return + } + // Register event + let targetEvent: EPOLL_EVENTS + switch event { + case .read: + targetEvent = EPOLLIN + case .write: + targetEvent = EPOLLOUT + } + + var event = epoll_event( + events: targetEvent.rawValue, + data: epoll_data(fd: fileDescriptor.rawValue) + ) + let rc = epoll_ctl( + state.epollFileDescriptor, + EPOLL_CTL_ADD, + fileDescriptor.rawValue, + &event + ) + if rc != 0 { + let error = SubprocessError( + code: .init(.asyncIOFailed( + "failed to add \(fileDescriptor.rawValue) to epoll list") + ), + underlyingError: .init(rawValue: errno) + ) + continuation.finish(throwing: error) + return + } + // Now save the continuation + _registration.withLock { storage in + storage[fileDescriptor.rawValue] = continuation + } + case .failure(let setupError): + continuation.finish(throwing: setupError) + return + } + } + } + + private func removeRegistration(for fileDescriptor: FileDescriptor) throws { + switch self.state { + case .success(let state): + let rc = epoll_ctl( + state.epollFileDescriptor, + EPOLL_CTL_DEL, + fileDescriptor.rawValue, + nil + ) + guard rc == 0 else { + throw SubprocessError( + code: .init(.asyncIOFailed( + "failed to remove \(fileDescriptor.rawValue) to epoll list") + ), + underlyingError: .init(rawValue: errno) + ) + } + _registration.withLock { store in + _ = store.removeValue(forKey: fileDescriptor.rawValue) + } + case .failure(let setupFailure): + throw setupFailure + } + } +} + +extension AsyncIO { + + protocol _ContiguousBytes { + var count: Int { get } + + func withUnsafeBytes( + _ body: (UnsafeRawBufferPointer + ) throws -> ResultType) rethrows -> ResultType + } + + func read( + from diskIO: borrowing TrackedPlatformDiskIO, + upTo maxLength: Int + ) async throws -> [UInt8]? { + return try await self.read(from: diskIO.fileDescriptor, upTo: maxLength) + } + + func read( + from fileDescriptor: FileDescriptor, + upTo maxLength: Int + ) async throws -> [UInt8]? { + // If we are reading until EOF, start with readBufferSize + // and gradually increase buffer size + let bufferLength = maxLength == .max ? readBufferSize : maxLength + + var resultBuffer: [UInt8] = Array( + repeating: 0, count: bufferLength + ) + var readLength: Int = 0 + let signalStream = self.registerFileDescriptor(fileDescriptor, for: .read) + for try await _ in signalStream { + // Every iteration signals we are ready to read more data + while true { + let bytesRead = resultBuffer.withUnsafeMutableBufferPointer { bufferPointer in + // Get a pointer to the memory at the specified offset + let targetCount = bufferPointer.count - readLength + + let offsetAddress = bufferPointer.baseAddress!.advanced(by: readLength) + + // Read directly into the buffer at the offset + return _SubprocessCShims.read(fileDescriptor.rawValue, offsetAddress, targetCount) + } + if bytesRead > 0 { + // Read some data + readLength += bytesRead + if maxLength == .max { + // Grow resultBuffer if needed + guard Double(readLength) > 0.8 * Double(resultBuffer.count) else { + continue + } + resultBuffer.append( + contentsOf: Array(repeating: 0, count: resultBuffer.count) + ) + } else if readLength >= maxLength { + // When we reached maxLength, return! + try self.removeRegistration(for: fileDescriptor) + return resultBuffer + } + } else if bytesRead == 0 { + // We reached EOF. Return whatever's left + try self.removeRegistration(for: fileDescriptor) + guard readLength > 0 else { + return nil + } + resultBuffer.removeLast(resultBuffer.count - readLength) + return resultBuffer + } else { + if errno == EAGAIN || errno == EWOULDBLOCK { + // No more data for now wait for the next signal + break + } else { + // Throw all other errors + try self.removeRegistration(for: fileDescriptor) + throw SubprocessError.UnderlyingError(rawValue: errno) + } + } + } + } + return resultBuffer + } + + func write( + _ array: [UInt8], + to diskIO: borrowing TrackedPlatformDiskIO + ) async throws -> Int { + return try await self._write(array, to: diskIO) + } + + func _write( + _ bytes: Bytes, + to diskIO: borrowing TrackedPlatformDiskIO + ) async throws -> Int { + let fileDescriptor = diskIO.fileDescriptor + let signalStream = self.registerFileDescriptor(fileDescriptor, for: .write) + var writtenLength: Int = 0 + for try await _ in signalStream { + while true { + let written = bytes.withUnsafeBytes { ptr in + let remainingLength = ptr.count - writtenLength + let startPtr = ptr.baseAddress!.advanced(by: writtenLength) + return _SubprocessCShims.write(fileDescriptor.rawValue, startPtr, remainingLength) + } + if written > 0 { + writtenLength += written + if writtenLength >= bytes.count { + // Wrote all data + try self.removeRegistration(for: fileDescriptor) + return writtenLength + } + } else { + if errno == EAGAIN || errno == EWOULDBLOCK { + // No more data for now wait for the next signal + break + } else { + // Throw all other errors + try self.removeRegistration(for: fileDescriptor) + throw SubprocessError.UnderlyingError(rawValue: errno) + } + } + } + } + return 0 + } + + #if SubprocessSpan + func write( + _ span: borrowing RawSpan, + to diskIO: borrowing TrackedPlatformDiskIO + ) async throws -> Int { + let fileDescriptor = diskIO.fileDescriptor + let signalStream = self.registerFileDescriptor(fileDescriptor, for: .write) + var writtenLength: Int = 0 + for try await _ in signalStream { + while true { + let written = span.withUnsafeBytes { ptr in + let remainingLength = ptr.count - writtenLength + let startPtr = ptr.baseAddress!.advanced(by: writtenLength) + return _SubprocessCShims.write(fileDescriptor.rawValue, startPtr, remainingLength) + } + if written > 0 { + writtenLength += written + if writtenLength >= span.byteCount { + // Wrote all data + try self.removeRegistration(for: fileDescriptor) + return writtenLength + } + } else { + if errno == EAGAIN || errno == EWOULDBLOCK { + // No more data for now wait for the next signal + break + } else { + // Throw all other errors + try self.removeRegistration(for: fileDescriptor) + throw SubprocessError.UnderlyingError(rawValue: errno) + } + } + } + } + return 0 + } + #endif +} + +extension Array : AsyncIO._ContiguousBytes where Element == UInt8 {} + +#endif // canImport(Glibc) || canImport(Android) || canImport(Musl) + +// MARK: - macOS (DispatchIO) +#if canImport(Darwin) + +internal import Dispatch + + +final class AsyncIO: Sendable { + static let shared: AsyncIO = AsyncIO() + + private init() {} + + internal func read( + from diskIO: borrowing TrackedPlatformDiskIO, + upTo maxLength: Int + ) async throws -> DispatchData? { + return try await self.read( + from: diskIO.dispatchIO, + upTo: maxLength, + ) + } + + internal func read( + from dispatchIO: DispatchIO, + upTo maxLength: Int + ) async throws -> DispatchData? { + return try await withCheckedThrowingContinuation { continuation in + var buffer: DispatchData = .empty + dispatchIO.read( + offset: 0, + length: maxLength, + queue: .global() + ) { done, data, error in + if error != 0 { + continuation.resume( + throwing: SubprocessError( + code: .init(.failedToReadFromSubprocess), + underlyingError: .init(rawValue: error) + ) + ) + return + } + if let data = data { + if buffer.isEmpty { + buffer = data + } else { + buffer.append(data) + } + } + if done { + if !buffer.isEmpty { + continuation.resume(returning: buffer) + } else { + continuation.resume(returning: nil) + } + } + } + } + } + + #if SubprocessSpan + internal func write( + _ span: borrowing RawSpan, + to diskIO: borrowing TrackedPlatformDiskIO + ) async throws -> Int { + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + let dispatchData = span.withUnsafeBytes { + return DispatchData( + bytesNoCopy: $0, + deallocator: .custom( + nil, + { + // noop + } + ) + ) + } + self.write(dispatchData, to: diskIO) { writtenLength, error in + if let error = error { + continuation.resume(throwing: error) + } else { + continuation.resume(returning: writtenLength) + } + } + } + } + #endif // SubprocessSpan + + internal func write( + _ array: [UInt8], + to diskIO: borrowing TrackedPlatformDiskIO + ) async throws -> Int { + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + let dispatchData = array.withUnsafeBytes { + return DispatchData( + bytesNoCopy: $0, + deallocator: .custom( + nil, + { + // noop + } + ) + ) + } + self.write(dispatchData, to: diskIO) { writtenLength, error in + if let error = error { + continuation.resume(throwing: error) + } else { + continuation.resume(returning: writtenLength) + } + } + } + } + + internal func write( + _ dispatchData: DispatchData, + to diskIO: borrowing TrackedPlatformDiskIO, + queue: DispatchQueue = .global(), + completion: @escaping (Int, Error?) -> Void + ) { + diskIO.dispatchIO.write( + offset: 0, + data: dispatchData, + queue: queue + ) { done, unwritten, error in + guard done else { + // Wait until we are done writing or encountered some error + return + } + + let unwrittenLength = unwritten?.count ?? 0 + let writtenLength = dispatchData.count - unwrittenLength + guard error != 0 else { + completion(writtenLength, nil) + return + } + completion( + writtenLength, + SubprocessError( + code: .init(.failedToWriteToSubprocess), + underlyingError: .init(rawValue: error) + ) + ) + } + } +} + +#endif + +// MARK: - Windows (I/O Completion Ports) TODO +#if os(Windows) + +internal import Dispatch +import WinSDK + +final class AsyncIO: Sendable { + + protocol _ContiguousBytes: Sendable { + var count: Int { get } + + func withUnsafeBytes( + _ body: (UnsafeRawBufferPointer + ) throws -> ResultType) rethrows -> ResultType + } + + static let shared = AsyncIO() + + private init() {} + + func read( + from diskIO: borrowing TrackedPlatformDiskIO, + upTo maxLength: Int + ) async throws -> [UInt8]? { + return try await self.read(from: diskIO.fileDescriptor, upTo: maxLength) + } + + func read( + from fileDescriptor: FileDescriptor, + upTo maxLength: Int + ) async throws -> [UInt8]? { + return try await withCheckedThrowingContinuation { continuation in + DispatchQueue.global(qos: .userInitiated).async { + var totalBytesRead: Int = 0 + var lastError: DWORD? = nil + let values = [UInt8]( + unsafeUninitializedCapacity: maxLength + ) { buffer, initializedCount in + while true { + guard let baseAddress = buffer.baseAddress else { + initializedCount = 0 + break + } + let bufferPtr = baseAddress.advanced(by: totalBytesRead) + var bytesRead: DWORD = 0 + let readSucceed = ReadFile( + fileDescriptor.platformDescriptor, + UnsafeMutableRawPointer(mutating: bufferPtr), + DWORD(maxLength - totalBytesRead), + &bytesRead, + nil + ) + if !readSucceed { + // Windows throws ERROR_BROKEN_PIPE when the pipe is closed + let error = GetLastError() + if error == ERROR_BROKEN_PIPE { + // We are done reading + initializedCount = totalBytesRead + } else { + // We got some error + lastError = error + initializedCount = 0 + } + break + } else { + // We successfully read the current round + totalBytesRead += Int(bytesRead) + } + + if totalBytesRead >= maxLength { + initializedCount = min(maxLength, totalBytesRead) + break + } + } + } + if let lastError = lastError { + let windowsError = SubprocessError( + code: .init(.failedToReadFromSubprocess), + underlyingError: .init(rawValue: lastError) + ) + continuation.resume(throwing: windowsError) + } else { + continuation.resume(returning: values) + } + } + } + } + + func write( + _ array: [UInt8], + to diskIO: borrowing TrackedPlatformDiskIO + ) async throws -> Int { + return try await self._write(array, to: diskIO) + } + + #if SubprocessSpan + func write( + _ span: borrowing RawSpan, + to diskIO: borrowing TrackedPlatformDiskIO + ) async throws -> Int { + // TODO: Remove this hack with I/O Completion Ports rewrite + struct _Box: @unchecked Sendable { + let ptr: UnsafeRawBufferPointer + } + let fileDescriptor = diskIO.fileDescriptor + return try await withCheckedThrowingContinuation { continuation in + span.withUnsafeBytes { ptr in + let box = _Box(ptr: ptr) + DispatchQueue.global().async { + let handle = HANDLE(bitPattern: _get_osfhandle(fileDescriptor.rawValue))! + var writtenBytes: DWORD = 0 + let writeSucceed = WriteFile( + handle, + box.ptr.baseAddress, + DWORD(box.ptr.count), + &writtenBytes, + nil + ) + if !writeSucceed { + let error = SubprocessError( + code: .init(.failedToWriteToSubprocess), + underlyingError: .init(rawValue: GetLastError()) + ) + continuation.resume(throwing: error) + } else { + continuation.resume(returning: Int(writtenBytes)) + } + } + } + } + } + #endif // SubprocessSpan + + func _write( + _ bytes: Bytes, + to diskIO: borrowing TrackedPlatformDiskIO + ) async throws -> Int { + let fileDescriptor = diskIO.fileDescriptor + return try await withCheckedThrowingContinuation { continuation in + DispatchQueue.global().async { + let handle = HANDLE(bitPattern: _get_osfhandle(fileDescriptor.rawValue))! + var writtenBytes: DWORD = 0 + let writeSucceed = bytes.withUnsafeBytes { ptr in + return WriteFile( + handle, + ptr.baseAddress, + DWORD(ptr.count), + &writtenBytes, + nil + ) + } + if !writeSucceed { + let error = SubprocessError( + code: .init(.failedToWriteToSubprocess), + underlyingError: .init(rawValue: GetLastError()) + ) + continuation.resume(throwing: error) + } else { + continuation.resume(returning: Int(writtenBytes)) + } + } + } + } +} + +extension Array : AsyncIO._ContiguousBytes where Element == UInt8 {} + +#endif + diff --git a/Sources/Subprocess/IO/Input.swift b/Sources/Subprocess/IO/Input.swift index 58bfe4d..3f473e4 100644 --- a/Sources/Subprocess/IO/Input.swift +++ b/Sources/Subprocess/IO/Input.swift @@ -224,7 +224,7 @@ public final actor StandardInputWriter: Sendable { public func write( _ array: [UInt8] ) async throws -> Int { - return try await self.diskIO.write(array) + return try await AsyncIO.shared.write(array, to: self.diskIO) } /// Write a `RawSpan` to the standard input of the subprocess. @@ -232,7 +232,7 @@ public final actor StandardInputWriter: Sendable { /// - Returns number of bytes written #if SubprocessSpan public func write(_ span: borrowing RawSpan) async throws -> Int { - return try await self.diskIO.write(span) + return try await AsyncIO.shared.write(span, to: self.diskIO) } #endif diff --git a/Sources/Subprocess/IO/Output.swift b/Sources/Subprocess/IO/Output.swift index 454eca4..563a73a 100644 --- a/Sources/Subprocess/IO/Output.swift +++ b/Sources/Subprocess/IO/Output.swift @@ -142,14 +142,12 @@ public struct BytesOutput: OutputProtocol { internal func captureOutput( from diskIO: consuming TrackedPlatformDiskIO? ) async throws -> [UInt8] { - #if os(Windows) - let result = try await diskIO?.fileDescriptor.read(upToLength: self.maxSize) ?? [] - try diskIO?.safelyClose() - return result - #else - let result = try await diskIO!.dispatchIO.read(upToLength: self.maxSize) + let result = try await AsyncIO.shared.read(from: diskIO!, upTo: self.maxSize) try diskIO?.safelyClose() + #if canImport(Darwin) return result?.array() ?? [] + #else + return result ?? [] #endif } @@ -264,14 +262,14 @@ extension OutputProtocol { if OutputType.self == Void.self { return () as! OutputType } - #if os(Windows) - let result = try await diskIO?.fileDescriptor.read(upToLength: self.maxSize) - try diskIO?.safelyClose() - return try self.output(from: result ?? []) - #else - let result = try await diskIO!.dispatchIO.read(upToLength: self.maxSize) + // Force unwrap is safe here because only `OutputType.self == Void` would + // have nil `TrackedPlatformDiskIO` + let result = try await AsyncIO.shared.read(from: diskIO!, upTo: self.maxSize) try diskIO?.safelyClose() + #if canImport(Darwin) return try self.output(from: result ?? .empty) + #else + return try self.output(from: result ?? []) #endif } } @@ -293,34 +291,34 @@ extension OutputProtocol where OutputType == Void { #if SubprocessSpan extension OutputProtocol { - #if os(Windows) - internal func output(from data: [UInt8]) throws -> OutputType { + #if canImport(Darwin) + internal func output(from data: DispatchData) throws -> OutputType { guard !data.isEmpty else { let empty = UnsafeRawBufferPointer(start: nil, count: 0) let span = RawSpan(_unsafeBytes: empty) return try self.output(from: span) } - return try data.withUnsafeBufferPointer { ptr in - let span = RawSpan(_unsafeBytes: UnsafeRawBufferPointer(ptr)) + return try data.withUnsafeBytes { ptr in + let bufferPtr = UnsafeRawBufferPointer(start: ptr, count: data.count) + let span = RawSpan(_unsafeBytes: bufferPtr) return try self.output(from: span) } } #else - internal func output(from data: DispatchData) throws -> OutputType { + internal func output(from data: [UInt8]) throws -> OutputType { guard !data.isEmpty else { let empty = UnsafeRawBufferPointer(start: nil, count: 0) let span = RawSpan(_unsafeBytes: empty) return try self.output(from: span) } - return try data.withUnsafeBytes { ptr in - let bufferPtr = UnsafeRawBufferPointer(start: ptr, count: data.count) - let span = RawSpan(_unsafeBytes: bufferPtr) + return try data.withUnsafeBufferPointer { ptr in + let span = RawSpan(_unsafeBytes: UnsafeRawBufferPointer(ptr)) return try self.output(from: span) } } - #endif // os(Windows) + #endif // canImport(Darwin) } #endif diff --git a/Sources/Subprocess/Platforms/Subprocess+Darwin.swift b/Sources/Subprocess/Platforms/Subprocess+Darwin.swift index 7cd32c7..02640ab 100644 --- a/Sources/Subprocess/Platforms/Subprocess+Darwin.swift +++ b/Sources/Subprocess/Platforms/Subprocess+Darwin.swift @@ -524,4 +524,30 @@ internal func monitorProcessTermination( } } +internal typealias TrackedPlatformDiskIO = TrackedDispatchIO + +extension TrackedFileDescriptor { + internal consuming func createPlatformDiskIO() -> TrackedPlatformDiskIO { + // Transferring out the ownership of fileDescriptor means we don't have go close here + let shouldClose = self.closeWhenDone + let closeFd = self.fileDescriptor + let dispatchIO: DispatchIO = DispatchIO( + type: .stream, + fileDescriptor: self.platformDescriptor(), + queue: .global(), + cleanupHandler: { error in + // Close the file descriptor + if shouldClose { + try? closeFd.close() + } + } + ) + let result: TrackedPlatformDiskIO = .init( + dispatchIO, closeWhenDone: self.closeWhenDone + ) + self.closeWhenDone = false + return result + } +} + #endif // canImport(Darwin) diff --git a/Sources/Subprocess/Platforms/Subprocess+Linux.swift b/Sources/Subprocess/Platforms/Subprocess+Linux.swift index bf1e5b8..47fbe96 100644 --- a/Sources/Subprocess/Platforms/Subprocess+Linux.swift +++ b/Sources/Subprocess/Platforms/Subprocess+Linux.swift @@ -423,4 +423,18 @@ private func _setupMonitorSignalHandler() { setup } +internal typealias TrackedPlatformDiskIO = TrackedFileDescriptor + +extension TrackedFileDescriptor { + internal consuming func createPlatformDiskIO() -> TrackedPlatformDiskIO { + // Transferring out the ownership of fileDescriptor means we don't have go close here + let result: TrackedPlatformDiskIO = .init( + self.fileDescriptor, + closeWhenDone: self.closeWhenDone + ) + self.closeWhenDone = false + return result + } +} + #endif // canImport(Glibc) || canImport(Android) || canImport(Musl) diff --git a/Sources/Subprocess/Platforms/Subprocess+Unix.swift b/Sources/Subprocess/Platforms/Subprocess+Unix.swift index bb5da2f..122e5c5 100644 --- a/Sources/Subprocess/Platforms/Subprocess+Unix.swift +++ b/Sources/Subprocess/Platforms/Subprocess+Unix.swift @@ -375,146 +375,5 @@ extension FileDescriptor { } internal typealias PlatformFileDescriptor = CInt -internal typealias TrackedPlatformDiskIO = TrackedDispatchIO - -extension TrackedFileDescriptor { - internal consuming func createPlatformDiskIO() -> TrackedPlatformDiskIO { - let dispatchIO: DispatchIO = DispatchIO( - type: .stream, - fileDescriptor: self.platformDescriptor(), - queue: .global(), - cleanupHandler: { error in - // Close the file descriptor - if self.closeWhenDone { - try? self.safelyClose() - } - } - ) - return .init(dispatchIO, closeWhenDone: self.closeWhenDone) - } -} - -// MARK: - TrackedDispatchIO extensions -extension DispatchIO { - internal func read(upToLength maxLength: Int) async throws -> DispatchData? { - return try await withCheckedThrowingContinuation { continuation in - var buffer: DispatchData = .empty - self.read( - offset: 0, - length: maxLength, - queue: .global() - ) { done, data, error in - if error != 0 { - continuation.resume( - throwing: SubprocessError( - code: .init(.failedToReadFromSubprocess), - underlyingError: .init(rawValue: error) - ) - ) - return - } - if let data = data { - if buffer.isEmpty { - buffer = data - } else { - buffer.append(data) - } - } - if done { - if !buffer.isEmpty { - continuation.resume(returning: buffer) - } else { - continuation.resume(returning: nil) - } - } - } - } - } -} - -extension TrackedDispatchIO { - #if SubprocessSpan - internal func write( - _ span: borrowing RawSpan - ) async throws -> Int { - try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in - let dispatchData = span.withUnsafeBytes { - return DispatchData( - bytesNoCopy: $0, - deallocator: .custom( - nil, - { - // noop - } - ) - ) - } - self.write(dispatchData) { writtenLength, error in - if let error = error { - continuation.resume(throwing: error) - } else { - continuation.resume(returning: writtenLength) - } - } - } - } - #endif // SubprocessSpan - - internal func write( - _ array: [UInt8] - ) async throws -> Int { - try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in - let dispatchData = array.withUnsafeBytes { - return DispatchData( - bytesNoCopy: $0, - deallocator: .custom( - nil, - { - // noop - } - ) - ) - } - self.write(dispatchData) { writtenLength, error in - if let error = error { - continuation.resume(throwing: error) - } else { - continuation.resume(returning: writtenLength) - } - } - } - } - - internal func write( - _ dispatchData: DispatchData, - queue: DispatchQueue = .global(), - completion: @escaping (Int, Error?) -> Void - ) { - self.dispatchIO.write( - offset: 0, - data: dispatchData, - queue: queue - ) { done, unwritten, error in - guard done else { - // Wait until we are done writing or encountered some error - return - } - - let unwrittenLength = unwritten?.count ?? 0 - let writtenLength = dispatchData.count - unwrittenLength - guard error != 0 else { - completion(writtenLength, nil) - return - } - completion( - writtenLength, - SubprocessError( - code: .init(.failedToWriteToSubprocess), - underlyingError: .init(rawValue: error) - ) - ) - } - } -} #endif // canImport(Darwin) || canImport(Glibc) || canImport(Android) || canImport(Musl) diff --git a/Sources/Subprocess/Platforms/Subprocess+Windows.swift b/Sources/Subprocess/Platforms/Subprocess+Windows.swift index 3dfb00c..723a662 100644 --- a/Sources/Subprocess/Platforms/Subprocess+Windows.swift +++ b/Sources/Subprocess/Platforms/Subprocess+Windows.swift @@ -1004,168 +1004,15 @@ extension FileDescriptor { } } -extension FileDescriptor { - internal func read(upToLength maxLength: Int) async throws -> [UInt8]? { - return try await withCheckedThrowingContinuation { continuation in - self.readUntilEOF( - upToLength: maxLength - ) { result in - switch result { - case .failure(let error): - continuation.resume(throwing: error) - case .success(let bytes): - continuation.resume(returning: bytes.isEmpty ? nil : bytes) - } - } - } - } - - internal func readUntilEOF( - upToLength maxLength: Int, - resultHandler: @Sendable @escaping (Swift.Result<[UInt8], any (Error & Sendable)>) -> Void - ) { - DispatchQueue.global(qos: .userInitiated).async { - var totalBytesRead: Int = 0 - var lastError: DWORD? = nil - let values = [UInt8]( - unsafeUninitializedCapacity: maxLength - ) { buffer, initializedCount in - while true { - guard let baseAddress = buffer.baseAddress else { - initializedCount = 0 - break - } - let bufferPtr = baseAddress.advanced(by: totalBytesRead) - var bytesRead: DWORD = 0 - let readSucceed = ReadFile( - self.platformDescriptor, - UnsafeMutableRawPointer(mutating: bufferPtr), - DWORD(maxLength - totalBytesRead), - &bytesRead, - nil - ) - if !readSucceed { - // Windows throws ERROR_BROKEN_PIPE when the pipe is closed - let error = GetLastError() - if error == ERROR_BROKEN_PIPE { - // We are done reading - initializedCount = totalBytesRead - } else { - // We got some error - lastError = error - initializedCount = 0 - } - break - } else { - // We successfully read the current round - totalBytesRead += Int(bytesRead) - } - - if totalBytesRead >= maxLength { - initializedCount = min(maxLength, totalBytesRead) - break - } - } - } - if let lastError = lastError { - let windowsError = SubprocessError( - code: .init(.failedToReadFromSubprocess), - underlyingError: .init(rawValue: lastError) - ) - resultHandler(.failure(windowsError)) - } else { - resultHandler(.success(values)) - } - } - } -} - extension TrackedFileDescriptor { internal consuming func createPlatformDiskIO() -> TrackedPlatformDiskIO { - // TrackedPlatformDiskIO is a typealias of TrackedFileDescriptor on Windows (they're the same type) - // Just return the same object so we don't create a copy and try to double-close the fd. - return self - } - - internal func readUntilEOF( - upToLength maxLength: Int, - resultHandler: @Sendable @escaping (Swift.Result<[UInt8], any (Error & Sendable)>) -> Void - ) { - self.fileDescriptor.readUntilEOF( - upToLength: maxLength, - resultHandler: resultHandler + // Transferring out the ownership of fileDescriptor means we don't have go close here + let result: TrackedPlatformDiskIO = .init( + self.fileDescriptor, + closeWhenDone: self.closeWhenDone ) - } - -#if SubprocessSpan - internal func write( - _ span: borrowing RawSpan - ) async throws -> Int { - let fileDescriptor = self.fileDescriptor - return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in - span.withUnsafeBytes { ptr in - // TODO: Use WriteFileEx for asyc here - Self.write( - ptr, - to: fileDescriptor - ) { writtenLength, error in - if let error = error { - continuation.resume(throwing: error) - } else { - continuation.resume(returning: writtenLength) - } - } - } - } - } -#endif - - internal func write( - _ array: [UInt8] - ) async throws -> Int { - try await withCheckedThrowingContinuation { continuation in - // TODO: Figure out a better way to asynchronously write - let fd = self.fileDescriptor - DispatchQueue.global(qos: .userInitiated).async { - array.withUnsafeBytes { - Self.write( - $0, - to: fd - ) { writtenLength, error in - if let error = error { - continuation.resume(throwing: error) - } else { - continuation.resume(returning: writtenLength) - } - } - } - } - } - } - - internal static func write( - _ ptr: UnsafeRawBufferPointer, - to fileDescriptor: FileDescriptor, - completion: @escaping (Int, Swift.Error?) -> Void - ) { - let handle = HANDLE(bitPattern: _get_osfhandle(fileDescriptor.rawValue))! - var writtenBytes: DWORD = 0 - let writeSucceed = WriteFile( - handle, - ptr.baseAddress, - DWORD(ptr.count), - &writtenBytes, - nil - ) - if !writeSucceed { - let error = SubprocessError( - code: .init(.failedToWriteToSubprocess), - underlyingError: .init(rawValue: GetLastError()) - ) - completion(Int(writtenBytes), error) - } else { - completion(Int(writtenBytes), nil) - } + self.closeWhenDone = false + return result } } diff --git a/Sources/Subprocess/SubprocessFoundation/Input+Foundation.swift b/Sources/Subprocess/SubprocessFoundation/Input+Foundation.swift index a1db8ad..8312c56 100644 --- a/Sources/Subprocess/SubprocessFoundation/Input+Foundation.swift +++ b/Sources/Subprocess/SubprocessFoundation/Input+Foundation.swift @@ -111,7 +111,7 @@ extension StandardInputWriter { public func write( _ data: Data ) async throws -> Int { - return try await self.diskIO.write(data) + return try await AsyncIO.shared.write(data, to: self.diskIO) } /// Write a AsyncSequence of Data to the standard input of the subprocess. @@ -128,35 +128,12 @@ extension StandardInputWriter { } } -#if os(Windows) -extension TrackedFileDescriptor { - internal func write( - _ data: Data - ) async throws -> Int { - let fileDescriptor = self.fileDescriptor - return try await withCheckedThrowingContinuation { continuation in - // TODO: Figure out a better way to asynchronously write - DispatchQueue.global(qos: .userInitiated).async { - data.withUnsafeBytes { - Self.write( - $0, - to: fileDescriptor - ) { writtenLength, error in - if let error = error { - continuation.resume(throwing: error) - } else { - continuation.resume(returning: writtenLength) - } - } - } - } - } - } -} -#else -extension TrackedDispatchIO { + +#if canImport(Darwin) +extension AsyncIO { internal func write( - _ data: Data + _ data: Data, + to diskIO: borrowing TrackedPlatformDiskIO ) async throws -> Int { try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in let dispatchData = data.withUnsafeBytes { @@ -170,7 +147,7 @@ extension TrackedDispatchIO { ) ) } - self.write(dispatchData) { writtenLength, error in + self.write(dispatchData, to: diskIO) { writtenLength, error in if let error = error { continuation.resume(throwing: error) } else { @@ -180,6 +157,17 @@ extension TrackedDispatchIO { } } } -#endif // os(Windows) +#else +extension Data : AsyncIO._ContiguousBytes { } + +extension AsyncIO { + internal func write( + _ data: Data, + to diskIO: borrowing TrackedPlatformDiskIO + ) async throws -> Int { + return try await self._write(data, to: diskIO) + } +} +#endif // canImport(Darwin) #endif // SubprocessFoundation diff --git a/Sources/_SubprocessCShims/include/process_shims.h b/Sources/_SubprocessCShims/include/process_shims.h index 35cbd2f..0ae4a5a 100644 --- a/Sources/_SubprocessCShims/include/process_shims.h +++ b/Sources/_SubprocessCShims/include/process_shims.h @@ -21,6 +21,11 @@ #include #endif +#if TARGET_OS_LINUX +#include +#include +#endif // TARGET_OS_LINUX + #if __has_include() vm_size_t _subprocess_vm_size(void); #endif diff --git a/Tests/SubprocessTests/SubprocessTests+Unix.swift b/Tests/SubprocessTests/SubprocessTests+Unix.swift index a50c4e8..25a464e 100644 --- a/Tests/SubprocessTests/SubprocessTests+Unix.swift +++ b/Tests/SubprocessTests/SubprocessTests+Unix.swift @@ -668,14 +668,11 @@ extension SubprocessUnixTests { var platformOptions = PlatformOptions() platformOptions.supplementaryGroups = Array(expectedGroups) let idResult = try await Subprocess.run( - .name("swift"), + .path("/usr/bin/swift"), arguments: [getgroupsSwift.string], platformOptions: platformOptions, - output: .string, - error: .string, + output: .string ) - let error = try #require(idResult.standardError) - try #require(error == "") #expect(idResult.terminationStatus.isSuccess) let ids = try #require( idResult.standardOutput From e374669f6a97a214a6e5d8548d7263af381e55cc Mon Sep 17 00:00:00 2001 From: Charles Hu Date: Wed, 2 Jul 2025 22:09:00 -0700 Subject: [PATCH 02/10] Refactor captureOutput() to minimize force unwrap --- Sources/Subprocess/IO/AsyncIO.swift | 20 +++++++++++++++++++- Sources/Subprocess/IO/Output.swift | 27 ++++++++++++++++++--------- 2 files changed, 37 insertions(+), 10 deletions(-) diff --git a/Sources/Subprocess/IO/AsyncIO.swift b/Sources/Subprocess/IO/AsyncIO.swift index 3f075e5..2e7b129 100644 --- a/Sources/Subprocess/IO/AsyncIO.swift +++ b/Sources/Subprocess/IO/AsyncIO.swift @@ -352,8 +352,14 @@ extension AsyncIO { ) var readLength: Int = 0 let signalStream = self.registerFileDescriptor(fileDescriptor, for: .read) + /// Outer loop: every iteration signals we are ready to read more data for try await _ in signalStream { - // Every iteration signals we are ready to read more data + /// Inner loop: repeatedly call `.read()` and read more data until: + /// 1. We reached EOF (read length is 0), in which case return the result + /// 2. We read `maxLength` bytes, in which case return the result + /// 3. `read()` returns -1 and sets `errno` to `EAGAIN` or `EWOULDBLOCK`. In + /// this case we `break` out of the inner loop and wait `.read()` to be + /// ready by `await`ing the next signal in the outer loop. while true { let bytesRead = resultBuffer.withUnsafeMutableBufferPointer { bufferPointer in // Get a pointer to the memory at the specified offset @@ -417,7 +423,13 @@ extension AsyncIO { let fileDescriptor = diskIO.fileDescriptor let signalStream = self.registerFileDescriptor(fileDescriptor, for: .write) var writtenLength: Int = 0 + /// Outer loop: every iteration signals we are ready to read more data for try await _ in signalStream { + /// Inner loop: repeatedly call `.write()` and write more data until: + /// 1. We've written bytes.count bytes. + /// 3. `.write()` returns -1 and sets `errno` to `EAGAIN` or `EWOULDBLOCK`. In + /// this case we `break` out of the inner loop and wait `.write()` to be + /// ready by `await`ing the next signal in the outer loop. while true { let written = bytes.withUnsafeBytes { ptr in let remainingLength = ptr.count - writtenLength @@ -454,7 +466,13 @@ extension AsyncIO { let fileDescriptor = diskIO.fileDescriptor let signalStream = self.registerFileDescriptor(fileDescriptor, for: .write) var writtenLength: Int = 0 + /// Outer loop: every iteration signals we are ready to read more data for try await _ in signalStream { + /// Inner loop: repeatedly call `.write()` and write more data until: + /// 1. We've written bytes.count bytes. + /// 3. `.write()` returns -1 and sets `errno` to `EAGAIN` or `EWOULDBLOCK`. In + /// this case we `break` out of the inner loop and wait `.write()` to be + /// ready by `await`ing the next signal in the outer loop. while true { let written = span.withUnsafeBytes { ptr in let remainingLength = ptr.count - writtenLength diff --git a/Sources/Subprocess/IO/Output.swift b/Sources/Subprocess/IO/Output.swift index 563a73a..983b473 100644 --- a/Sources/Subprocess/IO/Output.swift +++ b/Sources/Subprocess/IO/Output.swift @@ -140,10 +140,10 @@ public struct BytesOutput: OutputProtocol { public let maxSize: Int internal func captureOutput( - from diskIO: consuming TrackedPlatformDiskIO? + from diskIO: consuming TrackedPlatformDiskIO ) async throws -> [UInt8] { - let result = try await AsyncIO.shared.read(from: diskIO!, upTo: self.maxSize) - try diskIO?.safelyClose() + let result = try await AsyncIO.shared.read(from: diskIO, upTo: self.maxSize) + try diskIO.safelyClose() #if canImport(Darwin) return result?.array() ?? [] #else @@ -255,17 +255,26 @@ extension OutputProtocol { internal func captureOutput( from diskIO: consuming TrackedPlatformDiskIO? ) async throws -> OutputType { - if let bytesOutput = self as? BytesOutput { - return try await bytesOutput.captureOutput(from: diskIO) as! Self.OutputType - } - if OutputType.self == Void.self { return () as! OutputType } + // `diskIO` is only `nil` for any types that conform to `OutputProtocol` + // and have `Void` as ``OutputType` (i.e. `DiscardedOutput`). Since we + // made sure `OutputType` is not `Void` on the line above, `diskIO` + // must not be nil; otherwise, this is a programmer error. + guard var diskIO else { + fatalError( + "Internal Inconsistency Error: diskIO must not be nil when OutputType is not Void" + ) + } + + if let bytesOutput = self as? BytesOutput { + return try await bytesOutput.captureOutput(from: diskIO) as! Self.OutputType + } // Force unwrap is safe here because only `OutputType.self == Void` would // have nil `TrackedPlatformDiskIO` - let result = try await AsyncIO.shared.read(from: diskIO!, upTo: self.maxSize) - try diskIO?.safelyClose() + let result = try await AsyncIO.shared.read(from: diskIO, upTo: self.maxSize) + try diskIO.safelyClose() #if canImport(Darwin) return try self.output(from: result ?? .empty) #else From 61508ad71b32af5dbdaafcfbc930df4c18eaaf05 Mon Sep 17 00:00:00 2001 From: Charles Hu Date: Fri, 16 May 2025 19:47:06 -0700 Subject: [PATCH 03/10] Create platform specific AsyncIO - Darwin: based on DispatchIO - Linux: based on epoll - Windows (not included in this commit): based on IOCP with OVERLAPPED --- Sources/Subprocess/IO/AsyncIO.swift | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Sources/Subprocess/IO/AsyncIO.swift b/Sources/Subprocess/IO/AsyncIO.swift index 2e7b129..e4f0608 100644 --- a/Sources/Subprocess/IO/AsyncIO.swift +++ b/Sources/Subprocess/IO/AsyncIO.swift @@ -54,6 +54,8 @@ final class AsyncIO: Sendable { } } + static let shared: AsyncIO = AsyncIO() + private enum Event { case read case write From a72d35f0d321f0f6e9e23e6c6375cabd00b55c19 Mon Sep 17 00:00:00 2001 From: Charles Hu Date: Mon, 23 Jun 2025 10:01:02 -0700 Subject: [PATCH 04/10] Fix fd was not closed error on Windows --- .../Platforms/Subprocess+Windows.swift | 51 +++++++++---------- 1 file changed, 23 insertions(+), 28 deletions(-) diff --git a/Sources/Subprocess/Platforms/Subprocess+Windows.swift b/Sources/Subprocess/Platforms/Subprocess+Windows.swift index 723a662..db799ba 100644 --- a/Sources/Subprocess/Platforms/Subprocess+Windows.swift +++ b/Sources/Subprocess/Platforms/Subprocess+Windows.swift @@ -48,37 +48,32 @@ extension Configuration { outputPipe: consuming CreatedPipe, errorPipe: consuming CreatedPipe ) throws -> SpawnResult { - var inputPipeBox: CreatedPipe? = consume inputPipe - var outputPipeBox: CreatedPipe? = consume outputPipe - var errorPipeBox: CreatedPipe? = consume errorPipe - - var _inputPipe = inputPipeBox.take()! - var _outputPipe = outputPipeBox.take()! - var _errorPipe = errorPipeBox.take()! - - let inputReadFileDescriptor: TrackedFileDescriptor? = _inputPipe.readFileDescriptor() - let inputWriteFileDescriptor: TrackedFileDescriptor? = _inputPipe.writeFileDescriptor() - let outputReadFileDescriptor: TrackedFileDescriptor? = _outputPipe.readFileDescriptor() - let outputWriteFileDescriptor: TrackedFileDescriptor? = _outputPipe.writeFileDescriptor() - let errorReadFileDescriptor: TrackedFileDescriptor? = _errorPipe.readFileDescriptor() - let errorWriteFileDescriptor: TrackedFileDescriptor? = _errorPipe.writeFileDescriptor() - - let ( - applicationName, - commandAndArgs, - environment, - intendedWorkingDir - ): (String?, String, String, String?) + var inputReadFileDescriptor: TrackedFileDescriptor? = inputPipe.readFileDescriptor() + var inputWriteFileDescriptor: TrackedFileDescriptor? = inputPipe.writeFileDescriptor() + var outputReadFileDescriptor: TrackedFileDescriptor? = outputPipe.readFileDescriptor() + var outputWriteFileDescriptor: TrackedFileDescriptor? = outputPipe.writeFileDescriptor() + var errorReadFileDescriptor: TrackedFileDescriptor? = errorPipe.readFileDescriptor() + var errorWriteFileDescriptor: TrackedFileDescriptor? = errorPipe.writeFileDescriptor() + + let applicationName: String? + let commandAndArgs: String + let environment: String + let intendedWorkingDir: String? do { - (applicationName, commandAndArgs, environment, intendedWorkingDir) = try self.preSpawn() + ( + applicationName, + commandAndArgs, + environment, + intendedWorkingDir + ) = try self.preSpawn() } catch { try self.safelyCloseMultiple( - inputRead: inputReadFileDescriptor, - inputWrite: inputWriteFileDescriptor, - outputRead: outputReadFileDescriptor, - outputWrite: outputWriteFileDescriptor, - errorRead: errorReadFileDescriptor, - errorWrite: errorWriteFileDescriptor + inputRead: inputReadFileDescriptor.take(), + inputWrite: inputWriteFileDescriptor.take(), + outputRead: outputReadFileDescriptor.take(), + outputWrite: outputWriteFileDescriptor.take(), + errorRead: errorReadFileDescriptor.take(), + errorWrite: errorWriteFileDescriptor.take() ) throw error } From 74995e42d112014443dca30bb989d12970d3f9f1 Mon Sep 17 00:00:00 2001 From: Charles Hu Date: Tue, 24 Jun 2025 13:43:34 -0700 Subject: [PATCH 05/10] Fix Windows test errors --- Sources/Subprocess/IO/AsyncIO.swift | 11 ++++++++--- .../SubprocessTests/SubprocessTests+Windows.swift | 14 ++++++-------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/Sources/Subprocess/IO/AsyncIO.swift b/Sources/Subprocess/IO/AsyncIO.swift index e4f0608..d1464cc 100644 --- a/Sources/Subprocess/IO/AsyncIO.swift +++ b/Sources/Subprocess/IO/AsyncIO.swift @@ -330,8 +330,8 @@ extension AsyncIO { var count: Int { get } func withUnsafeBytes( - _ body: (UnsafeRawBufferPointer - ) throws -> ResultType) rethrows -> ResultType + _ body: (UnsafeRawBufferPointer) throws -> ResultType + ) rethrows -> ResultType } func read( @@ -739,7 +739,12 @@ final class AsyncIO: Sendable { ) continuation.resume(throwing: windowsError) } else { - continuation.resume(returning: values) + // If we didn't read anything, return nil + if values.isEmpty { + continuation.resume(returning: nil) + } else { + continuation.resume(returning: values) + } } } } diff --git a/Tests/SubprocessTests/SubprocessTests+Windows.swift b/Tests/SubprocessTests/SubprocessTests+Windows.swift index 3c615e9..c2981e3 100644 --- a/Tests/SubprocessTests/SubprocessTests+Windows.swift +++ b/Tests/SubprocessTests/SubprocessTests+Windows.swift @@ -27,7 +27,7 @@ import TestResources @Suite(.serialized) struct SubprocessWindowsTests { - private let cmdExe: Subprocess.Executable = .path("C:\\Windows\\System32\\cmd.exe") + private let cmdExe: Subprocess.Executable = .name("cmd.exe") } // MARK: - Executable Tests @@ -87,7 +87,7 @@ extension SubprocessWindowsTests { Issue.record("Expected to throw POSIXError") } catch { guard let subprocessError = error as? SubprocessError, - let underlying = subprocessError.underlyingError + let underlying = subprocessError.underlyingError else { Issue.record("Expected CocoaError, got \(error)") return @@ -128,7 +128,6 @@ extension SubprocessWindowsTests { environment: .inherit, output: .string ) - #expect(result.terminationStatus.isSuccess) // As a sanity check, make sure there's // `C:\Windows\system32` in PATH // since we inherited the environment variables @@ -249,7 +248,6 @@ extension SubprocessWindowsTests { output: .data(limit: 2048 * 1024) ) - #expect(catResult.terminationStatus.isSuccess) // Make sure we read all bytes #expect( catResult.standardOutput == expected @@ -304,7 +302,6 @@ extension SubprocessWindowsTests { input: .sequence(stream), output: .data(limit: 2048 * 1024) ) - #expect(catResult.terminationStatus.isSuccess) #expect( catResult.standardOutput == expected ) @@ -510,7 +507,7 @@ extension SubprocessWindowsTests { @Test func testPlatformOptionsCreateNewConsole() async throws { let parentConsole = GetConsoleWindow() let sameConsoleResult = try await Subprocess.run( - .path("C:\\Windows\\System32\\WindowsPowerShell\\v1.0\\powershell.exe"), + .name("powershell.exe"), arguments: [ "-File", windowsTester.string, "-mode", "get-console-window", @@ -529,7 +526,7 @@ extension SubprocessWindowsTests { var platformOptions: Subprocess.PlatformOptions = .init() platformOptions.consoleBehavior = .createNew let differentConsoleResult = try await Subprocess.run( - .path("C:\\Windows\\System32\\WindowsPowerShell\\v1.0\\powershell.exe"), + .name("powershell.exe"), arguments: [ "-File", windowsTester.string, "-mode", "get-console-window", @@ -700,12 +697,13 @@ extension SubprocessWindowsTests { 0 ) let pid = try Subprocess.runDetached( - .path("C:\\Windows\\System32\\WindowsPowerShell\\v1.0\\powershell.exe"), + .name("powershell.exe"), arguments: [ "-Command", "Write-Host $PID", ], output: writeFd ) + try writeFd.close() // Wait for process to finish guard let processHandle = OpenProcess( From 44a9e978d7b6c5b4d67a72aa5865c00cbbb303c2 Mon Sep 17 00:00:00 2001 From: Charles Hu Date: Wed, 25 Jun 2025 13:58:37 -0700 Subject: [PATCH 06/10] Introduce Windows IOCP based AsyncIO implementation --- Sources/Subprocess/API.swift | 50 +- Sources/Subprocess/AsyncBufferSequence.swift | 12 +- Sources/Subprocess/Configuration.swift | 328 +++++++++--- Sources/Subprocess/Error.swift | 2 +- Sources/Subprocess/Execution.swift | 2 +- Sources/Subprocess/IO/AsyncIO.swift | 483 +++++++++++++----- Sources/Subprocess/IO/Input.swift | 17 +- Sources/Subprocess/IO/Output.swift | 22 +- .../Platforms/Subprocess+Darwin.swift | 44 +- .../Platforms/Subprocess+Linux.swift | 32 +- .../Platforms/Subprocess+Windows.swift | 64 +-- .../Input+Foundation.swift | 4 +- Sources/Subprocess/Teardown.swift | 2 +- .../SubprocessTests+Windows.swift | 16 +- Tests/TestResources/TestResources.swift | 2 +- 15 files changed, 732 insertions(+), 348 deletions(-) diff --git a/Sources/Subprocess/API.swift b/Sources/Subprocess/API.swift index b7788d6..6710d0a 100644 --- a/Sources/Subprocess/API.swift +++ b/Sources/Subprocess/API.swift @@ -105,9 +105,9 @@ public func run< output: try output.createPipe(), error: try error.createPipe() ) { execution, inputIO, outputIO, errorIO in - var inputIOBox: TrackedPlatformDiskIO? = consume inputIO - var outputIOBox: TrackedPlatformDiskIO? = consume outputIO - var errorIOBox: TrackedPlatformDiskIO? = consume errorIO + var inputIOBox: IOChannel? = consume inputIO + var outputIOBox: IOChannel? = consume outputIO + var errorIOBox: IOChannel? = consume errorIO // Write input, capture output and error in parallel async let stdout = try output.captureOutput(from: outputIOBox.take()) @@ -177,12 +177,12 @@ public func run( output: try output.createPipe(), error: try error.createPipe() ) { execution, inputIO, outputIO, errorIO in - var inputIOBox: TrackedPlatformDiskIO? = consume inputIO - var outputIOBox: TrackedPlatformDiskIO? = consume outputIO + var inputIOBox: IOChannel? = consume inputIO + var outputIOBox: IOChannel? = consume outputIO return try await withThrowingTaskGroup( of: Void.self, returning: Result.self ) { group in - var inputIOContainer: TrackedPlatformDiskIO? = inputIOBox.take() + var inputIOContainer: IOChannel? = inputIOBox.take() group.addTask { if let inputIO = inputIOContainer.take() { let writer = StandardInputWriter(diskIO: inputIO) @@ -253,7 +253,7 @@ public func run( } // Body runs in the same isolation - let outputSequence = AsyncBufferSequence(diskIO: outputIOBox.take()!.consumeDiskIO()) + let outputSequence = AsyncBufferSequence(diskIO: outputIOBox.take()!.consumeIOChannel()) let result = try await body(execution, outputSequence) try await group.waitForAll() return result @@ -299,13 +299,13 @@ public func run( output: try output.createPipe(), error: try error.createPipe() ) { execution, inputIO, outputIO, errorIO in - var inputIOBox: TrackedPlatformDiskIO? = consume inputIO - var errorIOBox: TrackedPlatformDiskIO? = consume errorIO + var inputIOBox: IOChannel? = consume inputIO + var errorIOBox: IOChannel? = consume errorIO return try await withThrowingTaskGroup( of: Void.self, returning: Result.self ) { group in - var inputIOContainer: TrackedPlatformDiskIO? = inputIOBox.take() + var inputIOContainer: IOChannel? = inputIOBox.take() group.addTask { if let inputIO = inputIOContainer.take() { let writer = StandardInputWriter(diskIO: inputIO) @@ -315,7 +315,7 @@ public func run( } // Body runs in the same isolation - let errorSequence = AsyncBufferSequence(diskIO: errorIOBox.take()!.consumeDiskIO()) + let errorSequence = AsyncBufferSequence(diskIO: errorIOBox.take()!.consumeIOChannel()) let result = try await body(execution, errorSequence) try await group.waitForAll() return result @@ -363,7 +363,7 @@ public func run( error: try error.createPipe() ) { execution, inputIO, outputIO, errorIO in let writer = StandardInputWriter(diskIO: inputIO!) - let outputSequence = AsyncBufferSequence(diskIO: outputIO!.consumeDiskIO()) + let outputSequence = AsyncBufferSequence(diskIO: outputIO!.consumeIOChannel()) return try await body(execution, writer, outputSequence) } } @@ -408,7 +408,7 @@ public func run( error: try error.createPipe() ) { execution, inputIO, outputIO, errorIO in let writer = StandardInputWriter(diskIO: inputIO!) - let errorSequence = AsyncBufferSequence(diskIO: errorIO!.consumeDiskIO()) + let errorSequence = AsyncBufferSequence(diskIO: errorIO!.consumeIOChannel()) return try await body(execution, writer, errorSequence) } } @@ -460,8 +460,8 @@ public func run( error: try error.createPipe() ) { execution, inputIO, outputIO, errorIO in let writer = StandardInputWriter(diskIO: inputIO!) - let outputSequence = AsyncBufferSequence(diskIO: outputIO!.consumeDiskIO()) - let errorSequence = AsyncBufferSequence(diskIO: errorIO!.consumeDiskIO()) + let outputSequence = AsyncBufferSequence(diskIO: outputIO!.consumeIOChannel()) + let errorSequence = AsyncBufferSequence(diskIO: errorIO!.consumeIOChannel()) return try await body(execution, writer, outputSequence, errorSequence) } } @@ -497,16 +497,16 @@ public func run< error: try error.createPipe() ) { (execution, inputIO, outputIO, errorIO) -> RunResult in // Write input, capture output and error in parallel - var inputIOBox: TrackedPlatformDiskIO? = consume inputIO - var outputIOBox: TrackedPlatformDiskIO? = consume outputIO - var errorIOBox: TrackedPlatformDiskIO? = consume errorIO + var inputIOBox: IOChannel? = consume inputIO + var outputIOBox: IOChannel? = consume outputIO + var errorIOBox: IOChannel? = consume errorIO return try await withThrowingTaskGroup( of: OutputCapturingState?.self, returning: RunResult.self ) { group in - var inputIOContainer: TrackedPlatformDiskIO? = inputIOBox.take() - var outputIOContainer: TrackedPlatformDiskIO? = outputIOBox.take() - var errorIOContainer: TrackedPlatformDiskIO? = errorIOBox.take() + var inputIOContainer: IOChannel? = inputIOBox.take() + var outputIOContainer: IOChannel? = outputIOBox.take() + var errorIOContainer: IOChannel? = errorIOBox.take() group.addTask { if let writeFd = inputIOContainer.take() { let writer = StandardInputWriter(diskIO: writeFd) @@ -580,8 +580,8 @@ public func run( error: try error.createPipe() ) { execution, inputIO, outputIO, errorIO in let writer = StandardInputWriter(diskIO: inputIO!) - let outputSequence = AsyncBufferSequence(diskIO: outputIO!.consumeDiskIO()) - let errorSequence = AsyncBufferSequence(diskIO: errorIO!.consumeDiskIO()) + let outputSequence = AsyncBufferSequence(diskIO: outputIO!.consumeIOChannel()) + let errorSequence = AsyncBufferSequence(diskIO: errorIO!.consumeIOChannel()) return try await body(execution, writer, outputSequence, errorSequence) } } diff --git a/Sources/Subprocess/AsyncBufferSequence.swift b/Sources/Subprocess/AsyncBufferSequence.swift index 147e09a..0ca7e6e 100644 --- a/Sources/Subprocess/AsyncBufferSequence.swift +++ b/Sources/Subprocess/AsyncBufferSequence.swift @@ -19,12 +19,14 @@ internal import Dispatch #endif -public struct AsyncBufferSequence: AsyncSequence, Sendable { +public struct AsyncBufferSequence: AsyncSequence, @unchecked Sendable { public typealias Failure = any Swift.Error public typealias Element = Buffer #if canImport(Darwin) internal typealias DiskIO = DispatchIO + #elseif canImport(WinSDK) + internal typealias DiskIO = HANDLE #else internal typealias DiskIO = FileDescriptor #endif @@ -54,9 +56,11 @@ public struct AsyncBufferSequence: AsyncSequence, Sendable { guard let data else { // We finished reading. Close the file descriptor now #if canImport(Darwin) - self.diskIO.close() + try _safelyClose(.dispatchIO(self.diskIO)) + #elseif canImport(WinSDK) + try _safelyClose(.handle(self.diskIO)) #else - try self.diskIO.close() + try _safelyClose(.fileDescriptor(self.diskIO)) #endif return nil } @@ -337,7 +341,7 @@ private let _pageSize: Int = { Int(_subprocess_vm_size()) }() #elseif canImport(WinSDK) -import WinSDK +@preconcurrency import WinSDK private let _pageSize: Int = { var sysInfo: SYSTEM_INFO = SYSTEM_INFO() GetSystemInfo(&sysInfo) diff --git a/Sources/Subprocess/Configuration.swift b/Sources/Subprocess/Configuration.swift index 9e50469..1e20557 100644 --- a/Sources/Subprocess/Configuration.swift +++ b/Sources/Subprocess/Configuration.swift @@ -24,7 +24,7 @@ import Glibc #elseif canImport(Musl) import Musl #elseif canImport(WinSDK) -import WinSDK +@preconcurrency import WinSDK #endif internal import Dispatch @@ -64,7 +64,7 @@ public struct Configuration: Sendable { output: consuming CreatedPipe, error: consuming CreatedPipe, isolation: isolated (any Actor)? = #isolation, - _ body: ((Execution, consuming TrackedPlatformDiskIO?, consuming TrackedPlatformDiskIO?, consuming TrackedPlatformDiskIO?) async throws -> Result) + _ body: ((Execution, consuming IOChannel?, consuming IOChannel?, consuming IOChannel?) async throws -> Result) ) async throws -> ExecutionResult { let spawnResults = try self.spawn( withInput: input, @@ -139,12 +139,12 @@ extension Configuration { /// Close each input individually, and throw the first error if there's multiple errors thrown @Sendable internal func safelyCloseMultiple( - inputRead: consuming TrackedFileDescriptor?, - inputWrite: consuming TrackedFileDescriptor?, - outputRead: consuming TrackedFileDescriptor?, - outputWrite: consuming TrackedFileDescriptor?, - errorRead: consuming TrackedFileDescriptor?, - errorWrite: consuming TrackedFileDescriptor? + inputRead: consuming IODescriptor?, + inputWrite: consuming IODescriptor?, + outputRead: consuming IODescriptor?, + outputWrite: consuming IODescriptor?, + errorRead: consuming IODescriptor?, + errorWrite: consuming IODescriptor? ) throws { var possibleError: (any Swift.Error)? = nil @@ -495,15 +495,15 @@ extension Configuration { /// via `SpawnResult` to perform actual reads internal struct SpawnResult: ~Copyable { let execution: Execution - var _inputWriteEnd: TrackedPlatformDiskIO? - var _outputReadEnd: TrackedPlatformDiskIO? - var _errorReadEnd: TrackedPlatformDiskIO? + var _inputWriteEnd: IOChannel? + var _outputReadEnd: IOChannel? + var _errorReadEnd: IOChannel? init( execution: Execution, - inputWriteEnd: consuming TrackedPlatformDiskIO?, - outputReadEnd: consuming TrackedPlatformDiskIO?, - errorReadEnd: consuming TrackedPlatformDiskIO? + inputWriteEnd: consuming IOChannel?, + outputReadEnd: consuming IOChannel?, + errorReadEnd: consuming IOChannel? ) { self.execution = execution self._inputWriteEnd = consume inputWriteEnd @@ -511,15 +511,15 @@ extension Configuration { self._errorReadEnd = consume errorReadEnd } - mutating func inputWriteEnd() -> TrackedPlatformDiskIO? { + mutating func inputWriteEnd() -> IOChannel? { return self._inputWriteEnd.take() } - mutating func outputReadEnd() -> TrackedPlatformDiskIO? { + mutating func outputReadEnd() -> IOChannel? { return self._outputReadEnd.take() } - mutating func errorReadEnd() -> TrackedPlatformDiskIO? { + mutating func errorReadEnd() -> IOChannel? { return self._errorReadEnd.take() } } @@ -589,34 +589,45 @@ internal enum StringOrRawBytes: Sendable, Hashable { } } -/// A wrapped `FileDescriptor` and whether it should be closed -/// automatically when done. -internal struct TrackedFileDescriptor: ~Copyable { - internal var closeWhenDone: Bool - internal let fileDescriptor: FileDescriptor - - internal init( - _ fileDescriptor: FileDescriptor, - closeWhenDone: Bool - ) { - self.fileDescriptor = fileDescriptor - self.closeWhenDone = closeWhenDone - } - - consuming func consumeDiskIO() -> FileDescriptor { - let result = self.fileDescriptor - // Transfer the ownership out and therefor - // don't perform close on deinit - self.closeWhenDone = false - return result - } +internal enum _CloseTarget { + #if canImport(WinSDK) + case handle(HANDLE) + #endif + case fileDescriptor(FileDescriptor) + case dispatchIO(DispatchIO) +} - internal mutating func safelyClose() throws { - guard self.closeWhenDone else { - return +internal func _safelyClose(_ target: _CloseTarget) throws { + switch target { + #if canImport(WinSDK) + case .handle(let handle): + /// Windows does not provide a “deregistration” API (the reverse of + /// `CreateIoCompletionPort`) for handles and it it reuses HANDLE + /// values once they are closed. Since we rely on the handle value + /// as the completion key for `CreateIoCompletionPort`, we should + /// remove the registration when the handle is closed to allow + /// new registration to proceed if the handle is reused. + AsyncIO.shared.removeRegistration(for: handle) + guard CloseHandle(handle) else { + let error = GetLastError() + // Getting `ERROR_INVALID_HANDLE` suggests that the file descriptor + // might have been closed unexpectedly. This can pose security risks + // if another part of the code inadvertently reuses the same HANDLE. + // We use `fatalError` upon receiving `ERROR_INVALID_HANDLE` + // to prevent accidentally closing a different HANDLE. + guard error != ERROR_INVALID_HANDLE else { + fatalError( + "HANDLE \(handle) is already closed" + ) + } + let subprocessError = SubprocessError( + code: .init(.asyncIOFailed("Failed to close HANDLE")), + underlyingError: .init(rawValue: error) + ) + throw subprocessError } - closeWhenDone = false - + #endif + case .fileDescriptor(let fileDescriptor): do { try fileDescriptor.close() } catch { @@ -638,6 +649,69 @@ internal struct TrackedFileDescriptor: ~Copyable { // Throw other kinds of errors to allow user to catch them throw error } + case .dispatchIO(let dispatchIO): + dispatchIO.close() + } +} + +/// `IODescriptor` wraps platform-specific `FileDescriptor`, +/// which is used to establish a connection to the standard input/output (IO) +/// system during the process of spawning a child process. Unlike `IODescriptor`, +/// the `IODescriptor` does not support data read/write operations; +/// its primary function is to facilitate the spawning of child processes +/// by providing a platform-specific file descriptor. +internal struct IODescriptor: ~Copyable { + #if canImport(WinSDK) + typealias Descriptor = HANDLE + #else + typealias Descriptor = FileDescriptor + #endif + + internal var closeWhenDone: Bool + internal let descriptor: Descriptor + + internal init( + _ descriptor: Descriptor, + closeWhenDone: Bool + ) { + self.descriptor = descriptor + self.closeWhenDone = closeWhenDone + } + + consuming func createIOChannel() -> IOChannel { + let shouldClose = self.closeWhenDone + self.closeWhenDone = false + #if canImport(Darwin) + // Transferring out the ownership of fileDescriptor means we don't have go close here + let closeFd = self.descriptor + let dispatchIO: DispatchIO = DispatchIO( + type: .stream, + fileDescriptor: self.platformDescriptor(), + queue: .global(), + cleanupHandler: { error in + // Close the file descriptor + if shouldClose { + try? closeFd.close() + } + } + ) + return IOChannel(dispatchIO, closeWhenDone: shouldClose) + #else + return IOChannel(self.descriptor, closeWhenDone: shouldClose) + #endif + } + + internal mutating func safelyClose() throws { + guard self.closeWhenDone else { + return + } + closeWhenDone = false + + #if canImport(WinSDK) + try _safelyClose(.handle(self.descriptor)) + #else + try _safelyClose(.fileDescriptor(self.descriptor)) + #endif } deinit { @@ -645,77 +719,178 @@ internal struct TrackedFileDescriptor: ~Copyable { return } - fatalError("FileDescriptor \(self.fileDescriptor.rawValue) was not closed") + fatalError("FileDescriptor \(self.descriptor) was not closed") } internal func platformDescriptor() -> PlatformFileDescriptor { - return self.fileDescriptor.platformDescriptor + #if canImport(WinSDK) + return self.descriptor + #else + return self.descriptor.platformDescriptor + #endif } } -#if !os(Windows) -/// A wrapped `DispatchIO` and whether it should be closed -/// automatically when done. -internal struct TrackedDispatchIO: ~Copyable { +internal struct IOChannel: ~Copyable, @unchecked Sendable { + #if canImport(WinSDK) + typealias Channel = HANDLE + #elseif canImport(Darwin) + typealias Channel = DispatchIO + #else + typealias Channel = FileDescriptor + #endif + internal var closeWhenDone: Bool - internal var dispatchIO: DispatchIO + internal let channel: Channel internal init( - _ dispatchIO: DispatchIO, + _ channel: Channel, closeWhenDone: Bool ) { - self.dispatchIO = dispatchIO + self.channel = channel self.closeWhenDone = closeWhenDone } - consuming func consumeDiskIO() -> DispatchIO { - let result = self.dispatchIO - // Transfer the ownership out and therefor - // don't perform close on deinit - self.closeWhenDone = false - return result - } - internal mutating func safelyClose() throws { guard self.closeWhenDone else { return } closeWhenDone = false - dispatchIO.close() - } - deinit { - guard self.closeWhenDone else { - return - } + #if canImport(WinSDK) + try _safelyClose(.handle(self.channel)) + #elseif canImport(Darwin) + try _safelyClose(.dispatchIO(self.channel)) + #else + try _safelyClose(.fileDescriptor(self.channel)) + #endif + } - fatalError("DispatchIO \(self.dispatchIO) was not closed") + internal consuming func consumeIOChannel() -> Channel { + let result = self.channel + // Transfer the ownership out and therefor + // don't perform close on deinit + self.closeWhenDone = false + return result } } -#endif internal struct CreatedPipe: ~Copyable { - internal var _readFileDescriptor: TrackedFileDescriptor? - internal var _writeFileDescriptor: TrackedFileDescriptor? + internal enum Purpose: CustomStringConvertible { + /// This pipe is used for standard input. This option maps to + /// `PIPE_ACCESS_OUTBOUND` on Windows where child only reads, + /// parent only writes. + case input + /// This pipe is used for standard output and standard error. + /// This option maps to `PIPE_ACCESS_INBOUND` on Windows where + /// child only writes, parent only reads. + case output + + var description: String { + switch self { + case .input: + return "input" + case .output: + return "output" + } + } + } + + internal var _readFileDescriptor: IODescriptor? + internal var _writeFileDescriptor: IODescriptor? internal init( - readFileDescriptor: consuming TrackedFileDescriptor?, - writeFileDescriptor: consuming TrackedFileDescriptor? + readFileDescriptor: consuming IODescriptor?, + writeFileDescriptor: consuming IODescriptor? ) { self._readFileDescriptor = readFileDescriptor self._writeFileDescriptor = writeFileDescriptor } - mutating func readFileDescriptor() -> TrackedFileDescriptor? { + mutating func readFileDescriptor() -> IODescriptor? { return self._readFileDescriptor.take() } - mutating func writeFileDescriptor() -> TrackedFileDescriptor? { + mutating func writeFileDescriptor() -> IODescriptor? { return self._writeFileDescriptor.take() } - internal init(closeWhenDone: Bool) throws { - let pipe = try FileDescriptor.ssp_pipe() + internal init(closeWhenDone: Bool, purpose: Purpose) throws { + #if canImport(WinSDK) + // On Windows, we need to create a named pipe + let pipeName = "\\\\.\\pipe\\subprocess-\(purpose)-\(Int.random(in: .min ..< .max))" + var saAttributes: SECURITY_ATTRIBUTES = SECURITY_ATTRIBUTES() + saAttributes.nLength = DWORD(MemoryLayout.size) + saAttributes.bInheritHandle = true + saAttributes.lpSecurityDescriptor = nil + + let parentEnd = pipeName.withCString( + encodedAs: UTF16.self + ) { pipeNameW in + // Use OVERLAPPED for async IO + var openMode: DWORD = DWORD(FILE_FLAG_OVERLAPPED) + switch purpose { + case .input: + openMode |= DWORD(PIPE_ACCESS_OUTBOUND) + case .output: + openMode |= DWORD(PIPE_ACCESS_INBOUND) + } + + return CreateNamedPipeW( + pipeNameW, + openMode, + DWORD(PIPE_TYPE_BYTE | PIPE_READMODE_BYTE | PIPE_WAIT), + 1, // Max instance, + DWORD(readBufferSize), + DWORD(readBufferSize), + 0, + &saAttributes + ) + } + guard let parentEnd, parentEnd != INVALID_HANDLE_VALUE else { + throw SubprocessError( + code: .init(.asyncIOFailed("CreateNamedPipeW failed")), + underlyingError: .init(rawValue: GetLastError()) + ) + } + + let childEnd = pipeName.withCString( + encodedAs: UTF16.self + ) { pipeNameW in + var targetAccess: DWORD = 0 + switch purpose { + case .input: + targetAccess = DWORD(GENERIC_READ) + case .output: + targetAccess = DWORD(GENERIC_WRITE) + } + + return CreateFileW( + pipeNameW, + targetAccess, + 0, + &saAttributes, + DWORD(OPEN_EXISTING), + DWORD(FILE_ATTRIBUTE_NORMAL | FILE_FLAG_OVERLAPPED), + nil + ) + } + guard let childEnd, childEnd != INVALID_HANDLE_VALUE else { + throw SubprocessError( + code: .init(.asyncIOFailed("CreateFileW failed")), + underlyingError: .init(rawValue: GetLastError()) + ) + } + switch purpose { + case .input: + self._readFileDescriptor = .init(childEnd, closeWhenDone: closeWhenDone) + self._writeFileDescriptor = .init(parentEnd, closeWhenDone: closeWhenDone) + case .output: + self._readFileDescriptor = .init(parentEnd, closeWhenDone: closeWhenDone) + self._writeFileDescriptor = .init(childEnd, closeWhenDone: closeWhenDone) + } + #else + let pipe = try FileDescriptor.pipe() self._readFileDescriptor = .init( pipe.readEnd, closeWhenDone: closeWhenDone @@ -724,6 +899,7 @@ internal struct CreatedPipe: ~Copyable { pipe.writeEnd, closeWhenDone: closeWhenDone ) + #endif } } diff --git a/Sources/Subprocess/Error.swift b/Sources/Subprocess/Error.swift index b7a6ca5..5e4bd80 100644 --- a/Sources/Subprocess/Error.swift +++ b/Sources/Subprocess/Error.swift @@ -18,7 +18,7 @@ import Glibc #elseif canImport(Musl) import Musl #elseif canImport(WinSDK) -import WinSDK +@preconcurrency import WinSDK #endif /// Error thrown from Subprocess diff --git a/Sources/Subprocess/Execution.swift b/Sources/Subprocess/Execution.swift index a21a170..66f8628 100644 --- a/Sources/Subprocess/Execution.swift +++ b/Sources/Subprocess/Execution.swift @@ -24,7 +24,7 @@ import Glibc #elseif canImport(Musl) import Musl #elseif canImport(WinSDK) -import WinSDK +@preconcurrency import WinSDK #endif /// An object that represents a subprocess that has been diff --git a/Sources/Subprocess/IO/AsyncIO.swift b/Sources/Subprocess/IO/AsyncIO.swift index d1464cc..35894a5 100644 --- a/Sources/Subprocess/IO/AsyncIO.swift +++ b/Sources/Subprocess/IO/AsyncIO.swift @@ -54,8 +54,6 @@ final class AsyncIO: Sendable { } } - static let shared: AsyncIO = AsyncIO() - private enum Event { case read case write @@ -335,10 +333,10 @@ extension AsyncIO { } func read( - from diskIO: borrowing TrackedPlatformDiskIO, + from diskIO: borrowing IOChannel, upTo maxLength: Int ) async throws -> [UInt8]? { - return try await self.read(from: diskIO.fileDescriptor, upTo: maxLength) + return try await self.read(from: diskIO.channel, upTo: maxLength) } func read( @@ -413,16 +411,16 @@ extension AsyncIO { func write( _ array: [UInt8], - to diskIO: borrowing TrackedPlatformDiskIO + to diskIO: borrowing IOChannel ) async throws -> Int { return try await self._write(array, to: diskIO) } func _write( _ bytes: Bytes, - to diskIO: borrowing TrackedPlatformDiskIO + to diskIO: borrowing IOChannel ) async throws -> Int { - let fileDescriptor = diskIO.fileDescriptor + let fileDescriptor = diskIO.channel let signalStream = self.registerFileDescriptor(fileDescriptor, for: .write) var writtenLength: Int = 0 /// Outer loop: every iteration signals we are ready to read more data @@ -463,9 +461,9 @@ extension AsyncIO { #if SubprocessSpan func write( _ span: borrowing RawSpan, - to diskIO: borrowing TrackedPlatformDiskIO + to diskIO: borrowing IOChannel ) async throws -> Int { - let fileDescriptor = diskIO.fileDescriptor + let fileDescriptor = diskIO.channel let signalStream = self.registerFileDescriptor(fileDescriptor, for: .write) var writtenLength: Int = 0 /// Outer loop: every iteration signals we are ready to read more data @@ -521,11 +519,11 @@ final class AsyncIO: Sendable { private init() {} internal func read( - from diskIO: borrowing TrackedPlatformDiskIO, + from diskIO: borrowing IOChannel, upTo maxLength: Int ) async throws -> DispatchData? { return try await self.read( - from: diskIO.dispatchIO, + from: diskIO.channel, upTo: maxLength, ) } @@ -571,7 +569,7 @@ final class AsyncIO: Sendable { #if SubprocessSpan internal func write( _ span: borrowing RawSpan, - to diskIO: borrowing TrackedPlatformDiskIO + to diskIO: borrowing IOChannel ) async throws -> Int { try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in let dispatchData = span.withUnsafeBytes { @@ -598,7 +596,7 @@ final class AsyncIO: Sendable { internal func write( _ array: [UInt8], - to diskIO: borrowing TrackedPlatformDiskIO + to diskIO: borrowing IOChannel ) async throws -> Int { try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in let dispatchData = array.withUnsafeBytes { @@ -624,11 +622,11 @@ final class AsyncIO: Sendable { internal func write( _ dispatchData: DispatchData, - to diskIO: borrowing TrackedPlatformDiskIO, + to diskIO: borrowing IOChannel, queue: DispatchQueue = .global(), completion: @escaping (Int, Error?) -> Void ) { - diskIO.dispatchIO.write( + diskIO.channel.write( offset: 0, data: dispatchData, queue: queue @@ -657,13 +655,20 @@ final class AsyncIO: Sendable { #endif -// MARK: - Windows (I/O Completion Ports) TODO +// MARK: - Windows (I/O Completion Ports) #if os(Windows) +import Synchronization internal import Dispatch -import WinSDK +@preconcurrency import WinSDK -final class AsyncIO: Sendable { +private typealias SignalStream = AsyncThrowingStream +private let shutdownPort: UInt64 = .max +private let _registration: Mutex< + [UInt64 : SignalStream.Continuation] +> = Mutex([:]) + +final class AsyncIO: @unchecked Sendable { protocol _ContiguousBytes: Sendable { var count: Int { get } @@ -673,78 +678,279 @@ final class AsyncIO: Sendable { ) throws -> ResultType) rethrows -> ResultType } + private final class MonitorThreadContext { + let ioCompletionPort: HANDLE + + init(ioCompletionPort: HANDLE) { + self.ioCompletionPort = ioCompletionPort + } + } + static let shared = AsyncIO() - private init() {} + private let ioCompletionPort: Result + + private let monitorThread: Result + + private init() { + var maybeSetupError: SubprocessError? = nil + // Create the the completion port + guard let port = CreateIoCompletionPort( + INVALID_HANDLE_VALUE, nil, 0, 0 + ), port != INVALID_HANDLE_VALUE else { + let error = SubprocessError( + code: .init(.asyncIOFailed("CreateIoCompletionPort failed")), + underlyingError: .init(rawValue: GetLastError()) + ) + self.ioCompletionPort = .failure(error) + self.monitorThread = .failure(error) + return + } + self.ioCompletionPort = .success(port) + // Create monitor thread + let threadContext = MonitorThreadContext(ioCompletionPort: port) + let threadContextPtr = Unmanaged.passRetained(threadContext) + let threadHandle = CreateThread(nil, 0, { args in + func reportError(_ error: SubprocessError) { + _registration.withLock { store in + for continuation in store.values { + continuation.finish(throwing: error) + } + } + } + + let unmanaged = Unmanaged.fromOpaque(args!) + let context = unmanaged.takeRetainedValue() + + // Monitor loop + while true { + var bytesTransferred: DWORD = 0 + var targetFileDescriptor: UInt64 = 0 + var overlapped: LPOVERLAPPED? = nil + + let monitorResult = GetQueuedCompletionStatus( + context.ioCompletionPort, + &bytesTransferred, + &targetFileDescriptor, + &overlapped, + INFINITE + ) + if !monitorResult { + let lastError = GetLastError() + if lastError == ERROR_BROKEN_PIPE { + // We finished reading the handle. Signal EOF by + // finishing the stream. + // NOTE: here we deliberately leave now unused continuation + // in the store. Windows does not offer an API to remove a + // HANDLE from an IOCP port, therefore we leave the registration + // to signify the HANDLE has already been resisted. + _registration.withLock { store in + if let continuation = store[targetFileDescriptor] { + continuation.finish() + } + } + continue + } else { + let error = SubprocessError( + code: .init(.asyncIOFailed("GetQueuedCompletionStatus failed")), + underlyingError: .init(rawValue: lastError) + ) + reportError(error) + break + } + } + + // Breakout the monitor loop if we received shutdown from the shutdownFD + if targetFileDescriptor == shutdownPort { + break + } + // Notify the continuations + _registration.withLock { store in + if let continuation = store[targetFileDescriptor] { + continuation.yield(bytesTransferred) + } + } + } + return 0 + }, threadContextPtr.toOpaque(), 0, nil) + guard let threadHandle = threadHandle else { + let error = SubprocessError( + code: .init(.asyncIOFailed("CreateThread failed")), + underlyingError: .init(rawValue: GetLastError()) + ) + self.monitorThread = .failure(error) + return + } + self.monitorThread = .success(threadHandle) + + atexit { + AsyncIO.shared.shutdown() + } + } + + private func shutdown() { + // Post status to shutdown HANDLE + guard case .success(let ioPort) = ioCompletionPort, + case .success(let monitorThreadHandle) = monitorThread else { + return + } + PostQueuedCompletionStatus( + ioPort, + 0, + shutdownPort, + nil + ) + // Wait for monitor thread to exit + WaitForSingleObject(monitorThreadHandle, INFINITE) + CloseHandle(ioPort) + CloseHandle(monitorThreadHandle) + } + + private func registerHandle(_ handle: HANDLE) -> SignalStream { + return SignalStream { continuation in + switch self.ioCompletionPort { + case .success(let ioPort): + // Make sure thread setup also succeed + if case .failure(let error) = monitorThread { + continuation.finish(throwing: error) + return + } + let completionKey = UInt64(UInt(bitPattern: handle)) + // Windows does not offer an API to remove a handle + // from given ioCompletionPort. If this handle has already + // been registered we simply need to update the continuation + let registrationFound = _registration.withLock { storage in + if storage[completionKey] != nil { + // Old registration found. This means this handle has + // already been registered. We simply need to update + // the continuation saved + storage[completionKey] = continuation + return true + } else { + return false + } + } + if registrationFound { + return + } + + // Windows Documentation: The function returns the handle + // of the existing I/O completion port if successful + guard CreateIoCompletionPort( + handle, ioPort, completionKey, 0 + ) == ioPort else { + let error = SubprocessError( + code: .init(.asyncIOFailed("CreateIoCompletionPort failed")), + underlyingError: .init(rawValue: GetLastError()) + ) + continuation.finish(throwing: error) + return + } + // Now save the continuation + _registration.withLock { storage in + storage[completionKey] = continuation + } + case .failure(let error): + continuation.finish(throwing: error) + } + } + } + + internal func removeRegistration(for handle: HANDLE) { + let completionKey = UInt64(UInt(bitPattern: handle)) + _registration.withLock { storage in + storage.removeValue(forKey: completionKey) + } + } func read( - from diskIO: borrowing TrackedPlatformDiskIO, + from diskIO: borrowing IOChannel, upTo maxLength: Int ) async throws -> [UInt8]? { - return try await self.read(from: diskIO.fileDescriptor, upTo: maxLength) + return try await self.read(from: diskIO.channel, upTo: maxLength) } func read( - from fileDescriptor: FileDescriptor, + from handle: HANDLE, upTo maxLength: Int ) async throws -> [UInt8]? { - return try await withCheckedThrowingContinuation { continuation in - DispatchQueue.global(qos: .userInitiated).async { - var totalBytesRead: Int = 0 - var lastError: DWORD? = nil - let values = [UInt8]( - unsafeUninitializedCapacity: maxLength - ) { buffer, initializedCount in - while true { - guard let baseAddress = buffer.baseAddress else { - initializedCount = 0 - break - } - let bufferPtr = baseAddress.advanced(by: totalBytesRead) - var bytesRead: DWORD = 0 - let readSucceed = ReadFile( - fileDescriptor.platformDescriptor, - UnsafeMutableRawPointer(mutating: bufferPtr), - DWORD(maxLength - totalBytesRead), - &bytesRead, - nil - ) - if !readSucceed { - // Windows throws ERROR_BROKEN_PIPE when the pipe is closed - let error = GetLastError() - if error == ERROR_BROKEN_PIPE { - // We are done reading - initializedCount = totalBytesRead - } else { - // We got some error - lastError = error - initializedCount = 0 - } - break - } else { - // We successfully read the current round - totalBytesRead += Int(bytesRead) - } + // If we are reading until EOF, start with readBufferSize + // and gradually increase buffer size + let bufferLength = maxLength == .max ? readBufferSize : maxLength - if totalBytesRead >= maxLength { - initializedCount = min(maxLength, totalBytesRead) - break - } - } + var resultBuffer: [UInt8] = Array( + repeating: 0, count: bufferLength + ) + var readLength: Int = 0 + var signalStream = self.registerHandle(handle).makeAsyncIterator() + + while true { + var overlapped = _OVERLAPPED() + let succeed = try resultBuffer.withUnsafeMutableBufferPointer { bufferPointer in + // Get a pointer to the memory at the specified offset + // Windows ReadFile uses DWORD for target count, which means we can only + // read up to DWORD (aka UInt32) max. + let targetCount: DWORD + if MemoryLayout.size == MemoryLayout.size { + // On 32 bit systems we don't have to worry about overflowing + targetCount = DWORD(truncatingIfNeeded: bufferPointer.count - readLength) + } else { + // On 64 bit systems we need to cap the count at DWORD max + targetCount = DWORD(truncatingIfNeeded: min(bufferPointer.count - readLength, Int(UInt32.max))) } - if let lastError = lastError { - let windowsError = SubprocessError( + + let offsetAddress = bufferPointer.baseAddress!.advanced(by: readLength) + // Read directly into the buffer at the offset + return ReadFile( + handle, + offsetAddress, + DWORD(truncatingIfNeeded: targetCount), + nil, + &overlapped + ) + } + + if !succeed { + // It is expected `ReadFile` to return `false` in async mode. + // Make sure we only get `ERROR_IO_PENDING` or `ERROR_BROKEN_PIPE` + let lastError = GetLastError() + if lastError == ERROR_BROKEN_PIPE { + // We reached EOF + return nil + } + guard lastError == ERROR_IO_PENDING else { + let error = SubprocessError( code: .init(.failedToReadFromSubprocess), underlyingError: .init(rawValue: lastError) ) - continuation.resume(throwing: windowsError) - } else { - // If we didn't read anything, return nil - if values.isEmpty { - continuation.resume(returning: nil) - } else { - continuation.resume(returning: values) + throw error + } + + } + // Now wait for read to finish + let bytesRead = try await signalStream.next() ?? 0 + + if bytesRead == 0 { + // We reached EOF. Return whatever's left + guard readLength > 0 else { + return nil + } + resultBuffer.removeLast(resultBuffer.count - readLength) + return resultBuffer + } else { + // Read some data + readLength += Int(bytesRead) + if maxLength == .max { + // Grow resultBuffer if needed + guard Double(readLength) > 0.8 * Double(resultBuffer.count) else { + continue } + resultBuffer.append( + contentsOf: Array(repeating: 0, count: resultBuffer.count) + ) + } else if readLength >= maxLength { + // When we reached maxLength, return! + return resultBuffer } } } @@ -752,7 +958,7 @@ final class AsyncIO: Sendable { func write( _ array: [UInt8], - to diskIO: borrowing TrackedPlatformDiskIO + to diskIO: borrowing IOChannel ) async throws -> Int { return try await self._write(array, to: diskIO) } @@ -760,36 +966,53 @@ final class AsyncIO: Sendable { #if SubprocessSpan func write( _ span: borrowing RawSpan, - to diskIO: borrowing TrackedPlatformDiskIO + to diskIO: borrowing IOChannel ) async throws -> Int { - // TODO: Remove this hack with I/O Completion Ports rewrite - struct _Box: @unchecked Sendable { - let ptr: UnsafeRawBufferPointer - } - let fileDescriptor = diskIO.fileDescriptor - return try await withCheckedThrowingContinuation { continuation in - span.withUnsafeBytes { ptr in - let box = _Box(ptr: ptr) - DispatchQueue.global().async { - let handle = HANDLE(bitPattern: _get_osfhandle(fileDescriptor.rawValue))! - var writtenBytes: DWORD = 0 - let writeSucceed = WriteFile( - handle, - box.ptr.baseAddress, - DWORD(box.ptr.count), - &writtenBytes, - nil + let handle = diskIO.channel + var signalStream = self.registerHandle(diskIO.channel).makeAsyncIterator() + var writtenLength: Int = 0 + while true { + var overlapped = _OVERLAPPED() + let succeed = try span.withUnsafeBytes { ptr in + // Windows WriteFile uses DWORD for target count + // which means we can only write up to DWORD max + let remainingLength: DWORD + if MemoryLayout.size == MemoryLayout.size { + // On 32 bit systems we don't have to worry about overflowing + remainingLength = DWORD(truncatingIfNeeded: ptr.count - writtenLength) + } else { + // On 64 bit systems we need to cap the count at DWORD max + remainingLength = DWORD(truncatingIfNeeded: min(ptr.count - writtenLength, Int(DWORD.max))) + } + + let startPtr = ptr.baseAddress!.advanced(by: writtenLength) + return WriteFile( + handle, + startPtr, + DWORD(truncatingIfNeeded: remainingLength), + nil, + &overlapped + ) + } + if !succeed { + // It is expected `WriteFile` to return `false` in async mode. + // Make sure we only get `ERROR_IO_PENDING` + let lastError = GetLastError() + guard lastError == ERROR_IO_PENDING else { + let error = SubprocessError( + code: .init(.failedToWriteToSubprocess), + underlyingError: .init(rawValue: lastError) ) - if !writeSucceed { - let error = SubprocessError( - code: .init(.failedToWriteToSubprocess), - underlyingError: .init(rawValue: GetLastError()) - ) - continuation.resume(throwing: error) - } else { - continuation.resume(returning: Int(writtenBytes)) - } + throw error } + + } + // Now wait for read to finish + let bytesWritten: DWORD = try await signalStream.next() ?? 0 + + writtenLength += Int(bytesWritten) + if writtenLength >= span.byteCount { + return writtenLength } } } @@ -797,32 +1020,52 @@ final class AsyncIO: Sendable { func _write( _ bytes: Bytes, - to diskIO: borrowing TrackedPlatformDiskIO + to diskIO: borrowing IOChannel ) async throws -> Int { - let fileDescriptor = diskIO.fileDescriptor - return try await withCheckedThrowingContinuation { continuation in - DispatchQueue.global().async { - let handle = HANDLE(bitPattern: _get_osfhandle(fileDescriptor.rawValue))! - var writtenBytes: DWORD = 0 - let writeSucceed = bytes.withUnsafeBytes { ptr in - return WriteFile( - handle, - ptr.baseAddress, - DWORD(ptr.count), - &writtenBytes, - nil - ) + let handle = diskIO.channel + var signalStream = self.registerHandle(diskIO.channel).makeAsyncIterator() + var writtenLength: Int = 0 + while true { + var overlapped = _OVERLAPPED() + let succeed = try bytes.withUnsafeBytes { ptr in + // Windows WriteFile uses DWORD for target count + // which means we can only write up to DWORD max + let remainingLength: DWORD + if MemoryLayout.size == MemoryLayout.size { + // On 32 bit systems we don't have to worry about overflowing + remainingLength = DWORD(truncatingIfNeeded: ptr.count - writtenLength) + } else { + // On 64 bit systems we need to cap the count at DWORD max + remainingLength = DWORD(truncatingIfNeeded: min(ptr.count - writtenLength, Int(DWORD.max))) } - if !writeSucceed { + let startPtr = ptr.baseAddress!.advanced(by: writtenLength) + return WriteFile( + handle, + startPtr, + DWORD(truncatingIfNeeded: remainingLength), + nil, + &overlapped + ) + } + + if !succeed { + // It is expected `WriteFile` to return `false` in async mode. + // Make sure we only get `ERROR_IO_PENDING` + let lastError = GetLastError() + guard lastError == ERROR_IO_PENDING else { let error = SubprocessError( code: .init(.failedToWriteToSubprocess), - underlyingError: .init(rawValue: GetLastError()) + underlyingError: .init(rawValue: lastError) ) - continuation.resume(throwing: error) - } else { - continuation.resume(returning: Int(writtenBytes)) + throw error } } + // Now wait for read to finish + let bytesWritten: DWORD = try await signalStream.next() ?? 0 + writtenLength += Int(bytesWritten) + if writtenLength >= bytes.count { + return writtenLength + } } } } diff --git a/Sources/Subprocess/IO/Input.swift b/Sources/Subprocess/IO/Input.swift index 3f473e4..715428e 100644 --- a/Sources/Subprocess/IO/Input.swift +++ b/Sources/Subprocess/IO/Input.swift @@ -15,6 +15,10 @@ @preconcurrency import SystemPackage #endif +#if canImport(WinSDK) +@preconcurrency import WinSDK +#endif + #if SubprocessFoundation #if canImport(Darwin) @@ -78,9 +82,14 @@ public struct FileDescriptorInput: InputProtocol { private let closeAfterSpawningProcess: Bool internal func createPipe() throws -> CreatedPipe { + #if canImport(WinSDK) + let readFd = HANDLE(bitPattern: _get_osfhandle(self.fileDescriptor.rawValue))! + #else + let readFd = self.fileDescriptor + #endif return CreatedPipe( readFileDescriptor: .init( - self.fileDescriptor, + readFd, closeWhenDone: self.closeAfterSpawningProcess ), writeFileDescriptor: nil @@ -203,7 +212,7 @@ extension InputProtocol { return try fdInput.createPipe() } // Base implementation - return try CreatedPipe(closeWhenDone: true) + return try CreatedPipe(closeWhenDone: true, purpose: .input) } } @@ -212,9 +221,9 @@ extension InputProtocol { /// A writer that writes to the standard input of the subprocess. public final actor StandardInputWriter: Sendable { - internal var diskIO: TrackedPlatformDiskIO + internal var diskIO: IOChannel - init(diskIO: consuming TrackedPlatformDiskIO) { + init(diskIO: consuming IOChannel) { self.diskIO = diskIO } diff --git a/Sources/Subprocess/IO/Output.swift b/Sources/Subprocess/IO/Output.swift index 983b473..223ccd6 100644 --- a/Sources/Subprocess/IO/Output.swift +++ b/Sources/Subprocess/IO/Output.swift @@ -14,6 +14,11 @@ #else @preconcurrency import SystemPackage #endif + +#if canImport(WinSDK) +@preconcurrency import WinSDK +#endif + internal import Dispatch // MARK: - Output @@ -85,10 +90,15 @@ public struct FileDescriptorOutput: OutputProtocol { private let fileDescriptor: FileDescriptor internal func createPipe() throws -> CreatedPipe { + #if canImport(WinSDK) + let writeFd = HANDLE(bitPattern: _get_osfhandle(self.fileDescriptor.rawValue))! + #else + let writeFd = self.fileDescriptor + #endif return CreatedPipe( readFileDescriptor: nil, writeFileDescriptor: .init( - self.fileDescriptor, + writeFd, closeWhenDone: self.closeAfterSpawningProcess ) ) @@ -140,7 +150,7 @@ public struct BytesOutput: OutputProtocol { public let maxSize: Int internal func captureOutput( - from diskIO: consuming TrackedPlatformDiskIO + from diskIO: consuming IOChannel ) async throws -> [UInt8] { let result = try await AsyncIO.shared.read(from: diskIO, upTo: self.maxSize) try diskIO.safelyClose() @@ -247,13 +257,13 @@ extension OutputProtocol { return try fdOutput.createPipe() } // Base pipe based implementation for everything else - return try CreatedPipe(closeWhenDone: true) + return try CreatedPipe(closeWhenDone: true, purpose: .output) } /// Capture the output from the subprocess up to maxSize @_disfavoredOverload internal func captureOutput( - from diskIO: consuming TrackedPlatformDiskIO? + from diskIO: consuming IOChannel? ) async throws -> OutputType { if OutputType.self == Void.self { return () as! OutputType @@ -272,7 +282,7 @@ extension OutputProtocol { return try await bytesOutput.captureOutput(from: diskIO) as! Self.OutputType } // Force unwrap is safe here because only `OutputType.self == Void` would - // have nil `TrackedPlatformDiskIO` + // have nil `IOChannel` let result = try await AsyncIO.shared.read(from: diskIO, upTo: self.maxSize) try diskIO.safelyClose() #if canImport(Darwin) @@ -284,7 +294,7 @@ extension OutputProtocol { } extension OutputProtocol where OutputType == Void { - internal func captureOutput(from fileDescriptor: consuming TrackedPlatformDiskIO?) async throws {} + internal func captureOutput(from fileDescriptor: consuming IOChannel?) async throws {} #if SubprocessSpan /// Convert the output from Data to expected output type diff --git a/Sources/Subprocess/Platforms/Subprocess+Darwin.swift b/Sources/Subprocess/Platforms/Subprocess+Darwin.swift index 02640ab..5309f1e 100644 --- a/Sources/Subprocess/Platforms/Subprocess+Darwin.swift +++ b/Sources/Subprocess/Platforms/Subprocess+Darwin.swift @@ -173,12 +173,12 @@ extension Configuration { var _outputPipe = outputPipeBox.take()! var _errorPipe = errorPipeBox.take()! - let inputReadFileDescriptor: TrackedFileDescriptor? = _inputPipe.readFileDescriptor() - let inputWriteFileDescriptor: TrackedFileDescriptor? = _inputPipe.writeFileDescriptor() - let outputReadFileDescriptor: TrackedFileDescriptor? = _outputPipe.readFileDescriptor() - let outputWriteFileDescriptor: TrackedFileDescriptor? = _outputPipe.writeFileDescriptor() - let errorReadFileDescriptor: TrackedFileDescriptor? = _errorPipe.readFileDescriptor() - let errorWriteFileDescriptor: TrackedFileDescriptor? = _errorPipe.writeFileDescriptor() + let inputReadFileDescriptor: IODescriptor? = _inputPipe.readFileDescriptor() + let inputWriteFileDescriptor: IODescriptor? = _inputPipe.writeFileDescriptor() + let outputReadFileDescriptor: IODescriptor? = _outputPipe.readFileDescriptor() + let outputWriteFileDescriptor: IODescriptor? = _outputPipe.writeFileDescriptor() + let errorReadFileDescriptor: IODescriptor? = _errorPipe.readFileDescriptor() + let errorWriteFileDescriptor: IODescriptor? = _errorPipe.writeFileDescriptor() for possibleExecutablePath in possiblePaths { var pid: pid_t = 0 @@ -442,9 +442,9 @@ extension Configuration { ) return SpawnResult( execution: execution, - inputWriteEnd: inputWriteFileDescriptor?.createPlatformDiskIO(), - outputReadEnd: outputReadFileDescriptor?.createPlatformDiskIO(), - errorReadEnd: errorReadFileDescriptor?.createPlatformDiskIO() + inputWriteEnd: inputWriteFileDescriptor?.createIOChannel(), + outputReadEnd: outputReadFileDescriptor?.createIOChannel(), + errorReadEnd: errorReadFileDescriptor?.createIOChannel() ) } @@ -524,30 +524,4 @@ internal func monitorProcessTermination( } } -internal typealias TrackedPlatformDiskIO = TrackedDispatchIO - -extension TrackedFileDescriptor { - internal consuming func createPlatformDiskIO() -> TrackedPlatformDiskIO { - // Transferring out the ownership of fileDescriptor means we don't have go close here - let shouldClose = self.closeWhenDone - let closeFd = self.fileDescriptor - let dispatchIO: DispatchIO = DispatchIO( - type: .stream, - fileDescriptor: self.platformDescriptor(), - queue: .global(), - cleanupHandler: { error in - // Close the file descriptor - if shouldClose { - try? closeFd.close() - } - } - ) - let result: TrackedPlatformDiskIO = .init( - dispatchIO, closeWhenDone: self.closeWhenDone - ) - self.closeWhenDone = false - return result - } -} - #endif // canImport(Darwin) diff --git a/Sources/Subprocess/Platforms/Subprocess+Linux.swift b/Sources/Subprocess/Platforms/Subprocess+Linux.swift index 47fbe96..3b449ac 100644 --- a/Sources/Subprocess/Platforms/Subprocess+Linux.swift +++ b/Sources/Subprocess/Platforms/Subprocess+Linux.swift @@ -56,12 +56,12 @@ extension Configuration { var _outputPipe = outputPipeBox.take()! var _errorPipe = errorPipeBox.take()! - let inputReadFileDescriptor: TrackedFileDescriptor? = _inputPipe.readFileDescriptor() - let inputWriteFileDescriptor: TrackedFileDescriptor? = _inputPipe.writeFileDescriptor() - let outputReadFileDescriptor: TrackedFileDescriptor? = _outputPipe.readFileDescriptor() - let outputWriteFileDescriptor: TrackedFileDescriptor? = _outputPipe.writeFileDescriptor() - let errorReadFileDescriptor: TrackedFileDescriptor? = _errorPipe.readFileDescriptor() - let errorWriteFileDescriptor: TrackedFileDescriptor? = _errorPipe.writeFileDescriptor() + let inputReadFileDescriptor: IODescriptor? = _inputPipe.readFileDescriptor() + let inputWriteFileDescriptor: IODescriptor? = _inputPipe.writeFileDescriptor() + let outputReadFileDescriptor: IODescriptor? = _outputPipe.readFileDescriptor() + let outputWriteFileDescriptor: IODescriptor? = _outputPipe.writeFileDescriptor() + let errorReadFileDescriptor: IODescriptor? = _errorPipe.readFileDescriptor() + let errorWriteFileDescriptor: IODescriptor? = _errorPipe.writeFileDescriptor() for possibleExecutablePath in possiblePaths { var processGroupIDPtr: UnsafeMutablePointer? = nil @@ -154,9 +154,9 @@ extension Configuration { ) return SpawnResult( execution: execution, - inputWriteEnd: inputWriteFileDescriptor?.createPlatformDiskIO(), - outputReadEnd: outputReadFileDescriptor?.createPlatformDiskIO(), - errorReadEnd: errorReadFileDescriptor?.createPlatformDiskIO() + inputWriteEnd: inputWriteFileDescriptor?.createIOChannel(), + outputReadEnd: outputReadFileDescriptor?.createIOChannel(), + errorReadEnd: errorReadFileDescriptor?.createIOChannel() ) } @@ -423,18 +423,4 @@ private func _setupMonitorSignalHandler() { setup } -internal typealias TrackedPlatformDiskIO = TrackedFileDescriptor - -extension TrackedFileDescriptor { - internal consuming func createPlatformDiskIO() -> TrackedPlatformDiskIO { - // Transferring out the ownership of fileDescriptor means we don't have go close here - let result: TrackedPlatformDiskIO = .init( - self.fileDescriptor, - closeWhenDone: self.closeWhenDone - ) - self.closeWhenDone = false - return result - } -} - #endif // canImport(Glibc) || canImport(Android) || canImport(Musl) diff --git a/Sources/Subprocess/Platforms/Subprocess+Windows.swift b/Sources/Subprocess/Platforms/Subprocess+Windows.swift index db799ba..63229aa 100644 --- a/Sources/Subprocess/Platforms/Subprocess+Windows.swift +++ b/Sources/Subprocess/Platforms/Subprocess+Windows.swift @@ -11,7 +11,7 @@ #if canImport(WinSDK) -import WinSDK +@preconcurrency import WinSDK internal import Dispatch #if canImport(System) @preconcurrency import System @@ -48,12 +48,12 @@ extension Configuration { outputPipe: consuming CreatedPipe, errorPipe: consuming CreatedPipe ) throws -> SpawnResult { - var inputReadFileDescriptor: TrackedFileDescriptor? = inputPipe.readFileDescriptor() - var inputWriteFileDescriptor: TrackedFileDescriptor? = inputPipe.writeFileDescriptor() - var outputReadFileDescriptor: TrackedFileDescriptor? = outputPipe.readFileDescriptor() - var outputWriteFileDescriptor: TrackedFileDescriptor? = outputPipe.writeFileDescriptor() - var errorReadFileDescriptor: TrackedFileDescriptor? = errorPipe.readFileDescriptor() - var errorWriteFileDescriptor: TrackedFileDescriptor? = errorPipe.writeFileDescriptor() + var inputReadFileDescriptor: IODescriptor? = inputPipe.readFileDescriptor() + var inputWriteFileDescriptor: IODescriptor? = inputPipe.writeFileDescriptor() + var outputReadFileDescriptor: IODescriptor? = outputPipe.readFileDescriptor() + var outputWriteFileDescriptor: IODescriptor? = outputPipe.writeFileDescriptor() + var errorReadFileDescriptor: IODescriptor? = errorPipe.readFileDescriptor() + var errorWriteFileDescriptor: IODescriptor? = errorPipe.writeFileDescriptor() let applicationName: String? let commandAndArgs: String @@ -162,9 +162,9 @@ extension Configuration { return SpawnResult( execution: execution, - inputWriteEnd: inputWriteFileDescriptor?.createPlatformDiskIO(), - outputReadEnd: outputReadFileDescriptor?.createPlatformDiskIO(), - errorReadEnd: errorReadFileDescriptor?.createPlatformDiskIO() + inputWriteEnd: inputWriteFileDescriptor?.createIOChannel(), + outputReadEnd: outputReadFileDescriptor?.createIOChannel(), + errorReadEnd: errorReadFileDescriptor?.createIOChannel() ) } @@ -182,12 +182,12 @@ extension Configuration { var _outputPipe = outputPipeBox.take()! var _errorPipe = errorPipeBox.take()! - let inputReadFileDescriptor: TrackedFileDescriptor? = _inputPipe.readFileDescriptor() - let inputWriteFileDescriptor: TrackedFileDescriptor? = _inputPipe.writeFileDescriptor() - let outputReadFileDescriptor: TrackedFileDescriptor? = _outputPipe.readFileDescriptor() - let outputWriteFileDescriptor: TrackedFileDescriptor? = _outputPipe.writeFileDescriptor() - let errorReadFileDescriptor: TrackedFileDescriptor? = _errorPipe.readFileDescriptor() - let errorWriteFileDescriptor: TrackedFileDescriptor? = _errorPipe.writeFileDescriptor() + let inputReadFileDescriptor: IODescriptor? = _inputPipe.readFileDescriptor() + let inputWriteFileDescriptor: IODescriptor? = _inputPipe.writeFileDescriptor() + let outputReadFileDescriptor: IODescriptor? = _outputPipe.readFileDescriptor() + let outputWriteFileDescriptor: IODescriptor? = _outputPipe.writeFileDescriptor() + let errorReadFileDescriptor: IODescriptor? = _errorPipe.readFileDescriptor() + let errorWriteFileDescriptor: IODescriptor? = _errorPipe.writeFileDescriptor() let ( applicationName, @@ -306,9 +306,9 @@ extension Configuration { return SpawnResult( execution: execution, - inputWriteEnd: inputWriteFileDescriptor?.createPlatformDiskIO(), - outputReadEnd: outputReadFileDescriptor?.createPlatformDiskIO(), - errorReadEnd: errorReadFileDescriptor?.createPlatformDiskIO() + inputWriteEnd: inputWriteFileDescriptor?.createIOChannel(), + outputReadEnd: outputReadFileDescriptor?.createIOChannel(), + errorReadEnd: errorReadFileDescriptor?.createIOChannel() ) } } @@ -773,12 +773,12 @@ extension Configuration { } private func generateStartupInfo( - withInputRead inputReadFileDescriptor: borrowing TrackedFileDescriptor?, - inputWrite inputWriteFileDescriptor: borrowing TrackedFileDescriptor?, - outputRead outputReadFileDescriptor: borrowing TrackedFileDescriptor?, - outputWrite outputWriteFileDescriptor: borrowing TrackedFileDescriptor?, - errorRead errorReadFileDescriptor: borrowing TrackedFileDescriptor?, - errorWrite errorWriteFileDescriptor: borrowing TrackedFileDescriptor?, + withInputRead inputReadFileDescriptor: borrowing IODescriptor?, + inputWrite inputWriteFileDescriptor: borrowing IODescriptor?, + outputRead outputReadFileDescriptor: borrowing IODescriptor?, + outputWrite outputWriteFileDescriptor: borrowing IODescriptor?, + errorRead errorReadFileDescriptor: borrowing IODescriptor?, + errorWrite errorWriteFileDescriptor: borrowing IODescriptor?, ) throws -> STARTUPINFOW { var info: STARTUPINFOW = STARTUPINFOW() info.cb = DWORD(MemoryLayout.size) @@ -950,8 +950,6 @@ extension Configuration { // MARK: - Type alias internal typealias PlatformFileDescriptor = HANDLE -internal typealias TrackedPlatformDiskIO = TrackedFileDescriptor - // MARK: - Pipe Support extension FileDescriptor { // NOTE: Not the same as SwiftSystem's FileDescriptor.pipe, which has different behavior, @@ -999,18 +997,6 @@ extension FileDescriptor { } } -extension TrackedFileDescriptor { - internal consuming func createPlatformDiskIO() -> TrackedPlatformDiskIO { - // Transferring out the ownership of fileDescriptor means we don't have go close here - let result: TrackedPlatformDiskIO = .init( - self.fileDescriptor, - closeWhenDone: self.closeWhenDone - ) - self.closeWhenDone = false - return result - } -} - extension Optional where Wrapped == String { fileprivate func withOptionalCString( encodedAs targetEncoding: Encoding.Type, diff --git a/Sources/Subprocess/SubprocessFoundation/Input+Foundation.swift b/Sources/Subprocess/SubprocessFoundation/Input+Foundation.swift index 8312c56..c82d2d3 100644 --- a/Sources/Subprocess/SubprocessFoundation/Input+Foundation.swift +++ b/Sources/Subprocess/SubprocessFoundation/Input+Foundation.swift @@ -133,7 +133,7 @@ extension StandardInputWriter { extension AsyncIO { internal func write( _ data: Data, - to diskIO: borrowing TrackedPlatformDiskIO + to diskIO: borrowing IOChannel ) async throws -> Int { try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in let dispatchData = data.withUnsafeBytes { @@ -163,7 +163,7 @@ extension Data : AsyncIO._ContiguousBytes { } extension AsyncIO { internal func write( _ data: Data, - to diskIO: borrowing TrackedPlatformDiskIO + to diskIO: borrowing IOChannel ) async throws -> Int { return try await self._write(data, to: diskIO) } diff --git a/Sources/Subprocess/Teardown.swift b/Sources/Subprocess/Teardown.swift index a4c0352..d5dd551 100644 --- a/Sources/Subprocess/Teardown.swift +++ b/Sources/Subprocess/Teardown.swift @@ -20,7 +20,7 @@ import Glibc #elseif canImport(Musl) import Musl #elseif canImport(WinSDK) -import WinSDK +@preconcurrency import WinSDK #endif /// A step in the graceful shutdown teardown sequence. diff --git a/Tests/SubprocessTests/SubprocessTests+Windows.swift b/Tests/SubprocessTests/SubprocessTests+Windows.swift index c2981e3..3b9bed7 100644 --- a/Tests/SubprocessTests/SubprocessTests+Windows.swift +++ b/Tests/SubprocessTests/SubprocessTests+Windows.swift @@ -11,7 +11,7 @@ #if canImport(WinSDK) -import WinSDK +@preconcurrency import WinSDK import FoundationEssentials import Testing import Dispatch @@ -269,7 +269,6 @@ extension SubprocessWindowsTests { output: .data(limit: 2048 * 1024), error: .discarded ) - #expect(catResult.terminationStatus.isSuccess) // Make sure we read all bytes #expect( catResult.standardOutput == expected @@ -324,7 +323,6 @@ extension SubprocessWindowsTests { } return buffer } - #expect(result.terminationStatus.isSuccess) #expect(result.value == expected) } @@ -361,7 +359,6 @@ extension SubprocessWindowsTests { } return buffer } - #expect(result.terminationStatus.isSuccess) #expect(result.value == expected) } } @@ -548,7 +545,7 @@ extension SubprocessWindowsTests { var platformOptions: Subprocess.PlatformOptions = .init() platformOptions.consoleBehavior = .detach let detachConsoleResult = try await Subprocess.run( - .path("C:\\Windows\\System32\\WindowsPowerShell\\v1.0\\powershell.exe"), + .name("powershell.exe"), arguments: [ "-File", windowsTester.string, "-mode", "get-console-window", @@ -572,7 +569,7 @@ extension SubprocessWindowsTests { } let parentConsole = GetConsoleWindow() let newConsoleResult = try await Subprocess.run( - .path("C:\\Windows\\System32\\WindowsPowerShell\\v1.0\\powershell.exe"), + .name("powershell.exe"), arguments: [ "-File", windowsTester.string, "-mode", "get-console-window", @@ -603,7 +600,7 @@ extension SubprocessWindowsTests { } } let changeTitleResult = try await Subprocess.run( - .path("C:\\Windows\\System32\\WindowsPowerShell\\v1.0\\powershell.exe"), + .name("powershell.exe"), arguments: [ "-Command", "$consoleTitle = [console]::Title; Write-Host $consoleTitle", ], @@ -651,7 +648,7 @@ extension SubprocessWindowsTests { // Now check the to make sure the process is actually suspended // Why not spawn another process to do that? var checkResult = try await Subprocess.run( - .path("C:\\Windows\\System32\\WindowsPowerShell\\v1.0\\powershell.exe"), + .name("powershell.exe"), arguments: [ "-File", windowsTester.string, "-mode", "is-process-suspended", @@ -668,7 +665,7 @@ extension SubprocessWindowsTests { // Now resume the process try subprocess.resume() checkResult = try await Subprocess.run( - .path("C:\\Windows\\System32\\WindowsPowerShell\\v1.0\\powershell.exe"), + .name("powershell.exe"), arguments: [ "-File", windowsTester.string, "-mode", "is-process-suspended", @@ -720,7 +717,6 @@ extension SubprocessWindowsTests { WaitForSingleObject(processHandle, INFINITE) // Up to 10 characters because Windows process IDs are DWORDs (UInt32), whose max value is 10 digits. - try writeFd.close() let data = try await readFd.readUntilEOF(upToLength: 10) let resultPID = try #require( String(data: data, encoding: .utf8) diff --git a/Tests/TestResources/TestResources.swift b/Tests/TestResources/TestResources.swift index 6a09808..7030862 100644 --- a/Tests/TestResources/TestResources.swift +++ b/Tests/TestResources/TestResources.swift @@ -10,7 +10,7 @@ //===----------------------------------------------------------------------===// #if canImport(WinSDK) -import WinSDK +@preconcurrency import WinSDK #endif // Confitionally require Foundation due to `Bundle.module` From 22e4b383d1516359ad6c9ccee9a6df4adc59a524 Mon Sep 17 00:00:00 2001 From: Charles Hu Date: Wed, 23 Jul 2025 14:17:03 -0700 Subject: [PATCH 07/10] Move platform specific AsyncIO implementations to separate files --- Sources/Subprocess/CMakeLists.txt | 4 +- Sources/Subprocess/IO/AsyncIO+Darwin.swift | 165 +++ Sources/Subprocess/IO/AsyncIO+Linux.swift | 507 +++++++++ Sources/Subprocess/IO/AsyncIO+Windows.swift | 437 ++++++++ Sources/Subprocess/IO/AsyncIO.swift | 1076 ------------------- 5 files changed, 1112 insertions(+), 1077 deletions(-) create mode 100644 Sources/Subprocess/IO/AsyncIO+Darwin.swift create mode 100644 Sources/Subprocess/IO/AsyncIO+Linux.swift create mode 100644 Sources/Subprocess/IO/AsyncIO+Windows.swift delete mode 100644 Sources/Subprocess/IO/AsyncIO.swift diff --git a/Sources/Subprocess/CMakeLists.txt b/Sources/Subprocess/CMakeLists.txt index c136bfe..ea95c73 100644 --- a/Sources/Subprocess/CMakeLists.txt +++ b/Sources/Subprocess/CMakeLists.txt @@ -17,7 +17,9 @@ add_library(Subprocess Result.swift IO/Output.swift IO/Input.swift - IO/AsyncIO.swift + IO/AsyncIO+Darwin.swift + IO/AsyncIO+Linux.swift + IO/AsyncIO+Windows.swift Span+Subprocess.swift AsyncBufferSequence.swift API.swift diff --git a/Sources/Subprocess/IO/AsyncIO+Darwin.swift b/Sources/Subprocess/IO/AsyncIO+Darwin.swift new file mode 100644 index 0000000..218f75d --- /dev/null +++ b/Sources/Subprocess/IO/AsyncIO+Darwin.swift @@ -0,0 +1,165 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2025 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// +//===----------------------------------------------------------------------===// + +/// Darwin AsyncIO implementation based on DispatchIO + +// MARK: - macOS (DispatchIO) +#if canImport(Darwin) + +#if canImport(System) +@preconcurrency import System +#else +@preconcurrency import SystemPackage +#endif + +internal import Dispatch + +final class AsyncIO: Sendable { + static let shared: AsyncIO = AsyncIO() + + private init() {} + + internal func read( + from diskIO: borrowing IOChannel, + upTo maxLength: Int + ) async throws -> DispatchData? { + return try await self.read( + from: diskIO.channel, + upTo: maxLength, + ) + } + + internal func read( + from dispatchIO: DispatchIO, + upTo maxLength: Int + ) async throws -> DispatchData? { + return try await withCheckedThrowingContinuation { continuation in + var buffer: DispatchData = .empty + dispatchIO.read( + offset: 0, + length: maxLength, + queue: .global() + ) { done, data, error in + if error != 0 { + continuation.resume( + throwing: SubprocessError( + code: .init(.failedToReadFromSubprocess), + underlyingError: .init(rawValue: error) + ) + ) + return + } + if let data = data { + if buffer.isEmpty { + buffer = data + } else { + buffer.append(data) + } + } + if done { + if !buffer.isEmpty { + continuation.resume(returning: buffer) + } else { + continuation.resume(returning: nil) + } + } + } + } + } + + #if SubprocessSpan + internal func write( + _ span: borrowing RawSpan, + to diskIO: borrowing IOChannel + ) async throws -> Int { + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + let dispatchData = span.withUnsafeBytes { + return DispatchData( + bytesNoCopy: $0, + deallocator: .custom( + nil, + { + // noop + } + ) + ) + } + self.write(dispatchData, to: diskIO) { writtenLength, error in + if let error = error { + continuation.resume(throwing: error) + } else { + continuation.resume(returning: writtenLength) + } + } + } + } + #endif // SubprocessSpan + + internal func write( + _ array: [UInt8], + to diskIO: borrowing IOChannel + ) async throws -> Int { + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + let dispatchData = array.withUnsafeBytes { + return DispatchData( + bytesNoCopy: $0, + deallocator: .custom( + nil, + { + // noop + } + ) + ) + } + self.write(dispatchData, to: diskIO) { writtenLength, error in + if let error = error { + continuation.resume(throwing: error) + } else { + continuation.resume(returning: writtenLength) + } + } + } + } + + internal func write( + _ dispatchData: DispatchData, + to diskIO: borrowing IOChannel, + queue: DispatchQueue = .global(), + completion: @escaping (Int, Error?) -> Void + ) { + diskIO.channel.write( + offset: 0, + data: dispatchData, + queue: queue + ) { done, unwritten, error in + guard done else { + // Wait until we are done writing or encountered some error + return + } + + let unwrittenLength = unwritten?.count ?? 0 + let writtenLength = dispatchData.count - unwrittenLength + guard error != 0 else { + completion(writtenLength, nil) + return + } + completion( + writtenLength, + SubprocessError( + code: .init(.failedToWriteToSubprocess), + underlyingError: .init(rawValue: error) + ) + ) + } + } +} + +#endif diff --git a/Sources/Subprocess/IO/AsyncIO+Linux.swift b/Sources/Subprocess/IO/AsyncIO+Linux.swift new file mode 100644 index 0000000..c571085 --- /dev/null +++ b/Sources/Subprocess/IO/AsyncIO+Linux.swift @@ -0,0 +1,507 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2025 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// +//===----------------------------------------------------------------------===// + +/// Linux AsyncIO implementation based on epoll + +#if canImport(Glibc) || canImport(Android) || canImport(Musl) + +#if canImport(System) +@preconcurrency import System +#else +@preconcurrency import SystemPackage +#endif + +#if canImport(Glibc) +import Glibc +#elseif canImport(Android) +import Android +#elseif canImport(Musl) +import Musl +#endif + +import _SubprocessCShims +import Synchronization + +private typealias SignalStream = AsyncThrowingStream +private let _epollEventSize = 256 +private let _registration: Mutex< + [PlatformFileDescriptor : SignalStream.Continuation] +> = Mutex([:]) + +final class AsyncIO: Sendable { + + typealias OutputStream = AsyncThrowingStream + + private final class MonitorThreadContext { + let epollFileDescriptor: CInt + let shutdownFileDescriptor: CInt + + init( + epollFileDescriptor: CInt, + shutdownFileDescriptor: CInt + ) { + self.epollFileDescriptor = epollFileDescriptor + self.shutdownFileDescriptor = shutdownFileDescriptor + } + } + + private enum Event { + case read + case write + } + + private struct State { + let epollFileDescriptor: CInt + let shutdownFileDescriptor: CInt + let monitorThread: pthread_t + } + + static let shared: AsyncIO = AsyncIO() + + private let state: Result + + private init() { + // Create main epoll fd + let epollFileDescriptor = epoll_create1(CInt(EPOLL_CLOEXEC)) + guard epollFileDescriptor >= 0 else { + let error = SubprocessError( + code: .init(.asyncIOFailed("epoll_create1 failed")), + underlyingError: .init(rawValue: errno) + ) + self.state = .failure(error) + return + } + // Create shutdownFileDescriptor + let shutdownFileDescriptor = eventfd(0, CInt(EFD_NONBLOCK | EFD_CLOEXEC)) + guard shutdownFileDescriptor >= 0 else { + let error = SubprocessError( + code: .init(.asyncIOFailed("eventfd failed")), + underlyingError: .init(rawValue: errno) + ) + self.state = .failure(error) + return + } + + // Register shutdownFileDescriptor with epoll + var event = epoll_event( + events: EPOLLIN.rawValue, + data: epoll_data(fd: shutdownFileDescriptor) + ) + var rc = epoll_ctl( + epollFileDescriptor, + EPOLL_CTL_ADD, + shutdownFileDescriptor, + &event + ) + guard rc == 0 else { + let error = SubprocessError( + code: .init(.asyncIOFailed( + "failed to add shutdown fd \(shutdownFileDescriptor) to epoll list") + ), + underlyingError: .init(rawValue: errno) + ) + self.state = .failure(error) + return + } + + // Create thread data + let context = MonitorThreadContext( + epollFileDescriptor: epollFileDescriptor, + shutdownFileDescriptor: shutdownFileDescriptor + ) + let threadContext = Unmanaged.passRetained(context) +#if os(FreeBSD) || os(OpenBSD) + var thread: pthread_t? = nil +#else + var thread: pthread_t = pthread_t() +#endif + rc = pthread_create(&thread, nil, { args in + func reportError(_ error: SubprocessError) { + _registration.withLock { store in + for continuation in store.values { + continuation.finish(throwing: error) + } + } + } + + let unmanaged = Unmanaged.fromOpaque(args!) + let context = unmanaged.takeRetainedValue() + + var events: [epoll_event] = Array( + repeating: epoll_event(events: 0, data: epoll_data(fd: 0)), + count: _epollEventSize + ) + + // Enter the monitor loop + monitorLoop: while true { + let eventCount = epoll_wait( + context.epollFileDescriptor, + &events, + CInt(events.count), + -1 + ) + if eventCount < 0 { + if errno == EINTR || errno == EAGAIN { + continue // interrupted by signal; try again + } + // Report other errors + let error = SubprocessError( + code: .init(.asyncIOFailed( + "epoll_wait failed") + ), + underlyingError: .init(rawValue: errno) + ) + reportError(error) + break monitorLoop + } + + for index in 0 ..< Int(eventCount) { + let event = events[index] + let targetFileDescriptor = event.data.fd + // Breakout the monitor loop if we received shutdown + // from the shutdownFD + if targetFileDescriptor == context.shutdownFileDescriptor { + var buf: UInt64 = 0 + _ = _SubprocessCShims.read(context.shutdownFileDescriptor, &buf, MemoryLayout.size) + break monitorLoop + } + + // Notify the continuation + _registration.withLock { store in + if let continuation = store[targetFileDescriptor] { + continuation.yield(true) + } + } + } + } + + return nil + }, threadContext.toOpaque()) + guard rc == 0 else { + let error = SubprocessError( + code: .init(.asyncIOFailed("Failed to create monitor thread")), + underlyingError: .init(rawValue: rc) + ) + self.state = .failure(error) + return + } + +#if os(FreeBSD) || os(OpenBSD) + let monitorThread = thread! +#else + let monitorThread = thread +#endif + + let state = State( + epollFileDescriptor: epollFileDescriptor, + shutdownFileDescriptor: shutdownFileDescriptor, + monitorThread: monitorThread + ) + self.state = .success(state) + + atexit { + AsyncIO.shared.shutdown() + } + } + + private func shutdown() { + guard case .success(let currentState) = self.state else { + return + } + + var one: UInt64 = 1 + // Wake up the thread for shutdown + _ = _SubprocessCShims.write(currentState.shutdownFileDescriptor, &one, MemoryLayout.stride) + // Cleanup the monitor thread + pthread_join(currentState.monitorThread, nil) + } + + + private func registerFileDescriptor( + _ fileDescriptor: FileDescriptor, + for event: Event + ) -> SignalStream { + return SignalStream { continuation in + // If setup failed, nothing much we can do + switch self.state { + case .success(let state): + // Set file descriptor to be non blocking + let flags = fcntl(fileDescriptor.rawValue, F_GETFD) + guard flags != -1 else { + let error = SubprocessError( + code: .init(.asyncIOFailed( + "failed to get flags for \(fileDescriptor.rawValue)") + ), + underlyingError: .init(rawValue: errno) + ) + continuation.finish(throwing: error) + return + } + guard fcntl(fileDescriptor.rawValue, F_SETFL, flags | O_NONBLOCK) != -1 else { + let error = SubprocessError( + code: .init(.asyncIOFailed( + "failed to set \(fileDescriptor.rawValue) to be non-blocking") + ), + underlyingError: .init(rawValue: errno) + ) + continuation.finish(throwing: error) + return + } + // Register event + let targetEvent: EPOLL_EVENTS + switch event { + case .read: + targetEvent = EPOLLIN + case .write: + targetEvent = EPOLLOUT + } + + var event = epoll_event( + events: targetEvent.rawValue, + data: epoll_data(fd: fileDescriptor.rawValue) + ) + let rc = epoll_ctl( + state.epollFileDescriptor, + EPOLL_CTL_ADD, + fileDescriptor.rawValue, + &event + ) + if rc != 0 { + let error = SubprocessError( + code: .init(.asyncIOFailed( + "failed to add \(fileDescriptor.rawValue) to epoll list") + ), + underlyingError: .init(rawValue: errno) + ) + continuation.finish(throwing: error) + return + } + // Now save the continuation + _registration.withLock { storage in + storage[fileDescriptor.rawValue] = continuation + } + case .failure(let setupError): + continuation.finish(throwing: setupError) + return + } + } + } + + private func removeRegistration(for fileDescriptor: FileDescriptor) throws { + switch self.state { + case .success(let state): + let rc = epoll_ctl( + state.epollFileDescriptor, + EPOLL_CTL_DEL, + fileDescriptor.rawValue, + nil + ) + guard rc == 0 else { + throw SubprocessError( + code: .init(.asyncIOFailed( + "failed to remove \(fileDescriptor.rawValue) to epoll list") + ), + underlyingError: .init(rawValue: errno) + ) + } + _registration.withLock { store in + _ = store.removeValue(forKey: fileDescriptor.rawValue) + } + case .failure(let setupFailure): + throw setupFailure + } + } +} + +extension AsyncIO { + + protocol _ContiguousBytes { + var count: Int { get } + + func withUnsafeBytes( + _ body: (UnsafeRawBufferPointer) throws -> ResultType + ) rethrows -> ResultType + } + + func read( + from diskIO: borrowing IOChannel, + upTo maxLength: Int + ) async throws -> [UInt8]? { + return try await self.read(from: diskIO.channel, upTo: maxLength) + } + + func read( + from fileDescriptor: FileDescriptor, + upTo maxLength: Int + ) async throws -> [UInt8]? { + // If we are reading until EOF, start with readBufferSize + // and gradually increase buffer size + let bufferLength = maxLength == .max ? readBufferSize : maxLength + + var resultBuffer: [UInt8] = Array( + repeating: 0, count: bufferLength + ) + var readLength: Int = 0 + let signalStream = self.registerFileDescriptor(fileDescriptor, for: .read) + /// Outer loop: every iteration signals we are ready to read more data + for try await _ in signalStream { + /// Inner loop: repeatedly call `.read()` and read more data until: + /// 1. We reached EOF (read length is 0), in which case return the result + /// 2. We read `maxLength` bytes, in which case return the result + /// 3. `read()` returns -1 and sets `errno` to `EAGAIN` or `EWOULDBLOCK`. In + /// this case we `break` out of the inner loop and wait `.read()` to be + /// ready by `await`ing the next signal in the outer loop. + while true { + let bytesRead = resultBuffer.withUnsafeMutableBufferPointer { bufferPointer in + // Get a pointer to the memory at the specified offset + let targetCount = bufferPointer.count - readLength + + let offsetAddress = bufferPointer.baseAddress!.advanced(by: readLength) + + // Read directly into the buffer at the offset + return _SubprocessCShims.read(fileDescriptor.rawValue, offsetAddress, targetCount) + } + if bytesRead > 0 { + // Read some data + readLength += bytesRead + if maxLength == .max { + // Grow resultBuffer if needed + guard Double(readLength) > 0.8 * Double(resultBuffer.count) else { + continue + } + resultBuffer.append( + contentsOf: Array(repeating: 0, count: resultBuffer.count) + ) + } else if readLength >= maxLength { + // When we reached maxLength, return! + try self.removeRegistration(for: fileDescriptor) + return resultBuffer + } + } else if bytesRead == 0 { + // We reached EOF. Return whatever's left + try self.removeRegistration(for: fileDescriptor) + guard readLength > 0 else { + return nil + } + resultBuffer.removeLast(resultBuffer.count - readLength) + return resultBuffer + } else { + if errno == EAGAIN || errno == EWOULDBLOCK { + // No more data for now wait for the next signal + break + } else { + // Throw all other errors + try self.removeRegistration(for: fileDescriptor) + throw SubprocessError.UnderlyingError(rawValue: errno) + } + } + } + } + return resultBuffer + } + + func write( + _ array: [UInt8], + to diskIO: borrowing IOChannel + ) async throws -> Int { + return try await self._write(array, to: diskIO) + } + + func _write( + _ bytes: Bytes, + to diskIO: borrowing IOChannel + ) async throws -> Int { + let fileDescriptor = diskIO.channel + let signalStream = self.registerFileDescriptor(fileDescriptor, for: .write) + var writtenLength: Int = 0 + /// Outer loop: every iteration signals we are ready to read more data + for try await _ in signalStream { + /// Inner loop: repeatedly call `.write()` and write more data until: + /// 1. We've written bytes.count bytes. + /// 3. `.write()` returns -1 and sets `errno` to `EAGAIN` or `EWOULDBLOCK`. In + /// this case we `break` out of the inner loop and wait `.write()` to be + /// ready by `await`ing the next signal in the outer loop. + while true { + let written = bytes.withUnsafeBytes { ptr in + let remainingLength = ptr.count - writtenLength + let startPtr = ptr.baseAddress!.advanced(by: writtenLength) + return _SubprocessCShims.write(fileDescriptor.rawValue, startPtr, remainingLength) + } + if written > 0 { + writtenLength += written + if writtenLength >= bytes.count { + // Wrote all data + try self.removeRegistration(for: fileDescriptor) + return writtenLength + } + } else { + if errno == EAGAIN || errno == EWOULDBLOCK { + // No more data for now wait for the next signal + break + } else { + // Throw all other errors + try self.removeRegistration(for: fileDescriptor) + throw SubprocessError.UnderlyingError(rawValue: errno) + } + } + } + } + return 0 + } + +#if SubprocessSpan + func write( + _ span: borrowing RawSpan, + to diskIO: borrowing IOChannel + ) async throws -> Int { + let fileDescriptor = diskIO.channel + let signalStream = self.registerFileDescriptor(fileDescriptor, for: .write) + var writtenLength: Int = 0 + /// Outer loop: every iteration signals we are ready to read more data + for try await _ in signalStream { + /// Inner loop: repeatedly call `.write()` and write more data until: + /// 1. We've written bytes.count bytes. + /// 3. `.write()` returns -1 and sets `errno` to `EAGAIN` or `EWOULDBLOCK`. In + /// this case we `break` out of the inner loop and wait `.write()` to be + /// ready by `await`ing the next signal in the outer loop. + while true { + let written = span.withUnsafeBytes { ptr in + let remainingLength = ptr.count - writtenLength + let startPtr = ptr.baseAddress!.advanced(by: writtenLength) + return _SubprocessCShims.write(fileDescriptor.rawValue, startPtr, remainingLength) + } + if written > 0 { + writtenLength += written + if writtenLength >= span.byteCount { + // Wrote all data + try self.removeRegistration(for: fileDescriptor) + return writtenLength + } + } else { + if errno == EAGAIN || errno == EWOULDBLOCK { + // No more data for now wait for the next signal + break + } else { + // Throw all other errors + try self.removeRegistration(for: fileDescriptor) + throw SubprocessError.UnderlyingError(rawValue: errno) + } + } + } + } + return 0 + } +#endif +} + +extension Array : AsyncIO._ContiguousBytes where Element == UInt8 {} + +#endif // canImport(Glibc) || canImport(Android) || canImport(Musl) diff --git a/Sources/Subprocess/IO/AsyncIO+Windows.swift b/Sources/Subprocess/IO/AsyncIO+Windows.swift new file mode 100644 index 0000000..bf75e59 --- /dev/null +++ b/Sources/Subprocess/IO/AsyncIO+Windows.swift @@ -0,0 +1,437 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2025 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// +//===----------------------------------------------------------------------===// + +/// Windows AsyncIO based on IO Completion Ports and Overlapped + +#if os(Windows) + +#if canImport(System) +@preconcurrency import System +#else +@preconcurrency import SystemPackage +#endif + +import Synchronization +internal import Dispatch +@preconcurrency import WinSDK + +private typealias SignalStream = AsyncThrowingStream +private let shutdownPort: UInt64 = .max +private let _registration: Mutex< + [UInt64 : SignalStream.Continuation] +> = Mutex([:]) + +final class AsyncIO: @unchecked Sendable { + + protocol _ContiguousBytes: Sendable { + var count: Int { get } + + func withUnsafeBytes( + _ body: (UnsafeRawBufferPointer + ) throws -> ResultType) rethrows -> ResultType + } + + private final class MonitorThreadContext { + let ioCompletionPort: HANDLE + + init(ioCompletionPort: HANDLE) { + self.ioCompletionPort = ioCompletionPort + } + } + + static let shared = AsyncIO() + + private let ioCompletionPort: Result + + private let monitorThread: Result + + private init() { + var maybeSetupError: SubprocessError? = nil + // Create the the completion port + guard let port = CreateIoCompletionPort( + INVALID_HANDLE_VALUE, nil, 0, 0 + ), port != INVALID_HANDLE_VALUE else { + let error = SubprocessError( + code: .init(.asyncIOFailed("CreateIoCompletionPort failed")), + underlyingError: .init(rawValue: GetLastError()) + ) + self.ioCompletionPort = .failure(error) + self.monitorThread = .failure(error) + return + } + self.ioCompletionPort = .success(port) + // Create monitor thread + let threadContext = MonitorThreadContext(ioCompletionPort: port) + let threadContextPtr = Unmanaged.passRetained(threadContext) + let threadHandle = CreateThread(nil, 0, { args in + func reportError(_ error: SubprocessError) { + _registration.withLock { store in + for continuation in store.values { + continuation.finish(throwing: error) + } + } + } + + let unmanaged = Unmanaged.fromOpaque(args!) + let context = unmanaged.takeRetainedValue() + + // Monitor loop + while true { + var bytesTransferred: DWORD = 0 + var targetFileDescriptor: UInt64 = 0 + var overlapped: LPOVERLAPPED? = nil + + let monitorResult = GetQueuedCompletionStatus( + context.ioCompletionPort, + &bytesTransferred, + &targetFileDescriptor, + &overlapped, + INFINITE + ) + if !monitorResult { + let lastError = GetLastError() + if lastError == ERROR_BROKEN_PIPE { + // We finished reading the handle. Signal EOF by + // finishing the stream. + // NOTE: here we deliberately leave now unused continuation + // in the store. Windows does not offer an API to remove a + // HANDLE from an IOCP port, therefore we leave the registration + // to signify the HANDLE has already been resisted. + _registration.withLock { store in + if let continuation = store[targetFileDescriptor] { + continuation.finish() + } + } + continue + } else { + let error = SubprocessError( + code: .init(.asyncIOFailed("GetQueuedCompletionStatus failed")), + underlyingError: .init(rawValue: lastError) + ) + reportError(error) + break + } + } + + // Breakout the monitor loop if we received shutdown from the shutdownFD + if targetFileDescriptor == shutdownPort { + break + } + // Notify the continuations + _registration.withLock { store in + if let continuation = store[targetFileDescriptor] { + continuation.yield(bytesTransferred) + } + } + } + return 0 + }, threadContextPtr.toOpaque(), 0, nil) + guard let threadHandle = threadHandle else { + let error = SubprocessError( + code: .init(.asyncIOFailed("CreateThread failed")), + underlyingError: .init(rawValue: GetLastError()) + ) + self.monitorThread = .failure(error) + return + } + self.monitorThread = .success(threadHandle) + + atexit { + AsyncIO.shared.shutdown() + } + } + + private func shutdown() { + // Post status to shutdown HANDLE + guard case .success(let ioPort) = ioCompletionPort, + case .success(let monitorThreadHandle) = monitorThread else { + return + } + PostQueuedCompletionStatus( + ioPort, + 0, + shutdownPort, + nil + ) + // Wait for monitor thread to exit + WaitForSingleObject(monitorThreadHandle, INFINITE) + CloseHandle(ioPort) + CloseHandle(monitorThreadHandle) + } + + private func registerHandle(_ handle: HANDLE) -> SignalStream { + return SignalStream { continuation in + switch self.ioCompletionPort { + case .success(let ioPort): + // Make sure thread setup also succeed + if case .failure(let error) = monitorThread { + continuation.finish(throwing: error) + return + } + let completionKey = UInt64(UInt(bitPattern: handle)) + // Windows does not offer an API to remove a handle + // from given ioCompletionPort. If this handle has already + // been registered we simply need to update the continuation + let registrationFound = _registration.withLock { storage in + if storage[completionKey] != nil { + // Old registration found. This means this handle has + // already been registered. We simply need to update + // the continuation saved + storage[completionKey] = continuation + return true + } else { + return false + } + } + if registrationFound { + return + } + + // Windows Documentation: The function returns the handle + // of the existing I/O completion port if successful + guard CreateIoCompletionPort( + handle, ioPort, completionKey, 0 + ) == ioPort else { + let error = SubprocessError( + code: .init(.asyncIOFailed("CreateIoCompletionPort failed")), + underlyingError: .init(rawValue: GetLastError()) + ) + continuation.finish(throwing: error) + return + } + // Now save the continuation + _registration.withLock { storage in + storage[completionKey] = continuation + } + case .failure(let error): + continuation.finish(throwing: error) + } + } + } + + internal func removeRegistration(for handle: HANDLE) { + let completionKey = UInt64(UInt(bitPattern: handle)) + _registration.withLock { storage in + storage.removeValue(forKey: completionKey) + } + } + + func read( + from diskIO: borrowing IOChannel, + upTo maxLength: Int + ) async throws -> [UInt8]? { + return try await self.read(from: diskIO.channel, upTo: maxLength) + } + + func read( + from handle: HANDLE, + upTo maxLength: Int + ) async throws -> [UInt8]? { + // If we are reading until EOF, start with readBufferSize + // and gradually increase buffer size + let bufferLength = maxLength == .max ? readBufferSize : maxLength + + var resultBuffer: [UInt8] = Array( + repeating: 0, count: bufferLength + ) + var readLength: Int = 0 + var signalStream = self.registerHandle(handle).makeAsyncIterator() + + while true { + var overlapped = _OVERLAPPED() + let succeed = try resultBuffer.withUnsafeMutableBufferPointer { bufferPointer in + // Get a pointer to the memory at the specified offset + // Windows ReadFile uses DWORD for target count, which means we can only + // read up to DWORD (aka UInt32) max. + let targetCount: DWORD + if MemoryLayout.size == MemoryLayout.size { + // On 32 bit systems we don't have to worry about overflowing + targetCount = DWORD(truncatingIfNeeded: bufferPointer.count - readLength) + } else { + // On 64 bit systems we need to cap the count at DWORD max + targetCount = DWORD(truncatingIfNeeded: min(bufferPointer.count - readLength, Int(UInt32.max))) + } + + let offsetAddress = bufferPointer.baseAddress!.advanced(by: readLength) + // Read directly into the buffer at the offset + return ReadFile( + handle, + offsetAddress, + DWORD(truncatingIfNeeded: targetCount), + nil, + &overlapped + ) + } + + if !succeed { + // It is expected `ReadFile` to return `false` in async mode. + // Make sure we only get `ERROR_IO_PENDING` or `ERROR_BROKEN_PIPE` + let lastError = GetLastError() + if lastError == ERROR_BROKEN_PIPE { + // We reached EOF + return nil + } + guard lastError == ERROR_IO_PENDING else { + let error = SubprocessError( + code: .init(.failedToReadFromSubprocess), + underlyingError: .init(rawValue: lastError) + ) + throw error + } + + } + // Now wait for read to finish + let bytesRead = try await signalStream.next() ?? 0 + + if bytesRead == 0 { + // We reached EOF. Return whatever's left + guard readLength > 0 else { + return nil + } + resultBuffer.removeLast(resultBuffer.count - readLength) + return resultBuffer + } else { + // Read some data + readLength += Int(bytesRead) + if maxLength == .max { + // Grow resultBuffer if needed + guard Double(readLength) > 0.8 * Double(resultBuffer.count) else { + continue + } + resultBuffer.append( + contentsOf: Array(repeating: 0, count: resultBuffer.count) + ) + } else if readLength >= maxLength { + // When we reached maxLength, return! + return resultBuffer + } + } + } + } + + func write( + _ array: [UInt8], + to diskIO: borrowing IOChannel + ) async throws -> Int { + return try await self._write(array, to: diskIO) + } + +#if SubprocessSpan + func write( + _ span: borrowing RawSpan, + to diskIO: borrowing IOChannel + ) async throws -> Int { + let handle = diskIO.channel + var signalStream = self.registerHandle(diskIO.channel).makeAsyncIterator() + var writtenLength: Int = 0 + while true { + var overlapped = _OVERLAPPED() + let succeed = try span.withUnsafeBytes { ptr in + // Windows WriteFile uses DWORD for target count + // which means we can only write up to DWORD max + let remainingLength: DWORD + if MemoryLayout.size == MemoryLayout.size { + // On 32 bit systems we don't have to worry about overflowing + remainingLength = DWORD(truncatingIfNeeded: ptr.count - writtenLength) + } else { + // On 64 bit systems we need to cap the count at DWORD max + remainingLength = DWORD(truncatingIfNeeded: min(ptr.count - writtenLength, Int(DWORD.max))) + } + + let startPtr = ptr.baseAddress!.advanced(by: writtenLength) + return WriteFile( + handle, + startPtr, + DWORD(truncatingIfNeeded: remainingLength), + nil, + &overlapped + ) + } + if !succeed { + // It is expected `WriteFile` to return `false` in async mode. + // Make sure we only get `ERROR_IO_PENDING` + let lastError = GetLastError() + guard lastError == ERROR_IO_PENDING else { + let error = SubprocessError( + code: .init(.failedToWriteToSubprocess), + underlyingError: .init(rawValue: lastError) + ) + throw error + } + + } + // Now wait for read to finish + let bytesWritten: DWORD = try await signalStream.next() ?? 0 + + writtenLength += Int(bytesWritten) + if writtenLength >= span.byteCount { + return writtenLength + } + } + } +#endif // SubprocessSpan + + func _write( + _ bytes: Bytes, + to diskIO: borrowing IOChannel + ) async throws -> Int { + let handle = diskIO.channel + var signalStream = self.registerHandle(diskIO.channel).makeAsyncIterator() + var writtenLength: Int = 0 + while true { + var overlapped = _OVERLAPPED() + let succeed = try bytes.withUnsafeBytes { ptr in + // Windows WriteFile uses DWORD for target count + // which means we can only write up to DWORD max + let remainingLength: DWORD + if MemoryLayout.size == MemoryLayout.size { + // On 32 bit systems we don't have to worry about overflowing + remainingLength = DWORD(truncatingIfNeeded: ptr.count - writtenLength) + } else { + // On 64 bit systems we need to cap the count at DWORD max + remainingLength = DWORD(truncatingIfNeeded: min(ptr.count - writtenLength, Int(DWORD.max))) + } + let startPtr = ptr.baseAddress!.advanced(by: writtenLength) + return WriteFile( + handle, + startPtr, + DWORD(truncatingIfNeeded: remainingLength), + nil, + &overlapped + ) + } + + if !succeed { + // It is expected `WriteFile` to return `false` in async mode. + // Make sure we only get `ERROR_IO_PENDING` + let lastError = GetLastError() + guard lastError == ERROR_IO_PENDING else { + let error = SubprocessError( + code: .init(.failedToWriteToSubprocess), + underlyingError: .init(rawValue: lastError) + ) + throw error + } + } + // Now wait for read to finish + let bytesWritten: DWORD = try await signalStream.next() ?? 0 + writtenLength += Int(bytesWritten) + if writtenLength >= bytes.count { + return writtenLength + } + } + } +} + +extension Array : AsyncIO._ContiguousBytes where Element == UInt8 {} + +#endif + diff --git a/Sources/Subprocess/IO/AsyncIO.swift b/Sources/Subprocess/IO/AsyncIO.swift deleted file mode 100644 index 35894a5..0000000 --- a/Sources/Subprocess/IO/AsyncIO.swift +++ /dev/null @@ -1,1076 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the Swift.org open source project -// -// Copyright (c) 2025 Apple Inc. and the Swift project authors -// Licensed under Apache License v2.0 with Runtime Library Exception -// -// See https://swift.org/LICENSE.txt for license information -// -//===----------------------------------------------------------------------===// - -#if canImport(System) -@preconcurrency import System -#else -@preconcurrency import SystemPackage -#endif - -/// Platform specific asynchronous read/write implementation - -// MARK: - Linux (epoll) -#if canImport(Glibc) || canImport(Android) || canImport(Musl) - -#if canImport(Glibc) -import Glibc -#elseif canImport(Android) -import Android -#elseif canImport(Musl) -import Musl -#endif - -import _SubprocessCShims -import Synchronization - -private typealias SignalStream = AsyncThrowingStream -private let _epollEventSize = 256 -private let _registration: Mutex< - [PlatformFileDescriptor : SignalStream.Continuation] -> = Mutex([:]) - -final class AsyncIO: Sendable { - - typealias OutputStream = AsyncThrowingStream - - private final class MonitorThreadContext { - let epollFileDescriptor: CInt - let shutdownFileDescriptor: CInt - - init( - epollFileDescriptor: CInt, - shutdownFileDescriptor: CInt - ) { - self.epollFileDescriptor = epollFileDescriptor - self.shutdownFileDescriptor = shutdownFileDescriptor - } - } - - private enum Event { - case read - case write - } - - private struct State { - let epollFileDescriptor: CInt - let shutdownFileDescriptor: CInt - let monitorThread: pthread_t - } - - static let shared: AsyncIO = AsyncIO() - - private let state: Result - - private init() { - // Create main epoll fd - let epollFileDescriptor = epoll_create1(CInt(EPOLL_CLOEXEC)) - guard epollFileDescriptor >= 0 else { - let error = SubprocessError( - code: .init(.asyncIOFailed("epoll_create1 failed")), - underlyingError: .init(rawValue: errno) - ) - self.state = .failure(error) - return - } - // Create shutdownFileDescriptor - let shutdownFileDescriptor = eventfd(0, CInt(EFD_NONBLOCK | EFD_CLOEXEC)) - guard shutdownFileDescriptor >= 0 else { - let error = SubprocessError( - code: .init(.asyncIOFailed("eventfd failed")), - underlyingError: .init(rawValue: errno) - ) - self.state = .failure(error) - return - } - - // Register shutdownFileDescriptor with epoll - var event = epoll_event( - events: EPOLLIN.rawValue, - data: epoll_data(fd: shutdownFileDescriptor) - ) - var rc = epoll_ctl( - epollFileDescriptor, - EPOLL_CTL_ADD, - shutdownFileDescriptor, - &event - ) - guard rc == 0 else { - let error = SubprocessError( - code: .init(.asyncIOFailed( - "failed to add shutdown fd \(shutdownFileDescriptor) to epoll list") - ), - underlyingError: .init(rawValue: errno) - ) - self.state = .failure(error) - return - } - - // Create thread data - let context = MonitorThreadContext( - epollFileDescriptor: epollFileDescriptor, - shutdownFileDescriptor: shutdownFileDescriptor - ) - let threadContext = Unmanaged.passRetained(context) - #if os(FreeBSD) || os(OpenBSD) - var thread: pthread_t? = nil - #else - var thread: pthread_t = pthread_t() - #endif - rc = pthread_create(&thread, nil, { args in - func reportError(_ error: SubprocessError) { - _registration.withLock { store in - for continuation in store.values { - continuation.finish(throwing: error) - } - } - } - - let unmanaged = Unmanaged.fromOpaque(args!) - let context = unmanaged.takeRetainedValue() - - var events: [epoll_event] = Array( - repeating: epoll_event(events: 0, data: epoll_data(fd: 0)), - count: _epollEventSize - ) - - // Enter the monitor loop - monitorLoop: while true { - let eventCount = epoll_wait( - context.epollFileDescriptor, - &events, - CInt(events.count), - -1 - ) - if eventCount < 0 { - if errno == EINTR || errno == EAGAIN { - continue // interrupted by signal; try again - } - // Report other errors - let error = SubprocessError( - code: .init(.asyncIOFailed( - "epoll_wait failed") - ), - underlyingError: .init(rawValue: errno) - ) - reportError(error) - break monitorLoop - } - - for index in 0 ..< Int(eventCount) { - let event = events[index] - let targetFileDescriptor = event.data.fd - // Breakout the monitor loop if we received shutdown - // from the shutdownFD - if targetFileDescriptor == context.shutdownFileDescriptor { - var buf: UInt64 = 0 - _ = _SubprocessCShims.read(context.shutdownFileDescriptor, &buf, MemoryLayout.size) - break monitorLoop - } - - // Notify the continuation - _registration.withLock { store in - if let continuation = store[targetFileDescriptor] { - continuation.yield(true) - } - } - } - } - - return nil - }, threadContext.toOpaque()) - guard rc == 0 else { - let error = SubprocessError( - code: .init(.asyncIOFailed("Failed to create monitor thread")), - underlyingError: .init(rawValue: rc) - ) - self.state = .failure(error) - return - } - - #if os(FreeBSD) || os(OpenBSD) - let monitorThread = thread! - #else - let monitorThread = thread - #endif - - let state = State( - epollFileDescriptor: epollFileDescriptor, - shutdownFileDescriptor: shutdownFileDescriptor, - monitorThread: monitorThread - ) - self.state = .success(state) - - atexit { - AsyncIO.shared.shutdown() - } - } - - private func shutdown() { - guard case .success(let currentState) = self.state else { - return - } - - var one: UInt64 = 1 - // Wake up the thread for shutdown - _ = _SubprocessCShims.write(currentState.shutdownFileDescriptor, &one, MemoryLayout.stride) - // Cleanup the monitor thread - pthread_join(currentState.monitorThread, nil) - } - - - private func registerFileDescriptor( - _ fileDescriptor: FileDescriptor, - for event: Event - ) -> SignalStream { - return SignalStream { continuation in - // If setup failed, nothing much we can do - switch self.state { - case .success(let state): - // Set file descriptor to be non blocking - let flags = fcntl(fileDescriptor.rawValue, F_GETFD) - guard flags != -1 else { - let error = SubprocessError( - code: .init(.asyncIOFailed( - "failed to get flags for \(fileDescriptor.rawValue)") - ), - underlyingError: .init(rawValue: errno) - ) - continuation.finish(throwing: error) - return - } - guard fcntl(fileDescriptor.rawValue, F_SETFL, flags | O_NONBLOCK) != -1 else { - let error = SubprocessError( - code: .init(.asyncIOFailed( - "failed to set \(fileDescriptor.rawValue) to be non-blocking") - ), - underlyingError: .init(rawValue: errno) - ) - continuation.finish(throwing: error) - return - } - // Register event - let targetEvent: EPOLL_EVENTS - switch event { - case .read: - targetEvent = EPOLLIN - case .write: - targetEvent = EPOLLOUT - } - - var event = epoll_event( - events: targetEvent.rawValue, - data: epoll_data(fd: fileDescriptor.rawValue) - ) - let rc = epoll_ctl( - state.epollFileDescriptor, - EPOLL_CTL_ADD, - fileDescriptor.rawValue, - &event - ) - if rc != 0 { - let error = SubprocessError( - code: .init(.asyncIOFailed( - "failed to add \(fileDescriptor.rawValue) to epoll list") - ), - underlyingError: .init(rawValue: errno) - ) - continuation.finish(throwing: error) - return - } - // Now save the continuation - _registration.withLock { storage in - storage[fileDescriptor.rawValue] = continuation - } - case .failure(let setupError): - continuation.finish(throwing: setupError) - return - } - } - } - - private func removeRegistration(for fileDescriptor: FileDescriptor) throws { - switch self.state { - case .success(let state): - let rc = epoll_ctl( - state.epollFileDescriptor, - EPOLL_CTL_DEL, - fileDescriptor.rawValue, - nil - ) - guard rc == 0 else { - throw SubprocessError( - code: .init(.asyncIOFailed( - "failed to remove \(fileDescriptor.rawValue) to epoll list") - ), - underlyingError: .init(rawValue: errno) - ) - } - _registration.withLock { store in - _ = store.removeValue(forKey: fileDescriptor.rawValue) - } - case .failure(let setupFailure): - throw setupFailure - } - } -} - -extension AsyncIO { - - protocol _ContiguousBytes { - var count: Int { get } - - func withUnsafeBytes( - _ body: (UnsafeRawBufferPointer) throws -> ResultType - ) rethrows -> ResultType - } - - func read( - from diskIO: borrowing IOChannel, - upTo maxLength: Int - ) async throws -> [UInt8]? { - return try await self.read(from: diskIO.channel, upTo: maxLength) - } - - func read( - from fileDescriptor: FileDescriptor, - upTo maxLength: Int - ) async throws -> [UInt8]? { - // If we are reading until EOF, start with readBufferSize - // and gradually increase buffer size - let bufferLength = maxLength == .max ? readBufferSize : maxLength - - var resultBuffer: [UInt8] = Array( - repeating: 0, count: bufferLength - ) - var readLength: Int = 0 - let signalStream = self.registerFileDescriptor(fileDescriptor, for: .read) - /// Outer loop: every iteration signals we are ready to read more data - for try await _ in signalStream { - /// Inner loop: repeatedly call `.read()` and read more data until: - /// 1. We reached EOF (read length is 0), in which case return the result - /// 2. We read `maxLength` bytes, in which case return the result - /// 3. `read()` returns -1 and sets `errno` to `EAGAIN` or `EWOULDBLOCK`. In - /// this case we `break` out of the inner loop and wait `.read()` to be - /// ready by `await`ing the next signal in the outer loop. - while true { - let bytesRead = resultBuffer.withUnsafeMutableBufferPointer { bufferPointer in - // Get a pointer to the memory at the specified offset - let targetCount = bufferPointer.count - readLength - - let offsetAddress = bufferPointer.baseAddress!.advanced(by: readLength) - - // Read directly into the buffer at the offset - return _SubprocessCShims.read(fileDescriptor.rawValue, offsetAddress, targetCount) - } - if bytesRead > 0 { - // Read some data - readLength += bytesRead - if maxLength == .max { - // Grow resultBuffer if needed - guard Double(readLength) > 0.8 * Double(resultBuffer.count) else { - continue - } - resultBuffer.append( - contentsOf: Array(repeating: 0, count: resultBuffer.count) - ) - } else if readLength >= maxLength { - // When we reached maxLength, return! - try self.removeRegistration(for: fileDescriptor) - return resultBuffer - } - } else if bytesRead == 0 { - // We reached EOF. Return whatever's left - try self.removeRegistration(for: fileDescriptor) - guard readLength > 0 else { - return nil - } - resultBuffer.removeLast(resultBuffer.count - readLength) - return resultBuffer - } else { - if errno == EAGAIN || errno == EWOULDBLOCK { - // No more data for now wait for the next signal - break - } else { - // Throw all other errors - try self.removeRegistration(for: fileDescriptor) - throw SubprocessError.UnderlyingError(rawValue: errno) - } - } - } - } - return resultBuffer - } - - func write( - _ array: [UInt8], - to diskIO: borrowing IOChannel - ) async throws -> Int { - return try await self._write(array, to: diskIO) - } - - func _write( - _ bytes: Bytes, - to diskIO: borrowing IOChannel - ) async throws -> Int { - let fileDescriptor = diskIO.channel - let signalStream = self.registerFileDescriptor(fileDescriptor, for: .write) - var writtenLength: Int = 0 - /// Outer loop: every iteration signals we are ready to read more data - for try await _ in signalStream { - /// Inner loop: repeatedly call `.write()` and write more data until: - /// 1. We've written bytes.count bytes. - /// 3. `.write()` returns -1 and sets `errno` to `EAGAIN` or `EWOULDBLOCK`. In - /// this case we `break` out of the inner loop and wait `.write()` to be - /// ready by `await`ing the next signal in the outer loop. - while true { - let written = bytes.withUnsafeBytes { ptr in - let remainingLength = ptr.count - writtenLength - let startPtr = ptr.baseAddress!.advanced(by: writtenLength) - return _SubprocessCShims.write(fileDescriptor.rawValue, startPtr, remainingLength) - } - if written > 0 { - writtenLength += written - if writtenLength >= bytes.count { - // Wrote all data - try self.removeRegistration(for: fileDescriptor) - return writtenLength - } - } else { - if errno == EAGAIN || errno == EWOULDBLOCK { - // No more data for now wait for the next signal - break - } else { - // Throw all other errors - try self.removeRegistration(for: fileDescriptor) - throw SubprocessError.UnderlyingError(rawValue: errno) - } - } - } - } - return 0 - } - - #if SubprocessSpan - func write( - _ span: borrowing RawSpan, - to diskIO: borrowing IOChannel - ) async throws -> Int { - let fileDescriptor = diskIO.channel - let signalStream = self.registerFileDescriptor(fileDescriptor, for: .write) - var writtenLength: Int = 0 - /// Outer loop: every iteration signals we are ready to read more data - for try await _ in signalStream { - /// Inner loop: repeatedly call `.write()` and write more data until: - /// 1. We've written bytes.count bytes. - /// 3. `.write()` returns -1 and sets `errno` to `EAGAIN` or `EWOULDBLOCK`. In - /// this case we `break` out of the inner loop and wait `.write()` to be - /// ready by `await`ing the next signal in the outer loop. - while true { - let written = span.withUnsafeBytes { ptr in - let remainingLength = ptr.count - writtenLength - let startPtr = ptr.baseAddress!.advanced(by: writtenLength) - return _SubprocessCShims.write(fileDescriptor.rawValue, startPtr, remainingLength) - } - if written > 0 { - writtenLength += written - if writtenLength >= span.byteCount { - // Wrote all data - try self.removeRegistration(for: fileDescriptor) - return writtenLength - } - } else { - if errno == EAGAIN || errno == EWOULDBLOCK { - // No more data for now wait for the next signal - break - } else { - // Throw all other errors - try self.removeRegistration(for: fileDescriptor) - throw SubprocessError.UnderlyingError(rawValue: errno) - } - } - } - } - return 0 - } - #endif -} - -extension Array : AsyncIO._ContiguousBytes where Element == UInt8 {} - -#endif // canImport(Glibc) || canImport(Android) || canImport(Musl) - -// MARK: - macOS (DispatchIO) -#if canImport(Darwin) - -internal import Dispatch - - -final class AsyncIO: Sendable { - static let shared: AsyncIO = AsyncIO() - - private init() {} - - internal func read( - from diskIO: borrowing IOChannel, - upTo maxLength: Int - ) async throws -> DispatchData? { - return try await self.read( - from: diskIO.channel, - upTo: maxLength, - ) - } - - internal func read( - from dispatchIO: DispatchIO, - upTo maxLength: Int - ) async throws -> DispatchData? { - return try await withCheckedThrowingContinuation { continuation in - var buffer: DispatchData = .empty - dispatchIO.read( - offset: 0, - length: maxLength, - queue: .global() - ) { done, data, error in - if error != 0 { - continuation.resume( - throwing: SubprocessError( - code: .init(.failedToReadFromSubprocess), - underlyingError: .init(rawValue: error) - ) - ) - return - } - if let data = data { - if buffer.isEmpty { - buffer = data - } else { - buffer.append(data) - } - } - if done { - if !buffer.isEmpty { - continuation.resume(returning: buffer) - } else { - continuation.resume(returning: nil) - } - } - } - } - } - - #if SubprocessSpan - internal func write( - _ span: borrowing RawSpan, - to diskIO: borrowing IOChannel - ) async throws -> Int { - try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in - let dispatchData = span.withUnsafeBytes { - return DispatchData( - bytesNoCopy: $0, - deallocator: .custom( - nil, - { - // noop - } - ) - ) - } - self.write(dispatchData, to: diskIO) { writtenLength, error in - if let error = error { - continuation.resume(throwing: error) - } else { - continuation.resume(returning: writtenLength) - } - } - } - } - #endif // SubprocessSpan - - internal func write( - _ array: [UInt8], - to diskIO: borrowing IOChannel - ) async throws -> Int { - try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in - let dispatchData = array.withUnsafeBytes { - return DispatchData( - bytesNoCopy: $0, - deallocator: .custom( - nil, - { - // noop - } - ) - ) - } - self.write(dispatchData, to: diskIO) { writtenLength, error in - if let error = error { - continuation.resume(throwing: error) - } else { - continuation.resume(returning: writtenLength) - } - } - } - } - - internal func write( - _ dispatchData: DispatchData, - to diskIO: borrowing IOChannel, - queue: DispatchQueue = .global(), - completion: @escaping (Int, Error?) -> Void - ) { - diskIO.channel.write( - offset: 0, - data: dispatchData, - queue: queue - ) { done, unwritten, error in - guard done else { - // Wait until we are done writing or encountered some error - return - } - - let unwrittenLength = unwritten?.count ?? 0 - let writtenLength = dispatchData.count - unwrittenLength - guard error != 0 else { - completion(writtenLength, nil) - return - } - completion( - writtenLength, - SubprocessError( - code: .init(.failedToWriteToSubprocess), - underlyingError: .init(rawValue: error) - ) - ) - } - } -} - -#endif - -// MARK: - Windows (I/O Completion Ports) -#if os(Windows) - -import Synchronization -internal import Dispatch -@preconcurrency import WinSDK - -private typealias SignalStream = AsyncThrowingStream -private let shutdownPort: UInt64 = .max -private let _registration: Mutex< - [UInt64 : SignalStream.Continuation] -> = Mutex([:]) - -final class AsyncIO: @unchecked Sendable { - - protocol _ContiguousBytes: Sendable { - var count: Int { get } - - func withUnsafeBytes( - _ body: (UnsafeRawBufferPointer - ) throws -> ResultType) rethrows -> ResultType - } - - private final class MonitorThreadContext { - let ioCompletionPort: HANDLE - - init(ioCompletionPort: HANDLE) { - self.ioCompletionPort = ioCompletionPort - } - } - - static let shared = AsyncIO() - - private let ioCompletionPort: Result - - private let monitorThread: Result - - private init() { - var maybeSetupError: SubprocessError? = nil - // Create the the completion port - guard let port = CreateIoCompletionPort( - INVALID_HANDLE_VALUE, nil, 0, 0 - ), port != INVALID_HANDLE_VALUE else { - let error = SubprocessError( - code: .init(.asyncIOFailed("CreateIoCompletionPort failed")), - underlyingError: .init(rawValue: GetLastError()) - ) - self.ioCompletionPort = .failure(error) - self.monitorThread = .failure(error) - return - } - self.ioCompletionPort = .success(port) - // Create monitor thread - let threadContext = MonitorThreadContext(ioCompletionPort: port) - let threadContextPtr = Unmanaged.passRetained(threadContext) - let threadHandle = CreateThread(nil, 0, { args in - func reportError(_ error: SubprocessError) { - _registration.withLock { store in - for continuation in store.values { - continuation.finish(throwing: error) - } - } - } - - let unmanaged = Unmanaged.fromOpaque(args!) - let context = unmanaged.takeRetainedValue() - - // Monitor loop - while true { - var bytesTransferred: DWORD = 0 - var targetFileDescriptor: UInt64 = 0 - var overlapped: LPOVERLAPPED? = nil - - let monitorResult = GetQueuedCompletionStatus( - context.ioCompletionPort, - &bytesTransferred, - &targetFileDescriptor, - &overlapped, - INFINITE - ) - if !monitorResult { - let lastError = GetLastError() - if lastError == ERROR_BROKEN_PIPE { - // We finished reading the handle. Signal EOF by - // finishing the stream. - // NOTE: here we deliberately leave now unused continuation - // in the store. Windows does not offer an API to remove a - // HANDLE from an IOCP port, therefore we leave the registration - // to signify the HANDLE has already been resisted. - _registration.withLock { store in - if let continuation = store[targetFileDescriptor] { - continuation.finish() - } - } - continue - } else { - let error = SubprocessError( - code: .init(.asyncIOFailed("GetQueuedCompletionStatus failed")), - underlyingError: .init(rawValue: lastError) - ) - reportError(error) - break - } - } - - // Breakout the monitor loop if we received shutdown from the shutdownFD - if targetFileDescriptor == shutdownPort { - break - } - // Notify the continuations - _registration.withLock { store in - if let continuation = store[targetFileDescriptor] { - continuation.yield(bytesTransferred) - } - } - } - return 0 - }, threadContextPtr.toOpaque(), 0, nil) - guard let threadHandle = threadHandle else { - let error = SubprocessError( - code: .init(.asyncIOFailed("CreateThread failed")), - underlyingError: .init(rawValue: GetLastError()) - ) - self.monitorThread = .failure(error) - return - } - self.monitorThread = .success(threadHandle) - - atexit { - AsyncIO.shared.shutdown() - } - } - - private func shutdown() { - // Post status to shutdown HANDLE - guard case .success(let ioPort) = ioCompletionPort, - case .success(let monitorThreadHandle) = monitorThread else { - return - } - PostQueuedCompletionStatus( - ioPort, - 0, - shutdownPort, - nil - ) - // Wait for monitor thread to exit - WaitForSingleObject(monitorThreadHandle, INFINITE) - CloseHandle(ioPort) - CloseHandle(monitorThreadHandle) - } - - private func registerHandle(_ handle: HANDLE) -> SignalStream { - return SignalStream { continuation in - switch self.ioCompletionPort { - case .success(let ioPort): - // Make sure thread setup also succeed - if case .failure(let error) = monitorThread { - continuation.finish(throwing: error) - return - } - let completionKey = UInt64(UInt(bitPattern: handle)) - // Windows does not offer an API to remove a handle - // from given ioCompletionPort. If this handle has already - // been registered we simply need to update the continuation - let registrationFound = _registration.withLock { storage in - if storage[completionKey] != nil { - // Old registration found. This means this handle has - // already been registered. We simply need to update - // the continuation saved - storage[completionKey] = continuation - return true - } else { - return false - } - } - if registrationFound { - return - } - - // Windows Documentation: The function returns the handle - // of the existing I/O completion port if successful - guard CreateIoCompletionPort( - handle, ioPort, completionKey, 0 - ) == ioPort else { - let error = SubprocessError( - code: .init(.asyncIOFailed("CreateIoCompletionPort failed")), - underlyingError: .init(rawValue: GetLastError()) - ) - continuation.finish(throwing: error) - return - } - // Now save the continuation - _registration.withLock { storage in - storage[completionKey] = continuation - } - case .failure(let error): - continuation.finish(throwing: error) - } - } - } - - internal func removeRegistration(for handle: HANDLE) { - let completionKey = UInt64(UInt(bitPattern: handle)) - _registration.withLock { storage in - storage.removeValue(forKey: completionKey) - } - } - - func read( - from diskIO: borrowing IOChannel, - upTo maxLength: Int - ) async throws -> [UInt8]? { - return try await self.read(from: diskIO.channel, upTo: maxLength) - } - - func read( - from handle: HANDLE, - upTo maxLength: Int - ) async throws -> [UInt8]? { - // If we are reading until EOF, start with readBufferSize - // and gradually increase buffer size - let bufferLength = maxLength == .max ? readBufferSize : maxLength - - var resultBuffer: [UInt8] = Array( - repeating: 0, count: bufferLength - ) - var readLength: Int = 0 - var signalStream = self.registerHandle(handle).makeAsyncIterator() - - while true { - var overlapped = _OVERLAPPED() - let succeed = try resultBuffer.withUnsafeMutableBufferPointer { bufferPointer in - // Get a pointer to the memory at the specified offset - // Windows ReadFile uses DWORD for target count, which means we can only - // read up to DWORD (aka UInt32) max. - let targetCount: DWORD - if MemoryLayout.size == MemoryLayout.size { - // On 32 bit systems we don't have to worry about overflowing - targetCount = DWORD(truncatingIfNeeded: bufferPointer.count - readLength) - } else { - // On 64 bit systems we need to cap the count at DWORD max - targetCount = DWORD(truncatingIfNeeded: min(bufferPointer.count - readLength, Int(UInt32.max))) - } - - let offsetAddress = bufferPointer.baseAddress!.advanced(by: readLength) - // Read directly into the buffer at the offset - return ReadFile( - handle, - offsetAddress, - DWORD(truncatingIfNeeded: targetCount), - nil, - &overlapped - ) - } - - if !succeed { - // It is expected `ReadFile` to return `false` in async mode. - // Make sure we only get `ERROR_IO_PENDING` or `ERROR_BROKEN_PIPE` - let lastError = GetLastError() - if lastError == ERROR_BROKEN_PIPE { - // We reached EOF - return nil - } - guard lastError == ERROR_IO_PENDING else { - let error = SubprocessError( - code: .init(.failedToReadFromSubprocess), - underlyingError: .init(rawValue: lastError) - ) - throw error - } - - } - // Now wait for read to finish - let bytesRead = try await signalStream.next() ?? 0 - - if bytesRead == 0 { - // We reached EOF. Return whatever's left - guard readLength > 0 else { - return nil - } - resultBuffer.removeLast(resultBuffer.count - readLength) - return resultBuffer - } else { - // Read some data - readLength += Int(bytesRead) - if maxLength == .max { - // Grow resultBuffer if needed - guard Double(readLength) > 0.8 * Double(resultBuffer.count) else { - continue - } - resultBuffer.append( - contentsOf: Array(repeating: 0, count: resultBuffer.count) - ) - } else if readLength >= maxLength { - // When we reached maxLength, return! - return resultBuffer - } - } - } - } - - func write( - _ array: [UInt8], - to diskIO: borrowing IOChannel - ) async throws -> Int { - return try await self._write(array, to: diskIO) - } - - #if SubprocessSpan - func write( - _ span: borrowing RawSpan, - to diskIO: borrowing IOChannel - ) async throws -> Int { - let handle = diskIO.channel - var signalStream = self.registerHandle(diskIO.channel).makeAsyncIterator() - var writtenLength: Int = 0 - while true { - var overlapped = _OVERLAPPED() - let succeed = try span.withUnsafeBytes { ptr in - // Windows WriteFile uses DWORD for target count - // which means we can only write up to DWORD max - let remainingLength: DWORD - if MemoryLayout.size == MemoryLayout.size { - // On 32 bit systems we don't have to worry about overflowing - remainingLength = DWORD(truncatingIfNeeded: ptr.count - writtenLength) - } else { - // On 64 bit systems we need to cap the count at DWORD max - remainingLength = DWORD(truncatingIfNeeded: min(ptr.count - writtenLength, Int(DWORD.max))) - } - - let startPtr = ptr.baseAddress!.advanced(by: writtenLength) - return WriteFile( - handle, - startPtr, - DWORD(truncatingIfNeeded: remainingLength), - nil, - &overlapped - ) - } - if !succeed { - // It is expected `WriteFile` to return `false` in async mode. - // Make sure we only get `ERROR_IO_PENDING` - let lastError = GetLastError() - guard lastError == ERROR_IO_PENDING else { - let error = SubprocessError( - code: .init(.failedToWriteToSubprocess), - underlyingError: .init(rawValue: lastError) - ) - throw error - } - - } - // Now wait for read to finish - let bytesWritten: DWORD = try await signalStream.next() ?? 0 - - writtenLength += Int(bytesWritten) - if writtenLength >= span.byteCount { - return writtenLength - } - } - } - #endif // SubprocessSpan - - func _write( - _ bytes: Bytes, - to diskIO: borrowing IOChannel - ) async throws -> Int { - let handle = diskIO.channel - var signalStream = self.registerHandle(diskIO.channel).makeAsyncIterator() - var writtenLength: Int = 0 - while true { - var overlapped = _OVERLAPPED() - let succeed = try bytes.withUnsafeBytes { ptr in - // Windows WriteFile uses DWORD for target count - // which means we can only write up to DWORD max - let remainingLength: DWORD - if MemoryLayout.size == MemoryLayout.size { - // On 32 bit systems we don't have to worry about overflowing - remainingLength = DWORD(truncatingIfNeeded: ptr.count - writtenLength) - } else { - // On 64 bit systems we need to cap the count at DWORD max - remainingLength = DWORD(truncatingIfNeeded: min(ptr.count - writtenLength, Int(DWORD.max))) - } - let startPtr = ptr.baseAddress!.advanced(by: writtenLength) - return WriteFile( - handle, - startPtr, - DWORD(truncatingIfNeeded: remainingLength), - nil, - &overlapped - ) - } - - if !succeed { - // It is expected `WriteFile` to return `false` in async mode. - // Make sure we only get `ERROR_IO_PENDING` - let lastError = GetLastError() - guard lastError == ERROR_IO_PENDING else { - let error = SubprocessError( - code: .init(.failedToWriteToSubprocess), - underlyingError: .init(rawValue: lastError) - ) - throw error - } - } - // Now wait for read to finish - let bytesWritten: DWORD = try await signalStream.next() ?? 0 - writtenLength += Int(bytesWritten) - if writtenLength >= bytes.count { - return writtenLength - } - } - } -} - -extension Array : AsyncIO._ContiguousBytes where Element == UInt8 {} - -#endif - From ba8c6c8cd39e62a7a57735cbb6a8ba98868025a5 Mon Sep 17 00:00:00 2001 From: Charles Hu Date: Thu, 24 Jul 2025 14:06:01 -0700 Subject: [PATCH 08/10] Use _beginthreadex instead of CreatThread on Windows for AsyncIO --- Sources/Subprocess/AsyncBufferSequence.swift | 2 +- Sources/Subprocess/Configuration.swift | 4 +- Sources/Subprocess/IO/AsyncIO+Darwin.swift | 36 +++---- Sources/Subprocess/IO/AsyncIO+Linux.swift | 31 +++--- Sources/Subprocess/IO/AsyncIO+Windows.swift | 97 +++++++++++-------- Sources/Subprocess/IO/Output.swift | 29 +++++- .../Platforms/Subprocess+Linux.swift | 4 - .../Platforms/Subprocess+Windows.swift | 2 +- .../_SubprocessCShims/include/process_shims.h | 1 + Sources/_SubprocessCShims/process_shims.c | 4 + 10 files changed, 125 insertions(+), 85 deletions(-) diff --git a/Sources/Subprocess/AsyncBufferSequence.swift b/Sources/Subprocess/AsyncBufferSequence.swift index 0ca7e6e..3076b12 100644 --- a/Sources/Subprocess/AsyncBufferSequence.swift +++ b/Sources/Subprocess/AsyncBufferSequence.swift @@ -153,7 +153,7 @@ extension AsyncBufferSequence { ) } #else - // Cast data to CodeUnitg type + // Cast data to CodeUnit type let result = buffer.withUnsafeBytes { ptr in return ptr.withMemoryRebound(to: Encoding.CodeUnit.self) { codeUnitPtr in return Array(codeUnitPtr) diff --git a/Sources/Subprocess/Configuration.swift b/Sources/Subprocess/Configuration.swift index 1e20557..7c8be26 100644 --- a/Sources/Subprocess/Configuration.swift +++ b/Sources/Subprocess/Configuration.swift @@ -602,7 +602,7 @@ internal func _safelyClose(_ target: _CloseTarget) throws { #if canImport(WinSDK) case .handle(let handle): /// Windows does not provide a “deregistration” API (the reverse of - /// `CreateIoCompletionPort`) for handles and it it reuses HANDLE + /// `CreateIoCompletionPort`) for handles and it reuses HANDLE /// values once they are closed. Since we rely on the handle value /// as the completion key for `CreateIoCompletionPort`, we should /// remove the registration when the handle is closed to allow @@ -688,7 +688,7 @@ internal struct IODescriptor: ~Copyable { type: .stream, fileDescriptor: self.platformDescriptor(), queue: .global(), - cleanupHandler: { error in + cleanupHandler: { @Sendable error in // Close the file descriptor if shouldClose { try? closeFd.close() diff --git a/Sources/Subprocess/IO/AsyncIO+Darwin.swift b/Sources/Subprocess/IO/AsyncIO+Darwin.swift index 218f75d..6c5d258 100644 --- a/Sources/Subprocess/IO/AsyncIO+Darwin.swift +++ b/Sources/Subprocess/IO/AsyncIO+Darwin.swift @@ -57,7 +57,7 @@ final class AsyncIO: Sendable { ) return } - if let data = data { + if let data { if buffer.isEmpty { buffer = data } else { @@ -81,8 +81,8 @@ final class AsyncIO: Sendable { to diskIO: borrowing IOChannel ) async throws -> Int { try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in - let dispatchData = span.withUnsafeBytes { - return DispatchData( + span.withUnsafeBytes { + let dispatchData = DispatchData( bytesNoCopy: $0, deallocator: .custom( nil, @@ -91,12 +91,13 @@ final class AsyncIO: Sendable { } ) ) - } - self.write(dispatchData, to: diskIO) { writtenLength, error in - if let error = error { - continuation.resume(throwing: error) - } else { - continuation.resume(returning: writtenLength) + + self.write(dispatchData, to: diskIO) { writtenLength, error in + if let error { + continuation.resume(throwing: error) + } else { + continuation.resume(returning: writtenLength) + } } } } @@ -108,8 +109,8 @@ final class AsyncIO: Sendable { to diskIO: borrowing IOChannel ) async throws -> Int { try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in - let dispatchData = array.withUnsafeBytes { - return DispatchData( + array.withUnsafeBytes { + let dispatchData = DispatchData( bytesNoCopy: $0, deallocator: .custom( nil, @@ -118,12 +119,13 @@ final class AsyncIO: Sendable { } ) ) - } - self.write(dispatchData, to: diskIO) { writtenLength, error in - if let error = error { - continuation.resume(throwing: error) - } else { - continuation.resume(returning: writtenLength) + + self.write(dispatchData, to: diskIO) { writtenLength, error in + if let error { + continuation.resume(throwing: error) + } else { + continuation.resume(returning: writtenLength) + } } } } diff --git a/Sources/Subprocess/IO/AsyncIO+Linux.swift b/Sources/Subprocess/IO/AsyncIO+Linux.swift index c571085..57ba9c9 100644 --- a/Sources/Subprocess/IO/AsyncIO+Linux.swift +++ b/Sources/Subprocess/IO/AsyncIO+Linux.swift @@ -118,11 +118,7 @@ final class AsyncIO: Sendable { shutdownFileDescriptor: shutdownFileDescriptor ) let threadContext = Unmanaged.passRetained(context) -#if os(FreeBSD) || os(OpenBSD) - var thread: pthread_t? = nil -#else var thread: pthread_t = pthread_t() -#endif rc = pthread_create(&thread, nil, { args in func reportError(_ error: SubprocessError) { _registration.withLock { store in @@ -175,11 +171,13 @@ final class AsyncIO: Sendable { } // Notify the continuation - _registration.withLock { store in + let continuation = _registration.withLock { store -> SignalStream.Continuation? in if let continuation = store[targetFileDescriptor] { - continuation.yield(true) + return continuation } + return nil } + continuation?.yield(true) } } @@ -194,16 +192,10 @@ final class AsyncIO: Sendable { return } -#if os(FreeBSD) || os(OpenBSD) - let monitorThread = thread! -#else - let monitorThread = thread -#endif - let state = State( epollFileDescriptor: epollFileDescriptor, shutdownFileDescriptor: shutdownFileDescriptor, - monitorThread: monitorThread + monitorThread: thread ) self.state = .success(state) @@ -222,6 +214,8 @@ final class AsyncIO: Sendable { _ = _SubprocessCShims.write(currentState.shutdownFileDescriptor, &one, MemoryLayout.stride) // Cleanup the monitor thread pthread_join(currentState.monitorThread, nil) + close(currentState.epollFileDescriptor) + close(currentState.shutdownFileDescriptor) } @@ -394,7 +388,7 @@ extension AsyncIO { resultBuffer.removeLast(resultBuffer.count - readLength) return resultBuffer } else { - if errno == EAGAIN || errno == EWOULDBLOCK { + if self.shouldWaitForNextSignal(with: errno) { // No more data for now wait for the next signal break } else { @@ -443,7 +437,7 @@ extension AsyncIO { return writtenLength } } else { - if errno == EAGAIN || errno == EWOULDBLOCK { + if self.shouldWaitForNextSignal(with: errno) { // No more data for now wait for the next signal break } else { @@ -486,7 +480,7 @@ extension AsyncIO { return writtenLength } } else { - if errno == EAGAIN || errno == EWOULDBLOCK { + if self.shouldWaitForNextSignal(with: errno) { // No more data for now wait for the next signal break } else { @@ -500,6 +494,11 @@ extension AsyncIO { return 0 } #endif + + @inline(__always) + private func shouldWaitForNextSignal(with error: CInt) -> Bool { + return error == EAGAIN || error == EWOULDBLOCK || error == EINTR + } } extension Array : AsyncIO._ContiguousBytes where Element == UInt8 {} diff --git a/Sources/Subprocess/IO/AsyncIO+Windows.swift b/Sources/Subprocess/IO/AsyncIO+Windows.swift index bf75e59..72fd6a2 100644 --- a/Sources/Subprocess/IO/AsyncIO+Windows.swift +++ b/Sources/Subprocess/IO/AsyncIO+Windows.swift @@ -19,6 +19,7 @@ @preconcurrency import SystemPackage #endif +import _SubprocessCShims import Synchronization internal import Dispatch @preconcurrency import WinSDK @@ -71,7 +72,11 @@ final class AsyncIO: @unchecked Sendable { // Create monitor thread let threadContext = MonitorThreadContext(ioCompletionPort: port) let threadContextPtr = Unmanaged.passRetained(threadContext) - let threadHandle = CreateThread(nil, 0, { args in + /// Microsoft documentation for `CreateThread` states: + /// > A thread in an executable that calls the C run-time library (CRT) + /// > should use the _beginthreadex and _endthreadex functions for + /// > thread management rather than CreateThread and ExitThread + let threadHandleValue = _beginthreadex(nil, 0, { args in func reportError(_ error: SubprocessError) { _registration.withLock { store in for continuation in store.values { @@ -126,18 +131,23 @@ final class AsyncIO: @unchecked Sendable { break } // Notify the continuations - _registration.withLock { store in + let continuation = _registration.withLock { store -> SignalStream.Continuation? in if let continuation = store[targetFileDescriptor] { - continuation.yield(bytesTransferred) + return continuation } + return nil } + continuation?.yield(bytesTransferred) } return 0 }, threadContextPtr.toOpaque(), 0, nil) - guard let threadHandle = threadHandle else { + guard threadHandleValue > 0, + let threadHandle = HANDLE(bitPattern: threadHandleValue) else { + // _beginthreadex uses errno instead of GetLastError() + let capturedError = _subprocess_windows_get_errno() let error = SubprocessError( - code: .init(.asyncIOFailed("CreateThread failed")), - underlyingError: .init(rawValue: GetLastError()) + code: .init(.asyncIOFailed("_beginthreadex failed")), + underlyingError: .init(rawValue: capturedError) ) self.monitorThread = .failure(error) return @@ -156,10 +166,10 @@ final class AsyncIO: @unchecked Sendable { return } PostQueuedCompletionStatus( - ioPort, - 0, - shutdownPort, - nil + ioPort, // CompletionPort + 0, // Number of bytes transferred. + shutdownPort, // Completion key to post status + nil // Overlapped ) // Wait for monitor thread to exit WaitForSingleObject(monitorThreadHandle, INFINITE) @@ -246,26 +256,24 @@ final class AsyncIO: @unchecked Sendable { var signalStream = self.registerHandle(handle).makeAsyncIterator() while true { + // We use an empty `_OVERLAPPED()` here because `ReadFile` below + // only reads non-seekable files, aka pipes. var overlapped = _OVERLAPPED() let succeed = try resultBuffer.withUnsafeMutableBufferPointer { bufferPointer in // Get a pointer to the memory at the specified offset // Windows ReadFile uses DWORD for target count, which means we can only // read up to DWORD (aka UInt32) max. - let targetCount: DWORD - if MemoryLayout.size == MemoryLayout.size { - // On 32 bit systems we don't have to worry about overflowing - targetCount = DWORD(truncatingIfNeeded: bufferPointer.count - readLength) - } else { - // On 64 bit systems we need to cap the count at DWORD max - targetCount = DWORD(truncatingIfNeeded: min(bufferPointer.count - readLength, Int(UInt32.max))) - } + let targetCount: DWORD = self.calculateRemainingCount( + totalCount: bufferPointer.count, + readCount: readLength + ) let offsetAddress = bufferPointer.baseAddress!.advanced(by: readLength) // Read directly into the buffer at the offset return ReadFile( handle, offsetAddress, - DWORD(truncatingIfNeeded: targetCount), + targetCount, nil, &overlapped ) @@ -300,7 +308,7 @@ final class AsyncIO: @unchecked Sendable { return resultBuffer } else { // Read some data - readLength += Int(bytesRead) + readLength += Int(truncatingIfNeeded: bytesRead) if maxLength == .max { // Grow resultBuffer if needed guard Double(readLength) > 0.8 * Double(resultBuffer.count) else { @@ -333,24 +341,22 @@ final class AsyncIO: @unchecked Sendable { var signalStream = self.registerHandle(diskIO.channel).makeAsyncIterator() var writtenLength: Int = 0 while true { + // We use an empty `_OVERLAPPED()` here because `WriteFile` below + // only writes to non-seekable files, aka pipes. var overlapped = _OVERLAPPED() let succeed = try span.withUnsafeBytes { ptr in // Windows WriteFile uses DWORD for target count // which means we can only write up to DWORD max - let remainingLength: DWORD - if MemoryLayout.size == MemoryLayout.size { - // On 32 bit systems we don't have to worry about overflowing - remainingLength = DWORD(truncatingIfNeeded: ptr.count - writtenLength) - } else { - // On 64 bit systems we need to cap the count at DWORD max - remainingLength = DWORD(truncatingIfNeeded: min(ptr.count - writtenLength, Int(DWORD.max))) - } + let remainingLength: DWORD = self.calculateRemainingCount( + totalCount: ptr.count, + readCount: writtenLength + ) let startPtr = ptr.baseAddress!.advanced(by: writtenLength) return WriteFile( handle, startPtr, - DWORD(truncatingIfNeeded: remainingLength), + remainingLength, nil, &overlapped ) @@ -371,7 +377,7 @@ final class AsyncIO: @unchecked Sendable { // Now wait for read to finish let bytesWritten: DWORD = try await signalStream.next() ?? 0 - writtenLength += Int(bytesWritten) + writtenLength += Int(truncatingIfNeeded: bytesWritten) if writtenLength >= span.byteCount { return writtenLength } @@ -387,23 +393,21 @@ final class AsyncIO: @unchecked Sendable { var signalStream = self.registerHandle(diskIO.channel).makeAsyncIterator() var writtenLength: Int = 0 while true { + // We use an empty `_OVERLAPPED()` here because `WriteFile` below + // only writes to non-seekable files, aka pipes. var overlapped = _OVERLAPPED() let succeed = try bytes.withUnsafeBytes { ptr in // Windows WriteFile uses DWORD for target count // which means we can only write up to DWORD max - let remainingLength: DWORD - if MemoryLayout.size == MemoryLayout.size { - // On 32 bit systems we don't have to worry about overflowing - remainingLength = DWORD(truncatingIfNeeded: ptr.count - writtenLength) - } else { - // On 64 bit systems we need to cap the count at DWORD max - remainingLength = DWORD(truncatingIfNeeded: min(ptr.count - writtenLength, Int(DWORD.max))) - } + let remainingLength: DWORD = self.calculateRemainingCount( + totalCount: ptr.count, + readCount: writtenLength + ) let startPtr = ptr.baseAddress!.advanced(by: writtenLength) return WriteFile( handle, startPtr, - DWORD(truncatingIfNeeded: remainingLength), + remainingLength, nil, &overlapped ) @@ -423,12 +427,25 @@ final class AsyncIO: @unchecked Sendable { } // Now wait for read to finish let bytesWritten: DWORD = try await signalStream.next() ?? 0 - writtenLength += Int(bytesWritten) + writtenLength += Int(truncatingIfNeeded: bytesWritten) if writtenLength >= bytes.count { return writtenLength } } } + + // Windows ReadFile uses DWORD for target count, which means we can only + // read up to DWORD (aka UInt32) max. + private func calculateRemainingCount(totalCount: Int, readCount: Int) -> DWORD { + // We support both 32bit and 64bit systems for Windows + if MemoryLayout.size == MemoryLayout.size { + // On 32 bit systems we don't have to worry about overflowing + return DWORD(truncatingIfNeeded: totalCount - readCount) + } else { + // On 64 bit systems we need to cap the count at DWORD max + return DWORD(truncatingIfNeeded: min(totalCount - readCount, Int(DWORD.max))) + } + } } extension Array : AsyncIO._ContiguousBytes where Element == UInt8 {} diff --git a/Sources/Subprocess/IO/Output.swift b/Sources/Subprocess/IO/Output.swift index 223ccd6..c2744b9 100644 --- a/Sources/Subprocess/IO/Output.swift +++ b/Sources/Subprocess/IO/Output.swift @@ -152,7 +152,18 @@ public struct BytesOutput: OutputProtocol { internal func captureOutput( from diskIO: consuming IOChannel ) async throws -> [UInt8] { - let result = try await AsyncIO.shared.read(from: diskIO, upTo: self.maxSize) + #if canImport(Darwin) + var result: DispatchData? = nil + #else + var result: [UInt8]? = nil + #endif + do { + result = try await AsyncIO.shared.read(from: diskIO, upTo: self.maxSize) + } catch { + try diskIO.safelyClose() + throw error + } + try diskIO.safelyClose() #if canImport(Darwin) return result?.array() ?? [] @@ -281,9 +292,19 @@ extension OutputProtocol { if let bytesOutput = self as? BytesOutput { return try await bytesOutput.captureOutput(from: diskIO) as! Self.OutputType } - // Force unwrap is safe here because only `OutputType.self == Void` would - // have nil `IOChannel` - let result = try await AsyncIO.shared.read(from: diskIO, upTo: self.maxSize) + + #if canImport(Darwin) + var result: DispatchData? = nil + #else + var result: [UInt8]? = nil + #endif + do { + result = try await AsyncIO.shared.read(from: diskIO, upTo: self.maxSize) + } catch { + try diskIO.safelyClose() + throw error + } + try diskIO.safelyClose() #if canImport(Darwin) return try self.output(from: result ?? .empty) diff --git a/Sources/Subprocess/Platforms/Subprocess+Linux.swift b/Sources/Subprocess/Platforms/Subprocess+Linux.swift index 3b449ac..45db2a7 100644 --- a/Sources/Subprocess/Platforms/Subprocess+Linux.swift +++ b/Sources/Subprocess/Platforms/Subprocess+Linux.swift @@ -285,11 +285,7 @@ internal func monitorProcessTermination( // Small helper to provide thread-safe access to the child process to continuations map as well as a condition variable to suspend the calling thread when there are no subprocesses to wait for. Note that Mutex cannot be used here because we need the semantics of pthread_cond_wait, which requires passing the pthread_mutex_t instance as a parameter, something the Mutex API does not provide access to. private final class ChildProcessContinuations: Sendable { - #if os(FreeBSD) || os(OpenBSD) - typealias MutexType = pthread_mutex_t? - #else typealias MutexType = pthread_mutex_t - #endif private nonisolated(unsafe) var continuations = [pid_t: CheckedContinuation]() private nonisolated(unsafe) let mutex = UnsafeMutablePointer.allocate(capacity: 1) diff --git a/Sources/Subprocess/Platforms/Subprocess+Windows.swift b/Sources/Subprocess/Platforms/Subprocess+Windows.swift index 63229aa..fe3c54f 100644 --- a/Sources/Subprocess/Platforms/Subprocess+Windows.swift +++ b/Sources/Subprocess/Platforms/Subprocess+Windows.swift @@ -358,7 +358,7 @@ public struct PlatformOptions: Sendable { public static let inherit: Self = .init(.inherit) } - /// `ConsoleBehavior` defines how should the window appear + /// `WindowStyle` defines how should the window appear /// when spawning a new process public struct WindowStyle: Sendable, Hashable { internal enum Storage: Sendable, Hashable { diff --git a/Sources/_SubprocessCShims/include/process_shims.h b/Sources/_SubprocessCShims/include/process_shims.h index 0ae4a5a..6eb185e 100644 --- a/Sources/_SubprocessCShims/include/process_shims.h +++ b/Sources/_SubprocessCShims/include/process_shims.h @@ -90,6 +90,7 @@ typedef int BOOL; #endif BOOL _subprocess_windows_send_vm_close(DWORD pid); +unsigned int _subprocess_windows_get_errno(void); #endif diff --git a/Sources/_SubprocessCShims/process_shims.c b/Sources/_SubprocessCShims/process_shims.c index cf7ca66..c79ca54 100644 --- a/Sources/_SubprocessCShims/process_shims.c +++ b/Sources/_SubprocessCShims/process_shims.c @@ -779,5 +779,9 @@ BOOL _subprocess_windows_send_vm_close( return FALSE; } +unsigned int _subprocess_windows_get_errno(void) { + return errno; +} + #endif From 9c6703f2a004d561cdcaeb57cce76cfca536ee3f Mon Sep 17 00:00:00 2001 From: Charles Hu Date: Fri, 25 Jul 2025 15:27:38 -0700 Subject: [PATCH 09/10] Remove DispatchData._ContiguousBufferView since Linux no longer relies on DispatchIO and we can just use DispatchData.Region on Darwin --- Package.swift | 2 +- Sources/Subprocess/Buffer.swift | 55 ++----- Sources/Subprocess/Configuration.swift | 159 ++++++++++++--------- Sources/Subprocess/IO/AsyncIO+Darwin.swift | 2 +- 4 files changed, 111 insertions(+), 107 deletions(-) diff --git a/Package.swift b/Package.swift index 0e9c521..e1a8810 100644 --- a/Package.swift +++ b/Package.swift @@ -6,7 +6,7 @@ import PackageDescription var dep: [Package.Dependency] = [ .package( url: "https://github.com/apple/swift-system", - from: "1.5.0" + exact: "1.5.0" ) ] #if !os(Windows) diff --git a/Sources/Subprocess/Buffer.swift b/Sources/Subprocess/Buffer.swift index 292fac4..4cd59d3 100644 --- a/Sources/Subprocess/Buffer.swift +++ b/Sources/Subprocess/Buffer.swift @@ -20,15 +20,15 @@ extension AsyncBufferSequence { #if canImport(Darwin) // We need to keep the backingData alive while Slice is alive internal let backingData: DispatchData - internal let data: DispatchData._ContiguousBufferView + internal let data: DispatchData.Region - internal init(data: DispatchData._ContiguousBufferView, backingData: DispatchData) { + internal init(data: DispatchData.Region, backingData: DispatchData) { self.data = data self.backingData = backingData } internal static func createFrom(_ data: DispatchData) -> [Buffer] { - let slices = data.contiguousBufferViews + let slices = data.regions // In most (all?) cases data should only have one slice if _fastPath(slices.count == 1) { return [.init(data: slices[0], backingData: data)] @@ -98,54 +98,27 @@ extension AsyncBufferSequence.Buffer: Equatable, Hashable { } public func hash(into hasher: inout Hasher) { - hasher.combine(self.data) + return self.data.hash(into: &hasher) } #endif // else Compiler generated conformances } -// MARK: - DispatchData.Block -#if canImport(Darwin) || canImport(Glibc) || canImport(Android) || canImport(Musl) -extension DispatchData { - /// Unfortunately `DispatchData.Region` is not available on Linux, hence our own wrapper - internal struct _ContiguousBufferView: @unchecked Sendable, RandomAccessCollection, Hashable { - typealias Element = UInt8 - - internal let bytes: UnsafeBufferPointer - - internal var startIndex: Int { self.bytes.startIndex } - internal var endIndex: Int { self.bytes.endIndex } - - internal init(bytes: UnsafeBufferPointer) { - self.bytes = bytes - } - - internal func withUnsafeBytes(_ body: (UnsafeRawBufferPointer) throws -> ResultType) rethrows -> ResultType { - return try body(UnsafeRawBufferPointer(self.bytes)) - } - - internal func hash(into hasher: inout Hasher) { - hasher.combine(bytes: UnsafeRawBufferPointer(self.bytes)) - } - - internal static func == (lhs: DispatchData._ContiguousBufferView, rhs: DispatchData._ContiguousBufferView) -> Bool { - return lhs.bytes.elementsEqual(rhs.bytes) - } - - subscript(position: Int) -> UInt8 { - _read { - yield self.bytes[position] +#if canImport(Darwin) +extension DispatchData.Region { + static func == (lhs: DispatchData.Region, rhs: DispatchData.Region) -> Bool { + return lhs.withUnsafeBytes { lhsBytes in + return rhs.withUnsafeBytes { rhsBytes in + return lhsBytes.elementsEqual(rhsBytes) } } } - internal var contiguousBufferViews: [_ContiguousBufferView] { - var slices = [_ContiguousBufferView]() - enumerateBytes { (bytes, index, stop) in - slices.append(_ContiguousBufferView(bytes: bytes)) + internal func hash(into hasher: inout Hasher) { + return self.withUnsafeBytes { ptr in + return hasher.combine(bytes: ptr) } - return slices } } - #endif + diff --git a/Sources/Subprocess/Configuration.swift b/Sources/Subprocess/Configuration.swift index 7c8be26..1fcafdd 100644 --- a/Sources/Subprocess/Configuration.swift +++ b/Sources/Subprocess/Configuration.swift @@ -29,6 +29,8 @@ import Musl internal import Dispatch +import Synchronization + /// A collection of configurations parameters to use when /// spawning a subprocess. public struct Configuration: Sendable { @@ -775,6 +777,16 @@ internal struct IOChannel: ~Copyable, @unchecked Sendable { } } +#if canImport(WinSDK) +internal enum PipeNameCounter { + private static let value = Atomic(0) + + internal static func nextValue() -> UInt64 { + return self.value.add(1, ordering: .relaxed).newValue + } +} +#endif + internal struct CreatedPipe: ~Copyable { internal enum Purpose: CustomStringConvertible { /// This pipe is used for standard input. This option maps to @@ -817,77 +829,96 @@ internal struct CreatedPipe: ~Copyable { internal init(closeWhenDone: Bool, purpose: Purpose) throws { #if canImport(WinSDK) - // On Windows, we need to create a named pipe - let pipeName = "\\\\.\\pipe\\subprocess-\(purpose)-\(Int.random(in: .min ..< .max))" - var saAttributes: SECURITY_ATTRIBUTES = SECURITY_ATTRIBUTES() - saAttributes.nLength = DWORD(MemoryLayout.size) - saAttributes.bInheritHandle = true - saAttributes.lpSecurityDescriptor = nil - - let parentEnd = pipeName.withCString( - encodedAs: UTF16.self - ) { pipeNameW in - // Use OVERLAPPED for async IO - var openMode: DWORD = DWORD(FILE_FLAG_OVERLAPPED) - switch purpose { - case .input: - openMode |= DWORD(PIPE_ACCESS_OUTBOUND) - case .output: - openMode |= DWORD(PIPE_ACCESS_INBOUND) + /// On Windows, we need to create a named pipe. + /// According to Microsoft documentation: + /// > Asynchronous (overlapped) read and write operations are + /// > not supported by anonymous pipes. + /// See https://learn.microsoft.com/en-us/windows/win32/ipc/anonymous-pipe-operations + while true { + /// Windows named pipes are system wide. To avoid creating two pipes with the same + /// name, create the pipe with `FILE_FLAG_FIRST_PIPE_INSTANCE` such that it will + /// return error `ERROR_ACCESS_DENIED` if we try to create another pipe with the same name. + let pipeName = "\\\\.\\pipe\\LOCAL\\subprocess-\(purpose)-\(PipeNameCounter.nextValue())" + var saAttributes: SECURITY_ATTRIBUTES = SECURITY_ATTRIBUTES() + saAttributes.nLength = DWORD(MemoryLayout.size) + saAttributes.bInheritHandle = true + saAttributes.lpSecurityDescriptor = nil + + let parentEnd = pipeName.withCString( + encodedAs: UTF16.self + ) { pipeNameW in + // Use OVERLAPPED for async IO + var openMode: DWORD = DWORD(FILE_FLAG_OVERLAPPED | FILE_FLAG_FIRST_PIPE_INSTANCE) + switch purpose { + case .input: + openMode |= DWORD(PIPE_ACCESS_OUTBOUND) + case .output: + openMode |= DWORD(PIPE_ACCESS_INBOUND) + } + + return CreateNamedPipeW( + pipeNameW, + openMode, + DWORD(PIPE_TYPE_BYTE | PIPE_READMODE_BYTE | PIPE_WAIT), + 1, // Max instance, + DWORD(readBufferSize), + DWORD(readBufferSize), + 0, + &saAttributes + ) + } + guard let parentEnd, parentEnd != INVALID_HANDLE_VALUE else { + // Since we created the pipe with `FILE_FLAG_FIRST_PIPE_INSTANCE`, + // if there's already a pipe with the same name, GetLastError() + // will be set to FILE_FLAG_FIRST_PIPE_INSTANCE. In this case, + // try again with a different name. + let errorCode = GetLastError() + guard errorCode != FILE_FLAG_FIRST_PIPE_INSTANCE else { + continue + } + // Throw all other errors + throw SubprocessError( + code: .init(.asyncIOFailed("CreateNamedPipeW failed")), + underlyingError: .init(rawValue: GetLastError()) + ) } - return CreateNamedPipeW( - pipeNameW, - openMode, - DWORD(PIPE_TYPE_BYTE | PIPE_READMODE_BYTE | PIPE_WAIT), - 1, // Max instance, - DWORD(readBufferSize), - DWORD(readBufferSize), - 0, - &saAttributes - ) - } - guard let parentEnd, parentEnd != INVALID_HANDLE_VALUE else { - throw SubprocessError( - code: .init(.asyncIOFailed("CreateNamedPipeW failed")), - underlyingError: .init(rawValue: GetLastError()) - ) - } + let childEnd = pipeName.withCString( + encodedAs: UTF16.self + ) { pipeNameW in + var targetAccess: DWORD = 0 + switch purpose { + case .input: + targetAccess = DWORD(GENERIC_READ) + case .output: + targetAccess = DWORD(GENERIC_WRITE) + } - let childEnd = pipeName.withCString( - encodedAs: UTF16.self - ) { pipeNameW in - var targetAccess: DWORD = 0 + return CreateFileW( + pipeNameW, + targetAccess, + 0, + &saAttributes, + DWORD(OPEN_EXISTING), + DWORD(FILE_ATTRIBUTE_NORMAL | FILE_FLAG_OVERLAPPED), + nil + ) + } + guard let childEnd, childEnd != INVALID_HANDLE_VALUE else { + throw SubprocessError( + code: .init(.asyncIOFailed("CreateFileW failed")), + underlyingError: .init(rawValue: GetLastError()) + ) + } switch purpose { case .input: - targetAccess = DWORD(GENERIC_READ) + self._readFileDescriptor = .init(childEnd, closeWhenDone: closeWhenDone) + self._writeFileDescriptor = .init(parentEnd, closeWhenDone: closeWhenDone) case .output: - targetAccess = DWORD(GENERIC_WRITE) + self._readFileDescriptor = .init(parentEnd, closeWhenDone: closeWhenDone) + self._writeFileDescriptor = .init(childEnd, closeWhenDone: closeWhenDone) } - - return CreateFileW( - pipeNameW, - targetAccess, - 0, - &saAttributes, - DWORD(OPEN_EXISTING), - DWORD(FILE_ATTRIBUTE_NORMAL | FILE_FLAG_OVERLAPPED), - nil - ) - } - guard let childEnd, childEnd != INVALID_HANDLE_VALUE else { - throw SubprocessError( - code: .init(.asyncIOFailed("CreateFileW failed")), - underlyingError: .init(rawValue: GetLastError()) - ) - } - switch purpose { - case .input: - self._readFileDescriptor = .init(childEnd, closeWhenDone: closeWhenDone) - self._writeFileDescriptor = .init(parentEnd, closeWhenDone: closeWhenDone) - case .output: - self._readFileDescriptor = .init(parentEnd, closeWhenDone: closeWhenDone) - self._writeFileDescriptor = .init(childEnd, closeWhenDone: closeWhenDone) + return } #else let pipe = try FileDescriptor.pipe() diff --git a/Sources/Subprocess/IO/AsyncIO+Darwin.swift b/Sources/Subprocess/IO/AsyncIO+Darwin.swift index 6c5d258..4d355f3 100644 --- a/Sources/Subprocess/IO/AsyncIO+Darwin.swift +++ b/Sources/Subprocess/IO/AsyncIO+Darwin.swift @@ -46,7 +46,7 @@ final class AsyncIO: Sendable { dispatchIO.read( offset: 0, length: maxLength, - queue: .global() + queue: DispatchQueue(label: "SubprocessReadQueue") ) { done, data, error in if error != 0 { continuation.resume( From 6f030af3e52ca77a273df43570ef2b6990caa3b6 Mon Sep 17 00:00:00 2001 From: Charles Hu Date: Mon, 28 Jul 2025 13:48:10 -0700 Subject: [PATCH 10/10] Linux: reap child process if fork succeed but exec fails --- Package.swift | 2 ++ Sources/Subprocess/Configuration.swift | 4 ++-- Sources/Subprocess/IO/AsyncIO+Linux.swift | 12 ++++++++++-- Sources/Subprocess/Platforms/Subprocess+Linux.swift | 10 +--------- Sources/_SubprocessCShims/process_shims.c | 9 ++++++++- Tests/SubprocessTests/SubprocessTests+Unix.swift | 2 +- 6 files changed, 24 insertions(+), 15 deletions(-) diff --git a/Package.swift b/Package.swift index e1a8810..62cf4da 100644 --- a/Package.swift +++ b/Package.swift @@ -6,6 +6,8 @@ import PackageDescription var dep: [Package.Dependency] = [ .package( url: "https://github.com/apple/swift-system", + // Temporarily pin to 1.5.0 because 1.6.0 has a breaking change for Ubuntu Focal + // https://github.com/apple/swift-system/issues/237 exact: "1.5.0" ) ] diff --git a/Sources/Subprocess/Configuration.swift b/Sources/Subprocess/Configuration.swift index 1fcafdd..6e6cd38 100644 --- a/Sources/Subprocess/Configuration.swift +++ b/Sources/Subprocess/Configuration.swift @@ -870,10 +870,10 @@ internal struct CreatedPipe: ~Copyable { guard let parentEnd, parentEnd != INVALID_HANDLE_VALUE else { // Since we created the pipe with `FILE_FLAG_FIRST_PIPE_INSTANCE`, // if there's already a pipe with the same name, GetLastError() - // will be set to FILE_FLAG_FIRST_PIPE_INSTANCE. In this case, + // will be set to ERROR_ACCESS_DENIED. In this case, // try again with a different name. let errorCode = GetLastError() - guard errorCode != FILE_FLAG_FIRST_PIPE_INSTANCE else { + guard errorCode != ERROR_ACCESS_DENIED else { continue } // Throw all other errors diff --git a/Sources/Subprocess/IO/AsyncIO+Linux.swift b/Sources/Subprocess/IO/AsyncIO+Linux.swift index 57ba9c9..783b04c 100644 --- a/Sources/Subprocess/IO/AsyncIO+Linux.swift +++ b/Sources/Subprocess/IO/AsyncIO+Linux.swift @@ -214,8 +214,16 @@ final class AsyncIO: Sendable { _ = _SubprocessCShims.write(currentState.shutdownFileDescriptor, &one, MemoryLayout.stride) // Cleanup the monitor thread pthread_join(currentState.monitorThread, nil) - close(currentState.epollFileDescriptor) - close(currentState.shutdownFileDescriptor) + var closeError: CInt = 0 + if _SubprocessCShims.close(currentState.epollFileDescriptor) != 0 { + closeError = errno + } + if _SubprocessCShims.close(currentState.shutdownFileDescriptor) != 0 { + closeError = errno + } + if closeError != 0 { + fatalError("Failed to close epollfd: \(String(cString: strerror(closeError)))") + } } diff --git a/Sources/Subprocess/Platforms/Subprocess+Linux.swift b/Sources/Subprocess/Platforms/Subprocess+Linux.swift index 45db2a7..ae1b5d8 100644 --- a/Sources/Subprocess/Platforms/Subprocess+Linux.swift +++ b/Sources/Subprocess/Platforms/Subprocess+Linux.swift @@ -131,14 +131,6 @@ extension Configuration { underlyingError: .init(rawValue: spawnError) ) } - func captureError(_ work: () throws -> Void) -> (any Swift.Error)? { - do { - try work() - return nil - } catch { - return error - } - } // After spawn finishes, close all child side fds try self.safelyCloseMultiple( inputRead: inputReadFileDescriptor, @@ -267,7 +259,7 @@ extension String { internal func monitorProcessTermination( forExecution execution: Execution ) async throws -> TerminationStatus { - try await withCheckedThrowingContinuation { continuation in + return try await withCheckedThrowingContinuation { continuation in _childProcessContinuations.withLock { continuations in // We don't need to worry about a race condition here because waitid() // does not clear the wait/zombie state of the child process. If it sees diff --git a/Sources/_SubprocessCShims/process_shims.c b/Sources/_SubprocessCShims/process_shims.c index c79ca54..c6a70d1 100644 --- a/Sources/_SubprocessCShims/process_shims.c +++ b/Sources/_SubprocessCShims/process_shims.c @@ -689,7 +689,14 @@ int _subprocess_fork_exec( // exec worked! close(pipefd[0]); return 0; - } else if (read_rc > 0) { + } + // if we reach this point, exec failed. + // Since we already have the child pid (fork succeed), reap the child + // This mimic posix_spawn behavior + siginfo_t info; + waitid(P_PID, childPid, &info, WEXITED); + + if (read_rc > 0) { // Child exec failed and reported back close(pipefd[0]); return childError; diff --git a/Tests/SubprocessTests/SubprocessTests+Unix.swift b/Tests/SubprocessTests/SubprocessTests+Unix.swift index 25a464e..4dc1a14 100644 --- a/Tests/SubprocessTests/SubprocessTests+Unix.swift +++ b/Tests/SubprocessTests/SubprocessTests+Unix.swift @@ -668,7 +668,7 @@ extension SubprocessUnixTests { var platformOptions = PlatformOptions() platformOptions.supplementaryGroups = Array(expectedGroups) let idResult = try await Subprocess.run( - .path("/usr/bin/swift"), + .name("swift"), arguments: [getgroupsSwift.string], platformOptions: platformOptions, output: .string