diff --git a/evp.go b/evp.go index da71d9b3..81595bae 100644 --- a/evp.go +++ b/evp.go @@ -60,23 +60,21 @@ func cryptoHashToMD(ch crypto.Hash) (md C.GO_EVP_MD_PTR) { } cacheMD.Store(ch, md) }() - // SupportsHash returns false for MD5 and MD5SHA1 because we don't - // provide a hash.Hash implementation for them. Yet, they can + // SupportsHash returns false for MD5SHA1 because we don't + // provide a hash.Hash implementation for it. Yet, it can // still be used when signing/verifying with an RSA key. - switch ch { - case crypto.MD5: - return C.go_openssl_EVP_md5() - case crypto.MD5SHA1: + if ch == crypto.MD5SHA1 { if vMajor == 1 && vMinor == 0 { return C.go_openssl_EVP_md5_sha1_backport() } else { return C.go_openssl_EVP_md5_sha1() } } - if !SupportsHash(ch) { - return nil - } switch ch { + case crypto.MD4: + return C.go_openssl_EVP_md4() + case crypto.MD5: + return C.go_openssl_EVP_md5() case crypto.SHA1: return C.go_openssl_EVP_sha1() case crypto.SHA224: @@ -88,13 +86,21 @@ func cryptoHashToMD(ch crypto.Hash) (md C.GO_EVP_MD_PTR) { case crypto.SHA512: return C.go_openssl_EVP_sha512() case crypto.SHA3_224: - return C.go_openssl_EVP_sha3_224() + if version1_1_1_or_above() { + return C.go_openssl_EVP_sha3_224() + } case crypto.SHA3_256: - return C.go_openssl_EVP_sha3_256() + if version1_1_1_or_above() { + return C.go_openssl_EVP_sha3_256() + } case crypto.SHA3_384: - return C.go_openssl_EVP_sha3_384() + if version1_1_1_or_above() { + return C.go_openssl_EVP_sha3_384() + } case crypto.SHA3_512: - return C.go_openssl_EVP_sha3_512() + if version1_1_1_or_above() { + return C.go_openssl_EVP_sha3_512() + } } return nil } diff --git a/goopenssl.h b/goopenssl.h index dc2ce35c..d9efb1ba 100644 --- a/goopenssl.h +++ b/goopenssl.h @@ -66,14 +66,14 @@ FOR_ALL_OPENSSL_FUNCTIONS #undef DEFINEFUNC_RENAMED_1_1 #undef DEFINEFUNC_RENAMED_3_0 -// go_sha_sum copies ctx into ctx2 and calls EVP_DigestFinal using ctx2. +// go_hash_sum copies ctx into ctx2 and calls EVP_DigestFinal using ctx2. // This is necessary because Go hash.Hash mandates that Sum has no effect // on the underlying stream. In particular it is OK to Sum, then Write more, // then Sum again, and the second Sum acts as if the first didn't happen. // It is written in C because Sum() tend to be in the hot path, // and doing one cgo call instead of two is a significant performance win. static inline int -go_sha_sum(GO_EVP_MD_CTX_PTR ctx, GO_EVP_MD_CTX_PTR ctx2, unsigned char *out) +go_hash_sum(GO_EVP_MD_CTX_PTR ctx, GO_EVP_MD_CTX_PTR ctx2, unsigned char *out) { if (go_openssl_EVP_MD_CTX_copy(ctx2, ctx) != 1) return 0; diff --git a/sha.go b/hash.go similarity index 82% rename from sha.go rename to hash.go index c15e1d77..646b4ce2 100644 --- a/sha.go +++ b/hash.go @@ -24,40 +24,54 @@ import ( // and applying a noescape along the way. // This is all to preserve compatibility with the allocation behavior of the non-openssl implementations. -func shaX(ch crypto.Hash, p []byte, sum []byte) bool { +func hashOneShot(ch crypto.Hash, p []byte, sum []byte) bool { return C.go_openssl_EVP_Digest(unsafe.Pointer(&*addr(p)), C.size_t(len(p)), (*C.uchar)(unsafe.Pointer(&*addr(sum))), nil, cryptoHashToMD(ch), nil) != 0 } +func MD4(p []byte) (sum [16]byte) { + if !hashOneShot(crypto.MD4, p, sum[:]) { + panic("openssl: MD4 failed") + } + return +} + +func MD5(p []byte) (sum [16]byte) { + if !hashOneShot(crypto.MD5, p, sum[:]) { + panic("openssl: MD5 failed") + } + return +} + func SHA1(p []byte) (sum [20]byte) { - if !shaX(crypto.SHA1, p, sum[:]) { + if !hashOneShot(crypto.SHA1, p, sum[:]) { panic("openssl: SHA1 failed") } return } func SHA224(p []byte) (sum [28]byte) { - if !shaX(crypto.SHA224, p, sum[:]) { + if !hashOneShot(crypto.SHA224, p, sum[:]) { panic("openssl: SHA224 failed") } return } func SHA256(p []byte) (sum [32]byte) { - if !shaX(crypto.SHA256, p, sum[:]) { + if !hashOneShot(crypto.SHA256, p, sum[:]) { panic("openssl: SHA256 failed") } return } func SHA384(p []byte) (sum [48]byte) { - if !shaX(crypto.SHA384, p, sum[:]) { + if !hashOneShot(crypto.SHA384, p, sum[:]) { panic("openssl: SHA384 failed") } return } func SHA512(p []byte) (sum [64]byte) { - if !shaX(crypto.SHA512, p, sum[:]) { + if !hashOneShot(crypto.SHA512, p, sum[:]) { panic("openssl: SHA512 failed") } return @@ -65,40 +79,32 @@ func SHA512(p []byte) (sum [64]byte) { // SupportsHash returns true if a hash.Hash implementation is supported for h. func SupportsHash(h crypto.Hash) bool { - switch h { - case crypto.SHA1, crypto.SHA224, crypto.SHA256, crypto.SHA384, crypto.SHA512: - return true - case crypto.SHA3_224, crypto.SHA3_256, crypto.SHA3_384, crypto.SHA3_512: - return vMajor > 1 || - (vMajor >= 1 && vMinor > 1) || - (vMajor >= 1 && vMinor >= 1 && vPatch >= 1) - } - return false + return cryptoHashToMD(h) != nil } func SHA3_224(p []byte) (sum [28]byte) { - if !shaX(crypto.SHA3_224, p, sum[:]) { + if !hashOneShot(crypto.SHA3_224, p, sum[:]) { panic("openssl: SHA3_224 failed") } return } func SHA3_256(p []byte) (sum [32]byte) { - if !shaX(crypto.SHA3_256, p, sum[:]) { + if !hashOneShot(crypto.SHA3_256, p, sum[:]) { panic("openssl: SHA3_256 failed") } return } func SHA3_384(p []byte) (sum [48]byte) { - if !shaX(crypto.SHA3_384, p, sum[:]) { + if !hashOneShot(crypto.SHA3_384, p, sum[:]) { panic("openssl: SHA3_384 failed") } return } func SHA3_512(p []byte) (sum [64]byte) { - if !shaX(crypto.SHA3_512, p, sum[:]) { + if !hashOneShot(crypto.SHA3_512, p, sum[:]) { panic("openssl: SHA3_512 failed") } return @@ -183,17 +189,17 @@ func (h *evpHash) BlockSize() int { } func (h *evpHash) sum(out []byte) { - if C.go_sha_sum(h.ctx, h.ctx2, base(out)) != 1 { - panic(newOpenSSLError("go_sha_sum")) + if C.go_hash_sum(h.ctx, h.ctx2, base(out)) != 1 { + panic(newOpenSSLError("go_hash_sum")) } runtime.KeepAlive(h) } -// shaState returns a pointer to the internal sha structure. +// hashState returns a pointer to the internal hash structure. // // The EVP_MD_CTX memory layout has changed in OpenSSL 3 // and the property holding the internal structure is no longer md_data but algctx. -func (h *evpHash) shaState() unsafe.Pointer { +func (h *evpHash) hashState() unsafe.Pointer { switch vMajor { case 1: // https://github.com/openssl/openssl/blob/0418e993c717a6863f206feaa40673a261de7395/crypto/evp/evp_local.h#L12. @@ -217,6 +223,97 @@ func (h *evpHash) shaState() unsafe.Pointer { } } +// NewMD4 returns a new MD4 hash. +// The returned hash doesn't implement encoding.BinaryMarshaler and +// encoding.BinaryUnmarshaler. +func NewMD4() hash.Hash { + return &md4Hash{ + evpHash: newEvpHash(crypto.MD4, 16, 64), + } +} + +type md4Hash struct { + *evpHash + out [16]byte +} + +func (h *md4Hash) Sum(in []byte) []byte { + h.sum(h.out[:]) + return append(in, h.out[:]...) +} + +// NewMD5 returns a new MD5 hash. +func NewMD5() hash.Hash { + return &md5Hash{ + evpHash: newEvpHash(crypto.MD5, 16, 64), + } +} + +// md5State layout is taken from +// https://github.com/openssl/openssl/blob/0418e993c717a6863f206feaa40673a261de7395/include/openssl/md5.h#L33. +type md5State struct { + h [4]uint32 + nl, nh uint32 + x [64]byte + nx uint32 +} + +type md5Hash struct { + *evpHash + out [16]byte +} + +func (h *md5Hash) Sum(in []byte) []byte { + h.sum(h.out[:]) + return append(in, h.out[:]...) +} + +const ( + md5Magic = "md5\x01" + md5MarshaledSize = len(md5Magic) + 4*4 + 64 + 8 +) + +func (h *md5Hash) MarshalBinary() ([]byte, error) { + d := (*md5State)(h.hashState()) + if d == nil { + return nil, errors.New("crypto/md5: can't retrieve hash state") + } + b := make([]byte, 0, md5MarshaledSize) + b = append(b, md5Magic...) + b = appendUint32(b, d.h[0]) + b = appendUint32(b, d.h[1]) + b = appendUint32(b, d.h[2]) + b = appendUint32(b, d.h[3]) + b = append(b, d.x[:d.nx]...) + b = b[:len(b)+len(d.x)-int(d.nx)] // already zero + b = appendUint64(b, uint64(d.nl)>>3|uint64(d.nh)<<29) + return b, nil +} + +func (h *md5Hash) UnmarshalBinary(b []byte) error { + if len(b) < len(md5Magic) || string(b[:len(md5Magic)]) != md5Magic { + return errors.New("crypto/md5: invalid hash state identifier") + } + if len(b) != md5MarshaledSize { + return errors.New("crypto/md5: invalid hash state size") + } + d := (*md5State)(h.hashState()) + if d == nil { + return errors.New("crypto/md5: can't retrieve hash state") + } + b = b[len(md5Magic):] + b, d.h[0] = consumeUint32(b) + b, d.h[1] = consumeUint32(b) + b, d.h[2] = consumeUint32(b) + b, d.h[3] = consumeUint32(b) + b = b[copy(d.x[:], b):] + _, n := consumeUint64(b) + d.nl = uint32(n << 3) + d.nh = uint32(n >> 29) + d.nx = uint32(n) % 64 + return nil +} + // NewSHA1 returns a new SHA1 hash. func NewSHA1() hash.Hash { return &sha1Hash{ @@ -249,7 +346,7 @@ const ( ) func (h *sha1Hash) MarshalBinary() ([]byte, error) { - d := (*sha1State)(h.shaState()) + d := (*sha1State)(h.hashState()) if d == nil { return nil, errors.New("crypto/sha1: can't retrieve hash state") } @@ -273,7 +370,7 @@ func (h *sha1Hash) UnmarshalBinary(b []byte) error { if len(b) != sha1MarshaledSize { return errors.New("crypto/sha1: invalid hash state size") } - d := (*sha1State)(h.shaState()) + d := (*sha1State)(h.hashState()) if d == nil { return errors.New("crypto/sha1: can't retrieve hash state") } @@ -341,7 +438,7 @@ type sha256State struct { } func (h *sha224Hash) MarshalBinary() ([]byte, error) { - d := (*sha256State)(h.shaState()) + d := (*sha256State)(h.hashState()) if d == nil { return nil, errors.New("crypto/sha256: can't retrieve hash state") } @@ -362,7 +459,7 @@ func (h *sha224Hash) MarshalBinary() ([]byte, error) { } func (h *sha256Hash) MarshalBinary() ([]byte, error) { - d := (*sha256State)(h.shaState()) + d := (*sha256State)(h.hashState()) if d == nil { return nil, errors.New("crypto/sha256: can't retrieve hash state") } @@ -389,7 +486,7 @@ func (h *sha224Hash) UnmarshalBinary(b []byte) error { if len(b) != marshaledSize256 { return errors.New("crypto/sha256: invalid hash state size") } - d := (*sha256State)(h.shaState()) + d := (*sha256State)(h.hashState()) if d == nil { return errors.New("crypto/sha256: can't retrieve hash state") } @@ -417,7 +514,7 @@ func (h *sha256Hash) UnmarshalBinary(b []byte) error { if len(b) != marshaledSize256 { return errors.New("crypto/sha256: invalid hash state size") } - d := (*sha256State)(h.shaState()) + d := (*sha256State)(h.hashState()) if d == nil { return errors.New("crypto/sha256: can't retrieve hash state") } @@ -490,7 +587,7 @@ const ( ) func (h *sha384Hash) MarshalBinary() ([]byte, error) { - d := (*sha512State)(h.shaState()) + d := (*sha512State)(h.hashState()) if d == nil { return nil, errors.New("crypto/sha512: can't retrieve hash state") } @@ -511,7 +608,7 @@ func (h *sha384Hash) MarshalBinary() ([]byte, error) { } func (h *sha512Hash) MarshalBinary() ([]byte, error) { - d := (*sha512State)(h.shaState()) + d := (*sha512State)(h.hashState()) if d == nil { return nil, errors.New("crypto/sha512: can't retrieve hash state") } @@ -541,7 +638,7 @@ func (h *sha384Hash) UnmarshalBinary(b []byte) error { if len(b) != marshaledSize512 { return errors.New("crypto/sha512: invalid hash state size") } - d := (*sha512State)(h.shaState()) + d := (*sha512State)(h.hashState()) if d == nil { return errors.New("crypto/sha512: can't retrieve hash state") } @@ -572,7 +669,7 @@ func (h *sha512Hash) UnmarshalBinary(b []byte) error { if len(b) != marshaledSize512 { return errors.New("crypto/sha512: invalid hash state size") } - d := (*sha512State)(h.shaState()) + d := (*sha512State)(h.hashState()) if d == nil { return errors.New("crypto/sha512: can't retrieve hash state") } diff --git a/sha_test.go b/hash_test.go similarity index 86% rename from sha_test.go rename to hash_test.go index f76d3a36..7244038a 100644 --- a/sha_test.go +++ b/hash_test.go @@ -6,7 +6,6 @@ import ( "encoding" "hash" "io" - "strings" "testing" "github.com/golang-fips/openssl/v2" @@ -14,6 +13,10 @@ import ( func cryptoToHash(h crypto.Hash) func() hash.Hash { switch h { + case crypto.MD4: + return openssl.NewMD4 + case crypto.MD5: + return openssl.NewMD5 case crypto.SHA1: return openssl.NewSHA1 case crypto.SHA224: @@ -36,27 +39,32 @@ func cryptoToHash(h crypto.Hash) func() hash.Hash { return nil } -func TestSha(t *testing.T) { +func TestHash(t *testing.T) { msg := []byte("testing") - var tests = []crypto.Hash{ - crypto.SHA1, - crypto.SHA224, - crypto.SHA256, - crypto.SHA384, - crypto.SHA512, - crypto.SHA3_224, - crypto.SHA3_256, - crypto.SHA3_384, - crypto.SHA3_512, + var tests = []struct { + h crypto.Hash + hasMarshaler bool + }{ + {crypto.MD4, false}, + {crypto.MD5, true}, + {crypto.SHA1, true}, + {crypto.SHA224, true}, + {crypto.SHA256, true}, + {crypto.SHA384, true}, + {crypto.SHA512, true}, + {crypto.SHA3_224, false}, + {crypto.SHA3_256, false}, + {crypto.SHA3_384, false}, + {crypto.SHA3_512, false}, } for _, tt := range tests { tt := tt - t.Run(tt.String(), func(t *testing.T) { + t.Run(tt.h.String(), func(t *testing.T) { t.Parallel() - if !openssl.SupportsHash(tt) { + if !openssl.SupportsHash(tt.h) { t.Skip("skipping: not supported") } - h := cryptoToHash(tt)() + h := cryptoToHash(tt.h)() initSum := h.Sum(nil) n, err := h.Write(msg) if err != nil { @@ -72,12 +80,12 @@ func TestSha(t *testing.T) { if bytes.Equal(sum, initSum) { t.Error("Write didn't change internal hash state") } - if !strings.HasPrefix(tt.String(), "SHA3-") { + if tt.hasMarshaler { state, err := h.(encoding.BinaryMarshaler).MarshalBinary() if err != nil { t.Errorf("could not marshal: %v", err) } - h2 := cryptoToHash(tt)() + h2 := cryptoToHash(tt.h)() if err := h2.(encoding.BinaryUnmarshaler).UnmarshalBinary(state); err != nil { t.Errorf("could not unmarshal: %v", err) } @@ -111,7 +119,7 @@ func TestSha(t *testing.T) { } } -func TestSHA_OneShot(t *testing.T) { +func TestHash_OneShot(t *testing.T) { msg := []byte("testing") var tests = []struct { h crypto.Hash diff --git a/hkdf.go b/hkdf.go index 18b92059..0d33e34a 100644 --- a/hkdf.go +++ b/hkdf.go @@ -13,9 +13,7 @@ import ( ) func SupportsHKDF() bool { - return vMajor > 1 || - (vMajor >= 1 && vMinor > 1) || - (vMajor >= 1 && vMinor >= 1 && vPatch >= 1) + return version1_1_1_or_above() } func newHKDF(h func() hash.Hash, mode C.int) (*hkdf, error) { diff --git a/openssl.go b/openssl.go index 5c9324d8..14b1a81e 100644 --- a/openssl.go +++ b/openssl.go @@ -406,3 +406,7 @@ func bnToBinPad(bn C.GO_BIGNUM_PTR, to []byte) error { func CheckLeaks() { C.go_openssl_do_leak_check() } + +func version1_1_1_or_above() bool { + return vMajor > 1 || (vMajor >= 1 && vMinor > 1) || (vMajor >= 1 && vMinor >= 1 && vPatch >= 1) +} diff --git a/shims.h b/shims.h index 47d8724e..d2ab2f52 100644 --- a/shims.h +++ b/shims.h @@ -216,6 +216,7 @@ DEFINEFUNC_LEGACY_1_0(int, SHA1_Init, (GO_SHA_CTX_PTR c), (c)) \ DEFINEFUNC_LEGACY_1_0(int, SHA1_Update, (GO_SHA_CTX_PTR c, const void *data, size_t len), (c, data, len)) \ DEFINEFUNC_LEGACY_1_0(int, SHA1_Final, (unsigned char *md, GO_SHA_CTX_PTR c), (md, c)) \ DEFINEFUNC_1_1(const GO_EVP_MD_PTR, EVP_md5_sha1, (void), ()) \ +DEFINEFUNC(const GO_EVP_MD_PTR, EVP_md4, (void), ()) \ DEFINEFUNC(const GO_EVP_MD_PTR, EVP_md5, (void), ()) \ DEFINEFUNC(const GO_EVP_MD_PTR, EVP_sha1, (void), ()) \ DEFINEFUNC(const GO_EVP_MD_PTR, EVP_sha224, (void), ()) \