Skip to content

spk_util

extract_spk_feat(spk2wav_dict, gpu_id, spk_emb_model, save_path=None, batch_size=10)

Extract speaker features using a specified speaker embedding model and save them.

Parameters:

Name Type Description Default
spk2wav_dict Dict

A dictionary mapping unique IDs to waveform file paths.

required
gpu_id int

The GPU device ID to use. Set to -1 for CPU.

required
spk_emb_model str

The speaker embedding model to use (either 'ecapa' or 'xvector').

required
save_path str

The path to save the extracted speaker features. If not given, the extracted features will be stored in memory. Defaults to None.

None
batch_size int

The batch size for processing. Defaults to 10.

10

Returns:

Type Description
Tuple[Dict, Dict]

Tuple[Dict, Dict]: - A dictionary mapping unique IDs to the corresponding extracted speaker features. - A dictionary mapping speaker IDs to the corresponding average speaker features.

Source code in speechain/utilbox/spk_util.py
def extract_spk_feat(
    spk2wav_dict: Dict[str, Dict[str, str]],
    gpu_id: int,
    spk_emb_model: str,
    save_path: str = None,
    batch_size: int = 10,
) -> Tuple[Dict, Dict]:
    """Extract speaker features using a specified speaker embedding model and save them.

    Args:
        spk2wav_dict (Dict):
            A dictionary mapping unique IDs to waveform file paths.
        gpu_id (int):
            The GPU device ID to use. Set to -1 for CPU.
        spk_emb_model (str):
            The speaker embedding model to use (either 'ecapa' or 'xvector').
        save_path (str, optional):
            The path to save the extracted speaker features. If not given, the extracted features will be stored
            in memory. Defaults to None.
        batch_size (int, optional):
            The batch size for processing. Defaults to 10.

    Returns:
        Tuple[Dict, Dict]:
            - A dictionary mapping unique IDs to the corresponding extracted speaker features.
            - A dictionary mapping speaker IDs to the corresponding average speaker features.
    """

    def proc_curr_batch():
        """Process the current batch of waveforms and extract speaker features."""
        idx_list, wav_list = [i[0] for i in curr_batch], [i[1] for i in curr_batch]
        wav_len = torch.LongTensor([w.size(0) for w in wav_list]).to(device)
        max_wav_len = wav_len.max().item()

        # Pad feature vectors into a matrix
        wav_matrix = torch.zeros((wav_len.size(0), max_wav_len), device=device)
        for i in range(len(wav_list)):
            wav_matrix[i][: wav_len[i]] = wav_list[i]

        spk_feat = speechbrain_model.encode_batch(
            wavs=wav_matrix, wav_lens=wav_len / max_wav_len
        )
        if save_path is None:
            idx2spk_feat.update(
                dict(zip(idx_list, [to_cpu(s_f, tgt="numpy") for s_f in spk_feat]))
            )
        else:
            idx2spk_feat.update(
                save_data_by_format(
                    file_format="npy",
                    save_path=save_path,
                    group_ids=spk_id,
                    file_name_list=idx_list,
                    file_content_list=[to_cpu(s_f, tgt="numpy") for s_f in spk_feat],
                )
            )

        # refresh the current batch
        return []

    # initialize the speaker embedding model and downloading path for speechbrain API
    download_dir = parse_path_args("datasets/spk_emb_models")
    if spk_emb_model == "ecapa":
        speechbrain_args = dict(
            source="speechbrain/spkrec-ecapa-voxceleb",
            savedir=os.path.join(download_dir, "spkrec-ecapa-voxceleb"),
        )
    elif spk_emb_model == "xvector":
        speechbrain_args = dict(
            source="speechbrain/spkrec-xvect-voxceleb",
            savedir=os.path.join(download_dir, "spkrec-xvect-voxceleb"),
        )
    else:
        raise ValueError(
            f"Unknown speaker embedding model ({spk_emb_model})! "
            f"Currently, spk_emb_model should be one of ['ecapa', 'xvector']."
        )

    device = f"cuda:{gpu_id}" if gpu_id >= 0 else "cpu"
    speechbrain_args.update(run_opts=dict(device=device))
    speechbrain_model = EncoderClassifier.from_hparams(**speechbrain_args)

    idx2spk_feat, spk2aver_spk_feat, resamplers = {}, {}, {}
    # loop each speaker
    for spk_id, wav_dict in tqdm(spk2wav_dict.items()):
        # a batch contains only waveforms for a single speaker
        curr_batch = []
        # loop each waveform file for each speaker
        for wav_idx, wav_path in wav_dict.items():
            # Collect the data into the current batch
            wav, sample_rate = read_data_by_path(
                wav_path, return_tensor=True, return_sample_rate=True
            )
            wav = wav.squeeze(-1).to(device)
            if sample_rate > 16000:
                if sample_rate not in resamplers.keys():
                    resamplers[sample_rate] = torchaudio.transforms.Resample(
                        orig_freq=sample_rate, new_freq=16000
                    ).to(device)
                wav = resamplers[sample_rate](wav)

            elif sample_rate < 16000:
                raise RuntimeError

            curr_batch.append([wav_idx, wav])
            # Process the batch if it meets the given size
            if len(curr_batch) == batch_size:
                # refresh the current batch
                curr_batch = proc_curr_batch()

        # Process the remaining incomplete batch
        if len(curr_batch) != 0:
            curr_batch = proc_curr_batch()

        # calculate the average speaker embedding for each speaker
        spk_feat_list, failed_idx_list = [], []
        # loop each waveform file for each speaker and check the saved data
        for wav_idx in wav_dict.keys():
            if isinstance(idx2spk_feat[wav_idx], str):
                try:
                    spk_feat_list.append(read_data_by_path(idx2spk_feat[wav_idx]))
                except ValueError:
                    # record the waveform index whose speaker embedding dumping failed
                    failed_idx_list.append(wav_idx)
            else:
                spk_feat_list.append(idx2spk_feat[wav_idx])

        # keep looping until no failed waveforms remain
        while len(failed_idx_list) != 0:
            # loop each failed waveform
            for wav_idx in failed_idx_list:
                wav, sample_rate = read_data_by_path(
                    wav_dict[wav_idx], return_tensor=True, return_sample_rate=True
                )
                wav = wav.squeeze(-1).to(device)
                if sample_rate > 16000:
                    wav = resamplers[sample_rate](wav)
                curr_batch.append([wav_idx, wav])

            # reprocess the failed waveforms and check again
            curr_batch = proc_curr_batch()
            for wav_idx in failed_idx_list:
                try:
                    spk_feat_list.append(idx2spk_feat[wav_idx])
                except ValueError:
                    # the waveform remains in the failed list if the error happens again
                    pass
                else:
                    # remove the waveform if no error happens
                    failed_idx_list.remove(wav_idx)

        aver_spk_feat = np.mean(spk_feat_list, axis=0)
        # Save the average speaker features in memory or to disk
        if save_path is None:
            spk2aver_spk_feat[spk_id] = aver_spk_feat
        else:
            spk2aver_spk_feat.update(
                save_data_by_format(
                    file_format="npy",
                    save_path=save_path,
                    group_ids=spk_id,
                    file_name_list=spk_id,
                    file_content_list=aver_spk_feat,
                )
            )
    return idx2spk_feat, spk2aver_spk_feat