mirror of https://github.com/rhasspy/piper
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
209 lines
6.5 KiB
Python
209 lines
6.5 KiB
Python
import json
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Iterable, List, Optional, Sequence, Union
|
|
|
|
import torch
|
|
from torch import FloatTensor, LongTensor
|
|
from torch.utils.data import Dataset
|
|
|
|
|
|
_LOGGER = logging.getLogger("vits.dataset")
|
|
|
|
|
|
@dataclass
|
|
class Utterance:
|
|
phoneme_ids: List[int]
|
|
audio_norm_path: Path
|
|
audio_spec_path: Path
|
|
speaker_id: Optional[int] = None
|
|
text: Optional[str] = None
|
|
|
|
|
|
@dataclass
|
|
class UtteranceTensors:
|
|
phoneme_ids: LongTensor
|
|
spectrogram: FloatTensor
|
|
audio_norm: FloatTensor
|
|
speaker_id: Optional[LongTensor] = None
|
|
text: Optional[str] = None
|
|
|
|
@property
|
|
def spec_length(self) -> int:
|
|
return self.spectrogram.size(1)
|
|
|
|
|
|
@dataclass
|
|
class Batch:
|
|
phoneme_ids: LongTensor
|
|
phoneme_lengths: LongTensor
|
|
spectrograms: FloatTensor
|
|
spectrogram_lengths: LongTensor
|
|
audios: FloatTensor
|
|
audio_lengths: LongTensor
|
|
speaker_ids: Optional[LongTensor] = None
|
|
|
|
|
|
# @dataclass
|
|
# class LarynxDatasetSettings:
|
|
# sample_rate: int
|
|
# is_multispeaker: bool
|
|
# espeak_voice: Optional[str] = None
|
|
# phoneme_map: Dict[str, Optional[List[str]]] = field(default_factory=dict)
|
|
# phoneme_id_map: Dict[str, List[int]] = DEFAULT_PHONEME_ID_MAP
|
|
|
|
|
|
class LarynxDataset(Dataset):
|
|
"""
|
|
Dataset format:
|
|
|
|
* phoneme_ids (required)
|
|
* audio_norm_path (required)
|
|
* audio_spec_path (required)
|
|
* text (optional)
|
|
* phonemes (optional)
|
|
* audio_path (optional)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dataset_paths: List[Union[str, Path]], # settings: LarynxDatasetSettings
|
|
):
|
|
# self.settings = settings
|
|
self.utterances: List[Utterance] = []
|
|
|
|
for dataset_path in dataset_paths:
|
|
dataset_path = Path(dataset_path)
|
|
_LOGGER.debug("Loading dataset: %s", dataset_path)
|
|
self.utterances.extend(LarynxDataset.load_dataset(dataset_path))
|
|
|
|
def __len__(self):
|
|
return len(self.utterances)
|
|
|
|
def __getitem__(self, idx) -> UtteranceTensors:
|
|
utt = self.utterances[idx]
|
|
return UtteranceTensors(
|
|
phoneme_ids=LongTensor(utt.phoneme_ids),
|
|
audio_norm=torch.load(utt.audio_norm_path),
|
|
spectrogram=torch.load(utt.audio_spec_path),
|
|
speaker_id=LongTensor([utt.speaker_id])
|
|
if utt.speaker_id is not None
|
|
else None,
|
|
text=utt.text,
|
|
)
|
|
|
|
@staticmethod
|
|
def load_dataset(dataset_path: Path) -> Iterable[Utterance]:
|
|
with open(dataset_path, "r", encoding="utf-8") as dataset_file:
|
|
for line_idx, line in enumerate(dataset_file):
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
|
|
try:
|
|
yield LarynxDataset.load_utterance(line)
|
|
except Exception:
|
|
_LOGGER.exception(
|
|
"Error on line %s of %s: %s",
|
|
line_idx + 1,
|
|
dataset_path,
|
|
line,
|
|
)
|
|
|
|
@staticmethod
|
|
def load_utterance(line: str) -> Utterance:
|
|
utt_dict = json.loads(line)
|
|
return Utterance(
|
|
phoneme_ids=utt_dict["phoneme_ids"],
|
|
audio_norm_path=Path(utt_dict["audio_norm_path"]),
|
|
audio_spec_path=Path(utt_dict["audio_spec_path"]),
|
|
speaker_id=utt_dict.get("speaker_id"),
|
|
text=utt_dict.get("text"),
|
|
)
|
|
|
|
|
|
class UtteranceCollate:
|
|
def __init__(self, is_multispeaker: bool, segment_size: int):
|
|
self.is_multispeaker = is_multispeaker
|
|
self.segment_size = segment_size
|
|
|
|
def __call__(self, utterances: Sequence[UtteranceTensors]) -> Batch:
|
|
num_utterances = len(utterances)
|
|
assert num_utterances > 0, "No utterances"
|
|
|
|
max_phonemes_length = 0
|
|
max_spec_length = 0
|
|
max_audio_length = 0
|
|
|
|
num_mels = 0
|
|
|
|
# Determine lengths
|
|
for utt_idx, utt in enumerate(utterances):
|
|
assert utt.spectrogram is not None
|
|
assert utt.audio_norm is not None
|
|
|
|
phoneme_length = utt.phoneme_ids.size(0)
|
|
spec_length = utt.spectrogram.size(1)
|
|
audio_length = utt.audio_norm.size(1)
|
|
|
|
max_phonemes_length = max(max_phonemes_length, phoneme_length)
|
|
max_spec_length = max(max_spec_length, spec_length)
|
|
max_audio_length = max(max_audio_length, audio_length)
|
|
|
|
num_mels = utt.spectrogram.size(0)
|
|
if self.is_multispeaker:
|
|
assert utt.speaker_id is not None, "Missing speaker id"
|
|
|
|
# Audio cannot be smaller than segment size (8192)
|
|
max_audio_length = max(max_audio_length, self.segment_size)
|
|
|
|
# Create padded tensors
|
|
phonemes_padded = LongTensor(num_utterances, max_phonemes_length)
|
|
spec_padded = FloatTensor(num_utterances, num_mels, max_spec_length)
|
|
audio_padded = FloatTensor(num_utterances, 1, max_audio_length)
|
|
|
|
phonemes_padded.zero_()
|
|
spec_padded.zero_()
|
|
audio_padded.zero_()
|
|
|
|
phoneme_lengths = LongTensor(num_utterances)
|
|
spec_lengths = LongTensor(num_utterances)
|
|
audio_lengths = LongTensor(num_utterances)
|
|
|
|
speaker_ids: Optional[LongTensor] = None
|
|
if self.is_multispeaker:
|
|
speaker_ids = LongTensor(num_utterances)
|
|
|
|
# Sort by decreasing spectrogram length
|
|
sorted_utterances = sorted(
|
|
utterances, key=lambda u: u.spectrogram.size(1), reverse=True
|
|
)
|
|
for utt_idx, utt in enumerate(sorted_utterances):
|
|
phoneme_length = utt.phoneme_ids.size(0)
|
|
spec_length = utt.spectrogram.size(1)
|
|
audio_length = utt.audio_norm.size(1)
|
|
|
|
phonemes_padded[utt_idx, :phoneme_length] = utt.phoneme_ids
|
|
phoneme_lengths[utt_idx] = phoneme_length
|
|
|
|
spec_padded[utt_idx, :, :spec_length] = utt.spectrogram
|
|
spec_lengths[utt_idx] = spec_length
|
|
|
|
audio_padded[utt_idx, :, :audio_length] = utt.audio_norm
|
|
audio_lengths[utt_idx] = audio_length
|
|
|
|
if utt.speaker_id is not None:
|
|
assert speaker_ids is not None
|
|
speaker_ids[utt_idx] = utt.speaker_id
|
|
|
|
return Batch(
|
|
phoneme_ids=phonemes_padded,
|
|
phoneme_lengths=phoneme_lengths,
|
|
spectrograms=spec_padded,
|
|
spectrogram_lengths=spec_lengths,
|
|
audios=audio_padded,
|
|
audio_lengths=audio_lengths,
|
|
speaker_ids=speaker_ids,
|
|
)
|