Merge branch 'master' into crl-support
# Conflicts: # authority/config/config.gopull/731/head
commit
d2483f3a70
@ -1,12 +0,0 @@
|
||||
name: Pull Request Labeler
|
||||
on:
|
||||
pull_request_target
|
||||
|
||||
jobs:
|
||||
label:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/labeler@v3.0.2
|
||||
with:
|
||||
repo-token: "${{ secrets.GITHUB_TOKEN }}"
|
||||
|
@ -0,0 +1,29 @@
|
||||
name: Add Issues and PRs to Triage
|
||||
|
||||
on:
|
||||
issues:
|
||||
types:
|
||||
- opened
|
||||
pull_request_target:
|
||||
types:
|
||||
- opened
|
||||
|
||||
jobs:
|
||||
|
||||
label:
|
||||
name: Label PR
|
||||
runs-on: ubuntu-latest
|
||||
if: github.event_name == 'pull_request_target'
|
||||
steps:
|
||||
- uses: actions/labeler@v3.0.2
|
||||
with:
|
||||
repo-token: "${{ secrets.GITHUB_TOKEN }}"
|
||||
|
||||
add-to-project:
|
||||
name: Add to Triage Project
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/add-to-project@v0.3.0
|
||||
with:
|
||||
project-url: https://github.com/orgs/smallstep/projects/94
|
||||
github-token: ${{ secrets.TRIAGE_PAT }}
|
@ -0,0 +1,8 @@
|
||||
We appreciate any effort to discover and disclose security vulnerabilities responsibly.
|
||||
|
||||
If you would like to report a vulnerability in one of our projects, or have security concerns regarding Smallstep software, please email security@smallstep.com.
|
||||
|
||||
In order for us to best respond to your report, please include any of the following:
|
||||
* Steps to reproduce or proof-of-concept
|
||||
* Any relevant tools, including versions used
|
||||
* Tool output
|
@ -1,137 +0,0 @@
|
||||
package apiv1
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/x509"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// KeyManager is the interface implemented by all the KMS.
|
||||
type KeyManager interface {
|
||||
GetPublicKey(req *GetPublicKeyRequest) (crypto.PublicKey, error)
|
||||
CreateKey(req *CreateKeyRequest) (*CreateKeyResponse, error)
|
||||
CreateSigner(req *CreateSignerRequest) (crypto.Signer, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
// Decrypter is an interface implemented by KMSes that are used
|
||||
// in operations that require decryption
|
||||
type Decrypter interface {
|
||||
CreateDecrypter(req *CreateDecrypterRequest) (crypto.Decrypter, error)
|
||||
}
|
||||
|
||||
// CertificateManager is the interface implemented by the KMS that can load and
|
||||
// store x509.Certificates.
|
||||
type CertificateManager interface {
|
||||
LoadCertificate(req *LoadCertificateRequest) (*x509.Certificate, error)
|
||||
StoreCertificate(req *StoreCertificateRequest) error
|
||||
}
|
||||
|
||||
// ValidateName is an interface that KeyManager can implement to validate a
|
||||
// given name or URI.
|
||||
type NameValidator interface {
|
||||
ValidateName(s string) error
|
||||
}
|
||||
|
||||
// ErrNotImplemented is the type of error returned if an operation is not
|
||||
// implemented.
|
||||
type ErrNotImplemented struct {
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e ErrNotImplemented) Error() string {
|
||||
if e.Message != "" {
|
||||
return e.Message
|
||||
}
|
||||
return "not implemented"
|
||||
}
|
||||
|
||||
// ErrAlreadyExists is the type of error returned if a key already exists. This
|
||||
// is currently only implmented on pkcs11.
|
||||
type ErrAlreadyExists struct {
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e ErrAlreadyExists) Error() string {
|
||||
if e.Message != "" {
|
||||
return e.Message
|
||||
}
|
||||
return "key already exists"
|
||||
}
|
||||
|
||||
// Type represents the KMS type used.
|
||||
type Type string
|
||||
|
||||
const (
|
||||
// DefaultKMS is a KMS implementation using software.
|
||||
DefaultKMS Type = ""
|
||||
// SoftKMS is a KMS implementation using software.
|
||||
SoftKMS Type = "softkms"
|
||||
// CloudKMS is a KMS implementation using Google's Cloud KMS.
|
||||
CloudKMS Type = "cloudkms"
|
||||
// AmazonKMS is a KMS implementation using Amazon AWS KMS.
|
||||
AmazonKMS Type = "awskms"
|
||||
// PKCS11 is a KMS implementation using the PKCS11 standard.
|
||||
PKCS11 Type = "pkcs11"
|
||||
// YubiKey is a KMS implementation using a YubiKey PIV.
|
||||
YubiKey Type = "yubikey"
|
||||
// SSHAgentKMS is a KMS implementation using ssh-agent to access keys.
|
||||
SSHAgentKMS Type = "sshagentkms"
|
||||
// AzureKMS is a KMS implementation using Azure Key Vault.
|
||||
AzureKMS Type = "azurekms"
|
||||
)
|
||||
|
||||
// Options are the KMS options. They represent the kms object in the ca.json.
|
||||
type Options struct {
|
||||
// The type of the KMS to use.
|
||||
Type string `json:"type"`
|
||||
|
||||
// Path to the credentials file used in CloudKMS and AmazonKMS.
|
||||
CredentialsFile string `json:"credentialsFile,omitempty"`
|
||||
|
||||
// URI is based on the PKCS #11 URI Scheme defined in
|
||||
// https://tools.ietf.org/html/rfc7512 and represents the configuration used
|
||||
// to connect to the KMS.
|
||||
//
|
||||
// Used by: pkcs11
|
||||
URI string `json:"uri,omitempty"`
|
||||
|
||||
// Pin used to access the PKCS11 module. It can be defined in the URI using
|
||||
// the pin-value or pin-source properties.
|
||||
Pin string `json:"pin,omitempty"`
|
||||
|
||||
// ManagementKey used in YubiKeys. Default management key is the hexadecimal
|
||||
// string 010203040506070801020304050607080102030405060708:
|
||||
// []byte{
|
||||
// 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
|
||||
// 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
|
||||
// 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
|
||||
// }
|
||||
ManagementKey string `json:"managementKey,omitempty"`
|
||||
|
||||
// Region to use in AmazonKMS.
|
||||
Region string `json:"region,omitempty"`
|
||||
|
||||
// Profile to use in AmazonKMS.
|
||||
Profile string `json:"profile,omitempty"`
|
||||
}
|
||||
|
||||
// Validate checks the fields in Options.
|
||||
func (o *Options) Validate() error {
|
||||
if o == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch Type(strings.ToLower(o.Type)) {
|
||||
case DefaultKMS, SoftKMS: // Go crypto based kms.
|
||||
case CloudKMS, AmazonKMS, AzureKMS: // Cloud based kms.
|
||||
case YubiKey, PKCS11: // Hardware based kms.
|
||||
case SSHAgentKMS: // Others
|
||||
default:
|
||||
return errors.Errorf("unsupported kms type %s", o.Type)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@ -1,76 +0,0 @@
|
||||
package apiv1
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestOptions_Validate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
options *Options
|
||||
wantErr bool
|
||||
}{
|
||||
{"nil", nil, false},
|
||||
{"softkms", &Options{Type: "softkms"}, false},
|
||||
{"cloudkms", &Options{Type: "cloudkms"}, false},
|
||||
{"awskms", &Options{Type: "awskms"}, false},
|
||||
{"sshagentkms", &Options{Type: "sshagentkms"}, false},
|
||||
{"pkcs11", &Options{Type: "pkcs11"}, false},
|
||||
{"unsupported", &Options{Type: "unsupported"}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.options.Validate(); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Options.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrNotImplemented_Error(t *testing.T) {
|
||||
type fields struct {
|
||||
msg string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want string
|
||||
}{
|
||||
{"default", fields{}, "not implemented"},
|
||||
{"custom", fields{"custom message: not implemented"}, "custom message: not implemented"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
e := ErrNotImplemented{
|
||||
Message: tt.fields.msg,
|
||||
}
|
||||
if got := e.Error(); got != tt.want {
|
||||
t.Errorf("ErrNotImplemented.Error() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrAlreadyExists_Error(t *testing.T) {
|
||||
type fields struct {
|
||||
msg string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want string
|
||||
}{
|
||||
{"default", fields{}, "key already exists"},
|
||||
{"custom", fields{"custom message: key already exists"}, "custom message: key already exists"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
e := ErrAlreadyExists{
|
||||
Message: tt.fields.msg,
|
||||
}
|
||||
if got := e.Error(); got != tt.want {
|
||||
t.Errorf("ErrAlreadyExists.Error() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -1,27 +0,0 @@
|
||||
package apiv1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var registry = new(sync.Map)
|
||||
|
||||
// KeyManagerNewFunc is the type that represents the method to initialize a new
|
||||
// KeyManager.
|
||||
type KeyManagerNewFunc func(ctx context.Context, opts Options) (KeyManager, error)
|
||||
|
||||
// Register adds to the registry a method to create a KeyManager of type t.
|
||||
func Register(t Type, fn KeyManagerNewFunc) {
|
||||
registry.Store(t, fn)
|
||||
}
|
||||
|
||||
// LoadKeyManagerNewFunc returns the function initialize a KayManager.
|
||||
func LoadKeyManagerNewFunc(t Type) (KeyManagerNewFunc, bool) {
|
||||
v, ok := registry.Load(t)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
fn, ok := v.(KeyManagerNewFunc)
|
||||
return fn, ok
|
||||
}
|
@ -1,167 +0,0 @@
|
||||
package apiv1
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ProtectionLevel specifies on some KMS how cryptographic operations are
|
||||
// performed.
|
||||
type ProtectionLevel int
|
||||
|
||||
const (
|
||||
// Protection level not specified.
|
||||
UnspecifiedProtectionLevel ProtectionLevel = iota
|
||||
// Crypto operations are performed in software.
|
||||
Software
|
||||
// Crypto operations are performed in a Hardware Security Module.
|
||||
HSM
|
||||
)
|
||||
|
||||
// String returns a string representation of p.
|
||||
func (p ProtectionLevel) String() string {
|
||||
switch p {
|
||||
case UnspecifiedProtectionLevel:
|
||||
return "unspecified"
|
||||
case Software:
|
||||
return "software"
|
||||
case HSM:
|
||||
return "hsm"
|
||||
default:
|
||||
return fmt.Sprintf("unknown(%d)", p)
|
||||
}
|
||||
}
|
||||
|
||||
// SignatureAlgorithm used for cryptographic signing.
|
||||
type SignatureAlgorithm int
|
||||
|
||||
const (
|
||||
// Not specified.
|
||||
UnspecifiedSignAlgorithm SignatureAlgorithm = iota
|
||||
// RSASSA-PKCS1-v1_5 key and a SHA256 digest.
|
||||
SHA256WithRSA
|
||||
// RSASSA-PKCS1-v1_5 key and a SHA384 digest.
|
||||
SHA384WithRSA
|
||||
// RSASSA-PKCS1-v1_5 key and a SHA512 digest.
|
||||
SHA512WithRSA
|
||||
// RSASSA-PSS key with a SHA256 digest.
|
||||
SHA256WithRSAPSS
|
||||
// RSASSA-PSS key with a SHA384 digest.
|
||||
SHA384WithRSAPSS
|
||||
// RSASSA-PSS key with a SHA512 digest.
|
||||
SHA512WithRSAPSS
|
||||
// ECDSA on the NIST P-256 curve with a SHA256 digest.
|
||||
ECDSAWithSHA256
|
||||
// ECDSA on the NIST P-384 curve with a SHA384 digest.
|
||||
ECDSAWithSHA384
|
||||
// ECDSA on the NIST P-521 curve with a SHA512 digest.
|
||||
ECDSAWithSHA512
|
||||
// EdDSA on Curve25519 with a SHA512 digest.
|
||||
PureEd25519
|
||||
)
|
||||
|
||||
// String returns a string representation of s.
|
||||
func (s SignatureAlgorithm) String() string {
|
||||
switch s {
|
||||
case UnspecifiedSignAlgorithm:
|
||||
return "unspecified"
|
||||
case SHA256WithRSA:
|
||||
return "SHA256-RSA"
|
||||
case SHA384WithRSA:
|
||||
return "SHA384-RSA"
|
||||
case SHA512WithRSA:
|
||||
return "SHA512-RSA"
|
||||
case SHA256WithRSAPSS:
|
||||
return "SHA256-RSAPSS"
|
||||
case SHA384WithRSAPSS:
|
||||
return "SHA384-RSAPSS"
|
||||
case SHA512WithRSAPSS:
|
||||
return "SHA512-RSAPSS"
|
||||
case ECDSAWithSHA256:
|
||||
return "ECDSA-SHA256"
|
||||
case ECDSAWithSHA384:
|
||||
return "ECDSA-SHA384"
|
||||
case ECDSAWithSHA512:
|
||||
return "ECDSA-SHA512"
|
||||
case PureEd25519:
|
||||
return "Ed25519"
|
||||
default:
|
||||
return fmt.Sprintf("unknown(%d)", s)
|
||||
}
|
||||
}
|
||||
|
||||
// GetPublicKeyRequest is the parameter used in the kms.GetPublicKey method.
|
||||
type GetPublicKeyRequest struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
// CreateKeyRequest is the parameter used in the kms.CreateKey method.
|
||||
type CreateKeyRequest struct {
|
||||
// Name represents the key name or label used to identify a key.
|
||||
//
|
||||
// Used by: awskms, cloudkms, azurekms, pkcs11, yubikey.
|
||||
Name string
|
||||
|
||||
// SignatureAlgorithm represents the type of key to create.
|
||||
SignatureAlgorithm SignatureAlgorithm
|
||||
|
||||
// Bits is the number of bits on RSA keys.
|
||||
Bits int
|
||||
|
||||
// ProtectionLevel specifies how cryptographic operations are performed.
|
||||
// Used by: cloudkms, azurekms.
|
||||
ProtectionLevel ProtectionLevel
|
||||
|
||||
// Extractable defines if the new key may be exported from the HSM under a
|
||||
// wrap key. On pkcs11 sets the CKA_EXTRACTABLE bit.
|
||||
//
|
||||
// Used by: pkcs11
|
||||
Extractable bool
|
||||
}
|
||||
|
||||
// CreateKeyResponse is the response value of the kms.CreateKey method.
|
||||
type CreateKeyResponse struct {
|
||||
Name string
|
||||
PublicKey crypto.PublicKey
|
||||
PrivateKey crypto.PrivateKey
|
||||
CreateSignerRequest CreateSignerRequest
|
||||
}
|
||||
|
||||
// CreateSignerRequest is the parameter used in the kms.CreateSigner method.
|
||||
type CreateSignerRequest struct {
|
||||
Signer crypto.Signer
|
||||
SigningKey string
|
||||
SigningKeyPEM []byte
|
||||
TokenLabel string
|
||||
PublicKey string
|
||||
PublicKeyPEM []byte
|
||||
Password []byte
|
||||
}
|
||||
|
||||
// CreateDecrypterRequest is the parameter used in the kms.Decrypt method.
|
||||
type CreateDecrypterRequest struct {
|
||||
Decrypter crypto.Decrypter
|
||||
DecryptionKey string
|
||||
DecryptionKeyPEM []byte
|
||||
Password []byte
|
||||
}
|
||||
|
||||
// LoadCertificateRequest is the parameter used in the LoadCertificate method of
|
||||
// a CertificateManager.
|
||||
type LoadCertificateRequest struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
// StoreCertificateRequest is the parameter used in the StoreCertificate method
|
||||
// of a CertificateManager.
|
||||
type StoreCertificateRequest struct {
|
||||
Name string
|
||||
Certificate *x509.Certificate
|
||||
|
||||
// Extractable defines if the new certificate may be exported from the HSM
|
||||
// under a wrap key. On pkcs11 sets the CKA_EXTRACTABLE bit.
|
||||
//
|
||||
// Used by: pkcs11
|
||||
Extractable bool
|
||||
}
|
@ -1,51 +0,0 @@
|
||||
package apiv1
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestProtectionLevel_String(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
p ProtectionLevel
|
||||
want string
|
||||
}{
|
||||
{"unspecified", UnspecifiedProtectionLevel, "unspecified"},
|
||||
{"software", Software, "software"},
|
||||
{"hsm", HSM, "hsm"},
|
||||
{"unknown", ProtectionLevel(100), "unknown(100)"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.p.String(); got != tt.want {
|
||||
t.Errorf("ProtectionLevel.String() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignatureAlgorithm_String(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
s SignatureAlgorithm
|
||||
want string
|
||||
}{
|
||||
{"UnspecifiedSignAlgorithm", UnspecifiedSignAlgorithm, "unspecified"},
|
||||
{"SHA256WithRSA", SHA256WithRSA, "SHA256-RSA"},
|
||||
{"SHA384WithRSA", SHA384WithRSA, "SHA384-RSA"},
|
||||
{"SHA512WithRSA", SHA512WithRSA, "SHA512-RSA"},
|
||||
{"SHA256WithRSAPSS", SHA256WithRSAPSS, "SHA256-RSAPSS"},
|
||||
{"SHA384WithRSAPSS", SHA384WithRSAPSS, "SHA384-RSAPSS"},
|
||||
{"SHA512WithRSAPSS", SHA512WithRSAPSS, "SHA512-RSAPSS"},
|
||||
{"ECDSAWithSHA256", ECDSAWithSHA256, "ECDSA-SHA256"},
|
||||
{"ECDSAWithSHA384", ECDSAWithSHA384, "ECDSA-SHA384"},
|
||||
{"ECDSAWithSHA512", ECDSAWithSHA512, "ECDSA-SHA512"},
|
||||
{"PureEd25519", PureEd25519, "Ed25519"},
|
||||
{"unknown", SignatureAlgorithm(100), "unknown(100)"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.s.String(); got != tt.want {
|
||||
t.Errorf("SignatureAlgorithm.String() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -1,267 +0,0 @@
|
||||
package awskms
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
"github.com/aws/aws-sdk-go/service/kms"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/kms/apiv1"
|
||||
"github.com/smallstep/certificates/kms/uri"
|
||||
"go.step.sm/crypto/pemutil"
|
||||
)
|
||||
|
||||
// Scheme is the scheme used in uris.
|
||||
const Scheme = "awskms"
|
||||
|
||||
// KMS implements a KMS using AWS Key Management Service.
|
||||
type KMS struct {
|
||||
session *session.Session
|
||||
service KeyManagementClient
|
||||
}
|
||||
|
||||
// KeyManagementClient defines the methods on KeyManagementClient that this
|
||||
// package will use. This interface will be used for unit testing.
|
||||
type KeyManagementClient interface {
|
||||
GetPublicKeyWithContext(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error)
|
||||
CreateKeyWithContext(ctx aws.Context, input *kms.CreateKeyInput, opts ...request.Option) (*kms.CreateKeyOutput, error)
|
||||
CreateAliasWithContext(ctx aws.Context, input *kms.CreateAliasInput, opts ...request.Option) (*kms.CreateAliasOutput, error)
|
||||
SignWithContext(ctx aws.Context, input *kms.SignInput, opts ...request.Option) (*kms.SignOutput, error)
|
||||
}
|
||||
|
||||
// customerMasterKeySpecMapping is a mapping between the step signature algorithm,
|
||||
// and bits for RSA keys, with awskms CustomerMasterKeySpec.
|
||||
var customerMasterKeySpecMapping = map[apiv1.SignatureAlgorithm]interface{}{
|
||||
apiv1.UnspecifiedSignAlgorithm: kms.CustomerMasterKeySpecEccNistP256,
|
||||
apiv1.SHA256WithRSA: map[int]string{
|
||||
0: kms.CustomerMasterKeySpecRsa3072,
|
||||
2048: kms.CustomerMasterKeySpecRsa2048,
|
||||
3072: kms.CustomerMasterKeySpecRsa3072,
|
||||
4096: kms.CustomerMasterKeySpecRsa4096,
|
||||
},
|
||||
apiv1.SHA512WithRSA: map[int]string{
|
||||
0: kms.CustomerMasterKeySpecRsa4096,
|
||||
4096: kms.CustomerMasterKeySpecRsa4096,
|
||||
},
|
||||
apiv1.SHA256WithRSAPSS: map[int]string{
|
||||
0: kms.CustomerMasterKeySpecRsa3072,
|
||||
2048: kms.CustomerMasterKeySpecRsa2048,
|
||||
3072: kms.CustomerMasterKeySpecRsa3072,
|
||||
4096: kms.CustomerMasterKeySpecRsa4096,
|
||||
},
|
||||
apiv1.SHA512WithRSAPSS: map[int]string{
|
||||
0: kms.CustomerMasterKeySpecRsa4096,
|
||||
4096: kms.CustomerMasterKeySpecRsa4096,
|
||||
},
|
||||
apiv1.ECDSAWithSHA256: kms.CustomerMasterKeySpecEccNistP256,
|
||||
apiv1.ECDSAWithSHA384: kms.CustomerMasterKeySpecEccNistP384,
|
||||
apiv1.ECDSAWithSHA512: kms.CustomerMasterKeySpecEccNistP521,
|
||||
}
|
||||
|
||||
// New creates a new AWSKMS. By default, sessions will be created using the
|
||||
// credentials in `~/.aws/credentials`, but this can be overridden using the
|
||||
// CredentialsFile option, the Region and Profile can also be configured as
|
||||
// options.
|
||||
//
|
||||
// AWS sessions can also be configured with environment variables, see docs at
|
||||
// https://docs.aws.amazon.com/sdk-for-go/api/aws/session/ for all the options.
|
||||
func New(ctx context.Context, opts apiv1.Options) (*KMS, error) {
|
||||
var o session.Options
|
||||
|
||||
if opts.URI != "" {
|
||||
u, err := uri.ParseWithScheme(Scheme, opts.URI)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
o.Profile = u.Get("profile")
|
||||
if v := u.Get("region"); v != "" {
|
||||
o.Config.Region = new(string)
|
||||
*o.Config.Region = v
|
||||
}
|
||||
if f := u.Get("credentials-file"); f != "" {
|
||||
o.SharedConfigFiles = []string{f}
|
||||
}
|
||||
}
|
||||
|
||||
// Deprecated way to set configuration parameters.
|
||||
if opts.Region != "" {
|
||||
o.Config.Region = &opts.Region
|
||||
}
|
||||
if opts.Profile != "" {
|
||||
o.Profile = opts.Profile
|
||||
}
|
||||
if opts.CredentialsFile != "" {
|
||||
o.SharedConfigFiles = []string{opts.CredentialsFile}
|
||||
}
|
||||
|
||||
sess, err := session.NewSessionWithOptions(o)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error creating AWS session")
|
||||
}
|
||||
|
||||
return &KMS{
|
||||
session: sess,
|
||||
service: kms.New(sess),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
apiv1.Register(apiv1.AmazonKMS, func(ctx context.Context, opts apiv1.Options) (apiv1.KeyManager, error) {
|
||||
return New(ctx, opts)
|
||||
})
|
||||
}
|
||||
|
||||
// GetPublicKey returns a public key from KMS.
|
||||
func (k *KMS) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey, error) {
|
||||
if req.Name == "" {
|
||||
return nil, errors.New("getPublicKey 'name' cannot be empty")
|
||||
}
|
||||
keyID, err := parseKeyID(req.Name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx, cancel := defaultContext()
|
||||
defer cancel()
|
||||
|
||||
resp, err := k.service.GetPublicKeyWithContext(ctx, &kms.GetPublicKeyInput{
|
||||
KeyId: &keyID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "awskms GetPublicKeyWithContext failed")
|
||||
}
|
||||
|
||||
return pemutil.ParseDER(resp.PublicKey)
|
||||
}
|
||||
|
||||
// CreateKey generates a new key in KMS and returns the public key version
|
||||
// of it.
|
||||
func (k *KMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyResponse, error) {
|
||||
if req.Name == "" {
|
||||
return nil, errors.New("createKeyRequest 'name' cannot be empty")
|
||||
}
|
||||
|
||||
keySpec, err := getCustomerMasterKeySpecMapping(req.SignatureAlgorithm, req.Bits)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tag := new(kms.Tag)
|
||||
tag.SetTagKey("name")
|
||||
tag.SetTagValue(req.Name)
|
||||
|
||||
input := &kms.CreateKeyInput{
|
||||
Description: &req.Name,
|
||||
CustomerMasterKeySpec: &keySpec,
|
||||
Tags: []*kms.Tag{tag},
|
||||
}
|
||||
input.SetKeyUsage(kms.KeyUsageTypeSignVerify)
|
||||
|
||||
ctx, cancel := defaultContext()
|
||||
defer cancel()
|
||||
|
||||
resp, err := k.service.CreateKeyWithContext(ctx, input)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "awskms CreateKeyWithContext failed")
|
||||
}
|
||||
if err := k.createKeyAlias(*resp.KeyMetadata.KeyId, req.Name); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create uri for key
|
||||
name := uri.New("awskms", url.Values{
|
||||
"key-id": []string{*resp.KeyMetadata.KeyId},
|
||||
}).String()
|
||||
|
||||
publicKey, err := k.GetPublicKey(&apiv1.GetPublicKeyRequest{
|
||||
Name: name,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Names uses Amazon Resource Name
|
||||
// https://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html
|
||||
return &apiv1.CreateKeyResponse{
|
||||
Name: name,
|
||||
PublicKey: publicKey,
|
||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
||||
SigningKey: name,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (k *KMS) createKeyAlias(keyID, alias string) error {
|
||||
alias = "alias/" + alias + "-" + keyID[:8]
|
||||
|
||||
ctx, cancel := defaultContext()
|
||||
defer cancel()
|
||||
|
||||
_, err := k.service.CreateAliasWithContext(ctx, &kms.CreateAliasInput{
|
||||
AliasName: &alias,
|
||||
TargetKeyId: &keyID,
|
||||
})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "awskms CreateAliasWithContext failed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateSigner creates a new crypto.Signer with a previously configured key.
|
||||
func (k *KMS) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, error) {
|
||||
if req.SigningKey == "" {
|
||||
return nil, errors.New("createSigner 'signingKey' cannot be empty")
|
||||
}
|
||||
return NewSigner(k.service, req.SigningKey)
|
||||
}
|
||||
|
||||
// Close closes the connection of the KMS client.
|
||||
func (k *KMS) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func defaultContext() (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(context.Background(), 15*time.Second)
|
||||
}
|
||||
|
||||
// parseKeyID extracts the key-id from an uri.
|
||||
func parseKeyID(name string) (string, error) {
|
||||
name = strings.ToLower(name)
|
||||
if strings.HasPrefix(name, "awskms:") || strings.HasPrefix(name, "aws:") {
|
||||
u, err := uri.Parse(name)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if k := u.Get("key-id"); k != "" {
|
||||
return k, nil
|
||||
}
|
||||
return "", errors.Errorf("failed to get key-id from %s", name)
|
||||
}
|
||||
return name, nil
|
||||
}
|
||||
|
||||
func getCustomerMasterKeySpecMapping(alg apiv1.SignatureAlgorithm, bits int) (string, error) {
|
||||
v, ok := customerMasterKeySpecMapping[alg]
|
||||
if !ok {
|
||||
return "", errors.Errorf("awskms does not support signature algorithm '%s'", alg)
|
||||
}
|
||||
|
||||
switch v := v.(type) {
|
||||
case string:
|
||||
return v, nil
|
||||
case map[int]string:
|
||||
s, ok := v[bits]
|
||||
if !ok {
|
||||
return "", errors.Errorf("awskms does not support signature algorithm '%s' with '%d' bits", alg, bits)
|
||||
}
|
||||
return s, nil
|
||||
default:
|
||||
return "", errors.Errorf("unexpected error: this should not happen")
|
||||
}
|
||||
}
|
@ -1,364 +0,0 @@
|
||||
package awskms
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
"github.com/aws/aws-sdk-go/service/kms"
|
||||
"github.com/smallstep/certificates/kms/apiv1"
|
||||
"go.step.sm/crypto/pemutil"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
sess, err := session.NewSessionWithOptions(session.Options{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
expected := &KMS{
|
||||
session: sess,
|
||||
service: kms.New(sess),
|
||||
}
|
||||
|
||||
// This will force an error in the session creation.
|
||||
// It does not fail with missing credentials.
|
||||
forceError := func(t *testing.T) {
|
||||
key := "AWS_CA_BUNDLE"
|
||||
value := os.Getenv(key)
|
||||
os.Setenv(key, filepath.Join(os.TempDir(), "missing-ca.crt"))
|
||||
t.Cleanup(func() {
|
||||
if value == "" {
|
||||
os.Unsetenv(key)
|
||||
} else {
|
||||
os.Setenv(key, value)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
opts apiv1.Options
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *KMS
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", args{ctx, apiv1.Options{}}, expected, false},
|
||||
{"ok with options", args{ctx, apiv1.Options{
|
||||
Region: "us-east-1",
|
||||
Profile: "smallstep",
|
||||
CredentialsFile: "~/aws/credentials",
|
||||
}}, expected, false},
|
||||
{"ok with uri", args{ctx, apiv1.Options{
|
||||
URI: "awskms:region=us-east-1;profile=smallstep;credentials-file=/var/run/aws/credentials",
|
||||
}}, expected, false},
|
||||
{"fail", args{ctx, apiv1.Options{}}, nil, true},
|
||||
{"fail uri", args{ctx, apiv1.Options{
|
||||
URI: "pkcs11:region=us-east-1;profile=smallstep;credentials-file=/var/run/aws/credentials",
|
||||
}}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Force an error in the session loading
|
||||
if tt.wantErr {
|
||||
forceError(t)
|
||||
}
|
||||
|
||||
got, err := New(tt.args.ctx, tt.args.opts)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("New() = %#v, want %#v", got, tt.want)
|
||||
}
|
||||
} else {
|
||||
if got.session == nil || got.service == nil {
|
||||
t.Errorf("New() = %#v, want %#v", got, tt.want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKMS_GetPublicKey(t *testing.T) {
|
||||
okClient := getOKClient()
|
||||
key, err := pemutil.ParseKey([]byte(publicKey))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
type fields struct {
|
||||
session *session.Session
|
||||
service KeyManagementClient
|
||||
}
|
||||
type args struct {
|
||||
req *apiv1.GetPublicKeyRequest
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want crypto.PublicKey
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", fields{nil, okClient}, args{&apiv1.GetPublicKeyRequest{
|
||||
Name: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936",
|
||||
}}, key, false},
|
||||
{"fail empty", fields{nil, okClient}, args{&apiv1.GetPublicKeyRequest{}}, nil, true},
|
||||
{"fail name", fields{nil, okClient}, args{&apiv1.GetPublicKeyRequest{
|
||||
Name: "awskms:key-id=",
|
||||
}}, nil, true},
|
||||
{"fail getPublicKey", fields{nil, &MockClient{
|
||||
getPublicKeyWithContext: func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) {
|
||||
return nil, fmt.Errorf("an error")
|
||||
},
|
||||
}}, args{&apiv1.GetPublicKeyRequest{
|
||||
Name: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936",
|
||||
}}, nil, true},
|
||||
{"fail not der", fields{nil, &MockClient{
|
||||
getPublicKeyWithContext: func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) {
|
||||
return &kms.GetPublicKeyOutput{
|
||||
KeyId: input.KeyId,
|
||||
PublicKey: []byte(publicKey),
|
||||
}, nil
|
||||
},
|
||||
}}, args{&apiv1.GetPublicKeyRequest{
|
||||
Name: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936",
|
||||
}}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
k := &KMS{
|
||||
session: tt.fields.session,
|
||||
service: tt.fields.service,
|
||||
}
|
||||
got, err := k.GetPublicKey(tt.args.req)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("KMS.GetPublicKey() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("KMS.GetPublicKey() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKMS_CreateKey(t *testing.T) {
|
||||
okClient := getOKClient()
|
||||
key, err := pemutil.ParseKey([]byte(publicKey))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
type fields struct {
|
||||
session *session.Session
|
||||
service KeyManagementClient
|
||||
}
|
||||
type args struct {
|
||||
req *apiv1.CreateKeyRequest
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want *apiv1.CreateKeyResponse
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", fields{nil, okClient}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "root",
|
||||
SignatureAlgorithm: apiv1.ECDSAWithSHA256,
|
||||
}}, &apiv1.CreateKeyResponse{
|
||||
Name: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936",
|
||||
PublicKey: key,
|
||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
||||
SigningKey: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936",
|
||||
},
|
||||
}, false},
|
||||
{"ok rsa", fields{nil, okClient}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "root",
|
||||
SignatureAlgorithm: apiv1.SHA256WithRSA,
|
||||
Bits: 2048,
|
||||
}}, &apiv1.CreateKeyResponse{
|
||||
Name: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936",
|
||||
PublicKey: key,
|
||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
||||
SigningKey: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936",
|
||||
},
|
||||
}, false},
|
||||
{"fail empty", fields{nil, okClient}, args{&apiv1.CreateKeyRequest{}}, nil, true},
|
||||
{"fail unsupported alg", fields{nil, okClient}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "root",
|
||||
SignatureAlgorithm: apiv1.PureEd25519,
|
||||
}}, nil, true},
|
||||
{"fail unsupported bits", fields{nil, okClient}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "root",
|
||||
SignatureAlgorithm: apiv1.SHA256WithRSA,
|
||||
Bits: 1234,
|
||||
}}, nil, true},
|
||||
{"fail createKey", fields{nil, &MockClient{
|
||||
createKeyWithContext: func(ctx aws.Context, input *kms.CreateKeyInput, opts ...request.Option) (*kms.CreateKeyOutput, error) {
|
||||
return nil, fmt.Errorf("an error")
|
||||
},
|
||||
createAliasWithContext: okClient.createAliasWithContext,
|
||||
getPublicKeyWithContext: okClient.getPublicKeyWithContext,
|
||||
}}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "root",
|
||||
SignatureAlgorithm: apiv1.ECDSAWithSHA256,
|
||||
}}, nil, true},
|
||||
{"fail createAlias", fields{nil, &MockClient{
|
||||
createKeyWithContext: okClient.createKeyWithContext,
|
||||
createAliasWithContext: func(ctx aws.Context, input *kms.CreateAliasInput, opts ...request.Option) (*kms.CreateAliasOutput, error) {
|
||||
return nil, fmt.Errorf("an error")
|
||||
},
|
||||
getPublicKeyWithContext: okClient.getPublicKeyWithContext,
|
||||
}}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "root",
|
||||
SignatureAlgorithm: apiv1.ECDSAWithSHA256,
|
||||
}}, nil, true},
|
||||
{"fail getPublicKey", fields{nil, &MockClient{
|
||||
createKeyWithContext: okClient.createKeyWithContext,
|
||||
createAliasWithContext: okClient.createAliasWithContext,
|
||||
getPublicKeyWithContext: func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) {
|
||||
return nil, fmt.Errorf("an error")
|
||||
},
|
||||
}}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "root",
|
||||
SignatureAlgorithm: apiv1.ECDSAWithSHA256,
|
||||
}}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
k := &KMS{
|
||||
session: tt.fields.session,
|
||||
service: tt.fields.service,
|
||||
}
|
||||
got, err := k.CreateKey(tt.args.req)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("KMS.CreateKey() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("KMS.CreateKey() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKMS_CreateSigner(t *testing.T) {
|
||||
client := getOKClient()
|
||||
key, err := pemutil.ParseKey([]byte(publicKey))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
type fields struct {
|
||||
session *session.Session
|
||||
service KeyManagementClient
|
||||
}
|
||||
type args struct {
|
||||
req *apiv1.CreateSignerRequest
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want crypto.Signer
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", fields{nil, client}, args{&apiv1.CreateSignerRequest{
|
||||
SigningKey: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936",
|
||||
}}, &Signer{
|
||||
service: client,
|
||||
keyID: "be468355-ca7a-40d9-a28b-8ae1c4c7f936",
|
||||
publicKey: key,
|
||||
}, false},
|
||||
{"fail empty", fields{nil, client}, args{&apiv1.CreateSignerRequest{}}, nil, true},
|
||||
{"fail preload", fields{nil, client}, args{&apiv1.CreateSignerRequest{}}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
k := &KMS{
|
||||
session: tt.fields.session,
|
||||
service: tt.fields.service,
|
||||
}
|
||||
got, err := k.CreateSigner(tt.args.req)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("KMS.CreateSigner() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("KMS.CreateSigner() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKMS_Close(t *testing.T) {
|
||||
type fields struct {
|
||||
session *session.Session
|
||||
service KeyManagementClient
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", fields{nil, getOKClient()}, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
k := &KMS{
|
||||
session: tt.fields.session,
|
||||
service: tt.fields.service,
|
||||
}
|
||||
if err := k.Close(); (err != nil) != tt.wantErr {
|
||||
t.Errorf("KMS.Close() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_parseKeyID(t *testing.T) {
|
||||
type args struct {
|
||||
name string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok uri", args{"awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936"}, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", false},
|
||||
{"ok key id", args{"be468355-ca7a-40d9-a28b-8ae1c4c7f936"}, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", false},
|
||||
{"ok arn", args{"arn:aws:kms:us-east-1:123456789:key/be468355-ca7a-40d9-a28b-8ae1c4c7f936"}, "arn:aws:kms:us-east-1:123456789:key/be468355-ca7a-40d9-a28b-8ae1c4c7f936", false},
|
||||
{"fail parse", args{"awskms:key-id=%ZZ"}, "", true},
|
||||
{"fail empty key", args{"awskms:key-id="}, "", true},
|
||||
{"fail missing", args{"awskms:foo=bar"}, "", true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := parseKeyID(tt.args.name)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("parseKeyID() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("parseKeyID() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -1,72 +0,0 @@
|
||||
package awskms
|
||||
|
||||
import (
|
||||
"encoding/pem"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/service/kms"
|
||||
)
|
||||
|
||||
type MockClient struct {
|
||||
getPublicKeyWithContext func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error)
|
||||
createKeyWithContext func(ctx aws.Context, input *kms.CreateKeyInput, opts ...request.Option) (*kms.CreateKeyOutput, error)
|
||||
createAliasWithContext func(ctx aws.Context, input *kms.CreateAliasInput, opts ...request.Option) (*kms.CreateAliasOutput, error)
|
||||
signWithContext func(ctx aws.Context, input *kms.SignInput, opts ...request.Option) (*kms.SignOutput, error)
|
||||
}
|
||||
|
||||
func (m *MockClient) GetPublicKeyWithContext(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) {
|
||||
return m.getPublicKeyWithContext(ctx, input, opts...)
|
||||
}
|
||||
|
||||
func (m *MockClient) CreateKeyWithContext(ctx aws.Context, input *kms.CreateKeyInput, opts ...request.Option) (*kms.CreateKeyOutput, error) {
|
||||
return m.createKeyWithContext(ctx, input, opts...)
|
||||
}
|
||||
|
||||
func (m *MockClient) CreateAliasWithContext(ctx aws.Context, input *kms.CreateAliasInput, opts ...request.Option) (*kms.CreateAliasOutput, error) {
|
||||
return m.createAliasWithContext(ctx, input, opts...)
|
||||
}
|
||||
|
||||
func (m *MockClient) SignWithContext(ctx aws.Context, input *kms.SignInput, opts ...request.Option) (*kms.SignOutput, error) {
|
||||
return m.signWithContext(ctx, input, opts...)
|
||||
}
|
||||
|
||||
const (
|
||||
publicKey = `-----BEGIN PUBLIC KEY-----
|
||||
MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE8XWlIWkOThxNjGbZLYUgRHmsvCrW
|
||||
KF+HLktPfPTIK3lGd1k4849WQs59XIN+LXZQ6b2eRBEBKAHEyQus8UU7gw==
|
||||
-----END PUBLIC KEY-----`
|
||||
keyID = "be468355-ca7a-40d9-a28b-8ae1c4c7f936"
|
||||
)
|
||||
|
||||
var signature = []byte{
|
||||
0xe3, 0xb0, 0xc4, 0x42, 0x98, 0xfc, 0x1c, 0x14, 0x9a, 0xfb, 0xf4, 0xc8, 0x99, 0x6f, 0xb9, 0x24,
|
||||
0x27, 0xae, 0x41, 0xe4, 0x64, 0x9b, 0x93, 0x4c, 0xa4, 0x95, 0x99, 0x1b, 0x78, 0x52, 0xb8, 0x55,
|
||||
}
|
||||
|
||||
func getOKClient() *MockClient {
|
||||
return &MockClient{
|
||||
getPublicKeyWithContext: func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) {
|
||||
block, _ := pem.Decode([]byte(publicKey))
|
||||
return &kms.GetPublicKeyOutput{
|
||||
KeyId: input.KeyId,
|
||||
PublicKey: block.Bytes,
|
||||
}, nil
|
||||
},
|
||||
createKeyWithContext: func(ctx aws.Context, input *kms.CreateKeyInput, opts ...request.Option) (*kms.CreateKeyOutput, error) {
|
||||
md := new(kms.KeyMetadata)
|
||||
md.SetKeyId(keyID)
|
||||
return &kms.CreateKeyOutput{
|
||||
KeyMetadata: md,
|
||||
}, nil
|
||||
},
|
||||
createAliasWithContext: func(ctx aws.Context, input *kms.CreateAliasInput, opts ...request.Option) (*kms.CreateAliasOutput, error) {
|
||||
return &kms.CreateAliasOutput{}, nil
|
||||
},
|
||||
signWithContext: func(ctx aws.Context, input *kms.SignInput, opts ...request.Option) (*kms.SignOutput, error) {
|
||||
return &kms.SignOutput{
|
||||
Signature: signature,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
}
|
@ -1,122 +0,0 @@
|
||||
package awskms
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
"io"
|
||||
|
||||
"github.com/aws/aws-sdk-go/service/kms"
|
||||
"github.com/pkg/errors"
|
||||
"go.step.sm/crypto/pemutil"
|
||||
)
|
||||
|
||||
// Signer implements a crypto.Signer using the AWS KMS.
|
||||
type Signer struct {
|
||||
service KeyManagementClient
|
||||
keyID string
|
||||
publicKey crypto.PublicKey
|
||||
}
|
||||
|
||||
// NewSigner creates a new signer using a key in the AWS KMS.
|
||||
func NewSigner(svc KeyManagementClient, signingKey string) (*Signer, error) {
|
||||
keyID, err := parseKeyID(signingKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Make sure that the key exists.
|
||||
signer := &Signer{
|
||||
service: svc,
|
||||
keyID: keyID,
|
||||
}
|
||||
if err := signer.preloadKey(keyID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return signer, nil
|
||||
}
|
||||
|
||||
func (s *Signer) preloadKey(keyID string) error {
|
||||
ctx, cancel := defaultContext()
|
||||
defer cancel()
|
||||
|
||||
resp, err := s.service.GetPublicKeyWithContext(ctx, &kms.GetPublicKeyInput{
|
||||
KeyId: &keyID,
|
||||
})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "awskms GetPublicKeyWithContext failed")
|
||||
}
|
||||
|
||||
s.publicKey, err = pemutil.ParseDER(resp.PublicKey)
|
||||
return err
|
||||
}
|
||||
|
||||
// Public returns the public key of this signer or an error.
|
||||
func (s *Signer) Public() crypto.PublicKey {
|
||||
return s.publicKey
|
||||
}
|
||||
|
||||
// Sign signs digest with the private key stored in the AWS KMS.
|
||||
func (s *Signer) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
|
||||
alg, err := getSigningAlgorithm(s.Public(), opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req := &kms.SignInput{
|
||||
KeyId: &s.keyID,
|
||||
SigningAlgorithm: &alg,
|
||||
Message: digest,
|
||||
}
|
||||
req.SetMessageType("DIGEST")
|
||||
|
||||
ctx, cancel := defaultContext()
|
||||
defer cancel()
|
||||
|
||||
resp, err := s.service.SignWithContext(ctx, req)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "awsKMS SignWithContext failed")
|
||||
}
|
||||
|
||||
return resp.Signature, nil
|
||||
}
|
||||
|
||||
func getSigningAlgorithm(key crypto.PublicKey, opts crypto.SignerOpts) (string, error) {
|
||||
switch key.(type) {
|
||||
case *rsa.PublicKey:
|
||||
_, isPSS := opts.(*rsa.PSSOptions)
|
||||
switch h := opts.HashFunc(); h {
|
||||
case crypto.SHA256:
|
||||
if isPSS {
|
||||
return kms.SigningAlgorithmSpecRsassaPssSha256, nil
|
||||
}
|
||||
return kms.SigningAlgorithmSpecRsassaPkcs1V15Sha256, nil
|
||||
case crypto.SHA384:
|
||||
if isPSS {
|
||||
return kms.SigningAlgorithmSpecRsassaPssSha384, nil
|
||||
}
|
||||
return kms.SigningAlgorithmSpecRsassaPkcs1V15Sha384, nil
|
||||
case crypto.SHA512:
|
||||
if isPSS {
|
||||
return kms.SigningAlgorithmSpecRsassaPssSha512, nil
|
||||
}
|
||||
return kms.SigningAlgorithmSpecRsassaPkcs1V15Sha512, nil
|
||||
default:
|
||||
return "", errors.Errorf("unsupported hash function %v", h)
|
||||
}
|
||||
case *ecdsa.PublicKey:
|
||||
switch h := opts.HashFunc(); h {
|
||||
case crypto.SHA256:
|
||||
return kms.SigningAlgorithmSpecEcdsaSha256, nil
|
||||
case crypto.SHA384:
|
||||
return kms.SigningAlgorithmSpecEcdsaSha384, nil
|
||||
case crypto.SHA512:
|
||||
return kms.SigningAlgorithmSpecEcdsaSha512, nil
|
||||
default:
|
||||
return "", errors.Errorf("unsupported hash function %v", h)
|
||||
}
|
||||
default:
|
||||
return "", errors.Errorf("unsupported key type %T", key)
|
||||
}
|
||||
}
|
@ -1,191 +0,0 @@
|
||||
package awskms
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/service/kms"
|
||||
"go.step.sm/crypto/pemutil"
|
||||
)
|
||||
|
||||
func TestNewSigner(t *testing.T) {
|
||||
okClient := getOKClient()
|
||||
key, err := pemutil.ParseKey([]byte(publicKey))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
type args struct {
|
||||
svc KeyManagementClient
|
||||
signingKey string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *Signer
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", args{okClient, "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936"}, &Signer{
|
||||
service: okClient,
|
||||
keyID: "be468355-ca7a-40d9-a28b-8ae1c4c7f936",
|
||||
publicKey: key,
|
||||
}, false},
|
||||
{"fail parse", args{okClient, "awskms:key-id="}, nil, true},
|
||||
{"fail preload", args{&MockClient{
|
||||
getPublicKeyWithContext: func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) {
|
||||
return nil, fmt.Errorf("an error")
|
||||
},
|
||||
}, "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936"}, nil, true},
|
||||
{"fail preload not der", args{&MockClient{
|
||||
getPublicKeyWithContext: func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) {
|
||||
return &kms.GetPublicKeyOutput{
|
||||
KeyId: input.KeyId,
|
||||
PublicKey: []byte(publicKey),
|
||||
}, nil
|
||||
},
|
||||
}, "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936"}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := NewSigner(tt.args.svc, tt.args.signingKey)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("NewSigner() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("NewSigner() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSigner_Public(t *testing.T) {
|
||||
okClient := getOKClient()
|
||||
key, err := pemutil.ParseKey([]byte(publicKey))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
type fields struct {
|
||||
service KeyManagementClient
|
||||
keyID string
|
||||
publicKey crypto.PublicKey
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want crypto.PublicKey
|
||||
}{
|
||||
{"ok", fields{okClient, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", key}, key},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &Signer{
|
||||
service: tt.fields.service,
|
||||
keyID: tt.fields.keyID,
|
||||
publicKey: tt.fields.publicKey,
|
||||
}
|
||||
if got := s.Public(); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Signer.Public() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSigner_Sign(t *testing.T) {
|
||||
okClient := getOKClient()
|
||||
key, err := pemutil.ParseKey([]byte(publicKey))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
type fields struct {
|
||||
service KeyManagementClient
|
||||
keyID string
|
||||
publicKey crypto.PublicKey
|
||||
}
|
||||
type args struct {
|
||||
rand io.Reader
|
||||
digest []byte
|
||||
opts crypto.SignerOpts
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want []byte
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", fields{okClient, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", key}, args{rand.Reader, []byte("digest"), crypto.SHA256}, signature, false},
|
||||
{"fail alg", fields{okClient, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", key}, args{rand.Reader, []byte("digest"), crypto.MD5}, nil, true},
|
||||
{"fail key", fields{okClient, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", []byte("key")}, args{rand.Reader, []byte("digest"), crypto.SHA256}, nil, true},
|
||||
{"fail sign", fields{&MockClient{
|
||||
signWithContext: func(ctx aws.Context, input *kms.SignInput, opts ...request.Option) (*kms.SignOutput, error) {
|
||||
return nil, fmt.Errorf("an error")
|
||||
},
|
||||
}, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", key}, args{rand.Reader, []byte("digest"), crypto.SHA256}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &Signer{
|
||||
service: tt.fields.service,
|
||||
keyID: tt.fields.keyID,
|
||||
publicKey: tt.fields.publicKey,
|
||||
}
|
||||
got, err := s.Sign(tt.args.rand, tt.args.digest, tt.args.opts)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Signer.Sign() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Signer.Sign() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_getSigningAlgorithm(t *testing.T) {
|
||||
type args struct {
|
||||
key crypto.PublicKey
|
||||
opts crypto.SignerOpts
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{"rsa+sha256", args{&rsa.PublicKey{}, crypto.SHA256}, "RSASSA_PKCS1_V1_5_SHA_256", false},
|
||||
{"rsa+sha384", args{&rsa.PublicKey{}, crypto.SHA384}, "RSASSA_PKCS1_V1_5_SHA_384", false},
|
||||
{"rsa+sha512", args{&rsa.PublicKey{}, crypto.SHA512}, "RSASSA_PKCS1_V1_5_SHA_512", false},
|
||||
{"pssrsa+sha256", args{&rsa.PublicKey{}, &rsa.PSSOptions{Hash: crypto.SHA256.HashFunc()}}, "RSASSA_PSS_SHA_256", false},
|
||||
{"pssrsa+sha384", args{&rsa.PublicKey{}, &rsa.PSSOptions{Hash: crypto.SHA384.HashFunc()}}, "RSASSA_PSS_SHA_384", false},
|
||||
{"pssrsa+sha512", args{&rsa.PublicKey{}, &rsa.PSSOptions{Hash: crypto.SHA512.HashFunc()}}, "RSASSA_PSS_SHA_512", false},
|
||||
{"P256", args{&ecdsa.PublicKey{}, crypto.SHA256}, "ECDSA_SHA_256", false},
|
||||
{"P384", args{&ecdsa.PublicKey{}, crypto.SHA384}, "ECDSA_SHA_384", false},
|
||||
{"P521", args{&ecdsa.PublicKey{}, crypto.SHA512}, "ECDSA_SHA_512", false},
|
||||
{"fail type", args{[]byte("key"), crypto.SHA256}, "", true},
|
||||
{"fail rsa alg", args{&rsa.PublicKey{}, crypto.MD5}, "", true},
|
||||
{"fail ecdsa alg", args{&ecdsa.PublicKey{}, crypto.MD5}, "", true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := getSigningAlgorithm(tt.args.key, tt.args.opts)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("getSigningAlgorithm() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("getSigningAlgorithm() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -1,81 +0,0 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/smallstep/certificates/kms/azurekms (interfaces: KeyVaultClient)
|
||||
|
||||
// Package mock is a generated GoMock package.
|
||||
package mock
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
keyvault "github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// KeyVaultClient is a mock of KeyVaultClient interface
|
||||
type KeyVaultClient struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *KeyVaultClientMockRecorder
|
||||
}
|
||||
|
||||
// KeyVaultClientMockRecorder is the mock recorder for KeyVaultClient
|
||||
type KeyVaultClientMockRecorder struct {
|
||||
mock *KeyVaultClient
|
||||
}
|
||||
|
||||
// NewKeyVaultClient creates a new mock instance
|
||||
func NewKeyVaultClient(ctrl *gomock.Controller) *KeyVaultClient {
|
||||
mock := &KeyVaultClient{ctrl: ctrl}
|
||||
mock.recorder = &KeyVaultClientMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *KeyVaultClient) EXPECT() *KeyVaultClientMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// CreateKey mocks base method
|
||||
func (m *KeyVaultClient) CreateKey(arg0 context.Context, arg1, arg2 string, arg3 keyvault.KeyCreateParameters) (keyvault.KeyBundle, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CreateKey", arg0, arg1, arg2, arg3)
|
||||
ret0, _ := ret[0].(keyvault.KeyBundle)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// CreateKey indicates an expected call of CreateKey
|
||||
func (mr *KeyVaultClientMockRecorder) CreateKey(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateKey", reflect.TypeOf((*KeyVaultClient)(nil).CreateKey), arg0, arg1, arg2, arg3)
|
||||
}
|
||||
|
||||
// GetKey mocks base method
|
||||
func (m *KeyVaultClient) GetKey(arg0 context.Context, arg1, arg2, arg3 string) (keyvault.KeyBundle, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetKey", arg0, arg1, arg2, arg3)
|
||||
ret0, _ := ret[0].(keyvault.KeyBundle)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetKey indicates an expected call of GetKey
|
||||
func (mr *KeyVaultClientMockRecorder) GetKey(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKey", reflect.TypeOf((*KeyVaultClient)(nil).GetKey), arg0, arg1, arg2, arg3)
|
||||
}
|
||||
|
||||
// Sign mocks base method
|
||||
func (m *KeyVaultClient) Sign(arg0 context.Context, arg1, arg2, arg3 string, arg4 keyvault.KeySignParameters) (keyvault.KeyOperationResult, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Sign", arg0, arg1, arg2, arg3, arg4)
|
||||
ret0, _ := ret[0].(keyvault.KeyOperationResult)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Sign indicates an expected call of Sign
|
||||
func (mr *KeyVaultClientMockRecorder) Sign(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Sign", reflect.TypeOf((*KeyVaultClient)(nil).Sign), arg0, arg1, arg2, arg3, arg4)
|
||||
}
|
@ -1,342 +0,0 @@
|
||||
package azurekms
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"regexp"
|
||||
"time"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault"
|
||||
"github.com/Azure/go-autorest/autorest/azure"
|
||||
"github.com/Azure/go-autorest/autorest/azure/auth"
|
||||
"github.com/Azure/go-autorest/autorest/date"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/kms/apiv1"
|
||||
"github.com/smallstep/certificates/kms/uri"
|
||||
)
|
||||
|
||||
func init() {
|
||||
apiv1.Register(apiv1.AzureKMS, func(ctx context.Context, opts apiv1.Options) (apiv1.KeyManager, error) {
|
||||
return New(ctx, opts)
|
||||
})
|
||||
}
|
||||
|
||||
// Scheme is the scheme used for the Azure Key Vault uris.
|
||||
const Scheme = "azurekms"
|
||||
|
||||
// keyIDRegexp is the regular expression that Key Vault uses on the kid. We can
|
||||
// extract the vault, name and version of the key.
|
||||
var keyIDRegexp = regexp.MustCompile(`^https://([0-9a-zA-Z-]+)\.vault\.azure\.net/keys/([0-9a-zA-Z-]+)/([0-9a-zA-Z-]+)$`)
|
||||
|
||||
var (
|
||||
valueTrue = true
|
||||
value2048 int32 = 2048
|
||||
value3072 int32 = 3072
|
||||
value4096 int32 = 4096
|
||||
)
|
||||
|
||||
var now = func() time.Time {
|
||||
return time.Now().UTC()
|
||||
}
|
||||
|
||||
type keyType struct {
|
||||
Kty keyvault.JSONWebKeyType
|
||||
Curve keyvault.JSONWebKeyCurveName
|
||||
}
|
||||
|
||||
func (k keyType) KeyType(pl apiv1.ProtectionLevel) keyvault.JSONWebKeyType {
|
||||
switch k.Kty {
|
||||
case keyvault.EC:
|
||||
if pl == apiv1.HSM {
|
||||
return keyvault.ECHSM
|
||||
}
|
||||
return k.Kty
|
||||
case keyvault.RSA:
|
||||
if pl == apiv1.HSM {
|
||||
return keyvault.RSAHSM
|
||||
}
|
||||
return k.Kty
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
var signatureAlgorithmMapping = map[apiv1.SignatureAlgorithm]keyType{
|
||||
apiv1.UnspecifiedSignAlgorithm: {
|
||||
Kty: keyvault.EC,
|
||||
Curve: keyvault.P256,
|
||||
},
|
||||
apiv1.SHA256WithRSA: {
|
||||
Kty: keyvault.RSA,
|
||||
},
|
||||
apiv1.SHA384WithRSA: {
|
||||
Kty: keyvault.RSA,
|
||||
},
|
||||
apiv1.SHA512WithRSA: {
|
||||
Kty: keyvault.RSA,
|
||||
},
|
||||
apiv1.SHA256WithRSAPSS: {
|
||||
Kty: keyvault.RSA,
|
||||
},
|
||||
apiv1.SHA384WithRSAPSS: {
|
||||
Kty: keyvault.RSA,
|
||||
},
|
||||
apiv1.SHA512WithRSAPSS: {
|
||||
Kty: keyvault.RSA,
|
||||
},
|
||||
apiv1.ECDSAWithSHA256: {
|
||||
Kty: keyvault.EC,
|
||||
Curve: keyvault.P256,
|
||||
},
|
||||
apiv1.ECDSAWithSHA384: {
|
||||
Kty: keyvault.EC,
|
||||
Curve: keyvault.P384,
|
||||
},
|
||||
apiv1.ECDSAWithSHA512: {
|
||||
Kty: keyvault.EC,
|
||||
Curve: keyvault.P521,
|
||||
},
|
||||
}
|
||||
|
||||
// vaultResource is the value the client will use as audience.
|
||||
const vaultResource = "https://vault.azure.net"
|
||||
|
||||
// KeyVaultClient is the interface implemented by keyvault.BaseClient. It will
|
||||
// be used for testing purposes.
|
||||
type KeyVaultClient interface {
|
||||
GetKey(ctx context.Context, vaultBaseURL string, keyName string, keyVersion string) (keyvault.KeyBundle, error)
|
||||
CreateKey(ctx context.Context, vaultBaseURL string, keyName string, parameters keyvault.KeyCreateParameters) (keyvault.KeyBundle, error)
|
||||
Sign(ctx context.Context, vaultBaseURL string, keyName string, keyVersion string, parameters keyvault.KeySignParameters) (keyvault.KeyOperationResult, error)
|
||||
}
|
||||
|
||||
// KeyVault implements a KMS using Azure Key Vault.
|
||||
//
|
||||
// The URI format used in Azure Key Vault is the following:
|
||||
//
|
||||
// - azurekms:name=key-name;vault=vault-name
|
||||
// - azurekms:name=key-name;vault=vault-name?version=key-version
|
||||
// - azurekms:name=key-name;vault=vault-name?hsm=true
|
||||
//
|
||||
// The scheme is "azurekms"; "name" is the key name; "vault" is the key vault
|
||||
// name where the key is located; "version" is an optional parameter that
|
||||
// defines the version of they key, if version is not given, the latest one will
|
||||
// be used; "hsm" defines if an HSM want to be used for this key, this is
|
||||
// specially useful when this is used from `step`.
|
||||
//
|
||||
// TODO(mariano): The implementation is using /services/keyvault/v7.1/keyvault
|
||||
// package, at some point Azure might create a keyvault client with all the
|
||||
// functionality in /sdk/keyvault, we should migrate to that once available.
|
||||
type KeyVault struct {
|
||||
baseClient KeyVaultClient
|
||||
defaults DefaultOptions
|
||||
}
|
||||
|
||||
// DefaultOptions are custom options that can be passed as defaults using the
|
||||
// URI in apiv1.Options.
|
||||
type DefaultOptions struct {
|
||||
Vault string
|
||||
ProtectionLevel apiv1.ProtectionLevel
|
||||
}
|
||||
|
||||
var createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient, error) {
|
||||
baseClient := keyvault.New()
|
||||
|
||||
// With an URI, try to log in only using client credentials in the URI.
|
||||
// Client credentials requires:
|
||||
// - client-id
|
||||
// - client-secret
|
||||
// - tenant-id
|
||||
// And optionally the aad-endpoint to support custom clouds:
|
||||
// - aad-endpoint (defaults to https://login.microsoftonline.com/)
|
||||
if opts.URI != "" {
|
||||
u, err := uri.ParseWithScheme(Scheme, opts.URI)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Required options
|
||||
clientID := u.Get("client-id")
|
||||
clientSecret := u.Get("client-secret")
|
||||
tenantID := u.Get("tenant-id")
|
||||
// optional
|
||||
aadEndpoint := u.Get("aad-endpoint")
|
||||
|
||||
if clientID != "" && clientSecret != "" && tenantID != "" {
|
||||
s := auth.EnvironmentSettings{
|
||||
Values: map[string]string{
|
||||
auth.ClientID: clientID,
|
||||
auth.ClientSecret: clientSecret,
|
||||
auth.TenantID: tenantID,
|
||||
auth.Resource: vaultResource,
|
||||
},
|
||||
Environment: azure.PublicCloud,
|
||||
}
|
||||
if aadEndpoint != "" {
|
||||
s.Environment.ActiveDirectoryEndpoint = aadEndpoint
|
||||
}
|
||||
baseClient.Authorizer, err = s.GetAuthorizer()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return baseClient, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Attempt to authorize with the following methods:
|
||||
// 1. Environment variables.
|
||||
// - Client credentials
|
||||
// - Client certificate
|
||||
// - Username and password
|
||||
// - MSI
|
||||
// 2. Using Azure CLI 2.0 on local development.
|
||||
authorizer, err := auth.NewAuthorizerFromEnvironmentWithResource(vaultResource)
|
||||
if err != nil {
|
||||
authorizer, err = auth.NewAuthorizerFromCLIWithResource(vaultResource)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error getting authorizer for key vault")
|
||||
}
|
||||
}
|
||||
baseClient.Authorizer = authorizer
|
||||
return &baseClient, nil
|
||||
}
|
||||
|
||||
// New initializes a new KMS implemented using Azure Key Vault.
|
||||
func New(ctx context.Context, opts apiv1.Options) (*KeyVault, error) {
|
||||
baseClient, err := createClient(ctx, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// step and step-ca do not need and URI, but having a default vault and
|
||||
// protection level is useful if this package is used as an api
|
||||
var defaults DefaultOptions
|
||||
if opts.URI != "" {
|
||||
u, err := uri.ParseWithScheme(Scheme, opts.URI)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defaults.Vault = u.Get("vault")
|
||||
if u.GetBool("hsm") {
|
||||
defaults.ProtectionLevel = apiv1.HSM
|
||||
}
|
||||
}
|
||||
|
||||
return &KeyVault{
|
||||
baseClient: baseClient,
|
||||
defaults: defaults,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetPublicKey loads a public key from Azure Key Vault by its resource name.
|
||||
func (k *KeyVault) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey, error) {
|
||||
if req.Name == "" {
|
||||
return nil, errors.New("getPublicKeyRequest 'name' cannot be empty")
|
||||
}
|
||||
|
||||
vault, name, version, _, err := parseKeyName(req.Name, k.defaults)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx, cancel := defaultContext()
|
||||
defer cancel()
|
||||
|
||||
resp, err := k.baseClient.GetKey(ctx, vaultBaseURL(vault), name, version)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "keyVault GetKey failed")
|
||||
}
|
||||
|
||||
return convertKey(resp.Key)
|
||||
}
|
||||
|
||||
// CreateKey creates a asymmetric key in Azure Key Vault.
|
||||
func (k *KeyVault) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyResponse, error) {
|
||||
if req.Name == "" {
|
||||
return nil, errors.New("createKeyRequest 'name' cannot be empty")
|
||||
}
|
||||
|
||||
vault, name, _, hsm, err := parseKeyName(req.Name, k.defaults)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Override protection level to HSM only if it's not specified, and is given
|
||||
// in the uri.
|
||||
protectionLevel := req.ProtectionLevel
|
||||
if protectionLevel == apiv1.UnspecifiedProtectionLevel && hsm {
|
||||
protectionLevel = apiv1.HSM
|
||||
}
|
||||
|
||||
kt, ok := signatureAlgorithmMapping[req.SignatureAlgorithm]
|
||||
if !ok {
|
||||
return nil, errors.Errorf("keyVault does not support signature algorithm '%s'", req.SignatureAlgorithm)
|
||||
}
|
||||
var keySize *int32
|
||||
if kt.Kty == keyvault.RSA || kt.Kty == keyvault.RSAHSM {
|
||||
switch req.Bits {
|
||||
case 2048:
|
||||
keySize = &value2048
|
||||
case 0, 3072:
|
||||
keySize = &value3072
|
||||
case 4096:
|
||||
keySize = &value4096
|
||||
default:
|
||||
return nil, errors.Errorf("keyVault does not support key size %d", req.Bits)
|
||||
}
|
||||
}
|
||||
|
||||
created := date.UnixTime(now())
|
||||
|
||||
ctx, cancel := defaultContext()
|
||||
defer cancel()
|
||||
|
||||
resp, err := k.baseClient.CreateKey(ctx, vaultBaseURL(vault), name, keyvault.KeyCreateParameters{
|
||||
Kty: kt.KeyType(protectionLevel),
|
||||
KeySize: keySize,
|
||||
Curve: kt.Curve,
|
||||
KeyOps: &[]keyvault.JSONWebKeyOperation{
|
||||
keyvault.Sign, keyvault.Verify,
|
||||
},
|
||||
KeyAttributes: &keyvault.KeyAttributes{
|
||||
Enabled: &valueTrue,
|
||||
Created: &created,
|
||||
NotBefore: &created,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "keyVault CreateKey failed")
|
||||
}
|
||||
|
||||
publicKey, err := convertKey(resp.Key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keyURI := getKeyName(vault, name, resp)
|
||||
return &apiv1.CreateKeyResponse{
|
||||
Name: keyURI,
|
||||
PublicKey: publicKey,
|
||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
||||
SigningKey: keyURI,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateSigner returns a crypto.Signer from a previously created asymmetric key.
|
||||
func (k *KeyVault) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, error) {
|
||||
if req.SigningKey == "" {
|
||||
return nil, errors.New("createSignerRequest 'signingKey' cannot be empty")
|
||||
}
|
||||
return NewSigner(k.baseClient, req.SigningKey, k.defaults)
|
||||
}
|
||||
|
||||
// Close closes the client connection to the Azure Key Vault. This is a noop.
|
||||
func (k *KeyVault) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateName validates that the given string is a valid URI.
|
||||
func (k *KeyVault) ValidateName(s string) error {
|
||||
_, _, _, _, err := parseKeyName(s, k.defaults)
|
||||
return err
|
||||
}
|
@ -1,653 +0,0 @@
|
||||
//go:generate mockgen -package mock -mock_names=KeyVaultClient=KeyVaultClient -destination internal/mock/key_vault_client.go github.com/smallstep/certificates/kms/azurekms KeyVaultClient
|
||||
package azurekms
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault"
|
||||
"github.com/Azure/go-autorest/autorest/date"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/smallstep/certificates/kms/apiv1"
|
||||
"github.com/smallstep/certificates/kms/azurekms/internal/mock"
|
||||
"go.step.sm/crypto/keyutil"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
var errTest = fmt.Errorf("test error")
|
||||
|
||||
func mockNow(t *testing.T) time.Time {
|
||||
old := now
|
||||
t0 := time.Unix(1234567890, 123).UTC()
|
||||
now = func() time.Time {
|
||||
return t0
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
now = old
|
||||
})
|
||||
return t0
|
||||
}
|
||||
|
||||
func mockClient(t *testing.T) *mock.KeyVaultClient {
|
||||
t.Helper()
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(func() {
|
||||
ctrl.Finish()
|
||||
})
|
||||
return mock.NewKeyVaultClient(ctrl)
|
||||
}
|
||||
|
||||
func createJWK(t *testing.T, pub crypto.PublicKey) *keyvault.JSONWebKey {
|
||||
t.Helper()
|
||||
b, err := json.Marshal(&jose.JSONWebKey{
|
||||
Key: pub,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
key := new(keyvault.JSONWebKey)
|
||||
if err := json.Unmarshal(b, key); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
||||
func Test_now(t *testing.T) {
|
||||
t0 := now()
|
||||
if loc := t0.Location(); loc != time.UTC {
|
||||
t.Errorf("now() Location = %v, want %v", loc, time.UTC)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
client := mockClient(t)
|
||||
old := createClient
|
||||
t.Cleanup(func() {
|
||||
createClient = old
|
||||
})
|
||||
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
opts apiv1.Options
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func()
|
||||
args args
|
||||
want *KeyVault
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", func() {
|
||||
createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient, error) {
|
||||
return client, nil
|
||||
}
|
||||
}, args{context.Background(), apiv1.Options{}}, &KeyVault{
|
||||
baseClient: client,
|
||||
}, false},
|
||||
{"ok with vault", func() {
|
||||
createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient, error) {
|
||||
return client, nil
|
||||
}
|
||||
}, args{context.Background(), apiv1.Options{
|
||||
URI: "azurekms:vault=my-vault",
|
||||
}}, &KeyVault{
|
||||
baseClient: client,
|
||||
defaults: DefaultOptions{
|
||||
Vault: "my-vault",
|
||||
ProtectionLevel: apiv1.UnspecifiedProtectionLevel,
|
||||
},
|
||||
}, false},
|
||||
{"ok with vault + hsm", func() {
|
||||
createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient, error) {
|
||||
return client, nil
|
||||
}
|
||||
}, args{context.Background(), apiv1.Options{
|
||||
URI: "azurekms:vault=my-vault;hsm=true",
|
||||
}}, &KeyVault{
|
||||
baseClient: client,
|
||||
defaults: DefaultOptions{
|
||||
Vault: "my-vault",
|
||||
ProtectionLevel: apiv1.HSM,
|
||||
},
|
||||
}, false},
|
||||
{"fail", func() {
|
||||
createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient, error) {
|
||||
return nil, errTest
|
||||
}
|
||||
}, args{context.Background(), apiv1.Options{}}, nil, true},
|
||||
{"fail uri", func() {
|
||||
createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient, error) {
|
||||
return client, nil
|
||||
}
|
||||
}, args{context.Background(), apiv1.Options{
|
||||
URI: "kms:vault=my-vault;hsm=true",
|
||||
}}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.setup()
|
||||
got, err := New(tt.args.ctx, tt.args.opts)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("New() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeyVault_createClient(t *testing.T) {
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
opts apiv1.Options
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
skip bool
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", args{context.Background(), apiv1.Options{}}, true, false},
|
||||
{"ok with uri", args{context.Background(), apiv1.Options{
|
||||
URI: "azurekms:client-id=id;client-secret=secret;tenant-id=id",
|
||||
}}, false, false},
|
||||
{"ok with uri+aad", args{context.Background(), apiv1.Options{
|
||||
URI: "azurekms:client-id=id;client-secret=secret;tenant-id=id;aad-enpoint=https%3A%2F%2Flogin.microsoftonline.us%2F",
|
||||
}}, false, false},
|
||||
{"ok with uri no config", args{context.Background(), apiv1.Options{
|
||||
URI: "azurekms:",
|
||||
}}, true, false},
|
||||
{"fail uri", args{context.Background(), apiv1.Options{
|
||||
URI: "kms:client-id=id;client-secret=secret;tenant-id=id",
|
||||
}}, false, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.skip {
|
||||
t.SkipNow()
|
||||
}
|
||||
_, err := createClient(tt.args.ctx, tt.args.opts)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeyVault_GetPublicKey(t *testing.T) {
|
||||
key, err := keyutil.GenerateDefaultSigner()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
pub := key.Public()
|
||||
jwk := createJWK(t, pub)
|
||||
|
||||
client := mockClient(t)
|
||||
client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", "").Return(keyvault.KeyBundle{
|
||||
Key: jwk,
|
||||
}, nil)
|
||||
client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", "my-version").Return(keyvault.KeyBundle{
|
||||
Key: jwk,
|
||||
}, nil)
|
||||
client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "not-found", "my-version").Return(keyvault.KeyBundle{}, errTest)
|
||||
|
||||
type fields struct {
|
||||
baseClient KeyVaultClient
|
||||
}
|
||||
type args struct {
|
||||
req *apiv1.GetPublicKeyRequest
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want crypto.PublicKey
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", fields{client}, args{&apiv1.GetPublicKeyRequest{
|
||||
Name: "azurekms:vault=my-vault;name=my-key",
|
||||
}}, pub, false},
|
||||
{"ok with version", fields{client}, args{&apiv1.GetPublicKeyRequest{
|
||||
Name: "azurekms:vault=my-vault;name=my-key?version=my-version",
|
||||
}}, pub, false},
|
||||
{"fail GetKey", fields{client}, args{&apiv1.GetPublicKeyRequest{
|
||||
Name: "azurekms:vault=my-vault;name=not-found?version=my-version",
|
||||
}}, nil, true},
|
||||
{"fail empty", fields{client}, args{&apiv1.GetPublicKeyRequest{
|
||||
Name: "",
|
||||
}}, nil, true},
|
||||
{"fail vault", fields{client}, args{&apiv1.GetPublicKeyRequest{
|
||||
Name: "azurekms:vault=;name=not-found?version=my-version",
|
||||
}}, nil, true},
|
||||
{"fail id", fields{client}, args{&apiv1.GetPublicKeyRequest{
|
||||
Name: "azurekms:vault=;name=?version=my-version",
|
||||
}}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
k := &KeyVault{
|
||||
baseClient: tt.fields.baseClient,
|
||||
}
|
||||
got, err := k.GetPublicKey(tt.args.req)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("KeyVault.GetPublicKey() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("KeyVault.GetPublicKey() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeyVault_CreateKey(t *testing.T) {
|
||||
ecKey, err := keyutil.GenerateDefaultSigner()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rsaKey, err := keyutil.GenerateSigner("RSA", "", 2048)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ecPub := ecKey.Public()
|
||||
rsaPub := rsaKey.Public()
|
||||
ecJWK := createJWK(t, ecPub)
|
||||
rsaJWK := createJWK(t, rsaPub)
|
||||
|
||||
t0 := date.UnixTime(mockNow(t))
|
||||
client := mockClient(t)
|
||||
|
||||
expects := []struct {
|
||||
Name string
|
||||
Kty keyvault.JSONWebKeyType
|
||||
KeySize *int32
|
||||
Curve keyvault.JSONWebKeyCurveName
|
||||
Key *keyvault.JSONWebKey
|
||||
}{
|
||||
{"P-256", keyvault.EC, nil, keyvault.P256, ecJWK},
|
||||
{"P-256 HSM", keyvault.ECHSM, nil, keyvault.P256, ecJWK},
|
||||
{"P-256 HSM (uri)", keyvault.ECHSM, nil, keyvault.P256, ecJWK},
|
||||
{"P-256 Default", keyvault.EC, nil, keyvault.P256, ecJWK},
|
||||
{"P-384", keyvault.EC, nil, keyvault.P384, ecJWK},
|
||||
{"P-521", keyvault.EC, nil, keyvault.P521, ecJWK},
|
||||
{"RSA 0", keyvault.RSA, &value3072, "", rsaJWK},
|
||||
{"RSA 0 HSM", keyvault.RSAHSM, &value3072, "", rsaJWK},
|
||||
{"RSA 0 HSM (uri)", keyvault.RSAHSM, &value3072, "", rsaJWK},
|
||||
{"RSA 2048", keyvault.RSA, &value2048, "", rsaJWK},
|
||||
{"RSA 3072", keyvault.RSA, &value3072, "", rsaJWK},
|
||||
{"RSA 4096", keyvault.RSA, &value4096, "", rsaJWK},
|
||||
}
|
||||
|
||||
for _, e := range expects {
|
||||
client.EXPECT().CreateKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", keyvault.KeyCreateParameters{
|
||||
Kty: e.Kty,
|
||||
KeySize: e.KeySize,
|
||||
Curve: e.Curve,
|
||||
KeyOps: &[]keyvault.JSONWebKeyOperation{
|
||||
keyvault.Sign, keyvault.Verify,
|
||||
},
|
||||
KeyAttributes: &keyvault.KeyAttributes{
|
||||
Enabled: &valueTrue,
|
||||
Created: &t0,
|
||||
NotBefore: &t0,
|
||||
},
|
||||
}).Return(keyvault.KeyBundle{
|
||||
Key: e.Key,
|
||||
}, nil)
|
||||
}
|
||||
client.EXPECT().CreateKey(gomock.Any(), "https://my-vault.vault.azure.net/", "not-found", gomock.Any()).Return(keyvault.KeyBundle{}, errTest)
|
||||
client.EXPECT().CreateKey(gomock.Any(), "https://my-vault.vault.azure.net/", "not-found", gomock.Any()).Return(keyvault.KeyBundle{
|
||||
Key: nil,
|
||||
}, nil)
|
||||
|
||||
type fields struct {
|
||||
baseClient KeyVaultClient
|
||||
}
|
||||
type args struct {
|
||||
req *apiv1.CreateKeyRequest
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want *apiv1.CreateKeyResponse
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok P-256", fields{client}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "azurekms:vault=my-vault;name=my-key",
|
||||
SignatureAlgorithm: apiv1.ECDSAWithSHA256,
|
||||
ProtectionLevel: apiv1.Software,
|
||||
}}, &apiv1.CreateKeyResponse{
|
||||
Name: "azurekms:name=my-key;vault=my-vault",
|
||||
PublicKey: ecPub,
|
||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
||||
SigningKey: "azurekms:name=my-key;vault=my-vault",
|
||||
},
|
||||
}, false},
|
||||
{"ok P-256 HSM", fields{client}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "azurekms:vault=my-vault;name=my-key",
|
||||
SignatureAlgorithm: apiv1.ECDSAWithSHA256,
|
||||
ProtectionLevel: apiv1.HSM,
|
||||
}}, &apiv1.CreateKeyResponse{
|
||||
Name: "azurekms:name=my-key;vault=my-vault",
|
||||
PublicKey: ecPub,
|
||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
||||
SigningKey: "azurekms:name=my-key;vault=my-vault",
|
||||
},
|
||||
}, false},
|
||||
{"ok P-256 HSM (uri)", fields{client}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "azurekms:vault=my-vault;name=my-key?hsm=true",
|
||||
SignatureAlgorithm: apiv1.ECDSAWithSHA256,
|
||||
}}, &apiv1.CreateKeyResponse{
|
||||
Name: "azurekms:name=my-key;vault=my-vault",
|
||||
PublicKey: ecPub,
|
||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
||||
SigningKey: "azurekms:name=my-key;vault=my-vault",
|
||||
},
|
||||
}, false},
|
||||
{"ok P-256 Default", fields{client}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "azurekms:vault=my-vault;name=my-key",
|
||||
}}, &apiv1.CreateKeyResponse{
|
||||
Name: "azurekms:name=my-key;vault=my-vault",
|
||||
PublicKey: ecPub,
|
||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
||||
SigningKey: "azurekms:name=my-key;vault=my-vault",
|
||||
},
|
||||
}, false},
|
||||
{"ok P-384", fields{client}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "azurekms:vault=my-vault;name=my-key",
|
||||
SignatureAlgorithm: apiv1.ECDSAWithSHA384,
|
||||
}}, &apiv1.CreateKeyResponse{
|
||||
Name: "azurekms:name=my-key;vault=my-vault",
|
||||
PublicKey: ecPub,
|
||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
||||
SigningKey: "azurekms:name=my-key;vault=my-vault",
|
||||
},
|
||||
}, false},
|
||||
{"ok P-521", fields{client}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "azurekms:vault=my-vault;name=my-key",
|
||||
SignatureAlgorithm: apiv1.ECDSAWithSHA512,
|
||||
}}, &apiv1.CreateKeyResponse{
|
||||
Name: "azurekms:name=my-key;vault=my-vault",
|
||||
PublicKey: ecPub,
|
||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
||||
SigningKey: "azurekms:name=my-key;vault=my-vault",
|
||||
},
|
||||
}, false},
|
||||
{"ok RSA 0", fields{client}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "azurekms:vault=my-vault;name=my-key",
|
||||
Bits: 0,
|
||||
SignatureAlgorithm: apiv1.SHA256WithRSA,
|
||||
ProtectionLevel: apiv1.Software,
|
||||
}}, &apiv1.CreateKeyResponse{
|
||||
Name: "azurekms:name=my-key;vault=my-vault",
|
||||
PublicKey: rsaPub,
|
||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
||||
SigningKey: "azurekms:name=my-key;vault=my-vault",
|
||||
},
|
||||
}, false},
|
||||
{"ok RSA 0 HSM", fields{client}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "azurekms:vault=my-vault;name=my-key",
|
||||
Bits: 0,
|
||||
SignatureAlgorithm: apiv1.SHA256WithRSAPSS,
|
||||
ProtectionLevel: apiv1.HSM,
|
||||
}}, &apiv1.CreateKeyResponse{
|
||||
Name: "azurekms:name=my-key;vault=my-vault",
|
||||
PublicKey: rsaPub,
|
||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
||||
SigningKey: "azurekms:name=my-key;vault=my-vault",
|
||||
},
|
||||
}, false},
|
||||
{"ok RSA 0 HSM (uri)", fields{client}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "azurekms:vault=my-vault;name=my-key;hsm=true",
|
||||
Bits: 0,
|
||||
SignatureAlgorithm: apiv1.SHA256WithRSAPSS,
|
||||
}}, &apiv1.CreateKeyResponse{
|
||||
Name: "azurekms:name=my-key;vault=my-vault",
|
||||
PublicKey: rsaPub,
|
||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
||||
SigningKey: "azurekms:name=my-key;vault=my-vault",
|
||||
},
|
||||
}, false},
|
||||
{"ok RSA 2048", fields{client}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "azurekms:vault=my-vault;name=my-key",
|
||||
Bits: 2048,
|
||||
SignatureAlgorithm: apiv1.SHA384WithRSA,
|
||||
}}, &apiv1.CreateKeyResponse{
|
||||
Name: "azurekms:name=my-key;vault=my-vault",
|
||||
PublicKey: rsaPub,
|
||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
||||
SigningKey: "azurekms:name=my-key;vault=my-vault",
|
||||
},
|
||||
}, false},
|
||||
{"ok RSA 3072", fields{client}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "azurekms:vault=my-vault;name=my-key",
|
||||
Bits: 3072,
|
||||
SignatureAlgorithm: apiv1.SHA512WithRSA,
|
||||
}}, &apiv1.CreateKeyResponse{
|
||||
Name: "azurekms:name=my-key;vault=my-vault",
|
||||
PublicKey: rsaPub,
|
||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
||||
SigningKey: "azurekms:name=my-key;vault=my-vault",
|
||||
},
|
||||
}, false},
|
||||
{"ok RSA 4096", fields{client}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "azurekms:vault=my-vault;name=my-key",
|
||||
Bits: 4096,
|
||||
SignatureAlgorithm: apiv1.SHA512WithRSAPSS,
|
||||
}}, &apiv1.CreateKeyResponse{
|
||||
Name: "azurekms:name=my-key;vault=my-vault",
|
||||
PublicKey: rsaPub,
|
||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
||||
SigningKey: "azurekms:name=my-key;vault=my-vault",
|
||||
},
|
||||
}, false},
|
||||
{"fail createKey", fields{client}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "azurekms:vault=my-vault;name=not-found",
|
||||
SignatureAlgorithm: apiv1.ECDSAWithSHA256,
|
||||
}}, nil, true},
|
||||
{"fail convertKey", fields{client}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "azurekms:vault=my-vault;name=not-found",
|
||||
SignatureAlgorithm: apiv1.ECDSAWithSHA256,
|
||||
}}, nil, true},
|
||||
{"fail name", fields{client}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "",
|
||||
}}, nil, true},
|
||||
{"fail vault", fields{client}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "azurekms:vault=;name=not-found?version=my-version",
|
||||
}}, nil, true},
|
||||
{"fail id", fields{client}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "azurekms:vault=my-vault;name=?version=my-version",
|
||||
}}, nil, true},
|
||||
{"fail SignatureAlgorithm", fields{client}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "azurekms:vault=my-vault;name=not-found",
|
||||
SignatureAlgorithm: apiv1.PureEd25519,
|
||||
}}, nil, true},
|
||||
{"fail bit size", fields{client}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "azurekms:vault=my-vault;name=not-found",
|
||||
SignatureAlgorithm: apiv1.SHA384WithRSAPSS,
|
||||
Bits: 1024,
|
||||
}}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
k := &KeyVault{
|
||||
baseClient: tt.fields.baseClient,
|
||||
}
|
||||
got, err := k.CreateKey(tt.args.req)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("KeyVault.CreateKey() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("KeyVault.CreateKey() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeyVault_CreateSigner(t *testing.T) {
|
||||
key, err := keyutil.GenerateDefaultSigner()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
pub := key.Public()
|
||||
jwk := createJWK(t, pub)
|
||||
|
||||
client := mockClient(t)
|
||||
client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", "").Return(keyvault.KeyBundle{
|
||||
Key: jwk,
|
||||
}, nil)
|
||||
client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", "my-version").Return(keyvault.KeyBundle{
|
||||
Key: jwk,
|
||||
}, nil)
|
||||
client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "not-found", "my-version").Return(keyvault.KeyBundle{}, errTest)
|
||||
|
||||
type fields struct {
|
||||
baseClient KeyVaultClient
|
||||
}
|
||||
type args struct {
|
||||
req *apiv1.CreateSignerRequest
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want crypto.Signer
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", fields{client}, args{&apiv1.CreateSignerRequest{
|
||||
SigningKey: "azurekms:vault=my-vault;name=my-key",
|
||||
}}, &Signer{
|
||||
client: client,
|
||||
vaultBaseURL: "https://my-vault.vault.azure.net/",
|
||||
name: "my-key",
|
||||
version: "",
|
||||
publicKey: pub,
|
||||
}, false},
|
||||
{"ok with version", fields{client}, args{&apiv1.CreateSignerRequest{
|
||||
SigningKey: "azurekms:vault=my-vault;name=my-key;version=my-version",
|
||||
}}, &Signer{
|
||||
client: client,
|
||||
vaultBaseURL: "https://my-vault.vault.azure.net/",
|
||||
name: "my-key",
|
||||
version: "my-version",
|
||||
publicKey: pub,
|
||||
}, false},
|
||||
{"fail GetKey", fields{client}, args{&apiv1.CreateSignerRequest{
|
||||
SigningKey: "azurekms:vault=my-vault;name=not-found;version=my-version",
|
||||
}}, nil, true},
|
||||
{"fail SigningKey", fields{client}, args{&apiv1.CreateSignerRequest{
|
||||
SigningKey: "",
|
||||
}}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
k := &KeyVault{
|
||||
baseClient: tt.fields.baseClient,
|
||||
}
|
||||
got, err := k.CreateSigner(tt.args.req)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("KeyVault.CreateSigner() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("KeyVault.CreateSigner() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeyVault_Close(t *testing.T) {
|
||||
client := mockClient(t)
|
||||
type fields struct {
|
||||
baseClient KeyVaultClient
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", fields{client}, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
k := &KeyVault{
|
||||
baseClient: tt.fields.baseClient,
|
||||
}
|
||||
if err := k.Close(); (err != nil) != tt.wantErr {
|
||||
t.Errorf("KeyVault.Close() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_keyType_KeyType(t *testing.T) {
|
||||
type fields struct {
|
||||
Kty keyvault.JSONWebKeyType
|
||||
Curve keyvault.JSONWebKeyCurveName
|
||||
}
|
||||
type args struct {
|
||||
pl apiv1.ProtectionLevel
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want keyvault.JSONWebKeyType
|
||||
}{
|
||||
{"ec", fields{keyvault.EC, keyvault.P256}, args{apiv1.UnspecifiedProtectionLevel}, keyvault.EC},
|
||||
{"ec software", fields{keyvault.EC, keyvault.P384}, args{apiv1.Software}, keyvault.EC},
|
||||
{"ec hsm", fields{keyvault.EC, keyvault.P521}, args{apiv1.HSM}, keyvault.ECHSM},
|
||||
{"rsa", fields{keyvault.RSA, keyvault.P256}, args{apiv1.UnspecifiedProtectionLevel}, keyvault.RSA},
|
||||
{"rsa software", fields{keyvault.RSA, ""}, args{apiv1.Software}, keyvault.RSA},
|
||||
{"rsa hsm", fields{keyvault.RSA, ""}, args{apiv1.HSM}, keyvault.RSAHSM},
|
||||
{"empty", fields{"FOO", ""}, args{apiv1.UnspecifiedProtectionLevel}, ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
k := keyType{
|
||||
Kty: tt.fields.Kty,
|
||||
Curve: tt.fields.Curve,
|
||||
}
|
||||
if got := k.KeyType(tt.args.pl); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("keyType.KeyType() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeyVault_ValidateName(t *testing.T) {
|
||||
type args struct {
|
||||
s string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", args{"azurekms:name=my-key;vault=my-vault"}, false},
|
||||
{"ok hsm", args{"azurekms:name=my-key;vault=my-vault?hsm=true"}, false},
|
||||
{"fail scheme", args{"azure:name=my-key;vault=my-vault"}, true},
|
||||
{"fail parse uri", args{"azurekms:name=%ZZ;vault=my-vault"}, true},
|
||||
{"fail no name", args{"azurekms:vault=my-vault"}, true},
|
||||
{"fail no vault", args{"azurekms:name=my-key"}, true},
|
||||
{"fail empty", args{""}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
k := &KeyVault{}
|
||||
if err := k.ValidateName(tt.args.s); (err != nil) != tt.wantErr {
|
||||
t.Errorf("KeyVault.ValidateName() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -1,182 +0,0 @@
|
||||
package azurekms
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"math/big"
|
||||
"time"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault"
|
||||
"github.com/Azure/go-autorest/autorest/azure"
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/crypto/cryptobyte"
|
||||
"golang.org/x/crypto/cryptobyte/asn1"
|
||||
)
|
||||
|
||||
// Signer implements a crypto.Signer using the AWS KMS.
|
||||
type Signer struct {
|
||||
client KeyVaultClient
|
||||
vaultBaseURL string
|
||||
name string
|
||||
version string
|
||||
publicKey crypto.PublicKey
|
||||
}
|
||||
|
||||
// NewSigner creates a new signer using a key in the AWS KMS.
|
||||
func NewSigner(client KeyVaultClient, signingKey string, defaults DefaultOptions) (crypto.Signer, error) {
|
||||
vault, name, version, _, err := parseKeyName(signingKey, defaults)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Make sure that the key exists.
|
||||
signer := &Signer{
|
||||
client: client,
|
||||
vaultBaseURL: vaultBaseURL(vault),
|
||||
name: name,
|
||||
version: version,
|
||||
}
|
||||
if err := signer.preloadKey(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return signer, nil
|
||||
}
|
||||
|
||||
func (s *Signer) preloadKey() error {
|
||||
ctx, cancel := defaultContext()
|
||||
defer cancel()
|
||||
|
||||
resp, err := s.client.GetKey(ctx, s.vaultBaseURL, s.name, s.version)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "keyVault GetKey failed")
|
||||
}
|
||||
|
||||
s.publicKey, err = convertKey(resp.Key)
|
||||
return err
|
||||
}
|
||||
|
||||
// Public returns the public key of this signer or an error.
|
||||
func (s *Signer) Public() crypto.PublicKey {
|
||||
return s.publicKey
|
||||
}
|
||||
|
||||
// Sign signs digest with the private key stored in the AWS KMS.
|
||||
func (s *Signer) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
|
||||
alg, err := getSigningAlgorithm(s.Public(), opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b64 := base64.RawURLEncoding.EncodeToString(digest)
|
||||
|
||||
// Sign with retry if the key is not ready
|
||||
resp, err := s.signWithRetry(alg, b64, 3)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "keyVault Sign failed")
|
||||
}
|
||||
|
||||
sig, err := base64.RawURLEncoding.DecodeString(*resp.Result)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error decoding keyVault Sign result")
|
||||
}
|
||||
|
||||
var octetSize int
|
||||
switch alg {
|
||||
case keyvault.ES256:
|
||||
octetSize = 32 // 256-bit, concat(R,S) = 64 bytes
|
||||
case keyvault.ES384:
|
||||
octetSize = 48 // 384-bit, concat(R,S) = 96 bytes
|
||||
case keyvault.ES512:
|
||||
octetSize = 66 // 528-bit, concat(R,S) = 132 bytes
|
||||
default:
|
||||
return sig, nil
|
||||
}
|
||||
|
||||
// Convert to asn1
|
||||
if len(sig) != octetSize*2 {
|
||||
return nil, errors.Errorf("keyVault Sign failed: unexpected signature length")
|
||||
}
|
||||
var b cryptobyte.Builder
|
||||
b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
|
||||
b.AddASN1BigInt(new(big.Int).SetBytes(sig[:octetSize])) // R
|
||||
b.AddASN1BigInt(new(big.Int).SetBytes(sig[octetSize:])) // S
|
||||
})
|
||||
return b.Bytes()
|
||||
}
|
||||
|
||||
func (s *Signer) signWithRetry(alg keyvault.JSONWebKeySignatureAlgorithm, b64 string, retryAttempts int) (keyvault.KeyOperationResult, error) {
|
||||
retry:
|
||||
ctx, cancel := defaultContext()
|
||||
defer cancel()
|
||||
|
||||
resp, err := s.client.Sign(ctx, s.vaultBaseURL, s.name, s.version, keyvault.KeySignParameters{
|
||||
Algorithm: alg,
|
||||
Value: &b64,
|
||||
})
|
||||
if err != nil && retryAttempts > 0 {
|
||||
var requestError *azure.RequestError
|
||||
if errors.As(err, &requestError) {
|
||||
if se := requestError.ServiceError; se != nil && se.InnerError != nil {
|
||||
code, ok := se.InnerError["code"].(string)
|
||||
if ok && code == "KeyNotYetValid" {
|
||||
time.Sleep(time.Second / time.Duration(retryAttempts))
|
||||
retryAttempts--
|
||||
goto retry
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func getSigningAlgorithm(key crypto.PublicKey, opts crypto.SignerOpts) (keyvault.JSONWebKeySignatureAlgorithm, error) {
|
||||
switch key.(type) {
|
||||
case *rsa.PublicKey:
|
||||
hashFunc := opts.HashFunc()
|
||||
pss, isPSS := opts.(*rsa.PSSOptions)
|
||||
// Random salt lengths are not supported
|
||||
if isPSS &&
|
||||
pss.SaltLength != rsa.PSSSaltLengthAuto &&
|
||||
pss.SaltLength != rsa.PSSSaltLengthEqualsHash &&
|
||||
pss.SaltLength != hashFunc.Size() {
|
||||
return "", errors.Errorf("unsupported RSA-PSS salt length %d", pss.SaltLength)
|
||||
}
|
||||
|
||||
switch h := hashFunc; h {
|
||||
case crypto.SHA256:
|
||||
if isPSS {
|
||||
return keyvault.PS256, nil
|
||||
}
|
||||
return keyvault.RS256, nil
|
||||
case crypto.SHA384:
|
||||
if isPSS {
|
||||
return keyvault.PS384, nil
|
||||
}
|
||||
return keyvault.RS384, nil
|
||||
case crypto.SHA512:
|
||||
if isPSS {
|
||||
return keyvault.PS512, nil
|
||||
}
|
||||
return keyvault.RS512, nil
|
||||
default:
|
||||
return "", errors.Errorf("unsupported hash function %v", h)
|
||||
}
|
||||
case *ecdsa.PublicKey:
|
||||
switch h := opts.HashFunc(); h {
|
||||
case crypto.SHA256:
|
||||
return keyvault.ES256, nil
|
||||
case crypto.SHA384:
|
||||
return keyvault.ES384, nil
|
||||
case crypto.SHA512:
|
||||
return keyvault.ES512, nil
|
||||
default:
|
||||
return "", errors.Errorf("unsupported hash function %v", h)
|
||||
}
|
||||
default:
|
||||
return "", errors.Errorf("unsupported key type %T", key)
|
||||
}
|
||||
}
|
@ -1,493 +0,0 @@
|
||||
package azurekms
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault"
|
||||
"github.com/Azure/go-autorest/autorest"
|
||||
"github.com/Azure/go-autorest/autorest/azure"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/smallstep/certificates/kms/apiv1"
|
||||
"go.step.sm/crypto/keyutil"
|
||||
"golang.org/x/crypto/cryptobyte"
|
||||
"golang.org/x/crypto/cryptobyte/asn1"
|
||||
)
|
||||
|
||||
func TestNewSigner(t *testing.T) {
|
||||
key, err := keyutil.GenerateDefaultSigner()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
pub := key.Public()
|
||||
jwk := createJWK(t, pub)
|
||||
|
||||
client := mockClient(t)
|
||||
client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", "").Return(keyvault.KeyBundle{
|
||||
Key: jwk,
|
||||
}, nil)
|
||||
client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", "my-version").Return(keyvault.KeyBundle{
|
||||
Key: jwk,
|
||||
}, nil)
|
||||
client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", "my-version").Return(keyvault.KeyBundle{
|
||||
Key: jwk,
|
||||
}, nil)
|
||||
client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "not-found", "my-version").Return(keyvault.KeyBundle{}, errTest)
|
||||
|
||||
var noOptions DefaultOptions
|
||||
type args struct {
|
||||
client KeyVaultClient
|
||||
signingKey string
|
||||
defaults DefaultOptions
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want crypto.Signer
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", args{client, "azurekms:vault=my-vault;name=my-key", noOptions}, &Signer{
|
||||
client: client,
|
||||
vaultBaseURL: "https://my-vault.vault.azure.net/",
|
||||
name: "my-key",
|
||||
version: "",
|
||||
publicKey: pub,
|
||||
}, false},
|
||||
{"ok with version", args{client, "azurekms:name=my-key;vault=my-vault?version=my-version", noOptions}, &Signer{
|
||||
client: client,
|
||||
vaultBaseURL: "https://my-vault.vault.azure.net/",
|
||||
name: "my-key",
|
||||
version: "my-version",
|
||||
publicKey: pub,
|
||||
}, false},
|
||||
{"ok with options", args{client, "azurekms:name=my-key?version=my-version", DefaultOptions{Vault: "my-vault", ProtectionLevel: apiv1.HSM}}, &Signer{
|
||||
client: client,
|
||||
vaultBaseURL: "https://my-vault.vault.azure.net/",
|
||||
name: "my-key",
|
||||
version: "my-version",
|
||||
publicKey: pub,
|
||||
}, false},
|
||||
{"fail GetKey", args{client, "azurekms:name=not-found;vault=my-vault?version=my-version", noOptions}, nil, true},
|
||||
{"fail vault", args{client, "azurekms:name=not-found;vault=", noOptions}, nil, true},
|
||||
{"fail id", args{client, "azurekms:name=;vault=my-vault?version=my-version", noOptions}, nil, true},
|
||||
{"fail scheme", args{client, "kms:name=not-found;vault=my-vault?version=my-version", noOptions}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := NewSigner(tt.args.client, tt.args.signingKey, tt.args.defaults)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("NewSigner() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("NewSigner() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSigner_Public(t *testing.T) {
|
||||
key, err := keyutil.GenerateDefaultSigner()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
pub := key.Public()
|
||||
|
||||
type fields struct {
|
||||
publicKey crypto.PublicKey
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want crypto.PublicKey
|
||||
}{
|
||||
{"ok", fields{pub}, pub},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &Signer{
|
||||
publicKey: tt.fields.publicKey,
|
||||
}
|
||||
if got := s.Public(); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Signer.Public() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSigner_Sign(t *testing.T) {
|
||||
sign := func(kty, crv string, bits int, opts crypto.SignerOpts) (crypto.PublicKey, []byte, string, []byte) {
|
||||
key, err := keyutil.GenerateSigner(kty, crv, bits)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
h := opts.HashFunc().New()
|
||||
h.Write([]byte("random-data"))
|
||||
sum := h.Sum(nil)
|
||||
|
||||
var sig, resultSig []byte
|
||||
if priv, ok := key.(*ecdsa.PrivateKey); ok {
|
||||
r, s, err := ecdsa.Sign(rand.Reader, priv, sum)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
curveBits := priv.Params().BitSize
|
||||
keyBytes := curveBits / 8
|
||||
if curveBits%8 > 0 {
|
||||
keyBytes++
|
||||
}
|
||||
rBytes := r.Bytes()
|
||||
rBytesPadded := make([]byte, keyBytes)
|
||||
copy(rBytesPadded[keyBytes-len(rBytes):], rBytes)
|
||||
|
||||
sBytes := s.Bytes()
|
||||
sBytesPadded := make([]byte, keyBytes)
|
||||
copy(sBytesPadded[keyBytes-len(sBytes):], sBytes)
|
||||
// nolint:gocritic
|
||||
resultSig = append(rBytesPadded, sBytesPadded...)
|
||||
|
||||
var b cryptobyte.Builder
|
||||
b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
|
||||
b.AddASN1BigInt(r)
|
||||
b.AddASN1BigInt(s)
|
||||
})
|
||||
sig, err = b.Bytes()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
sig, err = key.Sign(rand.Reader, sum, opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resultSig = sig
|
||||
}
|
||||
|
||||
return key.Public(), h.Sum(nil), base64.RawURLEncoding.EncodeToString(resultSig), sig
|
||||
}
|
||||
|
||||
p256, p256Digest, p256ResultSig, p256Sig := sign("EC", "P-256", 0, crypto.SHA256)
|
||||
p384, p384Digest, p386ResultSig, p384Sig := sign("EC", "P-384", 0, crypto.SHA384)
|
||||
p521, p521Digest, p521ResultSig, p521Sig := sign("EC", "P-521", 0, crypto.SHA512)
|
||||
rsaSHA256, rsaSHA256Digest, rsaSHA256ResultSig, rsaSHA256Sig := sign("RSA", "", 2048, crypto.SHA256)
|
||||
rsaSHA384, rsaSHA384Digest, rsaSHA384ResultSig, rsaSHA384Sig := sign("RSA", "", 2048, crypto.SHA384)
|
||||
rsaSHA512, rsaSHA512Digest, rsaSHA512ResultSig, rsaSHA512Sig := sign("RSA", "", 2048, crypto.SHA512)
|
||||
rsaPSSSHA256, rsaPSSSHA256Digest, rsaPSSSHA256ResultSig, rsaPSSSHA256Sig := sign("RSA", "", 2048, &rsa.PSSOptions{
|
||||
SaltLength: rsa.PSSSaltLengthAuto,
|
||||
Hash: crypto.SHA256,
|
||||
})
|
||||
rsaPSSSHA384, rsaPSSSHA384Digest, rsaPSSSHA384ResultSig, rsaPSSSHA384Sig := sign("RSA", "", 2048, &rsa.PSSOptions{
|
||||
SaltLength: rsa.PSSSaltLengthAuto,
|
||||
Hash: crypto.SHA512,
|
||||
})
|
||||
rsaPSSSHA512, rsaPSSSHA512Digest, rsaPSSSHA512ResultSig, rsaPSSSHA512Sig := sign("RSA", "", 2048, &rsa.PSSOptions{
|
||||
SaltLength: rsa.PSSSaltLengthAuto,
|
||||
Hash: crypto.SHA512,
|
||||
})
|
||||
|
||||
ed25519Key, err := keyutil.GenerateSigner("OKP", "Ed25519", 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
client := mockClient(t)
|
||||
expects := []struct {
|
||||
name string
|
||||
keyVersion string
|
||||
alg keyvault.JSONWebKeySignatureAlgorithm
|
||||
digest []byte
|
||||
result keyvault.KeyOperationResult
|
||||
err error
|
||||
}{
|
||||
{"P-256", "", keyvault.ES256, p256Digest, keyvault.KeyOperationResult{
|
||||
Result: &p256ResultSig,
|
||||
}, nil},
|
||||
{"P-384", "my-version", keyvault.ES384, p384Digest, keyvault.KeyOperationResult{
|
||||
Result: &p386ResultSig,
|
||||
}, nil},
|
||||
{"P-521", "my-version", keyvault.ES512, p521Digest, keyvault.KeyOperationResult{
|
||||
Result: &p521ResultSig,
|
||||
}, nil},
|
||||
{"RSA SHA256", "", keyvault.RS256, rsaSHA256Digest, keyvault.KeyOperationResult{
|
||||
Result: &rsaSHA256ResultSig,
|
||||
}, nil},
|
||||
{"RSA SHA384", "", keyvault.RS384, rsaSHA384Digest, keyvault.KeyOperationResult{
|
||||
Result: &rsaSHA384ResultSig,
|
||||
}, nil},
|
||||
{"RSA SHA512", "", keyvault.RS512, rsaSHA512Digest, keyvault.KeyOperationResult{
|
||||
Result: &rsaSHA512ResultSig,
|
||||
}, nil},
|
||||
{"RSA-PSS SHA256", "", keyvault.PS256, rsaPSSSHA256Digest, keyvault.KeyOperationResult{
|
||||
Result: &rsaPSSSHA256ResultSig,
|
||||
}, nil},
|
||||
{"RSA-PSS SHA384", "", keyvault.PS384, rsaPSSSHA384Digest, keyvault.KeyOperationResult{
|
||||
Result: &rsaPSSSHA384ResultSig,
|
||||
}, nil},
|
||||
{"RSA-PSS SHA512", "", keyvault.PS512, rsaPSSSHA512Digest, keyvault.KeyOperationResult{
|
||||
Result: &rsaPSSSHA512ResultSig,
|
||||
}, nil},
|
||||
// Errors
|
||||
{"fail Sign", "", keyvault.RS256, rsaSHA256Digest, keyvault.KeyOperationResult{}, errTest},
|
||||
{"fail sign length", "", keyvault.ES256, p256Digest, keyvault.KeyOperationResult{
|
||||
Result: &rsaSHA256ResultSig,
|
||||
}, nil},
|
||||
{"fail base64", "", keyvault.ES256, p256Digest, keyvault.KeyOperationResult{
|
||||
Result: func() *string {
|
||||
v := "😎"
|
||||
return &v
|
||||
}(),
|
||||
}, nil},
|
||||
}
|
||||
for _, e := range expects {
|
||||
value := base64.RawURLEncoding.EncodeToString(e.digest)
|
||||
client.EXPECT().Sign(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", e.keyVersion, keyvault.KeySignParameters{
|
||||
Algorithm: e.alg,
|
||||
Value: &value,
|
||||
}).Return(e.result, e.err)
|
||||
}
|
||||
|
||||
type fields struct {
|
||||
client KeyVaultClient
|
||||
vaultBaseURL string
|
||||
name string
|
||||
version string
|
||||
publicKey crypto.PublicKey
|
||||
}
|
||||
type args struct {
|
||||
rand io.Reader
|
||||
digest []byte
|
||||
opts crypto.SignerOpts
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want []byte
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok P-256", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", p256}, args{
|
||||
rand.Reader, p256Digest, crypto.SHA256,
|
||||
}, p256Sig, false},
|
||||
{"ok P-384", fields{client, "https://my-vault.vault.azure.net/", "my-key", "my-version", p384}, args{
|
||||
rand.Reader, p384Digest, crypto.SHA384,
|
||||
}, p384Sig, false},
|
||||
{"ok P-521", fields{client, "https://my-vault.vault.azure.net/", "my-key", "my-version", p521}, args{
|
||||
rand.Reader, p521Digest, crypto.SHA512,
|
||||
}, p521Sig, false},
|
||||
{"ok RSA SHA256", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaSHA256}, args{
|
||||
rand.Reader, rsaSHA256Digest, crypto.SHA256,
|
||||
}, rsaSHA256Sig, false},
|
||||
{"ok RSA SHA384", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaSHA384}, args{
|
||||
rand.Reader, rsaSHA384Digest, crypto.SHA384,
|
||||
}, rsaSHA384Sig, false},
|
||||
{"ok RSA SHA512", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaSHA512}, args{
|
||||
rand.Reader, rsaSHA512Digest, crypto.SHA512,
|
||||
}, rsaSHA512Sig, false},
|
||||
{"ok RSA-PSS SHA256", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaPSSSHA256}, args{
|
||||
rand.Reader, rsaPSSSHA256Digest, &rsa.PSSOptions{
|
||||
SaltLength: rsa.PSSSaltLengthAuto,
|
||||
Hash: crypto.SHA256,
|
||||
},
|
||||
}, rsaPSSSHA256Sig, false},
|
||||
{"ok RSA-PSS SHA384", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaPSSSHA384}, args{
|
||||
rand.Reader, rsaPSSSHA384Digest, &rsa.PSSOptions{
|
||||
SaltLength: rsa.PSSSaltLengthEqualsHash,
|
||||
Hash: crypto.SHA384,
|
||||
},
|
||||
}, rsaPSSSHA384Sig, false},
|
||||
{"ok RSA-PSS SHA512", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaPSSSHA512}, args{
|
||||
rand.Reader, rsaPSSSHA512Digest, &rsa.PSSOptions{
|
||||
SaltLength: 64,
|
||||
Hash: crypto.SHA512,
|
||||
},
|
||||
}, rsaPSSSHA512Sig, false},
|
||||
{"fail Sign", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaSHA256}, args{
|
||||
rand.Reader, rsaSHA256Digest, crypto.SHA256,
|
||||
}, nil, true},
|
||||
{"fail sign length", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", p256}, args{
|
||||
rand.Reader, p256Digest, crypto.SHA256,
|
||||
}, nil, true},
|
||||
{"fail base64", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", p256}, args{
|
||||
rand.Reader, p256Digest, crypto.SHA256,
|
||||
}, nil, true},
|
||||
{"fail RSA-PSS salt length", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaPSSSHA256}, args{
|
||||
rand.Reader, rsaPSSSHA256Digest, &rsa.PSSOptions{
|
||||
SaltLength: 64,
|
||||
Hash: crypto.SHA256,
|
||||
},
|
||||
}, nil, true},
|
||||
{"fail RSA Hash", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaSHA256}, args{
|
||||
rand.Reader, rsaSHA256Digest, crypto.SHA1,
|
||||
}, nil, true},
|
||||
{"fail ECDSA Hash", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", p256}, args{
|
||||
rand.Reader, p256Digest, crypto.MD5,
|
||||
}, nil, true},
|
||||
{"fail Ed25519", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", ed25519Key}, args{
|
||||
rand.Reader, []byte("message"), crypto.Hash(0),
|
||||
}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &Signer{
|
||||
client: tt.fields.client,
|
||||
vaultBaseURL: tt.fields.vaultBaseURL,
|
||||
name: tt.fields.name,
|
||||
version: tt.fields.version,
|
||||
publicKey: tt.fields.publicKey,
|
||||
}
|
||||
got, err := s.Sign(tt.args.rand, tt.args.digest, tt.args.opts)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Signer.Sign() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Signer.Sign() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSigner_Sign_signWithRetry(t *testing.T) {
|
||||
sign := func(kty, crv string, bits int, opts crypto.SignerOpts) (crypto.PublicKey, []byte, string, []byte) {
|
||||
key, err := keyutil.GenerateSigner(kty, crv, bits)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
h := opts.HashFunc().New()
|
||||
h.Write([]byte("random-data"))
|
||||
sum := h.Sum(nil)
|
||||
|
||||
var sig, resultSig []byte
|
||||
if priv, ok := key.(*ecdsa.PrivateKey); ok {
|
||||
r, s, err := ecdsa.Sign(rand.Reader, priv, sum)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
curveBits := priv.Params().BitSize
|
||||
keyBytes := curveBits / 8
|
||||
if curveBits%8 > 0 {
|
||||
keyBytes++
|
||||
}
|
||||
rBytes := r.Bytes()
|
||||
rBytesPadded := make([]byte, keyBytes)
|
||||
copy(rBytesPadded[keyBytes-len(rBytes):], rBytes)
|
||||
|
||||
sBytes := s.Bytes()
|
||||
sBytesPadded := make([]byte, keyBytes)
|
||||
copy(sBytesPadded[keyBytes-len(sBytes):], sBytes)
|
||||
// nolint:gocritic
|
||||
resultSig = append(rBytesPadded, sBytesPadded...)
|
||||
|
||||
var b cryptobyte.Builder
|
||||
b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
|
||||
b.AddASN1BigInt(r)
|
||||
b.AddASN1BigInt(s)
|
||||
})
|
||||
sig, err = b.Bytes()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
sig, err = key.Sign(rand.Reader, sum, opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resultSig = sig
|
||||
}
|
||||
|
||||
return key.Public(), h.Sum(nil), base64.RawURLEncoding.EncodeToString(resultSig), sig
|
||||
}
|
||||
|
||||
p256, p256Digest, p256ResultSig, p256Sig := sign("EC", "P-256", 0, crypto.SHA256)
|
||||
okResult := keyvault.KeyOperationResult{
|
||||
Result: &p256ResultSig,
|
||||
}
|
||||
failResult := keyvault.KeyOperationResult{}
|
||||
retryError := autorest.DetailedError{
|
||||
Original: &azure.RequestError{
|
||||
ServiceError: &azure.ServiceError{
|
||||
InnerError: map[string]interface{}{
|
||||
"code": "KeyNotYetValid",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
client := mockClient(t)
|
||||
expects := []struct {
|
||||
name string
|
||||
keyVersion string
|
||||
alg keyvault.JSONWebKeySignatureAlgorithm
|
||||
digest []byte
|
||||
result keyvault.KeyOperationResult
|
||||
err error
|
||||
}{
|
||||
{"ok 1", "", keyvault.ES256, p256Digest, failResult, retryError},
|
||||
{"ok 2", "", keyvault.ES256, p256Digest, failResult, retryError},
|
||||
{"ok 3", "", keyvault.ES256, p256Digest, failResult, retryError},
|
||||
{"ok 4", "", keyvault.ES256, p256Digest, okResult, nil},
|
||||
{"fail", "fail-version", keyvault.ES256, p256Digest, failResult, retryError},
|
||||
{"fail", "fail-version", keyvault.ES256, p256Digest, failResult, retryError},
|
||||
{"fail", "fail-version", keyvault.ES256, p256Digest, failResult, retryError},
|
||||
{"fail", "fail-version", keyvault.ES256, p256Digest, failResult, retryError},
|
||||
}
|
||||
for _, e := range expects {
|
||||
value := base64.RawURLEncoding.EncodeToString(e.digest)
|
||||
client.EXPECT().Sign(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", e.keyVersion, keyvault.KeySignParameters{
|
||||
Algorithm: e.alg,
|
||||
Value: &value,
|
||||
}).Return(e.result, e.err)
|
||||
}
|
||||
|
||||
type fields struct {
|
||||
client KeyVaultClient
|
||||
vaultBaseURL string
|
||||
name string
|
||||
version string
|
||||
publicKey crypto.PublicKey
|
||||
}
|
||||
type args struct {
|
||||
rand io.Reader
|
||||
digest []byte
|
||||
opts crypto.SignerOpts
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want []byte
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", p256}, args{
|
||||
rand.Reader, p256Digest, crypto.SHA256,
|
||||
}, p256Sig, false},
|
||||
{"fail", fields{client, "https://my-vault.vault.azure.net/", "my-key", "fail-version", p256}, args{
|
||||
rand.Reader, p256Digest, crypto.SHA256,
|
||||
}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &Signer{
|
||||
client: tt.fields.client,
|
||||
vaultBaseURL: tt.fields.vaultBaseURL,
|
||||
name: tt.fields.name,
|
||||
version: tt.fields.version,
|
||||
publicKey: tt.fields.publicKey,
|
||||
}
|
||||
got, err := s.Sign(tt.args.rand, tt.args.digest, tt.args.opts)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Signer.Sign() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Signer.Sign() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -1,98 +0,0 @@
|
||||
package azurekms
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"encoding/json"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/kms/apiv1"
|
||||
"github.com/smallstep/certificates/kms/uri"
|
||||
"go.step.sm/crypto/jose"
|
||||
)
|
||||
|
||||
// defaultContext returns the default context used in requests to azure.
|
||||
func defaultContext() (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(context.Background(), 15*time.Second)
|
||||
}
|
||||
|
||||
// getKeyName returns the uri of the key vault key.
|
||||
func getKeyName(vault, name string, bundle keyvault.KeyBundle) string {
|
||||
if bundle.Key != nil && bundle.Key.Kid != nil {
|
||||
sm := keyIDRegexp.FindAllStringSubmatch(*bundle.Key.Kid, 1)
|
||||
if len(sm) == 1 && len(sm[0]) == 4 {
|
||||
m := sm[0]
|
||||
u := uri.New(Scheme, url.Values{
|
||||
"vault": []string{m[1]},
|
||||
"name": []string{m[2]},
|
||||
})
|
||||
u.RawQuery = url.Values{"version": []string{m[3]}}.Encode()
|
||||
return u.String()
|
||||
}
|
||||
}
|
||||
// Fallback to URI without id.
|
||||
return uri.New(Scheme, url.Values{
|
||||
"vault": []string{vault},
|
||||
"name": []string{name},
|
||||
}).String()
|
||||
}
|
||||
|
||||
// parseKeyName returns the key vault, name and version from URIs like:
|
||||
//
|
||||
// - azurekms:vault=key-vault;name=key-name
|
||||
// - azurekms:vault=key-vault;name=key-name?version=key-id
|
||||
// - azurekms:vault=key-vault;name=key-name?version=key-id&hsm=true
|
||||
//
|
||||
// The key-id defines the version of the key, if it is not passed the latest
|
||||
// version will be used.
|
||||
//
|
||||
// HSM can also be passed to define the protection level if this is not given in
|
||||
// CreateQuery.
|
||||
func parseKeyName(rawURI string, defaults DefaultOptions) (vault, name, version string, hsm bool, err error) {
|
||||
var u *uri.URI
|
||||
|
||||
u, err = uri.ParseWithScheme(Scheme, rawURI)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if name = u.Get("name"); name == "" {
|
||||
err = errors.Errorf("key uri %s is not valid: name is missing", rawURI)
|
||||
return
|
||||
}
|
||||
if vault = u.Get("vault"); vault == "" {
|
||||
if defaults.Vault == "" {
|
||||
name = ""
|
||||
err = errors.Errorf("key uri %s is not valid: vault is missing", rawURI)
|
||||
return
|
||||
}
|
||||
vault = defaults.Vault
|
||||
}
|
||||
if u.Get("hsm") == "" {
|
||||
hsm = (defaults.ProtectionLevel == apiv1.HSM)
|
||||
} else {
|
||||
hsm = u.GetBool("hsm")
|
||||
}
|
||||
|
||||
version = u.Get("version")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func vaultBaseURL(vault string) string {
|
||||
return "https://" + vault + ".vault.azure.net/"
|
||||
}
|
||||
|
||||
func convertKey(key *keyvault.JSONWebKey) (crypto.PublicKey, error) {
|
||||
b, err := json.Marshal(key)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error marshaling key")
|
||||
}
|
||||
var jwk jose.JSONWebKey
|
||||
if err := jwk.UnmarshalJSON(b); err != nil {
|
||||
return nil, errors.Wrap(err, "error unmarshaling key")
|
||||
}
|
||||
return jwk.Key, nil
|
||||
}
|
@ -1,96 +0,0 @@
|
||||
package azurekms
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault"
|
||||
"github.com/smallstep/certificates/kms/apiv1"
|
||||
)
|
||||
|
||||
func Test_getKeyName(t *testing.T) {
|
||||
getBundle := func(kid string) keyvault.KeyBundle {
|
||||
return keyvault.KeyBundle{
|
||||
Key: &keyvault.JSONWebKey{
|
||||
Kid: &kid,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type args struct {
|
||||
vault string
|
||||
name string
|
||||
bundle keyvault.KeyBundle
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
}{
|
||||
{"ok", args{"my-vault", "my-key", getBundle("https://my-vault.vault.azure.net/keys/my-key/my-version")}, "azurekms:name=my-key;vault=my-vault?version=my-version"},
|
||||
{"ok default", args{"my-vault", "my-key", getBundle("https://my-vault.foo.net/keys/my-key/my-version")}, "azurekms:name=my-key;vault=my-vault"},
|
||||
{"ok too short", args{"my-vault", "my-key", getBundle("https://my-vault.vault.azure.net/keys/my-version")}, "azurekms:name=my-key;vault=my-vault"},
|
||||
{"ok too long", args{"my-vault", "my-key", getBundle("https://my-vault.vault.azure.net/keys/my-key/my-version/sign")}, "azurekms:name=my-key;vault=my-vault"},
|
||||
{"ok nil key", args{"my-vault", "my-key", keyvault.KeyBundle{}}, "azurekms:name=my-key;vault=my-vault"},
|
||||
{"ok nil kid", args{"my-vault", "my-key", keyvault.KeyBundle{Key: &keyvault.JSONWebKey{}}}, "azurekms:name=my-key;vault=my-vault"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := getKeyName(tt.args.vault, tt.args.name, tt.args.bundle); got != tt.want {
|
||||
t.Errorf("getKeyName() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_parseKeyName(t *testing.T) {
|
||||
var noOptions DefaultOptions
|
||||
type args struct {
|
||||
rawURI string
|
||||
defaults DefaultOptions
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantVault string
|
||||
wantName string
|
||||
wantVersion string
|
||||
wantHsm bool
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", args{"azurekms:name=my-key;vault=my-vault?version=my-version", noOptions}, "my-vault", "my-key", "my-version", false, false},
|
||||
{"ok opaque version", args{"azurekms:name=my-key;vault=my-vault;version=my-version", noOptions}, "my-vault", "my-key", "my-version", false, false},
|
||||
{"ok no version", args{"azurekms:name=my-key;vault=my-vault", noOptions}, "my-vault", "my-key", "", false, false},
|
||||
{"ok hsm", args{"azurekms:name=my-key;vault=my-vault?hsm=true", noOptions}, "my-vault", "my-key", "", true, false},
|
||||
{"ok hsm false", args{"azurekms:name=my-key;vault=my-vault?hsm=false", noOptions}, "my-vault", "my-key", "", false, false},
|
||||
{"ok default vault", args{"azurekms:name=my-key?version=my-version", DefaultOptions{Vault: "my-vault"}}, "my-vault", "my-key", "my-version", false, false},
|
||||
{"ok default hsm", args{"azurekms:name=my-key;vault=my-vault?version=my-version", DefaultOptions{Vault: "other-vault", ProtectionLevel: apiv1.HSM}}, "my-vault", "my-key", "my-version", true, false},
|
||||
{"fail scheme", args{"azure:name=my-key;vault=my-vault", noOptions}, "", "", "", false, true},
|
||||
{"fail parse uri", args{"azurekms:name=%ZZ;vault=my-vault", noOptions}, "", "", "", false, true},
|
||||
{"fail no name", args{"azurekms:vault=my-vault", noOptions}, "", "", "", false, true},
|
||||
{"fail empty name", args{"azurekms:name=;vault=my-vault", noOptions}, "", "", "", false, true},
|
||||
{"fail no vault", args{"azurekms:name=my-key", noOptions}, "", "", "", false, true},
|
||||
{"fail empty vault", args{"azurekms:name=my-key;vault=", noOptions}, "", "", "", false, true},
|
||||
{"fail empty", args{"", noOptions}, "", "", "", false, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotVault, gotName, gotVersion, gotHsm, err := parseKeyName(tt.args.rawURI, tt.args.defaults)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("parseKeyName() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if gotVault != tt.wantVault {
|
||||
t.Errorf("parseKeyName() gotVault = %v, want %v", gotVault, tt.wantVault)
|
||||
}
|
||||
if gotName != tt.wantName {
|
||||
t.Errorf("parseKeyName() gotName = %v, want %v", gotName, tt.wantName)
|
||||
}
|
||||
if gotVersion != tt.wantVersion {
|
||||
t.Errorf("parseKeyName() gotVersion = %v, want %v", gotVersion, tt.wantVersion)
|
||||
}
|
||||
if gotHsm != tt.wantHsm {
|
||||
t.Errorf("parseKeyName() gotHsm = %v, want %v", gotHsm, tt.wantHsm)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -1,348 +0,0 @@
|
||||
package cloudkms
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/x509"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
cloudkms "cloud.google.com/go/kms/apiv1"
|
||||
gax "github.com/googleapis/gax-go/v2"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/kms/apiv1"
|
||||
"github.com/smallstep/certificates/kms/uri"
|
||||
"go.step.sm/crypto/pemutil"
|
||||
"google.golang.org/api/option"
|
||||
kmspb "google.golang.org/genproto/googleapis/cloud/kms/v1"
|
||||
)
|
||||
|
||||
// Scheme is the scheme used in uris.
|
||||
const Scheme = "cloudkms"
|
||||
|
||||
const pendingGenerationRetries = 10
|
||||
|
||||
// protectionLevelMapping maps step protection levels with cloud kms ones.
|
||||
var protectionLevelMapping = map[apiv1.ProtectionLevel]kmspb.ProtectionLevel{
|
||||
apiv1.UnspecifiedProtectionLevel: kmspb.ProtectionLevel_PROTECTION_LEVEL_UNSPECIFIED,
|
||||
apiv1.Software: kmspb.ProtectionLevel_SOFTWARE,
|
||||
apiv1.HSM: kmspb.ProtectionLevel_HSM,
|
||||
}
|
||||
|
||||
// signatureAlgorithmMapping is a mapping between the step signature algorithm,
|
||||
// and bits for RSA keys, with cloud kms one.
|
||||
//
|
||||
// Cloud KMS does not support SHA384WithRSA, SHA384WithRSAPSS, SHA384WithRSAPSS,
|
||||
// ECDSAWithSHA512, and PureEd25519.
|
||||
var signatureAlgorithmMapping = map[apiv1.SignatureAlgorithm]interface{}{
|
||||
apiv1.UnspecifiedSignAlgorithm: kmspb.CryptoKeyVersion_CRYPTO_KEY_VERSION_ALGORITHM_UNSPECIFIED,
|
||||
apiv1.SHA256WithRSA: map[int]kmspb.CryptoKeyVersion_CryptoKeyVersionAlgorithm{
|
||||
0: kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_3072_SHA256,
|
||||
2048: kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_2048_SHA256,
|
||||
3072: kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_3072_SHA256,
|
||||
4096: kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_4096_SHA256,
|
||||
},
|
||||
apiv1.SHA512WithRSA: map[int]kmspb.CryptoKeyVersion_CryptoKeyVersionAlgorithm{
|
||||
0: kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_4096_SHA512,
|
||||
4096: kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_4096_SHA512,
|
||||
},
|
||||
apiv1.SHA256WithRSAPSS: map[int]kmspb.CryptoKeyVersion_CryptoKeyVersionAlgorithm{
|
||||
0: kmspb.CryptoKeyVersion_RSA_SIGN_PSS_3072_SHA256,
|
||||
2048: kmspb.CryptoKeyVersion_RSA_SIGN_PSS_2048_SHA256,
|
||||
3072: kmspb.CryptoKeyVersion_RSA_SIGN_PSS_3072_SHA256,
|
||||
4096: kmspb.CryptoKeyVersion_RSA_SIGN_PSS_4096_SHA256,
|
||||
},
|
||||
apiv1.SHA512WithRSAPSS: map[int]kmspb.CryptoKeyVersion_CryptoKeyVersionAlgorithm{
|
||||
0: kmspb.CryptoKeyVersion_RSA_SIGN_PSS_4096_SHA512,
|
||||
4096: kmspb.CryptoKeyVersion_RSA_SIGN_PSS_4096_SHA512,
|
||||
},
|
||||
apiv1.ECDSAWithSHA256: kmspb.CryptoKeyVersion_EC_SIGN_P256_SHA256,
|
||||
apiv1.ECDSAWithSHA384: kmspb.CryptoKeyVersion_EC_SIGN_P384_SHA384,
|
||||
}
|
||||
|
||||
var cryptoKeyVersionMapping = map[kmspb.CryptoKeyVersion_CryptoKeyVersionAlgorithm]x509.SignatureAlgorithm{
|
||||
kmspb.CryptoKeyVersion_EC_SIGN_P256_SHA256: x509.ECDSAWithSHA256,
|
||||
kmspb.CryptoKeyVersion_EC_SIGN_P384_SHA384: x509.ECDSAWithSHA384,
|
||||
kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_2048_SHA256: x509.SHA256WithRSA,
|
||||
kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_3072_SHA256: x509.SHA256WithRSA,
|
||||
kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_4096_SHA256: x509.SHA256WithRSA,
|
||||
kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_4096_SHA512: x509.SHA512WithRSA,
|
||||
kmspb.CryptoKeyVersion_RSA_SIGN_PSS_2048_SHA256: x509.SHA256WithRSAPSS,
|
||||
kmspb.CryptoKeyVersion_RSA_SIGN_PSS_3072_SHA256: x509.SHA256WithRSAPSS,
|
||||
kmspb.CryptoKeyVersion_RSA_SIGN_PSS_4096_SHA256: x509.SHA256WithRSAPSS,
|
||||
kmspb.CryptoKeyVersion_RSA_SIGN_PSS_4096_SHA512: x509.SHA512WithRSAPSS,
|
||||
}
|
||||
|
||||
// KeyManagementClient defines the methods on KeyManagementClient that this
|
||||
// package will use. This interface will be used for unit testing.
|
||||
type KeyManagementClient interface {
|
||||
Close() error
|
||||
GetPublicKey(context.Context, *kmspb.GetPublicKeyRequest, ...gax.CallOption) (*kmspb.PublicKey, error)
|
||||
AsymmetricSign(context.Context, *kmspb.AsymmetricSignRequest, ...gax.CallOption) (*kmspb.AsymmetricSignResponse, error)
|
||||
CreateCryptoKey(context.Context, *kmspb.CreateCryptoKeyRequest, ...gax.CallOption) (*kmspb.CryptoKey, error)
|
||||
GetKeyRing(context.Context, *kmspb.GetKeyRingRequest, ...gax.CallOption) (*kmspb.KeyRing, error)
|
||||
CreateKeyRing(context.Context, *kmspb.CreateKeyRingRequest, ...gax.CallOption) (*kmspb.KeyRing, error)
|
||||
CreateCryptoKeyVersion(ctx context.Context, req *kmspb.CreateCryptoKeyVersionRequest, opts ...gax.CallOption) (*kmspb.CryptoKeyVersion, error)
|
||||
}
|
||||
|
||||
var newKeyManagementClient = func(ctx context.Context, opts ...option.ClientOption) (KeyManagementClient, error) {
|
||||
return cloudkms.NewKeyManagementClient(ctx, opts...)
|
||||
}
|
||||
|
||||
// CloudKMS implements a KMS using Google's Cloud apiv1.
|
||||
type CloudKMS struct {
|
||||
client KeyManagementClient
|
||||
}
|
||||
|
||||
// New creates a new CloudKMS configured with a new client.
|
||||
func New(ctx context.Context, opts apiv1.Options) (*CloudKMS, error) {
|
||||
var cloudOpts []option.ClientOption
|
||||
|
||||
if opts.URI != "" {
|
||||
u, err := uri.ParseWithScheme(Scheme, opts.URI)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if f := u.Get("credentials-file"); f != "" {
|
||||
cloudOpts = append(cloudOpts, option.WithCredentialsFile(f))
|
||||
}
|
||||
}
|
||||
|
||||
// Deprecated way to set configuration parameters.
|
||||
if opts.CredentialsFile != "" {
|
||||
cloudOpts = append(cloudOpts, option.WithCredentialsFile(opts.CredentialsFile))
|
||||
}
|
||||
|
||||
client, err := newKeyManagementClient(ctx, cloudOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &CloudKMS{
|
||||
client: client,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
apiv1.Register(apiv1.CloudKMS, func(ctx context.Context, opts apiv1.Options) (apiv1.KeyManager, error) {
|
||||
return New(ctx, opts)
|
||||
})
|
||||
}
|
||||
|
||||
// NewCloudKMS creates a CloudKMS with a given client.
|
||||
func NewCloudKMS(client KeyManagementClient) *CloudKMS {
|
||||
return &CloudKMS{
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the connection of the Cloud KMS client.
|
||||
func (k *CloudKMS) Close() error {
|
||||
if err := k.client.Close(); err != nil {
|
||||
return errors.Wrap(err, "cloudKMS Close failed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateSigner returns a new cloudkms signer configured with the given signing
|
||||
// key name.
|
||||
func (k *CloudKMS) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, error) {
|
||||
if req.SigningKey == "" {
|
||||
return nil, errors.New("signing key cannot be empty")
|
||||
}
|
||||
return NewSigner(k.client, req.SigningKey)
|
||||
}
|
||||
|
||||
// CreateKey creates in Google's Cloud KMS a new asymmetric key for signing.
|
||||
func (k *CloudKMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyResponse, error) {
|
||||
if req.Name == "" {
|
||||
return nil, errors.New("createKeyRequest 'name' cannot be empty")
|
||||
}
|
||||
|
||||
protectionLevel, ok := protectionLevelMapping[req.ProtectionLevel]
|
||||
if !ok {
|
||||
return nil, errors.Errorf("cloudKMS does not support protection level '%s'", req.ProtectionLevel)
|
||||
}
|
||||
|
||||
var signatureAlgorithm kmspb.CryptoKeyVersion_CryptoKeyVersionAlgorithm
|
||||
v, ok := signatureAlgorithmMapping[req.SignatureAlgorithm]
|
||||
if !ok {
|
||||
return nil, errors.Errorf("cloudKMS does not support signature algorithm '%s'", req.SignatureAlgorithm)
|
||||
}
|
||||
switch v := v.(type) {
|
||||
case kmspb.CryptoKeyVersion_CryptoKeyVersionAlgorithm:
|
||||
signatureAlgorithm = v
|
||||
case map[int]kmspb.CryptoKeyVersion_CryptoKeyVersionAlgorithm:
|
||||
if signatureAlgorithm, ok = v[req.Bits]; !ok {
|
||||
return nil, errors.Errorf("cloudKMS does not support signature algorithm '%s' with '%d' bits", req.SignatureAlgorithm, req.Bits)
|
||||
}
|
||||
default:
|
||||
return nil, errors.Errorf("unexpected error: this should not happen")
|
||||
}
|
||||
|
||||
var crytoKeyName string
|
||||
|
||||
// Split `projects/PROJECT_ID/locations/global/keyRings/RING_ID/cryptoKeys/KEY_ID`
|
||||
// to `projects/PROJECT_ID/locations/global/keyRings/RING_ID` and `KEY_ID`.
|
||||
keyRing, keyID := Parent(req.Name)
|
||||
if err := k.createKeyRingIfNeeded(keyRing); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx, cancel := defaultContext()
|
||||
defer cancel()
|
||||
|
||||
// Create private key in CloudKMS.
|
||||
response, err := k.client.CreateCryptoKey(ctx, &kmspb.CreateCryptoKeyRequest{
|
||||
Parent: keyRing,
|
||||
CryptoKeyId: keyID,
|
||||
CryptoKey: &kmspb.CryptoKey{
|
||||
Purpose: kmspb.CryptoKey_ASYMMETRIC_SIGN,
|
||||
VersionTemplate: &kmspb.CryptoKeyVersionTemplate{
|
||||
ProtectionLevel: protectionLevel,
|
||||
Algorithm: signatureAlgorithm,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
if status.Code(err) != codes.AlreadyExists {
|
||||
return nil, errors.Wrap(err, "cloudKMS CreateCryptoKey failed")
|
||||
}
|
||||
// Create a new version if the key already exists.
|
||||
//
|
||||
// Note that it will have the same purpose, protection level and
|
||||
// algorithm than as previous one.
|
||||
req := &kmspb.CreateCryptoKeyVersionRequest{
|
||||
Parent: req.Name,
|
||||
CryptoKeyVersion: &kmspb.CryptoKeyVersion{
|
||||
State: kmspb.CryptoKeyVersion_ENABLED,
|
||||
},
|
||||
}
|
||||
response, err := k.client.CreateCryptoKeyVersion(ctx, req)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "cloudKMS CreateCryptoKeyVersion failed")
|
||||
}
|
||||
crytoKeyName = response.Name
|
||||
} else {
|
||||
crytoKeyName = response.Name + "/cryptoKeyVersions/1"
|
||||
}
|
||||
|
||||
// Sleep deterministically to avoid retries because of PENDING_GENERATING.
|
||||
// One second is often enough.
|
||||
if protectionLevel == kmspb.ProtectionLevel_HSM {
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
|
||||
// Retrieve public key to add it to the response.
|
||||
pk, err := k.GetPublicKey(&apiv1.GetPublicKeyRequest{
|
||||
Name: crytoKeyName,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "cloudKMS GetPublicKey failed")
|
||||
}
|
||||
|
||||
return &apiv1.CreateKeyResponse{
|
||||
Name: crytoKeyName,
|
||||
PublicKey: pk,
|
||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
||||
SigningKey: crytoKeyName,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (k *CloudKMS) createKeyRingIfNeeded(name string) error {
|
||||
ctx, cancel := defaultContext()
|
||||
defer cancel()
|
||||
|
||||
_, err := k.client.GetKeyRing(ctx, &kmspb.GetKeyRingRequest{
|
||||
Name: name,
|
||||
})
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
parent, child := Parent(name)
|
||||
_, err = k.client.CreateKeyRing(ctx, &kmspb.CreateKeyRingRequest{
|
||||
Parent: parent,
|
||||
KeyRingId: child,
|
||||
})
|
||||
if err != nil && status.Code(err) != codes.AlreadyExists {
|
||||
return errors.Wrap(err, "cloudKMS CreateKeyRing failed")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPublicKey gets from Google's Cloud KMS a public key by name. Key names
|
||||
// follow the pattern:
|
||||
//
|
||||
// projects/([^/]+)/locations/([a-zA-Z0-9_-]{1,63})/keyRings/([a-zA-Z0-9_-]{1,63})/cryptoKeys/([a-zA-Z0-9_-]{1,63})/cryptoKeyVersions/([a-zA-Z0-9_-]{1,63})
|
||||
func (k *CloudKMS) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey, error) {
|
||||
if req.Name == "" {
|
||||
return nil, errors.New("createKeyRequest 'name' cannot be empty")
|
||||
}
|
||||
|
||||
response, err := k.getPublicKeyWithRetries(req.Name, pendingGenerationRetries)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "cloudKMS GetPublicKey failed")
|
||||
}
|
||||
|
||||
pk, err := pemutil.ParseKey([]byte(response.Pem))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pk, nil
|
||||
}
|
||||
|
||||
// getPublicKeyWithRetries retries the request if the error is
|
||||
// FailedPrecondition, caused because the key is in the PENDING_GENERATION
|
||||
// status.
|
||||
func (k *CloudKMS) getPublicKeyWithRetries(name string, retries int) (response *kmspb.PublicKey, err error) {
|
||||
workFn := func() (*kmspb.PublicKey, error) {
|
||||
ctx, cancel := defaultContext()
|
||||
defer cancel()
|
||||
return k.client.GetPublicKey(ctx, &kmspb.GetPublicKeyRequest{
|
||||
Name: name,
|
||||
})
|
||||
}
|
||||
for i := 0; i < retries; i++ {
|
||||
if response, err = workFn(); err == nil {
|
||||
return
|
||||
}
|
||||
if status.Code(err) == codes.FailedPrecondition {
|
||||
log.Println("Waiting for key generation ...")
|
||||
time.Sleep(time.Duration(i+1) * time.Second)
|
||||
continue
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func defaultContext() (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(context.Background(), 15*time.Second)
|
||||
}
|
||||
|
||||
// Parent splits a string in the format `key/value/key2/value2` in a parent and
|
||||
// child, for the previous string it will return `key/value` and `value2`.
|
||||
func Parent(name string) (string, string) {
|
||||
a, b := parent(name)
|
||||
a, _ = parent(a)
|
||||
return a, b
|
||||
}
|
||||
|
||||
func parent(name string) (string, string) {
|
||||
i := strings.LastIndex(name, "/")
|
||||
switch i {
|
||||
case -1:
|
||||
return "", name
|
||||
case 0:
|
||||
return "", name[i+1:]
|
||||
default:
|
||||
return name[:i], name[i+1:]
|
||||
}
|
||||
}
|
@ -1,464 +0,0 @@
|
||||
package cloudkms
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"fmt"
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
gax "github.com/googleapis/gax-go/v2"
|
||||
"github.com/smallstep/certificates/kms/apiv1"
|
||||
"go.step.sm/crypto/pemutil"
|
||||
"google.golang.org/api/option"
|
||||
kmspb "google.golang.org/genproto/googleapis/cloud/kms/v1"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestParent(t *testing.T) {
|
||||
type args struct {
|
||||
name string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
want1 string
|
||||
}{
|
||||
{"zero", args{"child"}, "", "child"},
|
||||
{"one", args{"parent/child"}, "", "child"},
|
||||
{"two", args{"grandparent/parent/child"}, "grandparent", "child"},
|
||||
{"three", args{"great-grandparent/grandparent/parent/child"}, "great-grandparent/grandparent", "child"},
|
||||
{"empty", args{""}, "", ""},
|
||||
{"root", args{"/"}, "", ""},
|
||||
{"child", args{"/child"}, "", "child"},
|
||||
{"parent", args{"parent/"}, "", ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, got1 := Parent(tt.args.name)
|
||||
if got != tt.want {
|
||||
t.Errorf("Parent() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
if got1 != tt.want1 {
|
||||
t.Errorf("Parent() got1 = %v, want %v", got1, tt.want1)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
tmp := newKeyManagementClient
|
||||
t.Cleanup(func() {
|
||||
newKeyManagementClient = tmp
|
||||
})
|
||||
newKeyManagementClient = func(ctx context.Context, opts ...option.ClientOption) (KeyManagementClient, error) {
|
||||
if len(opts) > 0 {
|
||||
return nil, fmt.Errorf("test error")
|
||||
}
|
||||
return &MockClient{}, nil
|
||||
}
|
||||
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
opts apiv1.Options
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *CloudKMS
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", args{context.Background(), apiv1.Options{}}, &CloudKMS{client: &MockClient{}}, false},
|
||||
{"ok with uri", args{context.Background(), apiv1.Options{URI: "cloudkms:"}}, &CloudKMS{client: &MockClient{}}, false},
|
||||
{"fail credentials", args{context.Background(), apiv1.Options{CredentialsFile: "testdata/missing"}}, nil, true},
|
||||
{"fail with uri", args{context.Background(), apiv1.Options{URI: "cloudkms:credentials-file=testdata/missing"}}, nil, true},
|
||||
{"fail schema", args{context.Background(), apiv1.Options{URI: "pkcs11:"}}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := New(tt.args.ctx, tt.args.opts)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("New() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew_real(t *testing.T) {
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
opts apiv1.Options
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *CloudKMS
|
||||
wantErr bool
|
||||
}{
|
||||
{"fail credentials", args{context.Background(), apiv1.Options{CredentialsFile: "testdata/missing"}}, nil, true},
|
||||
{"fail with uri", args{context.Background(), apiv1.Options{URI: "cloudkms:credentials-file=testdata/missing"}}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := New(tt.args.ctx, tt.args.opts)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("New() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCloudKMS(t *testing.T) {
|
||||
type args struct {
|
||||
client KeyManagementClient
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *CloudKMS
|
||||
}{
|
||||
{"ok", args{&MockClient{}}, &CloudKMS{&MockClient{}}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := NewCloudKMS(tt.args.client); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("NewCloudKMS() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloudKMS_Close(t *testing.T) {
|
||||
type fields struct {
|
||||
client KeyManagementClient
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", fields{&MockClient{close: func() error { return nil }}}, false},
|
||||
{"fail", fields{&MockClient{close: func() error { return fmt.Errorf("an error") }}}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
k := &CloudKMS{
|
||||
client: tt.fields.client,
|
||||
}
|
||||
if err := k.Close(); (err != nil) != tt.wantErr {
|
||||
t.Errorf("CloudKMS.Close() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloudKMS_CreateSigner(t *testing.T) {
|
||||
keyName := "projects/p/locations/l/keyRings/k/cryptoKeys/c/cryptoKeyVersions/1"
|
||||
pemBytes, err := os.ReadFile("testdata/pub.pem")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
pk, err := pemutil.ParseKey(pemBytes)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
type fields struct {
|
||||
client KeyManagementClient
|
||||
}
|
||||
type args struct {
|
||||
req *apiv1.CreateSignerRequest
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want crypto.Signer
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", fields{&MockClient{
|
||||
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
|
||||
return &kmspb.PublicKey{Pem: string(pemBytes)}, nil
|
||||
},
|
||||
}}, args{&apiv1.CreateSignerRequest{SigningKey: keyName}}, &Signer{client: &MockClient{}, signingKey: keyName, publicKey: pk}, false},
|
||||
{"fail", fields{&MockClient{
|
||||
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
|
||||
return nil, fmt.Errorf("test error")
|
||||
},
|
||||
}}, args{&apiv1.CreateSignerRequest{SigningKey: ""}}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
k := &CloudKMS{
|
||||
client: tt.fields.client,
|
||||
}
|
||||
got, err := k.CreateSigner(tt.args.req)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("CloudKMS.CreateSigner() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if signer, ok := got.(*Signer); ok {
|
||||
signer.client = &MockClient{}
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("CloudKMS.CreateSigner() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloudKMS_CreateKey(t *testing.T) {
|
||||
keyName := "projects/p/locations/l/keyRings/k/cryptoKeys/c"
|
||||
testError := fmt.Errorf("an error")
|
||||
alreadyExists := status.Error(codes.AlreadyExists, "already exists")
|
||||
|
||||
pemBytes, err := os.ReadFile("testdata/pub.pem")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
pk, err := pemutil.ParseKey(pemBytes)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var retries int
|
||||
type fields struct {
|
||||
client KeyManagementClient
|
||||
}
|
||||
type args struct {
|
||||
req *apiv1.CreateKeyRequest
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want *apiv1.CreateKeyResponse
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", fields{
|
||||
&MockClient{
|
||||
getKeyRing: func(_ context.Context, _ *kmspb.GetKeyRingRequest, _ ...gax.CallOption) (*kmspb.KeyRing, error) {
|
||||
return &kmspb.KeyRing{}, nil
|
||||
},
|
||||
createCryptoKey: func(_ context.Context, _ *kmspb.CreateCryptoKeyRequest, _ ...gax.CallOption) (*kmspb.CryptoKey, error) {
|
||||
return &kmspb.CryptoKey{Name: keyName}, nil
|
||||
},
|
||||
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
|
||||
return &kmspb.PublicKey{Pem: string(pemBytes)}, nil
|
||||
},
|
||||
}},
|
||||
args{&apiv1.CreateKeyRequest{Name: keyName, ProtectionLevel: apiv1.HSM, SignatureAlgorithm: apiv1.ECDSAWithSHA256}},
|
||||
&apiv1.CreateKeyResponse{Name: keyName + "/cryptoKeyVersions/1", PublicKey: pk, CreateSignerRequest: apiv1.CreateSignerRequest{SigningKey: keyName + "/cryptoKeyVersions/1"}}, false},
|
||||
{"ok new key ring", fields{
|
||||
&MockClient{
|
||||
getKeyRing: func(_ context.Context, _ *kmspb.GetKeyRingRequest, _ ...gax.CallOption) (*kmspb.KeyRing, error) {
|
||||
return nil, testError
|
||||
},
|
||||
createKeyRing: func(_ context.Context, _ *kmspb.CreateKeyRingRequest, _ ...gax.CallOption) (*kmspb.KeyRing, error) {
|
||||
return nil, alreadyExists
|
||||
},
|
||||
createCryptoKey: func(_ context.Context, _ *kmspb.CreateCryptoKeyRequest, _ ...gax.CallOption) (*kmspb.CryptoKey, error) {
|
||||
return &kmspb.CryptoKey{Name: keyName}, nil
|
||||
},
|
||||
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
|
||||
return &kmspb.PublicKey{Pem: string(pemBytes)}, nil
|
||||
},
|
||||
}},
|
||||
args{&apiv1.CreateKeyRequest{Name: keyName, ProtectionLevel: apiv1.Software, SignatureAlgorithm: apiv1.SHA256WithRSA, Bits: 3072}},
|
||||
&apiv1.CreateKeyResponse{Name: keyName + "/cryptoKeyVersions/1", PublicKey: pk, CreateSignerRequest: apiv1.CreateSignerRequest{SigningKey: keyName + "/cryptoKeyVersions/1"}}, false},
|
||||
{"ok new key version", fields{
|
||||
&MockClient{
|
||||
getKeyRing: func(_ context.Context, _ *kmspb.GetKeyRingRequest, _ ...gax.CallOption) (*kmspb.KeyRing, error) {
|
||||
return &kmspb.KeyRing{}, nil
|
||||
},
|
||||
createCryptoKey: func(_ context.Context, _ *kmspb.CreateCryptoKeyRequest, _ ...gax.CallOption) (*kmspb.CryptoKey, error) {
|
||||
return nil, alreadyExists
|
||||
},
|
||||
createCryptoKeyVersion: func(_ context.Context, _ *kmspb.CreateCryptoKeyVersionRequest, _ ...gax.CallOption) (*kmspb.CryptoKeyVersion, error) {
|
||||
return &kmspb.CryptoKeyVersion{Name: keyName + "/cryptoKeyVersions/2"}, nil
|
||||
},
|
||||
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
|
||||
return &kmspb.PublicKey{Pem: string(pemBytes)}, nil
|
||||
},
|
||||
}},
|
||||
args{&apiv1.CreateKeyRequest{Name: keyName, ProtectionLevel: apiv1.HSM, SignatureAlgorithm: apiv1.ECDSAWithSHA256}},
|
||||
&apiv1.CreateKeyResponse{Name: keyName + "/cryptoKeyVersions/2", PublicKey: pk, CreateSignerRequest: apiv1.CreateSignerRequest{SigningKey: keyName + "/cryptoKeyVersions/2"}}, false},
|
||||
{"ok with retries", fields{
|
||||
&MockClient{
|
||||
getKeyRing: func(_ context.Context, _ *kmspb.GetKeyRingRequest, _ ...gax.CallOption) (*kmspb.KeyRing, error) {
|
||||
return &kmspb.KeyRing{}, nil
|
||||
},
|
||||
createCryptoKey: func(_ context.Context, _ *kmspb.CreateCryptoKeyRequest, _ ...gax.CallOption) (*kmspb.CryptoKey, error) {
|
||||
return &kmspb.CryptoKey{Name: keyName}, nil
|
||||
},
|
||||
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
|
||||
if retries != 2 {
|
||||
retries++
|
||||
return nil, status.Error(codes.FailedPrecondition, "key is not enabled, current state is: PENDING_GENERATION")
|
||||
}
|
||||
return &kmspb.PublicKey{Pem: string(pemBytes)}, nil
|
||||
},
|
||||
}},
|
||||
args{&apiv1.CreateKeyRequest{Name: keyName, ProtectionLevel: apiv1.HSM, SignatureAlgorithm: apiv1.ECDSAWithSHA256}},
|
||||
&apiv1.CreateKeyResponse{Name: keyName + "/cryptoKeyVersions/1", PublicKey: pk, CreateSignerRequest: apiv1.CreateSignerRequest{SigningKey: keyName + "/cryptoKeyVersions/1"}}, false},
|
||||
{"fail name", fields{&MockClient{}}, args{&apiv1.CreateKeyRequest{}}, nil, true},
|
||||
{"fail protection level", fields{&MockClient{}}, args{&apiv1.CreateKeyRequest{Name: keyName, ProtectionLevel: apiv1.ProtectionLevel(100)}}, nil, true},
|
||||
{"fail signature algorithm", fields{&MockClient{}}, args{&apiv1.CreateKeyRequest{Name: keyName, ProtectionLevel: apiv1.Software, SignatureAlgorithm: apiv1.SignatureAlgorithm(100)}}, nil, true},
|
||||
{"fail number of bits", fields{&MockClient{}}, args{&apiv1.CreateKeyRequest{Name: keyName, ProtectionLevel: apiv1.Software, SignatureAlgorithm: apiv1.SHA256WithRSA, Bits: 1024}},
|
||||
nil, true},
|
||||
{"fail create key ring", fields{
|
||||
&MockClient{
|
||||
getKeyRing: func(_ context.Context, _ *kmspb.GetKeyRingRequest, _ ...gax.CallOption) (*kmspb.KeyRing, error) {
|
||||
return nil, testError
|
||||
},
|
||||
createKeyRing: func(_ context.Context, _ *kmspb.CreateKeyRingRequest, _ ...gax.CallOption) (*kmspb.KeyRing, error) {
|
||||
return nil, testError
|
||||
},
|
||||
}},
|
||||
args{&apiv1.CreateKeyRequest{Name: keyName, ProtectionLevel: apiv1.HSM, SignatureAlgorithm: apiv1.ECDSAWithSHA256}},
|
||||
nil, true},
|
||||
{"fail create key", fields{
|
||||
&MockClient{
|
||||
getKeyRing: func(_ context.Context, _ *kmspb.GetKeyRingRequest, _ ...gax.CallOption) (*kmspb.KeyRing, error) {
|
||||
return &kmspb.KeyRing{}, nil
|
||||
},
|
||||
createCryptoKey: func(_ context.Context, _ *kmspb.CreateCryptoKeyRequest, _ ...gax.CallOption) (*kmspb.CryptoKey, error) {
|
||||
return nil, testError
|
||||
},
|
||||
}},
|
||||
args{&apiv1.CreateKeyRequest{Name: keyName, ProtectionLevel: apiv1.HSM, SignatureAlgorithm: apiv1.ECDSAWithSHA256}},
|
||||
nil, true},
|
||||
{"fail create key version", fields{
|
||||
&MockClient{
|
||||
getKeyRing: func(_ context.Context, _ *kmspb.GetKeyRingRequest, _ ...gax.CallOption) (*kmspb.KeyRing, error) {
|
||||
return &kmspb.KeyRing{}, nil
|
||||
},
|
||||
createCryptoKey: func(_ context.Context, _ *kmspb.CreateCryptoKeyRequest, _ ...gax.CallOption) (*kmspb.CryptoKey, error) {
|
||||
return nil, alreadyExists
|
||||
},
|
||||
createCryptoKeyVersion: func(_ context.Context, _ *kmspb.CreateCryptoKeyVersionRequest, _ ...gax.CallOption) (*kmspb.CryptoKeyVersion, error) {
|
||||
return nil, testError
|
||||
},
|
||||
}},
|
||||
args{&apiv1.CreateKeyRequest{Name: keyName, ProtectionLevel: apiv1.HSM, SignatureAlgorithm: apiv1.ECDSAWithSHA256}},
|
||||
nil, true},
|
||||
{"fail get public key", fields{
|
||||
&MockClient{
|
||||
getKeyRing: func(_ context.Context, _ *kmspb.GetKeyRingRequest, _ ...gax.CallOption) (*kmspb.KeyRing, error) {
|
||||
return &kmspb.KeyRing{}, nil
|
||||
},
|
||||
createCryptoKey: func(_ context.Context, _ *kmspb.CreateCryptoKeyRequest, _ ...gax.CallOption) (*kmspb.CryptoKey, error) {
|
||||
return &kmspb.CryptoKey{Name: keyName}, nil
|
||||
},
|
||||
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
|
||||
return nil, testError
|
||||
},
|
||||
}},
|
||||
args{&apiv1.CreateKeyRequest{Name: keyName, ProtectionLevel: apiv1.HSM, SignatureAlgorithm: apiv1.ECDSAWithSHA256}},
|
||||
nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
k := &CloudKMS{
|
||||
client: tt.fields.client,
|
||||
}
|
||||
got, err := k.CreateKey(tt.args.req)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("CloudKMS.CreateKey() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("CloudKMS.CreateKey() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloudKMS_GetPublicKey(t *testing.T) {
|
||||
keyName := "projects/p/locations/l/keyRings/k/cryptoKeys/c/cryptoKeyVersions/1"
|
||||
testError := fmt.Errorf("an error")
|
||||
|
||||
pemBytes, err := os.ReadFile("testdata/pub.pem")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
pk, err := pemutil.ParseKey(pemBytes)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var retries int
|
||||
type fields struct {
|
||||
client KeyManagementClient
|
||||
}
|
||||
type args struct {
|
||||
req *apiv1.GetPublicKeyRequest
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want crypto.PublicKey
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", fields{
|
||||
&MockClient{
|
||||
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
|
||||
return &kmspb.PublicKey{Pem: string(pemBytes)}, nil
|
||||
},
|
||||
}},
|
||||
args{&apiv1.GetPublicKeyRequest{Name: keyName}}, pk, false},
|
||||
{"ok with retries", fields{
|
||||
&MockClient{
|
||||
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
|
||||
if retries != 2 {
|
||||
retries++
|
||||
return nil, status.Error(codes.FailedPrecondition, "key is not enabled, current state is: PENDING_GENERATION")
|
||||
}
|
||||
return &kmspb.PublicKey{Pem: string(pemBytes)}, nil
|
||||
},
|
||||
}},
|
||||
args{&apiv1.GetPublicKeyRequest{Name: keyName}}, pk, false},
|
||||
{"fail name", fields{&MockClient{}}, args{&apiv1.GetPublicKeyRequest{}}, nil, true},
|
||||
{"fail get public key", fields{
|
||||
&MockClient{
|
||||
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
|
||||
return nil, testError
|
||||
},
|
||||
}},
|
||||
args{&apiv1.GetPublicKeyRequest{Name: keyName}}, nil, true},
|
||||
{"fail parse pem", fields{
|
||||
&MockClient{
|
||||
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
|
||||
return &kmspb.PublicKey{Pem: string("bad pem")}, nil
|
||||
},
|
||||
}},
|
||||
args{&apiv1.GetPublicKeyRequest{Name: keyName}}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
k := &CloudKMS{
|
||||
client: tt.fields.client,
|
||||
}
|
||||
got, err := k.GetPublicKey(tt.args.req)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("CloudKMS.GetPublicKey() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("CloudKMS.GetPublicKey() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue