@ -1,10 +1,12 @@
package nosql
import (
"bytes"
"context"
"crypto/x509"
"encoding/json"
"encoding/pem"
"fmt"
"testing"
"time"
@ -14,7 +16,6 @@ import (
"github.com/smallstep/certificates/db"
"github.com/smallstep/nosql"
nosqldb "github.com/smallstep/nosql/database"
"go.step.sm/crypto/pemutil"
)
@ -31,7 +32,6 @@ func TestDB_CreateCertificate(t *testing.T) {
err error
_id * string
}
countOfCmpAndSwapCalls := 0
var tests = map [ string ] func ( t * testing . T ) test {
"fail/cmpAndSwap-error" : func ( t * testing . T ) test {
cert := & acme . Certificate {
@ -76,7 +76,10 @@ func TestDB_CreateCertificate(t *testing.T) {
return test {
db : & db . MockNoSQLDB {
MCmpAndSwap : func ( bucket , key , old , nu [ ] byte ) ( [ ] byte , bool , error ) {
if countOfCmpAndSwapCalls == 0 {
if ! bytes . Equal ( bucket , certTable ) && ! bytes . Equal ( bucket , certBySerialTable ) {
t . Fail ( )
}
if bytes . Equal ( bucket , certTable ) {
* idPtr = string ( key )
assert . Equals ( t , bucket , certTable )
assert . Equals ( t , key , [ ] byte ( cert . ID ) )
@ -90,7 +93,7 @@ func TestDB_CreateCertificate(t *testing.T) {
assert . True ( t , clock . Now ( ) . Add ( - time . Minute ) . Before ( dbc . CreatedAt ) )
assert . True ( t , clock . Now ( ) . Add ( time . Minute ) . After ( dbc . CreatedAt ) )
}
if countOfCmpAndSwapCalls == 1 {
if bytes. Equal ( bucket , certBySerialTable ) {
assert . Equals ( t , bucket , certBySerialTable )
assert . Equals ( t , key , [ ] byte ( cert . Leaf . SerialNumber . String ( ) ) )
assert . Equals ( t , old , nil )
@ -103,8 +106,6 @@ func TestDB_CreateCertificate(t *testing.T) {
* idPtr = cert . ID
}
countOfCmpAndSwapCalls ++
return nil , true , nil
} ,
} ,
@ -335,3 +336,135 @@ func Test_parseBundle(t *testing.T) {
} )
}
}
func TestDB_GetCertificateBySerial ( t * testing . T ) {
leaf , err := pemutil . ReadCertificate ( "../../../authority/testdata/certs/foo.crt" )
assert . FatalError ( t , err )
inter , err := pemutil . ReadCertificate ( "../../../authority/testdata/certs/intermediate_ca.crt" )
assert . FatalError ( t , err )
root , err := pemutil . ReadCertificate ( "../../../authority/testdata/certs/root_ca.crt" )
assert . FatalError ( t , err )
certID := "certID"
serial := ""
type test struct {
db nosql . DB
err error
acmeErr * acme . Error
}
var tests = map [ string ] func ( t * testing . T ) test {
"fail/not-found" : func ( t * testing . T ) test {
return test {
db : & db . MockNoSQLDB {
MGet : func ( bucket , key [ ] byte ) ( [ ] byte , error ) {
if bytes . Equal ( bucket , certBySerialTable ) {
return nil , nosqldb . ErrNotFound
}
return nil , errors . New ( "wrong table" )
} ,
} ,
acmeErr : acme . NewError ( acme . ErrorMalformedType , "certificate with serial %s not found" , serial ) ,
}
} ,
"fail/db-error" : func ( t * testing . T ) test {
return test {
db : & db . MockNoSQLDB {
MGet : func ( bucket , key [ ] byte ) ( [ ] byte , error ) {
if bytes . Equal ( bucket , certBySerialTable ) {
return nil , errors . New ( "force" )
}
return nil , errors . New ( "wrong table" )
} ,
} ,
err : fmt . Errorf ( "error loading certificate ID for serial %s" , serial ) ,
}
} ,
"fail/unmarshal-dbSerial" : func ( t * testing . T ) test {
return test {
db : & db . MockNoSQLDB {
MGet : func ( bucket , key [ ] byte ) ( [ ] byte , error ) {
if bytes . Equal ( bucket , certBySerialTable ) {
return [ ] byte ( ` { "serial":malformed!} ` ) , nil
}
return nil , errors . New ( "wrong table" )
} ,
} ,
err : fmt . Errorf ( "error unmarshaling certificate with serial %s" , serial ) ,
}
} ,
"ok" : func ( t * testing . T ) test {
return test {
db : & db . MockNoSQLDB {
MGet : func ( bucket , key [ ] byte ) ( [ ] byte , error ) {
if bytes . Equal ( bucket , certBySerialTable ) {
certSerial := dbSerial {
Serial : serial ,
CertificateID : certID ,
}
b , err := json . Marshal ( certSerial )
assert . FatalError ( t , err )
return b , nil
}
if bytes . Equal ( bucket , certTable ) {
cert := dbCert {
ID : certID ,
AccountID : "accountID" ,
OrderID : "orderID" ,
Leaf : pem . EncodeToMemory ( & pem . Block {
Type : "CERTIFICATE" ,
Bytes : leaf . Raw ,
} ) ,
Intermediates : append ( pem . EncodeToMemory ( & pem . Block {
Type : "CERTIFICATE" ,
Bytes : inter . Raw ,
} ) , pem . EncodeToMemory ( & pem . Block {
Type : "CERTIFICATE" ,
Bytes : root . Raw ,
} ) ... ) ,
CreatedAt : clock . Now ( ) ,
}
b , err := json . Marshal ( cert )
assert . FatalError ( t , err )
return b , nil
}
return nil , errors . New ( "wrong table" )
} ,
} ,
}
} ,
}
for name , prep := range tests {
tc := prep ( t )
t . Run ( name , func ( t * testing . T ) {
d := DB { db : tc . db }
cert , err := d . GetCertificateBySerial ( context . Background ( ) , serial )
if err != nil {
switch k := err . ( type ) {
case * acme . Error :
if assert . NotNil ( t , tc . acmeErr ) {
assert . Equals ( t , k . Type , tc . acmeErr . Type )
assert . Equals ( t , k . Detail , tc . acmeErr . Detail )
assert . Equals ( t , k . Status , tc . acmeErr . Status )
assert . Equals ( t , k . Err . Error ( ) , tc . acmeErr . Err . Error ( ) )
assert . Equals ( t , k . Detail , tc . acmeErr . Detail )
}
default :
if assert . NotNil ( t , tc . err ) {
assert . HasPrefix ( t , err . Error ( ) , tc . err . Error ( ) )
}
}
} else if assert . Nil ( t , tc . err ) {
assert . Equals ( t , cert . ID , certID )
assert . Equals ( t , cert . AccountID , "accountID" )
assert . Equals ( t , cert . OrderID , "orderID" )
assert . Equals ( t , cert . Leaf , leaf )
assert . Equals ( t , cert . Intermediates , [ ] * x509 . Certificate { inter , root } )
}
} )
}
}