diff options
Diffstat (limited to 'internal/funcs')
45 files changed, 6161 insertions, 0 deletions
diff --git a/internal/funcs/aws.go b/internal/funcs/aws.go new file mode 100644 index 00000000..ef95d612 --- /dev/null +++ b/internal/funcs/aws.go @@ -0,0 +1,152 @@ +package funcs + +import ( + "context" + "sync" + + "github.com/hairyhenderson/gomplate/v4/aws" + "github.com/hairyhenderson/gomplate/v4/conv" +) + +// AWSNS - the aws namespace +// +// Deprecated: don't use +// +//nolint:golint +func AWSNS() *Funcs { + return &Funcs{} +} + +// AWSFuncs - +// +// Deprecated: use [CreateAWSFuncs] instead +func AWSFuncs(f map[string]interface{}) { + f2 := CreateAWSFuncs(context.Background()) + for k, v := range f2 { + f[k] = v + } +} + +// CreateAWSFuncs - +func CreateAWSFuncs(ctx context.Context) map[string]interface{} { + f := map[string]interface{}{} + + ns := &Funcs{ + ctx: ctx, + awsopts: aws.GetClientOptions(), + } + + f["aws"] = func() interface{} { return ns } + + // global aliases - for backwards compatibility + f["ec2meta"] = ns.EC2Meta + f["ec2dynamic"] = ns.EC2Dynamic + f["ec2tag"] = ns.EC2Tag + f["ec2tags"] = ns.EC2Tags + f["ec2region"] = ns.EC2Region + return f +} + +// Funcs - +type Funcs struct { + ctx context.Context + + meta *aws.Ec2Meta + info *aws.Ec2Info + kms *aws.KMS + sts *aws.STS + metaInit sync.Once + infoInit sync.Once + kmsInit sync.Once + stsInit sync.Once + awsopts aws.ClientOptions +} + +// EC2Region - +func (a *Funcs) EC2Region(def ...string) (string, error) { + a.metaInit.Do(a.initMeta) + return a.meta.Region(def...) +} + +// EC2Meta - +func (a *Funcs) EC2Meta(key string, def ...string) (string, error) { + a.metaInit.Do(a.initMeta) + return a.meta.Meta(key, def...) +} + +// EC2Dynamic - +func (a *Funcs) EC2Dynamic(key string, def ...string) (string, error) { + a.metaInit.Do(a.initMeta) + return a.meta.Dynamic(key, def...) +} + +// EC2Tag - +func (a *Funcs) EC2Tag(tag string, def ...string) (string, error) { + a.infoInit.Do(a.initInfo) + return a.info.Tag(tag, def...) +} + +// EC2Tag - +func (a *Funcs) EC2Tags() (map[string]string, error) { + a.infoInit.Do(a.initInfo) + return a.info.Tags() +} + +// KMSEncrypt - +func (a *Funcs) KMSEncrypt(keyID, plaintext interface{}) (string, error) { + a.kmsInit.Do(a.initKMS) + return a.kms.Encrypt(conv.ToString(keyID), conv.ToString(plaintext)) +} + +// KMSDecrypt - +func (a *Funcs) KMSDecrypt(ciphertext interface{}) (string, error) { + a.kmsInit.Do(a.initKMS) + return a.kms.Decrypt(conv.ToString(ciphertext)) +} + +// UserID - Gets the unique identifier of the calling entity. The exact value +// depends on the type of entity making the call. The values returned are those +// listed in the aws:userid column in the Principal table +// (http://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_variables.html#principaltable) +// found on the Policy Variables reference page in the IAM User Guide. +func (a *Funcs) UserID() (string, error) { + a.stsInit.Do(a.initSTS) + return a.sts.UserID() +} + +// Account - Gets the AWS account ID number of the account that owns or +// contains the calling entity. +func (a *Funcs) Account() (string, error) { + a.stsInit.Do(a.initSTS) + return a.sts.Account() +} + +// ARN - Gets the AWS ARN associated with the calling entity +func (a *Funcs) ARN() (string, error) { + a.stsInit.Do(a.initSTS) + return a.sts.Arn() +} + +func (a *Funcs) initMeta() { + if a.meta == nil { + a.meta = aws.NewEc2Meta(a.awsopts) + } +} + +func (a *Funcs) initInfo() { + if a.info == nil { + a.info = aws.NewEc2Info(a.awsopts) + } +} + +func (a *Funcs) initKMS() { + if a.kms == nil { + a.kms = aws.NewKMS(a.awsopts) + } +} + +func (a *Funcs) initSTS() { + if a.sts == nil { + a.sts = aws.NewSTS(a.awsopts) + } +} diff --git a/internal/funcs/aws_test.go b/internal/funcs/aws_test.go new file mode 100644 index 00000000..2a1efc5d --- /dev/null +++ b/internal/funcs/aws_test.go @@ -0,0 +1,46 @@ +package funcs + +import ( + "context" + "strconv" + "testing" + + "github.com/hairyhenderson/gomplate/v4/aws" + "github.com/stretchr/testify/assert" +) + +func TestCreateAWSFuncs(t *testing.T) { + t.Parallel() + + for i := 0; i < 10; i++ { + // Run this a bunch to catch race conditions + t.Run(strconv.Itoa(i), func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + fmap := CreateAWSFuncs(ctx) + actual := fmap["aws"].(func() interface{}) + + assert.Equal(t, ctx, actual().(*Funcs).ctx) + }) + } +} + +func TestAWSFuncs(t *testing.T) { + t.Parallel() + + m := aws.NewDummyEc2Meta() + i := aws.NewDummyEc2Info(m) + af := &Funcs{meta: m, info: i} + assert.Equal(t, "unknown", must(af.EC2Region())) + assert.Equal(t, "", must(af.EC2Meta("foo"))) + assert.Equal(t, "", must(af.EC2Tag("foo"))) + assert.Equal(t, "unknown", must(af.EC2Region())) +} + +func must(r interface{}, err error) interface{} { + if err != nil { + panic(err) + } + return r +} diff --git a/internal/funcs/base64.go b/internal/funcs/base64.go new file mode 100644 index 00000000..b25a9091 --- /dev/null +++ b/internal/funcs/base64.go @@ -0,0 +1,77 @@ +package funcs + +import ( + "context" + + "github.com/hairyhenderson/gomplate/v4/base64" + "github.com/hairyhenderson/gomplate/v4/conv" +) + +// Base64NS - the base64 namespace +// +// Deprecated: don't use +func Base64NS() *Base64Funcs { + return &Base64Funcs{} +} + +// AddBase64Funcs - +// +// Deprecated: use [CreateBase64Funcs] instead +func AddBase64Funcs(f map[string]interface{}) { + for k, v := range CreateBase64Funcs(context.Background()) { + f[k] = v + } +} + +// CreateBase64Funcs - +func CreateBase64Funcs(ctx context.Context) map[string]interface{} { + f := map[string]interface{}{} + + ns := &Base64Funcs{ctx} + f["base64"] = func() interface{} { return ns } + + return f +} + +// Base64Funcs - +type Base64Funcs struct { + ctx context.Context +} + +// Encode - +func (Base64Funcs) Encode(in interface{}) (string, error) { + b := toBytes(in) + return base64.Encode(b) +} + +// Decode - +func (Base64Funcs) Decode(in interface{}) (string, error) { + out, err := base64.Decode(conv.ToString(in)) + return string(out), err +} + +// DecodeBytes - +func (Base64Funcs) DecodeBytes(in interface{}) ([]byte, error) { + out, err := base64.Decode(conv.ToString(in)) + return out, err +} + +type byter interface { + Bytes() []byte +} + +func toBytes(in interface{}) []byte { + if in == nil { + return []byte{} + } + if s, ok := in.([]byte); ok { + return s + } + if s, ok := in.(byter); ok { + return s.Bytes() + } + if s, ok := in.(string); ok { + return []byte(s) + } + return []byte(conv.ToString(in)) +} diff --git a/internal/funcs/base64_test.go b/internal/funcs/base64_test.go new file mode 100644 index 00000000..f0cca886 --- /dev/null +++ b/internal/funcs/base64_test.go @@ -0,0 +1,79 @@ +package funcs + +import ( + "bytes" + "context" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCreateBase64Funcs(t *testing.T) { + t.Parallel() + + for i := 0; i < 10; i++ { + // Run this a bunch to catch race conditions + t.Run(strconv.Itoa(i), func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + fmap := CreateBase64Funcs(ctx) + actual := fmap["base64"].(func() interface{}) + + assert.Equal(t, ctx, actual().(*Base64Funcs).ctx) + }) + } +} + +func TestBase64Encode(t *testing.T) { + t.Parallel() + + bf := &Base64Funcs{} + assert.Equal(t, "Zm9vYmFy", must(bf.Encode("foobar"))) +} + +func TestBase64Decode(t *testing.T) { + t.Parallel() + + bf := &Base64Funcs{} + assert.Equal(t, "foobar", must(bf.Decode("Zm9vYmFy"))) +} + +func TestBase64DecodeBytes(t *testing.T) { + t.Parallel() + + bf := &Base64Funcs{} + out, err := bf.DecodeBytes("Zm9vYmFy") + require.NoError(t, err) + assert.Equal(t, "foobar", string(out)) +} + +func TestToBytes(t *testing.T) { + t.Parallel() + + assert.Equal(t, []byte{0, 1, 2, 3}, toBytes([]byte{0, 1, 2, 3})) + + buf := &bytes.Buffer{} + buf.WriteString("hi") + assert.Equal(t, []byte("hi"), toBytes(buf)) + assert.Equal(t, []byte{}, toBytes(nil)) + assert.Equal(t, []byte("42"), toBytes(42)) +} + +func BenchmarkToBytes(b *testing.B) { + for i := 0; i < b.N; i++ { + b.StopTimer() + buf := &bytes.Buffer{} + buf.WriteString("hi") + bin := []byte{0, 1, 2, 3} + b.StartTimer() + + toBytes(bin) + + toBytes(buf) + toBytes(nil) + toBytes(42) + } +} diff --git a/internal/funcs/coll.go b/internal/funcs/coll.go new file mode 100644 index 00000000..c8cb4dcc --- /dev/null +++ b/internal/funcs/coll.go @@ -0,0 +1,218 @@ +package funcs + +import ( + "context" + "fmt" + "reflect" + + "github.com/hairyhenderson/gomplate/v4/conv" + "github.com/hairyhenderson/gomplate/v4/internal/deprecated" + "github.com/hairyhenderson/gomplate/v4/internal/texttemplate" + + "github.com/hairyhenderson/gomplate/v4/coll" +) + +// CollNS - +// +// Deprecated: don't use +func CollNS() *CollFuncs { + return &CollFuncs{} +} + +// AddCollFuncs - +// +// Deprecated: use CreateCollFuncs instead +func AddCollFuncs(f map[string]interface{}) { + for k, v := range CreateCollFuncs(context.Background()) { + f[k] = v + } +} + +// CreateCollFuncs - +func CreateCollFuncs(ctx context.Context) map[string]interface{} { + f := map[string]interface{}{} + + ns := &CollFuncs{ctx} + f["coll"] = func() interface{} { return ns } + + f["has"] = ns.Has + f["slice"] = ns.deprecatedSlice + f["dict"] = ns.Dict + f["keys"] = ns.Keys + f["values"] = ns.Values + f["append"] = ns.Append + f["prepend"] = ns.Prepend + f["uniq"] = ns.Uniq + f["reverse"] = ns.Reverse + f["merge"] = ns.Merge + f["sort"] = ns.Sort + f["jsonpath"] = ns.JSONPath + f["jq"] = ns.JQ + f["flatten"] = ns.Flatten + return f +} + +// CollFuncs - +type CollFuncs struct { + ctx context.Context +} + +// Slice - +func (CollFuncs) Slice(args ...interface{}) []interface{} { + return coll.Slice(args...) +} + +// deprecatedSlice - +// Deprecated: use coll.Slice instead +func (f *CollFuncs) deprecatedSlice(args ...interface{}) []interface{} { + deprecated.WarnDeprecated(f.ctx, "the 'slice' alias for coll.Slice is deprecated - use coll.Slice instead") + return coll.Slice(args...) +} + +// GoSlice - same as text/template's 'slice' function +func (CollFuncs) GoSlice(item reflect.Value, indexes ...reflect.Value) (reflect.Value, error) { + return texttemplate.GoSlice(item, indexes...) +} + +// Has - +func (CollFuncs) Has(in interface{}, key string) bool { + return coll.Has(in, key) +} + +// Index returns the result of indexing the last argument with the preceding +// index keys. This is similar to the `index` built-in template function, but +// the arguments are ordered differently for pipeline compatibility. Also, this +// function is more strict, and will return an error when the value doesn't +// contain the given key. +func (CollFuncs) Index(args ...interface{}) (interface{}, error) { + if len(args) < 2 { + return nil, fmt.Errorf("wrong number of args: wanted at least 2, got %d", len(args)) + } + + item := args[len(args)-1] + indexes := args[:len(args)-1] + + return coll.Index(item, indexes...) +} + +// Dict - +func (CollFuncs) Dict(in ...interface{}) (map[string]interface{}, error) { + return coll.Dict(in...) +} + +// Keys - +func (CollFuncs) Keys(in ...map[string]interface{}) ([]string, error) { + return coll.Keys(in...) +} + +// Values - +func (CollFuncs) Values(in ...map[string]interface{}) ([]interface{}, error) { + return coll.Values(in...) +} + +// Append - +func (CollFuncs) Append(v interface{}, list interface{}) ([]interface{}, error) { + return coll.Append(v, list) +} + +// Prepend - +func (CollFuncs) Prepend(v interface{}, list interface{}) ([]interface{}, error) { + return coll.Prepend(v, list) +} + +// Uniq - +func (CollFuncs) Uniq(in interface{}) ([]interface{}, error) { + return coll.Uniq(in) +} + +// Reverse - +func (CollFuncs) Reverse(in interface{}) ([]interface{}, error) { + return coll.Reverse(in) +} + +// Merge - +func (CollFuncs) Merge(dst map[string]interface{}, src ...map[string]interface{}) (map[string]interface{}, error) { + return coll.Merge(dst, src...) +} + +// Sort - +func (CollFuncs) Sort(args ...interface{}) ([]interface{}, error) { + var ( + key string + list interface{} + ) + if len(args) == 0 || len(args) > 2 { + return nil, fmt.Errorf("wrong number of args: wanted 1 or 2, got %d", len(args)) + } + if len(args) == 1 { + list = args[0] + } + if len(args) == 2 { + key = conv.ToString(args[0]) + list = args[1] + } + return coll.Sort(key, list) +} + +// JSONPath - +func (CollFuncs) JSONPath(p string, in interface{}) (interface{}, error) { + return coll.JSONPath(p, in) +} + +// JQ - +func (f *CollFuncs) JQ(jqExpr string, in interface{}) (interface{}, error) { + return coll.JQ(f.ctx, jqExpr, in) +} + +// Flatten - +func (CollFuncs) Flatten(args ...interface{}) ([]interface{}, error) { + if len(args) == 0 || len(args) > 2 { + return nil, fmt.Errorf("wrong number of args: wanted 1 or 2, got %d", len(args)) + } + list := args[0] + depth := -1 + if len(args) == 2 { + depth = conv.ToInt(args[0]) + list = args[1] + } + return coll.Flatten(list, depth) +} + +func pickOmitArgs(args ...interface{}) (map[string]interface{}, []string, error) { + if len(args) <= 1 { + return nil, nil, fmt.Errorf("wrong number of args: wanted 2 or more, got %d", len(args)) + } + + m, ok := args[len(args)-1].(map[string]interface{}) + if !ok { + return nil, nil, fmt.Errorf("wrong map type: must be map[string]interface{}, got %T", args[len(args)-1]) + } + + keys := make([]string, len(args)-1) + for i, v := range args[0 : len(args)-1] { + k, ok := v.(string) + if !ok { + return nil, nil, fmt.Errorf("wrong key type: must be string, got %T (%+v)", args[i], args[i]) + } + keys[i] = k + } + return m, keys, nil +} + +// Pick - +func (CollFuncs) Pick(args ...interface{}) (map[string]interface{}, error) { + m, keys, err := pickOmitArgs(args...) + if err != nil { + return nil, err + } + return coll.Pick(m, keys...), nil +} + +// Omit - +func (CollFuncs) Omit(args ...interface{}) (map[string]interface{}, error) { + m, keys, err := pickOmitArgs(args...) + if err != nil { + return nil, err + } + return coll.Omit(m, keys...), nil +} diff --git a/internal/funcs/coll_test.go b/internal/funcs/coll_test.go new file mode 100644 index 00000000..71900a59 --- /dev/null +++ b/internal/funcs/coll_test.go @@ -0,0 +1,179 @@ +package funcs + +import ( + "context" + "reflect" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCreateCollFuncs(t *testing.T) { + t.Parallel() + + for i := 0; i < 10; i++ { + // Run this a bunch to catch race conditions + t.Run(strconv.Itoa(i), func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + fmap := CreateCollFuncs(ctx) + actual := fmap["coll"].(func() interface{}) + + assert.Equal(t, ctx, actual().(*CollFuncs).ctx) + }) + } +} + +func TestFlatten(t *testing.T) { + t.Parallel() + + c := CollFuncs{} + + _, err := c.Flatten() + assert.Error(t, err) + + _, err = c.Flatten(42) + assert.Error(t, err) + + out, err := c.Flatten([]interface{}{1, []interface{}{[]int{2}, 3}}) + require.NoError(t, err) + assert.EqualValues(t, []interface{}{1, 2, 3}, out) + + out, err = c.Flatten(1, []interface{}{1, []interface{}{[]int{2}, 3}}) + require.NoError(t, err) + assert.EqualValues(t, []interface{}{1, []int{2}, 3}, out) +} + +func TestPick(t *testing.T) { + t.Parallel() + + c := &CollFuncs{} + + _, err := c.Pick() + assert.Error(t, err) + + _, err = c.Pick("") + assert.Error(t, err) + + _, err = c.Pick("foo", nil) + assert.Error(t, err) + + _, err = c.Pick("foo", "bar") + assert.Error(t, err) + + _, err = c.Pick(map[string]interface{}{}, "foo", "bar", map[string]interface{}{}) + assert.Error(t, err) + + in := map[string]interface{}{ + "foo": "bar", + "bar": true, + "": "baz", + } + out, err := c.Pick("baz", in) + require.NoError(t, err) + assert.EqualValues(t, map[string]interface{}{}, out) + + expected := map[string]interface{}{ + "foo": "bar", + "bar": true, + } + out, err = c.Pick("foo", "bar", in) + require.NoError(t, err) + assert.EqualValues(t, expected, out) + + expected = map[string]interface{}{ + "": "baz", + } + out, err = c.Pick("", in) + require.NoError(t, err) + assert.EqualValues(t, expected, out) + + out, err = c.Pick("foo", "bar", "", in) + require.NoError(t, err) + assert.EqualValues(t, in, out) +} + +func TestOmit(t *testing.T) { + t.Parallel() + + c := &CollFuncs{} + + _, err := c.Omit() + assert.Error(t, err) + + _, err = c.Omit("") + assert.Error(t, err) + + _, err = c.Omit("foo", nil) + assert.Error(t, err) + + _, err = c.Omit("foo", "bar") + assert.Error(t, err) + + _, err = c.Omit(map[string]interface{}{}, "foo", "bar", map[string]interface{}{}) + assert.Error(t, err) + + in := map[string]interface{}{ + "foo": "bar", + "bar": true, + "": "baz", + } + out, err := c.Omit("baz", in) + require.NoError(t, err) + assert.EqualValues(t, in, out) + + expected := map[string]interface{}{ + "foo": "bar", + "bar": true, + } + out, err = c.Omit("", in) + require.NoError(t, err) + assert.EqualValues(t, expected, out) + + expected = map[string]interface{}{ + "": "baz", + } + out, err = c.Omit("foo", "bar", in) + require.NoError(t, err) + assert.EqualValues(t, expected, out) + + out, err = c.Omit("foo", "bar", "", in) + require.NoError(t, err) + assert.EqualValues(t, map[string]interface{}{}, out) +} + +func TestGoSlice(t *testing.T) { + t.Parallel() + + c := &CollFuncs{} + + in := reflect.ValueOf(nil) + _, err := c.GoSlice(in) + assert.Error(t, err) + + in = reflect.ValueOf(42) + _, err = c.GoSlice(in) + assert.Error(t, err) + + // invalid index type + in = reflect.ValueOf([]interface{}{1}) + _, err = c.GoSlice(in, reflect.ValueOf([]interface{}{[]int{2}})) + assert.Error(t, err) + + // valid slice, no slicing + in = reflect.ValueOf([]int{1}) + out, err := c.GoSlice(in) + require.NoError(t, err) + assert.Equal(t, reflect.TypeOf([]int{}), out.Type()) + assert.EqualValues(t, []int{1}, out.Interface()) + + // valid slice, slicing + in = reflect.ValueOf([]string{"foo", "bar", "baz"}) + out, err = c.GoSlice(in, reflect.ValueOf(1), reflect.ValueOf(3)) + require.NoError(t, err) + assert.Equal(t, reflect.TypeOf([]string{}), out.Type()) + assert.EqualValues(t, []string{"bar", "baz"}, out.Interface()) +} diff --git a/internal/funcs/conv.go b/internal/funcs/conv.go new file mode 100644 index 00000000..3361879d --- /dev/null +++ b/internal/funcs/conv.go @@ -0,0 +1,166 @@ +package funcs + +import ( + "context" + "net/url" + "text/template" + + "github.com/hairyhenderson/gomplate/v4/coll" + "github.com/hairyhenderson/gomplate/v4/conv" + "github.com/hairyhenderson/gomplate/v4/internal/deprecated" +) + +// ConvNS - +// +// Deprecated: don't use +func ConvNS() *ConvFuncs { + return &ConvFuncs{} +} + +// AddConvFuncs - +// +// Deprecated: use [CreateConvFuncs] instead +func AddConvFuncs(f map[string]interface{}) { + for k, v := range CreateConvFuncs(context.Background()) { + f[k] = v + } +} + +// CreateConvFuncs - +func CreateConvFuncs(ctx context.Context) map[string]interface{} { + ns := &ConvFuncs{ctx} + + f := map[string]interface{}{} + f["conv"] = func() interface{} { return ns } + + f["urlParse"] = ns.URL + f["bool"] = ns.Bool + f["join"] = ns.Join + f["default"] = ns.Default + return f +} + +// ConvFuncs - +type ConvFuncs struct { + ctx context.Context +} + +// Bool - +// +// Deprecated: use [ToBool] instead +func (f *ConvFuncs) Bool(s interface{}) bool { + deprecated.WarnDeprecated(f.ctx, "conv.Bool is deprecated - use conv.ToBool instead") + return conv.Bool(conv.ToString(s)) +} + +// ToBool - +func (ConvFuncs) ToBool(in interface{}) bool { + return conv.ToBool(in) +} + +// ToBools - +func (ConvFuncs) ToBools(in ...interface{}) []bool { + return conv.ToBools(in...) +} + +// Slice - +// +// Deprecated: use [CollFuncs.Slice] instead +func (f *ConvFuncs) Slice(args ...interface{}) []interface{} { + deprecated.WarnDeprecated(f.ctx, "conv.Slice is deprecated - use coll.Slice instead") + return coll.Slice(args...) +} + +// Join - +func (ConvFuncs) Join(in interface{}, sep string) (string, error) { + return conv.Join(in, sep) +} + +// Has - +// +// Deprecated: use [CollFuncs.Has] instead +func (f *ConvFuncs) Has(in interface{}, key string) bool { + deprecated.WarnDeprecated(f.ctx, "conv.Has is deprecated - use coll.Has instead") + return coll.Has(in, key) +} + +// ParseInt - +func (ConvFuncs) ParseInt(s interface{}, base, bitSize int) int64 { + return conv.MustParseInt(conv.ToString(s), base, bitSize) +} + +// ParseFloat - +func (ConvFuncs) ParseFloat(s interface{}, bitSize int) float64 { + return conv.MustParseFloat(conv.ToString(s), bitSize) +} + +// ParseUint - +func (ConvFuncs) ParseUint(s interface{}, base, bitSize int) uint64 { + return conv.MustParseUint(conv.ToString(s), base, bitSize) +} + +// Atoi - +func (ConvFuncs) Atoi(s interface{}) int { + return conv.MustAtoi(conv.ToString(s)) +} + +// URL - +func (ConvFuncs) URL(s interface{}) (*url.URL, error) { + return url.Parse(conv.ToString(s)) +} + +// ToInt64 - +func (ConvFuncs) ToInt64(in interface{}) int64 { + return conv.ToInt64(in) +} + +// ToInt - +func (ConvFuncs) ToInt(in interface{}) int { + return conv.ToInt(in) +} + +// ToInt64s - +func (ConvFuncs) ToInt64s(in ...interface{}) []int64 { + return conv.ToInt64s(in...) +} + +// ToInts - +func (ConvFuncs) ToInts(in ...interface{}) []int { + return conv.ToInts(in...) +} + +// ToFloat64 - +func (ConvFuncs) ToFloat64(in interface{}) float64 { + return conv.ToFloat64(in) +} + +// ToFloat64s - +func (ConvFuncs) ToFloat64s(in ...interface{}) []float64 { + return conv.ToFloat64s(in...) +} + +// ToString - +func (ConvFuncs) ToString(in interface{}) string { + return conv.ToString(in) +} + +// ToStrings - +func (ConvFuncs) ToStrings(in ...interface{}) []string { + return conv.ToStrings(in...) +} + +// Default - +func (ConvFuncs) Default(def, in interface{}) interface{} { + if truth, ok := template.IsTrue(in); truth && ok { + return in + } + return def +} + +// Dict - +// +// Deprecated: use [CollFuncs.Dict] instead +func (f *ConvFuncs) Dict(in ...interface{}) (map[string]interface{}, error) { + deprecated.WarnDeprecated(f.ctx, "conv.Dict is deprecated - use coll.Dict instead") + return coll.Dict(in...) +} diff --git a/internal/funcs/conv_test.go b/internal/funcs/conv_test.go new file mode 100644 index 00000000..b20f013c --- /dev/null +++ b/internal/funcs/conv_test.go @@ -0,0 +1,63 @@ +package funcs + +import ( + "context" + "fmt" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCreateConvFuncs(t *testing.T) { + t.Parallel() + + for i := 0; i < 10; i++ { + // Run this a bunch to catch race conditions + t.Run(strconv.Itoa(i), func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + fmap := CreateConvFuncs(ctx) + actual := fmap["conv"].(func() interface{}) + + assert.Equal(t, ctx, actual().(*ConvFuncs).ctx) + }) + } +} + +func TestDefault(t *testing.T) { + t.Parallel() + + s := struct{}{} + c := &ConvFuncs{} + def := "DEFAULT" + data := []struct { + val interface{} + empty bool + }{ + {0, true}, + {1, false}, + {nil, true}, + {"", true}, + {"foo", false}, + {[]string{}, true}, + {[]string{"foo"}, false}, + {[]string{""}, false}, + {c, false}, + {s, false}, + } + + for _, d := range data { + d := d + t.Run(fmt.Sprintf("%T/%#v empty==%v", d.val, d.val, d.empty), func(t *testing.T) { + t.Parallel() + + if d.empty { + assert.Equal(t, def, c.Default(def, d.val)) + } else { + assert.Equal(t, d.val, c.Default(def, d.val)) + } + }) + } +} diff --git a/internal/funcs/crypto.go b/internal/funcs/crypto.go new file mode 100644 index 00000000..e3628fad --- /dev/null +++ b/internal/funcs/crypto.go @@ -0,0 +1,399 @@ +package funcs + +import ( + "context" + gcrypto "crypto" + "crypto/elliptic" + "crypto/sha1" //nolint: gosec + "crypto/sha256" + "crypto/sha512" + "encoding/base64" + "encoding/hex" + "fmt" + "strings" + "unicode/utf8" + + "golang.org/x/crypto/bcrypt" + + "github.com/hairyhenderson/gomplate/v4/conv" + "github.com/hairyhenderson/gomplate/v4/crypto" +) + +// CryptoNS - the crypto namespace +// +// Deprecated: don't use +func CryptoNS() *CryptoFuncs { + return &CryptoFuncs{} +} + +// AddCryptoFuncs - +// +// Deprecated: use [CreateCryptoFuncs] instead +func AddCryptoFuncs(f map[string]interface{}) { + for k, v := range CreateCryptoFuncs(context.Background()) { + f[k] = v + } +} + +// CreateCryptoFuncs - +func CreateCryptoFuncs(ctx context.Context) map[string]interface{} { + f := map[string]interface{}{} + + ns := &CryptoFuncs{ctx} + + f["crypto"] = func() interface{} { return ns } + return f +} + +// CryptoFuncs - +type CryptoFuncs struct { + ctx context.Context +} + +// PBKDF2 - Run the Password-Based Key Derivation Function #2 as defined in +// RFC 2898 (PKCS #5 v2.0). This function outputs the binary result in hex +// format. +func (CryptoFuncs) PBKDF2(password, salt, iter, keylen interface{}, hashFunc ...string) (k string, err error) { + var h gcrypto.Hash + if len(hashFunc) == 0 { + h = gcrypto.SHA1 + } else { + h, err = crypto.StrToHash(hashFunc[0]) + if err != nil { + return "", err + } + } + pw := toBytes(password) + s := toBytes(salt) + i := conv.ToInt(iter) + kl := conv.ToInt(keylen) + + dk, err := crypto.PBKDF2(pw, s, i, kl, h) + return fmt.Sprintf("%02x", dk), err +} + +// WPAPSK - Convert an ASCII passphrase to WPA PSK for a given SSID +func (f CryptoFuncs) WPAPSK(ssid, password interface{}) (string, error) { + return f.PBKDF2(password, ssid, 4096, 32) +} + +// SHA1 - Note: SHA-1 is cryptographically broken and should not be used for secure applications. +func (f CryptoFuncs) SHA1(input interface{}) string { + out, _ := f.SHA1Bytes(input) + return fmt.Sprintf("%02x", out) +} + +// SHA224 - +func (f CryptoFuncs) SHA224(input interface{}) string { + out, _ := f.SHA224Bytes(input) + return fmt.Sprintf("%02x", out) +} + +// SHA256 - +func (f CryptoFuncs) SHA256(input interface{}) string { + out, _ := f.SHA256Bytes(input) + return fmt.Sprintf("%02x", out) +} + +// SHA384 - +func (f CryptoFuncs) SHA384(input interface{}) string { + out, _ := f.SHA384Bytes(input) + return fmt.Sprintf("%02x", out) +} + +// SHA512 - +func (f CryptoFuncs) SHA512(input interface{}) string { + out, _ := f.SHA512Bytes(input) + return fmt.Sprintf("%02x", out) +} + +// SHA512_224 - +// +//nolint:revive,stylecheck +func (f CryptoFuncs) SHA512_224(input interface{}) string { + out, _ := f.SHA512_224Bytes(input) + return fmt.Sprintf("%02x", out) +} + +// SHA512_256 - +// +//nolint:revive,stylecheck +func (f CryptoFuncs) SHA512_256(input interface{}) string { + out, _ := f.SHA512_256Bytes(input) + return fmt.Sprintf("%02x", out) +} + +// SHA1 - Note: SHA-1 is cryptographically broken and should not be used for secure applications. +func (CryptoFuncs) SHA1Bytes(input interface{}) ([]byte, error) { + //nolint:gosec + b := sha1.Sum(toBytes(input)) + out := make([]byte, len(b)) + copy(out, b[:]) + return out, nil +} + +// SHA224 - +func (CryptoFuncs) SHA224Bytes(input interface{}) ([]byte, error) { + b := sha256.Sum224(toBytes(input)) + out := make([]byte, len(b)) + copy(out, b[:]) + return out, nil +} + +// SHA256 - +func (CryptoFuncs) SHA256Bytes(input interface{}) ([]byte, error) { + b := sha256.Sum256(toBytes(input)) + out := make([]byte, len(b)) + copy(out, b[:]) + return out, nil +} + +// SHA384 - +func (CryptoFuncs) SHA384Bytes(input interface{}) ([]byte, error) { + b := sha512.Sum384(toBytes(input)) + out := make([]byte, len(b)) + copy(out, b[:]) + return out, nil +} + +// SHA512 - +func (CryptoFuncs) SHA512Bytes(input interface{}) ([]byte, error) { + b := sha512.Sum512(toBytes(input)) + out := make([]byte, len(b)) + copy(out, b[:]) + return out, nil +} + +// SHA512_224 - +func (CryptoFuncs) SHA512_224Bytes(input interface{}) ([]byte, error) { + b := sha512.Sum512_224(toBytes(input)) + out := make([]byte, len(b)) + copy(out, b[:]) + return out, nil +} + +// SHA512_256 - +func (CryptoFuncs) SHA512_256Bytes(input interface{}) ([]byte, error) { + b := sha512.Sum512_256(toBytes(input)) + out := make([]byte, len(b)) + copy(out, b[:]) + return out, nil +} + +// Bcrypt - +func (CryptoFuncs) Bcrypt(args ...interface{}) (string, error) { + input := "" + cost := bcrypt.DefaultCost + if len(args) == 0 { + return "", fmt.Errorf("bcrypt requires at least an 'input' value") + } + if len(args) == 1 { + input = conv.ToString(args[0]) + } + if len(args) == 2 { + cost = conv.ToInt(args[0]) + input = conv.ToString(args[1]) + } + hash, err := bcrypt.GenerateFromPassword([]byte(input), cost) + return string(hash), err +} + +// RSAEncrypt - +// Experimental! +func (f *CryptoFuncs) RSAEncrypt(key string, in interface{}) ([]byte, error) { + if err := checkExperimental(f.ctx); err != nil { + return nil, err + } + msg := toBytes(in) + return crypto.RSAEncrypt(key, msg) +} + +// RSADecrypt - +// Experimental! +func (f *CryptoFuncs) RSADecrypt(key string, in []byte) (string, error) { + if err := checkExperimental(f.ctx); err != nil { + return "", err + } + out, err := crypto.RSADecrypt(key, in) + return string(out), err +} + +// RSADecryptBytes - +// Experimental! +func (f *CryptoFuncs) RSADecryptBytes(key string, in []byte) ([]byte, error) { + if err := checkExperimental(f.ctx); err != nil { + return nil, err + } + out, err := crypto.RSADecrypt(key, in) + return out, err +} + +// RSAGenerateKey - +// Experimental! +func (f *CryptoFuncs) RSAGenerateKey(args ...interface{}) (string, error) { + if err := checkExperimental(f.ctx); err != nil { + return "", err + } + bits := 4096 + if len(args) == 1 { + bits = conv.ToInt(args[0]) + } else if len(args) > 1 { + return "", fmt.Errorf("wrong number of args: want 0 or 1, got %d", len(args)) + } + out, err := crypto.RSAGenerateKey(bits) + return string(out), err +} + +// RSADerivePublicKey - +// Experimental! +func (f *CryptoFuncs) RSADerivePublicKey(privateKey string) (string, error) { + if err := checkExperimental(f.ctx); err != nil { + return "", err + } + out, err := crypto.RSADerivePublicKey([]byte(privateKey)) + return string(out), err +} + +// ECDSAGenerateKey - +// Experimental! +func (f *CryptoFuncs) ECDSAGenerateKey(args ...interface{}) (string, error) { + if err := checkExperimental(f.ctx); err != nil { + return "", err + } + + curve := elliptic.P256() + if len(args) == 1 { + c := conv.ToString(args[0]) + c = strings.ToUpper(c) + c = strings.ReplaceAll(c, "-", "") + var ok bool + curve, ok = crypto.Curves(c) + if !ok { + return "", fmt.Errorf("unknown curve: %s", c) + } + } else if len(args) > 1 { + return "", fmt.Errorf("wrong number of args: want 0 or 1, got %d", len(args)) + } + + out, err := crypto.ECDSAGenerateKey(curve) + return string(out), err +} + +// ECDSADerivePublicKey - +// Experimental! +func (f *CryptoFuncs) ECDSADerivePublicKey(privateKey string) (string, error) { + if err := checkExperimental(f.ctx); err != nil { + return "", err + } + + out, err := crypto.ECDSADerivePublicKey([]byte(privateKey)) + return string(out), err +} + +// Ed25519GenerateKey - +// Experimental! +func (f *CryptoFuncs) Ed25519GenerateKey() (string, error) { + if err := checkExperimental(f.ctx); err != nil { + return "", err + } + out, err := crypto.Ed25519GenerateKey() + return string(out), err +} + +// Ed25519GenerateKeyFromSeed - +// Experimental! +func (f *CryptoFuncs) Ed25519GenerateKeyFromSeed(encoding, seed string) (string, error) { + if err := checkExperimental(f.ctx); err != nil { + return "", err + } + if !utf8.ValidString(seed) { + return "", fmt.Errorf("given seed is not valid UTF-8") // Don't print out seed (private). + } + var seedB []byte + var err error + switch encoding { + case "base64": + seedB, err = base64.StdEncoding.DecodeString(seed) + case "hex": + seedB, err = hex.DecodeString(seed) + default: + return "", fmt.Errorf("invalid encoding given: %s - only 'hex' or 'base64' are valid options", encoding) + } + if err != nil { + return "", fmt.Errorf("could not decode given seed: %w", err) + } + out, err := crypto.Ed25519GenerateKeyFromSeed(seedB) + return string(out), err +} + +// Ed25519DerivePublicKey - +// Experimental! +func (f *CryptoFuncs) Ed25519DerivePublicKey(privateKey string) (string, error) { + if err := checkExperimental(f.ctx); err != nil { + return "", err + } + out, err := crypto.Ed25519DerivePublicKey([]byte(privateKey)) + return string(out), err +} + +// EncryptAES - +// Experimental! +func (f *CryptoFuncs) EncryptAES(key string, args ...interface{}) ([]byte, error) { + if err := checkExperimental(f.ctx); err != nil { + return nil, err + } + + k, msg, err := parseAESArgs(key, args...) + if err != nil { + return nil, err + } + + return crypto.EncryptAESCBC(k, msg) +} + +// DecryptAES - +// Experimental! +func (f *CryptoFuncs) DecryptAES(key string, args ...interface{}) (string, error) { + if err := checkExperimental(f.ctx); err != nil { + return "", err + } + + out, err := f.DecryptAESBytes(key, args...) + return conv.ToString(out), err +} + +// DecryptAESBytes - +// Experimental! +func (f *CryptoFuncs) DecryptAESBytes(key string, args ...interface{}) ([]byte, error) { + if err := checkExperimental(f.ctx); err != nil { + return nil, err + } + + k, msg, err := parseAESArgs(key, args...) + if err != nil { + return nil, err + } + + return crypto.DecryptAESCBC(k, msg) +} + +func parseAESArgs(key string, args ...interface{}) ([]byte, []byte, error) { + keyBits := 256 // default to AES-256-CBC + + var msg []byte + + switch len(args) { + case 1: + msg = toBytes(args[0]) + case 2: + keyBits = conv.ToInt(args[0]) + msg = toBytes(args[1]) + default: + return nil, nil, fmt.Errorf("wrong number of args: want 2 or 3, got %d", len(args)) + } + + k := make([]byte, keyBits/8) + copy(k, []byte(key)) + + return k, msg, nil +} diff --git a/internal/funcs/crypto_test.go b/internal/funcs/crypto_test.go new file mode 100644 index 00000000..4d1b5660 --- /dev/null +++ b/internal/funcs/crypto_test.go @@ -0,0 +1,278 @@ +package funcs + +import ( + "context" + "encoding/base64" + "strconv" + "strings" + "testing" + + "github.com/hairyhenderson/gomplate/v4/internal/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCreateCryptoFuncs(t *testing.T) { + for i := 0; i < 10; i++ { + // Run this a bunch to catch race conditions + t.Run(strconv.Itoa(i), func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + fmap := CreateCryptoFuncs(ctx) + actual := fmap["crypto"].(func() interface{}) + + assert.Equal(t, ctx, actual().(*CryptoFuncs).ctx) + }) + } +} + +func testCryptoNS() *CryptoFuncs { + return &CryptoFuncs{ctx: config.SetExperimental(context.Background())} +} + +func TestPBKDF2(t *testing.T) { + t.Parallel() + + c := testCryptoNS() + dk, err := c.PBKDF2("password", []byte("IEEE"), "4096", 32) + assert.Equal(t, "f42c6fc52df0ebef9ebb4b90b38a5f902e83fe1b135a70e23aed762e9710a12e", dk) + require.NoError(t, err) + + dk, err = c.PBKDF2([]byte("password"), "IEEE", 4096, "64", "SHA-512") + assert.Equal(t, "c16f4cb6d03e23614399dee5e7f676fb1da0eb9471b6a74a6c5bc934c6ec7d2ab7028fbb1000b1beb97f17646045d8144792352f6676d13b20a4c03754903d7e", dk) + require.NoError(t, err) + + _, err = c.PBKDF2(nil, nil, nil, nil, "bogus") + assert.Error(t, err) +} + +func TestWPAPSK(t *testing.T) { + t.Parallel() + + c := testCryptoNS() + dk, err := c.WPAPSK("password", "MySSID") + assert.Equal(t, "3a98def84b11644a17ebcc9b17955d2360ce8b8a85b8a78413fc551d722a84e7", dk) + require.NoError(t, err) +} + +func TestSHA(t *testing.T) { + t.Parallel() + + in := "abc" + sha1 := "a9993e364706816aba3e25717850c26c9cd0d89d" + sha224 := "23097d223405d8228642a477bda255b32aadbce4bda0b3f7e36c9da7" + sha256 := "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad" + sha384 := "cb00753f45a35e8bb5a03d699ac65007272c32ab0eded1631a8b605a43ff5bed8086072ba1e7cc2358baeca134c825a7" + sha512 := "ddaf35a193617abacc417349ae20413112e6fa4e89a97ea20a9eeee64b55d39a2192992a274fc1a836ba3c23a3feebbd454d4423643ce80e2a9ac94fa54ca49f" + sha512_224 := "4634270f707b6a54daae7530460842e20e37ed265ceee9a43e8924aa" + sha512_256 := "53048e2681941ef99b2e29b76b4c7dabe4c2d0c634fc6d46e0e2f13107e7af23" + c := testCryptoNS() + assert.Equal(t, sha1, c.SHA1(in)) + assert.Equal(t, sha224, c.SHA224(in)) + assert.Equal(t, sha256, c.SHA256(in)) + assert.Equal(t, sha384, c.SHA384(in)) + assert.Equal(t, sha512, c.SHA512(in)) + assert.Equal(t, sha512_224, c.SHA512_224(in)) + assert.Equal(t, sha512_256, c.SHA512_256(in)) +} + +func TestBcrypt(t *testing.T) { + t.Parallel() + + if testing.Short() { + t.Skip("skipping slow test") + } + + in := "foo" + c := testCryptoNS() + + t.Run("no arg default", func(t *testing.T) { + t.Parallel() + + actual, err := c.Bcrypt(in) + require.NoError(t, err) + assert.True(t, strings.HasPrefix(actual, "$2a$10$")) + }) + + t.Run("cost less than min", func(t *testing.T) { + t.Parallel() + + actual, err := c.Bcrypt(0, in) + require.NoError(t, err) + assert.True(t, strings.HasPrefix(actual, "$2a$10$")) + }) + + t.Run("cost equal to min", func(t *testing.T) { + t.Parallel() + + actual, err := c.Bcrypt(4, in) + require.NoError(t, err) + assert.True(t, strings.HasPrefix(actual, "$2a$04$")) + }) + + t.Run("no args errors", func(t *testing.T) { + t.Parallel() + + _, err := c.Bcrypt() + assert.Error(t, err) + }) +} + +func TestRSAGenerateKey(t *testing.T) { + t.Parallel() + + c := testCryptoNS() + _, err := c.RSAGenerateKey(0) + assert.Error(t, err) + + _, err = c.RSAGenerateKey(0, "foo", true) + assert.Error(t, err) + + key, err := c.RSAGenerateKey(2048) + require.NoError(t, err) + assert.True(t, strings.HasPrefix(key, + "-----BEGIN RSA PRIVATE KEY-----")) + assert.True(t, strings.HasSuffix(key, + "-----END RSA PRIVATE KEY-----\n")) +} + +func TestECDSAGenerateKey(t *testing.T) { + c := testCryptoNS() + _, err := c.ECDSAGenerateKey("") + assert.Error(t, err) + + _, err = c.ECDSAGenerateKey(0, "P-999", true) + assert.Error(t, err) + + key, err := c.ECDSAGenerateKey("P-256") + require.NoError(t, err) + assert.True(t, strings.HasPrefix(key, + "-----BEGIN EC PRIVATE KEY-----")) + assert.True(t, strings.HasSuffix(key, + "-----END EC PRIVATE KEY-----\n")) +} + +func TestECDSADerivePublicKey(t *testing.T) { + c := testCryptoNS() + + _, err := c.ECDSADerivePublicKey("") + assert.Error(t, err) + + key, _ := c.ECDSAGenerateKey("P-256") + pub, err := c.ECDSADerivePublicKey(key) + require.NoError(t, err) + assert.True(t, strings.HasPrefix(pub, + "-----BEGIN PUBLIC KEY-----")) + assert.True(t, strings.HasSuffix(pub, + "-----END PUBLIC KEY-----\n")) +} + +func TestEd25519GenerateKey(t *testing.T) { + c := testCryptoNS() + key, err := c.Ed25519GenerateKey() + require.NoError(t, err) + + assert.True(t, strings.HasPrefix(key, + "-----BEGIN PRIVATE KEY-----")) + assert.True(t, strings.HasSuffix(key, + "-----END PRIVATE KEY-----\n")) +} + +func TestEd25519GenerateKeyFromSeed(t *testing.T) { + c := testCryptoNS() + enc := "" + seed := "" + _, err := c.Ed25519GenerateKeyFromSeed(enc, seed) + assert.Error(t, err) + + enc = "base64" + seed = "0000000000000000000000000000000" // 31 bytes, instead of wanted 32. + _, err = c.Ed25519GenerateKeyFromSeed(enc, seed) + assert.Error(t, err) + + seed += "0" // 32 bytes. + b64seed := base64.StdEncoding.EncodeToString([]byte(seed)) + key, err := c.Ed25519GenerateKeyFromSeed(enc, b64seed) + require.NoError(t, err) + + assert.True(t, strings.HasPrefix(key, + "-----BEGIN PRIVATE KEY-----")) + assert.True(t, strings.HasSuffix(key, + "-----END PRIVATE KEY-----\n")) +} + +func TestEd25519DerivePublicKey(t *testing.T) { + c := testCryptoNS() + + _, err := c.Ed25519DerivePublicKey("") + assert.Error(t, err) + + key, _ := c.Ed25519GenerateKey() + pub, err := c.Ed25519DerivePublicKey(key) + require.NoError(t, err) + assert.True(t, strings.HasPrefix(pub, + "-----BEGIN PUBLIC KEY-----")) + assert.True(t, strings.HasSuffix(pub, + "-----END PUBLIC KEY-----\n")) +} + +func TestRSACrypt(t *testing.T) { + t.Parallel() + + if testing.Short() { + t.Skip("skipping slow test") + } + + c := testCryptoNS() + key, err := c.RSAGenerateKey() + require.NoError(t, err) + pub, err := c.RSADerivePublicKey(key) + require.NoError(t, err) + + in := "hello world" + enc, err := c.RSAEncrypt(pub, in) + require.NoError(t, err) + + dec, err := c.RSADecrypt(key, enc) + require.NoError(t, err) + assert.Equal(t, in, dec) + + b, err := c.RSADecryptBytes(key, enc) + require.NoError(t, err) + assert.Equal(t, dec, string(b)) +} + +func TestAESCrypt(t *testing.T) { + c := testCryptoNS() + key := "0123456789012345" + in := "hello world" + + _, err := c.EncryptAES(key, 1, 2, 3, 4) + assert.Error(t, err) + + _, err = c.DecryptAES(key, 1, 2, 3, 4) + assert.Error(t, err) + + enc, err := c.EncryptAES(key, in) + require.NoError(t, err) + + dec, err := c.DecryptAES(key, enc) + require.NoError(t, err) + assert.Equal(t, in, dec) + + b, err := c.DecryptAESBytes(key, enc) + require.NoError(t, err) + assert.Equal(t, dec, string(b)) + + enc, err = c.EncryptAES(key, 128, in) + require.NoError(t, err) + + dec, err = c.DecryptAES(key, 128, enc) + require.NoError(t, err) + assert.Equal(t, in, dec) + + b, err = c.DecryptAESBytes(key, 128, enc) + require.NoError(t, err) + assert.Equal(t, dec, string(b)) +} diff --git a/internal/funcs/data.go b/internal/funcs/data.go new file mode 100644 index 00000000..32f5b2da --- /dev/null +++ b/internal/funcs/data.go @@ -0,0 +1,140 @@ +package funcs + +import ( + "context" + + "github.com/hairyhenderson/gomplate/v4/conv" + "github.com/hairyhenderson/gomplate/v4/data" + "github.com/hairyhenderson/gomplate/v4/internal/parsers" +) + +// DataNS - +// +// Deprecated: don't use +func DataNS() *DataFuncs { + return &DataFuncs{} +} + +// AddDataFuncs - +// +// Deprecated: use [CreateDataFuncs] instead +func AddDataFuncs(f map[string]interface{}, d *data.Data) { + for k, v := range CreateDataFuncs(context.Background(), d) { + f[k] = v + } +} + +// CreateDataFuncs - +// +//nolint:staticcheck +func CreateDataFuncs(ctx context.Context, d *data.Data) map[string]interface{} { + f := map[string]interface{}{} + f["datasource"] = d.Datasource + f["ds"] = d.Datasource + f["datasourceExists"] = d.DatasourceExists + f["datasourceReachable"] = d.DatasourceReachable + f["defineDatasource"] = d.DefineDatasource + f["include"] = d.Include + f["listDatasources"] = d.ListDatasources + + ns := &DataFuncs{ctx} + + f["data"] = func() interface{} { return ns } + + f["json"] = ns.JSON + f["jsonArray"] = ns.JSONArray + f["yaml"] = ns.YAML + f["yamlArray"] = ns.YAMLArray + f["toml"] = ns.TOML + f["csv"] = ns.CSV + f["csvByRow"] = ns.CSVByRow + f["csvByColumn"] = ns.CSVByColumn + f["cue"] = ns.CUE + f["toJSON"] = ns.ToJSON + f["toJSONPretty"] = ns.ToJSONPretty + f["toYAML"] = ns.ToYAML + f["toTOML"] = ns.ToTOML + f["toCSV"] = ns.ToCSV + f["toCUE"] = ns.ToCUE + return f +} + +// DataFuncs - +type DataFuncs struct { + ctx context.Context +} + +// JSON - +func (f *DataFuncs) JSON(in interface{}) (map[string]interface{}, error) { + return parsers.JSON(conv.ToString(in)) +} + +// JSONArray - +func (f *DataFuncs) JSONArray(in interface{}) ([]interface{}, error) { + return parsers.JSONArray(conv.ToString(in)) +} + +// YAML - +func (f *DataFuncs) YAML(in interface{}) (map[string]interface{}, error) { + return parsers.YAML(conv.ToString(in)) +} + +// YAMLArray - +func (f *DataFuncs) YAMLArray(in interface{}) ([]interface{}, error) { + return parsers.YAMLArray(conv.ToString(in)) +} + +// TOML - +func (f *DataFuncs) TOML(in interface{}) (interface{}, error) { + return parsers.TOML(conv.ToString(in)) +} + +// CSV - +func (f *DataFuncs) CSV(args ...string) ([][]string, error) { + return parsers.CSV(args...) +} + +// CSVByRow - +func (f *DataFuncs) CSVByRow(args ...string) (rows []map[string]string, err error) { + return parsers.CSVByRow(args...) +} + +// CSVByColumn - +func (f *DataFuncs) CSVByColumn(args ...string) (cols map[string][]string, err error) { + return parsers.CSVByColumn(args...) +} + +// CUE - +func (f *DataFuncs) CUE(in interface{}) (interface{}, error) { + return parsers.CUE(conv.ToString(in)) +} + +// ToCSV - +func (f *DataFuncs) ToCSV(args ...interface{}) (string, error) { + return parsers.ToCSV(args...) +} + +// ToCUE - +func (f *DataFuncs) ToCUE(in interface{}) (string, error) { + return parsers.ToCUE(in) +} + +// ToJSON - +func (f *DataFuncs) ToJSON(in interface{}) (string, error) { + return parsers.ToJSON(in) +} + +// ToJSONPretty - +func (f *DataFuncs) ToJSONPretty(indent string, in interface{}) (string, error) { + return parsers.ToJSONPretty(indent, in) +} + +// ToYAML - +func (f *DataFuncs) ToYAML(in interface{}) (string, error) { + return parsers.ToYAML(in) +} + +// ToTOML - +func (f *DataFuncs) ToTOML(in interface{}) (string, error) { + return parsers.ToTOML(in) +} diff --git a/internal/funcs/data_test.go b/internal/funcs/data_test.go new file mode 100644 index 00000000..c536beb3 --- /dev/null +++ b/internal/funcs/data_test.go @@ -0,0 +1,26 @@ +package funcs + +import ( + "context" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCreateDataFuncs(t *testing.T) { + t.Parallel() + + for i := 0; i < 10; i++ { + // Run this a bunch to catch race conditions + t.Run(strconv.Itoa(i), func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + fmap := CreateDataFuncs(ctx, nil) + actual := fmap["data"].(func() interface{}) + + assert.Equal(t, ctx, actual().(*DataFuncs).ctx) + }) + } +} diff --git a/internal/funcs/env.go b/internal/funcs/env.go new file mode 100644 index 00000000..0df63b3e --- /dev/null +++ b/internal/funcs/env.go @@ -0,0 +1,49 @@ +package funcs + +import ( + "context" + + "github.com/hairyhenderson/gomplate/v4/conv" + "github.com/hairyhenderson/gomplate/v4/env" +) + +// EnvNS - the Env namespace +// +// Deprecated: don't use +func EnvNS() *EnvFuncs { + return &EnvFuncs{} +} + +// AddEnvFuncs - +// +// Deprecated: use [CreateEnvFuncs] instead +func AddEnvFuncs(f map[string]interface{}) { + for k, v := range CreateEnvFuncs(context.Background()) { + f[k] = v + } +} + +// CreateEnvFuncs - +func CreateEnvFuncs(ctx context.Context) map[string]interface{} { + ns := &EnvFuncs{ctx} + + return map[string]interface{}{ + "env": func() interface{} { return ns }, + "getenv": ns.Getenv, + } +} + +// EnvFuncs - +type EnvFuncs struct { + ctx context.Context +} + +// Getenv - +func (EnvFuncs) Getenv(key interface{}, def ...string) string { + return env.Getenv(conv.ToString(key), def...) +} + +// ExpandEnv - +func (EnvFuncs) ExpandEnv(s interface{}) string { + return env.ExpandEnv(conv.ToString(s)) +} diff --git a/internal/funcs/env_test.go b/internal/funcs/env_test.go new file mode 100644 index 00000000..a6ee43ac --- /dev/null +++ b/internal/funcs/env_test.go @@ -0,0 +1,37 @@ +package funcs + +import ( + "context" + "os" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCreateEnvFuncs(t *testing.T) { + t.Parallel() + + for i := 0; i < 10; i++ { + // Run this a bunch to catch race conditions + t.Run(strconv.Itoa(i), func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + fmap := CreateEnvFuncs(ctx) + actual := fmap["env"].(func() interface{}) + + assert.Equal(t, ctx, actual().(*EnvFuncs).ctx) + }) + } +} + +func TestEnvGetenv(t *testing.T) { + t.Parallel() + + ef := &EnvFuncs{} + expected := os.Getenv("USER") + assert.Equal(t, expected, ef.Getenv("USER")) + + assert.Equal(t, "foo", ef.Getenv("bogusenvvar", "foo")) +} diff --git a/internal/funcs/file.go b/internal/funcs/file.go new file mode 100644 index 00000000..32af47cf --- /dev/null +++ b/internal/funcs/file.go @@ -0,0 +1,128 @@ +package funcs + +import ( + "context" + "io/fs" + "path/filepath" + + osfs "github.com/hack-pad/hackpadfs/os" + "github.com/hairyhenderson/gomplate/v4/conv" + "github.com/hairyhenderson/gomplate/v4/internal/datafs" + "github.com/hairyhenderson/gomplate/v4/internal/iohelpers" +) + +// FileNS - the File namespace +// +// Deprecated: don't use +func FileNS() *FileFuncs { + return &FileFuncs{} +} + +// AddFileFuncs - +// +// Deprecated: use [CreateFileFuncs] instead +func AddFileFuncs(f map[string]interface{}) { + for k, v := range CreateFileFuncs(context.Background()) { + f[k] = v + } +} + +// CreateFileFuncs - +func CreateFileFuncs(ctx context.Context) map[string]interface{} { + fsys, err := datafs.FSysForPath(ctx, "/") + if err != nil { + fsys = datafs.WrapWdFS(osfs.NewFS()) + } + + ns := &FileFuncs{ + ctx: ctx, + fs: fsys, + } + + return map[string]interface{}{ + "file": func() interface{} { return ns }, + } +} + +// FileFuncs - +type FileFuncs struct { + ctx context.Context + fs fs.FS +} + +// Read - +func (f *FileFuncs) Read(path interface{}) (string, error) { + b, err := fs.ReadFile(f.fs, conv.ToString(path)) + return string(b), err +} + +// Stat - +func (f *FileFuncs) Stat(path interface{}) (fs.FileInfo, error) { + return fs.Stat(f.fs, conv.ToString(path)) +} + +// Exists - +func (f *FileFuncs) Exists(path interface{}) bool { + _, err := f.Stat(conv.ToString(path)) + return err == nil +} + +// IsDir - +func (f *FileFuncs) IsDir(path interface{}) bool { + i, err := f.Stat(conv.ToString(path)) + return err == nil && i.IsDir() +} + +// ReadDir - +func (f *FileFuncs) ReadDir(path interface{}) ([]string, error) { + des, err := fs.ReadDir(f.fs, conv.ToString(path)) + if err != nil { + return nil, err + } + + names := make([]string, len(des)) + for i, de := range des { + names[i] = de.Name() + } + + return names, nil +} + +// Walk - +func (f *FileFuncs) Walk(path interface{}) ([]string, error) { + files := make([]string, 0) + err := fs.WalkDir(f.fs, conv.ToString(path), func(subpath string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + + // fs.WalkDir always uses slash-separated paths, even on Windows. We + // need to convert them to the OS-specific separator as that was the + // previous behavior. + subpath = filepath.FromSlash(subpath) + + files = append(files, subpath) + return nil + }) + return files, err +} + +// Write - +func (f *FileFuncs) Write(path interface{}, data interface{}) (s string, err error) { + type byteser interface{ Bytes() []byte } + + var content []byte + fname := conv.ToString(path) + + if b, ok := data.([]byte); ok { + content = b + } else if b, ok := data.(byteser); ok { + content = b.Bytes() + } else { + content = []byte(conv.ToString(data)) + } + + err = iohelpers.WriteFile(f.fs, fname, content) + + return "", err +} diff --git a/internal/funcs/file_test.go b/internal/funcs/file_test.go new file mode 100644 index 00000000..d9fe5ccc --- /dev/null +++ b/internal/funcs/file_test.go @@ -0,0 +1,177 @@ +package funcs + +import ( + "bytes" + "context" + "io/fs" + "os" + "path/filepath" + "strconv" + "testing" + "testing/fstest" + + "github.com/hack-pad/hackpadfs" + osfs "github.com/hack-pad/hackpadfs/os" + "github.com/hairyhenderson/gomplate/v4/internal/datafs" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + tfs "gotest.tools/v3/fs" +) + +func TestCreateFileFuncs(t *testing.T) { + t.Parallel() + + for i := 0; i < 10; i++ { + // Run this a bunch to catch race conditions + t.Run(strconv.Itoa(i), func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + fmap := CreateFileFuncs(ctx) + actual := fmap["file"].(func() interface{}) + + assert.Equal(t, ctx, actual().(*FileFuncs).ctx) + }) + } +} + +func TestFileExists(t *testing.T) { + t.Parallel() + + fsys := fstest.MapFS{ + "tmp": &fstest.MapFile{Mode: fs.ModeDir | 0o777}, + "tmp/foo": &fstest.MapFile{Data: []byte("foo")}, + } + ff := &FileFuncs{fs: datafs.WrapWdFS(fsys)} + + assert.True(t, ff.Exists("/tmp/foo")) + assert.False(t, ff.Exists("/tmp/bar")) +} + +func TestFileIsDir(t *testing.T) { + t.Parallel() + + fsys := fstest.MapFS{ + "tmp": &fstest.MapFile{Mode: fs.ModeDir | 0o777}, + "tmp/foo": &fstest.MapFile{Data: []byte("foo")}, + } + + ff := &FileFuncs{fs: datafs.WrapWdFS(fsys)} + + assert.True(t, ff.IsDir("/tmp")) + assert.False(t, ff.IsDir("/tmp/foo")) +} + +func TestFileWalk(t *testing.T) { + t.Parallel() + + fsys := fstest.MapFS{ + "tmp": &fstest.MapFile{Mode: fs.ModeDir | 0o777}, + "tmp/bar": &fstest.MapFile{Mode: fs.ModeDir | 0o777}, + "tmp/bar/baz": &fstest.MapFile{Mode: fs.ModeDir | 0o777}, + "tmp/bar/baz/foo": &fstest.MapFile{Data: []byte("foo")}, + } + + ff := &FileFuncs{fs: datafs.WrapWdFS(fsys)} + + expectedLists := [][]string{{"tmp"}, {"tmp", "bar"}, {"tmp", "bar", "baz"}, {"tmp", "bar", "baz", "foo"}} + expectedPaths := make([]string, 0) + for _, path := range expectedLists { + expectedPaths = append(expectedPaths, string(filepath.Separator)+filepath.Join(path...)) + } + + actualPaths, err := ff.Walk(string(filepath.Separator) + "tmp") + + require.NoError(t, err) + assert.Equal(t, expectedPaths, actualPaths) +} + +func TestReadDir(t *testing.T) { + fsys := fs.FS(fstest.MapFS{ + "tmp": &fstest.MapFile{Mode: fs.ModeDir | 0o777}, + "tmp/foo": &fstest.MapFile{Data: []byte("foo")}, + "tmp/bar": &fstest.MapFile{Data: []byte("bar")}, + "tmp/baz": &fstest.MapFile{Data: []byte("baz")}, + "tmp/qux": &fstest.MapFile{Mode: fs.ModeDir | 0o777}, + "tmp/qux/quux": &fstest.MapFile{Data: []byte("quux")}, + }) + + fsys = datafs.WrapWdFS(fsys) + + ff := &FileFuncs{ + ctx: context.Background(), + fs: fsys, + } + + actual, err := ff.ReadDir("/tmp") + require.NoError(t, err) + assert.Equal(t, []string{"bar", "baz", "foo", "qux"}, actual) + + _, err = ff.ReadDir("/tmp/foo") + assert.Error(t, err) +} + +func TestWrite(t *testing.T) { + oldwd, _ := os.Getwd() + defer os.Chdir(oldwd) + + rootDir := tfs.NewDir(t, "gomplate-test") + t.Cleanup(rootDir.Remove) + + // we want to use a real filesystem here, so we can test interactions with + // the current working directory + fsys := datafs.WrapWdFS(osfs.NewFS()) + + f := &FileFuncs{ + ctx: context.Background(), + fs: fsys, + } + + newwd := rootDir.Join("the", "path", "we", "want") + badwd := rootDir.Join("some", "other", "dir") + hackpadfs.MkdirAll(fsys, newwd, 0o755) + hackpadfs.MkdirAll(fsys, badwd, 0o755) + newwd, _ = filepath.EvalSymlinks(newwd) + badwd, _ = filepath.EvalSymlinks(badwd) + + err := os.Chdir(newwd) + require.NoError(t, err) + + _, err = f.Write("/foo", []byte("Hello world")) + assert.Error(t, err) + + rel, err := filepath.Rel(newwd, badwd) + require.NoError(t, err) + _, err = f.Write(rel, []byte("Hello world")) + assert.Error(t, err) + + foopath := filepath.Join(newwd, "foo") + _, err = f.Write(foopath, []byte("Hello world")) + require.NoError(t, err) + + out, err := fs.ReadFile(fsys, foopath) + require.NoError(t, err) + assert.Equal(t, "Hello world", string(out)) + + _, err = f.Write(foopath, []byte("truncate")) + require.NoError(t, err) + + out, err = fs.ReadFile(fsys, foopath) + require.NoError(t, err) + assert.Equal(t, "truncate", string(out)) + + foopath = filepath.Join(newwd, "nonexistant", "subdir", "foo") + _, err = f.Write(foopath, "Hello subdirranean world!") + require.NoError(t, err) + + out, err = fs.ReadFile(fsys, foopath) + require.NoError(t, err) + assert.Equal(t, "Hello subdirranean world!", string(out)) + + _, err = f.Write(foopath, bytes.NewBufferString("Hello from a byte buffer!")) + require.NoError(t, err) + + out, err = fs.ReadFile(fsys, foopath) + require.NoError(t, err) + assert.Equal(t, "Hello from a byte buffer!", string(out)) +} diff --git a/internal/funcs/filepath.go b/internal/funcs/filepath.go new file mode 100644 index 00000000..726a1d7c --- /dev/null +++ b/internal/funcs/filepath.go @@ -0,0 +1,100 @@ +package funcs + +import ( + "context" + "path/filepath" + + "github.com/hairyhenderson/gomplate/v4/conv" +) + +// FilePathNS - the Path namespace +// +// Deprecated: don't use +func FilePathNS() *FilePathFuncs { + return &FilePathFuncs{} +} + +// AddFilePathFuncs - +// +// Deprecated: use [CreateFilePathFuncs] instead +func AddFilePathFuncs(f map[string]interface{}) { + for k, v := range CreateFilePathFuncs(context.Background()) { + f[k] = v + } +} + +// CreateFilePathFuncs - +func CreateFilePathFuncs(ctx context.Context) map[string]interface{} { + ns := &FilePathFuncs{ctx} + + return map[string]interface{}{ + "filepath": func() interface{} { return ns }, + } +} + +// FilePathFuncs - +type FilePathFuncs struct { + ctx context.Context +} + +// Base - +func (f *FilePathFuncs) Base(in interface{}) string { + return filepath.Base(conv.ToString(in)) +} + +// Clean - +func (f *FilePathFuncs) Clean(in interface{}) string { + return filepath.Clean(conv.ToString(in)) +} + +// Dir - +func (f *FilePathFuncs) Dir(in interface{}) string { + return filepath.Dir(conv.ToString(in)) +} + +// Ext - +func (f *FilePathFuncs) Ext(in interface{}) string { + return filepath.Ext(conv.ToString(in)) +} + +// FromSlash - +func (f *FilePathFuncs) FromSlash(in interface{}) string { + return filepath.FromSlash(conv.ToString(in)) +} + +// IsAbs - +func (f *FilePathFuncs) IsAbs(in interface{}) bool { + return filepath.IsAbs(conv.ToString(in)) +} + +// Join - +func (f *FilePathFuncs) Join(elem ...interface{}) string { + s := conv.ToStrings(elem...) + return filepath.Join(s...) +} + +// Match - +func (f *FilePathFuncs) Match(pattern, name interface{}) (matched bool, err error) { + return filepath.Match(conv.ToString(pattern), conv.ToString(name)) +} + +// Rel - +func (f *FilePathFuncs) Rel(basepath, targpath interface{}) (string, error) { + return filepath.Rel(conv.ToString(basepath), conv.ToString(targpath)) +} + +// Split - +func (f *FilePathFuncs) Split(in interface{}) []string { + dir, file := filepath.Split(conv.ToString(in)) + return []string{dir, file} +} + +// ToSlash - +func (f *FilePathFuncs) ToSlash(in interface{}) string { + return filepath.ToSlash(conv.ToString(in)) +} + +// VolumeName - +func (f *FilePathFuncs) VolumeName(in interface{}) string { + return filepath.VolumeName(conv.ToString(in)) +} diff --git a/internal/funcs/filepath_test.go b/internal/funcs/filepath_test.go new file mode 100644 index 00000000..0c4bbe80 --- /dev/null +++ b/internal/funcs/filepath_test.go @@ -0,0 +1,26 @@ +package funcs + +import ( + "context" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCreateFilePathFuncs(t *testing.T) { + t.Parallel() + + for i := 0; i < 10; i++ { + // Run this a bunch to catch race conditions + t.Run(strconv.Itoa(i), func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + fmap := CreateFilePathFuncs(ctx) + actual := fmap["filepath"].(func() interface{}) + + assert.Equal(t, ctx, actual().(*FilePathFuncs).ctx) + }) + } +} diff --git a/internal/funcs/filepath_unix_test.go b/internal/funcs/filepath_unix_test.go new file mode 100644 index 00000000..381e7432 --- /dev/null +++ b/internal/funcs/filepath_unix_test.go @@ -0,0 +1,33 @@ +//go:build !windows +// +build !windows + +package funcs + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFilePathFuncs(t *testing.T) { + t.Parallel() + + f := &FilePathFuncs{} + assert.Equal(t, "bar", f.Base("foo/bar")) + assert.Equal(t, "bar", f.Base("/foo/bar")) + + assert.Equal(t, "/foo/baz", f.Clean("/foo/bar/../baz")) + assert.Equal(t, "foo", f.Dir("foo/bar")) + assert.Equal(t, ".txt", f.Ext("/foo/bar/baz.txt")) + assert.False(t, f.IsAbs("foo/bar")) + assert.True(t, f.IsAbs("/foo/bar")) + assert.Equal(t, "foo/bar/qux", f.Join("foo", "bar", "baz", "..", "qux")) + m, _ := f.Match("*.txt", "foo.json") + assert.False(t, m) + m, _ = f.Match("*.txt", "foo.txt") + assert.True(t, m) + r, _ := f.Rel("/foo/bar", "/foo/bar/baz") + assert.Equal(t, "baz", r) + assert.Equal(t, []string{"/foo/bar/", "baz"}, f.Split("/foo/bar/baz")) + assert.Equal(t, "", f.VolumeName("/foo/bar")) +} diff --git a/internal/funcs/filepath_windows_test.go b/internal/funcs/filepath_windows_test.go new file mode 100644 index 00000000..190dfedb --- /dev/null +++ b/internal/funcs/filepath_windows_test.go @@ -0,0 +1,35 @@ +//go:build windows +// +build windows + +package funcs + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFilePathFuncs(t *testing.T) { + t.Parallel() + + f := &FilePathFuncs{} + assert.Equal(t, "bar", f.Base(`foo\bar`)) + assert.Equal(t, "bar", f.Base("C:/foo/bar")) + assert.Equal(t, "bar", f.Base(`C:\foo\bar`)) + + assert.Equal(t, `C:\foo\baz`, f.Clean(`C:\foo\bar\..\baz`)) + assert.Equal(t, "foo", f.Dir(`foo\bar`)) + assert.Equal(t, ".txt", f.Ext(`C:\foo\bar\baz.txt`)) + assert.False(t, f.IsAbs(`foo\bar`)) + assert.True(t, f.IsAbs(`C:\foo\bar`)) + assert.False(t, f.IsAbs(`\foo\bar`)) + assert.Equal(t, `foo\bar\qux`, f.Join("foo", "bar", "baz", "..", "qux")) + m, _ := f.Match("*.txt", "foo.json") + assert.False(t, m) + m, _ = f.Match("*.txt", "foo.txt") + assert.True(t, m) + r, _ := f.Rel(`C:\foo\bar`, `C:\foo\bar\baz`) + assert.Equal(t, "baz", r) + assert.Equal(t, []string{`C:\foo\bar\`, "baz"}, f.Split(`C:\foo\bar\baz`)) + assert.Equal(t, "D:", f.VolumeName(`D:\foo\bar`)) +} diff --git a/internal/funcs/funcs.go b/internal/funcs/funcs.go new file mode 100644 index 00000000..f5549e61 --- /dev/null +++ b/internal/funcs/funcs.go @@ -0,0 +1,15 @@ +package funcs + +import ( + "context" + "fmt" + + "github.com/hairyhenderson/gomplate/v4/internal/config" +) + +func checkExperimental(ctx context.Context) error { + if !config.ExperimentalEnabled(ctx) { + return fmt.Errorf("experimental function, but experimental mode not enabled") + } + return nil +} diff --git a/internal/funcs/gcp.go b/internal/funcs/gcp.go new file mode 100644 index 00000000..040a451b --- /dev/null +++ b/internal/funcs/gcp.go @@ -0,0 +1,52 @@ +package funcs + +import ( + "context" + "sync" + + "github.com/hairyhenderson/gomplate/v4/gcp" +) + +// GCPNS - the gcp namespace +// +// Deprecated: don't use +func GCPNS() *GcpFuncs { + return &GcpFuncs{gcpopts: gcp.GetClientOptions()} +} + +// AddGCPFuncs - +// +// Deprecated: use [CreateGCPFuncs] instead +func AddGCPFuncs(f map[string]interface{}) { + for k, v := range CreateGCPFuncs(context.Background()) { + f[k] = v + } +} + +// CreateGCPFuncs - +func CreateGCPFuncs(ctx context.Context) map[string]interface{} { + ns := &GcpFuncs{ + ctx: ctx, + gcpopts: gcp.GetClientOptions(), + } + return map[string]interface{}{ + "gcp": func() interface{} { return ns }, + } +} + +// GcpFuncs - +type GcpFuncs struct { + ctx context.Context + + meta *gcp.MetaClient + gcpopts gcp.ClientOptions +} + +// Meta - +func (a *GcpFuncs) Meta(key string, def ...string) (string, error) { + a.meta = sync.OnceValue[*gcp.MetaClient](func() *gcp.MetaClient { + return gcp.NewMetaClient(a.ctx, a.gcpopts) + })() + + return a.meta.Meta(key, def...) +} diff --git a/internal/funcs/gcp_test.go b/internal/funcs/gcp_test.go new file mode 100644 index 00000000..91ab54fb --- /dev/null +++ b/internal/funcs/gcp_test.go @@ -0,0 +1,26 @@ +package funcs + +import ( + "context" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCreateGCPFuncs(t *testing.T) { + t.Parallel() + + for i := 0; i < 10; i++ { + // Run this a bunch to catch race conditions + t.Run(strconv.Itoa(i), func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + fmap := CreateGCPFuncs(ctx) + actual := fmap["gcp"].(func() interface{}) + + assert.Equal(t, ctx, actual().(*GcpFuncs).ctx) + }) + } +} diff --git a/internal/funcs/math.go b/internal/funcs/math.go new file mode 100644 index 00000000..68316101 --- /dev/null +++ b/internal/funcs/math.go @@ -0,0 +1,247 @@ +package funcs + +import ( + "context" + "fmt" + gmath "math" + "strconv" + + "github.com/hairyhenderson/gomplate/v4/conv" + + "github.com/hairyhenderson/gomplate/v4/math" +) + +// MathNS - the math namespace +// +// Deprecated: don't use +func MathNS() *MathFuncs { + return &MathFuncs{} +} + +// AddMathFuncs - +// +// Deprecated: use [CreateMathFuncs] instead +func AddMathFuncs(f map[string]interface{}) { + for k, v := range CreateMathFuncs(context.Background()) { + f[k] = v + } +} + +// CreateMathFuncs - +func CreateMathFuncs(ctx context.Context) map[string]interface{} { + f := map[string]interface{}{} + + ns := &MathFuncs{ctx} + f["math"] = func() interface{} { return ns } + + f["add"] = ns.Add + f["sub"] = ns.Sub + f["mul"] = ns.Mul + f["div"] = ns.Div + f["rem"] = ns.Rem + f["pow"] = ns.Pow + f["seq"] = ns.Seq + return f +} + +// MathFuncs - +type MathFuncs struct { + ctx context.Context +} + +// IsInt - +func (f MathFuncs) IsInt(n interface{}) bool { + switch i := n.(type) { + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + return true + case string: + _, err := strconv.ParseInt(i, 0, 64) + return err == nil + } + return false +} + +// IsFloat - +func (f MathFuncs) IsFloat(n interface{}) bool { + switch i := n.(type) { + case float32, float64: + return true + case string: + _, err := strconv.ParseFloat(i, 64) + if err != nil { + return false + } + if f.IsInt(i) { + return false + } + return true + } + return false +} + +func (f MathFuncs) containsFloat(n ...interface{}) bool { + c := false + for _, v := range n { + if f.IsFloat(v) { + return true + } + } + return c +} + +// IsNum - +func (f MathFuncs) IsNum(n interface{}) bool { + return f.IsInt(n) || f.IsFloat(n) +} + +// Abs - +func (f MathFuncs) Abs(n interface{}) interface{} { + m := gmath.Abs(conv.ToFloat64(n)) + if f.IsInt(n) { + return conv.ToInt64(m) + } + return m +} + +// Add - +func (f MathFuncs) Add(n ...interface{}) interface{} { + if f.containsFloat(n...) { + nums := conv.ToFloat64s(n...) + var x float64 + for _, v := range nums { + x += v + } + return x + } + nums := conv.ToInt64s(n...) + var x int64 + for _, v := range nums { + x += v + } + return x +} + +// Mul - +func (f MathFuncs) Mul(n ...interface{}) interface{} { + if f.containsFloat(n...) { + nums := conv.ToFloat64s(n...) + x := 1. + for _, v := range nums { + x *= v + } + return x + } + nums := conv.ToInt64s(n...) + x := int64(1) + for _, v := range nums { + x *= v + } + return x +} + +// Sub - +func (f MathFuncs) Sub(a, b interface{}) interface{} { + if f.containsFloat(a, b) { + return conv.ToFloat64(a) - conv.ToFloat64(b) + } + return conv.ToInt64(a) - conv.ToInt64(b) +} + +// Div - +func (f MathFuncs) Div(a, b interface{}) (interface{}, error) { + divisor := conv.ToFloat64(a) + dividend := conv.ToFloat64(b) + if dividend == 0 { + return 0, fmt.Errorf("error: division by 0") + } + return divisor / dividend, nil +} + +// Rem - +func (f MathFuncs) Rem(a, b interface{}) interface{} { + return conv.ToInt64(a) % conv.ToInt64(b) +} + +// Pow - +func (f MathFuncs) Pow(a, b interface{}) interface{} { + r := gmath.Pow(conv.ToFloat64(a), conv.ToFloat64(b)) + if f.IsFloat(a) { + return r + } + return conv.ToInt64(r) +} + +// Seq - return a sequence from `start` to `end`, in steps of `step` +// start and step are optional, and default to 1. +func (f MathFuncs) Seq(n ...interface{}) ([]int64, error) { + start := int64(1) + end := int64(0) + step := int64(1) + if len(n) == 0 { + return nil, fmt.Errorf("math.Seq must be given at least an 'end' value") + } + if len(n) == 1 { + end = conv.ToInt64(n[0]) + } + if len(n) == 2 { + start = conv.ToInt64(n[0]) + end = conv.ToInt64(n[1]) + } + if len(n) == 3 { + start = conv.ToInt64(n[0]) + end = conv.ToInt64(n[1]) + step = conv.ToInt64(n[2]) + } + return math.Seq(conv.ToInt64(start), conv.ToInt64(end), conv.ToInt64(step)), nil +} + +// Max - +func (f MathFuncs) Max(a interface{}, b ...interface{}) (interface{}, error) { + if f.IsFloat(a) || f.containsFloat(b...) { + m := conv.ToFloat64(a) + for _, n := range conv.ToFloat64s(b...) { + m = gmath.Max(m, n) + } + return m, nil + } + m := conv.ToInt64(a) + for _, n := range conv.ToInt64s(b...) { + if n > m { + m = n + } + } + return m, nil +} + +// Min - +func (f MathFuncs) Min(a interface{}, b ...interface{}) (interface{}, error) { + if f.IsFloat(a) || f.containsFloat(b...) { + m := conv.ToFloat64(a) + for _, n := range conv.ToFloat64s(b...) { + m = gmath.Min(m, n) + } + return m, nil + } + m := conv.ToInt64(a) + for _, n := range conv.ToInt64s(b...) { + if n < m { + m = n + } + } + return m, nil +} + +// Ceil - +func (f MathFuncs) Ceil(n interface{}) interface{} { + return gmath.Ceil(conv.ToFloat64(n)) +} + +// Floor - +func (f MathFuncs) Floor(n interface{}) interface{} { + return gmath.Floor(conv.ToFloat64(n)) +} + +// Round - +func (f MathFuncs) Round(n interface{}) interface{} { + return gmath.Round(conv.ToFloat64(n)) +} diff --git a/internal/funcs/math_test.go b/internal/funcs/math_test.go new file mode 100644 index 00000000..45681e21 --- /dev/null +++ b/internal/funcs/math_test.go @@ -0,0 +1,400 @@ +package funcs + +import ( + "context" + "fmt" + gmath "math" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCreateMathFuncs(t *testing.T) { + t.Parallel() + + for i := 0; i < 10; i++ { + // Run this a bunch to catch race conditions + t.Run(strconv.Itoa(i), func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + fmap := CreateMathFuncs(ctx) + actual := fmap["math"].(func() interface{}) + + assert.Equal(t, ctx, actual().(*MathFuncs).ctx) + }) + } +} + +func TestAdd(t *testing.T) { + t.Parallel() + + m := MathFuncs{} + assert.Equal(t, int64(12), m.Add(1, 1, 2, 3, 5)) + assert.Equal(t, int64(2), m.Add(1, 1)) + assert.Equal(t, int64(1), m.Add(1)) + assert.Equal(t, int64(0), m.Add(-5, 5)) + assert.InDelta(t, float64(5.1), m.Add(4.9, "0.2"), 0.000000001) +} + +func TestMul(t *testing.T) { + t.Parallel() + + m := MathFuncs{} + assert.Equal(t, int64(30), m.Mul(1, 1, 2, 3, 5)) + assert.Equal(t, int64(1), m.Mul(1, 1)) + assert.Equal(t, int64(1), m.Mul(1)) + assert.Equal(t, int64(-25), m.Mul("-5", 5)) + assert.Equal(t, int64(28), m.Mul(14, "2")) + assert.Equal(t, float64(0.5), m.Mul("-1", -0.5)) +} + +func TestSub(t *testing.T) { + t.Parallel() + + m := MathFuncs{} + assert.Equal(t, int64(0), m.Sub(1, 1)) + assert.Equal(t, int64(-10), m.Sub(-5, 5)) + assert.Equal(t, int64(-41), m.Sub(true, "42")) + assert.InDelta(t, -5.3, m.Sub(10, 15.3), 0.000000000000001) +} + +func mustDiv(a, b interface{}) interface{} { + m := MathFuncs{} + r, err := m.Div(a, b) + if err != nil { + return -1 + } + return r +} + +func TestDiv(t *testing.T) { + t.Parallel() + + m := MathFuncs{} + _, err := m.Div(1, 0) + assert.Error(t, err) + assert.Equal(t, 1., mustDiv(1, 1)) + assert.Equal(t, -1., mustDiv(-5, 5)) + assert.Equal(t, 1./42, mustDiv(true, "42")) + assert.InDelta(t, 0.5, mustDiv(1, 2), 1e-12) +} + +func TestRem(t *testing.T) { + t.Parallel() + + m := MathFuncs{} + assert.Equal(t, int64(0), m.Rem(1, 1)) + assert.Equal(t, int64(2), m.Rem(5, 3.0)) +} + +func TestPow(t *testing.T) { + t.Parallel() + + m := MathFuncs{} + assert.Equal(t, int64(4), m.Pow(2, "2")) + assert.Equal(t, 2.25, m.Pow(1.5, 2)) +} + +func mustSeq(t *testing.T, n ...interface{}) []int64 { + m := MathFuncs{} + s, err := m.Seq(n...) + if err != nil { + t.Fatal(err) + } + return s +} + +func TestSeq(t *testing.T) { + t.Parallel() + + m := MathFuncs{} + assert.EqualValues(t, []int64{0, 1, 2, 3}, mustSeq(t, 0, 3)) + assert.EqualValues(t, []int64{1, 0}, mustSeq(t, 0)) + assert.EqualValues(t, []int64{0, 2, 4}, mustSeq(t, 0, 4, 2)) + assert.EqualValues(t, []int64{0, 2, 4}, mustSeq(t, 0, 5, 2)) + assert.EqualValues(t, []int64{0}, mustSeq(t, 0, 5, 8)) + _, err := m.Seq() + assert.Error(t, err) +} + +func TestIsIntFloatNum(t *testing.T) { + t.Parallel() + + tests := []struct { + in interface{} + isInt bool + isFloat bool + }{ + {0, true, false}, + {1, true, false}, + {-1, true, false}, + {uint(42), true, false}, + {uint8(255), true, false}, + {uint16(42), true, false}, + {uint32(42), true, false}, + {uint64(42), true, false}, + {int(42), true, false}, + {int8(127), true, false}, + {int16(42), true, false}, + {int32(42), true, false}, + {int64(42), true, false}, + {float32(18.3), false, true}, + {float64(18.3), false, true}, + {1.5, false, true}, + {-18.6, false, true}, + {"42", true, false}, + {"052", true, false}, + {"0xff", true, false}, + {"-42", true, false}, + {"-0", true, false}, + {"3.14", false, true}, + {"-3.14", false, true}, + {"0.00", false, true}, + {"NaN", false, true}, + {"-Inf", false, true}, + {"+Inf", false, true}, + {"", false, false}, + {"foo", false, false}, + {nil, false, false}, + {true, false, false}, + } + m := MathFuncs{} + for _, tt := range tests { + tt := tt + t.Run(fmt.Sprintf("%T(%#v)", tt.in, tt.in), func(t *testing.T) { + t.Parallel() + + assert.Equal(t, tt.isInt, m.IsInt(tt.in)) + assert.Equal(t, tt.isFloat, m.IsFloat(tt.in)) + assert.Equal(t, tt.isInt || tt.isFloat, m.IsNum(tt.in)) + }) + } +} + +func BenchmarkIsFloat(b *testing.B) { + data := []interface{}{ + 0, 1, -1, uint(42), uint8(255), uint16(42), uint32(42), uint64(42), int(42), int8(127), int16(42), int32(42), int64(42), float32(18.3), float64(18.3), 1.5, -18.6, "42", "052", "0xff", "-42", "-0", "3.14", "-3.14", "0.00", "NaN", "-Inf", "+Inf", "", "foo", nil, true, + } + m := MathFuncs{} + for _, n := range data { + n := n + b.Run(fmt.Sprintf("%T(%v)", n, n), func(b *testing.B) { + for i := 0; i < b.N; i++ { + m.IsFloat(n) + } + }) + } +} + +func TestMax(t *testing.T) { + t.Parallel() + + m := MathFuncs{} + data := []struct { + expected interface{} + n []interface{} + }{ + {int64(0), []interface{}{nil}}, + {int64(0), []interface{}{0}}, + {int64(0), []interface{}{"not a number"}}, + {int64(1), []interface{}{1}}, + {int64(-1), []interface{}{-1}}, + {int64(1), []interface{}{-1, 0, 1}}, + {3.9, []interface{}{3.14, 3, 3.9}}, + {int64(255), []interface{}{"14", "0xff", -5}}, + } + for _, d := range data { + d := d + t.Run(fmt.Sprintf("%v==%v", d.n, d.expected), func(t *testing.T) { + t.Parallel() + + var actual interface{} + if len(d.n) == 1 { + actual, _ = m.Max(d.n[0]) + } else { + actual, _ = m.Max(d.n[0], d.n[1:]...) + } + assert.Equal(t, d.expected, actual) + }) + } +} + +func TestMin(t *testing.T) { + t.Parallel() + + m := MathFuncs{} + data := []struct { + expected interface{} + n []interface{} + }{ + {int64(0), []interface{}{nil}}, + {int64(0), []interface{}{0}}, + {int64(0), []interface{}{"not a number"}}, + {int64(1), []interface{}{1}}, + {int64(-1), []interface{}{-1}}, + {int64(-1), []interface{}{-1, 0, 1}}, + {3., []interface{}{3.14, 3, 3.9}}, + {int64(-5), []interface{}{"14", "0xff", -5}}, + } + for _, d := range data { + d := d + t.Run(fmt.Sprintf("%v==%v", d.n, d.expected), func(t *testing.T) { + t.Parallel() + + var actual interface{} + if len(d.n) == 1 { + actual, _ = m.Min(d.n[0]) + } else { + actual, _ = m.Min(d.n[0], d.n[1:]...) + } + assert.Equal(t, d.expected, actual) + }) + } +} + +func TestContainsFloat(t *testing.T) { + t.Parallel() + + m := MathFuncs{} + data := []struct { + n []interface{} + expected bool + }{ + {[]interface{}{nil}, false}, + {[]interface{}{0}, false}, + {[]interface{}{"not a number"}, false}, + {[]interface{}{1}, false}, + {[]interface{}{-1}, false}, + {[]interface{}{-1, 0, 1}, false}, + {[]interface{}{3.14, 3, 3.9}, true}, + {[]interface{}{"14", "0xff", -5}, false}, + {[]interface{}{"14.8", "0xff", -5}, true}, + {[]interface{}{"-Inf", 2}, true}, + {[]interface{}{"NaN"}, true}, + } + for _, d := range data { + d := d + t.Run(fmt.Sprintf("%v==%v", d.n, d.expected), func(t *testing.T) { + t.Parallel() + + if d.expected { + assert.True(t, m.containsFloat(d.n...)) + } else { + assert.False(t, m.containsFloat(d.n...)) + } + }) + } +} + +func TestCeil(t *testing.T) { + t.Parallel() + + m := MathFuncs{} + data := []struct { + n interface{} + a float64 + }{ + {"", 0.}, + {nil, 0.}, + {"Inf", gmath.Inf(1)}, + {0, 0.}, + {4.99, 5.}, + {42.1, 43}, + {-1.9, -1}, + } + for _, d := range data { + d := d + t.Run(fmt.Sprintf("%v==%v", d.n, d.a), func(t *testing.T) { + t.Parallel() + + assert.InDelta(t, d.a, m.Ceil(d.n), 1e-12) + }) + } +} + +func TestFloor(t *testing.T) { + t.Parallel() + + m := MathFuncs{} + data := []struct { + n interface{} + a float64 + }{ + {"", 0.}, + {nil, 0.}, + {"Inf", gmath.Inf(1)}, + {0, 0.}, + {4.99, 4.}, + {42.1, 42}, + {-1.9, -2.}, + } + for _, d := range data { + d := d + t.Run(fmt.Sprintf("%v==%v", d.n, d.a), func(t *testing.T) { + t.Parallel() + + assert.InDelta(t, d.a, m.Floor(d.n), 1e-12) + }) + } +} + +func TestRound(t *testing.T) { + t.Parallel() + + m := MathFuncs{} + data := []struct { + n interface{} + a float64 + }{ + {"", 0.}, + {nil, 0.}, + {"Inf", gmath.Inf(1)}, + {0, 0.}, + {4.99, 5}, + {42.1, 42}, + {-1.9, -2.}, + {3.5, 4}, + {-3.5, -4}, + {4.5, 5}, + {-4.5, -5}, + } + for _, d := range data { + d := d + t.Run(fmt.Sprintf("%v==%v", d.n, d.a), func(t *testing.T) { + t.Parallel() + + assert.InDelta(t, d.a, m.Round(d.n), 1e-12) + }) + } +} + +func TestAbs(t *testing.T) { + t.Parallel() + + m := MathFuncs{} + data := []struct { + n interface{} + a interface{} + }{ + {"", 0.}, + {nil, 0.}, + {"-Inf", gmath.Inf(1)}, + {0, int64(0)}, + {0., 0.}, + {gmath.Copysign(0, -1), 0.}, + {3.14, 3.14}, + {-1.9, 1.9}, + {2, int64(2)}, + {-2, int64(2)}, + } + for _, d := range data { + d := d + t.Run(fmt.Sprintf("%#v==%v", d.n, d.a), func(t *testing.T) { + t.Parallel() + + assert.Equal(t, d.a, m.Abs(d.n)) + }) + } +} diff --git a/internal/funcs/net.go b/internal/funcs/net.go new file mode 100644 index 00000000..3342bd27 --- /dev/null +++ b/internal/funcs/net.go @@ -0,0 +1,300 @@ +package funcs + +import ( + "context" + "fmt" + "math/big" + stdnet "net" + "net/netip" + + "github.com/hairyhenderson/gomplate/v4/conv" + "github.com/hairyhenderson/gomplate/v4/internal/cidr" + "github.com/hairyhenderson/gomplate/v4/internal/deprecated" + "github.com/hairyhenderson/gomplate/v4/net" + "go4.org/netipx" + "inet.af/netaddr" +) + +// NetNS - the net namespace +// +// Deprecated: don't use +func NetNS() *NetFuncs { + return &NetFuncs{} +} + +// AddNetFuncs - +// +// Deprecated: use [CreateNetFuncs] instead +func AddNetFuncs(f map[string]interface{}) { + for k, v := range CreateNetFuncs(context.Background()) { + f[k] = v + } +} + +// CreateNetFuncs - +func CreateNetFuncs(ctx context.Context) map[string]interface{} { + ns := &NetFuncs{ctx} + return map[string]interface{}{ + "net": func() interface{} { return ns }, + } +} + +// NetFuncs - +type NetFuncs struct { + ctx context.Context +} + +// LookupIP - +func (f NetFuncs) LookupIP(name interface{}) (string, error) { + return net.LookupIP(conv.ToString(name)) +} + +// LookupIPs - +func (f NetFuncs) LookupIPs(name interface{}) ([]string, error) { + return net.LookupIPs(conv.ToString(name)) +} + +// LookupCNAME - +func (f NetFuncs) LookupCNAME(name interface{}) (string, error) { + return net.LookupCNAME(conv.ToString(name)) +} + +// LookupSRV - +func (f NetFuncs) LookupSRV(name interface{}) (*stdnet.SRV, error) { + return net.LookupSRV(conv.ToString(name)) +} + +// LookupSRVs - +func (f NetFuncs) LookupSRVs(name interface{}) ([]*stdnet.SRV, error) { + return net.LookupSRVs(conv.ToString(name)) +} + +// LookupTXT - +func (f NetFuncs) LookupTXT(name interface{}) ([]string, error) { + return net.LookupTXT(conv.ToString(name)) +} + +// ParseIP - +// +// Deprecated: use [ParseAddr] instead +func (f *NetFuncs) ParseIP(ip interface{}) (netaddr.IP, error) { + deprecated.WarnDeprecated(f.ctx, "net.ParseIP is deprecated - use net.ParseAddr instead") + return netaddr.ParseIP(conv.ToString(ip)) +} + +// ParseIPPrefix - +// +// Deprecated: use [ParsePrefix] instead +func (f *NetFuncs) ParseIPPrefix(ipprefix interface{}) (netaddr.IPPrefix, error) { + deprecated.WarnDeprecated(f.ctx, "net.ParseIPPrefix is deprecated - use net.ParsePrefix instead") + return netaddr.ParseIPPrefix(conv.ToString(ipprefix)) +} + +// ParseIPRange - +// +// Deprecated: use [ParseRange] instead +func (f *NetFuncs) ParseIPRange(iprange interface{}) (netaddr.IPRange, error) { + deprecated.WarnDeprecated(f.ctx, "net.ParseIPRange is deprecated - use net.ParseRange instead") + return netaddr.ParseIPRange(conv.ToString(iprange)) +} + +// ParseAddr - +func (f NetFuncs) ParseAddr(ip interface{}) (netip.Addr, error) { + return netip.ParseAddr(conv.ToString(ip)) +} + +// ParsePrefix - +func (f NetFuncs) ParsePrefix(ipprefix interface{}) (netip.Prefix, error) { + return netip.ParsePrefix(conv.ToString(ipprefix)) +} + +// ParseRange - +// +// Experimental: this API may change in the future +func (f NetFuncs) ParseRange(iprange interface{}) (netipx.IPRange, error) { + return netipx.ParseIPRange(conv.ToString(iprange)) +} + +// func (f *NetFuncs) parseStdnetIPNet(prefix interface{}) (*stdnet.IPNet, error) { +// switch p := prefix.(type) { +// case *stdnet.IPNet: +// return p, nil +// case netaddr.IPPrefix: +// deprecated.WarnDeprecated(f.ctx, +// "support for netaddr.IPPrefix is deprecated - use net.ParsePrefix to produce a netip.Prefix instead") +// return p.Masked().IPNet(), nil +// case netip.Prefix: +// net := &stdnet.IPNet{ +// IP: p.Masked().Addr().AsSlice(), +// Mask: stdnet.CIDRMask(p.Bits(), p.Addr().BitLen()), +// } +// return net, nil +// default: +// _, network, err := stdnet.ParseCIDR(conv.ToString(prefix)) +// return network, err +// } +// } + +// TODO: look at using this instead of parseStdnetIPNet +func (f *NetFuncs) parseNetipPrefix(prefix interface{}) (netip.Prefix, error) { + switch p := prefix.(type) { + case *stdnet.IPNet: + return f.ipPrefixFromIPNet(p), nil + case netaddr.IPPrefix: + deprecated.WarnDeprecated(f.ctx, + "support for netaddr.IPPrefix is deprecated - use net.ParsePrefix to produce a netip.Prefix instead") + return f.ipPrefixFromIPNet(p.Masked().IPNet()), nil + case netip.Prefix: + return p, nil + default: + return netip.ParsePrefix(conv.ToString(prefix)) + } +} + +// func (f NetFuncs) ipFromNetIP(n stdnet.IP) netip.Addr { +// ip, _ := netip.AddrFromSlice(n) +// return ip +// } + +func (f NetFuncs) ipPrefixFromIPNet(n *stdnet.IPNet) netip.Prefix { + ip, _ := netip.AddrFromSlice(n.IP) + ones, _ := n.Mask.Size() + return netip.PrefixFrom(ip, ones) +} + +// CIDRHost - +// Experimental! +func (f *NetFuncs) CIDRHost(hostnum interface{}, prefix interface{}) (netip.Addr, error) { + if err := checkExperimental(f.ctx); err != nil { + return netip.Addr{}, err + } + + network, err := f.parseNetipPrefix(prefix) + if err != nil { + return netip.Addr{}, err + } + + ip, err := cidr.HostBig(network, big.NewInt(conv.ToInt64(hostnum))) + + return ip, err +} + +// CIDRNetmask - +// Experimental! +func (f *NetFuncs) CIDRNetmask(prefix interface{}) (netip.Addr, error) { + if err := checkExperimental(f.ctx); err != nil { + return netip.Addr{}, err + } + + p, err := f.parseNetipPrefix(prefix) + if err != nil { + return netip.Addr{}, err + } + + // fill an appropriately sized byte slice with as many 1s as prefix bits + b := make([]byte, p.Addr().BitLen()/8) + for i := 0; i < p.Bits(); i++ { + b[i/8] |= 1 << uint(7-i%8) + } + + m, ok := netip.AddrFromSlice(b) + if !ok { + return netip.Addr{}, fmt.Errorf("invalid netmask") + } + + return m, nil +} + +// CIDRSubnets - +// Experimental! +func (f *NetFuncs) CIDRSubnets(newbits interface{}, prefix interface{}) ([]netip.Prefix, error) { + if err := checkExperimental(f.ctx); err != nil { + return nil, err + } + + network, err := f.parseNetipPrefix(prefix) + if err != nil { + return nil, err + } + + nBits := conv.ToInt(newbits) + if nBits < 1 { + return nil, fmt.Errorf("must extend prefix by at least one bit") + } + + maxNetNum := int64(1 << uint64(nBits)) + retValues := make([]netip.Prefix, maxNetNum) + for i := int64(0); i < maxNetNum; i++ { + subnet, err := cidr.SubnetBig(network, nBits, big.NewInt(i)) + if err != nil { + return nil, err + } + retValues[i] = subnet + } + + return retValues, nil +} + +// CIDRSubnetSizes - +// Experimental! +func (f *NetFuncs) CIDRSubnetSizes(args ...interface{}) ([]netip.Prefix, error) { + if err := checkExperimental(f.ctx); err != nil { + return nil, err + } + + if len(args) < 2 { + return nil, fmt.Errorf("wrong number of args: want 2 or more, got %d", len(args)) + } + + network, err := f.parseNetipPrefix(args[len(args)-1]) + if err != nil { + return nil, err + } + newbits := conv.ToInts(args[:len(args)-1]...) + + startPrefixLen := network.Bits() + firstLength := newbits[0] + + firstLength += startPrefixLen + retValues := make([]netip.Prefix, len(newbits)) + + current, _ := cidr.PreviousSubnet(network, firstLength) + + for i, length := range newbits { + if length < 1 { + return nil, fmt.Errorf("must extend prefix by at least one bit") + } + // For portability with 32-bit systems where the subnet number + // will be a 32-bit int, we only allow extension of 32 bits in + // one call even if we're running on a 64-bit machine. + // (Of course, this is significant only for IPv6.) + if length > 32 { + return nil, fmt.Errorf("may not extend prefix by more than 32 bits") + } + + length += startPrefixLen + if length > network.Addr().BitLen() { + protocol := "IP" + switch { + case network.Addr().Is4(): + protocol = "IPv4" + case network.Addr().Is6(): + protocol = "IPv6" + } + return nil, fmt.Errorf("would extend prefix to %d bits, which is too long for an %s address", length, protocol) + } + + next, rollover := cidr.NextSubnet(current, length) + if rollover || !network.Contains(next.Addr()) { + // If we run out of suffix bits in the base CIDR prefix then + // NextSubnet will start incrementing the prefix bits, which + // we don't allow because it would then allocate addresses + // outside of the caller's given prefix. + return nil, fmt.Errorf("not enough remaining address space for a subnet with a prefix of %d bits after %s", length, current.String()) + } + current = next + retValues[i] = current + } + + return retValues, nil +} diff --git a/internal/funcs/net_test.go b/internal/funcs/net_test.go new file mode 100644 index 00000000..7c3afffe --- /dev/null +++ b/internal/funcs/net_test.go @@ -0,0 +1,267 @@ +package funcs + +import ( + "context" + stdnet "net" + "net/netip" + "strconv" + "testing" + + "github.com/hairyhenderson/gomplate/v4/internal/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "inet.af/netaddr" +) + +func TestCreateNetFuncs(t *testing.T) { + t.Parallel() + + for i := 0; i < 10; i++ { + // Run this a bunch to catch race conditions + t.Run(strconv.Itoa(i), func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + fmap := CreateNetFuncs(ctx) + actual := fmap["net"].(func() interface{}) + + assert.Equal(t, ctx, actual().(*NetFuncs).ctx) + }) + } +} + +func TestNetLookupIP(t *testing.T) { + t.Parallel() + + n := NetFuncs{} + assert.Equal(t, "127.0.0.1", must(n.LookupIP("localhost"))) +} + +func TestParseIP(t *testing.T) { + t.Parallel() + + n := testNetNS() + _, err := n.ParseIP("not an IP") + assert.Error(t, err) + + ip, err := n.ParseIP("2001:470:20::2") + require.NoError(t, err) + assert.Equal(t, netaddr.IPFrom16([16]byte{ + 0x20, 0x01, 0x04, 0x70, + 0, 0x20, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0x02, + }), ip) +} + +func TestParseIPPrefix(t *testing.T) { + t.Parallel() + + n := testNetNS() + _, err := n.ParseIPPrefix("not an IP") + assert.Error(t, err) + + _, err = n.ParseIPPrefix("1.1.1.1") + assert.Error(t, err) + + ipprefix, err := n.ParseIPPrefix("192.168.0.2/28") + require.NoError(t, err) + assert.Equal(t, "192.168.0.0/28", ipprefix.Masked().String()) +} + +func TestParseIPRange(t *testing.T) { + t.Parallel() + + n := testNetNS() + _, err := n.ParseIPRange("not an IP") + assert.Error(t, err) + + _, err = n.ParseIPRange("1.1.1.1") + assert.Error(t, err) + + iprange, err := n.ParseIPRange("192.168.0.2-192.168.23.255") + require.NoError(t, err) + assert.Equal(t, "192.168.0.2-192.168.23.255", iprange.String()) +} + +func TestParseAddr(t *testing.T) { + t.Parallel() + + n := testNetNS() + _, err := n.ParseAddr("not an IP") + assert.Error(t, err) + + ip, err := n.ParseAddr("2001:470:20::2") + require.NoError(t, err) + assert.Equal(t, netip.AddrFrom16([16]byte{ + 0x20, 0x01, 0x04, 0x70, + 0, 0x20, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0x02, + }), ip) +} + +func TestParsePrefix(t *testing.T) { + t.Parallel() + + n := testNetNS() + _, err := n.ParsePrefix("not an IP") + assert.Error(t, err) + + _, err = n.ParsePrefix("1.1.1.1") + assert.Error(t, err) + + ipprefix, err := n.ParsePrefix("192.168.0.2/28") + require.NoError(t, err) + assert.Equal(t, "192.168.0.0/28", ipprefix.Masked().String()) +} + +func TestParseRange(t *testing.T) { + t.Parallel() + + n := testNetNS() + _, err := n.ParseRange("not an IP") + assert.Error(t, err) + + _, err = n.ParseRange("1.1.1.1") + assert.Error(t, err) + + iprange, err := n.ParseRange("192.168.0.2-192.168.23.255") + require.NoError(t, err) + assert.Equal(t, "192.168.0.2-192.168.23.255", iprange.String()) +} + +func testNetNS() *NetFuncs { + return &NetFuncs{ctx: config.SetExperimental(context.Background())} +} + +func TestCIDRHost(t *testing.T) { + n := testNetNS() + + // net.IPNet + _, netIP, _ := stdnet.ParseCIDR("10.12.127.0/20") + + ip, err := n.CIDRHost(16, netIP) + require.NoError(t, err) + assert.Equal(t, "10.12.112.16", ip.String()) + + ip, err = n.CIDRHost(268, netIP) + require.NoError(t, err) + assert.Equal(t, "10.12.113.12", ip.String()) + + _, netIP, _ = stdnet.ParseCIDR("fd00:fd12:3456:7890:00a2::/72") + ip, err = n.CIDRHost(34, netIP) + require.NoError(t, err) + assert.Equal(t, "fd00:fd12:3456:7890::22", ip.String()) + + // inet.af/netaddr.IPPrefix + ipPrefix, _ := n.ParseIPPrefix("10.12.127.0/20") + + ip, err = n.CIDRHost(16, ipPrefix) + require.NoError(t, err) + assert.Equal(t, "10.12.112.16", ip.String()) + + ip, err = n.CIDRHost(268, ipPrefix) + require.NoError(t, err) + assert.Equal(t, "10.12.113.12", ip.String()) + + ipPrefix, _ = n.ParseIPPrefix("fd00:fd12:3456:7890:00a2::/72") + ip, err = n.CIDRHost(34, ipPrefix) + require.NoError(t, err) + assert.Equal(t, "fd00:fd12:3456:7890::22", ip.String()) + + // net/netip.Prefix + prefix := netip.MustParsePrefix("10.12.127.0/20") + + ip, err = n.CIDRHost(16, prefix) + require.NoError(t, err) + assert.Equal(t, "10.12.112.16", ip.String()) + + ip, err = n.CIDRHost(268, prefix) + require.NoError(t, err) + assert.Equal(t, "10.12.113.12", ip.String()) + + prefix = netip.MustParsePrefix("fd00:fd12:3456:7890:00a2::/72") + ip, err = n.CIDRHost(34, prefix) + require.NoError(t, err) + assert.Equal(t, "fd00:fd12:3456:7890::22", ip.String()) +} + +func TestCIDRNetmask(t *testing.T) { + n := testNetNS() + + ip, err := n.CIDRNetmask("10.0.0.0/12") + require.NoError(t, err) + assert.Equal(t, "255.240.0.0", ip.String()) + + ip, err = n.CIDRNetmask("fd00:fd12:3456:7890:00a2::/72") + require.NoError(t, err) + assert.Equal(t, "ffff:ffff:ffff:ffff:ff00::", ip.String()) +} + +func TestCIDRSubnets(t *testing.T) { + n := testNetNS() + network := netip.MustParsePrefix("10.0.0.0/16") + + subnets, err := n.CIDRSubnets(-1, network) + assert.Error(t, err) + assert.Nil(t, subnets) + + subnets, err = n.CIDRSubnets(2, network) + require.NoError(t, err) + assert.Len(t, subnets, 4) + assert.Equal(t, "10.0.0.0/18", subnets[0].String()) + assert.Equal(t, "10.0.64.0/18", subnets[1].String()) + assert.Equal(t, "10.0.128.0/18", subnets[2].String()) + assert.Equal(t, "10.0.192.0/18", subnets[3].String()) +} + +func TestCIDRSubnetSizes(t *testing.T) { + n := testNetNS() + + subnets, err := n.CIDRSubnetSizes(netip.MustParsePrefix("10.1.0.0/16")) + assert.Error(t, err) + assert.Nil(t, subnets) + + subnets, err = n.CIDRSubnetSizes(32, netip.MustParsePrefix("10.1.0.0/16")) + assert.Error(t, err) + assert.Nil(t, subnets) + + subnets, err = n.CIDRSubnetSizes(127, netip.MustParsePrefix("ffff::/48")) + assert.Error(t, err) + assert.Nil(t, subnets) + + subnets, err = n.CIDRSubnetSizes(-1, netip.MustParsePrefix("10.1.0.0/16")) + assert.Error(t, err) + assert.Nil(t, subnets) + + network := netip.MustParsePrefix("8000::/1") + subnets, err = n.CIDRSubnetSizes(1, 2, 2, network) + require.NoError(t, err) + assert.Len(t, subnets, 3) + assert.Equal(t, "8000::/2", subnets[0].String()) + assert.Equal(t, "c000::/3", subnets[1].String()) + assert.Equal(t, "e000::/3", subnets[2].String()) + + network = netip.MustParsePrefix("10.1.0.0/16") + subnets, err = n.CIDRSubnetSizes(4, 4, 8, 4, network) + require.NoError(t, err) + assert.Len(t, subnets, 4) + assert.Equal(t, "10.1.0.0/20", subnets[0].String()) + assert.Equal(t, "10.1.16.0/20", subnets[1].String()) + assert.Equal(t, "10.1.32.0/24", subnets[2].String()) + assert.Equal(t, "10.1.48.0/20", subnets[3].String()) + + network = netip.MustParsePrefix("2016:1234:5678:9abc:ffff:ffff:ffff:cafe/64") + subnets, err = n.CIDRSubnetSizes(2, 2, 3, 3, 6, 6, 8, 10, network) + require.NoError(t, err) + assert.Len(t, subnets, 8) + assert.Equal(t, "2016:1234:5678:9abc::/66", subnets[0].String()) + assert.Equal(t, "2016:1234:5678:9abc:4000::/66", subnets[1].String()) + assert.Equal(t, "2016:1234:5678:9abc:8000::/67", subnets[2].String()) + assert.Equal(t, "2016:1234:5678:9abc:a000::/67", subnets[3].String()) + assert.Equal(t, "2016:1234:5678:9abc:c000::/70", subnets[4].String()) + assert.Equal(t, "2016:1234:5678:9abc:c400::/70", subnets[5].String()) + assert.Equal(t, "2016:1234:5678:9abc:c800::/72", subnets[6].String()) + assert.Equal(t, "2016:1234:5678:9abc:c900::/74", subnets[7].String()) +} diff --git a/internal/funcs/path.go b/internal/funcs/path.go new file mode 100644 index 00000000..324f188a --- /dev/null +++ b/internal/funcs/path.go @@ -0,0 +1,79 @@ +package funcs + +import ( + "context" + "path" + + "github.com/hairyhenderson/gomplate/v4/conv" +) + +// PathNS - the Path namespace +// +// Deprecated: don't use +func PathNS() *PathFuncs { + return &PathFuncs{} +} + +// AddPathFuncs - +// +// Deprecated: use [CreatePathFuncs] instead +func AddPathFuncs(f map[string]interface{}) { + for k, v := range CreatePathFuncs(context.Background()) { + f[k] = v + } +} + +// CreatePathFuncs - +func CreatePathFuncs(ctx context.Context) map[string]interface{} { + ns := &PathFuncs{ctx} + return map[string]interface{}{ + "path": func() interface{} { return ns }, + } +} + +// PathFuncs - +type PathFuncs struct { + ctx context.Context +} + +// Base - +func (PathFuncs) Base(in interface{}) string { + return path.Base(conv.ToString(in)) +} + +// Clean - +func (PathFuncs) Clean(in interface{}) string { + return path.Clean(conv.ToString(in)) +} + +// Dir - +func (PathFuncs) Dir(in interface{}) string { + return path.Dir(conv.ToString(in)) +} + +// Ext - +func (PathFuncs) Ext(in interface{}) string { + return path.Ext(conv.ToString(in)) +} + +// IsAbs - +func (PathFuncs) IsAbs(in interface{}) bool { + return path.IsAbs(conv.ToString(in)) +} + +// Join - +func (PathFuncs) Join(elem ...interface{}) string { + s := conv.ToStrings(elem...) + return path.Join(s...) +} + +// Match - +func (PathFuncs) Match(pattern, name interface{}) (matched bool, err error) { + return path.Match(conv.ToString(pattern), conv.ToString(name)) +} + +// Split - +func (PathFuncs) Split(in interface{}) []string { + dir, file := path.Split(conv.ToString(in)) + return []string{dir, file} +} diff --git a/internal/funcs/path_test.go b/internal/funcs/path_test.go new file mode 100644 index 00000000..a314933b --- /dev/null +++ b/internal/funcs/path_test.go @@ -0,0 +1,46 @@ +package funcs + +import ( + "context" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCreatePathFuncs(t *testing.T) { + t.Parallel() + + for i := 0; i < 10; i++ { + // Run this a bunch to catch race conditions + t.Run(strconv.Itoa(i), func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + fmap := CreatePathFuncs(ctx) + actual := fmap["path"].(func() interface{}) + + assert.Equal(t, ctx, actual().(*PathFuncs).ctx) + }) + } +} + +func TestPathFuncs(t *testing.T) { + t.Parallel() + + p := PathFuncs{} + assert.Equal(t, "bar", p.Base("foo/bar")) + assert.Equal(t, "bar", p.Base("/foo/bar")) + + assert.Equal(t, "/foo/baz", p.Clean("/foo/bar/../baz")) + assert.Equal(t, "foo", p.Dir("foo/bar")) + assert.Equal(t, ".txt", p.Ext("/foo/bar/baz.txt")) + assert.False(t, false, p.IsAbs("foo/bar")) + assert.True(t, p.IsAbs("/foo/bar")) + assert.Equal(t, "foo/bar/qux", p.Join("foo", "bar", "baz", "..", "qux")) + m, _ := p.Match("*.txt", "foo.json") + assert.False(t, m) + m, _ = p.Match("*.txt", "foo.txt") + assert.True(t, m) + assert.Equal(t, []string{"/foo/bar/", "baz"}, p.Split("/foo/bar/baz")) +} diff --git a/internal/funcs/random.go b/internal/funcs/random.go new file mode 100644 index 00000000..6aa863be --- /dev/null +++ b/internal/funcs/random.go @@ -0,0 +1,157 @@ +package funcs + +import ( + "context" + "fmt" + "strconv" + "unicode/utf8" + + "github.com/hairyhenderson/gomplate/v4/conv" + iconv "github.com/hairyhenderson/gomplate/v4/internal/conv" + "github.com/hairyhenderson/gomplate/v4/random" +) + +// RandomNS - +// +// Deprecated: don't use +func RandomNS() *RandomFuncs { + return &RandomFuncs{} +} + +// AddRandomFuncs - +// +// Deprecated: use [CreateRandomFuncs] instead +func AddRandomFuncs(f map[string]interface{}) { + for k, v := range CreateRandomFuncs(context.Background()) { + f[k] = v + } +} + +// CreateRandomFuncs - +func CreateRandomFuncs(ctx context.Context) map[string]interface{} { + ns := &RandomFuncs{ctx} + return map[string]interface{}{ + "random": func() interface{} { return ns }, + } +} + +// RandomFuncs - +type RandomFuncs struct { + ctx context.Context +} + +// ASCII - +func (RandomFuncs) ASCII(count interface{}) (string, error) { + return random.StringBounds(conv.ToInt(count), ' ', '~') +} + +// Alpha - +func (RandomFuncs) Alpha(count interface{}) (string, error) { + return random.StringRE(conv.ToInt(count), "[[:alpha:]]") +} + +// AlphaNum - +func (RandomFuncs) AlphaNum(count interface{}) (string, error) { + return random.StringRE(conv.ToInt(count), "[[:alnum:]]") +} + +// String - +func (RandomFuncs) String(count interface{}, args ...interface{}) (s string, err error) { + c := conv.ToInt(count) + if c == 0 { + return "", fmt.Errorf("count must be greater than 0") + } + m := "" + switch len(args) { + case 0: + m = "" + case 1: + m = conv.ToString(args[0]) + case 2: + var l, u rune + if isString(args[0]) && isString(args[1]) { + l, u, err = toCodePoints(args[0].(string), args[1].(string)) + if err != nil { + return "", err + } + } else { + l = rune(conv.ToInt(args[0])) + u = rune(conv.ToInt(args[1])) + } + + return random.StringBounds(c, l, u) + } + + return random.StringRE(c, m) +} + +func isString(s interface{}) bool { + switch s.(type) { + case string: + return true + default: + return false + } +} + +var rlen = utf8.RuneCountInString + +func toCodePoints(l, u string) (rune, rune, error) { + // no way are these representing valid printable codepoints - we'll treat + // them as runes + if rlen(l) == 1 && rlen(u) == 1 { + lower, _ := utf8.DecodeRuneInString(l) + upper, _ := utf8.DecodeRuneInString(u) + return lower, upper, nil + } + + li, err := strconv.ParseInt(l, 0, 32) + if err != nil { + return 0, 0, err + } + ui, err := strconv.ParseInt(u, 0, 32) + if err != nil { + return 0, 0, err + } + + return rune(li), rune(ui), nil +} + +// Item - +func (RandomFuncs) Item(items interface{}) (interface{}, error) { + i, err := iconv.InterfaceSlice(items) + if err != nil { + return nil, err + } + return random.Item(i) +} + +// Number - +func (RandomFuncs) Number(args ...interface{}) (int64, error) { + var min, max int64 + min, max = 0, 100 + switch len(args) { + case 0: + case 1: + max = conv.ToInt64(args[0]) + case 2: + min = conv.ToInt64(args[0]) + max = conv.ToInt64(args[1]) + } + return random.Number(min, max) +} + +// Float - +func (RandomFuncs) Float(args ...interface{}) (float64, error) { + var min, max float64 + min, max = 0, 1.0 + switch len(args) { + case 0: + case 1: + max = conv.ToFloat64(args[0]) + case 2: + min = conv.ToFloat64(args[0]) + max = conv.ToFloat64(args[1]) + } + return random.Float(min, max) +} diff --git a/internal/funcs/random_test.go b/internal/funcs/random_test.go new file mode 100644 index 00000000..7f166ffc --- /dev/null +++ b/internal/funcs/random_test.go @@ -0,0 +1,223 @@ +package funcs + +import ( + "context" + "strconv" + "testing" + "unicode/utf8" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCreateRandomFuncs(t *testing.T) { + t.Parallel() + + for i := 0; i < 10; i++ { + // Run this a bunch to catch race conditions + t.Run(strconv.Itoa(i), func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + fmap := CreateRandomFuncs(ctx) + actual := fmap["random"].(func() interface{}) + + assert.Equal(t, ctx, actual().(*RandomFuncs).ctx) + }) + } +} + +func TestASCII(t *testing.T) { + t.Parallel() + + f := RandomFuncs{} + s, err := f.ASCII(0) + require.NoError(t, err) + assert.Empty(t, s) + + s, err = f.ASCII(100) + require.NoError(t, err) + assert.Len(t, s, 100) + assert.Regexp(t, "^[[:print:]]*$", s) +} + +func TestAlpha(t *testing.T) { + t.Parallel() + + if testing.Short() { + t.Skip("skipping slow test") + } + + f := RandomFuncs{} + s, err := f.Alpha(0) + require.NoError(t, err) + assert.Empty(t, s) + + s, err = f.Alpha(100) + require.NoError(t, err) + assert.Len(t, s, 100) + assert.Regexp(t, "^[[:alpha:]]*$", s) +} + +func TestAlphaNum(t *testing.T) { + t.Parallel() + + if testing.Short() { + t.Skip("skipping slow test") + } + + f := RandomFuncs{} + s, err := f.AlphaNum(0) + require.NoError(t, err) + assert.Empty(t, s) + + s, err = f.AlphaNum(100) + require.NoError(t, err) + assert.Len(t, s, 100) + assert.Regexp(t, "^[[:alnum:]]*$", s) +} + +func TestToCodePoints(t *testing.T) { + t.Parallel() + + l, u, err := toCodePoints("a", "b") + require.NoError(t, err) + assert.Equal(t, 'a', l) + assert.Equal(t, 'b', u) + + _, _, err = toCodePoints("foo", "bar") + assert.Error(t, err) + + _, _, err = toCodePoints("0755", "bar") + assert.Error(t, err) + + l, u, err = toCodePoints("0xD700", "0x0001FFFF") + require.NoError(t, err) + assert.Equal(t, '\ud700', l) + assert.Equal(t, '\U0001ffff', u) + + l, u, err = toCodePoints("0011", "0777") + require.NoError(t, err) + assert.Equal(t, rune(0o011), l) + assert.Equal(t, rune(0o777), u) + + l, u, err = toCodePoints("♬", "♟") + require.NoError(t, err) + assert.Equal(t, rune(0x266C), l) + assert.Equal(t, '♟', u) +} + +func TestString(t *testing.T) { + t.Parallel() + + if testing.Short() { + t.Skip("skipping slow test") + } + + f := RandomFuncs{} + out, err := f.String(1) + require.NoError(t, err) + assert.Len(t, out, 1) + + out, err = f.String(42) + require.NoError(t, err) + assert.Len(t, out, 42) + + _, err = f.String(0) + assert.Error(t, err) + + out, err = f.String(8, "[a-z]") + require.NoError(t, err) + assert.Regexp(t, "^[a-z]{8}$", out) + + out, err = f.String(10, 0x23, 0x26) + require.NoError(t, err) + assert.Regexp(t, "^[#$%&]{10}$", out) + + out, err = f.String(8, '\U0001f062', '\U0001f093') + require.NoError(t, err) + assert.Regexp(t, "^[🁢-🂓]{8}$", out) + + out, err = f.String(8, '\U0001f062', '\U0001f093') + require.NoError(t, err) + assert.Regexp(t, "^[🁢-🂓]{8}$", out) + + out, err = f.String(8, "♚", "♟") + require.NoError(t, err) + assert.Regexp(t, "^[♚-♟]{8}$", out) + + out, err = f.String(100, "♠", "♣") + require.NoError(t, err) + assert.Equal(t, 100, utf8.RuneCountInString(out)) + assert.Regexp(t, "^[♠-♣]{100}$", out) +} + +func TestItem(t *testing.T) { + t.Parallel() + + f := RandomFuncs{} + _, err := f.Item(nil) + assert.Error(t, err) + + _, err = f.Item("foo") + assert.Error(t, err) + + i, err := f.Item([]string{"foo"}) + require.NoError(t, err) + assert.Equal(t, "foo", i) + + in := []string{"foo", "bar"} + got := "" + for j := 0; j < 10; j++ { + i, err = f.Item(in) + require.NoError(t, err) + got += i.(string) + } + assert.NotEqual(t, "foofoofoofoofoofoofoofoofoofoo", got) + assert.NotEqual(t, "barbarbarbarbarbarbarbarbarbar", got) +} + +func TestNumber(t *testing.T) { + t.Parallel() + + f := RandomFuncs{} + n, err := f.Number() + require.NoError(t, err) + assert.True(t, 0 <= n && n <= 100, n) + + _, err = f.Number(-1) + assert.Error(t, err) + + n, err = f.Number(0) + require.NoError(t, err) + assert.Equal(t, int64(0), n) + + n, err = f.Number(9, 9) + require.NoError(t, err) + assert.Equal(t, int64(9), n) + + n, err = f.Number(-10, -10) + require.NoError(t, err) + assert.Equal(t, int64(-10), n) +} + +func TestFloat(t *testing.T) { + t.Parallel() + + f := RandomFuncs{} + n, err := f.Float() + require.NoError(t, err) + assert.InDelta(t, 0.5, n, 0.5) + + n, err = f.Float(0.5) + require.NoError(t, err) + assert.InDelta(t, 0.25, n, 0.25) + + n, err = f.Float(490, 500) + require.NoError(t, err) + assert.InDelta(t, 495, n, 5) + + n, err = f.Float(-500, 500) + require.NoError(t, err) + assert.InDelta(t, 0, n, 500) +} diff --git a/internal/funcs/regexp.go b/internal/funcs/regexp.go new file mode 100644 index 00000000..66cbe4c4 --- /dev/null +++ b/internal/funcs/regexp.go @@ -0,0 +1,106 @@ +package funcs + +import ( + "context" + "fmt" + + "github.com/hairyhenderson/gomplate/v4/conv" + "github.com/hairyhenderson/gomplate/v4/regexp" +) + +// ReNS - +// +// Deprecated: don't use +func ReNS() *ReFuncs { + return &ReFuncs{} +} + +// AddReFuncs - +// +// Deprecated: use [CreateReFuncs] instead +func AddReFuncs(f map[string]interface{}) { + for k, v := range CreateReFuncs(context.Background()) { + f[k] = v + } +} + +// CreateReFuncs - +func CreateReFuncs(ctx context.Context) map[string]interface{} { + ns := &ReFuncs{ctx} + return map[string]interface{}{ + "regexp": func() interface{} { return ns }, + } +} + +// ReFuncs - +type ReFuncs struct { + ctx context.Context +} + +// Find - +func (ReFuncs) Find(re, input interface{}) (string, error) { + return regexp.Find(conv.ToString(re), conv.ToString(input)) +} + +// FindAll - +func (ReFuncs) FindAll(args ...interface{}) ([]string, error) { + re := "" + n := 0 + input := "" + switch len(args) { + case 2: + n = -1 + re = conv.ToString(args[0]) + input = conv.ToString(args[1]) + case 3: + re = conv.ToString(args[0]) + n = conv.ToInt(args[1]) + input = conv.ToString(args[2]) + default: + return nil, fmt.Errorf("wrong number of args: want 2 or 3, got %d", len(args)) + } + return regexp.FindAll(re, n, input) +} + +// Match - +func (ReFuncs) Match(re, input interface{}) bool { + return regexp.Match(conv.ToString(re), conv.ToString(input)) +} + +// QuoteMeta - +func (ReFuncs) QuoteMeta(in interface{}) string { + return regexp.QuoteMeta(conv.ToString(in)) +} + +// Replace - +func (ReFuncs) Replace(re, replacement, input interface{}) string { + return regexp.Replace(conv.ToString(re), + conv.ToString(replacement), + conv.ToString(input)) +} + +// ReplaceLiteral - +func (ReFuncs) ReplaceLiteral(re, replacement, input interface{}) (string, error) { + return regexp.ReplaceLiteral(conv.ToString(re), + conv.ToString(replacement), + conv.ToString(input)) +} + +// Split - +func (ReFuncs) Split(args ...interface{}) ([]string, error) { + re := "" + n := -1 + input := "" + switch len(args) { + case 2: + re = conv.ToString(args[0]) + input = conv.ToString(args[1]) + case 3: + re = conv.ToString(args[0]) + n = conv.ToInt(args[1]) + input = conv.ToString(args[2]) + default: + return nil, fmt.Errorf("wrong number of args: want 2 or 3, got %d", len(args)) + } + return regexp.Split(re, n, input) +} diff --git a/internal/funcs/regexp_test.go b/internal/funcs/regexp_test.go new file mode 100644 index 00000000..181c74df --- /dev/null +++ b/internal/funcs/regexp_test.go @@ -0,0 +1,146 @@ +package funcs + +import ( + "context" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCreateReFuncs(t *testing.T) { + t.Parallel() + + for i := 0; i < 10; i++ { + // Run this a bunch to catch race conditions + t.Run(strconv.Itoa(i), func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + fmap := CreateReFuncs(ctx) + actual := fmap["regexp"].(func() interface{}) + + assert.Equal(t, ctx, actual().(*ReFuncs).ctx) + }) + } +} + +func TestReplace(t *testing.T) { + t.Parallel() + + re := &ReFuncs{} + assert.Equal(t, "hello world", re.Replace("i", "ello", "hi world")) +} + +func TestMatch(t *testing.T) { + t.Parallel() + + re := &ReFuncs{} + assert.True(t, re.Match(`i\ `, "hi world")) +} + +func TestFind(t *testing.T) { + t.Parallel() + + re := &ReFuncs{} + f, err := re.Find(`[a-z]+`, `foo bar baz`) + require.NoError(t, err) + assert.Equal(t, "foo", f) + + _, err = re.Find(`[a-`, "") + assert.Error(t, err) + + f, err = re.Find("4", 42) + require.NoError(t, err) + assert.Equal(t, "4", f) + + f, err = re.Find(false, 42) + require.NoError(t, err) + assert.Equal(t, "", f) +} + +func TestFindAll(t *testing.T) { + t.Parallel() + + re := &ReFuncs{} + f, err := re.FindAll(`[a-z]+`, `foo bar baz`) + require.NoError(t, err) + assert.EqualValues(t, []string{"foo", "bar", "baz"}, f) + + f, err = re.FindAll(`[a-z]+`, -1, `foo bar baz`) + require.NoError(t, err) + assert.EqualValues(t, []string{"foo", "bar", "baz"}, f) + + _, err = re.FindAll(`[a-`, "") + assert.Error(t, err) + + _, err = re.FindAll("") + assert.Error(t, err) + + _, err = re.FindAll("", "", "", "") + assert.Error(t, err) + + f, err = re.FindAll(`[a-z]+`, 0, `foo bar baz`) + require.NoError(t, err) + assert.Nil(t, f) + + f, err = re.FindAll(`[a-z]+`, 2, `foo bar baz`) + require.NoError(t, err) + assert.EqualValues(t, []string{"foo", "bar"}, f) + + f, err = re.FindAll(`[a-z]+`, 14, `foo bar baz`) + require.NoError(t, err) + assert.EqualValues(t, []string{"foo", "bar", "baz"}, f) + + f, err = re.FindAll(`qux`, `foo bar baz`) + require.NoError(t, err) + assert.Nil(t, f) +} + +func TestSplit(t *testing.T) { + t.Parallel() + + re := &ReFuncs{} + f, err := re.Split(` `, `foo bar baz`) + require.NoError(t, err) + assert.EqualValues(t, []string{"foo", "bar", "baz"}, f) + + f, err = re.Split(`\s+`, -1, `foo bar baz`) + require.NoError(t, err) + assert.EqualValues(t, []string{"foo", "bar", "baz"}, f) + + _, err = re.Split(`[a-`, "") + assert.Error(t, err) + + _, err = re.Split("") + assert.Error(t, err) + + _, err = re.Split("", "", "", "") + assert.Error(t, err) + + f, err = re.Split(` `, 0, `foo bar baz`) + require.NoError(t, err) + assert.Nil(t, f) + + f, err = re.Split(`\s+`, 2, `foo bar baz`) + require.NoError(t, err) + assert.EqualValues(t, []string{"foo", "bar baz"}, f) + + f, err = re.Split(`\s`, 14, `foo bar baz`) + require.NoError(t, err) + assert.EqualValues(t, []string{"foo", "", "bar", "baz"}, f) + + f, err = re.Split(`[\s,.]`, 14, `foo bar.baz,qux`) + require.NoError(t, err) + assert.EqualValues(t, []string{"foo", "bar", "baz", "qux"}, f) +} + +func TestReplaceLiteral(t *testing.T) { + t.Parallel() + + re := &ReFuncs{} + r, err := re.ReplaceLiteral("i", "ello$1", "hi world") + require.NoError(t, err) + assert.Equal(t, "hello$1 world", r) +} diff --git a/internal/funcs/semver.go b/internal/funcs/semver.go new file mode 100644 index 00000000..0212c998 --- /dev/null +++ b/internal/funcs/semver.go @@ -0,0 +1,40 @@ +package funcs + +import ( + "context" + + "github.com/Masterminds/semver/v3" +) + +// CreateSemverFuncs - +func CreateSemverFuncs(ctx context.Context) map[string]interface{} { + ns := &SemverFuncs{ctx} + return map[string]interface{}{ + "semver": func() interface{} { return ns }, + } +} + +// SemverFuncs - +type SemverFuncs struct { + ctx context.Context +} + +// Semver - +func (SemverFuncs) Semver(version string) (*semver.Version, error) { + return semver.NewVersion(version) +} + +// CheckConstraint - +func (SemverFuncs) CheckConstraint(constraint, in string) (bool, error) { + c, err := semver.NewConstraint(constraint) + if err != nil { + return false, err + } + + v, err := semver.NewVersion(in) + if err != nil { + return false, err + } + + return c.Check(v), nil +} diff --git a/internal/funcs/semver_test.go b/internal/funcs/semver_test.go new file mode 100644 index 00000000..f10aad5c --- /dev/null +++ b/internal/funcs/semver_test.go @@ -0,0 +1,61 @@ +package funcs + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSemverFuncs_MatchConstraint(t *testing.T) { + tests := []struct { + name string + constraint string + in string + want bool + wantErr bool + }{ + { + name: "mached constraint", + constraint: ">=1.0.0", + in: "v1.1.1", + want: true, + wantErr: false, + }, + { + name: "not matched constraint", + constraint: "<1.0.0", + in: "v1.1.1", + want: false, + wantErr: false, + }, + { + name: "wrong constraint", + constraint: "abc", + in: "v1.1.1", + want: false, + wantErr: true, + }, + { + name: "wrong in", + constraint: ">1.0.0", + in: "va.b.c", + want: false, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := SemverFuncs{ + ctx: context.Background(), + } + got, err := s.CheckConstraint(tt.constraint, tt.in) + if tt.wantErr { + assert.Errorf(t, err, "SemverFuncs.CheckConstraint() error = %v, wantErr %v", err, tt.wantErr) + } else { + assert.NoErrorf(t, err, "SemverFuncs.CheckConstraint() error = %v, wantErr %v", err, tt.wantErr) + assert.Equal(t, tt.want, got) + } + }) + } +} diff --git a/internal/funcs/sockaddr.go b/internal/funcs/sockaddr.go new file mode 100644 index 00000000..e41fb162 --- /dev/null +++ b/internal/funcs/sockaddr.go @@ -0,0 +1,130 @@ +package funcs + +import ( + "context" + + "github.com/hashicorp/go-sockaddr" + "github.com/hashicorp/go-sockaddr/template" +) + +// SockaddrNS - the sockaddr namespace +// +// Deprecated: don't use +func SockaddrNS() *SockaddrFuncs { + return &SockaddrFuncs{} +} + +// AddSockaddrFuncs - +// +// Deprecated: use [CreateSockaddrFuncs] instead +func AddSockaddrFuncs(f map[string]interface{}) { + f["sockaddr"] = SockaddrNS +} + +// CreateSockaddrFuncs - +func CreateSockaddrFuncs(ctx context.Context) map[string]interface{} { + ns := &SockaddrFuncs{ctx} + return map[string]interface{}{ + "sockaddr": func() interface{} { return ns }, + } +} + +// SockaddrFuncs - +type SockaddrFuncs struct { + ctx context.Context +} + +// GetAllInterfaces - +func (SockaddrFuncs) GetAllInterfaces() (sockaddr.IfAddrs, error) { + return sockaddr.GetAllInterfaces() +} + +// GetDefaultInterfaces - +func (SockaddrFuncs) GetDefaultInterfaces() (sockaddr.IfAddrs, error) { + return sockaddr.GetDefaultInterfaces() +} + +// GetPrivateInterfaces - +func (SockaddrFuncs) GetPrivateInterfaces() (sockaddr.IfAddrs, error) { + return sockaddr.GetPrivateInterfaces() +} + +// GetPublicInterfaces - +func (SockaddrFuncs) GetPublicInterfaces() (sockaddr.IfAddrs, error) { + return sockaddr.GetPublicInterfaces() +} + +// Sort - +func (SockaddrFuncs) Sort(selectorParam string, inputIfAddrs sockaddr.IfAddrs) (sockaddr.IfAddrs, error) { + return sockaddr.SortIfBy(selectorParam, inputIfAddrs) +} + +// Exclude - +func (SockaddrFuncs) Exclude(selectorName, selectorParam string, inputIfAddrs sockaddr.IfAddrs) (sockaddr.IfAddrs, error) { + return sockaddr.ExcludeIfs(selectorName, selectorParam, inputIfAddrs) +} + +// Include - +func (SockaddrFuncs) Include(selectorName, selectorParam string, inputIfAddrs sockaddr.IfAddrs) (sockaddr.IfAddrs, error) { + return sockaddr.IncludeIfs(selectorName, selectorParam, inputIfAddrs) +} + +// Attr - +func (SockaddrFuncs) Attr(selectorName string, ifAddrsRaw interface{}) (string, error) { + return template.Attr(selectorName, ifAddrsRaw) +} + +// Join - +func (SockaddrFuncs) Join(selectorName, joinString string, inputIfAddrs sockaddr.IfAddrs) (string, error) { + return sockaddr.JoinIfAddrs(selectorName, joinString, inputIfAddrs) +} + +// Limit - +func (SockaddrFuncs) Limit(lim uint, in sockaddr.IfAddrs) (sockaddr.IfAddrs, error) { + return sockaddr.LimitIfAddrs(lim, in) +} + +// Offset - +func (SockaddrFuncs) Offset(off int, in sockaddr.IfAddrs) (sockaddr.IfAddrs, error) { + return sockaddr.OffsetIfAddrs(off, in) +} + +// Unique - +func (SockaddrFuncs) Unique(selectorName string, inputIfAddrs sockaddr.IfAddrs) (sockaddr.IfAddrs, error) { + return sockaddr.UniqueIfAddrsBy(selectorName, inputIfAddrs) +} + +// Math - +func (SockaddrFuncs) Math(operation, value string, inputIfAddrs sockaddr.IfAddrs) (sockaddr.IfAddrs, error) { + return sockaddr.IfAddrsMath(operation, value, inputIfAddrs) +} + +// GetPrivateIP - +func (SockaddrFuncs) GetPrivateIP() (string, error) { + return sockaddr.GetPrivateIP() +} + +// GetPrivateIPs - +func (SockaddrFuncs) GetPrivateIPs() (string, error) { + return sockaddr.GetPrivateIPs() +} + +// GetPublicIP - +func (SockaddrFuncs) GetPublicIP() (string, error) { + return sockaddr.GetPublicIP() +} + +// GetPublicIPs - +func (SockaddrFuncs) GetPublicIPs() (string, error) { + return sockaddr.GetPublicIPs() +} + +// GetInterfaceIP - +func (SockaddrFuncs) GetInterfaceIP(namedIfRE string) (string, error) { + return sockaddr.GetInterfaceIP(namedIfRE) +} + +// GetInterfaceIPs - +func (SockaddrFuncs) GetInterfaceIPs(namedIfRE string) (string, error) { + return sockaddr.GetInterfaceIPs(namedIfRE) +} diff --git a/internal/funcs/sockaddr_test.go b/internal/funcs/sockaddr_test.go new file mode 100644 index 00000000..caf86bf8 --- /dev/null +++ b/internal/funcs/sockaddr_test.go @@ -0,0 +1,26 @@ +package funcs + +import ( + "context" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCreateSockaddrFuncs(t *testing.T) { + t.Parallel() + + for i := 0; i < 10; i++ { + // Run this a bunch to catch race conditions + t.Run(strconv.Itoa(i), func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + fmap := CreateSockaddrFuncs(ctx) + actual := fmap["sockaddr"].(func() interface{}) + + assert.Equal(t, ctx, actual().(*SockaddrFuncs).ctx) + }) + } +} diff --git a/internal/funcs/strings.go b/internal/funcs/strings.go new file mode 100644 index 00000000..f95e6bf5 --- /dev/null +++ b/internal/funcs/strings.go @@ -0,0 +1,370 @@ +package funcs + +// Namespace strings contains mostly wrappers of equivalently-named +// functions in the standard library `strings` package, with +// differences in argument order where it makes pipelining +// in templates easier. + +import ( + "context" + "fmt" + "reflect" + "strings" + "unicode/utf8" + + "github.com/Masterminds/goutils" + "github.com/hairyhenderson/gomplate/v4/conv" + "github.com/hairyhenderson/gomplate/v4/internal/deprecated" + gompstrings "github.com/hairyhenderson/gomplate/v4/strings" + + "github.com/gosimple/slug" + "golang.org/x/text/cases" + "golang.org/x/text/language" +) + +// StrNS - +// +// Deprecated: don't use +func StrNS() *StringFuncs { + return &StringFuncs{} +} + +// AddStringFuncs - +// +// Deprecated: use [CreateStringFuncs] instead +func AddStringFuncs(f map[string]interface{}) { + for k, v := range CreateStringFuncs(context.Background()) { + f[k] = v + } +} + +// CreateStringFuncs - +func CreateStringFuncs(ctx context.Context) map[string]interface{} { + f := map[string]interface{}{} + + ns := &StringFuncs{ctx, language.Und} + f["strings"] = func() interface{} { return ns } + + f["replaceAll"] = ns.ReplaceAll + f["title"] = ns.Title + f["toUpper"] = ns.ToUpper + f["toLower"] = ns.ToLower + f["trimSpace"] = ns.TrimSpace + f["indent"] = ns.Indent + f["quote"] = ns.Quote + f["shellQuote"] = ns.ShellQuote + f["squote"] = ns.Squote + + // these are legacy aliases with non-pipelinable arg order + f["contains"] = ns.oldContains + f["hasPrefix"] = ns.oldHasPrefix + f["hasSuffix"] = ns.oldHasSuffix + f["split"] = ns.oldSplit + f["splitN"] = ns.oldSplitN + f["trim"] = ns.oldTrim + + return f +} + +// StringFuncs - +type StringFuncs struct { + ctx context.Context + + // tag - the selected BCP 47 language tag. Currently gomplate only supports + // Und (undetermined) + tag language.Tag +} + +// ---- legacy aliases with non-pipelinable arg order + +// oldContains - +// +// Deprecated: use [strings.Contains] instead +func (f *StringFuncs) oldContains(s, substr string) bool { + deprecated.WarnDeprecated(f.ctx, "contains is deprecated - use strings.Contains instead") + return strings.Contains(s, substr) +} + +// oldHasPrefix - +// +// Deprecated: use [strings.HasPrefix] instead +func (f *StringFuncs) oldHasPrefix(s, prefix string) bool { + deprecated.WarnDeprecated(f.ctx, "hasPrefix is deprecated - use strings.HasPrefix instead") + return strings.HasPrefix(s, prefix) +} + +// oldHasSuffix - +// +// Deprecated: use [strings.HasSuffix] instead +func (f *StringFuncs) oldHasSuffix(s, suffix string) bool { + deprecated.WarnDeprecated(f.ctx, "hasSuffix is deprecated - use strings.HasSuffix instead") + return strings.HasSuffix(s, suffix) +} + +// oldSplit - +// +// Deprecated: use [strings.Split] instead +func (f *StringFuncs) oldSplit(s, sep string) []string { + deprecated.WarnDeprecated(f.ctx, "split is deprecated - use strings.Split instead") + return strings.Split(s, sep) +} + +// oldSplitN - +// +// Deprecated: use [strings.SplitN] instead +func (f *StringFuncs) oldSplitN(s, sep string, n int) []string { + deprecated.WarnDeprecated(f.ctx, "splitN is deprecated - use strings.SplitN instead") + return strings.SplitN(s, sep, n) +} + +// oldTrim - +// +// Deprecated: use [strings.Trim] instead +func (f *StringFuncs) oldTrim(s, cutset string) string { + deprecated.WarnDeprecated(f.ctx, "trim is deprecated - use strings.Trim instead") + return strings.Trim(s, cutset) +} + +// ---- + +// Abbrev - +func (StringFuncs) Abbrev(args ...interface{}) (string, error) { + str := "" + offset := 0 + maxWidth := 0 + if len(args) < 2 { + return "", fmt.Errorf("abbrev requires a 'maxWidth' and 'input' argument") + } + if len(args) == 2 { + maxWidth = conv.ToInt(args[0]) + str = conv.ToString(args[1]) + } + if len(args) == 3 { + offset = conv.ToInt(args[0]) + maxWidth = conv.ToInt(args[1]) + str = conv.ToString(args[2]) + } + if len(str) <= maxWidth { + return str, nil + } + return goutils.AbbreviateFull(str, offset, maxWidth) +} + +// ReplaceAll - +func (StringFuncs) ReplaceAll(old, new string, s interface{}) string { + return strings.ReplaceAll(conv.ToString(s), old, new) +} + +// Contains - +func (StringFuncs) Contains(substr string, s interface{}) bool { + return strings.Contains(conv.ToString(s), substr) +} + +// HasPrefix - +func (StringFuncs) HasPrefix(prefix string, s interface{}) bool { + return strings.HasPrefix(conv.ToString(s), prefix) +} + +// HasSuffix - +func (StringFuncs) HasSuffix(suffix string, s interface{}) bool { + return strings.HasSuffix(conv.ToString(s), suffix) +} + +// Repeat - +func (StringFuncs) Repeat(count int, s interface{}) (string, error) { + if count < 0 { + return "", fmt.Errorf("negative count %d", count) + } + str := conv.ToString(s) + if count > 0 && len(str)*count/count != len(str) { + return "", fmt.Errorf("count %d too long: causes overflow", count) + } + return strings.Repeat(str, count), nil +} + +// SkipLines - +func (StringFuncs) SkipLines(skip int, in string) (string, error) { + return gompstrings.SkipLines(skip, in) +} + +// Sort - +// +// Deprecated: use [CollFuncs.Sort] instead +func (f *StringFuncs) Sort(list interface{}) ([]string, error) { + deprecated.WarnDeprecated(f.ctx, "strings.Sort is deprecated - use coll.Sort instead") + + switch v := list.(type) { + case []string: + return gompstrings.Sort(v), nil + case []interface{}: + l := len(v) + b := make([]string, len(v)) + for i := 0; i < l; i++ { + b[i] = conv.ToString(v[i]) + } + return gompstrings.Sort(b), nil + default: + return nil, fmt.Errorf("wrong type for value; expected []string; got %T", list) + } +} + +// Split - +func (StringFuncs) Split(sep string, s interface{}) []string { + return strings.Split(conv.ToString(s), sep) +} + +// SplitN - +func (StringFuncs) SplitN(sep string, n int, s interface{}) []string { + return strings.SplitN(conv.ToString(s), sep, n) +} + +// Trim - +func (StringFuncs) Trim(cutset string, s interface{}) string { + return strings.Trim(conv.ToString(s), cutset) +} + +// TrimPrefix - +func (StringFuncs) TrimPrefix(cutset string, s interface{}) string { + return strings.TrimPrefix(conv.ToString(s), cutset) +} + +// TrimSuffix - +func (StringFuncs) TrimSuffix(cutset string, s interface{}) string { + return strings.TrimSuffix(conv.ToString(s), cutset) +} + +// Title - +func (f *StringFuncs) Title(s interface{}) string { + return cases.Title(f.tag, cases.NoLower).String(conv.ToString(s)) +} + +// ToUpper - +func (f *StringFuncs) ToUpper(s interface{}) string { + return cases.Upper(f.tag).String(conv.ToString(s)) +} + +// ToLower - +func (f *StringFuncs) ToLower(s interface{}) string { + return cases.Lower(f.tag).String(conv.ToString(s)) +} + +// TrimSpace - +func (StringFuncs) TrimSpace(s interface{}) string { + return strings.TrimSpace(conv.ToString(s)) +} + +// Trunc - +func (StringFuncs) Trunc(length int, s interface{}) string { + return gompstrings.Trunc(length, conv.ToString(s)) +} + +// Indent - +func (StringFuncs) Indent(args ...interface{}) (string, error) { + input := conv.ToString(args[len(args)-1]) + indent := " " + width := 1 + var ok bool + switch len(args) { + case 2: + indent, ok = args[0].(string) + if !ok { + width, ok = args[0].(int) + if !ok { + return "", fmt.Errorf("indent: invalid arguments") + } + indent = " " + } + case 3: + width, ok = args[0].(int) + if !ok { + return "", fmt.Errorf("indent: invalid arguments") + } + indent, ok = args[1].(string) + if !ok { + return "", fmt.Errorf("indent: invalid arguments") + } + } + return gompstrings.Indent(width, indent, input), nil +} + +// Slug - +func (StringFuncs) Slug(in interface{}) string { + return slug.Make(conv.ToString(in)) +} + +// Quote - +func (StringFuncs) Quote(in interface{}) string { + return fmt.Sprintf("%q", conv.ToString(in)) +} + +// ShellQuote - +func (StringFuncs) ShellQuote(in interface{}) string { + val := reflect.ValueOf(in) + switch val.Kind() { + case reflect.Array, reflect.Slice: + var sb strings.Builder + max := val.Len() + for n := 0; n < max; n++ { + sb.WriteString(gompstrings.ShellQuote(conv.ToString(val.Index(n)))) + if n+1 != max { + sb.WriteRune(' ') + } + } + return sb.String() + } + return gompstrings.ShellQuote(conv.ToString(in)) +} + +// Squote - +func (StringFuncs) Squote(in interface{}) string { + s := conv.ToString(in) + s = strings.ReplaceAll(s, `'`, `''`) + return fmt.Sprintf("'%s'", s) +} + +// SnakeCase - +func (StringFuncs) SnakeCase(in interface{}) (string, error) { + return gompstrings.SnakeCase(conv.ToString(in)), nil +} + +// CamelCase - +func (StringFuncs) CamelCase(in interface{}) (string, error) { + return gompstrings.CamelCase(conv.ToString(in)), nil +} + +// KebabCase - +func (StringFuncs) KebabCase(in interface{}) (string, error) { + return gompstrings.KebabCase(conv.ToString(in)), nil +} + +// WordWrap - +func (StringFuncs) WordWrap(args ...interface{}) (string, error) { + if len(args) == 0 || len(args) > 3 { + return "", fmt.Errorf("expected 1, 2, or 3 args, got %d", len(args)) + } + in := conv.ToString(args[len(args)-1]) + + opts := gompstrings.WordWrapOpts{} + if len(args) == 2 { + switch a := (args[0]).(type) { + case string: + opts.LBSeq = a + default: + opts.Width = uint(conv.ToInt(a)) + } + } + if len(args) == 3 { + opts.Width = uint(conv.ToInt(args[0])) + opts.LBSeq = conv.ToString(args[1]) + } + return gompstrings.WordWrap(in, opts), nil +} + +// RuneCount - like len(s), but for runes +func (StringFuncs) RuneCount(args ...interface{}) (int, error) { + s := "" + for _, arg := range args { + s += conv.ToString(arg) + } + return utf8.RuneCountInString(s), nil +} diff --git a/internal/funcs/strings_test.go b/internal/funcs/strings_test.go new file mode 100644 index 00000000..fdc64a47 --- /dev/null +++ b/internal/funcs/strings_test.go @@ -0,0 +1,258 @@ +package funcs + +import ( + "context" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCreateStringFuncs(t *testing.T) { + t.Parallel() + + for i := 0; i < 10; i++ { + // Run this a bunch to catch race conditions + t.Run(strconv.Itoa(i), func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + fmap := CreateStringFuncs(ctx) + actual := fmap["strings"].(func() interface{}) + + assert.Equal(t, ctx, actual().(*StringFuncs).ctx) + }) + } +} + +func TestReplaceAll(t *testing.T) { + t.Parallel() + + sf := &StringFuncs{} + + assert.Equal(t, "Replaced", + sf.ReplaceAll("Orig", "Replaced", "Orig")) + assert.Equal(t, "ReplacedReplaced", + sf.ReplaceAll("Orig", "Replaced", "OrigOrig")) +} + +func TestIndent(t *testing.T) { + t.Parallel() + + sf := &StringFuncs{} + + testdata := []struct { + out string + args []interface{} + }{ + {" foo\n bar\n baz", []interface{}{"foo\nbar\nbaz"}}, + {" foo\n bar\n baz", []interface{}{" ", "foo\nbar\nbaz"}}, + {"---foo\n---bar\n---baz", []interface{}{3, "-", "foo\nbar\nbaz"}}, + {" foo\n bar\n baz", []interface{}{3, "foo\nbar\nbaz"}}, + } + + for _, d := range testdata { + out, err := sf.Indent(d.args...) + require.NoError(t, err) + assert.Equal(t, d.out, out) + } +} + +func TestTrimPrefix(t *testing.T) { + t.Parallel() + + sf := &StringFuncs{} + + assert.Equal(t, "Bar", + sf.TrimPrefix("Foo", "FooBar")) +} + +func TestTitle(t *testing.T) { + sf := &StringFuncs{} + testdata := []struct { + in interface{} + out string + }{ + {``, ``}, + {`foo`, `Foo`}, + {`foo bar`, `Foo Bar`}, + {`ljoo džar`, `Ljoo Džar`}, + {`foo bar᳇baz`, `Foo Bar᳇Baz`}, // ᳇ should be treated as punctuation + {`foo,bar&baz`, `Foo,Bar&Baz`}, + {`FOO`, `FOO`}, + {`bar FOO`, `Bar FOO`}, + } + + for _, d := range testdata { + up := sf.Title(d.in) + assert.Equal(t, d.out, up) + } +} + +func TestTrunc(t *testing.T) { + t.Parallel() + + sf := &StringFuncs{} + assert.Equal(t, "", sf.Trunc(5, "")) + assert.Equal(t, "", sf.Trunc(0, nil)) + assert.Equal(t, "123", sf.Trunc(3, 123456789)) + assert.Equal(t, "hello, world", sf.Trunc(-1, "hello, world")) +} + +func TestAbbrev(t *testing.T) { + t.Parallel() + + sf := &StringFuncs{} + _, err := sf.Abbrev() + assert.Error(t, err) + + _, err = sf.Abbrev("foo") + assert.Error(t, err) + + s, err := sf.Abbrev(3, "foo") + require.NoError(t, err) + assert.Equal(t, "foo", s) + + s, err = sf.Abbrev(2, 6, "foobar") + require.NoError(t, err) + assert.Equal(t, "foobar", s) + + s, err = sf.Abbrev(6, 9, "foobarbazquxquux") + require.NoError(t, err) + assert.Equal(t, "...baz...", s) +} + +func TestSlug(t *testing.T) { + sf := &StringFuncs{} + s := sf.Slug(nil) + assert.Equal(t, "nil", s) + + s = sf.Slug(0) + assert.Equal(t, "0", s) + + s = sf.Slug(1.85e-5) + assert.Equal(t, "1-85e-05", s) + + s = sf.Slug("Hello, World!") + assert.Equal(t, "hello-world", s) + + s = sf.Slug("foo@example.com") + assert.Equal(t, "fooatexample-com", s) + + s = sf.Slug("rock & roll!") + assert.Equal(t, "rock-and-roll", s) + + s = sf.Slug("foo@example.com") + assert.Equal(t, "fooatexample-com", s) + + s = sf.Slug(`100%`) + assert.Equal(t, "100", s) +} + +func TestSort(t *testing.T) { + t.Parallel() + sf := &StringFuncs{ctx: context.Background()} + + in := []string{"foo", "bar", "baz"} + out := []string{"bar", "baz", "foo"} + assert.Equal(t, out, must(sf.Sort(in))) + + assert.Equal(t, out, must(sf.Sort([]interface{}{"foo", "bar", "baz"}))) +} + +func TestQuote(t *testing.T) { + t.Parallel() + + sf := &StringFuncs{} + testdata := []struct { + in interface{} + out string + }{ + {``, `""`}, + {`foo`, `"foo"`}, + {nil, `"nil"`}, + {123.4, `"123.4"`}, + {`hello "world"`, `"hello \"world\""`}, + {`it's its`, `"it's its"`}, + } + + for _, d := range testdata { + assert.Equal(t, d.out, sf.Quote(d.in)) + } +} + +func TestShellQuote(t *testing.T) { + t.Parallel() + + sf := &StringFuncs{} + testdata := []struct { + in interface{} + out string + }{ + // conventional cases are covered in gompstrings.ShellQuote() tests + // we cover only cases that require type conversion or array/slice combining here + {nil, `'nil'`}, + {123.4, `'123.4'`}, + // array and slice cases + {[]string{}, ``}, + {[]string{"", ""}, `'' ''`}, + {[...]string{"one'two", "three four"}, `'one'"'"'two' 'three four'`}, + {[]string{"one'two", "three four"}, `'one'"'"'two' 'three four'`}, + } + + for _, d := range testdata { + assert.Equal(t, d.out, sf.ShellQuote(d.in)) + } +} + +func TestSquote(t *testing.T) { + t.Parallel() + + sf := &StringFuncs{} + testdata := []struct { + in interface{} + out string + }{ + {``, `''`}, + {`foo`, `'foo'`}, + {nil, `'nil'`}, + {123.4, `'123.4'`}, + {`hello "world"`, `'hello "world"'`}, + {`it's its`, `'it''s its'`}, + } + + for _, d := range testdata { + assert.Equal(t, d.out, sf.Squote(d.in)) + } +} + +func TestRuneCount(t *testing.T) { + t.Parallel() + + sf := &StringFuncs{} + + n, err := sf.RuneCount("") + require.NoError(t, err) + assert.Equal(t, 0, n) + + n, err = sf.RuneCount("foo") + require.NoError(t, err) + assert.Equal(t, 3, n) + + n, err = sf.RuneCount("foo", "bar") + require.NoError(t, err) + assert.Equal(t, 6, n) + + n, err = sf.RuneCount(42, true) + require.NoError(t, err) + assert.Equal(t, 6, n) + + n, err = sf.RuneCount("😂\U0001F602") + require.NoError(t, err) + assert.Equal(t, 2, n) + + n, err = sf.RuneCount("\U0001F600", 3.14) + require.NoError(t, err) + assert.Equal(t, 5, n) +} diff --git a/internal/funcs/test.go b/internal/funcs/test.go new file mode 100644 index 00000000..875cf631 --- /dev/null +++ b/internal/funcs/test.go @@ -0,0 +1,120 @@ +package funcs + +import ( + "context" + "fmt" + "reflect" + + "github.com/hairyhenderson/gomplate/v4/conv" + "github.com/hairyhenderson/gomplate/v4/test" +) + +// TestNS - +// +// Deprecated: don't use +func TestNS() *TestFuncs { + return &TestFuncs{} +} + +// AddTestFuncs - +// +// Deprecated: use [CreateTestFuncs] instead +func AddTestFuncs(f map[string]interface{}) { + for k, v := range CreateTestFuncs(context.Background()) { + f[k] = v + } +} + +// CreateTestFuncs - +func CreateTestFuncs(ctx context.Context) map[string]interface{} { + f := map[string]interface{}{} + + ns := &TestFuncs{ctx} + f["test"] = func() interface{} { return ns } + + f["assert"] = ns.Assert + f["fail"] = ns.Fail + f["required"] = ns.Required + f["ternary"] = ns.Ternary + f["kind"] = ns.Kind + f["isKind"] = ns.IsKind + return f +} + +// TestFuncs - +type TestFuncs struct { + ctx context.Context +} + +// Assert - +func (TestFuncs) Assert(args ...interface{}) (string, error) { + input := conv.ToBool(args[len(args)-1]) + switch len(args) { + case 1: + return test.Assert(input, "") + case 2: + message, ok := args[0].(string) + if !ok { + return "", fmt.Errorf("at <1>: expected string; found %T", args[0]) + } + return test.Assert(input, message) + default: + return "", fmt.Errorf("wrong number of args: want 1 or 2, got %d", len(args)) + } +} + +// Fail - +func (TestFuncs) Fail(args ...interface{}) (string, error) { + switch len(args) { + case 0: + return "", test.Fail("") + case 1: + return "", test.Fail(conv.ToString(args[0])) + default: + return "", fmt.Errorf("wrong number of args: want 0 or 1, got %d", len(args)) + } +} + +// Required - +func (TestFuncs) Required(args ...interface{}) (interface{}, error) { + switch len(args) { + case 1: + return test.Required("", args[0]) + case 2: + message, ok := args[0].(string) + if !ok { + return nil, fmt.Errorf("at <1>: expected string; found %T", args[0]) + } + return test.Required(message, args[1]) + default: + return nil, fmt.Errorf("wrong number of args: want 1 or 2, got %d", len(args)) + } +} + +// Ternary - +func (TestFuncs) Ternary(tval, fval, b interface{}) interface{} { + if conv.ToBool(b) { + return tval + } + return fval +} + +// Kind - return the kind of the argument +func (TestFuncs) Kind(arg interface{}) string { + return reflect.ValueOf(arg).Kind().String() +} + +// IsKind - return whether or not the argument is of the given kind +func (f TestFuncs) IsKind(kind string, arg interface{}) bool { + k := f.Kind(arg) + if kind == "number" { + switch k { + case "int", "int8", "int16", "int32", "int64", + "uint", "uint8", "uint16", "uint32", "uint64", "uintptr", + "float32", "float64", + "complex64", "complex128": + kind = k + } + } + return k == kind +} diff --git a/internal/funcs/test_test.go b/internal/funcs/test_test.go new file mode 100644 index 00000000..26909813 --- /dev/null +++ b/internal/funcs/test_test.go @@ -0,0 +1,174 @@ +package funcs + +import ( + "context" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCreateTestFuncs(t *testing.T) { + t.Parallel() + + for i := 0; i < 10; i++ { + // Run this a bunch to catch race conditions + t.Run(strconv.Itoa(i), func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + fmap := CreateTestFuncs(ctx) + actual := fmap["test"].(func() interface{}) + + assert.Equal(t, ctx, actual().(*TestFuncs).ctx) + }) + } +} + +func TestAssert(t *testing.T) { + t.Parallel() + + f := TestNS() + _, err := f.Assert(false) + assert.Error(t, err) + + _, err = f.Assert(true) + require.NoError(t, err) + + _, err = f.Assert("foo", true) + require.NoError(t, err) + + _, err = f.Assert("foo", "false") + assert.EqualError(t, err, "assertion failed: foo") +} + +func TestRequired(t *testing.T) { + t.Parallel() + + f := TestNS() + errMsg := "can not render template: a required value was not set" + v, err := f.Required("") + assert.Error(t, err) + assert.EqualError(t, err, errMsg) + assert.Nil(t, v) + + v, err = f.Required(nil) + assert.Error(t, err) + assert.EqualError(t, err, errMsg) + assert.Nil(t, v) + + errMsg = "hello world" + v, err = f.Required(errMsg, nil) + assert.Error(t, err) + assert.EqualError(t, err, errMsg) + assert.Nil(t, v) + + v, err = f.Required(42, nil) + assert.Error(t, err) + assert.EqualError(t, err, "at <1>: expected string; found int") + assert.Nil(t, v) + + v, err = f.Required() + assert.Error(t, err) + assert.EqualError(t, err, "wrong number of args: want 1 or 2, got 0") + assert.Nil(t, v) + + v, err = f.Required("", 2, 3) + assert.Error(t, err) + assert.EqualError(t, err, "wrong number of args: want 1 or 2, got 3") + assert.Nil(t, v) + + v, err = f.Required(0) + require.NoError(t, err) + assert.Equal(t, v, 0) + + v, err = f.Required("foo") + require.NoError(t, err) + assert.Equal(t, v, "foo") +} + +func TestTernary(t *testing.T) { + t.Parallel() + + f := TestNS() + testdata := []struct { + tval, fval, b interface{} + expected interface{} + }{ + {"foo", 42, false, 42}, + {"foo", 42, "yes", "foo"}, + {false, true, true, false}, + } + for _, d := range testdata { + assert.Equal(t, d.expected, f.Ternary(d.tval, d.fval, d.b)) + } +} + +func TestKind(t *testing.T) { + t.Parallel() + + f := TestNS() + testdata := []struct { + arg interface{} + expected string + }{ + {"foo", "string"}, + {nil, "invalid"}, + {false, "bool"}, + {[]string{"foo", "bar"}, "slice"}, + {map[string]string{"foo": "bar"}, "map"}, + {42, "int"}, + {42.0, "float64"}, + {uint(42), "uint"}, + {struct{}{}, "struct"}, + } + for _, d := range testdata { + assert.Equal(t, d.expected, f.Kind(d.arg)) + } +} + +func TestIsKind(t *testing.T) { + t.Parallel() + + f := TestNS() + truedata := []struct { + arg interface{} + kind string + }{ + {"foo", "string"}, + {nil, "invalid"}, + {false, "bool"}, + {[]string{"foo", "bar"}, "slice"}, + {map[string]string{"foo": "bar"}, "map"}, + {42, "int"}, + {42.0, "float64"}, + {uint(42), "uint"}, + {struct{}{}, "struct"}, + {42.0, "number"}, + {42, "number"}, + {uint32(64000), "number"}, + {complex128(64000), "number"}, + } + for _, d := range truedata { + assert.True(t, f.IsKind(d.kind, d.arg)) + } + + falsedata := []struct { + arg interface{} + kind string + }{ + {"foo", "bool"}, + {nil, "struct"}, + {false, "string"}, + {[]string{"foo", "bar"}, "map"}, + {map[string]string{"foo": "bar"}, "int"}, + {42, "int64"}, + {42.0, "float32"}, + {uint(42), "int"}, + {struct{}{}, "interface"}, + } + for _, d := range falsedata { + assert.False(t, f.IsKind(d.kind, d.arg)) + } +} diff --git a/internal/funcs/time.go b/internal/funcs/time.go new file mode 100644 index 00000000..ca02a35a --- /dev/null +++ b/internal/funcs/time.go @@ -0,0 +1,229 @@ +package funcs + +import ( + "context" + "fmt" + "strconv" + "strings" + gotime "time" + + "github.com/hairyhenderson/gomplate/v4/conv" + "github.com/hairyhenderson/gomplate/v4/env" + "github.com/hairyhenderson/gomplate/v4/time" +) + +// TimeNS - +// +// Deprecated: don't use +func TimeNS() *TimeFuncs { + return &TimeFuncs{ + ANSIC: gotime.ANSIC, + UnixDate: gotime.UnixDate, + RubyDate: gotime.RubyDate, + RFC822: gotime.RFC822, + RFC822Z: gotime.RFC822Z, + RFC850: gotime.RFC850, + RFC1123: gotime.RFC1123, + RFC1123Z: gotime.RFC1123Z, + RFC3339: gotime.RFC3339, + RFC3339Nano: gotime.RFC3339Nano, + Kitchen: gotime.Kitchen, + Stamp: gotime.Stamp, + StampMilli: gotime.StampMilli, + StampMicro: gotime.StampMicro, + StampNano: gotime.StampNano, + } +} + +// AddTimeFuncs - +// +// Deprecated: use [CreateTimeFuncs] instead +func AddTimeFuncs(f map[string]interface{}) { + for k, v := range CreateTimeFuncs(context.Background()) { + f[k] = v + } +} + +// CreateTimeFuncs - +func CreateTimeFuncs(ctx context.Context) map[string]interface{} { + ns := &TimeFuncs{ + ctx: ctx, + ANSIC: gotime.ANSIC, + UnixDate: gotime.UnixDate, + RubyDate: gotime.RubyDate, + RFC822: gotime.RFC822, + RFC822Z: gotime.RFC822Z, + RFC850: gotime.RFC850, + RFC1123: gotime.RFC1123, + RFC1123Z: gotime.RFC1123Z, + RFC3339: gotime.RFC3339, + RFC3339Nano: gotime.RFC3339Nano, + Kitchen: gotime.Kitchen, + Stamp: gotime.Stamp, + StampMilli: gotime.StampMilli, + StampMicro: gotime.StampMicro, + StampNano: gotime.StampNano, + } + + return map[string]interface{}{ + "time": func() interface{} { return ns }, + } +} + +// TimeFuncs - +type TimeFuncs struct { + ctx context.Context + ANSIC string + UnixDate string + RubyDate string + RFC822 string + RFC822Z string + RFC850 string + RFC1123 string + RFC1123Z string + RFC3339 string + RFC3339Nano string + Kitchen string + Stamp string + StampMilli string + StampMicro string + StampNano string +} + +// ZoneName - return the local system's time zone's name +func (TimeFuncs) ZoneName() string { + return time.ZoneName() +} + +// ZoneOffset - return the local system's time zone's name +func (TimeFuncs) ZoneOffset() int { + return time.ZoneOffset() +} + +// Parse - +func (TimeFuncs) Parse(layout string, value interface{}) (gotime.Time, error) { + return gotime.Parse(layout, conv.ToString(value)) +} + +// ParseLocal - +func (f TimeFuncs) ParseLocal(layout string, value interface{}) (gotime.Time, error) { + tz := env.Getenv("TZ", "Local") + return f.ParseInLocation(layout, tz, value) +} + +// ParseInLocation - +func (TimeFuncs) ParseInLocation(layout, location string, value interface{}) (gotime.Time, error) { + loc, err := gotime.LoadLocation(location) + if err != nil { + return gotime.Time{}, err + } + return gotime.ParseInLocation(layout, conv.ToString(value), loc) +} + +// Now - +func (TimeFuncs) Now() gotime.Time { + return gotime.Now() +} + +// Unix - convert UNIX time (in seconds since the UNIX epoch) into a time.Time for further processing +// Takes a string or number (int or float) +func (TimeFuncs) Unix(in interface{}) (gotime.Time, error) { + sec, nsec, err := parseNum(in) + if err != nil { + return gotime.Time{}, err + } + return gotime.Unix(sec, nsec), nil +} + +// Nanosecond - +func (TimeFuncs) Nanosecond(n interface{}) gotime.Duration { + return gotime.Nanosecond * gotime.Duration(conv.ToInt64(n)) +} + +// Microsecond - +func (TimeFuncs) Microsecond(n interface{}) gotime.Duration { + return gotime.Microsecond * gotime.Duration(conv.ToInt64(n)) +} + +// Millisecond - +func (TimeFuncs) Millisecond(n interface{}) gotime.Duration { + return gotime.Millisecond * gotime.Duration(conv.ToInt64(n)) +} + +// Second - +func (TimeFuncs) Second(n interface{}) gotime.Duration { + return gotime.Second * gotime.Duration(conv.ToInt64(n)) +} + +// Minute - +func (TimeFuncs) Minute(n interface{}) gotime.Duration { + return gotime.Minute * gotime.Duration(conv.ToInt64(n)) +} + +// Hour - +func (TimeFuncs) Hour(n interface{}) gotime.Duration { + return gotime.Hour * gotime.Duration(conv.ToInt64(n)) +} + +// ParseDuration - +func (TimeFuncs) ParseDuration(n interface{}) (gotime.Duration, error) { + return gotime.ParseDuration(conv.ToString(n)) +} + +// Since - +func (TimeFuncs) Since(n gotime.Time) gotime.Duration { + return gotime.Since(n) +} + +// Until - +func (TimeFuncs) Until(n gotime.Time) gotime.Duration { + return gotime.Until(n) +} + +// convert a number input to a pair of int64s, representing the integer portion and the decimal remainder +// this can handle a string as well as any integer or float type +// precision is at the "nano" level (i.e. 1e+9) +func parseNum(in interface{}) (integral int64, fractional int64, err error) { + if s, ok := in.(string); ok { + ss := strings.Split(s, ".") + if len(ss) > 2 { + return 0, 0, fmt.Errorf("can not parse '%s' as a number - too many decimal points", s) + } + if len(ss) == 1 { + integral, err := strconv.ParseInt(s, 0, 64) + return integral, 0, err + } + integral, err := strconv.ParseInt(ss[0], 0, 64) + if err != nil { + return integral, 0, err + } + fractional, err = strconv.ParseInt(padRight(ss[1], "0", 9), 0, 64) + return integral, fractional, err + } + if s, ok := in.(fmt.Stringer); ok { + return parseNum(s.String()) + } + if i, ok := in.(int); ok { + return int64(i), 0, nil + } + if u, ok := in.(uint64); ok { + return int64(u), 0, nil + } + if f, ok := in.(float64); ok { + return 0, 0, fmt.Errorf("can not parse floating point number (%f) - use a string instead", f) + } + if in == nil { + return 0, 0, nil + } + return 0, 0, nil +} + +// pads a number with zeroes +func padRight(in, pad string, length int) string { + for { + in += pad + if len(in) > length { + return in[0:length] + } + } +} diff --git a/internal/funcs/time_test.go b/internal/funcs/time_test.go new file mode 100644 index 00000000..8b2d6e4a --- /dev/null +++ b/internal/funcs/time_test.go @@ -0,0 +1,82 @@ +package funcs + +import ( + "context" + "math" + "math/big" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCreateTimeFuncs(t *testing.T) { + t.Parallel() + + for i := 0; i < 10; i++ { + // Run this a bunch to catch race conditions + t.Run(strconv.Itoa(i), func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + fmap := CreateTimeFuncs(ctx) + actual := fmap["time"].(func() interface{}) + + assert.Equal(t, ctx, actual().(*TimeFuncs).ctx) + }) + } +} + +func TestParseNum(t *testing.T) { + t.Parallel() + + i, f, _ := parseNum("42") + assert.Equal(t, int64(42), i) + assert.Equal(t, int64(0), f) + + i, f, _ = parseNum(42) + assert.Equal(t, int64(42), i) + assert.Equal(t, int64(0), f) + + i, f, _ = parseNum(big.NewInt(42)) + assert.Equal(t, int64(42), i) + assert.Equal(t, int64(0), f) + + i, f, _ = parseNum(big.NewFloat(42.0)) + assert.Equal(t, int64(42), i) + assert.Equal(t, int64(0), f) + + i, f, _ = parseNum(uint64(math.MaxInt64)) + assert.Equal(t, int64(uint64(math.MaxInt64)), i) + assert.Equal(t, int64(0), f) + + i, f, _ = parseNum("9223372036854775807.999999999") + assert.Equal(t, int64(9223372036854775807), i) + assert.Equal(t, int64(999999999), f) + + i, f, _ = parseNum("999999999999999.123456789123") + assert.Equal(t, int64(999999999999999), i) + assert.Equal(t, int64(123456789), f) + + i, f, _ = parseNum("123456.789") + assert.Equal(t, int64(123456), i) + assert.Equal(t, int64(789000000), f) + + _, _, err := parseNum("bogus.9223372036854775807") + assert.Error(t, err) + + _, _, err = parseNum("bogus") + assert.Error(t, err) + + _, _, err = parseNum("1.2.3") + assert.Error(t, err) + + _, _, err = parseNum(1.1) + assert.Error(t, err) + + i, f, err = parseNum(nil) + assert.Zero(t, i) + assert.Zero(t, f) + require.NoError(t, err) +} diff --git a/internal/funcs/uuid.go b/internal/funcs/uuid.go new file mode 100644 index 00000000..2ad9d874 --- /dev/null +++ b/internal/funcs/uuid.go @@ -0,0 +1,83 @@ +package funcs + +import ( + "context" + + "github.com/hairyhenderson/gomplate/v4/conv" + + "github.com/google/uuid" +) + +// UUIDNS - +// +// Deprecated: don't use +func UUIDNS() *UUIDFuncs { + return &UUIDFuncs{} +} + +// AddUUIDFuncs - +// +// Deprecated: use [CreateUUIDFuncs] instead +func AddUUIDFuncs(f map[string]interface{}) { + for k, v := range CreateUUIDFuncs(context.Background()) { + f[k] = v + } +} + +// CreateUUIDFuncs - +func CreateUUIDFuncs(ctx context.Context) map[string]interface{} { + ns := &UUIDFuncs{ctx} + return map[string]interface{}{ + "uuid": func() interface{} { return ns }, + } +} + +// UUIDFuncs - +type UUIDFuncs struct { + ctx context.Context +} + +// V1 - return a version 1 UUID (based on the current MAC Address and the +// current date/time). Use V4 instead in most cases. +func (UUIDFuncs) V1() (string, error) { + u, err := uuid.NewUUID() + if err != nil { + return "", err + } + return u.String(), nil +} + +// V4 - return a version 4 (random) UUID +func (UUIDFuncs) V4() (string, error) { + u, err := uuid.NewRandom() + if err != nil { + return "", err + } + return u.String(), nil +} + +// Nil - +func (UUIDFuncs) Nil() (string, error) { + return uuid.Nil.String(), nil +} + +// IsValid - checks if the given UUID is in the correct format. It does not +// validate whether the version or variant are correct. +func (f UUIDFuncs) IsValid(in interface{}) (bool, error) { + _, err := f.Parse(in) + return err == nil, nil +} + +// Parse - parse a UUID for further manipulation or inspection. +// +// Both the standard UUID forms of xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx and +// urn:uuid:xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx are decoded as well as the +// Microsoft encoding {xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx} and the raw hex +// encoding: xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx. +func (UUIDFuncs) Parse(in interface{}) (uuid.UUID, error) { + u, err := uuid.Parse(conv.ToString(in)) + if err != nil { + return uuid.Nil, err + } + return u, err +} diff --git a/internal/funcs/uuid_test.go b/internal/funcs/uuid_test.go new file mode 100644 index 00000000..c21981ea --- /dev/null +++ b/internal/funcs/uuid_test.go @@ -0,0 +1,116 @@ +package funcs + +import ( + "context" + "net/url" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCreateUUIDFuncs(t *testing.T) { + t.Parallel() + + for i := 0; i < 10; i++ { + // Run this a bunch to catch race conditions + t.Run(strconv.Itoa(i), func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + fmap := CreateUUIDFuncs(ctx) + actual := fmap["uuid"].(func() interface{}) + + assert.Equal(t, ctx, actual().(*UUIDFuncs).ctx) + }) + } +} + +const ( + uuidV1Pattern = "^[[:xdigit:]]{8}-[[:xdigit:]]{4}-1[[:xdigit:]]{3}-[89ab][[:xdigit:]]{3}-[[:xdigit:]]{12}$" + uuidV4Pattern = "^[[:xdigit:]]{8}-[[:xdigit:]]{4}-4[[:xdigit:]]{3}-[89ab][[:xdigit:]]{3}-[[:xdigit:]]{12}$" +) + +func TestV1(t *testing.T) { + t.Parallel() + + u := UUIDNS() + i, err := u.V1() + require.NoError(t, err) + assert.Regexp(t, uuidV1Pattern, i) +} + +func TestV4(t *testing.T) { + t.Parallel() + + u := UUIDNS() + i, err := u.V4() + require.NoError(t, err) + assert.Regexp(t, uuidV4Pattern, i) +} + +func TestNil(t *testing.T) { + t.Parallel() + + u := UUIDNS() + i, err := u.Nil() + require.NoError(t, err) + assert.Equal(t, "00000000-0000-0000-0000-000000000000", i) +} + +func TestIsValid(t *testing.T) { + t.Parallel() + + u := UUIDNS() + in := interface{}(false) + i, err := u.IsValid(in) + require.NoError(t, err) + assert.False(t, i) + + in = 12345 + i, err = u.IsValid(in) + require.NoError(t, err) + assert.False(t, i) + + testdata := []interface{}{ + "123456781234123412341234567890ab", + "12345678-1234-1234-1234-1234567890ab", + "urn:uuid:12345678-1234-1234-1234-1234567890ab", + "{12345678-1234-1234-1234-1234567890ab}", + } + + for _, d := range testdata { + i, err = u.IsValid(d) + require.NoError(t, err) + assert.True(t, i) + } +} + +func TestParse(t *testing.T) { + t.Parallel() + + u := UUIDNS() + in := interface{}(false) + _, err := u.Parse(in) + assert.Error(t, err) + + in = 12345 + _, err = u.Parse(in) + assert.Error(t, err) + + in = "12345678-1234-1234-1234-1234567890ab" + testdata := []interface{}{ + "123456781234123412341234567890ab", + "12345678-1234-1234-1234-1234567890ab", + "urn:uuid:12345678-1234-1234-1234-1234567890ab", + must(url.Parse("urn:uuid:12345678-1234-1234-1234-1234567890ab")), + "{12345678-1234-1234-1234-1234567890ab}", + } + + for _, d := range testdata { + uid, err := u.Parse(d) + require.NoError(t, err) + assert.Equal(t, in, uid.String()) + } +} |
