diff --git a/Sources/StructuredQueriesCore/QueryFragmentBuilder.swift b/Sources/StructuredQueriesCore/QueryFragmentBuilder.swift index d029e44a..19e7e1ed 100644 --- a/Sources/StructuredQueriesCore/QueryFragmentBuilder.swift +++ b/Sources/StructuredQueriesCore/QueryFragmentBuilder.swift @@ -27,6 +27,12 @@ extension QueryFragmentBuilder { ) -> [QueryFragment] { [expression.queryFragment] } + + public static func buildExpression( + _ expression: some QueryExpression> + ) -> [QueryFragment] { + [expression.queryFragment] + } } extension QueryFragmentBuilder<()> { diff --git a/Sources/StructuredQueriesCore/ScalarFunctions.swift b/Sources/StructuredQueriesCore/ScalarFunctions.swift index f999dcb0..56cef290 100644 --- a/Sources/StructuredQueriesCore/ScalarFunctions.swift +++ b/Sources/StructuredQueriesCore/ScalarFunctions.swift @@ -120,7 +120,8 @@ extension QueryExpression where QueryValue: FloatingPoint { } } -extension QueryExpression where QueryValue: Numeric { +extension QueryExpression +where QueryValue: _OptionalPromotable, QueryValue._Optionalized.Wrapped: Numeric { /// Wraps this numeric query expression with the `abs` function. /// /// - Returns: An expression wrapped with the `abs` function. @@ -251,14 +252,18 @@ extension QueryExpression where QueryValue == String { public func instr(_ occurrence: some QueryExpression) -> some QueryExpression { QueryFunction("instr", self, occurrence) } +} +extension QueryExpression where QueryValue: _OptionalPromotable { /// Wraps this string expression with the `lower` function. /// /// - Returns: An expression wrapped with the `lower` function. public func lower() -> some QueryExpression { QueryFunction("lower", self) } +} +extension QueryExpression where QueryValue == String { /// Wraps this string expression with the `ltrim` function. /// /// - Parameter characters: Characters to trim. @@ -279,14 +284,18 @@ extension QueryExpression where QueryValue == String { public func octetLength() -> some QueryExpression { QueryFunction("octet_length", self) } +} +extension QueryExpression where QueryValue: _OptionalPromotable { /// Wraps this string expression with the `quote` function. /// /// - Returns: An expression wrapped with the `quote` function. public func quote() -> some QueryExpression { QueryFunction("quote", self) } +} +extension QueryExpression where QueryValue == String { /// Creates an expression invoking the `replace` function. /// /// - Parameters: @@ -346,13 +355,15 @@ extension QueryExpression where QueryValue == String { return QueryFunction("trim", self) } } +} +extension QueryExpression where QueryValue: _OptionalPromotable { /// Wraps this string query expression with the `unhex` function. /// /// - Parameter characters: Non-hexadecimal characters to skip. /// - Returns: An optional blob expression of the `unhex` function wrapping this expression. public func unhex( - _ characters: (some QueryExpression)? = QueryValue?.none + _ characters: (some QueryExpression)? = String?.none ) -> some QueryExpression<[UInt8]?> { if let characters { return QueryFunction("unhex", self, characters) diff --git a/Sources/StructuredQueriesCore/Statements/Delete.swift b/Sources/StructuredQueriesCore/Statements/Delete.swift index 3d2898de..4595dbe7 100644 --- a/Sources/StructuredQueriesCore/Statements/Delete.swift +++ b/Sources/StructuredQueriesCore/Statements/Delete.swift @@ -48,7 +48,9 @@ public struct Delete { /// /// - Parameter keyPath: A key path to a Boolean expression to filter by. /// - Returns: A statement with the added predicate. - public func `where`(_ keyPath: KeyPath>) -> Self { + public func `where`( + _ keyPath: KeyPath>> + ) -> Self { var update = self update.where.append(From.columns[keyPath: keyPath].queryFragment) return update @@ -64,7 +66,9 @@ public struct Delete { /// - Parameter predicate: A closure that returns a Boolean expression to filter by. /// - Returns: A statement with the added predicate. @_disfavoredOverload - public func `where`(_ predicate: (From.TableColumns) -> some QueryExpression) -> Self { + public func `where`( + _ predicate: (From.TableColumns) -> some QueryExpression> + ) -> Self { var update = self update.where.append(predicate(From.columns).queryFragment) return update diff --git a/Sources/StructuredQueriesCore/Statements/Select.swift b/Sources/StructuredQueriesCore/Statements/Select.swift index fcb11d0b..2a356537 100644 --- a/Sources/StructuredQueriesCore/Statements/Select.swift +++ b/Sources/StructuredQueriesCore/Statements/Select.swift @@ -236,7 +236,7 @@ extension Table { /// columns. /// - Returns: A select statement that is filtered by the given predicate. public static func having( - _ predicate: (TableColumns) -> some QueryExpression + _ predicate: (TableColumns) -> some QueryExpression> ) -> SelectOf { Where().having(predicate) } @@ -1110,7 +1110,7 @@ extension Select { /// - Parameter keyPath: A key path from this select's table to a Boolean expression to filter by. /// - Returns: A new select statement that appends the given predicate to its `WHERE` clause. public func `where`( - _ keyPath: KeyPath> + _ keyPath: KeyPath>> ) -> Self where Joins == () { var select = self @@ -1125,7 +1125,9 @@ extension Select { /// - Returns: A new select statement that appends the given predicate to its `WHERE` clause. @_disfavoredOverload public func `where`( - _ predicate: (From.TableColumns, repeat (each J).TableColumns) -> some QueryExpression + _ predicate: (From.TableColumns, repeat (each J).TableColumns) -> some QueryExpression< + some _OptionalPromotable + > ) -> Self where Joins == (repeat each J) { var select = self @@ -1218,7 +1220,9 @@ extension Select { /// - Returns: A new select statement that appends the given predicate to its `HAVING` clause. @_disfavoredOverload public func having( - _ predicate: (From.TableColumns, repeat (each J).TableColumns) -> some QueryExpression + _ predicate: (From.TableColumns, repeat (each J).TableColumns) -> some QueryExpression< + some _OptionalPromotable + > ) -> Self where Joins == (repeat each J) { var select = self diff --git a/Sources/StructuredQueriesCore/Statements/SelectStatement.swift b/Sources/StructuredQueriesCore/Statements/SelectStatement.swift index 6cda833b..ce7ef301 100644 --- a/Sources/StructuredQueriesCore/Statements/SelectStatement.swift +++ b/Sources/StructuredQueriesCore/Statements/SelectStatement.swift @@ -47,7 +47,7 @@ public typealias SelectStatementOf = extension SelectStatement { public static func `where`( - _ predicate: (From.TableColumns) -> some QueryExpression + _ predicate: (From.TableColumns) -> some QueryExpression> ) -> Self where Self == Where { Self(predicates: [predicate(From.columns).queryFragment]) diff --git a/Sources/StructuredQueriesCore/Statements/Update.swift b/Sources/StructuredQueriesCore/Statements/Update.swift index 685bdf77..6ab67078 100644 --- a/Sources/StructuredQueriesCore/Statements/Update.swift +++ b/Sources/StructuredQueriesCore/Statements/Update.swift @@ -105,7 +105,9 @@ public struct Update { /// /// - Parameter keyPath: A key path to a Boolean expression to filter by. /// - Returns: A statement with the added predicate. - public func `where`(_ keyPath: KeyPath>) -> Self { + public func `where`( + _ keyPath: KeyPath>> + ) -> Self { var update = self update.where.append(From.columns[keyPath: keyPath].queryFragment) return update @@ -117,7 +119,7 @@ public struct Update { /// - Returns: A statement with the added predicate. @_disfavoredOverload public func `where`( - _ predicate: (From.TableColumns) -> some QueryExpression + _ predicate: (From.TableColumns) -> some QueryExpression> ) -> Self { var update = self update.where.append(predicate(From.columns).queryFragment) diff --git a/Sources/StructuredQueriesCore/Statements/Where.swift b/Sources/StructuredQueriesCore/Statements/Where.swift index 5960d958..2e1704d0 100644 --- a/Sources/StructuredQueriesCore/Statements/Where.swift +++ b/Sources/StructuredQueriesCore/Statements/Where.swift @@ -20,7 +20,7 @@ extension Table { /// - Parameter keyPath: A key path to a Boolean expression to filter by. /// - Returns: A `WHERE` clause. public static func `where`( - _ keyPath: KeyPath> + _ keyPath: KeyPath>> ) -> Where { Where(predicates: [columns[keyPath: keyPath].queryFragment]) } @@ -33,7 +33,7 @@ extension Table { /// - Returns: A `WHERE` clause. @_disfavoredOverload public static func `where`( - _ predicate: (TableColumns) -> some QueryExpression + _ predicate: (TableColumns) -> some QueryExpression> ) -> Where { Where(predicates: [predicate(columns).queryFragment]) } @@ -292,7 +292,7 @@ extension Where: SelectStatement { /// - Parameter keyPath: A key path to a Boolean expression to filter by. /// - Returns: A where clause with the added predicate. public func `where`( - _ keyPath: KeyPath> + _ keyPath: KeyPath>> ) -> Self { var `where` = self `where`.predicates.append(From.columns[keyPath: keyPath].queryFragment) @@ -305,7 +305,7 @@ extension Where: SelectStatement { /// - Returns: A where clause with the added predicate. @_disfavoredOverload public func `where`( - _ predicate: (From.TableColumns) -> some QueryExpression + _ predicate: (From.TableColumns) -> some QueryExpression> ) -> Self { var `where` = self `where`.predicates.append(predicate(From.columns).queryFragment) @@ -409,7 +409,7 @@ extension Where: SelectStatement { /// A select statement for the filtered table with the given `HAVING` clause. public func having( - _ predicate: (From.TableColumns) -> some QueryExpression + _ predicate: (From.TableColumns) -> some QueryExpression> ) -> SelectOf { asSelect().having(predicate) } diff --git a/Tests/StructuredQueriesTests/WhereTests.swift b/Tests/StructuredQueriesTests/WhereTests.swift index e0b61e15..ca72bcf9 100644 --- a/Tests/StructuredQueriesTests/WhereTests.swift +++ b/Tests/StructuredQueriesTests/WhereTests.swift @@ -1,6 +1,8 @@ +import Dependencies import Foundation import InlineSnapshotTesting import StructuredQueries +import StructuredQueriesSQLite import Testing extension SnapshotTests { @@ -109,5 +111,30 @@ extension SnapshotTests { """ } } + + @Test func optionalBoolean() throws { + @Dependency(\.defaultDatabase) var db + let remindersListIDs = try db.execute( + RemindersList.insert { + RemindersList.Draft(title: "New list") + } + .returning(\.id) + ) + let remindersListID = try #require(remindersListIDs.first) + + assertQuery( + RemindersList + .find(remindersListID) + .leftJoin(Reminder.all) { $0.id.eq($1.remindersListID) } + .where { $1.isCompleted } + ) { + """ + SELECT "remindersLists"."id", "remindersLists"."color", "remindersLists"."title", "reminders"."id", "reminders"."assignedUserID", "reminders"."dueDate", "reminders"."isCompleted", "reminders"."isFlagged", "reminders"."notes", "reminders"."priority", "reminders"."remindersListID", "reminders"."title" + FROM "remindersLists" + LEFT JOIN "reminders" ON ("remindersLists"."id" = "reminders"."remindersListID") + WHERE ("remindersLists"."id" = 4) AND "reminders"."isCompleted" + """ + } + } } }