Skip to content

Commit 92205da

Browse files
authored
Merge pull request #1 from rcarver/knn
Knn
2 parents 9926796 + 7bf7f2e commit 92205da

File tree

4 files changed

+98
-1
lines changed

4 files changed

+98
-1
lines changed

Sources/ElasticsearchQueryBuilder/Components.swift

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ public protocol ArrayComponent {
1313
}
1414

1515
extension ArrayComponent {
16-
func makeCompactArray() -> [QueryDict] {
16+
public func makeCompactArray() -> [QueryDict] {
1717
var value = self.makeArray()
1818
value.removeAll(where: \.isEmpty)
1919
return value
@@ -30,6 +30,10 @@ public struct RootComponent<Component: DictComponent>: RootQueryable, DictCompon
3030
}
3131
}
3232

33+
public struct EmptyArrayComponent: ArrayComponent {
34+
public func makeArray() -> [QueryDict] { [] }
35+
}
36+
3337
/// Namespace for `@ElasticsearchQueryBuilder` components
3438
public enum esb {}
3539

@@ -184,6 +188,45 @@ extension esb {
184188
}
185189
}
186190

191+
/// Adds `knn` block to the query syntax.
192+
public struct kNearestNeighbor<Component: ArrayComponent>: DictComponent {
193+
let field: String
194+
let vector: [Double]
195+
let options: QueryDict
196+
var filter: Component
197+
public init(
198+
_ field: String,
199+
_ vector: [Double],
200+
options: () -> QueryDict = { [:] },
201+
@QueryArrayBuilder filter: () -> Component
202+
) {
203+
self.field = field
204+
self.vector = vector
205+
self.options = options()
206+
self.filter = filter()
207+
}
208+
public init(
209+
_ field: String,
210+
_ vector: [Double],
211+
options: () -> QueryDict = { [:] }
212+
) where Component == EmptyArrayComponent {
213+
self.field = field
214+
self.vector = vector
215+
self.options = options()
216+
self.filter = EmptyArrayComponent()
217+
}
218+
public func makeDict() -> QueryDict {
219+
var dict: QueryDict = self.options
220+
dict["field"] = .string(self.field)
221+
dict["query_vector"] = .array(self.vector)
222+
let filterValues = self.filter.makeCompactArray()
223+
if !filterValues.isEmpty {
224+
dict["filter"] = .array(filterValues.map(QueryValue.dict))
225+
}
226+
return [ "knn" : .dict(dict) ]
227+
}
228+
}
229+
187230
/// Adds `function_score` block to the query syntax.
188231
public struct FunctionScore<Component: DictComponent>: DictComponent {
189232
var component: Component

Sources/ElasticsearchQueryBuilder/Encoding.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ extension QueryValue: Encodable {
66
switch self {
77
case let .array(value):
88
try container.encode(value)
9+
case let .bool(value):
10+
try container.encode(value)
911
case let .date(value, format: format):
1012
switch format {
1113
case .secondsSince1970:

Sources/ElasticsearchQueryBuilder/QueryValue.swift

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ public typealias QueryDict = [ String : QueryValue ]
44

55
public enum QueryValue: Equatable {
66
case array([QueryValue])
7+
case bool(Bool)
78
case date(Date, format: QueryDateFormat)
89
case dict(QueryDict)
910
case float(Float)
@@ -45,6 +46,12 @@ extension QueryValue: ExpressibleByArrayLiteral {
4546
}
4647
}
4748

49+
extension QueryValue: ExpressibleByBooleanLiteral {
50+
public init(booleanLiteral value: Bool) {
51+
self = .bool(value)
52+
}
53+
}
54+
4855
extension QueryValue: ExpressibleByDictionaryLiteral {
4956
public init(dictionaryLiteral elements: (String, QueryValue)...) {
5057
var dict = QueryDict()

Tests/ElasticsearchQueryBuilderTests/ComponentTests.swift

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,51 @@ final class BoolTests: XCTestCase {
175175
}
176176
}
177177

178+
final class KNearestNeighborTests: XCTestCase {
179+
func testBuildBasic() throws {
180+
@ElasticsearchQueryBuilder func build() -> some esb.QueryDSL {
181+
esb.kNearestNeighbor("vector_field", [1,2,3])
182+
}
183+
XCTAssertNoDifference(build().makeQuery(), [
184+
"knn": [
185+
"field": "vector_field",
186+
"query_vector": [1.0, 2.0, 3.0],
187+
]
188+
])
189+
}
190+
func testBuildWithOptionsAndFilter() throws {
191+
@ElasticsearchQueryBuilder func build() -> some esb.QueryDSL {
192+
esb.kNearestNeighbor("vector_field", [1,2,3]) {
193+
[
194+
"k": 5,
195+
"index": true
196+
]
197+
} filter: {
198+
esb.Key("match_bool_prefix") {
199+
[
200+
"message": "quick brown f"
201+
]
202+
}
203+
}
204+
}
205+
XCTAssertNoDifference(build().makeQuery(), [
206+
"knn": [
207+
"field": "vector_field",
208+
"query_vector": [1.0, 2.0, 3.0],
209+
"k": 5,
210+
"index": true,
211+
"filter": [
212+
[
213+
"match_bool_prefix": [
214+
"message": "quick brown f"
215+
]
216+
]
217+
]
218+
]
219+
])
220+
}
221+
}
222+
178223
final class FunctionScoreTests: XCTestCase {
179224
func testBuild() throws {
180225
@ElasticsearchQueryBuilder func build() -> some esb.QueryDSL {

0 commit comments

Comments
 (0)