diff --git a/test/conftest.py b/test/conftest.py
index 2fbc269e1..decd2c85c 100644
--- a/test/conftest.py
+++ b/test/conftest.py
@@ -1,4 +1,3 @@
-import functools
import inspect
import pytest
@@ -10,7 +9,9 @@ from yt_dlp.utils._utils import _YDLLogger as FakeLogger
@pytest.fixture
def handler(request):
- RH_KEY = request.param
+ RH_KEY = getattr(request, 'param', None)
+ if not RH_KEY:
+ return
if inspect.isclass(RH_KEY) and issubclass(RH_KEY, RequestHandler):
handler = RH_KEY
elif RH_KEY in _REQUEST_HANDLERS:
@@ -18,9 +19,46 @@ def handler(request):
else:
pytest.skip(f'{RH_KEY} request handler is not available')
- return functools.partial(handler, logger=FakeLogger)
+ class HandlerWrapper(handler):
+ RH_KEY = handler.RH_KEY
+ def __init__(self, *args, **kwargs):
+ super().__init__(logger=FakeLogger, *args, **kwargs)
-def validate_and_send(rh, req):
- rh.validate(req)
- return rh.send(req)
+ return HandlerWrapper
+
+
+@pytest.fixture(autouse=True)
+def skip_handler(request, handler):
+ """usage: pytest.mark.skip_handler('my_handler', 'reason')"""
+ for marker in request.node.iter_markers('skip_handler'):
+ if marker.args[0] == handler.RH_KEY:
+ pytest.skip(marker.args[1] if len(marker.args) > 1 else '')
+
+
+@pytest.fixture(autouse=True)
+def skip_handler_if(request, handler):
+ """usage: pytest.mark.skip_handler_if('my_handler', lambda request: True, 'reason')"""
+ for marker in request.node.iter_markers('skip_handler_if'):
+ if marker.args[0] == handler.RH_KEY and marker.args[1](request):
+ pytest.skip(marker.args[2] if len(marker.args) > 2 else '')
+
+
+@pytest.fixture(autouse=True)
+def skip_handlers_if(request, handler):
+ """usage: pytest.mark.skip_handlers_if(lambda request, handler: True, 'reason')"""
+ for marker in request.node.iter_markers('skip_handlers_if'):
+ if handler and marker.args[0](request, handler):
+ pytest.skip(marker.args[1] if len(marker.args) > 1 else '')
+
+
+def pytest_configure(config):
+ config.addinivalue_line(
+ "markers", "skip_handler(handler): skip test for the given handler",
+ )
+ config.addinivalue_line(
+ "markers", "skip_handler_if(handler): skip test for the given handler if condition is true"
+ )
+ config.addinivalue_line(
+ "markers", "skip_handlers_if(handler): skip test for handlers when the condition is true"
+ )
diff --git a/test/helper.py b/test/helper.py
index 7760fd8d7..e7473120d 100644
--- a/test/helper.py
+++ b/test/helper.py
@@ -338,3 +338,8 @@ def http_server_port(httpd):
def verify_address_availability(address):
if find_available_port(address) is None:
pytest.skip(f'Unable to bind to source address {address} (address may not exist)')
+
+
+def validate_and_send(rh, req):
+ rh.validate(req)
+ return rh.send(req)
diff --git a/test/test_http_proxy.py b/test/test_http_proxy.py
new file mode 100644
index 000000000..6e27bc6d5
--- /dev/null
+++ b/test/test_http_proxy.py
@@ -0,0 +1,430 @@
+import abc
+import base64
+import contextlib
+import functools
+import json
+import os
+import random
+import ssl
+import threading
+from http.server import BaseHTTPRequestHandler
+from socketserver import BaseRequestHandler, ThreadingTCPServer
+
+import pytest
+
+from test.helper import http_server_port, verify_address_availability
+from test.test_networking import TEST_DIR
+from test.test_socks import IPv6ThreadingTCPServer
+from yt_dlp.dependencies import urllib3
+from yt_dlp.networking import Request
+from yt_dlp.networking.exceptions import HTTPError, ProxyError, SSLError
+
+
+class HTTPProxyAuthMixin:
+
+ def proxy_auth_error(self):
+ self.send_response(407)
+ self.send_header('Proxy-Authenticate', 'Basic realm="test http proxy"')
+ self.end_headers()
+ return False
+
+ def do_proxy_auth(self, username, password):
+ if username is None and password is None:
+ return True
+
+ proxy_auth_header = self.headers.get('Proxy-Authorization', None)
+ if proxy_auth_header is None:
+ return self.proxy_auth_error()
+
+ if not proxy_auth_header.startswith('Basic '):
+ return self.proxy_auth_error()
+
+ auth = proxy_auth_header[6:]
+
+ try:
+ auth_username, auth_password = base64.b64decode(auth).decode().split(':', 1)
+ except Exception:
+ return self.proxy_auth_error()
+
+ if auth_username != (username or '') or auth_password != (password or ''):
+ return self.proxy_auth_error()
+ return True
+
+
+class HTTPProxyHandler(BaseHTTPRequestHandler, HTTPProxyAuthMixin):
+ def __init__(self, *args, proxy_info=None, username=None, password=None, request_handler=None, **kwargs):
+ self.username = username
+ self.password = password
+ self.proxy_info = proxy_info
+ super().__init__(*args, **kwargs)
+
+ def do_GET(self):
+ if not self.do_proxy_auth(self.username, self.password):
+ self.server.close_request(self.request)
+ return
+ if self.path.endswith('/proxy_info'):
+ payload = json.dumps(self.proxy_info or {
+ 'client_address': self.client_address,
+ 'connect': False,
+ 'connect_host': None,
+ 'connect_port': None,
+ 'headers': dict(self.headers),
+ 'path': self.path,
+ 'proxy': ':'.join(str(y) for y in self.connection.getsockname()),
+ })
+ self.send_response(200)
+ self.send_header('Content-Type', 'application/json; charset=utf-8')
+ self.send_header('Content-Length', str(len(payload)))
+ self.end_headers()
+ self.wfile.write(payload.encode())
+ else:
+ self.send_response(404)
+ self.end_headers()
+
+ self.server.close_request(self.request)
+
+
+if urllib3:
+ import urllib3.util.ssltransport
+
+ class SSLTransport(urllib3.util.ssltransport.SSLTransport):
+ """
+ Modified version of urllib3 SSLTransport to support server side SSL
+
+ This allows us to chain multiple TLS connections.
+ """
+ def __init__(self, socket, ssl_context, server_hostname=None, suppress_ragged_eofs=True, server_side=False):
+ self.incoming = ssl.MemoryBIO()
+ self.outgoing = ssl.MemoryBIO()
+
+ self.suppress_ragged_eofs = suppress_ragged_eofs
+ self.socket = socket
+
+ self.sslobj = ssl_context.wrap_bio(
+ self.incoming,
+ self.outgoing,
+ server_hostname=server_hostname,
+ server_side=server_side
+ )
+ self._ssl_io_loop(self.sslobj.do_handshake)
+
+ @property
+ def _io_refs(self):
+ return self.socket._io_refs
+
+ @_io_refs.setter
+ def _io_refs(self, value):
+ self.socket._io_refs = value
+
+ def shutdown(self, *args, **kwargs):
+ self.socket.shutdown(*args, **kwargs)
+else:
+ SSLTransport = None
+
+
+class HTTPSProxyHandler(HTTPProxyHandler):
+ def __init__(self, request, *args, **kwargs):
+ certfn = os.path.join(TEST_DIR, 'testcert.pem')
+ sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+ sslctx.load_cert_chain(certfn, None)
+ if isinstance(request, ssl.SSLSocket):
+ request = SSLTransport(request, ssl_context=sslctx, server_side=True)
+ else:
+ request = sslctx.wrap_socket(request, server_side=True)
+ super().__init__(request, *args, **kwargs)
+
+
+class WebSocketProxyHandler(BaseRequestHandler):
+ def __init__(self, *args, proxy_info=None, **kwargs):
+ self.proxy_info = proxy_info
+ super().__init__(*args, **kwargs)
+
+ def handle(self):
+ import websockets.sync.server
+ protocol = websockets.ServerProtocol()
+ connection = websockets.sync.server.ServerConnection(socket=self.request, protocol=protocol, close_timeout=0)
+ connection.handshake()
+ for message in connection:
+ if message == 'proxy_info':
+ connection.send(json.dumps(self.proxy_info))
+ connection.close()
+
+
+class WebSocketSecureProxyHandler(WebSocketProxyHandler):
+ def __init__(self, request, *args, **kwargs):
+ certfn = os.path.join(TEST_DIR, 'testcert.pem')
+ sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+ sslctx.load_cert_chain(certfn, None)
+ if SSLTransport:
+ request = SSLTransport(request, ssl_context=sslctx, server_side=True)
+ else:
+ request = sslctx.wrap_socket(request, server_side=True)
+ super().__init__(request, *args, **kwargs)
+
+
+class HTTPConnectProxyHandler(BaseHTTPRequestHandler, HTTPProxyAuthMixin):
+ protocol_version = 'HTTP/1.1'
+ default_request_version = 'HTTP/1.1'
+
+ def __init__(self, *args, username=None, password=None, request_handler=None, **kwargs):
+ self.username = username
+ self.password = password
+ self.request_handler = request_handler
+ super().__init__(*args, **kwargs)
+
+ def do_CONNECT(self):
+ if not self.do_proxy_auth(self.username, self.password):
+ self.server.close_request(self.request)
+ return
+ self.send_response(200)
+ self.end_headers()
+ proxy_info = {
+ 'client_address': self.client_address,
+ 'connect': True,
+ 'connect_host': self.path.split(':')[0],
+ 'connect_port': int(self.path.split(':')[1]),
+ 'headers': dict(self.headers),
+ 'path': self.path,
+ 'proxy': ':'.join(str(y) for y in self.connection.getsockname()),
+ }
+ self.request_handler(self.request, self.client_address, self.server, proxy_info=proxy_info)
+ self.server.close_request(self.request)
+
+
+class HTTPSConnectProxyHandler(HTTPConnectProxyHandler):
+ def __init__(self, request, *args, **kwargs):
+ certfn = os.path.join(TEST_DIR, 'testcert.pem')
+ sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+ sslctx.load_cert_chain(certfn, None)
+ request = sslctx.wrap_socket(request, server_side=True)
+ self._original_request = request
+ super().__init__(request, *args, **kwargs)
+
+ def do_CONNECT(self):
+ super().do_CONNECT()
+ self.server.close_request(self._original_request)
+
+
+@contextlib.contextmanager
+def proxy_server(proxy_server_class, request_handler, bind_ip=None, **proxy_server_kwargs):
+ server = server_thread = None
+ try:
+ bind_address = bind_ip or '127.0.0.1'
+ server_type = ThreadingTCPServer if '.' in bind_address else IPv6ThreadingTCPServer
+ server = server_type(
+ (bind_address, 0), functools.partial(proxy_server_class, request_handler=request_handler, **proxy_server_kwargs))
+ server_port = http_server_port(server)
+ server_thread = threading.Thread(target=server.serve_forever)
+ server_thread.daemon = True
+ server_thread.start()
+ if '.' not in bind_address:
+ yield f'[{bind_address}]:{server_port}'
+ else:
+ yield f'{bind_address}:{server_port}'
+ finally:
+ server.shutdown()
+ server.server_close()
+ server_thread.join(2.0)
+
+
+class HTTPProxyTestContext(abc.ABC):
+ REQUEST_HANDLER_CLASS = None
+ REQUEST_PROTO = None
+
+ def http_server(self, server_class, *args, **kwargs):
+ return proxy_server(server_class, self.REQUEST_HANDLER_CLASS, *args, **kwargs)
+
+ @abc.abstractmethod
+ def proxy_info_request(self, handler, target_domain=None, target_port=None, **req_kwargs) -> dict:
+ """return a dict of proxy_info"""
+
+
+class HTTPProxyHTTPTestContext(HTTPProxyTestContext):
+ # Standard HTTP Proxy for http requests
+ REQUEST_HANDLER_CLASS = HTTPProxyHandler
+ REQUEST_PROTO = 'http'
+
+ def proxy_info_request(self, handler, target_domain=None, target_port=None, **req_kwargs):
+ request = Request(f'http://{target_domain or "127.0.0.1"}:{target_port or "40000"}/proxy_info', **req_kwargs)
+ handler.validate(request)
+ return json.loads(handler.send(request).read().decode())
+
+
+class HTTPProxyHTTPSTestContext(HTTPProxyTestContext):
+ # HTTP Connect proxy, for https requests
+ REQUEST_HANDLER_CLASS = HTTPSProxyHandler
+ REQUEST_PROTO = 'https'
+
+ def proxy_info_request(self, handler, target_domain=None, target_port=None, **req_kwargs):
+ request = Request(f'https://{target_domain or "127.0.0.1"}:{target_port or "40000"}/proxy_info', **req_kwargs)
+ handler.validate(request)
+ return json.loads(handler.send(request).read().decode())
+
+
+class HTTPProxyWebSocketTestContext(HTTPProxyTestContext):
+ REQUEST_HANDLER_CLASS = WebSocketProxyHandler
+ REQUEST_PROTO = 'ws'
+
+ def proxy_info_request(self, handler, target_domain=None, target_port=None, **req_kwargs):
+ request = Request(f'{self.REQUEST_PROTO}://{target_domain or "127.0.0.1"}:{target_port or "40000"}', **req_kwargs)
+ handler.validate(request)
+ ws = handler.send(request)
+ ws.send('proxy_info')
+ socks_info = ws.recv()
+ ws.close()
+ return json.loads(socks_info)
+
+
+class HTTPProxyWebSocketSecureTestContext(HTTPProxyWebSocketTestContext):
+ REQUEST_HANDLER_CLASS = WebSocketSecureProxyHandler
+ REQUEST_PROTO = 'wss'
+
+
+CTX_MAP = {
+ 'http': HTTPProxyHTTPTestContext,
+ 'https': HTTPProxyHTTPSTestContext,
+ 'ws': HTTPProxyWebSocketTestContext,
+ 'wss': HTTPProxyWebSocketSecureTestContext,
+}
+
+
+@pytest.fixture(scope='module')
+def ctx(request):
+ return CTX_MAP[request.param]()
+
+
+@pytest.mark.parametrize(
+ 'handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
+@pytest.mark.parametrize('ctx', ['http'], indirect=True) # pure http proxy can only support http
+class TestHTTPProxy:
+ def test_http_no_auth(self, handler, ctx):
+ with ctx.http_server(HTTPProxyHandler) as server_address:
+ with handler(proxies={ctx.REQUEST_PROTO: f'http://{server_address}'}) as rh:
+ proxy_info = ctx.proxy_info_request(rh)
+ assert proxy_info['proxy'] == server_address
+ assert proxy_info['connect'] is False
+ assert 'Proxy-Authorization' not in proxy_info['headers']
+
+ def test_http_auth(self, handler, ctx):
+ with ctx.http_server(HTTPProxyHandler, username='test', password='test') as server_address:
+ with handler(proxies={ctx.REQUEST_PROTO: f'http://test:test@{server_address}'}) as rh:
+ proxy_info = ctx.proxy_info_request(rh)
+ assert proxy_info['proxy'] == server_address
+ assert 'Proxy-Authorization' in proxy_info['headers']
+
+ def test_http_bad_auth(self, handler, ctx):
+ with ctx.http_server(HTTPProxyHandler, username='test', password='test') as server_address:
+ with handler(proxies={ctx.REQUEST_PROTO: f'http://test:bad@{server_address}'}) as rh:
+ with pytest.raises(HTTPError) as exc_info:
+ ctx.proxy_info_request(rh)
+ assert exc_info.value.response.status == 407
+ exc_info.value.response.close()
+
+ def test_http_source_address(self, handler, ctx):
+ with ctx.http_server(HTTPProxyHandler) as server_address:
+ source_address = f'127.0.0.{random.randint(5, 255)}'
+ verify_address_availability(source_address)
+ with handler(proxies={ctx.REQUEST_PROTO: f'http://{server_address}'},
+ source_address=source_address) as rh:
+ proxy_info = ctx.proxy_info_request(rh)
+ assert proxy_info['proxy'] == server_address
+ assert proxy_info['client_address'][0] == source_address
+
+ @pytest.mark.skip_handler('Urllib', 'urllib does not support https proxies')
+ def test_https(self, handler, ctx):
+ with ctx.http_server(HTTPSProxyHandler) as server_address:
+ with handler(verify=False, proxies={ctx.REQUEST_PROTO: f'https://{server_address}'}) as rh:
+ proxy_info = ctx.proxy_info_request(rh)
+ assert proxy_info['proxy'] == server_address
+ assert proxy_info['connect'] is False
+ assert 'Proxy-Authorization' not in proxy_info['headers']
+
+ @pytest.mark.skip_handler('Urllib', 'urllib does not support https proxies')
+ def test_https_verify_failed(self, handler, ctx):
+ with ctx.http_server(HTTPSProxyHandler) as server_address:
+ with handler(verify=True, proxies={ctx.REQUEST_PROTO: f'https://{server_address}'}) as rh:
+ # Accept SSLError as may not be feasible to tell if it is proxy or request error.
+ # note: if request proto also does ssl verification, this may also be the error of the request.
+ # Until we can support passing custom cacerts to handlers, we cannot properly test this for all cases.
+ with pytest.raises((ProxyError, SSLError)):
+ ctx.proxy_info_request(rh)
+
+ def test_http_with_idn(self, handler, ctx):
+ with ctx.http_server(HTTPProxyHandler) as server_address:
+ with handler(proxies={ctx.REQUEST_PROTO: f'http://{server_address}'}) as rh:
+ proxy_info = ctx.proxy_info_request(rh, target_domain='中文.tw')
+ assert proxy_info['proxy'] == server_address
+ assert proxy_info['path'].startswith('http://xn--fiq228c.tw')
+ assert proxy_info['headers']['Host'].split(':', 1)[0] == 'xn--fiq228c.tw'
+
+
+@pytest.mark.parametrize(
+ 'handler,ctx', [
+ ('Requests', 'https'),
+ ('CurlCFFI', 'https'),
+ ('Websockets', 'ws'),
+ ('Websockets', 'wss')
+ ], indirect=True)
+class TestHTTPConnectProxy:
+ def test_http_connect_no_auth(self, handler, ctx):
+ with ctx.http_server(HTTPConnectProxyHandler) as server_address:
+ with handler(verify=False, proxies={ctx.REQUEST_PROTO: f'http://{server_address}'}) as rh:
+ proxy_info = ctx.proxy_info_request(rh)
+ assert proxy_info['proxy'] == server_address
+ assert proxy_info['connect'] is True
+ assert 'Proxy-Authorization' not in proxy_info['headers']
+
+ def test_http_connect_auth(self, handler, ctx):
+ with ctx.http_server(HTTPConnectProxyHandler, username='test', password='test') as server_address:
+ with handler(verify=False, proxies={ctx.REQUEST_PROTO: f'http://test:test@{server_address}'}) as rh:
+ proxy_info = ctx.proxy_info_request(rh)
+ assert proxy_info['proxy'] == server_address
+ assert 'Proxy-Authorization' in proxy_info['headers']
+
+ @pytest.mark.skip_handler(
+ 'Requests',
+ 'bug in urllib3 causes unclosed socket: https://github.com/urllib3/urllib3/issues/3374'
+ )
+ def test_http_connect_bad_auth(self, handler, ctx):
+ with ctx.http_server(HTTPConnectProxyHandler, username='test', password='test') as server_address:
+ with handler(verify=False, proxies={ctx.REQUEST_PROTO: f'http://test:bad@{server_address}'}) as rh:
+ with pytest.raises(ProxyError):
+ ctx.proxy_info_request(rh)
+
+ def test_http_connect_source_address(self, handler, ctx):
+ with ctx.http_server(HTTPConnectProxyHandler) as server_address:
+ source_address = f'127.0.0.{random.randint(5, 255)}'
+ verify_address_availability(source_address)
+ with handler(proxies={ctx.REQUEST_PROTO: f'http://{server_address}'},
+ source_address=source_address,
+ verify=False) as rh:
+ proxy_info = ctx.proxy_info_request(rh)
+ assert proxy_info['proxy'] == server_address
+ assert proxy_info['client_address'][0] == source_address
+
+ @pytest.mark.skipif(urllib3 is None, reason='requires urllib3 to test')
+ def test_https_connect_proxy(self, handler, ctx):
+ with ctx.http_server(HTTPSConnectProxyHandler) as server_address:
+ with handler(verify=False, proxies={ctx.REQUEST_PROTO: f'https://{server_address}'}) as rh:
+ proxy_info = ctx.proxy_info_request(rh)
+ assert proxy_info['proxy'] == server_address
+ assert proxy_info['connect'] is True
+ assert 'Proxy-Authorization' not in proxy_info['headers']
+
+ @pytest.mark.skipif(urllib3 is None, reason='requires urllib3 to test')
+ def test_https_connect_verify_failed(self, handler, ctx):
+ with ctx.http_server(HTTPSConnectProxyHandler) as server_address:
+ with handler(verify=True, proxies={ctx.REQUEST_PROTO: f'https://{server_address}'}) as rh:
+ # Accept SSLError as may not be feasible to tell if it is proxy or request error.
+ # note: if request proto also does ssl verification, this may also be the error of the request.
+ # Until we can support passing custom cacerts to handlers, we cannot properly test this for all cases.
+ with pytest.raises((ProxyError, SSLError)):
+ ctx.proxy_info_request(rh)
+
+ @pytest.mark.skipif(urllib3 is None, reason='requires urllib3 to test')
+ def test_https_connect_proxy_auth(self, handler, ctx):
+ with ctx.http_server(HTTPSConnectProxyHandler, username='test', password='test') as server_address:
+ with handler(verify=False, proxies={ctx.REQUEST_PROTO: f'https://test:test@{server_address}'}) as rh:
+ proxy_info = ctx.proxy_info_request(rh)
+ assert proxy_info['proxy'] == server_address
+ assert 'Proxy-Authorization' in proxy_info['headers']
diff --git a/test/test_networking.py b/test/test_networking.py
index b50f70d08..f5a9a95eb 100644
--- a/test/test_networking.py
+++ b/test/test_networking.py
@@ -6,6 +6,8 @@ import sys
import pytest
+from yt_dlp.networking.common import Features
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import gzip
@@ -27,8 +29,12 @@ import zlib
from email.message import Message
from http.cookiejar import CookieJar
-from test.conftest import validate_and_send
-from test.helper import FakeYDL, http_server_port, verify_address_availability
+from test.helper import (
+ FakeYDL,
+ http_server_port,
+ validate_and_send,
+ verify_address_availability,
+)
from yt_dlp.cookies import YoutubeDLCookieJar
from yt_dlp.dependencies import brotli, curl_cffi, requests, urllib3
from yt_dlp.networking import (
@@ -62,21 +68,6 @@ from yt_dlp.utils.networking import HTTPHeaderDict, std_headers
TEST_DIR = os.path.dirname(os.path.abspath(__file__))
-def _build_proxy_handler(name):
- class HTTPTestRequestHandler(http.server.BaseHTTPRequestHandler):
- proxy_name = name
-
- def log_message(self, format, *args):
- pass
-
- def do_GET(self):
- self.send_response(200)
- self.send_header('Content-Type', 'text/plain; charset=utf-8')
- self.end_headers()
- self.wfile.write(f'{self.proxy_name}: {self.path}'.encode())
- return HTTPTestRequestHandler
-
-
class HTTPTestRequestHandler(http.server.BaseHTTPRequestHandler):
protocol_version = 'HTTP/1.1'
default_request_version = 'HTTP/1.1'
@@ -317,8 +308,9 @@ class TestRequestHandlerBase:
cls.https_server_thread.start()
+@pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
class TestHTTPRequestHandler(TestRequestHandlerBase):
- @pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
+
def test_verify_cert(self, handler):
with handler() as rh:
with pytest.raises(CertificateVerifyError):
@@ -329,7 +321,6 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
assert r.status == 200
r.close()
- @pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
def test_ssl_error(self, handler):
# HTTPS server with too old TLS version
# XXX: is there a better way to test this than to create a new server?
@@ -347,7 +338,6 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
validate_and_send(rh, Request(f'https://127.0.0.1:{https_port}/headers'))
assert not issubclass(exc_info.type, CertificateVerifyError)
- @pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
def test_percent_encode(self, handler):
with handler() as rh:
# Unicode characters should be encoded with uppercase percent-encoding
@@ -359,7 +349,6 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
assert res.status == 200
res.close()
- @pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
@pytest.mark.parametrize('path', [
'/a/b/./../../headers',
'/redirect_dotsegments',
@@ -375,15 +364,13 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
assert res.url == f'http://127.0.0.1:{self.http_port}/headers'
res.close()
- # Not supported by CurlCFFI (non-standard)
- @pytest.mark.parametrize('handler', ['Urllib', 'Requests'], indirect=True)
+ @pytest.mark.skip_handler('CurlCFFI', 'not supported by curl-cffi (non-standard)')
def test_unicode_path_redirection(self, handler):
with handler() as rh:
r = validate_and_send(rh, Request(f'http://127.0.0.1:{self.http_port}/302-non-ascii-redirect'))
assert r.url == f'http://127.0.0.1:{self.http_port}/%E4%B8%AD%E6%96%87.html'
r.close()
- @pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
def test_raise_http_error(self, handler):
with handler() as rh:
for bad_status in (400, 500, 599, 302):
@@ -393,7 +380,6 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
# Should not raise an error
validate_and_send(rh, Request('http://127.0.0.1:%d/gen_200' % self.http_port)).close()
- @pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
def test_response_url(self, handler):
with handler() as rh:
# Response url should be that of the last url in redirect chain
@@ -405,7 +391,6 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
res2.close()
# Covers some basic cases we expect some level of consistency between request handlers for
- @pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
@pytest.mark.parametrize('redirect_status,method,expected', [
# A 303 must either use GET or HEAD for subsequent request
(303, 'POST', ('', 'GET', False)),
@@ -447,7 +432,6 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
assert expected[1] == res.headers.get('method')
assert expected[2] == ('content-length' in headers.decode().lower())
- @pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
def test_request_cookie_header(self, handler):
# We should accept a Cookie header being passed as in normal headers and handle it appropriately.
with handler() as rh:
@@ -480,19 +464,16 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
assert b'cookie: test=ytdlp' not in data.lower()
assert b'cookie: test=test3' in data.lower()
- @pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
def test_redirect_loop(self, handler):
with handler() as rh:
with pytest.raises(HTTPError, match='redirect loop'):
validate_and_send(rh, Request(f'http://127.0.0.1:{self.http_port}/redirect_loop'))
- @pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
def test_incompleteread(self, handler):
with handler(timeout=2) as rh:
with pytest.raises(IncompleteRead, match='13 bytes read, 234221 more expected'):
validate_and_send(rh, Request('http://127.0.0.1:%d/incompleteread' % self.http_port)).read()
- @pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
def test_cookies(self, handler):
cookiejar = YoutubeDLCookieJar()
cookiejar.set_cookie(http.cookiejar.Cookie(
@@ -509,7 +490,6 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
rh, Request(f'http://127.0.0.1:{self.http_port}/headers', extensions={'cookiejar': cookiejar})).read()
assert b'cookie: test=ytdlp' in data.lower()
- @pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
def test_headers(self, handler):
with handler(headers=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'})) as rh:
@@ -525,7 +505,6 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
assert b'test2: test2' not in data
assert b'test3: test3' in data
- @pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
def test_read_timeout(self, handler):
with handler() as rh:
# Default timeout is 20 seconds, so this should go through
@@ -541,7 +520,6 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
validate_and_send(
rh, Request(f'http://127.0.0.1:{self.http_port}/timeout_1', extensions={'timeout': 4}))
- @pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
def test_connect_timeout(self, handler):
# nothing should be listening on this port
connect_timeout_url = 'http://10.255.255.255'
@@ -560,7 +538,6 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
rh, Request(connect_timeout_url, extensions={'timeout': 0.01}))
assert 0.01 <= time.time() - now < 20
- @pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
def test_source_address(self, handler):
source_address = f'127.0.0.{random.randint(5, 255)}'
# on some systems these loopback addresses we need for testing may not be available
@@ -572,13 +549,13 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
assert source_address == data
# Not supported by CurlCFFI
- @pytest.mark.parametrize('handler', ['Urllib', 'Requests'], indirect=True)
+ @pytest.mark.skip_handler('CurlCFFI', 'not supported by curl-cffi')
def test_gzip_trailing_garbage(self, handler):
with handler() as rh:
data = validate_and_send(rh, Request(f'http://localhost:{self.http_port}/trailing_garbage')).read().decode()
assert data == ''
- @pytest.mark.parametrize('handler', ['Urllib', 'Requests'], indirect=True)
+ @pytest.mark.skip_handler('CurlCFFI', 'not applicable to curl-cffi')
@pytest.mark.skipif(not brotli, reason='brotli support is not installed')
def test_brotli(self, handler):
with handler() as rh:
@@ -589,7 +566,6 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
assert res.headers.get('Content-Encoding') == 'br'
assert res.read() == b''
- @pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
def test_deflate(self, handler):
with handler() as rh:
res = validate_and_send(
@@ -599,7 +575,6 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
assert res.headers.get('Content-Encoding') == 'deflate'
assert res.read() == b''
- @pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
def test_gzip(self, handler):
with handler() as rh:
res = validate_and_send(
@@ -609,7 +584,6 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
assert res.headers.get('Content-Encoding') == 'gzip'
assert res.read() == b''
- @pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
def test_multiple_encodings(self, handler):
with handler() as rh:
for pair in ('gzip,deflate', 'deflate, gzip', 'gzip, gzip', 'deflate, deflate'):
@@ -620,8 +594,7 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
assert res.headers.get('Content-Encoding') == pair
assert res.read() == b''
- # Not supported by curl_cffi
- @pytest.mark.parametrize('handler', ['Urllib', 'Requests'], indirect=True)
+ @pytest.mark.skip_handler('CurlCFFI', 'not supported by curl-cffi')
def test_unsupported_encoding(self, handler):
with handler() as rh:
res = validate_and_send(
@@ -631,7 +604,6 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
assert res.headers.get('Content-Encoding') == 'unsupported'
assert res.read() == b'raw'
- @pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
def test_read(self, handler):
with handler() as rh:
res = validate_and_send(
@@ -642,83 +614,48 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
assert res.read().decode().endswith('\n\n')
assert res.read() == b''
+ def test_request_disable_proxy(self, handler):
+ for proxy_proto in handler._SUPPORTED_PROXY_SCHEMES or ['http']:
+ # Given the handler is configured with a proxy
+ with handler(proxies={'http': f'{proxy_proto}://10.255.255.255'}, timeout=5) as rh:
+ # When a proxy is explicitly set to None for the request
+ res = validate_and_send(
+ rh, Request(f'http://127.0.0.1:{self.http_port}/headers', proxies={'http': None}))
+ # Then no proxy should be used
+ res.close()
+ assert res.status == 200
-class TestHTTPProxy(TestRequestHandlerBase):
- # Note: this only tests http urls over non-CONNECT proxy
- @classmethod
- def setup_class(cls):
- super().setup_class()
- # HTTP Proxy server
- cls.proxy = http.server.ThreadingHTTPServer(
- ('127.0.0.1', 0), _build_proxy_handler('normal'))
- cls.proxy_port = http_server_port(cls.proxy)
- cls.proxy_thread = threading.Thread(target=cls.proxy.serve_forever)
- cls.proxy_thread.daemon = True
- cls.proxy_thread.start()
-
- # Geo proxy server
- cls.geo_proxy = http.server.ThreadingHTTPServer(
- ('127.0.0.1', 0), _build_proxy_handler('geo'))
- cls.geo_port = http_server_port(cls.geo_proxy)
- cls.geo_proxy_thread = threading.Thread(target=cls.geo_proxy.serve_forever)
- cls.geo_proxy_thread.daemon = True
- cls.geo_proxy_thread.start()
-
- @pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
- def test_http_proxy(self, handler):
- http_proxy = f'http://127.0.0.1:{self.proxy_port}'
- geo_proxy = f'http://127.0.0.1:{self.geo_port}'
-
- # Test global http proxy
- # Test per request http proxy
- # Test per request http proxy disables proxy
- url = 'http://foo.com/bar'
-
- # Global HTTP proxy
- with handler(proxies={'http': http_proxy}) as rh:
- res = validate_and_send(rh, Request(url)).read().decode()
- assert res == f'normal: {url}'
-
- # Per request proxy overrides global
- res = validate_and_send(rh, Request(url, proxies={'http': geo_proxy})).read().decode()
- assert res == f'geo: {url}'
-
- # and setting to None disables all proxies for that request
- real_url = f'http://127.0.0.1:{self.http_port}/headers'
- res = validate_and_send(
- rh, Request(real_url, proxies={'http': None})).read().decode()
- assert res != f'normal: {real_url}'
- assert 'Accept' in res
-
- @pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
+ @pytest.mark.skip_handlers_if(
+ lambda _, handler: Features.NO_PROXY not in handler._SUPPORTED_FEATURES, 'handler does not support NO_PROXY')
def test_noproxy(self, handler):
- with handler(proxies={'proxy': f'http://127.0.0.1:{self.proxy_port}'}) as rh:
- # NO_PROXY
- for no_proxy in (f'127.0.0.1:{self.http_port}', '127.0.0.1', 'localhost'):
- nop_response = validate_and_send(
- rh, Request(f'http://127.0.0.1:{self.http_port}/headers', proxies={'no': no_proxy})).read().decode(
- 'utf-8')
- assert 'Accept' in nop_response
-
- @pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
+ for proxy_proto in handler._SUPPORTED_PROXY_SCHEMES or ['http']:
+ # Given the handler is configured with a proxy
+ with handler(proxies={'http': f'{proxy_proto}://10.255.255.255'}, timeout=5) as rh:
+ for no_proxy in (f'127.0.0.1:{self.http_port}', '127.0.0.1', 'localhost'):
+ # When request no proxy includes the request url host
+ nop_response = validate_and_send(
+ rh, Request(f'http://127.0.0.1:{self.http_port}/headers', proxies={'no': no_proxy}))
+ # Then the proxy should not be used
+ assert nop_response.status == 200
+ nop_response.close()
+
+ @pytest.mark.skip_handlers_if(
+ lambda _, handler: Features.ALL_PROXY not in handler._SUPPORTED_FEATURES, 'handler does not support ALL_PROXY')
def test_allproxy(self, handler):
- url = 'http://foo.com/bar'
- with handler() as rh:
- response = validate_and_send(rh, Request(url, proxies={'all': f'http://127.0.0.1:{self.proxy_port}'})).read().decode(
- 'utf-8')
- assert response == f'normal: {url}'
+ # This is a bit of a hacky test, but it should be enough to check whether the handler is using the proxy.
+ # 0.1s might not be enough of a timeout if proxy is not used in all cases, but should still get failures.
+ with handler(proxies={'all': 'http://10.255.255.255'}, timeout=0.1) as rh:
+ with pytest.raises(TransportError):
+ validate_and_send(rh, Request(f'http://127.0.0.1:{self.http_port}/headers')).close()
- @pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
- def test_http_proxy_with_idn(self, handler):
- with handler(proxies={
- 'http': f'http://127.0.0.1:{self.proxy_port}',
- }) as rh:
- url = 'http://中文.tw/'
- response = rh.send(Request(url)).read().decode()
- # b'xn--fiq228c' is '中文'.encode('idna')
- assert response == 'normal: http://xn--fiq228c.tw/'
+ with handler(timeout=0.1) as rh:
+ with pytest.raises(TransportError):
+ validate_and_send(
+ rh, Request(
+ f'http://127.0.0.1:{self.http_port}/headers', proxies={'all': 'http://10.255.255.255'})).close()
+@pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
class TestClientCertificate:
@classmethod
def setup_class(cls):
@@ -745,27 +682,23 @@ class TestClientCertificate:
) as rh:
validate_and_send(rh, Request(f'https://127.0.0.1:{self.port}/video.html')).read().decode()
- @pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
def test_certificate_combined_nopass(self, handler):
self._run_test(handler, client_cert={
'client_certificate': os.path.join(self.certdir, 'clientwithkey.crt'),
})
- @pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
def test_certificate_nocombined_nopass(self, handler):
self._run_test(handler, client_cert={
'client_certificate': os.path.join(self.certdir, 'client.crt'),
'client_certificate_key': os.path.join(self.certdir, 'client.key'),
})
- @pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
def test_certificate_combined_pass(self, handler):
self._run_test(handler, client_cert={
'client_certificate': os.path.join(self.certdir, 'clientwithencryptedkey.crt'),
'client_certificate_password': 'foobar',
})
- @pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
def test_certificate_nocombined_pass(self, handler):
self._run_test(handler, client_cert={
'client_certificate': os.path.join(self.certdir, 'client.crt'),
@@ -805,8 +738,8 @@ class TestRequestHandlerMisc:
assert len(logging_handlers) == before_count
+@pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
class TestUrllibRequestHandler(TestRequestHandlerBase):
- @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
def test_file_urls(self, handler):
# See https://github.com/ytdl-org/youtube-dl/issues/8227
tf = tempfile.NamedTemporaryFile(delete=False)
@@ -828,7 +761,6 @@ class TestUrllibRequestHandler(TestRequestHandlerBase):
os.unlink(tf.name)
- @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
def test_http_error_returns_content(self, handler):
# urllib HTTPError will try close the underlying response if reference to the HTTPError object is lost
def get_response():
@@ -841,7 +773,6 @@ class TestUrllibRequestHandler(TestRequestHandlerBase):
assert get_response().read() == b''
- @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
def test_verify_cert_error_text(self, handler):
# Check the output of the error message
with handler() as rh:
@@ -851,7 +782,6 @@ class TestUrllibRequestHandler(TestRequestHandlerBase):
):
validate_and_send(rh, Request(f'https://127.0.0.1:{self.https_port}/headers'))
- @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
@pytest.mark.parametrize('req,match,version_check', [
# https://github.com/python/cpython/blob/987b712b4aeeece336eed24fcc87a950a756c3e2/Lib/http/client.py#L1256
# bpo-39603: Check implemented in 3.7.9+, 3.8.5+
@@ -1183,7 +1113,7 @@ class TestRequestHandlerValidation:
]
PROXY_SCHEME_TESTS = [
- # scheme, expected to fail
+ # proxy scheme, expected to fail
('Urllib', 'http', [
('http', False),
('https', UnsupportedRequest),
@@ -1209,30 +1139,41 @@ class TestRequestHandlerValidation:
('socks5', False),
('socks5h', False),
]),
+ ('Websockets', 'ws', [
+ ('http', False),
+ ('https', False),
+ ('socks4', False),
+ ('socks4a', False),
+ ('socks5', False),
+ ('socks5h', False),
+ ]),
(NoCheckRH, 'http', [('http', False)]),
(HTTPSupportedRH, 'http', [('http', UnsupportedRequest)]),
- ('Websockets', 'ws', [('http', UnsupportedRequest)]),
(NoCheckRH, 'http', [('http', False)]),
(HTTPSupportedRH, 'http', [('http', UnsupportedRequest)]),
]
PROXY_KEY_TESTS = [
- # key, expected to fail
- ('Urllib', [
- ('all', False),
- ('unrelated', False),
+ # proxy key, proxy scheme, expected to fail
+ ('Urllib', 'http', [
+ ('all', 'http', False),
+ ('unrelated', 'http', False),
]),
- ('Requests', [
- ('all', False),
- ('unrelated', False),
+ ('Requests', 'http', [
+ ('all', 'http', False),
+ ('unrelated', 'http', False),
]),
- ('CurlCFFI', [
- ('all', False),
- ('unrelated', False),
+ ('CurlCFFI', 'http', [
+ ('all', 'http', False),
+ ('unrelated', 'http', False),
+ ]),
+ ('Websockets', 'ws', [
+ ('all', 'socks5', False),
+ ('unrelated', 'socks5', False),
]),
- (NoCheckRH, [('all', False)]),
- (HTTPSupportedRH, [('all', UnsupportedRequest)]),
- (HTTPSupportedRH, [('no', UnsupportedRequest)]),
+ (NoCheckRH, 'http', [('all', 'http', False)]),
+ (HTTPSupportedRH, 'http', [('all', 'http', UnsupportedRequest)]),
+ (HTTPSupportedRH, 'http', [('no', 'http', UnsupportedRequest)]),
]
EXTENSION_TESTS = [
@@ -1274,28 +1215,54 @@ class TestRequestHandlerValidation:
]),
]
+ @pytest.mark.parametrize('handler,fail,scheme', [
+ ('Urllib', False, 'http'),
+ ('Requests', False, 'http'),
+ ('CurlCFFI', False, 'http'),
+ ('Websockets', False, 'ws')
+ ], indirect=['handler'])
+ def test_no_proxy(self, handler, fail, scheme):
+ run_validation(handler, fail, Request(f'{scheme}://example.com', proxies={'no': '127.0.0.1,github.com'}))
+ run_validation(handler, fail, Request(f'{scheme}://example.com'), proxies={'no': '127.0.0.1,github.com'})
+
+ @pytest.mark.parametrize('handler,scheme', [
+ ('Urllib', 'http'),
+ (HTTPSupportedRH, 'http'),
+ ('Requests', 'http'),
+ ('CurlCFFI', 'http'),
+ ('Websockets', 'ws')
+ ], indirect=['handler'])
+ def test_empty_proxy(self, handler, scheme):
+ run_validation(handler, False, Request(f'{scheme}://', proxies={scheme: None}))
+ run_validation(handler, False, Request(f'{scheme}://'), proxies={scheme: None})
+
+ @pytest.mark.parametrize('proxy_url', ['//example.com', 'example.com', '127.0.0.1', '/a/b/c'])
+ @pytest.mark.parametrize('handler,scheme', [
+ ('Urllib', 'http'),
+ (HTTPSupportedRH, 'http'),
+ ('Requests', 'http'),
+ ('CurlCFFI', 'http'),
+ ('Websockets', 'ws')
+ ], indirect=['handler'])
+ def test_invalid_proxy_url(self, handler, scheme, proxy_url):
+ run_validation(handler, UnsupportedRequest, Request(f'{scheme}://', proxies={scheme: proxy_url}))
+
@pytest.mark.parametrize('handler,scheme,fail,handler_kwargs', [
(handler_tests[0], scheme, fail, handler_kwargs)
for handler_tests in URL_SCHEME_TESTS
for scheme, fail, handler_kwargs in handler_tests[1]
-
], indirect=['handler'])
def test_url_scheme(self, handler, scheme, fail, handler_kwargs):
run_validation(handler, fail, Request(f'{scheme}://'), **(handler_kwargs or {}))
- @pytest.mark.parametrize('handler,fail', [('Urllib', False), ('Requests', False), ('CurlCFFI', False)], indirect=['handler'])
- def test_no_proxy(self, handler, fail):
- run_validation(handler, fail, Request('http://', proxies={'no': '127.0.0.1,github.com'}))
- run_validation(handler, fail, Request('http://'), proxies={'no': '127.0.0.1,github.com'})
-
- @pytest.mark.parametrize('handler,proxy_key,fail', [
- (handler_tests[0], proxy_key, fail)
+ @pytest.mark.parametrize('handler,scheme,proxy_key,proxy_scheme,fail', [
+ (handler_tests[0], handler_tests[1], proxy_key, proxy_scheme, fail)
for handler_tests in PROXY_KEY_TESTS
- for proxy_key, fail in handler_tests[1]
+ for proxy_key, proxy_scheme, fail in handler_tests[2]
], indirect=['handler'])
- def test_proxy_key(self, handler, proxy_key, fail):
- run_validation(handler, fail, Request('http://', proxies={proxy_key: 'http://example.com'}))
- run_validation(handler, fail, Request('http://'), proxies={proxy_key: 'http://example.com'})
+ def test_proxy_key(self, handler, scheme, proxy_key, proxy_scheme, fail):
+ run_validation(handler, fail, Request(f'{scheme}://', proxies={proxy_key: f'{proxy_scheme}://example.com'}))
+ run_validation(handler, fail, Request(f'{scheme}://'), proxies={proxy_key: f'{proxy_scheme}://example.com'})
@pytest.mark.parametrize('handler,req_scheme,scheme,fail', [
(handler_tests[0], handler_tests[1], scheme, fail)
@@ -1306,16 +1273,6 @@ class TestRequestHandlerValidation:
run_validation(handler, fail, Request(f'{req_scheme}://', proxies={req_scheme: f'{scheme}://example.com'}))
run_validation(handler, fail, Request(f'{req_scheme}://'), proxies={req_scheme: f'{scheme}://example.com'})
- @pytest.mark.parametrize('handler', ['Urllib', HTTPSupportedRH, 'Requests', 'CurlCFFI'], indirect=True)
- def test_empty_proxy(self, handler):
- run_validation(handler, False, Request('http://', proxies={'http': None}))
- run_validation(handler, False, Request('http://'), proxies={'http': None})
-
- @pytest.mark.parametrize('proxy_url', ['//example.com', 'example.com', '127.0.0.1', '/a/b/c'])
- @pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
- def test_invalid_proxy_url(self, handler, proxy_url):
- run_validation(handler, UnsupportedRequest, Request('http://', proxies={'http': proxy_url}))
-
@pytest.mark.parametrize('handler,scheme,extensions,fail', [
(handler_tests[0], handler_tests[1], extensions, fail)
for handler_tests in EXTENSION_TESTS
diff --git a/test/test_socks.py b/test/test_socks.py
index 43d612d85..20237dc76 100644
--- a/test/test_socks.py
+++ b/test/test_socks.py
@@ -216,7 +216,9 @@ class SocksWebSocketTestRequestHandler(SocksTestRequestHandler):
protocol = websockets.ServerProtocol()
connection = websockets.sync.server.ServerConnection(socket=self.request, protocol=protocol, close_timeout=0)
connection.handshake()
- connection.send(json.dumps(self.socks_info))
+ for message in connection:
+ if message == 'socks_info':
+ connection.send(json.dumps(self.socks_info))
connection.close()
diff --git a/test/test_websockets.py b/test/test_websockets.py
index b294b0932..bc9f2187a 100644
--- a/test/test_websockets.py
+++ b/test/test_websockets.py
@@ -7,6 +7,7 @@ import sys
import pytest
from test.helper import verify_address_availability
+from yt_dlp.networking.common import Features
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
@@ -18,7 +19,7 @@ import random
import ssl
import threading
-from yt_dlp import socks
+from yt_dlp import socks, traverse_obj
from yt_dlp.cookies import YoutubeDLCookieJar
from yt_dlp.dependencies import websockets
from yt_dlp.networking import Request
@@ -114,6 +115,7 @@ def ws_validate_and_send(rh, req):
@pytest.mark.skipif(not websockets, reason='websockets must be installed to test websocket request handlers')
+@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
class TestWebsSocketRequestHandlerConformance:
@classmethod
def setup_class(cls):
@@ -129,7 +131,6 @@ class TestWebsSocketRequestHandlerConformance:
cls.mtls_wss_thread, cls.mtls_wss_port = create_mtls_wss_websocket_server()
cls.mtls_wss_base_url = f'wss://127.0.0.1:{cls.mtls_wss_port}'
- @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_basic_websockets(self, handler):
with handler() as rh:
ws = ws_validate_and_send(rh, Request(self.ws_base_url))
@@ -141,7 +142,6 @@ class TestWebsSocketRequestHandlerConformance:
# https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6
@pytest.mark.parametrize('msg,opcode', [('str', 1), (b'bytes', 2)])
- @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_send_types(self, handler, msg, opcode):
with handler() as rh:
ws = ws_validate_and_send(rh, Request(self.ws_base_url))
@@ -149,7 +149,6 @@ class TestWebsSocketRequestHandlerConformance:
assert int(ws.recv()) == opcode
ws.close()
- @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_verify_cert(self, handler):
with handler() as rh:
with pytest.raises(CertificateVerifyError):
@@ -160,14 +159,12 @@ class TestWebsSocketRequestHandlerConformance:
assert ws.status == 101
ws.close()
- @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_ssl_error(self, handler):
with handler(verify=False) as rh:
with pytest.raises(SSLError, match=r'ssl(?:v3|/tls) alert handshake failure') as exc_info:
ws_validate_and_send(rh, Request(self.bad_wss_host))
assert not issubclass(exc_info.type, CertificateVerifyError)
- @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
@pytest.mark.parametrize('path,expected', [
# Unicode characters should be encoded with uppercase percent-encoding
('/中文', '/%E4%B8%AD%E6%96%87'),
@@ -182,7 +179,6 @@ class TestWebsSocketRequestHandlerConformance:
assert ws.status == 101
ws.close()
- @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_remove_dot_segments(self, handler):
with handler() as rh:
# This isn't a comprehensive test,
@@ -195,7 +191,6 @@ class TestWebsSocketRequestHandlerConformance:
# We are restricted to known HTTP status codes in http.HTTPStatus
# Redirects are not supported for websockets
- @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
@pytest.mark.parametrize('status', (200, 204, 301, 302, 303, 400, 500, 511))
def test_raise_http_error(self, handler, status):
with handler() as rh:
@@ -203,7 +198,6 @@ class TestWebsSocketRequestHandlerConformance:
ws_validate_and_send(rh, Request(f'{self.ws_base_url}/gen_{status}'))
assert exc_info.value.status == status
- @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
@pytest.mark.parametrize('params,extensions', [
({'timeout': sys.float_info.min}, {}),
({}, {'timeout': sys.float_info.min}),
@@ -213,7 +207,6 @@ class TestWebsSocketRequestHandlerConformance:
with pytest.raises(TransportError):
ws_validate_and_send(rh, Request(self.ws_base_url, extensions=extensions))
- @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_cookies(self, handler):
cookiejar = YoutubeDLCookieJar()
cookiejar.set_cookie(http.cookiejar.Cookie(
@@ -239,7 +232,6 @@ class TestWebsSocketRequestHandlerConformance:
assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
ws.close()
- @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_source_address(self, handler):
source_address = f'127.0.0.{random.randint(5, 255)}'
verify_address_availability(source_address)
@@ -249,7 +241,6 @@ class TestWebsSocketRequestHandlerConformance:
assert source_address == ws.recv()
ws.close()
- @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_response_url(self, handler):
with handler() as rh:
url = f'{self.ws_base_url}/something'
@@ -257,7 +248,6 @@ class TestWebsSocketRequestHandlerConformance:
assert ws.url == url
ws.close()
- @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_request_headers(self, handler):
with handler(headers=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'})) as rh:
# Global Headers
@@ -293,7 +283,6 @@ class TestWebsSocketRequestHandlerConformance:
'client_certificate_password': 'foobar',
}
))
- @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_mtls(self, handler, client_cert):
with handler(
# Disable client-side validation of unacceptable self-signed testcert.pem
@@ -303,6 +292,44 @@ class TestWebsSocketRequestHandlerConformance:
) as rh:
ws_validate_and_send(rh, Request(self.mtls_wss_base_url)).close()
+ def test_request_disable_proxy(self, handler):
+ for proxy_proto in handler._SUPPORTED_PROXY_SCHEMES or ['ws']:
+ # Given handler is configured with a proxy
+ with handler(proxies={'ws': f'{proxy_proto}://10.255.255.255'}, timeout=5) as rh:
+ # When a proxy is explicitly set to None for the request
+ ws = ws_validate_and_send(rh, Request(self.ws_base_url, proxies={'http': None}))
+ # Then no proxy should be used
+ assert ws.status == 101
+ ws.close()
+
+ @pytest.mark.skip_handlers_if(
+ lambda _, handler: Features.NO_PROXY not in handler._SUPPORTED_FEATURES, 'handler does not support NO_PROXY')
+ def test_noproxy(self, handler):
+ for proxy_proto in handler._SUPPORTED_PROXY_SCHEMES or ['ws']:
+ # Given the handler is configured with a proxy
+ with handler(proxies={'ws': f'{proxy_proto}://10.255.255.255'}, timeout=5) as rh:
+ for no_proxy in (f'127.0.0.1:{self.ws_port}', '127.0.0.1', 'localhost'):
+ # When request no proxy includes the request url host
+ ws = ws_validate_and_send(rh, Request(self.ws_base_url, proxies={'no': no_proxy}))
+ # Then the proxy should not be used
+ assert ws.status == 101
+ ws.close()
+
+ @pytest.mark.skip_handlers_if(
+ lambda _, handler: Features.ALL_PROXY not in handler._SUPPORTED_FEATURES, 'handler does not support ALL_PROXY')
+ def test_allproxy(self, handler):
+ supported_proto = traverse_obj(handler._SUPPORTED_PROXY_SCHEMES, 0, default='ws')
+ # This is a bit of a hacky test, but it should be enough to check whether the handler is using the proxy.
+ # 0.1s might not be enough of a timeout if proxy is not used in all cases, but should still get failures.
+ with handler(proxies={'all': f'{supported_proto}://10.255.255.255'}, timeout=0.1) as rh:
+ with pytest.raises(TransportError):
+ ws_validate_and_send(rh, Request(self.ws_base_url)).close()
+
+ with handler(timeout=0.1) as rh:
+ with pytest.raises(TransportError):
+ ws_validate_and_send(
+ rh, Request(self.ws_base_url, proxies={'all': f'{supported_proto}://10.255.255.255'})).close()
+
def create_fake_ws_connection(raised):
import websockets.sync.client
diff --git a/yt_dlp/YoutubeDL.py b/yt_dlp/YoutubeDL.py
index 9f730d038..d9fed50ff 100644
--- a/yt_dlp/YoutubeDL.py
+++ b/yt_dlp/YoutubeDL.py
@@ -4140,15 +4140,15 @@ class YoutubeDL:
'Use --enable-file-urls to enable at your own risk.', cause=ue) from ue
if (
'unsupported proxy type: "https"' in ue.msg.lower()
- and 'requests' not in self._request_director.handlers
- and 'curl_cffi' not in self._request_director.handlers
+ and 'Requests' not in self._request_director.handlers
+ and 'CurlCFFI' not in self._request_director.handlers
):
raise RequestError(
'To use an HTTPS proxy for this request, one of the following dependencies needs to be installed: requests, curl_cffi')
elif (
re.match(r'unsupported url scheme: "wss?"', ue.msg.lower())
- and 'websockets' not in self._request_director.handlers
+ and 'Websockets' not in self._request_director.handlers
):
raise RequestError(
'This request requires WebSocket support. '
diff --git a/yt_dlp/networking/_curlcffi.py b/yt_dlp/networking/_curlcffi.py
index 39d1f70fb..d2afce089 100644
--- a/yt_dlp/networking/_curlcffi.py
+++ b/yt_dlp/networking/_curlcffi.py
@@ -21,7 +21,7 @@ from .exceptions import (
TransportError,
)
from .impersonate import ImpersonateRequestHandler, ImpersonateTarget
-from ..dependencies import curl_cffi
+from ..dependencies import curl_cffi, certifi
from ..utils import int_or_none
if curl_cffi is None:
@@ -156,6 +156,13 @@ class CurlCFFIRH(ImpersonateRequestHandler, InstanceStoreMixin):
# See: https://curl.se/libcurl/c/CURLOPT_HTTPPROXYTUNNEL.html
session.curl.setopt(CurlOpt.HTTPPROXYTUNNEL, 1)
+ # curl_cffi does not currently set these for proxies
+ session.curl.setopt(CurlOpt.PROXY_CAINFO, certifi.where())
+
+ if not self.verify:
+ session.curl.setopt(CurlOpt.PROXY_SSL_VERIFYPEER, 0)
+ session.curl.setopt(CurlOpt.PROXY_SSL_VERIFYHOST, 0)
+
headers = self._get_impersonate_headers(request)
if self._client_cert:
@@ -203,7 +210,10 @@ class CurlCFFIRH(ImpersonateRequestHandler, InstanceStoreMixin):
max_redirects_exceeded = True
curl_response = e.response
- elif e.code == CurlECode.PROXY:
+ elif (
+ e.code == CurlECode.PROXY
+ or (e.code == CurlECode.RECV_ERROR and 'Received HTTP code 407 from proxy after CONNECT' in str(e))
+ ):
raise ProxyError(cause=e) from e
else:
raise TransportError(cause=e) from e
diff --git a/yt_dlp/networking/_websockets.py b/yt_dlp/networking/_websockets.py
index 6e235b0c6..ab75b2868 100644
--- a/yt_dlp/networking/_websockets.py
+++ b/yt_dlp/networking/_websockets.py
@@ -1,10 +1,13 @@
from __future__ import annotations
+import base64
import contextlib
import io
import logging
import ssl
import sys
+import urllib.parse
+from http.client import HTTPConnection, HTTPResponse
from ._helper import (
create_connection,
@@ -20,12 +23,14 @@ from .exceptions import (
RequestError,
SSLError,
TransportError,
+ UnsupportedRequest,
)
from .websocket import WebSocketRequestHandler, WebSocketResponse
from ..compat import functools
-from ..dependencies import websockets
+from ..dependencies import urllib3, websockets
from ..socks import ProxyError as SocksProxyError
from ..utils import int_or_none
+from ..utils.networking import HTTPHeaderDict
if not websockets:
raise ImportError('websockets is not installed')
@@ -36,6 +41,11 @@ websockets_version = tuple(map(int_or_none, websockets.version.version.split('.'
if websockets_version < (12, 0):
raise ImportError('Only websockets>=12.0 is supported')
+urllib3_supported = False
+urllib3_version = tuple(int_or_none(x, default=0) for x in urllib3.__version__.split('.')) if urllib3 else None
+if urllib3_version and urllib3_version >= (1, 26, 17):
+ urllib3_supported = True
+
import websockets.sync.client
from websockets.uri import parse_uri
@@ -98,7 +108,7 @@ class WebsocketsRH(WebSocketRequestHandler):
https://github.com/python-websockets/websockets
"""
_SUPPORTED_URL_SCHEMES = ('wss', 'ws')
- _SUPPORTED_PROXY_SCHEMES = ('socks4', 'socks4a', 'socks5', 'socks5h')
+ _SUPPORTED_PROXY_SCHEMES = ('socks4', 'socks4a', 'socks5', 'socks5h', 'http', 'https')
_SUPPORTED_FEATURES = (Features.ALL_PROXY, Features.NO_PROXY)
RH_NAME = 'websockets'
@@ -108,12 +118,23 @@ class WebsocketsRH(WebSocketRequestHandler):
for name in ('websockets.client', 'websockets.server'):
logger = logging.getLogger(name)
handler = logging.StreamHandler(stream=sys.stdout)
- handler.setFormatter(logging.Formatter(f'{self.RH_NAME}: %(message)s'))
+ handler.setFormatter(logging.Formatter(f'{self.RH_NAME}: [{name}] %(message)s'))
self.__logging_handlers[name] = handler
logger.addHandler(handler)
if self.verbose:
logger.setLevel(logging.DEBUG)
+ def _validate(self, request):
+ super()._validate(request)
+ proxy = select_proxy(request.url, self._get_proxies(request))
+ if (
+ proxy
+ and urllib.parse.urlparse(proxy).scheme.lower() == 'https'
+ and urllib.parse.urlparse(request.url).scheme.lower() == 'wss'
+ and not urllib3_supported
+ ):
+ raise UnsupportedRequest('WSS over HTTPS proxies requires a supported version of urllib3')
+
def _check_extensions(self, extensions):
super()._check_extensions(extensions)
extensions.pop('timeout', None)
@@ -125,6 +146,38 @@ class WebsocketsRH(WebSocketRequestHandler):
for name, handler in self.__logging_handlers.items():
logging.getLogger(name).removeHandler(handler)
+ def _make_sock(self, proxy, url, timeout):
+ create_conn_kwargs = {
+ 'source_address': (self.source_address, 0) if self.source_address else None,
+ 'timeout': timeout
+ }
+ parsed_url = parse_uri(url)
+ parsed_proxy_url = urllib.parse.urlparse(proxy)
+ if proxy:
+ if parsed_proxy_url.scheme.startswith('socks'):
+ socks_proxy_options = make_socks_proxy_opts(proxy)
+ return create_connection(
+ address=(socks_proxy_options['addr'], socks_proxy_options['port']),
+ _create_socket_func=functools.partial(
+ create_socks_proxy_socket, (parsed_url.host, parsed_url.port), socks_proxy_options),
+ **create_conn_kwargs
+ )
+
+ elif parsed_proxy_url.scheme in ('http', 'https'):
+ return create_http_connect_conn(
+ proxy_url=proxy,
+ url=url,
+ timeout=timeout,
+ ssl_context=self._make_sslcontext() if parsed_proxy_url.scheme == 'https' else None,
+ source_address=self.source_address,
+ username=parsed_proxy_url.username,
+ password=parsed_proxy_url.password,
+ )
+ return create_connection(
+ address=(parsed_url.host, parsed_url.port),
+ **create_conn_kwargs
+ )
+
def _send(self, request):
timeout = self._calculate_timeout(request)
headers = self._merge_headers(request.headers)
@@ -134,33 +187,22 @@ class WebsocketsRH(WebSocketRequestHandler):
if cookie_header:
headers['cookie'] = cookie_header
- wsuri = parse_uri(request.url)
- create_conn_kwargs = {
- 'source_address': (self.source_address, 0) if self.source_address else None,
- 'timeout': timeout
- }
proxy = select_proxy(request.url, self._get_proxies(request))
- try:
- if proxy:
- socks_proxy_options = make_socks_proxy_opts(proxy)
- sock = create_connection(
- address=(socks_proxy_options['addr'], socks_proxy_options['port']),
- _create_socket_func=functools.partial(
- create_socks_proxy_socket, (wsuri.host, wsuri.port), socks_proxy_options),
- **create_conn_kwargs
- )
+
+ ssl_context = None
+ if parse_uri(request.url).secure:
+ if WebsocketsSSLContext is not None:
+ ssl_context = WebsocketsSSLContext(self._make_sslcontext())
else:
- sock = create_connection(
- address=(wsuri.host, wsuri.port),
- **create_conn_kwargs
- )
+ ssl_context = self._make_sslcontext()
+ try:
conn = websockets.sync.client.connect(
- sock=sock,
+ sock=self._make_sock(proxy, request.url, timeout),
uri=request.url,
additional_headers=headers,
open_timeout=timeout,
user_agent_header=None,
- ssl_context=self._make_sslcontext() if wsuri.secure else None,
+ ssl_context=ssl_context,
close_timeout=0, # not ideal, but prevents yt-dlp hanging
)
return WebsocketsResponseAdapter(conn, url=request.url)
@@ -185,3 +227,98 @@ class WebsocketsRH(WebSocketRequestHandler):
) from e
except (OSError, TimeoutError, websockets.exceptions.WebSocketException) as e:
raise TransportError(cause=e) from e
+
+
+class NoCloseHTTPResponse(HTTPResponse):
+ def begin(self):
+ super().begin()
+ # Revert the default behavior of closing the connection after reading the response
+ if not self._check_close() and not self.chunked and self.length is None:
+ self.will_close = False
+
+
+if urllib3_supported:
+ from urllib3.util.ssltransport import SSLTransport
+
+ class WebsocketsSSLTransport(SSLTransport):
+ """
+ Modified version of urllib3 SSLTransport to support additional operations used by websockets
+ """
+ def setsockopt(self, *args, **kwargs):
+ self.socket.setsockopt(*args, **kwargs)
+
+ def shutdown(self, *args, **kwargs):
+ self.unwrap()
+ self.socket.shutdown(*args, **kwargs)
+else:
+ WebsocketsSSLTransport = None
+
+
+class WebsocketsSSLContext:
+ """
+ Dummy SSL Context for websockets which returns a WebsocketsSSLTransport instance
+ for wrap socket when using TLS-in-TLS.
+ """
+ def __init__(self, ssl_context: ssl.SSLContext):
+ self.ssl_context = ssl_context
+
+ def wrap_socket(self, sock, server_hostname=None):
+ if isinstance(sock, ssl.SSLSocket):
+ return WebsocketsSSLTransport(sock, self.ssl_context, server_hostname=server_hostname)
+ return self.ssl_context.wrap_socket(sock, server_hostname=server_hostname)
+
+
+def create_http_connect_conn(
+ proxy_url,
+ url,
+ timeout=None,
+ ssl_context=None,
+ source_address=None,
+ username=None,
+ password=None,
+):
+
+ proxy_headers = HTTPHeaderDict()
+
+ if username is not None or password is not None:
+ proxy_headers['Proxy-Authorization'] = 'Basic ' + base64.b64encode(
+ f'{username or ""}:{password or ""}'.encode('utf-8')).decode('utf-8')
+
+ proxy_url_parsed = urllib.parse.urlparse(proxy_url)
+ request_url_parsed = parse_uri(url)
+
+ conn = HTTPConnection(proxy_url_parsed.hostname, port=proxy_url_parsed.port, timeout=timeout)
+ conn.response_class = NoCloseHTTPResponse
+
+ if hasattr(conn, '_create_connection'):
+ conn._create_connection = create_connection
+
+ if source_address is not None:
+ conn.source_address = (source_address, 0)
+
+ try:
+ conn.connect()
+ if ssl_context:
+ conn.sock = ssl_context.wrap_socket(conn.sock, server_hostname=proxy_url_parsed.hostname)
+ conn.request(
+ method='CONNECT',
+ url=f'{request_url_parsed.host}:{request_url_parsed.port}',
+ headers=proxy_headers)
+ response = conn.getresponse()
+ except OSError as e:
+ conn.close()
+ raise ProxyError('Unable to connect to proxy', cause=e) from e
+
+ if response.status == 200:
+ return conn.sock
+ elif response.status == 407:
+ conn.close()
+ raise ProxyError('Got HTTP Error 407 with CONNECT: Proxy Authentication Required')
+ else:
+ conn.close()
+ res_adapter = Response(
+ fp=io.BytesIO(b''),
+ url=proxy_url, headers=response.headers,
+ status=response.status,
+ reason=response.reason)
+ raise HTTPError(response=res_adapter)
diff --git a/yt_dlp/networking/websocket.py b/yt_dlp/networking/websocket.py
index 0e7e73c9e..d407cadad 100644
--- a/yt_dlp/networking/websocket.py
+++ b/yt_dlp/networking/websocket.py
@@ -1,8 +1,9 @@
from __future__ import annotations
import abc
+import urllib.parse
-from .common import RequestHandler, Response
+from .common import RequestHandler, Response, register_preference
class WebSocketResponse(Response):
@@ -21,3 +22,10 @@ class WebSocketResponse(Response):
class WebSocketRequestHandler(RequestHandler, abc.ABC):
pass
+
+
+@register_preference(WebSocketRequestHandler)
+def websocket_preference(_, request):
+ if urllib.parse.urlparse(request.url).scheme in ('ws', 'wss'):
+ return 200
+ return 0