diff --git a/authority/provisioner/duration.go b/authority/provisioner/duration.go index 68ffbb0a..d18a81e9 100644 --- a/authority/provisioner/duration.go +++ b/authority/provisioner/duration.go @@ -12,6 +12,16 @@ type Duration struct { time.Duration } +// NewDuration parses a duration string and returns a Duration type or and error +// if the given string is not a duration. +func NewDuration(s string) (*Duration, error) { + d, err := time.ParseDuration(s) + if err != nil { + return nil, errors.Wrapf(err, "error parsing %s as duration", s) + } + return &Duration{Duration: d}, nil +} + // MarshalJSON parses a duration string and sets it to the duration. // // A duration string is a possibly signed sequence of decimal numbers, each with @@ -29,7 +39,7 @@ func (d *Duration) MarshalJSON() ([]byte, error) { func (d *Duration) UnmarshalJSON(data []byte) (err error) { var ( s string - _d time.Duration + dd time.Duration ) if d == nil { return errors.New("duration cannot be nil") @@ -37,10 +47,10 @@ func (d *Duration) UnmarshalJSON(data []byte) (err error) { if err = json.Unmarshal(data, &s); err != nil { return errors.Wrapf(err, "error unmarshaling %s", data) } - if _d, err = time.ParseDuration(s); err != nil { + if dd, err = time.ParseDuration(s); err != nil { return errors.Wrapf(err, "error parsing %s as duration", s) } - d.Duration = _d + d.Duration = dd return } diff --git a/authority/provisioner/duration_test.go b/authority/provisioner/duration_test.go index faf5e7f4..828064cc 100644 --- a/authority/provisioner/duration_test.go +++ b/authority/provisioner/duration_test.go @@ -6,6 +6,35 @@ import ( "time" ) +func TestNewDuration(t *testing.T) { + type args struct { + s string + } + tests := []struct { + name string + args args + want *Duration + wantErr bool + }{ + {"ok", args{"1h2m3s"}, &Duration{Duration: 3723 * time.Second}, false}, + {"fail empty", args{""}, nil, true}, + {"fail number", args{"123"}, nil, true}, + {"fail string", args{"1hour"}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewDuration(tt.args.s) + if (err != nil) != tt.wantErr { + t.Errorf("NewDuration() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewDuration() = %v, want %v", got, tt.want) + } + }) + } +} + func TestDuration_UnmarshalJSON(t *testing.T) { type args struct { data []byte