Skip to content

speaker

EncoderClassifier

Bases: Module

Source code in speechain/module/encoder/speaker.py
class EncoderClassifier(nn.Module):
    def __init__(self, model_type="ecapa"):
        super().__init__()
        self.model_type = model_type
        self.frontend = MelSpectrogramFrontend()
        self.model = self._create_model()

    def _create_model(self):
        if self.model_type == "ecapa":
            return self._create_ecapa()
        elif self.model_type == "xvector":
            return self._create_xvector()
        else:
            raise ValueError(f"Unknown model type: {self.model_type}")

    def _create_ecapa(self):
        channels = 512
        model = nn.Sequential(
            nn.Conv1d(80, channels, 7, padding=3),
            nn.BatchNorm1d(channels),
            nn.ReLU(),
            nn.Sequential(
                Res2Block(channels), SEModule(channels), nn.BatchNorm1d(channels)
            ),
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(),
            nn.Linear(channels, 192),
        )
        return model

    def _create_xvector(self):
        model = nn.Sequential(
            nn.Conv1d(80, 512, 5, padding=2),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Conv1d(512, 512, 3, padding=1),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Conv1d(512, 512, 3, padding=1),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 192),
        )
        return model

    def encode_batch(self, wavs, wav_lens=None):
        """Extract speaker embeddings from raw waveforms.

        Args:
            wavs: (batch, time) raw waveform tensor
            wav_lens: relative lengths (optional, not used currently)

        Returns:
            (batch, embedding_dim) normalized speaker embeddings
        """
        self.eval()
        with torch.no_grad():
            # Convert waveforms to mel-spectrograms
            # wavs: (batch, time) -> mel: (batch, n_mels, time)
            mel = self.frontend(wavs)
            # Model expects (batch, n_mels, time)
            embeddings = self.model(mel)
            return F.normalize(embeddings, p=2, dim=1)

    @classmethod
    def from_hparams(cls, source, savedir=None, run_opts=None):
        """Load pretrained model."""
        model = cls(model_type="ecapa" if "ecapa" in source else "xvector")

        if run_opts and "device" in run_opts:
            model = model.to(run_opts["device"])

        if savedir:
            weights_path = os.path.join(savedir, "encoder.pth")
            if os.path.exists(weights_path):
                model.load_state_dict(
                    torch.load(weights_path, map_location=run_opts["device"])
                )

        return model

encode_batch(wavs, wav_lens=None)

Extract speaker embeddings from raw waveforms.

Parameters:

Name Type Description Default
wavs

(batch, time) raw waveform tensor

required
wav_lens

relative lengths (optional, not used currently)

None

Returns:

Type Description

(batch, embedding_dim) normalized speaker embeddings

Source code in speechain/module/encoder/speaker.py
def encode_batch(self, wavs, wav_lens=None):
    """Extract speaker embeddings from raw waveforms.

    Args:
        wavs: (batch, time) raw waveform tensor
        wav_lens: relative lengths (optional, not used currently)

    Returns:
        (batch, embedding_dim) normalized speaker embeddings
    """
    self.eval()
    with torch.no_grad():
        # Convert waveforms to mel-spectrograms
        # wavs: (batch, time) -> mel: (batch, n_mels, time)
        mel = self.frontend(wavs)
        # Model expects (batch, n_mels, time)
        embeddings = self.model(mel)
        return F.normalize(embeddings, p=2, dim=1)

from_hparams(source, savedir=None, run_opts=None) classmethod

Load pretrained model.

Source code in speechain/module/encoder/speaker.py
@classmethod
def from_hparams(cls, source, savedir=None, run_opts=None):
    """Load pretrained model."""
    model = cls(model_type="ecapa" if "ecapa" in source else "xvector")

    if run_opts and "device" in run_opts:
        model = model.to(run_opts["device"])

    if savedir:
        weights_path = os.path.join(savedir, "encoder.pth")
        if os.path.exists(weights_path):
            model.load_state_dict(
                torch.load(weights_path, map_location=run_opts["device"])
            )

    return model

MelSpectrogramFrontend

Bases: Module

Mel-spectrogram frontend for processing raw waveforms.

Source code in speechain/module/encoder/speaker.py
class MelSpectrogramFrontend(nn.Module):
    """Mel-spectrogram frontend for processing raw waveforms."""

    def __init__(
        self, sample_rate=16000, n_mels=80, n_fft=400, hop_length=160, win_length=400
    ):
        super().__init__()
        self.mel_spec = torchaudio.transforms.MelSpectrogram(
            sample_rate=sample_rate,
            n_fft=n_fft,
            hop_length=hop_length,
            win_length=win_length,
            n_mels=n_mels,
        )

    def forward(self, x):
        # x: (batch, time)
        mel = self.mel_spec(x)  # (batch, n_mels, time)
        # Apply log compression
        mel = torch.log(mel + 1e-6)
        return mel