diff --git a/Package.resolved b/Package.resolved new file mode 100644 index 0000000..8b5ef50 --- /dev/null +++ b/Package.resolved @@ -0,0 +1,16 @@ +{ + "object": { + "pins": [ + { + "package": "SnapshotTesting", + "repositoryURL": "https://github.com/pointfreeco/swift-snapshot-testing.git", + "state": { + "branch": null, + "revision": "f8a9c997c3c1dab4e216a8ec9014e23144cbab37", + "version": "1.9.0" + } + } + ] + }, + "version": 1 +} diff --git a/Package.swift b/Package.swift index 67ce836..3bf4ba2 100644 --- a/Package.swift +++ b/Package.swift @@ -10,11 +10,13 @@ let package = Package( .library( name: "PostgREST", targets: ["PostgREST"] - ), + ) ], dependencies: [ // Dependencies declare other packages that this package depends on. - // .package(url: /* package url */, from: "1.0.0"), + .package( + name: "SnapshotTesting", + url: "https://github.com/pointfreeco/swift-snapshot-testing.git", from: "1.8.1") ], targets: [ // Targets are the basic building blocks of a package. A target can define a module or a test suite. @@ -25,7 +27,7 @@ let package = Package( ), .testTarget( name: "PostgRESTTests", - dependencies: ["PostgREST"] - ), + dependencies: ["PostgREST", "SnapshotTesting"] + ) ] ) diff --git a/Sources/PostgREST/PostgrestBuilder.swift b/Sources/PostgREST/PostgrestBuilder.swift index 45fba59..1e6b733 100644 --- a/Sources/PostgREST/PostgrestBuilder.swift +++ b/Sources/PostgREST/PostgrestBuilder.swift @@ -2,19 +2,18 @@ import Foundation public class PostgrestBuilder { var url: String + var queryParams: [(name: String, value: String)] var headers: [String: String] var schema: String? var method: String? var body: [String: Any]? - public init(url: String, headers: [String: String] = [:], schema: String?) { - self.url = url - self.headers = headers - self.schema = schema - } - - public init(url: String, method: String?, headers: [String: String] = [:], schema: String?, body: [String: Any]?) { + init( + url: String, queryParams: [(name: String, value: String)], headers: [String: String], + schema: String?, method: String?, body: [String: Any]? + ) { self.url = url + self.queryParams = queryParams self.headers = headers self.schema = schema self.method = method @@ -22,6 +21,79 @@ public class PostgrestBuilder { } public func execute(head: Bool = false, count: CountOption? = nil, completion: @escaping (Result) -> Void) { + let request: URLRequest + do { + request = try buildURLRequest(head: head, count: count) + } catch { + completion(.failure(error)) + return + } + + let session = URLSession.shared + let dataTask = session.dataTask(with: request, completionHandler: { [unowned self] (data, response, error) -> Void in + if let error = error { + completion(.failure(error)) + return + } + + guard let response = response as? HTTPURLResponse else { + completion(.failure(PostgrestError(message: "failed to get response"))) + return + } + + guard let data = data else { + completion(.failure(PostgrestError(message: "empty data"))) + return + } + + do { + try validate(data: data, response: response) + let response = try parse(data: data, response: response) + completion(.success(response)) + } catch { + completion(.failure(error)) + } + }) + + dataTask.resume() + } + + private func validate(data: Data, response: HTTPURLResponse) throws { + if 200 ..< 300 ~= response.statusCode { + return + } + + guard let json = try JSONSerialization.jsonObject(with: data, options: []) as? [String: Any] else { + throw PostgrestError(message: "failed to get error") + } + + throw PostgrestError(from: json) ?? PostgrestError(message: "failed to get error") + } + + private func parse(data: Data, response: HTTPURLResponse) throws -> PostgrestResponse { + var body: Any = data + var count: Int? + + if method == "HEAD" { + if let accept = response.allHeaderFields["Accept"] as? String, accept == "text/csv" { + body = data + } else { + try JSONSerialization.jsonObject(with: data, options: []) + } + } + + if let contentRange = response.allHeaderFields["content-range"] as? String, + let lastElement = contentRange.split(separator: "/").last { + count = lastElement == "*" ? nil : Int(lastElement) + } + + let postgrestResponse = PostgrestResponse(body: body) + postgrestResponse.status = response.statusCode + postgrestResponse.count = count + return postgrestResponse + } + + func buildURLRequest(head: Bool, count: CountOption?) throws -> URLRequest { if head { method = "HEAD" } @@ -34,100 +106,42 @@ public class PostgrestBuilder { } } - if method == nil { - completion(.failure(PostgrestError(message: "Missing table operation: select, insert, update or delete"))) - return + guard let method = method else { + throw PostgrestError(message: "Missing table operation: select, insert, update or delete") } - if let method = method, method == "GET" || method == "HEAD" { + if method == "GET" || method == "HEAD" { headers["Content-Type"] = "application/json" } if let schema = schema { - if let method = method, method == "GET" || method == "HEAD" { + if method == "GET" || method == "HEAD" { headers["Accept-Profile"] = schema } else { headers["Content-Profile"] = schema } } - guard let url = URL(string: url) else { - completion(.failure(PostgrestError(message: "badURL"))) - return + guard var components = URLComponents(string: url) else { + throw PostgrestError(message: "badURL") + } + + if !queryParams.isEmpty { + components.queryItems = components.queryItems ?? [] + components.queryItems!.append(contentsOf: queryParams.map(URLQueryItem.init)) + } + + guard let url = components.url else { + throw PostgrestError(message: "badURL") } var request = URLRequest(url: url) request.httpMethod = method request.allHTTPHeaderFields = headers - - let session = URLSession.shared - let dataTask = session.dataTask(with: request, completionHandler: { [unowned self] (data, response, error) -> Void in - if let error = error { - completion(.failure(error)) - return - } - - if let resp = response as? HTTPURLResponse { - if let data = data { - do { - completion(.success(try self.parse(data: data, response: resp))) - } catch { - completion(.failure(error)) - return - } - } - } else { - completion(.failure(PostgrestError(message: "failed to get response"))) - } - - }) - - dataTask.resume() - } - - private func parse(data: Data, response: HTTPURLResponse) throws -> PostgrestResponse { - if response.statusCode == 200 || 200 ..< 300 ~= response.statusCode { - var body: Any = data - var count: Int? - - if let method = method, method == "HEAD" { - if let accept = response.allHeaderFields["Accept"] as? String, accept == "text/csv" { - body = data - } else { - do { - let json = try JSONSerialization.jsonObject(with: data, options: []) - body = json - } catch { - throw error - } - } - } - - if let contentRange = response.allHeaderFields["content-range"] as? String, let lastElement = contentRange.split(separator: "/").last { - count = lastElement == "*" ? nil : Int(lastElement) - } - - let postgrestResponse = PostgrestResponse(body: body) - postgrestResponse.status = response.statusCode - postgrestResponse.count = count - return postgrestResponse - } else { - do { - let json = try JSONSerialization.jsonObject(with: data, options: []) - if let errorJson: [String: Any] = json as? [String: Any] { - throw PostgrestError(from: errorJson) ?? PostgrestError(message: "failed to get error") - } else { - throw PostgrestError(message: "failed to get error") - } - } catch { - throw error - } - } + return request } func appendSearchParams(name: String, value: String) { - var urlComponent = URLComponents(string: url) - urlComponent?.queryItems?.append(URLQueryItem(name: name, value: value)) - url = urlComponent?.url?.absoluteString ?? url + queryParams.append((name, value)) } } diff --git a/Sources/PostgREST/PostgrestClient.swift b/Sources/PostgREST/PostgrestClient.swift index 97bf440..0d7c94f 100644 --- a/Sources/PostgREST/PostgrestClient.swift +++ b/Sources/PostgREST/PostgrestClient.swift @@ -1,5 +1,3 @@ - - public class PostgrestClient { var url: String var headers: [String: String] @@ -12,10 +10,10 @@ public class PostgrestClient { } public func form(_ table: String) -> PostgrestQueryBuilder { - return PostgrestQueryBuilder(url: "\(url)/\(table)", headers: headers, schema: schema) + return PostgrestQueryBuilder(url: "\(url)/\(table)", queryParams: [], headers: headers, schema: schema, method: nil, body: nil) } public func rpc(fn: String, parameters: [String: Any]?) -> PostgrestTransformBuilder { - return PostgrestRpcBuilder(url: "\(url)/rpc/\(fn)", headers: headers, schema: schema).rpc(parameters: parameters) + return PostgrestRpcBuilder(url: "\(url)/rpc/\(fn)", queryParams: [], headers: headers, schema: schema, method: nil, body: nil).rpc(parameters: parameters) } } diff --git a/Sources/PostgREST/PostgrestFilterBuilder.swift b/Sources/PostgREST/PostgrestFilterBuilder.swift index 830787d..508772e 100644 --- a/Sources/PostgREST/PostgrestFilterBuilder.swift +++ b/Sources/PostgREST/PostgrestFilterBuilder.swift @@ -7,62 +7,62 @@ public class PostgrestFilterBuilder: PostgrestTransformBuilder { public func not(column: String, operator op: Operator, value: String) -> PostgrestFilterBuilder { appendSearchParams(name: column, value: "not.\(op.rawValue).\(value)") - return PostgrestFilterBuilder(url: url, method: method, headers: headers, schema: schema, body: body) + return self } public func or(filters: String) -> PostgrestFilterBuilder { appendSearchParams(name: "or", value: "(\(filters))") - return PostgrestFilterBuilder(url: url, method: method, headers: headers, schema: schema, body: body) + return self } public func eq(column: String, value: String) -> PostgrestFilterBuilder { appendSearchParams(name: column, value: "eq.\(value)") - return PostgrestFilterBuilder(url: url, method: method, headers: headers, schema: schema, body: body) + return self } public func neq(column: String, value: String) -> PostgrestFilterBuilder { appendSearchParams(name: column, value: "neq.\(value)") - return PostgrestFilterBuilder(url: url, method: method, headers: headers, schema: schema, body: body) + return self } public func gt(column: String, value: String) -> PostgrestFilterBuilder { appendSearchParams(name: column, value: "gt.\(value)") - return PostgrestFilterBuilder(url: url, method: method, headers: headers, schema: schema, body: body) + return self } public func gte(column: String, value: String) -> PostgrestFilterBuilder { appendSearchParams(name: column, value: "gte.\(value)") - return PostgrestFilterBuilder(url: url, method: method, headers: headers, schema: schema, body: body) + return self } public func lt(column: String, value: String) -> PostgrestFilterBuilder { appendSearchParams(name: column, value: "lt.\(value)") - return PostgrestFilterBuilder(url: url, method: method, headers: headers, schema: schema, body: body) + return self } public func lte(column: String, value: String) -> PostgrestFilterBuilder { appendSearchParams(name: column, value: "lte.\(value)") - return PostgrestFilterBuilder(url: url, method: method, headers: headers, schema: schema, body: body) + return self } public func like(column: String, value: String) -> PostgrestFilterBuilder { appendSearchParams(name: column, value: "like.\(value)") - return PostgrestFilterBuilder(url: url, method: method, headers: headers, schema: schema, body: body) + return self } public func ilike(column: String, value: String) -> PostgrestFilterBuilder { appendSearchParams(name: column, value: "ilike.\(value)") - return PostgrestFilterBuilder(url: url, method: method, headers: headers, schema: schema, body: body) + return self } public func `is`(column: String, value: String) -> PostgrestFilterBuilder { appendSearchParams(name: column, value: "is.\(value)") - return PostgrestFilterBuilder(url: url, method: method, headers: headers, schema: schema, body: body) + return self } public func `in`(column: String, value: [String]) -> PostgrestFilterBuilder { appendSearchParams(name: column, value: "in.\(value.joined(separator: ","))") - return PostgrestFilterBuilder(url: url, method: method, headers: headers, schema: schema, body: body) + return self } public func contains(column: String, value: Any) -> PostgrestFilterBuilder { @@ -73,32 +73,32 @@ public class PostgrestFilterBuilder: PostgrestTransformBuilder { } else if let data: Data = try? JSONSerialization.data(withJSONObject: value, options: []), let json = String(data: data, encoding: .utf8) { appendSearchParams(name: column, value: "cs.\(json)") } - return PostgrestFilterBuilder(url: url, method: method, headers: headers, schema: schema, body: body) + return self } public func rangeLt(column: String, range: String) -> PostgrestFilterBuilder { appendSearchParams(name: column, value: "sl.\(range)") - return PostgrestFilterBuilder(url: url, method: method, headers: headers, schema: schema, body: body) + return self } public func rangeGt(column: String, range: String) -> PostgrestFilterBuilder { appendSearchParams(name: column, value: "sr.\(range)") - return PostgrestFilterBuilder(url: url, method: method, headers: headers, schema: schema, body: body) + return self } public func rangeGte(column: String, range: String) -> PostgrestFilterBuilder { appendSearchParams(name: column, value: "nxl.\(range)") - return PostgrestFilterBuilder(url: url, method: method, headers: headers, schema: schema, body: body) + return self } public func rangeLte(column: String, range: String) -> PostgrestFilterBuilder { appendSearchParams(name: column, value: "nxr.\(range)") - return PostgrestFilterBuilder(url: url, method: method, headers: headers, schema: schema, body: body) + return self } public func rangeAdjacent(column: String, range: String) -> PostgrestFilterBuilder { appendSearchParams(name: column, value: "adj.\(range)") - return PostgrestFilterBuilder(url: url, method: method, headers: headers, schema: schema, body: body) + return self } public func overlaps(column: String, value: Any) -> PostgrestFilterBuilder { @@ -107,48 +107,51 @@ public class PostgrestFilterBuilder: PostgrestTransformBuilder { } else if let arr: [String] = value as? [String] { appendSearchParams(name: column, value: "ov.\(arr.joined(separator: ","))") } - return PostgrestFilterBuilder(url: url, method: method, headers: headers, schema: schema, body: body) + return self } public func textSearch(column: String, range: String) -> PostgrestFilterBuilder { appendSearchParams(name: column, value: "adj.\(range)") - return PostgrestFilterBuilder(url: url, method: method, headers: headers, schema: schema, body: body) + return self } - public func textSearch(column: String, query: String, config: String? = nil, type: TextSearchType? = nil) -> PostgrestFilterBuilder { - appendSearchParams(name: column, value: "\(type?.rawValue ?? "")fts\(config ?? "").\(query)") - return PostgrestFilterBuilder(url: url, method: method, headers: headers, schema: schema, body: body) + public func textSearch( + column: String, query: String, config: String? = nil, type: TextSearchType? = nil + ) -> PostgrestFilterBuilder { + appendSearchParams( + name: column, value: "\(type?.rawValue ?? "")fts\(config ?? "").\(query)") + return self } public func fts(column: String, query: String, config: String? = nil) -> PostgrestFilterBuilder { appendSearchParams(name: column, value: "fts\(config ?? "").\(query)") - return PostgrestFilterBuilder(url: url, method: method, headers: headers, schema: schema, body: body) + return self } public func plfts(column: String, query: String, config: String? = nil) -> PostgrestFilterBuilder { appendSearchParams(name: column, value: "plfts\(config ?? "").\(query)") - return PostgrestFilterBuilder(url: url, method: method, headers: headers, schema: schema, body: body) + return self } public func phfts(column: String, query: String, config: String? = nil) -> PostgrestFilterBuilder { appendSearchParams(name: column, value: "phfts\(config ?? "").\(query)") - return PostgrestFilterBuilder(url: url, method: method, headers: headers, schema: schema, body: body) + return self } public func wfts(column: String, query: String, config: String? = nil) -> PostgrestFilterBuilder { appendSearchParams(name: column, value: "wfts\(config ?? "").\(query)") - return PostgrestFilterBuilder(url: url, method: method, headers: headers, schema: schema, body: body) + return self } public func filter(column: String, operator: String, value: String) -> PostgrestFilterBuilder { appendSearchParams(name: column, value: "\(`operator`).\(value)") - return PostgrestFilterBuilder(url: url, method: method, headers: headers, schema: schema, body: body) + return self } public func match(query: [String: String]) -> PostgrestFilterBuilder { query.forEach { key, value in appendSearchParams(name: key, value: "eq.\(value)") } - return PostgrestFilterBuilder(url: url, method: method, headers: headers, schema: schema, body: body) + return self } } diff --git a/Sources/PostgREST/PostgrestQueryBuilder.swift b/Sources/PostgREST/PostgrestQueryBuilder.swift index 891af8f..c688a54 100644 --- a/Sources/PostgREST/PostgrestQueryBuilder.swift +++ b/Sources/PostgREST/PostgrestQueryBuilder.swift @@ -1,9 +1,4 @@ - public class PostgrestQueryBuilder: PostgrestBuilder { - override public init(url: String, method: String? = nil, headers: [String: String] = [:], schema: String? = nil, body: [String: Any]? = nil) { - super.init(url: url, method: method, headers: headers, schema: schema, body: body) - } - public func select(columns: String = "*") -> PostgrestFilterBuilder { method = "GET" var quoted = false @@ -17,7 +12,9 @@ public class PostgrestQueryBuilder: PostgrestBuilder { return String(char) }.reduce("", +) appendSearchParams(name: "select", value: cleanedColumns) - return PostgrestFilterBuilder(url: url, method: method, headers: headers, schema: schema, body: body) + return PostgrestFilterBuilder( + url: url, queryParams: queryParams, headers: headers, schema: schema, method: method, + body: body) } public func insert(values: [String: Any], upsert: Bool = false, onConflict: String? = nil) -> PostgrestBuilder { @@ -46,12 +43,16 @@ public class PostgrestQueryBuilder: PostgrestBuilder { method = "PATCH" headers["Prefer"] = "return=representation" body = values - return PostgrestFilterBuilder(url: url, method: method, headers: headers, schema: schema, body: body) + return PostgrestFilterBuilder( + url: url, queryParams: queryParams, headers: headers, schema: schema, method: method, + body: body) } public func delete() -> PostgrestFilterBuilder { method = "DELETE" headers["Prefer"] = "return=representation" - return PostgrestFilterBuilder(url: url, method: method, headers: headers, schema: schema, body: body) + return PostgrestFilterBuilder( + url: url, queryParams: queryParams, headers: headers, schema: schema, method: method, + body: body) } } diff --git a/Sources/PostgREST/PostgrestRpcBuilder.swift b/Sources/PostgREST/PostgrestRpcBuilder.swift index 430a897..a7a409e 100644 --- a/Sources/PostgREST/PostgrestRpcBuilder.swift +++ b/Sources/PostgREST/PostgrestRpcBuilder.swift @@ -1,8 +1,9 @@ - public class PostgrestRpcBuilder: PostgrestBuilder { public func rpc(parameters: [String: Any]?) -> PostgrestTransformBuilder { method = "POST" body = parameters - return PostgrestTransformBuilder(url: url, method: method, headers: headers, schema: schema, body: body) + return PostgrestTransformBuilder( + url: url, queryParams: queryParams, headers: headers, schema: schema, method: schema, + body: body) } } diff --git a/Sources/PostgREST/PostgrestTransformBuilder.swift b/Sources/PostgREST/PostgrestTransformBuilder.swift index f576eb4..5e04bee 100644 --- a/Sources/PostgREST/PostgrestTransformBuilder.swift +++ b/Sources/PostgREST/PostgrestTransformBuilder.swift @@ -1,2 +1 @@ - public class PostgrestTransformBuilder: PostgrestBuilder {} diff --git a/Tests/LinuxMain.swift b/Tests/LinuxMain.swift index 351803f..d66b2e5 100644 --- a/Tests/LinuxMain.swift +++ b/Tests/LinuxMain.swift @@ -1,6 +1,5 @@ -import XCTest - import PostgRESTTests +import XCTest var tests = [XCTestCaseEntry]() tests += PostgRESTTests.allTests() diff --git a/Tests/PostgRESTTests/BuildURLRequestTests.swift b/Tests/PostgRESTTests/BuildURLRequestTests.swift new file mode 100644 index 0000000..8c3d801 --- /dev/null +++ b/Tests/PostgRESTTests/BuildURLRequestTests.swift @@ -0,0 +1,34 @@ +import Foundation +import SnapshotTesting +import XCTest + +@testable import PostgREST + +final class BuildURLRequestTests: XCTestCase { + let url = "https://example.supabase.co" + + struct TestCase { + let name: String + var record = false + let build: (PostgrestClient) throws -> URLRequest + } + + func testBuildURLRequest() throws { + let client = PostgrestClient(url: url, schema: nil) + + let testCases: [TestCase] = [ + TestCase(name: "select all users where email ends with '@supabase.co'") { client in + try client.form("users") + .select() + .like(column: "email", value: "%@supabase.co") + .buildURLRequest(head: false, count: nil) + } + ] + + for testCase in testCases { + let request = try testCase.build(client) + assertSnapshot( + matching: request, as: .curl, named: testCase.name, record: testCase.record) + } + } +} diff --git a/Tests/PostgRESTTests/PostgRESTTests.swift b/Tests/PostgRESTTests/PostgRESTTests.swift index 9ed089e..1c6509c 100644 --- a/Tests/PostgRESTTests/PostgRESTTests.swift +++ b/Tests/PostgRESTTests/PostgRESTTests.swift @@ -10,6 +10,6 @@ final class PostgRESTTests: XCTestCase { } static var allTests = [ - ("testExample", testExample), + ("testExample", testExample) ] } diff --git a/Tests/PostgRESTTests/XCTestManifests.swift b/Tests/PostgRESTTests/XCTestManifests.swift index 1c142e9..2ab32e2 100644 --- a/Tests/PostgRESTTests/XCTestManifests.swift +++ b/Tests/PostgRESTTests/XCTestManifests.swift @@ -3,7 +3,7 @@ import XCTest #if !canImport(ObjectiveC) public func allTests() -> [XCTestCaseEntry] { return [ - testCase(PostgRESTTests.allTests), + testCase(PostgRESTTests.allTests) ] } #endif diff --git a/Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildURLRequest.select-all-users-where-email-ends-with-supabase-co.txt b/Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildURLRequest.select-all-users-where-email-ends-with-supabase-co.txt new file mode 100644 index 0000000..d62c6b8 --- /dev/null +++ b/Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildURLRequest.select-all-users-where-email-ends-with-supabase-co.txt @@ -0,0 +1,3 @@ +curl \ + --header "Content-Type: application/json" \ + "https://example.supabase.co/users?select=*&email=like.%25@supabase.co" \ No newline at end of file