diff --git a/daemon/lokinet-vpn.cpp b/daemon/lokinet-vpn.cpp index 4791455c6..cf8bfc9bd 100644 --- a/daemon/lokinet-vpn.cpp +++ b/daemon/lokinet-vpn.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include #include @@ -50,6 +51,50 @@ OMQ_Request( return std::nullopt; } +namespace +{ + template + constexpr bool is_optional = false; + template + constexpr bool is_optional> = true; + + // Extracts a value from a cxxopts result and assigns it into `value` if present. The value can + // either be a plain value or a std::optional. If not present, `value` is not touched. + template + void + extract_option(const cxxopts::ParseResult& r, const std::string& name, T& value) + { + if (r.count(name)) + { + if constexpr (is_optional) + value = r[name].as(); + else + value = r[name].as(); + } + } + + // Takes a code, prints a message, and returns the code. Intended use is: + // return exit_error(1, "blah: {}", 42); + // from within main(). + template + [[nodiscard]] int + exit_error(int code, const std::string& format, T&&... args) + { + fmt::print(format, std::forward(args)...); + fmt::print("\n"); + return code; + } + + // Same as above, but with code omitted (uses exit code 1) + template + [[nodiscard]] int + exit_error(const std::string& format, T&&... args) + { + return exit_error(1, format, std::forward(args)...); + } + +} // namespace + int main(int argc, char* argv[]) { @@ -95,57 +140,36 @@ main(int argc, char* argv[]) { logLevel = oxenmq::LogLevel::debug; } - if (result.count("rpc") > 0) - { - rpcURL = oxenmq::address(result["rpc"].as()); - } - if (result.count("exit") > 0) - { - exitAddress = result["exit"].as(); - } goUp = result.count("up") > 0; goDown = result.count("down") > 0; printStatus = result.count("status") > 0; killDaemon = result.count("kill") > 0; - if (result.count("endpoint") > 0) - { - endpoint = result["endpoint"].as(); - } - if (result.count("token") > 0) - { - token = result["token"].as(); - } - if (result.count("auth") > 0) - { - token = result["auth"].as(); - } - if (result.count("range") > 0) - { - range = result["range"].as(); - } + extract_option(result, "rpc", rpcURL); + extract_option(result, "exit", exitAddress); + extract_option(result, "endpoint", endpoint); + extract_option(result, "token", token); + extract_option(result, "auth", token); + extract_option(result, "range", range); } catch (const cxxopts::option_not_exists_exception& ex) { - std::cerr << ex.what(); - std::cout << opts.help() << std::endl; - return 1; + return exit_error(2, "{}\n{}", ex.what(), opts.help()); } catch (std::exception& ex) { - std::cout << ex.what() << std::endl; - return 1; - } - if ((not goUp) and (not goDown) and (not printStatus) and (not killDaemon)) - { - std::cout << opts.help() << std::endl; - return 1; + return exit_error(2, "{}", ex.what()); } + + int num_commands = goUp + goDown + printStatus + killDaemon; + + if (num_commands == 0) + return exit_error(3, "One of --up/--down/--status/--kill must be specified"); + if (num_commands != 1) + return exit_error(3, "Only one of --up/--down/--status/--kill may be specified"); + if (goUp and exitAddress.empty()) - { - std::cout << "no exit address provided" << std::endl; - return 1; - } + return exit_error("no exit address provided"); oxenmq::OxenMQ omq{ [](oxenmq::LogLevel lvl, const char* file, int line, std::string msg) { @@ -173,23 +197,16 @@ main(int argc, char* argv[]) if (killDaemon) { - const auto maybe = OMQ_Request(lmq, connID, "llarp.halt"); - if (not maybe.has_value()) - { - std::cout << "call to llarp.admin.die failed" << std::endl; - return 1; - } + if (not OMQ_Request(omq, connID, "llarp.halt")) + return exit_error("call to llarp.halt failed"); return 0; } if (printStatus) { - const auto maybe_status = OMQ_Request(lmq, connID, "llarp.status"); - if (not maybe_status.has_value()) - { - std::cout << "call to llarp.status failed" << std::endl; - return 1; - } + const auto maybe_status = OMQ_Request(omq, connID, "llarp.status"); + if (not maybe_status) + return exit_error("call to llarp.status failed"); try { @@ -209,8 +226,7 @@ main(int argc, char* argv[]) } catch (std::exception& ex) { - std::cout << "failed to parse result: " << ex.what() << std::endl; - return 1; + return exit_error("failed to parse result: {}", ex.what()); } return 0; } @@ -220,18 +236,14 @@ main(int argc, char* argv[]) if (range) opts["range"] = *range; - auto maybe_result = OMQ_Request(omq, connID, "llarp.exit", opts); + auto maybe_result = OMQ_Request(omq, connID, "llarp.exit", std::move(opts)); - if (not maybe_result.has_value()) - { - std::cout << "could not add exit" << std::endl; - return 1; - } + if (not maybe_result) + return exit_error("could not add exit"); - if (maybe_result->contains("error") and maybe_result->at("error").is_string()) + if (auto err_it = maybe_result->find("error"); err_it != maybe_result->end()) { - std::cout << maybe_result->at("error").get() << std::endl; - return 1; + return exit_error("{}", err_it->get()); } } if (goDown) @@ -239,7 +251,8 @@ main(int argc, char* argv[]) nlohmann::json opts{{"unmap", true}}; if (range) opts["range"] = *range; - OMQ_Request(omq, connID, "llarp.exit", std::move(opts)); + if (not OMQ_Request(omq, connID, "llarp.exit", std::move(opts))) + return exit_error("failed to unmap exit"); } return 0;