From 09fc70b6de09fd63b2b47d4df0574675c9e83493 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Wed, 17 Jan 2018 10:17:22 -0800 Subject: [PATCH 01/27] Renamed some tests and test parameters for clarity, and adhere to Go conventions (#74) --- firebase_test.go | 95 +++++++++++++---------- testdata/firebase_config.json | 4 +- testdata/firebase_config_invalid_key.json | 4 +- testdata/firebase_config_partial.json | 2 +- 4 files changed, 61 insertions(+), 44 deletions(-) diff --git a/firebase_test.go b/firebase_test.go index 686d6af5..df41c56d 100644 --- a/firebase_test.go +++ b/firebase_test.go @@ -15,6 +15,7 @@ package firebase import ( + "fmt" "io/ioutil" "log" "net/http" @@ -355,41 +356,46 @@ func TestAutoInit(t *testing.T) { wantOptions *Config }{ { - "No environment variable, no explicit options", + "", "", nil, &Config{ProjectID: "mock-project-id"}, // from default creds here and below. - }, { - "Environment variable set to file, no explicit options", + }, + { + "", "testdata/firebase_config.json", nil, &Config{ - ProjectID: "hipster-chat-mock", - StorageBucket: "hipster-chat.appspot.mock", + ProjectID: "auto-init-project-id", + StorageBucket: "auto-init.storage.bucket", }, - }, { - "Environment variable set to string, no explicit options", + }, + { + "", `{ - "projectId": "hipster-chat-mock", - "storageBucket": "hipster-chat.appspot.mock" + "projectId": "auto-init-project-id", + "storageBucket": "auto-init.storage.bucket" }`, nil, &Config{ - ProjectID: "hipster-chat-mock", - StorageBucket: "hipster-chat.appspot.mock", + ProjectID: "auto-init-project-id", + StorageBucket: "auto-init.storage.bucket", }, - }, { - "Environment variable set to file with some values missing, no explicit options", + }, + { + "", "testdata/firebase_config_partial.json", nil, - &Config{ProjectID: "hipster-chat-mock"}, - }, { - "Environment variable set to string with some values missing, no explicit options", - `{"projectId": "hipster-chat-mock"}`, + &Config{ProjectID: "auto-init-project-id"}, + }, + { + "", + `{"projectId": "auto-init-project-id"}`, nil, - &Config{ProjectID: "hipster-chat-mock"}, - }, { - "Environment variable set to file which is ignored as some explicit options are passed", + &Config{ProjectID: "auto-init-project-id"}, + }, + { + "", "testdata/firebase_config_partial.json", &Config{StorageBucket: "sb1-mock"}, &Config{ @@ -397,36 +403,45 @@ func TestAutoInit(t *testing.T) { StorageBucket: "sb1-mock", }, }, { - "Environment variable set to string which is ignored as some explicit options are passed", - `{"projectId": "hipster-chat-mock"}`, + "", + `{"projectId": "auto-init-project-id"}`, &Config{StorageBucket: "sb1-mock"}, &Config{ - ProjectID: "mock-project-id", + ProjectID: "mock-project-id", // from default creds StorageBucket: "sb1-mock", }, - }, { - "Environment variable set to file which is ignored as options are explicitly empty", + }, + { + "", "testdata/firebase_config_partial.json", &Config{}, &Config{ProjectID: "mock-project-id"}, - }, { - "Environment variable set to file with an unknown key which is ignored, no explicit options", + }, + { + "", + `{"projectId": "auto-init-project-id"}`, + &Config{}, + &Config{ProjectID: "mock-project-id"}, + }, + { + "", "testdata/firebase_config_invalid_key.json", nil, &Config{ ProjectID: "mock-project-id", // from default creds - StorageBucket: "hipster-chat.appspot.mock", + StorageBucket: "auto-init.storage.bucket", }, - }, { - "Environment variable set to string with an unknown key which is ignored, no explicit options", + }, + { + "", `{ - "obviously_bad_key": "hipster-chat-mock", - "storageBucket": "hipster-chat.appspot.mock" + "obviously_bad_key": "mock-project-id", + "storageBucket": "auto-init.storage.bucket" }`, nil, &Config{ ProjectID: "mock-project-id", - StorageBucket: "hipster-chat.appspot.mock", + StorageBucket: "auto-init.storage.bucket", }, }, } @@ -435,7 +450,7 @@ func TestAutoInit(t *testing.T) { defer reinstateEnv(credEnvVar, credOld) for _, test := range tests { - t.Run(test.name, func(t *testing.T) { + t.Run(fmt.Sprintf("NewApp(%s)", test.name), func(t *testing.T) { overwriteEnv(firebaseEnvName, test.optionsConfig) app, err := NewApp(context.Background(), test.initOptions) if err != nil { @@ -454,15 +469,17 @@ func TestAutoInitInvalidFiles(t *testing.T) { wantError string }{ { - "nonexistant file", + "NonexistingFile", "testdata/no_such_file.json", "open testdata/no_such_file.json: no such file or directory", - }, { - "invalid JSON", + }, + { + "InvalidJSON", "testdata/firebase_config_invalid.json", "invalid character 'b' looking for beginning of value", - }, { - "empty file", + }, + { + "EmptyFile", "testdata/firebase_config_empty.json", "unexpected end of JSON input", }, diff --git a/testdata/firebase_config.json b/testdata/firebase_config.json index d249fe76..e9a3b5bc 100644 --- a/testdata/firebase_config.json +++ b/testdata/firebase_config.json @@ -1,4 +1,4 @@ { - "projectId": "hipster-chat-mock", - "storageBucket": "hipster-chat.appspot.mock" + "projectId": "auto-init-project-id", + "storageBucket": "auto-init.storage.bucket" } diff --git a/testdata/firebase_config_invalid_key.json b/testdata/firebase_config_invalid_key.json index 8fad82c8..6cbc52f4 100644 --- a/testdata/firebase_config_invalid_key.json +++ b/testdata/firebase_config_invalid_key.json @@ -1,4 +1,4 @@ { - "project1d_bad_key": "hipster-chat-mock", - "storageBucket": "hipster-chat.appspot.mock" + "project1d_bad_key": "auto-init-project-id", + "storageBucket": "auto-init.storage.bucket" } diff --git a/testdata/firebase_config_partial.json b/testdata/firebase_config_partial.json index 1775043e..8515413f 100644 --- a/testdata/firebase_config_partial.json +++ b/testdata/firebase_config_partial.json @@ -1,3 +1,3 @@ { - "projectId": "hipster-chat-mock" + "projectId": "auto-init-project-id" } From c24cb17191866a3a11a365b9b36ad9af3adf70ba Mon Sep 17 00:00:00 2001 From: avishalom Date: Fri, 19 Jan 2018 15:29:18 -0500 Subject: [PATCH 02/27] clean unused types (#76) --- auth/user_mgt.go | 31 ------------------------------- 1 file changed, 31 deletions(-) diff --git a/auth/user_mgt.go b/auth/user_mgt.go index 9d0098dc..eb92d9a5 100644 --- a/auth/user_mgt.go +++ b/auth/user_mgt.go @@ -498,37 +498,6 @@ func (u *UserToUpdate) preparePayload(user *identitytoolkit.IdentitytoolkitRelyi // End of validators -// Response Types ------------------------------- - -type getUserResponse struct { - RequestType string - Users []responseUserRecord -} - -type responseUserRecord struct { - UID string - DisplayName string - Email string - PhoneNumber string - PhotoURL string - CreationTimestamp int64 - LastLogInTimestamp int64 - ProviderID string - CustomClaims string - Disabled bool - EmailVerified bool - ProviderUserInfo []*UserInfo - PasswordHash string - PasswordSalt string - ValidSince int64 -} - -type listUsersResponse struct { - RequestType string - Users []responseUserRecord - NextPage string -} - // Helper functions for retrieval and HTTP calls. func (c *Client) createUser(ctx context.Context, user *UserToCreate) (string, error) { From fb6fa29bdc92e08deeee70b1f2728177984d0aa1 Mon Sep 17 00:00:00 2001 From: avishalom Date: Mon, 29 Jan 2018 14:52:09 -0500 Subject: [PATCH 03/27] Create CHANGELOG.md (#75) (#79) * Create CHANGELOG.md Initial changelog based on https://firebase.google.com/support/release-notes/admin/go --- CHANGELOG.md | 101 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 101 insertions(+) create mode 100644 CHANGELOG.md diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 00000000..3ee0af8b --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,101 @@ +# Unreleased +- + +# v2.4.0 + +### Initialization + +- [added] The [`firebase.NewApp()`](https://godoc.org/firebase.google.com/go#NewApp) + method can now be invoked without any arguments. This initializes an app + using Google Application Default Credentials, and + [`firebase.Config`](https://godoc.org/firebase.google.com/go#Config) loaded + from the `FIREBASE_CONFIG` environment variable. + +### Authentication + +- [changed] The user management operations in the `auth` package now uses the + [`identitytoolkit/v3`](https://google.golang.org/api/identitytoolkit/v3) library. +- [changed] The `ProviderID` field on the + [`auth.UserRecord`](https://godoc.org/firebase.google.com/go/auth#UserRecord) + type is now set to the constant value `firebase`. + +# v2.3.0 + +- [added] A new [`InstanceID`](https://godoc.org/firebase.google.com/go#App.InstanceID) + API that facilitates deleting instance IDs and associated user data from + Firebase projects. + +# v2.2.1 + +### Authentication + +- [changed] Adding the `X-Client-Version` to the headers in the API calls for + tracking API usage. + +# v2.2.0 + +### Authentication + +- [added] A new user management API that supports querying and updating + user accounts associated with a Firebase project. This adds `GetUser()`, + `GetUserByEmail()`, `GetUserByPhoneNumber()`, `CreateUser()`, `UpdateUser()`, + `DeleteUser()`, `Users()` and `SetCustomUserClaims()` functions to the + [`auth.Client`](https://godoc.org/firebase.google.com/go/auth#Client) API. + +# v2.1.0 + +- [added] A new [`Firestore` API](https://godoc.org/firebase.google.com/go#App.Firestore) + that enables access to [Cloud Firestore](/docs/firestore) databases. + +# v2.0.0 + +- [added] A new [Cloud Storage API](https://godoc.org/firebase.google.com/go/storage) + that facilitates accessing Google Cloud Storage buckets using the + [`cloud.google.com/go/storage`](https://cloud.google.com/go/storage) + package. + +### Authentication + +- [changed] The [`Auth()`](https://godoc.org/firebase.google.com/go#App.Auth) + API now accepts a `Context` argument. This breaking + change enables passing different contexts to different services, instead + of using a single context per [`App`](https://godoc.org/firebase.google.com/go#App). + +# v1.0.2 + +### Authentication + +- [changed] When deployed in the Google App Engine environment, the SDK can + now leverage the utilities provided by the + [App Engine SDK](https://cloud.google.com/appengine/docs/standard/go/reference) + to sign JWT tokens. As a result, it is now possible to initialize the Admin + SDK in App Engine without a service account JSON file, and still be able to + call [`CustomToken()`](https://godoc.org/firebase.google.com/go/auth#Client.CustomToken) + and [`CustomTokenWithClaims()`](https://godoc.org/firebase.google.com/go/auth#Client.CustomTokenWithClaims). + +# v1.0.1 + +### Authentication + +- [changed] Now uses the client options provided during + [SDK initialization](https://godoc.org/firebase.google.com/go#NewApp) to + create the [`http.Client`](https://godoc.org/net/http#Client) that is used + to fetch public key certificates. This enables developers to use the ID token + verification feature in environments like Google App Engine by providing a + platform-specific `http.Client` using + [`option.WithHTTPClient()`](https://godoc.org/google.golang.org/api/option#WithHTTPClient). + +# v1.0.0 + +- [added] Initial release of the Admin Go SDK. See + [Add the Firebase Admin SDK to your Server](/docs/admin/setup/) to get + started. +- [added] You can configure the SDK to use service account credentials, user + credentials (refresh tokens), or Google Cloud application default credentials + to access your Firebase project. + +### Authentication + +- [added] The initial release includes the `CustomToken()`, + `CustomTokenWithClaims()`, and `VerifyIDToken()` functions for minting custom + authentication tokens and verifying Firebase ID tokens. From aae4f9321d4f1e6ce2864d1fefc96e2160afbac0 Mon Sep 17 00:00:00 2001 From: avishalom Date: Thu, 1 Feb 2018 14:16:03 -0500 Subject: [PATCH 04/27] change instance ID format (#82) Changing the format of the "non-existing" instance ID in the integration tests to comply with the expected iid format. --- integration/iid/iid_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/integration/iid/iid_test.go b/integration/iid/iid_test.go index 8b7b316c..d3303dc4 100644 --- a/integration/iid/iid_test.go +++ b/integration/iid/iid_test.go @@ -50,12 +50,12 @@ func TestMain(m *testing.M) { } func TestNonExisting(t *testing.T) { - err := client.DeleteInstanceID(context.Background(), "non-existing") + err := client.DeleteInstanceID(context.Background(), "dnon-existY") if err == nil { - t.Errorf("DeleteInstanceID(non-existing) = nil; want error") + t.Errorf("DeleteInstanceID(\"dnon-existY\") = nil; want error") } - want := `instance id "non-existing": failed to find the instance id` + want := `instance id "dnon-existY": failed to find the instance id` if err.Error() != want { - t.Errorf("DeleteInstanceID(non-existing) = %v; want = %v", err, want) + t.Errorf("DeleteInstanceID(\"dnon-existY\") = %v; want = %v", err, want) } } From 56a731253688f584f110829f705d9cd91b9ec7e1 Mon Sep 17 00:00:00 2001 From: avishalom Date: Thu, 8 Feb 2018 12:39:52 -0500 Subject: [PATCH 05/27] Import context from golang.org/x/net/ for 1.6 compatibility (#87) * import golang.org/x/net/context instead of context for 1.6 compatibility --- CHANGELOG.md | 2 +- auth/auth_std.go | 2 +- integration/firestore/firestore_test.go | 3 ++- integration/iid/iid_test.go | 3 ++- storage/storage.go | 3 ++- 5 files changed, 8 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3ee0af8b..754e2f73 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,5 @@ # Unreleased -- +- Import context from golang.org/x/net/ for 1.6 compatibility # v2.4.0 diff --git a/auth/auth_std.go b/auth/auth_std.go index f593a7cc..2055af38 100644 --- a/auth/auth_std.go +++ b/auth/auth_std.go @@ -16,7 +16,7 @@ package auth -import "context" +import "golang.org/x/net/context" func newSigner(ctx context.Context) (signer, error) { return serviceAcctSigner{}, nil diff --git a/integration/firestore/firestore_test.go b/integration/firestore/firestore_test.go index 6c367205..6e7b4e28 100644 --- a/integration/firestore/firestore_test.go +++ b/integration/firestore/firestore_test.go @@ -15,12 +15,13 @@ package firestore import ( - "context" "log" "reflect" "testing" "firebase.google.com/go/integration/internal" + + "golang.org/x/net/context" ) func TestFirestore(t *testing.T) { diff --git a/integration/iid/iid_test.go b/integration/iid/iid_test.go index d3303dc4..02132925 100644 --- a/integration/iid/iid_test.go +++ b/integration/iid/iid_test.go @@ -16,7 +16,6 @@ package iid import ( - "context" "flag" "log" "os" @@ -24,6 +23,8 @@ import ( "firebase.google.com/go/iid" "firebase.google.com/go/integration/internal" + + "golang.org/x/net/context" ) var client *iid.Client diff --git a/storage/storage.go b/storage/storage.go index 878e2175..985b6eb7 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -16,11 +16,12 @@ package storage import ( - "context" "errors" "cloud.google.com/go/storage" "firebase.google.com/go/internal" + + "golang.org/x/net/context" ) // Client is the interface for the Firebase Storage service. From c764f496cf24721a775b727893a662821f59dc0b Mon Sep 17 00:00:00 2001 From: avishalom Date: Thu, 8 Feb 2018 12:44:37 -0500 Subject: [PATCH 06/27] Document non existing name in integration tests for iid (#85) --- integration/iid/iid_test.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/integration/iid/iid_test.go b/integration/iid/iid_test.go index 02132925..9be5dce0 100644 --- a/integration/iid/iid_test.go +++ b/integration/iid/iid_test.go @@ -51,12 +51,13 @@ func TestMain(m *testing.M) { } func TestNonExisting(t *testing.T) { - err := client.DeleteInstanceID(context.Background(), "dnon-existY") + // legal instance IDs are /[cdef][A-Za-z0-9_-]{9}[AEIMQUYcgkosw048]/ + err := client.DeleteInstanceID(context.Background(), "fictive-ID0") if err == nil { - t.Errorf("DeleteInstanceID(\"dnon-existY\") = nil; want error") + t.Errorf("DeleteInstanceID(non-existing) = nil; want error") } - want := `instance id "dnon-existY": failed to find the instance id` + want := `instance id "fictive-ID0": failed to find the instance id` if err.Error() != want { - t.Errorf("DeleteInstanceID(\"dnon-existY\") = %v; want = %v", err, want) + t.Errorf("DeleteInstanceID(non-existing) = %v; want = %v", err, want) } } From 06eb0e061a430e2889763709ed4fa698308b98a9 Mon Sep 17 00:00:00 2001 From: avishalom Date: Mon, 12 Feb 2018 20:33:52 -0500 Subject: [PATCH 07/27] Revoke Tokens (#77) Adding TokensValidAfterMillis property, RevokeRefreshTokens(), and VerifyIDTokenAndCheckRevoked(). --- CHANGELOG.md | 10 +++++ auth/auth.go | 34 +++++++++++++++ auth/auth_test.go | 49 +++++++++++++++++++++- auth/user_mgt.go | 38 ++++++++++++----- auth/user_mgt_test.go | 63 ++++++++++++++++++++++------ integration/auth/auth_test.go | 70 +++++++++++++++++++++++++++++-- integration/auth/user_mgt_test.go | 60 ++++++++++++++++++-------- testdata/get_user.json | 4 +- testdata/list_users.json | 12 +++--- 9 files changed, 285 insertions(+), 55 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 754e2f73..51a8c4af 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,14 @@ # Unreleased + +### Token revocation +- [added] A New ['VerifyIDTokenAndCheckRevoked(ctx, token)'](https://godoc.org/firebase.google.com/go/auth#Client.VerifyIDToken) + method has been added to check for revoked ID tokens. +- [added] A new method ['RevokeRefreshTokens(uid)'](https://godoc.org/firebase.google.com/go/auth#Client.RevokeRefreshTokens) + has been added to invalidate all refresh tokens issued to a user. +- [added] A new property + `TokensValidAfterMillis` has been added to the ['UserRecord'](https://godoc.org/firebase.google.com/go/auth#UserRecord). + This property stores the time of the revocation truncated to 1 second accuracy. + - Import context from golang.org/x/net/ for 1.6 compatibility # v2.4.0 diff --git a/auth/auth.go b/auth/auth.go index cc798182..f6605c7b 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -177,6 +177,18 @@ func (c *Client) CustomTokenWithClaims(uid string, devClaims map[string]interfac return encodeToken(c.snr, defaultHeader(), payload) } +// RevokeRefreshTokens revokes all refresh tokens issued to a user. +// +// RevokeRefreshTokens updates the user's TokensValidAfterMillis to the current UTC second. +// It is important that the server on which this is called has its clock set correctly and synchronized. +// +// While this revokes all sessions for a specified user and disables any new ID tokens for existing sessions +// from getting minted, existing ID tokens may remain active until their natural expiration (one hour). +// To verify that ID tokens are revoked, use `verifyIdTokenAndCheckRevoked(ctx, idToken)`. +func (c *Client) RevokeRefreshTokens(ctx context.Context, uid string) error { + return c.updateUser(ctx, uid, (&UserToUpdate{}).revokeRefreshTokens()) +} + // VerifyIDToken verifies the signature and payload of the provided ID token. // // VerifyIDToken accepts a signed JWT token string, and verifies that it is current, issued for the @@ -184,6 +196,7 @@ func (c *Client) CustomTokenWithClaims(uid string, devClaims map[string]interfac // a Token containing the decoded claims in the input JWT. See // https://firebase.google.com/docs/auth/admin/verify-id-tokens#retrieve_id_tokens_on_clients for // more details on how to obtain an ID token in a client app. +// This does not check whether or not the token has been revoked. See `VerifyIDTokenAndCheckRevoked` below. func (c *Client) VerifyIDToken(idToken string) (*Token, error) { if c.projectID == "" { return nil, errors.New("project id not available") @@ -237,6 +250,27 @@ func (c *Client) VerifyIDToken(idToken string) (*Token, error) { return p, nil } +// VerifyIDTokenAndCheckRevoked verifies the provided ID token and checks it has not been revoked. +// +// VerifyIDTokenAndCheckRevoked verifies the signature and payload of the provided ID token and +// checks that it wasn't revoked. Uses VerifyIDToken() internally to verify the ID token JWT. +func (c *Client) VerifyIDTokenAndCheckRevoked(ctx context.Context, idToken string) (*Token, error) { + p, err := c.VerifyIDToken(idToken) + if err != nil { + return nil, err + } + + user, err := c.GetUser(ctx, p.UID) + if err != nil { + return nil, err + } + + if p.IssuedAt*1000 < user.TokensValidAfterMillis { + return nil, fmt.Errorf("ID token has been revoked") + } + return p, nil +} + func parseKey(key string) (*rsa.PrivateKey, error) { block, _ := pem.Decode([]byte(key)) if block == nil { diff --git a/auth/auth_test.go b/auth/auth_test.go index 79e957d8..690b5d6f 100644 --- a/auth/auth_test.go +++ b/auth/auth_test.go @@ -37,10 +37,10 @@ import ( ) var client *Client +var ctx context.Context var testIDToken string var testGetUserResponse []byte var testListUsersResponse []byte - var defaultTestOpts = []option.ClientOption{ option.WithCredentialsFile("../testdata/service_account.json"), } @@ -49,7 +49,6 @@ func TestMain(m *testing.M) { var ( err error ks keySource - ctx context.Context creds *google.DefaultCredentials opts []option.ClientOption ) @@ -193,6 +192,52 @@ func TestCustomTokenInvalidCredential(t *testing.T) { } } +func TestVerifyIDTokenAndCheckRevokedValid(t *testing.T) { + s := echoServer(testGetUserResponse, t) + defer s.Close() + + ft, err := s.Client.VerifyIDTokenAndCheckRevoked(ctx, testIDToken) + if err != nil { + t.Error(err) + } + if ft.Claims["admin"] != true { + t.Errorf("Claims['admin'] = %v; want = true", ft.Claims["admin"]) + } + if ft.UID != ft.Subject { + t.Errorf("UID = %q; Sub = %q; want UID = Sub", ft.UID, ft.Subject) + } +} + +func TestVerifyIDTokenAndCheckRevokedDoNotCheck(t *testing.T) { + s := echoServer(testGetUserResponse, t) + defer s.Close() + tok := getIDToken(mockIDTokenPayload{"uid": "uid", "iat": 1970}) // old token + + ft, err := s.Client.VerifyIDToken(tok) + if err != nil { + t.Fatal(err) + } + if ft.Claims["admin"] != true { + t.Errorf("Claims['admin'] = %v; want = true", ft.Claims["admin"]) + } + if ft.UID != ft.Subject { + t.Errorf("UID = %q; Sub = %q; want UID = Sub", ft.UID, ft.Subject) + } +} + +func TestVerifyIDTokenAndCheckRevokedInvalidated(t *testing.T) { + s := echoServer(testGetUserResponse, t) + defer s.Close() + tok := getIDToken(mockIDTokenPayload{"uid": "uid", "iat": 1970}) // old token + + p, err := s.Client.VerifyIDTokenAndCheckRevoked(ctx, tok) + we := "ID token has been revoked" + if p != nil || err == nil || err.Error() != we { + t.Errorf("VerifyIDTokenAndCheckRevoked(ctx, token) =(%v, %v); want = (%v, %v)", + p, err, nil, we) + } +} + func TestVerifyIDToken(t *testing.T) { ft, err := client.VerifyIDToken(testIDToken) if err != nil { diff --git a/auth/user_mgt.go b/auth/user_mgt.go index eb92d9a5..423ceece 100644 --- a/auth/user_mgt.go +++ b/auth/user_mgt.go @@ -21,6 +21,7 @@ import ( "reflect" "regexp" "strings" + "time" "golang.org/x/net/context" "google.golang.org/api/identitytoolkit/v3" @@ -39,6 +40,7 @@ var commonValidators = map[string]func(interface{}) error{ "password": validatePassword, "photoUrl": validatePhotoURL, "localId": validateUID, + "validSince": func(interface{}) error { return nil }, // Needed for preparePayload. } // Create a new interface @@ -65,6 +67,7 @@ type UserInfo struct { } // UserMetadata contains additional metadata associated with a user account. +// Timestamps are in milliseconds since epoch. type UserMetadata struct { CreationTimestamp int64 LastLogInTimestamp int64 @@ -73,11 +76,12 @@ type UserMetadata struct { // UserRecord contains metadata associated with a Firebase user account. type UserRecord struct { *UserInfo - CustomClaims map[string]interface{} - Disabled bool - EmailVerified bool - ProviderUserInfo []*UserInfo - UserMetadata *UserMetadata + CustomClaims map[string]interface{} + Disabled bool + EmailVerified bool + ProviderUserInfo []*UserInfo + TokensValidAfterMillis int64 // milliseconds since epoch. + UserMetadata *UserMetadata } // ExportedUserRecord is the returned user value used when listing all the users. @@ -173,6 +177,13 @@ func (u *UserToUpdate) PhoneNumber(phone string) *UserToUpdate { u.set("phoneNum // PhotoURL setter. func (u *UserToUpdate) PhotoURL(url string) *UserToUpdate { u.set("photoUrl", url); return u } +// revokeRefreshTokens revokes all refresh tokens for a user by setting the validSince property +// to the present in epoch seconds. +func (u *UserToUpdate) revokeRefreshTokens() *UserToUpdate { + u.set("validSince", time.Now().Unix()) + return u +} + // CreateUser creates a new user with the specified properties. func (c *Client) CreateUser(ctx context.Context, user *UserToCreate) (*UserRecord, error) { uid, err := c.createUser(ctx, user) @@ -471,7 +482,12 @@ func (u *UserToUpdate) preparePayload(user *identitytoolkit.IdentitytoolkitRelyi if err := validate(v); err != nil { return err } - reflect.ValueOf(user).Elem().FieldByName(strings.Title(key)).SetString(params[key].(string)) + f := reflect.ValueOf(user).Elem().FieldByName(strings.Title(key)) + if f.Kind() == reflect.String { + f.SetString(params[key].(string)) + } else if f.Kind() == reflect.Int64 { + f.SetInt(params[key].(int64)) + } } } if params["disableUser"] != nil { @@ -528,7 +544,6 @@ func (c *Client) updateUser(ctx context.Context, uid string, user *UserToUpdate) if user == nil || user.params == nil { return fmt.Errorf("update parameters must not be nil or empty") } - request := &identitytoolkit.IdentitytoolkitRelyingpartySetAccountInfoRequest{ LocalId: uid, } @@ -597,10 +612,11 @@ func makeExportedUser(r *identitytoolkit.UserInfo) (*ExportedUserRecord, error) ProviderID: defaultProviderID, UID: r.LocalId, }, - CustomClaims: cc, - Disabled: r.Disabled, - EmailVerified: r.EmailVerified, - ProviderUserInfo: providerUserInfo, + CustomClaims: cc, + Disabled: r.Disabled, + EmailVerified: r.EmailVerified, + ProviderUserInfo: providerUserInfo, + TokensValidAfterMillis: r.ValidSince * 1000, UserMetadata: &UserMetadata{ LastLogInTimestamp: r.LastLoginAt, CreationTimestamp: r.CreatedAt, diff --git a/auth/user_mgt_test.go b/auth/user_mgt_test.go index 5c75fdf9..db7b7300 100644 --- a/auth/user_mgt_test.go +++ b/auth/user_mgt_test.go @@ -24,6 +24,7 @@ import ( "reflect" "strings" "testing" + "time" "firebase.google.com/go/internal" @@ -59,9 +60,10 @@ var testUser = &UserRecord{ UID: "testuid", }, }, + TokensValidAfterMillis: 1494364393000, UserMetadata: &UserMetadata{ - CreationTimestamp: 1234567890, - LastLogInTimestamp: 1233211232, + CreationTimestamp: 1234567890000, + LastLogInTimestamp: 1233211232000, }, CustomClaims: map[string]interface{}{"admin": true, "package": "gold"}, } @@ -496,6 +498,41 @@ func TestUpdateUser(t *testing.T) { } } } +func TestRevokeRefreshTokens(t *testing.T) { + resp := `{ + "kind": "identitytoolkit#SetAccountInfoResponse", + "localId": "expectedUserID" + }` + s := echoServer([]byte(resp), t) + defer s.Close() + before := time.Now().Unix() + if err := s.Client.RevokeRefreshTokens(context.Background(), "some_uid"); err != nil { + t.Error(err) + } + after := time.Now().Unix() + + req := &identitytoolkit.IdentitytoolkitRelyingpartySetAccountInfoRequest{} + if err := json.Unmarshal(s.Rbody, &req); err != nil { + t.Error(err) + } + if req.ValidSince > after || req.ValidSince < before { + t.Errorf("validSince = %d, expecting time between %d and %d", req.ValidSince, before, after) + } +} + +func TestRevokeRefreshTokensInvalidUID(t *testing.T) { + resp := `{ + "kind": "identitytoolkit#SetAccountInfoResponse", + "localId": "expectedUserID" + }` + s := echoServer([]byte(resp), t) + defer s.Close() + + we := "uid must not be empty" + if err := s.Client.RevokeRefreshTokens(context.Background(), ""); err == nil || err.Error() != we { + t.Errorf("RevokeRefreshTokens(); err = %s; want err = %s", err.Error(), we) + } +} func TestInvalidSetCustomClaims(t *testing.T) { cases := []struct { @@ -609,8 +646,8 @@ func TestMakeExportedUser(t *testing.T) { PasswordHash: "passwordhash", ValidSince: 1494364393, Disabled: false, - CreatedAt: 1234567890, - LastLoginAt: 1233211232, + CreatedAt: 1234567890000, + LastLoginAt: 1233211232000, CustomAttributes: `{"admin": true, "package": "gold"}`, ProviderUserInfo: []*identitytoolkit.UserInfoProviderUserInfo{ { @@ -637,7 +674,8 @@ func TestMakeExportedUser(t *testing.T) { } if !reflect.DeepEqual(exported.UserRecord, want.UserRecord) { // zero in - t.Errorf("makeExportedUser() = %#v; want: %#v", exported.UserRecord, want.UserRecord) + t.Errorf("makeExportedUser() = %#v; want: %#v \n(%#v)\n(%#v)", exported.UserRecord, want.UserRecord, + exported.UserMetadata, want.UserMetadata) } if exported.PasswordHash != want.PasswordHash { t.Errorf("PasswordHash = %q; want = %q", exported.PasswordHash, want.PasswordHash) @@ -690,8 +728,7 @@ func echoServer(resp interface{}, t *testing.T) *mockAuthServer { case []byte: b = v default: - b, err = json.Marshal(resp) - if err != nil { + if b, err = json.Marshal(resp); err != nil { t.Fatal("marshaling error") } } @@ -729,17 +766,17 @@ func echoServer(resp interface{}, t *testing.T) *mockAuthServer { } w.Header().Set("Content-Type", "application/json") w.Write(s.Resp) - }) s.Srv = httptest.NewServer(handler) - conf := &internal.AuthConfig{ Opts: []option.ClientOption{ - option.WithTokenSource(&mockTokenSource{testToken}), - }, - Version: testVersion, + option.WithTokenSource(&mockTokenSource{testToken})}, + ProjectID: "mock-project-id", + Version: testVersion, } - authClient, err := NewClient(context.Background(), conf) + + authClient, err := NewClient(ctx, conf) + authClient.ks = &fileKeySource{FilePath: "../testdata/public_certs.json"} if err != nil { t.Fatal(err) } diff --git a/integration/auth/auth_test.go b/integration/auth/auth_test.go index ee0dc0b3..5e027d44 100644 --- a/integration/auth/auth_test.go +++ b/integration/auth/auth_test.go @@ -25,6 +25,7 @@ import ( "net/http" "os" "testing" + "time" "firebase.google.com/go/auth" "firebase.google.com/go/integration/internal" @@ -62,7 +63,6 @@ func TestCustomToken(t *testing.T) { if err != nil { t.Fatal(err) } - idt, err := signInWithCustomToken(ct) if err != nil { t.Fatal(err) @@ -75,10 +75,69 @@ func TestCustomToken(t *testing.T) { if vt.UID != "user1" { t.Errorf("UID = %q; want UID = %q", vt.UID, "user1") } + if err = client.DeleteUser(context.Background(), "user1"); err != nil { + t.Error(err) + } +} + +func TestVerifyIDTokenAndCheckRevoked(t *testing.T) { + uid := "user_revoked" + ct, err := client.CustomToken(uid) + + if err != nil { + t.Fatal(err) + } + idt, err := signInWithCustomToken(ct) + if err != nil { + t.Fatal(err) + } + ctx := context.Background() + vt, err := client.VerifyIDTokenAndCheckRevoked(ctx, idt) + if err != nil { + t.Fatal(err) + } + if vt.UID != uid { + t.Errorf("UID = %q; want UID = %q", vt.UID, uid) + } + // The backend stores the validSince property in seconds since the epoch. + // The issuedAt property of the token is also in seconds. If a token was + // issued, and then in the same second tokens were revoked, the token will + // have the same timestamp as the tokensValidAfterMillis, and will therefore + // not be considered revoked. Hence we wait one second before revoking. + time.Sleep(time.Second) + if err = client.RevokeRefreshTokens(ctx, uid); err != nil { + t.Fatal(err) + } + + vt, err = client.VerifyIDTokenAndCheckRevoked(ctx, idt) + we := "ID token has been revoked" + if vt != nil || err == nil || err.Error() != we { + t.Errorf("tok, err := VerifyIDTokenAndCheckRevoked(); got (%v, %s) ; want (%v, %v)", + vt, err, nil, we) + } + + // Does not return error for revoked token. + if _, err = client.VerifyIDToken(idt); err != nil { + t.Errorf("VerifyIDToken(); err = %s; want err = ", err) + } + + // Sign in after revocation. + if idt, err = signInWithCustomToken(ct); err != nil { + t.Fatal(err) + } + + if _, err = client.VerifyIDTokenAndCheckRevoked(ctx, idt); err != nil { + t.Errorf("VerifyIDTokenAndCheckRevoked(); err = %s; want err = ", err) + } + + err = client.DeleteUser(ctx, uid) + if err != nil { + t.Error(err) + } } func TestCustomTokenWithClaims(t *testing.T) { - ct, err := client.CustomTokenWithClaims("user1", map[string]interface{}{ + ct, err := client.CustomTokenWithClaims("user2", map[string]interface{}{ "premium": true, "package": "gold", }) @@ -95,8 +154,8 @@ func TestCustomTokenWithClaims(t *testing.T) { if err != nil { t.Fatal(err) } - if vt.UID != "user1" { - t.Errorf("UID = %q; want UID = %q", vt.UID, "user1") + if vt.UID != "user2" { + t.Errorf("UID = %q; want UID = %q", vt.UID, "user2") } if premium, ok := vt.Claims["premium"].(bool); !ok || !premium { t.Errorf("Claims['premium'] = %v; want Claims['premium'] = true", vt.Claims["premium"]) @@ -104,6 +163,9 @@ func TestCustomTokenWithClaims(t *testing.T) { if pkg, ok := vt.Claims["package"].(string); !ok || pkg != "gold" { t.Errorf("Claims['package'] = %v; want Claims['package'] = \"gold\"", vt.Claims["package"]) } + if err = client.DeleteUser(context.Background(), "user2"); err != nil { + t.Error(err) + } } func signInWithCustomToken(token string) (string, error) { diff --git a/integration/auth/user_mgt_test.go b/integration/auth/user_mgt_test.go index 8227599f..3013fbe0 100644 --- a/integration/auth/user_mgt_test.go +++ b/integration/auth/user_mgt_test.go @@ -19,6 +19,7 @@ import ( "fmt" "reflect" "testing" + "time" "google.golang.org/api/iterator" @@ -34,16 +35,30 @@ var testFixtures = struct { }{} func TestUserManagement(t *testing.T) { - t.Run("Create test users", testCreateUsers) - t.Run("Get user", testGetUser) - t.Run("Iterate users", testUserIterator) - t.Run("Paged iteration", testPager) - t.Run("Disable user account", testDisableUser) - t.Run("Update user", testUpdateUser) - t.Run("Remove user attributes", testRemovePhonePhotoName) - t.Run("Remove custom claims", testRemoveCustomClaims) - t.Run("Add custom claims", testAddCustomClaims) - t.Run("Delete test users", testDeleteUsers) + orderedRuns := []struct { + name string + testFunc func(*testing.T) + }{ + {"Create test users", testCreateUsers}, + {"Get user", testGetUser}, + {"Iterate users", testUserIterator}, + {"Paged iteration", testPager}, + {"Disable user account", testDisableUser}, + {"Update user", testUpdateUser}, + {"Remove user attributes", testRemovePhonePhotoName}, + {"Remove custom claims", testRemoveCustomClaims}, + {"Add custom claims", testAddCustomClaims}, + {"Delete test users", testDeleteUsers}, + } + // The tests are meant to be run in sequence. A failure in creating the users + // should be fatal so non of the other tests run. However calling Fatal from a + // subtest does not prevent the other subtests from running, hence we check the + // success of each subtest before proceeding. + for _, run := range orderedRuns { + if ok := t.Run(run.name, run.testFunc); !ok { + t.Fatalf("Failed run %v", run.name) + } + } } // N.B if the tests are failing due to inability to create existing users, manual @@ -52,14 +67,22 @@ func TestUserManagement(t *testing.T) { func testCreateUsers(t *testing.T) { // Create users with uid for i := 0; i < 3; i++ { - params := (&auth.UserToCreate{}).UID(fmt.Sprintf("tempTestUserID-%d", i)) + uid := fmt.Sprintf("tempTestUserID-%d", i) + params := (&auth.UserToCreate{}).UID(uid) u, err := client.CreateUser(context.Background(), params) if err != nil { - t.Fatal("failed to create user", i, err) + t.Fatal(err) } testFixtures.uidList = append(testFixtures.uidList, u.UID) - } + // make sure that the user.TokensValidAfterMillis is not in the future or stale. + if u.TokensValidAfterMillis > time.Now().Unix()*1000 { + t.Errorf("timestamp cannot be in the future") + } + if time.Now().Sub(time.Unix(u.TokensValidAfterMillis, 0)) > time.Hour { + t.Errorf("timestamp should be recent") + } + } // Create user with no parameters (zero-value) u, err := client.CreateUser(context.Background(), (&auth.UserToCreate{})) if err != nil { @@ -75,8 +98,8 @@ func testCreateUsers(t *testing.T) { Email(uid + "email@test.com"). DisplayName("display_name"). Password("password") - u, err = client.CreateUser(context.Background(), params) - if err != nil { + + if u, err = client.CreateUser(context.Background(), params); err != nil { t.Fatal(err) } testFixtures.sampleUserWithData = u @@ -85,6 +108,7 @@ func testCreateUsers(t *testing.T) { func testGetUser(t *testing.T) { want := testFixtures.sampleUserWithData + u, err := client.GetUser(context.Background(), want.UID) if err != nil { t.Fatalf("error getting user %s", err) @@ -216,12 +240,13 @@ func testUpdateUser(t *testing.T) { UID: testFixtures.sampleUserBlank.UID, ProviderID: "firebase", }, + TokensValidAfterMillis: u.TokensValidAfterMillis, UserMetadata: &auth.UserMetadata{ CreationTimestamp: testFixtures.sampleUserBlank.UserMetadata.CreationTimestamp, }, } if !reflect.DeepEqual(u, want) { - t.Errorf("GetUser() = %v; want = %v", u, want) + t.Errorf("GetUser() = %#v; want = %#v", u, want) } params := (&auth.UserToUpdate{}). @@ -247,6 +272,7 @@ func testUpdateUser(t *testing.T) { ProviderID: "firebase", Email: "abc@ab.ab", }, + TokensValidAfterMillis: u.TokensValidAfterMillis, UserMetadata: &auth.UserMetadata{ CreationTimestamp: testFixtures.sampleUserBlank.UserMetadata.CreationTimestamp, }, @@ -289,7 +315,7 @@ func testUpdateUser(t *testing.T) { // now compare the rest of the record, without the ProviderInfo u.ProviderUserInfo = nil if !reflect.DeepEqual(u, want) { - t.Errorf("UpdateUser() = %v; want = %v", u, want) + t.Errorf("UpdateUser() = %#v; want = %#v", u, want) } } diff --git a/testdata/get_user.json b/testdata/get_user.json index a56ef9f3..a62102e0 100644 --- a/testdata/get_user.json +++ b/testdata/get_user.json @@ -28,8 +28,8 @@ "passwordUpdatedAt": 1.494364393E+12, "validSince": "1494364393", "disabled": false, - "createdAt": "1234567890", - "lastLoginAt": "1233211232", + "createdAt": "1234567890000", + "lastLoginAt": "1233211232000", "customAttributes": "{\"admin\": true, \"package\": \"gold\"}" } ] diff --git a/testdata/list_users.json b/testdata/list_users.json index 21d152fc..a0c625ef 100644 --- a/testdata/list_users.json +++ b/testdata/list_users.json @@ -28,8 +28,8 @@ "passwordUpdatedAt": 1.494364393E+12, "validSince": "1494364393", "disabled": false, - "createdAt": "1234567890", - "lastLoginAt": "1233211232", + "createdAt": "1234567890000", + "lastLoginAt": "1233211232000", "customAttributes": "{\"admin\": true, \"package\": \"gold\"}" }, { @@ -59,8 +59,8 @@ "passwordUpdatedAt": 1.494364393E+12, "validSince": "1494364393", "disabled": false, - "createdAt": "1234567890", - "lastLoginAt": "1233211232", + "createdAt": "1234567890000", + "lastLoginAt": "1233211232000", "customAttributes": "{\"admin\": true, \"package\": \"gold\"}" }, { @@ -90,8 +90,8 @@ "passwordUpdatedAt": 1.494364393E+12, "validSince": "1494364393", "disabled": false, - "createdAt": "1234567890", - "lastLoginAt": "1233211232", + "createdAt": "1234567890000", + "lastLoginAt": "1233211232000", "customAttributes": "{\"admin\": true, \"package\": \"gold\"}" } ], From 3f7b4ba80820fff3d03d8a4418d2beb03fff9720 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Mon, 12 Feb 2018 18:17:54 -0800 Subject: [PATCH 08/27] Firebase Cloud Messaging API (#81) * Adding Firebase Cloud Messaging (#62) * initial commit for adding Firebase Cloud Messaging * add validator * use http const in messaging test * add client version header for stats * init integration test * add integration test (validated on IOS today) * add comment with URL to enable Firebase Cloud Messaging API * fix broken test * add integration tests * accept a Message instead of RequestMessage + and rename method + send / sendDryRun * update fcm url * rollback url endpoint * fix http constants, change responseMessage visibility, change map[string]interface{} as map[string]string * fix http constants * fix integration tests * fix APNS naming * add validators * Added APNS types; Updated tests * Added more tests; Fixed APNS serialization * Updated documentation * Improved error handling inFCM * Added utils file * Updated integration tests * Implemented topic management operations * Added integration tests * Updated CHANGELOG * Addressing code review comments * Supporting 0 valued Aps.Badge * Addressing some review comments * Removed some unused vars * Accepting prefixed topic names (#84) * Accepting prefixed topic named * Added a comment * Using new FCM error codes (#89) --- CHANGELOG.md | 24 +- firebase.go | 12 + firebase_test.go | 12 + integration/messaging/messaging_test.go | 125 ++++ internal/internal.go | 7 + messaging/messaging.go | 475 +++++++++++++ messaging/messaging_test.go | 870 ++++++++++++++++++++++++ messaging/messaging_utils.go | 131 ++++ 8 files changed, 1647 insertions(+), 9 deletions(-) create mode 100644 integration/messaging/messaging_test.go create mode 100644 messaging/messaging.go create mode 100644 messaging/messaging_test.go create mode 100644 messaging/messaging_utils.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 51a8c4af..ea15a5f5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,16 +1,22 @@ # Unreleased -### Token revocation -- [added] A New ['VerifyIDTokenAndCheckRevoked(ctx, token)'](https://godoc.org/firebase.google.com/go/auth#Client.VerifyIDToken) - method has been added to check for revoked ID tokens. -- [added] A new method ['RevokeRefreshTokens(uid)'](https://godoc.org/firebase.google.com/go/auth#Client.RevokeRefreshTokens) - has been added to invalidate all refresh tokens issued to a user. -- [added] A new property - `TokensValidAfterMillis` has been added to the ['UserRecord'](https://godoc.org/firebase.google.com/go/auth#UserRecord). - This property stores the time of the revocation truncated to 1 second accuracy. - - Import context from golang.org/x/net/ for 1.6 compatibility +### Cloud Messaging + +- [feature] Added the `messaging` package for sending Firebase notifications + and managing topic subscriptions. + +### Authentication + +- [added] A new [`VerifyIDTokenAndCheckRevoked()`](https://godoc.org/firebase.google.com/go/auth#Client.VerifyIDToken) + function has been added to check for revoked ID tokens. +- [added] A new [`RevokeRefreshTokens()`](https://godoc.org/firebase.google.com/go/auth#Client.RevokeRefreshTokens) + function has been added to invalidate all refresh tokens issued to a user. +- [added] A new property `TokensValidAfterMillis` has been added to the + ['UserRecord'](https://godoc.org/firebase.google.com/go/auth#UserRecord) + type, which stores the time of the revocation truncated to 1 second accuracy. + # v2.4.0 ### Initialization diff --git a/firebase.go b/firebase.go index a0e5085c..9ed07125 100644 --- a/firebase.go +++ b/firebase.go @@ -28,6 +28,7 @@ import ( "firebase.google.com/go/auth" "firebase.google.com/go/iid" "firebase.google.com/go/internal" + "firebase.google.com/go/messaging" "firebase.google.com/go/storage" "golang.org/x/net/context" @@ -43,6 +44,7 @@ var firebaseScopes = []string{ "https://www.googleapis.com/auth/firebase", "https://www.googleapis.com/auth/identitytoolkit", "https://www.googleapis.com/auth/userinfo.email", + "https://www.googleapis.com/auth/firebase.messaging", } // Version of the Firebase Go Admin SDK. @@ -103,6 +105,16 @@ func (a *App) InstanceID(ctx context.Context) (*iid.Client, error) { return iid.NewClient(ctx, conf) } +// Messaging returns an instance of messaging.Client. +func (a *App) Messaging(ctx context.Context) (*messaging.Client, error) { + conf := &internal.MessagingConfig{ + ProjectID: a.projectID, + Opts: a.opts, + Version: Version, + } + return messaging.NewClient(ctx, conf) +} + // NewApp creates a new App from the provided config and client options. // // If the client options contain a valid credential (a service account file, a refresh token diff --git a/firebase_test.go b/firebase_test.go index df41c56d..fc33ba20 100644 --- a/firebase_test.go +++ b/firebase_test.go @@ -304,6 +304,18 @@ func TestInstanceID(t *testing.T) { } } +func TestMessaging(t *testing.T) { + ctx := context.Background() + app, err := NewApp(ctx, nil, option.WithCredentialsFile("testdata/service_account.json")) + if err != nil { + t.Fatal(err) + } + + if c, err := app.Messaging(ctx); c == nil || err != nil { + t.Errorf("Messaging() = (%v, %v); want (iid, nil)", c, err) + } +} + func TestCustomTokenSource(t *testing.T) { ctx := context.Background() ts := &testTokenSource{AccessToken: "mock-token-from-custom"} diff --git a/integration/messaging/messaging_test.go b/integration/messaging/messaging_test.go new file mode 100644 index 00000000..d7bb0693 --- /dev/null +++ b/integration/messaging/messaging_test.go @@ -0,0 +1,125 @@ +// Copyright 2018 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messaging + +import ( + "flag" + "log" + "os" + "regexp" + "testing" + + "golang.org/x/net/context" + + "firebase.google.com/go/integration/internal" + "firebase.google.com/go/messaging" +) + +// The registration token has the proper format, but is not valid (i.e. expired). The intention of +// these integration tests is to verify that the endpoints return the proper payload, but it is +// hard to ensure this token remains valid. The tests below should still pass regardless. +const testRegistrationToken = "fGw0qy4TGgk:APA91bGtWGjuhp4WRhHXgbabIYp1jxEKI08ofj_v1bKhWAGJQ4e3a" + + "rRCWzeTfHaLz83mBnDh0aPWB1AykXAVUUGl2h1wT4XI6XazWpvY7RBUSYfoxtqSWGIm2nvWh2BOP1YG501SsRoE" + +var client *messaging.Client + +// Enable API before testing +// https://console.developers.google.com/apis/library/fcm.googleapis.com +func TestMain(m *testing.M) { + flag.Parse() + if testing.Short() { + log.Println("Skipping messaging integration tests in short mode.") + return + } + + ctx := context.Background() + app, err := internal.NewTestApp(ctx) + if err != nil { + log.Fatalln(err) + } + + client, err = app.Messaging(ctx) + if err != nil { + log.Fatalln(err) + } + os.Exit(m.Run()) +} + +func TestSend(t *testing.T) { + msg := &messaging.Message{ + Topic: "foo-bar", + Notification: &messaging.Notification{ + Title: "Title", + Body: "Body", + }, + Android: &messaging.AndroidConfig{ + Notification: &messaging.AndroidNotification{ + Title: "Android Title", + Body: "Android Body", + }, + }, + APNS: &messaging.APNSConfig{ + Payload: &messaging.APNSPayload{ + Aps: &messaging.Aps{ + Alert: &messaging.ApsAlert{ + Title: "APNS Title", + Body: "APNS Body", + }, + }, + }, + }, + Webpush: &messaging.WebpushConfig{ + Notification: &messaging.WebpushNotification{ + Title: "Webpush Title", + Body: "Webpush Body", + }, + }, + } + name, err := client.SendDryRun(context.Background(), msg) + if err != nil { + log.Fatalln(err) + } + const pattern = "^projects/.*/messages/.*$" + if !regexp.MustCompile(pattern).MatchString(name) { + t.Errorf("Send() = %q; want = %q", name, pattern) + } +} + +func TestSendInvalidToken(t *testing.T) { + msg := &messaging.Message{Token: "INVALID_TOKEN"} + if _, err := client.Send(context.Background(), msg); err == nil { + t.Errorf("Send() = nil; want error") + } +} + +func TestSubscribe(t *testing.T) { + tmr, err := client.SubscribeToTopic(context.Background(), []string{testRegistrationToken}, "mock-topic") + if err != nil { + t.Fatal(err) + } + if tmr.SuccessCount+tmr.FailureCount != 1 { + t.Errorf("SubscribeToTopic() = %v; want total 1", tmr) + } +} + +func TestUnsubscribe(t *testing.T) { + tmr, err := client.UnsubscribeFromTopic(context.Background(), []string{testRegistrationToken}, "mock-topic") + if err != nil { + t.Fatal(err) + } + if tmr.SuccessCount+tmr.FailureCount != 1 { + t.Errorf("UnsubscribeFromTopic() = %v; want total 1", tmr) + } +} diff --git a/internal/internal.go b/internal/internal.go index 34c4f32d..225edc9e 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -46,6 +46,13 @@ type MockTokenSource struct { AccessToken string } +// MessagingConfig represents the configuration of Firebase Cloud Messaging service. +type MessagingConfig struct { + Opts []option.ClientOption + ProjectID string + Version string +} + // Token returns the test token associated with the TokenSource. func (ts *MockTokenSource) Token() (*oauth2.Token, error) { return &oauth2.Token{AccessToken: ts.AccessToken}, nil diff --git a/messaging/messaging.go b/messaging/messaging.go new file mode 100644 index 00000000..97b77d64 --- /dev/null +++ b/messaging/messaging.go @@ -0,0 +1,475 @@ +// Copyright 2018 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package messaging contains functions for sending messages and managing +// device subscriptions with Firebase Cloud Messaging (FCM). +package messaging + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "regexp" + "strings" + "time" + + "golang.org/x/net/context" + + "firebase.google.com/go/internal" + "google.golang.org/api/transport" +) + +const ( + messagingEndpoint = "https://fcm.googleapis.com/v1" + iidEndpoint = "https://iid.googleapis.com" + iidSubscribe = "iid/v1:batchAdd" + iidUnsubscribe = "iid/v1:batchRemove" +) + +var ( + topicNamePattern = regexp.MustCompile("^(/topics/)?(private/)?[a-zA-Z0-9-_.~%]+$") + + fcmErrorCodes = map[string]string{ + "INVALID_ARGUMENT": "request contains an invalid argument; code: invalid-argument", + "UNREGISTERED": "app instance has been unregistered; code: registration-token-not-registered", + "SENDER_ID_MISMATCH": "sender id does not match regisration token; code: authentication-error", + "QUOTA_EXCEEDED": "messaging service quota exceeded; code: message-rate-exceeded", + "APNS_AUTH_ERROR": "apns certificate or auth key was invalid; code: authentication-error", + "UNAVAILABLE": "backend servers are temporarily unavailable; code: server-unavailable", + "INTERNAL": "back servers encountered an unknown internl error; code: internal-error", + } + + iidErrorCodes = map[string]string{ + "INVALID_ARGUMENT": "request contains an invalid argument; code: invalid-argument", + "NOT_FOUND": "request contains an invalid argument; code: registration-token-not-registered", + "INTERNAL": "server encountered an internal error; code: internal-error", + "TOO_MANY_TOPICS": "client exceeded the number of allowed topics; code: too-many-topics", + } +) + +// Client is the interface for the Firebase Cloud Messaging (FCM) service. +type Client struct { + fcmEndpoint string // to enable testing against arbitrary endpoints + iidEndpoint string // to enable testing against arbitrary endpoints + client *internal.HTTPClient + project string + version string +} + +// Message to be sent via Firebase Cloud Messaging. +// +// Message contains payload data, recipient information and platform-specific configuration +// options. A Message must specify exactly one of Token, Topic or Condition fields. Apart from +// that a Message may specify any combination of Data, Notification, Android, Webpush and APNS +// fields. See https://firebase.google.com/docs/reference/fcm/rest/v1/projects.messages for more +// details on how the backend FCM servers handle different message parameters. +type Message struct { + Data map[string]string `json:"data,omitempty"` + Notification *Notification `json:"notification,omitempty"` + Android *AndroidConfig `json:"android,omitempty"` + Webpush *WebpushConfig `json:"webpush,omitempty"` + APNS *APNSConfig `json:"apns,omitempty"` + Token string `json:"token,omitempty"` + Topic string `json:"-"` + Condition string `json:"condition,omitempty"` +} + +// MarshalJSON marshals a Message into JSON (for internal use only). +func (m *Message) MarshalJSON() ([]byte, error) { + // Create a new type to prevent infinite recursion. + type messageInternal Message + s := &struct { + BareTopic string `json:"topic,omitempty"` + *messageInternal + }{ + BareTopic: strings.TrimPrefix(m.Topic, "/topics/"), + messageInternal: (*messageInternal)(m), + } + return json.Marshal(s) +} + +// Notification is the basic notification template to use across all platforms. +type Notification struct { + Title string `json:"title,omitempty"` + Body string `json:"body,omitempty"` +} + +// AndroidConfig contains messaging options specific to the Android platform. +type AndroidConfig struct { + CollapseKey string `json:"collapse_key,omitempty"` + Priority string `json:"priority,omitempty"` // one of "normal" or "high" + TTL *time.Duration `json:"-"` + RestrictedPackageName string `json:"restricted_package_name,omitempty"` + Data map[string]string `json:"data,omitempty"` // if specified, overrides the Data field on Message type + Notification *AndroidNotification `json:"notification,omitempty"` +} + +// MarshalJSON marshals an AndroidConfig into JSON (for internal use only). +func (a *AndroidConfig) MarshalJSON() ([]byte, error) { + var ttl string + if a.TTL != nil { + seconds := int64(*a.TTL / time.Second) + nanos := int64((*a.TTL - time.Duration(seconds)*time.Second) / time.Nanosecond) + if nanos > 0 { + ttl = fmt.Sprintf("%d.%09ds", seconds, nanos) + } else { + ttl = fmt.Sprintf("%ds", seconds) + } + } + + type androidInternal AndroidConfig + s := &struct { + TTL string `json:"ttl,omitempty"` + *androidInternal + }{ + TTL: ttl, + androidInternal: (*androidInternal)(a), + } + return json.Marshal(s) +} + +// AndroidNotification is a notification to send to Android devices. +type AndroidNotification struct { + Title string `json:"title,omitempty"` // if specified, overrides the Title field of the Notification type + Body string `json:"body,omitempty"` // if specified, overrides the Body field of the Notification type + Icon string `json:"icon,omitempty"` + Color string `json:"color,omitempty"` // notification color in #RRGGBB format + Sound string `json:"sound,omitempty"` + Tag string `json:"tag,omitempty"` + ClickAction string `json:"click_action,omitempty"` + BodyLocKey string `json:"body_loc_key,omitempty"` + BodyLocArgs []string `json:"body_loc_args,omitempty"` + TitleLocKey string `json:"title_loc_key,omitempty"` + TitleLocArgs []string `json:"title_loc_args,omitempty"` +} + +// WebpushConfig contains messaging options specific to the WebPush protocol. +// +// See https://tools.ietf.org/html/rfc8030#section-5 for additional details, and supported +// headers. +type WebpushConfig struct { + Headers map[string]string `json:"headers,omitempty"` + Data map[string]string `json:"data,omitempty"` + Notification *WebpushNotification `json:"notification,omitempty"` +} + +// WebpushNotification is a notification to send via WebPush protocol. +type WebpushNotification struct { + Title string `json:"title,omitempty"` // if specified, overrides the Title field of the Notification type + Body string `json:"body,omitempty"` // if specified, overrides the Body field of the Notification type + Icon string `json:"icon,omitempty"` +} + +// APNSConfig contains messaging options specific to the Apple Push Notification Service (APNS). +// +// See https://developer.apple.com/library/content/documentation/NetworkingInternet/Conceptual/RemoteNotificationsPG/CommunicatingwithAPNs.html +// for more details on supported headers and payload keys. +type APNSConfig struct { + Headers map[string]string `json:"headers,omitempty"` + Payload *APNSPayload `json:"payload,omitempty"` +} + +// APNSPayload is the payload that can be included in an APNS message. +// +// The payload mainly consists of the aps dictionary. Additionally it may contain arbitrary +// key-values pairs as custom data fields. +// +// See https://developer.apple.com/library/content/documentation/NetworkingInternet/Conceptual/RemoteNotificationsPG/PayloadKeyReference.html +// for a full list of supported payload fields. +type APNSPayload struct { + Aps *Aps + CustomData map[string]interface{} +} + +// MarshalJSON marshals an APNSPayload into JSON (for internal use only). +func (p *APNSPayload) MarshalJSON() ([]byte, error) { + m := map[string]interface{}{"aps": p.Aps} + for k, v := range p.CustomData { + m[k] = v + } + return json.Marshal(m) +} + +// Aps represents the aps dictionary that may be included in an APNSPayload. +// +// Alert may be specified as a string (via the AlertString field), or as a struct (via the Alert +// field). +type Aps struct { + AlertString string `json:"-"` + Alert *ApsAlert `json:"-"` + Badge *int `json:"badge,omitempty"` + Sound string `json:"sound,omitempty"` + ContentAvailable bool `json:"-"` + Category string `json:"category,omitempty"` + ThreadID string `json:"thread-id,omitempty"` +} + +// MarshalJSON marshals an Aps into JSON (for internal use only). +func (a *Aps) MarshalJSON() ([]byte, error) { + type apsAlias Aps + s := &struct { + Alert interface{} `json:"alert,omitempty"` + ContentAvailable *int `json:"content-available,omitempty"` + *apsAlias + }{ + apsAlias: (*apsAlias)(a), + } + + if a.Alert != nil { + s.Alert = a.Alert + } else if a.AlertString != "" { + s.Alert = a.AlertString + } + if a.ContentAvailable { + one := 1 + s.ContentAvailable = &one + } + return json.Marshal(s) +} + +// ApsAlert is the alert payload that can be included in an Aps. +// +// See https://developer.apple.com/library/content/documentation/NetworkingInternet/Conceptual/RemoteNotificationsPG/PayloadKeyReference.html +// for supported fields. +type ApsAlert struct { + Title string `json:"title,omitempty"` // if specified, overrides the Title field of the Notification type + Body string `json:"body,omitempty"` // if specified, overrides the Body field of the Notification type + LocKey string `json:"loc-key,omitempty"` + LocArgs []string `json:"loc-args,omitempty"` + TitleLocKey string `json:"title-loc-key,omitempty"` + TitleLocArgs []string `json:"title-loc-args,omitempty"` + ActionLocKey string `json:"action-loc-key,omitempty"` + LaunchImage string `json:"launch-image,omitempty"` +} + +// ErrorInfo is a topic management error. +type ErrorInfo struct { + Index int + Reason string +} + +// TopicManagementResponse is the result produced by topic management operations. +// +// TopicManagementResponse provides an overview of how many input tokens were successfully handled, +// and how many failed. In case of failures, the Errors list provides specific details concerning +// each error. +type TopicManagementResponse struct { + SuccessCount int + FailureCount int + Errors []*ErrorInfo +} + +func newTopicManagementResponse(resp *iidResponse) *TopicManagementResponse { + tmr := &TopicManagementResponse{} + for idx, res := range resp.Results { + if len(res) == 0 { + tmr.SuccessCount++ + } else { + tmr.FailureCount++ + code := res["error"].(string) + reason := iidErrorCodes[code] + if reason == "" { + reason = "unknown-error" + } + tmr.Errors = append(tmr.Errors, &ErrorInfo{ + Index: idx, + Reason: reason, + }) + } + } + return tmr +} + +// NewClient creates a new instance of the Firebase Cloud Messaging Client. +// +// This function can only be invoked from within the SDK. Client applications should access the +// the messaging service through firebase.App. +func NewClient(ctx context.Context, c *internal.MessagingConfig) (*Client, error) { + if c.ProjectID == "" { + return nil, errors.New("project ID is required to access Firebase Cloud Messaging client") + } + + hc, _, err := transport.NewHTTPClient(ctx, c.Opts...) + if err != nil { + return nil, err + } + + return &Client{ + fcmEndpoint: messagingEndpoint, + iidEndpoint: iidEndpoint, + client: &internal.HTTPClient{Client: hc}, + project: c.ProjectID, + version: "Go/Admin/" + c.Version, + }, nil +} + +// Send sends a Message to Firebase Cloud Messaging. +// +// The Message must specify exactly one of Token, Topic and Condition fields. FCM will +// customize the message for each target platform based on the arguments specified in the +// Message. +func (c *Client) Send(ctx context.Context, message *Message) (string, error) { + payload := &fcmRequest{ + Message: message, + } + return c.makeSendRequest(ctx, payload) +} + +// SendDryRun sends a Message to Firebase Cloud Messaging in the dry run (validation only) mode. +// +// This function does not actually deliver the message to target devices. Instead, it performs all +// the SDK-level and backend validations on the message, and emulates the send operation. +func (c *Client) SendDryRun(ctx context.Context, message *Message) (string, error) { + payload := &fcmRequest{ + ValidateOnly: true, + Message: message, + } + return c.makeSendRequest(ctx, payload) +} + +// SubscribeToTopic subscribes a list of registration tokens to a topic. +// +// The tokens list must not be empty, and have at most 1000 tokens. +func (c *Client) SubscribeToTopic(ctx context.Context, tokens []string, topic string) (*TopicManagementResponse, error) { + req := &iidRequest{ + Topic: topic, + Tokens: tokens, + op: iidSubscribe, + } + return c.makeTopicManagementRequest(ctx, req) +} + +// UnsubscribeFromTopic unsubscribes a list of registration tokens from a topic. +// +// The tokens list must not be empty, and have at most 1000 tokens. +func (c *Client) UnsubscribeFromTopic(ctx context.Context, tokens []string, topic string) (*TopicManagementResponse, error) { + req := &iidRequest{ + Topic: topic, + Tokens: tokens, + op: iidSubscribe, + } + return c.makeTopicManagementRequest(ctx, req) +} + +type fcmRequest struct { + ValidateOnly bool `json:"validate_only,omitempty"` + Message *Message `json:"message,omitempty"` +} + +type fcmResponse struct { + Name string `json:"name"` +} + +type fcmError struct { + Error struct { + Status string `json:"status"` + } `json:"error"` +} + +type iidRequest struct { + Topic string `json:"to"` + Tokens []string `json:"registration_tokens"` + op string +} + +type iidResponse struct { + Results []map[string]interface{} `json:"results"` +} + +type iidError struct { + Error string `json:"error"` +} + +func (c *Client) makeSendRequest(ctx context.Context, req *fcmRequest) (string, error) { + if err := validateMessage(req.Message); err != nil { + return "", err + } + + request := &internal.Request{ + Method: http.MethodPost, + URL: fmt.Sprintf("%s/projects/%s/messages:send", c.fcmEndpoint, c.project), + Body: internal.NewJSONEntity(req), + } + resp, err := c.client.Do(ctx, request) + if err != nil { + return "", err + } + + if resp.Status == http.StatusOK { + var result fcmResponse + err := json.Unmarshal(resp.Body, &result) + return result.Name, err + } + + var fe fcmError + json.Unmarshal(resp.Body, &fe) // ignore any json parse errors at this level + msg := fcmErrorCodes[fe.Error.Status] + if msg == "" { + msg = fmt.Sprintf("server responded with an unknown error; response: %s", string(resp.Body)) + } + return "", fmt.Errorf("http error status: %d; reason: %s", resp.Status, msg) +} + +func (c *Client) makeTopicManagementRequest(ctx context.Context, req *iidRequest) (*TopicManagementResponse, error) { + if len(req.Tokens) == 0 { + return nil, fmt.Errorf("no tokens specified") + } + if len(req.Tokens) > 1000 { + return nil, fmt.Errorf("tokens list must not contain more than 1000 items") + } + for _, token := range req.Tokens { + if token == "" { + return nil, fmt.Errorf("tokens list must not contain empty strings") + } + } + + if req.Topic == "" { + return nil, fmt.Errorf("topic name not specified") + } + if !topicNamePattern.MatchString(req.Topic) { + return nil, fmt.Errorf("invalid topic name: %q", req.Topic) + } + + if !strings.HasPrefix(req.Topic, "/topics/") { + req.Topic = "/topics/" + req.Topic + } + + request := &internal.Request{ + Method: http.MethodPost, + URL: fmt.Sprintf("%s/%s", c.iidEndpoint, req.op), + Body: internal.NewJSONEntity(req), + Opts: []internal.HTTPOption{internal.WithHeader("access_token_auth", "true")}, + } + resp, err := c.client.Do(ctx, request) + if err != nil { + return nil, err + } + + if resp.Status == http.StatusOK { + var result iidResponse + if err := json.Unmarshal(resp.Body, &result); err != nil { + return nil, err + } + return newTopicManagementResponse(&result), nil + } + + var ie iidError + json.Unmarshal(resp.Body, &ie) // ignore any json parse errors at this level + msg := iidErrorCodes[ie.Error] + if msg == "" { + msg = fmt.Sprintf("client encountered an unknown error; response: %s", string(resp.Body)) + } + return nil, fmt.Errorf("http error status: %d; reason: %s", resp.Status, msg) +} diff --git a/messaging/messaging_test.go b/messaging/messaging_test.go new file mode 100644 index 00000000..27808e84 --- /dev/null +++ b/messaging/messaging_test.go @@ -0,0 +1,870 @@ +// Copyright 2018 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messaging + +import ( + "context" + "encoding/json" + "io/ioutil" + "net/http" + "net/http/httptest" + "reflect" + "strings" + "testing" + "time" + + "google.golang.org/api/option" + + "firebase.google.com/go/internal" +) + +const testMessageID = "projects/test-project/messages/msg_id" + +var ( + testMessagingConfig = &internal.MessagingConfig{ + ProjectID: "test-project", + Opts: []option.ClientOption{ + option.WithTokenSource(&internal.MockTokenSource{AccessToken: "test-token"}), + }, + } + + ttlWithNanos = time.Duration(1500) * time.Millisecond + ttl = time.Duration(10) * time.Second + invalidTTL = time.Duration(-10) * time.Second + + badge = 42 + badgeZero = 0 +) + +var validMessages = []struct { + name string + req *Message + want map[string]interface{} +}{ + { + name: "TokenOnly", + req: &Message{Token: "test-token"}, + want: map[string]interface{}{"token": "test-token"}, + }, + { + name: "TopicOnly", + req: &Message{Topic: "test-topic"}, + want: map[string]interface{}{"topic": "test-topic"}, + }, + { + name: "PrefixedTopicOnly", + req: &Message{Topic: "/topics/test-topic"}, + want: map[string]interface{}{"topic": "test-topic"}, + }, + { + name: "ConditionOnly", + req: &Message{Condition: "test-condition"}, + want: map[string]interface{}{"condition": "test-condition"}, + }, + { + name: "DataMessage", + req: &Message{ + Data: map[string]string{ + "k1": "v1", + "k2": "v2", + }, + Topic: "test-topic", + }, + want: map[string]interface{}{ + "data": map[string]interface{}{ + "k1": "v1", + "k2": "v2", + }, + "topic": "test-topic", + }, + }, + { + name: "NotificationMessage", + req: &Message{ + Notification: &Notification{ + Title: "t", + Body: "b", + }, + Topic: "test-topic", + }, + want: map[string]interface{}{ + "notification": map[string]interface{}{ + "title": "t", + "body": "b", + }, + "topic": "test-topic", + }, + }, + { + name: "AndroidDataMessage", + req: &Message{ + Android: &AndroidConfig{ + CollapseKey: "ck", + Data: map[string]string{ + "k1": "v1", + "k2": "v2", + }, + Priority: "normal", + TTL: &ttl, + }, + Topic: "test-topic", + }, + want: map[string]interface{}{ + "android": map[string]interface{}{ + "collapse_key": "ck", + "data": map[string]interface{}{ + "k1": "v1", + "k2": "v2", + }, + "priority": "normal", + "ttl": "10s", + }, + "topic": "test-topic", + }, + }, + { + name: "AndroidNotificationMessage", + req: &Message{ + Android: &AndroidConfig{ + RestrictedPackageName: "rpn", + Notification: &AndroidNotification{ + Title: "t", + Body: "b", + Color: "#112233", + Sound: "s", + TitleLocKey: "tlk", + TitleLocArgs: []string{"t1", "t2"}, + BodyLocKey: "blk", + BodyLocArgs: []string{"b1", "b2"}, + }, + TTL: &ttlWithNanos, + }, + Topic: "test-topic", + }, + want: map[string]interface{}{ + "android": map[string]interface{}{ + "restricted_package_name": "rpn", + "notification": map[string]interface{}{ + "title": "t", + "body": "b", + "color": "#112233", + "sound": "s", + "title_loc_key": "tlk", + "title_loc_args": []interface{}{"t1", "t2"}, + "body_loc_key": "blk", + "body_loc_args": []interface{}{"b1", "b2"}, + }, + "ttl": "1.500000000s", + }, + "topic": "test-topic", + }, + }, + { + name: "AndroidNoTTL", + req: &Message{ + Android: &AndroidConfig{ + Priority: "high", + }, + Topic: "test-topic", + }, + want: map[string]interface{}{ + "android": map[string]interface{}{ + "priority": "high", + }, + "topic": "test-topic", + }, + }, + { + name: "WebpushMessage", + req: &Message{ + Webpush: &WebpushConfig{ + Headers: map[string]string{ + "h1": "v1", + "h2": "v2", + }, + Data: map[string]string{ + "k1": "v1", + "k2": "v2", + }, + Notification: &WebpushNotification{ + Title: "t", + Body: "b", + Icon: "i", + }, + }, + Topic: "test-topic", + }, + want: map[string]interface{}{ + "webpush": map[string]interface{}{ + "headers": map[string]interface{}{"h1": "v1", "h2": "v2"}, + "data": map[string]interface{}{"k1": "v1", "k2": "v2"}, + "notification": map[string]interface{}{"title": "t", "body": "b", "icon": "i"}, + }, + "topic": "test-topic", + }, + }, + { + name: "APNSHeadersOnly", + req: &Message{ + APNS: &APNSConfig{ + Headers: map[string]string{ + "h1": "v1", + "h2": "v2", + }, + }, + Topic: "test-topic", + }, + want: map[string]interface{}{ + "apns": map[string]interface{}{ + "headers": map[string]interface{}{"h1": "v1", "h2": "v2"}, + }, + "topic": "test-topic", + }, + }, + { + name: "APNSAlertString", + req: &Message{ + APNS: &APNSConfig{ + Headers: map[string]string{ + "h1": "v1", + "h2": "v2", + }, + Payload: &APNSPayload{ + Aps: &Aps{ + AlertString: "a", + Badge: &badge, + Category: "c", + Sound: "s", + ThreadID: "t", + ContentAvailable: true, + }, + CustomData: map[string]interface{}{ + "k1": "v1", + "k2": true, + }, + }, + }, + Topic: "test-topic", + }, + want: map[string]interface{}{ + "apns": map[string]interface{}{ + "headers": map[string]interface{}{"h1": "v1", "h2": "v2"}, + "payload": map[string]interface{}{ + "aps": map[string]interface{}{ + "alert": "a", + "badge": float64(badge), + "category": "c", + "sound": "s", + "thread-id": "t", + "content-available": float64(1), + }, + "k1": "v1", + "k2": true, + }, + }, + "topic": "test-topic", + }, + }, + { + name: "APNSBadgeZero", + req: &Message{ + APNS: &APNSConfig{ + Payload: &APNSPayload{ + Aps: &Aps{ + Badge: &badgeZero, + Category: "c", + Sound: "s", + ThreadID: "t", + ContentAvailable: true, + }, + }, + }, + Topic: "test-topic", + }, + want: map[string]interface{}{ + "apns": map[string]interface{}{ + "payload": map[string]interface{}{ + "aps": map[string]interface{}{ + "badge": float64(badgeZero), + "category": "c", + "sound": "s", + "thread-id": "t", + "content-available": float64(1), + }, + }, + }, + "topic": "test-topic", + }, + }, + { + name: "APNSAlertObject", + req: &Message{ + APNS: &APNSConfig{ + Payload: &APNSPayload{ + Aps: &Aps{ + Alert: &ApsAlert{ + Title: "t", + Body: "b", + TitleLocKey: "tlk", + TitleLocArgs: []string{"t1", "t2"}, + LocKey: "blk", + LocArgs: []string{"b1", "b2"}, + ActionLocKey: "alk", + LaunchImage: "li", + }, + }, + }, + }, + Topic: "test-topic", + }, + want: map[string]interface{}{ + "apns": map[string]interface{}{ + "payload": map[string]interface{}{ + "aps": map[string]interface{}{ + "alert": map[string]interface{}{ + "title": "t", + "body": "b", + "title-loc-key": "tlk", + "title-loc-args": []interface{}{"t1", "t2"}, + "loc-key": "blk", + "loc-args": []interface{}{"b1", "b2"}, + "action-loc-key": "alk", + "launch-image": "li", + }, + }, + }, + }, + "topic": "test-topic", + }, + }, +} + +var invalidMessages = []struct { + name string + req *Message + want string +}{ + { + name: "NilMessage", + req: nil, + want: "message must not be nil", + }, + { + name: "NoTargets", + req: &Message{}, + want: "exactly one of token, topic or condition must be specified", + }, + { + name: "MultipleTargets", + req: &Message{ + Token: "token", + Topic: "topic", + }, + want: "exactly one of token, topic or condition must be specified", + }, + { + name: "InvalidPrefixedTopicName", + req: &Message{ + Topic: "/topics/", + }, + want: "malformed topic name", + }, + { + name: "InvalidTopicName", + req: &Message{ + Topic: "foo*bar", + }, + want: "malformed topic name", + }, + { + name: "InvalidAndroidTTL", + req: &Message{ + Android: &AndroidConfig{ + TTL: &invalidTTL, + }, + Topic: "topic", + }, + want: "ttl duration must not be negative", + }, + { + name: "InvalidAndroidPriority", + req: &Message{ + Android: &AndroidConfig{ + Priority: "not normal", + }, + Topic: "topic", + }, + want: "priority must be 'normal' or 'high'", + }, + { + name: "InvalidAndroidColor1", + req: &Message{ + Android: &AndroidConfig{ + Notification: &AndroidNotification{ + Color: "112233", + }, + }, + Topic: "topic", + }, + want: "color must be in the #RRGGBB form", + }, + { + name: "InvalidAndroidColor2", + req: &Message{ + Android: &AndroidConfig{ + Notification: &AndroidNotification{ + Color: "#112233X", + }, + }, + Topic: "topic", + }, + want: "color must be in the #RRGGBB form", + }, + { + name: "InvalidAndroidTitleLocArgs", + req: &Message{ + Android: &AndroidConfig{ + Notification: &AndroidNotification{ + TitleLocArgs: []string{"a1"}, + }, + }, + Topic: "topic", + }, + want: "titleLocKey is required when specifying titleLocArgs", + }, + { + name: "InvalidAndroidBodyLocArgs", + req: &Message{ + Android: &AndroidConfig{ + Notification: &AndroidNotification{ + BodyLocArgs: []string{"a1"}, + }, + }, + Topic: "topic", + }, + want: "bodyLocKey is required when specifying bodyLocArgs", + }, + { + name: "APNSMultipleAlerts", + req: &Message{ + APNS: &APNSConfig{ + Payload: &APNSPayload{ + Aps: &Aps{ + Alert: &ApsAlert{}, + AlertString: "alert", + }, + }, + }, + Topic: "topic", + }, + want: "multiple alert specifications", + }, + { + name: "InvalidAPNSTitleLocArgs", + req: &Message{ + APNS: &APNSConfig{ + Payload: &APNSPayload{ + Aps: &Aps{ + Alert: &ApsAlert{ + TitleLocArgs: []string{"a1"}, + }, + }, + }, + }, + Topic: "topic", + }, + want: "titleLocKey is required when specifying titleLocArgs", + }, + { + name: "InvalidAPNSLocArgs", + req: &Message{ + APNS: &APNSConfig{ + Payload: &APNSPayload{ + Aps: &Aps{ + Alert: &ApsAlert{ + LocArgs: []string{"a1"}, + }, + }, + }, + }, + Topic: "topic", + }, + want: "locKey is required when specifying locArgs", + }, +} + +var invalidTopicMgtArgs = []struct { + name string + tokens []string + topic string + want string +}{ + { + name: "NoTokensAndTopic", + want: "no tokens specified", + }, + { + name: "NoTopic", + tokens: []string{"token1"}, + want: "topic name not specified", + }, + { + name: "InvalidTopicName", + tokens: []string{"token1"}, + topic: "foo*bar", + want: "invalid topic name: \"foo*bar\"", + }, + { + name: "TooManyTokens", + tokens: strings.Split("a"+strings.Repeat(",a", 1000), ","), + topic: "topic", + want: "tokens list must not contain more than 1000 items", + }, + { + name: "EmptyToken", + tokens: []string{"foo", ""}, + topic: "topic", + want: "tokens list must not contain empty strings", + }, +} + +func TestNoProjectID(t *testing.T) { + client, err := NewClient(context.Background(), &internal.MessagingConfig{}) + if client != nil || err == nil { + t.Errorf("NewClient() = (%v, %v); want = (nil, error)", client, err) + } +} + +func TestSend(t *testing.T) { + var tr *http.Request + var b []byte + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + tr = r + b, _ = ioutil.ReadAll(r.Body) + w.Header().Set("Content-Type", "application/json") + w.Write([]byte("{ \"name\":\"" + testMessageID + "\" }")) + })) + defer ts.Close() + + ctx := context.Background() + client, err := NewClient(ctx, testMessagingConfig) + if err != nil { + t.Fatal(err) + } + client.fcmEndpoint = ts.URL + + for _, tc := range validMessages { + t.Run(tc.name, func(t *testing.T) { + name, err := client.Send(ctx, tc.req) + if name != testMessageID || err != nil { + t.Errorf("Send() = (%q, %v); want = (%q, nil)", name, err, testMessageID) + } + checkFCMRequest(t, b, tr, tc.want, false) + }) + } +} + +func TestSendDryRun(t *testing.T) { + var tr *http.Request + var b []byte + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + tr = r + b, _ = ioutil.ReadAll(r.Body) + w.Header().Set("Content-Type", "application/json") + w.Write([]byte("{ \"name\":\"" + testMessageID + "\" }")) + })) + defer ts.Close() + + ctx := context.Background() + client, err := NewClient(ctx, testMessagingConfig) + if err != nil { + t.Fatal(err) + } + client.fcmEndpoint = ts.URL + + for _, tc := range validMessages { + t.Run(tc.name, func(t *testing.T) { + name, err := client.SendDryRun(ctx, tc.req) + if name != testMessageID || err != nil { + t.Errorf("SendDryRun() = (%q, %v); want = (%q, nil)", name, err, testMessageID) + } + checkFCMRequest(t, b, tr, tc.want, true) + }) + } +} + +func TestSendError(t *testing.T) { + var resp string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(resp)) + })) + defer ts.Close() + + ctx := context.Background() + client, err := NewClient(ctx, testMessagingConfig) + if err != nil { + t.Fatal(err) + } + client.fcmEndpoint = ts.URL + + cases := []struct { + resp string + want string + }{ + { + resp: "{}", + want: "http error status: 500; reason: server responded with an unknown error; response: {}", + }, + { + resp: "{\"error\": {\"status\": \"INVALID_ARGUMENT\", \"message\": \"test error\"}}", + want: "http error status: 500; reason: request contains an invalid argument; code: invalid-argument", + }, + { + resp: "not json", + want: "http error status: 500; reason: server responded with an unknown error; response: not json", + }, + } + for _, tc := range cases { + resp = tc.resp + name, err := client.Send(ctx, &Message{Topic: "topic"}) + if err == nil || err.Error() != tc.want { + t.Errorf("Send() = (%q, %v); want = (%q, %q)", name, err, "", tc.want) + } + } +} + +func TestInvalidMessage(t *testing.T) { + ctx := context.Background() + client, err := NewClient(ctx, testMessagingConfig) + if err != nil { + t.Fatal(err) + } + for _, tc := range invalidMessages { + t.Run(tc.name, func(t *testing.T) { + name, err := client.Send(ctx, tc.req) + if err == nil || err.Error() != tc.want { + t.Errorf("Send() = (%q, %v); want = (%q, %q)", name, err, "", tc.want) + } + }) + } +} + +func TestSubscribe(t *testing.T) { + var tr *http.Request + var b []byte + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + tr = r + b, _ = ioutil.ReadAll(r.Body) + w.Header().Set("Content-Type", "application/json") + w.Write([]byte("{\"results\": [{}, {\"error\": \"error_reason\"}]}")) + })) + defer ts.Close() + + ctx := context.Background() + client, err := NewClient(ctx, testMessagingConfig) + if err != nil { + t.Fatal(err) + } + client.iidEndpoint = ts.URL + + resp, err := client.SubscribeToTopic(ctx, []string{"id1", "id2"}, "test-topic") + if err != nil { + t.Fatal(err) + } + checkIIDRequest(t, b, tr, iidSubscribe) + checkTopicMgtResponse(t, resp) +} + +func TestInvalidSubscribe(t *testing.T) { + ctx := context.Background() + client, err := NewClient(ctx, testMessagingConfig) + if err != nil { + t.Fatal(err) + } + for _, tc := range invalidTopicMgtArgs { + t.Run(tc.name, func(t *testing.T) { + name, err := client.SubscribeToTopic(ctx, tc.tokens, tc.topic) + if err == nil || err.Error() != tc.want { + t.Errorf("SubscribeToTopic() = (%q, %v); want = (%q, %q)", name, err, "", tc.want) + } + }) + } +} + +func TestUnsubscribe(t *testing.T) { + var tr *http.Request + var b []byte + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + tr = r + b, _ = ioutil.ReadAll(r.Body) + w.Header().Set("Content-Type", "application/json") + w.Write([]byte("{\"results\": [{}, {\"error\": \"error_reason\"}]}")) + })) + defer ts.Close() + + ctx := context.Background() + client, err := NewClient(ctx, testMessagingConfig) + if err != nil { + t.Fatal(err) + } + client.iidEndpoint = ts.URL + + resp, err := client.UnsubscribeFromTopic(ctx, []string{"id1", "id2"}, "test-topic") + if err != nil { + t.Fatal(err) + } + checkIIDRequest(t, b, tr, iidSubscribe) + checkTopicMgtResponse(t, resp) +} + +func TestInvalidUnsubscribe(t *testing.T) { + ctx := context.Background() + client, err := NewClient(ctx, testMessagingConfig) + if err != nil { + t.Fatal(err) + } + for _, tc := range invalidTopicMgtArgs { + t.Run(tc.name, func(t *testing.T) { + name, err := client.UnsubscribeFromTopic(ctx, tc.tokens, tc.topic) + if err == nil || err.Error() != tc.want { + t.Errorf("UnsubscribeFromTopic() = (%q, %v); want = (%q, %q)", name, err, "", tc.want) + } + }) + } +} + +func TestTopicManagementError(t *testing.T) { + var resp string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(resp)) + })) + defer ts.Close() + + ctx := context.Background() + client, err := NewClient(ctx, testMessagingConfig) + if err != nil { + t.Fatal(err) + } + client.iidEndpoint = ts.URL + + cases := []struct { + resp string + want string + }{ + { + resp: "{}", + want: "http error status: 500; reason: client encountered an unknown error; response: {}", + }, + { + resp: "{\"error\": \"INVALID_ARGUMENT\"}", + want: "http error status: 500; reason: request contains an invalid argument; code: invalid-argument", + }, + { + resp: "not json", + want: "http error status: 500; reason: client encountered an unknown error; response: not json", + }, + } + for _, tc := range cases { + resp = tc.resp + tmr, err := client.SubscribeToTopic(ctx, []string{"id1"}, "topic") + if err == nil || err.Error() != tc.want { + t.Errorf("SubscribeToTopic() = (%q, %v); want = (%q, %q)", tmr, err, "", tc.want) + } + } + for _, tc := range cases { + resp = tc.resp + tmr, err := client.UnsubscribeFromTopic(ctx, []string{"id1"}, "topic") + if err == nil || err.Error() != tc.want { + t.Errorf("UnsubscribeFromTopic() = (%q, %v); want = (%q, %q)", tmr, err, "", tc.want) + } + } +} + +func checkFCMRequest(t *testing.T, b []byte, tr *http.Request, want map[string]interface{}, dryRun bool) { + var parsed map[string]interface{} + if err := json.Unmarshal(b, &parsed); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(parsed["message"], want) { + t.Errorf("Body = %#v; want = %#v", parsed["message"], want) + } + + validate, ok := parsed["validate_only"] + if dryRun { + if !ok || validate != true { + t.Errorf("ValidateOnly = %v; want = true", validate) + } + } else if ok { + t.Errorf("ValidateOnly = %v; want none", validate) + } + + if tr.Method != http.MethodPost { + t.Errorf("Method = %q; want = %q", tr.Method, http.MethodPost) + } + if tr.URL.Path != "/projects/test-project/messages:send" { + t.Errorf("Path = %q; want = %q", tr.URL.Path, "/projects/test-project/messages:send") + } + if h := tr.Header.Get("Authorization"); h != "Bearer test-token" { + t.Errorf("Authorization = %q; want = %q", h, "Bearer test-token") + } +} + +func checkIIDRequest(t *testing.T, b []byte, tr *http.Request, op string) { + var parsed map[string]interface{} + if err := json.Unmarshal(b, &parsed); err != nil { + t.Fatal(err) + } + want := map[string]interface{}{ + "to": "/topics/test-topic", + "registration_tokens": []interface{}{"id1", "id2"}, + } + if !reflect.DeepEqual(parsed, want) { + t.Errorf("Body = %#v; want = %#v", parsed, want) + } + + if tr.Method != http.MethodPost { + t.Errorf("Method = %q; want = %q", tr.Method, http.MethodPost) + } + wantOp := "/" + op + if tr.URL.Path != wantOp { + t.Errorf("Path = %q; want = %q", tr.URL.Path, wantOp) + } + if h := tr.Header.Get("Authorization"); h != "Bearer test-token" { + t.Errorf("Authorization = %q; want = %q", h, "Bearer test-token") + } +} + +func checkTopicMgtResponse(t *testing.T, resp *TopicManagementResponse) { + if resp.SuccessCount != 1 { + t.Errorf("SuccessCount = %d; want = %d", resp.SuccessCount, 1) + } + if resp.FailureCount != 1 { + t.Errorf("FailureCount = %d; want = %d", resp.FailureCount, 1) + } + if len(resp.Errors) != 1 { + t.Fatalf("Errors = %d; want = %d", len(resp.Errors), 1) + } + e := resp.Errors[0] + if e.Index != 1 { + t.Errorf("ErrorInfo.Index = %d; want = %d", e.Index, 1) + } + if e.Reason != "unknown-error" { + t.Errorf("ErrorInfo.Reason = %s; want = %s", e.Reason, "unknown-error") + } +} diff --git a/messaging/messaging_utils.go b/messaging/messaging_utils.go new file mode 100644 index 00000000..ffd4df95 --- /dev/null +++ b/messaging/messaging_utils.go @@ -0,0 +1,131 @@ +// Copyright 2018 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messaging + +import ( + "fmt" + "regexp" + "strings" +) + +var ( + bareTopicNamePattern = regexp.MustCompile("^[a-zA-Z0-9-_.~%]+$") + colorPattern = regexp.MustCompile("^#[0-9a-fA-F]{6}$") +) + +func validateMessage(message *Message) error { + if message == nil { + return fmt.Errorf("message must not be nil") + } + + targets := countNonEmpty(message.Token, message.Condition, message.Topic) + if targets != 1 { + return fmt.Errorf("exactly one of token, topic or condition must be specified") + } + + // validate topic + if message.Topic != "" { + bt := strings.TrimPrefix(message.Topic, "/topics/") + if !bareTopicNamePattern.MatchString(bt) { + return fmt.Errorf("malformed topic name") + } + } + + // validate AndroidConfig + if err := validateAndroidConfig(message.Android); err != nil { + return err + } + + // validate APNSConfig + return validateAPNSConfig(message.APNS) +} + +func validateAndroidConfig(config *AndroidConfig) error { + if config == nil { + return nil + } + + if config.TTL != nil && config.TTL.Seconds() < 0 { + return fmt.Errorf("ttl duration must not be negative") + } + if config.Priority != "" && config.Priority != "normal" && config.Priority != "high" { + return fmt.Errorf("priority must be 'normal' or 'high'") + } + // validate AndroidNotification + return validateAndroidNotification(config.Notification) +} + +func validateAndroidNotification(notification *AndroidNotification) error { + if notification == nil { + return nil + } + if notification.Color != "" && !colorPattern.MatchString(notification.Color) { + return fmt.Errorf("color must be in the #RRGGBB form") + } + if len(notification.TitleLocArgs) > 0 && notification.TitleLocKey == "" { + return fmt.Errorf("titleLocKey is required when specifying titleLocArgs") + } + if len(notification.BodyLocArgs) > 0 && notification.BodyLocKey == "" { + return fmt.Errorf("bodyLocKey is required when specifying bodyLocArgs") + } + return nil +} + +func validateAPNSConfig(config *APNSConfig) error { + if config != nil { + return validateAPNSPayload(config.Payload) + } + return nil +} + +func validateAPNSPayload(payload *APNSPayload) error { + if payload != nil { + return validateAps(payload.Aps) + } + return nil +} + +func validateAps(aps *Aps) error { + if aps != nil { + if aps.Alert != nil && aps.AlertString != "" { + return fmt.Errorf("multiple alert specifications") + } + return validateApsAlert(aps.Alert) + } + return nil +} + +func validateApsAlert(alert *ApsAlert) error { + if alert == nil { + return nil + } + if len(alert.TitleLocArgs) > 0 && alert.TitleLocKey == "" { + return fmt.Errorf("titleLocKey is required when specifying titleLocArgs") + } + if len(alert.LocArgs) > 0 && alert.LocKey == "" { + return fmt.Errorf("locKey is required when specifying locArgs") + } + return nil +} + +func countNonEmpty(strings ...string) int { + count := 0 + for _, s := range strings { + if s != "" { + count++ + } + } + return count +} From a053b99dd1dbc4165080ec00211d20aea6161f9e Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Wed, 14 Feb 2018 13:24:17 -0800 Subject: [PATCH 09/27] Bumped version to 2.5.0 (#90) --- CHANGELOG.md | 6 +++++- firebase.go | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 48d813f0..3060d302 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,10 +1,14 @@ # Unreleased +- + +# v2.5.0 + - [changed] Import context from `golang.org/x/net` for 1.6 compatibility ### Cloud Messaging -- [feature] Added the `messaging` package for sending Firebase notifications +- [added] Added the `messaging` package for sending Firebase notifications and managing topic subscriptions. ### Authentication diff --git a/firebase.go b/firebase.go index 9ed07125..ed09ac6d 100644 --- a/firebase.go +++ b/firebase.go @@ -48,7 +48,7 @@ var firebaseScopes = []string{ } // Version of the Firebase Go Admin SDK. -const Version = "2.4.0" +const Version = "2.5.0" // firebaseEnvName is the name of the environment variable with the Config. const firebaseEnvName = "FIREBASE_CONFIG" From fd78f9dd24950ed9333ebf2299128c3faedb2e15 Mon Sep 17 00:00:00 2001 From: Cyrille Hemidy Date: Sun, 18 Feb 2018 03:36:37 +0100 Subject: [PATCH 10/27] Lint (#96) * fix misspelling * add check error * missing copyright --- auth/jwt_test.go | 14 ++++++++++++++ firebase_test.go | 2 +- integration/auth/user_mgt_test.go | 3 +++ 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/auth/jwt_test.go b/auth/jwt_test.go index 79264ee0..4b0858af 100644 --- a/auth/jwt_test.go +++ b/auth/jwt_test.go @@ -1,3 +1,17 @@ +// Copyright 2017 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package auth import ( diff --git a/firebase_test.go b/firebase_test.go index fc33ba20..7b17755b 100644 --- a/firebase_test.go +++ b/firebase_test.go @@ -574,7 +574,7 @@ func overwriteEnv(varName, newVal string) string { return oldVal } -// reinstateEnv restores the enviornment variable, will usually be used deferred with overwriteEnv. +// reinstateEnv restores the environment variable, will usually be used deferred with overwriteEnv. func reinstateEnv(varName, oldVal string) { if len(varName) > 0 { os.Setenv(varName, oldVal) diff --git a/integration/auth/user_mgt_test.go b/integration/auth/user_mgt_test.go index 3013fbe0..4c19f6e5 100644 --- a/integration/auth/user_mgt_test.go +++ b/integration/auth/user_mgt_test.go @@ -371,6 +371,9 @@ func testRemoveCustomClaims(t *testing.T) { t.Fatal(err) } u, err = client.GetUser(context.Background(), testFixtures.sampleUserBlank.UID) + if err != nil { + t.Fatal(err) + } if u.CustomClaims != nil { t.Errorf("CustomClaims() = %#v; want = nil", u.CustomClaims) } From 3a386a48687a5c48dabf698c057ddcd30fcc9853 Mon Sep 17 00:00:00 2001 From: Cyrille Hemidy Date: Sun, 18 Feb 2018 03:38:55 +0100 Subject: [PATCH 11/27] Doc (#97) * update readme with Authentication Guide & Release Notes * fix a misspelling : separately * fix missing newline before package * add Go Report Card + update doc --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index 441b5fad..3ea293a2 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,6 @@ [![Build Status](https://travis-ci.org/firebase/firebase-admin-go.svg?branch=master)](https://travis-ci.org/firebase/firebase-admin-go) [![GoDoc](https://godoc.org/firebase.google.com/go?status.svg)](https://godoc.org/firebase.google.com/go) +[![Go Report Card](https://goreportcard.com/badge/github.com/firebase/firebase-admin-go)](https://goreportcard.com/report/github.com/firebase/firebase-admin-go) # Firebase Admin Go SDK @@ -43,6 +44,9 @@ requests, code review feedback, and also pull requests. * [Setup Guide](https://firebase.google.com/docs/admin/setup/) * [Authentication Guide](https://firebase.google.com/docs/auth/admin/) +* [Cloud Firestore](https://firebase.google.com/docs/firestore/) +* [Cloud Messaging Guide](https://firebase.google.com/docs/cloud-messaging/admin/) +* [Storage Guide](https://firebase.google.com/docs/storage/admin/start) * [API Reference](https://godoc.org/firebase.google.com/go) * [Release Notes](https://firebase.google.com/support/release-notes/admin/go) From 387fa390d63ddfef38821e279cfad4bb7e95762e Mon Sep 17 00:00:00 2001 From: Cyrille Hemidy Date: Mon, 26 Feb 2018 22:36:37 +0100 Subject: [PATCH 12/27] add travis build for go versions 1.7.x -> 1.10.x (#98) * add build for go version 1.6.x -> 1.10.x * fix 1.10 version * fix context to golang.org/x/net/context for go 1.6 compatibility * add race detector + go vet on build + build without failure on go unstable * add go16 et go17 file due to req.withcontext which is only go 1.7 * fix context package * update go16.go to remove WithContext * update bad import * remove unused func * finally use ctxhttp.Do with multiple build version * ignore integration package for install * fix go get command * put go 1.6.X in allow_failures dur to test failure * fix inversion of code * remove go 1.6 support * revert initial version with req.WithContext * fix travis to support go 1.10.x * nits --- .travis.yml | 28 +++++++++++++++++++++++++--- messaging/messaging_test.go | 5 ++--- snippets/auth.go | 3 +-- 3 files changed, 28 insertions(+), 8 deletions(-) diff --git a/.travis.yml b/.travis.yml index b1d02a82..53f411b4 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,8 +1,30 @@ language: go + +go: + - 1.7.x + - 1.8.x + - 1.9.x + - "1.10.x" + - master + +matrix: + # Build OK if fails on unstable development versions of Go. + allow_failures: + - go: master + # Don't wait for tests to finish on allow_failures. + # Mark the build finished if tests pass on other versions of Go. + fast_finish: true + go_import_path: firebase.google.com/go + before_install: - - go get github.com/golang/lint/golint + - go get github.com/golang/lint/golint # Golint requires Go 1.6 or later. + +install: + # Prior to golang 1.8, this can trigger an error for packages containing only tests. + - go get -t -v $(go list ./... | grep -v integration) + script: - golint -set_exit_status $(go list ./...) - - go test -v -test.short ./... - + - go test -v -race -test.short ./... # Run tests with the race detector. + - go vet -v ./... # Run Go static analyzer. diff --git a/messaging/messaging_test.go b/messaging/messaging_test.go index 27808e84..6706f217 100644 --- a/messaging/messaging_test.go +++ b/messaging/messaging_test.go @@ -15,7 +15,6 @@ package messaging import ( - "context" "encoding/json" "io/ioutil" "net/http" @@ -25,9 +24,9 @@ import ( "testing" "time" - "google.golang.org/api/option" - "firebase.google.com/go/internal" + "golang.org/x/net/context" + "google.golang.org/api/option" ) const testMessageID = "projects/test-project/messages/msg_id" diff --git a/snippets/auth.go b/snippets/auth.go index 31d24837..68847869 100644 --- a/snippets/auth.go +++ b/snippets/auth.go @@ -15,12 +15,11 @@ package snippets import ( - "context" "log" firebase "firebase.google.com/go" "firebase.google.com/go/auth" - + "golang.org/x/net/context" "google.golang.org/api/iterator" ) From 9d2e4a86a9b5e18a4122bc719be84b6b62179a7d Mon Sep 17 00:00:00 2001 From: avishalom Date: Tue, 27 Feb 2018 14:50:16 -0500 Subject: [PATCH 13/27] Import context from standard package (#101) * Import context from standard package. --- auth/auth.go | 2 +- auth/auth_appengine.go | 2 +- auth/auth_std.go | 2 +- auth/auth_test.go | 2 +- auth/user_mgt.go | 2 +- auth/user_mgt_test.go | 2 +- firebase.go | 2 +- firebase_test.go | 2 +- iid/iid.go | 3 +-- iid/iid_test.go | 3 +-- integration/auth/auth_test.go | 3 +-- integration/auth/user_mgt_test.go | 3 +-- integration/firestore/firestore_test.go | 3 +-- integration/iid/iid_test.go | 3 +-- integration/internal/internal.go | 3 +-- integration/messaging/messaging_test.go | 3 +-- integration/storage/storage_test.go | 2 +- internal/http_client.go | 3 +-- internal/http_client_test.go | 3 +-- messaging/messaging.go | 3 +-- messaging/messaging_test.go | 2 +- snippets/auth.go | 2 +- snippets/messaging.go | 2 +- storage/storage.go | 3 +-- storage/storage_test.go | 2 +- 25 files changed, 25 insertions(+), 37 deletions(-) diff --git a/auth/auth.go b/auth/auth.go index f6605c7b..a0b206a8 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -16,6 +16,7 @@ package auth import ( + "context" "crypto/rsa" "crypto/x509" "encoding/json" @@ -25,7 +26,6 @@ import ( "strings" "firebase.google.com/go/internal" - "golang.org/x/net/context" "google.golang.org/api/identitytoolkit/v3" "google.golang.org/api/transport" ) diff --git a/auth/auth_appengine.go b/auth/auth_appengine.go index 351f61c1..5e05cdb1 100644 --- a/auth/auth_appengine.go +++ b/auth/auth_appengine.go @@ -17,7 +17,7 @@ package auth import ( - "golang.org/x/net/context" + "context" "google.golang.org/appengine" ) diff --git a/auth/auth_std.go b/auth/auth_std.go index 2055af38..f593a7cc 100644 --- a/auth/auth_std.go +++ b/auth/auth_std.go @@ -16,7 +16,7 @@ package auth -import "golang.org/x/net/context" +import "context" func newSigner(ctx context.Context) (signer, error) { return serviceAcctSigner{}, nil diff --git a/auth/auth_test.go b/auth/auth_test.go index 690b5d6f..2676f6c4 100644 --- a/auth/auth_test.go +++ b/auth/auth_test.go @@ -15,6 +15,7 @@ package auth import ( + "context" "encoding/json" "errors" "fmt" @@ -25,7 +26,6 @@ import ( "testing" "time" - "golang.org/x/net/context" "golang.org/x/oauth2/google" "google.golang.org/api/option" diff --git a/auth/user_mgt.go b/auth/user_mgt.go index 423ceece..551753ea 100644 --- a/auth/user_mgt.go +++ b/auth/user_mgt.go @@ -15,6 +15,7 @@ package auth import ( + "context" "encoding/json" "fmt" "net/http" @@ -23,7 +24,6 @@ import ( "strings" "time" - "golang.org/x/net/context" "google.golang.org/api/identitytoolkit/v3" "google.golang.org/api/iterator" ) diff --git a/auth/user_mgt_test.go b/auth/user_mgt_test.go index db7b7300..6a95312a 100644 --- a/auth/user_mgt_test.go +++ b/auth/user_mgt_test.go @@ -16,6 +16,7 @@ package auth import ( "bytes" + "context" "encoding/json" "fmt" "io/ioutil" @@ -28,7 +29,6 @@ import ( "firebase.google.com/go/internal" - "golang.org/x/net/context" "golang.org/x/oauth2" "google.golang.org/api/identitytoolkit/v3" "google.golang.org/api/iterator" diff --git a/firebase.go b/firebase.go index ed09ac6d..b533c12d 100644 --- a/firebase.go +++ b/firebase.go @@ -18,6 +18,7 @@ package firebase import ( + "context" "encoding/json" "errors" "io/ioutil" @@ -31,7 +32,6 @@ import ( "firebase.google.com/go/messaging" "firebase.google.com/go/storage" - "golang.org/x/net/context" "golang.org/x/oauth2/google" "google.golang.org/api/option" "google.golang.org/api/transport" diff --git a/firebase_test.go b/firebase_test.go index 7b17755b..fc3d50d3 100644 --- a/firebase_test.go +++ b/firebase_test.go @@ -15,6 +15,7 @@ package firebase import ( + "context" "fmt" "io/ioutil" "log" @@ -32,7 +33,6 @@ import ( "encoding/json" - "golang.org/x/net/context" "golang.org/x/oauth2" "google.golang.org/api/option" ) diff --git a/iid/iid.go b/iid/iid.go index 980a7bed..b282db40 100644 --- a/iid/iid.go +++ b/iid/iid.go @@ -16,6 +16,7 @@ package iid import ( + "context" "errors" "fmt" "net/http" @@ -23,8 +24,6 @@ import ( "google.golang.org/api/transport" "firebase.google.com/go/internal" - - "golang.org/x/net/context" ) const iidEndpoint = "https://console.firebase.google.com/v1" diff --git a/iid/iid_test.go b/iid/iid_test.go index 6d154650..b3e69638 100644 --- a/iid/iid_test.go +++ b/iid/iid_test.go @@ -15,6 +15,7 @@ package iid import ( + "context" "fmt" "net/http" "net/http/httptest" @@ -23,8 +24,6 @@ import ( "google.golang.org/api/option" "firebase.google.com/go/internal" - - "golang.org/x/net/context" ) var testIIDConfig = &internal.InstanceIDConfig{ diff --git a/integration/auth/auth_test.go b/integration/auth/auth_test.go index 5e027d44..33d8cf2c 100644 --- a/integration/auth/auth_test.go +++ b/integration/auth/auth_test.go @@ -17,6 +17,7 @@ package auth import ( "bytes" + "context" "encoding/json" "flag" "fmt" @@ -29,8 +30,6 @@ import ( "firebase.google.com/go/auth" "firebase.google.com/go/integration/internal" - - "golang.org/x/net/context" ) const apiURL = "https://www.googleapis.com/identitytoolkit/v3/relyingparty/verifyCustomToken?key=%s" diff --git a/integration/auth/user_mgt_test.go b/integration/auth/user_mgt_test.go index 4c19f6e5..4121d62c 100644 --- a/integration/auth/user_mgt_test.go +++ b/integration/auth/user_mgt_test.go @@ -16,6 +16,7 @@ package auth import ( + "context" "fmt" "reflect" "testing" @@ -24,8 +25,6 @@ import ( "google.golang.org/api/iterator" "firebase.google.com/go/auth" - - "golang.org/x/net/context" ) var testFixtures = struct { diff --git a/integration/firestore/firestore_test.go b/integration/firestore/firestore_test.go index 6e7b4e28..6c367205 100644 --- a/integration/firestore/firestore_test.go +++ b/integration/firestore/firestore_test.go @@ -15,13 +15,12 @@ package firestore import ( + "context" "log" "reflect" "testing" "firebase.google.com/go/integration/internal" - - "golang.org/x/net/context" ) func TestFirestore(t *testing.T) { diff --git a/integration/iid/iid_test.go b/integration/iid/iid_test.go index 9be5dce0..55cf5620 100644 --- a/integration/iid/iid_test.go +++ b/integration/iid/iid_test.go @@ -16,6 +16,7 @@ package iid import ( + "context" "flag" "log" "os" @@ -23,8 +24,6 @@ import ( "firebase.google.com/go/iid" "firebase.google.com/go/integration/internal" - - "golang.org/x/net/context" ) var client *iid.Client diff --git a/integration/internal/internal.go b/integration/internal/internal.go index bc52a16a..1fe112a3 100644 --- a/integration/internal/internal.go +++ b/integration/internal/internal.go @@ -16,14 +16,13 @@ package internal import ( + "context" "encoding/json" "go/build" "io/ioutil" "path/filepath" "strings" - "golang.org/x/net/context" - firebase "firebase.google.com/go" "google.golang.org/api/option" ) diff --git a/integration/messaging/messaging_test.go b/integration/messaging/messaging_test.go index d7bb0693..231aab34 100644 --- a/integration/messaging/messaging_test.go +++ b/integration/messaging/messaging_test.go @@ -15,14 +15,13 @@ package messaging import ( + "context" "flag" "log" "os" "regexp" "testing" - "golang.org/x/net/context" - "firebase.google.com/go/integration/internal" "firebase.google.com/go/messaging" ) diff --git a/integration/storage/storage_test.go b/integration/storage/storage_test.go index 5efe92d2..b6fc2301 100644 --- a/integration/storage/storage_test.go +++ b/integration/storage/storage_test.go @@ -15,6 +15,7 @@ package storage import ( + "context" "flag" "fmt" "io/ioutil" @@ -25,7 +26,6 @@ import ( gcs "cloud.google.com/go/storage" "firebase.google.com/go/integration/internal" "firebase.google.com/go/storage" - "golang.org/x/net/context" ) var ctx context.Context diff --git a/internal/http_client.go b/internal/http_client.go index bd40c366..984e8a1d 100644 --- a/internal/http_client.go +++ b/internal/http_client.go @@ -16,13 +16,12 @@ package internal import ( "bytes" + "context" "encoding/json" "fmt" "io" "io/ioutil" "net/http" - - "golang.org/x/net/context" ) // HTTPClient is a convenient API to make HTTP calls. diff --git a/internal/http_client_test.go b/internal/http_client_test.go index bdac7474..14729d17 100644 --- a/internal/http_client_test.go +++ b/internal/http_client_test.go @@ -14,14 +14,13 @@ package internal import ( + "context" "encoding/json" "io/ioutil" "net/http" "net/http/httptest" "reflect" "testing" - - "golang.org/x/net/context" ) var cases = []struct { diff --git a/messaging/messaging.go b/messaging/messaging.go index 97b77d64..a3d71e97 100644 --- a/messaging/messaging.go +++ b/messaging/messaging.go @@ -17,6 +17,7 @@ package messaging import ( + "context" "encoding/json" "errors" "fmt" @@ -25,8 +26,6 @@ import ( "strings" "time" - "golang.org/x/net/context" - "firebase.google.com/go/internal" "google.golang.org/api/transport" ) diff --git a/messaging/messaging_test.go b/messaging/messaging_test.go index 6706f217..53b0650a 100644 --- a/messaging/messaging_test.go +++ b/messaging/messaging_test.go @@ -15,6 +15,7 @@ package messaging import ( + "context" "encoding/json" "io/ioutil" "net/http" @@ -25,7 +26,6 @@ import ( "time" "firebase.google.com/go/internal" - "golang.org/x/net/context" "google.golang.org/api/option" ) diff --git a/snippets/auth.go b/snippets/auth.go index 68847869..d9548e94 100644 --- a/snippets/auth.go +++ b/snippets/auth.go @@ -15,11 +15,11 @@ package snippets import ( + "context" "log" firebase "firebase.google.com/go" "firebase.google.com/go/auth" - "golang.org/x/net/context" "google.golang.org/api/iterator" ) diff --git a/snippets/messaging.go b/snippets/messaging.go index f558b40c..18f6e462 100644 --- a/snippets/messaging.go +++ b/snippets/messaging.go @@ -15,13 +15,13 @@ package snippets import ( + "context" "fmt" "log" "time" "firebase.google.com/go" "firebase.google.com/go/messaging" - "golang.org/x/net/context" ) func sendToToken(app *firebase.App) { diff --git a/storage/storage.go b/storage/storage.go index 985b6eb7..878e2175 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -16,12 +16,11 @@ package storage import ( + "context" "errors" "cloud.google.com/go/storage" "firebase.google.com/go/internal" - - "golang.org/x/net/context" ) // Client is the interface for the Firebase Storage service. diff --git a/storage/storage_test.go b/storage/storage_test.go index 833aedf3..7a77e60c 100644 --- a/storage/storage_test.go +++ b/storage/storage_test.go @@ -15,12 +15,12 @@ package storage import ( + "context" "testing" "google.golang.org/api/option" "firebase.google.com/go/internal" - "golang.org/x/net/context" ) var opts = []option.ClientOption{ From c9be1e93083aa43d2aaa825d790991ae1634d064 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Tue, 27 Feb 2018 17:12:50 -0800 Subject: [PATCH 14/27] Firebase Database API (#92) * Experimental RTDB code * Added ref.Set() * Added Push(), Update(), Remove() and tests * Adding Transaction() support * Fixed Transaction() API * Code cleanup * Implemented Query() API * Added GetIfChanged() and integration tests * More integration tests * Updated unit test * More integration tests * Integration tests for queries * Auth override support and more tests * More test cases; AuthOverride support in App * Implemented AuthOverride support; Added tests * Implementing the new API * More code cleanup * Code clean up * Refactored the http client code * More tests * Boosted test coverage to 97% * Better error messages in tests; Added license headers * Added documentatioon and cleaned up tests * Fixing a build break * Finishing up documentation * More test cases * Implemented a reusable HTTP client API * Added test cases * Comment clean up * Using the shared http client API * Simplified the usage by adding HTTPClient * using the new client API * Using the old ctx import * Using the old context import * Refactored db code * More refactoring * Support for arbitrary entity types in the request * Renamed fields; Added documentation * Removing a redundant else case * Code readability improvements * Cleaned up the RTDB HTTP client code * Added shallow reads support; Added the new txn API * Implementing GetOrdered() for queries * Adding more sorting tests * Added Query ordering tests * Fixing some lint errors and compilation errors * Removing unused function * Cleaned up unit tests for db * Updated query impl and tests * Added integration tests for ordered queries * Removed With*() from query functions * Updated change log; Added more tests * Support for database url in auto init * Support for loading auth overrides from env * Removed db.AuthOverride type * Renamed ao to authOverride everywhere; Other code review nits * Introducing the QueryNode interface to handle ordered query results (#100) * Database Sample Snippets (#102) * Adding database snippets * Adding query snippets * Added complex query samples * Updated variable name * Fixing a typo * Fixing query example * Updated DB snippets to use GetOrdered() * Removing unnecessary placeholders in Fatalln() calls * Removing unnecessary placeholders in Fatalln() calls --- CHANGELOG.md | 2 +- auth/auth.go | 2 +- db/auth_override_test.go | 107 ++++ db/db.go | 134 ++++ db/db_test.go | 404 +++++++++++++ db/query.go | 423 +++++++++++++ db/query_test.go | 774 ++++++++++++++++++++++++ db/ref.go | 262 ++++++++ db/ref_test.go | 729 ++++++++++++++++++++++ firebase.go | 56 +- firebase_test.go | 87 +++ integration/auth/auth_test.go | 2 +- integration/db/db_test.go | 709 ++++++++++++++++++++++ integration/db/query_test.go | 266 ++++++++ integration/firestore/firestore_test.go | 2 +- integration/iid/iid_test.go | 2 +- integration/internal/internal.go | 25 +- integration/messaging/messaging_test.go | 2 +- integration/storage/storage_test.go | 11 +- internal/internal.go | 18 + snippets/auth.go | 25 +- snippets/db.go | 528 ++++++++++++++++ testdata/dinosaurs.json | 78 +++ testdata/dinosaurs_index.json | 29 + testdata/firebase_config.json | 1 + 25 files changed, 4634 insertions(+), 44 deletions(-) create mode 100644 db/auth_override_test.go create mode 100644 db/db.go create mode 100644 db/db_test.go create mode 100644 db/query.go create mode 100644 db/query_test.go create mode 100644 db/ref.go create mode 100644 db/ref_test.go create mode 100644 integration/db/db_test.go create mode 100644 integration/db/query_test.go create mode 100644 snippets/db.go create mode 100644 testdata/dinosaurs.json create mode 100644 testdata/dinosaurs_index.json diff --git a/CHANGELOG.md b/CHANGELOG.md index 3060d302..9b2583f7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ # Unreleased -- +- [added] Added the `db` package for interacting with the Firebase database. # v2.5.0 diff --git a/auth/auth.go b/auth/auth.go index a0b206a8..98822fef 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -78,7 +78,7 @@ type signer interface { // NewClient creates a new instance of the Firebase Auth Client. // // This function can only be invoked from within the SDK. Client applications should access the -// the Auth service through firebase.App. +// Auth service through firebase.App. func NewClient(ctx context.Context, c *internal.AuthConfig) (*Client, error) { var ( err error diff --git a/db/auth_override_test.go b/db/auth_override_test.go new file mode 100644 index 00000000..86cbeef2 --- /dev/null +++ b/db/auth_override_test.go @@ -0,0 +1,107 @@ +// Copyright 2018 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package db + +import ( + "testing" + + "golang.org/x/net/context" +) + +func TestAuthOverrideGet(t *testing.T) { + mock := &mockServer{Resp: "data"} + srv := mock.Start(aoClient) + defer srv.Close() + + ref := aoClient.NewRef("peter") + var got string + if err := ref.Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + if got != "data" { + t.Errorf("Ref(AuthOverride).Get() = %q; want = %q", got, "data") + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"auth_variable_override": testAuthOverrides}, + }) +} + +func TestAuthOverrideSet(t *testing.T) { + mock := &mockServer{} + srv := mock.Start(aoClient) + defer srv.Close() + + ref := aoClient.NewRef("peter") + want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} + if err := ref.Set(context.Background(), want); err != nil { + t.Fatal(err) + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "PUT", + Body: serialize(want), + Path: "/peter.json", + Query: map[string]string{"auth_variable_override": testAuthOverrides, "print": "silent"}, + }) +} + +func TestAuthOverrideQuery(t *testing.T) { + mock := &mockServer{Resp: "data"} + srv := mock.Start(aoClient) + defer srv.Close() + + ref := aoClient.NewRef("peter") + var got string + if err := ref.OrderByChild("foo").Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + if got != "data" { + t.Errorf("Ref(AuthOverride).OrderByChild() = %q; want = %q", got, "data") + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{ + "auth_variable_override": testAuthOverrides, + "orderBy": "\"foo\"", + }, + }) +} + +func TestAuthOverrideRangeQuery(t *testing.T) { + mock := &mockServer{Resp: "data"} + srv := mock.Start(aoClient) + defer srv.Close() + + ref := aoClient.NewRef("peter") + var got string + if err := ref.OrderByChild("foo").StartAt(1).EndAt(10).Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + if got != "data" { + t.Errorf("Ref(AuthOverride).OrderByChild() = %q; want = %q", got, "data") + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{ + "auth_variable_override": testAuthOverrides, + "orderBy": "\"foo\"", + "startAt": "1", + "endAt": "10", + }, + }) +} diff --git a/db/db.go b/db/db.go new file mode 100644 index 00000000..6bed3922 --- /dev/null +++ b/db/db.go @@ -0,0 +1,134 @@ +// Copyright 2018 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package db contains functions for accessing the Firebase Realtime Database. +package db + +import ( + "encoding/json" + "fmt" + "runtime" + "strings" + + "firebase.google.com/go/internal" + + "net/url" + + "golang.org/x/net/context" + "google.golang.org/api/option" + "google.golang.org/api/transport" +) + +const userAgentFormat = "Firebase/HTTP/%s/%s/AdminGo" +const invalidChars = "[].#$" +const authVarOverride = "auth_variable_override" + +// Client is the interface for the Firebase Realtime Database service. +type Client struct { + hc *internal.HTTPClient + url string + authOverride string +} + +// NewClient creates a new instance of the Firebase Database Client. +// +// This function can only be invoked from within the SDK. Client applications should access the +// Database service through firebase.App. +func NewClient(ctx context.Context, c *internal.DatabaseConfig) (*Client, error) { + opts := append([]option.ClientOption{}, c.Opts...) + ua := fmt.Sprintf(userAgentFormat, c.Version, runtime.Version()) + opts = append(opts, option.WithUserAgent(ua)) + hc, _, err := transport.NewHTTPClient(ctx, opts...) + if err != nil { + return nil, err + } + + p, err := url.ParseRequestURI(c.URL) + if err != nil { + return nil, err + } else if p.Scheme != "https" { + return nil, fmt.Errorf("invalid database URL: %q; want scheme: %q", c.URL, "https") + } else if !strings.HasSuffix(p.Host, ".firebaseio.com") { + return nil, fmt.Errorf("invalid database URL: %q; want host: %q", c.URL, "firebaseio.com") + } + + var ao []byte + if c.AuthOverride == nil || len(c.AuthOverride) > 0 { + ao, err = json.Marshal(c.AuthOverride) + if err != nil { + return nil, err + } + } + + ep := func(b []byte) string { + var p struct { + Error string `json:"error"` + } + if err := json.Unmarshal(b, &p); err != nil { + return "" + } + return p.Error + } + return &Client{ + hc: &internal.HTTPClient{Client: hc, ErrParser: ep}, + url: fmt.Sprintf("https://%s", p.Host), + authOverride: string(ao), + }, nil +} + +// NewRef returns a new database reference representing the node at the specified path. +func (c *Client) NewRef(path string) *Ref { + segs := parsePath(path) + key := "" + if len(segs) > 0 { + key = segs[len(segs)-1] + } + + return &Ref{ + Key: key, + Path: "/" + strings.Join(segs, "/"), + client: c, + segs: segs, + } +} + +func (c *Client) send( + ctx context.Context, + method, path string, + body internal.HTTPEntity, + opts ...internal.HTTPOption) (*internal.Response, error) { + + if strings.ContainsAny(path, invalidChars) { + return nil, fmt.Errorf("invalid path with illegal characters: %q", path) + } + if c.authOverride != "" { + opts = append(opts, internal.WithQueryParam(authVarOverride, c.authOverride)) + } + return c.hc.Do(ctx, &internal.Request{ + Method: method, + URL: fmt.Sprintf("%s%s.json", c.url, path), + Body: body, + Opts: opts, + }) +} + +func parsePath(path string) []string { + var segs []string + for _, s := range strings.Split(path, "/") { + if s != "" { + segs = append(segs, s) + } + } + return segs +} diff --git a/db/db_test.go b/db/db_test.go new file mode 100644 index 00000000..01234504 --- /dev/null +++ b/db/db_test.go @@ -0,0 +1,404 @@ +// Copyright 2018 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package db + +import ( + "fmt" + "log" + "net/http" + "net/http/httptest" + "os" + "runtime" + "testing" + + "golang.org/x/net/context" + "golang.org/x/oauth2" + + "encoding/json" + + "reflect" + + "io/ioutil" + + "net/url" + + "firebase.google.com/go/internal" + "google.golang.org/api/option" +) + +const testURL = "https://test-db.firebaseio.com" + +var testUserAgent string +var testAuthOverrides string +var testOpts = []option.ClientOption{ + option.WithTokenSource(&mockTokenSource{"mock-token"}), +} + +var client *Client +var aoClient *Client +var testref *Ref + +func TestMain(m *testing.M) { + var err error + client, err = NewClient(context.Background(), &internal.DatabaseConfig{ + Opts: testOpts, + URL: testURL, + Version: "1.2.3", + AuthOverride: map[string]interface{}{}, + }) + if err != nil { + log.Fatalln(err) + } + + ao := map[string]interface{}{"uid": "user1"} + aoClient, err = NewClient(context.Background(), &internal.DatabaseConfig{ + Opts: testOpts, + URL: testURL, + Version: "1.2.3", + AuthOverride: ao, + }) + if err != nil { + log.Fatalln(err) + } + + b, err := json.Marshal(ao) + if err != nil { + log.Fatalln(err) + } + testAuthOverrides = string(b) + + testref = client.NewRef("peter") + testUserAgent = fmt.Sprintf(userAgentFormat, "1.2.3", runtime.Version()) + os.Exit(m.Run()) +} + +func TestNewClient(t *testing.T) { + c, err := NewClient(context.Background(), &internal.DatabaseConfig{ + Opts: testOpts, + URL: testURL, + AuthOverride: make(map[string]interface{}), + }) + if err != nil { + t.Fatal(err) + } + if c.url != testURL { + t.Errorf("NewClient().url = %q; want = %q", c.url, testURL) + } + if c.hc == nil { + t.Errorf("NewClient().hc = nil; want non-nil") + } + if c.authOverride != "" { + t.Errorf("NewClient().ao = %q; want = %q", c.authOverride, "") + } +} + +func TestNewClientAuthOverrides(t *testing.T) { + cases := []map[string]interface{}{ + nil, + map[string]interface{}{"uid": "user1"}, + } + for _, tc := range cases { + c, err := NewClient(context.Background(), &internal.DatabaseConfig{ + Opts: testOpts, + URL: testURL, + AuthOverride: tc, + }) + if err != nil { + t.Fatal(err) + } + if c.url != testURL { + t.Errorf("NewClient(%v).url = %q; want = %q", tc, c.url, testURL) + } + if c.hc == nil { + t.Errorf("NewClient(%v).hc = nil; want non-nil", tc) + } + b, err := json.Marshal(tc) + if err != nil { + t.Fatal(err) + } + if c.authOverride != string(b) { + t.Errorf("NewClient(%v).ao = %q; want = %q", tc, c.authOverride, string(b)) + } + } +} + +func TestInvalidURL(t *testing.T) { + cases := []string{ + "", + "foo", + "http://db.firebaseio.com", + "https://firebase.google.com", + } + for _, tc := range cases { + c, err := NewClient(context.Background(), &internal.DatabaseConfig{ + Opts: testOpts, + URL: tc, + }) + if c != nil || err == nil { + t.Errorf("NewClient(%q) = (%v, %v); want = (nil, error)", tc, c, err) + } + } +} + +func TestInvalidAuthOverride(t *testing.T) { + c, err := NewClient(context.Background(), &internal.DatabaseConfig{ + Opts: testOpts, + URL: testURL, + AuthOverride: map[string]interface{}{"uid": func() {}}, + }) + if c != nil || err == nil { + t.Errorf("NewClient() = (%v, %v); want = (nil, error)", c, err) + } +} + +func TestNewRef(t *testing.T) { + cases := []struct { + Path string + WantPath string + WantKey string + }{ + {"", "/", ""}, + {"/", "/", ""}, + {"foo", "/foo", "foo"}, + {"/foo", "/foo", "foo"}, + {"foo/bar", "/foo/bar", "bar"}, + {"/foo/bar", "/foo/bar", "bar"}, + {"/foo/bar/", "/foo/bar", "bar"}, + } + for _, tc := range cases { + r := client.NewRef(tc.Path) + if r.client == nil { + t.Errorf("NewRef(%q).client = nil; want = %v", tc.Path, r.client) + } + if r.Path != tc.WantPath { + t.Errorf("NewRef(%q).Path = %q; want = %q", tc.Path, r.Path, tc.WantPath) + } + if r.Key != tc.WantKey { + t.Errorf("NewRef(%q).Key = %q; want = %q", tc.Path, r.Key, tc.WantKey) + } + } +} + +func TestParent(t *testing.T) { + cases := []struct { + Path string + HasParent bool + Want string + }{ + {"", false, ""}, + {"/", false, ""}, + {"foo", true, ""}, + {"/foo", true, ""}, + {"foo/bar", true, "foo"}, + {"/foo/bar", true, "foo"}, + {"/foo/bar/", true, "foo"}, + } + for _, tc := range cases { + r := client.NewRef(tc.Path).Parent() + if tc.HasParent { + if r == nil { + t.Fatalf("Parent(%q) = nil; want = Ref(%q)", tc.Path, tc.Want) + } + if r.client == nil { + t.Errorf("Parent(%q).client = nil; want = %v", tc.Path, client) + } + if r.Key != tc.Want { + t.Errorf("Parent(%q).Key = %q; want = %q", tc.Path, r.Key, tc.Want) + } + } else if r != nil { + t.Fatalf("Parent(%q) = %v; want = nil", tc.Path, r) + } + } +} + +func TestChild(t *testing.T) { + r := client.NewRef("/test") + cases := []struct { + Path string + Want string + Parent string + }{ + {"", "/test", "/"}, + {"foo", "/test/foo", "/test"}, + {"/foo", "/test/foo", "/test"}, + {"foo/", "/test/foo", "/test"}, + {"/foo/", "/test/foo", "/test"}, + {"//foo//", "/test/foo", "/test"}, + {"foo/bar", "/test/foo/bar", "/test/foo"}, + {"/foo/bar", "/test/foo/bar", "/test/foo"}, + {"foo/bar/", "/test/foo/bar", "/test/foo"}, + {"/foo/bar/", "/test/foo/bar", "/test/foo"}, + {"//foo/bar", "/test/foo/bar", "/test/foo"}, + {"foo//bar/", "/test/foo/bar", "/test/foo"}, + {"foo/bar//", "/test/foo/bar", "/test/foo"}, + } + for _, tc := range cases { + c := r.Child(tc.Path) + if c.Path != tc.Want { + t.Errorf("Child(%q) = %q; want = %q", tc.Path, c.Path, tc.Want) + } + if c.Parent().Path != tc.Parent { + t.Errorf("Child(%q).Parent() = %q; want = %q", tc.Path, c.Parent().Path, tc.Parent) + } + } +} + +func checkOnlyRequest(t *testing.T, got []*testReq, want *testReq) { + checkAllRequests(t, got, []*testReq{want}) +} + +func checkAllRequests(t *testing.T, got []*testReq, want []*testReq) { + if len(got) != len(want) { + t.Errorf("Request Count = %d; want = %d", len(got), len(want)) + } else { + for i, r := range got { + checkRequest(t, r, want[i]) + } + } +} + +func checkRequest(t *testing.T, got, want *testReq) { + if h := got.Header.Get("Authorization"); h != "Bearer mock-token" { + t.Errorf("Authorization = %q; want = %q", h, "Bearer mock-token") + } + if h := got.Header.Get("User-Agent"); h != testUserAgent { + t.Errorf("User-Agent = %q; want = %q", h, testUserAgent) + } + + if got.Method != want.Method { + t.Errorf("Method = %q; want = %q", got.Method, want.Method) + } + + if got.Path != want.Path { + t.Errorf("Path = %q; want = %q", got.Path, want.Path) + } + if len(want.Query) != len(got.Query) { + t.Errorf("QueryParam = %v; want = %v", got.Query, want.Query) + } + for k, v := range want.Query { + if got.Query[k] != v { + t.Errorf("QueryParam(%v) = %v; want = %v", k, got.Query[k], v) + } + } + for k, v := range want.Header { + if got.Header.Get(k) != v[0] { + t.Errorf("Header(%q) = %q; want = %q", k, got.Header.Get(k), v[0]) + } + } + if want.Body != nil { + if h := got.Header.Get("Content-Type"); h != "application/json" { + t.Errorf("User-Agent = %q; want = %q", h, "application/json") + } + var wi, gi interface{} + if err := json.Unmarshal(want.Body, &wi); err != nil { + t.Fatal(err) + } + if err := json.Unmarshal(got.Body, &gi); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(gi, wi) { + t.Errorf("Body = %v; want = %v", gi, wi) + } + } else if len(got.Body) != 0 { + t.Errorf("Body = %v; want empty", got.Body) + } +} + +type testReq struct { + Method string + Path string + Header http.Header + Body []byte + Query map[string]string +} + +func newTestReq(r *http.Request) (*testReq, error) { + defer r.Body.Close() + b, err := ioutil.ReadAll(r.Body) + if err != nil { + return nil, err + } + + u, err := url.Parse(r.RequestURI) + if err != nil { + return nil, err + } + + query := make(map[string]string) + for k, v := range u.Query() { + query[k] = v[0] + } + return &testReq{ + Method: r.Method, + Path: u.Path, + Header: r.Header, + Body: b, + Query: query, + }, nil +} + +type mockServer struct { + Resp interface{} + Header map[string]string + Status int + Reqs []*testReq + srv *httptest.Server +} + +func (s *mockServer) Start(c *Client) *httptest.Server { + if s.srv != nil { + return s.srv + } + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + tr, _ := newTestReq(r) + s.Reqs = append(s.Reqs, tr) + + for k, v := range s.Header { + w.Header().Set(k, v) + } + + print := r.URL.Query().Get("print") + if s.Status != 0 { + w.WriteHeader(s.Status) + } else if print == "silent" { + w.WriteHeader(http.StatusNoContent) + return + } + b, _ := json.Marshal(s.Resp) + w.Header().Set("Content-Type", "application/json") + w.Write(b) + }) + s.srv = httptest.NewServer(handler) + c.url = s.srv.URL + return s.srv +} + +type mockTokenSource struct { + AccessToken string +} + +func (ts *mockTokenSource) Token() (*oauth2.Token, error) { + return &oauth2.Token{AccessToken: ts.AccessToken}, nil +} + +type person struct { + Name string `json:"name"` + Age int32 `json:"age"` +} + +func serialize(v interface{}) []byte { + b, _ := json.Marshal(v) + return b +} diff --git a/db/query.go b/db/query.go new file mode 100644 index 00000000..c6013483 --- /dev/null +++ b/db/query.go @@ -0,0 +1,423 @@ +// Copyright 2018 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package db + +import ( + "encoding/json" + "fmt" + "net/http" + "sort" + "strconv" + "strings" + + "firebase.google.com/go/internal" + + "golang.org/x/net/context" +) + +// QueryNode represents a data node retrieved from an ordered query. +type QueryNode interface { + Key() string + Unmarshal(v interface{}) error +} + +// Query represents a complex query that can be executed on a Ref. +// +// Complex queries can consist of up to 2 components: a required ordering constraint, and an +// optional filtering constraint. At the server, data is first sorted according to the given +// ordering constraint (e.g. order by child). Then the filtering constraint (e.g. limit, range) is +// applied on the sorted data to produce the final result. Despite the ordering constraint, the +// final result is returned by the server as an unordered collection. Therefore the values read +// from a Query instance are not ordered. +type Query struct { + client *Client + path string + order orderBy + limFirst, limLast int + start, end, equalTo interface{} +} + +// StartAt returns a shallow copy of the Query with v set as a lower bound of a range query. +// +// The resulting Query will only return child nodes with a value greater than or equal to v. +func (q *Query) StartAt(v interface{}) *Query { + q2 := &Query{} + *q2 = *q + q2.start = v + return q2 +} + +// EndAt returns a shallow copy of the Query with v set as a upper bound of a range query. +// +// The resulting Query will only return child nodes with a value less than or equal to v. +func (q *Query) EndAt(v interface{}) *Query { + q2 := &Query{} + *q2 = *q + q2.end = v + return q2 +} + +// EqualTo returns a shallow copy of the Query with v set as an equals constraint. +// +// The resulting Query will only return child nodes whose values equal to v. +func (q *Query) EqualTo(v interface{}) *Query { + q2 := &Query{} + *q2 = *q + q2.equalTo = v + return q2 +} + +// LimitToFirst returns a shallow copy of the Query, which is anchored to the first n +// elements of the window. +func (q *Query) LimitToFirst(n int) *Query { + q2 := &Query{} + *q2 = *q + q2.limFirst = n + return q2 +} + +// LimitToLast returns a shallow copy of the Query, which is anchored to the last n +// elements of the window. +func (q *Query) LimitToLast(n int) *Query { + q2 := &Query{} + *q2 = *q + q2.limLast = n + return q2 +} + +// Get executes the Query and populates v with the results. +// +// Data deserialization is performed using https://golang.org/pkg/encoding/json/#Unmarshal, and +// therefore v has the same requirements as the json package. Specifically, it must be a pointer, +// and must not be nil. +// +// Despite the ordering constraint of the Query, results are not stored in any particular order +// in v. Use GetOrdered() to obtain ordered results. +func (q *Query) Get(ctx context.Context, v interface{}) error { + qp := make(map[string]string) + if err := initQueryParams(q, qp); err != nil { + return err + } + resp, err := q.client.send(ctx, "GET", q.path, nil, internal.WithQueryParams(qp)) + if err != nil { + return err + } + return resp.Unmarshal(http.StatusOK, v) +} + +// GetOrdered executes the Query and returns the results as an ordered slice. +func (q *Query) GetOrdered(ctx context.Context) ([]QueryNode, error) { + var temp interface{} + if err := q.Get(ctx, &temp); err != nil { + return nil, err + } + if temp == nil { + return nil, nil + } + + sn := newSortableNodes(temp, q.order) + sort.Sort(sn) + result := make([]QueryNode, len(sn)) + for i, v := range sn { + result[i] = v + } + return result, nil +} + +// OrderByChild returns a Query that orders data by child values before applying filters. +// +// Returned Query can be used to set additional parameters, and execute complex database queries +// (e.g. limit queries, range queries). If r has a context associated with it, the resulting Query +// will inherit it. +func (r *Ref) OrderByChild(child string) *Query { + return newQuery(r, orderByChild(child)) +} + +// OrderByKey returns a Query that orders data by key before applying filters. +// +// Returned Query can be used to set additional parameters, and execute complex database queries +// (e.g. limit queries, range queries). If r has a context associated with it, the resulting Query +// will inherit it. +func (r *Ref) OrderByKey() *Query { + return newQuery(r, orderByProperty("$key")) +} + +// OrderByValue returns a Query that orders data by value before applying filters. +// +// Returned Query can be used to set additional parameters, and execute complex database queries +// (e.g. limit queries, range queries). If r has a context associated with it, the resulting Query +// will inherit it. +func (r *Ref) OrderByValue() *Query { + return newQuery(r, orderByProperty("$value")) +} + +func newQuery(r *Ref, ob orderBy) *Query { + return &Query{ + client: r.client, + path: r.Path, + order: ob, + } +} + +func initQueryParams(q *Query, qp map[string]string) error { + ob, err := q.order.encode() + if err != nil { + return err + } + qp["orderBy"] = ob + + if q.limFirst > 0 && q.limLast > 0 { + return fmt.Errorf("cannot set both limit parameter: first = %d, last = %d", q.limFirst, q.limLast) + } else if q.limFirst < 0 { + return fmt.Errorf("limit first cannot be negative: %d", q.limFirst) + } else if q.limLast < 0 { + return fmt.Errorf("limit last cannot be negative: %d", q.limLast) + } + + if q.limFirst > 0 { + qp["limitToFirst"] = strconv.Itoa(q.limFirst) + } else if q.limLast > 0 { + qp["limitToLast"] = strconv.Itoa(q.limLast) + } + + if err := encodeFilter("startAt", q.start, qp); err != nil { + return err + } + if err := encodeFilter("endAt", q.end, qp); err != nil { + return err + } + return encodeFilter("equalTo", q.equalTo, qp) +} + +func encodeFilter(key string, val interface{}, m map[string]string) error { + if val == nil { + return nil + } + b, err := json.Marshal(val) + if err != nil { + return err + } + m[key] = string(b) + return nil +} + +type orderBy interface { + encode() (string, error) +} + +type orderByChild string + +func (p orderByChild) encode() (string, error) { + if p == "" { + return "", fmt.Errorf("empty child path") + } else if strings.ContainsAny(string(p), invalidChars) { + return "", fmt.Errorf("invalid child path with illegal characters: %q", p) + } + segs := parsePath(string(p)) + if len(segs) == 0 { + return "", fmt.Errorf("invalid child path: %q", p) + } + b, err := json.Marshal(strings.Join(segs, "/")) + if err != nil { + return "", nil + } + return string(b), nil +} + +type orderByProperty string + +func (p orderByProperty) encode() (string, error) { + b, err := json.Marshal(p) + if err != nil { + return "", err + } + return string(b), nil +} + +// Firebase type ordering: https://firebase.google.com/docs/database/rest/retrieve-data#section-rest-ordered-data +const ( + typeNull = 0 + typeBoolFalse = 1 + typeBoolTrue = 2 + typeNumeric = 3 + typeString = 4 + typeObject = 5 +) + +// comparableKey is a union type of numeric values and strings. +type comparableKey struct { + Num *float64 + Str *string +} + +func (k *comparableKey) Compare(o *comparableKey) int { + if k.Str != nil && o.Str != nil { + return strings.Compare(*k.Str, *o.Str) + } else if k.Num != nil && o.Num != nil { + if *k.Num < *o.Num { + return -1 + } else if *k.Num == *o.Num { + return 0 + } + return 1 + } else if k.Num != nil { + // numeric keys appear before string keys + return -1 + } + return 1 +} + +func newComparableKey(v interface{}) *comparableKey { + if s, ok := v.(string); ok { + return &comparableKey{Str: &s} + } + + // Numeric values could be int (in the case of array indices and type constants), or float64 (if + // the value was received as json). + if i, ok := v.(int); ok { + f := float64(i) + return &comparableKey{Num: &f} + } + + f := v.(float64) + return &comparableKey{Num: &f} +} + +type queryNodeImpl struct { + CompKey *comparableKey + Value interface{} + Index interface{} + IndexType int +} + +func (q *queryNodeImpl) Key() string { + if q.CompKey.Str != nil { + return *q.CompKey.Str + } + // Numeric keys in queryNodeImpl are always array indices, and can be safely coverted into int. + return strconv.Itoa(int(*q.CompKey.Num)) +} + +func (q *queryNodeImpl) Unmarshal(v interface{}) error { + b, err := json.Marshal(q.Value) + if err != nil { + return err + } + return json.Unmarshal(b, v) +} + +func newQueryNode(key, val interface{}, order orderBy) *queryNodeImpl { + var index interface{} + if prop, ok := order.(orderByProperty); ok { + if prop == "$value" { + index = val + } else { + index = key + } + } else { + path := order.(orderByChild) + index = extractChildValue(val, string(path)) + } + return &queryNodeImpl{ + CompKey: newComparableKey(key), + Value: val, + Index: index, + IndexType: getIndexType(index), + } +} + +type sortableNodes []*queryNodeImpl + +func (s sortableNodes) Len() int { + return len(s) +} + +func (s sortableNodes) Swap(i, j int) { + s[i], s[j] = s[j], s[i] +} + +func (s sortableNodes) Less(i, j int) bool { + a, b := s[i], s[j] + var aKey, bKey *comparableKey + if a.IndexType == b.IndexType { + // If the indices have the same type and are comparable (i.e. numeric or string), compare + // them directly. Otherwise, compare the keys. + if (a.IndexType == typeNumeric || a.IndexType == typeString) && a.Index != b.Index { + aKey, bKey = newComparableKey(a.Index), newComparableKey(b.Index) + } else { + aKey, bKey = a.CompKey, b.CompKey + } + } else { + // If the indices are of different types, use the type ordering of Firebase. + aKey, bKey = newComparableKey(a.IndexType), newComparableKey(b.IndexType) + } + + return aKey.Compare(bKey) < 0 +} + +func newSortableNodes(values interface{}, order orderBy) sortableNodes { + var entries sortableNodes + if m, ok := values.(map[string]interface{}); ok { + for key, val := range m { + entries = append(entries, newQueryNode(key, val, order)) + } + } else if l, ok := values.([]interface{}); ok { + for key, val := range l { + entries = append(entries, newQueryNode(key, val, order)) + } + } else { + entries = append(entries, newQueryNode(0, values, order)) + } + return entries +} + +// extractChildValue retrieves the value at path from val. +// +// If the given path does not exist in val, or val does not support child path traversal, +// extractChildValue returns nil. +func extractChildValue(val interface{}, path string) interface{} { + segments := parsePath(path) + curr := val + for _, s := range segments { + if curr == nil { + return nil + } + + currMap, ok := curr.(map[string]interface{}) + if !ok { + return nil + } + if curr, ok = currMap[s]; !ok { + return nil + } + } + return curr +} + +func getIndexType(index interface{}) int { + if index == nil { + return typeNull + } else if b, ok := index.(bool); ok { + if b { + return typeBoolTrue + } + return typeBoolFalse + } else if _, ok := index.(float64); ok { + return typeNumeric + } else if _, ok := index.(string); ok { + return typeString + } + return typeObject +} diff --git a/db/query_test.go b/db/query_test.go new file mode 100644 index 00000000..4473daff --- /dev/null +++ b/db/query_test.go @@ -0,0 +1,774 @@ +// Copyright 2018 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package db + +import ( + "fmt" + "reflect" + "testing" + + "golang.org/x/net/context" +) + +var sortableKeysResp = map[string]interface{}{ + "bob": person{Name: "bob", Age: 20}, + "alice": person{Name: "alice", Age: 30}, + "charlie": person{Name: "charlie", Age: 15}, + "dave": person{Name: "dave", Age: 25}, + "ernie": person{Name: "ernie"}, +} + +var sortableValuesResp = []struct { + resp map[string]interface{} + want []interface{} + wantKeys []string +}{ + { + resp: map[string]interface{}{"k1": 1, "k2": 2, "k3": 3}, + want: []interface{}{1.0, 2.0, 3.0}, + wantKeys: []string{"k1", "k2", "k3"}, + }, + { + resp: map[string]interface{}{"k1": 3, "k2": 2, "k3": 1}, + want: []interface{}{1.0, 2.0, 3.0}, + wantKeys: []string{"k3", "k2", "k1"}, + }, + { + resp: map[string]interface{}{"k1": 3, "k2": 1, "k3": 2}, + want: []interface{}{1.0, 2.0, 3.0}, + wantKeys: []string{"k2", "k3", "k1"}, + }, + { + resp: map[string]interface{}{"k1": 1, "k2": 2, "k3": 1}, + want: []interface{}{1.0, 1.0, 2.0}, + wantKeys: []string{"k1", "k3", "k2"}, + }, + { + resp: map[string]interface{}{"k1": 1, "k2": 1, "k3": 2}, + want: []interface{}{1.0, 1.0, 2.0}, + wantKeys: []string{"k1", "k2", "k3"}, + }, + { + resp: map[string]interface{}{"k1": 2, "k2": 1, "k3": 1}, + want: []interface{}{1.0, 1.0, 2.0}, + wantKeys: []string{"k2", "k3", "k1"}, + }, + { + resp: map[string]interface{}{"k1": "foo", "k2": "bar", "k3": "baz"}, + want: []interface{}{"bar", "baz", "foo"}, + wantKeys: []string{"k2", "k3", "k1"}, + }, + { + resp: map[string]interface{}{"k1": "foo", "k2": "bar", "k3": 10}, + want: []interface{}{10.0, "bar", "foo"}, + wantKeys: []string{"k3", "k2", "k1"}, + }, + { + resp: map[string]interface{}{"k1": "foo", "k2": "bar", "k3": nil}, + want: []interface{}{nil, "bar", "foo"}, + wantKeys: []string{"k3", "k2", "k1"}, + }, + { + resp: map[string]interface{}{"k1": 5, "k2": "bar", "k3": nil}, + want: []interface{}{nil, 5.0, "bar"}, + wantKeys: []string{"k3", "k1", "k2"}, + }, + { + resp: map[string]interface{}{ + "k1": true, "k2": 0, "k3": "foo", "k4": "foo", "k5": false, + "k6": map[string]interface{}{"k1": true}, + }, + want: []interface{}{false, true, 0.0, "foo", "foo", map[string]interface{}{"k1": true}}, + wantKeys: []string{"k5", "k1", "k2", "k3", "k4", "k6"}, + }, + { + resp: map[string]interface{}{ + "k1": true, "k2": 0, "k3": "foo", "k4": "foo", "k5": false, + "k6": map[string]interface{}{"k1": true}, "k7": nil, + "k8": map[string]interface{}{"k0": true}, + }, + want: []interface{}{ + nil, false, true, 0.0, "foo", "foo", + map[string]interface{}{"k1": true}, map[string]interface{}{"k0": true}, + }, + wantKeys: []string{"k7", "k5", "k1", "k2", "k3", "k4", "k6", "k8"}, + }, +} + +func TestChildQuery(t *testing.T) { + want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + cases := []string{ + "messages", "messages/", "/messages", + } + var reqs []*testReq + for _, tc := range cases { + var got map[string]interface{} + if err := testref.OrderByChild(tc).Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("OrderByChild(%q) = %v; want = %v", tc, got, want) + } + reqs = append(reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"orderBy": "\"messages\""}, + }) + } + + checkAllRequests(t, mock.Reqs, reqs) +} + +func TestNestedChildQuery(t *testing.T) { + want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + var got map[string]interface{} + if err := testref.OrderByChild("messages/ratings").Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("OrderByChild(%q) = %v; want = %v", "messages/ratings", got, want) + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"orderBy": "\"messages/ratings\""}, + }) +} + +func TestChildQueryWithParams(t *testing.T) { + want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + q := testref.OrderByChild("messages").StartAt("m4").EndAt("m50").LimitToFirst(10) + var got map[string]interface{} + if err := q.Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("OrderByChild() = %v; want = %v", got, want) + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{ + "orderBy": "\"messages\"", + "startAt": "\"m4\"", + "endAt": "\"m50\"", + "limitToFirst": "10", + }, + }) +} + +func TestInvalidOrderByChild(t *testing.T) { + mock := &mockServer{Resp: "test"} + srv := mock.Start(client) + defer srv.Close() + + r := client.NewRef("/") + cases := []string{ + "", "/", "foo$", "foo.", "foo#", "foo]", + "foo[", "$key", "$value", "$priority", + } + for _, tc := range cases { + var got string + if err := r.OrderByChild(tc).Get(context.Background(), &got); got != "" || err == nil { + t.Errorf("OrderByChild(%q) = (%q, %v); want = (%q, error)", tc, got, err, "") + } + } + if len(mock.Reqs) != 0 { + t.Errorf("OrderByChild() = %v; want = empty", mock.Reqs) + } +} + +func TestKeyQuery(t *testing.T) { + want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + var got map[string]interface{} + if err := testref.OrderByKey().Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("OrderByKey() = %v; want = %v", got, want) + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"orderBy": "\"$key\""}, + }) +} + +func TestValueQuery(t *testing.T) { + want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + var got map[string]interface{} + if err := testref.OrderByValue().Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("OrderByValue() = %v; want = %v", got, want) + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"orderBy": "\"$value\""}, + }) +} + +func TestLimitFirstQuery(t *testing.T) { + want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + var got map[string]interface{} + if err := testref.OrderByChild("messages").LimitToFirst(10).Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("LimitToFirst() = %v; want = %v", got, want) + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"limitToFirst": "10", "orderBy": "\"messages\""}, + }) +} + +func TestLimitLastQuery(t *testing.T) { + want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + var got map[string]interface{} + if err := testref.OrderByChild("messages").LimitToLast(10).Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("LimitToLast() = %v; want = %v", got, want) + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"limitToLast": "10", "orderBy": "\"messages\""}, + }) +} + +func TestInvalidLimitQuery(t *testing.T) { + want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + q := testref.OrderByChild("messages") + cases := []struct { + name string + q *Query + }{ + {"BothLimits", q.LimitToFirst(10).LimitToLast(10)}, + {"NegativeFirst", q.LimitToFirst(-10)}, + {"NegativeLast", q.LimitToLast(-10)}, + } + for _, tc := range cases { + var got map[string]interface{} + if err := tc.q.Get(context.Background(), &got); got != nil || err == nil { + t.Errorf("OrderByChild(%q) = (%v, %v); want = (nil, error)", tc.name, got, err) + } + if len(mock.Reqs) != 0 { + t.Errorf("OrderByChild(%q): %v; want: empty", tc.name, mock.Reqs) + } + } +} + +func TestStartAtQuery(t *testing.T) { + want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + var got map[string]interface{} + if err := testref.OrderByChild("messages").StartAt(10).Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("StartAt() = %v; want = %v", got, want) + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"startAt": "10", "orderBy": "\"messages\""}, + }) +} + +func TestEndAtQuery(t *testing.T) { + want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + var got map[string]interface{} + if err := testref.OrderByChild("messages").EndAt(10).Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("EndAt() = %v; want = %v", got, want) + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"endAt": "10", "orderBy": "\"messages\""}, + }) +} + +func TestEqualToQuery(t *testing.T) { + want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + var got map[string]interface{} + if err := testref.OrderByChild("messages").EqualTo(10).Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("EqualTo() = %v; want = %v", got, want) + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"equalTo": "10", "orderBy": "\"messages\""}, + }) +} + +func TestInvalidFilterQuery(t *testing.T) { + want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + q := testref.OrderByChild("messages") + cases := []struct { + name string + q *Query + }{ + {"InvalidStartAt", q.StartAt(func() {})}, + {"InvalidEndAt", q.EndAt(func() {})}, + {"InvalidEqualTo", q.EqualTo(func() {})}, + } + for _, tc := range cases { + var got map[string]interface{} + if err := tc.q.Get(context.Background(), &got); got != nil || err == nil { + t.Errorf("OrderByChild(%q) = (%v, %v); want = (nil, error)", tc.name, got, err) + } + if len(mock.Reqs) != 0 { + t.Errorf("OrdderByChild(%q) = %v; want = empty", tc.name, mock.Reqs) + } + } +} + +func TestAllParamsQuery(t *testing.T) { + want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + q := testref.OrderByChild("messages").LimitToFirst(100).StartAt("bar").EndAt("foo") + var got map[string]interface{} + if err := q.Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("OrderByChild(AllParams) = %v; want = %v", got, want) + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{ + "limitToFirst": "100", + "startAt": "\"bar\"", + "endAt": "\"foo\"", + "orderBy": "\"messages\"", + }, + }) +} + +func TestChildQueryGetOrdered(t *testing.T) { + mock := &mockServer{Resp: sortableKeysResp} + srv := mock.Start(client) + defer srv.Close() + + cases := []struct { + child string + want []string + }{ + {"name", []string{"alice", "bob", "charlie", "dave", "ernie"}}, + {"age", []string{"ernie", "charlie", "bob", "dave", "alice"}}, + {"nonexisting", []string{"alice", "bob", "charlie", "dave", "ernie"}}, + } + + var reqs []*testReq + for idx, tc := range cases { + result, err := testref.OrderByChild(tc.child).GetOrdered(context.Background()) + if err != nil { + t.Fatal(err) + } + reqs = append(reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"orderBy": fmt.Sprintf("%q", tc.child)}, + }) + + var gotKeys, gotVals []string + for _, r := range result { + var p person + if err := r.Unmarshal(&p); err != nil { + t.Fatal(err) + } + gotKeys = append(gotKeys, r.Key()) + gotVals = append(gotVals, p.Name) + } + if !reflect.DeepEqual(tc.want, gotKeys) { + t.Errorf("[%d] GetOrdered(child: %q) = %v; want = %v", idx, tc.child, gotKeys, tc.want) + } + if !reflect.DeepEqual(tc.want, gotVals) { + t.Errorf("[%d] GetOrdered(child: %q) = %v; want = %v", idx, tc.child, gotVals, tc.want) + } + } + checkAllRequests(t, mock.Reqs, reqs) +} + +func TestImmediateChildQueryGetOrdered(t *testing.T) { + mock := &mockServer{} + srv := mock.Start(client) + defer srv.Close() + + type parsedMap struct { + Child interface{} `json:"child"` + } + + var reqs []*testReq + for idx, tc := range sortableValuesResp { + resp := map[string]interface{}{} + for k, v := range tc.resp { + resp[k] = map[string]interface{}{"child": v} + } + mock.Resp = resp + + result, err := testref.OrderByChild("child").GetOrdered(context.Background()) + if err != nil { + t.Fatal(err) + } + reqs = append(reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"orderBy": "\"child\""}, + }) + + var gotKeys []string + var gotVals []interface{} + for _, r := range result { + var p parsedMap + if err := r.Unmarshal(&p); err != nil { + t.Fatal(err) + } + gotKeys = append(gotKeys, r.Key()) + gotVals = append(gotVals, p.Child) + } + if !reflect.DeepEqual(tc.wantKeys, gotKeys) { + t.Errorf("[%d] GetOrdered(child: %q) = %v; want = %v", idx, "child", gotKeys, tc.wantKeys) + } + if !reflect.DeepEqual(tc.want, gotVals) { + t.Errorf("[%d] GetOrdered(child: %q) = %v; want = %v", idx, "child", gotVals, tc.want) + } + } + checkAllRequests(t, mock.Reqs, reqs) +} + +func TestNestedChildQueryGetOrdered(t *testing.T) { + mock := &mockServer{} + srv := mock.Start(client) + defer srv.Close() + + type grandChild struct { + GrandChild interface{} `json:"grandchild"` + } + type parsedMap struct { + Child grandChild `json:"child"` + } + + var reqs []*testReq + for idx, tc := range sortableValuesResp { + resp := map[string]interface{}{} + for k, v := range tc.resp { + resp[k] = map[string]interface{}{"child": map[string]interface{}{"grandchild": v}} + } + mock.Resp = resp + + q := testref.OrderByChild("child/grandchild") + result, err := q.GetOrdered(context.Background()) + if err != nil { + t.Fatal(err) + } + reqs = append(reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"orderBy": "\"child/grandchild\""}, + }) + + var gotKeys []string + var gotVals []interface{} + for _, r := range result { + var p parsedMap + if err := r.Unmarshal(&p); err != nil { + t.Fatal(err) + } + gotKeys = append(gotKeys, r.Key()) + gotVals = append(gotVals, p.Child.GrandChild) + } + if !reflect.DeepEqual(tc.wantKeys, gotKeys) { + t.Errorf("[%d] GetOrdered(child: %q) = %v; want = %v", idx, "child/grandchild", gotKeys, tc.wantKeys) + } + if !reflect.DeepEqual(tc.want, gotVals) { + t.Errorf("[%d] GetOrdered(child: %q) = %v; want = %v", idx, "child/grandchild", gotVals, tc.want) + } + } + checkAllRequests(t, mock.Reqs, reqs) +} + +func TestKeyQueryGetOrdered(t *testing.T) { + mock := &mockServer{Resp: sortableKeysResp} + srv := mock.Start(client) + defer srv.Close() + + result, err := testref.OrderByKey().GetOrdered(context.Background()) + if err != nil { + t.Fatal(err) + } + req := &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"orderBy": "\"$key\""}, + } + + var gotKeys, gotVals []string + for _, r := range result { + var p person + if err := r.Unmarshal(&p); err != nil { + t.Fatal(err) + } + gotKeys = append(gotKeys, r.Key()) + gotVals = append(gotVals, p.Name) + } + + want := []string{"alice", "bob", "charlie", "dave", "ernie"} + if !reflect.DeepEqual(want, gotKeys) { + t.Errorf("GetOrdered(key) = %v; want = %v", gotKeys, want) + } + if !reflect.DeepEqual(want, gotVals) { + t.Errorf("GetOrdered(key) = %v; want = %v", gotVals, want) + } + checkOnlyRequest(t, mock.Reqs, req) +} + +func TestValueQueryGetOrdered(t *testing.T) { + mock := &mockServer{} + srv := mock.Start(client) + defer srv.Close() + + var reqs []*testReq + for idx, tc := range sortableValuesResp { + mock.Resp = tc.resp + + result, err := testref.OrderByValue().GetOrdered(context.Background()) + if err != nil { + t.Fatal(err) + } + reqs = append(reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"orderBy": "\"$value\""}, + }) + + var gotKeys []string + var gotVals []interface{} + for _, r := range result { + var v interface{} + if err := r.Unmarshal(&v); err != nil { + t.Fatal(err) + } + gotKeys = append(gotKeys, r.Key()) + gotVals = append(gotVals, v) + } + + if !reflect.DeepEqual(tc.wantKeys, gotKeys) { + t.Errorf("[%d] GetOrdered(value) = %v; want = %v", idx, gotKeys, tc.wantKeys) + } + if !reflect.DeepEqual(tc.want, gotVals) { + t.Errorf("[%d] GetOrdered(value) = %v; want = %v", idx, gotVals, tc.want) + } + } + checkAllRequests(t, mock.Reqs, reqs) +} + +func TestValueQueryGetOrderedWithList(t *testing.T) { + cases := []struct { + resp []interface{} + want []interface{} + wantKeys []string + }{ + { + resp: []interface{}{1, 2, 3}, + want: []interface{}{1.0, 2.0, 3.0}, + wantKeys: []string{"0", "1", "2"}, + }, + { + resp: []interface{}{3, 2, 1}, + want: []interface{}{1.0, 2.0, 3.0}, + wantKeys: []string{"2", "1", "0"}, + }, + { + resp: []interface{}{1, 3, 2}, + want: []interface{}{1.0, 2.0, 3.0}, + wantKeys: []string{"0", "2", "1"}, + }, + { + resp: []interface{}{1, 3, 3}, + want: []interface{}{1.0, 3.0, 3.0}, + wantKeys: []string{"0", "1", "2"}, + }, + { + resp: []interface{}{1, 2, 1}, + want: []interface{}{1.0, 1.0, 2.0}, + wantKeys: []string{"0", "2", "1"}, + }, + { + resp: []interface{}{"foo", "bar", "baz"}, + want: []interface{}{"bar", "baz", "foo"}, + wantKeys: []string{"1", "2", "0"}, + }, + { + resp: []interface{}{"foo", 1, false, nil, 0, true}, + want: []interface{}{nil, false, true, 0.0, 1.0, "foo"}, + wantKeys: []string{"3", "2", "5", "4", "1", "0"}, + }, + } + + mock := &mockServer{} + srv := mock.Start(client) + defer srv.Close() + + var reqs []*testReq + for _, tc := range cases { + mock.Resp = tc.resp + + result, err := testref.OrderByValue().GetOrdered(context.Background()) + if err != nil { + t.Fatal(err) + } + reqs = append(reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"orderBy": "\"$value\""}, + }) + + var gotKeys []string + var gotVals []interface{} + for _, r := range result { + var v interface{} + if err := r.Unmarshal(&v); err != nil { + t.Fatal(err) + } + gotKeys = append(gotKeys, r.Key()) + gotVals = append(gotVals, v) + } + + if !reflect.DeepEqual(tc.wantKeys, gotKeys) { + t.Errorf("GetOrdered(value) = %v; want = %v", gotKeys, tc.wantKeys) + } + if !reflect.DeepEqual(tc.want, gotVals) { + t.Errorf("GetOrdered(value) = %v; want = %v", gotVals, tc.want) + } + } + checkAllRequests(t, mock.Reqs, reqs) +} + +func TestGetOrderedWithNilResult(t *testing.T) { + mock := &mockServer{Resp: nil} + srv := mock.Start(client) + defer srv.Close() + + result, err := testref.OrderByChild("child").GetOrdered(context.Background()) + if err != nil { + t.Fatal(err) + } + if result != nil { + t.Errorf("GetOrdered(value) = %v; want = nil", result) + } +} + +func TestGetOrderedWithLeafNode(t *testing.T) { + mock := &mockServer{Resp: "foo"} + srv := mock.Start(client) + defer srv.Close() + + result, err := testref.OrderByChild("child").GetOrdered(context.Background()) + if err != nil { + t.Fatal(err) + } + if len(result) != 1 { + t.Fatalf("GetOrdered(chid) = %d; want = 1", len(result)) + } + if result[0].Key() != "0" { + t.Errorf("GetOrdered(value).Key() = %v; want = %q", result[0].Key(), 0) + } + + var v interface{} + if err := result[0].Unmarshal(&v); err != nil { + t.Fatal(err) + } + if v != "foo" { + t.Errorf("GetOrdered(value) = %v; want = %v", v, "foo") + } +} + +func TestQueryHttpError(t *testing.T) { + mock := &mockServer{Resp: map[string]string{"error": "test error"}, Status: 500} + srv := mock.Start(client) + defer srv.Close() + + want := "http error status: 500; reason: test error" + result, err := testref.OrderByChild("child").GetOrdered(context.Background()) + if err == nil || err.Error() != want { + t.Errorf("GetOrdered() = %v; want = %v", err, want) + } + if result != nil { + t.Errorf("GetOrdered() = %v; want = nil", result) + } +} diff --git a/db/ref.go b/db/ref.go new file mode 100644 index 00000000..8fbadf84 --- /dev/null +++ b/db/ref.go @@ -0,0 +1,262 @@ +// Copyright 2018 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package db + +import ( + "encoding/json" + "fmt" + "net/http" + "strings" + + "firebase.google.com/go/internal" + + "golang.org/x/net/context" +) + +// txnRetires is the maximum number of times a transaction is retried before giving up. Transaction +// retries are triggered by concurrent conflicting updates to the same database location. +const txnRetries = 25 + +// Ref represents a node in the Firebase Realtime Database. +type Ref struct { + Key string + Path string + + segs []string + client *Client +} + +// TransactionNode represents the value of a node within the scope of a transaction. +type TransactionNode interface { + Unmarshal(v interface{}) error +} + +type transactionNodeImpl struct { + Raw []byte +} + +func (t *transactionNodeImpl) Unmarshal(v interface{}) error { + return json.Unmarshal(t.Raw, v) +} + +// Parent returns a reference to the parent of the current node. +// +// If the current reference points to the root of the database, Parent returns nil. +func (r *Ref) Parent() *Ref { + l := len(r.segs) + if l > 0 { + path := strings.Join(r.segs[:l-1], "/") + return r.client.NewRef(path) + } + return nil +} + +// Child returns a reference to the specified child node. +func (r *Ref) Child(path string) *Ref { + fp := fmt.Sprintf("%s/%s", r.Path, path) + return r.client.NewRef(fp) +} + +// Get retrieves the value at the current database location, and stores it in the value pointed to +// by v. +// +// Data deserialization is performed using https://golang.org/pkg/encoding/json/#Unmarshal, and +// therefore v has the same requirements as the json package. Specifically, it must be a pointer, +// and must not be nil. +func (r *Ref) Get(ctx context.Context, v interface{}) error { + resp, err := r.send(ctx, "GET") + if err != nil { + return err + } + return resp.Unmarshal(http.StatusOK, v) +} + +// GetWithETag retrieves the value at the current database location, along with its ETag. +func (r *Ref) GetWithETag(ctx context.Context, v interface{}) (string, error) { + resp, err := r.send(ctx, "GET", internal.WithHeader("X-Firebase-ETag", "true")) + if err != nil { + return "", err + } else if err := resp.Unmarshal(http.StatusOK, v); err != nil { + return "", err + } + return resp.Header.Get("Etag"), nil +} + +// GetShallow performs a shallow read on the current database location. +// +// Shallow reads do not retrieve the child nodes of the current reference. +func (r *Ref) GetShallow(ctx context.Context, v interface{}) error { + resp, err := r.send(ctx, "GET", internal.WithQueryParam("shallow", "true")) + if err != nil { + return err + } + return resp.Unmarshal(http.StatusOK, v) +} + +// GetIfChanged retrieves the value and ETag of the current database location only if the specified +// ETag does not match. +// +// If the specified ETag does not match, returns true along with the latest ETag of the database +// location. The value of the database location will be stored in v just like a regular Get() call. +// If the etag matches, returns false along with the same ETag passed into the function. No data +// will be stored in v in this case. +func (r *Ref) GetIfChanged(ctx context.Context, etag string, v interface{}) (bool, string, error) { + resp, err := r.send(ctx, "GET", internal.WithHeader("If-None-Match", etag)) + if err != nil { + return false, "", err + } + if resp.Status == http.StatusNotModified { + return false, etag, nil + } + if err := resp.Unmarshal(http.StatusOK, v); err != nil { + return false, "", err + } + return true, resp.Header.Get("ETag"), nil +} + +// Set stores the value v in the current database node. +// +// Set uses https://golang.org/pkg/encoding/json/#Marshal to serialize values into JSON. Therefore +// v has the same requirements as the json package. Values like functions and channels cannot be +// saved into Realtime Database. +func (r *Ref) Set(ctx context.Context, v interface{}) error { + resp, err := r.sendWithBody(ctx, "PUT", v, internal.WithQueryParam("print", "silent")) + if err != nil { + return err + } + return resp.CheckStatus(http.StatusNoContent) +} + +// SetIfUnchanged conditionally sets the data at this location to the given value. +// +// Sets the data at this location to v only if the specified ETag matches. Returns true if the +// value is written. Returns false if no changes are made to the database. +func (r *Ref) SetIfUnchanged(ctx context.Context, etag string, v interface{}) (bool, error) { + resp, err := r.sendWithBody(ctx, "PUT", v, internal.WithHeader("If-Match", etag)) + if err != nil { + return false, err + } + if resp.Status == http.StatusPreconditionFailed { + return false, nil + } + if err := resp.CheckStatus(http.StatusOK); err != nil { + return false, err + } + return true, nil +} + +// Push creates a new child node at the current location, and returns a reference to it. +// +// If v is not nil, it will be set as the initial value of the new child node. If v is nil, the +// new child node will be created with empty string as the value. +func (r *Ref) Push(ctx context.Context, v interface{}) (*Ref, error) { + if v == nil { + v = "" + } + resp, err := r.sendWithBody(ctx, "POST", v) + if err != nil { + return nil, err + } + var d struct { + Name string `json:"name"` + } + if err := resp.Unmarshal(http.StatusOK, &d); err != nil { + return nil, err + } + return r.Child(d.Name), nil +} + +// Update modifies the specified child keys of the current location to the provided values. +func (r *Ref) Update(ctx context.Context, v map[string]interface{}) error { + if len(v) == 0 { + return fmt.Errorf("value argument must be a non-empty map") + } + resp, err := r.sendWithBody(ctx, "PATCH", v, internal.WithQueryParam("print", "silent")) + if err != nil { + return err + } + return resp.CheckStatus(http.StatusNoContent) +} + +// UpdateFn represents a function type that can be passed into Transaction(). +type UpdateFn func(TransactionNode) (interface{}, error) + +// Transaction atomically modifies the data at this location. +// +// Unlike a normal Set(), which just overwrites the data regardless of its previous state, +// Transaction() is used to modify the existing value to a new value, ensuring there are no +// conflicts with other clients simultaneously writing to the same location. +// +// This is accomplished by passing an update function which is used to transform the current value +// of this reference into a new value. If another client writes to this location before the new +// value is successfully saved, the update function is called again with the new current value, and +// the write will be retried. In case of repeated failures, this method will retry the transaction up +// to 25 times before giving up and returning an error. +// +// The update function may also force an early abort by returning an error instead of returning a +// value. +func (r *Ref) Transaction(ctx context.Context, fn UpdateFn) error { + resp, err := r.send(ctx, "GET", internal.WithHeader("X-Firebase-ETag", "true")) + if err != nil { + return err + } else if err := resp.CheckStatus(http.StatusOK); err != nil { + return err + } + etag := resp.Header.Get("Etag") + + for i := 0; i < txnRetries; i++ { + new, err := fn(&transactionNodeImpl{resp.Body}) + if err != nil { + return err + } + resp, err = r.sendWithBody(ctx, "PUT", new, internal.WithHeader("If-Match", etag)) + if err != nil { + return err + } + if resp.Status == http.StatusOK { + return nil + } else if err := resp.CheckStatus(http.StatusPreconditionFailed); err != nil { + return err + } + etag = resp.Header.Get("ETag") + } + return fmt.Errorf("transaction aborted after failed retries") +} + +// Delete removes this node from the database. +func (r *Ref) Delete(ctx context.Context) error { + resp, err := r.send(ctx, "DELETE") + if err != nil { + return err + } + return resp.CheckStatus(http.StatusOK) +} + +func (r *Ref) send( + ctx context.Context, + method string, + opts ...internal.HTTPOption) (*internal.Response, error) { + + return r.client.send(ctx, method, r.Path, nil, opts...) +} + +func (r *Ref) sendWithBody( + ctx context.Context, + method string, + body interface{}, + opts ...internal.HTTPOption) (*internal.Response, error) { + + return r.client.send(ctx, method, r.Path, internal.NewJSONEntity(body), opts...) +} diff --git a/db/ref_test.go b/db/ref_test.go new file mode 100644 index 00000000..93e348d0 --- /dev/null +++ b/db/ref_test.go @@ -0,0 +1,729 @@ +// Copyright 2018 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package db + +import ( + "fmt" + "net/http" + "reflect" + "testing" + + "golang.org/x/net/context" +) + +type refOp func(r *Ref) error + +var testOps = []struct { + name string + resp interface{} + op refOp +}{ + { + "Get()", + "test", + func(r *Ref) error { + var got string + return r.Get(context.Background(), &got) + }, + }, + { + "GetWithETag()", + "test", + func(r *Ref) error { + var got string + _, err := r.GetWithETag(context.Background(), &got) + return err + }, + }, + { + "GetShallow()", + "test", + func(r *Ref) error { + var got string + return r.GetShallow(context.Background(), &got) + }, + }, + { + "GetIfChanged()", + "test", + func(r *Ref) error { + var got string + _, _, err := r.GetIfChanged(context.Background(), "etag", &got) + return err + }, + }, + { + "Set()", + nil, + func(r *Ref) error { + return r.Set(context.Background(), "foo") + }, + }, + { + "SetIfUnchanged()", + nil, + func(r *Ref) error { + _, err := r.SetIfUnchanged(context.Background(), "etag", "foo") + return err + }, + }, + { + "Push()", + map[string]interface{}{"name": "test"}, + func(r *Ref) error { + _, err := r.Push(context.Background(), "foo") + return err + }, + }, + { + "Update()", + nil, + func(r *Ref) error { + return r.Update(context.Background(), map[string]interface{}{"foo": "bar"}) + }, + }, + { + "Delete()", + nil, + func(r *Ref) error { + return r.Delete(context.Background()) + }, + }, + { + "Transaction()", + nil, + func(r *Ref) error { + fn := func(t TransactionNode) (interface{}, error) { + var v interface{} + if err := t.Unmarshal(&v); err != nil { + return nil, err + } + return v, nil + } + return r.Transaction(context.Background(), fn) + }, + }, +} + +func TestGet(t *testing.T) { + mock := &mockServer{} + srv := mock.Start(client) + defer srv.Close() + + cases := []interface{}{ + nil, float64(1), true, "foo", + map[string]interface{}{"name": "Peter Parker", "age": float64(17)}, + } + var want []*testReq + for _, tc := range cases { + mock.Resp = tc + var got interface{} + if err := testref.Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(tc, got) { + t.Errorf("Get() = %v; want = %v", got, tc) + } + want = append(want, &testReq{Method: "GET", Path: "/peter.json"}) + } + checkAllRequests(t, mock.Reqs, want) +} + +func TestInvalidGet(t *testing.T) { + want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + got := func() {} + if err := testref.Get(context.Background(), &got); err == nil { + t.Errorf("Get(func) = nil; want error") + } + checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: "/peter.json"}) +} + +func TestGetWithStruct(t *testing.T) { + want := person{Name: "Peter Parker", Age: 17} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + var got person + if err := testref.Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + if want != got { + t.Errorf("Get(struct) = %v; want = %v", got, want) + } + checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: "/peter.json"}) +} + +func TestGetShallow(t *testing.T) { + mock := &mockServer{} + srv := mock.Start(client) + defer srv.Close() + + cases := []interface{}{ + nil, float64(1), true, "foo", + map[string]interface{}{"name": "Peter Parker", "age": float64(17)}, + map[string]interface{}{"name": "Peter Parker", "nestedChild": true}, + } + wantQuery := map[string]string{"shallow": "true"} + var want []*testReq + for _, tc := range cases { + mock.Resp = tc + var got interface{} + if err := testref.GetShallow(context.Background(), &got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(tc, got) { + t.Errorf("GetShallow() = %v; want = %v", got, tc) + } + want = append(want, &testReq{Method: "GET", Path: "/peter.json", Query: wantQuery}) + } + checkAllRequests(t, mock.Reqs, want) +} + +func TestGetWithETag(t *testing.T) { + want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} + mock := &mockServer{ + Resp: want, + Header: map[string]string{"ETag": "mock-etag"}, + } + srv := mock.Start(client) + defer srv.Close() + + var got map[string]interface{} + etag, err := testref.GetWithETag(context.Background(), &got) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("GetWithETag() = %v; want = %v", got, want) + } + if etag != "mock-etag" { + t.Errorf("GetWithETag() = %q; want = %q", etag, "mock-etag") + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Header: http.Header{"X-Firebase-ETag": []string{"true"}}, + }) +} + +func TestGetIfChanged(t *testing.T) { + want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} + mock := &mockServer{ + Resp: want, + Header: map[string]string{"ETag": "new-etag"}, + } + srv := mock.Start(client) + defer srv.Close() + + var got map[string]interface{} + ok, etag, err := testref.GetIfChanged(context.Background(), "old-etag", &got) + if err != nil { + t.Fatal(err) + } + if !ok { + t.Errorf("GetIfChanged() = %v; want = %v", ok, true) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("GetIfChanged() = %v; want = %v", got, want) + } + if etag != "new-etag" { + t.Errorf("GetIfChanged() = %q; want = %q", etag, "new-etag") + } + + mock.Status = http.StatusNotModified + mock.Resp = nil + var got2 map[string]interface{} + ok, etag, err = testref.GetIfChanged(context.Background(), "new-etag", &got2) + if err != nil { + t.Fatal(err) + } + if ok { + t.Errorf("GetIfChanged() = %v; want = %v", ok, false) + } + if got2 != nil { + t.Errorf("GetIfChanged() = %v; want nil", got2) + } + if etag != "new-etag" { + t.Errorf("GetIfChanged() = %q; want = %q", etag, "new-etag") + } + + checkAllRequests(t, mock.Reqs, []*testReq{ + &testReq{ + Method: "GET", + Path: "/peter.json", + Header: http.Header{"If-None-Match": []string{"old-etag"}}, + }, + &testReq{ + Method: "GET", + Path: "/peter.json", + Header: http.Header{"If-None-Match": []string{"new-etag"}}, + }, + }) +} + +func TestWelformedHttpError(t *testing.T) { + mock := &mockServer{Resp: map[string]string{"error": "test error"}, Status: 500} + srv := mock.Start(client) + defer srv.Close() + + want := "http error status: 500; reason: test error" + for _, tc := range testOps { + err := tc.op(testref) + if err == nil || err.Error() != want { + t.Errorf("%s = %v; want = %v", tc.name, err, want) + } + } + + if len(mock.Reqs) != len(testOps) { + t.Errorf("Requests = %d; want = %d", len(mock.Reqs), len(testOps)) + } +} + +func TestUnexpectedHttpError(t *testing.T) { + mock := &mockServer{Resp: "unexpected error", Status: 500} + srv := mock.Start(client) + defer srv.Close() + + want := "http error status: 500; reason: \"unexpected error\"" + for _, tc := range testOps { + err := tc.op(testref) + if err == nil || err.Error() != want { + t.Errorf("%s = %v; want = %v", tc.name, err, want) + } + } + + if len(mock.Reqs) != len(testOps) { + t.Errorf("Requests = %d; want = %d", len(mock.Reqs), len(testOps)) + } +} + +func TestInvalidPath(t *testing.T) { + mock := &mockServer{Resp: "test"} + srv := mock.Start(client) + defer srv.Close() + + cases := []string{ + "foo$", "foo.", "foo#", "foo]", "foo[", + } + for _, tc := range cases { + r := client.NewRef(tc) + for _, o := range testOps { + err := o.op(r) + if err == nil { + t.Errorf("%s = nil; want = error", o.name) + } + } + } + + if len(mock.Reqs) != 0 { + t.Errorf("Requests = %v; want = empty", mock.Reqs) + } +} + +func TestInvalidChildPath(t *testing.T) { + mock := &mockServer{Resp: "test"} + srv := mock.Start(client) + defer srv.Close() + + cases := []string{ + "foo$", "foo.", "foo#", "foo]", "foo[", + } + for _, tc := range cases { + r := testref.Child(tc) + for _, o := range testOps { + err := o.op(r) + if err == nil { + t.Errorf("%s = nil; want = error", o.name) + } + } + } + + if len(mock.Reqs) != 0 { + t.Errorf("Requests = %v; want = empty", mock.Reqs) + } +} + +func TestSet(t *testing.T) { + mock := &mockServer{} + srv := mock.Start(client) + defer srv.Close() + + cases := []interface{}{ + 1, + true, + "foo", + map[string]interface{}{"name": "Peter Parker", "age": float64(17)}, + &person{"Peter Parker", 17}, + } + var want []*testReq + for _, tc := range cases { + if err := testref.Set(context.Background(), tc); err != nil { + t.Fatal(err) + } + want = append(want, &testReq{ + Method: "PUT", + Path: "/peter.json", + Body: serialize(tc), + Query: map[string]string{"print": "silent"}, + }) + } + checkAllRequests(t, mock.Reqs, want) +} + +func TestInvalidSet(t *testing.T) { + mock := &mockServer{} + srv := mock.Start(client) + defer srv.Close() + + cases := []interface{}{ + func() {}, + make(chan int), + } + for _, tc := range cases { + if err := testref.Set(context.Background(), tc); err == nil { + t.Errorf("Set(%v) = nil; want = error", tc) + } + } + if len(mock.Reqs) != 0 { + t.Errorf("Set() = %v; want = empty", mock.Reqs) + } +} + +func TestSetIfUnchanged(t *testing.T) { + mock := &mockServer{} + srv := mock.Start(client) + defer srv.Close() + + want := &person{"Peter Parker", 17} + ok, err := testref.SetIfUnchanged(context.Background(), "mock-etag", &want) + if err != nil { + t.Fatal(err) + } + if !ok { + t.Errorf("SetIfUnchanged() = %v; want = %v", ok, true) + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "PUT", + Path: "/peter.json", + Body: serialize(want), + Header: http.Header{"If-Match": []string{"mock-etag"}}, + }) +} + +func TestSetIfUnchangedError(t *testing.T) { + mock := &mockServer{ + Status: http.StatusPreconditionFailed, + Resp: &person{"Tony Stark", 39}, + } + srv := mock.Start(client) + defer srv.Close() + + want := &person{"Peter Parker", 17} + ok, err := testref.SetIfUnchanged(context.Background(), "mock-etag", &want) + if err != nil { + t.Fatal(err) + } + if ok { + t.Errorf("SetIfUnchanged() = %v; want = %v", ok, false) + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "PUT", + Path: "/peter.json", + Body: serialize(want), + Header: http.Header{"If-Match": []string{"mock-etag"}}, + }) +} + +func TestPush(t *testing.T) { + mock := &mockServer{Resp: map[string]string{"name": "new_key"}} + srv := mock.Start(client) + defer srv.Close() + + child, err := testref.Push(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + + if child.Key != "new_key" { + t.Errorf("Push() = %q; want = %q", child.Key, "new_key") + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "POST", + Path: "/peter.json", + Body: serialize(""), + }) +} + +func TestPushWithValue(t *testing.T) { + mock := &mockServer{Resp: map[string]string{"name": "new_key"}} + srv := mock.Start(client) + defer srv.Close() + + want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} + child, err := testref.Push(context.Background(), want) + if err != nil { + t.Fatal(err) + } + + if child.Key != "new_key" { + t.Errorf("Push() = %q; want = %q", child.Key, "new_key") + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "POST", + Path: "/peter.json", + Body: serialize(want), + }) +} + +func TestUpdate(t *testing.T) { + want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + if err := testref.Update(context.Background(), want); err != nil { + t.Fatal(err) + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "PATCH", + Path: "/peter.json", + Body: serialize(want), + Query: map[string]string{"print": "silent"}, + }) +} + +func TestInvalidUpdate(t *testing.T) { + cases := []map[string]interface{}{ + nil, + make(map[string]interface{}), + map[string]interface{}{"foo": func() {}}, + } + for _, tc := range cases { + if err := testref.Update(context.Background(), tc); err == nil { + t.Errorf("Update(%v) = nil; want error", tc) + } + } +} + +func TestTransaction(t *testing.T) { + mock := &mockServer{ + Resp: &person{"Peter Parker", 17}, + Header: map[string]string{"ETag": "mock-etag"}, + } + srv := mock.Start(client) + defer srv.Close() + + var fn UpdateFn = func(t TransactionNode) (interface{}, error) { + var p person + if err := t.Unmarshal(&p); err != nil { + return nil, err + } + p.Age++ + return &p, nil + } + if err := testref.Transaction(context.Background(), fn); err != nil { + t.Fatal(err) + } + checkAllRequests(t, mock.Reqs, []*testReq{ + &testReq{ + Method: "GET", + Path: "/peter.json", + Header: http.Header{"X-Firebase-ETag": []string{"true"}}, + }, + &testReq{ + Method: "PUT", + Path: "/peter.json", + Body: serialize(map[string]interface{}{ + "name": "Peter Parker", + "age": 18, + }), + Header: http.Header{"If-Match": []string{"mock-etag"}}, + }, + }) +} + +func TestTransactionRetry(t *testing.T) { + mock := &mockServer{ + Resp: &person{"Peter Parker", 17}, + Header: map[string]string{"ETag": "mock-etag1"}, + } + srv := mock.Start(client) + defer srv.Close() + + cnt := 0 + var fn UpdateFn = func(t TransactionNode) (interface{}, error) { + if cnt == 0 { + mock.Status = http.StatusPreconditionFailed + mock.Header = map[string]string{"ETag": "mock-etag2"} + mock.Resp = &person{"Peter Parker", 19} + } else if cnt == 1 { + mock.Status = http.StatusOK + } + cnt++ + var p person + if err := t.Unmarshal(&p); err != nil { + return nil, err + } + p.Age++ + return &p, nil + } + if err := testref.Transaction(context.Background(), fn); err != nil { + t.Fatal(err) + } + if cnt != 2 { + t.Errorf("Transaction() retries = %d; want = %d", cnt, 2) + } + checkAllRequests(t, mock.Reqs, []*testReq{ + &testReq{ + Method: "GET", + Path: "/peter.json", + Header: http.Header{"X-Firebase-ETag": []string{"true"}}, + }, + &testReq{ + Method: "PUT", + Path: "/peter.json", + Body: serialize(map[string]interface{}{ + "name": "Peter Parker", + "age": 18, + }), + Header: http.Header{"If-Match": []string{"mock-etag1"}}, + }, + &testReq{ + Method: "PUT", + Path: "/peter.json", + Body: serialize(map[string]interface{}{ + "name": "Peter Parker", + "age": 20, + }), + Header: http.Header{"If-Match": []string{"mock-etag2"}}, + }, + }) +} + +func TestTransactionError(t *testing.T) { + mock := &mockServer{ + Resp: &person{"Peter Parker", 17}, + Header: map[string]string{"ETag": "mock-etag1"}, + } + srv := mock.Start(client) + defer srv.Close() + + cnt := 0 + want := "user error" + var fn UpdateFn = func(t TransactionNode) (interface{}, error) { + if cnt == 0 { + mock.Status = http.StatusPreconditionFailed + mock.Header = map[string]string{"ETag": "mock-etag2"} + mock.Resp = &person{"Peter Parker", 19} + } else if cnt == 1 { + return nil, fmt.Errorf(want) + } + cnt++ + var p person + if err := t.Unmarshal(&p); err != nil { + return nil, err + } + p.Age++ + return &p, nil + } + if err := testref.Transaction(context.Background(), fn); err == nil || err.Error() != want { + t.Errorf("Transaction() = %v; want = %q", err, want) + } + if cnt != 1 { + t.Errorf("Transaction() retries = %d; want = %d", cnt, 1) + } + checkAllRequests(t, mock.Reqs, []*testReq{ + &testReq{ + Method: "GET", + Path: "/peter.json", + Header: http.Header{"X-Firebase-ETag": []string{"true"}}, + }, + &testReq{ + Method: "PUT", + Path: "/peter.json", + Body: serialize(map[string]interface{}{ + "name": "Peter Parker", + "age": 18, + }), + Header: http.Header{"If-Match": []string{"mock-etag1"}}, + }, + }) +} + +func TestTransactionAbort(t *testing.T) { + mock := &mockServer{ + Resp: &person{"Peter Parker", 17}, + Header: map[string]string{"ETag": "mock-etag1"}, + } + srv := mock.Start(client) + defer srv.Close() + + cnt := 0 + var fn UpdateFn = func(t TransactionNode) (interface{}, error) { + if cnt == 0 { + mock.Status = http.StatusPreconditionFailed + mock.Header = map[string]string{"ETag": "mock-etag1"} + } + cnt++ + var p person + if err := t.Unmarshal(&p); err != nil { + return nil, err + } + p.Age++ + return &p, nil + } + err := testref.Transaction(context.Background(), fn) + if err == nil { + t.Errorf("Transaction() = nil; want error") + } + wanted := []*testReq{ + &testReq{ + Method: "GET", + Path: "/peter.json", + Header: http.Header{"X-Firebase-ETag": []string{"true"}}, + }, + } + for i := 0; i < txnRetries; i++ { + wanted = append(wanted, &testReq{ + Method: "PUT", + Path: "/peter.json", + Body: serialize(map[string]interface{}{ + "name": "Peter Parker", + "age": 18, + }), + Header: http.Header{"If-Match": []string{"mock-etag1"}}, + }) + } + checkAllRequests(t, mock.Reqs, wanted) +} + +func TestDelete(t *testing.T) { + mock := &mockServer{Resp: "null"} + srv := mock.Start(client) + defer srv.Close() + + if err := testref.Delete(context.Background()); err != nil { + t.Fatal(err) + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "DELETE", + Path: "/peter.json", + }) +} diff --git a/firebase.go b/firebase.go index b533c12d..0c6718da 100644 --- a/firebase.go +++ b/firebase.go @@ -27,6 +27,7 @@ import ( "cloud.google.com/go/firestore" "firebase.google.com/go/auth" + "firebase.google.com/go/db" "firebase.google.com/go/iid" "firebase.google.com/go/internal" "firebase.google.com/go/messaging" @@ -37,15 +38,7 @@ import ( "google.golang.org/api/transport" ) -var firebaseScopes = []string{ - "https://www.googleapis.com/auth/cloud-platform", - "https://www.googleapis.com/auth/datastore", - "https://www.googleapis.com/auth/devstorage.full_control", - "https://www.googleapis.com/auth/firebase", - "https://www.googleapis.com/auth/identitytoolkit", - "https://www.googleapis.com/auth/userinfo.email", - "https://www.googleapis.com/auth/firebase.messaging", -} +var defaultAuthOverrides = make(map[string]interface{}) // Version of the Firebase Go Admin SDK. const Version = "2.5.0" @@ -55,7 +48,9 @@ const firebaseEnvName = "FIREBASE_CONFIG" // An App holds configuration and state common to all Firebase services that are exposed from the SDK. type App struct { + authOverride map[string]interface{} creds *google.DefaultCredentials + dbURL string projectID string storageBucket string opts []option.ClientOption @@ -63,8 +58,10 @@ type App struct { // Config represents the configuration used to initialize an App. type Config struct { - ProjectID string `json:"projectId"` - StorageBucket string `json:"storageBucket"` + AuthOverride *map[string]interface{} `json:"databaseAuthVariableOverride"` + DatabaseURL string `json:"databaseURL"` + ProjectID string `json:"projectId"` + StorageBucket string `json:"storageBucket"` } // Auth returns an instance of auth.Client. @@ -78,6 +75,17 @@ func (a *App) Auth(ctx context.Context) (*auth.Client, error) { return auth.NewClient(ctx, conf) } +// Database returns an instance of db.Client. +func (a *App) Database(ctx context.Context) (*db.Client, error) { + conf := &internal.DatabaseConfig{ + AuthOverride: a.authOverride, + URL: a.dbURL, + Opts: a.opts, + Version: Version, + } + return db.NewClient(ctx, conf) +} + // Storage returns a new instance of storage.Client. func (a *App) Storage(ctx context.Context) (*storage.Client, error) { conf := &internal.StorageConfig{ @@ -124,7 +132,7 @@ func (a *App) Messaging(ctx context.Context) (*messaging.Client, error) { // `FIREBASE_CONFIG` environment variable. If the value in it starts with a `{` it is parsed as a // JSON object, otherwise it is assumed to be the name of the JSON file containing the options. func NewApp(ctx context.Context, config *Config, opts ...option.ClientOption) (*App, error) { - o := []option.ClientOption{option.WithScopes(firebaseScopes...)} + o := []option.ClientOption{option.WithScopes(internal.FirebaseScopes...)} o = append(o, opts...) creds, err := transport.Creds(ctx, o...) if err != nil { @@ -145,8 +153,15 @@ func NewApp(ctx context.Context, config *Config, opts ...option.ClientOption) (* pid = os.Getenv("GCLOUD_PROJECT") } + ao := defaultAuthOverrides + if config.AuthOverride != nil { + ao = *config.AuthOverride + } + return &App{ + authOverride: ao, creds: creds, + dbURL: config.DatabaseURL, projectID: pid, storageBucket: config.StorageBucket, opts: o, @@ -170,6 +185,19 @@ func getConfigDefaults() (*Config, error) { return nil, err } } - err := json.Unmarshal(dat, fbc) - return fbc, err + if err := json.Unmarshal(dat, fbc); err != nil { + return nil, err + } + + // Some special handling necessary for db auth overrides + var m map[string]interface{} + if err := json.Unmarshal(dat, &m); err != nil { + return nil, err + } + if ao, ok := m["databaseAuthVariableOverride"]; ok && ao == nil { + // Auth overrides are explicitly set to null + var nullMap map[string]interface{} + fbc.AuthOverride = &nullMap + } + return fbc, nil } diff --git a/firebase_test.go b/firebase_test.go index fc3d50d3..bbf1a4c3 100644 --- a/firebase_test.go +++ b/firebase_test.go @@ -22,6 +22,7 @@ import ( "net/http" "net/http/httptest" "os" + "reflect" "strconv" "strings" "testing" @@ -227,6 +228,48 @@ func TestAuth(t *testing.T) { } } +func TestDatabase(t *testing.T) { + ctx := context.Background() + conf := &Config{DatabaseURL: "https://mock-db.firebaseio.com"} + app, err := NewApp(ctx, conf, option.WithCredentialsFile("testdata/service_account.json")) + if err != nil { + t.Fatal(err) + } + + if app.authOverride == nil || len(app.authOverride) != 0 { + t.Errorf("AuthOverrides = %v; want = empty map", app.authOverride) + } + if c, err := app.Database(ctx); c == nil || err != nil { + t.Errorf("Database() = (%v, %v); want (db, nil)", c, err) + } +} + +func TestDatabaseAuthOverrides(t *testing.T) { + cases := []map[string]interface{}{ + nil, + map[string]interface{}{}, + map[string]interface{}{"uid": "user1"}, + } + for _, tc := range cases { + ctx := context.Background() + conf := &Config{ + AuthOverride: &tc, + DatabaseURL: "https://mock-db.firebaseio.com", + } + app, err := NewApp(ctx, conf, option.WithCredentialsFile("testdata/service_account.json")) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(app.authOverride, tc) { + t.Errorf("AuthOverrides = %v; want = %v", app.authOverride, tc) + } + if c, err := app.Database(ctx); c == nil || err != nil { + t.Errorf("Database() = (%v, %v); want (db, nil)", c, err) + } + } +} + func TestStorage(t *testing.T) { ctx := context.Background() app, err := NewApp(ctx, nil, option.WithCredentialsFile("testdata/service_account.json")) @@ -360,7 +403,10 @@ func TestVersion(t *testing.T) { } } } + func TestAutoInit(t *testing.T) { + var nullMap map[string]interface{} + uidMap := map[string]interface{}{"uid": "test"} tests := []struct { name string optionsConfig string @@ -378,6 +424,7 @@ func TestAutoInit(t *testing.T) { "testdata/firebase_config.json", nil, &Config{ + DatabaseURL: "https://auto-init.database.url", ProjectID: "auto-init-project-id", StorageBucket: "auto-init.storage.bucket", }, @@ -385,11 +432,13 @@ func TestAutoInit(t *testing.T) { { "", `{ + "databaseURL": "https://auto-init.database.url", "projectId": "auto-init-project-id", "storageBucket": "auto-init.storage.bucket" }`, nil, &Config{ + DatabaseURL: "https://auto-init.database.url", ProjectID: "auto-init-project-id", StorageBucket: "auto-init.storage.bucket", }, @@ -456,6 +505,34 @@ func TestAutoInit(t *testing.T) { StorageBucket: "auto-init.storage.bucket", }, }, + { + "", + `{ + "databaseURL": "https://auto-init.database.url", + "projectId": "auto-init-project-id", + "databaseAuthVariableOverride": null + }`, + nil, + &Config{ + DatabaseURL: "https://auto-init.database.url", + ProjectID: "auto-init-project-id", + AuthOverride: &nullMap, + }, + }, + { + "", + `{ + "databaseURL": "https://auto-init.database.url", + "projectId": "auto-init-project-id", + "databaseAuthVariableOverride": {"uid": "test"} + }`, + nil, + &Config{ + DatabaseURL: "https://auto-init.database.url", + ProjectID: "auto-init-project-id", + AuthOverride: &uidMap, + }, + }, } credOld := overwriteEnv(credEnvVar, "testdata/service_account.json") @@ -523,6 +600,16 @@ func (t *testTokenSource) Token() (*oauth2.Token, error) { } func compareConfig(got *App, want *Config, t *testing.T) { + if got.dbURL != want.DatabaseURL { + t.Errorf("app.dbURL = %q; want = %q", got.dbURL, want.DatabaseURL) + } + if want.AuthOverride != nil { + if !reflect.DeepEqual(got.authOverride, *want.AuthOverride) { + t.Errorf("app.ao = %#v; want = %#v", got.authOverride, *want.AuthOverride) + } + } else if !reflect.DeepEqual(got.authOverride, defaultAuthOverrides) { + t.Errorf("app.ao = %#v; want = nil", got.authOverride) + } if got.projectID != want.ProjectID { t.Errorf("app.projectID = %q; want = %q", got.projectID, want.ProjectID) } diff --git a/integration/auth/auth_test.go b/integration/auth/auth_test.go index 33d8cf2c..2b8d6bd6 100644 --- a/integration/auth/auth_test.go +++ b/integration/auth/auth_test.go @@ -44,7 +44,7 @@ func TestMain(m *testing.M) { } ctx := context.Background() - app, err := internal.NewTestApp(ctx) + app, err := internal.NewTestApp(ctx, nil) if err != nil { log.Fatalln(err) } diff --git a/integration/db/db_test.go b/integration/db/db_test.go new file mode 100644 index 00000000..0754d5bf --- /dev/null +++ b/integration/db/db_test.go @@ -0,0 +1,709 @@ +// Copyright 2018 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package db contains integration tests for the firebase.google.com/go/db package. +package db + +import ( + "bytes" + "encoding/json" + "flag" + "fmt" + "io/ioutil" + "log" + "net/http" + "os" + "reflect" + "testing" + + "golang.org/x/net/context" + + "firebase.google.com/go" + "firebase.google.com/go/db" + "firebase.google.com/go/integration/internal" +) + +var client *db.Client +var aoClient *db.Client +var guestClient *db.Client + +var ref *db.Ref +var users *db.Ref +var dinos *db.Ref + +var testData map[string]interface{} +var parsedTestData map[string]Dinosaur + +const permDenied = "http error status: 401; reason: Permission denied" + +func TestMain(m *testing.M) { + flag.Parse() + if testing.Short() { + log.Println("skipping database integration tests in short mode.") + os.Exit(0) + } + + pid, err := internal.ProjectID() + if err != nil { + log.Fatalln(err) + } + + client, err = initClient(pid) + if err != nil { + log.Fatalln(err) + } + + aoClient, err = initOverrideClient(pid) + if err != nil { + log.Fatalln(err) + } + + guestClient, err = initGuestClient(pid) + if err != nil { + log.Fatalln(err) + } + + ref = client.NewRef("_adminsdk/go/dinodb") + dinos = ref.Child("dinosaurs") + users = ref.Parent().Child("users") + + initRules() + initData() + + os.Exit(m.Run()) +} + +func initClient(pid string) (*db.Client, error) { + ctx := context.Background() + app, err := internal.NewTestApp(ctx, &firebase.Config{ + DatabaseURL: fmt.Sprintf("https://%s.firebaseio.com", pid), + }) + if err != nil { + return nil, err + } + + return app.Database(ctx) +} + +func initOverrideClient(pid string) (*db.Client, error) { + ctx := context.Background() + ao := map[string]interface{}{"uid": "user1"} + app, err := internal.NewTestApp(ctx, &firebase.Config{ + DatabaseURL: fmt.Sprintf("https://%s.firebaseio.com", pid), + AuthOverride: &ao, + }) + if err != nil { + return nil, err + } + + return app.Database(ctx) +} + +func initGuestClient(pid string) (*db.Client, error) { + ctx := context.Background() + var nullMap map[string]interface{} + app, err := internal.NewTestApp(ctx, &firebase.Config{ + DatabaseURL: fmt.Sprintf("https://%s.firebaseio.com", pid), + AuthOverride: &nullMap, + }) + if err != nil { + return nil, err + } + + return app.Database(ctx) +} + +func initRules() { + b, err := ioutil.ReadFile(internal.Resource("dinosaurs_index.json")) + if err != nil { + log.Fatalln(err) + } + + pid, err := internal.ProjectID() + if err != nil { + log.Fatalln(err) + } + + url := fmt.Sprintf("https://%s.firebaseio.com/.settings/rules.json", pid) + req, err := http.NewRequest("PUT", url, bytes.NewBuffer(b)) + if err != nil { + log.Fatalln(err) + } + req.Header.Set("Content-Type", "application/json") + + hc, err := internal.NewHTTPClient(context.Background()) + if err != nil { + log.Fatalln(err) + } + resp, err := hc.Do(req) + if err != nil { + log.Fatalln(err) + } + defer resp.Body.Close() + + b, err = ioutil.ReadAll(resp.Body) + if err != nil { + log.Fatalln(err) + } else if resp.StatusCode != http.StatusOK { + log.Fatalln("failed to update rules:", string(b)) + } +} + +func initData() { + b, err := ioutil.ReadFile(internal.Resource("dinosaurs.json")) + if err != nil { + log.Fatalln(err) + } + if err = json.Unmarshal(b, &testData); err != nil { + log.Fatalln(err) + } + + b, err = json.Marshal(testData["dinosaurs"]) + if err != nil { + log.Fatalln(err) + } + if err = json.Unmarshal(b, &parsedTestData); err != nil { + log.Fatalln(err) + } + + if err = ref.Set(context.Background(), testData); err != nil { + log.Fatalln(err) + } +} + +func TestRef(t *testing.T) { + if ref.Key != "dinodb" { + t.Errorf("Key = %q; want = %q", ref.Key, "dinodb") + } + if ref.Path != "/_adminsdk/go/dinodb" { + t.Errorf("Path = %q; want = %q", ref.Path, "/_adminsdk/go/dinodb") + } +} + +func TestChild(t *testing.T) { + c := ref.Child("dinosaurs") + if c.Key != "dinosaurs" { + t.Errorf("Key = %q; want = %q", c.Key, "dinosaurs") + } + if c.Path != "/_adminsdk/go/dinodb/dinosaurs" { + t.Errorf("Path = %q; want = %q", c.Path, "/_adminsdk/go/dinodb/dinosaurs") + } +} + +func TestParent(t *testing.T) { + p := ref.Parent() + if p.Key != "go" { + t.Errorf("Key = %q; want = %q", p.Key, "go") + } + if p.Path != "/_adminsdk/go" { + t.Errorf("Path = %q; want = %q", p.Path, "/_adminsdk/go") + } +} + +func TestGet(t *testing.T) { + var m map[string]interface{} + if err := ref.Get(context.Background(), &m); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(testData, m) { + t.Errorf("Get() = %v; want = %v", m, testData) + } +} + +func TestGetWithETag(t *testing.T) { + var m map[string]interface{} + etag, err := ref.GetWithETag(context.Background(), &m) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(testData, m) { + t.Errorf("GetWithETag() = %v; want = %v", m, testData) + } + if etag == "" { + t.Errorf("GetWithETag() = \"\"; want non-empty") + } +} + +func TestGetShallow(t *testing.T) { + var m map[string]interface{} + if err := ref.GetShallow(context.Background(), &m); err != nil { + t.Fatal(err) + } + want := map[string]interface{}{} + for k := range testData { + want[k] = true + } + if !reflect.DeepEqual(want, m) { + t.Errorf("GetShallow() = %v; want = %v", m, want) + } +} + +func TestGetIfChanged(t *testing.T) { + var m map[string]interface{} + ok, etag, err := ref.GetIfChanged(context.Background(), "wrong-etag", &m) + if err != nil { + t.Fatal(err) + } + if !ok || etag == "" { + t.Errorf("GetIfChanged() = (%v, %q); want = (%v, %q)", ok, etag, true, "non-empty") + } + if !reflect.DeepEqual(testData, m) { + t.Errorf("GetWithETag() = %v; want = %v", m, testData) + } + + var m2 map[string]interface{} + ok, etag2, err := ref.GetIfChanged(context.Background(), etag, &m2) + if err != nil { + t.Fatal(err) + } + if ok || etag != etag2 { + t.Errorf("GetIfChanged() = (%v, %q); want = (%v, %q)", ok, etag2, false, etag) + } + if len(m2) != 0 { + t.Errorf("GetWithETag() = %v; want empty", m) + } +} + +func TestGetChildValue(t *testing.T) { + c := ref.Child("dinosaurs") + var m map[string]interface{} + if err := c.Get(context.Background(), &m); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(testData["dinosaurs"], m) { + t.Errorf("Get() = %v; want = %v", m, testData["dinosaurs"]) + } +} + +func TestGetGrandChildValue(t *testing.T) { + c := ref.Child("dinosaurs/lambeosaurus") + var got Dinosaur + if err := c.Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + want := parsedTestData["lambeosaurus"] + if !reflect.DeepEqual(want, got) { + t.Errorf("Get() = %v; want = %v", got, want) + } +} + +func TestGetNonExistingChild(t *testing.T) { + c := ref.Child("non_existing") + var i interface{} + if err := c.Get(context.Background(), &i); err != nil { + t.Fatal(err) + } + if i != nil { + t.Errorf("Get() = %v; want nil", i) + } +} + +func TestPush(t *testing.T) { + u, err := users.Push(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + if u.Path != "/_adminsdk/go/users/"+u.Key { + t.Errorf("Push() = %q; want = %q", u.Path, "/_adminsdk/go/users/"+u.Key) + } + + var i interface{} + if err := u.Get(context.Background(), &i); err != nil { + t.Fatal(err) + } + if i != "" { + t.Errorf("Get() = %v; want empty string", i) + } +} + +func TestPushWithValue(t *testing.T) { + want := User{"Luis Alvarez", 1911} + u, err := users.Push(context.Background(), &want) + if err != nil { + t.Fatal(err) + } + if u.Path != "/_adminsdk/go/users/"+u.Key { + t.Errorf("Push() = %q; want = %q", u.Path, "/_adminsdk/go/users/"+u.Key) + } + + var got User + if err := u.Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + if want != got { + t.Errorf("Get() = %v; want = %v", got, want) + } +} + +func TestSetPrimitiveValue(t *testing.T) { + u, err := users.Push(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + if err := u.Set(context.Background(), "value"); err != nil { + t.Fatal(err) + } + var got string + if err := u.Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + if got != "value" { + t.Errorf("Get() = %q; want = %q", got, "value") + } +} + +func TestSetComplexValue(t *testing.T) { + u, err := users.Push(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + + want := User{"Mary Anning", 1799} + if err := u.Set(context.Background(), &want); err != nil { + t.Fatal(err) + } + var got User + if err := u.Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + if got != want { + t.Errorf("Get() = %v; want = %v", got, want) + } +} + +func TestUpdateChildren(t *testing.T) { + u, err := users.Push(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + + want := map[string]interface{}{ + "name": "Robert Bakker", + "since": float64(1945), + } + if err := u.Update(context.Background(), want); err != nil { + t.Fatal(err) + } + var got map[string]interface{} + if err := u.Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("Get() = %v; want = %v", got, want) + } +} + +func TestUpdateChildrenWithExistingValue(t *testing.T) { + u, err := users.Push(context.Background(), map[string]interface{}{ + "name": "Edwin Colbert", + "since": float64(1900), + }) + if err != nil { + t.Fatal(err) + } + + update := map[string]interface{}{"since": float64(1905)} + if err := u.Update(context.Background(), update); err != nil { + t.Fatal(err) + } + var got map[string]interface{} + if err := u.Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + want := map[string]interface{}{ + "name": "Edwin Colbert", + "since": float64(1905), + } + if !reflect.DeepEqual(want, got) { + t.Errorf("Get() = %v; want = %v", got, want) + } +} + +func TestUpdateNestedChildren(t *testing.T) { + edward, err := users.Push(context.Background(), map[string]interface{}{ + "name": "Edward Cope", "since": float64(1800), + }) + if err != nil { + t.Fatal(err) + } + jack, err := users.Push(context.Background(), map[string]interface{}{ + "name": "Jack Horner", "since": float64(1940), + }) + if err != nil { + t.Fatal(err) + } + delta := map[string]interface{}{ + fmt.Sprintf("%s/since", edward.Key): 1840, + fmt.Sprintf("%s/since", jack.Key): 1946, + } + if err := users.Update(context.Background(), delta); err != nil { + t.Fatal(err) + } + var got map[string]interface{} + if err := edward.Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + want := map[string]interface{}{"name": "Edward Cope", "since": float64(1840)} + if !reflect.DeepEqual(want, got) { + t.Errorf("Get() = %v; want = %v", got, want) + } + + if err := jack.Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + want = map[string]interface{}{"name": "Jack Horner", "since": float64(1946)} + if !reflect.DeepEqual(want, got) { + t.Errorf("Get() = %v; want = %v", got, want) + } +} + +func TestSetIfChanged(t *testing.T) { + edward, err := users.Push(context.Background(), &User{"Edward Cope", 1800}) + if err != nil { + t.Fatal(err) + } + + update := User{"Jack Horner", 1940} + ok, err := edward.SetIfUnchanged(context.Background(), "invalid-etag", &update) + if err != nil { + t.Fatal(err) + } + if ok { + t.Errorf("SetIfUnchanged() = %v; want = %v", ok, false) + } + + var u User + etag, err := edward.GetWithETag(context.Background(), &u) + if err != nil { + t.Fatal(err) + } + ok, err = edward.SetIfUnchanged(context.Background(), etag, &update) + if err != nil { + t.Fatal(err) + } + if !ok { + t.Errorf("SetIfUnchanged() = %v; want = %v", ok, true) + } + + if err := edward.Get(context.Background(), &u); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(update, u) { + t.Errorf("Get() = %v; want = %v", u, update) + } +} + +func TestTransaction(t *testing.T) { + u, err := users.Push(context.Background(), &User{Name: "Richard"}) + if err != nil { + t.Fatal(err) + } + fn := func(t db.TransactionNode) (interface{}, error) { + var user User + if err := t.Unmarshal(&user); err != nil { + return nil, err + } + user.Name = "Richard Owen" + user.Since = 1804 + return &user, nil + } + if err := u.Transaction(context.Background(), fn); err != nil { + t.Fatal(err) + } + var got User + if err := u.Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + want := User{"Richard Owen", 1804} + if !reflect.DeepEqual(want, got) { + t.Errorf("Get() = %v; want = %v", got, want) + } +} + +func TestTransactionScalar(t *testing.T) { + cnt := users.Child("count") + if err := cnt.Set(context.Background(), 42); err != nil { + t.Fatal(err) + } + fn := func(t db.TransactionNode) (interface{}, error) { + var snap float64 + if err := t.Unmarshal(&snap); err != nil { + return nil, err + } + return snap + 1, nil + } + if err := cnt.Transaction(context.Background(), fn); err != nil { + t.Fatal(err) + } + var got float64 + if err := cnt.Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + if got != 43.0 { + t.Errorf("Get() = %v; want = %v", got, 43.0) + } +} + +func TestDelete(t *testing.T) { + u, err := users.Push(context.Background(), "foo") + if err != nil { + t.Fatal(err) + } + var got string + if err := u.Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + if got != "foo" { + t.Errorf("Get() = %q; want = %q", got, "foo") + } + if err := u.Delete(context.Background()); err != nil { + t.Fatal(err) + } + + var got2 string + if err := u.Get(context.Background(), &got2); err != nil { + t.Fatal(err) + } + if got2 != "" { + t.Errorf("Get() = %q; want = %q", got2, "") + } +} + +func TestNoAccess(t *testing.T) { + r := aoClient.NewRef(protectedRef(t, "_adminsdk/go/admin")) + var got string + if err := r.Get(context.Background(), &got); err == nil || got != "" { + t.Errorf("Get() = (%q, %v); want = (empty, error)", got, err) + } else if err.Error() != permDenied { + t.Errorf("Error = %q; want = %q", err.Error(), permDenied) + } + if err := r.Set(context.Background(), "update"); err == nil { + t.Errorf("Set() = nil; want = error") + } else if err.Error() != permDenied { + t.Errorf("Error = %q; want = %q", err.Error(), permDenied) + } +} + +func TestReadAccess(t *testing.T) { + r := aoClient.NewRef(protectedRef(t, "_adminsdk/go/protected/user2")) + var got string + if err := r.Get(context.Background(), &got); err != nil || got != "test" { + t.Errorf("Get() = (%q, %v); want = (%q, nil)", got, err, "test") + } + if err := r.Set(context.Background(), "update"); err == nil { + t.Errorf("Set() = nil; want = error") + } else if err.Error() != permDenied { + t.Errorf("Error = %q; want = %q", err.Error(), permDenied) + } +} + +func TestReadWriteAccess(t *testing.T) { + r := aoClient.NewRef(protectedRef(t, "_adminsdk/go/protected/user1")) + var got string + if err := r.Get(context.Background(), &got); err != nil || got != "test" { + t.Errorf("Get() = (%q, %v); want = (%q, nil)", got, err, "test") + } + if err := r.Set(context.Background(), "update"); err != nil { + t.Errorf("Set() = %v; want = nil", err) + } +} + +func TestQueryAccess(t *testing.T) { + r := aoClient.NewRef("_adminsdk/go/protected") + got := make(map[string]interface{}) + if err := r.OrderByKey().LimitToFirst(2).Get(context.Background(), &got); err == nil { + t.Errorf("OrderByQuery() = nil; want = error") + } else if err.Error() != permDenied { + t.Errorf("Error = %q; want = %q", err.Error(), permDenied) + } +} + +func TestGuestAccess(t *testing.T) { + r := guestClient.NewRef(protectedRef(t, "_adminsdk/go/public")) + var got string + if err := r.Get(context.Background(), &got); err != nil || got != "test" { + t.Errorf("Get() = (%q, %v); want = (%q, nil)", got, err, "test") + } + if err := r.Set(context.Background(), "update"); err == nil { + t.Errorf("Set() = nil; want = error") + } else if err.Error() != permDenied { + t.Errorf("Error = %q; want = %q", err.Error(), permDenied) + } + + got = "" + r = guestClient.NewRef("_adminsdk/go") + if err := r.Get(context.Background(), &got); err == nil || got != "" { + t.Errorf("Get() = (%q, %v); want = (empty, error)", got, err) + } else if err.Error() != permDenied { + t.Errorf("Error = %q; want = %q", err.Error(), permDenied) + } + + c := r.Child("protected/user2") + if err := c.Get(context.Background(), &got); err == nil || got != "" { + t.Errorf("Get() = (%q, %v); want = (empty, error)", got, err) + } else if err.Error() != permDenied { + t.Errorf("Error = %q; want = %q", err.Error(), permDenied) + } + + c = r.Child("admin") + if err := c.Get(context.Background(), &got); err == nil || got != "" { + t.Errorf("Get() = (%q, %v); want = (empty, error)", got, err) + } else if err.Error() != permDenied { + t.Errorf("Error = %q; want = %q", err.Error(), permDenied) + } +} + +func TestWithContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + var m map[string]interface{} + if err := ref.Get(ctx, &m); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(testData, m) { + t.Errorf("Get() = %v; want = %v", m, testData) + } + + cancel() + m = nil + if err := ref.Get(ctx, &m); len(m) != 0 || err == nil { + t.Errorf("Get() = (%v, %v); want = (empty, error)", m, err) + } +} + +func protectedRef(t *testing.T, p string) string { + r := client.NewRef(p) + if err := r.Set(context.Background(), "test"); err != nil { + t.Fatal(err) + } + return p +} + +type Dinosaur struct { + Appeared int `json:"appeared"` + Height float64 `json:"height"` + Length float64 `json:"length"` + Order string `json:"order"` + Vanished int `json:"vanished"` + Weight int `json:"weight"` + Ratings Ratings `json:"ratings"` +} + +type Ratings struct { + Pos int `json:"pos"` +} + +type User struct { + Name string `json:"name"` + Since int `json:"since"` +} diff --git a/integration/db/query_test.go b/integration/db/query_test.go new file mode 100644 index 00000000..6573d915 --- /dev/null +++ b/integration/db/query_test.go @@ -0,0 +1,266 @@ +// Copyright 2018 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package db + +import ( + "testing" + + "firebase.google.com/go/db" + + "reflect" + + "golang.org/x/net/context" +) + +var heightSorted = []string{ + "linhenykus", "pterodactyl", "lambeosaurus", + "triceratops", "stegosaurus", "bruhathkayosaurus", +} + +func TestLimitToFirst(t *testing.T) { + for _, tc := range []int{2, 10} { + results, err := dinos.OrderByChild("height").LimitToFirst(tc).GetOrdered(context.Background()) + if err != nil { + t.Fatal(err) + } + + wl := min(tc, len(heightSorted)) + want := heightSorted[:wl] + if len(results) != wl { + t.Errorf("LimitToFirst() = %d; want = %d", len(results), wl) + } + got := getNames(results) + if !reflect.DeepEqual(got, want) { + t.Errorf("LimitToLast() = %v; want = %v", got, want) + } + compareValues(t, results) + } +} + +func TestLimitToLast(t *testing.T) { + for _, tc := range []int{2, 10} { + results, err := dinos.OrderByChild("height").LimitToLast(tc).GetOrdered(context.Background()) + if err != nil { + t.Fatal(err) + } + + wl := min(tc, len(heightSorted)) + want := heightSorted[len(heightSorted)-wl:] + if len(results) != wl { + t.Errorf("LimitToLast() = %d; want = %d", len(results), wl) + } + got := getNames(results) + if !reflect.DeepEqual(got, want) { + t.Errorf("LimitToLast() = %v; want = %v", got, want) + } + compareValues(t, results) + } +} + +func TestStartAt(t *testing.T) { + results, err := dinos.OrderByChild("height").StartAt(3.5).GetOrdered(context.Background()) + if err != nil { + t.Fatal(err) + } + + want := heightSorted[len(heightSorted)-2:] + if len(results) != len(want) { + t.Errorf("StartAt() = %d; want = %d", len(results), len(want)) + } + got := getNames(results) + if !reflect.DeepEqual(got, want) { + t.Errorf("LimitToLast() = %v; want = %v", got, want) + } + compareValues(t, results) +} + +func TestEndAt(t *testing.T) { + results, err := dinos.OrderByChild("height").EndAt(3.5).GetOrdered(context.Background()) + if err != nil { + t.Fatal(err) + } + + want := heightSorted[:4] + if len(results) != len(want) { + t.Errorf("StartAt() = %d; want = %d", len(results), len(want)) + } + got := getNames(results) + if !reflect.DeepEqual(got, want) { + t.Errorf("LimitToLast() = %v; want = %v", got, want) + } + compareValues(t, results) +} + +func TestStartAndEndAt(t *testing.T) { + results, err := dinos.OrderByChild("height").StartAt(2.5).EndAt(5).GetOrdered(context.Background()) + if err != nil { + t.Fatal(err) + } + + want := heightSorted[len(heightSorted)-3 : len(heightSorted)-1] + if len(results) != len(want) { + t.Errorf("StartAt(), EndAt() = %d; want = %d", len(results), len(want)) + } + got := getNames(results) + if !reflect.DeepEqual(got, want) { + t.Errorf("LimitToLast() = %v; want = %v", got, want) + } + compareValues(t, results) +} + +func TestEqualTo(t *testing.T) { + results, err := dinos.OrderByChild("height").EqualTo(0.6).GetOrdered(context.Background()) + if err != nil { + t.Fatal(err) + } + + want := heightSorted[:2] + if len(results) != len(want) { + t.Errorf("EqualTo() = %d; want = %d", len(results), len(want)) + } + got := getNames(results) + if !reflect.DeepEqual(got, want) { + t.Errorf("LimitToLast() = %v; want = %v", got, want) + } + compareValues(t, results) +} + +func TestOrderByNestedChild(t *testing.T) { + results, err := dinos.OrderByChild("ratings/pos").StartAt(4).GetOrdered(context.Background()) + if err != nil { + t.Fatal(err) + } + + want := []string{"pterodactyl", "stegosaurus", "triceratops"} + if len(results) != len(want) { + t.Errorf("OrderByChild(ratings/pos) = %d; want = %d", len(results), len(want)) + } + got := getNames(results) + if !reflect.DeepEqual(got, want) { + t.Errorf("LimitToLast() = %v; want = %v", got, want) + } + compareValues(t, results) +} + +func TestOrderByKey(t *testing.T) { + results, err := dinos.OrderByKey().LimitToFirst(2).GetOrdered(context.Background()) + if err != nil { + t.Fatal(err) + } + + want := []string{"bruhathkayosaurus", "lambeosaurus"} + if len(results) != len(want) { + t.Errorf("OrderByKey() = %d; want = %d", len(results), len(want)) + } + got := getNames(results) + if !reflect.DeepEqual(got, want) { + t.Errorf("LimitToLast() = %v; want = %v", got, want) + } + compareValues(t, results) +} + +func TestOrderByValue(t *testing.T) { + scores := ref.Child("scores") + results, err := scores.OrderByValue().LimitToLast(2).GetOrdered(context.Background()) + if err != nil { + t.Fatal(err) + } + + want := []string{"linhenykus", "pterodactyl"} + if len(results) != len(want) { + t.Errorf("OrderByValue() = %d; want = %d", len(results), len(want)) + } + got := getNames(results) + if !reflect.DeepEqual(got, want) { + t.Errorf("LimitToLast() = %v; want = %v", got, want) + } + wantScores := []int{80, 93} + for i, r := range results { + var val int + if err := r.Unmarshal(&val); err != nil { + t.Fatalf("queryNode.Unmarshal() = %v", err) + } + if val != wantScores[i] { + t.Errorf("queryNode.Unmarshal() = %d; want = %d", val, wantScores[i]) + } + } +} + +func TestQueryWithContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + q := dinos.OrderByKey().LimitToFirst(2) + var m map[string]Dinosaur + if err := q.Get(ctx, &m); err != nil { + t.Fatal(err) + } + + want := []string{"bruhathkayosaurus", "lambeosaurus"} + if len(m) != len(want) { + t.Errorf("OrderByKey() = %d; want = %d", len(m), len(want)) + } + + cancel() + m = nil + if err := q.Get(ctx, &m); len(m) != 0 || err == nil { + t.Errorf("Get() = (%v, %v); want = (empty, error)", m, err) + } +} + +func TestUnorderedQuery(t *testing.T) { + var m map[string]Dinosaur + if err := dinos.OrderByChild("height"). + StartAt(2.5). + EndAt(5). + Get(context.Background(), &m); err != nil { + t.Fatal(err) + } + + want := heightSorted[len(heightSorted)-3 : len(heightSorted)-1] + if len(m) != len(want) { + t.Errorf("Get() = %d; want = %d", len(m), len(want)) + } + for i, w := range want { + if _, ok := m[w]; !ok { + t.Errorf("[%d] result[%q] not present", i, w) + } + } +} + +func min(i, j int) int { + if i < j { + return i + } + return j +} + +func getNames(results []db.QueryNode) []string { + s := make([]string, len(results)) + for i, v := range results { + s[i] = v.Key() + } + return s +} + +func compareValues(t *testing.T, results []db.QueryNode) { + for _, r := range results { + var d Dinosaur + if err := r.Unmarshal(&d); err != nil { + t.Fatalf("queryNode.Unmarshal(%q) = %v", r.Key(), err) + } + if !reflect.DeepEqual(d, parsedTestData[r.Key()]) { + t.Errorf("queryNode.Unmarshal(%q) = %v; want = %v", r.Key(), d, parsedTestData[r.Key()]) + } + } +} diff --git a/integration/firestore/firestore_test.go b/integration/firestore/firestore_test.go index 6c367205..1b861d92 100644 --- a/integration/firestore/firestore_test.go +++ b/integration/firestore/firestore_test.go @@ -29,7 +29,7 @@ func TestFirestore(t *testing.T) { return } ctx := context.Background() - app, err := internal.NewTestApp(ctx) + app, err := internal.NewTestApp(ctx, nil) if err != nil { t.Fatal(err) } diff --git a/integration/iid/iid_test.go b/integration/iid/iid_test.go index 55cf5620..2b1b1c9c 100644 --- a/integration/iid/iid_test.go +++ b/integration/iid/iid_test.go @@ -36,7 +36,7 @@ func TestMain(m *testing.M) { } ctx := context.Background() - app, err := internal.NewTestApp(ctx) + app, err := internal.NewTestApp(ctx, nil) if err != nil { log.Fatalln(err) } diff --git a/integration/internal/internal.go b/integration/internal/internal.go index 1fe112a3..a5cd7af6 100644 --- a/integration/internal/internal.go +++ b/integration/internal/internal.go @@ -20,11 +20,14 @@ import ( "encoding/json" "go/build" "io/ioutil" + "net/http" "path/filepath" "strings" firebase "firebase.google.com/go" + "firebase.google.com/go/internal" "google.golang.org/api/option" + "google.golang.org/api/transport" ) const certPath = "integration_cert.json" @@ -41,15 +44,8 @@ func Resource(name string) string { // NewTestApp looks for a service account JSON file named integration_cert.json // in the testdata directory. This file is used to initialize the newly created // App instance. -func NewTestApp(ctx context.Context) (*firebase.App, error) { - pid, err := ProjectID() - if err != nil { - return nil, err - } - config := &firebase.Config{ - StorageBucket: pid + ".appspot.com", - } - return firebase.NewApp(ctx, config, option.WithCredentialsFile(Resource(certPath))) +func NewTestApp(ctx context.Context, conf *firebase.Config) (*firebase.App, error) { + return firebase.NewApp(ctx, conf, option.WithCredentialsFile(Resource(certPath))) } // APIKey fetches a Firebase API key for integration tests. @@ -78,3 +74,14 @@ func ProjectID() (string, error) { } return serviceAccount.ProjectID, nil } + +// NewHTTPClient creates an HTTP client for making authorized requests during tests. +func NewHTTPClient(ctx context.Context, opts ...option.ClientOption) (*http.Client, error) { + opts = append( + opts, + option.WithCredentialsFile(Resource(certPath)), + option.WithScopes(internal.FirebaseScopes...), + ) + hc, _, err := transport.NewHTTPClient(ctx, opts...) + return hc, err +} diff --git a/integration/messaging/messaging_test.go b/integration/messaging/messaging_test.go index 231aab34..4b8ef6d7 100644 --- a/integration/messaging/messaging_test.go +++ b/integration/messaging/messaging_test.go @@ -44,7 +44,7 @@ func TestMain(m *testing.M) { } ctx := context.Background() - app, err := internal.NewTestApp(ctx) + app, err := internal.NewTestApp(ctx, nil) if err != nil { log.Fatalln(err) } diff --git a/integration/storage/storage_test.go b/integration/storage/storage_test.go index b6fc2301..b5a205d5 100644 --- a/integration/storage/storage_test.go +++ b/integration/storage/storage_test.go @@ -23,6 +23,8 @@ import ( "os" "testing" + "firebase.google.com/go" + gcs "cloud.google.com/go/storage" "firebase.google.com/go/integration/internal" "firebase.google.com/go/storage" @@ -38,8 +40,15 @@ func TestMain(m *testing.M) { os.Exit(0) } + pid, err := internal.ProjectID() + if err != nil { + log.Fatalln(err) + } + ctx = context.Background() - app, err := internal.NewTestApp(ctx) + app, err := internal.NewTestApp(ctx, &firebase.Config{ + StorageBucket: fmt.Sprintf("%s.appspot.com", pid), + }) if err != nil { log.Fatalln(err) } diff --git a/internal/internal.go b/internal/internal.go index 225edc9e..bc4f41d1 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -21,6 +21,16 @@ import ( "google.golang.org/api/option" ) +// FirebaseScopes is the set of OAuth2 scopes used by the Admin SDK. +var FirebaseScopes = []string{ + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/datastore", + "https://www.googleapis.com/auth/devstorage.full_control", + "https://www.googleapis.com/auth/firebase", + "https://www.googleapis.com/auth/identitytoolkit", + "https://www.googleapis.com/auth/userinfo.email", +} + // AuthConfig represents the configuration of Firebase Auth service. type AuthConfig struct { Opts []option.ClientOption @@ -35,6 +45,14 @@ type InstanceIDConfig struct { ProjectID string } +// DatabaseConfig represents the configuration of Firebase Database service. +type DatabaseConfig struct { + Opts []option.ClientOption + URL string + Version string + AuthOverride map[string]interface{} +} + // StorageConfig represents the configuration of Google Cloud Storage service. type StorageConfig struct { Opts []option.ClientOption diff --git a/snippets/auth.go b/snippets/auth.go index d9548e94..9fb739ba 100644 --- a/snippets/auth.go +++ b/snippets/auth.go @@ -93,10 +93,8 @@ func verifyIDToken(app *firebase.App, idToken string) *auth.Token { // https://firebase.google.com/docs/auth/admin/manage-sessions // ================================================================== -func revokeRefreshTokens(app *firebase.App, uid string) { - +func revokeRefreshTokens(ctx context.Context, app *firebase.App, uid string) { // [START revoke_tokens_golang] - ctx := context.Background() client, err := app.Auth(ctx) if err != nil { log.Fatalf("error getting Auth client: %v\n", err) @@ -114,8 +112,7 @@ func revokeRefreshTokens(app *firebase.App, uid string) { // [END revoke_tokens_golang] } -func verifyIDTokenAndCheckRevoked(app *firebase.App, idToken string) *auth.Token { - ctx := context.Background() +func verifyIDTokenAndCheckRevoked(ctx context.Context, app *firebase.App, idToken string) *auth.Token { // [START verify_id_token_and_check_revoked_golang] client, err := app.Auth(ctx) if err != nil { @@ -144,7 +141,7 @@ func getUser(ctx context.Context, app *firebase.App) *auth.UserRecord { // [START get_user_golang] // Get an auth client from the firebase.App - client, err := app.Auth(context.Background()) + client, err := app.Auth(ctx) if err != nil { log.Fatalf("error getting Auth client: %v\n", err) } @@ -192,7 +189,7 @@ func createUser(ctx context.Context, client *auth.Client) *auth.UserRecord { DisplayName("John Doe"). PhotoURL("http://www.example.com/12345678/photo.png"). Disabled(false) - u, err := client.CreateUser(context.Background(), params) + u, err := client.CreateUser(ctx, params) if err != nil { log.Fatalf("error creating user: %v\n", err) } @@ -208,7 +205,7 @@ func createUserWithUID(ctx context.Context, client *auth.Client) *auth.UserRecor UID(uid). Email("user@example.com"). PhoneNumber("+15555550100") - u, err := client.CreateUser(context.Background(), params) + u, err := client.CreateUser(ctx, params) if err != nil { log.Fatalf("error creating user: %v\n", err) } @@ -228,7 +225,7 @@ func updateUser(ctx context.Context, client *auth.Client) { DisplayName("John Doe"). PhotoURL("http://www.example.com/12345678/photo.png"). Disabled(true) - u, err := client.UpdateUser(context.Background(), uid, params) + u, err := client.UpdateUser(ctx, uid, params) if err != nil { log.Fatalf("error updating user: %v\n", err) } @@ -239,7 +236,7 @@ func updateUser(ctx context.Context, client *auth.Client) { func deleteUser(ctx context.Context, client *auth.Client) { uid := "d" // [START delete_user_golang] - err := client.DeleteUser(context.Background(), uid) + err := client.DeleteUser(ctx, uid) if err != nil { log.Fatalf("error deleting user: %v\n", err) } @@ -251,14 +248,14 @@ func customClaimsSet(ctx context.Context, app *firebase.App) { uid := "uid" // [START set_custom_user_claims_golang] // Get an auth client from the firebase.App - client, err := app.Auth(context.Background()) + client, err := app.Auth(ctx) if err != nil { log.Fatalf("error getting Auth client: %v\n", err) } // Set admin privilege on the user corresponding to uid. claims := map[string]interface{}{"admin": true} - err = client.SetCustomUserClaims(context.Background(), uid, claims) + err = client.SetCustomUserClaims(ctx, uid, claims) if err != nil { log.Fatalf("error setting custom claims %v\n", err) } @@ -350,7 +347,7 @@ func customClaimsIncremental(ctx context.Context, client *auth.Client) { func listUsers(ctx context.Context, client *auth.Client) { // [START list_all_users_golang] // Note, behind the scenes, the Users() iterator will retrive 1000 Users at a time through the API - iter := client.Users(context.Background(), "") + iter := client.Users(ctx, "") for { user, err := iter.Next() if err == iterator.Done { @@ -365,7 +362,7 @@ func listUsers(ctx context.Context, client *auth.Client) { // Iterating by pages 100 users at a time. // Note that using both the Next() function on an iterator and the NextPage() // on a Pager wrapping that same iterator will result in an error. - pager := iterator.NewPager(client.Users(context.Background(), ""), 100, "") + pager := iterator.NewPager(client.Users(ctx, ""), 100, "") for { var users []*auth.ExportedUserRecord nextPageToken, err := pager.NextPage(&users) diff --git a/snippets/db.go b/snippets/db.go new file mode 100644 index 00000000..8e0bea71 --- /dev/null +++ b/snippets/db.go @@ -0,0 +1,528 @@ +// Copyright 2018 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package snippets + +// [START authenticate_db_imports] +import ( + "context" + "fmt" + "log" + + "firebase.google.com/go/db" + + "firebase.google.com/go" + "google.golang.org/api/option" +) + +// [END authenticate_db_imports] + +func authenticateWithAdminPrivileges() { + // [START authenticate_with_admin_privileges] + ctx := context.Background() + conf := &firebase.Config{ + DatabaseURL: "https://databaseName.firebaseio.com", + } + // Fetch the service account key JSON file contents + opt := option.WithCredentialsFile("path/to/serviceAccountKey.json") + + // Initialize the app with a service account, granting admin privileges + app, err := firebase.NewApp(ctx, conf, opt) + if err != nil { + log.Fatalln("Error initializing app:", err) + } + + client, err := app.Database(ctx) + if err != nil { + log.Fatalln("Error initializing database client:", err) + } + + // As an admin, the app has access to read and write all data, regradless of Security Rules + ref := client.NewRef("restricted_access/secret_document") + var data map[string]interface{} + if err := ref.Get(ctx, &data); err != nil { + log.Fatalln("Error reading from database:", err) + } + fmt.Println(data) + // [END authenticate_with_admin_privileges] +} + +func authenticateWithLimitedPrivileges() { + // [START authenticate_with_limited_privileges] + ctx := context.Background() + // Initialize the app with a custom auth variable, limiting the server's access + ao := map[string]interface{}{"uid": "my-service-worker"} + conf := &firebase.Config{ + DatabaseURL: "https://databaseName.firebaseio.com", + AuthOverride: &ao, + } + + // Fetch the service account key JSON file contents + opt := option.WithCredentialsFile("path/to/serviceAccountKey.json") + + app, err := firebase.NewApp(ctx, conf, opt) + if err != nil { + log.Fatalln("Error initializing app:", err) + } + + client, err := app.Database(ctx) + if err != nil { + log.Fatalln("Error initializing database client:", err) + } + + // The app only has access as defined in the Security Rules + ref := client.NewRef("/some_resource") + var data map[string]interface{} + if err := ref.Get(ctx, &data); err != nil { + log.Fatalln("Error reading from database:", err) + } + fmt.Println(data) + // [END authenticate_with_limited_privileges] +} + +func authenticateWithGuestPrivileges() { + // [START authenticate_with_guest_privileges] + ctx := context.Background() + // Initialize the app with a nil auth variable, limiting the server's access + var nilMap map[string]interface{} + conf := &firebase.Config{ + DatabaseURL: "https://databaseName.firebaseio.com", + AuthOverride: &nilMap, + } + + // Fetch the service account key JSON file contents + opt := option.WithCredentialsFile("path/to/serviceAccountKey.json") + + app, err := firebase.NewApp(ctx, conf, opt) + if err != nil { + log.Fatalln("Error initializing app:", err) + } + + client, err := app.Database(ctx) + if err != nil { + log.Fatalln("Error initializing database client:", err) + } + + // The app only has access to public data as defined in the Security Rules + ref := client.NewRef("/some_resource") + var data map[string]interface{} + if err := ref.Get(ctx, &data); err != nil { + log.Fatalln("Error reading from database:", err) + } + fmt.Println(data) + // [END authenticate_with_guest_privileges] +} + +func getReference(ctx context.Context, app *firebase.App) { + // [START get_reference] + // Create a database client from App. + client, err := app.Database(ctx) + if err != nil { + log.Fatalln("Error initializing database client:", err) + } + + // Get a database reference to our blog. + ref := client.NewRef("server/saving-data/fireblog") + // [END get_reference] + fmt.Println(ref.Path) +} + +// [START user_type] + +// User is a json-serializable type. +type User struct { + DateOfBirth string `json:"date_of_birth,omitempty"` + FullName string `json:"full_name,omitempty"` + Nickname string `json:"nickname,omitempty"` +} + +// [END user_type] + +func setValue(ctx context.Context, ref *db.Ref) { + // [START set_value] + usersRef := ref.Child("users") + err := usersRef.Set(ctx, map[string]*User{ + "alanisawesome": &User{ + DateOfBirth: "June 23, 1912", + FullName: "Alan Turing", + }, + "gracehop": &User{ + DateOfBirth: "December 9, 1906", + FullName: "Grace Hopper", + }, + }) + if err != nil { + log.Fatalln("Error setting value:", err) + } + // [END set_value] +} + +func setChildValue(ctx context.Context, usersRef *db.Ref) { + // [START set_child_value] + if err := usersRef.Child("alanisawesome").Set(ctx, &User{ + DateOfBirth: "June 23, 1912", + FullName: "Alan Turing", + }); err != nil { + log.Fatalln("Error setting value:", err) + } + + if err := usersRef.Child("gracehop").Set(ctx, &User{ + DateOfBirth: "December 9, 1906", + FullName: "Grace Hopper", + }); err != nil { + log.Fatalln("Error setting value:", err) + } + // [END set_child_value] +} + +func updateChild(ctx context.Context, usersRef *db.Ref) { + // [START update_child] + hopperRef := usersRef.Child("gracehop") + if err := hopperRef.Update(ctx, map[string]interface{}{ + "nickname": "Amazing Grace", + }); err != nil { + log.Fatalln("Error updating child:", err) + } + // [END update_child] +} + +func updateChildren(ctx context.Context, usersRef *db.Ref) { + // [START update_children] + if err := usersRef.Update(ctx, map[string]interface{}{ + "alanisawesome/nickname": "Alan The Machine", + "gracehop/nickname": "Amazing Grace", + }); err != nil { + log.Fatalln("Error updating children:", err) + } + // [END update_children] +} + +func overwriteValue(ctx context.Context, usersRef *db.Ref) { + // [START overwrite_value] + if err := usersRef.Update(ctx, map[string]interface{}{ + "alanisawesome": &User{Nickname: "Alan The Machine"}, + "gracehop": &User{Nickname: "Amazing Grace"}, + }); err != nil { + log.Fatalln("Error updating children:", err) + } + // [END overwrite_value] +} + +// [START post_type] + +// Post is a json-serializable type. +type Post struct { + Author string `json:"author,omitempty"` + Title string `json:"title,omitempty"` +} + +// [END post_type] + +func pushValue(ctx context.Context, ref *db.Ref) { + // [START push_value] + postsRef := ref.Child("posts") + + newPostRef, err := postsRef.Push(ctx, nil) + if err != nil { + log.Fatalln("Error pushing child node:", err) + } + + if err := newPostRef.Set(ctx, &Post{ + Author: "gracehop", + Title: "Announcing COBOL, a New Programming Language", + }); err != nil { + log.Fatalln("Error setting value:", err) + } + + // We can also chain the two calls together + if _, err := postsRef.Push(ctx, &Post{ + Author: "alanisawesome", + Title: "The Turing Machine", + }); err != nil { + log.Fatalln("Error pushing child node:", err) + } + // [END push_value] +} + +func pushAndSetValue(ctx context.Context, postsRef *db.Ref) { + // [START push_and_set_value] + if _, err := postsRef.Push(ctx, &Post{ + Author: "gracehop", + Title: "Announcing COBOL, a New Programming Language", + }); err != nil { + log.Fatalln("Error pushing child node:", err) + } + // [END push_and_set_value] +} + +func pushKey(ctx context.Context, postsRef *db.Ref) { + // [START push_key] + // Generate a reference to a new location and add some data using Push() + newPostRef, err := postsRef.Push(ctx, nil) + if err != nil { + log.Fatalln("Error pushing child node:", err) + } + + // Get the unique key generated by Push() + postID := newPostRef.Key + // [END push_key] + fmt.Println(postID) +} + +func transaction(ctx context.Context, client *db.Client) { + // [START transaction] + fn := func(t db.TransactionNode) (interface{}, error) { + var currentValue int + if err := t.Unmarshal(¤tValue); err != nil { + return nil, err + } + return currentValue + 1, nil + } + + ref := client.NewRef("server/saving-data/fireblog/posts/-JRHTHaIs-jNPLXOQivY/upvotes") + if err := ref.Transaction(ctx, fn); err != nil { + log.Fatalln("Transaction failed to commit:", err) + } + // [END transaction] +} + +func readValue(ctx context.Context, app *firebase.App) { + // [START read_value] + // Create a database client from App. + client, err := app.Database(ctx) + if err != nil { + log.Fatalln("Error initializing database client:", err) + } + + // Get a database reference to our posts + ref := client.NewRef("server/saving-data/fireblog/posts") + + // Read the data at the posts reference (this is a blocking operation) + var post Post + if err := ref.Get(ctx, &post); err != nil { + log.Fatalln("Error reading value:", err) + } + // [END read_value] + fmt.Println(ref.Path) +} + +// [START dinosaur_type] + +// Dinosaur is a json-serializable type. +type Dinosaur struct { + Height int `json:"height"` + Width int `json:"width"` +} + +// [END dinosaur_type] + +func orderByChild(ctx context.Context, client *db.Client) { + // [START order_by_child] + ref := client.NewRef("dinosaurs") + + results, err := ref.OrderByChild("height").GetOrdered(ctx) + if err != nil { + log.Fatalln("Error querying database:", err) + } + for _, r := range results { + var d Dinosaur + if err := r.Unmarshal(&d); err != nil { + log.Fatalln("Error unmarshaling result:", err) + } + fmt.Printf("%s was %d meteres tall", r.Key(), d.Height) + } + // [END order_by_child] +} + +func orderByNestedChild(ctx context.Context, client *db.Client) { + // [START order_by_nested_child] + ref := client.NewRef("dinosaurs") + + results, err := ref.OrderByChild("dimensions/height").GetOrdered(ctx) + if err != nil { + log.Fatalln("Error querying database:", err) + } + for _, r := range results { + var d Dinosaur + if err := r.Unmarshal(&d); err != nil { + log.Fatalln("Error unmarshaling result:", err) + } + fmt.Printf("%s was %d meteres tall", r.Key(), d.Height) + } + // [END order_by_nested_child] +} + +func orderByKey(ctx context.Context, client *db.Client) { + // [START order_by_key] + ref := client.NewRef("dinosaurs") + + results, err := ref.OrderByKey().GetOrdered(ctx) + if err != nil { + log.Fatalln("Error querying database:", err) + } + snapshot := make([]Dinosaur, len(results)) + for i, r := range results { + var d Dinosaur + if err := r.Unmarshal(&d); err != nil { + log.Fatalln("Error unmarshaling result:", err) + } + snapshot[i] = d + } + fmt.Println(snapshot) + // [END order_by_key] +} + +func orderByValue(ctx context.Context, client *db.Client) { + // [START order_by_value] + ref := client.NewRef("scores") + + results, err := ref.OrderByValue().GetOrdered(ctx) + if err != nil { + log.Fatalln("Error querying database:", err) + } + for _, r := range results { + var score int + if err := r.Unmarshal(&score); err != nil { + log.Fatalln("Error unmarshaling result:", err) + } + fmt.Printf("The %s dinosaur's score is %d\n", r.Key(), score) + } + // [END order_by_value] +} + +func limitToLast(ctx context.Context, client *db.Client) { + // [START limit_query_1] + ref := client.NewRef("dinosaurs") + + results, err := ref.OrderByChild("weight").LimitToLast(2).GetOrdered(ctx) + if err != nil { + log.Fatalln("Error querying database:", err) + } + for _, r := range results { + fmt.Println(r.Key()) + } + // [END limit_query_1] +} + +func limitToFirst(ctx context.Context, client *db.Client) { + // [START limit_query_2] + ref := client.NewRef("dinosaurs") + + results, err := ref.OrderByChild("height").LimitToFirst(2).GetOrdered(ctx) + if err != nil { + log.Fatalln("Error querying database:", err) + } + for _, r := range results { + fmt.Println(r.Key()) + } + // [END limit_query_2] +} + +func limitWithValueOrder(ctx context.Context, client *db.Client) { + // [START limit_query_3] + ref := client.NewRef("scores") + + results, err := ref.OrderByValue().LimitToLast(3).GetOrdered(ctx) + if err != nil { + log.Fatalln("Error querying database:", err) + } + for _, r := range results { + var score int + if err := r.Unmarshal(&score); err != nil { + log.Fatalln("Error unmarshaling result:", err) + } + fmt.Printf("The %s dinosaur's score is %d\n", r.Key(), score) + } + // [END limit_query_3] +} + +func startAt(ctx context.Context, client *db.Client) { + // [START range_query_1] + ref := client.NewRef("dinosaurs") + + results, err := ref.OrderByChild("height").StartAt(3).GetOrdered(ctx) + if err != nil { + log.Fatalln("Error querying database:", err) + } + for _, r := range results { + fmt.Println(r.Key()) + } + // [END range_query_1] +} + +func endAt(ctx context.Context, client *db.Client) { + // [START range_query_2] + ref := client.NewRef("dinosaurs") + + results, err := ref.OrderByKey().EndAt("pterodactyl").GetOrdered(ctx) + if err != nil { + log.Fatalln("Error querying database:", err) + } + for _, r := range results { + fmt.Println(r.Key()) + } + // [END range_query_2] +} + +func startAndEndAt(ctx context.Context, client *db.Client) { + // [START range_query_3] + ref := client.NewRef("dinosaurs") + + results, err := ref.OrderByKey().StartAt("b").EndAt("b\uf8ff").GetOrdered(ctx) + if err != nil { + log.Fatalln("Error querying database:", err) + } + for _, r := range results { + fmt.Println(r.Key()) + } + // [END range_query_3] +} + +func equalTo(ctx context.Context, client *db.Client) { + // [START range_query_4] + ref := client.NewRef("dinosaurs") + + results, err := ref.OrderByChild("height").EqualTo(25).GetOrdered(ctx) + if err != nil { + log.Fatalln("Error querying database:", err) + } + for _, r := range results { + fmt.Println(r.Key()) + } + // [END range_query_4] +} + +func complexQuery(ctx context.Context, client *db.Client) { + // [START complex_query] + ref := client.NewRef("dinosaurs") + + var favDinoHeight int + if err := ref.Child("stegosaurus").Child("height").Get(ctx, &favDinoHeight); err != nil { + log.Fatalln("Error querying database:", err) + } + + query := ref.OrderByChild("height").EndAt(favDinoHeight).LimitToLast(2) + results, err := query.GetOrdered(ctx) + if err != nil { + log.Fatalln("Error querying database:", err) + } + if len(results) == 2 { + // Data is ordered by increasing height, so we want the first entry. + // Second entry is stegosarus. + fmt.Printf("The dinosaur just shorter than the stegosaurus is %s\n", results[0].Key()) + } else { + fmt.Println("The stegosaurus is the shortest dino") + } + // [END complex_query] +} diff --git a/testdata/dinosaurs.json b/testdata/dinosaurs.json new file mode 100644 index 00000000..9d7afaab --- /dev/null +++ b/testdata/dinosaurs.json @@ -0,0 +1,78 @@ +{ + "dinosaurs": { + "bruhathkayosaurus": { + "appeared": -70000000, + "height": 25, + "length": 44, + "order": "saurischia", + "vanished": -70000000, + "weight": 135000, + "ratings": { + "pos": 1 + } + }, + "lambeosaurus": { + "appeared": -76000000, + "height": 2.1, + "length": 12.5, + "order": "ornithischia", + "vanished": -75000000, + "weight": 5000, + "ratings": { + "pos": 2 + } + }, + "linhenykus": { + "appeared": -85000000, + "height": 0.6, + "length": 1, + "order": "theropoda", + "vanished": -75000000, + "weight": 3, + "ratings": { + "pos": 3 + } + }, + "pterodactyl": { + "appeared": -150000000, + "height": 0.6, + "length": 0.8, + "order": "pterosauria", + "vanished": -148500000, + "weight": 2, + "ratings": { + "pos": 4 + } + }, + "stegosaurus": { + "appeared": -155000000, + "height": 4, + "length": 9, + "order": "ornithischia", + "vanished": -150000000, + "weight": 2500, + "ratings": { + "pos": 5 + } + }, + "triceratops": { + "appeared": -68000000, + "height": 3, + "length": 8, + "order": "ornithischia", + "vanished": -66000000, + "weight": 11000, + "ratings": { + "pos": 6 + } + } + }, + "scores": { + "bruhathkayosaurus": 55, + "lambeosaurus": 21, + "linhenykus": 80, + "pterodactyl": 93, + "stegosaurus": 5, + "triceratops": 22 + } +} diff --git a/testdata/dinosaurs_index.json b/testdata/dinosaurs_index.json new file mode 100644 index 00000000..bf4a2551 --- /dev/null +++ b/testdata/dinosaurs_index.json @@ -0,0 +1,29 @@ +{ + "rules": { + "_adminsdk": { + "go": { + "dinodb": { + "dinosaurs": { + ".indexOn": ["height", "ratings/pos"] + }, + "scores": { + ".indexOn": ".value" + } + }, + "protected": { + "$uid": { + ".read": "auth != null", + ".write": "$uid === auth.uid" + } + }, + "admin": { + ".read": "false", + ".write": "false" + }, + "public": { + ".read": "true" + } + } + } + } +} \ No newline at end of file diff --git a/testdata/firebase_config.json b/testdata/firebase_config.json index e9a3b5bc..772da62d 100644 --- a/testdata/firebase_config.json +++ b/testdata/firebase_config.json @@ -1,4 +1,5 @@ { + "databaseURL": "https://auto-init.database.url", "projectId": "auto-init-project-id", "storageBucket": "auto-init.storage.bucket" } From 3a49e807af383d1459ae3fa7ebac088fc3c1f588 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Tue, 27 Feb 2018 17:55:23 -0800 Subject: [PATCH 15/27] Handling FCM canonical error codes (#103) --- CHANGELOG.md | 2 ++ messaging/messaging.go | 15 +++++++++++---- messaging/messaging_test.go | 4 ++++ 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9b2583f7..8e0ff764 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,7 @@ # Unreleased +- [changed] Improved error handling in FCM by mapping more server-side + errors to client-side error codes. - [added] Added the `db` package for interacting with the Firebase database. # v2.5.0 diff --git a/messaging/messaging.go b/messaging/messaging.go index a3d71e97..231e7212 100644 --- a/messaging/messaging.go +++ b/messaging/messaging.go @@ -41,13 +41,20 @@ var ( topicNamePattern = regexp.MustCompile("^(/topics/)?(private/)?[a-zA-Z0-9-_.~%]+$") fcmErrorCodes = map[string]string{ + // FCM v1 canonical error codes + "NOT_FOUND": "app instance has been unregistered; code: registration-token-not-registered", + "PERMISSION_DENIED": "sender id does not match regisration token; code: mismatched-credential", + "RESOURCE_EXHAUSTED": "messaging service quota exceeded; code: message-rate-exceeded", + "UNAUTHENTICATED": "apns certificate or auth key was invalid; code: invalid-apns-credentials", + + // FCM v1 new error codes + "APNS_AUTH_ERROR": "apns certificate or auth key was invalid; code: invalid-apns-credentials", + "INTERNAL": "back servers encountered an unknown internl error; code: internal-error", "INVALID_ARGUMENT": "request contains an invalid argument; code: invalid-argument", - "UNREGISTERED": "app instance has been unregistered; code: registration-token-not-registered", - "SENDER_ID_MISMATCH": "sender id does not match regisration token; code: authentication-error", + "SENDER_ID_MISMATCH": "sender id does not match regisration token; code: mismatched-credential", "QUOTA_EXCEEDED": "messaging service quota exceeded; code: message-rate-exceeded", - "APNS_AUTH_ERROR": "apns certificate or auth key was invalid; code: authentication-error", "UNAVAILABLE": "backend servers are temporarily unavailable; code: server-unavailable", - "INTERNAL": "back servers encountered an unknown internl error; code: internal-error", + "UNREGISTERED": "app instance has been unregistered; code: registration-token-not-registered", } iidErrorCodes = map[string]string{ diff --git a/messaging/messaging_test.go b/messaging/messaging_test.go index 53b0650a..79ef498b 100644 --- a/messaging/messaging_test.go +++ b/messaging/messaging_test.go @@ -632,6 +632,10 @@ func TestSendError(t *testing.T) { resp: "{\"error\": {\"status\": \"INVALID_ARGUMENT\", \"message\": \"test error\"}}", want: "http error status: 500; reason: request contains an invalid argument; code: invalid-argument", }, + { + resp: "{\"error\": {\"status\": \"NOT_FOUND\", \"message\": \"test error\"}}", + want: "http error status: 500; reason: app instance has been unregistered; code: registration-token-not-registered", + }, { resp: "not json", want: "http error status: 500; reason: server responded with an unknown error; response: not json", From f03d5a6c96b7fa6d1be0c58b6f1f85373656dd2b Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Wed, 28 Feb 2018 10:57:50 -0800 Subject: [PATCH 16/27] Formatting test file with gofmt (#104) --- auth/user_mgt_test.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/auth/user_mgt_test.go b/auth/user_mgt_test.go index 6a95312a..3f298ed6 100644 --- a/auth/user_mgt_test.go +++ b/auth/user_mgt_test.go @@ -167,9 +167,9 @@ func TestListUsers(t *testing.T) { defer s.Close() want := []*ExportedUserRecord{ - &ExportedUserRecord{UserRecord: testUser, PasswordHash: "passwordhash1", PasswordSalt: "salt1"}, - &ExportedUserRecord{UserRecord: testUser, PasswordHash: "passwordhash2", PasswordSalt: "salt2"}, - &ExportedUserRecord{UserRecord: testUser, PasswordHash: "passwordhash3", PasswordSalt: "salt3"}, + {UserRecord: testUser, PasswordHash: "passwordhash1", PasswordSalt: "salt1"}, + {UserRecord: testUser, PasswordHash: "passwordhash2", PasswordSalt: "salt2"}, + {UserRecord: testUser, PasswordHash: "passwordhash3", PasswordSalt: "salt3"}, } testIterator := func(iter *UserIterator, token string, req map[string]interface{}) { @@ -574,9 +574,9 @@ func TestInvalidSetCustomClaims(t *testing.T) { func TestSetCustomClaims(t *testing.T) { cases := []map[string]interface{}{ nil, - map[string]interface{}{}, - map[string]interface{}{"admin": true}, - map[string]interface{}{"admin": true, "package": "gold"}, + {}, + {"admin": true}, + {"admin": true, "package": "gold"}, } resp := `{ From f0be2f40493a1fb1884af057d0d29dfbdc2f4bb8 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Wed, 28 Feb 2018 11:36:46 -0800 Subject: [PATCH 17/27] Bumped version to 2.6.0 (#105) --- CHANGELOG.md | 4 ++++ firebase.go | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8e0ff764..5852ed27 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # Unreleased +- + +# v2.6.0 + - [changed] Improved error handling in FCM by mapping more server-side errors to client-side error codes. - [added] Added the `db` package for interacting with the Firebase database. diff --git a/firebase.go b/firebase.go index 0c6718da..0e34c058 100644 --- a/firebase.go +++ b/firebase.go @@ -41,7 +41,7 @@ import ( var defaultAuthOverrides = make(map[string]interface{}) // Version of the Firebase Go Admin SDK. -const Version = "2.5.0" +const Version = "2.6.0" // firebaseEnvName is the name of the environment variable with the Config. const firebaseEnvName = "FIREBASE_CONFIG" From 04299faa819412cfd53a5a1226fdd9157fb59164 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Wed, 28 Feb 2018 18:07:34 -0800 Subject: [PATCH 18/27] Formatting (simplification) changes (#107) --- db/db_test.go | 2 +- db/query.go | 2 +- db/ref_test.go | 22 +++++++++++----------- firebase_test.go | 4 ++-- snippets/db.go | 4 ++-- 5 files changed, 17 insertions(+), 17 deletions(-) diff --git a/db/db_test.go b/db/db_test.go index 01234504..e811c9df 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -107,7 +107,7 @@ func TestNewClient(t *testing.T) { func TestNewClientAuthOverrides(t *testing.T) { cases := []map[string]interface{}{ nil, - map[string]interface{}{"uid": "user1"}, + {"uid": "user1"}, } for _, tc := range cases { c, err := NewClient(context.Background(), &internal.DatabaseConfig{ diff --git a/db/query.go b/db/query.go index c6013483..ca377c12 100644 --- a/db/query.go +++ b/db/query.go @@ -306,7 +306,7 @@ func (q *queryNodeImpl) Key() string { if q.CompKey.Str != nil { return *q.CompKey.Str } - // Numeric keys in queryNodeImpl are always array indices, and can be safely coverted into int. + // Numeric keys in queryNodeImpl are always array indices, and can be safely converted into int. return strconv.Itoa(int(*q.CompKey.Num)) } diff --git a/db/ref_test.go b/db/ref_test.go index 93e348d0..8c489467 100644 --- a/db/ref_test.go +++ b/db/ref_test.go @@ -265,12 +265,12 @@ func TestGetIfChanged(t *testing.T) { } checkAllRequests(t, mock.Reqs, []*testReq{ - &testReq{ + { Method: "GET", Path: "/peter.json", Header: http.Header{"If-None-Match": []string{"old-etag"}}, }, - &testReq{ + { Method: "GET", Path: "/peter.json", Header: http.Header{"If-None-Match": []string{"new-etag"}}, @@ -513,7 +513,7 @@ func TestInvalidUpdate(t *testing.T) { cases := []map[string]interface{}{ nil, make(map[string]interface{}), - map[string]interface{}{"foo": func() {}}, + {"foo": func() {}}, } for _, tc := range cases { if err := testref.Update(context.Background(), tc); err == nil { @@ -542,12 +542,12 @@ func TestTransaction(t *testing.T) { t.Fatal(err) } checkAllRequests(t, mock.Reqs, []*testReq{ - &testReq{ + { Method: "GET", Path: "/peter.json", Header: http.Header{"X-Firebase-ETag": []string{"true"}}, }, - &testReq{ + { Method: "PUT", Path: "/peter.json", Body: serialize(map[string]interface{}{ @@ -591,12 +591,12 @@ func TestTransactionRetry(t *testing.T) { t.Errorf("Transaction() retries = %d; want = %d", cnt, 2) } checkAllRequests(t, mock.Reqs, []*testReq{ - &testReq{ + { Method: "GET", Path: "/peter.json", Header: http.Header{"X-Firebase-ETag": []string{"true"}}, }, - &testReq{ + { Method: "PUT", Path: "/peter.json", Body: serialize(map[string]interface{}{ @@ -605,7 +605,7 @@ func TestTransactionRetry(t *testing.T) { }), Header: http.Header{"If-Match": []string{"mock-etag1"}}, }, - &testReq{ + { Method: "PUT", Path: "/peter.json", Body: serialize(map[string]interface{}{ @@ -650,12 +650,12 @@ func TestTransactionError(t *testing.T) { t.Errorf("Transaction() retries = %d; want = %d", cnt, 1) } checkAllRequests(t, mock.Reqs, []*testReq{ - &testReq{ + { Method: "GET", Path: "/peter.json", Header: http.Header{"X-Firebase-ETag": []string{"true"}}, }, - &testReq{ + { Method: "PUT", Path: "/peter.json", Body: serialize(map[string]interface{}{ @@ -694,7 +694,7 @@ func TestTransactionAbort(t *testing.T) { t.Errorf("Transaction() = nil; want error") } wanted := []*testReq{ - &testReq{ + { Method: "GET", Path: "/peter.json", Header: http.Header{"X-Firebase-ETag": []string{"true"}}, diff --git a/firebase_test.go b/firebase_test.go index bbf1a4c3..4781b378 100644 --- a/firebase_test.go +++ b/firebase_test.go @@ -247,8 +247,8 @@ func TestDatabase(t *testing.T) { func TestDatabaseAuthOverrides(t *testing.T) { cases := []map[string]interface{}{ nil, - map[string]interface{}{}, - map[string]interface{}{"uid": "user1"}, + {}, + {"uid": "user1"}, } for _, tc := range cases { ctx := context.Background() diff --git a/snippets/db.go b/snippets/db.go index 8e0bea71..fc9bd852 100644 --- a/snippets/db.go +++ b/snippets/db.go @@ -153,11 +153,11 @@ func setValue(ctx context.Context, ref *db.Ref) { // [START set_value] usersRef := ref.Child("users") err := usersRef.Set(ctx, map[string]*User{ - "alanisawesome": &User{ + "alanisawesome": { DateOfBirth: "June 23, 1912", FullName: "Alan Turing", }, - "gracehop": &User{ + "gracehop": { DateOfBirth: "December 9, 1906", FullName: "Grace Hopper", }, From 07d34840308dab44630ae87b7024f2e552c2f2ba Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Thu, 1 Mar 2018 11:47:46 -0800 Subject: [PATCH 19/27] Checking for unformatted files in CI (#108) * Checking for unformatted files in CI * Adding newline at eof --- .travis.gofmt.sh | 6 ++++++ .travis.yml | 1 + 2 files changed, 7 insertions(+) create mode 100755 .travis.gofmt.sh diff --git a/.travis.gofmt.sh b/.travis.gofmt.sh new file mode 100755 index 00000000..e33451d7 --- /dev/null +++ b/.travis.gofmt.sh @@ -0,0 +1,6 @@ +#!/bin/bash +if [[ ! -z "$(gofmt -l -s .)" ]]; then + echo "Go code is not formatted:" + gofmt -d -s . + exit 1 +fi diff --git a/.travis.yml b/.travis.yml index 53f411b4..0ff8cd68 100644 --- a/.travis.yml +++ b/.travis.yml @@ -26,5 +26,6 @@ install: script: - golint -set_exit_status $(go list ./...) + - ./.travis.gofmt.sh - go test -v -race -test.short ./... # Run tests with the race detector. - go vet -v ./... # Run Go static analyzer. From 0bbfc6c600e848de0b2deaa4c18d063cf4de71ea Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Fri, 9 Mar 2018 14:09:43 -0800 Subject: [PATCH 20/27] Document Minimum Go Version (#111) --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index 3ea293a2..b3ddcdf5 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,11 @@ Please refer to the [CONTRIBUTING page](./CONTRIBUTING.md) for more information about how you can contribute to this project. We welcome bug reports, feature requests, code review feedback, and also pull requests. +## Supported Go Versions + +We support Go v1.7 and higher. +[Continuous integration](https://travis-ci.org/firebase/firebase-admin-go) system +tests the code on Go v1.7 through v1.10. ## Documentation From cf5cb07c6d2702eaf8f8f434c0d3233f89b4981c Mon Sep 17 00:00:00 2001 From: Michal Jemala Date: Mon, 12 Mar 2018 19:22:25 +0100 Subject: [PATCH 21/27] Fix invalid endpoint URL for topic unsubscribe (#114) --- messaging/messaging.go | 2 +- messaging/messaging_test.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/messaging/messaging.go b/messaging/messaging.go index 231e7212..aaeeee5c 100644 --- a/messaging/messaging.go +++ b/messaging/messaging.go @@ -364,7 +364,7 @@ func (c *Client) UnsubscribeFromTopic(ctx context.Context, tokens []string, topi req := &iidRequest{ Topic: topic, Tokens: tokens, - op: iidSubscribe, + op: iidUnsubscribe, } return c.makeTopicManagementRequest(ctx, req) } diff --git a/messaging/messaging_test.go b/messaging/messaging_test.go index 79ef498b..4bf9d89c 100644 --- a/messaging/messaging_test.go +++ b/messaging/messaging_test.go @@ -730,7 +730,7 @@ func TestUnsubscribe(t *testing.T) { if err != nil { t.Fatal(err) } - checkIIDRequest(t, b, tr, iidSubscribe) + checkIIDRequest(t, b, tr, iidUnsubscribe) checkTopicMgtResponse(t, resp) } From 31b566df208880c45d5598e2c43c0087d7a46431 Mon Sep 17 00:00:00 2001 From: avishalom Date: Wed, 14 Mar 2018 15:05:16 -0400 Subject: [PATCH 22/27] Fix error message for missing user (#113) --- auth/user_mgt.go | 2 +- auth/user_mgt_test.go | 25 ++++++++++++++++--------- integration/auth/user_mgt_test.go | 5 ++++- 3 files changed, 21 insertions(+), 11 deletions(-) diff --git a/auth/user_mgt.go b/auth/user_mgt.go index 551753ea..0c1af368 100644 --- a/auth/user_mgt.go +++ b/auth/user_mgt.go @@ -567,7 +567,7 @@ func (c *Client) getUser(ctx context.Context, request *identitytoolkit.Identityt return nil, err } if len(resp.Users) == 0 { - return nil, fmt.Errorf("cannot find user from params: %v", request) + return nil, fmt.Errorf("cannot find user given params: id:%v, phone:%v, email: %v", request.LocalId, request.PhoneNumber, request.Email) } eu, err := makeExportedUser(resp.Users[0]) diff --git a/auth/user_mgt_test.go b/auth/user_mgt_test.go index 3f298ed6..fedb0baf 100644 --- a/auth/user_mgt_test.go +++ b/auth/user_mgt_test.go @@ -148,17 +148,24 @@ func TestGetNonExistingUser(t *testing.T) { s := echoServer([]byte(resp), t) defer s.Close() - user, err := s.Client.GetUser(context.Background(), "ignored_id") - if user != nil || err == nil { - t.Errorf("GetUser(non-existing) = (%v, %v); want = (nil, error)", user, err) + want := "cannot find user given params: id:[%s], phone:[%s], email: [%s]" + + we := fmt.Sprintf(want, "id-nonexisting", "", "") + user, err := s.Client.GetUser(context.Background(), "id-nonexisting") + if user != nil || err == nil || err.Error() != we { + t.Errorf("GetUser(non-existing) = (%v, %q); want = (nil, %q)", user, err, we) } - user, err = s.Client.GetUserByEmail(context.Background(), "test@email.com") - if user != nil || err == nil { - t.Errorf("GetUserByEmail(non-existing) = (%v, %v); want = (nil, error)", user, err) + + we = fmt.Sprintf(want, "", "", "foo@bar.nonexisting") + user, err = s.Client.GetUserByEmail(context.Background(), "foo@bar.nonexisting") + if user != nil || err == nil || err.Error() != we { + t.Errorf("GetUserByEmail(non-existing) = (%v, %q); want = (nil, %q)", user, err, we) } - user, err = s.Client.GetUserByPhoneNumber(context.Background(), "+1234567890") - if user != nil || err == nil { - t.Errorf("GetUserPhoneNumber(non-existing) = (%v, %v); want = (nil, error)", user, err) + + we = fmt.Sprintf(want, "", "+12345678901", "") + user, err = s.Client.GetUserByPhoneNumber(context.Background(), "+12345678901") + if user != nil || err == nil || err.Error() != we { + t.Errorf("GetUserPhoneNumber(non-existing) = (%v, %q); want = (nil, %q)", user, err, we) } } diff --git a/integration/auth/user_mgt_test.go b/integration/auth/user_mgt_test.go index 4121d62c..6a5aeb04 100644 --- a/integration/auth/user_mgt_test.go +++ b/integration/auth/user_mgt_test.go @@ -40,6 +40,8 @@ func TestUserManagement(t *testing.T) { }{ {"Create test users", testCreateUsers}, {"Get user", testGetUser}, + {"Get user by phone", testGetUserByPhoneNumber}, + {"Get user by email", testGetUserByEmail}, {"Iterate users", testUserIterator}, {"Paged iteration", testPager}, {"Disable user account", testDisableUser}, @@ -96,7 +98,8 @@ func testCreateUsers(t *testing.T) { UID(uid). Email(uid + "email@test.com"). DisplayName("display_name"). - Password("password") + Password("password"). + PhoneNumber("+12223334444") if u, err = client.CreateUser(context.Background(), params); err != nil { t.Fatal(err) From a3ce7c86fc21259687ab73763849f1621fec5c9f Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Wed, 14 Mar 2018 21:52:16 -0700 Subject: [PATCH 23/27] Update CHANGELOG.md (#117) --- CHANGELOG.md | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5852ed27..7e6ddb29 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,11 @@ # Unreleased -- +- [changed] Fixed a bug in the + [`UnsubscribeFromTopic()`](https://godoc.org/firebase.google.com/go/messaging#Client.UnsubscribeFromTopic) + function. +- [changed] Improved the error message returned by `GetUser()`, + `GetUserByEmail()` and `GetUserByPhoneNumber()` APIs in + [`auth`](https://godoc.org/firebase.google.com/go/auth) package. # v2.6.0 From 9bd56f912a10f295f0a05a08f533856c452f6bc8 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Thu, 15 Mar 2018 07:24:04 -0700 Subject: [PATCH 24/27] Removing unused member from auth.Client (#118) --- auth/auth.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/auth/auth.go b/auth/auth.go index 98822fef..dbb42fd2 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -62,7 +62,6 @@ type Token struct { // Client facilitates generating custom JWT tokens for Firebase clients, and verifying ID tokens issued // by Firebase backend services. type Client struct { - hc *internal.HTTPClient is *identitytoolkit.Service ks keySource projectID string @@ -123,7 +122,6 @@ func NewClient(ctx context.Context, c *internal.AuthConfig) (*Client, error) { } return &Client{ - hc: &internal.HTTPClient{Client: hc}, is: is, ks: newHTTPKeySource(googleCertURL, hc), projectID: c.ProjectID, From 0c849e992d639c1dec1df768f94670b75880ff28 Mon Sep 17 00:00:00 2001 From: Tyler Bui-Palsulich <26876514+tbpg@users.noreply.github.com> Date: Thu, 15 Mar 2018 14:30:30 -0400 Subject: [PATCH 25/27] Support Go 1.6 (#120) * all: use golang.org/x/net/context * internal: use ctxhttp to use /x/ context The 1.6 Request type doesn't have WithContext. * all: don't use subtests to keep 1.6 compatibility * integration: use float64 for fields with exp value Values like -7e+07 cannot be parsed into ints in Go 1.6. So, use floats instead. * integration/messaging: use t.Fatal not log.Fatal * travis: add 1.6.x * changelog: mention addition of 1.6 support * readme: mention go version support --- .travis.yml | 1 + CHANGELOG.md | 1 + README.md | 2 + auth/auth.go | 3 +- auth/auth_appengine.go | 2 +- auth/auth_std.go | 2 +- auth/auth_test.go | 3 +- auth/user_mgt.go | 3 +- auth/user_mgt_test.go | 3 +- firebase.go | 3 +- firebase_test.go | 32 ++++++-------- iid/iid.go | 3 +- iid/iid_test.go | 3 +- integration/auth/auth_test.go | 3 +- integration/auth/user_mgt_test.go | 10 ++--- integration/db/db_test.go | 4 +- integration/firestore/firestore_test.go | 3 +- integration/iid/iid_test.go | 3 +- integration/internal/internal.go | 3 +- integration/messaging/messaging_test.go | 5 ++- integration/storage/storage_test.go | 3 +- internal/http_client.go | 7 ++- internal/http_client_test.go | 3 +- messaging/messaging.go | 3 +- messaging/messaging_test.go | 57 +++++++++++-------------- snippets/auth.go | 3 +- snippets/db.go | 3 +- snippets/init.go | 3 +- snippets/messaging.go | 3 +- snippets/storage.go | 3 +- storage/storage.go | 3 +- storage/storage_test.go | 3 +- 32 files changed, 101 insertions(+), 85 deletions(-) diff --git a/.travis.yml b/.travis.yml index 0ff8cd68..ab496972 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,6 +1,7 @@ language: go go: + - 1.6.x - 1.7.x - 1.8.x - 1.9.x diff --git a/CHANGELOG.md b/CHANGELOG.md index 7e6ddb29..5a29b807 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,6 @@ # Unreleased +- [added] Added support for Go 1.6. - [changed] Fixed a bug in the [`UnsubscribeFromTopic()`](https://godoc.org/firebase.google.com/go/messaging#Client.UnsubscribeFromTopic) function. diff --git a/README.md b/README.md index b3ddcdf5..364aab9c 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,8 @@ Admin Go SDK enables access to Firebase services from privileged environments (such as servers or cloud) in Go. Currently this SDK provides Firebase custom authentication support. +Go versions >= 1.6 are supported. + For more information, visit the [Firebase Admin SDK setup guide](https://firebase.google.com/docs/admin/setup/). diff --git a/auth/auth.go b/auth/auth.go index dbb42fd2..2576165e 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -16,7 +16,6 @@ package auth import ( - "context" "crypto/rsa" "crypto/x509" "encoding/json" @@ -25,6 +24,8 @@ import ( "fmt" "strings" + "golang.org/x/net/context" + "firebase.google.com/go/internal" "google.golang.org/api/identitytoolkit/v3" "google.golang.org/api/transport" diff --git a/auth/auth_appengine.go b/auth/auth_appengine.go index 5e05cdb1..351f61c1 100644 --- a/auth/auth_appengine.go +++ b/auth/auth_appengine.go @@ -17,7 +17,7 @@ package auth import ( - "context" + "golang.org/x/net/context" "google.golang.org/appengine" ) diff --git a/auth/auth_std.go b/auth/auth_std.go index f593a7cc..2055af38 100644 --- a/auth/auth_std.go +++ b/auth/auth_std.go @@ -16,7 +16,7 @@ package auth -import "context" +import "golang.org/x/net/context" func newSigner(ctx context.Context) (signer, error) { return serviceAcctSigner{}, nil diff --git a/auth/auth_test.go b/auth/auth_test.go index 2676f6c4..6aea0d3a 100644 --- a/auth/auth_test.go +++ b/auth/auth_test.go @@ -15,7 +15,6 @@ package auth import ( - "context" "encoding/json" "errors" "fmt" @@ -26,6 +25,8 @@ import ( "testing" "time" + "golang.org/x/net/context" + "golang.org/x/oauth2/google" "google.golang.org/api/option" diff --git a/auth/user_mgt.go b/auth/user_mgt.go index 0c1af368..e6d7386b 100644 --- a/auth/user_mgt.go +++ b/auth/user_mgt.go @@ -15,7 +15,6 @@ package auth import ( - "context" "encoding/json" "fmt" "net/http" @@ -24,6 +23,8 @@ import ( "strings" "time" + "golang.org/x/net/context" + "google.golang.org/api/identitytoolkit/v3" "google.golang.org/api/iterator" ) diff --git a/auth/user_mgt_test.go b/auth/user_mgt_test.go index fedb0baf..30072d4f 100644 --- a/auth/user_mgt_test.go +++ b/auth/user_mgt_test.go @@ -16,7 +16,6 @@ package auth import ( "bytes" - "context" "encoding/json" "fmt" "io/ioutil" @@ -27,6 +26,8 @@ import ( "testing" "time" + "golang.org/x/net/context" + "firebase.google.com/go/internal" "golang.org/x/oauth2" diff --git a/firebase.go b/firebase.go index 0e34c058..2e0291ec 100644 --- a/firebase.go +++ b/firebase.go @@ -18,12 +18,13 @@ package firebase import ( - "context" "encoding/json" "errors" "io/ioutil" "os" + "golang.org/x/net/context" + "cloud.google.com/go/firestore" "firebase.google.com/go/auth" diff --git a/firebase_test.go b/firebase_test.go index 4781b378..8ba2b762 100644 --- a/firebase_test.go +++ b/firebase_test.go @@ -15,8 +15,6 @@ package firebase import ( - "context" - "fmt" "io/ioutil" "log" "net/http" @@ -28,6 +26,8 @@ import ( "testing" "time" + "golang.org/x/net/context" + "golang.org/x/oauth2/google" "google.golang.org/api/transport" @@ -539,15 +539,13 @@ func TestAutoInit(t *testing.T) { defer reinstateEnv(credEnvVar, credOld) for _, test := range tests { - t.Run(fmt.Sprintf("NewApp(%s)", test.name), func(t *testing.T) { - overwriteEnv(firebaseEnvName, test.optionsConfig) - app, err := NewApp(context.Background(), test.initOptions) - if err != nil { - t.Error(err) - } else { - compareConfig(app, test.wantOptions, t) - } - }) + overwriteEnv(firebaseEnvName, test.optionsConfig) + app, err := NewApp(context.Background(), test.initOptions) + if err != nil { + t.Errorf("NewApp(%s): %v", test.name, err) + } else { + compareConfig(app, test.wantOptions, t) + } } } @@ -577,13 +575,11 @@ func TestAutoInitInvalidFiles(t *testing.T) { defer reinstateEnv(credEnvVar, credOld) for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - overwriteEnv(firebaseEnvName, test.filename) - _, err := NewApp(context.Background(), nil) - if err == nil || err.Error() != test.wantError { - t.Errorf("got error = %s; want = %s", err, test.wantError) - } - }) + overwriteEnv(firebaseEnvName, test.filename) + _, err := NewApp(context.Background(), nil) + if err == nil || err.Error() != test.wantError { + t.Errorf("%s got error = %s; want = %s", test.name, err, test.wantError) + } } } diff --git a/iid/iid.go b/iid/iid.go index b282db40..566833e6 100644 --- a/iid/iid.go +++ b/iid/iid.go @@ -16,11 +16,12 @@ package iid import ( - "context" "errors" "fmt" "net/http" + "golang.org/x/net/context" + "google.golang.org/api/transport" "firebase.google.com/go/internal" diff --git a/iid/iid_test.go b/iid/iid_test.go index b3e69638..4d0e33aa 100644 --- a/iid/iid_test.go +++ b/iid/iid_test.go @@ -15,12 +15,13 @@ package iid import ( - "context" "fmt" "net/http" "net/http/httptest" "testing" + "golang.org/x/net/context" + "google.golang.org/api/option" "firebase.google.com/go/internal" diff --git a/integration/auth/auth_test.go b/integration/auth/auth_test.go index 2b8d6bd6..b84ff126 100644 --- a/integration/auth/auth_test.go +++ b/integration/auth/auth_test.go @@ -17,7 +17,6 @@ package auth import ( "bytes" - "context" "encoding/json" "flag" "fmt" @@ -28,6 +27,8 @@ import ( "testing" "time" + "golang.org/x/net/context" + "firebase.google.com/go/auth" "firebase.google.com/go/integration/internal" ) diff --git a/integration/auth/user_mgt_test.go b/integration/auth/user_mgt_test.go index 6a5aeb04..dc9ceba0 100644 --- a/integration/auth/user_mgt_test.go +++ b/integration/auth/user_mgt_test.go @@ -16,12 +16,13 @@ package auth import ( - "context" "fmt" "reflect" "testing" "time" + "golang.org/x/net/context" + "google.golang.org/api/iterator" "firebase.google.com/go/auth" @@ -52,11 +53,10 @@ func TestUserManagement(t *testing.T) { {"Delete test users", testDeleteUsers}, } // The tests are meant to be run in sequence. A failure in creating the users - // should be fatal so non of the other tests run. However calling Fatal from a - // subtest does not prevent the other subtests from running, hence we check the - // success of each subtest before proceeding. + // should be fatal so none of the other tests run. for _, run := range orderedRuns { - if ok := t.Run(run.name, run.testFunc); !ok { + run.testFunc(t) + if t.Failed() { t.Fatalf("Failed run %v", run.name) } } diff --git a/integration/db/db_test.go b/integration/db/db_test.go index 0754d5bf..e5f61f04 100644 --- a/integration/db/db_test.go +++ b/integration/db/db_test.go @@ -690,11 +690,11 @@ func protectedRef(t *testing.T, p string) string { } type Dinosaur struct { - Appeared int `json:"appeared"` + Appeared float64 `json:"appeared"` Height float64 `json:"height"` Length float64 `json:"length"` Order string `json:"order"` - Vanished int `json:"vanished"` + Vanished float64 `json:"vanished"` Weight int `json:"weight"` Ratings Ratings `json:"ratings"` } diff --git a/integration/firestore/firestore_test.go b/integration/firestore/firestore_test.go index 1b861d92..c89e68e6 100644 --- a/integration/firestore/firestore_test.go +++ b/integration/firestore/firestore_test.go @@ -15,11 +15,12 @@ package firestore import ( - "context" "log" "reflect" "testing" + "golang.org/x/net/context" + "firebase.google.com/go/integration/internal" ) diff --git a/integration/iid/iid_test.go b/integration/iid/iid_test.go index 2b1b1c9c..82ecd0ac 100644 --- a/integration/iid/iid_test.go +++ b/integration/iid/iid_test.go @@ -16,12 +16,13 @@ package iid import ( - "context" "flag" "log" "os" "testing" + "golang.org/x/net/context" + "firebase.google.com/go/iid" "firebase.google.com/go/integration/internal" ) diff --git a/integration/internal/internal.go b/integration/internal/internal.go index a5cd7af6..497065eb 100644 --- a/integration/internal/internal.go +++ b/integration/internal/internal.go @@ -16,7 +16,6 @@ package internal import ( - "context" "encoding/json" "go/build" "io/ioutil" @@ -24,6 +23,8 @@ import ( "path/filepath" "strings" + "golang.org/x/net/context" + firebase "firebase.google.com/go" "firebase.google.com/go/internal" "google.golang.org/api/option" diff --git a/integration/messaging/messaging_test.go b/integration/messaging/messaging_test.go index 4b8ef6d7..10818ddd 100644 --- a/integration/messaging/messaging_test.go +++ b/integration/messaging/messaging_test.go @@ -15,13 +15,14 @@ package messaging import ( - "context" "flag" "log" "os" "regexp" "testing" + "golang.org/x/net/context" + "firebase.google.com/go/integration/internal" "firebase.google.com/go/messaging" ) @@ -88,7 +89,7 @@ func TestSend(t *testing.T) { } name, err := client.SendDryRun(context.Background(), msg) if err != nil { - log.Fatalln(err) + t.Fatal(err) } const pattern = "^projects/.*/messages/.*$" if !regexp.MustCompile(pattern).MatchString(name) { diff --git a/integration/storage/storage_test.go b/integration/storage/storage_test.go index b5a205d5..c47b29a8 100644 --- a/integration/storage/storage_test.go +++ b/integration/storage/storage_test.go @@ -15,7 +15,6 @@ package storage import ( - "context" "flag" "fmt" "io/ioutil" @@ -23,6 +22,8 @@ import ( "os" "testing" + "golang.org/x/net/context" + "firebase.google.com/go" gcs "cloud.google.com/go/storage" diff --git a/internal/http_client.go b/internal/http_client.go index 984e8a1d..7f3df67a 100644 --- a/internal/http_client.go +++ b/internal/http_client.go @@ -16,12 +16,15 @@ package internal import ( "bytes" - "context" "encoding/json" "fmt" "io" "io/ioutil" "net/http" + + "golang.org/x/net/context/ctxhttp" + + "golang.org/x/net/context" ) // HTTPClient is a convenient API to make HTTP calls. @@ -43,7 +46,7 @@ func (c *HTTPClient) Do(ctx context.Context, r *Request) (*Response, error) { return nil, err } - resp, err := c.Client.Do(req.WithContext(ctx)) + resp, err := ctxhttp.Do(ctx, c.Client, req) if err != nil { return nil, err } diff --git a/internal/http_client_test.go b/internal/http_client_test.go index 14729d17..bdac7474 100644 --- a/internal/http_client_test.go +++ b/internal/http_client_test.go @@ -14,13 +14,14 @@ package internal import ( - "context" "encoding/json" "io/ioutil" "net/http" "net/http/httptest" "reflect" "testing" + + "golang.org/x/net/context" ) var cases = []struct { diff --git a/messaging/messaging.go b/messaging/messaging.go index aaeeee5c..fb6fa5dd 100644 --- a/messaging/messaging.go +++ b/messaging/messaging.go @@ -17,7 +17,6 @@ package messaging import ( - "context" "encoding/json" "errors" "fmt" @@ -26,6 +25,8 @@ import ( "strings" "time" + "golang.org/x/net/context" + "firebase.google.com/go/internal" "google.golang.org/api/transport" ) diff --git a/messaging/messaging_test.go b/messaging/messaging_test.go index 4bf9d89c..1d6d3ad5 100644 --- a/messaging/messaging_test.go +++ b/messaging/messaging_test.go @@ -15,7 +15,6 @@ package messaging import ( - "context" "encoding/json" "io/ioutil" "net/http" @@ -25,6 +24,8 @@ import ( "testing" "time" + "golang.org/x/net/context" + "firebase.google.com/go/internal" "google.golang.org/api/option" ) @@ -565,13 +566,11 @@ func TestSend(t *testing.T) { client.fcmEndpoint = ts.URL for _, tc := range validMessages { - t.Run(tc.name, func(t *testing.T) { - name, err := client.Send(ctx, tc.req) - if name != testMessageID || err != nil { - t.Errorf("Send() = (%q, %v); want = (%q, nil)", name, err, testMessageID) - } - checkFCMRequest(t, b, tr, tc.want, false) - }) + name, err := client.Send(ctx, tc.req) + if name != testMessageID || err != nil { + t.Errorf("Send(%s) = (%q, %v); want = (%q, nil)", tc.name, name, err, testMessageID) + } + checkFCMRequest(t, b, tr, tc.want, false) } } @@ -594,13 +593,11 @@ func TestSendDryRun(t *testing.T) { client.fcmEndpoint = ts.URL for _, tc := range validMessages { - t.Run(tc.name, func(t *testing.T) { - name, err := client.SendDryRun(ctx, tc.req) - if name != testMessageID || err != nil { - t.Errorf("SendDryRun() = (%q, %v); want = (%q, nil)", name, err, testMessageID) - } - checkFCMRequest(t, b, tr, tc.want, true) - }) + name, err := client.SendDryRun(ctx, tc.req) + if name != testMessageID || err != nil { + t.Errorf("SendDryRun(%s) = (%q, %v); want = (%q, nil)", tc.name, name, err, testMessageID) + } + checkFCMRequest(t, b, tr, tc.want, true) } } @@ -657,12 +654,10 @@ func TestInvalidMessage(t *testing.T) { t.Fatal(err) } for _, tc := range invalidMessages { - t.Run(tc.name, func(t *testing.T) { - name, err := client.Send(ctx, tc.req) - if err == nil || err.Error() != tc.want { - t.Errorf("Send() = (%q, %v); want = (%q, %q)", name, err, "", tc.want) - } - }) + name, err := client.Send(ctx, tc.req) + if err == nil || err.Error() != tc.want { + t.Errorf("Send(%s) = (%q, %v); want = (%q, %q)", tc.name, name, err, "", tc.want) + } } } @@ -699,12 +694,10 @@ func TestInvalidSubscribe(t *testing.T) { t.Fatal(err) } for _, tc := range invalidTopicMgtArgs { - t.Run(tc.name, func(t *testing.T) { - name, err := client.SubscribeToTopic(ctx, tc.tokens, tc.topic) - if err == nil || err.Error() != tc.want { - t.Errorf("SubscribeToTopic() = (%q, %v); want = (%q, %q)", name, err, "", tc.want) - } - }) + name, err := client.SubscribeToTopic(ctx, tc.tokens, tc.topic) + if err == nil || err.Error() != tc.want { + t.Errorf("SubscribeToTopic(%s) = (%q, %v); want = (%q, %q)", tc.name, name, err, "", tc.want) + } } } @@ -741,12 +734,10 @@ func TestInvalidUnsubscribe(t *testing.T) { t.Fatal(err) } for _, tc := range invalidTopicMgtArgs { - t.Run(tc.name, func(t *testing.T) { - name, err := client.UnsubscribeFromTopic(ctx, tc.tokens, tc.topic) - if err == nil || err.Error() != tc.want { - t.Errorf("UnsubscribeFromTopic() = (%q, %v); want = (%q, %q)", name, err, "", tc.want) - } - }) + name, err := client.UnsubscribeFromTopic(ctx, tc.tokens, tc.topic) + if err == nil || err.Error() != tc.want { + t.Errorf("UnsubscribeFromTopic(%s) = (%q, %v); want = (%q, %q)", tc.name, name, err, "", tc.want) + } } } diff --git a/snippets/auth.go b/snippets/auth.go index 9fb739ba..b80fa418 100644 --- a/snippets/auth.go +++ b/snippets/auth.go @@ -15,9 +15,10 @@ package snippets import ( - "context" "log" + "golang.org/x/net/context" + firebase "firebase.google.com/go" "firebase.google.com/go/auth" "google.golang.org/api/iterator" diff --git a/snippets/db.go b/snippets/db.go index fc9bd852..c853e12a 100644 --- a/snippets/db.go +++ b/snippets/db.go @@ -16,10 +16,11 @@ package snippets // [START authenticate_db_imports] import ( - "context" "fmt" "log" + "golang.org/x/net/context" + "firebase.google.com/go/db" "firebase.google.com/go" diff --git a/snippets/init.go b/snippets/init.go index 8af16a12..2513ae9d 100644 --- a/snippets/init.go +++ b/snippets/init.go @@ -16,9 +16,10 @@ package snippets // [START admin_import_golang] import ( - "context" "log" + "golang.org/x/net/context" + firebase "firebase.google.com/go" "firebase.google.com/go/auth" diff --git a/snippets/messaging.go b/snippets/messaging.go index 18f6e462..b3ce67d4 100644 --- a/snippets/messaging.go +++ b/snippets/messaging.go @@ -15,11 +15,12 @@ package snippets import ( - "context" "fmt" "log" "time" + "golang.org/x/net/context" + "firebase.google.com/go" "firebase.google.com/go/messaging" ) diff --git a/snippets/storage.go b/snippets/storage.go index a39538a3..169d56e4 100644 --- a/snippets/storage.go +++ b/snippets/storage.go @@ -15,9 +15,10 @@ package snippets import ( - "context" "log" + "golang.org/x/net/context" + firebase "firebase.google.com/go" "google.golang.org/api/option" ) diff --git a/storage/storage.go b/storage/storage.go index 878e2175..dbcd1303 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -16,9 +16,10 @@ package storage import ( - "context" "errors" + "golang.org/x/net/context" + "cloud.google.com/go/storage" "firebase.google.com/go/internal" ) diff --git a/storage/storage_test.go b/storage/storage_test.go index 7a77e60c..eff97a1b 100644 --- a/storage/storage_test.go +++ b/storage/storage_test.go @@ -15,9 +15,10 @@ package storage import ( - "context" "testing" + "golang.org/x/net/context" + "google.golang.org/api/option" "firebase.google.com/go/internal" From eae7451fa8e2ccab552ddb08c69b7c033f539518 Mon Sep 17 00:00:00 2001 From: avishalom Date: Thu, 15 Mar 2018 15:03:01 -0400 Subject: [PATCH 26/27] Bumped version to 2.6.1 (#121) --- README.md | 8 +++----- firebase.go | 2 +- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 364aab9c..e91d1188 100644 --- a/README.md +++ b/README.md @@ -20,8 +20,6 @@ Admin Go SDK enables access to Firebase services from privileged environments (such as servers or cloud) in Go. Currently this SDK provides Firebase custom authentication support. -Go versions >= 1.6 are supported. - For more information, visit the [Firebase Admin SDK setup guide](https://firebase.google.com/docs/admin/setup/). @@ -39,13 +37,13 @@ go get firebase.google.com/go Please refer to the [CONTRIBUTING page](./CONTRIBUTING.md) for more information about how you can contribute to this project. We welcome bug reports, feature -requests, code review feedback, and also pull requests. +requests, code review feedback, and also pull requests. ## Supported Go Versions -We support Go v1.7 and higher. +We support Go v1.6 and higher. [Continuous integration](https://travis-ci.org/firebase/firebase-admin-go) system -tests the code on Go v1.7 through v1.10. +tests the code on Go v1.6 through v1.10. ## Documentation diff --git a/firebase.go b/firebase.go index 2e0291ec..c341dfce 100644 --- a/firebase.go +++ b/firebase.go @@ -42,7 +42,7 @@ import ( var defaultAuthOverrides = make(map[string]interface{}) // Version of the Firebase Go Admin SDK. -const Version = "2.6.0" +const Version = "2.6.1" // firebaseEnvName is the name of the environment variable with the Config. const firebaseEnvName = "FIREBASE_CONFIG" From 98f50a28993bd10db4accddc2568cb88c577b808 Mon Sep 17 00:00:00 2001 From: avishalom Date: Thu, 15 Mar 2018 16:19:51 -0400 Subject: [PATCH 27/27] Changlog updates (#123) --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5a29b807..893bb9ae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,7 @@ # Unreleased +# v2.6.1 + - [added] Added support for Go 1.6. - [changed] Fixed a bug in the [`UnsubscribeFromTopic()`](https://godoc.org/firebase.google.com/go/messaging#Client.UnsubscribeFromTopic)