diff options
| author | Dave Henderson <dhenderson@gmail.com> | 2024-06-16 15:16:54 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-06-16 19:16:54 +0000 |
| commit | c275254c45353aeadf6adf85858856d0c36d624f (patch) | |
| tree | e64ac9cef3f839fc7a38077bc98e26115a3803b5 /internal | |
| parent | a33f6d9fe0f145afdd06ef5b3431e3c2fe281c16 (diff) | |
chore(api)!: Error instead of returning 0 on invalid inputs to conv.* functions (#2104)
Signed-off-by: Dave Henderson <dhenderson@gmail.com>
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/funcs/coll.go | 10 | ||||
| -rw-r--r-- | internal/funcs/conv.go | 30 | ||||
| -rw-r--r-- | internal/funcs/crypto.go | 49 | ||||
| -rw-r--r-- | internal/funcs/math.go | 261 | ||||
| -rw-r--r-- | internal/funcs/math_test.go | 204 | ||||
| -rw-r--r-- | internal/funcs/net.go | 19 | ||||
| -rw-r--r-- | internal/funcs/random.go | 83 | ||||
| -rw-r--r-- | internal/funcs/regexp.go | 19 | ||||
| -rw-r--r-- | internal/funcs/strings.go | 56 | ||||
| -rw-r--r-- | internal/funcs/time.go | 54 | ||||
| -rw-r--r-- | internal/tests/integration/math_test.go | 5 |
11 files changed, 630 insertions, 160 deletions
diff --git a/internal/funcs/coll.go b/internal/funcs/coll.go index 9e08aeac..945f77d3 100644 --- a/internal/funcs/coll.go +++ b/internal/funcs/coll.go @@ -153,12 +153,20 @@ func (CollFuncs) Flatten(args ...interface{}) ([]interface{}, error) { if len(args) == 0 || len(args) > 2 { return nil, fmt.Errorf("wrong number of args: wanted 1 or 2, got %d", len(args)) } + list := args[0] + + var err error depth := -1 if len(args) == 2 { - depth = conv.ToInt(args[0]) + depth, err = conv.ToInt(list) + if err != nil { + return nil, fmt.Errorf("wrong depth type: must be int, got %T (%+v)", list, list) + } + list = args[1] } + return coll.Flatten(list, depth) } diff --git a/internal/funcs/conv.go b/internal/funcs/conv.go index 5bc5cab9..67fd3b9f 100644 --- a/internal/funcs/conv.go +++ b/internal/funcs/conv.go @@ -3,6 +3,7 @@ package funcs import ( "context" "net/url" + "strconv" "text/template" "github.com/hairyhenderson/gomplate/v4/conv" @@ -52,23 +53,23 @@ func (ConvFuncs) Join(in interface{}, sep string) (string, error) { } // ParseInt - -func (ConvFuncs) ParseInt(s interface{}, base, bitSize int) int64 { - return conv.MustParseInt(conv.ToString(s), base, bitSize) +func (ConvFuncs) ParseInt(s interface{}, base, bitSize int) (int64, error) { + return strconv.ParseInt(conv.ToString(s), base, bitSize) } // ParseFloat - -func (ConvFuncs) ParseFloat(s interface{}, bitSize int) float64 { - return conv.MustParseFloat(conv.ToString(s), bitSize) +func (ConvFuncs) ParseFloat(s interface{}, bitSize int) (float64, error) { + return strconv.ParseFloat(conv.ToString(s), bitSize) } // ParseUint - -func (ConvFuncs) ParseUint(s interface{}, base, bitSize int) uint64 { - return conv.MustParseUint(conv.ToString(s), base, bitSize) +func (ConvFuncs) ParseUint(s interface{}, base, bitSize int) (uint64, error) { + return strconv.ParseUint(conv.ToString(s), base, bitSize) } // Atoi - -func (ConvFuncs) Atoi(s interface{}) int { - return conv.MustAtoi(conv.ToString(s)) +func (ConvFuncs) Atoi(s interface{}) (int, error) { + return strconv.Atoi(conv.ToString(s)) } // URL - @@ -77,32 +78,32 @@ func (ConvFuncs) URL(s interface{}) (*url.URL, error) { } // ToInt64 - -func (ConvFuncs) ToInt64(in interface{}) int64 { +func (ConvFuncs) ToInt64(in interface{}) (int64, error) { return conv.ToInt64(in) } // ToInt - -func (ConvFuncs) ToInt(in interface{}) int { +func (ConvFuncs) ToInt(in interface{}) (int, error) { return conv.ToInt(in) } // ToInt64s - -func (ConvFuncs) ToInt64s(in ...interface{}) []int64 { +func (ConvFuncs) ToInt64s(in ...interface{}) ([]int64, error) { return conv.ToInt64s(in...) } // ToInts - -func (ConvFuncs) ToInts(in ...interface{}) []int { +func (ConvFuncs) ToInts(in ...interface{}) ([]int, error) { return conv.ToInts(in...) } // ToFloat64 - -func (ConvFuncs) ToFloat64(in interface{}) float64 { +func (ConvFuncs) ToFloat64(in interface{}) (float64, error) { return conv.ToFloat64(in) } // ToFloat64s - -func (ConvFuncs) ToFloat64s(in ...interface{}) []float64 { +func (ConvFuncs) ToFloat64s(in ...interface{}) ([]float64, error) { return conv.ToFloat64s(in...) } @@ -121,5 +122,6 @@ func (ConvFuncs) Default(def, in interface{}) interface{} { if truth, ok := template.IsTrue(in); truth && ok { return in } + return def } diff --git a/internal/funcs/crypto.go b/internal/funcs/crypto.go index fab02a7a..8807d518 100644 --- a/internal/funcs/crypto.go +++ b/internal/funcs/crypto.go @@ -49,8 +49,16 @@ func (CryptoFuncs) PBKDF2(password, salt, iter, keylen interface{}, hashFunc ... } pw := toBytes(password) s := toBytes(salt) - i := conv.ToInt(iter) - kl := conv.ToInt(keylen) + + i, err := conv.ToInt(iter) + if err != nil { + return "", fmt.Errorf("iter must be an integer: %w", err) + } + + kl, err := conv.ToInt(keylen) + if err != nil { + return "", fmt.Errorf("keylen must be an integer: %w", err) + } dk, err := crypto.PBKDF2(pw, s, i, kl, h) return fmt.Sprintf("%02x", dk), err @@ -167,17 +175,24 @@ func (CryptoFuncs) SHA512_256Bytes(input interface{}) ([]byte, error) { // Bcrypt - func (CryptoFuncs) Bcrypt(args ...interface{}) (string, error) { input := "" + + var err error cost := bcrypt.DefaultCost - if len(args) == 0 { - return "", fmt.Errorf("bcrypt requires at least an 'input' value") - } - if len(args) == 1 { + + switch len(args) { + case 1: input = conv.ToString(args[0]) - } - if len(args) == 2 { - cost = conv.ToInt(args[0]) + case 2: + cost, err = conv.ToInt(args[0]) + if err != nil { + return "", fmt.Errorf("bcrypt cost must be an integer: %w", err) + } + input = conv.ToString(args[1]) + default: + return "", fmt.Errorf("wrong number of args: want 1 or 2, got %d", len(args)) } + hash, err := bcrypt.GenerateFromPassword([]byte(input), cost) return string(hash), err } @@ -215,15 +230,21 @@ func (f *CryptoFuncs) RSADecryptBytes(key string, in []byte) ([]byte, error) { // RSAGenerateKey - // Experimental! func (f *CryptoFuncs) RSAGenerateKey(args ...interface{}) (string, error) { - if err := checkExperimental(f.ctx); err != nil { + err := checkExperimental(f.ctx) + if err != nil { return "", err } + bits := 4096 if len(args) == 1 { - bits = conv.ToInt(args[0]) + bits, err = conv.ToInt(args[0]) + if err != nil { + return "", fmt.Errorf("bits must be an integer: %w", err) + } } else if len(args) > 1 { return "", fmt.Errorf("wrong number of args: want 0 or 1, got %d", len(args)) } + out, err := crypto.RSAGenerateKey(bits) return string(out), err } @@ -370,7 +391,11 @@ func parseAESArgs(key string, args ...interface{}) ([]byte, []byte, error) { case 1: msg = toBytes(args[0]) case 2: - keyBits = conv.ToInt(args[0]) + var err error + keyBits, err = conv.ToInt(args[0]) + if err != nil { + return nil, nil, fmt.Errorf("keyBits must be an integer: %w", err) + } msg = toBytes(args[1]) default: return nil, nil, fmt.Errorf("wrong number of args: want 2 or 3, got %d", len(args)) diff --git a/internal/funcs/math.go b/internal/funcs/math.go index 3881e057..694586a7 100644 --- a/internal/funcs/math.go +++ b/internal/funcs/math.go @@ -79,79 +79,158 @@ func (f MathFuncs) IsNum(n interface{}) bool { } // Abs - -func (f MathFuncs) Abs(n interface{}) interface{} { - m := gmath.Abs(conv.ToFloat64(n)) +func (f MathFuncs) Abs(n interface{}) (interface{}, error) { + fn, err := conv.ToFloat64(n) + if err != nil { + return nil, fmt.Errorf("expected a number: %w", err) + } + + m := gmath.Abs(fn) if f.IsInt(n) { return conv.ToInt64(m) } - return m + + return m, nil } // Add - -func (f MathFuncs) Add(n ...interface{}) interface{} { +func (f MathFuncs) Add(n ...interface{}) (interface{}, error) { if f.containsFloat(n...) { - nums := conv.ToFloat64s(n...) + nums, err := conv.ToFloat64s(n...) + if err != nil { + return nil, fmt.Errorf("expected number inputs: %w", err) + } + var x float64 for _, v := range nums { x += v } - return x + + return x, nil + } + + nums, err := conv.ToInt64s(n...) + if err != nil { + return nil, fmt.Errorf("expected number inputs: %w", err) } - nums := conv.ToInt64s(n...) + var x int64 for _, v := range nums { x += v } - return x + + return x, nil } // Mul - -func (f MathFuncs) Mul(n ...interface{}) interface{} { +func (f MathFuncs) Mul(n ...interface{}) (interface{}, error) { if f.containsFloat(n...) { - nums := conv.ToFloat64s(n...) + nums, err := conv.ToFloat64s(n...) + if err != nil { + return nil, fmt.Errorf("expected number inputs: %w", err) + } + x := 1. for _, v := range nums { x *= v } - return x + + return x, nil + } + + nums, err := conv.ToInt64s(n...) + if err != nil { + return nil, fmt.Errorf("expected number inputs: %w", err) } - nums := conv.ToInt64s(n...) + x := int64(1) for _, v := range nums { x *= v } - return x + + return x, nil } // Sub - -func (f MathFuncs) Sub(a, b interface{}) interface{} { +func (f MathFuncs) Sub(a, b interface{}) (interface{}, error) { if f.containsFloat(a, b) { - return conv.ToFloat64(a) - conv.ToFloat64(b) + fa, err := conv.ToFloat64(a) + if err != nil { + return nil, fmt.Errorf("expected a number: %w", err) + } + + fb, err := conv.ToFloat64(b) + if err != nil { + return nil, fmt.Errorf("expected a number: %w", err) + } + + return fa - fb, nil + } + + ia, err := conv.ToInt64(a) + if err != nil { + return nil, fmt.Errorf("expected a number: %w", err) + } + + ib, err := conv.ToInt64(b) + if err != nil { + return nil, fmt.Errorf("expected a number: %w", err) } - return conv.ToInt64(a) - conv.ToInt64(b) + + return ia - ib, nil } // Div - func (f MathFuncs) Div(a, b interface{}) (interface{}, error) { - divisor := conv.ToFloat64(a) - dividend := conv.ToFloat64(b) + divisor, err := conv.ToFloat64(a) + if err != nil { + return nil, fmt.Errorf("expected a number: %w", err) + } + + dividend, err := conv.ToFloat64(b) + if err != nil { + return nil, fmt.Errorf("expected a number: %w", err) + } + if dividend == 0 { return 0, fmt.Errorf("error: division by 0") } + return divisor / dividend, nil } // Rem - -func (f MathFuncs) Rem(a, b interface{}) interface{} { - return conv.ToInt64(a) % conv.ToInt64(b) +func (f MathFuncs) Rem(a, b interface{}) (interface{}, error) { + ia, err := conv.ToInt64(a) + if err != nil { + return nil, fmt.Errorf("expected a number: %w", err) + } + + ib, err := conv.ToInt64(b) + if err != nil { + return nil, fmt.Errorf("expected a number: %w", err) + } + + return ia % ib, nil } // Pow - -func (f MathFuncs) Pow(a, b interface{}) interface{} { - r := gmath.Pow(conv.ToFloat64(a), conv.ToFloat64(b)) +func (f MathFuncs) Pow(a, b interface{}) (interface{}, error) { + fa, err := conv.ToFloat64(a) + if err != nil { + return nil, fmt.Errorf("expected a number: %w", err) + } + + fb, err := conv.ToFloat64(b) + if err != nil { + return nil, fmt.Errorf("expected a number: %w", err) + } + + r := gmath.Pow(fa, fb) if f.IsFloat(a) { - return r + return r, nil } + return conv.ToInt64(r) } @@ -161,53 +240,116 @@ func (f MathFuncs) Seq(n ...interface{}) ([]int64, error) { start := int64(1) end := int64(0) step := int64(1) - if len(n) == 0 { - return nil, fmt.Errorf("math.Seq must be given at least an 'end' value") - } - if len(n) == 1 { - end = conv.ToInt64(n[0]) - } - if len(n) == 2 { - start = conv.ToInt64(n[0]) - end = conv.ToInt64(n[1]) - } - if len(n) == 3 { - start = conv.ToInt64(n[0]) - end = conv.ToInt64(n[1]) - step = conv.ToInt64(n[2]) + + var err error + + switch len(n) { + case 1: + end, err = conv.ToInt64(n[0]) + if err != nil { + return nil, fmt.Errorf("expected a number: %w", err) + } + case 2: + start, err = conv.ToInt64(n[0]) + if err != nil { + return nil, fmt.Errorf("expected a number: %w", err) + } + + end, err = conv.ToInt64(n[1]) + if err != nil { + return nil, fmt.Errorf("expected a number: %w", err) + } + case 3: + start, err = conv.ToInt64(n[0]) + if err != nil { + return nil, fmt.Errorf("expected a number: %w", err) + } + + end, err = conv.ToInt64(n[1]) + if err != nil { + return nil, fmt.Errorf("expected a number: %w", err) + } + + step, err = conv.ToInt64(n[2]) + if err != nil { + return nil, fmt.Errorf("expected a number: %w", err) + } + default: + return nil, fmt.Errorf("expected 1, 2, or 3 arguments, got %d", len(n)) } - return math.Seq(conv.ToInt64(start), conv.ToInt64(end), conv.ToInt64(step)), nil + + return math.Seq(start, end, step), nil } // Max - func (f MathFuncs) Max(a interface{}, b ...interface{}) (interface{}, error) { if f.IsFloat(a) || f.containsFloat(b...) { - m := conv.ToFloat64(a) - for _, n := range conv.ToFloat64s(b...) { + m, err := conv.ToFloat64(a) + if err != nil { + return nil, fmt.Errorf("expected a number: %w", err) + } + + floats, err := conv.ToFloat64s(b...) + if err != nil { + return nil, fmt.Errorf("expected a number: %w", err) + } + + for _, n := range floats { m = gmath.Max(m, n) } + return m, nil } - m := conv.ToInt64(a) - for _, n := range conv.ToInt64s(b...) { + + m, err := conv.ToInt64(a) + if err != nil { + return nil, fmt.Errorf("expected a number: %w", err) + } + + nums, err := conv.ToInt64s(b...) + if err != nil { + return nil, fmt.Errorf("expected a number: %w", err) + } + + for _, n := range nums { if n > m { m = n } } + return m, nil } // Min - func (f MathFuncs) Min(a interface{}, b ...interface{}) (interface{}, error) { if f.IsFloat(a) || f.containsFloat(b...) { - m := conv.ToFloat64(a) - for _, n := range conv.ToFloat64s(b...) { + m, err := conv.ToFloat64(a) + if err != nil { + return nil, fmt.Errorf("expected a number: %w", err) + } + + floats, err := conv.ToFloat64s(b...) + if err != nil { + return nil, fmt.Errorf("expected a number: %w", err) + } + + for _, n := range floats { m = gmath.Min(m, n) } return m, nil } - m := conv.ToInt64(a) - for _, n := range conv.ToInt64s(b...) { + + m, err := conv.ToInt64(a) + if err != nil { + return nil, fmt.Errorf("expected a number: %w", err) + } + + nums, err := conv.ToInt64s(b...) + if err != nil { + return nil, fmt.Errorf("expected a number: %w", err) + } + + for _, n := range nums { if n < m { m = n } @@ -216,16 +358,31 @@ func (f MathFuncs) Min(a interface{}, b ...interface{}) (interface{}, error) { } // Ceil - -func (f MathFuncs) Ceil(n interface{}) interface{} { - return gmath.Ceil(conv.ToFloat64(n)) +func (f MathFuncs) Ceil(n interface{}) (interface{}, error) { + in, err := conv.ToFloat64(n) + if err != nil { + return nil, fmt.Errorf("n must be a number: %w", err) + } + + return gmath.Ceil(in), nil } // Floor - -func (f MathFuncs) Floor(n interface{}) interface{} { - return gmath.Floor(conv.ToFloat64(n)) +func (f MathFuncs) Floor(n interface{}) (interface{}, error) { + in, err := conv.ToFloat64(n) + if err != nil { + return nil, fmt.Errorf("n must be a number: %w", err) + } + + return gmath.Floor(in), nil } // Round - -func (f MathFuncs) Round(n interface{}) interface{} { - return gmath.Round(conv.ToFloat64(n)) +func (f MathFuncs) Round(n interface{}) (interface{}, error) { + in, err := conv.ToFloat64(n) + if err != nil { + return nil, fmt.Errorf("n must be a number: %w", err) + } + + return gmath.Round(in), nil } diff --git a/internal/funcs/math_test.go b/internal/funcs/math_test.go index 50b05784..8df794fc 100644 --- a/internal/funcs/math_test.go +++ b/internal/funcs/math_test.go @@ -32,33 +32,78 @@ func TestAdd(t *testing.T) { t.Parallel() m := MathFuncs{} - assert.Equal(t, int64(12), m.Add(1, 1, 2, 3, 5)) - assert.Equal(t, int64(2), m.Add(1, 1)) - assert.Equal(t, int64(1), m.Add(1)) - assert.Equal(t, int64(0), m.Add(-5, 5)) - assert.InEpsilon(t, float64(5.1), m.Add(4.9, "0.2"), 1e-12) + + actual, err := m.Add(1, 1, 2, 3, 5) + require.NoError(t, err) + assert.Equal(t, int64(12), actual) + + actual, err = m.Add(1, 1) + require.NoError(t, err) + assert.Equal(t, int64(2), actual) + + actual, err = m.Add(1) + require.NoError(t, err) + assert.Equal(t, int64(1), actual) + + actual, err = m.Add(-5, 5) + require.NoError(t, err) + assert.Equal(t, int64(0), actual) + + actual, err = m.Add(4.9, "0.2") + require.NoError(t, err) + assert.InEpsilon(t, float64(5.1), actual, 1e-12) } func TestMul(t *testing.T) { t.Parallel() m := MathFuncs{} - assert.Equal(t, int64(30), m.Mul(1, 1, 2, 3, 5)) - assert.Equal(t, int64(1), m.Mul(1, 1)) - assert.Equal(t, int64(1), m.Mul(1)) - assert.Equal(t, int64(-25), m.Mul("-5", 5)) - assert.Equal(t, int64(28), m.Mul(14, "2")) - assert.InEpsilon(t, float64(0.5), m.Mul("-1", -0.5), 1e-12) + + actual, err := m.Mul(1, 1, 2, 3, 5) + require.NoError(t, err) + assert.Equal(t, int64(30), actual) + + actual, err = m.Mul(1, 1) + require.NoError(t, err) + assert.Equal(t, int64(1), actual) + + actual, err = m.Mul(1) + require.NoError(t, err) + assert.Equal(t, int64(1), actual) + + actual, err = m.Mul("-5", 5) + require.NoError(t, err) + assert.Equal(t, int64(-25), actual) + + actual, err = m.Mul(14, "2") + require.NoError(t, err) + assert.Equal(t, int64(28), actual) + + actual, err = m.Mul("-1", -0.5) + require.NoError(t, err) + assert.InEpsilon(t, float64(0.5), actual, 1e-12) } func TestSub(t *testing.T) { t.Parallel() m := MathFuncs{} - assert.Equal(t, int64(0), m.Sub(1, 1)) - assert.Equal(t, int64(-10), m.Sub(-5, 5)) - assert.Equal(t, int64(-41), m.Sub(true, "42")) - assert.InEpsilon(t, -5.3, m.Sub(10, 15.3), 1e-12) + + actual, err := m.Sub(1, 1) + require.NoError(t, err) + assert.Equal(t, int64(0), actual) + + actual, err = m.Sub(-5, 5) + require.NoError(t, err) + assert.Equal(t, int64(-10), actual) + + actual, err = m.Sub(true, "42") + require.NoError(t, err) + assert.Equal(t, int64(-41), actual) + + actual, err = m.Sub(10, 15.3) + require.NoError(t, err) + assert.InEpsilon(t, -5.3, actual, 1e-12) } func mustDiv(a, b interface{}) interface{} { @@ -86,16 +131,28 @@ func TestRem(t *testing.T) { t.Parallel() m := MathFuncs{} - assert.Equal(t, int64(0), m.Rem(1, 1)) - assert.Equal(t, int64(2), m.Rem(5, 3.0)) + + actual, err := m.Rem(1, 1) + require.NoError(t, err) + assert.Equal(t, int64(0), actual) + + actual, err = m.Rem(5, 3.0) + require.NoError(t, err) + assert.Equal(t, int64(2), actual) } func TestPow(t *testing.T) { t.Parallel() m := MathFuncs{} - assert.Equal(t, int64(4), m.Pow(2, "2")) - assert.InEpsilon(t, 2.25, m.Pow(1.5, 2), 1e-12) + + actual, err := m.Pow(2, "2") + require.NoError(t, err) + assert.Equal(t, int64(4), actual) + + actual, err = m.Pow(1.5, 2) + require.NoError(t, err) + assert.InEpsilon(t, 2.25, actual, 1e-12) } func mustSeq(t *testing.T, n ...interface{}) []int64 { @@ -161,6 +218,7 @@ func TestIsIntFloatNum(t *testing.T) { {nil, false, false}, {true, false, false}, } + m := MathFuncs{} for _, tt := range tests { tt := tt @@ -178,6 +236,7 @@ func BenchmarkIsFloat(b *testing.B) { data := []interface{}{ 0, 1, -1, uint(42), uint8(255), uint16(42), uint32(42), uint64(42), int(42), int8(127), int16(42), int32(42), int64(42), float32(18.3), float64(18.3), 1.5, -18.6, "42", "052", "0xff", "-42", "-0", "3.14", "-3.14", "0.00", "NaN", "-Inf", "+Inf", "", "foo", nil, true, } + m := MathFuncs{} for _, n := range data { n := n @@ -197,9 +256,7 @@ func TestMax(t *testing.T) { expected interface{} n []interface{} }{ - {int64(0), []interface{}{nil}}, {int64(0), []interface{}{0}}, - {int64(0), []interface{}{"not a number"}}, {int64(1), []interface{}{1}}, {int64(-1), []interface{}{-1}}, {int64(1), []interface{}{-1, 0, 1}}, @@ -220,6 +277,19 @@ func TestMax(t *testing.T) { assert.Equal(t, d.expected, actual) }) } + + t.Run("error cases", func(t *testing.T) { + t.Parallel() + + _, err := m.Max("foo") + require.Error(t, err) + + _, err = m.Max(nil) + require.Error(t, err) + + _, err = m.Max("") + require.Error(t, err) + }) } func TestMin(t *testing.T) { @@ -230,9 +300,7 @@ func TestMin(t *testing.T) { expected interface{} n []interface{} }{ - {int64(0), []interface{}{nil}}, {int64(0), []interface{}{0}}, - {int64(0), []interface{}{"not a number"}}, {int64(1), []interface{}{1}}, {int64(-1), []interface{}{-1}}, {int64(-1), []interface{}{-1, 0, 1}}, @@ -250,9 +318,23 @@ func TestMin(t *testing.T) { } else { actual, _ = m.Min(d.n[0], d.n[1:]...) } + assert.Equal(t, d.expected, actual) }) } + + t.Run("error cases", func(t *testing.T) { + t.Parallel() + + _, err := m.Min("foo") + require.Error(t, err) + + _, err = m.Min(nil) + require.Error(t, err) + + _, err = m.Min("") + require.Error(t, err) + }) } func TestContainsFloat(t *testing.T) { @@ -297,8 +379,6 @@ func TestCeil(t *testing.T) { n interface{} a float64 }{ - {"", 0.}, - {nil, 0.}, {"Inf", gmath.Inf(1)}, {0, 0.}, {4.99, 5.}, @@ -310,9 +390,24 @@ func TestCeil(t *testing.T) { t.Run(fmt.Sprintf("%v==%v", d.n, d.a), func(t *testing.T) { t.Parallel() - assert.InDelta(t, d.a, m.Ceil(d.n), 1e-12) + actual, err := m.Ceil(d.n) + require.NoError(t, err) + assert.InDelta(t, d.a, actual, 1e-12) }) } + + t.Run("error cases", func(t *testing.T) { + t.Parallel() + + _, err := m.Ceil("foo") + require.Error(t, err) + + _, err = m.Ceil(nil) + require.Error(t, err) + + _, err = m.Ceil("") + require.Error(t, err) + }) } func TestFloor(t *testing.T) { @@ -323,8 +418,6 @@ func TestFloor(t *testing.T) { n interface{} a float64 }{ - {"", 0.}, - {nil, 0.}, {"Inf", gmath.Inf(1)}, {0, 0.}, {4.99, 4.}, @@ -336,9 +429,24 @@ func TestFloor(t *testing.T) { t.Run(fmt.Sprintf("%v==%v", d.n, d.a), func(t *testing.T) { t.Parallel() - assert.InDelta(t, d.a, m.Floor(d.n), 1e-12) + actual, err := m.Floor(d.n) + require.NoError(t, err) + assert.InDelta(t, d.a, actual, 1e-12) }) } + + t.Run("error cases", func(t *testing.T) { + t.Parallel() + + _, err := m.Floor("foo") + require.Error(t, err) + + _, err = m.Floor(nil) + require.Error(t, err) + + _, err = m.Floor("") + require.Error(t, err) + }) } func TestRound(t *testing.T) { @@ -349,8 +457,6 @@ func TestRound(t *testing.T) { n interface{} a float64 }{ - {"", 0.}, - {nil, 0.}, {"Inf", gmath.Inf(1)}, {0, 0.}, {4.99, 5}, @@ -366,9 +472,24 @@ func TestRound(t *testing.T) { t.Run(fmt.Sprintf("%v==%v", d.n, d.a), func(t *testing.T) { t.Parallel() - assert.InDelta(t, d.a, m.Round(d.n), 1e-12) + actual, err := m.Round(d.n) + require.NoError(t, err) + assert.InDelta(t, d.a, actual, 1e-12) }) } + + t.Run("error cases", func(t *testing.T) { + t.Parallel() + + _, err := m.Round("foo") + require.Error(t, err) + + _, err = m.Round(nil) + require.Error(t, err) + + _, err = m.Round("") + require.Error(t, err) + }) } func TestAbs(t *testing.T) { @@ -379,8 +500,6 @@ func TestAbs(t *testing.T) { n interface{} a interface{} }{ - {"", 0.}, - {nil, 0.}, {"-Inf", gmath.Inf(1)}, {0, int64(0)}, {0., 0.}, @@ -395,7 +514,22 @@ func TestAbs(t *testing.T) { t.Run(fmt.Sprintf("%#v==%v", d.n, d.a), func(t *testing.T) { t.Parallel() - assert.Equal(t, d.a, m.Abs(d.n)) + actual, err := m.Abs(d.n) + require.NoError(t, err) + assert.Equal(t, d.a, actual) }) } + + t.Run("error cases", func(t *testing.T) { + t.Parallel() + + _, err := m.Abs("foo") + require.Error(t, err) + + _, err = m.Abs(nil) + require.Error(t, err) + + _, err = m.Abs("") + require.Error(t, err) + }) } diff --git a/internal/funcs/net.go b/internal/funcs/net.go index bf8a2674..9f69579a 100644 --- a/internal/funcs/net.go +++ b/internal/funcs/net.go @@ -132,7 +132,12 @@ func (f *NetFuncs) CIDRHost(hostnum interface{}, prefix interface{}) (netip.Addr return netip.Addr{}, err } - ip, err := cidr.HostBig(network, big.NewInt(conv.ToInt64(hostnum))) + n, err := conv.ToInt64(hostnum) + if err != nil { + return netip.Addr{}, fmt.Errorf("expected a number: %w", err) + } + + ip, err := cidr.HostBig(network, big.NewInt(n)) return ip, err } @@ -175,7 +180,11 @@ func (f *NetFuncs) CIDRSubnets(newbits interface{}, prefix interface{}) ([]netip return nil, err } - nBits := conv.ToInt(newbits) + nBits, err := conv.ToInt(newbits) + if err != nil { + return nil, fmt.Errorf("newbits must be a number: %w", err) + } + if nBits < 1 { return nil, fmt.Errorf("must extend prefix by at least one bit") } @@ -208,7 +217,11 @@ func (f *NetFuncs) CIDRSubnetSizes(args ...interface{}) ([]netip.Prefix, error) if err != nil { return nil, err } - newbits := conv.ToInts(args[:len(args)-1]...) + + newbits, err := conv.ToInts(args[:len(args)-1]...) + if err != nil { + return nil, fmt.Errorf("newbits must be numbers: %w", err) + } startPrefixLen := network.Bits() firstLength := newbits[0] diff --git a/internal/funcs/random.go b/internal/funcs/random.go index 84bad8d7..0f18ae92 100644 --- a/internal/funcs/random.go +++ b/internal/funcs/random.go @@ -26,25 +26,45 @@ type RandomFuncs struct { // ASCII - func (RandomFuncs) ASCII(count interface{}) (string, error) { - return random.StringBounds(conv.ToInt(count), ' ', '~') + n, err := conv.ToInt(count) + if err != nil { + return "", fmt.Errorf("count must be an integer: %w", err) + } + + return random.StringBounds(n, ' ', '~') } // Alpha - func (RandomFuncs) Alpha(count interface{}) (string, error) { - return random.StringRE(conv.ToInt(count), "[[:alpha:]]") + n, err := conv.ToInt(count) + if err != nil { + return "", fmt.Errorf("count must be an integer: %w", err) + } + + return random.StringRE(n, "[[:alpha:]]") } // AlphaNum - func (RandomFuncs) AlphaNum(count interface{}) (string, error) { - return random.StringRE(conv.ToInt(count), "[[:alnum:]]") + n, err := conv.ToInt(count) + if err != nil { + return "", fmt.Errorf("count must be an integer: %w", err) + } + + return random.StringRE(n, "[[:alnum:]]") } // String - -func (RandomFuncs) String(count interface{}, args ...interface{}) (s string, err error) { - c := conv.ToInt(count) +func (RandomFuncs) String(count interface{}, args ...interface{}) (string, error) { + c, err := conv.ToInt(count) + if err != nil { + return "", fmt.Errorf("count must be an integer: %w", err) + } + if c == 0 { return "", fmt.Errorf("count must be greater than 0") } + m := "" switch len(args) { case 0: @@ -59,8 +79,17 @@ func (RandomFuncs) String(count interface{}, args ...interface{}) (s string, err return "", err } } else { - l = rune(conv.ToInt(args[0])) - u = rune(conv.ToInt(args[1])) + nl, err := conv.ToInt(args[0]) + if err != nil { + return "", fmt.Errorf("lower must be an integer: %w", err) + } + + nu, err := conv.ToInt(args[1]) + if err != nil { + return "", fmt.Errorf("upper must be an integer: %w", err) + } + + l, u = rune(nl), rune(nu) } return random.StringBounds(c, l, u) @@ -114,14 +143,28 @@ func (RandomFuncs) Item(items interface{}) (interface{}, error) { func (RandomFuncs) Number(args ...interface{}) (int64, error) { var min, max int64 min, max = 0, 100 + + var err error + switch len(args) { case 0: case 1: - max = conv.ToInt64(args[0]) + max, err = conv.ToInt64(args[0]) + if err != nil { + return 0, fmt.Errorf("max must be a number: %w", err) + } case 2: - min = conv.ToInt64(args[0]) - max = conv.ToInt64(args[1]) + min, err = conv.ToInt64(args[0]) + if err != nil { + return 0, fmt.Errorf("min must be a number: %w", err) + } + + max, err = conv.ToInt64(args[1]) + if err != nil { + return 0, fmt.Errorf("max must be a number: %w", err) + } } + return random.Number(min, max) } @@ -129,13 +172,27 @@ func (RandomFuncs) Number(args ...interface{}) (int64, error) { func (RandomFuncs) Float(args ...interface{}) (float64, error) { var min, max float64 min, max = 0, 1.0 + + var err error + switch len(args) { case 0: case 1: - max = conv.ToFloat64(args[0]) + max, err = conv.ToFloat64(args[0]) + if err != nil { + return 0, fmt.Errorf("max must be a number: %w", err) + } case 2: - min = conv.ToFloat64(args[0]) - max = conv.ToFloat64(args[1]) + min, err = conv.ToFloat64(args[0]) + if err != nil { + return 0, fmt.Errorf("min must be a number: %w", err) + } + + max, err = conv.ToFloat64(args[1]) + if err != nil { + return 0, fmt.Errorf("max must be a number: %w", err) + } } + return random.Float(min, max) } diff --git a/internal/funcs/regexp.go b/internal/funcs/regexp.go index 49bafd1c..d7f6a685 100644 --- a/internal/funcs/regexp.go +++ b/internal/funcs/regexp.go @@ -31,6 +31,7 @@ func (ReFuncs) FindAll(args ...interface{}) ([]string, error) { re := "" n := 0 input := "" + switch len(args) { case 2: n = -1 @@ -38,11 +39,18 @@ func (ReFuncs) FindAll(args ...interface{}) ([]string, error) { input = conv.ToString(args[1]) case 3: re = conv.ToString(args[0]) - n = conv.ToInt(args[1]) + + var err error + n, err = conv.ToInt(args[1]) + if err != nil { + return nil, fmt.Errorf("n must be an integer: %w", err) + } + input = conv.ToString(args[2]) default: return nil, fmt.Errorf("wrong number of args: want 2 or 3, got %d", len(args)) } + return regexp.FindAll(re, n, input) } @@ -75,16 +83,23 @@ func (ReFuncs) Split(args ...interface{}) ([]string, error) { re := "" n := -1 input := "" + switch len(args) { case 2: re = conv.ToString(args[0]) input = conv.ToString(args[1]) case 3: re = conv.ToString(args[0]) - n = conv.ToInt(args[1]) + var err error + n, err = conv.ToInt(args[1]) + if err != nil { + return nil, fmt.Errorf("n must be an integer: %w", err) + } + input = conv.ToString(args[2]) default: return nil, fmt.Errorf("wrong number of args: want 2 or 3, got %d", len(args)) } + return regexp.Split(re, n, input) } diff --git a/internal/funcs/strings.go b/internal/funcs/strings.go index 53869c7b..cb440868 100644 --- a/internal/funcs/strings.go +++ b/internal/funcs/strings.go @@ -115,23 +115,39 @@ func (f *StringFuncs) oldTrim(s, cutset string) string { func (StringFuncs) Abbrev(args ...interface{}) (string, error) { str := "" offset := 0 - maxWidth := 0 - if len(args) < 2 { - return "", fmt.Errorf("abbrev requires a 'maxWidth' and 'input' argument") - } - if len(args) == 2 { - maxWidth = conv.ToInt(args[0]) + width := 0 + + var err error + + switch len(args) { + case 2: + width, err = conv.ToInt(args[0]) + if err != nil { + return "", fmt.Errorf("width must be an integer: %w", err) + } + str = conv.ToString(args[1]) - } - if len(args) == 3 { - offset = conv.ToInt(args[0]) - maxWidth = conv.ToInt(args[1]) + case 3: + offset, err = conv.ToInt(args[0]) + if err != nil { + return "", fmt.Errorf("offset must be an integer: %w", err) + } + + width, err = conv.ToInt(args[1]) + if err != nil { + return "", fmt.Errorf("width must be an integer: %w", err) + } + str = conv.ToString(args[2]) + default: + return "", fmt.Errorf("abbrev requires a 'width' and 'input' argument") } - if len(str) <= maxWidth { + + if len(str) <= width { return str, nil } - return goutils.AbbreviateFull(str, offset, maxWidth) + + return goutils.AbbreviateFull(str, offset, width) } // ReplaceAll - @@ -342,13 +358,25 @@ func (StringFuncs) WordWrap(args ...interface{}) (string, error) { case string: opts.LBSeq = a default: - opts.Width = uint(conv.ToInt(a)) + n, err := conv.ToInt(args[0]) + if err != nil { + return "", fmt.Errorf("expected width to be a number: %w", err) + } + + opts.Width = uint(n) } } + if len(args) == 3 { - opts.Width = uint(conv.ToInt(args[0])) + n, err := conv.ToInt(args[0]) + if err != nil { + return "", fmt.Errorf("expected width to be a number: %w", err) + } + + opts.Width = uint(n) opts.LBSeq = conv.ToString(args[1]) } + return gompstrings.WordWrap(in, opts), nil } diff --git a/internal/funcs/time.go b/internal/funcs/time.go index deaddc2e..532335c3 100644 --- a/internal/funcs/time.go +++ b/internal/funcs/time.go @@ -104,33 +104,63 @@ func (TimeFuncs) Unix(in interface{}) (gotime.Time, error) { } // Nanosecond - -func (TimeFuncs) Nanosecond(n interface{}) gotime.Duration { - return gotime.Nanosecond * gotime.Duration(conv.ToInt64(n)) +func (TimeFuncs) Nanosecond(n interface{}) (gotime.Duration, error) { + in, err := conv.ToInt64(n) + if err != nil { + return 0, fmt.Errorf("expected a number: %w", err) + } + + return gotime.Nanosecond * gotime.Duration(in), nil } // Microsecond - -func (TimeFuncs) Microsecond(n interface{}) gotime.Duration { - return gotime.Microsecond * gotime.Duration(conv.ToInt64(n)) +func (TimeFuncs) Microsecond(n interface{}) (gotime.Duration, error) { + in, err := conv.ToInt64(n) + if err != nil { + return 0, fmt.Errorf("expected a number: %w", err) + } + + return gotime.Microsecond * gotime.Duration(in), nil } // Millisecond - -func (TimeFuncs) Millisecond(n interface{}) gotime.Duration { - return gotime.Millisecond * gotime.Duration(conv.ToInt64(n)) +func (TimeFuncs) Millisecond(n interface{}) (gotime.Duration, error) { + in, err := conv.ToInt64(n) + if err != nil { + return 0, fmt.Errorf("expected a number: %w", err) + } + + return gotime.Millisecond * gotime.Duration(in), nil } // Second - -func (TimeFuncs) Second(n interface{}) gotime.Duration { - return gotime.Second * gotime.Duration(conv.ToInt64(n)) +func (TimeFuncs) Second(n interface{}) (gotime.Duration, error) { + in, err := conv.ToInt64(n) + if err != nil { + return 0, fmt.Errorf("expected a number: %w", err) + } + + return gotime.Second * gotime.Duration(in), nil } // Minute - -func (TimeFuncs) Minute(n interface{}) gotime.Duration { - return gotime.Minute * gotime.Duration(conv.ToInt64(n)) +func (TimeFuncs) Minute(n interface{}) (gotime.Duration, error) { + in, err := conv.ToInt64(n) + if err != nil { + return 0, fmt.Errorf("expected a number: %w", err) + } + + return gotime.Minute * gotime.Duration(in), nil } // Hour - -func (TimeFuncs) Hour(n interface{}) gotime.Duration { - return gotime.Hour * gotime.Duration(conv.ToInt64(n)) +func (TimeFuncs) Hour(n interface{}) (gotime.Duration, error) { + in, err := conv.ToInt64(n) + if err != nil { + return 0, fmt.Errorf("expected a number: %w", err) + } + + return gotime.Hour * gotime.Duration(in), nil } // ParseDuration - diff --git a/internal/tests/integration/math_test.go b/internal/tests/integration/math_test.go index 769cae46..b3289b28 100644 --- a/internal/tests/integration/math_test.go +++ b/internal/tests/integration/math_test.go @@ -13,8 +13,9 @@ func TestMath(t *testing.T) { inOutTest(t, `{{ math.Pow 8 4 }} {{ pow 2 2 }}`, "4096 4") inOutTest(t, `{{ math.Seq 0 }}, {{ seq 0 3 }}, {{ seq -5 -10 2 }}`, `[1 0], [0 1 2 3], [-5 -7 -9]`) - inOutTest(t, `{{ math.Round 0.99 }}, {{ math.Round "foo" }}, {{math.Round 3.5}}`, - `1, 0, 4`) + inOutTest(t, `{{ math.Round 0.99 }}, {{math.Round 3.5}}`, `1, 4`) inOutTest(t, `{{ math.Max -0 "+Inf" "NaN" }}, {{ math.Max 3.4 3.401 3.399 }}`, `+Inf, 3.401`) + + inOutContainsError(t, `{{ math.Round "foo" }}`, `could not convert \"foo\"`) } |
