summaryrefslogtreecommitdiff
path: root/vault
diff options
context:
space:
mode:
authorDrew MacInnis <Drew.MacInnis@qlik.com>2016-11-17 22:43:06 -0500
committerDrew MacInnis <Drew.MacInnis@qlik.com>2016-11-19 12:12:20 -0500
commitdbdd898f16a873d785f95d78a2d2a1a999f9882f (patch)
treecef6c89ce0d8f04df8500ee546551617a6eddf7c /vault
parent3d0c34fd61d758e9b61cdbfc6a950b9e378f807f (diff)
Handle vault redirects
Diffstat (limited to 'vault')
-rw-r--r--vault/app-id_strategy.go22
-rw-r--r--vault/client.go49
-rw-r--r--vault/http.go47
-rw-r--r--vault/http_test.go68
4 files changed, 163 insertions, 23 deletions
diff --git a/vault/app-id_strategy.go b/vault/app-id_strategy.go
index 2d8b04b5..06603f0b 100644
--- a/vault/app-id_strategy.go
+++ b/vault/app-id_strategy.go
@@ -28,20 +28,34 @@ func NewAppIDAuthStrategy() *AppIDAuthStrategy {
return nil
}
-// GetToken - log in to the app-id auth backend and return the client token
-func (a *AppIDAuthStrategy) GetToken(addr *url.URL) (string, error) {
+// GetHTTPClient configures the HTTP client with a timeout
+func (a *AppIDAuthStrategy) GetHTTPClient() *http.Client {
if a.hc == nil {
a.hc = &http.Client{Timeout: time.Second * 5}
}
- client := a.hc
+ return a.hc
+}
+
+// SetToken is a no-op for AppIDAuthStrategy as a token hasn't been acquired yet
+func (a *AppIDAuthStrategy) SetToken(req *http.Request) {
+ // no-op
+}
+// Do wraps http.Client.Do
+func (a *AppIDAuthStrategy) Do(req *http.Request) (*http.Response, error) {
+ hc := a.GetHTTPClient()
+ return hc.Do(req)
+}
+
+// GetToken - log in to the app-id auth backend and return the client token
+func (a *AppIDAuthStrategy) GetToken(addr *url.URL) (string, error) {
buf := new(bytes.Buffer)
json.NewEncoder(buf).Encode(&a)
u := &url.URL{}
*u = *addr
u.Path = "/v1/auth/app-id/login"
- res, err := client.Post(u.String(), "application/json; charset=utf-8", buf)
+ res, err := requestAndFollow(a, "POST", u, buf.Bytes())
if err != nil {
return "", err
}
diff --git a/vault/client.go b/vault/client.go
index 3a8cad65..7e80fa27 100644
--- a/vault/client.go
+++ b/vault/client.go
@@ -54,6 +54,31 @@ func getAuthStrategy() AuthStrategy {
return nil
}
+// GetHTTPClient returns a client configured w/X-Vault-Token header
+func (c *Client) GetHTTPClient() *http.Client {
+ if c.hc == nil {
+ c.hc = &http.Client{
+ Timeout: time.Second * 5,
+ CheckRedirect: func(req *http.Request, via []*http.Request) error {
+ c.SetToken(req)
+ return nil
+ },
+ }
+ }
+ return c.hc
+}
+
+// SetToken adds an X-Vault-Token header to the request
+func (c *Client) SetToken(req *http.Request) {
+ req.Header.Set("X-Vault-Token", c.token)
+}
+
+// Do wraps http.Client.Do
+func (c *Client) Do(req *http.Request) (*http.Response, error) {
+ hc := c.GetHTTPClient()
+ return hc.Do(req)
+}
+
// Login - log in to Vault with the discovered auth backend and save the token
func (c *Client) Login() error {
token, err := c.Auth.GetToken(c.Addr)
@@ -72,17 +97,12 @@ func (c *Client) RevokeToken() {
return
}
- if c.hc == nil {
- c.hc = &http.Client{Timeout: time.Second * 5}
- }
-
u := &url.URL{}
*u = *c.Addr
u.Path = "/v1/auth/token/revoke-self"
- req, _ := http.NewRequest("POST", u.String(), nil)
- req.Header.Set("X-Vault-Token", c.token)
- res, err := c.hc.Do(req)
+ res, err := requestAndFollow(c, "POST", u, nil)
+
if err != nil {
log.Println("Error while revoking Vault Token", err)
}
@@ -94,24 +114,15 @@ func (c *Client) RevokeToken() {
func (c *Client) Read(path string) ([]byte, error) {
path = normalizeURLPath(path)
- if c.hc == nil {
- c.hc = &http.Client{Timeout: time.Second * 5}
- }
u := &url.URL{}
*u = *c.Addr
u.Path = "/v1" + path
- req, err := http.NewRequest("GET", u.String(), nil)
- if err != nil {
- return nil, err
- }
- req.Header.Set("X-Vault-Token", c.token)
- res, err := c.hc.Do(req)
+ res, err := requestAndFollow(c, "GET", u, nil)
if err != nil {
return nil, err
}
-
body, err := ioutil.ReadAll(res.Body)
res.Body.Close()
if err != nil {
@@ -119,7 +130,7 @@ func (c *Client) Read(path string) ([]byte, error) {
}
if res.StatusCode != 200 {
- err = fmt.Errorf("Unexpected HTTP status %d on Read from %s: %s", res.StatusCode, u, body)
+ err = fmt.Errorf("Unexpected HTTP status %d on Read from %s: %s", res.StatusCode, path, body)
return nil, err
}
@@ -131,7 +142,7 @@ func (c *Client) Read(path string) ([]byte, error) {
}
if _, ok := response["data"]; !ok {
- return nil, fmt.Errorf("Unexpected HTTP body on Read for %s: %s", u, body)
+ return nil, fmt.Errorf("Unexpected HTTP body on Read for %s: %s", path, body)
}
return json.Marshal(response["data"])
diff --git a/vault/http.go b/vault/http.go
new file mode 100644
index 00000000..f8335796
--- /dev/null
+++ b/vault/http.go
@@ -0,0 +1,47 @@
+package vault
+
+import (
+ "bytes"
+ "net/http"
+ "net/url"
+)
+
+// httpClient
+type httpClient interface {
+ GetHTTPClient() *http.Client
+ SetToken(req *http.Request)
+ Do(req *http.Request) (*http.Response, error)
+}
+
+func requestAndFollow(hc httpClient, method string, u *url.URL, body []byte) (*http.Response, error) {
+ var res *http.Response
+ var err error
+ for attempts := 0; attempts < 2; attempts++ {
+ reader := bytes.NewReader(body)
+ req, err := http.NewRequest(method, u.String(), reader)
+
+ if err != nil {
+ return nil, err
+ }
+ hc.SetToken(req)
+ if method == "POST" {
+ req.Header.Set("Content-Type", "application/json; charset=utf-8")
+ }
+
+ res, err = hc.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ if res.StatusCode == http.StatusTemporaryRedirect {
+ res.Body.Close()
+ location, errLocation := res.Location()
+ if errLocation != nil {
+ return nil, errLocation
+ }
+ u.Host = location.Host
+ } else {
+ break
+ }
+ }
+ return res, err
+}
diff --git a/vault/http_test.go b/vault/http_test.go
new file mode 100644
index 00000000..e2ee4429
--- /dev/null
+++ b/vault/http_test.go
@@ -0,0 +1,68 @@
+package vault
+
+import (
+ "fmt"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+type testClient struct{}
+
+func (tc *testClient) GetHTTPClient() *http.Client {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ reqStr := fmt.Sprintf("%s %s", r.Method, r.URL)
+ switch reqStr {
+ case "POST http://vaultA:8500/v1/foo":
+ w.Header().Set("Content-Type", "application/json")
+ w.Header().Set("Location", "http://vaultB:8500/v1/foo")
+ w.WriteHeader(http.StatusTemporaryRedirect)
+ fmt.Fprintln(w, "{}")
+ case "POST http://vaultB:8500/v1/foo":
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusOK)
+ fmt.Fprintln(w, "{}")
+ default:
+ w.WriteHeader(http.StatusInternalServerError)
+ fmt.Fprintf(w, "{ 'message': 'Unexpected request: %s'}", reqStr)
+ }
+ }))
+ return &http.Client{
+ Transport: &http.Transport{
+ Proxy: func(req *http.Request) (*url.URL, error) {
+ return url.Parse(server.URL)
+ },
+ },
+ }
+}
+
+func (tc *testClient) SetToken(req *http.Request) {
+ req.Header.Set("X-Vault-Token", "dead-beef-cafe-babe")
+}
+
+func (tc *testClient) Do(req *http.Request) (*http.Response, error) {
+ hc := tc.GetHTTPClient()
+ return hc.Do(req)
+}
+
+func TestRequestAndFollow_GetWithRedirect(t *testing.T) {
+ tc := &testClient{}
+ u, _ := url.Parse("http://vaultA:8500/v1/foo")
+
+ res, err := requestAndFollow(tc, "POST", u, nil)
+ assert.NoError(t, err)
+ assert.Equal(t, http.StatusOK, res.StatusCode)
+
+}
+
+func TestRequestAndFollow_GetNoRedirect(t *testing.T) {
+ tc := &testClient{}
+ u, _ := url.Parse("http://vaultB:8500/v1/foo")
+
+ res, err := requestAndFollow(tc, "POST", u, nil)
+ assert.NoError(t, err)
+ assert.Equal(t, http.StatusOK, res.StatusCode)
+}