diff --git a/.buildkite/download-xcframework.sh b/.buildkite/download-xcframework.sh index b8b7f2abb..e7fb211c3 100755 --- a/.buildkite/download-xcframework.sh +++ b/.buildkite/download-xcframework.sh @@ -4,7 +4,6 @@ set -euo pipefail echo "--- :arrow_down: Downloading xcframework" buildkite-agent artifact download target/libwordpressFFI.xcframework.zip . --step "xcframework" -buildkite-agent artifact download native/swift/Sources/wordpress-api-wrapper/wp_api.swift . --step "xcframework" -buildkite-agent artifact download native/swift/Sources/wordpress-api-wrapper/wp_localization.swift . --step "xcframework" +buildkite-agent artifact download 'native/swift/Sources/wordpress-api-wrapper/*.swift' . --step "xcframework" unzip target/libwordpressFFI.xcframework.zip -d . rm target/libwordpressFFI.xcframework.zip diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 01f7bcca0..927ad6018 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -74,10 +74,7 @@ steps: zip -r target/libwordpressFFI.xcframework.zip target/libwordpressFFI.xcframework artifact_paths: - target/libwordpressFFI.xcframework.zip - - native/swift/Sources/wordpress-api-wrapper/wp_api.swift - - native/swift/Sources/wordpress-api-wrapper/wp_localization.swift - - native/swift/Sources/wordpress-api-wrapper/wp_com.swift - - native/swift/Sources/wordpress-api-wrapper/jetpack.swift + - native/swift/Sources/wordpress-api-wrapper/*.swift agents: queue: mac - label: ":swift: Build Docs" diff --git a/.swiftlint.yml b/.swiftlint.yml index d0a3be016..a1bd52483 100644 --- a/.swiftlint.yml +++ b/.swiftlint.yml @@ -3,10 +3,8 @@ strict: true included: - native/swift excluded: # paths to ignore during linting. Takes precedence over `included`. - - native/swift/Sources/wordpress-api-wrapper/wp_api.swift # auto-generated code - - native/swift/Sources/wordpress-api-wrapper/wp_localization.swift # auto-generated code - - native/swift/Sources/wordpress-api-wrapper/wp_com.swift # auto-generated code - - native/swift/Sources/wordpress-api-wrapper/jetpack.swift # auto-generated code + - native/swift/Sources/wordpress-api-wrapper # auto-generated code + - native/swift/**/.build disabled_rules: # Don't think we should enable this rule. # See https://github.com/realm/SwiftLint/issues/5263 for context. @@ -19,4 +17,3 @@ disabled_rules: - file_length - function_body_length - type_body_length - diff --git a/Makefile b/Makefile index 729deb4fa..19035d941 100644 --- a/Makefile +++ b/Makefile @@ -102,7 +102,7 @@ _build-apple-%-tvos _build-apple-%-tvos-sim _build-apple-%-watchos _build-apple- # Build the library for a specific target _build-apple-%: - cargo $(CARGO_OPTS) $(cargo_config_library) build --target $* --package wp_api --profile $(CARGO_PROFILE) + cargo $(CARGO_OPTS) $(cargo_config_library) build --target $* --features export-uncancellable-endpoints --package wp_api --profile $(CARGO_PROFILE) ./scripts/swift-bindings.sh target/$*/$(CARGO_PROFILE_DIRNAME)/libwp_api.a # Build the library for one single platform, including real device and simulator. @@ -141,7 +141,7 @@ docker-image-web: docker build -t wordpress-rs-web -f wp_rs_web/Dockerfile . --progress=plain swift-linux-library: - cargo build --release --package wp_api + cargo build --release --features export-uncancellable-endpoints --package wp_api ./scripts/swift-bindings.sh target/release/libwp_api.a mkdir -p target/release/libwordpressFFI-linux cp target/release/swift-bindings/Headers/* target/release/libwordpressFFI-linux/ @@ -272,7 +272,7 @@ validate-localizations: @# Help: Validate localization files using `wp_localization_validation` crate $(rust_docker_run) /bin/bash -c "cargo run --bin wp_localization_validation -- --localization-folder ./wp_localization/localization/" -format-swift: +fmt-swift: @# Help: Format the Swift binding code xcrun swift format -i -r native/swift/Sources/wordpress-api-wrapper diff --git a/fastlane/Fastfile b/fastlane/Fastfile index 9f53aff15..ac7bd7377 100644 --- a/fastlane/Fastfile +++ b/fastlane/Fastfile @@ -412,8 +412,7 @@ end def xcframework_bindings_file_path dir = File.join(PROJECT_ROOT, 'native', 'swift', 'Sources', 'wordpress-api-wrapper') - %w[wp_api.swift wp_localization.swift] - .map { |file| File.join(dir, file) } + Dir.glob(File.join(dir, '*.swift')) end def remove_lane_context_values(names) diff --git a/native/kotlin/api/kotlin/src/integrationTest/kotlin/MockRequestExecutor.kt b/native/kotlin/api/kotlin/src/integrationTest/kotlin/MockRequestExecutor.kt index c1dabe837..2db1f170a 100644 --- a/native/kotlin/api/kotlin/src/integrationTest/kotlin/MockRequestExecutor.kt +++ b/native/kotlin/api/kotlin/src/integrationTest/kotlin/MockRequestExecutor.kt @@ -3,6 +3,7 @@ package rs.wordpress.api.kotlin import kotlinx.coroutines.delay import okio.FileNotFoundException import uniffi.wp_api.MediaUploadRequest +import uniffi.wp_api.RequestContext import uniffi.wp_api.RequestExecutor import uniffi.wp_api.WpNetworkHeaderMap import uniffi.wp_api.WpNetworkRequest @@ -47,6 +48,10 @@ class MockRequestExecutor(private var stubs: List = listOf()) : RequestExe override suspend fun sleep(millis: ULong) { delay(millis.toLong()) } + + override fun cancel(context: RequestContext) { + // No-op + } } val WpNetworkResponse.Companion.empty: WpNetworkResponse diff --git a/native/kotlin/api/kotlin/src/main/kotlin/rs/wordpress/api/kotlin/DebugMiddleware.kt b/native/kotlin/api/kotlin/src/main/kotlin/rs/wordpress/api/kotlin/DebugMiddleware.kt index 8498787f1..172f4d949 100644 --- a/native/kotlin/api/kotlin/src/main/kotlin/rs/wordpress/api/kotlin/DebugMiddleware.kt +++ b/native/kotlin/api/kotlin/src/main/kotlin/rs/wordpress/api/kotlin/DebugMiddleware.kt @@ -1,4 +1,5 @@ package rs.wordpress.api.kotlin +import uniffi.wp_api.RequestContext import uniffi.wp_api.RequestExecutor import uniffi.wp_api.WpApiMiddleware import uniffi.wp_api.WpNetworkRequest @@ -9,7 +10,8 @@ class DebugMiddleware : WpApiMiddleware { override suspend fun process( requestExecutor: RequestExecutor, response: WpNetworkResponse, - request: WpNetworkRequest + request: WpNetworkRequest, + context: RequestContext? ): WpNetworkResponse { println("Request: ${request.url()}") println("Response:") diff --git a/native/kotlin/api/kotlin/src/main/kotlin/rs/wordpress/api/kotlin/WpRequestExecutor.kt b/native/kotlin/api/kotlin/src/main/kotlin/rs/wordpress/api/kotlin/WpRequestExecutor.kt index 96b2cb7ed..c15826bfd 100644 --- a/native/kotlin/api/kotlin/src/main/kotlin/rs/wordpress/api/kotlin/WpRequestExecutor.kt +++ b/native/kotlin/api/kotlin/src/main/kotlin/rs/wordpress/api/kotlin/WpRequestExecutor.kt @@ -16,6 +16,7 @@ import okhttp3.RequestBody.Companion.toRequestBody import uniffi.wp_api.InvalidSslErrorReason import uniffi.wp_api.MediaUploadRequest import uniffi.wp_api.MediaUploadRequestExecutionException +import uniffi.wp_api.RequestContext import uniffi.wp_api.RequestExecutionErrorReason import uniffi.wp_api.RequestExecutionException import uniffi.wp_api.RequestExecutor @@ -150,6 +151,10 @@ class WpRequestExecutor( delay(millis.toLong()) } + override fun cancel(context: RequestContext) { + // No-op + } + private fun File.canBeUploaded() = exists() && isFile && canRead() /** diff --git a/native/swift/Example/Example.xcodeproj/xcshareddata/xcschemes/Example.xcscheme b/native/swift/Example/Example.xcodeproj/xcshareddata/xcschemes/Example.xcscheme new file mode 100644 index 000000000..81cf8a708 --- /dev/null +++ b/native/swift/Example/Example.xcodeproj/xcshareddata/xcschemes/Example.xcscheme @@ -0,0 +1,78 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/native/swift/Sources/wordpress-api/Middleware.swift b/native/swift/Sources/wordpress-api/Middleware.swift index 1e4524e23..6f4c0c7eb 100644 --- a/native/swift/Sources/wordpress-api/Middleware.swift +++ b/native/swift/Sources/wordpress-api/Middleware.swift @@ -4,7 +4,8 @@ public final class DebugMiddleware: WpApiMiddleware { public func process( requestExecutor: any WordPressAPIInternal.RequestExecutor, response: WordPressAPIInternal.WpNetworkResponse, - request: WordPressAPIInternal.WpNetworkRequest + request: WordPressAPIInternal.WpNetworkRequest, + context: RequestContext? ) async throws -> WordPressAPIInternal.WpNetworkResponse { debugPrint("Performed request: \(String(describing: try? request.buildURLRequest(additionalHeaders: [:])))") debugPrint("Received response: \(response)") diff --git a/native/swift/Sources/wordpress-api/SafeRequestExecutor.swift b/native/swift/Sources/wordpress-api/SafeRequestExecutor.swift index ee03b9a2f..0ab39589d 100644 --- a/native/swift/Sources/wordpress-api/SafeRequestExecutor.swift +++ b/native/swift/Sources/wordpress-api/SafeRequestExecutor.swift @@ -75,6 +75,14 @@ public final class WpRequestExecutor: SafeRequestExecutor { } } + public func cancel(context: RequestContext) { + for requestId in context.requestIds() { + Task { + await self.cancelRequest(withId: requestId) + } + } + } + func perform(_ request: NetworkRequestContent) async -> Result { do { let (data, response) = try await request.perform( @@ -131,6 +139,26 @@ public final class WpRequestExecutor: SafeRequestExecutor { } #endif + private func cancelRequest(withId requestId: String) async { +#if canImport(Combine) + var task = (await self.session.allTasks).first { + $0.originalRequest?.requestId == requestId + } + + if task == nil { + task = await NotificationCenter.default + .publisher(for: RequestExecutorDelegate.didCreateTaskNotification) + .compactMap { $0.object as? URLSessionTask } + .first { $0.originalRequest?.requestId == requestId } + .timeout(.seconds(1), scheduler: DispatchQueue.global()) + .values + .first { _ in true } + } + + task?.cancel() +#endif + } + private func handleHttpsError( _ error: Error, for request: NetworkRequestContent diff --git a/native/swift/Sources/wordpress-api/WordPressAPI.swift b/native/swift/Sources/wordpress-api/WordPressAPI.swift index e47477383..47294eea4 100644 --- a/native/swift/Sources/wordpress-api/WordPressAPI.swift +++ b/native/swift/Sources/wordpress-api/WordPressAPI.swift @@ -20,7 +20,7 @@ public actor WordPressAPI { } private let apiUrlResolver: ApiUrlResolver - private let requestExecutor: SafeRequestExecutor + let requestExecutor: SafeRequestExecutor private let apiClientDelegate: WpApiClientDelegate package let requestBuilder: UniffiWpApiClient diff --git a/native/swift/Tests/integration-tests/CancellationTests.swift b/native/swift/Tests/integration-tests/CancellationTests.swift new file mode 100644 index 000000000..71a540e01 --- /dev/null +++ b/native/swift/Tests/integration-tests/CancellationTests.swift @@ -0,0 +1,37 @@ +import Foundation +import Testing + +@testable import WordPressAPI +@testable import WordPressAPIInternal + +#if os(macOS) + +struct CancellationTests { + let api = WordPressAPI.admin() + + @Test + func cancelUploadingLongPost() async throws { + let file = try #require(Bundle.module.url(forResource: "test-data/test_media.jpg", withExtension: nil)) + let content = try String(data: Data(contentsOf: file).base64EncodedData(), encoding: .utf8)! + + let title = UUID().uuidString + await #expect( + throws: WpApiError.RequestExecutionFailed(statusCode: nil, redirects: nil, reason: .cancellationError), + performing: { + let task = Task { + _ = try await api.posts.create(params: .init(title: title, content: content, meta: nil)) + Issue.record("The creating post function should throw") + } + + try await Task.sleep(for: .milliseconds(10)) + task.cancel() + + try await task.value + } + ) + + try await restoreTestServer() + } +} + +#endif diff --git a/native/swift/Tests/integration-tests/MediaTests.swift b/native/swift/Tests/integration-tests/MediaTests.swift index a21de5a04..86e12a61c 100644 --- a/native/swift/Tests/integration-tests/MediaTests.swift +++ b/native/swift/Tests/integration-tests/MediaTests.swift @@ -1,5 +1,5 @@ import Foundation -import WordPressAPI +@testable import WordPressAPI import Testing @Suite @@ -55,6 +55,7 @@ struct MediaTests { fromLocalFileURL: file, fulfilling: progress ) + Issue.record("The creating post function should throw") } let cancellable = progress.publisher(for: \.fractionCompleted).first { $0 > 0 }.sink { _ in @@ -83,6 +84,7 @@ struct MediaTests { fromLocalFileURL: file, fulfilling: progress ) + Issue.record("The creating post function should throw") } let cancellable = progress.publisher(for: \.fractionCompleted).first { $0 > 0 }.sink { _ in diff --git a/native/swift/Tests/wordpress-api/Support/HTTPStubs.swift b/native/swift/Tests/wordpress-api/Support/HTTPStubs.swift index bac4e3697..33d0f5e98 100644 --- a/native/swift/Tests/wordpress-api/Support/HTTPStubs.swift +++ b/native/swift/Tests/wordpress-api/Support/HTTPStubs.swift @@ -25,7 +25,9 @@ final class HTTPStubs: SafeRequestExecutor { self } - public func execute(_ request: WpNetworkRequest) async -> Result { + public func execute( + _ request: WpNetworkRequest + ) async -> Result { if let response = stub(for: request) { return .success(response) } @@ -88,6 +90,10 @@ final class HTTPStubs: SafeRequestExecutor { // swiftlint:disable:next force_try try! await Task.sleep(nanoseconds: millis * 1000) } + + func cancel(context: RequestContext) { + // No-op + } } extension WpNetworkResponse { diff --git a/native/swift/Tests/wordpress-api/WordPressAPITests.swift b/native/swift/Tests/wordpress-api/WordPressAPITests.swift index 3143ad936..28f3f6db8 100644 --- a/native/swift/Tests/wordpress-api/WordPressAPITests.swift +++ b/native/swift/Tests/wordpress-api/WordPressAPITests.swift @@ -79,7 +79,8 @@ private actor CounterMiddleware: Middleware { func process( requestExecutor: RequestExecutor, response: WpNetworkResponse, - request: WpNetworkRequest + request: WpNetworkRequest, + context: RequestContext? ) async throws -> WpNetworkResponse { count += 1 return response diff --git a/native/swift/Tools/.gitignore b/native/swift/Tools/.gitignore new file mode 100644 index 000000000..0023a5340 --- /dev/null +++ b/native/swift/Tools/.gitignore @@ -0,0 +1,8 @@ +.DS_Store +/.build +/Packages +xcuserdata/ +DerivedData/ +.swiftpm/configuration/registries.json +.swiftpm/xcode/package.xcworkspace/contents.xcworkspacedata +.netrc diff --git a/native/swift/Tools/Package.resolved b/native/swift/Tools/Package.resolved new file mode 100644 index 000000000..52f8e4db7 --- /dev/null +++ b/native/swift/Tools/Package.resolved @@ -0,0 +1,33 @@ +{ + "originHash" : "8dd1ad804e0fca931c3e5e129e82648c4c38fe782cdd7db4b112f0f922bd5225", + "pins" : [ + { + "identity" : "swift-argument-parser", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-argument-parser.git", + "state" : { + "revision" : "309a47b2b1d9b5e991f36961c983ecec72275be3", + "version" : "1.6.1" + } + }, + { + "identity" : "swift-log", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-log.git", + "state" : { + "revision" : "ce592ae52f982c847a4efc0dd881cc9eb32d29f2", + "version" : "1.6.4" + } + }, + { + "identity" : "swift-syntax", + "kind" : "remoteSourceControl", + "location" : "https://github.com/swiftlang/swift-syntax.git", + "state" : { + "revision" : "0687f71944021d616d34d922343dcef086855920", + "version" : "600.0.1" + } + } + ], + "version" : 3 +} diff --git a/native/swift/Tools/Package.swift b/native/swift/Tools/Package.swift new file mode 100644 index 000000000..b6c283095 --- /dev/null +++ b/native/swift/Tools/Package.swift @@ -0,0 +1,30 @@ +// swift-tools-version: 6.1 + +import PackageDescription + +let package = Package( + name: "Tools", + platforms: [ + .macOS(.v12) + ], + products: [ + .executable(name: "generate-cancellable", targets: ["GenerateCancellable"]) + ], + dependencies: [ + .package(url: "https://github.com/swiftlang/swift-syntax.git", from: "600.0.0"), + .package(url: "https://github.com/apple/swift-argument-parser.git", from: "1.0.0"), + .package(url: "https://github.com/apple/swift-log.git", from: "1.0.0") + ], + targets: [ + .executableTarget( + name: "GenerateCancellable", + dependencies: [ + .product(name: "SwiftSyntax", package: "swift-syntax"), + .product(name: "SwiftParser", package: "swift-syntax"), + .product(name: "SwiftSyntaxBuilder", package: "swift-syntax"), + .product(name: "ArgumentParser", package: "swift-argument-parser"), + .product(name: "Logging", package: "swift-log") + ] + ) + ] +) diff --git a/native/swift/Tools/Sources/GenerateCancellable/main.swift b/native/swift/Tools/Sources/GenerateCancellable/main.swift new file mode 100644 index 000000000..1a10dd006 --- /dev/null +++ b/native/swift/Tools/Sources/GenerateCancellable/main.swift @@ -0,0 +1,365 @@ +import ArgumentParser +import Foundation +import Logging +import SwiftSyntax +import SwiftParser +import SwiftSyntaxBuilder + +@main +struct GenerateCancellable: ParsableCommand { + static let configuration = CommandConfiguration( + commandName: "generate-cancellable", + abstract: "Generate cancellable APIs" + ) + + @Argument(help: "The Swift source file to analyze") + var inputFile: String + + @Argument(help: "Output file path for generated extensions") + var output: String + + @Option(name: .long, help: "Set log level (debug, info, notice, warning, error, critical)") + var logLevel: String = "info" + + mutating func run() throws { + let logger = setupLogger() + + let inputURL = URL(fileURLWithPath: inputFile) + let outputURL = URL(fileURLWithPath: output) + + logger.info("Analyzing file: \(inputURL.path)") + logger.info("Generating extensions to: \(outputURL.path)") + + guard FileManager.default.fileExists(atPath: inputURL.path) else { + throw ValidationError("Input file does not exist: \(inputURL.path)") + } + + let sourceCode = try String(contentsOf: inputURL, encoding: .utf8) + let extensionsCode = try generateExtensions(sourceCode: sourceCode, logger: logger) + + try extensionsCode.write(to: outputURL, atomically: true, encoding: .utf8) + + logger.info("Successfully generated extensions to: \(outputURL.path)") + } + + private func setupLogger() -> Logger { + let logLevelValue = Logger.Level(rawValue: logLevel.lowercased()) ?? .info + + LoggingSystem.bootstrap { label in + var handler = StreamLogHandler.standardOutput(label: label) + handler.logLevel = logLevelValue + return handler + } + + return Logger(label: "generate-cancellable") + } + + private func generateExtensions(sourceCode: String, logger: Logger) throws -> String { + let syntaxTree = Parser.parse(source: sourceCode) + + logger.debug("Parsed syntax tree successfully") + + let analyzer = CancellationAnalyzer(logger: logger) + let analysis = analyzer.analyze(syntaxTree) + + logger.info("Found \(analysis.count) RequestExecutor classes") + + let generator = ExtensionGenerator(analysis: analysis, logger: logger) + return generator.generateExtensionsCode() + } +} + +struct FunctionInfo { + let declaration: FunctionDeclSyntax + let name: String + let parameters: FunctionParameterListSyntax + let returnType: TypeSyntax? + let modifiers: DeclModifierListSyntax? + let attributes: AttributeListSyntax? + let asyncKeyword: TokenSyntax? + let throwsClause: ThrowsClauseSyntax? +} + +struct ClassAnalysis { + let className: String + let cancellationFunctions: [FunctionInfo] + let existingFunctions: [String: FunctionInfo] +} + +class CancellationAnalyzer: SyntaxVisitor { + private let logger: Logger + private var analysis: [String: ClassAnalysis] = [:] + private var currentClass: String? + private var currentCancellationFunctions: [FunctionInfo] = [] + private var currentExistingFunctions: [String: FunctionInfo] = [:] + + init(logger: Logger) { + self.logger = logger + super.init(viewMode: .sourceAccurate) + } + + func analyze(_ tree: SourceFileSyntax) -> [String: ClassAnalysis] { + walk(tree) + return analysis + } + + override func visit(_ node: ClassDeclSyntax) -> SyntaxVisitorContinueKind { + let className = node.name.text + + if className.hasSuffix("RequestExecutor") { + logger.debug("Analyzing RequestExecutor class: \(className)") + + currentClass = className + currentCancellationFunctions = [] + currentExistingFunctions = [:] + + return .visitChildren + } + + return .skipChildren + } + + override func visitPost(_ node: ClassDeclSyntax) { + if let className = currentClass { + analysis[className] = ClassAnalysis( + className: className, + cancellationFunctions: currentCancellationFunctions, + existingFunctions: currentExistingFunctions + ) + + logger.debug( + "Found \(currentCancellationFunctions.count) cancellation functions in \(className)" + ) + logger.debug( + "Found \(currentExistingFunctions.count) existing functions in \(className)" + ) + + currentClass = nil + } + } + + override func visit(_ node: FunctionDeclSyntax) -> SyntaxVisitorContinueKind { + guard currentClass != nil else { return .skipChildren } + + let functionName = node.name.text + + let functionInfo = FunctionInfo( + declaration: node, + name: functionName, + parameters: node.signature.parameterClause.parameters, + returnType: node.signature.returnClause?.type, + modifiers: node.modifiers, + attributes: node.attributes, + asyncKeyword: node.signature.effectSpecifiers?.asyncSpecifier, + throwsClause: node.signature.effectSpecifiers?.throwsClause + ) + + if functionName.hasSuffix("Cancellation") && hasCancellationTokenParameter(node) { + currentCancellationFunctions.append(functionInfo) + logger.trace("Found cancellation function: \(functionName)") + } else { + currentExistingFunctions[functionName] = functionInfo + } + + return .skipChildren + } + + private func hasCancellationTokenParameter(_ function: FunctionDeclSyntax) -> Bool { + let parameters = function.signature.parameterClause.parameters + guard let lastParam = parameters.last else { return false } + + let paramName = lastParam.firstName.text + let paramType = lastParam.type.description.trimmingCharacters(in: .whitespacesAndNewlines) + + return paramName == "context" && paramType == "RequestContext?" + } +} + +class ExtensionGenerator { + private let analysis: [String: ClassAnalysis] + private let logger: Logger + + init(analysis: [String: ClassAnalysis], logger: Logger) { + self.analysis = analysis + self.logger = logger + } + + func generateExtensionsCode() -> String { + var extensions: [ExtensionDeclSyntax] = [] + + for (className, classAnalysis) in analysis { + logger.debug("Generating extension for class: \(className)") + + let extensionMembers = generateExtensionMembers(for: classAnalysis) + + if !extensionMembers.isEmpty { + let extensionDecl = ExtensionDeclSyntax( + extensionKeyword: .keyword(.extension, trailingTrivia: .space), + extendedType: IdentifierTypeSyntax(name: .identifier(className)), + memberBlock: MemberBlockSyntax( + leftBrace: .leftBraceToken(leadingTrivia: .space, trailingTrivia: .newlines(2)), + members: MemberBlockItemListSyntax(extensionMembers), + rightBrace: .rightBraceToken(leadingTrivia: .newline) + ), + trailingTrivia: .newlines(2) + ) + extensions.append(extensionDecl) + logger.info("Generated extension for \(className) with \(extensionMembers.count) functions") + } + } + + let sourceFile = SourceFileSyntax( + statements: CodeBlockItemListSyntax( + extensions.map { ext in + CodeBlockItemSyntax(item: .decl(DeclSyntax(ext))) + } + ) + ) + + return """ + // Do not modify. This file is automatically generated. + // swiftlint:disable all + + import Foundation + + \(sourceFile.description) + """ + } + + private func generateExtensionMembers(for classAnalysis: ClassAnalysis) -> [MemberBlockItemSyntax] { + var members: [MemberBlockItemSyntax] = [] + + for cancellationFunc in classAnalysis.cancellationFunctions { + let funcName = String(cancellationFunc.name.dropLast("Cancellation".count)) + + if classAnalysis.existingFunctions[funcName] != nil { + logger.warning( + "Skipping generation of '\(funcName)' - function already exists in \(classAnalysis.className)" + ) + continue + } + + logger.debug("Generating new function: \(funcName)") + + let newFunction = createNonCancellationFunction(from: cancellationFunc, name: funcName) + let memberItem = MemberBlockItemSyntax( + decl: DeclSyntax(newFunction), + trailingTrivia: .newline + ) + members.append(memberItem) + } + + return members + } + + private func createNonCancellationFunction( + from cancellationFunc: FunctionInfo, + name: String + ) -> FunctionDeclSyntax { + var parametersWithoutCancellation: FunctionParameterListSyntax + + if cancellationFunc.parameters.isEmpty || cancellationFunc.parameters.count == 1 { + parametersWithoutCancellation = FunctionParameterListSyntax([]) + } else { + var paramsArray = Array(cancellationFunc.parameters.dropLast()) + + if !paramsArray.isEmpty { + let lastIndex = paramsArray.count - 1 + paramsArray[lastIndex] = paramsArray[lastIndex].with(\.trailingComma, nil) + } + + parametersWithoutCancellation = FunctionParameterListSyntax(paramsArray) + } + + let effectSpecifiers = FunctionEffectSpecifiersSyntax( + asyncSpecifier: cancellationFunc.asyncKeyword?.with(\.leadingTrivia, .space).with(\.trailingTrivia, .space), + throwsClause: cancellationFunc.throwsClause?.with(\.leadingTrivia, .init()).with(\.trailingTrivia, .init()) + ) + + let returnClause = cancellationFunc.returnType.map { type in + let cleanType = type.with(\.leadingTrivia, .init()).with(\.trailingTrivia, .init()) + return ReturnClauseSyntax( + arrow: .arrowToken(leadingTrivia: .space, trailingTrivia: .space), + type: cleanType + ) + } + + let signature = FunctionSignatureSyntax( + parameterClause: FunctionParameterClauseSyntax( + leftParen: .leftParenToken(), + parameters: parametersWithoutCancellation, + rightParen: .rightParenToken() + ), + effectSpecifiers: effectSpecifiers, + returnClause: returnClause + ) + + let functionCallBody = createFunctionCallBody( + cancellationFunctionName: cancellationFunc.name, + parameters: parametersWithoutCancellation, + hasReturnValue: cancellationFunc.returnType != nil + ) + + let cleanModifiers: DeclModifierListSyntax + if let originalModifiers = cancellationFunc.modifiers, !originalModifiers.isEmpty { + var modifiersArray: [DeclModifierSyntax] = [] + + for (index, modifier) in originalModifiers.enumerated() { + var cleanModifier = modifier + + if modifier.name.text == "open" { + cleanModifier = cleanModifier.with(\.name, .keyword(.public)) + } + + cleanModifier = cleanModifier.with(\.leadingTrivia, index == 0 ? .spaces(4) : .space) + .with(\.trailingTrivia, .space) + + modifiersArray.append(cleanModifier) + } + + cleanModifiers = DeclModifierListSyntax(modifiersArray) + } else { + cleanModifiers = DeclModifierListSyntax([]) + } + + let funcKeywordLeadingTrivia: Trivia = cleanModifiers.isEmpty ? .spaces(4) : .init() + + return FunctionDeclSyntax( + attributes: cancellationFunc.attributes ?? AttributeListSyntax([]), + modifiers: cleanModifiers, + funcKeyword: .keyword(.func, leadingTrivia: funcKeywordLeadingTrivia, trailingTrivia: .space), + name: .identifier(name), + signature: signature, + body: functionCallBody + ) + } + + private func createFunctionCallBody( + cancellationFunctionName: String, + parameters: FunctionParameterListSyntax, + hasReturnValue: Bool + ) -> CodeBlockSyntax { + // Build parameter arguments for the function call + let parameterArguments = parameters.map { parameter in + let paramName = parameter.firstName.text + return "\(paramName): \(paramName)" + }.joined(separator: ", ") + + let functionCallArgs = parameterArguments.isEmpty ? + "context: context" : + "\(parameterArguments), context: context" + + return CodeBlockSyntax { + DeclSyntax( + """ + let context = RequestContext() + return try await withTaskCancellationHandler { + try await \(raw: cancellationFunctionName)(\(raw: functionCallArgs)) + } onCancel: { + self.cancel(context: context) + } + """ + ) + } + } +} diff --git a/scripts/swift-bindings.sh b/scripts/swift-bindings.sh index 79bfd9442..df9f91012 100755 --- a/scripts/swift-bindings.sh +++ b/scripts/swift-bindings.sh @@ -7,6 +7,8 @@ if [ $# -ne 1 ]; then exit 1 fi +echo "Generating Swift bindings... (This may take a while when running for the first time.)" + module_name="libwordpressFFI" library_path=$1 output_dir="$(dirname "$library_path")/swift-bindings" @@ -49,6 +51,15 @@ done rm -f native/swift/Sources/wordpress-api-wrapper/*.swift mv "$output_dir"/*.swift native/swift/Sources/wordpress-api-wrapper/ +for swift_file in native/swift/Sources/wordpress-api-wrapper/*.swift; do + basename=$(basename "$swift_file" .swift) + output_file="native/swift/Sources/wordpress-api-wrapper/${basename}_cancellable.swift" + + swift run -c release --quiet \ + --package-path native/swift/Tools \ + generate-cancellable --log-level warning "$swift_file" "$output_file" +done + header_dir="$output_dir/Headers" mkdir -p "$header_dir" diff --git a/wp_api/Cargo.toml b/wp_api/Cargo.toml index 8b67c2536..314ff5c81 100644 --- a/wp_api/Cargo.toml +++ b/wp_api/Cargo.toml @@ -4,6 +4,7 @@ version = "0.1.0" edition = "2024" [features] +export-uncancellable-endpoints = ["wp_derive_request_builder/export-uncancellable-endpoints"] integration-tests = [] reqwest-request-executor = ["dep:reqwest", "dep:tokio", "dep:hyper-util", "dep:rustls", "dep:hickory-resolver", "dep:hyper", "dep:h2"] diff --git a/wp_api/src/login/login_client.rs b/wp_api/src/login/login_client.rs index c95f66999..97295f2a9 100644 --- a/wp_api/src/login/login_client.rs +++ b/wp_api/src/login/login_client.rs @@ -299,6 +299,7 @@ impl WpLoginClient { body: None, } .into(), + None, ) .await } @@ -326,6 +327,7 @@ impl WpLoginClient { body: None, } .into(), + None, ) .await } @@ -411,6 +413,7 @@ impl WpLoginClient { body: Some(Arc::new(WpNetworkRequestBody::new(r#"system.listMethods"#.as_bytes().to_vec()))), } .into(), + None, ) .await // It's very likely xml-rpc is blocked by the hosting provider (the request has not reached to WordPress), @@ -449,6 +452,7 @@ impl WpLoginClient { body: None, } .into(), + None, ) .await .map_err(|error| XmlrpcDiscoveryError::FetchHomepage { error })?; @@ -467,6 +471,7 @@ impl WpLoginClient { body: None, } .into(), + None, ) .await .map_err(|_| XmlrpcDiscoveryError::Disabled { diff --git a/wp_api/src/middleware.rs b/wp_api/src/middleware.rs index 82080bd60..e3f905d7a 100644 --- a/wp_api/src/middleware.rs +++ b/wp_api/src/middleware.rs @@ -1,6 +1,7 @@ use crate::{ api_client::IsWpApiClientDelegate, api_error::{RequestExecutionError, RequestExecutionErrorReason}, + request::RequestContext, request::{RequestExecutor, WpNetworkRequest, WpNetworkResponse}, }; use std::{fmt::Debug, sync::Arc, time::Duration}; @@ -22,12 +23,18 @@ impl WpApiMiddlewarePipeline { request_executor: Arc, response: WpNetworkResponse, request: Arc, + context: Option>, ) -> Result { let mut response = response; for middleware in &self.middlewares { response = middleware - .process(request_executor.clone(), response, request.clone()) + .process( + request_executor.clone(), + response, + request.clone(), + context.clone(), + ) .await?; } @@ -71,6 +78,7 @@ pub trait WpApiMiddleware: Send + Sync + Debug { request_executor: Arc, response: WpNetworkResponse, request: Arc, + context: Option>, ) -> Result; } @@ -84,7 +92,12 @@ pub trait PerformsRequests { async fn perform( &self, request: Arc, + context: Option>, ) -> Result { + if let Some(context) = &context { + context.add_request_id(request.uuid.clone()); + } + let pipeline = &self.get_middleware_pipeline(); let response = self.get_request_executor().execute(request.clone()).await?; @@ -93,6 +106,7 @@ pub trait PerformsRequests { self.get_request_executor().clone(), response, request.clone(), + context.clone(), ) .await?; @@ -153,6 +167,7 @@ impl WpApiMiddleware for RetryAfterMiddleware { request_executor: Arc, response: WpNetworkResponse, request: Arc, + context: Option>, ) -> Result { let mut response = response; @@ -176,8 +191,12 @@ impl WpApiMiddleware for RetryAfterMiddleware { ) .await; let new_request = Arc::new(request.clone_with_incremented_retry_count()); + if let Some(context) = &context { + context.add_request_id(new_request.uuid.clone()); + } response = request_executor.execute(new_request.clone()).await?; - self.process(request_executor, response, new_request).await + self.process(request_executor, response, new_request, context) + .await } else { // We have no idea how long to wait so we shouldn't try Ok(response) @@ -210,6 +229,7 @@ impl WpApiMiddleware for ApiDiscoveryAuthenticationMiddleware { request_executor: Arc, response: WpNetworkResponse, request: Arc, + context: Option>, ) -> Result { if response.request_header_map.has_http_authentication() { // Request was already authenticated @@ -220,6 +240,10 @@ impl WpApiMiddleware for ApiDiscoveryAuthenticationMiddleware { return Ok(response); } + if let Some(context) = &context { + context.add_request_id(request.uuid.clone()); + } + request_executor .execute( request @@ -276,6 +300,8 @@ mod tests { } async fn sleep(&self, _: u64) {} + + fn cancel(&self, _: Arc) {} } #[tokio::test] @@ -364,6 +390,7 @@ mod tests { request_header_map: Arc::new(map.into()), }, WpNetworkRequest::get(WpEndpointUrl("unused".to_string())).into(), + None, ) .await } @@ -424,6 +451,8 @@ mod tests { } async fn sleep(&self, _: u64) {} + + fn cancel(&self, _: Arc) {} } #[tokio::test] @@ -460,6 +489,7 @@ mod tests { Arc::new(foo_executor), rate_limit_exceeded_response(), WpNetworkRequest::get(WpEndpointUrl("unused".to_string())).into(), + None, ) .await } diff --git a/wp_api/src/request.rs b/wp_api/src/request.rs index 0a435db1f..dff87550a 100644 --- a/wp_api/src/request.rs +++ b/wp_api/src/request.rs @@ -24,7 +24,7 @@ use std::{ collections::HashMap, fmt::Debug, str::{FromStr, Utf8Error}, - sync::Arc, + sync::{Arc, Mutex}, }; use url::Url; use uuid::Uuid; @@ -151,6 +151,8 @@ pub trait RequestExecutor: Send + Sync { ) -> Result; async fn sleep(&self, millis: u64); + + fn cancel(&self, context: Arc); } #[derive(uniffi::Object)] @@ -820,6 +822,35 @@ impl AuthenticationState { } } +#[derive(Debug, Default, uniffi::Object)] +pub struct RequestContext { + request_ids: Mutex>, +} + +#[uniffi::export] +impl RequestContext { + #[uniffi::constructor] + pub fn new() -> Self { + Self { + request_ids: Mutex::new(Vec::new()), + } + } + + pub fn add_request_id(&self, request_id: String) { + if let Ok(mut ids) = self.request_ids.lock() { + ids.push(request_id); + } + } + + pub fn request_ids(&self) -> Vec { + if let Ok(ids) = self.request_ids.lock() { + return (*ids).clone(); + } + + vec![] + } +} + pub mod user_agent { #[uniffi::export] pub fn default_user_agent(client_specific_postfix: &str) -> String { diff --git a/wp_api/src/reqwest_request_executor.rs b/wp_api/src/reqwest_request_executor.rs index 00a23f4bb..29b49a242 100644 --- a/wp_api/src/reqwest_request_executor.rs +++ b/wp_api/src/reqwest_request_executor.rs @@ -3,6 +3,7 @@ use crate::{ InvalidSslErrorReason, MediaUploadRequestExecutionError, RequestExecutionError, RequestExecutionErrorReason, }, + request::RequestContext, request::{ NetworkRequestAccessor, RequestExecutor, RequestMethod, WpNetworkHeaderMap, WpNetworkRequest, WpNetworkResponse, endpoint::media_endpoint::MediaUploadRequest, @@ -167,6 +168,10 @@ impl RequestExecutor for ReqwestRequestExecutor { async fn sleep(&self, millis: u64) { tokio::time::sleep(std::time::Duration::from_millis(millis)).await; } + + fn cancel(&self, _context: Arc) { + // No-op for reqwest + } } impl From for RequestExecutionError { diff --git a/wp_api_integration_tests/src/mock.rs b/wp_api_integration_tests/src/mock.rs index 32456d6c6..212d969a0 100644 --- a/wp_api_integration_tests/src/mock.rs +++ b/wp_api_integration_tests/src/mock.rs @@ -1,6 +1,8 @@ use async_trait::async_trait; use std::sync::Arc; -use wp_api::{prelude::*, request::endpoint::media_endpoint::MediaUploadRequest}; +use wp_api::{ + prelude::*, request::RequestContext, request::endpoint::media_endpoint::MediaUploadRequest, +}; #[derive(Debug)] pub struct MockExecutor { @@ -39,6 +41,8 @@ impl RequestExecutor for MockExecutor { } async fn sleep(&self, _: u64) {} + + fn cancel(&self, _: Arc) {} } pub mod response_helpers { diff --git a/wp_api_integration_tests/tests/test_app_notifier_immut.rs b/wp_api_integration_tests/tests/test_app_notifier_immut.rs index c99a4244c..f55408904 100644 --- a/wp_api_integration_tests/tests/test_app_notifier_immut.rs +++ b/wp_api_integration_tests/tests/test_app_notifier_immut.rs @@ -2,7 +2,7 @@ use std::sync::{ Mutex, atomic::{AtomicBool, Ordering}, }; -use wp_api::users::UserListParams; +use wp_api::{request::RequestContext, users::UserListParams}; use wp_api_integration_tests::prelude::*; #[tokio::test] @@ -133,4 +133,6 @@ impl RequestExecutor for TrackedRequestExecutor { } async fn sleep(&self, _: u64) {} + + fn cancel(&self, _: Arc) {} } diff --git a/wp_api_integration_tests/tests/test_media_err.rs b/wp_api_integration_tests/tests/test_media_err.rs index 10d24c2ed..a51c5285d 100644 --- a/wp_api_integration_tests/tests/test_media_err.rs +++ b/wp_api_integration_tests/tests/test_media_err.rs @@ -3,6 +3,7 @@ use wp_api::{ media::{MediaCreateParams, MediaId, MediaListParams, MediaUpdateParams}, posts::WpApiParamPostsOrderBy, prelude::*, + request::RequestContext, request::endpoint::media_endpoint::MediaUploadRequest, users::UserId, }; @@ -259,4 +260,6 @@ impl RequestExecutor for MediaErrNetworking { async fn sleep(&self, millis: u64) { tokio::time::sleep(std::time::Duration::from_millis(millis)).await; } + + fn cancel(&self, _context: Arc) {} } diff --git a/wp_derive_request_builder/Cargo.toml b/wp_derive_request_builder/Cargo.toml index 564d333ad..3480f105a 100644 --- a/wp_derive_request_builder/Cargo.toml +++ b/wp_derive_request_builder/Cargo.toml @@ -6,6 +6,7 @@ autotests = false [features] generate_request_builder = [] +export-uncancellable-endpoints = [] [lib] proc-macro = true diff --git a/wp_derive_request_builder/src/generate.rs b/wp_derive_request_builder/src/generate.rs index 009fa0144..fdd0a0f70 100644 --- a/wp_derive_request_builder/src/generate.rs +++ b/wp_derive_request_builder/src/generate.rs @@ -39,17 +39,17 @@ fn generate_async_request_executor( let generated_request_builder_ident = &config.generated_idents.request_builder; let generated_request_executor_ident = &config.generated_idents.request_executor; - let functions = parsed_enum.variants.iter().map(|variant| { + let (exported_functions, unexported_functions) = parsed_enum.variants.iter().fold((vec![], vec![]), |(mut exported, mut unexported), variant| { let url_parts = variant.attr.url_parts.as_slice(); let params_type = &variant.attr.params; - ContextAndFilterHandler::from_request_type( + (exported, unexported) = ContextAndFilterHandler::from_request_type( variant.attr.request_type, variant.attr.filter_by.clone(), &variant.attr.available_contexts, ) .into_iter() - .map(|context_and_filter_handler| { + .fold((exported, unexported), |(mut exported, mut unexported), context_and_filter_handler| { let request_from_request_builder = fn_body_get_request_from_request_builder( &variant.variant_ident, url_parts, @@ -65,20 +65,28 @@ fn generate_async_request_executor( variant.attr.request_type, &context_and_filter_handler, ); + let fn_signature_cancellable = append_context_param(fn_signature.clone()); + let fn_signature_body = invoke_cancellation_variant(fn_signature.clone()); let response_type_ident = ident_response_type( &parsed_enum.enum_ident, &variant.variant_ident, &context_and_filter_handler, ); - quote! { + + let uncancellable = quote! { pub async #fn_signature -> Result<#response_type_ident, #error_type> { + #fn_signature_body + } + }; + let cancellable = quote! { + pub async #fn_signature_cancellable -> Result<#response_type_ident, #error_type> { use #crate_ident::api_error::MaybeWpError; use #crate_ident::middleware::PerformsRequests; use #crate_ident::request::NetworkRequestAccessor; let perform_request = async || { #request_from_request_builder let request_url: String = request.url().into(); - let response = self.perform(std::sync::Arc::new(request)).await?; + let response = self.perform(std::sync::Arc::new(request), context.clone()).await?; let response_status_code = response.status_code; let parsed_response = response.parse(); let unauthorized = parsed_response.is_unauthorized_error().unwrap_or_default() || (response_status_code == 401 && self.fetch_authentication_state().await.map(|auth_state| auth_state.is_unauthorized()).unwrap_or_default()); @@ -106,9 +114,20 @@ fn generate_async_request_executor( parsed_response } + }; + + if cfg!(feature = "export-uncancellable-endpoints") { + exported.push(cancellable); + unexported.push(uncancellable); + } else { + unexported.push(cancellable); + exported.push(uncancellable); } - }) - .collect::() + + (exported, unexported) + }); + + (exported, unexported) }); let generated_return_types = parsed_enum.variants.iter().map(|variant| { @@ -214,11 +233,19 @@ fn generate_async_request_executor( } #[uniffi::export] impl #generated_request_executor_ident { - #(#functions)* + #(#exported_functions)* pub async fn fetch_authentication_state(&self) -> Result<#crate_ident::request::AuthenticationState, #error_type> { #crate_ident::request::fetch_authentication_state(self.delegate.request_executor.clone(), self.api_url_resolver.clone(), self.delegate.auth_provider.clone()).await } + + pub fn cancel(&self, context: std::sync::Arc) { + self.delegate.request_executor.cancel(context); + } + } + + impl #generated_request_executor_ident { + #(#unexported_functions)* } } } diff --git a/wp_derive_request_builder/src/generate/helpers_to_generate_tokens.rs b/wp_derive_request_builder/src/generate/helpers_to_generate_tokens.rs index cce078704..e4d68e48b 100644 --- a/wp_derive_request_builder/src/generate/helpers_to_generate_tokens.rs +++ b/wp_derive_request_builder/src/generate/helpers_to_generate_tokens.rs @@ -1,7 +1,7 @@ use convert_case::{Case, Casing}; use proc_macro2::{TokenStream, TokenTree}; use quote::{format_ident, quote}; -use syn::Ident; +use syn::{FnArg, Ident, Pat, parse_quote}; use super::{ContextAndFilterHandler, PartOf, WpContext}; use crate::{ @@ -84,6 +84,46 @@ pub fn fn_signature( quote! { fn #fn_name(&self, #url_params #provided_param #fields_param) } } +pub fn append_context_param(input: TokenStream) -> TokenStream { + let mut signature: syn::Signature = syn::parse2(input).unwrap(); + + let original_name = signature.ident.to_string(); + signature.ident = format_ident!("{}_cancellation", original_name); + + let new_arg: FnArg = + parse_quote! { context: Option> }; + signature.inputs.push(new_arg); + + quote! { #signature } +} + +pub fn invoke_cancellation_variant(input: TokenStream) -> TokenStream { + let signature: syn::Signature = syn::parse2(input).unwrap(); + + let fn_name = &signature.ident; + let fn_name = format_ident!("{}_cancellation", fn_name); + + let param_names: Vec<_> = signature + .inputs + .iter() + .filter_map(|arg| { + match arg { + FnArg::Typed(pat_type) => match &*pat_type.pat { + Pat::Ident(ident) => Some(ident.ident.clone()), + _ => None, + }, + FnArg::Receiver(_) => None, // Skip 'self' parameter + } + }) + .collect(); + + if param_names.is_empty() { + quote! { self.#fn_name(None).await } + } else { + quote! { self.#fn_name(#(#param_names),*, None).await } + } +} + pub fn fn_url_params(url_parts: &[UrlPart]) -> TokenStream { let params = url_parts.iter().filter_map(|p| { if let UrlPart::Dynamic(p) = p { @@ -1243,4 +1283,32 @@ mod tests { tokens: quote! { crate::SparseUserField }, }) } + + #[rstest] + #[case( + quote! { fn list(&self, params: &UserListParams) }, + "fn list_cancellation (& self , params : & UserListParams , context : Option < std :: sync :: Arc < crate :: request :: RequestContext > >)" + )] + #[case( + quote! { fn list(&self) }, + "fn list_cancellation (& self , context : Option < std :: sync :: Arc < crate :: request :: RequestContext > >)" + )] + fn test_append_context_param(#[case] input: TokenStream, #[case] expected: &str) { + let result = append_context_param(input); + assert_eq!(result.to_string(), expected); + } + + #[rstest] + #[case( + quote! { fn list(&self, params: &UserListParams, fields: &[SparseUserField]) }, + "self . list_cancellation (params , fields , None) . await" + )] + #[case( + quote! { fn list(&self) }, + "self . list_cancellation (None) . await" + )] + fn test_invoke_cancellation_variant(#[case] input: TokenStream, #[case] expected: &str) { + let result = invoke_cancellation_variant(input); + assert_eq!(result.to_string(), expected); + } }