diff --git a/test/test_download.py b/test/test_download.py index 43b39c36b..fd7752cdd 100755 --- a/test/test_download.py +++ b/test/test_download.py @@ -10,10 +10,7 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import collections import hashlib -import http.client import json -import socket -import urllib.error from test.helper import ( assertGreaterEqual, @@ -29,6 +26,7 @@ from test.helper import ( import yt_dlp.YoutubeDL # isort: split from yt_dlp.extractor import get_info_extractor +from yt_dlp.networking.exceptions import HTTPError, TransportError from yt_dlp.utils import ( DownloadError, ExtractorError, @@ -162,8 +160,7 @@ def generator(test_case, tname): force_generic_extractor=params.get('force_generic_extractor', False)) except (DownloadError, ExtractorError) as err: # Check if the exception is not a network related one - if (err.exc_info[0] not in (urllib.error.URLError, socket.timeout, UnavailableVideoError, http.client.BadStatusLine) - or (err.exc_info[0] == urllib.error.HTTPError and err.exc_info[1].code == 503)): + if not isinstance(err.exc_info[1], (TransportError, UnavailableVideoError)) or (isinstance(err.exc_info[1], HTTPError) and err.exc_info[1].code == 503): err.msg = f'{getattr(err, "msg", err)} ({tname})' raise @@ -249,7 +246,7 @@ def generator(test_case, tname): # extractor returns full results even with extract_flat res_tcs = [{'info_dict': e} for e in res_dict['entries']] try_rm_tcs_files(res_tcs) - + ydl.close() return test_template diff --git a/test/test_networking.py b/test/test_networking.py index e4e66dce1..147a4ff49 100644 --- a/test/test_networking.py +++ b/test/test_networking.py @@ -3,32 +3,74 @@ # Allow direct execution import os import sys -import unittest + +import pytest sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +import functools import gzip +import http.client import http.cookiejar import http.server +import inspect import io import pathlib +import random import ssl import tempfile import threading +import time import urllib.error import urllib.request +import warnings import zlib +from email.message import Message +from http.cookiejar import CookieJar -from test.helper import http_server_port -from yt_dlp import YoutubeDL +from test.helper import FakeYDL, http_server_port from yt_dlp.dependencies import brotli -from yt_dlp.utils import sanitized_Request, urlencode_postdata - -from .helper import FakeYDL +from yt_dlp.networking import ( + HEADRequest, + PUTRequest, + Request, + RequestDirector, + RequestHandler, + Response, +) +from yt_dlp.networking._urllib import UrllibRH +from yt_dlp.networking.common import _REQUEST_HANDLERS +from yt_dlp.networking.exceptions import ( + CertificateVerifyError, + HTTPError, + IncompleteRead, + NoSupportingHandlers, + RequestError, + SSLError, + TransportError, + UnsupportedRequest, +) +from yt_dlp.utils._utils import _YDLLogger as FakeLogger +from yt_dlp.utils.networking import HTTPHeaderDict TEST_DIR = os.path.dirname(os.path.abspath(__file__)) +def _build_proxy_handler(name): + class HTTPTestRequestHandler(http.server.BaseHTTPRequestHandler): + proxy_name = name + + def log_message(self, format, *args): + pass + + def do_GET(self): + self.send_response(200) + self.send_header('Content-Type', 'text/plain; charset=utf-8') + self.end_headers() + self.wfile.write('{self.proxy_name}: {self.path}'.format(self=self).encode()) + return HTTPTestRequestHandler + + class HTTPTestRequestHandler(http.server.BaseHTTPRequestHandler): protocol_version = 'HTTP/1.1' @@ -36,7 +78,7 @@ class HTTPTestRequestHandler(http.server.BaseHTTPRequestHandler): pass def _headers(self): - payload = str(self.headers).encode('utf-8') + payload = str(self.headers).encode() self.send_response(200) self.send_header('Content-Type', 'application/json') self.send_header('Content-Length', str(len(payload))) @@ -70,7 +112,7 @@ class HTTPTestRequestHandler(http.server.BaseHTTPRequestHandler): return self.rfile.read(int(self.headers['Content-Length'])) def do_POST(self): - data = self._read_data() + data = self._read_data() + str(self.headers).encode() if self.path.startswith('/redirect_'): self._redirect() elif self.path.startswith('/method'): @@ -89,7 +131,7 @@ class HTTPTestRequestHandler(http.server.BaseHTTPRequestHandler): self._status(404) def do_PUT(self): - data = self._read_data() + data = self._read_data() + str(self.headers).encode() if self.path.startswith('/redirect_'): self._redirect() elif self.path.startswith('/method'): @@ -102,7 +144,7 @@ class HTTPTestRequestHandler(http.server.BaseHTTPRequestHandler): payload = b'