scheduler: add batch support for dialing with cache

pull/127/head
Aloïs Micard 3 years ago
parent 829afcbb6a
commit 7820820fa9
No known key found for this signature in database
GPG Key ID: 1A0EB82F071F5EFE

@ -119,7 +119,7 @@ func (state *State) handleTimeoutURLEvent(subscriber event.Subscriber, msg event
cacheKey := u.Hostname()
count, err := state.hostnameCache.GetInt64(cacheKey)
if err != nil && err != cache.ErrNIL {
if err != nil {
return err
}
count++

@ -96,7 +96,7 @@ func TestHandleTimeoutURLEventNoDispatch(t *testing.T) {
configClientMock.EXPECT().GetForbiddenHostnames().Return([]configapi.ForbiddenHostname{}, nil)
configClientMock.EXPECT().GetBlackListThreshold().Return(configapi.BlackListThreshold{Threshold: 10}, nil)
hostnameCacheMock.EXPECT().GetInt64("down-example.onion").Return(int64(0), cache.ErrNIL)
hostnameCacheMock.EXPECT().GetInt64("down-example.onion").Return(int64(0), nil)
hostnameCacheMock.EXPECT().SetInt64("down-example.onion", int64(1), cache.NoTTL).Return(nil)
s := State{configClient: configClientMock, hostnameCache: hostnameCacheMock, httpClient: httpClientMock}

@ -3,15 +3,12 @@ package cache
//go:generate mockgen -destination=../cache_mock/cache_mock.go -package=cache_mock . Cache
import (
"errors"
"time"
)
var (
// NoTTL define an entry that lives forever
NoTTL = time.Duration(0)
// ErrNIL is returned when there's no value for given key
ErrNIL = errors.New("value is nil")
)
// Cache represent a KV database
@ -21,4 +18,7 @@ type Cache interface {
GetInt64(key string) (int64, error)
SetInt64(key string, value int64, TTL time.Duration) error
GetManyInt64(keys []string) (map[string]int64, error)
SetManyInt64(values map[string]int64, TTL time.Duration) error
}

@ -25,11 +25,11 @@ func NewRedisCache(URI string, keyPrefix string) (Cache, error) {
func (rc *redisCache) GetBytes(key string) ([]byte, error) {
val, err := rc.client.Get(context.Background(), rc.getKey(key)).Bytes()
if err == redis.Nil {
err = ErrNIL
if err != nil && err != redis.Nil {
return nil, err
}
return val, err
return val, nil
}
func (rc *redisCache) SetBytes(key string, value []byte, TTL time.Duration) error {
@ -38,17 +38,60 @@ func (rc *redisCache) SetBytes(key string, value []byte, TTL time.Duration) erro
func (rc *redisCache) GetInt64(key string) (int64, error) {
val, err := rc.client.Get(context.Background(), rc.getKey(key)).Int64()
if err == redis.Nil {
err = ErrNIL
if err != nil && err != redis.Nil {
return 0, err
}
return val, err
return val, nil
}
func (rc *redisCache) SetInt64(key string, value int64, TTL time.Duration) error {
return rc.client.Set(context.Background(), rc.getKey(key), value, TTL).Err()
}
func (rc *redisCache) GetManyInt64(keys []string) (map[string]int64, error) {
pipeline := rc.client.Pipeline()
// Execute commands and keep pointer to them
commands := map[string]*redis.StringCmd{}
for _, key := range keys {
commands[key] = pipeline.Get(context.Background(), rc.getKey(key))
}
// Execute pipeline
if _, err := pipeline.Exec(context.Background()); err != nil && err != redis.Nil {
return nil, err
}
// Get back values
values := map[string]int64{}
for _, key := range keys {
val, err := commands[key].Int64()
if err != nil {
// If it's a real error
if err != redis.Nil {
return nil, err
}
} else {
// Only returns entry if there's one
values[key] = val
}
}
return values, nil
}
func (rc *redisCache) SetManyInt64(values map[string]int64, TTL time.Duration) error {
pipeline := rc.client.TxPipeline()
for key, value := range values {
pipeline.Set(context.Background(), rc.getKey(key), value, TTL)
}
_, err := pipeline.Exec(context.Background())
return err
}
func (rc *redisCache) getKey(key string) string {
if rc.keyPrefix == "" {
return key

@ -135,7 +135,12 @@ func (state *State) setConfiguration(w http.ResponseWriter, r *http.Request) {
func setDefaultValues(configCache cache.Cache, values map[string]string) error {
for key, value := range values {
if _, err := configCache.GetBytes(key); err == cache.ErrNIL {
b, err := configCache.GetBytes(key)
if err != nil {
return err
}
if b == nil {
if err := configCache.SetBytes(key, []byte(value), cache.NoTTL); err != nil {
return fmt.Errorf("error while setting default value of %s: %s", key, err)
}

@ -92,16 +92,38 @@ func (state *State) handleNewResourceEvent(subscriber event.Subscriber, msg even
return fmt.Errorf("error while extracting URLs")
}
// Load values in batch
urlCache, err := state.urlCache.GetManyInt64(urls)
if err != nil {
return err
}
for _, u := range urls {
if err := state.processURL(u, subscriber); err != nil {
if err := state.processURL(u, subscriber, urlCache); err != nil {
log.Err(err).Msg("error while processing URL")
}
}
// Update URL cache
delay, err := state.configClient.GetRefreshDelay()
if err != nil {
return err
}
ttl := delay.Delay
if ttl == -1 {
ttl = cache.NoTTL
}
// Update values in batch
if err := state.urlCache.SetManyInt64(urlCache, ttl); err != nil {
return err
}
return nil
}
func (state *State) processURL(rawURL string, pub event.Publisher) error {
func (state *State) processURL(rawURL string, pub event.Publisher, urlCache map[string]int64) error {
u, err := url.Parse(rawURL)
if err != nil {
return fmt.Errorf("error while parsing URL: %s", err)
@ -157,30 +179,13 @@ func (state *State) processURL(rawURL string, pub event.Publisher) error {
}
// Check if URL should be scheduled
count, err := state.urlCache.GetInt64(rawURL)
if err != nil && err != cache.ErrNIL {
return err
}
if count > 0 {
if urlCache[rawURL] > 0 {
return fmt.Errorf("%s %w", u, errAlreadyScheduled)
}
log.Debug().Stringer("url", u).Msg("URL should be scheduled")
// Update URL cache
delay, err := state.configClient.GetRefreshDelay()
if err != nil {
return err
}
ttl := delay.Delay
if ttl == -1 {
ttl = cache.NoTTL
}
if err := state.urlCache.SetInt64(rawURL, count+1, ttl); err != nil {
return err
}
urlCache[rawURL]++
if err := pub.PublishEvent(&event.NewURLEvent{URL: rawURL}); err != nil {
return fmt.Errorf("error while publishing URL: %s", err)

@ -13,7 +13,6 @@ import (
"github.com/creekorful/trandoshan/internal/test"
"github.com/golang/mock/gomock"
"testing"
"time"
)
func TestState_Name(t *testing.T) {
@ -66,7 +65,7 @@ func TestProcessURL_NotDotOnion(t *testing.T) {
for _, url := range urls {
state := State{}
if err := state.processURL(url, nil); !errors.Is(err, errNotOnionHostname) {
if err := state.processURL(url, nil, nil); !errors.Is(err, errNotOnionHostname) {
t.Fail()
}
}
@ -80,7 +79,7 @@ func TestProcessURL_ProtocolForbidden(t *testing.T) {
for _, url := range urls {
state := State{}
if err := state.processURL(url, nil); !errors.Is(err, errProtocolNotAllowed) {
if err := state.processURL(url, nil, nil); !errors.Is(err, errProtocolNotAllowed) {
t.Fail()
}
}
@ -98,7 +97,7 @@ func TestProcessURL_ExtensionForbidden(t *testing.T) {
configClientMock.EXPECT().GetAllowedMimeTypes().Return([]client.MimeType{{Extensions: []string{"html", "php"}}}, nil)
state := State{configClient: configClientMock}
if err := state.processURL(url, nil); !errors.Is(err, errExtensionNotAllowed) {
if err := state.processURL(url, nil, nil); !errors.Is(err, errExtensionNotAllowed) {
t.Fail()
}
}
@ -139,7 +138,7 @@ func TestProcessURL_HostnameForbidden(t *testing.T) {
configClientMock.EXPECT().GetForbiddenHostnames().Return(tst.forbiddenHostnames, nil)
state := State{configClient: configClientMock}
if err := state.processURL(tst.url, nil); !errors.Is(err, errHostnameNotAllowed) {
if err := state.processURL(tst.url, nil, nil); !errors.Is(err, errHostnameNotAllowed) {
t.Fail()
}
}
@ -149,42 +148,95 @@ func TestProcessURL_AlreadyScheduled(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
urlCacheMock := cache_mock.NewMockCache(mockCtrl)
configClientMock := client_mock.NewMockClient(mockCtrl)
urlCacheMock.EXPECT().GetInt64("https://facebookcorewwi.onion/test.php?id=12").Return(int64(1), nil)
configClientMock.EXPECT().GetAllowedMimeTypes().Return([]client.MimeType{{Extensions: []string{"html", "php"}}}, nil)
configClientMock.EXPECT().GetForbiddenHostnames().Return([]client.ForbiddenHostname{}, nil)
state := State{urlCache: urlCacheMock, configClient: configClientMock}
if err := state.processURL("https://facebookcorewwi.onion/test.php?id=12", nil); !errors.Is(err, errAlreadyScheduled) {
urlCache := map[string]int64{"https://facebookcorewwi.onion/test.php?id=12": 1}
state := State{configClient: configClientMock}
if err := state.processURL("https://facebookcorewwi.onion/test.php?id=12", nil, urlCache); !errors.Is(err, errAlreadyScheduled) {
t.Fail()
}
}
func TestHandleNewResourceEvent(t *testing.T) {
func TestProcessURL(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
urlCacheMock := cache_mock.NewMockCache(mockCtrl)
configClientMock := client_mock.NewMockClient(mockCtrl)
pubMock := event_mock.NewMockPublisher(mockCtrl)
urls := []string{"https://example.onion/index.php", "http://google.onion/admin.secret/login.html",
"https://example.onion", "https://www.facebookcorewwwi.onion/recover.now/initiate?ars=facebook_login"}
// pre fill cache
urlCache := map[string]int64{}
for _, url := range urls {
urlCacheMock.EXPECT().GetInt64(url).Return(int64(0), cache.ErrNIL)
configClientMock.EXPECT().GetAllowedMimeTypes().Return([]client.MimeType{{Extensions: []string{"html", "php"}}}, nil)
configClientMock.EXPECT().GetForbiddenHostnames().Return([]client.ForbiddenHostname{}, nil)
configClientMock.EXPECT().GetRefreshDelay().Return(client.RefreshDelay{Delay: 10 * time.Hour}, nil)
urlCacheMock.EXPECT().SetInt64(url, int64(1), time.Duration(10*time.Hour)).Return(nil)
pubMock.EXPECT().PublishEvent(&event.NewURLEvent{URL: url}).Return(nil)
state := State{urlCache: urlCacheMock, configClient: configClientMock}
if err := state.processURL(url, pubMock); err != nil {
state := State{configClient: configClientMock}
if err := state.processURL(url, pubMock, urlCache); err != nil {
t.Fail()
}
if val, exist := urlCache[url]; !exist || val != 1 {
t.Fail()
}
}
}
func TestHandleNewResourceEvent(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
subscriberMock := event_mock.NewMockSubscriber(mockCtrl)
urlCacheMock := cache_mock.NewMockCache(mockCtrl)
configClientMock := client_mock.NewMockClient(mockCtrl)
msg := event.RawMessage{}
subscriberMock.EXPECT().
Read(&msg, &event.NewResourceEvent{}).
SetArg(1, event.NewResourceEvent{
URL: "https://l.facebookcorewwwi.onion/test.php",
Body: `
<a href=\"https://facebook.onion/test.php?id=1\">This is a little test</a>.
Check out https://google.onion. This is an image https://example.onion/test.png
This domain is blacklisted: https://m.fbi.onion/test.php
`,
}).
Return(nil)
urlCacheMock.EXPECT().
GetManyInt64([]string{"https://facebook.onion/test.php?id=1", "https://google.onion", "https://example.onion/test.png", "https://m.fbi.onion/test.php"}).
Return(map[string]int64{
"https://google.onion": 1,
}, nil)
configClientMock.EXPECT().GetAllowedMimeTypes().
Times(4).
Return([]client.MimeType{{Extensions: []string{"php"}}}, nil)
configClientMock.EXPECT().GetForbiddenHostnames().
Times(3).
Return([]client.ForbiddenHostname{
{Hostname: "fbi.onion"},
}, nil)
configClientMock.EXPECT().GetRefreshDelay().Return(client.RefreshDelay{Delay: -1}, nil)
subscriberMock.EXPECT().PublishEvent(&event.NewURLEvent{
URL: "https://facebook.onion/test.php?id=1",
})
urlCacheMock.EXPECT().SetManyInt64(map[string]int64{
"https://google.onion": 1,
"https://facebook.onion/test.php?id=1": 1,
}, cache.NoTTL).Return(nil)
s := State{urlCache: urlCacheMock, configClient: configClientMock}
if err := s.handleNewResourceEvent(subscriberMock, msg); err != nil {
t.Fail()
}
}

Loading…
Cancel
Save