Skip to content

speaker

EncoderClassifier

Bases: Module

Source code in speechain/module/encoder/speaker.py
class EncoderClassifier(nn.Module):
    def __init__(self, model_type="ecapa"):
        super().__init__()
        self.model_type = model_type
        self.model = self._create_model()

    def _create_model(self):
        if self.model_type == "ecapa":
            return self._create_ecapa()
        elif self.model_type == "xvector":
            return self._create_xvector()
        else:
            raise ValueError(f"Unknown model type: {self.model_type}")

    def _create_ecapa(self):
        channels = 512
        model = nn.Sequential(
            nn.Conv1d(80, channels, 7, padding=3),
            nn.BatchNorm1d(channels),
            nn.ReLU(),
            nn.Sequential(
                Res2Block(channels), SEModule(channels), nn.BatchNorm1d(channels)
            ),
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(),
            nn.Linear(channels, 192),
        )
        return model

    def _create_xvector(self):
        model = nn.Sequential(
            nn.Conv1d(80, 512, 5, padding=2),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Conv1d(512, 512, 3, padding=1),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Conv1d(512, 512, 3, padding=1),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 192),
        )
        return model

    def encode_batch(self, wavs, wav_lens=None):
        self.eval()
        with torch.no_grad():
            x = wavs.transpose(1, 2)
            embeddings = self.model(x)
            return F.normalize(embeddings, p=2, dim=1)

    @classmethod
    def from_hparams(cls, source, savedir=None, run_opts=None):
        """Load pretrained model."""
        model = cls(model_type="ecapa" if "ecapa" in source else "xvector")

        if run_opts and "device" in run_opts:
            model = model.to(run_opts["device"])

        if savedir:
            weights_path = os.path.join(savedir, "encoder.pth")
            if os.path.exists(weights_path):
                model.load_state_dict(
                    torch.load(weights_path, map_location=run_opts["device"])
                )

        return model

from_hparams(source, savedir=None, run_opts=None) classmethod

Load pretrained model.

Source code in speechain/module/encoder/speaker.py
@classmethod
def from_hparams(cls, source, savedir=None, run_opts=None):
    """Load pretrained model."""
    model = cls(model_type="ecapa" if "ecapa" in source else "xvector")

    if run_opts and "device" in run_opts:
        model = model.to(run_opts["device"])

    if savedir:
        weights_path = os.path.join(savedir, "encoder.pth")
        if os.path.exists(weights_path):
            model.load_state_dict(
                torch.load(weights_path, map_location=run_opts["device"])
            )

    return model