Move the flow component to the decoder. RTF of encoder is now 0.01

pull/255/head
mush42 7 months ago
parent a691f14689
commit 7f66948f40

@ -55,9 +55,7 @@ class VitsEncoder(nn.Module):
) # [b, t', t], [b, t, d] -> [b, d, t']
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
# Should we move the `flow` to the decoder?
z = gen.flow(z_p, y_mask, g=g, reverse=True)
return z, y_mask, g
return z_p, y_mask, g
class VitsDecoder(nn.Module):
@ -66,6 +64,7 @@ class VitsDecoder(nn.Module):
self.gen = gen
def forward(self, z, y_mask, g=None):
z = self.gen.flow(z, y_mask, g=g, reverse=True)
output = self.gen.dec((z * y_mask), g=g)
return output

Loading…
Cancel
Save