Unify LuaSocket usage (#7405)

* Add a new socketutil module with a few helper functions that allow us to:
  * Always use a sane User-Agent (previously, only Wikipedia did so)
  * Set timeouts in an almost sane manner. Doing it explicitly prevents an interaction with KOSync that does crazy stuff I don't even want to try to understand.
* Unified said timeouts based on the request's intended usage (except for Wikipedia, which already had meaningful timeout values).
* Stopped using LuaSec directly, LuaSocket defers to LuaSec sanely on its own. Everything now transparently supports HTTPS without code duplication.
pull/7412/head
NiLuJe 3 years ago committed by GitHub
parent 89c0578c8d
commit 2f9db25969
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,10 +1,9 @@
local DocumentRegistry = require("document/documentregistry")
local JSON = require("json")
local http = require('socket.http')
local https = require('ssl.https')
local ltn12 = require('ltn12')
local socket = require('socket')
local url = require('socket.url')
local http = require("socket.http")
local ltn12 = require("ltn12")
local socket = require("socket")
local socketutil = require("socketutil")
local _ = require("gettext")
local DropBoxApi = {
@ -15,17 +14,18 @@ local API_LIST_FOLDER = "https://api.dropboxapi.com/2/files/list_folder"
local API_DOWNLOAD_FILE = "https://content.dropboxapi.com/2/files/download"
function DropBoxApi:fetchInfo(token)
local request, sink = {}, {}
local parsed = url.parse(API_URL_INFO)
request['url'] = API_URL_INFO
request['method'] = 'POST'
local headers = { ["Authorization"] = "Bearer ".. token }
request['headers'] = headers
request['sink'] = ltn12.sink.table(sink)
http.TIMEOUT = 5
https.TIMEOUT = 5
local httpRequest = parsed.scheme == 'http' and http.request or https.request
local headers_request = socket.skip(1, httpRequest(request))
local sink = {}
socketutil:set_timeout()
local request = {
url = API_URL_INFO,
method = "POST",
headers = {
["Authorization"] = "Bearer " .. token,
},
sink = ltn12.sink.table(sink),
}
local headers_request = socket.skip(1, http.request(request))
socketutil:reset_timeout()
local result_response = table.concat(sink)
if headers_request == nil then
return nil
@ -39,23 +39,24 @@ function DropBoxApi:fetchInfo(token)
end
function DropBoxApi:fetchListFolders(path, token)
local request, sink = {}, {}
if path == nil or path == "/" then path = "" end
local parsed = url.parse(API_LIST_FOLDER)
request['url'] = API_LIST_FOLDER
request['method'] = 'POST'
local data = "{\"path\": \"" .. path .. "\",\"recursive\": false,\"include_media_info\": false,"..
"\"include_deleted\": false,\"include_has_explicit_shared_members\": false}"
local headers = { ["Authorization"] = "Bearer ".. token,
["Content-Type"] = "application/json" ,
["Content-Length"] = #data}
request['headers'] = headers
request['source'] = ltn12.source.string(data)
request['sink'] = ltn12.sink.table(sink)
http.TIMEOUT = 5
https.TIMEOUT = 5
local httpRequest = parsed.scheme == 'http' and http.request or https.request
local headers_request = socket.skip(1, httpRequest(request))
local sink = {}
socketutil:set_timeout()
local request = {
url = API_LIST_FOLDER,
method = "POST",
headers = {
["Authorization"] = "Bearer ".. token,
["Content-Type"] = "application/json",
["Content-Length"] = #data,
},
source = ltn12.source.string(data),
sink = ltn12.sink.table(sink),
}
local headers_request = socket.skip(1, http.request(request))
socketutil:reset_timeout()
if headers_request == nil then
return nil
end
@ -73,20 +74,18 @@ function DropBoxApi:fetchListFolders(path, token)
end
function DropBoxApi:downloadFile(path, token, local_path)
local parsed = url.parse(API_DOWNLOAD_FILE)
local url_api = API_DOWNLOAD_FILE
local data1 = "{\"path\": \"" .. path .. "\"}"
local headers = { ["Authorization"] = "Bearer ".. token,
["Dropbox-API-Arg"] = data1}
http.TIMEOUT = 5
https.TIMEOUT = 5
local httpRequest = parsed.scheme == 'http' and http.request or https.request
local _, code_return, _ = httpRequest{
url = url_api,
method = 'GET',
headers = headers,
sink = ltn12.sink.file(io.open(local_path, "w"))
}
socketutil:set_timeout(socketutil.FILE_BLOCK_TIMEOUT, socketutil.FILE_TOTAL_TIMEOUT)
local code_return = socket.skip(1, http.request{
url = API_DOWNLOAD_FILE,
method = "GET",
headers = {
["Authorization"] = "Bearer ".. token,
["Dropbox-API-Arg"] = data1,
},
sink = ltn12.sink.file(io.open(local_path, "w")),
})
socketutil:reset_timeout()
return code_return
end

@ -1,11 +1,9 @@
local DocumentRegistry = require("document/documentregistry")
local FFIUtil = require("ffi/util")
local http = require('socket.http')
local https = require('ssl.https')
local ltn12 = require('ltn12')
local mime = require('mime')
local socket = require('socket')
local url = require('socket.url')
local http = require("socket.http")
local ltn12 = require("ltn12")
local socket = require("socket")
local socketutil = require("socketutil")
local util = require("util")
local _ = require("gettext")
@ -77,23 +75,24 @@ function WebDavApi:listFolder(address, user, pass, folder_path)
webdav_url = webdav_url .. "/"
end
local request, sink = {}, {}
local parsed = url.parse(webdav_url)
local sink = {}
local data = [[<?xml version="1.0"?><a:propfind xmlns:a="DAV:"><a:prop><a:resourcetype/></a:prop></a:propfind>]]
local auth = string.format("%s:%s", user, pass)
local headers = { ["Authorization"] = "Basic " .. mime.b64( auth ),
["Content-Type"] = "application/xml",
["Depth"] = "1",
["Content-Length"] = #data}
request["url"] = webdav_url
request["method"] = "PROPFIND"
request["headers"] = headers
request["source"] = ltn12.source.string(data)
request["sink"] = ltn12.sink.table(sink)
http.TIMEOUT = 5
https.TIMEOUT = 5
local httpRequest = parsed.scheme == "http" and http.request or https.request
local headers_request = socket.skip(1, httpRequest(request))
socketutil:set_timeout()
local request = {
url = webdav_url,
method = "PROPFIND",
headers = {
["Content-Type"] = "application/xml",
["Depth"] = "1",
["Content-Length"] = #data,
},
username = user,
password = pass,
source = ltn12.source.string(data),
sink = ltn12.sink.table(sink),
}
local headers_request = socket.skip(1, http.request(request))
socketutil:reset_timeout()
if headers_request == nil then
return nil
end
@ -152,18 +151,15 @@ function WebDavApi:listFolder(address, user, pass, folder_path)
end
function WebDavApi:downloadFile(file_url, user, pass, local_path)
local parsed = url.parse(file_url)
local auth = string.format("%s:%s", user, pass)
local headers = { ["Authorization"] = "Basic " .. mime.b64( auth ) }
http.TIMEOUT = 5
https.TIMEOUT = 5
local httpRequest = parsed.scheme == "http" and http.request or https.request
local _, code_return, _ = httpRequest{
url = file_url,
method = "GET",
headers = headers,
sink = ltn12.sink.file(io.open(local_path, "w"))
}
socketutil:set_timeout(socketutil.FILE_BLOCK_TIMEOUT, socketutil.FILE_TOTAL_TIMEOUT)
local code_return = socket.skip(1, http.request{
url = file_url,
method = "GET",
sink = ltn12.sink.file(io.open(local_path, "w")),
username = user,
password = pass,
})
socketutil:reset_timeout()
return code_return
end

@ -965,25 +965,21 @@ end
function ReaderDictionary:downloadDictionary(dict, download_location, continue)
continue = continue or false
local socket = require("socket")
local socketutil = require("socketutil")
local http = socket.http
local https = require("ssl.https")
local ltn12 = require("ltn12")
local url = socket.url
local parsed = url.parse(dict.url)
local httpRequest = parsed.scheme == "http" and http.request or https.request
if not continue then
local file_size
--local r, c, h = httpRequest {
local dummy, headers, dummy = socket.skip(1, httpRequest{
method = "HEAD",
url = dict.url,
-- Skip body & code args
socketutil:set_timeout()
local headers = socket.skip(2, http.request{
method = "HEAD",
url = dict.url,
--redirect = true,
})
--logger.dbg(status)
socketutil:reset_timeout()
--logger.dbg(headers)
--logger.dbg(code)
file_size = headers and headers["content-length"]
UIManager:show(ConfirmBox:new{
@ -1004,10 +1000,12 @@ function ReaderDictionary:downloadDictionary(dict, download_location, continue)
end)
end
local dummy, c, dummy = httpRequest{
url = dict.url,
sink = ltn12.sink.file(io.open(download_location, "w")),
}
socketutil:set_timeout(socketutil.FILE_BLOCK_TIMEOUT, socketutil.FILE_TOTAL_TIMEOUT)
local c = socket.skip(1, http.request{
url = dict.url,
sink = ltn12.sink.file(io.open(download_location, "w")),
})
socketutil:reset_timeout()
if c == 200 then
logger.dbg("file downloaded to", download_location)
else

@ -0,0 +1,141 @@
--[[--
This module contains miscellaneous helper functions specific to our usage of LuaSocket/LuaSec.
]]
local Version = require("version")
local http = require("socket.http")
local https = require("ssl.https")
local ltn12 = require("ltn12")
local socket = require("socket")
local socketutil = {
-- Init to the default LuaSocket/LuaSec values
block_timeout = 60,
total_timeout = -1,
}
--- Builds a sensible UserAgent that fits Wikipedia's UA policy <https://meta.wikimedia.org/wiki/User-Agent_policy>
local socket_ua = http.USERAGENT
socketutil.USER_AGENT = "KOReader/" .. Version:getShortVersion() .. " (https://koreader.rocks/) " .. socket_ua:gsub(" ", "/")
-- Monkey-patch it in LuaSocket, as it already takes care of inserting the appropriate header to its requests.
http.USERAGENT = socketutil.USER_AGENT
--- Common timeout values
-- Large content
socketutil.LARGE_BLOCK_TIMEOUT = 10
socketutil.LARGE_TOTAL_TIMEOUT = 30
-- File downloads
socketutil.FILE_BLOCK_TIMEOUT = 15
socketutil.FILE_TOTAL_TIMEOUT = 60
-- Upstream defaults
socketutil.DEFAULT_BLOCK_TIMEOUT = 60
socketutil.DEFAULT_TOTAL_TIMEOUT = -1
--- Update the timeout values.
-- Note that this only affects socket polling,
-- c.f., LuaSocket's timeout_getretry @ src/timeout.c & usage in src/usocket.c
-- Moreover, the timeout is actually *reset* between polls (via timeout_markstart, e.g. in buffer_meth_receive).
-- So, in practice, this timeout only helps *very* bad connections (on one end or the other),
-- and you'd be hard-pressed to ever hit the *total* timeout, since the starting point is reset extremely often.
-- In our case, we want to enforce an *actual* limit on how much time we're willing to block for, start to finish.
-- We do that via the custom sinks below, which will start ticking as soon as the first chunk of data is received.
-- To simplify, in most cases, the socket timeout matters *before* we receive data,
-- and the sink timeout *once* we've started receiving data (at which point the socket timeout is reset every chunk).
-- In practice, that means you don't want to set block_timeout too low,
-- as that's what the socket timeout will end up using most of the time.
-- Note that name resolution happens earlier and one level lower (e.g., glibc),
-- so name resolution delays will fall outside of these timeouts.
function socketutil:set_timeout(block_timeout, total_timeout)
self.block_timeout = block_timeout or 5
self.total_timeout = total_timeout or 15
-- Also update the actual LuaSocket & LuaSec constants, because:
-- 1. LuaSocket's `open` does a `settimeout` *after* create with this constant
-- 2. KOSync updates it to a stupidly low value
http.TIMEOUT = self.block_timeout
https.TIMEOUT = self.block_timeout
end
--- Reset timeout values to LuaSocket defaults.
function socketutil:reset_timeout()
self.block_timeout = self.DEFAULT_BLOCK_TIMEOUT
self.total_timeout = self.DEFAULT_TOTAL_TIMEOUT
http.TIMEOUT = self.block_timeout
https.TIMEOUT = self.block_timeout
end
--- Monkey-patch LuaSocket's `socket.tcp` in order to honor tighter timeouts, to avoid blocking the UI for too long.
-- NOTE: While we could use a custom `create` function for HTTP LuaSocket `request`s,
-- with HTTPS, the way LuaSocket/LuaSec handles those is much more finicky,
-- because LuaSocket's adjustrequest function (in http.lua) passes the adjusted nreqt table to it,
-- but only when it does the automagic scheme handling, not when it's set by the caller :/.
-- And LuaSec's own `request` function overload *forbids* setting create, because of similar shenanigans...
-- TL;DR: Just monkey-patching socket.tcp directly will affect both HTTP & HTTPS
-- without us having to maintain a tweaked version of LuaSec's `https.tcp` function...
local real_socket_tcp = socket.tcp
function socketutil.tcp()
-- Based on https://stackoverflow.com/a/6021774
local req_sock = real_socket_tcp()
req_sock:settimeout(socketutil.block_timeout, "b")
req_sock:settimeout(socketutil.total_timeout, "t")
return req_sock
end
socket.tcp = socketutil.tcp
--- Various timeout return codes
socketutil.TIMEOUT_CODE = "timeout" -- from LuaSocket's io.c
socketutil.SSL_HANDSHAKE_CODE = "wantread" -- from LuaSec's ssl.c
socketutil.SINK_TIMEOUT_CODE = "sink timeout" -- from our own socketutil
-- NOTE: Use os.time() for simplicity's sake (we don't really need subsecond precision).
-- LuaSocket itself is already using gettimeofday anyway (although it does the maths, like ffi/util's getTimestamp).
-- Proper etiquette would have everyone using clock_gettime(CLOCK_MONOTONIC) for this kind of stuff,
-- but it's a tad more annoying to use because it's stuffed in librt in old glibc versions,
-- and I have no idea what macOS & Android do with it (but it is POSIX). Plus, win32.
--- Custom version of `ltn12.sink.table` that honors total_timeout
function socketutil.table_sink(t)
if socketutil.total_timeout < 0 then
return ltn12.sink.table(t)
end
local start_ts = os.time()
t = t or {}
local f = function(chunk, err)
if chunk then
if os.time() - start_ts > socketutil.total_timeout then
return nil, socketutil.SINK_TIMEOUT_CODE
end
table.insert(t, chunk)
end
return 1
end
return f, t
end
--- Custom version of `ltn12.sink.file` that honors total_timeout
function socketutil.file_sink(handle, io_err)
if socketutil.total_timeout < 0 then
return ltn12.sink.file(handle, io_err)
end
if handle then
local start_ts = os.time()
return function(chunk, err)
if not chunk then
handle:close()
return 1
else
if os.time() - start_ts > socketutil.total_timeout then
handle:close()
return nil, socketutil.SINK_TIMEOUT_CODE
end
return handle:write(chunk)
end
end
else
return nil, io_err or "unable to open file"
end
end
return socketutil

@ -166,6 +166,8 @@ end
function OTAManager:checkUpdate()
local http = require("socket.http")
local ltn12 = require("ltn12")
local socket = require("socket")
local socketutil = require("socketutil")
local update_file = (self:getOTAType() == "link") and self:getLinkFilename() or self:getZsyncFilename()
@ -173,11 +175,14 @@ function OTAManager:checkUpdate()
local local_update_file = ota_dir .. update_file
-- download zsync file from OTA server
logger.dbg("downloading update file", ota_update_file)
local _, c, _ = http.request{
url = ota_update_file,
sink = ltn12.sink.file(io.open(local_update_file, "w"))}
if c ~= 200 then
logger.warn("cannot find update file", c)
socketutil:set_timeout()
local code, _, status = socket.skip(1, http.request{
url = ota_update_file,
sink = ltn12.sink.file(io.open(local_update_file, "w")),
})
socketutil:reset_timeout()
if code ~= 200 then
logger.warn("cannot find update file:", status or code or "network unreachable")
return
end
-- parse OTA package version

@ -284,13 +284,12 @@ Returns decoded JSON table from translate server.
@treturn string result, or nil
--]]
function Translator:loadPage(text, target_lang, source_lang)
local socket = require('socket')
local url = require('socket.url')
local http = require('socket.http')
local https = require('ssl.https')
local ltn12 = require('ltn12')
local socket = require("socket")
local socketutil = require("socketutil")
local url = require("socket.url")
local http = require("socket.http")
local ltn12 = require("ltn12")
local request, sink = {}, {}
local query = ""
self.trans_params.tl = target_lang
self.trans_params.sl = source_lang
@ -308,25 +307,24 @@ function Translator:loadPage(text, target_lang, source_lang)
parsed.query = query .. "q=" .. url.escape(text)
-- HTTP request
request['url'] = url.build(parsed)
local sink = {}
socketutil:set_timeout()
local request = {
url = url.build(parsed),
method = "GET",
sink = ltn12.sink.table(sink),
}
logger.dbg("Calling", request.url)
request['method'] = 'GET'
request['sink'] = ltn12.sink.table(sink)
-- We may try to set a common User-Agent if it happens we're 403 Forbidden
-- request['headers'] = {
-- ["User-Agent"] = "Mozilla/5.0 (Windows NT 10.0; Win64; x64)",
-- }
http.TIMEOUT, https.TIMEOUT = 10, 10
local httpRequest = parsed.scheme == 'http' and http.request or https.request
-- first argument returned by skip is code
local _, headers, status = socket.skip(1, httpRequest(request))
-- Skip first argument (body, goes to the sink)
local code, headers, status = socket.skip(1, http.request(request))
socketutil:reset_timeout()
-- raise error message when network is unavailable
if headers == nil then
error("Network is unreachable")
end
if status ~= "HTTP/1.1 200 OK" then
if code ~= 200 then
logger.warn("translator HTTP status not okay:", status)
return
end

@ -95,84 +95,40 @@ function Wikipedia:getWikiServer(lang)
return string.format(self.wiki_server, lang or self.default_lang)
end
-- Say who we are to Wikipedia (see https://meta.wikimedia.org/wiki/User-Agent_policy)
local USER_AGENT = T("KOReader/%1 (https://koreader.rocks/) %2",
(lfs.attributes("git-rev", "mode") == "file" and io.open("git-rev", "r"):read("*line") or "devel"),
require('socket.http').USERAGENT:gsub(" ", "/") )
-- Codes that getUrlContent may get from requester.request()
local TIMEOUT_CODE = "timeout" -- from socket.lua
local MAXTIME_CODE = "maxtime reached" -- from sink_table_with_maxtime
-- Sink that stores into a table, aborting if maxtime has elapsed
local function sink_table_with_maxtime(t, maxtime)
-- Start counting as soon as this sink is created
local start_secs, start_usecs = ffiutil.gettime()
local starttime = start_secs + start_usecs/1000000
t = t or {}
local f = function(chunk, err)
local secs, usecs = ffiutil.gettime()
if secs + usecs/1000000 - starttime > maxtime then
return nil, MAXTIME_CODE
end
if chunk then table.insert(t, chunk) end
return 1
end
return f, t
end
-- Get URL content
local function getUrlContent(url, timeout, maxtime)
local socket = require('socket')
local ltn12 = require('ltn12')
local http = require('socket.http')
local https = require('ssl.https')
local requester
if url:sub(1,7) == "http://" then
requester = http
elseif url:sub(1,8) == "https://" then
requester = https
else
local http = require("socket.http")
local ltn12 = require("ltn12")
local socket = require("socket")
local socketutil = require("socketutil")
local socket_url = require("socket.url")
local parsed = socket_url.parse(url)
if parsed.scheme ~= "http" and parsed.scheme ~= "https" then
return false, "Unsupported protocol"
end
if not timeout then timeout = 10 end
-- timeout needs to be set to 'http', even if we use 'https'
http.TIMEOUT, https.TIMEOUT = timeout, timeout
local request = {}
local sink = {}
request['url'] = url
request['method'] = 'GET'
request['headers'] = {
["User-Agent"] = USER_AGENT,
socketutil:set_timeout(timeout, maxtime or 30)
local request = {
url = url,
method = "GET",
sink = maxtime and socketutil.table_sink(sink) or ltn12.sink.table(sink),
}
-- 'timeout' delay works on socket, and is triggered when
-- that time has passed trying to connect, or after connection
-- when no data has been read for this time.
-- On a slow connection, it may not be triggered (as we could read
-- 1 byte every 1 second, not triggering any timeout).
-- 'maxtime' can be provided to overcome that, and we start counting
-- as soon as the first content byte is received (but it is checked
-- for only when data is received).
-- Setting 'maxtime' and 'timeout' gives more chance to abort the request when
-- it takes too much time (in the worst case: in timeout+maxtime seconds).
-- But time taken by DNS lookup cannot easily be accounted for, so
-- a request may (when dns lookup takes time) exceed timeout and maxtime...
if maxtime then
request['sink'] = sink_table_with_maxtime(sink, maxtime)
else
request['sink'] = ltn12.sink.table(sink)
end
local code, headers, status = socket.skip(1, requester.request(request))
local code, headers, status = socket.skip(1, http.request(request))
socketutil:reset_timeout()
local content = table.concat(sink) -- empty or content accumulated till now
-- logger.dbg("code:", code)
-- logger.dbg("headers:", headers)
-- logger.dbg("status:", status)
-- logger.dbg("#content:", #content)
if code == TIMEOUT_CODE or code == MAXTIME_CODE then
if code == socketutil.TIMEOUT_CODE or
code == socketutil.SSL_HANDSHAKE_CODE or
code == socketutil.SINK_TIMEOUT_CODE
then
logger.warn("request interrupted:", code)
return false, code
end
@ -212,7 +168,7 @@ local WIKIPEDIA_IMAGES = 4
-- return decoded JSON table from Wikipedia
--]]
function Wikipedia:loadPage(text, lang, page_type, plain)
local url = require('socket.url')
local url = require("socket.url")
local query = ""
local parsed = url.parse(self:getWikiServer(lang))
parsed.path = self.wiki_path

@ -1,6 +1,7 @@
local json = require("json")
local http = require("socket.http")
local json = require("json")
local ltn12 = require("ltn12")
local socketutil = require("socketutil")
local JoplinClient = {
server_ip = "localhost",
@ -19,16 +20,18 @@ function JoplinClient:_makeRequest(url, method, request_body)
local sink = {}
local request_body_json = json.encode(request_body)
local source = ltn12.source.string(request_body_json)
socketutil:set_timeout(socketutil.LARGE_BLOCK_TIMEOUT, socketutil.LARGE_TOTAL_TIMEOUT)
http.request{
url = url,
method = method,
sink = ltn12.sink.table(sink),
source = source,
url = url,
method = method,
sink = ltn12.sink.table(sink),
source = source,
headers = {
["Content-Length"] = #request_body_json,
["Content-Type"] = "application/json"
}
},
}
socketutil:reset_timeout()
if not sink[1] then
error("No response from Joplin Server")

@ -2,10 +2,10 @@ local InputContainer = require("ui/widget/container/inputcontainer")
local GoodreadsBook = require("goodreadsbook")
local InfoMessage = require("ui/widget/infomessage")
local UIManager = require("ui/uimanager")
local url = require('socket.url')
local socket = require('socket')
local https = require('ssl.https')
local ltn12 = require('ltn12')
local http = require("socket.http")
local ltn12 = require("ltn12")
local socket = require("socket")
local socketutil = require("socketutil")
local _ = require("gettext")
local GoodreadsApi = InputContainer:new {
@ -42,14 +42,15 @@ local function genIdUrl(id, userApi)
end
function GoodreadsApi:fetchXml(s_url)
local request, sink = {}, {}
local parsed = url.parse(s_url)
request['url'] = s_url
request['method'] = 'GET'
request['sink'] = ltn12.sink.table(sink)
https.TIMEOUT = 5
local httpsRequest = parsed.scheme == 'https' and https.request
local headers = socket.skip(1, httpsRequest(request))
local sink = {}
socketutil:set_timeout()
local request = {
url = s_url,
method = "GET",
sink = ltn12.sink.table(sink),
}
local headers = socket.skip(2, http.request(request))
socketutil:reset_timeout()
if headers == nil then
return nil
end

@ -19,7 +19,7 @@ local TextWidget = require("ui/widget/textwidget")
local UIManager = require("ui/uimanager")
local VerticalGroup = require("ui/widget/verticalgroup")
local VerticalSpan = require("ui/widget/verticalspan")
local https = require('ssl.https')
local https = require("ssl.https")
local _ = require("gettext")
local Screen = require("device").screen
local T = require("ffi/util").template

@ -1,11 +1,11 @@
local Version = require("version")
local ffiutil = require("ffi/util")
local http = require("socket.http")
local https = require("ssl.https")
local logger = require("logger")
local ltn12 = require("ltn12")
local socket = require("socket")
local socket_url = require("socket.url")
local socketutil = require("socketutil")
local _ = require("gettext")
local T = ffiutil.template
@ -20,10 +20,6 @@ local EpubDownloadBackend = {
}
local max_redirects = 5; --prevent infinite redirects
-- Codes that getUrlContent may get from requester.request()
local TIMEOUT_CODE = "timeout" -- from socket.lua
local MAXTIME_CODE = "maxtime reached" -- from sink_table_with_maxtime
-- filter HTML using CSS selector
local function filter(text, element)
local htmlparser = require("htmlparser")
@ -72,23 +68,6 @@ local function filter(text, element)
return "<!DOCTYPE html><html><head></head><body>" .. filtered .. "</body></html>"
end
-- Sink that stores into a table, aborting if maxtime has elapsed
local function sink_table_with_maxtime(t, maxtime)
-- Start counting as soon as this sink is created
local start_secs, start_usecs = ffiutil.gettime()
local starttime = start_secs + start_usecs/1000000
t = t or {}
local f = function(chunk, err)
local secs, usecs = ffiutil.gettime()
if secs + usecs/1000000 - starttime > maxtime then
return nil, MAXTIME_CODE
end
if chunk then table.insert(t, chunk) end
return 1
end
return f, t
end
-- Get URL content
local function getUrlContent(url, timeout, maxtime, redirectCount)
logger.dbg("getUrlContent(", url, ",", timeout, ",", maxtime, ",", redirectCount, ")")
@ -100,35 +79,19 @@ local function getUrlContent(url, timeout, maxtime, redirectCount)
if not timeout then timeout = 10 end
logger.dbg("timeout:", timeout)
-- timeout needs to be set to "http", even if we use "https"
--http.TIMEOUT, https.TIMEOUT = timeout, timeout
-- "timeout" delay works on socket, and is triggered when
-- that time has passed trying to connect, or after connection
-- when no data has been read for this time.
-- On a slow connection, it may not be triggered (as we could read
-- 1 byte every 1 second, not triggering any timeout).
-- "maxtime" can be provided to overcome that, and we start counting
-- as soon as the first content byte is received (but it is checked
-- for only when data is received).
-- Setting "maxtime" and "timeout" gives more chance to abort the request when
-- it takes too much time (in the worst case: in timeout+maxtime seconds).
-- But time taken by DNS lookup cannot easily be accounted for, so
-- a request may (when dns lookup takes time) exceed timeout and maxtime...
local request, sink = {}, {}
if maxtime then
request.sink = sink_table_with_maxtime(sink, maxtime)
else
request.sink = ltn12.sink.table(sink)
end
request.url = url
request.method = "GET"
local sink = {}
local parsed = socket_url.parse(url)
local httpRequest = parsed.scheme == "http" and http.request or https.request
socketutil:set_timeout(timeout, maxtime or 30)
local request = {
url = url,
method = "GET",
sink = maxtime and socketutil.table_sink(sink) or ltn12.sink.table(sink),
}
logger.dbg("request:", request)
local code, headers, status = socket.skip(1, httpRequest(request))
logger.dbg("After httpRequest")
local code, headers, status = socket.skip(1, http.request(request))
socketutil:reset_timeout()
logger.dbg("After http.request")
local content = table.concat(sink) -- empty or content accumulated till now
logger.dbg("type(code):", type(code))
logger.dbg("code:", code)
@ -136,7 +99,10 @@ local function getUrlContent(url, timeout, maxtime, redirectCount)
logger.dbg("status:", status)
logger.dbg("#content:", #content)
if code == TIMEOUT_CODE or code == MAXTIME_CODE then
if code == socketutil.TIMEOUT_CODE or
code == socketutil.SSL_HANDSHAKE_CODE or
code == socketutil.SINK_TIMEOUT_CODE
then
logger.warn("request interrupted:", code)
return false, code
end

@ -1,9 +1,8 @@
local http = require("socket.http")
local https = require("ssl.https")
local logger = require("logger")
local ltn12 = require("ltn12")
local socket = require('socket')
local socket_url = require("socket.url")
local socket = require("socket")
local socketutil = require("socketutil")
local InternalDownloadBackend = {}
local max_redirects = 5; --prevent infinite redirects
@ -15,13 +14,14 @@ function InternalDownloadBackend:getResponseAsString(url, redirectCount)
error("InternalDownloadBackend: reached max redirects: ", redirectCount)
end
logger.dbg("InternalDownloadBackend: url :", url)
local request, sink = {}, {}
request['sink'] = ltn12.sink.table(sink)
request['url'] = url
local parsed = socket_url.parse(url)
local httpRequest = parsed.scheme == 'http' and http.request or https.request;
local code, headers, status = socket.skip(1, httpRequest(request))
local sink = {}
socketutil:set_timeout(socketutil.LARGE_BLOCK_TIMEOUT, socketutil.LARGE_TOTAL_TIMEOUT)
local request = {
url = url,
sink = ltn12.sink.table(sink),
}
local code, headers, status = socket.skip(1, http.request(request))
socketutil:reset_timeout()
if code ~= 200 then
logger.dbg("InternalDownloadBackend: HTTP response code <> 200. Response status: ", status)

@ -13,12 +13,13 @@ local NetworkMgr = require("ui/network/manager")
local OPDSParser = require("opdsparser")
local Screen = require("device").screen
local UIManager = require("ui/uimanager")
local http = require('socket.http')
local http = require("socket.http")
local lfs = require("libs/libkoreader-lfs")
local logger = require("logger")
local ltn12 = require('ltn12')
local socket = require('socket')
local url = require('socket.url')
local ltn12 = require("ltn12")
local socket = require("socket")
local socketutil = require("socketutil")
local url = require("socket.url")
local util = require("util")
local _ = require("gettext")
local T = require("ffi/util").template
@ -268,19 +269,21 @@ end
function OPDSBrowser:fetchFeed(item_url, username, password, method)
local sink = {}
socketutil:set_timeout(socketutil.LARGE_BLOCK_TIMEOUT, socketutil.LARGE_TOTAL_TIMEOUT)
local request = {
url = item_url,
method = method and method or "GET",
-- Explicitly specify that we don't support compressed content. Some servers will still break RFC2616 14.3 and send crap instead.
headers = { ["Accept-Encoding"] = "identity", },
headers = {
["Accept-Encoding"] = "identity",
},
sink = ltn12.sink.table(sink),
username = username,
password = password
password = password,
}
logger.info("Request:", request)
http.TIMEOUT = 10
local httpRequest = http.request
local code, headers = socket.skip(1, httpRequest(request))
local code, headers = socket.skip(1, http.request(request))
socketutil:reset_timeout()
-- raise error message when network is unavailable
if headers == nil then
error(code)
@ -560,26 +563,20 @@ function OPDSBrowser:downloadFile(item, filetype, remote_url)
UIManager:scheduleIn(1, function()
logger.dbg("Downloading file", local_path, "from", remote_url)
local parsed = url.parse(remote_url)
http.TIMEOUT = 20
local dummy, code, headers
if parsed.scheme == "http" then
dummy, code, headers = http.request {
url = remote_url,
headers = { ["Accept-Encoding"] = "identity", },
sink = ltn12.sink.file(io.open(local_path, "w")),
user = item.username,
password = item.password
}
elseif parsed.scheme == "https" then
dummy, code, headers = http.request {
local code, headers
if parsed.scheme == "http" or parsed.scheme == "https" then
socketutil:set_timeout(socketutil.FILE_BLOCK_TIMEOUT, socketutil.FILE_TOTAL_TIMEOUT)
code, headers = socket.skip(1, http.request {
url = remote_url,
headers = { ["Accept-Encoding"] = "identity", },
headers = {
["Accept-Encoding"] = "identity",
},
sink = ltn12.sink.file(io.open(local_path, "w")),
user = item.username,
password = item.password
}
password = item.password,
})
socketutil:reset_timeout()
else
UIManager:show(InfoMessage:new {
text = T(_("Invalid protocol:\n%1"), parsed.scheme),

@ -26,6 +26,7 @@ local http = require("socket.http")
local logger = require("logger")
local ltn12 = require("ltn12")
local socket = require("socket")
local socketutil = require("socketutil")
local util = require("util")
local _ = require("gettext")
local T = FFIUtil.template
@ -379,7 +380,7 @@ function Wallabag:getBearerToken()
["Content-type"] = "application/json",
["Accept"] = "application/json, */*",
["Content-Length"] = tostring(#bodyJSON),
}
}
local result = self:callAPI("POST", login_url, headers, bodyJSON, "")
if result then
@ -539,21 +540,25 @@ end
-- filepath: downloads the file if provided, returns JSON otherwise
---- @todo separate call to internal API from the download on external server
function Wallabag:callAPI(method, apiurl, headers, body, filepath, quiet)
local request, sink = {}, {}
local sink = {}
local request = {}
-- Is it an API call, or a regular file direct download?
if apiurl:sub(1, 1) == "/" then
-- API call to our server, has the form "/random/api/call"
request.url = self.server_url .. apiurl
if headers == nil then
headers = { ["Authorization"] = "Bearer " .. self.access_token, }
headers = {
["Authorization"] = "Bearer " .. self.access_token,
}
end
else
-- regular url link to a foreign server
local file_url = apiurl
request.url = file_url
if headers == nil then
headers = {} -- no need for a token here
-- no need for a token here
headers = {}
end
end
@ -570,9 +575,9 @@ function Wallabag:callAPI(method, apiurl, headers, body, filepath, quiet)
logger.dbg("Wallabag: URL ", request.url)
logger.dbg("Wallabag: method ", method)
http.TIMEOUT = 30
local httpRequest = http.request
local code, resp_headers = socket.skip(1, httpRequest(request))
socketutil:set_timeout(socketutil.LARGE_BLOCK_TIMEOUT, socketutil.LARGE_TOTAL_TIMEOUT)
local code, resp_headers = socket.skip(1, http.request(request))
socketutil:reset_timeout()
-- raise error message when network is unavailable
if resp_headers == nil then
logger.dbg("Wallabag: Server error: ", code)

Loading…
Cancel
Save