Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Changes:
- Refactored the driver tests
- Added more benchmarks and moved all to a separate file
- Other small refactoring
- DSN parameter values must now be url.QueryEscape'ed. This allows text values to contain special characters, such as '&'.

New Features:

Expand Down
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ A DSN in its fullest form:
username:password@protocol(address)/dbname?param=value
```

Except of the databasename, all values are optional. So the minimal DSN is:
Except for the databasename, all values are optional. So the minimal DSN is:
```
/dbname
```
Expand Down Expand Up @@ -110,7 +110,7 @@ Possible Parameters are:
* `allowOldPasswords`: `allowAllFiles=true` allows the usage of the insecure old password method. This should be avoided, but is necessary in some cases. See also [the old_passwords wiki page](https://github.com/go-sql-driver/mysql/wiki/old_passwords).
* `charset`: Sets the charset used for client-server interaction ("SET NAMES `value`"). If multiple charsets are set (separated by a comma), the following charset is used if setting the charset failes. This enables support for `utf8mb4` ([introduced in MySQL 5.5.3](http://dev.mysql.com/doc/refman/5.5/en/charset-unicode-utf8mb4.html)) with fallback to `utf8` for older servers (`charset=utf8mb4,utf8`).
* `clientFoundRows`: `clientFoundRows=true` causes an UPDATE to return the number of matching rows instead of the number of rows changed.
* `loc`: Sets the location for time.Time values (when using `parseTime=true`). The default is `UTC`. *"Local"* sets the system's location. See [time.LoadLocation](http://golang.org/pkg/time/#LoadLocation) for details.
* `loc`: Sets the location for time.Time values (when using `parseTime=true`). The default is `UTC`. *"Local"* sets the system's location. See [time.LoadLocation](http://golang.org/pkg/time/#LoadLocation) for details. Please keep in mind, that param values must be [url.QueryEscape](http://golang.org/pkg/net/url/#QueryEscape)'ed. Alternatively you can manually replace the `/` with `%2F`. For example `US/Pacific` would be `US%2FPacific`.
* `parseTime`: `parseTime=true` changes the output type of `DATE` and `DATETIME` values to `time.Time` instead of `[]byte` / `string`
* `strict`: Enable strict mode. MySQL warnings are treated as errors.
* `timeout`: **Driver** side connection timeout. The value must be a string of decimal numbers, each with optional fraction and a unit suffix ( *"ms"*, *"s"*, *"m"*, *"h"* ), such as *"30s"*, *"0.5m"* or *"1m30s"*. To set a server side timeout, use the parameter [`wait_timeout`](http://dev.mysql.com/doc/refman/5.6/en/server-system-variables.html#sysvar_wait_timeout).
Expand All @@ -122,6 +122,8 @@ All other parameters are interpreted as system variables:
* `tx_isolation`: *"SET [tx_isolation](https://dev.mysql.com/doc/refman/5.5/en/server-system-variables.html#sysvar_tx_isolation)=`value`"*
* `param`: *"SET `param`=`value`"*

***The values must be [url.QueryEscape](http://golang.org/pkg/net/url/#QueryEscape)'ed!***

#### Examples
```
user@unix(/path/to/socket)/dbname
Expand All @@ -132,7 +134,7 @@ user:password@tcp(localhost:5555)/dbname?autocommit=true
```

```
user:password@tcp([de:ad:be:ef::ca:fe]:80)/dbname?tls=skip-verify&charset=utf8mb4,utf8
user:password@tcp([de:ad:be:ef::ca:fe]:80)/dbname?tls=skip-verify&charset=utf8mb4,utf8&sys_var=withSlash%2FandAt%40
```

```
Expand Down
3 changes: 2 additions & 1 deletion driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"io"
"io/ioutil"
"net"
"net/url"
"os"
"strings"
"testing"
Expand Down Expand Up @@ -206,7 +207,7 @@ func TestTimezoneConversion(t *testing.T) {
}

for _, tz := range zones {
runTests(t, dsn+"&parseTime=true&loc="+tz, tzTest)
runTests(t, dsn+"&parseTime=true&loc="+url.QueryEscape(tz), tzTest)
}
}

Expand Down
242 changes: 151 additions & 91 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,26 @@ import (
"crypto/tls"
"database/sql/driver"
"encoding/binary"
"errors"
"fmt"
"io"
"log"
"net/url"
"os"
"regexp"
"strings"
"time"
)

var (
errLog *log.Logger // Error Logger
dsnPattern *regexp.Regexp // Data Source Name Parser
tlsConfigRegister map[string]*tls.Config // Register for custom tls.Configs

errInvalidDSNUnescaped = errors.New("Invalid DSN: Did you forget to escape a param value?")
errInvalidDSNAddr = errors.New("Invalid DSN: Network Address not terminated (missing closing brace)")
)

func init() {
errLog = log.New(os.Stderr, "[MySQL] ", log.Ldate|log.Ltime|log.Lshortfile)

dsnPattern = regexp.MustCompile(
`^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
`(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
`\/(?P<dbname>.*?)` + // /dbname
`(?:\?(?P<params>[^\?]*))?$`) // [?param1=value1&paramN=valueN]

tlsConfigRegister = make(map[string]*tls.Config)
}

Expand Down Expand Up @@ -77,98 +73,69 @@ func DeregisterTLSConfig(key string) {
delete(tlsConfigRegister, key)
}

// parseDSN parses the DSN string to a config
func parseDSN(dsn string) (cfg *config, err error) {
cfg = new(config)
cfg.params = make(map[string]string)

matches := dsnPattern.FindStringSubmatch(dsn)
names := dsnPattern.SubexpNames()

for i, match := range matches {
switch names[i] {
case "user":
cfg.user = match
case "passwd":
cfg.passwd = match
case "net":
cfg.net = match
case "addr":
cfg.addr = match
case "dbname":
cfg.dbname = match
case "params":
for _, v := range strings.Split(match, "&") {
param := strings.SplitN(v, "=", 2)
if len(param) != 2 {
continue
}

// cfg params
switch value := param[1]; param[0] {

// Disable INFILE whitelist / enable all files
case "allowAllFiles":
var isBool bool
cfg.allowAllFiles, isBool = readBool(value)
if !isBool {
err = fmt.Errorf("Invalid Bool value: %s", value)
return
}

// Switch "rowsAffected" mode
case "clientFoundRows":
var isBool bool
cfg.clientFoundRows, isBool = readBool(value)
if !isBool {
err = fmt.Errorf("Invalid Bool value: %s", value)
return
}
// TODO: use strings.IndexByte when we can depend on Go 1.2

// [user[:password]@][net[(addr)]]/dbname[?param1=value1&paramN=valueN]
// Find the last '/' (since the password or the net addr might contain a '/')
for i := len(dsn) - 1; i >= 0; i-- {
if dsn[i] == '/' {
var j, k int

// left part is empty if i <= 0
if i > 0 {
// [username[:password]@][protocol[(address)]]
// Find the last '@' in dsn[:i]
for j = i; j >= 0; j-- {
if dsn[j] == '@' {
// username[:password]
// Find the first ':' in dsn[:j]
for k = 0; k < j; k++ {
if dsn[k] == ':' {
cfg.passwd = dsn[k+1 : j]
break
}
}
cfg.user = dsn[:k]

// Use old authentication mode (pre MySQL 4.1)
case "allowOldPasswords":
var isBool bool
cfg.allowOldPasswords, isBool = readBool(value)
if !isBool {
err = fmt.Errorf("Invalid Bool value: %s", value)
return
break
}
}

// Time Location
case "loc":
cfg.loc, err = time.LoadLocation(value)
if err != nil {
return
// [protocol[(address)]]
// Find the first '(' in dsn[j+1:i]
for k = j + 1; k < i; k++ {
if dsn[k] == '(' {
// dsn[i-1] must be == ')' if an adress is specified
if dsn[i-1] != ')' {
if strings.ContainsRune(dsn[k+1:i], ')') {
return nil, errInvalidDSNUnescaped
}
return nil, errInvalidDSNAddr
}
cfg.addr = dsn[k+1 : i-1]
break
}
}
cfg.net = dsn[j+1 : k]
}

// Dial Timeout
case "timeout":
cfg.timeout, err = time.ParseDuration(value)
if err != nil {
// dbname[?param1=value1&...&paramN=valueN]
// Find the first '?' in dsn[i+1:]
for j = i + 1; j < len(dsn); j++ {
if dsn[j] == '?' {
if err = parseDSNParams(cfg, dsn[j+1:]); err != nil {
return
}

// TLS-Encryption
case "tls":
boolValue, isBool := readBool(value)
if isBool {
if boolValue {
cfg.tls = &tls.Config{}
}
} else {
if strings.ToLower(value) == "skip-verify" {
cfg.tls = &tls.Config{InsecureSkipVerify: true}
} else if tlsConfig, ok := tlsConfigRegister[value]; ok {
cfg.tls = tlsConfig
} else {
err = fmt.Errorf("Invalid value / unknown config name: %s", value)
return
}
}

default:
cfg.params[param[0]] = value
break
}
}
cfg.dbname = dsn[i+1 : j]

break
}
}

Expand All @@ -179,17 +146,110 @@ func parseDSN(dsn string) (cfg *config, err error) {

// Set default adress if empty
if cfg.addr == "" {
cfg.addr = "127.0.0.1:3306"
switch cfg.net {
case "tcp":
cfg.addr = "127.0.0.1:3306"
case "unix":
cfg.addr = "/tmp/mysql.sock"
default:
return nil, errors.New("Default addr for network '" + cfg.net + "' unknown")
}

}

// Set default location if not set
// Set default location if empty
if cfg.loc == nil {
cfg.loc = time.UTC
}

return
}

// parseDSNParams parses the DSN "query string"
// Values must be url.QueryEscape'ed
func parseDSNParams(cfg *config, params string) (err error) {
for _, v := range strings.Split(params, "&") {
param := strings.SplitN(v, "=", 2)
if len(param) != 2 {
continue
}

// cfg params
switch value := param[1]; param[0] {

// Disable INFILE whitelist / enable all files
case "allowAllFiles":
var isBool bool
cfg.allowAllFiles, isBool = readBool(value)
if !isBool {
return fmt.Errorf("Invalid Bool value: %s", value)
}

// Switch "rowsAffected" mode
case "clientFoundRows":
var isBool bool
cfg.clientFoundRows, isBool = readBool(value)
if !isBool {
return fmt.Errorf("Invalid Bool value: %s", value)
}

// Use old authentication mode (pre MySQL 4.1)
case "allowOldPasswords":
var isBool bool
cfg.allowOldPasswords, isBool = readBool(value)
if !isBool {
return fmt.Errorf("Invalid Bool value: %s", value)
}

// Time Location
case "loc":
if value, err = url.QueryUnescape(value); err != nil {
return
}
cfg.loc, err = time.LoadLocation(value)
if err != nil {
return
}

// Dial Timeout
case "timeout":
cfg.timeout, err = time.ParseDuration(value)
if err != nil {
return
}

// TLS-Encryption
case "tls":
boolValue, isBool := readBool(value)
if isBool {
if boolValue {
cfg.tls = &tls.Config{}
}
} else {
if strings.ToLower(value) == "skip-verify" {
cfg.tls = &tls.Config{InsecureSkipVerify: true}
} else if tlsConfig, ok := tlsConfigRegister[value]; ok {
cfg.tls = tlsConfig
} else {
return fmt.Errorf("Invalid value / unknown config name: %s", value)
}
}

default:
// lazy init
if cfg.params == nil {
cfg.params = make(map[string]string)
}

if cfg.params[param[0]], err = url.QueryUnescape(value); err != nil {
return
}
}
}

return
}

// Returns the bool value of the input.
// The 2nd return value indicates if the input was a valid bool value
func readBool(input string) (value bool, valid bool) {
Expand Down
Loading