diff --git a/encryption.go b/encryption.go index b168cb4..da7484a 100644 --- a/encryption.go +++ b/encryption.go @@ -10,6 +10,7 @@ import ( "encoding/json" "errors" "fmt" + "sync/atomic" "github.com/golang-jwt/jwt" "gopkg.in/square/go-jose.v2" @@ -48,7 +49,7 @@ const DefaultKey = `MIIEowIBAAKCAQEAtI1Jf2zmfwLzpAjVarORtjKtmCHQtgNxqWDdVNVa` + type Keypair struct { PrivateKey *rsa.PrivateKey PublicKey *rsa.PublicKey - Kid string + Kid atomic.Value } // NewKeypair makes a Keypair off the provided rsa.PrivateKey or returns @@ -98,8 +99,14 @@ func DefaultKeypair() (*Keypair, error) { // If not manually set, computes the JWT headers' `kid` func (k *Keypair) KeyID() (string, error) { - if k.Kid != "" { - return k.Kid, nil + var kid string + existingKid := k.Kid.Load() + if existingKid != nil { + kid = existingKid.(string) + } + + if kid != "" { + return kid, nil } publicKeyDERBytes, err := x509.MarshalPKIXPublicKey(k.PublicKey) @@ -113,9 +120,10 @@ func (k *Keypair) KeyID() (string, error) { } publicKeyDERHash := hasher.Sum(nil) - k.Kid = base64.RawURLEncoding.EncodeToString(publicKeyDERHash) + newKeyID := base64.RawURLEncoding.EncodeToString(publicKeyDERHash) + k.Kid.Store(newKeyID) - return k.Kid, nil + return newKeyID, nil } // JWKS is the JSON JWKS representation of the rsa.PublicKey diff --git a/encryption_test.go b/encryption_test.go index 5dce607..5c14327 100644 --- a/encryption_test.go +++ b/encryption_test.go @@ -75,12 +75,12 @@ func TestKeypair_SignJWTVerifyJWT(t *testing.T) { assert.Equal(t, audience, claims["aud"]) assert.Equal(t, issuer, claims["iss"]) - alice.Kid = "WRONG" + alice.Kid.Store("WRONG") _, err = alice.VerifyJWT(tokenStr) assert.Error(t, err) const customKid = "USER_DEFINED" - bob.Kid = customKid + bob.Kid.Store(customKid) kidTokenStr, err := bob.SignJWT(standardClaims) assert.NoError(t, err) diff --git a/session.go b/session.go index a245caa..e87f1f0 100644 --- a/session.go +++ b/session.go @@ -3,6 +3,7 @@ package mockoidc import ( "errors" "strings" + "sync" "time" "github.com/golang-jwt/jwt" @@ -21,6 +22,7 @@ type Session struct { // SessionStore manages our Session objects type SessionStore struct { + sync.RWMutex Store map[string]*Session CodeQueue *CodeQueue } @@ -55,14 +57,18 @@ func (ss *SessionStore) NewSession(scope string, nonce string, user User, codeCh CodeChallenge: codeChallenge, CodeChallengeMethod: codeChallengeMethod, } + ss.Lock() ss.Store[sessionID] = session + ss.Unlock() return session, nil } // GetSessionByID looks up the Session func (ss *SessionStore) GetSessionByID(id string) (*Session, error) { + ss.RLock() session, ok := ss.Store[id] + ss.RUnlock() if !ok { return nil, errors.New("session not found") } diff --git a/session_test.go b/session_test.go index 54cd10b..ed44334 100644 --- a/session_test.go +++ b/session_test.go @@ -43,7 +43,9 @@ func TestSessionStore_NewSession(t *testing.T) { assert.NoError(t, err) assert.Equal(t, session.Scopes, []string{"openid", "email", "profile"}) assert.Equal(t, len(ss.Store), 1) + ss.RLock() assert.Equal(t, ss.Store[session.SessionID], session) + ss.RUnlock() assert.Equal(t, session.CodeChallenge, "sum") assert.Equal(t, session.CodeChallengeMethod, "S256") }