Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ Egor Smolyakov <egorsmkv at gmail.com>
Evan Shaw <evan at vendhq.com>
Frederick Mayle <frederickmayle at gmail.com>
Gustavo Kristic <gkristic at gmail.com>
Hajime Nakagami <nakagami at gmail.com>
Hanno Braun <mail at hannobraun.com>
Henri Yandell <flamefew at gmail.com>
Hirotaka Yamamoto <ymmt2005 at gmail.com>
Expand Down
6 changes: 6 additions & 0 deletions const.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,9 @@ const (
statusInTransReadonly
statusSessionStateChanged
)

const (
cachingSha2PasswordRequestPublicKey = 2
cachingSha2PasswordFastAuthSuccess = 3
cachingSha2PasswordPerformFullAuthentication = 4
)
28 changes: 24 additions & 4 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,20 +107,20 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
mc.writeTimeout = mc.cfg.WriteTimeout

// Reading Handshake Initialization Packet
cipher, err := mc.readInitPacket()
cipher, pluginName, err := mc.readInitPacket()
if err != nil {
mc.cleanup()
return nil, err
}

// Send Client Authentication Packet
if err = mc.writeAuthPacket(cipher); err != nil {
if err = mc.writeAuthPacket(cipher, pluginName); err != nil {
mc.cleanup()
return nil, err
}

// Handle response to auth packet, switch methods if possible
if err = handleAuthResult(mc, cipher); err != nil {
if err = handleAuthResult(mc, cipher, pluginName); err != nil {
// Authentication failed and MySQL has already closed the connection
// (https://dev.mysql.com/doc/internals/en/authentication-fails.html).
// Do not send COM_QUIT, just cleanup and return the error.
Expand Down Expand Up @@ -153,7 +153,27 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
return mc, nil
}

func handleAuthResult(mc *mysqlConn, oldCipher []byte) error {
func handleAuthResult(mc *mysqlConn, oldCipher []byte, pluginName string) error {

// handle caching_sha2_password
if pluginName == "caching_sha2_password" {
auth, err := mc.readCachingSha2PasswordAuthResult()
if err != nil {
return err
}
if auth == cachingSha2PasswordPerformFullAuthentication {
if mc.cfg.tls != nil || mc.cfg.Net == "unix" {
if err = mc.writeClearAuthPacket(); err != nil {
return err
}
} else {
if err = mc.writePublicKeyAuthPacket(oldCipher); err != nil {
return err
}
}
}
}

// Read Result Packet
cipher, err := mc.readResultOK()
if err == nil {
Expand Down
4 changes: 2 additions & 2 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1842,7 +1842,7 @@ func TestSQLInjection(t *testing.T) {

dsns := []string{
dsn,
dsn + "&sql_mode='NO_BACKSLASH_ESCAPES,NO_AUTO_CREATE_USER'",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a required change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NO_AUTO_CREATE_USER SQL mode seems removed at MySQL 8.0
https://dev.mysql.com/doc/refman/8.0/en/mysql-nutshell.html

My linux box show that message

--- FAIL: TestSQLInjection (0.92s)
        driver_test.go:161: error on exec CREATE TABLE test (v INTEGER): Error 1231: Variable 'sql_mode' can't be set to the value of 'NO_AUTO_CREATE_USER'

dsn + "&sql_mode='NO_BACKSLASH_ESCAPES'",
}
for _, testdsn := range dsns {
runTests(t, testdsn, createTest("1 OR 1=1"))
Expand Down Expand Up @@ -1872,7 +1872,7 @@ func TestInsertRetrieveEscapedData(t *testing.T) {

dsns := []string{
dsn,
dsn + "&sql_mode='NO_BACKSLASH_ESCAPES,NO_AUTO_CREATE_USER'",
dsn + "&sql_mode='NO_BACKSLASH_ESCAPES'",
}
for _, testdsn := range dsns {
runTests(t, testdsn, testData)
Expand Down
85 changes: 73 additions & 12 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,14 @@ package mysql

import (
"bytes"
"crypto/rand"
"crypto/rsa"
"crypto/sha1"
"crypto/tls"
"crypto/x509"
"database/sql/driver"
"encoding/binary"
"encoding/pem"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -154,24 +159,24 @@ func (mc *mysqlConn) writePacket(data []byte) error {

// Handshake Initialization Packet
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
func (mc *mysqlConn) readInitPacket() ([]byte, error) {
func (mc *mysqlConn) readInitPacket() ([]byte, string, error) {
data, err := mc.readPacket()
if err != nil {
// for init we can rewrite this to ErrBadConn for sql.Driver to retry, since
// in connection initialization we don't risk retrying non-idempotent actions.
if err == ErrInvalidConn {
return nil, driver.ErrBadConn
return nil, "", driver.ErrBadConn
}
return nil, err
return nil, "", err
}

if data[0] == iERR {
return nil, mc.handleErrorPacket(data)
return nil, "", mc.handleErrorPacket(data)
}

// protocol version [1 byte]
if data[0] < minProtocolVersion {
return nil, fmt.Errorf(
return nil, "", fmt.Errorf(
"unsupported protocol version %d. Version %d or higher is required",
data[0],
minProtocolVersion,
Expand All @@ -191,13 +196,14 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {
// capability flags (lower 2 bytes) [2 bytes]
mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
if mc.flags&clientProtocol41 == 0 {
return nil, ErrOldProtocol
return nil, "", ErrOldProtocol
}
if mc.flags&clientSSL == 0 && mc.cfg.tls != nil {
return nil, ErrNoTLS
return nil, "", ErrNoTLS
}
pos += 2

pluginName := ""
if len(data) > pos {
// character set [1 byte]
// status flags [2 bytes]
Expand All @@ -219,6 +225,8 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {
// The official Python library uses the fixed length 12
// which seems to work but technically could have a hidden bug.
cipher = append(cipher, data[pos:pos+12]...)
pos += 13
pluginName = string(data[pos : pos+bytes.IndexByte(data[pos:], 0x00)])

// TODO: Verify string termination
// EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2)
Expand All @@ -232,18 +240,22 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {
// make a memory safe copy of the cipher slice
var b [20]byte
copy(b[:], cipher)
return b[:], nil
return b[:], pluginName, nil
}

// make a memory safe copy of the cipher slice
var b [8]byte
copy(b[:], cipher)
return b[:], nil
return b[:], pluginName, nil
}

// Client Authentication Packet
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
func (mc *mysqlConn) writeAuthPacket(cipher []byte, pluginName string) error {
if pluginName != "mysql_native_password" && pluginName != "caching_sha2_password" {
return fmt.Errorf("unknown authentication plugin name '%s'", pluginName)
}

// Adjust client flags based on server support
clientFlags := clientProtocol41 |
clientSecureConn |
Expand All @@ -268,7 +280,13 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
}

// User Password
scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd))
var scrambleBuff []byte
switch pluginName {
case "mysql_native_password":
scrambleBuff = scramblePassword(cipher, []byte(mc.cfg.Passwd))
case "caching_sha2_password":
scrambleBuff = scrambleCachingSha2Password(cipher, []byte(mc.cfg.Passwd))
}

pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + 1 + len(scrambleBuff) + 21 + 1

Expand Down Expand Up @@ -350,7 +368,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
}

// Assume native client during response
pos += copy(data[pos:], "mysql_native_password")
pos += copy(data[pos:], pluginName)
data[pos] = 0x00

// Send Auth packet
Expand Down Expand Up @@ -422,6 +440,39 @@ func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error {
return mc.writePacket(data)
}

// Caching sha2 authentication. Public key request and send encrypted password
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
func (mc *mysqlConn) writePublicKeyAuthPacket(cipher []byte) error {
// request public key
data := mc.buf.takeSmallBuffer(4 + 1)
data[4] = cachingSha2PasswordRequestPublicKey
mc.writePacket(data)

data, err := mc.readPacket()
if err != nil {
return err
}

block, _ := pem.Decode(data[1:])
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return err
}

plain := make([]byte, 20)
for k, v := range []byte(mc.cfg.Passwd) {
plain[k] = byte(v)
}
Copy link
Member

@methane methane May 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When len(mc.cfg.Passwd) > 20, this code may overflow.
Could we just use copy(plain, mc.cfg.Passwd) here?

for i := range plain {
plain[i] ^= cipher[i]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

len(plain) == 20, but how about len(cipher)?
Shouldn't we do cipher[i % len(cipher)]?

Copy link
Contributor Author

@nakagami nakagami May 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Please fix only above. (len(mc.cfg.Passwd) > 20 case).

Copy link
Contributor Author

@nakagami nakagami May 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, There is no limitation about password length (< 20) when caching_sha2_password.
Now fix it 84f6018

}
sha1 := sha1.New()
enc, _ := rsa.EncryptOAEP(sha1, rand.Reader, pub.(*rsa.PublicKey), plain, nil)
data = mc.buf.takeSmallBuffer(4 + len(enc))
copy(data[4:], enc)
return mc.writePacket(data)
}

/******************************************************************************
* Command Packets *
******************************************************************************/
Expand Down Expand Up @@ -535,6 +586,16 @@ func (mc *mysqlConn) readResultOK() ([]byte, error) {
return nil, err
}

func (mc *mysqlConn) readCachingSha2PasswordAuthResult() (int, error) {
data, err := mc.readPacket()
if err == nil {
if data[0] != 1 {
return 0, ErrMalformPkt
}
}
return int(data[1]), err
}

// Result Set Header Packet
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset
func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) {
Expand Down
29 changes: 29 additions & 0 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ package mysql

import (
"crypto/sha1"
"crypto/sha256"
"crypto/tls"
"database/sql/driver"
"encoding/binary"
Expand Down Expand Up @@ -211,6 +212,34 @@ func scrambleOldPassword(scramble, password []byte) []byte {
return out[:]
}

// Encrypt password using 8.0 default method
func scrambleCachingSha2Password(scramble, password []byte) []byte {
if len(password) == 0 {
return nil
}

// XOR(SHA256(password), SHA256(SHA256(SHA256(password)), scramble))

crypt := sha256.New()
crypt.Write(password)
message1 := crypt.Sum(nil)

crypt.Reset()
crypt.Write(message1)
message1Hash := crypt.Sum(nil)

crypt.Reset()
crypt.Write(message1Hash)
crypt.Write(scramble)
message2 := crypt.Sum(nil)

for i := range message1 {
message1[i] ^= message2[i]
}

return message1
}

/******************************************************************************
* Time related utils *
******************************************************************************/
Expand Down
18 changes: 18 additions & 0 deletions utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,24 @@ func TestOldPass(t *testing.T) {
}
}

func TestCachingSha2Pass(t *testing.T) {
scramble := []byte{10, 47, 74, 111, 75, 73, 34, 48, 88, 76, 114, 74, 37, 13, 3, 80, 82, 2, 23, 21}
vectors := []struct {
pass string
out string
}{
{"secret", "f490e76f66d9d86665ce54d98c78d0acfe2fb0b08b423da807144873d30b312c"},
{"secret2", "abc3934a012cf342e876071c8ee202de51785b430258a7a0138bc79c4d800bc6"},
}
for _, tuple := range vectors {
ours := scrambleCachingSha2Password(scramble, []byte(tuple.pass))
if tuple.out != fmt.Sprintf("%x", ours) {
t.Errorf("Failed caching sha2 password %q", tuple.pass)
}
}

}

func TestFormatBinaryDateTime(t *testing.T) {
rawDate := [11]byte{}
binary.LittleEndian.PutUint16(rawDate[:2], 1978) // years
Expand Down