Merge pull request #65 from smallstep/cloud-identities

Cloud identities
pull/77/head^2 v0.11.0-rc.1
Mariano Cano 5 years ago committed by GitHub
commit 578beec25d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -58,14 +58,7 @@ func (a *Authority) authorizeToken(ott string) (provisioner.Interface, error) {
}
// Store the token to protect against reuse.
var reuseKey string
switch p.GetType() {
case provisioner.TypeJWK:
reuseKey = claims.ID
case provisioner.TypeOIDC:
reuseKey = claims.Nonce
}
if reuseKey != "" {
if reuseKey, err := p.GetTokenID(ott); err == nil {
ok, err := a.db.UseToken(reuseKey, ott)
if err != nil {
return nil, &apiError{errors.Wrap(err, "authorizeToken: failed when checking if token already used"),

@ -0,0 +1,427 @@
package provisioner
import (
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"encoding/hex"
"encoding/json"
"encoding/pem"
"fmt"
"io/ioutil"
"net"
"net/http"
"strings"
"time"
"github.com/pkg/errors"
"github.com/smallstep/cli/jose"
)
// awsIssuer is the string used as issuer in the generated tokens.
const awsIssuer = "ec2.amazonaws.com"
// awsIdentityURL is the url used to retrieve the instance identity document.
const awsIdentityURL = "http://169.254.169.254/latest/dynamic/instance-identity/document"
// awsSignatureURL is the url used to retrieve the instance identity signature.
const awsSignatureURL = "http://169.254.169.254/latest/dynamic/instance-identity/signature"
// awsCertificate is the certificate used to validate the instance identity
// signature.
const awsCertificate = `-----BEGIN CERTIFICATE-----
MIIDIjCCAougAwIBAgIJAKnL4UEDMN/FMA0GCSqGSIb3DQEBBQUAMGoxCzAJBgNV
BAYTAlVTMRMwEQYDVQQIEwpXYXNoaW5ndG9uMRAwDgYDVQQHEwdTZWF0dGxlMRgw
FgYDVQQKEw9BbWF6b24uY29tIEluYy4xGjAYBgNVBAMTEWVjMi5hbWF6b25hd3Mu
Y29tMB4XDTE0MDYwNTE0MjgwMloXDTI0MDYwNTE0MjgwMlowajELMAkGA1UEBhMC
VVMxEzARBgNVBAgTCldhc2hpbmd0b24xEDAOBgNVBAcTB1NlYXR0bGUxGDAWBgNV
BAoTD0FtYXpvbi5jb20gSW5jLjEaMBgGA1UEAxMRZWMyLmFtYXpvbmF3cy5jb20w
gZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAIe9GN//SRK2knbjySG0ho3yqQM3
e2TDhWO8D2e8+XZqck754gFSo99AbT2RmXClambI7xsYHZFapbELC4H91ycihvrD
jbST1ZjkLQgga0NE1q43eS68ZeTDccScXQSNivSlzJZS8HJZjgqzBlXjZftjtdJL
XeE4hwvo0sD4f3j9AgMBAAGjgc8wgcwwHQYDVR0OBBYEFCXWzAgVyrbwnFncFFIs
77VBdlE4MIGcBgNVHSMEgZQwgZGAFCXWzAgVyrbwnFncFFIs77VBdlE4oW6kbDBq
MQswCQYDVQQGEwJVUzETMBEGA1UECBMKV2FzaGluZ3RvbjEQMA4GA1UEBxMHU2Vh
dHRsZTEYMBYGA1UEChMPQW1hem9uLmNvbSBJbmMuMRowGAYDVQQDExFlYzIuYW1h
em9uYXdzLmNvbYIJAKnL4UEDMN/FMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEF
BQADgYEAFYcz1OgEhQBXIwIdsgCOS8vEtiJYF+j9uO6jz7VOmJqO+pRlAbRlvY8T
C1haGgSI/A1uZUKs/Zfnph0oEI0/hu1IIJ/SKBDtN5lvmZ/IzbOPIJWirlsllQIQ
7zvWbGd9c9+Rm3p04oTvhup99la7kZqevJK0QRdD/6NpCKsqP/0=
-----END CERTIFICATE-----`
// awsSignatureAlgorithm is the signature algorithm used to verify the identity
// document signature.
const awsSignatureAlgorithm = x509.SHA256WithRSA
type awsConfig struct {
identityURL string
signatureURL string
certificate *x509.Certificate
signatureAlgorithm x509.SignatureAlgorithm
}
func newAWSConfig() (*awsConfig, error) {
block, _ := pem.Decode([]byte(awsCertificate))
if block == nil || block.Type != "CERTIFICATE" {
return nil, errors.New("error decoding AWS certificate")
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return nil, errors.Wrap(err, "error parsing AWS certificate")
}
return &awsConfig{
identityURL: awsIdentityURL,
signatureURL: awsSignatureURL,
certificate: cert,
signatureAlgorithm: awsSignatureAlgorithm,
}, nil
}
type awsPayload struct {
jose.Claims
Amazon awsAmazonPayload `json:"amazon"`
SANs []string `json:"sans"`
document awsInstanceIdentityDocument
}
type awsAmazonPayload struct {
Document []byte `json:"document"`
Signature []byte `json:"signature"`
}
type awsInstanceIdentityDocument struct {
AccountID string `json:"accountId"`
Architecture string `json:"architecture"`
AvailabilityZone string `json:"availabilityZone"`
BillingProducts []string `json:"billingProducts"`
DevpayProductCodes []string `json:"devpayProductCodes"`
ImageID string `json:"imageId"`
InstanceID string `json:"instanceId"`
InstanceType string `json:"instanceType"`
KernelID string `json:"kernelId"`
PendingTime time.Time `json:"pendingTime"`
PrivateIP string `json:"privateIp"`
RamdiskID string `json:"ramdiskId"`
Region string `json:"region"`
Version string `json:"version"`
}
// AWS is the provisioner that supports identity tokens created from the Amazon
// Web Services Instance Identity Documents.
//
// If DisableCustomSANs is true, only the internal DNS and IP will be added as a
// SAN. By default it will accept any SAN in the CSR.
//
// If DisableTrustOnFirstUse is true, multiple sign request for this provisioner
// with the same instance will be accepted. By default only the first request
// will be accepted.
//
// If InstanceAge is set, only the instances with a pendingTime within the given
// period will be accepted.
//
// Amazon Identity docs are available at
// https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instance-identity-documents.html
type AWS struct {
Type string `json:"type"`
Name string `json:"name"`
Accounts []string `json:"accounts"`
DisableCustomSANs bool `json:"disableCustomSANs"`
DisableTrustOnFirstUse bool `json:"disableTrustOnFirstUse"`
InstanceAge Duration `json:"instanceAge,omitempty"`
Claims *Claims `json:"claims,omitempty"`
claimer *Claimer
config *awsConfig
audiences Audiences
}
// GetID returns the provisioner unique identifier.
func (p *AWS) GetID() string {
return "aws/" + p.Name
}
// GetTokenID returns the identifier of the token.
func (p *AWS) GetTokenID(token string) (string, error) {
payload, err := p.authorizeToken(token)
if err != nil {
return "", err
}
// If TOFU is disabled create an ID for the token, so it cannot be reused.
// The timestamps, document and signatures should be mostly unique.
if p.DisableTrustOnFirstUse {
sum := sha256.Sum256([]byte(token))
return strings.ToLower(hex.EncodeToString(sum[:])), nil
}
return payload.ID, nil
}
// GetName returns the name of the provisioner.
func (p *AWS) GetName() string {
return p.Name
}
// GetType returns the type of provisioner.
func (p *AWS) GetType() Type {
return TypeAWS
}
// GetEncryptedKey is not available in an AWS provisioner.
func (p *AWS) GetEncryptedKey() (kid string, key string, ok bool) {
return "", "", false
}
// GetIdentityToken retrieves the identity document and it's signature and
// generates a token with them.
func (p *AWS) GetIdentityToken(caURL string) (string, error) {
// Initialize the config if this method is used from the cli.
if err := p.assertConfig(); err != nil {
return "", err
}
var idoc awsInstanceIdentityDocument
doc, err := p.readURL(p.config.identityURL)
if err != nil {
return "", errors.Wrap(err, "error retrieving identity document, are you in an AWS VM?")
}
if err := json.Unmarshal(doc, &idoc); err != nil {
return "", errors.Wrap(err, "error unmarshaling identity document")
}
sig, err := p.readURL(p.config.signatureURL)
if err != nil {
return "", errors.Wrap(err, "error retrieving identity document signature, are you in an AWS VM?")
}
signature, err := base64.StdEncoding.DecodeString(string(sig))
if err != nil {
return "", errors.Wrap(err, "error decoding identity document signature")
}
if err := p.checkSignature(doc, signature); err != nil {
return "", err
}
audience, err := generateSignAudience(caURL, p.GetID())
if err != nil {
return "", err
}
// Create unique ID for Trust On First Use (TOFU). Only the first instance
// per provisioner is allowed as we don't have a way to trust the given
// sans.
unique := fmt.Sprintf("%s.%s", p.GetID(), idoc.InstanceID)
sum := sha256.Sum256([]byte(unique))
// Create a JWT from the identity document
signer, err := jose.NewSigner(
jose.SigningKey{Algorithm: jose.HS256, Key: signature},
new(jose.SignerOptions).WithType("JWT"),
)
if err != nil {
return "", errors.Wrap(err, "error creating signer")
}
now := time.Now()
payload := awsPayload{
Claims: jose.Claims{
Issuer: awsIssuer,
Subject: idoc.InstanceID,
Audience: []string{audience},
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
NotBefore: jose.NewNumericDate(now),
IssuedAt: jose.NewNumericDate(now),
ID: strings.ToLower(hex.EncodeToString(sum[:])),
},
Amazon: awsAmazonPayload{
Document: doc,
Signature: signature,
},
}
tok, err := jose.Signed(signer).Claims(payload).CompactSerialize()
if err != nil {
return "", errors.Wrap(err, "error serialiazing token")
}
return tok, nil
}
// Init validates and initializes the AWS provisioner.
func (p *AWS) Init(config Config) (err error) {
switch {
case p.Type == "":
return errors.New("provisioner type cannot be empty")
case p.Name == "":
return errors.New("provisioner name cannot be empty")
case p.InstanceAge.Value() < 0:
return errors.New("provisioner instanceAge cannot be negative")
}
// Update claims with global ones
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
return err
}
// Add default config
if p.config, err = newAWSConfig(); err != nil {
return err
}
p.audiences = config.Audiences.WithFragment(p.GetID())
return nil
}
// AuthorizeSign validates the given token and returns the sign options that
// will be used on certificate creation.
func (p *AWS) AuthorizeSign(token string) ([]SignOption, error) {
payload, err := p.authorizeToken(token)
if err != nil {
return nil, err
}
doc := payload.document
// Enforce default DNS and IP if configured.
// By default we'll accept the SANs in the CSR.
// There's no way to trust them other than TOFU.
var so []SignOption
if p.DisableCustomSANs {
so = append(so, dnsNamesValidator([]string{
fmt.Sprintf("ip-%s.%s.compute.internal", strings.Replace(doc.PrivateIP, ".", "-", -1), doc.Region),
}))
so = append(so, ipAddressesValidator([]net.IP{
net.ParseIP(doc.PrivateIP),
}))
}
return append(so,
commonNameValidator(doc.InstanceID),
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()),
newProvisionerExtensionOption(TypeAWS, p.Name, doc.AccountID),
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()),
), nil
}
// AuthorizeRenewal returns an error if the renewal is disabled.
func (p *AWS) AuthorizeRenewal(cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() {
return errors.Errorf("renew is disabled for provisioner %s", p.GetID())
}
return nil
}
// AuthorizeRevoke returns an error because revoke is not supported on AWS
// provisioners.
func (p *AWS) AuthorizeRevoke(token string) error {
return errors.New("revoke is not supported on a AWS provisioner")
}
// assertConfig initializes the config if it has not been initialized
func (p *AWS) assertConfig() (err error) {
if p.config != nil {
return
}
p.config, err = newAWSConfig()
return err
}
// checkSignature returns an error if the signature is not valid.
func (p *AWS) checkSignature(signed, signature []byte) error {
if err := p.config.certificate.CheckSignature(p.config.signatureAlgorithm, signed, signature); err != nil {
return errors.Wrap(err, "error validating identity document signature")
}
return nil
}
// readURL does a GET request to the given url and returns the body. It's not
// using pkg/errors to avoid verbose errors, the caller should use it and write
// the appropriate error.
func (p *AWS) readURL(url string) ([]byte, error) {
r, err := http.Get(url)
if err != nil {
return nil, err
}
defer r.Body.Close()
b, err := ioutil.ReadAll(r.Body)
if err != nil {
return nil, err
}
return b, nil
}
// authorizeToken performs common jwt authorization actions and returns the
// claims for case specific downstream parsing.
// e.g. a Sign request will auth/validate different fields than a Revoke request.
func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
jwt, err := jose.ParseSigned(token)
if err != nil {
return nil, errors.Wrapf(err, "error parsing token")
}
if len(jwt.Headers) == 0 {
return nil, errors.New("error parsing token: header is missing")
}
var unsafeClaims awsPayload
if err := jwt.UnsafeClaimsWithoutVerification(&unsafeClaims); err != nil {
return nil, errors.Wrap(err, "error unmarshaling claims")
}
var payload awsPayload
if err := jwt.Claims(unsafeClaims.Amazon.Signature, &payload); err != nil {
return nil, errors.Wrap(err, "error verifying claims")
}
// Validate identity document signature
if err := p.checkSignature(payload.Amazon.Document, payload.Amazon.Signature); err != nil {
return nil, err
}
var doc awsInstanceIdentityDocument
if err := json.Unmarshal(payload.Amazon.Document, &doc); err != nil {
return nil, errors.Wrap(err, "error unmarshaling identity document")
}
switch {
case doc.AccountID == "":
return nil, errors.New("identity document accountId cannot be empty")
case doc.InstanceID == "":
return nil, errors.New("identity document instanceId cannot be empty")
case doc.PrivateIP == "":
return nil, errors.New("identity document privateIp cannot be empty")
case doc.Region == "":
return nil, errors.New("identity document region cannot be empty")
}
// According to "rfc7519 JSON Web Token" acceptable skew should be no
// more than a few minutes.
now := time.Now().UTC()
if err = payload.ValidateWithLeeway(jose.Expected{
Issuer: awsIssuer,
Subject: doc.InstanceID,
Time: now,
}, time.Minute); err != nil {
return nil, errors.Wrapf(err, "invalid token")
}
// validate audiences with the defaults
if !matchesAudience(payload.Audience, p.audiences.Sign) {
fmt.Println(payload.Audience, "vs", p.audiences.Sign)
return nil, errors.New("invalid token: invalid audience claim (aud)")
}
// validate accounts
if len(p.Accounts) > 0 {
var found bool
for _, sa := range p.Accounts {
if sa == doc.AccountID {
found = true
break
}
}
if !found {
return nil, errors.New("invalid identity document: accountId is not valid")
}
}
// validate instance age
if d := p.InstanceAge.Value(); d > 0 {
if now.Sub(doc.PendingTime) > d {
return nil, errors.New("identity document pendingTime is too old")
}
}
payload.document = doc
return &payload, nil
}

@ -0,0 +1,389 @@
package provisioner
import (
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/hex"
"encoding/pem"
"fmt"
"net/url"
"strings"
"testing"
"time"
"github.com/smallstep/assert"
"github.com/smallstep/cli/jose"
)
func TestAWS_Getters(t *testing.T) {
p, err := generateAWS()
assert.FatalError(t, err)
aud := "aws/" + p.Name
if got := p.GetID(); got != aud {
t.Errorf("AWS.GetID() = %v, want %v", got, aud)
}
if got := p.GetName(); got != p.Name {
t.Errorf("AWS.GetName() = %v, want %v", got, p.Name)
}
if got := p.GetType(); got != TypeAWS {
t.Errorf("AWS.GetType() = %v, want %v", got, TypeAWS)
}
kid, key, ok := p.GetEncryptedKey()
if kid != "" || key != "" || ok == true {
t.Errorf("AWS.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)",
kid, key, ok, "", "", false)
}
}
func TestAWS_GetTokenID(t *testing.T) {
p1, srv, err := generateAWSWithServer()
assert.FatalError(t, err)
defer srv.Close()
p2, err := generateAWS()
assert.FatalError(t, err)
p2.Accounts = p1.Accounts
p2.config = p1.config
p2.DisableTrustOnFirstUse = true
t1, err := p1.GetIdentityToken("https://ca.smallstep.com")
assert.FatalError(t, err)
_, claims, err := parseAWSToken(t1)
assert.FatalError(t, err)
sum := sha256.Sum256([]byte(fmt.Sprintf("%s.%s", p1.GetID(), claims.document.InstanceID)))
w1 := strings.ToLower(hex.EncodeToString(sum[:]))
t2, err := p2.GetIdentityToken("https://ca.smallstep.com")
assert.FatalError(t, err)
sum = sha256.Sum256([]byte(t2))
w2 := strings.ToLower(hex.EncodeToString(sum[:]))
type args struct {
token string
}
tests := []struct {
name string
aws *AWS
args args
want string
wantErr bool
}{
{"ok", p1, args{t1}, w1, false},
{"ok no TOFU", p2, args{t2}, w2, false},
{"fail", p1, args{"bad-token"}, "", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.aws.GetTokenID(tt.args.token)
if (err != nil) != tt.wantErr {
t.Errorf("AWS.GetTokenID() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("AWS.GetTokenID() = %v, want %v", got, tt.want)
}
})
}
}
func TestAWS_GetIdentityToken(t *testing.T) {
p1, srv, err := generateAWSWithServer()
assert.FatalError(t, err)
defer srv.Close()
p2, err := generateAWS()
assert.FatalError(t, err)
p2.Accounts = p1.Accounts
p2.config.identityURL = srv.URL + "/bad-document"
p2.config.signatureURL = p1.config.signatureURL
p3, err := generateAWS()
assert.FatalError(t, err)
p3.Accounts = p1.Accounts
p3.config.signatureURL = srv.URL
p3.config.identityURL = p1.config.identityURL
p4, err := generateAWS()
assert.FatalError(t, err)
p4.Accounts = p1.Accounts
p4.config.signatureURL = srv.URL + "/bad-signature"
p4.config.identityURL = p1.config.identityURL
caURL := "https://ca.smallstep.com"
u, err := url.Parse(caURL)
assert.FatalError(t, err)
type args struct {
caURL string
}
tests := []struct {
name string
aws *AWS
args args
wantErr bool
}{
{"ok", p1, args{caURL}, false},
{"fail ca url", p1, args{"://ca.smallstep.com"}, true},
{"fail identityURL", p2, args{caURL}, true},
{"fail signatureURL", p3, args{caURL}, true},
{"fail signature", p4, args{caURL}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.aws.GetIdentityToken(tt.args.caURL)
if (err != nil) != tt.wantErr {
t.Errorf("AWS.GetIdentityToken() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr == false {
_, c, err := parseAWSToken(got)
if assert.NoError(t, err) {
assert.Equals(t, awsIssuer, c.Issuer)
assert.Equals(t, c.document.InstanceID, c.Subject)
assert.Equals(t, jose.Audience{u.ResolveReference(&url.URL{Path: "/1.0/sign", Fragment: tt.aws.GetID()}).String()}, c.Audience)
assert.Equals(t, tt.aws.Accounts[0], c.document.AccountID)
err = tt.aws.config.certificate.CheckSignature(
tt.aws.config.signatureAlgorithm, c.Amazon.Document, c.Amazon.Signature)
assert.NoError(t, err)
}
}
})
}
}
func TestAWS_Init(t *testing.T) {
config := Config{
Claims: globalProvisionerClaims,
}
badClaims := &Claims{
DefaultTLSDur: &Duration{0},
}
zero := Duration{Duration: 0}
type fields struct {
Type string
Name string
Accounts []string
DisableCustomSANs bool
DisableTrustOnFirstUse bool
InstanceAge Duration
Claims *Claims
}
type args struct {
config Config
}
tests := []struct {
name string
fields fields
args args
wantErr bool
}{
{"ok", fields{"AWS", "name", []string{"account"}, false, false, zero, nil}, args{config}, false},
{"ok", fields{"AWS", "name", []string{"account"}, true, true, Duration{Duration: 1 * time.Minute}, nil}, args{config}, false},
{"fail type ", fields{"", "name", []string{"account"}, false, false, zero, nil}, args{config}, true},
{"fail name", fields{"AWS", "", []string{"account"}, false, false, zero, nil}, args{config}, true},
{"bad instance age", fields{"AWS", "name", []string{"account"}, false, false, Duration{Duration: -1 * time.Minute}, nil}, args{config}, true},
{"fail claims", fields{"AWS", "name", []string{"account"}, false, false, zero, badClaims}, args{config}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &AWS{
Type: tt.fields.Type,
Name: tt.fields.Name,
Accounts: tt.fields.Accounts,
DisableCustomSANs: tt.fields.DisableCustomSANs,
DisableTrustOnFirstUse: tt.fields.DisableTrustOnFirstUse,
InstanceAge: tt.fields.InstanceAge,
Claims: tt.fields.Claims,
}
if err := p.Init(tt.args.config); (err != nil) != tt.wantErr {
t.Errorf("AWS.Init() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestAWS_AuthorizeSign(t *testing.T) {
p1, srv, err := generateAWSWithServer()
assert.FatalError(t, err)
defer srv.Close()
p2, err := generateAWS()
assert.FatalError(t, err)
p2.Accounts = p1.Accounts
p2.config = p1.config
p2.DisableCustomSANs = true
p2.InstanceAge = Duration{1 * time.Minute}
p3, err := generateAWS()
assert.FatalError(t, err)
p3.config = p1.config
t1, err := p1.GetIdentityToken("https://ca.smallstep.com")
assert.FatalError(t, err)
t2, err := p2.GetIdentityToken("https://ca.smallstep.com")
assert.FatalError(t, err)
t3, err := p3.GetIdentityToken("https://ca.smallstep.com")
assert.FatalError(t, err)
block, _ := pem.Decode([]byte(awsTestKey))
if block == nil || block.Type != "RSA PRIVATE KEY" {
t.Fatal("error decoding AWS key")
}
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
assert.FatalError(t, err)
badKey, err := rsa.GenerateKey(rand.Reader, 1024)
assert.FatalError(t, err)
t4, err := generateAWSToken(
"instance-id", awsIssuer, p1.GetID(), p1.Accounts[0], "instance-id",
"127.0.0.1", "us-west-1", time.Now(), key)
assert.FatalError(t, err)
failSubject, err := generateAWSToken(
"bad-subject", awsIssuer, p1.GetID(), p1.Accounts[0], "instance-id",
"127.0.0.1", "us-west-1", time.Now(), key)
assert.FatalError(t, err)
failIssuer, err := generateAWSToken(
"instance-id", "bad-issuer", p1.GetID(), p1.Accounts[0], "instance-id",
"127.0.0.1", "us-west-1", time.Now(), key)
assert.FatalError(t, err)
failAudience, err := generateAWSToken(
"instance-id", awsIssuer, "bad-audience", p1.Accounts[0], "instance-id",
"127.0.0.1", "us-west-1", time.Now(), key)
assert.FatalError(t, err)
failAccount, err := generateAWSToken(
"instance-id", awsIssuer, p1.GetID(), "", "instance-id",
"127.0.0.1", "us-west-1", time.Now(), key)
assert.FatalError(t, err)
failInstanceID, err := generateAWSToken(
"instance-id", awsIssuer, p1.GetID(), p1.Accounts[0], "",
"127.0.0.1", "us-west-1", time.Now(), key)
assert.FatalError(t, err)
failPrivateIP, err := generateAWSToken(
"instance-id", awsIssuer, p1.GetID(), p1.Accounts[0], "instance-id",
"", "us-west-1", time.Now(), key)
assert.FatalError(t, err)
failRegion, err := generateAWSToken(
"instance-id", awsIssuer, p1.GetID(), p1.Accounts[0], "instance-id",
"127.0.0.1", "", time.Now(), key)
assert.FatalError(t, err)
failExp, err := generateAWSToken(
"instance-id", awsIssuer, p1.GetID(), p1.Accounts[0], "instance-id",
"127.0.0.1", "us-west-1", time.Now().Add(-360*time.Second), key)
assert.FatalError(t, err)
failNbf, err := generateAWSToken(
"instance-id", awsIssuer, p1.GetID(), p1.Accounts[0], "instance-id",
"127.0.0.1", "us-west-1", time.Now().Add(360*time.Second), key)
assert.FatalError(t, err)
failKey, err := generateAWSToken(
"instance-id", awsIssuer, p1.GetID(), p1.Accounts[0], "instance-id",
"127.0.0.1", "us-west-1", time.Now(), badKey)
assert.FatalError(t, err)
failInstanceAge, err := generateAWSToken(
"instance-id", awsIssuer, p2.GetID(), p2.Accounts[0], "instance-id",
"127.0.0.1", "us-west-1", time.Now().Add(-1*time.Minute), key)
assert.FatalError(t, err)
type args struct {
token string
}
tests := []struct {
name string
aws *AWS
args args
wantLen int
wantErr bool
}{
{"ok", p1, args{t1}, 4, false},
{"ok", p2, args{t2}, 6, false},
{"ok", p1, args{t4}, 4, false},
{"fail account", p3, args{t3}, 0, true},
{"fail token", p1, args{"token"}, 0, true},
{"fail subject", p1, args{failSubject}, 0, true},
{"fail issuer", p1, args{failIssuer}, 0, true},
{"fail audience", p1, args{failAudience}, 0, true},
{"fail account", p1, args{failAccount}, 0, true},
{"fail instanceID", p1, args{failInstanceID}, 0, true},
{"fail privateIP", p1, args{failPrivateIP}, 0, true},
{"fail region", p1, args{failRegion}, 0, true},
{"fail exp", p1, args{failExp}, 0, true},
{"fail nbf", p1, args{failNbf}, 0, true},
{"fail key", p1, args{failKey}, 0, true},
{"fail instance age", p2, args{failInstanceAge}, 0, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.aws.AuthorizeSign(tt.args.token)
if (err != nil) != tt.wantErr {
t.Errorf("AWS.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
return
}
assert.Len(t, tt.wantLen, got)
})
}
}
func TestAWS_AuthorizeRenewal(t *testing.T) {
p1, err := generateAWS()
assert.FatalError(t, err)
p2, err := generateAWS()
assert.FatalError(t, err)
// disable renewal
disable := true
p2.Claims = &Claims{DisableRenewal: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
type args struct {
cert *x509.Certificate
}
tests := []struct {
name string
aws *AWS
args args
wantErr bool
}{
{"ok", p1, args{nil}, false},
{"fail", p2, args{nil}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.aws.AuthorizeRenewal(tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("AWS.AuthorizeRenewal() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestAWS_AuthorizeRevoke(t *testing.T) {
p1, srv, err := generateAWSWithServer()
assert.FatalError(t, err)
defer srv.Close()
t1, err := p1.GetIdentityToken("https://ca.smallstep.com")
assert.FatalError(t, err)
type args struct {
token string
}
tests := []struct {
name string
aws *AWS
args args
wantErr bool
}{
{"ok", p1, args{t1}, true}, // revoke is disabled
{"fail", p1, args{"token"}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.aws.AuthorizeRevoke(tt.args.token); (err != nil) != tt.wantErr {
t.Errorf("AWS.AuthorizeRevoke() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

@ -0,0 +1,303 @@
package provisioner
import (
"crypto/sha256"
"crypto/x509"
"encoding/hex"
"encoding/json"
"io/ioutil"
"net/http"
"regexp"
"strings"
"time"
"github.com/pkg/errors"
"github.com/smallstep/cli/jose"
)
// azureOIDCBaseURL is the base discovery url for Microsoft Azure tokens.
const azureOIDCBaseURL = "https://login.microsoftonline.com"
// azureIdentityTokenURL is the URL to get the identity token for an instance.
const azureIdentityTokenURL = "http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fmanagement.azure.com%2F"
// azureDefaultAudience is the default audience used.
const azureDefaultAudience = "https://management.azure.com/"
// azureXMSMirIDRegExp is the regular expression used to parse the xms_mirid claim.
// Using case insensitive as resourceGroups appears as resourcegroups.
var azureXMSMirIDRegExp = regexp.MustCompile(`(?i)^/subscriptions/([^/]+)/resourceGroups/([^/]+)/providers/Microsoft.Compute/virtualMachines/([^/]+)$`)
type azureConfig struct {
oidcDiscoveryURL string
identityTokenURL string
}
func newAzureConfig(tenantID string) *azureConfig {
return &azureConfig{
oidcDiscoveryURL: azureOIDCBaseURL + "/" + tenantID + "/.well-known/openid-configuration",
identityTokenURL: azureIdentityTokenURL,
}
}
type azureIdentityToken struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ClientID string `json:"client_id"`
ExpiresIn int64 `json:"expires_in,string"`
ExpiresOn int64 `json:"expires_on,string"`
ExtExpiresIn int64 `json:"ext_expires_in,string"`
NotBefore int64 `json:"not_before,string"`
Resource string `json:"resource"`
TokenType string `json:"token_type"`
}
type azurePayload struct {
jose.Claims
AppID string `json:"appid"`
AppIDAcr string `json:"appidacr"`
IdentityProvider string `json:"idp"`
ObjectID string `json:"oid"`
TenantID string `json:"tid"`
Version string `json:"ver"`
XMSMirID string `json:"xms_mirid"`
}
// Azure is the provisioner that supports identity tokens created from the
// Microsoft Azure Instance Metadata service.
//
// The default audience is "https://management.azure.com/".
//
// If DisableCustomSANs is true, only the internal DNS and IP will be added as a
// SAN. By default it will accept any SAN in the CSR.
//
// If DisableTrustOnFirstUse is true, multiple sign request for this provisioner
// with the same instance will be accepted. By default only the first request
// will be accepted.
//
// Microsoft Azure identity docs are available at
// https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token
// and https://docs.microsoft.com/en-us/azure/virtual-machines/windows/instance-metadata-service
type Azure struct {
Type string `json:"type"`
Name string `json:"name"`
TenantID string `json:"tenantId"`
ResourceGroups []string `json:"resourceGroups"`
Audience string `json:"audience,omitempty"`
DisableCustomSANs bool `json:"disableCustomSANs"`
DisableTrustOnFirstUse bool `json:"disableTrustOnFirstUse"`
Claims *Claims `json:"claims,omitempty"`
claimer *Claimer
config *azureConfig
oidcConfig openIDConfiguration
keyStore *keyStore
}
// GetID returns the provisioner unique identifier.
func (p *Azure) GetID() string {
return p.TenantID
}
// GetTokenID returns the identifier of the token. The default value for Azure
// the SHA256 of "xms_mirid", but if DisableTrustOnFirstUse is set to true, then
// it will be the token kid.
func (p *Azure) GetTokenID(token string) (string, error) {
jwt, err := jose.ParseSigned(token)
if err != nil {
return "", errors.Wrap(err, "error parsing token")
}
// Get claims w/out verification. We need to look up the provisioner
// key in order to verify the claims and we need the issuer from the claims
// before we can look up the provisioner.
var claims azurePayload
if err = jwt.UnsafeClaimsWithoutVerification(&claims); err != nil {
return "", errors.Wrap(err, "error verifying claims")
}
// If TOFU is disabled create return the token kid
if p.DisableTrustOnFirstUse {
return claims.ID, nil
}
sum := sha256.Sum256([]byte(claims.XMSMirID))
return strings.ToLower(hex.EncodeToString(sum[:])), nil
}
// GetName returns the name of the provisioner.
func (p *Azure) GetName() string {
return p.Name
}
// GetType returns the type of provisioner.
func (p *Azure) GetType() Type {
return TypeAzure
}
// GetEncryptedKey is not available in an Azure provisioner.
func (p *Azure) GetEncryptedKey() (kid string, key string, ok bool) {
return "", "", false
}
// GetIdentityToken retrieves from the metadata service the identity token and
// returns it.
func (p *Azure) GetIdentityToken() (string, error) {
// Initialize the config if this method is used from the cli.
p.assertConfig()
req, err := http.NewRequest("GET", p.config.identityTokenURL, http.NoBody)
if err != nil {
return "", errors.Wrap(err, "error creating request")
}
req.Header.Set("Metadata", "true")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", errors.Wrap(err, "error getting identity token, are you in a Azure VM?")
}
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return "", errors.Wrap(err, "error reading identity token response")
}
if resp.StatusCode >= 400 {
return "", errors.Errorf("error getting identity token: status=%d, response=%s", resp.StatusCode, b)
}
var identityToken azureIdentityToken
if err := json.Unmarshal(b, &identityToken); err != nil {
return "", errors.Wrap(err, "error unmarshaling identity token response")
}
return identityToken.AccessToken, nil
}
// Init validates and initializes the Azure provisioner.
func (p *Azure) Init(config Config) (err error) {
switch {
case p.Type == "":
return errors.New("provisioner type cannot be empty")
case p.Name == "":
return errors.New("provisioner name cannot be empty")
case p.TenantID == "":
return errors.New("provisioner tenantId cannot be empty")
case p.Audience == "": // use default audience
p.Audience = azureDefaultAudience
}
// Initialize config
p.assertConfig()
// Update claims with global ones
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
return err
}
// Decode and validate openid-configuration endpoint
if err := getAndDecode(p.config.oidcDiscoveryURL, &p.oidcConfig); err != nil {
return err
}
if err := p.oidcConfig.Validate(); err != nil {
return errors.Wrapf(err, "error parsing %s", p.config.oidcDiscoveryURL)
}
// Get JWK key set
if p.keyStore, err = newKeyStore(p.oidcConfig.JWKSetURI); err != nil {
return err
}
return nil
}
// AuthorizeSign validates the given token and returns the sign options that
// will be used on certificate creation.
func (p *Azure) AuthorizeSign(token string) ([]SignOption, error) {
jwt, err := jose.ParseSigned(token)
if err != nil {
return nil, errors.Wrapf(err, "error parsing token")
}
if len(jwt.Headers) == 0 {
return nil, errors.New("error parsing token: header is missing")
}
var found bool
var claims azurePayload
keys := p.keyStore.Get(jwt.Headers[0].KeyID)
for _, key := range keys {
if err := jwt.Claims(key.Public(), &claims); err == nil {
found = true
break
}
}
if !found {
return nil, errors.New("cannot validate token")
}
if err := claims.ValidateWithLeeway(jose.Expected{
Audience: []string{p.Audience},
Issuer: p.oidcConfig.Issuer,
Time: time.Now(),
}, 1*time.Minute); err != nil {
return nil, errors.Wrap(err, "failed to validate payload")
}
// Validate TenantID
if claims.TenantID != p.TenantID {
return nil, errors.New("validation failed: invalid tenant id claim (tid)")
}
re := azureXMSMirIDRegExp.FindStringSubmatch(claims.XMSMirID)
if len(re) != 4 {
return nil, errors.Errorf("error parsing xms_mirid claim: %s", claims.XMSMirID)
}
group, name := re[2], re[3]
// Filter by resource group
if len(p.ResourceGroups) > 0 {
var found bool
for _, g := range p.ResourceGroups {
if g == group {
found = true
break
}
}
if !found {
return nil, errors.New("validation failed: invalid resource group")
}
}
// Enforce default DNS if configured.
// By default we'll accept the SANs in the CSR.
// There's no way to trust them other than TOFU.
var so []SignOption
if p.DisableCustomSANs {
// name will work only inside the virtual network
so = append(so, dnsNamesValidator([]string{name}))
}
return append(so,
commonNameValidator(name),
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()),
newProvisionerExtensionOption(TypeAzure, p.Name, p.TenantID),
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()),
), nil
}
// AuthorizeRenewal returns an error if the renewal is disabled.
func (p *Azure) AuthorizeRenewal(cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() {
return errors.Errorf("renew is disabled for provisioner %s", p.GetID())
}
return nil
}
// AuthorizeRevoke returns an error because revoke is not supported on Azure
// provisioners.
func (p *Azure) AuthorizeRevoke(token string) error {
return errors.New("revoke is not supported on a Azure provisioner")
}
// assertConfig initializes the config if it has not been initialized
func (p *Azure) assertConfig() {
if p.config == nil {
p.config = newAzureConfig(p.TenantID)
}
}

@ -0,0 +1,384 @@
package provisioner
import (
"crypto/sha256"
"crypto/x509"
"encoding/hex"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/smallstep/assert"
)
func TestAzure_Getters(t *testing.T) {
p, err := generateAzure()
assert.FatalError(t, err)
if got := p.GetID(); got != p.TenantID {
t.Errorf("Azure.GetID() = %v, want %v", got, p.TenantID)
}
if got := p.GetName(); got != p.Name {
t.Errorf("Azure.GetName() = %v, want %v", got, p.Name)
}
if got := p.GetType(); got != TypeAzure {
t.Errorf("Azure.GetType() = %v, want %v", got, TypeAzure)
}
kid, key, ok := p.GetEncryptedKey()
if kid != "" || key != "" || ok == true {
t.Errorf("Azure.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)",
kid, key, ok, "", "", false)
}
}
func TestAzure_GetTokenID(t *testing.T) {
p1, srv, err := generateAzureWithServer()
assert.FatalError(t, err)
defer srv.Close()
p2, err := generateAzure()
assert.FatalError(t, err)
p2.TenantID = p1.TenantID
p2.config = p1.config
p2.oidcConfig = p1.oidcConfig
p2.keyStore = p1.keyStore
p2.DisableTrustOnFirstUse = true
t1, err := p1.GetIdentityToken()
assert.FatalError(t, err)
t2, err := p2.GetIdentityToken()
assert.FatalError(t, err)
sum := sha256.Sum256([]byte("/subscriptions/subscriptionID/resourceGroups/resourceGroup/providers/Microsoft.Compute/virtualMachines/virtualMachine"))
w1 := strings.ToLower(hex.EncodeToString(sum[:]))
type args struct {
token string
}
tests := []struct {
name string
azure *Azure
args args
want string
wantErr bool
}{
{"ok", p1, args{t1}, w1, false},
{"ok no TOFU", p2, args{t2}, "the-jti", false},
{"fail token", p1, args{"bad-token"}, "", true},
{"fail claims", p1, args{"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.ey.fooo"}, "", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.azure.GetTokenID(tt.args.token)
if (err != nil) != tt.wantErr {
t.Errorf("Azure.GetTokenID() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("Azure.GetTokenID() = %v, want %v", got, tt.want)
}
})
}
}
func TestAzure_GetIdentityToken(t *testing.T) {
p1, err := generateAzure()
assert.FatalError(t, err)
t1, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine",
time.Now(), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/bad-request":
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
case "/bad-json":
w.Write([]byte(t1))
default:
w.Header().Add("Content-Type", "application/json")
w.Write([]byte(fmt.Sprintf(`{"access_token":"%s"}`, t1)))
}
}))
defer srv.Close()
tests := []struct {
name string
azure *Azure
identityTokenURL string
want string
wantErr bool
}{
{"ok", p1, srv.URL, t1, false},
{"fail request", p1, srv.URL + "/bad-request", "", true},
{"fail unmarshal", p1, srv.URL + "/bad-json", "", true},
{"fail url", p1, "://ca.smallstep.com", "", true},
{"fail connect", p1, "foobarzar", "", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.azure.config.identityTokenURL = tt.identityTokenURL
got, err := tt.azure.GetIdentityToken()
if (err != nil) != tt.wantErr {
t.Errorf("Azure.GetIdentityToken() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("Azure.GetIdentityToken() = %v, want %v", got, tt.want)
}
})
}
}
func TestAzure_Init(t *testing.T) {
p1, srv, err := generateAzureWithServer()
assert.FatalError(t, err)
defer srv.Close()
config := Config{
Claims: globalProvisionerClaims,
}
badClaims := &Claims{
DefaultTLSDur: &Duration{0},
}
badDiscoveryURL := &azureConfig{
oidcDiscoveryURL: srv.URL + "/error",
identityTokenURL: p1.config.identityTokenURL,
}
badJWKURL := &azureConfig{
oidcDiscoveryURL: srv.URL + "/openid-configuration-fail-jwk",
identityTokenURL: p1.config.identityTokenURL,
}
badAzureConfig := &azureConfig{
oidcDiscoveryURL: srv.URL + "/openid-configuration-no-issuer",
identityTokenURL: p1.config.identityTokenURL,
}
type fields struct {
Type string
Name string
TenantID string
Claims *Claims
config *azureConfig
}
type args struct {
config Config
}
tests := []struct {
name string
fields fields
args args
wantErr bool
}{
{"ok", fields{p1.Type, p1.Name, p1.TenantID, nil, p1.config}, args{config}, false},
{"ok with config", fields{p1.Type, p1.Name, p1.TenantID, nil, p1.config}, args{config}, false},
{"fail type", fields{"", p1.Name, p1.TenantID, nil, p1.config}, args{config}, true},
{"fail name", fields{p1.Type, "", p1.TenantID, nil, p1.config}, args{config}, true},
{"fail tenant id", fields{p1.Type, p1.Name, "", nil, p1.config}, args{config}, true},
{"fail claims", fields{p1.Type, p1.Name, p1.TenantID, badClaims, p1.config}, args{config}, true},
{"fail discovery URL", fields{p1.Type, p1.Name, p1.TenantID, nil, badDiscoveryURL}, args{config}, true},
{"fail JWK URL", fields{p1.Type, p1.Name, p1.TenantID, nil, badJWKURL}, args{config}, true},
{"fail config Validate", fields{p1.Type, p1.Name, p1.TenantID, nil, badAzureConfig}, args{config}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &Azure{
Type: tt.fields.Type,
Name: tt.fields.Name,
TenantID: tt.fields.TenantID,
Claims: tt.fields.Claims,
config: tt.fields.config,
}
if err := p.Init(tt.args.config); (err != nil) != tt.wantErr {
t.Errorf("Azure.Init() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestAzure_AuthorizeSign(t *testing.T) {
p1, srv, err := generateAzureWithServer()
assert.FatalError(t, err)
defer srv.Close()
p2, err := generateAzure()
assert.FatalError(t, err)
p2.TenantID = p1.TenantID
p2.ResourceGroups = []string{"resourceGroup"}
p2.config = p1.config
p2.oidcConfig = p1.oidcConfig
p2.keyStore = p1.keyStore
p2.DisableCustomSANs = true
p3, err := generateAzure()
assert.FatalError(t, err)
p3.config = p1.config
p3.oidcConfig = p1.oidcConfig
p3.keyStore = p1.keyStore
p4, err := generateAzure()
assert.FatalError(t, err)
p4.TenantID = p1.TenantID
p4.ResourceGroups = []string{"foobarzar"}
p4.config = p1.config
p4.oidcConfig = p1.oidcConfig
p4.keyStore = p1.keyStore
badKey, err := generateJSONWebKey()
assert.FatalError(t, err)
t1, err := p1.GetIdentityToken()
assert.FatalError(t, err)
t2, err := p2.GetIdentityToken()
assert.FatalError(t, err)
t3, err := p3.GetIdentityToken()
assert.FatalError(t, err)
t4, err := p4.GetIdentityToken()
assert.FatalError(t, err)
t11, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine",
time.Now(), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
failIssuer, err := generateAzureToken("subject", "bad-issuer", azureDefaultAudience,
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine",
time.Now(), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
failAudience, err := generateAzureToken("subject", p1.oidcConfig.Issuer, "bad-audience",
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine",
time.Now(), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
failExp, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine",
time.Now().Add(-360*time.Second), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
failNbf, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine",
time.Now().Add(360*time.Second), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
failKey, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine",
time.Now(), badKey)
assert.FatalError(t, err)
type args struct {
token string
}
tests := []struct {
name string
azure *Azure
args args
wantLen int
wantErr bool
}{
{"ok", p1, args{t1}, 4, false},
{"ok", p2, args{t2}, 5, false},
{"ok", p1, args{t11}, 4, false},
{"fail tenant", p3, args{t3}, 0, true},
{"fail resource group", p4, args{t4}, 0, true},
{"fail token", p1, args{"token"}, 0, true},
{"fail issuer", p1, args{failIssuer}, 0, true},
{"fail audience", p1, args{failAudience}, 0, true},
{"fail exp", p1, args{failExp}, 0, true},
{"fail nbf", p1, args{failNbf}, 0, true},
{"fail key", p1, args{failKey}, 0, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.azure.AuthorizeSign(tt.args.token)
if (err != nil) != tt.wantErr {
t.Errorf("Azure.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
return
}
assert.Len(t, tt.wantLen, got)
})
}
}
func TestAzure_AuthorizeRenewal(t *testing.T) {
p1, err := generateAzure()
assert.FatalError(t, err)
p2, err := generateAzure()
assert.FatalError(t, err)
// disable renewal
disable := true
p2.Claims = &Claims{DisableRenewal: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
type args struct {
cert *x509.Certificate
}
tests := []struct {
name string
azure *Azure
args args
wantErr bool
}{
{"ok", p1, args{nil}, false},
{"fail", p2, args{nil}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.azure.AuthorizeRenewal(tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("Azure.AuthorizeRenewal() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestAzure_AuthorizeRevoke(t *testing.T) {
az, srv, err := generateAzureWithServer()
assert.FatalError(t, err)
defer srv.Close()
token, err := az.GetIdentityToken()
assert.FatalError(t, err)
type args struct {
token string
}
tests := []struct {
name string
azure *Azure
args args
wantErr bool
}{
{"ok token", az, args{token}, true}, // revoke is disabled
{"bad token", az, args{"bad token"}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.azure.AuthorizeRevoke(tt.args.token); (err != nil) != tt.wantErr {
t.Errorf("Azure.AuthorizeRevoke() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestAzure_assertConfig(t *testing.T) {
p1, err := generateAzure()
assert.FatalError(t, err)
p2, err := generateAzure()
assert.FatalError(t, err)
p2.config = nil
tests := []struct {
name string
azure *Azure
}{
{"ok with config", p1},
{"ok no config", p2},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.azure.assertConfig()
})
}
}

@ -33,6 +33,14 @@ func (p provisionerSlice) Len() int { return len(p) }
func (p provisionerSlice) Less(i, j int) bool { return p[i].uid < p[j].uid }
func (p provisionerSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
// loadByTokenPayload is a payload used to extract the id used to load the
// provisioner.
type loadByTokenPayload struct {
jose.Claims
AuthorizedParty string `json:"azp"` // OIDC client id
TenantID string `json:"tid"` // Microsoft Azure tenant id
}
// Collection is a memory map of provisioners.
type Collection struct {
byID *sync.Map
@ -58,25 +66,48 @@ func (c *Collection) Load(id string) (Interface, bool) {
// LoadByToken parses the token claims and loads the provisioner associated.
func (c *Collection) LoadByToken(token *jose.JSONWebToken, claims *jose.Claims) (Interface, bool) {
var audiences []string
// Get all audiences with the given fragment
fragment := extractFragment(claims.Audience)
if fragment == "" {
audiences = c.audiences.All()
} else {
audiences = c.audiences.WithFragment(fragment).All()
}
// match with server audiences
if matchesAudience(claims.Audience, c.audiences.All()) {
if matchesAudience(claims.Audience, audiences) {
// Use fragment to get provisioner name (GCP, AWS)
if fragment != "" {
return c.Load(fragment)
}
// If matches with stored audiences it will be a JWT token (default), and
// the id would be <issuer>:<kid>.
return c.Load(claims.Issuer + ":" + token.Headers[0].KeyID)
}
// The ID will be just the clientID stored in azp or aud.
var payload openIDPayload
// The ID will be just the clientID stored in azp, aud or tid.
var payload loadByTokenPayload
if err := token.UnsafeClaimsWithoutVerification(&payload); err != nil {
return nil, false
}
// audience is required
// Audience is required
if len(payload.Audience) == 0 {
return nil, false
}
// Try with azp (OIDC)
if len(payload.AuthorizedParty) > 0 {
return c.Load(payload.AuthorizedParty)
if p, ok := c.Load(payload.AuthorizedParty); ok {
return p, ok
}
}
// Try with tid (Azure)
if payload.TenantID != "" {
if p, ok := c.Load(payload.TenantID); ok {
return p, ok
}
}
// Fallback to aud
return c.Load(payload.Audience[0])
}
@ -89,10 +120,16 @@ func (c *Collection) LoadByCertificate(cert *x509.Certificate) (Interface, bool)
if _, err := asn1.Unmarshal(e.Value, &provisioner); err != nil {
return nil, false
}
if provisioner.Type == int(TypeJWK) {
switch Type(provisioner.Type) {
case TypeJWK:
return c.Load(string(provisioner.Name) + ":" + string(provisioner.CredentialID))
case TypeAWS:
return c.Load("aws/" + string(provisioner.Name))
case TypeGCP:
return c.Load("gcp/" + string(provisioner.Name))
default:
return c.Load(string(provisioner.CredentialID))
}
return c.Load(string(provisioner.CredentialID))
}
}
@ -210,3 +247,13 @@ func stripPort(rawurl string) string {
u.Host = u.Hostname()
return u.String()
}
// extractFragment extracts the first fragment of an audience url.
func extractFragment(audience []string) string {
for _, s := range audience {
if u, err := url.Parse(s); err == nil && u.Fragment != "" {
return u.Fragment
}
}
return ""
}

@ -12,6 +12,16 @@ type Duration struct {
time.Duration
}
// NewDuration parses a duration string and returns a Duration type or an error
// if the given string is not a duration.
func NewDuration(s string) (*Duration, error) {
d, err := time.ParseDuration(s)
if err != nil {
return nil, errors.Wrapf(err, "error parsing %s as duration", s)
}
return &Duration{Duration: d}, nil
}
// MarshalJSON parses a duration string and sets it to the duration.
//
// A duration string is a possibly signed sequence of decimal numbers, each with
@ -29,7 +39,7 @@ func (d *Duration) MarshalJSON() ([]byte, error) {
func (d *Duration) UnmarshalJSON(data []byte) (err error) {
var (
s string
_d time.Duration
dd time.Duration
)
if d == nil {
return errors.New("duration cannot be nil")
@ -37,9 +47,17 @@ func (d *Duration) UnmarshalJSON(data []byte) (err error) {
if err = json.Unmarshal(data, &s); err != nil {
return errors.Wrapf(err, "error unmarshaling %s", data)
}
if _d, err = time.ParseDuration(s); err != nil {
if dd, err = time.ParseDuration(s); err != nil {
return errors.Wrapf(err, "error parsing %s as duration", s)
}
d.Duration = _d
d.Duration = dd
return
}
// Value returns 0 if the duration is null, the inner duration otherwise.
func (d *Duration) Value() time.Duration {
if d == nil {
return 0
}
return d.Duration
}

@ -6,6 +6,35 @@ import (
"time"
)
func TestNewDuration(t *testing.T) {
type args struct {
s string
}
tests := []struct {
name string
args args
want *Duration
wantErr bool
}{
{"ok", args{"1h2m3s"}, &Duration{Duration: 3723 * time.Second}, false},
{"fail empty", args{""}, nil, true},
{"fail number", args{"123"}, nil, true},
{"fail string", args{"1hour"}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NewDuration(tt.args.s)
if (err != nil) != tt.wantErr {
t.Errorf("NewDuration() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("NewDuration() = %v, want %v", got, tt.want)
}
})
}
}
func TestDuration_UnmarshalJSON(t *testing.T) {
type args struct {
data []byte
@ -59,3 +88,24 @@ func TestDuration_MarshalJSON(t *testing.T) {
})
}
}
func TestDuration_Value(t *testing.T) {
var dur *Duration
tests := []struct {
name string
duration *Duration
want time.Duration
}{
{"ok", &Duration{Duration: 1 * time.Minute}, 1 * time.Minute},
{"ok new", new(Duration), 0},
{"ok nil", nil, 0},
{"ok nil var", dur, 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.duration.Value(); got != tt.want {
t.Errorf("Duration.Value() = %v, want %v", got, tt.want)
}
})
}
}

@ -0,0 +1,343 @@
package provisioner
import (
"bytes"
"crypto/sha256"
"crypto/x509"
"encoding/hex"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"strings"
"time"
"github.com/pkg/errors"
"github.com/smallstep/cli/jose"
)
// gcpCertsURL is the url that serves Google OAuth2 public keys.
const gcpCertsURL = "https://www.googleapis.com/oauth2/v3/certs"
// gcpIdentityURL is the base url for the identity document in GCP.
const gcpIdentityURL = "http://metadata/computeMetadata/v1/instance/service-accounts/default/identity"
// gcpPayload extends jwt.Claims with custom GCP attributes.
type gcpPayload struct {
jose.Claims
AuthorizedParty string `json:"azp"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
Google gcpGooglePayload `json:"google"`
}
type gcpGooglePayload struct {
ComputeEngine gcpComputeEnginePayload `json:"compute_engine"`
}
type gcpComputeEnginePayload struct {
InstanceID string `json:"instance_id"`
InstanceName string `json:"instance_name"`
InstanceCreationTimestamp *jose.NumericDate `json:"instance_creation_timestamp"`
ProjectID string `json:"project_id"`
ProjectNumber int64 `json:"project_number"`
Zone string `json:"zone"`
LicenseID []string `json:"license_id"`
}
type gcpConfig struct {
CertsURL string
IdentityURL string
}
func newGCPConfig() *gcpConfig {
return &gcpConfig{
CertsURL: gcpCertsURL,
IdentityURL: gcpIdentityURL,
}
}
// GCP is the provisioner that supports identity tokens created by the Google
// Cloud Platform metadata API.
//
// If DisableCustomSANs is true, only the internal DNS and IP will be added as a
// SAN. By default it will accept any SAN in the CSR.
//
// If DisableTrustOnFirstUse is true, multiple sign request for this provisioner
// with the same instance will be accepted. By default only the first request
// will be accepted.
//
// If InstanceAge is set, only the instances with an instance_creation_timestamp
// within the given period will be accepted.
//
// Google Identity docs are available at
// https://cloud.google.com/compute/docs/instances/verifying-instance-identity
type GCP struct {
Type string `json:"type"`
Name string `json:"name"`
ServiceAccounts []string `json:"serviceAccounts"`
ProjectIDs []string `json:"projectIDs"`
DisableCustomSANs bool `json:"disableCustomSANs"`
DisableTrustOnFirstUse bool `json:"disableTrustOnFirstUse"`
InstanceAge Duration `json:"instanceAge,omitempty"`
Claims *Claims `json:"claims,omitempty"`
claimer *Claimer
config *gcpConfig
keyStore *keyStore
audiences Audiences
}
// GetID returns the provisioner unique identifier. The name should uniquely
// identify any GCP provisioner.
func (p *GCP) GetID() string {
return "gcp/" + p.Name
}
// GetTokenID returns the identifier of the token. The default value for GCP the
// SHA256 of "provisioner_id.instance_id", but if DisableTrustOnFirstUse is set
// to true, then it will be the SHA256 of the token.
func (p *GCP) GetTokenID(token string) (string, error) {
jwt, err := jose.ParseSigned(token)
if err != nil {
return "", errors.Wrap(err, "error parsing token")
}
// If TOFU is disabled create an ID for the token, so it cannot be reused.
if p.DisableTrustOnFirstUse {
sum := sha256.Sum256([]byte(token))
return strings.ToLower(hex.EncodeToString(sum[:])), nil
}
// Get claims w/out verification.
var claims gcpPayload
if err = jwt.UnsafeClaimsWithoutVerification(&claims); err != nil {
return "", errors.Wrap(err, "error verifying claims")
}
// Create unique ID for Trust On First Use (TOFU). Only the first instance
// per provisioner is allowed as we don't have a way to trust the given
// sans.
unique := fmt.Sprintf("%s.%s", p.GetID(), claims.Google.ComputeEngine.InstanceID)
sum := sha256.Sum256([]byte(unique))
return strings.ToLower(hex.EncodeToString(sum[:])), nil
}
// GetName returns the name of the provisioner.
func (p *GCP) GetName() string {
return p.Name
}
// GetType returns the type of provisioner.
func (p *GCP) GetType() Type {
return TypeGCP
}
// GetEncryptedKey is not available in a GCP provisioner.
func (p *GCP) GetEncryptedKey() (kid string, key string, ok bool) {
return "", "", false
}
// GetIdentityURL returns the url that generates the GCP token.
func (p *GCP) GetIdentityURL(audience string) string {
// Initialize config if required
p.assertConfig()
q := url.Values{}
q.Add("audience", audience)
q.Add("format", "full")
q.Add("licenses", "FALSE")
return fmt.Sprintf("%s?%s", p.config.IdentityURL, q.Encode())
}
// GetIdentityToken does an HTTP request to the identity url.
func (p *GCP) GetIdentityToken(caURL string) (string, error) {
audience, err := generateSignAudience(caURL, p.GetID())
if err != nil {
return "", err
}
req, err := http.NewRequest("GET", p.GetIdentityURL(audience), http.NoBody)
if err != nil {
return "", errors.Wrap(err, "error creating identity request")
}
req.Header.Set("Metadata-Flavor", "Google")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", errors.Wrap(err, "error doing identity request, are you in a GCP VM?")
}
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return "", errors.Wrap(err, "error on identity request")
}
if resp.StatusCode >= 400 {
return "", errors.Errorf("error on identity request: status=%d, response=%s", resp.StatusCode, b)
}
return string(bytes.TrimSpace(b)), nil
}
// Init validates and initializes the GCP provisioner.
func (p *GCP) Init(config Config) error {
var err error
switch {
case p.Type == "":
return errors.New("provisioner type cannot be empty")
case p.Name == "":
return errors.New("provisioner name cannot be empty")
case p.InstanceAge.Value() < 0:
return errors.New("provisioner instanceAge cannot be negative")
}
// Initialize config
p.assertConfig()
// Update claims with global ones
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
return err
}
// Initialize key store
p.keyStore, err = newKeyStore(p.config.CertsURL)
if err != nil {
return err
}
p.audiences = config.Audiences.WithFragment(p.GetID())
return nil
}
// AuthorizeSign validates the given token and returns the sign options that
// will be used on certificate creation.
func (p *GCP) AuthorizeSign(token string) ([]SignOption, error) {
claims, err := p.authorizeToken(token)
if err != nil {
return nil, err
}
ce := claims.Google.ComputeEngine
// Enforce default DNS if configured.
// By default we we'll accept the SANs in the CSR.
// There's no way to trust them other than TOFU.
var so []SignOption
if p.DisableCustomSANs {
so = append(so, dnsNamesValidator([]string{
fmt.Sprintf("%s.c.%s.internal", ce.InstanceName, ce.ProjectID),
fmt.Sprintf("%s.%s.c.%s.internal", ce.InstanceName, ce.Zone, ce.ProjectID),
}))
}
return append(so,
commonNameValidator(ce.InstanceName),
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()),
newProvisionerExtensionOption(TypeGCP, p.Name, claims.Subject),
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()),
), nil
}
// AuthorizeRenewal returns an error if the renewal is disabled.
func (p *GCP) AuthorizeRenewal(cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() {
return errors.Errorf("renew is disabled for provisioner %s", p.GetID())
}
return nil
}
// AuthorizeRevoke returns an error because revoke is not supported on GCP
// provisioners.
func (p *GCP) AuthorizeRevoke(token string) error {
return errors.New("revoke is not supported on a GCP provisioner")
}
// assertConfig initializes the config if it has not been initialized.
func (p *GCP) assertConfig() {
if p.config == nil {
p.config = newGCPConfig()
}
}
// authorizeToken performs common jwt authorization actions and returns the
// claims for case specific downstream parsing.
// e.g. a Sign request will auth/validate different fields than a Revoke request.
func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
jwt, err := jose.ParseSigned(token)
if err != nil {
return nil, errors.Wrapf(err, "error parsing token")
}
if len(jwt.Headers) == 0 {
return nil, errors.New("error parsing token: header is missing")
}
var found bool
var claims gcpPayload
kid := jwt.Headers[0].KeyID
keys := p.keyStore.Get(kid)
for _, key := range keys {
if err := jwt.Claims(key.Public(), &claims); err == nil {
found = true
break
}
}
if !found {
return nil, errors.Errorf("failed to validate payload: cannot find key for kid %s", kid)
}
// According to "rfc7519 JSON Web Token" acceptable skew should be no
// more than a few minutes.
now := time.Now().UTC()
if err = claims.ValidateWithLeeway(jose.Expected{
Issuer: "https://accounts.google.com",
Time: now,
}, time.Minute); err != nil {
return nil, errors.Wrapf(err, "invalid token")
}
// validate audiences with the defaults
if !matchesAudience(claims.Audience, p.audiences.Sign) {
return nil, errors.New("invalid token: invalid audience claim (aud)")
}
// validate subject (service account)
if len(p.ServiceAccounts) > 0 {
var found bool
for _, sa := range p.ServiceAccounts {
if sa == claims.Subject || sa == claims.Email {
found = true
break
}
}
if !found {
return nil, errors.New("invalid token: invalid subject claim")
}
}
// validate projects
if len(p.ProjectIDs) > 0 {
var found bool
for _, pi := range p.ProjectIDs {
if pi == claims.Google.ComputeEngine.ProjectID {
found = true
break
}
}
if !found {
return nil, errors.New("invalid token: invalid project id")
}
}
// validate instance age
if d := p.InstanceAge.Value(); d > 0 {
if now.Sub(claims.Google.ComputeEngine.InstanceCreationTimestamp.Time()) > d {
return nil, errors.New("token google.compute_engine.instance_creation_timestamp is too old")
}
}
switch {
case claims.Google.ComputeEngine.InstanceID == "":
return nil, errors.New("token google.compute_engine.instance_id cannot be empty")
case claims.Google.ComputeEngine.InstanceName == "":
return nil, errors.New("token google.compute_engine.instance_name cannot be empty")
case claims.Google.ComputeEngine.ProjectID == "":
return nil, errors.New("token google.compute_engine.project_id cannot be empty")
case claims.Google.ComputeEngine.Zone == "":
return nil, errors.New("token google.compute_engine.zone cannot be empty")
}
return &claims, nil
}

@ -0,0 +1,404 @@
package provisioner
import (
"crypto/sha256"
"crypto/x509"
"encoding/hex"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
"github.com/smallstep/assert"
)
func TestGCP_Getters(t *testing.T) {
p, err := generateGCP()
assert.FatalError(t, err)
id := "gcp/" + p.Name
if got := p.GetID(); got != id {
t.Errorf("GCP.GetID() = %v, want %v", got, id)
}
if got := p.GetName(); got != p.Name {
t.Errorf("GCP.GetName() = %v, want %v", got, p.Name)
}
if got := p.GetType(); got != TypeGCP {
t.Errorf("GCP.GetType() = %v, want %v", got, TypeGCP)
}
kid, key, ok := p.GetEncryptedKey()
if kid != "" || key != "" || ok == true {
t.Errorf("GCP.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)",
kid, key, ok, "", "", false)
}
aud := "https://ca.smallstep.com/1.0/sign#" + url.QueryEscape(id)
expected := fmt.Sprintf("http://metadata/computeMetadata/v1/instance/service-accounts/default/identity?audience=%s&format=full&licenses=FALSE", url.QueryEscape(aud))
if got := p.GetIdentityURL(aud); got != expected {
t.Errorf("GCP.GetIdentityURL() = %v, want %v", got, expected)
}
}
func TestGCP_GetTokenID(t *testing.T) {
p1, err := generateGCP()
assert.FatalError(t, err)
p1.Name = "name"
p2, err := generateGCP()
assert.FatalError(t, err)
p2.DisableTrustOnFirstUse = true
now := time.Now()
t1, err := generateGCPToken(p1.ServiceAccounts[0],
"https://accounts.google.com", "gcp/name",
"instance-id", "instance-name", "project-id", "zone",
now, &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
t2, err := generateGCPToken(p2.ServiceAccounts[0],
"https://accounts.google.com", p2.GetID(),
"instance-id", "instance-name", "project-id", "zone",
now, &p2.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
sum := sha256.Sum256([]byte("gcp/name.instance-id"))
want1 := strings.ToLower(hex.EncodeToString(sum[:]))
sum = sha256.Sum256([]byte(t2))
want2 := strings.ToLower(hex.EncodeToString(sum[:]))
type args struct {
token string
}
tests := []struct {
name string
gcp *GCP
args args
want string
wantErr bool
}{
{"ok", p1, args{t1}, want1, false},
{"ok", p2, args{t2}, want2, false},
{"fail token", p1, args{"token"}, "", true},
{"fail claims", p1, args{"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.ey.fooo"}, "", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.gcp.GetTokenID(tt.args.token)
if (err != nil) != tt.wantErr {
t.Errorf("GCP.GetTokenID() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("GCP.GetTokenID() = %v, want %v", got, tt.want)
}
})
}
}
func TestGCP_GetIdentityToken(t *testing.T) {
p1, err := generateGCP()
assert.FatalError(t, err)
t1, err := generateGCPToken(p1.ServiceAccounts[0],
"https://accounts.google.com", p1.GetID(),
"instance-id", "instance-name", "project-id", "zone",
time.Now(), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/bad-request":
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
default:
w.Write([]byte(t1))
}
}))
defer srv.Close()
type args struct {
caURL string
}
tests := []struct {
name string
gcp *GCP
args args
identityURL string
want string
wantErr bool
}{
{"ok", p1, args{"https://ca"}, srv.URL, t1, false},
{"fail ca url", p1, args{"://ca"}, srv.URL, "", true},
{"fail request", p1, args{"https://ca"}, srv.URL + "/bad-request", "", true},
{"fail url", p1, args{"https://ca"}, "://ca.smallstep.com", "", true},
{"fail connect", p1, args{"https://ca"}, "foobarzar", "", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.gcp.config.IdentityURL = tt.identityURL
got, err := tt.gcp.GetIdentityToken(tt.args.caURL)
t.Log(err)
if (err != nil) != tt.wantErr {
t.Errorf("GCP.GetIdentityToken() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("GCP.GetIdentityToken() = %v, want %v", got, tt.want)
}
})
}
}
func TestGCP_Init(t *testing.T) {
srv := generateJWKServer(2)
defer srv.Close()
config := Config{
Claims: globalProvisionerClaims,
}
badClaims := &Claims{
DefaultTLSDur: &Duration{0},
}
zero := Duration{Duration: 0}
type fields struct {
Type string
Name string
ServiceAccounts []string
InstanceAge Duration
Claims *Claims
}
type args struct {
config Config
certsURL string
}
tests := []struct {
name string
fields fields
args args
wantErr bool
}{
{"ok", fields{"GCP", "name", nil, zero, nil}, args{config, srv.URL}, false},
{"ok", fields{"GCP", "name", []string{"service-account"}, zero, nil}, args{config, srv.URL}, false},
{"ok", fields{"GCP", "name", []string{"service-account"}, Duration{Duration: 1 * time.Minute}, nil}, args{config, srv.URL}, false},
{"bad type", fields{"", "name", nil, zero, nil}, args{config, srv.URL}, true},
{"bad name", fields{"GCP", "", nil, zero, nil}, args{config, srv.URL}, true},
{"bad duration", fields{"GCP", "name", nil, Duration{Duration: -1 * time.Minute}, nil}, args{config, srv.URL}, true},
{"bad claims", fields{"GCP", "name", nil, zero, badClaims}, args{config, srv.URL}, true},
{"bad certs", fields{"GCP", "name", nil, zero, nil}, args{config, srv.URL + "/error"}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &GCP{
Type: tt.fields.Type,
Name: tt.fields.Name,
ServiceAccounts: tt.fields.ServiceAccounts,
InstanceAge: tt.fields.InstanceAge,
Claims: tt.fields.Claims,
config: &gcpConfig{
CertsURL: tt.args.certsURL,
IdentityURL: gcpIdentityURL,
},
}
if err := p.Init(tt.args.config); (err != nil) != tt.wantErr {
t.Errorf("GCP.Init() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestGCP_AuthorizeSign(t *testing.T) {
p1, err := generateGCP()
assert.FatalError(t, err)
p2, err := generateGCP()
assert.FatalError(t, err)
p2.DisableCustomSANs = true
p3, err := generateGCP()
assert.FatalError(t, err)
p3.ProjectIDs = []string{"other-project-id"}
p3.ServiceAccounts = []string{"foo@developer.gserviceaccount.com"}
p3.InstanceAge = Duration{1 * time.Minute}
aKey, err := generateJSONWebKey()
assert.FatalError(t, err)
t1, err := generateGCPToken(p1.ServiceAccounts[0],
"https://accounts.google.com", p1.GetID(),
"instance-id", "instance-name", "project-id", "zone",
time.Now(), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
t2, err := generateGCPToken(p2.ServiceAccounts[0],
"https://accounts.google.com", p2.GetID(),
"instance-id", "instance-name", "project-id", "zone",
time.Now(), &p2.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
t3, err := generateGCPToken(p3.ServiceAccounts[0],
"https://accounts.google.com", p3.GetID(),
"instance-id", "instance-name", "other-project-id", "zone",
time.Now(), &p3.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
failKey, err := generateGCPToken(p1.ServiceAccounts[0],
"https://accounts.google.com", p1.GetID(),
"instance-id", "instance-name", "project-id", "zone",
time.Now(), aKey)
assert.FatalError(t, err)
failIss, err := generateGCPToken(p1.ServiceAccounts[0],
"https://foo.bar.zar", p1.GetID(),
"instance-id", "instance-name", "project-id", "zone",
time.Now(), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
failAud, err := generateGCPToken(p1.ServiceAccounts[0],
"https://accounts.google.com", "gcp:foo",
"instance-id", "instance-name", "project-id", "zone",
time.Now(), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
failExp, err := generateGCPToken(p1.ServiceAccounts[0],
"https://accounts.google.com", p1.GetID(),
"instance-id", "instance-name", "project-id", "zone",
time.Now().Add(-360*time.Second), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
failNbf, err := generateGCPToken(p1.ServiceAccounts[0],
"https://accounts.google.com", p1.GetID(),
"instance-id", "instance-name", "project-id", "zone",
time.Now().Add(360*time.Second), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
failServiceAccount, err := generateGCPToken("foo",
"https://accounts.google.com", p1.GetID(),
"instance-id", "instance-name", "project-id", "zone",
time.Now(), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
failInvalidProjectID, err := generateGCPToken(p3.ServiceAccounts[0],
"https://accounts.google.com", p3.GetID(),
"instance-id", "instance-name", "project-id", "zone",
time.Now(), &p3.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
failInvalidInstanceAge, err := generateGCPToken(p3.ServiceAccounts[0],
"https://accounts.google.com", p3.GetID(),
"instance-id", "instance-name", "other-project-id", "zone",
time.Now().Add(-1*time.Minute), &p3.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
failInstanceID, err := generateGCPToken(p1.ServiceAccounts[0],
"https://accounts.google.com", p1.GetID(),
"", "instance-name", "project-id", "zone",
time.Now(), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
failInstanceName, err := generateGCPToken(p1.ServiceAccounts[0],
"https://accounts.google.com", p1.GetID(),
"instance-id", "", "project-id", "zone",
time.Now(), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
failProjectID, err := generateGCPToken(p1.ServiceAccounts[0],
"https://accounts.google.com", p1.GetID(),
"instance-id", "instance-name", "", "zone",
time.Now(), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
failZone, err := generateGCPToken(p1.ServiceAccounts[0],
"https://accounts.google.com", p1.GetID(),
"instance-id", "instance-name", "project-id", "",
time.Now(), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
type args struct {
token string
}
tests := []struct {
name string
gcp *GCP
args args
wantLen int
wantErr bool
}{
{"ok", p1, args{t1}, 4, false},
{"ok", p2, args{t2}, 5, false},
{"ok", p3, args{t3}, 4, false},
{"fail token", p1, args{"token"}, 0, true},
{"fail key", p1, args{failKey}, 0, true},
{"fail iss", p1, args{failIss}, 0, true},
{"fail aud", p1, args{failAud}, 0, true},
{"fail exp", p1, args{failExp}, 0, true},
{"fail nbf", p1, args{failNbf}, 0, true},
{"fail service account", p1, args{failServiceAccount}, 0, true},
{"fail invalid project id", p3, args{failInvalidProjectID}, 0, true},
{"fail invalid instance age", p3, args{failInvalidInstanceAge}, 0, true},
{"fail instance id", p1, args{failInstanceID}, 0, true},
{"fail instance name", p1, args{failInstanceName}, 0, true},
{"fail project id", p1, args{failProjectID}, 0, true},
{"fail zone", p1, args{failZone}, 0, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.gcp.AuthorizeSign(tt.args.token)
if (err != nil) != tt.wantErr {
t.Errorf("GCP.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
return
}
assert.Len(t, tt.wantLen, got)
})
}
}
func TestGCP_AuthorizeRenewal(t *testing.T) {
p1, err := generateGCP()
assert.FatalError(t, err)
p2, err := generateGCP()
assert.FatalError(t, err)
// disable renewal
disable := true
p2.Claims = &Claims{DisableRenewal: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
type args struct {
cert *x509.Certificate
}
tests := []struct {
name string
prov *GCP
args args
wantErr bool
}{
{"ok", p1, args{nil}, false},
{"fail", p2, args{nil}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.prov.AuthorizeRenewal(tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("GCP.AuthorizeRenewal() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestGCP_AuthorizeRevoke(t *testing.T) {
p1, err := generateGCP()
assert.FatalError(t, err)
t1, err := generateGCPToken(p1.ServiceAccounts[0],
"https://accounts.google.com", p1.GetID(),
"instance-id", "instance-name", "project-id", "zone",
time.Now(), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
type args struct {
token string
}
tests := []struct {
name string
gcp *GCP
args args
wantErr bool
}{
{"ok", p1, args{t1}, true}, // revoke is disabled
{"fail", p1, args{"token"}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.gcp.AuthorizeRevoke(tt.args.token); (err != nil) != tt.wantErr {
t.Errorf("GCP.AuthorizeRevoke() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

@ -33,7 +33,6 @@ func (p *JWK) GetID() string {
return p.Name + ":" + p.Key.KeyID
}
//
// GetTokenID returns the identifier of the token.
func (p *JWK) GetTokenID(ott string) (string, error) {
// Validate payload

@ -278,7 +278,6 @@ func TestOIDC_AuthorizeSign(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.prov.AuthorizeSign(tt.args.token)
if (err != nil) != tt.wantErr {
fmt.Println(tt)
t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr)
return
}
@ -386,47 +385,6 @@ func TestOIDC_AuthorizeRenewal(t *testing.T) {
}
}
/*
func TestOIDC_AuthorizeRevoke(t *testing.T) {
srv := generateJWKServer(2)
defer srv.Close()
var keys jose.JSONWebKeySet
assert.FatalError(t, getAndDecode(srv.URL+"/private", &keys))
// Create test provisioners
p1, err := generateOIDC()
assert.FatalError(t, err)
// Update configuration endpoints and initialize
config := Config{Claims: globalProvisionerClaims}
p1.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
assert.FatalError(t, p1.Init(config))
t1, err := generateSimpleToken("the-issuer", p1.ClientID, &keys.Keys[0])
assert.FatalError(t, err)
type args struct {
token string
}
tests := []struct {
name string
prov *OIDC
args args
wantErr bool
}{
{"disabled", p1, args{t1}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.prov.AuthorizeRevoke(tt.args.token); (err != nil) != tt.wantErr {
t.Errorf("OIDC.AuthorizeRevoke() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
*/
func Test_sanitizeEmail(t *testing.T) {
tests := []struct {
name string

@ -3,6 +3,7 @@ package provisioner
import (
"crypto/x509"
"encoding/json"
"net/url"
"strings"
"github.com/pkg/errors"
@ -28,21 +29,59 @@ type Audiences struct {
}
// All returns all supported audiences across all request types in one list.
func (a *Audiences) All() []string {
func (a Audiences) All() []string {
return append(a.Sign, a.Revoke...)
}
// WithFragment returns a copy of audiences where the url audiences contains the
// given fragment.
func (a Audiences) WithFragment(fragment string) Audiences {
ret := Audiences{
Sign: make([]string, len(a.Sign)),
Revoke: make([]string, len(a.Revoke)),
}
for i, s := range a.Sign {
if u, err := url.Parse(s); err == nil {
ret.Sign[i] = u.ResolveReference(&url.URL{Fragment: fragment}).String()
} else {
ret.Sign[i] = s
}
}
for i, s := range a.Revoke {
if u, err := url.Parse(s); err == nil {
ret.Revoke[i] = u.ResolveReference(&url.URL{Fragment: fragment}).String()
} else {
ret.Revoke[i] = s
}
}
return ret
}
// generateSignAudience generates a sign audience with the format
// https://<host>/1.0/sign#provisionerID
func generateSignAudience(caURL string, provisionerID string) (string, error) {
u, err := url.Parse(caURL)
if err != nil {
return "", errors.Wrapf(err, "error parsing %s", caURL)
}
return u.ResolveReference(&url.URL{Path: "/1.0/sign", Fragment: provisionerID}).String(), nil
}
// Type indicates the provisioner Type.
type Type int
const (
noopType Type = 0
// TypeJWK is used to indicate the JWK provisioners.
TypeJWK Type = 1
// TypeOIDC is used to indicate the OIDC provisioners.
TypeOIDC Type = 2
// TypeGCP is used to indicate the GCP provisioners.
TypeGCP Type = 3
// TypeAWS is used to indicate the AWS provisioners.
TypeAWS Type = 4
// TypeAzure is used to indicate the Azure provisioners.
TypeAzure Type = 5
// RevokeAudienceKey is the key for the 'revoke' audiences in the audiences map.
RevokeAudienceKey = "revoke"
@ -50,6 +89,24 @@ const (
SignAudienceKey = "sign"
)
// String returns the string representation of the type.
func (t Type) String() string {
switch t {
case TypeJWK:
return "JWK"
case TypeOIDC:
return "OIDC"
case TypeGCP:
return "GCP"
case TypeAWS:
return "AWS"
case TypeAzure:
return "Azure"
default:
return ""
}
}
// Config defines the default parameters used in the initialization of
// provisioners.
type Config struct {
@ -86,6 +143,12 @@ func (l *List) UnmarshalJSON(data []byte) error {
p = &JWK{}
case "oidc":
p = &OIDC{}
case "gcp":
p = &GCP{}
case "aws":
p = &AWS{}
case "azure":
p = &Azure{}
default:
return errors.Errorf("provisioner type %s not supported", typ.Type)
}

@ -0,0 +1,26 @@
package provisioner
import "testing"
func TestType_String(t *testing.T) {
tests := []struct {
name string
t Type
want string
}{
{"JWK", TypeJWK, "JWK"},
{"OIDC", TypeOIDC, "OIDC"},
{"AWS", TypeAWS, "AWS"},
{"Azure", TypeAzure, "Azure"},
{"GCP", TypeGCP, "GCP"},
{"noop", noopType, ""},
{"notFound", 1000, ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.t.String(); got != tt.want {
t.Errorf("Type.String() = %v, want %v", got, tt.want)
}
})
}
}

@ -2,12 +2,19 @@ package provisioner
import (
"crypto"
"crypto/rand"
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"encoding/hex"
"encoding/json"
"encoding/pem"
"fmt"
"net/http"
"net/http/httptest"
"time"
"github.com/pkg/errors"
"github.com/smallstep/cli/crypto/randutil"
"github.com/smallstep/cli/jose"
)
@ -17,6 +24,37 @@ var testAudiences = Audiences{
Revoke: []string{"https://ca.smallstep.com/revoke", "https://ca.smallstep.com/1.0/revoke"},
}
const awsTestCertificate = `-----BEGIN CERTIFICATE-----
MIICFTCCAX6gAwIBAgIRAKmbVVYAl/1XEqRfF3eJ97MwDQYJKoZIhvcNAQELBQAw
GDEWMBQGA1UEAxMNQVdTIFRlc3QgQ2VydDAeFw0xOTA0MjQyMjU3MzlaFw0yOTA0
MjEyMjU3MzlaMBgxFjAUBgNVBAMTDUFXUyBUZXN0IENlcnQwgZ8wDQYJKoZIhvcN
AQEBBQADgY0AMIGJAoGBAOHMmMXwbXN90SoRl/xXAcJs5TacaVYJ5iNAVWM5KYyF
+JwqYuJp/umLztFUi0oX0luu3EzD4KurVeUJSzZjTFTX1d/NX6hA45+bvdSUOcgV
UghO+2uhBZ4SNFxFRZ7SKvoWIN195l5bVX6/60Eo6+kUCKCkyxW4V/ksWzdXjHnf
AgMBAAGjXzBdMA4GA1UdDwEB/wQEAwIBBjASBgNVHRMBAf8ECDAGAQH/AgEBMB0G
A1UdDgQWBBRHfLOjEddK/CWCIHNg8Oc/oJa1IzAYBgNVHREEETAPgg1BV1MgVGVz
dCBDZXJ0MA0GCSqGSIb3DQEBCwUAA4GBAKNCiVM9eGb9dW2xNyHaHAmmy7ERB2OJ
7oXHfLjooOavk9lU/Gs2jfX/JSBa84+DzWg9ShmCNLti8CxU/dhzXW7jE/5CcdTa
DCA6B3Yl5TmfG9+D9dtFqRB2CiMgNcsJJE5Dc6pDwBIiSj/MkE0AaGVQmSwn6Cb6
vX1TAxqeWJHq
-----END CERTIFICATE-----`
const awsTestKey = `-----BEGIN RSA PRIVATE KEY-----
MIICXAIBAAKBgQDhzJjF8G1zfdEqEZf8VwHCbOU2nGlWCeYjQFVjOSmMhficKmLi
af7pi87RVItKF9JbrtxMw+Crq1XlCUs2Y0xU19XfzV+oQOOfm73UlDnIFVIITvtr
oQWeEjRcRUWe0ir6FiDdfeZeW1V+v+tBKOvpFAigpMsVuFf5LFs3V4x53wIDAQAB
AoGADZQFF9oWatyFCHeYYSdGRs/PlNIhD3h262XB/L6CPh4MTi/KVH01RAwROstP
uPvnvXWtb7xTtV8PQj+l0zZzb4W/DLCSBdoRwpuNXyffUCtbI22jPupTsVu+ENWR
3x7HHzoZYjU45ADSTMxEtwD7/zyNgpRKjIA2HYpkt+fI27ECQQD5/AOr9/yQD73x
cquF+FWahWgDL25YeMwdfe1HfpUxUxd9kJJKieB8E2BtBAv9XNguxIBpf7VlAKsF
NFhdfWFHAkEA5zuX8vqDecSzyNNEQd3tugxt1pGOXNesHzuPbdlw3ppN9Rbd93an
uU2TaAvTjr/3EkxulYNRmHs+RSVK54+uqQJAKWurhBQMAibJlzcj2ofiTz8pk9WJ
GBmz4HMcHMuJlumoq8KHqtgbnRNs18Ni5TE8FMu0Z0ak3L52l98rgRokQwJBAJS8
9KTLF79AFBVeME3eH4jJbe3TeyulX4ZHnZ8fe0b1IqhAqU8A+CpuCB+pW9A7Ewam
O4vZCKd4vzljH6eL+OECQHHxhYoTW7lFpKGnUDG9fPZ3eYzWpgka6w1vvBk10BAu
6fbwppM9pQ7DPMg7V6YGEjjT0gX9B9TttfHxGhvtZNQ=
-----END RSA PRIVATE KEY-----`
func must(args ...interface{}) []interface{} {
if l := len(args); l > 0 && args[l-1] != nil {
if err, ok := args[l-1].(error); ok {
@ -163,6 +201,234 @@ func generateOIDC() (*OIDC, error) {
}, nil
}
func generateGCP() (*GCP, error) {
name, err := randutil.Alphanumeric(10)
if err != nil {
return nil, err
}
serviceAccount, err := randutil.Alphanumeric(10)
if err != nil {
return nil, err
}
jwk, err := generateJSONWebKey()
if err != nil {
return nil, err
}
claimer, err := NewClaimer(nil, globalProvisionerClaims)
if err != nil {
return nil, err
}
return &GCP{
Type: "GCP",
Name: name,
ServiceAccounts: []string{serviceAccount},
Claims: &globalProvisionerClaims,
claimer: claimer,
config: newGCPConfig(),
keyStore: &keyStore{
keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}},
expiry: time.Now().Add(24 * time.Hour),
},
audiences: testAudiences.WithFragment("gcp/" + name),
}, nil
}
func generateAWS() (*AWS, error) {
name, err := randutil.Alphanumeric(10)
if err != nil {
return nil, err
}
accountID, err := randutil.Alphanumeric(10)
if err != nil {
return nil, err
}
claimer, err := NewClaimer(nil, globalProvisionerClaims)
if err != nil {
return nil, err
}
block, _ := pem.Decode([]byte(awsTestCertificate))
if block == nil || block.Type != "CERTIFICATE" {
return nil, errors.New("error decoding AWS certificate")
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return nil, errors.Wrap(err, "error parsing AWS certificate")
}
return &AWS{
Type: "AWS",
Name: name,
Accounts: []string{accountID},
Claims: &globalProvisionerClaims,
claimer: claimer,
config: &awsConfig{
identityURL: awsIdentityURL,
signatureURL: awsSignatureURL,
certificate: cert,
signatureAlgorithm: awsSignatureAlgorithm,
},
audiences: testAudiences.WithFragment("aws/" + name),
}, nil
}
func generateAWSWithServer() (*AWS, *httptest.Server, error) {
aws, err := generateAWS()
if err != nil {
return nil, nil, err
}
block, _ := pem.Decode([]byte(awsTestKey))
if block == nil || block.Type != "RSA PRIVATE KEY" {
return nil, nil, errors.New("error decoding AWS key")
}
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return nil, nil, errors.Wrap(err, "error parsing AWS private key")
}
instanceID, err := randutil.Alphanumeric(10)
if err != nil {
return nil, nil, err
}
imageID, err := randutil.Alphanumeric(10)
if err != nil {
return nil, nil, err
}
doc, err := json.MarshalIndent(awsInstanceIdentityDocument{
AccountID: aws.Accounts[0],
Architecture: "x86_64",
AvailabilityZone: "us-west-2b",
ImageID: imageID,
InstanceID: instanceID,
InstanceType: "t2.micro",
PendingTime: time.Now(),
PrivateIP: "127.0.0.1",
Region: "us-west-1",
Version: "2017-09-30",
}, "", " ")
if err != nil {
return nil, nil, err
}
sum := sha256.Sum256(doc)
signature, err := key.Sign(rand.Reader, sum[:], crypto.SHA256)
if err != nil {
return nil, nil, errors.Wrap(err, "error signing document")
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/latest/dynamic/instance-identity/document":
w.Write(doc)
case "/latest/dynamic/instance-identity/signature":
w.Write([]byte(base64.StdEncoding.EncodeToString(signature)))
case "/bad-document":
w.Write([]byte("{}"))
case "/bad-signature":
w.Write([]byte("YmFkLXNpZ25hdHVyZQo="))
default:
http.NotFound(w, r)
}
}))
aws.config.identityURL = srv.URL + "/latest/dynamic/instance-identity/document"
aws.config.signatureURL = srv.URL + "/latest/dynamic/instance-identity/signature"
return aws, srv, nil
}
func generateAzure() (*Azure, error) {
name, err := randutil.Alphanumeric(10)
if err != nil {
return nil, err
}
tenantID, err := randutil.Alphanumeric(10)
if err != nil {
return nil, err
}
claimer, err := NewClaimer(nil, globalProvisionerClaims)
if err != nil {
return nil, err
}
jwk, err := generateJSONWebKey()
if err != nil {
return nil, err
}
return &Azure{
Type: "Azure",
Name: name,
TenantID: tenantID,
Audience: azureDefaultAudience,
Claims: &globalProvisionerClaims,
claimer: claimer,
config: newAzureConfig(tenantID),
oidcConfig: openIDConfiguration{
Issuer: "https://sts.windows.net/" + tenantID + "/",
JWKSetURI: "https://login.microsoftonline.com/common/discovery/keys",
},
keyStore: &keyStore{
keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}},
expiry: time.Now().Add(24 * time.Hour),
},
}, nil
}
func generateAzureWithServer() (*Azure, *httptest.Server, error) {
az, err := generateAzure()
if err != nil {
return nil, nil, err
}
writeJSON := func(w http.ResponseWriter, v interface{}) {
b, err := json.Marshal(v)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
w.Header().Add("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(b)
}
getPublic := func(ks jose.JSONWebKeySet) jose.JSONWebKeySet {
var ret jose.JSONWebKeySet
for _, k := range ks.Keys {
ret.Keys = append(ret.Keys, k.Public())
}
return ret
}
issuer := "https://sts.windows.net/" + az.TenantID + "/"
srv := httptest.NewUnstartedServer(nil)
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/error":
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
case "/" + az.TenantID + "/.well-known/openid-configuration":
writeJSON(w, openIDConfiguration{Issuer: issuer, JWKSetURI: srv.URL + "/jwks_uri"})
case "/openid-configuration-no-issuer":
writeJSON(w, openIDConfiguration{Issuer: "", JWKSetURI: srv.URL + "/jwks_uri"})
case "/openid-configuration-fail-jwk":
writeJSON(w, openIDConfiguration{Issuer: issuer, JWKSetURI: srv.URL + "/error"})
case "/random":
keySet := must(generateJSONWebKeySet(2))[0].(jose.JSONWebKeySet)
w.Header().Add("Cache-Control", "max-age=5")
writeJSON(w, getPublic(keySet))
case "/private":
writeJSON(w, az.keyStore.keySet)
case "/jwks_uri":
w.Header().Add("Cache-Control", "max-age=5")
writeJSON(w, getPublic(az.keyStore.keySet))
case "/metadata/identity/oauth2/token":
tok, err := generateAzureToken("subject", issuer, "https://management.azure.com/", az.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", time.Now(), &az.keyStore.keySet.Keys[0])
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
} else {
writeJSON(w, azureIdentityToken{
AccessToken: tok,
})
}
default:
http.NotFound(w, r)
}
})
srv.Start()
az.config.oidcDiscoveryURL = srv.URL + "/" + az.TenantID + "/.well-known/openid-configuration"
az.config.identityTokenURL = srv.URL + "/metadata/identity/oauth2/token"
return az, srv, nil
}
func generateCollection(nJWK, nOIDC int) (*Collection, error) {
col := NewCollection(testAudiences)
for i := 0; i < nJWK; i++ {
@ -220,6 +486,127 @@ func generateToken(sub, iss, aud string, email string, sans []string, iat time.T
return jose.Signed(sig).Claims(claims).CompactSerialize()
}
func generateGCPToken(sub, iss, aud, instanceID, instanceName, projectID, zone string, iat time.Time, jwk *jose.JSONWebKey) (string, error) {
sig, err := jose.NewSigner(
jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID),
)
if err != nil {
return "", err
}
aud, err = generateSignAudience("https://ca.smallstep.com", aud)
if err != nil {
return "", err
}
claims := gcpPayload{
Claims: jose.Claims{
Subject: sub,
Issuer: iss,
IssuedAt: jose.NewNumericDate(iat),
NotBefore: jose.NewNumericDate(iat),
Expiry: jose.NewNumericDate(iat.Add(5 * time.Minute)),
Audience: []string{aud},
},
AuthorizedParty: sub,
Email: "foo@developer.gserviceaccount.com",
EmailVerified: true,
Google: gcpGooglePayload{
ComputeEngine: gcpComputeEnginePayload{
InstanceID: instanceID,
InstanceName: instanceName,
InstanceCreationTimestamp: jose.NewNumericDate(iat),
ProjectID: projectID,
ProjectNumber: 1234567890,
Zone: zone,
},
},
}
return jose.Signed(sig).Claims(claims).CompactSerialize()
}
func generateAWSToken(sub, iss, aud, accountID, instanceID, privateIP, region string, iat time.Time, key crypto.Signer) (string, error) {
doc, err := json.MarshalIndent(awsInstanceIdentityDocument{
AccountID: accountID,
Architecture: "x86_64",
AvailabilityZone: "us-west-2b",
ImageID: "ami-123123",
InstanceID: instanceID,
InstanceType: "t2.micro",
PendingTime: iat,
PrivateIP: privateIP,
Region: region,
Version: "2017-09-30",
}, "", " ")
if err != nil {
return "", err
}
sum := sha256.Sum256(doc)
signature, err := key.Sign(rand.Reader, sum[:], crypto.SHA256)
if err != nil {
return "", errors.Wrap(err, "error signing document")
}
sig, err := jose.NewSigner(
jose.SigningKey{Algorithm: jose.HS256, Key: signature},
new(jose.SignerOptions).WithType("JWT"),
)
if err != nil {
return "", err
}
aud, err = generateSignAudience("https://ca.smallstep.com", aud)
if err != nil {
return "", err
}
claims := awsPayload{
Claims: jose.Claims{
Subject: sub,
Issuer: iss,
IssuedAt: jose.NewNumericDate(iat),
NotBefore: jose.NewNumericDate(iat),
Expiry: jose.NewNumericDate(iat.Add(5 * time.Minute)),
Audience: []string{aud},
},
Amazon: awsAmazonPayload{
Document: doc,
Signature: signature,
},
}
return jose.Signed(sig).Claims(claims).CompactSerialize()
}
func generateAzureToken(sub, iss, aud, tenantID, subscriptionID, resourceGroup, virtualMachine string, iat time.Time, jwk *jose.JSONWebKey) (string, error) {
sig, err := jose.NewSigner(
jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID),
)
if err != nil {
return "", err
}
claims := azurePayload{
Claims: jose.Claims{
Subject: sub,
Issuer: iss,
IssuedAt: jose.NewNumericDate(iat),
NotBefore: jose.NewNumericDate(iat),
Expiry: jose.NewNumericDate(iat.Add(5 * time.Minute)),
Audience: []string{aud},
ID: "the-jti",
},
AppID: "the-appid",
AppIDAcr: "the-appidacr",
IdentityProvider: "the-idp",
ObjectID: "the-oid",
TenantID: tenantID,
Version: "the-version",
XMSMirID: fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Compute/virtualMachines/%s", subscriptionID, resourceGroup, virtualMachine),
}
return jose.Signed(sig).Claims(claims).CompactSerialize()
}
func parseToken(token string) (*jose.JSONWebToken, *jose.Claims, error) {
tok, err := jose.ParseSigned(token)
if err != nil {
@ -232,6 +619,23 @@ func parseToken(token string) (*jose.JSONWebToken, *jose.Claims, error) {
return tok, claims, nil
}
func parseAWSToken(token string) (*jose.JSONWebToken, *awsPayload, error) {
tok, err := jose.ParseSigned(token)
if err != nil {
return nil, nil, err
}
claims := new(awsPayload)
if err := tok.UnsafeClaimsWithoutVerification(claims); err != nil {
return nil, nil, err
}
var doc awsInstanceIdentityDocument
if err := json.Unmarshal(claims.Amazon.Document, &doc); err != nil {
return nil, nil, errors.Wrap(err, "error unmarshaling identity document")
}
claims.document = doc
return tok, claims, nil
}
func generateJWKServer(n int) *httptest.Server {
hits := struct {
Hits int `json:"hits"`

@ -27,6 +27,9 @@ Provisioners are people or code that are registered with the CA and authorized
to issue "provisioning tokens". Provisioning tokens are single use tokens that
can be used to authenticate with the CA and get a certificate.
See [provisioners.md](provisioners.md) for more information on the supported
provisioners and its options.
## Initializing PKI and configuring the Certificate Authority
To initialize a PKI and configure the Step Certificate Authority run:

@ -0,0 +1,311 @@
# Provisioners
Provisioners are people or code that are registered with the CA and authorized
to issue "provisioning tokens". Provisioning tokens are single-use tokens that
can be used to authenticate with the CA and get a certificate.
## JWK
JWK is the default provisioner type. It uses public-key cryptography sign and
validate a JSON Web Token (JWT).
The [step](https://github.com/smallstep/cli) CLI tool will create a JWK
provisioner when `step ca init` is used, and it also contains commands to add
(`step ca provisioner add`) and remove (`step ca provisioner remove`) JWK
provisioners.
In the ca.json configuration file, a complete JWK provisioner example looks like:
```json
{
"type": "JWK",
"name": "you@smallstep.com",
"key": {
"use": "sig",
"kty": "EC",
"kid": "NPM_9Gz_omTqchS6Xx9Yfvs-EuxkYo6VAk4sL7gyyM4",
"crv": "P-256",
"alg": "ES256",
"x": "bBI5AkO9lwvDuWGfOr0F6ttXC-ZRzJo8kKn5wTzRJXI",
"y": "rcfaqE-EEZgs34Q9SSH3f9Ua5a8dKopXNfEzDD8KRlU"
},
"encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJjdHkiOiJqd2sranNvbiIsImVuYyI6IkEyNTZHQ00iLCJwMmMiOjEwMDAwMCwicDJzIjoiTlV6MjlEb3hKMVdOaFI3dUNjaGdYZyJ9.YN7xhz6RAbz_9bcuXoymBOj8bOg23ETAdmSCRyHpxGekkV0q3STYYg.vo1oBnZsZjgRu5Ln.Xop8AvZ74h_im2jxeaq-hYYWnaK_eF7MGr4xcZGodMUxp-hGPqS85oWkyprkQLYt1-jXTURfpejtmPeB4-sxgj7OFxMYYus84BdkG9BZgSBmMN9SqZItOv4pqg_NwQA0bv9g9A_e-N6QUFanxuYQsEPX_-IwWBDbNKyN9bXbpEQa0FKNVsTvFahGzOxQngXipi265VADkh8MJLjYerplKIbNeOJJbLd9CbS9fceLvQUNr3ACGgAejSaWmeNUVqbho1lY4882iS8QVx1VzjluTXlAMdSUUDHArHEihz008kCyF0YfvNdGebyEDLvTmF6KkhqMpsWn3zASYBidc9k._ch9BtvRRhcLD838itIQlw",
"claims": {
"minTLSCertDuration": "5m",
"maxTLSCertDuration": "24h",
"defaultTLSCertDuration": "24h",
"disableRenewal": false
}
}
```
* `type` (mandatory): for a JWK provisioner it must be `JWK`, this field is case
insensitive.
* `name` (mandatory): identifies the provisioner, a good practice is to
use an email address or a descriptive string that allows the identification of
the owner, but it can be any non-empty string.
* `key` (mandatory): is the JWK (JSON Web Key) representation of a public key
used to validate a signed token.
* `encryptedKey` (recommended): is the encrypted private key used to sign a
token. It's a JWE compact string containing the JWK representation of the
private key.
We can use [step](https://github.com/smallstep/cli) to see the private key
encrypted with the password `asdf`:
```sh
$ echo ey...lw | step crypto jwe decrypt | jq
Please enter the password to decrypt the content encryption key:
{
"use": "sig",
"kty": "EC",
"kid": "NPM_9Gz_omTqchS6Xx9Yfvs-EuxkYo6VAk4sL7gyyM4",
"crv": "P-256",
"alg": "ES256",
"x": "bBI5AkO9lwvDuWGfOr0F6ttXC-ZRzJo8kKn5wTzRJXI",
"y": "rcfaqE-EEZgs34Q9SSH3f9Ua5a8dKopXNfEzDD8KRlU",
"d": "rsjCCM_2FQ-uk7nywBEQHl84oaPo4mTpYDgXAu63igE"
}
```
If the ca.json does not contain the encryptedKey, the private key must be
provided using the `--key` flag of the `step ca token` to be able to sign the
token.
* `claims` (optional): overwrites the default claims set in the authority.
You can set one or more of the following claims:
* `minTLSCertDuration`: do not allow certificates with a duration less than
this value.
* `maxTLSCertDuration`: do not allow certificates with a duration greater than
this value.
* `defaultTLSCertDuration`: if no certificate validity period is specified,
use this value.
* `disableIssuedAtCheck`: disable a check verifying that provisioning tokens
must be issued after the CA has booted. This claim is one prevention against
token reuse. The default value is `false`. Do not change this unless you
know what you are doing.
## OIDC
An OIDC provisioner allows a user to get a certificate after authenticating
himself with an OAuth OpenID Connect identity provider. The ID token provided
will be used on the CA authentication, and by default, the certificate will only
have the user's email as a Subject Alternative Name (SAN) Extension.
One of the most common providers and the one we'll use in the following example
is G-Suite.
```json
{
"type": "OIDC",
"name": "Google",
"clientID": "1087160488420-8qt7bavg3qesdhs6it824mhnfgcfe8il.apps.googleusercontent.com",
"clientSecret": "udTrOT3gzrO7W9fDPgZQLfYJ",
"configurationEndpoint": "https://accounts.google.com/.well-known/openid-configuration",
"admins": ["you@smallstep.com"],
"domains": ["smallstep.com"],
"claims": {
"maxTLSCertDuration": "8h",
"defaultTLSCertDuration": "2h",
"disableRenewal": true
}
}
```
* `type` (mandatory): indicates the provisioner type and must be `OIDC`.
* `name` (mandatory): a string used to identify the provider when the CLI is
used.
* `clientID` (mandatory): the client id provided by the identity provider used
to initialize the authentication flow.
* `clientSecret` (mandatory): the client secret provided by the identity
provider used to get the id token. Some identity providers might use an empty
string as a secret.
* `configurationEndpoing` (mandatory): is the HTTP address used by the CA to get
the OpenID Connect configuration and public keys used to validate the tokens.
* `admins` (optional): is the list of emails that will be able to get
certificates with custom SANs. If a user is not an admin, it will only be able
to get a certificate with its email in it.
* `domains` (optional): is the list of domains valid. If provided only the
emails with the provided domains will be able to authenticate.
* `claims` (optional): overwrites the default claims set in the authority, see
the [JWK](#jwk) section for all the options.
## Provisioners for Cloud Identities
[Step certificates](https://github.com/smallstep/certificates) can grant
certificates to code running in a machine without any other authentication than
the one provided by the cloud. Usually, this is implemented with some kind of
signed document, but the information contained on them might not be enough to
generate a certificate. Due to this limitation, the cloud identities use by
default a trust model called Trust On First Use (TOFU).
The Trust On First Use model allows the use of more permissive CSRs that can
have custom SANs that cannot be validated. But it comes with the limitation that
you can only grant a certificate once. After this first grant, the same machine
will need to renew the certificate using mTLS, and the CA will block any other
attempt to grant a certificate to that instance.
### AWS
The AWS provisioner allows granting a certificate to an Amazon EC2 instance
using the [Instance Identity Documents](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instance-identity-documents.html)
The [step](https://github.com/smallstep/cli) CLI will generate a custom JWT
token containing the instance identity document and its signature and the CA
will grant a certificate after validating it.
In the ca.json, an AWS provisioner looks like:
```json
{
"type": "AWS",
"name": "Amazon Web Services",
"accounts": ["1234567890"],
"disableCustomSANs": false,
"disableTrustOnFirstUse": false,
"instanceAge": "1h",
"claims": {
"maxTLSCertDuration": "2160h",
"defaultTLSCertDuration": "2160h"
}
}
```
* `type` (mandatory): indicates the provisioner type and must be `AWS`.
* `name` (mandatory): a string used to identify the provider when the CLI is
used.
* `accounts` (optional): the list of AWS account numbers that are allowed to use
this provisioner. If none is specified, all AWS accounts will be valid.
* `disableCustomSANs` (optional): by default custom SANs are valid, but if this
option is set to true only the SANs available in the instance identity
document will be valid, these are the private IP and the DNS
`ip-<private-ip>.<region>.compute.internal`.
* `disableTrustOnFirstUse` (optional): by default only one certificate will be
granted per instance, but if the option is set to true this limit is not set
and different tokens can be used to get different certificates.
* `instanceAge` (optional): the maximum age of an instance to grant a
certificate. The instance age is a string using the duration format.
* `claims` (optional): overwrites the default claims set in the authority, see
the [JWK](#jwk) section for all the options.
### GCP
The GCP provisioner grants certificates to Google Compute Engine instance using
its [identity](https://cloud.google.com/compute/docs/instances/verifying-instance-identity)
token. The CA will validate the JWT and grant a certificate.
In the ca.json, a GCP provisioner looks like:
```json
{
"type": "GCP",
"name": "Google Cloud",
"serviceAccounts": ["1234567890"],
"projectIDs": ["project-id"],
"disableCustomSANs": false,
"disableTrustOnFirstUse": false,
"instanceAge": "1h",
"claims": {
"maxTLSCertDuration": "2160h",
"defaultTLSCertDuration": "2160h"
}
}
```
* `type` (mandatory): indicates the provisioner type and must be `GCP`.
* `name` (mandatory): a string used to identify the provider when the CLI is
used.
* `serviceAccounts` (optional): the list of service account numbers that are
allowed to use this provisioner. If none is specified, all service accounts
will be valid.
* `projectIDs` (optional): the list of project identifiers that are allowed to
use this provisioner. If non is specified all project will be valid.
* `disableCustomSANs` (optional): by default custom SANs are valid, but if this
option is set to true only the SANs available in the instance identity
document will be valid, these are the DNS
`<instance-name>.c.<project-id>.internal` and
`<instance-name>.<zone>.c.<project-id>.internal`
* `disableTrustOnFirstUse` (optional): by default only one certificate will be
granted per instance, but if the option is set to true this limit is not set
and different tokens can be used to get different certificates.
* `instanceAge` (optional): the maximum age of an instance to grant a
certificate. The instance age is a string using the duration format.
* `claims` (optional): overwrites the default claims set in the authority, see
the [JWK](#jwk) section for all the options.
### Azure
The Azure provisioner grants certificates to Microsoft Azure instances using
the [managed identities tokens](https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token).
The CA will validate the JWT and grant a certificate.
In the ca.json, an Azure provisioner looks like:
```json
{
"type": "Azure",
"name": "Microsoft Azure",
"tenantId": "b17c217c-84db-43f0-babd-e06a71083cda",
"resourceGroups": ["backend", "accounting"],
"audience": "https://management.azure.com/",
"disableCustomSANs": false,
"disableTrustOnFirstUse": false,
"claims": {
"maxTLSCertDuration": "2160h",
"defaultTLSCertDuration": "2160h"
}
}
```
* `type` (mandatory): indicates the provisioner type and must be `Azure`.
* `name` (mandatory): a string used to identify the provider when the CLI is
used.
* `tenantId` (mandatory): the Azure account tenant id for this provisioner. This
id is the Directory ID available in the Azure Active Directory properties.
* `audience` (optional): defaults to `https://management.azure.com/` but it can
be changed if necessary.
* `resourceGroups` (optional): the list of resource group names that are allowed
to use this provisioner. If none is specified, all resource groups will be
valid.
* `disableCustomSANs` (optional): by default custom SANs are valid, but if this
option is set to true only the SANs available in the token will be valid, in
Azure only the virtual machine name is available.
* `disableTrustOnFirstUse` (optional): by default only one certificate will be
granted per instance, but if the option is set to true this limit is not set
and different tokens can be used to get different certificates.
* `claims` (optional): overwrites the default claims set in the authority, see
the [JWK](#jwk) section for all the options.
Loading…
Cancel
Save