Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ For Unix domain sockets the address is the absolute path to the MySQL-Server-soc

Possible Parameters are:
* `allowAllFiles`: `allowAllFiles=true` disables the file Whitelist for `LOAD DATA LOCAL INFILE` and allows *all* files. [*Might be insecure!*](http://dev.mysql.com/doc/refman/5.7/en/load-data-local.html)
* `allowOldPasswords`: `allowAllFiles=true` allows the usage of the insecure old password method. This should be avoided, but is necessary in some cases. See also [the old_passwords wiki page](https://github.com/go-sql-driver/mysql/wiki/old_passwords).
* `charset`: Sets the charset used for client-server interaction ("SET NAMES `value`"). If multiple charsets are set (separated by a comma), the following charset is used if setting the charset failes. This enables support for `utf8mb4` ([introduced in MySQL 5.5.3](http://dev.mysql.com/doc/refman/5.5/en/charset-unicode-utf8mb4.html)) with fallback to `utf8` for older servers (`charset=utf8mb4,utf8`).
* `clientFoundRows`: `clientFoundRows=true` causes an UPDATE to return the number of matching rows instead of the number of rows changed.
* `loc`: Sets the location for time.Time values (when using `parseTime=true`). The default is `UTC`. *"Local"* sets the system's location. See [time.LoadLocation](http://golang.org/pkg/time/#LoadLocation) for details.
Expand Down
24 changes: 12 additions & 12 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
type mysqlConn struct {
cfg *config
flags clientFlag
cipher []byte
netConn net.Conn
buf *buffer
protocol uint8
Expand All @@ -35,17 +34,18 @@ type mysqlConn struct {
}

type config struct {
user string
passwd string
net string
addr string
dbname string
params map[string]string
loc *time.Location
timeout time.Duration
tls *tls.Config
allowAllFiles bool
clientFoundRows bool
user string
passwd string
net string
addr string
dbname string
params map[string]string
loc *time.Location
timeout time.Duration
tls *tls.Config
allowAllFiles bool
allowOldPasswords bool
clientFoundRows bool
}

// Handles parameters set in DSN
Expand Down
27 changes: 22 additions & 5 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,26 +52,42 @@ func (d *MySQLDriver) Open(dsn string) (driver.Conn, error) {
mc.buf = newBuffer(mc.netConn)

// Reading Handshake Initialization Packet
err = mc.readInitPacket()
cipher, err := mc.readInitPacket()
if err != nil {
mc.Close()
return nil, err
}

// Send Client Authentication Packet
err = mc.writeAuthPacket()
if err != nil {
if err = mc.writeAuthPacket(cipher); err != nil {
mc.Close()
return nil, err
}

// Read Result Packet
err = mc.readResultOK()
if err != nil {
return nil, err
// Retry with old authentication method, if allowed
if mc.cfg.allowOldPasswords && err == errOldPassword {
if err = mc.writeOldAuthPacket(cipher); err != nil {
mc.Close()
return nil, err
}
if err = mc.readResultOK(); err != nil {
mc.Close()
return nil, err
}
} else {
mc.Close()
return nil, err
}

}

// Get max allowed packet size
maxap, err := mc.getSystemVar("max_allowed_packet")
if err != nil {
mc.Close()
return nil, err
}
mc.maxPacketAllowed = stringToInt(maxap) - 1
Expand All @@ -82,10 +98,11 @@ func (d *MySQLDriver) Open(dsn string) (driver.Conn, error) {
// Handle DSN Params
err = mc.handleParams()
if err != nil {
mc.Close()
return nil, err
}

return mc, err
return mc, nil
}

func init() {
Expand Down
2 changes: 1 addition & 1 deletion errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ var (
errInvalidConn = errors.New("Invalid Connection")
errMalformPkt = errors.New("Malformed Packet")
errNoTLS = errors.New("TLS encryption requested but server does not support TLS")
errOldPassword = errors.New("It seems like you are using old_passwords, which is unsupported. See https://github.com/go-sql-driver/mysql/wiki/old_passwords")
errOldPassword = errors.New("This server only supports the insecure old password authentication. If you still want to use it, please add 'allowOldPasswords=1' to your DSN. See also https://github.com/go-sql-driver/mysql/wiki/old_passwords")
errOldProtocol = errors.New("MySQL-Server does not support required Protocol 41+")
errPktSync = errors.New("Commands out of sync. You can't run this command now")
errPktSyncMul = errors.New("Commands out of sync. Did you run multiple statements at once?")
Expand Down
43 changes: 33 additions & 10 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,14 +138,14 @@ func (mc *mysqlConn) splitPacket(data []byte) (err error) {

// Handshake Initialization Packet
// http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::Handshake
func (mc *mysqlConn) readInitPacket() (err error) {
func (mc *mysqlConn) readInitPacket() (cipher []byte, err error) {
data, err := mc.readPacket()
if err != nil {
return
}

if data[0] == iERR {
return mc.handleErrorPacket(data)
return nil, mc.handleErrorPacket(data)
}

// protocol version [1 byte]
Expand All @@ -154,25 +154,26 @@ func (mc *mysqlConn) readInitPacket() (err error) {
"Unsupported MySQL Protocol Version %d. Protocol Version %d or higher is required",
data[0],
minProtocolVersion)
return
}

// server version [null terminated string]
// connection id [4 bytes]
pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4

// first part of the password cipher [8 bytes]
mc.cipher = append(mc.cipher, data[pos:pos+8]...)
cipher = data[pos : pos+8]

// (filler) always 0x00 [1 byte]
pos += 8 + 1

// capability flags (lower 2 bytes) [2 bytes]
mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
if mc.flags&clientProtocol41 == 0 {
err = errOldProtocol
return nil, errOldProtocol
}
if mc.flags&clientSSL == 0 && mc.cfg.tls != nil {
return errNoTLS
return nil, errNoTLS
}
pos += 2

Expand All @@ -188,7 +189,7 @@ func (mc *mysqlConn) readInitPacket() (err error) {
// The documentation is ambiguous about the length.
// The official Python library uses the fixed length 12
// which is not documented but seems to work.
mc.cipher = append(mc.cipher, data[pos:pos+12]...)
cipher = append(cipher, data[pos:pos+12]...)

// TODO: Verify string termination
// EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2)
Expand All @@ -206,7 +207,7 @@ func (mc *mysqlConn) readInitPacket() (err error) {

// Client Authentication Packet
// http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::HandshakeResponse
func (mc *mysqlConn) writeAuthPacket() error {
func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
// Adjust client flags based on server support
clientFlags := clientProtocol41 |
clientSecureConn |
Expand All @@ -225,8 +226,7 @@ func (mc *mysqlConn) writeAuthPacket() error {
}

// User Password
scrambleBuff := scramblePassword(mc.cipher, []byte(mc.cfg.passwd))
mc.cipher = nil
scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.passwd))

pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.user) + 1 + 1 + len(scrambleBuff)

Expand Down Expand Up @@ -307,6 +307,28 @@ func (mc *mysqlConn) writeAuthPacket() error {
return mc.writePacket(data)
}

// Client old authentication packet
// http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::AuthSwitchResponse
func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error {
// User password
scrambleBuff := scrambleOldPassword(cipher, []byte(mc.cfg.passwd))

// Calculate the packet lenght and add a tailing 0
pktLen := len(scrambleBuff) + 1
data := make([]byte, pktLen+4)

// Add the packet header [24bit length + 1 byte sequence]
data[0] = byte(pktLen)
data[1] = byte(pktLen >> 8)
data[2] = byte(pktLen >> 16)
data[3] = mc.sequence

// Add the scrambled password [null terminated string]
copy(data[4:], scrambleBuff)

return mc.writePacket(data)
}

/******************************************************************************
* Command Packets *
******************************************************************************/
Expand Down Expand Up @@ -388,7 +410,8 @@ func (mc *mysqlConn) readResultOK() error {
case iOK:
return mc.handleOkPacket(data)

case iEOF: // someone is using old_passwords
case iEOF:
// someone is using old_passwords
return errOldPassword

default: // Error otherwise
Expand Down
102 changes: 93 additions & 9 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"fmt"
"io"
"log"
"math"
"os"
"regexp"
"strings"
Expand Down Expand Up @@ -125,6 +126,15 @@ func parseDSN(dsn string) (cfg *config, err error) {
return
}

// Use old authentication mode (pre MySQL 4.1)
case "allowOldPasswords":
var isBool bool
cfg.allowOldPasswords, isBool = readBool(value)
if !isBool {
err = fmt.Errorf("Invalid Bool value: %s", value)
return
}

// Time Location
case "loc":
cfg.loc, err = time.LoadLocation(value)
Expand Down Expand Up @@ -182,6 +192,24 @@ func parseDSN(dsn string) (cfg *config, err error) {
return
}

// Returns the bool value of the input.
// The 2nd return value indicates if the input was a valid bool value
func readBool(input string) (value bool, valid bool) {
switch input {
case "1", "true", "TRUE", "True":
return true, true
case "0", "false", "FALSE", "False":
return false, true
}

// Not a valid bool value
return
}

/******************************************************************************
* Authentication *
******************************************************************************/

// Encrypt password using 4.1+ method
// http://forge.mysql.com/wiki/MySQL_Internals_ClientServer_Protocol#4.1_and_later
func scramblePassword(scramble, password []byte) []byte {
Expand Down Expand Up @@ -213,20 +241,76 @@ func scramblePassword(scramble, password []byte) []byte {
return scramble
}

// Returns the bool value of the input.
// The 2nd return value indicates if the input was a valid bool value
func readBool(input string) (value bool, valid bool) {
switch input {
case "1", "true", "TRUE", "True":
return true, true
case "0", "false", "FALSE", "False":
return false, true
// Encrypt password using pre 4.1 (old password) method
// https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c
type myRnd struct {
seed1, seed2 uint32
}

const myRndMaxVal = 0x3FFFFFFF

func newMyRnd(seed1, seed2 uint32) *myRnd {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not like this?

return &myRnd{
    seed1: seed1 % myRndMaxVal,
    seed2: seed2 % myRndMaxVal,
}

r := new(myRnd)
r.seed1 = seed1 % myRndMaxVal
r.seed2 = seed2 % myRndMaxVal
return r
}

func (r *myRnd) Float64() float64 {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only used with byte(math.Floor(r.Float64() * 31)) and byte(math.Floor(r.Float64() * 31) + 64) below, right?
Why do a float conversion at all?
return byte(uint64(r.seed1) * 31 / myRndMaxVal) should do the trick, right? so there's only the +64 case and the regular case. Screams for a different function name, but this is not a Float64 representation anyway, it "iterates". This is from the top of my head, a test probably wouldn't hurt - but a one time test is enough.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

r.seed1 = (r.seed1*3 + r.seed2) % myRndMaxVal
r.seed2 = (r.seed1 + r.seed2 + 33) % myRndMaxVal
return float64(r.seed1) / myRndMaxVal
}

func pwHash(password []byte) (result [2]uint32) {
var add uint32 = 7
var tmp uint32

result[0] = 1345345333
result[1] = 0x12345671

for _, c := range password {
// skip spaces and tabs in password
if c == ' ' || c == '\t' {
continue
}

tmp = uint32(c)
result[0] ^= (((result[0] & 63) + add) * tmp) + (result[0] << 8)
result[1] += (result[1] << 8) ^ result[0]
add += tmp
}

// Not a valid bool value
result[0] &= (1 << 31) - 1 // Don't use sign bit (str2int)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As this is not uint but uint32, I think 0x7fffffff is clearer than a bitshift. ^0x80000000 would probably be ok, too.

result[1] &= (1 << 31) - 1
return
}

func scrambleOldPassword(scramble, password []byte) []byte {
if len(password) == 0 {
return nil
}

scramble = scramble[:8]

hashPw := pwHash(password)
hashSc := pwHash(scramble)

r := newMyRnd(hashPw[0]^hashSc[0], hashPw[1]^hashSc[1])

var out [8]byte
for i := range out {
out[i] = byte(math.Floor(r.Float64()*31) + 64)
}

mask := byte(math.Floor(r.Float64() * 31))
for i := range out {
out[i] ^= mask
}

return out[:]
}

/******************************************************************************
* Time related utils *
******************************************************************************/
Expand Down
18 changes: 9 additions & 9 deletions utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ func TestDSNParser(t *testing.T) {
out string
loc *time.Location
}{
{"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p timeout:0 tls:<nil> allowAllFiles:false clientFoundRows:false}", time.UTC},
{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false clientFoundRows:false}", time.UTC},
{"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false clientFoundRows:false}", time.UTC},
{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false clientFoundRows:false}", time.UTC},
{"user:password@/dbname?loc=UTC&timeout=30s&allowAllFiles=1&clientFoundRows=true", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:30000000000 tls:<nil> allowAllFiles:true clientFoundRows:true}", time.UTC},
{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false clientFoundRows:false}", time.Local},
{"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false clientFoundRows:false}", time.UTC},
{"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false clientFoundRows:false}", time.UTC},
{"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false clientFoundRows:false}", time.UTC},
{"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"user:password@/dbname?loc=UTC&timeout=30s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:30000000000 tls:<nil> allowAllFiles:true allowOldPasswords:true clientFoundRows:true}", time.UTC},
{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.Local},
{"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
}

var cfg *config
Expand Down