diff --git a/go.mod b/go.mod index 94051e1..9fb140c 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module cattlecloud.net/go/webtools go 1.26 require ( + cattlecloud.net/go/forms v1.1.0 cattlecloud.net/go/scope v1.2.1 github.com/golang-jwt/jwt/v5 v5.3.1 github.com/hashicorp/go-set/v3 v3.0.1 @@ -11,4 +12,7 @@ require ( github.com/shoenig/test v1.13.2 ) -require github.com/google/go-cmp v0.7.0 // indirect +require ( + github.com/google/go-cmp v0.7.0 // indirect + github.com/shoenig/lang v0.0.7 // indirect +) diff --git a/go.sum b/go.sum index ec1f785..98618b7 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +cattlecloud.net/go/forms v1.1.0 h1:ZTIyg4AMG9IY87tfNHmRnRgy8X4ZPiui1JYwNT3562s= +cattlecloud.net/go/forms v1.1.0/go.mod h1:B8PMCKE7VtjbvirdXJpvxZIDdU8F/ybaKgNsOqZMkCw= cattlecloud.net/go/scope v1.2.1 h1:kCiA2lE6/qdMXL56rT3ZjkjFH63rwJMq1fCarE2x1F0= cattlecloud.net/go/scope v1.2.1/go.mod h1:YGE0XO+qTS84e0nxPDA97WmiMxnjknMQ7WOUWYNzy9Y= github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY= @@ -10,5 +12,7 @@ github.com/mileusna/useragent v1.3.5 h1:SJM5NzBmh/hO+4LGeATKpaEX9+b4vcGg2qXGLiNG github.com/mileusna/useragent v1.3.5/go.mod h1:3d8TOmwL/5I8pJjyVDteHtgDGcefrFUX4ccGOMKNYYc= github.com/shoenig/go-conceal v0.5.6 h1:K2j8Ql6U4YrBxCRaNF/AnuYaeG8dmf2HcApc7nEdmpk= github.com/shoenig/go-conceal v0.5.6/go.mod h1:rP6ts7GI3lTWQu0gZBWN/aLR1YrdqvrAZbT8cxzxd2A= +github.com/shoenig/lang v0.0.7 h1:0F7/U1ria0edQPYf0e4zX+hJ2Wxo4UPss2fydWkqvCw= +github.com/shoenig/lang v0.0.7/go.mod h1:DStvcG5yPYr/xBBcTEaousm+Pqjn9ozAKfyqWwfhj34= github.com/shoenig/test v1.13.2 h1:SaGxHxg7xkRuKuNtuFmHf0LgNGaAgcBT7HN4WHCKfqU= github.com/shoenig/test v1.13.2/go.mod h1:MKmiRyEeuFl8y9PCoThaRDgYQZeWBhRQlH99poXz5LI= diff --git a/middles/oauth/exchange.go b/middles/oauth/exchange.go new file mode 100644 index 0000000..30284ff --- /dev/null +++ b/middles/oauth/exchange.go @@ -0,0 +1,49 @@ +package oauth + +import ( + "errors" + "net/http" + "regexp" + + "cattlecloud.net/go/forms" + "github.com/shoenig/go-conceal" +) + +func ParseIDP(r *http.Request) string { + var idp string + forms.MustParse(r, forms.Schema{"idp": forms.String(&idp)}) + return idp +} + +func ParseCodeState(r *http.Request) (string, string, error) { + var ( + code string + state string + fail string + ) + + if err := forms.Parse(r, forms.Schema{ + "code": forms.String(&code), + "state": forms.String(&state), + "error": forms.StringOr(&fail, ""), + }); err != nil { + return "", "", err + } + + if fail == "" { + return code, state, nil + } + + return "", "", errors.New(fail) +} + +var nonceRe = regexp.MustCompile(`nonce=([a-f0-9-]{36})`) + +func ParseNonce(state string) *conceal.Text { + // should be in the form nonce=; return empty string if not + results := nonceRe.FindStringSubmatch(state) + if len(results) != 2 { + return conceal.New("") + } + return conceal.New(results[1]) +} diff --git a/middles/oauth/exchange_test.go b/middles/oauth/exchange_test.go new file mode 100644 index 0000000..6f7591d --- /dev/null +++ b/middles/oauth/exchange_test.go @@ -0,0 +1,131 @@ +package oauth + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/shoenig/test/must" +) + +func TestParseIDP(t *testing.T) { + t.Parallel() + + t.Run("success", func(t *testing.T) { + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/?idp=google", nil) + result := ParseIDP(req) + must.Eq(t, "google", result) + }) + + t.Run("missing idp panics", func(t *testing.T) { + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/", nil) + must.Panic(t, func() { + ParseIDP(req) + }) + }) +} + +func TestParseCodeState(t *testing.T) { + t.Parallel() + + t.Run("success with code and state", func(t *testing.T) { + req := httptest.NewRequestWithContext( + t.Context(), http.MethodGet, + "/?code=auth-code-123&state=state-token-456", nil, + ) + code, state, err := ParseCodeState(req) + must.NoError(t, err) + must.Eq(t, "auth-code-123", code) + must.Eq(t, "state-token-456", state) + }) + + t.Run("error returned from provider", func(t *testing.T) { + req := httptest.NewRequestWithContext( + t.Context(), http.MethodGet, + "/?error=access_denied&code=&state=", nil, + ) + code, state, err := ParseCodeState(req) + must.Error(t, err) + must.Eq(t, "access_denied", err.Error()) + must.Eq(t, "", code) + must.Eq(t, "", state) + }) + + t.Run("error from provider without code or state", func(t *testing.T) { + req := httptest.NewRequestWithContext( + t.Context(), http.MethodGet, + "/?error=access_denied", nil, + ) + code, state, err := ParseCodeState(req) + must.Error(t, err) + must.Eq(t, "", code) + must.Eq(t, "", state) + }) + + t.Run("missing code", func(t *testing.T) { + req := httptest.NewRequestWithContext( + t.Context(), http.MethodGet, + "/?state=state-token-456", nil, + ) + code, state, err := ParseCodeState(req) + must.Error(t, err) + must.Eq(t, "", code) + must.Eq(t, "", state) + }) + + t.Run("missing state", func(t *testing.T) { + req := httptest.NewRequestWithContext( + t.Context(), http.MethodGet, + "/?code=auth-code-123", nil, + ) + code, state, err := ParseCodeState(req) + must.Error(t, err) + must.Eq(t, "", code) + must.Eq(t, "", state) + }) +} + +func TestParseNonce(t *testing.T) { + t.Parallel() + + t.Run("valid nonce in state", func(t *testing.T) { + state := "nonce=12345678-1234-1234-1234-123456789abc" + result := ParseNonce(state) + must.Eq(t, "12345678-1234-1234-1234-123456789abc", result.Unveil()) + }) + + t.Run("no nonce in state", func(t *testing.T) { + state := "something without nonce" + result := ParseNonce(state) + must.Eq(t, "", result.Unveil()) + }) + + t.Run("malformed nonce too short", func(t *testing.T) { + state := "nonce=short" + result := ParseNonce(state) + must.Eq(t, "", result.Unveil()) + }) + + t.Run("empty state", func(t *testing.T) { + result := ParseNonce("") + must.Eq(t, "", result.Unveil()) + }) + + t.Run("nonce with extra params", func(t *testing.T) { + state := "nonce=abcdef01-2345-6789-abcd-ef0123456789;extra=stuff" + result := ParseNonce(state) + must.Eq(t, "abcdef01-2345-6789-abcd-ef0123456789", result.Unveil()) + }) + + t.Run("uppercase hex rejected", func(t *testing.T) { + state := "nonce=ABCDEF01-2345-6789-ABCD-EF0123456789" + result := ParseNonce(state) + must.Eq(t, "", result.Unveil()) + }) + + t.Run("nonce at end of longer state", func(t *testing.T) { + state := "prefix_data;nonce=deadbeef-cafe-babe-0123-456789abcdef" + result := ParseNonce(state) + must.Eq(t, "deadbeef-cafe-babe-0123-456789abcdef", result.Unveil()) + }) +}