diff --git a/core/codec/aesecb.go b/core/codec/aesecb.go index 1ec99a04c..0ee472ec4 100644 --- a/core/codec/aesecb.go +++ b/core/codec/aesecb.go @@ -6,8 +6,6 @@ import ( "crypto/cipher" "encoding/base64" "errors" - - "github.com/zeromicro/go-zero/core/logx" ) // ErrPaddingSize indicates bad padding size. @@ -27,7 +25,8 @@ func newECB(b cipher.Block) *ecb { type ecbEncrypter ecb -// NewECBEncrypter returns an ECB encrypter. +// Deprecated: NewECBEncrypter returns an ECB encrypter. +// ECB mode is insecure for multi-block data. Use AES-GCM instead. func NewECBEncrypter(b cipher.Block) cipher.BlockMode { return (*ecbEncrypter)(newECB(b)) } @@ -39,12 +38,10 @@ func (x *ecbEncrypter) BlockSize() int { return x.blockSize } // the block size. Dst and src must overlap entirely or not at all. func (x *ecbEncrypter) CryptBlocks(dst, src []byte) { if len(src)%x.blockSize != 0 { - logx.Error("crypto/cipher: input not full blocks") - return + panic("crypto/cipher: input not full blocks") } if len(dst) < len(src) { - logx.Error("crypto/cipher: output smaller than input") - return + panic("crypto/cipher: output smaller than input") } for len(src) > 0 { @@ -56,7 +53,8 @@ func (x *ecbEncrypter) CryptBlocks(dst, src []byte) { type ecbDecrypter ecb -// NewECBDecrypter returns an ECB decrypter. +// Deprecated: NewECBDecrypter returns an ECB decrypter. +// ECB mode is insecure for multi-block data. Use AES-GCM instead. func NewECBDecrypter(b cipher.Block) cipher.BlockMode { return (*ecbDecrypter)(newECB(b)) } @@ -70,12 +68,10 @@ func (x *ecbDecrypter) BlockSize() int { // the block size. Dst and src must overlap entirely or not at all. func (x *ecbDecrypter) CryptBlocks(dst, src []byte) { if len(src)%x.blockSize != 0 { - logx.Error("crypto/cipher: input not full blocks") - return + panic("crypto/cipher: input not full blocks") } if len(dst) < len(src) { - logx.Error("crypto/cipher: output smaller than input") - return + panic("crypto/cipher: output smaller than input") } for len(src) > 0 { @@ -85,14 +81,18 @@ func (x *ecbDecrypter) CryptBlocks(dst, src []byte) { } } -// EcbDecrypt decrypts src with the given key. +// Deprecated: EcbDecrypt decrypts src with the given key. +// ECB mode is insecure for multi-block data. Use AES-GCM instead. func EcbDecrypt(key, src []byte) ([]byte, error) { block, err := aes.NewCipher(key) if err != nil { - logx.Errorf("Decrypt key error: % x", key) return nil, err } + if len(src)%block.BlockSize() != 0 { + return nil, ErrPaddingSize + } + decrypter := NewECBDecrypter(block) decrypted := make([]byte, len(src)) decrypter.CryptBlocks(decrypted, src) @@ -100,8 +100,9 @@ func EcbDecrypt(key, src []byte) ([]byte, error) { return pkcs5Unpadding(decrypted, decrypter.BlockSize()) } -// EcbDecryptBase64 decrypts base64 encoded src with the given base64 encoded key. +// Deprecated: EcbDecryptBase64 decrypts base64 encoded src with the given base64 encoded key. // The returned string is also base64 encoded. +// ECB mode is insecure for multi-block data. Use AES-GCM instead. func EcbDecryptBase64(key, src string) (string, error) { keyBytes, err := getKeyBytes(key) if err != nil { @@ -121,11 +122,11 @@ func EcbDecryptBase64(key, src string) (string, error) { return base64.StdEncoding.EncodeToString(decryptedBytes), nil } -// EcbEncrypt encrypts src with the given key. +// Deprecated: EcbEncrypt encrypts src with the given key. +// ECB mode is insecure for multi-block data. Use AES-GCM instead. func EcbEncrypt(key, src []byte) ([]byte, error) { block, err := aes.NewCipher(key) if err != nil { - logx.Errorf("Encrypt key error: % x", key) return nil, err } @@ -137,8 +138,9 @@ func EcbEncrypt(key, src []byte) ([]byte, error) { return crypted, nil } -// EcbEncryptBase64 encrypts base64 encoded src with the given base64 encoded key. +// Deprecated: EcbEncryptBase64 encrypts base64 encoded src with the given base64 encoded key. // The returned string is also base64 encoded. +// ECB mode is insecure for multi-block data. Use AES-GCM instead. func EcbEncryptBase64(key, src string) (string, error) { keyBytes, err := getKeyBytes(key) if err != nil { @@ -179,10 +181,20 @@ func pkcs5Padding(ciphertext []byte, blockSize int) []byte { func pkcs5Unpadding(src []byte, blockSize int) ([]byte, error) { length := len(src) - unpadding := int(src[length-1]) - if unpadding >= length || unpadding > blockSize { + if length == 0 { return nil, ErrPaddingSize } + unpadding := int(src[length-1]) + if unpadding < 1 || unpadding > blockSize || unpadding > length { + return nil, ErrPaddingSize + } + + for _, b := range src[length-unpadding:] { + if int(b) != unpadding { + return nil, ErrPaddingSize + } + } + return src[:length-unpadding], nil } diff --git a/core/codec/aesecb_test.go b/core/codec/aesecb_test.go index a1117f3ab..39cd9abbb 100644 --- a/core/codec/aesecb_test.go +++ b/core/codec/aesecb_test.go @@ -28,8 +28,8 @@ func TestAesEcb(t *testing.T) { _, err = EcbDecrypt(badKey2, dst) assert.NotNil(t, err) _, err = EcbDecrypt(key, val) - // not enough block, just nil - assert.Nil(t, err) + // not a multiple of block size + assert.NotNil(t, err) src, err := EcbDecrypt(key, dst) assert.Nil(t, err) assert.Equal(t, val, src) @@ -41,33 +41,28 @@ func TestAesEcb(t *testing.T) { assert.Equal(t, 16, decrypter.BlockSize()) dst = make([]byte, 8) - encrypter.CryptBlocks(dst, val) - for _, b := range dst { - assert.Equal(t, byte(0), b) - } + assert.Panics(t, func() { + encrypter.CryptBlocks(dst, val) + }) dst = make([]byte, 8) - encrypter.CryptBlocks(dst, valLong) - for _, b := range dst { - assert.Equal(t, byte(0), b) - } + assert.Panics(t, func() { + encrypter.CryptBlocks(dst, valLong) + }) dst = make([]byte, 8) - decrypter.CryptBlocks(dst, val) - for _, b := range dst { - assert.Equal(t, byte(0), b) - } + assert.Panics(t, func() { + decrypter.CryptBlocks(dst, val) + }) dst = make([]byte, 8) - decrypter.CryptBlocks(dst, valLong) - for _, b := range dst { - assert.Equal(t, byte(0), b) - } + assert.Panics(t, func() { + decrypter.CryptBlocks(dst, valLong) + }) _, err = EcbEncryptBase64("cTR0N3dDKkYtSmFOZFJnVWpYbjJyNXU4eC9BP0QK", "aGVsbG93b3JsZGxvbmcuLgo=") assert.Error(t, err) } - func TestAesEcbBase64(t *testing.T) { const ( val = "hello" @@ -98,3 +93,44 @@ func TestAesEcbBase64(t *testing.T) { assert.Nil(t, err) assert.Equal(t, val, string(b)) } + +func TestPkcs5UnpaddingEmptyInput(t *testing.T) { + _, err := pkcs5Unpadding([]byte{}, 16) + assert.Equal(t, ErrPaddingSize, err) +} + +func TestPkcs5UnpaddingMalformedPadding(t *testing.T) { + // Valid PKCS5 padding of 3: last 3 bytes should all be 0x03 + // Here we corrupt one padding byte + malformed := []byte{0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x02, 0x03, 0x03} + _, err := pkcs5Unpadding(malformed, 16) + assert.Equal(t, ErrPaddingSize, err) + + // All padding bytes correct + valid := []byte{0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x03, 0x03, 0x03} + result, err := pkcs5Unpadding(valid, 16) + assert.NoError(t, err) + assert.Equal(t, valid[:13], result) +} + +func TestPkcs5UnpaddingInvalidPaddingValue(t *testing.T) { + // padding value = 0 (< 1) + _, err := pkcs5Unpadding([]byte{0x41, 0x00}, 16) + assert.Equal(t, ErrPaddingSize, err) + + // padding value > blockSize + _, err = pkcs5Unpadding([]byte{0x41, 0x41, 0x41, 0x41, 17}, 4) + assert.Equal(t, ErrPaddingSize, err) + + // padding value > length + _, err = pkcs5Unpadding([]byte{0x41, 0x03}, 16) + assert.Equal(t, ErrPaddingSize, err) +} + +func TestEcbDecryptEmptyInput(t *testing.T) { + key := []byte("q4t7w!z%C*F-JaNdRgUjXn2r5u8x/A?D") + _, err := EcbDecrypt(key, []byte{}) + assert.Equal(t, ErrPaddingSize, err) +} diff --git a/core/codec/dh.go b/core/codec/dh.go index 63dd2d9c9..4f7f8c3fe 100644 --- a/core/codec/dh.go +++ b/core/codec/dh.go @@ -35,7 +35,7 @@ func ComputeKey(pubKey, priKey *big.Int) (*big.Int, error) { return nil, ErrInvalidPubKey } - if pubKey.Sign() <= 0 && p.Cmp(pubKey) <= 0 { + if pubKey.Sign() <= 0 || p.Cmp(pubKey) <= 0 { return nil, ErrPubKeyOutOfBound } diff --git a/core/codec/dh_test.go b/core/codec/dh_test.go index 9f788b445..48f1e4160 100644 --- a/core/codec/dh_test.go +++ b/core/codec/dh_test.go @@ -94,3 +94,32 @@ func TestDHOnErrors(t *testing.T) { assert.NotNil(t, NewPublicKey([]byte(""))) } + +func TestDHPubKeyBoundary(t *testing.T) { + key, err := GenerateKey() + assert.Nil(t, err) + + // pubKey = 0 should be rejected + _, err = ComputeKey(big.NewInt(0), key.PriKey) + assert.ErrorIs(t, err, ErrPubKeyOutOfBound) + + // pubKey = -1 should be rejected + _, err = ComputeKey(big.NewInt(-1), key.PriKey) + assert.ErrorIs(t, err, ErrPubKeyOutOfBound) + + // pubKey = p should be rejected + _, err = ComputeKey(new(big.Int).Set(p), key.PriKey) + assert.ErrorIs(t, err, ErrPubKeyOutOfBound) + + // pubKey = p+1 should be rejected + _, err = ComputeKey(new(big.Int).Add(p, big.NewInt(1)), key.PriKey) + assert.ErrorIs(t, err, ErrPubKeyOutOfBound) + + // pubKey = 1 should be accepted + _, err = ComputeKey(big.NewInt(1), key.PriKey) + assert.NoError(t, err) + + // pubKey = p-1 should be accepted + _, err = ComputeKey(new(big.Int).Sub(p, big.NewInt(1)), key.PriKey) + assert.NoError(t, err) +} diff --git a/core/codec/rsa.go b/core/codec/rsa.go index 5a54b8383..830f18466 100644 --- a/core/codec/rsa.go +++ b/core/codec/rsa.go @@ -3,6 +3,7 @@ package codec import ( "crypto/rand" "crypto/rsa" + "crypto/sha256" "crypto/x509" "encoding/base64" "encoding/pem" @@ -46,7 +47,9 @@ type ( } ) -// NewRsaDecrypter returns a RsaDecrypter with the given file. +// Deprecated: NewRsaDecrypter returns a RsaDecrypter with the given file. +// PKCS#1 v1.5 padding is vulnerable to padding oracle attacks. +// Use NewRsaOAEPDecrypter instead. func NewRsaDecrypter(file string) (RsaDecrypter, error) { content, err := os.ReadFile(file) if err != nil { @@ -90,7 +93,9 @@ func (r *rsaDecrypter) DecryptBase64(input string) ([]byte, error) { return r.Decrypt(base64Decoded) } -// NewRsaEncrypter returns a RsaEncrypter with the given key. +// Deprecated: NewRsaEncrypter returns a RsaEncrypter with the given key. +// PKCS#1 v1.5 padding is vulnerable to padding oracle attacks. +// Use NewRsaOAEPEncrypter instead. func NewRsaEncrypter(key []byte) (RsaEncrypter, error) { block, _ := pem.Decode(key) if block == nil { @@ -154,3 +159,90 @@ func rsaDecryptBlock(privateKey *rsa.PrivateKey, block []byte) ([]byte, error) { func rsaEncryptBlock(publicKey *rsa.PublicKey, msg []byte) ([]byte, error) { return rsa.EncryptPKCS1v15(rand.Reader, publicKey, msg) } + +// NewRsaOAEPDecrypter returns a RsaDecrypter using OAEP with SHA-256. +func NewRsaOAEPDecrypter(file string) (RsaDecrypter, error) { + content, err := os.ReadFile(file) + if err != nil { + return nil, err + } + + block, _ := pem.Decode(content) + if block == nil { + return nil, ErrPrivateKey + } + + privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return nil, err + } + + return &rsaOAEPDecrypter{ + rsaBase: rsaBase{ + bytesLimit: privateKey.N.BitLen() >> 3, + }, + privateKey: privateKey, + }, nil +} + +// NewRsaOAEPEncrypter returns a RsaEncrypter using OAEP with SHA-256. +func NewRsaOAEPEncrypter(key []byte) (RsaEncrypter, error) { + block, _ := pem.Decode(key) + if block == nil { + return nil, ErrPublicKey + } + + pub, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return nil, err + } + + switch pubKey := pub.(type) { + case *rsa.PublicKey: + // OAEP overhead: 2*hash_size + 2 + hashSize := sha256.New().Size() + return &rsaOAEPEncrypter{ + rsaBase: rsaBase{ + bytesLimit: (pubKey.N.BitLen() >> 3) - 2*hashSize - 2, + }, + publicKey: pubKey, + }, nil + default: + return nil, ErrNotRsaKey + } +} + +type rsaOAEPDecrypter struct { + rsaBase + privateKey *rsa.PrivateKey +} + +func (r *rsaOAEPDecrypter) Decrypt(input []byte) ([]byte, error) { + return r.crypt(input, func(block []byte) ([]byte, error) { + return rsa.DecryptOAEP(sha256.New(), rand.Reader, r.privateKey, block, nil) + }) +} + +func (r *rsaOAEPDecrypter) DecryptBase64(input string) ([]byte, error) { + if len(input) == 0 { + return nil, nil + } + + base64Decoded, err := base64.StdEncoding.DecodeString(input) + if err != nil { + return nil, err + } + + return r.Decrypt(base64Decoded) +} + +type rsaOAEPEncrypter struct { + rsaBase + publicKey *rsa.PublicKey +} + +func (r *rsaOAEPEncrypter) Encrypt(input []byte) ([]byte, error) { + return r.crypt(input, func(block []byte) ([]byte, error) { + return rsa.EncryptOAEP(sha256.New(), rand.Reader, r.publicKey, block, nil) + }) +} diff --git a/core/codec/rsa_test.go b/core/codec/rsa_test.go index 68ce64353..16b49446a 100644 --- a/core/codec/rsa_test.go +++ b/core/codec/rsa_test.go @@ -1,7 +1,12 @@ package codec import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" "encoding/base64" + "encoding/pem" "os" "testing" @@ -58,3 +63,78 @@ func TestBadPubKey(t *testing.T) { _, err := NewRsaEncrypter([]byte("foo")) assert.Equal(t, ErrPublicKey, err) } + +func TestOAEPCryption(t *testing.T) { + enc, err := NewRsaOAEPEncrypter([]byte(pubKey)) + assert.Nil(t, err) + ret, err := enc.Encrypt([]byte(testBody)) + assert.Nil(t, err) + + file, err := fs.TempFilenameWithText(priKey) + assert.Nil(t, err) + defer os.Remove(file) + dec, err := NewRsaOAEPDecrypter(file) + assert.Nil(t, err) + actual, err := dec.Decrypt(ret) + assert.Nil(t, err) + assert.Equal(t, testBody, string(actual)) + + actual, err = dec.DecryptBase64(base64.StdEncoding.EncodeToString(ret)) + assert.Nil(t, err) + assert.Equal(t, testBody, string(actual)) + + // empty input + actual, err = dec.DecryptBase64("") + assert.Nil(t, err) + assert.Nil(t, actual) +} + +func TestOAEPBadKeys(t *testing.T) { + _, err := NewRsaOAEPEncrypter([]byte("bad")) + assert.Equal(t, ErrPublicKey, err) + + _, err = NewRsaOAEPDecrypter("nonexistent") + assert.Error(t, err) + + // valid PEM but invalid private key content + badPem, err := fs.TempFilenameWithText("-----BEGIN RSA PRIVATE KEY-----\nYmFk\n-----END RSA PRIVATE KEY-----") + assert.Nil(t, err) + defer os.Remove(badPem) + _, err = NewRsaOAEPDecrypter(badPem) + assert.Error(t, err) + + // not PEM content at all + notPem, err := fs.TempFilenameWithText("not a pem file") + assert.Nil(t, err) + defer os.Remove(notPem) + _, err = NewRsaOAEPDecrypter(notPem) + assert.Equal(t, ErrPrivateKey, err) +} + +func TestOAEPEncrypterParseError(t *testing.T) { + // valid PEM block but invalid public key content + badPub := []byte("-----BEGIN PUBLIC KEY-----\nYmFk\n-----END PUBLIC KEY-----") + _, err := NewRsaOAEPEncrypter(badPub) + assert.Error(t, err) +} + +func TestOAEPEncrypterNonRsaKey(t *testing.T) { + ecKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + assert.Nil(t, err) + derBytes, err := x509.MarshalPKIXPublicKey(&ecKey.PublicKey) + assert.Nil(t, err) + ecPem := pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: derBytes}) + _, err = NewRsaOAEPEncrypter(ecPem) + assert.Equal(t, ErrNotRsaKey, err) +} + +func TestOAEPDecryptBase64Error(t *testing.T) { + file, err := fs.TempFilenameWithText(priKey) + assert.Nil(t, err) + defer os.Remove(file) + dec, err := NewRsaOAEPDecrypter(file) + assert.Nil(t, err) + + _, err = dec.DecryptBase64("not-valid-base64!!!") + assert.Error(t, err) +}