Skip to content

sb_util

SpeechBrainWrapper

Bases: object

A wrapper class for the vocoder forward function of the speechbrain package.

This wrapper is not implemented as a Module because we don't want it to be in the computational graph of a TTS model.

Before wrapping

feat -> vocoder -> wav

After wrapping: feat, feat_len -> SpeechBrainWrapper(vocoder) -> wav, wav_len

Source code in speechain/utilbox/sb_util.py
class SpeechBrainWrapper(object):
    """A wrapper class for the vocoder forward function of the speechbrain package.

    This wrapper is not implemented as a Module because we don't want it to be in the computational graph of a TTS model.

    Before wrapping:
        feat -> vocoder -> wav
    After wrapping:
        feat, feat_len -> SpeechBrainWrapper(vocoder) -> wav, wav_len
    """

    def __init__(self, vocoder: HIFIGAN):
        self.vocoder = vocoder

    def __call__(self, feat: torch.Tensor, feat_len: torch.Tensor):
        wav = self.vocoder.decode_batch(feat.transpose(-2, -1)).transpose(-2, -1)
        # the lengths of the shorter utterances in the batch are estimated by their feature lengths
        wav_len = (feat_len * (wav.size(1) / feat.size(1))).long()
        # make sure that the redundant parts are set to silence
        for i in range(len(wav_len)):
            wav[i][wav_len[i] :] = 0
        return wav[:, : wav_len.max()], wav_len