testing work in progress.
parent
83848e9cd3
commit
54d86ca1c1
@ -0,0 +1,184 @@
|
||||
package provisioner
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
// func Test_newSortedProvisioners(t *testing.T) {
|
||||
// provisioners := make(List, 20)
|
||||
// for i := range provisioners {
|
||||
// provisioners[i] = generateProvisioner(t)
|
||||
// }
|
||||
|
||||
// ps, err := newSortedProvisioners(provisioners)
|
||||
// assert.FatalError(t, err)
|
||||
// prev := ""
|
||||
// for i, p := range ps {
|
||||
// if p.uid < prev {
|
||||
// t.Errorf("%s should be less that %s", p.uid, prev)
|
||||
// }
|
||||
// if p.provisioner.Key.KeyID != provisioners[i].Key.KeyID {
|
||||
// t.Errorf("provisioner order is not the same: %s != %s", p.provisioner.Key.KeyID, provisioners[i].Key.KeyID)
|
||||
// }
|
||||
// prev = p.uid
|
||||
// }
|
||||
// }
|
||||
|
||||
// func Test_provisionerSlice_Find(t *testing.T) {
|
||||
// trim := func(s string) string {
|
||||
// return strings.TrimLeft(s, "0")
|
||||
// }
|
||||
// provisioners := make([]*Provisioner, 20)
|
||||
// for i := range provisioners {
|
||||
// provisioners[i] = generateProvisioner(t)
|
||||
// }
|
||||
// ps, err := newSortedProvisioners(provisioners)
|
||||
// assert.FatalError(t, err)
|
||||
|
||||
// type args struct {
|
||||
// cursor string
|
||||
// limit int
|
||||
// }
|
||||
// tests := []struct {
|
||||
// name string
|
||||
// p provisionerSlice
|
||||
// args args
|
||||
// want []*JWK
|
||||
// want1 string
|
||||
// }{
|
||||
// {"all", ps, args{"", DefaultProvisionersMax}, provisioners[0:20], ""},
|
||||
// {"0 to 19", ps, args{"", 20}, provisioners[0:20], ""},
|
||||
// {"0 to 9", ps, args{"", 10}, provisioners[0:10], trim(ps[10].uid)},
|
||||
// {"9 to 19", ps, args{trim(ps[10].uid), 10}, provisioners[10:20], ""},
|
||||
// {"1", ps, args{trim(ps[1].uid), 1}, provisioners[1:2], trim(ps[2].uid)},
|
||||
// {"1 to 5", ps, args{trim(ps[1].uid), 4}, provisioners[1:5], trim(ps[5].uid)},
|
||||
// {"defaultLimit", ps, args{"", 0}, provisioners[0:20], ""},
|
||||
// {"overTheLimit", ps, args{"", DefaultProvisionersMax + 1}, provisioners[0:20], ""},
|
||||
// }
|
||||
// for _, tt := range tests {
|
||||
// t.Run(tt.name, func(t *testing.T) {
|
||||
// got, got1 := tt.p.Find(tt.args.cursor, tt.args.limit)
|
||||
// if !reflect.DeepEqual(got, tt.want) {
|
||||
// t.Errorf("provisionerSlice.Find() got = %v, want %v", got, tt.want)
|
||||
// }
|
||||
// if got1 != tt.want1 {
|
||||
// t.Errorf("provisionerSlice.Find() got1 = %v, want %v", got1, tt.want1)
|
||||
// }
|
||||
// })
|
||||
// }
|
||||
// }
|
||||
|
||||
func TestCollection_Load(t *testing.T) {
|
||||
p, err := generateJWK()
|
||||
assert.FatalError(t, err)
|
||||
byID := new(sync.Map)
|
||||
byID.Store(p.GetID(), p)
|
||||
byID.Store("string", "a-string")
|
||||
|
||||
type fields struct {
|
||||
byID *sync.Map
|
||||
}
|
||||
type args struct {
|
||||
id string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want Interface
|
||||
want1 bool
|
||||
}{
|
||||
{"ok", fields{byID}, args{p.GetID()}, p, true},
|
||||
{"fail", fields{byID}, args{"fail"}, nil, false},
|
||||
{"invalid", fields{byID}, args{"string"}, nil, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Collection{
|
||||
byID: tt.fields.byID,
|
||||
}
|
||||
got, got1 := c.Load(tt.args.id)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Collection.Load() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
if got1 != tt.want1 {
|
||||
t.Errorf("Collection.Load() got1 = %v, want %v", got1, tt.want1)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollection_LoadByToken(t *testing.T) {
|
||||
p1, err := generateJWK()
|
||||
assert.FatalError(t, err)
|
||||
p2, err := generateJWK()
|
||||
assert.FatalError(t, err)
|
||||
p3, err := generateOIDC()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
byID := new(sync.Map)
|
||||
byID.Store(p1.GetID(), p1)
|
||||
byID.Store(p2.GetID(), p2)
|
||||
byID.Store(p3.GetID(), p3)
|
||||
byID.Store("string", "a-string")
|
||||
|
||||
jwk, err := decryptJSONWebKey(p1.EncryptedKey)
|
||||
assert.FatalError(t, err)
|
||||
token, err := generateSimpleToken(p1.Name, testAudiences[0], jwk)
|
||||
assert.FatalError(t, err)
|
||||
t1, c1, err := parseToken(token)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
jwk, err = decryptJSONWebKey(p2.EncryptedKey)
|
||||
token, err = generateSimpleToken(p2.Name, testAudiences[1], jwk)
|
||||
assert.FatalError(t, err)
|
||||
t2, c2, err := parseToken(token)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
token, err = generateSimpleToken(p3.configuration.Issuer, p3.ClientID, &p3.keyStore.keys.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
t3, c3, err := parseToken(token)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type fields struct {
|
||||
byID *sync.Map
|
||||
audiences []string
|
||||
}
|
||||
type args struct {
|
||||
token *jose.JSONWebToken
|
||||
claims *jose.Claims
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want Interface
|
||||
want1 bool
|
||||
}{
|
||||
{"ok1", fields{byID, testAudiences}, args{t1, c1}, p1, true},
|
||||
{"ok2", fields{byID, testAudiences}, args{t2, c2}, p2, true},
|
||||
{"ok3", fields{byID, testAudiences}, args{t3, c3}, p3, true},
|
||||
{"fail", fields{byID, []string{"https://foo"}}, args{t1, c1}, nil, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Collection{
|
||||
byID: tt.fields.byID,
|
||||
audiences: tt.fields.audiences,
|
||||
}
|
||||
got, got1 := c.LoadByToken(tt.args.token, tt.args.claims)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Collection.LoadByToken() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
if got1 != tt.want1 {
|
||||
t.Errorf("Collection.LoadByToken() got1 = %v, want %v", got1, tt.want1)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -1,4 +1,6 @@
|
||||
package authority
|
||||
// +build ignore
|
||||
|
||||
package provisioner
|
||||
|
||||
import (
|
||||
"crypto/x509/pkix"
|
@ -0,0 +1,11 @@
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIBhzCCASygAwIBAgIRANJiwPnM38wWznkJGOcIyIYwCgYIKoZIzj0EAwIwITEf
|
||||
MB0GA1UEAxMWU21hbGxzdGVwIFRlc3QgUm9vdCBDQTAeFw0xODA5MjcxODE4MDla
|
||||
Fw0yODA5MjQxODE4MDlaMCExHzAdBgNVBAMTFlNtYWxsc3RlcCBUZXN0IFJvb3Qg
|
||||
Q0EwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAAS15w7dx9zPjCnQ7+RlRkvUXQJN
|
||||
Fjk5Hg5K9nCoiiNQQhcQMw63/pXQxHNsugiMshcN59XJC8195KJPm25nXN8co0Uw
|
||||
QzAOBgNVHQ8BAf8EBAMCAaYwEgYDVR0TAQH/BAgwBgEB/wIBATAdBgNVHQ4EFgQU
|
||||
B2BAXUSPZbFjnY6VzbApV48Tn3owCgYIKoZIzj0EAwIDSQAwRgIhAJRTVmc2xW8c
|
||||
ESx4oIp2d/OX9KBZzpcNi9fHnnJCS0FXAiEA7OpFb2+b8KBzg1c02x21PS7pHoET
|
||||
/A8LXNH4M06A7vE=
|
||||
-----END CERTIFICATE-----
|
@ -0,0 +1,208 @@
|
||||
package provisioner
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/smallstep/cli/crypto/randutil"
|
||||
"github.com/smallstep/cli/jose"
|
||||
"github.com/smallstep/cli/token"
|
||||
"github.com/smallstep/cli/token/provision"
|
||||
)
|
||||
|
||||
var testAudiences = []string{
|
||||
"https://ca.smallstep.com/sign",
|
||||
"https://ca.smallsteomcom/1.0/sign",
|
||||
}
|
||||
|
||||
func generateJSONWebKey() (*jose.JSONWebKey, error) {
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fp, err := jwk.Thumbprint(crypto.SHA256)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
jwk.KeyID = string(hex.EncodeToString(fp))
|
||||
return jwk, nil
|
||||
}
|
||||
|
||||
func encryptJSONWebKey(jwk *jose.JSONWebKey) (*jose.JSONWebEncryption, error) {
|
||||
b, err := json.Marshal(jwk)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
salt, err := randutil.Salt(jose.PBKDF2SaltSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
opts := new(jose.EncrypterOptions)
|
||||
opts.WithContentType(jose.ContentType("jwk+json"))
|
||||
recipient := jose.Recipient{
|
||||
Algorithm: jose.PBES2_HS256_A128KW,
|
||||
Key: []byte("password"),
|
||||
PBES2Count: jose.PBKDF2Iterations,
|
||||
PBES2Salt: salt,
|
||||
}
|
||||
encrypter, err := jose.NewEncrypter(jose.DefaultEncAlgorithm, recipient, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return encrypter.Encrypt(b)
|
||||
}
|
||||
|
||||
func decryptJSONWebKey(key string) (*jose.JSONWebKey, error) {
|
||||
enc, err := jose.ParseEncrypted(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
b, err := enc.Decrypt([]byte("password"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
jwk := new(jose.JSONWebKey)
|
||||
if err := json.Unmarshal(b, jwk); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return jwk, nil
|
||||
}
|
||||
|
||||
func generateJWK() (*JWK, error) {
|
||||
name, err := randutil.Alphanumeric(10)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
jwk, err := generateJSONWebKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
jwe, err := encryptJSONWebKey(jwk)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
public := jwk.Public()
|
||||
encrypted, err := jwe.CompactSerialize()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &JWK{
|
||||
Name: name,
|
||||
Type: "JWK",
|
||||
Key: &public,
|
||||
EncryptedKey: encrypted,
|
||||
audiences: testAudiences,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func generateOIDC() (*OIDC, error) {
|
||||
name, err := randutil.Alphanumeric(10)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clientID, err := randutil.Alphanumeric(10)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
issuer, err := randutil.Alphanumeric(10)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
jwk, err := generateJSONWebKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &OIDC{
|
||||
Name: name,
|
||||
Type: "OIDC",
|
||||
ClientID: clientID,
|
||||
ConfigurationEndpoint: "https://example.com/.well-known/openid-configuration",
|
||||
configuration: openIDConfiguration{
|
||||
Issuer: issuer,
|
||||
JWKSetURI: "https://example.com/.well-known/jwks",
|
||||
},
|
||||
keyStore: &keyStore{
|
||||
keys: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}},
|
||||
expiry: time.Now().Add(24 * time.Hour),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func generateCollection(nJWK, nOIDC int) (*Collection, error) {
|
||||
col := NewCollection(testAudiences)
|
||||
for i := 0; i < nJWK; i++ {
|
||||
p, err := generateJWK()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
col.Store(p)
|
||||
}
|
||||
for i := 0; i < nOIDC; i++ {
|
||||
p, err := generateOIDC()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
col.Store(p)
|
||||
}
|
||||
return col, nil
|
||||
}
|
||||
|
||||
func generateSimpleToken(iss, aud string, jwk *jose.JSONWebKey) (string, error) {
|
||||
now := time.Now()
|
||||
return generateToken("the-sub", []string{"test.smallstep.com"}, jwk.KeyID, iss, aud, "testdata/root_ca.crt", now, now.Add(5*time.Minute), jwk)
|
||||
}
|
||||
|
||||
func generateToken(sub string, sans []string, kid, iss, aud, root string, notBefore, notAfter time.Time, jwk *jose.JSONWebKey) (string, error) {
|
||||
// A random jwt id will be used to identify duplicated tokens
|
||||
jwtID, err := randutil.Hex(64) // 256 bits
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
tokOptions := []token.Options{
|
||||
token.WithJWTID(jwtID),
|
||||
token.WithKid(kid),
|
||||
token.WithIssuer(iss),
|
||||
token.WithAudience(aud),
|
||||
}
|
||||
if len(root) > 0 {
|
||||
tokOptions = append(tokOptions, token.WithRootCA(root))
|
||||
}
|
||||
|
||||
// If there are no SANs then add the 'subject' (common-name) as the only SAN.
|
||||
if len(sans) == 0 {
|
||||
sans = []string{sub}
|
||||
}
|
||||
|
||||
tokOptions = append(tokOptions, token.WithSANS(sans))
|
||||
if !notBefore.IsZero() || !notAfter.IsZero() {
|
||||
if notBefore.IsZero() {
|
||||
notBefore = time.Now()
|
||||
}
|
||||
if notAfter.IsZero() {
|
||||
notAfter = notBefore.Add(token.DefaultValidity)
|
||||
}
|
||||
tokOptions = append(tokOptions, token.WithValidity(notBefore, notAfter))
|
||||
}
|
||||
|
||||
tok, err := provision.New(sub, tokOptions...)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return tok.SignedString(jwk.Algorithm, jwk.Key)
|
||||
}
|
||||
|
||||
func parseToken(token string) (*jose.JSONWebToken, *jose.Claims, error) {
|
||||
tok, err := jose.ParseSigned(token)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
claims := new(jose.Claims)
|
||||
if err := tok.UnsafeClaimsWithoutVerification(claims); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return tok, claims, nil
|
||||
}
|
Loading…
Reference in New Issue