From 0b6f829b1dfda15d3c1d7d1fbe4ea6102c26dd24 Mon Sep 17 00:00:00 2001 From: Simon Sawicki Date: Wed, 6 Dec 2023 21:46:45 +0100 Subject: [PATCH] [utils] `traverse_obj`: Move `is_user_input` into output template (#8673) Authored by: Grub4K --- test/test_utils.py | 17 ----------------- yt_dlp/YoutubeDL.py | 14 ++++++++++++-- yt_dlp/utils/traversal.py | 19 ++++++------------- 3 files changed, 18 insertions(+), 32 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 77040f29c..100f11788 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -2317,23 +2317,6 @@ Line 1 self.assertEqual(traverse_obj({}, (0, slice(1)), traverse_string=True), [], msg='branching should result in list if `traverse_string`') - # Test is_user_input behavior - _IS_USER_INPUT_DATA = {'range8': list(range(8))} - self.assertEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', '3'), - is_user_input=True), 3, - msg='allow for string indexing if `is_user_input`') - self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', '3:'), - is_user_input=True), tuple(range(8))[3:], - msg='allow for string slice if `is_user_input`') - self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':4:2'), - is_user_input=True), tuple(range(8))[:4:2], - msg='allow step in string slice if `is_user_input`') - self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':'), - is_user_input=True), range(8), - msg='`:` should be treated as `...` if `is_user_input`') - with self.assertRaises(TypeError, msg='too many params should result in error'): - traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':::'), is_user_input=True) - # Test re.Match as input obj mobj = re.fullmatch(r'0(12)(?P3)(4)?', '0123') self.assertEqual(traverse_obj(mobj, ...), [x for x in mobj.groups() if x is not None], diff --git a/yt_dlp/YoutubeDL.py b/yt_dlp/YoutubeDL.py index 29dd76186..0c07866e4 100644 --- a/yt_dlp/YoutubeDL.py +++ b/yt_dlp/YoutubeDL.py @@ -1201,6 +1201,15 @@ class YoutubeDL: (?:\|(?P.*?))? )$''') + def _from_user_input(field): + if field == ':': + return ... + elif ':' in field: + return slice(*map(int_or_none, field.split(':'))) + elif int_or_none(field) is not None: + return int(field) + return field + def _traverse_infodict(fields): fields = [f for x in re.split(r'\.({.+?})\.?', fields) for f in ([x] if x.startswith('{') else x.split('.'))] @@ -1210,11 +1219,12 @@ class YoutubeDL: for i, f in enumerate(fields): if not f.startswith('{'): + fields[i] = _from_user_input(f) continue assert f.endswith('}'), f'No closing brace for {f} in {fields}' - fields[i] = {k: k.split('.') for k in f[1:-1].split(',')} + fields[i] = {k: list(map(_from_user_input, k.split('.'))) for k in f[1:-1].split(',')} - return traverse_obj(info_dict, fields, is_user_input=True, traverse_string=True) + return traverse_obj(info_dict, fields, traverse_string=True) def get_value(mdict): # Object traversal diff --git a/yt_dlp/utils/traversal.py b/yt_dlp/utils/traversal.py index 462c3ba5d..ff5703198 100644 --- a/yt_dlp/utils/traversal.py +++ b/yt_dlp/utils/traversal.py @@ -8,7 +8,7 @@ from ._utils import ( IDENTITY, NO_DEFAULT, LazyList, - int_or_none, + deprecation_warning, is_iterable_like, try_call, variadic, @@ -17,7 +17,7 @@ from ._utils import ( def traverse_obj( obj, *paths, default=NO_DEFAULT, expected_type=None, get_all=True, - casesense=True, is_user_input=False, traverse_string=False): + casesense=True, is_user_input=NO_DEFAULT, traverse_string=False): """ Safely traverse nested `dict`s and `Iterable`s @@ -63,10 +63,8 @@ def traverse_obj( @param get_all If `False`, return the first matching result, otherwise all matching ones. @param casesense If `False`, consider string dictionary keys as case insensitive. - The following are only meant to be used by YoutubeDL.prepare_outtmpl and are not part of the API + `traverse_string` is only meant to be used by YoutubeDL.prepare_outtmpl and is not part of the API - @param is_user_input Whether the keys are generated from user input. - If `True` strings get converted to `int`/`slice` if needed. @param traverse_string Whether to traverse into objects as strings. If `True`, any non-compatible object will first be converted into a string and then traversed into. @@ -80,6 +78,9 @@ def traverse_obj( If no `default` is given and the last path branches, a `list` of results is always returned. If a path ends on a `dict` that result will always be a `dict`. """ + if is_user_input is not NO_DEFAULT: + deprecation_warning('The is_user_input parameter is deprecated and no longer works') + casefold = lambda k: k.casefold() if isinstance(k, str) else k if isinstance(expected_type, type): @@ -195,14 +196,6 @@ def traverse_obj( key = None for last, key in lazy_last(variadic(path, (str, bytes, dict, set))): - if is_user_input and isinstance(key, str): - if key == ':': - key = ... - elif ':' in key: - key = slice(*map(int_or_none, key.split(':'))) - elif int_or_none(key) is not None: - key = int(key) - if not casesense and isinstance(key, str): key = key.casefold()