From b03abe27ce68d87993334bc9b895999120458390 Mon Sep 17 00:00:00 2001 From: Nicola Peduzzi Date: Sun, 16 Jun 2013 18:37:06 +0200 Subject: [PATCH 1/5] Added support for old_password authentication method --- packets.go | 32 ++++++++++++++++++++++++++-- utils.go | 62 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 92 insertions(+), 2 deletions(-) diff --git a/packets.go b/packets.go index 92be62158..805858bdc 100644 --- a/packets.go +++ b/packets.go @@ -307,6 +307,29 @@ func (mc *mysqlConn) writeAuthPacket() error { return mc.writePacket(data) } +// Client old authentication packet +// http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::AuthSwitchResponse +func (mc *mysqlConn) writeOldAuthPacket() error { + // User password + scrambleBuff := scrambleOldPassword(mc.cipher, []byte(mc.cfg.passwd)) + mc.cipher = nil + + // Calculate the packet lenght and add a tailing 0 + pktLen := len(scrambleBuff) + 1 + data := make([]byte, pktLen+4) + + // Add the packet header [24bit length + 1 byte sequence] + data[0] = byte(pktLen) + data[1] = byte(pktLen >> 8) + data[2] = byte(pktLen >> 16) + data[3] = mc.sequence + + // Add the scrambled password (it will be terminated by 0) + copy(data[4:], scrambleBuff) + + return mc.writePacket(data) +} + /****************************************************************************** * Command Packets * ******************************************************************************/ @@ -388,8 +411,13 @@ func (mc *mysqlConn) readResultOK() error { case iOK: return mc.handleOkPacket(data) - case iEOF: // someone is using old_passwords - return errOldPassword + case iEOF: + // someone is using old_passwords + err = mc.writeOldAuthPacket() + if err != nil { + return err + } + return mc.readResultOK() default: // Error otherwise return mc.handleErrorPacket(data) diff --git a/utils.go b/utils.go index 6658ef6f3..088f1a895 100644 --- a/utils.go +++ b/utils.go @@ -17,6 +17,7 @@ import ( "fmt" "io" "log" + "math" "os" "regexp" "strings" @@ -213,6 +214,67 @@ func scramblePassword(scramble, password []byte) []byte { return scramble } +// Encrypt password using pre 4.1 (old password) method +// https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c +type myRnd struct { + seed1, seed2 uint32 +} + +const myRndMaxVal = 0x3FFFFFFF + +func newMyRnd(seed1, seed2 uint32) *myRnd { + r := new(myRnd) + r.seed1 = seed1 % myRndMaxVal + r.seed2 = seed2 % myRndMaxVal + return r +} + +func (r *myRnd) Float64() float64 { + r.seed1 = (r.seed1*3 + r.seed2) % myRndMaxVal + r.seed2 = (r.seed1 + r.seed2 + 33) % myRndMaxVal + return float64(r.seed1) / myRndMaxVal +} + +// https://github.com/atcurtis/mariadb/blob/master/sql/password.c +func pwHash(password []byte) (result [2]uint32) { + var nr, add, nr2, tmp uint32 + nr, add, nr2 = 1345345333, 7, 0x12345671 + + for _, c := range password { + if c == ' ' || c == '\t' { + continue // skip space in password + } + + tmp = uint32(c) + nr ^= (((nr & 63) + add) * tmp) + (nr << 8) + nr2 += (nr2 << 8) ^ nr + add += tmp + } + + result[0] = nr & ((1 << 31) - 1) // Don't use sign bit (str2int) + result[1] = nr2 & ((1 << 31) - 1) + return +} + +func scrambleOldPassword(scramble, password []byte) []byte { + if len(password) == 0 { + return nil + } + scramble = scramble[:8] + hashPw := pwHash(password) + hashSc := pwHash(scramble) + r := newMyRnd(hashPw[0]^hashSc[0], hashPw[1]^hashSc[1]) + var out [8]byte + for i := range out { + out[i] = byte(math.Floor(r.Float64()*31) + 64) + } + extra := byte(math.Floor(r.Float64() * 31)) + for i := range out { + out[i] ^= extra + } + return out[:] +} + // Returns the bool value of the input. // The 2nd return value indicates if the input was a valid bool value func readBool(input string) (value bool, valid bool) { From e78fff054d158f4f4814688c1e0c2e9e606812f9 Mon Sep 17 00:00:00 2001 From: Nicola Peduzzi Date: Sun, 16 Jun 2013 18:43:15 +0200 Subject: [PATCH 2/5] Fixed old_password authentication cipher removal --- packets.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/packets.go b/packets.go index 805858bdc..9c2ad6a55 100644 --- a/packets.go +++ b/packets.go @@ -226,7 +226,6 @@ func (mc *mysqlConn) writeAuthPacket() error { // User Password scrambleBuff := scramblePassword(mc.cipher, []byte(mc.cfg.passwd)) - mc.cipher = nil pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.user) + 1 + 1 + len(scrambleBuff) @@ -409,6 +408,8 @@ func (mc *mysqlConn) readResultOK() error { switch data[0] { case iOK: + // Remove the chipher in case of successfull authentication + mc.cipher = nil return mc.handleOkPacket(data) case iEOF: From 83ed16b5063110dd80787235e502e4bdc53d09d8 Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Sat, 14 Sep 2013 18:11:13 +0200 Subject: [PATCH 3/5] Require explicitly allowing old passwords + close connection if authentication fails --- README.md | 1 + connection.go | 24 ++++++++++++------------ driver.go | 27 ++++++++++++++++++++++----- errors.go | 2 +- packets.go | 32 +++++++++++++------------------- utils.go | 9 +++++++++ utils_test.go | 18 +++++++++--------- 7 files changed, 67 insertions(+), 46 deletions(-) diff --git a/README.md b/README.md index fec5a4431..60f449ee5 100644 --- a/README.md +++ b/README.md @@ -107,6 +107,7 @@ For Unix domain sockets the address is the absolute path to the MySQL-Server-soc Possible Parameters are: * `allowAllFiles`: `allowAllFiles=true` disables the file Whitelist for `LOAD DATA LOCAL INFILE` and allows *all* files. [*Might be insecure!*](http://dev.mysql.com/doc/refman/5.7/en/load-data-local.html) + * `allowOldPasswords`: `allowAllFiles=true` allows the usage of the insecure old password method. This should be avoided, but is necessary in some cases. See also [the old_passwords wiki page](https://github.com/go-sql-driver/mysql/wiki/old_passwords). * `charset`: Sets the charset used for client-server interaction ("SET NAMES `value`"). If multiple charsets are set (separated by a comma), the following charset is used if setting the charset failes. This enables support for `utf8mb4` ([introduced in MySQL 5.5.3](http://dev.mysql.com/doc/refman/5.5/en/charset-unicode-utf8mb4.html)) with fallback to `utf8` for older servers (`charset=utf8mb4,utf8`). * `clientFoundRows`: `clientFoundRows=true` causes an UPDATE to return the number of matching rows instead of the number of rows changed. * `loc`: Sets the location for time.Time values (when using `parseTime=true`). The default is `UTC`. *"Local"* sets the system's location. See [time.LoadLocation](http://golang.org/pkg/time/#LoadLocation) for details. diff --git a/connection.go b/connection.go index df07b955e..83c77ec58 100644 --- a/connection.go +++ b/connection.go @@ -21,7 +21,6 @@ import ( type mysqlConn struct { cfg *config flags clientFlag - cipher []byte netConn net.Conn buf *buffer protocol uint8 @@ -35,17 +34,18 @@ type mysqlConn struct { } type config struct { - user string - passwd string - net string - addr string - dbname string - params map[string]string - loc *time.Location - timeout time.Duration - tls *tls.Config - allowAllFiles bool - clientFoundRows bool + user string + passwd string + net string + addr string + dbname string + params map[string]string + loc *time.Location + timeout time.Duration + tls *tls.Config + allowAllFiles bool + allowOldPasswords bool + clientFoundRows bool } // Handles parameters set in DSN diff --git a/driver.go b/driver.go index 8f093c69c..53afaf2d4 100644 --- a/driver.go +++ b/driver.go @@ -52,26 +52,42 @@ func (d *MySQLDriver) Open(dsn string) (driver.Conn, error) { mc.buf = newBuffer(mc.netConn) // Reading Handshake Initialization Packet - err = mc.readInitPacket() + cipher, err := mc.readInitPacket() if err != nil { + mc.Close() return nil, err } // Send Client Authentication Packet - err = mc.writeAuthPacket() - if err != nil { + if err = mc.writeAuthPacket(cipher); err != nil { + mc.Close() return nil, err } // Read Result Packet err = mc.readResultOK() if err != nil { - return nil, err + // Retry with old authentication method, if allowed + if mc.cfg.allowOldPasswords && err == errOldPassword { + if err = mc.writeOldAuthPacket(cipher); err != nil { + mc.Close() + return nil, err + } + if err = mc.readResultOK(); err != nil { + mc.Close() + return nil, err + } + } else { + mc.Close() + return nil, err + } + } // Get max allowed packet size maxap, err := mc.getSystemVar("max_allowed_packet") if err != nil { + mc.Close() return nil, err } mc.maxPacketAllowed = stringToInt(maxap) - 1 @@ -82,10 +98,11 @@ func (d *MySQLDriver) Open(dsn string) (driver.Conn, error) { // Handle DSN Params err = mc.handleParams() if err != nil { + mc.Close() return nil, err } - return mc, err + return mc, nil } func init() { diff --git a/errors.go b/errors.go index 09b3ef14e..d1f13df15 100644 --- a/errors.go +++ b/errors.go @@ -20,7 +20,7 @@ var ( errInvalidConn = errors.New("Invalid Connection") errMalformPkt = errors.New("Malformed Packet") errNoTLS = errors.New("TLS encryption requested but server does not support TLS") - errOldPassword = errors.New("It seems like you are using old_passwords, which is unsupported. See https://github.com/go-sql-driver/mysql/wiki/old_passwords") + errOldPassword = errors.New("This server only supports the insecure old password authentication. If you still want to use it, please add 'allowOldPasswords=1' to your DSN. See also https://github.com/go-sql-driver/mysql/wiki/old_passwords") errOldProtocol = errors.New("MySQL-Server does not support required Protocol 41+") errPktSync = errors.New("Commands out of sync. You can't run this command now") errPktSyncMul = errors.New("Commands out of sync. Did you run multiple statements at once?") diff --git a/packets.go b/packets.go index 9c2ad6a55..e893d19b4 100644 --- a/packets.go +++ b/packets.go @@ -138,14 +138,14 @@ func (mc *mysqlConn) splitPacket(data []byte) (err error) { // Handshake Initialization Packet // http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::Handshake -func (mc *mysqlConn) readInitPacket() (err error) { +func (mc *mysqlConn) readInitPacket() (cipher []byte, err error) { data, err := mc.readPacket() if err != nil { return } if data[0] == iERR { - return mc.handleErrorPacket(data) + return nil, mc.handleErrorPacket(data) } // protocol version [1 byte] @@ -154,6 +154,7 @@ func (mc *mysqlConn) readInitPacket() (err error) { "Unsupported MySQL Protocol Version %d. Protocol Version %d or higher is required", data[0], minProtocolVersion) + return } // server version [null terminated string] @@ -161,7 +162,7 @@ func (mc *mysqlConn) readInitPacket() (err error) { pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4 // first part of the password cipher [8 bytes] - mc.cipher = append(mc.cipher, data[pos:pos+8]...) + cipher = data[pos : pos+8] // (filler) always 0x00 [1 byte] pos += 8 + 1 @@ -169,10 +170,10 @@ func (mc *mysqlConn) readInitPacket() (err error) { // capability flags (lower 2 bytes) [2 bytes] mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) if mc.flags&clientProtocol41 == 0 { - err = errOldProtocol + return nil, errOldProtocol } if mc.flags&clientSSL == 0 && mc.cfg.tls != nil { - return errNoTLS + return nil, errNoTLS } pos += 2 @@ -188,7 +189,7 @@ func (mc *mysqlConn) readInitPacket() (err error) { // The documentation is ambiguous about the length. // The official Python library uses the fixed length 12 // which is not documented but seems to work. - mc.cipher = append(mc.cipher, data[pos:pos+12]...) + cipher = append(cipher, data[pos:pos+12]...) // TODO: Verify string termination // EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2) @@ -206,7 +207,7 @@ func (mc *mysqlConn) readInitPacket() (err error) { // Client Authentication Packet // http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::HandshakeResponse -func (mc *mysqlConn) writeAuthPacket() error { +func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { // Adjust client flags based on server support clientFlags := clientProtocol41 | clientSecureConn | @@ -225,7 +226,7 @@ func (mc *mysqlConn) writeAuthPacket() error { } // User Password - scrambleBuff := scramblePassword(mc.cipher, []byte(mc.cfg.passwd)) + scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.passwd)) pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.user) + 1 + 1 + len(scrambleBuff) @@ -308,10 +309,9 @@ func (mc *mysqlConn) writeAuthPacket() error { // Client old authentication packet // http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::AuthSwitchResponse -func (mc *mysqlConn) writeOldAuthPacket() error { +func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error { // User password - scrambleBuff := scrambleOldPassword(mc.cipher, []byte(mc.cfg.passwd)) - mc.cipher = nil + scrambleBuff := scrambleOldPassword(cipher, []byte(mc.cfg.passwd)) // Calculate the packet lenght and add a tailing 0 pktLen := len(scrambleBuff) + 1 @@ -323,7 +323,7 @@ func (mc *mysqlConn) writeOldAuthPacket() error { data[2] = byte(pktLen >> 16) data[3] = mc.sequence - // Add the scrambled password (it will be terminated by 0) + // Add the scrambled password [null terminated string] copy(data[4:], scrambleBuff) return mc.writePacket(data) @@ -408,17 +408,11 @@ func (mc *mysqlConn) readResultOK() error { switch data[0] { case iOK: - // Remove the chipher in case of successfull authentication - mc.cipher = nil return mc.handleOkPacket(data) case iEOF: // someone is using old_passwords - err = mc.writeOldAuthPacket() - if err != nil { - return err - } - return mc.readResultOK() + return errOldPassword default: // Error otherwise return mc.handleErrorPacket(data) diff --git a/utils.go b/utils.go index 088f1a895..e078380ac 100644 --- a/utils.go +++ b/utils.go @@ -126,6 +126,15 @@ func parseDSN(dsn string) (cfg *config, err error) { return } + // Use old authentication mode (pre MySQL 4.1) + case "allowOldPasswords": + var isBool bool + cfg.allowOldPasswords, isBool = readBool(value) + if !isBool { + err = fmt.Errorf("Invalid Bool value: %s", value) + return + } + // Time Location case "loc": cfg.loc, err = time.LoadLocation(value) diff --git a/utils_test.go b/utils_test.go index 836790061..46d46a492 100644 --- a/utils_test.go +++ b/utils_test.go @@ -21,15 +21,15 @@ func TestDSNParser(t *testing.T) { out string loc *time.Location }{ - {"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p timeout:0 tls: allowAllFiles:false clientFoundRows:false}", time.UTC}, - {"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls: allowAllFiles:false clientFoundRows:false}", time.UTC}, - {"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls: allowAllFiles:false clientFoundRows:false}", time.UTC}, - {"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p timeout:0 tls: allowAllFiles:false clientFoundRows:false}", time.UTC}, - {"user:password@/dbname?loc=UTC&timeout=30s&allowAllFiles=1&clientFoundRows=true", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:30000000000 tls: allowAllFiles:true clientFoundRows:true}", time.UTC}, - {"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p timeout:0 tls: allowAllFiles:false clientFoundRows:false}", time.Local}, - {"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:0 tls: allowAllFiles:false clientFoundRows:false}", time.UTC}, - {"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls: allowAllFiles:false clientFoundRows:false}", time.UTC}, - {"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls: allowAllFiles:false clientFoundRows:false}", time.UTC}, + {"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, + {"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, + {"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, + {"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, + {"user:password@/dbname?loc=UTC&timeout=30s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:30000000000 tls: allowAllFiles:true allowOldPasswords:true clientFoundRows:true}", time.UTC}, + {"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.Local}, + {"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, + {"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, + {"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, } var cfg *config From 0bc514ddc49f60df60ecbeb42a96c6229f202005 Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Sat, 14 Sep 2013 19:05:59 +0200 Subject: [PATCH 4/5] Small refactoring --- utils.go | 61 ++++++++++++++++++++++++++++++++++---------------------- 1 file changed, 37 insertions(+), 24 deletions(-) diff --git a/utils.go b/utils.go index e078380ac..e1267a795 100644 --- a/utils.go +++ b/utils.go @@ -192,6 +192,24 @@ func parseDSN(dsn string) (cfg *config, err error) { return } +// Returns the bool value of the input. +// The 2nd return value indicates if the input was a valid bool value +func readBool(input string) (value bool, valid bool) { + switch input { + case "1", "true", "TRUE", "True": + return true, true + case "0", "false", "FALSE", "False": + return false, true + } + + // Not a valid bool value + return +} + +/****************************************************************************** +* Authentication * +******************************************************************************/ + // Encrypt password using 4.1+ method // http://forge.mysql.com/wiki/MySQL_Internals_ClientServer_Protocol#4.1_and_later func scramblePassword(scramble, password []byte) []byte { @@ -244,24 +262,27 @@ func (r *myRnd) Float64() float64 { return float64(r.seed1) / myRndMaxVal } -// https://github.com/atcurtis/mariadb/blob/master/sql/password.c func pwHash(password []byte) (result [2]uint32) { - var nr, add, nr2, tmp uint32 - nr, add, nr2 = 1345345333, 7, 0x12345671 + var add uint32 = 7 + var tmp uint32 + + result[0] = 1345345333 + result[1] = 0x12345671 for _, c := range password { + // skip spaces and tabs in password if c == ' ' || c == '\t' { - continue // skip space in password + continue } tmp = uint32(c) - nr ^= (((nr & 63) + add) * tmp) + (nr << 8) - nr2 += (nr2 << 8) ^ nr + result[0] ^= (((result[0] & 63) + add) * tmp) + (result[0] << 8) + result[1] += (result[1] << 8) ^ result[0] add += tmp } - result[0] = nr & ((1 << 31) - 1) // Don't use sign bit (str2int) - result[1] = nr2 & ((1 << 31) - 1) + result[0] &= (1 << 31) - 1 // Don't use sign bit (str2int) + result[1] &= (1 << 31) - 1 return } @@ -269,33 +290,25 @@ func scrambleOldPassword(scramble, password []byte) []byte { if len(password) == 0 { return nil } + scramble = scramble[:8] + hashPw := pwHash(password) hashSc := pwHash(scramble) + r := newMyRnd(hashPw[0]^hashSc[0], hashPw[1]^hashSc[1]) + var out [8]byte for i := range out { out[i] = byte(math.Floor(r.Float64()*31) + 64) } - extra := byte(math.Floor(r.Float64() * 31)) - for i := range out { - out[i] ^= extra - } - return out[:] -} -// Returns the bool value of the input. -// The 2nd return value indicates if the input was a valid bool value -func readBool(input string) (value bool, valid bool) { - switch input { - case "1", "true", "TRUE", "True": - return true, true - case "0", "false", "FALSE", "False": - return false, true + mask := byte(math.Floor(r.Float64() * 31)) + for i := range out { + out[i] ^= mask } - // Not a valid bool value - return + return out[:] } /****************************************************************************** From ff8fee69b9e5822cb53f6b72b0b764a15c3a0013 Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Sun, 15 Sep 2013 13:27:49 +0200 Subject: [PATCH 5/5] refactor again --- utils.go | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/utils.go b/utils.go index e1267a795..63df8d974 100644 --- a/utils.go +++ b/utils.go @@ -17,7 +17,6 @@ import ( "fmt" "io" "log" - "math" "os" "regexp" "strings" @@ -211,7 +210,6 @@ func readBool(input string) (value bool, valid bool) { ******************************************************************************/ // Encrypt password using 4.1+ method -// http://forge.mysql.com/wiki/MySQL_Internals_ClientServer_Protocol#4.1_and_later func scramblePassword(scramble, password []byte) []byte { if len(password) == 0 { return nil @@ -249,19 +247,25 @@ type myRnd struct { const myRndMaxVal = 0x3FFFFFFF +// Pseudo random number generator func newMyRnd(seed1, seed2 uint32) *myRnd { - r := new(myRnd) - r.seed1 = seed1 % myRndMaxVal - r.seed2 = seed2 % myRndMaxVal - return r + return &myRnd{ + seed1: seed1 % myRndMaxVal, + seed2: seed2 % myRndMaxVal, + } } -func (r *myRnd) Float64() float64 { +// Tested to be equivalent to MariaDB's floating point variant +// http://play.golang.org/p/QHvhd4qved +// http://play.golang.org/p/RG0q4ElWDx +func (r *myRnd) NextByte() byte { r.seed1 = (r.seed1*3 + r.seed2) % myRndMaxVal r.seed2 = (r.seed1 + r.seed2 + 33) % myRndMaxVal - return float64(r.seed1) / myRndMaxVal + + return byte(uint64(r.seed1) * 31 / myRndMaxVal) } +// Generate binary hash from byte string using insecure pre 4.1 method func pwHash(password []byte) (result [2]uint32) { var add uint32 = 7 var tmp uint32 @@ -281,11 +285,14 @@ func pwHash(password []byte) (result [2]uint32) { add += tmp } - result[0] &= (1 << 31) - 1 // Don't use sign bit (str2int) - result[1] &= (1 << 31) - 1 + // Remove sign bit (1<<31)-1) + result[0] &= 0x7FFFFFFF + result[1] &= 0x7FFFFFFF + return } +// Encrypt password using insecure pre 4.1 method func scrambleOldPassword(scramble, password []byte) []byte { if len(password) == 0 { return nil @@ -300,10 +307,10 @@ func scrambleOldPassword(scramble, password []byte) []byte { var out [8]byte for i := range out { - out[i] = byte(math.Floor(r.Float64()*31) + 64) + out[i] = r.NextByte() + 64 } - mask := byte(math.Floor(r.Float64() * 31)) + mask := r.NextByte() for i := range out { out[i] ^= mask }