diff --git a/CHANGELOG.md b/CHANGELOG.md index 40b9938b4..72c71d429 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ Changes: - Refactored the driver tests - Added more benchmarks and moved all to a separate file - Other small refactoring + - DSN parameter values must now be url.QueryEscape'ed. This allows text values to contain special characters, such as '&'. New Features: diff --git a/README.md b/README.md index 60f449ee5..0f04cfd92 100644 --- a/README.md +++ b/README.md @@ -78,7 +78,7 @@ A DSN in its fullest form: username:password@protocol(address)/dbname?param=value ``` -Except of the databasename, all values are optional. So the minimal DSN is: +Except for the databasename, all values are optional. So the minimal DSN is: ``` /dbname ``` @@ -110,7 +110,7 @@ Possible Parameters are: * `allowOldPasswords`: `allowAllFiles=true` allows the usage of the insecure old password method. This should be avoided, but is necessary in some cases. See also [the old_passwords wiki page](https://github.com/go-sql-driver/mysql/wiki/old_passwords). * `charset`: Sets the charset used for client-server interaction ("SET NAMES `value`"). If multiple charsets are set (separated by a comma), the following charset is used if setting the charset failes. This enables support for `utf8mb4` ([introduced in MySQL 5.5.3](http://dev.mysql.com/doc/refman/5.5/en/charset-unicode-utf8mb4.html)) with fallback to `utf8` for older servers (`charset=utf8mb4,utf8`). * `clientFoundRows`: `clientFoundRows=true` causes an UPDATE to return the number of matching rows instead of the number of rows changed. - * `loc`: Sets the location for time.Time values (when using `parseTime=true`). The default is `UTC`. *"Local"* sets the system's location. See [time.LoadLocation](http://golang.org/pkg/time/#LoadLocation) for details. + * `loc`: Sets the location for time.Time values (when using `parseTime=true`). The default is `UTC`. *"Local"* sets the system's location. See [time.LoadLocation](http://golang.org/pkg/time/#LoadLocation) for details. Please keep in mind, that param values must be [url.QueryEscape](http://golang.org/pkg/net/url/#QueryEscape)'ed. Alternatively you can manually replace the `/` with `%2F`. For example `US/Pacific` would be `US%2FPacific`. * `parseTime`: `parseTime=true` changes the output type of `DATE` and `DATETIME` values to `time.Time` instead of `[]byte` / `string` * `strict`: Enable strict mode. MySQL warnings are treated as errors. * `timeout`: **Driver** side connection timeout. The value must be a string of decimal numbers, each with optional fraction and a unit suffix ( *"ms"*, *"s"*, *"m"*, *"h"* ), such as *"30s"*, *"0.5m"* or *"1m30s"*. To set a server side timeout, use the parameter [`wait_timeout`](http://dev.mysql.com/doc/refman/5.6/en/server-system-variables.html#sysvar_wait_timeout). @@ -122,6 +122,8 @@ All other parameters are interpreted as system variables: * `tx_isolation`: *"SET [tx_isolation](https://dev.mysql.com/doc/refman/5.5/en/server-system-variables.html#sysvar_tx_isolation)=`value`"* * `param`: *"SET `param`=`value`"* +***The values must be [url.QueryEscape](http://golang.org/pkg/net/url/#QueryEscape)'ed!*** + #### Examples ``` user@unix(/path/to/socket)/dbname @@ -132,7 +134,7 @@ user:password@tcp(localhost:5555)/dbname?autocommit=true ``` ``` -user:password@tcp([de:ad:be:ef::ca:fe]:80)/dbname?tls=skip-verify&charset=utf8mb4,utf8 +user:password@tcp([de:ad:be:ef::ca:fe]:80)/dbname?tls=skip-verify&charset=utf8mb4,utf8&sys_var=withSlash%2FandAt%40 ``` ``` diff --git a/driver_test.go b/driver_test.go index 078484c43..41fbf06a9 100644 --- a/driver_test.go +++ b/driver_test.go @@ -15,6 +15,7 @@ import ( "io" "io/ioutil" "net" + "net/url" "os" "strings" "testing" @@ -206,7 +207,7 @@ func TestTimezoneConversion(t *testing.T) { } for _, tz := range zones { - runTests(t, dsn+"&parseTime=true&loc="+tz, tzTest) + runTests(t, dsn+"&parseTime=true&loc="+url.QueryEscape(tz), tzTest) } } diff --git a/utils.go b/utils.go index faada5864..ca592086d 100644 --- a/utils.go +++ b/utils.go @@ -13,30 +13,26 @@ import ( "crypto/tls" "database/sql/driver" "encoding/binary" + "errors" "fmt" "io" "log" + "net/url" "os" - "regexp" "strings" "time" ) var ( errLog *log.Logger // Error Logger - dsnPattern *regexp.Regexp // Data Source Name Parser tlsConfigRegister map[string]*tls.Config // Register for custom tls.Configs + + errInvalidDSNUnescaped = errors.New("Invalid DSN: Did you forget to escape a param value?") + errInvalidDSNAddr = errors.New("Invalid DSN: Network Address not terminated (missing closing brace)") ) func init() { errLog = log.New(os.Stderr, "[MySQL] ", log.Ldate|log.Ltime|log.Lshortfile) - - dsnPattern = regexp.MustCompile( - `^(?:(?P.*?)(?::(?P.*))?@)?` + // [user[:password]@] - `(?:(?P[^\(]*)(?:\((?P[^\)]*)\))?)?` + // [net[(addr)]] - `\/(?P.*?)` + // /dbname - `(?:\?(?P[^\?]*))?$`) // [?param1=value1¶mN=valueN] - tlsConfigRegister = make(map[string]*tls.Config) } @@ -77,98 +73,69 @@ func DeregisterTLSConfig(key string) { delete(tlsConfigRegister, key) } +// parseDSN parses the DSN string to a config func parseDSN(dsn string) (cfg *config, err error) { cfg = new(config) - cfg.params = make(map[string]string) - - matches := dsnPattern.FindStringSubmatch(dsn) - names := dsnPattern.SubexpNames() - - for i, match := range matches { - switch names[i] { - case "user": - cfg.user = match - case "passwd": - cfg.passwd = match - case "net": - cfg.net = match - case "addr": - cfg.addr = match - case "dbname": - cfg.dbname = match - case "params": - for _, v := range strings.Split(match, "&") { - param := strings.SplitN(v, "=", 2) - if len(param) != 2 { - continue - } - - // cfg params - switch value := param[1]; param[0] { - - // Disable INFILE whitelist / enable all files - case "allowAllFiles": - var isBool bool - cfg.allowAllFiles, isBool = readBool(value) - if !isBool { - err = fmt.Errorf("Invalid Bool value: %s", value) - return - } - // Switch "rowsAffected" mode - case "clientFoundRows": - var isBool bool - cfg.clientFoundRows, isBool = readBool(value) - if !isBool { - err = fmt.Errorf("Invalid Bool value: %s", value) - return - } + // TODO: use strings.IndexByte when we can depend on Go 1.2 + + // [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN] + // Find the last '/' (since the password or the net addr might contain a '/') + for i := len(dsn) - 1; i >= 0; i-- { + if dsn[i] == '/' { + var j, k int + + // left part is empty if i <= 0 + if i > 0 { + // [username[:password]@][protocol[(address)]] + // Find the last '@' in dsn[:i] + for j = i; j >= 0; j-- { + if dsn[j] == '@' { + // username[:password] + // Find the first ':' in dsn[:j] + for k = 0; k < j; k++ { + if dsn[k] == ':' { + cfg.passwd = dsn[k+1 : j] + break + } + } + cfg.user = dsn[:k] - // Use old authentication mode (pre MySQL 4.1) - case "allowOldPasswords": - var isBool bool - cfg.allowOldPasswords, isBool = readBool(value) - if !isBool { - err = fmt.Errorf("Invalid Bool value: %s", value) - return + break } + } - // Time Location - case "loc": - cfg.loc, err = time.LoadLocation(value) - if err != nil { - return + // [protocol[(address)]] + // Find the first '(' in dsn[j+1:i] + for k = j + 1; k < i; k++ { + if dsn[k] == '(' { + // dsn[i-1] must be == ')' if an adress is specified + if dsn[i-1] != ')' { + if strings.ContainsRune(dsn[k+1:i], ')') { + return nil, errInvalidDSNUnescaped + } + return nil, errInvalidDSNAddr + } + cfg.addr = dsn[k+1 : i-1] + break } + } + cfg.net = dsn[j+1 : k] + } - // Dial Timeout - case "timeout": - cfg.timeout, err = time.ParseDuration(value) - if err != nil { + // dbname[?param1=value1&...¶mN=valueN] + // Find the first '?' in dsn[i+1:] + for j = i + 1; j < len(dsn); j++ { + if dsn[j] == '?' { + if err = parseDSNParams(cfg, dsn[j+1:]); err != nil { return } - - // TLS-Encryption - case "tls": - boolValue, isBool := readBool(value) - if isBool { - if boolValue { - cfg.tls = &tls.Config{} - } - } else { - if strings.ToLower(value) == "skip-verify" { - cfg.tls = &tls.Config{InsecureSkipVerify: true} - } else if tlsConfig, ok := tlsConfigRegister[value]; ok { - cfg.tls = tlsConfig - } else { - err = fmt.Errorf("Invalid value / unknown config name: %s", value) - return - } - } - - default: - cfg.params[param[0]] = value + break } } + cfg.dbname = dsn[i+1 : j] + + break } } @@ -179,10 +146,18 @@ func parseDSN(dsn string) (cfg *config, err error) { // Set default adress if empty if cfg.addr == "" { - cfg.addr = "127.0.0.1:3306" + switch cfg.net { + case "tcp": + cfg.addr = "127.0.0.1:3306" + case "unix": + cfg.addr = "/tmp/mysql.sock" + default: + return nil, errors.New("Default addr for network '" + cfg.net + "' unknown") + } + } - // Set default location if not set + // Set default location if empty if cfg.loc == nil { cfg.loc = time.UTC } @@ -190,6 +165,91 @@ func parseDSN(dsn string) (cfg *config, err error) { return } +// parseDSNParams parses the DSN "query string" +// Values must be url.QueryEscape'ed +func parseDSNParams(cfg *config, params string) (err error) { + for _, v := range strings.Split(params, "&") { + param := strings.SplitN(v, "=", 2) + if len(param) != 2 { + continue + } + + // cfg params + switch value := param[1]; param[0] { + + // Disable INFILE whitelist / enable all files + case "allowAllFiles": + var isBool bool + cfg.allowAllFiles, isBool = readBool(value) + if !isBool { + return fmt.Errorf("Invalid Bool value: %s", value) + } + + // Switch "rowsAffected" mode + case "clientFoundRows": + var isBool bool + cfg.clientFoundRows, isBool = readBool(value) + if !isBool { + return fmt.Errorf("Invalid Bool value: %s", value) + } + + // Use old authentication mode (pre MySQL 4.1) + case "allowOldPasswords": + var isBool bool + cfg.allowOldPasswords, isBool = readBool(value) + if !isBool { + return fmt.Errorf("Invalid Bool value: %s", value) + } + + // Time Location + case "loc": + if value, err = url.QueryUnescape(value); err != nil { + return + } + cfg.loc, err = time.LoadLocation(value) + if err != nil { + return + } + + // Dial Timeout + case "timeout": + cfg.timeout, err = time.ParseDuration(value) + if err != nil { + return + } + + // TLS-Encryption + case "tls": + boolValue, isBool := readBool(value) + if isBool { + if boolValue { + cfg.tls = &tls.Config{} + } + } else { + if strings.ToLower(value) == "skip-verify" { + cfg.tls = &tls.Config{InsecureSkipVerify: true} + } else if tlsConfig, ok := tlsConfigRegister[value]; ok { + cfg.tls = tlsConfig + } else { + return fmt.Errorf("Invalid value / unknown config name: %s", value) + } + } + + default: + // lazy init + if cfg.params == nil { + cfg.params = make(map[string]string) + } + + if cfg.params[param[0]], err = url.QueryUnescape(value); err != nil { + return + } + } + } + + return +} + // Returns the bool value of the input. // The 2nd return value indicates if the input was a valid bool value func readBool(input string) (value bool, valid bool) { diff --git a/utils_test.go b/utils_test.go index 39ad6cd5d..9142a5e97 100644 --- a/utils_test.go +++ b/utils_test.go @@ -14,23 +14,26 @@ import ( "time" ) -func TestDSNParser(t *testing.T) { - var testDSNs = []struct { - in string - out string - loc *time.Location - }{ - {"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, - {"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, - {"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, - {"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, - {"user:password@/dbname?loc=UTC&timeout=30s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:30000000000 tls: allowAllFiles:true allowOldPasswords:true clientFoundRows:true}", time.UTC}, - {"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.Local}, - {"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, - {"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, - {"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, - } +var testDSNs = []struct { + in string + out string + loc *time.Location +}{ + {"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, + {"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, + {"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, + {"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, + {"user:password@/dbname?loc=UTC&timeout=30s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:30000000000 tls: allowAllFiles:true allowOldPasswords:true clientFoundRows:true}", time.UTC}, + {"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.Local}, + {"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, + {"@/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, + {"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, + {"", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, + {"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, + {"unix/?arg=%2Fsome%2Fpath.ext", "&{user: passwd: net:unix addr:/tmp/mysql.sock dbname: params:map[arg:/some/path.ext] loc:%p timeout:0 tls: allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, +} +func TestDSNParser(t *testing.T) { var cfg *config var err error var res string @@ -51,6 +54,35 @@ func TestDSNParser(t *testing.T) { } } +func TestDSNParserInvalid(t *testing.T) { + var invalidDSNs = []string{ + "@net(addr/", // no closing brace + "@tcp(/", // no closing brace + "tcp(/", // no closing brace + "(/", // no closing brace + "net(addr)//", // unescaped + //"/dbname?arg=/some/unescaped/path", + } + + for i, tst := range invalidDSNs { + if _, err := parseDSN(tst); err == nil { + t.Errorf("invalid DSN #%d. (%s) didn't error!", i, tst) + } + } +} + +func BenchmarkParseDSN(b *testing.B) { + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + for _, tst := range testDSNs { + if _, err := parseDSN(tst.in); err != nil { + b.Error(err.Error()) + } + } + } +} + func TestScanNullTime(t *testing.T) { var scanTests = []struct { in interface{}