mirror of https://github.com/oxen-io/lokinet
use lokimq workers instead of llarp:🧵:ThreadPool
parent
30b158b906
commit
f4971a88fd
@ -1,24 +0,0 @@
|
||||
#ifndef LLARP_LOGGER_H
|
||||
#define LLARP_LOGGER_H
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C"
|
||||
{
|
||||
enum LogLevel
|
||||
{
|
||||
eLogDebug,
|
||||
eLogInfo,
|
||||
eLogWarn,
|
||||
eLogError,
|
||||
eLogNone
|
||||
};
|
||||
|
||||
void
|
||||
cSetLogLevel(enum LogLevel lvl);
|
||||
|
||||
void
|
||||
cSetLogNodeName(const char* name);
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
@ -1,331 +0,0 @@
|
||||
#include <util/thread/thread_pool.hpp>
|
||||
|
||||
#include <util/thread/threading.hpp>
|
||||
|
||||
namespace llarp
|
||||
{
|
||||
namespace thread
|
||||
{
|
||||
void
|
||||
ThreadPool::join()
|
||||
{
|
||||
for (auto& t : m_threads)
|
||||
{
|
||||
if (t.joinable())
|
||||
{
|
||||
t.join();
|
||||
}
|
||||
}
|
||||
|
||||
m_createdThreads = 0;
|
||||
}
|
||||
|
||||
void
|
||||
ThreadPool::runJobs()
|
||||
{
|
||||
while (m_status.load(std::memory_order_relaxed) == Status::Run)
|
||||
{
|
||||
auto functor = m_queue.tryPopFront();
|
||||
|
||||
if (functor)
|
||||
{
|
||||
(*functor)();
|
||||
}
|
||||
else
|
||||
{
|
||||
m_idleThreads++;
|
||||
|
||||
if (m_status == Status::Run && m_queue.empty())
|
||||
{
|
||||
m_semaphore.wait();
|
||||
}
|
||||
|
||||
m_idleThreads.fetch_sub(1, std::memory_order_relaxed);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
ThreadPool::drainQueue()
|
||||
{
|
||||
while (m_status.load(std::memory_order_relaxed) == Status::Drain)
|
||||
{
|
||||
auto functor = m_queue.tryPopFront();
|
||||
|
||||
if (!functor)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
(*functor)();
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
ThreadPool::waitThreads()
|
||||
{
|
||||
std::unique_lock lock{m_gateMutex};
|
||||
m_numThreadsCV.wait(lock, [this] { return allThreadsReady(); });
|
||||
}
|
||||
|
||||
void
|
||||
ThreadPool::releaseThreads()
|
||||
{
|
||||
{
|
||||
std::lock_guard lock{m_gateMutex};
|
||||
m_numThreadsReady = 0;
|
||||
++m_gateCount;
|
||||
}
|
||||
m_gateCV.notify_all();
|
||||
}
|
||||
|
||||
void
|
||||
ThreadPool::interrupt()
|
||||
{
|
||||
std::lock_guard lock{m_gateMutex};
|
||||
|
||||
size_t count = m_idleThreads;
|
||||
|
||||
for (size_t i = 0; i < count; ++i)
|
||||
{
|
||||
m_semaphore.notify();
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
ThreadPool::worker()
|
||||
{
|
||||
// Lock will be valid until the end of the statement
|
||||
size_t gateCount = (std::lock_guard{m_gateMutex}, m_gateCount);
|
||||
|
||||
util::SetThreadName(m_name);
|
||||
|
||||
for (;;)
|
||||
{
|
||||
{
|
||||
std::unique_lock lock{m_gateMutex};
|
||||
++m_numThreadsReady;
|
||||
m_numThreadsCV.notify_one();
|
||||
|
||||
m_gateCV.wait(lock, [&] { return gateCount != m_gateCount; });
|
||||
|
||||
gateCount = m_gateCount;
|
||||
}
|
||||
|
||||
Status status = m_status.load(std::memory_order_relaxed);
|
||||
|
||||
// Can't use a switch here as we want to load and fall through.
|
||||
|
||||
if (status == Status::Run)
|
||||
{
|
||||
runJobs();
|
||||
status = m_status;
|
||||
}
|
||||
|
||||
if (status == Status::Drain)
|
||||
{
|
||||
drainQueue();
|
||||
}
|
||||
else if (status == Status::Suspend)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(status == Status::Stop);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool
|
||||
ThreadPool::spawn()
|
||||
{
|
||||
try
|
||||
{
|
||||
m_threads.at(m_createdThreads) = std::thread(std::bind(&ThreadPool::worker, this));
|
||||
++m_createdThreads;
|
||||
return true;
|
||||
}
|
||||
catch (const std::system_error&)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
ThreadPool::ThreadPool(size_t numThreads, size_t maxJobs, std::string_view name)
|
||||
: m_queue(maxJobs)
|
||||
, m_semaphore(0)
|
||||
, m_idleThreads(0)
|
||||
, m_status(Status::Stop)
|
||||
, m_gateCount(0)
|
||||
, m_numThreadsReady(0)
|
||||
, m_name(name)
|
||||
, m_threads(numThreads)
|
||||
, m_createdThreads(0)
|
||||
{
|
||||
assert(numThreads != 0);
|
||||
assert(maxJobs != 0);
|
||||
disable();
|
||||
}
|
||||
|
||||
ThreadPool::~ThreadPool()
|
||||
{
|
||||
shutdown();
|
||||
}
|
||||
|
||||
bool
|
||||
ThreadPool::addJob(const Job& job)
|
||||
{
|
||||
assert(job);
|
||||
|
||||
QueueReturn ret = m_queue.pushBack(job);
|
||||
|
||||
if (ret == QueueReturn::Success && m_idleThreads > 0)
|
||||
{
|
||||
m_semaphore.notify();
|
||||
}
|
||||
|
||||
return ret == QueueReturn::Success;
|
||||
}
|
||||
bool
|
||||
ThreadPool::addJob(Job&& job)
|
||||
{
|
||||
assert(job);
|
||||
QueueReturn ret = m_queue.pushBack(std::move(job));
|
||||
|
||||
if (ret == QueueReturn::Success && m_idleThreads > 0)
|
||||
{
|
||||
m_semaphore.notify();
|
||||
}
|
||||
|
||||
return ret == QueueReturn::Success;
|
||||
}
|
||||
|
||||
bool
|
||||
ThreadPool::tryAddJob(const Job& job)
|
||||
{
|
||||
assert(job);
|
||||
QueueReturn ret = m_queue.tryPushBack(job);
|
||||
|
||||
if (ret == QueueReturn::Success && m_idleThreads > 0)
|
||||
{
|
||||
m_semaphore.notify();
|
||||
}
|
||||
|
||||
return ret == QueueReturn::Success;
|
||||
}
|
||||
|
||||
bool
|
||||
ThreadPool::tryAddJob(Job&& job)
|
||||
{
|
||||
assert(job);
|
||||
QueueReturn ret = m_queue.tryPushBack(std::move(job));
|
||||
|
||||
if (ret == QueueReturn::Success && m_idleThreads > 0)
|
||||
{
|
||||
m_semaphore.notify();
|
||||
}
|
||||
|
||||
return ret == QueueReturn::Success;
|
||||
}
|
||||
|
||||
void
|
||||
ThreadPool::drain()
|
||||
{
|
||||
util::Lock lock(m_mutex);
|
||||
|
||||
if (m_status.load(std::memory_order_relaxed) == Status::Run)
|
||||
{
|
||||
m_status = Status::Drain;
|
||||
|
||||
interrupt();
|
||||
waitThreads();
|
||||
|
||||
m_status = Status::Run;
|
||||
|
||||
releaseThreads();
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
ThreadPool::shutdown()
|
||||
{
|
||||
util::Lock lock(m_mutex);
|
||||
|
||||
if (m_status.load(std::memory_order_relaxed) == Status::Run)
|
||||
{
|
||||
m_queue.disable();
|
||||
m_status = Status::Stop;
|
||||
|
||||
interrupt();
|
||||
m_queue.removeAll();
|
||||
|
||||
join();
|
||||
}
|
||||
}
|
||||
|
||||
bool
|
||||
ThreadPool::start()
|
||||
{
|
||||
util::Lock lock(m_mutex);
|
||||
|
||||
if (m_status.load(std::memory_order_relaxed) != Status::Stop)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
for (auto it = (m_threads.begin() + m_createdThreads); it != m_threads.end(); ++it)
|
||||
{
|
||||
if (!spawn())
|
||||
{
|
||||
releaseThreads();
|
||||
|
||||
join();
|
||||
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
waitThreads();
|
||||
|
||||
m_queue.enable();
|
||||
m_status = Status::Run;
|
||||
|
||||
// `releaseThreads` has a release barrier so workers don't return from
|
||||
// wait and not see the above store.
|
||||
|
||||
releaseThreads();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void
|
||||
ThreadPool::stop()
|
||||
{
|
||||
util::Lock lock(m_mutex);
|
||||
|
||||
if (m_status.load(std::memory_order_relaxed) == Status::Run)
|
||||
{
|
||||
m_queue.disable();
|
||||
m_status = Status::Drain;
|
||||
|
||||
// `interrupt` has an acquire barrier (locks a mutex), so nothing will
|
||||
// be executed before the above store to `status`.
|
||||
interrupt();
|
||||
|
||||
waitThreads();
|
||||
|
||||
m_status = Status::Stop;
|
||||
|
||||
// `releaseThreads` has a release barrier so workers don't return from
|
||||
// wait and not see the above store.
|
||||
|
||||
releaseThreads();
|
||||
|
||||
join();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace thread
|
||||
} // namespace llarp
|
@ -1,216 +0,0 @@
|
||||
#ifndef LLARP_THREAD_POOL_HPP
|
||||
#define LLARP_THREAD_POOL_HPP
|
||||
|
||||
#include <util/thread/queue.hpp>
|
||||
#include <util/thread/threading.hpp>
|
||||
|
||||
#include <atomic>
|
||||
#include <functional>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <string_view>
|
||||
|
||||
namespace llarp
|
||||
{
|
||||
namespace thread
|
||||
{
|
||||
class ThreadPool
|
||||
{
|
||||
// Provide an efficient fixed size threadpool. The following attributes
|
||||
// of the threadpool are fixed at construction time:
|
||||
// - the max number of pending jobs
|
||||
// - the number of threads
|
||||
public:
|
||||
using Job = std::function<void()>;
|
||||
using JobQueue = Queue<Job>;
|
||||
|
||||
enum class Status
|
||||
{
|
||||
Stop,
|
||||
Run,
|
||||
Suspend,
|
||||
Drain
|
||||
};
|
||||
|
||||
private:
|
||||
JobQueue m_queue; // The job queue
|
||||
util::Semaphore m_semaphore; // The semaphore for the queue.
|
||||
|
||||
std::atomic_size_t m_idleThreads; // Number of idle threads
|
||||
|
||||
util::Mutex m_mutex;
|
||||
|
||||
std::atomic<Status> m_status;
|
||||
|
||||
size_t m_gateCount GUARDED_BY(m_gateMutex);
|
||||
size_t m_numThreadsReady GUARDED_BY(m_gateMutex); // Threads ready to go through the gate.
|
||||
|
||||
std::mutex m_gateMutex;
|
||||
std::condition_variable m_gateCV;
|
||||
std::condition_variable m_numThreadsCV;
|
||||
|
||||
std::string m_name;
|
||||
std::vector<std::thread> m_threads;
|
||||
size_t m_createdThreads;
|
||||
|
||||
void
|
||||
join();
|
||||
|
||||
void
|
||||
runJobs();
|
||||
|
||||
void
|
||||
drainQueue();
|
||||
|
||||
void
|
||||
waitThreads();
|
||||
|
||||
void
|
||||
releaseThreads();
|
||||
|
||||
void
|
||||
interrupt();
|
||||
|
||||
void
|
||||
worker();
|
||||
|
||||
bool
|
||||
spawn();
|
||||
|
||||
bool
|
||||
allThreadsReady() const REQUIRES_SHARED(m_gateMutex)
|
||||
{
|
||||
return m_numThreadsReady == m_threads.size();
|
||||
}
|
||||
|
||||
public:
|
||||
ThreadPool(size_t numThreads, size_t maxJobs, std::string_view name);
|
||||
|
||||
~ThreadPool();
|
||||
|
||||
// Disable the threadpool. Calls to `addJob` and `tryAddJob` will fail.
|
||||
// Jobs currently in the pool will not be affected.
|
||||
void
|
||||
disable();
|
||||
|
||||
void
|
||||
enable();
|
||||
|
||||
// Add a job to the bool. Note this call will block if the underlying
|
||||
// queue is full.
|
||||
// Returns false if the queue is currently disabled.
|
||||
bool
|
||||
addJob(const Job& job);
|
||||
bool
|
||||
addJob(Job&& job);
|
||||
|
||||
// Try to add a job to the pool. If the queue is full, or the queue is
|
||||
// disabled, return false.
|
||||
// This call will not block.
|
||||
bool
|
||||
tryAddJob(const Job& job);
|
||||
bool
|
||||
tryAddJob(Job&& job);
|
||||
|
||||
// Wait until all current jobs are complete.
|
||||
// If any jobs are submitted during this time, they **may** or **may not**
|
||||
// run.
|
||||
void
|
||||
drain();
|
||||
|
||||
// Disable this pool, and cancel all pending jobs. After all currently
|
||||
// running jobs are complete, join with the threads in the pool.
|
||||
void
|
||||
shutdown();
|
||||
|
||||
// Start this threadpool by spawning `threadCount()` threads.
|
||||
bool
|
||||
start();
|
||||
|
||||
// Disable queueing on this threadpool and wait until all pending jobs
|
||||
// have finished.
|
||||
void
|
||||
stop();
|
||||
|
||||
bool
|
||||
enabled() const;
|
||||
|
||||
bool
|
||||
started() const;
|
||||
|
||||
size_t
|
||||
activeThreadCount() const;
|
||||
|
||||
// Current number of queued jobs
|
||||
size_t
|
||||
jobCount() const;
|
||||
|
||||
// Number of threads passed in the constructor
|
||||
size_t
|
||||
threadCount() const;
|
||||
|
||||
// Number of threads currently started in the threadpool
|
||||
size_t
|
||||
startedThreadCount() const;
|
||||
|
||||
// Max number of queued jobs
|
||||
size_t
|
||||
capacity() const;
|
||||
};
|
||||
|
||||
inline void
|
||||
ThreadPool::disable()
|
||||
{
|
||||
m_queue.disable();
|
||||
}
|
||||
|
||||
inline void
|
||||
ThreadPool::enable()
|
||||
{
|
||||
m_queue.enable();
|
||||
}
|
||||
|
||||
inline bool
|
||||
ThreadPool::enabled() const
|
||||
{
|
||||
return m_queue.enabled();
|
||||
}
|
||||
|
||||
inline size_t
|
||||
ThreadPool::activeThreadCount() const
|
||||
{
|
||||
if (m_threads.size() == m_createdThreads)
|
||||
{
|
||||
return m_threads.size() - m_idleThreads.load(std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
inline size_t
|
||||
ThreadPool::threadCount() const
|
||||
{
|
||||
return m_threads.size();
|
||||
}
|
||||
|
||||
inline size_t
|
||||
ThreadPool::startedThreadCount() const
|
||||
{
|
||||
return m_createdThreads;
|
||||
}
|
||||
|
||||
inline size_t
|
||||
ThreadPool::jobCount() const
|
||||
{
|
||||
return m_queue.size();
|
||||
}
|
||||
|
||||
inline size_t
|
||||
ThreadPool::capacity() const
|
||||
{
|
||||
return m_queue.capacity();
|
||||
}
|
||||
} // namespace thread
|
||||
} // namespace llarp
|
||||
|
||||
#endif
|
@ -1,89 +0,0 @@
|
||||
#include <util/logging/logger.hpp>
|
||||
#include <util/time.hpp>
|
||||
#include <util/thread/threadpool.h>
|
||||
#include <util/thread/thread_pool.hpp>
|
||||
|
||||
#include <cstring>
|
||||
#include <functional>
|
||||
#include <queue>
|
||||
|
||||
struct llarp_threadpool*
|
||||
llarp_init_threadpool(int workers, const char* name, size_t queueLength)
|
||||
{
|
||||
if (workers <= 0)
|
||||
workers = 1;
|
||||
return new llarp_threadpool(workers, name, queueLength);
|
||||
}
|
||||
|
||||
void
|
||||
llarp_threadpool_join(struct llarp_threadpool* pool)
|
||||
{
|
||||
llarp::LogDebug("threadpool join");
|
||||
if (pool->impl)
|
||||
pool->impl->stop();
|
||||
pool->impl.reset();
|
||||
}
|
||||
|
||||
void
|
||||
llarp_threadpool_start(struct llarp_threadpool* pool)
|
||||
{
|
||||
if (pool->impl)
|
||||
pool->impl->start();
|
||||
}
|
||||
|
||||
void
|
||||
llarp_threadpool_stop(struct llarp_threadpool* pool)
|
||||
{
|
||||
llarp::LogDebug("threadpool stop");
|
||||
if (pool->impl)
|
||||
pool->impl->disable();
|
||||
}
|
||||
|
||||
bool
|
||||
llarp_threadpool_queue_job(struct llarp_threadpool* pool, struct llarp_thread_job job)
|
||||
{
|
||||
return llarp_threadpool_queue_job(pool, std::bind(job.work, job.user));
|
||||
}
|
||||
|
||||
bool
|
||||
llarp_threadpool_queue_job(struct llarp_threadpool* pool, std::function<void(void)> func)
|
||||
{
|
||||
return pool->impl && pool->impl->addJob(func);
|
||||
}
|
||||
|
||||
void
|
||||
llarp_threadpool_tick(struct llarp_threadpool* pool)
|
||||
{
|
||||
if (pool->impl)
|
||||
{
|
||||
pool->impl->drain();
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
llarp_free_threadpool(struct llarp_threadpool** pool)
|
||||
{
|
||||
if (*pool)
|
||||
{
|
||||
delete *pool;
|
||||
}
|
||||
*pool = nullptr;
|
||||
}
|
||||
|
||||
size_t
|
||||
llarp_threadpool::size() const
|
||||
{
|
||||
return impl ? impl->capacity() : 0;
|
||||
}
|
||||
|
||||
size_t
|
||||
llarp_threadpool::pendingJobs() const
|
||||
{
|
||||
return impl ? impl->jobCount() : 0;
|
||||
}
|
||||
|
||||
size_t
|
||||
llarp_threadpool::numThreads() const
|
||||
{
|
||||
return impl ? impl->activeThreadCount() : 0;
|
||||
}
|
@ -1,91 +0,0 @@
|
||||
#ifndef LLARP_THREADPOOL_H
|
||||
#define LLARP_THREADPOOL_H
|
||||
|
||||
#include <util/thread/queue.hpp>
|
||||
#include <util/thread/thread_pool.hpp>
|
||||
#include <util/thread/threading.hpp>
|
||||
#include <util/thread/annotations.hpp>
|
||||
#include <util/types.hpp>
|
||||
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <string_view>
|
||||
|
||||
struct llarp_threadpool;
|
||||
|
||||
#ifdef __cplusplus
|
||||
struct llarp_threadpool
|
||||
{
|
||||
std::unique_ptr<llarp::thread::ThreadPool> impl;
|
||||
|
||||
llarp_threadpool(int workers, std::string_view name, size_t queueLength = size_t{1024 * 8})
|
||||
: impl(std::make_unique<llarp::thread::ThreadPool>(
|
||||
workers, std::max(queueLength, size_t{32}), name))
|
||||
{
|
||||
}
|
||||
|
||||
size_t
|
||||
size() const;
|
||||
|
||||
size_t
|
||||
pendingJobs() const;
|
||||
|
||||
size_t
|
||||
numThreads() const;
|
||||
|
||||
/// see if this thread is full given lookahead amount
|
||||
bool
|
||||
LooksFull(size_t lookahead) const
|
||||
{
|
||||
return (pendingJobs() + lookahead) >= size();
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
struct llarp_threadpool*
|
||||
llarp_init_threadpool(int workers, const char* name, size_t queueLength);
|
||||
|
||||
void
|
||||
llarp_free_threadpool(struct llarp_threadpool** tp);
|
||||
|
||||
using llarp_thread_work_func = void (*)(void*);
|
||||
|
||||
/** job to be done in worker thread */
|
||||
struct llarp_thread_job
|
||||
{
|
||||
#ifdef __cplusplus
|
||||
/** user data to pass to work function */
|
||||
void* user{nullptr};
|
||||
/** called in threadpool worker thread */
|
||||
llarp_thread_work_func work{nullptr};
|
||||
|
||||
llarp_thread_job(void* u, llarp_thread_work_func w) : user(u), work(w)
|
||||
{
|
||||
}
|
||||
|
||||
llarp_thread_job() = default;
|
||||
#else
|
||||
void* user;
|
||||
llarp_thread_work_func work;
|
||||
#endif
|
||||
};
|
||||
|
||||
void
|
||||
llarp_threadpool_tick(struct llarp_threadpool* tp);
|
||||
|
||||
bool
|
||||
llarp_threadpool_queue_job(struct llarp_threadpool* tp, struct llarp_thread_job j);
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
||||
bool
|
||||
llarp_threadpool_queue_job(struct llarp_threadpool* tp, std::function<void(void)> func);
|
||||
|
||||
#endif
|
||||
|
||||
void
|
||||
llarp_threadpool_start(struct llarp_threadpool* tp);
|
||||
void
|
||||
llarp_threadpool_stop(struct llarp_threadpool* tp);
|
||||
|
||||
#endif
|
@ -0,0 +1,278 @@
|
||||
#include <catch2/catch.hpp>
|
||||
#include <crypto/crypto.hpp>
|
||||
#include <crypto/crypto_libsodium.hpp>
|
||||
#include <string_view>
|
||||
|
||||
#include <router_contact.hpp>
|
||||
#include <iwp/iwp.hpp>
|
||||
#include <util/meta/memfn.hpp>
|
||||
#include <messages/link_message_parser.hpp>
|
||||
#include <messages/discard.hpp>
|
||||
#include <util/time.hpp>
|
||||
|
||||
#undef LOG_TAG
|
||||
#define LOG_TAG "test_iwp_session.cpp"
|
||||
|
||||
namespace iwp = llarp::iwp;
|
||||
namespace util = llarp::util;
|
||||
|
||||
/// make an iwp link
|
||||
template <bool inbound, typename... Args_t>
|
||||
static llarp::LinkLayer_ptr
|
||||
make_link(Args_t... args)
|
||||
{
|
||||
if (inbound)
|
||||
return iwp::NewInboundLink(args...);
|
||||
else
|
||||
return iwp::NewOutboundLink(args...);
|
||||
}
|
||||
using Logic_ptr = std::shared_ptr<llarp::Logic>;
|
||||
|
||||
/// a single iwp link with associated keys and members to make unit tests work
|
||||
struct IWPLinkContext
|
||||
{
|
||||
llarp::RouterContact rc;
|
||||
llarp::IpAddress localAddr;
|
||||
llarp::LinkLayer_ptr link;
|
||||
std::shared_ptr<llarp::KeyManager> keyManager;
|
||||
llarp::LinkMessageParser m_Parser;
|
||||
llarp_ev_loop_ptr m_Loop;
|
||||
/// is the test done on this context ?
|
||||
bool gucci = false;
|
||||
|
||||
IWPLinkContext(std::string_view addr, llarp_ev_loop_ptr loop)
|
||||
: localAddr{std::move(addr)}
|
||||
, keyManager{std::make_shared<llarp::KeyManager>()}
|
||||
, m_Parser{nullptr}
|
||||
, m_Loop{std::move(loop)}
|
||||
{
|
||||
// generate keys
|
||||
llarp::CryptoManager::instance()->identity_keygen(keyManager->identityKey);
|
||||
llarp::CryptoManager::instance()->encryption_keygen(keyManager->encryptionKey);
|
||||
llarp::CryptoManager::instance()->encryption_keygen(keyManager->transportKey);
|
||||
|
||||
// set keys in rc
|
||||
rc.pubkey = keyManager->identityKey.toPublic();
|
||||
rc.enckey = keyManager->encryptionKey.toPublic();
|
||||
}
|
||||
|
||||
bool
|
||||
HandleMessage(llarp::ILinkSession* from, const llarp_buffer_t& buf)
|
||||
{
|
||||
return m_Parser.ProcessFrom(from, buf);
|
||||
}
|
||||
|
||||
/// initialize link
|
||||
template <bool inbound>
|
||||
void
|
||||
InitLink(std::function<void(llarp::ILinkSession*)> established)
|
||||
{
|
||||
link = make_link<inbound>(
|
||||
keyManager,
|
||||
// getrc
|
||||
[&]() -> const llarp::RouterContact& { return rc; },
|
||||
// link message handler
|
||||
util::memFn(&IWPLinkContext::HandleMessage, this),
|
||||
// sign buffer
|
||||
[&](llarp::Signature& sig, const llarp_buffer_t& buf) {
|
||||
REQUIRE(llarp::CryptoManager::instance()->sign(sig, keyManager->identityKey, buf));
|
||||
return true;
|
||||
},
|
||||
// established handler
|
||||
[established](llarp::ILinkSession* s) {
|
||||
REQUIRE(s != nullptr);
|
||||
established(s);
|
||||
return true;
|
||||
},
|
||||
// renegotiate handler
|
||||
[](llarp::RouterContact newrc, llarp::RouterContact oldrc) {
|
||||
REQUIRE(newrc.pubkey == oldrc.pubkey);
|
||||
return true;
|
||||
},
|
||||
// timeout handler
|
||||
[&](llarp::ILinkSession*) {
|
||||
llarp_ev_loop_stop(m_Loop);
|
||||
REQUIRE(false);
|
||||
},
|
||||
// session closed handler
|
||||
[](llarp::RouterID) {},
|
||||
// pump done handler
|
||||
[]() {},
|
||||
// do work function
|
||||
[l = m_Loop](llarp::Work_t work) { l->call_after_delay(1ms, work); });
|
||||
REQUIRE(link->Configure(
|
||||
m_Loop, llarp::net::LoopbackInterfaceName(), AF_INET, *localAddr.getPort()));
|
||||
|
||||
if (inbound)
|
||||
{
|
||||
// only add address info on the recipiant's rc
|
||||
rc.addrs.emplace_back();
|
||||
REQUIRE(link->GetOurAddressInfo(rc.addrs.back()));
|
||||
}
|
||||
// sign rc
|
||||
REQUIRE(rc.Sign(keyManager->identityKey));
|
||||
REQUIRE(keyManager != nullptr);
|
||||
}
|
||||
};
|
||||
|
||||
using Context_ptr = std::shared_ptr<IWPLinkContext>;
|
||||
|
||||
/// run an iwp unit test after setup
|
||||
/// call take 2 parameters, test and a timeout
|
||||
///
|
||||
/// test is a callable that takes 5 arguments:
|
||||
/// 0) std::function<Logic_ptr(void)> that starts the iwp links and gives a logic to call with
|
||||
/// 1) std::function<void(void)> that ends the unit test if we are done
|
||||
/// 2) std::function<void(void)> that ends the unit test right now as a success
|
||||
/// 3) client iwp link context (shared_ptr)
|
||||
/// 4) relay iwp link context (shared_ptr)
|
||||
///
|
||||
/// timeout is a std::chrono::duration that tells the driver how long to run the unit test for
|
||||
/// before it should assume failure of unit test
|
||||
template <typename Func_t, typename Duration_t = std::chrono::milliseconds>
|
||||
void
|
||||
RunIWPTest(Func_t test, Duration_t timeout = 1s)
|
||||
{
|
||||
// shut up logs
|
||||
llarp::LogSilencer shutup;
|
||||
|
||||
// set up event loop
|
||||
auto logic = std::make_shared<llarp::Logic>();
|
||||
auto loop = llarp_make_ev_loop();
|
||||
loop->set_logic(logic);
|
||||
|
||||
llarp::LogContext::Instance().Initialize(
|
||||
llarp::eLogDebug, llarp::LogType::File, "stdout", "unit test", [loop](auto work) {
|
||||
loop->call_soon(work);
|
||||
});
|
||||
|
||||
// turn off bogon blocking
|
||||
auto oldBlockBogons = llarp::RouterContact::BlockBogons;
|
||||
llarp::RouterContact::BlockBogons = false;
|
||||
|
||||
// set up cryptography
|
||||
llarp::sodium::CryptoLibSodium crypto{};
|
||||
llarp::CryptoManager manager{&crypto};
|
||||
|
||||
// set up client
|
||||
auto initiator = std::make_shared<IWPLinkContext>("127.0.0.1:3001", loop);
|
||||
// set up server
|
||||
auto recipiant = std::make_shared<IWPLinkContext>("127.0.0.1:3002", loop);
|
||||
|
||||
// function for ending unit test on success
|
||||
auto endIfDone = [initiator, recipiant, loop, logic]() {
|
||||
if (initiator->gucci and recipiant->gucci)
|
||||
{
|
||||
LogicCall(logic, [loop]() { llarp_ev_loop_stop(loop); });
|
||||
}
|
||||
};
|
||||
// function to start test and give logic to unit test
|
||||
auto start = [initiator, recipiant, logic]() {
|
||||
REQUIRE(initiator->link->Start(logic));
|
||||
REQUIRE(recipiant->link->Start(logic));
|
||||
return logic;
|
||||
};
|
||||
|
||||
// function to end test immediately
|
||||
auto endTest = [logic, loop]() { LogicCall(logic, [loop]() { llarp_ev_loop_stop(loop); }); };
|
||||
|
||||
loop->call_after_delay(
|
||||
std::chrono::duration_cast<llarp_time_t>(timeout), []() { REQUIRE(false); });
|
||||
test(start, endIfDone, endTest, initiator, recipiant);
|
||||
llarp_ev_loop_run_single_process(loop, logic);
|
||||
llarp::RouterContact::BlockBogons = oldBlockBogons;
|
||||
}
|
||||
|
||||
/// ensure clients can connect to relays
|
||||
TEST_CASE("IWP handshake", "[iwp]")
|
||||
{
|
||||
RunIWPTest([](std::function<Logic_ptr(void)> start,
|
||||
std::function<void(void)> endIfDone,
|
||||
[[maybe_unused]] std::function<void(void)> endTestNow,
|
||||
Context_ptr alice,
|
||||
Context_ptr bob) {
|
||||
// set up initiator
|
||||
alice->InitLink<false>([=](auto remote) {
|
||||
REQUIRE(remote->GetRemoteRC() == bob->rc);
|
||||
alice->gucci = true;
|
||||
endIfDone();
|
||||
});
|
||||
// set up recipiant
|
||||
bob->InitLink<true>([=](auto remote) {
|
||||
REQUIRE(remote->GetRemoteRC() == alice->rc);
|
||||
bob->gucci = true;
|
||||
endIfDone();
|
||||
});
|
||||
// start unit test
|
||||
auto logic = start();
|
||||
// try establishing a session
|
||||
LogicCall(logic, [link = alice->link, rc = bob->rc]() { REQUIRE(link->TryEstablishTo(rc)); });
|
||||
});
|
||||
}
|
||||
|
||||
/// ensure relays cannot connect to clients
|
||||
TEST_CASE("IWP handshake reverse", "[iwp]")
|
||||
{
|
||||
RunIWPTest([](std::function<Logic_ptr(void)> start,
|
||||
[[maybe_unused]] std::function<void(void)> endIfDone,
|
||||
std::function<void(void)> endTestNow,
|
||||
Context_ptr alice,
|
||||
Context_ptr bob) {
|
||||
alice->InitLink<false>([](auto) {});
|
||||
bob->InitLink<true>([](auto) {});
|
||||
// start unit test
|
||||
auto logic = start();
|
||||
// try establishing a session in the wrong direction
|
||||
LogicCall(logic, [logic, link = bob->link, rc = alice->rc, endTestNow]() {
|
||||
REQUIRE(not link->TryEstablishTo(rc));
|
||||
endTestNow();
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
/// ensure iwp can send messages between sessions
|
||||
TEST_CASE("IWP send messages", "[iwp]")
|
||||
{
|
||||
RunIWPTest([](std::function<Logic_ptr(void)> start,
|
||||
std::function<void(void)> endIfDone,
|
||||
std::function<void(void)> endTestNow,
|
||||
Context_ptr alice,
|
||||
Context_ptr bob) {
|
||||
constexpr int aliceNumSend = 128;
|
||||
int aliceNumSent = 0;
|
||||
// when alice makes a session to bob send `aliceNumSend` messages to him
|
||||
alice->InitLink<false>([endIfDone, alice, &aliceNumSent](auto session) {
|
||||
for (auto index = 0; index < aliceNumSend; index++)
|
||||
{
|
||||
alice->m_Loop->call_soon([session, endIfDone, alice, &aliceNumSent]() {
|
||||
// generate a discard message that is 512 bytes long
|
||||
llarp::DiscardMessage msg;
|
||||
std::vector<byte_t> msgBuff(512);
|
||||
llarp_buffer_t buf(msgBuff);
|
||||
// add random padding
|
||||
llarp::CryptoManager::instance()->randomize(buf);
|
||||
// encode the discard message
|
||||
msg.BEncode(&buf);
|
||||
// send the message
|
||||
session->SendMessageBuffer(msgBuff, [endIfDone, alice, &aliceNumSent](auto status) {
|
||||
if (status == llarp::ILinkSession::DeliveryStatus::eDeliverySuccess)
|
||||
{
|
||||
// on successful transmit increment the number we sent
|
||||
aliceNumSent++;
|
||||
}
|
||||
// if we sent all the messages sucessfully we end the unit test
|
||||
alice->gucci = aliceNumSent == aliceNumSend;
|
||||
endIfDone();
|
||||
});
|
||||
});
|
||||
}
|
||||
});
|
||||
bob->InitLink<true>([bob](auto) { bob->gucci = true; });
|
||||
// start unit test
|
||||
auto logic = start();
|
||||
// try establishing a session from alice to bob
|
||||
LogicCall(logic, [logic, link = alice->link, rc = bob->rc, endTestNow]() {
|
||||
REQUIRE(link->TryEstablishTo(rc));
|
||||
});
|
||||
});
|
||||
}
|
@ -1,329 +0,0 @@
|
||||
#include <crypto/crypto_libsodium.hpp>
|
||||
#include <ev/ev.h>
|
||||
#include <iwp/iwp.hpp>
|
||||
#include <llarp_test.hpp>
|
||||
#include <iwp/iwp.hpp>
|
||||
#include <memory>
|
||||
#include <messages/link_intro.hpp>
|
||||
#include <messages/discard.hpp>
|
||||
|
||||
#include <test_util.hpp>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
using namespace ::llarp;
|
||||
using namespace ::testing;
|
||||
|
||||
struct LinkLayerTest : public test::LlarpTest<llarp::sodium::CryptoLibSodium>
|
||||
{
|
||||
static constexpr uint16_t AlicePort = 41163;
|
||||
static constexpr uint16_t BobPort = 8088;
|
||||
|
||||
struct Context
|
||||
{
|
||||
Context()
|
||||
{
|
||||
keyManager = std::make_shared<KeyManager>();
|
||||
|
||||
SecretKey signingKey;
|
||||
CryptoManager::instance()->identity_keygen(signingKey);
|
||||
keyManager->identityKey = signingKey;
|
||||
|
||||
SecretKey encryptionKey;
|
||||
CryptoManager::instance()->encryption_keygen(encryptionKey);
|
||||
keyManager->encryptionKey = encryptionKey;
|
||||
|
||||
SecretKey transportKey;
|
||||
CryptoManager::instance()->encryption_keygen(transportKey);
|
||||
keyManager->transportKey = transportKey;
|
||||
|
||||
rc.pubkey = signingKey.toPublic();
|
||||
rc.enckey = encryptionKey.toPublic();
|
||||
}
|
||||
|
||||
std::shared_ptr<thread::ThreadPool> worker;
|
||||
|
||||
std::shared_ptr<KeyManager> keyManager;
|
||||
|
||||
RouterContact rc;
|
||||
|
||||
bool madeSession = false;
|
||||
bool gotLIM = false;
|
||||
|
||||
bool
|
||||
IsGucci() const
|
||||
{
|
||||
return gotLIM && madeSession;
|
||||
}
|
||||
|
||||
void
|
||||
Setup()
|
||||
{
|
||||
worker = std::make_shared<thread::ThreadPool>(1, 128, "test-worker");
|
||||
worker->start();
|
||||
}
|
||||
|
||||
const RouterContact&
|
||||
GetRC() const
|
||||
{
|
||||
return rc;
|
||||
}
|
||||
|
||||
RouterID
|
||||
GetRouterID() const
|
||||
{
|
||||
return rc.pubkey;
|
||||
}
|
||||
|
||||
std::shared_ptr<ILinkLayer> link;
|
||||
|
||||
static std::string
|
||||
localLoopBack()
|
||||
{
|
||||
#if defined(__FreeBSD__) || defined(__OpenBSD__) || defined(__NetBSD__) || (__APPLE__ && __MACH__) \
|
||||
|| (__sun)
|
||||
return "lo0";
|
||||
#else
|
||||
return "lo";
|
||||
#endif
|
||||
}
|
||||
|
||||
bool
|
||||
Start(std::shared_ptr<Logic> logic, llarp_ev_loop_ptr loop, uint16_t port)
|
||||
{
|
||||
if (!link)
|
||||
return false;
|
||||
if (!link->Configure(loop, localLoopBack(), AF_INET, port))
|
||||
return false;
|
||||
/*
|
||||
* TODO: ephemeral key management
|
||||
if(!link->GenEphemeralKeys())
|
||||
return false;
|
||||
*/
|
||||
rc.addrs.emplace_back();
|
||||
if (!link->GetOurAddressInfo(rc.addrs[0]))
|
||||
return false;
|
||||
if (!rc.Sign(keyManager->identityKey))
|
||||
return false;
|
||||
return link->Start(logic, worker);
|
||||
}
|
||||
|
||||
void
|
||||
Stop()
|
||||
{
|
||||
if (link)
|
||||
link->Stop();
|
||||
if (worker)
|
||||
{
|
||||
worker->drain();
|
||||
worker->stop();
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
TearDown()
|
||||
{
|
||||
link.reset();
|
||||
worker.reset();
|
||||
}
|
||||
};
|
||||
|
||||
Context Alice;
|
||||
Context Bob;
|
||||
|
||||
bool success = false;
|
||||
const bool shouldDebug = false;
|
||||
|
||||
llarp_ev_loop_ptr netLoop;
|
||||
std::shared_ptr<Logic> m_logic;
|
||||
|
||||
llarp_time_t oldRCLifetime;
|
||||
llarp::LogLevel oldLevel;
|
||||
|
||||
LinkLayerTest() : netLoop(nullptr)
|
||||
{
|
||||
}
|
||||
|
||||
void
|
||||
SetUp()
|
||||
{
|
||||
oldLevel = llarp::LogContext::Instance().curLevel;
|
||||
if (shouldDebug)
|
||||
llarp::SetLogLevel(eLogTrace);
|
||||
oldRCLifetime = RouterContact::Lifetime;
|
||||
RouterContact::BlockBogons = false;
|
||||
RouterContact::Lifetime = 500ms;
|
||||
netLoop = llarp_make_ev_loop();
|
||||
m_logic.reset(new Logic());
|
||||
netLoop->set_logic(m_logic);
|
||||
Alice.Setup();
|
||||
Bob.Setup();
|
||||
}
|
||||
|
||||
void
|
||||
TearDown()
|
||||
{
|
||||
Alice.TearDown();
|
||||
Bob.TearDown();
|
||||
m_logic.reset();
|
||||
netLoop.reset();
|
||||
RouterContact::BlockBogons = true;
|
||||
RouterContact::Lifetime = oldRCLifetime;
|
||||
llarp::SetLogLevel(oldLevel);
|
||||
}
|
||||
|
||||
void
|
||||
RunMainloop()
|
||||
{
|
||||
m_logic->call_later(5s, std::bind(&LinkLayerTest::Stop, this));
|
||||
llarp_ev_loop_run_single_process(netLoop, m_logic);
|
||||
}
|
||||
|
||||
void
|
||||
Stop()
|
||||
{
|
||||
Alice.Stop();
|
||||
Bob.Stop();
|
||||
llarp_ev_loop_stop(netLoop);
|
||||
m_logic->stop();
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(LinkLayerTest, TestIWP)
|
||||
{
|
||||
#ifdef WIN32
|
||||
GTEST_SKIP();
|
||||
#else
|
||||
auto sendDiscardMessage = [](ILinkSession* s, auto callback) -> bool {
|
||||
// send discard message in reply to complete unit test
|
||||
std::vector<byte_t> tmp(32);
|
||||
llarp_buffer_t otherBuf(tmp);
|
||||
DiscardMessage discard;
|
||||
if (!discard.BEncode(&otherBuf))
|
||||
return false;
|
||||
return s->SendMessageBuffer(std::move(tmp), callback);
|
||||
};
|
||||
Alice.link = iwp::NewInboundLink(
|
||||
// KeyManager
|
||||
Alice.keyManager,
|
||||
|
||||
// GetRCFunc
|
||||
[&]() -> const RouterContact& { return Alice.GetRC(); },
|
||||
|
||||
// LinkMessageHandler
|
||||
[&](ILinkSession* s, const llarp_buffer_t& buf) -> bool {
|
||||
llarp_buffer_t copy(buf.base, buf.sz);
|
||||
if (not Alice.gotLIM)
|
||||
{
|
||||
LinkIntroMessage msg;
|
||||
if (msg.BDecode(©))
|
||||
{
|
||||
Alice.gotLIM = s->GotLIM(&msg);
|
||||
}
|
||||
}
|
||||
return Alice.gotLIM;
|
||||
},
|
||||
|
||||
// SignBufferFunc
|
||||
[&](Signature& sig, const llarp_buffer_t& buf) -> bool {
|
||||
return m_crypto.sign(sig, Alice.keyManager->identityKey, buf);
|
||||
},
|
||||
|
||||
// SessionEstablishedHandler
|
||||
[&, this](ILinkSession* s) -> bool {
|
||||
const auto rc = s->GetRemoteRC();
|
||||
if (rc.pubkey != Bob.GetRC().pubkey)
|
||||
return false;
|
||||
LogInfo("alice established with bob");
|
||||
Alice.madeSession = true;
|
||||
sendDiscardMessage(s, [&](auto status) {
|
||||
success = status == llarp::ILinkSession::DeliveryStatus::eDeliverySuccess;
|
||||
LogInfo("message sent to bob suceess=", success);
|
||||
this->Stop();
|
||||
});
|
||||
return true;
|
||||
},
|
||||
|
||||
// SessionRenegotiateHandler
|
||||
[&](RouterContact, RouterContact) -> bool { return true; },
|
||||
|
||||
// TimeoutHandler
|
||||
[&](ILinkSession* session) {
|
||||
ASSERT_FALSE(session->IsEstablished());
|
||||
Stop();
|
||||
},
|
||||
|
||||
// SessionClosedHandler
|
||||
[&](RouterID router) { ASSERT_EQ(router, Alice.GetRouterID()); },
|
||||
|
||||
// PumpDoneHandler
|
||||
[]() {});
|
||||
|
||||
Bob.link = iwp::NewInboundLink(
|
||||
// KeyManager
|
||||
Bob.keyManager,
|
||||
|
||||
// GetRCFunc
|
||||
[&]() -> const RouterContact& { return Bob.GetRC(); },
|
||||
|
||||
// LinkMessageHandler
|
||||
[&](ILinkSession* s, const llarp_buffer_t& buf) -> bool {
|
||||
llarp_buffer_t copy(buf.base, buf.sz);
|
||||
if (not Bob.gotLIM)
|
||||
{
|
||||
LinkIntroMessage msg;
|
||||
if (msg.BDecode(©))
|
||||
{
|
||||
Bob.gotLIM = s->GotLIM(&msg);
|
||||
}
|
||||
return Bob.gotLIM;
|
||||
}
|
||||
DiscardMessage discard;
|
||||
if (discard.BDecode(©))
|
||||
{
|
||||
LogInfo("bog got discard message from alice");
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
},
|
||||
|
||||
// SignBufferFunc
|
||||
[&](Signature& sig, const llarp_buffer_t& buf) -> bool {
|
||||
return m_crypto.sign(sig, Bob.keyManager->identityKey, buf);
|
||||
},
|
||||
|
||||
// SessionEstablishedHandler
|
||||
[&](ILinkSession* s) -> bool {
|
||||
if (s->GetRemoteRC().pubkey != Alice.GetRC().pubkey)
|
||||
return false;
|
||||
LogInfo("bob established with alice");
|
||||
Bob.madeSession = true;
|
||||
|
||||
return true;
|
||||
},
|
||||
|
||||
// SessionRenegotiateHandler
|
||||
[&](RouterContact newrc, RouterContact oldrc) -> bool {
|
||||
return newrc.pubkey == oldrc.pubkey;
|
||||
},
|
||||
|
||||
// TimeoutHandler
|
||||
[&](ILinkSession* session) { ASSERT_FALSE(session->IsEstablished()); },
|
||||
|
||||
// SessionClosedHandler
|
||||
[&](RouterID router) { ASSERT_EQ(router, Alice.GetRouterID()); },
|
||||
|
||||
// PumpDoneHandler
|
||||
[]() {});
|
||||
|
||||
ASSERT_TRUE(Alice.Start(m_logic, netLoop, AlicePort));
|
||||
ASSERT_TRUE(Bob.Start(m_logic, netLoop, BobPort));
|
||||
|
||||
LogicCall(m_logic, [&]() { ASSERT_TRUE(Alice.link->TryEstablishTo(Bob.GetRC())); });
|
||||
|
||||
RunMainloop();
|
||||
ASSERT_TRUE(Alice.IsGucci());
|
||||
ASSERT_TRUE(Bob.IsGucci());
|
||||
ASSERT_TRUE(success);
|
||||
#endif
|
||||
};
|
@ -1,456 +0,0 @@
|
||||
#include <util/thread/thread_pool.hpp>
|
||||
#include <util/thread/threading.hpp>
|
||||
#include <util/thread/barrier.hpp>
|
||||
|
||||
#include <condition_variable>
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
using namespace llarp;
|
||||
using namespace llarp::thread;
|
||||
|
||||
using LockGuard = std::unique_lock< std::mutex >;
|
||||
|
||||
class PoolArgs
|
||||
{
|
||||
public:
|
||||
std::mutex& mutex;
|
||||
std::condition_variable& start;
|
||||
std::condition_variable& stop;
|
||||
volatile size_t count;
|
||||
volatile size_t startSignal;
|
||||
volatile size_t stopSignal;
|
||||
};
|
||||
|
||||
class BarrierArgs
|
||||
{
|
||||
public:
|
||||
util::Barrier& startBarrier;
|
||||
util::Barrier& stopBarrier;
|
||||
|
||||
std::atomic_size_t count;
|
||||
};
|
||||
|
||||
class BasicWorkArgs
|
||||
{
|
||||
public:
|
||||
std::atomic_size_t count;
|
||||
};
|
||||
|
||||
void
|
||||
simpleFunction(PoolArgs& args)
|
||||
{
|
||||
LockGuard lock(args.mutex);
|
||||
++args.count;
|
||||
++args.startSignal;
|
||||
args.start.notify_one();
|
||||
|
||||
args.stop.wait(lock, [&]() { return args.stopSignal; });
|
||||
}
|
||||
|
||||
void
|
||||
incrementFunction(PoolArgs& args)
|
||||
{
|
||||
LockGuard lock(args.mutex);
|
||||
++args.count;
|
||||
++args.startSignal;
|
||||
args.start.notify_one();
|
||||
}
|
||||
|
||||
void
|
||||
barrierFunction(BarrierArgs& args)
|
||||
{
|
||||
args.startBarrier.Block();
|
||||
args.count++;
|
||||
args.stopBarrier.Block();
|
||||
}
|
||||
|
||||
void
|
||||
basicWork(BasicWorkArgs& args)
|
||||
{
|
||||
args.count++;
|
||||
}
|
||||
|
||||
void
|
||||
recurse(util::Barrier& barrier, std::atomic_size_t& counter, ThreadPool& pool,
|
||||
size_t depthLimit)
|
||||
{
|
||||
ASSERT_LE(0u, counter);
|
||||
ASSERT_GT(depthLimit, counter);
|
||||
|
||||
if(++counter != depthLimit)
|
||||
{
|
||||
ASSERT_TRUE(
|
||||
pool.addJob(std::bind(recurse, std::ref(barrier), std::ref(counter),
|
||||
std::ref(pool), depthLimit)));
|
||||
}
|
||||
|
||||
barrier.Block();
|
||||
}
|
||||
|
||||
class DestructiveObject
|
||||
{
|
||||
private:
|
||||
util::Barrier& barrier;
|
||||
ThreadPool& pool;
|
||||
|
||||
public:
|
||||
DestructiveObject(util::Barrier& b, ThreadPool& p) : barrier(b), pool(p)
|
||||
{
|
||||
}
|
||||
|
||||
~DestructiveObject()
|
||||
{
|
||||
auto job = std::bind(&util::Barrier::Block, &barrier);
|
||||
pool.addJob(job);
|
||||
}
|
||||
};
|
||||
|
||||
void
|
||||
destructiveJob(DestructiveObject* obj)
|
||||
{
|
||||
delete obj;
|
||||
}
|
||||
|
||||
TEST(TestThreadPool, breathing)
|
||||
{
|
||||
static constexpr size_t threads = 10;
|
||||
static constexpr size_t capacity = 50;
|
||||
|
||||
ThreadPool pool(threads, capacity, "breathing");
|
||||
|
||||
ASSERT_EQ(0u, pool.startedThreadCount());
|
||||
ASSERT_EQ(capacity, pool.capacity());
|
||||
ASSERT_EQ(0u, pool.jobCount());
|
||||
|
||||
ASSERT_TRUE(pool.start());
|
||||
|
||||
ASSERT_EQ(threads, pool.startedThreadCount());
|
||||
ASSERT_EQ(capacity, pool.capacity());
|
||||
ASSERT_EQ(0u, pool.jobCount());
|
||||
|
||||
pool.drain();
|
||||
}
|
||||
|
||||
struct AccessorsData
|
||||
{
|
||||
size_t threads;
|
||||
size_t capacity;
|
||||
};
|
||||
|
||||
std::ostream&
|
||||
operator<<(std::ostream& os, AccessorsData d)
|
||||
{
|
||||
os << "[ threads = " << d.threads << " capacity = " << d.capacity << " ]";
|
||||
return os;
|
||||
}
|
||||
|
||||
class Accessors : public ::testing::TestWithParam< AccessorsData >
|
||||
{
|
||||
};
|
||||
|
||||
TEST_P(Accessors, accessors)
|
||||
{
|
||||
auto d = GetParam();
|
||||
|
||||
ThreadPool pool(d.threads, d.capacity, "accessors");
|
||||
|
||||
ASSERT_EQ(d.threads, pool.threadCount());
|
||||
ASSERT_EQ(d.capacity, pool.capacity());
|
||||
ASSERT_EQ(0u, pool.startedThreadCount());
|
||||
}
|
||||
|
||||
static const AccessorsData accessorsData[] = {
|
||||
{10, 50}, {1, 1}, {50, 100}, {2, 22}, {100, 200}};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestThreadPool, Accessors,
|
||||
::testing::ValuesIn(accessorsData));
|
||||
|
||||
struct ClosingData
|
||||
{
|
||||
size_t threads;
|
||||
size_t capacity;
|
||||
};
|
||||
|
||||
std::ostream&
|
||||
operator<<(std::ostream& os, ClosingData d)
|
||||
{
|
||||
os << "[ threads = " << d.threads << " capacity = " << d.capacity << " ]";
|
||||
return os;
|
||||
}
|
||||
|
||||
class Closing : public ::testing::TestWithParam< ClosingData >
|
||||
{
|
||||
};
|
||||
|
||||
TEST_P(Closing, drain)
|
||||
{
|
||||
auto d = GetParam();
|
||||
|
||||
std::mutex mutex;
|
||||
std::condition_variable start;
|
||||
std::condition_variable stop;
|
||||
|
||||
PoolArgs args{mutex, start, stop, 0, 0, 0};
|
||||
|
||||
ThreadPool pool(d.threads, d.capacity, "drain");
|
||||
|
||||
ASSERT_EQ(d.threads, pool.threadCount());
|
||||
ASSERT_EQ(d.capacity, pool.capacity());
|
||||
ASSERT_EQ(0u, pool.startedThreadCount());
|
||||
|
||||
auto simpleJob = std::bind(simpleFunction, std::ref(args));
|
||||
|
||||
ASSERT_FALSE(pool.addJob(simpleJob));
|
||||
|
||||
ASSERT_TRUE(pool.start());
|
||||
ASSERT_EQ(0u, pool.jobCount());
|
||||
|
||||
LockGuard lock(mutex);
|
||||
|
||||
for(size_t i = 0; i < d.threads; ++i)
|
||||
{
|
||||
args.startSignal = 0;
|
||||
args.stopSignal = 0;
|
||||
ASSERT_TRUE(pool.addJob(simpleJob));
|
||||
|
||||
start.wait(lock, [&]() { return args.startSignal; });
|
||||
}
|
||||
|
||||
args.stopSignal++;
|
||||
|
||||
lock.unlock();
|
||||
|
||||
stop.notify_all();
|
||||
|
||||
pool.drain();
|
||||
|
||||
ASSERT_EQ(d.threads, pool.startedThreadCount());
|
||||
ASSERT_EQ(0u, pool.jobCount());
|
||||
}
|
||||
|
||||
TEST_P(Closing, stop)
|
||||
{
|
||||
auto d = GetParam();
|
||||
|
||||
ThreadPool pool(d.threads, d.capacity, "stop");
|
||||
|
||||
std::mutex mutex;
|
||||
std::condition_variable start;
|
||||
std::condition_variable stop;
|
||||
|
||||
PoolArgs args{mutex, start, stop, 0, 0, 0};
|
||||
|
||||
ASSERT_EQ(d.threads, pool.threadCount());
|
||||
ASSERT_EQ(d.capacity, pool.capacity());
|
||||
ASSERT_EQ(0u, pool.startedThreadCount());
|
||||
|
||||
auto simpleJob = std::bind(simpleFunction, std::ref(args));
|
||||
|
||||
ASSERT_FALSE(pool.addJob(simpleJob));
|
||||
|
||||
ASSERT_TRUE(pool.start());
|
||||
ASSERT_EQ(0u, pool.jobCount());
|
||||
|
||||
LockGuard lock(mutex);
|
||||
|
||||
for(size_t i = 0; i < d.capacity; ++i)
|
||||
{
|
||||
args.startSignal = 0;
|
||||
args.stopSignal = 0;
|
||||
ASSERT_TRUE(pool.addJob(simpleJob));
|
||||
|
||||
while(i < d.threads && !args.startSignal)
|
||||
{
|
||||
start.wait(lock);
|
||||
}
|
||||
}
|
||||
|
||||
args.stopSignal++;
|
||||
|
||||
lock.unlock();
|
||||
|
||||
stop.notify_all();
|
||||
|
||||
pool.stop();
|
||||
|
||||
ASSERT_EQ(d.capacity, args.count);
|
||||
ASSERT_EQ(0u, pool.startedThreadCount());
|
||||
ASSERT_EQ(0u, pool.activeThreadCount());
|
||||
ASSERT_EQ(0u, pool.jobCount());
|
||||
}
|
||||
|
||||
TEST_P(Closing, shutdown)
|
||||
{
|
||||
auto d = GetParam();
|
||||
|
||||
ThreadPool pool(d.threads, d.capacity, "shutdown");
|
||||
|
||||
std::mutex mutex;
|
||||
std::condition_variable start;
|
||||
std::condition_variable stop;
|
||||
|
||||
PoolArgs args{mutex, start, stop, 0, 0, 0};
|
||||
|
||||
ASSERT_EQ(d.threads, pool.threadCount());
|
||||
ASSERT_EQ(d.capacity, pool.capacity());
|
||||
ASSERT_EQ(0u, pool.startedThreadCount());
|
||||
|
||||
auto simpleJob = std::bind(simpleFunction, std::ref(args));
|
||||
|
||||
ASSERT_FALSE(pool.addJob(simpleJob));
|
||||
|
||||
ASSERT_TRUE(pool.start());
|
||||
ASSERT_EQ(0u, pool.jobCount());
|
||||
|
||||
LockGuard lock(mutex);
|
||||
|
||||
for(size_t i = 0; i < d.capacity; ++i)
|
||||
{
|
||||
args.startSignal = 0;
|
||||
args.stopSignal = 0;
|
||||
ASSERT_TRUE(pool.addJob(simpleJob));
|
||||
|
||||
while(i < d.threads && !args.startSignal)
|
||||
{
|
||||
start.wait(lock);
|
||||
}
|
||||
}
|
||||
|
||||
ASSERT_EQ(d.threads, pool.startedThreadCount());
|
||||
ASSERT_EQ(d.capacity - d.threads, pool.jobCount());
|
||||
|
||||
auto incrementJob = std::bind(incrementFunction, std::ref(args));
|
||||
|
||||
for(size_t i = 0; i < d.threads; ++i)
|
||||
{
|
||||
ASSERT_TRUE(pool.addJob(incrementJob));
|
||||
}
|
||||
|
||||
args.stopSignal++;
|
||||
stop.notify_all();
|
||||
|
||||
lock.unlock();
|
||||
|
||||
pool.shutdown();
|
||||
|
||||
ASSERT_EQ(0u, pool.startedThreadCount());
|
||||
ASSERT_EQ(0u, pool.activeThreadCount());
|
||||
ASSERT_EQ(0u, pool.jobCount());
|
||||
}
|
||||
|
||||
ClosingData closingData[] = {{1, 1}, {2, 2}, {10, 10},
|
||||
{10, 50}, {50, 75}, {25, 80}};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestThreadPool, Closing,
|
||||
::testing::ValuesIn(closingData));
|
||||
|
||||
struct TryAddData
|
||||
{
|
||||
size_t threads;
|
||||
size_t capacity;
|
||||
};
|
||||
|
||||
std::ostream&
|
||||
operator<<(std::ostream& os, TryAddData d)
|
||||
{
|
||||
os << "[ threads = " << d.threads << " capacity = " << d.capacity << " ]";
|
||||
return os;
|
||||
}
|
||||
|
||||
class TryAdd : public ::testing::TestWithParam< TryAddData >
|
||||
{
|
||||
};
|
||||
|
||||
TEST_P(TryAdd, noblocking)
|
||||
{
|
||||
// Verify that tryAdd does not block.
|
||||
// Fill the queue, then verify `tryAddJob` does not block.
|
||||
auto d = GetParam();
|
||||
|
||||
ThreadPool pool(d.threads, d.capacity, "noblocking");
|
||||
|
||||
util::Barrier startBarrier(d.threads + 1);
|
||||
util::Barrier stopBarrier(d.threads + 1);
|
||||
|
||||
BarrierArgs args{startBarrier, stopBarrier, {0}};
|
||||
|
||||
auto simpleJob = std::bind(barrierFunction, std::ref(args));
|
||||
|
||||
ASSERT_FALSE(pool.tryAddJob(simpleJob));
|
||||
|
||||
ASSERT_TRUE(pool.start());
|
||||
|
||||
for(size_t i = 0; i < d.threads; ++i)
|
||||
{
|
||||
ASSERT_TRUE(pool.tryAddJob(simpleJob));
|
||||
}
|
||||
|
||||
// Wait for everything to start.
|
||||
startBarrier.Block();
|
||||
|
||||
// and that we emptied the queue.
|
||||
ASSERT_EQ(0u, pool.jobCount());
|
||||
|
||||
BasicWorkArgs basicWorkArgs = {{0}};
|
||||
|
||||
auto workJob = std::bind(basicWork, std::ref(basicWorkArgs));
|
||||
|
||||
for(size_t i = 0; i < d.capacity; ++i)
|
||||
{
|
||||
ASSERT_TRUE(pool.tryAddJob(workJob));
|
||||
}
|
||||
|
||||
// queue should now be full
|
||||
ASSERT_FALSE(pool.tryAddJob(workJob));
|
||||
|
||||
// and finish
|
||||
stopBarrier.Block();
|
||||
}
|
||||
|
||||
TEST(TestThreadPool, recurseJob)
|
||||
{
|
||||
// Verify we can enqueue a job onto the threadpool from a thread which is
|
||||
// currently executing a threadpool job.
|
||||
|
||||
static constexpr size_t threads = 10;
|
||||
static constexpr size_t depth = 10;
|
||||
static constexpr size_t capacity = 100;
|
||||
|
||||
util::Barrier barrier(threads + 1);
|
||||
std::atomic_size_t counter{0};
|
||||
|
||||
ThreadPool pool(threads, capacity, "recurse");
|
||||
|
||||
pool.start();
|
||||
|
||||
ASSERT_TRUE(pool.addJob(std::bind(recurse, std::ref(barrier),
|
||||
std::ref(counter), std::ref(pool), depth)));
|
||||
|
||||
barrier.Block();
|
||||
ASSERT_EQ(depth, counter);
|
||||
}
|
||||
|
||||
TEST(TestThreadPool, destructors)
|
||||
{
|
||||
// Verify that functors have their destructors called outside of threadpool
|
||||
// locks.
|
||||
|
||||
static constexpr size_t threads = 1;
|
||||
static constexpr size_t capacity = 100;
|
||||
|
||||
ThreadPool pool(threads, capacity, "destructors");
|
||||
|
||||
pool.start();
|
||||
|
||||
util::Barrier barrier(threads + 1);
|
||||
|
||||
{
|
||||
DestructiveObject* obj = new DestructiveObject(barrier, pool);
|
||||
ASSERT_TRUE(pool.addJob(std::bind(destructiveJob, obj)));
|
||||
}
|
||||
|
||||
barrier.Block();
|
||||
}
|
Loading…
Reference in New Issue