From ee3069abfc98d3ba0f99e92f1145195cd879f2e7 Mon Sep 17 00:00:00 2001 From: Dave Henderson Date: Sun, 14 Jun 2020 15:02:00 -0400 Subject: New RSA encrypt/decrypt functions, and new base64.DecodeBytes function Signed-off-by: Dave Henderson --- funcs/base64.go | 6 ++++++ funcs/base64_test.go | 10 +++++++--- funcs/crypto.go | 36 ++++++++++++++++++++++++++++++++++++ funcs/crypto_test.go | 36 ++++++++++++++++++++++++++++++++++++ 4 files changed, 85 insertions(+), 3 deletions(-) (limited to 'funcs') diff --git a/funcs/base64.go b/funcs/base64.go index ac7cf3ee..da9b315e 100644 --- a/funcs/base64.go +++ b/funcs/base64.go @@ -38,6 +38,12 @@ func (f *Base64Funcs) Decode(in interface{}) (string, error) { return string(out), err } +// DecodeBytes - +func (f *Base64Funcs) DecodeBytes(in interface{}) ([]byte, error) { + out, err := base64.Decode(conv.ToString(in)) + return out, err +} + type byter interface { Bytes() []byte } diff --git a/funcs/base64_test.go b/funcs/base64_test.go index a559e17c..6cfe2b9e 100644 --- a/funcs/base64_test.go +++ b/funcs/base64_test.go @@ -15,7 +15,13 @@ func TestBase64Encode(t *testing.T) { func TestBase64Decode(t *testing.T) { bf := &Base64Funcs{} assert.Equal(t, "foobar", must(bf.Decode("Zm9vYmFy"))) - // assert.Equal(t, "", bf.Decode(nil)) +} + +func TestBase64DecodeBytes(t *testing.T) { + bf := &Base64Funcs{} + out, err := bf.DecodeBytes("Zm9vYmFy") + assert.NoError(t, err) + assert.Equal(t, "foobar", string(out)) } func TestToBytes(t *testing.T) { @@ -24,8 +30,6 @@ func TestToBytes(t *testing.T) { buf := &bytes.Buffer{} buf.WriteString("hi") assert.Equal(t, []byte("hi"), toBytes(buf)) - assert.Equal(t, []byte{}, toBytes(nil)) - assert.Equal(t, []byte("42"), toBytes(42)) } diff --git a/funcs/crypto.go b/funcs/crypto.go index d1bd7c9d..59516ded 100644 --- a/funcs/crypto.go +++ b/funcs/crypto.go @@ -131,3 +131,39 @@ func (f *CryptoFuncs) Bcrypt(args ...interface{}) (string, error) { hash, err := bcrypt.GenerateFromPassword([]byte(input), cost) return string(hash), err } + +// RSAEncrypt - +func (f *CryptoFuncs) RSAEncrypt(key string, in interface{}) ([]byte, error) { + msg := toBytes(in) + return crypto.RSAEncrypt(key, msg) +} + +// RSADecrypt - +func (f *CryptoFuncs) RSADecrypt(key string, in []byte) (string, error) { + out, err := crypto.RSADecrypt(key, in) + return string(out), err +} + +// RSADecryptBytes - +func (f *CryptoFuncs) RSADecryptBytes(key string, in []byte) ([]byte, error) { + out, err := crypto.RSADecrypt(key, in) + return out, err +} + +// RSAGenerateKey - +func (f *CryptoFuncs) RSAGenerateKey(args ...interface{}) (string, error) { + bits := 4096 + if len(args) == 1 { + bits = conv.ToInt(args[0]) + } else if len(args) > 1 { + return "", fmt.Errorf("wrong number of args: want 0 or 1, got %d", len(args)) + } + out, err := crypto.RSAGenerateKey(bits) + return string(out), err +} + +// RSADerivePublicKey - +func (f *CryptoFuncs) RSADerivePublicKey(privateKey string) (string, error) { + out, err := crypto.RSADerivePublicKey([]byte(privateKey)) + return string(out), err +} diff --git a/funcs/crypto_test.go b/funcs/crypto_test.go index c2e14d19..4d2b483f 100644 --- a/funcs/crypto_test.go +++ b/funcs/crypto_test.go @@ -64,3 +64,39 @@ func TestBcrypt(t *testing.T) { _, err = c.Bcrypt() assert.Error(t, err) } + +func TestRSAGenerateKey(t *testing.T) { + c := CryptoNS() + _, err := c.RSAGenerateKey(0) + assert.Error(t, err) + + _, err = c.RSAGenerateKey(0, "foo", true) + assert.Error(t, err) + + key, err := c.RSAGenerateKey(12) + assert.NoError(t, err) + assert.True(t, strings.HasPrefix(key, + "-----BEGIN RSA PRIVATE KEY-----")) + assert.True(t, strings.HasSuffix(key, + "-----END RSA PRIVATE KEY-----\n")) +} + +func TestRSACrypt(t *testing.T) { + c := CryptoNS() + key, err := c.RSAGenerateKey() + assert.NoError(t, err) + pub, err := c.RSADerivePublicKey(key) + assert.NoError(t, err) + + in := "hello world" + enc, err := c.RSAEncrypt(pub, in) + assert.NoError(t, err) + + dec, err := c.RSADecrypt(key, enc) + assert.NoError(t, err) + assert.Equal(t, in, dec) + + b, err := c.RSADecryptBytes(key, enc) + assert.NoError(t, err) + assert.Equal(t, dec, string(b)) +} -- cgit v1.2.3