From ddf24e642795181d818ff0e452aa72ae4aa13e0f Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Tue, 22 Oct 2013 10:54:55 +0200 Subject: [PATCH 1/9] use the connection buffer for writing --- benchmark_test.go | 7 +- buffer.go | 51 +++++- packets.go | 390 +++++++++++++++++++++++++++------------------- 3 files changed, 281 insertions(+), 167 deletions(-) diff --git a/benchmark_test.go b/benchmark_test.go index 2fb3f6b2b..67ce9c547 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -69,23 +69,26 @@ func BenchmarkQuery(b *testing.B) { stmt := tb.checkStmt(db.Prepare("SELECT val FROM foo WHERE id=?")) defer stmt.Close() - b.StartTimer() remain := int64(b.N) var wg sync.WaitGroup wg.Add(concurrencyLevel) defer wg.Wait() + b.StartTimer() + for i := 0; i < concurrencyLevel; i++ { go func() { - defer wg.Done() for { if atomic.AddInt64(&remain, -1) < 0 { + wg.Done() return } + var got string tb.check(stmt.QueryRow(1).Scan(&got)) if got != "one" { b.Errorf("query = %q; want one", got) + wg.Done() return } } diff --git a/buffer.go b/buffer.go index d6c8b72e4..e47cdf455 100644 --- a/buffer.go +++ b/buffer.go @@ -12,7 +12,10 @@ import "io" const defaultBufSize = 4096 -// A read buffer similar to bufio.Reader but zero-copy-ish +// A buffer which is used for both reading and writing. +// This is possible since communication on each connection is synchronous. +// In other words, we can't write and read simultaneously on the same connection. +// The buffer is similar to bufio.Reader / Writer but zero-copy-ish // Also highly optimized for this particular use case. type buffer struct { buf []byte @@ -37,8 +40,11 @@ func (b *buffer) fill(need int) (err error) { } // grow buffer if necessary + // TODO: let the buffer shrink again at some point + // Maybe keep the org buf slice and swap back? if need > len(b.buf) { - newBuf := make([]byte, need) + // Round up to the next multiple of the default size + newBuf := make([]byte, ((need/defaultBufSize)+1)*defaultBufSize) copy(newBuf, b.buf) b.buf = newBuf } @@ -74,3 +80,44 @@ func (b *buffer) readNext(need int) (p []byte, err error) { b.length -= need return } + +// returns a buffer with the requested size. +// If possible, a slice from the existing buffer is returned. +// Otherwise a bigger buffer is made. +// Only one buffer (total) can be used at a time. +func (b *buffer) writeBuffer(length int) []byte { + if b.length > 0 { + return nil + } + + // test (cheap) general case first + if length <= defaultBufSize || length <= cap(b.buf) { + return b.buf[:length] + } + + if length < maxPacketSize { + b.buf = make([]byte, length) + return b.buf + } + return make([]byte, length) +} + +// shortcut which can be used if the requested buffer is guaranteed to be +// smaller than defaultBufSize +// Only one buffer (total) can be used at a time. +func (b *buffer) smallWriteBuffer(length int) []byte { + if b.length == 0 { + return b.buf[:length] + } + return nil +} + +// takeCompleteBuffer returns the complete existing buffer. +// This can be used if the necessary buffer size is unknown. +// Only one buffer (total) can be used at a time. +func (b *buffer) takeCompleteBuffer() []byte { + if b.length == 0 { + return b.buf + } + return nil +} diff --git a/packets.go b/packets.go index 9b4c99e8e..d30f644e6 100644 --- a/packets.go +++ b/packets.go @@ -239,8 +239,13 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { pktLen += n + 1 } - // Calculate packet length and make buffer with that size - data := make([]byte, pktLen+4) + // Calculate packet length and get buffer with that size + data := mc.buf.smallWriteBuffer(pktLen + 4) + if data == nil { + // can not take the buffer. Something must be wrong with the connection + errLog.Print("Busy buffer") + return driver.ErrBadConn + } // ClientFlags [32 bit] data[4] = byte(clientFlags) @@ -249,10 +254,10 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { data[7] = byte(clientFlags >> 24) // MaxPacketSize [32 bit] (none) - //data[8] = 0x00 - //data[9] = 0x00 - //data[10] = 0x00 - //data[11] = 0x00 + data[8] = 0x00 + data[9] = 0x00 + data[10] = 0x00 + data[11] = 0x00 // Charset [1 byte] data[12] = collation_utf8_general_ci @@ -293,7 +298,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { if len(mc.cfg.user) > 0 { pos += copy(data[pos:], mc.cfg.user) } - //data[pos] = 0x00 + data[pos] = 0x00 pos++ // ScrambleBuffer [length encoded integer] @@ -303,7 +308,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { // Databasename [null terminated string] if len(mc.cfg.dbname) > 0 { pos += copy(data[pos:], mc.cfg.dbname) - //data[pos] = 0x00 + data[pos] = 0x00 } // Send Auth packet @@ -318,7 +323,12 @@ func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error { // Calculate the packet lenght and add a tailing 0 pktLen := len(scrambleBuff) + 1 - data := make([]byte, pktLen+4) + data := mc.buf.smallWriteBuffer(pktLen + 4) + if data == nil { + // can not take the buffer. Something must be wrong with the connection + errLog.Print("Busy buffer") + return driver.ErrBadConn + } // Add the packet header [24bit length + 1 byte sequence] data[0] = byte(pktLen) @@ -340,17 +350,24 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { // Reset Packet Sequence mc.sequence = 0 + data := mc.buf.smallWriteBuffer(4 + 1) + if data == nil { + // can not take the buffer. Something must be wrong with the connection + errLog.Print("Busy buffer") + return driver.ErrBadConn + } + + // Add the packet header [24bit length + 1 byte sequence] + data[0] = 0x01 // 1 byte long + data[1] = 0x00 + data[2] = 0x00 + data[3] = 0x00 // sequence is always 0 + + // Add command byte + data[4] = command + // Send CMD packet - return mc.writePacket([]byte{ - // Add the packet header [24bit length + 1 byte sequence] - 0x01, // 1 byte long - 0x00, - 0x00, - 0x00, // mc.sequence - - // Add command byte - command, - }) + return mc.writePacket(data) } func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { @@ -358,13 +375,18 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { mc.sequence = 0 pktLen := 1 + len(arg) - data := make([]byte, pktLen+4) + data := mc.buf.writeBuffer(pktLen + 4) + if data == nil { + // can not take the buffer. Something must be wrong with the connection + errLog.Print("Busy buffer") + return driver.ErrBadConn + } // Add the packet header [24bit length + 1 byte sequence] data[0] = byte(pktLen) data[1] = byte(pktLen >> 8) data[2] = byte(pktLen >> 16) - //data[3] = mc.sequence + data[3] = 0x00 // sequence is always 0 // Add command byte data[4] = command @@ -380,23 +402,30 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { // Reset Packet Sequence mc.sequence = 0 + data := mc.buf.smallWriteBuffer(4 + 1 + 4) + if data == nil { + // can not take the buffer. Something must be wrong with the connection + errLog.Print("Busy buffer") + return driver.ErrBadConn + } + + // Add the packet header [24bit length + 1 byte sequence] + data[0] = 0x05 // 1 bytes long + data[1] = 0x00 + data[2] = 0x00 + data[3] = 0x00 // sequence is always 0 + + // Add command byte + data[4] = command + + // Add arg [32 bit] + data[5] = byte(arg) + data[6] = byte(arg >> 8) + data[7] = byte(arg >> 16) + data[8] = byte(arg >> 24) + // Send CMD packet - return mc.writePacket([]byte{ - // Add the packet header [24bit length + 1 byte sequence] - 0x05, // 5 bytes long - 0x00, - 0x00, - 0x00, // mc.sequence - - // Add command byte - command, - - // Add arg [32 bit] - byte(arg), - byte(arg >> 8), - byte(arg >> 16), - byte(arg >> 24), - }) + return mc.writePacket(data) } /****************************************************************************** @@ -599,10 +628,10 @@ func (mc *mysqlConn) readColumns(count int) (columns []mysqlField, err error) { // Read Packets as Field Packets until EOF-Packet or an Error appears // http://dev.mysql.com/doc/internals/en/text-protocol.html#packet-ProtocolText::ResultsetRow -func (rows *mysqlRows) readRow(dest []driver.Value) (err error) { +func (rows *mysqlRows) readRow(dest []driver.Value) error { data, err := rows.mc.readPacket() if err != nil { - return + return err } // EOF Packet @@ -641,24 +670,22 @@ func (rows *mysqlRows) readRow(dest []driver.Value) (err error) { continue } } - return // err + return err // err != nil } - return + return nil } // Reads Packets until EOF-Packet or an Error appears. Returns count of Packets read -func (mc *mysqlConn) readUntilEOF() (err error) { - var data []byte - +func (mc *mysqlConn) readUntilEOF() error { for { - data, err = mc.readPacket() + data, err := mc.readPacket() // No Err and no EOF Packet if err == nil && data[0] != iEOF { continue } - return // Err or EOF + return err // Err or EOF } } @@ -710,11 +737,16 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (columnCount uint16, err error) } // http://dev.mysql.com/doc/internals/en/prepared-statements.html#com-stmt-send-long-data -func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) (err error) { +func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { maxLen := stmt.mc.maxPacketAllowed - 1 pktLen := maxLen argLen := len(arg) + + // Can not use the write buffer since + // a) the buffer is too small + // b) it is in use data := make([]byte, 4+1+4+2+argLen) + copy(data[4+1+4+2:], arg) for argLen > 0 { @@ -742,7 +774,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) (err error) data[10] = byte(paramID >> 8) // Send CMD packet - err = stmt.mc.writePacket(data[:4+pktLen]) + err := stmt.mc.writePacket(data[:4+pktLen]) if err == nil { argLen -= pktLen - (1 + 4 + 2) data = data[pktLen-(1+4+2):] @@ -758,7 +790,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) (err error) } // Execute Prepared Statement -// http://dev.mysql.com/doc/internals/en/prepared-statements.html#com-stmt-execute +// http://dev.mysql.com/doc/internals/en/com-stmt-execute.html func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { if len(args) != stmt.paramCount { return fmt.Errorf( @@ -770,107 +802,32 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // Reset packet-sequence stmt.mc.sequence = 0 - pktLen := 1 + 4 + 1 + 4 + ((stmt.paramCount + 7) >> 3) + 1 + (stmt.paramCount << 1) - paramValues := make([][]byte, stmt.paramCount) - paramTypes := make([]byte, (stmt.paramCount << 1)) - bitMask := uint64(0) - var i int - - for i = range args { - // build NULL-bitmap - if args[i] == nil { - bitMask += 1 << uint(i) - paramTypes[i<<1] = fieldTypeNULL - continue - } - - // cache types and values - switch v := args[i].(type) { - case int64: - paramTypes[i<<1] = fieldTypeLongLong - paramValues[i] = uint64ToBytes(uint64(v)) - pktLen += 8 - continue - - case float64: - paramTypes[i<<1] = fieldTypeDouble - paramValues[i] = uint64ToBytes(math.Float64bits(v)) - pktLen += 8 - continue - - case bool: - paramTypes[i<<1] = fieldTypeTiny - pktLen++ - if v { - paramValues[i] = []byte{0x01} - } else { - paramValues[i] = []byte{0x00} - } - continue - - case []byte: - paramTypes[i<<1] = fieldTypeString - if len(v) < stmt.mc.maxPacketAllowed-pktLen-(stmt.paramCount-(i+1))*64 { - paramValues[i] = append( - lengthEncodedIntegerToBytes(uint64(len(v))), - v..., - ) - pktLen += len(paramValues[i]) - continue - } else { - err := stmt.writeCommandLongData(i, v) - if err == nil { - continue - } - return err - } - - case string: - paramTypes[i<<1] = fieldTypeString - if len(v) < stmt.mc.maxPacketAllowed-pktLen-(stmt.paramCount-(i+1))*64 { - paramValues[i] = append( - lengthEncodedIntegerToBytes(uint64(len(v))), - []byte(v)..., - ) - pktLen += len(paramValues[i]) - continue - } else { - err := stmt.writeCommandLongData(i, []byte(v)) - if err == nil { - continue - } - return err - } - - case time.Time: - paramTypes[i<<1] = fieldTypeString - - var val []byte - if v.IsZero() { - val = []byte("0000-00-00") - } else { - val = []byte(v.In(stmt.mc.cfg.loc).Format(timeFormat)) - } - - paramValues[i] = append( - lengthEncodedIntegerToBytes(uint64(len(val))), - val..., - ) - pktLen += len(paramValues[i]) - continue + var data []byte - default: - return fmt.Errorf("Can't convert type: %T", args[i]) + if len(args) == 0 { + const pktLen = 1 + 4 + 1 + 4 + data = stmt.mc.buf.writeBuffer(4 + pktLen) + if data == nil { + // can not take the buffer. Something must be wrong with the connection + errLog.Print("Busy buffer") + return driver.ErrBadConn } - } - data := make([]byte, pktLen+4) + // packet header [4 bytes] + data[0] = byte(pktLen) + data[1] = byte(pktLen >> 8) + data[2] = byte(pktLen >> 16) + data[3] = 0x00 // sequence is always 0 + } else { + data = stmt.mc.buf.takeCompleteBuffer() + if data == nil { + // can not take the buffer. Something must be wrong with the connection + errLog.Print("Busy buffer") + return driver.ErrBadConn + } - // packet header [4 bytes] - data[0] = byte(pktLen) - data[1] = byte(pktLen >> 8) - data[2] = byte(pktLen >> 16) - data[3] = stmt.mc.sequence + // header (bytes 0-3) is added after we know the packet size + } // command [1 byte] data[4] = comStmtExecute @@ -882,32 +839,139 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { data[8] = byte(stmt.id >> 24) // flags (0: CURSOR_TYPE_NO_CURSOR) [1 byte] - //data[9] = 0x00 + data[9] = 0x00 // iteration_count (uint32(1)) [4 bytes] data[10] = 0x01 - //data[11] = 0x00 - //data[12] = 0x00 - //data[13] = 0x00 - - if stmt.paramCount > 0 { - // NULL-bitmap [(param_count+7)/8 bytes] - pos := 14 + ((stmt.paramCount + 7) >> 3) - // Convert bitMask to bytes - for i = 14; i < pos; i++ { - data[i] = byte(bitMask >> uint((i-14)<<3)) - } + data[11] = 0x00 + data[12] = 0x00 + data[13] = 0x00 + + if len(args) > 0 { + // NULL-bitmap [(len(args)+7)/8 bytes] + nullMask := uint64(0) + + pos := 4 + 1 + 4 + 1 + 4 + ((len(args) + 7) >> 3) // newParameterBoundFlag 1 [1 byte] data[pos] = 0x01 pos++ - // type of parameters [param_count*2 bytes] - pos += copy(data[pos:], paramTypes) + // type of each parameter [len(args)*2 bytes] + paramTypes := data[pos:] + pos += (len(args) << 1) + + // value of each parameter [n bytes] + paramValues := data[pos:pos] + valuesCap := cap(paramValues) + + for i := range args { + // build NULL-bitmap + if args[i] == nil { + nullMask += 1 << uint(i) + paramTypes[i+i] = fieldTypeNULL + continue + } + + // cache types and values + switch v := args[i].(type) { + case int64: + paramTypes[i+i] = fieldTypeLongLong + if cap(paramValues) <= len(paramValues)+8 { + paramValues = paramValues[:len(paramValues)+8] + binary.LittleEndian.PutUint64(paramValues, uint64(v)) + } else { + paramValues = append(paramValues, + uint64ToBytes(uint64(v))..., + ) + } + + case float64: + paramTypes[i+i] = fieldTypeDouble + if cap(paramValues) <= len(paramValues)+8 { + paramValues = paramValues[:len(paramValues)+8] + binary.LittleEndian.PutUint64(paramValues, math.Float64bits(v)) + } else { + paramValues = append(paramValues, + uint64ToBytes(math.Float64bits(v))..., + ) + } + + case bool: + paramTypes[i+i] = fieldTypeTiny + if v { + paramValues = append(paramValues, 0x01) + } else { + paramValues = append(paramValues, 0x00) + } + + case []byte: + paramTypes[i+i] = fieldTypeString + if len(v) < stmt.mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 { + paramValues = append(paramValues, + lengthEncodedIntegerToBytes(uint64(len(v)))..., + ) + paramValues = append(paramValues, v...) + } else { + if err := stmt.writeCommandLongData(i, v); err != nil { + return err + } + } + + case string: + paramTypes[i+i] = fieldTypeString + if len(v) < stmt.mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 { + paramValues = append(paramValues, + lengthEncodedIntegerToBytes(uint64(len(v)))..., + ) + paramValues = append(paramValues, v...) + } else { + if err := stmt.writeCommandLongData(i, []byte(v)); err != nil { + return err + } + } + + case time.Time: + paramTypes[i+i] = fieldTypeString + + var val []byte + if v.IsZero() { + val = []byte("0000-00-00") + } else { + val = []byte(v.In(stmt.mc.cfg.loc).Format(timeFormat)) + } + + paramValues = append(paramValues, + lengthEncodedIntegerToBytes(uint64(len(val)))..., + ) + paramValues = append(paramValues, val...) + + default: + return fmt.Errorf("Can't convert type: %T", args[i]) + } + } + + // Check if param values exceeded the available buffer + // In that case we must build the data packet with the new values buffer + if valuesCap != cap(paramValues) { + data = append(data[:pos], paramValues...) + stmt.mc.buf.buf = data + } + + pos += len(paramValues) + data = data[:pos] + + pktLen := pos - 4 + + // packet header [4 bytes] + data[0] = byte(pktLen) + data[1] = byte(pktLen >> 8) + data[2] = byte(pktLen >> 16) + data[3] = stmt.mc.sequence - // values for the parameters [n bytes] - for i = range paramValues { - pos += copy(data[pos:], paramValues[i]) + // Convert nullMask to bytes + for i, max := 14, 14+((stmt.paramCount+7)>>3); i < max; i++ { + data[i] = byte(nullMask >> uint((i-14)<<3)) } } From 656614d3b2d1e3a93885e6cee65dfdf532e34885 Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Tue, 22 Oct 2013 11:43:53 +0200 Subject: [PATCH 2/9] writeExecutePacket: fix capacity check --- packets.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packets.go b/packets.go index d30f644e6..471ee9815 100644 --- a/packets.go +++ b/packets.go @@ -877,7 +877,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { switch v := args[i].(type) { case int64: paramTypes[i+i] = fieldTypeLongLong - if cap(paramValues) <= len(paramValues)+8 { + if cap(paramValues)-len(paramValues)-8 >= 0 { paramValues = paramValues[:len(paramValues)+8] binary.LittleEndian.PutUint64(paramValues, uint64(v)) } else { @@ -888,7 +888,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { case float64: paramTypes[i+i] = fieldTypeDouble - if cap(paramValues) <= len(paramValues)+8 { + if cap(paramValues)-len(paramValues)-8 >= 0 { paramValues = paramValues[:len(paramValues)+8] binary.LittleEndian.PutUint64(paramValues, math.Float64bits(v)) } else { From 3d95bd01faee84ee252a5ff19407cda98fb1beb3 Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Tue, 22 Oct 2013 18:58:58 +0200 Subject: [PATCH 3/9] writeExecutePacket: fix packing --- packets.go | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/packets.go b/packets.go index 471ee9815..da6395da5 100644 --- a/packets.go +++ b/packets.go @@ -868,8 +868,9 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { for i := range args { // build NULL-bitmap if args[i] == nil { - nullMask += 1 << uint(i) + nullMask |= 1 << uint(i) paramTypes[i+i] = fieldTypeNULL + paramTypes[i+i+1] = 0x00 continue } @@ -877,9 +878,11 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { switch v := args[i].(type) { case int64: paramTypes[i+i] = fieldTypeLongLong + paramTypes[i+i+1] = 0x00 + if cap(paramValues)-len(paramValues)-8 >= 0 { paramValues = paramValues[:len(paramValues)+8] - binary.LittleEndian.PutUint64(paramValues, uint64(v)) + binary.LittleEndian.PutUint64(paramValues[len(paramValues)-8:], uint64(v)) } else { paramValues = append(paramValues, uint64ToBytes(uint64(v))..., @@ -888,9 +891,11 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { case float64: paramTypes[i+i] = fieldTypeDouble + paramTypes[i+i+1] = 0x00 + if cap(paramValues)-len(paramValues)-8 >= 0 { paramValues = paramValues[:len(paramValues)+8] - binary.LittleEndian.PutUint64(paramValues, math.Float64bits(v)) + binary.LittleEndian.PutUint64(paramValues[len(paramValues)-8:], math.Float64bits(v)) } else { paramValues = append(paramValues, uint64ToBytes(math.Float64bits(v))..., @@ -899,6 +904,8 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { case bool: paramTypes[i+i] = fieldTypeTiny + paramTypes[i+i+1] = 0x00 + if v { paramValues = append(paramValues, 0x01) } else { @@ -907,6 +914,8 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { case []byte: paramTypes[i+i] = fieldTypeString + paramTypes[i+i+1] = 0x00 + if len(v) < stmt.mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 { paramValues = append(paramValues, lengthEncodedIntegerToBytes(uint64(len(v)))..., @@ -920,6 +929,8 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { case string: paramTypes[i+i] = fieldTypeString + paramTypes[i+i+1] = 0x00 + if len(v) < stmt.mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 { paramValues = append(paramValues, lengthEncodedIntegerToBytes(uint64(len(v)))..., @@ -933,6 +944,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { case time.Time: paramTypes[i+i] = fieldTypeString + paramTypes[i+i+1] = 0x00 var val []byte if v.IsZero() { From 33d6df2bf4e7bdb45cccbf05ddc650e78c2bffef Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Tue, 22 Oct 2013 23:58:37 +0200 Subject: [PATCH 4/9] various refactoring --- packets.go | 157 ++++++++++++++++++++++------------------------------- 1 file changed, 65 insertions(+), 92 deletions(-) diff --git a/packets.go b/packets.go index da6395da5..84913b6af 100644 --- a/packets.go +++ b/packets.go @@ -23,9 +23,9 @@ import ( // http://dev.mysql.com/doc/internals/en/client-server-protocol.html // Read packet to buffer 'data' -func (mc *mysqlConn) readPacket() (data []byte, err error) { +func (mc *mysqlConn) readPacket() ([]byte, error) { // Read packet header - data, err = mc.buf.readNext(4) + data, err := mc.buf.readNext(4) if err != nil { errLog.Print(err.Error()) mc.Close() @@ -97,7 +97,7 @@ func (mc *mysqlConn) writePacket(data []byte) error { return mc.splitPacket(data) } -func (mc *mysqlConn) splitPacket(data []byte) (err error) { +func (mc *mysqlConn) splitPacket(data []byte) error { pktLen := len(data) - 4 if pktLen > mc.maxPacketAllowed { @@ -141,10 +141,10 @@ func (mc *mysqlConn) splitPacket(data []byte) (err error) { // Handshake Initialization Packet // http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::Handshake -func (mc *mysqlConn) readInitPacket() (cipher []byte, err error) { +func (mc *mysqlConn) readInitPacket() ([]byte, error) { data, err := mc.readPacket() if err != nil { - return + return nil, err } if data[0] == iERR { @@ -153,11 +153,11 @@ func (mc *mysqlConn) readInitPacket() (cipher []byte, err error) { // protocol version [1 byte] if data[0] < minProtocolVersion { - err = fmt.Errorf( + return nil, fmt.Errorf( "Unsupported MySQL Protocol Version %d. Protocol Version %d or higher is required", data[0], - minProtocolVersion) - return + minProtocolVersion, + ) } // server version [null terminated string] @@ -165,7 +165,7 @@ func (mc *mysqlConn) readInitPacket() (cipher []byte, err error) { pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4 // first part of the password cipher [8 bytes] - cipher = data[pos : pos+8] + cipher := data[pos : pos+8] // (filler) always 0x00 [1 byte] pos += 8 + 1 @@ -205,7 +205,7 @@ func (mc *mysqlConn) readInitPacket() (cipher []byte, err error) { //return errMalformPkt } - return + return cipher, nil } // Client Authentication Packet @@ -497,7 +497,7 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error { // SQL State [optional: # + 5bytes string] //sqlstate := string(data[pos : pos+6]) - if data[pos] == 0x23 { + if data[3] == 0x23 { pos = 9 } @@ -510,7 +510,7 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error { // Ok Packet // http://dev.mysql.com/doc/internals/en/overview.html#packet-OK_Packet -func (mc *mysqlConn) handleOkPacket(data []byte) (err error) { +func (mc *mysqlConn) handleOkPacket(data []byte) error { var n, m int // 0x00 [1 byte] @@ -525,72 +525,66 @@ func (mc *mysqlConn) handleOkPacket(data []byte) (err error) { // warning count [2 bytes] if !mc.strict { - return + return nil } else { pos := 1 + n + m + 2 if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 { - err = mc.getWarnings() + return mc.getWarnings() } + return nil } - - // message [until end of packet] - return } // Read Packets as Field Packets until EOF-Packet or an Error appears // http://dev.mysql.com/doc/internals/en/text-protocol.html#packet-Protocol::ColumnDefinition41 -func (mc *mysqlConn) readColumns(count int) (columns []mysqlField, err error) { - var data []byte - var i, pos, n int - var name []byte - - columns = make([]mysqlField, count) +func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { + columns := make([]mysqlField, count) - for { - data, err = mc.readPacket() + for i := 0; ; i++ { + data, err := mc.readPacket() if err != nil { - return + return nil, err } // EOF Packet if data[0] == iEOF && (len(data) == 5 || len(data) == 1) { - if i != count { - err = fmt.Errorf("ColumnsCount mismatch n:%d len:%d", count, len(columns)) + if i == count { + return columns, nil } - return + return nil, fmt.Errorf("ColumnsCount mismatch n:%d len:%d", count, len(columns)) } // Catalog - pos, err = skipLengthEnodedString(data) + pos, err := skipLengthEnodedString(data) if err != nil { - return + return nil, err } // Database [len coded string] - n, err = skipLengthEnodedString(data[pos:]) + n, err := skipLengthEnodedString(data[pos:]) if err != nil { - return + return nil, err } pos += n // Table [len coded string] n, err = skipLengthEnodedString(data[pos:]) if err != nil { - return + return nil, err } pos += n // Original table [len coded string] n, err = skipLengthEnodedString(data[pos:]) if err != nil { - return + return nil, err } pos += n // Name [len coded string] - name, _, n, err = readLengthEnodedString(data[pos:]) + name, _, n, err := readLengthEnodedString(data[pos:]) if err != nil { - return + return nil, err } columns[i].name = string(name) pos += n @@ -598,7 +592,7 @@ func (mc *mysqlConn) readColumns(count int) (columns []mysqlField, err error) { // Original name [len coded string] n, err = skipLengthEnodedString(data[pos:]) if err != nil { - return + return nil, err } // Filler [1 byte] @@ -621,8 +615,6 @@ func (mc *mysqlConn) readColumns(count int) (columns []mysqlField, err error) { //if pos < len(data) { // defaultVal, _, err = bytesToLengthCodedBinary(data[pos:]) //} - - i++ } } @@ -656,7 +648,10 @@ func (rows *mysqlRows) readRow(dest []driver.Value) error { switch rows.columns[i].fieldType { case fieldTypeTimestamp, fieldTypeDateTime, fieldTypeDate, fieldTypeNewDate: - dest[i], err = parseDateTime(string(dest[i].([]byte)), rows.mc.cfg.loc) + dest[i], err = parseDateTime( + string(dest[i].([]byte)), + rows.mc.cfg.loc, + ) if err == nil { continue } @@ -695,61 +690,52 @@ func (mc *mysqlConn) readUntilEOF() error { // Prepare Result Packets // http://dev.mysql.com/doc/internals/en/prepared-statements.html#com-stmt-prepare-response -func (stmt *mysqlStmt) readPrepareResultPacket() (columnCount uint16, err error) { +func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) { data, err := stmt.mc.readPacket() if err == nil { - // Position - pos := 0 - // packet indicator [1 byte] - if data[pos] != iOK { - err = stmt.mc.handleErrorPacket(data) - return + if data[0] != iOK { + return 0, stmt.mc.handleErrorPacket(data) } - pos++ // statement id [4 bytes] - stmt.id = binary.LittleEndian.Uint32(data[pos : pos+4]) - pos += 4 + stmt.id = binary.LittleEndian.Uint32(data[1 : 1+4]) // Column count [16 bit uint] - columnCount = binary.LittleEndian.Uint16(data[pos : pos+2]) - pos += 2 + columnCount := binary.LittleEndian.Uint16(data[1+4 : 1+4+2]) // Param count [16 bit uint] - stmt.paramCount = int(binary.LittleEndian.Uint16(data[pos : pos+2])) - pos += 2 + stmt.paramCount = int(binary.LittleEndian.Uint16(data[1+4+2 : 1+4+2+2])) // Reserved [8 bit] - pos++ // Warning count [16 bit uint] if !stmt.mc.strict { - return + return columnCount, nil } else { // Check for warnings count > 0, only available in MySQL > 4.1 - if len(data) >= 12 && binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 { - err = stmt.mc.getWarnings() + if len(data) >= 12 && binary.LittleEndian.Uint16(data[1+4+2+2+1:1+4+2+2+1+2]) > 0 { + return columnCount, stmt.mc.getWarnings() } + return columnCount, nil } } - return + return 0, err } // http://dev.mysql.com/doc/internals/en/prepared-statements.html#com-stmt-send-long-data func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { maxLen := stmt.mc.maxPacketAllowed - 1 pktLen := maxLen - argLen := len(arg) // Can not use the write buffer since // a) the buffer is too small // b) it is in use - data := make([]byte, 4+1+4+2+argLen) + data := make([]byte, 4+1+4+2+len(arg)) copy(data[4+1+4+2:], arg) - for argLen > 0 { + for argLen := len(arg); argLen > 0; argLen -= pktLen - (1 + 4 + 2) { if 1+4+2+argLen < maxLen { pktLen = 1 + 4 + 2 + argLen } @@ -776,7 +762,6 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { // Send CMD packet err := stmt.mc.writePacket(data[:4+pktLen]) if err == nil { - argLen -= pktLen - (1 + 4 + 2) data = data[pktLen-(1+4+2):] continue } @@ -796,7 +781,8 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { return fmt.Errorf( "Arguments count mismatch (Got: %d Has: %d)", len(args), - stmt.paramCount) + stmt.paramCount, + ) } // Reset packet-sequence @@ -991,10 +977,10 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { } // http://dev.mysql.com/doc/internals/en/prepared-statements.html#packet-ProtocolBinary::ResultsetRow -func (rows *mysqlRows) readBinaryRow(dest []driver.Value) (err error) { +func (rows *mysqlRows) readBinaryRow(dest []driver.Value) error { data, err := rows.mc.readPacket() if err != nil { - return + return err } // packet indicator [1 byte] @@ -1010,22 +996,16 @@ func (rows *mysqlRows) readBinaryRow(dest []driver.Value) (err error) { // NULL-bitmap, [(column-count + 7 + 2) / 8 bytes] pos := 1 + (len(dest)+7+2)>>3 - nullBitMap := data[1:pos] - - // values [rest] - var n int - var unsigned bool + nullMask := data[1:pos] for i := range dest { // Field is NULL // (byte >> bit-pos) % 2 == 1 - if ((nullBitMap[(i+2)>>3] >> uint((i+2)&7)) & 1) == 1 { + if ((nullMask[(i+2)>>3] >> uint((i+2)&7)) & 1) == 1 { dest[i] = nil continue } - unsigned = rows.columns[i].flags&flagUnsigned != 0 - // Convert to byte-coded string switch rows.columns[i].fieldType { case fieldTypeNULL: @@ -1034,7 +1014,7 @@ func (rows *mysqlRows) readBinaryRow(dest []driver.Value) (err error) { // Numeric Types case fieldTypeTiny: - if unsigned { + if rows.columns[i].flags&flagUnsigned != 0 { dest[i] = int64(data[pos]) } else { dest[i] = int64(int8(data[pos])) @@ -1043,7 +1023,7 @@ func (rows *mysqlRows) readBinaryRow(dest []driver.Value) (err error) { continue case fieldTypeShort, fieldTypeYear: - if unsigned { + if rows.columns[i].flags&flagUnsigned != 0 { dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2])) } else { dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2]))) @@ -1052,7 +1032,7 @@ func (rows *mysqlRows) readBinaryRow(dest []driver.Value) (err error) { continue case fieldTypeInt24, fieldTypeLong: - if unsigned { + if rows.columns[i].flags&flagUnsigned != 0 { dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4])) } else { dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4]))) @@ -1061,7 +1041,7 @@ func (rows *mysqlRows) readBinaryRow(dest []driver.Value) (err error) { continue case fieldTypeLongLong: - if unsigned { + if rows.columns[i].flags&flagUnsigned != 0 { val := binary.LittleEndian.Uint64(data[pos : pos+8]) if val > math.MaxInt64 { dest[i] = uint64ToString(val) @@ -1090,6 +1070,7 @@ func (rows *mysqlRows) readBinaryRow(dest []driver.Value) (err error) { fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB, fieldTypeVarString, fieldTypeString, fieldTypeGeometry: var isNull bool + var n int dest[i], isNull, n, err = readLengthEnodedString(data[pos:]) pos += n if err == nil { @@ -1100,14 +1081,11 @@ func (rows *mysqlRows) readBinaryRow(dest []driver.Value) (err error) { continue } } - return // err + return err // Date YYYY-MM-DD case fieldTypeDate, fieldTypeNewDate: - var num uint64 - var isNull bool - num, isNull, n = readLengthEncodedInteger(data[pos:]) - + num, isNull, n := readLengthEncodedInteger(data[pos:]) pos += n if isNull { @@ -1130,10 +1108,7 @@ func (rows *mysqlRows) readBinaryRow(dest []driver.Value) (err error) { // Time [-][H]HH:MM:SS[.fractal] case fieldTypeTime: - var num uint64 - var isNull bool - num, isNull, n = readLengthEncodedInteger(data[pos:]) - + num, isNull, n := readLengthEncodedInteger(data[pos:]) pos += n if num == 0 { @@ -1179,9 +1154,7 @@ func (rows *mysqlRows) readBinaryRow(dest []driver.Value) (err error) { // Timestamp YYYY-MM-DD HH:MM:SS[.fractal] case fieldTypeTimestamp, fieldTypeDateTime: - var num uint64 - var isNull bool - num, isNull, n = readLengthEncodedInteger(data[pos:]) + num, isNull, n := readLengthEncodedInteger(data[pos:]) pos += n @@ -1209,5 +1182,5 @@ func (rows *mysqlRows) readBinaryRow(dest []driver.Value) (err error) { } } - return + return nil } From 5975ca92129ebc024aa295759dc9e8ad19687a6f Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Wed, 23 Oct 2013 13:17:59 +0200 Subject: [PATCH 5/9] more refactoring Try to remove unnecessary indirections and initialisations with zero. Also update links to the MySQL doc --- connection.go | 57 ++++++++++++++++++++++++--------------------------- packets.go | 57 +++++++++++++++++++++++++++------------------------ rows.go | 6 +++--- statement.go | 37 ++++++++++++++++----------------- 4 files changed, 78 insertions(+), 79 deletions(-) diff --git a/connection.go b/connection.go index 36a27b572..f769f7869 100644 --- a/connection.go +++ b/connection.go @@ -136,14 +136,14 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { columnCount, err := stmt.readPrepareResultPacket() if err == nil { if stmt.paramCount > 0 { - stmt.params, err = stmt.mc.readColumns(stmt.paramCount) + stmt.params, err = mc.readColumns(stmt.paramCount) if err != nil { return nil, err } } if columnCount > 0 { - err = stmt.mc.readUntilEOF() + err = mc.readUntilEOF() } } @@ -171,26 +171,24 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err } // Internal function to execute commands -func (mc *mysqlConn) exec(query string) (err error) { +func (mc *mysqlConn) exec(query string) error { // Send command - err = mc.writeCommandPacketStr(comQuery, query) + err := mc.writeCommandPacketStr(comQuery, query) if err != nil { - return + return err } // Read Result - var resLen int - resLen, err = mc.readResultSetHeaderPacket() + resLen, err := mc.readResultSetHeaderPacket() if err == nil && resLen > 0 { - err = mc.readUntilEOF() - if err != nil { - return + if err = mc.readUntilEOF(); err != nil { + return err } err = mc.readUntilEOF() } - return + return err } func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) { @@ -211,7 +209,6 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro return rows, err } } - return nil, err } @@ -221,29 +218,29 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro // Gets the value of the given MySQL System Variable // The returned byte slice is only valid until the next read -func (mc *mysqlConn) getSystemVar(name string) (val []byte, err error) { +func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { // Send command - err = mc.writeCommandPacketStr(comQuery, "SELECT @@"+name) + if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil { + return nil, err + } + + // Read Result + resLen, err := mc.readResultSetHeaderPacket() if err == nil { - // Read Result - var resLen int - resLen, err = mc.readResultSetHeaderPacket() - if err == nil { - rows := &mysqlRows{mc, false, nil, false} + rows := &mysqlRows{mc, false, nil, false} - if resLen > 0 { - // Columns - rows.columns, err = mc.readColumns(resLen) + if resLen > 0 { + // Columns + rows.columns, err = mc.readColumns(resLen) + if err != nil { + return nil, err } + } - dest := make([]driver.Value, resLen) - err = rows.readRow(dest) - if err == nil { - val = dest[0].([]byte) - err = mc.readUntilEOF() - } + dest := make([]driver.Value, resLen) + if err = rows.readRow(dest); err == nil { + return dest[0].([]byte), mc.readUntilEOF() } } - - return + return nil, err } diff --git a/packets.go b/packets.go index 84913b6af..a14b7761e 100644 --- a/packets.go +++ b/packets.go @@ -140,7 +140,7 @@ func (mc *mysqlConn) splitPacket(data []byte) error { ******************************************************************************/ // Handshake Initialization Packet -// http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::Handshake +// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake func (mc *mysqlConn) readInitPacket() ([]byte, error) { data, err := mc.readPacket() if err != nil { @@ -197,7 +197,6 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) { // TODO: Verify string termination // EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2) // \NUL otherwise - // http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::Handshake // //if data[len(data)-1] == 0 { // return @@ -209,7 +208,7 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) { } // Client Authentication Packet -// http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::HandshakeResponse +// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { // Adjust client flags based on server support clientFlags := clientProtocol41 | @@ -263,7 +262,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { data[12] = collation_utf8_general_ci // SSL Connection Request Packet - // http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::SSLRequest + // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest if mc.cfg.tls != nil { // Packet header [24bit length + 1 byte sequence] data[0] = byte((4 + 4 + 1 + 23)) @@ -316,7 +315,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { } // Client old authentication packet -// http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::AuthSwitchResponse +// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error { // User password scrambleBuff := scrambleOldPassword(cipher, []byte(mc.cfg.passwd)) @@ -454,7 +453,7 @@ func (mc *mysqlConn) readResultOK() error { } // Result Set Header Packet -// http://dev.mysql.com/doc/internals/en/text-protocol.html#packet-ProtocolText::Resultset +// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) { data, err := mc.readPacket() if err == nil { @@ -482,7 +481,7 @@ func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) { } // Error Packet -// http://dev.mysql.com/doc/internals/en/overview.html#packet-ERR_Packet +// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-ERR_Packet func (mc *mysqlConn) handleErrorPacket(data []byte) error { if data[0] != iERR { return errMalformPkt @@ -509,7 +508,7 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error { } // Ok Packet -// http://dev.mysql.com/doc/internals/en/overview.html#packet-OK_Packet +// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet func (mc *mysqlConn) handleOkPacket(data []byte) error { var n, m int @@ -536,7 +535,7 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error { } // Read Packets as Field Packets until EOF-Packet or an Error appears -// http://dev.mysql.com/doc/internals/en/text-protocol.html#packet-Protocol::ColumnDefinition41 +// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41 func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { columns := make([]mysqlField, count) @@ -619,9 +618,11 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { } // Read Packets as Field Packets until EOF-Packet or an Error appears -// http://dev.mysql.com/doc/internals/en/text-protocol.html#packet-ProtocolText::ResultsetRow +// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow func (rows *mysqlRows) readRow(dest []driver.Value) error { - data, err := rows.mc.readPacket() + mc := rows.mc + + data, err := mc.readPacket() if err != nil { return err } @@ -642,7 +643,7 @@ func (rows *mysqlRows) readRow(dest []driver.Value) error { pos += n if err == nil { if !isNull { - if !rows.mc.parseTime { + if !mc.parseTime { continue } else { switch rows.columns[i].fieldType { @@ -650,7 +651,7 @@ func (rows *mysqlRows) readRow(dest []driver.Value) error { fieldTypeDate, fieldTypeNewDate: dest[i], err = parseDateTime( string(dest[i].([]byte)), - rows.mc.cfg.loc, + mc.cfg.loc, ) if err == nil { continue @@ -689,7 +690,7 @@ func (mc *mysqlConn) readUntilEOF() error { ******************************************************************************/ // Prepare Result Packets -// http://dev.mysql.com/doc/internals/en/prepared-statements.html#com-stmt-prepare-response +// http://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) { data, err := stmt.mc.readPacket() if err == nil { @@ -723,7 +724,7 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) { return 0, err } -// http://dev.mysql.com/doc/internals/en/prepared-statements.html#com-stmt-send-long-data +// http://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { maxLen := stmt.mc.maxPacketAllowed - 1 pktLen := maxLen @@ -785,14 +786,16 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { ) } + mc := stmt.mc + // Reset packet-sequence - stmt.mc.sequence = 0 + mc.sequence = 0 var data []byte if len(args) == 0 { const pktLen = 1 + 4 + 1 + 4 - data = stmt.mc.buf.writeBuffer(4 + pktLen) + data = mc.buf.writeBuffer(4 + pktLen) if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print("Busy buffer") @@ -805,7 +808,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { data[2] = byte(pktLen >> 16) data[3] = 0x00 // sequence is always 0 } else { - data = stmt.mc.buf.takeCompleteBuffer() + data = mc.buf.takeCompleteBuffer() if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print("Busy buffer") @@ -902,7 +905,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { paramTypes[i+i] = fieldTypeString paramTypes[i+i+1] = 0x00 - if len(v) < stmt.mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 { + if len(v) < mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 { paramValues = append(paramValues, lengthEncodedIntegerToBytes(uint64(len(v)))..., ) @@ -917,7 +920,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { paramTypes[i+i] = fieldTypeString paramTypes[i+i+1] = 0x00 - if len(v) < stmt.mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 { + if len(v) < mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 { paramValues = append(paramValues, lengthEncodedIntegerToBytes(uint64(len(v)))..., ) @@ -936,7 +939,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { if v.IsZero() { val = []byte("0000-00-00") } else { - val = []byte(v.In(stmt.mc.cfg.loc).Format(timeFormat)) + val = []byte(v.In(mc.cfg.loc).Format(timeFormat)) } paramValues = append(paramValues, @@ -953,7 +956,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // In that case we must build the data packet with the new values buffer if valuesCap != cap(paramValues) { data = append(data[:pos], paramValues...) - stmt.mc.buf.buf = data + mc.buf.buf = data } pos += len(paramValues) @@ -965,18 +968,18 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { data[0] = byte(pktLen) data[1] = byte(pktLen >> 8) data[2] = byte(pktLen >> 16) - data[3] = stmt.mc.sequence + data[3] = mc.sequence // Convert nullMask to bytes - for i, max := 14, 14+((stmt.paramCount+7)>>3); i < max; i++ { - data[i] = byte(nullMask >> uint((i-14)<<3)) + for i, max := 0, (stmt.paramCount+7)>>3; i < max; i++ { + data[i+14] = byte(nullMask >> uint(i<<3)) } } - return stmt.mc.writePacket(data) + return mc.writePacket(data) } -// http://dev.mysql.com/doc/internals/en/prepared-statements.html#packet-ProtocolBinary::ResultsetRow +// http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html func (rows *mysqlRows) readBinaryRow(dest []driver.Value) error { data, err := rows.mc.readPacket() if err != nil { diff --git a/rows.go b/rows.go index 726f02ac0..71dcfb1a7 100644 --- a/rows.go +++ b/rows.go @@ -26,12 +26,12 @@ type mysqlRows struct { eof bool } -func (rows *mysqlRows) Columns() (columns []string) { - columns = make([]string, len(rows.columns)) +func (rows *mysqlRows) Columns() []string { + columns := make([]string, len(rows.columns)) for i := range columns { columns[i] = rows.columns[i].name } - return + return columns } func (rows *mysqlRows) Close() (err error) { diff --git a/statement.go b/statement.go index 56b12533f..025f2ecf5 100644 --- a/statement.go +++ b/statement.go @@ -19,14 +19,14 @@ type mysqlStmt struct { params []mysqlField } -func (stmt *mysqlStmt) Close() (err error) { +func (stmt *mysqlStmt) Close() error { if stmt.mc == nil || stmt.mc.netConn == nil { return errInvalidConn } - err = stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id) + err := stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id) stmt.mc = nil - return + return err } func (stmt *mysqlStmt) NumInput() int { @@ -34,33 +34,34 @@ func (stmt *mysqlStmt) NumInput() int { } func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { - stmt.mc.affectedRows = 0 - stmt.mc.insertId = 0 - // Send command err := stmt.writeExecutePacket(args) if err != nil { return nil, err } + mc := stmt.mc + + mc.affectedRows = 0 + mc.insertId = 0 + // Read Result - var resLen int - resLen, err = stmt.mc.readResultSetHeaderPacket() + resLen, err := mc.readResultSetHeaderPacket() if err == nil { if resLen > 0 { // Columns - err = stmt.mc.readUntilEOF() + err = mc.readUntilEOF() if err != nil { return nil, err } // Rows - err = stmt.mc.readUntilEOF() + err = mc.readUntilEOF() } if err == nil { return &mysqlResult{ - affectedRows: int64(stmt.mc.affectedRows), - insertId: int64(stmt.mc.insertId), + affectedRows: int64(mc.affectedRows), + insertId: int64(mc.insertId), }, nil } } @@ -75,21 +76,19 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { return nil, err } + mc := stmt.mc + // Read Result - var resLen int - resLen, err = stmt.mc.readResultSetHeaderPacket() + resLen, err := mc.readResultSetHeaderPacket() if err != nil { return nil, err } - rows := &mysqlRows{stmt.mc, true, nil, false} + rows := &mysqlRows{mc, true, nil, false} if resLen > 0 { // Columns - rows.columns, err = stmt.mc.readColumns(resLen) - if err != nil { - return nil, err - } + rows.columns, err = mc.readColumns(resLen) } return rows, err From 8751b72867eb0d06af3c5f8a2a61e29d475760b6 Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Wed, 23 Oct 2013 17:32:44 +0200 Subject: [PATCH 6/9] buffer: rename take buffer funcs --- buffer.go | 4 ++-- packets.go | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/buffer.go b/buffer.go index e47cdf455..ed13fa283 100644 --- a/buffer.go +++ b/buffer.go @@ -85,7 +85,7 @@ func (b *buffer) readNext(need int) (p []byte, err error) { // If possible, a slice from the existing buffer is returned. // Otherwise a bigger buffer is made. // Only one buffer (total) can be used at a time. -func (b *buffer) writeBuffer(length int) []byte { +func (b *buffer) takeBuffer(length int) []byte { if b.length > 0 { return nil } @@ -105,7 +105,7 @@ func (b *buffer) writeBuffer(length int) []byte { // shortcut which can be used if the requested buffer is guaranteed to be // smaller than defaultBufSize // Only one buffer (total) can be used at a time. -func (b *buffer) smallWriteBuffer(length int) []byte { +func (b *buffer) takeSmallBuffer(length int) []byte { if b.length == 0 { return b.buf[:length] } diff --git a/packets.go b/packets.go index a14b7761e..6db9e5d06 100644 --- a/packets.go +++ b/packets.go @@ -239,7 +239,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { } // Calculate packet length and get buffer with that size - data := mc.buf.smallWriteBuffer(pktLen + 4) + data := mc.buf.takeSmallBuffer(pktLen + 4) if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print("Busy buffer") @@ -322,7 +322,7 @@ func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error { // Calculate the packet lenght and add a tailing 0 pktLen := len(scrambleBuff) + 1 - data := mc.buf.smallWriteBuffer(pktLen + 4) + data := mc.buf.takeSmallBuffer(pktLen + 4) if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print("Busy buffer") @@ -349,7 +349,7 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { // Reset Packet Sequence mc.sequence = 0 - data := mc.buf.smallWriteBuffer(4 + 1) + data := mc.buf.takeSmallBuffer(4 + 1) if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print("Busy buffer") @@ -374,7 +374,7 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { mc.sequence = 0 pktLen := 1 + len(arg) - data := mc.buf.writeBuffer(pktLen + 4) + data := mc.buf.takeBuffer(pktLen + 4) if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print("Busy buffer") @@ -401,7 +401,7 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { // Reset Packet Sequence mc.sequence = 0 - data := mc.buf.smallWriteBuffer(4 + 1 + 4) + data := mc.buf.takeSmallBuffer(4 + 1 + 4) if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print("Busy buffer") @@ -795,7 +795,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { if len(args) == 0 { const pktLen = 1 + 4 + 1 + 4 - data = mc.buf.writeBuffer(4 + pktLen) + data = mc.buf.takeBuffer(4 + pktLen) if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print("Busy buffer") From 605647ed915515f124b755441a68205274ac67d3 Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Wed, 23 Oct 2013 17:34:47 +0200 Subject: [PATCH 7/9] changelog: Add write buffer --- CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 72c71d429..764b1c3ae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,8 @@ Changes: - Made closing rows and connections error tolerant. This allows for example deferring rows.Close() without checking for errors - New Logo - Changed the copyright header to include all contributors - - Optimized the read buffer + - Optimized the buffer for reading + - Use the buffer also for writing. This results in zero allocations (by the driver) for most queries - Improved the LOAD INFILE documentation - The driver struct is now exported to make the driver directly accessible - Refactored the driver tests From d8e6c384d4d4e681abd0a4e09bf86150b3a2345d Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Wed, 23 Oct 2013 18:42:11 +0200 Subject: [PATCH 8/9] writeCommandPacketUint32: fix packet length comment --- packets.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packets.go b/packets.go index 6db9e5d06..129e6ac33 100644 --- a/packets.go +++ b/packets.go @@ -409,7 +409,7 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { } // Add the packet header [24bit length + 1 byte sequence] - data[0] = 0x05 // 1 bytes long + data[0] = 0x05 // 5 bytes long data[1] = 0x00 data[2] = 0x00 data[3] = 0x00 // sequence is always 0 From 228ba3461b36875338ec7967f4a0f38cc4f130a1 Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Thu, 24 Oct 2013 02:56:56 +0200 Subject: [PATCH 9/9] packets: YAR (yet another refactoring) --- packets.go | 50 +++++++++++++++++++++++++++++++------------------- 1 file changed, 31 insertions(+), 19 deletions(-) diff --git a/packets.go b/packets.go index 129e6ac33..731749cd2 100644 --- a/packets.go +++ b/packets.go @@ -360,7 +360,7 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { data[0] = 0x01 // 1 byte long data[1] = 0x00 data[2] = 0x00 - data[3] = 0x00 // sequence is always 0 + data[3] = 0x00 // new command, sequence id is always 0 // Add command byte data[4] = command @@ -385,7 +385,7 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { data[0] = byte(pktLen) data[1] = byte(pktLen >> 8) data[2] = byte(pktLen >> 16) - data[3] = 0x00 // sequence is always 0 + data[3] = 0x00 // new command, sequence id is always 0 // Add command byte data[4] = command @@ -412,7 +412,7 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { data[0] = 0x05 // 5 bytes long data[1] = 0x00 data[2] = 0x00 - data[3] = 0x00 // sequence is always 0 + data[3] = 0x00 // new command, sequence id is always 0 // Add command byte data[4] = command @@ -495,8 +495,8 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error { pos := 3 // SQL State [optional: # + 5bytes string] - //sqlstate := string(data[pos : pos+6]) if data[3] == 0x23 { + //sqlstate := string(data[4 : 4+5]) pos = 9 } @@ -700,13 +700,13 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) { } // statement id [4 bytes] - stmt.id = binary.LittleEndian.Uint32(data[1 : 1+4]) + stmt.id = binary.LittleEndian.Uint32(data[1:5]) // Column count [16 bit uint] - columnCount := binary.LittleEndian.Uint16(data[1+4 : 1+4+2]) + columnCount := binary.LittleEndian.Uint16(data[5:7]) // Param count [16 bit uint] - stmt.paramCount = int(binary.LittleEndian.Uint16(data[1+4+2 : 1+4+2+2])) + stmt.paramCount = int(binary.LittleEndian.Uint16(data[7:9])) // Reserved [8 bit] @@ -715,7 +715,7 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) { return columnCount, nil } else { // Check for warnings count > 0, only available in MySQL > 4.1 - if len(data) >= 12 && binary.LittleEndian.Uint16(data[1+4+2+2+1:1+4+2+2+1+2]) > 0 { + if len(data) >= 12 && binary.LittleEndian.Uint16(data[10:12]) > 0 { return columnCount, stmt.mc.getWarnings() } return columnCount, nil @@ -729,16 +729,22 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { maxLen := stmt.mc.maxPacketAllowed - 1 pktLen := maxLen + // After the header (bytes 0-3) follows before the data: + // 1 byte command + // 4 bytes stmtID + // 2 bytes paramID + const dataOffset = 1 + 4 + 2 + // Can not use the write buffer since // a) the buffer is too small // b) it is in use data := make([]byte, 4+1+4+2+len(arg)) - copy(data[4+1+4+2:], arg) + copy(data[4+dataOffset:], arg) - for argLen := len(arg); argLen > 0; argLen -= pktLen - (1 + 4 + 2) { - if 1+4+2+argLen < maxLen { - pktLen = 1 + 4 + 2 + argLen + for argLen := len(arg); argLen > 0; argLen -= pktLen - dataOffset { + if dataOffset+argLen < maxLen { + pktLen = dataOffset + argLen } // Add the packet header [24bit length + 1 byte sequence] @@ -763,7 +769,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { // Send CMD packet err := stmt.mc.writePacket(data[:4+pktLen]) if err == nil { - data = data[pktLen-(1+4+2):] + data = data[pktLen-dataOffset:] continue } return err @@ -806,7 +812,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { data[0] = byte(pktLen) data[1] = byte(pktLen >> 8) data[2] = byte(pktLen >> 16) - data[3] = 0x00 // sequence is always 0 + data[3] = 0x00 // new command, sequence id is always 0 } else { data = mc.buf.takeCompleteBuffer() if data == nil { @@ -871,7 +877,10 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { if cap(paramValues)-len(paramValues)-8 >= 0 { paramValues = paramValues[:len(paramValues)+8] - binary.LittleEndian.PutUint64(paramValues[len(paramValues)-8:], uint64(v)) + binary.LittleEndian.PutUint64( + paramValues[len(paramValues)-8:], + uint64(v), + ) } else { paramValues = append(paramValues, uint64ToBytes(uint64(v))..., @@ -884,7 +893,10 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { if cap(paramValues)-len(paramValues)-8 >= 0 { paramValues = paramValues[:len(paramValues)+8] - binary.LittleEndian.PutUint64(paramValues[len(paramValues)-8:], math.Float64bits(v)) + binary.LittleEndian.PutUint64( + paramValues[len(paramValues)-8:], + math.Float64bits(v), + ) } else { paramValues = append(paramValues, uint64ToBytes(math.Float64bits(v))..., @@ -991,10 +1003,10 @@ func (rows *mysqlRows) readBinaryRow(dest []driver.Value) error { // EOF Packet if data[0] == iEOF && len(data) == 5 { return io.EOF - } else { - // Error otherwise - return rows.mc.handleErrorPacket(data) } + + // Error otherwise + return rows.mc.handleErrorPacket(data) } // NULL-bitmap, [(column-count + 7 + 2) / 8 bytes]