Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions Sources/StructuredQueriesCore/QueryFragmentBuilder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ extension QueryFragmentBuilder<Bool> {
) -> [QueryFragment] {
[expression.queryFragment]
}

public static func buildExpression(
_ expression: some QueryExpression<some _OptionalPromotable<Bool?>>
) -> [QueryFragment] {
[expression.queryFragment]
}
}

extension QueryFragmentBuilder<()> {
Expand Down
15 changes: 13 additions & 2 deletions Sources/StructuredQueriesCore/ScalarFunctions.swift
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ extension QueryExpression where QueryValue: FloatingPoint {
}
}

extension QueryExpression where QueryValue: Numeric {
extension QueryExpression
where QueryValue: _OptionalPromotable, QueryValue._Optionalized.Wrapped: Numeric {
/// Wraps this numeric query expression with the `abs` function.
///
/// - Returns: An expression wrapped with the `abs` function.
Expand Down Expand Up @@ -251,14 +252,18 @@ extension QueryExpression where QueryValue == String {
public func instr(_ occurrence: some QueryExpression<QueryValue>) -> some QueryExpression<Int> {
QueryFunction("instr", self, occurrence)
}
}

extension QueryExpression where QueryValue: _OptionalPromotable<String?> {
/// Wraps this string expression with the `lower` function.
///
/// - Returns: An expression wrapped with the `lower` function.
public func lower() -> some QueryExpression<QueryValue> {
QueryFunction("lower", self)
}
}

extension QueryExpression where QueryValue == String {
/// Wraps this string expression with the `ltrim` function.
///
/// - Parameter characters: Characters to trim.
Expand All @@ -279,14 +284,18 @@ extension QueryExpression where QueryValue == String {
public func octetLength() -> some QueryExpression<Int> {
QueryFunction("octet_length", self)
}
}

extension QueryExpression where QueryValue: _OptionalPromotable<String?> {
/// Wraps this string expression with the `quote` function.
///
/// - Returns: An expression wrapped with the `quote` function.
public func quote() -> some QueryExpression<QueryValue> {
QueryFunction("quote", self)
}
}

extension QueryExpression where QueryValue == String {
/// Creates an expression invoking the `replace` function.
///
/// - Parameters:
Expand Down Expand Up @@ -346,13 +355,15 @@ extension QueryExpression where QueryValue == String {
return QueryFunction("trim", self)
}
}
}

extension QueryExpression where QueryValue: _OptionalPromotable<String?> {
/// Wraps this string query expression with the `unhex` function.
///
/// - Parameter characters: Non-hexadecimal characters to skip.
/// - Returns: An optional blob expression of the `unhex` function wrapping this expression.
public func unhex(
_ characters: (some QueryExpression<QueryValue>)? = QueryValue?.none
_ characters: (some QueryExpression<String>)? = String?.none
) -> some QueryExpression<[UInt8]?> {
if let characters {
return QueryFunction("unhex", self, characters)
Expand Down
8 changes: 6 additions & 2 deletions Sources/StructuredQueriesCore/Statements/Delete.swift
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ public struct Delete<From: Table, Returning> {
///
/// - Parameter keyPath: A key path to a Boolean expression to filter by.
/// - Returns: A statement with the added predicate.
public func `where`(_ keyPath: KeyPath<From.TableColumns, some QueryExpression<Bool>>) -> Self {
public func `where`(
_ keyPath: KeyPath<From.TableColumns, some QueryExpression<some _OptionalPromotable<Bool?>>>
) -> Self {
var update = self
update.where.append(From.columns[keyPath: keyPath].queryFragment)
return update
Expand All @@ -64,7 +66,9 @@ public struct Delete<From: Table, Returning> {
/// - Parameter predicate: A closure that returns a Boolean expression to filter by.
/// - Returns: A statement with the added predicate.
@_disfavoredOverload
public func `where`(_ predicate: (From.TableColumns) -> some QueryExpression<Bool>) -> Self {
public func `where`(
_ predicate: (From.TableColumns) -> some QueryExpression<some _OptionalPromotable<Bool?>>
) -> Self {
var update = self
update.where.append(predicate(From.columns).queryFragment)
return update
Expand Down
12 changes: 8 additions & 4 deletions Sources/StructuredQueriesCore/Statements/Select.swift
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ extension Table {
/// columns.
/// - Returns: A select statement that is filtered by the given predicate.
public static func having(
_ predicate: (TableColumns) -> some QueryExpression<Bool>
_ predicate: (TableColumns) -> some QueryExpression<some _OptionalPromotable<Bool?>>
) -> SelectOf<Self> {
Where().having(predicate)
}
Expand Down Expand Up @@ -1110,7 +1110,7 @@ extension Select {
/// - Parameter keyPath: A key path from this select's table to a Boolean expression to filter by.
/// - Returns: A new select statement that appends the given predicate to its `WHERE` clause.
public func `where`(
_ keyPath: KeyPath<From.TableColumns, some QueryExpression<Bool>>
_ keyPath: KeyPath<From.TableColumns, some QueryExpression<some _OptionalPromotable<Bool?>>>
) -> Self
where Joins == () {
var select = self
Expand All @@ -1125,7 +1125,9 @@ extension Select {
/// - Returns: A new select statement that appends the given predicate to its `WHERE` clause.
@_disfavoredOverload
public func `where`<each J: Table>(
_ predicate: (From.TableColumns, repeat (each J).TableColumns) -> some QueryExpression<Bool>
_ predicate: (From.TableColumns, repeat (each J).TableColumns) -> some QueryExpression<
some _OptionalPromotable<Bool?>
>
) -> Self
where Joins == (repeat each J) {
var select = self
Expand Down Expand Up @@ -1218,7 +1220,9 @@ extension Select {
/// - Returns: A new select statement that appends the given predicate to its `HAVING` clause.
@_disfavoredOverload
public func having<each J: Table>(
_ predicate: (From.TableColumns, repeat (each J).TableColumns) -> some QueryExpression<Bool>
_ predicate: (From.TableColumns, repeat (each J).TableColumns) -> some QueryExpression<
some _OptionalPromotable<Bool?>
>
) -> Self
where Joins == (repeat each J) {
var select = self
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public typealias SelectStatementOf<From: Table, each Join: Table> =

extension SelectStatement {
public static func `where`<From>(
_ predicate: (From.TableColumns) -> some QueryExpression<Bool>
_ predicate: (From.TableColumns) -> some QueryExpression<some _OptionalPromotable<Bool?>>
) -> Self
where Self == Where<From> {
Self(predicates: [predicate(From.columns).queryFragment])
Expand Down
6 changes: 4 additions & 2 deletions Sources/StructuredQueriesCore/Statements/Update.swift
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ public struct Update<From: Table, Returning> {
///
/// - Parameter keyPath: A key path to a Boolean expression to filter by.
/// - Returns: A statement with the added predicate.
public func `where`(_ keyPath: KeyPath<From.TableColumns, some QueryExpression<Bool>>) -> Self {
public func `where`(
_ keyPath: KeyPath<From.TableColumns, some QueryExpression<some _OptionalPromotable<Bool?>>>
) -> Self {
var update = self
update.where.append(From.columns[keyPath: keyPath].queryFragment)
return update
Expand All @@ -117,7 +119,7 @@ public struct Update<From: Table, Returning> {
/// - Returns: A statement with the added predicate.
@_disfavoredOverload
public func `where`(
_ predicate: (From.TableColumns) -> some QueryExpression<Bool>
_ predicate: (From.TableColumns) -> some QueryExpression<some _OptionalPromotable<Bool?>>
) -> Self {
var update = self
update.where.append(predicate(From.columns).queryFragment)
Expand Down
10 changes: 5 additions & 5 deletions Sources/StructuredQueriesCore/Statements/Where.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ extension Table {
/// - Parameter keyPath: A key path to a Boolean expression to filter by.
/// - Returns: A `WHERE` clause.
public static func `where`(
_ keyPath: KeyPath<TableColumns, some QueryExpression<Bool>>
_ keyPath: KeyPath<TableColumns, some QueryExpression<some _OptionalPromotable<Bool?>>>
) -> Where<Self> {
Where(predicates: [columns[keyPath: keyPath].queryFragment])
}
Expand All @@ -33,7 +33,7 @@ extension Table {
/// - Returns: A `WHERE` clause.
@_disfavoredOverload
public static func `where`(
_ predicate: (TableColumns) -> some QueryExpression<Bool>
_ predicate: (TableColumns) -> some QueryExpression<some _OptionalPromotable<Bool?>>
) -> Where<Self> {
Where(predicates: [predicate(columns).queryFragment])
}
Expand Down Expand Up @@ -292,7 +292,7 @@ extension Where: SelectStatement {
/// - Parameter keyPath: A key path to a Boolean expression to filter by.
/// - Returns: A where clause with the added predicate.
public func `where`(
_ keyPath: KeyPath<From.TableColumns, some QueryExpression<Bool>>
_ keyPath: KeyPath<From.TableColumns, some QueryExpression<some _OptionalPromotable<Bool?>>>
) -> Self {
var `where` = self
`where`.predicates.append(From.columns[keyPath: keyPath].queryFragment)
Expand All @@ -305,7 +305,7 @@ extension Where: SelectStatement {
/// - Returns: A where clause with the added predicate.
@_disfavoredOverload
public func `where`(
_ predicate: (From.TableColumns) -> some QueryExpression<Bool>
_ predicate: (From.TableColumns) -> some QueryExpression<some _OptionalPromotable<Bool?>>
) -> Self {
var `where` = self
`where`.predicates.append(predicate(From.columns).queryFragment)
Expand Down Expand Up @@ -409,7 +409,7 @@ extension Where: SelectStatement {

/// A select statement for the filtered table with the given `HAVING` clause.
public func having(
_ predicate: (From.TableColumns) -> some QueryExpression<Bool>
_ predicate: (From.TableColumns) -> some QueryExpression<some _OptionalPromotable<Bool?>>
) -> SelectOf<From> {
asSelect().having(predicate)
}
Expand Down
27 changes: 27 additions & 0 deletions Tests/StructuredQueriesTests/WhereTests.swift
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import Dependencies
import Foundation
import InlineSnapshotTesting
import StructuredQueries
import StructuredQueriesSQLite
import Testing

extension SnapshotTests {
Expand Down Expand Up @@ -109,5 +111,30 @@ extension SnapshotTests {
"""
}
}

@Test func optionalBoolean() throws {
@Dependency(\.defaultDatabase) var db
let remindersListIDs = try db.execute(
RemindersList.insert {
RemindersList.Draft(title: "New list")
}
.returning(\.id)
)
let remindersListID = try #require(remindersListIDs.first)

assertQuery(
RemindersList
.find(remindersListID)
.leftJoin(Reminder.all) { $0.id.eq($1.remindersListID) }
.where { $1.isCompleted }
) {
"""
SELECT "remindersLists"."id", "remindersLists"."color", "remindersLists"."title", "reminders"."id", "reminders"."assignedUserID", "reminders"."dueDate", "reminders"."isCompleted", "reminders"."isFlagged", "reminders"."notes", "reminders"."priority", "reminders"."remindersListID", "reminders"."title"
FROM "remindersLists"
LEFT JOIN "reminders" ON ("remindersLists"."id" = "reminders"."remindersListID")
WHERE ("remindersLists"."id" = 4) AND "reminders"."isCompleted"
"""
}
}
}
}