diff --git a/test/test_mp4parser.py b/test/test_mp4parser.py new file mode 100644 index 000000000..dbaa61f95 --- /dev/null +++ b/test/test_mp4parser.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python + +# Allow direct execution +import os +import sys +import unittest +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import io + +from yt_dlp.postprocessor.mp4direct import ( + parse_mp4_boxes, + write_mp4_boxes, +) + +TEST_SEQUENCE = [ + ('test', b'123456'), + ('trak', b''), + ('helo', b'abcdef'), + ('1984', b'1q84'), + ('moov', b''), + ('keys', b'2022'), + (None, 'moov'), + ('topp', b'1991'), + (None, 'trak'), +] + +# on-file reprensetation of the above sequence +TEST_BYTES = b'\x00\x00\x00\x0etest123456\x00\x00\x00Btrak\x00\x00\x00\x0eheloabcdef\x00\x00\x00\x0c19841q84\x00\x00\x00\x14moov\x00\x00\x00\x0ckeys2022\x00\x00\x00\x0ctopp1991' + + +class TestMP4Parser(unittest.TestCase): + def test_write_sequence(self): + with io.BytesIO() as w: + write_mp4_boxes(w, TEST_SEQUENCE) + bs = w.getvalue() + self.assertEqual(TEST_BYTES, bs) + + def test_read_bytes(self): + with io.BytesIO(TEST_BYTES) as r: + result = list(parse_mp4_boxes(r)) + self.assertListEqual(TEST_SEQUENCE, result) + + def test_mismatched_box_end(self): + with io.BytesIO() as w, self.assertRaises(AssertionError): + write_mp4_boxes(w, [ + ('moov', b''), + ('trak', b''), + (None, 'moov'), + (None, 'trak'), + ]) diff --git a/yt_dlp/YoutubeDL.py b/yt_dlp/YoutubeDL.py index 9f730d038..ec5765563 100644 --- a/yt_dlp/YoutubeDL.py +++ b/yt_dlp/YoutubeDL.py @@ -57,6 +57,7 @@ from .postprocessor import ( FFmpegPostProcessor, FFmpegVideoConvertorPP, MoveFilesAfterDownloadPP, + MP4FixupTimestampPP, get_postprocessor, ) from .postprocessor.ffmpeg import resolve_mapping as resolve_recode_mapping @@ -3518,8 +3519,11 @@ class YoutubeDL: and (info_dict.get('is_live') or info_dict.get('is_dash_periods')), 'Possible duplicate MOOV atoms', FFmpegFixupDuplicateMoovPP) + is_fmp4 = info_dict.get('protocol') == 'websocket_frag' and info_dict.get('container') == 'fmp4' ffmpeg_fixup(downloader == 'web_socket_fragment', 'Malformed timestamps detected', FFmpegFixupTimestampPP) ffmpeg_fixup(downloader == 'web_socket_fragment', 'Malformed duration detected', FFmpegFixupDurationPP) + ffmpeg_fixup(downloader == 'web_socket_to_file' and is_fmp4, 'Malformed timestamps detected', MP4FixupTimestampPP) + ffmpeg_fixup(downloader == 'web_socket_to_file' and is_fmp4, 'Possible duplicate MOOV atoms', FFmpegFixupDuplicateMoovPP) fixup() try: diff --git a/yt_dlp/downloader/__init__.py b/yt_dlp/downloader/__init__.py index 51a9f28f0..b99374754 100644 --- a/yt_dlp/downloader/__init__.py +++ b/yt_dlp/downloader/__init__.py @@ -33,7 +33,7 @@ from .mhtml import MhtmlFD from .niconico import NiconicoDmcFD, NiconicoLiveFD from .rtmp import RtmpFD from .rtsp import RtspFD -from .websocket import WebSocketFragmentFD +from .websocket import WebSocketFragmentFD, WebSocketToFileFD from .youtube_live_chat import YoutubeLiveChatFD PROTOCOL_MAP = { @@ -121,6 +121,9 @@ def _get_suitable_downloader(info_dict, protocol, params, default): elif params.get('hls_prefer_native') is False: return FFmpegFD + if protocol == 'websocket_frag' and info_dict.get('container') == 'fmp4' and external_downloader != 'ffmpeg': + return WebSocketToFileFD + return PROTOCOL_MAP.get(protocol, default) diff --git a/yt_dlp/downloader/websocket.py b/yt_dlp/downloader/websocket.py index 6837ff1da..68bfef4d2 100644 --- a/yt_dlp/downloader/websocket.py +++ b/yt_dlp/downloader/websocket.py @@ -1,31 +1,41 @@ import asyncio import contextlib import os -import signal import threading +import time from .common import FileDownloader from .external import FFmpegFD from ..dependencies import websockets -class FFmpegSinkFD(FileDownloader): +class _WebSocketFD(FileDownloader): + async def connect(self, stdin, info_dict): + try: + await self.real_connection(stdin, info_dict) + except OSError: + pass + finally: + with contextlib.suppress(OSError): + stdin.flush() + stdin.close() + + async def real_connection(self, sink, info_dict): + async with websockets.connect(info_dict['url'], extra_headers=info_dict.get('http_headers', {})) as ws: + while True: + recv = await ws.recv() + if isinstance(recv, str): + recv = recv.encode('utf8') + sink.write(recv) + + +class WebSocketFragmentFD(_WebSocketFD): """ A sink to ffmpeg for downloading fragments in any form """ def real_download(self, filename, info_dict): info_copy = info_dict.copy() info_copy['url'] = '-' - - async def call_conn(proc, stdin): - try: - await self.real_connection(stdin, info_dict) - except OSError: - pass - finally: - with contextlib.suppress(OSError): - stdin.flush() - stdin.close() - os.kill(os.getpid(), signal.SIGINT) + connect = self.connect class FFmpegStdinFD(FFmpegFD): @classmethod @@ -33,21 +43,51 @@ class FFmpegSinkFD(FileDownloader): return FFmpegFD.get_basename() def on_process_started(self, proc, stdin): - thread = threading.Thread(target=asyncio.run, daemon=True, args=(call_conn(proc, stdin), )) + thread = threading.Thread(target=asyncio.run, daemon=True, args=(connect(stdin, info_dict), )) thread.start() return FFmpegStdinFD(self.ydl, self.params or {}).download(filename, info_copy) - async def real_connection(self, sink, info_dict): - """ Override this in subclasses """ - raise NotImplementedError('This method must be implemented by subclasses') +class WebSocketToFileFD(_WebSocketFD): + """ A sink to a file for downloading fragments in any form """ + def real_download(self, filename, info_dict): + tempname = self.temp_name(filename) + try: + with open(tempname, 'wb') as w: + started = time.time() + status = { + 'filename': info_dict.get('_filename'), + 'status': 'downloading', + 'elapsed': 0, + 'downloaded_bytes': 0, + } + self._hook_progress(status, info_dict) -class WebSocketFragmentFD(FFmpegSinkFD): - async def real_connection(self, sink, info_dict): - async with websockets.connect(info_dict['url'], extra_headers=info_dict.get('http_headers', {})) as ws: - while True: - recv = await ws.recv() - if isinstance(recv, str): - recv = recv.encode('utf8') - sink.write(recv) + thread = threading.Thread(target=asyncio.run, daemon=True, args=(self.connect(w, info_dict), )) + thread.start() + time_and_size, avg_len = [], 10 + while thread.is_alive(): + time.sleep(0.1) + + downloaded, curr = w.tell(), time.time() + # taken from ffmpeg attachment + time_and_size.append((downloaded, curr)) + time_and_size = time_and_size[-avg_len:] + if len(time_and_size) > 1: + last, early = time_and_size[0], time_and_size[-1] + average_speed = (early[0] - last[0]) / (early[1] - last[1]) + else: + average_speed = None + + status.update({ + 'downloaded_bytes': downloaded, + 'speed': average_speed, + 'elapsed': curr - started, + }) + self._hook_progress(status, info_dict) + except KeyboardInterrupt: + pass + finally: + os.replace(tempname, filename) + return True diff --git a/yt_dlp/extractor/twitcasting.py b/yt_dlp/extractor/twitcasting.py index 28ea16cc2..3fce53164 100644 --- a/yt_dlp/extractor/twitcasting.py +++ b/yt_dlp/extractor/twitcasting.py @@ -195,6 +195,7 @@ class TwitCastingIE(InfoExtractor): 'source_preference': -10, # TwitCasting simply sends moof atom directly over WS 'protocol': 'websocket_frag', + 'container': 'fmp4', }) infodict = { diff --git a/yt_dlp/postprocessor/__init__.py b/yt_dlp/postprocessor/__init__.py index bfe9df733..c847b0713 100644 --- a/yt_dlp/postprocessor/__init__.py +++ b/yt_dlp/postprocessor/__init__.py @@ -30,6 +30,7 @@ from .metadataparser import ( ) from .modify_chapters import ModifyChaptersPP from .movefilesafterdownload import MoveFilesAfterDownloadPP +from .mp4direct import MP4FixupTimestampPP from .sponskrub import SponSkrubPP from .sponsorblock import SponsorBlockPP from .xattrpp import XAttrMetadataPP diff --git a/yt_dlp/postprocessor/mp4direct.py b/yt_dlp/postprocessor/mp4direct.py new file mode 100644 index 000000000..70e50907a --- /dev/null +++ b/yt_dlp/postprocessor/mp4direct.py @@ -0,0 +1,277 @@ +import os +import struct + +from io import BytesIO, RawIOBase +from math import inf +from typing import Tuple + +from .common import PostProcessor +from ..utils import prepend_extension + + +class LengthLimiter(RawIOBase): + """ + A bytes IO to limit length to be read. + """ + + def __init__(self, r: RawIOBase, size: int): + self.r = r + self.remaining = size + + def read(self, sz: int = None) -> bytes: + if self.remaining == 0: + return b'' + if sz in (-1, None): + sz = self.remaining + sz = min(sz, self.remaining) + ret = self.r.read(sz) + if ret: + self.remaining -= len(ret) + return ret + + def readall(self) -> bytes: + if self.remaining == 0: + return b'' + ret = self.read(self.remaining) + if ret: + self.remaining -= len(ret) + return ret + + def readable(self) -> bool: + return bool(self.remaining) + + +def read_harder(r, size): + """ + Try to read from the stream. + + @params r byte stream to read + @params size Number of bytes to read in total + """ + + retry = 0 + buf = b'' + while len(buf) < size and retry < 3: + ret = r.read(size - len(buf)) + if not ret: + retry += 1 + continue + retry = 0 + buf += ret + + return buf + + +def pack_be32(value: int) -> bytes: + """ Pack value to 4-byte-long bytes in the big-endian byte order """ + return struct.pack('>I', value) + + +def pack_be64(value: int) -> bytes: + """ Pack value to 8-byte-long bytes in the big-endian byte order """ + return struct.pack('>L', value) + + +def unpack_be32(value: bytes) -> int: + """ Convert 4-byte-long bytes in the big-endian byte order, to an integer value """ + return struct.unpack('>I', value)[0] + + +def unpack_be64(value: bytes) -> int: + """ Convert 8-byte-long bytes in the big-endian byte order, to an integer value """ + return struct.unpack('>L', value)[0] + + +def unpack_ver_flags(value: bytes) -> Tuple[int, int]: + """ + Unpack 4-byte-long value into version and flags. + @returns (version, flags) + """ + + ver, up_flag, down_flag = struct.unpack('>BBH', value) + return ver, (up_flag << 16 | down_flag) + + +# https://github.com/gpac/mp4box.js/blob/4e1bc23724d2603754971abc00c2bd5aede7be60/src/box.js#L13-L40 +MP4_CONTAINER_BOXES = ('moov', 'trak', 'edts', 'mdia', 'minf', 'dinf', 'stbl', 'mvex', 'moof', 'traf', 'vttc', 'tref', 'iref', 'mfra', 'meco', 'hnti', 'hinf', 'strk', 'strd', 'sinf', 'rinf', 'schi', 'trgr', 'udta', 'iprp', 'ipco') +""" List of boxes that nests the other boxes """ + + +def parse_mp4_boxes(r: RawIOBase): + """ + Parses an ISO BMFF (which MP4 follows) and yields its boxes as a sequence. + This does not interpret content of these boxes. + + Sequence details: + ('atom', b'blablabla'): A box, with content (not container boxes) + ('atom', b''): Possibly container box (must check MP4_CONTAINER_BOXES) or really an empty box + (None, 'atom'): End of a container box + + Example: Path: + ('test', b'123456') /test + ('moov', b'') /moov (start of container box) + ('helo', b'abcdef') /moov/helo + ('1984', b'1q84') /moov/1984 + ('trak', b'') /moov/trak (start of container box) + ('keys', b'2022') /moov/trak/keys + (None , 'trak') /moov/trak (end of container box) + ('topp', b'1991') /moov/topp + (None , 'moov') /moov (end of container box) + """ + + while True: + size_b = read_harder(r, 4) + if not size_b: + break + type_b = r.read(4) + # 00 00 00 20 is big-endian + box_size = unpack_be32(size_b) + type_s = type_b.decode() + if type_s in MP4_CONTAINER_BOXES: + yield (type_s, b'') + yield from parse_mp4_boxes(LengthLimiter(r, box_size - 8)) + yield (None, type_s) + continue + # subtract by 8 + full_body = read_harder(r, box_size - 8) + yield (type_s, full_body) + + +def write_mp4_boxes(w: RawIOBase, box_iter): + """ + Writes an ISO BMFF file from a given sequence to a given writer. + The iterator to be passed must follow parse_mp4_boxes's protocol. + """ + + stack = [ + (None, w), # parent box, IO + ] + for btype, content in box_iter: + if btype in MP4_CONTAINER_BOXES: + bio = BytesIO() + stack.append((btype, bio)) + continue + elif btype is None: + assert stack[-1][0] == content + btype, bio = stack.pop() + content = bio.getvalue() + + wt = stack[-1][1] + wt.write(pack_be32(len(content) + 8)) + wt.write(btype.encode()[:4]) + wt.write(content) + + +class MP4FixupTimestampPP(PostProcessor): + + @property + def available(self): + return True + + def analyze_mp4(self, filepath): + """ returns (baseMediaDecodeTime offset, sample duration cutoff) """ + smallest_bmdt, known_sdur = inf, set() + with open(filepath, 'rb') as r: + for btype, content in parse_mp4_boxes(r): + if btype == 'tfdt': + version, _ = unpack_ver_flags(content[0:4]) + # baseMediaDecodeTime always comes to the first + if version == 0: + bmdt = unpack_be32(content[4:8]) + else: + bmdt = unpack_be64(content[4:12]) + if bmdt == 0: + continue + smallest_bmdt = min(bmdt, smallest_bmdt) + elif btype == 'tfhd': + version, flags = unpack_ver_flags(content[0:4]) + if not flags & 0x08: + # this box does not contain "sample duration" + continue + # https://github.com/gpac/mp4box.js/blob/4e1bc23724d2603754971abc00c2bd5aede7be60/src/box.js#L203-L209 + # https://github.com/gpac/mp4box.js/blob/4e1bc23724d2603754971abc00c2bd5aede7be60/src/parsing/tfhd.js + sdur_start = 8 # header + track id + if flags & 0x01: + sdur_start += 8 + if flags & 0x02: + sdur_start += 4 + # the next 4 bytes are "sample duration" + sample_dur = unpack_be32(content[sdur_start:sdur_start + 4]) + known_sdur.add(sample_dur) + + maximum_sdur = max(known_sdur) + for multiplier in (0.7, 0.8, 0.9, 0.95): + sdur_cutoff = maximum_sdur * multiplier + if len(set(x for x in known_sdur if x > sdur_cutoff)) < 3: + break + else: + sdur_cutoff = inf + + return smallest_bmdt, sdur_cutoff + + @staticmethod + def transform(r, bmdt_offset, sdur_cutoff): + for btype, content in r: + if btype == 'tfdt': + version, _ = unpack_ver_flags(content[0:4]) + if version == 0: + bmdt = unpack_be32(content[4:8]) + else: + bmdt = unpack_be64(content[4:12]) + if bmdt == 0: + yield (btype, content) + continue + # calculate new baseMediaDecodeTime + bmdt = max(0, bmdt - bmdt_offset) + # pack everything again and insert as a new box + if version == 0: + bmdt_b = pack_be32(bmdt) + else: + bmdt_b = pack_be64(bmdt) + yield ('tfdt', content[0:4] + bmdt_b + content[8 + version * 4:]) + continue + elif btype == 'tfhd': + version, flags = unpack_ver_flags(content[0:4]) + if not flags & 0x08: + yield (btype, content) + continue + sdur_start = 8 + if flags & 0x01: + sdur_start += 8 + if flags & 0x02: + sdur_start += 4 + sample_dur = unpack_be32(content[sdur_start:sdur_start + 4]) + if sample_dur > sdur_cutoff: + sample_dur = 0 + sd_b = pack_be32(sample_dur) + yield ('tfhd', content[:sdur_start] + sd_b + content[sdur_start + 4:]) + continue + yield (btype, content) + + def modify_mp4(self, src, dst, bmdt_offset, sdur_cutoff): + with open(src, 'rb') as r, open(dst, 'wb') as w: + write_mp4_boxes(w, self.transform(parse_mp4_boxes(r))) + + def run(self, information): + filename = information['filepath'] + temp_filename = prepend_extension(filename, 'temp') + + self.write_debug('Analyzing MP4') + bmdt_offset, sdur_cutoff = self.analyze_mp4(filename) + working = inf not in (bmdt_offset, sdur_cutoff) + # if any of them are Infinity, there's something wrong + # baseMediaDecodeTime = to shift PTS + # sample duration = to define duration in each segment + self.write_debug(f'baseMediaDecodeTime offset = {bmdt_offset}, sample duration cutoff = {sdur_cutoff}') + if bmdt_offset == inf: + # safeguard + bmdt_offset = 0 + self.modify_mp4(filename, temp_filename, bmdt_offset, sdur_cutoff) + if working: + self.to_screen('Duration of the file has been fixed') + else: + self.report_warning(f'Failed to fix duration of the file. (baseMediaDecodeTime offset = {bmdt_offset}, sample duration cutoff = {sdur_cutoff})') + + os.replace(temp_filename, filename) + + return [], information