Unify LuaSocket usage (#7405)

* Add a new socketutil module with a few helper functions that allow us to:
  * Always use a sane User-Agent (previously, only Wikipedia did so)
  * Set timeouts in an almost sane manner. Doing it explicitly prevents an interaction with KOSync that does crazy stuff I don't even want to try to understand.
* Unified said timeouts based on the request's intended usage (except for Wikipedia, which already had meaningful timeout values).
* Stopped using LuaSec directly, LuaSocket defers to LuaSec sanely on its own. Everything now transparently supports HTTPS without code duplication.
pull/7412/head
NiLuJe 3 years ago committed by GitHub
parent 89c0578c8d
commit 2f9db25969
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,10 +1,9 @@
local DocumentRegistry = require("document/documentregistry") local DocumentRegistry = require("document/documentregistry")
local JSON = require("json") local JSON = require("json")
local http = require('socket.http') local http = require("socket.http")
local https = require('ssl.https') local ltn12 = require("ltn12")
local ltn12 = require('ltn12') local socket = require("socket")
local socket = require('socket') local socketutil = require("socketutil")
local url = require('socket.url')
local _ = require("gettext") local _ = require("gettext")
local DropBoxApi = { local DropBoxApi = {
@ -15,17 +14,18 @@ local API_LIST_FOLDER = "https://api.dropboxapi.com/2/files/list_folder"
local API_DOWNLOAD_FILE = "https://content.dropboxapi.com/2/files/download" local API_DOWNLOAD_FILE = "https://content.dropboxapi.com/2/files/download"
function DropBoxApi:fetchInfo(token) function DropBoxApi:fetchInfo(token)
local request, sink = {}, {} local sink = {}
local parsed = url.parse(API_URL_INFO) socketutil:set_timeout()
request['url'] = API_URL_INFO local request = {
request['method'] = 'POST' url = API_URL_INFO,
local headers = { ["Authorization"] = "Bearer ".. token } method = "POST",
request['headers'] = headers headers = {
request['sink'] = ltn12.sink.table(sink) ["Authorization"] = "Bearer " .. token,
http.TIMEOUT = 5 },
https.TIMEOUT = 5 sink = ltn12.sink.table(sink),
local httpRequest = parsed.scheme == 'http' and http.request or https.request }
local headers_request = socket.skip(1, httpRequest(request)) local headers_request = socket.skip(1, http.request(request))
socketutil:reset_timeout()
local result_response = table.concat(sink) local result_response = table.concat(sink)
if headers_request == nil then if headers_request == nil then
return nil return nil
@ -39,23 +39,24 @@ function DropBoxApi:fetchInfo(token)
end end
function DropBoxApi:fetchListFolders(path, token) function DropBoxApi:fetchListFolders(path, token)
local request, sink = {}, {}
if path == nil or path == "/" then path = "" end if path == nil or path == "/" then path = "" end
local parsed = url.parse(API_LIST_FOLDER)
request['url'] = API_LIST_FOLDER
request['method'] = 'POST'
local data = "{\"path\": \"" .. path .. "\",\"recursive\": false,\"include_media_info\": false,".. local data = "{\"path\": \"" .. path .. "\",\"recursive\": false,\"include_media_info\": false,"..
"\"include_deleted\": false,\"include_has_explicit_shared_members\": false}" "\"include_deleted\": false,\"include_has_explicit_shared_members\": false}"
local headers = { ["Authorization"] = "Bearer ".. token, local sink = {}
["Content-Type"] = "application/json" , socketutil:set_timeout()
["Content-Length"] = #data} local request = {
request['headers'] = headers url = API_LIST_FOLDER,
request['source'] = ltn12.source.string(data) method = "POST",
request['sink'] = ltn12.sink.table(sink) headers = {
http.TIMEOUT = 5 ["Authorization"] = "Bearer ".. token,
https.TIMEOUT = 5 ["Content-Type"] = "application/json",
local httpRequest = parsed.scheme == 'http' and http.request or https.request ["Content-Length"] = #data,
local headers_request = socket.skip(1, httpRequest(request)) },
source = ltn12.source.string(data),
sink = ltn12.sink.table(sink),
}
local headers_request = socket.skip(1, http.request(request))
socketutil:reset_timeout()
if headers_request == nil then if headers_request == nil then
return nil return nil
end end
@ -73,20 +74,18 @@ function DropBoxApi:fetchListFolders(path, token)
end end
function DropBoxApi:downloadFile(path, token, local_path) function DropBoxApi:downloadFile(path, token, local_path)
local parsed = url.parse(API_DOWNLOAD_FILE)
local url_api = API_DOWNLOAD_FILE
local data1 = "{\"path\": \"" .. path .. "\"}" local data1 = "{\"path\": \"" .. path .. "\"}"
local headers = { ["Authorization"] = "Bearer ".. token, socketutil:set_timeout(socketutil.FILE_BLOCK_TIMEOUT, socketutil.FILE_TOTAL_TIMEOUT)
["Dropbox-API-Arg"] = data1} local code_return = socket.skip(1, http.request{
http.TIMEOUT = 5 url = API_DOWNLOAD_FILE,
https.TIMEOUT = 5 method = "GET",
local httpRequest = parsed.scheme == 'http' and http.request or https.request headers = {
local _, code_return, _ = httpRequest{ ["Authorization"] = "Bearer ".. token,
url = url_api, ["Dropbox-API-Arg"] = data1,
method = 'GET', },
headers = headers, sink = ltn12.sink.file(io.open(local_path, "w")),
sink = ltn12.sink.file(io.open(local_path, "w")) })
} socketutil:reset_timeout()
return code_return return code_return
end end

@ -1,11 +1,9 @@
local DocumentRegistry = require("document/documentregistry") local DocumentRegistry = require("document/documentregistry")
local FFIUtil = require("ffi/util") local FFIUtil = require("ffi/util")
local http = require('socket.http') local http = require("socket.http")
local https = require('ssl.https') local ltn12 = require("ltn12")
local ltn12 = require('ltn12') local socket = require("socket")
local mime = require('mime') local socketutil = require("socketutil")
local socket = require('socket')
local url = require('socket.url')
local util = require("util") local util = require("util")
local _ = require("gettext") local _ = require("gettext")
@ -77,23 +75,24 @@ function WebDavApi:listFolder(address, user, pass, folder_path)
webdav_url = webdav_url .. "/" webdav_url = webdav_url .. "/"
end end
local request, sink = {}, {} local sink = {}
local parsed = url.parse(webdav_url)
local data = [[<?xml version="1.0"?><a:propfind xmlns:a="DAV:"><a:prop><a:resourcetype/></a:prop></a:propfind>]] local data = [[<?xml version="1.0"?><a:propfind xmlns:a="DAV:"><a:prop><a:resourcetype/></a:prop></a:propfind>]]
local auth = string.format("%s:%s", user, pass) socketutil:set_timeout()
local headers = { ["Authorization"] = "Basic " .. mime.b64( auth ), local request = {
["Content-Type"] = "application/xml", url = webdav_url,
["Depth"] = "1", method = "PROPFIND",
["Content-Length"] = #data} headers = {
request["url"] = webdav_url ["Content-Type"] = "application/xml",
request["method"] = "PROPFIND" ["Depth"] = "1",
request["headers"] = headers ["Content-Length"] = #data,
request["source"] = ltn12.source.string(data) },
request["sink"] = ltn12.sink.table(sink) username = user,
http.TIMEOUT = 5 password = pass,
https.TIMEOUT = 5 source = ltn12.source.string(data),
local httpRequest = parsed.scheme == "http" and http.request or https.request sink = ltn12.sink.table(sink),
local headers_request = socket.skip(1, httpRequest(request)) }
local headers_request = socket.skip(1, http.request(request))
socketutil:reset_timeout()
if headers_request == nil then if headers_request == nil then
return nil return nil
end end
@ -152,18 +151,15 @@ function WebDavApi:listFolder(address, user, pass, folder_path)
end end
function WebDavApi:downloadFile(file_url, user, pass, local_path) function WebDavApi:downloadFile(file_url, user, pass, local_path)
local parsed = url.parse(file_url) socketutil:set_timeout(socketutil.FILE_BLOCK_TIMEOUT, socketutil.FILE_TOTAL_TIMEOUT)
local auth = string.format("%s:%s", user, pass) local code_return = socket.skip(1, http.request{
local headers = { ["Authorization"] = "Basic " .. mime.b64( auth ) } url = file_url,
http.TIMEOUT = 5 method = "GET",
https.TIMEOUT = 5 sink = ltn12.sink.file(io.open(local_path, "w")),
local httpRequest = parsed.scheme == "http" and http.request or https.request username = user,
local _, code_return, _ = httpRequest{ password = pass,
url = file_url, })
method = "GET", socketutil:reset_timeout()
headers = headers,
sink = ltn12.sink.file(io.open(local_path, "w"))
}
return code_return return code_return
end end

@ -965,25 +965,21 @@ end
function ReaderDictionary:downloadDictionary(dict, download_location, continue) function ReaderDictionary:downloadDictionary(dict, download_location, continue)
continue = continue or false continue = continue or false
local socket = require("socket") local socket = require("socket")
local socketutil = require("socketutil")
local http = socket.http local http = socket.http
local https = require("ssl.https")
local ltn12 = require("ltn12") local ltn12 = require("ltn12")
local url = socket.url
local parsed = url.parse(dict.url)
local httpRequest = parsed.scheme == "http" and http.request or https.request
if not continue then if not continue then
local file_size local file_size
--local r, c, h = httpRequest { -- Skip body & code args
local dummy, headers, dummy = socket.skip(1, httpRequest{ socketutil:set_timeout()
method = "HEAD", local headers = socket.skip(2, http.request{
url = dict.url, method = "HEAD",
url = dict.url,
--redirect = true, --redirect = true,
}) })
--logger.dbg(status) socketutil:reset_timeout()
--logger.dbg(headers) --logger.dbg(headers)
--logger.dbg(code)
file_size = headers and headers["content-length"] file_size = headers and headers["content-length"]
UIManager:show(ConfirmBox:new{ UIManager:show(ConfirmBox:new{
@ -1004,10 +1000,12 @@ function ReaderDictionary:downloadDictionary(dict, download_location, continue)
end) end)
end end
local dummy, c, dummy = httpRequest{ socketutil:set_timeout(socketutil.FILE_BLOCK_TIMEOUT, socketutil.FILE_TOTAL_TIMEOUT)
url = dict.url, local c = socket.skip(1, http.request{
sink = ltn12.sink.file(io.open(download_location, "w")), url = dict.url,
} sink = ltn12.sink.file(io.open(download_location, "w")),
})
socketutil:reset_timeout()
if c == 200 then if c == 200 then
logger.dbg("file downloaded to", download_location) logger.dbg("file downloaded to", download_location)
else else

@ -0,0 +1,141 @@
--[[--
This module contains miscellaneous helper functions specific to our usage of LuaSocket/LuaSec.
]]
local Version = require("version")
local http = require("socket.http")
local https = require("ssl.https")
local ltn12 = require("ltn12")
local socket = require("socket")
local socketutil = {
-- Init to the default LuaSocket/LuaSec values
block_timeout = 60,
total_timeout = -1,
}
--- Builds a sensible UserAgent that fits Wikipedia's UA policy <https://meta.wikimedia.org/wiki/User-Agent_policy>
local socket_ua = http.USERAGENT
socketutil.USER_AGENT = "KOReader/" .. Version:getShortVersion() .. " (https://koreader.rocks/) " .. socket_ua:gsub(" ", "/")
-- Monkey-patch it in LuaSocket, as it already takes care of inserting the appropriate header to its requests.
http.USERAGENT = socketutil.USER_AGENT
--- Common timeout values
-- Large content
socketutil.LARGE_BLOCK_TIMEOUT = 10
socketutil.LARGE_TOTAL_TIMEOUT = 30
-- File downloads
socketutil.FILE_BLOCK_TIMEOUT = 15
socketutil.FILE_TOTAL_TIMEOUT = 60
-- Upstream defaults
socketutil.DEFAULT_BLOCK_TIMEOUT = 60
socketutil.DEFAULT_TOTAL_TIMEOUT = -1
--- Update the timeout values.
-- Note that this only affects socket polling,
-- c.f., LuaSocket's timeout_getretry @ src/timeout.c & usage in src/usocket.c
-- Moreover, the timeout is actually *reset* between polls (via timeout_markstart, e.g. in buffer_meth_receive).
-- So, in practice, this timeout only helps *very* bad connections (on one end or the other),
-- and you'd be hard-pressed to ever hit the *total* timeout, since the starting point is reset extremely often.
-- In our case, we want to enforce an *actual* limit on how much time we're willing to block for, start to finish.
-- We do that via the custom sinks below, which will start ticking as soon as the first chunk of data is received.
-- To simplify, in most cases, the socket timeout matters *before* we receive data,
-- and the sink timeout *once* we've started receiving data (at which point the socket timeout is reset every chunk).
-- In practice, that means you don't want to set block_timeout too low,
-- as that's what the socket timeout will end up using most of the time.
-- Note that name resolution happens earlier and one level lower (e.g., glibc),
-- so name resolution delays will fall outside of these timeouts.
function socketutil:set_timeout(block_timeout, total_timeout)
self.block_timeout = block_timeout or 5
self.total_timeout = total_timeout or 15
-- Also update the actual LuaSocket & LuaSec constants, because:
-- 1. LuaSocket's `open` does a `settimeout` *after* create with this constant
-- 2. KOSync updates it to a stupidly low value
http.TIMEOUT = self.block_timeout
https.TIMEOUT = self.block_timeout
end
--- Reset timeout values to LuaSocket defaults.
function socketutil:reset_timeout()
self.block_timeout = self.DEFAULT_BLOCK_TIMEOUT
self.total_timeout = self.DEFAULT_TOTAL_TIMEOUT
http.TIMEOUT = self.block_timeout
https.TIMEOUT = self.block_timeout
end
--- Monkey-patch LuaSocket's `socket.tcp` in order to honor tighter timeouts, to avoid blocking the UI for too long.
-- NOTE: While we could use a custom `create` function for HTTP LuaSocket `request`s,
-- with HTTPS, the way LuaSocket/LuaSec handles those is much more finicky,
-- because LuaSocket's adjustrequest function (in http.lua) passes the adjusted nreqt table to it,
-- but only when it does the automagic scheme handling, not when it's set by the caller :/.
-- And LuaSec's own `request` function overload *forbids* setting create, because of similar shenanigans...
-- TL;DR: Just monkey-patching socket.tcp directly will affect both HTTP & HTTPS
-- without us having to maintain a tweaked version of LuaSec's `https.tcp` function...
local real_socket_tcp = socket.tcp
function socketutil.tcp()
-- Based on https://stackoverflow.com/a/6021774
local req_sock = real_socket_tcp()
req_sock:settimeout(socketutil.block_timeout, "b")
req_sock:settimeout(socketutil.total_timeout, "t")
return req_sock
end
socket.tcp = socketutil.tcp
--- Various timeout return codes
socketutil.TIMEOUT_CODE = "timeout" -- from LuaSocket's io.c
socketutil.SSL_HANDSHAKE_CODE = "wantread" -- from LuaSec's ssl.c
socketutil.SINK_TIMEOUT_CODE = "sink timeout" -- from our own socketutil
-- NOTE: Use os.time() for simplicity's sake (we don't really need subsecond precision).
-- LuaSocket itself is already using gettimeofday anyway (although it does the maths, like ffi/util's getTimestamp).
-- Proper etiquette would have everyone using clock_gettime(CLOCK_MONOTONIC) for this kind of stuff,
-- but it's a tad more annoying to use because it's stuffed in librt in old glibc versions,
-- and I have no idea what macOS & Android do with it (but it is POSIX). Plus, win32.
--- Custom version of `ltn12.sink.table` that honors total_timeout
function socketutil.table_sink(t)
if socketutil.total_timeout < 0 then
return ltn12.sink.table(t)
end
local start_ts = os.time()
t = t or {}
local f = function(chunk, err)
if chunk then
if os.time() - start_ts > socketutil.total_timeout then
return nil, socketutil.SINK_TIMEOUT_CODE
end
table.insert(t, chunk)
end
return 1
end
return f, t
end
--- Custom version of `ltn12.sink.file` that honors total_timeout
function socketutil.file_sink(handle, io_err)
if socketutil.total_timeout < 0 then
return ltn12.sink.file(handle, io_err)
end
if handle then
local start_ts = os.time()
return function(chunk, err)
if not chunk then
handle:close()
return 1
else
if os.time() - start_ts > socketutil.total_timeout then
handle:close()
return nil, socketutil.SINK_TIMEOUT_CODE
end
return handle:write(chunk)
end
end
else
return nil, io_err or "unable to open file"
end
end
return socketutil

@ -166,6 +166,8 @@ end
function OTAManager:checkUpdate() function OTAManager:checkUpdate()
local http = require("socket.http") local http = require("socket.http")
local ltn12 = require("ltn12") local ltn12 = require("ltn12")
local socket = require("socket")
local socketutil = require("socketutil")
local update_file = (self:getOTAType() == "link") and self:getLinkFilename() or self:getZsyncFilename() local update_file = (self:getOTAType() == "link") and self:getLinkFilename() or self:getZsyncFilename()
@ -173,11 +175,14 @@ function OTAManager:checkUpdate()
local local_update_file = ota_dir .. update_file local local_update_file = ota_dir .. update_file
-- download zsync file from OTA server -- download zsync file from OTA server
logger.dbg("downloading update file", ota_update_file) logger.dbg("downloading update file", ota_update_file)
local _, c, _ = http.request{ socketutil:set_timeout()
url = ota_update_file, local code, _, status = socket.skip(1, http.request{
sink = ltn12.sink.file(io.open(local_update_file, "w"))} url = ota_update_file,
if c ~= 200 then sink = ltn12.sink.file(io.open(local_update_file, "w")),
logger.warn("cannot find update file", c) })
socketutil:reset_timeout()
if code ~= 200 then
logger.warn("cannot find update file:", status or code or "network unreachable")
return return
end end
-- parse OTA package version -- parse OTA package version

@ -284,13 +284,12 @@ Returns decoded JSON table from translate server.
@treturn string result, or nil @treturn string result, or nil
--]] --]]
function Translator:loadPage(text, target_lang, source_lang) function Translator:loadPage(text, target_lang, source_lang)
local socket = require('socket') local socket = require("socket")
local url = require('socket.url') local socketutil = require("socketutil")
local http = require('socket.http') local url = require("socket.url")
local https = require('ssl.https') local http = require("socket.http")
local ltn12 = require('ltn12') local ltn12 = require("ltn12")
local request, sink = {}, {}
local query = "" local query = ""
self.trans_params.tl = target_lang self.trans_params.tl = target_lang
self.trans_params.sl = source_lang self.trans_params.sl = source_lang
@ -308,25 +307,24 @@ function Translator:loadPage(text, target_lang, source_lang)
parsed.query = query .. "q=" .. url.escape(text) parsed.query = query .. "q=" .. url.escape(text)
-- HTTP request -- HTTP request
request['url'] = url.build(parsed) local sink = {}
socketutil:set_timeout()
local request = {
url = url.build(parsed),
method = "GET",
sink = ltn12.sink.table(sink),
}
logger.dbg("Calling", request.url) logger.dbg("Calling", request.url)
request['method'] = 'GET' -- Skip first argument (body, goes to the sink)
request['sink'] = ltn12.sink.table(sink) local code, headers, status = socket.skip(1, http.request(request))
-- We may try to set a common User-Agent if it happens we're 403 Forbidden socketutil:reset_timeout()
-- request['headers'] = {
-- ["User-Agent"] = "Mozilla/5.0 (Windows NT 10.0; Win64; x64)",
-- }
http.TIMEOUT, https.TIMEOUT = 10, 10
local httpRequest = parsed.scheme == 'http' and http.request or https.request
-- first argument returned by skip is code
local _, headers, status = socket.skip(1, httpRequest(request))
-- raise error message when network is unavailable -- raise error message when network is unavailable
if headers == nil then if headers == nil then
error("Network is unreachable") error("Network is unreachable")
end end
if status ~= "HTTP/1.1 200 OK" then if code ~= 200 then
logger.warn("translator HTTP status not okay:", status) logger.warn("translator HTTP status not okay:", status)
return return
end end

@ -95,84 +95,40 @@ function Wikipedia:getWikiServer(lang)
return string.format(self.wiki_server, lang or self.default_lang) return string.format(self.wiki_server, lang or self.default_lang)
end end
-- Say who we are to Wikipedia (see https://meta.wikimedia.org/wiki/User-Agent_policy)
local USER_AGENT = T("KOReader/%1 (https://koreader.rocks/) %2",
(lfs.attributes("git-rev", "mode") == "file" and io.open("git-rev", "r"):read("*line") or "devel"),
require('socket.http').USERAGENT:gsub(" ", "/") )
-- Codes that getUrlContent may get from requester.request()
local TIMEOUT_CODE = "timeout" -- from socket.lua
local MAXTIME_CODE = "maxtime reached" -- from sink_table_with_maxtime
-- Sink that stores into a table, aborting if maxtime has elapsed
local function sink_table_with_maxtime(t, maxtime)
-- Start counting as soon as this sink is created
local start_secs, start_usecs = ffiutil.gettime()
local starttime = start_secs + start_usecs/1000000
t = t or {}
local f = function(chunk, err)
local secs, usecs = ffiutil.gettime()
if secs + usecs/1000000 - starttime > maxtime then
return nil, MAXTIME_CODE
end
if chunk then table.insert(t, chunk) end
return 1
end
return f, t
end
-- Get URL content -- Get URL content
local function getUrlContent(url, timeout, maxtime) local function getUrlContent(url, timeout, maxtime)
local socket = require('socket') local http = require("socket.http")
local ltn12 = require('ltn12') local ltn12 = require("ltn12")
local http = require('socket.http') local socket = require("socket")
local https = require('ssl.https') local socketutil = require("socketutil")
local socket_url = require("socket.url")
local requester
if url:sub(1,7) == "http://" then local parsed = socket_url.parse(url)
requester = http if parsed.scheme ~= "http" and parsed.scheme ~= "https" then
elseif url:sub(1,8) == "https://" then
requester = https
else
return false, "Unsupported protocol" return false, "Unsupported protocol"
end end
if not timeout then timeout = 10 end if not timeout then timeout = 10 end
-- timeout needs to be set to 'http', even if we use 'https'
http.TIMEOUT, https.TIMEOUT = timeout, timeout
local request = {}
local sink = {} local sink = {}
request['url'] = url socketutil:set_timeout(timeout, maxtime or 30)
request['method'] = 'GET' local request = {
request['headers'] = { url = url,
["User-Agent"] = USER_AGENT, method = "GET",
sink = maxtime and socketutil.table_sink(sink) or ltn12.sink.table(sink),
} }
-- 'timeout' delay works on socket, and is triggered when
-- that time has passed trying to connect, or after connection
-- when no data has been read for this time.
-- On a slow connection, it may not be triggered (as we could read
-- 1 byte every 1 second, not triggering any timeout).
-- 'maxtime' can be provided to overcome that, and we start counting
-- as soon as the first content byte is received (but it is checked
-- for only when data is received).
-- Setting 'maxtime' and 'timeout' gives more chance to abort the request when
-- it takes too much time (in the worst case: in timeout+maxtime seconds).
-- But time taken by DNS lookup cannot easily be accounted for, so
-- a request may (when dns lookup takes time) exceed timeout and maxtime...
if maxtime then
request['sink'] = sink_table_with_maxtime(sink, maxtime)
else
request['sink'] = ltn12.sink.table(sink)
end
local code, headers, status = socket.skip(1, requester.request(request)) local code, headers, status = socket.skip(1, http.request(request))
socketutil:reset_timeout()
local content = table.concat(sink) -- empty or content accumulated till now local content = table.concat(sink) -- empty or content accumulated till now
-- logger.dbg("code:", code) -- logger.dbg("code:", code)
-- logger.dbg("headers:", headers) -- logger.dbg("headers:", headers)
-- logger.dbg("status:", status) -- logger.dbg("status:", status)
-- logger.dbg("#content:", #content) -- logger.dbg("#content:", #content)
if code == TIMEOUT_CODE or code == MAXTIME_CODE then if code == socketutil.TIMEOUT_CODE or
code == socketutil.SSL_HANDSHAKE_CODE or
code == socketutil.SINK_TIMEOUT_CODE
then
logger.warn("request interrupted:", code) logger.warn("request interrupted:", code)
return false, code return false, code
end end
@ -212,7 +168,7 @@ local WIKIPEDIA_IMAGES = 4
-- return decoded JSON table from Wikipedia -- return decoded JSON table from Wikipedia
--]] --]]
function Wikipedia:loadPage(text, lang, page_type, plain) function Wikipedia:loadPage(text, lang, page_type, plain)
local url = require('socket.url') local url = require("socket.url")
local query = "" local query = ""
local parsed = url.parse(self:getWikiServer(lang)) local parsed = url.parse(self:getWikiServer(lang))
parsed.path = self.wiki_path parsed.path = self.wiki_path

@ -1,6 +1,7 @@
local json = require("json")
local http = require("socket.http") local http = require("socket.http")
local json = require("json")
local ltn12 = require("ltn12") local ltn12 = require("ltn12")
local socketutil = require("socketutil")
local JoplinClient = { local JoplinClient = {
server_ip = "localhost", server_ip = "localhost",
@ -19,16 +20,18 @@ function JoplinClient:_makeRequest(url, method, request_body)
local sink = {} local sink = {}
local request_body_json = json.encode(request_body) local request_body_json = json.encode(request_body)
local source = ltn12.source.string(request_body_json) local source = ltn12.source.string(request_body_json)
socketutil:set_timeout(socketutil.LARGE_BLOCK_TIMEOUT, socketutil.LARGE_TOTAL_TIMEOUT)
http.request{ http.request{
url = url, url = url,
method = method, method = method,
sink = ltn12.sink.table(sink), sink = ltn12.sink.table(sink),
source = source, source = source,
headers = { headers = {
["Content-Length"] = #request_body_json, ["Content-Length"] = #request_body_json,
["Content-Type"] = "application/json" ["Content-Type"] = "application/json"
} },
} }
socketutil:reset_timeout()
if not sink[1] then if not sink[1] then
error("No response from Joplin Server") error("No response from Joplin Server")

@ -2,10 +2,10 @@ local InputContainer = require("ui/widget/container/inputcontainer")
local GoodreadsBook = require("goodreadsbook") local GoodreadsBook = require("goodreadsbook")
local InfoMessage = require("ui/widget/infomessage") local InfoMessage = require("ui/widget/infomessage")
local UIManager = require("ui/uimanager") local UIManager = require("ui/uimanager")
local url = require('socket.url') local http = require("socket.http")
local socket = require('socket') local ltn12 = require("ltn12")
local https = require('ssl.https') local socket = require("socket")
local ltn12 = require('ltn12') local socketutil = require("socketutil")
local _ = require("gettext") local _ = require("gettext")
local GoodreadsApi = InputContainer:new { local GoodreadsApi = InputContainer:new {
@ -42,14 +42,15 @@ local function genIdUrl(id, userApi)
end end
function GoodreadsApi:fetchXml(s_url) function GoodreadsApi:fetchXml(s_url)
local request, sink = {}, {} local sink = {}
local parsed = url.parse(s_url) socketutil:set_timeout()
request['url'] = s_url local request = {
request['method'] = 'GET' url = s_url,
request['sink'] = ltn12.sink.table(sink) method = "GET",
https.TIMEOUT = 5 sink = ltn12.sink.table(sink),
local httpsRequest = parsed.scheme == 'https' and https.request }
local headers = socket.skip(1, httpsRequest(request)) local headers = socket.skip(2, http.request(request))
socketutil:reset_timeout()
if headers == nil then if headers == nil then
return nil return nil
end end

@ -19,7 +19,7 @@ local TextWidget = require("ui/widget/textwidget")
local UIManager = require("ui/uimanager") local UIManager = require("ui/uimanager")
local VerticalGroup = require("ui/widget/verticalgroup") local VerticalGroup = require("ui/widget/verticalgroup")
local VerticalSpan = require("ui/widget/verticalspan") local VerticalSpan = require("ui/widget/verticalspan")
local https = require('ssl.https') local https = require("ssl.https")
local _ = require("gettext") local _ = require("gettext")
local Screen = require("device").screen local Screen = require("device").screen
local T = require("ffi/util").template local T = require("ffi/util").template

@ -1,11 +1,11 @@
local Version = require("version") local Version = require("version")
local ffiutil = require("ffi/util") local ffiutil = require("ffi/util")
local http = require("socket.http") local http = require("socket.http")
local https = require("ssl.https")
local logger = require("logger") local logger = require("logger")
local ltn12 = require("ltn12") local ltn12 = require("ltn12")
local socket = require("socket") local socket = require("socket")
local socket_url = require("socket.url") local socket_url = require("socket.url")
local socketutil = require("socketutil")
local _ = require("gettext") local _ = require("gettext")
local T = ffiutil.template local T = ffiutil.template
@ -20,10 +20,6 @@ local EpubDownloadBackend = {
} }
local max_redirects = 5; --prevent infinite redirects local max_redirects = 5; --prevent infinite redirects
-- Codes that getUrlContent may get from requester.request()
local TIMEOUT_CODE = "timeout" -- from socket.lua
local MAXTIME_CODE = "maxtime reached" -- from sink_table_with_maxtime
-- filter HTML using CSS selector -- filter HTML using CSS selector
local function filter(text, element) local function filter(text, element)
local htmlparser = require("htmlparser") local htmlparser = require("htmlparser")
@ -72,23 +68,6 @@ local function filter(text, element)
return "<!DOCTYPE html><html><head></head><body>" .. filtered .. "</body></html>" return "<!DOCTYPE html><html><head></head><body>" .. filtered .. "</body></html>"
end end
-- Sink that stores into a table, aborting if maxtime has elapsed
local function sink_table_with_maxtime(t, maxtime)
-- Start counting as soon as this sink is created
local start_secs, start_usecs = ffiutil.gettime()
local starttime = start_secs + start_usecs/1000000
t = t or {}
local f = function(chunk, err)
local secs, usecs = ffiutil.gettime()
if secs + usecs/1000000 - starttime > maxtime then
return nil, MAXTIME_CODE
end
if chunk then table.insert(t, chunk) end
return 1
end
return f, t
end
-- Get URL content -- Get URL content
local function getUrlContent(url, timeout, maxtime, redirectCount) local function getUrlContent(url, timeout, maxtime, redirectCount)
logger.dbg("getUrlContent(", url, ",", timeout, ",", maxtime, ",", redirectCount, ")") logger.dbg("getUrlContent(", url, ",", timeout, ",", maxtime, ",", redirectCount, ")")
@ -100,35 +79,19 @@ local function getUrlContent(url, timeout, maxtime, redirectCount)
if not timeout then timeout = 10 end if not timeout then timeout = 10 end
logger.dbg("timeout:", timeout) logger.dbg("timeout:", timeout)
-- timeout needs to be set to "http", even if we use "https"
--http.TIMEOUT, https.TIMEOUT = timeout, timeout
-- "timeout" delay works on socket, and is triggered when local sink = {}
-- that time has passed trying to connect, or after connection
-- when no data has been read for this time.
-- On a slow connection, it may not be triggered (as we could read
-- 1 byte every 1 second, not triggering any timeout).
-- "maxtime" can be provided to overcome that, and we start counting
-- as soon as the first content byte is received (but it is checked
-- for only when data is received).
-- Setting "maxtime" and "timeout" gives more chance to abort the request when
-- it takes too much time (in the worst case: in timeout+maxtime seconds).
-- But time taken by DNS lookup cannot easily be accounted for, so
-- a request may (when dns lookup takes time) exceed timeout and maxtime...
local request, sink = {}, {}
if maxtime then
request.sink = sink_table_with_maxtime(sink, maxtime)
else
request.sink = ltn12.sink.table(sink)
end
request.url = url
request.method = "GET"
local parsed = socket_url.parse(url) local parsed = socket_url.parse(url)
socketutil:set_timeout(timeout, maxtime or 30)
local httpRequest = parsed.scheme == "http" and http.request or https.request local request = {
url = url,
method = "GET",
sink = maxtime and socketutil.table_sink(sink) or ltn12.sink.table(sink),
}
logger.dbg("request:", request) logger.dbg("request:", request)
local code, headers, status = socket.skip(1, httpRequest(request)) local code, headers, status = socket.skip(1, http.request(request))
logger.dbg("After httpRequest") socketutil:reset_timeout()
logger.dbg("After http.request")
local content = table.concat(sink) -- empty or content accumulated till now local content = table.concat(sink) -- empty or content accumulated till now
logger.dbg("type(code):", type(code)) logger.dbg("type(code):", type(code))
logger.dbg("code:", code) logger.dbg("code:", code)
@ -136,7 +99,10 @@ local function getUrlContent(url, timeout, maxtime, redirectCount)
logger.dbg("status:", status) logger.dbg("status:", status)
logger.dbg("#content:", #content) logger.dbg("#content:", #content)
if code == TIMEOUT_CODE or code == MAXTIME_CODE then if code == socketutil.TIMEOUT_CODE or
code == socketutil.SSL_HANDSHAKE_CODE or
code == socketutil.SINK_TIMEOUT_CODE
then
logger.warn("request interrupted:", code) logger.warn("request interrupted:", code)
return false, code return false, code
end end

@ -1,9 +1,8 @@
local http = require("socket.http") local http = require("socket.http")
local https = require("ssl.https")
local logger = require("logger") local logger = require("logger")
local ltn12 = require("ltn12") local ltn12 = require("ltn12")
local socket = require('socket') local socket = require("socket")
local socket_url = require("socket.url") local socketutil = require("socketutil")
local InternalDownloadBackend = {} local InternalDownloadBackend = {}
local max_redirects = 5; --prevent infinite redirects local max_redirects = 5; --prevent infinite redirects
@ -15,13 +14,14 @@ function InternalDownloadBackend:getResponseAsString(url, redirectCount)
error("InternalDownloadBackend: reached max redirects: ", redirectCount) error("InternalDownloadBackend: reached max redirects: ", redirectCount)
end end
logger.dbg("InternalDownloadBackend: url :", url) logger.dbg("InternalDownloadBackend: url :", url)
local request, sink = {}, {} local sink = {}
request['sink'] = ltn12.sink.table(sink) socketutil:set_timeout(socketutil.LARGE_BLOCK_TIMEOUT, socketutil.LARGE_TOTAL_TIMEOUT)
request['url'] = url local request = {
local parsed = socket_url.parse(url) url = url,
sink = ltn12.sink.table(sink),
local httpRequest = parsed.scheme == 'http' and http.request or https.request; }
local code, headers, status = socket.skip(1, httpRequest(request)) local code, headers, status = socket.skip(1, http.request(request))
socketutil:reset_timeout()
if code ~= 200 then if code ~= 200 then
logger.dbg("InternalDownloadBackend: HTTP response code <> 200. Response status: ", status) logger.dbg("InternalDownloadBackend: HTTP response code <> 200. Response status: ", status)

@ -13,12 +13,13 @@ local NetworkMgr = require("ui/network/manager")
local OPDSParser = require("opdsparser") local OPDSParser = require("opdsparser")
local Screen = require("device").screen local Screen = require("device").screen
local UIManager = require("ui/uimanager") local UIManager = require("ui/uimanager")
local http = require('socket.http') local http = require("socket.http")
local lfs = require("libs/libkoreader-lfs") local lfs = require("libs/libkoreader-lfs")
local logger = require("logger") local logger = require("logger")
local ltn12 = require('ltn12') local ltn12 = require("ltn12")
local socket = require('socket') local socket = require("socket")
local url = require('socket.url') local socketutil = require("socketutil")
local url = require("socket.url")
local util = require("util") local util = require("util")
local _ = require("gettext") local _ = require("gettext")
local T = require("ffi/util").template local T = require("ffi/util").template
@ -268,19 +269,21 @@ end
function OPDSBrowser:fetchFeed(item_url, username, password, method) function OPDSBrowser:fetchFeed(item_url, username, password, method)
local sink = {} local sink = {}
socketutil:set_timeout(socketutil.LARGE_BLOCK_TIMEOUT, socketutil.LARGE_TOTAL_TIMEOUT)
local request = { local request = {
url = item_url, url = item_url,
method = method and method or "GET", method = method and method or "GET",
-- Explicitly specify that we don't support compressed content. Some servers will still break RFC2616 14.3 and send crap instead. -- Explicitly specify that we don't support compressed content. Some servers will still break RFC2616 14.3 and send crap instead.
headers = { ["Accept-Encoding"] = "identity", }, headers = {
["Accept-Encoding"] = "identity",
},
sink = ltn12.sink.table(sink), sink = ltn12.sink.table(sink),
username = username, username = username,
password = password password = password,
} }
logger.info("Request:", request) logger.info("Request:", request)
http.TIMEOUT = 10 local code, headers = socket.skip(1, http.request(request))
local httpRequest = http.request socketutil:reset_timeout()
local code, headers = socket.skip(1, httpRequest(request))
-- raise error message when network is unavailable -- raise error message when network is unavailable
if headers == nil then if headers == nil then
error(code) error(code)
@ -560,26 +563,20 @@ function OPDSBrowser:downloadFile(item, filetype, remote_url)
UIManager:scheduleIn(1, function() UIManager:scheduleIn(1, function()
logger.dbg("Downloading file", local_path, "from", remote_url) logger.dbg("Downloading file", local_path, "from", remote_url)
local parsed = url.parse(remote_url) local parsed = url.parse(remote_url)
http.TIMEOUT = 20
local dummy, code, headers local code, headers
if parsed.scheme == "http" or parsed.scheme == "https" then
if parsed.scheme == "http" then socketutil:set_timeout(socketutil.FILE_BLOCK_TIMEOUT, socketutil.FILE_TOTAL_TIMEOUT)
dummy, code, headers = http.request { code, headers = socket.skip(1, http.request {
url = remote_url,
headers = { ["Accept-Encoding"] = "identity", },
sink = ltn12.sink.file(io.open(local_path, "w")),
user = item.username,
password = item.password
}
elseif parsed.scheme == "https" then
dummy, code, headers = http.request {
url = remote_url, url = remote_url,
headers = { ["Accept-Encoding"] = "identity", }, headers = {
["Accept-Encoding"] = "identity",
},
sink = ltn12.sink.file(io.open(local_path, "w")), sink = ltn12.sink.file(io.open(local_path, "w")),
user = item.username, user = item.username,
password = item.password password = item.password,
} })
socketutil:reset_timeout()
else else
UIManager:show(InfoMessage:new { UIManager:show(InfoMessage:new {
text = T(_("Invalid protocol:\n%1"), parsed.scheme), text = T(_("Invalid protocol:\n%1"), parsed.scheme),

@ -26,6 +26,7 @@ local http = require("socket.http")
local logger = require("logger") local logger = require("logger")
local ltn12 = require("ltn12") local ltn12 = require("ltn12")
local socket = require("socket") local socket = require("socket")
local socketutil = require("socketutil")
local util = require("util") local util = require("util")
local _ = require("gettext") local _ = require("gettext")
local T = FFIUtil.template local T = FFIUtil.template
@ -379,7 +380,7 @@ function Wallabag:getBearerToken()
["Content-type"] = "application/json", ["Content-type"] = "application/json",
["Accept"] = "application/json, */*", ["Accept"] = "application/json, */*",
["Content-Length"] = tostring(#bodyJSON), ["Content-Length"] = tostring(#bodyJSON),
} }
local result = self:callAPI("POST", login_url, headers, bodyJSON, "") local result = self:callAPI("POST", login_url, headers, bodyJSON, "")
if result then if result then
@ -539,21 +540,25 @@ end
-- filepath: downloads the file if provided, returns JSON otherwise -- filepath: downloads the file if provided, returns JSON otherwise
---- @todo separate call to internal API from the download on external server ---- @todo separate call to internal API from the download on external server
function Wallabag:callAPI(method, apiurl, headers, body, filepath, quiet) function Wallabag:callAPI(method, apiurl, headers, body, filepath, quiet)
local request, sink = {}, {} local sink = {}
local request = {}
-- Is it an API call, or a regular file direct download? -- Is it an API call, or a regular file direct download?
if apiurl:sub(1, 1) == "/" then if apiurl:sub(1, 1) == "/" then
-- API call to our server, has the form "/random/api/call" -- API call to our server, has the form "/random/api/call"
request.url = self.server_url .. apiurl request.url = self.server_url .. apiurl
if headers == nil then if headers == nil then
headers = { ["Authorization"] = "Bearer " .. self.access_token, } headers = {
["Authorization"] = "Bearer " .. self.access_token,
}
end end
else else
-- regular url link to a foreign server -- regular url link to a foreign server
local file_url = apiurl local file_url = apiurl
request.url = file_url request.url = file_url
if headers == nil then if headers == nil then
headers = {} -- no need for a token here -- no need for a token here
headers = {}
end end
end end
@ -570,9 +575,9 @@ function Wallabag:callAPI(method, apiurl, headers, body, filepath, quiet)
logger.dbg("Wallabag: URL ", request.url) logger.dbg("Wallabag: URL ", request.url)
logger.dbg("Wallabag: method ", method) logger.dbg("Wallabag: method ", method)
http.TIMEOUT = 30 socketutil:set_timeout(socketutil.LARGE_BLOCK_TIMEOUT, socketutil.LARGE_TOTAL_TIMEOUT)
local httpRequest = http.request local code, resp_headers = socket.skip(1, http.request(request))
local code, resp_headers = socket.skip(1, httpRequest(request)) socketutil:reset_timeout()
-- raise error message when network is unavailable -- raise error message when network is unavailable
if resp_headers == nil then if resp_headers == nil then
logger.dbg("Wallabag: Server error: ", code) logger.dbg("Wallabag: Server error: ", code)

Loading…
Cancel
Save