@@ -11,6 +11,7 @@ package mysql
1111import (
1212 "bytes"
1313 "errors"
14+ "fmt"
1415 "net"
1516 "testing"
1617 "time"
@@ -132,31 +133,57 @@ func TestReadPacketSingleByte(t *testing.T) {
132133 }
133134}
134135
136+ type mockLogger struct {
137+ bytes.Buffer
138+ }
139+
140+ func (ml * mockLogger ) Print (v ... any ) {
141+ ml .WriteString (fmt .Sprint (v ... ) + "\n " )
142+ }
143+
135144func TestReadPacketWrongSequenceID (t * testing.T ) {
136145 conn := new (mockConn )
137146 mc := & mysqlConn {
138147 buf : newBuffer (conn ),
148+ cfg : NewConfig (),
139149 }
150+ logger := & mockLogger {}
151+ mc .cfg .Logger = Logger (logger )
140152
141153 // too low sequence id
142154 conn .data = []byte {0x01 , 0x00 , 0x00 , 0x00 , 0xff }
143155 conn .maxReads = 1
144156 mc .sequence = 1
145- _ , err := mc .readPacket ()
146- if err != ErrPktSync {
147- t .Errorf ("expected ErrPktSync, got %v" , err )
157+ data , err := mc .readPacket ()
158+ if err != nil {
159+ t .Errorf ("expected nil, got %v" , err )
160+ }
161+ if len (data ) != 1 || data [0 ] != 0xff {
162+ t .Errorf ("expected [0xff], got % x" , data )
163+ }
164+ logMsg := logger .String ()
165+ if logMsg != ErrPktSync .Error ()+ "\n " {
166+ t .Errorf ("expected ErrPktSync.Error(), got %q" , logMsg )
148167 }
149168
150169 // reset
151170 conn .reads = 0
152171 mc .sequence = 0
153172 mc .buf = newBuffer (conn )
173+ logger .Reset ()
154174
155175 // too high sequence id
156176 conn .data = []byte {0x01 , 0x00 , 0x00 , 0x42 , 0xff }
157- _ , err = mc .readPacket ()
158- if err != ErrPktSyncMul {
159- t .Errorf ("expected ErrPktSyncMul, got %v" , err )
177+ data , err = mc .readPacket ()
178+ if err != nil {
179+ t .Errorf ("expected nil, got %v" , err )
180+ }
181+ if len (data ) != 1 || data [0 ] != 0xff {
182+ t .Errorf ("expected [0xff], got % x" , data )
183+ }
184+ logMsg = logger .String ()
185+ if logMsg != ErrPktSyncMul .Error ()+ "\n " {
186+ t .Errorf ("expected ErrPktSync.Error(), got %q" , logMsg )
160187 }
161188}
162189
0 commit comments