From dc029498cb5a3efbe44e54dcb5cf080d451450fa Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Wed, 16 Oct 2013 17:30:31 +0200 Subject: [PATCH 1/4] New DSN parser + Set right default addr for net=unix Go 1.2RC1 BenchmarkParseDSN_new 200000 10545 ns/op 4039 B/op 42 allocs/op BenchmarkParseDSN_old 10000 233313 ns/op 7588 B/op 91 allocs/op Go 1.1 BenchmarkParseDSN_new 200000 7940 ns/op 4204 B/op 42 allocs/op BenchmarkParseDSN_old 10000 264115 ns/op 8083 B/op 91 allocs/op --- utils.go | 232 ++++++++++++++++++++++++++++++-------------------- utils_test.go | 60 +++++++++---- 2 files changed, 185 insertions(+), 107 deletions(-) diff --git a/utils.go b/utils.go index faada5864..7e4ff9dfe 100644 --- a/utils.go +++ b/utils.go @@ -13,30 +13,24 @@ import ( "crypto/tls" "database/sql/driver" "encoding/binary" + "errors" "fmt" "io" "log" "os" - "regexp" "strings" "time" ) var ( errLog *log.Logger // Error Logger - dsnPattern *regexp.Regexp // Data Source Name Parser tlsConfigRegister map[string]*tls.Config // Register for custom tls.Configs + + errInvalidDSN = errors.New("Invalid DSN") ) func init() { errLog = log.New(os.Stderr, "[MySQL] ", log.Ldate|log.Ltime|log.Lshortfile) - - dsnPattern = regexp.MustCompile( - `^(?:(?P.*?)(?::(?P.*))?@)?` + // [user[:password]@] - `(?:(?P[^\(]*)(?:\((?P[^\)]*)\))?)?` + // [net[(addr)]] - `\/(?P.*?)` + // /dbname - `(?:\?(?P[^\?]*))?$`) // [?param1=value1¶mN=valueN] - tlsConfigRegister = make(map[string]*tls.Config) } @@ -79,96 +73,69 @@ func DeregisterTLSConfig(key string) { func parseDSN(dsn string) (cfg *config, err error) { cfg = new(config) - cfg.params = make(map[string]string) - - matches := dsnPattern.FindStringSubmatch(dsn) - names := dsnPattern.SubexpNames() - - for i, match := range matches { - switch names[i] { - case "user": - cfg.user = match - case "passwd": - cfg.passwd = match - case "net": - cfg.net = match - case "addr": - cfg.addr = match - case "dbname": - cfg.dbname = match - case "params": - for _, v := range strings.Split(match, "&") { - param := strings.SplitN(v, "=", 2) - if len(param) != 2 { - continue - } - - // cfg params - switch value := param[1]; param[0] { - // Disable INFILE whitelist / enable all files - case "allowAllFiles": - var isBool bool - cfg.allowAllFiles, isBool = readBool(value) - if !isBool { - err = fmt.Errorf("Invalid Bool value: %s", value) - return - } + // TODO: use strings.IndexByte when we can depend on Go 1.2 + + // [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN] + // Find the last '/' + for i := len(dsn) - 1; i >= 0; i-- { + if dsn[i] == '/' { + var j int + + // left part is empty if i <= 0 + if i > 0 { + // [username[:password]@][protocol[(address)]] + // Find the last '@' in dsn[:i] + for j = i; j >= 0; j-- { + if dsn[j] == '@' { + // username[:password] + // Find the first ':' in dsn[:j] + var k int + for k = 0; k < j; k++ { + if dsn[k] == ':' { + cfg.passwd = dsn[k+1 : j] + break + } + } + cfg.user = dsn[:k] + + // [protocol[(address)]] + // Find the first '(' in dsn[j+1:i] + for k = j + 1; k < i; k++ { + if dsn[k] == '(' { + // dsn[i-1] must be == ')' if an adress is specified + if dsn[i-1] != ')' { + return nil, errInvalidDSN + } + cfg.addr = dsn[k+1 : i-1] + break + } + } + cfg.net = dsn[j+1 : k] - // Switch "rowsAffected" mode - case "clientFoundRows": - var isBool bool - cfg.clientFoundRows, isBool = readBool(value) - if !isBool { - err = fmt.Errorf("Invalid Bool value: %s", value) - return - } - - // Use old authentication mode (pre MySQL 4.1) - case "allowOldPasswords": - var isBool bool - cfg.allowOldPasswords, isBool = readBool(value) - if !isBool { - err = fmt.Errorf("Invalid Bool value: %s", value) - return + break } + } - // Time Location - case "loc": - cfg.loc, err = time.LoadLocation(value) - if err != nil { - return - } + // non-empty left part must contain an '@' + if j < 0 { + return nil, errInvalidDSN + } + } - // Dial Timeout - case "timeout": - cfg.timeout, err = time.ParseDuration(value) - if err != nil { + // dbname[?param1=value1&...¶mN=valueN] + // Find the first '?' in dsn[i+1:] + for j = i + 1; j < len(dsn); j++ { + if dsn[j] == '?' { + if err = parseDSNParams(cfg, dsn[j+1:]); err != nil { return } - - // TLS-Encryption - case "tls": - boolValue, isBool := readBool(value) - if isBool { - if boolValue { - cfg.tls = &tls.Config{} - } - } else { - if strings.ToLower(value) == "skip-verify" { - cfg.tls = &tls.Config{InsecureSkipVerify: true} - } else if tlsConfig, ok := tlsConfigRegister[value]; ok { - cfg.tls = tlsConfig - } else { - err = fmt.Errorf("Invalid value / unknown config name: %s", value) - return - } - } - - default: - cfg.params[param[0]] = value + break } } + cfg.dbname = dsn[i+1 : j] + + break } } @@ -179,7 +146,15 @@ func parseDSN(dsn string) (cfg *config, err error) { // Set default adress if empty if cfg.addr == "" { - cfg.addr = "127.0.0.1:3306" + switch cfg.net { + case "tcp": + cfg.addr = "127.0.0.1:3306" + case "unix": + cfg.addr = "/tmp/mysql.sock" + default: + return nil, errors.New("Default addr for network '" + cfg.net + "' unknown") + } + } // Set default location if not set @@ -190,6 +165,81 @@ func parseDSN(dsn string) (cfg *config, err error) { return } +func parseDSNParams(cfg *config, params string) (err error) { + cfg.params = make(map[string]string) + + for _, v := range strings.Split(params, "&") { + param := strings.SplitN(v, "=", 2) + if len(param) != 2 { + continue + } + + // cfg params + switch value := param[1]; param[0] { + + // Disable INFILE whitelist / enable all files + case "allowAllFiles": + var isBool bool + cfg.allowAllFiles, isBool = readBool(value) + if !isBool { + return fmt.Errorf("Invalid Bool value: %s", value) + } + + // Switch "rowsAffected" mode + case "clientFoundRows": + var isBool bool + cfg.clientFoundRows, isBool = readBool(value) + if !isBool { + return fmt.Errorf("Invalid Bool value: %s", value) + } + + // Use old authentication mode (pre MySQL 4.1) + case "allowOldPasswords": + var isBool bool + cfg.allowOldPasswords, isBool = readBool(value) + if !isBool { + return fmt.Errorf("Invalid Bool value: %s", value) + } + + // Time Location + case "loc": + cfg.loc, err = time.LoadLocation(value) + if err != nil { + return + } + + // Dial Timeout + case "timeout": + cfg.timeout, err = time.ParseDuration(value) + if err != nil { + return + } + + // TLS-Encryption + case "tls": + boolValue, isBool := readBool(value) + if isBool { + if boolValue { + cfg.tls = &tls.Config{} + } + } else { + if strings.ToLower(value) == "skip-verify" { + cfg.tls = &tls.Config{InsecureSkipVerify: true} + } else if tlsConfig, ok := tlsConfigRegister[value]; ok { + cfg.tls = tlsConfig + } else { + return fmt.Errorf("Invalid value / unknown config name: %s", value) + } + } + + default: + cfg.params[param[0]] = value + } + } + + return +} + // Returns the bool value of the input. // The 2nd return value indicates if the input was a valid bool value func readBool(input string) (value bool, valid bool) { diff --git a/utils_test.go b/utils_test.go index 39ad6cd5d..dbba669bf 100644 --- a/utils_test.go +++ b/utils_test.go @@ -14,23 +14,26 @@ import ( "time" ) -func TestDSNParser(t *testing.T) { - var testDSNs = []struct { - in string - out string - loc *time.Location - }{ - {"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, - {"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, - {"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, - {"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, - {"user:password@/dbname?loc=UTC&timeout=30s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:30000000000 tls: allowAllFiles:true allowOldPasswords:true clientFoundRows:true}", time.UTC}, - {"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.Local}, - {"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, - {"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, - {"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, - } +var testDSNs = []struct { + in string + out string + loc *time.Location +}{ + {"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, + {"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, + {"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, + {"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, + {"user:password@/dbname?loc=UTC&timeout=30s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:30000000000 tls: allowAllFiles:true allowOldPasswords:true clientFoundRows:true}", time.UTC}, + {"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.Local}, + {"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, + {"@/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, + {"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, + {"", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, + {"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, + {"@unix/", "&{user: passwd: net:unix addr:/tmp/mysql.sock dbname: params:map[] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, +} +func TestDSNParser(t *testing.T) { var cfg *config var err error var res string @@ -51,6 +54,31 @@ func TestDSNParser(t *testing.T) { } } +func TestDSNParserInvalid(t *testing.T) { + var invalidDSNs = []string{ + "asdf/dbname", + //"/dbname?arg=/some/unescaped/path", + } + + for i, tst := range invalidDSNs { + if _, err := parseDSN(tst); err == nil { + t.Errorf("invalid DSN #%d. (%s) didn't error!", i, tst) + } + } +} + +func BenchmarkParseDSN(b *testing.B) { + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + for _, tst := range testDSNs { + if _, err := parseDSN(tst.in); err != nil { + b.Error(err.Error()) + } + } + } +} + func TestScanNullTime(t *testing.T) { var scanTests = []struct { in interface{} From 7503ab8073eb37c533ac4a211e9a9f469e5cd267 Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Wed, 16 Oct 2013 19:49:16 +0200 Subject: [PATCH 2/4] Escape DSN param values Not really a behavior change, since it was mostly broken before --- README.md | 6 ++++-- driver_test.go | 3 ++- utils.go | 32 ++++++++++++++++++++++++-------- utils_test.go | 3 ++- 4 files changed, 32 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 60f449ee5..7a95668e3 100644 --- a/README.md +++ b/README.md @@ -110,7 +110,7 @@ Possible Parameters are: * `allowOldPasswords`: `allowAllFiles=true` allows the usage of the insecure old password method. This should be avoided, but is necessary in some cases. See also [the old_passwords wiki page](https://github.com/go-sql-driver/mysql/wiki/old_passwords). * `charset`: Sets the charset used for client-server interaction ("SET NAMES `value`"). If multiple charsets are set (separated by a comma), the following charset is used if setting the charset failes. This enables support for `utf8mb4` ([introduced in MySQL 5.5.3](http://dev.mysql.com/doc/refman/5.5/en/charset-unicode-utf8mb4.html)) with fallback to `utf8` for older servers (`charset=utf8mb4,utf8`). * `clientFoundRows`: `clientFoundRows=true` causes an UPDATE to return the number of matching rows instead of the number of rows changed. - * `loc`: Sets the location for time.Time values (when using `parseTime=true`). The default is `UTC`. *"Local"* sets the system's location. See [time.LoadLocation](http://golang.org/pkg/time/#LoadLocation) for details. + * `loc`: Sets the location for time.Time values (when using `parseTime=true`). The default is `UTC`. *"Local"* sets the system's location. See [time.LoadLocation](http://golang.org/pkg/time/#LoadLocation) for details. Please keep in mind, that param values must be [url.QueryEscape](http://golang.org/pkg/net/url/#QueryEscape)'ed. Alternatively you can manually replace the `/` with `%2F`. For example `US/Pacific` would be `US%2FPacific`. * `parseTime`: `parseTime=true` changes the output type of `DATE` and `DATETIME` values to `time.Time` instead of `[]byte` / `string` * `strict`: Enable strict mode. MySQL warnings are treated as errors. * `timeout`: **Driver** side connection timeout. The value must be a string of decimal numbers, each with optional fraction and a unit suffix ( *"ms"*, *"s"*, *"m"*, *"h"* ), such as *"30s"*, *"0.5m"* or *"1m30s"*. To set a server side timeout, use the parameter [`wait_timeout`](http://dev.mysql.com/doc/refman/5.6/en/server-system-variables.html#sysvar_wait_timeout). @@ -122,6 +122,8 @@ All other parameters are interpreted as system variables: * `tx_isolation`: *"SET [tx_isolation](https://dev.mysql.com/doc/refman/5.5/en/server-system-variables.html#sysvar_tx_isolation)=`value`"* * `param`: *"SET `param`=`value`"* +***The values must be [url.QueryEscape](http://golang.org/pkg/net/url/#QueryEscape)'ed!*** + #### Examples ``` user@unix(/path/to/socket)/dbname @@ -132,7 +134,7 @@ user:password@tcp(localhost:5555)/dbname?autocommit=true ``` ``` -user:password@tcp([de:ad:be:ef::ca:fe]:80)/dbname?tls=skip-verify&charset=utf8mb4,utf8 +user:password@tcp([de:ad:be:ef::ca:fe]:80)/dbname?tls=skip-verify&charset=utf8mb4,utf8&sys_var=withSlash%2FandAt%40 ``` ``` diff --git a/driver_test.go b/driver_test.go index 078484c43..41fbf06a9 100644 --- a/driver_test.go +++ b/driver_test.go @@ -15,6 +15,7 @@ import ( "io" "io/ioutil" "net" + "net/url" "os" "strings" "testing" @@ -206,7 +207,7 @@ func TestTimezoneConversion(t *testing.T) { } for _, tz := range zones { - runTests(t, dsn+"&parseTime=true&loc="+tz, tzTest) + runTests(t, dsn+"&parseTime=true&loc="+url.QueryEscape(tz), tzTest) } } diff --git a/utils.go b/utils.go index 7e4ff9dfe..7f8af463a 100644 --- a/utils.go +++ b/utils.go @@ -17,6 +17,7 @@ import ( "fmt" "io" "log" + "net/url" "os" "strings" "time" @@ -26,7 +27,8 @@ var ( errLog *log.Logger // Error Logger tlsConfigRegister map[string]*tls.Config // Register for custom tls.Configs - errInvalidDSN = errors.New("Invalid DSN") + errInvalidDSNUnescaped = errors.New("Invalid DSN: Did you forget to escape a param value?") + errInvalidDSNAddr = errors.New("Invalid DSN: Network Address not terminated (missing closing brace)") ) func init() { @@ -71,13 +73,14 @@ func DeregisterTLSConfig(key string) { delete(tlsConfigRegister, key) } +// parseDSN parses the DSN string to a config func parseDSN(dsn string) (cfg *config, err error) { cfg = new(config) // TODO: use strings.IndexByte when we can depend on Go 1.2 // [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN] - // Find the last '/' + // Find the last '/' (since the password might contain a '/') for i := len(dsn) - 1; i >= 0; i-- { if dsn[i] == '/' { var j int @@ -105,7 +108,10 @@ func parseDSN(dsn string) (cfg *config, err error) { if dsn[k] == '(' { // dsn[i-1] must be == ')' if an adress is specified if dsn[i-1] != ')' { - return nil, errInvalidDSN + if strings.ContainsRune(dsn[k+1:i], ')') { + return nil, errInvalidDSNUnescaped + } + return nil, errInvalidDSNAddr } cfg.addr = dsn[k+1 : i-1] break @@ -119,7 +125,7 @@ func parseDSN(dsn string) (cfg *config, err error) { // non-empty left part must contain an '@' if j < 0 { - return nil, errInvalidDSN + return nil, errInvalidDSNUnescaped } } @@ -157,7 +163,7 @@ func parseDSN(dsn string) (cfg *config, err error) { } - // Set default location if not set + // Set default location if empty if cfg.loc == nil { cfg.loc = time.UTC } @@ -165,9 +171,9 @@ func parseDSN(dsn string) (cfg *config, err error) { return } +// parseDSNParams parses the DSN "query string" +// Values must be url.QueryEscape'ed func parseDSNParams(cfg *config, params string) (err error) { - cfg.params = make(map[string]string) - for _, v := range strings.Split(params, "&") { param := strings.SplitN(v, "=", 2) if len(param) != 2 { @@ -203,6 +209,9 @@ func parseDSNParams(cfg *config, params string) (err error) { // Time Location case "loc": + if value, err = url.QueryUnescape(value); err != nil { + return + } cfg.loc, err = time.LoadLocation(value) if err != nil { return @@ -233,7 +242,14 @@ func parseDSNParams(cfg *config, params string) (err error) { } default: - cfg.params[param[0]] = value + // lazy init + if cfg.params == nil { + cfg.params = make(map[string]string) + } + + if cfg.params[param[0]], err = url.QueryUnescape(value); err != nil { + return + } } } diff --git a/utils_test.go b/utils_test.go index dbba669bf..3d439c909 100644 --- a/utils_test.go +++ b/utils_test.go @@ -30,7 +30,7 @@ var testDSNs = []struct { {"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, {"", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, {"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, - {"@unix/", "&{user: passwd: net:unix addr:/tmp/mysql.sock dbname: params:map[] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, + {"@unix/?arg=%2Fsome%2Fpath.ext", "&{user: passwd: net:unix addr:/tmp/mysql.sock dbname: params:map[arg:/some/path.ext] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, } func TestDSNParser(t *testing.T) { @@ -57,6 +57,7 @@ func TestDSNParser(t *testing.T) { func TestDSNParserInvalid(t *testing.T) { var invalidDSNs = []string{ "asdf/dbname", + "@net(addr/", //"/dbname?arg=/some/unescaped/path", } From 3f855aaafa5d7bf78876bc7bc858db37c40e42f7 Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Wed, 16 Oct 2013 19:56:22 +0200 Subject: [PATCH 3/4] DSN param values must be escaped --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 40b9938b4..72c71d429 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ Changes: - Refactored the driver tests - Added more benchmarks and moved all to a separate file - Other small refactoring + - DSN parameter values must now be url.QueryEscape'ed. This allows text values to contain special characters, such as '&'. New Features: From 6d1a06dd1fa2a68e2d3bc8acb718958bf6b65c76 Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Thu, 17 Oct 2013 00:08:28 +0200 Subject: [PATCH 4/4] Fix protocol parsing --- README.md | 2 +- utils.go | 40 +++++++++++++++++----------------------- utils_test.go | 9 ++++++--- 3 files changed, 24 insertions(+), 27 deletions(-) diff --git a/README.md b/README.md index 7a95668e3..0f04cfd92 100644 --- a/README.md +++ b/README.md @@ -78,7 +78,7 @@ A DSN in its fullest form: username:password@protocol(address)/dbname?param=value ``` -Except of the databasename, all values are optional. So the minimal DSN is: +Except for the databasename, all values are optional. So the minimal DSN is: ``` /dbname ``` diff --git a/utils.go b/utils.go index 7f8af463a..ca592086d 100644 --- a/utils.go +++ b/utils.go @@ -80,10 +80,10 @@ func parseDSN(dsn string) (cfg *config, err error) { // TODO: use strings.IndexByte when we can depend on Go 1.2 // [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN] - // Find the last '/' (since the password might contain a '/') + // Find the last '/' (since the password or the net addr might contain a '/') for i := len(dsn) - 1; i >= 0; i-- { if dsn[i] == '/' { - var j int + var j, k int // left part is empty if i <= 0 if i > 0 { @@ -93,7 +93,6 @@ func parseDSN(dsn string) (cfg *config, err error) { if dsn[j] == '@' { // username[:password] // Find the first ':' in dsn[:j] - var k int for k = 0; k < j; k++ { if dsn[k] == ':' { cfg.passwd = dsn[k+1 : j] @@ -102,31 +101,26 @@ func parseDSN(dsn string) (cfg *config, err error) { } cfg.user = dsn[:k] - // [protocol[(address)]] - // Find the first '(' in dsn[j+1:i] - for k = j + 1; k < i; k++ { - if dsn[k] == '(' { - // dsn[i-1] must be == ')' if an adress is specified - if dsn[i-1] != ')' { - if strings.ContainsRune(dsn[k+1:i], ')') { - return nil, errInvalidDSNUnescaped - } - return nil, errInvalidDSNAddr - } - cfg.addr = dsn[k+1 : i-1] - break - } - } - cfg.net = dsn[j+1 : k] - break } } - // non-empty left part must contain an '@' - if j < 0 { - return nil, errInvalidDSNUnescaped + // [protocol[(address)]] + // Find the first '(' in dsn[j+1:i] + for k = j + 1; k < i; k++ { + if dsn[k] == '(' { + // dsn[i-1] must be == ')' if an adress is specified + if dsn[i-1] != ')' { + if strings.ContainsRune(dsn[k+1:i], ')') { + return nil, errInvalidDSNUnescaped + } + return nil, errInvalidDSNAddr + } + cfg.addr = dsn[k+1 : i-1] + break + } } + cfg.net = dsn[j+1 : k] } // dbname[?param1=value1&...¶mN=valueN] diff --git a/utils_test.go b/utils_test.go index 3d439c909..9142a5e97 100644 --- a/utils_test.go +++ b/utils_test.go @@ -30,7 +30,7 @@ var testDSNs = []struct { {"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, {"", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, {"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, - {"@unix/?arg=%2Fsome%2Fpath.ext", "&{user: passwd: net:unix addr:/tmp/mysql.sock dbname: params:map[arg:/some/path.ext] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, + {"unix/?arg=%2Fsome%2Fpath.ext", "&{user: passwd: net:unix addr:/tmp/mysql.sock dbname: params:map[arg:/some/path.ext] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, } func TestDSNParser(t *testing.T) { @@ -56,8 +56,11 @@ func TestDSNParser(t *testing.T) { func TestDSNParserInvalid(t *testing.T) { var invalidDSNs = []string{ - "asdf/dbname", - "@net(addr/", + "@net(addr/", // no closing brace + "@tcp(/", // no closing brace + "tcp(/", // no closing brace + "(/", // no closing brace + "net(addr)//", // unescaped //"/dbname?arg=/some/unescaped/path", }