diff --git a/authority/config/config.go b/authority/config/config.go index 75c32994..589b5bbf 100644 --- a/authority/config/config.go +++ b/authority/config/config.go @@ -270,28 +270,36 @@ func (c *Config) GetAudiences() provisioner.Audiences { for _, name := range c.DNSNames { audiences.Sign = append(audiences.Sign, - fmt.Sprintf("https://%s/1.0/sign", name), - fmt.Sprintf("https://%s/sign", name), - fmt.Sprintf("https://%s/1.0/ssh/sign", name), - fmt.Sprintf("https://%s/ssh/sign", name)) + fmt.Sprintf("https://%s/1.0/sign", toHostname(name)), + fmt.Sprintf("https://%s/sign", toHostname(name)), + fmt.Sprintf("https://%s/1.0/ssh/sign", toHostname(name)), + fmt.Sprintf("https://%s/ssh/sign", toHostname(name))) audiences.Revoke = append(audiences.Revoke, - fmt.Sprintf("https://%s/1.0/revoke", name), - fmt.Sprintf("https://%s/revoke", name)) + fmt.Sprintf("https://%s/1.0/revoke", toHostname(name)), + fmt.Sprintf("https://%s/revoke", toHostname(name))) audiences.SSHSign = append(audiences.SSHSign, - fmt.Sprintf("https://%s/1.0/ssh/sign", name), - fmt.Sprintf("https://%s/ssh/sign", name), - fmt.Sprintf("https://%s/1.0/sign", name), - fmt.Sprintf("https://%s/sign", name)) + fmt.Sprintf("https://%s/1.0/ssh/sign", toHostname(name)), + fmt.Sprintf("https://%s/ssh/sign", toHostname(name)), + fmt.Sprintf("https://%s/1.0/sign", toHostname(name)), + fmt.Sprintf("https://%s/sign", toHostname(name))) audiences.SSHRevoke = append(audiences.SSHRevoke, - fmt.Sprintf("https://%s/1.0/ssh/revoke", name), - fmt.Sprintf("https://%s/ssh/revoke", name)) + fmt.Sprintf("https://%s/1.0/ssh/revoke", toHostname(name)), + fmt.Sprintf("https://%s/ssh/revoke", toHostname(name))) audiences.SSHRenew = append(audiences.SSHRenew, - fmt.Sprintf("https://%s/1.0/ssh/renew", name), - fmt.Sprintf("https://%s/ssh/renew", name)) + fmt.Sprintf("https://%s/1.0/ssh/renew", toHostname(name)), + fmt.Sprintf("https://%s/ssh/renew", toHostname(name))) audiences.SSHRekey = append(audiences.SSHRekey, - fmt.Sprintf("https://%s/1.0/ssh/rekey", name), - fmt.Sprintf("https://%s/ssh/rekey", name)) + fmt.Sprintf("https://%s/1.0/ssh/rekey", toHostname(name)), + fmt.Sprintf("https://%s/ssh/rekey", toHostname(name))) } return audiences } + +func toHostname(name string) string { + // ensure an IPv6 address is represented with square brackets when used as hostname + if ip := net.ParseIP(name); ip != nil && ip.To4() == nil { + name = "[" + name + "]" + } + return name +} diff --git a/authority/config/config_test.go b/authority/config/config_test.go index a5b60513..b921be13 100644 --- a/authority/config/config_test.go +++ b/authority/config/config_test.go @@ -7,9 +7,8 @@ import ( "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/authority/provisioner" - "go.step.sm/crypto/jose" - _ "github.com/smallstep/certificates/cas" + "go.step.sm/crypto/jose" ) func TestConfigValidate(t *testing.T) { @@ -298,3 +297,23 @@ func TestAuthConfigValidate(t *testing.T) { }) } } + +func Test_toHostname(t *testing.T) { + tests := []struct { + name string + want string + }{ + {name: "localhost", want: "localhost"}, + {name: "ca.smallstep.com", want: "ca.smallstep.com"}, + {name: "127.0.0.1", want: "127.0.0.1"}, + {name: "::1", want: "[::1]"}, + {name: "[::1]", want: "[::1]"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := toHostname(tt.name); got != tt.want { + t.Errorf("toHostname() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/ca/provisioner_test.go b/ca/provisioner_test.go index 01b54d17..39193f3f 100644 --- a/ca/provisioner_test.go +++ b/ca/provisioner_test.go @@ -200,6 +200,102 @@ func TestProvisioner_Token(t *testing.T) { } } +func TestProvisioner_IPv6Token(t *testing.T) { + p := getTestProvisioner(t, "https://[::1]:9000") + sha := "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7" + + type fields struct { + name string + kid string + fingerprint string + jwk *jose.JSONWebKey + tokenLifetime time.Duration + } + type args struct { + subject string + sans []string + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + {"ok", fields{p.name, p.kid, sha, p.jwk, p.tokenLifetime}, args{"subject", nil}, false}, + {"ok-with-san", fields{p.name, p.kid, sha, p.jwk, p.tokenLifetime}, args{"subject", []string{"foo.smallstep.com"}}, false}, + {"ok-with-sans", fields{p.name, p.kid, sha, p.jwk, p.tokenLifetime}, args{"subject", []string{"foo.smallstep.com", "127.0.0.1"}}, false}, + {"fail-no-subject", fields{p.name, p.kid, sha, p.jwk, p.tokenLifetime}, args{"", []string{"foo.smallstep.com"}}, true}, + {"fail-no-key", fields{p.name, p.kid, sha, &jose.JSONWebKey{}, p.tokenLifetime}, args{"subject", nil}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &Provisioner{ + name: tt.fields.name, + kid: tt.fields.kid, + audience: "https://[::1]:9000/1.0/sign", + fingerprint: tt.fields.fingerprint, + jwk: tt.fields.jwk, + tokenLifetime: tt.fields.tokenLifetime, + } + got, err := p.Token(tt.args.subject, tt.args.sans...) + if (err != nil) != tt.wantErr { + t.Errorf("Provisioner.Token() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.wantErr == false { + jwt, err := jose.ParseSigned(got) + if err != nil { + t.Error(err) + return + } + var claims jose.Claims + if err := jwt.Claims(tt.fields.jwk.Public(), &claims); err != nil { + t.Error(err) + return + } + if err := claims.ValidateWithLeeway(jose.Expected{ + Audience: []string{"https://[::1]:9000/1.0/sign"}, + Issuer: tt.fields.name, + Subject: tt.args.subject, + Time: time.Now().UTC(), + }, time.Minute); err != nil { + t.Error(err) + return + } + lifetime := claims.Expiry.Time().Sub(claims.NotBefore.Time()) + if lifetime != tt.fields.tokenLifetime { + t.Errorf("Claims token life time = %s, want %s", lifetime, tt.fields.tokenLifetime) + } + allClaims := make(map[string]interface{}) + if err := jwt.Claims(tt.fields.jwk.Public(), &allClaims); err != nil { + t.Error(err) + return + } + if v, ok := allClaims["sha"].(string); !ok || v != sha { + t.Errorf("Claim sha = %s, want %s", v, sha) + } + if len(tt.args.sans) == 0 { + if v, ok := allClaims["sans"].([]interface{}); !ok || !reflect.DeepEqual(v, []interface{}{tt.args.subject}) { + t.Errorf("Claim sans = %s, want %s", v, []interface{}{tt.args.subject}) + } + } else { + want := []interface{}{} + for _, s := range tt.args.sans { + want = append(want, s) + } + if v, ok := allClaims["sans"].([]interface{}); !ok || !reflect.DeepEqual(v, want) { + t.Errorf("Claim sans = %s, want %s", v, want) + } + } + if v, ok := allClaims["jti"].(string); !ok || v == "" { + t.Errorf("Claim jti = %s, want not blank", v) + } + } + }) + } +} + func TestProvisioner_SSHToken(t *testing.T) { p := getTestProvisioner(t, "https://127.0.0.1:9000") sha := "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7"