class HIFIGAN(nn.Module):
    def __init__(
        self,
        in_channels=80,
        upsample_initial_channel=512,
        upsample_rates=(8, 8, 2, 2),
        upsample_kernel_sizes=(16, 16, 4, 4),
    ):
        super().__init__()
        self.pre_net = nn.Conv1d(in_channels, upsample_initial_channel, 7, padding=3)
        ups_in_channels = [
            upsample_initial_channel // (2**i) for i in range(len(upsample_rates))
        ]
        ups_out_channels = [
            upsample_initial_channel // (2 ** (i + 1))
            for i in range(len(upsample_rates))
        ]
        self.upsamples = nn.ModuleList()
        self.mrfs = nn.ModuleList()
        for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
            self.upsamples.append(
                nn.ConvTranspose1d(
                    ups_in_channels[i],
                    ups_out_channels[i],
                    k,
                    stride=u,
                    padding=(k - u) // 2,
                )
            )
            self.mrfs.append(MRF(ups_out_channels[i]))
        self.post_net = nn.Sequential(
            nn.LeakyReLU(0.1),
            nn.Conv1d(ups_out_channels[-1], 1, 7, padding=3),
            nn.Tanh(),
        )
    def forward(self, x):
        x = self.pre_net(x)
        for up, mrf in zip(self.upsamples, self.mrfs):
            x = F.leaky_relu(x, 0.1)
            x = up(x)
            x = mrf(x)
        x = self.post_net(x)
        return x.squeeze(1)
    def decode_batch(self, feats):
        """Convert mel-spectrograms to waveforms
        Args:
            feats: (batch, time, n_mels)
        Returns:
            waveforms: (batch, time')
        """
        self.eval()
        with torch.no_grad():
            x = feats.transpose(1, 2)  # (B, n_mels, T)
            return self.forward(x)
    @classmethod
    def from_hparams(cls, source, savedir=None, run_opts=None):
        """Load pretrained model
        Args:
            source: Model identifier (e.g. "speechbrain/tts-hifigan-ljspeech")
            savedir: Directory to save model weights
            run_opts: Runtime options including device
        Returns:
            model: Loaded HiFiGAN model
        """
        model = cls()
        # Set device
        if run_opts and "device" in run_opts:
            model = model.to(run_opts["device"])
        # Load pretrained weights based on source
        if savedir:
            weights_path = os.path.join(savedir, "generator.pth")
            if os.path.exists(weights_path):
                model.load_state_dict(
                    torch.load(weights_path, map_location=run_opts["device"])
                )
        return model