Skip to content

hifigan

HIFIGAN

Bases: Module

Source code in speechain/module/vocoder/hifigan.py
class HIFIGAN(nn.Module):
    def __init__(
        self,
        in_channels=80,
        upsample_initial_channel=512,
        upsample_rates=(8, 8, 2, 2),
        upsample_kernel_sizes=(16, 16, 4, 4),
    ):
        super().__init__()

        self.pre_net = nn.Conv1d(in_channels, upsample_initial_channel, 7, padding=3)

        ups_in_channels = [
            upsample_initial_channel // (2**i) for i in range(len(upsample_rates))
        ]
        ups_out_channels = [
            upsample_initial_channel // (2 ** (i + 1))
            for i in range(len(upsample_rates))
        ]

        self.upsamples = nn.ModuleList()
        self.mrfs = nn.ModuleList()

        for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
            self.upsamples.append(
                nn.ConvTranspose1d(
                    ups_in_channels[i],
                    ups_out_channels[i],
                    k,
                    stride=u,
                    padding=(k - u) // 2,
                )
            )
            self.mrfs.append(MRF(ups_out_channels[i]))

        self.post_net = nn.Sequential(
            nn.LeakyReLU(0.1),
            nn.Conv1d(ups_out_channels[-1], 1, 7, padding=3),
            nn.Tanh(),
        )

    def forward(self, x):
        x = self.pre_net(x)
        for up, mrf in zip(self.upsamples, self.mrfs):
            x = F.leaky_relu(x, 0.1)
            x = up(x)
            x = mrf(x)
        x = self.post_net(x)
        return x.squeeze(1)

    def decode_batch(self, feats):
        """Convert mel-spectrograms to waveforms
        Args:
            feats: (batch, time, n_mels)
        Returns:
            waveforms: (batch, time')
        """
        self.eval()
        with torch.no_grad():
            x = feats.transpose(1, 2)  # (B, n_mels, T)
            return self.forward(x)

    @classmethod
    def from_hparams(cls, source, savedir=None, run_opts=None):
        """Load pretrained model
        Args:
            source: Model identifier (e.g. "speechbrain/tts-hifigan-ljspeech")
            savedir: Directory to save model weights
            run_opts: Runtime options including device
        Returns:
            model: Loaded HiFiGAN model
        """
        model = cls()

        # Set device
        if run_opts and "device" in run_opts:
            model = model.to(run_opts["device"])

        # Load pretrained weights based on source
        if savedir:
            weights_path = os.path.join(savedir, "generator.pth")
            if os.path.exists(weights_path):
                model.load_state_dict(
                    torch.load(weights_path, map_location=run_opts["device"])
                )

        return model

decode_batch(feats)

Convert mel-spectrograms to waveforms Args: feats: (batch, time, n_mels) Returns: waveforms: (batch, time')

Source code in speechain/module/vocoder/hifigan.py
def decode_batch(self, feats):
    """Convert mel-spectrograms to waveforms
    Args:
        feats: (batch, time, n_mels)
    Returns:
        waveforms: (batch, time')
    """
    self.eval()
    with torch.no_grad():
        x = feats.transpose(1, 2)  # (B, n_mels, T)
        return self.forward(x)

from_hparams(source, savedir=None, run_opts=None) classmethod

Load pretrained model Args: source: Model identifier (e.g. "speechbrain/tts-hifigan-ljspeech") savedir: Directory to save model weights run_opts: Runtime options including device Returns: model: Loaded HiFiGAN model

Source code in speechain/module/vocoder/hifigan.py
@classmethod
def from_hparams(cls, source, savedir=None, run_opts=None):
    """Load pretrained model
    Args:
        source: Model identifier (e.g. "speechbrain/tts-hifigan-ljspeech")
        savedir: Directory to save model weights
        run_opts: Runtime options including device
    Returns:
        model: Loaded HiFiGAN model
    """
    model = cls()

    # Set device
    if run_opts and "device" in run_opts:
        model = model.to(run_opts["device"])

    # Load pretrained weights based on source
    if savedir:
        weights_path = os.path.join(savedir, "generator.pth")
        if os.path.exists(weights_path):
            model.load_state_dict(
                torch.load(weights_path, map_location=run_opts["device"])
            )

    return model