Skip to content

sb_util

SpeechBrainWrapper

Bases: object

A wrapper class for the vocoder forward function of the speechbrain package.

This wrapper is not implemented as a Module because we don't want it to be in the computational graph of a TTS model.

Before wrapping

feat -> vocoder -> wav

After wrapping: feat, feat_len -> SpeechBrainWrapper(vocoder) -> wav, wav_len

Source code in speechain/utilbox/sb_util.py
class SpeechBrainWrapper(object):
    """A wrapper class for the vocoder forward function of the speechbrain package.

    This wrapper is not implemented as a Module because we don't want it to be in the computational graph of a TTS model.

    Before wrapping:
        feat -> vocoder -> wav
    After wrapping:
        feat, feat_len -> SpeechBrainWrapper(vocoder) -> wav, wav_len
    """

    def __init__(self, vocoder: HiFiGAN):
        self.vocoder = vocoder

    def __call__(self, feat: torch.Tensor, feat_len: torch.Tensor):
        # feat is (batch, time, channels), need to transpose to (batch, channels, time) for HiFiGAN forward
        # Check for NaN/Inf in features - if present, this indicates a problem with the model
        if not torch.isfinite(feat).all():
            import warnings

            warnings.warn(
                "Non-finite values (NaN/Inf) detected in mel-spectrogram features! "
                "This indicates numerical instability in the TTS model. "
                "The model needs to be retrained with proper gradient clipping and loss scaling."
            )
            # Convert NaN to zero as a last resort to avoid crashing
            feat = torch.nan_to_num(feat, nan=0.0, posinf=0.0, neginf=0.0)

        wav = self.vocoder.decode_batch(feat.transpose(-2, -1))
        # wav output is (batch, time) after decode_batch
        # add channel dimension back: (batch, time) -> (batch, time, 1)
        wav = wav.unsqueeze(-1)
        # the lengths of the shorter utterances in the batch are estimated by their feature lengths
        wav_len = (feat_len * (wav.size(1) / feat.size(1))).long()
        # make sure that the redundant parts are set to silence
        for i in range(len(wav_len)):
            wav[i][wav_len[i] :] = 0
        return wav[:, : wav_len.max()], wav_len