diff --git a/Sources/Subprocess/Configuration.swift b/Sources/Subprocess/Configuration.swift index 5396506..2d3a446 100644 --- a/Sources/Subprocess/Configuration.swift +++ b/Sources/Subprocess/Configuration.swift @@ -585,7 +585,7 @@ internal enum StringOrRawBytes: Sendable, Hashable { /// automatically when done. internal struct TrackedFileDescriptor: ~Copyable { internal var closeWhenDone: Bool - internal let fileDescriptor: FileDescriptor + internal var fileDescriptor: FileDescriptor internal init( _ fileDescriptor: FileDescriptor, @@ -675,7 +675,7 @@ internal struct TrackedDispatchIO: ~Copyable { return } closeWhenDone = false - dispatchIO.close() + dispatchIO.close(flags: [.stop]) } deinit { diff --git a/Sources/Subprocess/IO/Output.swift b/Sources/Subprocess/IO/Output.swift index 454eca4..d2730c6 100644 --- a/Sources/Subprocess/IO/Output.swift +++ b/Sources/Subprocess/IO/Output.swift @@ -142,15 +142,9 @@ public struct BytesOutput: OutputProtocol { internal func captureOutput( from diskIO: consuming TrackedPlatformDiskIO? ) async throws -> [UInt8] { - #if os(Windows) - let result = try await diskIO?.fileDescriptor.read(upToLength: self.maxSize) ?? [] - try diskIO?.safelyClose() - return result - #else - let result = try await diskIO!.dispatchIO.read(upToLength: self.maxSize) - try diskIO?.safelyClose() - return result?.array() ?? [] - #endif + try await diskIO!.readCancellable { diskIO in + try await diskIO.read(upToLength: self.maxSize)?.array() ?? [] + } } #if SubprocessSpan @@ -264,15 +258,11 @@ extension OutputProtocol { if OutputType.self == Void.self { return () as! OutputType } - #if os(Windows) - let result = try await diskIO?.fileDescriptor.read(upToLength: self.maxSize) - try diskIO?.safelyClose() - return try self.output(from: result ?? []) - #else - let result = try await diskIO!.dispatchIO.read(upToLength: self.maxSize) - try diskIO?.safelyClose() - return try self.output(from: result ?? .empty) - #endif + + return try await diskIO!.readCancellable { diskIO in + let result = try await diskIO.read(upToLength: self.maxSize) + return try self.output(from: result ?? .empty) + } } } @@ -338,3 +328,65 @@ extension DispatchData { return result ?? [] } } + +#if os(Windows) +typealias PlatformIO = FileDescriptor +extension Array { + fileprivate var empty: Self { + [] + } + + fileprivate func array() -> Self { + self + } +} +#else +typealias PlatformIO = DispatchIO +#endif + +/// Runs `block` while _immediately_ responding to cancellation by throwing a `CancellationError` if the parent task is cancelled, regardless of whether `block` reacts to cancellation. +fileprivate func withImmediateCancellation(_ block: @escaping @Sendable () async throws -> T) async throws -> T { + // (ab)use an AsyncStream to return the buffer or immediately react to cancellation + let stream = AsyncThrowingStream { + return try await block() + } + var it = stream.makeAsyncIterator() + guard let next = try await it.next() else { + throw CancellationError() + } + return next +} + +extension TrackedPlatformDiskIO { + mutating func readCancellable(_ block: @escaping @Sendable (PlatformIO) async throws -> OutputType) async throws -> OutputType { + try await tryFinally { + let io: PlatformIO + #if os(Windows) + io = fileDescriptor + #else + io = dispatchIO + #endif + return try await withImmediateCancellation { + try await block(io) + } + } finally: { _ in + try safelyClose() + } + } +} + +fileprivate func tryFinally(_ work: () async throws -> T, finally: (Error?) async throws -> ()) async throws -> T { + let result: Result + do { + result = try await .success(work()) + } catch let e { + result = .failure(e) + } + switch result { + case .success: + try await finally(nil) + case let .failure(error): + try await finally(error) + } + return try result.get() +}