diff --git a/db/db.go b/db/db.go index 602e3623..eccaf801 100644 --- a/db/db.go +++ b/db/db.go @@ -243,9 +243,7 @@ type ProvisionerData struct { // authorized the certificate. func (db *DB) StoreCertificateChain(p provisioner.Interface, chain ...*x509.Certificate) error { leaf := chain[0] - if err := db.StoreCertificate(leaf); err != nil { - return err - } + serialNumber := []byte(leaf.SerialNumber.String()) data := &CertificateData{} if p != nil { data.Provisioner = &ProvisionerData{ @@ -254,13 +252,16 @@ func (db *DB) StoreCertificateChain(p provisioner.Interface, chain ...*x509.Cert Type: p.GetType().String(), } } - b, err := json.Marshal(data) if err != nil { return errors.Wrap(err, "error marshaling json") } - if err := db.Set(certsDataTable, []byte(leaf.SerialNumber.String()), b); err != nil { - return errors.Wrap(err, "database Set error") + // Add certificate and certificate data in one transaction. + tx := new(database.Tx) + tx.Set(certsTable, serialNumber, leaf.Raw) + tx.Set(certsDataTable, serialNumber, b) + if err := db.Update(tx); err != nil { + return errors.Wrap(err, "database Update error") } return nil } diff --git a/db/db_test.go b/db/db_test.go index d7c58c9c..b4515a5b 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -188,53 +188,36 @@ func TestDB_StoreCertificateChain(t *testing.T) { wantErr bool }{ {"ok", fields{&MockNoSQLDB{ - MSet: func(bucket, key, value []byte) error { - switch string(bucket) { - case "x509_certs": - assert.Equals(t, key, []byte("1234")) - assert.Equals(t, value, []byte("the certificate")) - case "x509_certs_data": - assert.Equals(t, key, []byte("1234")) - assert.Equals(t, value, []byte(`{"provisioner":{"id":"some-id","name":"admin","type":"JWK"}}`)) - default: - t.Errorf("unexpected bucket %s", bucket) + MUpdate: func(tx *database.Tx) error { + if len(tx.Operations) != 2 { + t.Fatal("unexpected number of operations") } + assert.Equals(t, []byte("x509_certs"), tx.Operations[0].Bucket) + assert.Equals(t, []byte("1234"), tx.Operations[0].Key) + assert.Equals(t, []byte("the certificate"), tx.Operations[0].Value) + assert.Equals(t, []byte("x509_certs_data"), tx.Operations[1].Bucket) + assert.Equals(t, []byte("1234"), tx.Operations[1].Key) + assert.Equals(t, []byte(`{"provisioner":{"id":"some-id","name":"admin","type":"JWK"}}`), tx.Operations[1].Value) return nil }, }, true}, args{p, chain}, false}, {"ok no provisioner", fields{&MockNoSQLDB{ - MSet: func(bucket, key, value []byte) error { - switch string(bucket) { - case "x509_certs": - assert.Equals(t, key, []byte("1234")) - assert.Equals(t, value, []byte("the certificate")) - case "x509_certs_data": - assert.Equals(t, key, []byte("1234")) - assert.Equals(t, value, []byte(`{}`)) - default: - t.Errorf("unexpected bucket %s", bucket) + MUpdate: func(tx *database.Tx) error { + if len(tx.Operations) != 2 { + t.Fatal("unexpected number of operations") } + assert.Equals(t, []byte("x509_certs"), tx.Operations[0].Bucket) + assert.Equals(t, []byte("1234"), tx.Operations[0].Key) + assert.Equals(t, []byte("the certificate"), tx.Operations[0].Value) + assert.Equals(t, []byte("x509_certs_data"), tx.Operations[1].Bucket) + assert.Equals(t, []byte("1234"), tx.Operations[1].Key) + assert.Equals(t, []byte(`{}`), tx.Operations[1].Value) return nil }, }, true}, args{nil, chain}, false}, {"fail store certificate", fields{&MockNoSQLDB{ - MSet: func(bucket, key, value []byte) error { - switch string(bucket) { - case "x509_certs": - return errors.New("test error") - default: - return nil - } - }, - }, true}, args{p, chain}, true}, - {"fail store provisioner", fields{&MockNoSQLDB{ - MSet: func(bucket, key, value []byte) error { - switch string(bucket) { - case "x509_certs_data": - return errors.New("test error") - default: - return nil - } + MUpdate: func(tx *database.Tx) error { + return errors.New("test error") }, }, true}, args{p, chain}, true}, }