|
16 | 16 | package web |
17 | 17 |
|
18 | 18 | import ( |
| 19 | + "context" |
19 | 20 | "crypto/tls" |
20 | 21 | "crypto/x509" |
21 | 22 | "errors" |
@@ -382,16 +383,14 @@ func (test *TestInputs) Test(t *testing.T) { |
382 | 383 | w.Write([]byte("Hello World!")) |
383 | 384 | }), |
384 | 385 | } |
385 | | - defer func() { |
386 | | - server.Close() |
387 | | - }() |
| 386 | + t.Cleanup(func() { server.Close() }) |
388 | 387 | go func() { |
389 | 388 | defer func() { |
390 | 389 | if recover() != nil { |
391 | 390 | recordConnectionError(errors.New("Panic starting server")) |
392 | 391 | } |
393 | 392 | }() |
394 | | - err := Listen(server, test.YAMLConfigPath, testlogger) |
| 393 | + err := ListenAndServe(server, test.YAMLConfigPath, testlogger) |
395 | 394 | recordConnectionError(err) |
396 | 395 | }() |
397 | 396 |
|
@@ -587,3 +586,67 @@ func TestUsers(t *testing.T) { |
587 | 586 | t.Run(testInputs.Name, testInputs.Test) |
588 | 587 | } |
589 | 588 | } |
| 589 | + |
| 590 | +// TestBasicAuthCache validates that the cache is working by calling a password |
| 591 | +// protected endpoint multiple times. |
| 592 | +func TestBasicAuthCache(t *testing.T) { |
| 593 | + server := &http.Server{ |
| 594 | + Addr: port, |
| 595 | + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| 596 | + w.Write([]byte("Hello World!")) |
| 597 | + }), |
| 598 | + } |
| 599 | + |
| 600 | + done := make(chan struct{}) |
| 601 | + t.Cleanup(func() { |
| 602 | + if err := server.Shutdown(context.Background()); err != nil { |
| 603 | + t.Fatal(err) |
| 604 | + } |
| 605 | + <-done |
| 606 | + }) |
| 607 | + |
| 608 | + go func() { |
| 609 | + ListenAndServe(server, "testdata/tls_config_users_noTLS.good.yml", testlogger) |
| 610 | + close(done) |
| 611 | + }() |
| 612 | + |
| 613 | + login := func(username, password string, code int) { |
| 614 | + client := &http.Client{} |
| 615 | + req, err := http.NewRequest("GET", "http://localhost"+port, nil) |
| 616 | + if err != nil { |
| 617 | + t.Fatal(err) |
| 618 | + } |
| 619 | + req.SetBasicAuth(username, password) |
| 620 | + r, err := client.Do(req) |
| 621 | + if err != nil { |
| 622 | + t.Fatal(err) |
| 623 | + } |
| 624 | + if r.StatusCode != code { |
| 625 | + t.Fatalf("bad return code, expected %d, got %d", code, r.StatusCode) |
| 626 | + } |
| 627 | + } |
| 628 | + |
| 629 | + // Initial logins, checking that it just works. |
| 630 | + login("alice", "alice123", 200) |
| 631 | + login("alice", "alice1234", 401) |
| 632 | + |
| 633 | + var ( |
| 634 | + start = make(chan struct{}) |
| 635 | + wg sync.WaitGroup |
| 636 | + ) |
| 637 | + wg.Add(200) |
| 638 | + for i := 0; i < 100; i++ { |
| 639 | + go func() { |
| 640 | + <-start |
| 641 | + login("alice", "alice123", 200) |
| 642 | + wg.Done() |
| 643 | + }() |
| 644 | + go func() { |
| 645 | + <-start |
| 646 | + login("alice", "alice1234", 401) |
| 647 | + wg.Done() |
| 648 | + }() |
| 649 | + } |
| 650 | + close(start) |
| 651 | + wg.Wait() |
| 652 | +} |
0 commit comments