diff --git a/CHANGELOG.md b/CHANGELOG.md index 72c71d429..764b1c3ae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,8 @@ Changes: - Made closing rows and connections error tolerant. This allows for example deferring rows.Close() without checking for errors - New Logo - Changed the copyright header to include all contributors - - Optimized the read buffer + - Optimized the buffer for reading + - Use the buffer also for writing. This results in zero allocations (by the driver) for most queries - Improved the LOAD INFILE documentation - The driver struct is now exported to make the driver directly accessible - Refactored the driver tests diff --git a/benchmark_test.go b/benchmark_test.go index 2fb3f6b2b..67ce9c547 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -69,23 +69,26 @@ func BenchmarkQuery(b *testing.B) { stmt := tb.checkStmt(db.Prepare("SELECT val FROM foo WHERE id=?")) defer stmt.Close() - b.StartTimer() remain := int64(b.N) var wg sync.WaitGroup wg.Add(concurrencyLevel) defer wg.Wait() + b.StartTimer() + for i := 0; i < concurrencyLevel; i++ { go func() { - defer wg.Done() for { if atomic.AddInt64(&remain, -1) < 0 { + wg.Done() return } + var got string tb.check(stmt.QueryRow(1).Scan(&got)) if got != "one" { b.Errorf("query = %q; want one", got) + wg.Done() return } } diff --git a/buffer.go b/buffer.go index d6c8b72e4..ed13fa283 100644 --- a/buffer.go +++ b/buffer.go @@ -12,7 +12,10 @@ import "io" const defaultBufSize = 4096 -// A read buffer similar to bufio.Reader but zero-copy-ish +// A buffer which is used for both reading and writing. +// This is possible since communication on each connection is synchronous. +// In other words, we can't write and read simultaneously on the same connection. +// The buffer is similar to bufio.Reader / Writer but zero-copy-ish // Also highly optimized for this particular use case. type buffer struct { buf []byte @@ -37,8 +40,11 @@ func (b *buffer) fill(need int) (err error) { } // grow buffer if necessary + // TODO: let the buffer shrink again at some point + // Maybe keep the org buf slice and swap back? if need > len(b.buf) { - newBuf := make([]byte, need) + // Round up to the next multiple of the default size + newBuf := make([]byte, ((need/defaultBufSize)+1)*defaultBufSize) copy(newBuf, b.buf) b.buf = newBuf } @@ -74,3 +80,44 @@ func (b *buffer) readNext(need int) (p []byte, err error) { b.length -= need return } + +// returns a buffer with the requested size. +// If possible, a slice from the existing buffer is returned. +// Otherwise a bigger buffer is made. +// Only one buffer (total) can be used at a time. +func (b *buffer) takeBuffer(length int) []byte { + if b.length > 0 { + return nil + } + + // test (cheap) general case first + if length <= defaultBufSize || length <= cap(b.buf) { + return b.buf[:length] + } + + if length < maxPacketSize { + b.buf = make([]byte, length) + return b.buf + } + return make([]byte, length) +} + +// shortcut which can be used if the requested buffer is guaranteed to be +// smaller than defaultBufSize +// Only one buffer (total) can be used at a time. +func (b *buffer) takeSmallBuffer(length int) []byte { + if b.length == 0 { + return b.buf[:length] + } + return nil +} + +// takeCompleteBuffer returns the complete existing buffer. +// This can be used if the necessary buffer size is unknown. +// Only one buffer (total) can be used at a time. +func (b *buffer) takeCompleteBuffer() []byte { + if b.length == 0 { + return b.buf + } + return nil +} diff --git a/connection.go b/connection.go index 36a27b572..f769f7869 100644 --- a/connection.go +++ b/connection.go @@ -136,14 +136,14 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { columnCount, err := stmt.readPrepareResultPacket() if err == nil { if stmt.paramCount > 0 { - stmt.params, err = stmt.mc.readColumns(stmt.paramCount) + stmt.params, err = mc.readColumns(stmt.paramCount) if err != nil { return nil, err } } if columnCount > 0 { - err = stmt.mc.readUntilEOF() + err = mc.readUntilEOF() } } @@ -171,26 +171,24 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err } // Internal function to execute commands -func (mc *mysqlConn) exec(query string) (err error) { +func (mc *mysqlConn) exec(query string) error { // Send command - err = mc.writeCommandPacketStr(comQuery, query) + err := mc.writeCommandPacketStr(comQuery, query) if err != nil { - return + return err } // Read Result - var resLen int - resLen, err = mc.readResultSetHeaderPacket() + resLen, err := mc.readResultSetHeaderPacket() if err == nil && resLen > 0 { - err = mc.readUntilEOF() - if err != nil { - return + if err = mc.readUntilEOF(); err != nil { + return err } err = mc.readUntilEOF() } - return + return err } func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) { @@ -211,7 +209,6 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro return rows, err } } - return nil, err } @@ -221,29 +218,29 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro // Gets the value of the given MySQL System Variable // The returned byte slice is only valid until the next read -func (mc *mysqlConn) getSystemVar(name string) (val []byte, err error) { +func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { // Send command - err = mc.writeCommandPacketStr(comQuery, "SELECT @@"+name) + if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil { + return nil, err + } + + // Read Result + resLen, err := mc.readResultSetHeaderPacket() if err == nil { - // Read Result - var resLen int - resLen, err = mc.readResultSetHeaderPacket() - if err == nil { - rows := &mysqlRows{mc, false, nil, false} + rows := &mysqlRows{mc, false, nil, false} - if resLen > 0 { - // Columns - rows.columns, err = mc.readColumns(resLen) + if resLen > 0 { + // Columns + rows.columns, err = mc.readColumns(resLen) + if err != nil { + return nil, err } + } - dest := make([]driver.Value, resLen) - err = rows.readRow(dest) - if err == nil { - val = dest[0].([]byte) - err = mc.readUntilEOF() - } + dest := make([]driver.Value, resLen) + if err = rows.readRow(dest); err == nil { + return dest[0].([]byte), mc.readUntilEOF() } } - - return + return nil, err } diff --git a/packets.go b/packets.go index 9b4c99e8e..731749cd2 100644 --- a/packets.go +++ b/packets.go @@ -23,9 +23,9 @@ import ( // http://dev.mysql.com/doc/internals/en/client-server-protocol.html // Read packet to buffer 'data' -func (mc *mysqlConn) readPacket() (data []byte, err error) { +func (mc *mysqlConn) readPacket() ([]byte, error) { // Read packet header - data, err = mc.buf.readNext(4) + data, err := mc.buf.readNext(4) if err != nil { errLog.Print(err.Error()) mc.Close() @@ -97,7 +97,7 @@ func (mc *mysqlConn) writePacket(data []byte) error { return mc.splitPacket(data) } -func (mc *mysqlConn) splitPacket(data []byte) (err error) { +func (mc *mysqlConn) splitPacket(data []byte) error { pktLen := len(data) - 4 if pktLen > mc.maxPacketAllowed { @@ -140,11 +140,11 @@ func (mc *mysqlConn) splitPacket(data []byte) (err error) { ******************************************************************************/ // Handshake Initialization Packet -// http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::Handshake -func (mc *mysqlConn) readInitPacket() (cipher []byte, err error) { +// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake +func (mc *mysqlConn) readInitPacket() ([]byte, error) { data, err := mc.readPacket() if err != nil { - return + return nil, err } if data[0] == iERR { @@ -153,11 +153,11 @@ func (mc *mysqlConn) readInitPacket() (cipher []byte, err error) { // protocol version [1 byte] if data[0] < minProtocolVersion { - err = fmt.Errorf( + return nil, fmt.Errorf( "Unsupported MySQL Protocol Version %d. Protocol Version %d or higher is required", data[0], - minProtocolVersion) - return + minProtocolVersion, + ) } // server version [null terminated string] @@ -165,7 +165,7 @@ func (mc *mysqlConn) readInitPacket() (cipher []byte, err error) { pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4 // first part of the password cipher [8 bytes] - cipher = data[pos : pos+8] + cipher := data[pos : pos+8] // (filler) always 0x00 [1 byte] pos += 8 + 1 @@ -197,7 +197,6 @@ func (mc *mysqlConn) readInitPacket() (cipher []byte, err error) { // TODO: Verify string termination // EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2) // \NUL otherwise - // http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::Handshake // //if data[len(data)-1] == 0 { // return @@ -205,11 +204,11 @@ func (mc *mysqlConn) readInitPacket() (cipher []byte, err error) { //return errMalformPkt } - return + return cipher, nil } // Client Authentication Packet -// http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::HandshakeResponse +// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { // Adjust client flags based on server support clientFlags := clientProtocol41 | @@ -239,8 +238,13 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { pktLen += n + 1 } - // Calculate packet length and make buffer with that size - data := make([]byte, pktLen+4) + // Calculate packet length and get buffer with that size + data := mc.buf.takeSmallBuffer(pktLen + 4) + if data == nil { + // can not take the buffer. Something must be wrong with the connection + errLog.Print("Busy buffer") + return driver.ErrBadConn + } // ClientFlags [32 bit] data[4] = byte(clientFlags) @@ -249,16 +253,16 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { data[7] = byte(clientFlags >> 24) // MaxPacketSize [32 bit] (none) - //data[8] = 0x00 - //data[9] = 0x00 - //data[10] = 0x00 - //data[11] = 0x00 + data[8] = 0x00 + data[9] = 0x00 + data[10] = 0x00 + data[11] = 0x00 // Charset [1 byte] data[12] = collation_utf8_general_ci // SSL Connection Request Packet - // http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::SSLRequest + // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest if mc.cfg.tls != nil { // Packet header [24bit length + 1 byte sequence] data[0] = byte((4 + 4 + 1 + 23)) @@ -293,7 +297,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { if len(mc.cfg.user) > 0 { pos += copy(data[pos:], mc.cfg.user) } - //data[pos] = 0x00 + data[pos] = 0x00 pos++ // ScrambleBuffer [length encoded integer] @@ -303,7 +307,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { // Databasename [null terminated string] if len(mc.cfg.dbname) > 0 { pos += copy(data[pos:], mc.cfg.dbname) - //data[pos] = 0x00 + data[pos] = 0x00 } // Send Auth packet @@ -311,14 +315,19 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { } // Client old authentication packet -// http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::AuthSwitchResponse +// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error { // User password scrambleBuff := scrambleOldPassword(cipher, []byte(mc.cfg.passwd)) // Calculate the packet lenght and add a tailing 0 pktLen := len(scrambleBuff) + 1 - data := make([]byte, pktLen+4) + data := mc.buf.takeSmallBuffer(pktLen + 4) + if data == nil { + // can not take the buffer. Something must be wrong with the connection + errLog.Print("Busy buffer") + return driver.ErrBadConn + } // Add the packet header [24bit length + 1 byte sequence] data[0] = byte(pktLen) @@ -340,17 +349,24 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { // Reset Packet Sequence mc.sequence = 0 + data := mc.buf.takeSmallBuffer(4 + 1) + if data == nil { + // can not take the buffer. Something must be wrong with the connection + errLog.Print("Busy buffer") + return driver.ErrBadConn + } + + // Add the packet header [24bit length + 1 byte sequence] + data[0] = 0x01 // 1 byte long + data[1] = 0x00 + data[2] = 0x00 + data[3] = 0x00 // new command, sequence id is always 0 + + // Add command byte + data[4] = command + // Send CMD packet - return mc.writePacket([]byte{ - // Add the packet header [24bit length + 1 byte sequence] - 0x01, // 1 byte long - 0x00, - 0x00, - 0x00, // mc.sequence - - // Add command byte - command, - }) + return mc.writePacket(data) } func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { @@ -358,13 +374,18 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { mc.sequence = 0 pktLen := 1 + len(arg) - data := make([]byte, pktLen+4) + data := mc.buf.takeBuffer(pktLen + 4) + if data == nil { + // can not take the buffer. Something must be wrong with the connection + errLog.Print("Busy buffer") + return driver.ErrBadConn + } // Add the packet header [24bit length + 1 byte sequence] data[0] = byte(pktLen) data[1] = byte(pktLen >> 8) data[2] = byte(pktLen >> 16) - //data[3] = mc.sequence + data[3] = 0x00 // new command, sequence id is always 0 // Add command byte data[4] = command @@ -380,23 +401,30 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { // Reset Packet Sequence mc.sequence = 0 + data := mc.buf.takeSmallBuffer(4 + 1 + 4) + if data == nil { + // can not take the buffer. Something must be wrong with the connection + errLog.Print("Busy buffer") + return driver.ErrBadConn + } + + // Add the packet header [24bit length + 1 byte sequence] + data[0] = 0x05 // 5 bytes long + data[1] = 0x00 + data[2] = 0x00 + data[3] = 0x00 // new command, sequence id is always 0 + + // Add command byte + data[4] = command + + // Add arg [32 bit] + data[5] = byte(arg) + data[6] = byte(arg >> 8) + data[7] = byte(arg >> 16) + data[8] = byte(arg >> 24) + // Send CMD packet - return mc.writePacket([]byte{ - // Add the packet header [24bit length + 1 byte sequence] - 0x05, // 5 bytes long - 0x00, - 0x00, - 0x00, // mc.sequence - - // Add command byte - command, - - // Add arg [32 bit] - byte(arg), - byte(arg >> 8), - byte(arg >> 16), - byte(arg >> 24), - }) + return mc.writePacket(data) } /****************************************************************************** @@ -425,7 +453,7 @@ func (mc *mysqlConn) readResultOK() error { } // Result Set Header Packet -// http://dev.mysql.com/doc/internals/en/text-protocol.html#packet-ProtocolText::Resultset +// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) { data, err := mc.readPacket() if err == nil { @@ -453,7 +481,7 @@ func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) { } // Error Packet -// http://dev.mysql.com/doc/internals/en/overview.html#packet-ERR_Packet +// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-ERR_Packet func (mc *mysqlConn) handleErrorPacket(data []byte) error { if data[0] != iERR { return errMalformPkt @@ -467,8 +495,8 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error { pos := 3 // SQL State [optional: # + 5bytes string] - //sqlstate := string(data[pos : pos+6]) - if data[pos] == 0x23 { + if data[3] == 0x23 { + //sqlstate := string(data[4 : 4+5]) pos = 9 } @@ -480,8 +508,8 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error { } // Ok Packet -// http://dev.mysql.com/doc/internals/en/overview.html#packet-OK_Packet -func (mc *mysqlConn) handleOkPacket(data []byte) (err error) { +// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet +func (mc *mysqlConn) handleOkPacket(data []byte) error { var n, m int // 0x00 [1 byte] @@ -496,72 +524,66 @@ func (mc *mysqlConn) handleOkPacket(data []byte) (err error) { // warning count [2 bytes] if !mc.strict { - return + return nil } else { pos := 1 + n + m + 2 if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 { - err = mc.getWarnings() + return mc.getWarnings() } + return nil } - - // message [until end of packet] - return } // Read Packets as Field Packets until EOF-Packet or an Error appears -// http://dev.mysql.com/doc/internals/en/text-protocol.html#packet-Protocol::ColumnDefinition41 -func (mc *mysqlConn) readColumns(count int) (columns []mysqlField, err error) { - var data []byte - var i, pos, n int - var name []byte - - columns = make([]mysqlField, count) +// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41 +func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { + columns := make([]mysqlField, count) - for { - data, err = mc.readPacket() + for i := 0; ; i++ { + data, err := mc.readPacket() if err != nil { - return + return nil, err } // EOF Packet if data[0] == iEOF && (len(data) == 5 || len(data) == 1) { - if i != count { - err = fmt.Errorf("ColumnsCount mismatch n:%d len:%d", count, len(columns)) + if i == count { + return columns, nil } - return + return nil, fmt.Errorf("ColumnsCount mismatch n:%d len:%d", count, len(columns)) } // Catalog - pos, err = skipLengthEnodedString(data) + pos, err := skipLengthEnodedString(data) if err != nil { - return + return nil, err } // Database [len coded string] - n, err = skipLengthEnodedString(data[pos:]) + n, err := skipLengthEnodedString(data[pos:]) if err != nil { - return + return nil, err } pos += n // Table [len coded string] n, err = skipLengthEnodedString(data[pos:]) if err != nil { - return + return nil, err } pos += n // Original table [len coded string] n, err = skipLengthEnodedString(data[pos:]) if err != nil { - return + return nil, err } pos += n // Name [len coded string] - name, _, n, err = readLengthEnodedString(data[pos:]) + name, _, n, err := readLengthEnodedString(data[pos:]) if err != nil { - return + return nil, err } columns[i].name = string(name) pos += n @@ -569,7 +591,7 @@ func (mc *mysqlConn) readColumns(count int) (columns []mysqlField, err error) { // Original name [len coded string] n, err = skipLengthEnodedString(data[pos:]) if err != nil { - return + return nil, err } // Filler [1 byte] @@ -592,17 +614,17 @@ func (mc *mysqlConn) readColumns(count int) (columns []mysqlField, err error) { //if pos < len(data) { // defaultVal, _, err = bytesToLengthCodedBinary(data[pos:]) //} - - i++ } } // Read Packets as Field Packets until EOF-Packet or an Error appears -// http://dev.mysql.com/doc/internals/en/text-protocol.html#packet-ProtocolText::ResultsetRow -func (rows *mysqlRows) readRow(dest []driver.Value) (err error) { - data, err := rows.mc.readPacket() +// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow +func (rows *mysqlRows) readRow(dest []driver.Value) error { + mc := rows.mc + + data, err := mc.readPacket() if err != nil { - return + return err } // EOF Packet @@ -621,13 +643,16 @@ func (rows *mysqlRows) readRow(dest []driver.Value) (err error) { pos += n if err == nil { if !isNull { - if !rows.mc.parseTime { + if !mc.parseTime { continue } else { switch rows.columns[i].fieldType { case fieldTypeTimestamp, fieldTypeDateTime, fieldTypeDate, fieldTypeNewDate: - dest[i], err = parseDateTime(string(dest[i].([]byte)), rows.mc.cfg.loc) + dest[i], err = parseDateTime( + string(dest[i].([]byte)), + mc.cfg.loc, + ) if err == nil { continue } @@ -641,24 +666,22 @@ func (rows *mysqlRows) readRow(dest []driver.Value) (err error) { continue } } - return // err + return err // err != nil } - return + return nil } // Reads Packets until EOF-Packet or an Error appears. Returns count of Packets read -func (mc *mysqlConn) readUntilEOF() (err error) { - var data []byte - +func (mc *mysqlConn) readUntilEOF() error { for { - data, err = mc.readPacket() + data, err := mc.readPacket() // No Err and no EOF Packet if err == nil && data[0] != iEOF { continue } - return // Err or EOF + return err // Err or EOF } } @@ -667,59 +690,61 @@ func (mc *mysqlConn) readUntilEOF() (err error) { ******************************************************************************/ // Prepare Result Packets -// http://dev.mysql.com/doc/internals/en/prepared-statements.html#com-stmt-prepare-response -func (stmt *mysqlStmt) readPrepareResultPacket() (columnCount uint16, err error) { +// http://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html +func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) { data, err := stmt.mc.readPacket() if err == nil { - // Position - pos := 0 - // packet indicator [1 byte] - if data[pos] != iOK { - err = stmt.mc.handleErrorPacket(data) - return + if data[0] != iOK { + return 0, stmt.mc.handleErrorPacket(data) } - pos++ // statement id [4 bytes] - stmt.id = binary.LittleEndian.Uint32(data[pos : pos+4]) - pos += 4 + stmt.id = binary.LittleEndian.Uint32(data[1:5]) // Column count [16 bit uint] - columnCount = binary.LittleEndian.Uint16(data[pos : pos+2]) - pos += 2 + columnCount := binary.LittleEndian.Uint16(data[5:7]) // Param count [16 bit uint] - stmt.paramCount = int(binary.LittleEndian.Uint16(data[pos : pos+2])) - pos += 2 + stmt.paramCount = int(binary.LittleEndian.Uint16(data[7:9])) // Reserved [8 bit] - pos++ // Warning count [16 bit uint] if !stmt.mc.strict { - return + return columnCount, nil } else { // Check for warnings count > 0, only available in MySQL > 4.1 - if len(data) >= 12 && binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 { - err = stmt.mc.getWarnings() + if len(data) >= 12 && binary.LittleEndian.Uint16(data[10:12]) > 0 { + return columnCount, stmt.mc.getWarnings() } + return columnCount, nil } } - return + return 0, err } -// http://dev.mysql.com/doc/internals/en/prepared-statements.html#com-stmt-send-long-data -func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) (err error) { +// http://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html +func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { maxLen := stmt.mc.maxPacketAllowed - 1 pktLen := maxLen - argLen := len(arg) - data := make([]byte, 4+1+4+2+argLen) - copy(data[4+1+4+2:], arg) - for argLen > 0 { - if 1+4+2+argLen < maxLen { - pktLen = 1 + 4 + 2 + argLen + // After the header (bytes 0-3) follows before the data: + // 1 byte command + // 4 bytes stmtID + // 2 bytes paramID + const dataOffset = 1 + 4 + 2 + + // Can not use the write buffer since + // a) the buffer is too small + // b) it is in use + data := make([]byte, 4+1+4+2+len(arg)) + + copy(data[4+dataOffset:], arg) + + for argLen := len(arg); argLen > 0; argLen -= pktLen - dataOffset { + if dataOffset+argLen < maxLen { + pktLen = dataOffset + argLen } // Add the packet header [24bit length + 1 byte sequence] @@ -742,10 +767,9 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) (err error) data[10] = byte(paramID >> 8) // Send CMD packet - err = stmt.mc.writePacket(data[:4+pktLen]) + err := stmt.mc.writePacket(data[:4+pktLen]) if err == nil { - argLen -= pktLen - (1 + 4 + 2) - data = data[pktLen-(1+4+2):] + data = data[pktLen-dataOffset:] continue } return err @@ -758,119 +782,47 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) (err error) } // Execute Prepared Statement -// http://dev.mysql.com/doc/internals/en/prepared-statements.html#com-stmt-execute +// http://dev.mysql.com/doc/internals/en/com-stmt-execute.html func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { if len(args) != stmt.paramCount { return fmt.Errorf( "Arguments count mismatch (Got: %d Has: %d)", len(args), - stmt.paramCount) + stmt.paramCount, + ) } - // Reset packet-sequence - stmt.mc.sequence = 0 - - pktLen := 1 + 4 + 1 + 4 + ((stmt.paramCount + 7) >> 3) + 1 + (stmt.paramCount << 1) - paramValues := make([][]byte, stmt.paramCount) - paramTypes := make([]byte, (stmt.paramCount << 1)) - bitMask := uint64(0) - var i int - - for i = range args { - // build NULL-bitmap - if args[i] == nil { - bitMask += 1 << uint(i) - paramTypes[i<<1] = fieldTypeNULL - continue - } - - // cache types and values - switch v := args[i].(type) { - case int64: - paramTypes[i<<1] = fieldTypeLongLong - paramValues[i] = uint64ToBytes(uint64(v)) - pktLen += 8 - continue - - case float64: - paramTypes[i<<1] = fieldTypeDouble - paramValues[i] = uint64ToBytes(math.Float64bits(v)) - pktLen += 8 - continue - - case bool: - paramTypes[i<<1] = fieldTypeTiny - pktLen++ - if v { - paramValues[i] = []byte{0x01} - } else { - paramValues[i] = []byte{0x00} - } - continue - - case []byte: - paramTypes[i<<1] = fieldTypeString - if len(v) < stmt.mc.maxPacketAllowed-pktLen-(stmt.paramCount-(i+1))*64 { - paramValues[i] = append( - lengthEncodedIntegerToBytes(uint64(len(v))), - v..., - ) - pktLen += len(paramValues[i]) - continue - } else { - err := stmt.writeCommandLongData(i, v) - if err == nil { - continue - } - return err - } - - case string: - paramTypes[i<<1] = fieldTypeString - if len(v) < stmt.mc.maxPacketAllowed-pktLen-(stmt.paramCount-(i+1))*64 { - paramValues[i] = append( - lengthEncodedIntegerToBytes(uint64(len(v))), - []byte(v)..., - ) - pktLen += len(paramValues[i]) - continue - } else { - err := stmt.writeCommandLongData(i, []byte(v)) - if err == nil { - continue - } - return err - } - - case time.Time: - paramTypes[i<<1] = fieldTypeString + mc := stmt.mc - var val []byte - if v.IsZero() { - val = []byte("0000-00-00") - } else { - val = []byte(v.In(stmt.mc.cfg.loc).Format(timeFormat)) - } + // Reset packet-sequence + mc.sequence = 0 - paramValues[i] = append( - lengthEncodedIntegerToBytes(uint64(len(val))), - val..., - ) - pktLen += len(paramValues[i]) - continue + var data []byte - default: - return fmt.Errorf("Can't convert type: %T", args[i]) + if len(args) == 0 { + const pktLen = 1 + 4 + 1 + 4 + data = mc.buf.takeBuffer(4 + pktLen) + if data == nil { + // can not take the buffer. Something must be wrong with the connection + errLog.Print("Busy buffer") + return driver.ErrBadConn } - } - data := make([]byte, pktLen+4) + // packet header [4 bytes] + data[0] = byte(pktLen) + data[1] = byte(pktLen >> 8) + data[2] = byte(pktLen >> 16) + data[3] = 0x00 // new command, sequence id is always 0 + } else { + data = mc.buf.takeCompleteBuffer() + if data == nil { + // can not take the buffer. Something must be wrong with the connection + errLog.Print("Busy buffer") + return driver.ErrBadConn + } - // packet header [4 bytes] - data[0] = byte(pktLen) - data[1] = byte(pktLen >> 8) - data[2] = byte(pktLen >> 16) - data[3] = stmt.mc.sequence + // header (bytes 0-3) is added after we know the packet size + } // command [1 byte] data[4] = comStmtExecute @@ -882,43 +834,168 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { data[8] = byte(stmt.id >> 24) // flags (0: CURSOR_TYPE_NO_CURSOR) [1 byte] - //data[9] = 0x00 + data[9] = 0x00 // iteration_count (uint32(1)) [4 bytes] data[10] = 0x01 - //data[11] = 0x00 - //data[12] = 0x00 - //data[13] = 0x00 - - if stmt.paramCount > 0 { - // NULL-bitmap [(param_count+7)/8 bytes] - pos := 14 + ((stmt.paramCount + 7) >> 3) - // Convert bitMask to bytes - for i = 14; i < pos; i++ { - data[i] = byte(bitMask >> uint((i-14)<<3)) - } + data[11] = 0x00 + data[12] = 0x00 + data[13] = 0x00 + + if len(args) > 0 { + // NULL-bitmap [(len(args)+7)/8 bytes] + nullMask := uint64(0) + + pos := 4 + 1 + 4 + 1 + 4 + ((len(args) + 7) >> 3) // newParameterBoundFlag 1 [1 byte] data[pos] = 0x01 pos++ - // type of parameters [param_count*2 bytes] - pos += copy(data[pos:], paramTypes) + // type of each parameter [len(args)*2 bytes] + paramTypes := data[pos:] + pos += (len(args) << 1) + + // value of each parameter [n bytes] + paramValues := data[pos:pos] + valuesCap := cap(paramValues) + + for i := range args { + // build NULL-bitmap + if args[i] == nil { + nullMask |= 1 << uint(i) + paramTypes[i+i] = fieldTypeNULL + paramTypes[i+i+1] = 0x00 + continue + } + + // cache types and values + switch v := args[i].(type) { + case int64: + paramTypes[i+i] = fieldTypeLongLong + paramTypes[i+i+1] = 0x00 + + if cap(paramValues)-len(paramValues)-8 >= 0 { + paramValues = paramValues[:len(paramValues)+8] + binary.LittleEndian.PutUint64( + paramValues[len(paramValues)-8:], + uint64(v), + ) + } else { + paramValues = append(paramValues, + uint64ToBytes(uint64(v))..., + ) + } + + case float64: + paramTypes[i+i] = fieldTypeDouble + paramTypes[i+i+1] = 0x00 + + if cap(paramValues)-len(paramValues)-8 >= 0 { + paramValues = paramValues[:len(paramValues)+8] + binary.LittleEndian.PutUint64( + paramValues[len(paramValues)-8:], + math.Float64bits(v), + ) + } else { + paramValues = append(paramValues, + uint64ToBytes(math.Float64bits(v))..., + ) + } + + case bool: + paramTypes[i+i] = fieldTypeTiny + paramTypes[i+i+1] = 0x00 + + if v { + paramValues = append(paramValues, 0x01) + } else { + paramValues = append(paramValues, 0x00) + } + + case []byte: + paramTypes[i+i] = fieldTypeString + paramTypes[i+i+1] = 0x00 + + if len(v) < mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 { + paramValues = append(paramValues, + lengthEncodedIntegerToBytes(uint64(len(v)))..., + ) + paramValues = append(paramValues, v...) + } else { + if err := stmt.writeCommandLongData(i, v); err != nil { + return err + } + } + + case string: + paramTypes[i+i] = fieldTypeString + paramTypes[i+i+1] = 0x00 + + if len(v) < mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 { + paramValues = append(paramValues, + lengthEncodedIntegerToBytes(uint64(len(v)))..., + ) + paramValues = append(paramValues, v...) + } else { + if err := stmt.writeCommandLongData(i, []byte(v)); err != nil { + return err + } + } + + case time.Time: + paramTypes[i+i] = fieldTypeString + paramTypes[i+i+1] = 0x00 + + var val []byte + if v.IsZero() { + val = []byte("0000-00-00") + } else { + val = []byte(v.In(mc.cfg.loc).Format(timeFormat)) + } + + paramValues = append(paramValues, + lengthEncodedIntegerToBytes(uint64(len(val)))..., + ) + paramValues = append(paramValues, val...) + + default: + return fmt.Errorf("Can't convert type: %T", args[i]) + } + } - // values for the parameters [n bytes] - for i = range paramValues { - pos += copy(data[pos:], paramValues[i]) + // Check if param values exceeded the available buffer + // In that case we must build the data packet with the new values buffer + if valuesCap != cap(paramValues) { + data = append(data[:pos], paramValues...) + mc.buf.buf = data + } + + pos += len(paramValues) + data = data[:pos] + + pktLen := pos - 4 + + // packet header [4 bytes] + data[0] = byte(pktLen) + data[1] = byte(pktLen >> 8) + data[2] = byte(pktLen >> 16) + data[3] = mc.sequence + + // Convert nullMask to bytes + for i, max := 0, (stmt.paramCount+7)>>3; i < max; i++ { + data[i+14] = byte(nullMask >> uint(i<<3)) } } - return stmt.mc.writePacket(data) + return mc.writePacket(data) } -// http://dev.mysql.com/doc/internals/en/prepared-statements.html#packet-ProtocolBinary::ResultsetRow -func (rows *mysqlRows) readBinaryRow(dest []driver.Value) (err error) { +// http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html +func (rows *mysqlRows) readBinaryRow(dest []driver.Value) error { data, err := rows.mc.readPacket() if err != nil { - return + return err } // packet indicator [1 byte] @@ -926,30 +1003,24 @@ func (rows *mysqlRows) readBinaryRow(dest []driver.Value) (err error) { // EOF Packet if data[0] == iEOF && len(data) == 5 { return io.EOF - } else { - // Error otherwise - return rows.mc.handleErrorPacket(data) } + + // Error otherwise + return rows.mc.handleErrorPacket(data) } // NULL-bitmap, [(column-count + 7 + 2) / 8 bytes] pos := 1 + (len(dest)+7+2)>>3 - nullBitMap := data[1:pos] - - // values [rest] - var n int - var unsigned bool + nullMask := data[1:pos] for i := range dest { // Field is NULL // (byte >> bit-pos) % 2 == 1 - if ((nullBitMap[(i+2)>>3] >> uint((i+2)&7)) & 1) == 1 { + if ((nullMask[(i+2)>>3] >> uint((i+2)&7)) & 1) == 1 { dest[i] = nil continue } - unsigned = rows.columns[i].flags&flagUnsigned != 0 - // Convert to byte-coded string switch rows.columns[i].fieldType { case fieldTypeNULL: @@ -958,7 +1029,7 @@ func (rows *mysqlRows) readBinaryRow(dest []driver.Value) (err error) { // Numeric Types case fieldTypeTiny: - if unsigned { + if rows.columns[i].flags&flagUnsigned != 0 { dest[i] = int64(data[pos]) } else { dest[i] = int64(int8(data[pos])) @@ -967,7 +1038,7 @@ func (rows *mysqlRows) readBinaryRow(dest []driver.Value) (err error) { continue case fieldTypeShort, fieldTypeYear: - if unsigned { + if rows.columns[i].flags&flagUnsigned != 0 { dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2])) } else { dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2]))) @@ -976,7 +1047,7 @@ func (rows *mysqlRows) readBinaryRow(dest []driver.Value) (err error) { continue case fieldTypeInt24, fieldTypeLong: - if unsigned { + if rows.columns[i].flags&flagUnsigned != 0 { dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4])) } else { dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4]))) @@ -985,7 +1056,7 @@ func (rows *mysqlRows) readBinaryRow(dest []driver.Value) (err error) { continue case fieldTypeLongLong: - if unsigned { + if rows.columns[i].flags&flagUnsigned != 0 { val := binary.LittleEndian.Uint64(data[pos : pos+8]) if val > math.MaxInt64 { dest[i] = uint64ToString(val) @@ -1014,6 +1085,7 @@ func (rows *mysqlRows) readBinaryRow(dest []driver.Value) (err error) { fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB, fieldTypeVarString, fieldTypeString, fieldTypeGeometry: var isNull bool + var n int dest[i], isNull, n, err = readLengthEnodedString(data[pos:]) pos += n if err == nil { @@ -1024,14 +1096,11 @@ func (rows *mysqlRows) readBinaryRow(dest []driver.Value) (err error) { continue } } - return // err + return err // Date YYYY-MM-DD case fieldTypeDate, fieldTypeNewDate: - var num uint64 - var isNull bool - num, isNull, n = readLengthEncodedInteger(data[pos:]) - + num, isNull, n := readLengthEncodedInteger(data[pos:]) pos += n if isNull { @@ -1054,10 +1123,7 @@ func (rows *mysqlRows) readBinaryRow(dest []driver.Value) (err error) { // Time [-][H]HH:MM:SS[.fractal] case fieldTypeTime: - var num uint64 - var isNull bool - num, isNull, n = readLengthEncodedInteger(data[pos:]) - + num, isNull, n := readLengthEncodedInteger(data[pos:]) pos += n if num == 0 { @@ -1103,9 +1169,7 @@ func (rows *mysqlRows) readBinaryRow(dest []driver.Value) (err error) { // Timestamp YYYY-MM-DD HH:MM:SS[.fractal] case fieldTypeTimestamp, fieldTypeDateTime: - var num uint64 - var isNull bool - num, isNull, n = readLengthEncodedInteger(data[pos:]) + num, isNull, n := readLengthEncodedInteger(data[pos:]) pos += n @@ -1133,5 +1197,5 @@ func (rows *mysqlRows) readBinaryRow(dest []driver.Value) (err error) { } } - return + return nil } diff --git a/rows.go b/rows.go index 726f02ac0..71dcfb1a7 100644 --- a/rows.go +++ b/rows.go @@ -26,12 +26,12 @@ type mysqlRows struct { eof bool } -func (rows *mysqlRows) Columns() (columns []string) { - columns = make([]string, len(rows.columns)) +func (rows *mysqlRows) Columns() []string { + columns := make([]string, len(rows.columns)) for i := range columns { columns[i] = rows.columns[i].name } - return + return columns } func (rows *mysqlRows) Close() (err error) { diff --git a/statement.go b/statement.go index 56b12533f..025f2ecf5 100644 --- a/statement.go +++ b/statement.go @@ -19,14 +19,14 @@ type mysqlStmt struct { params []mysqlField } -func (stmt *mysqlStmt) Close() (err error) { +func (stmt *mysqlStmt) Close() error { if stmt.mc == nil || stmt.mc.netConn == nil { return errInvalidConn } - err = stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id) + err := stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id) stmt.mc = nil - return + return err } func (stmt *mysqlStmt) NumInput() int { @@ -34,33 +34,34 @@ func (stmt *mysqlStmt) NumInput() int { } func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { - stmt.mc.affectedRows = 0 - stmt.mc.insertId = 0 - // Send command err := stmt.writeExecutePacket(args) if err != nil { return nil, err } + mc := stmt.mc + + mc.affectedRows = 0 + mc.insertId = 0 + // Read Result - var resLen int - resLen, err = stmt.mc.readResultSetHeaderPacket() + resLen, err := mc.readResultSetHeaderPacket() if err == nil { if resLen > 0 { // Columns - err = stmt.mc.readUntilEOF() + err = mc.readUntilEOF() if err != nil { return nil, err } // Rows - err = stmt.mc.readUntilEOF() + err = mc.readUntilEOF() } if err == nil { return &mysqlResult{ - affectedRows: int64(stmt.mc.affectedRows), - insertId: int64(stmt.mc.insertId), + affectedRows: int64(mc.affectedRows), + insertId: int64(mc.insertId), }, nil } } @@ -75,21 +76,19 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { return nil, err } + mc := stmt.mc + // Read Result - var resLen int - resLen, err = stmt.mc.readResultSetHeaderPacket() + resLen, err := mc.readResultSetHeaderPacket() if err != nil { return nil, err } - rows := &mysqlRows{stmt.mc, true, nil, false} + rows := &mysqlRows{mc, true, nil, false} if resLen > 0 { // Columns - rows.columns, err = stmt.mc.readColumns(resLen) - if err != nil { - return nil, err - } + rows.columns, err = mc.readColumns(resLen) } return rows, err