From 5042da9d13e737eed8ba5140204d6bb0b486a1b3 Mon Sep 17 00:00:00 2001 From: KJ Tsanaktsidis Date: Sun, 15 Dec 2019 22:18:13 +1100 Subject: [PATCH 1/2] Allow registering custom CredentialProviders for per-conn passwords When using a temporary credential system for MySQL, for example IAM database authenticaiton on AWS or the Database secret backend for Hashicorp Vault, it may not be the case that the same username and password be used for opening every connection in a *sql.DB. This PR adds funcionality whereby the caller can, instead of specifying cfg.User and cfg.Passwd (in the DSN as user:pass@...), specify a CredentialProvider= arguemnt which refers to a callback registered with RegisterCredentialProvider. When a new connection is to be opened, if the CredentialProvider callback is specified, that is called to obtain a username/password pair rather than using the values from the DSN. --- AUTHORS | 1 + README.md | 10 ++++ auth.go | 67 ++++++++++++++++++++------ auth_test.go | 124 ++++++++++++++++++++++++------------------------- connector.go | 15 ++++-- driver_test.go | 72 ++++++++++++++++++++++++++++ dsn.go | 45 +++++++++++------- dsn_test.go | 6 ++- packets.go | 8 ++-- 9 files changed, 247 insertions(+), 101 deletions(-) diff --git a/AUTHORS b/AUTHORS index ad5989800..c144c46ab 100644 --- a/AUTHORS +++ b/AUTHORS @@ -103,3 +103,4 @@ Multiplay Ltd. Percona LLC Pivotal Inc. Stripe Inc. +Zendesk Inc. diff --git a/README.md b/README.md index 2d15ffda3..a90f7938e 100644 --- a/README.md +++ b/README.md @@ -209,6 +209,16 @@ SELECT u.id FROM users as u will return `u.id` instead of just `id` if `columnsWithAlias=true`. +#### `credentialProvider` + +``` +Type: string +Valid Values: +Default: "" +``` + +If set, this must refer to a credential provider name registered with `RegisterCredentialProvider`. When this is set, the username and password in the DSN will be ignored; instead, each time a conneciton is to be opened, the named credential provider function will be called to obtain a username/password to connect with. This is useful when using, for example, IAM database auth in Amazon AWS, where "passwords" are actually temporary tokens that expire. + ##### `interpolateParams` ``` diff --git a/auth.go b/auth.go index fec7040d4..8162cd342 100644 --- a/auth.go +++ b/auth.go @@ -15,13 +15,16 @@ import ( "crypto/sha256" "crypto/x509" "encoding/pem" + "fmt" "sync" ) // server pub keys registry var ( - serverPubKeyLock sync.RWMutex - serverPubKeyRegistry map[string]*rsa.PublicKey + serverPubKeyLock sync.RWMutex + serverPubKeyRegistry map[string]*rsa.PublicKey + credentialProviderLock sync.RWMutex + credentialProviderRegistry map[string]CredentialProviderFunc ) // RegisterServerPubKey registers a server RSA public key which can be used to @@ -81,6 +84,44 @@ func getServerPubKey(name string) (pubKey *rsa.PublicKey) { return } +// CredentialProviderFunc is a function which can be used to fetch a username/password +// pair for use when opening a new MySQL connection. The first return value is the username +// and the second the password. +type CredentialProviderFunc func() (user string, password string, error error) + +// RegisterCredentialProvider registers a function to be called on every connection open to +// get the username and password to call +func RegisterCredentialProvider(name string, providerFunc CredentialProviderFunc) { + credentialProviderLock.Lock() + if credentialProviderRegistry == nil { + credentialProviderRegistry = make(map[string]CredentialProviderFunc) + } + credentialProviderRegistry[name] = providerFunc + credentialProviderLock.Unlock() +} + +// DeregisterCredentialProvider removes a function registered with RegisterCredentialProvider +func DeregisterCredentialProvider(name string) { + credentialProviderLock.Lock() + if credentialProviderRegistry != nil { + delete(credentialProviderRegistry, name) + } + credentialProviderLock.Unlock() +} + +func getCredentialsFromConfig(cfg *Config) (user string, password string, error error) { + if cfg.CredentialProvider != "" { + credentialProviderLock.RLock() + defer credentialProviderLock.RUnlock() + cpFunc, ok := credentialProviderRegistry[cfg.CredentialProvider] + if !ok { + return "", "", fmt.Errorf("credential provider %s not registered", cfg.CredentialProvider) + } + return cpFunc() + } + return cfg.User, cfg.Passwd, nil +} + // Hash password using pre 4.1 (old password) method // https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c type myRnd struct { @@ -237,10 +278,10 @@ func (mc *mysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) erro return mc.writeAuthSwitchPacket(enc) } -func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { +func (mc *mysqlConn) auth(authData []byte, plugin string, password string) ([]byte, error) { switch plugin { case "caching_sha2_password": - authResp := scrambleSHA256Password(authData, mc.cfg.Passwd) + authResp := scrambleSHA256Password(authData, password) return authResp, nil case "mysql_old_password": @@ -250,7 +291,7 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { // Note: there are edge cases where this should work but doesn't; // this is currently "wontfix": // https://github.com/go-sql-driver/mysql/issues/184 - authResp := append(scrambleOldPassword(authData[:8], mc.cfg.Passwd), 0) + authResp := append(scrambleOldPassword(authData[:8], password), 0) return authResp, nil case "mysql_clear_password": @@ -259,7 +300,7 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { } // http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html // http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html - return append([]byte(mc.cfg.Passwd), 0), nil + return append([]byte(password), 0), nil case "mysql_native_password": if !mc.cfg.AllowNativePasswords { @@ -267,16 +308,16 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { } // https://dev.mysql.com/doc/internals/en/secure-password-authentication.html // Native password authentication only need and will need 20-byte challenge. - authResp := scramblePassword(authData[:20], mc.cfg.Passwd) + authResp := scramblePassword(authData[:20], password) return authResp, nil case "sha256_password": - if len(mc.cfg.Passwd) == 0 { + if len(password) == 0 { return []byte{0}, nil } if mc.cfg.tls != nil || mc.cfg.Net == "unix" { // write cleartext auth packet - return append([]byte(mc.cfg.Passwd), 0), nil + return append([]byte(password), 0), nil } pubKey := mc.cfg.pubKey @@ -286,7 +327,7 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { } // encrypted password - enc, err := encryptPassword(mc.cfg.Passwd, authData, pubKey) + enc, err := encryptPassword(password, authData, pubKey) return enc, err default: @@ -295,7 +336,7 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { } } -func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { +func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string, password string) error { // Read Result Packet authData, newPlugin, err := mc.readAuthResult() if err != nil { @@ -315,7 +356,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { plugin = newPlugin - authResp, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin, password) if err != nil { return err } @@ -352,7 +393,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { case cachingSha2PasswordPerformFullAuthentication: if mc.cfg.tls != nil || mc.cfg.Net == "unix" { // write cleartext auth packet - err = mc.writeAuthSwitchPacket(append([]byte(mc.cfg.Passwd), 0)) + err = mc.writeAuthSwitchPacket(append([]byte(password), 0)) if err != nil { return err } diff --git a/auth_test.go b/auth_test.go index 1920ef39f..c1d08a454 100644 --- a/auth_test.go +++ b/auth_test.go @@ -85,11 +85,11 @@ func TestAuthFastCachingSHA256PasswordCached(t *testing.T) { plugin := "caching_sha2_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin, mc.cfg.Passwd) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin, mc.cfg.User) if err != nil { t.Fatal(err) } @@ -115,7 +115,7 @@ func TestAuthFastCachingSHA256PasswordCached(t *testing.T) { conn.maxReads = 1 // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } } @@ -130,11 +130,11 @@ func TestAuthFastCachingSHA256PasswordEmpty(t *testing.T) { plugin := "caching_sha2_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin, mc.cfg.Passwd) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin, mc.cfg.User) if err != nil { t.Fatal(err) } @@ -157,7 +157,7 @@ func TestAuthFastCachingSHA256PasswordEmpty(t *testing.T) { conn.maxReads = 1 // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } } @@ -172,11 +172,11 @@ func TestAuthFastCachingSHA256PasswordFullRSA(t *testing.T) { plugin := "caching_sha2_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin, mc.cfg.Passwd) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin, mc.cfg.User) if err != nil { t.Fatal(err) } @@ -208,7 +208,7 @@ func TestAuthFastCachingSHA256PasswordFullRSA(t *testing.T) { conn.maxReads = 3 // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -228,11 +228,11 @@ func TestAuthFastCachingSHA256PasswordFullRSAWithKey(t *testing.T) { plugin := "caching_sha2_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin, mc.cfg.Passwd) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin, mc.cfg.User) if err != nil { t.Fatal(err) } @@ -261,7 +261,7 @@ func TestAuthFastCachingSHA256PasswordFullRSAWithKey(t *testing.T) { conn.maxReads = 2 // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -280,11 +280,11 @@ func TestAuthFastCachingSHA256PasswordFullSecure(t *testing.T) { plugin := "caching_sha2_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin, mc.cfg.Passwd) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin, mc.cfg.User) if err != nil { t.Fatal(err) } @@ -317,7 +317,7 @@ func TestAuthFastCachingSHA256PasswordFullSecure(t *testing.T) { conn.maxReads = 3 // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -336,7 +336,7 @@ func TestAuthFastCleartextPasswordNotAllowed(t *testing.T) { plugin := "mysql_clear_password" // Send Client Authentication Packet - _, err := mc.auth(authData, plugin) + _, err := mc.auth(authData, plugin, mc.cfg.Passwd) if err != ErrCleartextPassword { t.Errorf("expected ErrCleartextPassword, got %v", err) } @@ -353,11 +353,11 @@ func TestAuthFastCleartextPassword(t *testing.T) { plugin := "mysql_clear_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin, mc.cfg.Passwd) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin, mc.cfg.User) if err != nil { t.Fatal(err) } @@ -380,7 +380,7 @@ func TestAuthFastCleartextPassword(t *testing.T) { conn.maxReads = 1 // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } } @@ -396,11 +396,11 @@ func TestAuthFastCleartextPasswordEmpty(t *testing.T) { plugin := "mysql_clear_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin, mc.cfg.Passwd) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin, mc.cfg.User) if err != nil { t.Fatal(err) } @@ -423,7 +423,7 @@ func TestAuthFastCleartextPasswordEmpty(t *testing.T) { conn.maxReads = 1 // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } } @@ -439,7 +439,7 @@ func TestAuthFastNativePasswordNotAllowed(t *testing.T) { plugin := "mysql_native_password" // Send Client Authentication Packet - _, err := mc.auth(authData, plugin) + _, err := mc.auth(authData, plugin, mc.cfg.Passwd) if err != ErrNativePassword { t.Errorf("expected ErrNativePassword, got %v", err) } @@ -455,11 +455,11 @@ func TestAuthFastNativePassword(t *testing.T) { plugin := "mysql_native_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin, mc.cfg.Passwd) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin, mc.cfg.User) if err != nil { t.Fatal(err) } @@ -483,7 +483,7 @@ func TestAuthFastNativePassword(t *testing.T) { conn.maxReads = 1 // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } } @@ -498,11 +498,11 @@ func TestAuthFastNativePasswordEmpty(t *testing.T) { plugin := "mysql_native_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin, mc.cfg.Passwd) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin, mc.cfg.User) if err != nil { t.Fatal(err) } @@ -525,7 +525,7 @@ func TestAuthFastNativePasswordEmpty(t *testing.T) { conn.maxReads = 1 // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } } @@ -540,11 +540,11 @@ func TestAuthFastSHA256PasswordEmpty(t *testing.T) { plugin := "sha256_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin, mc.cfg.Passwd) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin, mc.cfg.User) if err != nil { t.Fatal(err) } @@ -569,7 +569,7 @@ func TestAuthFastSHA256PasswordEmpty(t *testing.T) { conn.maxReads = 2 // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -588,11 +588,11 @@ func TestAuthFastSHA256PasswordRSA(t *testing.T) { plugin := "sha256_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin, mc.cfg.Passwd) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin, mc.cfg.User) if err != nil { t.Fatal(err) } @@ -617,7 +617,7 @@ func TestAuthFastSHA256PasswordRSA(t *testing.T) { conn.maxReads = 2 // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -637,11 +637,11 @@ func TestAuthFastSHA256PasswordRSAWithKey(t *testing.T) { plugin := "sha256_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin, mc.cfg.Passwd) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin, mc.cfg.User) if err != nil { t.Fatal(err) } @@ -651,7 +651,7 @@ func TestAuthFastSHA256PasswordRSAWithKey(t *testing.T) { conn.maxReads = 1 // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } } @@ -670,7 +670,7 @@ func TestAuthFastSHA256PasswordSecure(t *testing.T) { plugin := "sha256_password" // send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin, mc.cfg.Passwd) if err != nil { t.Fatal(err) } @@ -678,7 +678,7 @@ func TestAuthFastSHA256PasswordSecure(t *testing.T) { // unset TLS config to prevent the actual establishment of a TLS wrapper mc.cfg.tls = nil - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin, mc.cfg.User) if err != nil { t.Fatal(err) } @@ -699,7 +699,7 @@ func TestAuthFastSHA256PasswordSecure(t *testing.T) { conn.maxReads = 1 // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -728,7 +728,7 @@ func TestAuthSwitchCachingSHA256PasswordCached(t *testing.T) { 47, 43, 9, 41, 112, 67, 110} plugin := "mysql_native_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -761,7 +761,7 @@ func TestAuthSwitchCachingSHA256PasswordEmpty(t *testing.T) { 47, 43, 9, 41, 112, 67, 110} plugin := "mysql_native_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -797,7 +797,7 @@ func TestAuthSwitchCachingSHA256PasswordFullRSA(t *testing.T) { 47, 43, 9, 41, 112, 67, 110} plugin := "mysql_native_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -842,7 +842,7 @@ func TestAuthSwitchCachingSHA256PasswordFullRSAWithKey(t *testing.T) { 47, 43, 9, 41, 112, 67, 110} plugin := "mysql_native_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -885,7 +885,7 @@ func TestAuthSwitchCachingSHA256PasswordFullSecure(t *testing.T) { 47, 43, 9, 41, 112, 67, 110} plugin := "mysql_native_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -912,7 +912,7 @@ func TestAuthSwitchCleartextPasswordNotAllowed(t *testing.T) { authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, 47, 43, 9, 41, 112, 67, 110} plugin := "mysql_native_password" - err := mc.handleAuthResult(authData, plugin) + err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd) if err != ErrCleartextPassword { t.Errorf("expected ErrCleartextPassword, got %v", err) } @@ -935,7 +935,7 @@ func TestAuthSwitchCleartextPassword(t *testing.T) { 47, 43, 9, 41, 112, 67, 110} plugin := "mysql_native_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -962,7 +962,7 @@ func TestAuthSwitchCleartextPasswordEmpty(t *testing.T) { 47, 43, 9, 41, 112, 67, 110} plugin := "mysql_native_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -984,7 +984,7 @@ func TestAuthSwitchNativePasswordNotAllowed(t *testing.T) { authData := []byte{96, 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, 48, 31, 89, 39, 55, 31} plugin := "caching_sha2_password" - err := mc.handleAuthResult(authData, plugin) + err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd) if err != ErrNativePassword { t.Errorf("expected ErrNativePassword, got %v", err) } @@ -1009,7 +1009,7 @@ func TestAuthSwitchNativePassword(t *testing.T) { 48, 31, 89, 39, 55, 31} plugin := "caching_sha2_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -1039,7 +1039,7 @@ func TestAuthSwitchNativePasswordEmpty(t *testing.T) { 48, 31, 89, 39, 55, 31} plugin := "caching_sha2_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -1059,7 +1059,7 @@ func TestAuthSwitchOldPasswordNotAllowed(t *testing.T) { authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, 84, 96, 101, 92, 123, 121, 107} plugin := "mysql_native_password" - err := mc.handleAuthResult(authData, plugin) + err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd) if err != ErrOldPassword { t.Errorf("expected ErrOldPassword, got %v", err) } @@ -1075,7 +1075,7 @@ func TestOldAuthSwitchNotAllowed(t *testing.T) { authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, 84, 96, 101, 92, 123, 121, 107} plugin := "mysql_native_password" - err := mc.handleAuthResult(authData, plugin) + err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd) if err != ErrOldPassword { t.Errorf("expected ErrOldPassword, got %v", err) } @@ -1099,7 +1099,7 @@ func TestAuthSwitchOldPassword(t *testing.T) { 84, 96, 101, 92, 123, 121, 107} plugin := "mysql_native_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -1126,7 +1126,7 @@ func TestOldAuthSwitch(t *testing.T) { 84, 96, 101, 92, 123, 121, 107} plugin := "mysql_native_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -1153,7 +1153,7 @@ func TestAuthSwitchOldPasswordEmpty(t *testing.T) { 84, 96, 101, 92, 123, 121, 107} plugin := "mysql_native_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -1180,7 +1180,7 @@ func TestOldAuthSwitchPasswordEmpty(t *testing.T) { 84, 96, 101, 92, 123, 121, 107} plugin := "mysql_native_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -1209,7 +1209,7 @@ func TestAuthSwitchSHA256PasswordEmpty(t *testing.T) { 47, 43, 9, 41, 112, 67, 110} plugin := "mysql_native_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -1244,7 +1244,7 @@ func TestAuthSwitchSHA256PasswordRSA(t *testing.T) { 47, 43, 9, 41, 112, 67, 110} plugin := "mysql_native_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -1280,7 +1280,7 @@ func TestAuthSwitchSHA256PasswordRSAWithKey(t *testing.T) { 47, 43, 9, 41, 112, 67, 110} plugin := "mysql_native_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -1316,7 +1316,7 @@ func TestAuthSwitchSHA256PasswordSecure(t *testing.T) { 47, 43, 9, 41, 112, 67, 110} plugin := "mysql_native_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } diff --git a/connector.go b/connector.go index d567b4e4f..38715e9a5 100644 --- a/connector.go +++ b/connector.go @@ -88,25 +88,32 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { plugin = defaultAuthPlugin } + user, password, err := getCredentialsFromConfig(c.cfg) + if err != nil { + mc.cleanup() + return nil, err + } + // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin, password) if err != nil { // try the default auth plugin, if using the requested plugin failed errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error()) plugin = defaultAuthPlugin - authResp, err = mc.auth(authData, plugin) + authResp, err = mc.auth(authData, plugin, password) if err != nil { mc.cleanup() return nil, err } } - if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil { + + if err = mc.writeHandshakeResponsePacket(authResp, plugin, user); err != nil { mc.cleanup() return nil, err } // Handle response to auth packet, switch methods if possible - if err = mc.handleAuthResult(authData, plugin); err != nil { + if err = mc.handleAuthResult(authData, plugin, password); err != nil { // Authentication failed and MySQL has already closed the connection // (https://dev.mysql.com/doc/internals/en/authentication-fails.html). // Do not send COM_QUIT, just cleanup and return the error. diff --git a/driver_test.go b/driver_test.go index ace083dfc..63614c04b 100644 --- a/driver_test.go +++ b/driver_test.go @@ -3163,3 +3163,75 @@ func TestConnectorTimeoutsWatchCancel(t *testing.T) { t.Errorf("connection not closed") } } + +func TestCredentialProviderFunc(t *testing.T) { + // Our test provider func should return a valid password, then an invalid one, then a valid one + // to test that it really is having an effect. + shouldFailCreds := false + shouldFailError := false + RegisterCredentialProvider("TestCredentialProviderFunc", func() (string, string, error) { + if shouldFailCreds { + return "fail", "fail", nil + } + if shouldFailError { + return "", "", fmt.Errorf("credential_error") + } + return user, pass, nil + }) + defer DeregisterCredentialProvider("TestCredentialProviderFunc") + dsn := fmt.Sprintf("%s/%s?timeout=30s&credentialProvider=TestCredentialProviderFunc", netAddr, dbname) + runTests(t, dsn, func(dbt *DBTest) { + ctx := context.Background() + c1, err := dbt.db.Conn(ctx) + if err != nil { + dbt.Fatalf("error opening conn: %s", err) + } + defer c1.Close() + + rows, err := c1.QueryContext(ctx, "SELECT USER()") + if err != nil { + dbt.Fatalf("error running SELECT USER(): %s", err) + } + connUserAndHost := "" + for rows.Next() { + err := rows.Scan(&connUserAndHost) + if err != nil { + dbt.Fatalf("error running query: %s", err) + } + } + parts := strings.Split(connUserAndHost, "@") + connUser := strings.Join(parts[:len(parts)-1], "@") + if connUser != user { + dbt.Errorf("USER() and credentials don't match: %s != %s", connUser, user) + } + + // open one that should fail (wrong creds) + shouldFailCreds = true + _, err = dbt.db.Conn(ctx) + shouldFailCreds = false + if err == nil { + dbt.Errorf("expected second open to fail") + } + + // open one that should fail (with an error) + shouldFailError = true + _, err = dbt.db.Conn(ctx) + shouldFailError = false + if err == nil { + dbt.Errorf("expected third open to fail") + } + if !strings.Contains(err.Error(), "credential_error") { + dbt.Errorf("expected third open to fail with credential_error") + } + + c4, err := dbt.db.Conn(ctx) + if err != nil { + dbt.Fatalf("error opening conn: %s", err) + } + defer c4.Close() + err = c4.PingContext(ctx) + if err != nil { + dbt.Errorf("error running PingContext: %s", err) + } + }) +} diff --git a/dsn.go b/dsn.go index 1d9b4ab0a..0c3a58fe3 100644 --- a/dsn.go +++ b/dsn.go @@ -34,22 +34,23 @@ var ( // If a new Config is created instead of being parsed from a DSN string, // the NewConfig function should be used, which sets default values. type Config struct { - User string // Username - Passwd string // Password (requires User) - Net string // Network type - Addr string // Network address (requires Net) - DBName string // Database name - Params map[string]string // Connection parameters - Collation string // Connection collation - Loc *time.Location // Location for time.Time values - MaxAllowedPacket int // Max packet size allowed - ServerPubKey string // Server public key name - pubKey *rsa.PublicKey // Server public key - TLSConfig string // TLS configuration name - tls *tls.Config // TLS configuration - Timeout time.Duration // Dial timeout - ReadTimeout time.Duration // I/O read timeout - WriteTimeout time.Duration // I/O write timeout + User string // Username + Passwd string // Password (requires User) + CredentialProvider string // Credential provider name registered with RegisterCredentialProvider + Net string // Network type + Addr string // Network address (requires Net) + DBName string // Database name + Params map[string]string // Connection parameters + Collation string // Connection collation + Loc *time.Location // Location for time.Time values + MaxAllowedPacket int // Max packet size allowed + ServerPubKey string // Server public key name + pubKey *rsa.PublicKey // Server public key + TLSConfig string // TLS configuration name + tls *tls.Config // TLS configuration + Timeout time.Duration // Dial timeout + ReadTimeout time.Duration // I/O read timeout + WriteTimeout time.Duration // I/O write timeout AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE AllowCleartextPasswords bool // Allows the cleartext client side plugin @@ -347,6 +348,16 @@ func (cfg *Config) FormatDSN() string { } + if cfg.CredentialProvider != "" { + if hasParam { + buf.WriteString("&credentialProvider=") + } else { + hasParam = true + buf.WriteString("?credentialProvider=") + } + buf.WriteString(cfg.CredentialProvider) + } + // other params if cfg.Params != nil { var params []string @@ -613,6 +624,8 @@ func parseDSNParams(cfg *Config, params string) (err error) { if err != nil { return } + case "credentialProvider": + cfg.CredentialProvider = value default: // lazy init if cfg.Params == nil { diff --git a/dsn_test.go b/dsn_test.go index 50dc2932c..82194b52e 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -71,8 +71,10 @@ var testDSNs = []struct { }, { "tcp(de:ad:be:ef::ca:fe)/dbname", &Config{Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, -}, -} +}, { + "tcp(localhost)/dbname?credentialProvider=foobar", + &Config{Net: "tcp", Addr: "localhost:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CredentialProvider: "foobar"}, +}} func TestDSNParser(t *testing.T) { for i, tst := range testDSNs { diff --git a/packets.go b/packets.go index 30b3352c2..18b0d3731 100644 --- a/packets.go +++ b/packets.go @@ -276,7 +276,7 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro // Client Authentication Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse -func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string) error { +func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string, user string) error { // Adjust client flags based on server support clientFlags := clientProtocol41 | clientSecureConn | @@ -310,7 +310,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string clientFlags |= clientPluginAuthLenEncClientData } - pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1 + pktLen := 4 + 4 + 1 + 23 + len(user) + 1 + len(authRespLEI) + len(authResp) + 21 + 1 // To specify a db name if n := len(mc.cfg.DBName); n > 0 { @@ -373,8 +373,8 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string } // User [null terminated string] - if len(mc.cfg.User) > 0 { - pos += copy(data[pos:], mc.cfg.User) + if len(user) > 0 { + pos += copy(data[pos:], user) } data[pos] = 0x00 pos++ From fc21a985cad1ffda243edbeaa7736d7b757260b4 Mon Sep 17 00:00:00 2001 From: KJ Tsanaktsidis Date: Tue, 17 Dec 2019 09:11:53 +1100 Subject: [PATCH 2/2] Remove registry for CredentialProvider Instead, it can be used by passing the config object directly to a Connector. --- README.md | 35 ++++++++++++++++++-------- auth.go | 40 ++---------------------------- connector.go | 2 +- driver_test.go | 66 +++++++++++++++++++++++++++++++------------------- dsn.go | 53 ++++++++++++++++++---------------------- dsn_test.go | 3 --- 6 files changed, 93 insertions(+), 106 deletions(-) diff --git a/README.md b/README.md index a90f7938e..5a3489031 100644 --- a/README.md +++ b/README.md @@ -209,16 +209,6 @@ SELECT u.id FROM users as u will return `u.id` instead of just `id` if `columnsWithAlias=true`. -#### `credentialProvider` - -``` -Type: string -Valid Values: -Default: "" -``` - -If set, this must refer to a credential provider name registered with `RegisterCredentialProvider`. When this is set, the username and password in the DSN will be ignored; instead, each time a conneciton is to be opened, the named credential provider function will be called to obtain a username/password to connect with. This is useful when using, for example, IAM database auth in Amazon AWS, where "passwords" are actually temporary tokens that expire. - ##### `interpolateParams` ``` @@ -377,6 +367,31 @@ Examples: * [`time_zone=%27Europe%2FParis%27`](https://dev.mysql.com/doc/refman/5.5/en/time-zone-support.html): `SET time_zone='Europe/Paris'` * [`tx_isolation=%27REPEATABLE-READ%27`](https://dev.mysql.com/doc/refman/5.5/en/server-system-variables.html#sysvar_tx_isolation): `SET tx_isolation='REPEATABLE-READ'` +#### Non-DSN parameters + +Some parameters (those that have types too complex to fit into a string) are not supported as part of a DSN string, but can only be specified by using the Connector interface. To use these parameters, set your database client up like so: + +```go +dbConfig := mysql.Config { + Addr: "localhost:3306", + // ... other parameters ... +} +connector, err := mysql.NewConnector(dbConfig) +if err != nil { + panic(error) +} +db := sql.OpenDB(connector) +``` + +##### `CredentialProvider` + +``` +Type: CredentialProviderFunc +Default: nil +``` + +If set, this must refer to a credential provider function of type `CredentialProviderFunc`. When this is set, the `User` and `Passwd` fields in the config will be ignored; instead, each time a connection is to be opened, the credential provider function will be called to obtain a username/password to connect with. This is useful when using, for example, IAM database auth in Amazon AWS, where "passwords" are actually temporary tokens that expire. + #### Examples ``` diff --git a/auth.go b/auth.go index 8162cd342..37606d657 100644 --- a/auth.go +++ b/auth.go @@ -15,16 +15,13 @@ import ( "crypto/sha256" "crypto/x509" "encoding/pem" - "fmt" "sync" ) // server pub keys registry var ( - serverPubKeyLock sync.RWMutex - serverPubKeyRegistry map[string]*rsa.PublicKey - credentialProviderLock sync.RWMutex - credentialProviderRegistry map[string]CredentialProviderFunc + serverPubKeyLock sync.RWMutex + serverPubKeyRegistry map[string]*rsa.PublicKey ) // RegisterServerPubKey registers a server RSA public key which can be used to @@ -89,39 +86,6 @@ func getServerPubKey(name string) (pubKey *rsa.PublicKey) { // and the second the password. type CredentialProviderFunc func() (user string, password string, error error) -// RegisterCredentialProvider registers a function to be called on every connection open to -// get the username and password to call -func RegisterCredentialProvider(name string, providerFunc CredentialProviderFunc) { - credentialProviderLock.Lock() - if credentialProviderRegistry == nil { - credentialProviderRegistry = make(map[string]CredentialProviderFunc) - } - credentialProviderRegistry[name] = providerFunc - credentialProviderLock.Unlock() -} - -// DeregisterCredentialProvider removes a function registered with RegisterCredentialProvider -func DeregisterCredentialProvider(name string) { - credentialProviderLock.Lock() - if credentialProviderRegistry != nil { - delete(credentialProviderRegistry, name) - } - credentialProviderLock.Unlock() -} - -func getCredentialsFromConfig(cfg *Config) (user string, password string, error error) { - if cfg.CredentialProvider != "" { - credentialProviderLock.RLock() - defer credentialProviderLock.RUnlock() - cpFunc, ok := credentialProviderRegistry[cfg.CredentialProvider] - if !ok { - return "", "", fmt.Errorf("credential provider %s not registered", cfg.CredentialProvider) - } - return cpFunc() - } - return cfg.User, cfg.Passwd, nil -} - // Hash password using pre 4.1 (old password) method // https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c type myRnd struct { diff --git a/connector.go b/connector.go index 38715e9a5..9a5088157 100644 --- a/connector.go +++ b/connector.go @@ -88,7 +88,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { plugin = defaultAuthPlugin } - user, password, err := getCredentialsFromConfig(c.cfg) + user, password, err := c.cfg.getCredentials() if err != nil { mc.cleanup() return nil, err diff --git a/driver_test.go b/driver_test.go index 63614c04b..5b5fc7c2e 100644 --- a/driver_test.go +++ b/driver_test.go @@ -125,36 +125,47 @@ func runTestsWithMultiStatement(t *testing.T, dsn string, tests ...func(dbt *DBT } func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { + cfg, err := ParseDSN(dsn) + if err != nil { + t.Fatalf("error formatting DSN") + } + runTestsWithConfig(t, cfg, tests...) +} + +func runTestsWithConfig(t *testing.T, cfg *Config, tests ...func(dbt *DBTest)) { if !available { t.Skipf("MySQL server not running on %s", netAddr) } - db, err := sql.Open("mysql", dsn) + connector, err := NewConnector(cfg) if err != nil { t.Fatalf("error connecting: %s", err.Error()) } + db := sql.OpenDB(connector) defer db.Close() db.Exec("DROP TABLE IF EXISTS test") - dsn2 := dsn + "&interpolateParams=true" + cfg2 := cfg.Clone() + cfg2.InterpolateParams = true var db2 *sql.DB - if _, err := ParseDSN(dsn2); err != errInvalidDSNUnsafeCollation { - db2, err = sql.Open("mysql", dsn2) - if err != nil { - t.Fatalf("error connecting: %s", err.Error()) - } + connector2, err := NewConnector(cfg2) + if err != errInvalidDSNUnsafeCollation { + db2 = sql.OpenDB(connector2) defer db2.Close() + } else if err != nil { + t.Fatalf("error connecting: %s", err.Error()) } - dsn3 := dsn + "&multiStatements=true" + cfg3 := cfg.Clone() + cfg3.MultiStatements = true var db3 *sql.DB - if _, err := ParseDSN(dsn3); err != errInvalidDSNUnsafeCollation { - db3, err = sql.Open("mysql", dsn3) - if err != nil { - t.Fatalf("error connecting: %s", err.Error()) - } + connector3, err := NewConnector(cfg2) + if err != errInvalidDSNUnsafeCollation { + db3 = sql.OpenDB(connector3) defer db3.Close() + } else if err != nil { + t.Fatalf("error connecting: %s", err.Error()) } dbt := &DBTest{t, db} @@ -3169,18 +3180,23 @@ func TestCredentialProviderFunc(t *testing.T) { // to test that it really is having an effect. shouldFailCreds := false shouldFailError := false - RegisterCredentialProvider("TestCredentialProviderFunc", func() (string, string, error) { - if shouldFailCreds { - return "fail", "fail", nil - } - if shouldFailError { - return "", "", fmt.Errorf("credential_error") - } - return user, pass, nil - }) - defer DeregisterCredentialProvider("TestCredentialProviderFunc") - dsn := fmt.Sprintf("%s/%s?timeout=30s&credentialProvider=TestCredentialProviderFunc", netAddr, dbname) - runTests(t, dsn, func(dbt *DBTest) { + cfg := &Config{ + Addr: addr, + Net: prot, + DBName: dbname, + Collation: defaultCollation, + AllowNativePasswords: true, + CredentialProvider: func() (string, string, error) { + if shouldFailCreds { + return "fail", "fail", nil + } + if shouldFailError { + return "", "", fmt.Errorf("credential_error") + } + return user, pass, nil + }, + } + runTestsWithConfig(t, cfg, func(dbt *DBTest) { ctx := context.Background() c1, err := dbt.db.Conn(ctx) if err != nil { diff --git a/dsn.go b/dsn.go index 0c3a58fe3..00fc8ca9a 100644 --- a/dsn.go +++ b/dsn.go @@ -34,23 +34,23 @@ var ( // If a new Config is created instead of being parsed from a DSN string, // the NewConfig function should be used, which sets default values. type Config struct { - User string // Username - Passwd string // Password (requires User) - CredentialProvider string // Credential provider name registered with RegisterCredentialProvider - Net string // Network type - Addr string // Network address (requires Net) - DBName string // Database name - Params map[string]string // Connection parameters - Collation string // Connection collation - Loc *time.Location // Location for time.Time values - MaxAllowedPacket int // Max packet size allowed - ServerPubKey string // Server public key name - pubKey *rsa.PublicKey // Server public key - TLSConfig string // TLS configuration name - tls *tls.Config // TLS configuration - Timeout time.Duration // Dial timeout - ReadTimeout time.Duration // I/O read timeout - WriteTimeout time.Duration // I/O write timeout + User string // Username + Passwd string // Password (requires User) + CredentialProvider CredentialProviderFunc // Credential provider function + Net string // Network type + Addr string // Network address (requires Net) + DBName string // Database name + Params map[string]string // Connection parameters + Collation string // Connection collation + Loc *time.Location // Location for time.Time values + MaxAllowedPacket int // Max packet size allowed + ServerPubKey string // Server public key name + pubKey *rsa.PublicKey // Server public key + TLSConfig string // TLS configuration name + tls *tls.Config // TLS configuration + Timeout time.Duration // Dial timeout + ReadTimeout time.Duration // I/O read timeout + WriteTimeout time.Duration // I/O write timeout AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE AllowCleartextPasswords bool // Allows the cleartext client side plugin @@ -348,16 +348,6 @@ func (cfg *Config) FormatDSN() string { } - if cfg.CredentialProvider != "" { - if hasParam { - buf.WriteString("&credentialProvider=") - } else { - hasParam = true - buf.WriteString("?credentialProvider=") - } - buf.WriteString(cfg.CredentialProvider) - } - // other params if cfg.Params != nil { var params []string @@ -624,8 +614,6 @@ func parseDSNParams(cfg *Config, params string) (err error) { if err != nil { return } - case "credentialProvider": - cfg.CredentialProvider = value default: // lazy init if cfg.Params == nil { @@ -641,6 +629,13 @@ func parseDSNParams(cfg *Config, params string) (err error) { return } +func (cfg *Config) getCredentials() (user string, password string, err error) { + if cfg.CredentialProvider != nil { + return cfg.CredentialProvider() + } + return cfg.User, cfg.Passwd, nil +} + func ensureHavePort(addr string) string { if _, _, err := net.SplitHostPort(addr); err != nil { return net.JoinHostPort(addr, "3306") diff --git a/dsn_test.go b/dsn_test.go index 82194b52e..2f5ab658f 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -71,9 +71,6 @@ var testDSNs = []struct { }, { "tcp(de:ad:be:ef::ca:fe)/dbname", &Config{Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, -}, { - "tcp(localhost)/dbname?credentialProvider=foobar", - &Config{Net: "tcp", Addr: "localhost:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CredentialProvider: "foobar"}, }} func TestDSNParser(t *testing.T) {