You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
piper/src/python/larynx_train/vits/lightning.py

331 lines
11 KiB
Python

import logging
from pathlib import Path
from typing import List, Optional, Tuple, Union
import pytorch_lightning as pl
import torch
from torch import autocast
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset, random_split
from .commons import slice_segments
from .dataset import Batch, LarynxDataset, UtteranceCollate
from .losses import discriminator_loss, feature_loss, generator_loss, kl_loss
from .mel_processing import mel_spectrogram_torch, spec_to_mel_torch
from .models import MultiPeriodDiscriminator, SynthesizerTrn
_LOGGER = logging.getLogger("vits.lightning")
class VitsModel(pl.LightningModule):
def __init__(
self,
num_symbols: int,
num_speakers: int,
# audio
resblock="2",
resblock_kernel_sizes=(3, 5, 7),
resblock_dilation_sizes=(
(1, 2),
(2, 6),
(3, 12),
),
upsample_rates=(8, 8, 4),
upsample_initial_channel=256,
upsample_kernel_sizes=(16, 16, 8),
# mel
filter_length: int = 1024,
hop_length: int = 256,
win_length: int = 1024,
mel_channels: int = 80,
sample_rate: int = 22050,
sample_bytes: int = 2,
channels: int = 1,
mel_fmin: float = 0.0,
mel_fmax: Optional[float] = None,
# model
inter_channels: int = 192,
hidden_channels: int = 192,
filter_channels: int = 768,
n_heads: int = 2,
n_layers: int = 6,
kernel_size: int = 3,
p_dropout: float = 0.1,
n_layers_q: int = 3,
use_spectral_norm: bool = False,
gin_channels: int = 0,
use_sdp: bool = True,
segment_size: int = 8192,
# training
dataset: Optional[List[Union[str, Path]]] = None,
learning_rate: float = 2e-4,
betas: Tuple[float, float] = (0.8, 0.99),
eps: float = 1e-9,
batch_size: int = 1,
lr_decay: float = 0.999875,
init_lr_ratio: float = 1.0,
warmup_epochs: int = 0,
c_mel: int = 45,
c_kl: float = 1.0,
grad_clip: Optional[float] = None,
num_workers: int = 1,
seed: int = 1234,
num_test_examples: int = 5,
validation_split: float = 0.1,
**kwargs
):
super().__init__()
self.save_hyperparameters()
if (self.hparams.num_speakers > 1) and (self.hparams.gin_channels <= 0):
# Default gin_channels for multi-speaker model
self.hparams.gin_channels = 512
# Set up models
self.model_g = SynthesizerTrn(
n_vocab=self.hparams.num_symbols,
spec_channels=self.hparams.filter_length // 2 + 1,
segment_size=self.hparams.segment_size // self.hparams.hop_length,
inter_channels=self.hparams.inter_channels,
hidden_channels=self.hparams.hidden_channels,
filter_channels=self.hparams.filter_channels,
n_heads=self.hparams.n_heads,
n_layers=self.hparams.n_layers,
kernel_size=self.hparams.kernel_size,
p_dropout=self.hparams.p_dropout,
resblock=self.hparams.resblock,
resblock_kernel_sizes=self.hparams.resblock_kernel_sizes,
resblock_dilation_sizes=self.hparams.resblock_dilation_sizes,
upsample_rates=self.hparams.upsample_rates,
upsample_initial_channel=self.hparams.upsample_initial_channel,
upsample_kernel_sizes=self.hparams.upsample_kernel_sizes,
n_speakers=self.hparams.num_speakers,
gin_channels=self.hparams.gin_channels,
use_sdp=self.hparams.use_sdp,
)
self.model_d = MultiPeriodDiscriminator(
use_spectral_norm=self.hparams.use_spectral_norm
)
# Dataset splits
self._train_dataset: Optional[Dataset] = None
self._val_dataset: Optional[Dataset] = None
self._test_dataset: Optional[Dataset] = None
self._load_datasets(validation_split, num_test_examples)
# State kept between training optimizers
self._y = None
self._y_hat = None
def _load_datasets(self, validation_split: float, num_test_examples: int):
full_dataset = LarynxDataset(self.hparams.dataset)
valid_set_size = int(len(full_dataset) * validation_split)
train_set_size = len(full_dataset) - valid_set_size - num_test_examples
self._train_dataset, self._test_dataset, self._val_dataset = random_split(
full_dataset, [train_set_size, num_test_examples, valid_set_size]
)
def forward(self, text, text_lengths, scales, sid=None):
noise_scale = scales[0]
length_scale = scales[1]
noise_scale_w = scales[2]
audio, *_ = self.model_g.infer(
text,
text_lengths,
noise_scale=noise_scale,
length_scale=length_scale,
noise_scale_w=noise_scale_w,
sid=sid,
)
return audio
def train_dataloader(self):
return DataLoader(
self._train_dataset,
collate_fn=UtteranceCollate(
is_multispeaker=self.hparams.num_speakers > 1,
segment_size=self.hparams.segment_size,
),
num_workers=self.hparams.num_workers,
batch_size=self.hparams.batch_size,
)
def val_dataloader(self):
return DataLoader(
self._val_dataset,
collate_fn=UtteranceCollate(
is_multispeaker=self.hparams.num_speakers > 1,
segment_size=self.hparams.segment_size,
),
num_workers=self.hparams.num_workers,
batch_size=self.hparams.batch_size,
)
def test_dataloader(self):
return DataLoader(
self._test_dataset,
collate_fn=UtteranceCollate(
is_multispeaker=self.hparams.num_speakers > 1,
segment_size=self.hparams.segment_size,
),
num_workers=self.hparams.num_workers,
batch_size=self.hparams.batch_size,
)
def training_step(self, batch: Batch, batch_idx: int, optimizer_idx: int):
if optimizer_idx == 0:
return self.training_step_g(batch)
if optimizer_idx == 1:
return self.training_step_d(batch)
def training_step_g(self, batch: Batch):
x, x_lengths, y, _, spec, spec_lengths, speaker_ids = (
batch.phoneme_ids,
batch.phoneme_lengths,
batch.audios,
batch.audio_lengths,
batch.spectrograms,
batch.spectrogram_lengths,
batch.speaker_ids if batch.speaker_ids is not None else None,
)
(
y_hat,
l_length,
_attn,
ids_slice,
_x_mask,
z_mask,
(_z, z_p, m_p, logs_p, _m_q, logs_q),
) = self.model_g(x, x_lengths, spec, spec_lengths, speaker_ids)
self._y_hat = y_hat
mel = spec_to_mel_torch(
spec,
self.hparams.filter_length,
self.hparams.mel_channels,
self.hparams.sample_rate,
self.hparams.mel_fmin,
self.hparams.mel_fmax,
)
y_mel = slice_segments(
mel,
ids_slice,
self.hparams.segment_size // self.hparams.hop_length,
)
y_hat_mel = mel_spectrogram_torch(
y_hat.squeeze(1),
self.hparams.filter_length,
self.hparams.mel_channels,
self.hparams.sample_rate,
self.hparams.hop_length,
self.hparams.win_length,
self.hparams.mel_fmin,
self.hparams.mel_fmax,
)
y = slice_segments(
y,
ids_slice * self.hparams.hop_length,
self.hparams.segment_size,
) # slice
# Save for training_step_d
self._y = y
_y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = self.model_d(y, y_hat)
with autocast(self.device.type, enabled=False):
# Generator loss
loss_dur = torch.sum(l_length.float())
loss_mel = F.l1_loss(y_mel, y_hat_mel) * self.hparams.c_mel
loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * self.hparams.c_kl
loss_fm = feature_loss(fmap_r, fmap_g)
loss_gen, _losses_gen = generator_loss(y_d_hat_g)
loss_gen_all = loss_gen + loss_fm + loss_mel + loss_dur + loss_kl
self.log("loss_gen_all", loss_gen_all)
return loss_gen_all
def training_step_d(self, batch: Batch):
# From training_step_g
y = self._y
y_hat = self._y_hat
y_d_hat_r, y_d_hat_g, _, _ = self.model_d(y, y_hat.detach())
with autocast(self.device.type, enabled=False):
# Discriminator
loss_disc, _losses_disc_r, _losses_disc_g = discriminator_loss(
y_d_hat_r, y_d_hat_g
)
loss_disc_all = loss_disc
self.log("loss_disc_all", loss_disc_all)
return loss_disc_all
def validation_step(self, batch: Batch, batch_idx: int):
val_loss = self.training_step_g(batch)
self.log("val_loss", val_loss)
# Generate audio examples
for utt_idx, test_utt in enumerate(self._test_dataset):
text = test_utt.phoneme_ids.unsqueeze(0).to(self.device)
text_lengths = torch.LongTensor([len(test_utt.phoneme_ids)]).to(self.device)
scales = [0.667, 1.0, 0.8]
test_audio = self(text, text_lengths, scales).detach()
# Scale to make louder in [-1, 1]
test_audio = test_audio * (1.0 / max(0.01, abs(test_audio.max())))
tag = test_utt.text or str(utt_idx)
self.logger.experiment.add_audio(
tag, test_audio, sample_rate=self.hparams.sample_rate
)
return val_loss
def configure_optimizers(self):
optimizers = [
torch.optim.AdamW(
self.model_g.parameters(),
lr=self.hparams.learning_rate,
betas=self.hparams.betas,
eps=self.hparams.eps,
),
torch.optim.AdamW(
self.model_d.parameters(),
lr=self.hparams.learning_rate,
betas=self.hparams.betas,
eps=self.hparams.eps,
),
]
schedulers = [
torch.optim.lr_scheduler.ExponentialLR(
optimizers[0], gamma=self.hparams.lr_decay
),
torch.optim.lr_scheduler.ExponentialLR(
optimizers[1], gamma=self.hparams.lr_decay
),
]
return optimizers, schedulers
@staticmethod
def add_model_specific_args(parent_parser):
parser = parent_parser.add_argument_group("VitsModel")
parser.add_argument("--batch-size", type=int, required=True)
parser.add_argument("--validation-split", type=float, default=0.1)
parser.add_argument("--num-test-examples", type=int, default=5)
#
parser.add_argument("--hidden-channels", type=int, default=192)
parser.add_argument("--inter-channels", type=int, default=192)
parser.add_argument("--filter-channels", type=int, default=768)
parser.add_argument("--n-layers", type=int, default=6)
parser.add_argument("--n-heads", type=int, default=2)
#
return parent_parser