From ef0a86342d993108a1eef2044b143cda25525734 Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Sat, 21 Oct 2017 16:32:05 -0700 Subject: [PATCH 01/58] Experimental RTDB code --- db/db.go | 121 +++++++++++++++++++++++ db/db_test.go | 231 +++++++++++++++++++++++++++++++++++++++++++ internal/internal.go | 5 + 3 files changed, 357 insertions(+) create mode 100644 db/db.go create mode 100644 db/db_test.go diff --git a/db/db.go b/db/db.go new file mode 100644 index 00000000..eeac758c --- /dev/null +++ b/db/db.go @@ -0,0 +1,121 @@ +package db + +import ( + "fmt" + "net/http" + "strings" + + firebase "firebase.google.com/go" + "firebase.google.com/go/internal" + + "net/url" + + "io/ioutil" + + "encoding/json" + + "golang.org/x/net/context" + "google.golang.org/api/option" + "google.golang.org/api/transport" +) + +const invalidChars = "[].#$" + +var userAgent = fmt.Sprintf("Firebase/HTTP/%s/AdminGo", firebase.Version) + +type Client struct { + hc *http.Client + baseURL string +} + +func NewClient(ctx context.Context, c *internal.DatabaseConfig) (*Client, error) { + o := []option.ClientOption{option.WithUserAgent(userAgent)} + o = append(o, c.Opts...) + + hc, _, err := transport.NewHTTPClient(ctx, o...) + if err != nil { + return nil, err + } + if c.BaseURL == "" { + return nil, fmt.Errorf("database url not specified") + } + url, err := url.Parse(c.BaseURL) + if err != nil { + return nil, err + } else if url.Scheme != "https" { + return nil, fmt.Errorf("invalid database URL (incorrect scheme): %q", c.BaseURL) + } else if !strings.HasSuffix(url.Host, ".firebaseio.com") { + return nil, fmt.Errorf("invalid database URL (incorrest host): %q", c.BaseURL) + } + return &Client{ + hc: hc, + baseURL: fmt.Sprintf("https://%s", url.Host), + }, nil +} + +func (c *Client) NewRef(path string) (*Ref, error) { + if strings.ContainsAny(path, invalidChars) { + return nil, fmt.Errorf("path %q contains one or more invalid characters", path) + } + var segs []string + for _, s := range strings.Split(path, "/") { + if s != "" { + segs = append(segs, s) + } + } + + key := "" + if len(segs) > 0 { + key = segs[len(segs)-1] + } + + return &Ref{ + client: c, + segs: segs, + Key: key, + Path: "/" + strings.Join(segs, "/"), + }, nil +} + +func (c *Client) sendRequest(method string, path string) (*http.Response, error) { + url := fmt.Sprintf("%s%s%s", c.baseURL, path, ".json") + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, err + } + return c.hc.Do(req) +} + +type Ref struct { + client *Client + segs []string + Key string + Path string +} + +func (r *Ref) Parent() *Ref { + l := len(r.segs) + if l > 0 { + path := strings.Join(r.segs[:l-1], "/") + parent, _ := r.client.NewRef(path) + return parent + } + return nil +} + +func (r *Ref) Get(v interface{}) error { + resp, err := r.client.sendRequest("GET", r.Path) + if err != nil { + return err + } + defer resp.Body.Close() + + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return err + } + if err := json.Unmarshal(b, v); err != nil { + return err + } + return nil +} diff --git a/db/db_test.go b/db/db_test.go new file mode 100644 index 00000000..38143990 --- /dev/null +++ b/db/db_test.go @@ -0,0 +1,231 @@ +package db + +import ( + "net/http" + "net/http/httptest" + "testing" + + "golang.org/x/net/context" + "golang.org/x/oauth2" + + "encoding/json" + + "reflect" + + "firebase.google.com/go/internal" + "google.golang.org/api/option" +) + +const testURL = "https://test-db.firebaseio.com" + +var testOpts = []option.ClientOption{ + option.WithTokenSource(&mockTokenSource{"mock-token"}), +} + +func TestNewClient(t *testing.T) { + c, err := NewClient(context.Background(), &internal.DatabaseConfig{ + Opts: testOpts, + BaseURL: testURL, + }) + if err != nil { + t.Fatal(err) + } + if c.baseURL != testURL { + t.Errorf("BaseURL = %q; want: %q", c.baseURL, testURL) + } else if c.hc == nil { + t.Errorf("http.Client = nil; want non-nil") + } +} + +func TestNewClientError(t *testing.T) { + cases := []string{ + "", + "foo", + "http://db.firebaseio.com", + "https://firebase.google.com", + } + for _, tc := range cases { + c, err := NewClient(context.Background(), &internal.DatabaseConfig{ + Opts: testOpts, + BaseURL: tc, + }) + if c != nil || err == nil { + t.Errorf("NewClient() = (%v, %v); want = (nil, error)", c, err) + } + } +} + +func TestNewRef(t *testing.T) { + c := newTestClient(t) + cases := []struct { + Path string + WantPath string + WantKey string + }{ + {"", "/", ""}, + {"/", "/", ""}, + {"foo", "/foo", "foo"}, + {"/foo", "/foo", "foo"}, + {"foo/bar", "/foo/bar", "bar"}, + {"/foo/bar", "/foo/bar", "bar"}, + {"/foo/bar/", "/foo/bar", "bar"}, + } + for _, tc := range cases { + r, err := c.NewRef(tc.Path) + if err != nil { + t.Fatal(err) + } + if r.client == nil { + t.Errorf("Client = nil; want = %v", c) + } else if r.Path != tc.WantPath { + t.Errorf("Path = %q; want = %q", r.Path, tc.WantPath) + } else if r.Key != tc.WantKey { + t.Errorf("Key = %q; want = %q", r.Key, tc.WantKey) + } + } +} + +func TestParent(t *testing.T) { + c := newTestClient(t) + cases := []struct { + Path string + HasParent bool + Want string + }{ + {"", false, ""}, + {"/", false, ""}, + {"foo", true, ""}, + {"/foo", true, ""}, + {"foo/bar", true, "foo"}, + {"/foo/bar", true, "foo"}, + {"/foo/bar/", true, "foo"}, + } + for _, tc := range cases { + r, err := c.NewRef(tc.Path) + if err != nil { + t.Fatal(err) + } + + r = r.Parent() + if tc.HasParent { + if r == nil { + t.Fatalf("Parent = nil; want = %q", tc.Want) + } else if r.client == nil { + t.Errorf("Client = nil; want = %v", c) + } else if r.Key != tc.Want { + t.Errorf("Key = %q; want = %q", r.Key, tc.Want) + } + } else if r != nil { + t.Fatalf("Parent = %v; want = nil", r) + } + } +} + +func TestGet(t *testing.T) { + want := map[string]interface{}{ + "name": "Peter Parker", + "age": float64(17), + } + c := newTestClient(t) + mock, err := newMockServer(want) + if err != nil { + t.Fatal(err) + } + defer mock.Srv.Close() + c.baseURL = mock.Srv.URL + + ref, err := c.NewRef("peter") + if err != nil { + t.Fatal(err) + } + + var got map[string]interface{} + if err := ref.Get(&got); err != nil { + t.Fatal(err) + } else if !reflect.DeepEqual(want, got) { + t.Errorf("Get() = %v; want = %v", got, want) + } + checkRequests(t, mock.Req, 1) +} + +func TestGetWithStruct(t *testing.T) { + want := person{Name: "Peter Parker", Age: 17} + c := newTestClient(t) + mock, err := newMockServer(want) + if err != nil { + t.Fatal(err) + } + defer mock.Srv.Close() + c.baseURL = mock.Srv.URL + + ref, err := c.NewRef("peter") + if err != nil { + t.Fatal(err) + } + + var got person + if err := ref.Get(&got); err != nil { + t.Fatal(err) + } else if want != got { + t.Errorf("Get() = %v; want = %v", got, want) + } + checkRequests(t, mock.Req, 1) +} + +func newTestClient(t *testing.T) *Client { + c, err := NewClient(context.Background(), &internal.DatabaseConfig{ + Opts: testOpts, + BaseURL: testURL, + }) + if err != nil { + t.Fatal(err) + } + return c +} + +func checkRequests(t *testing.T, req []*http.Request, num int) { + if len(req) != num { + t.Errorf("Request Count = %d; want = %d", len(req), num) + } + for _, r := range req { + if h := r.Header.Get("Authorization"); h != "Bearer mock-token" { + t.Errorf("Authorization = %q; want = %q", h, "Bearer mock-token") + } else if h := r.Header.Get("User-Agent"); h != userAgent { + t.Errorf("User-Agent = %q; want = %q", h, userAgent) + } + } +} + +func newMockServer(v interface{}) (*mockServer, error) { + b, err := json.Marshal(v) + if err != nil { + return nil, err + } + + mock := &mockServer{} + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mock.Req = append(mock.Req, r) + w.Header().Set("Content-Type", "application/json") + w.Write(b) + }) + mock.Srv = httptest.NewServer(handler) + return mock, nil +} + +type mockServer struct { + Req []*http.Request + Srv *httptest.Server +} + +type mockTokenSource struct { + AccessToken string +} + +type person struct { + Name string `json:"name"` + Age int32 `json:"age"` +} + +func (ts *mockTokenSource) Token() (*oauth2.Token, error) { + return &oauth2.Token{AccessToken: ts.AccessToken}, nil +} diff --git a/internal/internal.go b/internal/internal.go index 8513391a..bb8073e8 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -28,6 +28,11 @@ type AuthConfig struct { ProjectID string } +type DatabaseConfig struct { + Opts []option.ClientOption + BaseURL string +} + // StorageConfig represents the configuration of Google Cloud Storage service. type StorageConfig struct { Opts []option.ClientOption From becb82f4b84b7362b7709d5df36d9f184da665cb Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Sat, 21 Oct 2017 21:45:24 -0700 Subject: [PATCH 02/58] Added ref.Set() --- db/db.go | 103 +++++++++++++++++++++++++++++----- db/db_test.go | 151 ++++++++++++++++++++++++++++++++------------------ 2 files changed, 185 insertions(+), 69 deletions(-) diff --git a/db/db.go b/db/db.go index eeac758c..300eafd4 100644 --- a/db/db.go +++ b/db/db.go @@ -1,7 +1,23 @@ +// Copyright 2017 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package db contains functions for accessing the Firebase Realtime Database. package db import ( "fmt" + "io" "net/http" "strings" @@ -14,6 +30,10 @@ import ( "encoding/json" + "runtime" + + "bytes" + "golang.org/x/net/context" "google.golang.org/api/option" "google.golang.org/api/transport" @@ -21,8 +41,9 @@ import ( const invalidChars = "[].#$" -var userAgent = fmt.Sprintf("Firebase/HTTP/%s/AdminGo", firebase.Version) +var userAgent = fmt.Sprintf("Firebase/HTTP/%s/%s/AdminGo", firebase.Version, runtime.Version()) +// Client is the interface for the Firebase Realtime Database service. type Client struct { hc *http.Client baseURL string @@ -70,27 +91,78 @@ func (c *Client) NewRef(path string) (*Ref, error) { } return &Ref{ - client: c, - segs: segs, Key: key, Path: "/" + strings.Join(segs, "/"), + client: c, + segs: segs, }, nil } -func (c *Client) sendRequest(method string, path string) (*http.Response, error) { - url := fmt.Sprintf("%s%s%s", c.baseURL, path, ".json") - req, err := http.NewRequest("GET", url, nil) +func (c *Client) send(r *request) (*response, error) { + url := fmt.Sprintf("%s%s%s", c.baseURL, r.Path, ".json") + + var data io.Reader + if r.Body != nil { + b, err := json.Marshal(r.Body) + if err != nil { + return nil, err + } + data = bytes.NewBuffer(b) + } + + req, err := http.NewRequest(r.Method, url, data) + if err != nil { + return nil, err + } + if data != nil { + req.Header.Add("Content-Type", "application/json") + } + resp, err := c.hc.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + b, err := ioutil.ReadAll(resp.Body) if err != nil { return nil, err } - return c.hc.Do(req) + return &response{Status: resp.StatusCode, Body: b}, nil +} + +type request struct { + Method string + Path string + Body interface{} +} + +type response struct { + Status int + Body []byte +} + +func (r *response) CheckStatus(want int) error { + if r.Status != want { + return fmt.Errorf("http error: %d; body: %s", r.Status, string(r.Body)) + } + return nil +} + +func (r *response) CheckAndParse(want int, v interface{}) error { + if err := r.CheckStatus(want); err != nil { + return err + } else if err := json.Unmarshal(r.Body, v); err != nil { + return err + } + return nil } type Ref struct { + Key string + Path string + client *Client segs []string - Key string - Path string } func (r *Ref) Parent() *Ref { @@ -104,17 +176,20 @@ func (r *Ref) Parent() *Ref { } func (r *Ref) Get(v interface{}) error { - resp, err := r.client.sendRequest("GET", r.Path) + resp, err := r.client.send(&request{Method: "GET", Path: r.Path}) if err != nil { return err + } else if err := resp.CheckAndParse(http.StatusOK, v); err != nil { + return err } - defer resp.Body.Close() + return nil +} - b, err := ioutil.ReadAll(resp.Body) +func (r *Ref) Set(v interface{}) error { + resp, err := r.client.send(&request{Method: "PUT", Path: r.Path, Body: v}) if err != nil { return err - } - if err := json.Unmarshal(b, v); err != nil { + } else if err := resp.CheckStatus(http.StatusOK); err != nil { return err } return nil diff --git a/db/db_test.go b/db/db_test.go index 38143990..87c5b9a1 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -1,8 +1,10 @@ package db import ( + "log" "net/http" "net/http/httptest" + "os" "testing" "golang.org/x/net/context" @@ -22,6 +24,18 @@ var testOpts = []option.ClientOption{ option.WithTokenSource(&mockTokenSource{"mock-token"}), } +var client *Client + +func TestMain(m *testing.M) { + var err error + conf := &internal.DatabaseConfig{Opts: testOpts, BaseURL: testURL} + client, err = NewClient(context.Background(), conf) + if err != nil { + log.Fatalln(err) + } + os.Exit(m.Run()) +} + func TestNewClient(t *testing.T) { c, err := NewClient(context.Background(), &internal.DatabaseConfig{ Opts: testOpts, @@ -56,7 +70,6 @@ func TestNewClientError(t *testing.T) { } func TestNewRef(t *testing.T) { - c := newTestClient(t) cases := []struct { Path string WantPath string @@ -71,12 +84,12 @@ func TestNewRef(t *testing.T) { {"/foo/bar/", "/foo/bar", "bar"}, } for _, tc := range cases { - r, err := c.NewRef(tc.Path) + r, err := client.NewRef(tc.Path) if err != nil { t.Fatal(err) } if r.client == nil { - t.Errorf("Client = nil; want = %v", c) + t.Errorf("Client = nil; want = %v", client) } else if r.Path != tc.WantPath { t.Errorf("Path = %q; want = %q", r.Path, tc.WantPath) } else if r.Key != tc.WantKey { @@ -86,7 +99,6 @@ func TestNewRef(t *testing.T) { } func TestParent(t *testing.T) { - c := newTestClient(t) cases := []struct { Path string HasParent bool @@ -101,7 +113,7 @@ func TestParent(t *testing.T) { {"/foo/bar/", true, "foo"}, } for _, tc := range cases { - r, err := c.NewRef(tc.Path) + r, err := client.NewRef(tc.Path) if err != nil { t.Fatal(err) } @@ -111,7 +123,7 @@ func TestParent(t *testing.T) { if r == nil { t.Fatalf("Parent = nil; want = %q", tc.Want) } else if r.client == nil { - t.Errorf("Client = nil; want = %v", c) + t.Errorf("Client = nil; want = %v", client) } else if r.Key != tc.Want { t.Errorf("Key = %q; want = %q", r.Key, tc.Want) } @@ -122,19 +134,12 @@ func TestParent(t *testing.T) { } func TestGet(t *testing.T) { - want := map[string]interface{}{ - "name": "Peter Parker", - "age": float64(17), - } - c := newTestClient(t) - mock, err := newMockServer(want) - if err != nil { - t.Fatal(err) - } - defer mock.Srv.Close() - c.baseURL = mock.Srv.URL + want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() - ref, err := c.NewRef("peter") + ref, err := client.NewRef("peter") if err != nil { t.Fatal(err) } @@ -145,20 +150,18 @@ func TestGet(t *testing.T) { } else if !reflect.DeepEqual(want, got) { t.Errorf("Get() = %v; want = %v", got, want) } - checkRequests(t, mock.Req, 1) + + checkRequestDefaults(t, mock.Req, 1) + checkRequest(t, mock.Req[0], "GET", "/peter.json") } func TestGetWithStruct(t *testing.T) { want := person{Name: "Peter Parker", Age: 17} - c := newTestClient(t) - mock, err := newMockServer(want) - if err != nil { - t.Fatal(err) - } - defer mock.Srv.Close() - c.baseURL = mock.Srv.URL + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() - ref, err := c.NewRef("peter") + ref, err := client.NewRef("peter") if err != nil { t.Fatal(err) } @@ -169,63 +172,101 @@ func TestGetWithStruct(t *testing.T) { } else if want != got { t.Errorf("Get() = %v; want = %v", got, want) } - checkRequests(t, mock.Req, 1) + + checkRequestDefaults(t, mock.Req, 1) + checkRequest(t, mock.Req[0], "GET", "/peter.json") } -func newTestClient(t *testing.T) *Client { - c, err := NewClient(context.Background(), &internal.DatabaseConfig{ - Opts: testOpts, - BaseURL: testURL, - }) +func TestSet(t *testing.T) { + want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + ref, err := client.NewRef("peter") + if err != nil { + t.Fatal(err) + } + + if err := ref.Set(&want); err != nil { + t.Fatal(err) + } + + checkRequestDefaults(t, mock.Req, 1) + checkRequest(t, mock.Req[0], "PUT", "/peter.json") +} + +func TestSetWithStruct(t *testing.T) { + want := &person{"Peter Parker", 17} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + ref, err := client.NewRef("peter") if err != nil { t.Fatal(err) } - return c + + if err := ref.Set(&want); err != nil { + t.Fatal(err) + } + + checkRequestDefaults(t, mock.Req, 1) + checkRequest(t, mock.Req[0], "PUT", "/peter.json") } -func checkRequests(t *testing.T, req []*http.Request, num int) { +func checkRequestDefaults(t *testing.T, req []*http.Request, num int) { if len(req) != num { t.Errorf("Request Count = %d; want = %d", len(req), num) } for _, r := range req { if h := r.Header.Get("Authorization"); h != "Bearer mock-token" { t.Errorf("Authorization = %q; want = %q", h, "Bearer mock-token") - } else if h := r.Header.Get("User-Agent"); h != userAgent { + } + if h := r.Header.Get("User-Agent"); h != userAgent { t.Errorf("User-Agent = %q; want = %q", h, userAgent) } } } -func newMockServer(v interface{}) (*mockServer, error) { - b, err := json.Marshal(v) - if err != nil { - return nil, err +func checkRequest(t *testing.T, r *http.Request, method, url string) { + if r.Method != method { + t.Errorf("Method = %q; want = %q", r.Method, method) + } + if r.RequestURI != url { + t.Errorf("URL = %q; want = %q", r.RequestURI, url) } - - mock := &mockServer{} - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - mock.Req = append(mock.Req, r) - w.Header().Set("Content-Type", "application/json") - w.Write(b) - }) - mock.Srv = httptest.NewServer(handler) - return mock, nil } type mockServer struct { - Req []*http.Request - Srv *httptest.Server + Resp interface{} + Req []*http.Request + srv *httptest.Server +} + +func (s *mockServer) Start(c *Client) *httptest.Server { + if s.srv == nil { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s.Req = append(s.Req, r) + b, _ := json.Marshal(s.Resp) + w.Header().Set("Content-Type", "application/json") + w.Write(b) + }) + s.srv = httptest.NewServer(handler) + c.baseURL = s.srv.URL + } + return s.srv } type mockTokenSource struct { AccessToken string } +func (ts *mockTokenSource) Token() (*oauth2.Token, error) { + return &oauth2.Token{AccessToken: ts.AccessToken}, nil +} + type person struct { Name string `json:"name"` Age int32 `json:"age"` } - -func (ts *mockTokenSource) Token() (*oauth2.Token, error) { - return &oauth2.Token{AccessToken: ts.AccessToken}, nil -} From 2d4e34276493d93dc14a1c89f87f70943c73f544 Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Sat, 21 Oct 2017 23:39:19 -0700 Subject: [PATCH 03/58] Added Push(), Update(), Remove() and tests --- db/db.go | 133 ++++++++++++++++++++++------- db/db_test.go | 229 ++++++++++++++++++++++++++++++++++++++++---------- 2 files changed, 288 insertions(+), 74 deletions(-) diff --git a/db/db.go b/db/db.go index 300eafd4..54887cc2 100644 --- a/db/db.go +++ b/db/db.go @@ -109,14 +109,19 @@ func (c *Client) send(r *request) (*response, error) { } data = bytes.NewBuffer(b) } - req, err := http.NewRequest(r.Method, url, data) if err != nil { return nil, err - } - if data != nil { + } else if data != nil { req.Header.Add("Content-Type", "application/json") } + + q := req.URL.Query() + for k, v := range r.Query { + q.Add(k, v) + } + req.URL.RawQuery = q.Encode() + resp, err := c.hc.Do(req) if err != nil { return nil, err @@ -130,33 +135,6 @@ func (c *Client) send(r *request) (*response, error) { return &response{Status: resp.StatusCode, Body: b}, nil } -type request struct { - Method string - Path string - Body interface{} -} - -type response struct { - Status int - Body []byte -} - -func (r *response) CheckStatus(want int) error { - if r.Status != want { - return fmt.Errorf("http error: %d; body: %s", r.Status, string(r.Body)) - } - return nil -} - -func (r *response) CheckAndParse(want int, v interface{}) error { - if err := r.CheckStatus(want); err != nil { - return err - } else if err := json.Unmarshal(r.Body, v); err != nil { - return err - } - return nil -} - type Ref struct { Key string Path string @@ -175,6 +153,14 @@ func (r *Ref) Parent() *Ref { return nil } +func (r *Ref) Child(path string) (*Ref, error) { + if strings.HasPrefix(path, "/") { + return nil, fmt.Errorf("child path must not start with %q", "/") + } + fp := fmt.Sprintf("%s/%s", r.Path, path) + return r.client.NewRef(fp) +} + func (r *Ref) Get(v interface{}) error { resp, err := r.client.send(&request{Method: "GET", Path: r.Path}) if err != nil { @@ -186,7 +172,54 @@ func (r *Ref) Get(v interface{}) error { } func (r *Ref) Set(v interface{}) error { - resp, err := r.client.send(&request{Method: "PUT", Path: r.Path, Body: v}) + resp, err := r.client.send(&request{ + Method: "PUT", + Path: r.Path, + Body: v, + Query: map[string]string{"print": "silent"}, + }) + if err != nil { + return err + } else if err := resp.CheckStatus(http.StatusNoContent); err != nil { + return err + } + return nil +} + +func (r *Ref) Push(v interface{}) (*Ref, error) { + resp, err := r.client.send(&request{Method: "POST", Path: r.Path, Body: v}) + if err != nil { + return nil, err + } + var d struct { + Name string `json:"name"` + } + if err := resp.CheckAndParse(http.StatusOK, &d); err != nil { + return nil, err + } + return r.Child(d.Name) +} + +func (r *Ref) Update(v map[string]interface{}) error { + if len(v) == 0 { + return fmt.Errorf("value argument must be a non-empty map") + } + resp, err := r.client.send(&request{ + Method: "PATCH", + Path: r.Path, + Body: v, + Query: map[string]string{"print": "silent"}, + }) + if err != nil { + return err + } else if err := resp.CheckStatus(http.StatusNoContent); err != nil { + return err + } + return nil +} + +func (r *Ref) Remove() error { + resp, err := r.client.send(&request{Method: "DELETE", Path: r.Path}) if err != nil { return err } else if err := resp.CheckStatus(http.StatusOK); err != nil { @@ -194,3 +227,41 @@ func (r *Ref) Set(v interface{}) error { } return nil } + +type request struct { + Method string + Path string + Body interface{} + Query map[string]string +} + +type response struct { + Status int + Body []byte +} + +func (r *response) CheckStatus(want int) error { + if r.Status == want { + return nil + } + var b struct { + Error string `json:"error"` + } + json.Unmarshal(r.Body, &b) + var msg string + if b.Error != "" { + msg = fmt.Sprintf("http error status: %d; reason: %s", r.Status, b.Error) + } else { + msg = fmt.Sprintf("http error status: %d; message: %s", r.Status, string(r.Body)) + } + return fmt.Errorf(msg) +} + +func (r *response) CheckAndParse(want int, v interface{}) error { + if err := r.CheckStatus(want); err != nil { + return err + } else if err := json.Unmarshal(r.Body, v); err != nil { + return err + } + return nil +} diff --git a/db/db_test.go b/db/db_test.go index 87c5b9a1..92865f40 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -14,6 +14,8 @@ import ( "reflect" + "io/ioutil" + "firebase.google.com/go/internal" "google.golang.org/api/option" ) @@ -25,6 +27,7 @@ var testOpts = []option.ClientOption{ } var client *Client +var ref *Ref func TestMain(m *testing.M) { var err error @@ -33,6 +36,11 @@ func TestMain(m *testing.M) { if err != nil { log.Fatalln(err) } + + ref, err = client.NewRef("peter") + if err != nil { + log.Fatalln(err) + } os.Exit(m.Run()) } @@ -139,20 +147,13 @@ func TestGet(t *testing.T) { srv := mock.Start(client) defer srv.Close() - ref, err := client.NewRef("peter") - if err != nil { - t.Fatal(err) - } - var got map[string]interface{} if err := ref.Get(&got); err != nil { t.Fatal(err) } else if !reflect.DeepEqual(want, got) { t.Errorf("Get() = %v; want = %v", got, want) } - - checkRequestDefaults(t, mock.Req, 1) - checkRequest(t, mock.Req[0], "GET", "/peter.json") + checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: "/peter.json"}) } func TestGetWithStruct(t *testing.T) { @@ -161,93 +162,235 @@ func TestGetWithStruct(t *testing.T) { srv := mock.Start(client) defer srv.Close() - ref, err := client.NewRef("peter") - if err != nil { - t.Fatal(err) - } - var got person if err := ref.Get(&got); err != nil { t.Fatal(err) } else if want != got { t.Errorf("Get() = %v; want = %v", got, want) } + checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: "/peter.json"}) +} + +func TestWerlformedHttpError(t *testing.T) { + mock := &mockServer{Resp: map[string]string{"error": "test error"}, Status: 500} + srv := mock.Start(client) + defer srv.Close() + + var got person + err := ref.Get(&got) + want := "http error status: 500; reason: test error" + if err == nil || err.Error() != want { + t.Errorf("Get() = %v; want = %v", err, want) + } + checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: "/peter.json"}) +} + +func TestUnexpectedHttpError(t *testing.T) { + mock := &mockServer{Resp: "unexpected error", Status: 500} + srv := mock.Start(client) + defer srv.Close() - checkRequestDefaults(t, mock.Req, 1) - checkRequest(t, mock.Req[0], "GET", "/peter.json") + var got person + err := ref.Get(&got) + want := "http error status: 500; message: \"unexpected error\"" + if err == nil || err.Error() != want { + t.Errorf("Get() = %v; want = %v", err, want) + } + checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: "/peter.json"}) } func TestSet(t *testing.T) { + mock := &mockServer{} + srv := mock.Start(client) + defer srv.Close() + want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} - mock := &mockServer{Resp: want} + if err := ref.Set(want); err != nil { + t.Fatal(err) + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "PUT", + Path: "/peter.json?print=silent", + Body: serialize(want), + }) +} + +func TestSetWithStruct(t *testing.T) { + mock := &mockServer{} + srv := mock.Start(client) + defer srv.Close() + + want := &person{"Peter Parker", 17} + if err := ref.Set(&want); err != nil { + t.Fatal(err) + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "PUT", + Path: "/peter.json?print=silent", + Body: serialize(want), + }) +} + +func TestPush(t *testing.T) { + mock := &mockServer{Resp: map[string]string{"name": "new_key"}} srv := mock.Start(client) defer srv.Close() - ref, err := client.NewRef("peter") + child, err := ref.Push(nil) if err != nil { t.Fatal(err) } - if err := ref.Set(&want); err != nil { + if child.Key != "new_key" { + t.Errorf("Push() = %q; want = %q", child.Key, "new_key") + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "POST", + Path: "/peter.json", + }) +} + +func TestPushWithValue(t *testing.T) { + mock := &mockServer{Resp: map[string]string{"name": "new_key"}} + srv := mock.Start(client) + defer srv.Close() + + want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} + child, err := ref.Push(want) + if err != nil { t.Fatal(err) } - checkRequestDefaults(t, mock.Req, 1) - checkRequest(t, mock.Req[0], "PUT", "/peter.json") + if child.Key != "new_key" { + t.Errorf("Push() = %q; want = %q", child.Key, "new_key") + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "POST", + Path: "/peter.json", + Body: serialize(want), + }) } -func TestSetWithStruct(t *testing.T) { - want := &person{"Peter Parker", 17} +func TestUpdate(t *testing.T) { + want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} mock := &mockServer{Resp: want} srv := mock.Start(client) defer srv.Close() - ref, err := client.NewRef("peter") - if err != nil { + if err := ref.Update(want); err != nil { t.Fatal(err) } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "PATCH", + Path: "/peter.json?print=silent", + Body: serialize(want), + }) +} - if err := ref.Set(&want); err != nil { +func TestInvalidUpdate(t *testing.T) { + if err := ref.Update(nil); err == nil { + t.Errorf("Update(nil) = nil; want error") + } + + m := make(map[string]interface{}) + if err := ref.Update(m); err == nil { + t.Errorf("Update(map{}) = nil; want error") + } +} + +func TestRemove(t *testing.T) { + mock := &mockServer{Resp: "null"} + srv := mock.Start(client) + defer srv.Close() + + if err := ref.Remove(); err != nil { t.Fatal(err) } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "DELETE", + Path: "/peter.json", + }) +} - checkRequestDefaults(t, mock.Req, 1) - checkRequest(t, mock.Req[0], "PUT", "/peter.json") +func checkOnlyRequest(t *testing.T, got []*testReq, want *testReq) { + checkAllRequests(t, got, []*testReq{want}) } -func checkRequestDefaults(t *testing.T, req []*http.Request, num int) { - if len(req) != num { - t.Errorf("Request Count = %d; want = %d", len(req), num) +func checkAllRequests(t *testing.T, got []*testReq, want []*testReq) { + if len(got) != len(want) { + t.Errorf("Request Count = %d; want = %d", len(got), len(want)) } - for _, r := range req { + for i, r := range got { if h := r.Header.Get("Authorization"); h != "Bearer mock-token" { t.Errorf("Authorization = %q; want = %q", h, "Bearer mock-token") } if h := r.Header.Get("User-Agent"); h != userAgent { t.Errorf("User-Agent = %q; want = %q", h, userAgent) } - } -} -func checkRequest(t *testing.T, r *http.Request, method, url string) { - if r.Method != method { - t.Errorf("Method = %q; want = %q", r.Method, method) - } - if r.RequestURI != url { - t.Errorf("URL = %q; want = %q", r.RequestURI, url) + w := want[i] + if r.Method != w.Method { + t.Errorf("Method = %q; want = %q", r.Method, w.Method) + } + if r.Path != w.Path { + t.Errorf("URL = %q; want = %q", r.Path, w.Path) + } + if w.Body != nil { + if !reflect.DeepEqual(r.Body, w.Body) { + t.Errorf("Body = %v; want = %v", string(r.Body), string(w.Body)) + } + } else if len(r.Body) != 0 { + t.Errorf("Body = %v; want empty", r.Body) + } } } type mockServer struct { - Resp interface{} - Req []*http.Request - srv *httptest.Server + Resp interface{} + Status int + Reqs []*testReq + srv *httptest.Server +} + +type testReq struct { + Method string + Path string + Header http.Header + Body []byte +} + +func serialize(v interface{}) []byte { + b, _ := json.Marshal(v) + return b +} + +func newTestReq(r *http.Request) (*testReq, error) { + defer r.Body.Close() + b, err := ioutil.ReadAll(r.Body) + if err != nil { + return nil, err + } + return &testReq{ + Method: r.Method, + Path: r.RequestURI, + Header: r.Header, + Body: b, + }, nil } func (s *mockServer) Start(c *Client) *httptest.Server { if s.srv == nil { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - s.Req = append(s.Req, r) + tr, _ := newTestReq(r) + s.Reqs = append(s.Reqs, tr) + + print := r.URL.Query().Get("print") + if s.Status != 0 { + w.WriteHeader(s.Status) + } else if print == "silent" { + w.WriteHeader(http.StatusNoContent) + return + } b, _ := json.Marshal(s.Resp) w.Header().Set("Content-Type", "application/json") w.Write(b) From 6bf42cfa92ae71b1867848dec4fb7e9ca93f1523 Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Sun, 22 Oct 2017 03:27:29 -0700 Subject: [PATCH 04/58] Adding Transaction() support --- db/db.go | 114 +++------------------ db/db_test.go | 267 ++++++++++++++++++++++++++++++++++++++++++-------- db/ref.go | 178 +++++++++++++++++++++++++++++++++ 3 files changed, 420 insertions(+), 139 deletions(-) create mode 100644 db/ref.go diff --git a/db/db.go b/db/db.go index 54887cc2..24666dab 100644 --- a/db/db.go +++ b/db/db.go @@ -101,21 +101,26 @@ func (c *Client) NewRef(path string) (*Ref, error) { func (c *Client) send(r *request) (*response, error) { url := fmt.Sprintf("%s%s%s", c.baseURL, r.Path, ".json") - var data io.Reader + var body io.Reader if r.Body != nil { b, err := json.Marshal(r.Body) if err != nil { return nil, err } - data = bytes.NewBuffer(b) + body = bytes.NewBuffer(b) } - req, err := http.NewRequest(r.Method, url, data) + + req, err := http.NewRequest(r.Method, url, body) if err != nil { return nil, err - } else if data != nil { + } else if body != nil { req.Header.Add("Content-Type", "application/json") } + for k, v := range r.Header { + req.Header.Add(k, v) + } + q := req.URL.Query() for k, v := range r.Query { q.Add(k, v) @@ -132,100 +137,11 @@ func (c *Client) send(r *request) (*response, error) { if err != nil { return nil, err } - return &response{Status: resp.StatusCode, Body: b}, nil -} - -type Ref struct { - Key string - Path string - - client *Client - segs []string -} - -func (r *Ref) Parent() *Ref { - l := len(r.segs) - if l > 0 { - path := strings.Join(r.segs[:l-1], "/") - parent, _ := r.client.NewRef(path) - return parent - } - return nil -} - -func (r *Ref) Child(path string) (*Ref, error) { - if strings.HasPrefix(path, "/") { - return nil, fmt.Errorf("child path must not start with %q", "/") - } - fp := fmt.Sprintf("%s/%s", r.Path, path) - return r.client.NewRef(fp) -} - -func (r *Ref) Get(v interface{}) error { - resp, err := r.client.send(&request{Method: "GET", Path: r.Path}) - if err != nil { - return err - } else if err := resp.CheckAndParse(http.StatusOK, v); err != nil { - return err - } - return nil -} - -func (r *Ref) Set(v interface{}) error { - resp, err := r.client.send(&request{ - Method: "PUT", - Path: r.Path, - Body: v, - Query: map[string]string{"print": "silent"}, - }) - if err != nil { - return err - } else if err := resp.CheckStatus(http.StatusNoContent); err != nil { - return err - } - return nil -} - -func (r *Ref) Push(v interface{}) (*Ref, error) { - resp, err := r.client.send(&request{Method: "POST", Path: r.Path, Body: v}) - if err != nil { - return nil, err - } - var d struct { - Name string `json:"name"` - } - if err := resp.CheckAndParse(http.StatusOK, &d); err != nil { - return nil, err - } - return r.Child(d.Name) -} - -func (r *Ref) Update(v map[string]interface{}) error { - if len(v) == 0 { - return fmt.Errorf("value argument must be a non-empty map") - } - resp, err := r.client.send(&request{ - Method: "PATCH", - Path: r.Path, - Body: v, - Query: map[string]string{"print": "silent"}, - }) - if err != nil { - return err - } else if err := resp.CheckStatus(http.StatusNoContent); err != nil { - return err - } - return nil -} - -func (r *Ref) Remove() error { - resp, err := r.client.send(&request{Method: "DELETE", Path: r.Path}) - if err != nil { - return err - } else if err := resp.CheckStatus(http.StatusOK); err != nil { - return err - } - return nil + return &response{ + Status: resp.StatusCode, + Body: b, + Header: resp.Header, + }, nil } type request struct { @@ -233,10 +149,12 @@ type request struct { Path string Body interface{} Query map[string]string + Header map[string]string } type response struct { Status int + Header http.Header Body []byte } diff --git a/db/db_test.go b/db/db_test.go index 92865f40..cdddf1f0 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -98,9 +98,11 @@ func TestNewRef(t *testing.T) { } if r.client == nil { t.Errorf("Client = nil; want = %v", client) - } else if r.Path != tc.WantPath { + } + if r.Path != tc.WantPath { t.Errorf("Path = %q; want = %q", r.Path, tc.WantPath) - } else if r.Key != tc.WantKey { + } + if r.Key != tc.WantKey { t.Errorf("Key = %q; want = %q", r.Key, tc.WantKey) } } @@ -130,9 +132,11 @@ func TestParent(t *testing.T) { if tc.HasParent { if r == nil { t.Fatalf("Parent = nil; want = %q", tc.Want) - } else if r.client == nil { + } + if r.client == nil { t.Errorf("Client = nil; want = %v", client) - } else if r.Key != tc.Want { + } + if r.Key != tc.Want { t.Errorf("Key = %q; want = %q", r.Key, tc.Want) } } else if r != nil { @@ -150,7 +154,8 @@ func TestGet(t *testing.T) { var got map[string]interface{} if err := ref.Get(&got); err != nil { t.Fatal(err) - } else if !reflect.DeepEqual(want, got) { + } + if !reflect.DeepEqual(want, got) { t.Errorf("Get() = %v; want = %v", got, want) } checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: "/peter.json"}) @@ -165,12 +170,40 @@ func TestGetWithStruct(t *testing.T) { var got person if err := ref.Get(&got); err != nil { t.Fatal(err) - } else if want != got { + } + if want != got { t.Errorf("Get() = %v; want = %v", got, want) } checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: "/peter.json"}) } +func TestGetWithETag(t *testing.T) { + want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} + mock := &mockServer{ + Resp: want, + Header: map[string]string{"ETag": "mock-etag"}, + } + srv := mock.Start(client) + defer srv.Close() + + var got map[string]interface{} + etag, err := ref.GetWithETag(&got) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("Get() = %v; want = %v", got, want) + } + if etag != "mock-etag" { + t.Errorf("ETag = %q; want = %q", etag, "mock-etag") + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Header: http.Header{"X-Firebase-ETag": []string{"true"}}, + }) +} + func TestWerlformedHttpError(t *testing.T) { mock := &mockServer{Resp: map[string]string{"error": "test error"}, Status: 500} srv := mock.Start(client) @@ -231,6 +264,51 @@ func TestSetWithStruct(t *testing.T) { }) } +func TestSetIfUnchanged(t *testing.T) { + mock := &mockServer{} + srv := mock.Start(client) + defer srv.Close() + + want := &person{"Peter Parker", 17} + ok, err := ref.SetIfUnchanged("mock-etag", &want) + if err != nil { + t.Fatal(err) + } + if !ok { + t.Errorf("SetIfUnchanged() = %v; want = %v", ok, true) + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "PUT", + Path: "/peter.json", + Body: serialize(want), + Header: http.Header{"If-Match": []string{"mock-etag"}}, + }) +} + +func TestSetIfUnchangedError(t *testing.T) { + mock := &mockServer{ + Status: http.StatusPreconditionFailed, + Resp: &person{"Tony Stark", 39}, + } + srv := mock.Start(client) + defer srv.Close() + + want := &person{"Peter Parker", 17} + ok, err := ref.SetIfUnchanged("mock-etag", &want) + if err != nil { + t.Fatal(err) + } + if ok { + t.Errorf("SetIfUnchanged() = %v; want = %v", ok, false) + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "PUT", + Path: "/peter.json", + Body: serialize(want), + Header: http.Header{"If-Match": []string{"mock-etag"}}, + }) +} + func TestPush(t *testing.T) { mock := &mockServer{Resp: map[string]string{"name": "new_key"}} srv := mock.Start(client) @@ -298,6 +376,92 @@ func TestInvalidUpdate(t *testing.T) { } } +func TestTransaction(t *testing.T) { + mock := &mockServer{ + Resp: &person{"Peter Parker", 17}, + Header: map[string]string{"ETag": "mock-etag"}, + } + srv := mock.Start(client) + defer srv.Close() + + var fn UpdateFn = func(i interface{}) (interface{}, error) { + p := i.(map[string]interface{}) + p["age"] = p["age"].(float64) + 1.0 + return p, nil + } + if err := ref.Transaction(fn); err != nil { + t.Fatal(err) + } + checkAllRequests(t, mock.Reqs, []*testReq{ + &testReq{ + Method: "GET", + Path: "/peter.json", + Header: http.Header{"X-Firebase-ETag": []string{"true"}}, + }, + &testReq{ + Method: "PUT", + Path: "/peter.json", + Body: serialize(map[string]interface{}{ + "name": "Peter Parker", + "age": 18, + }), + Header: http.Header{"If-Match": []string{"mock-etag"}}, + }, + }) +} + +func TestTransactionRetry(t *testing.T) { + mock := &mockServer{ + Resp: &person{"Peter Parker", 17}, + Header: map[string]string{"ETag": "mock-etag1"}, + } + srv := mock.Start(client) + defer srv.Close() + + cnt := 0 + var fn UpdateFn = func(i interface{}) (interface{}, error) { + if cnt == 0 { + mock.Status = http.StatusPreconditionFailed + mock.Header = map[string]string{"ETag": "mock-etag2"} + mock.Resp = &person{"Peter Parker", 19} + } else if cnt == 1 { + mock.Status = http.StatusOK + } + cnt++ + p := i.(map[string]interface{}) + p["age"] = p["age"].(float64) + 1.0 + return p, nil + } + if err := ref.Transaction(fn); err != nil { + t.Fatal(err) + } + checkAllRequests(t, mock.Reqs, []*testReq{ + &testReq{ + Method: "GET", + Path: "/peter.json", + Header: http.Header{"X-Firebase-ETag": []string{"true"}}, + }, + &testReq{ + Method: "PUT", + Path: "/peter.json", + Body: serialize(map[string]interface{}{ + "name": "Peter Parker", + "age": 18, + }), + Header: http.Header{"If-Match": []string{"mock-etag1"}}, + }, + &testReq{ + Method: "PUT", + Path: "/peter.json", + Body: serialize(map[string]interface{}{ + "name": "Peter Parker", + "age": 20, + }), + Header: http.Header{"If-Match": []string{"mock-etag2"}}, + }, + }) +} + func TestRemove(t *testing.T) { mock := &mockServer{Resp: "null"} srv := mock.Start(client) @@ -321,32 +485,48 @@ func checkAllRequests(t *testing.T, got []*testReq, want []*testReq) { t.Errorf("Request Count = %d; want = %d", len(got), len(want)) } for i, r := range got { - if h := r.Header.Get("Authorization"); h != "Bearer mock-token" { - t.Errorf("Authorization = %q; want = %q", h, "Bearer mock-token") - } - if h := r.Header.Get("User-Agent"); h != userAgent { - t.Errorf("User-Agent = %q; want = %q", h, userAgent) - } + checkRequest(t, r, want[i]) + } +} + +func checkRequest(t *testing.T, got *testReq, want *testReq) { + if h := got.Header.Get("Authorization"); h != "Bearer mock-token" { + t.Errorf("Authorization = %q; want = %q", h, "Bearer mock-token") + } + if h := got.Header.Get("User-Agent"); h != userAgent { + t.Errorf("User-Agent = %q; want = %q", h, userAgent) + } - w := want[i] - if r.Method != w.Method { - t.Errorf("Method = %q; want = %q", r.Method, w.Method) + if got.Method != want.Method { + t.Errorf("Method = %q; want = %q", got.Method, want.Method) + } + if got.Path != want.Path { + t.Errorf("URL = %q; want = %q", got.Path, want.Path) + } + for k, v := range want.Header { + if got.Header.Get(k) != v[0] { + t.Errorf("Header(%q) = %q; want = %q", k, got.Header.Get(k), v[0]) } - if r.Path != w.Path { - t.Errorf("URL = %q; want = %q", r.Path, w.Path) + } + if want.Body != nil { + var wi, gi interface{} + if err := json.Unmarshal(want.Body, &wi); err != nil { + t.Fatal(err) } - if w.Body != nil { - if !reflect.DeepEqual(r.Body, w.Body) { - t.Errorf("Body = %v; want = %v", string(r.Body), string(w.Body)) - } - } else if len(r.Body) != 0 { - t.Errorf("Body = %v; want empty", r.Body) + if err := json.Unmarshal(got.Body, &gi); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(gi, wi) { + t.Errorf("Body = %v; want = %v", gi, wi) } + } else if len(got.Body) != 0 { + t.Errorf("Body = %v; want empty", got.Body) } } type mockServer struct { Resp interface{} + Header map[string]string Status int Reqs []*testReq srv *httptest.Server @@ -379,25 +559,30 @@ func newTestReq(r *http.Request) (*testReq, error) { } func (s *mockServer) Start(c *Client) *httptest.Server { - if s.srv == nil { - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - tr, _ := newTestReq(r) - s.Reqs = append(s.Reqs, tr) - - print := r.URL.Query().Get("print") - if s.Status != 0 { - w.WriteHeader(s.Status) - } else if print == "silent" { - w.WriteHeader(http.StatusNoContent) - return - } - b, _ := json.Marshal(s.Resp) - w.Header().Set("Content-Type", "application/json") - w.Write(b) - }) - s.srv = httptest.NewServer(handler) - c.baseURL = s.srv.URL + if s.srv != nil { + return s.srv } + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + tr, _ := newTestReq(r) + s.Reqs = append(s.Reqs, tr) + + for k, v := range s.Header { + w.Header().Set(k, v) + } + + print := r.URL.Query().Get("print") + if s.Status != 0 { + w.WriteHeader(s.Status) + } else if print == "silent" { + w.WriteHeader(http.StatusNoContent) + return + } + b, _ := json.Marshal(s.Resp) + w.Header().Set("Content-Type", "application/json") + w.Write(b) + }) + s.srv = httptest.NewServer(handler) + c.baseURL = s.srv.URL return s.srv } diff --git a/db/ref.go b/db/ref.go new file mode 100644 index 00000000..67cd47a3 --- /dev/null +++ b/db/ref.go @@ -0,0 +1,178 @@ +// Copyright 2017 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package db + +import ( + "encoding/json" + "fmt" + "net/http" + "strings" +) + +type Ref struct { + Key string + Path string + + client *Client + segs []string +} + +func (r *Ref) Parent() *Ref { + l := len(r.segs) + if l > 0 { + path := strings.Join(r.segs[:l-1], "/") + parent, _ := r.client.NewRef(path) + return parent + } + return nil +} + +func (r *Ref) Child(path string) (*Ref, error) { + if strings.HasPrefix(path, "/") { + return nil, fmt.Errorf("child path must not start with %q", "/") + } + fp := fmt.Sprintf("%s/%s", r.Path, path) + return r.client.NewRef(fp) +} + +func (r *Ref) Get(v interface{}) error { + resp, err := r.client.send(&request{Method: "GET", Path: r.Path}) + if err != nil { + return err + } + return resp.CheckAndParse(http.StatusOK, v) +} + +func (r *Ref) GetWithETag(v interface{}) (string, error) { + resp, err := r.client.send(&request{ + Method: "GET", + Path: r.Path, + Header: map[string]string{"X-Firebase-ETag": "true"}, + }) + if err != nil { + return "", err + } else if err := resp.CheckAndParse(http.StatusOK, v); err != nil { + return "", err + } + return resp.Header.Get("Etag"), nil +} + +func (r *Ref) Set(v interface{}) error { + resp, err := r.client.send(&request{ + Method: "PUT", + Path: r.Path, + Body: v, + Query: map[string]string{"print": "silent"}, + }) + if err != nil { + return err + } + return resp.CheckStatus(http.StatusNoContent) +} + +func (r *Ref) SetIfUnchanged(etag string, v interface{}) (bool, error) { + ok, _, err := r.compareAndSet(etag, v) + return ok, err +} + +func (r *Ref) Push(v interface{}) (*Ref, error) { + resp, err := r.client.send(&request{ + Method: "POST", + Path: r.Path, + Body: v, + }) + if err != nil { + return nil, err + } + var d struct { + Name string `json:"name"` + } + if err := resp.CheckAndParse(http.StatusOK, &d); err != nil { + return nil, err + } + return r.Child(d.Name) +} + +func (r *Ref) Update(v map[string]interface{}) error { + if len(v) == 0 { + return fmt.Errorf("value argument must be a non-empty map") + } + resp, err := r.client.send(&request{ + Method: "PATCH", + Path: r.Path, + Body: v, + Query: map[string]string{"print": "silent"}, + }) + if err != nil { + return err + } + return resp.CheckStatus(http.StatusNoContent) +} + +type UpdateFn func(interface{}) (interface{}, error) + +func (r *Ref) Transaction(fn UpdateFn) error { + var curr interface{} + etag, err := r.GetWithETag(&curr) + if err != nil { + return err + } + + for i := 0; i < 20; i++ { + new, err := fn(curr) + if err != nil { + return err + } + + ok, b, err := r.compareAndSet(etag, new) + if err != nil { + return err + } else if ok { + break + } else if err := json.Unmarshal(b, &curr); err != nil { + return err + } + } + return nil +} + +func (r *Ref) Remove() error { + resp, err := r.client.send(&request{ + Method: "DELETE", + Path: r.Path, + }) + if err != nil { + return err + } + return resp.CheckStatus(http.StatusOK) +} + +func (r *Ref) compareAndSet(etag string, new interface{}) (bool, []byte, error) { + resp, err := r.client.send(&request{ + Method: "PUT", + Path: r.Path, + Body: new, + Header: map[string]string{"If-Match": etag}, + }) + if err != nil { + return false, nil, err + } + if resp.Status == http.StatusPreconditionFailed { + return false, resp.Body, nil + } else if err := resp.CheckStatus(http.StatusOK); err != nil { + return false, nil, err + } + return true, nil, nil +} From 0a21f4e3458e8a90c90605b4253bb99c5471fc3b Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Sun, 22 Oct 2017 12:21:44 -0700 Subject: [PATCH 05/58] Fixed Transaction() API --- db/db_test.go | 47 +++++++++++++++++++++++++++++++++++++++++++++ db/ref.go | 53 ++++++++++++++++++++++++--------------------------- 2 files changed, 72 insertions(+), 28 deletions(-) diff --git a/db/db_test.go b/db/db_test.go index cdddf1f0..2d7c8fe0 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -435,6 +435,9 @@ func TestTransactionRetry(t *testing.T) { if err := ref.Transaction(fn); err != nil { t.Fatal(err) } + if cnt != 2 { + t.Errorf("Retry Count = %d; want = %d", cnt, 2) + } checkAllRequests(t, mock.Reqs, []*testReq{ &testReq{ Method: "GET", @@ -462,6 +465,50 @@ func TestTransactionRetry(t *testing.T) { }) } +func TestTransactionAbort(t *testing.T) { + mock := &mockServer{ + Resp: &person{"Peter Parker", 17}, + Header: map[string]string{"ETag": "mock-etag1"}, + } + srv := mock.Start(client) + defer srv.Close() + + cnt := 0 + var fn UpdateFn = func(i interface{}) (interface{}, error) { + if cnt == 0 { + mock.Status = http.StatusPreconditionFailed + mock.Header = map[string]string{"ETag": "mock-etag1"} + } + cnt++ + p := i.(map[string]interface{}) + p["age"] = p["age"].(float64) + 1.0 + return p, nil + } + err := ref.Transaction(fn) + if err == nil { + t.Errorf("Transaction() = nil; want error") + } + wanted := []*testReq{ + &testReq{ + Method: "GET", + Path: "/peter.json", + Header: http.Header{"X-Firebase-ETag": []string{"true"}}, + }, + } + for i := 0; i < 20; i++ { + wanted = append(wanted, &testReq{ + Method: "PUT", + Path: "/peter.json", + Body: serialize(map[string]interface{}{ + "name": "Peter Parker", + "age": 18, + }), + Header: http.Header{"If-Match": []string{"mock-etag1"}}, + }) + } + checkAllRequests(t, mock.Reqs, wanted) +} + func TestRemove(t *testing.T) { mock := &mockServer{Resp: "null"} srv := mock.Start(client) diff --git a/db/ref.go b/db/ref.go index 67cd47a3..47d61873 100644 --- a/db/ref.go +++ b/db/ref.go @@ -15,7 +15,6 @@ package db import ( - "encoding/json" "fmt" "net/http" "strings" @@ -83,8 +82,20 @@ func (r *Ref) Set(v interface{}) error { } func (r *Ref) SetIfUnchanged(etag string, v interface{}) (bool, error) { - ok, _, err := r.compareAndSet(etag, v) - return ok, err + resp, err := r.client.send(&request{ + Method: "PUT", + Path: r.Path, + Body: v, + Header: map[string]string{"If-Match": etag}, + }) + if err != nil { + return false, err + } else if err := resp.CheckStatus(http.StatusOK); err == nil { + return true, nil + } else if err := resp.CheckStatus(http.StatusPreconditionFailed); err != nil { + return false, err + } + return false, nil } func (r *Ref) Push(v interface{}) (*Ref, error) { @@ -136,16 +147,20 @@ func (r *Ref) Transaction(fn UpdateFn) error { return err } - ok, b, err := r.compareAndSet(etag, new) - if err != nil { - return err - } else if ok { - break - } else if err := json.Unmarshal(b, &curr); err != nil { + resp, err := r.client.send(&request{ + Method: "PUT", + Path: r.Path, + Body: new, + Header: map[string]string{"If-Match": etag}, + }) + if err := resp.CheckStatus(http.StatusOK); err == nil { + return nil + } else if err := resp.CheckAndParse(http.StatusPreconditionFailed, &curr); err != nil { return err } + etag = resp.Header.Get("ETag") } - return nil + return fmt.Errorf("transaction aborted after failed retries") } func (r *Ref) Remove() error { @@ -158,21 +173,3 @@ func (r *Ref) Remove() error { } return resp.CheckStatus(http.StatusOK) } - -func (r *Ref) compareAndSet(etag string, new interface{}) (bool, []byte, error) { - resp, err := r.client.send(&request{ - Method: "PUT", - Path: r.Path, - Body: new, - Header: map[string]string{"If-Match": etag}, - }) - if err != nil { - return false, nil, err - } - if resp.Status == http.StatusPreconditionFailed { - return false, resp.Body, nil - } else if err := resp.CheckStatus(http.StatusOK); err != nil { - return false, nil, err - } - return true, nil, nil -} From 0fcbbce387af8d1eb8a36eb6418e2c25bf0476dd Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Sun, 22 Oct 2017 13:05:47 -0700 Subject: [PATCH 06/58] Code cleanup --- db/db.go | 93 ---------------------------------------------- db/db_test.go | 29 ++++++++------- db/http_client.go | 95 +++++++++++++++++++++++++++++++++++++++++++++++ db/ref.go | 48 ++++-------------------- 4 files changed, 119 insertions(+), 146 deletions(-) create mode 100644 db/http_client.go diff --git a/db/db.go b/db/db.go index 24666dab..9df3e589 100644 --- a/db/db.go +++ b/db/db.go @@ -17,7 +17,6 @@ package db import ( "fmt" - "io" "net/http" "strings" @@ -26,14 +25,8 @@ import ( "net/url" - "io/ioutil" - - "encoding/json" - "runtime" - "bytes" - "golang.org/x/net/context" "google.golang.org/api/option" "google.golang.org/api/transport" @@ -97,89 +90,3 @@ func (c *Client) NewRef(path string) (*Ref, error) { segs: segs, }, nil } - -func (c *Client) send(r *request) (*response, error) { - url := fmt.Sprintf("%s%s%s", c.baseURL, r.Path, ".json") - - var body io.Reader - if r.Body != nil { - b, err := json.Marshal(r.Body) - if err != nil { - return nil, err - } - body = bytes.NewBuffer(b) - } - - req, err := http.NewRequest(r.Method, url, body) - if err != nil { - return nil, err - } else if body != nil { - req.Header.Add("Content-Type", "application/json") - } - - for k, v := range r.Header { - req.Header.Add(k, v) - } - - q := req.URL.Query() - for k, v := range r.Query { - q.Add(k, v) - } - req.URL.RawQuery = q.Encode() - - resp, err := c.hc.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - b, err := ioutil.ReadAll(resp.Body) - if err != nil { - return nil, err - } - return &response{ - Status: resp.StatusCode, - Body: b, - Header: resp.Header, - }, nil -} - -type request struct { - Method string - Path string - Body interface{} - Query map[string]string - Header map[string]string -} - -type response struct { - Status int - Header http.Header - Body []byte -} - -func (r *response) CheckStatus(want int) error { - if r.Status == want { - return nil - } - var b struct { - Error string `json:"error"` - } - json.Unmarshal(r.Body, &b) - var msg string - if b.Error != "" { - msg = fmt.Sprintf("http error status: %d; reason: %s", r.Status, b.Error) - } else { - msg = fmt.Sprintf("http error status: %d; message: %s", r.Status, string(r.Body)) - } - return fmt.Errorf(msg) -} - -func (r *response) CheckAndParse(want int, v interface{}) error { - if err := r.CheckStatus(want); err != nil { - return err - } else if err := json.Unmarshal(r.Body, v); err != nil { - return err - } - return nil -} diff --git a/db/db_test.go b/db/db_test.go index 2d7c8fe0..21d20833 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -556,6 +556,9 @@ func checkRequest(t *testing.T, got *testReq, want *testReq) { } } if want.Body != nil { + if h := got.Header.Get("Content-Type"); h != "application/json" { + t.Errorf("User-Agent = %q; want = %q", h, "application/json") + } var wi, gi interface{} if err := json.Unmarshal(want.Body, &wi); err != nil { t.Fatal(err) @@ -571,14 +574,6 @@ func checkRequest(t *testing.T, got *testReq, want *testReq) { } } -type mockServer struct { - Resp interface{} - Header map[string]string - Status int - Reqs []*testReq - srv *httptest.Server -} - type testReq struct { Method string Path string @@ -586,11 +581,6 @@ type testReq struct { Body []byte } -func serialize(v interface{}) []byte { - b, _ := json.Marshal(v) - return b -} - func newTestReq(r *http.Request) (*testReq, error) { defer r.Body.Close() b, err := ioutil.ReadAll(r.Body) @@ -605,6 +595,14 @@ func newTestReq(r *http.Request) (*testReq, error) { }, nil } +type mockServer struct { + Resp interface{} + Header map[string]string + Status int + Reqs []*testReq + srv *httptest.Server +} + func (s *mockServer) Start(c *Client) *httptest.Server { if s.srv != nil { return s.srv @@ -645,3 +643,8 @@ type person struct { Name string `json:"name"` Age int32 `json:"age"` } + +func serialize(v interface{}) []byte { + b, _ := json.Marshal(v) + return b +} diff --git a/db/http_client.go b/db/http_client.go new file mode 100644 index 00000000..70d331e8 --- /dev/null +++ b/db/http_client.go @@ -0,0 +1,95 @@ +package db + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "net/http" +) + +func (r *Ref) send(method string, body interface{}, opts ...httpOption) (*response, error) { + url := fmt.Sprintf("%s%s%s", r.client.baseURL, r.Path, ".json") + var data io.Reader + if body != nil { + b, err := json.Marshal(body) + if err != nil { + return nil, err + } + data = bytes.NewBuffer(b) + opts = append(opts, withHeader("Content-Type", "application/json")) + } + + req, err := http.NewRequest(method, url, data) + if err != nil { + return nil, err + } + for _, o := range opts { + o(req) + } + + resp, err := r.client.hc.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, err + } + return &response{ + Status: resp.StatusCode, + Body: b, + Header: resp.Header, + }, nil +} + +type response struct { + Status int + Header http.Header + Body []byte +} + +func (r *response) CheckStatus(want int) error { + if r.Status == want { + return nil + } + var b struct { + Error string `json:"error"` + } + json.Unmarshal(r.Body, &b) + var msg string + if b.Error != "" { + msg = fmt.Sprintf("http error status: %d; reason: %s", r.Status, b.Error) + } else { + msg = fmt.Sprintf("http error status: %d; message: %s", r.Status, string(r.Body)) + } + return fmt.Errorf(msg) +} + +func (r *response) CheckAndParse(want int, v interface{}) error { + if err := r.CheckStatus(want); err != nil { + return err + } else if err := json.Unmarshal(r.Body, v); err != nil { + return err + } + return nil +} + +type httpOption func(*http.Request) + +func withHeader(key, value string) httpOption { + return func(r *http.Request) { + r.Header.Set(key, value) + } +} + +func withQueryParam(key, value string) httpOption { + return func(r *http.Request) { + q := r.URL.Query() + q.Add(key, value) + r.URL.RawQuery = q.Encode() + } +} diff --git a/db/ref.go b/db/ref.go index 47d61873..d943f010 100644 --- a/db/ref.go +++ b/db/ref.go @@ -47,7 +47,7 @@ func (r *Ref) Child(path string) (*Ref, error) { } func (r *Ref) Get(v interface{}) error { - resp, err := r.client.send(&request{Method: "GET", Path: r.Path}) + resp, err := r.send("GET", nil) if err != nil { return err } @@ -55,11 +55,7 @@ func (r *Ref) Get(v interface{}) error { } func (r *Ref) GetWithETag(v interface{}) (string, error) { - resp, err := r.client.send(&request{ - Method: "GET", - Path: r.Path, - Header: map[string]string{"X-Firebase-ETag": "true"}, - }) + resp, err := r.send("GET", nil, withHeader("X-Firebase-ETag", "true")) if err != nil { return "", err } else if err := resp.CheckAndParse(http.StatusOK, v); err != nil { @@ -69,12 +65,7 @@ func (r *Ref) GetWithETag(v interface{}) (string, error) { } func (r *Ref) Set(v interface{}) error { - resp, err := r.client.send(&request{ - Method: "PUT", - Path: r.Path, - Body: v, - Query: map[string]string{"print": "silent"}, - }) + resp, err := r.send("PUT", v, withQueryParam("print", "silent")) if err != nil { return err } @@ -82,12 +73,7 @@ func (r *Ref) Set(v interface{}) error { } func (r *Ref) SetIfUnchanged(etag string, v interface{}) (bool, error) { - resp, err := r.client.send(&request{ - Method: "PUT", - Path: r.Path, - Body: v, - Header: map[string]string{"If-Match": etag}, - }) + resp, err := r.send("PUT", v, withHeader("If-Match", etag)) if err != nil { return false, err } else if err := resp.CheckStatus(http.StatusOK); err == nil { @@ -99,11 +85,7 @@ func (r *Ref) SetIfUnchanged(etag string, v interface{}) (bool, error) { } func (r *Ref) Push(v interface{}) (*Ref, error) { - resp, err := r.client.send(&request{ - Method: "POST", - Path: r.Path, - Body: v, - }) + resp, err := r.send("POST", v) if err != nil { return nil, err } @@ -120,12 +102,7 @@ func (r *Ref) Update(v map[string]interface{}) error { if len(v) == 0 { return fmt.Errorf("value argument must be a non-empty map") } - resp, err := r.client.send(&request{ - Method: "PATCH", - Path: r.Path, - Body: v, - Query: map[string]string{"print": "silent"}, - }) + resp, err := r.send("PATCH", v, withQueryParam("print", "silent")) if err != nil { return err } @@ -146,13 +123,7 @@ func (r *Ref) Transaction(fn UpdateFn) error { if err != nil { return err } - - resp, err := r.client.send(&request{ - Method: "PUT", - Path: r.Path, - Body: new, - Header: map[string]string{"If-Match": etag}, - }) + resp, err := r.send("PUT", new, withHeader("If-Match", etag)) if err := resp.CheckStatus(http.StatusOK); err == nil { return nil } else if err := resp.CheckAndParse(http.StatusPreconditionFailed, &curr); err != nil { @@ -164,10 +135,7 @@ func (r *Ref) Transaction(fn UpdateFn) error { } func (r *Ref) Remove() error { - resp, err := r.client.send(&request{ - Method: "DELETE", - Path: r.Path, - }) + resp, err := r.send("DELETE", nil) if err != nil { return err } From 65870f71d0975c37b0fba703ac8e7a43223e2391 Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Sun, 22 Oct 2017 15:28:27 -0700 Subject: [PATCH 07/58] Implemented Query() API --- db/db.go | 24 +++-- db/db_test.go | 31 +++++- db/http_client.go | 10 ++ db/query.go | 142 ++++++++++++++++++++++++++ db/query_test.go | 246 ++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 441 insertions(+), 12 deletions(-) create mode 100644 db/query.go create mode 100644 db/query_test.go diff --git a/db/db.go b/db/db.go index 9df3e589..efb34b20 100644 --- a/db/db.go +++ b/db/db.go @@ -68,14 +68,9 @@ func NewClient(ctx context.Context, c *internal.DatabaseConfig) (*Client, error) } func (c *Client) NewRef(path string) (*Ref, error) { - if strings.ContainsAny(path, invalidChars) { - return nil, fmt.Errorf("path %q contains one or more invalid characters", path) - } - var segs []string - for _, s := range strings.Split(path, "/") { - if s != "" { - segs = append(segs, s) - } + segs, err := parsePath(path) + if err != nil { + return nil, err } key := "" @@ -90,3 +85,16 @@ func (c *Client) NewRef(path string) (*Ref, error) { segs: segs, }, nil } + +func parsePath(path string) ([]string, error) { + if strings.ContainsAny(path, invalidChars) { + return nil, fmt.Errorf("path %q contains one or more invalid characters", path) + } + var segs []string + for _, s := range strings.Split(path, "/") { + if s != "" { + segs = append(segs, s) + } + } + return segs, nil +} diff --git a/db/db_test.go b/db/db_test.go index 21d20833..9da3a7bf 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -16,6 +16,8 @@ import ( "io/ioutil" + "net/url" + "firebase.google.com/go/internal" "google.golang.org/api/option" ) @@ -536,7 +538,7 @@ func checkAllRequests(t *testing.T, got []*testReq, want []*testReq) { } } -func checkRequest(t *testing.T, got *testReq, want *testReq) { +func checkRequest(t *testing.T, got, want *testReq) { if h := got.Header.Get("Authorization"); h != "Bearer mock-token" { t.Errorf("Authorization = %q; want = %q", h, "Bearer mock-token") } @@ -547,9 +549,8 @@ func checkRequest(t *testing.T, got *testReq, want *testReq) { if got.Method != want.Method { t.Errorf("Method = %q; want = %q", got.Method, want.Method) } - if got.Path != want.Path { - t.Errorf("URL = %q; want = %q", got.Path, want.Path) - } + + checkURL(t, got.Path, want.Path) for k, v := range want.Header { if got.Header.Get(k) != v[0] { t.Errorf("Header(%q) = %q; want = %q", k, got.Header.Get(k), v[0]) @@ -574,6 +575,28 @@ func checkRequest(t *testing.T, got *testReq, want *testReq) { } } +func checkURL(t *testing.T, ug, uw string) { + got, err := url.ParseRequestURI(ug) + if err != nil { + t.Fatal(err) + } + want, err := url.ParseRequestURI(uw) + if err != nil { + t.Fatal(err) + } + if got.Path != want.Path { + t.Errorf("Path = %q; want = %q", got, want) + } + if len(got.Query()) != len(want.Query()) { + t.Errorf("QueryParams = %v; want = %v", got.Query(), want.Query()) + } + for k, v := range want.Query() { + if !reflect.DeepEqual(v, got.Query()[k]) { + t.Errorf("QueryParam(%v) = %v; want = %v", k, got.Query()[k], v) + } + } +} + type testReq struct { Method string Path string diff --git a/db/http_client.go b/db/http_client.go index 70d331e8..fc66911b 100644 --- a/db/http_client.go +++ b/db/http_client.go @@ -93,3 +93,13 @@ func withQueryParam(key, value string) httpOption { r.URL.RawQuery = q.Encode() } } + +func withQueryParams(qp queryParams) httpOption { + return func(r *http.Request) { + q := r.URL.Query() + for k, v := range qp { + q.Add(k, v) + } + r.URL.RawQuery = q.Encode() + } +} diff --git a/db/query.go b/db/query.go new file mode 100644 index 00000000..8faea3fc --- /dev/null +++ b/db/query.go @@ -0,0 +1,142 @@ +package db + +import ( + "encoding/json" + "fmt" + "net/http" + "strconv" + "strings" +) + +var reservedFilters = map[string]bool{ + "$key": true, + "$value": true, + "$priority": true, +} + +type Query struct { + ref *Ref + qp queryParams +} + +func (q *Query) Get(v interface{}) error { + resp, err := q.ref.send("GET", nil, withQueryParams(q.qp)) + if err != nil { + return err + } + return resp.CheckAndParse(http.StatusOK, v) +} + +type QueryOption interface { + apply(qp queryParams) error +} + +func WithLimitToFirst(lim int) QueryOption { + return &limitParam{"limitToFirst", lim} +} + +func WithLimitToLast(lim int) QueryOption { + return &limitParam{"limitToLast", lim} +} + +func WithStartAt(v interface{}) QueryOption { + return &filterParam{"startAt", v} +} + +func WithEndAt(v interface{}) QueryOption { + return &filterParam{"endAt", v} +} + +func WithEqualTo(v interface{}) QueryOption { + return &filterParam{"equalTo", v} +} + +func (r *Ref) OrderByChild(child string, opts ...QueryOption) (*Query, error) { + if child == "" { + return nil, fmt.Errorf("child path must be a non-empty string") + } + if _, ok := reservedFilters[child]; ok { + return nil, fmt.Errorf("invalid child path: %s", child) + } + segs, err := parsePath(child) + if err != nil { + return nil, err + } + opts = append(opts, orderByParam(strings.Join(segs, "/"))) + return newQuery(r, opts) +} + +func (r *Ref) OrderByKey(opts ...QueryOption) (*Query, error) { + opts = append(opts, orderByParam("$key")) + return newQuery(r, opts) +} + +func (r *Ref) OrderByValue(opts ...QueryOption) (*Query, error) { + opts = append(opts, orderByParam("$value")) + return newQuery(r, opts) +} + +func newQuery(r *Ref, opts []QueryOption) (*Query, error) { + qp := make(queryParams) + for _, o := range opts { + if err := o.apply(qp); err != nil { + return nil, err + } + } + return &Query{ref: r, qp: qp}, nil +} + +type queryParams map[string]string + +type orderByParam string + +func (p orderByParam) apply(qp queryParams) error { + b, err := json.Marshal(p) + if err != nil { + return err + } + qp["orderBy"] = string(b) + return nil +} + +type limitParam struct { + key string + val int +} + +func (p *limitParam) apply(qp queryParams) error { + if p.val < 0 { + return fmt.Errorf("limit parameters must not be negative: %d", p.val) + } else if p.val == 0 { + return nil + } + + qp[p.key] = strconv.Itoa(p.val) + cnt := 0 + for _, k := range []string{"limitToFirst", "limitToLast"} { + if _, ok := qp[k]; ok { + cnt++ + } + } + if cnt == 2 { + return fmt.Errorf("cannot set both limit parameters") + } + return nil +} + +type filterParam struct { + key string + val interface{} +} + +func (p *filterParam) apply(qp queryParams) error { + if p.val == nil { + return nil + } + b, err := json.Marshal(p.val) + if err != nil { + return err + } + qp[p.key] = string(b) + return nil +} diff --git a/db/query_test.go b/db/query_test.go new file mode 100644 index 00000000..14cf22ea --- /dev/null +++ b/db/query_test.go @@ -0,0 +1,246 @@ +package db + +import ( + "reflect" + "testing" +) + +func TestChildQuery(t *testing.T) { + want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + q, err := ref.OrderByChild("messages") + if err != nil { + t.Fatal(err) + } + + var got map[string]interface{} + if err := q.Get(&got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("Get() = %v; want = %v", got, want) + } + p := "/peter.json?orderBy=\"messages\"" + checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: p}) +} + +func TestNestedChildQuery(t *testing.T) { + want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + q, err := ref.OrderByChild("messages/ratings") + if err != nil { + t.Fatal(err) + } + + var got map[string]interface{} + if err := q.Get(&got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("Get() = %v; want = %v", got, want) + } + p := "/peter.json?orderBy=\"messages/ratings\"" + checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: p}) +} + +func TestChildQueryWithParams(t *testing.T) { + want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + opts := []QueryOption{ + WithStartAt("m4"), + WithEndAt("m50"), + WithLimitToFirst(10), + } + q, err := ref.OrderByChild("messages", opts...) + if err != nil { + t.Fatal(err) + } + + var got map[string]interface{} + if err := q.Get(&got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("Get() = %v; want = %v", got, want) + } + p := "/peter.json?startAt=\"m4\"&endAt=\"m50\"&limitToFirst=10&orderBy=\"messages\"" + checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: p}) +} + +func TestKeyQuery(t *testing.T) { + want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + q, err := ref.OrderByKey() + if err != nil { + t.Fatal(err) + } + + var got map[string]interface{} + if err := q.Get(&got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("Get() = %v; want = %v", got, want) + } + checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: "/peter.json?orderBy=\"$key\""}) +} + +func TestValueQuery(t *testing.T) { + want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + q, err := ref.OrderByValue() + if err != nil { + t.Fatal(err) + } + + var got map[string]interface{} + if err := q.Get(&got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("Get() = %v; want = %v", got, want) + } + checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: "/peter.json?orderBy=\"$value\""}) +} + +func TestLimitFirstQuery(t *testing.T) { + want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + q, err := ref.OrderByChild("messages", WithLimitToFirst(10)) + if err != nil { + t.Fatal(err) + } + + var got map[string]interface{} + if err := q.Get(&got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("Get() = %v; want = %v", got, want) + } + p := "/peter.json?limitToFirst=10&orderBy=\"messages\"" + checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: p}) +} + +func TestLimitLastQuery(t *testing.T) { + want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + q, err := ref.OrderByChild("messages", WithLimitToLast(10)) + if err != nil { + t.Fatal(err) + } + + var got map[string]interface{} + if err := q.Get(&got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("Get() = %v; want = %v", got, want) + } + p := "/peter.json?limitToLast=10&orderBy=\"messages\"" + checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: p}) +} + +func TestInvalidLimitQuery(t *testing.T) { + q, err := ref.OrderByChild("messages", WithLimitToFirst(10), WithLimitToLast(10)) + if q != nil || err == nil { + t.Errorf("Query(first=10, last=10) = (%v, %v); want (nil, error)", q, err) + } + + q, err = ref.OrderByChild("messages", WithLimitToFirst(-10)) + if q != nil || err == nil { + t.Errorf("Query(first=-10) = (%v, %v); want (nil, error)", q, err) + } + + q, err = ref.OrderByChild("messages", WithLimitToLast(-10)) + if q != nil || err == nil { + t.Errorf("Query(last=-10) = (%v, %v); want (nil, error)", q, err) + } +} + +func TestStartAtQuery(t *testing.T) { + want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + q, err := ref.OrderByChild("messages", WithStartAt(10)) + if err != nil { + t.Fatal(err) + } + + var got map[string]interface{} + if err := q.Get(&got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("Get() = %v; want = %v", got, want) + } + p := "/peter.json?startAt=10&orderBy=\"messages\"" + checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: p}) +} + +func TestEndAtQuery(t *testing.T) { + want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + q, err := ref.OrderByChild("messages", WithEndAt(10)) + if err != nil { + t.Fatal(err) + } + + var got map[string]interface{} + if err := q.Get(&got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("Get() = %v; want = %v", got, want) + } + p := "/peter.json?endAt=10&orderBy=\"messages\"" + checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: p}) +} + +func TestEqualToQuery(t *testing.T) { + want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + q, err := ref.OrderByChild("messages", WithEqualTo(10)) + if err != nil { + t.Fatal(err) + } + + var got map[string]interface{} + if err := q.Get(&got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("Get() = %v; want = %v", got, want) + } + p := "/peter.json?equalTo=10&orderBy=\"messages\"" + checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: p}) +} From cad2bfee69535e95ceff46ef67bc3a0899823ab4 Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Sun, 22 Oct 2017 18:19:50 -0700 Subject: [PATCH 08/58] Added GetIfChanged() and integration tests --- db/db.go | 8 +- db/db_test.go | 66 +++++++++- db/ref.go | 12 ++ firebase.go | 14 ++ integration/db/db_test.go | 214 +++++++++++++++++++++++++++++++ integration/internal/internal.go | 5 +- internal/internal.go | 1 + testdata/dinosaurs.json | 78 +++++++++++ 8 files changed, 389 insertions(+), 9 deletions(-) create mode 100644 integration/db/db_test.go create mode 100644 testdata/dinosaurs.json diff --git a/db/db.go b/db/db.go index efb34b20..9bd960ce 100644 --- a/db/db.go +++ b/db/db.go @@ -18,23 +18,20 @@ package db import ( "fmt" "net/http" + "runtime" "strings" - firebase "firebase.google.com/go" "firebase.google.com/go/internal" "net/url" - "runtime" - "golang.org/x/net/context" "google.golang.org/api/option" "google.golang.org/api/transport" ) const invalidChars = "[].#$" - -var userAgent = fmt.Sprintf("Firebase/HTTP/%s/%s/AdminGo", firebase.Version, runtime.Version()) +const userAgent = "Firebase/HTTP/%s/%s/AdminGo" // Client is the interface for the Firebase Realtime Database service. type Client struct { @@ -43,6 +40,7 @@ type Client struct { } func NewClient(ctx context.Context, c *internal.DatabaseConfig) (*Client, error) { + userAgent := fmt.Sprintf(userAgent, c.Version, runtime.Version()) o := []option.ClientOption{option.WithUserAgent(userAgent)} o = append(o, c.Opts...) diff --git a/db/db_test.go b/db/db_test.go index 9da3a7bf..1bce4571 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -1,10 +1,12 @@ package db import ( + "fmt" "log" "net/http" "net/http/httptest" "os" + "runtime" "testing" "golang.org/x/net/context" @@ -24,6 +26,7 @@ import ( const testURL = "https://test-db.firebaseio.com" +var testUserAgent string var testOpts = []option.ClientOption{ option.WithTokenSource(&mockTokenSource{"mock-token"}), } @@ -33,7 +36,7 @@ var ref *Ref func TestMain(m *testing.M) { var err error - conf := &internal.DatabaseConfig{Opts: testOpts, BaseURL: testURL} + conf := &internal.DatabaseConfig{Opts: testOpts, BaseURL: testURL, Version: "1.2.3"} client, err = NewClient(context.Background(), conf) if err != nil { log.Fatalln(err) @@ -43,6 +46,8 @@ func TestMain(m *testing.M) { if err != nil { log.Fatalln(err) } + + testUserAgent = fmt.Sprintf(userAgent, "1.2.3", runtime.Version()) os.Exit(m.Run()) } @@ -206,6 +211,61 @@ func TestGetWithETag(t *testing.T) { }) } +func TestGetIfChanged(t *testing.T) { + want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} + mock := &mockServer{ + Resp: want, + Header: map[string]string{"ETag": "new-etag"}, + } + srv := mock.Start(client) + defer srv.Close() + + var got map[string]interface{} + ok, etag, err := ref.GetIfChanged("old-etag", &got) + if err != nil { + t.Fatal(err) + } + if !ok { + t.Errorf("Get() = %v; want = %v", ok, true) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("Get() = %v; want = %v", got, want) + } + if etag != "new-etag" { + t.Errorf("ETag = %q; want = %q", etag, "new-etag") + } + + mock.Status = http.StatusNotModified + mock.Resp = nil + var got2 map[string]interface{} + ok, etag, err = ref.GetIfChanged("new-etag", &got2) + if err != nil { + t.Fatal(err) + } + if ok { + t.Errorf("Get() = %v; want = %v", ok, false) + } + if got2 != nil { + t.Errorf("Get() = %v; want nil", got2) + } + if etag != "new-etag" { + t.Errorf("ETag = %q; want = %q", etag, "new-etag") + } + + checkAllRequests(t, mock.Reqs, []*testReq{ + &testReq{ + Method: "GET", + Path: "/peter.json", + Header: http.Header{"If-None-Match": []string{"old-etag"}}, + }, + &testReq{ + Method: "GET", + Path: "/peter.json", + Header: http.Header{"If-None-Match": []string{"new-etag"}}, + }, + }) +} + func TestWerlformedHttpError(t *testing.T) { mock := &mockServer{Resp: map[string]string{"error": "test error"}, Status: 500} srv := mock.Start(client) @@ -542,8 +602,8 @@ func checkRequest(t *testing.T, got, want *testReq) { if h := got.Header.Get("Authorization"); h != "Bearer mock-token" { t.Errorf("Authorization = %q; want = %q", h, "Bearer mock-token") } - if h := got.Header.Get("User-Agent"); h != userAgent { - t.Errorf("User-Agent = %q; want = %q", h, userAgent) + if h := got.Header.Get("User-Agent"); h != testUserAgent { + t.Errorf("User-Agent = %q; want = %q", h, testUserAgent) } if got.Method != want.Method { diff --git a/db/ref.go b/db/ref.go index d943f010..1766656c 100644 --- a/db/ref.go +++ b/db/ref.go @@ -64,6 +64,18 @@ func (r *Ref) GetWithETag(v interface{}) (string, error) { return resp.Header.Get("Etag"), nil } +func (r *Ref) GetIfChanged(etag string, v interface{}) (bool, string, error) { + resp, err := r.send("GET", nil, withHeader("If-None-Match", etag)) + if err != nil { + return false, "", err + } else if err := resp.CheckAndParse(http.StatusOK, v); err == nil { + return true, resp.Header.Get("ETag"), nil + } else if err := resp.CheckStatus(http.StatusNotModified); err != nil { + return false, "", err + } + return false, etag, nil +} + func (r *Ref) Set(v interface{}) error { resp, err := r.send("PUT", v, withQueryParam("print", "silent")) if err != nil { diff --git a/firebase.go b/firebase.go index 27f8fdf0..e6c131dd 100644 --- a/firebase.go +++ b/firebase.go @@ -19,6 +19,7 @@ package firebase import ( "firebase.google.com/go/auth" + "firebase.google.com/go/db" "firebase.google.com/go/internal" "firebase.google.com/go/storage" @@ -42,6 +43,7 @@ const Version = "2.0.0" // An App holds configuration and state common to all Firebase services that are exposed from the SDK. type App struct { creds *google.DefaultCredentials + databaseURL string projectID string storageBucket string opts []option.ClientOption @@ -49,6 +51,7 @@ type App struct { // Config represents the configuration used to initialize an App. type Config struct { + DatabaseURL string ProjectID string StorageBucket string } @@ -63,6 +66,16 @@ func (a *App) Auth(ctx context.Context) (*auth.Client, error) { return auth.NewClient(ctx, conf) } +// Database returns an instance of db.Client. +func (a *App) Database(ctx context.Context) (*db.Client, error) { + conf := &internal.DatabaseConfig{ + BaseURL: a.databaseURL, + Opts: a.opts, + Version: Version, + } + return db.NewClient(ctx, conf) +} + // Storage returns a new instance of storage.Client. func (a *App) Storage(ctx context.Context) (*storage.Client, error) { conf := &internal.StorageConfig{ @@ -101,6 +114,7 @@ func NewApp(ctx context.Context, config *Config, opts ...option.ClientOption) (* return &App{ creds: creds, + databaseURL: config.DatabaseURL, projectID: pid, storageBucket: config.StorageBucket, opts: o, diff --git a/integration/db/db_test.go b/integration/db/db_test.go new file mode 100644 index 00000000..0fd854f9 --- /dev/null +++ b/integration/db/db_test.go @@ -0,0 +1,214 @@ +package db + +import ( + "context" + "flag" + "log" + "os" + "testing" + + "io/ioutil" + + "encoding/json" + + "reflect" + + "firebase.google.com/go/db" + "firebase.google.com/go/integration/internal" +) + +var client *db.Client +var ref *db.Ref +var testData map[string]interface{} +var parsedTestData map[string]Dinosaur + +func TestMain(m *testing.M) { + flag.Parse() + if testing.Short() { + log.Println("skipping database integration tests in short mode.") + os.Exit(0) + } + + ctx := context.Background() + app, err := internal.NewTestApp(ctx) + if err != nil { + log.Fatalln(err) + } + + client, err = app.Database(ctx) + if err != nil { + log.Fatalln(err) + } + + ref, err = client.NewRef("_adminsdk/go/dinodb") + if err != nil { + log.Fatalln(err) + } + setup() + + os.Exit(m.Run()) +} + +func setup() { + b, err := ioutil.ReadFile(internal.Resource("dinosaurs.json")) + if err != nil { + log.Fatalln(err) + } + if err = json.Unmarshal(b, &testData); err != nil { + log.Fatalln(err) + } + + b, err = json.Marshal(testData["dinosaurs"]) + if err != nil { + log.Fatalln(err) + } + if err = json.Unmarshal(b, &parsedTestData); err != nil { + log.Fatalln(err) + } + + if err = ref.Set(testData); err != nil { + log.Fatalln(err) + } +} + +func TestRef(t *testing.T) { + if ref.Key != "dinodb" { + t.Errorf("Key = %q; want = %q", ref.Key, "dinodb") + } + if ref.Path != "/_adminsdk/go/dinodb" { + t.Errorf("Path = %q; want = %q", ref.Path, "/_adminsdk/go/dinodb") + } +} + +func TestChild(t *testing.T) { + c, err := ref.Child("dinosaurs") + if err != nil { + t.Fatal(err) + } + if c.Key != "dinosaurs" { + t.Errorf("Key = %q; want = %q", c.Key, "dinosaurs") + } + if c.Path != "/_adminsdk/go/dinodb/dinosaurs" { + t.Errorf("Path = %q; want = %q", c.Path, "/_adminsdk/go/dinodb/dinosaurs") + } +} + +func TestParent(t *testing.T) { + p := ref.Parent() + if p.Key != "go" { + t.Errorf("Key = %q; want = %q", p.Key, "go") + } + if p.Path != "/_adminsdk/go" { + t.Errorf("Path = %q; want = %q", p.Path, "/_adminsdk/go") + } +} + +func TestGet(t *testing.T) { + var m map[string]interface{} + if err := ref.Get(&m); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(testData, m) { + t.Errorf("Get() = %v; want = %v", m, testData) + } +} + +func TestGetWithETag(t *testing.T) { + var m map[string]interface{} + etag, err := ref.GetWithETag(&m) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(testData, m) { + t.Errorf("GetWithETag() = %v; want = %v", m, testData) + } + if etag == "" { + t.Errorf("GetWithETag() = \"\"; want non-empty") + } +} + +func TestGetIfChanged(t *testing.T) { + var m map[string]interface{} + ok, etag, err := ref.GetIfChanged("wrong-etag", &m) + if err != nil { + t.Fatal(err) + } + if !ok || etag == "" { + t.Errorf("GetIfChanged() = (%v, %q); want = (%v, %q)", ok, etag, true, "non-empty") + } + if !reflect.DeepEqual(testData, m) { + t.Errorf("GetWithETag() = %v; want = %v", m, testData) + } + + var m2 map[string]interface{} + ok, etag2, err := ref.GetIfChanged(etag, &m2) + if err != nil { + t.Fatal(err) + } + if ok || etag != etag2 { + t.Errorf("GetIfChanged() = (%v, %q); want = (%v, %q)", ok, etag2, false, etag) + } + if len(m2) != 0 { + t.Errorf("GetWithETag() = %v; want empty", m) + } +} + +func TestGetChildValue(t *testing.T) { + c, err := ref.Child("dinosaurs") + if err != nil { + t.Fatal(err) + } + + var m map[string]interface{} + if err := c.Get(&m); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(testData["dinosaurs"], m) { + t.Errorf("Get() = %v; want = %v", m, testData["dinosaurs"]) + } +} + +func TestGetGrandChildValue(t *testing.T) { + c, err := ref.Child("dinosaurs/lambeosaurus") + if err != nil { + t.Fatal(err) + } + + var got Dinosaur + if err := c.Get(&got); err != nil { + t.Fatal(err) + } + want := parsedTestData["lambeosaurus"] + if !reflect.DeepEqual(want, got) { + t.Errorf("Get() = %v; want = %v", got, want) + } +} + +func TestGetNonExistingChild(t *testing.T) { + c, err := ref.Child("non_existing") + if err != nil { + t.Fatal(err) + } + + var i interface{} + if err := c.Get(&i); err != nil { + t.Fatal(err) + } + if i != nil { + t.Errorf("Get() = %v; want nil", i) + } +} + +type Dinosaur struct { + Appeared int `json:"appeared"` + Height float64 `json:"height"` + Length float64 `json:"length"` + Order string `json:"order"` + Vanished int `json:"vanished"` + Weight int `json:"weight"` + Ratings Ratings `json:"ratings"` +} + +type Ratings struct { + Pos int `json:"pos"` +} diff --git a/integration/internal/internal.go b/integration/internal/internal.go index bc52a16a..a40c9801 100644 --- a/integration/internal/internal.go +++ b/integration/internal/internal.go @@ -24,6 +24,8 @@ import ( "golang.org/x/net/context" + "fmt" + firebase "firebase.google.com/go" "google.golang.org/api/option" ) @@ -48,7 +50,8 @@ func NewTestApp(ctx context.Context) (*firebase.App, error) { return nil, err } config := &firebase.Config{ - StorageBucket: pid + ".appspot.com", + DatabaseURL: fmt.Sprintf("https://%s.firebaseio.com", pid), + StorageBucket: fmt.Sprintf("%s.appspot.com", pid), } return firebase.NewApp(ctx, config, option.WithCredentialsFile(Resource(certPath))) } diff --git a/internal/internal.go b/internal/internal.go index bb8073e8..ba1ce0ab 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -31,6 +31,7 @@ type AuthConfig struct { type DatabaseConfig struct { Opts []option.ClientOption BaseURL string + Version string } // StorageConfig represents the configuration of Google Cloud Storage service. diff --git a/testdata/dinosaurs.json b/testdata/dinosaurs.json new file mode 100644 index 00000000..29ca1936 --- /dev/null +++ b/testdata/dinosaurs.json @@ -0,0 +1,78 @@ +{ + "dinosaurs": { + "bruhathkayosaurus": { + "appeared": -70000000, + "height": 25, + "length": 44, + "order": "saurischia", + "vanished": -70000000, + "weight": 135000, + "ratings": { + "pos": 1 + } + }, + "lambeosaurus": { + "appeared": -76000000, + "height": 2.1, + "length": 12.5, + "order": "ornithischia", + "vanished": -75000000, + "weight": 5000, + "ratings": { + "pos": 2 + } + }, + "linhenykus": { + "appeared": -85000000, + "height": 0.6, + "length": 1, + "order": "theropoda", + "vanished": -75000000, + "weight": 3, + "ratings": { + "pos": 3 + } + }, + "pterodactyl": { + "appeared": -150000000, + "height": 0.6, + "length": 0.8, + "order": "pterosauria", + "vanished": -148500000, + "weight": 2, + "ratings": { + "pos": 4 + } + }, + "stegosaurus": { + "appeared": -155000000, + "height": 4, + "length": 9, + "order": "ornithischia", + "vanished": -150000000, + "weight": 2500, + "ratings": { + "pos": 5 + } + }, + "triceratops": { + "appeared": -68000000, + "height": 3, + "length": 8, + "order": "ornithischia", + "vanished": -66000000, + "weight": 11000, + "ratings": { + "pos": 6 + } + } + }, + "scores": { + "bruhathkayosaurus": 55, + "lambeosaurus": 21, + "linhenykus": 80, + "pterodactyl": 93, + "stegosaurus": 5, + "triceratops": 22 + } +} \ No newline at end of file From 40c06911e18d1f37ffe9276fbdbc84eee8f22a79 Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Sun, 22 Oct 2017 19:20:29 -0700 Subject: [PATCH 09/58] More integration tests --- db/ref.go | 3 ++ integration/db/db_test.go | 84 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 87 insertions(+) diff --git a/db/ref.go b/db/ref.go index 1766656c..6c6ff1bd 100644 --- a/db/ref.go +++ b/db/ref.go @@ -97,6 +97,9 @@ func (r *Ref) SetIfUnchanged(etag string, v interface{}) (bool, error) { } func (r *Ref) Push(v interface{}) (*Ref, error) { + if v == nil { + v = "" + } resp, err := r.send("POST", v) if err != nil { return nil, err diff --git a/integration/db/db_test.go b/integration/db/db_test.go index 0fd854f9..a6333b08 100644 --- a/integration/db/db_test.go +++ b/integration/db/db_test.go @@ -19,6 +19,7 @@ import ( var client *db.Client var ref *db.Ref +var users *db.Ref var testData map[string]interface{} var parsedTestData map[string]Dinosaur @@ -44,6 +45,11 @@ func TestMain(m *testing.M) { if err != nil { log.Fatalln(err) } + + users, err = ref.Parent().Child("users") + if err != nil { + log.Fatalln(err) + } setup() os.Exit(m.Run()) @@ -199,6 +205,79 @@ func TestGetNonExistingChild(t *testing.T) { } } +func TestPush(t *testing.T) { + u, err := users.Push(nil) + if err != nil { + t.Fatal(err) + } + if u.Path != "/_adminsdk/go/users/"+u.Key { + t.Errorf("Push() = %q; want = %q", u.Path, "/_adminsdk/go/users/"+u.Key) + } + + var i interface{} + if err := u.Get(&i); err != nil { + t.Fatal(err) + } + if i != "" { + t.Errorf("Get() = %v; want empty string", i) + } +} + +func TestPushWithValue(t *testing.T) { + want := User{"Luis Alvarez", 1911} + u, err := users.Push(&want) + if err != nil { + t.Fatal(err) + } + if u.Path != "/_adminsdk/go/users/"+u.Key { + t.Errorf("Push() = %q; want = %q", u.Path, "/_adminsdk/go/users/"+u.Key) + } + + var got User + if err := u.Get(&got); err != nil { + t.Fatal(err) + } + if want != got { + t.Errorf("Get() = %v; want = %v", got, want) + } +} + +func TestPrimitiveValue(t *testing.T) { + u, err := users.Push(nil) + if err != nil { + t.Fatal(err) + } + if err := u.Set("value"); err != nil { + t.Fatal(err) + } + var got string + if err := u.Get(&got); err != nil { + t.Fatal(err) + } + if got != "value" { + t.Errorf("Get() = %q; want = %q", got, "value") + } +} + +func TestComplexValue(t *testing.T) { + u, err := users.Push(nil) + if err != nil { + t.Fatal(err) + } + + want := User{"Mary Anning", 1799} + if err := u.Set(&want); err != nil { + t.Fatal(err) + } + var got User + if err := u.Get(&got); err != nil { + t.Fatal(err) + } + if got != want { + t.Errorf("Get() = %v; want = %v", got, want) + } +} + type Dinosaur struct { Appeared int `json:"appeared"` Height float64 `json:"height"` @@ -212,3 +291,8 @@ type Dinosaur struct { type Ratings struct { Pos int `json:"pos"` } + +type User struct { + Name string `json:"name"` + Since int `json:"sine"` +} From 0e85fa22be7b3db2cb6861e5f5de35b1b93febb9 Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Sun, 22 Oct 2017 19:24:36 -0700 Subject: [PATCH 10/58] Updated unit test --- db/db_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/db/db_test.go b/db/db_test.go index 1bce4571..4fd4cc65 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -387,6 +387,7 @@ func TestPush(t *testing.T) { checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "POST", Path: "/peter.json", + Body: serialize(""), }) } From 20fdbf3b10f2371aeef9ba1ace8246f94c33b349 Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Sun, 22 Oct 2017 21:53:38 -0700 Subject: [PATCH 11/58] More integration tests --- db/db_test.go | 4 +- db/ref.go | 2 +- integration/db/db_test.go | 198 +++++++++++++++++++++++++++++++++++++- 3 files changed, 198 insertions(+), 6 deletions(-) diff --git a/db/db_test.go b/db/db_test.go index 4fd4cc65..f936d36b 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -572,12 +572,12 @@ func TestTransactionAbort(t *testing.T) { checkAllRequests(t, mock.Reqs, wanted) } -func TestRemove(t *testing.T) { +func TestDelete(t *testing.T) { mock := &mockServer{Resp: "null"} srv := mock.Start(client) defer srv.Close() - if err := ref.Remove(); err != nil { + if err := ref.Delete(); err != nil { t.Fatal(err) } checkOnlyRequest(t, mock.Reqs, &testReq{ diff --git a/db/ref.go b/db/ref.go index 6c6ff1bd..774503ff 100644 --- a/db/ref.go +++ b/db/ref.go @@ -149,7 +149,7 @@ func (r *Ref) Transaction(fn UpdateFn) error { return fmt.Errorf("transaction aborted after failed retries") } -func (r *Ref) Remove() error { +func (r *Ref) Delete() error { resp, err := r.send("DELETE", nil) if err != nil { return err diff --git a/integration/db/db_test.go b/integration/db/db_test.go index a6333b08..9e08f2e9 100644 --- a/integration/db/db_test.go +++ b/integration/db/db_test.go @@ -3,6 +3,7 @@ package db import ( "context" "flag" + "fmt" "log" "os" "testing" @@ -242,7 +243,7 @@ func TestPushWithValue(t *testing.T) { } } -func TestPrimitiveValue(t *testing.T) { +func TestSetPrimitiveValue(t *testing.T) { u, err := users.Push(nil) if err != nil { t.Fatal(err) @@ -259,7 +260,7 @@ func TestPrimitiveValue(t *testing.T) { } } -func TestComplexValue(t *testing.T) { +func TestSetComplexValue(t *testing.T) { u, err := users.Push(nil) if err != nil { t.Fatal(err) @@ -278,6 +279,197 @@ func TestComplexValue(t *testing.T) { } } +func TestUpdateChildren(t *testing.T) { + u, err := users.Push(nil) + if err != nil { + t.Fatal(err) + } + + want := map[string]interface{}{ + "name": "Robert Bakker", + "since": float64(1945), + } + if err := u.Update(want); err != nil { + t.Fatal(err) + } + var got map[string]interface{} + if err := u.Get(&got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("Get() = %v; want = %v", got, want) + } +} + +func TestUpdateChildrenWithExistingValue(t *testing.T) { + u, err := users.Push(map[string]interface{}{ + "name": "Edwin Colbert", + "since": float64(1900), + }) + if err != nil { + t.Fatal(err) + } + + update := map[string]interface{}{"since": float64(1905)} + if err := u.Update(update); err != nil { + t.Fatal(err) + } + var got map[string]interface{} + if err := u.Get(&got); err != nil { + t.Fatal(err) + } + want := map[string]interface{}{ + "name": "Edwin Colbert", + "since": float64(1905), + } + if !reflect.DeepEqual(want, got) { + t.Errorf("Get() = %v; want = %v", got, want) + } +} + +func TestUpdateNestedChildren(t *testing.T) { + edward, err := users.Push(map[string]interface{}{"name": "Edward Cope", "since": float64(1800)}) + if err != nil { + t.Fatal(err) + } + jack, err := users.Push(map[string]interface{}{"name": "Jack Horner", "since": float64(1940)}) + if err != nil { + t.Fatal(err) + } + delta := map[string]interface{}{ + fmt.Sprintf("%s/since", edward.Key): 1840, + fmt.Sprintf("%s/since", jack.Key): 1946, + } + if err := users.Update(delta); err != nil { + t.Fatal(err) + } + var got map[string]interface{} + if err := edward.Get(&got); err != nil { + t.Fatal(err) + } + want := map[string]interface{}{"name": "Edward Cope", "since": float64(1840)} + if !reflect.DeepEqual(want, got) { + t.Errorf("Get() = %v; want = %v", got, want) + } + + if err := jack.Get(&got); err != nil { + t.Fatal(err) + } + want = map[string]interface{}{"name": "Jack Horner", "since": float64(1946)} + if !reflect.DeepEqual(want, got) { + t.Errorf("Get() = %v; want = %v", got, want) + } +} + +func TestSetIfChanged(t *testing.T) { + edward, err := users.Push(&User{"Edward Cope", 1800}) + if err != nil { + t.Fatal(err) + } + + update := User{"Jack Horner", 1940} + ok, err := edward.SetIfUnchanged("invalid-etag", &update) + if err != nil { + t.Fatal(err) + } + if ok { + t.Errorf("SetIfUnchanged() = %v; want = %v", ok, false) + } + + var u User + etag, err := edward.GetWithETag(&u) + if err != nil { + t.Fatal(err) + } + ok, err = edward.SetIfUnchanged(etag, &update) + if err != nil { + t.Fatal(err) + } + if !ok { + t.Errorf("SetIfUnchanged() = %v; want = %v", ok, true) + } + + if err := edward.Get(&u); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(update, u) { + t.Errorf("Get() = %v; want = %v", u, update) + } +} + +func TestTransaction(t *testing.T) { + u, err := users.Push(&User{Name: "Richard"}) + if err != nil { + t.Fatal(err) + } + fn := func(curr interface{}) (interface{}, error) { + snap := curr.(map[string]interface{}) + snap["name"] = "Richard Owen" + snap["since"] = 1804 + return snap, nil + } + if err := u.Transaction(fn); err != nil { + t.Fatal(err) + } + var got User + if err := u.Get(&got); err != nil { + t.Fatal(err) + } + want := User{"Richard Owen", 1804} + if !reflect.DeepEqual(want, got) { + t.Errorf("Get() = %v; want = %v", got, want) + } +} + +func TestTransactionScalar(t *testing.T) { + cnt, err := users.Child("count") + if err != nil { + t.Fatal(err) + } + if err := cnt.Set(42); err != nil { + t.Fatal(err) + } + fn := func(curr interface{}) (interface{}, error) { + snap := curr.(float64) + return snap + 1, nil + } + if err := cnt.Transaction(fn); err != nil { + t.Fatal(err) + } + var got float64 + if err := cnt.Get(&got); err != nil { + t.Fatal(err) + } + if got != 43.0 { + t.Errorf("Get() = %v; want = %v", got, 43.0) + } +} + +func TestDelete(t *testing.T) { + u, err := users.Push("foo") + if err != nil { + t.Fatal(err) + } + var got string + if err := u.Get(&got); err != nil { + t.Fatal(err) + } + if got != "foo" { + t.Errorf("Get() = %q; want = %q", got, "foo") + } + if err := u.Delete(); err != nil { + t.Fatal(err) + } + + var got2 string + if err := u.Get(&got2); err != nil { + t.Fatal(err) + } + if got2 != "" { + t.Errorf("Get() = %q; want = %q", got2, "") + } +} + type Dinosaur struct { Appeared int `json:"appeared"` Height float64 `json:"height"` @@ -294,5 +486,5 @@ type Ratings struct { type User struct { Name string `json:"name"` - Since int `json:"sine"` + Since int `json:"since"` } From ef547a6b5420cb7f0a156e9ab7577aa19528ab75 Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Mon, 23 Oct 2017 00:50:21 -0700 Subject: [PATCH 12/58] Integration tests for queries --- firebase.go | 8 +---- integration/db/db_test.go | 57 ++++++++++++++++++++++++++++++-- integration/db/query_test.go | 49 +++++++++++++++++++++++++++ integration/internal/internal.go | 13 ++++++++ internal/internal.go | 6 ++++ testdata/dinosaurs_index.json | 29 ++++++++++++++++ 6 files changed, 152 insertions(+), 10 deletions(-) create mode 100644 integration/db/query_test.go create mode 100644 testdata/dinosaurs_index.json diff --git a/firebase.go b/firebase.go index e6c131dd..6ae0ca52 100644 --- a/firebase.go +++ b/firebase.go @@ -31,12 +31,6 @@ import ( "google.golang.org/api/transport" ) -var firebaseScopes = []string{ - "https://www.googleapis.com/auth/devstorage.full_control", - "https://www.googleapis.com/auth/firebase", - "https://www.googleapis.com/auth/userinfo.email", -} - // Version of the Firebase Go Admin SDK. const Version = "2.0.0" @@ -91,7 +85,7 @@ func (a *App) Storage(ctx context.Context) (*storage.Client, error) { // oauth2.TokenSource) the App will be authenticated using that credential. Otherwise, NewApp attempts to // authenticate the App with Google application default credentials. func NewApp(ctx context.Context, config *Config, opts ...option.ClientOption) (*App, error) { - o := []option.ClientOption{option.WithScopes(firebaseScopes...)} + o := []option.ClientOption{option.WithScopes(internal.FirebaseScopes...)} o = append(o, opts...) creds, err := transport.Creds(ctx, o...) diff --git a/integration/db/db_test.go b/integration/db/db_test.go index 9e08f2e9..53b26b4e 100644 --- a/integration/db/db_test.go +++ b/integration/db/db_test.go @@ -5,6 +5,7 @@ import ( "flag" "fmt" "log" + "net/http" "os" "testing" @@ -14,6 +15,8 @@ import ( "reflect" + "bytes" + "firebase.google.com/go/db" "firebase.google.com/go/integration/internal" ) @@ -21,6 +24,7 @@ import ( var client *db.Client var ref *db.Ref var users *db.Ref +var dinos *db.Ref var testData map[string]interface{} var parsedTestData map[string]Dinosaur @@ -42,21 +46,68 @@ func TestMain(m *testing.M) { log.Fatalln(err) } + initRefs() + initRules() + initData() + + os.Exit(m.Run()) +} + +func initRefs() { + var err error ref, err = client.NewRef("_adminsdk/go/dinodb") if err != nil { log.Fatalln(err) } + dinos, err = ref.Child("dinosaurs") + if err != nil { + log.Fatalln(err) + } + users, err = ref.Parent().Child("users") if err != nil { log.Fatalln(err) } - setup() +} - os.Exit(m.Run()) +func initRules() { + b, err := ioutil.ReadFile(internal.Resource("dinosaurs_index.json")) + if err != nil { + log.Fatalln(err) + } + + pid, err := internal.ProjectID() + if err != nil { + log.Fatalln(err) + } + + url := fmt.Sprintf("https://%s.firebaseio.com/.settings/rules.json", pid) + req, err := http.NewRequest("PUT", url, bytes.NewBuffer(b)) + if err != nil { + log.Fatalln(err) + } + req.Header.Set("Content-Type", "application/json") + + hc, err := internal.NewHTTPClient(context.Background()) + if err != nil { + log.Fatalln(err) + } + resp, err := hc.Do(req) + if err != nil { + log.Fatalln(err) + } + defer resp.Body.Close() + + b, err = ioutil.ReadAll(resp.Body) + if err != nil { + log.Fatalln(err) + } else if resp.StatusCode != http.StatusOK { + log.Fatalln("failed to update rules: %q", string(b)) + } } -func setup() { +func initData() { b, err := ioutil.ReadFile(internal.Resource("dinosaurs.json")) if err != nil { log.Fatalln(err) diff --git a/integration/db/query_test.go b/integration/db/query_test.go new file mode 100644 index 00000000..9d106599 --- /dev/null +++ b/integration/db/query_test.go @@ -0,0 +1,49 @@ +package db + +import "testing" +import "firebase.google.com/go/db" + +var heightSorted = []string{ + "linhenykus", "pterodactyl", "lambeosaurus", + "triceratops", "stegosaurus", "bruhathkayosaurus", +} + +func TestLimitToFirst(t *testing.T) { + q, err := dinos.OrderByChild("height", db.WithLimitToFirst(2)) + if err != nil { + t.Fatal(err) + } + var m map[string]Dinosaur + if err := q.Get(&m); err != nil { + t.Fatal(err) + } + want := heightSorted[:2] + if len(m) != 2 { + t.Errorf("WithLimitToFirst() = %v; want = %v", m, want) + } + for _, d := range want { + if _, ok := m[d]; !ok { + t.Errorf("WithLimitToFirst() = %v; want key %q", m, d) + } + } +} + +func TestLimitToLast(t *testing.T) { + q, err := dinos.OrderByChild("height", db.WithLimitToLast(2)) + if err != nil { + t.Fatal(err) + } + var m map[string]Dinosaur + if err := q.Get(&m); err != nil { + t.Fatal(err) + } + want := heightSorted[len(heightSorted)-2:] + if len(m) != 2 { + t.Errorf("WithLimitToLast() = %v; want = %v", m, want) + } + for _, d := range want { + if _, ok := m[d]; !ok { + t.Errorf("WithLimitToLast() = %v; want key %q", m, d) + } + } +} diff --git a/integration/internal/internal.go b/integration/internal/internal.go index a40c9801..f210d92d 100644 --- a/integration/internal/internal.go +++ b/integration/internal/internal.go @@ -19,6 +19,7 @@ import ( "encoding/json" "go/build" "io/ioutil" + "net/http" "path/filepath" "strings" @@ -27,7 +28,9 @@ import ( "fmt" firebase "firebase.google.com/go" + "firebase.google.com/go/internal" "google.golang.org/api/option" + "google.golang.org/api/transport" ) const certPath = "integration_cert.json" @@ -82,3 +85,13 @@ func ProjectID() (string, error) { } return serviceAccount.ProjectID, nil } + +func NewHTTPClient(ctx context.Context, opts ...option.ClientOption) (*http.Client, error) { + opts = append( + opts, + option.WithCredentialsFile(Resource(certPath)), + option.WithScopes(internal.FirebaseScopes...), + ) + hc, _, err := transport.NewHTTPClient(ctx, opts...) + return hc, err +} diff --git a/internal/internal.go b/internal/internal.go index ba1ce0ab..b1c03ff5 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -21,6 +21,12 @@ import ( "golang.org/x/oauth2/google" ) +var FirebaseScopes = []string{ + "https://www.googleapis.com/auth/devstorage.full_control", + "https://www.googleapis.com/auth/firebase", + "https://www.googleapis.com/auth/userinfo.email", +} + // AuthConfig represents the configuration of Firebase Auth service. type AuthConfig struct { Opts []option.ClientOption diff --git a/testdata/dinosaurs_index.json b/testdata/dinosaurs_index.json new file mode 100644 index 00000000..bf4a2551 --- /dev/null +++ b/testdata/dinosaurs_index.json @@ -0,0 +1,29 @@ +{ + "rules": { + "_adminsdk": { + "go": { + "dinodb": { + "dinosaurs": { + ".indexOn": ["height", "ratings/pos"] + }, + "scores": { + ".indexOn": ".value" + } + }, + "protected": { + "$uid": { + ".read": "auth != null", + ".write": "$uid === auth.uid" + } + }, + "admin": { + ".read": "false", + ".write": "false" + }, + "public": { + ".read": "true" + } + } + } + } +} \ No newline at end of file From b4e127df9df2242366f2a09ae62bf841ce8ca5a7 Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Tue, 24 Oct 2017 00:38:28 -0700 Subject: [PATCH 13/58] Auth override support and more tests --- db/auth_override_test.go | 115 ++++++++++++++++++++++ db/db.go | 18 ++++ db/db_test.go | 80 +++++++++------ db/http_client.go | 1 + db/query_test.go | 103 +++++++++++++++---- db/ref.go | 1 + integration/db/query_test.go | 185 +++++++++++++++++++++++++++++++++-- internal/internal.go | 7 +- 8 files changed, 450 insertions(+), 60 deletions(-) create mode 100644 db/auth_override_test.go diff --git a/db/auth_override_test.go b/db/auth_override_test.go new file mode 100644 index 00000000..49a7f20e --- /dev/null +++ b/db/auth_override_test.go @@ -0,0 +1,115 @@ +package db + +import ( + "testing" +) + +func TestAuthOverrideGet(t *testing.T) { + mock := &mockServer{Resp: "data"} + srv := mock.Start(aoClient) + defer srv.Close() + + ref, err := aoClient.NewRef("peter") + if err != nil { + t.Fatal(err) + } + + var got string + if err := ref.Get(&got); err != nil { + t.Fatal(err) + } + if got != "data" { + t.Errorf("Get() = %q; want = %q", got, "data") + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"auth_variable_override": testAuthOverrides}, + }) +} + +func TestAuthOverrideSet(t *testing.T) { + mock := &mockServer{} + srv := mock.Start(aoClient) + defer srv.Close() + + ref, err := aoClient.NewRef("peter") + if err != nil { + t.Fatal(err) + } + + want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} + if err := ref.Set(want); err != nil { + t.Fatal(err) + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "PUT", + Body: serialize(want), + Path: "/peter.json", + Query: map[string]string{"auth_variable_override": testAuthOverrides, "print": "silent"}, + }) +} + +func TestAuthOverrideQuery(t *testing.T) { + mock := &mockServer{Resp: "data"} + srv := mock.Start(aoClient) + defer srv.Close() + + ref, err := aoClient.NewRef("peter") + if err != nil { + t.Fatal(err) + } + + q, err := ref.OrderByChild("foo") + if err != nil { + t.Fatal(err) + } + var got string + if err := q.Get(&got); err != nil { + t.Fatal(err) + } + if got != "data" { + t.Errorf("OrderByChild() = %q; want = %q", got, "data") + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{ + "auth_variable_override": testAuthOverrides, + "orderBy": "\"foo\"", + }, + }) +} + +func TestAuthOverrideRangeQuery(t *testing.T) { + mock := &mockServer{Resp: "data"} + srv := mock.Start(aoClient) + defer srv.Close() + + ref, err := aoClient.NewRef("peter") + if err != nil { + t.Fatal(err) + } + + q, err := ref.OrderByChild("foo", WithStartAt(1), WithEndAt(10)) + if err != nil { + t.Fatal(err) + } + var got string + if err := q.Get(&got); err != nil { + t.Fatal(err) + } + if got != "data" { + t.Errorf("OrderByChild() = %q; want = %q", got, "data") + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{ + "auth_variable_override": testAuthOverrides, + "orderBy": "\"foo\"", + "startAt": "1", + "endAt": "10", + }, + }) +} diff --git a/db/db.go b/db/db.go index 9bd960ce..b96a8e8a 100644 --- a/db/db.go +++ b/db/db.go @@ -25,6 +25,8 @@ import ( "net/url" + "encoding/json" + "golang.org/x/net/context" "google.golang.org/api/option" "google.golang.org/api/transport" @@ -37,6 +39,7 @@ const userAgent = "Firebase/HTTP/%s/%s/AdminGo" type Client struct { hc *http.Client baseURL string + ao string } func NewClient(ctx context.Context, c *internal.DatabaseConfig) (*Client, error) { @@ -59,9 +62,18 @@ func NewClient(ctx context.Context, c *internal.DatabaseConfig) (*Client, error) } else if !strings.HasSuffix(url.Host, ".firebaseio.com") { return nil, fmt.Errorf("invalid database URL (incorrest host): %q", c.BaseURL) } + + var ao []byte + if c.AuthOverrides != nil { + ao, err = json.Marshal(c.AuthOverrides) + if err != nil { + return nil, err + } + } return &Client{ hc: hc, baseURL: fmt.Sprintf("https://%s", url.Host), + ao: string(ao), }, nil } @@ -76,11 +88,17 @@ func (c *Client) NewRef(path string) (*Ref, error) { key = segs[len(segs)-1] } + var opts []httpOption + if c.ao != "" { + opts = append(opts, withQueryParam("auth_variable_override", c.ao)) + } + return &Ref{ Key: key, Path: "/" + strings.Join(segs, "/"), client: c, segs: segs, + opts: opts, }, nil } diff --git a/db/db_test.go b/db/db_test.go index f936d36b..551f0f61 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -27,21 +27,43 @@ import ( const testURL = "https://test-db.firebaseio.com" var testUserAgent string +var testAuthOverrides string var testOpts = []option.ClientOption{ option.WithTokenSource(&mockTokenSource{"mock-token"}), } var client *Client +var aoClient *Client var ref *Ref func TestMain(m *testing.M) { var err error - conf := &internal.DatabaseConfig{Opts: testOpts, BaseURL: testURL, Version: "1.2.3"} - client, err = NewClient(context.Background(), conf) + client, err = NewClient(context.Background(), &internal.DatabaseConfig{ + Opts: testOpts, + BaseURL: testURL, + Version: "1.2.3", + }) + if err != nil { + log.Fatalln(err) + } + + ao := map[string]interface{}{"uid": "user1"} + aoClient, err = NewClient(context.Background(), &internal.DatabaseConfig{ + Opts: testOpts, + BaseURL: testURL, + Version: "1.2.3", + AuthOverrides: ao, + }) if err != nil { log.Fatalln(err) } + b, err := json.Marshal(ao) + if err != nil { + log.Fatalln(err) + } + testAuthOverrides = string(b) + ref, err = client.NewRef("peter") if err != nil { log.Fatalln(err) @@ -305,8 +327,9 @@ func TestSet(t *testing.T) { } checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "PUT", - Path: "/peter.json?print=silent", + Path: "/peter.json", Body: serialize(want), + Query: map[string]string{"print": "silent"}, }) } @@ -321,8 +344,9 @@ func TestSetWithStruct(t *testing.T) { } checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "PUT", - Path: "/peter.json?print=silent", + Path: "/peter.json", Body: serialize(want), + Query: map[string]string{"print": "silent"}, }) } @@ -423,8 +447,9 @@ func TestUpdate(t *testing.T) { } checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "PATCH", - Path: "/peter.json?print=silent", + Path: "/peter.json", Body: serialize(want), + Query: map[string]string{"print": "silent"}, }) } @@ -611,7 +636,14 @@ func checkRequest(t *testing.T, got, want *testReq) { t.Errorf("Method = %q; want = %q", got.Method, want.Method) } - checkURL(t, got.Path, want.Path) + if got.Path != want.Path { + t.Errorf("Path = %q; want = %q", got.Path, want.Path) + } + for k, v := range want.Query { + if got.Query[k] != v { + t.Errorf("QueryParam(%v) = %v; want = %v", k, got.Query[k], v) + } + } for k, v := range want.Header { if got.Header.Get(k) != v[0] { t.Errorf("Header(%q) = %q; want = %q", k, got.Header.Get(k), v[0]) @@ -636,33 +668,12 @@ func checkRequest(t *testing.T, got, want *testReq) { } } -func checkURL(t *testing.T, ug, uw string) { - got, err := url.ParseRequestURI(ug) - if err != nil { - t.Fatal(err) - } - want, err := url.ParseRequestURI(uw) - if err != nil { - t.Fatal(err) - } - if got.Path != want.Path { - t.Errorf("Path = %q; want = %q", got, want) - } - if len(got.Query()) != len(want.Query()) { - t.Errorf("QueryParams = %v; want = %v", got.Query(), want.Query()) - } - for k, v := range want.Query() { - if !reflect.DeepEqual(v, got.Query()[k]) { - t.Errorf("QueryParam(%v) = %v; want = %v", k, got.Query()[k], v) - } - } -} - type testReq struct { Method string Path string Header http.Header Body []byte + Query map[string]string } func newTestReq(r *http.Request) (*testReq, error) { @@ -671,11 +682,22 @@ func newTestReq(r *http.Request) (*testReq, error) { if err != nil { return nil, err } + + u, err := url.Parse(r.RequestURI) + if err != nil { + return nil, err + } + + query := make(map[string]string) + for k, v := range u.Query() { + query[k] = v[0] + } return &testReq{ Method: r.Method, - Path: r.RequestURI, + Path: u.Path, Header: r.Header, Body: b, + Query: query, }, nil } diff --git a/db/http_client.go b/db/http_client.go index fc66911b..fcc55f14 100644 --- a/db/http_client.go +++ b/db/http_client.go @@ -25,6 +25,7 @@ func (r *Ref) send(method string, body interface{}, opts ...httpOption) (*respon if err != nil { return nil, err } + opts = append(opts, r.opts...) for _, o := range opts { o(req) } diff --git a/db/query_test.go b/db/query_test.go index 14cf22ea..8812128a 100644 --- a/db/query_test.go +++ b/db/query_test.go @@ -23,8 +23,11 @@ func TestChildQuery(t *testing.T) { if !reflect.DeepEqual(want, got) { t.Errorf("Get() = %v; want = %v", got, want) } - p := "/peter.json?orderBy=\"messages\"" - checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: p}) + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"orderBy": "\"messages\""}, + }) } func TestNestedChildQuery(t *testing.T) { @@ -45,8 +48,11 @@ func TestNestedChildQuery(t *testing.T) { if !reflect.DeepEqual(want, got) { t.Errorf("Get() = %v; want = %v", got, want) } - p := "/peter.json?orderBy=\"messages/ratings\"" - checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: p}) + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"orderBy": "\"messages/ratings\""}, + }) } func TestChildQueryWithParams(t *testing.T) { @@ -72,8 +78,16 @@ func TestChildQueryWithParams(t *testing.T) { if !reflect.DeepEqual(want, got) { t.Errorf("Get() = %v; want = %v", got, want) } - p := "/peter.json?startAt=\"m4\"&endAt=\"m50\"&limitToFirst=10&orderBy=\"messages\"" - checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: p}) + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{ + "orderBy": "\"messages\"", + "startAt": "\"m4\"", + "endAt": "\"m50\"", + "limitToFirst": "10", + }, + }) } func TestKeyQuery(t *testing.T) { @@ -94,7 +108,11 @@ func TestKeyQuery(t *testing.T) { if !reflect.DeepEqual(want, got) { t.Errorf("Get() = %v; want = %v", got, want) } - checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: "/peter.json?orderBy=\"$key\""}) + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"orderBy": "\"$key\""}, + }) } func TestValueQuery(t *testing.T) { @@ -115,7 +133,11 @@ func TestValueQuery(t *testing.T) { if !reflect.DeepEqual(want, got) { t.Errorf("Get() = %v; want = %v", got, want) } - checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: "/peter.json?orderBy=\"$value\""}) + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"orderBy": "\"$value\""}, + }) } func TestLimitFirstQuery(t *testing.T) { @@ -136,8 +158,11 @@ func TestLimitFirstQuery(t *testing.T) { if !reflect.DeepEqual(want, got) { t.Errorf("Get() = %v; want = %v", got, want) } - p := "/peter.json?limitToFirst=10&orderBy=\"messages\"" - checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: p}) + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"limitToFirst": "10", "orderBy": "\"messages\""}, + }) } func TestLimitLastQuery(t *testing.T) { @@ -158,8 +183,11 @@ func TestLimitLastQuery(t *testing.T) { if !reflect.DeepEqual(want, got) { t.Errorf("Get() = %v; want = %v", got, want) } - p := "/peter.json?limitToLast=10&orderBy=\"messages\"" - checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: p}) + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"limitToLast": "10", "orderBy": "\"messages\""}, + }) } func TestInvalidLimitQuery(t *testing.T) { @@ -197,8 +225,11 @@ func TestStartAtQuery(t *testing.T) { if !reflect.DeepEqual(want, got) { t.Errorf("Get() = %v; want = %v", got, want) } - p := "/peter.json?startAt=10&orderBy=\"messages\"" - checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: p}) + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"startAt": "10", "orderBy": "\"messages\""}, + }) } func TestEndAtQuery(t *testing.T) { @@ -219,8 +250,41 @@ func TestEndAtQuery(t *testing.T) { if !reflect.DeepEqual(want, got) { t.Errorf("Get() = %v; want = %v", got, want) } - p := "/peter.json?endAt=10&orderBy=\"messages\"" - checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: p}) + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"endAt": "10", "orderBy": "\"messages\""}, + }) +} + +func TestAllParamsQuery(t *testing.T) { + want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + q, err := ref.OrderByChild("messages", WithLimitToFirst(100), WithStartAt("bar"), WithEndAt("foo")) + if err != nil { + t.Fatal(err) + } + + var got map[string]interface{} + if err := q.Get(&got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("Get() = %v; want = %v", got, want) + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{ + "limitToFirst": "100", + "startAt": "\"bar\"", + "endAt": "\"foo\"", + "orderBy": "\"messages\"", + }, + }) } func TestEqualToQuery(t *testing.T) { @@ -241,6 +305,9 @@ func TestEqualToQuery(t *testing.T) { if !reflect.DeepEqual(want, got) { t.Errorf("Get() = %v; want = %v", got, want) } - p := "/peter.json?equalTo=10&orderBy=\"messages\"" - checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: p}) + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"equalTo": "10", "orderBy": "\"messages\""}, + }) } diff --git a/db/ref.go b/db/ref.go index 774503ff..c078fe7e 100644 --- a/db/ref.go +++ b/db/ref.go @@ -26,6 +26,7 @@ type Ref struct { client *Client segs []string + opts []httpOption } func (r *Ref) Parent() *Ref { diff --git a/integration/db/query_test.go b/integration/db/query_test.go index 9d106599..4e08cdf2 100644 --- a/integration/db/query_test.go +++ b/integration/db/query_test.go @@ -9,7 +9,61 @@ var heightSorted = []string{ } func TestLimitToFirst(t *testing.T) { - q, err := dinos.OrderByChild("height", db.WithLimitToFirst(2)) + for _, tc := range []int{2, 10} { + q, err := dinos.OrderByChild("height", db.WithLimitToFirst(tc)) + if err != nil { + t.Fatal(err) + } + var m map[string]Dinosaur + if err := q.Get(&m); err != nil { + t.Fatal(err) + } + + wl := tc + if len(heightSorted) < wl { + wl = len(heightSorted) + } + want := heightSorted[:wl] + if len(m) != len(want) { + t.Errorf("WithLimitToFirst() = %v; want = %v", m, want) + } + for _, d := range want { + if _, ok := m[d]; !ok { + t.Errorf("WithLimitToFirst() = %v; want key %q", m, d) + } + } + } +} + +func TestLimitToLast(t *testing.T) { + for _, tc := range []int{2, 10} { + q, err := dinos.OrderByChild("height", db.WithLimitToLast(tc)) + if err != nil { + t.Fatal(err) + } + var m map[string]Dinosaur + if err := q.Get(&m); err != nil { + t.Fatal(err) + } + + wl := tc + if len(heightSorted) < wl { + wl = len(heightSorted) + } + want := heightSorted[len(heightSorted)-wl:] + if len(m) != len(want) { + t.Errorf("WithLimitToLast() = %v; want = %v", m, want) + } + for _, d := range want { + if _, ok := m[d]; !ok { + t.Errorf("WithLimitToLast() = %v; want key %q", m, d) + } + } + } +} + +func TestStartAt(t *testing.T) { + q, err := dinos.OrderByChild("height", db.WithStartAt(3.5)) if err != nil { t.Fatal(err) } @@ -17,19 +71,83 @@ func TestLimitToFirst(t *testing.T) { if err := q.Get(&m); err != nil { t.Fatal(err) } + + want := heightSorted[len(heightSorted)-2:] + if len(m) != len(want) { + t.Errorf("WithStartAt() = %v; want = %v", m, want) + } + for _, d := range want { + if _, ok := m[d]; !ok { + t.Errorf("WithStartAt() = %v; want key %q", m, d) + } + } +} + +func TestEndAt(t *testing.T) { + q, err := dinos.OrderByChild("height", db.WithEndAt(3.5)) + if err != nil { + t.Fatal(err) + } + var m map[string]Dinosaur + if err := q.Get(&m); err != nil { + t.Fatal(err) + } + + want := heightSorted[:4] + if len(m) != len(want) { + t.Errorf("WithStartAt() = %v; want = %v", m, want) + } + for _, d := range want { + if _, ok := m[d]; !ok { + t.Errorf("WithStartAt() = %v; want key %q", m, d) + } + } +} + +func TestStartAndEndAt(t *testing.T) { + q, err := dinos.OrderByChild("height", db.WithStartAt(2.5), db.WithEndAt(5)) + if err != nil { + t.Fatal(err) + } + var m map[string]Dinosaur + if err := q.Get(&m); err != nil { + t.Fatal(err) + } + + want := heightSorted[len(heightSorted)-3 : len(heightSorted)-1] + if len(m) != len(want) { + t.Errorf("WithStartAt(), WithEndAt() = %v; want = %v", m, want) + } + for _, d := range want { + if _, ok := m[d]; !ok { + t.Errorf("WithStartAt(), WithEndAt() = %v; want key %q", m, d) + } + } +} + +func TestEqualTo(t *testing.T) { + q, err := dinos.OrderByChild("height", db.WithEqualTo(0.6)) + if err != nil { + t.Fatal(err) + } + var m map[string]Dinosaur + if err := q.Get(&m); err != nil { + t.Fatal(err) + } + want := heightSorted[:2] - if len(m) != 2 { - t.Errorf("WithLimitToFirst() = %v; want = %v", m, want) + if len(m) != len(want) { + t.Errorf("WithEqualTo() = %v; want = %v", m, want) } for _, d := range want { if _, ok := m[d]; !ok { - t.Errorf("WithLimitToFirst() = %v; want key %q", m, d) + t.Errorf("WithEqualTo() = %v; want key %q", m, d) } } } -func TestLimitToLast(t *testing.T) { - q, err := dinos.OrderByChild("height", db.WithLimitToLast(2)) +func TestOrderByNestedChild(t *testing.T) { + q, err := dinos.OrderByChild("ratings/pos", db.WithStartAt(4)) if err != nil { t.Fatal(err) } @@ -37,13 +155,60 @@ func TestLimitToLast(t *testing.T) { if err := q.Get(&m); err != nil { t.Fatal(err) } - want := heightSorted[len(heightSorted)-2:] - if len(m) != 2 { - t.Errorf("WithLimitToLast() = %v; want = %v", m, want) + + want := []string{"pterodactyl", "stegosaurus", "triceratops"} + if len(m) != len(want) { + t.Errorf("OrderByChild(ratings/pos) = %v; want = %v", m, want) + } + for _, d := range want { + if _, ok := m[d]; !ok { + t.Errorf("OrderByChild(ratings/pos) = %v; want key %q", m, d) + } + } +} + +func TestOrderByKey(t *testing.T) { + q, err := dinos.OrderByKey(db.WithLimitToFirst(2)) + if err != nil { + t.Fatal(err) + } + var m map[string]Dinosaur + if err := q.Get(&m); err != nil { + t.Fatal(err) + } + + want := []string{"bruhathkayosaurus", "lambeosaurus"} + if len(m) != len(want) { + t.Errorf("OrderByKey() = %v; want = %v", m, want) + } + for _, d := range want { + if _, ok := m[d]; !ok { + t.Errorf("OrderByKey() = %v; want key %q", m, d) + } + } +} + +func TestOrderByValue(t *testing.T) { + scores, err := ref.Child("scores") + if err != nil { + t.Fatal(err) + } + q, err := scores.OrderByValue(db.WithLimitToLast(2)) + if err != nil { + t.Fatal(err) + } + var m map[string]int + if err := q.Get(&m); err != nil { + t.Fatal(err) + } + + want := []string{"pterodactyl", "linhenykus"} + if len(m) != len(want) { + t.Errorf("OrderByValue() = %v; want = %v", m, want) } for _, d := range want { if _, ok := m[d]; !ok { - t.Errorf("WithLimitToLast() = %v; want key %q", m, d) + t.Errorf("OrderByValue() = %v; want key %q", m, d) } } } diff --git a/internal/internal.go b/internal/internal.go index b1c03ff5..0d1d97c7 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -35,9 +35,10 @@ type AuthConfig struct { } type DatabaseConfig struct { - Opts []option.ClientOption - BaseURL string - Version string + Opts []option.ClientOption + BaseURL string + Version string + AuthOverrides map[string]interface{} } // StorageConfig represents the configuration of Google Cloud Storage service. From 6af57b06006411379dd331402937ae02da07b2a0 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Tue, 24 Oct 2017 18:09:16 -0700 Subject: [PATCH 14/58] More test cases; AuthOverride support in App --- db/db_test.go | 459 +++------------------------- db/ref.go | 6 +- db/ref_test.go | 444 +++++++++++++++++++++++++++ firebase.go | 14 +- firebase_test.go | 13 + integration/auth/auth_test.go | 2 +- integration/db/db_test.go | 120 +++++++- integration/internal/internal.go | 14 +- integration/storage/storage_test.go | 11 +- 9 files changed, 638 insertions(+), 445 deletions(-) create mode 100644 db/ref_test.go diff --git a/db/db_test.go b/db/db_test.go index 551f0f61..4b462d7a 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -137,6 +137,16 @@ func TestNewRef(t *testing.T) { } } +func TestInvalidNewRef(t *testing.T) { + cases := []string{"foo#", "foo.", "foo$", "foo[", "foo]"} + for _, tc := range cases { + r, err := client.NewRef(tc) + if r != nil || err == nil { + t.Errorf("NewRef(%q) = (%v, %v); want = (nil, err)", tc, r, err) + } + } +} + func TestParent(t *testing.T) { cases := []struct { Path string @@ -174,441 +184,48 @@ func TestParent(t *testing.T) { } } -func TestGet(t *testing.T) { - want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} - mock := &mockServer{Resp: want} - srv := mock.Start(client) - defer srv.Close() - - var got map[string]interface{} - if err := ref.Get(&got); err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(want, got) { - t.Errorf("Get() = %v; want = %v", got, want) - } - checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: "/peter.json"}) -} - -func TestGetWithStruct(t *testing.T) { - want := person{Name: "Peter Parker", Age: 17} - mock := &mockServer{Resp: want} - srv := mock.Start(client) - defer srv.Close() - - var got person - if err := ref.Get(&got); err != nil { - t.Fatal(err) - } - if want != got { - t.Errorf("Get() = %v; want = %v", got, want) - } - checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: "/peter.json"}) -} - -func TestGetWithETag(t *testing.T) { - want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} - mock := &mockServer{ - Resp: want, - Header: map[string]string{"ETag": "mock-etag"}, - } - srv := mock.Start(client) - defer srv.Close() - - var got map[string]interface{} - etag, err := ref.GetWithETag(&got) - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(want, got) { - t.Errorf("Get() = %v; want = %v", got, want) - } - if etag != "mock-etag" { - t.Errorf("ETag = %q; want = %q", etag, "mock-etag") - } - checkOnlyRequest(t, mock.Reqs, &testReq{ - Method: "GET", - Path: "/peter.json", - Header: http.Header{"X-Firebase-ETag": []string{"true"}}, - }) -} - -func TestGetIfChanged(t *testing.T) { - want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} - mock := &mockServer{ - Resp: want, - Header: map[string]string{"ETag": "new-etag"}, - } - srv := mock.Start(client) - defer srv.Close() - - var got map[string]interface{} - ok, etag, err := ref.GetIfChanged("old-etag", &got) - if err != nil { - t.Fatal(err) - } - if !ok { - t.Errorf("Get() = %v; want = %v", ok, true) - } - if !reflect.DeepEqual(want, got) { - t.Errorf("Get() = %v; want = %v", got, want) - } - if etag != "new-etag" { - t.Errorf("ETag = %q; want = %q", etag, "new-etag") - } - - mock.Status = http.StatusNotModified - mock.Resp = nil - var got2 map[string]interface{} - ok, etag, err = ref.GetIfChanged("new-etag", &got2) - if err != nil { - t.Fatal(err) - } - if ok { - t.Errorf("Get() = %v; want = %v", ok, false) - } - if got2 != nil { - t.Errorf("Get() = %v; want nil", got2) - } - if etag != "new-etag" { - t.Errorf("ETag = %q; want = %q", etag, "new-etag") - } - - checkAllRequests(t, mock.Reqs, []*testReq{ - &testReq{ - Method: "GET", - Path: "/peter.json", - Header: http.Header{"If-None-Match": []string{"old-etag"}}, - }, - &testReq{ - Method: "GET", - Path: "/peter.json", - Header: http.Header{"If-None-Match": []string{"new-etag"}}, - }, - }) -} - -func TestWerlformedHttpError(t *testing.T) { - mock := &mockServer{Resp: map[string]string{"error": "test error"}, Status: 500} - srv := mock.Start(client) - defer srv.Close() - - var got person - err := ref.Get(&got) - want := "http error status: 500; reason: test error" - if err == nil || err.Error() != want { - t.Errorf("Get() = %v; want = %v", err, want) - } - checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: "/peter.json"}) -} - -func TestUnexpectedHttpError(t *testing.T) { - mock := &mockServer{Resp: "unexpected error", Status: 500} - srv := mock.Start(client) - defer srv.Close() - - var got person - err := ref.Get(&got) - want := "http error status: 500; message: \"unexpected error\"" - if err == nil || err.Error() != want { - t.Errorf("Get() = %v; want = %v", err, want) - } - checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: "/peter.json"}) -} - -func TestSet(t *testing.T) { - mock := &mockServer{} - srv := mock.Start(client) - defer srv.Close() - - want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} - if err := ref.Set(want); err != nil { - t.Fatal(err) - } - checkOnlyRequest(t, mock.Reqs, &testReq{ - Method: "PUT", - Path: "/peter.json", - Body: serialize(want), - Query: map[string]string{"print": "silent"}, - }) -} - -func TestSetWithStruct(t *testing.T) { - mock := &mockServer{} - srv := mock.Start(client) - defer srv.Close() - - want := &person{"Peter Parker", 17} - if err := ref.Set(&want); err != nil { - t.Fatal(err) - } - checkOnlyRequest(t, mock.Reqs, &testReq{ - Method: "PUT", - Path: "/peter.json", - Body: serialize(want), - Query: map[string]string{"print": "silent"}, - }) -} - -func TestSetIfUnchanged(t *testing.T) { - mock := &mockServer{} - srv := mock.Start(client) - defer srv.Close() - - want := &person{"Peter Parker", 17} - ok, err := ref.SetIfUnchanged("mock-etag", &want) - if err != nil { - t.Fatal(err) - } - if !ok { - t.Errorf("SetIfUnchanged() = %v; want = %v", ok, true) - } - checkOnlyRequest(t, mock.Reqs, &testReq{ - Method: "PUT", - Path: "/peter.json", - Body: serialize(want), - Header: http.Header{"If-Match": []string{"mock-etag"}}, - }) -} - -func TestSetIfUnchangedError(t *testing.T) { - mock := &mockServer{ - Status: http.StatusPreconditionFailed, - Resp: &person{"Tony Stark", 39}, - } - srv := mock.Start(client) - defer srv.Close() - - want := &person{"Peter Parker", 17} - ok, err := ref.SetIfUnchanged("mock-etag", &want) +func TestChild(t *testing.T) { + r, err := client.NewRef("/test") if err != nil { t.Fatal(err) } - if ok { - t.Errorf("SetIfUnchanged() = %v; want = %v", ok, false) - } - checkOnlyRequest(t, mock.Reqs, &testReq{ - Method: "PUT", - Path: "/peter.json", - Body: serialize(want), - Header: http.Header{"If-Match": []string{"mock-etag"}}, - }) -} - -func TestPush(t *testing.T) { - mock := &mockServer{Resp: map[string]string{"name": "new_key"}} - srv := mock.Start(client) - defer srv.Close() - child, err := ref.Push(nil) - if err != nil { - t.Fatal(err) + cases := []struct { + Path string + Want string + Parent string + }{ + {"foo", "/test/foo", "/test"}, + {"foo/bar", "/test/foo/bar", "/test/foo"}, + {"foo/bar/", "/test/foo/bar", "/test/foo"}, } - - if child.Key != "new_key" { - t.Errorf("Push() = %q; want = %q", child.Key, "new_key") + for _, tc := range cases { + c, err := r.Child(tc.Path) + if err != nil { + t.Fatal(err) + } + if c.Path != tc.Want { + t.Errorf("Child(%q) = %q; want = %q", tc.Path, c.Path, tc.Want) + } + if c.Parent().Path != tc.Parent { + t.Errorf("Child().Parent() = %q; want = %q", c.Parent().Path, tc.Parent) + } } - checkOnlyRequest(t, mock.Reqs, &testReq{ - Method: "POST", - Path: "/peter.json", - Body: serialize(""), - }) } -func TestPushWithValue(t *testing.T) { - mock := &mockServer{Resp: map[string]string{"name": "new_key"}} - srv := mock.Start(client) - defer srv.Close() - - want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} - child, err := ref.Push(want) +func TestInvalidChild(t *testing.T) { + r, err := client.NewRef("/test") if err != nil { t.Fatal(err) } - if child.Key != "new_key" { - t.Errorf("Push() = %q; want = %q", child.Key, "new_key") - } - checkOnlyRequest(t, mock.Reqs, &testReq{ - Method: "POST", - Path: "/peter.json", - Body: serialize(want), - }) -} - -func TestUpdate(t *testing.T) { - want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} - mock := &mockServer{Resp: want} - srv := mock.Start(client) - defer srv.Close() - - if err := ref.Update(want); err != nil { - t.Fatal(err) - } - checkOnlyRequest(t, mock.Reqs, &testReq{ - Method: "PATCH", - Path: "/peter.json", - Body: serialize(want), - Query: map[string]string{"print": "silent"}, - }) -} - -func TestInvalidUpdate(t *testing.T) { - if err := ref.Update(nil); err == nil { - t.Errorf("Update(nil) = nil; want error") - } - - m := make(map[string]interface{}) - if err := ref.Update(m); err == nil { - t.Errorf("Update(map{}) = nil; want error") - } -} - -func TestTransaction(t *testing.T) { - mock := &mockServer{ - Resp: &person{"Peter Parker", 17}, - Header: map[string]string{"ETag": "mock-etag"}, - } - srv := mock.Start(client) - defer srv.Close() - - var fn UpdateFn = func(i interface{}) (interface{}, error) { - p := i.(map[string]interface{}) - p["age"] = p["age"].(float64) + 1.0 - return p, nil - } - if err := ref.Transaction(fn); err != nil { - t.Fatal(err) - } - checkAllRequests(t, mock.Reqs, []*testReq{ - &testReq{ - Method: "GET", - Path: "/peter.json", - Header: http.Header{"X-Firebase-ETag": []string{"true"}}, - }, - &testReq{ - Method: "PUT", - Path: "/peter.json", - Body: serialize(map[string]interface{}{ - "name": "Peter Parker", - "age": 18, - }), - Header: http.Header{"If-Match": []string{"mock-etag"}}, - }, - }) -} - -func TestTransactionRetry(t *testing.T) { - mock := &mockServer{ - Resp: &person{"Peter Parker", 17}, - Header: map[string]string{"ETag": "mock-etag1"}, - } - srv := mock.Start(client) - defer srv.Close() - - cnt := 0 - var fn UpdateFn = func(i interface{}) (interface{}, error) { - if cnt == 0 { - mock.Status = http.StatusPreconditionFailed - mock.Header = map[string]string{"ETag": "mock-etag2"} - mock.Resp = &person{"Peter Parker", 19} - } else if cnt == 1 { - mock.Status = http.StatusOK - } - cnt++ - p := i.(map[string]interface{}) - p["age"] = p["age"].(float64) + 1.0 - return p, nil - } - if err := ref.Transaction(fn); err != nil { - t.Fatal(err) - } - if cnt != 2 { - t.Errorf("Retry Count = %d; want = %d", cnt, 2) - } - checkAllRequests(t, mock.Reqs, []*testReq{ - &testReq{ - Method: "GET", - Path: "/peter.json", - Header: http.Header{"X-Firebase-ETag": []string{"true"}}, - }, - &testReq{ - Method: "PUT", - Path: "/peter.json", - Body: serialize(map[string]interface{}{ - "name": "Peter Parker", - "age": 18, - }), - Header: http.Header{"If-Match": []string{"mock-etag1"}}, - }, - &testReq{ - Method: "PUT", - Path: "/peter.json", - Body: serialize(map[string]interface{}{ - "name": "Peter Parker", - "age": 20, - }), - Header: http.Header{"If-Match": []string{"mock-etag2"}}, - }, - }) -} - -func TestTransactionAbort(t *testing.T) { - mock := &mockServer{ - Resp: &person{"Peter Parker", 17}, - Header: map[string]string{"ETag": "mock-etag1"}, - } - srv := mock.Start(client) - defer srv.Close() - - cnt := 0 - var fn UpdateFn = func(i interface{}) (interface{}, error) { - if cnt == 0 { - mock.Status = http.StatusPreconditionFailed - mock.Header = map[string]string{"ETag": "mock-etag1"} + cases := []string{"", "/", "/foo", "foo#bar"} + for _, tc := range cases { + c, err := r.Child(tc) + if c != nil || err == nil { + t.Errorf("Child(%q) = (%v, %v); want = (nil, err)", tc, c, err) } - cnt++ - p := i.(map[string]interface{}) - p["age"] = p["age"].(float64) + 1.0 - return p, nil - } - err := ref.Transaction(fn) - if err == nil { - t.Errorf("Transaction() = nil; want error") - } - wanted := []*testReq{ - &testReq{ - Method: "GET", - Path: "/peter.json", - Header: http.Header{"X-Firebase-ETag": []string{"true"}}, - }, } - for i := 0; i < 20; i++ { - wanted = append(wanted, &testReq{ - Method: "PUT", - Path: "/peter.json", - Body: serialize(map[string]interface{}{ - "name": "Peter Parker", - "age": 18, - }), - Header: http.Header{"If-Match": []string{"mock-etag1"}}, - }) - } - checkAllRequests(t, mock.Reqs, wanted) -} - -func TestDelete(t *testing.T) { - mock := &mockServer{Resp: "null"} - srv := mock.Start(client) - defer srv.Close() - - if err := ref.Delete(); err != nil { - t.Fatal(err) - } - checkOnlyRequest(t, mock.Reqs, &testReq{ - Method: "DELETE", - Path: "/peter.json", - }) } func checkOnlyRequest(t *testing.T, got []*testReq, want *testReq) { diff --git a/db/ref.go b/db/ref.go index c078fe7e..92267cbd 100644 --- a/db/ref.go +++ b/db/ref.go @@ -40,8 +40,10 @@ func (r *Ref) Parent() *Ref { } func (r *Ref) Child(path string) (*Ref, error) { - if strings.HasPrefix(path, "/") { - return nil, fmt.Errorf("child path must not start with %q", "/") + if path == "" { + return nil, fmt.Errorf("child path must not be empty") + } else if strings.HasPrefix(path, "/") { + return nil, fmt.Errorf("invalid child path with '/' prefix: %q", path) } fp := fmt.Sprintf("%s/%s", r.Path, path) return r.client.NewRef(fp) diff --git a/db/ref_test.go b/db/ref_test.go new file mode 100644 index 00000000..fa1a6da6 --- /dev/null +++ b/db/ref_test.go @@ -0,0 +1,444 @@ +package db + +import ( + "net/http" + "reflect" + "testing" +) + +func TestGet(t *testing.T) { + want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + var got map[string]interface{} + if err := ref.Get(&got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("Get() = %v; want = %v", got, want) + } + checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: "/peter.json"}) +} + +func TestGetWithStruct(t *testing.T) { + want := person{Name: "Peter Parker", Age: 17} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + var got person + if err := ref.Get(&got); err != nil { + t.Fatal(err) + } + if want != got { + t.Errorf("Get() = %v; want = %v", got, want) + } + checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: "/peter.json"}) +} + +func TestGetWithETag(t *testing.T) { + want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} + mock := &mockServer{ + Resp: want, + Header: map[string]string{"ETag": "mock-etag"}, + } + srv := mock.Start(client) + defer srv.Close() + + var got map[string]interface{} + etag, err := ref.GetWithETag(&got) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("Get() = %v; want = %v", got, want) + } + if etag != "mock-etag" { + t.Errorf("ETag = %q; want = %q", etag, "mock-etag") + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Header: http.Header{"X-Firebase-ETag": []string{"true"}}, + }) +} + +func TestGetIfChanged(t *testing.T) { + want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} + mock := &mockServer{ + Resp: want, + Header: map[string]string{"ETag": "new-etag"}, + } + srv := mock.Start(client) + defer srv.Close() + + var got map[string]interface{} + ok, etag, err := ref.GetIfChanged("old-etag", &got) + if err != nil { + t.Fatal(err) + } + if !ok { + t.Errorf("Get() = %v; want = %v", ok, true) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("Get() = %v; want = %v", got, want) + } + if etag != "new-etag" { + t.Errorf("ETag = %q; want = %q", etag, "new-etag") + } + + mock.Status = http.StatusNotModified + mock.Resp = nil + var got2 map[string]interface{} + ok, etag, err = ref.GetIfChanged("new-etag", &got2) + if err != nil { + t.Fatal(err) + } + if ok { + t.Errorf("Get() = %v; want = %v", ok, false) + } + if got2 != nil { + t.Errorf("Get() = %v; want nil", got2) + } + if etag != "new-etag" { + t.Errorf("ETag = %q; want = %q", etag, "new-etag") + } + + checkAllRequests(t, mock.Reqs, []*testReq{ + &testReq{ + Method: "GET", + Path: "/peter.json", + Header: http.Header{"If-None-Match": []string{"old-etag"}}, + }, + &testReq{ + Method: "GET", + Path: "/peter.json", + Header: http.Header{"If-None-Match": []string{"new-etag"}}, + }, + }) +} + +func TestWerlformedHttpError(t *testing.T) { + mock := &mockServer{Resp: map[string]string{"error": "test error"}, Status: 500} + srv := mock.Start(client) + defer srv.Close() + + var got person + err := ref.Get(&got) + want := "http error status: 500; reason: test error" + if err == nil || err.Error() != want { + t.Errorf("Get() = %v; want = %v", err, want) + } + checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: "/peter.json"}) +} + +func TestUnexpectedHttpError(t *testing.T) { + mock := &mockServer{Resp: "unexpected error", Status: 500} + srv := mock.Start(client) + defer srv.Close() + + var got person + err := ref.Get(&got) + want := "http error status: 500; message: \"unexpected error\"" + if err == nil || err.Error() != want { + t.Errorf("Get() = %v; want = %v", err, want) + } + checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: "/peter.json"}) +} + +func TestSet(t *testing.T) { + mock := &mockServer{} + srv := mock.Start(client) + defer srv.Close() + + want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} + if err := ref.Set(want); err != nil { + t.Fatal(err) + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "PUT", + Path: "/peter.json", + Body: serialize(want), + Query: map[string]string{"print": "silent"}, + }) +} + +func TestSetWithStruct(t *testing.T) { + mock := &mockServer{} + srv := mock.Start(client) + defer srv.Close() + + want := &person{"Peter Parker", 17} + if err := ref.Set(&want); err != nil { + t.Fatal(err) + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "PUT", + Path: "/peter.json", + Body: serialize(want), + Query: map[string]string{"print": "silent"}, + }) +} + +func TestSetIfUnchanged(t *testing.T) { + mock := &mockServer{} + srv := mock.Start(client) + defer srv.Close() + + want := &person{"Peter Parker", 17} + ok, err := ref.SetIfUnchanged("mock-etag", &want) + if err != nil { + t.Fatal(err) + } + if !ok { + t.Errorf("SetIfUnchanged() = %v; want = %v", ok, true) + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "PUT", + Path: "/peter.json", + Body: serialize(want), + Header: http.Header{"If-Match": []string{"mock-etag"}}, + }) +} + +func TestSetIfUnchangedError(t *testing.T) { + mock := &mockServer{ + Status: http.StatusPreconditionFailed, + Resp: &person{"Tony Stark", 39}, + } + srv := mock.Start(client) + defer srv.Close() + + want := &person{"Peter Parker", 17} + ok, err := ref.SetIfUnchanged("mock-etag", &want) + if err != nil { + t.Fatal(err) + } + if ok { + t.Errorf("SetIfUnchanged() = %v; want = %v", ok, false) + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "PUT", + Path: "/peter.json", + Body: serialize(want), + Header: http.Header{"If-Match": []string{"mock-etag"}}, + }) +} + +func TestPush(t *testing.T) { + mock := &mockServer{Resp: map[string]string{"name": "new_key"}} + srv := mock.Start(client) + defer srv.Close() + + child, err := ref.Push(nil) + if err != nil { + t.Fatal(err) + } + + if child.Key != "new_key" { + t.Errorf("Push() = %q; want = %q", child.Key, "new_key") + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "POST", + Path: "/peter.json", + Body: serialize(""), + }) +} + +func TestPushWithValue(t *testing.T) { + mock := &mockServer{Resp: map[string]string{"name": "new_key"}} + srv := mock.Start(client) + defer srv.Close() + + want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} + child, err := ref.Push(want) + if err != nil { + t.Fatal(err) + } + + if child.Key != "new_key" { + t.Errorf("Push() = %q; want = %q", child.Key, "new_key") + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "POST", + Path: "/peter.json", + Body: serialize(want), + }) +} + +func TestUpdate(t *testing.T) { + want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + if err := ref.Update(want); err != nil { + t.Fatal(err) + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "PATCH", + Path: "/peter.json", + Body: serialize(want), + Query: map[string]string{"print": "silent"}, + }) +} + +func TestInvalidUpdate(t *testing.T) { + if err := ref.Update(nil); err == nil { + t.Errorf("Update(nil) = nil; want error") + } + + m := make(map[string]interface{}) + if err := ref.Update(m); err == nil { + t.Errorf("Update(map{}) = nil; want error") + } +} + +func TestTransaction(t *testing.T) { + mock := &mockServer{ + Resp: &person{"Peter Parker", 17}, + Header: map[string]string{"ETag": "mock-etag"}, + } + srv := mock.Start(client) + defer srv.Close() + + var fn UpdateFn = func(i interface{}) (interface{}, error) { + p := i.(map[string]interface{}) + p["age"] = p["age"].(float64) + 1.0 + return p, nil + } + if err := ref.Transaction(fn); err != nil { + t.Fatal(err) + } + checkAllRequests(t, mock.Reqs, []*testReq{ + &testReq{ + Method: "GET", + Path: "/peter.json", + Header: http.Header{"X-Firebase-ETag": []string{"true"}}, + }, + &testReq{ + Method: "PUT", + Path: "/peter.json", + Body: serialize(map[string]interface{}{ + "name": "Peter Parker", + "age": 18, + }), + Header: http.Header{"If-Match": []string{"mock-etag"}}, + }, + }) +} + +func TestTransactionRetry(t *testing.T) { + mock := &mockServer{ + Resp: &person{"Peter Parker", 17}, + Header: map[string]string{"ETag": "mock-etag1"}, + } + srv := mock.Start(client) + defer srv.Close() + + cnt := 0 + var fn UpdateFn = func(i interface{}) (interface{}, error) { + if cnt == 0 { + mock.Status = http.StatusPreconditionFailed + mock.Header = map[string]string{"ETag": "mock-etag2"} + mock.Resp = &person{"Peter Parker", 19} + } else if cnt == 1 { + mock.Status = http.StatusOK + } + cnt++ + p := i.(map[string]interface{}) + p["age"] = p["age"].(float64) + 1.0 + return p, nil + } + if err := ref.Transaction(fn); err != nil { + t.Fatal(err) + } + if cnt != 2 { + t.Errorf("Retry Count = %d; want = %d", cnt, 2) + } + checkAllRequests(t, mock.Reqs, []*testReq{ + &testReq{ + Method: "GET", + Path: "/peter.json", + Header: http.Header{"X-Firebase-ETag": []string{"true"}}, + }, + &testReq{ + Method: "PUT", + Path: "/peter.json", + Body: serialize(map[string]interface{}{ + "name": "Peter Parker", + "age": 18, + }), + Header: http.Header{"If-Match": []string{"mock-etag1"}}, + }, + &testReq{ + Method: "PUT", + Path: "/peter.json", + Body: serialize(map[string]interface{}{ + "name": "Peter Parker", + "age": 20, + }), + Header: http.Header{"If-Match": []string{"mock-etag2"}}, + }, + }) +} + +func TestTransactionAbort(t *testing.T) { + mock := &mockServer{ + Resp: &person{"Peter Parker", 17}, + Header: map[string]string{"ETag": "mock-etag1"}, + } + srv := mock.Start(client) + defer srv.Close() + + cnt := 0 + var fn UpdateFn = func(i interface{}) (interface{}, error) { + if cnt == 0 { + mock.Status = http.StatusPreconditionFailed + mock.Header = map[string]string{"ETag": "mock-etag1"} + } + cnt++ + p := i.(map[string]interface{}) + p["age"] = p["age"].(float64) + 1.0 + return p, nil + } + err := ref.Transaction(fn) + if err == nil { + t.Errorf("Transaction() = nil; want error") + } + wanted := []*testReq{ + &testReq{ + Method: "GET", + Path: "/peter.json", + Header: http.Header{"X-Firebase-ETag": []string{"true"}}, + }, + } + for i := 0; i < 20; i++ { + wanted = append(wanted, &testReq{ + Method: "PUT", + Path: "/peter.json", + Body: serialize(map[string]interface{}{ + "name": "Peter Parker", + "age": 18, + }), + Header: http.Header{"If-Match": []string{"mock-etag1"}}, + }) + } + checkAllRequests(t, mock.Reqs, wanted) +} + +func TestDelete(t *testing.T) { + mock := &mockServer{Resp: "null"} + srv := mock.Start(client) + defer srv.Close() + + if err := ref.Delete(); err != nil { + t.Fatal(err) + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "DELETE", + Path: "/peter.json", + }) +} diff --git a/firebase.go b/firebase.go index 6ae0ca52..68c3337d 100644 --- a/firebase.go +++ b/firebase.go @@ -36,8 +36,9 @@ const Version = "2.0.0" // An App holds configuration and state common to all Firebase services that are exposed from the SDK. type App struct { + authOverrides map[string]interface{} creds *google.DefaultCredentials - databaseURL string + dbURL string projectID string storageBucket string opts []option.ClientOption @@ -45,6 +46,7 @@ type App struct { // Config represents the configuration used to initialize an App. type Config struct { + AuthOverrides map[string]interface{} DatabaseURL string ProjectID string StorageBucket string @@ -63,9 +65,10 @@ func (a *App) Auth(ctx context.Context) (*auth.Client, error) { // Database returns an instance of db.Client. func (a *App) Database(ctx context.Context) (*db.Client, error) { conf := &internal.DatabaseConfig{ - BaseURL: a.databaseURL, - Opts: a.opts, - Version: Version, + AuthOverrides: a.authOverrides, + BaseURL: a.dbURL, + Opts: a.opts, + Version: Version, } return db.NewClient(ctx, conf) } @@ -107,8 +110,9 @@ func NewApp(ctx context.Context, config *Config, opts ...option.ClientOption) (* } return &App{ + authOverrides: config.AuthOverrides, creds: creds, - databaseURL: config.DatabaseURL, + dbURL: config.DatabaseURL, projectID: pid, storageBucket: config.StorageBucket, opts: o, diff --git a/firebase_test.go b/firebase_test.go index fb411279..919518aa 100644 --- a/firebase_test.go +++ b/firebase_test.go @@ -216,6 +216,19 @@ func TestAuth(t *testing.T) { } } +func TestDatabase(t *testing.T) { + ctx := context.Background() + conf := &Config{DatabaseURL: "https://mock-db.firebaseio.com"} + app, err := NewApp(ctx, conf, option.WithCredentialsFile("testdata/service_account.json")) + if err != nil { + t.Fatal(err) + } + + if c, err := app.Database(ctx); c == nil || err != nil { + t.Errorf("Database() = (%v, %v); want (db, nil)", c, err) + } +} + func TestStorage(t *testing.T) { ctx := context.Background() app, err := NewApp(ctx, nil, option.WithCredentialsFile("testdata/service_account.json")) diff --git a/integration/auth/auth_test.go b/integration/auth/auth_test.go index 11ea9463..44daa3a4 100644 --- a/integration/auth/auth_test.go +++ b/integration/auth/auth_test.go @@ -44,7 +44,7 @@ func TestMain(m *testing.M) { } ctx := context.Background() - app, err := internal.NewTestApp(ctx) + app, err := internal.NewTestApp(ctx, nil) if err != nil { log.Fatalln(err) } diff --git a/integration/db/db_test.go b/integration/db/db_test.go index 53b26b4e..7dc7f60e 100644 --- a/integration/db/db_test.go +++ b/integration/db/db_test.go @@ -9,6 +9,8 @@ import ( "os" "testing" + "firebase.google.com/go" + "io/ioutil" "encoding/json" @@ -22,12 +24,18 @@ import ( ) var client *db.Client +var aoClient *db.Client +var guestClient *db.Client + var ref *db.Ref var users *db.Ref var dinos *db.Ref + var testData map[string]interface{} var parsedTestData map[string]Dinosaur +const permDenied = "http error status: 401; reason: Permission denied" + func TestMain(m *testing.M) { flag.Parse() if testing.Short() { @@ -35,13 +43,17 @@ func TestMain(m *testing.M) { os.Exit(0) } - ctx := context.Background() - app, err := internal.NewTestApp(ctx) + pid, err := internal.ProjectID() if err != nil { log.Fatalln(err) } - client, err = app.Database(ctx) + client, err = initClient(pid) + if err != nil { + log.Fatalln(err) + } + + aoClient, err = initOverrideClient(pid) if err != nil { log.Fatalln(err) } @@ -53,6 +65,31 @@ func TestMain(m *testing.M) { os.Exit(m.Run()) } +func initClient(pid string) (*db.Client, error) { + ctx := context.Background() + app, err := internal.NewTestApp(ctx, &firebase.Config{ + DatabaseURL: fmt.Sprintf("https://%s.firebaseio.com", pid), + }) + if err != nil { + return nil, err + } + + return app.Database(ctx) +} + +func initOverrideClient(pid string) (*db.Client, error) { + ctx := context.Background() + app, err := internal.NewTestApp(ctx, &firebase.Config{ + DatabaseURL: fmt.Sprintf("https://%s.firebaseio.com", pid), + AuthOverrides: map[string]interface{}{"uid": "user1"}, + }) + if err != nil { + return nil, err + } + + return app.Database(ctx) +} + func initRefs() { var err error ref, err = client.NewRef("_adminsdk/go/dinodb") @@ -521,6 +558,83 @@ func TestDelete(t *testing.T) { } } +func TestNoAccess(t *testing.T) { + r, err := aoClient.NewRef(protectedRef(t, "_adminsdk/go/admin")) + if err != nil { + t.Fatal(err) + } + var got string + if err := r.Get(&got); err == nil || got != "" { + t.Errorf("Get() = (%q, %v); want = (empty, error)", got, err) + } else if err.Error() != permDenied { + t.Errorf("Error = %q; want = %q", err.Error(), permDenied) + } + if err := r.Set("update"); err == nil { + t.Errorf("Set() = nil; want = error") + } else if err.Error() != permDenied { + t.Errorf("Error = %q; want = %q", err.Error(), permDenied) + } +} + +func TestReadAccess(t *testing.T) { + r, err := aoClient.NewRef(protectedRef(t, "_adminsdk/go/protected/user2")) + if err != nil { + t.Fatal(err) + } + var got string + if err := r.Get(&got); err != nil || got != "test" { + t.Errorf("Get() = (%q, %v); want = (%q, nil)", got, err, "test") + } + if err := r.Set("update"); err == nil { + t.Errorf("Set() = nil; want = error") + } else if err.Error() != permDenied { + t.Errorf("Error = %q; want = %q", err.Error(), permDenied) + } +} + +func TestReadWriteAccess(t *testing.T) { + r, err := aoClient.NewRef(protectedRef(t, "_adminsdk/go/protected/user1")) + if err != nil { + t.Fatal(err) + } + var got string + if err := r.Get(&got); err != nil || got != "test" { + t.Errorf("Get() = (%q, %v); want = (%q, nil)", got, err, "test") + } + if err := r.Set("update"); err != nil { + t.Errorf("Set() = %v; want = nil", err) + } +} + +func TestQueryAccess(t *testing.T) { + r, err := aoClient.NewRef("_adminsdk/go/protected") + if err != nil { + t.Fatal(err) + } + + q, err := r.OrderByKey(db.WithLimitToFirst(2)) + if err != nil { + t.Fatal(err) + } + got := make(map[string]interface{}) + if err := q.Get(&got); err == nil { + t.Errorf("OrderByQuery() = nil; want = error") + } else if err.Error() != permDenied { + t.Errorf("Error = %q; want = %q", err.Error(), permDenied) + } +} + +func protectedRef(t *testing.T, p string) string { + r, err := client.NewRef(p) + if err != nil { + t.Fatal(err) + } + if err := r.Set("test"); err != nil { + t.Fatal(err) + } + return p +} + type Dinosaur struct { Appeared int `json:"appeared"` Height float64 `json:"height"` diff --git a/integration/internal/internal.go b/integration/internal/internal.go index f210d92d..1d8ea9b3 100644 --- a/integration/internal/internal.go +++ b/integration/internal/internal.go @@ -25,8 +25,6 @@ import ( "golang.org/x/net/context" - "fmt" - firebase "firebase.google.com/go" "firebase.google.com/go/internal" "google.golang.org/api/option" @@ -47,16 +45,8 @@ func Resource(name string) string { // NewTestApp looks for a service account JSON file named integration_cert.json // in the testdata directory. This file is used to initialize the newly created // App instance. -func NewTestApp(ctx context.Context) (*firebase.App, error) { - pid, err := ProjectID() - if err != nil { - return nil, err - } - config := &firebase.Config{ - DatabaseURL: fmt.Sprintf("https://%s.firebaseio.com", pid), - StorageBucket: fmt.Sprintf("%s.appspot.com", pid), - } - return firebase.NewApp(ctx, config, option.WithCredentialsFile(Resource(certPath))) +func NewTestApp(ctx context.Context, conf *firebase.Config) (*firebase.App, error) { + return firebase.NewApp(ctx, conf, option.WithCredentialsFile(Resource(certPath))) } // APIKey fetches a Firebase API key for integration tests. diff --git a/integration/storage/storage_test.go b/integration/storage/storage_test.go index 078e61c6..d56bb470 100644 --- a/integration/storage/storage_test.go +++ b/integration/storage/storage_test.go @@ -22,6 +22,8 @@ import ( "os" "testing" + "firebase.google.com/go" + gcs "cloud.google.com/go/storage" "firebase.google.com/go/integration/internal" "firebase.google.com/go/storage" @@ -38,8 +40,15 @@ func TestMain(m *testing.M) { os.Exit(0) } + pid, err := internal.ProjectID() + if err != nil { + log.Fatalln(err) + } + ctx = context.Background() - app, err := internal.NewTestApp(ctx) + app, err := internal.NewTestApp(ctx, &firebase.Config{ + StorageBucket: fmt.Sprintf("%s.appspot.com", pid), + }) if err != nil { log.Fatalln(err) } From da6fc75b3272650b4809868d2a39175edf81fa24 Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Tue, 24 Oct 2017 22:10:47 -0700 Subject: [PATCH 15/58] Implemented AuthOverride support; Added tests --- db/db.go | 8 +++-- db/db_test.go | 49 ++++++++++++++++++++++++--- firebase.go | 19 +++++++---- firebase_test.go | 32 ++++++++++++++++++ integration/db/db_test.go | 69 ++++++++++++++++++++++++++++++++++++++- internal/internal.go | 8 ++--- 6 files changed, 166 insertions(+), 19 deletions(-) diff --git a/db/db.go b/db/db.go index b96a8e8a..d0fefe65 100644 --- a/db/db.go +++ b/db/db.go @@ -42,6 +42,10 @@ type Client struct { ao string } +type AuthOverrides struct { + Map map[string]interface{} +} + func NewClient(ctx context.Context, c *internal.DatabaseConfig) (*Client, error) { userAgent := fmt.Sprintf(userAgent, c.Version, runtime.Version()) o := []option.ClientOption{option.WithUserAgent(userAgent)} @@ -64,8 +68,8 @@ func NewClient(ctx context.Context, c *internal.DatabaseConfig) (*Client, error) } var ao []byte - if c.AuthOverrides != nil { - ao, err = json.Marshal(c.AuthOverrides) + if c.AO == nil || len(c.AO) > 0 { + ao, err = json.Marshal(c.AO) if err != nil { return nil, err } diff --git a/db/db_test.go b/db/db_test.go index 4b462d7a..b73851d5 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -42,6 +42,7 @@ func TestMain(m *testing.M) { Opts: testOpts, BaseURL: testURL, Version: "1.2.3", + AO: map[string]interface{}{}, }) if err != nil { log.Fatalln(err) @@ -49,10 +50,10 @@ func TestMain(m *testing.M) { ao := map[string]interface{}{"uid": "user1"} aoClient, err = NewClient(context.Background(), &internal.DatabaseConfig{ - Opts: testOpts, - BaseURL: testURL, - Version: "1.2.3", - AuthOverrides: ao, + Opts: testOpts, + BaseURL: testURL, + Version: "1.2.3", + AO: ao, }) if err != nil { log.Fatalln(err) @@ -77,15 +78,50 @@ func TestNewClient(t *testing.T) { c, err := NewClient(context.Background(), &internal.DatabaseConfig{ Opts: testOpts, BaseURL: testURL, + AO: make(map[string]interface{}), }) if err != nil { t.Fatal(err) } if c.baseURL != testURL { t.Errorf("BaseURL = %q; want: %q", c.baseURL, testURL) - } else if c.hc == nil { + } + if c.hc == nil { t.Errorf("http.Client = nil; want non-nil") } + if c.ao != "" { + t.Errorf("AuthOverrides = %q; want %q", c.ao, "") + } +} + +func TestNewClientAuthOverrides(t *testing.T) { + cases := []map[string]interface{}{ + nil, + map[string]interface{}{"uid": "user1"}, + } + for _, tc := range cases { + c, err := NewClient(context.Background(), &internal.DatabaseConfig{ + Opts: testOpts, + BaseURL: testURL, + AO: tc, + }) + if err != nil { + t.Fatal(err) + } + if c.baseURL != testURL { + t.Errorf("BaseURL = %q; want: %q", c.baseURL, testURL) + } + if c.hc == nil { + t.Errorf("http.Client = nil; want non-nil") + } + b, err := json.Marshal(tc) + if err != nil { + t.Fatal(err) + } + if c.ao != string(b) { + t.Errorf("AuthOverrides = %q; want %q", c.ao, string(b)) + } + } } func TestNewClientError(t *testing.T) { @@ -256,6 +292,9 @@ func checkRequest(t *testing.T, got, want *testReq) { if got.Path != want.Path { t.Errorf("Path = %q; want = %q", got.Path, want.Path) } + if len(want.Query) != len(got.Query) { + t.Errorf("QueryParam = %v; want = %v", got.Query, want.Query) + } for k, v := range want.Query { if got.Query[k] != v { t.Errorf("QueryParam(%v) = %v; want = %v", k, got.Query[k], v) diff --git a/firebase.go b/firebase.go index 68c3337d..8e0c6c84 100644 --- a/firebase.go +++ b/firebase.go @@ -36,7 +36,7 @@ const Version = "2.0.0" // An App holds configuration and state common to all Firebase services that are exposed from the SDK. type App struct { - authOverrides map[string]interface{} + ao map[string]interface{} creds *google.DefaultCredentials dbURL string projectID string @@ -46,7 +46,7 @@ type App struct { // Config represents the configuration used to initialize an App. type Config struct { - AuthOverrides map[string]interface{} + AuthOverrides *db.AuthOverrides DatabaseURL string ProjectID string StorageBucket string @@ -65,10 +65,10 @@ func (a *App) Auth(ctx context.Context) (*auth.Client, error) { // Database returns an instance of db.Client. func (a *App) Database(ctx context.Context) (*db.Client, error) { conf := &internal.DatabaseConfig{ - AuthOverrides: a.authOverrides, - BaseURL: a.dbURL, - Opts: a.opts, - Version: Version, + AO: a.ao, + BaseURL: a.dbURL, + Opts: a.opts, + Version: Version, } return db.NewClient(ctx, conf) } @@ -109,8 +109,13 @@ func NewApp(ctx context.Context, config *Config, opts ...option.ClientOption) (* pid = os.Getenv("GCLOUD_PROJECT") } + ao := make(map[string]interface{}) + if config.AuthOverrides != nil { + ao = config.AuthOverrides.Map + } + return &App{ - authOverrides: config.AuthOverrides, + ao: ao, creds: creds, dbURL: config.DatabaseURL, projectID: pid, diff --git a/firebase_test.go b/firebase_test.go index 919518aa..dc028a1a 100644 --- a/firebase_test.go +++ b/firebase_test.go @@ -30,6 +30,9 @@ import ( "encoding/json" + "reflect" + + "firebase.google.com/go/db" "golang.org/x/net/context" "golang.org/x/oauth2" "google.golang.org/api/option" @@ -224,11 +227,40 @@ func TestDatabase(t *testing.T) { t.Fatal(err) } + if app.ao == nil || len(app.ao) != 0 { + t.Errorf("AuthOverrides = %v; want = empty map", app.ao) + } if c, err := app.Database(ctx); c == nil || err != nil { t.Errorf("Database() = (%v, %v); want (db, nil)", c, err) } } +func TestDatabaseAuthOverrides(t *testing.T) { + cases := []map[string]interface{}{ + nil, + map[string]interface{}{}, + map[string]interface{}{"uid": "user1"}, + } + for _, tc := range cases { + ctx := context.Background() + conf := &Config{ + AuthOverrides: &db.AuthOverrides{tc}, + DatabaseURL: "https://mock-db.firebaseio.com", + } + app, err := NewApp(ctx, conf, option.WithCredentialsFile("testdata/service_account.json")) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(app.ao, tc) { + t.Errorf("AuthOverrides = %v; want = %v", app.ao, tc) + } + if c, err := app.Database(ctx); c == nil || err != nil { + t.Errorf("Database() = (%v, %v); want (db, nil)", c, err) + } + } +} + func TestStorage(t *testing.T) { ctx := context.Background() app, err := NewApp(ctx, nil, option.WithCredentialsFile("testdata/service_account.json")) diff --git a/integration/db/db_test.go b/integration/db/db_test.go index 7dc7f60e..e75d34d4 100644 --- a/integration/db/db_test.go +++ b/integration/db/db_test.go @@ -58,6 +58,11 @@ func TestMain(m *testing.M) { log.Fatalln(err) } + guestClient, err = initGuestClient(pid) + if err != nil { + log.Fatalln(err) + } + initRefs() initRules() initData() @@ -78,10 +83,25 @@ func initClient(pid string) (*db.Client, error) { } func initOverrideClient(pid string) (*db.Client, error) { + ctx := context.Background() + app, err := internal.NewTestApp(ctx, &firebase.Config{ + DatabaseURL: fmt.Sprintf("https://%s.firebaseio.com", pid), + AuthOverrides: &db.AuthOverrides{ + Map: map[string]interface{}{"uid": "user1"}, + }, + }) + if err != nil { + return nil, err + } + + return app.Database(ctx) +} + +func initGuestClient(pid string) (*db.Client, error) { ctx := context.Background() app, err := internal.NewTestApp(ctx, &firebase.Config{ DatabaseURL: fmt.Sprintf("https://%s.firebaseio.com", pid), - AuthOverrides: map[string]interface{}{"uid": "user1"}, + AuthOverrides: &db.AuthOverrides{}, }) if err != nil { return nil, err @@ -624,6 +644,53 @@ func TestQueryAccess(t *testing.T) { } } +func TestGuestAccess(t *testing.T) { + r, err := guestClient.NewRef(protectedRef(t, "_adminsdk/go/public")) + if err != nil { + t.Fatal(err) + } + var got string + if err := r.Get(&got); err != nil || got != "test" { + t.Errorf("Get() = (%q, %v); want = (%q, nil)", got, err, "test") + } + if err := r.Set("update"); err == nil { + t.Errorf("Set() = nil; want = error") + } else if err.Error() != permDenied { + t.Errorf("Error = %q; want = %q", err.Error(), permDenied) + } + + got = "" + r, err = guestClient.NewRef("_adminsdk/go") + if err != nil { + t.Fatal(err) + } + if err := r.Get(&got); err == nil || got != "" { + t.Errorf("Get() = (%q, %v); want = (empty, error)", got, err) + } else if err.Error() != permDenied { + t.Errorf("Error = %q; want = %q", err.Error(), permDenied) + } + + c, err := r.Child("protected/user2") + if err != nil { + t.Fatal(err) + } + if err := c.Get(&got); err == nil || got != "" { + t.Errorf("Get() = (%q, %v); want = (empty, error)", got, err) + } else if err.Error() != permDenied { + t.Errorf("Error = %q; want = %q", err.Error(), permDenied) + } + + c, err = r.Child("admin") + if err != nil { + t.Fatal(err) + } + if err := c.Get(&got); err == nil || got != "" { + t.Errorf("Get() = (%q, %v); want = (empty, error)", got, err) + } else if err.Error() != permDenied { + t.Errorf("Error = %q; want = %q", err.Error(), permDenied) + } +} + func protectedRef(t *testing.T, p string) string { r, err := client.NewRef(p) if err != nil { diff --git a/internal/internal.go b/internal/internal.go index 0d1d97c7..262f356b 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -35,10 +35,10 @@ type AuthConfig struct { } type DatabaseConfig struct { - Opts []option.ClientOption - BaseURL string - Version string - AuthOverrides map[string]interface{} + Opts []option.ClientOption + BaseURL string + Version string + AO map[string]interface{} } // StorageConfig represents the configuration of Google Cloud Storage service. From 641dbaad9dce3e0788945e9f5d0d2373067dbc29 Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Thu, 26 Oct 2017 23:02:41 -0700 Subject: [PATCH 16/58] Implementing the new API --- db/auth_override_test.go | 36 ++--------- db/db.go | 73 +++++++++++----------- db/db_test.go | 85 +++++++++++-------------- db/http_client.go | 34 +++++++--- db/query.go | 81 ++++++++++++++++-------- db/query_test.go | 117 ++++++++++++----------------------- db/ref.go | 24 +++---- integration/db/db_test.go | 97 ++++++----------------------- integration/db/query_test.go | 50 +++------------ 9 files changed, 239 insertions(+), 358 deletions(-) diff --git a/db/auth_override_test.go b/db/auth_override_test.go index 49a7f20e..9d295bc7 100644 --- a/db/auth_override_test.go +++ b/db/auth_override_test.go @@ -9,11 +9,7 @@ func TestAuthOverrideGet(t *testing.T) { srv := mock.Start(aoClient) defer srv.Close() - ref, err := aoClient.NewRef("peter") - if err != nil { - t.Fatal(err) - } - + ref := aoClient.NewRef("peter") var got string if err := ref.Get(&got); err != nil { t.Fatal(err) @@ -33,11 +29,7 @@ func TestAuthOverrideSet(t *testing.T) { srv := mock.Start(aoClient) defer srv.Close() - ref, err := aoClient.NewRef("peter") - if err != nil { - t.Fatal(err) - } - + ref := aoClient.NewRef("peter") want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} if err := ref.Set(want); err != nil { t.Fatal(err) @@ -55,17 +47,9 @@ func TestAuthOverrideQuery(t *testing.T) { srv := mock.Start(aoClient) defer srv.Close() - ref, err := aoClient.NewRef("peter") - if err != nil { - t.Fatal(err) - } - - q, err := ref.OrderByChild("foo") - if err != nil { - t.Fatal(err) - } + ref := aoClient.NewRef("peter") var got string - if err := q.Get(&got); err != nil { + if err := ref.OrderByChild("foo").Get(&got); err != nil { t.Fatal(err) } if got != "data" { @@ -86,17 +70,9 @@ func TestAuthOverrideRangeQuery(t *testing.T) { srv := mock.Start(aoClient) defer srv.Close() - ref, err := aoClient.NewRef("peter") - if err != nil { - t.Fatal(err) - } - - q, err := ref.OrderByChild("foo", WithStartAt(1), WithEndAt(10)) - if err != nil { - t.Fatal(err) - } + ref := aoClient.NewRef("peter") var got string - if err := q.Get(&got); err != nil { + if err := ref.OrderByChild("foo", WithStartAt(1), WithEndAt(10)).Get(&got); err != nil { t.Fatal(err) } if got != "data" { diff --git a/db/db.go b/db/db.go index d0fefe65..0a8ec86d 100644 --- a/db/db.go +++ b/db/db.go @@ -16,6 +16,7 @@ package db import ( + "encoding/json" "fmt" "net/http" "runtime" @@ -25,45 +26,39 @@ import ( "net/url" - "encoding/json" - "golang.org/x/net/context" "google.golang.org/api/option" "google.golang.org/api/transport" ) const invalidChars = "[].#$" -const userAgent = "Firebase/HTTP/%s/%s/AdminGo" +const userAgentFormat = "Firebase/HTTP/%s/%s/AdminGo" // Client is the interface for the Firebase Realtime Database service. type Client struct { - hc *http.Client - baseURL string - ao string -} - -type AuthOverrides struct { - Map map[string]interface{} + hc *http.Client + url string + ao string } func NewClient(ctx context.Context, c *internal.DatabaseConfig) (*Client, error) { - userAgent := fmt.Sprintf(userAgent, c.Version, runtime.Version()) - o := []option.ClientOption{option.WithUserAgent(userAgent)} - o = append(o, c.Opts...) - - hc, _, err := transport.NewHTTPClient(ctx, o...) + opts := append([]option.ClientOption{}, c.Opts...) + ua := fmt.Sprintf(userAgentFormat, c.Version, runtime.Version()) + opts = append(opts, option.WithUserAgent(ua)) + hc, _, err := transport.NewHTTPClient(ctx, opts...) if err != nil { return nil, err } + if c.BaseURL == "" { return nil, fmt.Errorf("database url not specified") } - url, err := url.Parse(c.BaseURL) + p, err := url.Parse(c.BaseURL) if err != nil { return nil, err - } else if url.Scheme != "https" { + } else if p.Scheme != "https" { return nil, fmt.Errorf("invalid database URL (incorrect scheme): %q", c.BaseURL) - } else if !strings.HasSuffix(url.Host, ".firebaseio.com") { + } else if !strings.HasSuffix(p.Host, ".firebaseio.com") { return nil, fmt.Errorf("invalid database URL (incorrest host): %q", c.BaseURL) } @@ -74,47 +69,51 @@ func NewClient(ctx context.Context, c *internal.DatabaseConfig) (*Client, error) return nil, err } } + return &Client{ - hc: hc, - baseURL: fmt.Sprintf("https://%s", url.Host), - ao: string(ao), + hc: hc, + url: fmt.Sprintf("https://%s", p.Host), + ao: string(ao), }, nil } -func (c *Client) NewRef(path string) (*Ref, error) { - segs, err := parsePath(path) - if err != nil { - return nil, err +func newHTTPOptions(m map[string]interface{}) ([]httpOption, error) { + var opts []httpOption + if m == nil || len(m) > 0 { + ao, err := json.Marshal(m) + if err != nil { + return nil, err + } + opts = append(opts, withQueryParam("auth_variable_override", string(ao))) } + return opts, nil +} + +type AuthOverrides struct { + Map map[string]interface{} +} +func (c *Client) NewRef(path string) *Ref { + segs := parsePath(path) key := "" if len(segs) > 0 { key = segs[len(segs)-1] } - var opts []httpOption - if c.ao != "" { - opts = append(opts, withQueryParam("auth_variable_override", c.ao)) - } - return &Ref{ Key: key, Path: "/" + strings.Join(segs, "/"), client: c, segs: segs, - opts: opts, - }, nil + } } -func parsePath(path string) ([]string, error) { - if strings.ContainsAny(path, invalidChars) { - return nil, fmt.Errorf("path %q contains one or more invalid characters", path) - } +func parsePath(path string) []string { var segs []string for _, s := range strings.Split(path, "/") { if s != "" { segs = append(segs, s) } } - return segs, nil + return segs } diff --git a/db/db_test.go b/db/db_test.go index b73851d5..934ee8f4 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -65,12 +65,8 @@ func TestMain(m *testing.M) { } testAuthOverrides = string(b) - ref, err = client.NewRef("peter") - if err != nil { - log.Fatalln(err) - } - - testUserAgent = fmt.Sprintf(userAgent, "1.2.3", runtime.Version()) + ref = client.NewRef("peter") + testUserAgent = fmt.Sprintf(userAgentFormat, "1.2.3", runtime.Version()) os.Exit(m.Run()) } @@ -83,8 +79,8 @@ func TestNewClient(t *testing.T) { if err != nil { t.Fatal(err) } - if c.baseURL != testURL { - t.Errorf("BaseURL = %q; want: %q", c.baseURL, testURL) + if c.url != testURL { + t.Errorf("BaseURL = %q; want: %q", c.url, testURL) } if c.hc == nil { t.Errorf("http.Client = nil; want non-nil") @@ -108,8 +104,8 @@ func TestNewClientAuthOverrides(t *testing.T) { if err != nil { t.Fatal(err) } - if c.baseURL != testURL { - t.Errorf("BaseURL = %q; want: %q", c.baseURL, testURL) + if c.url != testURL { + t.Errorf("BaseURL = %q; want: %q", c.url, testURL) } if c.hc == nil { t.Errorf("http.Client = nil; want non-nil") @@ -157,10 +153,7 @@ func TestNewRef(t *testing.T) { {"/foo/bar/", "/foo/bar", "bar"}, } for _, tc := range cases { - r, err := client.NewRef(tc.Path) - if err != nil { - t.Fatal(err) - } + r := client.NewRef(tc.Path) if r.client == nil { t.Errorf("Client = nil; want = %v", client) } @@ -173,16 +166,6 @@ func TestNewRef(t *testing.T) { } } -func TestInvalidNewRef(t *testing.T) { - cases := []string{"foo#", "foo.", "foo$", "foo[", "foo]"} - for _, tc := range cases { - r, err := client.NewRef(tc) - if r != nil || err == nil { - t.Errorf("NewRef(%q) = (%v, %v); want = (nil, err)", tc, r, err) - } - } -} - func TestParent(t *testing.T) { cases := []struct { Path string @@ -198,12 +181,7 @@ func TestParent(t *testing.T) { {"/foo/bar/", true, "foo"}, } for _, tc := range cases { - r, err := client.NewRef(tc.Path) - if err != nil { - t.Fatal(err) - } - - r = r.Parent() + r := client.NewRef(tc.Path).Parent() if tc.HasParent { if r == nil { t.Fatalf("Parent = nil; want = %q", tc.Want) @@ -221,25 +199,28 @@ func TestParent(t *testing.T) { } func TestChild(t *testing.T) { - r, err := client.NewRef("/test") - if err != nil { - t.Fatal(err) - } - + r := client.NewRef("/test") cases := []struct { Path string Want string Parent string }{ + {"", "/test", "/"}, {"foo", "/test/foo", "/test"}, + {"/foo", "/test/foo", "/test"}, + {"foo/", "/test/foo", "/test"}, + {"/foo/", "/test/foo", "/test"}, + {"//foo//", "/test/foo", "/test"}, {"foo/bar", "/test/foo/bar", "/test/foo"}, + {"/foo/bar", "/test/foo/bar", "/test/foo"}, {"foo/bar/", "/test/foo/bar", "/test/foo"}, + {"/foo/bar/", "/test/foo/bar", "/test/foo"}, + {"//foo/bar", "/test/foo/bar", "/test/foo"}, + {"foo//bar/", "/test/foo/bar", "/test/foo"}, + {"foo/bar//", "/test/foo/bar", "/test/foo"}, } for _, tc := range cases { - c, err := r.Child(tc.Path) - if err != nil { - t.Fatal(err) - } + c := r.Child(tc.Path) if c.Path != tc.Want { t.Errorf("Child(%q) = %q; want = %q", tc.Path, c.Path, tc.Want) } @@ -249,19 +230,25 @@ func TestChild(t *testing.T) { } } -func TestInvalidChild(t *testing.T) { - r, err := client.NewRef("/test") - if err != nil { - t.Fatal(err) - } +func TestInvalidPath(t *testing.T) { + mock := &mockServer{Resp: "test"} + srv := mock.Start(client) + defer srv.Close() - cases := []string{"", "/", "/foo", "foo#bar"} + cases := []string{ + "foo$", "foo.", "foo#", "foo]", "foo[", + } for _, tc := range cases { - c, err := r.Child(tc) - if c != nil || err == nil { - t.Errorf("Child(%q) = (%v, %v); want = (nil, err)", tc, c, err) + r := client.NewRef(tc) + var got string + if err := r.Get(&got); got != "" || err == nil { + t.Errorf("Get() = (%q, %v); want = (%q, error)", got, err, "") } } + + if len(mock.Reqs) != 0 { + t.Errorf("Requests: %v; want: empty", mock.Reqs) + } } func checkOnlyRequest(t *testing.T, got []*testReq, want *testReq) { @@ -389,7 +376,7 @@ func (s *mockServer) Start(c *Client) *httptest.Server { w.Write(b) }) s.srv = httptest.NewServer(handler) - c.baseURL = s.srv.URL + c.url = s.srv.URL return s.srv } diff --git a/db/http_client.go b/db/http_client.go index fcc55f14..88fc2132 100644 --- a/db/http_client.go +++ b/db/http_client.go @@ -2,6 +2,7 @@ package db import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -9,8 +10,13 @@ import ( "net/http" ) -func (r *Ref) send(method string, body interface{}, opts ...httpOption) (*response, error) { - url := fmt.Sprintf("%s%s%s", r.client.baseURL, r.Path, ".json") +func (c *Client) send(method, path string, body interface{}, options ...httpOption) (*response, error) { + var opts []httpOption + if c.ao != "" { + opts = append(opts, withQueryParam("auth_variable_override", c.ao)) + } + opts = append(opts, options...) + var data io.Reader if body != nil { b, err := json.Marshal(body) @@ -21,16 +27,17 @@ func (r *Ref) send(method string, body interface{}, opts ...httpOption) (*respon opts = append(opts, withHeader("Content-Type", "application/json")) } + url := fmt.Sprintf("%s%s%s", c.url, path, ".json") req, err := http.NewRequest(method, url, data) if err != nil { return nil, err } - opts = append(opts, r.opts...) + for _, o := range opts { - o(req) + req = o(req) } - resp, err := r.client.hc.Do(req) + resp, err := c.hc.Do(req) if err != nil { return nil, err } @@ -79,28 +86,37 @@ func (r *response) CheckAndParse(want int, v interface{}) error { return nil } -type httpOption func(*http.Request) +type httpOption func(*http.Request) *http.Request func withHeader(key, value string) httpOption { - return func(r *http.Request) { + return func(r *http.Request) *http.Request { r.Header.Set(key, value) + return r } } func withQueryParam(key, value string) httpOption { - return func(r *http.Request) { + return func(r *http.Request) *http.Request { q := r.URL.Query() q.Add(key, value) r.URL.RawQuery = q.Encode() + return r } } func withQueryParams(qp queryParams) httpOption { - return func(r *http.Request) { + return func(r *http.Request) *http.Request { q := r.URL.Query() for k, v := range qp { q.Add(k, v) } r.URL.RawQuery = q.Encode() + return r + } +} + +func withContext(ctx context.Context) httpOption { + return func(r *http.Request) *http.Request { + return r.WithContext(ctx) } } diff --git a/db/query.go b/db/query.go index 8faea3fc..5c70315f 100644 --- a/db/query.go +++ b/db/query.go @@ -1,6 +1,7 @@ package db import ( + "context" "encoding/json" "fmt" "net/http" @@ -14,19 +15,44 @@ var reservedFilters = map[string]bool{ "$priority": true, } -type Query struct { - ref *Ref - qp queryParams +type Query interface { + Get(v interface{}) error + WithContext(ctx context.Context) Query } -func (q *Query) Get(v interface{}) error { - resp, err := q.ref.send("GET", nil, withQueryParams(q.qp)) +type queryImpl struct { + ctx context.Context + ref *Ref + opts []QueryOption +} + +func (q *queryImpl) Get(v interface{}) error { + qp := make(queryParams) + for _, o := range q.opts { + if err := o.apply(qp); err != nil { + return err + } + } + + opts := []httpOption{withQueryParams(qp)} + if q.ctx != nil { + opts = append(opts, withContext(q.ctx)) + } + resp, err := q.ref.send("GET", nil, opts...) if err != nil { return err } return resp.CheckAndParse(http.StatusOK, v) } +func (q *queryImpl) WithContext(ctx context.Context) Query { + return &queryImpl{ + ctx: ctx, + ref: q.ref, + opts: q.opts, + } +} + type QueryOption interface { apply(qp queryParams) error } @@ -51,43 +77,44 @@ func WithEqualTo(v interface{}) QueryOption { return &filterParam{"equalTo", v} } -func (r *Ref) OrderByChild(child string, opts ...QueryOption) (*Query, error) { - if child == "" { - return nil, fmt.Errorf("child path must be a non-empty string") - } - if _, ok := reservedFilters[child]; ok { - return nil, fmt.Errorf("invalid child path: %s", child) - } - segs, err := parsePath(child) - if err != nil { - return nil, err - } - opts = append(opts, orderByParam(strings.Join(segs, "/"))) +func (r *Ref) OrderByChild(child string, opts ...QueryOption) Query { + opts = append(opts, orderByChild(child)) return newQuery(r, opts) } -func (r *Ref) OrderByKey(opts ...QueryOption) (*Query, error) { +func (r *Ref) OrderByKey(opts ...QueryOption) Query { opts = append(opts, orderByParam("$key")) return newQuery(r, opts) } -func (r *Ref) OrderByValue(opts ...QueryOption) (*Query, error) { +func (r *Ref) OrderByValue(opts ...QueryOption) Query { opts = append(opts, orderByParam("$value")) return newQuery(r, opts) } -func newQuery(r *Ref, opts []QueryOption) (*Query, error) { - qp := make(queryParams) - for _, o := range opts { - if err := o.apply(qp); err != nil { - return nil, err - } - } - return &Query{ref: r, qp: qp}, nil +func newQuery(r *Ref, opts []QueryOption) Query { + return &queryImpl{ref: r, opts: opts} } type queryParams map[string]string +type orderByChild string + +func (p orderByChild) apply(qp queryParams) error { + if p == "" { + return fmt.Errorf("empty child path") + } else if strings.ContainsAny(string(p), invalidChars) { + return fmt.Errorf("invalid child path with illegal characters: %q", p) + } + segs := parsePath(string(p)) + b, err := json.Marshal(strings.Join(segs, "/")) + if err != nil { + return nil + } + qp["orderBy"] = string(b) + return nil +} + type orderByParam string func (p orderByParam) apply(qp queryParams) error { diff --git a/db/query_test.go b/db/query_test.go index 8812128a..3442b27d 100644 --- a/db/query_test.go +++ b/db/query_test.go @@ -11,13 +11,8 @@ func TestChildQuery(t *testing.T) { srv := mock.Start(client) defer srv.Close() - q, err := ref.OrderByChild("messages") - if err != nil { - t.Fatal(err) - } - var got map[string]interface{} - if err := q.Get(&got); err != nil { + if err := ref.OrderByChild("messages").Get(&got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { @@ -36,13 +31,8 @@ func TestNestedChildQuery(t *testing.T) { srv := mock.Start(client) defer srv.Close() - q, err := ref.OrderByChild("messages/ratings") - if err != nil { - t.Fatal(err) - } - var got map[string]interface{} - if err := q.Get(&got); err != nil { + if err := ref.OrderByChild("messages/ratings").Get(&got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { @@ -66,13 +56,8 @@ func TestChildQueryWithParams(t *testing.T) { WithEndAt("m50"), WithLimitToFirst(10), } - q, err := ref.OrderByChild("messages", opts...) - if err != nil { - t.Fatal(err) - } - var got map[string]interface{} - if err := q.Get(&got); err != nil { + if err := ref.OrderByChild("messages", opts...).Get(&got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { @@ -90,19 +75,34 @@ func TestChildQueryWithParams(t *testing.T) { }) } +func TestInvalidChildPath(t *testing.T) { + mock := &mockServer{Resp: "test"} + srv := mock.Start(client) + defer srv.Close() + + r := client.NewRef("/") + cases := []string{ + "foo$", "foo.", "foo#", "foo]", "foo[", + } + for _, tc := range cases { + var got string + if err := r.OrderByChild(tc).Get(&got); got != "" || err == nil { + t.Errorf("Get() = (%q, %v); want = (%q, error)", got, err, "") + } + } + if len(mock.Reqs) != 0 { + t.Errorf("Requests: %v; want: empty", mock.Reqs) + } +} + func TestKeyQuery(t *testing.T) { want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} mock := &mockServer{Resp: want} srv := mock.Start(client) defer srv.Close() - q, err := ref.OrderByKey() - if err != nil { - t.Fatal(err) - } - var got map[string]interface{} - if err := q.Get(&got); err != nil { + if err := ref.OrderByKey().Get(&got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { @@ -121,13 +121,8 @@ func TestValueQuery(t *testing.T) { srv := mock.Start(client) defer srv.Close() - q, err := ref.OrderByValue() - if err != nil { - t.Fatal(err) - } - var got map[string]interface{} - if err := q.Get(&got); err != nil { + if err := ref.OrderByValue().Get(&got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { @@ -146,13 +141,8 @@ func TestLimitFirstQuery(t *testing.T) { srv := mock.Start(client) defer srv.Close() - q, err := ref.OrderByChild("messages", WithLimitToFirst(10)) - if err != nil { - t.Fatal(err) - } - var got map[string]interface{} - if err := q.Get(&got); err != nil { + if err := ref.OrderByChild("messages", WithLimitToFirst(10)).Get(&got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { @@ -171,13 +161,8 @@ func TestLimitLastQuery(t *testing.T) { srv := mock.Start(client) defer srv.Close() - q, err := ref.OrderByChild("messages", WithLimitToLast(10)) - if err != nil { - t.Fatal(err) - } - var got map[string]interface{} - if err := q.Get(&got); err != nil { + if err := ref.OrderByChild("messages", WithLimitToLast(10)).Get(&got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { @@ -191,19 +176,18 @@ func TestLimitLastQuery(t *testing.T) { } func TestInvalidLimitQuery(t *testing.T) { - q, err := ref.OrderByChild("messages", WithLimitToFirst(10), WithLimitToLast(10)) - if q != nil || err == nil { - t.Errorf("Query(first=10, last=10) = (%v, %v); want (nil, error)", q, err) - } + want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() - q, err = ref.OrderByChild("messages", WithLimitToFirst(-10)) - if q != nil || err == nil { - t.Errorf("Query(first=-10) = (%v, %v); want (nil, error)", q, err) + var got map[string]interface{} + q := ref.OrderByChild("messages", WithLimitToFirst(10), WithLimitToLast(10)) + if err := q.Get(&got); got != nil || err == nil { + t.Errorf("Get() = (%v, %v); want = (nil, error)", got, err) } - - q, err = ref.OrderByChild("messages", WithLimitToLast(-10)) - if q != nil || err == nil { - t.Errorf("Query(last=-10) = (%v, %v); want (nil, error)", q, err) + if len(mock.Reqs) != 0 { + t.Errorf("Requests: %v; want: empty", mock.Reqs) } } @@ -213,13 +197,8 @@ func TestStartAtQuery(t *testing.T) { srv := mock.Start(client) defer srv.Close() - q, err := ref.OrderByChild("messages", WithStartAt(10)) - if err != nil { - t.Fatal(err) - } - var got map[string]interface{} - if err := q.Get(&got); err != nil { + if err := ref.OrderByChild("messages", WithStartAt(10)).Get(&got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { @@ -238,13 +217,8 @@ func TestEndAtQuery(t *testing.T) { srv := mock.Start(client) defer srv.Close() - q, err := ref.OrderByChild("messages", WithEndAt(10)) - if err != nil { - t.Fatal(err) - } - var got map[string]interface{} - if err := q.Get(&got); err != nil { + if err := ref.OrderByChild("messages", WithEndAt(10)).Get(&got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { @@ -263,11 +237,7 @@ func TestAllParamsQuery(t *testing.T) { srv := mock.Start(client) defer srv.Close() - q, err := ref.OrderByChild("messages", WithLimitToFirst(100), WithStartAt("bar"), WithEndAt("foo")) - if err != nil { - t.Fatal(err) - } - + q := ref.OrderByChild("messages", WithLimitToFirst(100), WithStartAt("bar"), WithEndAt("foo")) var got map[string]interface{} if err := q.Get(&got); err != nil { t.Fatal(err) @@ -293,13 +263,8 @@ func TestEqualToQuery(t *testing.T) { srv := mock.Start(client) defer srv.Close() - q, err := ref.OrderByChild("messages", WithEqualTo(10)) - if err != nil { - t.Fatal(err) - } - var got map[string]interface{} - if err := q.Get(&got); err != nil { + if err := ref.OrderByChild("messages", WithEqualTo(10)).Get(&got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { diff --git a/db/ref.go b/db/ref.go index 92267cbd..b00d8729 100644 --- a/db/ref.go +++ b/db/ref.go @@ -15,6 +15,7 @@ package db import ( + "context" "fmt" "net/http" "strings" @@ -24,27 +25,21 @@ type Ref struct { Key string Path string - client *Client segs []string - opts []httpOption + client *Client + ctx context.Context } func (r *Ref) Parent() *Ref { l := len(r.segs) if l > 0 { path := strings.Join(r.segs[:l-1], "/") - parent, _ := r.client.NewRef(path) - return parent + return r.client.NewRef(path) } return nil } -func (r *Ref) Child(path string) (*Ref, error) { - if path == "" { - return nil, fmt.Errorf("child path must not be empty") - } else if strings.HasPrefix(path, "/") { - return nil, fmt.Errorf("invalid child path with '/' prefix: %q", path) - } +func (r *Ref) Child(path string) *Ref { fp := fmt.Sprintf("%s/%s", r.Path, path) return r.client.NewRef(fp) } @@ -113,7 +108,7 @@ func (r *Ref) Push(v interface{}) (*Ref, error) { if err := resp.CheckAndParse(http.StatusOK, &d); err != nil { return nil, err } - return r.Child(d.Name) + return r.Child(d.Name), nil } func (r *Ref) Update(v map[string]interface{}) error { @@ -159,3 +154,10 @@ func (r *Ref) Delete() error { } return resp.CheckStatus(http.StatusOK) } + +func (r *Ref) send(method string, body interface{}, opts ...httpOption) (*response, error) { + if strings.ContainsAny(r.Path, invalidChars) { + return nil, fmt.Errorf("invalid path with illegal characters: %q", r.Path) + } + return r.client.send(method, r.Path, body, opts...) +} diff --git a/integration/db/db_test.go b/integration/db/db_test.go index e75d34d4..c73ec208 100644 --- a/integration/db/db_test.go +++ b/integration/db/db_test.go @@ -111,21 +111,9 @@ func initGuestClient(pid string) (*db.Client, error) { } func initRefs() { - var err error - ref, err = client.NewRef("_adminsdk/go/dinodb") - if err != nil { - log.Fatalln(err) - } - - dinos, err = ref.Child("dinosaurs") - if err != nil { - log.Fatalln(err) - } - - users, err = ref.Parent().Child("users") - if err != nil { - log.Fatalln(err) - } + ref = client.NewRef("_adminsdk/go/dinodb") + dinos = ref.Child("dinosaurs") + users = ref.Parent().Child("users") } func initRules() { @@ -196,10 +184,7 @@ func TestRef(t *testing.T) { } func TestChild(t *testing.T) { - c, err := ref.Child("dinosaurs") - if err != nil { - t.Fatal(err) - } + c := ref.Child("dinosaurs") if c.Key != "dinosaurs" { t.Errorf("Key = %q; want = %q", c.Key, "dinosaurs") } @@ -269,11 +254,7 @@ func TestGetIfChanged(t *testing.T) { } func TestGetChildValue(t *testing.T) { - c, err := ref.Child("dinosaurs") - if err != nil { - t.Fatal(err) - } - + c := ref.Child("dinosaurs") var m map[string]interface{} if err := c.Get(&m); err != nil { t.Fatal(err) @@ -284,11 +265,7 @@ func TestGetChildValue(t *testing.T) { } func TestGetGrandChildValue(t *testing.T) { - c, err := ref.Child("dinosaurs/lambeosaurus") - if err != nil { - t.Fatal(err) - } - + c := ref.Child("dinosaurs/lambeosaurus") var got Dinosaur if err := c.Get(&got); err != nil { t.Fatal(err) @@ -300,11 +277,7 @@ func TestGetGrandChildValue(t *testing.T) { } func TestGetNonExistingChild(t *testing.T) { - c, err := ref.Child("non_existing") - if err != nil { - t.Fatal(err) - } - + c := ref.Child("non_existing") var i interface{} if err := c.Get(&i); err != nil { t.Fatal(err) @@ -530,10 +503,7 @@ func TestTransaction(t *testing.T) { } func TestTransactionScalar(t *testing.T) { - cnt, err := users.Child("count") - if err != nil { - t.Fatal(err) - } + cnt := users.Child("count") if err := cnt.Set(42); err != nil { t.Fatal(err) } @@ -579,10 +549,7 @@ func TestDelete(t *testing.T) { } func TestNoAccess(t *testing.T) { - r, err := aoClient.NewRef(protectedRef(t, "_adminsdk/go/admin")) - if err != nil { - t.Fatal(err) - } + r := aoClient.NewRef(protectedRef(t, "_adminsdk/go/admin")) var got string if err := r.Get(&got); err == nil || got != "" { t.Errorf("Get() = (%q, %v); want = (empty, error)", got, err) @@ -597,10 +564,7 @@ func TestNoAccess(t *testing.T) { } func TestReadAccess(t *testing.T) { - r, err := aoClient.NewRef(protectedRef(t, "_adminsdk/go/protected/user2")) - if err != nil { - t.Fatal(err) - } + r := aoClient.NewRef(protectedRef(t, "_adminsdk/go/protected/user2")) var got string if err := r.Get(&got); err != nil || got != "test" { t.Errorf("Get() = (%q, %v); want = (%q, nil)", got, err, "test") @@ -613,10 +577,7 @@ func TestReadAccess(t *testing.T) { } func TestReadWriteAccess(t *testing.T) { - r, err := aoClient.NewRef(protectedRef(t, "_adminsdk/go/protected/user1")) - if err != nil { - t.Fatal(err) - } + r := aoClient.NewRef(protectedRef(t, "_adminsdk/go/protected/user1")) var got string if err := r.Get(&got); err != nil || got != "test" { t.Errorf("Get() = (%q, %v); want = (%q, nil)", got, err, "test") @@ -627,15 +588,8 @@ func TestReadWriteAccess(t *testing.T) { } func TestQueryAccess(t *testing.T) { - r, err := aoClient.NewRef("_adminsdk/go/protected") - if err != nil { - t.Fatal(err) - } - - q, err := r.OrderByKey(db.WithLimitToFirst(2)) - if err != nil { - t.Fatal(err) - } + r := aoClient.NewRef("_adminsdk/go/protected") + q := r.OrderByKey(db.WithLimitToFirst(2)) got := make(map[string]interface{}) if err := q.Get(&got); err == nil { t.Errorf("OrderByQuery() = nil; want = error") @@ -645,10 +599,7 @@ func TestQueryAccess(t *testing.T) { } func TestGuestAccess(t *testing.T) { - r, err := guestClient.NewRef(protectedRef(t, "_adminsdk/go/public")) - if err != nil { - t.Fatal(err) - } + r := guestClient.NewRef(protectedRef(t, "_adminsdk/go/public")) var got string if err := r.Get(&got); err != nil || got != "test" { t.Errorf("Get() = (%q, %v); want = (%q, nil)", got, err, "test") @@ -660,30 +611,21 @@ func TestGuestAccess(t *testing.T) { } got = "" - r, err = guestClient.NewRef("_adminsdk/go") - if err != nil { - t.Fatal(err) - } + r = guestClient.NewRef("_adminsdk/go") if err := r.Get(&got); err == nil || got != "" { t.Errorf("Get() = (%q, %v); want = (empty, error)", got, err) } else if err.Error() != permDenied { t.Errorf("Error = %q; want = %q", err.Error(), permDenied) } - c, err := r.Child("protected/user2") - if err != nil { - t.Fatal(err) - } + c := r.Child("protected/user2") if err := c.Get(&got); err == nil || got != "" { t.Errorf("Get() = (%q, %v); want = (empty, error)", got, err) } else if err.Error() != permDenied { t.Errorf("Error = %q; want = %q", err.Error(), permDenied) } - c, err = r.Child("admin") - if err != nil { - t.Fatal(err) - } + c = r.Child("admin") if err := c.Get(&got); err == nil || got != "" { t.Errorf("Get() = (%q, %v); want = (empty, error)", got, err) } else if err.Error() != permDenied { @@ -692,10 +634,7 @@ func TestGuestAccess(t *testing.T) { } func protectedRef(t *testing.T, p string) string { - r, err := client.NewRef(p) - if err != nil { - t.Fatal(err) - } + r := client.NewRef(p) if err := r.Set("test"); err != nil { t.Fatal(err) } diff --git a/integration/db/query_test.go b/integration/db/query_test.go index 4e08cdf2..6f02d61e 100644 --- a/integration/db/query_test.go +++ b/integration/db/query_test.go @@ -10,10 +10,7 @@ var heightSorted = []string{ func TestLimitToFirst(t *testing.T) { for _, tc := range []int{2, 10} { - q, err := dinos.OrderByChild("height", db.WithLimitToFirst(tc)) - if err != nil { - t.Fatal(err) - } + q := dinos.OrderByChild("height", db.WithLimitToFirst(tc)) var m map[string]Dinosaur if err := q.Get(&m); err != nil { t.Fatal(err) @@ -37,10 +34,7 @@ func TestLimitToFirst(t *testing.T) { func TestLimitToLast(t *testing.T) { for _, tc := range []int{2, 10} { - q, err := dinos.OrderByChild("height", db.WithLimitToLast(tc)) - if err != nil { - t.Fatal(err) - } + q := dinos.OrderByChild("height", db.WithLimitToLast(tc)) var m map[string]Dinosaur if err := q.Get(&m); err != nil { t.Fatal(err) @@ -63,10 +57,7 @@ func TestLimitToLast(t *testing.T) { } func TestStartAt(t *testing.T) { - q, err := dinos.OrderByChild("height", db.WithStartAt(3.5)) - if err != nil { - t.Fatal(err) - } + q := dinos.OrderByChild("height", db.WithStartAt(3.5)) var m map[string]Dinosaur if err := q.Get(&m); err != nil { t.Fatal(err) @@ -84,10 +75,7 @@ func TestStartAt(t *testing.T) { } func TestEndAt(t *testing.T) { - q, err := dinos.OrderByChild("height", db.WithEndAt(3.5)) - if err != nil { - t.Fatal(err) - } + q := dinos.OrderByChild("height", db.WithEndAt(3.5)) var m map[string]Dinosaur if err := q.Get(&m); err != nil { t.Fatal(err) @@ -105,10 +93,7 @@ func TestEndAt(t *testing.T) { } func TestStartAndEndAt(t *testing.T) { - q, err := dinos.OrderByChild("height", db.WithStartAt(2.5), db.WithEndAt(5)) - if err != nil { - t.Fatal(err) - } + q := dinos.OrderByChild("height", db.WithStartAt(2.5), db.WithEndAt(5)) var m map[string]Dinosaur if err := q.Get(&m); err != nil { t.Fatal(err) @@ -126,10 +111,7 @@ func TestStartAndEndAt(t *testing.T) { } func TestEqualTo(t *testing.T) { - q, err := dinos.OrderByChild("height", db.WithEqualTo(0.6)) - if err != nil { - t.Fatal(err) - } + q := dinos.OrderByChild("height", db.WithEqualTo(0.6)) var m map[string]Dinosaur if err := q.Get(&m); err != nil { t.Fatal(err) @@ -147,10 +129,7 @@ func TestEqualTo(t *testing.T) { } func TestOrderByNestedChild(t *testing.T) { - q, err := dinos.OrderByChild("ratings/pos", db.WithStartAt(4)) - if err != nil { - t.Fatal(err) - } + q := dinos.OrderByChild("ratings/pos", db.WithStartAt(4)) var m map[string]Dinosaur if err := q.Get(&m); err != nil { t.Fatal(err) @@ -168,10 +147,7 @@ func TestOrderByNestedChild(t *testing.T) { } func TestOrderByKey(t *testing.T) { - q, err := dinos.OrderByKey(db.WithLimitToFirst(2)) - if err != nil { - t.Fatal(err) - } + q := dinos.OrderByKey(db.WithLimitToFirst(2)) var m map[string]Dinosaur if err := q.Get(&m); err != nil { t.Fatal(err) @@ -189,14 +165,8 @@ func TestOrderByKey(t *testing.T) { } func TestOrderByValue(t *testing.T) { - scores, err := ref.Child("scores") - if err != nil { - t.Fatal(err) - } - q, err := scores.OrderByValue(db.WithLimitToLast(2)) - if err != nil { - t.Fatal(err) - } + scores := ref.Child("scores") + q := scores.OrderByValue(db.WithLimitToLast(2)) var m map[string]int if err := q.Get(&m); err != nil { t.Fatal(err) From f82a3e063675e4bf4ff96f424103893aeaad47f5 Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Thu, 26 Oct 2017 23:24:14 -0700 Subject: [PATCH 17/58] More code cleanup --- db/http_client.go | 2 +- db/query.go | 15 ++++----------- db/ref.go | 10 ++++++++++ 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/db/http_client.go b/db/http_client.go index 88fc2132..33dd38a9 100644 --- a/db/http_client.go +++ b/db/http_client.go @@ -15,7 +15,6 @@ func (c *Client) send(method, path string, body interface{}, options ...httpOpti if c.ao != "" { opts = append(opts, withQueryParam("auth_variable_override", c.ao)) } - opts = append(opts, options...) var data io.Reader if body != nil { @@ -33,6 +32,7 @@ func (c *Client) send(method, path string, body interface{}, options ...httpOpti return nil, err } + opts = append(opts, options...) for _, o := range opts { req = o(req) } diff --git a/db/query.go b/db/query.go index 5c70315f..d9a05dc0 100644 --- a/db/query.go +++ b/db/query.go @@ -9,12 +9,6 @@ import ( "strings" ) -var reservedFilters = map[string]bool{ - "$key": true, - "$value": true, - "$priority": true, -} - type Query interface { Get(v interface{}) error WithContext(ctx context.Context) Query @@ -46,11 +40,10 @@ func (q *queryImpl) Get(v interface{}) error { } func (q *queryImpl) WithContext(ctx context.Context) Query { - return &queryImpl{ - ctx: ctx, - ref: q.ref, - opts: q.opts, - } + q2 := new(queryImpl) + *q2 = *q + q2.ctx = ctx + return q2 } type QueryOption interface { diff --git a/db/ref.go b/db/ref.go index b00d8729..f0a8f460 100644 --- a/db/ref.go +++ b/db/ref.go @@ -52,6 +52,13 @@ func (r *Ref) Get(v interface{}) error { return resp.CheckAndParse(http.StatusOK, v) } +func (r *Ref) WithContext(ctx context.Context) Query { + r2 := new(Ref) + *r2 = *r + r2.ctx = ctx + return r2 +} + func (r *Ref) GetWithETag(v interface{}) (string, error) { resp, err := r.send("GET", nil, withHeader("X-Firebase-ETag", "true")) if err != nil { @@ -159,5 +166,8 @@ func (r *Ref) send(method string, body interface{}, opts ...httpOption) (*respon if strings.ContainsAny(r.Path, invalidChars) { return nil, fmt.Errorf("invalid path with illegal characters: %q", r.Path) } + if r.ctx != nil { + opts = append([]httpOption{withContext(r.ctx)}, opts...) + } return r.client.send(method, r.Path, body, opts...) } From 6ecb8de5c7b7794c60f19bee1e321ea41d64b753 Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Fri, 27 Oct 2017 01:38:28 -0700 Subject: [PATCH 18/58] Code clean up --- db/db.go | 21 +++------------------ db/db_test.go | 20 ++++++++++---------- db/query.go | 13 +++---------- firebase.go | 2 +- internal/internal.go | 2 +- 5 files changed, 18 insertions(+), 40 deletions(-) diff --git a/db/db.go b/db/db.go index 0a8ec86d..35b6395a 100644 --- a/db/db.go +++ b/db/db.go @@ -50,16 +50,13 @@ func NewClient(ctx context.Context, c *internal.DatabaseConfig) (*Client, error) return nil, err } - if c.BaseURL == "" { - return nil, fmt.Errorf("database url not specified") - } - p, err := url.Parse(c.BaseURL) + p, err := url.Parse(c.URL) if err != nil { return nil, err } else if p.Scheme != "https" { - return nil, fmt.Errorf("invalid database URL (incorrect scheme): %q", c.BaseURL) + return nil, fmt.Errorf("invalid database URL (incorrect scheme): %q", c.URL) } else if !strings.HasSuffix(p.Host, ".firebaseio.com") { - return nil, fmt.Errorf("invalid database URL (incorrest host): %q", c.BaseURL) + return nil, fmt.Errorf("invalid database URL (incorrest host): %q", c.URL) } var ao []byte @@ -77,18 +74,6 @@ func NewClient(ctx context.Context, c *internal.DatabaseConfig) (*Client, error) }, nil } -func newHTTPOptions(m map[string]interface{}) ([]httpOption, error) { - var opts []httpOption - if m == nil || len(m) > 0 { - ao, err := json.Marshal(m) - if err != nil { - return nil, err - } - opts = append(opts, withQueryParam("auth_variable_override", string(ao))) - } - return opts, nil -} - type AuthOverrides struct { Map map[string]interface{} } diff --git a/db/db_test.go b/db/db_test.go index 934ee8f4..aea74816 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -40,7 +40,7 @@ func TestMain(m *testing.M) { var err error client, err = NewClient(context.Background(), &internal.DatabaseConfig{ Opts: testOpts, - BaseURL: testURL, + URL: testURL, Version: "1.2.3", AO: map[string]interface{}{}, }) @@ -51,7 +51,7 @@ func TestMain(m *testing.M) { ao := map[string]interface{}{"uid": "user1"} aoClient, err = NewClient(context.Background(), &internal.DatabaseConfig{ Opts: testOpts, - BaseURL: testURL, + URL: testURL, Version: "1.2.3", AO: ao, }) @@ -72,9 +72,9 @@ func TestMain(m *testing.M) { func TestNewClient(t *testing.T) { c, err := NewClient(context.Background(), &internal.DatabaseConfig{ - Opts: testOpts, - BaseURL: testURL, - AO: make(map[string]interface{}), + Opts: testOpts, + URL: testURL, + AO: make(map[string]interface{}), }) if err != nil { t.Fatal(err) @@ -97,9 +97,9 @@ func TestNewClientAuthOverrides(t *testing.T) { } for _, tc := range cases { c, err := NewClient(context.Background(), &internal.DatabaseConfig{ - Opts: testOpts, - BaseURL: testURL, - AO: tc, + Opts: testOpts, + URL: testURL, + AO: tc, }) if err != nil { t.Fatal(err) @@ -129,8 +129,8 @@ func TestNewClientError(t *testing.T) { } for _, tc := range cases { c, err := NewClient(context.Background(), &internal.DatabaseConfig{ - Opts: testOpts, - BaseURL: tc, + Opts: testOpts, + URL: tc, }) if c != nil || err == nil { t.Errorf("NewClient() = (%v, %v); want = (nil, error)", c, err) diff --git a/db/query.go b/db/query.go index d9a05dc0..d7338aa7 100644 --- a/db/query.go +++ b/db/query.go @@ -71,22 +71,15 @@ func WithEqualTo(v interface{}) QueryOption { } func (r *Ref) OrderByChild(child string, opts ...QueryOption) Query { - opts = append(opts, orderByChild(child)) - return newQuery(r, opts) + return &queryImpl{ref: r, opts: append(opts, orderByChild(child))} } func (r *Ref) OrderByKey(opts ...QueryOption) Query { - opts = append(opts, orderByParam("$key")) - return newQuery(r, opts) + return &queryImpl{ref: r, opts: append(opts, orderByParam("$key"))} } func (r *Ref) OrderByValue(opts ...QueryOption) Query { - opts = append(opts, orderByParam("$value")) - return newQuery(r, opts) -} - -func newQuery(r *Ref, opts []QueryOption) Query { - return &queryImpl{ref: r, opts: opts} + return &queryImpl{ref: r, opts: append(opts, orderByParam("$value"))} } type queryParams map[string]string diff --git a/firebase.go b/firebase.go index 8e0c6c84..b32a69a5 100644 --- a/firebase.go +++ b/firebase.go @@ -66,7 +66,7 @@ func (a *App) Auth(ctx context.Context) (*auth.Client, error) { func (a *App) Database(ctx context.Context) (*db.Client, error) { conf := &internal.DatabaseConfig{ AO: a.ao, - BaseURL: a.dbURL, + URL: a.dbURL, Opts: a.opts, Version: Version, } diff --git a/internal/internal.go b/internal/internal.go index 262f356b..5370c9e5 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -36,7 +36,7 @@ type AuthConfig struct { type DatabaseConfig struct { Opts []option.ClientOption - BaseURL string + URL string Version string AO map[string]interface{} } From deb4eaca9bc89bc6ee48d7bea7fe64eb059b1eb8 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Fri, 27 Oct 2017 12:13:10 -0700 Subject: [PATCH 19/58] Refactored the http client code --- db/db.go | 1 - db/http_client.go | 62 +++++++++++++++++++++++++++++------------------ db/query.go | 35 +++++++++++++++++--------- db/ref.go | 36 +++++++++++++++------------ 4 files changed, 81 insertions(+), 53 deletions(-) diff --git a/db/db.go b/db/db.go index 35b6395a..0372d1a3 100644 --- a/db/db.go +++ b/db/db.go @@ -31,7 +31,6 @@ import ( "google.golang.org/api/transport" ) -const invalidChars = "[].#$" const userAgentFormat = "Firebase/HTTP/%s/%s/AdminGo" // Client is the interface for the Firebase Realtime Database service. diff --git a/db/http_client.go b/db/http_client.go index 33dd38a9..c20be5a5 100644 --- a/db/http_client.go +++ b/db/http_client.go @@ -8,17 +8,28 @@ import ( "io" "io/ioutil" "net/http" + "strings" ) -func (c *Client) send(method, path string, body interface{}, options ...httpOption) (*response, error) { - var opts []httpOption - if c.ao != "" { - opts = append(opts, withQueryParam("auth_variable_override", c.ao)) +const invalidChars = "[].#$" +const authVarOverride = "auth_variable_override" + +type request struct { + Method string + Path string + Body interface{} + Opts []httpOption +} + +func (c *Client) send(ctx context.Context, r *request) (*response, error) { + if strings.ContainsAny(r.Path, invalidChars) { + return nil, fmt.Errorf("invalid path with illegal characters: %q", r.Path) } + var opts []httpOption var data io.Reader - if body != nil { - b, err := json.Marshal(body) + if r.Body != nil { + b, err := json.Marshal(r.Body) if err != nil { return nil, err } @@ -26,18 +37,30 @@ func (c *Client) send(method, path string, body interface{}, options ...httpOpti opts = append(opts, withHeader("Content-Type", "application/json")) } - url := fmt.Sprintf("%s%s%s", c.url, path, ".json") - req, err := http.NewRequest(method, url, data) + url := fmt.Sprintf("%s%s.json", c.url, r.Path) + req, err := http.NewRequest(r.Method, url, data) if err != nil { return nil, err } - opts = append(opts, options...) + if ctx != nil { + req = req.WithContext(ctx) + } + + if c.ao != "" { + opts = append(opts, withQueryParam(authVarOverride, c.ao)) + } + opts = append(opts, r.Opts...) + + return doSend(c.hc, req, opts...) +} + +func doSend(hc *http.Client, req *http.Request, opts ...httpOption) (*response, error) { for _, o := range opts { - req = o(req) + o(req) } - resp, err := c.hc.Do(req) + resp, err := hc.Do(req) if err != nil { return nil, err } @@ -86,37 +109,28 @@ func (r *response) CheckAndParse(want int, v interface{}) error { return nil } -type httpOption func(*http.Request) *http.Request +type httpOption func(*http.Request) func withHeader(key, value string) httpOption { - return func(r *http.Request) *http.Request { + return func(r *http.Request) { r.Header.Set(key, value) - return r } } func withQueryParam(key, value string) httpOption { - return func(r *http.Request) *http.Request { + return func(r *http.Request) { q := r.URL.Query() q.Add(key, value) r.URL.RawQuery = q.Encode() - return r } } func withQueryParams(qp queryParams) httpOption { - return func(r *http.Request) *http.Request { + return func(r *http.Request) { q := r.URL.Query() for k, v := range qp { q.Add(k, v) } r.URL.RawQuery = q.Encode() - return r - } -} - -func withContext(ctx context.Context) httpOption { - return func(r *http.Request) *http.Request { - return r.WithContext(ctx) } } diff --git a/db/query.go b/db/query.go index d7338aa7..99a05aef 100644 --- a/db/query.go +++ b/db/query.go @@ -15,24 +15,26 @@ type Query interface { } type queryImpl struct { - ctx context.Context - ref *Ref - opts []QueryOption + Ctx context.Context + Client *Client + Path string + Opts []QueryOption } func (q *queryImpl) Get(v interface{}) error { qp := make(queryParams) - for _, o := range q.opts { + for _, o := range q.Opts { if err := o.apply(qp); err != nil { return err } } - opts := []httpOption{withQueryParams(qp)} - if q.ctx != nil { - opts = append(opts, withContext(q.ctx)) + req := &request{ + Method: "GET", + Path: q.Path, + Opts: []httpOption{withQueryParams(qp)}, } - resp, err := q.ref.send("GET", nil, opts...) + resp, err := q.Client.send(q.Ctx, req) if err != nil { return err } @@ -42,7 +44,7 @@ func (q *queryImpl) Get(v interface{}) error { func (q *queryImpl) WithContext(ctx context.Context) Query { q2 := new(queryImpl) *q2 = *q - q2.ctx = ctx + q2.Ctx = ctx return q2 } @@ -71,15 +73,24 @@ func WithEqualTo(v interface{}) QueryOption { } func (r *Ref) OrderByChild(child string, opts ...QueryOption) Query { - return &queryImpl{ref: r, opts: append(opts, orderByChild(child))} + return newQuery(r, orderByChild(child), opts...) } func (r *Ref) OrderByKey(opts ...QueryOption) Query { - return &queryImpl{ref: r, opts: append(opts, orderByParam("$key"))} + return newQuery(r, orderByParam("$key"), opts...) } func (r *Ref) OrderByValue(opts ...QueryOption) Query { - return &queryImpl{ref: r, opts: append(opts, orderByParam("$value"))} + return newQuery(r, orderByParam("$value"), opts...) +} + +func newQuery(r *Ref, orderBy QueryOption, opts ...QueryOption) Query { + return &queryImpl{ + Ctx: r.ctx, + Client: r.client, + Path: r.Path, + Opts: append(opts, orderBy), + } } type queryParams map[string]string diff --git a/db/ref.go b/db/ref.go index f0a8f460..1838aabf 100644 --- a/db/ref.go +++ b/db/ref.go @@ -45,7 +45,7 @@ func (r *Ref) Child(path string) *Ref { } func (r *Ref) Get(v interface{}) error { - resp, err := r.send("GET", nil) + resp, err := r.send("GET") if err != nil { return err } @@ -60,7 +60,7 @@ func (r *Ref) WithContext(ctx context.Context) Query { } func (r *Ref) GetWithETag(v interface{}) (string, error) { - resp, err := r.send("GET", nil, withHeader("X-Firebase-ETag", "true")) + resp, err := r.send("GET", withHeader("X-Firebase-ETag", "true")) if err != nil { return "", err } else if err := resp.CheckAndParse(http.StatusOK, v); err != nil { @@ -70,7 +70,7 @@ func (r *Ref) GetWithETag(v interface{}) (string, error) { } func (r *Ref) GetIfChanged(etag string, v interface{}) (bool, string, error) { - resp, err := r.send("GET", nil, withHeader("If-None-Match", etag)) + resp, err := r.send("GET", withHeader("If-None-Match", etag)) if err != nil { return false, "", err } else if err := resp.CheckAndParse(http.StatusOK, v); err == nil { @@ -82,7 +82,7 @@ func (r *Ref) GetIfChanged(etag string, v interface{}) (bool, string, error) { } func (r *Ref) Set(v interface{}) error { - resp, err := r.send("PUT", v, withQueryParam("print", "silent")) + resp, err := r.sendWithBody("PUT", v, withQueryParam("print", "silent")) if err != nil { return err } @@ -90,7 +90,7 @@ func (r *Ref) Set(v interface{}) error { } func (r *Ref) SetIfUnchanged(etag string, v interface{}) (bool, error) { - resp, err := r.send("PUT", v, withHeader("If-Match", etag)) + resp, err := r.sendWithBody("PUT", v, withHeader("If-Match", etag)) if err != nil { return false, err } else if err := resp.CheckStatus(http.StatusOK); err == nil { @@ -105,7 +105,7 @@ func (r *Ref) Push(v interface{}) (*Ref, error) { if v == nil { v = "" } - resp, err := r.send("POST", v) + resp, err := r.sendWithBody("POST", v) if err != nil { return nil, err } @@ -122,7 +122,7 @@ func (r *Ref) Update(v map[string]interface{}) error { if len(v) == 0 { return fmt.Errorf("value argument must be a non-empty map") } - resp, err := r.send("PATCH", v, withQueryParam("print", "silent")) + resp, err := r.sendWithBody("PATCH", v, withQueryParam("print", "silent")) if err != nil { return err } @@ -143,7 +143,7 @@ func (r *Ref) Transaction(fn UpdateFn) error { if err != nil { return err } - resp, err := r.send("PUT", new, withHeader("If-Match", etag)) + resp, err := r.sendWithBody("PUT", new, withHeader("If-Match", etag)) if err := resp.CheckStatus(http.StatusOK); err == nil { return nil } else if err := resp.CheckAndParse(http.StatusPreconditionFailed, &curr); err != nil { @@ -155,19 +155,23 @@ func (r *Ref) Transaction(fn UpdateFn) error { } func (r *Ref) Delete() error { - resp, err := r.send("DELETE", nil) + resp, err := r.send("DELETE") if err != nil { return err } return resp.CheckStatus(http.StatusOK) } -func (r *Ref) send(method string, body interface{}, opts ...httpOption) (*response, error) { - if strings.ContainsAny(r.Path, invalidChars) { - return nil, fmt.Errorf("invalid path with illegal characters: %q", r.Path) - } - if r.ctx != nil { - opts = append([]httpOption{withContext(r.ctx)}, opts...) +func (r *Ref) send(method string, opts ...httpOption) (*response, error) { + return r.sendWithBody(method, nil, opts...) +} + +func (r *Ref) sendWithBody(method string, body interface{}, opts ...httpOption) (*response, error) { + req := &request{ + Method: method, + Body: body, + Path: r.Path, + Opts: opts, } - return r.client.send(method, r.Path, body, opts...) + return r.client.send(r.ctx, req) } From caeaafe1eb914dd7250bd04e6677a915425c5d7b Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Fri, 27 Oct 2017 14:14:41 -0700 Subject: [PATCH 20/58] More tests --- db/db_test.go | 30 +++++++- db/http_client.go | 3 +- db/query.go | 28 +++++--- db/query_test.go | 135 +++++++++++++++++++++++++++++++---- db/ref.go | 5 +- db/ref_test.go | 75 +++++++++++++------ integration/db/db_test.go | 28 ++++++-- integration/db/query_test.go | 33 ++++++++- 8 files changed, 280 insertions(+), 57 deletions(-) diff --git a/db/db_test.go b/db/db_test.go index aea74816..e0cae556 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -34,7 +34,7 @@ var testOpts = []option.ClientOption{ var client *Client var aoClient *Client -var ref *Ref +var testref *Ref func TestMain(m *testing.M) { var err error @@ -65,7 +65,7 @@ func TestMain(m *testing.M) { } testAuthOverrides = string(b) - ref = client.NewRef("peter") + testref = client.NewRef("peter") testUserAgent = fmt.Sprintf(userAgentFormat, "1.2.3", runtime.Version()) os.Exit(m.Run()) } @@ -155,7 +155,10 @@ func TestNewRef(t *testing.T) { for _, tc := range cases { r := client.NewRef(tc.Path) if r.client == nil { - t.Errorf("Client = nil; want = %v", client) + t.Errorf("Client = nil; want = %v", r.client) + } + if r.ctx != nil { + t.Errorf("Ctx = %v; want nil", r.ctx) } if r.Path != tc.WantPath { t.Errorf("Path = %q; want = %q", r.Path, tc.WantPath) @@ -251,6 +254,27 @@ func TestInvalidPath(t *testing.T) { } } +func TestInvalidChildPath(t *testing.T) { + mock := &mockServer{Resp: "test"} + srv := mock.Start(client) + defer srv.Close() + + cases := []string{ + "foo$", "foo.", "foo#", "foo]", "foo[", + } + r := client.NewRef("test") + for _, tc := range cases { + var got string + if err := r.Child(tc).Get(&got); got != "" || err == nil { + t.Errorf("Get() = (%q, %v); want = (%q, error)", got, err, "") + } + } + + if len(mock.Reqs) != 0 { + t.Errorf("Requests: %v; want: empty", mock.Reqs) + } +} + func checkOnlyRequest(t *testing.T, got []*testReq, want *testReq) { checkAllRequests(t, got, []*testReq{want}) } diff --git a/db/http_client.go b/db/http_client.go index c20be5a5..a64ee997 100644 --- a/db/http_client.go +++ b/db/http_client.go @@ -2,13 +2,14 @@ package db import ( "bytes" - "context" "encoding/json" "fmt" "io" "io/ioutil" "net/http" "strings" + + "golang.org/x/net/context" ) const invalidChars = "[].#$" diff --git a/db/query.go b/db/query.go index 99a05aef..34359eda 100644 --- a/db/query.go +++ b/db/query.go @@ -1,12 +1,13 @@ package db import ( - "context" "encoding/json" "fmt" "net/http" "strconv" "strings" + + "golang.org/x/net/context" ) type Query interface { @@ -18,11 +19,15 @@ type queryImpl struct { Ctx context.Context Client *Client Path string + OB orderBy Opts []QueryOption } func (q *queryImpl) Get(v interface{}) error { qp := make(queryParams) + if err := q.OB.apply(qp); err != nil { + return err + } for _, o := range q.Opts { if err := o.apply(qp); err != nil { return err @@ -48,6 +53,12 @@ func (q *queryImpl) WithContext(ctx context.Context) Query { return q2 } +type queryParams map[string]string + +type orderBy interface { + apply(qp queryParams) error +} + type QueryOption interface { apply(qp queryParams) error } @@ -77,24 +88,23 @@ func (r *Ref) OrderByChild(child string, opts ...QueryOption) Query { } func (r *Ref) OrderByKey(opts ...QueryOption) Query { - return newQuery(r, orderByParam("$key"), opts...) + return newQuery(r, orderByProperty("$key"), opts...) } func (r *Ref) OrderByValue(opts ...QueryOption) Query { - return newQuery(r, orderByParam("$value"), opts...) + return newQuery(r, orderByProperty("$value"), opts...) } -func newQuery(r *Ref, orderBy QueryOption, opts ...QueryOption) Query { +func newQuery(r *Ref, ob orderBy, opts ...QueryOption) Query { return &queryImpl{ Ctx: r.ctx, Client: r.client, Path: r.Path, - Opts: append(opts, orderBy), + OB: ob, + Opts: opts, } } -type queryParams map[string]string - type orderByChild string func (p orderByChild) apply(qp queryParams) error { @@ -112,9 +122,9 @@ func (p orderByChild) apply(qp queryParams) error { return nil } -type orderByParam string +type orderByProperty string -func (p orderByParam) apply(qp queryParams) error { +func (p orderByProperty) apply(qp queryParams) error { b, err := json.Marshal(p) if err != nil { return err diff --git a/db/query_test.go b/db/query_test.go index 3442b27d..eb84cb67 100644 --- a/db/query_test.go +++ b/db/query_test.go @@ -1,10 +1,119 @@ package db import ( + "context" "reflect" "testing" ) +func TestQueryWithContext(t *testing.T) { + q := client.NewRef("peter").OrderByChild("messages") + if q.(*queryImpl).Ctx != nil { + t.Errorf("Ctx = %v; want nil", q.(*queryImpl).Ctx) + } + + ctx, cancel := context.WithCancel(context.Background()) + q = q.WithContext(ctx) + if q.(*queryImpl).Ctx != ctx { + t.Errorf("Ctx = %v; want %v", q.(*queryImpl).Ctx, ctx) + } + + want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + var got map[string]interface{} + if err := q.Get(&got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("Get() = %v; want = %v", got, want) + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"orderBy": "\"messages\""}, + }) + + cancel() + got = nil + if err := q.Get(&got); len(got) != 0 || err == nil { + t.Errorf("Get() = (%v, %v); want = (empty, error)", got, err) + } +} + +func TestQueryFromRefWithContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + q := client.NewRef("peter").WithContext(ctx).OrderByChild("messages") + if q.(*queryImpl).Ctx != ctx { + t.Errorf("Ctx = %v; want %v", q.(*queryImpl).Ctx, ctx) + } + + want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + var got map[string]interface{} + if err := q.Get(&got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("Get() = %v; want = %v", got, want) + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"orderBy": "\"messages\""}, + }) + + cancel() + got = nil + if err := q.Get(&got); len(got) != 0 || err == nil { + t.Errorf("Get() = (%v, %v); want = (empty, error)", got, err) + } +} + +func TestQueryWithContextPrecedence(t *testing.T) { + ctx1 := context.Background() + ctx2, cancel := context.WithCancel(ctx1) + + r := client.NewRef("peter").WithContext(ctx1) + q := r.OrderByChild("messages").WithContext(ctx2) + if q.(*queryImpl).Ctx != ctx2 { + t.Errorf("Ctx = %v; want %v", q.(*queryImpl).Ctx, ctx2) + } + + want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + var got map[string]interface{} + if err := q.Get(&got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("Get() = %v; want = %v", got, want) + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"orderBy": "\"messages\""}, + }) + + cancel() + got = nil + if err := q.Get(&got); len(got) != 0 || err == nil { + t.Errorf("Get() = (%v, %v); want = (empty, error)", got, err) + } + + if err := r.Get(&got); !reflect.DeepEqual(got, want) || err != nil { + t.Errorf("Get() = (%v, %v); want = (%v, nil)", got, err, want) + } +} + func TestChildQuery(t *testing.T) { want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} mock := &mockServer{Resp: want} @@ -12,7 +121,7 @@ func TestChildQuery(t *testing.T) { defer srv.Close() var got map[string]interface{} - if err := ref.OrderByChild("messages").Get(&got); err != nil { + if err := testref.OrderByChild("messages").Get(&got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { @@ -32,7 +141,7 @@ func TestNestedChildQuery(t *testing.T) { defer srv.Close() var got map[string]interface{} - if err := ref.OrderByChild("messages/ratings").Get(&got); err != nil { + if err := testref.OrderByChild("messages/ratings").Get(&got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { @@ -57,7 +166,7 @@ func TestChildQueryWithParams(t *testing.T) { WithLimitToFirst(10), } var got map[string]interface{} - if err := ref.OrderByChild("messages", opts...).Get(&got); err != nil { + if err := testref.OrderByChild("messages", opts...).Get(&got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { @@ -75,7 +184,7 @@ func TestChildQueryWithParams(t *testing.T) { }) } -func TestInvalidChildPath(t *testing.T) { +func TestInvalidOrderByChild(t *testing.T) { mock := &mockServer{Resp: "test"} srv := mock.Start(client) defer srv.Close() @@ -102,7 +211,7 @@ func TestKeyQuery(t *testing.T) { defer srv.Close() var got map[string]interface{} - if err := ref.OrderByKey().Get(&got); err != nil { + if err := testref.OrderByKey().Get(&got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { @@ -122,7 +231,7 @@ func TestValueQuery(t *testing.T) { defer srv.Close() var got map[string]interface{} - if err := ref.OrderByValue().Get(&got); err != nil { + if err := testref.OrderByValue().Get(&got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { @@ -142,7 +251,7 @@ func TestLimitFirstQuery(t *testing.T) { defer srv.Close() var got map[string]interface{} - if err := ref.OrderByChild("messages", WithLimitToFirst(10)).Get(&got); err != nil { + if err := testref.OrderByChild("messages", WithLimitToFirst(10)).Get(&got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { @@ -162,7 +271,7 @@ func TestLimitLastQuery(t *testing.T) { defer srv.Close() var got map[string]interface{} - if err := ref.OrderByChild("messages", WithLimitToLast(10)).Get(&got); err != nil { + if err := testref.OrderByChild("messages", WithLimitToLast(10)).Get(&got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { @@ -182,7 +291,7 @@ func TestInvalidLimitQuery(t *testing.T) { defer srv.Close() var got map[string]interface{} - q := ref.OrderByChild("messages", WithLimitToFirst(10), WithLimitToLast(10)) + q := testref.OrderByChild("messages", WithLimitToFirst(10), WithLimitToLast(10)) if err := q.Get(&got); got != nil || err == nil { t.Errorf("Get() = (%v, %v); want = (nil, error)", got, err) } @@ -198,7 +307,7 @@ func TestStartAtQuery(t *testing.T) { defer srv.Close() var got map[string]interface{} - if err := ref.OrderByChild("messages", WithStartAt(10)).Get(&got); err != nil { + if err := testref.OrderByChild("messages", WithStartAt(10)).Get(&got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { @@ -218,7 +327,7 @@ func TestEndAtQuery(t *testing.T) { defer srv.Close() var got map[string]interface{} - if err := ref.OrderByChild("messages", WithEndAt(10)).Get(&got); err != nil { + if err := testref.OrderByChild("messages", WithEndAt(10)).Get(&got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { @@ -237,7 +346,7 @@ func TestAllParamsQuery(t *testing.T) { srv := mock.Start(client) defer srv.Close() - q := ref.OrderByChild("messages", WithLimitToFirst(100), WithStartAt("bar"), WithEndAt("foo")) + q := testref.OrderByChild("messages", WithLimitToFirst(100), WithStartAt("bar"), WithEndAt("foo")) var got map[string]interface{} if err := q.Get(&got); err != nil { t.Fatal(err) @@ -264,7 +373,7 @@ func TestEqualToQuery(t *testing.T) { defer srv.Close() var got map[string]interface{} - if err := ref.OrderByChild("messages", WithEqualTo(10)).Get(&got); err != nil { + if err := testref.OrderByChild("messages", WithEqualTo(10)).Get(&got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { diff --git a/db/ref.go b/db/ref.go index 1838aabf..4b83eebf 100644 --- a/db/ref.go +++ b/db/ref.go @@ -15,10 +15,11 @@ package db import ( - "context" "fmt" "net/http" "strings" + + "golang.org/x/net/context" ) type Ref struct { @@ -52,7 +53,7 @@ func (r *Ref) Get(v interface{}) error { return resp.CheckAndParse(http.StatusOK, v) } -func (r *Ref) WithContext(ctx context.Context) Query { +func (r *Ref) WithContext(ctx context.Context) *Ref { r2 := new(Ref) *r2 = *r r2.ctx = ctx diff --git a/db/ref_test.go b/db/ref_test.go index fa1a6da6..4b3f61bc 100644 --- a/db/ref_test.go +++ b/db/ref_test.go @@ -4,8 +4,43 @@ import ( "net/http" "reflect" "testing" + + "golang.org/x/net/context" ) +func TestRefWithContext(t *testing.T) { + r := client.NewRef("peter") + if r.ctx != nil { + t.Errorf("Ctx = %v; want nil", r.ctx) + } + + ctx, cancel := context.WithCancel(context.Background()) + r = r.WithContext(ctx) + if r.ctx != ctx { + t.Errorf("Ctx = %v; want %v", r.ctx, ctx) + } + + want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + var got map[string]interface{} + if err := r.Get(&got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("Get() = %v; want = %v", got, want) + } + checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: "/peter.json"}) + + cancel() + got = nil + if err := r.Get(&got); len(got) != 0 || err == nil { + t.Errorf("Get() = (%v, %v); want = (empty, error)", got, err) + } +} + func TestGet(t *testing.T) { want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} mock := &mockServer{Resp: want} @@ -13,7 +48,7 @@ func TestGet(t *testing.T) { defer srv.Close() var got map[string]interface{} - if err := ref.Get(&got); err != nil { + if err := testref.Get(&got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { @@ -29,7 +64,7 @@ func TestGetWithStruct(t *testing.T) { defer srv.Close() var got person - if err := ref.Get(&got); err != nil { + if err := testref.Get(&got); err != nil { t.Fatal(err) } if want != got { @@ -48,7 +83,7 @@ func TestGetWithETag(t *testing.T) { defer srv.Close() var got map[string]interface{} - etag, err := ref.GetWithETag(&got) + etag, err := testref.GetWithETag(&got) if err != nil { t.Fatal(err) } @@ -75,7 +110,7 @@ func TestGetIfChanged(t *testing.T) { defer srv.Close() var got map[string]interface{} - ok, etag, err := ref.GetIfChanged("old-etag", &got) + ok, etag, err := testref.GetIfChanged("old-etag", &got) if err != nil { t.Fatal(err) } @@ -92,7 +127,7 @@ func TestGetIfChanged(t *testing.T) { mock.Status = http.StatusNotModified mock.Resp = nil var got2 map[string]interface{} - ok, etag, err = ref.GetIfChanged("new-etag", &got2) + ok, etag, err = testref.GetIfChanged("new-etag", &got2) if err != nil { t.Fatal(err) } @@ -126,7 +161,7 @@ func TestWerlformedHttpError(t *testing.T) { defer srv.Close() var got person - err := ref.Get(&got) + err := testref.Get(&got) want := "http error status: 500; reason: test error" if err == nil || err.Error() != want { t.Errorf("Get() = %v; want = %v", err, want) @@ -140,7 +175,7 @@ func TestUnexpectedHttpError(t *testing.T) { defer srv.Close() var got person - err := ref.Get(&got) + err := testref.Get(&got) want := "http error status: 500; message: \"unexpected error\"" if err == nil || err.Error() != want { t.Errorf("Get() = %v; want = %v", err, want) @@ -154,7 +189,7 @@ func TestSet(t *testing.T) { defer srv.Close() want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} - if err := ref.Set(want); err != nil { + if err := testref.Set(want); err != nil { t.Fatal(err) } checkOnlyRequest(t, mock.Reqs, &testReq{ @@ -171,7 +206,7 @@ func TestSetWithStruct(t *testing.T) { defer srv.Close() want := &person{"Peter Parker", 17} - if err := ref.Set(&want); err != nil { + if err := testref.Set(&want); err != nil { t.Fatal(err) } checkOnlyRequest(t, mock.Reqs, &testReq{ @@ -188,7 +223,7 @@ func TestSetIfUnchanged(t *testing.T) { defer srv.Close() want := &person{"Peter Parker", 17} - ok, err := ref.SetIfUnchanged("mock-etag", &want) + ok, err := testref.SetIfUnchanged("mock-etag", &want) if err != nil { t.Fatal(err) } @@ -212,7 +247,7 @@ func TestSetIfUnchangedError(t *testing.T) { defer srv.Close() want := &person{"Peter Parker", 17} - ok, err := ref.SetIfUnchanged("mock-etag", &want) + ok, err := testref.SetIfUnchanged("mock-etag", &want) if err != nil { t.Fatal(err) } @@ -232,7 +267,7 @@ func TestPush(t *testing.T) { srv := mock.Start(client) defer srv.Close() - child, err := ref.Push(nil) + child, err := testref.Push(nil) if err != nil { t.Fatal(err) } @@ -253,7 +288,7 @@ func TestPushWithValue(t *testing.T) { defer srv.Close() want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} - child, err := ref.Push(want) + child, err := testref.Push(want) if err != nil { t.Fatal(err) } @@ -274,7 +309,7 @@ func TestUpdate(t *testing.T) { srv := mock.Start(client) defer srv.Close() - if err := ref.Update(want); err != nil { + if err := testref.Update(want); err != nil { t.Fatal(err) } checkOnlyRequest(t, mock.Reqs, &testReq{ @@ -286,12 +321,12 @@ func TestUpdate(t *testing.T) { } func TestInvalidUpdate(t *testing.T) { - if err := ref.Update(nil); err == nil { + if err := testref.Update(nil); err == nil { t.Errorf("Update(nil) = nil; want error") } m := make(map[string]interface{}) - if err := ref.Update(m); err == nil { + if err := testref.Update(m); err == nil { t.Errorf("Update(map{}) = nil; want error") } } @@ -309,7 +344,7 @@ func TestTransaction(t *testing.T) { p["age"] = p["age"].(float64) + 1.0 return p, nil } - if err := ref.Transaction(fn); err != nil { + if err := testref.Transaction(fn); err != nil { t.Fatal(err) } checkAllRequests(t, mock.Reqs, []*testReq{ @@ -352,7 +387,7 @@ func TestTransactionRetry(t *testing.T) { p["age"] = p["age"].(float64) + 1.0 return p, nil } - if err := ref.Transaction(fn); err != nil { + if err := testref.Transaction(fn); err != nil { t.Fatal(err) } if cnt != 2 { @@ -404,7 +439,7 @@ func TestTransactionAbort(t *testing.T) { p["age"] = p["age"].(float64) + 1.0 return p, nil } - err := ref.Transaction(fn) + err := testref.Transaction(fn) if err == nil { t.Errorf("Transaction() = nil; want error") } @@ -434,7 +469,7 @@ func TestDelete(t *testing.T) { srv := mock.Start(client) defer srv.Close() - if err := ref.Delete(); err != nil { + if err := testref.Delete(); err != nil { t.Fatal(err) } checkOnlyRequest(t, mock.Reqs, &testReq{ diff --git a/integration/db/db_test.go b/integration/db/db_test.go index c73ec208..6b0445fd 100644 --- a/integration/db/db_test.go +++ b/integration/db/db_test.go @@ -63,7 +63,10 @@ func TestMain(m *testing.M) { log.Fatalln(err) } - initRefs() + ref = client.NewRef("_adminsdk/go/dinodb") + dinos = ref.Child("dinosaurs") + users = ref.Parent().Child("users") + initRules() initData() @@ -110,12 +113,6 @@ func initGuestClient(pid string) (*db.Client, error) { return app.Database(ctx) } -func initRefs() { - ref = client.NewRef("_adminsdk/go/dinodb") - dinos = ref.Child("dinosaurs") - users = ref.Parent().Child("users") -} - func initRules() { b, err := ioutil.ReadFile(internal.Resource("dinosaurs_index.json")) if err != nil { @@ -633,6 +630,23 @@ func TestGuestAccess(t *testing.T) { } } +func TestWithContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + var m map[string]interface{} + if err := ref.WithContext(ctx).Get(&m); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(testData, m) { + t.Errorf("Get() = %v; want = %v", m, testData) + } + + cancel() + m = nil + if err := ref.WithContext(ctx).Get(&m); len(m) != 0 || err == nil { + t.Errorf("Get() = (%v, %v); want = (empty, error)", m, err) + } +} + func protectedRef(t *testing.T, p string) string { r := client.NewRef(p) if err := r.Set("test"); err != nil { diff --git a/integration/db/query_test.go b/integration/db/query_test.go index 6f02d61e..003865db 100644 --- a/integration/db/query_test.go +++ b/integration/db/query_test.go @@ -1,7 +1,11 @@ package db -import "testing" -import "firebase.google.com/go/db" +import ( + "context" + "testing" + + "firebase.google.com/go/db" +) var heightSorted = []string{ "linhenykus", "pterodactyl", "lambeosaurus", @@ -182,3 +186,28 @@ func TestOrderByValue(t *testing.T) { } } } + +func TestQueryWithContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + q := dinos.OrderByKey(db.WithLimitToFirst(2)).WithContext(ctx) + var m map[string]Dinosaur + if err := q.Get(&m); err != nil { + t.Fatal(err) + } + + want := []string{"bruhathkayosaurus", "lambeosaurus"} + if len(m) != len(want) { + t.Errorf("OrderByKey() = %v; want = %v", m, want) + } + for _, d := range want { + if _, ok := m[d]; !ok { + t.Errorf("OrderByKey() = %v; want key %q", m, d) + } + } + + cancel() + m = nil + if err := q.Get(&m); len(m) != 0 || err == nil { + t.Errorf("Get() = (%v, %v); want = (empty, error)", m, err) + } +} From bf092a2fa63d5a2380a6e47804bfd721ba5b43af Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Fri, 27 Oct 2017 16:31:58 -0700 Subject: [PATCH 21/58] Boosted test coverage to 97% --- db/auth_override_test.go | 2 +- db/db.go | 2 +- db/db_test.go | 55 ++------- db/http_client.go | 2 +- db/query.go | 226 +++++++++++++++++----------------- db/query_test.go | 129 ++++++++++++-------- db/ref.go | 4 +- db/ref_test.go | 227 ++++++++++++++++++++++++++++++----- integration/db/db_test.go | 3 +- integration/db/query_test.go | 31 ++--- 10 files changed, 420 insertions(+), 261 deletions(-) diff --git a/db/auth_override_test.go b/db/auth_override_test.go index 9d295bc7..401f9c3f 100644 --- a/db/auth_override_test.go +++ b/db/auth_override_test.go @@ -72,7 +72,7 @@ func TestAuthOverrideRangeQuery(t *testing.T) { ref := aoClient.NewRef("peter") var got string - if err := ref.OrderByChild("foo", WithStartAt(1), WithEndAt(10)).Get(&got); err != nil { + if err := ref.OrderByChild("foo").WithStartAt(1).WithEndAt(10).Get(&got); err != nil { t.Fatal(err) } if got != "data" { diff --git a/db/db.go b/db/db.go index 0372d1a3..f3fc4081 100644 --- a/db/db.go +++ b/db/db.go @@ -49,7 +49,7 @@ func NewClient(ctx context.Context, c *internal.DatabaseConfig) (*Client, error) return nil, err } - p, err := url.Parse(c.URL) + p, err := url.ParseRequestURI(c.URL) if err != nil { return nil, err } else if p.Scheme != "https" { diff --git a/db/db_test.go b/db/db_test.go index e0cae556..c63f2544 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -120,7 +120,7 @@ func TestNewClientAuthOverrides(t *testing.T) { } } -func TestNewClientError(t *testing.T) { +func TestInvalidURL(t *testing.T) { cases := []string{ "", "foo", @@ -138,6 +138,17 @@ func TestNewClientError(t *testing.T) { } } +func TestInvalidAuthOverride(t *testing.T) { + c, err := NewClient(context.Background(), &internal.DatabaseConfig{ + Opts: testOpts, + URL: testURL, + AO: map[string]interface{}{"uid": func() {}}, + }) + if c != nil || err == nil { + t.Errorf("NewClient() = (%v, %v); want = (nil, error)", c, err) + } +} + func TestNewRef(t *testing.T) { cases := []struct { Path string @@ -233,48 +244,6 @@ func TestChild(t *testing.T) { } } -func TestInvalidPath(t *testing.T) { - mock := &mockServer{Resp: "test"} - srv := mock.Start(client) - defer srv.Close() - - cases := []string{ - "foo$", "foo.", "foo#", "foo]", "foo[", - } - for _, tc := range cases { - r := client.NewRef(tc) - var got string - if err := r.Get(&got); got != "" || err == nil { - t.Errorf("Get() = (%q, %v); want = (%q, error)", got, err, "") - } - } - - if len(mock.Reqs) != 0 { - t.Errorf("Requests: %v; want: empty", mock.Reqs) - } -} - -func TestInvalidChildPath(t *testing.T) { - mock := &mockServer{Resp: "test"} - srv := mock.Start(client) - defer srv.Close() - - cases := []string{ - "foo$", "foo.", "foo#", "foo]", "foo[", - } - r := client.NewRef("test") - for _, tc := range cases { - var got string - if err := r.Child(tc).Get(&got); got != "" || err == nil { - t.Errorf("Get() = (%q, %v); want = (%q, error)", got, err, "") - } - } - - if len(mock.Reqs) != 0 { - t.Errorf("Requests: %v; want: empty", mock.Reqs) - } -} - func checkOnlyRequest(t *testing.T, got []*testReq, want *testReq) { checkAllRequests(t, got, []*testReq{want}) } diff --git a/db/http_client.go b/db/http_client.go index a64ee997..0a2b2051 100644 --- a/db/http_client.go +++ b/db/http_client.go @@ -126,7 +126,7 @@ func withQueryParam(key, value string) httpOption { } } -func withQueryParams(qp queryParams) httpOption { +func withQueryParams(qp map[string]string) httpOption { return func(r *http.Request) { q := r.URL.Query() for k, v := range qp { diff --git a/db/query.go b/db/query.go index 34359eda..16265b65 100644 --- a/db/query.go +++ b/db/query.go @@ -10,167 +10,163 @@ import ( "golang.org/x/net/context" ) -type Query interface { - Get(v interface{}) error - WithContext(ctx context.Context) Query +type Query struct { + ctx context.Context + client *Client + path string + ob orderBy + limFirst, limLast int + start, end, equalTo interface{} } -type queryImpl struct { - Ctx context.Context - Client *Client - Path string - OB orderBy - Opts []QueryOption +func (q *Query) WithStartAt(v interface{}) *Query { + q2 := new(Query) + *q2 = *q + q2.start = v + return q2 } -func (q *queryImpl) Get(v interface{}) error { - qp := make(queryParams) - if err := q.OB.apply(qp); err != nil { - return err - } - for _, o := range q.Opts { - if err := o.apply(qp); err != nil { - return err - } - } - - req := &request{ - Method: "GET", - Path: q.Path, - Opts: []httpOption{withQueryParams(qp)}, - } - resp, err := q.Client.send(q.Ctx, req) - if err != nil { - return err - } - return resp.CheckAndParse(http.StatusOK, v) +func (q *Query) WithEndAt(v interface{}) *Query { + q2 := new(Query) + *q2 = *q + q2.end = v + return q2 } -func (q *queryImpl) WithContext(ctx context.Context) Query { - q2 := new(queryImpl) +func (q *Query) WithEqualTo(v interface{}) *Query { + q2 := new(Query) *q2 = *q - q2.Ctx = ctx + q2.equalTo = v return q2 } -type queryParams map[string]string - -type orderBy interface { - apply(qp queryParams) error +func (q *Query) WithLimitToFirst(lim int) *Query { + q2 := new(Query) + *q2 = *q + q2.limFirst = lim + return q2 } -type QueryOption interface { - apply(qp queryParams) error +func (q *Query) WithLimitToLast(lim int) *Query { + q2 := new(Query) + *q2 = *q + q2.limLast = lim + return q2 } -func WithLimitToFirst(lim int) QueryOption { - return &limitParam{"limitToFirst", lim} +func (q *Query) WithContext(ctx context.Context) *Query { + q2 := new(Query) + *q2 = *q + q2.ctx = ctx + return q2 } -func WithLimitToLast(lim int) QueryOption { - return &limitParam{"limitToLast", lim} -} +func (q *Query) Get(v interface{}) error { + qp := make(map[string]string) + ob, err := q.ob.encode() + if err != nil { + return err + } + qp["orderBy"] = ob + + if q.limFirst > 0 && q.limLast > 0 { + return fmt.Errorf("cannot set both limit parameter: first = %d, last = %d", q.limFirst, q.limLast) + } else if q.limFirst < 0 { + return fmt.Errorf("limit first cannot be negative: %d", q.limFirst) + } else if q.limLast < 0 { + return fmt.Errorf("limit last cannot be negative: %d", q.limLast) + } -func WithStartAt(v interface{}) QueryOption { - return &filterParam{"startAt", v} -} + if q.limFirst > 0 { + qp["limitToFirst"] = strconv.Itoa(q.limFirst) + } else if q.limLast > 0 { + qp["limitToLast"] = strconv.Itoa(q.limLast) + } -func WithEndAt(v interface{}) QueryOption { - return &filterParam{"endAt", v} -} + if err := encodeFilter("startAt", q.start, qp); err != nil { + return err + } + if err := encodeFilter("endAt", q.end, qp); err != nil { + return err + } + if err := encodeFilter("equalTo", q.equalTo, qp); err != nil { + return err + } -func WithEqualTo(v interface{}) QueryOption { - return &filterParam{"equalTo", v} + req := &request{ + Method: "GET", + Path: q.path, + Opts: []httpOption{withQueryParams(qp)}, + } + resp, err := q.client.send(q.ctx, req) + if err != nil { + return err + } + return resp.CheckAndParse(http.StatusOK, v) } -func (r *Ref) OrderByChild(child string, opts ...QueryOption) Query { - return newQuery(r, orderByChild(child), opts...) +func (r *Ref) OrderByChild(child string) *Query { + return newQuery(r, orderByChild(child)) } -func (r *Ref) OrderByKey(opts ...QueryOption) Query { - return newQuery(r, orderByProperty("$key"), opts...) +func (r *Ref) OrderByKey() *Query { + return newQuery(r, orderByProperty("$key")) } -func (r *Ref) OrderByValue(opts ...QueryOption) Query { - return newQuery(r, orderByProperty("$value"), opts...) +func (r *Ref) OrderByValue() *Query { + return newQuery(r, orderByProperty("$value")) } -func newQuery(r *Ref, ob orderBy, opts ...QueryOption) Query { - return &queryImpl{ - Ctx: r.ctx, - Client: r.client, - Path: r.Path, - OB: ob, - Opts: opts, +func newQuery(r *Ref, ob orderBy) *Query { + return &Query{ + ctx: r.ctx, + client: r.client, + path: r.Path, + ob: ob, } } -type orderByChild string - -func (p orderByChild) apply(qp queryParams) error { - if p == "" { - return fmt.Errorf("empty child path") - } else if strings.ContainsAny(string(p), invalidChars) { - return fmt.Errorf("invalid child path with illegal characters: %q", p) - } - segs := parsePath(string(p)) - b, err := json.Marshal(strings.Join(segs, "/")) - if err != nil { +func encodeFilter(key string, val interface{}, m map[string]string) error { + if val == nil { return nil } - qp["orderBy"] = string(b) - return nil -} - -type orderByProperty string - -func (p orderByProperty) apply(qp queryParams) error { - b, err := json.Marshal(p) + b, err := json.Marshal(val) if err != nil { return err } - qp["orderBy"] = string(b) + m[key] = string(b) return nil } -type limitParam struct { - key string - val int +type orderBy interface { + encode() (string, error) } -func (p *limitParam) apply(qp queryParams) error { - if p.val < 0 { - return fmt.Errorf("limit parameters must not be negative: %d", p.val) - } else if p.val == 0 { - return nil - } +type orderByChild string - qp[p.key] = strconv.Itoa(p.val) - cnt := 0 - for _, k := range []string{"limitToFirst", "limitToLast"} { - if _, ok := qp[k]; ok { - cnt++ - } +func (p orderByChild) encode() (string, error) { + if p == "" { + return "", fmt.Errorf("empty child path") + } else if strings.ContainsAny(string(p), invalidChars) { + return "", fmt.Errorf("invalid child path with illegal characters: %q", p) } - if cnt == 2 { - return fmt.Errorf("cannot set both limit parameters") + segs := parsePath(string(p)) + if len(segs) == 0 { + return "", fmt.Errorf("invalid child path: %q", p) } - return nil + b, err := json.Marshal(strings.Join(segs, "/")) + if err != nil { + return "", nil + } + return string(b), nil } -type filterParam struct { - key string - val interface{} -} +type orderByProperty string -func (p *filterParam) apply(qp queryParams) error { - if p.val == nil { - return nil - } - b, err := json.Marshal(p.val) +func (p orderByProperty) encode() (string, error) { + b, err := json.Marshal(p) if err != nil { - return err + return "", err } - qp[p.key] = string(b) - return nil + return string(b), nil } diff --git a/db/query_test.go b/db/query_test.go index eb84cb67..c31dbb82 100644 --- a/db/query_test.go +++ b/db/query_test.go @@ -8,14 +8,14 @@ import ( func TestQueryWithContext(t *testing.T) { q := client.NewRef("peter").OrderByChild("messages") - if q.(*queryImpl).Ctx != nil { - t.Errorf("Ctx = %v; want nil", q.(*queryImpl).Ctx) + if q.ctx != nil { + t.Errorf("Ctx = %v; want nil", q.ctx) } ctx, cancel := context.WithCancel(context.Background()) q = q.WithContext(ctx) - if q.(*queryImpl).Ctx != ctx { - t.Errorf("Ctx = %v; want %v", q.(*queryImpl).Ctx, ctx) + if q.ctx != ctx { + t.Errorf("Ctx = %v; want %v", q.ctx, ctx) } want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} @@ -46,8 +46,8 @@ func TestQueryWithContext(t *testing.T) { func TestQueryFromRefWithContext(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) q := client.NewRef("peter").WithContext(ctx).OrderByChild("messages") - if q.(*queryImpl).Ctx != ctx { - t.Errorf("Ctx = %v; want %v", q.(*queryImpl).Ctx, ctx) + if q.ctx != ctx { + t.Errorf("Ctx = %v; want %v", q.ctx, ctx) } want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} @@ -81,8 +81,8 @@ func TestQueryWithContextPrecedence(t *testing.T) { r := client.NewRef("peter").WithContext(ctx1) q := r.OrderByChild("messages").WithContext(ctx2) - if q.(*queryImpl).Ctx != ctx2 { - t.Errorf("Ctx = %v; want %v", q.(*queryImpl).Ctx, ctx2) + if q.ctx != ctx2 { + t.Errorf("Ctx = %v; want %v", q.ctx, ctx2) } want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} @@ -120,18 +120,26 @@ func TestChildQuery(t *testing.T) { srv := mock.Start(client) defer srv.Close() - var got map[string]interface{} - if err := testref.OrderByChild("messages").Get(&got); err != nil { - t.Fatal(err) + cases := []string{ + "messages", "messages/", "/messages", } - if !reflect.DeepEqual(want, got) { - t.Errorf("Get() = %v; want = %v", got, want) + var reqs []*testReq + for _, tc := range cases { + var got map[string]interface{} + if err := testref.OrderByChild(tc).Get(&got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("Get() = %v; want = %v", got, want) + } + reqs = append(reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"orderBy": "\"messages\""}, + }) } - checkOnlyRequest(t, mock.Reqs, &testReq{ - Method: "GET", - Path: "/peter.json", - Query: map[string]string{"orderBy": "\"messages\""}, - }) + + checkAllRequests(t, mock.Reqs, reqs) } func TestNestedChildQuery(t *testing.T) { @@ -160,13 +168,9 @@ func TestChildQueryWithParams(t *testing.T) { srv := mock.Start(client) defer srv.Close() - opts := []QueryOption{ - WithStartAt("m4"), - WithEndAt("m50"), - WithLimitToFirst(10), - } + q := testref.OrderByChild("messages").WithStartAt("m4").WithEndAt("m50").WithLimitToFirst(10) var got map[string]interface{} - if err := testref.OrderByChild("messages", opts...).Get(&got); err != nil { + if err := q.Get(&got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { @@ -191,7 +195,8 @@ func TestInvalidOrderByChild(t *testing.T) { r := client.NewRef("/") cases := []string{ - "foo$", "foo.", "foo#", "foo]", "foo[", + "", "/", "foo$", "foo.", "foo#", "foo]", + "foo[", "$key", "$value", "$priority", } for _, tc := range cases { var got string @@ -251,7 +256,7 @@ func TestLimitFirstQuery(t *testing.T) { defer srv.Close() var got map[string]interface{} - if err := testref.OrderByChild("messages", WithLimitToFirst(10)).Get(&got); err != nil { + if err := testref.OrderByChild("messages").WithLimitToFirst(10).Get(&got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { @@ -271,7 +276,7 @@ func TestLimitLastQuery(t *testing.T) { defer srv.Close() var got map[string]interface{} - if err := testref.OrderByChild("messages", WithLimitToLast(10)).Get(&got); err != nil { + if err := testref.OrderByChild("messages").WithLimitToLast(10).Get(&got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { @@ -290,13 +295,20 @@ func TestInvalidLimitQuery(t *testing.T) { srv := mock.Start(client) defer srv.Close() - var got map[string]interface{} - q := testref.OrderByChild("messages", WithLimitToFirst(10), WithLimitToLast(10)) - if err := q.Get(&got); got != nil || err == nil { - t.Errorf("Get() = (%v, %v); want = (nil, error)", got, err) + q := testref.OrderByChild("messages") + cases := []*Query{ + q.WithLimitToFirst(10).WithLimitToLast(10), + q.WithLimitToFirst(-10), + q.WithLimitToLast(-10), } - if len(mock.Reqs) != 0 { - t.Errorf("Requests: %v; want: empty", mock.Reqs) + for _, tc := range cases { + var got map[string]interface{} + if err := tc.Get(&got); got != nil || err == nil { + t.Errorf("Get() = (%v, %v); want = (nil, error)", got, err) + } + if len(mock.Reqs) != 0 { + t.Errorf("Requests: %v; want: empty", mock.Reqs) + } } } @@ -307,7 +319,7 @@ func TestStartAtQuery(t *testing.T) { defer srv.Close() var got map[string]interface{} - if err := testref.OrderByChild("messages", WithStartAt(10)).Get(&got); err != nil { + if err := testref.OrderByChild("messages").WithStartAt(10).Get(&got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { @@ -327,7 +339,7 @@ func TestEndAtQuery(t *testing.T) { defer srv.Close() var got map[string]interface{} - if err := testref.OrderByChild("messages", WithEndAt(10)).Get(&got); err != nil { + if err := testref.OrderByChild("messages").WithEndAt(10).Get(&got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { @@ -340,15 +352,14 @@ func TestEndAtQuery(t *testing.T) { }) } -func TestAllParamsQuery(t *testing.T) { +func TestEqualToQuery(t *testing.T) { want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} mock := &mockServer{Resp: want} srv := mock.Start(client) defer srv.Close() - q := testref.OrderByChild("messages", WithLimitToFirst(100), WithStartAt("bar"), WithEndAt("foo")) var got map[string]interface{} - if err := q.Get(&got); err != nil { + if err := testref.OrderByChild("messages").WithEqualTo(10).Get(&got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { @@ -357,23 +368,42 @@ func TestAllParamsQuery(t *testing.T) { checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "GET", Path: "/peter.json", - Query: map[string]string{ - "limitToFirst": "100", - "startAt": "\"bar\"", - "endAt": "\"foo\"", - "orderBy": "\"messages\"", - }, + Query: map[string]string{"equalTo": "10", "orderBy": "\"messages\""}, }) } -func TestEqualToQuery(t *testing.T) { +func TestInvalidFilterQuery(t *testing.T) { want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} mock := &mockServer{Resp: want} srv := mock.Start(client) defer srv.Close() + q := testref.OrderByChild("messages") + cases := []*Query{ + q.WithStartAt(func() {}), + q.WithEndAt(func() {}), + q.WithEqualTo(func() {}), + } + for _, tc := range cases { + var got map[string]interface{} + if err := tc.Get(&got); got != nil || err == nil { + t.Errorf("Get() = (%v, %v); want = (nil, error)", got, err) + } + if len(mock.Reqs) != 0 { + t.Errorf("Requests: %v; want: empty", mock.Reqs) + } + } +} + +func TestAllParamsQuery(t *testing.T) { + want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + q := testref.OrderByChild("messages").WithLimitToFirst(100).WithStartAt("bar").WithEndAt("foo") var got map[string]interface{} - if err := testref.OrderByChild("messages", WithEqualTo(10)).Get(&got); err != nil { + if err := q.Get(&got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { @@ -382,6 +412,11 @@ func TestEqualToQuery(t *testing.T) { checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "GET", Path: "/peter.json", - Query: map[string]string{"equalTo": "10", "orderBy": "\"messages\""}, + Query: map[string]string{ + "limitToFirst": "100", + "startAt": "\"bar\"", + "endAt": "\"foo\"", + "orderBy": "\"messages\"", + }, }) } diff --git a/db/ref.go b/db/ref.go index 4b83eebf..96314212 100644 --- a/db/ref.go +++ b/db/ref.go @@ -145,7 +145,9 @@ func (r *Ref) Transaction(fn UpdateFn) error { return err } resp, err := r.sendWithBody("PUT", new, withHeader("If-Match", etag)) - if err := resp.CheckStatus(http.StatusOK); err == nil { + if err != nil { + return err + } else if err := resp.CheckStatus(http.StatusOK); err == nil { return nil } else if err := resp.CheckAndParse(http.StatusPreconditionFailed, &curr); err != nil { return err diff --git a/db/ref_test.go b/db/ref_test.go index 4b3f61bc..84e54ae6 100644 --- a/db/ref_test.go +++ b/db/ref_test.go @@ -1,6 +1,7 @@ package db import ( + "fmt" "net/http" "reflect" "testing" @@ -8,6 +9,48 @@ import ( "golang.org/x/net/context" ) +type refOp func(r *Ref) error + +var testOps = []refOp{ + func(r *Ref) error { + var got string + return r.Get(&got) + }, + func(r *Ref) error { + var got string + _, err := r.GetWithETag(&got) + return err + }, + func(r *Ref) error { + var got string + _, _, err := r.GetIfChanged("etag", &got) + return err + }, + func(r *Ref) error { + return r.Set("foo") + }, + func(r *Ref) error { + _, err := r.SetIfUnchanged("etag", "foo") + return err + }, + func(r *Ref) error { + _, err := r.Push("foo") + return err + }, + func(r *Ref) error { + return r.Update(map[string]interface{}{"foo": "bar"}) + }, + func(r *Ref) error { + return r.Delete() + }, + func(r *Ref) error { + fn := func(v interface{}) (interface{}, error) { + return v, nil + } + return r.Transaction(fn) + }, +} + func TestRefWithContext(t *testing.T) { r := client.NewRef("peter") if r.ctx != nil { @@ -57,6 +100,19 @@ func TestGet(t *testing.T) { checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: "/peter.json"}) } +func TestInvalidGet(t *testing.T) { + want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + got := func() {} + if err := testref.Get(&got); err == nil { + t.Errorf("Get() = nil; want error") + } + checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: "/peter.json"}) +} + func TestGetWithStruct(t *testing.T) { want := person{Name: "Peter Parker", Age: 17} mock := &mockServer{Resp: want} @@ -160,13 +216,17 @@ func TestWerlformedHttpError(t *testing.T) { srv := mock.Start(client) defer srv.Close() - var got person - err := testref.Get(&got) want := "http error status: 500; reason: test error" - if err == nil || err.Error() != want { - t.Errorf("Get() = %v; want = %v", err, want) + for _, tc := range testOps { + err := tc(testref) + if err == nil || err.Error() != want { + t.Errorf("Get() = %v; want = %v", err, want) + } + } + + if len(mock.Reqs) != len(testOps) { + t.Errorf("Requests = %d; want = %d", len(mock.Reqs), len(testOps)) } - checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: "/peter.json"}) } func TestUnexpectedHttpError(t *testing.T) { @@ -174,13 +234,63 @@ func TestUnexpectedHttpError(t *testing.T) { srv := mock.Start(client) defer srv.Close() - var got person - err := testref.Get(&got) want := "http error status: 500; message: \"unexpected error\"" - if err == nil || err.Error() != want { - t.Errorf("Get() = %v; want = %v", err, want) + for _, tc := range testOps { + err := tc(testref) + if err == nil || err.Error() != want { + t.Errorf("Get() = %v; want = %v", err, want) + } + } + + if len(mock.Reqs) != len(testOps) { + t.Errorf("Requests = %d; want = %d", len(mock.Reqs), len(testOps)) + } +} + +func TestInvalidPath(t *testing.T) { + mock := &mockServer{Resp: "test"} + srv := mock.Start(client) + defer srv.Close() + + cases := []string{ + "foo$", "foo.", "foo#", "foo]", "foo[", + } + for _, tc := range cases { + r := client.NewRef(tc) + for _, op := range testOps { + err := op(r) + if err == nil { + t.Errorf("Get() = nil; want = error") + } + } + } + + if len(mock.Reqs) != 0 { + t.Errorf("Requests: %v; want: empty", mock.Reqs) + } +} + +func TestInvalidChildPath(t *testing.T) { + mock := &mockServer{Resp: "test"} + srv := mock.Start(client) + defer srv.Close() + + cases := []string{ + "foo$", "foo.", "foo#", "foo]", "foo[", + } + for _, tc := range cases { + r := testref.Child(tc) + for _, op := range testOps { + err := op(r) + if err == nil { + t.Errorf("Get() = nil; want = error") + } + } + } + + if len(mock.Reqs) != 0 { + t.Errorf("Requests: %v; want: empty", mock.Reqs) } - checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: "/peter.json"}) } func TestSet(t *testing.T) { @@ -188,33 +298,45 @@ func TestSet(t *testing.T) { srv := mock.Start(client) defer srv.Close() - want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} - if err := testref.Set(want); err != nil { - t.Fatal(err) + cases := []interface{}{ + 1, + true, + "foo", + map[string]interface{}{"name": "Peter Parker", "age": float64(17)}, + &person{"Peter Parker", 17}, + } + var want []*testReq + for _, tc := range cases { + if err := testref.Set(tc); err != nil { + t.Fatal(err) + } + want = append(want, &testReq{ + Method: "PUT", + Path: "/peter.json", + Body: serialize(tc), + Query: map[string]string{"print": "silent"}, + }) } - checkOnlyRequest(t, mock.Reqs, &testReq{ - Method: "PUT", - Path: "/peter.json", - Body: serialize(want), - Query: map[string]string{"print": "silent"}, - }) + checkAllRequests(t, mock.Reqs, want) } -func TestSetWithStruct(t *testing.T) { +func TestInvalidSet(t *testing.T) { mock := &mockServer{} srv := mock.Start(client) defer srv.Close() - want := &person{"Peter Parker", 17} - if err := testref.Set(&want); err != nil { - t.Fatal(err) + cases := []interface{}{ + func() {}, + make(chan int), + } + for _, tc := range cases { + if err := testref.Set(tc); err == nil { + t.Errorf("Set() = nil; want error") + } + } + if len(mock.Reqs) != 0 { + t.Errorf("Requests = %v; want = empty", mock.Reqs) } - checkOnlyRequest(t, mock.Reqs, &testReq{ - Method: "PUT", - Path: "/peter.json", - Body: serialize(want), - Query: map[string]string{"print": "silent"}, - }) } func TestSetIfUnchanged(t *testing.T) { @@ -420,6 +542,53 @@ func TestTransactionRetry(t *testing.T) { }) } +func TestTransactionError(t *testing.T) { + mock := &mockServer{ + Resp: &person{"Peter Parker", 17}, + Header: map[string]string{"ETag": "mock-etag1"}, + } + srv := mock.Start(client) + defer srv.Close() + + cnt := 0 + want := "user error" + var fn UpdateFn = func(i interface{}) (interface{}, error) { + if cnt == 0 { + mock.Status = http.StatusPreconditionFailed + mock.Header = map[string]string{"ETag": "mock-etag2"} + mock.Resp = &person{"Peter Parker", 19} + } else if cnt == 1 { + return nil, fmt.Errorf(want) + } + cnt++ + p := i.(map[string]interface{}) + p["age"] = p["age"].(float64) + 1.0 + return p, nil + } + if err := testref.Transaction(fn); err == nil || err.Error() != want { + t.Errorf("Transaction() = %v; want = %q", err, want) + } + if cnt != 1 { + t.Errorf("Retry Count = %d; want = %d", cnt, 1) + } + checkAllRequests(t, mock.Reqs, []*testReq{ + &testReq{ + Method: "GET", + Path: "/peter.json", + Header: http.Header{"X-Firebase-ETag": []string{"true"}}, + }, + &testReq{ + Method: "PUT", + Path: "/peter.json", + Body: serialize(map[string]interface{}{ + "name": "Peter Parker", + "age": 18, + }), + Header: http.Header{"If-Match": []string{"mock-etag1"}}, + }, + }) +} + func TestTransactionAbort(t *testing.T) { mock := &mockServer{ Resp: &person{"Peter Parker", 17}, diff --git a/integration/db/db_test.go b/integration/db/db_test.go index 6b0445fd..f3b140f2 100644 --- a/integration/db/db_test.go +++ b/integration/db/db_test.go @@ -586,9 +586,8 @@ func TestReadWriteAccess(t *testing.T) { func TestQueryAccess(t *testing.T) { r := aoClient.NewRef("_adminsdk/go/protected") - q := r.OrderByKey(db.WithLimitToFirst(2)) got := make(map[string]interface{}) - if err := q.Get(&got); err == nil { + if err := r.OrderByKey().WithLimitToFirst(2).Get(&got); err == nil { t.Errorf("OrderByQuery() = nil; want = error") } else if err.Error() != permDenied { t.Errorf("Error = %q; want = %q", err.Error(), permDenied) diff --git a/integration/db/query_test.go b/integration/db/query_test.go index 003865db..d25583af 100644 --- a/integration/db/query_test.go +++ b/integration/db/query_test.go @@ -3,8 +3,6 @@ package db import ( "context" "testing" - - "firebase.google.com/go/db" ) var heightSorted = []string{ @@ -14,9 +12,8 @@ var heightSorted = []string{ func TestLimitToFirst(t *testing.T) { for _, tc := range []int{2, 10} { - q := dinos.OrderByChild("height", db.WithLimitToFirst(tc)) var m map[string]Dinosaur - if err := q.Get(&m); err != nil { + if err := dinos.OrderByChild("height").WithLimitToFirst(tc).Get(&m); err != nil { t.Fatal(err) } @@ -38,9 +35,8 @@ func TestLimitToFirst(t *testing.T) { func TestLimitToLast(t *testing.T) { for _, tc := range []int{2, 10} { - q := dinos.OrderByChild("height", db.WithLimitToLast(tc)) var m map[string]Dinosaur - if err := q.Get(&m); err != nil { + if err := dinos.OrderByChild("height").WithLimitToLast(tc).Get(&m); err != nil { t.Fatal(err) } @@ -61,9 +57,8 @@ func TestLimitToLast(t *testing.T) { } func TestStartAt(t *testing.T) { - q := dinos.OrderByChild("height", db.WithStartAt(3.5)) var m map[string]Dinosaur - if err := q.Get(&m); err != nil { + if err := dinos.OrderByChild("height").WithStartAt(3.5).Get(&m); err != nil { t.Fatal(err) } @@ -79,9 +74,8 @@ func TestStartAt(t *testing.T) { } func TestEndAt(t *testing.T) { - q := dinos.OrderByChild("height", db.WithEndAt(3.5)) var m map[string]Dinosaur - if err := q.Get(&m); err != nil { + if err := dinos.OrderByChild("height").WithEndAt(3.5).Get(&m); err != nil { t.Fatal(err) } @@ -97,9 +91,8 @@ func TestEndAt(t *testing.T) { } func TestStartAndEndAt(t *testing.T) { - q := dinos.OrderByChild("height", db.WithStartAt(2.5), db.WithEndAt(5)) var m map[string]Dinosaur - if err := q.Get(&m); err != nil { + if err := dinos.OrderByChild("height").WithStartAt(2.5).WithEndAt(5).Get(&m); err != nil { t.Fatal(err) } @@ -115,9 +108,8 @@ func TestStartAndEndAt(t *testing.T) { } func TestEqualTo(t *testing.T) { - q := dinos.OrderByChild("height", db.WithEqualTo(0.6)) var m map[string]Dinosaur - if err := q.Get(&m); err != nil { + if err := dinos.OrderByChild("height").WithEqualTo(0.6).Get(&m); err != nil { t.Fatal(err) } @@ -133,9 +125,8 @@ func TestEqualTo(t *testing.T) { } func TestOrderByNestedChild(t *testing.T) { - q := dinos.OrderByChild("ratings/pos", db.WithStartAt(4)) var m map[string]Dinosaur - if err := q.Get(&m); err != nil { + if err := dinos.OrderByChild("ratings/pos").WithStartAt(4).Get(&m); err != nil { t.Fatal(err) } @@ -151,9 +142,8 @@ func TestOrderByNestedChild(t *testing.T) { } func TestOrderByKey(t *testing.T) { - q := dinos.OrderByKey(db.WithLimitToFirst(2)) var m map[string]Dinosaur - if err := q.Get(&m); err != nil { + if err := dinos.OrderByKey().WithLimitToFirst(2).Get(&m); err != nil { t.Fatal(err) } @@ -170,9 +160,8 @@ func TestOrderByKey(t *testing.T) { func TestOrderByValue(t *testing.T) { scores := ref.Child("scores") - q := scores.OrderByValue(db.WithLimitToLast(2)) var m map[string]int - if err := q.Get(&m); err != nil { + if err := scores.OrderByValue().WithLimitToLast(2).Get(&m); err != nil { t.Fatal(err) } @@ -189,7 +178,7 @@ func TestOrderByValue(t *testing.T) { func TestQueryWithContext(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) - q := dinos.OrderByKey(db.WithLimitToFirst(2)).WithContext(ctx) + q := dinos.OrderByKey().WithLimitToFirst(2).WithContext(ctx) var m map[string]Dinosaur if err := q.Get(&m); err != nil { t.Fatal(err) From 8450363bf44f6a54b5d5392437ce462a5f97418e Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Fri, 27 Oct 2017 17:40:55 -0700 Subject: [PATCH 22/58] Better error messages in tests; Added license headers --- db/auth_override_test.go | 20 +++++++-- db/db_test.go | 46 +++++++++++++------- db/http_client.go | 14 ++++++ db/query.go | 14 ++++++ db/query_test.go | 94 ++++++++++++++++++++++++---------------- db/ref_test.go | 62 ++++++++++++++++---------- 6 files changed, 170 insertions(+), 80 deletions(-) diff --git a/db/auth_override_test.go b/db/auth_override_test.go index 401f9c3f..7b78b02c 100644 --- a/db/auth_override_test.go +++ b/db/auth_override_test.go @@ -1,3 +1,17 @@ +// Copyright 2017 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package db import ( @@ -15,7 +29,7 @@ func TestAuthOverrideGet(t *testing.T) { t.Fatal(err) } if got != "data" { - t.Errorf("Get() = %q; want = %q", got, "data") + t.Errorf("Ref(AuthOverride).Get() = %q; want = %q", got, "data") } checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "GET", @@ -53,7 +67,7 @@ func TestAuthOverrideQuery(t *testing.T) { t.Fatal(err) } if got != "data" { - t.Errorf("OrderByChild() = %q; want = %q", got, "data") + t.Errorf("Ref(AuthOverride).OrderByChild() = %q; want = %q", got, "data") } checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "GET", @@ -76,7 +90,7 @@ func TestAuthOverrideRangeQuery(t *testing.T) { t.Fatal(err) } if got != "data" { - t.Errorf("OrderByChild() = %q; want = %q", got, "data") + t.Errorf("Ref(AuthOverride).OrderByChild() = %q; want = %q", got, "data") } checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "GET", diff --git a/db/db_test.go b/db/db_test.go index c63f2544..4e0ee288 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -1,3 +1,17 @@ +// Copyright 2017 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package db import ( @@ -80,13 +94,13 @@ func TestNewClient(t *testing.T) { t.Fatal(err) } if c.url != testURL { - t.Errorf("BaseURL = %q; want: %q", c.url, testURL) + t.Errorf("NewClient().url = %q; want = %q", c.url, testURL) } if c.hc == nil { - t.Errorf("http.Client = nil; want non-nil") + t.Errorf("NewClient().hc = nil; want non-nil") } if c.ao != "" { - t.Errorf("AuthOverrides = %q; want %q", c.ao, "") + t.Errorf("NewClient().ao = %q; want = %q", c.ao, "") } } @@ -105,17 +119,17 @@ func TestNewClientAuthOverrides(t *testing.T) { t.Fatal(err) } if c.url != testURL { - t.Errorf("BaseURL = %q; want: %q", c.url, testURL) + t.Errorf("NewClient(%v).url = %q; want = %q", tc, c.url, testURL) } if c.hc == nil { - t.Errorf("http.Client = nil; want non-nil") + t.Errorf("NewClient(%v).hc = nil; want non-nil", tc) } b, err := json.Marshal(tc) if err != nil { t.Fatal(err) } if c.ao != string(b) { - t.Errorf("AuthOverrides = %q; want %q", c.ao, string(b)) + t.Errorf("NewClient(%v).ao = %q; want = %q", tc, c.ao, string(b)) } } } @@ -133,7 +147,7 @@ func TestInvalidURL(t *testing.T) { URL: tc, }) if c != nil || err == nil { - t.Errorf("NewClient() = (%v, %v); want = (nil, error)", c, err) + t.Errorf("NewClient(%q) = (%v, %v); want = (nil, error)", tc, c, err) } } } @@ -166,16 +180,16 @@ func TestNewRef(t *testing.T) { for _, tc := range cases { r := client.NewRef(tc.Path) if r.client == nil { - t.Errorf("Client = nil; want = %v", r.client) + t.Errorf("NewRef(%q).client = nil; want = %v", tc.Path, r.client) } if r.ctx != nil { - t.Errorf("Ctx = %v; want nil", r.ctx) + t.Errorf("NewRef(%q).ctx = %v; want nil", tc.Path, r.ctx) } if r.Path != tc.WantPath { - t.Errorf("Path = %q; want = %q", r.Path, tc.WantPath) + t.Errorf("NewRef(%q).Path = %q; want = %q", tc.Path, r.Path, tc.WantPath) } if r.Key != tc.WantKey { - t.Errorf("Key = %q; want = %q", r.Key, tc.WantKey) + t.Errorf("NewRef(%q).Key = %q; want = %q", tc.Path, r.Key, tc.WantKey) } } } @@ -198,16 +212,16 @@ func TestParent(t *testing.T) { r := client.NewRef(tc.Path).Parent() if tc.HasParent { if r == nil { - t.Fatalf("Parent = nil; want = %q", tc.Want) + t.Fatalf("Parent(%q) = nil; want = Ref(%q)", tc.Path, tc.Want) } if r.client == nil { - t.Errorf("Client = nil; want = %v", client) + t.Errorf("Parent(%q).client = nil; want = %v", tc.Path, client) } if r.Key != tc.Want { - t.Errorf("Key = %q; want = %q", r.Key, tc.Want) + t.Errorf("Parent(%q).Key = %q; want = %q", tc.Path, r.Key, tc.Want) } } else if r != nil { - t.Fatalf("Parent = %v; want = nil", r) + t.Fatalf("Parent(%q) = %v; want = nil", tc.Path, r) } } } @@ -239,7 +253,7 @@ func TestChild(t *testing.T) { t.Errorf("Child(%q) = %q; want = %q", tc.Path, c.Path, tc.Want) } if c.Parent().Path != tc.Parent { - t.Errorf("Child().Parent() = %q; want = %q", c.Parent().Path, tc.Parent) + t.Errorf("Child(%q).Parent() = %q; want = %q", tc.Path, c.Parent().Path, tc.Parent) } } } diff --git a/db/http_client.go b/db/http_client.go index 0a2b2051..95999827 100644 --- a/db/http_client.go +++ b/db/http_client.go @@ -1,3 +1,17 @@ +// Copyright 2017 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package db import ( diff --git a/db/query.go b/db/query.go index 16265b65..585367d0 100644 --- a/db/query.go +++ b/db/query.go @@ -1,3 +1,17 @@ +// Copyright 2017 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package db import ( diff --git a/db/query_test.go b/db/query_test.go index c31dbb82..a2f17d8b 100644 --- a/db/query_test.go +++ b/db/query_test.go @@ -1,3 +1,16 @@ +// Copyright 2017 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. package db import ( @@ -9,13 +22,13 @@ import ( func TestQueryWithContext(t *testing.T) { q := client.NewRef("peter").OrderByChild("messages") if q.ctx != nil { - t.Errorf("Ctx = %v; want nil", q.ctx) + t.Errorf("query = %v; want nil", q.ctx) } ctx, cancel := context.WithCancel(context.Background()) q = q.WithContext(ctx) if q.ctx != ctx { - t.Errorf("Ctx = %v; want %v", q.ctx, ctx) + t.Errorf("query.WithContext() = %v; want %v", q.ctx, ctx) } want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} @@ -28,7 +41,7 @@ func TestQueryWithContext(t *testing.T) { t.Fatal(err) } if !reflect.DeepEqual(want, got) { - t.Errorf("Get() = %v; want = %v", got, want) + t.Errorf("query.WithContext() = %v; want = %v", got, want) } checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "GET", @@ -39,7 +52,7 @@ func TestQueryWithContext(t *testing.T) { cancel() got = nil if err := q.Get(&got); len(got) != 0 || err == nil { - t.Errorf("Get() = (%v, %v); want = (empty, error)", got, err) + t.Errorf("query.WithContext() = (%v, %v); want = (empty, error)", got, err) } } @@ -47,7 +60,7 @@ func TestQueryFromRefWithContext(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) q := client.NewRef("peter").WithContext(ctx).OrderByChild("messages") if q.ctx != ctx { - t.Errorf("Ctx = %v; want %v", q.ctx, ctx) + t.Errorf("ref.WithContext().query = %v; want %v", q.ctx, ctx) } want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} @@ -60,7 +73,7 @@ func TestQueryFromRefWithContext(t *testing.T) { t.Fatal(err) } if !reflect.DeepEqual(want, got) { - t.Errorf("Get() = %v; want = %v", got, want) + t.Errorf("ref.WithContext().query = %v; want = %v", got, want) } checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "GET", @@ -71,7 +84,7 @@ func TestQueryFromRefWithContext(t *testing.T) { cancel() got = nil if err := q.Get(&got); len(got) != 0 || err == nil { - t.Errorf("Get() = (%v, %v); want = (empty, error)", got, err) + t.Errorf("ref.WithContext().query = (%v, %v); want = (empty, error)", got, err) } } @@ -95,7 +108,7 @@ func TestQueryWithContextPrecedence(t *testing.T) { t.Fatal(err) } if !reflect.DeepEqual(want, got) { - t.Errorf("Get() = %v; want = %v", got, want) + t.Errorf("ref.WithContext().query.WithContext() = %v; want = %v", got, want) } checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "GET", @@ -106,11 +119,10 @@ func TestQueryWithContextPrecedence(t *testing.T) { cancel() got = nil if err := q.Get(&got); len(got) != 0 || err == nil { - t.Errorf("Get() = (%v, %v); want = (empty, error)", got, err) + t.Errorf("ref.WithContext().query.WithContext() = (%v, %v); want = (empty, error)", got, err) } - if err := r.Get(&got); !reflect.DeepEqual(got, want) || err != nil { - t.Errorf("Get() = (%v, %v); want = (%v, nil)", got, err, want) + t.Errorf("ref.WithContext() = (%v, %v); want = (%v, nil)", got, err, want) } } @@ -130,7 +142,7 @@ func TestChildQuery(t *testing.T) { t.Fatal(err) } if !reflect.DeepEqual(want, got) { - t.Errorf("Get() = %v; want = %v", got, want) + t.Errorf("OrderByChild(%q) = %v; want = %v", tc, got, want) } reqs = append(reqs, &testReq{ Method: "GET", @@ -153,7 +165,7 @@ func TestNestedChildQuery(t *testing.T) { t.Fatal(err) } if !reflect.DeepEqual(want, got) { - t.Errorf("Get() = %v; want = %v", got, want) + t.Errorf("OrderByChild(%q) = %v; want = %v", "messages/ratings", got, want) } checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "GET", @@ -174,7 +186,7 @@ func TestChildQueryWithParams(t *testing.T) { t.Fatal(err) } if !reflect.DeepEqual(want, got) { - t.Errorf("Get() = %v; want = %v", got, want) + t.Errorf("OrderByChild() = %v; want = %v", got, want) } checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "GET", @@ -201,11 +213,11 @@ func TestInvalidOrderByChild(t *testing.T) { for _, tc := range cases { var got string if err := r.OrderByChild(tc).Get(&got); got != "" || err == nil { - t.Errorf("Get() = (%q, %v); want = (%q, error)", got, err, "") + t.Errorf("OrderByChild(%q) = (%q, %v); want = (%q, error)", tc, got, err, "") } } if len(mock.Reqs) != 0 { - t.Errorf("Requests: %v; want: empty", mock.Reqs) + t.Errorf("OrderByChild() = %v; want = empty", mock.Reqs) } } @@ -220,7 +232,7 @@ func TestKeyQuery(t *testing.T) { t.Fatal(err) } if !reflect.DeepEqual(want, got) { - t.Errorf("Get() = %v; want = %v", got, want) + t.Errorf("OrderByKey() = %v; want = %v", got, want) } checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "GET", @@ -240,7 +252,7 @@ func TestValueQuery(t *testing.T) { t.Fatal(err) } if !reflect.DeepEqual(want, got) { - t.Errorf("Get() = %v; want = %v", got, want) + t.Errorf("OrderByValue() = %v; want = %v", got, want) } checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "GET", @@ -260,7 +272,7 @@ func TestLimitFirstQuery(t *testing.T) { t.Fatal(err) } if !reflect.DeepEqual(want, got) { - t.Errorf("Get() = %v; want = %v", got, want) + t.Errorf("query.WithLimitToFirst() = %v; want = %v", got, want) } checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "GET", @@ -280,7 +292,7 @@ func TestLimitLastQuery(t *testing.T) { t.Fatal(err) } if !reflect.DeepEqual(want, got) { - t.Errorf("Get() = %v; want = %v", got, want) + t.Errorf("query.WithLimitToLast() = %v; want = %v", got, want) } checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "GET", @@ -296,18 +308,21 @@ func TestInvalidLimitQuery(t *testing.T) { defer srv.Close() q := testref.OrderByChild("messages") - cases := []*Query{ - q.WithLimitToFirst(10).WithLimitToLast(10), - q.WithLimitToFirst(-10), - q.WithLimitToLast(-10), + cases := []struct { + name string + q *Query + }{ + {"BothLimits", q.WithLimitToFirst(10).WithLimitToLast(10)}, + {"NegativeFirst", q.WithLimitToFirst(-10)}, + {"NegativeLast", q.WithLimitToLast(-10)}, } for _, tc := range cases { var got map[string]interface{} - if err := tc.Get(&got); got != nil || err == nil { - t.Errorf("Get() = (%v, %v); want = (nil, error)", got, err) + if err := tc.q.Get(&got); got != nil || err == nil { + t.Errorf("OrderByChild(%q) = (%v, %v); want = (nil, error)", tc.name, got, err) } if len(mock.Reqs) != 0 { - t.Errorf("Requests: %v; want: empty", mock.Reqs) + t.Errorf("OrderByChild(%q): %v; want: empty", tc.name, mock.Reqs) } } } @@ -323,7 +338,7 @@ func TestStartAtQuery(t *testing.T) { t.Fatal(err) } if !reflect.DeepEqual(want, got) { - t.Errorf("Get() = %v; want = %v", got, want) + t.Errorf("WithStartAt() = %v; want = %v", got, want) } checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "GET", @@ -343,7 +358,7 @@ func TestEndAtQuery(t *testing.T) { t.Fatal(err) } if !reflect.DeepEqual(want, got) { - t.Errorf("Get() = %v; want = %v", got, want) + t.Errorf("WithEndAt() = %v; want = %v", got, want) } checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "GET", @@ -363,7 +378,7 @@ func TestEqualToQuery(t *testing.T) { t.Fatal(err) } if !reflect.DeepEqual(want, got) { - t.Errorf("Get() = %v; want = %v", got, want) + t.Errorf("WithEqualTo() = %v; want = %v", got, want) } checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "GET", @@ -379,18 +394,21 @@ func TestInvalidFilterQuery(t *testing.T) { defer srv.Close() q := testref.OrderByChild("messages") - cases := []*Query{ - q.WithStartAt(func() {}), - q.WithEndAt(func() {}), - q.WithEqualTo(func() {}), + cases := []struct { + name string + q *Query + }{ + {"InvalidStartAt", q.WithStartAt(func() {})}, + {"InvalidEndAt", q.WithEndAt(func() {})}, + {"InvalidEqualTo", q.WithEqualTo(func() {})}, } for _, tc := range cases { var got map[string]interface{} - if err := tc.Get(&got); got != nil || err == nil { - t.Errorf("Get() = (%v, %v); want = (nil, error)", got, err) + if err := tc.q.Get(&got); got != nil || err == nil { + t.Errorf("OrderByChild(%q) = (%v, %v); want = (nil, error)", tc.name, got, err) } if len(mock.Reqs) != 0 { - t.Errorf("Requests: %v; want: empty", mock.Reqs) + t.Errorf("OrdderByChild(%q) = %v; want = empty", tc.name, mock.Reqs) } } } @@ -407,7 +425,7 @@ func TestAllParamsQuery(t *testing.T) { t.Fatal(err) } if !reflect.DeepEqual(want, got) { - t.Errorf("Get() = %v; want = %v", got, want) + t.Errorf("OrderByChild(AllParams) = %v; want = %v", got, want) } checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "GET", diff --git a/db/ref_test.go b/db/ref_test.go index 84e54ae6..1d7cfe57 100644 --- a/db/ref_test.go +++ b/db/ref_test.go @@ -1,3 +1,17 @@ +// Copyright 2017 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package db import ( @@ -60,7 +74,7 @@ func TestRefWithContext(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) r = r.WithContext(ctx) if r.ctx != ctx { - t.Errorf("Ctx = %v; want %v", r.ctx, ctx) + t.Errorf("WithContext().Ctx = %v; want = %v", r.ctx, ctx) } want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} @@ -73,14 +87,14 @@ func TestRefWithContext(t *testing.T) { t.Fatal(err) } if !reflect.DeepEqual(want, got) { - t.Errorf("Get() = %v; want = %v", got, want) + t.Errorf("WithContext().Get() = %v; want = %v", got, want) } checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: "/peter.json"}) cancel() got = nil if err := r.Get(&got); len(got) != 0 || err == nil { - t.Errorf("Get() = (%v, %v); want = (empty, error)", got, err) + t.Errorf("WithContext().Get() = (%v, %v); want = (empty, error)", got, err) } } @@ -108,7 +122,7 @@ func TestInvalidGet(t *testing.T) { got := func() {} if err := testref.Get(&got); err == nil { - t.Errorf("Get() = nil; want error") + t.Errorf("Get(func) = nil; want error") } checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: "/peter.json"}) } @@ -124,7 +138,7 @@ func TestGetWithStruct(t *testing.T) { t.Fatal(err) } if want != got { - t.Errorf("Get() = %v; want = %v", got, want) + t.Errorf("Get(struct) = %v; want = %v", got, want) } checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: "/peter.json"}) } @@ -144,10 +158,10 @@ func TestGetWithETag(t *testing.T) { t.Fatal(err) } if !reflect.DeepEqual(want, got) { - t.Errorf("Get() = %v; want = %v", got, want) + t.Errorf("GetWithETag() = %v; want = %v", got, want) } if etag != "mock-etag" { - t.Errorf("ETag = %q; want = %q", etag, "mock-etag") + t.Errorf("GetWithETag() = %q; want = %q", etag, "mock-etag") } checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "GET", @@ -171,13 +185,13 @@ func TestGetIfChanged(t *testing.T) { t.Fatal(err) } if !ok { - t.Errorf("Get() = %v; want = %v", ok, true) + t.Errorf("GetIfChanged() = %v; want = %v", ok, true) } if !reflect.DeepEqual(want, got) { - t.Errorf("Get() = %v; want = %v", got, want) + t.Errorf("GetIfChanged() = %v; want = %v", got, want) } if etag != "new-etag" { - t.Errorf("ETag = %q; want = %q", etag, "new-etag") + t.Errorf("GetIfChanged() = %q; want = %q", etag, "new-etag") } mock.Status = http.StatusNotModified @@ -188,13 +202,13 @@ func TestGetIfChanged(t *testing.T) { t.Fatal(err) } if ok { - t.Errorf("Get() = %v; want = %v", ok, false) + t.Errorf("GetIfChanged() = %v; want = %v", ok, false) } if got2 != nil { - t.Errorf("Get() = %v; want nil", got2) + t.Errorf("GetIfChanged() = %v; want nil", got2) } if etag != "new-etag" { - t.Errorf("ETag = %q; want = %q", etag, "new-etag") + t.Errorf("GetIfChanged() = %q; want = %q", etag, "new-etag") } checkAllRequests(t, mock.Reqs, []*testReq{ @@ -331,11 +345,11 @@ func TestInvalidSet(t *testing.T) { } for _, tc := range cases { if err := testref.Set(tc); err == nil { - t.Errorf("Set() = nil; want error") + t.Errorf("Set(%v) = nil; want error", tc) } } if len(mock.Reqs) != 0 { - t.Errorf("Requests = %v; want = empty", mock.Reqs) + t.Errorf("Set() = %v; want = empty", mock.Reqs) } } @@ -443,13 +457,15 @@ func TestUpdate(t *testing.T) { } func TestInvalidUpdate(t *testing.T) { - if err := testref.Update(nil); err == nil { - t.Errorf("Update(nil) = nil; want error") + cases := []map[string]interface{}{ + nil, + make(map[string]interface{}), + map[string]interface{}{"foo": func() {}}, } - - m := make(map[string]interface{}) - if err := testref.Update(m); err == nil { - t.Errorf("Update(map{}) = nil; want error") + for _, tc := range cases { + if err := testref.Update(tc); err == nil { + t.Errorf("Update(%v) = nil; want error", tc) + } } } @@ -513,7 +529,7 @@ func TestTransactionRetry(t *testing.T) { t.Fatal(err) } if cnt != 2 { - t.Errorf("Retry Count = %d; want = %d", cnt, 2) + t.Errorf("Transaction() retries = %d; want = %d", cnt, 2) } checkAllRequests(t, mock.Reqs, []*testReq{ &testReq{ @@ -569,7 +585,7 @@ func TestTransactionError(t *testing.T) { t.Errorf("Transaction() = %v; want = %q", err, want) } if cnt != 1 { - t.Errorf("Retry Count = %d; want = %d", cnt, 1) + t.Errorf("Transaction() retries = %d; want = %d", cnt, 1) } checkAllRequests(t, mock.Reqs, []*testReq{ &testReq{ From 993f63970dc6b904af7d9226794c5c1c2478e4c3 Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Fri, 27 Oct 2017 21:59:17 -0700 Subject: [PATCH 23/58] Added documentatioon and cleaned up tests --- auth/auth.go | 2 +- db/db.go | 17 ++++- db/db_test.go | 7 +- db/ref.go | 55 +++++++++++++++- db/ref_test.go | 134 ++++++++++++++++++++++++-------------- firebase.go | 6 +- integration/db/db_test.go | 6 +- 7 files changed, 167 insertions(+), 60 deletions(-) diff --git a/auth/auth.go b/auth/auth.go index 26aa14d3..c7010f06 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -73,7 +73,7 @@ type signer interface { // NewClient creates a new instance of the Firebase Auth Client. // // This function can only be invoked from within the SDK. Client applications should access the -// the Auth service through firebase.App. +// Auth service through firebase.App. func NewClient(ctx context.Context, c *internal.AuthConfig) (*Client, error) { var ( err error diff --git a/db/db.go b/db/db.go index f3fc4081..03bf7eeb 100644 --- a/db/db.go +++ b/db/db.go @@ -40,6 +40,10 @@ type Client struct { ao string } +// NewClient creates a new instance of the Firebase Database Client. +// +// This function can only be invoked from within the SDK. Client applications should access the +// Database service through firebase.App. func NewClient(ctx context.Context, c *internal.DatabaseConfig) (*Client, error) { opts := append([]option.ClientOption{}, c.Opts...) ua := fmt.Sprintf(userAgentFormat, c.Version, runtime.Version()) @@ -73,10 +77,21 @@ func NewClient(ctx context.Context, c *internal.DatabaseConfig) (*Client, error) }, nil } -type AuthOverrides struct { +// AuthOverride regulates how Firebase security rules are enforced on database invocations. +// +// By default, the database calls made by the Admin SDK have administrative privileges, thereby +// allowing them to completely bypass all Firebase security rules. This behavior can be overridden +// by setting an AuthOverride. When specified, the AuthOverride value will become visible to the +// database server during security rule evaluation. Specifically, this value will be accessible +// via the auth variable of the security rules. +// +// Refer to https://firebase.google.com/docs/database/admin/start#authenticate-with-limited-privileges +// for more details and code samples. +type AuthOverride struct { Map map[string]interface{} } +// NewRef returns a new database reference representing the node at the specified path. func (c *Client) NewRef(path string) *Ref { segs := parsePath(path) key := "" diff --git a/db/db_test.go b/db/db_test.go index 4e0ee288..62226074 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -265,9 +265,10 @@ func checkOnlyRequest(t *testing.T, got []*testReq, want *testReq) { func checkAllRequests(t *testing.T, got []*testReq, want []*testReq) { if len(got) != len(want) { t.Errorf("Request Count = %d; want = %d", len(got), len(want)) - } - for i, r := range got { - checkRequest(t, r, want[i]) + } else { + for i, r := range got { + checkRequest(t, r, want[i]) + } } } diff --git a/db/ref.go b/db/ref.go index 96314212..a91baccb 100644 --- a/db/ref.go +++ b/db/ref.go @@ -22,6 +22,9 @@ import ( "golang.org/x/net/context" ) +const txnRetries = 25 + +// Ref represents a node in the Firebase Realtime Database. type Ref struct { Key string Path string @@ -31,6 +34,9 @@ type Ref struct { ctx context.Context } +// Parent returns a reference to the parent of the current node. +// +// If the current reference points to the root of the database, Parent returns nil. func (r *Ref) Parent() *Ref { l := len(r.segs) if l > 0 { @@ -40,11 +46,18 @@ func (r *Ref) Parent() *Ref { return nil } +// Child returns a reference to the specified child node. func (r *Ref) Child(path string) *Ref { fp := fmt.Sprintf("%s/%s", r.Path, path) return r.client.NewRef(fp) } +// Get retrieves the value at the current database location, and stores it in the value pointed to +// by v. +// +// Data deserialization is performed using https://golang.org/pkg/encoding/json/#Unmarshal, and +// therefore v has the same requirements as the json package. Specifically, it must be a pointer, +// and it must not be nil. func (r *Ref) Get(v interface{}) error { resp, err := r.send("GET") if err != nil { @@ -53,6 +66,9 @@ func (r *Ref) Get(v interface{}) error { return resp.CheckAndParse(http.StatusOK, v) } +// WithContext returns a shallow copy of this Ref with its context changed to ctx. +// +// The resulting Ref will use ctx for all subsequent RPC calls. func (r *Ref) WithContext(ctx context.Context) *Ref { r2 := new(Ref) *r2 = *r @@ -60,6 +76,7 @@ func (r *Ref) WithContext(ctx context.Context) *Ref { return r2 } +// GetWithETag retrieves the value at the current database location, along with its ETag. func (r *Ref) GetWithETag(v interface{}) (string, error) { resp, err := r.send("GET", withHeader("X-Firebase-ETag", "true")) if err != nil { @@ -70,6 +87,13 @@ func (r *Ref) GetWithETag(v interface{}) (string, error) { return resp.Header.Get("Etag"), nil } +// GetIfChanged retrieves the value and ETag of the current database location only if the specified +// ETag does not match. +// +// If the specified ETag does not match, returns true along with the latest ETag of the database +// location. The value of the database location will be stored in v just like a regular Get() call. +// If the etag matches, returns false along with the same ETag passed into the function. No data +// will be stored in v in this case. func (r *Ref) GetIfChanged(etag string, v interface{}) (bool, string, error) { resp, err := r.send("GET", withHeader("If-None-Match", etag)) if err != nil { @@ -82,6 +106,11 @@ func (r *Ref) GetIfChanged(etag string, v interface{}) (bool, string, error) { return false, etag, nil } +// Set stores the value v in the current database node. +// +// Set uses https://golang.org/pkg/encoding/json/#Marshal to serialize values into JSON. Therefore +// v has the same requirements as the json package. Values like functions and channels cannot be +// saved into Realtime Database. func (r *Ref) Set(v interface{}) error { resp, err := r.sendWithBody("PUT", v, withQueryParam("print", "silent")) if err != nil { @@ -90,6 +119,10 @@ func (r *Ref) Set(v interface{}) error { return resp.CheckStatus(http.StatusNoContent) } +// SetIfUnchanged conditionally sets the data at this location to the given value. +// +// Sets the data at this location to v only if the specified ETag matches. Returns true if the +// value is written. Returns false if no changes are made to the database. func (r *Ref) SetIfUnchanged(etag string, v interface{}) (bool, error) { resp, err := r.sendWithBody("PUT", v, withHeader("If-Match", etag)) if err != nil { @@ -102,6 +135,10 @@ func (r *Ref) SetIfUnchanged(etag string, v interface{}) (bool, error) { return false, nil } +// Push creates a new child node at the current location, and returns a reference to it. +// +// If v is not nil, it will be set as the initial value of the new child node. If v is nil, the +// new child node will be created with empty string as the value. func (r *Ref) Push(v interface{}) (*Ref, error) { if v == nil { v = "" @@ -119,6 +156,7 @@ func (r *Ref) Push(v interface{}) (*Ref, error) { return r.Child(d.Name), nil } +// Update modifies the specified child keys of the current location to the provided values. func (r *Ref) Update(v map[string]interface{}) error { if len(v) == 0 { return fmt.Errorf("value argument must be a non-empty map") @@ -132,6 +170,20 @@ func (r *Ref) Update(v map[string]interface{}) error { type UpdateFn func(interface{}) (interface{}, error) +// Transaction atomically modifies the data at this location. +// +// Unlike a normal Set(), which just overwrites the data regardless of its previous state, +// Transaction() is used to modify the existing value to a new value, ensuring there are no +// conflicts with other clients simultaneously writing to the same location. +// +// This is accomplished by passing an update function which is used to transform the current value +// of this reference into a new value. If another client writes to this location before the new +// value is successfully saved, the update function is called again with the new current value, and +// the write will be retried. In case of repeated failures, this method will retry the transaction up +// to 25 times before giving up and returning an error. +// +// The update function may also force an early abort by returning an error instead of returning a +// value. func (r *Ref) Transaction(fn UpdateFn) error { var curr interface{} etag, err := r.GetWithETag(&curr) @@ -139,7 +191,7 @@ func (r *Ref) Transaction(fn UpdateFn) error { return err } - for i := 0; i < 20; i++ { + for i := 0; i < txnRetries; i++ { new, err := fn(curr) if err != nil { return err @@ -157,6 +209,7 @@ func (r *Ref) Transaction(fn UpdateFn) error { return fmt.Errorf("transaction aborted after failed retries") } +// Delete removes this node from the database. func (r *Ref) Delete() error { resp, err := r.send("DELETE") if err != nil { diff --git a/db/ref_test.go b/db/ref_test.go index 1d7cfe57..eed6f9fb 100644 --- a/db/ref_test.go +++ b/db/ref_test.go @@ -25,43 +25,73 @@ import ( type refOp func(r *Ref) error -var testOps = []refOp{ - func(r *Ref) error { - var got string - return r.Get(&got) +var testOps = []struct { + name string + op refOp +}{ + { + "Get()", + func(r *Ref) error { + var got string + return r.Get(&got) + }, }, - func(r *Ref) error { - var got string - _, err := r.GetWithETag(&got) - return err + { + "GetWithETag()", + func(r *Ref) error { + var got string + _, err := r.GetWithETag(&got) + return err + }, }, - func(r *Ref) error { - var got string - _, _, err := r.GetIfChanged("etag", &got) - return err + { + "GetIfChanged()", + func(r *Ref) error { + var got string + _, _, err := r.GetIfChanged("etag", &got) + return err + }, }, - func(r *Ref) error { - return r.Set("foo") + { + "Set()", + func(r *Ref) error { + return r.Set("foo") + }, }, - func(r *Ref) error { - _, err := r.SetIfUnchanged("etag", "foo") - return err + { + "SetIfUnchanged()", + func(r *Ref) error { + _, err := r.SetIfUnchanged("etag", "foo") + return err + }, }, - func(r *Ref) error { - _, err := r.Push("foo") - return err + { + "Push()", + func(r *Ref) error { + _, err := r.Push("foo") + return err + }, }, - func(r *Ref) error { - return r.Update(map[string]interface{}{"foo": "bar"}) + { + "Update()", + func(r *Ref) error { + return r.Update(map[string]interface{}{"foo": "bar"}) + }, }, - func(r *Ref) error { - return r.Delete() + { + "Delete()", + func(r *Ref) error { + return r.Delete() + }, }, - func(r *Ref) error { - fn := func(v interface{}) (interface{}, error) { - return v, nil - } - return r.Transaction(fn) + { + "Transaction()", + func(r *Ref) error { + fn := func(v interface{}) (interface{}, error) { + return v, nil + } + return r.Transaction(fn) + }, }, } @@ -99,19 +129,27 @@ func TestRefWithContext(t *testing.T) { } func TestGet(t *testing.T) { - want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} - mock := &mockServer{Resp: want} + mock := &mockServer{} srv := mock.Start(client) defer srv.Close() - var got map[string]interface{} - if err := testref.Get(&got); err != nil { - t.Fatal(err) + cases := []interface{}{ + nil, float64(1), true, "foo", + map[string]interface{}{"name": "Peter Parker", "age": float64(17)}, } - if !reflect.DeepEqual(want, got) { - t.Errorf("Get() = %v; want = %v", got, want) + var want []*testReq + for _, tc := range cases { + mock.Resp = tc + var got interface{} + if err := testref.Get(&got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(tc, got) { + t.Errorf("Get() = %v; want = %v", got, tc) + } + want = append(want, &testReq{Method: "GET", Path: "/peter.json"}) } - checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: "/peter.json"}) + checkAllRequests(t, mock.Reqs, want) } func TestInvalidGet(t *testing.T) { @@ -232,9 +270,9 @@ func TestWerlformedHttpError(t *testing.T) { want := "http error status: 500; reason: test error" for _, tc := range testOps { - err := tc(testref) + err := tc.op(testref) if err == nil || err.Error() != want { - t.Errorf("Get() = %v; want = %v", err, want) + t.Errorf("%s = %v; want = %v", tc.name, err, want) } } @@ -250,9 +288,9 @@ func TestUnexpectedHttpError(t *testing.T) { want := "http error status: 500; message: \"unexpected error\"" for _, tc := range testOps { - err := tc(testref) + err := tc.op(testref) if err == nil || err.Error() != want { - t.Errorf("Get() = %v; want = %v", err, want) + t.Errorf("%s = %v; want = %v", tc.name, err, want) } } @@ -271,10 +309,10 @@ func TestInvalidPath(t *testing.T) { } for _, tc := range cases { r := client.NewRef(tc) - for _, op := range testOps { - err := op(r) + for _, o := range testOps { + err := o.op(r) if err == nil { - t.Errorf("Get() = nil; want = error") + t.Errorf("%s = nil; want = error", o.name) } } } @@ -294,10 +332,10 @@ func TestInvalidChildPath(t *testing.T) { } for _, tc := range cases { r := testref.Child(tc) - for _, op := range testOps { - err := op(r) + for _, o := range testOps { + err := o.op(r) if err == nil { - t.Errorf("Get() = nil; want = error") + t.Errorf("%s = nil; want = error", o.name) } } } @@ -635,7 +673,7 @@ func TestTransactionAbort(t *testing.T) { Header: http.Header{"X-Firebase-ETag": []string{"true"}}, }, } - for i := 0; i < 20; i++ { + for i := 0; i < txnRetries; i++ { wanted = append(wanted, &testReq{ Method: "PUT", Path: "/peter.json", diff --git a/firebase.go b/firebase.go index b32a69a5..6df1f82c 100644 --- a/firebase.go +++ b/firebase.go @@ -46,7 +46,7 @@ type App struct { // Config represents the configuration used to initialize an App. type Config struct { - AuthOverrides *db.AuthOverrides + AuthOverride *db.AuthOverride DatabaseURL string ProjectID string StorageBucket string @@ -110,8 +110,8 @@ func NewApp(ctx context.Context, config *Config, opts ...option.ClientOption) (* } ao := make(map[string]interface{}) - if config.AuthOverrides != nil { - ao = config.AuthOverrides.Map + if config.AuthOverride != nil { + ao = config.AuthOverride.Map } return &App{ diff --git a/integration/db/db_test.go b/integration/db/db_test.go index f3b140f2..d2c50347 100644 --- a/integration/db/db_test.go +++ b/integration/db/db_test.go @@ -89,7 +89,7 @@ func initOverrideClient(pid string) (*db.Client, error) { ctx := context.Background() app, err := internal.NewTestApp(ctx, &firebase.Config{ DatabaseURL: fmt.Sprintf("https://%s.firebaseio.com", pid), - AuthOverrides: &db.AuthOverrides{ + AuthOverride: &db.AuthOverride{ Map: map[string]interface{}{"uid": "user1"}, }, }) @@ -103,8 +103,8 @@ func initOverrideClient(pid string) (*db.Client, error) { func initGuestClient(pid string) (*db.Client, error) { ctx := context.Background() app, err := internal.NewTestApp(ctx, &firebase.Config{ - DatabaseURL: fmt.Sprintf("https://%s.firebaseio.com", pid), - AuthOverrides: &db.AuthOverrides{}, + DatabaseURL: fmt.Sprintf("https://%s.firebaseio.com", pid), + AuthOverride: &db.AuthOverride{}, }) if err != nil { return nil, err From 63a5ace92f9445417c3aed7bb0d7e60c7f6c572a Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Fri, 27 Oct 2017 22:02:36 -0700 Subject: [PATCH 24/58] Fixing a build break --- firebase_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/firebase_test.go b/firebase_test.go index dc028a1a..f80eae35 100644 --- a/firebase_test.go +++ b/firebase_test.go @@ -244,8 +244,8 @@ func TestDatabaseAuthOverrides(t *testing.T) { for _, tc := range cases { ctx := context.Background() conf := &Config{ - AuthOverrides: &db.AuthOverrides{tc}, - DatabaseURL: "https://mock-db.firebaseio.com", + AuthOverride: &db.AuthOverride{tc}, + DatabaseURL: "https://mock-db.firebaseio.com", } app, err := NewApp(ctx, conf, option.WithCredentialsFile("testdata/service_account.json")) if err != nil { From 835adaa4065eab3ace685564aa8db22ea5af05cb Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Fri, 27 Oct 2017 22:16:37 -0700 Subject: [PATCH 25/58] Finishing up documentation --- db/query.go | 50 ++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 46 insertions(+), 4 deletions(-) diff --git a/db/query.go b/db/query.go index 585367d0..746c8d7c 100644 --- a/db/query.go +++ b/db/query.go @@ -24,6 +24,14 @@ import ( "golang.org/x/net/context" ) +// Query represents a complex query that can be executed on a Ref. +// +// Complex queries can consist of up to 2 components: a required ordering constraint, and an +// optional filtering constraint. At the server, data is first sorted according to the given +// ordering constraint (e.g. order by child). Then the filtering constraint (e.g. limit, range) is +// applied on the sorted data to produce the final result. Despite the ordering constraint, the +// final result is returned by the server as an unordered collection. Therefore the values read +// from a Query instance are not ordered. type Query struct { ctx context.Context client *Client @@ -33,6 +41,9 @@ type Query struct { start, end, equalTo interface{} } +// WithStartAt returns a shallow copy of the Query with v set as a lower bound of a range query. +// +// The resulting Query will only return child nodes with a value greater than or equal to v. func (q *Query) WithStartAt(v interface{}) *Query { q2 := new(Query) *q2 = *q @@ -40,6 +51,9 @@ func (q *Query) WithStartAt(v interface{}) *Query { return q2 } +// WithEndAt returns a shallow copy of the Query with v set as a upper bound of a range query. +// +// The resulting Query will only return child nodes with a value less than or equal to v. func (q *Query) WithEndAt(v interface{}) *Query { q2 := new(Query) *q2 = *q @@ -47,6 +61,9 @@ func (q *Query) WithEndAt(v interface{}) *Query { return q2 } +// WithEqualTo returns a shallow copy of the Query with v set as an equals constraint. +// +// The resulting Query will only return child nodes whose values equal to v. func (q *Query) WithEqualTo(v interface{}) *Query { q2 := new(Query) *q2 = *q @@ -54,20 +71,27 @@ func (q *Query) WithEqualTo(v interface{}) *Query { return q2 } -func (q *Query) WithLimitToFirst(lim int) *Query { +// WithLimitToFirst returns a shallow copy of the Query, which is anchored to the first n +// elements of the window. +func (q *Query) WithLimitToFirst(n int) *Query { q2 := new(Query) *q2 = *q - q2.limFirst = lim + q2.limFirst = n return q2 } -func (q *Query) WithLimitToLast(lim int) *Query { +// WithLimitToLast returns a shallow copy of the Query, which is anchored to the last n +// elements of the window. +func (q *Query) WithLimitToLast(n int) *Query { q2 := new(Query) *q2 = *q - q2.limLast = lim + q2.limLast = n return q2 } +// WithContext returns a shallow copy of this Query with its context changed to ctx. +// +// The resulting Query will use ctx for all subsequent RPC calls. func (q *Query) WithContext(ctx context.Context) *Query { q2 := new(Query) *q2 = *q @@ -75,6 +99,9 @@ func (q *Query) WithContext(ctx context.Context) *Query { return q2 } +// Get executes the Query and populates v with the results. +// +// Results will not be stored in any particular order in v. func (q *Query) Get(v interface{}) error { qp := make(map[string]string) ob, err := q.ob.encode() @@ -119,14 +146,29 @@ func (q *Query) Get(v interface{}) error { return resp.CheckAndParse(http.StatusOK, v) } +// OrderByChild returns a Query that orders data by child values before applying filters. +// +// Returned Query can be used to set additional parameters, and execute complex database queries +// (e.g. limit queries, range queries). If r has a context associated with it, the resulting Query +// will inherit it. func (r *Ref) OrderByChild(child string) *Query { return newQuery(r, orderByChild(child)) } +// OrderByKey returns a Query that orders data by key before applying filters. +// +// Returned Query can be used to set additional parameters, and execute complex database queries +// (e.g. limit queries, range queries). If r has a context associated with it, the resulting Query +// will inherit it. func (r *Ref) OrderByKey() *Query { return newQuery(r, orderByProperty("$key")) } +// OrderByValue returns a Query that orders data by value before applying filters. +// +// Returned Query can be used to set additional parameters, and execute complex database queries +// (e.g. limit queries, range queries). If r has a context associated with it, the resulting Query +// will inherit it. func (r *Ref) OrderByValue() *Query { return newQuery(r, orderByProperty("$value")) } From 57263adf552b0e0ca33ff53f31b50acea1fe90d1 Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Fri, 27 Oct 2017 23:56:36 -0700 Subject: [PATCH 26/58] More test cases --- db/query_test.go | 3 +- db/ref_test.go | 71 +++++++++++++++++++++------------------ integration/db/db_test.go | 16 ++++----- 3 files changed, 46 insertions(+), 44 deletions(-) diff --git a/db/query_test.go b/db/query_test.go index a2f17d8b..c3b58598 100644 --- a/db/query_test.go +++ b/db/query_test.go @@ -14,9 +14,10 @@ package db import ( - "context" "reflect" "testing" + + "golang.org/x/net/context" ) func TestQueryWithContext(t *testing.T) { diff --git a/db/ref_test.go b/db/ref_test.go index eed6f9fb..b8f6949b 100644 --- a/db/ref_test.go +++ b/db/ref_test.go @@ -27,10 +27,12 @@ type refOp func(r *Ref) error var testOps = []struct { name string + resp interface{} op refOp }{ { "Get()", + "test", func(r *Ref) error { var got string return r.Get(&got) @@ -38,6 +40,7 @@ var testOps = []struct { }, { "GetWithETag()", + "test", func(r *Ref) error { var got string _, err := r.GetWithETag(&got) @@ -46,6 +49,7 @@ var testOps = []struct { }, { "GetIfChanged()", + "test", func(r *Ref) error { var got string _, _, err := r.GetIfChanged("etag", &got) @@ -54,12 +58,14 @@ var testOps = []struct { }, { "Set()", + nil, func(r *Ref) error { return r.Set("foo") }, }, { "SetIfUnchanged()", + nil, func(r *Ref) error { _, err := r.SetIfUnchanged("etag", "foo") return err @@ -67,6 +73,7 @@ var testOps = []struct { }, { "Push()", + map[string]interface{}{"name": "test"}, func(r *Ref) error { _, err := r.Push("foo") return err @@ -74,18 +81,21 @@ var testOps = []struct { }, { "Update()", + nil, func(r *Ref) error { return r.Update(map[string]interface{}{"foo": "bar"}) }, }, { "Delete()", + nil, func(r *Ref) error { return r.Delete() }, }, { "Transaction()", + nil, func(r *Ref) error { fn := func(v interface{}) (interface{}, error) { return v, nil @@ -95,39 +105,6 @@ var testOps = []struct { }, } -func TestRefWithContext(t *testing.T) { - r := client.NewRef("peter") - if r.ctx != nil { - t.Errorf("Ctx = %v; want nil", r.ctx) - } - - ctx, cancel := context.WithCancel(context.Background()) - r = r.WithContext(ctx) - if r.ctx != ctx { - t.Errorf("WithContext().Ctx = %v; want = %v", r.ctx, ctx) - } - - want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} - mock := &mockServer{Resp: want} - srv := mock.Start(client) - defer srv.Close() - - var got map[string]interface{} - if err := r.Get(&got); err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(want, got) { - t.Errorf("WithContext().Get() = %v; want = %v", got, want) - } - checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: "/peter.json"}) - - cancel() - got = nil - if err := r.Get(&got); len(got) != 0 || err == nil { - t.Errorf("WithContext().Get() = (%v, %v); want = (empty, error)", got, err) - } -} - func TestGet(t *testing.T) { mock := &mockServer{} srv := mock.Start(client) @@ -700,3 +677,31 @@ func TestDelete(t *testing.T) { Path: "/peter.json", }) } + +func TestWithContext(t *testing.T) { + if testref.ctx != nil { + t.Errorf("Ctx = %v; want nil", testref.ctx) + } + + mock := &mockServer{} + srv := mock.Start(client) + defer srv.Close() + + for _, tc := range testOps { + mock.Resp = tc.resp + ctx, cancel := context.WithCancel(context.Background()) + r := testref.WithContext(ctx) + if r.ctx != ctx { + t.Errorf("WithContext().ctx = %v; want = %v", r.ctx, ctx) + } + if err := tc.op(r); err != nil { + t.Errorf("%s %v", tc.name, err) + t.Fatal(err) + } + + cancel() + if err := tc.op(r); err == nil { + t.Errorf("WithContext().%s = nil; want = error", tc.name) + } + } +} diff --git a/integration/db/db_test.go b/integration/db/db_test.go index d2c50347..d4f0627f 100644 --- a/integration/db/db_test.go +++ b/integration/db/db_test.go @@ -1,24 +1,20 @@ package db import ( - "context" + "bytes" + "encoding/json" "flag" "fmt" + "io/ioutil" "log" "net/http" "os" - "testing" - - "firebase.google.com/go" - - "io/ioutil" - - "encoding/json" - "reflect" + "testing" - "bytes" + "golang.org/x/net/context" + "firebase.google.com/go" "firebase.google.com/go/db" "firebase.google.com/go/integration/internal" ) From e557b9af53292ba71037e178c980a57897f54ad6 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Mon, 30 Oct 2017 13:08:46 -0700 Subject: [PATCH 27/58] Implemented a reusable HTTP client API --- internal/http_client.go | 138 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 138 insertions(+) create mode 100644 internal/http_client.go diff --git a/internal/http_client.go b/internal/http_client.go new file mode 100644 index 00000000..b03b83df --- /dev/null +++ b/internal/http_client.go @@ -0,0 +1,138 @@ +// Copyright 2017 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package internal contains functionality that is only accessible from within the Admin SDK. +package internal + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "net/http" +) + +type Request struct { + Method string + URL string + Body interface{} + Opts []HTTPOption +} + +func (r *Request) Send(ctx context.Context, hc *http.Client) (*Response, error) { + req, err := r.newHTTPRequest() + if err != nil { + return nil, err + } + + resp, err := hc.Do(req.WithContext(ctx)) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, err + } + return &Response{ + Status: resp.StatusCode, + Body: b, + Header: resp.Header, + }, nil +} + +func (r *Request) newHTTPRequest() (*http.Request, error) { + var opts []HTTPOption + var data io.Reader + if r.Body != nil { + b, err := json.Marshal(r.Body) + if err != nil { + return nil, err + } + data = bytes.NewBuffer(b) + opts = append(opts, WithHeader("Content-Type", "application/json")) + } + + req, err := http.NewRequest(r.Method, r.URL, data) + if err != nil { + return nil, err + } + + opts = append(opts, r.Opts...) + for _, o := range opts { + o(req) + } + return req, nil +} + +type Response struct { + Status int + Header http.Header + Body []byte +} + +func (r *Response) CheckStatus(want int, ep ErrorParser) error { + if r.Status == want { + return nil + } + + var msg string + if ep != nil { + msg = ep(r) + } + if msg == "" { + msg = string(r.Body) + } + return fmt.Errorf("http error status: %d; reason: %s", r.Status, msg) +} + +func (r *Response) Unmarshal(want int, ep ErrorParser, v interface{}) error { + if err := r.CheckStatus(want, ep); err != nil { + return err + } else if err := json.Unmarshal(r.Body, v); err != nil { + return err + } + return nil +} + +type ErrorParser func(r *Response) string + +type HTTPOption func(*http.Request) + +func WithHeader(key, value string) HTTPOption { + return func(r *http.Request) { + r.Header.Set(key, value) + } +} + +func WithQueryParam(key, value string) HTTPOption { + return func(r *http.Request) { + q := r.URL.Query() + q.Add(key, value) + r.URL.RawQuery = q.Encode() + } +} + +func WithQueryParams(qp map[string]string) HTTPOption { + return func(r *http.Request) { + q := r.URL.Query() + for k, v := range qp { + q.Add(k, v) + } + r.URL.RawQuery = q.Encode() + } +} From f69b5adaf4d3a67968081ce4aee49132d9fb34b8 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Mon, 30 Oct 2017 15:32:24 -0700 Subject: [PATCH 28/58] Added test cases --- internal/http_client.go | 34 ++++- internal/http_client_test.go | 259 +++++++++++++++++++++++++++++++++++ 2 files changed, 292 insertions(+), 1 deletion(-) create mode 100644 internal/http_client_test.go diff --git a/internal/http_client.go b/internal/http_client.go index b03b83df..c4ac1c9d 100644 --- a/internal/http_client.go +++ b/internal/http_client.go @@ -25,6 +25,12 @@ import ( "net/http" ) +// Null represents JSON null value. +var Null struct{} = jsonNull{} + +type jsonNull struct{} + +// Request contains all the parameters required to construct an outgoing HTTP request. type Request struct { Method string URL string @@ -32,6 +38,10 @@ type Request struct { Opts []HTTPOption } +// Send executes the current Request using the given context and HTTP client. +// +// If the Body is not nil, it is serialized into a JSON string. To send JSON null as the body, use +// the internal.Null variable. func (r *Request) Send(ctx context.Context, hc *http.Client) (*Response, error) { req, err := r.newHTTPRequest() if err != nil { @@ -59,7 +69,13 @@ func (r *Request) newHTTPRequest() (*http.Request, error) { var opts []HTTPOption var data io.Reader if r.Body != nil { - b, err := json.Marshal(r.Body) + var body interface{} + if r.Body == Null { + body = nil + } else { + body = r.Body + } + b, err := json.Marshal(body) if err != nil { return nil, err } @@ -79,12 +95,17 @@ func (r *Request) newHTTPRequest() (*http.Request, error) { return req, nil } +// Response contains information extracted from an HTTP response. type Response struct { Status int Header http.Header Body []byte } +// CheckStatus checks whether the Response status code has the given HTTP status code. +// +// Returns an error if the status code does not match. If an ErroParser is specified, uses that to +// construct the returned error message. Otherwise includes the full response body in the error. func (r *Response) CheckStatus(want int, ep ErrorParser) error { if r.Status == want { return nil @@ -100,6 +121,11 @@ func (r *Response) CheckStatus(want int, ep ErrorParser) error { return fmt.Errorf("http error status: %d; reason: %s", r.Status, msg) } +// Unmarshal checks if the Response has the given HTTP status code, and if so unmarshals the +// response body into the variable pointed by v. +// +// Unmarshal uses https://golang.org/pkg/encoding/json/#Unmarshal internally, and hence v has the +// same requirements as the json package. func (r *Response) Unmarshal(want int, ep ErrorParser, v interface{}) error { if err := r.CheckStatus(want, ep); err != nil { return err @@ -109,16 +135,20 @@ func (r *Response) Unmarshal(want int, ep ErrorParser, v interface{}) error { return nil } +// ErrorParser is a function that is used to construct custom error messages. type ErrorParser func(r *Response) string +// HTTPOption is an additional parameter that can be specified to customize an outgoing request. type HTTPOption func(*http.Request) +// WithHeader creates an HTTPOption that will set an HTTP header on the request. func WithHeader(key, value string) HTTPOption { return func(r *http.Request) { r.Header.Set(key, value) } } +// WithQueryParam creates an HTTPOption that will set a query parameter on the request. func WithQueryParam(key, value string) HTTPOption { return func(r *http.Request) { q := r.URL.Query() @@ -127,6 +157,8 @@ func WithQueryParam(key, value string) HTTPOption { } } +// WithQueryParams creates an HTTPOption that will set all the entries of qp as query parameters +// on the request. func WithQueryParams(qp map[string]string) HTTPOption { return func(r *http.Request) { q := r.URL.Query() diff --git a/internal/http_client_test.go b/internal/http_client_test.go new file mode 100644 index 00000000..c7ddda74 --- /dev/null +++ b/internal/http_client_test.go @@ -0,0 +1,259 @@ +package internal + +import ( + "context" + "encoding/json" + "io/ioutil" + "net/http" + "net/http/httptest" + "reflect" + "testing" +) + +var cases = []struct { + req *Request + method string + body interface{} + headers map[string]string + query map[string]string +}{ + { + req: &Request{ + Method: "GET", + }, + method: "GET", + }, + { + req: &Request{ + Method: "GET", + Opts: []HTTPOption{ + WithHeader("Test-Header", "value1"), + WithQueryParam("testParam", "value2"), + }, + }, + method: "GET", + headers: map[string]string{"Test-Header": "value1"}, + query: map[string]string{"testParam": "value2"}, + }, + { + req: &Request{ + Method: "POST", + Body: map[string]string{"foo": "bar"}, + Opts: []HTTPOption{ + WithHeader("Test-Header", "value1"), + WithQueryParam("testParam1", "value2"), + WithQueryParam("testParam2", "value3"), + }, + }, + method: "POST", + body: map[string]string{"foo": "bar"}, + headers: map[string]string{"Test-Header": "value1"}, + query: map[string]string{"testParam1": "value2", "testParam2": "value3"}, + }, + { + req: &Request{ + Method: "POST", + Body: "body", + Opts: []HTTPOption{ + WithHeader("Test-Header", "value1"), + WithQueryParams(map[string]string{"testParam1": "value2", "testParam2": "value3"}), + }, + }, + method: "POST", + body: "body", + headers: map[string]string{"Test-Header": "value1"}, + query: map[string]string{"testParam1": "value2", "testParam2": "value3"}, + }, + { + req: &Request{ + Method: "PUT", + Body: Null, + Opts: []HTTPOption{ + WithHeader("Test-Header", "value1"), + WithQueryParams(map[string]string{"testParam1": "value2", "testParam2": "value3"}), + }, + }, + method: "PUT", + body: Null, + headers: map[string]string{"Test-Header": "value1"}, + query: map[string]string{"testParam1": "value2", "testParam2": "value3"}, + }, +} + +func TestSend(t *testing.T) { + want := map[string]interface{}{ + "key1": "value1", + "key2": float64(100), + } + b, err := json.Marshal(want) + if err != nil { + t.Fatal(err) + } + + idx := 0 + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + want := cases[idx] + if r.Method != want.method { + t.Errorf("[%d] Method = %q; want = %q", idx, r.Method, want.method) + } + for k, v := range want.headers { + h := r.Header.Get(k) + if h != v { + t.Errorf("[%d] Header(%q) = %q; want = %q", idx, k, h, v) + } + } + if want.query == nil { + if r.URL.Query().Encode() != "" { + t.Errorf("[%d] Query = %v; want = empty", idx, r.URL.Query().Encode()) + } + } + for k, v := range want.query { + q := r.URL.Query().Get(k) + if q != v { + t.Errorf("[%d] Query(%q) = %q; want = %q", idx, k, q, v) + } + } + if want.body != nil { + h := r.Header.Get("Content-Type") + if h != "application/json" { + t.Errorf("[%d] Content-Type = %q; want = %q", idx, h, "application/json") + } + + var wb []byte + if want.body == Null { + wb = []byte("null") + } else { + wb, err = json.Marshal(want.body) + if err != nil { + t.Fatal(err) + } + } + gb, err := ioutil.ReadAll(r.Body) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(wb, gb) { + t.Errorf("[%d] Body = %q; want = %q", idx, string(gb), string(wb)) + } + } + + idx++ + w.Header().Set("Content-Type", "application/json") + w.Write(b) + }) + server := httptest.NewServer(handler) + defer server.Close() + + for _, tc := range cases { + tc.req.URL = server.URL + resp, err := tc.req.Send(context.Background(), http.DefaultClient) + if err != nil { + t.Fatal(err) + } + if err := resp.CheckStatus(http.StatusOK, nil); err != nil { + t.Errorf("CheckStatus() = %v; want nil", err) + } + if err := resp.CheckStatus(http.StatusCreated, nil); err == nil { + t.Errorf("CheckStatus() = nil; want error") + } + + var got map[string]interface{} + if err := resp.Unmarshal(http.StatusOK, nil, &got); err != nil { + t.Errorf("Unmarshal() = %v; want nil", err) + } + if !reflect.DeepEqual(got, want) { + t.Errorf("Body = %v; want = %v", got, want) + } + } +} + +func TestErrorParser(t *testing.T) { + data := map[string]interface{}{ + "error": "test error", + } + b, err := json.Marshal(data) + if err != nil { + t.Fatal(err) + } + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Header().Set("Content-Type", "application/json") + w.Write(b) + }) + server := httptest.NewServer(handler) + defer server.Close() + + req := &Request{ + Method: "GET", + URL: server.URL, + } + resp, err := req.Send(context.Background(), http.DefaultClient) + if err != nil { + t.Fatal(err) + } + + ep := func(r *Response) string { + var b struct { + Error string `json:"error"` + } + if err := json.Unmarshal(r.Body, &b); err != nil { + return "" + } + return b.Error + } + + want := "http error status: 500; reason: test error" + if err := resp.CheckStatus(http.StatusOK, ep); err.Error() != want { + t.Errorf("CheckStatus() = %q; want = %q", err.Error(), want) + } + var got map[string]interface{} + if err := resp.Unmarshal(http.StatusOK, ep, &got); err.Error() != want { + t.Errorf("CheckStatus() = %q; want = %q", err.Error(), want) + } + if got != nil { + t.Errorf("Body = %v; want = nil", got) + } +} + +func TestInvalidURL(t *testing.T) { + req := &Request{ + Method: "GET", + URL: "http://localhost:250/mock.url", + } + _, err := req.Send(context.Background(), http.DefaultClient) + if err == nil { + t.Errorf("Send() = nil; want error") + } +} + +func TestUnmarshalError(t *testing.T) { + data := map[string]interface{}{ + "foo": "bar", + } + b, err := json.Marshal(data) + if err != nil { + t.Fatal(err) + } + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write(b) + }) + server := httptest.NewServer(handler) + defer server.Close() + + req := &Request{ + Method: "GET", + URL: server.URL, + } + resp, err := req.Send(context.Background(), http.DefaultClient) + if err != nil { + t.Fatal(err) + } + + var got func() + if err := resp.Unmarshal(http.StatusOK, nil, &got); err == nil { + t.Errorf("Unmarshal() = nil; want error") + } +} From d5a5fae43eb6770728df581f2bc1a04869f964f5 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Mon, 30 Oct 2017 15:42:03 -0700 Subject: [PATCH 29/58] Comment clean up --- internal/http_client.go | 1 - internal/http_client_test.go | 13 +++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/internal/http_client.go b/internal/http_client.go index c4ac1c9d..d19316f8 100644 --- a/internal/http_client.go +++ b/internal/http_client.go @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package internal contains functionality that is only accessible from within the Admin SDK. package internal import ( diff --git a/internal/http_client_test.go b/internal/http_client_test.go index c7ddda74..66019573 100644 --- a/internal/http_client_test.go +++ b/internal/http_client_test.go @@ -1,3 +1,16 @@ +// Copyright 2017 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. package internal import ( From 73bfd8f0122027c8c2a1d82789da6b75e02ace49 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Mon, 30 Oct 2017 16:52:53 -0700 Subject: [PATCH 30/58] Using the shared http client API --- db/auth_override_test.go | 9 +-- db/db.go | 29 ++++++++ db/db_test.go | 3 - db/http_client.go | 151 --------------------------------------- db/query.go | 27 +++---- db/query_test.go | 135 ++++------------------------------ db/ref.go | 95 ++++++++++++------------ db/ref_test.go | 86 ++++++++-------------- 8 files changed, 129 insertions(+), 406 deletions(-) delete mode 100644 db/http_client.go diff --git a/db/auth_override_test.go b/db/auth_override_test.go index 7b78b02c..c696acd8 100644 --- a/db/auth_override_test.go +++ b/db/auth_override_test.go @@ -15,6 +15,7 @@ package db import ( + "context" "testing" ) @@ -25,7 +26,7 @@ func TestAuthOverrideGet(t *testing.T) { ref := aoClient.NewRef("peter") var got string - if err := ref.Get(&got); err != nil { + if err := ref.Get(context.Background(), &got); err != nil { t.Fatal(err) } if got != "data" { @@ -45,7 +46,7 @@ func TestAuthOverrideSet(t *testing.T) { ref := aoClient.NewRef("peter") want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} - if err := ref.Set(want); err != nil { + if err := ref.Set(context.Background(), want); err != nil { t.Fatal(err) } checkOnlyRequest(t, mock.Reqs, &testReq{ @@ -63,7 +64,7 @@ func TestAuthOverrideQuery(t *testing.T) { ref := aoClient.NewRef("peter") var got string - if err := ref.OrderByChild("foo").Get(&got); err != nil { + if err := ref.OrderByChild("foo").Get(context.Background(), &got); err != nil { t.Fatal(err) } if got != "data" { @@ -86,7 +87,7 @@ func TestAuthOverrideRangeQuery(t *testing.T) { ref := aoClient.NewRef("peter") var got string - if err := ref.OrderByChild("foo").WithStartAt(1).WithEndAt(10).Get(&got); err != nil { + if err := ref.OrderByChild("foo").WithStartAt(1).WithEndAt(10).Get(context.Background(), &got); err != nil { t.Fatal(err) } if got != "data" { diff --git a/db/db.go b/db/db.go index 03bf7eeb..12c26230 100644 --- a/db/db.go +++ b/db/db.go @@ -32,6 +32,18 @@ import ( ) const userAgentFormat = "Firebase/HTTP/%s/%s/AdminGo" +const invalidChars = "[].#$" +const authVarOverride = "auth_variable_override" + +var errParser = func(r *internal.Response) string { + var b struct { + Error string `json:"error"` + } + if err := json.Unmarshal(r.Body, &b); err != nil { + return "" + } + return b.Error +} // Client is the interface for the Firebase Realtime Database service. type Client struct { @@ -107,6 +119,23 @@ func (c *Client) NewRef(path string) *Ref { } } +func (c *Client) newHTTPRequest(method, path string, body interface{}, opts ...internal.HTTPOption) (*internal.Request, error) { + if strings.ContainsAny(path, invalidChars) { + return nil, fmt.Errorf("invalid path with illegal characters: %q", path) + } + + if c.ao != "" { + opts = append(opts, internal.WithQueryParam(authVarOverride, c.ao)) + } + url := fmt.Sprintf("%s%s.json", c.url, path) + return &internal.Request{ + Method: method, + URL: url, + Body: body, + Opts: opts, + }, nil +} + func parsePath(path string) []string { var segs []string for _, s := range strings.Split(path, "/") { diff --git a/db/db_test.go b/db/db_test.go index 62226074..e48e80ed 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -182,9 +182,6 @@ func TestNewRef(t *testing.T) { if r.client == nil { t.Errorf("NewRef(%q).client = nil; want = %v", tc.Path, r.client) } - if r.ctx != nil { - t.Errorf("NewRef(%q).ctx = %v; want nil", tc.Path, r.ctx) - } if r.Path != tc.WantPath { t.Errorf("NewRef(%q).Path = %q; want = %q", tc.Path, r.Path, tc.WantPath) } diff --git a/db/http_client.go b/db/http_client.go deleted file mode 100644 index 95999827..00000000 --- a/db/http_client.go +++ /dev/null @@ -1,151 +0,0 @@ -// Copyright 2017 Google Inc. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package db - -import ( - "bytes" - "encoding/json" - "fmt" - "io" - "io/ioutil" - "net/http" - "strings" - - "golang.org/x/net/context" -) - -const invalidChars = "[].#$" -const authVarOverride = "auth_variable_override" - -type request struct { - Method string - Path string - Body interface{} - Opts []httpOption -} - -func (c *Client) send(ctx context.Context, r *request) (*response, error) { - if strings.ContainsAny(r.Path, invalidChars) { - return nil, fmt.Errorf("invalid path with illegal characters: %q", r.Path) - } - - var opts []httpOption - var data io.Reader - if r.Body != nil { - b, err := json.Marshal(r.Body) - if err != nil { - return nil, err - } - data = bytes.NewBuffer(b) - opts = append(opts, withHeader("Content-Type", "application/json")) - } - - url := fmt.Sprintf("%s%s.json", c.url, r.Path) - req, err := http.NewRequest(r.Method, url, data) - if err != nil { - return nil, err - } - - if ctx != nil { - req = req.WithContext(ctx) - } - - if c.ao != "" { - opts = append(opts, withQueryParam(authVarOverride, c.ao)) - } - opts = append(opts, r.Opts...) - - return doSend(c.hc, req, opts...) -} - -func doSend(hc *http.Client, req *http.Request, opts ...httpOption) (*response, error) { - for _, o := range opts { - o(req) - } - - resp, err := hc.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - b, err := ioutil.ReadAll(resp.Body) - if err != nil { - return nil, err - } - return &response{ - Status: resp.StatusCode, - Body: b, - Header: resp.Header, - }, nil -} - -type response struct { - Status int - Header http.Header - Body []byte -} - -func (r *response) CheckStatus(want int) error { - if r.Status == want { - return nil - } - var b struct { - Error string `json:"error"` - } - json.Unmarshal(r.Body, &b) - var msg string - if b.Error != "" { - msg = fmt.Sprintf("http error status: %d; reason: %s", r.Status, b.Error) - } else { - msg = fmt.Sprintf("http error status: %d; message: %s", r.Status, string(r.Body)) - } - return fmt.Errorf(msg) -} - -func (r *response) CheckAndParse(want int, v interface{}) error { - if err := r.CheckStatus(want); err != nil { - return err - } else if err := json.Unmarshal(r.Body, v); err != nil { - return err - } - return nil -} - -type httpOption func(*http.Request) - -func withHeader(key, value string) httpOption { - return func(r *http.Request) { - r.Header.Set(key, value) - } -} - -func withQueryParam(key, value string) httpOption { - return func(r *http.Request) { - q := r.URL.Query() - q.Add(key, value) - r.URL.RawQuery = q.Encode() - } -} - -func withQueryParams(qp map[string]string) httpOption { - return func(r *http.Request) { - q := r.URL.Query() - for k, v := range qp { - q.Add(k, v) - } - r.URL.RawQuery = q.Encode() - } -} diff --git a/db/query.go b/db/query.go index 746c8d7c..2121d8df 100644 --- a/db/query.go +++ b/db/query.go @@ -21,6 +21,8 @@ import ( "strconv" "strings" + "firebase.google.com/go/internal" + "golang.org/x/net/context" ) @@ -33,7 +35,6 @@ import ( // final result is returned by the server as an unordered collection. Therefore the values read // from a Query instance are not ordered. type Query struct { - ctx context.Context client *Client path string ob orderBy @@ -89,20 +90,10 @@ func (q *Query) WithLimitToLast(n int) *Query { return q2 } -// WithContext returns a shallow copy of this Query with its context changed to ctx. -// -// The resulting Query will use ctx for all subsequent RPC calls. -func (q *Query) WithContext(ctx context.Context) *Query { - q2 := new(Query) - *q2 = *q - q2.ctx = ctx - return q2 -} - // Get executes the Query and populates v with the results. // // Results will not be stored in any particular order in v. -func (q *Query) Get(v interface{}) error { +func (q *Query) Get(ctx context.Context, v interface{}) error { qp := make(map[string]string) ob, err := q.ob.encode() if err != nil { @@ -134,16 +125,15 @@ func (q *Query) Get(v interface{}) error { return err } - req := &request{ - Method: "GET", - Path: q.path, - Opts: []httpOption{withQueryParams(qp)}, + req, err := q.client.newHTTPRequest("GET", q.path, nil, internal.WithQueryParams(qp)) + if err != nil { + return err } - resp, err := q.client.send(q.ctx, req) + resp, err := req.Send(ctx, q.client.hc) if err != nil { return err } - return resp.CheckAndParse(http.StatusOK, v) + return resp.Unmarshal(http.StatusOK, errParser, v) } // OrderByChild returns a Query that orders data by child values before applying filters. @@ -175,7 +165,6 @@ func (r *Ref) OrderByValue() *Query { func newQuery(r *Ref, ob orderBy) *Query { return &Query{ - ctx: r.ctx, client: r.client, path: r.Path, ob: ob, diff --git a/db/query_test.go b/db/query_test.go index c3b58598..a67ff472 100644 --- a/db/query_test.go +++ b/db/query_test.go @@ -20,113 +20,6 @@ import ( "golang.org/x/net/context" ) -func TestQueryWithContext(t *testing.T) { - q := client.NewRef("peter").OrderByChild("messages") - if q.ctx != nil { - t.Errorf("query = %v; want nil", q.ctx) - } - - ctx, cancel := context.WithCancel(context.Background()) - q = q.WithContext(ctx) - if q.ctx != ctx { - t.Errorf("query.WithContext() = %v; want %v", q.ctx, ctx) - } - - want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} - mock := &mockServer{Resp: want} - srv := mock.Start(client) - defer srv.Close() - - var got map[string]interface{} - if err := q.Get(&got); err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(want, got) { - t.Errorf("query.WithContext() = %v; want = %v", got, want) - } - checkOnlyRequest(t, mock.Reqs, &testReq{ - Method: "GET", - Path: "/peter.json", - Query: map[string]string{"orderBy": "\"messages\""}, - }) - - cancel() - got = nil - if err := q.Get(&got); len(got) != 0 || err == nil { - t.Errorf("query.WithContext() = (%v, %v); want = (empty, error)", got, err) - } -} - -func TestQueryFromRefWithContext(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - q := client.NewRef("peter").WithContext(ctx).OrderByChild("messages") - if q.ctx != ctx { - t.Errorf("ref.WithContext().query = %v; want %v", q.ctx, ctx) - } - - want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} - mock := &mockServer{Resp: want} - srv := mock.Start(client) - defer srv.Close() - - var got map[string]interface{} - if err := q.Get(&got); err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(want, got) { - t.Errorf("ref.WithContext().query = %v; want = %v", got, want) - } - checkOnlyRequest(t, mock.Reqs, &testReq{ - Method: "GET", - Path: "/peter.json", - Query: map[string]string{"orderBy": "\"messages\""}, - }) - - cancel() - got = nil - if err := q.Get(&got); len(got) != 0 || err == nil { - t.Errorf("ref.WithContext().query = (%v, %v); want = (empty, error)", got, err) - } -} - -func TestQueryWithContextPrecedence(t *testing.T) { - ctx1 := context.Background() - ctx2, cancel := context.WithCancel(ctx1) - - r := client.NewRef("peter").WithContext(ctx1) - q := r.OrderByChild("messages").WithContext(ctx2) - if q.ctx != ctx2 { - t.Errorf("Ctx = %v; want %v", q.ctx, ctx2) - } - - want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} - mock := &mockServer{Resp: want} - srv := mock.Start(client) - defer srv.Close() - - var got map[string]interface{} - if err := q.Get(&got); err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(want, got) { - t.Errorf("ref.WithContext().query.WithContext() = %v; want = %v", got, want) - } - checkOnlyRequest(t, mock.Reqs, &testReq{ - Method: "GET", - Path: "/peter.json", - Query: map[string]string{"orderBy": "\"messages\""}, - }) - - cancel() - got = nil - if err := q.Get(&got); len(got) != 0 || err == nil { - t.Errorf("ref.WithContext().query.WithContext() = (%v, %v); want = (empty, error)", got, err) - } - if err := r.Get(&got); !reflect.DeepEqual(got, want) || err != nil { - t.Errorf("ref.WithContext() = (%v, %v); want = (%v, nil)", got, err, want) - } -} - func TestChildQuery(t *testing.T) { want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} mock := &mockServer{Resp: want} @@ -139,7 +32,7 @@ func TestChildQuery(t *testing.T) { var reqs []*testReq for _, tc := range cases { var got map[string]interface{} - if err := testref.OrderByChild(tc).Get(&got); err != nil { + if err := testref.OrderByChild(tc).Get(context.Background(), &got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { @@ -162,7 +55,7 @@ func TestNestedChildQuery(t *testing.T) { defer srv.Close() var got map[string]interface{} - if err := testref.OrderByChild("messages/ratings").Get(&got); err != nil { + if err := testref.OrderByChild("messages/ratings").Get(context.Background(), &got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { @@ -183,7 +76,7 @@ func TestChildQueryWithParams(t *testing.T) { q := testref.OrderByChild("messages").WithStartAt("m4").WithEndAt("m50").WithLimitToFirst(10) var got map[string]interface{} - if err := q.Get(&got); err != nil { + if err := q.Get(context.Background(), &got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { @@ -213,7 +106,7 @@ func TestInvalidOrderByChild(t *testing.T) { } for _, tc := range cases { var got string - if err := r.OrderByChild(tc).Get(&got); got != "" || err == nil { + if err := r.OrderByChild(tc).Get(context.Background(), &got); got != "" || err == nil { t.Errorf("OrderByChild(%q) = (%q, %v); want = (%q, error)", tc, got, err, "") } } @@ -229,7 +122,7 @@ func TestKeyQuery(t *testing.T) { defer srv.Close() var got map[string]interface{} - if err := testref.OrderByKey().Get(&got); err != nil { + if err := testref.OrderByKey().Get(context.Background(), &got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { @@ -249,7 +142,7 @@ func TestValueQuery(t *testing.T) { defer srv.Close() var got map[string]interface{} - if err := testref.OrderByValue().Get(&got); err != nil { + if err := testref.OrderByValue().Get(context.Background(), &got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { @@ -269,7 +162,7 @@ func TestLimitFirstQuery(t *testing.T) { defer srv.Close() var got map[string]interface{} - if err := testref.OrderByChild("messages").WithLimitToFirst(10).Get(&got); err != nil { + if err := testref.OrderByChild("messages").WithLimitToFirst(10).Get(context.Background(), &got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { @@ -289,7 +182,7 @@ func TestLimitLastQuery(t *testing.T) { defer srv.Close() var got map[string]interface{} - if err := testref.OrderByChild("messages").WithLimitToLast(10).Get(&got); err != nil { + if err := testref.OrderByChild("messages").WithLimitToLast(10).Get(context.Background(), &got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { @@ -319,7 +212,7 @@ func TestInvalidLimitQuery(t *testing.T) { } for _, tc := range cases { var got map[string]interface{} - if err := tc.q.Get(&got); got != nil || err == nil { + if err := tc.q.Get(context.Background(), &got); got != nil || err == nil { t.Errorf("OrderByChild(%q) = (%v, %v); want = (nil, error)", tc.name, got, err) } if len(mock.Reqs) != 0 { @@ -335,7 +228,7 @@ func TestStartAtQuery(t *testing.T) { defer srv.Close() var got map[string]interface{} - if err := testref.OrderByChild("messages").WithStartAt(10).Get(&got); err != nil { + if err := testref.OrderByChild("messages").WithStartAt(10).Get(context.Background(), &got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { @@ -355,7 +248,7 @@ func TestEndAtQuery(t *testing.T) { defer srv.Close() var got map[string]interface{} - if err := testref.OrderByChild("messages").WithEndAt(10).Get(&got); err != nil { + if err := testref.OrderByChild("messages").WithEndAt(10).Get(context.Background(), &got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { @@ -375,7 +268,7 @@ func TestEqualToQuery(t *testing.T) { defer srv.Close() var got map[string]interface{} - if err := testref.OrderByChild("messages").WithEqualTo(10).Get(&got); err != nil { + if err := testref.OrderByChild("messages").WithEqualTo(10).Get(context.Background(), &got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { @@ -405,7 +298,7 @@ func TestInvalidFilterQuery(t *testing.T) { } for _, tc := range cases { var got map[string]interface{} - if err := tc.q.Get(&got); got != nil || err == nil { + if err := tc.q.Get(context.Background(), &got); got != nil || err == nil { t.Errorf("OrderByChild(%q) = (%v, %v); want = (nil, error)", tc.name, got, err) } if len(mock.Reqs) != 0 { @@ -422,7 +315,7 @@ func TestAllParamsQuery(t *testing.T) { q := testref.OrderByChild("messages").WithLimitToFirst(100).WithStartAt("bar").WithEndAt("foo") var got map[string]interface{} - if err := q.Get(&got); err != nil { + if err := q.Get(context.Background(), &got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { diff --git a/db/ref.go b/db/ref.go index a91baccb..c04e3aa1 100644 --- a/db/ref.go +++ b/db/ref.go @@ -19,6 +19,8 @@ import ( "net/http" "strings" + "firebase.google.com/go/internal" + "golang.org/x/net/context" ) @@ -31,7 +33,6 @@ type Ref struct { segs []string client *Client - ctx context.Context } // Parent returns a reference to the parent of the current node. @@ -58,30 +59,20 @@ func (r *Ref) Child(path string) *Ref { // Data deserialization is performed using https://golang.org/pkg/encoding/json/#Unmarshal, and // therefore v has the same requirements as the json package. Specifically, it must be a pointer, // and it must not be nil. -func (r *Ref) Get(v interface{}) error { - resp, err := r.send("GET") +func (r *Ref) Get(ctx context.Context, v interface{}) error { + resp, err := r.send(ctx, "GET") if err != nil { return err } - return resp.CheckAndParse(http.StatusOK, v) -} - -// WithContext returns a shallow copy of this Ref with its context changed to ctx. -// -// The resulting Ref will use ctx for all subsequent RPC calls. -func (r *Ref) WithContext(ctx context.Context) *Ref { - r2 := new(Ref) - *r2 = *r - r2.ctx = ctx - return r2 + return resp.Unmarshal(http.StatusOK, errParser, v) } // GetWithETag retrieves the value at the current database location, along with its ETag. -func (r *Ref) GetWithETag(v interface{}) (string, error) { - resp, err := r.send("GET", withHeader("X-Firebase-ETag", "true")) +func (r *Ref) GetWithETag(ctx context.Context, v interface{}) (string, error) { + resp, err := r.send(ctx, "GET", internal.WithHeader("X-Firebase-ETag", "true")) if err != nil { return "", err - } else if err := resp.CheckAndParse(http.StatusOK, v); err != nil { + } else if err := resp.Unmarshal(http.StatusOK, errParser, v); err != nil { return "", err } return resp.Header.Get("Etag"), nil @@ -94,13 +85,13 @@ func (r *Ref) GetWithETag(v interface{}) (string, error) { // location. The value of the database location will be stored in v just like a regular Get() call. // If the etag matches, returns false along with the same ETag passed into the function. No data // will be stored in v in this case. -func (r *Ref) GetIfChanged(etag string, v interface{}) (bool, string, error) { - resp, err := r.send("GET", withHeader("If-None-Match", etag)) +func (r *Ref) GetIfChanged(ctx context.Context, etag string, v interface{}) (bool, string, error) { + resp, err := r.send(ctx, "GET", internal.WithHeader("If-None-Match", etag)) if err != nil { return false, "", err - } else if err := resp.CheckAndParse(http.StatusOK, v); err == nil { + } else if err := resp.Unmarshal(http.StatusOK, errParser, v); err == nil { return true, resp.Header.Get("ETag"), nil - } else if err := resp.CheckStatus(http.StatusNotModified); err != nil { + } else if err := resp.CheckStatus(http.StatusNotModified, errParser); err != nil { return false, "", err } return false, etag, nil @@ -111,25 +102,25 @@ func (r *Ref) GetIfChanged(etag string, v interface{}) (bool, string, error) { // Set uses https://golang.org/pkg/encoding/json/#Marshal to serialize values into JSON. Therefore // v has the same requirements as the json package. Values like functions and channels cannot be // saved into Realtime Database. -func (r *Ref) Set(v interface{}) error { - resp, err := r.sendWithBody("PUT", v, withQueryParam("print", "silent")) +func (r *Ref) Set(ctx context.Context, v interface{}) error { + resp, err := r.sendWithBody(ctx, "PUT", v, internal.WithQueryParam("print", "silent")) if err != nil { return err } - return resp.CheckStatus(http.StatusNoContent) + return resp.CheckStatus(http.StatusNoContent, errParser) } // SetIfUnchanged conditionally sets the data at this location to the given value. // // Sets the data at this location to v only if the specified ETag matches. Returns true if the // value is written. Returns false if no changes are made to the database. -func (r *Ref) SetIfUnchanged(etag string, v interface{}) (bool, error) { - resp, err := r.sendWithBody("PUT", v, withHeader("If-Match", etag)) +func (r *Ref) SetIfUnchanged(ctx context.Context, etag string, v interface{}) (bool, error) { + resp, err := r.sendWithBody(ctx, "PUT", v, internal.WithHeader("If-Match", etag)) if err != nil { return false, err - } else if err := resp.CheckStatus(http.StatusOK); err == nil { + } else if err := resp.CheckStatus(http.StatusOK, errParser); err == nil { return true, nil - } else if err := resp.CheckStatus(http.StatusPreconditionFailed); err != nil { + } else if err := resp.CheckStatus(http.StatusPreconditionFailed, errParser); err != nil { return false, err } return false, nil @@ -139,33 +130,33 @@ func (r *Ref) SetIfUnchanged(etag string, v interface{}) (bool, error) { // // If v is not nil, it will be set as the initial value of the new child node. If v is nil, the // new child node will be created with empty string as the value. -func (r *Ref) Push(v interface{}) (*Ref, error) { +func (r *Ref) Push(ctx context.Context, v interface{}) (*Ref, error) { if v == nil { v = "" } - resp, err := r.sendWithBody("POST", v) + resp, err := r.sendWithBody(ctx, "POST", v) if err != nil { return nil, err } var d struct { Name string `json:"name"` } - if err := resp.CheckAndParse(http.StatusOK, &d); err != nil { + if err := resp.Unmarshal(http.StatusOK, errParser, &d); err != nil { return nil, err } return r.Child(d.Name), nil } // Update modifies the specified child keys of the current location to the provided values. -func (r *Ref) Update(v map[string]interface{}) error { +func (r *Ref) Update(ctx context.Context, v map[string]interface{}) error { if len(v) == 0 { return fmt.Errorf("value argument must be a non-empty map") } - resp, err := r.sendWithBody("PATCH", v, withQueryParam("print", "silent")) + resp, err := r.sendWithBody(ctx, "PATCH", v, internal.WithQueryParam("print", "silent")) if err != nil { return err } - return resp.CheckStatus(http.StatusNoContent) + return resp.CheckStatus(http.StatusNoContent, errParser) } type UpdateFn func(interface{}) (interface{}, error) @@ -184,9 +175,9 @@ type UpdateFn func(interface{}) (interface{}, error) // // The update function may also force an early abort by returning an error instead of returning a // value. -func (r *Ref) Transaction(fn UpdateFn) error { +func (r *Ref) Transaction(ctx context.Context, fn UpdateFn) error { var curr interface{} - etag, err := r.GetWithETag(&curr) + etag, err := r.GetWithETag(ctx, &curr) if err != nil { return err } @@ -196,12 +187,12 @@ func (r *Ref) Transaction(fn UpdateFn) error { if err != nil { return err } - resp, err := r.sendWithBody("PUT", new, withHeader("If-Match", etag)) + resp, err := r.sendWithBody(ctx, "PUT", new, internal.WithHeader("If-Match", etag)) if err != nil { return err - } else if err := resp.CheckStatus(http.StatusOK); err == nil { + } else if err := resp.CheckStatus(http.StatusOK, errParser); err == nil { return nil - } else if err := resp.CheckAndParse(http.StatusPreconditionFailed, &curr); err != nil { + } else if err := resp.Unmarshal(http.StatusPreconditionFailed, errParser, &curr); err != nil { return err } etag = resp.Header.Get("ETag") @@ -210,24 +201,26 @@ func (r *Ref) Transaction(fn UpdateFn) error { } // Delete removes this node from the database. -func (r *Ref) Delete() error { - resp, err := r.send("DELETE") +func (r *Ref) Delete(ctx context.Context) error { + resp, err := r.send(ctx, "DELETE") if err != nil { return err } - return resp.CheckStatus(http.StatusOK) + return resp.CheckStatus(http.StatusOK, errParser) } -func (r *Ref) send(method string, opts ...httpOption) (*response, error) { - return r.sendWithBody(method, nil, opts...) +func (r *Ref) send( + ctx context.Context, method string, + opts ...internal.HTTPOption) (*internal.Response, error) { + return r.sendWithBody(ctx, method, nil, opts...) } -func (r *Ref) sendWithBody(method string, body interface{}, opts ...httpOption) (*response, error) { - req := &request{ - Method: method, - Body: body, - Path: r.Path, - Opts: opts, +func (r *Ref) sendWithBody( + ctx context.Context, method string, body interface{}, + opts ...internal.HTTPOption) (*internal.Response, error) { + req, err := r.client.newHTTPRequest(method, r.Path, body, opts...) + if err != nil { + return nil, err } - return r.client.send(r.ctx, req) + return req.Send(ctx, r.client.hc) } diff --git a/db/ref_test.go b/db/ref_test.go index b8f6949b..685a8e47 100644 --- a/db/ref_test.go +++ b/db/ref_test.go @@ -35,7 +35,7 @@ var testOps = []struct { "test", func(r *Ref) error { var got string - return r.Get(&got) + return r.Get(context.Background(), &got) }, }, { @@ -43,7 +43,7 @@ var testOps = []struct { "test", func(r *Ref) error { var got string - _, err := r.GetWithETag(&got) + _, err := r.GetWithETag(context.Background(), &got) return err }, }, @@ -52,7 +52,7 @@ var testOps = []struct { "test", func(r *Ref) error { var got string - _, _, err := r.GetIfChanged("etag", &got) + _, _, err := r.GetIfChanged(context.Background(), "etag", &got) return err }, }, @@ -60,14 +60,14 @@ var testOps = []struct { "Set()", nil, func(r *Ref) error { - return r.Set("foo") + return r.Set(context.Background(), "foo") }, }, { "SetIfUnchanged()", nil, func(r *Ref) error { - _, err := r.SetIfUnchanged("etag", "foo") + _, err := r.SetIfUnchanged(context.Background(), "etag", "foo") return err }, }, @@ -75,7 +75,7 @@ var testOps = []struct { "Push()", map[string]interface{}{"name": "test"}, func(r *Ref) error { - _, err := r.Push("foo") + _, err := r.Push(context.Background(), "foo") return err }, }, @@ -83,14 +83,14 @@ var testOps = []struct { "Update()", nil, func(r *Ref) error { - return r.Update(map[string]interface{}{"foo": "bar"}) + return r.Update(context.Background(), map[string]interface{}{"foo": "bar"}) }, }, { "Delete()", nil, func(r *Ref) error { - return r.Delete() + return r.Delete(context.Background()) }, }, { @@ -100,7 +100,7 @@ var testOps = []struct { fn := func(v interface{}) (interface{}, error) { return v, nil } - return r.Transaction(fn) + return r.Transaction(context.Background(), fn) }, }, } @@ -118,7 +118,7 @@ func TestGet(t *testing.T) { for _, tc := range cases { mock.Resp = tc var got interface{} - if err := testref.Get(&got); err != nil { + if err := testref.Get(context.Background(), &got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(tc, got) { @@ -136,7 +136,7 @@ func TestInvalidGet(t *testing.T) { defer srv.Close() got := func() {} - if err := testref.Get(&got); err == nil { + if err := testref.Get(context.Background(), &got); err == nil { t.Errorf("Get(func) = nil; want error") } checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: "/peter.json"}) @@ -149,7 +149,7 @@ func TestGetWithStruct(t *testing.T) { defer srv.Close() var got person - if err := testref.Get(&got); err != nil { + if err := testref.Get(context.Background(), &got); err != nil { t.Fatal(err) } if want != got { @@ -168,7 +168,7 @@ func TestGetWithETag(t *testing.T) { defer srv.Close() var got map[string]interface{} - etag, err := testref.GetWithETag(&got) + etag, err := testref.GetWithETag(context.Background(), &got) if err != nil { t.Fatal(err) } @@ -195,7 +195,7 @@ func TestGetIfChanged(t *testing.T) { defer srv.Close() var got map[string]interface{} - ok, etag, err := testref.GetIfChanged("old-etag", &got) + ok, etag, err := testref.GetIfChanged(context.Background(), "old-etag", &got) if err != nil { t.Fatal(err) } @@ -212,7 +212,7 @@ func TestGetIfChanged(t *testing.T) { mock.Status = http.StatusNotModified mock.Resp = nil var got2 map[string]interface{} - ok, etag, err = testref.GetIfChanged("new-etag", &got2) + ok, etag, err = testref.GetIfChanged(context.Background(), "new-etag", &got2) if err != nil { t.Fatal(err) } @@ -263,7 +263,7 @@ func TestUnexpectedHttpError(t *testing.T) { srv := mock.Start(client) defer srv.Close() - want := "http error status: 500; message: \"unexpected error\"" + want := "http error status: 500; reason: \"unexpected error\"" for _, tc := range testOps { err := tc.op(testref) if err == nil || err.Error() != want { @@ -336,7 +336,7 @@ func TestSet(t *testing.T) { } var want []*testReq for _, tc := range cases { - if err := testref.Set(tc); err != nil { + if err := testref.Set(context.Background(), tc); err != nil { t.Fatal(err) } want = append(want, &testReq{ @@ -359,7 +359,7 @@ func TestInvalidSet(t *testing.T) { make(chan int), } for _, tc := range cases { - if err := testref.Set(tc); err == nil { + if err := testref.Set(context.Background(), tc); err == nil { t.Errorf("Set(%v) = nil; want error", tc) } } @@ -374,7 +374,7 @@ func TestSetIfUnchanged(t *testing.T) { defer srv.Close() want := &person{"Peter Parker", 17} - ok, err := testref.SetIfUnchanged("mock-etag", &want) + ok, err := testref.SetIfUnchanged(context.Background(), "mock-etag", &want) if err != nil { t.Fatal(err) } @@ -398,7 +398,7 @@ func TestSetIfUnchangedError(t *testing.T) { defer srv.Close() want := &person{"Peter Parker", 17} - ok, err := testref.SetIfUnchanged("mock-etag", &want) + ok, err := testref.SetIfUnchanged(context.Background(), "mock-etag", &want) if err != nil { t.Fatal(err) } @@ -418,7 +418,7 @@ func TestPush(t *testing.T) { srv := mock.Start(client) defer srv.Close() - child, err := testref.Push(nil) + child, err := testref.Push(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -439,7 +439,7 @@ func TestPushWithValue(t *testing.T) { defer srv.Close() want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} - child, err := testref.Push(want) + child, err := testref.Push(context.Background(), want) if err != nil { t.Fatal(err) } @@ -460,7 +460,7 @@ func TestUpdate(t *testing.T) { srv := mock.Start(client) defer srv.Close() - if err := testref.Update(want); err != nil { + if err := testref.Update(context.Background(), want); err != nil { t.Fatal(err) } checkOnlyRequest(t, mock.Reqs, &testReq{ @@ -478,7 +478,7 @@ func TestInvalidUpdate(t *testing.T) { map[string]interface{}{"foo": func() {}}, } for _, tc := range cases { - if err := testref.Update(tc); err == nil { + if err := testref.Update(context.Background(), tc); err == nil { t.Errorf("Update(%v) = nil; want error", tc) } } @@ -497,7 +497,7 @@ func TestTransaction(t *testing.T) { p["age"] = p["age"].(float64) + 1.0 return p, nil } - if err := testref.Transaction(fn); err != nil { + if err := testref.Transaction(context.Background(), fn); err != nil { t.Fatal(err) } checkAllRequests(t, mock.Reqs, []*testReq{ @@ -540,7 +540,7 @@ func TestTransactionRetry(t *testing.T) { p["age"] = p["age"].(float64) + 1.0 return p, nil } - if err := testref.Transaction(fn); err != nil { + if err := testref.Transaction(context.Background(), fn); err != nil { t.Fatal(err) } if cnt != 2 { @@ -596,7 +596,7 @@ func TestTransactionError(t *testing.T) { p["age"] = p["age"].(float64) + 1.0 return p, nil } - if err := testref.Transaction(fn); err == nil || err.Error() != want { + if err := testref.Transaction(context.Background(), fn); err == nil || err.Error() != want { t.Errorf("Transaction() = %v; want = %q", err, want) } if cnt != 1 { @@ -639,7 +639,7 @@ func TestTransactionAbort(t *testing.T) { p["age"] = p["age"].(float64) + 1.0 return p, nil } - err := testref.Transaction(fn) + err := testref.Transaction(context.Background(), fn) if err == nil { t.Errorf("Transaction() = nil; want error") } @@ -669,7 +669,7 @@ func TestDelete(t *testing.T) { srv := mock.Start(client) defer srv.Close() - if err := testref.Delete(); err != nil { + if err := testref.Delete(context.Background()); err != nil { t.Fatal(err) } checkOnlyRequest(t, mock.Reqs, &testReq{ @@ -677,31 +677,3 @@ func TestDelete(t *testing.T) { Path: "/peter.json", }) } - -func TestWithContext(t *testing.T) { - if testref.ctx != nil { - t.Errorf("Ctx = %v; want nil", testref.ctx) - } - - mock := &mockServer{} - srv := mock.Start(client) - defer srv.Close() - - for _, tc := range testOps { - mock.Resp = tc.resp - ctx, cancel := context.WithCancel(context.Background()) - r := testref.WithContext(ctx) - if r.ctx != ctx { - t.Errorf("WithContext().ctx = %v; want = %v", r.ctx, ctx) - } - if err := tc.op(r); err != nil { - t.Errorf("%s %v", tc.name, err) - t.Fatal(err) - } - - cancel() - if err := tc.op(r); err == nil { - t.Errorf("WithContext().%s = nil; want = error", tc.name) - } - } -} From 11455d87583586924f287cd01f57b51bc6f03bed Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Mon, 30 Oct 2017 18:32:33 -0700 Subject: [PATCH 31/58] Simplified the usage by adding HTTPClient --- internal/http_client.go | 41 +++++++++++++++------------ internal/http_client_test.go | 54 ++++++++++++++++++------------------ 2 files changed, 50 insertions(+), 45 deletions(-) diff --git a/internal/http_client.go b/internal/http_client.go index d19316f8..df1805ba 100644 --- a/internal/http_client.go +++ b/internal/http_client.go @@ -29,25 +29,20 @@ var Null struct{} = jsonNull{} type jsonNull struct{} -// Request contains all the parameters required to construct an outgoing HTTP request. -type Request struct { - Method string - URL string - Body interface{} - Opts []HTTPOption +// HTTPClient can be used to send and receive JSON messages over HTTP. +type HTTPClient struct { + HC *http.Client + EP ErrorParser } -// Send executes the current Request using the given context and HTTP client. -// -// If the Body is not nil, it is serialized into a JSON string. To send JSON null as the body, use -// the internal.Null variable. -func (r *Request) Send(ctx context.Context, hc *http.Client) (*Response, error) { +// Do executes the given Request, and returns a Response. +func (c *HTTPClient) Do(ctx context.Context, r *Request) (*Response, error) { req, err := r.newHTTPRequest() if err != nil { return nil, err } - resp, err := hc.Do(req.WithContext(ctx)) + resp, err := c.HC.Do(req.WithContext(ctx)) if err != nil { return nil, err } @@ -61,9 +56,18 @@ func (r *Request) Send(ctx context.Context, hc *http.Client) (*Response, error) Status: resp.StatusCode, Body: b, Header: resp.Header, + ep: c.EP, }, nil } +// Request contains all the parameters required to construct an outgoing HTTP request. +type Request struct { + Method string + URL string + Body interface{} + Opts []HTTPOption +} + func (r *Request) newHTTPRequest() (*http.Request, error) { var opts []HTTPOption var data io.Reader @@ -99,20 +103,21 @@ type Response struct { Status int Header http.Header Body []byte + ep ErrorParser } // CheckStatus checks whether the Response status code has the given HTTP status code. // // Returns an error if the status code does not match. If an ErroParser is specified, uses that to // construct the returned error message. Otherwise includes the full response body in the error. -func (r *Response) CheckStatus(want int, ep ErrorParser) error { +func (r *Response) CheckStatus(want int) error { if r.Status == want { return nil } var msg string - if ep != nil { - msg = ep(r) + if r.ep != nil { + msg = r.ep(r.Body) } if msg == "" { msg = string(r.Body) @@ -125,8 +130,8 @@ func (r *Response) CheckStatus(want int, ep ErrorParser) error { // // Unmarshal uses https://golang.org/pkg/encoding/json/#Unmarshal internally, and hence v has the // same requirements as the json package. -func (r *Response) Unmarshal(want int, ep ErrorParser, v interface{}) error { - if err := r.CheckStatus(want, ep); err != nil { +func (r *Response) Unmarshal(want int, v interface{}) error { + if err := r.CheckStatus(want); err != nil { return err } else if err := json.Unmarshal(r.Body, v); err != nil { return err @@ -135,7 +140,7 @@ func (r *Response) Unmarshal(want int, ep ErrorParser, v interface{}) error { } // ErrorParser is a function that is used to construct custom error messages. -type ErrorParser func(r *Response) string +type ErrorParser func([]byte) string // HTTPOption is an additional parameter that can be specified to customize an outgoing request. type HTTPOption func(*http.Request) diff --git a/internal/http_client_test.go b/internal/http_client_test.go index 66019573..03a16e73 100644 --- a/internal/http_client_test.go +++ b/internal/http_client_test.go @@ -93,7 +93,7 @@ var cases = []struct { }, } -func TestSend(t *testing.T) { +func TestHTTPClient(t *testing.T) { want := map[string]interface{}{ "key1": "value1", "key2": float64(100), @@ -157,21 +157,22 @@ func TestSend(t *testing.T) { server := httptest.NewServer(handler) defer server.Close() + client := &HTTPClient{HC: http.DefaultClient} for _, tc := range cases { tc.req.URL = server.URL - resp, err := tc.req.Send(context.Background(), http.DefaultClient) + resp, err := client.Do(context.Background(), tc.req) if err != nil { t.Fatal(err) } - if err := resp.CheckStatus(http.StatusOK, nil); err != nil { + if err := resp.CheckStatus(http.StatusOK); err != nil { t.Errorf("CheckStatus() = %v; want nil", err) } - if err := resp.CheckStatus(http.StatusCreated, nil); err == nil { + if err := resp.CheckStatus(http.StatusCreated); err == nil { t.Errorf("CheckStatus() = nil; want error") } var got map[string]interface{} - if err := resp.Unmarshal(http.StatusOK, nil, &got); err != nil { + if err := resp.Unmarshal(http.StatusOK, &got); err != nil { t.Errorf("Unmarshal() = %v; want nil", err) } if !reflect.DeepEqual(got, want) { @@ -197,31 +198,31 @@ func TestErrorParser(t *testing.T) { server := httptest.NewServer(handler) defer server.Close() - req := &Request{ - Method: "GET", - URL: server.URL, - } - resp, err := req.Send(context.Background(), http.DefaultClient) - if err != nil { - t.Fatal(err) - } - - ep := func(r *Response) string { - var b struct { + ep := func(b []byte) string { + var p struct { Error string `json:"error"` } - if err := json.Unmarshal(r.Body, &b); err != nil { + if err := json.Unmarshal(b, &p); err != nil { return "" } - return b.Error + return p.Error + } + client := &HTTPClient{ + HC: http.DefaultClient, + EP: ep, + } + req := &Request{Method: "GET", URL: server.URL} + resp, err := client.Do(context.Background(), req) + if err != nil { + t.Fatal(err) } want := "http error status: 500; reason: test error" - if err := resp.CheckStatus(http.StatusOK, ep); err.Error() != want { + if err := resp.CheckStatus(http.StatusOK); err.Error() != want { t.Errorf("CheckStatus() = %q; want = %q", err.Error(), want) } var got map[string]interface{} - if err := resp.Unmarshal(http.StatusOK, ep, &got); err.Error() != want { + if err := resp.Unmarshal(http.StatusOK, &got); err.Error() != want { t.Errorf("CheckStatus() = %q; want = %q", err.Error(), want) } if got != nil { @@ -234,7 +235,8 @@ func TestInvalidURL(t *testing.T) { Method: "GET", URL: "http://localhost:250/mock.url", } - _, err := req.Send(context.Background(), http.DefaultClient) + client := &HTTPClient{HC: http.DefaultClient} + _, err := client.Do(context.Background(), req) if err == nil { t.Errorf("Send() = nil; want error") } @@ -256,17 +258,15 @@ func TestUnmarshalError(t *testing.T) { server := httptest.NewServer(handler) defer server.Close() - req := &Request{ - Method: "GET", - URL: server.URL, - } - resp, err := req.Send(context.Background(), http.DefaultClient) + req := &Request{Method: "GET", URL: server.URL} + client := &HTTPClient{HC: http.DefaultClient} + resp, err := client.Do(context.Background(), req) if err != nil { t.Fatal(err) } var got func() - if err := resp.Unmarshal(http.StatusOK, nil, &got); err == nil { + if err := resp.Unmarshal(http.StatusOK, &got); err == nil { t.Errorf("Unmarshal() = nil; want error") } } From d2f90f257e052f7c9a96e6214dd48dc1fbe85bed Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Mon, 30 Oct 2017 18:52:46 -0700 Subject: [PATCH 32/58] using the new client API --- db/db.go | 33 +++++----- db/query.go | 8 +-- db/ref.go | 30 ++++----- integration/db/db_test.go | 122 ++++++++++++++++++----------------- integration/db/query_test.go | 39 +++++++---- 5 files changed, 122 insertions(+), 110 deletions(-) diff --git a/db/db.go b/db/db.go index 12c26230..90579df3 100644 --- a/db/db.go +++ b/db/db.go @@ -18,7 +18,6 @@ package db import ( "encoding/json" "fmt" - "net/http" "runtime" "strings" @@ -35,19 +34,9 @@ const userAgentFormat = "Firebase/HTTP/%s/%s/AdminGo" const invalidChars = "[].#$" const authVarOverride = "auth_variable_override" -var errParser = func(r *internal.Response) string { - var b struct { - Error string `json:"error"` - } - if err := json.Unmarshal(r.Body, &b); err != nil { - return "" - } - return b.Error -} - // Client is the interface for the Firebase Realtime Database service. type Client struct { - hc *http.Client + hc *internal.HTTPClient url string ao string } @@ -82,8 +71,17 @@ func NewClient(ctx context.Context, c *internal.DatabaseConfig) (*Client, error) } } + errParser := func(b []byte) string { + var p struct { + Error string `json:"error"` + } + if err := json.Unmarshal(b, &p); err != nil { + return "" + } + return p.Error + } return &Client{ - hc: hc, + hc: &internal.HTTPClient{HC: hc, EP: errParser}, url: fmt.Sprintf("https://%s", p.Host), ao: string(ao), }, nil @@ -119,7 +117,9 @@ func (c *Client) NewRef(path string) *Ref { } } -func (c *Client) newHTTPRequest(method, path string, body interface{}, opts ...internal.HTTPOption) (*internal.Request, error) { +func (c *Client) send( + ctx context.Context, method, path string, body interface{}, + opts ...internal.HTTPOption) (*internal.Response, error) { if strings.ContainsAny(path, invalidChars) { return nil, fmt.Errorf("invalid path with illegal characters: %q", path) } @@ -128,12 +128,13 @@ func (c *Client) newHTTPRequest(method, path string, body interface{}, opts ...i opts = append(opts, internal.WithQueryParam(authVarOverride, c.ao)) } url := fmt.Sprintf("%s%s.json", c.url, path) - return &internal.Request{ + req := &internal.Request{ Method: method, URL: url, Body: body, Opts: opts, - }, nil + } + return c.hc.Do(ctx, req) } func parsePath(path string) []string { diff --git a/db/query.go b/db/query.go index 2121d8df..6b4245c7 100644 --- a/db/query.go +++ b/db/query.go @@ -125,15 +125,11 @@ func (q *Query) Get(ctx context.Context, v interface{}) error { return err } - req, err := q.client.newHTTPRequest("GET", q.path, nil, internal.WithQueryParams(qp)) + resp, err := q.client.send(ctx, "GET", q.path, nil, internal.WithQueryParams(qp)) if err != nil { return err } - resp, err := req.Send(ctx, q.client.hc) - if err != nil { - return err - } - return resp.Unmarshal(http.StatusOK, errParser, v) + return resp.Unmarshal(http.StatusOK, v) } // OrderByChild returns a Query that orders data by child values before applying filters. diff --git a/db/ref.go b/db/ref.go index c04e3aa1..9bd45826 100644 --- a/db/ref.go +++ b/db/ref.go @@ -64,7 +64,7 @@ func (r *Ref) Get(ctx context.Context, v interface{}) error { if err != nil { return err } - return resp.Unmarshal(http.StatusOK, errParser, v) + return resp.Unmarshal(http.StatusOK, v) } // GetWithETag retrieves the value at the current database location, along with its ETag. @@ -72,7 +72,7 @@ func (r *Ref) GetWithETag(ctx context.Context, v interface{}) (string, error) { resp, err := r.send(ctx, "GET", internal.WithHeader("X-Firebase-ETag", "true")) if err != nil { return "", err - } else if err := resp.Unmarshal(http.StatusOK, errParser, v); err != nil { + } else if err := resp.Unmarshal(http.StatusOK, v); err != nil { return "", err } return resp.Header.Get("Etag"), nil @@ -89,9 +89,9 @@ func (r *Ref) GetIfChanged(ctx context.Context, etag string, v interface{}) (boo resp, err := r.send(ctx, "GET", internal.WithHeader("If-None-Match", etag)) if err != nil { return false, "", err - } else if err := resp.Unmarshal(http.StatusOK, errParser, v); err == nil { + } else if err := resp.Unmarshal(http.StatusOK, v); err == nil { return true, resp.Header.Get("ETag"), nil - } else if err := resp.CheckStatus(http.StatusNotModified, errParser); err != nil { + } else if err := resp.CheckStatus(http.StatusNotModified); err != nil { return false, "", err } return false, etag, nil @@ -107,7 +107,7 @@ func (r *Ref) Set(ctx context.Context, v interface{}) error { if err != nil { return err } - return resp.CheckStatus(http.StatusNoContent, errParser) + return resp.CheckStatus(http.StatusNoContent) } // SetIfUnchanged conditionally sets the data at this location to the given value. @@ -118,9 +118,9 @@ func (r *Ref) SetIfUnchanged(ctx context.Context, etag string, v interface{}) (b resp, err := r.sendWithBody(ctx, "PUT", v, internal.WithHeader("If-Match", etag)) if err != nil { return false, err - } else if err := resp.CheckStatus(http.StatusOK, errParser); err == nil { + } else if err := resp.CheckStatus(http.StatusOK); err == nil { return true, nil - } else if err := resp.CheckStatus(http.StatusPreconditionFailed, errParser); err != nil { + } else if err := resp.CheckStatus(http.StatusPreconditionFailed); err != nil { return false, err } return false, nil @@ -141,7 +141,7 @@ func (r *Ref) Push(ctx context.Context, v interface{}) (*Ref, error) { var d struct { Name string `json:"name"` } - if err := resp.Unmarshal(http.StatusOK, errParser, &d); err != nil { + if err := resp.Unmarshal(http.StatusOK, &d); err != nil { return nil, err } return r.Child(d.Name), nil @@ -156,7 +156,7 @@ func (r *Ref) Update(ctx context.Context, v map[string]interface{}) error { if err != nil { return err } - return resp.CheckStatus(http.StatusNoContent, errParser) + return resp.CheckStatus(http.StatusNoContent) } type UpdateFn func(interface{}) (interface{}, error) @@ -190,9 +190,9 @@ func (r *Ref) Transaction(ctx context.Context, fn UpdateFn) error { resp, err := r.sendWithBody(ctx, "PUT", new, internal.WithHeader("If-Match", etag)) if err != nil { return err - } else if err := resp.CheckStatus(http.StatusOK, errParser); err == nil { + } else if err := resp.CheckStatus(http.StatusOK); err == nil { return nil - } else if err := resp.Unmarshal(http.StatusPreconditionFailed, errParser, &curr); err != nil { + } else if err := resp.Unmarshal(http.StatusPreconditionFailed, &curr); err != nil { return err } etag = resp.Header.Get("ETag") @@ -206,7 +206,7 @@ func (r *Ref) Delete(ctx context.Context) error { if err != nil { return err } - return resp.CheckStatus(http.StatusOK, errParser) + return resp.CheckStatus(http.StatusOK) } func (r *Ref) send( @@ -218,9 +218,5 @@ func (r *Ref) send( func (r *Ref) sendWithBody( ctx context.Context, method string, body interface{}, opts ...internal.HTTPOption) (*internal.Response, error) { - req, err := r.client.newHTTPRequest(method, r.Path, body, opts...) - if err != nil { - return nil, err - } - return req.Send(ctx, r.client.hc) + return r.client.send(ctx, method, r.Path, body, opts...) } diff --git a/integration/db/db_test.go b/integration/db/db_test.go index d4f0627f..65eb3f8b 100644 --- a/integration/db/db_test.go +++ b/integration/db/db_test.go @@ -162,7 +162,7 @@ func initData() { log.Fatalln(err) } - if err = ref.Set(testData); err != nil { + if err = ref.Set(context.Background(), testData); err != nil { log.Fatalln(err) } } @@ -198,7 +198,7 @@ func TestParent(t *testing.T) { func TestGet(t *testing.T) { var m map[string]interface{} - if err := ref.Get(&m); err != nil { + if err := ref.Get(context.Background(), &m); err != nil { t.Fatal(err) } if !reflect.DeepEqual(testData, m) { @@ -208,7 +208,7 @@ func TestGet(t *testing.T) { func TestGetWithETag(t *testing.T) { var m map[string]interface{} - etag, err := ref.GetWithETag(&m) + etag, err := ref.GetWithETag(context.Background(), &m) if err != nil { t.Fatal(err) } @@ -222,7 +222,7 @@ func TestGetWithETag(t *testing.T) { func TestGetIfChanged(t *testing.T) { var m map[string]interface{} - ok, etag, err := ref.GetIfChanged("wrong-etag", &m) + ok, etag, err := ref.GetIfChanged(context.Background(), "wrong-etag", &m) if err != nil { t.Fatal(err) } @@ -234,7 +234,7 @@ func TestGetIfChanged(t *testing.T) { } var m2 map[string]interface{} - ok, etag2, err := ref.GetIfChanged(etag, &m2) + ok, etag2, err := ref.GetIfChanged(context.Background(), etag, &m2) if err != nil { t.Fatal(err) } @@ -249,7 +249,7 @@ func TestGetIfChanged(t *testing.T) { func TestGetChildValue(t *testing.T) { c := ref.Child("dinosaurs") var m map[string]interface{} - if err := c.Get(&m); err != nil { + if err := c.Get(context.Background(), &m); err != nil { t.Fatal(err) } if !reflect.DeepEqual(testData["dinosaurs"], m) { @@ -260,7 +260,7 @@ func TestGetChildValue(t *testing.T) { func TestGetGrandChildValue(t *testing.T) { c := ref.Child("dinosaurs/lambeosaurus") var got Dinosaur - if err := c.Get(&got); err != nil { + if err := c.Get(context.Background(), &got); err != nil { t.Fatal(err) } want := parsedTestData["lambeosaurus"] @@ -272,7 +272,7 @@ func TestGetGrandChildValue(t *testing.T) { func TestGetNonExistingChild(t *testing.T) { c := ref.Child("non_existing") var i interface{} - if err := c.Get(&i); err != nil { + if err := c.Get(context.Background(), &i); err != nil { t.Fatal(err) } if i != nil { @@ -281,7 +281,7 @@ func TestGetNonExistingChild(t *testing.T) { } func TestPush(t *testing.T) { - u, err := users.Push(nil) + u, err := users.Push(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -290,7 +290,7 @@ func TestPush(t *testing.T) { } var i interface{} - if err := u.Get(&i); err != nil { + if err := u.Get(context.Background(), &i); err != nil { t.Fatal(err) } if i != "" { @@ -300,7 +300,7 @@ func TestPush(t *testing.T) { func TestPushWithValue(t *testing.T) { want := User{"Luis Alvarez", 1911} - u, err := users.Push(&want) + u, err := users.Push(context.Background(), &want) if err != nil { t.Fatal(err) } @@ -309,7 +309,7 @@ func TestPushWithValue(t *testing.T) { } var got User - if err := u.Get(&got); err != nil { + if err := u.Get(context.Background(), &got); err != nil { t.Fatal(err) } if want != got { @@ -318,15 +318,15 @@ func TestPushWithValue(t *testing.T) { } func TestSetPrimitiveValue(t *testing.T) { - u, err := users.Push(nil) + u, err := users.Push(context.Background(), nil) if err != nil { t.Fatal(err) } - if err := u.Set("value"); err != nil { + if err := u.Set(context.Background(), "value"); err != nil { t.Fatal(err) } var got string - if err := u.Get(&got); err != nil { + if err := u.Get(context.Background(), &got); err != nil { t.Fatal(err) } if got != "value" { @@ -335,17 +335,17 @@ func TestSetPrimitiveValue(t *testing.T) { } func TestSetComplexValue(t *testing.T) { - u, err := users.Push(nil) + u, err := users.Push(context.Background(), nil) if err != nil { t.Fatal(err) } want := User{"Mary Anning", 1799} - if err := u.Set(&want); err != nil { + if err := u.Set(context.Background(), &want); err != nil { t.Fatal(err) } var got User - if err := u.Get(&got); err != nil { + if err := u.Get(context.Background(), &got); err != nil { t.Fatal(err) } if got != want { @@ -354,7 +354,7 @@ func TestSetComplexValue(t *testing.T) { } func TestUpdateChildren(t *testing.T) { - u, err := users.Push(nil) + u, err := users.Push(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -363,11 +363,11 @@ func TestUpdateChildren(t *testing.T) { "name": "Robert Bakker", "since": float64(1945), } - if err := u.Update(want); err != nil { + if err := u.Update(context.Background(), want); err != nil { t.Fatal(err) } var got map[string]interface{} - if err := u.Get(&got); err != nil { + if err := u.Get(context.Background(), &got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { @@ -376,7 +376,7 @@ func TestUpdateChildren(t *testing.T) { } func TestUpdateChildrenWithExistingValue(t *testing.T) { - u, err := users.Push(map[string]interface{}{ + u, err := users.Push(context.Background(), map[string]interface{}{ "name": "Edwin Colbert", "since": float64(1900), }) @@ -385,11 +385,11 @@ func TestUpdateChildrenWithExistingValue(t *testing.T) { } update := map[string]interface{}{"since": float64(1905)} - if err := u.Update(update); err != nil { + if err := u.Update(context.Background(), update); err != nil { t.Fatal(err) } var got map[string]interface{} - if err := u.Get(&got); err != nil { + if err := u.Get(context.Background(), &got); err != nil { t.Fatal(err) } want := map[string]interface{}{ @@ -402,11 +402,15 @@ func TestUpdateChildrenWithExistingValue(t *testing.T) { } func TestUpdateNestedChildren(t *testing.T) { - edward, err := users.Push(map[string]interface{}{"name": "Edward Cope", "since": float64(1800)}) + edward, err := users.Push(context.Background(), map[string]interface{}{ + "name": "Edward Cope", "since": float64(1800), + }) if err != nil { t.Fatal(err) } - jack, err := users.Push(map[string]interface{}{"name": "Jack Horner", "since": float64(1940)}) + jack, err := users.Push(context.Background(), map[string]interface{}{ + "name": "Jack Horner", "since": float64(1940), + }) if err != nil { t.Fatal(err) } @@ -414,11 +418,11 @@ func TestUpdateNestedChildren(t *testing.T) { fmt.Sprintf("%s/since", edward.Key): 1840, fmt.Sprintf("%s/since", jack.Key): 1946, } - if err := users.Update(delta); err != nil { + if err := users.Update(context.Background(), delta); err != nil { t.Fatal(err) } var got map[string]interface{} - if err := edward.Get(&got); err != nil { + if err := edward.Get(context.Background(), &got); err != nil { t.Fatal(err) } want := map[string]interface{}{"name": "Edward Cope", "since": float64(1840)} @@ -426,7 +430,7 @@ func TestUpdateNestedChildren(t *testing.T) { t.Errorf("Get() = %v; want = %v", got, want) } - if err := jack.Get(&got); err != nil { + if err := jack.Get(context.Background(), &got); err != nil { t.Fatal(err) } want = map[string]interface{}{"name": "Jack Horner", "since": float64(1946)} @@ -436,13 +440,13 @@ func TestUpdateNestedChildren(t *testing.T) { } func TestSetIfChanged(t *testing.T) { - edward, err := users.Push(&User{"Edward Cope", 1800}) + edward, err := users.Push(context.Background(), &User{"Edward Cope", 1800}) if err != nil { t.Fatal(err) } update := User{"Jack Horner", 1940} - ok, err := edward.SetIfUnchanged("invalid-etag", &update) + ok, err := edward.SetIfUnchanged(context.Background(), "invalid-etag", &update) if err != nil { t.Fatal(err) } @@ -451,11 +455,11 @@ func TestSetIfChanged(t *testing.T) { } var u User - etag, err := edward.GetWithETag(&u) + etag, err := edward.GetWithETag(context.Background(), &u) if err != nil { t.Fatal(err) } - ok, err = edward.SetIfUnchanged(etag, &update) + ok, err = edward.SetIfUnchanged(context.Background(), etag, &update) if err != nil { t.Fatal(err) } @@ -463,7 +467,7 @@ func TestSetIfChanged(t *testing.T) { t.Errorf("SetIfUnchanged() = %v; want = %v", ok, true) } - if err := edward.Get(&u); err != nil { + if err := edward.Get(context.Background(), &u); err != nil { t.Fatal(err) } if !reflect.DeepEqual(update, u) { @@ -472,7 +476,7 @@ func TestSetIfChanged(t *testing.T) { } func TestTransaction(t *testing.T) { - u, err := users.Push(&User{Name: "Richard"}) + u, err := users.Push(context.Background(), &User{Name: "Richard"}) if err != nil { t.Fatal(err) } @@ -482,11 +486,11 @@ func TestTransaction(t *testing.T) { snap["since"] = 1804 return snap, nil } - if err := u.Transaction(fn); err != nil { + if err := u.Transaction(context.Background(), fn); err != nil { t.Fatal(err) } var got User - if err := u.Get(&got); err != nil { + if err := u.Get(context.Background(), &got); err != nil { t.Fatal(err) } want := User{"Richard Owen", 1804} @@ -497,18 +501,18 @@ func TestTransaction(t *testing.T) { func TestTransactionScalar(t *testing.T) { cnt := users.Child("count") - if err := cnt.Set(42); err != nil { + if err := cnt.Set(context.Background(), 42); err != nil { t.Fatal(err) } fn := func(curr interface{}) (interface{}, error) { snap := curr.(float64) return snap + 1, nil } - if err := cnt.Transaction(fn); err != nil { + if err := cnt.Transaction(context.Background(), fn); err != nil { t.Fatal(err) } var got float64 - if err := cnt.Get(&got); err != nil { + if err := cnt.Get(context.Background(), &got); err != nil { t.Fatal(err) } if got != 43.0 { @@ -517,23 +521,23 @@ func TestTransactionScalar(t *testing.T) { } func TestDelete(t *testing.T) { - u, err := users.Push("foo") + u, err := users.Push(context.Background(), "foo") if err != nil { t.Fatal(err) } var got string - if err := u.Get(&got); err != nil { + if err := u.Get(context.Background(), &got); err != nil { t.Fatal(err) } if got != "foo" { t.Errorf("Get() = %q; want = %q", got, "foo") } - if err := u.Delete(); err != nil { + if err := u.Delete(context.Background()); err != nil { t.Fatal(err) } var got2 string - if err := u.Get(&got2); err != nil { + if err := u.Get(context.Background(), &got2); err != nil { t.Fatal(err) } if got2 != "" { @@ -544,12 +548,12 @@ func TestDelete(t *testing.T) { func TestNoAccess(t *testing.T) { r := aoClient.NewRef(protectedRef(t, "_adminsdk/go/admin")) var got string - if err := r.Get(&got); err == nil || got != "" { + if err := r.Get(context.Background(), &got); err == nil || got != "" { t.Errorf("Get() = (%q, %v); want = (empty, error)", got, err) } else if err.Error() != permDenied { t.Errorf("Error = %q; want = %q", err.Error(), permDenied) } - if err := r.Set("update"); err == nil { + if err := r.Set(context.Background(), "update"); err == nil { t.Errorf("Set() = nil; want = error") } else if err.Error() != permDenied { t.Errorf("Error = %q; want = %q", err.Error(), permDenied) @@ -559,10 +563,10 @@ func TestNoAccess(t *testing.T) { func TestReadAccess(t *testing.T) { r := aoClient.NewRef(protectedRef(t, "_adminsdk/go/protected/user2")) var got string - if err := r.Get(&got); err != nil || got != "test" { + if err := r.Get(context.Background(), &got); err != nil || got != "test" { t.Errorf("Get() = (%q, %v); want = (%q, nil)", got, err, "test") } - if err := r.Set("update"); err == nil { + if err := r.Set(context.Background(), "update"); err == nil { t.Errorf("Set() = nil; want = error") } else if err.Error() != permDenied { t.Errorf("Error = %q; want = %q", err.Error(), permDenied) @@ -572,10 +576,10 @@ func TestReadAccess(t *testing.T) { func TestReadWriteAccess(t *testing.T) { r := aoClient.NewRef(protectedRef(t, "_adminsdk/go/protected/user1")) var got string - if err := r.Get(&got); err != nil || got != "test" { + if err := r.Get(context.Background(), &got); err != nil || got != "test" { t.Errorf("Get() = (%q, %v); want = (%q, nil)", got, err, "test") } - if err := r.Set("update"); err != nil { + if err := r.Set(context.Background(), "update"); err != nil { t.Errorf("Set() = %v; want = nil", err) } } @@ -583,7 +587,7 @@ func TestReadWriteAccess(t *testing.T) { func TestQueryAccess(t *testing.T) { r := aoClient.NewRef("_adminsdk/go/protected") got := make(map[string]interface{}) - if err := r.OrderByKey().WithLimitToFirst(2).Get(&got); err == nil { + if err := r.OrderByKey().WithLimitToFirst(2).Get(context.Background(), &got); err == nil { t.Errorf("OrderByQuery() = nil; want = error") } else if err.Error() != permDenied { t.Errorf("Error = %q; want = %q", err.Error(), permDenied) @@ -593,10 +597,10 @@ func TestQueryAccess(t *testing.T) { func TestGuestAccess(t *testing.T) { r := guestClient.NewRef(protectedRef(t, "_adminsdk/go/public")) var got string - if err := r.Get(&got); err != nil || got != "test" { + if err := r.Get(context.Background(), &got); err != nil || got != "test" { t.Errorf("Get() = (%q, %v); want = (%q, nil)", got, err, "test") } - if err := r.Set("update"); err == nil { + if err := r.Set(context.Background(), "update"); err == nil { t.Errorf("Set() = nil; want = error") } else if err.Error() != permDenied { t.Errorf("Error = %q; want = %q", err.Error(), permDenied) @@ -604,21 +608,21 @@ func TestGuestAccess(t *testing.T) { got = "" r = guestClient.NewRef("_adminsdk/go") - if err := r.Get(&got); err == nil || got != "" { + if err := r.Get(context.Background(), &got); err == nil || got != "" { t.Errorf("Get() = (%q, %v); want = (empty, error)", got, err) } else if err.Error() != permDenied { t.Errorf("Error = %q; want = %q", err.Error(), permDenied) } c := r.Child("protected/user2") - if err := c.Get(&got); err == nil || got != "" { + if err := c.Get(context.Background(), &got); err == nil || got != "" { t.Errorf("Get() = (%q, %v); want = (empty, error)", got, err) } else if err.Error() != permDenied { t.Errorf("Error = %q; want = %q", err.Error(), permDenied) } c = r.Child("admin") - if err := c.Get(&got); err == nil || got != "" { + if err := c.Get(context.Background(), &got); err == nil || got != "" { t.Errorf("Get() = (%q, %v); want = (empty, error)", got, err) } else if err.Error() != permDenied { t.Errorf("Error = %q; want = %q", err.Error(), permDenied) @@ -628,7 +632,7 @@ func TestGuestAccess(t *testing.T) { func TestWithContext(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) var m map[string]interface{} - if err := ref.WithContext(ctx).Get(&m); err != nil { + if err := ref.Get(ctx, &m); err != nil { t.Fatal(err) } if !reflect.DeepEqual(testData, m) { @@ -637,14 +641,14 @@ func TestWithContext(t *testing.T) { cancel() m = nil - if err := ref.WithContext(ctx).Get(&m); len(m) != 0 || err == nil { + if err := ref.Get(ctx, &m); len(m) != 0 || err == nil { t.Errorf("Get() = (%v, %v); want = (empty, error)", m, err) } } func protectedRef(t *testing.T, p string) string { r := client.NewRef(p) - if err := r.Set("test"); err != nil { + if err := r.Set(context.Background(), "test"); err != nil { t.Fatal(err) } return p diff --git a/integration/db/query_test.go b/integration/db/query_test.go index d25583af..8a6808a6 100644 --- a/integration/db/query_test.go +++ b/integration/db/query_test.go @@ -13,7 +13,9 @@ var heightSorted = []string{ func TestLimitToFirst(t *testing.T) { for _, tc := range []int{2, 10} { var m map[string]Dinosaur - if err := dinos.OrderByChild("height").WithLimitToFirst(tc).Get(&m); err != nil { + if err := dinos.OrderByChild("height"). + WithLimitToFirst(tc). + Get(context.Background(), &m); err != nil { t.Fatal(err) } @@ -36,7 +38,9 @@ func TestLimitToFirst(t *testing.T) { func TestLimitToLast(t *testing.T) { for _, tc := range []int{2, 10} { var m map[string]Dinosaur - if err := dinos.OrderByChild("height").WithLimitToLast(tc).Get(&m); err != nil { + if err := dinos.OrderByChild("height"). + WithLimitToLast(tc). + Get(context.Background(), &m); err != nil { t.Fatal(err) } @@ -58,7 +62,9 @@ func TestLimitToLast(t *testing.T) { func TestStartAt(t *testing.T) { var m map[string]Dinosaur - if err := dinos.OrderByChild("height").WithStartAt(3.5).Get(&m); err != nil { + if err := dinos.OrderByChild("height"). + WithStartAt(3.5). + Get(context.Background(), &m); err != nil { t.Fatal(err) } @@ -75,7 +81,9 @@ func TestStartAt(t *testing.T) { func TestEndAt(t *testing.T) { var m map[string]Dinosaur - if err := dinos.OrderByChild("height").WithEndAt(3.5).Get(&m); err != nil { + if err := dinos.OrderByChild("height"). + WithEndAt(3.5). + Get(context.Background(), &m); err != nil { t.Fatal(err) } @@ -92,7 +100,10 @@ func TestEndAt(t *testing.T) { func TestStartAndEndAt(t *testing.T) { var m map[string]Dinosaur - if err := dinos.OrderByChild("height").WithStartAt(2.5).WithEndAt(5).Get(&m); err != nil { + if err := dinos.OrderByChild("height"). + WithStartAt(2.5). + WithEndAt(5). + Get(context.Background(), &m); err != nil { t.Fatal(err) } @@ -109,7 +120,9 @@ func TestStartAndEndAt(t *testing.T) { func TestEqualTo(t *testing.T) { var m map[string]Dinosaur - if err := dinos.OrderByChild("height").WithEqualTo(0.6).Get(&m); err != nil { + if err := dinos.OrderByChild("height"). + WithEqualTo(0.6). + Get(context.Background(), &m); err != nil { t.Fatal(err) } @@ -126,7 +139,9 @@ func TestEqualTo(t *testing.T) { func TestOrderByNestedChild(t *testing.T) { var m map[string]Dinosaur - if err := dinos.OrderByChild("ratings/pos").WithStartAt(4).Get(&m); err != nil { + if err := dinos.OrderByChild("ratings/pos"). + WithStartAt(4). + Get(context.Background(), &m); err != nil { t.Fatal(err) } @@ -143,7 +158,7 @@ func TestOrderByNestedChild(t *testing.T) { func TestOrderByKey(t *testing.T) { var m map[string]Dinosaur - if err := dinos.OrderByKey().WithLimitToFirst(2).Get(&m); err != nil { + if err := dinos.OrderByKey().WithLimitToFirst(2).Get(context.Background(), &m); err != nil { t.Fatal(err) } @@ -161,7 +176,7 @@ func TestOrderByKey(t *testing.T) { func TestOrderByValue(t *testing.T) { scores := ref.Child("scores") var m map[string]int - if err := scores.OrderByValue().WithLimitToLast(2).Get(&m); err != nil { + if err := scores.OrderByValue().WithLimitToLast(2).Get(context.Background(), &m); err != nil { t.Fatal(err) } @@ -178,9 +193,9 @@ func TestOrderByValue(t *testing.T) { func TestQueryWithContext(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) - q := dinos.OrderByKey().WithLimitToFirst(2).WithContext(ctx) + q := dinos.OrderByKey().WithLimitToFirst(2) var m map[string]Dinosaur - if err := q.Get(&m); err != nil { + if err := q.Get(ctx, &m); err != nil { t.Fatal(err) } @@ -196,7 +211,7 @@ func TestQueryWithContext(t *testing.T) { cancel() m = nil - if err := q.Get(&m); len(m) != 0 || err == nil { + if err := q.Get(ctx, &m); len(m) != 0 || err == nil { t.Errorf("Get() = (%v, %v); want = (empty, error)", m, err) } } From 6c1d6eec8fbe3009a3170bd132c546b6483025d6 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Mon, 30 Oct 2017 18:57:38 -0700 Subject: [PATCH 33/58] Using the old ctx import --- internal/http_client.go | 3 ++- internal/http_client_test.go | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/internal/http_client.go b/internal/http_client.go index df1805ba..a3aec901 100644 --- a/internal/http_client.go +++ b/internal/http_client.go @@ -16,12 +16,13 @@ package internal import ( "bytes" - "context" "encoding/json" "fmt" "io" "io/ioutil" "net/http" + + "golang.org/x/net/context" ) // Null represents JSON null value. diff --git a/internal/http_client_test.go b/internal/http_client_test.go index 03a16e73..0dd483d7 100644 --- a/internal/http_client_test.go +++ b/internal/http_client_test.go @@ -14,13 +14,14 @@ package internal import ( - "context" "encoding/json" "io/ioutil" "net/http" "net/http/httptest" "reflect" "testing" + + "golang.org/x/net/context" ) var cases = []struct { From 6ddf0f6daea6f017a612a759f5a59dcdfbc0cb58 Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Mon, 30 Oct 2017 21:22:32 -0700 Subject: [PATCH 34/58] Using the old context import --- auth/auth_std.go | 2 +- db/auth_override_test.go | 3 ++- integration/db/query_test.go | 3 ++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/auth/auth_std.go b/auth/auth_std.go index f593a7cc..2055af38 100644 --- a/auth/auth_std.go +++ b/auth/auth_std.go @@ -16,7 +16,7 @@ package auth -import "context" +import "golang.org/x/net/context" func newSigner(ctx context.Context) (signer, error) { return serviceAcctSigner{}, nil diff --git a/db/auth_override_test.go b/db/auth_override_test.go index c696acd8..f7202a72 100644 --- a/db/auth_override_test.go +++ b/db/auth_override_test.go @@ -15,8 +15,9 @@ package db import ( - "context" "testing" + + "golang.org/x/net/context" ) func TestAuthOverrideGet(t *testing.T) { diff --git a/integration/db/query_test.go b/integration/db/query_test.go index 8a6808a6..37f4141f 100644 --- a/integration/db/query_test.go +++ b/integration/db/query_test.go @@ -1,8 +1,9 @@ package db import ( - "context" "testing" + + "golang.org/x/net/context" ) var heightSorted = []string{ From 6bcb4896f1e38066571d0b6956602ee6fe240df5 Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Mon, 30 Oct 2017 21:53:46 -0700 Subject: [PATCH 35/58] Refactored db code --- db/db.go | 17 +++++++-------- db/query.go | 63 +++++++++++++++++++++++++++++++---------------------- db/ref.go | 15 ++++++++++--- 3 files changed, 57 insertions(+), 38 deletions(-) diff --git a/db/db.go b/db/db.go index 90579df3..6561675f 100644 --- a/db/db.go +++ b/db/db.go @@ -117,24 +117,23 @@ func (c *Client) NewRef(path string) *Ref { } } -func (c *Client) send( - ctx context.Context, method, path string, body interface{}, - opts ...internal.HTTPOption) (*internal.Response, error) { +func (c *Client) newRequest( + method, path string, + body interface{}, + opts ...internal.HTTPOption) (*internal.Request, error) { + if strings.ContainsAny(path, invalidChars) { return nil, fmt.Errorf("invalid path with illegal characters: %q", path) } - if c.ao != "" { opts = append(opts, internal.WithQueryParam(authVarOverride, c.ao)) } - url := fmt.Sprintf("%s%s.json", c.url, path) - req := &internal.Request{ + return &internal.Request{ Method: method, - URL: url, + URL: fmt.Sprintf("%s%s.json", c.url, path), Body: body, Opts: opts, - } - return c.hc.Do(ctx, req) + }, nil } func parsePath(path string) []string { diff --git a/db/query.go b/db/query.go index 6b4245c7..2890fbd1 100644 --- a/db/query.go +++ b/db/query.go @@ -95,37 +95,15 @@ func (q *Query) WithLimitToLast(n int) *Query { // Results will not be stored in any particular order in v. func (q *Query) Get(ctx context.Context, v interface{}) error { qp := make(map[string]string) - ob, err := q.ob.encode() - if err != nil { + if err := initQueryParams(q, qp); err != nil { return err } - qp["orderBy"] = ob - if q.limFirst > 0 && q.limLast > 0 { - return fmt.Errorf("cannot set both limit parameter: first = %d, last = %d", q.limFirst, q.limLast) - } else if q.limFirst < 0 { - return fmt.Errorf("limit first cannot be negative: %d", q.limFirst) - } else if q.limLast < 0 { - return fmt.Errorf("limit last cannot be negative: %d", q.limLast) - } - - if q.limFirst > 0 { - qp["limitToFirst"] = strconv.Itoa(q.limFirst) - } else if q.limLast > 0 { - qp["limitToLast"] = strconv.Itoa(q.limLast) - } - - if err := encodeFilter("startAt", q.start, qp); err != nil { - return err - } - if err := encodeFilter("endAt", q.end, qp); err != nil { - return err - } - if err := encodeFilter("equalTo", q.equalTo, qp); err != nil { + req, err := q.client.newRequest("GET", q.path, nil, internal.WithQueryParams(qp)) + if err != nil { return err } - - resp, err := q.client.send(ctx, "GET", q.path, nil, internal.WithQueryParams(qp)) + resp, err := q.client.hc.Do(ctx, req) if err != nil { return err } @@ -167,6 +145,39 @@ func newQuery(r *Ref, ob orderBy) *Query { } } +func initQueryParams(q *Query, qp map[string]string) error { + ob, err := q.ob.encode() + if err != nil { + return err + } + qp["orderBy"] = ob + + if q.limFirst > 0 && q.limLast > 0 { + return fmt.Errorf("cannot set both limit parameter: first = %d, last = %d", q.limFirst, q.limLast) + } else if q.limFirst < 0 { + return fmt.Errorf("limit first cannot be negative: %d", q.limFirst) + } else if q.limLast < 0 { + return fmt.Errorf("limit last cannot be negative: %d", q.limLast) + } + + if q.limFirst > 0 { + qp["limitToFirst"] = strconv.Itoa(q.limFirst) + } else if q.limLast > 0 { + qp["limitToLast"] = strconv.Itoa(q.limLast) + } + + if err := encodeFilter("startAt", q.start, qp); err != nil { + return err + } + if err := encodeFilter("endAt", q.end, qp); err != nil { + return err + } + if err := encodeFilter("equalTo", q.equalTo, qp); err != nil { + return err + } + return nil +} + func encodeFilter(key string, val interface{}, m map[string]string) error { if val == nil { return nil diff --git a/db/ref.go b/db/ref.go index 9bd45826..5a2a15f9 100644 --- a/db/ref.go +++ b/db/ref.go @@ -210,13 +210,22 @@ func (r *Ref) Delete(ctx context.Context) error { } func (r *Ref) send( - ctx context.Context, method string, + ctx context.Context, + method string, opts ...internal.HTTPOption) (*internal.Response, error) { + return r.sendWithBody(ctx, method, nil, opts...) } func (r *Ref) sendWithBody( - ctx context.Context, method string, body interface{}, + ctx context.Context, + method string, + body interface{}, opts ...internal.HTTPOption) (*internal.Response, error) { - return r.client.send(ctx, method, r.Path, body, opts...) + + req, err := r.client.newRequest(method, r.Path, body, opts...) + if err != nil { + return nil, err + } + return r.client.hc.Do(ctx, req) } From 9c5dc029896602a29a0bec6a7b94808fb5c8c715 Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Mon, 30 Oct 2017 22:43:47 -0700 Subject: [PATCH 36/58] More refactoring --- db/db.go | 45 +++++++++++++++++++++++++++++---------------- db/query.go | 9 +++++---- db/ref.go | 10 ++++++---- 3 files changed, 40 insertions(+), 24 deletions(-) diff --git a/db/db.go b/db/db.go index 6561675f..c9ebb25f 100644 --- a/db/db.go +++ b/db/db.go @@ -117,23 +117,12 @@ func (c *Client) NewRef(path string) *Ref { } } -func (c *Client) newRequest( - method, path string, - body interface{}, - opts ...internal.HTTPOption) (*internal.Request, error) { - - if strings.ContainsAny(path, invalidChars) { - return nil, fmt.Errorf("invalid path with illegal characters: %q", path) - } - if c.ao != "" { - opts = append(opts, internal.WithQueryParam(authVarOverride, c.ao)) +func (c *Client) send(ctx context.Context, r *request) (*internal.Response, error) { + req, err := r.NewInternalRequest(c) + if err != nil { + return nil, err } - return &internal.Request{ - Method: method, - URL: fmt.Sprintf("%s%s.json", c.url, path), - Body: body, - Opts: opts, - }, nil + return c.hc.Do(ctx, req) } func parsePath(path string) []string { @@ -145,3 +134,27 @@ func parsePath(path string) []string { } return segs } + +type request struct { + Method, Path string + Body interface{} + Opts []internal.HTTPOption +} + +func (r *request) NewInternalRequest(c *Client) (*internal.Request, error) { + if strings.ContainsAny(r.Path, invalidChars) { + return nil, fmt.Errorf("invalid path with illegal characters: %q", r.Path) + } + + var opts []internal.HTTPOption + opts = append(opts, r.Opts...) + if c.ao != "" { + opts = append(opts, internal.WithQueryParam(authVarOverride, c.ao)) + } + return &internal.Request{ + Method: r.Method, + URL: fmt.Sprintf("%s%s.json", c.url, r.Path), + Body: r.Body, + Opts: opts, + }, nil +} diff --git a/db/query.go b/db/query.go index 2890fbd1..dfacb117 100644 --- a/db/query.go +++ b/db/query.go @@ -99,11 +99,12 @@ func (q *Query) Get(ctx context.Context, v interface{}) error { return err } - req, err := q.client.newRequest("GET", q.path, nil, internal.WithQueryParams(qp)) - if err != nil { - return err + req := &request{ + Method: "GET", + Path: q.path, + Opts: []internal.HTTPOption{internal.WithQueryParams(qp)}, } - resp, err := q.client.hc.Do(ctx, req) + resp, err := q.client.send(ctx, req) if err != nil { return err } diff --git a/db/ref.go b/db/ref.go index 5a2a15f9..050bb8ce 100644 --- a/db/ref.go +++ b/db/ref.go @@ -223,9 +223,11 @@ func (r *Ref) sendWithBody( body interface{}, opts ...internal.HTTPOption) (*internal.Response, error) { - req, err := r.client.newRequest(method, r.Path, body, opts...) - if err != nil { - return nil, err + req := &request{ + Method: method, + Path: r.Path, + Body: body, + Opts: opts, } - return r.client.hc.Do(ctx, req) + return r.client.send(ctx, req) } From 98a2e3c6763a0df5295d0428d3845af06584a96c Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Tue, 31 Oct 2017 11:09:50 -0700 Subject: [PATCH 37/58] Support for arbitrary entity types in the request --- internal/http_client.go | 40 +++++++++++++++++++++++------------- internal/http_client_test.go | 27 ++++++++---------------- 2 files changed, 35 insertions(+), 32 deletions(-) diff --git a/internal/http_client.go b/internal/http_client.go index a3aec901..0a4c4a1e 100644 --- a/internal/http_client.go +++ b/internal/http_client.go @@ -25,11 +25,6 @@ import ( "golang.org/x/net/context" ) -// Null represents JSON null value. -var Null struct{} = jsonNull{} - -type jsonNull struct{} - // HTTPClient can be used to send and receive JSON messages over HTTP. type HTTPClient struct { HC *http.Client @@ -65,7 +60,7 @@ func (c *HTTPClient) Do(ctx context.Context, r *Request) (*Response, error) { type Request struct { Method string URL string - Body interface{} + Body HTTPEntity Opts []HTTPOption } @@ -73,18 +68,12 @@ func (r *Request) newHTTPRequest() (*http.Request, error) { var opts []HTTPOption var data io.Reader if r.Body != nil { - var body interface{} - if r.Body == Null { - body = nil - } else { - body = r.Body - } - b, err := json.Marshal(body) + b, err := r.Body.Bytes() if err != nil { return nil, err } data = bytes.NewBuffer(b) - opts = append(opts, WithHeader("Content-Type", "application/json")) + opts = append(opts, WithHeader("Content-Type", r.Body.Mime())) } req, err := http.NewRequest(r.Method, r.URL, data) @@ -99,6 +88,29 @@ func (r *Request) newHTTPRequest() (*http.Request, error) { return req, nil } +// HTTPEntity represents a payload that can be included in an outgoing HTTP request. +type HTTPEntity interface { + Bytes() ([]byte, error) + Mime() string +} + +type jsonEntity struct { + Val interface{} +} + +// NewJSONEntity creates a new HTTPEntity that will be serialized into JSON. +func NewJSONEntity(v interface{}) HTTPEntity { + return &jsonEntity{Val: v} +} + +func (e *jsonEntity) Bytes() ([]byte, error) { + return json.Marshal(e.Val) +} + +func (e *jsonEntity) Mime() string { + return "application/json" +} + // Response contains information extracted from an HTTP response. type Response struct { Status int diff --git a/internal/http_client_test.go b/internal/http_client_test.go index 0dd483d7..b9f85136 100644 --- a/internal/http_client_test.go +++ b/internal/http_client_test.go @@ -27,7 +27,7 @@ import ( var cases = []struct { req *Request method string - body interface{} + body string headers map[string]string query map[string]string }{ @@ -52,7 +52,7 @@ var cases = []struct { { req: &Request{ Method: "POST", - Body: map[string]string{"foo": "bar"}, + Body: NewJSONEntity(map[string]string{"foo": "bar"}), Opts: []HTTPOption{ WithHeader("Test-Header", "value1"), WithQueryParam("testParam1", "value2"), @@ -60,35 +60,35 @@ var cases = []struct { }, }, method: "POST", - body: map[string]string{"foo": "bar"}, + body: "{\"foo\":\"bar\"}", headers: map[string]string{"Test-Header": "value1"}, query: map[string]string{"testParam1": "value2", "testParam2": "value3"}, }, { req: &Request{ Method: "POST", - Body: "body", + Body: NewJSONEntity("body"), Opts: []HTTPOption{ WithHeader("Test-Header", "value1"), WithQueryParams(map[string]string{"testParam1": "value2", "testParam2": "value3"}), }, }, method: "POST", - body: "body", + body: "\"body\"", headers: map[string]string{"Test-Header": "value1"}, query: map[string]string{"testParam1": "value2", "testParam2": "value3"}, }, { req: &Request{ Method: "PUT", - Body: Null, + Body: NewJSONEntity(nil), Opts: []HTTPOption{ WithHeader("Test-Header", "value1"), WithQueryParams(map[string]string{"testParam1": "value2", "testParam2": "value3"}), }, }, method: "PUT", - body: Null, + body: "null", headers: map[string]string{"Test-Header": "value1"}, query: map[string]string{"testParam1": "value2", "testParam2": "value3"}, }, @@ -127,21 +127,12 @@ func TestHTTPClient(t *testing.T) { t.Errorf("[%d] Query(%q) = %q; want = %q", idx, k, q, v) } } - if want.body != nil { + if want.body != "" { h := r.Header.Get("Content-Type") if h != "application/json" { t.Errorf("[%d] Content-Type = %q; want = %q", idx, h, "application/json") } - - var wb []byte - if want.body == Null { - wb = []byte("null") - } else { - wb, err = json.Marshal(want.body) - if err != nil { - t.Fatal(err) - } - } + wb := []byte(want.body) gb, err := ioutil.ReadAll(r.Body) if err != nil { t.Fatal(err) From 9cc76cbf097aebb993e9324c70f537f49cc331ab Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Tue, 31 Oct 2017 11:46:52 -0700 Subject: [PATCH 38/58] Renamed fields; Added documentation --- internal/http_client.go | 34 ++++++++++++++++++++-------------- internal/http_client_test.go | 10 +++++----- 2 files changed, 25 insertions(+), 19 deletions(-) diff --git a/internal/http_client.go b/internal/http_client.go index 0a4c4a1e..1821313e 100644 --- a/internal/http_client.go +++ b/internal/http_client.go @@ -25,10 +25,16 @@ import ( "golang.org/x/net/context" ) -// HTTPClient can be used to send and receive JSON messages over HTTP. +// HTTPClient is a convenient API to make HTTP calls. +// +// This API handles some of the repetitive tasks such as entity serialization and deserialization +// involved in making HTTP calls. It provides a convenient mechanism to set headers and query +// parameters on outgoing requests, while enforcing that an explicit context is used per request. +// Responses returned by HTTPClient can be easily parsed as JSON, and provide a simple mechanism to +// extract error details. type HTTPClient struct { - HC *http.Client - EP ErrorParser + Client *http.Client + ErrParser ErrorParser } // Do executes the given Request, and returns a Response. @@ -38,7 +44,7 @@ func (c *HTTPClient) Do(ctx context.Context, r *Request) (*Response, error) { return nil, err } - resp, err := c.HC.Do(req.WithContext(ctx)) + resp, err := c.Client.Do(req.WithContext(ctx)) if err != nil { return nil, err } @@ -49,10 +55,10 @@ func (c *HTTPClient) Do(ctx context.Context, r *Request) (*Response, error) { return nil, err } return &Response{ - Status: resp.StatusCode, - Body: b, - Header: resp.Header, - ep: c.EP, + Status: resp.StatusCode, + Body: b, + Header: resp.Header, + errParser: c.ErrParser, }, nil } @@ -113,10 +119,10 @@ func (e *jsonEntity) Mime() string { // Response contains information extracted from an HTTP response. type Response struct { - Status int - Header http.Header - Body []byte - ep ErrorParser + Status int + Header http.Header + Body []byte + errParser ErrorParser } // CheckStatus checks whether the Response status code has the given HTTP status code. @@ -129,8 +135,8 @@ func (r *Response) CheckStatus(want int) error { } var msg string - if r.ep != nil { - msg = r.ep(r.Body) + if r.errParser != nil { + msg = r.errParser(r.Body) } if msg == "" { msg = string(r.Body) diff --git a/internal/http_client_test.go b/internal/http_client_test.go index b9f85136..8df2b10c 100644 --- a/internal/http_client_test.go +++ b/internal/http_client_test.go @@ -149,7 +149,7 @@ func TestHTTPClient(t *testing.T) { server := httptest.NewServer(handler) defer server.Close() - client := &HTTPClient{HC: http.DefaultClient} + client := &HTTPClient{Client: http.DefaultClient} for _, tc := range cases { tc.req.URL = server.URL resp, err := client.Do(context.Background(), tc.req) @@ -200,8 +200,8 @@ func TestErrorParser(t *testing.T) { return p.Error } client := &HTTPClient{ - HC: http.DefaultClient, - EP: ep, + Client: http.DefaultClient, + ErrParser: ep, } req := &Request{Method: "GET", URL: server.URL} resp, err := client.Do(context.Background(), req) @@ -227,7 +227,7 @@ func TestInvalidURL(t *testing.T) { Method: "GET", URL: "http://localhost:250/mock.url", } - client := &HTTPClient{HC: http.DefaultClient} + client := &HTTPClient{Client: http.DefaultClient} _, err := client.Do(context.Background(), req) if err == nil { t.Errorf("Send() = nil; want error") @@ -251,7 +251,7 @@ func TestUnmarshalError(t *testing.T) { defer server.Close() req := &Request{Method: "GET", URL: server.URL} - client := &HTTPClient{HC: http.DefaultClient} + client := &HTTPClient{Client: http.DefaultClient} resp, err := client.Do(context.Background(), req) if err != nil { t.Fatal(err) From f96fc129f3d36ad87cac952a9afd5671d8e536b7 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Tue, 31 Oct 2017 14:18:28 -0700 Subject: [PATCH 39/58] Removing a redundant else case --- internal/http_client.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/internal/http_client.go b/internal/http_client.go index 1821313e..9216ac9c 100644 --- a/internal/http_client.go +++ b/internal/http_client.go @@ -152,7 +152,8 @@ func (r *Response) CheckStatus(want int) error { func (r *Response) Unmarshal(want int, v interface{}) error { if err := r.CheckStatus(want); err != nil { return err - } else if err := json.Unmarshal(r.Body, v); err != nil { + } + if err := json.Unmarshal(r.Body, v); err != nil { return err } return nil From be94fad5947b72b4e088f24621f424f0ca573ac1 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Tue, 31 Oct 2017 16:51:42 -0700 Subject: [PATCH 40/58] Code readability improvements --- db/db.go | 21 ++++++++++++--------- db/query.go | 2 +- db/ref.go | 31 ++++++++++++++++++++----------- 3 files changed, 33 insertions(+), 21 deletions(-) diff --git a/db/db.go b/db/db.go index c9ebb25f..af7cb0dd 100644 --- a/db/db.go +++ b/db/db.go @@ -71,7 +71,7 @@ func NewClient(ctx context.Context, c *internal.DatabaseConfig) (*Client, error) } } - errParser := func(b []byte) string { + ep := func(b []byte) string { var p struct { Error string `json:"error"` } @@ -81,7 +81,7 @@ func NewClient(ctx context.Context, c *internal.DatabaseConfig) (*Client, error) return p.Error } return &Client{ - hc: &internal.HTTPClient{HC: hc, EP: errParser}, + hc: &internal.HTTPClient{Client: hc, ErrParser: ep}, url: fmt.Sprintf("https://%s", p.Host), ao: string(ao), }, nil @@ -117,8 +117,8 @@ func (c *Client) NewRef(path string) *Ref { } } -func (c *Client) send(ctx context.Context, r *request) (*internal.Response, error) { - req, err := r.NewInternalRequest(c) +func (c *Client) send(ctx context.Context, r *dbReq) (*internal.Response, error) { + req, err := r.NewHTTPRequest(c) if err != nil { return nil, err } @@ -135,26 +135,29 @@ func parsePath(path string) []string { return segs } -type request struct { +type dbReq struct { Method, Path string Body interface{} Opts []internal.HTTPOption } -func (r *request) NewInternalRequest(c *Client) (*internal.Request, error) { +func (r *dbReq) NewHTTPRequest(c *Client) (*internal.Request, error) { if strings.ContainsAny(r.Path, invalidChars) { return nil, fmt.Errorf("invalid path with illegal characters: %q", r.Path) } var opts []internal.HTTPOption - opts = append(opts, r.Opts...) if c.ao != "" { opts = append(opts, internal.WithQueryParam(authVarOverride, c.ao)) } + var b internal.HTTPEntity + if r.Body != nil { + b = internal.NewJSONEntity(r.Body) + } return &internal.Request{ Method: r.Method, URL: fmt.Sprintf("%s%s.json", c.url, r.Path), - Body: r.Body, - Opts: opts, + Body: b, + Opts: append(opts, r.Opts...), }, nil } diff --git a/db/query.go b/db/query.go index dfacb117..4ca73bde 100644 --- a/db/query.go +++ b/db/query.go @@ -99,7 +99,7 @@ func (q *Query) Get(ctx context.Context, v interface{}) error { return err } - req := &request{ + req := &dbReq{ Method: "GET", Path: q.path, Opts: []internal.HTTPOption{internal.WithQueryParams(qp)}, diff --git a/db/ref.go b/db/ref.go index 050bb8ce..554e156d 100644 --- a/db/ref.go +++ b/db/ref.go @@ -89,12 +89,17 @@ func (r *Ref) GetIfChanged(ctx context.Context, etag string, v interface{}) (boo resp, err := r.send(ctx, "GET", internal.WithHeader("If-None-Match", etag)) if err != nil { return false, "", err - } else if err := resp.Unmarshal(http.StatusOK, v); err == nil { - return true, resp.Header.Get("ETag"), nil - } else if err := resp.CheckStatus(http.StatusNotModified); err != nil { + } + if resp.Status == http.StatusNotModified { + return false, etag, nil + } + if err := resp.CheckStatus(http.StatusOK); err != nil { + return false, "", err + } + if err := resp.Unmarshal(http.StatusOK, v); err != nil { return false, "", err } - return false, etag, nil + return true, resp.Header.Get("ETag"), nil } // Set stores the value v in the current database node. @@ -118,12 +123,14 @@ func (r *Ref) SetIfUnchanged(ctx context.Context, etag string, v interface{}) (b resp, err := r.sendWithBody(ctx, "PUT", v, internal.WithHeader("If-Match", etag)) if err != nil { return false, err - } else if err := resp.CheckStatus(http.StatusOK); err == nil { - return true, nil - } else if err := resp.CheckStatus(http.StatusPreconditionFailed); err != nil { + } + if resp.Status == http.StatusPreconditionFailed { + return false, nil + } + if err := resp.CheckStatus(http.StatusOK); err != nil { return false, err } - return false, nil + return true, nil } // Push creates a new child node at the current location, and returns a reference to it. @@ -190,9 +197,11 @@ func (r *Ref) Transaction(ctx context.Context, fn UpdateFn) error { resp, err := r.sendWithBody(ctx, "PUT", new, internal.WithHeader("If-Match", etag)) if err != nil { return err - } else if err := resp.CheckStatus(http.StatusOK); err == nil { + } + if err := resp.CheckStatus(http.StatusOK); err == nil { return nil - } else if err := resp.Unmarshal(http.StatusPreconditionFailed, &curr); err != nil { + } + if err := resp.Unmarshal(http.StatusPreconditionFailed, &curr); err != nil { return err } etag = resp.Header.Get("ETag") @@ -223,7 +232,7 @@ func (r *Ref) sendWithBody( body interface{}, opts ...internal.HTTPOption) (*internal.Response, error) { - req := &request{ + req := &dbReq{ Method: method, Path: r.Path, Body: body, From f08b19164e5e837f50f0ccd8331af3f2448050c4 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Tue, 31 Oct 2017 17:17:28 -0700 Subject: [PATCH 41/58] Cleaned up the RTDB HTTP client code --- db/db.go | 49 +++++++++++++++++-------------------------------- db/query.go | 8 +------- db/ref.go | 16 ++++------------ 3 files changed, 22 insertions(+), 51 deletions(-) diff --git a/db/db.go b/db/db.go index af7cb0dd..a7aaeb7d 100644 --- a/db/db.go +++ b/db/db.go @@ -117,12 +117,24 @@ func (c *Client) NewRef(path string) *Ref { } } -func (c *Client) send(ctx context.Context, r *dbReq) (*internal.Response, error) { - req, err := r.NewHTTPRequest(c) - if err != nil { - return nil, err +func (c *Client) send( + ctx context.Context, + method, path string, + body internal.HTTPEntity, + opts ...internal.HTTPOption) (*internal.Response, error) { + + if strings.ContainsAny(path, invalidChars) { + return nil, fmt.Errorf("invalid path with illegal characters: %q", path) + } + if c.ao != "" { + opts = append(opts, internal.WithQueryParam(authVarOverride, c.ao)) } - return c.hc.Do(ctx, req) + return c.hc.Do(ctx, &internal.Request{ + Method: method, + URL: fmt.Sprintf("%s%s.json", c.url, path), + Body: body, + Opts: opts, + }) } func parsePath(path string) []string { @@ -134,30 +146,3 @@ func parsePath(path string) []string { } return segs } - -type dbReq struct { - Method, Path string - Body interface{} - Opts []internal.HTTPOption -} - -func (r *dbReq) NewHTTPRequest(c *Client) (*internal.Request, error) { - if strings.ContainsAny(r.Path, invalidChars) { - return nil, fmt.Errorf("invalid path with illegal characters: %q", r.Path) - } - - var opts []internal.HTTPOption - if c.ao != "" { - opts = append(opts, internal.WithQueryParam(authVarOverride, c.ao)) - } - var b internal.HTTPEntity - if r.Body != nil { - b = internal.NewJSONEntity(r.Body) - } - return &internal.Request{ - Method: r.Method, - URL: fmt.Sprintf("%s%s.json", c.url, r.Path), - Body: b, - Opts: append(opts, r.Opts...), - }, nil -} diff --git a/db/query.go b/db/query.go index 4ca73bde..ab31aaa1 100644 --- a/db/query.go +++ b/db/query.go @@ -98,13 +98,7 @@ func (q *Query) Get(ctx context.Context, v interface{}) error { if err := initQueryParams(q, qp); err != nil { return err } - - req := &dbReq{ - Method: "GET", - Path: q.path, - Opts: []internal.HTTPOption{internal.WithQueryParams(qp)}, - } - resp, err := q.client.send(ctx, req) + resp, err := q.client.send(ctx, "GET", q.path, nil, internal.WithQueryParams(qp)) if err != nil { return err } diff --git a/db/ref.go b/db/ref.go index 554e156d..3936ed01 100644 --- a/db/ref.go +++ b/db/ref.go @@ -93,9 +93,6 @@ func (r *Ref) GetIfChanged(ctx context.Context, etag string, v interface{}) (boo if resp.Status == http.StatusNotModified { return false, etag, nil } - if err := resp.CheckStatus(http.StatusOK); err != nil { - return false, "", err - } if err := resp.Unmarshal(http.StatusOK, v); err != nil { return false, "", err } @@ -166,6 +163,7 @@ func (r *Ref) Update(ctx context.Context, v map[string]interface{}) error { return resp.CheckStatus(http.StatusNoContent) } +// UpdateFn represents a function type that can be passed into Transaction(). type UpdateFn func(interface{}) (interface{}, error) // Transaction atomically modifies the data at this location. @@ -198,7 +196,7 @@ func (r *Ref) Transaction(ctx context.Context, fn UpdateFn) error { if err != nil { return err } - if err := resp.CheckStatus(http.StatusOK); err == nil { + if resp.Status == http.StatusOK { return nil } if err := resp.Unmarshal(http.StatusPreconditionFailed, &curr); err != nil { @@ -223,7 +221,7 @@ func (r *Ref) send( method string, opts ...internal.HTTPOption) (*internal.Response, error) { - return r.sendWithBody(ctx, method, nil, opts...) + return r.client.send(ctx, method, r.Path, nil, opts...) } func (r *Ref) sendWithBody( @@ -232,11 +230,5 @@ func (r *Ref) sendWithBody( body interface{}, opts ...internal.HTTPOption) (*internal.Response, error) { - req := &dbReq{ - Method: method, - Path: r.Path, - Body: body, - Opts: opts, - } - return r.client.send(ctx, req) + return r.client.send(ctx, method, r.Path, internal.NewJSONEntity(body), opts...) } From 681a529dc1915a1a1edbee2d3e8238361aebfef0 Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Tue, 13 Feb 2018 13:20:43 -0800 Subject: [PATCH 42/58] Added shallow reads support; Added the new txn API --- db/ref.go | 40 +++++++++++++++---- db/ref_test.go | 84 +++++++++++++++++++++++++++++++-------- firebase.go | 8 ++-- integration/db/db_test.go | 34 ++++++++++++---- 4 files changed, 131 insertions(+), 35 deletions(-) diff --git a/db/ref.go b/db/ref.go index 3936ed01..bf3bb707 100644 --- a/db/ref.go +++ b/db/ref.go @@ -15,6 +15,7 @@ package db import ( + "encoding/json" "fmt" "net/http" "strings" @@ -35,6 +36,19 @@ type Ref struct { client *Client } +// TransactionNode represents the value of a node within the scope of a transaction. +type TransactionNode interface { + Unmarshal(v interface{}) error +} + +type transactionNodeImpl struct { + Raw []byte +} + +func (t *transactionNodeImpl) Unmarshal(v interface{}) error { + return json.Unmarshal(t.Raw, v) +} + // Parent returns a reference to the parent of the current node. // // If the current reference points to the root of the database, Parent returns nil. @@ -78,6 +92,17 @@ func (r *Ref) GetWithETag(ctx context.Context, v interface{}) (string, error) { return resp.Header.Get("Etag"), nil } +// GetShallow performs a shallow read on the current database location. +// +// Shallow reads do not retrieve the child nodes of the current reference. +func (r *Ref) GetShallow(ctx context.Context, v interface{}) error { + resp, err := r.send(ctx, "GET", internal.WithQueryParam("shallow", "true")) + if err != nil { + return err + } + return resp.Unmarshal(http.StatusOK, v) +} + // GetIfChanged retrieves the value and ETag of the current database location only if the specified // ETag does not match. // @@ -164,7 +189,7 @@ func (r *Ref) Update(ctx context.Context, v map[string]interface{}) error { } // UpdateFn represents a function type that can be passed into Transaction(). -type UpdateFn func(interface{}) (interface{}, error) +type UpdateFn func(TransactionNode) (interface{}, error) // Transaction atomically modifies the data at this location. // @@ -181,25 +206,26 @@ type UpdateFn func(interface{}) (interface{}, error) // The update function may also force an early abort by returning an error instead of returning a // value. func (r *Ref) Transaction(ctx context.Context, fn UpdateFn) error { - var curr interface{} - etag, err := r.GetWithETag(ctx, &curr) + resp, err := r.send(ctx, "GET", internal.WithHeader("X-Firebase-ETag", "true")) if err != nil { return err + } else if err := resp.CheckStatus(http.StatusOK); err != nil { + return err } + etag := resp.Header.Get("Etag") for i := 0; i < txnRetries; i++ { - new, err := fn(curr) + new, err := fn(&transactionNodeImpl{resp.Body}) if err != nil { return err } - resp, err := r.sendWithBody(ctx, "PUT", new, internal.WithHeader("If-Match", etag)) + resp, err = r.sendWithBody(ctx, "PUT", new, internal.WithHeader("If-Match", etag)) if err != nil { return err } if resp.Status == http.StatusOK { return nil - } - if err := resp.Unmarshal(http.StatusPreconditionFailed, &curr); err != nil { + } else if err := resp.CheckStatus(http.StatusPreconditionFailed); err != nil { return err } etag = resp.Header.Get("ETag") diff --git a/db/ref_test.go b/db/ref_test.go index 685a8e47..52a6e66d 100644 --- a/db/ref_test.go +++ b/db/ref_test.go @@ -47,6 +47,14 @@ var testOps = []struct { return err }, }, + { + "GetShallow()", + "test", + func(r *Ref) error { + var got string + return r.GetShallow(context.Background(), &got) + }, + }, { "GetIfChanged()", "test", @@ -97,7 +105,11 @@ var testOps = []struct { "Transaction()", nil, func(r *Ref) error { - fn := func(v interface{}) (interface{}, error) { + fn := func(t TransactionNode) (interface{}, error) { + var v interface{} + if err := t.Unmarshal(&v); err != nil { + return nil, err + } return v, nil } return r.Transaction(context.Background(), fn) @@ -158,6 +170,32 @@ func TestGetWithStruct(t *testing.T) { checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: "/peter.json"}) } +func TestGetShallow(t *testing.T) { + mock := &mockServer{} + srv := mock.Start(client) + defer srv.Close() + + cases := []interface{}{ + nil, float64(1), true, "foo", + map[string]interface{}{"name": "Peter Parker", "age": float64(17)}, + map[string]interface{}{"name": "Peter Parker", "nestedChild": true}, + } + wantQuery := map[string]string{"shallow": "true"} + var want []*testReq + for _, tc := range cases { + mock.Resp = tc + var got interface{} + if err := testref.GetShallow(context.Background(), &got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(tc, got) { + t.Errorf("Get() = %v; want = %v", got, tc) + } + want = append(want, &testReq{Method: "GET", Path: "/peter.json", Query: wantQuery}) + } + checkAllRequests(t, mock.Reqs, want) +} + func TestGetWithETag(t *testing.T) { want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} mock := &mockServer{ @@ -492,10 +530,13 @@ func TestTransaction(t *testing.T) { srv := mock.Start(client) defer srv.Close() - var fn UpdateFn = func(i interface{}) (interface{}, error) { - p := i.(map[string]interface{}) - p["age"] = p["age"].(float64) + 1.0 - return p, nil + var fn UpdateFn = func(t TransactionNode) (interface{}, error) { + var p person + if err := t.Unmarshal(&p); err != nil { + return nil, err + } + p.Age++ + return &p, nil } if err := testref.Transaction(context.Background(), fn); err != nil { t.Fatal(err) @@ -527,7 +568,7 @@ func TestTransactionRetry(t *testing.T) { defer srv.Close() cnt := 0 - var fn UpdateFn = func(i interface{}) (interface{}, error) { + var fn UpdateFn = func(t TransactionNode) (interface{}, error) { if cnt == 0 { mock.Status = http.StatusPreconditionFailed mock.Header = map[string]string{"ETag": "mock-etag2"} @@ -536,9 +577,12 @@ func TestTransactionRetry(t *testing.T) { mock.Status = http.StatusOK } cnt++ - p := i.(map[string]interface{}) - p["age"] = p["age"].(float64) + 1.0 - return p, nil + var p person + if err := t.Unmarshal(&p); err != nil { + return nil, err + } + p.Age++ + return &p, nil } if err := testref.Transaction(context.Background(), fn); err != nil { t.Fatal(err) @@ -583,7 +627,7 @@ func TestTransactionError(t *testing.T) { cnt := 0 want := "user error" - var fn UpdateFn = func(i interface{}) (interface{}, error) { + var fn UpdateFn = func(t TransactionNode) (interface{}, error) { if cnt == 0 { mock.Status = http.StatusPreconditionFailed mock.Header = map[string]string{"ETag": "mock-etag2"} @@ -592,9 +636,12 @@ func TestTransactionError(t *testing.T) { return nil, fmt.Errorf(want) } cnt++ - p := i.(map[string]interface{}) - p["age"] = p["age"].(float64) + 1.0 - return p, nil + var p person + if err := t.Unmarshal(&p); err != nil { + return nil, err + } + p.Age++ + return &p, nil } if err := testref.Transaction(context.Background(), fn); err == nil || err.Error() != want { t.Errorf("Transaction() = %v; want = %q", err, want) @@ -629,15 +676,18 @@ func TestTransactionAbort(t *testing.T) { defer srv.Close() cnt := 0 - var fn UpdateFn = func(i interface{}) (interface{}, error) { + var fn UpdateFn = func(t TransactionNode) (interface{}, error) { if cnt == 0 { mock.Status = http.StatusPreconditionFailed mock.Header = map[string]string{"ETag": "mock-etag1"} } cnt++ - p := i.(map[string]interface{}) - p["age"] = p["age"].(float64) + 1.0 - return p, nil + var p person + if err := t.Unmarshal(&p); err != nil { + return nil, err + } + p.Age++ + return &p, nil } err := testref.Transaction(context.Background(), fn) if err == nil { diff --git a/firebase.go b/firebase.go index 44da937e..8419f86a 100644 --- a/firebase.go +++ b/firebase.go @@ -66,10 +66,10 @@ type App struct { // Config represents the configuration used to initialize an App. type Config struct { - AuthOverride *db.AuthOverride - DatabaseURL string `json:"databaseURL"` - ProjectID string `json:"projectId"` - StorageBucket string `json:"storageBucket"` + AuthOverride *db.AuthOverride `json:"databaseAuthVariableOverride"` + DatabaseURL string `json:"databaseURL"` + ProjectID string `json:"projectId"` + StorageBucket string `json:"storageBucket"` } // Auth returns an instance of auth.Client. diff --git a/integration/db/db_test.go b/integration/db/db_test.go index 65eb3f8b..902fba16 100644 --- a/integration/db/db_test.go +++ b/integration/db/db_test.go @@ -220,6 +220,20 @@ func TestGetWithETag(t *testing.T) { } } +func TestGetShallow(t *testing.T) { + var m map[string]interface{} + if err := ref.GetShallow(context.Background(), &m); err != nil { + t.Fatal(err) + } + want := map[string]interface{}{} + for k := range testData { + want[k] = true + } + if !reflect.DeepEqual(want, m) { + t.Errorf("GetShallow() = %v; want = %v", m, want) + } +} + func TestGetIfChanged(t *testing.T) { var m map[string]interface{} ok, etag, err := ref.GetIfChanged(context.Background(), "wrong-etag", &m) @@ -480,11 +494,14 @@ func TestTransaction(t *testing.T) { if err != nil { t.Fatal(err) } - fn := func(curr interface{}) (interface{}, error) { - snap := curr.(map[string]interface{}) - snap["name"] = "Richard Owen" - snap["since"] = 1804 - return snap, nil + fn := func(t db.TransactionNode) (interface{}, error) { + var user User + if err := t.Unmarshal(&user); err != nil { + return nil, err + } + user.Name = "Richard Owen" + user.Since = 1804 + return &user, nil } if err := u.Transaction(context.Background(), fn); err != nil { t.Fatal(err) @@ -504,8 +521,11 @@ func TestTransactionScalar(t *testing.T) { if err := cnt.Set(context.Background(), 42); err != nil { t.Fatal(err) } - fn := func(curr interface{}) (interface{}, error) { - snap := curr.(float64) + fn := func(t db.TransactionNode) (interface{}, error) { + var snap float64 + if err := t.Unmarshal(&snap); err != nil { + return nil, err + } return snap + 1, nil } if err := cnt.Transaction(context.Background(), fn); err != nil { From 0dac62f84ba55aaf4f6638f994fd23dbb0fd58e8 Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Tue, 13 Feb 2018 17:33:27 -0800 Subject: [PATCH 43/58] Implementing GetOrdered() for queries --- db/query.go | 195 ++++++++++++++++++++++++++++++++++++++++++++++- db/query_test.go | 74 ++++++++++++++++++ 2 files changed, 265 insertions(+), 4 deletions(-) diff --git a/db/query.go b/db/query.go index ab31aaa1..f76badf6 100644 --- a/db/query.go +++ b/db/query.go @@ -18,6 +18,8 @@ import ( "encoding/json" "fmt" "net/http" + "reflect" + "sort" "strconv" "strings" @@ -105,6 +107,40 @@ func (q *Query) Get(ctx context.Context, v interface{}) error { return resp.Unmarshal(http.StatusOK, v) } +// GetOrdered executes the Query and provides the results as an ordered list. +// +// v must be a pointer to an array or a slice. +func (q *Query) GetOrdered(ctx context.Context, v interface{}) error { + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Ptr { + return fmt.Errorf("value must be a pointer") + } + if rv.Elem().Kind() != reflect.Slice && rv.Elem().Kind() != reflect.Array { + return fmt.Errorf("value must be a pointer to an array or a slice") + } + + var temp interface{} + if err := q.Get(ctx, &temp); err != nil { + return err + } + + sr, err := newSortableResult(temp, q.ob) + if err != nil { + return err + } + sort.Sort(sr) + + var values []interface{} + for _, val := range sr { + values = append(values, val.Value) + } + b, err := json.Marshal(values) + if err != nil { + return err + } + return json.Unmarshal(b, v) +} + // OrderByChild returns a Query that orders data by child values before applying filters. // // Returned Query can be used to set additional parameters, and execute complex database queries @@ -167,10 +203,7 @@ func initQueryParams(q *Query, qp map[string]string) error { if err := encodeFilter("endAt", q.end, qp); err != nil { return err } - if err := encodeFilter("equalTo", q.equalTo, qp); err != nil { - return err - } - return nil + return encodeFilter("equalTo", q.equalTo, qp) } func encodeFilter(key string, val interface{}, m map[string]string) error { @@ -217,3 +250,157 @@ func (p orderByProperty) encode() (string, error) { } return string(b), nil } + +const ( + typeNull = 0 + typeBoolFalse = 1 + typeBoolTrue = 2 + typeNumeric = 3 + typeString = 4 + typeObject = 5 +) + +type comparableKey struct { + Num *float64 + Str *string +} + +func (k *comparableKey) Val() interface{} { + if k.Str != nil { + return *k.Str + } + return *k.Num +} + +func (k *comparableKey) Compare(o *comparableKey) int { + if k.Str != nil && o.Str != nil { + return strings.Compare(*k.Str, *o.Str) + } else if k.Num != nil && o.Num != nil { + if *k.Num < *o.Num { + return -1 + } else if *k.Num == *o.Num { + return 0 + } + return 1 + } else if k.Num != nil { + return -1 + } + return 1 +} + +func newComparableKey(v interface{}) *comparableKey { + if s, ok := v.(string); ok { + return &comparableKey{Str: &s} + } + if i, ok := v.(int); ok { + f := float64(i) + return &comparableKey{Num: &f} + } + + f := v.(float64) + return &comparableKey{Num: &f} +} + +type sortableResult []*sortEntry + +func (s sortableResult) Len() int { + return len(s) +} + +func (s sortableResult) Swap(i, j int) { + s[i], s[j] = s[j], s[i] +} + +func (s sortableResult) Less(i, j int) bool { + a, b := s[i], s[j] + var aKey, bKey *comparableKey + if a.IndexType == b.IndexType { + if (a.IndexType == typeNumeric || a.IndexType == typeString) && a.Index != b.Index { + aKey, bKey = newComparableKey(a.Index), newComparableKey(b.Index) + } else { + aKey, bKey = newComparableKey(a.Key), newComparableKey(b.Key) + } + } else { + aKey, bKey = newComparableKey(a.IndexType), newComparableKey(b.IndexType) + } + + return aKey.Compare(bKey) < 0 +} + +func newSortableResult(values interface{}, order orderBy) (sortableResult, error) { + var entries sortableResult + if m, ok := values.(map[string]interface{}); ok { + for key, val := range m { + entries = append(entries, newSortEntry(key, val, order)) + } + } else if l, ok := values.([]interface{}); ok { + for key, val := range l { + entries = append(entries, newSortEntry(key, val, order)) + } + } else { + return nil, fmt.Errorf("sorting not supported for the result") + } + return entries, nil +} + +type sortEntry struct { + Key *comparableKey + Value interface{} + Index interface{} + IndexType int +} + +func newSortEntry(key, val interface{}, order orderBy) *sortEntry { + var index interface{} + if prop, ok := order.(orderByProperty); ok { + if prop == "$value" { + index = val + } else { + index = key + } + } else { + path := order.(orderByChild) + index = extractChildValue(val, string(path)) + } + return &sortEntry{ + Key: newComparableKey(key), + Value: val, + Index: index, + IndexType: getIndexType(index), + } +} + +func extractChildValue(val interface{}, path string) interface{} { + segments := parsePath(path) + curr := val + for _, s := range segments { + if curr == nil { + return nil + } + + currMap, ok := curr.(map[string]interface{}) + if !ok { + return nil + } + if curr, ok = currMap[s]; !ok { + return nil + } + } + return curr +} + +func getIndexType(index interface{}) int { + if index == nil { + return typeNull + } else if b, ok := index.(bool); ok { + if b { + return typeBoolTrue + } + return typeBoolFalse + } else if _, ok := index.(float64); ok { + return typeNumeric + } else if _, ok := index.(string); ok { + return typeString + } + return typeObject +} diff --git a/db/query_test.go b/db/query_test.go index a67ff472..0ae1a753 100644 --- a/db/query_test.go +++ b/db/query_test.go @@ -14,12 +14,21 @@ package db import ( + "fmt" "reflect" "testing" "golang.org/x/net/context" ) +var sortableResp = map[string]interface{}{ + "bob": person{Name: "bob", Age: 20}, + "alice": person{Name: "alice", Age: 30}, + "charlie": person{Name: "charlie", Age: 15}, + "dave": person{Name: "dave", Age: 25}, + "ernie": person{Name: "ernie"}, +} + func TestChildQuery(t *testing.T) { want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} mock := &mockServer{Resp: want} @@ -332,3 +341,68 @@ func TestAllParamsQuery(t *testing.T) { }, }) } + +func TestOrderedChildQuery(t *testing.T) { + mock := &mockServer{Resp: sortableResp} + srv := mock.Start(client) + defer srv.Close() + + cases := []struct { + child string + want []string + }{ + {"age", []string{"ernie", "charlie", "bob", "dave", "alice"}}, + {"name", []string{"alice", "bob", "charlie", "dave", "ernie"}}, + } + + var reqs []*testReq + for _, tc := range cases { + var result []person + if err := testref.OrderByChild(tc.child).GetOrdered(context.Background(), &result); err != nil { + t.Fatal(err) + } + reqs = append(reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"orderBy": fmt.Sprintf("%q", tc.child)}, + }) + + var got []string + for _, r := range result { + got = append(got, r.Name) + } + if !reflect.DeepEqual(tc.want, got) { + t.Errorf("GetOrdered(child: %q) = %v; want = %v", "age", got, tc.want) + } + } + + checkAllRequests(t, mock.Reqs, reqs) +} + +func TestOrderedKeyQuery(t *testing.T) { + mock := &mockServer{Resp: sortableResp} + srv := mock.Start(client) + defer srv.Close() + + var result []person + if err := testref.OrderByKey().GetOrdered(context.Background(), &result); err != nil { + t.Fatal(err) + } + req := &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"orderBy": "\"$key\""}, + } + + var got []string + for _, r := range result { + got = append(got, r.Name) + } + + want := []string{"alice", "bob", "charlie", "dave", "ernie"} + if !reflect.DeepEqual(want, got) { + t.Errorf("GetOrdered(child: %q) = %v; want = %v", "age", got, want) + } + + checkOnlyRequest(t, mock.Reqs, req) +} From 01ad19851df70ef84c4da45035d6cff84046926d Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Tue, 13 Feb 2018 18:49:45 -0800 Subject: [PATCH 44/58] Adding more sorting tests --- db/query.go | 2 +- db/query_test.go | 80 +++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 80 insertions(+), 2 deletions(-) diff --git a/db/query.go b/db/query.go index f76badf6..04030b73 100644 --- a/db/query.go +++ b/db/query.go @@ -318,7 +318,7 @@ func (s sortableResult) Less(i, j int) bool { if (a.IndexType == typeNumeric || a.IndexType == typeString) && a.Index != b.Index { aKey, bKey = newComparableKey(a.Index), newComparableKey(b.Index) } else { - aKey, bKey = newComparableKey(a.Key), newComparableKey(b.Key) + aKey, bKey = newComparableKey(a.Key.Val()), newComparableKey(b.Key.Val()) } } else { aKey, bKey = newComparableKey(a.IndexType), newComparableKey(b.IndexType) diff --git a/db/query_test.go b/db/query_test.go index 0ae1a753..6925f398 100644 --- a/db/query_test.go +++ b/db/query_test.go @@ -401,8 +401,86 @@ func TestOrderedKeyQuery(t *testing.T) { want := []string{"alice", "bob", "charlie", "dave", "ernie"} if !reflect.DeepEqual(want, got) { - t.Errorf("GetOrdered(child: %q) = %v; want = %v", "age", got, want) + t.Errorf("GetOrdered(key) = %v; want = %v", got, want) } checkOnlyRequest(t, mock.Reqs, req) } + +func TestOrderedValueQuery(t *testing.T) { + cases := []struct { + resp map[string]interface{} + want []interface{} + }{ + { + resp: map[string]interface{}{"k1": 1, "k2": 2, "k3": 3}, + want: []interface{}{1.0, 2.0, 3.0}, + }, + { + resp: map[string]interface{}{"k1": 3, "k2": 2, "k3": 1}, + want: []interface{}{1.0, 2.0, 3.0}, + }, + { + resp: map[string]interface{}{"k1": 3, "k2": 1, "k3": 2}, + want: []interface{}{1.0, 2.0, 3.0}, + }, + { + resp: map[string]interface{}{"k1": 1, "k2": 2, "k3": 1}, + want: []interface{}{1.0, 1.0, 2.0}, + }, + { + resp: map[string]interface{}{"k1": 1, "k2": 1, "k3": 2}, + want: []interface{}{1.0, 1.0, 2.0}, + }, + { + resp: map[string]interface{}{"k1": 2, "k2": 1, "k3": 1}, + want: []interface{}{1.0, 1.0, 2.0}, + }, + { + resp: map[string]interface{}{"k1": "foo", "k2": "bar", "k3": "baz"}, + want: []interface{}{"bar", "baz", "foo"}, + }, + { + resp: map[string]interface{}{"k1": "foo", "k2": "bar", "k3": 10}, + want: []interface{}{10.0, "bar", "foo"}, + }, + { + resp: map[string]interface{}{"k1": "foo", "k2": "bar", "k3": nil}, + want: []interface{}{nil, "bar", "foo"}, + }, + { + resp: map[string]interface{}{"k1": 5, "k2": "bar", "k3": nil}, + want: []interface{}{nil, 5.0, "bar"}, + }, + { + resp: map[string]interface{}{ + "k1": true, "k2": 0, "k3": "foo", "k4": "foo", "k5": false, + "k6": map[string]interface{}{"k1": true}, + }, + want: []interface{}{false, true, 0.0, "foo", "foo", map[string]interface{}{"k1": true}}, + }, + } + + mock := &mockServer{} + srv := mock.Start(client) + defer srv.Close() + + var reqs []*testReq + for _, tc := range cases { + mock.Resp = tc.resp + + var got []interface{} + if err := testref.OrderByValue().GetOrdered(context.Background(), &got); err != nil { + t.Fatal(err) + } + reqs = append(reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"orderBy": "\"$value\""}, + }) + + if !reflect.DeepEqual(tc.want, got) { + t.Errorf("GetOrdered(value) = %v; want = %v", got, tc.want) + } + } +} From 3cd8ea558128cb3abf534a23a4e902ee2b6cc26b Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Tue, 13 Feb 2018 22:39:22 -0800 Subject: [PATCH 45/58] Added Query ordering tests --- db/query.go | 2 +- db/query_test.go | 229 +++++++++++++++++++++++++++++++++++++---------- 2 files changed, 181 insertions(+), 50 deletions(-) diff --git a/db/query.go b/db/query.go index 04030b73..e81a6154 100644 --- a/db/query.go +++ b/db/query.go @@ -318,7 +318,7 @@ func (s sortableResult) Less(i, j int) bool { if (a.IndexType == typeNumeric || a.IndexType == typeString) && a.Index != b.Index { aKey, bKey = newComparableKey(a.Index), newComparableKey(b.Index) } else { - aKey, bKey = newComparableKey(a.Key.Val()), newComparableKey(b.Key.Val()) + aKey, bKey = a.Key, b.Key } } else { aKey, bKey = newComparableKey(a.IndexType), newComparableKey(b.IndexType) diff --git a/db/query_test.go b/db/query_test.go index 6925f398..0fd6bc56 100644 --- a/db/query_test.go +++ b/db/query_test.go @@ -14,14 +14,13 @@ package db import ( - "fmt" "reflect" "testing" "golang.org/x/net/context" ) -var sortableResp = map[string]interface{}{ +var sortableKeysResp = map[string]interface{}{ "bob": person{Name: "bob", Age: 20}, "alice": person{Name: "alice", Age: 30}, "charlie": person{Name: "charlie", Age: 15}, @@ -29,6 +28,70 @@ var sortableResp = map[string]interface{}{ "ernie": person{Name: "ernie"}, } +var sortableValuesResp = []struct { + resp map[string]interface{} + want []interface{} +}{ + { + resp: map[string]interface{}{"k1": 1, "k2": 2, "k3": 3}, + want: []interface{}{1.0, 2.0, 3.0}, + }, + { + resp: map[string]interface{}{"k1": 3, "k2": 2, "k3": 1}, + want: []interface{}{1.0, 2.0, 3.0}, + }, + { + resp: map[string]interface{}{"k1": 3, "k2": 1, "k3": 2}, + want: []interface{}{1.0, 2.0, 3.0}, + }, + { + resp: map[string]interface{}{"k1": 1, "k2": 2, "k3": 1}, + want: []interface{}{1.0, 1.0, 2.0}, + }, + { + resp: map[string]interface{}{"k1": 1, "k2": 1, "k3": 2}, + want: []interface{}{1.0, 1.0, 2.0}, + }, + { + resp: map[string]interface{}{"k1": 2, "k2": 1, "k3": 1}, + want: []interface{}{1.0, 1.0, 2.0}, + }, + { + resp: map[string]interface{}{"k1": "foo", "k2": "bar", "k3": "baz"}, + want: []interface{}{"bar", "baz", "foo"}, + }, + { + resp: map[string]interface{}{"k1": "foo", "k2": "bar", "k3": 10}, + want: []interface{}{10.0, "bar", "foo"}, + }, + { + resp: map[string]interface{}{"k1": "foo", "k2": "bar", "k3": nil}, + want: []interface{}{nil, "bar", "foo"}, + }, + { + resp: map[string]interface{}{"k1": 5, "k2": "bar", "k3": nil}, + want: []interface{}{nil, 5.0, "bar"}, + }, + { + resp: map[string]interface{}{ + "k1": true, "k2": 0, "k3": "foo", "k4": "foo", "k5": false, + "k6": map[string]interface{}{"k1": true}, + }, + want: []interface{}{false, true, 0.0, "foo", "foo", map[string]interface{}{"k1": true}}, + }, + { + resp: map[string]interface{}{ + "k1": true, "k2": 0, "k3": "foo", "k4": "foo", "k5": false, + "k6": map[string]interface{}{"k1": true}, "k7": nil, + "k8": map[string]interface{}{"k0": true}, + }, + want: []interface{}{ + nil, false, true, 0.0, "foo", "foo", + map[string]interface{}{"k1": true}, map[string]interface{}{"k0": true}, + }, + }, +} + func TestChildQuery(t *testing.T) { want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} mock := &mockServer{Resp: want} @@ -342,45 +405,107 @@ func TestAllParamsQuery(t *testing.T) { }) } -func TestOrderedChildQuery(t *testing.T) { - mock := &mockServer{Resp: sortableResp} +func TestInvalidGetOrdered(t *testing.T) { + q := testref.OrderByKey() + + var i interface{} + want := "value must be a pointer" + err := q.GetOrdered(context.Background(), i) + if err == nil || err.Error() != want { + t.Errorf("GetOrdered(interface) = %v; want = %v", err, want) + } + + want = "value must be a pointer to an array or a slice" + err = q.GetOrdered(context.Background(), &i) + if err == nil || err.Error() != want { + t.Errorf("GetOrdered(interface) = %v; want = %v", err, want) + } +} + +func TestChildQueryGetOrdered(t *testing.T) { + mock := &mockServer{} srv := mock.Start(client) defer srv.Close() - cases := []struct { - child string - want []string - }{ - {"age", []string{"ernie", "charlie", "bob", "dave", "alice"}}, - {"name", []string{"alice", "bob", "charlie", "dave", "ernie"}}, + type parsedMap struct { + Child interface{} `json:"child"` } var reqs []*testReq - for _, tc := range cases { - var result []person - if err := testref.OrderByChild(tc.child).GetOrdered(context.Background(), &result); err != nil { + for _, tc := range sortableValuesResp { + resp := map[string]interface{}{} + for k, v := range tc.resp { + resp[k] = map[string]interface{}{"child": v} + } + mock.Resp = resp + + var result []parsedMap + if err := testref.OrderByChild("child").GetOrdered(context.Background(), &result); err != nil { + t.Fatal(err) + } + reqs = append(reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"orderBy": "\"child\""}, + }) + + var got []interface{} + for _, r := range result { + got = append(got, r.Child) + } + if !reflect.DeepEqual(tc.want, got) { + t.Errorf("GetOrdered(child: %q) = %v; want = %v", "child", got, tc.want) + } + } + + checkAllRequests(t, mock.Reqs, reqs) +} + +func TestGrandChildQueryGetOrdered(t *testing.T) { + mock := &mockServer{} + srv := mock.Start(client) + defer srv.Close() + + type grandChild struct { + GrandChild interface{} `json:"grandchild"` + } + type parsedMap struct { + Child grandChild `json:"child"` + } + + var reqs []*testReq + for _, tc := range sortableValuesResp { + resp := map[string]interface{}{} + for k, v := range tc.resp { + resp[k] = map[string]interface{}{"child": map[string]interface{}{"grandchild": v}} + } + mock.Resp = resp + + var result []parsedMap + q := testref.OrderByChild("child/grandchild") + if err := q.GetOrdered(context.Background(), &result); err != nil { t.Fatal(err) } reqs = append(reqs, &testReq{ Method: "GET", Path: "/peter.json", - Query: map[string]string{"orderBy": fmt.Sprintf("%q", tc.child)}, + Query: map[string]string{"orderBy": "\"child/grandchild\""}, }) - var got []string + var got []interface{} for _, r := range result { - got = append(got, r.Name) + got = append(got, r.Child.GrandChild) } if !reflect.DeepEqual(tc.want, got) { - t.Errorf("GetOrdered(child: %q) = %v; want = %v", "age", got, tc.want) + t.Errorf("GetOrdered(child: %q) = %v; want = %v", "child/grandchild", got, tc.want) } } checkAllRequests(t, mock.Reqs, reqs) } -func TestOrderedKeyQuery(t *testing.T) { - mock := &mockServer{Resp: sortableResp} +func TestKeyQueryOrderedGet(t *testing.T) { + mock := &mockServer{Resp: sortableKeysResp} srv := mock.Start(client) defer srv.Close() @@ -407,57 +532,63 @@ func TestOrderedKeyQuery(t *testing.T) { checkOnlyRequest(t, mock.Reqs, req) } -func TestOrderedValueQuery(t *testing.T) { +func TestValueQueryGetOrdered(t *testing.T) { + mock := &mockServer{} + srv := mock.Start(client) + defer srv.Close() + + var reqs []*testReq + for _, tc := range sortableValuesResp { + mock.Resp = tc.resp + + var got []interface{} + if err := testref.OrderByValue().GetOrdered(context.Background(), &got); err != nil { + t.Fatal(err) + } + reqs = append(reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"orderBy": "\"$value\""}, + }) + + if !reflect.DeepEqual(tc.want, got) { + t.Errorf("GetOrdered(value) = %v; want = %v", got, tc.want) + } + } +} + +func TestValueQueryOrderedGetWithList(t *testing.T) { cases := []struct { - resp map[string]interface{} + resp []interface{} want []interface{} }{ { - resp: map[string]interface{}{"k1": 1, "k2": 2, "k3": 3}, + resp: []interface{}{1, 2, 3}, want: []interface{}{1.0, 2.0, 3.0}, }, { - resp: map[string]interface{}{"k1": 3, "k2": 2, "k3": 1}, + resp: []interface{}{3, 2, 1}, want: []interface{}{1.0, 2.0, 3.0}, }, { - resp: map[string]interface{}{"k1": 3, "k2": 1, "k3": 2}, + resp: []interface{}{1, 3, 2}, want: []interface{}{1.0, 2.0, 3.0}, }, { - resp: map[string]interface{}{"k1": 1, "k2": 2, "k3": 1}, - want: []interface{}{1.0, 1.0, 2.0}, + resp: []interface{}{1, 3, 3}, + want: []interface{}{1.0, 3.0, 3.0}, }, { - resp: map[string]interface{}{"k1": 1, "k2": 1, "k3": 2}, + resp: []interface{}{1, 2, 1}, want: []interface{}{1.0, 1.0, 2.0}, }, { - resp: map[string]interface{}{"k1": 2, "k2": 1, "k3": 1}, - want: []interface{}{1.0, 1.0, 2.0}, - }, - { - resp: map[string]interface{}{"k1": "foo", "k2": "bar", "k3": "baz"}, + resp: []interface{}{"foo", "bar", "baz"}, want: []interface{}{"bar", "baz", "foo"}, }, { - resp: map[string]interface{}{"k1": "foo", "k2": "bar", "k3": 10}, - want: []interface{}{10.0, "bar", "foo"}, - }, - { - resp: map[string]interface{}{"k1": "foo", "k2": "bar", "k3": nil}, - want: []interface{}{nil, "bar", "foo"}, - }, - { - resp: map[string]interface{}{"k1": 5, "k2": "bar", "k3": nil}, - want: []interface{}{nil, 5.0, "bar"}, - }, - { - resp: map[string]interface{}{ - "k1": true, "k2": 0, "k3": "foo", "k4": "foo", "k5": false, - "k6": map[string]interface{}{"k1": true}, - }, - want: []interface{}{false, true, 0.0, "foo", "foo", map[string]interface{}{"k1": true}}, + resp: []interface{}{"foo", 1, false, nil, 0, true}, + want: []interface{}{nil, false, true, 0.0, 1.0, "foo"}, }, } From b8b9a8415372654b97cc48412836602a4d4c7722 Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Wed, 14 Feb 2018 10:16:18 -0800 Subject: [PATCH 46/58] Fixing some lint errors and compilation errors --- integration/firestore/firestore_test.go | 2 +- integration/iid/iid_test.go | 2 +- integration/internal/internal.go | 1 + integration/messaging/messaging_test.go | 2 +- internal/internal.go | 5 +++++ 5 files changed, 9 insertions(+), 3 deletions(-) diff --git a/integration/firestore/firestore_test.go b/integration/firestore/firestore_test.go index 6e7b4e28..cf392e45 100644 --- a/integration/firestore/firestore_test.go +++ b/integration/firestore/firestore_test.go @@ -30,7 +30,7 @@ func TestFirestore(t *testing.T) { return } ctx := context.Background() - app, err := internal.NewTestApp(ctx) + app, err := internal.NewTestApp(ctx, nil) if err != nil { t.Fatal(err) } diff --git a/integration/iid/iid_test.go b/integration/iid/iid_test.go index 9be5dce0..a4de6343 100644 --- a/integration/iid/iid_test.go +++ b/integration/iid/iid_test.go @@ -37,7 +37,7 @@ func TestMain(m *testing.M) { } ctx := context.Background() - app, err := internal.NewTestApp(ctx) + app, err := internal.NewTestApp(ctx, nil) if err != nil { log.Fatalln(err) } diff --git a/integration/internal/internal.go b/integration/internal/internal.go index 1d8ea9b3..497065eb 100644 --- a/integration/internal/internal.go +++ b/integration/internal/internal.go @@ -76,6 +76,7 @@ func ProjectID() (string, error) { return serviceAccount.ProjectID, nil } +// NewHTTPClient creates an HTTP client for making authorized requests during tests. func NewHTTPClient(ctx context.Context, opts ...option.ClientOption) (*http.Client, error) { opts = append( opts, diff --git a/integration/messaging/messaging_test.go b/integration/messaging/messaging_test.go index d7bb0693..b41bf87e 100644 --- a/integration/messaging/messaging_test.go +++ b/integration/messaging/messaging_test.go @@ -45,7 +45,7 @@ func TestMain(m *testing.M) { } ctx := context.Background() - app, err := internal.NewTestApp(ctx) + app, err := internal.NewTestApp(ctx, nil) if err != nil { log.Fatalln(err) } diff --git a/internal/internal.go b/internal/internal.go index 8db1b87c..c520a3ee 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -21,9 +21,13 @@ import ( "google.golang.org/api/option" ) +// FirebaseScopes is the set of OAuth2 scopes used by the Admin SDK. var FirebaseScopes = []string{ + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/datastore", "https://www.googleapis.com/auth/devstorage.full_control", "https://www.googleapis.com/auth/firebase", + "https://www.googleapis.com/auth/identitytoolkit", "https://www.googleapis.com/auth/userinfo.email", } @@ -41,6 +45,7 @@ type InstanceIDConfig struct { ProjectID string } +// DatabaseConfig represents the configuration of Firebase Database service. type DatabaseConfig struct { Opts []option.ClientOption URL string From 5fe7d33d71bda3e181945a49bf4f8e0705d3f109 Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Wed, 14 Feb 2018 12:27:55 -0800 Subject: [PATCH 47/58] Removing unused function --- db/query.go | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/db/query.go b/db/query.go index e81a6154..ca7dbfc6 100644 --- a/db/query.go +++ b/db/query.go @@ -260,18 +260,12 @@ const ( typeObject = 5 ) +// comparableKey is union type of numeric values and strings. type comparableKey struct { Num *float64 Str *string } -func (k *comparableKey) Val() interface{} { - if k.Str != nil { - return *k.Str - } - return *k.Num -} - func (k *comparableKey) Compare(o *comparableKey) int { if k.Str != nil && o.Str != nil { return strings.Compare(*k.Str, *o.Str) From a65e3e43e497e250909c2bf280be7c0be21eb31c Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Wed, 14 Feb 2018 14:41:12 -0800 Subject: [PATCH 48/58] Cleaned up unit tests for db --- db/query.go | 83 +++++++++++++++++++++++++++--------------------- db/query_test.go | 44 +++++++++++++++++++++++-- 2 files changed, 87 insertions(+), 40 deletions(-) diff --git a/db/query.go b/db/query.go index ca7dbfc6..b898d53c 100644 --- a/db/query.go +++ b/db/query.go @@ -124,7 +124,7 @@ func (q *Query) GetOrdered(ctx context.Context, v interface{}) error { return err } - sr, err := newSortableResult(temp, q.ob) + sr, err := newSortableQueryResult(temp, q.ob) if err != nil { return err } @@ -251,6 +251,7 @@ func (p orderByProperty) encode() (string, error) { return string(b), nil } +// Firebase type ordering: https://firebase.google.com/docs/database/rest/retrieve-data#section-rest-ordered-data const ( typeNull = 0 typeBoolFalse = 1 @@ -260,7 +261,7 @@ const ( typeObject = 5 ) -// comparableKey is union type of numeric values and strings. +// comparableKey is a union type of numeric values and strings. type comparableKey struct { Num *float64 Str *string @@ -277,6 +278,7 @@ func (k *comparableKey) Compare(o *comparableKey) int { } return 1 } else if k.Num != nil { + // numeric keys appear before string keys return -1 } return 1 @@ -295,41 +297,71 @@ func newComparableKey(v interface{}) *comparableKey { return &comparableKey{Num: &f} } -type sortableResult []*sortEntry +type queryResult struct { + Key *comparableKey + Value interface{} + Index interface{} + IndexType int +} -func (s sortableResult) Len() int { +func newQueryResult(key, val interface{}, order orderBy) *queryResult { + var index interface{} + if prop, ok := order.(orderByProperty); ok { + if prop == "$value" { + index = val + } else { + index = key + } + } else { + path := order.(orderByChild) + index = extractChildValue(val, string(path)) + } + return &queryResult{ + Key: newComparableKey(key), + Value: val, + Index: index, + IndexType: getIndexType(index), + } +} + +type sortableQueryResult []*queryResult + +func (s sortableQueryResult) Len() int { return len(s) } -func (s sortableResult) Swap(i, j int) { +func (s sortableQueryResult) Swap(i, j int) { s[i], s[j] = s[j], s[i] } -func (s sortableResult) Less(i, j int) bool { +func (s sortableQueryResult) Less(i, j int) bool { a, b := s[i], s[j] var aKey, bKey *comparableKey if a.IndexType == b.IndexType { + // If the indices have the same type and are comparable (i.e. numeric or string), compare + // them directly. Otherwise, compare the keys. if (a.IndexType == typeNumeric || a.IndexType == typeString) && a.Index != b.Index { aKey, bKey = newComparableKey(a.Index), newComparableKey(b.Index) } else { aKey, bKey = a.Key, b.Key } } else { + // If the indices are of different types, use the type ordering of Firebase. aKey, bKey = newComparableKey(a.IndexType), newComparableKey(b.IndexType) } return aKey.Compare(bKey) < 0 } -func newSortableResult(values interface{}, order orderBy) (sortableResult, error) { - var entries sortableResult +func newSortableQueryResult(values interface{}, order orderBy) (sortableQueryResult, error) { + var entries sortableQueryResult if m, ok := values.(map[string]interface{}); ok { for key, val := range m { - entries = append(entries, newSortEntry(key, val, order)) + entries = append(entries, newQueryResult(key, val, order)) } } else if l, ok := values.([]interface{}); ok { for key, val := range l { - entries = append(entries, newSortEntry(key, val, order)) + entries = append(entries, newQueryResult(key, val, order)) } } else { return nil, fmt.Errorf("sorting not supported for the result") @@ -337,33 +369,10 @@ func newSortableResult(values interface{}, order orderBy) (sortableResult, error return entries, nil } -type sortEntry struct { - Key *comparableKey - Value interface{} - Index interface{} - IndexType int -} - -func newSortEntry(key, val interface{}, order orderBy) *sortEntry { - var index interface{} - if prop, ok := order.(orderByProperty); ok { - if prop == "$value" { - index = val - } else { - index = key - } - } else { - path := order.(orderByChild) - index = extractChildValue(val, string(path)) - } - return &sortEntry{ - Key: newComparableKey(key), - Value: val, - Index: index, - IndexType: getIndexType(index), - } -} - +// extractChildValue retrieves the value at path from val. +// +// If the given path does not exist in val, or val does not support child path traversal, +// extractChildValue returns nil. func extractChildValue(val interface{}, path string) interface{} { segments := parsePath(path) curr := val diff --git a/db/query_test.go b/db/query_test.go index 0fd6bc56..acf34e92 100644 --- a/db/query_test.go +++ b/db/query_test.go @@ -14,6 +14,7 @@ package db import ( + "fmt" "reflect" "testing" @@ -423,6 +424,43 @@ func TestInvalidGetOrdered(t *testing.T) { } func TestChildQueryGetOrdered(t *testing.T) { + mock := &mockServer{Resp: sortableKeysResp} + srv := mock.Start(client) + defer srv.Close() + + cases := []struct { + child string + want []string + }{ + {"name", []string{"alice", "bob", "charlie", "dave", "ernie"}}, + {"age", []string{"ernie", "charlie", "bob", "dave", "alice"}}, + } + + var reqs []*testReq + for _, tc := range cases { + var result []person + if err := testref.OrderByChild(tc.child).GetOrdered(context.Background(), &result); err != nil { + t.Fatal(err) + } + reqs = append(reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"orderBy": fmt.Sprintf("%q", tc.child)}, + }) + + var got []string + for _, r := range result { + got = append(got, r.Name) + } + if !reflect.DeepEqual(tc.want, got) { + t.Errorf("GetOrdered(child: %q) = %v; want = %v", tc.child, got, tc.want) + } + } + + checkAllRequests(t, mock.Reqs, reqs) +} + +func TestImmediateChildQueryGetOrdered(t *testing.T) { mock := &mockServer{} srv := mock.Start(client) defer srv.Close() @@ -461,7 +499,7 @@ func TestChildQueryGetOrdered(t *testing.T) { checkAllRequests(t, mock.Reqs, reqs) } -func TestGrandChildQueryGetOrdered(t *testing.T) { +func TestNestedChildQueryGetOrdered(t *testing.T) { mock := &mockServer{} srv := mock.Start(client) defer srv.Close() @@ -504,7 +542,7 @@ func TestGrandChildQueryGetOrdered(t *testing.T) { checkAllRequests(t, mock.Reqs, reqs) } -func TestKeyQueryOrderedGet(t *testing.T) { +func TestKeyQueryGetOrdered(t *testing.T) { mock := &mockServer{Resp: sortableKeysResp} srv := mock.Start(client) defer srv.Close() @@ -557,7 +595,7 @@ func TestValueQueryGetOrdered(t *testing.T) { } } -func TestValueQueryOrderedGetWithList(t *testing.T) { +func TestValueQueryGetOrderedWithList(t *testing.T) { cases := []struct { resp []interface{} want []interface{} From d7343081ec6fb164743a5ca262088cd95e7d4d6f Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Wed, 14 Feb 2018 15:39:14 -0800 Subject: [PATCH 49/58] Updated query impl and tests --- db/query.go | 9 ++++++--- db/query_test.go | 22 ++++++++++++++-------- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/db/query.go b/db/query.go index b898d53c..9c03fcea 100644 --- a/db/query.go +++ b/db/query.go @@ -112,11 +112,11 @@ func (q *Query) Get(ctx context.Context, v interface{}) error { // v must be a pointer to an array or a slice. func (q *Query) GetOrdered(ctx context.Context, v interface{}) error { rv := reflect.ValueOf(v) - if rv.Kind() != reflect.Ptr { - return fmt.Errorf("value must be a pointer") + if rv.Kind() != reflect.Ptr || rv.IsNil() { + return fmt.Errorf("nil or not a pointer") } if rv.Elem().Kind() != reflect.Slice && rv.Elem().Kind() != reflect.Array { - return fmt.Errorf("value must be a pointer to an array or a slice") + return fmt.Errorf("non-array non-slice pointer") } var temp interface{} @@ -288,6 +288,9 @@ func newComparableKey(v interface{}) *comparableKey { if s, ok := v.(string); ok { return &comparableKey{Str: &s} } + + // Numeric values could be int (in the case of array indices and type constants), or float64 (if + // the value was received as json). if i, ok := v.(int); ok { f := float64(i) return &comparableKey{Num: &f} diff --git a/db/query_test.go b/db/query_test.go index acf34e92..db34882c 100644 --- a/db/query_test.go +++ b/db/query_test.go @@ -409,15 +409,21 @@ func TestAllParamsQuery(t *testing.T) { func TestInvalidGetOrdered(t *testing.T) { q := testref.OrderByKey() - var i interface{} - want := "value must be a pointer" - err := q.GetOrdered(context.Background(), i) + want := "nil or not a pointer" + var p *[]person // nil + err := q.GetOrdered(context.Background(), p) if err == nil || err.Error() != want { t.Errorf("GetOrdered(interface) = %v; want = %v", err, want) } - want = "value must be a pointer to an array or a slice" - err = q.GetOrdered(context.Background(), &i) + var i interface{} // not a pointer + err = q.GetOrdered(context.Background(), i) + if err == nil || err.Error() != want { + t.Errorf("GetOrdered(interface) = %v; want = %v", err, want) + } + + want = "non-array non-slice pointer" + err = q.GetOrdered(context.Background(), &i) // pointer to a non-array value if err == nil || err.Error() != want { t.Errorf("GetOrdered(interface) = %v; want = %v", err, want) } @@ -438,8 +444,8 @@ func TestChildQueryGetOrdered(t *testing.T) { var reqs []*testReq for _, tc := range cases { - var result []person - if err := testref.OrderByChild(tc.child).GetOrdered(context.Background(), &result); err != nil { + var result *[]person + if err := testref.OrderByChild(tc.child).GetOrdered(context.Background(), result); err != nil { t.Fatal(err) } reqs = append(reqs, &testReq{ @@ -449,7 +455,7 @@ func TestChildQueryGetOrdered(t *testing.T) { }) var got []string - for _, r := range result { + for _, r := range *result { got = append(got, r.Name) } if !reflect.DeepEqual(tc.want, got) { From 0007b6b80e3ffd49c6dfa248cce9248e401a09af Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Wed, 14 Feb 2018 16:18:36 -0800 Subject: [PATCH 50/58] Added integration tests for ordered queries --- db/query_test.go | 6 +- integration/db/db_test.go | 14 +++ integration/db/query_test.go | 185 +++++++++++++++++++++-------------- 3 files changed, 130 insertions(+), 75 deletions(-) diff --git a/db/query_test.go b/db/query_test.go index db34882c..2e4977d3 100644 --- a/db/query_test.go +++ b/db/query_test.go @@ -444,8 +444,8 @@ func TestChildQueryGetOrdered(t *testing.T) { var reqs []*testReq for _, tc := range cases { - var result *[]person - if err := testref.OrderByChild(tc.child).GetOrdered(context.Background(), result); err != nil { + var result []person + if err := testref.OrderByChild(tc.child).GetOrdered(context.Background(), &result); err != nil { t.Fatal(err) } reqs = append(reqs, &testReq{ @@ -455,7 +455,7 @@ func TestChildQueryGetOrdered(t *testing.T) { }) var got []string - for _, r := range *result { + for _, r := range result { got = append(got, r.Name) } if !reflect.DeepEqual(tc.want, got) { diff --git a/integration/db/db_test.go b/integration/db/db_test.go index 902fba16..a337cf4c 100644 --- a/integration/db/db_test.go +++ b/integration/db/db_test.go @@ -1,3 +1,17 @@ +// Copyright 2018 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package db import ( diff --git a/integration/db/query_test.go b/integration/db/query_test.go index 37f4141f..96ae98e8 100644 --- a/integration/db/query_test.go +++ b/integration/db/query_test.go @@ -1,3 +1,17 @@ +// Copyright 2017 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package db import ( @@ -11,26 +25,30 @@ var heightSorted = []string{ "triceratops", "stegosaurus", "bruhathkayosaurus", } +func min(i, j int) int { + if i < j { + return i + } + return j +} + func TestLimitToFirst(t *testing.T) { for _, tc := range []int{2, 10} { - var m map[string]Dinosaur + var d []Dinosaur if err := dinos.OrderByChild("height"). WithLimitToFirst(tc). - Get(context.Background(), &m); err != nil { + GetOrdered(context.Background(), &d); err != nil { t.Fatal(err) } - wl := tc - if len(heightSorted) < wl { - wl = len(heightSorted) - } + wl := min(tc, len(heightSorted)) want := heightSorted[:wl] - if len(m) != len(want) { - t.Errorf("WithLimitToFirst() = %v; want = %v", m, want) + if len(d) != wl { + t.Errorf("WithLimitToFirst() = %v; want = %v", d, want) } - for _, d := range want { - if _, ok := m[d]; !ok { - t.Errorf("WithLimitToFirst() = %v; want key %q", m, d) + for i, w := range want { + if d[i] != parsedTestData[w] { + t.Errorf("[%d] WithLimitToFirst() = %v; want = %v", i, d[i], parsedTestData[w]) } } } @@ -38,156 +56,159 @@ func TestLimitToFirst(t *testing.T) { func TestLimitToLast(t *testing.T) { for _, tc := range []int{2, 10} { - var m map[string]Dinosaur + var d []Dinosaur if err := dinos.OrderByChild("height"). WithLimitToLast(tc). - Get(context.Background(), &m); err != nil { + GetOrdered(context.Background(), &d); err != nil { t.Fatal(err) } - wl := tc - if len(heightSorted) < wl { - wl = len(heightSorted) - } + wl := min(tc, len(heightSorted)) want := heightSorted[len(heightSorted)-wl:] - if len(m) != len(want) { - t.Errorf("WithLimitToLast() = %v; want = %v", m, want) + if len(d) != wl { + t.Errorf("WithLimitToLast() = %v; want = %v", d, want) } - for _, d := range want { - if _, ok := m[d]; !ok { - t.Errorf("WithLimitToLast() = %v; want key %q", m, d) + for i, w := range want { + if d[i] != parsedTestData[w] { + t.Errorf("[%d] WithLimitToLast() = %v; want = %v", i, d[i], parsedTestData[w]) } } } } func TestStartAt(t *testing.T) { - var m map[string]Dinosaur + var d []Dinosaur if err := dinos.OrderByChild("height"). WithStartAt(3.5). - Get(context.Background(), &m); err != nil { + GetOrdered(context.Background(), &d); err != nil { t.Fatal(err) } want := heightSorted[len(heightSorted)-2:] - if len(m) != len(want) { - t.Errorf("WithStartAt() = %v; want = %v", m, want) + if len(d) != len(want) { + t.Errorf("WithStartAt() = %v; want = %v", d, want) } - for _, d := range want { - if _, ok := m[d]; !ok { - t.Errorf("WithStartAt() = %v; want key %q", m, d) + for i, w := range want { + if d[i] != parsedTestData[w] { + t.Errorf("[%d] WithStartAt() = %v; want = %v", i, d[i], parsedTestData[w]) } } } func TestEndAt(t *testing.T) { - var m map[string]Dinosaur + var d []Dinosaur if err := dinos.OrderByChild("height"). WithEndAt(3.5). - Get(context.Background(), &m); err != nil { + GetOrdered(context.Background(), &d); err != nil { t.Fatal(err) } want := heightSorted[:4] - if len(m) != len(want) { - t.Errorf("WithStartAt() = %v; want = %v", m, want) + if len(d) != len(want) { + t.Errorf("WithStartAt() = %v; want = %v", d, want) } - for _, d := range want { - if _, ok := m[d]; !ok { - t.Errorf("WithStartAt() = %v; want key %q", m, d) + for i, w := range want { + if d[i] != parsedTestData[w] { + t.Errorf("[%d] WithEndAt() = %v; want = %v", i, d[i], parsedTestData[w]) } } } func TestStartAndEndAt(t *testing.T) { - var m map[string]Dinosaur + var d []Dinosaur if err := dinos.OrderByChild("height"). WithStartAt(2.5). WithEndAt(5). - Get(context.Background(), &m); err != nil { + GetOrdered(context.Background(), &d); err != nil { t.Fatal(err) } want := heightSorted[len(heightSorted)-3 : len(heightSorted)-1] - if len(m) != len(want) { - t.Errorf("WithStartAt(), WithEndAt() = %v; want = %v", m, want) + if len(d) != len(want) { + t.Errorf("WithStartAt(), WithEndAt() = %v; want = %v", d, want) } - for _, d := range want { - if _, ok := m[d]; !ok { - t.Errorf("WithStartAt(), WithEndAt() = %v; want key %q", m, d) + for i, w := range want { + if d[i] != parsedTestData[w] { + t.Errorf("[%d] WithStartAt(), WithEndAt() = %v; want = %v", i, d[i], parsedTestData[w]) } } } func TestEqualTo(t *testing.T) { - var m map[string]Dinosaur + var d []Dinosaur if err := dinos.OrderByChild("height"). WithEqualTo(0.6). - Get(context.Background(), &m); err != nil { + GetOrdered(context.Background(), &d); err != nil { t.Fatal(err) } want := heightSorted[:2] - if len(m) != len(want) { - t.Errorf("WithEqualTo() = %v; want = %v", m, want) + if len(d) != len(want) { + t.Errorf("WithEqualTo() = %v; want = %v", d, want) } - for _, d := range want { - if _, ok := m[d]; !ok { - t.Errorf("WithEqualTo() = %v; want key %q", m, d) + for i, w := range want { + if d[i] != parsedTestData[w] { + t.Errorf("[%d] WithEqualTo() = %v; want = %v", i, d[i], parsedTestData[w]) } } } func TestOrderByNestedChild(t *testing.T) { - var m map[string]Dinosaur + var d []Dinosaur if err := dinos.OrderByChild("ratings/pos"). WithStartAt(4). - Get(context.Background(), &m); err != nil { + GetOrdered(context.Background(), &d); err != nil { t.Fatal(err) } want := []string{"pterodactyl", "stegosaurus", "triceratops"} - if len(m) != len(want) { - t.Errorf("OrderByChild(ratings/pos) = %v; want = %v", m, want) + if len(d) != len(want) { + t.Errorf("OrderByChild(ratings/pos) = %v; want = %v", d, want) } - for _, d := range want { - if _, ok := m[d]; !ok { - t.Errorf("OrderByChild(ratings/pos) = %v; want key %q", m, d) + for i, w := range want { + if d[i] != parsedTestData[w] { + t.Errorf("[%d] OrderByChild(ratings/pos) = %v; want = %v", i, d[i], parsedTestData[w]) } } } func TestOrderByKey(t *testing.T) { - var m map[string]Dinosaur - if err := dinos.OrderByKey().WithLimitToFirst(2).Get(context.Background(), &m); err != nil { + var d []Dinosaur + if err := dinos.OrderByKey(). + WithLimitToFirst(2). + GetOrdered(context.Background(), &d); err != nil { t.Fatal(err) } want := []string{"bruhathkayosaurus", "lambeosaurus"} - if len(m) != len(want) { - t.Errorf("OrderByKey() = %v; want = %v", m, want) + if len(d) != len(want) { + t.Errorf("OrderByKey() = %v; want = %v", d, want) } - for _, d := range want { - if _, ok := m[d]; !ok { - t.Errorf("OrderByKey() = %v; want key %q", m, d) + for i, w := range want { + if d[i] != parsedTestData[w] { + t.Errorf("[%d] OrderByKey() = %v; want = %v", i, d[i], parsedTestData[w]) } } } func TestOrderByValue(t *testing.T) { scores := ref.Child("scores") - var m map[string]int - if err := scores.OrderByValue().WithLimitToLast(2).Get(context.Background(), &m); err != nil { + var s []int + if err := scores.OrderByValue(). + WithLimitToLast(2). + GetOrdered(context.Background(), &s); err != nil { t.Fatal(err) } - want := []string{"pterodactyl", "linhenykus"} - if len(m) != len(want) { - t.Errorf("OrderByValue() = %v; want = %v", m, want) + want := []string{"linhenykus", "pterodactyl"} + if len(s) != len(want) { + t.Errorf("OrderByValue() = %v; want = %v", s, want) } - for _, d := range want { - if _, ok := m[d]; !ok { - t.Errorf("OrderByValue() = %v; want key %q", m, d) + scoresData := testData["scores"].(map[string]interface{}) + for i, w := range want { + ws := int(scoresData[w].(float64)) + if s[i] != ws { + t.Errorf("[%d] OrderByValue() = %v; want = %v", i, s[i], ws) } } } @@ -216,3 +237,23 @@ func TestQueryWithContext(t *testing.T) { t.Errorf("Get() = (%v, %v); want = (empty, error)", m, err) } } + +func TestUnorderedQuery(t *testing.T) { + var m map[string]Dinosaur + if err := dinos.OrderByChild("height"). + WithStartAt(2.5). + WithEndAt(5). + Get(context.Background(), &m); err != nil { + t.Fatal(err) + } + + want := heightSorted[len(heightSorted)-3 : len(heightSorted)-1] + if len(m) != len(want) { + t.Errorf("WithStartAt(), WithEndAt() = %v; want = %v", m, want) + } + for _, w := range want { + if _, ok := m[w]; !ok { + t.Errorf("WithStartAt(), WithEndAt() = %v; want key = %v", m, w) + } + } +} From f0f206363677282513671a0a744230ac8418fd88 Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Wed, 14 Feb 2018 16:43:08 -0800 Subject: [PATCH 51/58] Removed With*() from query functions --- db/auth_override_test.go | 2 +- db/query.go | 31 +++++++++++++-------- db/query_test.go | 36 ++++++++++++------------ db/ref.go | 2 +- integration/db/db_test.go | 2 +- integration/db/query_test.go | 54 ++++++++++++++++++------------------ 6 files changed, 67 insertions(+), 60 deletions(-) diff --git a/db/auth_override_test.go b/db/auth_override_test.go index f7202a72..9b05609e 100644 --- a/db/auth_override_test.go +++ b/db/auth_override_test.go @@ -88,7 +88,7 @@ func TestAuthOverrideRangeQuery(t *testing.T) { ref := aoClient.NewRef("peter") var got string - if err := ref.OrderByChild("foo").WithStartAt(1).WithEndAt(10).Get(context.Background(), &got); err != nil { + if err := ref.OrderByChild("foo").StartAt(1).EndAt(10).Get(context.Background(), &got); err != nil { t.Fatal(err) } if got != "data" { diff --git a/db/query.go b/db/query.go index 9c03fcea..b488524b 100644 --- a/db/query.go +++ b/db/query.go @@ -44,48 +44,48 @@ type Query struct { start, end, equalTo interface{} } -// WithStartAt returns a shallow copy of the Query with v set as a lower bound of a range query. +// StartAt returns a shallow copy of the Query with v set as a lower bound of a range query. // // The resulting Query will only return child nodes with a value greater than or equal to v. -func (q *Query) WithStartAt(v interface{}) *Query { +func (q *Query) StartAt(v interface{}) *Query { q2 := new(Query) *q2 = *q q2.start = v return q2 } -// WithEndAt returns a shallow copy of the Query with v set as a upper bound of a range query. +// EndAt returns a shallow copy of the Query with v set as a upper bound of a range query. // // The resulting Query will only return child nodes with a value less than or equal to v. -func (q *Query) WithEndAt(v interface{}) *Query { +func (q *Query) EndAt(v interface{}) *Query { q2 := new(Query) *q2 = *q q2.end = v return q2 } -// WithEqualTo returns a shallow copy of the Query with v set as an equals constraint. +// EqualTo returns a shallow copy of the Query with v set as an equals constraint. // // The resulting Query will only return child nodes whose values equal to v. -func (q *Query) WithEqualTo(v interface{}) *Query { +func (q *Query) EqualTo(v interface{}) *Query { q2 := new(Query) *q2 = *q q2.equalTo = v return q2 } -// WithLimitToFirst returns a shallow copy of the Query, which is anchored to the first n +// LimitToFirst returns a shallow copy of the Query, which is anchored to the first n // elements of the window. -func (q *Query) WithLimitToFirst(n int) *Query { +func (q *Query) LimitToFirst(n int) *Query { q2 := new(Query) *q2 = *q q2.limFirst = n return q2 } -// WithLimitToLast returns a shallow copy of the Query, which is anchored to the last n +// LimitToLast returns a shallow copy of the Query, which is anchored to the last n // elements of the window. -func (q *Query) WithLimitToLast(n int) *Query { +func (q *Query) LimitToLast(n int) *Query { q2 := new(Query) *q2 = *q q2.limLast = n @@ -94,7 +94,12 @@ func (q *Query) WithLimitToLast(n int) *Query { // Get executes the Query and populates v with the results. // -// Results will not be stored in any particular order in v. +// Data deserialization is performed using https://golang.org/pkg/encoding/json/#Unmarshal, and +// therefore v has the same requirements as the json package. Specifically, it must be a pointer, +// and must not be nil. +// +// Despite the ordering constraint of the Query, results are not stored in any particular order +// in v. Use GetOrdered() to obtain ordered results. func (q *Query) Get(ctx context.Context, v interface{}) error { qp := make(map[string]string) if err := initQueryParams(q, qp); err != nil { @@ -109,7 +114,9 @@ func (q *Query) Get(ctx context.Context, v interface{}) error { // GetOrdered executes the Query and provides the results as an ordered list. // -// v must be a pointer to an array or a slice. +// v must be a pointer to an array or a slice. Only the child values returned by the query are +// unmarshalled into v. Top-level keys are not returned. Although if the Query was created using +// OrderByKey(), the returned values will still be ordered based on their keys. func (q *Query) GetOrdered(ctx context.Context, v interface{}) error { rv := reflect.ValueOf(v) if rv.Kind() != reflect.Ptr || rv.IsNil() { diff --git a/db/query_test.go b/db/query_test.go index 2e4977d3..e6d44caa 100644 --- a/db/query_test.go +++ b/db/query_test.go @@ -147,7 +147,7 @@ func TestChildQueryWithParams(t *testing.T) { srv := mock.Start(client) defer srv.Close() - q := testref.OrderByChild("messages").WithStartAt("m4").WithEndAt("m50").WithLimitToFirst(10) + q := testref.OrderByChild("messages").StartAt("m4").EndAt("m50").LimitToFirst(10) var got map[string]interface{} if err := q.Get(context.Background(), &got); err != nil { t.Fatal(err) @@ -235,11 +235,11 @@ func TestLimitFirstQuery(t *testing.T) { defer srv.Close() var got map[string]interface{} - if err := testref.OrderByChild("messages").WithLimitToFirst(10).Get(context.Background(), &got); err != nil { + if err := testref.OrderByChild("messages").LimitToFirst(10).Get(context.Background(), &got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { - t.Errorf("query.WithLimitToFirst() = %v; want = %v", got, want) + t.Errorf("LimitToFirst() = %v; want = %v", got, want) } checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "GET", @@ -255,11 +255,11 @@ func TestLimitLastQuery(t *testing.T) { defer srv.Close() var got map[string]interface{} - if err := testref.OrderByChild("messages").WithLimitToLast(10).Get(context.Background(), &got); err != nil { + if err := testref.OrderByChild("messages").LimitToLast(10).Get(context.Background(), &got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { - t.Errorf("query.WithLimitToLast() = %v; want = %v", got, want) + t.Errorf("LimitToLast() = %v; want = %v", got, want) } checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "GET", @@ -279,9 +279,9 @@ func TestInvalidLimitQuery(t *testing.T) { name string q *Query }{ - {"BothLimits", q.WithLimitToFirst(10).WithLimitToLast(10)}, - {"NegativeFirst", q.WithLimitToFirst(-10)}, - {"NegativeLast", q.WithLimitToLast(-10)}, + {"BothLimits", q.LimitToFirst(10).LimitToLast(10)}, + {"NegativeFirst", q.LimitToFirst(-10)}, + {"NegativeLast", q.LimitToLast(-10)}, } for _, tc := range cases { var got map[string]interface{} @@ -301,11 +301,11 @@ func TestStartAtQuery(t *testing.T) { defer srv.Close() var got map[string]interface{} - if err := testref.OrderByChild("messages").WithStartAt(10).Get(context.Background(), &got); err != nil { + if err := testref.OrderByChild("messages").StartAt(10).Get(context.Background(), &got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { - t.Errorf("WithStartAt() = %v; want = %v", got, want) + t.Errorf("StartAt() = %v; want = %v", got, want) } checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "GET", @@ -321,11 +321,11 @@ func TestEndAtQuery(t *testing.T) { defer srv.Close() var got map[string]interface{} - if err := testref.OrderByChild("messages").WithEndAt(10).Get(context.Background(), &got); err != nil { + if err := testref.OrderByChild("messages").EndAt(10).Get(context.Background(), &got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { - t.Errorf("WithEndAt() = %v; want = %v", got, want) + t.Errorf("EndAt() = %v; want = %v", got, want) } checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "GET", @@ -341,11 +341,11 @@ func TestEqualToQuery(t *testing.T) { defer srv.Close() var got map[string]interface{} - if err := testref.OrderByChild("messages").WithEqualTo(10).Get(context.Background(), &got); err != nil { + if err := testref.OrderByChild("messages").EqualTo(10).Get(context.Background(), &got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { - t.Errorf("WithEqualTo() = %v; want = %v", got, want) + t.Errorf("EqualTo() = %v; want = %v", got, want) } checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "GET", @@ -365,9 +365,9 @@ func TestInvalidFilterQuery(t *testing.T) { name string q *Query }{ - {"InvalidStartAt", q.WithStartAt(func() {})}, - {"InvalidEndAt", q.WithEndAt(func() {})}, - {"InvalidEqualTo", q.WithEqualTo(func() {})}, + {"InvalidStartAt", q.StartAt(func() {})}, + {"InvalidEndAt", q.EndAt(func() {})}, + {"InvalidEqualTo", q.EqualTo(func() {})}, } for _, tc := range cases { var got map[string]interface{} @@ -386,7 +386,7 @@ func TestAllParamsQuery(t *testing.T) { srv := mock.Start(client) defer srv.Close() - q := testref.OrderByChild("messages").WithLimitToFirst(100).WithStartAt("bar").WithEndAt("foo") + q := testref.OrderByChild("messages").LimitToFirst(100).StartAt("bar").EndAt("foo") var got map[string]interface{} if err := q.Get(context.Background(), &got); err != nil { t.Fatal(err) diff --git a/db/ref.go b/db/ref.go index bf3bb707..cff60ee5 100644 --- a/db/ref.go +++ b/db/ref.go @@ -72,7 +72,7 @@ func (r *Ref) Child(path string) *Ref { // // Data deserialization is performed using https://golang.org/pkg/encoding/json/#Unmarshal, and // therefore v has the same requirements as the json package. Specifically, it must be a pointer, -// and it must not be nil. +// and must not be nil. func (r *Ref) Get(ctx context.Context, v interface{}) error { resp, err := r.send(ctx, "GET") if err != nil { diff --git a/integration/db/db_test.go b/integration/db/db_test.go index a337cf4c..7725bd32 100644 --- a/integration/db/db_test.go +++ b/integration/db/db_test.go @@ -621,7 +621,7 @@ func TestReadWriteAccess(t *testing.T) { func TestQueryAccess(t *testing.T) { r := aoClient.NewRef("_adminsdk/go/protected") got := make(map[string]interface{}) - if err := r.OrderByKey().WithLimitToFirst(2).Get(context.Background(), &got); err == nil { + if err := r.OrderByKey().LimitToFirst(2).Get(context.Background(), &got); err == nil { t.Errorf("OrderByQuery() = nil; want = error") } else if err.Error() != permDenied { t.Errorf("Error = %q; want = %q", err.Error(), permDenied) diff --git a/integration/db/query_test.go b/integration/db/query_test.go index 96ae98e8..b5b740d6 100644 --- a/integration/db/query_test.go +++ b/integration/db/query_test.go @@ -36,7 +36,7 @@ func TestLimitToFirst(t *testing.T) { for _, tc := range []int{2, 10} { var d []Dinosaur if err := dinos.OrderByChild("height"). - WithLimitToFirst(tc). + LimitToFirst(tc). GetOrdered(context.Background(), &d); err != nil { t.Fatal(err) } @@ -44,11 +44,11 @@ func TestLimitToFirst(t *testing.T) { wl := min(tc, len(heightSorted)) want := heightSorted[:wl] if len(d) != wl { - t.Errorf("WithLimitToFirst() = %v; want = %v", d, want) + t.Errorf("LimitToFirst() = %v; want = %v", d, want) } for i, w := range want { if d[i] != parsedTestData[w] { - t.Errorf("[%d] WithLimitToFirst() = %v; want = %v", i, d[i], parsedTestData[w]) + t.Errorf("[%d] LimitToFirst() = %v; want = %v", i, d[i], parsedTestData[w]) } } } @@ -58,7 +58,7 @@ func TestLimitToLast(t *testing.T) { for _, tc := range []int{2, 10} { var d []Dinosaur if err := dinos.OrderByChild("height"). - WithLimitToLast(tc). + LimitToLast(tc). GetOrdered(context.Background(), &d); err != nil { t.Fatal(err) } @@ -66,11 +66,11 @@ func TestLimitToLast(t *testing.T) { wl := min(tc, len(heightSorted)) want := heightSorted[len(heightSorted)-wl:] if len(d) != wl { - t.Errorf("WithLimitToLast() = %v; want = %v", d, want) + t.Errorf("LimitToLast() = %v; want = %v", d, want) } for i, w := range want { if d[i] != parsedTestData[w] { - t.Errorf("[%d] WithLimitToLast() = %v; want = %v", i, d[i], parsedTestData[w]) + t.Errorf("[%d] LimitToLast() = %v; want = %v", i, d[i], parsedTestData[w]) } } } @@ -79,18 +79,18 @@ func TestLimitToLast(t *testing.T) { func TestStartAt(t *testing.T) { var d []Dinosaur if err := dinos.OrderByChild("height"). - WithStartAt(3.5). + StartAt(3.5). GetOrdered(context.Background(), &d); err != nil { t.Fatal(err) } want := heightSorted[len(heightSorted)-2:] if len(d) != len(want) { - t.Errorf("WithStartAt() = %v; want = %v", d, want) + t.Errorf("StartAt() = %v; want = %v", d, want) } for i, w := range want { if d[i] != parsedTestData[w] { - t.Errorf("[%d] WithStartAt() = %v; want = %v", i, d[i], parsedTestData[w]) + t.Errorf("[%d] StartAt() = %v; want = %v", i, d[i], parsedTestData[w]) } } } @@ -98,18 +98,18 @@ func TestStartAt(t *testing.T) { func TestEndAt(t *testing.T) { var d []Dinosaur if err := dinos.OrderByChild("height"). - WithEndAt(3.5). + EndAt(3.5). GetOrdered(context.Background(), &d); err != nil { t.Fatal(err) } want := heightSorted[:4] if len(d) != len(want) { - t.Errorf("WithStartAt() = %v; want = %v", d, want) + t.Errorf("StartAt() = %v; want = %v", d, want) } for i, w := range want { if d[i] != parsedTestData[w] { - t.Errorf("[%d] WithEndAt() = %v; want = %v", i, d[i], parsedTestData[w]) + t.Errorf("[%d] EndAt() = %v; want = %v", i, d[i], parsedTestData[w]) } } } @@ -117,19 +117,19 @@ func TestEndAt(t *testing.T) { func TestStartAndEndAt(t *testing.T) { var d []Dinosaur if err := dinos.OrderByChild("height"). - WithStartAt(2.5). - WithEndAt(5). + StartAt(2.5). + EndAt(5). GetOrdered(context.Background(), &d); err != nil { t.Fatal(err) } want := heightSorted[len(heightSorted)-3 : len(heightSorted)-1] if len(d) != len(want) { - t.Errorf("WithStartAt(), WithEndAt() = %v; want = %v", d, want) + t.Errorf("StartAt(), EndAt() = %v; want = %v", d, want) } for i, w := range want { if d[i] != parsedTestData[w] { - t.Errorf("[%d] WithStartAt(), WithEndAt() = %v; want = %v", i, d[i], parsedTestData[w]) + t.Errorf("[%d] StartAt(), EndAt() = %v; want = %v", i, d[i], parsedTestData[w]) } } } @@ -137,18 +137,18 @@ func TestStartAndEndAt(t *testing.T) { func TestEqualTo(t *testing.T) { var d []Dinosaur if err := dinos.OrderByChild("height"). - WithEqualTo(0.6). + EqualTo(0.6). GetOrdered(context.Background(), &d); err != nil { t.Fatal(err) } want := heightSorted[:2] if len(d) != len(want) { - t.Errorf("WithEqualTo() = %v; want = %v", d, want) + t.Errorf("EqualTo() = %v; want = %v", d, want) } for i, w := range want { if d[i] != parsedTestData[w] { - t.Errorf("[%d] WithEqualTo() = %v; want = %v", i, d[i], parsedTestData[w]) + t.Errorf("[%d] EqualTo() = %v; want = %v", i, d[i], parsedTestData[w]) } } } @@ -156,7 +156,7 @@ func TestEqualTo(t *testing.T) { func TestOrderByNestedChild(t *testing.T) { var d []Dinosaur if err := dinos.OrderByChild("ratings/pos"). - WithStartAt(4). + StartAt(4). GetOrdered(context.Background(), &d); err != nil { t.Fatal(err) } @@ -175,7 +175,7 @@ func TestOrderByNestedChild(t *testing.T) { func TestOrderByKey(t *testing.T) { var d []Dinosaur if err := dinos.OrderByKey(). - WithLimitToFirst(2). + LimitToFirst(2). GetOrdered(context.Background(), &d); err != nil { t.Fatal(err) } @@ -195,7 +195,7 @@ func TestOrderByValue(t *testing.T) { scores := ref.Child("scores") var s []int if err := scores.OrderByValue(). - WithLimitToLast(2). + LimitToLast(2). GetOrdered(context.Background(), &s); err != nil { t.Fatal(err) } @@ -215,7 +215,7 @@ func TestOrderByValue(t *testing.T) { func TestQueryWithContext(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) - q := dinos.OrderByKey().WithLimitToFirst(2) + q := dinos.OrderByKey().LimitToFirst(2) var m map[string]Dinosaur if err := q.Get(ctx, &m); err != nil { t.Fatal(err) @@ -241,19 +241,19 @@ func TestQueryWithContext(t *testing.T) { func TestUnorderedQuery(t *testing.T) { var m map[string]Dinosaur if err := dinos.OrderByChild("height"). - WithStartAt(2.5). - WithEndAt(5). + StartAt(2.5). + EndAt(5). Get(context.Background(), &m); err != nil { t.Fatal(err) } want := heightSorted[len(heightSorted)-3 : len(heightSorted)-1] if len(m) != len(want) { - t.Errorf("WithStartAt(), WithEndAt() = %v; want = %v", m, want) + t.Errorf("StartAt(), EndAt() = %v; want = %v", m, want) } for _, w := range want { if _, ok := m[w]; !ok { - t.Errorf("WithStartAt(), WithEndAt() = %v; want key = %v", m, w) + t.Errorf("StartAt(), EndAt() = %v; want key = %v", m, w) } } } From 0e095a032ef74d56ec325eab0077fe44a18f8647 Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Wed, 14 Feb 2018 17:39:42 -0800 Subject: [PATCH 52/58] Updated change log; Added more tests --- CHANGELOG.md | 2 +- db/auth_override_test.go | 2 +- db/db.go | 2 +- db/db_test.go | 2 +- db/query.go | 14 ++++----- db/query_test.go | 55 ++++++++++++++++++++++++++++++++---- db/ref.go | 2 +- db/ref_test.go | 10 +++---- integration/db/db_test.go | 1 + integration/db/query_test.go | 35 ++++++++++------------- 10 files changed, 83 insertions(+), 42 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3060d302..9b2583f7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ # Unreleased -- +- [added] Added the `db` package for interacting with the Firebase database. # v2.5.0 diff --git a/db/auth_override_test.go b/db/auth_override_test.go index 9b05609e..86cbeef2 100644 --- a/db/auth_override_test.go +++ b/db/auth_override_test.go @@ -1,4 +1,4 @@ -// Copyright 2017 Google Inc. All Rights Reserved. +// Copyright 2018 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/db/db.go b/db/db.go index a7aaeb7d..81f16cea 100644 --- a/db/db.go +++ b/db/db.go @@ -1,4 +1,4 @@ -// Copyright 2017 Google Inc. All Rights Reserved. +// Copyright 2018 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/db/db_test.go b/db/db_test.go index e48e80ed..0ec601cf 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -1,4 +1,4 @@ -// Copyright 2017 Google Inc. All Rights Reserved. +// Copyright 2018 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/db/query.go b/db/query.go index b488524b..08536ccd 100644 --- a/db/query.go +++ b/db/query.go @@ -1,4 +1,4 @@ -// Copyright 2017 Google Inc. All Rights Reserved. +// Copyright 2018 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -131,10 +131,10 @@ func (q *Query) GetOrdered(ctx context.Context, v interface{}) error { return err } - sr, err := newSortableQueryResult(temp, q.ob) - if err != nil { - return err + if temp == nil { + return nil } + sr := newSortableQueryResult(temp, q.ob) sort.Sort(sr) var values []interface{} @@ -363,7 +363,7 @@ func (s sortableQueryResult) Less(i, j int) bool { return aKey.Compare(bKey) < 0 } -func newSortableQueryResult(values interface{}, order orderBy) (sortableQueryResult, error) { +func newSortableQueryResult(values interface{}, order orderBy) sortableQueryResult { var entries sortableQueryResult if m, ok := values.(map[string]interface{}); ok { for key, val := range m { @@ -374,9 +374,9 @@ func newSortableQueryResult(values interface{}, order orderBy) (sortableQueryRes entries = append(entries, newQueryResult(key, val, order)) } } else { - return nil, fmt.Errorf("sorting not supported for the result") + entries = append(entries, newQueryResult(0, values, order)) } - return entries, nil + return entries } // extractChildValue retrieves the value at path from val. diff --git a/db/query_test.go b/db/query_test.go index e6d44caa..234764e5 100644 --- a/db/query_test.go +++ b/db/query_test.go @@ -1,4 +1,4 @@ -// Copyright 2017 Google Inc. All Rights Reserved. +// Copyright 2018 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -440,6 +440,7 @@ func TestChildQueryGetOrdered(t *testing.T) { }{ {"name", []string{"alice", "bob", "charlie", "dave", "ernie"}}, {"age", []string{"ernie", "charlie", "bob", "dave", "alice"}}, + {"nonexisting", []string{"alice", "bob", "charlie", "dave", "ernie"}}, } var reqs []*testReq @@ -462,7 +463,6 @@ func TestChildQueryGetOrdered(t *testing.T) { t.Errorf("GetOrdered(child: %q) = %v; want = %v", tc.child, got, tc.want) } } - checkAllRequests(t, mock.Reqs, reqs) } @@ -501,7 +501,6 @@ func TestImmediateChildQueryGetOrdered(t *testing.T) { t.Errorf("GetOrdered(child: %q) = %v; want = %v", "child", got, tc.want) } } - checkAllRequests(t, mock.Reqs, reqs) } @@ -544,7 +543,6 @@ func TestNestedChildQueryGetOrdered(t *testing.T) { t.Errorf("GetOrdered(child: %q) = %v; want = %v", "child/grandchild", got, tc.want) } } - checkAllRequests(t, mock.Reqs, reqs) } @@ -572,7 +570,6 @@ func TestKeyQueryGetOrdered(t *testing.T) { if !reflect.DeepEqual(want, got) { t.Errorf("GetOrdered(key) = %v; want = %v", got, want) } - checkOnlyRequest(t, mock.Reqs, req) } @@ -599,6 +596,7 @@ func TestValueQueryGetOrdered(t *testing.T) { t.Errorf("GetOrdered(value) = %v; want = %v", got, tc.want) } } + checkAllRequests(t, mock.Reqs, reqs) } func TestValueQueryGetOrderedWithList(t *testing.T) { @@ -658,4 +656,51 @@ func TestValueQueryGetOrderedWithList(t *testing.T) { t.Errorf("GetOrdered(value) = %v; want = %v", got, tc.want) } } + checkAllRequests(t, mock.Reqs, reqs) +} + +func TestGetOrderedWithNilResult(t *testing.T) { + mock := &mockServer{Resp: nil} + srv := mock.Start(client) + defer srv.Close() + + var got []interface{} + if err := testref.OrderByChild("child").GetOrdered(context.Background(), &got); err != nil { + t.Fatal(err) + } + if got != nil { + t.Errorf("GetOrdered(value) = %v; want = nil", got) + } +} + +func TestGetOrderedWithLeafNode(t *testing.T) { + mock := &mockServer{Resp: "foo"} + srv := mock.Start(client) + defer srv.Close() + + var got []interface{} + if err := testref.OrderByChild("child").GetOrdered(context.Background(), &got); err != nil { + t.Fatal(err) + } + + want := []interface{}{"foo"} + if !reflect.DeepEqual(want, got) { + t.Errorf("GetOrdered(value) = %v; want = %v", got, want) + } +} + +func TestQueryHttpError(t *testing.T) { + mock := &mockServer{Resp: map[string]string{"error": "test error"}, Status: 500} + srv := mock.Start(client) + defer srv.Close() + + want := "http error status: 500; reason: test error" + var got []string + err := testref.OrderByChild("child").GetOrdered(context.Background(), &got) + if err == nil || err.Error() != want { + t.Errorf("GetOrdered() = %v; want = %v", err, want) + } + if got != nil { + t.Errorf("GetOrdered() = %v; want = nil", got) + } } diff --git a/db/ref.go b/db/ref.go index cff60ee5..ea5c697e 100644 --- a/db/ref.go +++ b/db/ref.go @@ -1,4 +1,4 @@ -// Copyright 2017 Google Inc. All Rights Reserved. +// Copyright 2018 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/db/ref_test.go b/db/ref_test.go index 52a6e66d..7786b5b5 100644 --- a/db/ref_test.go +++ b/db/ref_test.go @@ -1,4 +1,4 @@ -// Copyright 2017 Google Inc. All Rights Reserved. +// Copyright 2018 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -189,7 +189,7 @@ func TestGetShallow(t *testing.T) { t.Fatal(err) } if !reflect.DeepEqual(tc, got) { - t.Errorf("Get() = %v; want = %v", got, tc) + t.Errorf("GetShallow() = %v; want = %v", got, tc) } want = append(want, &testReq{Method: "GET", Path: "/peter.json", Query: wantQuery}) } @@ -333,7 +333,7 @@ func TestInvalidPath(t *testing.T) { } if len(mock.Reqs) != 0 { - t.Errorf("Requests: %v; want: empty", mock.Reqs) + t.Errorf("Requests = %v; want = empty", mock.Reqs) } } @@ -356,7 +356,7 @@ func TestInvalidChildPath(t *testing.T) { } if len(mock.Reqs) != 0 { - t.Errorf("Requests: %v; want: empty", mock.Reqs) + t.Errorf("Requests = %v; want = empty", mock.Reqs) } } @@ -398,7 +398,7 @@ func TestInvalidSet(t *testing.T) { } for _, tc := range cases { if err := testref.Set(context.Background(), tc); err == nil { - t.Errorf("Set(%v) = nil; want error", tc) + t.Errorf("Set(%v) = nil; want = error", tc) } } if len(mock.Reqs) != 0 { diff --git a/integration/db/db_test.go b/integration/db/db_test.go index 7725bd32..d201fe63 100644 --- a/integration/db/db_test.go +++ b/integration/db/db_test.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package db contains integration tests for the firebase.google.com/go/db package. package db import ( diff --git a/integration/db/query_test.go b/integration/db/query_test.go index b5b740d6..72b2e507 100644 --- a/integration/db/query_test.go +++ b/integration/db/query_test.go @@ -1,4 +1,4 @@ -// Copyright 2017 Google Inc. All Rights Reserved. +// Copyright 2018 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -44,7 +44,7 @@ func TestLimitToFirst(t *testing.T) { wl := min(tc, len(heightSorted)) want := heightSorted[:wl] if len(d) != wl { - t.Errorf("LimitToFirst() = %v; want = %v", d, want) + t.Errorf("LimitToFirst() = %d; want = %d", len(d), wl) } for i, w := range want { if d[i] != parsedTestData[w] { @@ -66,7 +66,7 @@ func TestLimitToLast(t *testing.T) { wl := min(tc, len(heightSorted)) want := heightSorted[len(heightSorted)-wl:] if len(d) != wl { - t.Errorf("LimitToLast() = %v; want = %v", d, want) + t.Errorf("LimitToLast() = %d; want = %d", len(d), wl) } for i, w := range want { if d[i] != parsedTestData[w] { @@ -86,7 +86,7 @@ func TestStartAt(t *testing.T) { want := heightSorted[len(heightSorted)-2:] if len(d) != len(want) { - t.Errorf("StartAt() = %v; want = %v", d, want) + t.Errorf("StartAt() = %d; want = %d", len(d), len(want)) } for i, w := range want { if d[i] != parsedTestData[w] { @@ -105,7 +105,7 @@ func TestEndAt(t *testing.T) { want := heightSorted[:4] if len(d) != len(want) { - t.Errorf("StartAt() = %v; want = %v", d, want) + t.Errorf("StartAt() = %d; want = %d", len(d), len(want)) } for i, w := range want { if d[i] != parsedTestData[w] { @@ -125,7 +125,7 @@ func TestStartAndEndAt(t *testing.T) { want := heightSorted[len(heightSorted)-3 : len(heightSorted)-1] if len(d) != len(want) { - t.Errorf("StartAt(), EndAt() = %v; want = %v", d, want) + t.Errorf("StartAt(), EndAt() = %d; want = %d", len(d), len(want)) } for i, w := range want { if d[i] != parsedTestData[w] { @@ -144,7 +144,7 @@ func TestEqualTo(t *testing.T) { want := heightSorted[:2] if len(d) != len(want) { - t.Errorf("EqualTo() = %v; want = %v", d, want) + t.Errorf("EqualTo() = %d; want = %d", len(d), len(want)) } for i, w := range want { if d[i] != parsedTestData[w] { @@ -163,7 +163,7 @@ func TestOrderByNestedChild(t *testing.T) { want := []string{"pterodactyl", "stegosaurus", "triceratops"} if len(d) != len(want) { - t.Errorf("OrderByChild(ratings/pos) = %v; want = %v", d, want) + t.Errorf("OrderByChild(ratings/pos) = %d; want = %d", len(d), len(want)) } for i, w := range want { if d[i] != parsedTestData[w] { @@ -182,7 +182,7 @@ func TestOrderByKey(t *testing.T) { want := []string{"bruhathkayosaurus", "lambeosaurus"} if len(d) != len(want) { - t.Errorf("OrderByKey() = %v; want = %v", d, want) + t.Errorf("OrderByKey() = %d; want = %d", len(d), len(want)) } for i, w := range want { if d[i] != parsedTestData[w] { @@ -202,13 +202,13 @@ func TestOrderByValue(t *testing.T) { want := []string{"linhenykus", "pterodactyl"} if len(s) != len(want) { - t.Errorf("OrderByValue() = %v; want = %v", s, want) + t.Errorf("OrderByValue() = %d; want = %d", len(s), len(want)) } scoresData := testData["scores"].(map[string]interface{}) for i, w := range want { ws := int(scoresData[w].(float64)) if s[i] != ws { - t.Errorf("[%d] OrderByValue() = %v; want = %v", i, s[i], ws) + t.Errorf("[%d] OrderByValue() = %d; want = %d", i, s[i], ws) } } } @@ -223,12 +223,7 @@ func TestQueryWithContext(t *testing.T) { want := []string{"bruhathkayosaurus", "lambeosaurus"} if len(m) != len(want) { - t.Errorf("OrderByKey() = %v; want = %v", m, want) - } - for _, d := range want { - if _, ok := m[d]; !ok { - t.Errorf("OrderByKey() = %v; want key %q", m, d) - } + t.Errorf("OrderByKey() = %d; want = %d", len(m), len(want)) } cancel() @@ -249,11 +244,11 @@ func TestUnorderedQuery(t *testing.T) { want := heightSorted[len(heightSorted)-3 : len(heightSorted)-1] if len(m) != len(want) { - t.Errorf("StartAt(), EndAt() = %v; want = %v", m, want) + t.Errorf("Get() = %d; want = %d", len(m), len(want)) } - for _, w := range want { + for i, w := range want { if _, ok := m[w]; !ok { - t.Errorf("StartAt(), EndAt() = %v; want key = %v", m, w) + t.Errorf("[%d] result[%q] not present", i, w) } } } From 0bbe37289aad09d2379f1c59ed4d93dce4d5b133 Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Wed, 14 Feb 2018 18:43:40 -0800 Subject: [PATCH 53/58] Support for database url in auto init --- firebase_test.go | 6 ++++++ testdata/firebase_config.json | 1 + 2 files changed, 7 insertions(+) diff --git a/firebase_test.go b/firebase_test.go index 792c69e3..874b5d8e 100644 --- a/firebase_test.go +++ b/firebase_test.go @@ -423,6 +423,7 @@ func TestAutoInit(t *testing.T) { "testdata/firebase_config.json", nil, &Config{ + DatabaseURL: "https://auto-init.database.url", ProjectID: "auto-init-project-id", StorageBucket: "auto-init.storage.bucket", }, @@ -430,11 +431,13 @@ func TestAutoInit(t *testing.T) { { "", `{ + "databaseURL": "https://auto-init.database.url", "projectId": "auto-init-project-id", "storageBucket": "auto-init.storage.bucket" }`, nil, &Config{ + DatabaseURL: "https://auto-init.database.url", ProjectID: "auto-init-project-id", StorageBucket: "auto-init.storage.bucket", }, @@ -568,6 +571,9 @@ func (t *testTokenSource) Token() (*oauth2.Token, error) { } func compareConfig(got *App, want *Config, t *testing.T) { + if got.dbURL != want.DatabaseURL { + t.Errorf("app.dbURL = %q; want = %q", got.dbURL, want.DatabaseURL) + } if got.projectID != want.ProjectID { t.Errorf("app.projectID = %q; want = %q", got.projectID, want.ProjectID) } diff --git a/testdata/firebase_config.json b/testdata/firebase_config.json index e9a3b5bc..772da62d 100644 --- a/testdata/firebase_config.json +++ b/testdata/firebase_config.json @@ -1,4 +1,5 @@ { + "databaseURL": "https://auto-init.database.url", "projectId": "auto-init-project-id", "storageBucket": "auto-init.storage.bucket" } From f4cf85a5575c6765050fc42938c5371273ab923d Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Thu, 15 Feb 2018 15:05:22 -0800 Subject: [PATCH 54/58] Support for loading auth overrides from env --- firebase.go | 39 +++++++++++++++++++++---------------- firebase_test.go | 41 +++++++++++++++++++++++++++++++++++++-- integration/db/db_test.go | 10 +++++----- 3 files changed, 66 insertions(+), 24 deletions(-) diff --git a/firebase.go b/firebase.go index d95eb437..d06a8483 100644 --- a/firebase.go +++ b/firebase.go @@ -38,15 +38,7 @@ import ( "google.golang.org/api/transport" ) -var firebaseScopes = []string{ - "https://www.googleapis.com/auth/cloud-platform", - "https://www.googleapis.com/auth/datastore", - "https://www.googleapis.com/auth/devstorage.full_control", - "https://www.googleapis.com/auth/firebase", - "https://www.googleapis.com/auth/identitytoolkit", - "https://www.googleapis.com/auth/userinfo.email", - "https://www.googleapis.com/auth/firebase.messaging", -} +var defaultAuthOverrides = make(map[string]interface{}) // Version of the Firebase Go Admin SDK. const Version = "2.5.0" @@ -66,10 +58,10 @@ type App struct { // Config represents the configuration used to initialize an App. type Config struct { - AuthOverride *db.AuthOverride `json:"databaseAuthVariableOverride"` - DatabaseURL string `json:"databaseURL"` - ProjectID string `json:"projectId"` - StorageBucket string `json:"storageBucket"` + AuthOverride *map[string]interface{} `json:"databaseAuthVariableOverride"` + DatabaseURL string `json:"databaseURL"` + ProjectID string `json:"projectId"` + StorageBucket string `json:"storageBucket"` } // Auth returns an instance of auth.Client. @@ -161,9 +153,9 @@ func NewApp(ctx context.Context, config *Config, opts ...option.ClientOption) (* pid = os.Getenv("GCLOUD_PROJECT") } - ao := make(map[string]interface{}) + ao := defaultAuthOverrides if config.AuthOverride != nil { - ao = config.AuthOverride.Map + ao = *config.AuthOverride } return &App{ @@ -193,6 +185,19 @@ func getConfigDefaults() (*Config, error) { return nil, err } } - err := json.Unmarshal(dat, fbc) - return fbc, err + if err := json.Unmarshal(dat, fbc); err != nil { + return nil, err + } + + // Some special handling necessary for db auth overrides + var m map[string]interface{} + if err := json.Unmarshal(dat, &m); err != nil { + return nil, err + } + if ao, ok := m["databaseAuthVariableOverride"]; ok && ao == nil { + // Auth overrides are explicitly set to null + var nullMap map[string]interface{} + fbc.AuthOverride = &nullMap + } + return fbc, nil } diff --git a/firebase_test.go b/firebase_test.go index 874b5d8e..93087c10 100644 --- a/firebase_test.go +++ b/firebase_test.go @@ -34,7 +34,6 @@ import ( "reflect" - "firebase.google.com/go/db" "golang.org/x/net/context" "golang.org/x/oauth2" "google.golang.org/api/option" @@ -255,7 +254,7 @@ func TestDatabaseAuthOverrides(t *testing.T) { for _, tc := range cases { ctx := context.Background() conf := &Config{ - AuthOverride: &db.AuthOverride{tc}, + AuthOverride: &tc, DatabaseURL: "https://mock-db.firebaseio.com", } app, err := NewApp(ctx, conf, option.WithCredentialsFile("testdata/service_account.json")) @@ -405,7 +404,10 @@ func TestVersion(t *testing.T) { } } } + func TestAutoInit(t *testing.T) { + var nullMap map[string]interface{} + uidMap := map[string]interface{}{"uid": "test"} tests := []struct { name string optionsConfig string @@ -504,6 +506,34 @@ func TestAutoInit(t *testing.T) { StorageBucket: "auto-init.storage.bucket", }, }, + { + "", + `{ + "databaseURL": "https://auto-init.database.url", + "projectId": "auto-init-project-id", + "databaseAuthVariableOverride": null + }`, + nil, + &Config{ + DatabaseURL: "https://auto-init.database.url", + ProjectID: "auto-init-project-id", + AuthOverride: &nullMap, + }, + }, + { + "", + `{ + "databaseURL": "https://auto-init.database.url", + "projectId": "auto-init-project-id", + "databaseAuthVariableOverride": {"uid": "test"} + }`, + nil, + &Config{ + DatabaseURL: "https://auto-init.database.url", + ProjectID: "auto-init-project-id", + AuthOverride: &uidMap, + }, + }, } credOld := overwriteEnv(credEnvVar, "testdata/service_account.json") @@ -574,6 +604,13 @@ func compareConfig(got *App, want *Config, t *testing.T) { if got.dbURL != want.DatabaseURL { t.Errorf("app.dbURL = %q; want = %q", got.dbURL, want.DatabaseURL) } + if want.AuthOverride != nil { + if !reflect.DeepEqual(got.ao, *want.AuthOverride) { + t.Errorf("app.ao = %#v; want = %#v", got.ao, *want.AuthOverride) + } + } else if !reflect.DeepEqual(got.ao, defaultAuthOverrides) { + t.Errorf("app.ao = %#v; want = nil", got.ao) + } if got.projectID != want.ProjectID { t.Errorf("app.projectID = %q; want = %q", got.projectID, want.ProjectID) } diff --git a/integration/db/db_test.go b/integration/db/db_test.go index d201fe63..893ca180 100644 --- a/integration/db/db_test.go +++ b/integration/db/db_test.go @@ -98,11 +98,10 @@ func initClient(pid string) (*db.Client, error) { func initOverrideClient(pid string) (*db.Client, error) { ctx := context.Background() + ao := map[string]interface{}{"uid": "user1"} app, err := internal.NewTestApp(ctx, &firebase.Config{ - DatabaseURL: fmt.Sprintf("https://%s.firebaseio.com", pid), - AuthOverride: &db.AuthOverride{ - Map: map[string]interface{}{"uid": "user1"}, - }, + DatabaseURL: fmt.Sprintf("https://%s.firebaseio.com", pid), + AuthOverride: &ao, }) if err != nil { return nil, err @@ -113,9 +112,10 @@ func initOverrideClient(pid string) (*db.Client, error) { func initGuestClient(pid string) (*db.Client, error) { ctx := context.Background() + var nullMap map[string]interface{} app, err := internal.NewTestApp(ctx, &firebase.Config{ DatabaseURL: fmt.Sprintf("https://%s.firebaseio.com", pid), - AuthOverride: &db.AuthOverride{}, + AuthOverride: &nullMap, }) if err != nil { return nil, err From 14b35912d320159cd2f6f2b4eff5435de75f615f Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Thu, 15 Feb 2018 15:16:40 -0800 Subject: [PATCH 55/58] Removed db.AuthOverride type --- db/db.go | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/db/db.go b/db/db.go index 81f16cea..b0a97ef1 100644 --- a/db/db.go +++ b/db/db.go @@ -87,20 +87,6 @@ func NewClient(ctx context.Context, c *internal.DatabaseConfig) (*Client, error) }, nil } -// AuthOverride regulates how Firebase security rules are enforced on database invocations. -// -// By default, the database calls made by the Admin SDK have administrative privileges, thereby -// allowing them to completely bypass all Firebase security rules. This behavior can be overridden -// by setting an AuthOverride. When specified, the AuthOverride value will become visible to the -// database server during security rule evaluation. Specifically, this value will be accessible -// via the auth variable of the security rules. -// -// Refer to https://firebase.google.com/docs/database/admin/start#authenticate-with-limited-privileges -// for more details and code samples. -type AuthOverride struct { - Map map[string]interface{} -} - // NewRef returns a new database reference representing the node at the specified path. func (c *Client) NewRef(path string) *Ref { segs := parsePath(path) From c38f24e9f8a0730bab7d1d066f0c5bf33e5d0a0b Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Fri, 16 Feb 2018 11:42:23 -0800 Subject: [PATCH 56/58] Renamed ao to authOverride everywhere; Other code review nits --- db/db.go | 24 +++++++++++------------ db/db_test.go | 42 ++++++++++++++++++++--------------------- db/query.go | 18 +++++++++--------- db/ref.go | 2 ++ db/ref_test.go | 2 +- firebase.go | 12 ++++++------ firebase_test.go | 16 ++++++++-------- internal/internal.go | 8 ++++---- testdata/dinosaurs.json | 2 +- 9 files changed, 64 insertions(+), 62 deletions(-) diff --git a/db/db.go b/db/db.go index b0a97ef1..6bed3922 100644 --- a/db/db.go +++ b/db/db.go @@ -36,9 +36,9 @@ const authVarOverride = "auth_variable_override" // Client is the interface for the Firebase Realtime Database service. type Client struct { - hc *internal.HTTPClient - url string - ao string + hc *internal.HTTPClient + url string + authOverride string } // NewClient creates a new instance of the Firebase Database Client. @@ -58,14 +58,14 @@ func NewClient(ctx context.Context, c *internal.DatabaseConfig) (*Client, error) if err != nil { return nil, err } else if p.Scheme != "https" { - return nil, fmt.Errorf("invalid database URL (incorrect scheme): %q", c.URL) + return nil, fmt.Errorf("invalid database URL: %q; want scheme: %q", c.URL, "https") } else if !strings.HasSuffix(p.Host, ".firebaseio.com") { - return nil, fmt.Errorf("invalid database URL (incorrest host): %q", c.URL) + return nil, fmt.Errorf("invalid database URL: %q; want host: %q", c.URL, "firebaseio.com") } var ao []byte - if c.AO == nil || len(c.AO) > 0 { - ao, err = json.Marshal(c.AO) + if c.AuthOverride == nil || len(c.AuthOverride) > 0 { + ao, err = json.Marshal(c.AuthOverride) if err != nil { return nil, err } @@ -81,9 +81,9 @@ func NewClient(ctx context.Context, c *internal.DatabaseConfig) (*Client, error) return p.Error } return &Client{ - hc: &internal.HTTPClient{Client: hc, ErrParser: ep}, - url: fmt.Sprintf("https://%s", p.Host), - ao: string(ao), + hc: &internal.HTTPClient{Client: hc, ErrParser: ep}, + url: fmt.Sprintf("https://%s", p.Host), + authOverride: string(ao), }, nil } @@ -112,8 +112,8 @@ func (c *Client) send( if strings.ContainsAny(path, invalidChars) { return nil, fmt.Errorf("invalid path with illegal characters: %q", path) } - if c.ao != "" { - opts = append(opts, internal.WithQueryParam(authVarOverride, c.ao)) + if c.authOverride != "" { + opts = append(opts, internal.WithQueryParam(authVarOverride, c.authOverride)) } return c.hc.Do(ctx, &internal.Request{ Method: method, diff --git a/db/db_test.go b/db/db_test.go index 0ec601cf..01234504 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -53,10 +53,10 @@ var testref *Ref func TestMain(m *testing.M) { var err error client, err = NewClient(context.Background(), &internal.DatabaseConfig{ - Opts: testOpts, - URL: testURL, - Version: "1.2.3", - AO: map[string]interface{}{}, + Opts: testOpts, + URL: testURL, + Version: "1.2.3", + AuthOverride: map[string]interface{}{}, }) if err != nil { log.Fatalln(err) @@ -64,10 +64,10 @@ func TestMain(m *testing.M) { ao := map[string]interface{}{"uid": "user1"} aoClient, err = NewClient(context.Background(), &internal.DatabaseConfig{ - Opts: testOpts, - URL: testURL, - Version: "1.2.3", - AO: ao, + Opts: testOpts, + URL: testURL, + Version: "1.2.3", + AuthOverride: ao, }) if err != nil { log.Fatalln(err) @@ -86,9 +86,9 @@ func TestMain(m *testing.M) { func TestNewClient(t *testing.T) { c, err := NewClient(context.Background(), &internal.DatabaseConfig{ - Opts: testOpts, - URL: testURL, - AO: make(map[string]interface{}), + Opts: testOpts, + URL: testURL, + AuthOverride: make(map[string]interface{}), }) if err != nil { t.Fatal(err) @@ -99,8 +99,8 @@ func TestNewClient(t *testing.T) { if c.hc == nil { t.Errorf("NewClient().hc = nil; want non-nil") } - if c.ao != "" { - t.Errorf("NewClient().ao = %q; want = %q", c.ao, "") + if c.authOverride != "" { + t.Errorf("NewClient().ao = %q; want = %q", c.authOverride, "") } } @@ -111,9 +111,9 @@ func TestNewClientAuthOverrides(t *testing.T) { } for _, tc := range cases { c, err := NewClient(context.Background(), &internal.DatabaseConfig{ - Opts: testOpts, - URL: testURL, - AO: tc, + Opts: testOpts, + URL: testURL, + AuthOverride: tc, }) if err != nil { t.Fatal(err) @@ -128,8 +128,8 @@ func TestNewClientAuthOverrides(t *testing.T) { if err != nil { t.Fatal(err) } - if c.ao != string(b) { - t.Errorf("NewClient(%v).ao = %q; want = %q", tc, c.ao, string(b)) + if c.authOverride != string(b) { + t.Errorf("NewClient(%v).ao = %q; want = %q", tc, c.authOverride, string(b)) } } } @@ -154,9 +154,9 @@ func TestInvalidURL(t *testing.T) { func TestInvalidAuthOverride(t *testing.T) { c, err := NewClient(context.Background(), &internal.DatabaseConfig{ - Opts: testOpts, - URL: testURL, - AO: map[string]interface{}{"uid": func() {}}, + Opts: testOpts, + URL: testURL, + AuthOverride: map[string]interface{}{"uid": func() {}}, }) if c != nil || err == nil { t.Errorf("NewClient() = (%v, %v); want = (nil, error)", c, err) diff --git a/db/query.go b/db/query.go index 08536ccd..c4753a35 100644 --- a/db/query.go +++ b/db/query.go @@ -39,7 +39,7 @@ import ( type Query struct { client *Client path string - ob orderBy + order orderBy limFirst, limLast int start, end, equalTo interface{} } @@ -48,7 +48,7 @@ type Query struct { // // The resulting Query will only return child nodes with a value greater than or equal to v. func (q *Query) StartAt(v interface{}) *Query { - q2 := new(Query) + q2 := &Query{} *q2 = *q q2.start = v return q2 @@ -58,7 +58,7 @@ func (q *Query) StartAt(v interface{}) *Query { // // The resulting Query will only return child nodes with a value less than or equal to v. func (q *Query) EndAt(v interface{}) *Query { - q2 := new(Query) + q2 := &Query{} *q2 = *q q2.end = v return q2 @@ -68,7 +68,7 @@ func (q *Query) EndAt(v interface{}) *Query { // // The resulting Query will only return child nodes whose values equal to v. func (q *Query) EqualTo(v interface{}) *Query { - q2 := new(Query) + q2 := &Query{} *q2 = *q q2.equalTo = v return q2 @@ -77,7 +77,7 @@ func (q *Query) EqualTo(v interface{}) *Query { // LimitToFirst returns a shallow copy of the Query, which is anchored to the first n // elements of the window. func (q *Query) LimitToFirst(n int) *Query { - q2 := new(Query) + q2 := &Query{} *q2 = *q q2.limFirst = n return q2 @@ -86,7 +86,7 @@ func (q *Query) LimitToFirst(n int) *Query { // LimitToLast returns a shallow copy of the Query, which is anchored to the last n // elements of the window. func (q *Query) LimitToLast(n int) *Query { - q2 := new(Query) + q2 := &Query{} *q2 = *q q2.limLast = n return q2 @@ -134,7 +134,7 @@ func (q *Query) GetOrdered(ctx context.Context, v interface{}) error { if temp == nil { return nil } - sr := newSortableQueryResult(temp, q.ob) + sr := newSortableQueryResult(temp, q.order) sort.Sort(sr) var values []interface{} @@ -179,12 +179,12 @@ func newQuery(r *Ref, ob orderBy) *Query { return &Query{ client: r.client, path: r.Path, - ob: ob, + order: ob, } } func initQueryParams(q *Query, qp map[string]string) error { - ob, err := q.ob.encode() + ob, err := q.order.encode() if err != nil { return err } diff --git a/db/ref.go b/db/ref.go index ea5c697e..8fbadf84 100644 --- a/db/ref.go +++ b/db/ref.go @@ -25,6 +25,8 @@ import ( "golang.org/x/net/context" ) +// txnRetires is the maximum number of times a transaction is retried before giving up. Transaction +// retries are triggered by concurrent conflicting updates to the same database location. const txnRetries = 25 // Ref represents a node in the Firebase Realtime Database. diff --git a/db/ref_test.go b/db/ref_test.go index 7786b5b5..93e348d0 100644 --- a/db/ref_test.go +++ b/db/ref_test.go @@ -278,7 +278,7 @@ func TestGetIfChanged(t *testing.T) { }) } -func TestWerlformedHttpError(t *testing.T) { +func TestWelformedHttpError(t *testing.T) { mock := &mockServer{Resp: map[string]string{"error": "test error"}, Status: 500} srv := mock.Start(client) defer srv.Close() diff --git a/firebase.go b/firebase.go index d06a8483..37f7f6f1 100644 --- a/firebase.go +++ b/firebase.go @@ -48,7 +48,7 @@ const firebaseEnvName = "FIREBASE_CONFIG" // An App holds configuration and state common to all Firebase services that are exposed from the SDK. type App struct { - ao map[string]interface{} + authOverride map[string]interface{} creds *google.DefaultCredentials dbURL string projectID string @@ -78,10 +78,10 @@ func (a *App) Auth(ctx context.Context) (*auth.Client, error) { // Database returns an instance of db.Client. func (a *App) Database(ctx context.Context) (*db.Client, error) { conf := &internal.DatabaseConfig{ - AO: a.ao, - URL: a.dbURL, - Opts: a.opts, - Version: Version, + AuthOverride: a.authOverride, + URL: a.dbURL, + Opts: a.opts, + Version: Version, } return db.NewClient(ctx, conf) } @@ -159,7 +159,7 @@ func NewApp(ctx context.Context, config *Config, opts ...option.ClientOption) (* } return &App{ - ao: ao, + authOverride: ao, creds: creds, dbURL: config.DatabaseURL, projectID: pid, diff --git a/firebase_test.go b/firebase_test.go index 93087c10..a5b2db84 100644 --- a/firebase_test.go +++ b/firebase_test.go @@ -237,8 +237,8 @@ func TestDatabase(t *testing.T) { t.Fatal(err) } - if app.ao == nil || len(app.ao) != 0 { - t.Errorf("AuthOverrides = %v; want = empty map", app.ao) + if app.authOverride == nil || len(app.authOverride) != 0 { + t.Errorf("AuthOverrides = %v; want = empty map", app.authOverride) } if c, err := app.Database(ctx); c == nil || err != nil { t.Errorf("Database() = (%v, %v); want (db, nil)", c, err) @@ -262,8 +262,8 @@ func TestDatabaseAuthOverrides(t *testing.T) { t.Fatal(err) } - if !reflect.DeepEqual(app.ao, tc) { - t.Errorf("AuthOverrides = %v; want = %v", app.ao, tc) + if !reflect.DeepEqual(app.authOverride, tc) { + t.Errorf("AuthOverrides = %v; want = %v", app.authOverride, tc) } if c, err := app.Database(ctx); c == nil || err != nil { t.Errorf("Database() = (%v, %v); want (db, nil)", c, err) @@ -605,11 +605,11 @@ func compareConfig(got *App, want *Config, t *testing.T) { t.Errorf("app.dbURL = %q; want = %q", got.dbURL, want.DatabaseURL) } if want.AuthOverride != nil { - if !reflect.DeepEqual(got.ao, *want.AuthOverride) { - t.Errorf("app.ao = %#v; want = %#v", got.ao, *want.AuthOverride) + if !reflect.DeepEqual(got.authOverride, *want.AuthOverride) { + t.Errorf("app.ao = %#v; want = %#v", got.authOverride, *want.AuthOverride) } - } else if !reflect.DeepEqual(got.ao, defaultAuthOverrides) { - t.Errorf("app.ao = %#v; want = nil", got.ao) + } else if !reflect.DeepEqual(got.authOverride, defaultAuthOverrides) { + t.Errorf("app.ao = %#v; want = nil", got.authOverride) } if got.projectID != want.ProjectID { t.Errorf("app.projectID = %q; want = %q", got.projectID, want.ProjectID) diff --git a/internal/internal.go b/internal/internal.go index c520a3ee..bc4f41d1 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -47,10 +47,10 @@ type InstanceIDConfig struct { // DatabaseConfig represents the configuration of Firebase Database service. type DatabaseConfig struct { - Opts []option.ClientOption - URL string - Version string - AO map[string]interface{} + Opts []option.ClientOption + URL string + Version string + AuthOverride map[string]interface{} } // StorageConfig represents the configuration of Google Cloud Storage service. diff --git a/testdata/dinosaurs.json b/testdata/dinosaurs.json index 29ca1936..9d7afaab 100644 --- a/testdata/dinosaurs.json +++ b/testdata/dinosaurs.json @@ -75,4 +75,4 @@ "stegosaurus": 5, "triceratops": 22 } -} \ No newline at end of file +} From db9249ee92e0cdeac21ceef7f52fae3262541b62 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Tue, 27 Feb 2018 12:07:10 -0800 Subject: [PATCH 57/58] Introducing the QueryNode interface to handle ordered query results (#100) --- db/query.go | 90 +++++------ db/query_test.go | 290 +++++++++++++++++++++-------------- integration/db/query_test.go | 210 +++++++++++++------------ 3 files changed, 337 insertions(+), 253 deletions(-) diff --git a/db/query.go b/db/query.go index c4753a35..c6013483 100644 --- a/db/query.go +++ b/db/query.go @@ -18,7 +18,6 @@ import ( "encoding/json" "fmt" "net/http" - "reflect" "sort" "strconv" "strings" @@ -28,6 +27,12 @@ import ( "golang.org/x/net/context" ) +// QueryNode represents a data node retrieved from an ordered query. +type QueryNode interface { + Key() string + Unmarshal(v interface{}) error +} + // Query represents a complex query that can be executed on a Ref. // // Complex queries can consist of up to 2 components: a required ordering constraint, and an @@ -112,40 +117,23 @@ func (q *Query) Get(ctx context.Context, v interface{}) error { return resp.Unmarshal(http.StatusOK, v) } -// GetOrdered executes the Query and provides the results as an ordered list. -// -// v must be a pointer to an array or a slice. Only the child values returned by the query are -// unmarshalled into v. Top-level keys are not returned. Although if the Query was created using -// OrderByKey(), the returned values will still be ordered based on their keys. -func (q *Query) GetOrdered(ctx context.Context, v interface{}) error { - rv := reflect.ValueOf(v) - if rv.Kind() != reflect.Ptr || rv.IsNil() { - return fmt.Errorf("nil or not a pointer") - } - if rv.Elem().Kind() != reflect.Slice && rv.Elem().Kind() != reflect.Array { - return fmt.Errorf("non-array non-slice pointer") - } - +// GetOrdered executes the Query and returns the results as an ordered slice. +func (q *Query) GetOrdered(ctx context.Context) ([]QueryNode, error) { var temp interface{} if err := q.Get(ctx, &temp); err != nil { - return err + return nil, err } - if temp == nil { - return nil + return nil, nil } - sr := newSortableQueryResult(temp, q.order) - sort.Sort(sr) - var values []interface{} - for _, val := range sr { - values = append(values, val.Value) - } - b, err := json.Marshal(values) - if err != nil { - return err + sn := newSortableNodes(temp, q.order) + sort.Sort(sn) + result := make([]QueryNode, len(sn)) + for i, v := range sn { + result[i] = v } - return json.Unmarshal(b, v) + return result, nil } // OrderByChild returns a Query that orders data by child values before applying filters. @@ -307,14 +295,30 @@ func newComparableKey(v interface{}) *comparableKey { return &comparableKey{Num: &f} } -type queryResult struct { - Key *comparableKey +type queryNodeImpl struct { + CompKey *comparableKey Value interface{} Index interface{} IndexType int } -func newQueryResult(key, val interface{}, order orderBy) *queryResult { +func (q *queryNodeImpl) Key() string { + if q.CompKey.Str != nil { + return *q.CompKey.Str + } + // Numeric keys in queryNodeImpl are always array indices, and can be safely coverted into int. + return strconv.Itoa(int(*q.CompKey.Num)) +} + +func (q *queryNodeImpl) Unmarshal(v interface{}) error { + b, err := json.Marshal(q.Value) + if err != nil { + return err + } + return json.Unmarshal(b, v) +} + +func newQueryNode(key, val interface{}, order orderBy) *queryNodeImpl { var index interface{} if prop, ok := order.(orderByProperty); ok { if prop == "$value" { @@ -326,25 +330,25 @@ func newQueryResult(key, val interface{}, order orderBy) *queryResult { path := order.(orderByChild) index = extractChildValue(val, string(path)) } - return &queryResult{ - Key: newComparableKey(key), + return &queryNodeImpl{ + CompKey: newComparableKey(key), Value: val, Index: index, IndexType: getIndexType(index), } } -type sortableQueryResult []*queryResult +type sortableNodes []*queryNodeImpl -func (s sortableQueryResult) Len() int { +func (s sortableNodes) Len() int { return len(s) } -func (s sortableQueryResult) Swap(i, j int) { +func (s sortableNodes) Swap(i, j int) { s[i], s[j] = s[j], s[i] } -func (s sortableQueryResult) Less(i, j int) bool { +func (s sortableNodes) Less(i, j int) bool { a, b := s[i], s[j] var aKey, bKey *comparableKey if a.IndexType == b.IndexType { @@ -353,7 +357,7 @@ func (s sortableQueryResult) Less(i, j int) bool { if (a.IndexType == typeNumeric || a.IndexType == typeString) && a.Index != b.Index { aKey, bKey = newComparableKey(a.Index), newComparableKey(b.Index) } else { - aKey, bKey = a.Key, b.Key + aKey, bKey = a.CompKey, b.CompKey } } else { // If the indices are of different types, use the type ordering of Firebase. @@ -363,18 +367,18 @@ func (s sortableQueryResult) Less(i, j int) bool { return aKey.Compare(bKey) < 0 } -func newSortableQueryResult(values interface{}, order orderBy) sortableQueryResult { - var entries sortableQueryResult +func newSortableNodes(values interface{}, order orderBy) sortableNodes { + var entries sortableNodes if m, ok := values.(map[string]interface{}); ok { for key, val := range m { - entries = append(entries, newQueryResult(key, val, order)) + entries = append(entries, newQueryNode(key, val, order)) } } else if l, ok := values.([]interface{}); ok { for key, val := range l { - entries = append(entries, newQueryResult(key, val, order)) + entries = append(entries, newQueryNode(key, val, order)) } } else { - entries = append(entries, newQueryResult(0, values, order)) + entries = append(entries, newQueryNode(0, values, order)) } return entries } diff --git a/db/query_test.go b/db/query_test.go index 234764e5..4473daff 100644 --- a/db/query_test.go +++ b/db/query_test.go @@ -30,55 +30,67 @@ var sortableKeysResp = map[string]interface{}{ } var sortableValuesResp = []struct { - resp map[string]interface{} - want []interface{} + resp map[string]interface{} + want []interface{} + wantKeys []string }{ { - resp: map[string]interface{}{"k1": 1, "k2": 2, "k3": 3}, - want: []interface{}{1.0, 2.0, 3.0}, + resp: map[string]interface{}{"k1": 1, "k2": 2, "k3": 3}, + want: []interface{}{1.0, 2.0, 3.0}, + wantKeys: []string{"k1", "k2", "k3"}, }, { - resp: map[string]interface{}{"k1": 3, "k2": 2, "k3": 1}, - want: []interface{}{1.0, 2.0, 3.0}, + resp: map[string]interface{}{"k1": 3, "k2": 2, "k3": 1}, + want: []interface{}{1.0, 2.0, 3.0}, + wantKeys: []string{"k3", "k2", "k1"}, }, { - resp: map[string]interface{}{"k1": 3, "k2": 1, "k3": 2}, - want: []interface{}{1.0, 2.0, 3.0}, + resp: map[string]interface{}{"k1": 3, "k2": 1, "k3": 2}, + want: []interface{}{1.0, 2.0, 3.0}, + wantKeys: []string{"k2", "k3", "k1"}, }, { - resp: map[string]interface{}{"k1": 1, "k2": 2, "k3": 1}, - want: []interface{}{1.0, 1.0, 2.0}, + resp: map[string]interface{}{"k1": 1, "k2": 2, "k3": 1}, + want: []interface{}{1.0, 1.0, 2.0}, + wantKeys: []string{"k1", "k3", "k2"}, }, { - resp: map[string]interface{}{"k1": 1, "k2": 1, "k3": 2}, - want: []interface{}{1.0, 1.0, 2.0}, + resp: map[string]interface{}{"k1": 1, "k2": 1, "k3": 2}, + want: []interface{}{1.0, 1.0, 2.0}, + wantKeys: []string{"k1", "k2", "k3"}, }, { - resp: map[string]interface{}{"k1": 2, "k2": 1, "k3": 1}, - want: []interface{}{1.0, 1.0, 2.0}, + resp: map[string]interface{}{"k1": 2, "k2": 1, "k3": 1}, + want: []interface{}{1.0, 1.0, 2.0}, + wantKeys: []string{"k2", "k3", "k1"}, }, { - resp: map[string]interface{}{"k1": "foo", "k2": "bar", "k3": "baz"}, - want: []interface{}{"bar", "baz", "foo"}, + resp: map[string]interface{}{"k1": "foo", "k2": "bar", "k3": "baz"}, + want: []interface{}{"bar", "baz", "foo"}, + wantKeys: []string{"k2", "k3", "k1"}, }, { - resp: map[string]interface{}{"k1": "foo", "k2": "bar", "k3": 10}, - want: []interface{}{10.0, "bar", "foo"}, + resp: map[string]interface{}{"k1": "foo", "k2": "bar", "k3": 10}, + want: []interface{}{10.0, "bar", "foo"}, + wantKeys: []string{"k3", "k2", "k1"}, }, { - resp: map[string]interface{}{"k1": "foo", "k2": "bar", "k3": nil}, - want: []interface{}{nil, "bar", "foo"}, + resp: map[string]interface{}{"k1": "foo", "k2": "bar", "k3": nil}, + want: []interface{}{nil, "bar", "foo"}, + wantKeys: []string{"k3", "k2", "k1"}, }, { - resp: map[string]interface{}{"k1": 5, "k2": "bar", "k3": nil}, - want: []interface{}{nil, 5.0, "bar"}, + resp: map[string]interface{}{"k1": 5, "k2": "bar", "k3": nil}, + want: []interface{}{nil, 5.0, "bar"}, + wantKeys: []string{"k3", "k1", "k2"}, }, { resp: map[string]interface{}{ "k1": true, "k2": 0, "k3": "foo", "k4": "foo", "k5": false, "k6": map[string]interface{}{"k1": true}, }, - want: []interface{}{false, true, 0.0, "foo", "foo", map[string]interface{}{"k1": true}}, + want: []interface{}{false, true, 0.0, "foo", "foo", map[string]interface{}{"k1": true}}, + wantKeys: []string{"k5", "k1", "k2", "k3", "k4", "k6"}, }, { resp: map[string]interface{}{ @@ -90,6 +102,7 @@ var sortableValuesResp = []struct { nil, false, true, 0.0, "foo", "foo", map[string]interface{}{"k1": true}, map[string]interface{}{"k0": true}, }, + wantKeys: []string{"k7", "k5", "k1", "k2", "k3", "k4", "k6", "k8"}, }, } @@ -406,29 +419,6 @@ func TestAllParamsQuery(t *testing.T) { }) } -func TestInvalidGetOrdered(t *testing.T) { - q := testref.OrderByKey() - - want := "nil or not a pointer" - var p *[]person // nil - err := q.GetOrdered(context.Background(), p) - if err == nil || err.Error() != want { - t.Errorf("GetOrdered(interface) = %v; want = %v", err, want) - } - - var i interface{} // not a pointer - err = q.GetOrdered(context.Background(), i) - if err == nil || err.Error() != want { - t.Errorf("GetOrdered(interface) = %v; want = %v", err, want) - } - - want = "non-array non-slice pointer" - err = q.GetOrdered(context.Background(), &i) // pointer to a non-array value - if err == nil || err.Error() != want { - t.Errorf("GetOrdered(interface) = %v; want = %v", err, want) - } -} - func TestChildQueryGetOrdered(t *testing.T) { mock := &mockServer{Resp: sortableKeysResp} srv := mock.Start(client) @@ -444,9 +434,9 @@ func TestChildQueryGetOrdered(t *testing.T) { } var reqs []*testReq - for _, tc := range cases { - var result []person - if err := testref.OrderByChild(tc.child).GetOrdered(context.Background(), &result); err != nil { + for idx, tc := range cases { + result, err := testref.OrderByChild(tc.child).GetOrdered(context.Background()) + if err != nil { t.Fatal(err) } reqs = append(reqs, &testReq{ @@ -455,12 +445,20 @@ func TestChildQueryGetOrdered(t *testing.T) { Query: map[string]string{"orderBy": fmt.Sprintf("%q", tc.child)}, }) - var got []string + var gotKeys, gotVals []string for _, r := range result { - got = append(got, r.Name) + var p person + if err := r.Unmarshal(&p); err != nil { + t.Fatal(err) + } + gotKeys = append(gotKeys, r.Key()) + gotVals = append(gotVals, p.Name) } - if !reflect.DeepEqual(tc.want, got) { - t.Errorf("GetOrdered(child: %q) = %v; want = %v", tc.child, got, tc.want) + if !reflect.DeepEqual(tc.want, gotKeys) { + t.Errorf("[%d] GetOrdered(child: %q) = %v; want = %v", idx, tc.child, gotKeys, tc.want) + } + if !reflect.DeepEqual(tc.want, gotVals) { + t.Errorf("[%d] GetOrdered(child: %q) = %v; want = %v", idx, tc.child, gotVals, tc.want) } } checkAllRequests(t, mock.Reqs, reqs) @@ -476,15 +474,15 @@ func TestImmediateChildQueryGetOrdered(t *testing.T) { } var reqs []*testReq - for _, tc := range sortableValuesResp { + for idx, tc := range sortableValuesResp { resp := map[string]interface{}{} for k, v := range tc.resp { resp[k] = map[string]interface{}{"child": v} } mock.Resp = resp - var result []parsedMap - if err := testref.OrderByChild("child").GetOrdered(context.Background(), &result); err != nil { + result, err := testref.OrderByChild("child").GetOrdered(context.Background()) + if err != nil { t.Fatal(err) } reqs = append(reqs, &testReq{ @@ -493,12 +491,21 @@ func TestImmediateChildQueryGetOrdered(t *testing.T) { Query: map[string]string{"orderBy": "\"child\""}, }) - var got []interface{} + var gotKeys []string + var gotVals []interface{} for _, r := range result { - got = append(got, r.Child) + var p parsedMap + if err := r.Unmarshal(&p); err != nil { + t.Fatal(err) + } + gotKeys = append(gotKeys, r.Key()) + gotVals = append(gotVals, p.Child) + } + if !reflect.DeepEqual(tc.wantKeys, gotKeys) { + t.Errorf("[%d] GetOrdered(child: %q) = %v; want = %v", idx, "child", gotKeys, tc.wantKeys) } - if !reflect.DeepEqual(tc.want, got) { - t.Errorf("GetOrdered(child: %q) = %v; want = %v", "child", got, tc.want) + if !reflect.DeepEqual(tc.want, gotVals) { + t.Errorf("[%d] GetOrdered(child: %q) = %v; want = %v", idx, "child", gotVals, tc.want) } } checkAllRequests(t, mock.Reqs, reqs) @@ -517,16 +524,16 @@ func TestNestedChildQueryGetOrdered(t *testing.T) { } var reqs []*testReq - for _, tc := range sortableValuesResp { + for idx, tc := range sortableValuesResp { resp := map[string]interface{}{} for k, v := range tc.resp { resp[k] = map[string]interface{}{"child": map[string]interface{}{"grandchild": v}} } mock.Resp = resp - var result []parsedMap q := testref.OrderByChild("child/grandchild") - if err := q.GetOrdered(context.Background(), &result); err != nil { + result, err := q.GetOrdered(context.Background()) + if err != nil { t.Fatal(err) } reqs = append(reqs, &testReq{ @@ -535,12 +542,21 @@ func TestNestedChildQueryGetOrdered(t *testing.T) { Query: map[string]string{"orderBy": "\"child/grandchild\""}, }) - var got []interface{} + var gotKeys []string + var gotVals []interface{} for _, r := range result { - got = append(got, r.Child.GrandChild) + var p parsedMap + if err := r.Unmarshal(&p); err != nil { + t.Fatal(err) + } + gotKeys = append(gotKeys, r.Key()) + gotVals = append(gotVals, p.Child.GrandChild) } - if !reflect.DeepEqual(tc.want, got) { - t.Errorf("GetOrdered(child: %q) = %v; want = %v", "child/grandchild", got, tc.want) + if !reflect.DeepEqual(tc.wantKeys, gotKeys) { + t.Errorf("[%d] GetOrdered(child: %q) = %v; want = %v", idx, "child/grandchild", gotKeys, tc.wantKeys) + } + if !reflect.DeepEqual(tc.want, gotVals) { + t.Errorf("[%d] GetOrdered(child: %q) = %v; want = %v", idx, "child/grandchild", gotVals, tc.want) } } checkAllRequests(t, mock.Reqs, reqs) @@ -551,8 +567,8 @@ func TestKeyQueryGetOrdered(t *testing.T) { srv := mock.Start(client) defer srv.Close() - var result []person - if err := testref.OrderByKey().GetOrdered(context.Background(), &result); err != nil { + result, err := testref.OrderByKey().GetOrdered(context.Background()) + if err != nil { t.Fatal(err) } req := &testReq{ @@ -561,14 +577,22 @@ func TestKeyQueryGetOrdered(t *testing.T) { Query: map[string]string{"orderBy": "\"$key\""}, } - var got []string + var gotKeys, gotVals []string for _, r := range result { - got = append(got, r.Name) + var p person + if err := r.Unmarshal(&p); err != nil { + t.Fatal(err) + } + gotKeys = append(gotKeys, r.Key()) + gotVals = append(gotVals, p.Name) } want := []string{"alice", "bob", "charlie", "dave", "ernie"} - if !reflect.DeepEqual(want, got) { - t.Errorf("GetOrdered(key) = %v; want = %v", got, want) + if !reflect.DeepEqual(want, gotKeys) { + t.Errorf("GetOrdered(key) = %v; want = %v", gotKeys, want) + } + if !reflect.DeepEqual(want, gotVals) { + t.Errorf("GetOrdered(key) = %v; want = %v", gotVals, want) } checkOnlyRequest(t, mock.Reqs, req) } @@ -579,11 +603,11 @@ func TestValueQueryGetOrdered(t *testing.T) { defer srv.Close() var reqs []*testReq - for _, tc := range sortableValuesResp { + for idx, tc := range sortableValuesResp { mock.Resp = tc.resp - var got []interface{} - if err := testref.OrderByValue().GetOrdered(context.Background(), &got); err != nil { + result, err := testref.OrderByValue().GetOrdered(context.Background()) + if err != nil { t.Fatal(err) } reqs = append(reqs, &testReq{ @@ -592,8 +616,22 @@ func TestValueQueryGetOrdered(t *testing.T) { Query: map[string]string{"orderBy": "\"$value\""}, }) - if !reflect.DeepEqual(tc.want, got) { - t.Errorf("GetOrdered(value) = %v; want = %v", got, tc.want) + var gotKeys []string + var gotVals []interface{} + for _, r := range result { + var v interface{} + if err := r.Unmarshal(&v); err != nil { + t.Fatal(err) + } + gotKeys = append(gotKeys, r.Key()) + gotVals = append(gotVals, v) + } + + if !reflect.DeepEqual(tc.wantKeys, gotKeys) { + t.Errorf("[%d] GetOrdered(value) = %v; want = %v", idx, gotKeys, tc.wantKeys) + } + if !reflect.DeepEqual(tc.want, gotVals) { + t.Errorf("[%d] GetOrdered(value) = %v; want = %v", idx, gotVals, tc.want) } } checkAllRequests(t, mock.Reqs, reqs) @@ -601,36 +639,44 @@ func TestValueQueryGetOrdered(t *testing.T) { func TestValueQueryGetOrderedWithList(t *testing.T) { cases := []struct { - resp []interface{} - want []interface{} + resp []interface{} + want []interface{} + wantKeys []string }{ { - resp: []interface{}{1, 2, 3}, - want: []interface{}{1.0, 2.0, 3.0}, + resp: []interface{}{1, 2, 3}, + want: []interface{}{1.0, 2.0, 3.0}, + wantKeys: []string{"0", "1", "2"}, }, { - resp: []interface{}{3, 2, 1}, - want: []interface{}{1.0, 2.0, 3.0}, + resp: []interface{}{3, 2, 1}, + want: []interface{}{1.0, 2.0, 3.0}, + wantKeys: []string{"2", "1", "0"}, }, { - resp: []interface{}{1, 3, 2}, - want: []interface{}{1.0, 2.0, 3.0}, + resp: []interface{}{1, 3, 2}, + want: []interface{}{1.0, 2.0, 3.0}, + wantKeys: []string{"0", "2", "1"}, }, { - resp: []interface{}{1, 3, 3}, - want: []interface{}{1.0, 3.0, 3.0}, + resp: []interface{}{1, 3, 3}, + want: []interface{}{1.0, 3.0, 3.0}, + wantKeys: []string{"0", "1", "2"}, }, { - resp: []interface{}{1, 2, 1}, - want: []interface{}{1.0, 1.0, 2.0}, + resp: []interface{}{1, 2, 1}, + want: []interface{}{1.0, 1.0, 2.0}, + wantKeys: []string{"0", "2", "1"}, }, { - resp: []interface{}{"foo", "bar", "baz"}, - want: []interface{}{"bar", "baz", "foo"}, + resp: []interface{}{"foo", "bar", "baz"}, + want: []interface{}{"bar", "baz", "foo"}, + wantKeys: []string{"1", "2", "0"}, }, { - resp: []interface{}{"foo", 1, false, nil, 0, true}, - want: []interface{}{nil, false, true, 0.0, 1.0, "foo"}, + resp: []interface{}{"foo", 1, false, nil, 0, true}, + want: []interface{}{nil, false, true, 0.0, 1.0, "foo"}, + wantKeys: []string{"3", "2", "5", "4", "1", "0"}, }, } @@ -642,8 +688,8 @@ func TestValueQueryGetOrderedWithList(t *testing.T) { for _, tc := range cases { mock.Resp = tc.resp - var got []interface{} - if err := testref.OrderByValue().GetOrdered(context.Background(), &got); err != nil { + result, err := testref.OrderByValue().GetOrdered(context.Background()) + if err != nil { t.Fatal(err) } reqs = append(reqs, &testReq{ @@ -652,8 +698,22 @@ func TestValueQueryGetOrderedWithList(t *testing.T) { Query: map[string]string{"orderBy": "\"$value\""}, }) - if !reflect.DeepEqual(tc.want, got) { - t.Errorf("GetOrdered(value) = %v; want = %v", got, tc.want) + var gotKeys []string + var gotVals []interface{} + for _, r := range result { + var v interface{} + if err := r.Unmarshal(&v); err != nil { + t.Fatal(err) + } + gotKeys = append(gotKeys, r.Key()) + gotVals = append(gotVals, v) + } + + if !reflect.DeepEqual(tc.wantKeys, gotKeys) { + t.Errorf("GetOrdered(value) = %v; want = %v", gotKeys, tc.wantKeys) + } + if !reflect.DeepEqual(tc.want, gotVals) { + t.Errorf("GetOrdered(value) = %v; want = %v", gotVals, tc.want) } } checkAllRequests(t, mock.Reqs, reqs) @@ -664,12 +724,12 @@ func TestGetOrderedWithNilResult(t *testing.T) { srv := mock.Start(client) defer srv.Close() - var got []interface{} - if err := testref.OrderByChild("child").GetOrdered(context.Background(), &got); err != nil { + result, err := testref.OrderByChild("child").GetOrdered(context.Background()) + if err != nil { t.Fatal(err) } - if got != nil { - t.Errorf("GetOrdered(value) = %v; want = nil", got) + if result != nil { + t.Errorf("GetOrdered(value) = %v; want = nil", result) } } @@ -678,14 +738,23 @@ func TestGetOrderedWithLeafNode(t *testing.T) { srv := mock.Start(client) defer srv.Close() - var got []interface{} - if err := testref.OrderByChild("child").GetOrdered(context.Background(), &got); err != nil { + result, err := testref.OrderByChild("child").GetOrdered(context.Background()) + if err != nil { t.Fatal(err) } + if len(result) != 1 { + t.Fatalf("GetOrdered(chid) = %d; want = 1", len(result)) + } + if result[0].Key() != "0" { + t.Errorf("GetOrdered(value).Key() = %v; want = %q", result[0].Key(), 0) + } - want := []interface{}{"foo"} - if !reflect.DeepEqual(want, got) { - t.Errorf("GetOrdered(value) = %v; want = %v", got, want) + var v interface{} + if err := result[0].Unmarshal(&v); err != nil { + t.Fatal(err) + } + if v != "foo" { + t.Errorf("GetOrdered(value) = %v; want = %v", v, "foo") } } @@ -695,12 +764,11 @@ func TestQueryHttpError(t *testing.T) { defer srv.Close() want := "http error status: 500; reason: test error" - var got []string - err := testref.OrderByChild("child").GetOrdered(context.Background(), &got) + result, err := testref.OrderByChild("child").GetOrdered(context.Background()) if err == nil || err.Error() != want { t.Errorf("GetOrdered() = %v; want = %v", err, want) } - if got != nil { - t.Errorf("GetOrdered() = %v; want = nil", got) + if result != nil { + t.Errorf("GetOrdered() = %v; want = nil", result) } } diff --git a/integration/db/query_test.go b/integration/db/query_test.go index 72b2e507..6573d915 100644 --- a/integration/db/query_test.go +++ b/integration/db/query_test.go @@ -17,6 +17,10 @@ package db import ( "testing" + "firebase.google.com/go/db" + + "reflect" + "golang.org/x/net/context" ) @@ -25,190 +29,171 @@ var heightSorted = []string{ "triceratops", "stegosaurus", "bruhathkayosaurus", } -func min(i, j int) int { - if i < j { - return i - } - return j -} - func TestLimitToFirst(t *testing.T) { for _, tc := range []int{2, 10} { - var d []Dinosaur - if err := dinos.OrderByChild("height"). - LimitToFirst(tc). - GetOrdered(context.Background(), &d); err != nil { + results, err := dinos.OrderByChild("height").LimitToFirst(tc).GetOrdered(context.Background()) + if err != nil { t.Fatal(err) } wl := min(tc, len(heightSorted)) want := heightSorted[:wl] - if len(d) != wl { - t.Errorf("LimitToFirst() = %d; want = %d", len(d), wl) + if len(results) != wl { + t.Errorf("LimitToFirst() = %d; want = %d", len(results), wl) } - for i, w := range want { - if d[i] != parsedTestData[w] { - t.Errorf("[%d] LimitToFirst() = %v; want = %v", i, d[i], parsedTestData[w]) - } + got := getNames(results) + if !reflect.DeepEqual(got, want) { + t.Errorf("LimitToLast() = %v; want = %v", got, want) } + compareValues(t, results) } } func TestLimitToLast(t *testing.T) { for _, tc := range []int{2, 10} { - var d []Dinosaur - if err := dinos.OrderByChild("height"). - LimitToLast(tc). - GetOrdered(context.Background(), &d); err != nil { + results, err := dinos.OrderByChild("height").LimitToLast(tc).GetOrdered(context.Background()) + if err != nil { t.Fatal(err) } wl := min(tc, len(heightSorted)) want := heightSorted[len(heightSorted)-wl:] - if len(d) != wl { - t.Errorf("LimitToLast() = %d; want = %d", len(d), wl) + if len(results) != wl { + t.Errorf("LimitToLast() = %d; want = %d", len(results), wl) } - for i, w := range want { - if d[i] != parsedTestData[w] { - t.Errorf("[%d] LimitToLast() = %v; want = %v", i, d[i], parsedTestData[w]) - } + got := getNames(results) + if !reflect.DeepEqual(got, want) { + t.Errorf("LimitToLast() = %v; want = %v", got, want) } + compareValues(t, results) } } func TestStartAt(t *testing.T) { - var d []Dinosaur - if err := dinos.OrderByChild("height"). - StartAt(3.5). - GetOrdered(context.Background(), &d); err != nil { + results, err := dinos.OrderByChild("height").StartAt(3.5).GetOrdered(context.Background()) + if err != nil { t.Fatal(err) } want := heightSorted[len(heightSorted)-2:] - if len(d) != len(want) { - t.Errorf("StartAt() = %d; want = %d", len(d), len(want)) + if len(results) != len(want) { + t.Errorf("StartAt() = %d; want = %d", len(results), len(want)) } - for i, w := range want { - if d[i] != parsedTestData[w] { - t.Errorf("[%d] StartAt() = %v; want = %v", i, d[i], parsedTestData[w]) - } + got := getNames(results) + if !reflect.DeepEqual(got, want) { + t.Errorf("LimitToLast() = %v; want = %v", got, want) } + compareValues(t, results) } func TestEndAt(t *testing.T) { - var d []Dinosaur - if err := dinos.OrderByChild("height"). - EndAt(3.5). - GetOrdered(context.Background(), &d); err != nil { + results, err := dinos.OrderByChild("height").EndAt(3.5).GetOrdered(context.Background()) + if err != nil { t.Fatal(err) } want := heightSorted[:4] - if len(d) != len(want) { - t.Errorf("StartAt() = %d; want = %d", len(d), len(want)) + if len(results) != len(want) { + t.Errorf("StartAt() = %d; want = %d", len(results), len(want)) } - for i, w := range want { - if d[i] != parsedTestData[w] { - t.Errorf("[%d] EndAt() = %v; want = %v", i, d[i], parsedTestData[w]) - } + got := getNames(results) + if !reflect.DeepEqual(got, want) { + t.Errorf("LimitToLast() = %v; want = %v", got, want) } + compareValues(t, results) } func TestStartAndEndAt(t *testing.T) { - var d []Dinosaur - if err := dinos.OrderByChild("height"). - StartAt(2.5). - EndAt(5). - GetOrdered(context.Background(), &d); err != nil { + results, err := dinos.OrderByChild("height").StartAt(2.5).EndAt(5).GetOrdered(context.Background()) + if err != nil { t.Fatal(err) } want := heightSorted[len(heightSorted)-3 : len(heightSorted)-1] - if len(d) != len(want) { - t.Errorf("StartAt(), EndAt() = %d; want = %d", len(d), len(want)) + if len(results) != len(want) { + t.Errorf("StartAt(), EndAt() = %d; want = %d", len(results), len(want)) } - for i, w := range want { - if d[i] != parsedTestData[w] { - t.Errorf("[%d] StartAt(), EndAt() = %v; want = %v", i, d[i], parsedTestData[w]) - } + got := getNames(results) + if !reflect.DeepEqual(got, want) { + t.Errorf("LimitToLast() = %v; want = %v", got, want) } + compareValues(t, results) } func TestEqualTo(t *testing.T) { - var d []Dinosaur - if err := dinos.OrderByChild("height"). - EqualTo(0.6). - GetOrdered(context.Background(), &d); err != nil { + results, err := dinos.OrderByChild("height").EqualTo(0.6).GetOrdered(context.Background()) + if err != nil { t.Fatal(err) } want := heightSorted[:2] - if len(d) != len(want) { - t.Errorf("EqualTo() = %d; want = %d", len(d), len(want)) + if len(results) != len(want) { + t.Errorf("EqualTo() = %d; want = %d", len(results), len(want)) } - for i, w := range want { - if d[i] != parsedTestData[w] { - t.Errorf("[%d] EqualTo() = %v; want = %v", i, d[i], parsedTestData[w]) - } + got := getNames(results) + if !reflect.DeepEqual(got, want) { + t.Errorf("LimitToLast() = %v; want = %v", got, want) } + compareValues(t, results) } func TestOrderByNestedChild(t *testing.T) { - var d []Dinosaur - if err := dinos.OrderByChild("ratings/pos"). - StartAt(4). - GetOrdered(context.Background(), &d); err != nil { + results, err := dinos.OrderByChild("ratings/pos").StartAt(4).GetOrdered(context.Background()) + if err != nil { t.Fatal(err) } want := []string{"pterodactyl", "stegosaurus", "triceratops"} - if len(d) != len(want) { - t.Errorf("OrderByChild(ratings/pos) = %d; want = %d", len(d), len(want)) + if len(results) != len(want) { + t.Errorf("OrderByChild(ratings/pos) = %d; want = %d", len(results), len(want)) } - for i, w := range want { - if d[i] != parsedTestData[w] { - t.Errorf("[%d] OrderByChild(ratings/pos) = %v; want = %v", i, d[i], parsedTestData[w]) - } + got := getNames(results) + if !reflect.DeepEqual(got, want) { + t.Errorf("LimitToLast() = %v; want = %v", got, want) } + compareValues(t, results) } func TestOrderByKey(t *testing.T) { - var d []Dinosaur - if err := dinos.OrderByKey(). - LimitToFirst(2). - GetOrdered(context.Background(), &d); err != nil { + results, err := dinos.OrderByKey().LimitToFirst(2).GetOrdered(context.Background()) + if err != nil { t.Fatal(err) } want := []string{"bruhathkayosaurus", "lambeosaurus"} - if len(d) != len(want) { - t.Errorf("OrderByKey() = %d; want = %d", len(d), len(want)) + if len(results) != len(want) { + t.Errorf("OrderByKey() = %d; want = %d", len(results), len(want)) } - for i, w := range want { - if d[i] != parsedTestData[w] { - t.Errorf("[%d] OrderByKey() = %v; want = %v", i, d[i], parsedTestData[w]) - } + got := getNames(results) + if !reflect.DeepEqual(got, want) { + t.Errorf("LimitToLast() = %v; want = %v", got, want) } + compareValues(t, results) } func TestOrderByValue(t *testing.T) { scores := ref.Child("scores") - var s []int - if err := scores.OrderByValue(). - LimitToLast(2). - GetOrdered(context.Background(), &s); err != nil { + results, err := scores.OrderByValue().LimitToLast(2).GetOrdered(context.Background()) + if err != nil { t.Fatal(err) } want := []string{"linhenykus", "pterodactyl"} - if len(s) != len(want) { - t.Errorf("OrderByValue() = %d; want = %d", len(s), len(want)) + if len(results) != len(want) { + t.Errorf("OrderByValue() = %d; want = %d", len(results), len(want)) } - scoresData := testData["scores"].(map[string]interface{}) - for i, w := range want { - ws := int(scoresData[w].(float64)) - if s[i] != ws { - t.Errorf("[%d] OrderByValue() = %d; want = %d", i, s[i], ws) + got := getNames(results) + if !reflect.DeepEqual(got, want) { + t.Errorf("LimitToLast() = %v; want = %v", got, want) + } + wantScores := []int{80, 93} + for i, r := range results { + var val int + if err := r.Unmarshal(&val); err != nil { + t.Fatalf("queryNode.Unmarshal() = %v", err) + } + if val != wantScores[i] { + t.Errorf("queryNode.Unmarshal() = %d; want = %d", val, wantScores[i]) } } } @@ -252,3 +237,30 @@ func TestUnorderedQuery(t *testing.T) { } } } + +func min(i, j int) int { + if i < j { + return i + } + return j +} + +func getNames(results []db.QueryNode) []string { + s := make([]string, len(results)) + for i, v := range results { + s[i] = v.Key() + } + return s +} + +func compareValues(t *testing.T, results []db.QueryNode) { + for _, r := range results { + var d Dinosaur + if err := r.Unmarshal(&d); err != nil { + t.Fatalf("queryNode.Unmarshal(%q) = %v", r.Key(), err) + } + if !reflect.DeepEqual(d, parsedTestData[r.Key()]) { + t.Errorf("queryNode.Unmarshal(%q) = %v; want = %v", r.Key(), d, parsedTestData[r.Key()]) + } + } +} From bbcdae23d18f59301ea31634d61733276907779b Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Tue, 27 Feb 2018 13:56:00 -0800 Subject: [PATCH 58/58] Database Sample Snippets (#102) * Adding database snippets * Adding query snippets * Added complex query samples * Updated variable name * Fixing a typo * Fixing query example * Updated DB snippets to use GetOrdered() * Removing unnecessary placeholders in Fatalln() calls * Removing unnecessary placeholders in Fatalln() calls --- integration/db/db_test.go | 2 +- snippets/auth.go | 25 +- snippets/db.go | 528 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 540 insertions(+), 15 deletions(-) create mode 100644 snippets/db.go diff --git a/integration/db/db_test.go b/integration/db/db_test.go index 893ca180..0754d5bf 100644 --- a/integration/db/db_test.go +++ b/integration/db/db_test.go @@ -156,7 +156,7 @@ func initRules() { if err != nil { log.Fatalln(err) } else if resp.StatusCode != http.StatusOK { - log.Fatalln("failed to update rules: %q", string(b)) + log.Fatalln("failed to update rules:", string(b)) } } diff --git a/snippets/auth.go b/snippets/auth.go index d9548e94..9fb739ba 100644 --- a/snippets/auth.go +++ b/snippets/auth.go @@ -93,10 +93,8 @@ func verifyIDToken(app *firebase.App, idToken string) *auth.Token { // https://firebase.google.com/docs/auth/admin/manage-sessions // ================================================================== -func revokeRefreshTokens(app *firebase.App, uid string) { - +func revokeRefreshTokens(ctx context.Context, app *firebase.App, uid string) { // [START revoke_tokens_golang] - ctx := context.Background() client, err := app.Auth(ctx) if err != nil { log.Fatalf("error getting Auth client: %v\n", err) @@ -114,8 +112,7 @@ func revokeRefreshTokens(app *firebase.App, uid string) { // [END revoke_tokens_golang] } -func verifyIDTokenAndCheckRevoked(app *firebase.App, idToken string) *auth.Token { - ctx := context.Background() +func verifyIDTokenAndCheckRevoked(ctx context.Context, app *firebase.App, idToken string) *auth.Token { // [START verify_id_token_and_check_revoked_golang] client, err := app.Auth(ctx) if err != nil { @@ -144,7 +141,7 @@ func getUser(ctx context.Context, app *firebase.App) *auth.UserRecord { // [START get_user_golang] // Get an auth client from the firebase.App - client, err := app.Auth(context.Background()) + client, err := app.Auth(ctx) if err != nil { log.Fatalf("error getting Auth client: %v\n", err) } @@ -192,7 +189,7 @@ func createUser(ctx context.Context, client *auth.Client) *auth.UserRecord { DisplayName("John Doe"). PhotoURL("http://www.example.com/12345678/photo.png"). Disabled(false) - u, err := client.CreateUser(context.Background(), params) + u, err := client.CreateUser(ctx, params) if err != nil { log.Fatalf("error creating user: %v\n", err) } @@ -208,7 +205,7 @@ func createUserWithUID(ctx context.Context, client *auth.Client) *auth.UserRecor UID(uid). Email("user@example.com"). PhoneNumber("+15555550100") - u, err := client.CreateUser(context.Background(), params) + u, err := client.CreateUser(ctx, params) if err != nil { log.Fatalf("error creating user: %v\n", err) } @@ -228,7 +225,7 @@ func updateUser(ctx context.Context, client *auth.Client) { DisplayName("John Doe"). PhotoURL("http://www.example.com/12345678/photo.png"). Disabled(true) - u, err := client.UpdateUser(context.Background(), uid, params) + u, err := client.UpdateUser(ctx, uid, params) if err != nil { log.Fatalf("error updating user: %v\n", err) } @@ -239,7 +236,7 @@ func updateUser(ctx context.Context, client *auth.Client) { func deleteUser(ctx context.Context, client *auth.Client) { uid := "d" // [START delete_user_golang] - err := client.DeleteUser(context.Background(), uid) + err := client.DeleteUser(ctx, uid) if err != nil { log.Fatalf("error deleting user: %v\n", err) } @@ -251,14 +248,14 @@ func customClaimsSet(ctx context.Context, app *firebase.App) { uid := "uid" // [START set_custom_user_claims_golang] // Get an auth client from the firebase.App - client, err := app.Auth(context.Background()) + client, err := app.Auth(ctx) if err != nil { log.Fatalf("error getting Auth client: %v\n", err) } // Set admin privilege on the user corresponding to uid. claims := map[string]interface{}{"admin": true} - err = client.SetCustomUserClaims(context.Background(), uid, claims) + err = client.SetCustomUserClaims(ctx, uid, claims) if err != nil { log.Fatalf("error setting custom claims %v\n", err) } @@ -350,7 +347,7 @@ func customClaimsIncremental(ctx context.Context, client *auth.Client) { func listUsers(ctx context.Context, client *auth.Client) { // [START list_all_users_golang] // Note, behind the scenes, the Users() iterator will retrive 1000 Users at a time through the API - iter := client.Users(context.Background(), "") + iter := client.Users(ctx, "") for { user, err := iter.Next() if err == iterator.Done { @@ -365,7 +362,7 @@ func listUsers(ctx context.Context, client *auth.Client) { // Iterating by pages 100 users at a time. // Note that using both the Next() function on an iterator and the NextPage() // on a Pager wrapping that same iterator will result in an error. - pager := iterator.NewPager(client.Users(context.Background(), ""), 100, "") + pager := iterator.NewPager(client.Users(ctx, ""), 100, "") for { var users []*auth.ExportedUserRecord nextPageToken, err := pager.NextPage(&users) diff --git a/snippets/db.go b/snippets/db.go new file mode 100644 index 00000000..8e0bea71 --- /dev/null +++ b/snippets/db.go @@ -0,0 +1,528 @@ +// Copyright 2018 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package snippets + +// [START authenticate_db_imports] +import ( + "context" + "fmt" + "log" + + "firebase.google.com/go/db" + + "firebase.google.com/go" + "google.golang.org/api/option" +) + +// [END authenticate_db_imports] + +func authenticateWithAdminPrivileges() { + // [START authenticate_with_admin_privileges] + ctx := context.Background() + conf := &firebase.Config{ + DatabaseURL: "https://databaseName.firebaseio.com", + } + // Fetch the service account key JSON file contents + opt := option.WithCredentialsFile("path/to/serviceAccountKey.json") + + // Initialize the app with a service account, granting admin privileges + app, err := firebase.NewApp(ctx, conf, opt) + if err != nil { + log.Fatalln("Error initializing app:", err) + } + + client, err := app.Database(ctx) + if err != nil { + log.Fatalln("Error initializing database client:", err) + } + + // As an admin, the app has access to read and write all data, regradless of Security Rules + ref := client.NewRef("restricted_access/secret_document") + var data map[string]interface{} + if err := ref.Get(ctx, &data); err != nil { + log.Fatalln("Error reading from database:", err) + } + fmt.Println(data) + // [END authenticate_with_admin_privileges] +} + +func authenticateWithLimitedPrivileges() { + // [START authenticate_with_limited_privileges] + ctx := context.Background() + // Initialize the app with a custom auth variable, limiting the server's access + ao := map[string]interface{}{"uid": "my-service-worker"} + conf := &firebase.Config{ + DatabaseURL: "https://databaseName.firebaseio.com", + AuthOverride: &ao, + } + + // Fetch the service account key JSON file contents + opt := option.WithCredentialsFile("path/to/serviceAccountKey.json") + + app, err := firebase.NewApp(ctx, conf, opt) + if err != nil { + log.Fatalln("Error initializing app:", err) + } + + client, err := app.Database(ctx) + if err != nil { + log.Fatalln("Error initializing database client:", err) + } + + // The app only has access as defined in the Security Rules + ref := client.NewRef("/some_resource") + var data map[string]interface{} + if err := ref.Get(ctx, &data); err != nil { + log.Fatalln("Error reading from database:", err) + } + fmt.Println(data) + // [END authenticate_with_limited_privileges] +} + +func authenticateWithGuestPrivileges() { + // [START authenticate_with_guest_privileges] + ctx := context.Background() + // Initialize the app with a nil auth variable, limiting the server's access + var nilMap map[string]interface{} + conf := &firebase.Config{ + DatabaseURL: "https://databaseName.firebaseio.com", + AuthOverride: &nilMap, + } + + // Fetch the service account key JSON file contents + opt := option.WithCredentialsFile("path/to/serviceAccountKey.json") + + app, err := firebase.NewApp(ctx, conf, opt) + if err != nil { + log.Fatalln("Error initializing app:", err) + } + + client, err := app.Database(ctx) + if err != nil { + log.Fatalln("Error initializing database client:", err) + } + + // The app only has access to public data as defined in the Security Rules + ref := client.NewRef("/some_resource") + var data map[string]interface{} + if err := ref.Get(ctx, &data); err != nil { + log.Fatalln("Error reading from database:", err) + } + fmt.Println(data) + // [END authenticate_with_guest_privileges] +} + +func getReference(ctx context.Context, app *firebase.App) { + // [START get_reference] + // Create a database client from App. + client, err := app.Database(ctx) + if err != nil { + log.Fatalln("Error initializing database client:", err) + } + + // Get a database reference to our blog. + ref := client.NewRef("server/saving-data/fireblog") + // [END get_reference] + fmt.Println(ref.Path) +} + +// [START user_type] + +// User is a json-serializable type. +type User struct { + DateOfBirth string `json:"date_of_birth,omitempty"` + FullName string `json:"full_name,omitempty"` + Nickname string `json:"nickname,omitempty"` +} + +// [END user_type] + +func setValue(ctx context.Context, ref *db.Ref) { + // [START set_value] + usersRef := ref.Child("users") + err := usersRef.Set(ctx, map[string]*User{ + "alanisawesome": &User{ + DateOfBirth: "June 23, 1912", + FullName: "Alan Turing", + }, + "gracehop": &User{ + DateOfBirth: "December 9, 1906", + FullName: "Grace Hopper", + }, + }) + if err != nil { + log.Fatalln("Error setting value:", err) + } + // [END set_value] +} + +func setChildValue(ctx context.Context, usersRef *db.Ref) { + // [START set_child_value] + if err := usersRef.Child("alanisawesome").Set(ctx, &User{ + DateOfBirth: "June 23, 1912", + FullName: "Alan Turing", + }); err != nil { + log.Fatalln("Error setting value:", err) + } + + if err := usersRef.Child("gracehop").Set(ctx, &User{ + DateOfBirth: "December 9, 1906", + FullName: "Grace Hopper", + }); err != nil { + log.Fatalln("Error setting value:", err) + } + // [END set_child_value] +} + +func updateChild(ctx context.Context, usersRef *db.Ref) { + // [START update_child] + hopperRef := usersRef.Child("gracehop") + if err := hopperRef.Update(ctx, map[string]interface{}{ + "nickname": "Amazing Grace", + }); err != nil { + log.Fatalln("Error updating child:", err) + } + // [END update_child] +} + +func updateChildren(ctx context.Context, usersRef *db.Ref) { + // [START update_children] + if err := usersRef.Update(ctx, map[string]interface{}{ + "alanisawesome/nickname": "Alan The Machine", + "gracehop/nickname": "Amazing Grace", + }); err != nil { + log.Fatalln("Error updating children:", err) + } + // [END update_children] +} + +func overwriteValue(ctx context.Context, usersRef *db.Ref) { + // [START overwrite_value] + if err := usersRef.Update(ctx, map[string]interface{}{ + "alanisawesome": &User{Nickname: "Alan The Machine"}, + "gracehop": &User{Nickname: "Amazing Grace"}, + }); err != nil { + log.Fatalln("Error updating children:", err) + } + // [END overwrite_value] +} + +// [START post_type] + +// Post is a json-serializable type. +type Post struct { + Author string `json:"author,omitempty"` + Title string `json:"title,omitempty"` +} + +// [END post_type] + +func pushValue(ctx context.Context, ref *db.Ref) { + // [START push_value] + postsRef := ref.Child("posts") + + newPostRef, err := postsRef.Push(ctx, nil) + if err != nil { + log.Fatalln("Error pushing child node:", err) + } + + if err := newPostRef.Set(ctx, &Post{ + Author: "gracehop", + Title: "Announcing COBOL, a New Programming Language", + }); err != nil { + log.Fatalln("Error setting value:", err) + } + + // We can also chain the two calls together + if _, err := postsRef.Push(ctx, &Post{ + Author: "alanisawesome", + Title: "The Turing Machine", + }); err != nil { + log.Fatalln("Error pushing child node:", err) + } + // [END push_value] +} + +func pushAndSetValue(ctx context.Context, postsRef *db.Ref) { + // [START push_and_set_value] + if _, err := postsRef.Push(ctx, &Post{ + Author: "gracehop", + Title: "Announcing COBOL, a New Programming Language", + }); err != nil { + log.Fatalln("Error pushing child node:", err) + } + // [END push_and_set_value] +} + +func pushKey(ctx context.Context, postsRef *db.Ref) { + // [START push_key] + // Generate a reference to a new location and add some data using Push() + newPostRef, err := postsRef.Push(ctx, nil) + if err != nil { + log.Fatalln("Error pushing child node:", err) + } + + // Get the unique key generated by Push() + postID := newPostRef.Key + // [END push_key] + fmt.Println(postID) +} + +func transaction(ctx context.Context, client *db.Client) { + // [START transaction] + fn := func(t db.TransactionNode) (interface{}, error) { + var currentValue int + if err := t.Unmarshal(¤tValue); err != nil { + return nil, err + } + return currentValue + 1, nil + } + + ref := client.NewRef("server/saving-data/fireblog/posts/-JRHTHaIs-jNPLXOQivY/upvotes") + if err := ref.Transaction(ctx, fn); err != nil { + log.Fatalln("Transaction failed to commit:", err) + } + // [END transaction] +} + +func readValue(ctx context.Context, app *firebase.App) { + // [START read_value] + // Create a database client from App. + client, err := app.Database(ctx) + if err != nil { + log.Fatalln("Error initializing database client:", err) + } + + // Get a database reference to our posts + ref := client.NewRef("server/saving-data/fireblog/posts") + + // Read the data at the posts reference (this is a blocking operation) + var post Post + if err := ref.Get(ctx, &post); err != nil { + log.Fatalln("Error reading value:", err) + } + // [END read_value] + fmt.Println(ref.Path) +} + +// [START dinosaur_type] + +// Dinosaur is a json-serializable type. +type Dinosaur struct { + Height int `json:"height"` + Width int `json:"width"` +} + +// [END dinosaur_type] + +func orderByChild(ctx context.Context, client *db.Client) { + // [START order_by_child] + ref := client.NewRef("dinosaurs") + + results, err := ref.OrderByChild("height").GetOrdered(ctx) + if err != nil { + log.Fatalln("Error querying database:", err) + } + for _, r := range results { + var d Dinosaur + if err := r.Unmarshal(&d); err != nil { + log.Fatalln("Error unmarshaling result:", err) + } + fmt.Printf("%s was %d meteres tall", r.Key(), d.Height) + } + // [END order_by_child] +} + +func orderByNestedChild(ctx context.Context, client *db.Client) { + // [START order_by_nested_child] + ref := client.NewRef("dinosaurs") + + results, err := ref.OrderByChild("dimensions/height").GetOrdered(ctx) + if err != nil { + log.Fatalln("Error querying database:", err) + } + for _, r := range results { + var d Dinosaur + if err := r.Unmarshal(&d); err != nil { + log.Fatalln("Error unmarshaling result:", err) + } + fmt.Printf("%s was %d meteres tall", r.Key(), d.Height) + } + // [END order_by_nested_child] +} + +func orderByKey(ctx context.Context, client *db.Client) { + // [START order_by_key] + ref := client.NewRef("dinosaurs") + + results, err := ref.OrderByKey().GetOrdered(ctx) + if err != nil { + log.Fatalln("Error querying database:", err) + } + snapshot := make([]Dinosaur, len(results)) + for i, r := range results { + var d Dinosaur + if err := r.Unmarshal(&d); err != nil { + log.Fatalln("Error unmarshaling result:", err) + } + snapshot[i] = d + } + fmt.Println(snapshot) + // [END order_by_key] +} + +func orderByValue(ctx context.Context, client *db.Client) { + // [START order_by_value] + ref := client.NewRef("scores") + + results, err := ref.OrderByValue().GetOrdered(ctx) + if err != nil { + log.Fatalln("Error querying database:", err) + } + for _, r := range results { + var score int + if err := r.Unmarshal(&score); err != nil { + log.Fatalln("Error unmarshaling result:", err) + } + fmt.Printf("The %s dinosaur's score is %d\n", r.Key(), score) + } + // [END order_by_value] +} + +func limitToLast(ctx context.Context, client *db.Client) { + // [START limit_query_1] + ref := client.NewRef("dinosaurs") + + results, err := ref.OrderByChild("weight").LimitToLast(2).GetOrdered(ctx) + if err != nil { + log.Fatalln("Error querying database:", err) + } + for _, r := range results { + fmt.Println(r.Key()) + } + // [END limit_query_1] +} + +func limitToFirst(ctx context.Context, client *db.Client) { + // [START limit_query_2] + ref := client.NewRef("dinosaurs") + + results, err := ref.OrderByChild("height").LimitToFirst(2).GetOrdered(ctx) + if err != nil { + log.Fatalln("Error querying database:", err) + } + for _, r := range results { + fmt.Println(r.Key()) + } + // [END limit_query_2] +} + +func limitWithValueOrder(ctx context.Context, client *db.Client) { + // [START limit_query_3] + ref := client.NewRef("scores") + + results, err := ref.OrderByValue().LimitToLast(3).GetOrdered(ctx) + if err != nil { + log.Fatalln("Error querying database:", err) + } + for _, r := range results { + var score int + if err := r.Unmarshal(&score); err != nil { + log.Fatalln("Error unmarshaling result:", err) + } + fmt.Printf("The %s dinosaur's score is %d\n", r.Key(), score) + } + // [END limit_query_3] +} + +func startAt(ctx context.Context, client *db.Client) { + // [START range_query_1] + ref := client.NewRef("dinosaurs") + + results, err := ref.OrderByChild("height").StartAt(3).GetOrdered(ctx) + if err != nil { + log.Fatalln("Error querying database:", err) + } + for _, r := range results { + fmt.Println(r.Key()) + } + // [END range_query_1] +} + +func endAt(ctx context.Context, client *db.Client) { + // [START range_query_2] + ref := client.NewRef("dinosaurs") + + results, err := ref.OrderByKey().EndAt("pterodactyl").GetOrdered(ctx) + if err != nil { + log.Fatalln("Error querying database:", err) + } + for _, r := range results { + fmt.Println(r.Key()) + } + // [END range_query_2] +} + +func startAndEndAt(ctx context.Context, client *db.Client) { + // [START range_query_3] + ref := client.NewRef("dinosaurs") + + results, err := ref.OrderByKey().StartAt("b").EndAt("b\uf8ff").GetOrdered(ctx) + if err != nil { + log.Fatalln("Error querying database:", err) + } + for _, r := range results { + fmt.Println(r.Key()) + } + // [END range_query_3] +} + +func equalTo(ctx context.Context, client *db.Client) { + // [START range_query_4] + ref := client.NewRef("dinosaurs") + + results, err := ref.OrderByChild("height").EqualTo(25).GetOrdered(ctx) + if err != nil { + log.Fatalln("Error querying database:", err) + } + for _, r := range results { + fmt.Println(r.Key()) + } + // [END range_query_4] +} + +func complexQuery(ctx context.Context, client *db.Client) { + // [START complex_query] + ref := client.NewRef("dinosaurs") + + var favDinoHeight int + if err := ref.Child("stegosaurus").Child("height").Get(ctx, &favDinoHeight); err != nil { + log.Fatalln("Error querying database:", err) + } + + query := ref.OrderByChild("height").EndAt(favDinoHeight).LimitToLast(2) + results, err := query.GetOrdered(ctx) + if err != nil { + log.Fatalln("Error querying database:", err) + } + if len(results) == 2 { + // Data is ordered by increasing height, so we want the first entry. + // Second entry is stegosarus. + fmt.Printf("The dinosaur just shorter than the stegosaurus is %s\n", results[0].Key()) + } else { + fmt.Println("The stegosaurus is the shortest dino") + } + // [END complex_query] +}