Merge pull request #1673 from smallstep/herman/wire-template-transform
Add OIDC token template transformationpull/1670/head
commit
17578b57f2
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,394 @@
|
||||
package nosql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/smallstep/certificates/acme"
|
||||
certificatesdb "github.com/smallstep/certificates/db"
|
||||
"github.com/smallstep/nosql"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDB_GetDpopToken(t *testing.T) {
|
||||
type test struct {
|
||||
db *DB
|
||||
orderID string
|
||||
expected map[string]any
|
||||
expectedErr error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/acme-not-found": func(t *testing.T) test {
|
||||
dir := t.TempDir()
|
||||
db, err := nosql.New("badgerv2", dir)
|
||||
require.NoError(t, err)
|
||||
return test{
|
||||
db: &DB{
|
||||
db: db,
|
||||
},
|
||||
orderID: "orderID",
|
||||
expectedErr: &acme.Error{
|
||||
Type: "urn:ietf:params:acme:error:malformed",
|
||||
Status: 400,
|
||||
Detail: "The request message was malformed",
|
||||
Err: errors.New(`dpop token "orderID" not found`),
|
||||
},
|
||||
}
|
||||
},
|
||||
"fail/unmarshal-error": func(t *testing.T) test {
|
||||
dir := t.TempDir()
|
||||
db, err := nosql.New("badgerv2", dir)
|
||||
require.NoError(t, err)
|
||||
token := dbDpopToken{
|
||||
ID: "orderID",
|
||||
Content: []byte("{}"),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
b, err := json.Marshal(token)
|
||||
require.NoError(t, err)
|
||||
err = db.Set(wireDpopTokenTable, []byte("orderID"), b[1:]) // start at index 1; corrupt JSON data
|
||||
require.NoError(t, err)
|
||||
return test{
|
||||
db: &DB{
|
||||
db: db,
|
||||
},
|
||||
orderID: "orderID",
|
||||
expectedErr: errors.New(`error unmarshaling dpop "orderID" into dbDpopToken: invalid character ':' after top-level value`),
|
||||
}
|
||||
},
|
||||
"fail/db.Get": func(t *testing.T) test {
|
||||
db := &certificatesdb.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equal(t, wireDpopTokenTable, bucket)
|
||||
assert.Equal(t, []byte("orderID"), key)
|
||||
return nil, errors.New("fail")
|
||||
},
|
||||
}
|
||||
return test{
|
||||
db: &DB{
|
||||
db: db,
|
||||
},
|
||||
orderID: "orderID",
|
||||
expectedErr: errors.New(`error loading dpop "orderID": fail`),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
dir := t.TempDir()
|
||||
db, err := nosql.New("badgerv2", dir)
|
||||
require.NoError(t, err)
|
||||
token := dbDpopToken{
|
||||
ID: "orderID",
|
||||
Content: []byte(`{"sub": "wireapp://guVX5xeFS3eTatmXBIyA4A!7a41cf5b79683410@wire.com"}`),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
b, err := json.Marshal(token)
|
||||
require.NoError(t, err)
|
||||
err = db.Set(wireDpopTokenTable, []byte("orderID"), b)
|
||||
require.NoError(t, err)
|
||||
return test{
|
||||
db: &DB{
|
||||
db: db,
|
||||
},
|
||||
orderID: "orderID",
|
||||
expected: map[string]any{
|
||||
"sub": "wireapp://guVX5xeFS3eTatmXBIyA4A!7a41cf5b79683410@wire.com",
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
got, err := tc.db.GetDpopToken(context.Background(), tc.orderID)
|
||||
if tc.expectedErr != nil {
|
||||
assert.EqualError(t, err, tc.expectedErr.Error())
|
||||
ae := &acme.Error{}
|
||||
if errors.As(err, &ae) {
|
||||
ee := &acme.Error{}
|
||||
require.True(t, errors.As(tc.expectedErr, &ee))
|
||||
assert.Equal(t, ee.Detail, ae.Detail)
|
||||
assert.Equal(t, ee.Type, ae.Type)
|
||||
assert.Equal(t, ee.Status, ae.Status)
|
||||
}
|
||||
assert.Nil(t, got)
|
||||
return
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tc.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDB_CreateDpopToken(t *testing.T) {
|
||||
type test struct {
|
||||
db *DB
|
||||
orderID string
|
||||
dpop map[string]any
|
||||
expectedErr error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/db.Save": func(t *testing.T) test {
|
||||
db := &certificatesdb.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equal(t, wireDpopTokenTable, bucket)
|
||||
assert.Equal(t, []byte("orderID"), key)
|
||||
return nil, false, errors.New("fail")
|
||||
},
|
||||
}
|
||||
return test{
|
||||
db: &DB{
|
||||
db: db,
|
||||
},
|
||||
orderID: "orderID",
|
||||
dpop: map[string]any{
|
||||
"sub": "wireapp://guVX5xeFS3eTatmXBIyA4A!7a41cf5b79683410@wire.com",
|
||||
},
|
||||
expectedErr: errors.New("failed saving dpop token: error saving acme dpop: fail"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
dir := t.TempDir()
|
||||
db, err := nosql.New("badgerv2", dir)
|
||||
require.NoError(t, err)
|
||||
return test{
|
||||
db: &DB{
|
||||
db: db,
|
||||
},
|
||||
orderID: "orderID",
|
||||
dpop: map[string]any{
|
||||
"sub": "wireapp://guVX5xeFS3eTatmXBIyA4A!7a41cf5b79683410@wire.com",
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok/nil": func(t *testing.T) test {
|
||||
dir := t.TempDir()
|
||||
db, err := nosql.New("badgerv2", dir)
|
||||
require.NoError(t, err)
|
||||
return test{
|
||||
db: &DB{
|
||||
db: db,
|
||||
},
|
||||
orderID: "orderID",
|
||||
dpop: nil,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
err := tc.db.CreateDpopToken(context.Background(), tc.orderID, tc.dpop)
|
||||
if tc.expectedErr != nil {
|
||||
assert.EqualError(t, err, tc.expectedErr.Error())
|
||||
return
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
|
||||
dpop, err := tc.db.getDBDpopToken(context.Background(), tc.orderID)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.orderID, dpop.ID)
|
||||
var m map[string]any
|
||||
err = json.Unmarshal(dpop.Content, &m)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.dpop, m)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDB_GetOidcToken(t *testing.T) {
|
||||
type test struct {
|
||||
db *DB
|
||||
orderID string
|
||||
expected map[string]any
|
||||
expectedErr error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/acme-not-found": func(t *testing.T) test {
|
||||
dir := t.TempDir()
|
||||
db, err := nosql.New("badgerv2", dir)
|
||||
require.NoError(t, err)
|
||||
return test{
|
||||
db: &DB{
|
||||
db: db,
|
||||
},
|
||||
orderID: "orderID",
|
||||
expectedErr: &acme.Error{
|
||||
Type: "urn:ietf:params:acme:error:malformed",
|
||||
Status: 400,
|
||||
Detail: "The request message was malformed",
|
||||
Err: errors.New(`oidc token "orderID" not found`),
|
||||
},
|
||||
}
|
||||
},
|
||||
"fail/unmarshal-error": func(t *testing.T) test {
|
||||
dir := t.TempDir()
|
||||
db, err := nosql.New("badgerv2", dir)
|
||||
require.NoError(t, err)
|
||||
token := dbOidcToken{
|
||||
ID: "orderID",
|
||||
Content: []byte("{}"),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
b, err := json.Marshal(token)
|
||||
require.NoError(t, err)
|
||||
err = db.Set(wireOidcTokenTable, []byte("orderID"), b[1:]) // start at index 1; corrupt JSON data
|
||||
require.NoError(t, err)
|
||||
return test{
|
||||
db: &DB{
|
||||
db: db,
|
||||
},
|
||||
orderID: "orderID",
|
||||
expectedErr: errors.New(`error unmarshaling oidc token "orderID" into dbOidcToken: invalid character ':' after top-level value`),
|
||||
}
|
||||
},
|
||||
"fail/db.Get": func(t *testing.T) test {
|
||||
db := &certificatesdb.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equal(t, wireOidcTokenTable, bucket)
|
||||
assert.Equal(t, []byte("orderID"), key)
|
||||
return nil, errors.New("fail")
|
||||
},
|
||||
}
|
||||
return test{
|
||||
db: &DB{
|
||||
db: db,
|
||||
},
|
||||
orderID: "orderID",
|
||||
expectedErr: errors.New(`error loading oidc token "orderID": fail`),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
dir := t.TempDir()
|
||||
db, err := nosql.New("badgerv2", dir)
|
||||
require.NoError(t, err)
|
||||
token := dbOidcToken{
|
||||
ID: "orderID",
|
||||
Content: []byte(`{"name": "Alice Smith", "handle": "@alice.smith"}`),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
b, err := json.Marshal(token)
|
||||
require.NoError(t, err)
|
||||
err = db.Set(wireOidcTokenTable, []byte("orderID"), b)
|
||||
require.NoError(t, err)
|
||||
return test{
|
||||
db: &DB{
|
||||
db: db,
|
||||
},
|
||||
orderID: "orderID",
|
||||
expected: map[string]any{
|
||||
"name": "Alice Smith",
|
||||
"handle": "@alice.smith",
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
got, err := tc.db.GetOidcToken(context.Background(), tc.orderID)
|
||||
if tc.expectedErr != nil {
|
||||
assert.EqualError(t, err, tc.expectedErr.Error())
|
||||
ae := &acme.Error{}
|
||||
if errors.As(err, &ae) {
|
||||
ee := &acme.Error{}
|
||||
require.True(t, errors.As(tc.expectedErr, &ee))
|
||||
assert.Equal(t, ee.Detail, ae.Detail)
|
||||
assert.Equal(t, ee.Type, ae.Type)
|
||||
assert.Equal(t, ee.Status, ae.Status)
|
||||
}
|
||||
assert.Nil(t, got)
|
||||
return
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tc.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDB_CreateOidcToken(t *testing.T) {
|
||||
type test struct {
|
||||
db *DB
|
||||
orderID string
|
||||
oidc map[string]any
|
||||
expectedErr error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/db.Save": func(t *testing.T) test {
|
||||
db := &certificatesdb.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equal(t, wireOidcTokenTable, bucket)
|
||||
assert.Equal(t, []byte("orderID"), key)
|
||||
return nil, false, errors.New("fail")
|
||||
},
|
||||
}
|
||||
return test{
|
||||
db: &DB{
|
||||
db: db,
|
||||
},
|
||||
orderID: "orderID",
|
||||
oidc: map[string]any{
|
||||
"name": "Alice Smith",
|
||||
"handle": "@alice.smith",
|
||||
},
|
||||
expectedErr: errors.New("failed saving oidc token: error saving acme oidc: fail"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
dir := t.TempDir()
|
||||
db, err := nosql.New("badgerv2", dir)
|
||||
require.NoError(t, err)
|
||||
return test{
|
||||
db: &DB{
|
||||
db: db,
|
||||
},
|
||||
orderID: "orderID",
|
||||
oidc: map[string]any{
|
||||
"name": "Alice Smith",
|
||||
"handle": "@alice.smith",
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok/nil": func(t *testing.T) test {
|
||||
dir := t.TempDir()
|
||||
db, err := nosql.New("badgerv2", dir)
|
||||
require.NoError(t, err)
|
||||
return test{
|
||||
db: &DB{
|
||||
db: db,
|
||||
},
|
||||
orderID: "orderID",
|
||||
oidc: nil,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
err := tc.db.CreateOidcToken(context.Background(), tc.orderID, tc.oidc)
|
||||
if tc.expectedErr != nil {
|
||||
assert.EqualError(t, err, tc.expectedErr.Error())
|
||||
return
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
|
||||
oidc, err := tc.db.getDBOidcToken(context.Background(), tc.orderID)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.orderID, oidc.ID)
|
||||
var m map[string]any
|
||||
err = json.Unmarshal(oidc.Content, &m)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.oidc, m)
|
||||
})
|
||||
}
|
||||
}
|
@ -0,0 +1,58 @@
|
||||
package wire
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestParseID(t *testing.T) {
|
||||
ok := `{"name": "Alice Smith", "domain": "wire.com", "client-id": "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", "handle": "wireapp://%40alice_wire@wire.com"}`
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
wantWireID ID
|
||||
expectedErr error
|
||||
}{
|
||||
{name: "ok", data: []byte(ok), wantWireID: ID{Name: "Alice Smith", Domain: "wire.com", ClientID: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", Handle: "wireapp://%40alice_wire@wire.com"}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotWireID, err := ParseID(tt.data)
|
||||
if tt.expectedErr != nil {
|
||||
assert.EqualError(t, err, tt.expectedErr.Error())
|
||||
return
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.wantWireID, gotWireID)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseClientID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
clientID string
|
||||
want ClientID
|
||||
expectedErr error
|
||||
}{
|
||||
{name: "ok", clientID: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", want: ClientID{Scheme: "wireapp", Username: "CzbfFjDOQrenCbDxVmgnFw", DeviceID: "594930e9d50bb175", Domain: "wire.com"}},
|
||||
{name: "fail/uri", clientID: "bla", expectedErr: errors.New(`invalid Wire client ID URI "bla": error parsing bla: scheme is missing`)},
|
||||
{name: "fail/scheme", clientID: "not-wireapp://bla.com", expectedErr: errors.New(`invalid Wire client ID scheme "not-wireapp"; expected "wireapp"`)},
|
||||
{name: "fail/username", clientID: "wireapp://user@wire.com", expectedErr: errors.New(`invalid Wire client ID username "user"`)},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := ParseClientID(tt.clientID)
|
||||
if tt.expectedErr != nil {
|
||||
assert.EqualError(t, err, tt.expectedErr.Error())
|
||||
return
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
@ -0,0 +1,123 @@
|
||||
package wire
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"text/template"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestOIDCOptions_Transform(t *testing.T) {
|
||||
defaultTransform, err := parseTransform(``)
|
||||
require.NoError(t, err)
|
||||
swapTransform, err := parseTransform(`{"name": "{{ .preferred_username }}", "handle": "{{ .name }}"}`)
|
||||
require.NoError(t, err)
|
||||
funcTransform, err := parseTransform(`{"name": "{{ .name }}", "handle": "{{ first .usernames }}"}`)
|
||||
require.NoError(t, err)
|
||||
type fields struct {
|
||||
transform *template.Template
|
||||
}
|
||||
type args struct {
|
||||
v map[string]any
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want map[string]any
|
||||
expectedErr error
|
||||
}{
|
||||
{
|
||||
name: "ok/no-transform",
|
||||
fields: fields{
|
||||
transform: nil,
|
||||
},
|
||||
args: args{
|
||||
v: map[string]any{
|
||||
"name": "Example",
|
||||
"preferred_username": "Preferred",
|
||||
},
|
||||
},
|
||||
want: map[string]any{
|
||||
"name": "Example",
|
||||
"preferred_username": "Preferred",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ok/empty-data",
|
||||
fields: fields{
|
||||
transform: nil,
|
||||
},
|
||||
args: args{
|
||||
v: map[string]any{},
|
||||
},
|
||||
want: map[string]any{},
|
||||
},
|
||||
{
|
||||
name: "ok/default-transform",
|
||||
fields: fields{
|
||||
transform: defaultTransform,
|
||||
},
|
||||
args: args{
|
||||
v: map[string]any{
|
||||
"name": "Example",
|
||||
"preferred_username": "Preferred",
|
||||
},
|
||||
},
|
||||
want: map[string]any{
|
||||
"name": "Example",
|
||||
"handle": "Preferred",
|
||||
"preferred_username": "Preferred",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ok/swap-transform",
|
||||
fields: fields{
|
||||
transform: swapTransform,
|
||||
},
|
||||
args: args{
|
||||
v: map[string]any{
|
||||
"name": "Example",
|
||||
"preferred_username": "Preferred",
|
||||
},
|
||||
},
|
||||
want: map[string]any{
|
||||
"name": "Preferred",
|
||||
"handle": "Example",
|
||||
"preferred_username": "Preferred",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ok/transform-with-functions",
|
||||
fields: fields{
|
||||
transform: funcTransform,
|
||||
},
|
||||
args: args{
|
||||
v: map[string]any{
|
||||
"name": "Example",
|
||||
"usernames": []string{"name-1", "name-2", "name-3"},
|
||||
},
|
||||
},
|
||||
want: map[string]any{
|
||||
"name": "Example",
|
||||
"handle": "name-1",
|
||||
"usernames": []string{"name-1", "name-2", "name-3"},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
o := &OIDCOptions{
|
||||
transform: tt.fields.transform,
|
||||
}
|
||||
got, err := o.Transform(tt.args.v)
|
||||
if tt.expectedErr != nil {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue