Implement API authentication

Also split source code into new architecture + start writing tests
pull/30/head
Aloïs Micard 4 years ago
parent c2e501d0c2
commit fd32c66774
No known key found for this signature in database
GPG Key ID: 1A0EB82F071F5EFE

1
.gitignore vendored

@ -1 +1,2 @@
.idea/
**/**_mock.go

@ -5,6 +5,8 @@ go 1.14
require (
github.com/PuerkitoBio/purell v1.1.1
github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 // indirect
github.com/dgrijalva/jwt-go v3.2.0+incompatible
github.com/golang/mock v1.4.4
github.com/golang/protobuf v1.4.2 // indirect
github.com/labstack/echo/v4 v4.1.16
github.com/nats-io/nats-server/v2 v2.1.8 // indirect
@ -15,5 +17,6 @@ require (
github.com/urfave/cli/v2 v2.2.0
github.com/valyala/fasthttp v1.9.0
github.com/xhit/go-str2duration/v2 v2.0.0
golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59
mvdan.cc/xurls/v2 v2.1.0
)

@ -18,7 +18,10 @@ github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHqu
github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/mock v1.1.1 h1:G5FRp8JnTd7RQH5kemVNlMeyXQAztQ3mOWV95KxsXH8=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/mock v1.4.4 h1:l75CXGRSwbaYNpl/Z2X1XIIAMSCquvXgpVZDhwEIJsc=
github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8=
@ -149,6 +152,8 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/tools v0.0.0-20190828213141-aed303cbaa74 h1:4cFkmztxtMslUX2SctSl+blCyXfpzhGOy9LhKAqSMA4=
golang.org/x/tools v0.0.0-20190828213141-aed303cbaa74/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=

@ -1,43 +1,18 @@
package api
import (
"context"
"encoding/base64"
"encoding/json"
"github.com/creekorful/trandoshan/api"
"github.com/creekorful/trandoshan/internal/messaging"
"github.com/creekorful/trandoshan/internal/util/logging"
natsutil "github.com/creekorful/trandoshan/internal/util/nats"
logging2 "github.com/creekorful/trandoshan/internal/logging"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
"github.com/nats-io/nats.go"
"github.com/olivere/elastic/v7"
"github.com/rs/zerolog/log"
"github.com/urfave/cli/v2"
"net/http"
"strconv"
"time"
)
var (
resourcesIndex = "resources"
defaultPaginationSize = 50
maxPaginationSize = 100
)
type pagination struct {
page int
size int
}
// Represent a resource in elasticsearch
type resourceIndex struct {
URL string `json:"url"`
Body string `json:"body"`
Title string `json:"title"`
Time time.Time `json:"time"`
}
// GetApp return the api app
func GetApp() *cli.App {
return &cli.App{
@ -45,7 +20,7 @@ func GetApp() *cli.App {
Version: "0.4.0",
Usage: "Trandoshan API component",
Flags: []cli.Flag{
logging.GetLogFlag(),
logging2.GetLogFlag(),
&cli.StringFlag{
Name: "nats-uri",
Usage: "URI to the NATS server",
@ -67,7 +42,7 @@ func GetApp() *cli.App {
}
func execute(c *cli.Context) error {
logging.ConfigureLogger(c)
logging2.ConfigureLogger(c)
e := echo.New()
e.HideBanner = true
@ -77,264 +52,25 @@ func execute(c *cli.Context) error {
log.Debug().Str("uri", c.String("elasticsearch-uri")).Msg("Using Elasticsearch server")
log.Debug().Str("uri", c.String("nats-uri")).Msg("Using NATS server")
// Connect to the NATS server
nc, err := nats.Connect(c.String("nats-uri"))
if err != nil {
log.Err(err).Str("uri", c.String("nats-uri")).Msg("Error while connecting to NATS server")
return err
}
defer nc.Close()
signingKey := []byte(c.String("signing-key"))
// Create Elasticsearch client
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
es, err := elastic.DialContext(ctx,
elastic.SetURL(c.String("elasticsearch-uri")),
elastic.SetSniff(false),
elastic.SetHealthcheck(false),
)
// Create the service
svc, err := newService(c, signingKey)
if err != nil {
log.Err(err).Msg("Error while creating ES client")
return err
}
// Setup ES one for all
if err := setupElasticSearch(ctx, es); err != nil {
log.Err(err).Msg("Unable to start API")
return err
}
// Setup middlewares
e.Use(middleware.JWT([]byte(c.String("signing-key"))))
e.Use(middleware.JWT(signingKey))
// Add endpoints
e.GET("/v1/resources", searchResources(es))
e.POST("/v1/resources", addResource(es))
e.POST("/v1/urls", scheduleURL(nc))
e.POST("/v1/sessions", authenticate())
e.GET("/v1/resources", searchResourcesEndpoint(svc))
e.POST("/v1/resources", addResourceEndpoint(svc))
e.POST("/v1/urls", scheduleURLEndpoint(svc))
e.POST("/v1/sessions", authenticateEndpoint(svc))
log.Info().Msg("Successfully initialized tdsh-api. Waiting for requests")
return e.Start(":8080")
}
func searchResources(es *elastic.Client) echo.HandlerFunc {
return func(c echo.Context) error {
withBody := false
if c.QueryParam("with-body") == "true" {
withBody = true
}
startDate := time.Time{}
if val := c.QueryParam("start-date"); val != "" {
d, err := time.Parse(time.RFC3339, val)
if err == nil {
startDate = d
}
}
endDate := time.Time{}
if val := c.QueryParam("end-date"); val != "" {
d, err := time.Parse(time.RFC3339, val)
if err == nil {
endDate = d
}
}
// First of all base64decode the URL
b64URL := c.QueryParam("url")
b, err := base64.URLEncoding.DecodeString(b64URL)
if err != nil {
log.Err(err).Msg("Error while decoding URL")
return c.NoContent(http.StatusUnprocessableEntity)
}
// Acquire pagination
p := readPagination(c)
from := (p.page - 1) * p.size
// Build up search query
query := buildSearchQuery(string(b), c.QueryParam("keyword"), startDate, endDate)
// Get total count
totalCount, err := es.Count(resourcesIndex).Query(query).Do(context.Background())
if err != nil {
log.Err(err).Msg("Error while counting on ES")
return c.NoContent(http.StatusInternalServerError)
}
// Perform the search request.
res, err := es.Search().
Index(resourcesIndex).
Query(query).
From(from).
Size(p.size).
Do(context.Background())
if err != nil {
log.Err(err).Msg("Error while searching on ES")
return c.NoContent(http.StatusInternalServerError)
}
var resources []api.ResourceDto
for _, hit := range res.Hits.Hits {
var resource api.ResourceDto
if err := json.Unmarshal(hit.Source, &resource); err != nil {
log.Warn().Str("err", err.Error()).Msg("Error while un-marshaling resource")
continue
}
// Remove body if not wanted
if !withBody {
resource.Body = ""
}
resources = append(resources, resource)
}
// Write pagination
writePagination(c, p, totalCount)
return c.JSON(http.StatusOK, resources)
}
}
func addResource(es *elastic.Client) echo.HandlerFunc {
return func(c echo.Context) error {
var resourceDto api.ResourceDto
if err := json.NewDecoder(c.Request().Body).Decode(&resourceDto); err != nil {
log.Err(err).Msg("Error while un-marshaling resource")
return c.NoContent(http.StatusUnprocessableEntity)
}
log.Debug().Str("url", resourceDto.URL).Msg("Saving resource")
// Create Elasticsearch document
doc := resourceIndex{
URL: resourceDto.URL,
Body: resourceDto.Body,
Title: resourceDto.Title,
Time: resourceDto.Time,
}
_, err := es.Index().
Index(resourcesIndex).
BodyJson(doc).
Do(context.Background())
if err != nil {
log.Err(err).Msg("Error while creating ES document")
return err
}
log.Debug().Str("url", resourceDto.URL).Msg("Successfully saved resource")
return c.JSON(http.StatusCreated, resourceDto)
}
}
func buildSearchQuery(url, keyword string, startDate, endDate time.Time) elastic.Query {
var queries []elastic.Query
if url != "" {
log.Trace().Str("url", url).Msg("SearchQuery: Setting url")
queries = append(queries, elastic.NewTermQuery("url", url))
}
if keyword != "" {
log.Trace().Str("body", keyword).Msg("SearchQuery: Setting body")
queries = append(queries, elastic.NewTermQuery("body", keyword))
}
if !startDate.IsZero() || !endDate.IsZero() {
timeQuery := elastic.NewRangeQuery("time")
if !startDate.IsZero() {
log.Trace().Str("startDate", startDate.Format(time.RFC3339)).Msg("SearchQuery: Setting startDate")
timeQuery.Gte(startDate.Format(time.RFC3339))
}
if !endDate.IsZero() {
log.Trace().Str("endDate", endDate.Format(time.RFC3339)).Msg("SearchQuery: Setting endDate")
timeQuery.Lte(endDate.Format(time.RFC3339))
}
queries = append(queries, timeQuery)
}
// Handle specific case
if len(queries) == 0 {
return elastic.NewMatchAllQuery()
}
if len(queries) == 1 {
return queries[0]
}
// otherwise AND combine them
return elastic.NewBoolQuery().Must(queries...)
}
func scheduleURL(nc *nats.Conn) echo.HandlerFunc {
return func(c echo.Context) error {
var url string
if err := json.NewDecoder(c.Request().Body).Decode(&url); err != nil {
log.Err(err).Msg("Error while un-marshaling URL")
return c.NoContent(http.StatusUnprocessableEntity)
}
// Publish the URL
if err := natsutil.PublishMsg(nc, &messaging.URLFoundMsg{URL: url}); err != nil {
log.Err(err).Msg("Unable to publish URL")
return c.NoContent(http.StatusInternalServerError)
}
log.Debug().Str("url", url).Msg("Successfully published URL")
return nil
}
}
func authenticate() echo.HandlerFunc {
return func(c echo.Context) error {
return nil // TODO
}
}
func setupElasticSearch(ctx context.Context, es *elastic.Client) error {
// Setup index if doesn't exist
exist, err := es.IndexExists(resourcesIndex).Do(ctx)
if err != nil {
log.Err(err).Str("index", resourcesIndex).Msg("Error while checking if index exist")
return err
}
if !exist {
log.Debug().Str("index", resourcesIndex).Msg("Creating missing index")
if _, err := es.CreateIndex(resourcesIndex).Do(ctx); err != nil {
log.Err(err).Str("index", resourcesIndex).Msg("Error while creating index")
return err
}
} else {
log.Debug().Msg("index exist")
}
return nil
}
func readPagination(c echo.Context) pagination {
paginationPage, err := strconv.Atoi(c.QueryParam(api.PaginationPageQueryParam))
if err != nil {
paginationPage = 1
}
paginationSize, err := strconv.Atoi(c.QueryParam(api.PaginationSizeQueryParam))
if err != nil {
paginationSize = defaultPaginationSize
}
// Prevent too much results from being returned
if paginationSize > maxPaginationSize {
paginationSize = maxPaginationSize
}
return pagination{
page: paginationPage,
size: paginationSize,
}
}
func writePagination(c echo.Context, p pagination, totalCount int64) {
c.Response().Header().Set(api.PaginationPageHeader, strconv.Itoa(p.page))
c.Response().Header().Set(api.PaginationSizeHeader, strconv.Itoa(p.size))
c.Response().Header().Set(api.PaginationCountHeader, strconv.FormatInt(totalCount, 10))
}

@ -0,0 +1,155 @@
package api
import (
"encoding/base64"
"github.com/creekorful/trandoshan/api"
"github.com/creekorful/trandoshan/internal/database"
"github.com/labstack/echo/v4"
"net/http"
"strconv"
"strings"
"time"
)
func searchResourcesEndpoint(s service) echo.HandlerFunc {
return func(c echo.Context) error {
searchParams, err := newSearchParams(c)
if err != nil {
return err
}
resources, total, err := s.searchResources(searchParams)
if err != nil {
return err
}
writePagination(c, searchParams, total)
return c.JSON(http.StatusOK, resources)
}
}
func addResourceEndpoint(s service) echo.HandlerFunc {
return func(c echo.Context) error {
var res api.ResourceDto
if err := c.Bind(&res); err != nil {
return err
}
res, err := s.addResource(res)
if err != nil {
return err
}
return c.JSON(http.StatusCreated, res)
}
}
func scheduleURLEndpoint(s service) echo.HandlerFunc {
return func(c echo.Context) error {
var url string
if err := c.Bind(&url); err != nil {
return err
}
return s.scheduleURL(url)
}
}
func authenticateEndpoint(s service) echo.HandlerFunc {
return func(c echo.Context) error {
// Validate provided credentials
var credentials api.CredentialsDto
if err := c.Bind(&credentials); err != nil {
return err
}
token, err := s.authenticate(credentials)
if err != nil {
return err
}
return c.JSON(http.StatusOK, token)
}
}
func readPagination(c echo.Context) (int, int) {
paginationPage, err := strconv.Atoi(c.QueryParam(api.PaginationPageQueryParam))
if err != nil {
paginationPage = 1
}
paginationSize, err := strconv.Atoi(c.QueryParam(api.PaginationSizeQueryParam))
if err != nil {
paginationSize = defaultPaginationSize
}
// Prevent too much results from being returned
if paginationSize > maxPaginationSize {
paginationSize = maxPaginationSize
}
return paginationPage, paginationSize
}
func writePagination(c echo.Context, s *database.ResSearchParams, totalCount int64) {
c.Response().Header().Set(api.PaginationPageHeader, strconv.Itoa(s.PageNumber))
c.Response().Header().Set(api.PaginationSizeHeader, strconv.Itoa(s.PageSize))
c.Response().Header().Set(api.PaginationCountHeader, strconv.FormatInt(totalCount, 10))
}
func newSearchParams(c echo.Context) (*database.ResSearchParams, error) {
params := &database.ResSearchParams{}
params.Keyword = c.QueryParam("keyword")
if c.QueryParam("with-body") == "true" {
params.WithBody = true
}
// extract raw query params (unescaped to keep + sign when parsing date)
rawQueryParams := getRawQueryParam(c.QueryString())
if val, exist := rawQueryParams["start-date"]; exist {
d, err := time.Parse(time.RFC3339, val)
if err == nil {
params.StartDate = d
} else {
return nil, err
}
}
if val, exist := rawQueryParams["end-date"]; exist {
d, err := time.Parse(time.RFC3339, val)
if err == nil {
params.EndDate = d
} else {
return nil, err
}
}
// First of all base64decode the URL
b64URL := c.QueryParam("url")
b, err := base64.URLEncoding.DecodeString(b64URL)
if err != nil {
return nil, err
}
params.URL = string(b)
// Acquire pagination
page, size := readPagination(c)
params.PageNumber = page
params.PageSize = size
return params, nil
}
func getRawQueryParam(url string) map[string]string {
val := map[string]string{}
parts := strings.Split(url, "&")
for _, part := range parts {
p := strings.Split(part, "=")
val[p[0]] = p[1]
}
return val
}

@ -0,0 +1,62 @@
package api
import (
"fmt"
"github.com/labstack/echo/v4"
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestNewSearchParams(t *testing.T) {
e := echo.New()
startDate := time.Now()
target := fmt.Sprintf("/resources?with-body=true&pagination-page=1&keyword=keyword&url=dXJs&start-date=%s", startDate.Format(time.RFC3339))
req := httptest.NewRequest(http.MethodPost, target, nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
params, err := newSearchParams(c)
if err != nil {
t.Errorf("error while parsing search params: %s", err)
t.FailNow()
}
if !params.WithBody {
t.Errorf("wrong withBody: %v", params.WithBody)
}
if params.PageSize != 50 {
t.Errorf("wrong pagination-size: %d", params.PageSize)
}
if params.PageNumber != 1 {
t.Errorf("wrong pagination-page: %d", params.PageNumber)
}
if params.Keyword != "keyword" {
t.Errorf("wrong keyword: %s", params.Keyword)
}
if params.StartDate.Year() != startDate.Year() {
t.Errorf("wrong start-date (year)")
}
if params.StartDate.Month() != startDate.Month() {
t.Errorf("wrong start-date (month)")
}
if params.StartDate.Day() != startDate.Day() {
t.Errorf("wrong start-date (day)")
}
if params.StartDate.Hour() != startDate.Hour() {
t.Errorf("wrong start-date (hour)")
}
if params.StartDate.Minute() != startDate.Minute() {
t.Errorf("wrong start-date (minute)")
}
if params.StartDate.Second() != startDate.Second() {
t.Errorf("wrong start-date (second)")
}
if params.URL != "url" {
t.Errorf("wrong url: %s", params.URL)
}
}

@ -0,0 +1,143 @@
package api
import (
"github.com/creekorful/trandoshan/api"
"github.com/creekorful/trandoshan/internal/database"
"github.com/creekorful/trandoshan/internal/messaging"
"github.com/dgrijalva/jwt-go"
"github.com/labstack/echo/v4"
"github.com/rs/zerolog/log"
"github.com/urfave/cli/v2"
"golang.org/x/crypto/bcrypt"
"net/http"
)
type service interface {
searchResources(params *database.ResSearchParams) ([]api.ResourceDto, int64, error)
addResource(res api.ResourceDto) (api.ResourceDto, error)
scheduleURL(url string) error
authenticate(credentials api.CredentialsDto) (string, error)
close()
}
type svc struct {
users map[string][]byte
signingKey []byte
db database.Database
pub messaging.Publisher
}
func newService(c *cli.Context, signingKey []byte) (service, error) {
// Connect to the NATS server
pub, err := messaging.NewPublisher(c.String("nats-uri"))
if err != nil {
log.Err(err).Str("uri", c.String("nats-uri")).Msg("Error while connecting to NATS server")
return nil, err
}
// Create Elasticsearch client
db, err := database.NewElasticDB(c.String("elasticsearch-uri"))
if err != nil {
log.Err(err).Msg("Error while connecting to the database")
return nil, err
}
return &svc{
db: db,
users: map[string][]byte{},
signingKey: signingKey,
pub: pub,
}, nil
}
func (s *svc) searchResources(params *database.ResSearchParams) ([]api.ResourceDto, int64, error) {
totalCount, err := s.db.CountResources(params)
if err != nil {
log.Err(err).Msg("Error while counting on ES")
return nil, 0, err
}
res, err := s.db.SearchResources(params)
if err != nil {
log.Err(err).Msg("Error while searching on ES")
return nil, 0, err
}
var resources []api.ResourceDto
for _, r := range res {
resources = append(resources, api.ResourceDto{
URL: r.URL,
Body: r.Body,
Title: r.Title,
Time: r.Time,
})
}
return resources, totalCount, nil
}
func (s *svc) addResource(res api.ResourceDto) (api.ResourceDto, error) {
log.Debug().Str("url", res.URL).Msg("Saving resource")
// Create Elasticsearch document
doc := database.ResourceIdx{
URL: res.URL,
Body: res.Body,
Title: res.Title,
Time: res.Time,
}
if err := s.db.AddResource(doc); err != nil {
log.Err(err).Msg("Error while adding resource")
return api.ResourceDto{}, err
}
log.Debug().Str("url", res.URL).Msg("Successfully saved resource")
return res, nil
}
func (s *svc) scheduleURL(url string) error {
// Publish the URL
if err := s.pub.PublishMsg(&messaging.URLFoundMsg{URL: url}); err != nil {
log.Err(err).Msg("Unable to publish URL")
return err
}
log.Debug().Str("url", url).Msg("Successfully published URL")
return nil
}
func (s *svc) authenticate(credentials api.CredentialsDto) (string, error) {
if credentials.Username == "" || credentials.Password == "" {
log.Warn().Msg("Invalid credentials supplied")
return "", echo.NewHTTPError(http.StatusUnprocessableEntity)
}
// Try to find the user
pass, exists := s.users[credentials.Username]
if !exists {
log.Warn().Str("username", credentials.Username).Msg("No user found")
return "", echo.NewHTTPError(http.StatusUnprocessableEntity)
}
// Validate provided password
if err := bcrypt.CompareHashAndPassword(pass, []byte(credentials.Password)); err != nil {
log.Warn().Str("username", credentials.Username).Msg("Invalid password")
return "", echo.NewHTTPError(http.StatusUnauthorized)
}
log.Debug().Str("username", credentials.Username).Msg("Successfully logged-in")
// Build JWT token
claims := jwt.MapClaims{
"username": credentials.Username,
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
// Sign JWT token
return token.SignedString(s.signingKey)
}
func (s *svc) close() {
s.pub.Close()
}

@ -0,0 +1,154 @@
package api
import (
"github.com/creekorful/trandoshan/api"
"github.com/creekorful/trandoshan/internal/database"
"github.com/creekorful/trandoshan/internal/database_mock"
"github.com/creekorful/trandoshan/internal/messaging"
"github.com/creekorful/trandoshan/internal/messaging_mock"
"github.com/dgrijalva/jwt-go"
"github.com/golang/mock/gomock"
"testing"
"time"
)
func TestSearchResources(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
params := &database.ResSearchParams{Keyword: "example"}
dbMock := database_mock.NewMockDatabase(mockCtrl)
dbMock.EXPECT().CountResources(params).Return(int64(150), nil)
dbMock.EXPECT().SearchResources(params).Return([]database.ResourceIdx{
{
URL: "example-1.onion",
Body: "Example 1",
Title: "Example 1",
Time: time.Time{},
},
{
URL: "example-2.onion",
Body: "Example 2",
Title: "Example 2",
Time: time.Time{},
},
}, nil)
s := svc{db: dbMock}
res, count, err := s.searchResources(params)
if err != nil {
t.FailNow()
}
if count != 150 {
t.Error()
}
if len(res) != 2 {
t.Error()
}
}
func TestAddResource(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
dbMock := database_mock.NewMockDatabase(mockCtrl)
dbMock.EXPECT().AddResource(database.ResourceIdx{
URL: "example.onion",
Body: "TheBody",
Title: "Example",
Time: time.Time{},
})
s := svc{db: dbMock}
res, err := s.addResource(api.ResourceDto{
URL: "example.onion",
Body: "TheBody",
Title: "Example",
Time: time.Time{},
})
if err != nil {
t.FailNow()
}
if res.URL != "example.onion" {
t.FailNow()
}
if res.Body != "TheBody" {
t.FailNow()
}
if res.Title != "Example" {
t.FailNow()
}
if !res.Time.IsZero() {
t.FailNow()
}
}
func TestScheduleURL(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
pubMock := messaging_mock.NewMockPublisher(mockCtrl)
s := svc{pub: pubMock}
pubMock.EXPECT().PublishMsg(&messaging.URLFoundMsg{URL: "example.onion"})
if err := s.scheduleURL("example.onion"); err != nil {
t.FailNow()
}
}
func TestAuthenticateInvalidCredentials(t *testing.T) {
s := svc{}
if _, err := s.authenticate(api.CredentialsDto{}); err == nil {
t.FailNow()
}
}
func TestAuthenticateWrongCredentials(t *testing.T) {
s := svc{users: map[string][]byte{"creekorful": []byte("")}}
if _, err := s.authenticate(api.CredentialsDto{Username: "johndoe", Password: "test"}); err == nil {
t.FailNow()
}
if _, err := s.authenticate(api.CredentialsDto{Username: "creekorful", Password: "tes"}); err == nil {
t.FailNow()
}
}
func TestAuthenticate(t *testing.T) {
s := svc{
users: map[string][]byte{
"creekorful": []byte("$2a$10$aLX2t8JsTOoy9iRLBNm.RuPMmcA8MCXijuzhLvUwUbSlh.C/D2eLm")},
signingKey: []byte("secret"),
}
tokenStr, err := s.authenticate(api.CredentialsDto{Username: "creekorful", Password: "test"})
if err != nil {
t.FailNow()
}
claims := jwt.MapClaims{}
token, err := jwt.ParseWithClaims(tokenStr, claims, func(token *jwt.Token) (interface{}, error) {
return []byte("secret"), nil
})
if err != nil {
t.Error(err)
t.FailNow()
}
if token.Header["alg"] != jwt.SigningMethodHS256.Alg() {
t.Errorf("Invalid alg: %s", token.Header["alg"])
}
if claims["username"] != "creekorful" {
t.Errorf("Invalid username: %s", claims["username"])
}
}

@ -3,9 +3,8 @@ package crawler
import (
"crypto/tls"
"fmt"
logging2 "github.com/creekorful/trandoshan/internal/logging"
"github.com/creekorful/trandoshan/internal/messaging"
"github.com/creekorful/trandoshan/internal/util/logging"
natsutil "github.com/creekorful/trandoshan/internal/util/nats"
"github.com/nats-io/nats.go"
"github.com/rs/zerolog/log"
"github.com/urfave/cli/v2"
@ -24,7 +23,7 @@ func GetApp() *cli.App {
Version: "0.4.0",
Usage: "Trandoshan crawler component",
Flags: []cli.Flag{
logging.GetLogFlag(),
logging2.GetLogFlag(),
&cli.StringFlag{
Name: "nats-uri",
Usage: "URI to the NATS server",
@ -51,7 +50,7 @@ func GetApp() *cli.App {
}
func execute(ctx *cli.Context) error {
logging.ConfigureLogger(ctx)
logging2.ConfigureLogger(ctx)
log.Info().Str("ver", ctx.App.Version).Msg("Starting tdsh-crawler")
@ -71,7 +70,7 @@ func execute(ctx *cli.Context) error {
}
// Create the NATS subscriber
sub, err := natsutil.NewSubscriber(ctx.String("nats-uri"))
sub, err := messaging.NewSubscriber(ctx.String("nats-uri"))
if err != nil {
return err
}
@ -87,10 +86,10 @@ func execute(ctx *cli.Context) error {
return nil
}
func handleMessage(httpClient *fasthttp.Client, allowedContentTypes []string) natsutil.MsgHandler {
return func(nc *nats.Conn, msg *nats.Msg) error {
func handleMessage(httpClient *fasthttp.Client, allowedContentTypes []string) messaging.MsgHandler {
return func(sub messaging.Subscriber, msg *nats.Msg) error {
var urlMsg messaging.URLTodoMsg
if err := natsutil.ReadMsg(msg, &urlMsg); err != nil {
if err := sub.ReadMsg(msg, &urlMsg); err != nil {
return err
}
@ -105,7 +104,7 @@ func handleMessage(httpClient *fasthttp.Client, allowedContentTypes []string) na
URL: urlMsg.URL,
Body: body,
}
if err := natsutil.PublishMsg(nc, &res); err != nil {
if err := sub.PublishMsg(&res); err != nil {
log.Err(err).Msg("Error while publishing resource body")
}

@ -0,0 +1,182 @@
package database
import (
"context"
"encoding/json"
"github.com/olivere/elastic/v7"
"github.com/rs/zerolog/log"
"time"
)
//go:generate mockgen -source database.go -destination=../database_mock/database_mock.go -package=database_mock
var resourcesIndex = "resources"
// ResourceIdx represent a resource as stored in elasticsearch
type ResourceIdx struct {
URL string `json:"url"`
Body string `json:"body"`
Title string `json:"title"`
Time time.Time `json:"time"`
}
// ResSearchParams is the search params used
type ResSearchParams struct {
URL string
Keyword string
StartDate time.Time
EndDate time.Time
WithBody bool
PageSize int
PageNumber int
}
// Database is the interface used to abstract communication
// with the persistence unit
type Database interface {
SearchResources(params *ResSearchParams) ([]ResourceIdx, error)
CountResources(params *ResSearchParams) (int64, error)
AddResource(res ResourceIdx) error
}
type elasticSearchDB struct {
client *elastic.Client
}
// NewElasticDB create a new Database based on ES instance
func NewElasticDB(uri string) (Database, error) {
// Create Elasticsearch client
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
ec, err := elastic.DialContext(ctx,
elastic.SetURL(uri),
elastic.SetSniff(false),
elastic.SetHealthcheck(false),
)
if err != nil {
log.Err(err).Msg("Error while creating ES client")
return nil, err
}
if err := setupElasticSearch(ctx, ec); err != nil {
return nil, err
}
return &elasticSearchDB{
client: ec,
}, nil
}
func (e *elasticSearchDB) SearchResources(params *ResSearchParams) ([]ResourceIdx, error) {
q := buildSearchQuery(params)
from := (params.PageNumber - 1) * params.PageSize
res, err := e.client.Search().
Index(resourcesIndex).
Query(q).
From(from).
Size(params.PageSize).
Do(context.Background())
if err != nil {
log.Err(err).Msg("Error while searching on ES")
return nil, err
}
var resources []ResourceIdx
for _, hit := range res.Hits.Hits {
var resource ResourceIdx
if err := json.Unmarshal(hit.Source, &resource); err != nil {
log.Warn().Str("err", err.Error()).Msg("Error while un-marshaling resource")
continue
}
// Remove body if not wanted
if !params.WithBody {
resource.Body = ""
}
resources = append(resources, resource)
}
return resources, nil
}
func (e *elasticSearchDB) CountResources(params *ResSearchParams) (int64, error) {
q := buildSearchQuery(params)
count, err := e.client.Count(resourcesIndex).Query(q).Do(context.Background())
if err != nil {
return 0, err
}
return count, nil
}
func (e *elasticSearchDB) AddResource(res ResourceIdx) error {
_, err := e.client.Index().
Index(resourcesIndex).
BodyJson(res).
Do(context.Background())
return err
}
func buildSearchQuery(params *ResSearchParams) elastic.Query {
var queries []elastic.Query
if params.URL != "" {
log.Trace().Str("url", params.URL).Msg("SearchQuery: Setting url")
queries = append(queries, elastic.NewTermQuery("url", params.URL))
}
if params.Keyword != "" {
log.Trace().Str("body", params.Keyword).Msg("SearchQuery: Setting body")
queries = append(queries, elastic.NewTermQuery("body", params.Keyword))
}
if !params.StartDate.IsZero() || !params.EndDate.IsZero() {
timeQuery := elastic.NewRangeQuery("time")
if !params.StartDate.IsZero() {
log.Trace().
Str("startDate", params.StartDate.Format(time.RFC3339)).
Msg("SearchQuery: Setting startDate")
timeQuery.Gte(params.StartDate.Format(time.RFC3339))
}
if !params.EndDate.IsZero() {
log.Trace().
Str("endDate", params.EndDate.Format(time.RFC3339)).
Msg("SearchQuery: Setting endDate")
timeQuery.Lte(params.EndDate.Format(time.RFC3339))
}
queries = append(queries, timeQuery)
}
// Handle specific case
if len(queries) == 0 {
return elastic.NewMatchAllQuery()
}
if len(queries) == 1 {
return queries[0]
}
// otherwise AND combine them
return elastic.NewBoolQuery().Must(queries...)
}
func setupElasticSearch(ctx context.Context, es *elastic.Client) error {
// Setup index if doesn't exist
exist, err := es.IndexExists(resourcesIndex).Do(ctx)
if err != nil {
log.Err(err).Str("index", resourcesIndex).Msg("Error while checking if index exist")
return err
}
if !exist {
log.Debug().Str("index", resourcesIndex).Msg("Creating missing index")
if _, err := es.CreateIndex(resourcesIndex).Do(ctx); err != nil {
log.Err(err).Str("index", resourcesIndex).Msg("Error while creating index")
return err
}
} else {
log.Debug().Msg("index exist")
}
return nil
}

@ -4,9 +4,8 @@ import (
"fmt"
"github.com/PuerkitoBio/purell"
"github.com/creekorful/trandoshan/api"
logging2 "github.com/creekorful/trandoshan/internal/logging"
"github.com/creekorful/trandoshan/internal/messaging"
"github.com/creekorful/trandoshan/internal/util/logging"
natsutil "github.com/creekorful/trandoshan/internal/util/nats"
"github.com/nats-io/nats.go"
"github.com/rs/zerolog/log"
"github.com/urfave/cli/v2"
@ -27,7 +26,7 @@ func GetApp() *cli.App {
Version: "0.4.0",
Usage: "Trandoshan extractor component",
Flags: []cli.Flag{
logging.GetLogFlag(),
logging2.GetLogFlag(),
&cli.StringFlag{
Name: "nats-uri",
Usage: "URI to the NATS server",
@ -44,7 +43,7 @@ func GetApp() *cli.App {
}
func execute(ctx *cli.Context) error {
logging.ConfigureLogger(ctx)
logging2.ConfigureLogger(ctx)
log.Info().Str("ver", ctx.App.Version).Msg("Starting tdsh-extractor")
@ -55,7 +54,7 @@ func execute(ctx *cli.Context) error {
apiClient := api.NewClient(ctx.String("api-uri"))
// Create the NATS subscriber
sub, err := natsutil.NewSubscriber(ctx.String("nats-uri"))
sub, err := messaging.NewSubscriber(ctx.String("nats-uri"))
if err != nil {
return err
}
@ -71,10 +70,10 @@ func execute(ctx *cli.Context) error {
return nil
}
func handleMessage(apiClient api.Client) natsutil.MsgHandler {
return func(nc *nats.Conn, msg *nats.Msg) error {
func handleMessage(apiClient api.Client) messaging.MsgHandler {
return func(sub messaging.Subscriber, msg *nats.Msg) error {
var resMsg messaging.NewResourceMsg
if err := natsutil.ReadMsg(msg, &resMsg); err != nil {
if err := sub.ReadMsg(msg, &resMsg); err != nil {
log.Err(err).Msg("Error while reading message")
return err
}
@ -101,7 +100,7 @@ func handleMessage(apiClient api.Client) natsutil.MsgHandler {
Str("url", url).
Msg("Publishing found URL")
if err := natsutil.PublishMsg(nc, &messaging.URLFoundMsg{URL: url}); err != nil {
if err := sub.PublishMsg(&messaging.URLFoundMsg{URL: url}); err != nil {
log.Warn().
Str("url", url).
Str("err", err.Error()).

@ -1,4 +1,4 @@
package nats
package messaging
// Msg represent a message send-able trough NATS
type Msg interface {

@ -0,0 +1,48 @@
package messaging
import (
"encoding/json"
"fmt"
"github.com/nats-io/nats.go"
)
//go:generate mockgen -source publisher.go -destination=../messaging_mock/publisher_mock.go -package=messaging_mock
// Publisher is something that push msg to an event queue
type Publisher interface {
PublishMsg(msg Msg) error
Close()
}
type publisher struct {
nc *nats.Conn
}
// NewPublisher create a new Publisher instance
func NewPublisher(natsURI string) (Publisher, error) {
nc, err := nats.Connect(natsURI)
if err != nil {
return nil, err
}
return &publisher{
nc: nc,
}, nil
}
func (p *publisher) PublishMsg(msg Msg) error {
return publishJSON(p.nc, msg.Subject(), msg)
}
func (p *publisher) Close() {
p.nc.Close()
}
func publishJSON(nc *nats.Conn, subject string, msg interface{}) error {
msgBytes, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("error while encoding message: %s", err)
}
return nc.Publish(subject, msgBytes)
}

@ -0,0 +1,78 @@
package messaging
import (
"context"
"encoding/json"
"fmt"
"github.com/nats-io/nats.go"
)
// MsgHandler represent an handler for a NATS subscriber
type MsgHandler func(s Subscriber, msg *nats.Msg) error
// Subscriber is something that read msg from an event queue
type Subscriber interface {
Publisher
ReadMsg(natsMsg *nats.Msg, msg Msg) error
QueueSubscribe(subject, queue string, handler MsgHandler) error
Close()
}
// Subscriber represent a NATS subscriber
type subscriber struct {
nc *nats.Conn
}
// NewSubscriber create a new subscriber and connect it to given NATS server
func NewSubscriber(address string) (Subscriber, error) {
nc, err := nats.Connect(address)
if err != nil {
return nil, err
}
return &subscriber{
nc: nc,
}, nil
}
func (s *subscriber) ReadMsg(natsMsg *nats.Msg, msg Msg) error {
return readJSON(natsMsg, msg)
}
func (s *subscriber) QueueSubscribe(subject, queue string, handler MsgHandler) error {
// Create the subscriber
sub, err := s.nc.QueueSubscribeSync(subject, queue)
if err != nil {
return err
}
for {
// Read incoming message
msg, err := sub.NextMsgWithContext(context.Background())
if err != nil {
continue
}
// ... And process it
if err := handler(s, msg); err != nil {
continue
}
}
}
func (s *subscriber) PublishMsg(msg Msg) error {
return publishJSON(s.nc, msg.Subject(), msg)
}
func (s *subscriber) Close() {
s.nc.Close()
}
func readJSON(msg *nats.Msg, body interface{}) error {
if err := json.Unmarshal(msg.Data, body); err != nil {
return fmt.Errorf("error while decoding message: %s", err)
}
return nil
}

@ -4,9 +4,8 @@ import (
"encoding/base64"
"fmt"
"github.com/creekorful/trandoshan/api"
logging2 "github.com/creekorful/trandoshan/internal/logging"
"github.com/creekorful/trandoshan/internal/messaging"
"github.com/creekorful/trandoshan/internal/util/logging"
natsutil "github.com/creekorful/trandoshan/internal/util/nats"
"github.com/nats-io/nats.go"
"github.com/rs/zerolog/log"
"github.com/urfave/cli/v2"
@ -23,7 +22,7 @@ func GetApp() *cli.App {
Version: "0.4.0",
Usage: "Trandoshan scheduler component",
Flags: []cli.Flag{
logging.GetLogFlag(),
logging2.GetLogFlag(),
&cli.StringFlag{
Name: "nats-uri",
Usage: "URI to the NATS server",
@ -44,7 +43,7 @@ func GetApp() *cli.App {
}
func execute(ctx *cli.Context) error {
logging.ConfigureLogger(ctx)
logging2.ConfigureLogger(ctx)
log.Info().Str("ver", ctx.App.Version).Msg("Starting tdsh-scheduler")
@ -62,7 +61,7 @@ func execute(ctx *cli.Context) error {
apiClient := api.NewClient(ctx.String("api-uri"))
// Create the NATS subscriber
sub, err := natsutil.NewSubscriber(ctx.String("nats-uri"))
sub, err := messaging.NewSubscriber(ctx.String("nats-uri"))
if err != nil {
return err
}
@ -77,10 +76,10 @@ func execute(ctx *cli.Context) error {
return nil
}
func handleMessage(apiClient api.Client, refreshDelay time.Duration) natsutil.MsgHandler {
return func(nc *nats.Conn, msg *nats.Msg) error {
func handleMessage(apiClient api.Client, refreshDelay time.Duration) messaging.MsgHandler {
return func(sub messaging.Subscriber, msg *nats.Msg) error {
var urlMsg messaging.URLFoundMsg
if err := natsutil.ReadJSON(msg, &urlMsg); err != nil {
if err := sub.ReadMsg(msg, &urlMsg); err != nil {
return err
}
@ -115,7 +114,7 @@ func handleMessage(apiClient api.Client, refreshDelay time.Duration) natsutil.Ms
// No matches: schedule!
if len(urls) == 0 {
log.Debug().Stringer("url", u).Msg("URL should be scheduled")
if err := natsutil.PublishMsg(nc, &messaging.URLTodoMsg{URL: urlMsg.URL}); err != nil {
if err := sub.PublishMsg(&messaging.URLTodoMsg{URL: urlMsg.URL}); err != nil {
return fmt.Errorf("error while publishing URL: %s", err)
}
} else {

@ -3,7 +3,7 @@ package trandoshanctl
import (
"fmt"
"github.com/creekorful/trandoshan/api"
"github.com/creekorful/trandoshan/internal/util/logging"
logging2 "github.com/creekorful/trandoshan/internal/logging"
"github.com/olekukonko/tablewriter"
"github.com/rs/zerolog/log"
"github.com/urfave/cli/v2"
@ -18,7 +18,7 @@ func GetApp() *cli.App {
Version: "0.4.0",
Usage: "Trandoshan CLI",
Flags: []cli.Flag{
logging.GetLogFlag(),
logging2.GetLogFlag(),
&cli.StringFlag{
Name: "api-uri",
Usage: "URI to the API server",
@ -44,7 +44,7 @@ func GetApp() *cli.App {
}
func before(ctx *cli.Context) error {
logging.ConfigureLogger(ctx)
logging2.ConfigureLogger(ctx)
return nil
}

@ -1,36 +0,0 @@
package nats
import (
"encoding/json"
"fmt"
"github.com/nats-io/nats.go"
)
// PublishMsg publish given Msg
func PublishMsg(nc *nats.Conn, msg Msg) error {
return PublishJSON(nc, msg.Subject(), msg)
}
// ReadMsg read message from given connection
func ReadMsg(nc *nats.Msg, msg Msg) error {
return ReadJSON(nc, msg)
}
// PublishJSON publish given message serialized in json with given subject
func PublishJSON(nc *nats.Conn, subject string, msg interface{}) error {
msgBytes, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("error while encoding message: %s", err)
}
return nc.Publish(subject, msgBytes)
}
// ReadJSON read given encoded json message and deserialize into into given structure
func ReadJSON(msg *nats.Msg, body interface{}) error {
if err := json.Unmarshal(msg.Data, body); err != nil {
return fmt.Errorf("error while decoding message: %s", err)
}
return nil
}

@ -1,56 +0,0 @@
package nats
import (
"context"
"github.com/nats-io/nats.go"
"github.com/rs/zerolog/log"
)
// MsgHandler represent an handler for a NATS subscriber
type MsgHandler func(nc *nats.Conn, msg *nats.Msg) error
// Subscriber represent a NATS subscriber
type Subscriber struct {
nc *nats.Conn
}
// NewSubscriber create a new subscriber and connect it to given NATS server
func NewSubscriber(address string) (*Subscriber, error) {
nc, err := nats.Connect(address)
if err != nil {
return nil, err
}
return &Subscriber{
nc: nc,
}, nil
}
// QueueSubscribe subscribe to given subject, with given queue
func (qs *Subscriber) QueueSubscribe(subject, queue string, handler MsgHandler) error {
// Create the subscriber
sub, err := qs.nc.QueueSubscribeSync(subject, queue)
if err != nil {
return err
}
for {
// Read incoming message
msg, err := sub.NextMsgWithContext(context.Background())
if err != nil {
log.Warn().Str("err", err.Error()).Msg("Skipping current message because of error")
continue
}
// ... And process it
if err := handler(qs.nc, msg); err != nil {
log.Warn().Str("error", err.Error()).Msg("Skipping current message because of error")
continue
}
}
}
// Close terminate the connection to the NATS server
func (qs *Subscriber) Close() {
qs.nc.Close()
}
Loading…
Cancel
Save