From 1da91051d43152c23e4df62f2082e44300986fdc Mon Sep 17 00:00:00 2001 From: Dave Henderson Date: Sun, 17 Nov 2024 10:55:34 -0500 Subject: feat(fs): Support Vault AWS IAM auth (#2264) Signed-off-by: Dave Henderson --- internal/datafs/fsys.go | 2 + internal/datafs/vaultauth.go | 35 +- .../integration/datasources_vault_ec2_test.go | 61 +--- .../integration/datasources_vault_iam_test.go | 56 ++++ .../tests/integration/datasources_vault_test.go | 2 +- internal/tests/integration/test_ec2_utils.go | 251 -------------- internal/tests/integration/test_ec2_utils_test.go | 370 +++++++++++++++++++++ 7 files changed, 460 insertions(+), 317 deletions(-) create mode 100644 internal/tests/integration/datasources_vault_iam_test.go delete mode 100644 internal/tests/integration/test_ec2_utils.go create mode 100644 internal/tests/integration/test_ec2_utils_test.go diff --git a/internal/datafs/fsys.go b/internal/datafs/fsys.go index f7894b32..03abdf6d 100644 --- a/internal/datafs/fsys.go +++ b/internal/datafs/fsys.go @@ -101,6 +101,8 @@ func FSysForPath(ctx context.Context, path string) (fs.FS, error) { fsys = vaultauth.WithAuthMethod(compositeVaultAuthMethod(fileFsys), fsys) } + fsys = fsimpl.WithContextFS(ctx, fsys) + return fsys, nil } diff --git a/internal/datafs/vaultauth.go b/internal/datafs/vaultauth.go index 3ad733a0..83b9daa5 100644 --- a/internal/datafs/vaultauth.go +++ b/internal/datafs/vaultauth.go @@ -20,15 +20,12 @@ func compositeVaultAuthMethod(envFsys fs.FS) api.AuthMethod { return vaultauth.CompositeAuthMethod( vaultauth.EnvAuthMethod(), envEC2AuthAdapter(envFsys), + envIAMAuthAdapter(envFsys), ) } -// func CompositeVaultAuthMethod() api.AuthMethod { -// return compositeVaultAuthMethod(WrapWdFS(osfs.NewFS())) -// } - // envEC2AuthAdapter builds an AWS EC2 authentication method from environment -// variables, for use only with [CompositeVaultAuthMethod] +// variables, for use only with [compositeVaultAuthMethod] func envEC2AuthAdapter(envFS fs.FS) api.AuthMethod { mountPath := GetenvFsys(envFS, "VAULT_AUTH_AWS_MOUNT", "aws") @@ -61,8 +58,34 @@ func envEC2AuthAdapter(envFS fs.FS) api.AuthMethod { return &ec2AuthNonceWriter{AWSAuth: awsauth, nonce: nonce, output: output} } +// envIAMAuthAdapter builds an AWS IAM authentication method from environment +// variables, for use only with [compositeVaultAuthMethod] +func envIAMAuthAdapter(envFS fs.FS) api.AuthMethod { + mountPath := GetenvFsys(envFS, "VAULT_AUTH_AWS_MOUNT", "aws") + role := GetenvFsys(envFS, "VAULT_AUTH_AWS_ROLE") + + // temporary workaround while we wait to deprecate AWS_META_ENDPOINT + if endpoint := os.Getenv("AWS_META_ENDPOINT"); endpoint != "" { + deprecated.WarnDeprecated(context.Background(), "Use AWS_EC2_METADATA_SERVICE_ENDPOINT instead of AWS_META_ENDPOINT") + if os.Getenv("AWS_EC2_METADATA_SERVICE_ENDPOINT") == "" { + os.Setenv("AWS_EC2_METADATA_SERVICE_ENDPOINT", endpoint) + } + } + + awsauth, err := aws.NewAWSAuth( + aws.WithIAMAuth(), + aws.WithMountPath(mountPath), + aws.WithRole(role), + ) + if err != nil { + return nil + } + + return awsauth +} + // ec2AuthNonceWriter - wraps an AWSAuth, and writes the nonce to the nonce -// output file +// output file - only for ec2 auth type ec2AuthNonceWriter struct { *aws.AWSAuth nonce string diff --git a/internal/tests/integration/datasources_vault_ec2_test.go b/internal/tests/integration/datasources_vault_ec2_test.go index c00a8d43..4b57020d 100644 --- a/internal/tests/integration/datasources_vault_ec2_test.go +++ b/internal/tests/integration/datasources_vault_ec2_test.go @@ -4,71 +4,14 @@ package integration import ( - "encoding/pem" - "io" - "net/http" - "net/http/httptest" "testing" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "gotest.tools/v3/fs" ) -func setupDatasourcesVaultEc2Test(t *testing.T) (*fs.Dir, *vaultClient, *httptest.Server, []byte) { - t.Helper() - - priv, der, _ := certificateGenerate() - cert := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}) - - mux := http.NewServeMux() - mux.HandleFunc("/latest/dynamic/instance-identity/pkcs7", pkcsHandler(priv, der)) - mux.HandleFunc("/latest/dynamic/instance-identity/document", instanceDocumentHandler) - mux.HandleFunc("/latest/api/token", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var b []byte - if r.Body != nil { - var err error - b, err = io.ReadAll(r.Body) - if !assert.NoError(t, err) { - w.WriteHeader(http.StatusInternalServerError) - return - } - defer r.Body.Close() - } - t.Logf("IMDS Token request: %s %s: %s", r.Method, r.URL, b) - - w.Write([]byte("testtoken")) - })) - mux.HandleFunc("/latest/meta-data/instance-id", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - t.Logf("IMDS request: %s %s", r.Method, r.URL) - w.Write([]byte("i-00000000")) - })) - mux.HandleFunc("/sts/", stsHandler) - mux.HandleFunc("/ec2/", ec2Handler) - mux.HandleFunc("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - t.Logf("unhandled request: %s %s", r.Method, r.URL) - w.WriteHeader(http.StatusNotFound) - })) - - srv := httptest.NewServer(mux) - t.Cleanup(srv.Close) - - tmpDir, v := startVault(t) - - err := v.vc.Sys().PutPolicy("writepol", `path "*" { - policy = "write" -}`) - require.NoError(t, err) - err = v.vc.Sys().PutPolicy("readpol", `path "*" { - policy = "read" -}`) - require.NoError(t, err) - - return tmpDir, v, srv, cert -} - func TestDatasources_VaultEc2(t *testing.T) { - tmpDir, v, srv, cert := setupDatasourcesVaultEc2Test(t) + accountID, user := "1", "Test" + tmpDir, v, srv, cert := setupDatasourcesVaultAWSTest(t, accountID, user) v.vc.Logical().Write("secret/foo", map[string]interface{}{"value": "bar"}) defer v.vc.Logical().Delete("secret/foo") diff --git a/internal/tests/integration/datasources_vault_iam_test.go b/internal/tests/integration/datasources_vault_iam_test.go new file mode 100644 index 00000000..0cee8d3d --- /dev/null +++ b/internal/tests/integration/datasources_vault_iam_test.go @@ -0,0 +1,56 @@ +//go:build !windows +// +build !windows + +package integration + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestDatasources_VaultIAM(t *testing.T) { + accountID := "000000000000" + user := "foo" + + tmpDir, v, srv, _ := setupDatasourcesVaultAWSTest(t, accountID, user) + + v.vc.Logical().Write("secret/foo", map[string]interface{}{"value": "bar"}) + defer v.vc.Logical().Delete("secret/foo") + + err := v.vc.Sys().EnableAuth("aws", "aws", "") + require.NoError(t, err) + defer v.vc.Sys().DisableAuth("aws") + + endpoint := srv.URL + + accessKeyID := "secret" + secretAccessKey := "access" + + _, err = v.vc.Logical().Write("auth/aws/config/client", map[string]interface{}{ + "access_key": accessKeyID, + "secret_key": secretAccessKey, + "endpoint": endpoint, + "iam_endpoint": endpoint + "/iam", + "sts_endpoint": endpoint + "/sts", + "sts_region": "us-east-1", + }) + require.NoError(t, err) + + _, err = v.vc.Logical().Write("auth/aws/role/foo", map[string]interface{}{ + "auth_type": "iam", + "bound_iam_principal_arn": "arn:aws:iam::" + accountID + ":*", + "policies": "readpol", + "max_ttl": "5m", + }) + require.NoError(t, err) + + o, e, err := cmd(t, "-d", "vault=vault:///secret/", + "-i", `{{(ds "vault" "foo").value}}`). + withEnv("HOME", tmpDir.Join("home")). + withEnv("VAULT_ADDR", "http://"+v.addr). + withEnv("AWS_ACCESS_KEY_ID", accessKeyID). + withEnv("AWS_SECRET_ACCESS_KEY", secretAccessKey). + run() + assertSuccess(t, o, e, err, "bar") +} diff --git a/internal/tests/integration/datasources_vault_test.go b/internal/tests/integration/datasources_vault_test.go index 5fd0124c..74a2512b 100644 --- a/internal/tests/integration/datasources_vault_test.go +++ b/internal/tests/integration/datasources_vault_test.go @@ -69,7 +69,7 @@ func startVault(t *testing.T) (*fs.Dir, *vaultClient) { "-dev", "-dev-root-token-id="+vaultRootToken, "-dev-kv-v1", // default to v1, so we can test v1 and v2 - "-log-level=err", + "-log-level=info", "-dev-listen-address="+vaultAddr, "-config="+tmpDir.Join("config.json"), ) diff --git a/internal/tests/integration/test_ec2_utils.go b/internal/tests/integration/test_ec2_utils.go deleted file mode 100644 index f7ad1eaa..00000000 --- a/internal/tests/integration/test_ec2_utils.go +++ /dev/null @@ -1,251 +0,0 @@ -package integration - -import ( - "bytes" - "crypto/rand" - "crypto/rsa" - "crypto/x509" - "crypto/x509/pkix" - "encoding/pem" - "log" - "math/big" - "net/http" - "time" - - "github.com/fullsailor/pkcs7" -) - -const instanceDocument = `{ - "devpayProductCodes" : null, - "availabilityZone" : "xx-test-1b", - "privateIp" : "10.1.2.3", - "version" : "2010-08-31", - "instanceId" : "i-00000000000000000", - "billingProducts" : null, - "instanceType" : "t2.micro", - "accountId" : "1", - "imageId" : "ami-00000000", - "pendingTime" : "2000-00-01T0:00:00Z", - "architecture" : "x86_64", - "kernelId" : null, - "ramdiskId" : null, - "region" : "xx-test-1" -}` - -func instanceDocumentHandler(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Content-Type", "application/json") - _, err := w.Write([]byte(instanceDocument)) - if err != nil { - w.WriteHeader(500) - } -} - -func certificateGenerate() (priv *rsa.PrivateKey, derBytes []byte, err error) { - priv, err = rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - log.Fatalf("failed to generate private key: %s", err) - } - - serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) - serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) - if err != nil { - log.Fatalf("failed to generate serial number: %s", err) - } - - template := x509.Certificate{ - SerialNumber: serialNumber, - Subject: pkix.Name{ - Organization: []string{"Test"}, - }, - NotBefore: time.Now().Add(-24 * time.Hour), - NotAfter: time.Now().Add(365 * 24 * time.Hour), - } - - derBytes, err = x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) - if err != nil { - log.Fatalf("Failed to create certificate: %s", err) - } - - return priv, derBytes, err -} - -func pkcsHandler(priv *rsa.PrivateKey, derBytes []byte) func(http.ResponseWriter, *http.Request) { - return func(w http.ResponseWriter, _ *http.Request) { - cert, err := x509.ParseCertificate(derBytes) - if err != nil { - log.Fatalf("Cannot decode certificate: %s", err) - } - - // Initialize a SignedData struct with content to be signed - signedData, err := pkcs7.NewSignedData([]byte(instanceDocument)) - if err != nil { - log.Fatalf("Cannot initialize signed data: %s", err) - } - - // Add the signing cert and private key - if err = signedData.AddSigner(cert, priv, pkcs7.SignerInfoConfig{}); err != nil { - log.Fatalf("Cannot add signer: %s", err) - } - - // Finish() to obtain the signature bytes - detachedSignature, err := signedData.Finish() - if err != nil { - log.Fatalf("Cannot finish signing data: %s", err) - } - - encoded := pem.EncodeToMemory(&pem.Block{Type: "PKCS7", Bytes: detachedSignature}) - - encoded = bytes.TrimPrefix(encoded, []byte("-----BEGIN PKCS7-----\n")) - encoded = bytes.TrimSuffix(encoded, []byte("\n-----END PKCS7-----\n")) - - w.Header().Set("Content-Type", "text/plain") - _, err = w.Write(encoded) - if err != nil { - w.WriteHeader(500) - } - } -} - -func stsHandler(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Content-Type", "text/xml") - _, err := w.Write([]byte(` - - arn:aws:iam::1:user/Test - AKIAI44QH8DHBEXAMPLE - 1 - - - 01234567-89ab-cdef-0123-456789abcdef - -`)) - if err != nil { - w.WriteHeader(500) - } -} - -func ec2Handler(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Content-Type", "text/xml") - _, err := w.Write([]byte(` - 8f7724cf-496f-496e-8fe3-example - - - r-1234567890abcdef0 - 123456789012 - - - - i-00000000000000000 - ami-00000000 - - 16 - running - - ip-192-168-1-88.eu-west-1.compute.internal - ec2-54-194-252-215.eu-west-1.compute.amazonaws.com - - my_keypair - 0 - - t2.micro - 2015-12-22T10:44:05.000Z - - eu-west-1c - - default - - - disabled - - subnet-56f5f633 - vpc-11112222 - 192.168.1.88 - 54.194.252.215 - true - - - sg-e4076980 - SecurityGroup1 - - - x86_64 - ebs - /dev/xvda - - - /dev/xvda - - vol-1234567890abcdef0 - attached - 2015-12-22T10:44:09.000Z - true - - - - hvm - xMcwG14507example - - - Name - Server_1 - - - xen - - - eni-551ba033 - subnet-56f5f633 - vpc-11112222 - Primary network interface - 123456789012 - in-use - 02:dd:2c:5e:01:69 - 192.168.1.88 - ip-192-168-1-88.eu-west-1.compute.internal - true - - - sg-e4076980 - SecurityGroup1 - - - - eni-attach-39697adc - 0 - attached - 2015-12-22T10:44:05.000Z - true - - - 54.194.252.215 - ec2-54-194-252-215.eu-west-1.compute.amazonaws.com - amazon - - - - 192.168.1.88 - ip-192-168-1-88.eu-west-1.compute.internal - true - - 54.194.252.215 - ec2-54-194-252-215.eu-west-1.compute.amazonaws.com - amazon - - - - - - 2001:db8:1234:1a2b::123 - - - - - false - - - - -`)) - if err != nil { - w.WriteHeader(500) - } -} diff --git a/internal/tests/integration/test_ec2_utils_test.go b/internal/tests/integration/test_ec2_utils_test.go new file mode 100644 index 00000000..29b92350 --- /dev/null +++ b/internal/tests/integration/test_ec2_utils_test.go @@ -0,0 +1,370 @@ +//go:build !windows +// +build !windows + +package integration + +import ( + "bytes" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "io" + "log" + "math/big" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "github.com/fullsailor/pkcs7" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gotest.tools/v3/fs" +) + +const instanceDocument = `{ + "devpayProductCodes" : null, + "availabilityZone" : "xx-test-1b", + "privateIp" : "10.1.2.3", + "version" : "2010-08-31", + "instanceId" : "i-00000000000000000", + "billingProducts" : null, + "instanceType" : "t2.micro", + "accountId" : "1", + "imageId" : "ami-00000000", + "pendingTime" : "2000-00-01T0:00:00Z", + "architecture" : "x86_64", + "kernelId" : null, + "ramdiskId" : null, + "region" : "xx-test-1" +}` + +func instanceDocumentHandler(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, err := w.Write([]byte(instanceDocument)) + if err != nil { + w.WriteHeader(500) + } +} + +func certificateGenerate() (priv *rsa.PrivateKey, derBytes []byte, err error) { + priv, err = rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + log.Fatalf("failed to generate private key: %s", err) + } + + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + log.Fatalf("failed to generate serial number: %s", err) + } + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"Test"}, + }, + NotBefore: time.Now().Add(-24 * time.Hour), + NotAfter: time.Now().Add(365 * 24 * time.Hour), + } + + derBytes, err = x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + log.Fatalf("Failed to create certificate: %s", err) + } + + return priv, derBytes, err +} + +func pkcsHandler(priv *rsa.PrivateKey, derBytes []byte) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, _ *http.Request) { + cert, err := x509.ParseCertificate(derBytes) + if err != nil { + log.Fatalf("Cannot decode certificate: %s", err) + } + + // Initialize a SignedData struct with content to be signed + signedData, err := pkcs7.NewSignedData([]byte(instanceDocument)) + if err != nil { + log.Fatalf("Cannot initialize signed data: %s", err) + } + + // Add the signing cert and private key + if err = signedData.AddSigner(cert, priv, pkcs7.SignerInfoConfig{}); err != nil { + log.Fatalf("Cannot add signer: %s", err) + } + + // Finish() to obtain the signature bytes + detachedSignature, err := signedData.Finish() + if err != nil { + log.Fatalf("Cannot finish signing data: %s", err) + } + + encoded := pem.EncodeToMemory(&pem.Block{Type: "PKCS7", Bytes: detachedSignature}) + + encoded = bytes.TrimPrefix(encoded, []byte("-----BEGIN PKCS7-----\n")) + encoded = bytes.TrimSuffix(encoded, []byte("\n-----END PKCS7-----\n")) + + w.Header().Set("Content-Type", "text/plain") + _, err = w.Write(encoded) + if err != nil { + w.WriteHeader(500) + } + } +} + +func stsHandler(t *testing.T, accountID, user string) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + defer r.Body.Close() + + form, _ := url.ParseQuery(string(body)) + + // action must be GetCallerIdentity + assert.Equal(t, "GetCallerIdentity", form.Get("Action")) + + w.Header().Set("Content-Type", "text/xml") + _, err := w.Write([]byte(fmt.Sprintf(` + + + arn:aws:iam::%[1]s:user/%[2]s + AKIAI44QH8DHBEXAMPLE + %[1]s + + + 01234567-89ab-cdef-0123-456789abcdef + + `, accountID, user))) + if err != nil { + t.Errorf("failed to write response: %s", err) + w.WriteHeader(http.StatusInternalServerError) + } + assert.NoError(t, err) + }) +} + +func ec2Handler(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/xml") + _, err := w.Write([]byte(` + 8f7724cf-496f-496e-8fe3-example + + + r-1234567890abcdef0 + 123456789012 + + + + i-00000000000000000 + ami-00000000 + + 16 + running + + ip-192-168-1-88.eu-west-1.compute.internal + ec2-54-194-252-215.eu-west-1.compute.amazonaws.com + + my_keypair + 0 + + t2.micro + 2015-12-22T10:44:05.000Z + + eu-west-1c + + default + + + disabled + + subnet-56f5f633 + vpc-11112222 + 192.168.1.88 + 54.194.252.215 + true + + + sg-e4076980 + SecurityGroup1 + + + x86_64 + ebs + /dev/xvda + + + /dev/xvda + + vol-1234567890abcdef0 + attached + 2015-12-22T10:44:09.000Z + true + + + + hvm + xMcwG14507example + + + Name + Server_1 + + + xen + + + eni-551ba033 + subnet-56f5f633 + vpc-11112222 + Primary network interface + 123456789012 + in-use + 02:dd:2c:5e:01:69 + 192.168.1.88 + ip-192-168-1-88.eu-west-1.compute.internal + true + + + sg-e4076980 + SecurityGroup1 + + + + eni-attach-39697adc + 0 + attached + 2015-12-22T10:44:05.000Z + true + + + 54.194.252.215 + ec2-54-194-252-215.eu-west-1.compute.amazonaws.com + amazon + + + + 192.168.1.88 + ip-192-168-1-88.eu-west-1.compute.internal + true + + 54.194.252.215 + ec2-54-194-252-215.eu-west-1.compute.amazonaws.com + amazon + + + + + + 2001:db8:1234:1a2b::123 + + + + + false + + + + +`)) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + } +} + +func iamGetUserHandler(t *testing.T, accountID string) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + form, _ := url.ParseQuery(string(body)) + + // action must be GetUser + assert.Equal(t, "GetUser", form.Get("Action")) + + w.Header().Set("Content-Type", "text/xml") + _, err := w.Write([]byte(fmt.Sprintf(` + + + + / + %[1]s + m3o9qmhhl9dnjlh2fflg + arn:aws:iam::%[2]s:user/%[1]s + 2024-07-21T17:21:27.259000Z + + + + 3d0e2445-64ea-4bfb-9244-30d810773f9e + + `, form.Get("UserName"), accountID))) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + } + assert.NoError(t, err) + }) +} + +func setupDatasourcesVaultAWSTest(t *testing.T, accountID, user string) (*fs.Dir, *vaultClient, *httptest.Server, []byte) { + t.Helper() + + priv, der, _ := certificateGenerate() + cert := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}) + + mux := http.NewServeMux() + mux.HandleFunc("/latest/dynamic/instance-identity/pkcs7", pkcsHandler(priv, der)) + mux.HandleFunc("/latest/dynamic/instance-identity/document", instanceDocumentHandler) + mux.HandleFunc("/latest/api/token", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var b []byte + if r.Body != nil { + var err error + b, err = io.ReadAll(r.Body) + if !assert.NoError(t, err) { + w.WriteHeader(http.StatusInternalServerError) + return + } + defer r.Body.Close() + } + t.Logf("IMDS Token request: %s %s: %s", r.Method, r.URL, b) + + w.Write([]byte("testtoken")) + })) + mux.HandleFunc("/latest/meta-data/instance-id", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Logf("IMDS request: %s %s", r.Method, r.URL) + w.Write([]byte("i-00000000")) + })) + mux.Handle("/sts/", stsHandler(t, accountID, user)) + mux.Handle("/iam/", iamGetUserHandler(t, accountID)) + mux.HandleFunc("/ec2/", ec2Handler) + mux.HandleFunc("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Logf("unhandled request: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + })) + + // Vault sends requests to "/sts///" for some reason, and the ServeMux + // responds by redirecting to "/sts/" which Vault rejects. So we need to + // handle the extra slashes in a middleware first. + stripSlashes := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + for strings.HasSuffix(r.URL.Path, "//") { + r.URL.Path = r.URL.Path[:len(r.URL.Path)-1] + } + mux.ServeHTTP(w, r) + }) + + srv := httptest.NewServer(stripSlashes) + t.Cleanup(srv.Close) + + tmpDir, v := startVault(t) + + err := v.vc.Sys().PutPolicy("writepol", `path "*" { + policy = "write" +}`) + require.NoError(t, err) + err = v.vc.Sys().PutPolicy("readpol", `path "*" { + policy = "read" +}`) + require.NoError(t, err) + + return tmpDir, v, srv, cert +} -- cgit v1.2.3