You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
piper/src/python/larynx_train/vits/dataset.py

209 lines
6.5 KiB
Python

import json
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, List, Optional, Sequence, Union
import torch
from torch import FloatTensor, LongTensor
from torch.utils.data import Dataset
_LOGGER = logging.getLogger("vits.dataset")
@dataclass
class Utterance:
phoneme_ids: List[int]
audio_norm_path: Path
audio_spec_path: Path
speaker_id: Optional[int] = None
text: Optional[str] = None
@dataclass
class UtteranceTensors:
phoneme_ids: LongTensor
spectrogram: FloatTensor
audio_norm: FloatTensor
speaker_id: Optional[LongTensor] = None
text: Optional[str] = None
@property
def spec_length(self) -> int:
return self.spectrogram.size(1)
@dataclass
class Batch:
phoneme_ids: LongTensor
phoneme_lengths: LongTensor
spectrograms: FloatTensor
spectrogram_lengths: LongTensor
audios: FloatTensor
audio_lengths: LongTensor
speaker_ids: Optional[LongTensor] = None
# @dataclass
# class LarynxDatasetSettings:
# sample_rate: int
# is_multispeaker: bool
# espeak_voice: Optional[str] = None
# phoneme_map: Dict[str, Optional[List[str]]] = field(default_factory=dict)
# phoneme_id_map: Dict[str, List[int]] = DEFAULT_PHONEME_ID_MAP
class LarynxDataset(Dataset):
"""
Dataset format:
* phoneme_ids (required)
* audio_norm_path (required)
* audio_spec_path (required)
* text (optional)
* phonemes (optional)
* audio_path (optional)
"""
def __init__(
self,
dataset_paths: List[Union[str, Path]], # settings: LarynxDatasetSettings
):
# self.settings = settings
self.utterances: List[Utterance] = []
for dataset_path in dataset_paths:
dataset_path = Path(dataset_path)
_LOGGER.debug("Loading dataset: %s", dataset_path)
self.utterances.extend(LarynxDataset.load_dataset(dataset_path))
def __len__(self):
return len(self.utterances)
def __getitem__(self, idx) -> UtteranceTensors:
utt = self.utterances[idx]
return UtteranceTensors(
phoneme_ids=LongTensor(utt.phoneme_ids),
audio_norm=torch.load(utt.audio_norm_path),
spectrogram=torch.load(utt.audio_spec_path),
speaker_id=LongTensor([utt.speaker_id])
if utt.speaker_id is not None
else None,
text=utt.text,
)
@staticmethod
def load_dataset(dataset_path: Path) -> Iterable[Utterance]:
with open(dataset_path, "r", encoding="utf-8") as dataset_file:
for line_idx, line in enumerate(dataset_file):
line = line.strip()
if not line:
continue
try:
yield LarynxDataset.load_utterance(line)
except Exception:
_LOGGER.exception(
"Error on line %s of %s: %s",
line_idx + 1,
dataset_path,
line,
)
@staticmethod
def load_utterance(line: str) -> Utterance:
utt_dict = json.loads(line)
return Utterance(
phoneme_ids=utt_dict["phoneme_ids"],
audio_norm_path=Path(utt_dict["audio_norm_path"]),
audio_spec_path=Path(utt_dict["audio_spec_path"]),
speaker_id=utt_dict.get("speaker_id"),
text=utt_dict.get("text"),
)
class UtteranceCollate:
def __init__(self, is_multispeaker: bool, segment_size: int):
self.is_multispeaker = is_multispeaker
self.segment_size = segment_size
def __call__(self, utterances: Sequence[UtteranceTensors]) -> Batch:
num_utterances = len(utterances)
assert num_utterances > 0, "No utterances"
max_phonemes_length = 0
max_spec_length = 0
max_audio_length = 0
num_mels = 0
# Determine lengths
for utt_idx, utt in enumerate(utterances):
assert utt.spectrogram is not None
assert utt.audio_norm is not None
phoneme_length = utt.phoneme_ids.size(0)
spec_length = utt.spectrogram.size(1)
audio_length = utt.audio_norm.size(1)
max_phonemes_length = max(max_phonemes_length, phoneme_length)
max_spec_length = max(max_spec_length, spec_length)
max_audio_length = max(max_audio_length, audio_length)
num_mels = utt.spectrogram.size(0)
if self.is_multispeaker:
assert utt.speaker_id is not None, "Missing speaker id"
# Audio cannot be smaller than segment size (8192)
max_audio_length = max(max_audio_length, self.segment_size)
# Create padded tensors
phonemes_padded = LongTensor(num_utterances, max_phonemes_length)
spec_padded = FloatTensor(num_utterances, num_mels, max_spec_length)
audio_padded = FloatTensor(num_utterances, 1, max_audio_length)
phonemes_padded.zero_()
spec_padded.zero_()
audio_padded.zero_()
phoneme_lengths = LongTensor(num_utterances)
spec_lengths = LongTensor(num_utterances)
audio_lengths = LongTensor(num_utterances)
speaker_ids: Optional[LongTensor] = None
if self.is_multispeaker:
speaker_ids = LongTensor(num_utterances)
# Sort by decreasing spectrogram length
sorted_utterances = sorted(
utterances, key=lambda u: u.spectrogram.size(1), reverse=True
)
for utt_idx, utt in enumerate(sorted_utterances):
phoneme_length = utt.phoneme_ids.size(0)
spec_length = utt.spectrogram.size(1)
audio_length = utt.audio_norm.size(1)
phonemes_padded[utt_idx, :phoneme_length] = utt.phoneme_ids
phoneme_lengths[utt_idx] = phoneme_length
spec_padded[utt_idx, :, :spec_length] = utt.spectrogram
spec_lengths[utt_idx] = spec_length
audio_padded[utt_idx, :, :audio_length] = utt.audio_norm
audio_lengths[utt_idx] = audio_length
if utt.speaker_id is not None:
assert speaker_ids is not None
speaker_ids[utt_idx] = utt.speaker_id
return Batch(
phoneme_ids=phonemes_padded,
phoneme_lengths=phoneme_lengths,
spectrograms=spec_padded,
spectrogram_lengths=spec_lengths,
audios=audio_padded,
audio_lengths=audio_lengths,
speaker_ids=speaker_ids,
)