Merge pull request #1673 from smallstep/herman/wire-template-transform

Add OIDC token template transformation
pull/1670/head
Herman Slatman 5 months ago committed by GitHub
commit 17578b57f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -10,9 +10,9 @@ import (
"encoding/asn1"
"encoding/base64"
"encoding/json"
"encoding/pem"
"errors"
"io"
"math/big"
"net/http"
"net/http/httptest"
"net/url"
@ -28,8 +28,12 @@ import (
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/authority/provisioner/wire"
nosqlDB "github.com/smallstep/nosql"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/minica"
"go.step.sm/crypto/pemutil"
"go.step.sm/crypto/x509util"
)
const (
@ -50,11 +54,42 @@ func newWireProvisionerWithOptions(t *testing.T, options *provisioner.Options) *
return a
}
// TODO(hs): replace with test CA server + acmez based test client for
// more realistic integration test?
func TestWireIntegration(t *testing.T) {
fakeKey := `-----BEGIN PUBLIC KEY-----
MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k=
-----END PUBLIC KEY-----`
accessTokenSignerJWK, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
require.NoError(t, err)
accessTokenSignerPEMBlock, err := pemutil.Serialize(accessTokenSignerJWK.Public().Key)
require.NoError(t, err)
accessTokenSignerPEMBytes := pem.EncodeToMemory(accessTokenSignerPEMBlock)
accessTokenSigner, err := jose.NewSigner(jose.SigningKey{
Algorithm: jose.SignatureAlgorithm(accessTokenSignerJWK.Algorithm),
Key: accessTokenSignerJWK,
}, new(jose.SignerOptions))
require.NoError(t, err)
oidcTokenSignerJWK, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
require.NoError(t, err)
oidcTokenSigner, err := jose.NewSigner(jose.SigningKey{
Algorithm: jose.SignatureAlgorithm(oidcTokenSignerJWK.Algorithm),
Key: oidcTokenSignerJWK,
}, new(jose.SignerOptions))
require.NoError(t, err)
prov := newWireProvisionerWithOptions(t, &provisioner.Options{
X509: &provisioner.X509Options{
Template: `{
"subject": {
"organization": "WireTest",
"commonName": {{ toJson .Oidc.name }}
},
"uris": [{{ toJson .Oidc.handle }}, {{ toJson .Dpop.sub }}],
"keyUsage": ["digitalSignature"],
"extKeyUsage": ["clientAuth"]
}`,
},
Wire: &wire.Options{
OIDC: &wire.OIDCOptions{
Provider: &wire.Provider{
@ -71,12 +106,13 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k=
SkipClientIDCheck: true,
SkipExpiryCheck: true,
SkipIssuerCheck: true,
InsecureSkipSignatureCheck: true,
InsecureSkipSignatureCheck: true, // NOTE: this skips actual token verification
Now: time.Now,
},
TransformTemplate: "",
},
DPOP: &wire.DPOPOptions{
SigningKey: []byte(fakeKey),
SigningKey: accessTokenSignerPEMBytes,
},
},
})
@ -113,6 +149,12 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k=
ed25519PrivKey, ok := jwk.Key.(ed25519.PrivateKey)
require.True(t, ok)
dpopSigner, err := jose.NewSigner(jose.SigningKey{
Algorithm: jose.SignatureAlgorithm(jwk.Algorithm),
Key: jwk,
}, new(jose.SignerOptions))
require.NoError(t, err)
ed25519PubKey, ok := ed25519PrivKey.Public().(ed25519.PublicKey)
require.True(t, ok)
@ -256,7 +298,102 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k=
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("chID", challenge.ID)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: nil})
var payload []byte
switch challenge.Type {
case acme.WIREDPOP01:
dpopBytes, err := json.Marshal(struct {
jose.Claims
Challenge string `json:"chal,omitempty"`
Handle string `json:"handle,omitempty"`
}{
Claims: jose.Claims{
Subject: "wireapp://lJGYPz0ZRq2kvc_XpdaDlA!ed416ce8ecdd9fad@example.com",
},
Challenge: "token",
Handle: "wireapp://%40alice.smith.qa@example.com",
})
require.NoError(t, err)
dpop, err := dpopSigner.Sign(dpopBytes)
require.NoError(t, err)
proof, err := dpop.CompactSerialize()
require.NoError(t, err)
tokenBytes, err := json.Marshal(struct {
jose.Claims
Challenge string `json:"chal,omitempty"`
Cnf struct {
Kid string `json:"kid,omitempty"`
} `json:"cnf"`
Proof string `json:"proof,omitempty"`
ClientID string `json:"client_id"`
APIVersion int `json:"api_version"`
Scope string `json:"scope"`
}{
Claims: jose.Claims{
Issuer: "http://issuer.example.com",
Audience: []string{"test"},
Expiry: jose.NewNumericDate(time.Now().Add(1 * time.Minute)),
},
Challenge: "token",
Cnf: struct {
Kid string `json:"kid,omitempty"`
}{
Kid: jwk.KeyID,
},
Proof: proof,
ClientID: "wireapp://lJGYPz0ZRq2kvc_XpdaDlA!ed416ce8ecdd9fad@example.com",
APIVersion: 5,
Scope: "wire_client_id",
})
require.NoError(t, err)
signed, err := accessTokenSigner.Sign(tokenBytes)
require.NoError(t, err)
accessToken, err := signed.CompactSerialize()
require.NoError(t, err)
p, err := json.Marshal(struct {
AccessToken string `json:"access_token"`
}{
AccessToken: accessToken,
})
require.NoError(t, err)
payload = p
case acme.WIREOIDC01:
keyAuth, err := acme.KeyAuthorization("token", jwk)
require.NoError(t, err)
tokenBytes, err := json.Marshal(struct {
jose.Claims
Name string `json:"name,omitempty"`
PreferredUsername string `json:"preferred_username,omitempty"`
}{
Claims: jose.Claims{
Issuer: "https://issuer.example.com",
Audience: []string{"test"},
Expiry: jose.NewNumericDate(time.Now().Add(1 * time.Minute)),
},
Name: "Alice Smith",
PreferredUsername: "wireapp://%40alice_wire@wire.com",
})
require.NoError(t, err)
signed, err := oidcTokenSigner.Sign(tokenBytes)
require.NoError(t, err)
idToken, err := signed.CompactSerialize()
require.NoError(t, err)
p, err := json.Marshal(struct {
IDToken string `json:"id_token"`
KeyAuth string `json:"keyauth"`
}{
IDToken: idToken,
KeyAuth: keyAuth,
})
require.NoError(t, err)
payload = p
default:
require.Fail(t, "unexpected challenge payload type")
}
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payload})
req := httptest.NewRequest(http.MethodGet, "https://random.local/", http.NoBody).WithContext(ctx)
w := httptest.NewRecorder()
@ -297,6 +434,16 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k=
updatedAz := updateAz(ctx, az)
for _, challenge := range updatedAz.Challenges {
t.Log("updated challenge:", challenge.ID, challenge.Status)
switch challenge.Type {
case acme.WIREOIDC01:
err = db.CreateOidcToken(ctx, order.ID, map[string]any{"name": "Smith, Alice M (QA)", "handle": "wireapp://%40alice.smith.qa@example.com"})
require.NoError(t, err)
case acme.WIREDPOP01:
err = db.CreateDpopToken(ctx, order.ID, map[string]any{"sub": "wireapp://lJGYPz0ZRq2kvc_XpdaDlA!ed416ce8ecdd9fad@example.com"})
require.NoError(t, err)
default:
require.Fail(t, "unexpected challenge type")
}
}
}
@ -328,11 +475,36 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k=
// finalize order
finalizedOrder := func(ctx context.Context) (finalizedOrder *acme.Order) {
ca, err := minica.New(minica.WithName("WireTestCA"))
require.NoError(t, err)
mockMustAuthority(t, &mockCASigner{
signer: func(*x509.CertificateRequest, provisioner.SignOptions, ...provisioner.SignOption) ([]*x509.Certificate, error) {
return []*x509.Certificate{
{SerialNumber: big.NewInt(2)},
}, nil
signer: func(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
var (
certOptions []x509util.Option
)
for _, op := range extraOpts {
if k, ok := op.(provisioner.CertificateOptions); ok {
certOptions = append(certOptions, k.Options(signOpts)...)
}
}
x509utilTemplate, err := x509util.NewCertificate(csr, certOptions...)
require.NoError(t, err)
template := x509utilTemplate.GetCertificate()
require.NotNil(t, template)
cert, err := ca.Sign(template)
require.NoError(t, err)
u1, err := url.Parse("wireapp://%40alice.smith.qa@example.com")
require.NoError(t, err)
u2, err := url.Parse("wireapp://lJGYPz0ZRq2kvc_XpdaDlA%21ed416ce8ecdd9fad@example.com")
require.NoError(t, err)
assert.Equal(t, []*url.URL{u1, u2}, cert.URIs)
assert.Equal(t, "Smith, Alice M (QA)", cert.Subject.CommonName)
return []*x509.Certificate{cert, ca.Intermediate}, nil
},
})
@ -369,12 +541,6 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k=
frRaw, err := json.Marshal(fr)
require.NoError(t, err)
// TODO(hs): move these to a more appropriate place and/or provide more realistic value
err = db.CreateDpopToken(ctx, order.ID, map[string]any{"fake-dpop": "dpop-value"})
require.NoError(t, err)
err = db.CreateOidcToken(ctx, order.ID, map[string]any{"fake-oidc": "oidc-value"})
require.NoError(t, err)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: frRaw})
chiCtx := chi.NewRouteContext()

@ -25,6 +25,7 @@ import (
"strings"
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/fxamacker/cbor/v2"
"github.com/google/go-tpm/legacy/tpm2"
"github.com/smallstep/go-attestation/attest"
@ -36,6 +37,7 @@ import (
"github.com/smallstep/certificates/acme/wire"
"github.com/smallstep/certificates/authority/provisioner"
wireprovisioner "github.com/smallstep/certificates/authority/provisioner/wire"
)
type ChallengeType string
@ -113,7 +115,7 @@ func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey,
case WIREDPOP01:
return wireDPOP01Validate(ctx, ch, db, jwk, payload)
default:
return NewErrorISE("unexpected challenge type '%s'", ch.Type)
return NewErrorISE("unexpected challenge type %q", ch.Type)
}
}
@ -360,14 +362,18 @@ type wireOidcPayload struct {
func wireOIDC01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, payload []byte) error {
prov, ok := ProvisionerFromContext(ctx)
if !ok {
return NewErrorISE("no provisioner provided")
return NewErrorISE("missing provisioner")
}
var oidcPayload wireOidcPayload
err := json.Unmarshal(payload, &oidcPayload)
if err != nil {
return storeError(ctx, db, ch, false, WrapError(ErrorRejectedIdentifierType, err,
"error unmarshalling Wire challenge payload"))
return WrapError(ErrorMalformedType, err, "error unmarshalling Wire OIDC challenge payload")
}
wireID, err := wire.ParseID([]byte(ch.Value))
if err != nil {
return WrapErrorISE(err, "error unmarshalling challenge data")
}
wireOptions, err := prov.GetOptions().GetWireOptions()
@ -375,10 +381,21 @@ func wireOIDC01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSO
return WrapErrorISE(err, "failed getting Wire options")
}
// TODO(hs): move this into validation below?
expectedKeyAuth, err := KeyAuthorization(ch.Token, jwk)
if err != nil {
return WrapErrorISE(err, "error determining key authorization")
}
if expectedKeyAuth != oidcPayload.KeyAuth {
return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType,
"keyAuthorization does not match; expected %q, but got %q", expectedKeyAuth, oidcPayload.KeyAuth))
}
oidcOptions := wireOptions.GetOIDCOptions()
idToken, err := oidcOptions.GetProvider(ctx).Verifier(oidcOptions.GetConfig()).Verify(ctx, oidcPayload.IDToken)
verifier := oidcOptions.GetProvider(ctx).Verifier(oidcOptions.GetConfig())
idToken, err := verifier.Verify(ctx, oidcPayload.IDToken)
if err != nil {
return storeError(ctx, db, ch, false, WrapError(ErrorRejectedIdentifierType, err,
return storeError(ctx, db, ch, true, WrapError(ErrorRejectedIdentifierType, err,
"error verifying ID token signature"))
}
@ -390,26 +407,13 @@ func wireOIDC01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSO
KeyAuth string `json:"keyauth"` // TODO(hs): use this property instead of the one in the payload after https://github.com/wireapp/rusty-jwt-tools/tree/fix/keyauth is done
}
if err := idToken.Claims(&claims); err != nil {
return storeError(ctx, db, ch, false, WrapError(ErrorRejectedIdentifierType, err,
return storeError(ctx, db, ch, true, WrapError(ErrorRejectedIdentifierType, err,
"error retrieving claims from ID token"))
}
wireID, err := wire.ParseID([]byte(ch.Value))
transformedIDToken, err := validateWireOIDCClaims(oidcOptions, idToken, wireID)
if err != nil {
return WrapErrorISE(err, "error unmarshalling challenge data")
}
expectedKeyAuth, err := KeyAuthorization(ch.Token, jwk)
if err != nil {
return err
}
if expectedKeyAuth != oidcPayload.KeyAuth {
return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType,
"keyAuthorization does not match; expected %q, but got %q", expectedKeyAuth, oidcPayload.KeyAuth))
}
if wireID.Name != claims.Name || wireID.Handle != claims.Handle {
return storeError(ctx, db, ch, false, NewError(ErrorRejectedIdentifierType, "claims in OIDC ID token don't match"))
return storeError(ctx, db, ch, true, WrapError(ErrorRejectedIdentifierType, err, "claims in OIDC ID token don't match"))
}
// Update and store the challenge.
@ -421,31 +425,51 @@ func wireOIDC01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSO
return WrapErrorISE(err, "error updating challenge")
}
parsedIDToken, err := jose.ParseSigned(oidcPayload.IDToken)
if err != nil {
return WrapErrorISE(err, "invalid OIDC ID token")
}
oidcToken := make(map[string]interface{})
if err := parsedIDToken.UnsafeClaimsWithoutVerification(&oidcToken); err != nil {
return WrapErrorISE(err, "failed parsing OIDC id token")
}
orders, err := db.GetAllOrdersByAccountID(ctx, ch.AccountID)
if err != nil {
return WrapErrorISE(err, "could not find current order by account id")
return WrapErrorISE(err, "could not retrieve current order by account id")
}
if len(orders) == 0 {
return NewErrorISE("there are not enough orders for this account for this custom OIDC challenge")
}
order := orders[len(orders)-1]
if err := db.CreateOidcToken(ctx, order, oidcToken); err != nil {
if err := db.CreateOidcToken(ctx, order, transformedIDToken); err != nil {
return WrapErrorISE(err, "failed storing OIDC id token")
}
return nil
}
func validateWireOIDCClaims(o *wireprovisioner.OIDCOptions, token *oidc.IDToken, wireID wire.ID) (map[string]any, error) {
var m map[string]any
if err := token.Claims(&m); err != nil {
return nil, fmt.Errorf("failed extracting OIDC ID token claims: %w", err)
}
transformed, err := o.Transform(m)
if err != nil {
return nil, fmt.Errorf("failed transforming OIDC ID token: %w", err)
}
name, ok := transformed["name"]
if !ok {
return nil, fmt.Errorf("transformed OIDC ID token does not contain 'name'")
}
if wireID.Name != name {
return nil, fmt.Errorf("invalid 'name' %q after transformation", name)
}
handle, ok := transformed["handle"]
if !ok {
return nil, fmt.Errorf("transformed OIDC ID token does not contain 'handle'")
}
if wireID.Handle != handle {
return nil, fmt.Errorf("invalid 'handle' %q after transformation", handle)
}
return transformed, nil
}
type wireDpopPayload struct {
// AccessToken is the token generated by wire-server
AccessToken string `json:"access_token"`
@ -459,8 +483,7 @@ func wireDPOP01Validate(ctx context.Context, ch *Challenge, db DB, accountJWK *j
var dpopPayload wireDpopPayload
if err := json.Unmarshal(payload, &dpopPayload); err != nil {
return storeError(ctx, db, ch, false, WrapError(ErrorRejectedIdentifierType, err,
"error unmarshalling Wire challenge payload"))
return WrapError(ErrorMalformedType, err, "error unmarshalling Wire DPoP challenge payload")
}
wireID, err := wire.ParseID([]byte(ch.Value))
@ -496,7 +519,8 @@ func wireDPOP01Validate(ctx context.Context, ch *Challenge, db DB, accountJWK *j
}
_, dpop, err := parseAndVerifyWireAccessToken(params)
if err != nil {
return WrapErrorISE(err, "failed validating token")
return storeError(ctx, db, ch, true, WrapError(ErrorRejectedIdentifierType, err,
"failed validating Wire access token"))
}
// Update and store the challenge.

@ -31,17 +31,16 @@ import (
"time"
"github.com/fxamacker/cbor/v2"
"github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/provisioner"
wireprovisioner "github.com/smallstep/certificates/authority/provisioner/wire"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil"
"go.step.sm/crypto/minica"
"go.step.sm/crypto/pemutil"
"go.step.sm/crypto/x509util"
"github.com/smallstep/certificates/acme/wire"
"github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type mockClient struct {
@ -199,6 +198,25 @@ func mustAttestYubikey(t *testing.T, _, keyAuthorization string, serial int) ([]
return payload, leaf, ca.Root
}
func newWireProvisionerWithOptions(t *testing.T, options *provisioner.Options) *provisioner.ACME {
t.Helper()
prov := &provisioner.ACME{
Type: "ACME",
Name: "acme",
Options: options,
Challenges: []provisioner.ACMEChallenge{
provisioner.WIREOIDC_01,
provisioner.WIREDPOP_01,
},
}
if err := prov.Init(provisioner.Config{
Claims: config.GlobalProvisionerClaims,
}); err != nil {
t.Fatal(err)
}
return prov
}
func Test_storeError(t *testing.T) {
type test struct {
ch *Challenge
@ -399,6 +417,9 @@ func TestKeyAuthorization(t *testing.T) {
}
func TestChallenge_Validate(t *testing.T) {
fakeKey := `-----BEGIN PUBLIC KEY-----
MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k=
-----END PUBLIC KEY-----`
type test struct {
ch *Challenge
vc Client
@ -433,7 +454,7 @@ func TestChallenge_Validate(t *testing.T) {
}
return test{
ch: ch,
err: NewErrorISE("unexpected challenge type 'foo'"),
err: NewErrorISE(`unexpected challenge type "foo"`),
}
},
"fail/http-01": func(t *testing.T) test {
@ -856,6 +877,256 @@ func TestChallenge_Validate(t *testing.T) {
},
}
},
"ok/wire-oidc-01": func(t *testing.T) test {
jwk, keyAuth := mustAccountAndKeyAuthorization(t, "token")
signerJWK, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
require.NoError(t, err)
signer, err := jose.NewSigner(jose.SigningKey{
Algorithm: jose.SignatureAlgorithm(signerJWK.Algorithm),
Key: signerJWK,
}, new(jose.SignerOptions))
require.NoError(t, err)
srv := mustJWKServer(t, signerJWK.Public())
tokenBytes, err := json.Marshal(struct {
jose.Claims
Name string `json:"name,omitempty"`
PreferredUsername string `json:"preferred_username,omitempty"`
}{
Claims: jose.Claims{
Issuer: srv.URL,
Audience: []string{"test"},
Expiry: jose.NewNumericDate(time.Now().Add(1 * time.Minute)),
},
Name: "Alice Smith",
PreferredUsername: "wireapp://%40alice_wire@wire.com",
})
require.NoError(t, err)
signed, err := signer.Sign(tokenBytes)
require.NoError(t, err)
idToken, err := signed.CompactSerialize()
require.NoError(t, err)
payload, err := json.Marshal(struct {
IDToken string `json:"id_token"`
KeyAuth string `json:"keyauth"`
}{
IDToken: idToken,
KeyAuth: keyAuth,
})
require.NoError(t, err)
valueBytes, err := json.Marshal(struct {
Name string `json:"name,omitempty"`
Domain string `json:"domain,omitempty"`
ClientID string `json:"client-id,omitempty"`
Handle string `json:"handle,omitempty"`
}{
Name: "Alice Smith",
Domain: "wire.com",
ClientID: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com",
Handle: "wireapp://%40alice_wire@wire.com",
})
require.NoError(t, err)
ctx := NewProvisionerContext(context.Background(), newWireProvisionerWithOptions(t, &provisioner.Options{
Wire: &wireprovisioner.Options{
OIDC: &wireprovisioner.OIDCOptions{
Provider: &wireprovisioner.Provider{
IssuerURL: srv.URL,
JWKSURL: srv.URL + "/keys",
},
Config: &wireprovisioner.Config{
ClientID: "test",
SignatureAlgorithms: []string{"ES256"},
SkipClientIDCheck: false,
SkipExpiryCheck: false,
SkipIssuerCheck: false,
InsecureSkipSignatureCheck: false,
Now: time.Now,
},
TransformTemplate: "",
},
DPOP: &wireprovisioner.DPOPOptions{
SigningKey: []byte(fakeKey),
},
},
}))
return test{
ch: &Challenge{
ID: "chID",
AuthorizationID: "azID",
AccountID: "accID",
Token: "token",
Type: "wire-oidc-01",
Status: StatusPending,
Value: string(valueBytes),
},
srv: srv,
payload: payload,
ctx: ctx,
jwk: jwk,
db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
assert.Equal(t, "chID", updch.ID)
assert.Equal(t, "token", updch.Token)
assert.Equal(t, StatusValid, updch.Status)
assert.Equal(t, ChallengeType("wire-oidc-01"), updch.Type)
assert.Equal(t, string(valueBytes), updch.Value)
return nil
},
MockGetAllOrdersByAccountID: func(ctx context.Context, accountID string) ([]string, error) {
assert.Equal(t, "accID", accountID)
return []string{"orderID"}, nil
},
MockCreateOidcToken: func(ctx context.Context, orderID string, idToken map[string]interface{}) error {
assert.Equal(t, "orderID", orderID)
assert.Equal(t, "Alice Smith", idToken["name"].(string))
assert.Equal(t, "wireapp://%40alice_wire@wire.com", idToken["handle"].(string))
return nil
},
},
}
},
"ok/wire-dpop-01": func(t *testing.T) test {
jwk, keyAuth := mustAccountAndKeyAuthorization(t, "token")
_ = keyAuth // TODO(hs): keyAuth (not) required for DPoP? Or needs to be added to validation?
dpopSigner, err := jose.NewSigner(jose.SigningKey{
Algorithm: jose.SignatureAlgorithm(jwk.Algorithm),
Key: jwk,
}, new(jose.SignerOptions))
require.NoError(t, err)
signerJWK, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
require.NoError(t, err)
signer, err := jose.NewSigner(jose.SigningKey{
Algorithm: jose.SignatureAlgorithm(signerJWK.Algorithm),
Key: signerJWK,
}, new(jose.SignerOptions))
require.NoError(t, err)
signerPEMBlock, err := pemutil.Serialize(signerJWK.Public().Key)
require.NoError(t, err)
signerPEMBytes := pem.EncodeToMemory(signerPEMBlock)
dpopBytes, err := json.Marshal(struct {
jose.Claims
Challenge string `json:"chal,omitempty"`
Handle string `json:"handle,omitempty"`
}{
Claims: jose.Claims{
Subject: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com",
},
Challenge: "token",
Handle: "wireapp://%40alice_wire@wire.com",
})
require.NoError(t, err)
dpop, err := dpopSigner.Sign(dpopBytes)
require.NoError(t, err)
proof, err := dpop.CompactSerialize()
require.NoError(t, err)
tokenBytes, err := json.Marshal(struct {
jose.Claims
Challenge string `json:"chal,omitempty"`
Cnf struct {
Kid string `json:"kid,omitempty"`
} `json:"cnf"`
Proof string `json:"proof,omitempty"`
ClientID string `json:"client_id"`
APIVersion int `json:"api_version"`
Scope string `json:"scope"`
}{
Claims: jose.Claims{
Issuer: "http://issuer.example.com",
Audience: []string{"test"},
Expiry: jose.NewNumericDate(time.Now().Add(1 * time.Minute)),
},
Challenge: "token",
Cnf: struct {
Kid string `json:"kid,omitempty"`
}{
Kid: jwk.KeyID,
},
Proof: proof,
ClientID: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com",
APIVersion: 5,
Scope: "wire_client_id",
})
require.NoError(t, err)
signed, err := signer.Sign(tokenBytes)
require.NoError(t, err)
accessToken, err := signed.CompactSerialize()
require.NoError(t, err)
payload, err := json.Marshal(struct {
AccessToken string `json:"access_token"`
}{
AccessToken: accessToken,
})
require.NoError(t, err)
valueBytes, err := json.Marshal(struct {
Name string `json:"name,omitempty"`
Domain string `json:"domain,omitempty"`
ClientID string `json:"client-id,omitempty"`
Handle string `json:"handle,omitempty"`
}{
Name: "Alice Smith",
Domain: "wire.com",
ClientID: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com",
Handle: "wireapp://%40alice_wire@wire.com",
})
require.NoError(t, err)
ctx := NewProvisionerContext(context.Background(), newWireProvisionerWithOptions(t, &provisioner.Options{
Wire: &wireprovisioner.Options{
OIDC: &wireprovisioner.OIDCOptions{
Provider: &wireprovisioner.Provider{
IssuerURL: "http://issuerexample.com",
},
Config: &wireprovisioner.Config{
ClientID: "test",
SignatureAlgorithms: []string{"ES256"},
SkipClientIDCheck: false,
SkipExpiryCheck: false,
SkipIssuerCheck: false,
InsecureSkipSignatureCheck: false,
Now: time.Now,
},
TransformTemplate: "",
},
DPOP: &wireprovisioner.DPOPOptions{
SigningKey: signerPEMBytes,
},
},
}))
return test{
ch: &Challenge{
ID: "chID",
AuthorizationID: "azID",
AccountID: "accID",
Token: "token",
Type: "wire-dpop-01",
Status: StatusPending,
Value: string(valueBytes),
},
payload: payload,
ctx: ctx,
jwk: jwk,
db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
assert.Equal(t, "chID", updch.ID)
assert.Equal(t, "token", updch.Token)
assert.Equal(t, StatusValid, updch.Status)
assert.Equal(t, ChallengeType("wire-dpop-01"), updch.Type)
assert.Equal(t, string(valueBytes), updch.Value)
return nil
},
MockGetAllOrdersByAccountID: func(ctx context.Context, accountID string) ([]string, error) {
assert.Equal(t, "accID", accountID)
return []string{"orderID"}, nil
},
MockCreateDpopToken: func(ctx context.Context, orderID string, dpop map[string]interface{}) error {
assert.Equal(t, "orderID", orderID)
assert.Equal(t, "token", dpop["chal"].(string))
assert.Equal(t, "wireapp://%40alice_wire@wire.com", dpop["handle"].(string))
assert.Equal(t, "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", dpop["sub"].(string))
return nil
},
},
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
@ -870,25 +1141,63 @@ func TestChallenge_Validate(t *testing.T) {
ctx = context.Background()
}
ctx = NewClientContext(ctx, tc.vc)
if err := tc.ch.Validate(ctx, tc.db, tc.jwk, tc.payload); err != nil {
if assert.Error(t, tc.err) {
var k *Error
if errors.As(err, &k) {
assert.Equal(t, tc.err.Type, k.Type)
assert.Equal(t, tc.err.Detail, k.Detail)
assert.Equal(t, tc.err.Status, k.Status)
assert.Equal(t, tc.err.Err.Error(), k.Err.Error())
} else {
assert.Fail(t, "unexpected error type")
}
err := tc.ch.Validate(ctx, tc.db, tc.jwk, tc.payload)
if tc.err != nil {
var k *Error
if errors.As(err, &k) {
assert.Equal(t, tc.err.Type, k.Type)
assert.Equal(t, tc.err.Detail, k.Detail)
assert.Equal(t, tc.err.Status, k.Status)
assert.Equal(t, tc.err.Err.Error(), k.Err.Error())
} else {
assert.Fail(t, "unexpected error type")
}
} else {
assert.Nil(t, tc.err)
return
}
assert.NoError(t, err)
})
}
}
func mustJWKServer(t *testing.T, pub jose.JSONWebKey) *httptest.Server {
t.Helper()
mux := http.NewServeMux()
server := httptest.NewServer(mux)
b, err := json.Marshal(struct {
Keys []jose.JSONWebKey `json:"keys,omitempty"`
}{
Keys: []jose.JSONWebKey{pub},
})
require.NoError(t, err)
jwks := string(b)
wellKnown := fmt.Sprintf(`{
"issuer": "%[1]s",
"authorization_endpoint": "%[1]s/auth",
"token_endpoint": "%[1]s/token",
"jwks_uri": "%[1]s/keys",
"userinfo_endpoint": "%[1]s/userinfo",
"id_token_signing_alg_values_supported": ["ES256"]
}`, server.URL)
mux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, req *http.Request) {
_, err := io.WriteString(w, wellKnown)
if err != nil {
w.WriteHeader(500)
}
})
mux.HandleFunc("/keys", func(w http.ResponseWriter, req *http.Request) {
_, err := io.WriteString(w, jwks)
if err != nil {
w.WriteHeader(500)
}
})
t.Cleanup(server.Close)
return server
}
type errReader int
func (errReader) Read([]byte) (int, error) {
@ -4304,70 +4613,3 @@ func createSubjectAltNameExtension(dnsNames, emailAddresses x509util.MultiString
Value: rawBytes,
}, nil
}
func Test_parseAndVerifyWireAccessToken(t *testing.T) {
key := `
-----BEGIN PUBLIC KEY-----
MCowBQYDK2VwAyEAB2IYqBWXAouDt3WcCZgCM3t9gumMEKMlgMsGenSu+fA=
-----END PUBLIC KEY-----`
publicKey, err := pemutil.Parse([]byte(key))
require.NoError(t, err)
issuer := "http://wire.com:19983/clients/7a41cf5b79683410/access-token"
wireID := wire.ID{
ClientID: "wireapp://guVX5xeFS3eTatmXBIyA4A!7a41cf5b79683410@wire.com",
Handle: "wireapp://%40alice_wire@wire.com",
}
token := `eyJhbGciOiJFZERTQSIsInR5cCI6ImF0K2p3dCIsImp3ayI6eyJrdHkiOiJPS1AiLCJjcnYiOiJFZDI1NTE5IiwieCI6IkIySVlxQldYQW91RHQzV2NDWmdDTTN0OWd1bU1FS01sZ01zR2VuU3UtZkEifX0.eyJpYXQiOjE3MDQ5ODUyMDUsImV4cCI6MTcwNDk4OTE2NSwibmJmIjoxNzA0OTg1MjA1LCJpc3MiOiJodHRwOi8vd2lyZS5jb206MTk5ODMvY2xpZW50cy83YTQxY2Y1Yjc5NjgzNDEwL2FjY2Vzcy10b2tlbiIsInN1YiI6IndpcmVhcHA6Ly9ndVZYNXhlRlMzZVRhdG1YQkl5QTRBITdhNDFjZjViNzk2ODM0MTBAd2lyZS5jb20iLCJhdWQiOiJodHRwOi8vd2lyZS5jb206MTk5ODMvY2xpZW50cy83YTQxY2Y1Yjc5NjgzNDEwL2FjY2Vzcy10b2tlbiIsImp0aSI6IjQyYzQ2ZDRjLWU1MTAtNDE3NS05ZmI1LWQwNTVlMTI1YTQ5ZCIsIm5vbmNlIjoiVUVKeVIyZHFPRWh6WkZKRVlXSkJhVGt5T0RORVlURTJhRXMwZEhJeGNFYyIsImNoYWwiOiJiWFVHTnBVZmNSeDNFaEIzNHhQM3k2MmFRWm9HWlM2aiIsImNuZiI6eyJraWQiOiJvTVdmTkRKUXNJNWNQbFhONVVvQk5uY0t0YzRmMmRxMnZ3Q2pqWHNxdzdRIn0sInByb29mIjoiZXlKaGJHY2lPaUpGWkVSVFFTSXNJblI1Y0NJNkltUndiM0FyYW5kMElpd2lhbmRySWpwN0ltdDBlU0k2SWs5TFVDSXNJbU55ZGlJNklrVmtNalUxTVRraUxDSjRJam9pTVV3eFpVZ3lZVFpCWjFaMmVsUndOVnBoYkV0U1puRTJjRlpRVDNSRmFrazNhRGhVVUhwQ1dVWm5UU0o5ZlEuZXlKcFlYUWlPakUzTURRNU9EVXlNRFVzSW1WNGNDSTZNVGN3TkRrNU1qUXdOU3dpYm1KbUlqb3hOekEwT1RnMU1qQTFMQ0p6ZFdJaU9pSjNhWEpsWVhCd09pOHZaM1ZXV0RWNFpVWlRNMlZVWVhSdFdFSkplVUUwUVNFM1lUUXhZMlkxWWpjNU5qZ3pOREV3UUhkcGNtVXVZMjl0SWl3aWFuUnBJam9pTldVMk5qZzBZMkl0Tm1JME9DMDBOamhrTFdJd09URXRabVl3TkdKbFpEWmxZekpsSWl3aWJtOXVZMlVpT2lKVlJVcDVVakprY1U5RmFIcGFSa3BGV1ZkS1FtRlVhM2xQUkU1RldWUkZNbUZGY3pCa1NFbDRZMFZqSWl3aWFIUnRJam9pVUU5VFZDSXNJbWgwZFNJNkltaDBkSEE2THk5M2FYSmxMbU52YlRveE9UazRNeTlqYkdsbGJuUnpMemRoTkRGalpqVmlOemsyT0RNME1UQXZZV05qWlhOekxYUnZhMlZ1SWl3aVkyaGhiQ0k2SW1KWVZVZE9jRlZtWTFKNE0wVm9Rak0wZUZBemVUWXlZVkZhYjBkYVV6WnFJaXdpYUdGdVpHeGxJam9pZDJseVpXRndjRG92THlVME1HRnNhV05sWDNkcGNtVkFkMmx5WlM1amIyMGlMQ0owWldGdElqb2lkMmx5WlNKOS52bkN1T2JURFRLVFhCYXpyX3Z2X0xyZDBZT1Rac2xteHQtM2xKNWZKSU9iRVRidUVCTGlEaS1JVWZHcFJHTm1Dbm9IZjVocHNsWW5HeFMzSjloUmVDZyIsImNsaWVudF9pZCI6IndpcmVhcHA6Ly9ndVZYNXhlRlMzZVRhdG1YQkl5QTRBITdhNDFjZjViNzk2ODM0MTBAd2lyZS5jb20iLCJhcGlfdmVyc2lvbiI6NSwic2NvcGUiOiJ3aXJlX2NsaWVudF9pZCJ9.uCVYhmvCJm7nM1NxJQKl_XZJcSqm9eFmNmbRJkA5Wpsw70ZF1YANYC9nQ91QgsnuAbaRZMJiJt3P8ZntR2ozDQ`
ch := &Challenge{
Token: "bXUGNpUfcRx3EhB34xP3y62aQZoGZS6j",
}
issuedAtUnix, err := strconv.ParseInt("1704985205", 10, 64)
require.NoError(t, err)
issuedAt := time.Unix(issuedAtUnix, 0)
jwkBytes := []byte(`{"crv": "Ed25519", "kty": "OKP", "x": "1L1eH2a6AgVvzTp5ZalKRfq6pVPOtEjI7h8TPzBYFgM"}`)
var accountJWK jose.JSONWebKey
json.Unmarshal(jwkBytes, &accountJWK)
rawKid, err := accountJWK.Thumbprint(crypto.SHA256)
require.NoError(t, err)
accountJWK.KeyID = base64.RawURLEncoding.EncodeToString(rawKid)
at, dpop, err := parseAndVerifyWireAccessToken(wireVerifyParams{
token: token,
tokenKey: publicKey,
dpopKey: accountJWK.Public(),
dpopKeyID: accountJWK.KeyID,
issuer: issuer,
wireID: wireID,
chToken: ch.Token,
t: issuedAt.Add(1 * time.Minute), // set validation time to be one minute after issuance
})
if assert.NoError(t, err) {
// token assertions
assert.Equal(t, "42c46d4c-e510-4175-9fb5-d055e125a49d", at.ID)
assert.Equal(t, "http://wire.com:19983/clients/7a41cf5b79683410/access-token", at.Issuer)
assert.Equal(t, "wireapp://guVX5xeFS3eTatmXBIyA4A!7a41cf5b79683410@wire.com", at.Subject)
assert.Contains(t, at.Audience, "http://wire.com:19983/clients/7a41cf5b79683410/access-token")
assert.Equal(t, "bXUGNpUfcRx3EhB34xP3y62aQZoGZS6j", at.Challenge)
assert.Equal(t, "wireapp://guVX5xeFS3eTatmXBIyA4A!7a41cf5b79683410@wire.com", at.ClientID)
assert.Equal(t, 5, at.APIVersion)
assert.Equal(t, "wire_client_id", at.Scope)
if assert.NotNil(t, at.Cnf) {
assert.Equal(t, "oMWfNDJQsI5cPlXN5UoBNncKtc4f2dq2vwCjjXsqw7Q", at.Cnf.Kid)
}
// dpop proof assertions
dt := *dpop
assert.Equal(t, "bXUGNpUfcRx3EhB34xP3y62aQZoGZS6j", dt["chal"].(string))
assert.Equal(t, "wireapp://%40alice_wire@wire.com", dt["handle"].(string))
assert.Equal(t, "POST", dt["htm"].(string))
assert.Equal(t, "http://wire.com:19983/clients/7a41cf5b79683410/access-token", dt["htu"].(string))
assert.Equal(t, "5e6684cb-6b48-468d-b091-ff04bed6ec2e", dt["jti"].(string))
assert.Equal(t, "UEJyR2dqOEhzZFJEYWJBaTkyODNEYTE2aEs0dHIxcEc", dt["nonce"].(string))
assert.Equal(t, "wireapp://guVX5xeFS3eTatmXBIyA4A!7a41cf5b79683410@wire.com", dt["sub"].(string))
assert.Equal(t, "wire", dt["team"].(string))
}
}

File diff suppressed because it is too large Load Diff

@ -21,14 +21,14 @@ type dbDpopToken struct {
func (db *DB) getDBDpopToken(_ context.Context, orderID string) (*dbDpopToken, error) {
b, err := db.db.Get(wireDpopTokenTable, []byte(orderID))
if nosql.IsErrNotFound(err) {
return nil, acme.NewError(acme.ErrorMalformedType, "dpop %s not found", orderID)
return nil, acme.NewError(acme.ErrorMalformedType, "dpop token %q not found", orderID)
} else if err != nil {
return nil, errors.Wrapf(err, "error loading dpop %s", orderID)
return nil, errors.Wrapf(err, "error loading dpop %q", orderID)
}
d := new(dbDpopToken)
if err := json.Unmarshal(b, d); err != nil {
return nil, errors.Wrapf(err, "error unmarshaling dpop %s into dbDpopToken", orderID)
return nil, errors.Wrapf(err, "error unmarshaling dpop %q into dbDpopToken", orderID)
}
return d, nil
}
@ -50,7 +50,7 @@ func (db *DB) GetDpopToken(ctx context.Context, orderID string) (map[string]any,
func (db *DB) CreateDpopToken(ctx context.Context, orderID string, dpop map[string]any) error {
content, err := json.Marshal(dpop)
if err != nil {
return err
return fmt.Errorf("failed marshaling dpop token: %w", err)
}
now := clock.Now()
@ -75,13 +75,13 @@ type dbOidcToken struct {
func (db *DB) getDBOidcToken(_ context.Context, orderID string) (*dbOidcToken, error) {
b, err := db.db.Get(wireOidcTokenTable, []byte(orderID))
if nosql.IsErrNotFound(err) {
return nil, acme.NewError(acme.ErrorMalformedType, "oidc token %s not found", orderID)
return nil, acme.NewError(acme.ErrorMalformedType, "oidc token %q not found", orderID)
} else if err != nil {
return nil, errors.Wrapf(err, "error loading oidc token %s", orderID)
return nil, errors.Wrapf(err, "error loading oidc token %q", orderID)
}
o := new(dbOidcToken)
if err := json.Unmarshal(b, o); err != nil {
return nil, errors.Wrapf(err, "error unmarshaling oidc token %s into dbOidcToken", orderID)
return nil, errors.Wrapf(err, "error unmarshaling oidc token %q into dbOidcToken", orderID)
}
return o, nil
}
@ -103,7 +103,7 @@ func (db *DB) GetOidcToken(ctx context.Context, orderID string) (map[string]any,
func (db *DB) CreateOidcToken(ctx context.Context, orderID string, idToken map[string]any) error {
content, err := json.Marshal(idToken)
if err != nil {
return err
return fmt.Errorf("failed marshaling oidc token: %w", err)
}
now := clock.Now()

@ -0,0 +1,394 @@
package nosql
import (
"context"
"encoding/json"
"errors"
"testing"
"time"
"github.com/smallstep/certificates/acme"
certificatesdb "github.com/smallstep/certificates/db"
"github.com/smallstep/nosql"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDB_GetDpopToken(t *testing.T) {
type test struct {
db *DB
orderID string
expected map[string]any
expectedErr error
}
var tests = map[string]func(t *testing.T) test{
"fail/acme-not-found": func(t *testing.T) test {
dir := t.TempDir()
db, err := nosql.New("badgerv2", dir)
require.NoError(t, err)
return test{
db: &DB{
db: db,
},
orderID: "orderID",
expectedErr: &acme.Error{
Type: "urn:ietf:params:acme:error:malformed",
Status: 400,
Detail: "The request message was malformed",
Err: errors.New(`dpop token "orderID" not found`),
},
}
},
"fail/unmarshal-error": func(t *testing.T) test {
dir := t.TempDir()
db, err := nosql.New("badgerv2", dir)
require.NoError(t, err)
token := dbDpopToken{
ID: "orderID",
Content: []byte("{}"),
CreatedAt: time.Now(),
}
b, err := json.Marshal(token)
require.NoError(t, err)
err = db.Set(wireDpopTokenTable, []byte("orderID"), b[1:]) // start at index 1; corrupt JSON data
require.NoError(t, err)
return test{
db: &DB{
db: db,
},
orderID: "orderID",
expectedErr: errors.New(`error unmarshaling dpop "orderID" into dbDpopToken: invalid character ':' after top-level value`),
}
},
"fail/db.Get": func(t *testing.T) test {
db := &certificatesdb.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equal(t, wireDpopTokenTable, bucket)
assert.Equal(t, []byte("orderID"), key)
return nil, errors.New("fail")
},
}
return test{
db: &DB{
db: db,
},
orderID: "orderID",
expectedErr: errors.New(`error loading dpop "orderID": fail`),
}
},
"ok": func(t *testing.T) test {
dir := t.TempDir()
db, err := nosql.New("badgerv2", dir)
require.NoError(t, err)
token := dbDpopToken{
ID: "orderID",
Content: []byte(`{"sub": "wireapp://guVX5xeFS3eTatmXBIyA4A!7a41cf5b79683410@wire.com"}`),
CreatedAt: time.Now(),
}
b, err := json.Marshal(token)
require.NoError(t, err)
err = db.Set(wireDpopTokenTable, []byte("orderID"), b)
require.NoError(t, err)
return test{
db: &DB{
db: db,
},
orderID: "orderID",
expected: map[string]any{
"sub": "wireapp://guVX5xeFS3eTatmXBIyA4A!7a41cf5b79683410@wire.com",
},
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
got, err := tc.db.GetDpopToken(context.Background(), tc.orderID)
if tc.expectedErr != nil {
assert.EqualError(t, err, tc.expectedErr.Error())
ae := &acme.Error{}
if errors.As(err, &ae) {
ee := &acme.Error{}
require.True(t, errors.As(tc.expectedErr, &ee))
assert.Equal(t, ee.Detail, ae.Detail)
assert.Equal(t, ee.Type, ae.Type)
assert.Equal(t, ee.Status, ae.Status)
}
assert.Nil(t, got)
return
}
assert.NoError(t, err)
assert.Equal(t, tc.expected, got)
})
}
}
func TestDB_CreateDpopToken(t *testing.T) {
type test struct {
db *DB
orderID string
dpop map[string]any
expectedErr error
}
var tests = map[string]func(t *testing.T) test{
"fail/db.Save": func(t *testing.T) test {
db := &certificatesdb.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
assert.Equal(t, wireDpopTokenTable, bucket)
assert.Equal(t, []byte("orderID"), key)
return nil, false, errors.New("fail")
},
}
return test{
db: &DB{
db: db,
},
orderID: "orderID",
dpop: map[string]any{
"sub": "wireapp://guVX5xeFS3eTatmXBIyA4A!7a41cf5b79683410@wire.com",
},
expectedErr: errors.New("failed saving dpop token: error saving acme dpop: fail"),
}
},
"ok": func(t *testing.T) test {
dir := t.TempDir()
db, err := nosql.New("badgerv2", dir)
require.NoError(t, err)
return test{
db: &DB{
db: db,
},
orderID: "orderID",
dpop: map[string]any{
"sub": "wireapp://guVX5xeFS3eTatmXBIyA4A!7a41cf5b79683410@wire.com",
},
}
},
"ok/nil": func(t *testing.T) test {
dir := t.TempDir()
db, err := nosql.New("badgerv2", dir)
require.NoError(t, err)
return test{
db: &DB{
db: db,
},
orderID: "orderID",
dpop: nil,
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
err := tc.db.CreateDpopToken(context.Background(), tc.orderID, tc.dpop)
if tc.expectedErr != nil {
assert.EqualError(t, err, tc.expectedErr.Error())
return
}
assert.NoError(t, err)
dpop, err := tc.db.getDBDpopToken(context.Background(), tc.orderID)
require.NoError(t, err)
assert.Equal(t, tc.orderID, dpop.ID)
var m map[string]any
err = json.Unmarshal(dpop.Content, &m)
require.NoError(t, err)
assert.Equal(t, tc.dpop, m)
})
}
}
func TestDB_GetOidcToken(t *testing.T) {
type test struct {
db *DB
orderID string
expected map[string]any
expectedErr error
}
var tests = map[string]func(t *testing.T) test{
"fail/acme-not-found": func(t *testing.T) test {
dir := t.TempDir()
db, err := nosql.New("badgerv2", dir)
require.NoError(t, err)
return test{
db: &DB{
db: db,
},
orderID: "orderID",
expectedErr: &acme.Error{
Type: "urn:ietf:params:acme:error:malformed",
Status: 400,
Detail: "The request message was malformed",
Err: errors.New(`oidc token "orderID" not found`),
},
}
},
"fail/unmarshal-error": func(t *testing.T) test {
dir := t.TempDir()
db, err := nosql.New("badgerv2", dir)
require.NoError(t, err)
token := dbOidcToken{
ID: "orderID",
Content: []byte("{}"),
CreatedAt: time.Now(),
}
b, err := json.Marshal(token)
require.NoError(t, err)
err = db.Set(wireOidcTokenTable, []byte("orderID"), b[1:]) // start at index 1; corrupt JSON data
require.NoError(t, err)
return test{
db: &DB{
db: db,
},
orderID: "orderID",
expectedErr: errors.New(`error unmarshaling oidc token "orderID" into dbOidcToken: invalid character ':' after top-level value`),
}
},
"fail/db.Get": func(t *testing.T) test {
db := &certificatesdb.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equal(t, wireOidcTokenTable, bucket)
assert.Equal(t, []byte("orderID"), key)
return nil, errors.New("fail")
},
}
return test{
db: &DB{
db: db,
},
orderID: "orderID",
expectedErr: errors.New(`error loading oidc token "orderID": fail`),
}
},
"ok": func(t *testing.T) test {
dir := t.TempDir()
db, err := nosql.New("badgerv2", dir)
require.NoError(t, err)
token := dbOidcToken{
ID: "orderID",
Content: []byte(`{"name": "Alice Smith", "handle": "@alice.smith"}`),
CreatedAt: time.Now(),
}
b, err := json.Marshal(token)
require.NoError(t, err)
err = db.Set(wireOidcTokenTable, []byte("orderID"), b)
require.NoError(t, err)
return test{
db: &DB{
db: db,
},
orderID: "orderID",
expected: map[string]any{
"name": "Alice Smith",
"handle": "@alice.smith",
},
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
got, err := tc.db.GetOidcToken(context.Background(), tc.orderID)
if tc.expectedErr != nil {
assert.EqualError(t, err, tc.expectedErr.Error())
ae := &acme.Error{}
if errors.As(err, &ae) {
ee := &acme.Error{}
require.True(t, errors.As(tc.expectedErr, &ee))
assert.Equal(t, ee.Detail, ae.Detail)
assert.Equal(t, ee.Type, ae.Type)
assert.Equal(t, ee.Status, ae.Status)
}
assert.Nil(t, got)
return
}
assert.NoError(t, err)
assert.Equal(t, tc.expected, got)
})
}
}
func TestDB_CreateOidcToken(t *testing.T) {
type test struct {
db *DB
orderID string
oidc map[string]any
expectedErr error
}
var tests = map[string]func(t *testing.T) test{
"fail/db.Save": func(t *testing.T) test {
db := &certificatesdb.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
assert.Equal(t, wireOidcTokenTable, bucket)
assert.Equal(t, []byte("orderID"), key)
return nil, false, errors.New("fail")
},
}
return test{
db: &DB{
db: db,
},
orderID: "orderID",
oidc: map[string]any{
"name": "Alice Smith",
"handle": "@alice.smith",
},
expectedErr: errors.New("failed saving oidc token: error saving acme oidc: fail"),
}
},
"ok": func(t *testing.T) test {
dir := t.TempDir()
db, err := nosql.New("badgerv2", dir)
require.NoError(t, err)
return test{
db: &DB{
db: db,
},
orderID: "orderID",
oidc: map[string]any{
"name": "Alice Smith",
"handle": "@alice.smith",
},
}
},
"ok/nil": func(t *testing.T) test {
dir := t.TempDir()
db, err := nosql.New("badgerv2", dir)
require.NoError(t, err)
return test{
db: &DB{
db: db,
},
orderID: "orderID",
oidc: nil,
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
err := tc.db.CreateOidcToken(context.Background(), tc.orderID, tc.oidc)
if tc.expectedErr != nil {
assert.EqualError(t, err, tc.expectedErr.Error())
return
}
assert.NoError(t, err)
oidc, err := tc.db.getDBOidcToken(context.Background(), tc.orderID)
require.NoError(t, err)
assert.Equal(t, tc.orderID, oidc.ID)
var m map[string]any
err = json.Unmarshal(oidc.Content, &m)
require.NoError(t, err)
assert.Equal(t, tc.oidc, m)
})
}
}

@ -0,0 +1,58 @@
package wire
import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
)
func TestParseID(t *testing.T) {
ok := `{"name": "Alice Smith", "domain": "wire.com", "client-id": "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", "handle": "wireapp://%40alice_wire@wire.com"}`
tests := []struct {
name string
data []byte
wantWireID ID
expectedErr error
}{
{name: "ok", data: []byte(ok), wantWireID: ID{Name: "Alice Smith", Domain: "wire.com", ClientID: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", Handle: "wireapp://%40alice_wire@wire.com"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotWireID, err := ParseID(tt.data)
if tt.expectedErr != nil {
assert.EqualError(t, err, tt.expectedErr.Error())
return
}
assert.NoError(t, err)
assert.Equal(t, tt.wantWireID, gotWireID)
})
}
}
func TestParseClientID(t *testing.T) {
tests := []struct {
name string
clientID string
want ClientID
expectedErr error
}{
{name: "ok", clientID: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", want: ClientID{Scheme: "wireapp", Username: "CzbfFjDOQrenCbDxVmgnFw", DeviceID: "594930e9d50bb175", Domain: "wire.com"}},
{name: "fail/uri", clientID: "bla", expectedErr: errors.New(`invalid Wire client ID URI "bla": error parsing bla: scheme is missing`)},
{name: "fail/scheme", clientID: "not-wireapp://bla.com", expectedErr: errors.New(`invalid Wire client ID scheme "not-wireapp"; expected "wireapp"`)},
{name: "fail/username", clientID: "wireapp://user@wire.com", expectedErr: errors.New(`invalid Wire client ID username "user"`)},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ParseClientID(tt.clientID)
if tt.expectedErr != nil {
assert.EqualError(t, err, tt.expectedErr.Error())
return
}
assert.NoError(t, err)
assert.Equal(t, tt.want, got)
})
}
}

@ -57,6 +57,9 @@ func (o *Options) GetSSHOptions() *SSHOptions {
// GetWireOptions returns the SSH options.
func (o *Options) GetWireOptions() (*wire.Options, error) {
if o == nil {
return nil, errors.New("no options available")
}
if o.Wire == nil {
return nil, errors.New("no Wire options available")
}
if err := o.Wire.Validate(); err != nil {

@ -3,6 +3,7 @@ package wire
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"net/url"
@ -10,6 +11,7 @@ import (
"time"
"github.com/coreos/go-oidc/v3/oidc"
"go.step.sm/crypto/x509util"
)
type Provider struct {
@ -32,11 +34,13 @@ type Config struct {
}
type OIDCOptions struct {
Provider *Provider `json:"provider,omitempty"`
Config *Config `json:"config,omitempty"`
Provider *Provider `json:"provider,omitempty"`
Config *Config `json:"config,omitempty"`
TransformTemplate string `json:"transform,omitempty"`
oidcProviderConfig *oidc.ProviderConfig
target *template.Template
transform *template.Template
}
func (o *OIDCOptions) GetProvider(ctx context.Context) *oidc.Provider {
@ -62,6 +66,8 @@ func (o *OIDCOptions) GetConfig() *oidc.Config {
}
}
const defaultTemplate = `{"name": "{{ .name }}", "handle": "{{ .preferred_username }}"}`
func (o *OIDCOptions) validateAndInitialize() (err error) {
if o.Provider == nil {
return errors.New("provider not set")
@ -80,9 +86,22 @@ func (o *OIDCOptions) validateAndInitialize() (err error) {
return fmt.Errorf("failed parsing OIDC template: %w", err)
}
o.transform, err = parseTransform(o.TransformTemplate)
if err != nil {
return fmt.Errorf("failed parsing OIDC transformation template: %w", err)
}
return nil
}
func parseTransform(transformTemplate string) (*template.Template, error) {
if transformTemplate == "" {
transformTemplate = defaultTemplate
}
return template.New("transform").Funcs(x509util.GetFuncMap()).Parse(transformTemplate)
}
func (o *OIDCOptions) EvaluateTarget(deviceID string) (string, error) {
buf := new(bytes.Buffer)
if err := o.target.Execute(buf, struct{ DeviceID string }{DeviceID: deviceID}); err != nil {
@ -91,6 +110,28 @@ func (o *OIDCOptions) EvaluateTarget(deviceID string) (string, error) {
return buf.String(), nil
}
func (o *OIDCOptions) Transform(v map[string]any) (map[string]any, error) {
if o.transform == nil || v == nil {
return v, nil
}
// TODO(hs): add support for extracting error message from template "fail" function?
buf := new(bytes.Buffer)
if err := o.transform.Execute(buf, v); err != nil {
return nil, fmt.Errorf("failed executing OIDC transformation: %w", err)
}
var r map[string]any
if err := json.Unmarshal(buf.Bytes(), &r); err != nil {
return nil, fmt.Errorf("failed unmarshaling transformed OIDC token: %w", err)
}
// add original claims if not yet in the transformed result
for key, value := range v {
if _, ok := r[key]; !ok {
r[key] = value
}
}
return r, nil
}
func toOIDCProviderConfig(in *Provider) (*oidc.ProviderConfig, error) {
issuerURL, err := url.Parse(in.IssuerURL)
if err != nil {

@ -0,0 +1,123 @@
package wire
import (
"testing"
"text/template"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestOIDCOptions_Transform(t *testing.T) {
defaultTransform, err := parseTransform(``)
require.NoError(t, err)
swapTransform, err := parseTransform(`{"name": "{{ .preferred_username }}", "handle": "{{ .name }}"}`)
require.NoError(t, err)
funcTransform, err := parseTransform(`{"name": "{{ .name }}", "handle": "{{ first .usernames }}"}`)
require.NoError(t, err)
type fields struct {
transform *template.Template
}
type args struct {
v map[string]any
}
tests := []struct {
name string
fields fields
args args
want map[string]any
expectedErr error
}{
{
name: "ok/no-transform",
fields: fields{
transform: nil,
},
args: args{
v: map[string]any{
"name": "Example",
"preferred_username": "Preferred",
},
},
want: map[string]any{
"name": "Example",
"preferred_username": "Preferred",
},
},
{
name: "ok/empty-data",
fields: fields{
transform: nil,
},
args: args{
v: map[string]any{},
},
want: map[string]any{},
},
{
name: "ok/default-transform",
fields: fields{
transform: defaultTransform,
},
args: args{
v: map[string]any{
"name": "Example",
"preferred_username": "Preferred",
},
},
want: map[string]any{
"name": "Example",
"handle": "Preferred",
"preferred_username": "Preferred",
},
},
{
name: "ok/swap-transform",
fields: fields{
transform: swapTransform,
},
args: args{
v: map[string]any{
"name": "Example",
"preferred_username": "Preferred",
},
},
want: map[string]any{
"name": "Preferred",
"handle": "Example",
"preferred_username": "Preferred",
},
},
{
name: "ok/transform-with-functions",
fields: fields{
transform: funcTransform,
},
args: args{
v: map[string]any{
"name": "Example",
"usernames": []string{"name-1", "name-2", "name-3"},
},
},
want: map[string]any{
"name": "Example",
"handle": "name-1",
"usernames": []string{"name-1", "name-2", "name-3"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
o := &OIDCOptions{
transform: tt.fields.transform,
}
got, err := o.Transform(tt.args.v)
if tt.expectedErr != nil {
assert.Error(t, err)
return
}
assert.Equal(t, tt.want, got)
})
}
}

@ -83,6 +83,22 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k=
},
expectedErr: errors.New(`failed initializing OIDC options: failed parsing OIDC template: template: DeviceID:1: unexpected "}" in command`),
},
{
name: "fail/invalid-transform-template",
fields: fields{
OIDC: &OIDCOptions{
Provider: &Provider{
IssuerURL: "https://example.com",
},
Config: &Config{},
TransformTemplate: "{{}",
},
DPOP: &DPOPOptions{
SigningKey: key,
},
},
expectedErr: errors.New(`failed initializing OIDC options: failed parsing OIDC transformation template: template: transform:1: unexpected "}" in command`),
},
{
name: "fail/no-dpop-options",
fields: fields{

Loading…
Cancel
Save