Skip to content

Commit a9daeeb

Browse files
committed
feat: add initial search async function with channel #341
1 parent 7279710 commit a9daeeb

File tree

4 files changed

+215
-0
lines changed

4 files changed

+215
-0
lines changed

v3/client.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package ldap
22

33
import (
4+
"context"
45
"crypto/tls"
56
"time"
67
)
@@ -32,6 +33,8 @@ type Client interface {
3233
PasswordModify(*PasswordModifyRequest) (*PasswordModifyResult, error)
3334

3435
Search(*SearchRequest) (*SearchResult, error)
36+
SearchAsync(ctx context.Context, searchRequest *SearchRequest, bufferSize int) Response
37+
SearchWithChannel(ctx context.Context, searchRequest *SearchRequest, channelSize int) <-chan *SearchSingleResult
3538
SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) (*SearchResult, error)
3639
DirSync(searchRequest *SearchRequest, flags, maxAttrCount int64, cookie []byte) (*SearchResult, error)
3740
}

v3/examples_test.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,35 @@ func ExampleConn_Search() {
5151
}
5252
}
5353

54+
// This example demonstrates how to search with channel
55+
func ExampleConn_SearchAsync() {
56+
l, err := DialURL(fmt.Sprintf("%s:%d", "ldap.example.com", 389))
57+
if err != nil {
58+
log.Fatal(err)
59+
}
60+
defer l.Close()
61+
62+
searchRequest := NewSearchRequest(
63+
"dc=example,dc=com", // The base dn to search
64+
ScopeWholeSubtree, NeverDerefAliases, 0, 0, false,
65+
"(&(objectClass=organizationalPerson))", // The filter to apply
66+
[]string{"dn", "cn"}, // A list attributes to retrieve
67+
nil,
68+
)
69+
70+
ctx, cancel := context.WithCancel(context.Background())
71+
defer cancel()
72+
73+
r := l.SearchAsync(ctx, searchRequest, 64)
74+
for r.Next() {
75+
entry := r.Entry()
76+
fmt.Printf("%s has DN %s\n", entry.GetAttributeValue("cn"), entry.DN)
77+
}
78+
if err := r.Err(); err != nil {
79+
log.Fatal(err)
80+
}
81+
}
82+
5483
// This example demonstrates how to search with channel
5584
func ExampleConn_SearchWithChannel() {
5685
l, err := DialURL(fmt.Sprintf("%s:%d", "ldap.example.com", 389))

v3/response.go

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
package ldap
2+
3+
import (
4+
"context"
5+
"errors"
6+
"fmt"
7+
8+
ber "github.com/go-asn1-ber/asn1-ber"
9+
)
10+
11+
// Response defines an interface to get data from an LDAP server
12+
type Response interface {
13+
Entry() *Entry
14+
Referral() string
15+
Controls() []Control
16+
Err() error
17+
Next() bool
18+
}
19+
20+
type searchResponse struct {
21+
conn *Conn
22+
ch chan *SearchSingleResult
23+
24+
entry *Entry
25+
referral string
26+
controls []Control
27+
err error
28+
}
29+
30+
// Entry returns an entry from the given search request
31+
func (r *searchResponse) Entry() *Entry {
32+
return r.entry
33+
}
34+
35+
// Referral returns a referral from the given search request
36+
func (r *searchResponse) Referral() string {
37+
return r.referral
38+
}
39+
40+
// Controls returns controls from the given search request
41+
func (r *searchResponse) Controls() []Control {
42+
return r.controls
43+
}
44+
45+
// Err returns an error when the given search request was failed
46+
func (r *searchResponse) Err() error {
47+
return r.err
48+
}
49+
50+
// Next returns whether next data exist or not
51+
func (r *searchResponse) Next() bool {
52+
res := <-r.ch
53+
if res == nil {
54+
return false
55+
}
56+
r.err = res.Error
57+
if r.err != nil {
58+
return false
59+
}
60+
r.err = r.conn.GetLastError()
61+
if r.err != nil {
62+
return false
63+
}
64+
r.entry = res.Entry
65+
r.referral = res.Referral
66+
r.controls = res.Controls
67+
return true
68+
}
69+
70+
func (r *searchResponse) searchAsync(
71+
ctx context.Context, searchRequest *SearchRequest, bufferSize int) {
72+
if bufferSize > 0 {
73+
r.ch = make(chan *SearchSingleResult, bufferSize)
74+
} else {
75+
r.ch = make(chan *SearchSingleResult)
76+
}
77+
go func() {
78+
defer func() {
79+
close(r.ch)
80+
if err := recover(); err != nil {
81+
r.conn.err = fmt.Errorf("ldap: recovered panic in searchAsync: %v", err)
82+
}
83+
}()
84+
85+
if r.conn.IsClosing() {
86+
return
87+
}
88+
89+
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
90+
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, r.conn.nextMessageID(), "MessageID"))
91+
// encode search request
92+
err := searchRequest.appendTo(packet)
93+
if err != nil {
94+
r.ch <- &SearchSingleResult{Error: err}
95+
return
96+
}
97+
r.conn.Debug.PrintPacket(packet)
98+
99+
msgCtx, err := r.conn.sendMessage(packet)
100+
if err != nil {
101+
r.ch <- &SearchSingleResult{Error: err}
102+
return
103+
}
104+
defer r.conn.finishMessage(msgCtx)
105+
106+
foundSearchSingleResultDone := false
107+
for !foundSearchSingleResultDone {
108+
select {
109+
case <-ctx.Done():
110+
r.conn.Debug.Printf("%d: %s", msgCtx.id, ctx.Err().Error())
111+
return
112+
default:
113+
r.conn.Debug.Printf("%d: waiting for response", msgCtx.id)
114+
packetResponse, ok := <-msgCtx.responses
115+
if !ok {
116+
err := NewError(ErrorNetwork, errors.New("ldap: response channel closed"))
117+
r.ch <- &SearchSingleResult{Error: err}
118+
return
119+
}
120+
packet, err = packetResponse.ReadPacket()
121+
r.conn.Debug.Printf("%d: got response %p", msgCtx.id, packet)
122+
if err != nil {
123+
r.ch <- &SearchSingleResult{Error: err}
124+
return
125+
}
126+
127+
if r.conn.Debug {
128+
if err := addLDAPDescriptions(packet); err != nil {
129+
r.ch <- &SearchSingleResult{Error: err}
130+
return
131+
}
132+
ber.PrintPacket(packet)
133+
}
134+
135+
switch packet.Children[1].Tag {
136+
case ApplicationSearchResultEntry:
137+
r.ch <- &SearchSingleResult{
138+
Entry: &Entry{
139+
DN: packet.Children[1].Children[0].Value.(string),
140+
Attributes: unpackAttributes(packet.Children[1].Children[1].Children),
141+
},
142+
}
143+
144+
case ApplicationSearchResultDone:
145+
if err := GetLDAPError(packet); err != nil {
146+
r.ch <- &SearchSingleResult{Error: err}
147+
return
148+
}
149+
if len(packet.Children) == 3 {
150+
result := &SearchSingleResult{}
151+
for _, child := range packet.Children[2].Children {
152+
decodedChild, err := DecodeControl(child)
153+
if err != nil {
154+
werr := fmt.Errorf("failed to decode child control: %w", err)
155+
r.ch <- &SearchSingleResult{Error: werr}
156+
return
157+
}
158+
result.Controls = append(result.Controls, decodedChild)
159+
}
160+
r.ch <- result
161+
}
162+
foundSearchSingleResultDone = true
163+
164+
case ApplicationSearchResultReference:
165+
ref := packet.Children[1].Children[0].Value.(string)
166+
r.ch <- &SearchSingleResult{Referral: ref}
167+
}
168+
}
169+
}
170+
r.conn.Debug.Printf("%d: returning", msgCtx.id)
171+
}()
172+
}

v3/search.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,17 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) {
584584
}
585585
}
586586

587+
// SearchAsync performs a search request and returns all search results asynchronously.
588+
// This means you get all results until an error happens (or the search successfully finished),
589+
// e.g. for size / time limited requests all are recieved until the limit is reached.
590+
// To stop the search, call cancel function returned context.
591+
func (l *Conn) SearchAsync(
592+
ctx context.Context, searchRequest *SearchRequest, bufferSize int) Response {
593+
r := &searchResponse{conn: l}
594+
r.searchAsync(ctx, searchRequest, bufferSize)
595+
return r
596+
}
597+
587598
// SearchWithChannel performs a search request and returns all search results
588599
// via the returned channel as soon as they are received. This means you get
589600
// all results until an error happens (or the search successfully finished),

0 commit comments

Comments
 (0)