diff --git a/web/cache.go b/web/cache.go new file mode 100644 index 00000000..9c567214 --- /dev/null +++ b/web/cache.go @@ -0,0 +1,91 @@ +// Copyright 2021 The Prometheus Authors +// This code is partly borrowed from Caddy: +// Copyright 2015 Matthew Holt and The Caddy Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package web + +import ( + weakrand "math/rand" + "sync" + "time" +) + +var cacheSize = 100 + +func init() { + weakrand.Seed(time.Now().UnixNano()) +} + +type cache struct { + cache map[string]bool + mtx sync.Mutex +} + +// newCache returns a cache that contains a mapping of plaintext passwords +// to their hashes (with random eviction). This can greatly improve the +// performance of traffic-heavy servers that use secure password hashing +// algorithms, with the downside that plaintext passwords will be stored in +// memory for a longer time (this should not be a problem as long as your +// machine is not compromised, at which point all bets are off, since basicauth +// necessitates plaintext passwords being received over the wire anyway). +func newCache(size int) *cache { + return &cache{ + cache: make(map[string]bool, size), + } +} + +func (c *cache) get(key string) (bool, bool) { + c.mtx.Lock() + defer c.mtx.Unlock() + v, ok := c.cache[key] + return v, ok +} + +func (c *cache) set(key string, value bool) { + c.mtx.Lock() + defer c.mtx.Unlock() + c.makeRoom() + c.cache[key] = value +} + +func (c *cache) makeRoom() { + if len(c.cache) < cacheSize { + return + } + // We delete more than just 1 entry so that we don't have + // to do this on every request; assuming the capacity of + // the cache is on a long tail, we can save a lot of CPU + // time by doing a whole bunch of deletions now and then + // we won't have to do them again for a while. + numToDelete := len(c.cache) / 10 + if numToDelete < 1 { + numToDelete = 1 + } + for deleted := 0; deleted <= numToDelete; deleted++ { + // Go maps are "nondeterministic" not actually random, + // so although we could just chop off the "front" of the + // map with less code, this is a heavily skewed eviction + // strategy; generating random numbers is cheap and + // ensures a much better distribution. + rnd := weakrand.Intn(len(c.cache)) + i := 0 + for key := range c.cache { + if i == rnd { + delete(c.cache, key) + break + } + i++ + } + } +} diff --git a/web/cache_test.go b/web/cache_test.go new file mode 100644 index 00000000..eb165298 --- /dev/null +++ b/web/cache_test.go @@ -0,0 +1,37 @@ +// Copyright 2021 The Prometheus Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package web + +import ( + "fmt" + "testing" +) + +// TestCacheSize validates that makeRoom function caps the size of the cache +// appropriately. +func TestCacheSize(t *testing.T) { + cache := newCache(100) + expectedSize := 0 + for i := 0; i < 200; i++ { + cache.set(fmt.Sprintf("foo%d", i), true) + expectedSize++ + if expectedSize > 100 { + expectedSize = 90 + } + + if gotSize := len(cache.cache); gotSize != expectedSize { + t.Fatalf("iter %d: cache size invalid: expected %d, got %d", i, expectedSize, gotSize) + } + } +} diff --git a/web/tls_config.go b/web/tls_config.go index 2f244a51..3b2a959d 100644 --- a/web/tls_config.go +++ b/web/tls_config.go @@ -201,17 +201,19 @@ func Serve(l net.Listener, server *http.Server, tlsConfigPath string, logger log if server.Handler != nil { handler = server.Handler } - server.Handler = &userAuthRoundtrip{ - tlsConfigPath: tlsConfigPath, - logger: logger, - handler: handler, - } c, err := getConfig(tlsConfigPath) if err != nil { return err } + server.Handler = &userAuthRoundtrip{ + tlsConfigPath: tlsConfigPath, + logger: logger, + handler: handler, + cache: newCache(len(c.Users)), + } + config, err := ConfigToTLSConfig(&c.TLSConfig) switch err { case nil: diff --git a/web/tls_config_test.go b/web/tls_config_test.go index 4b3e3c36..80641a14 100644 --- a/web/tls_config_test.go +++ b/web/tls_config_test.go @@ -382,16 +382,14 @@ func (test *TestInputs) Test(t *testing.T) { w.Write([]byte("Hello World!")) }), } - defer func() { - server.Close() - }() + t.Cleanup(func() { server.Close() }) go func() { defer func() { if recover() != nil { recordConnectionError(errors.New("Panic starting server")) } }() - err := Listen(server, test.YAMLConfigPath, testlogger) + err := ListenAndServe(server, test.YAMLConfigPath, testlogger) recordConnectionError(err) }() diff --git a/web/users.go b/web/users.go index 7b9cd6a2..3f0d132d 100644 --- a/web/users.go +++ b/web/users.go @@ -1,4 +1,6 @@ // Copyright 2020 The Prometheus Authors +// This code is partly borrowed from Caddy: +// Copyright 2015 Matthew Holt and The Caddy Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -14,7 +16,9 @@ package web import ( + "encoding/hex" "net/http" + "sync" "github.com/go-kit/kit/log" "golang.org/x/crypto/bcrypt" @@ -40,6 +44,10 @@ type userAuthRoundtrip struct { tlsConfigPath string handler http.Handler logger log.Logger + cache *cache + // bcryptMtx is there to ensure that bcrypt.CompareHashAndPassword is run + // only once in parallel as this is CPU expansive. + bcryptMtx sync.Mutex } func (u *userAuthRoundtrip) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -58,7 +66,20 @@ func (u *userAuthRoundtrip) ServeHTTP(w http.ResponseWriter, r *http.Request) { user, pass, auth := r.BasicAuth() if auth { if hashedPassword, ok := c.Users[user]; ok { - if err := bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(pass)); err == nil { + cacheKey := hex.EncodeToString(append(append([]byte(user), []byte(hashedPassword)...), []byte(pass)...)) + authOk, ok := u.cache.get(cacheKey) + + if !ok { + // This user, hashedPassword, password is not cached. + u.bcryptMtx.Lock() + err := bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(pass)) + u.bcryptMtx.Unlock() + + authOk = err == nil + u.cache.set(cacheKey, authOk) + } + + if authOk { u.handler.ServeHTTP(w, r) return } diff --git a/web/users_test.go b/web/users_test.go new file mode 100644 index 00000000..6a46917b --- /dev/null +++ b/web/users_test.go @@ -0,0 +1,85 @@ +// Copyright 2021 The Prometheus Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package web + +import ( + "context" + "net/http" + "sync" + "testing" +) + +// TestBasicAuthCache validates that the cache is working by calling a password +// protected endpoint multiple times. +func TestBasicAuthCache(t *testing.T) { + server := &http.Server{ + Addr: port, + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("Hello World!")) + }), + } + + done := make(chan struct{}) + t.Cleanup(func() { + if err := server.Shutdown(context.Background()); err != nil { + t.Fatal(err) + } + <-done + }) + + go func() { + ListenAndServe(server, "testdata/tls_config_users_noTLS.good.yml", testlogger) + close(done) + }() + + login := func(username, password string, code int) { + client := &http.Client{} + req, err := http.NewRequest("GET", "http://localhost"+port, nil) + if err != nil { + t.Fatal(err) + } + req.SetBasicAuth(username, password) + r, err := client.Do(req) + if err != nil { + t.Fatal(err) + } + if r.StatusCode != code { + t.Fatalf("bad return code, expected %d, got %d", code, r.StatusCode) + } + } + + // Initial logins, checking that it just works. + login("alice", "alice123", 200) + login("alice", "alice1234", 401) + + var ( + start = make(chan struct{}) + wg sync.WaitGroup + ) + wg.Add(300) + for i := 0; i < 150; i++ { + go func() { + <-start + login("alice", "alice123", 200) + wg.Done() + }() + go func() { + <-start + login("alice", "alice1234", 401) + wg.Done() + }() + } + close(start) + wg.Wait() +}