diff --git a/acceptor.go b/acceptor.go index f58ef01f7..0228778c1 100644 --- a/acceptor.go +++ b/acceptor.go @@ -18,6 +18,7 @@ package quickfix import ( "bufio" "bytes" + "context" "crypto/tls" "io" "net" @@ -361,6 +362,7 @@ func (a *Acceptor) handleConnection(netConn net.Conn) { a.sessionAddr.Store(sessID, netConn.RemoteAddr()) msgIn := make(chan fixIn) msgOut := make(chan []byte) + ctx := context.Background() if err := session.connect(msgIn, msgOut); err != nil { a.globalLog.OnEventf("Unable to accept session %v connection: %v", sessID, err.Error()) @@ -369,10 +371,10 @@ func (a *Acceptor) handleConnection(netConn net.Conn) { go func() { msgIn <- fixIn{msgBytes, parser.lastRead} - readLoop(parser, msgIn, a.globalLog) + readLoop(ctx, parser, msgIn, a.globalLog) }() - writeLoop(netConn, msgOut, a.globalLog) + writeLoop(ctx, netConn, msgOut, a.globalLog) } func (a *Acceptor) dynamicSessionsLoop() { diff --git a/connection.go b/connection.go index 99a4c465e..95e77b239 100644 --- a/connection.go +++ b/connection.go @@ -15,10 +15,19 @@ package quickfix -import "io" +import ( + "context" + "io" +) -func writeLoop(connection io.Writer, messageOut chan []byte, log Log) { +func writeLoop(ctx context.Context, connection io.Writer, messageOut chan []byte, log Log) { for { + select { + case <-ctx.Done(): + return + default: + } + msg, ok := <-messageOut if !ok { return @@ -30,10 +39,16 @@ func writeLoop(connection io.Writer, messageOut chan []byte, log Log) { } } -func readLoop(parser *parser, msgIn chan fixIn, log Log) { +func readLoop(ctx context.Context, parser *parser, msgIn chan fixIn, log Log) { defer close(msgIn) for { + select { + case <-ctx.Done(): + return + default: + } + msg, err := parser.ReadMessage() if err != nil { log.OnEvent(err.Error()) diff --git a/connection_internal_test.go b/connection_internal_test.go index 081b3c110..3396c007e 100644 --- a/connection_internal_test.go +++ b/connection_internal_test.go @@ -17,11 +17,13 @@ package quickfix import ( "bytes" + "context" "strings" "testing" ) func TestWriteLoop(t *testing.T) { + ctx := context.Background() writer := bytes.NewBufferString("") msgOut := make(chan []byte) @@ -31,7 +33,7 @@ func TestWriteLoop(t *testing.T) { msgOut <- []byte("test msg 3") close(msgOut) }() - writeLoop(writer, msgOut, nullLog{}) + writeLoop(ctx, writer, msgOut, nullLog{}) expected := "test msg 1 test msg 2 test msg 3" @@ -40,12 +42,32 @@ func TestWriteLoop(t *testing.T) { } } +func TestWriteLoopCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + writer := bytes.NewBufferString("") + msgOut := make(chan []byte) + + go func() { + msgOut <- []byte("test msg 1") + cancel() + }() + writeLoop(ctx, writer, msgOut, nullLog{}) + + expected := "test msg 1" + + if writer.String() != expected { + t.Errorf("expected %v got %v", expected, writer.String()) + } +} + func TestReadLoop(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() msgIn := make(chan fixIn) stream := "hello8=FIX.4.09=5blah10=103garbage8=FIX.4.09=4foo10=103" parser := newParser(strings.NewReader(stream)) - go readLoop(parser, msgIn, nullLog{}) + go readLoop(ctx, parser, msgIn, nullLog{}) var tests = []struct { expectedMsg string @@ -71,3 +93,18 @@ func TestReadLoop(t *testing.T) { } } } + +func TestReadLoopCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + msgIn := make(chan fixIn) + stream := "hello8=FIX.4.09=5blah10=103garbage8=FIX.4.09=4foo10=103" + + parser := newParser(strings.NewReader(stream)) + + cancel() + go readLoop(ctx, parser, msgIn, nullLog{}) + _, ok := <-msgIn + if ok { + t.Error("Channel should be closed on context cancel") + } +} diff --git a/initiator.go b/initiator.go index 18451477e..f62b1f1d1 100644 --- a/initiator.go +++ b/initiator.go @@ -163,14 +163,17 @@ func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, di return } - ctx, cancel := context.WithCancel(context.Background()) + ctx := context.Background() + dialCtx, dialCancel := context.WithCancel(ctx) + readWriteCtx, readWriteCancel := context.WithCancel(ctx) // We start a goroutine in order to be able to cancel the dialer mid-connection // on receiving a stop signal to stop the initiator. go func() { select { case <-i.stopChan: - cancel() + dialCancel() + readWriteCancel() case <-ctx.Done(): return } @@ -183,7 +186,7 @@ func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, di address := session.SocketConnectAddress[connectionAttempt%len(session.SocketConnectAddress)] session.log.OnEventf("Connecting to: %v", address) - netConn, err := dialer.DialContext(ctx, "tcp", address) + netConn, err := dialer.DialContext(dialCtx, "tcp", address) if err != nil { session.log.OnEventf("Failed to connect: %v", err) goto reconnect @@ -207,24 +210,25 @@ func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, di msgIn = make(chan fixIn) msgOut = make(chan []byte) - if err := session.connect(msgIn, msgOut); err != nil { - session.log.OnEventf("Failed to initiate: %v", err) - goto reconnect - } - go readLoop(newParser(bufio.NewReader(netConn)), msgIn, session.log) + go readLoop(readWriteCtx, newParser(bufio.NewReader(netConn)), msgIn, session.log) disconnected = make(chan interface{}) go func() { - writeLoop(netConn, msgOut, session.log) + writeLoop(readWriteCtx, netConn, msgOut, session.log) if err := netConn.Close(); err != nil { session.log.OnEvent(err.Error()) } close(disconnected) }() + if err := session.connect(msgIn, msgOut); err != nil { + session.log.OnEventf("Failed to initiate: %v", err) + goto reconnect + } + // This ensures we properly cleanup the goroutine and context used for // dial cancelation after successful connection. - cancel() + dialCancel() select { case <-disconnected: @@ -233,7 +237,7 @@ func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, di } reconnect: - cancel() + dialCancel() connectionAttempt++ session.log.OnEventf("Reconnecting in %v", session.ReconnectInterval) diff --git a/initiator_test.go b/initiator_test.go new file mode 100644 index 000000000..1eb70dc34 --- /dev/null +++ b/initiator_test.go @@ -0,0 +1,177 @@ +package quickfix + +import ( + "context" + "errors" + "net" + "testing" + "time" + + "github.com/quickfixgo/quickfix/config" +) + +func TestNewInitiatorKeepReconnectingAfterLogonError(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + logonCount := 0 + app := &mockApplication{} + storeFactory := &mockMessageStoreFactory{saveMessageAndIncrError: errDBError} + logFactory := &mockLogFactory{ + onEvent: func(s string) { + if s == "Sending logon request" { + logonCount++ + if logonCount >= 5 { + cancel() + } + } + }, + } + + settings := NewSettings() + sessionSettings := newSession() + sessionID, err := settings.AddSession(sessionSettings) + if err != nil { + t.Fatalf("Expected no error adding session, got %v", err) + } + + initiator, err := NewInitiator(app, storeFactory, settings, logFactory) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + s, ok := initiator.sessions[sessionID] + if !ok { + t.Fatal("Expected session to be created") + } + + initiator.stopChan = make(chan interface{}) + go initiator.handleConnection(s, nil, &mockDialer{}) + + select { + case <-ctx.Done(): + initiator.Stop() + return + case <-time.After(10 * time.Second): + t.Error("retry stopped after logon error") + return + } +} + +func newSession() *SessionSettings { + sessionSettings := NewSessionSettings() + sessionSettings.Set(config.BeginString, "FIX.4.4") + sessionSettings.Set(config.SenderCompID, "X") + sessionSettings.Set(config.TargetCompID, "X") + sessionSettings.Set(config.HeartBtInt, "30") + sessionSettings.Set(config.SocketConnectHost, "localhost") + sessionSettings.Set(config.SocketConnectPort, "9878") + sessionSettings.Set(config.ReconnectInterval, "1") + return sessionSettings +} + +type mockApplication struct{} + +func (m *mockApplication) OnCreate(_ SessionID) {} +func (m *mockApplication) OnLogon(_ SessionID) {} +func (m *mockApplication) OnLogout(_ SessionID) {} +func (m *mockApplication) ToAdmin(_ *Message, _ SessionID) {} +func (m *mockApplication) ToApp(_ *Message, _ SessionID) error { return nil } +func (m *mockApplication) FromAdmin(_ *Message, _ SessionID) MessageRejectError { + return nil +} +func (m *mockApplication) FromApp(_ *Message, _ SessionID) MessageRejectError { + return nil +} + +type mockMessageStoreFactory struct { + saveMessageAndIncrError error +} + +func (m *mockMessageStoreFactory) Create(_ SessionID) (MessageStore, error) { + return &mockMessageStore{saveMessageAndIncrError: m.saveMessageAndIncrError}, nil +} + +var errDBError = errors.New("db error") + +type mockMessageStore struct { + saveMessageAndIncrError error +} + +func (m *mockMessageStore) NextSenderMsgSeqNum() int { return 1 } +func (m *mockMessageStore) NextTargetMsgSeqNum() int { return 1 } +func (m *mockMessageStore) IncrSenderMsgSeqNum() error { return nil } +func (m *mockMessageStore) IncrTargetMsgSeqNum() error { return nil } +func (m *mockMessageStore) SetNextSenderMsgSeqNum(_ int) error { return nil } +func (m *mockMessageStore) SetNextTargetMsgSeqNum(_ int) error { return nil } +func (m *mockMessageStore) CreationTime() time.Time { return time.Now() } +func (m *mockMessageStore) SaveMessage(_ int, _ []byte) error { return nil } +func (m *mockMessageStore) GetMessages(_, _ int) ([][]byte, error) { return nil, nil } +func (m *mockMessageStore) Refresh() error { return nil } +func (m *mockMessageStore) Reset() error { return nil } +func (m *mockMessageStore) Close() error { return nil } +func (m *mockMessageStore) IncrNextSenderMsgSeqNum() error { return nil } +func (m *mockMessageStore) IncrNextTargetMsgSeqNum() error { return nil } +func (m *mockMessageStore) IterateMessages(int, int, func([]byte) error) error { return nil } +func (m *mockMessageStore) SaveMessageAndIncrNextSenderMsgSeqNum(_ int, _ []byte) error { + return m.saveMessageAndIncrError +} +func (m *mockMessageStore) SetCreationTime(time.Time) {} + +type mockLogFactory struct { + shouldFail bool + onEvent func(string) +} + +func (m *mockLogFactory) Create() (Log, error) { + if m.shouldFail { + return nil, errors.New("log factory error") + } + return &mockLog{ + onEvent: m.onEvent, + }, nil +} + +func (m *mockLogFactory) CreateSessionLog(_ SessionID) (Log, error) { + return &mockLog{ + onEvent: m.onEvent, + }, nil +} + +type mockDialer struct{} + +type mockAddr struct { + network string + address string +} + +func (m *mockAddr) Network() string { return m.network } +func (m *mockAddr) String() string { return m.address } + +type mockConn struct{} + +func (m *mockConn) Read(_ []byte) (n int, err error) { return 0, nil } +func (m *mockConn) Write(_ []byte) (n int, err error) { return 0, nil } +func (m *mockConn) Close() error { return nil } +func (m *mockConn) LocalAddr() net.Addr { return &mockAddr{network: "tcp", address: "127.0.0.1:8080"} } +func (m *mockConn) RemoteAddr() net.Addr { return &mockAddr{network: "tcp", address: "127.0.0.1:9090"} } +func (m *mockConn) SetDeadline(_ time.Time) error { return nil } +func (m *mockConn) SetReadDeadline(_ time.Time) error { return nil } +func (m *mockConn) SetWriteDeadline(_ time.Time) error { return nil } + +func (m *mockDialer) DialContext(_ context.Context, _, _ string) (net.Conn, error) { + return &mockConn{}, nil +} + +type mockLog struct { + onEvent func(string) +} + +func (m *mockLog) OnIncoming(_ []byte) {} +func (m *mockLog) OnOutgoing(_ []byte) {} +func (m *mockLog) OnEvent(s string) { + if m.onEvent != nil { + m.onEvent(s) + } +} +func (m *mockLog) OnEventf(_ string, _ ...interface{}) {} diff --git a/session.go b/session.go index a6d296999..27a06b625 100644 --- a/session.go +++ b/session.go @@ -849,15 +849,15 @@ func (s *session) onAdmin(msg interface{}) { return } - if msg.err != nil { - close(msg.err) - } - s.messageIn = msg.messageIn s.messageOut = msg.messageOut s.sentReset = false - s.Connect(s) + err := s.Connect(s) + if msg.err != nil { + msg.err <- err + close(msg.err) + } case stopReq: s.Stop(s) diff --git a/session_state.go b/session_state.go index 6fe4dded7..8d7182489 100644 --- a/session_state.go +++ b/session_state.go @@ -36,36 +36,37 @@ func (sm *stateMachine) Start(s *session) { sm.CheckSessionTime(s, time.Now()) } -func (sm *stateMachine) Connect(session *session) { +func (sm *stateMachine) Connect(session *session) error { // No special logon logic needed for FIX Acceptors. if !session.InitiateLogon { sm.setState(session, logonState{}) - return + return nil } if session.RefreshOnLogon { if err := session.store.Refresh(); err != nil { session.logError(err) - return + return err } } if session.ResetOnLogon { if err := session.store.Reset(); err != nil { session.logError(err) - return + return err } } session.log.OnEvent("Sending logon request") if err := session.sendLogon(); err != nil { session.logError(err) - return + return err } sm.setState(session, logonState{}) // Fire logon timeout event after the pre-configured delay period. time.AfterFunc(session.LogonTimeout, func() { session.sessionEvent <- internal.LogonTimeout }) + return nil } func (sm *stateMachine) Stop(session *session) { diff --git a/session_test.go b/session_test.go index 2ea5c3508..07e06ecbe 100644 --- a/session_test.go +++ b/session_test.go @@ -17,6 +17,7 @@ package quickfix import ( "bytes" + "fmt" "testing" "time" @@ -743,6 +744,28 @@ func (s *SessionSuite) TestOnAdminConnectRefreshOnLogon() { } } +func (s *SessionSuite) TestOnAdminConnectError() { + dbError := fmt.Errorf("db error") + s.SetupTest() + errChannel := make(chan error, 1) + s.session.RefreshOnLogon = true + adminMsg := connect{ + messageOut: s.Receiver.sendChannel, + err: errChannel, + } + s.session.State = latentState{} + s.session.InitiateLogon = true + s.MockStore.On("Refresh").Return(dbError) + s.MockApp.On("ToAdmin") + s.session.onAdmin(adminMsg) + + err := <-errChannel + s.Assert().Equal(err, dbError) + + s.MockStore.AssertExpectations(s.T()) + +} + func (s *SessionSuite) TestOnAdminConnectAccept() { adminMsg := connect{ messageOut: s.Receiver.sendChannel,