summaryrefslogtreecommitdiff
path: root/cmd/ask_pass_test.go
blob: f9d828f6890da052c0180418020c485b6c185909 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
package main

import (
	"bytes"
	"context"
	"fmt"
	"net"
	"os"
	"strings"
	"testing"

	"github.com/argoproj/argo-cd/v2/reposerver/askpass"
	"github.com/spf13/cobra"
	"github.com/stretchr/testify/assert"
	"google.golang.org/grpc"
	"google.golang.org/grpc/credentials/insecure"
	"google.golang.org/grpc/test/bufconn"
)

const bufSize = 1024 * 1024

var lis *bufconn.Listener

func init() {
	lis = bufconn.Listen(bufSize)
	s := grpc.NewServer()
	askpass.RegisterAskPassServiceServer(s, &mockAskPassServer{})
	go func() {
		_ = s.Serve(lis)
	}()
}

type mockAskPassServer struct {
	askpass.UnimplementedAskPassServiceServer
}

func (m *mockAskPassServer) GetCredentials(ctx context.Context, req *askpass.CredentialsRequest) (*askpass.CredentialsResponse, error) {
	return &askpass.CredentialsResponse{Username: "testuser", Password: "testpassword"}, nil
}

func bufDialer(context.Context, string) (net.Conn, error) {
	return lis.Dial()
}

func NewTestCommand() *cobra.Command {
	cmd := NewAskPassCommand()
	cmd.Run = func(c *cobra.Command, args []string) {
		ctx := c.Context()
		if len(args) != 1 {
			fmt.Fprintf(c.ErrOrStderr(), "expected 1 argument, got %d\n", len(args))
			return
		}
		nonce := os.Getenv(askpass.ASKPASS_NONCE_ENV)
		if nonce == "" {
			fmt.Fprintf(c.ErrOrStderr(), "%s is not set\n", askpass.ASKPASS_NONCE_ENV)
			return
		}
		// nolint:staticcheck
		conn, err := grpc.DialContext(ctx, "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithTransportCredentials(insecure.NewCredentials()))
		if err != nil {
			fmt.Fprintf(c.ErrOrStderr(), "failed to connect: %v\n", err)
			return
		}
		defer conn.Close()
		client := askpass.NewAskPassServiceClient(conn)
		creds, err := client.GetCredentials(ctx, &askpass.CredentialsRequest{Nonce: nonce})
		if err != nil {
			fmt.Fprintf(c.ErrOrStderr(), "failed to get credentials: %v\n", err)
			return
		}
		switch {
		case strings.HasPrefix(args[0], "Username"):
			fmt.Fprintln(c.OutOrStdout(), creds.Username)
		case strings.HasPrefix(args[0], "Password"):
			fmt.Fprintln(c.OutOrStdout(), creds.Password)
		default:
			fmt.Fprintf(c.ErrOrStderr(), "unknown credential type '%s'\n", args[0])
		}
	}
	return cmd
}

func TestNewAskPassCommand(t *testing.T) {
	testCases := []struct {
		name        string
		args        []string
		envNonce    string
		expectedOut string
		expectedErr string
	}{
		{"no arguments", []string{}, "testnonce", "", "expected 1 argument, got 0"},
		{"missing nonce", []string{"Username"}, "", "", fmt.Sprintf("%s is not set", askpass.ASKPASS_NONCE_ENV)},
		{"valid username request", []string{"Username"}, "testnonce", "testuser", ""},
		{"valid password request", []string{"Password"}, "testnonce", "testpassword", ""},
		{"unknown credential type", []string{"Unknown"}, "testnonce", "", "unknown credential type 'Unknown'"},
	}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			os.Clearenv()
			if tc.envNonce != "" {
				os.Setenv(askpass.ASKPASS_NONCE_ENV, tc.envNonce)
			}

			var stdout, stderr bytes.Buffer
			command := NewTestCommand()
			command.SetArgs(tc.args)
			command.SetOut(&stdout)
			command.SetErr(&stderr)

			err := command.Execute()

			if tc.expectedOut != "" {
				assert.Equal(t, tc.expectedOut, strings.TrimSpace(stdout.String()))
			}

			if tc.expectedErr != "" {
				assert.Contains(t, stderr.String(), tc.expectedErr)
			} else {
				assert.NoError(t, err)
			}
		})
	}
}